Luis J Camargo commited on
Commit
69358b9
·
1 Parent(s): 68fabc6
Files changed (1) hide show
  1. app.py +79 -92
app.py CHANGED
@@ -1,12 +1,16 @@
1
  import os
 
2
  import gradio as gr
3
  import torch
4
  import numpy as np
 
5
  from transformers import WhisperProcessor, AutoConfig, AutoModel, WhisperConfig, WhisperPreTrainedModel
6
  from transformers.models.whisper.modeling_whisper import WhisperEncoder
7
  import torch.nn as nn
8
  import psutil
9
- import gc
 
 
10
 
11
  torch.set_num_threads(1)
12
 
@@ -25,22 +29,17 @@ class WhisperEncoderOnlyForClassification(WhisperPreTrainedModel):
25
 
26
  def __init__(self, config):
27
  super().__init__(config)
28
-
29
  self.encoder = WhisperEncoder(config)
30
-
31
  hidden = config.d_model
32
  self.fam_head = nn.Linear(hidden, config.n_fam)
33
  self.super_head = nn.Linear(hidden, config.n_super)
34
  self.code_head = nn.Linear(hidden, config.n_code)
35
-
36
  self.post_init()
37
 
38
  def get_input_embeddings(self):
39
- """Whisper doesn't have token embeddings"""
40
  return None
41
 
42
  def set_input_embeddings(self, value):
43
- """Ignore"""
44
  pass
45
 
46
  def enable_input_require_grads(self):
@@ -80,101 +79,93 @@ MODEL_REPO = "tachiwin/language_classification_enconly_model_2"
80
 
81
  print("Loading model on CPU...")
82
  processor = WhisperProcessor.from_pretrained(MODEL_REPO)
83
- model = WhisperEncoderOnlyForClassification.from_pretrained(MODEL_REPO, low_cpu_mem_usage=True)
 
 
 
84
  model.eval()
85
 
86
  print("Model loaded successfully!")
87
 
88
  def get_mem_usage():
89
  process = psutil.Process(os.getpid())
90
- return process.memory_info().rss / (1024 ** 2) # In MB
91
 
92
  # === INFERENCE FUNCTION ===
93
- def predict_language(audio):
94
- if audio is None:
95
- return "⚠️ No audio provided", "", ""
96
 
97
- gc.collect() # Start clean
98
  start_mem = get_mem_usage()
99
-
100
- sample_rate, audio_array = audio
101
- audio_len_sec = len(audio_array) / sample_rate
102
-
103
  print(f"\n--- [LOG] New Request ---")
104
  print(f"[LOG] Start Memory: {start_mem:.2f} MB")
105
- print(f"[LOG] Audio duration: {audio_len_sec:.2f}s, SR: {sample_rate}")
106
-
107
- # Normalization
108
- print("[LOG] Step 1: Normalizing audio...")
109
- if audio_array.dtype == np.int16:
110
- print("was npint16")
111
- audio_array = audio_array.astype(np.float32) / 32768.0
112
- elif audio_array.dtype == np.int32:
113
- print("was npint32")
114
- audio_array = audio_array.astype(np.float32) / 2147483648.0
115
- print(f"[LOG] Memory after normalization: {get_mem_usage():.2f} MB")
116
-
117
- # Resampling
118
- if sample_rate != 16000:
119
- print(f"[LOG] Step 2: Resampling {sample_rate}Hz -> 16000Hz...")
120
- import librosa
121
- audio_array = librosa.resample(audio_array, orig_sr=sample_rate, target_sr=16_000)
122
- print(f"[LOG] Memory after resampling: {get_mem_usage():.2f} MB")
123
-
124
- print("[LOG] DID RESAMPLE")
125
-
126
- # Preprocessing
127
- print("[LOG] Step 3: Extracting features...")
128
- inputs = processor(
129
- audio_array,
130
- sampling_rate=16_000,
131
- do_normalize=True,
132
- device="cpu",
133
- return_tensors="pt",
134
- )
135
- print("[LOG] DID EXTRACT")
136
-
137
- # Delete raw audio array immediately as it's now in 'inputs'
138
- del audio_array
139
- gc.collect()
140
- print(f"[LOG] Memory after preprocessing: {get_mem_usage():.2f} MB")
141
-
142
- # Inference
143
- print("[LOG] Step 4: Running model inference...")
144
- with torch.no_grad():
145
- outputs = model(input_features=inputs.input_features)
146
-
147
- # Cleanup inputs
148
- del inputs
149
- gc.collect()
150
- print(f"[LOG] Memory after inference: {get_mem_usage():.2f} MB")
151
-
152
- # Post-processing
153
- print("[LOG] Step 5: Post-processing results...")
154
- fam_probs = torch.softmax(outputs["fam_logits"], dim=-1)
155
- super_probs = torch.softmax(outputs["super_logits"], dim=-1)
156
- code_probs = torch.softmax(outputs["code_logits"], dim=-1)
157
 
158
- fam_idx = outputs["fam_logits"].argmax(-1).item()
159
- super_idx = outputs["super_logits"].argmax(-1).item()
160
- code_idx = outputs["code_logits"].argmax(-1).item()
161
-
162
- fam_conf = fam_probs[0, fam_idx].item()
163
- super_conf = super_probs[0, super_idx].item()
164
- code_conf = code_probs[0, code_idx].item()
165
-
166
- print(f"[LOG] Final Memory: {get_mem_usage():.2f} MB")
167
- print(f"--- [LOG] Request Finished ---\n")
168
-
169
- # Formatting results
170
- return (
171
- {f"{fam_idx}": fam_conf},
172
- {f"{super_idx}": super_conf},
173
- {f"{code_idx}": code_conf}
174
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
175
 
176
  # === UI COMPONENTS ===
177
- with gr.Blocks() as demo:
178
  gr.HTML(
179
  """
180
  <div style="text-align: center; padding: 30px; background: linear-gradient(135deg, #4f46e5 0%, #3b82f6 100%); color: white; border-radius: 15px; margin-bottom: 25px; box-shadow: 0 4px 6px -1px rgba(0, 0, 0, 0.1);">
@@ -189,7 +180,7 @@ with gr.Blocks() as demo:
189
  gr.Markdown("### 🎙️ 1. Input Audio")
190
  audio_input = gr.Audio(
191
  sources=["upload", "microphone"],
192
- type="numpy",
193
  label="Upload or Record"
194
  )
195
  with gr.Row():
@@ -230,8 +221,4 @@ with gr.Blocks() as demo:
230
  )
231
 
232
  if __name__ == "__main__":
233
- demo.launch(
234
- theme=gr.themes.Soft(primary_hue="indigo", secondary_hue="blue"),
235
- ssr_mode=False,
236
- show_error=True
237
- )
 
1
  import os
2
+ import gc
3
  import gradio as gr
4
  import torch
5
  import numpy as np
6
+ import librosa
7
  from transformers import WhisperProcessor, AutoConfig, AutoModel, WhisperConfig, WhisperPreTrainedModel
8
  from transformers.models.whisper.modeling_whisper import WhisperEncoder
9
  import torch.nn as nn
10
  import psutil
11
+
12
+ # --- CONFIGURATION ---
13
+ MAX_AUDIO_SECONDS = 30
14
 
15
  torch.set_num_threads(1)
16
 
 
29
 
30
  def __init__(self, config):
31
  super().__init__(config)
 
32
  self.encoder = WhisperEncoder(config)
 
33
  hidden = config.d_model
34
  self.fam_head = nn.Linear(hidden, config.n_fam)
35
  self.super_head = nn.Linear(hidden, config.n_super)
36
  self.code_head = nn.Linear(hidden, config.n_code)
 
37
  self.post_init()
38
 
39
  def get_input_embeddings(self):
 
40
  return None
41
 
42
  def set_input_embeddings(self, value):
 
43
  pass
44
 
45
  def enable_input_require_grads(self):
 
79
 
80
  print("Loading model on CPU...")
81
  processor = WhisperProcessor.from_pretrained(MODEL_REPO)
82
+ model = WhisperEncoderOnlyForClassification.from_pretrained(
83
+ MODEL_REPO,
84
+ low_cpu_mem_usage=True
85
+ )
86
  model.eval()
87
 
88
  print("Model loaded successfully!")
89
 
90
  def get_mem_usage():
91
  process = psutil.Process(os.getpid())
92
+ return process.memory_info().rss / (1024 ** 2)
93
 
94
  # === INFERENCE FUNCTION ===
95
+ def predict_language(audio_path):
96
+ if not audio_path:
97
+ raise gr.Error("No audio provided! Please upload or record an audio file.")
98
 
99
+ gc.collect()
100
  start_mem = get_mem_usage()
 
 
 
 
101
  print(f"\n--- [LOG] New Request ---")
102
  print(f"[LOG] Start Memory: {start_mem:.2f} MB")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
 
104
+ try:
105
+ # Load audio directly from filepath. Librosa automatically resamples to sr=16000 and normalizes to float32
106
+ print("[LOG] Step 1: Loading and resampling audio from file...")
107
+ audio_array, sample_rate = librosa.load(audio_path, sr=16000)
108
+
109
+ audio_len_sec = len(audio_array) / 16000
110
+ print(f"[LOG] Audio duration: {audio_len_sec:.2f}s, SR: 16000")
111
+ print(f"[LOG] Memory after load: {get_mem_usage():.2f} MB")
112
+
113
+ # Enforce length limit to prevent OOM
114
+ if audio_len_sec > MAX_AUDIO_SECONDS:
115
+ del audio_array
116
+ gc.collect()
117
+ raise gr.Error(f"Audio too long ({audio_len_sec:.1f}s). Please upload or record up to {MAX_AUDIO_SECONDS} seconds.")
118
+
119
+ # Preprocessing
120
+ print("[LOG] Step 3: Extracting features...")
121
+ inputs = processor(
122
+ audio_array,
123
+ sampling_rate=16000,
124
+ return_tensors="pt"
125
+ )
126
+
127
+ # Free up the raw audio array
128
+ del audio_array
129
+ gc.collect()
130
+ print(f"[LOG] Memory after preprocessing: {get_mem_usage():.2f} MB")
131
+
132
+ # Inference
133
+ print("[LOG] Step 4: Running model inference...")
134
+ with torch.no_grad():
135
+ outputs = model(input_features=inputs.input_features)
136
+
137
+ # Free up inputs
138
+ del inputs
139
+ gc.collect()
140
+
141
+ # Post-processing
142
+ print("[LOG] Step 5: Post-processing results...")
143
+ fam_probs = torch.softmax(outputs["fam_logits"], dim=-1)
144
+ super_probs = torch.softmax(outputs["super_logits"], dim=-1)
145
+ code_probs = torch.softmax(outputs["code_logits"], dim=-1)
146
+
147
+ fam_idx = outputs["fam_logits"].argmax(-1).item()
148
+ super_idx = outputs["super_logits"].argmax(-1).item()
149
+ code_idx = outputs["code_logits"].argmax(-1).item()
150
+
151
+ fam_conf = fam_probs[0, fam_idx].item()
152
+ super_conf = super_probs[0, super_idx].item()
153
+ code_conf = code_probs[0, code_idx].item()
154
+
155
+ print(f"[LOG] Final Memory: {get_mem_usage():.2f} MB")
156
+ print(f"--- [LOG] Request Finished ---\n")
157
+
158
+ return (
159
+ {f"{fam_idx}": fam_conf},
160
+ {f"{super_idx}": super_conf},
161
+ {f"{code_idx}": code_conf}
162
+ )
163
+ except Exception as e:
164
+ print(f"Error during inference: {e}")
165
+ raise gr.Error(f"Processing failed: {str(e)}")
166
 
167
  # === UI COMPONENTS ===
168
+ with gr.Blocks(theme=gr.themes.Soft(primary_hue="indigo", secondary_hue="blue")) as demo:
169
  gr.HTML(
170
  """
171
  <div style="text-align: center; padding: 30px; background: linear-gradient(135deg, #4f46e5 0%, #3b82f6 100%); color: white; border-radius: 15px; margin-bottom: 25px; box-shadow: 0 4px 6px -1px rgba(0, 0, 0, 0.1);">
 
180
  gr.Markdown("### 🎙️ 1. Input Audio")
181
  audio_input = gr.Audio(
182
  sources=["upload", "microphone"],
183
+ type="filepath", # Changed from numpy to filepath
184
  label="Upload or Record"
185
  )
186
  with gr.Row():
 
221
  )
222
 
223
  if __name__ == "__main__":
224
+ demo.launch(ssr_mode=False)