st192011 commited on
Commit
0353a67
Β·
verified Β·
1 Parent(s): 1c0c2a7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +65 -73
app.py CHANGED
@@ -2,118 +2,110 @@ import gradio as gr
2
  import os
3
  import random
4
  import soundfile as sf
 
5
  from transformers import pipeline
6
  from datasets import load_dataset
7
  from gradio_client import Client
8
  from stats_data import get_indomain_breakdown, get_experimental_summary, SPEAKER_META
9
 
10
- # 1. Setup Local Whisper (Baseline)
11
- # Running locally ensures the user gets an immediate baseline result
12
  whisper_asr = pipeline("automatic-speech-recognition", model="openai/whisper-tiny")
13
 
14
- # 2. Setup Private Backend Connection
15
  HF_TOKEN = os.getenv("HF_TOKEN")
16
- # Change this to your actual private space URL when ready
17
- PRIVATE_BACKEND_URL = "st192011/Torgo-DSR-Private"
18
 
19
- def get_sample_from_dataset(speaker_id):
20
- """Streams a sample from Hugging Face datasets."""
 
 
 
 
21
  try:
22
  if "UA" in speaker_id:
23
- return None, "UA-Speech samples are currently static in this lab. Please use the record function for custom UA testing.", SPEAKER_META[speaker_id]
24
-
25
- # Stream Torgo test set
26
- ds = load_dataset("unsw-cse/torgo", split="test", streaming=True)
27
- speaker_ds = ds.filter(lambda x: x["speaker_id"] == speaker_id)
 
 
 
 
 
 
28
 
29
- # Get a random sample from the first few available
30
- sample = next(iter(speaker_ds.shuffle(buffer_size=5)))
 
31
 
32
- audio_path = "temp_sample.wav"
33
  sf.write(audio_path, sample["audio"]["array"], sample["audio"]["sampling_rate"])
34
 
35
  return audio_path, sample["text"], SPEAKER_META[speaker_id]
36
  except Exception as e:
37
- return None, f"Error streaming dataset: {e}", None
38
 
39
- def run_lab_comparison(audio):
40
- if audio is None:
41
- return "Please provide audio.", "", ""
42
 
43
  # A. Local Whisper Inference
44
- w_raw = whisper_asr(audio)["text"]
45
- w_norm = w_raw.lower().strip().replace(".", "").replace("?", "")
46
 
47
- # B. Call Private Backend
48
- # This keeps your specific stacking, Allosaurus, and Gemma logic secret
49
  try:
50
  client = Client(PRIVATE_BACKEND_URL, hf_token=HF_TOKEN)
51
- # We expect the private app to return two strings: result_5k and result_10k
52
- res_5k, res_10k = client.predict(audio, w_norm, api_name="/predict_dsr_dual")
53
  except Exception as e:
54
- res_5k = "Backend Connection Required"
55
- res_10k = f"Error: {e}"
56
 
57
  return w_raw, res_5k, res_10k
58
 
59
- # UI Construction
60
- with gr.Blocks(theme=gr.themes.Soft(), title="Torgo DSR Lab") as demo:
61
  gr.Markdown("# βš—οΈ Torgo DSR Lab")
62
- gr.Markdown("### Neural Reconstruction and Correction for Severe Dysarthric Speech")
63
 
64
- with gr.Tab("πŸ”¬ Interactive Lab"):
65
- gr.Markdown("Select a speaker from the Torgo or UA-Speech datasets to compare standard ASR with our reconstruction layer.")
66
-
67
  with gr.Row():
68
  with gr.Column(scale=1):
69
- speaker_drop = gr.Dropdown(list(SPEAKER_META.keys()), label="Choose Speaker Profile")
70
- load_btn = gr.Button("🎲 Load Dataset Sample")
 
71
  gr.Markdown("---")
72
- audio_input = gr.Audio(sources=["microphone", "upload"], type="filepath", label="Input Audio (Real-time or Dataset)")
73
-
74
  with gr.Column(scale=2):
75
- with gr.Group():
76
- gr.Markdown("#### Speaker Metadata & Ground Truth")
77
- gt_display = gr.Textbox(label="Ground Truth (Human Verified)", interactive=False)
78
- meta_display = gr.JSON(label="Speaker Characteristics")
 
 
 
 
 
79
 
80
- with gr.Group():
81
- gr.Markdown("#### Reconstruction Comparison")
82
- w_out = gr.Textbox(label="Whisper Tiny Baseline (Uncorrected)")
83
- with gr.Row():
84
- out_5k = gr.Textbox(label="5K Pure Model (Acoustic Expert)")
85
- out_10k = gr.Textbox(label="10K Triple-Mix Model (Linguistic Assistant)")
86
-
87
- run_btn = gr.Button("πŸš€ Run Reconstruction Layer", variant="primary")
88
 
89
- with gr.Tab("πŸ“Š Research & Statistics"):
90
- gr.Markdown("## In-Domain Accuracy (Torgo Dataset)")
91
- gr.Markdown("This table shows the performance gain of our models across different severity levels when trained on speaker-specific data.")
 
 
 
 
 
92
  gr.DataFrame(get_indomain_breakdown())
93
 
94
- gr.Markdown("## Cross-Speaker & Cross-Domain Summary")
95
- gr.Markdown("Evaluation of the model's ability to generalize to unseen speakers (LOSO) and entirely different datasets (UA-Speech Zero-Shot).")
96
  gr.DataFrame(get_experimental_summary())
97
-
98
- gr.Markdown("""
99
- ### Key Scientific Findings
100
- * **Severity Correlation:** Standard ASR performance drops significantly as severity increases. Our models provide the highest relative gain (+100%) in the 'Severe' category.
101
- * **The Acoustic Floor:** The **5K Pure Model** (trained only on real data) provides the highest raw accuracy, proving that real-world articulatory distortions are essential for model grounding.
102
- * **Linguistic Fluency:** The **10K Triple-Mix Model** incorporates synthetic data to provide grammatically structured output, making it more suitable for assistant-based communication.
103
- * **Transfer Ability:** Our zero-shot tests on **UA-Speech (F02)** prove that the model has learned a generalized phonetic dictionary, outperforming Whisper on a completely foreign dataset.
104
- """)
105
 
106
  # Event Logic
107
- load_btn.click(
108
- get_sample_from_dataset,
109
- inputs=speaker_drop,
110
- outputs=[audio_input, gt_display, meta_display]
111
- )
112
-
113
- run_btn.click(
114
- run_lab_comparison,
115
- inputs=audio_input,
116
- outputs=[w_out, out_5k, out_10k]
117
- )
118
 
119
  demo.launch()
 
2
  import os
3
  import random
4
  import soundfile as sf
5
+ import re
6
  from transformers import pipeline
7
  from datasets import load_dataset
8
  from gradio_client import Client
9
  from stats_data import get_indomain_breakdown, get_experimental_summary, SPEAKER_META
10
 
11
+ # 1. Initialize Local Whisper (Baseline)
 
12
  whisper_asr = pipeline("automatic-speech-recognition", model="openai/whisper-tiny")
13
 
14
+ # 2. Setup Private Backend Connection (Hidden logic)
15
  HF_TOKEN = os.getenv("HF_TOKEN")
16
+ PRIVATE_BACKEND_URL = "st192011/Torgo-DSR-Private" # Update with your private space name
 
17
 
18
+ def normalize_text(text):
19
+ """Simple normalization for comparison: lowercase and strip punctuation."""
20
+ return re.sub(r'[^\w\s]', '', text).lower().strip()
21
+
22
+ def get_sample(speaker_id):
23
+ """Accesses HF Datasets via Streaming to get a sample for the UI."""
24
  try:
25
  if "UA" in speaker_id:
26
+ # Note: UA-Speech ID logic (Speaker F02)
27
+ path = "ngdiana/uaspeech_severity_high"
28
+ actual_spk = "F02"
29
+ else:
30
+ path = "unsw-cse/torgo"
31
+ actual_spk = speaker_id
32
+
33
+ # Stream dataset to avoid huge downloads
34
+ ds = load_dataset(path, split="test", streaming=True)
35
+ # Filter for the chosen speaker
36
+ speaker_ds = ds.filter(lambda x: x["speaker_id"] == actual_spk)
37
 
38
+ # Take a small buffer and pick a random sample
39
+ samples = list(speaker_ds.take(20))
40
+ sample = random.choice(samples)
41
 
42
+ audio_path = "sample_audio.wav"
43
  sf.write(audio_path, sample["audio"]["array"], sample["audio"]["sampling_rate"])
44
 
45
  return audio_path, sample["text"], SPEAKER_META[speaker_id]
46
  except Exception as e:
47
+ return None, f"Error accessing dataset: {e}", None
48
 
49
+ def run_correction(audio_path, gt_text):
50
+ if audio_path is None: return "No audio input", "", ""
 
51
 
52
  # A. Local Whisper Inference
53
+ w_raw = whisper_asr(audio_path)["text"]
54
+ w_norm = normalize_text(w_raw)
55
 
56
+ # B. Call Private Backend for the 5K and 10K results
 
57
  try:
58
  client = Client(PRIVATE_BACKEND_URL, hf_token=HF_TOKEN)
59
+ # Private app receives audio + normalized whisper, returns (5k_pred, 10k_pred)
60
+ res_5k, res_10k = client.predict(audio_path, w_norm, api_name="/predict_dsr_dual")
61
  except Exception as e:
62
+ res_5k, res_10k = "Backend Connection Required", f"Details: {e}"
 
63
 
64
  return w_raw, res_5k, res_10k
65
 
66
+ # UI Layout
67
+ with gr.Blocks(theme=gr.themes.Default(), title="Torgo DSR Lab") as demo:
68
  gr.Markdown("# βš—οΈ Torgo DSR Lab")
69
+ gr.Markdown("### Neural Reconstruction and ASR Correction for Torgo and UA-Speech")
70
 
71
+ with gr.Tab("πŸ”¬ Laboratory"):
 
 
72
  with gr.Row():
73
  with gr.Column(scale=1):
74
+ gr.Markdown("#### 1. Dataset Explorer")
75
+ spk_input = gr.Dropdown(list(SPEAKER_META.keys()), label="Select Speaker Profile")
76
+ load_btn = gr.Button("🎲 Load Random Dataset Sample")
77
  gr.Markdown("---")
78
+ audio_input = gr.Audio(sources=["microphone", "upload"], type="filepath", label="Input Audio")
79
+
80
  with gr.Column(scale=2):
81
+ gr.Markdown("#### 2. Metadata & Ground Truth")
82
+ gt_box = gr.Textbox(label="Ground Truth (Human Label)", interactive=False)
83
+ meta_box = gr.JSON(label="Speaker Characteristics")
84
+
85
+ gr.Markdown("#### 3. Comparison Results")
86
+ w_out = gr.Textbox(label="Whisper Tiny Baseline (Raw Transcript)")
87
+ with gr.Row():
88
+ out_5k = gr.Textbox(label="5K Pure Model (Acoustic Focus)")
89
+ out_10k = gr.Textbox(label="10K Triple-Mix Model (Linguistic Focus)")
90
 
91
+ run_btn = gr.Button("πŸš€ Run Correction Layer", variant="primary")
 
 
 
 
 
 
 
92
 
93
+ with gr.Tab("πŸ“Š Research Statistics"):
94
+ gr.Markdown("# πŸ”¬ Evaluation Metrics")
95
+ gr.Markdown("""
96
+ **Metric:** Exact Match Accuracy.
97
+ Calculated by comparing the **normalized prediction** (lowercase, no punctuation) against the **normalized ground truth**.
98
+ """)
99
+
100
+ gr.Markdown("### 1. In-Domain Torgo Breakdown (By Speaker)")
101
  gr.DataFrame(get_indomain_breakdown())
102
 
103
+ gr.Markdown("### 2. Experimental Milestone Summary")
104
+ gr.Markdown("_Note: The 10K model was utilized to test generalization via LOSO on unseen speaker F01._")
105
  gr.DataFrame(get_experimental_summary())
 
 
 
 
 
 
 
 
106
 
107
  # Event Logic
108
+ load_btn.click(get_sample, inputs=spk_input, outputs=[audio_input, gt_box, meta_box])
109
+ run_btn.click(run_correction, inputs=[audio_input, gt_box], outputs=[w_out, out_5k, out_10k])
 
 
 
 
 
 
 
 
 
110
 
111
  demo.launch()