|
|
""" |
|
|
WebSocket streaming for real-time speech pathology diagnosis. |
|
|
|
|
|
This module provides WebSocket endpoint for streaming audio analysis |
|
|
with <50ms latency per frame requirement. |
|
|
""" |
|
|
|
|
|
import logging |
|
|
import time |
|
|
import uuid |
|
|
import numpy as np |
|
|
from typing import Optional, Dict |
|
|
from collections import deque |
|
|
from datetime import datetime |
|
|
|
|
|
from fastapi import WebSocket, WebSocketDisconnect, HTTPException |
|
|
|
|
|
from api.schemas import StreamingDiagnosisResponse, FluencyInfo, ArticulationInfo, ErrorDetailSchema |
|
|
from models.phoneme_mapper import PhonemeMapper |
|
|
from models.error_taxonomy import ErrorMapper, ErrorType |
|
|
from inference.inference_pipeline import InferencePipeline |
|
|
from config import AudioConfig, default_audio_config |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
class StreamingBuffer: |
|
|
""" |
|
|
Buffer for managing sliding window in streaming audio. |
|
|
|
|
|
Maintains a buffer of audio samples and provides frames |
|
|
for processing with overlap management. |
|
|
""" |
|
|
|
|
|
def __init__(self, window_size_samples: int, hop_size_samples: int): |
|
|
""" |
|
|
Initialize streaming buffer. |
|
|
|
|
|
Args: |
|
|
window_size_samples: Size of analysis window in samples |
|
|
hop_size_samples: Hop size between windows in samples |
|
|
""" |
|
|
self.window_size_samples = window_size_samples |
|
|
self.hop_size_samples = hop_size_samples |
|
|
self.buffer = deque(maxlen=window_size_samples + hop_size_samples) |
|
|
self.frame_index = 0 |
|
|
|
|
|
logger.debug(f"StreamingBuffer initialized: window={window_size_samples}, hop={hop_size_samples}") |
|
|
|
|
|
def add_chunk(self, audio_chunk: np.ndarray) -> bool: |
|
|
""" |
|
|
Add audio chunk to buffer. |
|
|
|
|
|
Args: |
|
|
audio_chunk: Audio samples to add |
|
|
|
|
|
Returns: |
|
|
True if buffer has enough data for a frame, False otherwise |
|
|
""" |
|
|
self.buffer.extend(audio_chunk) |
|
|
return len(self.buffer) >= self.window_size_samples |
|
|
|
|
|
def get_frame(self) -> Optional[np.ndarray]: |
|
|
""" |
|
|
Get current frame from buffer. |
|
|
|
|
|
Returns: |
|
|
Audio frame array if ready, None otherwise |
|
|
""" |
|
|
if len(self.buffer) < self.window_size_samples: |
|
|
return None |
|
|
|
|
|
|
|
|
frame = np.array(list(self.buffer)[-self.window_size_samples:]) |
|
|
return frame |
|
|
|
|
|
def slide(self): |
|
|
"""Advance buffer by hop size.""" |
|
|
|
|
|
for _ in range(min(self.hop_size_samples, len(self.buffer))): |
|
|
if self.buffer: |
|
|
self.buffer.popleft() |
|
|
self.frame_index += 1 |
|
|
|
|
|
|
|
|
|
|
|
inference_pipeline: Optional[InferencePipeline] = None |
|
|
phoneme_mapper: Optional[PhonemeMapper] = None |
|
|
error_mapper: Optional[ErrorMapper] = None |
|
|
|
|
|
|
|
|
streaming_sessions: Dict[str, Dict] = {} |
|
|
|
|
|
|
|
|
def initialize_streaming( |
|
|
pipeline: InferencePipeline, |
|
|
mapper: Optional[PhonemeMapper] = None, |
|
|
error_mapper_instance: Optional[ErrorMapper] = None |
|
|
): |
|
|
""" |
|
|
Initialize streaming with dependencies. |
|
|
|
|
|
Args: |
|
|
pipeline: InferencePipeline instance |
|
|
mapper: Optional PhonemeMapper instance |
|
|
error_mapper_instance: Optional ErrorMapper instance |
|
|
""" |
|
|
global inference_pipeline, phoneme_mapper, error_mapper |
|
|
|
|
|
inference_pipeline = pipeline |
|
|
|
|
|
if mapper is None: |
|
|
try: |
|
|
phoneme_mapper = PhonemeMapper( |
|
|
frame_duration_ms=default_audio_config.chunk_duration_ms, |
|
|
sample_rate=default_audio_config.sample_rate |
|
|
) |
|
|
logger.info("β
PhonemeMapper initialized for streaming") |
|
|
except Exception as e: |
|
|
logger.warning(f"β οΈ PhonemeMapper not available: {e}") |
|
|
phoneme_mapper = None |
|
|
|
|
|
if error_mapper_instance is None: |
|
|
try: |
|
|
error_mapper = ErrorMapper() |
|
|
logger.info("β
ErrorMapper initialized for streaming") |
|
|
except Exception as e: |
|
|
logger.error(f"β ErrorMapper failed to initialize: {e}") |
|
|
error_mapper = None |
|
|
|
|
|
|
|
|
async def handle_streaming_websocket(websocket: WebSocket, session_id: Optional[str] = None): |
|
|
""" |
|
|
Handle WebSocket connection for streaming diagnosis. |
|
|
|
|
|
Args: |
|
|
websocket: WebSocket connection |
|
|
session_id: Optional session ID (auto-generated if not provided) |
|
|
""" |
|
|
if inference_pipeline is None: |
|
|
await websocket.close(code=1003, reason="Inference pipeline not loaded") |
|
|
return |
|
|
|
|
|
|
|
|
if not session_id: |
|
|
session_id = str(uuid.uuid4()) |
|
|
|
|
|
|
|
|
await websocket.accept() |
|
|
logger.info(f"π WebSocket connected: session_id={session_id}") |
|
|
|
|
|
|
|
|
window_size_samples = int( |
|
|
inference_pipeline.inference_config.window_size_ms * |
|
|
inference_pipeline.audio_config.sample_rate / 1000 |
|
|
) |
|
|
hop_size_samples = int( |
|
|
inference_pipeline.inference_config.hop_size_ms * |
|
|
inference_pipeline.audio_config.sample_rate / 1000 |
|
|
) |
|
|
|
|
|
buffer = StreamingBuffer(window_size_samples, hop_size_samples) |
|
|
|
|
|
|
|
|
streaming_sessions[session_id] = { |
|
|
"session_id": session_id, |
|
|
"connected_at": datetime.now(), |
|
|
"frame_count": 0, |
|
|
"total_latency_ms": 0.0 |
|
|
} |
|
|
|
|
|
frame_index = 0 |
|
|
start_time = time.time() |
|
|
|
|
|
try: |
|
|
while True: |
|
|
|
|
|
try: |
|
|
data = await websocket.receive_bytes() |
|
|
|
|
|
|
|
|
|
|
|
audio_chunk = np.frombuffer(data, dtype=np.int16).astype(np.float32) / 32768.0 |
|
|
|
|
|
|
|
|
buffer.add_chunk(audio_chunk) |
|
|
|
|
|
|
|
|
if buffer.get_frame() is not None: |
|
|
frame_start_time = time.time() |
|
|
|
|
|
|
|
|
frame = buffer.get_frame() |
|
|
|
|
|
|
|
|
try: |
|
|
result = inference_pipeline.predict_phone_level( |
|
|
frame, |
|
|
return_timestamps=False |
|
|
) |
|
|
|
|
|
if result.frame_predictions: |
|
|
frame_pred = result.frame_predictions[0] |
|
|
|
|
|
|
|
|
class_id = frame_pred.articulation_class |
|
|
if frame_pred.fluency_label == 'stutter': |
|
|
class_id += 4 |
|
|
|
|
|
error_detail = None |
|
|
phoneme = '' |
|
|
|
|
|
if error_mapper: |
|
|
try: |
|
|
error_detail_obj = error_mapper.map_classifier_output( |
|
|
class_id=class_id, |
|
|
confidence=frame_pred.confidence, |
|
|
phoneme=phoneme, |
|
|
fluency_label=frame_pred.fluency_label |
|
|
) |
|
|
|
|
|
if error_detail_obj.error_type != ErrorType.NORMAL: |
|
|
error_detail = ErrorDetailSchema( |
|
|
phoneme=error_detail_obj.phoneme, |
|
|
error_type=error_detail_obj.error_type.value, |
|
|
wrong_sound=error_detail_obj.wrong_sound, |
|
|
severity=error_detail_obj.severity, |
|
|
confidence=error_detail_obj.confidence, |
|
|
therapy=error_detail_obj.therapy, |
|
|
frame_indices=[frame_index] |
|
|
) |
|
|
except Exception as e: |
|
|
logger.warning(f"Error mapping failed: {e}") |
|
|
|
|
|
|
|
|
latency_ms = (time.time() - frame_start_time) * 1000 |
|
|
|
|
|
|
|
|
severity_level = "none" |
|
|
if error_detail and error_mapper: |
|
|
severity_level = error_mapper.get_severity_level(error_detail.severity).value |
|
|
|
|
|
|
|
|
response = StreamingDiagnosisResponse( |
|
|
session_id=session_id, |
|
|
frame_id=frame_index, |
|
|
timestamp=frame_index * (inference_pipeline.inference_config.hop_size_ms / 1000.0), |
|
|
phoneme=phoneme, |
|
|
fluency=FluencyInfo( |
|
|
label=frame_pred.fluency_label, |
|
|
confidence=frame_pred.fluency_prob if frame_pred.fluency_label == 'stutter' else (1.0 - frame_pred.fluency_prob) |
|
|
), |
|
|
articulation=ArticulationInfo( |
|
|
label=frame_pred.articulation_label, |
|
|
confidence=frame_pred.confidence, |
|
|
class_id=frame_pred.articulation_class |
|
|
), |
|
|
error=error_detail, |
|
|
severity_level=severity_level, |
|
|
confidence=frame_pred.confidence, |
|
|
latency_ms=latency_ms |
|
|
) |
|
|
|
|
|
|
|
|
await websocket.send_json(response.model_dump()) |
|
|
|
|
|
|
|
|
streaming_sessions[session_id]["frame_count"] += 1 |
|
|
streaming_sessions[session_id]["total_latency_ms"] += latency_ms |
|
|
|
|
|
|
|
|
if latency_ms > 50.0: |
|
|
logger.warning(f"β οΈ Latency exceeded 50ms: {latency_ms:.1f}ms") |
|
|
|
|
|
|
|
|
buffer.slide() |
|
|
frame_index += 1 |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"β Inference failed: {e}", exc_info=True) |
|
|
await websocket.send_json({ |
|
|
"error": f"Inference failed: {str(e)}", |
|
|
"frame_id": frame_index |
|
|
}) |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"β Error processing chunk: {e}", exc_info=True) |
|
|
await websocket.send_json({ |
|
|
"error": f"Processing failed: {str(e)}", |
|
|
"frame_id": frame_index |
|
|
}) |
|
|
|
|
|
except WebSocketDisconnect: |
|
|
logger.info(f"π WebSocket disconnected: session_id={session_id}") |
|
|
except Exception as e: |
|
|
logger.error(f"β WebSocket error: {e}", exc_info=True) |
|
|
finally: |
|
|
|
|
|
if session_id in streaming_sessions: |
|
|
session_data = streaming_sessions[session_id] |
|
|
avg_latency = session_data["total_latency_ms"] / session_data["frame_count"] if session_data["frame_count"] > 0 else 0.0 |
|
|
logger.info(f"π Session {session_id} stats: {session_data['frame_count']} frames, " |
|
|
f"avg_latency={avg_latency:.1f}ms") |
|
|
del streaming_sessions[session_id] |
|
|
|
|
|
|