import io import torch import torchaudio import numpy as np import gradio as gr import soundfile as sf from transformers import WhisperProcessor, WhisperForConditionalGeneration # ===== CONFIG ===== MODEL_ID = "vinai/PhoWhisper-small" DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") TARGET_SR = 16000 # Whisper expects 16kHz # ===== LOAD MODEL ===== processor = WhisperProcessor.from_pretrained(MODEL_ID) model = WhisperForConditionalGeneration.from_pretrained(MODEL_ID).to(DEVICE) model.eval() # prepare forced decoder ids for Vietnamese transcription try: forced_decoder_ids = processor.get_decoder_prompt_ids(language="vi", task="transcribe") except Exception: forced_decoder_ids = None # ===== HELPERS ===== def _read_audio_tuple(audio): """ audio: (sr, np.ndarray) coming from gr.Audio(type="numpy") returns mono float32 numpy array and original sr """ if audio is None: return None, None sr, data = audio # ensure numpy data = np.asarray(data) # stereo -> mono if data.ndim > 1: data = data.mean(axis=1) # convert to float32 in range [-1, 1] if needed if data.dtype.kind == "i": # integer PCM -> normalize maxv = float(np.iinfo(data.dtype).max) data = data.astype("float32") / maxv else: data = data.astype("float32") return data, sr # ===== INFERENCE ===== def s2t(audio): """ audio: (sr, numpy array) from gradio Audio returns: transcription string """ data, sr = _read_audio_tuple(audio) if data is None: return "No audio provided" # resample if needed if sr != TARGET_SR: waveform = torch.from_numpy(data) waveform = torchaudio.functional.resample(waveform, orig_freq=sr, new_freq=TARGET_SR) data = waveform.numpy() # processor -> input features inputs = processor(data, sampling_rate=TARGET_SR, return_tensors="pt") input_features = inputs.input_features.to(DEVICE) with torch.no_grad(): if forced_decoder_ids is not None: pred_ids = model.generate(input_features, forced_decoder_ids=forced_decoder_ids) else: pred_ids = model.generate(input_features) # decode transcription = processor.batch_decode(pred_ids, skip_special_tokens=True)[0] return transcription.strip() # ===== GRADIO APP ===== title = "Vietnamese Speech-to-Text — PhoWhisper-small" desc = "Upload or record audio (wav/mp3). Model: vinai/PhoWhisper-small. Resamples to 16 kHz." app = gr.Interface( fn=s2t, inputs=gr.Audio(type="numpy", label="Upload or record audio (.wav/.mp3)"), outputs=gr.Textbox(label="Transcription"), title="Vietnamese Speech-to-Text — PhoWhisper-small", description="Model: vinai/PhoWhisper-small. Resamples to 16 kHz." ) if __name__ == "__main__": app.launch()