Automatic Speech Recognition
Safetensors
Chinese
whisper
File size: 8,621 Bytes
fbcb7d1
 
 
 
 
 
 
df83b8b
 
 
 
 
fbcb7d1
 
 
 
 
 
 
 
df83b8b
 
 
 
fbcb7d1
df83b8b
 
 
 
fbcb7d1
 
 
 
 
 
 
 
 
df83b8b
fbcb7d1
df83b8b
 
fbcb7d1
 
df83b8b
 
 
 
fbcb7d1
 
 
 
df83b8b
fbcb7d1
df83b8b
 
 
 
 
 
 
 
 
fbcb7d1
df83b8b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fbcb7d1
 
df83b8b
fbcb7d1
 
 
 
 
 
 
 
df83b8b
fbcb7d1
 
 
df83b8b
fbcb7d1
 
 
 
df83b8b
 
 
 
 
 
 
 
 
 
 
fbcb7d1
 
 
df83b8b
fbcb7d1
 
 
 
 
 
 
 
 
 
df83b8b
fbcb7d1
 
 
 
 
df83b8b
 
fbcb7d1
 
 
 
 
df83b8b
 
 
 
 
 
 
 
 
 
 
fbcb7d1
 
 
df83b8b
 
fbcb7d1
 
 
 
 
df83b8b
 
 
fbcb7d1
df83b8b
fbcb7d1
df83b8b
 
 
 
 
 
 
fbcb7d1
 
 
 
 
 
 
 
 
 
 
 
 
df83b8b
 
fbcb7d1
 
 
 
df83b8b
fbcb7d1
 
 
 
 
 
df83b8b
fbcb7d1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
df83b8b
fbcb7d1
 
 
 
 
 
 
 
 
 
 
 
 
 
df83b8b
 
 
 
 
fbcb7d1
 
 
 
 
 
 
 
 
df83b8b
fbcb7d1
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
"""
Speech-to-Text Model Arena
A Gradio demo for comparing multiple STT models side-by-side.
"""

import gradio as gr
import logging
import os
import requests
from dotenv import load_dotenv

load_dotenv()

logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S",
)
logger = logging.getLogger("stt_arena")

HF_ENDPOINT = os.getenv("HF_ENDPOINT")
HF_API_KEY = os.getenv("HF_API_KEY")
WHISPER_API_URL = "https://router.huggingface.co/hf-inference/models/openai/whisper-large-v3"
WHISPER_TURBO_API_URL = "https://router.huggingface.co/hf-inference/models/openai/whisper-large-v3-turbo"

if HF_ENDPOINT:
    logger.info(f"Using Hugging Face Endpoint: {HF_ENDPOINT}")
else:
    logger.warning("HF_ENDPOINT not set, StutteredSpeechASR will use local model")

MODELS = [
    {
        "name": "πŸ—£οΈ StutteredSpeechASR",
        "id": "stuttered",
        "hf_id": "AImpower/StutteredSpeechASR",
        "description": "Whisper fine-tuned for stuttered speech (Mandarin)",
    },
    {
        "name": "πŸŽ™οΈ Whisper Large V3",
        "id": "whisper",
        "hf_id": "openai/whisper-large-v3",
        "description": "OpenAI Whisper Large V3 model (via HF Inference API)",
    },
    {
        "name": "πŸ”Š Whisper Large V3 Turbo",
        "id": "whisper_turbo",
        "hf_id": "openai/whisper-large-v3-turbo",
        "description": "OpenAI Whisper Large V3 Turbo (via HF Inference API)",
    },
]


def run_api_inference(audio_path: str, api_url: str, model_name: str) -> str:
    """
    Run inference using any Hugging Face API endpoint.
    
    Args:
        audio_path: Path to the audio file
        api_url: The API endpoint URL
        model_name: Name of the model for error messages
        
    Returns:
        Transcribed text
    """
    if not HF_API_KEY:
        raise ValueError("HF_API_KEY must be set in environment variables")
    
    logger.info(f"Running inference via {model_name}")
    
    with open(audio_path, "rb") as f:
        audio_bytes = f.read()
    
    headers = {
        "Authorization": f"Bearer {HF_API_KEY}",
        "Content-Type": "audio/wav",
    }
    
    response = requests.post(
        api_url,
        headers=headers,
        data=audio_bytes,
        timeout=120,
    )
    
    if response.status_code != 200:
        logger.error(f"{model_name} error: {response.status_code} - {response.text}")
        
        try:
            error_data = response.json()
            error_msg = error_data.get("error", "")
            
            if "paused" in error_msg.lower():
                return f"⏸️ The {model_name} endpoint is currently paused. Please contact the maintainer to restart it."
            elif "loading" in error_msg.lower():
                return f"⏳ {model_name} is loading. Please wait and try again."
            elif response.status_code == 503:
                return f"πŸ”„ {model_name} service is temporarily unavailable. Please try again."
            else:
                return f"❌ {model_name} Error: {error_msg}"
        except:
            return f"❌ {model_name} Error: HTTP {response.status_code}"
    
    result = response.json()
    logger.debug(f"{model_name} response: {result}")
    
    if isinstance(result, dict):
        transcription = result.get("text", "") or result.get("transcription", "")
    elif isinstance(result, list) and len(result) > 0:
        transcription = result[0].get("text", "") if isinstance(result[0], dict) else str(result[0])
    else:
        transcription = str(result)
    
    return transcription.strip()


def run_inference(audio_path: str, model_config: dict) -> str:
    """
    Run inference on a single model.
    
    Args:
        audio_path: Path to the audio file
        model_config: Model configuration dictionary
        
    Returns:
        Transcribed text
    """
    if audio_path is None:
        logger.warning("No audio provided")
        return "⚠️ No audio provided. Please record or upload audio first."

    try:
        logger.info(f"Running inference with model: {model_config['name']}")
        logger.debug(f"Audio path: {audio_path}")
        
        if model_config["id"] == "stuttered" and HF_ENDPOINT and HF_API_KEY:
            return run_api_inference(audio_path, HF_ENDPOINT, "StutteredSpeechASR")
        
        if model_config["id"] == "whisper" and HF_API_KEY:
            return run_api_inference(audio_path, WHISPER_API_URL, "Whisper Large V3")
        
        if model_config["id"] == "whisper_turbo" and HF_API_KEY:
            return run_api_inference(audio_path, WHISPER_TURBO_API_URL, "Whisper Large V3 Turbo")
        
        raise ValueError("HF_API_KEY must be set to use this model")

    except Exception as e:
        logger.error(f"Error during inference with {model_config['name']}: {str(e)}", exc_info=True)
        return f"❌ Error: {str(e)}"


def run_all_models(audio):
    """
    Run inference on all models sequentially.
    
    Args:
        audio: Audio input from Gradio component
        
    Returns:
        List of transcription results for each model
    """
    logger.info(f"Starting inference on {len(MODELS)} models")
    results = []

    for model_config in MODELS:
        text = run_inference(audio, model_config)
        results.append(text)

    logger.info("All models completed")
    return results


def load_css():
    """Load CSS from external file"""
    css_path = os.path.join(os.path.dirname(__file__), "style.css")
    try:
        with open(css_path, "r", encoding="utf-8") as f:
            return f.read()
    except FileNotFoundError:
        logger.warning(f"CSS file not found at {css_path}")
        return ""


# Build the Gradio interface
with gr.Blocks(
    theme=gr.themes.Soft(),
    title="StutteredSpeechASR Research Demo",
    css=load_css()
) as demo:

    # Title and Description
    gr.Markdown(
        """
        <div style="text-align: center; max-width: 800px; margin: 0 auto;">
        
        # πŸ—£οΈ StutteredSpeechASR Research Demo
        
        ### Fine-tuned Whisper model for stuttered speech recognition
        
        This demo showcases our **StutteredSpeechASR** model, a Whisper model fine-tuned specifically 
        for stuttered speech (Mandarin). Compare its performance against baseline Whisper models 
        to see the improvement on stuttered speech patterns.
        
        Upload an audio file or record using your microphone to test the models.
        
        </div>
        """,
        elem_classes=["title-text"]
    )

    gr.Markdown("---")

    # Audio Input Section
    with gr.Group():
        gr.Markdown("### 🎀 Audio Input")
        audio_input = gr.Audio(
            sources=["microphone", "upload"],
            type="filepath",
            label="Record or Upload Audio",
            streaming=False,
            editable=True,
        )

    # Run Button
    run_button = gr.Button(
        "πŸš€ Compare Models",
        variant="primary",
        size="lg",
        elem_classes=["run-button"]
    )

    gr.Markdown("---")
    gr.Markdown("### πŸ“Š Model Comparison Results")

    # Model Output Cards
    with gr.Row(equal_height=True):
        output_components = []

        for model in MODELS:
            with gr.Column(elem_classes=["model-card"]):
                gr.Markdown(f"## {model['name']}")

                text_output = gr.Textbox(
                    label="Transcription",
                    placeholder="Transcribed text will appear here...",
                    lines=4,
                    interactive=False,
                )

                output_components.append(text_output)

    run_button.click(
        fn=run_all_models,
        inputs=[audio_input],
        outputs=output_components,
        show_progress=True,
    )

    # Footer
    gr.Markdown("---")
    gr.Markdown(
        """
        <center>
        
        **πŸ’‘ Research Note:**
        - The StutteredSpeechASR model is designed to better handle stuttered speech patterns
        - For best results, use clear audio recordings
        
        *Research Demo | AImpower StutteredSpeechASR*
        
        </center>
        """,
        elem_classes=["footer"]
    )


# Launch the app
if __name__ == "__main__":
    logger.info("Starting StutteredSpeechASR Research Demo")
    logger.info(f"Models configured: {[m['name'] for m in MODELS]}")
    demo.launch(
        share=False,
        server_name="0.0.0.0",
        server_port=7860,
        show_error=True,
    )
    logger.info("Application shutdown")