sohamchitimali commited on
Commit
b997d93
·
1 Parent(s): 55bd939
Files changed (2) hide show
  1. app.py +19 -209
  2. requirements.txt +6 -18
app.py CHANGED
@@ -1,8 +1,5 @@
1
  import os
2
  import tempfile
3
- import glob
4
- import shutil
5
- import subprocess
6
  import torch
7
  import whisper
8
  import gradio as gr
@@ -10,34 +7,12 @@ from fastapi import FastAPI, File, Form, UploadFile, HTTPException
10
  from fastapi.middleware.cors import CORSMiddleware
11
  import uvicorn
12
 
13
- # -----------------------
14
- # Configuration / tuning
15
- # -----------------------
16
-
17
- # Use all CPU cores for PyTorch (must be set before loading the model)
18
- NUM_CPU = os.cpu_count() or 1
19
- torch.set_num_threads(NUM_CPU)
20
- torch.set_num_interop_threads(NUM_CPU)
21
-
22
  MODEL_NAME = os.getenv("WHISPER_MODEL", "base")
23
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
24
- # On CPU, always use fp16=False
25
- FP16 = (DEVICE == "cuda")
26
-
27
- # ffmpeg presence check (used to normalize & chunk audio)
28
- FFMPEG = shutil.which("ffmpeg")
29
-
30
- # chunk duration (seconds) - smaller chunks help with long audio on CPU
31
- CHUNK_SECONDS = int(os.getenv("WHUNK_CHUNK_SECONDS", "30"))
32
-
33
- # -----------------------
34
- # Load model (after threads set)
35
- # -----------------------
36
  MODEL = whisper.load_model(MODEL_NAME, device=DEVICE)
37
 
38
- # -----------------------
39
- # FastAPI app
40
- # -----------------------
41
  app = FastAPI(title="Whisper API")
42
 
43
  app.add_middleware(
@@ -48,192 +23,38 @@ app.add_middleware(
48
  allow_headers=["*"],
49
  )
50
 
51
- # -----------------------
52
- # Utilities
53
- # -----------------------
54
  def _save_temp(upload: UploadFile) -> str:
55
- """Save UploadFile to a temp file and return path."""
56
  suffix = os.path.splitext(upload.filename or "audio")[1] or ".wav"
57
  with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp:
58
  tmp.write(upload.file.read())
59
  return tmp.name
60
 
61
- def _ensure_wav_mono_16k(src_path: str) -> str:
62
- """
63
- Use ffmpeg to convert src_path to mono 16k WAV.
64
- If ffmpeg not present, return src_path (best-effort).
65
- Returns path to standardized wav (temp file) that caller must remove.
66
- """
67
- if not FFMPEG:
68
- # ffmpeg not available — rely on caller-provided file (best-effort)
69
- return src_path
70
-
71
- out_tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
72
- out_path = out_tmp.name
73
- out_tmp.close()
74
-
75
- cmd = [
76
- FFMPEG,
77
- "-y",
78
- "-i", src_path,
79
- "-ar", "16000", # sample rate 16 kHz
80
- "-ac", "1", # mono
81
- "-sample_fmt", "s16",# PCM16
82
- out_path
83
- ]
84
- try:
85
- subprocess.run(cmd, check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
86
- return out_path
87
- except Exception:
88
- # conversion failed — cleanup and fallback
89
- try:
90
- os.remove(out_path)
91
- except Exception:
92
- pass
93
- return src_path
94
-
95
- def _split_into_chunks(wav_path: str, chunk_seconds: int = CHUNK_SECONDS) -> list:
96
- """
97
- Split a WAV into chunk files using ffmpeg segmenter.
98
- Returns list of chunk file paths (sorted).
99
- If ffmpeg missing or splitting fails, returns [wav_path].
100
- Caller must remove chunk files after use.
101
- """
102
- if not FFMPEG:
103
- return [wav_path]
104
-
105
- tmpdir = tempfile.mkdtemp(prefix="whisper_chunks_")
106
- # segment into re-encoded WAVs to guarantee compatibility
107
- out_pattern = os.path.join(tmpdir, "chunk_%03d.wav")
108
- cmd = [
109
- FFMPEG,
110
- "-y",
111
- "-i", wav_path,
112
- "-ar", "16000",
113
- "-ac", "1",
114
- "-f", "segment",
115
- "-segment_time", str(chunk_seconds),
116
- "-reset_timestamps", "1",
117
- out_pattern
118
- ]
119
- try:
120
- subprocess.run(cmd, check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
121
- # collect chunk files
122
- chunks = sorted(glob.glob(os.path.join(tmpdir, "chunk_*.wav")))
123
- return chunks or [wav_path]
124
- except Exception:
125
- # on failure, cleanup and fallback
126
- shutil.rmtree(tmpdir, ignore_errors=True)
127
- return [wav_path]
128
-
129
- def _cleanup_paths(paths: list):
130
- for p in paths:
131
- try:
132
- if os.path.isdir(p):
133
- shutil.rmtree(p, ignore_errors=True)
134
- else:
135
- os.remove(p)
136
- except Exception:
137
- pass
138
-
139
- # -----------------------
140
- # Transcription core
141
- # -----------------------
142
- def transcribe_file(path: str, task: str = "transcribe") -> dict:
143
- """
144
- Transcribe (or translate) the provided file path.
145
- This:
146
- - normalizes audio to mono-16k WAV (via ffmpeg if available),
147
- - splits into CHUNK_SECONDS segments (if ffmpeg present),
148
- - transcribes segments sequentially with whisper and concatenates text.
149
- Returns a dict: {"text": ..., "language": ..., "duration": ...}
150
- """
151
- temp_to_cleanup = []
152
- try:
153
- # Ensure WAV 16k mono
154
- std_wav = _ensure_wav_mono_16k(path)
155
- if std_wav != path:
156
- temp_to_cleanup.append(std_wav)
157
-
158
- # Split into chunks
159
- chunks = _split_into_chunks(std_wav, CHUNK_SECONDS)
160
- # if chunking created files in a tempdir, ensure that dir removed later
161
- if len(chunks) > 1:
162
- temp_to_cleanup.extend(chunks)
163
- # note: we added chunk files individually; _split_into_chunks will have created a tmpdir.
164
- # we'll remove the chunk files and directory in cleanup below.
165
-
166
- full_text_parts = []
167
- language_detected = None
168
- duration_total = 0.0
169
-
170
- for idx, cpath in enumerate(chunks):
171
- # call model.transcribe on each chunk
172
- # We use same task (transcribe/translate) and FP16 flag accordingly.
173
- try:
174
- result = MODEL.transcribe(cpath, task=task, language=None, fp16=FP16)
175
- except Exception as e:
176
- # If a chunk fails, try once with fp16=False (safe fallback)
177
- try:
178
- result = MODEL.transcribe(cpath, task=task, language=None, fp16=False)
179
- except Exception as e2:
180
- # give up on this chunk but continue
181
- result = {"text": "", "language": None, "duration": 0.0}
182
-
183
- text = (result.get("text") or "").strip()
184
- if text:
185
- full_text_parts.append(text)
186
-
187
- # populate top-level language/duration from the last successful chunk if available
188
- if not language_detected and result.get("language"):
189
- language_detected = result.get("language")
190
- try:
191
- duration_total += float(result.get("duration") or 0.0)
192
- except Exception:
193
- pass
194
-
195
- # join with sensible spacing
196
- full_text = " ".join([p for p in full_text_parts if p])
197
-
198
- return {
199
- "text": full_text.strip(),
200
- "language": language_detected or "",
201
- "duration": duration_total
202
- }
203
- finally:
204
- # cleanup any temp files and chunk dirs
205
- _cleanup_paths(list(set(temp_to_cleanup)))
206
-
207
- # -----------------------
208
- # FastAPI endpoints
209
- # -----------------------
210
  @app.post("/api/transcribe")
211
  async def transcribe(audio: UploadFile = File(...)):
212
  if not audio:
213
  raise HTTPException(status_code=400, detail="No audio provided")
214
  path = _save_temp(audio)
215
  try:
216
- result = transcribe_file(path, task="transcribe")
217
  return {
218
  "text": result.get("text", "").strip(),
219
  "language": result.get("language", ""),
220
  "duration": float(result.get("duration") or 0.0)
221
  }
222
  finally:
223
- try:
224
- os.remove(path)
225
- except Exception:
226
- pass
227
 
228
  @app.post("/api/translate")
229
  async def translate(audio: UploadFile = File(...), target_language: str = Form(...)):
230
  if not audio:
231
  raise HTTPException(status_code=400, detail="No audio provided")
232
- if target_language.strip().lower() not in {"en", "eng", "english"}:
233
  raise HTTPException(status_code=400, detail="Whisper only translates to English")
234
  path = _save_temp(audio)
235
  try:
236
- result = transcribe_file(path, task="translate")
237
  return {
238
  "text": result.get("text", "").strip(),
239
  "source_language": result.get("language", ""),
@@ -241,46 +62,35 @@ async def translate(audio: UploadFile = File(...), target_language: str = Form(.
241
  "duration": float(result.get("duration") or 0.0)
242
  }
243
  finally:
244
- try:
245
- os.remove(path)
246
- except Exception:
247
- pass
248
 
249
  @app.get("/")
250
  async def root():
251
  return {"message": "Whisper API is running. Use /api/transcribe or /api/translate."}
252
 
253
- # -----------------------
254
- # Gradio UI
255
- # -----------------------
256
  def gradio_ui():
257
  with gr.Blocks() as demo:
258
  gr.Markdown("## 🎙️ Whisper API Demo")
259
  with gr.Row():
260
- audio_input = gr.Audio(label="Upload audio or record", type="filepath")
261
- translate_checkbox = gr.Checkbox(label="Translate to English", value=False)
262
- output = gr.Textbox(label="Transcription / Translation", lines=6)
263
  btn = gr.Button("Transcribe")
264
 
265
- def transcribe_gr(audio_path, do_translate):
266
- if audio_path is None or audio_path == "":
 
267
  return "No audio provided."
268
- task = "translate" if do_translate else "transcribe"
269
- # use the same internal function used by the API endpoints
270
- result = transcribe_file(audio_path, task=task)
271
  return result.get("text", "").strip()
272
 
273
- btn.click(fn=transcribe_gr, inputs=[audio_input, translate_checkbox], outputs=output)
274
  return demo
275
 
276
- # -----------------------
277
- # Mount Gradio inside FastAPI
278
- # -----------------------
279
  demo = gradio_ui()
280
  gr.mount_gradio_app(app, demo, path="/")
281
 
282
- # -----------------------
283
- # Run server (local)
284
- # -----------------------
285
  if __name__ == "__main__":
286
  uvicorn.run(app, host="0.0.0.0", port=7860)
 
1
  import os
2
  import tempfile
 
 
 
3
  import torch
4
  import whisper
5
  import gradio as gr
 
7
  from fastapi.middleware.cors import CORSMiddleware
8
  import uvicorn
9
 
10
+ # 🔹 Load Whisper model
 
 
 
 
 
 
 
 
11
  MODEL_NAME = os.getenv("WHISPER_MODEL", "base")
12
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
 
 
 
 
 
 
 
 
 
 
 
 
13
  MODEL = whisper.load_model(MODEL_NAME, device=DEVICE)
14
 
15
+ # 🔹 FastAPI app
 
 
16
  app = FastAPI(title="Whisper API")
17
 
18
  app.add_middleware(
 
23
  allow_headers=["*"],
24
  )
25
 
26
+ # 🔹 Utility to save uploaded files temporarily
 
 
27
  def _save_temp(upload: UploadFile) -> str:
 
28
  suffix = os.path.splitext(upload.filename or "audio")[1] or ".wav"
29
  with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp:
30
  tmp.write(upload.file.read())
31
  return tmp.name
32
 
33
+ # 🔹 API endpoints
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  @app.post("/api/transcribe")
35
  async def transcribe(audio: UploadFile = File(...)):
36
  if not audio:
37
  raise HTTPException(status_code=400, detail="No audio provided")
38
  path = _save_temp(audio)
39
  try:
40
+ result = MODEL.transcribe(path, task="transcribe", language=None, fp16=(DEVICE=="cuda"))
41
  return {
42
  "text": result.get("text", "").strip(),
43
  "language": result.get("language", ""),
44
  "duration": float(result.get("duration") or 0.0)
45
  }
46
  finally:
47
+ os.remove(path)
 
 
 
48
 
49
  @app.post("/api/translate")
50
  async def translate(audio: UploadFile = File(...), target_language: str = Form(...)):
51
  if not audio:
52
  raise HTTPException(status_code=400, detail="No audio provided")
53
+ if target_language.strip().lower() not in {"en","eng","english"}:
54
  raise HTTPException(status_code=400, detail="Whisper only translates to English")
55
  path = _save_temp(audio)
56
  try:
57
+ result = MODEL.transcribe(path, task="translate", language=None, fp16=(DEVICE=="cuda"))
58
  return {
59
  "text": result.get("text", "").strip(),
60
  "source_language": result.get("language", ""),
 
62
  "duration": float(result.get("duration") or 0.0)
63
  }
64
  finally:
65
+ os.remove(path)
 
 
 
66
 
67
  @app.get("/")
68
  async def root():
69
  return {"message": "Whisper API is running. Use /api/transcribe or /api/translate."}
70
 
71
+ # 🔹 Gradio UI
 
 
72
  def gradio_ui():
73
  with gr.Blocks() as demo:
74
  gr.Markdown("## 🎙️ Whisper API Demo")
75
  with gr.Row():
76
+ audio_input = gr.Audio(label="Upload audio", type="filepath") # fixed: no 'source'
77
+ output = gr.Textbox(label="Transcription")
 
78
  btn = gr.Button("Transcribe")
79
 
80
+ # Directly call Whisper model, no internal HTTP request
81
+ def transcribe_gr(audio_path):
82
+ if audio_path is None:
83
  return "No audio provided."
84
+ result = MODEL.transcribe(audio_path, task="transcribe", language=None, fp16=(DEVICE=="cuda"))
 
 
85
  return result.get("text", "").strip()
86
 
87
+ btn.click(fn=transcribe_gr, inputs=audio_input, outputs=output)
88
  return demo
89
 
90
+ # 🔹 Mount Gradio inside FastAPI
 
 
91
  demo = gradio_ui()
92
  gr.mount_gradio_app(app, demo, path="/")
93
 
94
+ # 🔹 Run server locally
 
 
95
  if __name__ == "__main__":
96
  uvicorn.run(app, host="0.0.0.0", port=7860)
requirements.txt CHANGED
@@ -1,18 +1,6 @@
1
- # Core app
2
- fastapi==0.116.1
3
- uvicorn[standard]==0.35.0
4
- gradio==5.42.0
5
-
6
- # Whisper (OpenAI's official package)
7
- openai-whisper==20250625
8
-
9
- # FastAPI file/form parsing
10
- python-multipart==0.0.20
11
-
12
- # Useful libs
13
- numpy>=1.25
14
- soundfile>=0.12.1
15
- requests>=2.31.0
16
-
17
- # Optional but HIGHLY recommended for CPU performance (see notes)
18
- faster-whisper>=0.8.0
 
1
+ fastapi
2
+ uvicorn
3
+ gradio
4
+ openai-whisper
5
+ torch
6
+ python-multipart