st192011 commited on
Commit
034cefa
·
verified ·
1 Parent(s): b3a0889

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +47 -49
app.py CHANGED
@@ -5,103 +5,96 @@ import re
5
  import random
6
  import librosa
7
  import soundfile as sf
8
- import pandas as pd
9
  from transformers import pipeline
10
  from datasets import load_dataset, Audio
11
  from gradio_client import Client
12
  from stats_data import get_indomain_breakdown, get_experimental_summary, SPEAKER_META
13
 
14
- # 1. Configuration & Indices
15
- TORGO_INDICES = {'FC01': 0, 'FC02': 302, 'FC03': 2489, 'MC02': 4411, 'MC01': 5534, 'MC03': 7689, 'MC04': 9358, 'M05': 10978, 'M02': 11565, 'M04': 12337, 'M01': 13003, 'F01': 13746, 'M03': 13982, 'F04': 14792, 'F03': 15465}
16
-
17
- HF_TOKEN = os.getenv("HF_TOKEN")
18
- PRIVATE_BACKEND_URL = "st192011/Torgo-DSR-Private"
19
-
20
- # 2. Local Whisper Baseline
21
  print("Loading Whisper Tiny...")
22
  whisper_asr = pipeline(
23
  "automatic-speech-recognition",
24
  model="openai/whisper-tiny",
25
- generate_kwargs={
26
- "language": "en",
27
- "task": "transcribe",
28
- "repetition_penalty": 3.0,
29
- "max_new_tokens": 64
30
- }
31
  )
32
 
33
- def normalize_text(text):
 
 
 
34
  if not text: return ""
35
  return re.sub(r'[^\w\s]', '', text).lower().strip()
36
 
37
- # --- Logic Functions ---
38
-
39
  def get_sample_logic(speaker_id):
40
- """Bypasses internal decoders for both Torgo and UA to avoid environment errors."""
41
  try:
42
- if speaker_id == "F02":
 
43
  dataset = load_dataset("resproj007/uaspeech_female", split="train", streaming=True)
44
  dataset = dataset.cast_column("audio", Audio(decode=False))
45
- # UA dataset is usually smaller; iterate to find variety or use F02 specifically
46
- sample = next(iter(dataset.shuffle(buffer_size=50)))
 
47
  else:
 
48
  dataset = load_dataset("abnerh/TORGO-database", split="train", streaming=True)
49
  dataset = dataset.cast_column("audio", Audio(decode=False))
50
 
51
- start_idx = TORGO_INDICES.get(speaker_id, 0)
52
- # Jump directly to speaker start + random offset within speaker range
53
- sample = next(iter(dataset.skip(start_idx + random.randint(0, 15))))
 
 
 
 
 
 
54
 
55
- # Process Ground Truth
56
- gt_text = sample.get('transcription') or sample.get('text') or sample.get('sentence') or "Unknown"
57
-
58
- # Manual Decode via Librosa to ensure stability on CPU tier
59
  audio_bytes = sample['audio']['bytes']
60
- audio_data, sample_rate = librosa.load(io.BytesIO(audio_bytes), sr=16000)
61
-
62
- temp_path = "current_sample.wav"
63
- sf.write(temp_path, audio_data, sample_rate)
64
 
65
- return temp_path, gt_text.lower().strip(), SPEAKER_META.get(speaker_id, {})
66
-
67
  except Exception as e:
68
- return None, f"Dataset Access Error: {e}", {}
69
 
 
70
  def run_whisper_step(audio_path):
71
  if not audio_path: return "No audio loaded", ""
72
  result = whisper_asr(audio_path)
73
  raw_w = result["text"]
74
- norm_w = normalize_text(raw_w)
75
  return raw_w, norm_w
76
 
77
  def run_model_step(audio_path, norm_whisper):
78
- if not audio_path or not norm_whisper: return "Load data and run Whisper first."
79
-
80
  try:
 
81
  client = Client(PRIVATE_BACKEND_URL, hf_token=HF_TOKEN)
82
- # Calls private app for Gemma 3 5K Model prediction
83
  prediction = client.predict(audio_path, norm_whisper, api_name="/predict_dsr")
84
  return prediction
85
  except Exception as e:
86
  return f"Backend Offline. Research Details: {e}"
87
 
88
- # --- UI Layout ---
89
  with gr.Blocks(theme=gr.themes.Soft(), title="Torgo DSR Lab") as demo:
90
  gr.Markdown("# ⚗️ Torgo DSR Lab")
91
- gr.Markdown("Reconstruction and Correction layer for severe dysarthric speech.")
92
 
93
  current_audio_path = gr.State("")
94
 
95
  with gr.Tab("🔬 Laboratory"):
96
  with gr.Row():
97
  with gr.Column(scale=1):
98
- gr.Markdown("### Step 1: Select Speaker")
99
- # Removed 'FC' control speakers from dropdown as requested
100
- dysarthric_speakers = ["F01", "F03", "F04", "M01", "M02", "M03", "M04", "M05", "F02"]
101
- speaker_input = gr.Dropdown(sorted(dysarthric_speakers), label="Speaker ID", value="F01")
102
  load_btn = gr.Button("Load Data")
103
  meta_display = gr.JSON(label="Speaker Meta")
104
  gt_box = gr.Textbox(label="Ground Truth")
 
 
105
 
106
  with gr.Column(scale=2):
107
  gr.Markdown("### Step 2: ASR Baseline")
@@ -121,24 +114,29 @@ with gr.Blocks(theme=gr.themes.Soft(), title="Torgo DSR Lab") as demo:
121
  with gr.Column():
122
  gr.Markdown("""
123
  ### 📏 Metric: Exact Match Accuracy
124
- Accuracy is calculated as the percentage of samples where the **normalized prediction** (lowercase, no punctuation) exactly matches the **ground truth**.
125
  """)
126
 
127
  with gr.Column():
128
  gr.Markdown("""
129
  ### 🧪 Model Definitions
130
  * **5K Pure Model:** Trained on real articulatory distortions. Optimized for phonetic fidelity.
131
- * **10K Triple-Mix Model:** Includes anchors and synthetic data. Used to test **generalization (LOSO)** on unseen speakers.
132
  """)
133
 
134
- gr.Markdown("## 1. Torgo In-Domain Breakdown (By Speaker)")
135
  gr.DataFrame(get_indomain_breakdown())
136
 
137
  gr.Markdown("## 2. Experimental Summary")
138
  gr.DataFrame(get_experimental_summary())
139
 
140
  # Event Mapping
141
- load_btn.click(get_sample_logic, inputs=speaker_input, outputs=[current_audio_path, gt_box, meta_display])
 
 
 
 
 
142
  whisper_btn.click(run_whisper_step, inputs=current_audio_path, outputs=[w_raw, w_norm])
143
  model_btn.click(run_model_step, inputs=[current_audio_path, w_norm], outputs=final_out)
144
 
 
5
  import random
6
  import librosa
7
  import soundfile as sf
 
8
  from transformers import pipeline
9
  from datasets import load_dataset, Audio
10
  from gradio_client import Client
11
  from stats_data import get_indomain_breakdown, get_experimental_summary, SPEAKER_META
12
 
13
+ # 1. Initialize Baseline ASR (Strict English, Repetition Penalty 3.0)
 
 
 
 
 
 
14
  print("Loading Whisper Tiny...")
15
  whisper_asr = pipeline(
16
  "automatic-speech-recognition",
17
  model="openai/whisper-tiny",
18
+ generate_kwargs={"language": "en", "task": "transcribe", "repetition_penalty": 3.0}
 
 
 
 
 
19
  )
20
 
21
+ HF_TOKEN = os.getenv("HF_TOKEN")
22
+ PRIVATE_BACKEND_URL = "st192011/Torgo-DSR-Private"
23
+
24
+ def normalize(text):
25
  if not text: return ""
26
  return re.sub(r'[^\w\s]', '', text).lower().strip()
27
 
28
+ # --- Logic: Data Loading ---
 
29
  def get_sample_logic(speaker_id):
 
30
  try:
31
+ if "UA" in speaker_id:
32
+ # UA-Speech Access (Direct pull for F02)
33
  dataset = load_dataset("resproj007/uaspeech_female", split="train", streaming=True)
34
  dataset = dataset.cast_column("audio", Audio(decode=False))
35
+ # UA is small, skip slightly for variety
36
+ sample = next(iter(dataset.skip(random.randint(0, 30))))
37
+ gt_text = sample.get('text') or sample.get('transcription') or sample.get('sentence')
38
  else:
39
+ # Torgo Access (Manual filtering as per Colab fix)
40
  dataset = load_dataset("abnerh/TORGO-database", split="train", streaming=True)
41
  dataset = dataset.cast_column("audio", Audio(decode=False))
42
 
43
+ def filter_spk(x):
44
+ sid = str(x.get('speaker_id', '')).upper()
45
+ if not sid or sid == "NONE":
46
+ sid = os.path.basename(x['audio']['path']).split('_')[0].upper()
47
+ return sid == speaker_id
48
+
49
+ speaker_ds = dataset.filter(filter_spk)
50
+ sample = next(iter(speaker_ds.shuffle(buffer_size=10)))
51
+ gt_text = sample.get('transcription') or sample.get('text')
52
 
53
+ # Decode Bytes manually to bypass torchcodec errors
 
 
 
54
  audio_bytes = sample['audio']['bytes']
55
+ audio_data, sr = librosa.load(io.BytesIO(audio_bytes), sr=16000)
56
+ temp_path = "sample.wav"
57
+ sf.write(temp_path, audio_data, sr)
 
58
 
59
+ return temp_path, gt_text.lower().strip(), SPEAKER_META[speaker_id]
 
60
  except Exception as e:
61
+ return None, f"Dataset Error: {e}", {}
62
 
63
+ # --- Logic: Model Steps ---
64
  def run_whisper_step(audio_path):
65
  if not audio_path: return "No audio loaded", ""
66
  result = whisper_asr(audio_path)
67
  raw_w = result["text"]
68
+ norm_w = normalize(raw_w)
69
  return raw_w, norm_w
70
 
71
  def run_model_step(audio_path, norm_whisper):
72
+ if not audio_path or not norm_whisper: return "Complete Steps 1 & 2 first."
 
73
  try:
74
+ # Call the private space for the 5K Gemma Model prediction
75
  client = Client(PRIVATE_BACKEND_URL, hf_token=HF_TOKEN)
 
76
  prediction = client.predict(audio_path, norm_whisper, api_name="/predict_dsr")
77
  return prediction
78
  except Exception as e:
79
  return f"Backend Offline. Research Details: {e}"
80
 
81
+ # --- UI Construction ---
82
  with gr.Blocks(theme=gr.themes.Soft(), title="Torgo DSR Lab") as demo:
83
  gr.Markdown("# ⚗️ Torgo DSR Lab")
84
+ gr.Markdown("Neural Reconstruction for Severe Dysarthria benchmarked on Torgo and UA-Speech.")
85
 
86
  current_audio_path = gr.State("")
87
 
88
  with gr.Tab("🔬 Laboratory"):
89
  with gr.Row():
90
  with gr.Column(scale=1):
91
+ gr.Markdown("### Step 1: Load Sample")
92
+ speaker_input = gr.Dropdown(sorted(list(SPEAKER_META.keys())), label="Speaker ID", value="F01")
 
 
93
  load_btn = gr.Button("Load Data")
94
  meta_display = gr.JSON(label="Speaker Meta")
95
  gt_box = gr.Textbox(label="Ground Truth")
96
+ # Added visible audio for user verification
97
+ audio_preview = gr.Audio(label="Audio Preview", type="filepath")
98
 
99
  with gr.Column(scale=2):
100
  gr.Markdown("### Step 2: ASR Baseline")
 
114
  with gr.Column():
115
  gr.Markdown("""
116
  ### 📏 Metric: Exact Match Accuracy
117
+ Accuracy is the percentage of samples where the **normalized prediction** (lowercase, no punctuation) matches the **ground truth**.
118
  """)
119
 
120
  with gr.Column():
121
  gr.Markdown("""
122
  ### 🧪 Model Definitions
123
  * **5K Pure Model:** Trained on real articulatory distortions. Optimized for phonetic fidelity.
124
+ * **10K Triple-Mix Model:** Includes synthetic data and anchors; utilized for generalization testing.
125
  """)
126
 
127
+ gr.Markdown("## 1. Torgo In-Domain Analysis")
128
  gr.DataFrame(get_indomain_breakdown())
129
 
130
  gr.Markdown("## 2. Experimental Summary")
131
  gr.DataFrame(get_experimental_summary())
132
 
133
  # Event Mapping
134
+ load_btn.click(
135
+ get_sample_logic,
136
+ inputs=speaker_input,
137
+ outputs=[current_audio_path, gt_box, meta_display]
138
+ ).then(lambda x: x, inputs=current_audio_path, outputs=audio_preview)
139
+
140
  whisper_btn.click(run_whisper_step, inputs=current_audio_path, outputs=[w_raw, w_norm])
141
  model_btn.click(run_model_step, inputs=[current_audio_path, w_norm], outputs=final_out)
142