Torgo-DSR-Lab / app.py
st192011's picture
Update app.py
0353a67 verified
raw
history blame
4.72 kB
import gradio as gr
import os
import random
import soundfile as sf
import re
from transformers import pipeline
from datasets import load_dataset
from gradio_client import Client
from stats_data import get_indomain_breakdown, get_experimental_summary, SPEAKER_META
# 1. Initialize Local Whisper (Baseline)
whisper_asr = pipeline("automatic-speech-recognition", model="openai/whisper-tiny")
# 2. Setup Private Backend Connection (Hidden logic)
HF_TOKEN = os.getenv("HF_TOKEN")
PRIVATE_BACKEND_URL = "st192011/Torgo-DSR-Private" # Update with your private space name
def normalize_text(text):
"""Simple normalization for comparison: lowercase and strip punctuation."""
return re.sub(r'[^\w\s]', '', text).lower().strip()
def get_sample(speaker_id):
"""Accesses HF Datasets via Streaming to get a sample for the UI."""
try:
if "UA" in speaker_id:
# Note: UA-Speech ID logic (Speaker F02)
path = "ngdiana/uaspeech_severity_high"
actual_spk = "F02"
else:
path = "unsw-cse/torgo"
actual_spk = speaker_id
# Stream dataset to avoid huge downloads
ds = load_dataset(path, split="test", streaming=True)
# Filter for the chosen speaker
speaker_ds = ds.filter(lambda x: x["speaker_id"] == actual_spk)
# Take a small buffer and pick a random sample
samples = list(speaker_ds.take(20))
sample = random.choice(samples)
audio_path = "sample_audio.wav"
sf.write(audio_path, sample["audio"]["array"], sample["audio"]["sampling_rate"])
return audio_path, sample["text"], SPEAKER_META[speaker_id]
except Exception as e:
return None, f"Error accessing dataset: {e}", None
def run_correction(audio_path, gt_text):
if audio_path is None: return "No audio input", "", ""
# A. Local Whisper Inference
w_raw = whisper_asr(audio_path)["text"]
w_norm = normalize_text(w_raw)
# B. Call Private Backend for the 5K and 10K results
try:
client = Client(PRIVATE_BACKEND_URL, hf_token=HF_TOKEN)
# Private app receives audio + normalized whisper, returns (5k_pred, 10k_pred)
res_5k, res_10k = client.predict(audio_path, w_norm, api_name="/predict_dsr_dual")
except Exception as e:
res_5k, res_10k = "Backend Connection Required", f"Details: {e}"
return w_raw, res_5k, res_10k
# UI Layout
with gr.Blocks(theme=gr.themes.Default(), title="Torgo DSR Lab") as demo:
gr.Markdown("# βš—οΈ Torgo DSR Lab")
gr.Markdown("### Neural Reconstruction and ASR Correction for Torgo and UA-Speech")
with gr.Tab("πŸ”¬ Laboratory"):
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("#### 1. Dataset Explorer")
spk_input = gr.Dropdown(list(SPEAKER_META.keys()), label="Select Speaker Profile")
load_btn = gr.Button("🎲 Load Random Dataset Sample")
gr.Markdown("---")
audio_input = gr.Audio(sources=["microphone", "upload"], type="filepath", label="Input Audio")
with gr.Column(scale=2):
gr.Markdown("#### 2. Metadata & Ground Truth")
gt_box = gr.Textbox(label="Ground Truth (Human Label)", interactive=False)
meta_box = gr.JSON(label="Speaker Characteristics")
gr.Markdown("#### 3. Comparison Results")
w_out = gr.Textbox(label="Whisper Tiny Baseline (Raw Transcript)")
with gr.Row():
out_5k = gr.Textbox(label="5K Pure Model (Acoustic Focus)")
out_10k = gr.Textbox(label="10K Triple-Mix Model (Linguistic Focus)")
run_btn = gr.Button("πŸš€ Run Correction Layer", variant="primary")
with gr.Tab("πŸ“Š Research Statistics"):
gr.Markdown("# πŸ”¬ Evaluation Metrics")
gr.Markdown("""
**Metric:** Exact Match Accuracy.
Calculated by comparing the **normalized prediction** (lowercase, no punctuation) against the **normalized ground truth**.
""")
gr.Markdown("### 1. In-Domain Torgo Breakdown (By Speaker)")
gr.DataFrame(get_indomain_breakdown())
gr.Markdown("### 2. Experimental Milestone Summary")
gr.Markdown("_Note: The 10K model was utilized to test generalization via LOSO on unseen speaker F01._")
gr.DataFrame(get_experimental_summary())
# Event Logic
load_btn.click(get_sample, inputs=spk_input, outputs=[audio_input, gt_box, meta_box])
run_btn.click(run_correction, inputs=[audio_input, gt_box], outputs=[w_out, out_5k, out_10k])
demo.launch()