Spaces:
Sleeping
Sleeping
Reduce Whisper hallucinations: condition_on_previous_text=False, temperature fallback, light VAD, no_speech filter
cd96775 verified | #!/usr/bin/env python3 | |
| """ | |
| Tahkik Inference Server — Hugging Face Space entry point. | |
| Loads the Whisper model ONCE at startup via faster-whisper (CTranslate2), | |
| then serves: | |
| - POST /evaluate — batch transcription (upload a full audio file) | |
| - WS /ws/stream — real-time streaming transcription (send PCM chunks) | |
| """ | |
| import asyncio | |
| import json | |
| import math | |
| import os | |
| import time | |
| import tempfile | |
| # Redirect model caches to /tmp (only writable dir in HF Spaces) | |
| os.environ.setdefault("HF_HOME", "/tmp/huggingface_cache") | |
| os.environ.setdefault("HF_HUB_DISABLE_PROGRESS_BARS", "1") | |
| os.environ.setdefault("CT2_VERBOSE", "0") | |
| import numpy as np | |
| from fastapi import FastAPI, File, UploadFile, HTTPException, WebSocket, WebSocketDisconnect | |
| from fastapi.responses import JSONResponse | |
| from faster_whisper import WhisperModel | |
| # --------------------------------------------------------------------------- | |
| # Constants | |
| # --------------------------------------------------------------------------- | |
| TAHKIK_MODEL = "benhadjermed/tahkik-small-warsh-ct2" | |
| SAMPLE_RATE = 16000 | |
| CHUNK_LENGTH_S = 30 | |
| OVERLAP_S = 1 | |
| # Minimum seconds of audio before running partial inference (reduces hallucinations) | |
| MIN_AUDIO_FOR_INFERENCE_S = 1.0 | |
| MIN_SAMPLES_FOR_INFERENCE = int(MIN_AUDIO_FOR_INFERENCE_S * SAMPLE_RATE) | |
| SILENCE_THRESHOLD = 0.02 # RMS threshold for silence | |
| SILENCE_DURATION_S = 0.8 # seconds of trailing silence to trigger finalization | |
| SILENCE_SAMPLES = int(SILENCE_DURATION_S * SAMPLE_RATE) | |
| # faster-whisper transcribe options shared by every inference call. | |
| # Standard anti-hallucination knobs — see openai/whisper#679. | |
| WHISPER_OPTS = dict( | |
| language="ar", | |
| task="transcribe", | |
| # Lightweight VAD strips long silence chunks the model would | |
| # otherwise hallucinate into, while keeping word endings intact. | |
| vad_filter=True, | |
| vad_parameters=dict(min_silence_duration_ms=300, threshold=0.35), | |
| # Standard temperature-fallback chain. Decodes that fail the | |
| # compression-ratio or log-prob check are retried at the next | |
| # temperature, then dropped if still bad. | |
| temperature=[0.0, 0.2, 0.4, 0.6, 0.8, 1.0], | |
| compression_ratio_threshold=2.4, | |
| log_prob_threshold=-1.0, | |
| no_speech_threshold=0.6, | |
| # Each window is a fresh decode — kills loop hallucinations. | |
| condition_on_previous_text=False, | |
| ) | |
| # Drop any segment the model itself flagged as likely non-speech. | |
| NO_SPEECH_PROB_DROP_THRESHOLD = 0.7 | |
| ALLOWED_EXTS = {".wav", ".m4a", ".mp3", ".flac", ".ogg"} | |
| # --------------------------------------------------------------------------- | |
| # Model loading (happens once at module import / server startup) | |
| # --------------------------------------------------------------------------- | |
| print("[inference] loading faster-whisper model...", flush=True) | |
| model = WhisperModel( | |
| TAHKIK_MODEL, | |
| device="cpu", | |
| compute_type="int8", | |
| download_root="/tmp/huggingface_cache", | |
| ) | |
| print("[inference] model ready", flush=True) | |
| # Global inference lock — one inference at a time to avoid resource contention. | |
| _inference_lock = asyncio.Lock() | |
| # --------------------------------------------------------------------------- | |
| # FastAPI app | |
| # --------------------------------------------------------------------------- | |
| app = FastAPI(title="Tahkik Inference API") | |
| def health(): | |
| return {"status": "ok"} | |
| # --------------------------------------------------------------------------- | |
| # POST /evaluate — batch transcription (backward compatible) | |
| # --------------------------------------------------------------------------- | |
| async def evaluate(audio: UploadFile = File(...)): | |
| filename = audio.filename or "recording.wav" | |
| ext = os.path.splitext(filename)[1].lower() or ".wav" | |
| if ext not in ALLOWED_EXTS: | |
| raise HTTPException(status_code=400, detail=f"unsupported audio format: {ext}") | |
| data = await audio.read() | |
| with tempfile.NamedTemporaryFile(suffix=ext, delete=False, dir="/tmp") as f: | |
| f.write(data) | |
| tmp_path = f.name | |
| try: | |
| result = _transcribe_file(tmp_path) | |
| except Exception as exc: | |
| raise HTTPException(status_code=500, detail=str(exc)) | |
| finally: | |
| os.unlink(tmp_path) | |
| return JSONResponse(result) | |
| # --------------------------------------------------------------------------- | |
| # WS /ws/stream — real-time streaming transcription | |
| # --------------------------------------------------------------------------- | |
| async def stream_transcribe(ws: WebSocket): | |
| """ | |
| Real-time streaming transcription over WebSocket. | |
| Protocol: | |
| Client → Server: | |
| - Binary frames: raw PCM 16-bit signed LE, 16 kHz, mono | |
| - Text frame: JSON {"type": "stop"} to signal end of recording | |
| Server → Client: | |
| - Text frames: JSON messages | |
| {"type": "partial", "text": "..."} — intermediate transcription | |
| {"type": "final", "text": "...", "confidence": 0.94, "processing_time_ms": 1234} | |
| {"type": "error", "message": "..."} | |
| """ | |
| await ws.accept() | |
| print("[ws] client connected", flush=True) | |
| # Accumulate raw PCM bytes from the client. | |
| audio_buffer = bytearray() | |
| session_text = "" | |
| last_inference_len = 0 # track buffer size at last inference to avoid redundant runs | |
| async def _run_partial(pcm_data: bytes): | |
| try: | |
| async with _inference_lock: | |
| text = await asyncio.get_event_loop().run_in_executor( | |
| None, _transcribe_pcm_buffer, pcm_data | |
| ) | |
| full_text = (session_text + " " + text).strip() | |
| try: | |
| await ws.send_json({"type": "partial", "text": full_text}) | |
| except Exception: | |
| pass # Connection likely closed | |
| except Exception as e: | |
| import traceback | |
| err_msg = traceback.format_exc() | |
| print(f"[ws] partial inference error:\n{err_msg}", flush=True) | |
| try: | |
| while True: | |
| message = await ws.receive() | |
| # --- Binary frame: audio chunk -------------------------------- | |
| if "bytes" in message and message["bytes"] is not None: | |
| audio_buffer.extend(message["bytes"]) | |
| # Only run inference if we have enough new audio. | |
| buffer_samples = len(audio_buffer) // 2 # 16-bit = 2 bytes/sample | |
| new_samples = buffer_samples - (last_inference_len // 2) | |
| if buffer_samples >= MIN_SAMPLES_FOR_INFERENCE: | |
| if _has_trailing_silence(bytes(audio_buffer), SILENCE_THRESHOLD, SILENCE_SAMPLES): | |
| print(f"[ws] auto-finalizing chunk due to silence", flush=True) | |
| async with _inference_lock: | |
| chunk_text = await asyncio.get_event_loop().run_in_executor( | |
| None, _transcribe_pcm_buffer, bytes(audio_buffer) | |
| ) | |
| session_text = (session_text + " " + chunk_text).strip() | |
| try: | |
| await ws.send_json({"type": "partial", "text": session_text}) | |
| except RuntimeError: | |
| # Client closed connection while we were running inference | |
| break | |
| audio_buffer = bytearray() | |
| last_inference_len = 0 | |
| continue | |
| # Prevent OOM if mic is left on but user is entirely silent for 10s | |
| if buffer_samples > SAMPLE_RATE * 10: | |
| audio_array = _pcm_bytes_to_float32(bytes(audio_buffer)) | |
| if np.sqrt(np.mean(audio_array ** 2)) < SILENCE_THRESHOLD * 2: | |
| print("[ws] buffer full of purely silence, dropping...", flush=True) | |
| audio_buffer = bytearray() | |
| last_inference_len = 0 | |
| continue | |
| if new_samples >= (SAMPLE_RATE // 2): | |
| # Run partial inference ONLY if the lock is free. | |
| # This prevents thousands of requests from queuing and timing out the final run. | |
| if not _inference_lock.locked(): | |
| last_inference_len = len(audio_buffer) | |
| # Run in background so ws.receive() is not blocked. | |
| asyncio.create_task(_run_partial(bytes(audio_buffer))) | |
| # --- Text frame: control message ------------------------------ | |
| elif "text" in message and message["text"] is not None: | |
| try: | |
| msg = json.loads(message["text"]) | |
| except json.JSONDecodeError: | |
| try: | |
| await ws.send_json({"type": "error", "message": "invalid JSON"}) | |
| except RuntimeError: | |
| pass | |
| continue | |
| if msg.get("type") == "stop": | |
| print(f"[ws] stop received, buffer size: {len(audio_buffer)} bytes", flush=True) | |
| buffer_samples = len(audio_buffer) // 2 | |
| if buffer_samples < MIN_SAMPLES_FOR_INFERENCE: | |
| try: | |
| await ws.send_json({ | |
| "type": "final", | |
| "text": session_text, | |
| "confidence": 1.0, | |
| "processing_time_ms": 0, | |
| }) | |
| except RuntimeError: | |
| pass | |
| else: | |
| t_start = time.time() | |
| async with _inference_lock: | |
| text, confidence = await asyncio.get_event_loop().run_in_executor( | |
| None, _transcribe_pcm_buffer_with_confidence, bytes(audio_buffer) | |
| ) | |
| elapsed = int((time.time() - t_start) * 1000) | |
| final_text = (session_text + " " + text).strip() | |
| try: | |
| await ws.send_json({ | |
| "type": "final", | |
| "text": final_text, | |
| "confidence": confidence, | |
| "processing_time_ms": elapsed, | |
| }) | |
| except RuntimeError: | |
| pass | |
| # Reset for potential next session on the same connection. | |
| audio_buffer = bytearray() | |
| session_text = "" | |
| last_inference_len = 0 | |
| break # Close after final result. | |
| except WebSocketDisconnect: | |
| print("[ws] client disconnected", flush=True) | |
| except Exception as exc: | |
| import traceback | |
| print(f"[ws] error:\n{traceback.format_exc()}", flush=True) | |
| try: | |
| await ws.send_json({"type": "error", "message": str(exc)}) | |
| except Exception: | |
| pass | |
| finally: | |
| try: | |
| await ws.close() | |
| except Exception: | |
| pass | |
| print("[ws] connection closed", flush=True) | |
| # --------------------------------------------------------------------------- | |
| # Inference helpers | |
| # --------------------------------------------------------------------------- | |
| def _pcm_bytes_to_float32(pcm_bytes: bytes) -> np.ndarray: | |
| """Convert raw PCM 16-bit signed LE bytes to float32 numpy array in [-1, 1].""" | |
| int16_array = np.frombuffer(pcm_bytes, dtype=np.int16) | |
| return int16_array.astype(np.float32) / 32768.0 | |
| def _has_trailing_silence(pcm_bytes: bytes, threshold: float, duration_samples: int) -> bool: | |
| """Check if buffer ends with N seconds of silence below threshold, AND had speech before it.""" | |
| if len(pcm_bytes) < duration_samples * 2: | |
| return False | |
| audio_array = _pcm_bytes_to_float32(pcm_bytes) | |
| trailing = audio_array[-duration_samples:] | |
| rms = np.sqrt(np.mean(trailing ** 2)) | |
| if rms < threshold: | |
| # Require some actual speech before the trailing silence to count as "trailing silence" | |
| leading = audio_array[:-duration_samples] | |
| if len(leading) > 0: | |
| leading_rms = np.sqrt(np.mean(leading ** 2)) | |
| if leading_rms > threshold * 1.5: | |
| return True | |
| return False | |
| def _logprob_to_confidence(avg_logprob: float) -> float: | |
| """Convert faster-whisper's avg_logprob to a 0-1 confidence score via exp().""" | |
| return math.exp(max(avg_logprob, -5.0)) # clamp to avoid exp(-inf) = 0 | |
| def _transcribe_pcm_buffer(pcm_bytes: bytes) -> str: | |
| """Run faster-whisper inference on raw PCM buffer, return text only.""" | |
| audio_array = _pcm_bytes_to_float32(pcm_bytes) | |
| # Limit to last 30 seconds (Whisper's context window). | |
| max_samples = CHUNK_LENGTH_S * SAMPLE_RATE | |
| if len(audio_array) > max_samples: | |
| audio_array = audio_array[-max_samples:] | |
| segments, _ = model.transcribe(audio_array, **WHISPER_OPTS) | |
| parts = [ | |
| seg.text.strip() | |
| for seg in segments | |
| if seg.no_speech_prob < NO_SPEECH_PROB_DROP_THRESHOLD | |
| ] | |
| return " ".join(p for p in parts if p) | |
| def _transcribe_pcm_buffer_with_confidence(pcm_bytes: bytes) -> tuple: | |
| """Run faster-whisper inference on raw PCM buffer, return (text, confidence).""" | |
| audio_array = _pcm_bytes_to_float32(pcm_bytes) | |
| chunks = _split_audio(audio_array) | |
| all_texts = [] | |
| all_scores = [] | |
| for chunk in chunks: | |
| segments, _ = model.transcribe(chunk, **WHISPER_OPTS) | |
| chunk_texts = [] | |
| chunk_logprobs = [] | |
| for seg in segments: | |
| if seg.no_speech_prob >= NO_SPEECH_PROB_DROP_THRESHOLD: | |
| continue | |
| chunk_texts.append(seg.text.strip()) | |
| chunk_logprobs.append(seg.avg_logprob) | |
| all_texts.append(" ".join(t for t in chunk_texts if t)) | |
| if chunk_logprobs: | |
| avg = sum(chunk_logprobs) / len(chunk_logprobs) | |
| all_scores.append(_logprob_to_confidence(avg)) | |
| else: | |
| all_scores.append(1.0) | |
| transcription = " ".join(t for t in all_texts if t) | |
| confidence = round(sum(all_scores) / len(all_scores), 4) if all_scores else 0.0 | |
| return transcription, confidence | |
| def _split_audio(audio_array, sr=SAMPLE_RATE, chunk_s=CHUNK_LENGTH_S, overlap_s=OVERLAP_S): | |
| chunk_len = int(chunk_s * sr) | |
| step_len = int((chunk_s - overlap_s) * sr) | |
| chunks = [] | |
| start = 0 | |
| while start < len(audio_array): | |
| end = min(start + chunk_len, len(audio_array)) | |
| chunks.append(audio_array[start:end]) | |
| start += step_len | |
| remaining = len(audio_array) - start | |
| if 0 < remaining < 2 * sr: | |
| chunks[-1] = audio_array[start - step_len:] | |
| break | |
| return chunks | |
| def _transcribe_file(audio_path: str) -> dict: | |
| import librosa | |
| t_start = time.time() | |
| audio_array, _ = librosa.load(audio_path, sr=SAMPLE_RATE) | |
| chunks = _split_audio(audio_array) | |
| all_texts = [] | |
| all_scores = [] | |
| for chunk in chunks: | |
| segments, _ = model.transcribe(chunk, **WHISPER_OPTS) | |
| chunk_texts = [] | |
| chunk_logprobs = [] | |
| for seg in segments: | |
| if seg.no_speech_prob >= NO_SPEECH_PROB_DROP_THRESHOLD: | |
| continue | |
| chunk_texts.append(seg.text.strip()) | |
| chunk_logprobs.append(seg.avg_logprob) | |
| all_texts.append(" ".join(t for t in chunk_texts if t)) | |
| if chunk_logprobs: | |
| avg = sum(chunk_logprobs) / len(chunk_logprobs) | |
| all_scores.append(_logprob_to_confidence(avg)) | |
| else: | |
| all_scores.append(1.0) | |
| return { | |
| "transcription": " ".join(all_texts), | |
| "confidence_score": round(sum(all_scores) / len(all_scores), 4) if all_scores else 0.0, | |
| "processing_time_ms": int((time.time() - t_start) * 1000), | |
| } | |