benhadjermed's picture
Reduce Whisper hallucinations: condition_on_previous_text=False, temperature fallback, light VAD, no_speech filter
cd96775 verified
#!/usr/bin/env python3
"""
Tahkik Inference Server — Hugging Face Space entry point.
Loads the Whisper model ONCE at startup via faster-whisper (CTranslate2),
then serves:
- POST /evaluate — batch transcription (upload a full audio file)
- WS /ws/stream — real-time streaming transcription (send PCM chunks)
"""
import asyncio
import json
import math
import os
import time
import tempfile
# Redirect model caches to /tmp (only writable dir in HF Spaces)
os.environ.setdefault("HF_HOME", "/tmp/huggingface_cache")
os.environ.setdefault("HF_HUB_DISABLE_PROGRESS_BARS", "1")
os.environ.setdefault("CT2_VERBOSE", "0")
import numpy as np
from fastapi import FastAPI, File, UploadFile, HTTPException, WebSocket, WebSocketDisconnect
from fastapi.responses import JSONResponse
from faster_whisper import WhisperModel
# ---------------------------------------------------------------------------
# Constants
# ---------------------------------------------------------------------------
TAHKIK_MODEL = "benhadjermed/tahkik-small-warsh-ct2"
SAMPLE_RATE = 16000
CHUNK_LENGTH_S = 30
OVERLAP_S = 1
# Minimum seconds of audio before running partial inference (reduces hallucinations)
MIN_AUDIO_FOR_INFERENCE_S = 1.0
MIN_SAMPLES_FOR_INFERENCE = int(MIN_AUDIO_FOR_INFERENCE_S * SAMPLE_RATE)
SILENCE_THRESHOLD = 0.02 # RMS threshold for silence
SILENCE_DURATION_S = 0.8 # seconds of trailing silence to trigger finalization
SILENCE_SAMPLES = int(SILENCE_DURATION_S * SAMPLE_RATE)
# faster-whisper transcribe options shared by every inference call.
# Standard anti-hallucination knobs — see openai/whisper#679.
WHISPER_OPTS = dict(
language="ar",
task="transcribe",
# Lightweight VAD strips long silence chunks the model would
# otherwise hallucinate into, while keeping word endings intact.
vad_filter=True,
vad_parameters=dict(min_silence_duration_ms=300, threshold=0.35),
# Standard temperature-fallback chain. Decodes that fail the
# compression-ratio or log-prob check are retried at the next
# temperature, then dropped if still bad.
temperature=[0.0, 0.2, 0.4, 0.6, 0.8, 1.0],
compression_ratio_threshold=2.4,
log_prob_threshold=-1.0,
no_speech_threshold=0.6,
# Each window is a fresh decode — kills loop hallucinations.
condition_on_previous_text=False,
)
# Drop any segment the model itself flagged as likely non-speech.
NO_SPEECH_PROB_DROP_THRESHOLD = 0.7
ALLOWED_EXTS = {".wav", ".m4a", ".mp3", ".flac", ".ogg"}
# ---------------------------------------------------------------------------
# Model loading (happens once at module import / server startup)
# ---------------------------------------------------------------------------
print("[inference] loading faster-whisper model...", flush=True)
model = WhisperModel(
TAHKIK_MODEL,
device="cpu",
compute_type="int8",
download_root="/tmp/huggingface_cache",
)
print("[inference] model ready", flush=True)
# Global inference lock — one inference at a time to avoid resource contention.
_inference_lock = asyncio.Lock()
# ---------------------------------------------------------------------------
# FastAPI app
# ---------------------------------------------------------------------------
app = FastAPI(title="Tahkik Inference API")
@app.get("/health")
def health():
return {"status": "ok"}
# ---------------------------------------------------------------------------
# POST /evaluate — batch transcription (backward compatible)
# ---------------------------------------------------------------------------
@app.post("/evaluate")
async def evaluate(audio: UploadFile = File(...)):
filename = audio.filename or "recording.wav"
ext = os.path.splitext(filename)[1].lower() or ".wav"
if ext not in ALLOWED_EXTS:
raise HTTPException(status_code=400, detail=f"unsupported audio format: {ext}")
data = await audio.read()
with tempfile.NamedTemporaryFile(suffix=ext, delete=False, dir="/tmp") as f:
f.write(data)
tmp_path = f.name
try:
result = _transcribe_file(tmp_path)
except Exception as exc:
raise HTTPException(status_code=500, detail=str(exc))
finally:
os.unlink(tmp_path)
return JSONResponse(result)
# ---------------------------------------------------------------------------
# WS /ws/stream — real-time streaming transcription
# ---------------------------------------------------------------------------
@app.websocket("/ws/stream")
async def stream_transcribe(ws: WebSocket):
"""
Real-time streaming transcription over WebSocket.
Protocol:
Client → Server:
- Binary frames: raw PCM 16-bit signed LE, 16 kHz, mono
- Text frame: JSON {"type": "stop"} to signal end of recording
Server → Client:
- Text frames: JSON messages
{"type": "partial", "text": "..."} — intermediate transcription
{"type": "final", "text": "...", "confidence": 0.94, "processing_time_ms": 1234}
{"type": "error", "message": "..."}
"""
await ws.accept()
print("[ws] client connected", flush=True)
# Accumulate raw PCM bytes from the client.
audio_buffer = bytearray()
session_text = ""
last_inference_len = 0 # track buffer size at last inference to avoid redundant runs
async def _run_partial(pcm_data: bytes):
try:
async with _inference_lock:
text = await asyncio.get_event_loop().run_in_executor(
None, _transcribe_pcm_buffer, pcm_data
)
full_text = (session_text + " " + text).strip()
try:
await ws.send_json({"type": "partial", "text": full_text})
except Exception:
pass # Connection likely closed
except Exception as e:
import traceback
err_msg = traceback.format_exc()
print(f"[ws] partial inference error:\n{err_msg}", flush=True)
try:
while True:
message = await ws.receive()
# --- Binary frame: audio chunk --------------------------------
if "bytes" in message and message["bytes"] is not None:
audio_buffer.extend(message["bytes"])
# Only run inference if we have enough new audio.
buffer_samples = len(audio_buffer) // 2 # 16-bit = 2 bytes/sample
new_samples = buffer_samples - (last_inference_len // 2)
if buffer_samples >= MIN_SAMPLES_FOR_INFERENCE:
if _has_trailing_silence(bytes(audio_buffer), SILENCE_THRESHOLD, SILENCE_SAMPLES):
print(f"[ws] auto-finalizing chunk due to silence", flush=True)
async with _inference_lock:
chunk_text = await asyncio.get_event_loop().run_in_executor(
None, _transcribe_pcm_buffer, bytes(audio_buffer)
)
session_text = (session_text + " " + chunk_text).strip()
try:
await ws.send_json({"type": "partial", "text": session_text})
except RuntimeError:
# Client closed connection while we were running inference
break
audio_buffer = bytearray()
last_inference_len = 0
continue
# Prevent OOM if mic is left on but user is entirely silent for 10s
if buffer_samples > SAMPLE_RATE * 10:
audio_array = _pcm_bytes_to_float32(bytes(audio_buffer))
if np.sqrt(np.mean(audio_array ** 2)) < SILENCE_THRESHOLD * 2:
print("[ws] buffer full of purely silence, dropping...", flush=True)
audio_buffer = bytearray()
last_inference_len = 0
continue
if new_samples >= (SAMPLE_RATE // 2):
# Run partial inference ONLY if the lock is free.
# This prevents thousands of requests from queuing and timing out the final run.
if not _inference_lock.locked():
last_inference_len = len(audio_buffer)
# Run in background so ws.receive() is not blocked.
asyncio.create_task(_run_partial(bytes(audio_buffer)))
# --- Text frame: control message ------------------------------
elif "text" in message and message["text"] is not None:
try:
msg = json.loads(message["text"])
except json.JSONDecodeError:
try:
await ws.send_json({"type": "error", "message": "invalid JSON"})
except RuntimeError:
pass
continue
if msg.get("type") == "stop":
print(f"[ws] stop received, buffer size: {len(audio_buffer)} bytes", flush=True)
buffer_samples = len(audio_buffer) // 2
if buffer_samples < MIN_SAMPLES_FOR_INFERENCE:
try:
await ws.send_json({
"type": "final",
"text": session_text,
"confidence": 1.0,
"processing_time_ms": 0,
})
except RuntimeError:
pass
else:
t_start = time.time()
async with _inference_lock:
text, confidence = await asyncio.get_event_loop().run_in_executor(
None, _transcribe_pcm_buffer_with_confidence, bytes(audio_buffer)
)
elapsed = int((time.time() - t_start) * 1000)
final_text = (session_text + " " + text).strip()
try:
await ws.send_json({
"type": "final",
"text": final_text,
"confidence": confidence,
"processing_time_ms": elapsed,
})
except RuntimeError:
pass
# Reset for potential next session on the same connection.
audio_buffer = bytearray()
session_text = ""
last_inference_len = 0
break # Close after final result.
except WebSocketDisconnect:
print("[ws] client disconnected", flush=True)
except Exception as exc:
import traceback
print(f"[ws] error:\n{traceback.format_exc()}", flush=True)
try:
await ws.send_json({"type": "error", "message": str(exc)})
except Exception:
pass
finally:
try:
await ws.close()
except Exception:
pass
print("[ws] connection closed", flush=True)
# ---------------------------------------------------------------------------
# Inference helpers
# ---------------------------------------------------------------------------
def _pcm_bytes_to_float32(pcm_bytes: bytes) -> np.ndarray:
"""Convert raw PCM 16-bit signed LE bytes to float32 numpy array in [-1, 1]."""
int16_array = np.frombuffer(pcm_bytes, dtype=np.int16)
return int16_array.astype(np.float32) / 32768.0
def _has_trailing_silence(pcm_bytes: bytes, threshold: float, duration_samples: int) -> bool:
"""Check if buffer ends with N seconds of silence below threshold, AND had speech before it."""
if len(pcm_bytes) < duration_samples * 2:
return False
audio_array = _pcm_bytes_to_float32(pcm_bytes)
trailing = audio_array[-duration_samples:]
rms = np.sqrt(np.mean(trailing ** 2))
if rms < threshold:
# Require some actual speech before the trailing silence to count as "trailing silence"
leading = audio_array[:-duration_samples]
if len(leading) > 0:
leading_rms = np.sqrt(np.mean(leading ** 2))
if leading_rms > threshold * 1.5:
return True
return False
def _logprob_to_confidence(avg_logprob: float) -> float:
"""Convert faster-whisper's avg_logprob to a 0-1 confidence score via exp()."""
return math.exp(max(avg_logprob, -5.0)) # clamp to avoid exp(-inf) = 0
def _transcribe_pcm_buffer(pcm_bytes: bytes) -> str:
"""Run faster-whisper inference on raw PCM buffer, return text only."""
audio_array = _pcm_bytes_to_float32(pcm_bytes)
# Limit to last 30 seconds (Whisper's context window).
max_samples = CHUNK_LENGTH_S * SAMPLE_RATE
if len(audio_array) > max_samples:
audio_array = audio_array[-max_samples:]
segments, _ = model.transcribe(audio_array, **WHISPER_OPTS)
parts = [
seg.text.strip()
for seg in segments
if seg.no_speech_prob < NO_SPEECH_PROB_DROP_THRESHOLD
]
return " ".join(p for p in parts if p)
def _transcribe_pcm_buffer_with_confidence(pcm_bytes: bytes) -> tuple:
"""Run faster-whisper inference on raw PCM buffer, return (text, confidence)."""
audio_array = _pcm_bytes_to_float32(pcm_bytes)
chunks = _split_audio(audio_array)
all_texts = []
all_scores = []
for chunk in chunks:
segments, _ = model.transcribe(chunk, **WHISPER_OPTS)
chunk_texts = []
chunk_logprobs = []
for seg in segments:
if seg.no_speech_prob >= NO_SPEECH_PROB_DROP_THRESHOLD:
continue
chunk_texts.append(seg.text.strip())
chunk_logprobs.append(seg.avg_logprob)
all_texts.append(" ".join(t for t in chunk_texts if t))
if chunk_logprobs:
avg = sum(chunk_logprobs) / len(chunk_logprobs)
all_scores.append(_logprob_to_confidence(avg))
else:
all_scores.append(1.0)
transcription = " ".join(t for t in all_texts if t)
confidence = round(sum(all_scores) / len(all_scores), 4) if all_scores else 0.0
return transcription, confidence
def _split_audio(audio_array, sr=SAMPLE_RATE, chunk_s=CHUNK_LENGTH_S, overlap_s=OVERLAP_S):
chunk_len = int(chunk_s * sr)
step_len = int((chunk_s - overlap_s) * sr)
chunks = []
start = 0
while start < len(audio_array):
end = min(start + chunk_len, len(audio_array))
chunks.append(audio_array[start:end])
start += step_len
remaining = len(audio_array) - start
if 0 < remaining < 2 * sr:
chunks[-1] = audio_array[start - step_len:]
break
return chunks
def _transcribe_file(audio_path: str) -> dict:
import librosa
t_start = time.time()
audio_array, _ = librosa.load(audio_path, sr=SAMPLE_RATE)
chunks = _split_audio(audio_array)
all_texts = []
all_scores = []
for chunk in chunks:
segments, _ = model.transcribe(chunk, **WHISPER_OPTS)
chunk_texts = []
chunk_logprobs = []
for seg in segments:
if seg.no_speech_prob >= NO_SPEECH_PROB_DROP_THRESHOLD:
continue
chunk_texts.append(seg.text.strip())
chunk_logprobs.append(seg.avg_logprob)
all_texts.append(" ".join(t for t in chunk_texts if t))
if chunk_logprobs:
avg = sum(chunk_logprobs) / len(chunk_logprobs)
all_scores.append(_logprob_to_confidence(avg))
else:
all_scores.append(1.0)
return {
"transcription": " ".join(all_texts),
"confidence_score": round(sum(all_scores) / len(all_scores), 4) if all_scores else 0.0,
"processing_time_ms": int((time.time() - t_start) * 1000),
}