import gradio as gr import torch from transformers import WhisperProcessor, WhisperForConditionalGeneration import numpy as np import os import librosa # --- CONFIGURATION --- TARGET_SAMPLE_RATE = 16000 # Get Hugging Face token from environment variable HF_TOKEN = os.environ.get("HF_TOKEN") if not HF_TOKEN: print("WARNING: HF_TOKEN not found! If your models are private/gated, loading will fail.") # Initialize device device = "cuda" if torch.cuda.is_available() else "cpu" # Use float16 for GPU (faster/less memory), float32 for CPU (compatibility) torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 print(f"Using device: {device} with precision: {torch_dtype}") print("Loading models... This may take a few minutes.") # --- MODEL DEFINITIONS --- MODEL_CONFIGS = { "Model 1: Ahmed107/hamsa-dis-saudi": { "path": "Ahmed107/hamsa-dis-saudi" }, "Model 2: nadsoft/ASR_hamsa-finetuned": { "path": "nadsoft/ASR_hamsa-finetuned" } } MODEL_KEYS = [ "Model 1: Ahmed107/hamsa-dis-saudi", "Model 2: nadsoft/ASR_hamsa-finetuned" ] # Global dictionaries to store loaded resources models = {} processors = {} # --- LOAD MODELS --- for key in MODEL_KEYS: config = MODEL_CONFIGS[key] try: print(f"Loading {key}...") # Load Processor processors[key] = WhisperProcessor.from_pretrained(config['path'], token=HF_TOKEN) # Load Model models[key] = WhisperForConditionalGeneration.from_pretrained( config['path'], token=HF_TOKEN, torch_dtype=torch_dtype ).to(device) print(f"✓ {key} loaded successfully!") except Exception as e: print(f"✗ Failed to load {key}: {str(e)}") # Check what loaded available_models = list(models.keys()) print(f"\nTotal models loaded: {len(available_models)}") # --- PROCESSING FUNCTION --- def transcribe_both_models(audio_file, repetition_penalty, temperature, num_beams, max_length): """ 1. Resamples audio to 16kHz. 2. Runs inference on both models. """ # Initialize output list with placeholders results = ["(Model failed or not loaded)"] * 2 if audio_file is None: return ["⚠️ Please upload an audio file."] * 2 try: # Gradio 'numpy' type returns a tuple: (sample_rate, array) original_sr, audio_data = audio_file # --- AUDIO PRE-PROCESSING --- # 1. Ensure Float32 audio_data = audio_data.astype(np.float32) # 2. Check for Stereo (2 channels) and convert to Mono if len(audio_data.shape) > 1: audio_data = np.mean(audio_data, axis=1) # 3. Normalize to range [-1, 1] if np.max(np.abs(audio_data)) > 0: audio_data = audio_data / np.max(np.abs(audio_data)) # 4. Resample to 16,000 Hz if necessary if original_sr != TARGET_SAMPLE_RATE: audio_data = librosa.resample( audio_data, orig_sr=original_sr, target_sr=TARGET_SAMPLE_RATE ) # 5. Final check on length if len(audio_data) == 0: return ["⚠️ Audio file is empty after processing."] * 2 except Exception as e: return [f"❌ Error during audio processing: {str(e)}"] * 2 # --- INFERENCE LOOP --- for i, key in enumerate(MODEL_KEYS): if key not in models: continue try: model = models[key] processor = processors[key] # Prepare inputs input_features = processor( audio_data, sampling_rate=TARGET_SAMPLE_RATE, return_tensors="pt" ).input_features # Move to GPU/CPU and correct precision input_features = input_features.to(device, dtype=torch_dtype) # Generate tokens with torch.no_grad(): predicted_ids = model.generate( input_features, repetition_penalty=repetition_penalty, temperature=temperature, num_beams=num_beams, max_length=max_length, do_sample=(temperature > 0) ) # Decode to text transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0] # Store result results[i] = transcription if transcription.strip() else "⚠️ (No text generated)" except Exception as e: results[i] = f"❌ Error: {str(e)}" return results[0], results[1] # --- GRADIO UI --- with gr.Blocks(title="Whisper Model Comparison") as demo: gr.Markdown( """ # 🎤 Whisper 2-Model Comparison Upload audio once, compare results from 2 different models. *System automatically resamples audio to 16kHz to prevent distortion.* """ ) with gr.Row(): # --- Left Column: Input --- with gr.Column(scale=1): audio_input = gr.Audio( label="Input Audio", type="numpy", sources=["upload", "microphone"] ) submit_btn = gr.Button("🚀 Transcribe", variant="primary", size="lg") # --- Right Column: Settings --- with gr.Column(scale=1): with gr.Accordion("⚙️ Generation Parameters", open=True): repetition_penalty = gr.Slider(1.0, 2.0, value=1.2, step=0.1, label="Repetition Penalty") temperature = gr.Slider(0.0, 1.0, value=0.0, step=0.1, label="Temperature") num_beams = gr.Slider(1, 10, value=5, step=1, label="Beam Size") max_length = gr.Slider(50, 448, value=225, step=25, label="Max Token Length") # --- Output Row --- gr.Markdown("### 📝 Results") with gr.Row(): with gr.Column(): gr.Markdown(f"**{MODEL_KEYS[0]}**") out_1 = gr.Textbox(label="Transcript", lines=8, show_label=False) with gr.Column(): gr.Markdown(f"**{MODEL_KEYS[1]}**") out_2 = gr.Textbox(label="Transcript", lines=8, show_label=False) # --- Event Binding --- submit_btn.click( fn=transcribe_both_models, inputs=[audio_input, repetition_penalty, temperature, num_beams, max_length], outputs=[out_1, out_2] ) if __name__ == "__main__": demo.launch()