Automatic Speech Recognition
Safetensors
Chinese
whisper
gpric024's picture
Changed for API calls instead of locally running models
df83b8b
raw
history blame
8.62 kB
"""
Speech-to-Text Model Arena
A Gradio demo for comparing multiple STT models side-by-side.
"""
import gradio as gr
import logging
import os
import requests
from dotenv import load_dotenv
load_dotenv()
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
logger = logging.getLogger("stt_arena")
HF_ENDPOINT = os.getenv("HF_ENDPOINT")
HF_API_KEY = os.getenv("HF_API_KEY")
WHISPER_API_URL = "https://router.huggingface.co/hf-inference/models/openai/whisper-large-v3"
WHISPER_TURBO_API_URL = "https://router.huggingface.co/hf-inference/models/openai/whisper-large-v3-turbo"
if HF_ENDPOINT:
logger.info(f"Using Hugging Face Endpoint: {HF_ENDPOINT}")
else:
logger.warning("HF_ENDPOINT not set, StutteredSpeechASR will use local model")
MODELS = [
{
"name": "๐Ÿ—ฃ๏ธ StutteredSpeechASR",
"id": "stuttered",
"hf_id": "AImpower/StutteredSpeechASR",
"description": "Whisper fine-tuned for stuttered speech (Mandarin)",
},
{
"name": "๐ŸŽ™๏ธ Whisper Large V3",
"id": "whisper",
"hf_id": "openai/whisper-large-v3",
"description": "OpenAI Whisper Large V3 model (via HF Inference API)",
},
{
"name": "๐Ÿ”Š Whisper Large V3 Turbo",
"id": "whisper_turbo",
"hf_id": "openai/whisper-large-v3-turbo",
"description": "OpenAI Whisper Large V3 Turbo (via HF Inference API)",
},
]
def run_api_inference(audio_path: str, api_url: str, model_name: str) -> str:
"""
Run inference using any Hugging Face API endpoint.
Args:
audio_path: Path to the audio file
api_url: The API endpoint URL
model_name: Name of the model for error messages
Returns:
Transcribed text
"""
if not HF_API_KEY:
raise ValueError("HF_API_KEY must be set in environment variables")
logger.info(f"Running inference via {model_name}")
with open(audio_path, "rb") as f:
audio_bytes = f.read()
headers = {
"Authorization": f"Bearer {HF_API_KEY}",
"Content-Type": "audio/wav",
}
response = requests.post(
api_url,
headers=headers,
data=audio_bytes,
timeout=120,
)
if response.status_code != 200:
logger.error(f"{model_name} error: {response.status_code} - {response.text}")
try:
error_data = response.json()
error_msg = error_data.get("error", "")
if "paused" in error_msg.lower():
return f"โธ๏ธ The {model_name} endpoint is currently paused. Please contact the maintainer to restart it."
elif "loading" in error_msg.lower():
return f"โณ {model_name} is loading. Please wait and try again."
elif response.status_code == 503:
return f"๐Ÿ”„ {model_name} service is temporarily unavailable. Please try again."
else:
return f"โŒ {model_name} Error: {error_msg}"
except:
return f"โŒ {model_name} Error: HTTP {response.status_code}"
result = response.json()
logger.debug(f"{model_name} response: {result}")
if isinstance(result, dict):
transcription = result.get("text", "") or result.get("transcription", "")
elif isinstance(result, list) and len(result) > 0:
transcription = result[0].get("text", "") if isinstance(result[0], dict) else str(result[0])
else:
transcription = str(result)
return transcription.strip()
def run_inference(audio_path: str, model_config: dict) -> str:
"""
Run inference on a single model.
Args:
audio_path: Path to the audio file
model_config: Model configuration dictionary
Returns:
Transcribed text
"""
if audio_path is None:
logger.warning("No audio provided")
return "โš ๏ธ No audio provided. Please record or upload audio first."
try:
logger.info(f"Running inference with model: {model_config['name']}")
logger.debug(f"Audio path: {audio_path}")
if model_config["id"] == "stuttered" and HF_ENDPOINT and HF_API_KEY:
return run_api_inference(audio_path, HF_ENDPOINT, "StutteredSpeechASR")
if model_config["id"] == "whisper" and HF_API_KEY:
return run_api_inference(audio_path, WHISPER_API_URL, "Whisper Large V3")
if model_config["id"] == "whisper_turbo" and HF_API_KEY:
return run_api_inference(audio_path, WHISPER_TURBO_API_URL, "Whisper Large V3 Turbo")
raise ValueError("HF_API_KEY must be set to use this model")
except Exception as e:
logger.error(f"Error during inference with {model_config['name']}: {str(e)}", exc_info=True)
return f"โŒ Error: {str(e)}"
def run_all_models(audio):
"""
Run inference on all models sequentially.
Args:
audio: Audio input from Gradio component
Returns:
List of transcription results for each model
"""
logger.info(f"Starting inference on {len(MODELS)} models")
results = []
for model_config in MODELS:
text = run_inference(audio, model_config)
results.append(text)
logger.info("All models completed")
return results
def load_css():
"""Load CSS from external file"""
css_path = os.path.join(os.path.dirname(__file__), "style.css")
try:
with open(css_path, "r", encoding="utf-8") as f:
return f.read()
except FileNotFoundError:
logger.warning(f"CSS file not found at {css_path}")
return ""
# Build the Gradio interface
with gr.Blocks(
theme=gr.themes.Soft(),
title="StutteredSpeechASR Research Demo",
css=load_css()
) as demo:
# Title and Description
gr.Markdown(
"""
<div style="text-align: center; max-width: 800px; margin: 0 auto;">
# ๐Ÿ—ฃ๏ธ StutteredSpeechASR Research Demo
### Fine-tuned Whisper model for stuttered speech recognition
This demo showcases our **StutteredSpeechASR** model, a Whisper model fine-tuned specifically
for stuttered speech (Mandarin). Compare its performance against baseline Whisper models
to see the improvement on stuttered speech patterns.
Upload an audio file or record using your microphone to test the models.
</div>
""",
elem_classes=["title-text"]
)
gr.Markdown("---")
# Audio Input Section
with gr.Group():
gr.Markdown("### ๐ŸŽค Audio Input")
audio_input = gr.Audio(
sources=["microphone", "upload"],
type="filepath",
label="Record or Upload Audio",
streaming=False,
editable=True,
)
# Run Button
run_button = gr.Button(
"๐Ÿš€ Compare Models",
variant="primary",
size="lg",
elem_classes=["run-button"]
)
gr.Markdown("---")
gr.Markdown("### ๐Ÿ“Š Model Comparison Results")
# Model Output Cards
with gr.Row(equal_height=True):
output_components = []
for model in MODELS:
with gr.Column(elem_classes=["model-card"]):
gr.Markdown(f"## {model['name']}")
text_output = gr.Textbox(
label="Transcription",
placeholder="Transcribed text will appear here...",
lines=4,
interactive=False,
)
output_components.append(text_output)
run_button.click(
fn=run_all_models,
inputs=[audio_input],
outputs=output_components,
show_progress=True,
)
# Footer
gr.Markdown("---")
gr.Markdown(
"""
<center>
**๐Ÿ’ก Research Note:**
- The StutteredSpeechASR model is designed to better handle stuttered speech patterns
- For best results, use clear audio recordings
*Research Demo | AImpower StutteredSpeechASR*
</center>
""",
elem_classes=["footer"]
)
# Launch the app
if __name__ == "__main__":
logger.info("Starting StutteredSpeechASR Research Demo")
logger.info(f"Models configured: {[m['name'] for m in MODELS]}")
demo.launch(
share=False,
server_name="0.0.0.0",
server_port=7860,
show_error=True,
)
logger.info("Application shutdown")