Spaces:
Running
Running
| import gradio as gr | |
| import torch | |
| import torchaudio | |
| from transformers import AutoProcessor, Wav2Vec2ForCTC | |
| MODEL_ID = "sb-x/mms-1b-bbl" | |
| processor = AutoProcessor.from_pretrained(MODEL_ID) | |
| model = Wav2Vec2ForCTC.from_pretrained(MODEL_ID) | |
| model.eval() | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model.to(device) | |
| def transcribe(audio): | |
| if audio is None: | |
| return "" | |
| sr, wav = audio | |
| wav = torch.tensor(wav).float() | |
| if wav.ndim > 1: | |
| wav = wav.mean(dim=1) | |
| if sr != 16000: | |
| wav = torchaudio.functional.resample(wav, sr, 16000) | |
| inputs = processor( | |
| wav.numpy(), | |
| sampling_rate=16000, | |
| return_tensors="pt" | |
| ) | |
| inputs = {k: v.to(device) for k, v in inputs.items()} | |
| with torch.no_grad(): | |
| logits = model(**inputs).logits | |
| pred_ids = torch.argmax(logits, dim=-1) | |
| return processor.batch_decode(pred_ids)[0] | |
| demo = gr.Interface( | |
| fn=transcribe, | |
| inputs=gr.Audio(type="numpy"), | |
| outputs=gr.Textbox(label="Transcription",lines=10), | |
| title="MMS-1b-bbl ASR Demo", | |
| description="Fine-tuned MMS ASR model on bbl data from mozilla (CC BY-NC 4.0)" | |
| ) | |
| demo.launch() | |