Luis J Camargo commited on
Commit
72cb2ee
Β·
1 Parent(s): a0ff692

Replace UI labels with interactive Top-K DataFrame table

Browse files
Files changed (2) hide show
  1. app.py +54 -29
  2. requirements.txt +1 -1
app.py CHANGED
@@ -4,6 +4,7 @@ import gradio as gr
4
  import torch
5
  import numpy as np
6
  import librosa
 
7
  from transformers import WhisperProcessor, AutoConfig, AutoModel, WhisperConfig, WhisperPreTrainedModel
8
  from transformers.models.whisper.modeling_whisper import WhisperEncoder
9
  import torch.nn as nn
@@ -150,7 +151,7 @@ def get_mem_usage():
150
  return process.memory_info().rss / (1024 ** 2)
151
 
152
  # === INFERENCE FUNCTION ===
153
- def predict_language(audio_path):
154
  if not audio_path:
155
  raise gr.Error("No audio provided! Please upload or record an audio file.")
156
 
@@ -194,31 +195,47 @@ def predict_language(audio_path):
194
  super_probs = torch.softmax(outputs["super_logits"], dim=-1)
195
  code_probs = torch.softmax(outputs["code_logits"], dim=-1)
196
 
197
- fam_idx = outputs["fam_logits"].argmax(-1).item()
198
- super_idx = outputs["super_logits"].argmax(-1).item()
199
- code_idx = outputs["code_logits"].argmax(-1).item()
200
-
201
- fam_conf = fam_probs[0, fam_idx].item()
202
- super_conf = super_probs[0, super_idx].item()
203
- code_conf = code_probs[0, code_idx].item()
204
-
205
- # Map indices to human-readable strings using the LabelExtractor logic
206
- # Strip the "<|" and "|>" tags if present for a cleaner UI
207
- fam_text = label_extractor.family_labels[fam_idx].strip("<@|>") if fam_idx < len(label_extractor.family_labels) else f"Unknown Fam ({fam_idx})"
208
- super_text = label_extractor.super_labels[super_idx].strip("<|>") if super_idx < len(label_extractor.super_labels) else f"Unknown Super ({super_idx})"
209
- code_raw = label_extractor.code_labels[code_idx].strip("<|>") if code_idx < len(label_extractor.code_labels) else f"Unknown Code ({code_idx})"
210
-
211
- # Apply inali_name mapping
212
- code_text = f"{CODE_TO_NAME[code_raw]} ({code_raw})" if code_raw in CODE_TO_NAME else code_raw
213
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
214
  print(f"[LOG] Final Memory: {get_mem_usage():.2f} MB")
215
  print(f"--- [LOG] Request Finished ---\n")
216
 
217
- return (
218
- {fam_text: fam_conf},
219
- {super_text: super_conf},
220
- {code_text: code_conf}
221
- )
222
  except Exception as e:
223
  print(f"Error during inference: {e}")
224
  raise gr.Error(f"Processing failed: {str(e)}")
@@ -242,26 +259,34 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="indigo", secondary_hue="blue"))
242
  type="filepath", # Changed from numpy to filepath
243
  label="Upload or Record"
244
  )
 
 
 
 
245
  with gr.Row():
246
  clear_btn = gr.Button("πŸ—‘οΈ Clear", variant="secondary")
247
  submit_btn = gr.Button("πŸš€ Classify", variant="primary")
248
 
249
  with gr.Column(scale=1):
250
  gr.Markdown("### πŸ“Š 2. Classification Results")
251
- fam_output = gr.Label(num_top_classes=1, label="🌍 Language Family")
252
- super_output = gr.Label(num_top_classes=1, label="πŸ—£οΈ Superlanguage")
253
- code_output = gr.Label(num_top_classes=1, label="πŸ”€ Language")
 
 
 
 
254
 
255
  submit_btn.click(
256
  fn=predict_language,
257
- inputs=audio_input,
258
- outputs=[fam_output, super_output, code_output]
259
  )
260
 
261
  clear_btn.click(
262
- fn=lambda: (None, None, None, None),
263
  inputs=None,
264
- outputs=[audio_input, fam_output, super_output, code_output]
265
  )
266
 
267
  gr.Markdown(
 
4
  import torch
5
  import numpy as np
6
  import librosa
7
+ import pandas as pd
8
  from transformers import WhisperProcessor, AutoConfig, AutoModel, WhisperConfig, WhisperPreTrainedModel
9
  from transformers.models.whisper.modeling_whisper import WhisperEncoder
10
  import torch.nn as nn
 
151
  return process.memory_info().rss / (1024 ** 2)
152
 
153
  # === INFERENCE FUNCTION ===
154
+ def predict_language(audio_path, top_k=3, threshold=0.0):
155
  if not audio_path:
156
  raise gr.Error("No audio provided! Please upload or record an audio file.")
157
 
 
195
  super_probs = torch.softmax(outputs["super_logits"], dim=-1)
196
  code_probs = torch.softmax(outputs["code_logits"], dim=-1)
197
 
198
+ # Extract top-k indices and probabilities
199
+ top_k = int(top_k)
200
+ fam_top = torch.topk(fam_probs[0], min(top_k, fam_probs.shape[-1]))
201
+ super_top = torch.topk(super_probs[0], min(top_k, super_probs.shape[-1]))
202
+ code_top = torch.topk(code_probs[0], min(top_k, code_probs.shape[-1]))
 
 
 
 
 
 
 
 
 
 
 
203
 
204
+ table_data = []
205
+
206
+ # Helper to format and add results to the table
207
+ def add_to_table(category, top_vals, top_idx, labels_list, apply_mapping=False):
208
+ # top_vals and top_idx are 1D tensors
209
+ valid_rank = 1
210
+ for i in range(len(top_vals)):
211
+ score = top_vals[i].item()
212
+ if score < threshold:
213
+ continue
214
+
215
+ idx = top_idx[i].item()
216
+ raw_label = labels_list[idx].strip("<|>") if idx < len(labels_list) else f"Unknown ({idx})"
217
+
218
+ if apply_mapping:
219
+ name = f"{CODE_TO_NAME[raw_label]} ({raw_label})" if raw_label in CODE_TO_NAME else raw_label
220
+ else:
221
+ name = raw_label
222
+
223
+ table_data.append([category, valid_rank, name, f"{score:.2%}"])
224
+ valid_rank += 1
225
+
226
+ add_to_table("🌍 Family", fam_top.values, fam_top.indices, label_extractor.family_labels)
227
+ add_to_table("πŸ—£οΈ Superlanguage", super_top.values, super_top.indices, label_extractor.super_labels)
228
+ add_to_table("πŸ”€ Code", code_top.values, code_top.indices, label_extractor.code_labels, apply_mapping=True)
229
+
230
+ if not table_data:
231
+ df = pd.DataFrame(columns=["Category", "Rank", "Prediction", "Confidence"])
232
+ else:
233
+ df = pd.DataFrame(table_data, columns=["Category", "Rank", "Prediction", "Confidence"])
234
+
235
  print(f"[LOG] Final Memory: {get_mem_usage():.2f} MB")
236
  print(f"--- [LOG] Request Finished ---\n")
237
 
238
+ return df
 
 
 
 
239
  except Exception as e:
240
  print(f"Error during inference: {e}")
241
  raise gr.Error(f"Processing failed: {str(e)}")
 
259
  type="filepath", # Changed from numpy to filepath
260
  label="Upload or Record"
261
  )
262
+ with gr.Row():
263
+ top_k = gr.Slider(minimum=1, maximum=10, step=1, value=3, label="Top-K Predictions")
264
+ threshold = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, value=0.0, label="Confidence Threshold")
265
+
266
  with gr.Row():
267
  clear_btn = gr.Button("πŸ—‘οΈ Clear", variant="secondary")
268
  submit_btn = gr.Button("πŸš€ Classify", variant="primary")
269
 
270
  with gr.Column(scale=1):
271
  gr.Markdown("### πŸ“Š 2. Classification Results")
272
+ results_table = gr.Dataframe(
273
+ headers=["Category", "Rank", "Prediction", "Confidence"],
274
+ datatype=["str", "number", "str", "str"],
275
+ label="Predictions",
276
+ interactive=False,
277
+ wrap=True
278
+ )
279
 
280
  submit_btn.click(
281
  fn=predict_language,
282
+ inputs=[audio_input, top_k, threshold],
283
+ outputs=[results_table]
284
  )
285
 
286
  clear_btn.click(
287
+ fn=lambda: (None, None),
288
  inputs=None,
289
+ outputs=[audio_input, results_table]
290
  )
291
 
292
  gr.Markdown(
requirements.txt CHANGED
@@ -5,4 +5,4 @@ numpy
5
  librosa
6
  huggingface_hub
7
  safetensors
8
- psutil
 
5
  librosa
6
  huggingface_hub
7
  safetensors
8
+ psutilpandas