Spaces:
Build error
Build error
| """ | |
| WebSocket chat router β real-time streaming AI chat with: | |
| - Prompt injection prevention | |
| - Input sanitization | |
| - Heartbeat/ping support | |
| - Structured error responses | |
| - Observability metrics tracking | |
| """ | |
| import json | |
| import re | |
| from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Query | |
| from app.database.database import SessionLocal | |
| from app.database.models import User | |
| from app.websocket.connection_manager import ws_manager | |
| from app.ai.chat import stream_chat_response | |
| from app.middleware.logging import ws_logger, metrics | |
| router = APIRouter(tags=["WebSockets"]) | |
| # βββ Prompt injection patterns ββββββββββββββββββββββββββββββββββββββββββββββββ | |
| _INJECTION_PATTERNS = [ | |
| r"ignore\s+(all\s+)?previous\s+instructions", | |
| r"you\s+are\s+now\s+a", | |
| r"forget\s+(everything|all)", | |
| r"new\s+system\s+prompt", | |
| r"disregard\s+(your|all)", | |
| r"act\s+as\s+(if\s+you\s+are|a\s+different)", | |
| r"jailbreak", | |
| r"dan\s+mode", | |
| r"developer\s+mode", | |
| r"<\s*script", | |
| r"javascript:", | |
| ] | |
| _INJECTION_RE = re.compile("|".join(_INJECTION_PATTERNS), re.IGNORECASE) | |
| MAX_MESSAGE_LENGTH = 2000 | |
| def sanitize_prompt(text: str) -> tuple[str, bool]: | |
| """ | |
| Returns (sanitized_text, is_safe). | |
| Strips control characters, checks for injection patterns. | |
| """ | |
| # Strip null bytes and control characters (keep newlines/tabs) | |
| cleaned = re.sub(r"[\x00-\x08\x0b\x0c\x0e-\x1f\x7f]", "", text) | |
| # Truncate | |
| cleaned = cleaned[:MAX_MESSAGE_LENGTH] | |
| # Check for injection | |
| if _INJECTION_RE.search(cleaned): | |
| return cleaned, False | |
| return cleaned, True | |
| async def websocket_chat_endpoint( | |
| websocket: WebSocket, | |
| user_id: str = Query(None), | |
| ): | |
| db = SessionLocal() | |
| # Resolve user | |
| if not user_id: | |
| user = db.query(User).first() | |
| if user: | |
| user_id = user.id | |
| else: | |
| await websocket.accept() | |
| await websocket.send_json({ | |
| "type": "error", | |
| "message": "No users found. Run: python app/scripts/seed_demo.py" | |
| }) | |
| await websocket.close() | |
| db.close() | |
| return | |
| await ws_manager.connect(websocket, user_id) | |
| metrics.ws_connects += 1 | |
| ws_logger.info("WebSocket connected", extra={"user_id": user_id[:8]}) | |
| try: | |
| while True: | |
| data = await websocket.receive_text() | |
| try: | |
| payload = json.loads(data) | |
| except json.JSONDecodeError: | |
| await websocket.send_json({"type": "error", "message": "Invalid JSON"}) | |
| continue | |
| msg_type = payload.get("type", "chat") | |
| # ββ Heartbeat ββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| if msg_type == "ping": | |
| await websocket.send_json({"type": "pong"}) | |
| continue | |
| # ββ Chat message βββββββββββββββββββββββββββββββββββββββββββββββββ | |
| if msg_type == "chat": | |
| raw_prompt = payload.get("message", "").strip() | |
| if not raw_prompt: | |
| await websocket.send_json({"type": "error", "message": "Message cannot be empty"}) | |
| continue | |
| # Sanitize + injection check | |
| prompt, is_safe = sanitize_prompt(raw_prompt) | |
| if not is_safe: | |
| ws_logger.warning("Prompt injection attempt blocked", extra={"user_id": user_id[:8]}) | |
| await websocket.send_json({ | |
| "type": "error", | |
| "message": "I can only help with financial questions about your accounts." | |
| }) | |
| continue | |
| await websocket.send_json({"type": "chat_start"}) | |
| try: | |
| for chunk in stream_chat_response(db, user_id, prompt): | |
| if chunk: | |
| await websocket.send_json({"type": "chat_chunk", "content": chunk}) | |
| except Exception as e: | |
| ws_logger.error("AI streaming error", extra={"error": str(e)[:100]}) | |
| await websocket.send_json({ | |
| "type": "error", | |
| "message": "AI response failed. Please try again." | |
| }) | |
| await websocket.send_json({"type": "chat_end"}) | |
| else: | |
| await websocket.send_json({ | |
| "type": "error", | |
| "message": f"Unknown message type: {msg_type}" | |
| }) | |
| except WebSocketDisconnect: | |
| ws_manager.disconnect(websocket, user_id) | |
| ws_logger.info("WebSocket disconnected", extra={"user_id": user_id[:8]}) | |
| except Exception as e: | |
| ws_logger.error("WebSocket error", extra={"user_id": user_id[:8], "error": str(e)[:100]}) | |
| ws_manager.disconnect(websocket, user_id) | |
| finally: | |
| db.close() | |