webapp / server.py
jenkins1122's picture
Upload server.py with huggingface_hub
2113bcf verified
Raw
History Blame Contribute Delete
6.25 kB
"""
Babelbit SN59 miner server.
Implements the validator's S2S audio protocol on a single endpoint
(default POST /v1/predict). The validator sends:
1) one init request (kind="init") -> reply ready=true + session_id (echo audio params)
2) many predict requests (kind="predict") streaming 80 ms input frames
-> reply target-language (English) audio frames; set out_eos=true when done
3) possibly a few drain predict requests (empty audio, in_eos=true) until out_eos
Audio wire format (from babelbit.utils.predict_audio):
- input frames: float32 little-endian PCM, 24 kHz, mono, 1920 samples/frame (7680 bytes)
- your output: float32 little-endian PCM, 24 kHz (validator down-converts to int16 WAV)
Run:
pip install -r requirements.txt
uvicorn server:app --host 0.0.0.0 --port 8091
# endpoint then lives at POST http://<ip>:8091/v1/predict
"""
from __future__ import annotations
import base64
import os
import uuid
from typing import Any, Dict
import numpy as np
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse
from model import load_model, TARGET_SAMPLE_RATE_HZ
MODEL_NAME = os.getenv("BB_MINER_MODEL", "echo")
PREDICT_PATH = "/" + os.getenv("BB_MINER_PREDICT_ENDPOINT", "v1/predict").lstrip("/")
app = FastAPI(title="babelbit-sn59-miner")
_model = load_model(MODEL_NAME)
_sessions: Dict[str, Any] = {} # session_id -> model state
def _f32_from_b64(audio_b64: str) -> np.ndarray:
if not audio_b64:
return np.zeros(0, np.float32)
raw = base64.b64decode(audio_b64)
return np.frombuffer(raw, dtype="<f4").astype(np.float32, copy=False)
def _b64_from_f32(pcm: np.ndarray) -> tuple[str, int]:
if pcm is None or pcm.size == 0:
return "", 0
raw = np.ascontiguousarray(pcm.astype("<f4", copy=False)).tobytes()
return base64.b64encode(raw).decode("ascii"), len(raw)
CAPTURE = os.getenv("BB_CAPTURE", "0") not in {"0", "", "false", "no"}
_capture: Dict[str, Dict[str, Any]] = {} # session_id -> {meta, frames}
def _handle_init(body: Dict[str, Any]) -> JSONResponse:
session_id = uuid.uuid4().hex
_sessions[session_id] = _model.start_session(
language=body.get("language"),
sample_rate_hz=int(body["sample_rate_hz"]),
channels=int(body["channels"]),
)
if CAPTURE:
# The validator just told us the REAL source language + audio params.
print(f"[CAPTURE] init challenge={body.get('challenge_uid')} utterance={body.get('utterance_id')} "
f"SOURCE_language={body.get('language')} sr={body.get('sample_rate_hz')} "
f"frame_rate={body.get('frame_rate_hz')} dtype={body.get('dtype')}", flush=True)
_capture[session_id] = {"meta": dict(body), "frames": []}
# Echo the audio params back EXACTLY — the validator rejects mismatches
# (sample_rate_hz, channels, dtype must match; frame_samples/frame_rate_hz > 0).
return JSONResponse(
{
"ready": True,
"miner_id": MODEL_NAME,
"session_id": session_id,
"challenge_uid": body["challenge_uid"],
"utterance_id": body["utterance_id"],
"language": body.get("language"),
"sample_rate_hz": int(body["sample_rate_hz"]),
"frame_rate_hz": float(body["frame_rate_hz"]),
"frame_samples": int(body["frame_samples"]),
"dtype": body["dtype"],
"channels": int(body["channels"]),
}
)
def _handle_predict(body: Dict[str, Any]) -> JSONResponse:
session_id = str(body.get("session_id") or "")
state = _sessions.get(session_id)
if state is None:
return JSONResponse({"error": "unknown session_id"}, status_code=400)
in_pcm = _f32_from_b64(body.get("audio_b64", ""))
is_final = bool(body.get("in_eos", False))
if CAPTURE and session_id in _capture:
cap = _capture[session_id]
if in_pcm.size:
cap["frames"].append(in_pcm)
if is_final:
import wave as _wave
pcm = np.concatenate(cap["frames"]) if cap["frames"] else np.zeros(1, np.float32)
i16 = np.clip(pcm, -1, 1)
i16 = np.where(i16 < 0, i16 * 32768, i16 * 32767).astype("<i2")
path = f"/tmp/bb_capture_{session_id[:8]}.wav"
with _wave.open(path, "wb") as w:
w.setnchannels(1); w.setsampwidth(2)
w.setframerate(int(cap["meta"].get("sample_rate_hz", 24000)))
w.writeframes(i16.tobytes())
print(f"[CAPTURE] saved source audio -> {path} ({pcm.size/24000:.2f}s) "
f"lang={cap['meta'].get('language')}", flush=True)
_capture.pop(session_id, None)
out_pcm, done = _model.push(state, in_pcm, is_final)
audio_b64, n_bytes = _b64_from_f32(out_pcm)
if done:
_sessions.pop(session_id, None)
return JSONResponse(
{
"session_id": session_id,
"audio_b64": audio_b64,
"out_eos": bool(done),
"n_bytes": n_bytes,
}
)
@app.on_event("startup")
def _warm_on_startup():
# Load + warm the model at boot so validators never hit a cold-load (~25s) mid-query.
try:
_model.start_session(language="fr", sample_rate_hz=24000, channels=1)
print("[startup] model warmed:", MODEL_NAME, flush=True)
except Exception as e:
print("[startup] warmup error:", e, flush=True)
@app.get("/healthz")
@app.get("/health")
def health():
return {"status": "ok", "model": MODEL_NAME, "target_sample_rate_hz": TARGET_SAMPLE_RATE_HZ}
@app.post("/v1/predict") # canonical (qualifying + arena)
@app.post("/predict") # alias (arena managed contract)
async def predict(request: Request):
body = await request.json()
kind = str(body.get("kind", "")).lower()
# NOTE: validator signs requests with bt_header_dendrite_* headers. For a
# production miner, verify the signature here and check the hotkey is a
# registered validator before serving. Left open for the scaffold.
if kind == "init":
return _handle_init(body)
if kind == "predict":
return _handle_predict(body)
return JSONResponse({"error": f"unknown kind '{kind}'"}, status_code=400)