Spaces:
Running
Running
File size: 2,188 Bytes
bddec1e c17b8b3 bddec1e 9622244 c17b8b3 bddec1e c17b8b3 bddec1e 8f2047c c17b8b3 bddec1e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 | import torch
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
from app.audio_processing import load_and_resample
# Global variables to hold the model and processor
processor = None
model = None
def load_model(model_name: str = "facebook/wav2vec2-large-960h-lv60-self"):
"""
Loads the Hugging Face Wav2Vec model and processor.
Defaulting to English base model. For multilingual, consider models like:
- 'facebook/mms-1b-all'
- 'jonatasgrosman/wav2vec2-large-xlsr-53-english' (or other languages)
"""
global processor, model
print(f"Loading model {model_name}...")
processor = Wav2Vec2Processor.from_pretrained(model_name)
model = Wav2Vec2ForCTC.from_pretrained(model_name)
# Temporary fix: Force CPU usage.
# The MPS (Apple Silicon GPU) backend in PyTorch currently has known bugs with Wav2Vec2
# that can cause the forward pass to freeze indefinitely.
device = "cpu"
# if torch.backends.mps.is_available():
# device = "mps"
# elif torch.cuda.is_available():
# device = "cuda"
model.to(device)
def transcribe_audio(audio_filepath: str) -> str:
"""
Takes an audio filepath, processes it, and runs it through the Wav2Vec model
to return a text transcription.
"""
if model is None or processor is None:
load_model()
try:
# 1. Load and resample audio to 16kHz
speech = load_and_resample(audio_filepath, target_sr=16000)
# 2. Prepare inputs
inputs = processor(speech, sampling_rate=16000, return_tensors="pt", padding=True)
# Move inputs to the same device as model
device = next(model.parameters()).device
inputs = {k: v.to(device) for k, v in inputs.items()}
# 3. Perform inference
with torch.no_grad():
logits = model(**inputs).logits
# 4. Decode the output
predicted_ids = torch.argmax(logits, dim=-1)
transcription = processor.batch_decode(predicted_ids)[0]
return transcription.lower()
except Exception as e:
return f"Error during transcription: {str(e)}"
|