Spaces:
Sleeping
Sleeping
File size: 1,697 Bytes
5c831db 5561e0e 5c831db 112a3ee 5561e0e 5c831db 112a3ee 0663839 e6552fa 5561e0e 0663839 112a3ee 5c831db 112a3ee 5c831db 112a3ee 5c831db 112a3ee 5c831db 0663839 5c831db 5561e0e 5c831db | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 | import torch
import torchaudio
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
import gradio as gr
MODEL_PATH = "nambn0321/ASR_models"
processor = Wav2Vec2Processor.from_pretrained(MODEL_PATH)
model = Wav2Vec2ForCTC.from_pretrained(MODEL_PATH).eval()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
def transcribe(audio):
try:
if audio is None:
return "No audio provided"
sr, data = audio
print(f"Sample rate: {sr}, Audio shape: {len(data)}")
waveform = torch.tensor(data, dtype=torch.float32).unsqueeze(0)
waveform = waveform / 32768.0
if sr != 16000:
resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=16000)
waveform = resampler(waveform)
if waveform.shape[0] > 1:
waveform = waveform.mean(dim=0, keepdim=True)
inputs = processor(waveform.squeeze().numpy(), sampling_rate=16000, return_tensors="pt", padding=True)
input_values = inputs.input_values.to(device)
with torch.no_grad():
logits = model(input_values).logits
predicted_ids = torch.argmax(logits, dim=-1)
transcription = processor.batch_decode(predicted_ids)[0]
return transcription.strip()
except Exception as e:
print("Error during transcription:", str(e))
return f"Error: {str(e)}"
gr.Interface(
fn=transcribe,
inputs=gr.Audio(type="numpy", label="Upload WAV/MP3 file"),
outputs=gr.Textbox(label="Transcription"),
title=" ASR Demo oMGMGGOMGOMGOGMOG",
description="Upload an audio file (WAV or MP3) and get the transcription.",
).launch()
|