File size: 1,697 Bytes
5c831db
 
 
 
 
5561e0e
5c831db
 
 
 
 
 
112a3ee
 
5561e0e
5c831db
112a3ee
 
0663839
e6552fa
5561e0e
0663839
112a3ee
 
 
5c831db
112a3ee
 
5c831db
112a3ee
 
5c831db
112a3ee
 
 
 
 
 
 
 
 
 
5c831db
 
 
0663839
5c831db
5561e0e
 
5c831db
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import torch
import torchaudio
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
import gradio as gr

MODEL_PATH = "nambn0321/ASR_models" 
processor = Wav2Vec2Processor.from_pretrained(MODEL_PATH)
model = Wav2Vec2ForCTC.from_pretrained(MODEL_PATH).eval()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

def transcribe(audio):
    try:
        if audio is None:
            return "No audio provided"

        sr, data = audio
        print(f"Sample rate: {sr}, Audio shape: {len(data)}")

        waveform = torch.tensor(data, dtype=torch.float32).unsqueeze(0)
        waveform = waveform / 32768.0  

        if sr != 16000:
            resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=16000)
            waveform = resampler(waveform)

        if waveform.shape[0] > 1:
            waveform = waveform.mean(dim=0, keepdim=True)

        inputs = processor(waveform.squeeze().numpy(), sampling_rate=16000, return_tensors="pt", padding=True)
        input_values = inputs.input_values.to(device)

        with torch.no_grad():
            logits = model(input_values).logits
            predicted_ids = torch.argmax(logits, dim=-1)

        transcription = processor.batch_decode(predicted_ids)[0]
        return transcription.strip()
    
    except Exception as e:
        print("Error during transcription:", str(e))
        return f"Error: {str(e)}"

gr.Interface(
    fn=transcribe,
    inputs=gr.Audio(type="numpy", label="Upload WAV/MP3 file"),
    outputs=gr.Textbox(label="Transcription"),
    title=" ASR Demo oMGMGGOMGOMGOGMOG",
    description="Upload an audio file (WAV or MP3) and get the transcription.",
).launch()