vigilaudio / src /api /app.py
nice-bill's picture
added 60 sec limit
b17fd2f
from fastapi import FastAPI, UploadFile, File, HTTPException, WebSocket, WebSocketDisconnect
import uvicorn
import shutil
import os
import torch
import librosa
import numpy as np
from optimum.onnxruntime import ORTModelForAudioClassification
from transformers import AutoFeatureExtractor
from typing import List, Dict
import tempfile
import soundfile as sf
import uuid
from datetime import datetime
import asyncio
app = FastAPI(title="VigilAudio: Optimized API with Real-time Streaming")
# --- CONFIG ---
MODEL_PATH = "models/onnx_quantized"
UPLOAD_DIR = "data/uploads/weak_predictions"
MAX_DURATION_SEC = 60.0 # Limit batch analysis to 60s for stability
os.makedirs(UPLOAD_DIR, exist_ok=True)
# --- MODEL LOADING ---
print(f"Loading optimized INT8 model...")
try:
feature_extractor = AutoFeatureExtractor.from_pretrained(MODEL_PATH)
model = ORTModelForAudioClassification.from_pretrained(MODEL_PATH, file_name="model_quantized.onnx")
id2label = model.config.id2label
print(f"API Ready. Labels: {list(id2label.values())}")
except Exception as e:
print(f"Failed to load model: {e}")
model = None
# --- HELPER FUNCTIONS ---
def segment_audio(audio, sr, window_size=2.0):
"""Splits audio into fixed-size windows."""
chunk_len = int(window_size * sr)
for i in range(0, len(audio), chunk_len):
yield audio[i:i + chunk_len]
def save_training_sample(audio_chunk, sr, predicted_emotion, confidence):
"""Saves low-confidence chunks for future training (Active Learning)."""
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
unique_id = str(uuid.uuid4())[:8]
filename = f"{timestamp}_{predicted_emotion}_{confidence:.2f}_{unique_id}.wav"
path = os.path.join(UPLOAD_DIR, filename)
try:
sf.write(path, audio_chunk, sr)
print(f"Saved weak prediction: {filename}")
except Exception as e:
print(f"Failed to save sample: {e}")
# --- STREAMING MANAGER ---
class AudioStreamBuffer:
def __init__(self, sample_rate=16000, window_size_sec=2.0):
self.sr = sample_rate
self.window_size = int(sample_rate * window_size_sec)
self.buffer = np.array([], dtype=np.float32)
def add_chunk(self, chunk_bytes):
chunk = np.frombuffer(chunk_bytes, dtype=np.int16).astype(np.float32) / 32768.0
self.buffer = np.append(self.buffer, chunk)
if len(self.buffer) > self.window_size:
self.buffer = self.buffer[-self.window_size:]
def is_ready(self):
return len(self.buffer) >= self.window_size
# --- ENDPOINTS ---
@app.get("/health")
def health():
return {
"status": "online",
"engine": "ONNX Runtime (INT8)",
"model_loaded": model is not None,
"max_duration_limit": MAX_DURATION_SEC
}
@app.post("/predict")
async def predict_emotion(file: UploadFile = File(...)):
if model is None:
raise HTTPException(status_code=500, detail="Model weights missing on server.")
with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(file.filename)[1]) as tmp:
shutil.copyfileobj(file.file, tmp)
tmp_path = tmp.name
try:
# 2. Load and Resample
speech, sr = librosa.load(tmp_path, sr=16000)
original_duration = librosa.get_duration(y=speech, sr=sr)
# --- DURATION LIMIT ---
is_truncated = False
if original_duration > MAX_DURATION_SEC:
speech = speech[:int(MAX_DURATION_SEC * sr)]
is_truncated = True
duration = librosa.get_duration(y=speech, sr=sr)
timeline = []
# 3. Process segments
for i, chunk in enumerate(segment_audio(speech, sr, window_size=2.0)):
if len(chunk) < 8000: continue
inputs = feature_extractor(chunk, sampling_rate=16000, return_tensors="pt", padding=True)
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
probs = torch.nn.functional.softmax(logits, dim=-1)
pred_id = torch.argmax(logits, dim=-1).item()
confidence = float(probs[0][pred_id])
emotion_label = id2label[pred_id]
if confidence < 0.60:
save_training_sample(chunk, sr, emotion_label, confidence)
timeline.append({
"start_sec": i * 2.0,
"end_sec": min((i + 1) * 2.0, duration),
"emotion": emotion_label,
"confidence": round(confidence, 4)
})
if not timeline:
raise HTTPException(status_code=400, detail="Audio content too short.")
emotions_list = [seg["emotion"] for seg in timeline]
dominant = max(set(emotions_list), key=emotions_list.count)
return {
"filename": file.filename,
"duration_seconds": round(duration, 2),
"original_duration": round(original_duration, 2),
"is_truncated": is_truncated,
"dominant_emotion": dominant,
"timeline": timeline
}
except Exception as e:
print(f"Prediction error: {e}")
raise HTTPException(status_code=500, detail=str(e))
finally:
if os.path.exists(tmp_path):
os.remove(tmp_path)
@app.websocket("/stream/audio")
async def stream_audio(websocket: WebSocket, rate: int = 16000):
await websocket.accept()
print(f"WebSocket Connected (Input Rate: {rate}Hz)")
buffer = AudioStreamBuffer()
resampler = None
if rate != 16000:
import torchaudio.transforms as T
resampler = T.Resample(rate, 16000)
try:
while True:
data = await websocket.receive_bytes()
chunk = torch.from_numpy(np.frombuffer(data, dtype=np.int16).astype(np.float32) / 32768.0)
if resampler:
chunk = resampler(chunk)
buffer.add_chunk(chunk.numpy().tobytes())
if buffer.is_ready():
inputs = feature_extractor(buffer.buffer, sampling_rate=16000, return_tensors="pt", padding=True)
with torch.no_grad():
outputs = model(**inputs)
probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
pred_id = torch.argmax(outputs.logits, dim=-1).item()
confidence = float(probs[0][pred_id])
await websocket.send_json({
"emotion": id2label[pred_id],
"confidence": confidence,
"timestamp": datetime.now().isoformat(),
"status": "high_confidence" if confidence > 0.85 else "low_confidence"
})
except WebSocketDisconnect:
print("WebSocket Disconnected")
except Exception as e:
print(f"WebSocket Error: {e}")
try: await websocket.close()
except: pass
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000)