""" Speech-to-Text Model Arena A Gradio demo for comparing multiple STT models side-by-side. """ import gradio as gr import logging import os import requests from dotenv import load_dotenv load_dotenv() logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", ) logger = logging.getLogger("stt_arena") HF_ENDPOINT = os.getenv("HF_ENDPOINT") HF_API_KEY = os.getenv("HF_API_KEY") WHISPER_API_URL = "https://router.huggingface.co/hf-inference/models/openai/whisper-large-v3" WHISPER_TURBO_API_URL = "https://router.huggingface.co/hf-inference/models/openai/whisper-large-v3-turbo" if HF_ENDPOINT: logger.info(f"Using Hugging Face Endpoint: {HF_ENDPOINT}") else: logger.warning("HF_ENDPOINT not set, StutteredSpeechASR will use local model") MODELS = [ { "name": "π£οΈ StutteredSpeechASR", "id": "stuttered", "hf_id": "AImpower/StutteredSpeechASR", "description": "Whisper fine-tuned for stuttered speech (Mandarin)", }, { "name": "ποΈ Whisper Large V3", "id": "whisper", "hf_id": "openai/whisper-large-v3", "description": "OpenAI Whisper Large V3 model (via HF Inference API)", }, { "name": "π Whisper Large V3 Turbo", "id": "whisper_turbo", "hf_id": "openai/whisper-large-v3-turbo", "description": "OpenAI Whisper Large V3 Turbo (via HF Inference API)", }, ] def run_api_inference(audio_path: str, api_url: str, model_name: str) -> str: """ Run inference using any Hugging Face API endpoint. Args: audio_path: Path to the audio file api_url: The API endpoint URL model_name: Name of the model for error messages Returns: Transcribed text """ if not HF_API_KEY: raise ValueError("HF_API_KEY must be set in environment variables") logger.info(f"Running inference via {model_name}") with open(audio_path, "rb") as f: audio_bytes = f.read() headers = { "Authorization": f"Bearer {HF_API_KEY}", "Content-Type": "audio/wav", } response = requests.post( api_url, headers=headers, data=audio_bytes, timeout=120, ) if response.status_code != 200: logger.error(f"{model_name} error: {response.status_code} - {response.text}") try: error_data = response.json() error_msg = error_data.get("error", "") if "paused" in error_msg.lower(): return f"βΈοΈ The {model_name} endpoint is currently paused. Please contact the maintainer to restart it." elif "loading" in error_msg.lower(): return f"β³ {model_name} is loading. Please wait and try again." elif response.status_code == 503: return f"π {model_name} service is temporarily unavailable. Please try again." else: return f"β {model_name} Error: {error_msg}" except: return f"β {model_name} Error: HTTP {response.status_code}" result = response.json() logger.debug(f"{model_name} response: {result}") if isinstance(result, dict): transcription = result.get("text", "") or result.get("transcription", "") elif isinstance(result, list) and len(result) > 0: transcription = result[0].get("text", "") if isinstance(result[0], dict) else str(result[0]) else: transcription = str(result) return transcription.strip() def run_inference(audio_path: str, model_config: dict) -> str: """ Run inference on a single model. Args: audio_path: Path to the audio file model_config: Model configuration dictionary Returns: Transcribed text """ if audio_path is None: logger.warning("No audio provided") return "β οΈ No audio provided. Please record or upload audio first." try: logger.info(f"Running inference with model: {model_config['name']}") logger.debug(f"Audio path: {audio_path}") if model_config["id"] == "stuttered" and HF_ENDPOINT and HF_API_KEY: return run_api_inference(audio_path, HF_ENDPOINT, "StutteredSpeechASR") if model_config["id"] == "whisper" and HF_API_KEY: return run_api_inference(audio_path, WHISPER_API_URL, "Whisper Large V3") if model_config["id"] == "whisper_turbo" and HF_API_KEY: return run_api_inference(audio_path, WHISPER_TURBO_API_URL, "Whisper Large V3 Turbo") raise ValueError("HF_API_KEY must be set to use this model") except Exception as e: logger.error(f"Error during inference with {model_config['name']}: {str(e)}", exc_info=True) return f"β Error: {str(e)}" def run_all_models(audio): """ Run inference on all models sequentially. Args: audio: Audio input from Gradio component Returns: List of transcription results for each model """ logger.info(f"Starting inference on {len(MODELS)} models") results = [] for model_config in MODELS: text = run_inference(audio, model_config) results.append(text) logger.info("All models completed") return results def load_css(): """Load CSS from external file""" css_path = os.path.join(os.path.dirname(__file__), "style.css") try: with open(css_path, "r", encoding="utf-8") as f: return f.read() except FileNotFoundError: logger.warning(f"CSS file not found at {css_path}") return "" # Build the Gradio interface with gr.Blocks( theme=gr.themes.Soft(), title="StutteredSpeechASR Research Demo", css=load_css() ) as demo: # Title and Description gr.Markdown( """