test-dis-hamsa / app.py
Ahmed107's picture
Update app.py
b81c715 verified
import gradio as gr
import torch
from transformers import WhisperProcessor, WhisperForConditionalGeneration
import numpy as np
import os
import librosa
# --- CONFIGURATION ---
TARGET_SAMPLE_RATE = 16000
# Get Hugging Face token from environment variable
HF_TOKEN = os.environ.get("HF_TOKEN")
if not HF_TOKEN:
print("WARNING: HF_TOKEN not found! If your models are private/gated, loading will fail.")
# Initialize device
device = "cuda" if torch.cuda.is_available() else "cpu"
# Use float16 for GPU (faster/less memory), float32 for CPU (compatibility)
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
print(f"Using device: {device} with precision: {torch_dtype}")
print("Loading models... This may take a few minutes.")
# --- MODEL DEFINITIONS ---
MODEL_CONFIGS = {
"Model 1: Ahmed107/hamsa-dis-saudi": {
"path": "Ahmed107/hamsa-dis-saudi"
},
"Model 2: nadsoft/ASR_hamsa-finetuned": {
"path": "nadsoft/ASR_hamsa-finetuned"
}
}
MODEL_KEYS = [
"Model 1: Ahmed107/hamsa-dis-saudi",
"Model 2: nadsoft/ASR_hamsa-finetuned"
]
# Global dictionaries to store loaded resources
models = {}
processors = {}
# --- LOAD MODELS ---
for key in MODEL_KEYS:
config = MODEL_CONFIGS[key]
try:
print(f"Loading {key}...")
# Load Processor
processors[key] = WhisperProcessor.from_pretrained(config['path'], token=HF_TOKEN)
# Load Model
models[key] = WhisperForConditionalGeneration.from_pretrained(
config['path'],
token=HF_TOKEN,
torch_dtype=torch_dtype
).to(device)
print(f"βœ“ {key} loaded successfully!")
except Exception as e:
print(f"βœ— Failed to load {key}: {str(e)}")
# Check what loaded
available_models = list(models.keys())
print(f"\nTotal models loaded: {len(available_models)}")
# --- PROCESSING FUNCTION ---
def transcribe_both_models(audio_file, repetition_penalty, temperature, num_beams, max_length):
"""
1. Resamples audio to 16kHz.
2. Runs inference on both models.
"""
# Initialize output list with placeholders
results = ["(Model failed or not loaded)"] * 2
if audio_file is None:
return ["⚠️ Please upload an audio file."] * 2
try:
# Gradio 'numpy' type returns a tuple: (sample_rate, array)
original_sr, audio_data = audio_file
# --- AUDIO PRE-PROCESSING ---
# 1. Ensure Float32
audio_data = audio_data.astype(np.float32)
# 2. Check for Stereo (2 channels) and convert to Mono
if len(audio_data.shape) > 1:
audio_data = np.mean(audio_data, axis=1)
# 3. Normalize to range [-1, 1]
if np.max(np.abs(audio_data)) > 0:
audio_data = audio_data / np.max(np.abs(audio_data))
# 4. Resample to 16,000 Hz if necessary
if original_sr != TARGET_SAMPLE_RATE:
audio_data = librosa.resample(
audio_data,
orig_sr=original_sr,
target_sr=TARGET_SAMPLE_RATE
)
# 5. Final check on length
if len(audio_data) == 0:
return ["⚠️ Audio file is empty after processing."] * 2
except Exception as e:
return [f"❌ Error during audio processing: {str(e)}"] * 2
# --- INFERENCE LOOP ---
for i, key in enumerate(MODEL_KEYS):
if key not in models:
continue
try:
model = models[key]
processor = processors[key]
# Prepare inputs
input_features = processor(
audio_data,
sampling_rate=TARGET_SAMPLE_RATE,
return_tensors="pt"
).input_features
# Move to GPU/CPU and correct precision
input_features = input_features.to(device, dtype=torch_dtype)
# Generate tokens
with torch.no_grad():
predicted_ids = model.generate(
input_features,
repetition_penalty=repetition_penalty,
temperature=temperature,
num_beams=num_beams,
max_length=max_length,
do_sample=(temperature > 0)
)
# Decode to text
transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
# Store result
results[i] = transcription if transcription.strip() else "⚠️ (No text generated)"
except Exception as e:
results[i] = f"❌ Error: {str(e)}"
return results[0], results[1]
# --- GRADIO UI ---
with gr.Blocks(title="Whisper Model Comparison") as demo:
gr.Markdown(
"""
# 🎀 Whisper 2-Model Comparison
Upload audio once, compare results from 2 different models.
*System automatically resamples audio to 16kHz to prevent distortion.*
"""
)
with gr.Row():
# --- Left Column: Input ---
with gr.Column(scale=1):
audio_input = gr.Audio(
label="Input Audio",
type="numpy",
sources=["upload", "microphone"]
)
submit_btn = gr.Button("πŸš€ Transcribe", variant="primary", size="lg")
# --- Right Column: Settings ---
with gr.Column(scale=1):
with gr.Accordion("βš™οΈ Generation Parameters", open=True):
repetition_penalty = gr.Slider(1.0, 2.0, value=1.2, step=0.1, label="Repetition Penalty")
temperature = gr.Slider(0.0, 1.0, value=0.0, step=0.1, label="Temperature")
num_beams = gr.Slider(1, 10, value=5, step=1, label="Beam Size")
max_length = gr.Slider(50, 448, value=225, step=25, label="Max Token Length")
# --- Output Row ---
gr.Markdown("### πŸ“ Results")
with gr.Row():
with gr.Column():
gr.Markdown(f"**{MODEL_KEYS[0]}**")
out_1 = gr.Textbox(label="Transcript", lines=8, show_label=False)
with gr.Column():
gr.Markdown(f"**{MODEL_KEYS[1]}**")
out_2 = gr.Textbox(label="Transcript", lines=8, show_label=False)
# --- Event Binding ---
submit_btn.click(
fn=transcribe_both_models,
inputs=[audio_input, repetition_penalty, temperature, num_beams, max_length],
outputs=[out_1, out_2]
)
if __name__ == "__main__":
demo.launch()