KuyaToto's picture
Update app.py
9c52375 verified
import gradio as gr
import torch
import torchaudio
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
# Load the model
model_id = "facebook/wav2vec2-base-960h"
processor = Wav2Vec2Processor.from_pretrained(model_id)
model = Wav2Vec2ForCTC.from_pretrained(model_id)
def transcribe(audio_file, progress=gr.Progress()):
if audio_file is None:
return "⚠️ No audio received."
waveform, sample_rate = torchaudio.load(audio_file)
if sample_rate != 16000:
waveform = torchaudio.functional.resample(waveform, orig_freq=sample_rate, new_freq=16000)
sample_rate = 16000
if waveform.shape[0] > 1:
waveform = waveform.mean(dim=0).unsqueeze(0)
input_values = processor(waveform.squeeze().numpy(), sampling_rate=sample_rate, return_tensors="pt").input_values
with torch.no_grad():
logits = model(input_values).logits
predicted_ids = torch.argmax(logits, dim=-1)
transcription = processor.batch_decode(predicted_ids)[0]
return transcription.lower()
demo = gr.Interface(
fn=transcribe,
inputs=gr.Audio(sources=["microphone"], type="filepath", label="🎤 Speak now"),
outputs=gr.Textbox(label="📝 Transcription"),
title="Wav2Vec2 Speech Transcription",
description="Speak into the microphone and get a transcription using Wav2Vec2-base.",
flagging_mode="never"
)
demo.launch()