| | """ |
| | Speech-to-Text Model Arena |
| | A Gradio demo for comparing multiple STT models side-by-side. |
| | """ |
| |
|
| | import gradio as gr |
| | import logging |
| | import os |
| | import requests |
| | from dotenv import load_dotenv |
| |
|
| | load_dotenv() |
| |
|
| | logging.basicConfig( |
| | level=logging.INFO, |
| | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", |
| | datefmt="%Y-%m-%d %H:%M:%S", |
| | ) |
| | logger = logging.getLogger("stt_arena") |
| |
|
| | HF_ENDPOINT = os.getenv("HF_ENDPOINT") |
| | HF_API_KEY = os.getenv("HF_API_KEY") |
| | WHISPER_API_URL = "https://router.huggingface.co/hf-inference/models/openai/whisper-large-v3" |
| | WHISPER_TURBO_API_URL = "https://router.huggingface.co/hf-inference/models/openai/whisper-large-v3-turbo" |
| |
|
| | if HF_ENDPOINT: |
| | logger.info(f"Using Hugging Face Endpoint: {HF_ENDPOINT}") |
| | else: |
| | logger.warning("HF_ENDPOINT not set, StutteredSpeechASR will use local model") |
| |
|
| | MODELS = [ |
| | { |
| | "name": "๐ฃ๏ธ StutteredSpeechASR", |
| | "id": "stuttered", |
| | "hf_id": "AImpower/StutteredSpeechASR", |
| | "description": "Whisper fine-tuned for stuttered speech (Mandarin)", |
| | }, |
| | { |
| | "name": "๐๏ธ Whisper Large V3", |
| | "id": "whisper", |
| | "hf_id": "openai/whisper-large-v3", |
| | "description": "OpenAI Whisper Large V3 model (via HF Inference API)", |
| | }, |
| | { |
| | "name": "๐ Whisper Large V3 Turbo", |
| | "id": "whisper_turbo", |
| | "hf_id": "openai/whisper-large-v3-turbo", |
| | "description": "OpenAI Whisper Large V3 Turbo (via HF Inference API)", |
| | }, |
| | ] |
| |
|
| |
|
| | def run_api_inference(audio_path: str, api_url: str, model_name: str) -> str: |
| | """ |
| | Run inference using any Hugging Face API endpoint. |
| | |
| | Args: |
| | audio_path: Path to the audio file |
| | api_url: The API endpoint URL |
| | model_name: Name of the model for error messages |
| | |
| | Returns: |
| | Transcribed text |
| | """ |
| | if not HF_API_KEY: |
| | raise ValueError("HF_API_KEY must be set in environment variables") |
| | |
| | logger.info(f"Running inference via {model_name}") |
| | |
| | with open(audio_path, "rb") as f: |
| | audio_bytes = f.read() |
| | |
| | headers = { |
| | "Authorization": f"Bearer {HF_API_KEY}", |
| | "Content-Type": "audio/wav", |
| | } |
| | |
| | response = requests.post( |
| | api_url, |
| | headers=headers, |
| | data=audio_bytes, |
| | timeout=120, |
| | ) |
| | |
| | if response.status_code != 200: |
| | logger.error(f"{model_name} error: {response.status_code} - {response.text}") |
| | |
| | try: |
| | error_data = response.json() |
| | error_msg = error_data.get("error", "") |
| | |
| | if "paused" in error_msg.lower(): |
| | return f"โธ๏ธ The {model_name} endpoint is currently paused. Please contact the maintainer to restart it." |
| | elif "loading" in error_msg.lower(): |
| | return f"โณ {model_name} is loading. Please wait and try again." |
| | elif response.status_code == 503: |
| | return f"๐ {model_name} service is temporarily unavailable. Please try again." |
| | else: |
| | return f"โ {model_name} Error: {error_msg}" |
| | except: |
| | return f"โ {model_name} Error: HTTP {response.status_code}" |
| | |
| | result = response.json() |
| | logger.debug(f"{model_name} response: {result}") |
| | |
| | if isinstance(result, dict): |
| | transcription = result.get("text", "") or result.get("transcription", "") |
| | elif isinstance(result, list) and len(result) > 0: |
| | transcription = result[0].get("text", "") if isinstance(result[0], dict) else str(result[0]) |
| | else: |
| | transcription = str(result) |
| | |
| | return transcription.strip() |
| |
|
| |
|
| | def run_inference(audio_path: str, model_config: dict) -> str: |
| | """ |
| | Run inference on a single model. |
| | |
| | Args: |
| | audio_path: Path to the audio file |
| | model_config: Model configuration dictionary |
| | |
| | Returns: |
| | Transcribed text |
| | """ |
| | if audio_path is None: |
| | logger.warning("No audio provided") |
| | return "โ ๏ธ No audio provided. Please record or upload audio first." |
| |
|
| | try: |
| | logger.info(f"Running inference with model: {model_config['name']}") |
| | logger.debug(f"Audio path: {audio_path}") |
| | |
| | if model_config["id"] == "stuttered" and HF_ENDPOINT and HF_API_KEY: |
| | return run_api_inference(audio_path, HF_ENDPOINT, "StutteredSpeechASR") |
| | |
| | if model_config["id"] == "whisper" and HF_API_KEY: |
| | return run_api_inference(audio_path, WHISPER_API_URL, "Whisper Large V3") |
| | |
| | if model_config["id"] == "whisper_turbo" and HF_API_KEY: |
| | return run_api_inference(audio_path, WHISPER_TURBO_API_URL, "Whisper Large V3 Turbo") |
| | |
| | raise ValueError("HF_API_KEY must be set to use this model") |
| |
|
| | except Exception as e: |
| | logger.error(f"Error during inference with {model_config['name']}: {str(e)}", exc_info=True) |
| | return f"โ Error: {str(e)}" |
| |
|
| |
|
| | def run_all_models(audio): |
| | """ |
| | Run inference on all models sequentially. |
| | |
| | Args: |
| | audio: Audio input from Gradio component |
| | |
| | Returns: |
| | List of transcription results for each model |
| | """ |
| | logger.info(f"Starting inference on {len(MODELS)} models") |
| | results = [] |
| |
|
| | for model_config in MODELS: |
| | text = run_inference(audio, model_config) |
| | results.append(text) |
| |
|
| | logger.info("All models completed") |
| | return results |
| |
|
| |
|
| | def load_css(): |
| | """Load CSS from external file""" |
| | css_path = os.path.join(os.path.dirname(__file__), "style.css") |
| | try: |
| | with open(css_path, "r", encoding="utf-8") as f: |
| | return f.read() |
| | except FileNotFoundError: |
| | logger.warning(f"CSS file not found at {css_path}") |
| | return "" |
| |
|
| |
|
| | |
| | with gr.Blocks( |
| | theme=gr.themes.Soft(), |
| | title="StutteredSpeechASR Research Demo", |
| | css=load_css() |
| | ) as demo: |
| |
|
| | |
| | gr.Markdown( |
| | """ |
| | <div style="text-align: center; max-width: 800px; margin: 0 auto;"> |
| | |
| | # ๐ฃ๏ธ StutteredSpeechASR Research Demo |
| | |
| | ### Fine-tuned Whisper model for stuttered speech recognition |
| | |
| | This demo showcases our **StutteredSpeechASR** model, a Whisper model fine-tuned specifically |
| | for stuttered speech (Mandarin). Compare its performance against baseline Whisper models |
| | to see the improvement on stuttered speech patterns. |
| | |
| | Upload an audio file or record using your microphone to test the models. |
| | |
| | </div> |
| | """, |
| | elem_classes=["title-text"] |
| | ) |
| |
|
| | gr.Markdown("---") |
| |
|
| | |
| | with gr.Group(): |
| | gr.Markdown("### ๐ค Audio Input") |
| | audio_input = gr.Audio( |
| | sources=["microphone", "upload"], |
| | type="filepath", |
| | label="Record or Upload Audio", |
| | streaming=False, |
| | editable=True, |
| | ) |
| |
|
| | |
| | run_button = gr.Button( |
| | "๐ Compare Models", |
| | variant="primary", |
| | size="lg", |
| | elem_classes=["run-button"] |
| | ) |
| |
|
| | gr.Markdown("---") |
| | gr.Markdown("### ๐ Model Comparison Results") |
| |
|
| | |
| | with gr.Row(equal_height=True): |
| | output_components = [] |
| |
|
| | for model in MODELS: |
| | with gr.Column(elem_classes=["model-card"]): |
| | gr.Markdown(f"## {model['name']}") |
| |
|
| | text_output = gr.Textbox( |
| | label="Transcription", |
| | placeholder="Transcribed text will appear here...", |
| | lines=4, |
| | interactive=False, |
| | ) |
| |
|
| | output_components.append(text_output) |
| |
|
| | run_button.click( |
| | fn=run_all_models, |
| | inputs=[audio_input], |
| | outputs=output_components, |
| | show_progress=True, |
| | ) |
| |
|
| | |
| | gr.Markdown("---") |
| | gr.Markdown( |
| | """ |
| | <center> |
| | |
| | **๐ก Research Note:** |
| | - The StutteredSpeechASR model is designed to better handle stuttered speech patterns |
| | - For best results, use clear audio recordings |
| | |
| | *Research Demo | AImpower StutteredSpeechASR* |
| | |
| | </center> |
| | """, |
| | elem_classes=["footer"] |
| | ) |
| |
|
| |
|
| | |
| | if __name__ == "__main__": |
| | logger.info("Starting StutteredSpeechASR Research Demo") |
| | logger.info(f"Models configured: {[m['name'] for m in MODELS]}") |
| | demo.launch( |
| | share=False, |
| | server_name="0.0.0.0", |
| | server_port=7860, |
| | show_error=True, |
| | ) |
| | logger.info("Application shutdown") |
| |
|