st192011 commited on
Commit
07ead75
Β·
verified Β·
1 Parent(s): 034cefa

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +55 -62
app.py CHANGED
@@ -5,57 +5,59 @@ import re
5
  import random
6
  import librosa
7
  import soundfile as sf
 
8
  from transformers import pipeline
9
  from datasets import load_dataset, Audio
10
  from gradio_client import Client
11
  from stats_data import get_indomain_breakdown, get_experimental_summary, SPEAKER_META
12
 
13
- # 1. Initialize Baseline ASR (Strict English, Repetition Penalty 3.0)
14
- print("Loading Whisper Tiny...")
 
 
 
 
15
  whisper_asr = pipeline(
16
  "automatic-speech-recognition",
17
  model="openai/whisper-tiny",
18
- generate_kwargs={"language": "en", "task": "transcribe", "repetition_penalty": 3.0}
19
  )
20
 
21
- HF_TOKEN = os.getenv("HF_TOKEN")
22
- PRIVATE_BACKEND_URL = "st192011/Torgo-DSR-Private"
23
-
24
- def normalize(text):
25
  if not text: return ""
26
  return re.sub(r'[^\w\s]', '', text).lower().strip()
27
 
 
 
 
 
 
 
 
 
28
  # --- Logic: Data Loading ---
29
  def get_sample_logic(speaker_id):
30
  try:
31
- if "UA" in speaker_id:
32
- # UA-Speech Access (Direct pull for F02)
33
  dataset = load_dataset("resproj007/uaspeech_female", split="train", streaming=True)
34
  dataset = dataset.cast_column("audio", Audio(decode=False))
35
- # UA is small, skip slightly for variety
36
  sample = next(iter(dataset.skip(random.randint(0, 30))))
37
- gt_text = sample.get('text') or sample.get('transcription') or sample.get('sentence')
38
  else:
39
- # Torgo Access (Manual filtering as per Colab fix)
40
  dataset = load_dataset("abnerh/TORGO-database", split="train", streaming=True)
41
  dataset = dataset.cast_column("audio", Audio(decode=False))
42
-
43
- def filter_spk(x):
44
- sid = str(x.get('speaker_id', '')).upper()
45
- if not sid or sid == "NONE":
46
- sid = os.path.basename(x['audio']['path']).split('_')[0].upper()
47
- return sid == speaker_id
48
-
49
- speaker_ds = dataset.filter(filter_spk)
50
- sample = next(iter(speaker_ds.shuffle(buffer_size=10)))
51
- gt_text = sample.get('transcription') or sample.get('text')
52
 
53
- # Decode Bytes manually to bypass torchcodec errors
54
  audio_bytes = sample['audio']['bytes']
55
- audio_data, sr = librosa.load(io.BytesIO(audio_bytes), sr=16000)
 
56
  temp_path = "sample.wav"
57
- sf.write(temp_path, audio_data, sr)
58
 
 
59
  return temp_path, gt_text.lower().strip(), SPEAKER_META[speaker_id]
60
  except Exception as e:
61
  return None, f"Dataset Error: {e}", {}
@@ -63,38 +65,42 @@ def get_sample_logic(speaker_id):
63
  # --- Logic: Model Steps ---
64
  def run_whisper_step(audio_path):
65
  if not audio_path: return "No audio loaded", ""
66
- result = whisper_asr(audio_path)
 
 
67
  raw_w = result["text"]
68
- norm_w = normalize(raw_w)
69
  return raw_w, norm_w
70
 
71
  def run_model_step(audio_path, norm_whisper):
72
- if not audio_path or not norm_whisper: return "Complete Steps 1 & 2 first."
 
 
73
  try:
74
- # Call the private space for the 5K Gemma Model prediction
75
  client = Client(PRIVATE_BACKEND_URL, hf_token=HF_TOKEN)
76
- prediction = client.predict(audio_path, norm_whisper, api_name="/predict_dsr")
77
  return prediction
78
  except Exception as e:
79
- return f"Backend Offline. Research Details: {e}"
80
 
81
  # --- UI Construction ---
82
  with gr.Blocks(theme=gr.themes.Soft(), title="Torgo DSR Lab") as demo:
83
  gr.Markdown("# βš—οΈ Torgo DSR Lab")
84
- gr.Markdown("Neural Reconstruction for Severe Dysarthria benchmarked on Torgo and UA-Speech.")
85
 
86
- current_audio_path = gr.State("")
87
-
88
  with gr.Tab("πŸ”¬ Laboratory"):
89
  with gr.Row():
90
  with gr.Column(scale=1):
91
- gr.Markdown("### Step 1: Load Sample")
92
- speaker_input = gr.Dropdown(sorted(list(SPEAKER_META.keys())), label="Speaker ID", value="F01")
93
- load_btn = gr.Button("Load Data")
94
- meta_display = gr.JSON(label="Speaker Meta")
95
- gt_box = gr.Textbox(label="Ground Truth")
96
- # Added visible audio for user verification
97
- audio_preview = gr.Audio(label="Audio Preview", type="filepath")
 
 
 
98
 
99
  with gr.Column(scale=2):
100
  gr.Markdown("### Step 2: ASR Baseline")
@@ -109,35 +115,22 @@ with gr.Blocks(theme=gr.themes.Soft(), title="Torgo DSR Lab") as demo:
109
 
110
  with gr.Tab("πŸ“Š Research Statistics"):
111
  gr.Markdown("# πŸ”¬ Performance Evaluation")
112
-
113
  with gr.Row():
114
  with gr.Column():
115
- gr.Markdown("""
116
- ### πŸ“ Metric: Exact Match Accuracy
117
- Accuracy is the percentage of samples where the **normalized prediction** (lowercase, no punctuation) matches the **ground truth**.
118
- """)
119
-
120
  with gr.Column():
121
- gr.Markdown("""
122
- ### πŸ§ͺ Model Definitions
123
- * **5K Pure Model:** Trained on real articulatory distortions. Optimized for phonetic fidelity.
124
- * **10K Triple-Mix Model:** Includes synthetic data and anchors; utilized for generalization testing.
125
- """)
126
 
127
  gr.Markdown("## 1. Torgo In-Domain Analysis")
128
  gr.DataFrame(get_indomain_breakdown())
129
-
130
  gr.Markdown("## 2. Experimental Summary")
131
  gr.DataFrame(get_experimental_summary())
132
 
133
- # Event Mapping
134
- load_btn.click(
135
- get_sample_logic,
136
- inputs=speaker_input,
137
- outputs=[current_audio_path, gt_box, meta_display]
138
- ).then(lambda x: x, inputs=current_audio_path, outputs=audio_preview)
139
-
140
- whisper_btn.click(run_whisper_step, inputs=current_audio_path, outputs=[w_raw, w_norm])
141
- model_btn.click(run_model_step, inputs=[current_audio_path, w_norm], outputs=final_out)
142
 
143
  demo.launch()
 
5
  import random
6
  import librosa
7
  import soundfile as sf
8
+ import pandas as pd
9
  from transformers import pipeline
10
  from datasets import load_dataset, Audio
11
  from gradio_client import Client
12
  from stats_data import get_indomain_breakdown, get_experimental_summary, SPEAKER_META
13
 
14
+ # 1. Configuration
15
+ TORGO_INDICES = {'FC01': 0, 'FC02': 302, 'FC03': 2489, 'MC02': 4411, 'MC01': 5534, 'MC03': 7689, 'MC04': 9358, 'M05': 10978, 'M02': 11565, 'M04': 12337, 'M01': 13003, 'F01': 13746, 'M03': 13982, 'F04': 14792, 'F03': 15465}
16
+ HF_TOKEN = os.getenv("HF_TOKEN")
17
+ PRIVATE_BACKEND_URL = "st192011/Torgo-DSR-Private"
18
+
19
+ # 2. Local Whisper Baseline (Strict English, Repetition Penalty 3.0)
20
  whisper_asr = pipeline(
21
  "automatic-speech-recognition",
22
  model="openai/whisper-tiny",
23
+ generate_kwargs={"language": "en", "task": "transcribe", "repetition_penalty": 3.0, "max_new_tokens": 64}
24
  )
25
 
26
+ def normalize_text(text):
 
 
 
27
  if not text: return ""
28
  return re.sub(r'[^\w\s]', '', text).lower().strip()
29
 
30
+ def standardize_audio(input_path):
31
+ """Ensures audio is 16kHz, Mono, and compatible with all models."""
32
+ if not input_path: return None
33
+ audio, sr = librosa.load(input_path, sr=16000, mono=True)
34
+ out_path = "processed_audio.wav"
35
+ sf.write(out_path, audio, 16000)
36
+ return out_path
37
+
38
  # --- Logic: Data Loading ---
39
  def get_sample_logic(speaker_id):
40
  try:
41
+ if speaker_id == "F02 (UA)":
 
42
  dataset = load_dataset("resproj007/uaspeech_female", split="train", streaming=True)
43
  dataset = dataset.cast_column("audio", Audio(decode=False))
 
44
  sample = next(iter(dataset.skip(random.randint(0, 30))))
45
+ gt_text = sample.get('text') or sample.get('transcription') or "Unknown"
46
  else:
 
47
  dataset = load_dataset("abnerh/TORGO-database", split="train", streaming=True)
48
  dataset = dataset.cast_column("audio", Audio(decode=False))
49
+ start_idx = TORGO_INDICES.get(speaker_id, 0)
50
+ sample = next(iter(dataset.skip(start_idx + random.randint(0, 15))))
51
+ gt_text = sample.get('transcription') or sample.get('text') or "Unknown"
 
 
 
 
 
 
 
52
 
53
+ # Decode Bytes manually and Standardize
54
  audio_bytes = sample['audio']['bytes']
55
+ audio_data, _ = librosa.load(io.BytesIO(audio_bytes), sr=16000, mono=True)
56
+
57
  temp_path = "sample.wav"
58
+ sf.write(temp_path, audio_data, 16000)
59
 
60
+ # We return the path to the gr.Audio component (which stores it in State)
61
  return temp_path, gt_text.lower().strip(), SPEAKER_META[speaker_id]
62
  except Exception as e:
63
  return None, f"Dataset Error: {e}", {}
 
65
  # --- Logic: Model Steps ---
66
  def run_whisper_step(audio_path):
67
  if not audio_path: return "No audio loaded", ""
68
+ # Standardize format before Whisper
69
+ clean_audio = standardize_audio(audio_path)
70
+ result = whisper_asr(clean_audio)
71
  raw_w = result["text"]
72
+ norm_w = normalize_text(raw_w)
73
  return raw_w, norm_w
74
 
75
  def run_model_step(audio_path, norm_whisper):
76
+ if not audio_path or not norm_whisper: return "Complete Step 1 & 2 first."
77
+ # Standardize format before sending to Private Backend
78
+ clean_audio = standardize_audio(audio_path)
79
  try:
 
80
  client = Client(PRIVATE_BACKEND_URL, hf_token=HF_TOKEN)
81
+ prediction = client.predict(clean_audio, norm_whisper, api_name="/predict_dsr")
82
  return prediction
83
  except Exception as e:
84
+ return f"Backend Offline. Details: {e}"
85
 
86
  # --- UI Construction ---
87
  with gr.Blocks(theme=gr.themes.Soft(), title="Torgo DSR Lab") as demo:
88
  gr.Markdown("# βš—οΈ Torgo DSR Lab")
89
+ gr.Markdown("Neural Reconstruction for Severe Dysarthria. Load samples from Torgo/UA or record your own.")
90
 
 
 
91
  with gr.Tab("πŸ”¬ Laboratory"):
92
  with gr.Row():
93
  with gr.Column(scale=1):
94
+ gr.Markdown("### Step 1: Input Audio")
95
+ speaker_input = gr.Dropdown(sorted(list(SPEAKER_META.keys())), label="Speaker Profile (for Dataset Samples)")
96
+ load_btn = gr.Button("🎲 Load Dataset Sample")
97
+
98
+ gr.Markdown("---")
99
+ # Unified Input: Handles both Dataset Samples and User Input
100
+ audio_input = gr.Audio(sources=["microphone", "upload"], type="filepath", label="Audio Input (Record/Upload/Dataset)")
101
+
102
+ meta_display = gr.JSON(label="Speaker Metadata")
103
+ gt_box = gr.Textbox(label="Ground Truth (if from dataset)")
104
 
105
  with gr.Column(scale=2):
106
  gr.Markdown("### Step 2: ASR Baseline")
 
115
 
116
  with gr.Tab("πŸ“Š Research Statistics"):
117
  gr.Markdown("# πŸ”¬ Performance Evaluation")
 
118
  with gr.Row():
119
  with gr.Column():
120
+ gr.Markdown("### πŸ“ Metric: Exact Match Accuracy")
121
+ gr.Markdown("Accuracy is calculated on normalized text (lowercase, no punctuation).")
 
 
 
122
  with gr.Column():
123
+ gr.Markdown("### πŸ§ͺ Model Definitions")
124
+ gr.Markdown("* **5K Pure Model:** Real data focus. \n* **10K Triple-Mix Model:** LOSO Generalization focus.")
 
 
 
125
 
126
  gr.Markdown("## 1. Torgo In-Domain Analysis")
127
  gr.DataFrame(get_indomain_breakdown())
 
128
  gr.Markdown("## 2. Experimental Summary")
129
  gr.DataFrame(get_experimental_summary())
130
 
131
+ # Logic connections
132
+ load_btn.click(get_sample_logic, inputs=speaker_input, outputs=[audio_input, gt_box, meta_display])
133
+ whisper_btn.click(run_whisper_step, inputs=audio_input, outputs=[w_raw, w_norm])
134
+ model_btn.click(run_model_step, inputs=[audio_input, w_norm], outputs=final_out)
 
 
 
 
 
135
 
136
  demo.launch()