Spaces:
Sleeping
Sleeping
| import os | |
| import json | |
| import asyncio | |
| from typing import Optional, Dict, Any | |
| from enum import Enum | |
| from fastapi import FastAPI, WebSocket, WebSocketDisconnect | |
| from dotenv import load_dotenv | |
| import websockets | |
| # ------------------ ENV ------------------ | |
| load_dotenv() | |
| # ------------------ SECURITY IMPORTS ------------------ | |
| from pii_masking import mask_pii, contains_pii_regex | |
| from prompt_injection import is_prompt_injection | |
| from llm_pii_detector import contains_pii_llm | |
| from base64_utils import ( | |
| extract_base64_segments, | |
| replace_base64_with_decoded | |
| ) | |
| # ------------------ ENUMS ------------------ | |
| class SecurityCheckResult(Enum): | |
| SAFE = "safe" | |
| PII_DETECTED_MASKABLE = "pii_detected_maskable" | |
| PII_DETECTED_NON_MASKABLE = "pii_detected_non_maskable" | |
| PROMPT_INJECTION = "prompt_injection" | |
| UNSAFE_ENCODED_CONTENT = "unsafe_encoded_content" | |
| # ------------------ SECURITY MIDDLEWARE ------------------ | |
| class SecurityMiddleware: | |
| def __init__(self): | |
| self.sessions: Dict[str, Dict[str, Any]] = {} | |
| def _get_session(self, session_id: str) -> Dict[str, Any]: | |
| if session_id not in self.sessions: | |
| self.sessions[session_id] = { | |
| "awaiting_pii_choice": False, | |
| "pii_pending_maskable": False, | |
| "pii_pending_masked_text": "", | |
| "original_message": "" | |
| } | |
| return self.sessions[session_id] | |
| def _clear_session(self, session_id: str): | |
| self.sessions.pop(session_id, None) | |
| def _looks_like_mask_choice(self, text: str) -> bool: | |
| t = text.strip().lower() | |
| return ("mask" in t) or (t in {"m", "yes", "y"}) | |
| def _looks_like_rewrite_choice(self, text: str) -> bool: | |
| t = text.strip().lower() | |
| return ( | |
| ("new" in t and "prompt" in t) | |
| or ("rewrite" in t) | |
| or ("rephrase" in t) | |
| or (t in {"n", "no"}) | |
| ) | |
| async def check_security( | |
| self, | |
| user_input: str, | |
| session_id: str | |
| ) -> tuple[SecurityCheckResult, str, Optional[str]]: | |
| session = self._get_session(session_id) | |
| if session["awaiting_pii_choice"]: | |
| return await self._handle_pii_choice(user_input, session) | |
| return await self._run_security_checks(user_input, session) | |
| async def _handle_pii_choice( | |
| self, | |
| user_input: str, | |
| session: Dict[str, Any] | |
| ): | |
| if self._looks_like_mask_choice(user_input): | |
| if not session["pii_pending_maskable"]: | |
| session["awaiting_pii_choice"] = False | |
| return ( | |
| SecurityCheckResult.PII_DETECTED_NON_MASKABLE, | |
| "Cannot safely mask this PII. Rewrite the prompt.", | |
| None | |
| ) | |
| masked = session["pii_pending_masked_text"] | |
| session["awaiting_pii_choice"] = False | |
| session["pii_pending_maskable"] = False | |
| return await self._run_security_checks(masked, session, skip_pii=True) | |
| if self._looks_like_rewrite_choice(user_input): | |
| session["awaiting_pii_choice"] = False | |
| return ( | |
| SecurityCheckResult.PII_DETECTED_NON_MASKABLE, | |
| "Please rewrite your prompt without PII.", | |
| None | |
| ) | |
| return ( | |
| SecurityCheckResult.PII_DETECTED_MASKABLE, | |
| "PII detected. Mask it or rewrite?", | |
| None | |
| ) | |
| async def _run_security_checks( | |
| self, | |
| user_input: str, | |
| session: Dict[str, Any], | |
| skip_pii: bool = False | |
| ): | |
| if not skip_pii: | |
| if contains_pii_regex(user_input): | |
| masked = mask_pii(user_input, user_tier="free") | |
| session.update({ | |
| "awaiting_pii_choice": True, | |
| "pii_pending_maskable": True, | |
| "pii_pending_masked_text": masked | |
| }) | |
| return ( | |
| SecurityCheckResult.PII_DETECTED_MASKABLE, | |
| "PII detected. Mask it or rewrite?", | |
| None | |
| ) | |
| if contains_pii_llm(user_input): | |
| session["awaiting_pii_choice"] = True | |
| return ( | |
| SecurityCheckResult.PII_DETECTED_NON_MASKABLE, | |
| "Sensitive PII detected. Rewrite prompt.", | |
| None | |
| ) | |
| masked_text = mask_pii(user_input, user_tier="free") | |
| base64_segments = extract_base64_segments(masked_text) | |
| if base64_segments: | |
| decoded = replace_base64_with_decoded(masked_text, base64_segments) | |
| if is_prompt_injection(decoded) or contains_pii_llm(decoded): | |
| return ( | |
| SecurityCheckResult.UNSAFE_ENCODED_CONTENT, | |
| "Unsafe encoded content detected.", | |
| None | |
| ) | |
| if is_prompt_injection(masked_text): | |
| return ( | |
| SecurityCheckResult.PROMPT_INJECTION, | |
| "Prompt injection detected. Rephrase.", | |
| None | |
| ) | |
| return (SecurityCheckResult.SAFE, "", masked_text) | |
| def sanitize_output(self, text: str) -> str: | |
| return mask_pii(text, user_tier="free") | |
| # ------------------ INIT ------------------ | |
| security_middleware = SecurityMiddleware() | |
| app = FastAPI() | |
| # ------------------ WEBSOCKET ------------------ | |
| async def chat_ws(websocket: WebSocket): | |
| await websocket.accept() | |
| backend_ws = None | |
| try: | |
| session_id = websocket.query_params.get("session_id", "default") | |
| backend_url = os.getenv( | |
| "BACKEND_AI_URL", | |
| "wss://partha181098-backend.hf.space/chat" | |
| ) | |
| while True: | |
| user_input = await websocket.receive_text() | |
| result, message, processed_input = await security_middleware.check_security( | |
| user_input, session_id | |
| ) | |
| if result != SecurityCheckResult.SAFE: | |
| await websocket.send_json({ | |
| "type": "security_warning", | |
| "message": message, | |
| "status": result.value | |
| }) | |
| continue | |
| if backend_ws is None or backend_ws.closed: | |
| backend_ws = await websockets.connect( | |
| f"{backend_url}?session_id={session_id}" | |
| ) | |
| await backend_ws.send(processed_input) | |
| # 🔥 TRUE STREAMING LOOP | |
| while True: | |
| data = await backend_ws.recv() | |
| payload = json.loads(data) | |
| if payload["type"] == "response": | |
| sanitized = security_middleware.sanitize_output( | |
| payload["content"] | |
| ) | |
| await websocket.send_json({ | |
| "type": "response", | |
| "content": sanitized | |
| }) | |
| elif payload["type"] == "complete": | |
| await websocket.send_json({"type": "complete"}) | |
| break | |
| except WebSocketDisconnect: | |
| print("Client disconnected") | |
| except Exception as e: | |
| print(f"Middleware error: {e}") | |
| try: | |
| await websocket.send_json({ | |
| "type": "error", | |
| "message": "Middleware failure" | |
| }) | |
| except Exception: | |
| pass | |
| finally: | |
| if backend_ws and not backend_ws.closed: | |
| await backend_ws.close() | |
| # ------------------ HEALTH ------------------ | |
| async def health(): | |
| return {"status": "ok"} | |
| # ------------------ RUN ------------------ | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=8000) | |