| import librosa | |
| import torch | |
| from transformers import Wav2Vec2ForCTC, AutoProcessor | |
| from transformers import set_seed | |
| import time | |
| def transcribe(fp:str, target_lang:str) -> str: | |
| ''' | |
| For given audio file, transcribe it. | |
| Parameters | |
| ---------- | |
| fp: str | |
| The file path to the audio file. | |
| target_lang:str | |
| The ISO-3 code of the target language. | |
| Returns | |
| ---------- | |
| transcript:str | |
| The transcribed text. | |
| ''' | |
| # Ensure replicability | |
| set_seed(555) | |
| start_time = time.time() | |
| # Load transcription model | |
| model_id = "facebook/mms-1b-all" | |
| processor = AutoProcessor.from_pretrained(model_id, target_lang=target_lang) | |
| model = Wav2Vec2ForCTC.from_pretrained(model_id, target_lang=target_lang, ignore_mismatched_sizes=True) | |
| # Process the audio | |
| signal, sampling_rate = librosa.load(fp, sr=16000) | |
| inputs = processor(signal, sampling_rate=16_000, return_tensors="pt") | |
| # Inference | |
| with torch.no_grad(): | |
| outputs = model(**inputs).logits | |
| ids = torch.argmax(outputs, dim=-1)[0] | |
| transcript = processor.decode(ids) | |
| print("Time elapsed: ", int(time.time() - start_time), " seconds") | |
| return transcript |