Spaces:
Sleeping
Sleeping
File size: 7,111 Bytes
14b7228 bdf62c5 c9a654c bdf62c5 48cf750 14b7228 bdf62c5 14b7228 bdf62c5 14b7228 c9a654c 48cf750 b17fd2f 48cf750 bdf62c5 14b7228 bdf62c5 c9a654c bdf62c5 48cf750 bdf62c5 48cf750 bdf62c5 14b7228 b17fd2f 48cf750 bdf62c5 48cf750 b17fd2f 48cf750 14b7228 bdf62c5 c9a654c bdf62c5 b17fd2f bdf62c5 14b7228 bdf62c5 b17fd2f bdf62c5 b17fd2f bdf62c5 14b7228 b17fd2f bdf62c5 c9a654c bdf62c5 48cf750 bdf62c5 b17fd2f 48cf750 bdf62c5 48cf750 b17fd2f 48cf750 bdf62c5 48cf750 bdf62c5 b17fd2f bdf62c5 14b7228 5f7f23e 14b7228 b17fd2f 14b7228 b17fd2f 14b7228 b17fd2f 14b7228 5f7f23e 14b7228 5f7f23e b17fd2f 14b7228 bdf62c5 b17fd2f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 | from fastapi import FastAPI, UploadFile, File, HTTPException, WebSocket, WebSocketDisconnect
import uvicorn
import shutil
import os
import torch
import librosa
import numpy as np
from optimum.onnxruntime import ORTModelForAudioClassification
from transformers import AutoFeatureExtractor
from typing import List, Dict
import tempfile
import soundfile as sf
import uuid
from datetime import datetime
import asyncio
app = FastAPI(title="VigilAudio: Optimized API with Real-time Streaming")
# --- CONFIG ---
MODEL_PATH = "models/onnx_quantized"
UPLOAD_DIR = "data/uploads/weak_predictions"
MAX_DURATION_SEC = 60.0 # Limit batch analysis to 60s for stability
os.makedirs(UPLOAD_DIR, exist_ok=True)
# --- MODEL LOADING ---
print(f"Loading optimized INT8 model...")
try:
feature_extractor = AutoFeatureExtractor.from_pretrained(MODEL_PATH)
model = ORTModelForAudioClassification.from_pretrained(MODEL_PATH, file_name="model_quantized.onnx")
id2label = model.config.id2label
print(f"API Ready. Labels: {list(id2label.values())}")
except Exception as e:
print(f"Failed to load model: {e}")
model = None
# --- HELPER FUNCTIONS ---
def segment_audio(audio, sr, window_size=2.0):
"""Splits audio into fixed-size windows."""
chunk_len = int(window_size * sr)
for i in range(0, len(audio), chunk_len):
yield audio[i:i + chunk_len]
def save_training_sample(audio_chunk, sr, predicted_emotion, confidence):
"""Saves low-confidence chunks for future training (Active Learning)."""
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
unique_id = str(uuid.uuid4())[:8]
filename = f"{timestamp}_{predicted_emotion}_{confidence:.2f}_{unique_id}.wav"
path = os.path.join(UPLOAD_DIR, filename)
try:
sf.write(path, audio_chunk, sr)
print(f"Saved weak prediction: {filename}")
except Exception as e:
print(f"Failed to save sample: {e}")
# --- STREAMING MANAGER ---
class AudioStreamBuffer:
def __init__(self, sample_rate=16000, window_size_sec=2.0):
self.sr = sample_rate
self.window_size = int(sample_rate * window_size_sec)
self.buffer = np.array([], dtype=np.float32)
def add_chunk(self, chunk_bytes):
chunk = np.frombuffer(chunk_bytes, dtype=np.int16).astype(np.float32) / 32768.0
self.buffer = np.append(self.buffer, chunk)
if len(self.buffer) > self.window_size:
self.buffer = self.buffer[-self.window_size:]
def is_ready(self):
return len(self.buffer) >= self.window_size
# --- ENDPOINTS ---
@app.get("/health")
def health():
return {
"status": "online",
"engine": "ONNX Runtime (INT8)",
"model_loaded": model is not None,
"max_duration_limit": MAX_DURATION_SEC
}
@app.post("/predict")
async def predict_emotion(file: UploadFile = File(...)):
if model is None:
raise HTTPException(status_code=500, detail="Model weights missing on server.")
with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(file.filename)[1]) as tmp:
shutil.copyfileobj(file.file, tmp)
tmp_path = tmp.name
try:
# 2. Load and Resample
speech, sr = librosa.load(tmp_path, sr=16000)
original_duration = librosa.get_duration(y=speech, sr=sr)
# --- DURATION LIMIT ---
is_truncated = False
if original_duration > MAX_DURATION_SEC:
speech = speech[:int(MAX_DURATION_SEC * sr)]
is_truncated = True
duration = librosa.get_duration(y=speech, sr=sr)
timeline = []
# 3. Process segments
for i, chunk in enumerate(segment_audio(speech, sr, window_size=2.0)):
if len(chunk) < 8000: continue
inputs = feature_extractor(chunk, sampling_rate=16000, return_tensors="pt", padding=True)
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
probs = torch.nn.functional.softmax(logits, dim=-1)
pred_id = torch.argmax(logits, dim=-1).item()
confidence = float(probs[0][pred_id])
emotion_label = id2label[pred_id]
if confidence < 0.60:
save_training_sample(chunk, sr, emotion_label, confidence)
timeline.append({
"start_sec": i * 2.0,
"end_sec": min((i + 1) * 2.0, duration),
"emotion": emotion_label,
"confidence": round(confidence, 4)
})
if not timeline:
raise HTTPException(status_code=400, detail="Audio content too short.")
emotions_list = [seg["emotion"] for seg in timeline]
dominant = max(set(emotions_list), key=emotions_list.count)
return {
"filename": file.filename,
"duration_seconds": round(duration, 2),
"original_duration": round(original_duration, 2),
"is_truncated": is_truncated,
"dominant_emotion": dominant,
"timeline": timeline
}
except Exception as e:
print(f"Prediction error: {e}")
raise HTTPException(status_code=500, detail=str(e))
finally:
if os.path.exists(tmp_path):
os.remove(tmp_path)
@app.websocket("/stream/audio")
async def stream_audio(websocket: WebSocket, rate: int = 16000):
await websocket.accept()
print(f"WebSocket Connected (Input Rate: {rate}Hz)")
buffer = AudioStreamBuffer()
resampler = None
if rate != 16000:
import torchaudio.transforms as T
resampler = T.Resample(rate, 16000)
try:
while True:
data = await websocket.receive_bytes()
chunk = torch.from_numpy(np.frombuffer(data, dtype=np.int16).astype(np.float32) / 32768.0)
if resampler:
chunk = resampler(chunk)
buffer.add_chunk(chunk.numpy().tobytes())
if buffer.is_ready():
inputs = feature_extractor(buffer.buffer, sampling_rate=16000, return_tensors="pt", padding=True)
with torch.no_grad():
outputs = model(**inputs)
probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
pred_id = torch.argmax(outputs.logits, dim=-1).item()
confidence = float(probs[0][pred_id])
await websocket.send_json({
"emotion": id2label[pred_id],
"confidence": confidence,
"timestamp": datetime.now().isoformat(),
"status": "high_confidence" if confidence > 0.85 else "low_confidence"
})
except WebSocketDisconnect:
print("WebSocket Disconnected")
except Exception as e:
print(f"WebSocket Error: {e}")
try: await websocket.close()
except: pass
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000)
|