st192011 commited on
Commit
712d6bb
Β·
verified Β·
1 Parent(s): baa22ab

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +84 -58
app.py CHANGED
@@ -3,108 +3,134 @@ import os
3
  import random
4
  import soundfile as sf
5
  import re
 
 
 
6
  from transformers import pipeline
7
- from datasets import load_dataset
8
  from gradio_client import Client
9
  from stats_data import get_indomain_breakdown, get_experimental_summary, SPEAKER_META
10
 
11
- # 1. Initialize Local Whisper (Baseline)
 
12
  whisper_asr = pipeline("automatic-speech-recognition", model="openai/whisper-tiny")
13
 
14
- # 2. Setup Private Backend Connection (Hidden logic)
15
  HF_TOKEN = os.getenv("HF_TOKEN")
16
- PRIVATE_BACKEND_URL = "st192011/Torgo-DSR-Private" # Update with your private space name
17
-
18
- def normalize_text(text):
19
- """Simple normalization for comparison: lowercase and strip punctuation."""
20
- return re.sub(r'[^\w\s]', '', text).lower().strip()
21
 
22
  def get_sample(speaker_id):
23
- """Accesses HF Datasets via Streaming to get a sample for the UI."""
24
  try:
25
- if "UA" in speaker_id:
26
- # Note: UA-Speech ID logic (Speaker F02)
27
- path = "ngdiana/uaspeech_severity_high"
28
- actual_spk = "F02"
 
 
 
 
29
  else:
30
- path = "unsw-cse/torgo"
31
- actual_spk = speaker_id
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
- # Stream dataset to avoid huge downloads
34
- ds = load_dataset(path, split="test", streaming=True)
35
- # Filter for the chosen speaker
36
- speaker_ds = ds.filter(lambda x: x["speaker_id"] == actual_spk)
37
 
38
- # Take a small buffer and pick a random sample
39
- samples = list(speaker_ds.take(20))
40
- sample = random.choice(samples)
41
-
42
- audio_path = "sample_audio.wav"
43
- sf.write(audio_path, sample["audio"]["array"], sample["audio"]["sampling_rate"])
44
-
45
- return audio_path, sample["text"], SPEAKER_META[speaker_id]
46
  except Exception as e:
47
  return None, f"Error accessing dataset: {e}", None
48
 
49
  def run_correction(audio_path, gt_text):
50
- if audio_path is None: return "No audio input", "", ""
 
51
 
52
  # A. Local Whisper Inference
53
- w_raw = whisper_asr(audio_path)["text"]
54
- w_norm = normalize_text(w_raw)
 
 
 
 
55
 
56
- # B. Call Private Backend for the 5K and 10K results
 
57
  try:
58
  client = Client(PRIVATE_BACKEND_URL, hf_token=HF_TOKEN)
59
- # Private app receives audio + normalized whisper, returns (5k_pred, 10k_pred)
60
  res_5k, res_10k = client.predict(audio_path, w_norm, api_name="/predict_dsr_dual")
61
  except Exception as e:
62
- res_5k, res_10k = "Backend Connection Required", f"Details: {e}"
 
63
 
64
  return w_raw, res_5k, res_10k
65
 
66
- # UI Layout
67
- with gr.Blocks(theme=gr.themes.Default(), title="Torgo DSR Lab") as demo:
68
  gr.Markdown("# βš—οΈ Torgo DSR Lab")
69
- gr.Markdown("### Neural Reconstruction and ASR Correction for Torgo and UA-Speech")
70
 
71
- with gr.Tab("πŸ”¬ Laboratory"):
72
  with gr.Row():
73
  with gr.Column(scale=1):
74
- gr.Markdown("#### 1. Dataset Explorer")
75
- spk_input = gr.Dropdown(list(SPEAKER_META.keys()), label="Select Speaker Profile")
76
- load_btn = gr.Button("🎲 Load Random Dataset Sample")
77
  gr.Markdown("---")
78
- audio_input = gr.Audio(sources=["microphone", "upload"], type="filepath", label="Input Audio")
79
 
80
  with gr.Column(scale=2):
81
- gr.Markdown("#### 2. Metadata & Ground Truth")
82
- gt_box = gr.Textbox(label="Ground Truth (Human Label)", interactive=False)
83
- meta_box = gr.JSON(label="Speaker Characteristics")
 
84
 
85
- gr.Markdown("#### 3. Comparison Results")
86
  w_out = gr.Textbox(label="Whisper Tiny Baseline (Raw Transcript)")
87
  with gr.Row():
88
- out_5k = gr.Textbox(label="5K Pure Model (Acoustic Focus)")
89
- out_10k = gr.Textbox(label="10K Triple-Mix Model (Linguistic Focus)")
90
-
91
- run_btn = gr.Button("πŸš€ Run Correction Layer", variant="primary")
92
 
93
  with gr.Tab("πŸ“Š Research Statistics"):
94
- gr.Markdown("# πŸ”¬ Evaluation Metrics")
95
- gr.Markdown("""
96
- **Metric:** Exact Match Accuracy.
97
- Calculated by comparing the **normalized prediction** (lowercase, no punctuation) against the **normalized ground truth**.
98
- """)
99
 
100
- gr.Markdown("### 1. In-Domain Torgo Breakdown (By Speaker)")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
  gr.DataFrame(get_indomain_breakdown())
102
 
103
- gr.Markdown("### 2. Experimental Milestone Summary")
104
- gr.Markdown("_Note: The 10K model was utilized to test generalization via LOSO on unseen speaker F01._")
105
  gr.DataFrame(get_experimental_summary())
106
 
107
- # Event Logic
108
  load_btn.click(get_sample, inputs=spk_input, outputs=[audio_input, gt_box, meta_box])
109
  run_btn.click(run_correction, inputs=[audio_input, gt_box], outputs=[w_out, out_5k, out_10k])
110
 
 
3
  import random
4
  import soundfile as sf
5
  import re
6
+ import io
7
+ import librosa
8
+ import torch
9
  from transformers import pipeline
10
+ from datasets import load_dataset, Audio
11
  from gradio_client import Client
12
  from stats_data import get_indomain_breakdown, get_experimental_summary, SPEAKER_META
13
 
14
+ # 1. Initialize Local Whisper Tiny (Baseline)
15
+ # CPU friendly, fast inference
16
  whisper_asr = pipeline("automatic-speech-recognition", model="openai/whisper-tiny")
17
 
18
+ # 2. Private Backend Config
19
  HF_TOKEN = os.getenv("HF_TOKEN")
20
+ PRIVATE_BACKEND_URL = "st192011/Torgo-DSR-Private"
 
 
 
 
21
 
22
  def get_sample(speaker_id):
23
+ """Integrated loading logic from your research code."""
24
  try:
25
+ if speaker_id == "F02":
26
+ # UA-Speech loading logic
27
+ dataset = load_dataset("resproj007/uaspeech_female", split="test", streaming=True)
28
+ # F02 is usually the primary speaker in this slice
29
+ sample = next(iter(dataset.shuffle(buffer_size=20)))
30
+ gt_text = sample.get('text') or sample.get('transcription') or sample.get('sentence', 'Unknown')
31
+ audio_data = sample['audio']['array']
32
+ sample_rate = sample['audio']['sampling_rate']
33
  else:
34
+ # Torgo loading logic
35
+ dataset = load_dataset("abnerh/TORGO-database", split="train", streaming=True)
36
+ # Cast for manual decoding as per your training script
37
+ dataset = dataset.cast_column("audio", Audio(decode=False))
38
+
39
+ # Filter by speaker
40
+ speaker_ds = dataset.filter(lambda x: str(x.get('speaker_id', '')).upper() == speaker_id)
41
+ sample = next(iter(speaker_ds.shuffle(buffer_size=20)))
42
+
43
+ # Extract ground truth
44
+ gt_text = sample.get('transcription') or sample.get('text', 'Unknown')
45
+
46
+ # Decode Audio bytes
47
+ audio_bytes = sample['audio']['bytes']
48
+ audio_data, sample_rate = librosa.load(io.BytesIO(audio_bytes), sr=16000)
49
 
50
+ # Save to temporary file for Gradio and Whisper
51
+ temp_path = "temp_sample.wav"
52
+ sf.write(temp_path, audio_data, sample_rate)
 
53
 
54
+ return temp_path, gt_text.lower().strip(), SPEAKER_META[speaker_id]
55
+
 
 
 
 
 
 
56
  except Exception as e:
57
  return None, f"Error accessing dataset: {e}", None
58
 
59
  def run_correction(audio_path, gt_text):
60
+ if audio_path is None:
61
+ return "No audio provided", "", "Please load a sample or record audio."
62
 
63
  # A. Local Whisper Inference
64
+ try:
65
+ w_res = whisper_asr(audio_path)
66
+ w_raw = w_res["text"]
67
+ w_norm = re.sub(r'[^\w\s]', '', w_raw).lower().strip()
68
+ except Exception as e:
69
+ return f"Whisper Error: {e}", "", ""
70
 
71
+ # B. Call Private Backend
72
+ # This sends the audio and the whisper transcript to your private Gemma model
73
  try:
74
  client = Client(PRIVATE_BACKEND_URL, hf_token=HF_TOKEN)
75
+ # Note: Your private backend should expect (audio_file, whisper_text)
76
  res_5k, res_10k = client.predict(audio_path, w_norm, api_name="/predict_dsr_dual")
77
  except Exception as e:
78
+ res_5k = "Backend Offline"
79
+ res_10k = "Please ensure the Private Space is running."
80
 
81
  return w_raw, res_5k, res_10k
82
 
83
+ # UI Construction
84
+ with gr.Blocks(theme=gr.themes.Soft(), title="Torgo DSR Lab") as demo:
85
  gr.Markdown("# βš—οΈ Torgo DSR Lab")
86
+ gr.Markdown("### Neural Reconstruction Layer for Torgo and UA-Speech Zero-Shot")
87
 
88
+ with gr.Tab("πŸ”¬ Interactive Lab"):
89
  with gr.Row():
90
  with gr.Column(scale=1):
91
+ gr.Markdown("#### 1. Select and Load Sample")
92
+ spk_input = gr.Dropdown(list(SPEAKER_META.keys()), label="Speaker ID", value="F01")
93
+ load_btn = gr.Button("🎲 Get Random Sample", variant="secondary")
94
  gr.Markdown("---")
95
+ audio_input = gr.Audio(sources=["microphone", "upload"], type="filepath", label="Audio Input")
96
 
97
  with gr.Column(scale=2):
98
+ gr.Markdown("#### 2. Metadata & Comparison")
99
+ with gr.Row():
100
+ gt_box = gr.Textbox(label="Ground Truth", interactive=False)
101
+ meta_box = gr.JSON(label="Speaker Meta")
102
 
 
103
  w_out = gr.Textbox(label="Whisper Tiny Baseline (Raw Transcript)")
104
  with gr.Row():
105
+ out_5k = gr.Textbox(label="5K Pure Model Prediction")
106
+ out_10k = gr.Textbox(label="10K Triple-Mix Prediction")
107
+
108
+ run_btn = gr.Button("πŸš€ Run ASR & Reconstruction", variant="primary")
109
 
110
  with gr.Tab("πŸ“Š Research Statistics"):
111
+ gr.Markdown("# πŸ”¬ Performance Evaluation")
 
 
 
 
112
 
113
+ with gr.Row():
114
+ with gr.Column():
115
+ gr.Markdown("""
116
+ ### πŸ“ Metric: Exact Match Accuracy
117
+ Accuracy is calculated as the percentage of samples where the **normalized prediction** (lowercase, no punctuation) exactly matches the **ground truth**.
118
+ """)
119
+
120
+ with gr.Column():
121
+ gr.Markdown("""
122
+ ### πŸ§ͺ Model Definitions
123
+ * **5K Pure Model:** Trained on 5,000 real Torgo samples. Optimized for articulatory fidelity.
124
+ * **10K Triple-Mix Model:** Includes phonetic anchors and synthetic data. Used for Generalization (LOSO) testing.
125
+ """)
126
+
127
+ gr.Markdown("## 1. Torgo In-Domain Breakdown (By Speaker)")
128
  gr.DataFrame(get_indomain_breakdown())
129
 
130
+ gr.Markdown("## 2. Experimental Condition Summary")
 
131
  gr.DataFrame(get_experimental_summary())
132
 
133
+ # Event Handlers
134
  load_btn.click(get_sample, inputs=spk_input, outputs=[audio_input, gt_box, meta_box])
135
  run_btn.click(run_correction, inputs=[audio_input, gt_box], outputs=[w_out, out_5k, out_10k])
136