import gradio as gr from transformers import WhisperProcessor, WhisperForConditionalGeneration, GenerationConfig import torch import torchaudio import numpy as np import av # Ensure you have installed this: pip install av # --- Configuration and Model Loading --- model_id = "OvozifyLabs/whisper-small-uz-v1" # Check for GPU and set device device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Loading model on device: {device}") # Load the processor and model (only runs once at startup) try: processor = WhisperProcessor.from_pretrained(model_id) model = WhisperForConditionalGeneration.from_pretrained(model_id).to(device) except Exception as e: print(f"Error loading model or processor: {e}") # Handle the error gracefully if the model cannot be loaded processor = None model = None # --- Audio Loading Helper Function --- def load_audio_file(file_path): """ Loads an audio file (handles M4A, MP3, WAV, etc.) and ensures it is resampled to 16000 Hz and converted to mono, which Whisper models require. """ sr_target = 16000 # Target sampling rate for the Whisper model if not file_path: raise FileNotFoundError("Audio file path is empty.") audio_data_list = [] current_sr = sr_target # Assume target SR initially try: # 1. Try torchaudio's built-in loader first (usually handles WAV, FLAC well) audio, sr = torchaudio.load(file_path) current_sr = sr # If torchaudio succeeds, perform necessary post-loading processing # Resample if needed if current_sr != sr_target: if audio.dtype != torch.float32: audio = audio.float() resampler = torchaudio.transforms.Resample(orig_freq=current_sr, new_freq=sr_target) audio = resampler(audio) current_sr = sr_target # Convert to mono if necessary (take the mean across channels) if audio.shape[0] > 1: audio = torch.mean(audio, dim=0, keepdim=True) return audio, current_sr except Exception as torchaudio_e: # 2. Fallback to using PyAV (FFmpeg wrapper) for formats like M4A, MP3 # print(f"Torchaudio failed. Falling back to PyAV. Error: {torchaudio_e}") try: import av with av.open(file_path) as container: stream = container.streams.audio[0] # Set up a resampler to ensure 16kHz float mono output resampler = av.AudioResampler( format='fltp', # 32-bit floating point layout='mono', # Force mono output rate=sr_target # Target sampling rate 16000 Hz ) # Decode the audio stream and resample frames for frame in container.decode(stream): for resampled_frame in resampler.resample(frame): # *** FIX APPLIED HERE: Removed 'format' keyword argument *** # to_ndarray() converts the frame to a NumPy array. # For a mono stream, [0] selects the single channel's data. audio_data_list.append(resampled_frame.to_ndarray()[0]) if not audio_data_list: raise RuntimeError("Could not decode audio frames using PyAV.") # Concatenate all the 1D NumPy arrays into a single, continuous array audio_np = np.concatenate(audio_data_list, axis=0) # Convert the NumPy array back to a PyTorch tensor, ensuring it's 1-channel (mono) audio = torch.from_numpy(audio_np).unsqueeze(0).float() return audio, sr_target except Exception as av_e: raise RuntimeError(f"Failed to load audio file using both torchaudio and PyAV. Error: {av_e}") # Note: The main `transcribe_audio` function and the Gradio setup do not need changes. # Just replace this one function and restart your application. # --- Post-Loading Processing (Only executes if torchaudio succeeded) --- # Resample if needed (if torchaudio succeeded but the rate was wrong) if current_sr != sr_target: if audio_data.dtype != torch.float32: audio_data = audio_data.float() resampler = torchaudio.transforms.Resample(orig_freq=current_sr, new_freq=sr_target) audio_data = resampler(audio_data) current_sr = sr_target # Convert to mono if necessary (take the mean across channels) if audio_data.shape[0] > 1: audio_data = torch.mean(audio_data, dim=0, keepdim=True) return audio_data, current_sr # --- Transcription Function --- def transcribe_audio(audio_file_path): """ Transcribes an audio file using the pre-loaded Whisper model. """ if model is None: return "Error: Model was not loaded successfully at startup." if audio_file_path is None: return "Error: No audio file provided." try: # Load audio using the robust loader and get the 16kHz mono tensor audio, sr = load_audio_file(audio_file_path) # The processor expects a 1D NumPy array for raw audio input # audio.squeeze().numpy() converts the (1, N) torch tensor to a (N,) numpy array inputs = processor(audio.squeeze().numpy(), sampling_rate=sr, return_tensors="pt") # Move inputs to the appropriate device input_features = inputs.input_features.to(device) forced_ids = processor.get_decoder_prompt_ids(language="uz", task="transcribe") gen_config = GenerationConfig( forced_decoder_ids=forced_ids, max_length=448 ) with torch.no_grad(): predicted_ids = model.generate( input_features, generation_config=gen_config ) # Decode the generated token IDs to get the text transcript text = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0] return text except Exception as e: return f"An error occurred during transcription: {e}" # --- Gradio Interface Setup --- # πŸ–ΌοΈ Interface Description title = "πŸ‡ΊπŸ‡Ώ Whisper Uz-Small v1: Audio Transcription" description = "A Gradio demo for the **OvozifyLabs/whisper-small-uz-v1** model for Uzbek ASR. Upload an audio file (M4A, MP3, WAV supported) or record directly." # 🎀 Input Component audio_input = gr.Audio( sources=["microphone", "upload"], type="filepath", label="Input Audio (M4A/MP3/WAV, etc.)" ) # πŸ“ Output Component text_output = gr.Textbox(label="Transcription Result") # πŸš€ Create the Interface demo = gr.Interface( fn=transcribe_audio, inputs=audio_input, outputs=text_output, title=title, description=description, # The 'allow_flagging' argument caused the TypeError and is removed/replaced # 'flagging_enabled=None' disables the flagging button, which is cleaner # flagging_enabled=None, # theme=gr.themes.Soft() ) # πŸ’» Launch the App if __name__ == "__main__": demo.launch()