ASR-finetuning / demo /app.py
saadmannan's picture
HF space application - exclude binary PDFs
5554ef1
"""
Gradio Demo for Whisper German ASR
Interactive web interface for audio transcription
"""
import gradio as gr
import torch
from transformers import WhisperForConditionalGeneration, WhisperProcessor
import librosa
import numpy as np
from pathlib import Path
import logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Global variables
model = None
processor = None
device = None
def load_model(model_path="./whisper_test_tuned"):
"""Load the fine-tuned Whisper model"""
global model, processor, device
logger.info(f"Loading model from: {model_path}")
model_path = Path(model_path)
# Check for checkpoint directories
if model_path.is_dir():
checkpoints = list(model_path.glob('checkpoint-*'))
if checkpoints:
latest = max(checkpoints, key=lambda p: int(p.name.split('-')[1]))
model_path = latest
logger.info(f"Using checkpoint: {latest.name}")
model = WhisperForConditionalGeneration.from_pretrained(str(model_path))
processor = WhisperProcessor.from_pretrained("openai/whisper-small")
# Set German language conditioning
model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(
language="german",
task="transcribe"
)
device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)
model.eval()
logger.info(f"βœ“ Model loaded on {device}")
return f"Model loaded successfully on {device}"
def transcribe_audio(audio_input):
"""Transcribe audio from microphone or file upload"""
if model is None:
return "❌ Error: Model not loaded. Please wait for model to load."
try:
# Handle different input formats
if audio_input is None:
return "❌ No audio provided"
# audio_input is a tuple (sample_rate, audio_data) from gradio
if isinstance(audio_input, tuple):
sr, audio = audio_input
# Convert to float32 and normalize
if audio.dtype == np.int16:
audio = audio.astype(np.float32) / 32768.0
elif audio.dtype == np.int32:
audio = audio.astype(np.float32) / 2147483648.0
else:
# File path
audio, sr = librosa.load(audio_input, sr=16000, mono=True)
# Resample if needed
if sr != 16000:
audio = librosa.resample(audio, orig_sr=sr, target_sr=16000)
# Ensure mono
if len(audio.shape) > 1:
audio = audio.mean(axis=1)
duration = len(audio) / 16000
# Process audio
input_features = processor(
audio,
sampling_rate=16000,
return_tensors="pt"
).input_features.to(device)
# Generate transcription
with torch.no_grad():
predicted_ids = model.generate(
input_features,
max_length=448,
num_beams=5,
early_stopping=True
)
transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
logger.info(f"Transcribed {duration:.2f}s audio: {transcription[:50]}...")
return f"🎀 **Transcription:**\n\n{transcription}\n\nπŸ“Š Duration: {duration:.2f} seconds"
except Exception as e:
logger.error(f"Transcription error: {e}")
return f"❌ Error: {str(e)}"
# Load model on startup
try:
load_model()
except Exception as e:
logger.error(f"Failed to load model: {e}")
logger.info("Model will need to be loaded manually")
# Create Gradio interface
with gr.Blocks(title="Whisper German ASR", theme=gr.themes.Soft()) as demo:
gr.Markdown(
"""
# πŸŽ™οΈ Whisper German ASR
Fine-tuned Whisper model for German speech recognition.
**Features:**
- Real-time transcription
- Microphone or file upload support
- Optimized for German language
**Model:** Whisper-small fine-tuned on German MINDS14 dataset
"""
)
with gr.Tab("🎀 Transcribe"):
with gr.Row():
with gr.Column():
audio_input = gr.Audio(
sources=["microphone", "upload"],
type="numpy",
label="Audio Input"
)
transcribe_btn = gr.Button("Transcribe", variant="primary", size="lg")
with gr.Column():
output_text = gr.Markdown(label="Transcription")
transcribe_btn.click(
fn=transcribe_audio,
inputs=audio_input,
outputs=output_text
)
with gr.Tab("ℹ️ About"):
gr.Markdown(
"""
## About This Model
This is a fine-tuned version of OpenAI's Whisper-small model,
specifically optimized for German speech recognition.
### Training Details
- **Base Model:** openai/whisper-small (242M parameters)
- **Dataset:** PolyAI/minds14 (German subset)
- **Training Samples:** ~274 samples
- **Performance:** ~13% Word Error Rate (WER)
### Technical Specifications
- **Sample Rate:** 16kHz
- **Max Duration:** 30 seconds
- **Language:** German (de)
- **Task:** Transcription
### Usage Tips
- Speak clearly and at a moderate pace
- Minimize background noise
- Audio should be in German language
- Best results with 1-30 second clips
### Links
- [GitHub Repository](#)
- [Model Card](#)
- [Documentation](#)
"""
)
with gr.Tab("πŸ“Š Examples"):
gr.Examples(
examples=[
# Add example audio files here if available
],
inputs=audio_input,
outputs=output_text,
fn=transcribe_audio,
cache_examples=False
)
# Launch the app
if __name__ == "__main__":
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=False
)