#!/usr/bin/env python3 """ Tahkik Inference Server — Hugging Face Space entry point. Loads the Whisper model ONCE at startup via faster-whisper (CTranslate2), then serves: - POST /evaluate — batch transcription (upload a full audio file) - WS /ws/stream — real-time streaming transcription (send PCM chunks) """ import asyncio import json import math import os import time import tempfile # Redirect model caches to /tmp (only writable dir in HF Spaces) os.environ.setdefault("HF_HOME", "/tmp/huggingface_cache") os.environ.setdefault("HF_HUB_DISABLE_PROGRESS_BARS", "1") os.environ.setdefault("CT2_VERBOSE", "0") import numpy as np from fastapi import FastAPI, File, UploadFile, HTTPException, WebSocket, WebSocketDisconnect from fastapi.responses import JSONResponse from faster_whisper import WhisperModel # --------------------------------------------------------------------------- # Constants # --------------------------------------------------------------------------- TAHKIK_MODEL = "benhadjermed/tahkik-small-warsh-ct2" SAMPLE_RATE = 16000 CHUNK_LENGTH_S = 30 OVERLAP_S = 1 # Minimum seconds of audio before running partial inference (reduces hallucinations) MIN_AUDIO_FOR_INFERENCE_S = 1.0 MIN_SAMPLES_FOR_INFERENCE = int(MIN_AUDIO_FOR_INFERENCE_S * SAMPLE_RATE) SILENCE_THRESHOLD = 0.02 # RMS threshold for silence SILENCE_DURATION_S = 0.8 # seconds of trailing silence to trigger finalization SILENCE_SAMPLES = int(SILENCE_DURATION_S * SAMPLE_RATE) # faster-whisper transcribe options shared by every inference call. # Standard anti-hallucination knobs — see openai/whisper#679. WHISPER_OPTS = dict( language="ar", task="transcribe", # Lightweight VAD strips long silence chunks the model would # otherwise hallucinate into, while keeping word endings intact. vad_filter=True, vad_parameters=dict(min_silence_duration_ms=300, threshold=0.35), # Standard temperature-fallback chain. Decodes that fail the # compression-ratio or log-prob check are retried at the next # temperature, then dropped if still bad. temperature=[0.0, 0.2, 0.4, 0.6, 0.8, 1.0], compression_ratio_threshold=2.4, log_prob_threshold=-1.0, no_speech_threshold=0.6, # Each window is a fresh decode — kills loop hallucinations. condition_on_previous_text=False, ) # Drop any segment the model itself flagged as likely non-speech. NO_SPEECH_PROB_DROP_THRESHOLD = 0.7 ALLOWED_EXTS = {".wav", ".m4a", ".mp3", ".flac", ".ogg"} # --------------------------------------------------------------------------- # Model loading (happens once at module import / server startup) # --------------------------------------------------------------------------- print("[inference] loading faster-whisper model...", flush=True) model = WhisperModel( TAHKIK_MODEL, device="cpu", compute_type="int8", download_root="/tmp/huggingface_cache", ) print("[inference] model ready", flush=True) # Global inference lock — one inference at a time to avoid resource contention. _inference_lock = asyncio.Lock() # --------------------------------------------------------------------------- # FastAPI app # --------------------------------------------------------------------------- app = FastAPI(title="Tahkik Inference API") @app.get("/health") def health(): return {"status": "ok"} # --------------------------------------------------------------------------- # POST /evaluate — batch transcription (backward compatible) # --------------------------------------------------------------------------- @app.post("/evaluate") async def evaluate(audio: UploadFile = File(...)): filename = audio.filename or "recording.wav" ext = os.path.splitext(filename)[1].lower() or ".wav" if ext not in ALLOWED_EXTS: raise HTTPException(status_code=400, detail=f"unsupported audio format: {ext}") data = await audio.read() with tempfile.NamedTemporaryFile(suffix=ext, delete=False, dir="/tmp") as f: f.write(data) tmp_path = f.name try: result = _transcribe_file(tmp_path) except Exception as exc: raise HTTPException(status_code=500, detail=str(exc)) finally: os.unlink(tmp_path) return JSONResponse(result) # --------------------------------------------------------------------------- # WS /ws/stream — real-time streaming transcription # --------------------------------------------------------------------------- @app.websocket("/ws/stream") async def stream_transcribe(ws: WebSocket): """ Real-time streaming transcription over WebSocket. Protocol: Client → Server: - Binary frames: raw PCM 16-bit signed LE, 16 kHz, mono - Text frame: JSON {"type": "stop"} to signal end of recording Server → Client: - Text frames: JSON messages {"type": "partial", "text": "..."} — intermediate transcription {"type": "final", "text": "...", "confidence": 0.94, "processing_time_ms": 1234} {"type": "error", "message": "..."} """ await ws.accept() print("[ws] client connected", flush=True) # Accumulate raw PCM bytes from the client. audio_buffer = bytearray() session_text = "" last_inference_len = 0 # track buffer size at last inference to avoid redundant runs async def _run_partial(pcm_data: bytes): try: async with _inference_lock: text = await asyncio.get_event_loop().run_in_executor( None, _transcribe_pcm_buffer, pcm_data ) full_text = (session_text + " " + text).strip() try: await ws.send_json({"type": "partial", "text": full_text}) except Exception: pass # Connection likely closed except Exception as e: import traceback err_msg = traceback.format_exc() print(f"[ws] partial inference error:\n{err_msg}", flush=True) try: while True: message = await ws.receive() # --- Binary frame: audio chunk -------------------------------- if "bytes" in message and message["bytes"] is not None: audio_buffer.extend(message["bytes"]) # Only run inference if we have enough new audio. buffer_samples = len(audio_buffer) // 2 # 16-bit = 2 bytes/sample new_samples = buffer_samples - (last_inference_len // 2) if buffer_samples >= MIN_SAMPLES_FOR_INFERENCE: if _has_trailing_silence(bytes(audio_buffer), SILENCE_THRESHOLD, SILENCE_SAMPLES): print(f"[ws] auto-finalizing chunk due to silence", flush=True) async with _inference_lock: chunk_text = await asyncio.get_event_loop().run_in_executor( None, _transcribe_pcm_buffer, bytes(audio_buffer) ) session_text = (session_text + " " + chunk_text).strip() try: await ws.send_json({"type": "partial", "text": session_text}) except RuntimeError: # Client closed connection while we were running inference break audio_buffer = bytearray() last_inference_len = 0 continue # Prevent OOM if mic is left on but user is entirely silent for 10s if buffer_samples > SAMPLE_RATE * 10: audio_array = _pcm_bytes_to_float32(bytes(audio_buffer)) if np.sqrt(np.mean(audio_array ** 2)) < SILENCE_THRESHOLD * 2: print("[ws] buffer full of purely silence, dropping...", flush=True) audio_buffer = bytearray() last_inference_len = 0 continue if new_samples >= (SAMPLE_RATE // 2): # Run partial inference ONLY if the lock is free. # This prevents thousands of requests from queuing and timing out the final run. if not _inference_lock.locked(): last_inference_len = len(audio_buffer) # Run in background so ws.receive() is not blocked. asyncio.create_task(_run_partial(bytes(audio_buffer))) # --- Text frame: control message ------------------------------ elif "text" in message and message["text"] is not None: try: msg = json.loads(message["text"]) except json.JSONDecodeError: try: await ws.send_json({"type": "error", "message": "invalid JSON"}) except RuntimeError: pass continue if msg.get("type") == "stop": print(f"[ws] stop received, buffer size: {len(audio_buffer)} bytes", flush=True) buffer_samples = len(audio_buffer) // 2 if buffer_samples < MIN_SAMPLES_FOR_INFERENCE: try: await ws.send_json({ "type": "final", "text": session_text, "confidence": 1.0, "processing_time_ms": 0, }) except RuntimeError: pass else: t_start = time.time() async with _inference_lock: text, confidence = await asyncio.get_event_loop().run_in_executor( None, _transcribe_pcm_buffer_with_confidence, bytes(audio_buffer) ) elapsed = int((time.time() - t_start) * 1000) final_text = (session_text + " " + text).strip() try: await ws.send_json({ "type": "final", "text": final_text, "confidence": confidence, "processing_time_ms": elapsed, }) except RuntimeError: pass # Reset for potential next session on the same connection. audio_buffer = bytearray() session_text = "" last_inference_len = 0 break # Close after final result. except WebSocketDisconnect: print("[ws] client disconnected", flush=True) except Exception as exc: import traceback print(f"[ws] error:\n{traceback.format_exc()}", flush=True) try: await ws.send_json({"type": "error", "message": str(exc)}) except Exception: pass finally: try: await ws.close() except Exception: pass print("[ws] connection closed", flush=True) # --------------------------------------------------------------------------- # Inference helpers # --------------------------------------------------------------------------- def _pcm_bytes_to_float32(pcm_bytes: bytes) -> np.ndarray: """Convert raw PCM 16-bit signed LE bytes to float32 numpy array in [-1, 1].""" int16_array = np.frombuffer(pcm_bytes, dtype=np.int16) return int16_array.astype(np.float32) / 32768.0 def _has_trailing_silence(pcm_bytes: bytes, threshold: float, duration_samples: int) -> bool: """Check if buffer ends with N seconds of silence below threshold, AND had speech before it.""" if len(pcm_bytes) < duration_samples * 2: return False audio_array = _pcm_bytes_to_float32(pcm_bytes) trailing = audio_array[-duration_samples:] rms = np.sqrt(np.mean(trailing ** 2)) if rms < threshold: # Require some actual speech before the trailing silence to count as "trailing silence" leading = audio_array[:-duration_samples] if len(leading) > 0: leading_rms = np.sqrt(np.mean(leading ** 2)) if leading_rms > threshold * 1.5: return True return False def _logprob_to_confidence(avg_logprob: float) -> float: """Convert faster-whisper's avg_logprob to a 0-1 confidence score via exp().""" return math.exp(max(avg_logprob, -5.0)) # clamp to avoid exp(-inf) = 0 def _transcribe_pcm_buffer(pcm_bytes: bytes) -> str: """Run faster-whisper inference on raw PCM buffer, return text only.""" audio_array = _pcm_bytes_to_float32(pcm_bytes) # Limit to last 30 seconds (Whisper's context window). max_samples = CHUNK_LENGTH_S * SAMPLE_RATE if len(audio_array) > max_samples: audio_array = audio_array[-max_samples:] segments, _ = model.transcribe(audio_array, **WHISPER_OPTS) parts = [ seg.text.strip() for seg in segments if seg.no_speech_prob < NO_SPEECH_PROB_DROP_THRESHOLD ] return " ".join(p for p in parts if p) def _transcribe_pcm_buffer_with_confidence(pcm_bytes: bytes) -> tuple: """Run faster-whisper inference on raw PCM buffer, return (text, confidence).""" audio_array = _pcm_bytes_to_float32(pcm_bytes) chunks = _split_audio(audio_array) all_texts = [] all_scores = [] for chunk in chunks: segments, _ = model.transcribe(chunk, **WHISPER_OPTS) chunk_texts = [] chunk_logprobs = [] for seg in segments: if seg.no_speech_prob >= NO_SPEECH_PROB_DROP_THRESHOLD: continue chunk_texts.append(seg.text.strip()) chunk_logprobs.append(seg.avg_logprob) all_texts.append(" ".join(t for t in chunk_texts if t)) if chunk_logprobs: avg = sum(chunk_logprobs) / len(chunk_logprobs) all_scores.append(_logprob_to_confidence(avg)) else: all_scores.append(1.0) transcription = " ".join(t for t in all_texts if t) confidence = round(sum(all_scores) / len(all_scores), 4) if all_scores else 0.0 return transcription, confidence def _split_audio(audio_array, sr=SAMPLE_RATE, chunk_s=CHUNK_LENGTH_S, overlap_s=OVERLAP_S): chunk_len = int(chunk_s * sr) step_len = int((chunk_s - overlap_s) * sr) chunks = [] start = 0 while start < len(audio_array): end = min(start + chunk_len, len(audio_array)) chunks.append(audio_array[start:end]) start += step_len remaining = len(audio_array) - start if 0 < remaining < 2 * sr: chunks[-1] = audio_array[start - step_len:] break return chunks def _transcribe_file(audio_path: str) -> dict: import librosa t_start = time.time() audio_array, _ = librosa.load(audio_path, sr=SAMPLE_RATE) chunks = _split_audio(audio_array) all_texts = [] all_scores = [] for chunk in chunks: segments, _ = model.transcribe(chunk, **WHISPER_OPTS) chunk_texts = [] chunk_logprobs = [] for seg in segments: if seg.no_speech_prob >= NO_SPEECH_PROB_DROP_THRESHOLD: continue chunk_texts.append(seg.text.strip()) chunk_logprobs.append(seg.avg_logprob) all_texts.append(" ".join(t for t in chunk_texts if t)) if chunk_logprobs: avg = sum(chunk_logprobs) / len(chunk_logprobs) all_scores.append(_logprob_to_confidence(avg)) else: all_scores.append(1.0) return { "transcription": " ".join(all_texts), "confidence_score": round(sum(all_scores) / len(all_scores), 4) if all_scores else 0.0, "processing_time_ms": int((time.time() - t_start) * 1000), }