File size: 12,411 Bytes
278e294
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
"""
WebSocket streaming for real-time speech pathology diagnosis.

This module provides WebSocket endpoint for streaming audio analysis
with <50ms latency per frame requirement.
"""

import logging
import time
import uuid
import numpy as np
from typing import Optional, Dict
from collections import deque
from datetime import datetime

from fastapi import WebSocket, WebSocketDisconnect, HTTPException

from api.schemas import StreamingDiagnosisResponse, FluencyInfo, ArticulationInfo, ErrorDetailSchema
from models.phoneme_mapper import PhonemeMapper
from models.error_taxonomy import ErrorMapper, ErrorType
from inference.inference_pipeline import InferencePipeline
from config import AudioConfig, default_audio_config

logger = logging.getLogger(__name__)


class StreamingBuffer:
    """
    Buffer for managing sliding window in streaming audio.
    
    Maintains a buffer of audio samples and provides frames
    for processing with overlap management.
    """
    
    def __init__(self, window_size_samples: int, hop_size_samples: int):
        """
        Initialize streaming buffer.
        
        Args:
            window_size_samples: Size of analysis window in samples
            hop_size_samples: Hop size between windows in samples
        """
        self.window_size_samples = window_size_samples
        self.hop_size_samples = hop_size_samples
        self.buffer = deque(maxlen=window_size_samples + hop_size_samples)
        self.frame_index = 0
        
        logger.debug(f"StreamingBuffer initialized: window={window_size_samples}, hop={hop_size_samples}")
    
    def add_chunk(self, audio_chunk: np.ndarray) -> bool:
        """
        Add audio chunk to buffer.
        
        Args:
            audio_chunk: Audio samples to add
        
        Returns:
            True if buffer has enough data for a frame, False otherwise
        """
        self.buffer.extend(audio_chunk)
        return len(self.buffer) >= self.window_size_samples
    
    def get_frame(self) -> Optional[np.ndarray]:
        """
        Get current frame from buffer.
        
        Returns:
            Audio frame array if ready, None otherwise
        """
        if len(self.buffer) < self.window_size_samples:
            return None
        
        # Extract window (last window_size_samples)
        frame = np.array(list(self.buffer)[-self.window_size_samples:])
        return frame
    
    def slide(self):
        """Advance buffer by hop size."""
        # Remove oldest hop_size_samples
        for _ in range(min(self.hop_size_samples, len(self.buffer))):
            if self.buffer:
                self.buffer.popleft()
        self.frame_index += 1


# Global instances (will be injected)
inference_pipeline: Optional[InferencePipeline] = None
phoneme_mapper: Optional[PhonemeMapper] = None
error_mapper: Optional[ErrorMapper] = None

# Active streaming sessions
streaming_sessions: Dict[str, Dict] = {}


def initialize_streaming(
    pipeline: InferencePipeline,
    mapper: Optional[PhonemeMapper] = None,
    error_mapper_instance: Optional[ErrorMapper] = None
):
    """
    Initialize streaming with dependencies.
    
    Args:
        pipeline: InferencePipeline instance
        mapper: Optional PhonemeMapper instance
        error_mapper_instance: Optional ErrorMapper instance
    """
    global inference_pipeline, phoneme_mapper, error_mapper
    
    inference_pipeline = pipeline
    
    if mapper is None:
        try:
            phoneme_mapper = PhonemeMapper(
                frame_duration_ms=default_audio_config.chunk_duration_ms,
                sample_rate=default_audio_config.sample_rate
            )
            logger.info("βœ… PhonemeMapper initialized for streaming")
        except Exception as e:
            logger.warning(f"⚠️ PhonemeMapper not available: {e}")
            phoneme_mapper = None
    
    if error_mapper_instance is None:
        try:
            error_mapper = ErrorMapper()
            logger.info("βœ… ErrorMapper initialized for streaming")
        except Exception as e:
            logger.error(f"❌ ErrorMapper failed to initialize: {e}")
            error_mapper = None


async def handle_streaming_websocket(websocket: WebSocket, session_id: Optional[str] = None):
    """
    Handle WebSocket connection for streaming diagnosis.
    
    Args:
        websocket: WebSocket connection
        session_id: Optional session ID (auto-generated if not provided)
    """
    if inference_pipeline is None:
        await websocket.close(code=1003, reason="Inference pipeline not loaded")
        return
    
    # Generate session ID
    if not session_id:
        session_id = str(uuid.uuid4())
    
    # Accept connection
    await websocket.accept()
    logger.info(f"πŸ”Œ WebSocket connected: session_id={session_id}")
    
    # Initialize buffer
    window_size_samples = int(
        inference_pipeline.inference_config.window_size_ms * 
        inference_pipeline.audio_config.sample_rate / 1000
    )
    hop_size_samples = int(
        inference_pipeline.inference_config.hop_size_ms * 
        inference_pipeline.audio_config.sample_rate / 1000
    )
    
    buffer = StreamingBuffer(window_size_samples, hop_size_samples)
    
    # Session metadata
    streaming_sessions[session_id] = {
        "session_id": session_id,
        "connected_at": datetime.now(),
        "frame_count": 0,
        "total_latency_ms": 0.0
    }
    
    frame_index = 0
    start_time = time.time()
    
    try:
        while True:
            # Receive audio chunk
            try:
                data = await websocket.receive_bytes()
                
                # Convert bytes to numpy array
                # Assuming 16-bit PCM, mono, 16kHz
                audio_chunk = np.frombuffer(data, dtype=np.int16).astype(np.float32) / 32768.0
                
                # Add to buffer
                buffer.add_chunk(audio_chunk)
                
                # Process if buffer is ready
                if buffer.get_frame() is not None:
                    frame_start_time = time.time()
                    
                    # Get frame
                    frame = buffer.get_frame()
                    
                    # Run inference
                    try:
                        result = inference_pipeline.predict_phone_level(
                            frame,
                            return_timestamps=False
                        )
                        
                        if result.frame_predictions:
                            frame_pred = result.frame_predictions[0]  # Single frame result
                            
                            # Map to error detail
                            class_id = frame_pred.articulation_class
                            if frame_pred.fluency_label == 'stutter':
                                class_id += 4
                            
                            error_detail = None
                            phoneme = ''  # Streaming doesn't have text input
                            
                            if error_mapper:
                                try:
                                    error_detail_obj = error_mapper.map_classifier_output(
                                        class_id=class_id,
                                        confidence=frame_pred.confidence,
                                        phoneme=phoneme,
                                        fluency_label=frame_pred.fluency_label
                                    )
                                    
                                    if error_detail_obj.error_type != ErrorType.NORMAL:
                                        error_detail = ErrorDetailSchema(
                                            phoneme=error_detail_obj.phoneme,
                                            error_type=error_detail_obj.error_type.value,
                                            wrong_sound=error_detail_obj.wrong_sound,
                                            severity=error_detail_obj.severity,
                                            confidence=error_detail_obj.confidence,
                                            therapy=error_detail_obj.therapy,
                                            frame_indices=[frame_index]
                                        )
                                except Exception as e:
                                    logger.warning(f"Error mapping failed: {e}")
                            
                            # Calculate latency
                            latency_ms = (time.time() - frame_start_time) * 1000
                            
                            # Get severity level
                            severity_level = "none"
                            if error_detail and error_mapper:
                                severity_level = error_mapper.get_severity_level(error_detail.severity).value
                            
                            # Create response
                            response = StreamingDiagnosisResponse(
                                session_id=session_id,
                                frame_id=frame_index,
                                timestamp=frame_index * (inference_pipeline.inference_config.hop_size_ms / 1000.0),
                                phoneme=phoneme,
                                fluency=FluencyInfo(
                                    label=frame_pred.fluency_label,
                                    confidence=frame_pred.fluency_prob if frame_pred.fluency_label == 'stutter' else (1.0 - frame_pred.fluency_prob)
                                ),
                                articulation=ArticulationInfo(
                                    label=frame_pred.articulation_label,
                                    confidence=frame_pred.confidence,
                                    class_id=frame_pred.articulation_class
                                ),
                                error=error_detail,
                                severity_level=severity_level,
                                confidence=frame_pred.confidence,
                                latency_ms=latency_ms
                            )
                            
                            # Send response
                            await websocket.send_json(response.model_dump())
                            
                            # Update session stats
                            streaming_sessions[session_id]["frame_count"] += 1
                            streaming_sessions[session_id]["total_latency_ms"] += latency_ms
                            
                            # Check latency requirement
                            if latency_ms > 50.0:
                                logger.warning(f"⚠️ Latency exceeded 50ms: {latency_ms:.1f}ms")
                            
                            # Slide buffer
                            buffer.slide()
                            frame_index += 1
                            
                    except Exception as e:
                        logger.error(f"❌ Inference failed: {e}", exc_info=True)
                        await websocket.send_json({
                            "error": f"Inference failed: {str(e)}",
                            "frame_id": frame_index
                        })
                
            except Exception as e:
                logger.error(f"❌ Error processing chunk: {e}", exc_info=True)
                await websocket.send_json({
                    "error": f"Processing failed: {str(e)}",
                    "frame_id": frame_index
                })
    
    except WebSocketDisconnect:
        logger.info(f"πŸ”Œ WebSocket disconnected: session_id={session_id}")
    except Exception as e:
        logger.error(f"❌ WebSocket error: {e}", exc_info=True)
    finally:
        # Cleanup session
        if session_id in streaming_sessions:
            session_data = streaming_sessions[session_id]
            avg_latency = session_data["total_latency_ms"] / session_data["frame_count"] if session_data["frame_count"] > 0 else 0.0
            logger.info(f"πŸ“Š Session {session_id} stats: {session_data['frame_count']} frames, "
                       f"avg_latency={avg_latency:.1f}ms")
            del streaming_sessions[session_id]