Luis J Camargo commited on
Commit
30e19d7
Β·
1 Parent(s): 1b87263

Refactor UI: separate tables and add Advanced Options accordion

Browse files
Files changed (1) hide show
  1. app.py +40 -40
app.py CHANGED
@@ -151,7 +151,7 @@ def get_mem_usage():
151
  return process.memory_info().rss / (1024 ** 2)
152
 
153
  # === INFERENCE FUNCTION ===
154
- def predict_language(audio_path, top_k=3, threshold=0.0):
155
  if not audio_path:
156
  raise gr.Error("No audio provided! Please upload or record an audio file.")
157
 
@@ -195,21 +195,14 @@ def predict_language(audio_path, top_k=3, threshold=0.0):
195
  super_probs = torch.softmax(outputs["super_logits"], dim=-1)
196
  code_probs = torch.softmax(outputs["code_logits"], dim=-1)
197
 
198
- # Extract top-k indices and probabilities
199
- top_k = int(top_k)
200
- fam_top = torch.topk(fam_probs[0], min(top_k, fam_probs.shape[-1]))
201
- super_top = torch.topk(super_probs[0], min(top_k, super_probs.shape[-1]))
202
- code_top = torch.topk(code_probs[0], min(top_k, code_probs.shape[-1]))
203
-
204
- table_data = []
205
-
206
- # Helper to format and add results to the table
207
- def add_to_table(category, top_vals, top_idx, labels_list, apply_mapping=False):
208
- # top_vals and top_idx are 1D tensors
209
- valid_rank = 1
210
  for i in range(len(top_vals)):
211
  score = top_vals[i].item()
212
- if score < threshold:
213
  continue
214
 
215
  idx = top_idx[i].item()
@@ -220,22 +213,20 @@ def predict_language(audio_path, top_k=3, threshold=0.0):
220
  else:
221
  name = raw_label
222
 
223
- table_data.append([category, valid_rank, name, f"{score:.2%}"])
224
- valid_rank += 1
225
-
226
- add_to_table("🌍 Family", fam_top.values, fam_top.indices, label_extractor.family_labels)
227
- add_to_table("πŸ—£οΈ Superlanguage", super_top.values, super_top.indices, label_extractor.super_labels)
228
- add_to_table("πŸ”€ Code", code_top.values, code_top.indices, label_extractor.code_labels, apply_mapping=True)
229
 
230
- if not table_data:
231
- df = pd.DataFrame(columns=["Category", "Rank", "Prediction", "Confidence"])
232
- else:
233
- df = pd.DataFrame(table_data, columns=["Category", "Rank", "Prediction", "Confidence"])
234
 
235
  print(f"[LOG] Final Memory: {get_mem_usage():.2f} MB")
236
  print(f"--- [LOG] Request Finished ---\n")
237
 
238
- return df
239
  except Exception as e:
240
  print(f"Error during inference: {e}")
241
  raise gr.Error(f"Processing failed: {str(e)}")
@@ -259,34 +250,43 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="indigo", secondary_hue="blue"))
259
  type="filepath", # Changed from numpy to filepath
260
  label="Upload or Record"
261
  )
262
- with gr.Row():
263
- top_k = gr.Slider(minimum=1, maximum=10, step=1, value=3, label="Top-K Predictions")
264
- threshold = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, value=0.0, label="Confidence Threshold")
265
-
 
 
 
 
 
 
 
 
 
 
 
 
 
266
  with gr.Row():
267
  clear_btn = gr.Button("πŸ—‘οΈ Clear", variant="secondary")
268
  submit_btn = gr.Button("πŸš€ Classify", variant="primary")
269
 
270
  with gr.Column(scale=1):
271
  gr.Markdown("### πŸ“Š 2. Classification Results")
272
- results_table = gr.Dataframe(
273
- headers=["Category", "Rank", "Prediction", "Confidence"],
274
- datatype=["str", "number", "str", "str"],
275
- label="Predictions",
276
- interactive=False,
277
- wrap=True
278
- )
279
 
280
  submit_btn.click(
281
  fn=predict_language,
282
- inputs=[audio_input, top_k, threshold],
283
- outputs=[results_table]
284
  )
285
 
286
  clear_btn.click(
287
- fn=lambda: (None, None),
288
  inputs=None,
289
- outputs=[audio_input, results_table]
290
  )
291
 
292
  gr.Markdown(
 
151
  return process.memory_info().rss / (1024 ** 2)
152
 
153
  # === INFERENCE FUNCTION ===
154
+ def predict_language(audio_path, fam_k=1, fam_thresh=0.0, super_k=1, super_thresh=0.0, code_k=3, code_thresh=0.0):
155
  if not audio_path:
156
  raise gr.Error("No audio provided! Please upload or record an audio file.")
157
 
 
195
  super_probs = torch.softmax(outputs["super_logits"], dim=-1)
196
  code_probs = torch.softmax(outputs["code_logits"], dim=-1)
197
 
198
+ def build_df(probs_tensor, k, thresh, labels_list, apply_mapping=False):
199
+ k = int(k)
200
+ top_vals, top_idx = torch.topk(probs_tensor[0], min(k, probs_tensor.shape[-1]))
201
+
202
+ table_data = []
 
 
 
 
 
 
 
203
  for i in range(len(top_vals)):
204
  score = top_vals[i].item()
205
+ if score < thresh:
206
  continue
207
 
208
  idx = top_idx[i].item()
 
213
  else:
214
  name = raw_label
215
 
216
+ table_data.append([name, f"{score:.2%}"])
217
+
218
+ if not table_data:
219
+ return pd.DataFrame(columns=["Prediction", "Confidence"])
220
+ return pd.DataFrame(table_data, columns=["Prediction", "Confidence"])
 
221
 
222
+ df_fam = build_df(fam_probs, fam_k, fam_thresh, label_extractor.family_labels)
223
+ df_super = build_df(super_probs, super_k, super_thresh, label_extractor.super_labels)
224
+ df_code = build_df(code_probs, code_k, code_thresh, label_extractor.code_labels, apply_mapping=True)
 
225
 
226
  print(f"[LOG] Final Memory: {get_mem_usage():.2f} MB")
227
  print(f"--- [LOG] Request Finished ---\n")
228
 
229
+ return df_fam, df_super, df_code
230
  except Exception as e:
231
  print(f"Error during inference: {e}")
232
  raise gr.Error(f"Processing failed: {str(e)}")
 
250
  type="filepath", # Changed from numpy to filepath
251
  label="Upload or Record"
252
  )
253
+ with gr.Accordion("βš™οΈ Advanced Options", open=False):
254
+ with gr.Group():
255
+ gr.Markdown("#### Language Family")
256
+ with gr.Row():
257
+ fam_k = gr.Slider(minimum=1, maximum=10, step=1, value=1, label="Top-K")
258
+ fam_thresh = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, value=0.0, label="Threshold")
259
+ with gr.Group():
260
+ gr.Markdown("#### Superlanguage")
261
+ with gr.Row():
262
+ super_k = gr.Slider(minimum=1, maximum=10, step=1, value=1, label="Top-K")
263
+ super_thresh = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, value=0.0, label="Threshold")
264
+ with gr.Group():
265
+ gr.Markdown("#### Language Code")
266
+ with gr.Row():
267
+ code_k = gr.Slider(minimum=1, maximum=10, step=1, value=3, label="Top-K")
268
+ code_thresh = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, value=0.0, label="Threshold")
269
+
270
  with gr.Row():
271
  clear_btn = gr.Button("πŸ—‘οΈ Clear", variant="secondary")
272
  submit_btn = gr.Button("πŸš€ Classify", variant="primary")
273
 
274
  with gr.Column(scale=1):
275
  gr.Markdown("### πŸ“Š 2. Classification Results")
276
+ fam_table = gr.Dataframe(headers=["Prediction", "Confidence"], datatype=["str", "str"], label="🌍 Language Family", interactive=False, wrap=True)
277
+ super_table = gr.Dataframe(headers=["Prediction", "Confidence"], datatype=["str", "str"], label="πŸ—£οΈ Superlanguage", interactive=False, wrap=True)
278
+ code_table = gr.Dataframe(headers=["Prediction", "Confidence"], datatype=["str", "str"], label="πŸ”€ Language Code", interactive=False, wrap=True)
 
 
 
 
279
 
280
  submit_btn.click(
281
  fn=predict_language,
282
+ inputs=[audio_input, fam_k, fam_thresh, super_k, super_thresh, code_k, code_thresh],
283
+ outputs=[fam_table, super_table, code_table]
284
  )
285
 
286
  clear_btn.click(
287
+ fn=lambda: (None, None, None, None),
288
  inputs=None,
289
+ outputs=[audio_input, fam_table, super_table, code_table]
290
  )
291
 
292
  gr.Markdown(