Spaces:
Running
Running
| import gradio as gr | |
| from transformers import WhisperProcessor, WhisperForConditionalGeneration, GenerationConfig | |
| import torch | |
| import torchaudio | |
| import numpy as np | |
| import av | |
| # --- Configuration and Model Loading --- | |
| model_id = "OvozifyLabs/whisper-small-uz-v1" | |
| # Check for GPU and set device | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f"Loading model on device: {device}") | |
| # Load the processor and model (only runs once at startup) | |
| try: | |
| processor = WhisperProcessor.from_pretrained(model_id) | |
| model = WhisperForConditionalGeneration.from_pretrained(model_id).to(device) | |
| except Exception as e: | |
| print(f"Error loading model or processor: {e}") | |
| processor = None | |
| model = None | |
| # --- Audio Loading Helper Function --- | |
| def load_audio_file(file_path): | |
| """ | |
| Loads an audio file (handles M4A, MP3, WAV, etc.) and ensures it is | |
| resampled to 16000 Hz and converted to mono, which Whisper models require. | |
| """ | |
| sr_target = 16000 | |
| if not file_path: | |
| raise FileNotFoundError("Audio file path is empty.") | |
| audio_data_list = [] | |
| current_sr = sr_target | |
| try: | |
| # Try torchaudio's built-in loader first | |
| audio, sr = torchaudio.load(file_path) | |
| current_sr = sr | |
| # Resample if needed | |
| if current_sr != sr_target: | |
| if audio.dtype != torch.float32: | |
| audio = audio.float() | |
| resampler = torchaudio.transforms.Resample(orig_freq=current_sr, new_freq=sr_target) | |
| audio = resampler(audio) | |
| current_sr = sr_target | |
| # Convert to mono if necessary | |
| if audio.shape[0] > 1: | |
| audio = torch.mean(audio, dim=0, keepdim=True) | |
| return audio, current_sr | |
| except Exception as torchaudio_e: | |
| # Fallback to PyAV for formats like M4A, MP3 | |
| try: | |
| import av | |
| with av.open(file_path) as container: | |
| stream = container.streams.audio[0] | |
| resampler = av.AudioResampler( | |
| format='fltp', | |
| layout='mono', | |
| rate=sr_target | |
| ) | |
| for frame in container.decode(stream): | |
| for resampled_frame in resampler.resample(frame): | |
| audio_data_list.append(resampled_frame.to_ndarray()[0]) | |
| if not audio_data_list: | |
| raise RuntimeError("Could not decode audio frames using PyAV.") | |
| audio_np = np.concatenate(audio_data_list, axis=0) | |
| audio = torch.from_numpy(audio_np).unsqueeze(0).float() | |
| return audio, sr_target | |
| except Exception as av_e: | |
| raise RuntimeError(f"Failed to load audio file using both torchaudio and PyAV. Error: {av_e}") | |
| # --- Audio Chunking Function --- | |
| def chunk_audio(audio_tensor, sampling_rate, chunk_length_s=30, overlap_s=5): | |
| """ | |
| Splits audio into overlapping chunks. | |
| Args: | |
| audio_tensor: torch.Tensor of shape (1, num_samples) - mono audio | |
| sampling_rate: int - sampling rate of the audio | |
| chunk_length_s: float - length of each chunk in seconds | |
| overlap_s: float - overlap between chunks in seconds | |
| Returns: | |
| List of audio chunks (torch.Tensors) | |
| """ | |
| chunk_samples = int(chunk_length_s * sampling_rate) | |
| overlap_samples = int(overlap_s * sampling_rate) | |
| stride = chunk_samples - overlap_samples | |
| audio_length = audio_tensor.shape[1] | |
| chunks = [] | |
| # If audio is shorter than chunk length, return as single chunk | |
| if audio_length <= chunk_samples: | |
| return [audio_tensor] | |
| # Split into chunks with overlap | |
| start = 0 | |
| while start < audio_length: | |
| end = min(start + chunk_samples, audio_length) | |
| chunk = audio_tensor[:, start:end] | |
| chunks.append(chunk) | |
| # Break if we've reached the end | |
| if end >= audio_length: | |
| break | |
| start += stride | |
| return chunks | |
| # --- Transcription Function --- | |
| def transcribe_audio(audio_file_path, language): | |
| """ | |
| Transcribes an audio file using the pre-loaded Whisper model. | |
| Automatically chunks audio longer than 30 seconds. | |
| """ | |
| if model is None: | |
| return "Error: Model was not loaded successfully at startup." | |
| if audio_file_path is None: | |
| return "Error: No audio file provided." | |
| lang_dict = { | |
| "Uzbek": "uz", | |
| "Russian": "ru", | |
| "English": "en" | |
| } | |
| language = lang_dict[language] | |
| try: | |
| # Load audio using the robust loader | |
| audio, sr = load_audio_file(audio_file_path) | |
| # Calculate audio duration | |
| duration_s = audio.shape[1] / sr | |
| # Check if chunking is needed | |
| if duration_s > 30: | |
| print(f"Audio duration: {duration_s:.2f}s - Chunking into segments...") | |
| chunks = chunk_audio(audio, sr, chunk_length_s=30, overlap_s=5) | |
| # Transcribe each chunk | |
| transcriptions = [] | |
| for i, chunk in enumerate(chunks): | |
| print(f"Processing chunk {i+1}/{len(chunks)}...") | |
| inputs = processor(chunk.squeeze().numpy(), sampling_rate=sr, return_tensors="pt") | |
| input_features = inputs.input_features.to(device) | |
| forced_ids = processor.get_decoder_prompt_ids(language=language, task="transcribe") | |
| gen_config = GenerationConfig( | |
| forced_decoder_ids=forced_ids, | |
| max_length=448 | |
| ) | |
| with torch.no_grad(): | |
| predicted_ids = model.generate( | |
| input_features, | |
| generation_config=gen_config | |
| ) | |
| text = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0] | |
| transcriptions.append(text) | |
| # Combine all transcriptions | |
| full_transcription = " ".join(transcriptions) | |
| return f"[Audio duration: {duration_s:.2f}s - Processed in {len(chunks)} chunks]\n\n{full_transcription}" | |
| else: | |
| # Process normally for short audio | |
| print(f"Audio duration: {duration_s:.2f}s - Processing as single segment...") | |
| inputs = processor(audio.squeeze().numpy(), sampling_rate=sr, return_tensors="pt") | |
| input_features = inputs.input_features.to(device) | |
| forced_ids = processor.get_decoder_prompt_ids(language=language, task="transcribe") | |
| gen_config = GenerationConfig( | |
| forced_decoder_ids=forced_ids, | |
| max_length=448 | |
| ) | |
| with torch.no_grad(): | |
| predicted_ids = model.generate( | |
| input_features, | |
| generation_config=gen_config | |
| ) | |
| text = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0] | |
| return text | |
| except Exception as e: | |
| return f"An error occurred during transcription: {e}" | |
| # --- Gradio Interface Setup --- | |
| title = "Whisper Small Uz v1: Multilingual audio transcription" | |
| description = """A Gradio demo for the **OvozifyLabs/whisper-small-uz-v1** model for Uzbek ASR. | |
| Upload an audio file (M4A, MP3, WAV supported) or record directly. """ | |
| language_input = gr.Dropdown( | |
| label="Select Language", | |
| choices=["Uzbek", "English", "Russian"], | |
| value="Uzbek" | |
| ) | |
| audio_input = gr.Audio( | |
| sources=["microphone", "upload"], | |
| type="filepath", | |
| label="Input Audio (M4A/MP3/WAV, etc.)" | |
| ) | |
| text_output = gr.Textbox(label="Transcription Result", lines=6, max_lines=25) | |
| demo = gr.Interface( | |
| fn=transcribe_audio, | |
| inputs=[audio_input, language_input], | |
| outputs=text_output, | |
| title=title, | |
| description=description, | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |