st192011 commited on
Commit
08dd52c
·
verified ·
1 Parent(s): b160197

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +59 -63
app.py CHANGED
@@ -10,89 +10,81 @@ from datasets import load_dataset, Audio
10
  from gradio_client import Client
11
  from stats_data import get_indomain_breakdown, get_experimental_summary, SPEAKER_META
12
 
13
- # 1. Initialize Baseline ASR with Generation Constraints
14
- # Set max_new_tokens to 64 to prevent infinite "L-O-O-O" loops
15
  whisper_asr = pipeline(
16
  "automatic-speech-recognition",
17
  model="openai/whisper-tiny",
18
  generate_kwargs={
19
  "language": "en",
20
- "task": "transcribe",
21
- "max_new_tokens": 64,
22
- "repetition_penalty": 1.5 # Discourages token looping
23
  }
24
  )
25
 
 
26
  HF_TOKEN = os.getenv("HF_TOKEN")
27
- PRIVATE_BACKEND_URL = os.getenv("PRIVATE_BACKEND_URL")
28
 
29
- def normalize_text(text):
30
  if not text: return ""
31
  return re.sub(r'[^\w\s]', '', text).lower().strip()
32
 
33
  def get_sample_logic(speaker_id):
 
34
  try:
35
- # PATH A: UA-SPEECH (Strictly following your provided running block)
36
  if speaker_id == "F02 (UA)":
37
  dataset = load_dataset("resproj007/uaspeech_female", split="train", streaming=True)
38
- # Shuffle helps pick a different word each time
39
- sample = next(iter(dataset.shuffle(buffer_size=100)))
40
- gt_text = sample.get('text') or sample.get('transcription') or sample.get('sentence', 'Unknown')
41
- audio_data = sample['audio']['array']
42
- sample_rate = sample['audio']['sampling_rate']
43
-
44
- # PATH B: TORGO (Optimized for speed)
45
  else:
46
  dataset = load_dataset("abnerh/TORGO-database", split="train", streaming=True)
47
  dataset = dataset.cast_column("audio", Audio(decode=False))
48
 
49
- # Speed Hack: Shuffle the stream buffer to find the speaker faster
50
- # This avoids starting from speaker MC01 every time
51
- shuffled_ds = dataset.shuffle(buffer_size=1000)
52
-
53
- # Find first match in shuffled stream
54
- found_sample = None
55
- for item in shuffled_ds:
56
- sid = str(item.get('speaker_id', '')).upper()
57
  if not sid or sid == "NONE":
58
- sid = os.path.basename(item['audio']['path']).split('_')[0].upper()
59
-
60
- if sid == speaker_id:
61
- found_sample = item
62
- break
63
 
64
- if not found_sample:
65
- return None, "Speaker search timeout. Try again.", {}
66
 
67
- gt_text = found_sample.get('transcription') or found_sample.get('text', 'Unknown')
68
- audio_bytes = found_sample['audio']['bytes']
69
- audio_data, sample_rate = librosa.load(io.BytesIO(audio_bytes), sr=16000)
70
-
71
- temp_path = "current_sample.wav"
72
- sf.write(temp_path, audio_data, sample_rate)
 
 
 
73
  return temp_path, gt_text.lower().strip(), SPEAKER_META[speaker_id]
74
-
75
  except Exception as e:
76
- return None, f"Dataset Error: {e}", {}
77
-
78
- def run_whisper_step(audio_path):
79
- if not audio_path: return "No audio loaded", ""
80
- result = whisper_asr(audio_path)
81
- raw_w = result["text"]
82
- norm_w = normalize_text(raw_w)
83
- return raw_w, norm_w
84
-
85
- def run_model_step(audio_path, norm_whisper):
86
- if not audio_path or not norm_whisper: return "Incomplete steps"
87
  try:
88
  client = Client(PRIVATE_BACKEND_URL, hf_token=HF_TOKEN)
89
- # Calling Private App which uses repetition_penalty=3.0
90
- prediction = client.predict(audio_path, norm_whisper, api_name="/predict_dsr")
91
- return prediction
92
  except Exception as e:
93
- return f"Backend Offline. Research Model requires Private Space access."
 
 
94
 
95
- # UI
96
  with gr.Blocks(theme=gr.themes.Soft(), title="Torgo DSR Lab") as demo:
97
  gr.Markdown("# ⚗️ Torgo DSR Lab")
98
  current_audio_path = gr.State("")
@@ -100,11 +92,12 @@ with gr.Blocks(theme=gr.themes.Soft(), title="Torgo DSR Lab") as demo:
100
  with gr.Tab("🔬 Laboratory"):
101
  with gr.Row():
102
  with gr.Column(scale=1):
103
- gr.Markdown("### Step 1: Load Sample")
104
- speaker_input = gr.Dropdown(sorted(list(SPEAKER_META.keys())), label="Speaker ID", value="F01")
105
  load_btn = gr.Button("Load Data")
106
- meta_display = gr.JSON(label="Speaker Meta")
107
  gt_box = gr.Textbox(label="Ground Truth")
 
108
 
109
  with gr.Column(scale=2):
110
  gr.Markdown("### Step 2: ASR Baseline")
@@ -129,17 +122,20 @@ with gr.Blocks(theme=gr.themes.Soft(), title="Torgo DSR Lab") as demo:
129
  gr.Markdown("""
130
  ### 🧪 Model Definitions
131
  * **5K Pure Model:** Trained on 5,000 real Torgo samples. Optimized for articulatory fidelity.
132
- * **10K Triple-Mix Model:** Includes synthetic data and anchors. Tested on **unseen speakers (LOSO)** to prove generalization.
133
  """)
134
-
135
- gr.Markdown("---")
136
- gr.Markdown("## 1. Torgo In-Domain Analysis")
137
  gr.DataFrame(get_indomain_breakdown())
 
138
  gr.Markdown("## 2. Experimental Summary")
139
  gr.DataFrame(get_experimental_summary())
140
 
141
- load_btn.click(get_sample_logic, inputs=speaker_input, outputs=[current_audio_path, gt_box, meta_display])
142
- whisper_btn.click(run_whisper_step, inputs=current_audio_path, outputs=[w_raw, w_norm])
143
- model_btn.click(run_model_step, inputs=[current_audio_path, w_norm], outputs=final_out)
 
 
 
144
 
145
  demo.launch()
 
10
  from gradio_client import Client
11
  from stats_data import get_indomain_breakdown, get_experimental_summary, SPEAKER_META
12
 
13
+ # 1. Setup Local Whisper (Forced to English, High Repetition Penalty)
14
+ print("Initializing ASR Baseline...")
15
  whisper_asr = pipeline(
16
  "automatic-speech-recognition",
17
  model="openai/whisper-tiny",
18
  generate_kwargs={
19
  "language": "en",
20
+ "task": "transcribe",
21
+ "repetition_penalty": 3.0,
22
+ "max_new_tokens": 64
23
  }
24
  )
25
 
26
+ # 2. Private Backend Config
27
  HF_TOKEN = os.getenv("HF_TOKEN")
28
+ PRIVATE_BACKEND_URL = "st192011/Torgo-DSR-Private"
29
 
30
+ def normalize(text):
31
  if not text: return ""
32
  return re.sub(r'[^\w\s]', '', text).lower().strip()
33
 
34
  def get_sample_logic(speaker_id):
35
+ """Optimized data loader: Skips normal control speakers to find targets faster."""
36
  try:
 
37
  if speaker_id == "F02 (UA)":
38
  dataset = load_dataset("resproj007/uaspeech_female", split="train", streaming=True)
39
+ dataset = dataset.cast_column("audio", Audio(decode=False))
40
+ # F02 is the primary dysarthric speaker in this split
41
+ speaker_ds = dataset.filter(lambda x: x["speaker_id"] == "F02")
 
 
 
 
42
  else:
43
  dataset = load_dataset("abnerh/TORGO-database", split="train", streaming=True)
44
  dataset = dataset.cast_column("audio", Audio(decode=False))
45
 
46
+ # Skip logic: ignore samples with 'control' status to speed up stream
47
+ def is_target_dysarthric(x):
48
+ sid = str(x.get('speaker_id', '')).upper()
 
 
 
 
 
49
  if not sid or sid == "NONE":
50
+ sid = os.path.basename(x['audio']['path']).split('_')[0].upper()
51
+ status = str(x.get('speech_status', '')).lower()
52
+ return sid == speaker_id and "control" not in status
 
 
53
 
54
+ speaker_ds = dataset.filter(is_target_dysarthric)
 
55
 
56
+ # Get sample and decode
57
+ sample = next(iter(speaker_ds.shuffle(buffer_size=10)))
58
+ gt_text = sample.get('transcription') or sample.get('text') or sample.get('sentence') or "Unknown"
59
+
60
+ audio_bytes = sample['audio']['bytes']
61
+ audio_data, sr = librosa.load(io.BytesIO(audio_bytes), sr=16000)
62
+
63
+ temp_path = "sample.wav"
64
+ sf.write(temp_path, audio_data, sr)
65
  return temp_path, gt_text.lower().strip(), SPEAKER_META[speaker_id]
 
66
  except Exception as e:
67
+ return None, f"Loading error: {e}", {}
68
+
69
+ def run_lab(audio_path):
70
+ if not audio_path: return "", "", "Error: No Audio"
71
+
72
+ # Baseline
73
+ w_res = whisper_asr(audio_path)
74
+ w_raw = w_res["text"]
75
+ w_norm = normalize(w_raw)
76
+
77
+ # Private Model Call
78
  try:
79
  client = Client(PRIVATE_BACKEND_URL, hf_token=HF_TOKEN)
80
+ # Assuming private backend returns the 5K prediction string
81
+ prediction = client.predict(audio_path, w_norm, api_name="/predict_dsr")
 
82
  except Exception as e:
83
+ prediction = f"Backend offline or Error: {e}"
84
+
85
+ return w_raw, w_norm, prediction
86
 
87
+ # UI Construction
88
  with gr.Blocks(theme=gr.themes.Soft(), title="Torgo DSR Lab") as demo:
89
  gr.Markdown("# ⚗️ Torgo DSR Lab")
90
  current_audio_path = gr.State("")
 
92
  with gr.Tab("🔬 Laboratory"):
93
  with gr.Row():
94
  with gr.Column(scale=1):
95
+ gr.Markdown("### Step 1: Data Selection")
96
+ spk_input = gr.Dropdown(sorted(list(SPEAKER_META.keys())), label="Speaker ID", value="F01")
97
  load_btn = gr.Button("Load Data")
98
+ meta_json = gr.JSON(label="Speaker Metadata")
99
  gt_box = gr.Textbox(label="Ground Truth")
100
+ audio_input = gr.Audio(sources=["microphone", "upload"], type="filepath", label="Input Audio")
101
 
102
  with gr.Column(scale=2):
103
  gr.Markdown("### Step 2: ASR Baseline")
 
122
  gr.Markdown("""
123
  ### 🧪 Model Definitions
124
  * **5K Pure Model:** Trained on 5,000 real Torgo samples. Optimized for articulatory fidelity.
125
+ * **10K Triple-Mix Model:** Includes phonetic anchors and synthetic data. Utilized to test **generalization (LOSO)** on unseen speakers.
126
  """)
127
+
128
+ gr.Markdown("## 1. Torgo In-Domain Breakdown (By Speaker)")
 
129
  gr.DataFrame(get_indomain_breakdown())
130
+
131
  gr.Markdown("## 2. Experimental Summary")
132
  gr.DataFrame(get_experimental_summary())
133
 
134
+ # Connection logic
135
+ load_btn.click(get_sample_logic, inputs=spk_input, outputs=[current_audio_path, gt_box, meta_json]).then(
136
+ lambda x: x, inputs=current_audio_path, outputs=audio_input
137
+ )
138
+ whisper_btn.click(run_whisper_step if 'run_whisper_step' in globals() else run_lab, inputs=current_audio_path, outputs=[w_raw, w_norm, final_out])
139
+ model_btn.click(run_lab, inputs=current_audio_path, outputs=[w_raw, w_norm, final_out])
140
 
141
  demo.launch()