Spaces:
Runtime error
Runtime error
| import torch | |
| import torchaudio | |
| import gradio as gr | |
| from transformers import Wav2Vec2BertProcessor, Wav2Vec2BertForCTC | |
| # Set device | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| # Load processor & model | |
| model_name = "cdactvm/w2v-bert-punjabi" # Change if using a Punjabi ASR model | |
| processor = Wav2Vec2BertProcessor.from_pretrained(model_name) | |
| # Loading the original model. | |
| original_model=Wav2Vec2BertForCTC.from_pretrained(model_name) | |
| # Explicitly allow Wav2Vec2BertForCTC during unpickling3+ | |
| torch.serialization.add_safe_globals([Wav2Vec2BertForCTC]) | |
| # Load the full quantized model | |
| quantized_model = torch.load("cdactvm/w2v-bert-punjabi/wav2vec2_bert_qint8.pth", weights_only=False) | |
| quantized_model.eval() | |
| ##################################################### | |
| # recognize speech using original model | |
| def transcribe_original_model(audio_path): | |
| # Load audio file | |
| waveform, sample_rate = torchaudio.load(audio_path) | |
| # Convert stereo to mono (if needed) | |
| if waveform.shape[0] > 1: | |
| waveform = torch.mean(waveform, dim=0, keepdim=True) | |
| # Resample to 16kHz | |
| if sample_rate != 16000: | |
| waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform) | |
| # Process audio | |
| inputs = processor(waveform.squeeze(0), sampling_rate=16000, return_tensors="pt") | |
| inputs = {key: val.to(device, dtype=torch.bfloat16) for key, val in inputs.items()} | |
| # Get logits & transcribe | |
| with torch.no_grad(): | |
| logits = original_model(**inputs).logits | |
| predicted_ids = torch.argmax(logits, dim=-1) | |
| transcription = processor.batch_decode(predicted_ids)[0] | |
| return transcription | |
| # recognize speech using quantized model. | |
| def transcribe_quantized_model(audio_path): | |
| # Load audio file | |
| waveform, sample_rate = torchaudio.load(audio_path) | |
| # Convert stereo to mono (if needed) | |
| if waveform.shape[0] > 1: | |
| waveform = torch.mean(waveform, dim=0, keepdim=True) | |
| # Resample to 16kHz | |
| if sample_rate != 16000: | |
| waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform) | |
| # Process audio | |
| inputs = processor(waveform.squeeze(0), sampling_rate=16000, return_tensors="pt") | |
| inputs = {key: val.to(device, dtype=torch.bfloat16) for key, val in inputs.items()} | |
| # Get logits & transcribe | |
| with torch.no_grad(): | |
| logits = quantized_model(**inputs).logits | |
| predicted_ids = torch.argmax(logits, dim=-1) | |
| transcription = processor.batch_decode(predicted_ids)[0] | |
| return transcription | |
| def select_lng(lng, mic=None, file=None): | |
| if mic is not None: | |
| audio = mic | |
| elif file is not None: | |
| audio = file | |
| else: | |
| return "You must either provide a mic recording or a file" | |
| if lng == "original_model": | |
| return transcribe_original_model(audio) | |
| elif lng == "quantized_model": | |
| return transcribe_quantized_model(audio) | |
| # Gradio Interface | |
| demo=gr.Interface( | |
| fn=select_lng, | |
| inputs=[ | |
| gr.Dropdown(["original_model","quantized_model"],label="Select Model"), | |
| gr.Audio(sources=["microphone","upload"], type="filepath"), | |
| ], | |
| outputs=["textbox"], | |
| title="Automatic Speech Recognition", | |
| description = "Upload an audio file and get the transcription in Punjabi.", | |
| ) | |
| if __name__ == "__main__": | |
| app.launch() | |