ActualSTT / app.py
Percy3822's picture
Update app.py
4db837f verified
import asyncio, json, os, time
from typing import Optional, List
import numpy as np
from faster_whisper import WhisperModel
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from fastapi.responses import JSONResponse
import httpx
import uvicorn
# =======================
# Config (env overrides)
# =======================
MODEL_NAME = os.environ.get("WHISPER_MODEL", "Systran/faster-whisper-tiny.en")
DEVICE = os.environ.get("WHISPER_DEVICE", "cpu")
COMPUTE_TYPE = os.environ.get("WHISPER_COMPUTE", "int8")
SAMPLE_RATE = int(os.environ.get("STT_SAMPLE_RATE", "16000"))
CHUNK_MS = int(os.environ.get("STT_CHUNK_MS", "20"))
FINAL_SILENCE_SEC = float(os.environ.get("FINAL_SILENCE_SEC", "0.8"))
MAX_BUFFER_SEC = float(os.environ.get("MAX_BUFFER_SEC", "30.0"))
# Optional: notify Brain on finals
BRAIN_URL = (os.environ.get("BRAIN_PROCESS_URL") or "").strip() # e.g. https://Percy3822-Brain_v2.hf.space/process
BRAIN_SECRET = (os.environ.get("BRAIN_SHARED_SECRET") or "").strip()
_NOTIFY = bool(BRAIN_URL)
# =======================
# HTTP client (fixed timeout)
# =======================
_client = httpx.AsyncClient(
timeout=httpx.Timeout(connect=8.0, read=20.0, write=8.0, pool=8.0)
)
# =======================
# Utils
# =======================
def pcm16_to_float32(b: bytes) -> np.ndarray:
a = np.frombuffer(b, dtype=np.int16)
return (a.astype(np.float32) / 32768.0)
def float32_concat(a: np.ndarray, b: np.ndarray) -> np.ndarray:
if a.size == 0: return b
if b.size == 0: return a
return np.concatenate([a, b])
# =======================
# Tiny speaker embedder (placeholder, with fixed _init_)
# =======================
class SpeakerEmbedder:
def _init_(self, model_name: Optional[str] = None, device: str = "cpu"):
self.model_name = model_name or "toy"
self.device = device
self.enrolled: Optional[np.ndarray] = None
def embed(self, audio_f32: np.ndarray, sr: int) -> Optional[np.ndarray]:
if audio_f32 is None or audio_f32.size == 0:
return None
v = audio_f32.astype(np.float32)
return np.array([float(v.mean()), float(v.std())], dtype=np.float32)
def enroll(self, audio_f32: np.ndarray, sr: int) -> bool:
e = self.embed(audio_f32, sr)
if e is None:
return False
self.enrolled = e
return True
_embedder: Optional[SpeakerEmbedder] = SpeakerEmbedder()
# =======================
# Whisper model
# =======================
print(f"[STT] Loading {MODEL_NAME} on {DEVICE} ({COMPUTE_TYPE})", flush=True)
model = WhisperModel(MODEL_NAME, device=DEVICE, compute_type=COMPUTE_TYPE)
# =======================
# FastAPI
# =======================
app = FastAPI(title="ActualSTT (Whisper)")
@app.get("/")
def root():
return JSONResponse({"ok": True, "tip": "Use WS /ws/stt; send {'event':'init','rate':16000} then PCM16 frames"})
@app.get("/health")
def health():
return {
"ok": True,
"engine": "faster-whisper",
"model": MODEL_NAME,
"device": DEVICE,
"compute": COMPUTE_TYPE,
"config": {
"sample_rate": SAMPLE_RATE,
"chunk_ms": CHUNK_MS,
"final_silence_sec": FINAL_SILENCE_SEC,
"max_buffer_sec": MAX_BUFFER_SEC,
},
"tip": "Use WS /ws/stt; send init then raw PCM16 LE frames",
}
# =======================
# Transcribe helpers
# =======================
def transcribe_block(audio_f32: np.ndarray, language: str, prompt: Optional[str]) -> str:
segments, _ = model.transcribe(
audio_f32,
language=(language or "en"),
task="transcribe",
beam_size=1,
vad_filter=False,
temperature=0.0,
no_speech_threshold=0.3,
initial_prompt=(prompt or None),
)
parts: List[str] = []
for seg in segments:
if seg.text:
parts.append(seg.text.strip())
return " ".join(parts).strip()
async def _notify_brain(text: str):
if not _NOTIFY or not text.strip():
return
headers = {"x-auth": BRAIN_SECRET} if BRAIN_SECRET else {}
payload = {"text": text.strip()}
try:
await _client.post(BRAIN_URL, json=payload, headers=headers)
except Exception as e:
print(f"[STT->Brain] notify failed: {e}", flush=True)
# Helper: treat Starlette's "disconnect" RuntimeError as a clean disconnect
async def safe_receive(ws: WebSocket):
try:
return await ws.receive()
except WebSocketDisconnect:
raise
except RuntimeError as e:
if "disconnect" in str(e).lower():
raise WebSocketDisconnect()
raise
# =======================
# WebSocket
# =======================
@app.websocket("/ws/stt")
async def ws_stt(ws: WebSocket):
await ws.accept()
# Parse init
try:
init_msg = await ws.receive_text()
except WebSocketDisconnect:
return
except Exception:
await ws.close(code=1002)
return
try:
init = json.loads(init_msg)
assert init.get("event") == "init", "first message must be {'event':'init','rate':16000}"
client_sr = int(init.get("rate", SAMPLE_RATE))
language = (init.get("language") or "en").strip()
prompt = (init.get("prompt") or "").strip() or None
if client_sr != SAMPLE_RATE:
await ws.send_json({"event":"error","detail":f"Expected rate {SAMPLE_RATE}, got {client_sr}"})
await ws.close(code=1002)
return
except Exception as e:
await ws.send_json({"event":"error","detail":f"Bad init: {e}"})
await ws.close(code=1002)
return
await ws.send_json(json.dumps({"event":"ready","sr": SAMPLE_RATE}))
# State
audio_f32 = np.zeros((0,), dtype=np.float32)
last_emit_text = ""
last_audio_ts = time.time()
FRAME_BYTES = int(SAMPLE_RATE * (CHUNK_MS / 1000.0) * 2)
async def producer():
nonlocal last_emit_text, last_audio_ts
try:
while True:
now = time.time()
# FINAL on brief silence
if last_emit_text and now - last_audio_ts >= FINAL_SILENCE_SEC:
try:
await ws.send_json(json.dumps({"event": "final", "text": last_emit_text}))
except Exception:
pass
# optional Brain notify (fire-and-forget)
asyncio.create_task(_notify_brain(last_emit_text))
last_emit_text = ""
await asyncio.sleep(0.05)
except WebSocketDisconnect:
pass
except Exception:
pass
async def consumer():
nonlocal audio_f32, last_emit_text, last_audio_ts
try:
while True:
msg = await safe_receive(ws)
if "bytes" in msg and msg["bytes"] is not None:
chunk = msg["bytes"]
# handle arbitrary sizes
for off in range(0, len(chunk), FRAME_BYTES):
sub = chunk[off:off+FRAME_BYTES]
if len(sub) == FRAME_BYTES:
audio_f32 = float32_concat(audio_f32, pcm16_to_float32(sub))
last_audio_ts = time.time()
# Cap buffer to avoid runaway memory
max_len = int(MAX_BUFFER_SEC * SAMPLE_RATE)
if audio_f32.size > max_len:
audio_f32 = audio_f32[-max_len:]
# Small tail decode for INTERIM
tail_len = SAMPLE_RATE # 1s tail
tail = audio_f32[-tail_len:] if audio_f32.size > tail_len else audio_f32
if tail.size >= int(0.8 * SAMPLE_RATE):
try:
txt = transcribe_block(tail, language, prompt)
if txt and txt != last_emit_text:
# only send the new suffix if it grows monotonically
new_part = txt[len(last_emit_text):].strip() if txt.startswith(last_emit_text) else txt
if new_part:
try:
await ws.send_json(json.dumps({"event":"interim","text": new_part}))
except Exception:
pass
last_emit_text = txt
except Exception:
# swallow interim decode errors
pass
else:
break
except WebSocketDisconnect:
pass
except Exception:
pass
finally:
try:
await ws.close()
except Exception:
pass
await asyncio.gather(producer(), consumer())
# ================
# Entrypoint
# ================
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=int(os.getenv("PORT", "7860")), workers=1)