nice-bill commited on
Commit
b17fd2f
·
1 Parent(s): 557bf9a

added 60 sec limit

Browse files
Files changed (1) hide show
  1. src/api/app.py +24 -34
src/api/app.py CHANGED
@@ -19,6 +19,7 @@ app = FastAPI(title="VigilAudio: Optimized API with Real-time Streaming")
19
  # --- CONFIG ---
20
  MODEL_PATH = "models/onnx_quantized"
21
  UPLOAD_DIR = "data/uploads/weak_predictions"
 
22
  os.makedirs(UPLOAD_DIR, exist_ok=True)
23
 
24
  # --- MODEL LOADING ---
@@ -33,7 +34,7 @@ except Exception as e:
33
  model = None
34
 
35
  # --- HELPER FUNCTIONS ---
36
- def segment_audio(audio, sr, window_size=3.0):
37
  """Splits audio into fixed-size windows."""
38
  chunk_len = int(window_size * sr)
39
  for i in range(0, len(audio), chunk_len):
@@ -48,7 +49,7 @@ def save_training_sample(audio_chunk, sr, predicted_emotion, confidence):
48
 
49
  try:
50
  sf.write(path, audio_chunk, sr)
51
- print(f"Saved weak prediction for review: {filename}")
52
  except Exception as e:
53
  print(f"Failed to save sample: {e}")
54
 
@@ -60,12 +61,8 @@ class AudioStreamBuffer:
60
  self.buffer = np.array([], dtype=np.float32)
61
 
62
  def add_chunk(self, chunk_bytes):
63
- # Convert raw bytes to float32 array (assuming 16-bit PCM for now)
64
- # Note: Ideally, we should resample here if input is not 16kHz
65
  chunk = np.frombuffer(chunk_bytes, dtype=np.int16).astype(np.float32) / 32768.0
66
  self.buffer = np.append(self.buffer, chunk)
67
-
68
- # Keep only the last window_size samples (Sliding Window)
69
  if len(self.buffer) > self.window_size:
70
  self.buffer = self.buffer[-self.window_size:]
71
 
@@ -79,7 +76,7 @@ def health():
79
  "status": "online",
80
  "engine": "ONNX Runtime (INT8)",
81
  "model_loaded": model is not None,
82
- "active_learning_path": UPLOAD_DIR
83
  }
84
 
85
  @app.post("/predict")
@@ -87,7 +84,6 @@ async def predict_emotion(file: UploadFile = File(...)):
87
  if model is None:
88
  raise HTTPException(status_code=500, detail="Model weights missing on server.")
89
 
90
- # 1. Save uploaded file to temp
91
  with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(file.filename)[1]) as tmp:
92
  shutil.copyfileobj(file.file, tmp)
93
  tmp_path = tmp.name
@@ -95,13 +91,20 @@ async def predict_emotion(file: UploadFile = File(...)):
95
  try:
96
  # 2. Load and Resample
97
  speech, sr = librosa.load(tmp_path, sr=16000)
98
- duration = librosa.get_duration(y=speech, sr=sr)
99
 
 
 
 
 
 
 
 
100
  timeline = []
101
 
102
  # 3. Process segments
103
- for i, chunk in enumerate(segment_audio(speech, sr, window_size=3.0)):
104
- if len(chunk) < 8000: continue # Skip very small fragments
105
 
106
  inputs = feature_extractor(chunk, sampling_rate=16000, return_tensors="pt", padding=True)
107
 
@@ -114,27 +117,27 @@ async def predict_emotion(file: UploadFile = File(...)):
114
 
115
  emotion_label = id2label[pred_id]
116
 
117
- # --- DATA FLYWHEEL (Active Learning) ---
118
  if confidence < 0.60:
119
  save_training_sample(chunk, sr, emotion_label, confidence)
120
 
121
  timeline.append({
122
- "start_sec": i * 3.0,
123
- "end_sec": min((i + 1) * 3.0, duration),
124
  "emotion": emotion_label,
125
  "confidence": round(confidence, 4)
126
  })
127
 
128
  if not timeline:
129
- raise HTTPException(status_code=400, detail="Audio file too short or empty.")
130
 
131
- # 4. Overall Summary
132
  emotions_list = [seg["emotion"] for seg in timeline]
133
  dominant = max(set(emotions_list), key=emotions_list.count)
134
 
135
  return {
136
  "filename": file.filename,
137
  "duration_seconds": round(duration, 2),
 
 
138
  "dominant_emotion": dominant,
139
  "timeline": timeline
140
  }
@@ -143,7 +146,6 @@ async def predict_emotion(file: UploadFile = File(...)):
143
  print(f"Prediction error: {e}")
144
  raise HTTPException(status_code=500, detail=str(e))
145
  finally:
146
- # 5. Cleanup
147
  if os.path.exists(tmp_path):
148
  os.remove(tmp_path)
149
 
@@ -153,7 +155,6 @@ async def stream_audio(websocket: WebSocket, rate: int = 16000):
153
  print(f"WebSocket Connected (Input Rate: {rate}Hz)")
154
  buffer = AudioStreamBuffer()
155
 
156
- # Pre-configure resampler if rate != 16000
157
  resampler = None
158
  if rate != 16000:
159
  import torchaudio.transforms as T
@@ -162,44 +163,33 @@ async def stream_audio(websocket: WebSocket, rate: int = 16000):
162
  try:
163
  while True:
164
  data = await websocket.receive_bytes()
165
-
166
- # 1. Convert to tensor
167
  chunk = torch.from_numpy(np.frombuffer(data, dtype=np.int16).astype(np.float32) / 32768.0)
168
-
169
- # 2. Resample if necessary
170
  if resampler:
171
  chunk = resampler(chunk)
172
 
173
- # 3. Add to buffer
174
- buffer.add_chunk(chunk.numpy().tobytes()) # Convert back to bytes for the buffer manager
175
 
176
  if buffer.is_ready():
177
  inputs = feature_extractor(buffer.buffer, sampling_rate=16000, return_tensors="pt", padding=True)
178
-
179
  with torch.no_grad():
180
  outputs = model(**inputs)
181
  probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
182
  pred_id = torch.argmax(outputs.logits, dim=-1).item()
183
  confidence = float(probs[0][pred_id])
184
 
185
- # 4. Confidence Threshold (0.85)
186
- # We only send if we are confident, or send a 'low_confidence' status
187
- response = {
188
  "emotion": id2label[pred_id],
189
  "confidence": confidence,
190
  "timestamp": datetime.now().isoformat(),
191
  "status": "high_confidence" if confidence > 0.85 else "low_confidence"
192
- }
193
-
194
- await websocket.send_json(response)
195
 
196
  except WebSocketDisconnect:
197
  print("WebSocket Disconnected")
198
  except Exception as e:
199
  print(f"WebSocket Error: {e}")
200
- try:
201
- await websocket.close()
202
  except: pass
203
 
204
  if __name__ == "__main__":
205
- uvicorn.run(app, host="0.0.0.0", port=8000)
 
19
  # --- CONFIG ---
20
  MODEL_PATH = "models/onnx_quantized"
21
  UPLOAD_DIR = "data/uploads/weak_predictions"
22
+ MAX_DURATION_SEC = 60.0 # Limit batch analysis to 60s for stability
23
  os.makedirs(UPLOAD_DIR, exist_ok=True)
24
 
25
  # --- MODEL LOADING ---
 
34
  model = None
35
 
36
  # --- HELPER FUNCTIONS ---
37
+ def segment_audio(audio, sr, window_size=2.0):
38
  """Splits audio into fixed-size windows."""
39
  chunk_len = int(window_size * sr)
40
  for i in range(0, len(audio), chunk_len):
 
49
 
50
  try:
51
  sf.write(path, audio_chunk, sr)
52
+ print(f"Saved weak prediction: {filename}")
53
  except Exception as e:
54
  print(f"Failed to save sample: {e}")
55
 
 
61
  self.buffer = np.array([], dtype=np.float32)
62
 
63
  def add_chunk(self, chunk_bytes):
 
 
64
  chunk = np.frombuffer(chunk_bytes, dtype=np.int16).astype(np.float32) / 32768.0
65
  self.buffer = np.append(self.buffer, chunk)
 
 
66
  if len(self.buffer) > self.window_size:
67
  self.buffer = self.buffer[-self.window_size:]
68
 
 
76
  "status": "online",
77
  "engine": "ONNX Runtime (INT8)",
78
  "model_loaded": model is not None,
79
+ "max_duration_limit": MAX_DURATION_SEC
80
  }
81
 
82
  @app.post("/predict")
 
84
  if model is None:
85
  raise HTTPException(status_code=500, detail="Model weights missing on server.")
86
 
 
87
  with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(file.filename)[1]) as tmp:
88
  shutil.copyfileobj(file.file, tmp)
89
  tmp_path = tmp.name
 
91
  try:
92
  # 2. Load and Resample
93
  speech, sr = librosa.load(tmp_path, sr=16000)
94
+ original_duration = librosa.get_duration(y=speech, sr=sr)
95
 
96
+ # --- DURATION LIMIT ---
97
+ is_truncated = False
98
+ if original_duration > MAX_DURATION_SEC:
99
+ speech = speech[:int(MAX_DURATION_SEC * sr)]
100
+ is_truncated = True
101
+
102
+ duration = librosa.get_duration(y=speech, sr=sr)
103
  timeline = []
104
 
105
  # 3. Process segments
106
+ for i, chunk in enumerate(segment_audio(speech, sr, window_size=2.0)):
107
+ if len(chunk) < 8000: continue
108
 
109
  inputs = feature_extractor(chunk, sampling_rate=16000, return_tensors="pt", padding=True)
110
 
 
117
 
118
  emotion_label = id2label[pred_id]
119
 
 
120
  if confidence < 0.60:
121
  save_training_sample(chunk, sr, emotion_label, confidence)
122
 
123
  timeline.append({
124
+ "start_sec": i * 2.0,
125
+ "end_sec": min((i + 1) * 2.0, duration),
126
  "emotion": emotion_label,
127
  "confidence": round(confidence, 4)
128
  })
129
 
130
  if not timeline:
131
+ raise HTTPException(status_code=400, detail="Audio content too short.")
132
 
 
133
  emotions_list = [seg["emotion"] for seg in timeline]
134
  dominant = max(set(emotions_list), key=emotions_list.count)
135
 
136
  return {
137
  "filename": file.filename,
138
  "duration_seconds": round(duration, 2),
139
+ "original_duration": round(original_duration, 2),
140
+ "is_truncated": is_truncated,
141
  "dominant_emotion": dominant,
142
  "timeline": timeline
143
  }
 
146
  print(f"Prediction error: {e}")
147
  raise HTTPException(status_code=500, detail=str(e))
148
  finally:
 
149
  if os.path.exists(tmp_path):
150
  os.remove(tmp_path)
151
 
 
155
  print(f"WebSocket Connected (Input Rate: {rate}Hz)")
156
  buffer = AudioStreamBuffer()
157
 
 
158
  resampler = None
159
  if rate != 16000:
160
  import torchaudio.transforms as T
 
163
  try:
164
  while True:
165
  data = await websocket.receive_bytes()
 
 
166
  chunk = torch.from_numpy(np.frombuffer(data, dtype=np.int16).astype(np.float32) / 32768.0)
 
 
167
  if resampler:
168
  chunk = resampler(chunk)
169
 
170
+ buffer.add_chunk(chunk.numpy().tobytes())
 
171
 
172
  if buffer.is_ready():
173
  inputs = feature_extractor(buffer.buffer, sampling_rate=16000, return_tensors="pt", padding=True)
 
174
  with torch.no_grad():
175
  outputs = model(**inputs)
176
  probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
177
  pred_id = torch.argmax(outputs.logits, dim=-1).item()
178
  confidence = float(probs[0][pred_id])
179
 
180
+ await websocket.send_json({
 
 
181
  "emotion": id2label[pred_id],
182
  "confidence": confidence,
183
  "timestamp": datetime.now().isoformat(),
184
  "status": "high_confidence" if confidence > 0.85 else "low_confidence"
185
+ })
 
 
186
 
187
  except WebSocketDisconnect:
188
  print("WebSocket Disconnected")
189
  except Exception as e:
190
  print(f"WebSocket Error: {e}")
191
+ try: await websocket.close()
 
192
  except: pass
193
 
194
  if __name__ == "__main__":
195
+ uvicorn.run(app, host="0.0.0.0", port=8000)