""" WebSocket streaming for real-time speech pathology diagnosis. This module provides WebSocket endpoint for streaming audio analysis with <50ms latency per frame requirement. """ import logging import time import uuid import numpy as np from typing import Optional, Dict from collections import deque from datetime import datetime from fastapi import WebSocket, WebSocketDisconnect, HTTPException from api.schemas import StreamingDiagnosisResponse, FluencyInfo, ArticulationInfo, ErrorDetailSchema from models.phoneme_mapper import PhonemeMapper from models.error_taxonomy import ErrorMapper, ErrorType from inference.inference_pipeline import InferencePipeline from config import AudioConfig, default_audio_config logger = logging.getLogger(__name__) class StreamingBuffer: """ Buffer for managing sliding window in streaming audio. Maintains a buffer of audio samples and provides frames for processing with overlap management. """ def __init__(self, window_size_samples: int, hop_size_samples: int): """ Initialize streaming buffer. Args: window_size_samples: Size of analysis window in samples hop_size_samples: Hop size between windows in samples """ self.window_size_samples = window_size_samples self.hop_size_samples = hop_size_samples self.buffer = deque(maxlen=window_size_samples + hop_size_samples) self.frame_index = 0 logger.debug(f"StreamingBuffer initialized: window={window_size_samples}, hop={hop_size_samples}") def add_chunk(self, audio_chunk: np.ndarray) -> bool: """ Add audio chunk to buffer. Args: audio_chunk: Audio samples to add Returns: True if buffer has enough data for a frame, False otherwise """ self.buffer.extend(audio_chunk) return len(self.buffer) >= self.window_size_samples def get_frame(self) -> Optional[np.ndarray]: """ Get current frame from buffer. Returns: Audio frame array if ready, None otherwise """ if len(self.buffer) < self.window_size_samples: return None # Extract window (last window_size_samples) frame = np.array(list(self.buffer)[-self.window_size_samples:]) return frame def slide(self): """Advance buffer by hop size.""" # Remove oldest hop_size_samples for _ in range(min(self.hop_size_samples, len(self.buffer))): if self.buffer: self.buffer.popleft() self.frame_index += 1 # Global instances (will be injected) inference_pipeline: Optional[InferencePipeline] = None phoneme_mapper: Optional[PhonemeMapper] = None error_mapper: Optional[ErrorMapper] = None # Active streaming sessions streaming_sessions: Dict[str, Dict] = {} def initialize_streaming( pipeline: InferencePipeline, mapper: Optional[PhonemeMapper] = None, error_mapper_instance: Optional[ErrorMapper] = None ): """ Initialize streaming with dependencies. Args: pipeline: InferencePipeline instance mapper: Optional PhonemeMapper instance error_mapper_instance: Optional ErrorMapper instance """ global inference_pipeline, phoneme_mapper, error_mapper inference_pipeline = pipeline if mapper is None: try: phoneme_mapper = PhonemeMapper( frame_duration_ms=default_audio_config.chunk_duration_ms, sample_rate=default_audio_config.sample_rate ) logger.info("✅ PhonemeMapper initialized for streaming") except Exception as e: logger.warning(f"⚠️ PhonemeMapper not available: {e}") phoneme_mapper = None if error_mapper_instance is None: try: error_mapper = ErrorMapper() logger.info("✅ ErrorMapper initialized for streaming") except Exception as e: logger.error(f"❌ ErrorMapper failed to initialize: {e}") error_mapper = None async def handle_streaming_websocket(websocket: WebSocket, session_id: Optional[str] = None): """ Handle WebSocket connection for streaming diagnosis. Args: websocket: WebSocket connection session_id: Optional session ID (auto-generated if not provided) """ if inference_pipeline is None: await websocket.close(code=1003, reason="Inference pipeline not loaded") return # Generate session ID if not session_id: session_id = str(uuid.uuid4()) # Accept connection await websocket.accept() logger.info(f"🔌 WebSocket connected: session_id={session_id}") # Initialize buffer window_size_samples = int( inference_pipeline.inference_config.window_size_ms * inference_pipeline.audio_config.sample_rate / 1000 ) hop_size_samples = int( inference_pipeline.inference_config.hop_size_ms * inference_pipeline.audio_config.sample_rate / 1000 ) buffer = StreamingBuffer(window_size_samples, hop_size_samples) # Session metadata streaming_sessions[session_id] = { "session_id": session_id, "connected_at": datetime.now(), "frame_count": 0, "total_latency_ms": 0.0 } frame_index = 0 start_time = time.time() try: while True: # Receive audio chunk try: data = await websocket.receive_bytes() # Convert bytes to numpy array # Assuming 16-bit PCM, mono, 16kHz audio_chunk = np.frombuffer(data, dtype=np.int16).astype(np.float32) / 32768.0 # Add to buffer buffer.add_chunk(audio_chunk) # Process if buffer is ready if buffer.get_frame() is not None: frame_start_time = time.time() # Get frame frame = buffer.get_frame() # Run inference try: result = inference_pipeline.predict_phone_level( frame, return_timestamps=False ) if result.frame_predictions: frame_pred = result.frame_predictions[0] # Single frame result # Map to error detail class_id = frame_pred.articulation_class if frame_pred.fluency_label == 'stutter': class_id += 4 error_detail = None phoneme = '' # Streaming doesn't have text input if error_mapper: try: error_detail_obj = error_mapper.map_classifier_output( class_id=class_id, confidence=frame_pred.confidence, phoneme=phoneme, fluency_label=frame_pred.fluency_label ) if error_detail_obj.error_type != ErrorType.NORMAL: error_detail = ErrorDetailSchema( phoneme=error_detail_obj.phoneme, error_type=error_detail_obj.error_type.value, wrong_sound=error_detail_obj.wrong_sound, severity=error_detail_obj.severity, confidence=error_detail_obj.confidence, therapy=error_detail_obj.therapy, frame_indices=[frame_index] ) except Exception as e: logger.warning(f"Error mapping failed: {e}") # Calculate latency latency_ms = (time.time() - frame_start_time) * 1000 # Get severity level severity_level = "none" if error_detail and error_mapper: severity_level = error_mapper.get_severity_level(error_detail.severity).value # Create response response = StreamingDiagnosisResponse( session_id=session_id, frame_id=frame_index, timestamp=frame_index * (inference_pipeline.inference_config.hop_size_ms / 1000.0), phoneme=phoneme, fluency=FluencyInfo( label=frame_pred.fluency_label, confidence=frame_pred.fluency_prob if frame_pred.fluency_label == 'stutter' else (1.0 - frame_pred.fluency_prob) ), articulation=ArticulationInfo( label=frame_pred.articulation_label, confidence=frame_pred.confidence, class_id=frame_pred.articulation_class ), error=error_detail, severity_level=severity_level, confidence=frame_pred.confidence, latency_ms=latency_ms ) # Send response await websocket.send_json(response.model_dump()) # Update session stats streaming_sessions[session_id]["frame_count"] += 1 streaming_sessions[session_id]["total_latency_ms"] += latency_ms # Check latency requirement if latency_ms > 50.0: logger.warning(f"⚠️ Latency exceeded 50ms: {latency_ms:.1f}ms") # Slide buffer buffer.slide() frame_index += 1 except Exception as e: logger.error(f"❌ Inference failed: {e}", exc_info=True) await websocket.send_json({ "error": f"Inference failed: {str(e)}", "frame_id": frame_index }) except Exception as e: logger.error(f"❌ Error processing chunk: {e}", exc_info=True) await websocket.send_json({ "error": f"Processing failed: {str(e)}", "frame_id": frame_index }) except WebSocketDisconnect: logger.info(f"🔌 WebSocket disconnected: session_id={session_id}") except Exception as e: logger.error(f"❌ WebSocket error: {e}", exc_info=True) finally: # Cleanup session if session_id in streaming_sessions: session_data = streaming_sessions[session_id] avg_latency = session_data["total_latency_ms"] / session_data["frame_count"] if session_data["frame_count"] > 0 else 0.0 logger.info(f"📊 Session {session_id} stats: {session_data['frame_count']} frames, " f"avg_latency={avg_latency:.1f}ms") del streaming_sessions[session_id]