# brain_app.py — Brain Space: STT → TTS coordinator + LIVE TTS streaming proxy import os, json, time, asyncio, tempfile from typing import AsyncGenerator, Dict, Any, Optional from fastapi import FastAPI, Request, Query, UploadFile from fastapi.responses import JSONResponse, StreamingResponse, FileResponse import httpx import websockets # === Directories === BASE_DIR = os.environ.get("BASE_DIR", "/tmp/brain_app") FILES_DIR = os.path.join(BASE_DIR, "files") LOGS_DIR = os.path.join(FILES_DIR, "logs") EVENTS_FILE = os.path.join(LOGS_DIR, "events.jsonl") for p in (BASE_DIR, FILES_DIR, LOGS_DIR): os.makedirs(p, exist_ok=True) # === External Spaces === TTS_BASE = os.environ.get("TTS_BASE", "https://Percy3822-ActualTTS.hf.space") STT_BASE = os.environ.get("STT_BASE", "https://Percy3822-ActualSTT.hf.space") # set to your STT Space # === TTS defaults === DEFAULT_VOICE = os.environ.get("DEFAULT_VOICE", "en_US-amy-medium") BASE_WPM = int(os.environ.get("BASE_WPM", "165")) NOISE_SCALE = float(os.environ.get("NOISE_SCALE", "0.33")) NOISE_W = float(os.environ.get("NOISE_W", "0.92")) app = FastAPI(title="Brain Space (STT→TTS coordinator)", version="3.1.0") log_queue: "asyncio.Queue[Dict[str, Any]]" = asyncio.Queue() def write_event(event: Dict[str, Any]) -> None: event.setdefault("ts", time.time()) with open(EVENTS_FILE, "a", encoding="utf-8") as f: f.write(json.dumps(event, ensure_ascii=False) + "\n") try: log_queue.put_nowait(event) except asyncio.QueueFull: pass def rate_to_length_scale(rate_wpm: Optional[int]) -> float: base = BASE_WPM if not isinstance(rate_wpm, int): return 1.0 r = max(80, min(320, rate_wpm)) return round(base / float(r), 3) def _tts_ws_url() -> str: """ Build the TTS WebSocket URL from TTS_BASE. e.g. https://Percy3822-ActualTTS.hf.space -> wss://Percy3822-ActualTTS.hf.space/ws/tts """ base = (TTS_BASE or "").rstrip("/") if base.startswith("https://"): return "wss://" + base[len("https://"):] + "/ws/tts" if base.startswith("http://"): return "ws://" + base[len("http://"):] + "/ws/tts" return (base + "/ws/tts") if not base.endswith("/ws/tts") else base def _wav_header(sr: int, ch: int) -> bytes: """Minimal PCM16 WAV header with large data size for streaming.""" bits = 16 byte_rate = sr * ch * (bits // 8) block_align = ch * (bits // 8) data_size = 0x7FFFFFFF riff_size = (36 + data_size) & 0xFFFFFFFF return ( b"RIFF" + riff_size.to_bytes(4, "little") + b"WAVE" + b"fmt " + (16).to_bytes(4, "little") + (1).to_bytes(2, "little") + # PCM (ch).to_bytes(2, "little") + (sr).to_bytes(4, "little") + (byte_rate).to_bytes(4, "little") + (block_align).to_bytes(2, "little") + (bits).to_bytes(2, "little") + b"data" + data_size.to_bytes(4, "little") ) # ---------- Health ---------- @app.get("/health") def health(): return { "ok": True, "service": "brain-space", "time": time.time(), "files_dir": FILES_DIR, "tts_base": TTS_BASE, "stt_base": STT_BASE, "defaults": {"voice": DEFAULT_VOICE, "rate_wpm": BASE_WPM} } # ---------- SSE logs (optional) ---------- @app.get("/stream/logs") async def stream_logs() -> StreamingResponse: async def gen() -> AsyncGenerator[bytes, None]: try: if os.path.exists(EVENTS_FILE): with open(EVENTS_FILE, "r", encoding="utf-8") as f: for line in f.readlines()[-50:]: yield b"data: " + line.encode("utf-8").rstrip(b"\n") + b"\n\n" except Exception: pass while True: event = await log_queue.get() yield b"data: " + json.dumps(event, ensure_ascii=False).encode("utf-8") + b"\n\n" return StreamingResponse(gen(), media_type="text/event-stream", headers={"Cache-Control":"no-cache","Connection":"keep-alive"}) # ---------- TTS proxy streaming (/tts/say.wav) ---------- # GET: /tts/say.wav?text=...&voice=...&rate_wpm=165 # POST: JSON {"text": "...", "voice": "...", "rate_wpm": 165} async def _proxy_tts_wav_stream(text: str, voice: str, rate_wpm: Optional[int], noise_scale: float, noise_w: float) -> StreamingResponse: length_scale = rate_to_length_scale(rate_wpm) if rate_wpm is not None else rate_to_length_scale(BASE_WPM) params = { "text": text, "voice": voice, "length_scale": f"{length_scale:.3f}", "noise_scale": f"{noise_scale:.3f}", "noise_w": f"{noise_w:.3f}", } async def gen(): async with httpx.AsyncClient(timeout=None) as client: async with client.stream("GET", f"{TTS_BASE}/speak.wav", params=params) as resp: if resp.status_code != 200: yield (await resp.aread()) return async for chunk in resp.aiter_bytes(): if chunk: yield chunk return StreamingResponse(gen(), media_type="audio/wav", headers={"Cache-Control":"no-cache"}) @app.get("/tts/say.wav") async def tts_say_wav_get( text: str = Query(..., description="Text to synthesize"), voice: str = Query(DEFAULT_VOICE), rate_wpm: Optional[int] = Query(BASE_WPM), noise_scale: float = Query(NOISE_SCALE), noise_w: float = Query(NOISE_W), ): write_event({"type":"tts_get","len":len(text),"voice":voice,"rate_wpm":rate_wpm}) return await _proxy_tts_wav_stream(text, voice, rate_wpm, noise_scale, noise_w) @app.post("/tts/say.wav") async def tts_say_wav_post(req: Request): try: body = await req.json() except Exception: return JSONResponse({"ok": False, "error": "Invalid JSON body"}, status_code=400) text = (body.get("text") or "").strip() if not text: return JSONResponse({"ok": False, "error": "Missing text"}, status_code=400) voice = (body.get("voice") or DEFAULT_VOICE).strip() rate_wpm = int(body.get("rate_wpm", BASE_WPM)) if body.get("rate_wpm") is not None else BASE_WPM noise_s = float(body.get("noise_scale", NOISE_SCALE)) noise_wgt = float(body.get("noise_w", NOISE_W)) write_event({"type":"tts_post","len":len(text),"voice":voice,"rate_wpm":rate_wpm}) return await _proxy_tts_wav_stream(text, voice, rate_wpm, noise_s, noise_wgt) # ---------- LIVE TTS WS → HTTP WAV streaming ---------- # GET: /tts/say.stream.wav?text=...&voice=...&rate_wpm=165 @app.get("/tts/say.stream.wav") async def tts_say_stream_wav( text: str = Query(..., description="Text to synthesize (live)"), voice: str = Query(DEFAULT_VOICE), rate_wpm: Optional[int] = Query(BASE_WPM), length_scale: Optional[float] = Query(None), noise_scale: float = Query(NOISE_SCALE), noise_w: float = Query(NOISE_W), ): """ LIVE streaming proxy: TTS WS (raw PCM16) -> HTTP chunked WAV. Starts emitting audio as soon as the TTS starts producing frames. """ ls = float(length_scale) if length_scale is not None else rate_to_length_scale(rate_wpm or BASE_WPM) write_event({"type":"tts_stream_get","len":len(text),"voice":voice,"ls":ls}) async def gen(): ws = None try: ws_url = _tts_ws_url() ws = await websockets.connect(ws_url, ping_interval=None, max_size=8_000_000) # init await ws.send(json.dumps({ "event": "init", "voice": voice, "length_scale": ls, "noise_scale": noise_scale, "noise_w": noise_w, })) sr, ch = 22050, 1 # wait for ready -> send WAV header immediately while True: m = await ws.recv() if isinstance(m, (bytes, bytearray)): # ignore until we know sr/ch continue try: evt = json.loads(m) except Exception: continue if evt.get("event") == "ready": sr = int(evt.get("sr", 22050)) ch = int(evt.get("channels", 1)) yield _wav_header(sr, ch) break if evt.get("event") == "error": yield f'ERROR: {evt.get("detail","tts init error")}'.encode("utf-8") return # speak await ws.send(json.dumps({"event": "speak", "text": text})) # pump frames while True: try: msg = await ws.recv() except websockets.exceptions.ConnectionClosed: break if isinstance(msg, (bytes, bytearray)): if msg: yield msg continue try: evt = json.loads(msg) except Exception: continue k = evt.get("event") if k in ("done", "end"): break if k == "error": d = evt.get("detail", "tts error") yield f'ERROR: {d}'.encode("utf-8") break # ignore logs except Exception as e: write_event({"type":"tts_stream_err","err":str(e)}) yield b"" finally: try: if ws is not None: await ws.close() except Exception: pass return StreamingResponse(gen(), media_type="audio/wav", headers={"Cache-Control":"no-cache","Connection":"keep-alive"}) # ---------- Optional: serve saved files later ---------- @app.get("/files/{name}") def get_file(name: str): path = os.path.join(FILES_DIR, name) if not os.path.exists(path): return JSONResponse({"ok": False, "error": "not found"}, status_code=404) return FileResponse(path, media_type="application/octet-stream", filename=name) if __name__ == "__main__": import uvicorn uvicorn.run("brain_app:app", host="0.0.0.0", port=7861, reload=False)