""" WebSocket chat router — real-time streaming AI chat with: - Prompt injection prevention - Input sanitization - Heartbeat/ping support - Structured error responses - Observability metrics tracking """ import json import re from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Query from app.database.database import SessionLocal from app.database.models import User from app.websocket.connection_manager import ws_manager from app.ai.chat import stream_chat_response from app.middleware.logging import ws_logger, metrics router = APIRouter(tags=["WebSockets"]) # ─── Prompt injection patterns ──────────────────────────────────────────────── _INJECTION_PATTERNS = [ r"ignore\s+(all\s+)?previous\s+instructions", r"you\s+are\s+now\s+a", r"forget\s+(everything|all)", r"new\s+system\s+prompt", r"disregard\s+(your|all)", r"act\s+as\s+(if\s+you\s+are|a\s+different)", r"jailbreak", r"dan\s+mode", r"developer\s+mode", r"<\s*script", r"javascript:", ] _INJECTION_RE = re.compile("|".join(_INJECTION_PATTERNS), re.IGNORECASE) MAX_MESSAGE_LENGTH = 2000 def sanitize_prompt(text: str) -> tuple[str, bool]: """ Returns (sanitized_text, is_safe). Strips control characters, checks for injection patterns. """ # Strip null bytes and control characters (keep newlines/tabs) cleaned = re.sub(r"[\x00-\x08\x0b\x0c\x0e-\x1f\x7f]", "", text) # Truncate cleaned = cleaned[:MAX_MESSAGE_LENGTH] # Check for injection if _INJECTION_RE.search(cleaned): return cleaned, False return cleaned, True @router.websocket("/api/ai/chat/ws") async def websocket_chat_endpoint( websocket: WebSocket, user_id: str = Query(None), ): db = SessionLocal() # Resolve user if not user_id: user = db.query(User).first() if user: user_id = user.id else: await websocket.accept() await websocket.send_json({ "type": "error", "message": "No users found. Run: python app/scripts/seed_demo.py" }) await websocket.close() db.close() return await ws_manager.connect(websocket, user_id) metrics.ws_connects += 1 ws_logger.info("WebSocket connected", extra={"user_id": user_id[:8]}) try: while True: data = await websocket.receive_text() try: payload = json.loads(data) except json.JSONDecodeError: await websocket.send_json({"type": "error", "message": "Invalid JSON"}) continue msg_type = payload.get("type", "chat") # ── Heartbeat ──────────────────────────────────────────────────── if msg_type == "ping": await websocket.send_json({"type": "pong"}) continue # ── Chat message ───────────────────────────────────────────────── if msg_type == "chat": raw_prompt = payload.get("message", "").strip() if not raw_prompt: await websocket.send_json({"type": "error", "message": "Message cannot be empty"}) continue # Sanitize + injection check prompt, is_safe = sanitize_prompt(raw_prompt) if not is_safe: ws_logger.warning("Prompt injection attempt blocked", extra={"user_id": user_id[:8]}) await websocket.send_json({ "type": "error", "message": "I can only help with financial questions about your accounts." }) continue await websocket.send_json({"type": "chat_start"}) try: for chunk in stream_chat_response(db, user_id, prompt): if chunk: await websocket.send_json({"type": "chat_chunk", "content": chunk}) except Exception as e: ws_logger.error("AI streaming error", extra={"error": str(e)[:100]}) await websocket.send_json({ "type": "error", "message": "AI response failed. Please try again." }) await websocket.send_json({"type": "chat_end"}) else: await websocket.send_json({ "type": "error", "message": f"Unknown message type: {msg_type}" }) except WebSocketDisconnect: ws_manager.disconnect(websocket, user_id) ws_logger.info("WebSocket disconnected", extra={"user_id": user_id[:8]}) except Exception as e: ws_logger.error("WebSocket error", extra={"user_id": user_id[:8], "error": str(e)[:100]}) ws_manager.disconnect(websocket, user_id) finally: db.close()