Spaces:
Paused
Paused
| import gradio as gr | |
| import torch | |
| from transformers import WhisperProcessor, WhisperForConditionalGeneration | |
| import numpy as np | |
| import os | |
| import librosa | |
| # --- CONFIGURATION --- | |
| TARGET_SAMPLE_RATE = 16000 | |
| # Get Hugging Face token from environment variable | |
| HF_TOKEN = os.environ.get("HF_TOKEN") | |
| if not HF_TOKEN: | |
| print("WARNING: HF_TOKEN not found! If your models are private/gated, loading will fail.") | |
| # Initialize device | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| # Use float16 for GPU (faster/less memory), float32 for CPU (compatibility) | |
| torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 | |
| print(f"Using device: {device} with precision: {torch_dtype}") | |
| print("Loading models... This may take a few minutes.") | |
| # --- MODEL DEFINITIONS --- | |
| MODEL_CONFIGS = { | |
| "Model 1: Ahmed107/hamsa-dis-saudi": { | |
| "path": "Ahmed107/hamsa-dis-saudi" | |
| }, | |
| "Model 2: nadsoft/ASR_hamsa-finetuned": { | |
| "path": "nadsoft/ASR_hamsa-finetuned" | |
| } | |
| } | |
| MODEL_KEYS = [ | |
| "Model 1: Ahmed107/hamsa-dis-saudi", | |
| "Model 2: nadsoft/ASR_hamsa-finetuned" | |
| ] | |
| # Global dictionaries to store loaded resources | |
| models = {} | |
| processors = {} | |
| # --- LOAD MODELS --- | |
| for key in MODEL_KEYS: | |
| config = MODEL_CONFIGS[key] | |
| try: | |
| print(f"Loading {key}...") | |
| # Load Processor | |
| processors[key] = WhisperProcessor.from_pretrained(config['path'], token=HF_TOKEN) | |
| # Load Model | |
| models[key] = WhisperForConditionalGeneration.from_pretrained( | |
| config['path'], | |
| token=HF_TOKEN, | |
| torch_dtype=torch_dtype | |
| ).to(device) | |
| print(f"β {key} loaded successfully!") | |
| except Exception as e: | |
| print(f"β Failed to load {key}: {str(e)}") | |
| # Check what loaded | |
| available_models = list(models.keys()) | |
| print(f"\nTotal models loaded: {len(available_models)}") | |
| # --- PROCESSING FUNCTION --- | |
| def transcribe_both_models(audio_file, repetition_penalty, temperature, num_beams, max_length): | |
| """ | |
| 1. Resamples audio to 16kHz. | |
| 2. Runs inference on both models. | |
| """ | |
| # Initialize output list with placeholders | |
| results = ["(Model failed or not loaded)"] * 2 | |
| if audio_file is None: | |
| return ["β οΈ Please upload an audio file."] * 2 | |
| try: | |
| # Gradio 'numpy' type returns a tuple: (sample_rate, array) | |
| original_sr, audio_data = audio_file | |
| # --- AUDIO PRE-PROCESSING --- | |
| # 1. Ensure Float32 | |
| audio_data = audio_data.astype(np.float32) | |
| # 2. Check for Stereo (2 channels) and convert to Mono | |
| if len(audio_data.shape) > 1: | |
| audio_data = np.mean(audio_data, axis=1) | |
| # 3. Normalize to range [-1, 1] | |
| if np.max(np.abs(audio_data)) > 0: | |
| audio_data = audio_data / np.max(np.abs(audio_data)) | |
| # 4. Resample to 16,000 Hz if necessary | |
| if original_sr != TARGET_SAMPLE_RATE: | |
| audio_data = librosa.resample( | |
| audio_data, | |
| orig_sr=original_sr, | |
| target_sr=TARGET_SAMPLE_RATE | |
| ) | |
| # 5. Final check on length | |
| if len(audio_data) == 0: | |
| return ["β οΈ Audio file is empty after processing."] * 2 | |
| except Exception as e: | |
| return [f"β Error during audio processing: {str(e)}"] * 2 | |
| # --- INFERENCE LOOP --- | |
| for i, key in enumerate(MODEL_KEYS): | |
| if key not in models: | |
| continue | |
| try: | |
| model = models[key] | |
| processor = processors[key] | |
| # Prepare inputs | |
| input_features = processor( | |
| audio_data, | |
| sampling_rate=TARGET_SAMPLE_RATE, | |
| return_tensors="pt" | |
| ).input_features | |
| # Move to GPU/CPU and correct precision | |
| input_features = input_features.to(device, dtype=torch_dtype) | |
| # Generate tokens | |
| with torch.no_grad(): | |
| predicted_ids = model.generate( | |
| input_features, | |
| repetition_penalty=repetition_penalty, | |
| temperature=temperature, | |
| num_beams=num_beams, | |
| max_length=max_length, | |
| do_sample=(temperature > 0) | |
| ) | |
| # Decode to text | |
| transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0] | |
| # Store result | |
| results[i] = transcription if transcription.strip() else "β οΈ (No text generated)" | |
| except Exception as e: | |
| results[i] = f"β Error: {str(e)}" | |
| return results[0], results[1] | |
| # --- GRADIO UI --- | |
| with gr.Blocks(title="Whisper Model Comparison") as demo: | |
| gr.Markdown( | |
| """ | |
| # π€ Whisper 2-Model Comparison | |
| Upload audio once, compare results from 2 different models. | |
| *System automatically resamples audio to 16kHz to prevent distortion.* | |
| """ | |
| ) | |
| with gr.Row(): | |
| # --- Left Column: Input --- | |
| with gr.Column(scale=1): | |
| audio_input = gr.Audio( | |
| label="Input Audio", | |
| type="numpy", | |
| sources=["upload", "microphone"] | |
| ) | |
| submit_btn = gr.Button("π Transcribe", variant="primary", size="lg") | |
| # --- Right Column: Settings --- | |
| with gr.Column(scale=1): | |
| with gr.Accordion("βοΈ Generation Parameters", open=True): | |
| repetition_penalty = gr.Slider(1.0, 2.0, value=1.2, step=0.1, label="Repetition Penalty") | |
| temperature = gr.Slider(0.0, 1.0, value=0.0, step=0.1, label="Temperature") | |
| num_beams = gr.Slider(1, 10, value=5, step=1, label="Beam Size") | |
| max_length = gr.Slider(50, 448, value=225, step=25, label="Max Token Length") | |
| # --- Output Row --- | |
| gr.Markdown("### π Results") | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown(f"**{MODEL_KEYS[0]}**") | |
| out_1 = gr.Textbox(label="Transcript", lines=8, show_label=False) | |
| with gr.Column(): | |
| gr.Markdown(f"**{MODEL_KEYS[1]}**") | |
| out_2 = gr.Textbox(label="Transcript", lines=8, show_label=False) | |
| # --- Event Binding --- | |
| submit_btn.click( | |
| fn=transcribe_both_models, | |
| inputs=[audio_input, repetition_penalty, temperature, num_beams, max_length], | |
| outputs=[out_1, out_2] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |