import gradio as gr from transformers import WhisperProcessor, WhisperForConditionalGeneration, GenerationConfig import torch import torchaudio import numpy as np import av # --- Configuration and Model Loading --- model_id = "OvozifyLabs/whisper-small-uz-v1" # Check for GPU and set device device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Loading model on device: {device}") # Load the processor and model (only runs once at startup) try: processor = WhisperProcessor.from_pretrained(model_id) model = WhisperForConditionalGeneration.from_pretrained(model_id).to(device) except Exception as e: print(f"Error loading model or processor: {e}") processor = None model = None # --- Audio Loading Helper Function --- def load_audio_file(file_path): """ Loads an audio file (handles M4A, MP3, WAV, etc.) and ensures it is resampled to 16000 Hz and converted to mono, which Whisper models require. """ sr_target = 16000 if not file_path: raise FileNotFoundError("Audio file path is empty.") audio_data_list = [] current_sr = sr_target try: # Try torchaudio's built-in loader first audio, sr = torchaudio.load(file_path) current_sr = sr # Resample if needed if current_sr != sr_target: if audio.dtype != torch.float32: audio = audio.float() resampler = torchaudio.transforms.Resample(orig_freq=current_sr, new_freq=sr_target) audio = resampler(audio) current_sr = sr_target # Convert to mono if necessary if audio.shape[0] > 1: audio = torch.mean(audio, dim=0, keepdim=True) return audio, current_sr except Exception as torchaudio_e: # Fallback to PyAV for formats like M4A, MP3 try: import av with av.open(file_path) as container: stream = container.streams.audio[0] resampler = av.AudioResampler( format='fltp', layout='mono', rate=sr_target ) for frame in container.decode(stream): for resampled_frame in resampler.resample(frame): audio_data_list.append(resampled_frame.to_ndarray()[0]) if not audio_data_list: raise RuntimeError("Could not decode audio frames using PyAV.") audio_np = np.concatenate(audio_data_list, axis=0) audio = torch.from_numpy(audio_np).unsqueeze(0).float() return audio, sr_target except Exception as av_e: raise RuntimeError(f"Failed to load audio file using both torchaudio and PyAV. Error: {av_e}") # --- Audio Chunking Function --- def chunk_audio(audio_tensor, sampling_rate, chunk_length_s=30, overlap_s=5): """ Splits audio into overlapping chunks. Args: audio_tensor: torch.Tensor of shape (1, num_samples) - mono audio sampling_rate: int - sampling rate of the audio chunk_length_s: float - length of each chunk in seconds overlap_s: float - overlap between chunks in seconds Returns: List of audio chunks (torch.Tensors) """ chunk_samples = int(chunk_length_s * sampling_rate) overlap_samples = int(overlap_s * sampling_rate) stride = chunk_samples - overlap_samples audio_length = audio_tensor.shape[1] chunks = [] # If audio is shorter than chunk length, return as single chunk if audio_length <= chunk_samples: return [audio_tensor] # Split into chunks with overlap start = 0 while start < audio_length: end = min(start + chunk_samples, audio_length) chunk = audio_tensor[:, start:end] chunks.append(chunk) # Break if we've reached the end if end >= audio_length: break start += stride return chunks # --- Transcription Function --- def transcribe_audio(audio_file_path, language): """ Transcribes an audio file using the pre-loaded Whisper model. Automatically chunks audio longer than 30 seconds. """ if model is None: return "Error: Model was not loaded successfully at startup." if audio_file_path is None: return "Error: No audio file provided." lang_dict = { "Uzbek": "uz", "Russian": "ru", "English": "en" } language = lang_dict[language] try: # Load audio using the robust loader audio, sr = load_audio_file(audio_file_path) # Calculate audio duration duration_s = audio.shape[1] / sr # Check if chunking is needed if duration_s > 30: print(f"Audio duration: {duration_s:.2f}s - Chunking into segments...") chunks = chunk_audio(audio, sr, chunk_length_s=30, overlap_s=5) # Transcribe each chunk transcriptions = [] for i, chunk in enumerate(chunks): print(f"Processing chunk {i+1}/{len(chunks)}...") inputs = processor(chunk.squeeze().numpy(), sampling_rate=sr, return_tensors="pt") input_features = inputs.input_features.to(device) forced_ids = processor.get_decoder_prompt_ids(language=language, task="transcribe") gen_config = GenerationConfig( forced_decoder_ids=forced_ids, max_length=448 ) with torch.no_grad(): predicted_ids = model.generate( input_features, generation_config=gen_config ) text = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0] transcriptions.append(text) # Combine all transcriptions full_transcription = " ".join(transcriptions) return f"[Audio duration: {duration_s:.2f}s - Processed in {len(chunks)} chunks]\n\n{full_transcription}" else: # Process normally for short audio print(f"Audio duration: {duration_s:.2f}s - Processing as single segment...") inputs = processor(audio.squeeze().numpy(), sampling_rate=sr, return_tensors="pt") input_features = inputs.input_features.to(device) forced_ids = processor.get_decoder_prompt_ids(language=language, task="transcribe") gen_config = GenerationConfig( forced_decoder_ids=forced_ids, max_length=448 ) with torch.no_grad(): predicted_ids = model.generate( input_features, generation_config=gen_config ) text = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0] return text except Exception as e: return f"An error occurred during transcription: {e}" # --- Gradio Interface Setup --- title = "Whisper Small Uz v1: Multilingual audio transcription" description = """A Gradio demo for the **OvozifyLabs/whisper-small-uz-v1** model for Uzbek ASR. Upload an audio file (M4A, MP3, WAV supported) or record directly. """ language_input = gr.Dropdown( label="Select Language", choices=["Uzbek", "English", "Russian"], value="Uzbek" ) audio_input = gr.Audio( sources=["microphone", "upload"], type="filepath", label="Input Audio (M4A/MP3/WAV, etc.)" ) text_output = gr.Textbox(label="Transcription Result", lines=6, max_lines=25) demo = gr.Interface( fn=transcribe_audio, inputs=[audio_input, language_input], outputs=text_output, title=title, description=description, ) if __name__ == "__main__": demo.launch()