Multilingual-ASR / app /asr_model.py
adiitya29's picture
fix: loaded the larger model for better results and added the downloading functionality for transcribed text
9622244
import torch
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
from app.audio_processing import load_and_resample
# Global variables to hold the model and processor
processor = None
model = None
def load_model(model_name: str = "facebook/wav2vec2-large-960h-lv60-self"):
"""
Loads the Hugging Face Wav2Vec model and processor.
Defaulting to English base model. For multilingual, consider models like:
- 'facebook/mms-1b-all'
- 'jonatasgrosman/wav2vec2-large-xlsr-53-english' (or other languages)
"""
global processor, model
print(f"Loading model {model_name}...")
processor = Wav2Vec2Processor.from_pretrained(model_name)
model = Wav2Vec2ForCTC.from_pretrained(model_name)
# Temporary fix: Force CPU usage.
# The MPS (Apple Silicon GPU) backend in PyTorch currently has known bugs with Wav2Vec2
# that can cause the forward pass to freeze indefinitely.
device = "cpu"
# if torch.backends.mps.is_available():
# device = "mps"
# elif torch.cuda.is_available():
# device = "cuda"
model.to(device)
def transcribe_audio(audio_filepath: str) -> str:
"""
Takes an audio filepath, processes it, and runs it through the Wav2Vec model
to return a text transcription.
"""
if model is None or processor is None:
load_model()
try:
# 1. Load and resample audio to 16kHz
speech = load_and_resample(audio_filepath, target_sr=16000)
# 2. Prepare inputs
inputs = processor(speech, sampling_rate=16000, return_tensors="pt", padding=True)
# Move inputs to the same device as model
device = next(model.parameters()).device
inputs = {k: v.to(device) for k, v in inputs.items()}
# 3. Perform inference
with torch.no_grad():
logits = model(**inputs).logits
# 4. Decode the output
predicted_ids = torch.argmax(logits, dim=-1)
transcription = processor.batch_decode(predicted_ids)[0]
return transcription.lower()
except Exception as e:
return f"Error during transcription: {str(e)}"