import os import json import asyncio from typing import Optional, Dict, Any from enum import Enum from fastapi import FastAPI, WebSocket, WebSocketDisconnect from dotenv import load_dotenv import websockets # ------------------ ENV ------------------ load_dotenv() # ------------------ SECURITY IMPORTS ------------------ from pii_masking import mask_pii, contains_pii_regex from prompt_injection import is_prompt_injection from llm_pii_detector import contains_pii_llm from base64_utils import ( extract_base64_segments, replace_base64_with_decoded ) # ------------------ ENUMS ------------------ class SecurityCheckResult(Enum): SAFE = "safe" PII_DETECTED_MASKABLE = "pii_detected_maskable" PII_DETECTED_NON_MASKABLE = "pii_detected_non_maskable" PROMPT_INJECTION = "prompt_injection" UNSAFE_ENCODED_CONTENT = "unsafe_encoded_content" # ------------------ SECURITY MIDDLEWARE ------------------ class SecurityMiddleware: def __init__(self): self.sessions: Dict[str, Dict[str, Any]] = {} def _get_session(self, session_id: str) -> Dict[str, Any]: if session_id not in self.sessions: self.sessions[session_id] = { "awaiting_pii_choice": False, "pii_pending_maskable": False, "pii_pending_masked_text": "", "original_message": "" } return self.sessions[session_id] def _clear_session(self, session_id: str): self.sessions.pop(session_id, None) def _looks_like_mask_choice(self, text: str) -> bool: t = text.strip().lower() return ("mask" in t) or (t in {"m", "yes", "y"}) def _looks_like_rewrite_choice(self, text: str) -> bool: t = text.strip().lower() return ( ("new" in t and "prompt" in t) or ("rewrite" in t) or ("rephrase" in t) or (t in {"n", "no"}) ) async def check_security( self, user_input: str, session_id: str ) -> tuple[SecurityCheckResult, str, Optional[str]]: session = self._get_session(session_id) if session["awaiting_pii_choice"]: return await self._handle_pii_choice(user_input, session) return await self._run_security_checks(user_input, session) async def _handle_pii_choice( self, user_input: str, session: Dict[str, Any] ): if self._looks_like_mask_choice(user_input): if not session["pii_pending_maskable"]: session["awaiting_pii_choice"] = False return ( SecurityCheckResult.PII_DETECTED_NON_MASKABLE, "Cannot safely mask this PII. Rewrite the prompt.", None ) masked = session["pii_pending_masked_text"] session["awaiting_pii_choice"] = False session["pii_pending_maskable"] = False return await self._run_security_checks(masked, session, skip_pii=True) if self._looks_like_rewrite_choice(user_input): session["awaiting_pii_choice"] = False return ( SecurityCheckResult.PII_DETECTED_NON_MASKABLE, "Please rewrite your prompt without PII.", None ) return ( SecurityCheckResult.PII_DETECTED_MASKABLE, "PII detected. Mask it or rewrite?", None ) async def _run_security_checks( self, user_input: str, session: Dict[str, Any], skip_pii: bool = False ): if not skip_pii: if contains_pii_regex(user_input): masked = mask_pii(user_input, user_tier="free") session.update({ "awaiting_pii_choice": True, "pii_pending_maskable": True, "pii_pending_masked_text": masked }) return ( SecurityCheckResult.PII_DETECTED_MASKABLE, "PII detected. Mask it or rewrite?", None ) if contains_pii_llm(user_input): session["awaiting_pii_choice"] = True return ( SecurityCheckResult.PII_DETECTED_NON_MASKABLE, "Sensitive PII detected. Rewrite prompt.", None ) masked_text = mask_pii(user_input, user_tier="free") base64_segments = extract_base64_segments(masked_text) if base64_segments: decoded = replace_base64_with_decoded(masked_text, base64_segments) if is_prompt_injection(decoded) or contains_pii_llm(decoded): return ( SecurityCheckResult.UNSAFE_ENCODED_CONTENT, "Unsafe encoded content detected.", None ) if is_prompt_injection(masked_text): return ( SecurityCheckResult.PROMPT_INJECTION, "Prompt injection detected. Rephrase.", None ) return (SecurityCheckResult.SAFE, "", masked_text) def sanitize_output(self, text: str) -> str: return mask_pii(text, user_tier="free") # ------------------ INIT ------------------ security_middleware = SecurityMiddleware() app = FastAPI() # ------------------ WEBSOCKET ------------------ @app.websocket("/chat") async def chat_ws(websocket: WebSocket): await websocket.accept() backend_ws = None try: session_id = websocket.query_params.get("session_id", "default") backend_url = os.getenv( "BACKEND_AI_URL", "wss://partha181098-backend.hf.space/chat" ) while True: user_input = await websocket.receive_text() result, message, processed_input = await security_middleware.check_security( user_input, session_id ) if result != SecurityCheckResult.SAFE: await websocket.send_json({ "type": "security_warning", "message": message, "status": result.value }) continue if backend_ws is None or backend_ws.closed: backend_ws = await websockets.connect( f"{backend_url}?session_id={session_id}" ) await backend_ws.send(processed_input) # 🔥 TRUE STREAMING LOOP while True: data = await backend_ws.recv() payload = json.loads(data) if payload["type"] == "response": sanitized = security_middleware.sanitize_output( payload["content"] ) await websocket.send_json({ "type": "response", "content": sanitized }) elif payload["type"] == "complete": await websocket.send_json({"type": "complete"}) break except WebSocketDisconnect: print("Client disconnected") except Exception as e: print(f"Middleware error: {e}") try: await websocket.send_json({ "type": "error", "message": "Middleware failure" }) except Exception: pass finally: if backend_ws and not backend_ws.closed: await backend_ws.close() # ------------------ HEALTH ------------------ @app.get("/health") async def health(): return {"status": "ok"} # ------------------ RUN ------------------ if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8000)