Spaces:
Running
Running
| import gradio as gr | |
| import os | |
| import random | |
| import soundfile as sf | |
| import re | |
| from transformers import pipeline | |
| from datasets import load_dataset | |
| from gradio_client import Client | |
| from stats_data import get_indomain_breakdown, get_experimental_summary, SPEAKER_META | |
| # 1. Initialize Local Whisper (Baseline) | |
| whisper_asr = pipeline("automatic-speech-recognition", model="openai/whisper-tiny") | |
| # 2. Setup Private Backend Connection (Hidden logic) | |
| HF_TOKEN = os.getenv("HF_TOKEN") | |
| PRIVATE_BACKEND_URL = "st192011/Torgo-DSR-Private" # Update with your private space name | |
| def normalize_text(text): | |
| """Simple normalization for comparison: lowercase and strip punctuation.""" | |
| return re.sub(r'[^\w\s]', '', text).lower().strip() | |
| def get_sample(speaker_id): | |
| """Accesses HF Datasets via Streaming to get a sample for the UI.""" | |
| try: | |
| if "UA" in speaker_id: | |
| # Note: UA-Speech ID logic (Speaker F02) | |
| path = "ngdiana/uaspeech_severity_high" | |
| actual_spk = "F02" | |
| else: | |
| path = "unsw-cse/torgo" | |
| actual_spk = speaker_id | |
| # Stream dataset to avoid huge downloads | |
| ds = load_dataset(path, split="test", streaming=True) | |
| # Filter for the chosen speaker | |
| speaker_ds = ds.filter(lambda x: x["speaker_id"] == actual_spk) | |
| # Take a small buffer and pick a random sample | |
| samples = list(speaker_ds.take(20)) | |
| sample = random.choice(samples) | |
| audio_path = "sample_audio.wav" | |
| sf.write(audio_path, sample["audio"]["array"], sample["audio"]["sampling_rate"]) | |
| return audio_path, sample["text"], SPEAKER_META[speaker_id] | |
| except Exception as e: | |
| return None, f"Error accessing dataset: {e}", None | |
| def run_correction(audio_path, gt_text): | |
| if audio_path is None: return "No audio input", "", "" | |
| # A. Local Whisper Inference | |
| w_raw = whisper_asr(audio_path)["text"] | |
| w_norm = normalize_text(w_raw) | |
| # B. Call Private Backend for the 5K and 10K results | |
| try: | |
| client = Client(PRIVATE_BACKEND_URL, hf_token=HF_TOKEN) | |
| # Private app receives audio + normalized whisper, returns (5k_pred, 10k_pred) | |
| res_5k, res_10k = client.predict(audio_path, w_norm, api_name="/predict_dsr_dual") | |
| except Exception as e: | |
| res_5k, res_10k = "Backend Connection Required", f"Details: {e}" | |
| return w_raw, res_5k, res_10k | |
| # UI Layout | |
| with gr.Blocks(theme=gr.themes.Default(), title="Torgo DSR Lab") as demo: | |
| gr.Markdown("# βοΈ Torgo DSR Lab") | |
| gr.Markdown("### Neural Reconstruction and ASR Correction for Torgo and UA-Speech") | |
| with gr.Tab("π¬ Laboratory"): | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("#### 1. Dataset Explorer") | |
| spk_input = gr.Dropdown(list(SPEAKER_META.keys()), label="Select Speaker Profile") | |
| load_btn = gr.Button("π² Load Random Dataset Sample") | |
| gr.Markdown("---") | |
| audio_input = gr.Audio(sources=["microphone", "upload"], type="filepath", label="Input Audio") | |
| with gr.Column(scale=2): | |
| gr.Markdown("#### 2. Metadata & Ground Truth") | |
| gt_box = gr.Textbox(label="Ground Truth (Human Label)", interactive=False) | |
| meta_box = gr.JSON(label="Speaker Characteristics") | |
| gr.Markdown("#### 3. Comparison Results") | |
| w_out = gr.Textbox(label="Whisper Tiny Baseline (Raw Transcript)") | |
| with gr.Row(): | |
| out_5k = gr.Textbox(label="5K Pure Model (Acoustic Focus)") | |
| out_10k = gr.Textbox(label="10K Triple-Mix Model (Linguistic Focus)") | |
| run_btn = gr.Button("π Run Correction Layer", variant="primary") | |
| with gr.Tab("π Research Statistics"): | |
| gr.Markdown("# π¬ Evaluation Metrics") | |
| gr.Markdown(""" | |
| **Metric:** Exact Match Accuracy. | |
| Calculated by comparing the **normalized prediction** (lowercase, no punctuation) against the **normalized ground truth**. | |
| """) | |
| gr.Markdown("### 1. In-Domain Torgo Breakdown (By Speaker)") | |
| gr.DataFrame(get_indomain_breakdown()) | |
| gr.Markdown("### 2. Experimental Milestone Summary") | |
| gr.Markdown("_Note: The 10K model was utilized to test generalization via LOSO on unseen speaker F01._") | |
| gr.DataFrame(get_experimental_summary()) | |
| # Event Logic | |
| load_btn.click(get_sample, inputs=spk_input, outputs=[audio_input, gt_box, meta_box]) | |
| run_btn.click(run_correction, inputs=[audio_input, gt_box], outputs=[w_out, out_5k, out_10k]) | |
| demo.launch() |