File size: 8,534 Bytes
6a2ebe8
 
712d6bb
4eb1313
 
712d6bb
4eb1313
005401d
601bbed
6a2ebe8
712d6bb
005401d
 
6a2ebe8
005401d
 
4eb1313
 
 
1368417
 
005401d
 
1368417
4eb1313
6a2ebe8
1fab43b
005401d
 
 
 
 
1fab43b
 
 
4eb1313
005401d
 
 
 
 
 
 
 
43fb18f
 
 
005401d
 
 
 
 
 
43fb18f
005401d
 
 
43fb18f
 
 
 
005401d
43fb18f
 
005401d
 
 
43fb18f
005401d
43fb18f
005401d
 
 
43fb18f
005401d
43fb18f
005401d
 
 
 
 
1fab43b
1368417
1fab43b
005401d
 
 
 
 
 
 
 
 
1fab43b
 
 
1368417
6a2ebe8
1fab43b
22f738f
1fab43b
601bbed
 
1fab43b
601bbed
 
1fab43b
 
005401d
6a2ebe8
005401d
6a2ebe8
005401d
712d6bb
005401d
1fab43b
b3a0889
005401d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4e82504
005401d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1368417
005401d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6a2ebe8
005401d
 
 
43fb18f
005401d
 
 
 
 
1fab43b
005401d
1fab43b
005401d
 
1368417
 
005401d
1368417
005401d
 
1368417
 
77940c5
005401d
1368417
005401d
 
 
1368417
6a2ebe8
005401d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
import gradio as gr
import os
import io
import re
import random
import librosa
import soundfile as sf
import pandas as pd
from gradio_client import Client, handle_file
from transformers import pipeline
from datasets import load_dataset, Audio
from gradio_client import Client
from stats_data import get_indomain_breakdown, get_experimental_summary, SPEAKER_META

# 1. Initialize Baseline ASR (Strict English, Repetition Penalty 3.0)
print("Initializing Whisper Tiny Baseline...")
whisper_asr = pipeline(
    "automatic-speech-recognition", 
    model="openai/whisper-tiny",
    generate_kwargs={
        "language": "en", 
        "task": "transcribe", 
        "repetition_penalty": 3.0
    }
)

# Configuration from Environment Variables
HF_TOKEN = os.getenv("HF_TOKEN")
PRIVATE_BACKEND_URL = "st192011/Torgo-DSR-Private"

def normalize_text(text):
    if not text: return ""
    # Remove punctuation and lowercase
    text = re.sub(r'[^\w\s]', '', text).lower().strip()
    return " ".join(text.split())

def format_audio(audio_path):
    """Ensures audio is 16kHz mono to match ASR training conditions."""
    y, sr = librosa.load(audio_path, sr=16000)
    out_path = "formatted_input.wav"
    sf.write(out_path, y, sr)
    return out_path

# --- Logic: Data Loading ---
def get_sample_logic(speaker_id):
    try:
        if "UA" in speaker_id:
            # UA-Speech Access (Direct pull for F02)
            dataset = load_dataset("resproj007/uaspeech_female", split="train", streaming=True)
            dataset = dataset.cast_column("audio", Audio(decode=False))
            # UA is small, skip slightly for variety
            sample = next(iter(dataset.skip(random.randint(0, 30))))
            gt_text = sample.get('text') or sample.get('transcription') or sample.get('sentence')
        else:
            # Torgo Access (Manual filtering as per Colab fix)
            dataset = load_dataset("abnerh/TORGO-database", split="train", streaming=True)
            dataset = dataset.cast_column("audio", Audio(decode=False))
            
            def filter_spk(x):
                sid = str(x.get('speaker_id', '')).upper()
                if not sid or sid == "NONE":
                    sid = os.path.basename(x['audio']['path']).split('_')[0].upper()
                return sid == speaker_id
            
            speaker_ds = dataset.filter(filter_spk)
            sample = next(iter(speaker_ds.shuffle(buffer_size=10)))
            gt_text = sample.get('transcription') or sample.get('text')

        # Decode Bytes manually to bypass torchcodec errors
        audio_bytes = sample['audio']['bytes']
        audio_data, sr = librosa.load(io.BytesIO(audio_bytes), sr=16000)
        temp_path = "dataset_sample.wav"
        sf.write(temp_path, audio_data, sr)
        
        return temp_path, gt_text.lower().strip(), SPEAKER_META[speaker_id]
    except Exception as e:
        return None, f"Dataset Error: {e}", {}

# --- Logic: Model Processing ---
def process_audio_step_1(audio_path):
    """Runs Whisper Baseline and returns normalized text."""
    if not audio_path: return "No audio loaded", ""
    
    # Pre-process audio format to 16k
    formatted_path = format_audio(audio_path)
    
    # Run Whisper
    result = whisper_asr(formatted_path)
    raw_w = result["text"]
    norm_w = normalize_text(raw_w)
    return raw_w, norm_w

def process_audio_step_2(audio_path, norm_whisper):
    """Sends audio + normalized whisper to the Private Model API."""
    if not audio_path or not norm_whisper: 
        return "Please load data and run Whisper (Step 1) first."
    
    try:
        # Connect to the private API
        client = Client(PRIVATE_BACKEND_URL, token=HF_TOKEN)
        
        # FIX: Wrap audio_path with handle_file()
        # This sends the metadata required by Pydantic ('gradio.FileData')
        prediction = client.predict(
            audio_path=handle_file(audio_path), 
            whisper_norm=norm_whisper, 
            api_name="/predict_dsr"
        )
        return prediction
    except Exception as e:
        return f"Backend Connection Required. Details: {e}"

# --- UI Construction ---
with gr.Blocks(theme=gr.themes.Soft(), title="Torgo DSR Lab") as demo:
    gr.Markdown("# βš—οΈ Torgo DSR Lab")
    gr.Markdown("Neural Reconstruction Layer for Torgo and UA-Speech Zero-Shot.")
    
    # Hidden state to store the path of the currently active audio
    active_audio_path = gr.State("")

    with gr.Tab("πŸ”¬ Laboratory"):
        with gr.Row():
            # LEFT COLUMN: Data Input
            with gr.Column(scale=1):
                with gr.Group():
                    gr.Markdown("### Channel A: Research Datasets")
                    speaker_input = gr.Dropdown(sorted(list(SPEAKER_META.keys())), label="Select Speaker Profile", value="F01")
                    load_btn = gr.Button("Load Sample from Dataset")
                    gt_box = gr.Textbox(label="Ground Truth (Reference)", interactive=False)
                    meta_display = gr.JSON(label="Speaker Metadata")
                
                gr.Markdown("---")
                
                with gr.Group():
                    gr.Markdown("### Channel B: Personal Input")
                    user_audio = gr.Audio(sources=["microphone", "upload"], type="filepath", label="Record or Upload Audio")
                    user_load_btn = gr.Button("Use This Audio")

            # RIGHT COLUMN: Transcripts
            with gr.Column(scale=2):
                gr.Markdown("### Analysis & Reconstruction")
                
                with gr.Group():
                    gr.Markdown("#### Step 1: ASR Baseline")
                    whisper_btn = gr.Button("Run Whisper Tiny")
                    w_raw = gr.Textbox(label="Whisper Raw Transcript")
                    w_norm = gr.Textbox(label="Whisper Normalized")
                
                gr.Markdown("---")
                
                with gr.Group():
                    gr.Markdown("#### Step 2: Neural Reconstruction")
                    model_btn = gr.Button("Run Our Correction Model", variant="primary")
                    final_out = gr.Textbox(label="DSR Lab Prediction (5K Model)")

    with gr.Tab("πŸ“Š Research Statistics"):
        gr.Markdown("# πŸ”¬ Performance Evaluation")
        
        with gr.Row():
            with gr.Column():
                gr.Markdown("""
                ### πŸ“ Metric: Exact Match Accuracy
                Accuracy is the percentage of samples where the **normalized prediction** (lowercase, no punctuation) exactly matches the **normalized ground truth**.
                """)
            
            with gr.Column():
                gr.Markdown("""
                ### πŸ§ͺ Model Definitions
                * **5K Pure Model:** Trained on real-world Torgo articulatory distortions. Optimized for phonetic fidelity.
                * **10K Triple-Mix Model:** Includes synthetic data and anchors; utilized for generalization (LOSO) testing.
                """)

        gr.Markdown("---")
        gr.Markdown("## 1. Torgo In-Domain Analysis (By Speaker)")
        gr.DataFrame(get_indomain_breakdown())
        
        gr.Markdown("## 2. Experimental Milestone Summary")
        gr.DataFrame(get_experimental_summary())

        gr.Markdown("""
        ### πŸ” Key Discovery: The Acoustic Floor
        Our research found that the **5K Pure Model** achieved higher accuracy in both in-domain and zero-shot tasks. This suggests an **'Acoustic Floor'** exists where real-world phonetic distortions are more valuable for model grounding than synthetic linguistic diversity.
        """)

    # --- Event Handlers ---

    # Dataset Channel: Load -> Update State -> Update UI Text/Meta
    load_btn.click(
        get_sample_logic, 
        inputs=speaker_input, 
        outputs=[active_audio_path, gt_box, meta_display]
    )

    # Personal Channel: Use Audio -> Update State -> Clear Reference
    user_load_btn.click(
        lambda x: (x, "User Recorded (No Ground Truth)", {"Dataset": "Custom", "Severity": "N/A"}), 
        inputs=user_audio, 
        outputs=[active_audio_path, gt_box, meta_display]
    )

    # Step 1: Whisper (Uses State)
    whisper_btn.click(
        process_audio_step_1, 
        inputs=active_audio_path, 
        outputs=[w_raw, w_norm]
    )

    # Step 2: Model (Uses State + Whisper result)
    model_btn.click(
        process_audio_step_2, 
        inputs=[active_audio_path, w_norm], 
        outputs=final_out
    )

demo.launch()