Spaces:
Running
Running
| import gradio as gr | |
| from transformers import WhisperProcessor, WhisperForConditionalGeneration | |
| import torch | |
| import torchaudio | |
| import numpy as np | |
| import av # Ensure you have installed this: pip install av | |
| # --- Configuration and Model Loading --- | |
| model_id = "OvozifyLabs/whisper-small-uz-v1" | |
| # Check for GPU and set device | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f"Loading model on device: {device}") | |
| # Load the processor and model (only runs once at startup) | |
| try: | |
| processor = WhisperProcessor.from_pretrained(model_id) | |
| model = WhisperForConditionalGeneration.from_pretrained(model_id).to(device) | |
| except Exception as e: | |
| print(f"Error loading model or processor: {e}") | |
| # Handle the error gracefully if the model cannot be loaded | |
| processor = None | |
| model = None | |
| # --- Audio Loading Helper Function --- | |
| def load_audio_file(file_path): | |
| """ | |
| Loads an audio file (handles M4A, MP3, WAV, etc.) and ensures it is | |
| resampled to 16000 Hz and converted to mono, which Whisper models require. | |
| """ | |
| sr_target = 16000 # Target sampling rate for the Whisper model | |
| if not file_path: | |
| raise FileNotFoundError("Audio file path is empty.") | |
| audio_data_list = [] | |
| current_sr = sr_target # Assume target SR initially | |
| try: | |
| # 1. Try torchaudio's built-in loader first (usually handles WAV, FLAC well) | |
| audio, sr = torchaudio.load(file_path) | |
| current_sr = sr | |
| # If torchaudio succeeds, perform necessary post-loading processing | |
| # Resample if needed | |
| if current_sr != sr_target: | |
| if audio.dtype != torch.float32: | |
| audio = audio.float() | |
| resampler = torchaudio.transforms.Resample(orig_freq=current_sr, new_freq=sr_target) | |
| audio = resampler(audio) | |
| current_sr = sr_target | |
| # Convert to mono if necessary (take the mean across channels) | |
| if audio.shape[0] > 1: | |
| audio = torch.mean(audio, dim=0, keepdim=True) | |
| return audio, current_sr | |
| except Exception as torchaudio_e: | |
| # 2. Fallback to using PyAV (FFmpeg wrapper) for formats like M4A, MP3 | |
| # print(f"Torchaudio failed. Falling back to PyAV. Error: {torchaudio_e}") | |
| try: | |
| import av | |
| with av.open(file_path) as container: | |
| stream = container.streams.audio[0] | |
| # Set up a resampler to ensure 16kHz float mono output | |
| resampler = av.AudioResampler( | |
| format='fltp', # 32-bit floating point | |
| layout='mono', # Force mono output | |
| rate=sr_target # Target sampling rate 16000 Hz | |
| ) | |
| # Decode the audio stream and resample frames | |
| for frame in container.decode(stream): | |
| for resampled_frame in resampler.resample(frame): | |
| # *** FIX APPLIED HERE: Removed 'format' keyword argument *** | |
| # to_ndarray() converts the frame to a NumPy array. | |
| # For a mono stream, [0] selects the single channel's data. | |
| audio_data_list.append(resampled_frame.to_ndarray()[0]) | |
| if not audio_data_list: | |
| raise RuntimeError("Could not decode audio frames using PyAV.") | |
| # Concatenate all the 1D NumPy arrays into a single, continuous array | |
| audio_np = np.concatenate(audio_data_list, axis=0) | |
| # Convert the NumPy array back to a PyTorch tensor, ensuring it's 1-channel (mono) | |
| audio = torch.from_numpy(audio_np).unsqueeze(0).float() | |
| return audio, sr_target | |
| except Exception as av_e: | |
| raise RuntimeError(f"Failed to load audio file using both torchaudio and PyAV. Error: {av_e}") | |
| # Note: The main `transcribe_audio` function and the Gradio setup do not need changes. | |
| # Just replace this one function and restart your application. | |
| # --- Post-Loading Processing (Only executes if torchaudio succeeded) --- | |
| # Resample if needed (if torchaudio succeeded but the rate was wrong) | |
| if current_sr != sr_target: | |
| if audio_data.dtype != torch.float32: | |
| audio_data = audio_data.float() | |
| resampler = torchaudio.transforms.Resample(orig_freq=current_sr, new_freq=sr_target) | |
| audio_data = resampler(audio_data) | |
| current_sr = sr_target | |
| # Convert to mono if necessary (take the mean across channels) | |
| if audio_data.shape[0] > 1: | |
| audio_data = torch.mean(audio_data, dim=0, keepdim=True) | |
| return audio_data, current_sr | |
| # --- Transcription Function --- | |
| def transcribe_audio(audio_file_path): | |
| """ | |
| Transcribes an audio file using the pre-loaded Whisper model. | |
| """ | |
| if model is None: | |
| return "Error: Model was not loaded successfully at startup." | |
| if audio_file_path is None: | |
| return "Error: No audio file provided." | |
| try: | |
| # Load audio using the robust loader and get the 16kHz mono tensor | |
| audio, sr = load_audio_file(audio_file_path) | |
| # The processor expects a 1D NumPy array for raw audio input | |
| # audio.squeeze().numpy() converts the (1, N) torch tensor to a (N,) numpy array | |
| inputs = processor(audio.squeeze().numpy(), sampling_rate=sr, return_tensors="pt") | |
| # Move inputs to the appropriate device | |
| input_features = inputs.input_features.to(device) | |
| with torch.no_grad(): | |
| # Use generation arguments to specify language and task for the Uz-Small model | |
| predicted_ids = model.generate( | |
| input_features, | |
| forced_decoder_ids=processor.get_decoder_prompt_ids(language="uz", task="transcribe"), | |
| max_length=448 # Use a reasonable max length for speed/resource management | |
| ) | |
| # Decode the generated token IDs to get the text transcript | |
| text = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0] | |
| return text | |
| except Exception as e: | |
| return f"An error occurred during transcription: {e}" | |
| # --- Gradio Interface Setup --- | |
| # πΌοΈ Interface Description | |
| title = "πΊπΏ Whisper Uz-Small v1: Audio Transcription" | |
| description = "A Gradio demo for the **OvozifyLabs/whisper-small-uz-v1** model for Uzbek ASR. Upload an audio file (M4A, MP3, WAV supported) or record directly." | |
| # π€ Input Component | |
| audio_input = gr.Audio( | |
| sources=["microphone", "upload"], | |
| type="filepath", | |
| label="Input Audio (M4A/MP3/WAV, etc.)" | |
| ) | |
| # π Output Component | |
| text_output = gr.Textbox(label="Transcription Result") | |
| # π Create the Interface | |
| demo = gr.Interface( | |
| fn=transcribe_audio, | |
| inputs=audio_input, | |
| outputs=text_output, | |
| title=title, | |
| description=description, | |
| # The 'allow_flagging' argument caused the TypeError and is removed/replaced | |
| # 'flagging_enabled=None' disables the flagging button, which is cleaner | |
| # flagging_enabled=None, | |
| # theme=gr.themes.Soft() | |
| ) | |
| # π» Launch the App | |
| if __name__ == "__main__": | |
| demo.launch() |