File size: 2,188 Bytes
bddec1e
 
 
c17b8b3
bddec1e
 
 
 
9622244
c17b8b3
 
bddec1e
 
 
c17b8b3
bddec1e
 
 
 
 
8f2047c
 
 
 
 
 
 
 
 
 
c17b8b3
 
 
 
 
 
bddec1e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import torch
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
from app.audio_processing import load_and_resample

# Global variables to hold the model and processor
processor = None
model = None

def load_model(model_name: str = "facebook/wav2vec2-large-960h-lv60-self"):
    """
    Loads the Hugging Face Wav2Vec model and processor.
    Defaulting to English base model. For multilingual, consider models like:
    - 'facebook/mms-1b-all'
    - 'jonatasgrosman/wav2vec2-large-xlsr-53-english' (or other languages)
    """
    global processor, model
    print(f"Loading model {model_name}...")
    processor = Wav2Vec2Processor.from_pretrained(model_name)
    model = Wav2Vec2ForCTC.from_pretrained(model_name)
    
    # Temporary fix: Force CPU usage. 
    # The MPS (Apple Silicon GPU) backend in PyTorch currently has known bugs with Wav2Vec2 
    # that can cause the forward pass to freeze indefinitely.
    device = "cpu"
    # if torch.backends.mps.is_available():
    #     device = "mps"
    # elif torch.cuda.is_available():
    #     device = "cuda"
        
    model.to(device)

def transcribe_audio(audio_filepath: str) -> str:
    """
    Takes an audio filepath, processes it, and runs it through the Wav2Vec model
    to return a text transcription.
    """
    if model is None or processor is None:
        load_model()
        
    try:
        # 1. Load and resample audio to 16kHz
        speech = load_and_resample(audio_filepath, target_sr=16000)
        
        # 2. Prepare inputs
        inputs = processor(speech, sampling_rate=16000, return_tensors="pt", padding=True)
        
        # Move inputs to the same device as model
        device = next(model.parameters()).device
        inputs = {k: v.to(device) for k, v in inputs.items()}
        
        # 3. Perform inference
        with torch.no_grad():
            logits = model(**inputs).logits
            
        # 4. Decode the output
        predicted_ids = torch.argmax(logits, dim=-1)
        transcription = processor.batch_decode(predicted_ids)[0]
        
        return transcription.lower()
    except Exception as e:
        return f"Error during transcription: {str(e)}"