middleware / app.py
PARTHA181098's picture
Update app.py
8d543c5 verified
import os
import json
import asyncio
from typing import Optional, Dict, Any
from enum import Enum
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from dotenv import load_dotenv
import websockets
# ------------------ ENV ------------------
load_dotenv()
# ------------------ SECURITY IMPORTS ------------------
from pii_masking import mask_pii, contains_pii_regex
from prompt_injection import is_prompt_injection
from llm_pii_detector import contains_pii_llm
from base64_utils import (
extract_base64_segments,
replace_base64_with_decoded
)
# ------------------ ENUMS ------------------
class SecurityCheckResult(Enum):
SAFE = "safe"
PII_DETECTED_MASKABLE = "pii_detected_maskable"
PII_DETECTED_NON_MASKABLE = "pii_detected_non_maskable"
PROMPT_INJECTION = "prompt_injection"
UNSAFE_ENCODED_CONTENT = "unsafe_encoded_content"
# ------------------ SECURITY MIDDLEWARE ------------------
class SecurityMiddleware:
def __init__(self):
self.sessions: Dict[str, Dict[str, Any]] = {}
def _get_session(self, session_id: str) -> Dict[str, Any]:
if session_id not in self.sessions:
self.sessions[session_id] = {
"awaiting_pii_choice": False,
"pii_pending_maskable": False,
"pii_pending_masked_text": "",
"original_message": ""
}
return self.sessions[session_id]
def _clear_session(self, session_id: str):
self.sessions.pop(session_id, None)
def _looks_like_mask_choice(self, text: str) -> bool:
t = text.strip().lower()
return ("mask" in t) or (t in {"m", "yes", "y"})
def _looks_like_rewrite_choice(self, text: str) -> bool:
t = text.strip().lower()
return (
("new" in t and "prompt" in t)
or ("rewrite" in t)
or ("rephrase" in t)
or (t in {"n", "no"})
)
async def check_security(
self,
user_input: str,
session_id: str
) -> tuple[SecurityCheckResult, str, Optional[str]]:
session = self._get_session(session_id)
if session["awaiting_pii_choice"]:
return await self._handle_pii_choice(user_input, session)
return await self._run_security_checks(user_input, session)
async def _handle_pii_choice(
self,
user_input: str,
session: Dict[str, Any]
):
if self._looks_like_mask_choice(user_input):
if not session["pii_pending_maskable"]:
session["awaiting_pii_choice"] = False
return (
SecurityCheckResult.PII_DETECTED_NON_MASKABLE,
"Cannot safely mask this PII. Rewrite the prompt.",
None
)
masked = session["pii_pending_masked_text"]
session["awaiting_pii_choice"] = False
session["pii_pending_maskable"] = False
return await self._run_security_checks(masked, session, skip_pii=True)
if self._looks_like_rewrite_choice(user_input):
session["awaiting_pii_choice"] = False
return (
SecurityCheckResult.PII_DETECTED_NON_MASKABLE,
"Please rewrite your prompt without PII.",
None
)
return (
SecurityCheckResult.PII_DETECTED_MASKABLE,
"PII detected. Mask it or rewrite?",
None
)
async def _run_security_checks(
self,
user_input: str,
session: Dict[str, Any],
skip_pii: bool = False
):
if not skip_pii:
if contains_pii_regex(user_input):
masked = mask_pii(user_input, user_tier="free")
session.update({
"awaiting_pii_choice": True,
"pii_pending_maskable": True,
"pii_pending_masked_text": masked
})
return (
SecurityCheckResult.PII_DETECTED_MASKABLE,
"PII detected. Mask it or rewrite?",
None
)
if contains_pii_llm(user_input):
session["awaiting_pii_choice"] = True
return (
SecurityCheckResult.PII_DETECTED_NON_MASKABLE,
"Sensitive PII detected. Rewrite prompt.",
None
)
masked_text = mask_pii(user_input, user_tier="free")
base64_segments = extract_base64_segments(masked_text)
if base64_segments:
decoded = replace_base64_with_decoded(masked_text, base64_segments)
if is_prompt_injection(decoded) or contains_pii_llm(decoded):
return (
SecurityCheckResult.UNSAFE_ENCODED_CONTENT,
"Unsafe encoded content detected.",
None
)
if is_prompt_injection(masked_text):
return (
SecurityCheckResult.PROMPT_INJECTION,
"Prompt injection detected. Rephrase.",
None
)
return (SecurityCheckResult.SAFE, "", masked_text)
def sanitize_output(self, text: str) -> str:
return mask_pii(text, user_tier="free")
# ------------------ INIT ------------------
security_middleware = SecurityMiddleware()
app = FastAPI()
# ------------------ WEBSOCKET ------------------
@app.websocket("/chat")
async def chat_ws(websocket: WebSocket):
await websocket.accept()
backend_ws = None
try:
session_id = websocket.query_params.get("session_id", "default")
backend_url = os.getenv(
"BACKEND_AI_URL",
"wss://partha181098-backend.hf.space/chat"
)
while True:
user_input = await websocket.receive_text()
result, message, processed_input = await security_middleware.check_security(
user_input, session_id
)
if result != SecurityCheckResult.SAFE:
await websocket.send_json({
"type": "security_warning",
"message": message,
"status": result.value
})
continue
if backend_ws is None or backend_ws.closed:
backend_ws = await websockets.connect(
f"{backend_url}?session_id={session_id}"
)
await backend_ws.send(processed_input)
# 🔥 TRUE STREAMING LOOP
while True:
data = await backend_ws.recv()
payload = json.loads(data)
if payload["type"] == "response":
sanitized = security_middleware.sanitize_output(
payload["content"]
)
await websocket.send_json({
"type": "response",
"content": sanitized
})
elif payload["type"] == "complete":
await websocket.send_json({"type": "complete"})
break
except WebSocketDisconnect:
print("Client disconnected")
except Exception as e:
print(f"Middleware error: {e}")
try:
await websocket.send_json({
"type": "error",
"message": "Middleware failure"
})
except Exception:
pass
finally:
if backend_ws and not backend_ws.closed:
await backend_ws.close()
# ------------------ HEALTH ------------------
@app.get("/health")
async def health():
return {"status": "ok"}
# ------------------ RUN ------------------
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)