Spaces:
Sleeping
Sleeping
shivam0897-i
fix(backend): Convert PyTorch thread execution to bounded async pool to prevent OOM on HF Spaces
4eae08d | """ | |
| FastAPI application for AI-Generated Voice Detection. | |
| Endpoint: POST /api/voice-detection | |
| - Accepts Base64-encoded MP3 audio | |
| - Returns classification (AI_GENERATED or HUMAN) with confidence score | |
| """ | |
| import logging | |
| import asyncio | |
| import uuid | |
| import time | |
| import json | |
| import io | |
| import hmac | |
| from dataclasses import dataclass, field, asdict | |
| from datetime import datetime, timezone | |
| from typing import Optional, Any, Dict, List | |
| from contextlib import asynccontextmanager | |
| import numpy as np | |
| from fastapi import FastAPI, HTTPException, Request, Depends, WebSocket, WebSocketDisconnect, Security | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import JSONResponse, RedirectResponse | |
| from fastapi.security import APIKeyHeader | |
| from fastapi.exceptions import RequestValidationError | |
| from pydantic import BaseModel, Field, field_validator, ValidationError | |
| from slowapi import Limiter, _rate_limit_exceeded_handler | |
| from slowapi.util import get_remote_address | |
| from slowapi.errors import RateLimitExceeded | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' | |
| ) | |
| logger = logging.getLogger(__name__) | |
| limiter = Limiter(key_func=get_remote_address, default_limits=["1000/minute"]) | |
| from audio_utils import decode_base64_audio, load_audio_from_bytes | |
| from model import analyze_voice, AnalysisResult | |
| from speech_to_text import transcribe_audio | |
| from fraud_language import analyze_transcript | |
| from llm_semantic_analyzer import analyze_semantic_with_llm, is_llm_semantic_provider_ready | |
| from privacy_utils import mask_sensitive_entities, sanitize_for_logging | |
| from config import settings | |
| try: | |
| import redis # type: ignore | |
| except Exception: # pragma: no cover - optional dependency | |
| redis = None | |
| MAX_AUDIO_BASE64_LENGTH = settings.MAX_AUDIO_SIZE_MB * 1024 * 1024 * 4 // 3 | |
| class SessionState: | |
| """In-memory state for a real-time analysis session (derived data only).""" | |
| session_id: str | |
| language: str | |
| started_at: str | |
| status: str = "active" | |
| chunks_processed: int = 0 | |
| alerts_triggered: int = 0 | |
| max_risk_score: int = 0 | |
| max_cpi: float = 0.0 | |
| final_call_label: str = "UNCERTAIN" | |
| final_voice_classification: str = "UNCERTAIN" | |
| final_voice_confidence: float = 0.0 | |
| max_voice_ai_confidence: float = 0.0 | |
| max_voice_human_confidence: float = 0.0 | |
| voice_ai_chunks: int = 0 | |
| voice_human_chunks: int = 0 | |
| llm_checks_performed: int = 0 | |
| risk_policy_version: str = settings.RISK_POLICY_VERSION | |
| risk_history: List[int] = field(default_factory=list) | |
| transcript_counts: Dict[str, int] = field(default_factory=dict) | |
| semantic_flag_counts: Dict[str, int] = field(default_factory=dict) | |
| keyword_category_counts: Dict[str, int] = field(default_factory=dict) | |
| behaviour_score: int = 0 | |
| session_behaviour_signals: List[str] = field(default_factory=list) | |
| last_transcript: str = "" | |
| last_update: Optional[str] = None | |
| alert_history: List[Dict[str, Any]] = field(default_factory=list) | |
| llm_last_engine: Optional[str] = None | |
| SESSION_STORE: Dict[str, SessionState] = {} | |
| SESSION_LOCK = asyncio.Lock() | |
| SESSION_LOCKS: Dict[str, asyncio.Lock] = {} | |
| SESSION_STORE_BACKEND_ACTIVE = "memory" | |
| REDIS_CLIENT: Any = None | |
| ASR_INFLIGHT_TASKS: set[asyncio.Task] = set() | |
| ASR_INFLIGHT_LOCK = asyncio.Lock() | |
| VOICE_INFLIGHT_TASKS: set[asyncio.Task] = set() | |
| VOICE_INFLIGHT_LOCK = asyncio.Lock() | |
| def use_redis_session_store() -> bool: | |
| """Return whether redis-backed session store is active.""" | |
| return SESSION_STORE_BACKEND_ACTIVE == "redis" and REDIS_CLIENT is not None | |
| def initialize_session_store_backend() -> None: | |
| """Initialize configured session backend with safe fallback to memory.""" | |
| global SESSION_STORE_BACKEND_ACTIVE, REDIS_CLIENT | |
| configured = str(getattr(settings, "SESSION_STORE_BACKEND", "memory") or "memory").strip().lower() | |
| if configured != "redis": | |
| SESSION_STORE_BACKEND_ACTIVE = "memory" | |
| REDIS_CLIENT = None | |
| logger.info("Session store backend: memory") | |
| return | |
| if redis is None: | |
| logger.warning("Redis backend requested but redis package is not installed. Falling back to memory store.") | |
| SESSION_STORE_BACKEND_ACTIVE = "memory" | |
| REDIS_CLIENT = None | |
| return | |
| redis_url = getattr(settings, "REDIS_URL", None) | |
| if not redis_url: | |
| logger.warning("Redis backend requested but REDIS_URL is empty. Falling back to memory store.") | |
| SESSION_STORE_BACKEND_ACTIVE = "memory" | |
| REDIS_CLIENT = None | |
| return | |
| try: | |
| REDIS_CLIENT = redis.Redis.from_url( | |
| redis_url, | |
| decode_responses=True, | |
| socket_connect_timeout=max(0.1, float(settings.REDIS_CONNECT_TIMEOUT_MS) / 1000.0), | |
| socket_timeout=max(0.1, float(settings.REDIS_IO_TIMEOUT_MS) / 1000.0), | |
| ) | |
| REDIS_CLIENT.ping() | |
| SESSION_STORE_BACKEND_ACTIVE = "redis" | |
| logger.info("Session store backend: redis") | |
| except Exception as exc: | |
| logger.warning("Failed to initialize redis session store (%s). Falling back to memory store.", exc) | |
| SESSION_STORE_BACKEND_ACTIVE = "memory" | |
| REDIS_CLIENT = None | |
| def _session_redis_key(session_id: str) -> str: | |
| return f"{settings.REDIS_PREFIX}:session:{session_id}" | |
| def _serialize_session(session: SessionState) -> str: | |
| return json.dumps(asdict(session), ensure_ascii=False, separators=(",", ":")) | |
| def _deserialize_session(raw: Optional[str]) -> Optional[SessionState]: | |
| if not raw: | |
| return None | |
| try: | |
| payload = json.loads(raw) | |
| if not isinstance(payload, dict): | |
| return None | |
| return SessionState(**payload) | |
| except Exception as exc: | |
| logger.warning("Failed to deserialize session payload: %s", exc) | |
| return None | |
| def get_session_state(session_id: str) -> Optional[SessionState]: | |
| """Fetch session state from active backend.""" | |
| if use_redis_session_store(): | |
| raw = REDIS_CLIENT.get(_session_redis_key(session_id)) | |
| return _deserialize_session(raw) | |
| return SESSION_STORE.get(session_id) | |
| def save_session_state(session: SessionState) -> None: | |
| """Persist session state to active backend.""" | |
| if use_redis_session_store(): | |
| ttl_seconds = max(1, int(session_retention_seconds(session))) | |
| REDIS_CLIENT.set(_session_redis_key(session.session_id), _serialize_session(session), ex=ttl_seconds) | |
| return | |
| SESSION_STORE[session.session_id] = session | |
| def delete_session_state(session_id: str) -> None: | |
| """Delete session from active backend.""" | |
| if use_redis_session_store(): | |
| REDIS_CLIENT.delete(_session_redis_key(session_id)) | |
| return | |
| SESSION_STORE.pop(session_id, None) | |
| def _asr_fallback_result(engine: str) -> Dict[str, Any]: | |
| return { | |
| "transcript": "", | |
| "confidence": 0.0, | |
| "engine": engine, | |
| "available": False, | |
| } | |
| def _discard_asr_task(task: asyncio.Task) -> None: | |
| ASR_INFLIGHT_TASKS.discard(task) | |
| async def transcribe_audio_guarded( | |
| audio: np.ndarray, | |
| sr: int, | |
| language: str, | |
| timeout_seconds: float, | |
| request_id: str, | |
| ) -> Dict[str, Any]: | |
| """Run ASR with timeout and bounded in-flight tasks to avoid thread pileups.""" | |
| max_inflight = max(1, int(getattr(settings, "ASR_MAX_INFLIGHT_TASKS", 1))) | |
| async with ASR_INFLIGHT_LOCK: | |
| stale_tasks = [task for task in ASR_INFLIGHT_TASKS if task.done()] | |
| for stale in stale_tasks: | |
| ASR_INFLIGHT_TASKS.discard(stale) | |
| if len(ASR_INFLIGHT_TASKS) >= max_inflight: | |
| logger.warning( | |
| "[%s] Realtime ASR skipped (inflight=%s, max=%s); continuing without transcript", | |
| request_id, | |
| len(ASR_INFLIGHT_TASKS), | |
| max_inflight, | |
| ) | |
| return _asr_fallback_result("busy") | |
| asr_task = asyncio.create_task(asyncio.to_thread(transcribe_audio, audio, sr, language)) | |
| ASR_INFLIGHT_TASKS.add(asr_task) | |
| asr_task.add_done_callback(_discard_asr_task) | |
| try: | |
| return await asyncio.wait_for(asyncio.shield(asr_task), timeout=timeout_seconds) | |
| except asyncio.TimeoutError: | |
| logger.warning( | |
| "[%s] Realtime ASR timed out after %.0fms; continuing without transcript", | |
| request_id, | |
| timeout_seconds * 1000, | |
| ) | |
| return _asr_fallback_result("timeout") | |
| except Exception as exc: | |
| logger.warning("[%s] Realtime ASR path failed: %s; continuing without transcript", request_id, exc) | |
| return _asr_fallback_result("error") | |
| def _voice_fallback_result() -> AnalysisResult: | |
| return AnalysisResult( | |
| label="UNCERTAIN", | |
| confidence=0.5, | |
| ai_confidence=0.0, | |
| human_confidence=0.0, | |
| error="Server busy or timeout.", | |
| processing_time_ms=0.0 | |
| ) | |
| def _discard_voice_task(task: asyncio.Task) -> None: | |
| VOICE_INFLIGHT_TASKS.discard(task) | |
| async def analyze_voice_guarded( | |
| audio: np.ndarray, | |
| sr: int, | |
| timeout_seconds: float, | |
| request_id: str, | |
| language: str = "English", | |
| realtime: bool = False, | |
| source: str = "file" | |
| ) -> AnalysisResult: | |
| """Run voice analysis with bounded in-flight tasks to prevent OOM thread pileups.""" | |
| max_inflight = max(1, int(getattr(settings, "VOICE_MAX_INFLIGHT_TASKS", 2))) | |
| async with VOICE_INFLIGHT_LOCK: | |
| stale_tasks = [task for task in VOICE_INFLIGHT_TASKS if task.done()] | |
| for stale in stale_tasks: | |
| VOICE_INFLIGHT_TASKS.discard(stale) | |
| if len(VOICE_INFLIGHT_TASKS) >= max_inflight: | |
| logger.warning( | |
| "[%s] Voice analysis skipped (inflight=%s, max=%s); preventing OOM", | |
| request_id, | |
| len(VOICE_INFLIGHT_TASKS), | |
| max_inflight, | |
| ) | |
| return _voice_fallback_result() | |
| voice_task = asyncio.create_task(asyncio.to_thread(analyze_voice, audio, sr, language, realtime, source)) | |
| VOICE_INFLIGHT_TASKS.add(voice_task) | |
| voice_task.add_done_callback(_discard_voice_task) | |
| try: | |
| return await asyncio.wait_for(asyncio.shield(voice_task), timeout=timeout_seconds) | |
| except asyncio.TimeoutError: | |
| logger.warning( | |
| "[%s] Voice analysis timed out after %.0fms", | |
| request_id, | |
| timeout_seconds * 1000, | |
| ) | |
| return _voice_fallback_result() | |
| except Exception as exc: | |
| logger.warning("[%s] Voice analysis failed: %s", request_id, exc) | |
| return _voice_fallback_result() | |
| def warmup_audio_pipeline() -> None: | |
| """Warm audio decoding stack to reduce first-request latency spikes.""" | |
| if not settings.AUDIO_PIPELINE_WARMUP_ENABLED: | |
| return | |
| try: | |
| import soundfile as sf | |
| warm_audio = np.zeros(16000, dtype=np.float32) | |
| wav_buffer = io.BytesIO() | |
| sf.write(wav_buffer, warm_audio, 16000, format="WAV", subtype="PCM_16") | |
| load_audio_from_bytes(wav_buffer.getvalue(), 16000, "wav") | |
| logger.info("Audio pipeline warm-up complete") | |
| except Exception as exc: | |
| logger.warning("Audio pipeline warm-up skipped: %s", exc) | |
| def warmup_asr_pipeline() -> None: | |
| """Warm ASR model and transcription path on startup.""" | |
| if not settings.ASR_ENABLED or not settings.ASR_WARMUP_ENABLED: | |
| return | |
| try: | |
| warm_audio = np.zeros(16000, dtype=np.float32) | |
| transcribe_audio(warm_audio, 16000, "English") | |
| logger.info("ASR warm-up complete") | |
| except Exception as exc: | |
| logger.warning("ASR warm-up skipped: %s", exc) | |
| def warmup_voice_pipeline() -> None: | |
| """Run one inference pass to avoid first realtime-model cold latency spike.""" | |
| if not settings.VOICE_WARMUP_ENABLED: | |
| return | |
| try: | |
| sr = 16000 | |
| duration_sec = 1.0 | |
| sample_count = max(1, int(sr * duration_sec)) | |
| t = np.linspace(0.0, duration_sec, sample_count, endpoint=False, dtype=np.float32) | |
| # Non-silent tone avoids edge-case feature paths and mirrors short speech chunks. | |
| warm_audio = (0.08 * np.sin(2 * np.pi * 220 * t)).astype(np.float32) | |
| analyze_voice(warm_audio, sr, "English", True) | |
| logger.info("Voice model warm-up complete") | |
| except Exception as exc: | |
| logger.warning("Voice model warm-up skipped: %s", exc) | |
| def run_startup_warmups() -> None: | |
| """Run non-critical startup warm-ups for latency-sensitive paths.""" | |
| warmup_audio_pipeline() | |
| warmup_voice_pipeline() | |
| warmup_asr_pipeline() | |
| # Detect environment | |
| if settings.SPACE_ID: | |
| logger.info("Running on HuggingFace Spaces: %s", settings.SPACE_ID) | |
| def get_session_lock(session_id: str) -> asyncio.Lock: | |
| """Return a per-session lock, creating one if needed.""" | |
| if session_id not in SESSION_LOCKS: | |
| SESSION_LOCKS[session_id] = asyncio.Lock() | |
| return SESSION_LOCKS[session_id] | |
| async def _periodic_session_purge(interval: int = 60) -> None: | |
| """Background task: purge expired sessions every *interval* seconds.""" | |
| while True: | |
| try: | |
| await asyncio.sleep(interval) | |
| async with SESSION_LOCK: | |
| removed = purge_expired_sessions() | |
| stale_lock_keys = [k for k in SESSION_LOCKS if k not in SESSION_STORE] | |
| for k in stale_lock_keys: | |
| del SESSION_LOCKS[k] | |
| if removed: | |
| logger.info("Periodic purge removed %d expired session(s)", removed) | |
| except asyncio.CancelledError: | |
| raise | |
| except Exception as exc: | |
| logger.warning("Periodic purge error: %s", exc) | |
| async def lifespan(app: FastAPI): | |
| """Manage application lifespan events.""" | |
| logger.info("Starting up - preloading ML model...") | |
| initialize_session_store_backend() | |
| try: | |
| from model import preload_model | |
| preload_model() | |
| logger.info("ML model loaded successfully") | |
| except Exception as e: | |
| logger.error("Failed to preload model: %s", e) | |
| try: | |
| await asyncio.to_thread(run_startup_warmups) | |
| except Exception as exc: | |
| logger.warning("Startup warm-ups encountered an issue: %s", exc) | |
| purge_task = asyncio.create_task(_periodic_session_purge()) | |
| yield | |
| # Shutdown | |
| purge_task.cancel() | |
| try: | |
| await purge_task | |
| except asyncio.CancelledError: | |
| pass | |
| logger.info("Shutting down...") | |
| app = FastAPI( | |
| title="AI Voice Detection API", | |
| description="Detects whether a voice sample is AI-generated or spoken by a real human", | |
| version="1.0.0", | |
| contact={ | |
| "name": "Shivam", | |
| "url": settings.WEBSITE_URL, | |
| }, | |
| docs_url="/docs", | |
| redoc_url="/redoc", | |
| lifespan=lifespan | |
| ) | |
| app.state.limiter = limiter | |
| app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler) | |
| _cors_origins = settings.ALLOWED_ORIGINS | |
| _cors_credentials = "*" not in _cors_origins | |
| if not _cors_credentials: | |
| logger.warning("CORS allow_origins='*' — credentials disabled. Set ALLOWED_ORIGINS for production.") | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=_cors_origins, | |
| allow_credentials=_cors_credentials, | |
| allow_methods=["GET", "POST", "OPTIONS"], | |
| allow_headers=["Content-Type", "x-api-key", "Authorization"], | |
| ) | |
| async def log_requests(request: Request, call_next): | |
| request_id = str(uuid.uuid4())[:8] | |
| request.state.request_id = request_id | |
| start_time = time.perf_counter() | |
| method = request.method | |
| path = request.url.path | |
| if method == "POST": | |
| logger.info("[%s] [START] %s %s", request_id, method, path) | |
| response = await call_next(request) | |
| duration_ms = (time.perf_counter() - start_time) * 1000 | |
| status_code = response.status_code | |
| if method == "POST": | |
| status_label = "[OK]" if status_code == 200 else "[ERR]" if status_code >= 400 else "[WARN]" | |
| logger.info("[%s] %s END %s %s -> %s (%0.fms)", request_id, status_label, method, path, status_code, duration_ms) | |
| response.headers["X-Request-ID"] = request_id | |
| response.headers["X-Response-Time"] = f"{duration_ms:.0f}ms" | |
| response.headers["X-Content-Type-Options"] = "nosniff" | |
| response.headers["Strict-Transport-Security"] = "max-age=31536000; includeSubDomains" | |
| # Relax CSP to allow standard API documentation via CDNs (ReDoc/Swagger) | |
| response.headers["Content-Security-Policy"] = ( | |
| "default-src 'self'; " | |
| "script-src 'self' 'unsafe-inline' 'unsafe-eval' https://cdn.jsdelivr.net; " | |
| "style-src 'self' 'unsafe-inline' https://cdn.jsdelivr.net https://fonts.googleapis.com; " | |
| "font-src 'self' https://fonts.gstatic.com; " | |
| "img-src 'self' data: https://fastapi.tiangolo.com;" | |
| ) | |
| return response | |
| class VoiceDetectionRequest(BaseModel): | |
| """Request body for voice detection.""" | |
| language: str = Field(default="Auto", description="Language hint (Auto, English, Hindi, Hinglish, Tamil, Malayalam, Telugu). Defaults to auto-detect.") | |
| audioFormat: str = Field(default="wav", description="Audio format (mp3, wav, flac, ogg, m4a, mp4, webm)") | |
| audioBase64: str = Field(..., description="Base64-encoded audio data") | |
| def validate_audio_size(cls, v: str) -> str: | |
| """Validate audio data is not too small or too large.""" | |
| if len(v) < 100: | |
| raise ValueError("Audio data too small - provide valid audio content") | |
| if len(v) > MAX_AUDIO_BASE64_LENGTH: | |
| raise ValueError(f"Audio data too large - maximum {settings.MAX_AUDIO_SIZE_MB}MB allowed") | |
| return v | |
| class ForensicMetrics(BaseModel): | |
| """Detailed forensic analysis metrics.""" | |
| authenticity_score: float = Field(..., description="Overall voice naturalness score (0-100)") | |
| pitch_naturalness: float = Field(..., description="Pitch stability and jitter score (0-100)") | |
| spectral_naturalness: float = Field(..., description="Spectral entropy and flatness score (0-100)") | |
| temporal_naturalness: float = Field(..., description="Rhythm and silence score (0-100)") | |
| class VoiceDetectionResponse(BaseModel): | |
| """Successful response from voice detection.""" | |
| status: str = "success" | |
| language: str | |
| classification: str # AI_GENERATED or HUMAN | |
| confidenceScore: float = Field(..., ge=0.0, le=1.0) | |
| explanation: str | |
| forensic_metrics: Optional[ForensicMetrics] = None | |
| modelUncertain: bool = False | |
| recommendedAction: Optional[str] = None | |
| class ErrorResponse(BaseModel): | |
| """Error response.""" | |
| status: str = "error" | |
| message: str | |
| class SessionStartRequest(BaseModel): | |
| """Request body for creating a real-time analysis session.""" | |
| language: str = Field(default="Auto", description="Language hint (Auto, English, Hindi, Hinglish, Tamil, Malayalam, Telugu). Defaults to auto-detect.") | |
| class SessionStartResponse(BaseModel): | |
| """Response body after creating a session.""" | |
| status: str = "success" | |
| session_id: str | |
| language: str | |
| started_at: str | |
| message: str | |
| class SessionChunkRequest(BaseModel): | |
| """Audio chunk request for real-time analysis.""" | |
| audioFormat: str = Field(default="wav", description="Audio format (mp3, wav, flac, ogg, m4a, mp4, webm)") | |
| audioBase64: str = Field(..., description="Base64-encoded audio chunk") | |
| language: Optional[str] = Field(default=None, description="Optional override. Defaults to session language") | |
| source: Optional[str] = Field(default="file", description="Audio source: 'mic' for browser microphone, 'file' for uploaded file") | |
| def validate_chunk_size(cls, v: str) -> str: | |
| if len(v) < 100: | |
| raise ValueError("Audio data too small - provide valid audio content") | |
| if len(v) > MAX_AUDIO_BASE64_LENGTH: | |
| raise ValueError(f"Audio data too large - maximum {settings.MAX_AUDIO_SIZE_MB}MB allowed") | |
| return v | |
| class RiskEvidence(BaseModel): | |
| """Model evidence used to produce risk score.""" | |
| audio_patterns: List[str] = Field(default_factory=list) | |
| keywords: List[str] = Field(default_factory=list) | |
| behaviour: List[str] = Field(default_factory=list) | |
| class RealTimeLanguageAnalysis(BaseModel): | |
| """Transcript and language risk signals for the current chunk.""" | |
| transcript: str = "" | |
| transcript_confidence: float = Field(default=0.0, ge=0.0, le=1.0) | |
| asr_engine: str = "unavailable" | |
| keyword_hits: List[str] = Field(default_factory=list) | |
| keyword_categories: List[str] = Field(default_factory=list) | |
| semantic_flags: List[str] = Field(default_factory=list) | |
| keyword_score: int = Field(default=0, ge=0, le=100) | |
| semantic_score: int = Field(default=0, ge=0, le=100) | |
| behaviour_score: int = Field(default=0, ge=0, le=100) | |
| session_behaviour_signals: List[str] = Field(default_factory=list) | |
| llm_semantic_used: bool = False | |
| llm_semantic_confidence: float = Field(default=0.0, ge=0.0, le=1.0) | |
| llm_semantic_model: Optional[str] = None | |
| class RealTimeAlert(BaseModel): | |
| """Alert details emitted by the risk engine.""" | |
| triggered: bool | |
| alert_type: Optional[str] = None | |
| severity: Optional[str] = None | |
| reason_summary: Optional[str] = None | |
| recommended_action: Optional[str] = None | |
| class ExplainabilitySignal(BaseModel): | |
| """Per-signal contribution to fused risk score.""" | |
| signal: str | |
| raw_score: int = Field(..., ge=0, le=100) | |
| weight: float = Field(..., ge=0.0, le=1.0) | |
| weighted_score: float = Field(..., ge=0.0, le=100.0) | |
| class RealTimeExplainability(BaseModel): | |
| """Human-readable explainability block for chunk risk output.""" | |
| summary: str | |
| top_indicators: List[str] = Field(default_factory=list) | |
| signal_contributions: List[ExplainabilitySignal] = Field(default_factory=list) | |
| uncertainty_note: Optional[str] = None | |
| class RealTimeUpdateResponse(BaseModel): | |
| """Chunk-by-chunk update response.""" | |
| status: str = "success" | |
| session_id: str | |
| timestamp: str | |
| risk_score: int = Field(..., ge=0, le=100) | |
| cpi: float = Field(..., ge=0.0, le=100.0, description="Conversational Pressure Index") | |
| risk_level: str | |
| call_label: str | |
| model_uncertain: bool = False | |
| voice_classification: str = "UNCERTAIN" | |
| voice_confidence: float = Field(default=0.0, ge=0.0, le=1.0) | |
| evidence: RiskEvidence | |
| language_analysis: RealTimeLanguageAnalysis | |
| alert: RealTimeAlert | |
| explainability: RealTimeExplainability | |
| chunks_processed: int = Field(..., ge=1) | |
| risk_policy_version: str = settings.RISK_POLICY_VERSION | |
| class SessionSummaryResponse(BaseModel): | |
| """Summary response for a completed or active session.""" | |
| status: str = "success" | |
| session_id: str | |
| language: str | |
| session_status: str | |
| started_at: str | |
| last_update: Optional[str] = None | |
| chunks_processed: int = 0 | |
| alerts_triggered: int = 0 | |
| max_risk_score: int = 0 | |
| max_cpi: float = 0.0 | |
| risk_level: str = "LOW" | |
| risk_label: str = "SAFE" | |
| final_call_label: str = "UNCERTAIN" | |
| final_voice_classification: str = "UNCERTAIN" | |
| final_voice_confidence: float = 0.0 | |
| max_voice_ai_confidence: float = 0.0 | |
| voice_ai_chunks: int = 0 | |
| voice_human_chunks: int = 0 | |
| llm_checks_performed: int = 0 | |
| risk_policy_version: str = settings.RISK_POLICY_VERSION | |
| alert_history: List[Dict[str, Any]] = Field(default_factory=list) | |
| class AlertHistoryItem(BaseModel): | |
| """One alert event emitted during session analysis.""" | |
| timestamp: str | |
| risk_score: int = Field(..., ge=0, le=100) | |
| risk_level: str | |
| call_label: str | |
| alert_type: str | |
| severity: str | |
| reason_summary: str | |
| recommended_action: str | |
| class AlertHistoryResponse(BaseModel): | |
| """Paginated alert history for one session.""" | |
| status: str = "success" | |
| session_id: str | |
| total_alerts: int | |
| alerts: List[AlertHistoryItem] = Field(default_factory=list) | |
| class RetentionPolicyResponse(BaseModel): | |
| """Explicit privacy and retention behavior for session processing.""" | |
| status: str = "success" | |
| raw_audio_storage: str = "not_persisted" | |
| active_session_retention_seconds: int | |
| ended_session_retention_seconds: int | |
| stored_derived_fields: List[str] | |
| def utc_now_iso() -> str: | |
| """Return a UTC ISO-8601 timestamp.""" | |
| return datetime.now(timezone.utc).isoformat().replace("+00:00", "Z") | |
| STORED_DERIVED_FIELDS = [ | |
| "risk_history", | |
| "behaviour_score", | |
| "session_behaviour_signals", | |
| "transcript_counts", | |
| "semantic_flag_counts", | |
| "keyword_category_counts", | |
| "chunks_processed", | |
| "alerts_triggered", | |
| "max_risk_score", | |
| "final_call_label", | |
| "voice_ai_chunks", | |
| "voice_human_chunks", | |
| "max_voice_ai_confidence", | |
| "final_voice_classification", | |
| "llm_checks_performed", | |
| ] | |
| def parse_iso_timestamp(value: Optional[str]) -> Optional[float]: | |
| """Convert ISO timestamp to epoch seconds.""" | |
| if value is None: | |
| return None | |
| try: | |
| return datetime.fromisoformat(value.replace("Z", "+00:00")).timestamp() | |
| except ValueError: | |
| return None | |
| def session_reference_timestamp(session: SessionState) -> Optional[float]: | |
| """Return the best available timestamp for retention checks.""" | |
| return parse_iso_timestamp(session.last_update) or parse_iso_timestamp(session.started_at) | |
| def session_retention_seconds(session: SessionState) -> int: | |
| """Resolve retention policy from session status.""" | |
| if session.status == "ended": | |
| return settings.SESSION_ENDED_RETENTION_SECONDS | |
| return settings.SESSION_ACTIVE_RETENTION_SECONDS | |
| def is_session_expired(session: SessionState, now_ts: Optional[float] = None) -> bool: | |
| """Check if a session exceeded status-specific retention TTL.""" | |
| reference_ts = session_reference_timestamp(session) | |
| if reference_ts is None: | |
| return False | |
| current = now_ts if now_ts is not None else time.time() | |
| return (current - reference_ts) > session_retention_seconds(session) | |
| def purge_expired_sessions(now_ts: Optional[float] = None) -> int: | |
| """Best-effort retention purge for stale sessions (memory backend).""" | |
| if use_redis_session_store(): | |
| # Redis keys self-expire by TTL; no in-process purge needed. | |
| return 0 | |
| current = now_ts if now_ts is not None else time.time() | |
| expired_ids = [sid for sid, state in SESSION_STORE.items() if is_session_expired(state, current)] | |
| for expired_id in expired_ids: | |
| delete_session_state(expired_id) | |
| return len(expired_ids) | |
| def validate_supported_language(language: str) -> str: | |
| """Validate supported language. Falls back to 'Auto' for unknown languages so the | |
| evaluator never gets a 400 for an unexpected language hint.""" | |
| if language not in settings.SUPPORTED_LANGUAGES: | |
| logger.warning("Unsupported language '%s' — falling back to 'Auto'", language) | |
| return "Auto" | |
| return language | |
| def validate_supported_format(audio_format: str) -> None: | |
| """Validate supported audio format.""" | |
| normalized = audio_format.lower() | |
| if normalized not in settings.SUPPORTED_FORMATS: | |
| raise HTTPException( | |
| status_code=400, | |
| detail={ | |
| "status": "error", | |
| "message": f"Unsupported audio format. Must be one of: {', '.join(settings.SUPPORTED_FORMATS)}" | |
| } | |
| ) | |
| def normalize_transcript_for_behavior(transcript: str) -> str: | |
| """Normalize transcript for repetition and trend analysis.""" | |
| lowered = transcript.lower() | |
| cleaned = "".join(ch if ch.isalnum() or ch.isspace() else " " for ch in lowered) | |
| return " ".join(cleaned.split()) | |
| def token_overlap_ratio(text_a: str, text_b: str) -> float: | |
| """Compute Jaccard overlap between token sets.""" | |
| tokens_a = set(text_a.split()) | |
| tokens_b = set(text_b.split()) | |
| if not tokens_a or not tokens_b: | |
| return 0.0 | |
| return len(tokens_a.intersection(tokens_b)) / len(tokens_a.union(tokens_b)) | |
| def dedupe_preserve_order(items: List[str]) -> List[str]: | |
| """Return unique string items while preserving first-seen order.""" | |
| seen = set() | |
| deduped: List[str] = [] | |
| for item in items: | |
| if item in seen: | |
| continue | |
| seen.add(item) | |
| deduped.append(item) | |
| return deduped | |
| def update_session_behaviour_state(session: SessionState, language_analysis: Dict[str, Any]) -> Dict[str, Any]: | |
| """Update session-level behaviour score from transcript and semantic trends.""" | |
| transcript_source = str(language_analysis.get("transcript_raw", language_analysis.get("transcript", ""))) | |
| transcript = normalize_transcript_for_behavior(transcript_source) | |
| semantic_flags = list(language_analysis.get("semantic_flags", [])) | |
| keyword_categories = list(language_analysis.get("keyword_categories", [])) | |
| for flag in semantic_flags: | |
| session.semantic_flag_counts[flag] = session.semantic_flag_counts.get(flag, 0) + 1 | |
| for category in keyword_categories: | |
| session.keyword_category_counts[category] = session.keyword_category_counts.get(category, 0) + 1 | |
| behavior_signals: List[str] = [] | |
| if transcript: | |
| count = session.transcript_counts.get(transcript, 0) + 1 | |
| session.transcript_counts[transcript] = count | |
| if count >= 2: | |
| behavior_signals.append("repetition_loop") | |
| if session.last_transcript: | |
| overlap = token_overlap_ratio(transcript, session.last_transcript) | |
| if overlap >= 0.75 and len(transcript.split()) >= 4: | |
| behavior_signals.append("repetition_loop") | |
| session.last_transcript = transcript | |
| urgency_count = session.semantic_flag_counts.get("urgency_language", 0) | |
| if urgency_count >= 2: | |
| behavior_signals.append("sustained_urgency") | |
| has_impersonation = session.semantic_flag_counts.get("authority_impersonation", 0) > 0 | |
| has_credentials = session.semantic_flag_counts.get("credential_request", 0) > 0 | |
| has_payment = session.semantic_flag_counts.get("payment_redirection", 0) > 0 | |
| has_threat = session.semantic_flag_counts.get("coercive_threat_language", 0) > 0 | |
| has_urgency = urgency_count > 0 | |
| if has_impersonation and has_credentials: | |
| behavior_signals.append("impersonation_plus_credential_request") | |
| if has_payment and has_urgency: | |
| behavior_signals.append("persistent_payment_pressure") | |
| if has_threat and has_urgency: | |
| behavior_signals.append("repeated_threat_urgency") | |
| repeated_categories = sum(1 for count in session.keyword_category_counts.values() if count >= 2) | |
| if repeated_categories >= 2: | |
| behavior_signals.append("repeated_fraud_categories") | |
| behavior_signals = sorted(set(behavior_signals)) | |
| score = 0 | |
| if "repetition_loop" in behavior_signals: | |
| max_repetition = max(session.transcript_counts.values()) if session.transcript_counts else 2 | |
| score += 25 + min(15, (max_repetition - 2) * 5) | |
| if "sustained_urgency" in behavior_signals: | |
| score += 15 + min(10, (urgency_count - 2) * 5) | |
| if "impersonation_plus_credential_request" in behavior_signals: | |
| score += 30 | |
| if "persistent_payment_pressure" in behavior_signals: | |
| score += 20 | |
| if "repeated_threat_urgency" in behavior_signals: | |
| score += 15 | |
| if "repeated_fraud_categories" in behavior_signals: | |
| score += 10 | |
| session.behaviour_score = max(0, min(100, score)) | |
| session.session_behaviour_signals = behavior_signals | |
| return { | |
| "behaviour_score": session.behaviour_score, | |
| "session_behaviour_signals": session.session_behaviour_signals, | |
| } | |
| def map_score_to_level(score: int) -> str: | |
| """Map numeric score to risk level.""" | |
| if score < 35: | |
| return "LOW" | |
| if score < 60: | |
| return "MEDIUM" | |
| if score < 80: | |
| return "HIGH" | |
| return "CRITICAL" | |
| def map_level_to_label(risk_level: str, model_uncertain: bool) -> str: | |
| """Map risk level to user-friendly label.""" | |
| if model_uncertain: | |
| return "UNCERTAIN" | |
| if risk_level == "LOW": | |
| return "SAFE" | |
| if risk_level == "MEDIUM": | |
| return "SPAM" | |
| return "FRAUD" | |
| def recommendation_for_level(risk_level: str, model_uncertain: bool) -> str: | |
| """Return a user action recommendation based on severity.""" | |
| if model_uncertain: | |
| return "Model uncertainty detected. Avoid sharing OTP/PIN and verify caller via official channel." | |
| if risk_level == "CRITICAL": | |
| return "High fraud risk. End the call and verify through an official support number." | |
| if risk_level == "HIGH": | |
| return "Fraud indicators detected. Do not share OTP, PIN, passwords, or UPI credentials." | |
| if risk_level == "MEDIUM": | |
| return "Suspicious call behavior detected. Verify caller identity before taking action." | |
| return "No high-risk fraud indicators detected in current chunk." | |
| def should_invoke_llm_semantic( | |
| provisional_scored: Dict[str, Any], | |
| transcript: str, | |
| transcript_confidence: float, | |
| next_chunk_index: int, | |
| ) -> bool: | |
| """Gate optional LLM semantic calls for ambiguous/uncertain chunks.""" | |
| if not settings.LLM_SEMANTIC_ENABLED: | |
| return False | |
| if not is_llm_semantic_provider_ready(): | |
| return False | |
| if not transcript.strip(): | |
| return False | |
| if len(transcript.strip()) < 8: | |
| return False | |
| if transcript_confidence < settings.LLM_SEMANTIC_MIN_ASR_CONFIDENCE: | |
| return False | |
| interval = max(1, settings.LLM_SEMANTIC_CHUNK_INTERVAL) | |
| if next_chunk_index > 1 and (next_chunk_index % interval) != 0: | |
| return False | |
| risk_score = int(provisional_scored.get("risk_score", 0)) | |
| model_uncertain = bool(provisional_scored.get("model_uncertain", False)) | |
| ambiguous_band = 35 <= risk_score < 80 | |
| return ambiguous_band or model_uncertain | |
| def normalize_voice_classification(classification: str, model_uncertain: bool) -> str: | |
| """Normalize realtime voice-authenticity classification.""" | |
| if model_uncertain: | |
| return "UNCERTAIN" | |
| normalized = str(classification or "HUMAN").upper() | |
| if normalized in {"AI_GENERATED", "HUMAN"}: | |
| return normalized | |
| return "HUMAN" | |
| def build_explainability_payload( | |
| risk_level: str, | |
| call_label: str, | |
| model_uncertain: bool, | |
| cpi: float, | |
| audio_score: int, | |
| keyword_score: int, | |
| semantic_score: int, | |
| behaviour_score: int, | |
| has_language_signals: bool, | |
| behaviour_signals: List[str], | |
| keyword_hits: List[str], | |
| acoustic_anomaly: float, | |
| risk_score: int = 0, | |
| delta_boost: int = 0, | |
| voice_classification: str = "UNCERTAIN", | |
| voice_confidence: float = 0.0, | |
| authenticity_score: float = 50.0, | |
| ) -> RealTimeExplainability: | |
| """Build explicit explainability signals and concise summary.""" | |
| if has_language_signals: | |
| raw_weights = { | |
| "audio": settings.RISK_WEIGHT_AUDIO, | |
| "keywords": settings.RISK_WEIGHT_KEYWORD, | |
| "semantic": settings.RISK_WEIGHT_SEMANTIC, | |
| "behaviour": settings.RISK_WEIGHT_BEHAVIOUR, | |
| } | |
| total = sum(raw_weights.values()) or 1.0 | |
| weights = {k: v / total for k, v in raw_weights.items()} | |
| else: | |
| weights = { | |
| "audio": 1.00, | |
| "keywords": 0.00, | |
| "semantic": 0.00, | |
| "behaviour": 0.00, | |
| } | |
| contributions = [ | |
| ExplainabilitySignal( | |
| signal="audio", | |
| raw_score=audio_score, | |
| weight=weights["audio"], | |
| weighted_score=round(audio_score * weights["audio"], 2), | |
| ), | |
| ExplainabilitySignal( | |
| signal="keywords", | |
| raw_score=keyword_score, | |
| weight=weights["keywords"], | |
| weighted_score=round(keyword_score * weights["keywords"], 2), | |
| ), | |
| ExplainabilitySignal( | |
| signal="semantic_intent", | |
| raw_score=semantic_score, | |
| weight=weights["semantic"], | |
| weighted_score=round(semantic_score * weights["semantic"], 2), | |
| ), | |
| ExplainabilitySignal( | |
| signal="behaviour", | |
| raw_score=behaviour_score, | |
| weight=weights["behaviour"], | |
| weighted_score=round(behaviour_score * weights["behaviour"], 2), | |
| ), | |
| ] | |
| # Build top indicators from actual detection signals | |
| indicators: List[str] = [] | |
| if voice_classification == "AI_GENERATED": | |
| indicators.append("ai_voice_detected") | |
| if voice_confidence >= 0.90: | |
| indicators.append("high_confidence_synthetic") | |
| if authenticity_score < 40: | |
| indicators.append("low_authenticity_score") | |
| if acoustic_anomaly >= 45: | |
| indicators.append("acoustic_anomaly_detected") | |
| indicators.extend(behaviour_signals) | |
| indicators.extend(keyword_hits[:3]) | |
| deduped_indicators = list(dict.fromkeys(indicators))[:6] | |
| base_from_signals = sum(c.weighted_score for c in contributions) | |
| summary_parts: List[str] = [ | |
| f"{risk_level.title()} risk classified as {call_label}." | |
| ] | |
| if delta_boost > 0: | |
| summary_parts.append(f"Score {risk_score} (base {int(base_from_signals)} + {delta_boost} trend boost).") | |
| summary_parts.append(f"CPI at {cpi:.1f}/100.") | |
| if acoustic_anomaly >= 60: | |
| summary_parts.append("Audio anomalies are materially elevated.") | |
| if keyword_score >= 45: | |
| summary_parts.append("Fraud-related keywords contribute to the score.") | |
| if semantic_score >= 45: | |
| summary_parts.append("Semantic coercion patterns were detected.") | |
| if behaviour_score >= 40: | |
| summary_parts.append("Session behavior trend increases risk.") | |
| if cpi >= 70: | |
| summary_parts.append("Pressure escalation velocity is high; early warning triggered.") | |
| if not has_language_signals: | |
| summary_parts.append("Assessment is currently audio-dominant.") | |
| uncertainty_note = None | |
| if model_uncertain: | |
| uncertainty_note = ( | |
| "Model confidence is limited for this chunk. Treat this result conservatively and verify through trusted channels." | |
| ) | |
| return RealTimeExplainability( | |
| summary=" ".join(summary_parts), | |
| top_indicators=deduped_indicators, | |
| signal_contributions=contributions, | |
| uncertainty_note=uncertainty_note, | |
| ) | |
| def build_risk_update( | |
| result_features: Dict[str, float], | |
| classification: str, | |
| confidence: float, | |
| language_analysis: Dict[str, Any], | |
| previous_score: Optional[int], | |
| llm_semantic: Optional[Dict[str, Any]] = None, | |
| session_ai_context: Optional[Dict[str, Any]] = None, | |
| ) -> Dict[str, Any]: | |
| """Build risk score, evidence and alert from model outputs and session trend. | |
| Args: | |
| session_ai_context: Optional dict with session-level AI detection state: | |
| - voice_ai_chunks (int): Number of chunks classified as AI so far | |
| - max_ai_confidence (float): Highest AI confidence seen in session | |
| - chunks_processed (int): Total chunks processed so far | |
| - risk_history (list[int]): Previous risk scores in session | |
| """ | |
| _ai_ctx = session_ai_context or {} | |
| _voice_ai_chunks = int(_ai_ctx.get("voice_ai_chunks", 0)) | |
| _max_ai_conf = float(_ai_ctx.get("max_ai_confidence", 0.0)) | |
| _chunks_processed = int(_ai_ctx.get("chunks_processed", 0)) | |
| _risk_history: List[int] = list(_ai_ctx.get("risk_history", [])) | |
| authenticity = float(result_features.get("authenticity_score", 50.0)) | |
| acoustic_anomaly = float(result_features.get("acoustic_anomaly_score", 0.0)) | |
| ml_fallback = bool(result_features.get("ml_fallback", 0.0)) | |
| realtime_heuristic_mode = bool(result_features.get("realtime_heuristic_mode", 0.0)) | |
| _audio_source = str(result_features.get("audio_source", "file")) | |
| normalized_classification = str(classification or "").upper() | |
| low_confidence_uncertain = bool( | |
| normalized_classification != "AI_GENERATED" | |
| and float(confidence) < 0.50 | |
| and int(language_analysis.get("keyword_score", 0)) == 0 | |
| and int(language_analysis.get("semantic_score", 0)) == 0 | |
| and int(language_analysis.get("behaviour_score", 0)) == 0 | |
| ) | |
| heuristic_uncertain = bool( | |
| realtime_heuristic_mode | |
| and normalized_classification != "AI_GENERATED" | |
| and float(confidence) < 0.90 | |
| ) | |
| model_uncertain = ml_fallback or low_confidence_uncertain or heuristic_uncertain | |
| keyword_score = int(language_analysis.get("keyword_score", 0)) | |
| semantic_score = int(language_analysis.get("semantic_score", 0)) | |
| behaviour_score = int(language_analysis.get("behaviour_score", 0)) | |
| keyword_hits = dedupe_preserve_order(list(language_analysis.get("keyword_hits", []))) | |
| behavior_from_language = dedupe_preserve_order(list(language_analysis.get("behaviour_signals", []))) | |
| behavior_from_session = dedupe_preserve_order(list(language_analysis.get("session_behaviour_signals", []))) | |
| keyword_categories = dedupe_preserve_order(list(language_analysis.get("keyword_categories", []))) | |
| semantic_flags = dedupe_preserve_order(list(language_analysis.get("semantic_flags", []))) | |
| transcript = str(language_analysis.get("transcript", "")).strip() | |
| llm_semantic_used = False | |
| llm_semantic_confidence = 0.0 | |
| llm_semantic_model: Optional[str] = None | |
| if llm_semantic and llm_semantic.get("available"): | |
| blend_weight = max(0.0, min(1.0, settings.LLM_SEMANTIC_BLEND_WEIGHT)) | |
| llm_score = int(max(0, min(100, llm_semantic.get("semantic_score", semantic_score)))) | |
| semantic_score = int(round((semantic_score * (1.0 - blend_weight)) + (llm_score * blend_weight))) | |
| llm_semantic_confidence = float(max(0.0, min(1.0, llm_semantic.get("confidence", 0.0)))) | |
| llm_semantic_model = str(llm_semantic.get("model") or settings.LLM_SEMANTIC_MODEL) | |
| llm_semantic_used = True | |
| keyword_hints = dedupe_preserve_order([str(x) for x in llm_semantic.get("keyword_hints", [])]) | |
| if keyword_hints: | |
| keyword_hits = dedupe_preserve_order(keyword_hits + keyword_hints) | |
| keyword_score = min(100, keyword_score + min(18, len(keyword_hints) * 6)) | |
| llm_flags = dedupe_preserve_order([str(x) for x in llm_semantic.get("semantic_flags", [])]) | |
| if llm_flags: | |
| semantic_flags = dedupe_preserve_order(semantic_flags + llm_flags) | |
| llm_behaviour = dedupe_preserve_order([str(x) for x in llm_semantic.get("behaviour_signals", [])]) | |
| if llm_behaviour: | |
| behavior_from_language = dedupe_preserve_order(behavior_from_language + llm_behaviour) | |
| # Audio signal risk. | |
| if classification == "AI_GENERATED": | |
| confidence_audio = int(round(confidence * 100)) | |
| anomaly_audio = int(max(0.0, min(100.0, acoustic_anomaly * 0.85))) | |
| audio_score = max(confidence_audio, anomaly_audio) | |
| # Dampen audio_score when signal forensics contradict AI classification | |
| # for mic-source audio (browser mic has authenticity 34-60 naturally). | |
| if authenticity > 35 and _audio_source == "mic": | |
| auth_dampen = max(0.50, 1.0 - (authenticity - 35) / 100.0) | |
| audio_score = int(round(audio_score * auth_dampen)) | |
| else: | |
| authenticity_audio_score = int(max(0, min(100, (50.0 - authenticity) * 1.2))) | |
| # Mic audio has higher spectral anomaly (40-78); use lower multiplier. | |
| _anomaly_mult = 0.55 if _audio_source == "mic" else 0.90 | |
| anomaly_audio_score = int(max(0.0, min(100.0, acoustic_anomaly * _anomaly_mult))) | |
| audio_score = max(authenticity_audio_score, anomaly_audio_score) | |
| has_language_signals = bool(transcript) or keyword_score > 0 or semantic_score > 0 or behaviour_score > 0 | |
| if has_language_signals: | |
| raw_weights = { | |
| "audio": settings.RISK_WEIGHT_AUDIO, | |
| "keywords": settings.RISK_WEIGHT_KEYWORD, | |
| "semantic": settings.RISK_WEIGHT_SEMANTIC, | |
| "behaviour": settings.RISK_WEIGHT_BEHAVIOUR, | |
| } | |
| total_weight = sum(raw_weights.values()) | |
| if total_weight <= 0: | |
| raw_weights = {"audio": 0.45, "keywords": 0.20, "semantic": 0.15, "behaviour": 0.20} | |
| total_weight = 1.0 | |
| normalized = {k: v / total_weight for k, v in raw_weights.items()} | |
| base_score = int( | |
| round( | |
| (audio_score * normalized["audio"]) | |
| + (keyword_score * normalized["keywords"]) | |
| + (semantic_score * normalized["semantic"]) | |
| + (behaviour_score * normalized["behaviour"]) | |
| ) | |
| ) | |
| else: | |
| base_score = audio_score | |
| if ml_fallback: | |
| base_score = max(base_score, 55) | |
| risk_score = max(0, min(100, base_score)) | |
| behaviour_signals: List[str] = list(behavior_from_language) + list(behavior_from_session) | |
| if keyword_score >= 60: | |
| behaviour_signals.append("keyword_cluster_detected") | |
| if semantic_score >= 60: | |
| behaviour_signals.append("semantic_coercion_detected") | |
| if behaviour_score >= 40: | |
| behaviour_signals.append("behaviour_risk_elevated") | |
| if acoustic_anomaly >= 60: | |
| behaviour_signals.append("acoustic_anomaly_detected") | |
| if previous_score is not None: | |
| delta = risk_score - previous_score | |
| if delta >= 15: | |
| behaviour_signals.append("rapid_risk_escalation") | |
| if risk_score >= 70 and previous_score >= 70: | |
| behaviour_signals.append("sustained_high_risk") | |
| else: | |
| delta = 0 | |
| delta_boost = 0 | |
| if delta > 0: | |
| delta_boost = int(delta * settings.RISK_DELTA_BOOST_FACTOR) | |
| risk_score = min(100, risk_score + delta_boost) | |
| # Risk dampening: prevent single-chunk LOW→CRITICAL jumps. | |
| if previous_score is not None and previous_score < 60 and risk_score >= 80: | |
| recent_high = sum(1 for s in _risk_history[-5:] if s >= 60) | |
| if recent_high < 2: | |
| risk_score = min(risk_score, 79) | |
| behaviour_signals.append("risk_dampened_no_prior_high") | |
| # First-chunk guard: cap noise-only first chunks at MEDIUM. | |
| if _chunks_processed == 0 and risk_score > 60: | |
| has_strong_signal = ( | |
| (classification == "AI_GENERATED" and confidence >= 0.80) | |
| or acoustic_anomaly >= 60 | |
| or keyword_score >= 40 | |
| or semantic_score >= 40 | |
| ) | |
| if not has_strong_signal: | |
| risk_score = 60 | |
| behaviour_signals.append("first_chunk_capped") | |
| # Cumulative escalation for sustained moderate signals. | |
| if len(_risk_history) >= 3 and risk_score >= 40: | |
| recent_moderate = sum(1 for s in _risk_history[-5:] if s >= 40) | |
| if recent_moderate >= 3: | |
| cumulative_boost = min(15, recent_moderate * 3) | |
| risk_score = min(100, risk_score + cumulative_boost) | |
| behaviour_signals.append("sustained_moderate_risk") | |
| # Sustained AI voice floor escalation. | |
| # floor = 70 + min(20, ai_chunks * 5) | |
| if classification == "AI_GENERATED" and confidence >= 0.92: | |
| ai_floor = 70 + min(20, _voice_ai_chunks * 5) | |
| risk_score = max(risk_score, ai_floor) | |
| if _voice_ai_chunks >= 2: | |
| behaviour_signals.append("sustained_ai_voice") | |
| # AI-voice-aware CPI includes synthetic voice ratio. | |
| _ai_ratio = (_voice_ai_chunks / max(1, _chunks_processed)) if _chunks_processed > 0 else 0.0 | |
| if previous_score is None: | |
| cpi = min(100.0, max(0.0, | |
| (behaviour_score * 0.35) | |
| + (semantic_score * 0.20) | |
| + (_ai_ratio * _max_ai_conf * 40.0) | |
| )) | |
| else: | |
| cpi = min(100.0, max(0.0, | |
| (max(0, delta) * 3.2) | |
| + (behaviour_score * 0.35) | |
| + (semantic_score * 0.15) | |
| + (_ai_ratio * _max_ai_conf * 40.0) | |
| )) | |
| if cpi >= 70: | |
| behaviour_signals.append("cpi_spike_detected") | |
| behaviour_signals = dedupe_preserve_order(behaviour_signals) | |
| risk_level = map_score_to_level(risk_score) | |
| call_label = map_level_to_label(risk_level, model_uncertain) | |
| audio_patterns = [ | |
| f"classification:{classification.lower()}", | |
| f"model_confidence:{confidence:.2f}", | |
| f"authenticity_score:{authenticity:.1f}", | |
| f"acoustic_anomaly_score:{acoustic_anomaly:.1f}", | |
| f"audio_score:{audio_score}", | |
| ] | |
| if ml_fallback: | |
| audio_patterns.append("model_fallback:true") | |
| audio_patterns = dedupe_preserve_order(audio_patterns) | |
| strong_intent = { | |
| "authority_with_credential_request", | |
| "urgent_payment_pressure", | |
| "threat_plus_urgency", | |
| "impersonation_plus_credential_request", | |
| "persistent_payment_pressure", | |
| "repeated_threat_urgency", | |
| } | |
| alert_triggered = ( | |
| risk_level in {"HIGH", "CRITICAL"} | |
| or "rapid_risk_escalation" in behaviour_signals | |
| or cpi >= 70 | |
| or any(signal in behaviour_signals for signal in strong_intent) | |
| ) | |
| # First-chunk alert guard: suppress unless CRITICAL or strong intent. | |
| if alert_triggered and _chunks_processed == 0: | |
| has_strong_intent = any(s in behaviour_signals for s in strong_intent) | |
| if risk_level != "CRITICAL" and not has_strong_intent: | |
| alert_triggered = False | |
| behaviour_signals.append("first_chunk_alert_suppressed") | |
| alert_type = None | |
| severity = None | |
| reason_summary = None | |
| recommended_action = None | |
| if alert_triggered: | |
| if risk_level == "CRITICAL": | |
| alert_type = "FRAUD_RISK_CRITICAL" | |
| elif cpi >= 70: | |
| alert_type = "EARLY_PRESSURE_WARNING" | |
| elif "rapid_risk_escalation" in behaviour_signals: | |
| alert_type = "RISK_ESCALATION" | |
| else: | |
| alert_type = "FRAUD_RISK_HIGH" | |
| severity = risk_level.lower() | |
| reasons: List[str] = [] | |
| if keyword_hits: | |
| reasons.append(f"fraud keywords detected ({', '.join(keyword_hits[:3])})") | |
| if semantic_score >= 45: | |
| reasons.append("coercive intent patterns detected") | |
| if behaviour_score >= 40: | |
| reasons.append("session behavior risk elevated") | |
| if "repetition_loop" in behaviour_signals: | |
| reasons.append("repetition loop detected") | |
| if "rapid_risk_escalation" in behaviour_signals: | |
| reasons.append(f"risk escalated rapidly (+{delta} points)") | |
| if cpi >= 70: | |
| reasons.append(f"conversational pressure index spiked ({cpi:.0f})") | |
| if "sustained_ai_voice" in behaviour_signals: | |
| reasons.append(f"sustained AI-generated voice across {_voice_ai_chunks} chunks") | |
| elif classification == "AI_GENERATED" and confidence >= 0.92: | |
| reasons.append(f"AI-generated voice detected ({confidence:.0%} confidence)") | |
| if "risk_dampened_no_prior_high" in behaviour_signals: | |
| reasons.append("risk capped — awaiting corroboration from additional chunks") | |
| if not reasons: | |
| reasons.append("high-risk audio pattern detected") | |
| reason_summary = ". ".join(reasons).capitalize() + "." | |
| recommended_action = recommendation_for_level(risk_level, model_uncertain) | |
| explainability = build_explainability_payload( | |
| risk_level=risk_level, | |
| call_label=call_label, | |
| model_uncertain=model_uncertain, | |
| cpi=cpi, | |
| audio_score=audio_score, | |
| keyword_score=keyword_score, | |
| semantic_score=semantic_score, | |
| behaviour_score=behaviour_score, | |
| has_language_signals=has_language_signals, | |
| behaviour_signals=behaviour_signals, | |
| keyword_hits=keyword_hits, | |
| acoustic_anomaly=acoustic_anomaly, | |
| risk_score=risk_score, | |
| delta_boost=delta_boost, | |
| voice_classification=classification, | |
| voice_confidence=confidence, | |
| authenticity_score=authenticity, | |
| ) | |
| return { | |
| "risk_score": risk_score, | |
| "cpi": round(cpi, 1), | |
| "risk_level": risk_level, | |
| "call_label": call_label, | |
| "model_uncertain": model_uncertain, | |
| "evidence": RiskEvidence( | |
| audio_patterns=audio_patterns, | |
| keywords=keyword_hits, | |
| behaviour=behaviour_signals | |
| ), | |
| "language_analysis": RealTimeLanguageAnalysis( | |
| transcript=transcript, | |
| transcript_confidence=float(language_analysis.get("transcript_confidence", 0.0)), | |
| asr_engine=str(language_analysis.get("asr_engine", "unavailable")), | |
| keyword_hits=keyword_hits, | |
| keyword_categories=keyword_categories, | |
| semantic_flags=semantic_flags, | |
| keyword_score=keyword_score, | |
| semantic_score=semantic_score, | |
| behaviour_score=behaviour_score, | |
| session_behaviour_signals=behavior_from_session, | |
| llm_semantic_used=llm_semantic_used, | |
| llm_semantic_confidence=llm_semantic_confidence, | |
| llm_semantic_model=llm_semantic_model, | |
| ), | |
| "alert": RealTimeAlert( | |
| triggered=alert_triggered, | |
| alert_type=alert_type, | |
| severity=severity, | |
| reason_summary=reason_summary, | |
| recommended_action=recommended_action | |
| ), | |
| "explainability": explainability, | |
| } | |
| async def process_audio_chunk( | |
| session_id: str, | |
| chunk_request: SessionChunkRequest, | |
| default_language: str, | |
| request_id: str | |
| ) -> RealTimeUpdateResponse: | |
| """Decode, analyze and score a real-time audio chunk.""" | |
| chunk_language = chunk_request.language or default_language | |
| validate_supported_language(chunk_language) | |
| validate_supported_format(chunk_request.audioFormat) | |
| audio_size_kb = len(chunk_request.audioBase64) * 3 / 4 / 1024 | |
| logger.info( | |
| f"[{request_id}] Realtime chunk: session={session_id}, language={chunk_language}, " | |
| f"format={chunk_request.audioFormat}, size~{audio_size_kb:.1f}KB" | |
| ) | |
| decode_start = time.perf_counter() | |
| audio_bytes = await asyncio.to_thread(decode_base64_audio, chunk_request.audioBase64) | |
| decode_ms = (time.perf_counter() - decode_start) * 1000 | |
| load_start = time.perf_counter() | |
| audio, sr = await asyncio.to_thread(load_audio_from_bytes, audio_bytes, 16000, chunk_request.audioFormat) | |
| load_ms = (time.perf_counter() - load_start) * 1000 | |
| duration_sec = len(audio) / sr | |
| logger.info( | |
| f"[{request_id}] Realtime analyze {duration_sec:.2f}s (decode {decode_ms:.0f}ms, load {load_ms:.0f}ms)" | |
| ) | |
| analyze_start = time.perf_counter() | |
| try: | |
| analysis_result = await analyze_voice_guarded( | |
| audio, sr, 3.0, request_id, chunk_language, True, chunk_request.source or "file" | |
| ) | |
| except Exception as exc: | |
| logger.warning("[%s] Realtime model path failed: %s; using conservative fallback", request_id, exc) | |
| analysis_result = AnalysisResult( | |
| classification="HUMAN", | |
| confidence_score=0.5, | |
| explanation="Realtime model path unavailable; conservative fallback applied.", | |
| features={ | |
| "ml_fallback": 1.0, | |
| "authenticity_score": 50.0, | |
| "pitch_naturalness": 50.0, | |
| "spectral_naturalness": 50.0, | |
| "temporal_naturalness": 50.0, | |
| "acoustic_anomaly_score": 50.0, | |
| }, | |
| ) | |
| analyze_ms = (time.perf_counter() - analyze_start) * 1000 | |
| logger.info( | |
| f"[{request_id}] Realtime result: {analysis_result.classification} " | |
| f"({analysis_result.confidence_score:.0%}) in {analyze_ms:.0f}ms" | |
| ) | |
| # Short-chunk guard: sub-2s segments are unreliable; carry forward | |
| # the session's majority classification instead. | |
| MIN_RELIABLE_DURATION = 2.0 | |
| if duration_sec < MIN_RELIABLE_DURATION: | |
| async with SESSION_LOCK: | |
| _sess = get_session_state(session_id) | |
| if _sess is not None and _sess.chunks_processed > 0: | |
| # Case 1: Session is mostly AI → carry forward AI | |
| if ( | |
| _sess.voice_ai_chunks > _sess.voice_human_chunks | |
| and _sess.max_voice_ai_confidence >= 0.85 | |
| and analysis_result.classification != "AI_GENERATED" | |
| ): | |
| logger.info( | |
| f"[{request_id}] Short chunk ({duration_sec:.2f}s < {MIN_RELIABLE_DURATION}s): " | |
| f"overriding {analysis_result.classification} → AI_GENERATED " | |
| f"(session has {_sess.voice_ai_chunks} AI vs {_sess.voice_human_chunks} human chunks)" | |
| ) | |
| analysis_result = AnalysisResult( | |
| classification="AI_GENERATED", | |
| confidence_score=_sess.max_voice_ai_confidence * 0.90, | |
| explanation=f"Short chunk ({duration_sec:.1f}s) — AI classification carried forward from session history.", | |
| features=analysis_result.features, | |
| ) | |
| # Case 2: Session is mostly HUMAN → carry forward HUMAN | |
| elif ( | |
| _sess.voice_human_chunks > _sess.voice_ai_chunks | |
| and _sess.final_voice_confidence >= 0.80 | |
| and analysis_result.classification != "HUMAN" | |
| ): | |
| logger.info( | |
| f"[{request_id}] Short chunk ({duration_sec:.2f}s < {MIN_RELIABLE_DURATION}s): " | |
| f"overriding {analysis_result.classification} → HUMAN " | |
| f"(session has {_sess.voice_human_chunks} human vs {_sess.voice_ai_chunks} AI chunks)" | |
| ) | |
| analysis_result = AnalysisResult( | |
| classification="HUMAN", | |
| confidence_score=_sess.final_voice_confidence * 0.90, | |
| explanation=f"Short chunk ({duration_sec:.1f}s) — HUMAN classification carried forward from session history.", | |
| features=analysis_result.features, | |
| ) | |
| asr_start = time.perf_counter() | |
| # Format-aware timeout: container formats (mp4, m4a, ogg, webm) need extra | |
| # demux time vs raw audio formats (mp3, wav, flac). | |
| _container_formats = {"mp4", "m4a", "ogg", "webm"} | |
| _base_timeout_ms = float(settings.ASR_TIMEOUT_MS) | |
| if chunk_request.audioFormat.lower() in _container_formats: | |
| asr_timeout_seconds = max(0.1, (_base_timeout_ms * 2.5) / 1000.0) | |
| else: | |
| asr_timeout_seconds = max(0.1, _base_timeout_ms / 1000.0) | |
| asr_result = await transcribe_audio_guarded( | |
| audio=audio, | |
| sr=sr, | |
| language=chunk_language, | |
| timeout_seconds=asr_timeout_seconds, | |
| request_id=request_id, | |
| ) | |
| asr_ms = (time.perf_counter() - asr_start) * 1000 | |
| raw_transcript = str(asr_result.get("transcript", "")) | |
| response_transcript = ( | |
| mask_sensitive_entities(raw_transcript) | |
| if settings.MASK_TRANSCRIPT_OUTPUT | |
| else raw_transcript | |
| ) | |
| language_result = analyze_transcript(raw_transcript, chunk_language) | |
| language_result["transcript_raw"] = raw_transcript | |
| language_result["transcript"] = response_transcript | |
| language_result["transcript_confidence"] = asr_result.get("confidence", 0.0) | |
| language_result["asr_engine"] = asr_result.get("engine", "unavailable") | |
| transcript_preview = sanitize_for_logging(raw_transcript, max_chars=90) | |
| logger.info( | |
| f"[{request_id}] Realtime ASR: engine={language_result['asr_engine']}, " | |
| f"confidence={language_result['transcript_confidence']:.2f}, " | |
| f"text_len={len(raw_transcript)}, preview='{transcript_preview}', asr={asr_ms:.0f}ms" | |
| ) | |
| # Read-only session snapshot for scoring and optional LLM gating. | |
| async with SESSION_LOCK: | |
| session = get_session_state(session_id) | |
| if session is None: | |
| raise HTTPException( | |
| status_code=404, | |
| detail={"status": "error", "message": "Session not found or expired"} | |
| ) | |
| if session.status != "active": | |
| raise HTTPException( | |
| status_code=409, | |
| detail={"status": "error", "message": "Session is not active. Start a new session to continue."} | |
| ) | |
| previous_score_snapshot = session.risk_history[-1] if session.risk_history else None | |
| next_chunk_index = session.chunks_processed + 1 | |
| # Snapshot session AI context under the lock for provisional scoring | |
| _provisional_ai_ctx = { | |
| "voice_ai_chunks": session.voice_ai_chunks, | |
| "max_ai_confidence": session.max_voice_ai_confidence, | |
| "chunks_processed": session.chunks_processed, | |
| "risk_history": list(session.risk_history), | |
| } | |
| provisional_scored = build_risk_update( | |
| analysis_result.features or {}, | |
| analysis_result.classification, | |
| analysis_result.confidence_score, | |
| language_result, | |
| previous_score_snapshot, | |
| session_ai_context=_provisional_ai_ctx, | |
| ) | |
| llm_semantic: Optional[Dict[str, Any]] = None | |
| llm_invoked = should_invoke_llm_semantic( | |
| provisional_scored=provisional_scored, | |
| transcript=raw_transcript, | |
| transcript_confidence=float(language_result.get("transcript_confidence", 0.0)), | |
| next_chunk_index=next_chunk_index, | |
| ) | |
| if llm_invoked: | |
| llm_semantic = await asyncio.to_thread( | |
| analyze_semantic_with_llm, | |
| raw_transcript, | |
| chunk_language, | |
| settings.LLM_SEMANTIC_TIMEOUT_MS, | |
| ) | |
| async with SESSION_LOCK: | |
| session = get_session_state(session_id) | |
| if session is None: | |
| raise HTTPException( | |
| status_code=404, | |
| detail={"status": "error", "message": "Session not found or expired"} | |
| ) | |
| if session.status != "active": | |
| raise HTTPException( | |
| status_code=409, | |
| detail={"status": "error", "message": "Session is not active. Start a new session to continue."} | |
| ) | |
| if llm_invoked: | |
| session.llm_checks_performed += 1 | |
| if llm_semantic and llm_semantic.get("available"): | |
| session.llm_last_engine = str(llm_semantic.get("engine", "openai-chat-completions")) | |
| else: | |
| reason = str((llm_semantic or {}).get("reason", "unavailable")) | |
| session.llm_last_engine = f"skipped:{reason}" | |
| behaviour_snapshot = update_session_behaviour_state(session, language_result) | |
| language_result.update(behaviour_snapshot) | |
| previous_score = session.risk_history[-1] if session.risk_history else None | |
| scored = build_risk_update( | |
| analysis_result.features or {}, | |
| analysis_result.classification, | |
| analysis_result.confidence_score, | |
| language_result, | |
| previous_score, | |
| llm_semantic=llm_semantic, | |
| session_ai_context={ | |
| "voice_ai_chunks": session.voice_ai_chunks, | |
| "max_ai_confidence": session.max_voice_ai_confidence, | |
| "chunks_processed": session.chunks_processed, | |
| "risk_history": list(session.risk_history), | |
| }, | |
| ) | |
| voice_classification = normalize_voice_classification( | |
| analysis_result.classification, | |
| scored["model_uncertain"], | |
| ) | |
| voice_confidence = float(max(0.0, min(1.0, analysis_result.confidence_score))) | |
| session.chunks_processed += 1 | |
| session.last_update = utc_now_iso() | |
| session.risk_history.append(scored["risk_score"]) | |
| if scored["risk_score"] >= session.max_risk_score: | |
| session.final_call_label = scored["call_label"] | |
| session.max_risk_score = max(session.max_risk_score, scored["risk_score"]) | |
| session.max_cpi = max(session.max_cpi, float(scored["cpi"])) | |
| if voice_classification == "AI_GENERATED": | |
| session.voice_ai_chunks += 1 | |
| session.max_voice_ai_confidence = max(session.max_voice_ai_confidence, voice_confidence) | |
| elif voice_classification == "HUMAN": | |
| session.voice_human_chunks += 1 | |
| session.max_voice_human_confidence = max(session.max_voice_human_confidence, voice_confidence) | |
| # Use majority vote for session-level classification instead of last-chunk | |
| if session.voice_ai_chunks > session.voice_human_chunks: | |
| session.final_voice_classification = "AI_GENERATED" | |
| session.final_voice_confidence = session.max_voice_ai_confidence | |
| elif session.voice_human_chunks > session.voice_ai_chunks: | |
| session.final_voice_classification = "HUMAN" | |
| session.final_voice_confidence = session.max_voice_human_confidence | |
| else: | |
| # Tie — use the latest chunk's decision | |
| session.final_voice_classification = voice_classification | |
| session.final_voice_confidence = voice_confidence | |
| # Reconcile final_call_label with majority vote. | |
| if session.final_voice_classification == "HUMAN" and session.final_call_label == "FRAUD": | |
| avg_risk = sum(session.risk_history) / max(1, len(session.risk_history)) | |
| session.final_call_label = "SPAM" if avg_risk >= 30 else "SAFE" | |
| # If majority says AI_GENERATED but label is SAFE, upgrade to | |
| # at least SPAM so the label reflects the AI detection. | |
| elif session.final_voice_classification == "AI_GENERATED" and session.final_call_label == "SAFE": | |
| session.final_call_label = "SPAM" | |
| # Average risk sanity check: downgrade FRAUD when most chunks are LOW. | |
| if session.final_call_label == "FRAUD" and session.chunks_processed >= 5: | |
| avg_risk = sum(session.risk_history) / max(1, len(session.risk_history)) | |
| if avg_risk < 35: | |
| session.final_call_label = "SPAM" | |
| logger.info( | |
| "Sanity: downgraded FRAUD -> SPAM (avg_risk=%.1f, chunks=%d)", | |
| avg_risk, session.chunks_processed, | |
| ) | |
| if scored["alert"].triggered: | |
| alert_obj = scored["alert"] | |
| alert_entry = { | |
| "timestamp": session.last_update, | |
| "risk_score": scored["risk_score"], | |
| "risk_level": scored["risk_level"], | |
| "call_label": scored["call_label"], | |
| "alert_type": alert_obj.alert_type or "FRAUD_RISK_HIGH", | |
| "severity": alert_obj.severity or scored["risk_level"].lower(), | |
| "reason_summary": alert_obj.reason_summary or "Fraud indicators detected.", | |
| "recommended_action": alert_obj.recommended_action | |
| or recommendation_for_level(scored["risk_level"], scored["model_uncertain"]), | |
| "occurrences": 1, | |
| } | |
| last_alert = session.alert_history[-1] if session.alert_history else None | |
| duplicate_keys = ("alert_type", "severity", "reason_summary", "recommended_action", "call_label", "risk_level") | |
| is_duplicate = bool( | |
| last_alert | |
| and all(last_alert.get(key) == alert_entry.get(key) for key in duplicate_keys) | |
| ) | |
| if is_duplicate: | |
| last_alert["timestamp"] = session.last_update | |
| last_alert["risk_score"] = max(int(last_alert.get("risk_score", 0)), scored["risk_score"]) | |
| last_alert["occurrences"] = last_alert.get("occurrences", 1) + 1 | |
| else: | |
| session.alerts_triggered += 1 | |
| session.alert_history.append(alert_entry) | |
| if len(session.alert_history) > 100: | |
| session.alert_history = session.alert_history[-100:] | |
| save_session_state(session) | |
| return RealTimeUpdateResponse( | |
| status="success", | |
| session_id=session_id, | |
| timestamp=session.last_update, | |
| risk_score=scored["risk_score"], | |
| cpi=scored["cpi"], | |
| risk_level=scored["risk_level"], | |
| call_label=scored["call_label"], | |
| model_uncertain=scored["model_uncertain"], | |
| voice_classification=voice_classification, | |
| voice_confidence=voice_confidence, | |
| evidence=scored["evidence"], | |
| language_analysis=scored["language_analysis"], | |
| alert=scored["alert"], | |
| explainability=scored["explainability"], | |
| chunks_processed=session.chunks_processed, | |
| risk_policy_version=settings.RISK_POLICY_VERSION, | |
| ) | |
| def session_to_summary(session: SessionState) -> SessionSummaryResponse: | |
| """Convert session state to response model.""" | |
| resolved_level = map_score_to_level(session.max_risk_score) | |
| resolved_label = map_level_to_label(resolved_level, model_uncertain=False) | |
| return SessionSummaryResponse( | |
| status="success", | |
| session_id=session.session_id, | |
| language=session.language, | |
| session_status=session.status, | |
| started_at=session.started_at, | |
| last_update=session.last_update, | |
| chunks_processed=session.chunks_processed, | |
| alerts_triggered=session.alerts_triggered, | |
| max_risk_score=session.max_risk_score, | |
| max_cpi=round(session.max_cpi, 1), | |
| risk_level=resolved_level, | |
| risk_label=resolved_label, | |
| final_call_label=session.final_call_label, | |
| final_voice_classification=session.final_voice_classification, | |
| final_voice_confidence=round(session.final_voice_confidence, 2), | |
| max_voice_ai_confidence=round(session.max_voice_ai_confidence, 2), | |
| voice_ai_chunks=session.voice_ai_chunks, | |
| voice_human_chunks=session.voice_human_chunks, | |
| llm_checks_performed=session.llm_checks_performed, | |
| risk_policy_version=settings.RISK_POLICY_VERSION, | |
| alert_history=list(session.alert_history), | |
| ) | |
| # Authentication | |
| api_key_header = APIKeyHeader(name="x-api-key", auto_error=False) # Changed to False for better error messages | |
| async def verify_api_key(x_api_key: str = Security(api_key_header)) -> str: | |
| """Dependency to verify API key. Raises 401 if invalid or missing.""" | |
| if x_api_key is None: | |
| logger.warning("API request without x-api-key header") | |
| raise HTTPException( | |
| status_code=401, | |
| detail={"status": "error", "message": "Missing API key. Include 'x-api-key' header."} | |
| ) | |
| if not hmac.compare_digest(x_api_key, settings.API_KEY): | |
| logger.warning("API request with invalid key: ***%s", x_api_key[-4:] if len(x_api_key) >= 4 else "****") | |
| raise HTTPException( | |
| status_code=401, | |
| detail={"status": "error", "message": "Invalid API key"} | |
| ) | |
| return x_api_key | |
| def verify_websocket_api_key(websocket: WebSocket) -> bool: | |
| """Validate API key for websocket connections.""" | |
| key = websocket.headers.get("x-api-key") or websocket.query_params.get("api_key") | |
| if key is None: | |
| return False | |
| return hmac.compare_digest(key, settings.API_KEY) | |
| # Routes | |
| async def root(): | |
| """Redirect to API documentation.""" | |
| return RedirectResponse(url="/docs") | |
| async def health_check(): | |
| """Health check for monitoring - verifies ML model is loaded.""" | |
| try: | |
| from model import _model | |
| model_loaded = _model is not None | |
| except Exception: | |
| model_loaded = False | |
| return { | |
| "status": "healthy" if model_loaded else "degraded", | |
| "model_loaded": model_loaded, | |
| "session_store_backend": SESSION_STORE_BACKEND_ACTIVE, | |
| } | |
| async def start_realtime_session( | |
| session_request: SessionStartRequest, | |
| api_key: str = Depends(verify_api_key) | |
| ): | |
| """Create a new real-time fraud analysis session.""" | |
| validate_supported_language(session_request.language) | |
| session_id = str(uuid.uuid4()) | |
| started_at = utc_now_iso() | |
| async with SESSION_LOCK: | |
| session_state = SessionState( | |
| session_id=session_id, | |
| language=session_request.language, | |
| started_at=started_at | |
| ) | |
| save_session_state(session_state) | |
| return SessionStartResponse( | |
| status="success", | |
| session_id=session_id, | |
| language=session_request.language, | |
| started_at=started_at, | |
| message="Session created. Send chunks using /api/voice-detection/v1/session/{session_id}/chunk or websocket stream." | |
| ) | |
| async def analyze_realtime_chunk( | |
| request: Request, | |
| session_id: str, | |
| chunk_request: SessionChunkRequest, | |
| api_key: str = Depends(verify_api_key) | |
| ): | |
| """Analyze one chunk for an active real-time session.""" | |
| request_id = getattr(request.state, "request_id", f"sess-{session_id[:8]}") | |
| async with SESSION_LOCK: | |
| session = get_session_state(session_id) | |
| if session is None: | |
| raise HTTPException( | |
| status_code=404, | |
| detail={"status": "error", "message": "Session not found or expired"} | |
| ) | |
| if session.status != "active": | |
| raise HTTPException( | |
| status_code=409, | |
| detail={"status": "error", "message": "Session is not active. Start a new session to continue."} | |
| ) | |
| session_language = session.language | |
| try: | |
| return await process_audio_chunk(session_id, chunk_request, session_language, request_id) | |
| except ValueError as e: | |
| raise HTTPException(status_code=400, detail={"status": "error", "message": str(e)}) from e | |
| async def stream_realtime_session(websocket: WebSocket, session_id: str): | |
| """WebSocket endpoint for continuous chunk-based analysis.""" | |
| # Accept auth via query-param or first-message token | |
| has_query_key = verify_websocket_api_key(websocket) | |
| if not has_query_key: | |
| # No query-param key — accept connection and require first-message auth | |
| pass | |
| async with SESSION_LOCK: | |
| session = get_session_state(session_id) | |
| if session is None: | |
| await websocket.close(code=1008, reason="Session not found or expired") | |
| return | |
| if session.status != "active": | |
| await websocket.close(code=1008, reason="Session is not active") | |
| return | |
| session_language = session.language | |
| await websocket.accept() | |
| request_id = f"ws-{session_id[:8]}" | |
| ws_start = time.time() | |
| # Fall back to first-message authentication | |
| if not has_query_key: | |
| try: | |
| auth_msg = await asyncio.wait_for(websocket.receive_json(), timeout=10.0) | |
| except (asyncio.TimeoutError, Exception): | |
| await websocket.close(code=1008, reason="Auth timeout") | |
| return | |
| if auth_msg.get("type") != "auth" or not auth_msg.get("api_key"): | |
| await websocket.close(code=1008, reason="Invalid auth message") | |
| return | |
| expected = settings.API_KEY | |
| provided = str(auth_msg["api_key"]) | |
| if expected is None or not hmac.compare_digest(expected, provided): | |
| await websocket.close(code=1008, reason="Invalid API key") | |
| return | |
| try: | |
| while True: | |
| # Enforce max connection duration | |
| elapsed = time.time() - ws_start | |
| if elapsed >= settings.WS_MAX_DURATION_SECONDS: | |
| await websocket.send_json({ | |
| "status": "error", | |
| "message": f"WebSocket max duration ({settings.WS_MAX_DURATION_SECONDS}s) exceeded" | |
| }) | |
| await websocket.close(code=1000, reason="Max duration exceeded") | |
| break | |
| # Enforce idle timeout | |
| try: | |
| payload = await asyncio.wait_for( | |
| websocket.receive_json(), | |
| timeout=settings.WS_IDLE_TIMEOUT_SECONDS | |
| ) | |
| except asyncio.TimeoutError: | |
| await websocket.send_json({ | |
| "status": "error", | |
| "message": f"Idle timeout ({settings.WS_IDLE_TIMEOUT_SECONDS}s) — no data received" | |
| }) | |
| await websocket.close(code=1000, reason="Idle timeout") | |
| break | |
| try: | |
| chunk_request = SessionChunkRequest.model_validate(payload) | |
| except ValidationError as e: | |
| await websocket.send_json({ | |
| "status": "error", | |
| "message": "Invalid chunk payload", | |
| "details": e.errors() | |
| }) | |
| continue | |
| try: | |
| update = await process_audio_chunk(session_id, chunk_request, session_language, request_id) | |
| await websocket.send_json(update.model_dump()) | |
| except HTTPException as e: | |
| detail = e.detail if isinstance(e.detail, dict) else {"status": "error", "message": str(e.detail)} | |
| await websocket.send_json(detail) | |
| except ValueError as e: | |
| await websocket.send_json({"status": "error", "message": str(e)}) | |
| except WebSocketDisconnect: | |
| logger.info("[%s] WebSocket disconnected", request_id) | |
| async def get_session_summary( | |
| session_id: str, | |
| api_key: str = Depends(verify_api_key) | |
| ): | |
| """Return current summary for a real-time session.""" | |
| async with SESSION_LOCK: | |
| session = get_session_state(session_id) | |
| if session is None: | |
| raise HTTPException( | |
| status_code=404, | |
| detail={"status": "error", "message": "Session not found or expired"} | |
| ) | |
| return session_to_summary(session) | |
| async def get_session_alerts( | |
| session_id: str, | |
| limit: int = 20, | |
| api_key: str = Depends(verify_api_key), | |
| ): | |
| """Return recent alert history for a real-time session.""" | |
| if limit < 1 or limit > 100: | |
| raise HTTPException( | |
| status_code=400, | |
| detail={"status": "error", "message": "limit must be between 1 and 100"}, | |
| ) | |
| async with SESSION_LOCK: | |
| session = get_session_state(session_id) | |
| if session is None: | |
| raise HTTPException( | |
| status_code=404, | |
| detail={"status": "error", "message": "Session not found or expired"}, | |
| ) | |
| alerts = [AlertHistoryItem(**item) for item in session.alert_history[-limit:]] | |
| return AlertHistoryResponse( | |
| status="success", | |
| session_id=session_id, | |
| total_alerts=len(session.alert_history), | |
| alerts=alerts, | |
| ) | |
| async def get_retention_policy(api_key: str = Depends(verify_api_key)): | |
| """Return explicit privacy defaults for raw audio and session-derived data.""" | |
| return RetentionPolicyResponse( | |
| status="success", | |
| raw_audio_storage="not_persisted", | |
| active_session_retention_seconds=settings.SESSION_ACTIVE_RETENTION_SECONDS, | |
| ended_session_retention_seconds=settings.SESSION_ENDED_RETENTION_SECONDS, | |
| stored_derived_fields=STORED_DERIVED_FIELDS, | |
| ) | |
| async def end_realtime_session( | |
| session_id: str, | |
| api_key: str = Depends(verify_api_key) | |
| ): | |
| """Mark a session as ended and return final summary.""" | |
| async with SESSION_LOCK: | |
| session = get_session_state(session_id) | |
| if session is None: | |
| raise HTTPException( | |
| status_code=404, | |
| detail={"status": "error", "message": "Session not found or expired"} | |
| ) | |
| session.status = "ended" | |
| session.last_update = utc_now_iso() | |
| save_session_state(session) | |
| return session_to_summary(session) | |
| # Rate limit: 1000 requests per minute per IP | |
| async def detect_voice( | |
| request: Request, # Required for rate limiter | |
| voice_request: VoiceDetectionRequest, | |
| api_key: str = Depends(verify_api_key) # Use dependency injection | |
| ): | |
| """ | |
| Returns classification result with confidence score and explanation. | |
| """ | |
| request_id = getattr(request.state, 'request_id', 'unknown') | |
| audio_size_kb = len(voice_request.audioBase64) * 3 / 4 / 1024 | |
| logger.info("[%s] Voice detection: lang=%s, fmt=%s, size~%.1fKB", | |
| request_id, voice_request.language, voice_request.audioFormat, audio_size_kb) | |
| voice_request.language = validate_supported_language(voice_request.language) | |
| validate_supported_format(voice_request.audioFormat) | |
| LEGACY_TIMEOUT_SECONDS = 20 | |
| try: | |
| decode_start = time.perf_counter() | |
| audio_bytes = await asyncio.to_thread(decode_base64_audio, voice_request.audioBase64) | |
| audio, sr = await asyncio.to_thread(load_audio_from_bytes, audio_bytes, 16000, voice_request.audioFormat) | |
| max_samples = sr * 20 | |
| if len(audio) > max_samples: | |
| logger.warning("[%s] Truncating audio from %.1fs to 20s", request_id, len(audio) / sr) | |
| audio = audio[:max_samples] | |
| duration_sec = len(audio) / sr | |
| remaining_budget = LEGACY_TIMEOUT_SECONDS - (time.perf_counter() - decode_start) | |
| if remaining_budget < 2: | |
| raise asyncio.TimeoutError("Insufficient time budget for analysis") | |
| result = await analyze_voice_guarded( | |
| audio, sr, max(2.0, remaining_budget), request_id, voice_request.language | |
| ) | |
| analyze_time = (time.perf_counter() - decode_start) * 1000 | |
| logger.info("[%s] Analysis complete: %s (%.0f%%) in %.0fms", | |
| request_id, result.classification, result.confidence_score * 100, analyze_time) | |
| metrics = None | |
| if result.features: | |
| metrics = ForensicMetrics( | |
| authenticity_score=result.features.get("authenticity_score", 0), | |
| pitch_naturalness=result.features.get("pitch_naturalness", 0), | |
| spectral_naturalness=result.features.get("spectral_naturalness", 0), | |
| temporal_naturalness=result.features.get("temporal_naturalness", 0) | |
| ) | |
| model_uncertain = bool((result.features or {}).get("ml_fallback", 0.0)) | |
| explanation = result.explanation | |
| recommended_action = None | |
| response_classification = result.classification | |
| if model_uncertain: | |
| explanation = ( | |
| "Model uncertainty detected due fallback inference. " | |
| "Treat result as cautionary and verify through trusted channels. " | |
| f"{result.explanation}" | |
| ) | |
| recommended_action = ( | |
| "Do not share OTP, PIN, passwords, or payment credentials. " | |
| "Verify caller identity through official support channels." | |
| ) | |
| elif response_classification == "AI_GENERATED": | |
| recommended_action = ( | |
| "AI-generated voice detected. Do not share OTP, PIN, or payment " | |
| "credentials. Verify caller identity through official channels." | |
| ) | |
| return VoiceDetectionResponse( | |
| status="success", | |
| language=voice_request.language, | |
| classification=response_classification, | |
| confidenceScore=result.confidence_score, | |
| explanation=explanation, | |
| forensic_metrics=metrics, | |
| modelUncertain=model_uncertain, | |
| recommendedAction=recommended_action, | |
| ) | |
| except ValueError as e: | |
| logger.warning("[%s] Validation error: %s", request_id, e) | |
| raise HTTPException( | |
| status_code=400, | |
| detail={"status": "error", "message": str(e)} | |
| ) | |
| except asyncio.TimeoutError: | |
| logger.warning("[%s] Legacy endpoint exceeded %ds budget", request_id, LEGACY_TIMEOUT_SECONDS) | |
| return VoiceDetectionResponse( | |
| status="success", | |
| language=voice_request.language, | |
| classification="HUMAN", | |
| confidenceScore=0.50, | |
| explanation="Analysis timed out. Returning cautionary HUMAN classification.", | |
| forensic_metrics=None, | |
| modelUncertain=True, | |
| recommendedAction="Analysis took too long. Verify caller identity through official channels.", | |
| ) | |
| except Exception as e: | |
| logger.error("[%s] Processing error: %s", request_id, e, exc_info=True) | |
| raise HTTPException( | |
| status_code=500, | |
| detail={"status": "error", "message": "Internal Server Error"} | |
| ) | |
| # Exception handlers | |
| def to_json_safe(value: Any) -> Any: | |
| """Recursively convert values to JSON-safe primitives.""" | |
| if value is None or isinstance(value, (str, int, float, bool)): | |
| return value | |
| if isinstance(value, BaseException): | |
| return str(value) | |
| if isinstance(value, dict): | |
| return {str(k): to_json_safe(v) for k, v in value.items()} | |
| if isinstance(value, (list, tuple, set)): | |
| return [to_json_safe(item) for item in value] | |
| return str(value) | |
| async def validation_exception_handler(request: Request, exc: RequestValidationError): | |
| """ | |
| Custom handler for 422 Validation Errors. | |
| Provides clearer error messages for common issues. | |
| """ | |
| errors = to_json_safe(exc.errors()) | |
| logger.warning("Validation error: %s", errors) | |
| # Build user-friendly error message | |
| error_messages = [] | |
| for error in errors: | |
| loc = " -> ".join(str(l) for l in error.get("loc", [])) | |
| msg = error.get("msg", "Invalid value") | |
| error_messages.append(f"{loc}: {msg}") | |
| # Common issue detection | |
| if any("audioBase64" in str(e.get("loc", [])) for e in errors): | |
| hint = " Hint: Ensure 'audioBase64' is a valid Base64-encoded string." | |
| elif any("language" in str(e.get("loc", [])) for e in errors): | |
| hint = f" Hint: 'language' must be one of: {', '.join(settings.SUPPORTED_LANGUAGES)}." | |
| else: | |
| hint = "" | |
| return JSONResponse( | |
| status_code=422, | |
| content={ | |
| "status": "error", | |
| "message": f"Request validation failed: {'; '.join(error_messages)}.{hint}", | |
| "details": errors | |
| } | |
| ) | |
| async def http_exception_handler(request: Request, exc: HTTPException): | |
| """Custom exception handler to ensure consistent error format.""" | |
| if isinstance(exc.detail, dict): | |
| return JSONResponse( | |
| status_code=exc.status_code, | |
| content=exc.detail | |
| ) | |
| return JSONResponse( | |
| status_code=exc.status_code, | |
| content={"status": "error", "message": str(exc.detail)} | |
| ) | |
| async def global_exception_handler(request: Request, exc: Exception): | |
| """Global handler to catch unhandled exceptions and prevent stack traces.""" | |
| logger.error("Unhandled error: %s", exc, exc_info=True) | |
| return JSONResponse( | |
| status_code=500, | |
| content={"status": "error", "message": "Internal Server Error"} | |
| ) | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=settings.PORT) | |