st192011 commited on
Commit
77940c5
·
verified ·
1 Parent(s): 07ead75

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +87 -85
app.py CHANGED
@@ -5,132 +5,134 @@ import re
5
  import random
6
  import librosa
7
  import soundfile as sf
8
- import pandas as pd
9
  from transformers import pipeline
10
  from datasets import load_dataset, Audio
11
  from gradio_client import Client
12
  from stats_data import get_indomain_breakdown, get_experimental_summary, SPEAKER_META
13
 
14
- # 1. Configuration
15
- TORGO_INDICES = {'FC01': 0, 'FC02': 302, 'FC03': 2489, 'MC02': 4411, 'MC01': 5534, 'MC03': 7689, 'MC04': 9358, 'M05': 10978, 'M02': 11565, 'M04': 12337, 'M01': 13003, 'F01': 13746, 'M03': 13982, 'F04': 14792, 'F03': 15465}
16
- HF_TOKEN = os.getenv("HF_TOKEN")
17
- PRIVATE_BACKEND_URL = "st192011/Torgo-DSR-Private"
18
-
19
- # 2. Local Whisper Baseline (Strict English, Repetition Penalty 3.0)
20
  whisper_asr = pipeline(
21
  "automatic-speech-recognition",
22
  model="openai/whisper-tiny",
23
- generate_kwargs={"language": "en", "task": "transcribe", "repetition_penalty": 3.0, "max_new_tokens": 64}
24
  )
25
 
 
 
 
26
  def normalize_text(text):
27
  if not text: return ""
28
  return re.sub(r'[^\w\s]', '', text).lower().strip()
29
 
30
- def standardize_audio(input_path):
31
- """Ensures audio is 16kHz, Mono, and compatible with all models."""
32
- if not input_path: return None
33
- audio, sr = librosa.load(input_path, sr=16000, mono=True)
34
- out_path = "processed_audio.wav"
35
- sf.write(out_path, audio, 16000)
36
- return out_path
37
-
38
- # --- Logic: Data Loading ---
39
- def get_sample_logic(speaker_id):
40
- try:
41
- if speaker_id == "F02 (UA)":
42
- dataset = load_dataset("resproj007/uaspeech_female", split="train", streaming=True)
43
- dataset = dataset.cast_column("audio", Audio(decode=False))
44
- sample = next(iter(dataset.skip(random.randint(0, 30))))
45
- gt_text = sample.get('text') or sample.get('transcription') or "Unknown"
46
- else:
47
- dataset = load_dataset("abnerh/TORGO-database", split="train", streaming=True)
48
- dataset = dataset.cast_column("audio", Audio(decode=False))
49
- start_idx = TORGO_INDICES.get(speaker_id, 0)
50
- sample = next(iter(dataset.skip(start_idx + random.randint(0, 15))))
51
- gt_text = sample.get('transcription') or sample.get('text') or "Unknown"
52
 
53
- # Decode Bytes manually and Standardize
54
- audio_bytes = sample['audio']['bytes']
55
- audio_data, _ = librosa.load(io.BytesIO(audio_bytes), sr=16000, mono=True)
56
-
57
- temp_path = "sample.wav"
58
- sf.write(temp_path, audio_data, 16000)
59
-
60
- # We return the path to the gr.Audio component (which stores it in State)
61
- return temp_path, gt_text.lower().strip(), SPEAKER_META[speaker_id]
62
- except Exception as e:
63
- return None, f"Dataset Error: {e}", {}
64
 
65
- # --- Logic: Model Steps ---
66
- def run_whisper_step(audio_path):
67
  if not audio_path: return "No audio loaded", ""
68
- # Standardize format before Whisper
69
- clean_audio = standardize_audio(audio_path)
70
- result = whisper_asr(clean_audio)
71
  raw_w = result["text"]
72
  norm_w = normalize_text(raw_w)
73
  return raw_w, norm_w
74
 
75
- def run_model_step(audio_path, norm_whisper):
76
- if not audio_path or not norm_whisper: return "Complete Step 1 & 2 first."
77
- # Standardize format before sending to Private Backend
78
- clean_audio = standardize_audio(audio_path)
79
  try:
80
  client = Client(PRIVATE_BACKEND_URL, hf_token=HF_TOKEN)
81
- prediction = client.predict(clean_audio, norm_whisper, api_name="/predict_dsr")
 
82
  return prediction
83
  except Exception as e:
84
- return f"Backend Offline. Details: {e}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
 
86
- # --- UI Construction ---
87
  with gr.Blocks(theme=gr.themes.Soft(), title="Torgo DSR Lab") as demo:
88
  gr.Markdown("# ⚗️ Torgo DSR Lab")
89
- gr.Markdown("Neural Reconstruction for Severe Dysarthria. Load samples from Torgo/UA or record your own.")
90
 
91
- with gr.Tab("🔬 Laboratory"):
 
 
 
 
 
92
  with gr.Row():
93
  with gr.Column(scale=1):
94
- gr.Markdown("### Step 1: Input Audio")
95
- speaker_input = gr.Dropdown(sorted(list(SPEAKER_META.keys())), label="Speaker Profile (for Dataset Samples)")
96
- load_btn = gr.Button("🎲 Load Dataset Sample")
 
 
 
 
 
 
97
 
98
  gr.Markdown("---")
99
- # Unified Input: Handles both Dataset Samples and User Input
100
- audio_input = gr.Audio(sources=["microphone", "upload"], type="filepath", label="Audio Input (Record/Upload/Dataset)")
101
-
102
- meta_display = gr.JSON(label="Speaker Metadata")
103
- gt_box = gr.Textbox(label="Ground Truth (if from dataset)")
104
 
 
 
 
 
 
 
 
105
  with gr.Column(scale=2):
106
- gr.Markdown("### Step 2: ASR Baseline")
107
- whisper_btn = gr.Button("Run Whisper Tiny")
108
- w_raw = gr.Textbox(label="Whisper Raw")
109
- w_norm = gr.Textbox(label="Whisper Normalized")
110
 
111
  gr.Markdown("---")
112
- gr.Markdown("### Step 3: Neural Reconstruction")
113
- model_btn = gr.Button("Run Our Model", variant="primary")
114
- final_out = gr.Textbox(label="DSR Lab Prediction")
115
 
116
  with gr.Tab("📊 Research Statistics"):
117
- gr.Markdown("# 🔬 Performance Evaluation")
118
- with gr.Row():
119
- with gr.Column():
120
- gr.Markdown("### 📏 Metric: Exact Match Accuracy")
121
- gr.Markdown("Accuracy is calculated on normalized text (lowercase, no punctuation).")
122
- with gr.Column():
123
- gr.Markdown("### 🧪 Model Definitions")
124
- gr.Markdown("* **5K Pure Model:** Real data focus. \n* **10K Triple-Mix Model:** LOSO Generalization focus.")
125
-
126
- gr.Markdown("## 1. Torgo In-Domain Analysis")
127
  gr.DataFrame(get_indomain_breakdown())
128
- gr.Markdown("## 2. Experimental Summary")
129
  gr.DataFrame(get_experimental_summary())
130
 
131
- # Logic connections
132
- load_btn.click(get_sample_logic, inputs=speaker_input, outputs=[audio_input, gt_box, meta_display])
133
- whisper_btn.click(run_whisper_step, inputs=audio_input, outputs=[w_raw, w_norm])
134
- model_btn.click(run_model_step, inputs=[audio_input, w_norm], outputs=final_out)
 
 
 
 
 
135
 
136
  demo.launch()
 
5
  import random
6
  import librosa
7
  import soundfile as sf
 
8
  from transformers import pipeline
9
  from datasets import load_dataset, Audio
10
  from gradio_client import Client
11
  from stats_data import get_indomain_breakdown, get_experimental_summary, SPEAKER_META
12
 
13
+ # 1. Setup Local Whisper Baseline (English, Strict Generation)
 
 
 
 
 
14
  whisper_asr = pipeline(
15
  "automatic-speech-recognition",
16
  model="openai/whisper-tiny",
17
+ generate_kwargs={"language": "en", "task": "transcribe", "repetition_penalty": 3.0}
18
  )
19
 
20
+ HF_TOKEN = os.getenv("HF_TOKEN")
21
+ PRIVATE_BACKEND_URL = "st192011/Torgo-DSR-Private"
22
+
23
  def normalize_text(text):
24
  if not text: return ""
25
  return re.sub(r'[^\w\s]', '', text).lower().strip()
26
 
27
+ # --- Shared Processing Logic ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
+ def process_audio_file(audio_path):
30
+ """Ensures any input audio is formatted correctly for ASR systems (16kHz Mono)."""
31
+ y, sr = librosa.load(audio_path, sr=16000)
32
+ fixed_path = "processed_audio.wav"
33
+ sf.write(fixed_path, y, sr)
34
+ return fixed_path
 
 
 
 
 
35
 
36
+ def run_whisper_logic(audio_path):
 
37
  if not audio_path: return "No audio loaded", ""
38
+ formatted_path = process_audio_file(audio_path)
39
+ result = whisper_asr(formatted_path)
 
40
  raw_w = result["text"]
41
  norm_w = normalize_text(raw_w)
42
  return raw_w, norm_w
43
 
44
+ def run_reconstruction_logic(audio_path, norm_whisper):
45
+ if not audio_path or not norm_whisper: return "Run Whisper step first."
 
 
46
  try:
47
  client = Client(PRIVATE_BACKEND_URL, hf_token=HF_TOKEN)
48
+ # Private backend handles Wav2Vec, Allosaurus, and Gemma 3 arbitration
49
+ prediction = client.predict(audio_path, norm_whisper, api_name="/predict_dsr")
50
  return prediction
51
  except Exception as e:
52
+ return f"Backend Offline. Error: {e}"
53
+
54
+ # --- Channel 1: Dataset Loader ---
55
+ def get_dataset_sample(speaker_id):
56
+ try:
57
+ if speaker_id == "F02":
58
+ ds = load_dataset("resproj007/uaspeech_female", split="train", streaming=True)
59
+ ds = ds.cast_column("audio", Audio(decode=False))
60
+ sample = next(iter(ds.skip(random.randint(0, 50))))
61
+ gt_text = sample.get('text') or sample.get('transcription') or "Unknown"
62
+ else:
63
+ ds = load_dataset("abnerh/TORGO-database", split="train", streaming=True)
64
+ ds = ds.cast_column("audio", Audio(decode=False))
65
+ indices = {'M05': 10978, 'M02': 11565, 'M04': 12337, 'M01': 13003, 'F01': 13746, 'M03': 13982, 'F04': 14792, 'F03': 15465}
66
+ start_idx = indices.get(speaker_id, 0)
67
+ sample = next(iter(ds.skip(start_idx + random.randint(0, 10))))
68
+ gt_text = sample.get('transcription') or sample.get('text') or "Unknown"
69
+
70
+ audio_data, sr = librosa.load(io.BytesIO(sample['audio']['bytes']), sr=16000)
71
+ temp_path = f"sample_{speaker_id}.wav"
72
+ sf.write(temp_path, audio_data, sr)
73
+ return temp_path, gt_text.lower().strip(), SPEAKER_META.get(speaker_id, {})
74
+ except Exception as e:
75
+ return None, f"Dataset Error: {e}", {}
76
 
77
+ # --- UI Layout ---
78
  with gr.Blocks(theme=gr.themes.Soft(), title="Torgo DSR Lab") as demo:
79
  gr.Markdown("# ⚗️ Torgo DSR Lab")
80
+ gr.Markdown("ASR Correction and Reconstruction Layer for Torgo and UA-Speech.")
81
 
82
+ # States for audio paths
83
+ lab_audio_state = gr.State("")
84
+ user_audio_state = gr.State("")
85
+
86
+ with gr.Tab("🔬 Research Samples"):
87
+ gr.Markdown("Select clinical samples from the Torgo or UA-Speech datasets.")
88
  with gr.Row():
89
  with gr.Column(scale=1):
90
+ speaker_input = gr.Dropdown(sorted(list(SPEAKER_META.keys())), label="Speaker ID", value="F01")
91
+ load_btn = gr.Button("Load Sample Data")
92
+ meta_display = gr.JSON(label="Sample Metadata")
93
+ gt_box = gr.Textbox(label="Ground Truth")
94
+
95
+ with gr.Column(scale=2):
96
+ whisper_btn_lab = gr.Button("1. Generate Whisper Baseline")
97
+ w_raw_lab = gr.Textbox(label="Whisper Raw")
98
+ w_norm_lab = gr.Textbox(label="Whisper Normalized")
99
 
100
  gr.Markdown("---")
101
+ model_btn_lab = gr.Button("2. Run Neural Reconstruction", variant="primary")
102
+ final_out_lab = gr.Textbox(label="DSR Lab Prediction")
 
 
 
103
 
104
+ with gr.Tab("🎤 Personal Test"):
105
+ gr.Markdown("Record or upload your own audio to test the reconstruction layer.")
106
+ with gr.Row():
107
+ with gr.Column(scale=1):
108
+ user_audio_input = gr.Audio(sources=["microphone", "upload"], type="filepath", label="User Audio")
109
+ process_user_btn = gr.Button("Prepare Audio")
110
+
111
  with gr.Column(scale=2):
112
+ whisper_btn_user = gr.Button("1. Generate Whisper Baseline")
113
+ w_raw_user = gr.Textbox(label="Whisper Raw")
114
+ w_norm_user = gr.Textbox(label="Whisper Normalized")
 
115
 
116
  gr.Markdown("---")
117
+ model_btn_user = gr.Button("2. Run Neural Reconstruction", variant="primary")
118
+ final_out_user = gr.Textbox(label="DSR Lab Prediction")
 
119
 
120
  with gr.Tab("📊 Research Statistics"):
121
+ gr.Markdown("# 🔬 Scientific Evaluation")
122
+ gr.Markdown("**Metric:** Exact Match Accuracy on normalized text (lowercase, no punctuation).")
123
+ gr.Markdown("## 1. Torgo In-Domain Breakdown")
 
 
 
 
 
 
 
124
  gr.DataFrame(get_indomain_breakdown())
125
+ gr.Markdown("## 2. Experimental Milestone Summary")
126
  gr.DataFrame(get_experimental_summary())
127
 
128
+ # --- Events: Research Tab ---
129
+ load_btn.click(get_dataset_sample, inputs=speaker_input, outputs=[lab_audio_state, gt_box, meta_display])
130
+ whisper_btn_lab.click(run_whisper_logic, inputs=lab_audio_state, outputs=[w_raw_lab, w_norm_lab])
131
+ model_btn_lab.click(run_reconstruction_logic, inputs=[lab_audio_state, w_norm_lab], outputs=final_out_lab)
132
+
133
+ # --- Events: Personal Tab ---
134
+ process_user_btn.click(lambda x: x, inputs=user_audio_input, outputs=user_audio_state)
135
+ whisper_btn_user.click(run_whisper_logic, inputs=user_audio_state, outputs=[w_raw_user, w_norm_user])
136
+ model_btn_user.click(run_reconstruction_logic, inputs=[user_audio_state, w_norm_user], outputs=final_out_user)
137
 
138
  demo.launch()