dmatekenya commited on
Commit
256ade0
·
1 Parent(s): ad594d3

Added app code

Browse files
Files changed (2) hide show
  1. app.py +244 -0
  2. requirements.txt +84 -0
app.py ADDED
@@ -0,0 +1,244 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ from typing import Optional, Tuple
4
+
5
+ import gradio as gr
6
+ import librosa
7
+ import numpy as np
8
+ import torch
9
+ from transformers import WhisperForConditionalGeneration, WhisperProcessor
10
+ from openai import OpenAI
11
+
12
+ # -----------------------------
13
+ # Models / Config
14
+ # -----------------------------
15
+ BASE_REPO = "openai/whisper-large-v3"
16
+
17
+ FINETUNED_REPO = "dmatekenya/whisper-large-v3-chichewa"
18
+ FINETUNED_REVISION = "bff60fb08ba9f294e05bfcab4306f30b6a0cfc0a" # pinned commit hash
19
+
20
+ # Local WhisperProcessor language hint (keep consistent with how you evaluated)
21
+ LOCAL_LANGUAGE = "shona"
22
+
23
+ # OpenAI language hint (ISO-639-1). Chichewa/Nyanja often "ny". Set to None to auto-detect.
24
+ OPENAI_LANGUAGE = "ny"
25
+
26
+ # Audio constraints
27
+ TARGET_SR = 16000
28
+ MAX_SECONDS = 30.0
29
+
30
+ # OpenAI transcription model (commercial)
31
+ OPENAI_MODEL = "whisper-1" # simple + stable
32
+
33
+ # -----------------------------
34
+ # UI Text / Styling
35
+ # -----------------------------
36
+ LOGO_HTML = """
37
+ <div style="text-align:center; margin-bottom: 25px;">
38
+ <img src="https://i.ibb.co/5nQdGSs/logo.png"
39
+ style="max-width: 100%; height: auto; border-radius: 12px;">
40
+ </div>
41
+ """
42
+
43
+ TITLE_HTML = """
44
+ <h1 style="text-align:center; font-size:34px; margin-bottom:10px;">
45
+ Chichewa Speech2Text: How Custom Data Improves Performance
46
+ </h1>
47
+ """
48
+
49
+ HIGHLIGHT_TEXT = """
50
+ <p style="text-align:center; font-size:20px; font-weight:600; color:#1F3A5F; margin-bottom:20px;">
51
+ Observe how the fine-tuned model provides better transcription quality.
52
+ </p>
53
+ """
54
+
55
+ DESCRIPTION_HTML = """
56
+ <p style="text-align:center; font-size:18px; margin-bottom: 18px;">
57
+ Upload or record a short Chichewa voice note (≤30 seconds). The same audio will be transcribed by three systems.
58
+ </p>
59
+ """
60
+
61
+ ARTICLE_HTML = """
62
+ <p style="text-align:center; margin-top: 10px;">
63
+ Read more about the <a href="https://dmatekenya.github.io/Chichewa-Speech2Text/README.html" target="_blank">ChichewaSpeech2Text</a> project
64
+ and sign up for our voice note donation event:
65
+ <a href="https://forms.gle/fHLESutofVvb2YFM9" target="_blank">Google Form</a>.
66
+ </p>
67
+ """
68
+
69
+ # -----------------------------
70
+ # Load local models once
71
+ # -----------------------------
72
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
73
+ DTYPE = torch.float16 if DEVICE == "cuda" else torch.float32
74
+
75
+ print(f"Using device: {DEVICE}", flush=True)
76
+
77
+ PROCESSOR = WhisperProcessor.from_pretrained(
78
+ BASE_REPO,
79
+ language=LOCAL_LANGUAGE,
80
+ task="transcribe",
81
+ )
82
+
83
+ MODEL_BASE = WhisperForConditionalGeneration.from_pretrained(BASE_REPO).to(DEVICE).eval()
84
+ MODEL_FT = WhisperForConditionalGeneration.from_pretrained(
85
+ FINETUNED_REPO,
86
+ revision=FINETUNED_REVISION,
87
+ ).to(DEVICE).eval()
88
+
89
+ if DEVICE == "cuda":
90
+ MODEL_BASE = MODEL_BASE.to(dtype=DTYPE)
91
+ MODEL_FT = MODEL_FT.to(dtype=DTYPE)
92
+
93
+ OPENAI_CLIENT = OpenAI()
94
+
95
+
96
+ # -----------------------------
97
+ # Helpers
98
+ # -----------------------------
99
+ def load_audio(audio_path: str) -> Tuple[np.ndarray, int, float]:
100
+ y, sr = librosa.load(audio_path, sr=TARGET_SR, mono=True)
101
+ dur = float(len(y) / sr) if sr else 0.0
102
+ return y, sr, dur
103
+
104
+
105
+ @torch.inference_mode()
106
+ def transcribe_local(model: WhisperForConditionalGeneration, audio_16k: np.ndarray) -> str:
107
+ inputs = PROCESSOR(audio_16k, sampling_rate=TARGET_SR, return_tensors="pt")
108
+ feats = inputs.input_features.to(DEVICE)
109
+ if DEVICE == "cuda":
110
+ feats = feats.to(dtype=DTYPE)
111
+
112
+ ids = model.generate(inputs=feats)
113
+ text = PROCESSOR.batch_decode(ids, skip_special_tokens=True)[0]
114
+ return text.strip()
115
+
116
+
117
+ def transcribe_openai(audio_path: str) -> str:
118
+ if not os.getenv("OPENAI_API_KEY"):
119
+ return "OpenAI ASR disabled: OPENAI_API_KEY not set in Space Secrets."
120
+
121
+ prompt = "Chichewa transcription. Malawi names like Lilongwe, Blantyre, Zomba. Keep local names as spoken."
122
+
123
+ with open(audio_path, "rb") as f:
124
+ resp = OPENAI_CLIENT.audio.transcriptions.create(
125
+ file=f,
126
+ model=OPENAI_MODEL,
127
+ language=OPENAI_LANGUAGE, # set to None for auto-detect
128
+ prompt=prompt,
129
+ temperature=0.0,
130
+ response_format="json",
131
+ )
132
+
133
+ return (resp.text or "").strip()
134
+
135
+
136
+ def transcribe_all(audio_path: Optional[str]) -> Tuple[str, str, str, str]:
137
+ """
138
+ Returns:
139
+ status, base_text, finetuned_text, openai_text
140
+ """
141
+ if not audio_path:
142
+ return "Please record or upload an audio clip.", "", "", ""
143
+
144
+ # Load audio once
145
+ try:
146
+ y, sr, dur = load_audio(audio_path)
147
+ except Exception as e:
148
+ return f"❌ Failed to load audio: {e}", "", "", ""
149
+
150
+ if dur > MAX_SECONDS:
151
+ return f"⚠️ Audio is {dur:.1f}s. Please keep clips ≤ {MAX_SECONDS:.0f}s.", "", "", ""
152
+
153
+ status = []
154
+
155
+ # Base (local)
156
+ t0 = time.time()
157
+ try:
158
+ base_text = transcribe_local(MODEL_BASE, y)
159
+ status.append(f"✅ Open Source (base) {time.time()-t0:.2f}s")
160
+ except Exception as e:
161
+ base_text = f"[ERROR] Base failed: {e}"
162
+ status.append("❌ Base failed")
163
+
164
+ # Fine-tuned (local)
165
+ t1 = time.time()
166
+ try:
167
+ ft_text = transcribe_local(MODEL_FT, y)
168
+ status.append(f"✅ Fine-tuned {time.time()-t1:.2f}s")
169
+ except Exception as e:
170
+ ft_text = f"[ERROR] Fine-tuned failed: {e}"
171
+ status.append("❌ Fine-tuned failed")
172
+
173
+ # OpenAI (commercial)
174
+ t2 = time.time()
175
+ try:
176
+ openai_text = transcribe_openai(audio_path)
177
+ status.append(f"✅ OpenAI ({OPENAI_MODEL}) {time.time()-t2:.2f}s")
178
+ except Exception as e:
179
+ openai_text = f"[ERROR] OpenAI failed: {e}"
180
+ status.append("❌ OpenAI failed")
181
+
182
+ return "\n".join(status), base_text, ft_text, openai_text
183
+
184
+
185
+ # -----------------------------
186
+ # Warm-up (local models only)
187
+ # -----------------------------
188
+ def warmup():
189
+ try:
190
+ dummy = np.zeros(int(TARGET_SR * 1.0), dtype=np.float32)
191
+ _ = transcribe_local(MODEL_BASE, dummy)
192
+ _ = transcribe_local(MODEL_FT, dummy)
193
+ print("Warm-up complete.", flush=True)
194
+ except Exception as e:
195
+ print(f"Warm-up skipped/failed: {e}", flush=True)
196
+
197
+
198
+ warmup()
199
+
200
+
201
+ # -----------------------------
202
+ # UI
203
+ # -----------------------------
204
+ with gr.Blocks(theme="grass", title="Chichewa Speech2Text") as demo:
205
+ gr.Markdown(LOGO_HTML)
206
+ gr.Markdown(TITLE_HTML)
207
+ gr.Markdown(HIGHLIGHT_TEXT)
208
+ gr.Markdown(DESCRIPTION_HTML)
209
+
210
+ audio_in = gr.Audio(
211
+ sources=["microphone", "upload"],
212
+ type="filepath",
213
+ label="Audio Input (Record or Upload)",
214
+ )
215
+
216
+ run_btn = gr.Button("Transcribe & Compare", variant="primary")
217
+
218
+ status_out = gr.Textbox(label="Status / timing", lines=3)
219
+
220
+ with gr.Row():
221
+ base_out = gr.Textbox(
222
+ label="Open Source ASR Model",
223
+ lines=12,
224
+ )
225
+ ft_out = gr.Textbox(
226
+ label="Open Source Model Fine-Tuned with Custom Chichewa Speech",
227
+ lines=12,
228
+ )
229
+ commercial_out = gr.Textbox(
230
+ label="Frontier Commercial ASR Model (OpenAI)",
231
+ lines=12,
232
+ )
233
+
234
+ run_btn.click(
235
+ fn=transcribe_all,
236
+ inputs=[audio_in],
237
+ outputs=[status_out, base_out, ft_out, commercial_out],
238
+ )
239
+
240
+ gr.Markdown(ARTICLE_HTML)
241
+
242
+ if __name__ == "__main__":
243
+ # Queue helps when multiple people test at once during your seminar
244
+ demo.queue(default_concurrency_limit=2).launch()
requirements.txt ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiofiles==24.1.0
2
+ annotated-doc==0.0.4
3
+ annotated-types==0.7.0
4
+ anyio==4.12.1
5
+ audioread==3.1.0
6
+ brotli==1.2.0
7
+ certifi==2026.1.4
8
+ cffi==2.0.0
9
+ charset-normalizer==3.4.4
10
+ click==8.3.1
11
+ decorator==5.2.1
12
+ distro==1.9.0
13
+ fastapi==0.129.2
14
+ ffmpy==1.0.0
15
+ filelock==3.24.3
16
+ fsspec==2026.2.0
17
+ gradio==6.6.0
18
+ gradio-client==2.1.0
19
+ groovy==0.1.2
20
+ groq==1.0.0
21
+ h11==0.16.0
22
+ hf-xet==1.2.0
23
+ httpcore==1.0.9
24
+ httpx==0.28.1
25
+ huggingface-hub==1.4.1
26
+ idna==3.11
27
+ jinja2==3.1.6
28
+ jiter==0.13.0
29
+ joblib==1.5.3
30
+ lazy-loader==0.4
31
+ librosa==0.11.0
32
+ llvmlite==0.46.0
33
+ markdown-it-py==4.0.0
34
+ markupsafe==3.0.3
35
+ mdurl==0.1.2
36
+ mpmath==1.3.0
37
+ msgpack==1.1.2
38
+ networkx==3.6.1
39
+ numba==0.64.0
40
+ numpy==2.4.2
41
+ openai==2.21.0
42
+ orjson==3.11.7
43
+ packaging==26.0
44
+ pandas==3.0.1
45
+ pillow==12.1.1
46
+ platformdirs==4.9.2
47
+ pooch==1.9.0
48
+ pycparser==3.0
49
+ pydantic==2.12.5
50
+ pydantic-core==2.41.5
51
+ pydub==0.25.1
52
+ pygments==2.19.2
53
+ python-dateutil==2.9.0.post0
54
+ python-multipart==0.0.22
55
+ pytz==2025.2
56
+ pyyaml==6.0.3
57
+ regex==2026.2.19
58
+ requests==2.32.5
59
+ rich==14.3.3
60
+ safehttpx==0.1.7
61
+ safetensors==0.7.0
62
+ scikit-learn==1.8.0
63
+ scipy==1.17.0
64
+ semantic-version==2.10.0
65
+ setuptools==82.0.0
66
+ shellingham==1.5.4
67
+ six==1.17.0
68
+ sniffio==1.3.1
69
+ soundfile==0.13.1
70
+ soxr==1.0.0
71
+ starlette==0.52.1
72
+ sympy==1.14.0
73
+ threadpoolctl==3.6.0
74
+ tokenizers==0.22.2
75
+ tomlkit==0.13.3
76
+ torch==2.10.0
77
+ tqdm==4.67.3
78
+ transformers==5.2.0
79
+ typer==0.24.1
80
+ typer-slim==0.24.0
81
+ typing-extensions==4.15.0
82
+ typing-inspection==0.4.2
83
+ urllib3==2.6.3
84
+ uvicorn==0.41.0