Luis J Camargo commited on
Commit
ec249fb
Β·
1 Parent(s): 84dac14

test pre return

Browse files
Files changed (1) hide show
  1. app.py +83 -91
app.py CHANGED
@@ -5,6 +5,8 @@ import numpy as np
5
  from transformers import WhisperProcessor, AutoConfig, AutoModel, WhisperConfig, WhisperPreTrainedModel
6
  from transformers.models.whisper.modeling_whisper import WhisperEncoder
7
  import torch.nn as nn
 
 
8
 
9
  # === CUSTOM MODEL CLASSES ===
10
  class WhisperEncoderOnlyConfig(WhisperConfig):
@@ -81,9 +83,6 @@ model.eval()
81
 
82
  print("Model loaded successfully!")
83
 
84
- import psutil
85
- import gc
86
-
87
  def get_mem_usage():
88
  process = psutil.Process(os.getpid())
89
  return process.memory_info().rss / (1024 ** 2) # In MB
@@ -91,87 +90,85 @@ def get_mem_usage():
91
  # === INFERENCE FUNCTION ===
92
  def predict_language(audio):
93
  if audio is None:
94
- yield "⚠️ No audio provided", {}, {}, {}
95
- return
96
-
97
- log_buffer = "--- [LOG] New Request ---\n"
98
- yield log_buffer, {}, {}, {}
99
 
100
- try:
101
- gc.collect()
102
- start_mem = get_mem_usage()
103
- sample_rate, audio_array = audio
104
- audio_len_sec = len(audio_array) / sample_rate
105
-
106
- log_buffer += f"RAM: {start_mem:.2f} MB | Len: {audio_len_sec:.2f}s | SR: {sample_rate}\n"
107
- yield log_buffer, {}, {}, {}
108
-
109
- # Normalization
110
- log_buffer += "Step 1: Normalizing...\n"
111
- yield log_buffer, {}, {}, {}
112
- if audio_array.dtype == np.int16:
113
- audio_array = audio_array.astype(np.float32) / 32768.0
114
- elif audio_array.dtype == np.int32:
115
- audio_array = audio_array.astype(np.float32) / 2147483648.0
116
-
117
- # Resampling
118
- if sample_rate != 16000:
119
- log_buffer += f"Step 2: Resampling {sample_rate}Hz -> 16kHz...\n"
120
- yield log_buffer, {}, {}, {}
121
- import librosa
122
- audio_array = librosa.resample(audio_array, orig_sr=sample_rate, target_sr=16000)
123
- log_buffer += f"Mem post-resample: {get_mem_usage():.2f} MB\n"
124
- yield log_buffer, {}, {}, {}
125
-
126
- # Preprocessing
127
- log_buffer += "Step 3: Extracting features...\n"
128
- yield log_buffer, {}, {}, {}
129
- inputs = processor(
130
- audio_array,
131
- sampling_rate=16000,
132
- return_tensors="pt"
133
- )
134
- del audio_array
135
- gc.collect()
136
- log_buffer += f"Mem post-features: {get_mem_usage():.2f} MB\n"
137
- yield log_buffer, {}, {}, {}
138
-
139
- # Inference
140
- log_buffer += "Step 4: Running Model (CPU)... \n"
141
- yield log_buffer, {}, {}, {}
142
- with torch.no_grad():
143
- outputs = model(input_features=inputs.input_features)
144
-
145
- del inputs
146
- gc.collect()
147
- log_buffer += f"Mem post-inference: {get_mem_usage():.2f} MB\n"
148
- yield log_buffer, {}, {}, {}
149
-
150
- # Post-processing
151
- log_buffer += "Step 5: Formatting results...\n"
152
- yield log_buffer, {}, {}, {}
153
- fam_probs = torch.softmax(outputs["fam_logits"], dim=-1)
154
- super_probs = torch.softmax(outputs["super_logits"], dim=-1)
155
- code_probs = torch.softmax(outputs["code_logits"], dim=-1)
156
-
157
- fam_idx = outputs["fam_logits"].argmax(-1).item()
158
- super_idx = outputs["super_logits"].argmax(-1).item()
159
- code_idx = outputs["code_logits"].argmax(-1).item()
160
-
161
- fam_conf = fam_probs[0, fam_idx].item()
162
- super_conf = super_probs[0, super_idx].item()
163
- code_conf = code_probs[0, code_idx].item()
164
-
165
- log_buffer += "--- [LOG] Finished Successfully ---"
166
- yield (
167
- log_buffer,
168
- {f"{fam_idx}": fam_conf},
169
- {f"{super_idx}": super_conf},
170
- {f"{code_idx}": code_conf}
171
- )
172
- except Exception as e:
173
- log_buffer += f"\n❌ CRASH: {str(e)}"
174
- yield log_buffer, {}, {}, {}
 
 
175
 
176
  # === UI COMPONENTS ===
177
  with gr.Blocks() as demo:
@@ -196,9 +193,6 @@ with gr.Blocks() as demo:
196
  clear_btn = gr.Button("πŸ—‘οΈ Clear", variant="secondary")
197
  submit_btn = gr.Button("πŸš€ Classify", variant="primary")
198
 
199
- # Persistent Log Output
200
- status_logs = gr.Textbox(label="πŸ” Persistent Status Log (Visible after crash)", interactive=False, lines=10)
201
-
202
  with gr.Column(scale=1):
203
  gr.Markdown("### πŸ“Š 2. Classification Results")
204
  fam_output = gr.Label(num_top_classes=1, label="🌍 Language Family")
@@ -208,16 +202,15 @@ with gr.Blocks() as demo:
208
  submit_btn.click(
209
  fn=predict_language,
210
  inputs=audio_input,
211
- outputs=[status_logs, fam_output, super_output, code_output]
212
  )
213
 
214
  clear_btn.click(
215
- fn=lambda: ("", None, None, None, None),
216
  inputs=None,
217
- outputs=[status_logs, audio_input, fam_output, super_output, code_output]
218
  )
219
 
220
-
221
  gr.Markdown(
222
  """
223
  ---
@@ -234,7 +227,6 @@ with gr.Blocks() as demo:
234
  )
235
 
236
  if __name__ == "__main__":
237
- # Increased concurrency for CPU stability
238
  demo.launch(
239
  theme=gr.themes.Soft(primary_hue="indigo", secondary_hue="blue"),
240
  ssr_mode=False,
 
5
  from transformers import WhisperProcessor, AutoConfig, AutoModel, WhisperConfig, WhisperPreTrainedModel
6
  from transformers.models.whisper.modeling_whisper import WhisperEncoder
7
  import torch.nn as nn
8
+ import psutil
9
+ import gc
10
 
11
  # === CUSTOM MODEL CLASSES ===
12
  class WhisperEncoderOnlyConfig(WhisperConfig):
 
83
 
84
  print("Model loaded successfully!")
85
 
 
 
 
86
  def get_mem_usage():
87
  process = psutil.Process(os.getpid())
88
  return process.memory_info().rss / (1024 ** 2) # In MB
 
90
  # === INFERENCE FUNCTION ===
91
  def predict_language(audio):
92
  if audio is None:
93
+ return "⚠️ No audio provided", "", ""
 
 
 
 
94
 
95
+ gc.collect() # Start clean
96
+ start_mem = get_mem_usage()
97
+
98
+ sample_rate, audio_array = audio
99
+ audio_len_sec = len(audio_array) / sample_rate
100
+
101
+ print(f"\n--- [LOG] New Request ---")
102
+ print(f"[LOG] Start Memory: {start_mem:.2f} MB")
103
+ print(f"[LOG] Audio duration: {audio_len_sec:.2f}s, SR: {sample_rate}")
104
+
105
+ # Normalization
106
+ print("[LOG] Step 1: Normalizing audio...")
107
+ if audio_array.dtype == np.int16:
108
+ print("was npint16")
109
+ audio_array = audio_array.astype(np.float32) / 32768.0
110
+ elif audio_array.dtype == np.int32:
111
+ print("was npint32")
112
+ audio_array = audio_array.astype(np.float32) / 2147483648.0
113
+ print(f"[LOG] Memory after normalization: {get_mem_usage():.2f} MB")
114
+
115
+ # Resampling
116
+ if sample_rate != 16000:
117
+ print(f"[LOG] Step 2: Resampling {sample_rate}Hz -> 16000Hz...")
118
+ import librosa
119
+ audio_array = librosa.resample(audio_array, orig_sr=sample_rate, target_sr=16_000)
120
+ print(f"[LOG] Memory after resampling: {get_mem_usage():.2f} MB")
121
+
122
+ print("[LOG] DID RESAMPLE")
123
+ return None
124
+
125
+ # Preprocessing
126
+ print("[LOG] Step 3: Extracting features...")
127
+ inputs = processor(
128
+ audio_array,
129
+ sampling_rate=16_000,
130
+ do_normalize=True,
131
+ device="cpu",
132
+ return_tensors="pt",
133
+ )
134
+ # Delete raw audio array immediately as it's now in 'inputs'
135
+ del audio_array
136
+ gc.collect()
137
+ print(f"[LOG] Memory after preprocessing: {get_mem_usage():.2f} MB")
138
+
139
+ # Inference
140
+ print("[LOG] Step 4: Running model inference...")
141
+ with torch.no_grad():
142
+ outputs = model(input_features=inputs.input_features)
143
+
144
+ # Cleanup inputs
145
+ del inputs
146
+ gc.collect()
147
+ print(f"[LOG] Memory after inference: {get_mem_usage():.2f} MB")
148
+
149
+ # Post-processing
150
+ print("[LOG] Step 5: Post-processing results...")
151
+ fam_probs = torch.softmax(outputs["fam_logits"], dim=-1)
152
+ super_probs = torch.softmax(outputs["super_logits"], dim=-1)
153
+ code_probs = torch.softmax(outputs["code_logits"], dim=-1)
154
+
155
+ fam_idx = outputs["fam_logits"].argmax(-1).item()
156
+ super_idx = outputs["super_logits"].argmax(-1).item()
157
+ code_idx = outputs["code_logits"].argmax(-1).item()
158
+
159
+ fam_conf = fam_probs[0, fam_idx].item()
160
+ super_conf = super_probs[0, super_idx].item()
161
+ code_conf = code_probs[0, code_idx].item()
162
+
163
+ print(f"[LOG] Final Memory: {get_mem_usage():.2f} MB")
164
+ print(f"--- [LOG] Request Finished ---\n")
165
+
166
+ # Formatting results
167
+ return (
168
+ {f"{fam_idx}": fam_conf},
169
+ {f"{super_idx}": super_conf},
170
+ {f"{code_idx}": code_conf}
171
+ )
172
 
173
  # === UI COMPONENTS ===
174
  with gr.Blocks() as demo:
 
193
  clear_btn = gr.Button("πŸ—‘οΈ Clear", variant="secondary")
194
  submit_btn = gr.Button("πŸš€ Classify", variant="primary")
195
 
 
 
 
196
  with gr.Column(scale=1):
197
  gr.Markdown("### πŸ“Š 2. Classification Results")
198
  fam_output = gr.Label(num_top_classes=1, label="🌍 Language Family")
 
202
  submit_btn.click(
203
  fn=predict_language,
204
  inputs=audio_input,
205
+ outputs=[fam_output, super_output, code_output]
206
  )
207
 
208
  clear_btn.click(
209
+ fn=lambda: (None, None, None, None),
210
  inputs=None,
211
+ outputs=[audio_input, fam_output, super_output, code_output]
212
  )
213
 
 
214
  gr.Markdown(
215
  """
216
  ---
 
227
  )
228
 
229
  if __name__ == "__main__":
 
230
  demo.launch(
231
  theme=gr.themes.Soft(primary_hue="indigo", secondary_hue="blue"),
232
  ssr_mode=False,