File size: 6,249 Bytes
2113bcf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
"""
Babelbit SN59 miner server.

Implements the validator's S2S audio protocol on a single endpoint
(default POST /v1/predict). The validator sends:

  1) one  init  request (kind="init")    -> reply ready=true + session_id (echo audio params)
  2) many predict requests (kind="predict") streaming 80 ms input frames
       -> reply target-language (English) audio frames; set out_eos=true when done
  3) possibly a few drain predict requests (empty audio, in_eos=true) until out_eos

Audio wire format (from babelbit.utils.predict_audio):
  - input frames:  float32 little-endian PCM, 24 kHz, mono, 1920 samples/frame (7680 bytes)
  - your output:   float32 little-endian PCM, 24 kHz (validator down-converts to int16 WAV)

Run:
  pip install -r requirements.txt
  uvicorn server:app --host 0.0.0.0 --port 8091
  # endpoint then lives at  POST http://<ip>:8091/v1/predict
"""
from __future__ import annotations

import base64
import os
import uuid
from typing import Any, Dict

import numpy as np
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse

from model import load_model, TARGET_SAMPLE_RATE_HZ

MODEL_NAME = os.getenv("BB_MINER_MODEL", "echo")
PREDICT_PATH = "/" + os.getenv("BB_MINER_PREDICT_ENDPOINT", "v1/predict").lstrip("/")

app = FastAPI(title="babelbit-sn59-miner")
_model = load_model(MODEL_NAME)
_sessions: Dict[str, Any] = {}  # session_id -> model state


def _f32_from_b64(audio_b64: str) -> np.ndarray:
    if not audio_b64:
        return np.zeros(0, np.float32)
    raw = base64.b64decode(audio_b64)
    return np.frombuffer(raw, dtype="<f4").astype(np.float32, copy=False)


def _b64_from_f32(pcm: np.ndarray) -> tuple[str, int]:
    if pcm is None or pcm.size == 0:
        return "", 0
    raw = np.ascontiguousarray(pcm.astype("<f4", copy=False)).tobytes()
    return base64.b64encode(raw).decode("ascii"), len(raw)


CAPTURE = os.getenv("BB_CAPTURE", "0") not in {"0", "", "false", "no"}
_capture: Dict[str, Dict[str, Any]] = {}  # session_id -> {meta, frames}


def _handle_init(body: Dict[str, Any]) -> JSONResponse:
    session_id = uuid.uuid4().hex
    _sessions[session_id] = _model.start_session(
        language=body.get("language"),
        sample_rate_hz=int(body["sample_rate_hz"]),
        channels=int(body["channels"]),
    )
    if CAPTURE:
        # The validator just told us the REAL source language + audio params.
        print(f"[CAPTURE] init challenge={body.get('challenge_uid')} utterance={body.get('utterance_id')} "
              f"SOURCE_language={body.get('language')} sr={body.get('sample_rate_hz')} "
              f"frame_rate={body.get('frame_rate_hz')} dtype={body.get('dtype')}", flush=True)
        _capture[session_id] = {"meta": dict(body), "frames": []}
    # Echo the audio params back EXACTLY — the validator rejects mismatches
    # (sample_rate_hz, channels, dtype must match; frame_samples/frame_rate_hz > 0).
    return JSONResponse(
        {
            "ready": True,
            "miner_id": MODEL_NAME,
            "session_id": session_id,
            "challenge_uid": body["challenge_uid"],
            "utterance_id": body["utterance_id"],
            "language": body.get("language"),
            "sample_rate_hz": int(body["sample_rate_hz"]),
            "frame_rate_hz": float(body["frame_rate_hz"]),
            "frame_samples": int(body["frame_samples"]),
            "dtype": body["dtype"],
            "channels": int(body["channels"]),
        }
    )


def _handle_predict(body: Dict[str, Any]) -> JSONResponse:
    session_id = str(body.get("session_id") or "")
    state = _sessions.get(session_id)
    if state is None:
        return JSONResponse({"error": "unknown session_id"}, status_code=400)

    in_pcm = _f32_from_b64(body.get("audio_b64", ""))
    is_final = bool(body.get("in_eos", False))

    if CAPTURE and session_id in _capture:
        cap = _capture[session_id]
        if in_pcm.size:
            cap["frames"].append(in_pcm)
        if is_final:
            import wave as _wave
            pcm = np.concatenate(cap["frames"]) if cap["frames"] else np.zeros(1, np.float32)
            i16 = np.clip(pcm, -1, 1)
            i16 = np.where(i16 < 0, i16 * 32768, i16 * 32767).astype("<i2")
            path = f"/tmp/bb_capture_{session_id[:8]}.wav"
            with _wave.open(path, "wb") as w:
                w.setnchannels(1); w.setsampwidth(2)
                w.setframerate(int(cap["meta"].get("sample_rate_hz", 24000)))
                w.writeframes(i16.tobytes())
            print(f"[CAPTURE] saved source audio -> {path}  ({pcm.size/24000:.2f}s)  "
                  f"lang={cap['meta'].get('language')}", flush=True)
            _capture.pop(session_id, None)

    out_pcm, done = _model.push(state, in_pcm, is_final)
    audio_b64, n_bytes = _b64_from_f32(out_pcm)

    if done:
        _sessions.pop(session_id, None)

    return JSONResponse(
        {
            "session_id": session_id,
            "audio_b64": audio_b64,
            "out_eos": bool(done),
            "n_bytes": n_bytes,
        }
    )


@app.on_event("startup")
def _warm_on_startup():
    # Load + warm the model at boot so validators never hit a cold-load (~25s) mid-query.
    try:
        _model.start_session(language="fr", sample_rate_hz=24000, channels=1)
        print("[startup] model warmed:", MODEL_NAME, flush=True)
    except Exception as e:
        print("[startup] warmup error:", e, flush=True)


@app.get("/healthz")
@app.get("/health")
def health():
    return {"status": "ok", "model": MODEL_NAME, "target_sample_rate_hz": TARGET_SAMPLE_RATE_HZ}


@app.post("/v1/predict")  # canonical (qualifying + arena)
@app.post("/predict")     # alias (arena managed contract)
async def predict(request: Request):
    body = await request.json()
    kind = str(body.get("kind", "")).lower()
    # NOTE: validator signs requests with bt_header_dendrite_* headers. For a
    # production miner, verify the signature here and check the hotkey is a
    # registered validator before serving. Left open for the scaffold.
    if kind == "init":
        return _handle_init(body)
    if kind == "predict":
        return _handle_predict(body)
    return JSONResponse({"error": f"unknown kind '{kind}'"}, status_code=400)