mohsin-devs's picture
Deploy to HF
a282d4b
"""
WebSocket chat router β€” real-time streaming AI chat with:
- Prompt injection prevention
- Input sanitization
- Heartbeat/ping support
- Structured error responses
- Observability metrics tracking
"""
import json
import re
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Query
from app.database.database import SessionLocal
from app.database.models import User
from app.websocket.connection_manager import ws_manager
from app.ai.chat import stream_chat_response
from app.middleware.logging import ws_logger, metrics
router = APIRouter(tags=["WebSockets"])
# ─── Prompt injection patterns ────────────────────────────────────────────────
_INJECTION_PATTERNS = [
r"ignore\s+(all\s+)?previous\s+instructions",
r"you\s+are\s+now\s+a",
r"forget\s+(everything|all)",
r"new\s+system\s+prompt",
r"disregard\s+(your|all)",
r"act\s+as\s+(if\s+you\s+are|a\s+different)",
r"jailbreak",
r"dan\s+mode",
r"developer\s+mode",
r"<\s*script",
r"javascript:",
]
_INJECTION_RE = re.compile("|".join(_INJECTION_PATTERNS), re.IGNORECASE)
MAX_MESSAGE_LENGTH = 2000
def sanitize_prompt(text: str) -> tuple[str, bool]:
"""
Returns (sanitized_text, is_safe).
Strips control characters, checks for injection patterns.
"""
# Strip null bytes and control characters (keep newlines/tabs)
cleaned = re.sub(r"[\x00-\x08\x0b\x0c\x0e-\x1f\x7f]", "", text)
# Truncate
cleaned = cleaned[:MAX_MESSAGE_LENGTH]
# Check for injection
if _INJECTION_RE.search(cleaned):
return cleaned, False
return cleaned, True
@router.websocket("/api/ai/chat/ws")
async def websocket_chat_endpoint(
websocket: WebSocket,
user_id: str = Query(None),
):
db = SessionLocal()
# Resolve user
if not user_id:
user = db.query(User).first()
if user:
user_id = user.id
else:
await websocket.accept()
await websocket.send_json({
"type": "error",
"message": "No users found. Run: python app/scripts/seed_demo.py"
})
await websocket.close()
db.close()
return
await ws_manager.connect(websocket, user_id)
metrics.ws_connects += 1
ws_logger.info("WebSocket connected", extra={"user_id": user_id[:8]})
try:
while True:
data = await websocket.receive_text()
try:
payload = json.loads(data)
except json.JSONDecodeError:
await websocket.send_json({"type": "error", "message": "Invalid JSON"})
continue
msg_type = payload.get("type", "chat")
# ── Heartbeat ────────────────────────────────────────────────────
if msg_type == "ping":
await websocket.send_json({"type": "pong"})
continue
# ── Chat message ─────────────────────────────────────────────────
if msg_type == "chat":
raw_prompt = payload.get("message", "").strip()
if not raw_prompt:
await websocket.send_json({"type": "error", "message": "Message cannot be empty"})
continue
# Sanitize + injection check
prompt, is_safe = sanitize_prompt(raw_prompt)
if not is_safe:
ws_logger.warning("Prompt injection attempt blocked", extra={"user_id": user_id[:8]})
await websocket.send_json({
"type": "error",
"message": "I can only help with financial questions about your accounts."
})
continue
await websocket.send_json({"type": "chat_start"})
try:
for chunk in stream_chat_response(db, user_id, prompt):
if chunk:
await websocket.send_json({"type": "chat_chunk", "content": chunk})
except Exception as e:
ws_logger.error("AI streaming error", extra={"error": str(e)[:100]})
await websocket.send_json({
"type": "error",
"message": "AI response failed. Please try again."
})
await websocket.send_json({"type": "chat_end"})
else:
await websocket.send_json({
"type": "error",
"message": f"Unknown message type: {msg_type}"
})
except WebSocketDisconnect:
ws_manager.disconnect(websocket, user_id)
ws_logger.info("WebSocket disconnected", extra={"user_id": user_id[:8]})
except Exception as e:
ws_logger.error("WebSocket error", extra={"user_id": user_id[:8], "error": str(e)[:100]})
ws_manager.disconnect(websocket, user_id)
finally:
db.close()