File size: 7,111 Bytes
14b7228
bdf62c5
 
 
 
 
 
c9a654c
 
bdf62c5
 
48cf750
 
 
14b7228
bdf62c5
14b7228
bdf62c5
14b7228
c9a654c
48cf750
b17fd2f
48cf750
bdf62c5
14b7228
 
bdf62c5
 
c9a654c
bdf62c5
48cf750
bdf62c5
48cf750
bdf62c5
 
14b7228
b17fd2f
48cf750
bdf62c5
 
 
 
48cf750
 
 
 
 
 
 
 
 
b17fd2f
48cf750
 
 
14b7228
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bdf62c5
 
 
 
c9a654c
bdf62c5
b17fd2f
bdf62c5
 
 
 
 
 
 
 
 
 
 
 
14b7228
bdf62c5
b17fd2f
bdf62c5
b17fd2f
 
 
 
 
 
 
bdf62c5
 
14b7228
b17fd2f
 
bdf62c5
 
 
 
c9a654c
 
bdf62c5
 
48cf750
 
 
 
 
 
bdf62c5
 
b17fd2f
 
48cf750
 
bdf62c5
 
48cf750
b17fd2f
48cf750
bdf62c5
48cf750
bdf62c5
 
 
 
b17fd2f
 
bdf62c5
 
 
 
 
 
 
 
 
 
 
14b7228
 
 
5f7f23e
14b7228
 
 
 
 
 
 
 
 
 
 
 
 
 
b17fd2f
14b7228
 
 
 
 
 
 
 
 
b17fd2f
14b7228
 
 
 
b17fd2f
14b7228
 
5f7f23e
14b7228
5f7f23e
b17fd2f
14b7228
 
bdf62c5
b17fd2f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
from fastapi import FastAPI, UploadFile, File, HTTPException, WebSocket, WebSocketDisconnect
import uvicorn
import shutil
import os
import torch
import librosa
import numpy as np
from optimum.onnxruntime import ORTModelForAudioClassification
from transformers import AutoFeatureExtractor
from typing import List, Dict
import tempfile
import soundfile as sf
import uuid
from datetime import datetime
import asyncio

app = FastAPI(title="VigilAudio: Optimized API with Real-time Streaming")

# --- CONFIG ---
MODEL_PATH = "models/onnx_quantized"
UPLOAD_DIR = "data/uploads/weak_predictions"
MAX_DURATION_SEC = 60.0 # Limit batch analysis to 60s for stability
os.makedirs(UPLOAD_DIR, exist_ok=True)

# --- MODEL LOADING ---
print(f"Loading optimized INT8 model...")
try:
    feature_extractor = AutoFeatureExtractor.from_pretrained(MODEL_PATH)
    model = ORTModelForAudioClassification.from_pretrained(MODEL_PATH, file_name="model_quantized.onnx")
    id2label = model.config.id2label
    print(f"API Ready. Labels: {list(id2label.values())}")
except Exception as e:
    print(f"Failed to load model: {e}")
    model = None

# --- HELPER FUNCTIONS ---
def segment_audio(audio, sr, window_size=2.0):
    """Splits audio into fixed-size windows."""
    chunk_len = int(window_size * sr)
    for i in range(0, len(audio), chunk_len):
        yield audio[i:i + chunk_len]

def save_training_sample(audio_chunk, sr, predicted_emotion, confidence):
    """Saves low-confidence chunks for future training (Active Learning)."""
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    unique_id = str(uuid.uuid4())[:8]
    filename = f"{timestamp}_{predicted_emotion}_{confidence:.2f}_{unique_id}.wav"
    path = os.path.join(UPLOAD_DIR, filename)
    
    try:
        sf.write(path, audio_chunk, sr)
        print(f"Saved weak prediction: {filename}")
    except Exception as e:
        print(f"Failed to save sample: {e}")

# --- STREAMING MANAGER ---
class AudioStreamBuffer:
    def __init__(self, sample_rate=16000, window_size_sec=2.0):
        self.sr = sample_rate
        self.window_size = int(sample_rate * window_size_sec)
        self.buffer = np.array([], dtype=np.float32)

    def add_chunk(self, chunk_bytes):
        chunk = np.frombuffer(chunk_bytes, dtype=np.int16).astype(np.float32) / 32768.0
        self.buffer = np.append(self.buffer, chunk)
        if len(self.buffer) > self.window_size:
            self.buffer = self.buffer[-self.window_size:]

    def is_ready(self):
        return len(self.buffer) >= self.window_size

# --- ENDPOINTS ---
@app.get("/health")
def health():
    return {
        "status": "online",
        "engine": "ONNX Runtime (INT8)",
        "model_loaded": model is not None,
        "max_duration_limit": MAX_DURATION_SEC
    }

@app.post("/predict")
async def predict_emotion(file: UploadFile = File(...)):
    if model is None:
        raise HTTPException(status_code=500, detail="Model weights missing on server.")

    with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(file.filename)[1]) as tmp:
        shutil.copyfileobj(file.file, tmp)
        tmp_path = tmp.name

    try:
        # 2. Load and Resample
        speech, sr = librosa.load(tmp_path, sr=16000)
        original_duration = librosa.get_duration(y=speech, sr=sr)
        
        # --- DURATION LIMIT ---
        is_truncated = False
        if original_duration > MAX_DURATION_SEC:
            speech = speech[:int(MAX_DURATION_SEC * sr)]
            is_truncated = True
        
        duration = librosa.get_duration(y=speech, sr=sr)
        timeline = []
        
        # 3. Process segments
        for i, chunk in enumerate(segment_audio(speech, sr, window_size=2.0)):
            if len(chunk) < 8000: continue
                
            inputs = feature_extractor(chunk, sampling_rate=16000, return_tensors="pt", padding=True)
            
            with torch.no_grad():
                outputs = model(**inputs)
                logits = outputs.logits
                probs = torch.nn.functional.softmax(logits, dim=-1)
                pred_id = torch.argmax(logits, dim=-1).item()
                confidence = float(probs[0][pred_id])
            
            emotion_label = id2label[pred_id]
            
            if confidence < 0.60:
                save_training_sample(chunk, sr, emotion_label, confidence)
            
            timeline.append({
                "start_sec": i * 2.0,
                "end_sec": min((i + 1) * 2.0, duration),
                "emotion": emotion_label,
                "confidence": round(confidence, 4)
            })

        if not timeline:
            raise HTTPException(status_code=400, detail="Audio content too short.")

        emotions_list = [seg["emotion"] for seg in timeline]
        dominant = max(set(emotions_list), key=emotions_list.count)

        return {
            "filename": file.filename,
            "duration_seconds": round(duration, 2),
            "original_duration": round(original_duration, 2),
            "is_truncated": is_truncated,
            "dominant_emotion": dominant,
            "timeline": timeline
        }

    except Exception as e:
        print(f"Prediction error: {e}")
        raise HTTPException(status_code=500, detail=str(e))
    finally:
        if os.path.exists(tmp_path):
            os.remove(tmp_path)

@app.websocket("/stream/audio")
async def stream_audio(websocket: WebSocket, rate: int = 16000):
    await websocket.accept()
    print(f"WebSocket Connected (Input Rate: {rate}Hz)")
    buffer = AudioStreamBuffer()
    
    resampler = None
    if rate != 16000:
        import torchaudio.transforms as T
        resampler = T.Resample(rate, 16000)

    try:
        while True:
            data = await websocket.receive_bytes()
            chunk = torch.from_numpy(np.frombuffer(data, dtype=np.int16).astype(np.float32) / 32768.0)
            if resampler:
                chunk = resampler(chunk)
            
            buffer.add_chunk(chunk.numpy().tobytes())
            
            if buffer.is_ready():
                inputs = feature_extractor(buffer.buffer, sampling_rate=16000, return_tensors="pt", padding=True)
                with torch.no_grad():
                    outputs = model(**inputs)
                    probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
                    pred_id = torch.argmax(outputs.logits, dim=-1).item()
                    confidence = float(probs[0][pred_id])
                
                await websocket.send_json({
                    "emotion": id2label[pred_id],
                    "confidence": confidence,
                    "timestamp": datetime.now().isoformat(),
                    "status": "high_confidence" if confidence > 0.85 else "low_confidence"
                })
                
    except WebSocketDisconnect:
        print("WebSocket Disconnected")
    except Exception as e:
        print(f"WebSocket Error: {e}")
        try: await websocket.close()
        except: pass

if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=8000)