marshal-yash commited on
Commit
6317c91
·
verified ·
1 Parent(s): 77029a5

Update server.py

Browse files
Files changed (1) hide show
  1. server.py +55 -61
server.py CHANGED
@@ -1,13 +1,22 @@
1
- import os, io, subprocess, tempfile
 
 
 
 
 
 
 
2
  from fastapi import FastAPI, UploadFile, File
3
  from fastapi.responses import JSONResponse
4
  from fastapi.middleware.cors import CORSMiddleware
5
  from transformers import Wav2Vec2ForSequenceClassification, AutoFeatureExtractor
6
- import librosa
7
- import torch, numpy as np, soundfile as sf
8
 
 
 
 
9
  app = FastAPI()
10
 
 
11
  origins = os.environ.get('CORS_ORIGINS', '*').split(',') if os.environ.get('CORS_ORIGINS') else ['*']
12
  app.add_middleware(
13
  CORSMiddleware,
@@ -17,22 +26,26 @@ app.add_middleware(
17
  allow_headers=["*"],
18
  )
19
 
 
 
 
20
  FFMPEG_BIN = os.environ.get('FFMPEG_BIN', 'ffmpeg')
 
21
 
22
- # FIX: MODEL IS IN THE SAME FOLDER AS server.py
23
- MODEL_DIR = os.path.dirname(__file__)
24
 
25
- # Load model + feature extractor
26
- model = Wav2Vec2ForSequenceClassification.from_pretrained(MODEL_DIR)
27
- fe = AutoFeatureExtractor.from_pretrained(MODEL_DIR)
28
-
29
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
30
  model.to(device)
31
  model.eval()
32
 
33
-
 
 
34
  def to_wav16k_mono(data: bytes) -> np.ndarray:
35
  try:
 
36
  p = subprocess.run(
37
  [FFMPEG_BIN, '-hide_banner', '-loglevel', 'error',
38
  '-i', 'pipe:0', '-ar', str(fe.sampling_rate), '-ac', '1',
@@ -40,60 +53,49 @@ def to_wav16k_mono(data: bytes) -> np.ndarray:
40
  input=data, stdout=subprocess.PIPE, stderr=subprocess.PIPE, check=True
41
  )
42
  audio, sr = sf.read(io.BytesIO(p.stdout), dtype='float32', always_2d=False)
43
- if isinstance(audio, np.ndarray):
44
- out = audio.astype(np.float32)
45
- if sr != fe.sampling_rate:
46
- out = librosa.resample(out, orig_sr=sr, target_sr=fe.sampling_rate)
47
- if out.size < fe.sampling_rate // 10:
48
- out = np.pad(out, (0, max(0, fe.sampling_rate - out.size)), mode='constant')
49
- return out
50
- return np.array(audio, dtype=np.float32)
51
-
52
  except Exception:
 
53
  try:
54
  audio, sr = sf.read(io.BytesIO(data), dtype='float32', always_2d=False)
55
- if isinstance(audio, np.ndarray):
56
- if audio.ndim > 1:
57
- audio = np.mean(audio, axis=1)
58
- out = audio.astype(np.float32)
59
- if sr != fe.sampling_rate:
60
- out = librosa.resample(out, orig_sr=sr, target_sr=fe.sampling_rate)
61
- if out.size < fe.sampling_rate // 10:
62
- out = np.pad(out, (0, max(0, fe.sampling_rate - out.size)), mode='constant')
63
- return out
64
-
65
- arr = np.array(audio, dtype=np.float32)
66
- if sr and sr != fe.sampling_rate:
67
- arr = librosa.resample(arr, orig_sr=sr, target_sr=fe.sampling_rate)
68
- if arr.size < fe.sampling_rate // 10:
69
- arr = np.pad(arr, (0, max(0, fe.sampling_rate - arr.size)), mode='constant')
70
- return arr
71
-
72
  except Exception:
73
- try:
74
- with tempfile.NamedTemporaryFile(delete=True, suffix='.audio') as tmp:
75
- tmp.write(data)
76
- tmp.flush()
77
- y, _sr = librosa.load(tmp.name, sr=fe.sampling_rate, mono=True)
78
- return y.astype(np.float32)
79
- except Exception:
80
- return np.zeros(fe.sampling_rate, dtype=np.float32)
81
-
 
 
 
 
82
 
83
- @app.post('/predict')
84
  async def predict(file: UploadFile = File(...)):
85
  try:
 
86
  data = await file.read()
87
  audio = to_wav16k_mono(data)
88
 
89
- inputs = fe(audio, sampling_rate=fe.sampling_rate, return_tensors='pt')
 
90
  inputs = {k: v.to(device) for k, v in inputs.items()}
91
 
 
92
  with torch.no_grad():
93
  logits = model(**inputs).logits
94
 
95
  probs = torch.softmax(logits, dim=-1)[0].cpu().numpy()
96
-
97
  label_map = model.config.id2label
98
  labels = [label_map.get(str(i), f"class_{i}") for i in range(len(probs))]
99
 
@@ -103,23 +105,15 @@ async def predict(file: UploadFile = File(...)):
103
  reverse=True
104
  )
105
 
106
- dominant = {
107
- 'label': pairs[0][0],
108
- 'score': pairs[0][1]
109
- } if pairs else {'label': '', 'score': 0.0}
110
 
111
  return {
112
- 'results': [{'label': l, 'score': s} for l, s in pairs],
113
- 'dominant': dominant
114
  }
115
 
116
  except Exception as e:
117
  return JSONResponse(
118
  status_code=400,
119
- content={'error': 'failed to process audio', 'message': f"{e.__class__.__name__}: {e}"}
120
  )
121
-
122
-
123
- @app.get('/')
124
- def root():
125
- return {'status': 'ok'}
 
1
+ import os
2
+ import io
3
+ import tempfile
4
+ import subprocess
5
+ import numpy as np
6
+ import torch
7
+ import librosa
8
+ import soundfile as sf
9
  from fastapi import FastAPI, UploadFile, File
10
  from fastapi.responses import JSONResponse
11
  from fastapi.middleware.cors import CORSMiddleware
12
  from transformers import Wav2Vec2ForSequenceClassification, AutoFeatureExtractor
 
 
13
 
14
+ # ---------------------------
15
+ # FastAPI setup
16
+ # ---------------------------
17
  app = FastAPI()
18
 
19
+ # CORS config
20
  origins = os.environ.get('CORS_ORIGINS', '*').split(',') if os.environ.get('CORS_ORIGINS') else ['*']
21
  app.add_middleware(
22
  CORSMiddleware,
 
26
  allow_headers=["*"],
27
  )
28
 
29
+ # ---------------------------
30
+ # Model setup
31
+ # ---------------------------
32
  FFMPEG_BIN = os.environ.get('FFMPEG_BIN', 'ffmpeg')
33
+ MODEL_REPO = "marshal-yash/SER_wav2vec"
34
 
35
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
36
 
37
+ # Load model and feature extractor from Hugging Face
38
+ model = Wav2Vec2ForSequenceClassification.from_pretrained(MODEL_REPO)
39
+ fe = AutoFeatureExtractor.from_pretrained(MODEL_REPO)
 
 
40
  model.to(device)
41
  model.eval()
42
 
43
+ # ---------------------------
44
+ # Utility: Convert audio to 16kHz mono
45
+ # ---------------------------
46
  def to_wav16k_mono(data: bytes) -> np.ndarray:
47
  try:
48
+ # Use ffmpeg if available
49
  p = subprocess.run(
50
  [FFMPEG_BIN, '-hide_banner', '-loglevel', 'error',
51
  '-i', 'pipe:0', '-ar', str(fe.sampling_rate), '-ac', '1',
 
53
  input=data, stdout=subprocess.PIPE, stderr=subprocess.PIPE, check=True
54
  )
55
  audio, sr = sf.read(io.BytesIO(p.stdout), dtype='float32', always_2d=False)
56
+ if sr != fe.sampling_rate:
57
+ audio = librosa.resample(audio, orig_sr=sr, target_sr=fe.sampling_rate)
58
+ return audio.astype(np.float32)
 
 
 
 
 
 
59
  except Exception:
60
+ # fallback: try reading directly with soundfile / librosa
61
  try:
62
  audio, sr = sf.read(io.BytesIO(data), dtype='float32', always_2d=False)
63
+ if audio.ndim > 1:
64
+ audio = np.mean(audio, axis=1)
65
+ if sr != fe.sampling_rate:
66
+ audio = librosa.resample(audio, orig_sr=sr, target_sr=fe.sampling_rate)
67
+ return audio.astype(np.float32)
 
 
 
 
 
 
 
 
 
 
 
 
68
  except Exception:
69
+ # last fallback
70
+ with tempfile.NamedTemporaryFile(delete=True, suffix='.audio') as tmp:
71
+ tmp.write(data)
72
+ tmp.flush()
73
+ y, _ = librosa.load(tmp.name, sr=fe.sampling_rate, mono=True)
74
+ return y.astype(np.float32)
75
+
76
+ # ---------------------------
77
+ # Routes
78
+ # ---------------------------
79
+ @app.get("/")
80
+ def root():
81
+ return {"status": "ok"}
82
 
83
+ @app.post("/predict")
84
  async def predict(file: UploadFile = File(...)):
85
  try:
86
+ # Read audio file
87
  data = await file.read()
88
  audio = to_wav16k_mono(data)
89
 
90
+ # Extract features
91
+ inputs = fe(audio, sampling_rate=fe.sampling_rate, return_tensors="pt")
92
  inputs = {k: v.to(device) for k, v in inputs.items()}
93
 
94
+ # Forward pass
95
  with torch.no_grad():
96
  logits = model(**inputs).logits
97
 
98
  probs = torch.softmax(logits, dim=-1)[0].cpu().numpy()
 
99
  label_map = model.config.id2label
100
  labels = [label_map.get(str(i), f"class_{i}") for i in range(len(probs))]
101
 
 
105
  reverse=True
106
  )
107
 
108
+ dominant = {"label": pairs[0][0], "score": pairs[0][1]} if pairs else {"label": "", "score": 0.0}
 
 
 
109
 
110
  return {
111
+ "results": [{"label": l, "score": s} for l, s in pairs],
112
+ "dominant": dominant
113
  }
114
 
115
  except Exception as e:
116
  return JSONResponse(
117
  status_code=400,
118
+ content={"error": "failed to process audio", "message": f"{e.__class__.__name__}: {e}"}
119
  )