Update app.py
Browse files
app.py
CHANGED
|
@@ -1,39 +1,34 @@
|
|
|
|
|
| 1 |
import os, json, time, asyncio, tempfile
|
| 2 |
-
from typing import AsyncGenerator, Dict, Any, Optional
|
| 3 |
-
from fastapi import FastAPI, Request, Query, UploadFile
|
| 4 |
from fastapi.responses import JSONResponse, StreamingResponse, FileResponse
|
| 5 |
|
| 6 |
-
#
|
| 7 |
BASE_DIR = os.environ.get("BASE_DIR", "/tmp/brain_app")
|
| 8 |
FILES_DIR = os.path.join(BASE_DIR, "files")
|
| 9 |
LOGS_DIR = os.path.join(FILES_DIR, "logs")
|
| 10 |
EVENTS_FILE = os.path.join(LOGS_DIR, "events.jsonl")
|
| 11 |
-
|
| 12 |
for p in (BASE_DIR, FILES_DIR, LOGS_DIR):
|
| 13 |
os.makedirs(p, exist_ok=True)
|
| 14 |
|
| 15 |
-
#
|
| 16 |
-
TTS_BASE = os.environ.get("TTS_BASE",
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
NOISE_W = float(os.environ.get("NOISE_W", "0.92"))
|
| 20 |
-
DEFAULT_VOICE = os.environ.get("DEFAULT_VOICE", "en_US-amy-medium")
|
| 21 |
-
|
| 22 |
-
# ========== STT Config ==========
|
| 23 |
-
STT_MODEL = os.environ.get("STT_MODEL", "base.en") # faster-whisper model id
|
| 24 |
-
STT_DEVICE = os.environ.get("STT_DEVICE", "cpu") # "cpu" | "cuda"
|
| 25 |
-
STT_COMPUTE = os.environ.get("STT_COMPUTE", "int8") # "int8"|"int8_float16"|"float32"
|
| 26 |
-
STT_MAXLEN_S = float(os.environ.get("STT_MAXLEN_S", "600")) # refuse extremely long uploads
|
| 27 |
|
| 28 |
-
#
|
| 29 |
-
|
|
|
|
|
|
|
|
|
|
| 30 |
|
| 31 |
-
#
|
|
|
|
| 32 |
log_queue: "asyncio.Queue[Dict[str, Any]]" = asyncio.Queue()
|
| 33 |
|
| 34 |
def write_event(event: Dict[str, Any]) -> None:
|
| 35 |
event.setdefault("ts", time.time())
|
| 36 |
-
os.makedirs(LOGS_DIR, exist_ok=True)
|
| 37 |
with open(EVENTS_FILE, "a", encoding="utf-8") as f:
|
| 38 |
f.write(json.dumps(event, ensure_ascii=False) + "\n")
|
| 39 |
try:
|
|
@@ -41,16 +36,13 @@ def write_event(event: Dict[str, Any]) -> None:
|
|
| 41 |
except asyncio.QueueFull:
|
| 42 |
pass
|
| 43 |
|
| 44 |
-
def clamp_rate(rate_wpm: Optional[int]) -> int:
|
| 45 |
-
if not isinstance(rate_wpm, int):
|
| 46 |
-
return BASE_WPM
|
| 47 |
-
return max(80, min(320, rate_wpm))
|
| 48 |
-
|
| 49 |
def rate_to_length_scale(rate_wpm: Optional[int]) -> float:
|
| 50 |
-
|
| 51 |
-
|
|
|
|
|
|
|
|
|
|
| 52 |
|
| 53 |
-
# ========== Health & Basics ==========
|
| 54 |
@app.get("/health")
|
| 55 |
def health():
|
| 56 |
return {
|
|
@@ -58,24 +50,12 @@ def health():
|
|
| 58 |
"service": "brain-space",
|
| 59 |
"time": time.time(),
|
| 60 |
"files_dir": FILES_DIR,
|
| 61 |
-
"logs_dir": LOGS_DIR,
|
| 62 |
"tts_base": TTS_BASE,
|
| 63 |
-
"
|
| 64 |
-
"
|
| 65 |
-
"stt_compute": STT_COMPUTE,
|
| 66 |
}
|
| 67 |
|
| 68 |
-
|
| 69 |
-
async def process(req: Request):
|
| 70 |
-
try:
|
| 71 |
-
payload = await req.json()
|
| 72 |
-
except Exception:
|
| 73 |
-
return JSONResponse({"ok": False, "error": "Invalid JSON body"}, status_code=400)
|
| 74 |
-
event = {"type": "process", "data": payload}
|
| 75 |
-
write_event(event)
|
| 76 |
-
return {"ok": True, "received": payload}
|
| 77 |
-
|
| 78 |
-
# ========== SSE Logs ==========
|
| 79 |
@app.get("/stream/logs")
|
| 80 |
async def stream_logs() -> StreamingResponse:
|
| 81 |
async def gen() -> AsyncGenerator[bytes, None]:
|
|
@@ -88,72 +68,40 @@ async def stream_logs() -> StreamingResponse:
|
|
| 88 |
pass
|
| 89 |
while True:
|
| 90 |
event = await log_queue.get()
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
return StreamingResponse(gen(), media_type="text/event-stream", headers=headers)
|
| 95 |
-
|
| 96 |
-
@app.post("/log_error")
|
| 97 |
-
async def log_error(req: Request):
|
| 98 |
-
try:
|
| 99 |
-
payload = await req.json()
|
| 100 |
-
except Exception:
|
| 101 |
-
return JSONResponse({"ok": False, "error": "Invalid JSON body"}, status_code=400)
|
| 102 |
-
event = {"type": "error", "data": payload}
|
| 103 |
-
write_event(event)
|
| 104 |
-
return {"ok": True}
|
| 105 |
-
|
| 106 |
-
# ========== TTS: JSON (file URL) ==========
|
| 107 |
-
@app.post("/tts/say")
|
| 108 |
-
async def tts_say_json(req: Request):
|
| 109 |
-
"""
|
| 110 |
-
POST JSON -> call TTS /speak (JSON) and return audio_url and audio_url_full.
|
| 111 |
-
Body:
|
| 112 |
-
{
|
| 113 |
-
"text": "Hello",
|
| 114 |
-
"voice": "en_US-amy-medium",
|
| 115 |
-
"rate_wpm": 165, # optional (maps to length_scale)
|
| 116 |
-
"length_scale": 1.05, # optional (overrides rate_wpm)
|
| 117 |
-
"noise_scale": 0.33, # optional
|
| 118 |
-
"noise_w": 0.92 # optional
|
| 119 |
-
}
|
| 120 |
-
"""
|
| 121 |
-
try:
|
| 122 |
-
body = await req.json()
|
| 123 |
-
except Exception:
|
| 124 |
-
return JSONResponse({"ok": False, "error": "Invalid JSON body"}, status_code=400)
|
| 125 |
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 134 |
|
|
|
|
|
|
|
| 135 |
import httpx
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
if not ok or not data or not data.get("ok"):
|
| 148 |
-
return JSONResponse({"ok": False, "error": (data or {}).get("error", f"TTS error {resp.status_code}")}, status_code=500)
|
| 149 |
-
|
| 150 |
-
audio_url = data["audio_url"]
|
| 151 |
-
audio_url_full = audio_url if audio_url.startswith("http") else f"{TTS_BASE}{audio_url}"
|
| 152 |
-
return {"ok": True, "audio_url": audio_url, "audio_url_full": audio_url_full, "voice": voice, "length_scale": length_scale}
|
| 153 |
-
|
| 154 |
-
# ========== TTS: Direct WAV Proxy ==========
|
| 155 |
-
async def _proxy_tts_wav_stream(text: str, voice: str, length_scale: float, noise_scale: float, noise_w: float, save_local: bool = False) -> StreamingResponse:
|
| 156 |
import httpx
|
|
|
|
| 157 |
params = {
|
| 158 |
"text": text,
|
| 159 |
"voice": voice,
|
|
@@ -161,233 +109,79 @@ async def _proxy_tts_wav_stream(text: str, voice: str, length_scale: float, nois
|
|
| 161 |
"noise_scale": f"{noise_scale:.3f}",
|
| 162 |
"noise_w": f"{noise_w:.3f}",
|
| 163 |
}
|
| 164 |
-
ts = int(time.time() * 1000)
|
| 165 |
-
local_path = os.path.join(FILES_DIR, f"say-{ts}.wav") if save_local else None
|
| 166 |
-
|
| 167 |
async def gen():
|
| 168 |
async with httpx.AsyncClient(timeout=None) as client:
|
| 169 |
async with client.stream("GET", f"{TTS_BASE}/speak.wav", params=params) as resp:
|
| 170 |
if resp.status_code != 200:
|
| 171 |
-
|
| 172 |
-
yield err_body
|
| 173 |
return
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
finally:
|
| 183 |
-
if f: f.close()
|
| 184 |
-
|
| 185 |
-
headers = {"Cache-Control": "no-cache"}
|
| 186 |
-
if local_path:
|
| 187 |
-
headers["X-Local-Path"] = local_path
|
| 188 |
-
return StreamingResponse(gen(), media_type="audio/wav", headers=headers)
|
| 189 |
-
|
| 190 |
-
@app.get("/tts/say.wav")
|
| 191 |
-
async def tts_say_wav_get(
|
| 192 |
-
text: str = Query(..., description="Text to synthesize"),
|
| 193 |
-
voice: str = Query(DEFAULT_VOICE, description="Voice id"),
|
| 194 |
-
rate_wpm: Optional[int] = Query(None, description="Words-per-minute"),
|
| 195 |
-
length_scale: Optional[float] = Query(None, description="Override length_scale"),
|
| 196 |
-
noise_scale: float = Query(NOISE_SCALE),
|
| 197 |
-
noise_w: float = Query(NOISE_W),
|
| 198 |
-
save: bool = Query(False, description="Also save under /files"),
|
| 199 |
-
):
|
| 200 |
-
ls = float(length_scale) if length_scale is not None else rate_to_length_scale(rate_wpm if rate_wpm is not None else BASE_WPM)
|
| 201 |
-
write_event({"type": "tts_say_wav_get", "data": {"len": len(text), "voice": voice, "ls": ls, "save": save}})
|
| 202 |
-
return await _proxy_tts_wav_stream(text, voice, ls, noise_scale, noise_w, save_local=save)
|
| 203 |
-
|
| 204 |
-
@app.post("/tts/say.wav")
|
| 205 |
-
async def tts_say_wav_post(req: Request, save: bool = Query(False, description="Also save under /files")):
|
| 206 |
-
try:
|
| 207 |
-
body = await req.json()
|
| 208 |
-
except Exception:
|
| 209 |
-
return JSONResponse({"ok": False, "error": "Invalid JSON body"}, status_code=400)
|
| 210 |
-
text = (body.get("text") or "").strip()
|
| 211 |
-
if not text:
|
| 212 |
-
return JSONResponse({"ok": False, "error": "Missing text"}, status_code=400)
|
| 213 |
-
voice = (body.get("voice") or DEFAULT_VOICE).strip()
|
| 214 |
-
ls = float(body["length_scale"]) if "length_scale" in body else rate_to_length_scale(int(body.get("rate_wpm", BASE_WPM)))
|
| 215 |
-
ns = float(body.get("noise_scale", NOISE_SCALE))
|
| 216 |
-
nw = float(body.get("noise_w", NOISE_W))
|
| 217 |
-
write_event({"type": "tts_say_wav_post", "data": {"len": len(text), "voice": voice, "ls": ls, "save": save}})
|
| 218 |
-
return await _proxy_tts_wav_stream(text, voice, ls, ns, nw, save_local=save)
|
| 219 |
-
|
| 220 |
-
# ========== Serve saved files ==========
|
| 221 |
-
@app.get("/files/{name}")
|
| 222 |
-
def get_saved_file(name: str):
|
| 223 |
-
path = os.path.join(FILES_DIR, name)
|
| 224 |
-
if not os.path.exists(path):
|
| 225 |
-
return JSONResponse({"ok": False, "error": "not found"}, status_code=404)
|
| 226 |
-
return FileResponse(path, media_type="audio/wav", filename=name)
|
| 227 |
-
|
| 228 |
-
# ========== STT (faster-whisper) ==========
|
| 229 |
-
_model = None
|
| 230 |
-
def _stt_model():
|
| 231 |
-
global _model
|
| 232 |
-
if _model is None:
|
| 233 |
-
from faster_whisper import WhisperModel
|
| 234 |
-
_model = WhisperModel(STT_MODEL, device=STT_DEVICE, compute_type=STT_COMPUTE)
|
| 235 |
-
return _model
|
| 236 |
-
|
| 237 |
-
async def _download_to_temp(url: str) -> str:
|
| 238 |
-
import httpx
|
| 239 |
-
_, ext = os.path.splitext(url.split("?")[0])
|
| 240 |
-
if not ext: ext = ".wav"
|
| 241 |
-
fd, tmp_path = tempfile.mkstemp(prefix="stt_", suffix=ext)
|
| 242 |
-
os.close(fd)
|
| 243 |
-
async with httpx.AsyncClient(timeout=300) as client:
|
| 244 |
-
r = await client.get(url)
|
| 245 |
-
r.raise_for_status()
|
| 246 |
-
with open(tmp_path, "wb") as f:
|
| 247 |
-
f.write(r.content)
|
| 248 |
-
return tmp_path
|
| 249 |
-
|
| 250 |
-
def _transcribe_path(path: str, language: Optional[str] = None) -> Dict[str, Any]:
|
| 251 |
-
model = _stt_model()
|
| 252 |
-
segments, info = model.transcribe(
|
| 253 |
-
path,
|
| 254 |
-
language=language, # "en" or None for auto
|
| 255 |
-
beam_size=5,
|
| 256 |
-
vad_filter=False,
|
| 257 |
-
word_timestamps=False
|
| 258 |
-
)
|
| 259 |
-
out_segments: List[Dict[str, Any]] = []
|
| 260 |
-
txt_parts: List[str] = []
|
| 261 |
-
dur = getattr(info, "duration", None)
|
| 262 |
-
for seg in segments:
|
| 263 |
-
out_segments.append({"start": seg.start, "end": seg.end, "text": seg.text})
|
| 264 |
-
txt_parts.append(seg.text)
|
| 265 |
-
# guard against absurdly long files
|
| 266 |
-
if STT_MAXLEN_S and dur and seg.end and float(seg.end) > STT_MAXLEN_S:
|
| 267 |
-
break
|
| 268 |
-
text = "".join(txt_parts).strip()
|
| 269 |
-
return {"text": text, "language": getattr(info, "language", language or "unknown"), "duration": dur, "segments": out_segments}
|
| 270 |
-
|
| 271 |
-
@app.post("/stt/transcribe")
|
| 272 |
-
async def stt_transcribe(
|
| 273 |
-
req: Request,
|
| 274 |
-
language: Optional[str] = Query(None, description="ISO code like 'en' (None = auto)"),
|
| 275 |
-
file_url: Optional[str] = Query(None, description="If provided via query")
|
| 276 |
-
):
|
| 277 |
-
"""
|
| 278 |
-
POST either:
|
| 279 |
-
- multipart/form-data with 'audio' file
|
| 280 |
-
- or JSON: { "file_url": "https://..." }
|
| 281 |
-
- or query param ?file_url=...
|
| 282 |
-
Returns: { ok, text, language, duration, segments:[...] }
|
| 283 |
-
"""
|
| 284 |
-
tmp_path = None
|
| 285 |
-
try:
|
| 286 |
-
content_type = req.headers.get("content-type","").lower()
|
| 287 |
-
if "multipart/form-data" in content_type:
|
| 288 |
-
form = await req.form()
|
| 289 |
-
up: UploadFile = form.get("audio") # key: audio
|
| 290 |
-
if not up:
|
| 291 |
-
return JSONResponse({"ok": False, "error": "Missing 'audio' file"}, status_code=400)
|
| 292 |
-
suffix = os.path.splitext(up.filename or "")[1] or ".wav"
|
| 293 |
-
fd, tmp_path = tempfile.mkstemp(prefix="stt_", suffix=suffix)
|
| 294 |
-
os.close(fd)
|
| 295 |
-
with open(tmp_path, "wb") as f:
|
| 296 |
-
f.write(await up.read())
|
| 297 |
-
else:
|
| 298 |
-
# JSON or query
|
| 299 |
-
try:
|
| 300 |
-
body = await req.json()
|
| 301 |
-
except Exception:
|
| 302 |
-
body = {}
|
| 303 |
-
url = file_url or (body.get("file_url") if isinstance(body, dict) else None)
|
| 304 |
-
if not url:
|
| 305 |
-
return JSONResponse({"ok": False, "error": "Provide file_url (JSON/query) or multipart 'audio' file"}, status_code=400)
|
| 306 |
-
tmp_path = await _download_to_temp(url)
|
| 307 |
-
|
| 308 |
-
res = _transcribe_path(tmp_path, language=language)
|
| 309 |
-
write_event({"type": "stt_transcribe", "data": {"ok": True, "language": res.get("language"), "dur": res.get("duration"), "text_len": len(res.get("text",""))}})
|
| 310 |
-
return {"ok": True, **res}
|
| 311 |
-
except Exception as e:
|
| 312 |
-
write_event({"type": "stt_transcribe", "data": {"ok": False, "error": str(e)}})
|
| 313 |
-
return JSONResponse({"ok": False, "error": str(e)}, status_code=500)
|
| 314 |
-
finally:
|
| 315 |
-
try:
|
| 316 |
-
if tmp_path and os.path.exists(tmp_path):
|
| 317 |
-
os.unlink(tmp_path)
|
| 318 |
-
except Exception:
|
| 319 |
-
pass
|
| 320 |
-
|
| 321 |
-
# --- End-to-end: STT -> Brain -> TTS (streamed WAV) ---
|
| 322 |
-
@app.post("/demo/echo.wav")
|
| 323 |
-
async def demo_echo_wav(
|
| 324 |
req: Request,
|
| 325 |
-
voice: str = Query(DEFAULT_VOICE
|
| 326 |
-
rate_wpm: Optional[int] = Query(
|
| 327 |
-
length_scale: Optional[float] = Query(None, description="Override prosody"),
|
| 328 |
noise_scale: float = Query(NOISE_SCALE),
|
| 329 |
noise_w: float = Query(NOISE_W),
|
| 330 |
-
save: bool = Query(False, description="Also save output WAV under /files"),
|
| 331 |
):
|
| 332 |
"""
|
| 333 |
-
|
| 334 |
-
|
| 335 |
-
|
| 336 |
-
Returns: streaming audio/wav that says what it heard.
|
| 337 |
"""
|
| 338 |
tmp_path = None
|
| 339 |
try:
|
| 340 |
-
#
|
| 341 |
-
|
| 342 |
-
if "multipart/form-data" in
|
| 343 |
form = await req.form()
|
| 344 |
-
up = form.get("audio")
|
| 345 |
if not up:
|
| 346 |
return JSONResponse({"ok": False, "error": "Missing 'audio' file"}, status_code=400)
|
| 347 |
-
|
| 348 |
-
|
|
|
|
| 349 |
os.close(fd)
|
| 350 |
with open(tmp_path, "wb") as f:
|
| 351 |
f.write(await up.read())
|
| 352 |
else:
|
| 353 |
-
# JSON with file_url
|
| 354 |
try:
|
| 355 |
body = await req.json()
|
| 356 |
except Exception:
|
| 357 |
body = {}
|
| 358 |
url = (body or {}).get("file_url")
|
| 359 |
if not url:
|
| 360 |
-
return JSONResponse({"ok": False, "error": "Provide multipart 'audio' or JSON {file_url}"},
|
|
|
|
| 361 |
tmp_path = await _download_to_temp(url)
|
| 362 |
|
| 363 |
-
#
|
| 364 |
-
|
| 365 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 366 |
if not text:
|
| 367 |
-
write_event({"type":
|
| 368 |
return JSONResponse({"ok": False, "error": "No speech detected"}, status_code=422)
|
| 369 |
|
| 370 |
-
#
|
| 371 |
-
reply_text = f"I heard: {text}"
|
| 372 |
-
reply_text = reply_text[:800] # safety bound
|
| 373 |
|
| 374 |
-
|
| 375 |
-
ls = float(length_scale) if length_scale is not None else rate_to_length_scale(rate_wpm if rate_wpm is not None else BASE_WPM)
|
| 376 |
|
| 377 |
-
|
| 378 |
-
|
| 379 |
-
# --- TTS (stream WAV back to the caller) ---
|
| 380 |
-
return await _proxy_tts_wav_stream(
|
| 381 |
-
text=reply_text,
|
| 382 |
-
voice=voice,
|
| 383 |
-
length_scale=ls,
|
| 384 |
-
noise_scale=noise_scale,
|
| 385 |
-
noise_w=noise_w,
|
| 386 |
-
save_local=save
|
| 387 |
-
)
|
| 388 |
|
| 389 |
except Exception as e:
|
| 390 |
-
write_event({"type":
|
| 391 |
return JSONResponse({"ok": False, "error": str(e)}, status_code=500)
|
| 392 |
finally:
|
| 393 |
try:
|
|
@@ -396,7 +190,14 @@ async def demo_echo_wav(
|
|
| 396 |
except Exception:
|
| 397 |
pass
|
| 398 |
|
| 399 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 400 |
if __name__ == "__main__":
|
| 401 |
import uvicorn
|
| 402 |
uvicorn.run("brain_app:app", host="0.0.0.0", port=7861, reload=False)
|
|
|
|
| 1 |
+
# brain_app.py — Brain Space: STT → TTS proxy streamer
|
| 2 |
import os, json, time, asyncio, tempfile
|
| 3 |
+
from typing import AsyncGenerator, Dict, Any, Optional
|
| 4 |
+
from fastapi import FastAPI, Request, Query, UploadFile
|
| 5 |
from fastapi.responses import JSONResponse, StreamingResponse, FileResponse
|
| 6 |
|
| 7 |
+
# === Directories ===
|
| 8 |
BASE_DIR = os.environ.get("BASE_DIR", "/tmp/brain_app")
|
| 9 |
FILES_DIR = os.path.join(BASE_DIR, "files")
|
| 10 |
LOGS_DIR = os.path.join(FILES_DIR, "logs")
|
| 11 |
EVENTS_FILE = os.path.join(LOGS_DIR, "events.jsonl")
|
|
|
|
| 12 |
for p in (BASE_DIR, FILES_DIR, LOGS_DIR):
|
| 13 |
os.makedirs(p, exist_ok=True)
|
| 14 |
|
| 15 |
+
# === External Spaces ===
|
| 16 |
+
TTS_BASE = os.environ.get("TTS_BASE", "https://Percy3822-ActualTTS.hf.space")
|
| 17 |
+
# TODO: set your STT Space base here (example):
|
| 18 |
+
STT_BASE = os.environ.get("STT_BASE", "https://YOUR-STT-SPACE.hf.space")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
|
| 20 |
+
# === TTS defaults ===
|
| 21 |
+
DEFAULT_VOICE = os.environ.get("DEFAULT_VOICE", "en_US-amy-medium")
|
| 22 |
+
BASE_WPM = int(os.environ.get("BASE_WPM", "165"))
|
| 23 |
+
NOISE_SCALE = float(os.environ.get("NOISE_SCALE", "0.33"))
|
| 24 |
+
NOISE_W = float(os.environ.get("NOISE_W", "0.92"))
|
| 25 |
|
| 26 |
+
# === App ===
|
| 27 |
+
app = FastAPI(title="Brain Space (STT→TTS coordinator)", version="3.0.0")
|
| 28 |
log_queue: "asyncio.Queue[Dict[str, Any]]" = asyncio.Queue()
|
| 29 |
|
| 30 |
def write_event(event: Dict[str, Any]) -> None:
|
| 31 |
event.setdefault("ts", time.time())
|
|
|
|
| 32 |
with open(EVENTS_FILE, "a", encoding="utf-8") as f:
|
| 33 |
f.write(json.dumps(event, ensure_ascii=False) + "\n")
|
| 34 |
try:
|
|
|
|
| 36 |
except asyncio.QueueFull:
|
| 37 |
pass
|
| 38 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
def rate_to_length_scale(rate_wpm: Optional[int]) -> float:
|
| 40 |
+
base = BASE_WPM
|
| 41 |
+
if not isinstance(rate_wpm, int):
|
| 42 |
+
return 1.0
|
| 43 |
+
r = max(80, min(320, rate_wpm))
|
| 44 |
+
return round(base / float(r), 3)
|
| 45 |
|
|
|
|
| 46 |
@app.get("/health")
|
| 47 |
def health():
|
| 48 |
return {
|
|
|
|
| 50 |
"service": "brain-space",
|
| 51 |
"time": time.time(),
|
| 52 |
"files_dir": FILES_DIR,
|
|
|
|
| 53 |
"tts_base": TTS_BASE,
|
| 54 |
+
"stt_base": STT_BASE,
|
| 55 |
+
"defaults": {"voice": DEFAULT_VOICE, "rate_wpm": BASE_WPM}
|
|
|
|
| 56 |
}
|
| 57 |
|
| 58 |
+
# ========== SSE logs (optional) ==========
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
@app.get("/stream/logs")
|
| 60 |
async def stream_logs() -> StreamingResponse:
|
| 61 |
async def gen() -> AsyncGenerator[bytes, None]:
|
|
|
|
| 68 |
pass
|
| 69 |
while True:
|
| 70 |
event = await log_queue.get()
|
| 71 |
+
yield b"data: " + json.dumps(event, ensure_ascii=False).encode("utf-8") + b"\n\n"
|
| 72 |
+
return StreamingResponse(gen(), media_type="text/event-stream",
|
| 73 |
+
headers={"Cache-Control":"no-cache","Connection":"keep-alive"})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 74 |
|
| 75 |
+
# ---------- Helpers ----------
|
| 76 |
+
async def _download_to_temp(url: str) -> str:
|
| 77 |
+
import httpx, os
|
| 78 |
+
_, ext = os.path.splitext(url.split("?")[0])
|
| 79 |
+
if not ext: ext = ".wav"
|
| 80 |
+
fd, tmp_path = tempfile.mkstemp(prefix="mic_", suffix=ext)
|
| 81 |
+
os.close(fd)
|
| 82 |
+
async with httpx.AsyncClient(timeout=300) as client:
|
| 83 |
+
r = await client.get(url)
|
| 84 |
+
r.raise_for_status()
|
| 85 |
+
with open(tmp_path, "wb") as f:
|
| 86 |
+
f.write(r.content)
|
| 87 |
+
return tmp_path
|
| 88 |
|
| 89 |
+
async def _call_stt_transcribe_file(path: str) -> Dict[str, Any]:
|
| 90 |
+
"""POST multipart 'audio' to STT /stt/transcribe and return its JSON."""
|
| 91 |
import httpx
|
| 92 |
+
stt_url = f"{STT_BASE}/stt/transcribe"
|
| 93 |
+
files = {"audio": (os.path.basename(path), open(path, "rb"), "audio/wav")}
|
| 94 |
+
async with httpx.AsyncClient(timeout=300) as client:
|
| 95 |
+
r = await client.post(stt_url, files=files)
|
| 96 |
+
ok = r.status_code == 200
|
| 97 |
+
data = r.json() if ok else {"ok": False, "error": f"STT {r.status_code}"}
|
| 98 |
+
return data
|
| 99 |
+
|
| 100 |
+
async def _proxy_tts_wav_stream(text: str, voice: str, rate_wpm: Optional[int],
|
| 101 |
+
noise_scale: float, noise_w: float) -> StreamingResponse:
|
| 102 |
+
"""Proxy stream from TTS /speak.wav based on text."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 103 |
import httpx
|
| 104 |
+
length_scale = rate_to_length_scale(rate_wpm) if rate_wpm is not None else rate_to_length_scale(BASE_WPM)
|
| 105 |
params = {
|
| 106 |
"text": text,
|
| 107 |
"voice": voice,
|
|
|
|
| 109 |
"noise_scale": f"{noise_scale:.3f}",
|
| 110 |
"noise_w": f"{noise_w:.3f}",
|
| 111 |
}
|
|
|
|
|
|
|
|
|
|
| 112 |
async def gen():
|
| 113 |
async with httpx.AsyncClient(timeout=None) as client:
|
| 114 |
async with client.stream("GET", f"{TTS_BASE}/speak.wav", params=params) as resp:
|
| 115 |
if resp.status_code != 200:
|
| 116 |
+
yield (await resp.aread())
|
|
|
|
| 117 |
return
|
| 118 |
+
async for chunk in resp.aiter_bytes():
|
| 119 |
+
if chunk:
|
| 120 |
+
yield chunk
|
| 121 |
+
return StreamingResponse(gen(), media_type="audio/wav", headers={"Cache-Control":"no-cache"})
|
| 122 |
+
|
| 123 |
+
# ========== The simple end-to-end endpoint ==========
|
| 124 |
+
@app.post("/demo/relay.wav")
|
| 125 |
+
async def demo_relay_wav(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 126 |
req: Request,
|
| 127 |
+
voice: str = Query(DEFAULT_VOICE),
|
| 128 |
+
rate_wpm: Optional[int] = Query(BASE_WPM),
|
|
|
|
| 129 |
noise_scale: float = Query(NOISE_SCALE),
|
| 130 |
noise_w: float = Query(NOISE_W),
|
|
|
|
| 131 |
):
|
| 132 |
"""
|
| 133 |
+
Accept 5s mic recording from client (multipart 'audio' or JSON {file_url}),
|
| 134 |
+
send to STT Space for transcription, then IMMEDIATELY proxy stream TTS WAV
|
| 135 |
+
that speaks back what was heard.
|
|
|
|
| 136 |
"""
|
| 137 |
tmp_path = None
|
| 138 |
try:
|
| 139 |
+
# Ingest audio
|
| 140 |
+
ctype = (req.headers.get("content-type") or "").lower()
|
| 141 |
+
if "multipart/form-data" in ctype:
|
| 142 |
form = await req.form()
|
| 143 |
+
up: UploadFile = form.get("audio")
|
| 144 |
if not up:
|
| 145 |
return JSONResponse({"ok": False, "error": "Missing 'audio' file"}, status_code=400)
|
| 146 |
+
import os, tempfile
|
| 147 |
+
suffix = os.path.splitext(up.filename or "")[1] or ".wav"
|
| 148 |
+
fd, tmp_path = tempfile.mkstemp(prefix="mic_", suffix=suffix)
|
| 149 |
os.close(fd)
|
| 150 |
with open(tmp_path, "wb") as f:
|
| 151 |
f.write(await up.read())
|
| 152 |
else:
|
| 153 |
+
# JSON with {file_url}
|
| 154 |
try:
|
| 155 |
body = await req.json()
|
| 156 |
except Exception:
|
| 157 |
body = {}
|
| 158 |
url = (body or {}).get("file_url")
|
| 159 |
if not url:
|
| 160 |
+
return JSONResponse({"ok": False, "error": "Provide multipart 'audio' or JSON {file_url}"},
|
| 161 |
+
status_code=400)
|
| 162 |
tmp_path = await _download_to_temp(url)
|
| 163 |
|
| 164 |
+
# STT
|
| 165 |
+
stt = await _call_stt_transcribe_file(tmp_path)
|
| 166 |
+
if not stt.get("ok"):
|
| 167 |
+
write_event({"type":"relay","ok":False,"stage":"stt","err":stt.get("error")})
|
| 168 |
+
return JSONResponse({"ok": False, "error": f"STT failed: {stt.get('error')}"}, status_code=502)
|
| 169 |
+
|
| 170 |
+
text = (stt.get("text") or "").strip()
|
| 171 |
if not text:
|
| 172 |
+
write_event({"type":"relay","ok":False,"stage":"stt","err":"empty transcript"})
|
| 173 |
return JSONResponse({"ok": False, "error": "No speech detected"}, status_code=422)
|
| 174 |
|
| 175 |
+
# Brain reply (for demo we just echo; you can replace with actual brain logic later)
|
| 176 |
+
reply_text = f"I heard: {text}"[:800]
|
|
|
|
| 177 |
|
| 178 |
+
write_event({"type":"relay","ok":True,"heard_len":len(text),"voice":voice,"rate_wpm":rate_wpm})
|
|
|
|
| 179 |
|
| 180 |
+
# TTS proxy stream (immediate)
|
| 181 |
+
return await _proxy_tts_wav_stream(reply_text, voice, rate_wpm, noise_scale, noise_w)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 182 |
|
| 183 |
except Exception as e:
|
| 184 |
+
write_event({"type":"relay","ok":False,"err":str(e)})
|
| 185 |
return JSONResponse({"ok": False, "error": str(e)}, status_code=500)
|
| 186 |
finally:
|
| 187 |
try:
|
|
|
|
| 190 |
except Exception:
|
| 191 |
pass
|
| 192 |
|
| 193 |
+
# Optional: serve saved files if you decide to persist later
|
| 194 |
+
@app.get("/files/{name}")
|
| 195 |
+
def get_file(name: str):
|
| 196 |
+
path = os.path.join(FILES_DIR, name)
|
| 197 |
+
if not os.path.exists(path):
|
| 198 |
+
return JSONResponse({"ok": False, "error": "not found"}, status_code=404)
|
| 199 |
+
return FileResponse(path, media_type="application/octet-stream", filename=name)
|
| 200 |
+
|
| 201 |
if __name__ == "__main__":
|
| 202 |
import uvicorn
|
| 203 |
uvicorn.run("brain_app:app", host="0.0.0.0", port=7861, reload=False)
|