whisper / app.py
cyberspyde's picture
update
93239e3
import gradio as gr
import torch
from transformers import WhisperProcessor, WhisperForConditionalGeneration
import logging
import os
from datetime import datetime
from huggingface_hub import HfApi
import numpy as np
from scipy import signal
import math
import time
# Setup logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
# CPU tuning knobs (helpful on Windows)
try:
_torch_threads = int(os.getenv("TORCH_NUM_THREADS", "0"))
if _torch_threads > 0:
torch.set_num_threads(_torch_threads)
except Exception:
pass
try:
_torch_interop_threads = int(os.getenv("TORCH_NUM_INTEROP_THREADS", "0"))
if _torch_interop_threads > 0:
torch.set_num_interop_threads(_torch_interop_threads)
except Exception:
pass
MODEL_NAME = "jmshd/whisper-uz"
# Log environment info
logger.info(f"Starting Whisper Uzbek STT application")
logger.info(f"PyTorch version: {torch.__version__}")
logger.info(f"CUDA available: {torch.cuda.is_available()}")
logger.info(f"Model: {MODEL_NAME}")
# Load model and processor
try:
logger.info("Loading processor...")
processor = WhisperProcessor.from_pretrained(MODEL_NAME)
logger.info("Loading model...")
model = WhisperForConditionalGeneration.from_pretrained(MODEL_NAME)
model.eval()
# Force Uzbek transcription when supported by the installed Transformers version.
# Some versions do not accept `forced_decoder_ids` as a generate() kwarg, so we
# prefer setting it on the config/generation_config.
try:
uz_forced_decoder_ids = processor.get_decoder_prompt_ids(language="uz", task="transcribe")
if hasattr(model, "generation_config") and hasattr(model.generation_config, "forced_decoder_ids"):
model.generation_config.forced_decoder_ids = uz_forced_decoder_ids
elif hasattr(model, "config"):
model.config.forced_decoder_ids = uz_forced_decoder_ids
logger.info("Configured Uzbek forced decoder prompt IDs")
except Exception as e:
logger.warning(f"Could not configure Uzbek decoding prompt IDs (continuing without forcing language): {e}")
logger.info("Model and processor loaded successfully")
except Exception as e:
logger.error(f"Error loading model: {str(e)}")
raise
def resample_audio(audio_data, orig_sr, target_sr=16000):
"""
Resample audio to target sample rate
Args:
audio_data: Audio array
orig_sr: Original sample rate
target_sr: Target sample rate (default 16000 for Whisper)
Returns:
Resampled audio array
"""
if orig_sr == target_sr:
return audio_data
# Normalize integer PCM to [-1, 1] float32
if audio_data.dtype.kind in {"i", "u"}:
max_abs = float(np.iinfo(audio_data.dtype).max)
audio_data = audio_data.astype(np.float32) / max_abs
else:
audio_data = audio_data.astype(np.float32)
# Resample using polyphase filtering (better than FFT resample for speech)
gcd = math.gcd(int(orig_sr), int(target_sr))
up = int(target_sr // gcd)
down = int(orig_sr // gcd)
resampled = signal.resample_poly(audio_data, up, down)
return resampled.astype(np.float32, copy=False)
def _to_mono(audio_data: np.ndarray) -> np.ndarray:
if audio_data.ndim != 2:
return audio_data
# Gradio typically returns (samples, channels); be defensive for (channels, samples)
if audio_data.shape[0] in (1, 2) and audio_data.shape[0] < audio_data.shape[1]:
return np.mean(audio_data, axis=0)
return np.mean(audio_data, axis=1)
def _normalize_audio(audio_data: np.ndarray) -> np.ndarray:
# Convert to float32 and scale integers to [-1, 1]
if audio_data.dtype.kind in {"i", "u"}:
max_abs = float(np.iinfo(audio_data.dtype).max)
audio_data = audio_data.astype(np.float32) / max_abs
else:
audio_data = audio_data.astype(np.float32, copy=False)
# Remove DC offset
if audio_data.size:
audio_data = audio_data - float(np.mean(audio_data))
# Prevent accidental clipping if input is out of expected range
peak = float(np.max(np.abs(audio_data))) if audio_data.size else 0.0
if peak > 1.0:
audio_data = audio_data / peak
return audio_data
def _trim_silence(audio_data: np.ndarray, sample_rate: int) -> np.ndarray:
"""Conservatively trim leading/trailing near-silence to reduce compute.
Uses a threshold relative to peak amplitude to avoid chopping quiet speech.
"""
if audio_data.size == 0:
return audio_data
peak = float(np.max(np.abs(audio_data)))
if peak <= 0.0:
return audio_data
threshold = max(1e-4, 0.02 * peak)
active = np.where(np.abs(audio_data) > threshold)[0]
if active.size == 0:
return audio_data
pad = int(0.10 * sample_rate) # 100ms context
start = max(int(active[0]) - pad, 0)
end = min(int(active[-1]) + pad, audio_data.shape[0] - 1)
return audio_data[start : end + 1]
def transcribe(audio, progress=gr.Progress()):
"""
Transcribe audio to text using Whisper model
Args:
audio: Audio input from Gradio (sample_rate, audio_data)
progress: Gradio progress tracker
Returns:
str: Transcribed text
"""
try:
if audio is None:
logger.warning("No audio input provided")
return "โš ๏ธ No audio provided. Please upload or record audio."
progress(0.1, desc="Processing audio...")
sample_rate, audio_data = audio
audio_data = np.asarray(audio_data)
logger.info(f"Processing audio - Sample rate: {sample_rate}, Shape: {audio_data.shape}, Dtype: {audio_data.dtype}")
# Handle stereo to mono conversion
if audio_data.ndim > 1:
logger.info("Converting multi-channel audio to mono")
audio_data = _to_mono(audio_data)
# Normalize / sanitize audio prior to resampling
audio_data = _normalize_audio(audio_data)
# Resample to 16000 Hz if needed
target_sr = 16000
if sample_rate != target_sr:
logger.info(f"Resampling from {sample_rate} Hz to {target_sr} Hz")
progress(0.2, desc=f"Resampling audio from {sample_rate} Hz to {target_sr} Hz...")
audio_data = resample_audio(audio_data, sample_rate, target_sr)
sample_rate = target_sr
# Optionally trim silence to reduce compute on short clips
if os.getenv("WHISPER_TRIM_SILENCE", "1") not in {"0", "false", "False"}:
audio_data = _trim_silence(audio_data, sample_rate)
progress(0.3, desc="Preparing input features...")
t0 = time.perf_counter()
# Make padding/truncation explicit and request attention_mask when supported.
processor_kwargs = {
"sampling_rate": sample_rate,
"return_tensors": "pt",
# IMPORTANT: avoid padding everything to 30s. This dramatically speeds up short clips on CPU.
"padding": False,
"truncation": False,
}
try:
processor_kwargs["return_attention_mask"] = True
except Exception:
pass
inputs = processor(audio_data, **processor_kwargs)
t1 = time.perf_counter()
logger.info(f"Feature extraction took {(t1 - t0):.2f}s")
progress(0.5, desc="Generating transcription (CPU may take a while)...")
max_new_tokens = int(os.getenv("WHISPER_MAX_NEW_TOKENS", "128"))
num_beams = int(os.getenv("WHISPER_NUM_BEAMS", "1"))
generate_kwargs = {
"max_new_tokens": max_new_tokens,
"num_beams": num_beams,
"do_sample": False,
}
# Pass an attention_mask for encoder-decoder generate().
# Whisper's processor may not return it; in that case construct an all-ones mask.
try:
attention_mask = inputs.get("attention_mask")
except Exception:
attention_mask = None
if attention_mask is None:
# input_features: [batch, n_mels, n_frames] -> mask over frames
frames = int(inputs.input_features.shape[-1])
attention_mask = torch.ones((inputs.input_features.shape[0], frames), dtype=torch.long)
generate_kwargs["attention_mask"] = attention_mask
logger.info(f"Generating with num_beams={num_beams}, max_new_tokens={max_new_tokens}")
t2 = time.perf_counter()
with torch.inference_mode():
predicted_ids = model.generate(inputs.input_features, **generate_kwargs)
t3 = time.perf_counter()
logger.info(f"Generation took {(t3 - t2):.2f}s")
progress(0.8, desc="Decoding text...")
text = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
progress(1.0, desc="Complete!")
logger.info(f"Transcription successful - Length: {len(text)} characters")
return text
except Exception as e:
error_msg = f"โŒ Error during transcription: {str(e)}"
logger.error(error_msg)
return error_msg
# Enhanced Gradio interface
theme = gr.themes.Soft()
with gr.Blocks() as iface:
gr.Markdown(
"""
# ๐ŸŽ™๏ธ Whisper Uzbek Speech-to-Text
Transcribe Uzbek audio to text using the Whisper model. This application runs on CPU and supports Uzbek language.
**Model:** `jmshd/whisper-uz`
"""
)
with gr.Row():
with gr.Column():
audio_input = gr.Audio(
label="Upload or Record Audio",
type="numpy",
sources=["microphone", "upload"]
)
transcribe_btn = gr.Button("๐ŸŽฏ Transcribe", variant="primary")
clear_btn = gr.ClearButton([audio_input])
with gr.Column():
output_text = gr.Textbox(
label="Transcription",
placeholder="Your transcribed text will appear here...",
lines=10
)
gr.Markdown(
"""
### ๐Ÿ“ Usage Instructions:
1. Click the microphone icon to record audio or upload an audio file
2. Click the "Transcribe" button to convert speech to text
3. The transcribed text will appear in the output box
### ๐Ÿ”Œ API Access:
This Space provides a REST API for programmatic access. Click "Use via API" button below for details.
**Quick API Example (Python):**
```python
from gradio_client import Client
client = Client("YOUR_SPACE_URL")
result = client.predict("path/to/audio.mp3", api_name="/predict")
print(result)
```
### โ„น๏ธ Information:
- Supported language: Uzbek
- Processing: CPU-only (may be slower than GPU)
- Model size: Small
- API: Enabled via Gradio Client
"""
)
transcribe_btn.click(
fn=transcribe,
inputs=audio_input,
outputs=output_text
)
# Launch configuration for Hugging Face Spaces
if __name__ == "__main__":
logger.info("Launching Gradio interface...")
logger.info("API endpoints will be available at /api/predict")
iface.queue() # Enable queue for better API performance
iface.launch(
share=False,
show_error=True,
server_name="0.0.0.0",
server_port=7860,
theme=theme,
)