marshal-yash commited on
Commit
5159013
·
verified ·
1 Parent(s): 4fcb97a

Update server.py

Browse files
Files changed (1) hide show
  1. server.py +43 -16
server.py CHANGED
@@ -18,19 +18,26 @@ app.add_middleware(
18
  )
19
 
20
  FFMPEG_BIN = os.environ.get('FFMPEG_BIN', 'ffmpeg')
21
- MODEL_DIR = os.environ.get('MODEL_DIR', os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'best_model')))
22
 
 
 
 
 
23
  model = Wav2Vec2ForSequenceClassification.from_pretrained(MODEL_DIR)
24
  fe = AutoFeatureExtractor.from_pretrained(MODEL_DIR)
 
25
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
26
  model.to(device)
27
  model.eval()
28
 
 
29
  def to_wav16k_mono(data: bytes) -> np.ndarray:
30
  try:
31
  p = subprocess.run(
32
- [FFMPEG_BIN, '-hide_banner', '-loglevel', 'error', '-i', 'pipe:0', '-ar', str(fe.sampling_rate), '-ac', '1', '-f', 'wav', 'pipe:1'],
33
- input=data, stdout=subprocess.PIPE, stderr=subprocess.PIPE, check=True,
 
 
34
  )
35
  audio, sr = sf.read(io.BytesIO(p.stdout), dtype='float32', always_2d=False)
36
  if isinstance(audio, np.ndarray):
@@ -41,6 +48,7 @@ def to_wav16k_mono(data: bytes) -> np.ndarray:
41
  out = np.pad(out, (0, max(0, fe.sampling_rate - out.size)), mode='constant')
42
  return out
43
  return np.array(audio, dtype=np.float32)
 
44
  except Exception:
45
  try:
46
  audio, sr = sf.read(io.BytesIO(data), dtype='float32', always_2d=False)
@@ -53,12 +61,14 @@ def to_wav16k_mono(data: bytes) -> np.ndarray:
53
  if out.size < fe.sampling_rate // 10:
54
  out = np.pad(out, (0, max(0, fe.sampling_rate - out.size)), mode='constant')
55
  return out
 
56
  arr = np.array(audio, dtype=np.float32)
57
  if sr and sr != fe.sampling_rate:
58
  arr = librosa.resample(arr, orig_sr=sr, target_sr=fe.sampling_rate)
59
  if arr.size < fe.sampling_rate // 10:
60
  arr = np.pad(arr, (0, max(0, fe.sampling_rate - arr.size)), mode='constant')
61
  return arr
 
62
  except Exception:
63
  try:
64
  with tempfile.NamedTemporaryFile(delete=True, suffix='.audio') as tmp:
@@ -69,30 +79,47 @@ def to_wav16k_mono(data: bytes) -> np.ndarray:
69
  except Exception:
70
  return np.zeros(fe.sampling_rate, dtype=np.float32)
71
 
 
72
  @app.post('/predict')
73
  async def predict(file: UploadFile = File(...)):
74
  try:
75
  data = await file.read()
76
  audio = to_wav16k_mono(data)
 
77
  inputs = fe(audio, sampling_rate=fe.sampling_rate, return_tensors='pt')
78
  inputs = {k: v.to(device) for k, v in inputs.items()}
 
79
  with torch.no_grad():
80
  logits = model(**inputs).logits
 
81
  probs = torch.softmax(logits, dim=-1)[0].cpu().numpy()
 
82
  label_map = model.config.id2label
83
- label_keys = list(label_map.keys())
84
- use_str = bool(label_keys) and isinstance(label_keys[0], str)
85
- labels = []
86
- for i in range(len(probs)):
87
- if use_str:
88
- labels.append(label_map.get(str(i), f"class_{i}"))
89
- else:
90
- labels.append(label_map.get(i, f"class_{i}"))
91
- pairs = sorted([(labels[i], float(probs[i])) for i in range(len(probs))], key=lambda x: x[1], reverse=True)
92
- dominant = { 'label': pairs[0][0], 'score': pairs[0][1] } if pairs else { 'label': '', 'score': 0.0 }
93
- return { 'results': [ { 'label': l, 'score': s } for l, s in pairs ], 'dominant': dominant }
 
 
 
 
 
 
 
94
  except Exception as e:
95
- return JSONResponse(status_code=400, content={ 'error': 'failed to process audio', 'message': f"{e.__class__.__name__}: {e}" })
 
 
 
 
 
96
  @app.get('/')
97
  def root():
98
- return { 'status': 'ok' }
 
18
  )
19
 
20
  FFMPEG_BIN = os.environ.get('FFMPEG_BIN', 'ffmpeg')
 
21
 
22
+ # ✅ FIX: MODEL IS IN THE SAME FOLDER AS server.py
23
+ MODEL_DIR = os.path.dirname(__file__)
24
+
25
+ # Load model + feature extractor
26
  model = Wav2Vec2ForSequenceClassification.from_pretrained(MODEL_DIR)
27
  fe = AutoFeatureExtractor.from_pretrained(MODEL_DIR)
28
+
29
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
30
  model.to(device)
31
  model.eval()
32
 
33
+
34
  def to_wav16k_mono(data: bytes) -> np.ndarray:
35
  try:
36
  p = subprocess.run(
37
+ [FFMPEG_BIN, '-hide_banner', '-loglevel', 'error',
38
+ '-i', 'pipe:0', '-ar', str(fe.sampling_rate), '-ac', '1',
39
+ '-f', 'wav', 'pipe:1'],
40
+ input=data, stdout=subprocess.PIPE, stderr=subprocess.PIPE, check=True
41
  )
42
  audio, sr = sf.read(io.BytesIO(p.stdout), dtype='float32', always_2d=False)
43
  if isinstance(audio, np.ndarray):
 
48
  out = np.pad(out, (0, max(0, fe.sampling_rate - out.size)), mode='constant')
49
  return out
50
  return np.array(audio, dtype=np.float32)
51
+
52
  except Exception:
53
  try:
54
  audio, sr = sf.read(io.BytesIO(data), dtype='float32', always_2d=False)
 
61
  if out.size < fe.sampling_rate // 10:
62
  out = np.pad(out, (0, max(0, fe.sampling_rate - out.size)), mode='constant')
63
  return out
64
+
65
  arr = np.array(audio, dtype=np.float32)
66
  if sr and sr != fe.sampling_rate:
67
  arr = librosa.resample(arr, orig_sr=sr, target_sr=fe.sampling_rate)
68
  if arr.size < fe.sampling_rate // 10:
69
  arr = np.pad(arr, (0, max(0, fe.sampling_rate - arr.size)), mode='constant')
70
  return arr
71
+
72
  except Exception:
73
  try:
74
  with tempfile.NamedTemporaryFile(delete=True, suffix='.audio') as tmp:
 
79
  except Exception:
80
  return np.zeros(fe.sampling_rate, dtype=np.float32)
81
 
82
+
83
  @app.post('/predict')
84
  async def predict(file: UploadFile = File(...)):
85
  try:
86
  data = await file.read()
87
  audio = to_wav16k_mono(data)
88
+
89
  inputs = fe(audio, sampling_rate=fe.sampling_rate, return_tensors='pt')
90
  inputs = {k: v.to(device) for k, v in inputs.items()}
91
+
92
  with torch.no_grad():
93
  logits = model(**inputs).logits
94
+
95
  probs = torch.softmax(logits, dim=-1)[0].cpu().numpy()
96
+
97
  label_map = model.config.id2label
98
+ labels = [label_map.get(str(i), f"class_{i}") for i in range(len(probs))]
99
+
100
+ pairs = sorted(
101
+ [(labels[i], float(probs[i])) for i in range(len(probs))],
102
+ key=lambda x: x[1],
103
+ reverse=True
104
+ )
105
+
106
+ dominant = {
107
+ 'label': pairs[0][0],
108
+ 'score': pairs[0][1]
109
+ } if pairs else {'label': '', 'score': 0.0}
110
+
111
+ return {
112
+ 'results': [{'label': l, 'score': s} for l, s in pairs],
113
+ 'dominant': dominant
114
+ }
115
+
116
  except Exception as e:
117
+ return JSONResponse(
118
+ status_code=400,
119
+ content={'error': 'failed to process audio', 'message': f"{e.__class__.__name__}: {e}"}
120
+ )
121
+
122
+
123
  @app.get('/')
124
  def root():
125
+ return {'status': 'ok'}