|
|
|
|
|
|
|
|
import os, json, time, asyncio, tempfile |
|
|
from typing import AsyncGenerator, Dict, Any, Optional |
|
|
from fastapi import FastAPI, Request, Query, UploadFile |
|
|
from fastapi.responses import JSONResponse, StreamingResponse, FileResponse |
|
|
|
|
|
import httpx |
|
|
import websockets |
|
|
|
|
|
|
|
|
BASE_DIR = os.environ.get("BASE_DIR", "/tmp/brain_app") |
|
|
FILES_DIR = os.path.join(BASE_DIR, "files") |
|
|
LOGS_DIR = os.path.join(FILES_DIR, "logs") |
|
|
EVENTS_FILE = os.path.join(LOGS_DIR, "events.jsonl") |
|
|
for p in (BASE_DIR, FILES_DIR, LOGS_DIR): |
|
|
os.makedirs(p, exist_ok=True) |
|
|
|
|
|
|
|
|
TTS_BASE = os.environ.get("TTS_BASE", "https://Percy3822-ActualTTS.hf.space") |
|
|
STT_BASE = os.environ.get("STT_BASE", "https://Percy3822-ActualSTT.hf.space") |
|
|
|
|
|
|
|
|
DEFAULT_VOICE = os.environ.get("DEFAULT_VOICE", "en_US-amy-medium") |
|
|
BASE_WPM = int(os.environ.get("BASE_WPM", "165")) |
|
|
NOISE_SCALE = float(os.environ.get("NOISE_SCALE", "0.33")) |
|
|
NOISE_W = float(os.environ.get("NOISE_W", "0.92")) |
|
|
|
|
|
app = FastAPI(title="Brain Space (STT→TTS coordinator)", version="3.1.0") |
|
|
log_queue: "asyncio.Queue[Dict[str, Any]]" = asyncio.Queue() |
|
|
|
|
|
def write_event(event: Dict[str, Any]) -> None: |
|
|
event.setdefault("ts", time.time()) |
|
|
with open(EVENTS_FILE, "a", encoding="utf-8") as f: |
|
|
f.write(json.dumps(event, ensure_ascii=False) + "\n") |
|
|
try: |
|
|
log_queue.put_nowait(event) |
|
|
except asyncio.QueueFull: |
|
|
pass |
|
|
|
|
|
def rate_to_length_scale(rate_wpm: Optional[int]) -> float: |
|
|
base = BASE_WPM |
|
|
if not isinstance(rate_wpm, int): |
|
|
return 1.0 |
|
|
r = max(80, min(320, rate_wpm)) |
|
|
return round(base / float(r), 3) |
|
|
|
|
|
def _tts_ws_url() -> str: |
|
|
""" |
|
|
Build the TTS WebSocket URL from TTS_BASE. |
|
|
e.g. https://Percy3822-ActualTTS.hf.space -> wss://Percy3822-ActualTTS.hf.space/ws/tts |
|
|
""" |
|
|
base = (TTS_BASE or "").rstrip("/") |
|
|
if base.startswith("https://"): |
|
|
return "wss://" + base[len("https://"):] + "/ws/tts" |
|
|
if base.startswith("http://"): |
|
|
return "ws://" + base[len("http://"):] + "/ws/tts" |
|
|
return (base + "/ws/tts") if not base.endswith("/ws/tts") else base |
|
|
|
|
|
def _wav_header(sr: int, ch: int) -> bytes: |
|
|
"""Minimal PCM16 WAV header with large data size for streaming.""" |
|
|
bits = 16 |
|
|
byte_rate = sr * ch * (bits // 8) |
|
|
block_align = ch * (bits // 8) |
|
|
data_size = 0x7FFFFFFF |
|
|
riff_size = (36 + data_size) & 0xFFFFFFFF |
|
|
return ( |
|
|
b"RIFF" + |
|
|
riff_size.to_bytes(4, "little") + |
|
|
b"WAVE" + |
|
|
b"fmt " + (16).to_bytes(4, "little") + |
|
|
(1).to_bytes(2, "little") + |
|
|
(ch).to_bytes(2, "little") + |
|
|
(sr).to_bytes(4, "little") + |
|
|
(byte_rate).to_bytes(4, "little") + |
|
|
(block_align).to_bytes(2, "little") + |
|
|
(bits).to_bytes(2, "little") + |
|
|
b"data" + data_size.to_bytes(4, "little") |
|
|
) |
|
|
|
|
|
|
|
|
@app.get("/health") |
|
|
def health(): |
|
|
return { |
|
|
"ok": True, |
|
|
"service": "brain-space", |
|
|
"time": time.time(), |
|
|
"files_dir": FILES_DIR, |
|
|
"tts_base": TTS_BASE, |
|
|
"stt_base": STT_BASE, |
|
|
"defaults": {"voice": DEFAULT_VOICE, "rate_wpm": BASE_WPM} |
|
|
} |
|
|
|
|
|
|
|
|
@app.get("/stream/logs") |
|
|
async def stream_logs() -> StreamingResponse: |
|
|
async def gen() -> AsyncGenerator[bytes, None]: |
|
|
try: |
|
|
if os.path.exists(EVENTS_FILE): |
|
|
with open(EVENTS_FILE, "r", encoding="utf-8") as f: |
|
|
for line in f.readlines()[-50:]: |
|
|
yield b"data: " + line.encode("utf-8").rstrip(b"\n") + b"\n\n" |
|
|
except Exception: |
|
|
pass |
|
|
while True: |
|
|
event = await log_queue.get() |
|
|
yield b"data: " + json.dumps(event, ensure_ascii=False).encode("utf-8") + b"\n\n" |
|
|
return StreamingResponse(gen(), media_type="text/event-stream", |
|
|
headers={"Cache-Control":"no-cache","Connection":"keep-alive"}) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def _proxy_tts_wav_stream(text: str, voice: str, rate_wpm: Optional[int], |
|
|
noise_scale: float, noise_w: float) -> StreamingResponse: |
|
|
length_scale = rate_to_length_scale(rate_wpm) if rate_wpm is not None else rate_to_length_scale(BASE_WPM) |
|
|
params = { |
|
|
"text": text, |
|
|
"voice": voice, |
|
|
"length_scale": f"{length_scale:.3f}", |
|
|
"noise_scale": f"{noise_scale:.3f}", |
|
|
"noise_w": f"{noise_w:.3f}", |
|
|
} |
|
|
async def gen(): |
|
|
async with httpx.AsyncClient(timeout=None) as client: |
|
|
async with client.stream("GET", f"{TTS_BASE}/speak.wav", params=params) as resp: |
|
|
if resp.status_code != 200: |
|
|
yield (await resp.aread()) |
|
|
return |
|
|
async for chunk in resp.aiter_bytes(): |
|
|
if chunk: |
|
|
yield chunk |
|
|
return StreamingResponse(gen(), media_type="audio/wav", headers={"Cache-Control":"no-cache"}) |
|
|
|
|
|
@app.get("/tts/say.wav") |
|
|
async def tts_say_wav_get( |
|
|
text: str = Query(..., description="Text to synthesize"), |
|
|
voice: str = Query(DEFAULT_VOICE), |
|
|
rate_wpm: Optional[int] = Query(BASE_WPM), |
|
|
noise_scale: float = Query(NOISE_SCALE), |
|
|
noise_w: float = Query(NOISE_W), |
|
|
): |
|
|
write_event({"type":"tts_get","len":len(text),"voice":voice,"rate_wpm":rate_wpm}) |
|
|
return await _proxy_tts_wav_stream(text, voice, rate_wpm, noise_scale, noise_w) |
|
|
|
|
|
@app.post("/tts/say.wav") |
|
|
async def tts_say_wav_post(req: Request): |
|
|
try: |
|
|
body = await req.json() |
|
|
except Exception: |
|
|
return JSONResponse({"ok": False, "error": "Invalid JSON body"}, status_code=400) |
|
|
text = (body.get("text") or "").strip() |
|
|
if not text: |
|
|
return JSONResponse({"ok": False, "error": "Missing text"}, status_code=400) |
|
|
voice = (body.get("voice") or DEFAULT_VOICE).strip() |
|
|
rate_wpm = int(body.get("rate_wpm", BASE_WPM)) if body.get("rate_wpm") is not None else BASE_WPM |
|
|
noise_s = float(body.get("noise_scale", NOISE_SCALE)) |
|
|
noise_wgt = float(body.get("noise_w", NOISE_W)) |
|
|
write_event({"type":"tts_post","len":len(text),"voice":voice,"rate_wpm":rate_wpm}) |
|
|
return await _proxy_tts_wav_stream(text, voice, rate_wpm, noise_s, noise_wgt) |
|
|
|
|
|
|
|
|
|
|
|
@app.get("/tts/say.stream.wav") |
|
|
async def tts_say_stream_wav( |
|
|
text: str = Query(..., description="Text to synthesize (live)"), |
|
|
voice: str = Query(DEFAULT_VOICE), |
|
|
rate_wpm: Optional[int] = Query(BASE_WPM), |
|
|
length_scale: Optional[float] = Query(None), |
|
|
noise_scale: float = Query(NOISE_SCALE), |
|
|
noise_w: float = Query(NOISE_W), |
|
|
): |
|
|
""" |
|
|
LIVE streaming proxy: TTS WS (raw PCM16) -> HTTP chunked WAV. |
|
|
Starts emitting audio as soon as the TTS starts producing frames. |
|
|
""" |
|
|
ls = float(length_scale) if length_scale is not None else rate_to_length_scale(rate_wpm or BASE_WPM) |
|
|
write_event({"type":"tts_stream_get","len":len(text),"voice":voice,"ls":ls}) |
|
|
|
|
|
async def gen(): |
|
|
ws = None |
|
|
try: |
|
|
ws_url = _tts_ws_url() |
|
|
ws = await websockets.connect(ws_url, ping_interval=None, max_size=8_000_000) |
|
|
|
|
|
|
|
|
await ws.send(json.dumps({ |
|
|
"event": "init", |
|
|
"voice": voice, |
|
|
"length_scale": ls, |
|
|
"noise_scale": noise_scale, |
|
|
"noise_w": noise_w, |
|
|
})) |
|
|
|
|
|
sr, ch = 22050, 1 |
|
|
|
|
|
while True: |
|
|
m = await ws.recv() |
|
|
if isinstance(m, (bytes, bytearray)): |
|
|
|
|
|
continue |
|
|
try: |
|
|
evt = json.loads(m) |
|
|
except Exception: |
|
|
continue |
|
|
if evt.get("event") == "ready": |
|
|
sr = int(evt.get("sr", 22050)) |
|
|
ch = int(evt.get("channels", 1)) |
|
|
yield _wav_header(sr, ch) |
|
|
break |
|
|
if evt.get("event") == "error": |
|
|
yield f'ERROR: {evt.get("detail","tts init error")}'.encode("utf-8") |
|
|
return |
|
|
|
|
|
|
|
|
await ws.send(json.dumps({"event": "speak", "text": text})) |
|
|
|
|
|
|
|
|
while True: |
|
|
try: |
|
|
msg = await ws.recv() |
|
|
except websockets.exceptions.ConnectionClosed: |
|
|
break |
|
|
|
|
|
if isinstance(msg, (bytes, bytearray)): |
|
|
if msg: |
|
|
yield msg |
|
|
continue |
|
|
|
|
|
try: |
|
|
evt = json.loads(msg) |
|
|
except Exception: |
|
|
continue |
|
|
k = evt.get("event") |
|
|
if k in ("done", "end"): |
|
|
break |
|
|
if k == "error": |
|
|
d = evt.get("detail", "tts error") |
|
|
yield f'ERROR: {d}'.encode("utf-8") |
|
|
break |
|
|
|
|
|
|
|
|
except Exception as e: |
|
|
write_event({"type":"tts_stream_err","err":str(e)}) |
|
|
yield b"" |
|
|
finally: |
|
|
try: |
|
|
if ws is not None: |
|
|
await ws.close() |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
return StreamingResponse(gen(), media_type="audio/wav", |
|
|
headers={"Cache-Control":"no-cache","Connection":"keep-alive"}) |
|
|
|
|
|
|
|
|
@app.get("/files/{name}") |
|
|
def get_file(name: str): |
|
|
path = os.path.join(FILES_DIR, name) |
|
|
if not os.path.exists(path): |
|
|
return JSONResponse({"ok": False, "error": "not found"}, status_code=404) |
|
|
return FileResponse(path, media_type="application/octet-stream", filename=name) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
import uvicorn |
|
|
uvicorn.run("brain_app:app", host="0.0.0.0", port=7861, reload=False) |