Torgo-DSR-Lab / app.py
st192011's picture
Update app.py
601bbed verified
import gradio as gr
import os
import io
import re
import random
import librosa
import soundfile as sf
import pandas as pd
from gradio_client import Client, handle_file
from transformers import pipeline
from datasets import load_dataset, Audio
from gradio_client import Client
from stats_data import get_indomain_breakdown, get_experimental_summary, SPEAKER_META
# 1. Initialize Baseline ASR (Strict English, Repetition Penalty 3.0)
print("Initializing Whisper Tiny Baseline...")
whisper_asr = pipeline(
"automatic-speech-recognition",
model="openai/whisper-tiny",
generate_kwargs={
"language": "en",
"task": "transcribe",
"repetition_penalty": 3.0
}
)
# Configuration from Environment Variables
HF_TOKEN = os.getenv("HF_TOKEN")
PRIVATE_BACKEND_URL = "st192011/Torgo-DSR-Private"
def normalize_text(text):
if not text: return ""
# Remove punctuation and lowercase
text = re.sub(r'[^\w\s]', '', text).lower().strip()
return " ".join(text.split())
def format_audio(audio_path):
"""Ensures audio is 16kHz mono to match ASR training conditions."""
y, sr = librosa.load(audio_path, sr=16000)
out_path = "formatted_input.wav"
sf.write(out_path, y, sr)
return out_path
# --- Logic: Data Loading ---
def get_sample_logic(speaker_id):
try:
if "UA" in speaker_id:
# UA-Speech Access (Direct pull for F02)
dataset = load_dataset("resproj007/uaspeech_female", split="train", streaming=True)
dataset = dataset.cast_column("audio", Audio(decode=False))
# UA is small, skip slightly for variety
sample = next(iter(dataset.skip(random.randint(0, 30))))
gt_text = sample.get('text') or sample.get('transcription') or sample.get('sentence')
else:
# Torgo Access (Manual filtering as per Colab fix)
dataset = load_dataset("abnerh/TORGO-database", split="train", streaming=True)
dataset = dataset.cast_column("audio", Audio(decode=False))
def filter_spk(x):
sid = str(x.get('speaker_id', '')).upper()
if not sid or sid == "NONE":
sid = os.path.basename(x['audio']['path']).split('_')[0].upper()
return sid == speaker_id
speaker_ds = dataset.filter(filter_spk)
sample = next(iter(speaker_ds.shuffle(buffer_size=10)))
gt_text = sample.get('transcription') or sample.get('text')
# Decode Bytes manually to bypass torchcodec errors
audio_bytes = sample['audio']['bytes']
audio_data, sr = librosa.load(io.BytesIO(audio_bytes), sr=16000)
temp_path = "dataset_sample.wav"
sf.write(temp_path, audio_data, sr)
return temp_path, gt_text.lower().strip(), SPEAKER_META[speaker_id]
except Exception as e:
return None, f"Dataset Error: {e}", {}
# --- Logic: Model Processing ---
def process_audio_step_1(audio_path):
"""Runs Whisper Baseline and returns normalized text."""
if not audio_path: return "No audio loaded", ""
# Pre-process audio format to 16k
formatted_path = format_audio(audio_path)
# Run Whisper
result = whisper_asr(formatted_path)
raw_w = result["text"]
norm_w = normalize_text(raw_w)
return raw_w, norm_w
def process_audio_step_2(audio_path, norm_whisper):
"""Sends audio + normalized whisper to the Private Model API."""
if not audio_path or not norm_whisper:
return "Please load data and run Whisper (Step 1) first."
try:
# Connect to the private API
client = Client(PRIVATE_BACKEND_URL, token=HF_TOKEN)
# FIX: Wrap audio_path with handle_file()
# This sends the metadata required by Pydantic ('gradio.FileData')
prediction = client.predict(
audio_path=handle_file(audio_path),
whisper_norm=norm_whisper,
api_name="/predict_dsr"
)
return prediction
except Exception as e:
return f"Backend Connection Required. Details: {e}"
# --- UI Construction ---
with gr.Blocks(theme=gr.themes.Soft(), title="Torgo DSR Lab") as demo:
gr.Markdown("# βš—οΈ Torgo DSR Lab")
gr.Markdown("Neural Reconstruction Layer for Torgo and UA-Speech Zero-Shot.")
# Hidden state to store the path of the currently active audio
active_audio_path = gr.State("")
with gr.Tab("πŸ”¬ Laboratory"):
with gr.Row():
# LEFT COLUMN: Data Input
with gr.Column(scale=1):
with gr.Group():
gr.Markdown("### Channel A: Research Datasets")
speaker_input = gr.Dropdown(sorted(list(SPEAKER_META.keys())), label="Select Speaker Profile", value="F01")
load_btn = gr.Button("Load Sample from Dataset")
gt_box = gr.Textbox(label="Ground Truth (Reference)", interactive=False)
meta_display = gr.JSON(label="Speaker Metadata")
gr.Markdown("---")
with gr.Group():
gr.Markdown("### Channel B: Personal Input")
user_audio = gr.Audio(sources=["microphone", "upload"], type="filepath", label="Record or Upload Audio")
user_load_btn = gr.Button("Use This Audio")
# RIGHT COLUMN: Transcripts
with gr.Column(scale=2):
gr.Markdown("### Analysis & Reconstruction")
with gr.Group():
gr.Markdown("#### Step 1: ASR Baseline")
whisper_btn = gr.Button("Run Whisper Tiny")
w_raw = gr.Textbox(label="Whisper Raw Transcript")
w_norm = gr.Textbox(label="Whisper Normalized")
gr.Markdown("---")
with gr.Group():
gr.Markdown("#### Step 2: Neural Reconstruction")
model_btn = gr.Button("Run Our Correction Model", variant="primary")
final_out = gr.Textbox(label="DSR Lab Prediction (5K Model)")
with gr.Tab("πŸ“Š Research Statistics"):
gr.Markdown("# πŸ”¬ Performance Evaluation")
with gr.Row():
with gr.Column():
gr.Markdown("""
### πŸ“ Metric: Exact Match Accuracy
Accuracy is the percentage of samples where the **normalized prediction** (lowercase, no punctuation) exactly matches the **normalized ground truth**.
""")
with gr.Column():
gr.Markdown("""
### πŸ§ͺ Model Definitions
* **5K Pure Model:** Trained on real-world Torgo articulatory distortions. Optimized for phonetic fidelity.
* **10K Triple-Mix Model:** Includes synthetic data and anchors; utilized for generalization (LOSO) testing.
""")
gr.Markdown("---")
gr.Markdown("## 1. Torgo In-Domain Analysis (By Speaker)")
gr.DataFrame(get_indomain_breakdown())
gr.Markdown("## 2. Experimental Milestone Summary")
gr.DataFrame(get_experimental_summary())
gr.Markdown("""
### πŸ” Key Discovery: The Acoustic Floor
Our research found that the **5K Pure Model** achieved higher accuracy in both in-domain and zero-shot tasks. This suggests an **'Acoustic Floor'** exists where real-world phonetic distortions are more valuable for model grounding than synthetic linguistic diversity.
""")
# --- Event Handlers ---
# Dataset Channel: Load -> Update State -> Update UI Text/Meta
load_btn.click(
get_sample_logic,
inputs=speaker_input,
outputs=[active_audio_path, gt_box, meta_display]
)
# Personal Channel: Use Audio -> Update State -> Clear Reference
user_load_btn.click(
lambda x: (x, "User Recorded (No Ground Truth)", {"Dataset": "Custom", "Severity": "N/A"}),
inputs=user_audio,
outputs=[active_audio_path, gt_box, meta_display]
)
# Step 1: Whisper (Uses State)
whisper_btn.click(
process_audio_step_1,
inputs=active_audio_path,
outputs=[w_raw, w_norm]
)
# Step 2: Model (Uses State + Whisper result)
model_btn.click(
process_audio_step_2,
inputs=[active_audio_path, w_norm],
outputs=final_out
)
demo.launch()