ZidanePMSE commited on
Commit
31266e7
·
verified ·
1 Parent(s): 17d198f

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +94 -0
  2. requirements.txt +7 -0
app.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import torch
3
+ import torchaudio
4
+ import numpy as np
5
+ import gradio as gr
6
+ import soundfile as sf
7
+
8
+ from transformers import WhisperProcessor, WhisperForConditionalGeneration
9
+
10
+ # ===== CONFIG =====
11
+ MODEL_ID = "vinai/PhoWhisper-small"
12
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
+ TARGET_SR = 16000 # Whisper expects 16kHz
14
+
15
+ # ===== LOAD MODEL =====
16
+ processor = WhisperProcessor.from_pretrained(MODEL_ID)
17
+ model = WhisperForConditionalGeneration.from_pretrained(MODEL_ID).to(DEVICE)
18
+ model.eval()
19
+
20
+ # prepare forced decoder ids for Vietnamese transcription
21
+ try:
22
+ forced_decoder_ids = processor.get_decoder_prompt_ids(language="vi", task="transcribe")
23
+ except Exception:
24
+ forced_decoder_ids = None
25
+
26
+ # ===== HELPERS =====
27
+ def _read_audio_tuple(audio):
28
+ """
29
+ audio: (sr, np.ndarray) coming from gr.Audio(type="numpy")
30
+ returns mono float32 numpy array and original sr
31
+ """
32
+ if audio is None:
33
+ return None, None
34
+ sr, data = audio
35
+ # ensure numpy
36
+ data = np.asarray(data)
37
+ # stereo -> mono
38
+ if data.ndim > 1:
39
+ data = data.mean(axis=1)
40
+ # convert to float32 in range [-1, 1] if needed
41
+ if data.dtype.kind == "i":
42
+ # integer PCM -> normalize
43
+ maxv = float(np.iinfo(data.dtype).max)
44
+ data = data.astype("float32") / maxv
45
+ else:
46
+ data = data.astype("float32")
47
+ return data, sr
48
+
49
+ # ===== INFERENCE =====
50
+ def s2t(audio):
51
+ """
52
+ audio: (sr, numpy array) from gradio Audio
53
+ returns: transcription string
54
+ """
55
+ data, sr = _read_audio_tuple(audio)
56
+ if data is None:
57
+ return "No audio provided"
58
+
59
+ # resample if needed
60
+ if sr != TARGET_SR:
61
+ waveform = torch.from_numpy(data)
62
+ waveform = torchaudio.functional.resample(waveform, orig_freq=sr, new_freq=TARGET_SR)
63
+ data = waveform.numpy()
64
+
65
+ # processor -> input features
66
+ inputs = processor(data, sampling_rate=TARGET_SR, return_tensors="pt")
67
+ input_features = inputs.input_features.to(DEVICE)
68
+
69
+ with torch.no_grad():
70
+ if forced_decoder_ids is not None:
71
+ pred_ids = model.generate(input_features, forced_decoder_ids=forced_decoder_ids)
72
+ else:
73
+ pred_ids = model.generate(input_features)
74
+
75
+ # decode
76
+ transcription = processor.batch_decode(pred_ids, skip_special_tokens=True)[0]
77
+ return transcription.strip()
78
+
79
+ # ===== GRADIO APP =====
80
+ title = "Vietnamese Speech-to-Text — PhoWhisper-small"
81
+ desc = "Upload or record audio (wav/mp3). Model: vinai/PhoWhisper-small. Resamples to 16 kHz."
82
+
83
+ app = gr.Interface(
84
+ fn=s2t,
85
+ inputs=gr.Audio(source="upload", type="numpy", label="Upload or record audio (.wav/.mp3)"),
86
+ outputs=gr.Textbox(label="Transcription"),
87
+ title=title,
88
+ description=desc,
89
+ allow_flagging="never",
90
+ examples=[],
91
+ )
92
+
93
+ if __name__ == "__main__":
94
+ app.launch()
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ torch
2
+ torchaudio
3
+ transformers
4
+ sentencepiece
5
+ gradio
6
+ soundfile
7
+ numpy