nice-bill commited on
Commit
c9a654c
·
1 Parent(s): bdf62c5

changed fastapi script to use onnx int8

Browse files
Files changed (1) hide show
  1. src/api/app.py +23 -16
src/api/app.py CHANGED
@@ -5,39 +5,45 @@ import os
5
  import torch
6
  import librosa
7
  import numpy as np
8
- from transformers import AutoFeatureExtractor, Wav2Vec2ForSequenceClassification
 
9
  from typing import List, Dict
10
  import tempfile
11
 
12
- app = FastAPI(title="VigilAudio Emotion API")
13
 
14
- MODEL_PATH = "models/wav2vec2-finetuned"
15
- DEVICE = torch.device("cpu")
 
16
 
17
- print(f"Loading model into API memory...")
 
18
  try:
19
  feature_extractor = AutoFeatureExtractor.from_pretrained(MODEL_PATH)
20
- model = Wav2Vec2ForSequenceClassification.from_pretrained(MODEL_PATH)
21
- model.to(DEVICE)
22
- model.eval()
 
23
  id2label = model.config.id2label
24
- print(f"API Model Ready. Labels: {list(id2label.values())}")
25
  except Exception as e:
26
  print(f"API Failed to load model: {e}")
27
  model = None
28
 
 
29
  def segment_audio(audio, sr, window_size=3.0):
30
- """Splits audio into fixed-size windows."""
31
  chunk_len = int(window_size * sr)
32
  for i in range(0, len(audio), chunk_len):
33
  yield audio[i:i + chunk_len]
34
 
 
35
  @app.get("/health")
36
  def health():
37
  return {
38
  "status": "online",
 
39
  "model_loaded": model is not None,
40
- "device": str(DEVICE)
41
  }
42
 
43
  @app.post("/predict")
@@ -56,13 +62,14 @@ async def predict_emotion(file: UploadFile = File(...)):
56
  timeline = []
57
 
58
  for i, chunk in enumerate(segment_audio(speech, sr, window_size=3.0)):
59
- if len(chunk) < 8000:
60
- continue
61
 
62
  inputs = feature_extractor(chunk, sampling_rate=16000, return_tensors="pt", padding=True)
63
 
 
64
  with torch.no_grad():
65
- logits = model(inputs.input_values.to(DEVICE)).logits
 
66
  probs = torch.nn.functional.softmax(logits, dim=-1)
67
  pred_id = torch.argmax(logits, dim=-1).item()
68
 
@@ -78,6 +85,7 @@ async def predict_emotion(file: UploadFile = File(...)):
78
 
79
  return {
80
  "filename": file.filename,
 
81
  "duration_seconds": round(duration, 2),
82
  "dominant_emotion": dominant,
83
  "timeline": timeline
@@ -87,9 +95,8 @@ async def predict_emotion(file: UploadFile = File(...)):
87
  print(f"Prediction error: {e}")
88
  raise HTTPException(status_code=500, detail=str(e))
89
  finally:
90
- # Cleanup temp file
91
  if os.path.exists(tmp_path):
92
  os.remove(tmp_path)
93
 
94
  if __name__ == "__main__":
95
- uvicorn.run(app, host="0.0.0.0", port=8000)
 
5
  import torch
6
  import librosa
7
  import numpy as np
8
+ from optimum.onnxruntime import ORTModelForAudioClassification
9
+ from transformers import AutoFeatureExtractor
10
  from typing import List, Dict
11
  import tempfile
12
 
13
+ app = FastAPI(title="VigilAudio Optimized API")
14
 
15
+ # --- CONFIG ---
16
+ # We use the INT8 model which proved to be the fastest in benchmarks
17
+ MODEL_PATH = "models/onnx_quantized"
18
 
19
+ # --- MODEL LOADING (Optimized with ONNX) ---
20
+ print(f"Loading OPTIMIZED INT8 ONNX model into memory...")
21
  try:
22
  feature_extractor = AutoFeatureExtractor.from_pretrained(MODEL_PATH)
23
+ # Note: we explicitly pass file_name since optimum expects model.onnx by default
24
+ model = ORTModelForAudioClassification.from_pretrained(MODEL_PATH, file_name="model_quantized.onnx")
25
+
26
+ # Label mapping from config
27
  id2label = model.config.id2label
28
+ print(f"Optimized API Ready. Speedup expected: ~1.8x")
29
  except Exception as e:
30
  print(f"API Failed to load model: {e}")
31
  model = None
32
 
33
+ # --- UTILS ---
34
  def segment_audio(audio, sr, window_size=3.0):
 
35
  chunk_len = int(window_size * sr)
36
  for i in range(0, len(audio), chunk_len):
37
  yield audio[i:i + chunk_len]
38
 
39
+ # --- ENDPOINTS ---
40
  @app.get("/health")
41
  def health():
42
  return {
43
  "status": "online",
44
+ "engine": "ONNX Runtime (INT8)",
45
  "model_loaded": model is not None,
46
+ "labels": list(id2label.values()) if model else []
47
  }
48
 
49
  @app.post("/predict")
 
62
  timeline = []
63
 
64
  for i, chunk in enumerate(segment_audio(speech, sr, window_size=3.0)):
65
+ if len(chunk) < 8000: continue
 
66
 
67
  inputs = feature_extractor(chunk, sampling_rate=16000, return_tensors="pt", padding=True)
68
 
69
+ # ONNX Inference
70
  with torch.no_grad():
71
+ outputs = model(**inputs)
72
+ logits = outputs.logits
73
  probs = torch.nn.functional.softmax(logits, dim=-1)
74
  pred_id = torch.argmax(logits, dim=-1).item()
75
 
 
85
 
86
  return {
87
  "filename": file.filename,
88
+ "engine": "ONNX_INT8",
89
  "duration_seconds": round(duration, 2),
90
  "dominant_emotion": dominant,
91
  "timeline": timeline
 
95
  print(f"Prediction error: {e}")
96
  raise HTTPException(status_code=500, detail=str(e))
97
  finally:
 
98
  if os.path.exists(tmp_path):
99
  os.remove(tmp_path)
100
 
101
  if __name__ == "__main__":
102
+ uvicorn.run(app, host="0.0.0.0", port=8000)