Brain / app.py
Percy3822's picture
Update app.py
3816405 verified
# brain_app.py — Brain Space: STT → TTS coordinator + LIVE TTS streaming proxy
import os, json, time, asyncio, tempfile
from typing import AsyncGenerator, Dict, Any, Optional
from fastapi import FastAPI, Request, Query, UploadFile
from fastapi.responses import JSONResponse, StreamingResponse, FileResponse
import httpx
import websockets
# === Directories ===
BASE_DIR = os.environ.get("BASE_DIR", "/tmp/brain_app")
FILES_DIR = os.path.join(BASE_DIR, "files")
LOGS_DIR = os.path.join(FILES_DIR, "logs")
EVENTS_FILE = os.path.join(LOGS_DIR, "events.jsonl")
for p in (BASE_DIR, FILES_DIR, LOGS_DIR):
os.makedirs(p, exist_ok=True)
# === External Spaces ===
TTS_BASE = os.environ.get("TTS_BASE", "https://Percy3822-ActualTTS.hf.space")
STT_BASE = os.environ.get("STT_BASE", "https://Percy3822-ActualSTT.hf.space") # set to your STT Space
# === TTS defaults ===
DEFAULT_VOICE = os.environ.get("DEFAULT_VOICE", "en_US-amy-medium")
BASE_WPM = int(os.environ.get("BASE_WPM", "165"))
NOISE_SCALE = float(os.environ.get("NOISE_SCALE", "0.33"))
NOISE_W = float(os.environ.get("NOISE_W", "0.92"))
app = FastAPI(title="Brain Space (STT→TTS coordinator)", version="3.1.0")
log_queue: "asyncio.Queue[Dict[str, Any]]" = asyncio.Queue()
def write_event(event: Dict[str, Any]) -> None:
event.setdefault("ts", time.time())
with open(EVENTS_FILE, "a", encoding="utf-8") as f:
f.write(json.dumps(event, ensure_ascii=False) + "\n")
try:
log_queue.put_nowait(event)
except asyncio.QueueFull:
pass
def rate_to_length_scale(rate_wpm: Optional[int]) -> float:
base = BASE_WPM
if not isinstance(rate_wpm, int):
return 1.0
r = max(80, min(320, rate_wpm))
return round(base / float(r), 3)
def _tts_ws_url() -> str:
"""
Build the TTS WebSocket URL from TTS_BASE.
e.g. https://Percy3822-ActualTTS.hf.space -> wss://Percy3822-ActualTTS.hf.space/ws/tts
"""
base = (TTS_BASE or "").rstrip("/")
if base.startswith("https://"):
return "wss://" + base[len("https://"):] + "/ws/tts"
if base.startswith("http://"):
return "ws://" + base[len("http://"):] + "/ws/tts"
return (base + "/ws/tts") if not base.endswith("/ws/tts") else base
def _wav_header(sr: int, ch: int) -> bytes:
"""Minimal PCM16 WAV header with large data size for streaming."""
bits = 16
byte_rate = sr * ch * (bits // 8)
block_align = ch * (bits // 8)
data_size = 0x7FFFFFFF
riff_size = (36 + data_size) & 0xFFFFFFFF
return (
b"RIFF" +
riff_size.to_bytes(4, "little") +
b"WAVE" +
b"fmt " + (16).to_bytes(4, "little") +
(1).to_bytes(2, "little") + # PCM
(ch).to_bytes(2, "little") +
(sr).to_bytes(4, "little") +
(byte_rate).to_bytes(4, "little") +
(block_align).to_bytes(2, "little") +
(bits).to_bytes(2, "little") +
b"data" + data_size.to_bytes(4, "little")
)
# ---------- Health ----------
@app.get("/health")
def health():
return {
"ok": True,
"service": "brain-space",
"time": time.time(),
"files_dir": FILES_DIR,
"tts_base": TTS_BASE,
"stt_base": STT_BASE,
"defaults": {"voice": DEFAULT_VOICE, "rate_wpm": BASE_WPM}
}
# ---------- SSE logs (optional) ----------
@app.get("/stream/logs")
async def stream_logs() -> StreamingResponse:
async def gen() -> AsyncGenerator[bytes, None]:
try:
if os.path.exists(EVENTS_FILE):
with open(EVENTS_FILE, "r", encoding="utf-8") as f:
for line in f.readlines()[-50:]:
yield b"data: " + line.encode("utf-8").rstrip(b"\n") + b"\n\n"
except Exception:
pass
while True:
event = await log_queue.get()
yield b"data: " + json.dumps(event, ensure_ascii=False).encode("utf-8") + b"\n\n"
return StreamingResponse(gen(), media_type="text/event-stream",
headers={"Cache-Control":"no-cache","Connection":"keep-alive"})
# ---------- TTS proxy streaming (/tts/say.wav) ----------
# GET: /tts/say.wav?text=...&voice=...&rate_wpm=165
# POST: JSON {"text": "...", "voice": "...", "rate_wpm": 165}
async def _proxy_tts_wav_stream(text: str, voice: str, rate_wpm: Optional[int],
noise_scale: float, noise_w: float) -> StreamingResponse:
length_scale = rate_to_length_scale(rate_wpm) if rate_wpm is not None else rate_to_length_scale(BASE_WPM)
params = {
"text": text,
"voice": voice,
"length_scale": f"{length_scale:.3f}",
"noise_scale": f"{noise_scale:.3f}",
"noise_w": f"{noise_w:.3f}",
}
async def gen():
async with httpx.AsyncClient(timeout=None) as client:
async with client.stream("GET", f"{TTS_BASE}/speak.wav", params=params) as resp:
if resp.status_code != 200:
yield (await resp.aread())
return
async for chunk in resp.aiter_bytes():
if chunk:
yield chunk
return StreamingResponse(gen(), media_type="audio/wav", headers={"Cache-Control":"no-cache"})
@app.get("/tts/say.wav")
async def tts_say_wav_get(
text: str = Query(..., description="Text to synthesize"),
voice: str = Query(DEFAULT_VOICE),
rate_wpm: Optional[int] = Query(BASE_WPM),
noise_scale: float = Query(NOISE_SCALE),
noise_w: float = Query(NOISE_W),
):
write_event({"type":"tts_get","len":len(text),"voice":voice,"rate_wpm":rate_wpm})
return await _proxy_tts_wav_stream(text, voice, rate_wpm, noise_scale, noise_w)
@app.post("/tts/say.wav")
async def tts_say_wav_post(req: Request):
try:
body = await req.json()
except Exception:
return JSONResponse({"ok": False, "error": "Invalid JSON body"}, status_code=400)
text = (body.get("text") or "").strip()
if not text:
return JSONResponse({"ok": False, "error": "Missing text"}, status_code=400)
voice = (body.get("voice") or DEFAULT_VOICE).strip()
rate_wpm = int(body.get("rate_wpm", BASE_WPM)) if body.get("rate_wpm") is not None else BASE_WPM
noise_s = float(body.get("noise_scale", NOISE_SCALE))
noise_wgt = float(body.get("noise_w", NOISE_W))
write_event({"type":"tts_post","len":len(text),"voice":voice,"rate_wpm":rate_wpm})
return await _proxy_tts_wav_stream(text, voice, rate_wpm, noise_s, noise_wgt)
# ---------- LIVE TTS WS → HTTP WAV streaming ----------
# GET: /tts/say.stream.wav?text=...&voice=...&rate_wpm=165
@app.get("/tts/say.stream.wav")
async def tts_say_stream_wav(
text: str = Query(..., description="Text to synthesize (live)"),
voice: str = Query(DEFAULT_VOICE),
rate_wpm: Optional[int] = Query(BASE_WPM),
length_scale: Optional[float] = Query(None),
noise_scale: float = Query(NOISE_SCALE),
noise_w: float = Query(NOISE_W),
):
"""
LIVE streaming proxy: TTS WS (raw PCM16) -> HTTP chunked WAV.
Starts emitting audio as soon as the TTS starts producing frames.
"""
ls = float(length_scale) if length_scale is not None else rate_to_length_scale(rate_wpm or BASE_WPM)
write_event({"type":"tts_stream_get","len":len(text),"voice":voice,"ls":ls})
async def gen():
ws = None
try:
ws_url = _tts_ws_url()
ws = await websockets.connect(ws_url, ping_interval=None, max_size=8_000_000)
# init
await ws.send(json.dumps({
"event": "init",
"voice": voice,
"length_scale": ls,
"noise_scale": noise_scale,
"noise_w": noise_w,
}))
sr, ch = 22050, 1
# wait for ready -> send WAV header immediately
while True:
m = await ws.recv()
if isinstance(m, (bytes, bytearray)):
# ignore until we know sr/ch
continue
try:
evt = json.loads(m)
except Exception:
continue
if evt.get("event") == "ready":
sr = int(evt.get("sr", 22050))
ch = int(evt.get("channels", 1))
yield _wav_header(sr, ch)
break
if evt.get("event") == "error":
yield f'ERROR: {evt.get("detail","tts init error")}'.encode("utf-8")
return
# speak
await ws.send(json.dumps({"event": "speak", "text": text}))
# pump frames
while True:
try:
msg = await ws.recv()
except websockets.exceptions.ConnectionClosed:
break
if isinstance(msg, (bytes, bytearray)):
if msg:
yield msg
continue
try:
evt = json.loads(msg)
except Exception:
continue
k = evt.get("event")
if k in ("done", "end"):
break
if k == "error":
d = evt.get("detail", "tts error")
yield f'ERROR: {d}'.encode("utf-8")
break
# ignore logs
except Exception as e:
write_event({"type":"tts_stream_err","err":str(e)})
yield b""
finally:
try:
if ws is not None:
await ws.close()
except Exception:
pass
return StreamingResponse(gen(), media_type="audio/wav",
headers={"Cache-Control":"no-cache","Connection":"keep-alive"})
# ---------- Optional: serve saved files later ----------
@app.get("/files/{name}")
def get_file(name: str):
path = os.path.join(FILES_DIR, name)
if not os.path.exists(path):
return JSONResponse({"ok": False, "error": "not found"}, status_code=404)
return FileResponse(path, media_type="application/octet-stream", filename=name)
if __name__ == "__main__":
import uvicorn
uvicorn.run("brain_app:app", host="0.0.0.0", port=7861, reload=False)