import torch from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC from app.audio_processing import load_and_resample # Global variables to hold the model and processor processor = None model = None def load_model(model_name: str = "facebook/wav2vec2-large-960h-lv60-self"): """ Loads the Hugging Face Wav2Vec model and processor. Defaulting to English base model. For multilingual, consider models like: - 'facebook/mms-1b-all' - 'jonatasgrosman/wav2vec2-large-xlsr-53-english' (or other languages) """ global processor, model print(f"Loading model {model_name}...") processor = Wav2Vec2Processor.from_pretrained(model_name) model = Wav2Vec2ForCTC.from_pretrained(model_name) # Temporary fix: Force CPU usage. # The MPS (Apple Silicon GPU) backend in PyTorch currently has known bugs with Wav2Vec2 # that can cause the forward pass to freeze indefinitely. device = "cpu" # if torch.backends.mps.is_available(): # device = "mps" # elif torch.cuda.is_available(): # device = "cuda" model.to(device) def transcribe_audio(audio_filepath: str) -> str: """ Takes an audio filepath, processes it, and runs it through the Wav2Vec model to return a text transcription. """ if model is None or processor is None: load_model() try: # 1. Load and resample audio to 16kHz speech = load_and_resample(audio_filepath, target_sr=16000) # 2. Prepare inputs inputs = processor(speech, sampling_rate=16000, return_tensors="pt", padding=True) # Move inputs to the same device as model device = next(model.parameters()).device inputs = {k: v.to(device) for k, v in inputs.items()} # 3. Perform inference with torch.no_grad(): logits = model(**inputs).logits # 4. Decode the output predicted_ids = torch.argmax(logits, dim=-1) transcription = processor.batch_decode(predicted_ids)[0] return transcription.lower() except Exception as e: return f"Error during transcription: {str(e)}"