speech_to_text / app.py
ZidanePMSE's picture
Update app.py
cd25af0 verified
import io
import torch
import torchaudio
import numpy as np
import gradio as gr
import soundfile as sf
from transformers import WhisperProcessor, WhisperForConditionalGeneration
# ===== CONFIG =====
MODEL_ID = "vinai/PhoWhisper-small"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
TARGET_SR = 16000 # Whisper expects 16kHz
# ===== LOAD MODEL =====
processor = WhisperProcessor.from_pretrained(MODEL_ID)
model = WhisperForConditionalGeneration.from_pretrained(MODEL_ID).to(DEVICE)
model.eval()
# prepare forced decoder ids for Vietnamese transcription
try:
forced_decoder_ids = processor.get_decoder_prompt_ids(language="vi", task="transcribe")
except Exception:
forced_decoder_ids = None
# ===== HELPERS =====
def _read_audio_tuple(audio):
"""
audio: (sr, np.ndarray) coming from gr.Audio(type="numpy")
returns mono float32 numpy array and original sr
"""
if audio is None:
return None, None
sr, data = audio
# ensure numpy
data = np.asarray(data)
# stereo -> mono
if data.ndim > 1:
data = data.mean(axis=1)
# convert to float32 in range [-1, 1] if needed
if data.dtype.kind == "i":
# integer PCM -> normalize
maxv = float(np.iinfo(data.dtype).max)
data = data.astype("float32") / maxv
else:
data = data.astype("float32")
return data, sr
# ===== INFERENCE =====
def s2t(audio):
"""
audio: (sr, numpy array) from gradio Audio
returns: transcription string
"""
data, sr = _read_audio_tuple(audio)
if data is None:
return "No audio provided"
# resample if needed
if sr != TARGET_SR:
waveform = torch.from_numpy(data)
waveform = torchaudio.functional.resample(waveform, orig_freq=sr, new_freq=TARGET_SR)
data = waveform.numpy()
# processor -> input features
inputs = processor(data, sampling_rate=TARGET_SR, return_tensors="pt")
input_features = inputs.input_features.to(DEVICE)
with torch.no_grad():
if forced_decoder_ids is not None:
pred_ids = model.generate(input_features, forced_decoder_ids=forced_decoder_ids)
else:
pred_ids = model.generate(input_features)
# decode
transcription = processor.batch_decode(pred_ids, skip_special_tokens=True)[0]
return transcription.strip()
# ===== GRADIO APP =====
title = "Vietnamese Speech-to-Text β€” PhoWhisper-small"
desc = "Upload or record audio (wav/mp3). Model: vinai/PhoWhisper-small. Resamples to 16 kHz."
app = gr.Interface(
fn=s2t,
inputs=gr.Audio(type="numpy", label="Upload or record audio (.wav/.mp3)"),
outputs=gr.Textbox(label="Transcription"),
title="Vietnamese Speech-to-Text β€” PhoWhisper-small",
description="Model: vinai/PhoWhisper-small. Resamples to 16 kHz."
)
if __name__ == "__main__":
app.launch()