benhadjermed commited on
Commit
3f4cf11
·
verified ·
1 Parent(s): 0785bb2

feat: migrate to streaming transcriptions via WebSockets

Browse files
Files changed (4) hide show
  1. Dockerfile +31 -0
  2. README.md +48 -8
  3. main.py +346 -0
  4. requirements.txt +9 -0
Dockerfile ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ── Tahkik Inference Space ──────────────────────────────────────────────────
2
+ # CPU image. To enable GPU (T4/L4/A100), change the base image to:
3
+ # FROM nvidia/cuda:12.1-runtime-ubuntu22.04
4
+ # and replace the pip torch line with the CUDA-specific wheel URL.
5
+ # ---------------------------------------------------------------------------
6
+
7
+ FROM python:3.10-slim
8
+
9
+ # HF Spaces requires a non-root user with UID 1000.
10
+ RUN useradd -m -u 1000 user
11
+
12
+ WORKDIR /home/user/app
13
+
14
+ # Install dependencies as root (before switching user).
15
+ COPY --chown=user requirements.txt .
16
+ RUN pip install --no-cache-dir -r requirements.txt
17
+
18
+ # Copy application code.
19
+ COPY --chown=user . .
20
+
21
+ # Redirect all model/cache downloads to /tmp (only writable path in Spaces).
22
+ ENV HF_HOME=/tmp/huggingface_cache
23
+ ENV TORCH_HOME=/tmp/torch_cache
24
+ ENV TRANSFORMERS_VERBOSITY=error
25
+ ENV HF_HUB_DISABLE_PROGRESS_BARS=1
26
+
27
+ USER user
28
+
29
+ EXPOSE 7860
30
+
31
+ CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "7860"]
README.md CHANGED
@@ -1,12 +1,52 @@
1
  ---
2
  title: Tahkik Basic Warsh
3
- emoji: 📊
4
- colorFrom: blue
5
- colorTo: gray
6
- sdk: gradio
7
- sdk_version: 6.12.0
8
- app_file: app.py
9
- pinned: false
10
  ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  title: Tahkik Basic Warsh
3
+ emoji: 📖
4
+ colorFrom: green
5
+ colorTo: blue
6
+ sdk: docker
7
+ app_port: 7860
 
 
8
  ---
9
 
10
+ # Tahkik Inference API
11
+
12
+ FastAPI inference server for the `benhadjermed/tahkik-basic-warsh` Whisper model.
13
+ Accepts Arabic Quranic audio and returns a transcription with a confidence score.
14
+
15
+ ## Endpoints
16
+
17
+ | Method | Path | Description |
18
+ |--------|-------------|------------------------------|
19
+ | GET | `/health` | Liveness check |
20
+ | POST | `/evaluate` | Transcribe an audio file |
21
+
22
+ ## POST /evaluate
23
+
24
+ **Request** — `multipart/form-data`
25
+
26
+ | Field | Type | Required | Notes |
27
+ |---------|------|----------|--------------------------------------------|
28
+ | `audio` | file | yes | `.wav`, `.mp3`, `.m4a`, `.flac`, or `.ogg` |
29
+
30
+ **Response** — `application/json`
31
+
32
+ ```json
33
+ {
34
+ "transcription": "الحمد لله رب العالمين",
35
+ "confidence_score": 0.9423,
36
+ "processing_time_ms": 1350
37
+ }
38
+ ```
39
+
40
+ **Error** — non-200 status
41
+
42
+ ```json
43
+ {
44
+ "detail": "unsupported audio format: .xyz"
45
+ }
46
+ ```
47
+
48
+ ## Environment / Secrets
49
+
50
+ | Name | Where to set | Purpose |
51
+ |------------|-------------------|------------------------------------------------|
52
+ | `HF_TOKEN` | Space secret | Required if `tahkik-basic-warsh` is private |
main.py ADDED
@@ -0,0 +1,346 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Tahkik Inference Server — Hugging Face Space entry point.
4
+
5
+ Loads the Whisper model ONCE at startup, then serves:
6
+ - POST /evaluate — batch transcription (upload a full audio file)
7
+ - WS /ws/stream — real-time streaming transcription (send PCM chunks)
8
+ """
9
+
10
+ import asyncio
11
+ import json
12
+ import os
13
+ import sys
14
+ import struct
15
+ import time
16
+ import tempfile
17
+
18
+ # Redirect model caches to /tmp (only writable dir in HF Spaces)
19
+ os.environ.setdefault("HF_HOME", "/tmp/huggingface_cache")
20
+ os.environ.setdefault("TORCH_HOME", "/tmp/torch_cache")
21
+ os.environ.setdefault("TRANSFORMERS_VERBOSITY", "error")
22
+ os.environ.setdefault("HF_HUB_DISABLE_PROGRESS_BARS", "1")
23
+
24
+ import numpy as np
25
+ from fastapi import FastAPI, File, UploadFile, HTTPException, WebSocket, WebSocketDisconnect
26
+ from fastapi.responses import JSONResponse
27
+ import torch
28
+ import torch.nn.functional as F
29
+ from transformers import WhisperForConditionalGeneration, WhisperProcessor
30
+
31
+ # ---------------------------------------------------------------------------
32
+ # Constants
33
+ # ---------------------------------------------------------------------------
34
+
35
+ TAHKIK_MODEL = "benhadjermed/tahkik-basic-warsh"
36
+ SAMPLE_RATE = 16000
37
+ CHUNK_LENGTH_S = 30
38
+ OVERLAP_S = 1
39
+
40
+ # Minimum seconds of audio before running partial inference (reduces hallucinations)
41
+ MIN_AUDIO_FOR_INFERENCE_S = 1.5
42
+ MIN_SAMPLES_FOR_INFERENCE = int(MIN_AUDIO_FOR_INFERENCE_S * SAMPLE_RATE)
43
+
44
+ ALLOWED_EXTS = {".wav", ".m4a", ".mp3", ".flac", ".ogg"}
45
+
46
+ # ---------------------------------------------------------------------------
47
+ # Model loading (happens once at module import / server startup)
48
+ # ---------------------------------------------------------------------------
49
+
50
+ print("[inference] importing torch / transformers...", flush=True)
51
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
52
+ torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
53
+ print(f"[inference] device: {device}", flush=True)
54
+
55
+ print("[inference] loading processor (openai/whisper-base)...", flush=True)
56
+ processor = WhisperProcessor.from_pretrained(
57
+ "openai/whisper-base", language="Arabic", task="transcribe"
58
+ )
59
+
60
+ print(f"[inference] loading model ({TAHKIK_MODEL})...", flush=True)
61
+ model = WhisperForConditionalGeneration.from_pretrained(
62
+ TAHKIK_MODEL, torch_dtype=torch_dtype
63
+ ).to(device)
64
+
65
+ # Patch missing generation config fields that some fine-tuned checkpoints omit.
66
+ if not hasattr(model.generation_config, "lang_to_id") or model.generation_config.lang_to_id is None:
67
+ print("[inference] patching generation config from base model...", flush=True)
68
+ _base = WhisperForConditionalGeneration.from_pretrained("openai/whisper-base")
69
+ model.generation_config.lang_to_id = _base.generation_config.lang_to_id
70
+ model.generation_config.id_to_lang = {v: k for k, v in _base.generation_config.lang_to_id.items()}
71
+ model.generation_config.task_to_id = _base.generation_config.task_to_id
72
+ del _base
73
+
74
+ print("[inference] model ready", flush=True)
75
+
76
+ # Global inference lock — one inference at a time to avoid GPU OOM.
77
+ _inference_lock = asyncio.Lock()
78
+
79
+ # ---------------------------------------------------------------------------
80
+ # FastAPI app
81
+ # ---------------------------------------------------------------------------
82
+
83
+ app = FastAPI(title="Tahkik Inference API")
84
+
85
+
86
+ @app.get("/health")
87
+ def health():
88
+ return {"status": "ok"}
89
+
90
+
91
+ # ---------------------------------------------------------------------------
92
+ # POST /evaluate — batch transcription (backward compatible)
93
+ # ---------------------------------------------------------------------------
94
+
95
+ @app.post("/evaluate")
96
+ async def evaluate(audio: UploadFile = File(...)):
97
+ filename = audio.filename or "recording.wav"
98
+ ext = os.path.splitext(filename)[1].lower() or ".wav"
99
+ if ext not in ALLOWED_EXTS:
100
+ raise HTTPException(status_code=400, detail=f"unsupported audio format: {ext}")
101
+
102
+ data = await audio.read()
103
+ with tempfile.NamedTemporaryFile(suffix=ext, delete=False, dir="/tmp") as f:
104
+ f.write(data)
105
+ tmp_path = f.name
106
+
107
+ try:
108
+ result = _transcribe_file(tmp_path)
109
+ except Exception as exc:
110
+ raise HTTPException(status_code=500, detail=str(exc))
111
+ finally:
112
+ os.unlink(tmp_path)
113
+
114
+ return JSONResponse(result)
115
+
116
+
117
+ # ---------------------------------------------------------------------------
118
+ # WS /ws/stream — real-time streaming transcription
119
+ # ---------------------------------------------------------------------------
120
+
121
+ @app.websocket("/ws/stream")
122
+ async def stream_transcribe(ws: WebSocket):
123
+ """
124
+ Real-time streaming transcription over WebSocket.
125
+
126
+ Protocol:
127
+ Client → Server:
128
+ - Binary frames: raw PCM 16-bit signed LE, 16 kHz, mono
129
+ - Text frame: JSON {"type": "stop"} to signal end of recording
130
+
131
+ Server → Client:
132
+ - Text frames: JSON messages
133
+ {"type": "partial", "text": "..."} — intermediate transcription
134
+ {"type": "final", "text": "...", "confidence": 0.94, "processing_time_ms": 1234}
135
+ {"type": "error", "message": "..."}
136
+ """
137
+ await ws.accept()
138
+ print("[ws] client connected", flush=True)
139
+
140
+ # Accumulate raw PCM bytes from the client.
141
+ audio_buffer = bytearray()
142
+ last_inference_len = 0 # track buffer size at last inference to avoid redundant runs
143
+
144
+ try:
145
+ while True:
146
+ message = await ws.receive()
147
+
148
+ # --- Binary frame: audio chunk --------------------------------
149
+ if "bytes" in message and message["bytes"] is not None:
150
+ audio_buffer.extend(message["bytes"])
151
+
152
+ # Only run inference if we have enough new audio.
153
+ buffer_samples = len(audio_buffer) // 2 # 16-bit = 2 bytes/sample
154
+ new_samples = buffer_samples - (last_inference_len // 2)
155
+
156
+ if buffer_samples >= MIN_SAMPLES_FOR_INFERENCE and new_samples >= (SAMPLE_RATE // 2):
157
+ # Run partial inference on the accumulated buffer.
158
+ async with _inference_lock:
159
+ text = await asyncio.get_event_loop().run_in_executor(
160
+ None, _transcribe_pcm_buffer, bytes(audio_buffer)
161
+ )
162
+ last_inference_len = len(audio_buffer)
163
+
164
+ await ws.send_json({"type": "partial", "text": text})
165
+
166
+ # --- Text frame: control message ------------------------------
167
+ elif "text" in message and message["text"] is not None:
168
+ try:
169
+ msg = json.loads(message["text"])
170
+ except json.JSONDecodeError:
171
+ await ws.send_json({"type": "error", "message": "invalid JSON"})
172
+ continue
173
+
174
+ if msg.get("type") == "stop":
175
+ print(f"[ws] stop received, buffer size: {len(audio_buffer)} bytes", flush=True)
176
+
177
+ buffer_samples = len(audio_buffer) // 2
178
+ if buffer_samples < MIN_SAMPLES_FOR_INFERENCE:
179
+ await ws.send_json({
180
+ "type": "final",
181
+ "text": "",
182
+ "confidence": 0.0,
183
+ "processing_time_ms": 0,
184
+ })
185
+ else:
186
+ t_start = time.time()
187
+ async with _inference_lock:
188
+ text, confidence = await asyncio.get_event_loop().run_in_executor(
189
+ None, _transcribe_pcm_buffer_with_confidence, bytes(audio_buffer)
190
+ )
191
+ elapsed = int((time.time() - t_start) * 1000)
192
+
193
+ await ws.send_json({
194
+ "type": "final",
195
+ "text": text,
196
+ "confidence": confidence,
197
+ "processing_time_ms": elapsed,
198
+ })
199
+
200
+ # Reset for potential next session on the same connection.
201
+ audio_buffer = bytearray()
202
+ last_inference_len = 0
203
+ break # Close after final result.
204
+
205
+ except WebSocketDisconnect:
206
+ print("[ws] client disconnected", flush=True)
207
+ except Exception as exc:
208
+ print(f"[ws] error: {exc}", flush=True)
209
+ try:
210
+ await ws.send_json({"type": "error", "message": str(exc)})
211
+ except Exception:
212
+ pass
213
+ finally:
214
+ try:
215
+ await ws.close()
216
+ except Exception:
217
+ pass
218
+ print("[ws] connection closed", flush=True)
219
+
220
+
221
+ # ---------------------------------------------------------------------------
222
+ # Inference helpers
223
+ # ---------------------------------------------------------------------------
224
+
225
+ def _pcm_bytes_to_float32(pcm_bytes: bytes) -> np.ndarray:
226
+ """Convert raw PCM 16-bit signed LE bytes to float32 numpy array in [-1, 1]."""
227
+ int16_array = np.frombuffer(pcm_bytes, dtype=np.int16)
228
+ return int16_array.astype(np.float32) / 32768.0
229
+
230
+
231
+ def _transcribe_pcm_buffer(pcm_bytes: bytes) -> str:
232
+ """Run Whisper inference on raw PCM buffer, return text only."""
233
+ audio_array = _pcm_bytes_to_float32(pcm_bytes)
234
+
235
+ # Limit to last 30 seconds (Whisper's context window).
236
+ max_samples = CHUNK_LENGTH_S * SAMPLE_RATE
237
+ if len(audio_array) > max_samples:
238
+ audio_array = audio_array[-max_samples:]
239
+
240
+ inputs = processor(
241
+ audio_array, sampling_rate=SAMPLE_RATE, return_tensors="pt"
242
+ ).input_features.to(device, dtype=torch_dtype)
243
+
244
+ with torch.no_grad():
245
+ outputs = model.generate(
246
+ inputs,
247
+ language="ar",
248
+ task="transcribe",
249
+ )
250
+
251
+ text = processor.batch_decode(outputs, skip_special_tokens=True)[0].strip()
252
+ return text
253
+
254
+
255
+ def _transcribe_pcm_buffer_with_confidence(pcm_bytes: bytes) -> tuple:
256
+ """Run Whisper inference on raw PCM buffer, return (text, confidence)."""
257
+ audio_array = _pcm_bytes_to_float32(pcm_bytes)
258
+
259
+ chunks = _split_audio(audio_array)
260
+ all_texts = []
261
+ all_scores = []
262
+
263
+ for chunk in chunks:
264
+ inputs = processor(
265
+ chunk, sampling_rate=SAMPLE_RATE, return_tensors="pt"
266
+ ).input_features.to(device, dtype=torch_dtype)
267
+
268
+ with torch.no_grad():
269
+ outputs = model.generate(
270
+ inputs,
271
+ language="ar",
272
+ task="transcribe",
273
+ return_dict_in_generate=True,
274
+ output_scores=True,
275
+ )
276
+
277
+ text = processor.batch_decode(outputs.sequences, skip_special_tokens=True)[0].strip()
278
+ all_texts.append(text)
279
+
280
+ if outputs.scores:
281
+ token_probs = [F.softmax(s, dim=-1).max(dim=-1).values for s in outputs.scores]
282
+ chunk_score = float(sum(p.mean().item() for p in token_probs) / len(token_probs))
283
+ else:
284
+ chunk_score = 1.0
285
+ all_scores.append(chunk_score)
286
+
287
+ transcription = " ".join(all_texts)
288
+ confidence = round(sum(all_scores) / len(all_scores), 4) if all_scores else 0.0
289
+ return transcription, confidence
290
+
291
+
292
+ def _split_audio(audio_array, sr=SAMPLE_RATE, chunk_s=CHUNK_LENGTH_S, overlap_s=OVERLAP_S):
293
+ chunk_len = int(chunk_s * sr)
294
+ step_len = int((chunk_s - overlap_s) * sr)
295
+ chunks = []
296
+ start = 0
297
+ while start < len(audio_array):
298
+ end = min(start + chunk_len, len(audio_array))
299
+ chunks.append(audio_array[start:end])
300
+ start += step_len
301
+ remaining = len(audio_array) - start
302
+ if 0 < remaining < 2 * sr:
303
+ chunks[-1] = audio_array[start - step_len:]
304
+ break
305
+ return chunks
306
+
307
+
308
+ def _transcribe_file(audio_path: str) -> dict:
309
+ import librosa
310
+
311
+ t_start = time.time()
312
+ audio_array, _ = librosa.load(audio_path, sr=SAMPLE_RATE)
313
+
314
+ chunks = _split_audio(audio_array)
315
+ all_texts = []
316
+ all_scores = []
317
+
318
+ for chunk in chunks:
319
+ inputs = processor(
320
+ chunk, sampling_rate=SAMPLE_RATE, return_tensors="pt"
321
+ ).input_features.to(device, dtype=torch_dtype)
322
+
323
+ with torch.no_grad():
324
+ outputs = model.generate(
325
+ inputs,
326
+ language="ar",
327
+ task="transcribe",
328
+ return_dict_in_generate=True,
329
+ output_scores=True,
330
+ )
331
+
332
+ text = processor.batch_decode(outputs.sequences, skip_special_tokens=True)[0].strip()
333
+ all_texts.append(text)
334
+
335
+ if outputs.scores:
336
+ token_probs = [F.softmax(s, dim=-1).max(dim=-1).values for s in outputs.scores]
337
+ chunk_score = float(sum(p.mean().item() for p in token_probs) / len(token_probs))
338
+ else:
339
+ chunk_score = 1.0
340
+ all_scores.append(chunk_score)
341
+
342
+ return {
343
+ "transcription": " ".join(all_texts),
344
+ "confidence_score": round(sum(all_scores) / len(all_scores), 4) if all_scores else 0.0,
345
+ "processing_time_ms": int((time.time() - t_start) * 1000),
346
+ }
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ fastapi
2
+ uvicorn[standard]
3
+ torch
4
+ transformers
5
+ librosa
6
+ soundfile
7
+ accelerate
8
+ python-multipart
9
+ numpy