Spaces:
Running
Running
| import torch | |
| from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC | |
| from app.audio_processing import load_and_resample | |
| # Global variables to hold the model and processor | |
| processor = None | |
| model = None | |
| def load_model(model_name: str = "facebook/wav2vec2-large-960h-lv60-self"): | |
| """ | |
| Loads the Hugging Face Wav2Vec model and processor. | |
| Defaulting to English base model. For multilingual, consider models like: | |
| - 'facebook/mms-1b-all' | |
| - 'jonatasgrosman/wav2vec2-large-xlsr-53-english' (or other languages) | |
| """ | |
| global processor, model | |
| print(f"Loading model {model_name}...") | |
| processor = Wav2Vec2Processor.from_pretrained(model_name) | |
| model = Wav2Vec2ForCTC.from_pretrained(model_name) | |
| # Temporary fix: Force CPU usage. | |
| # The MPS (Apple Silicon GPU) backend in PyTorch currently has known bugs with Wav2Vec2 | |
| # that can cause the forward pass to freeze indefinitely. | |
| device = "cpu" | |
| # if torch.backends.mps.is_available(): | |
| # device = "mps" | |
| # elif torch.cuda.is_available(): | |
| # device = "cuda" | |
| model.to(device) | |
| def transcribe_audio(audio_filepath: str) -> str: | |
| """ | |
| Takes an audio filepath, processes it, and runs it through the Wav2Vec model | |
| to return a text transcription. | |
| """ | |
| if model is None or processor is None: | |
| load_model() | |
| try: | |
| # 1. Load and resample audio to 16kHz | |
| speech = load_and_resample(audio_filepath, target_sr=16000) | |
| # 2. Prepare inputs | |
| inputs = processor(speech, sampling_rate=16000, return_tensors="pt", padding=True) | |
| # Move inputs to the same device as model | |
| device = next(model.parameters()).device | |
| inputs = {k: v.to(device) for k, v in inputs.items()} | |
| # 3. Perform inference | |
| with torch.no_grad(): | |
| logits = model(**inputs).logits | |
| # 4. Decode the output | |
| predicted_ids = torch.argmax(logits, dim=-1) | |
| transcription = processor.batch_decode(predicted_ids)[0] | |
| return transcription.lower() | |
| except Exception as e: | |
| return f"Error during transcription: {str(e)}" | |