st192011 commited on
Commit
43fb18f
Β·
verified Β·
1 Parent(s): 77940c5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +88 -83
app.py CHANGED
@@ -10,7 +10,8 @@ from datasets import load_dataset, Audio
10
  from gradio_client import Client
11
  from stats_data import get_indomain_breakdown, get_experimental_summary, SPEAKER_META
12
 
13
- # 1. Setup Local Whisper Baseline (English, Strict Generation)
 
14
  whisper_asr = pipeline(
15
  "automatic-speech-recognition",
16
  model="openai/whisper-tiny",
@@ -20,119 +21,123 @@ whisper_asr = pipeline(
20
  HF_TOKEN = os.getenv("HF_TOKEN")
21
  PRIVATE_BACKEND_URL = "st192011/Torgo-DSR-Private"
22
 
23
- def normalize_text(text):
24
  if not text: return ""
25
  return re.sub(r'[^\w\s]', '', text).lower().strip()
26
 
27
- # --- Shared Processing Logic ---
28
-
29
- def process_audio_file(audio_path):
30
- """Ensures any input audio is formatted correctly for ASR systems (16kHz Mono)."""
31
- y, sr = librosa.load(audio_path, sr=16000)
32
- fixed_path = "processed_audio.wav"
33
- sf.write(fixed_path, y, sr)
34
- return fixed_path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
- def run_whisper_logic(audio_path):
 
37
  if not audio_path: return "No audio loaded", ""
38
- formatted_path = process_audio_file(audio_path)
39
- result = whisper_asr(formatted_path)
40
  raw_w = result["text"]
41
- norm_w = normalize_text(raw_w)
42
  return raw_w, norm_w
43
 
44
- def run_reconstruction_logic(audio_path, norm_whisper):
45
- if not audio_path or not norm_whisper: return "Run Whisper step first."
46
  try:
 
47
  client = Client(PRIVATE_BACKEND_URL, hf_token=HF_TOKEN)
48
- # Private backend handles Wav2Vec, Allosaurus, and Gemma 3 arbitration
49
  prediction = client.predict(audio_path, norm_whisper, api_name="/predict_dsr")
50
  return prediction
51
  except Exception as e:
52
- return f"Backend Offline. Error: {e}"
53
-
54
- # --- Channel 1: Dataset Loader ---
55
- def get_dataset_sample(speaker_id):
56
- try:
57
- if speaker_id == "F02":
58
- ds = load_dataset("resproj007/uaspeech_female", split="train", streaming=True)
59
- ds = ds.cast_column("audio", Audio(decode=False))
60
- sample = next(iter(ds.skip(random.randint(0, 50))))
61
- gt_text = sample.get('text') or sample.get('transcription') or "Unknown"
62
- else:
63
- ds = load_dataset("abnerh/TORGO-database", split="train", streaming=True)
64
- ds = ds.cast_column("audio", Audio(decode=False))
65
- indices = {'M05': 10978, 'M02': 11565, 'M04': 12337, 'M01': 13003, 'F01': 13746, 'M03': 13982, 'F04': 14792, 'F03': 15465}
66
- start_idx = indices.get(speaker_id, 0)
67
- sample = next(iter(ds.skip(start_idx + random.randint(0, 10))))
68
- gt_text = sample.get('transcription') or sample.get('text') or "Unknown"
69
-
70
- audio_data, sr = librosa.load(io.BytesIO(sample['audio']['bytes']), sr=16000)
71
- temp_path = f"sample_{speaker_id}.wav"
72
- sf.write(temp_path, audio_data, sr)
73
- return temp_path, gt_text.lower().strip(), SPEAKER_META.get(speaker_id, {})
74
- except Exception as e:
75
- return None, f"Dataset Error: {e}", {}
76
 
77
- # --- UI Layout ---
78
  with gr.Blocks(theme=gr.themes.Soft(), title="Torgo DSR Lab") as demo:
79
  gr.Markdown("# βš—οΈ Torgo DSR Lab")
80
- gr.Markdown("ASR Correction and Reconstruction Layer for Torgo and UA-Speech.")
81
 
82
- # States for audio paths
83
- lab_audio_state = gr.State("")
84
- user_audio_state = gr.State("")
85
 
86
- with gr.Tab("πŸ”¬ Research Samples"):
87
- gr.Markdown("Select clinical samples from the Torgo or UA-Speech datasets.")
88
  with gr.Row():
89
  with gr.Column(scale=1):
 
90
  speaker_input = gr.Dropdown(sorted(list(SPEAKER_META.keys())), label="Speaker ID", value="F01")
91
- load_btn = gr.Button("Load Sample Data")
92
- meta_display = gr.JSON(label="Sample Metadata")
93
  gt_box = gr.Textbox(label="Ground Truth")
94
-
95
- with gr.Column(scale=2):
96
- whisper_btn_lab = gr.Button("1. Generate Whisper Baseline")
97
- w_raw_lab = gr.Textbox(label="Whisper Raw")
98
- w_norm_lab = gr.Textbox(label="Whisper Normalized")
99
-
100
- gr.Markdown("---")
101
- model_btn_lab = gr.Button("2. Run Neural Reconstruction", variant="primary")
102
- final_out_lab = gr.Textbox(label="DSR Lab Prediction")
103
 
104
- with gr.Tab("🎀 Personal Test"):
105
- gr.Markdown("Record or upload your own audio to test the reconstruction layer.")
106
- with gr.Row():
107
- with gr.Column(scale=1):
108
- user_audio_input = gr.Audio(sources=["microphone", "upload"], type="filepath", label="User Audio")
109
- process_user_btn = gr.Button("Prepare Audio")
110
-
111
  with gr.Column(scale=2):
112
- whisper_btn_user = gr.Button("1. Generate Whisper Baseline")
113
- w_raw_user = gr.Textbox(label="Whisper Raw")
114
- w_norm_user = gr.Textbox(label="Whisper Normalized")
 
115
 
116
  gr.Markdown("---")
117
- model_btn_user = gr.Button("2. Run Neural Reconstruction", variant="primary")
118
- final_out_user = gr.Textbox(label="DSR Lab Prediction")
 
119
 
120
  with gr.Tab("πŸ“Š Research Statistics"):
121
- gr.Markdown("# πŸ”¬ Scientific Evaluation")
122
- gr.Markdown("**Metric:** Exact Match Accuracy on normalized text (lowercase, no punctuation).")
123
- gr.Markdown("## 1. Torgo In-Domain Breakdown")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
  gr.DataFrame(get_indomain_breakdown())
125
- gr.Markdown("## 2. Experimental Milestone Summary")
 
126
  gr.DataFrame(get_experimental_summary())
127
 
128
- # --- Events: Research Tab ---
129
- load_btn.click(get_dataset_sample, inputs=speaker_input, outputs=[lab_audio_state, gt_box, meta_display])
130
- whisper_btn_lab.click(run_whisper_logic, inputs=lab_audio_state, outputs=[w_raw_lab, w_norm_lab])
131
- model_btn_lab.click(run_reconstruction_logic, inputs=[lab_audio_state, w_norm_lab], outputs=final_out_lab)
 
 
132
 
133
- # --- Events: Personal Tab ---
134
- process_user_btn.click(lambda x: x, inputs=user_audio_input, outputs=user_audio_state)
135
- whisper_btn_user.click(run_whisper_logic, inputs=user_audio_state, outputs=[w_raw_user, w_norm_user])
136
- model_btn_user.click(run_reconstruction_logic, inputs=[user_audio_state, w_norm_user], outputs=final_out_user)
137
 
138
  demo.launch()
 
10
  from gradio_client import Client
11
  from stats_data import get_indomain_breakdown, get_experimental_summary, SPEAKER_META
12
 
13
+ # 1. Initialize Baseline ASR (Strict English, Repetition Penalty 3.0)
14
+ print("Loading Whisper Tiny...")
15
  whisper_asr = pipeline(
16
  "automatic-speech-recognition",
17
  model="openai/whisper-tiny",
 
21
  HF_TOKEN = os.getenv("HF_TOKEN")
22
  PRIVATE_BACKEND_URL = "st192011/Torgo-DSR-Private"
23
 
24
+ def normalize(text):
25
  if not text: return ""
26
  return re.sub(r'[^\w\s]', '', text).lower().strip()
27
 
28
+ # --- Logic: Data Loading ---
29
+ def get_sample_logic(speaker_id):
30
+ try:
31
+ if "UA" in speaker_id:
32
+ # UA-Speech Access (Direct pull for F02)
33
+ dataset = load_dataset("resproj007/uaspeech_female", split="train", streaming=True)
34
+ dataset = dataset.cast_column("audio", Audio(decode=False))
35
+ # UA is small, skip slightly for variety
36
+ sample = next(iter(dataset.skip(random.randint(0, 30))))
37
+ gt_text = sample.get('text') or sample.get('transcription') or sample.get('sentence')
38
+ else:
39
+ # Torgo Access (Manual filtering as per Colab fix)
40
+ dataset = load_dataset("abnerh/TORGO-database", split="train", streaming=True)
41
+ dataset = dataset.cast_column("audio", Audio(decode=False))
42
+
43
+ def filter_spk(x):
44
+ sid = str(x.get('speaker_id', '')).upper()
45
+ if not sid or sid == "NONE":
46
+ sid = os.path.basename(x['audio']['path']).split('_')[0].upper()
47
+ return sid == speaker_id
48
+
49
+ speaker_ds = dataset.filter(filter_spk)
50
+ sample = next(iter(speaker_ds.shuffle(buffer_size=10)))
51
+ gt_text = sample.get('transcription') or sample.get('text')
52
+
53
+ # Decode Bytes manually to bypass torchcodec errors
54
+ audio_bytes = sample['audio']['bytes']
55
+ audio_data, sr = librosa.load(io.BytesIO(audio_bytes), sr=16000)
56
+ temp_path = "sample.wav"
57
+ sf.write(temp_path, audio_data, sr)
58
+
59
+ return temp_path, gt_text.lower().strip(), SPEAKER_META[speaker_id]
60
+ except Exception as e:
61
+ return None, f"Dataset Error: {e}", {}
62
 
63
+ # --- Logic: Model Steps ---
64
+ def run_whisper_step(audio_path):
65
  if not audio_path: return "No audio loaded", ""
66
+ result = whisper_asr(audio_path)
 
67
  raw_w = result["text"]
68
+ norm_w = normalize(raw_w)
69
  return raw_w, norm_w
70
 
71
+ def run_model_step(audio_path, norm_whisper):
72
+ if not audio_path or not norm_whisper: return "Complete Steps 1 & 2 first."
73
  try:
74
+ # Call the private space for the 5K Gemma Model prediction
75
  client = Client(PRIVATE_BACKEND_URL, hf_token=HF_TOKEN)
 
76
  prediction = client.predict(audio_path, norm_whisper, api_name="/predict_dsr")
77
  return prediction
78
  except Exception as e:
79
+ return f"Backend Offline. Research Details: {e}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
 
81
+ # --- UI Construction ---
82
  with gr.Blocks(theme=gr.themes.Soft(), title="Torgo DSR Lab") as demo:
83
  gr.Markdown("# βš—οΈ Torgo DSR Lab")
84
+ gr.Markdown("Neural Reconstruction for Severe Dysarthria benchmarked on Torgo and UA-Speech.")
85
 
86
+ current_audio_path = gr.State("")
 
 
87
 
88
+ with gr.Tab("πŸ”¬ Laboratory"):
 
89
  with gr.Row():
90
  with gr.Column(scale=1):
91
+ gr.Markdown("### Step 1: Load Sample")
92
  speaker_input = gr.Dropdown(sorted(list(SPEAKER_META.keys())), label="Speaker ID", value="F01")
93
+ load_btn = gr.Button("Load Data")
94
+ meta_display = gr.JSON(label="Speaker Meta")
95
  gt_box = gr.Textbox(label="Ground Truth")
96
+ # Added visible audio for user verification
97
+ audio_preview = gr.Audio(label="Audio Preview", type="filepath")
 
 
 
 
 
 
 
98
 
 
 
 
 
 
 
 
99
  with gr.Column(scale=2):
100
+ gr.Markdown("### Step 2: ASR Baseline")
101
+ whisper_btn = gr.Button("Run Whisper Tiny")
102
+ w_raw = gr.Textbox(label="Whisper Raw")
103
+ w_norm = gr.Textbox(label="Whisper Normalized")
104
 
105
  gr.Markdown("---")
106
+ gr.Markdown("### Step 3: Neural Reconstruction")
107
+ model_btn = gr.Button("Run Our Model", variant="primary")
108
+ final_out = gr.Textbox(label="DSR Lab Prediction")
109
 
110
  with gr.Tab("πŸ“Š Research Statistics"):
111
+ gr.Markdown("# πŸ”¬ Performance Evaluation")
112
+
113
+ with gr.Row():
114
+ with gr.Column():
115
+ gr.Markdown("""
116
+ ### πŸ“ Metric: Exact Match Accuracy
117
+ Accuracy is the percentage of samples where the **normalized prediction** (lowercase, no punctuation) matches the **ground truth**.
118
+ """)
119
+
120
+ with gr.Column():
121
+ gr.Markdown("""
122
+ ### πŸ§ͺ Model Definitions
123
+ * **5K Pure Model:** Trained on real articulatory distortions. Optimized for phonetic fidelity.
124
+ * **10K Triple-Mix Model:** Includes synthetic data and anchors; utilized for generalization testing.
125
+ """)
126
+
127
+ gr.Markdown("## 1. Torgo In-Domain Analysis")
128
  gr.DataFrame(get_indomain_breakdown())
129
+
130
+ gr.Markdown("## 2. Experimental Summary")
131
  gr.DataFrame(get_experimental_summary())
132
 
133
+ # Event Mapping
134
+ load_btn.click(
135
+ get_sample_logic,
136
+ inputs=speaker_input,
137
+ outputs=[current_audio_path, gt_box, meta_display]
138
+ ).then(lambda x: x, inputs=current_audio_path, outputs=audio_preview)
139
 
140
+ whisper_btn.click(run_whisper_step, inputs=current_audio_path, outputs=[w_raw, w_norm])
141
+ model_btn.click(run_model_step, inputs=[current_audio_path, w_norm], outputs=final_out)
 
 
142
 
143
  demo.launch()