Update app.py
Browse files
app.py
CHANGED
|
@@ -1,20 +1,29 @@
|
|
| 1 |
-
import os, json, time, asyncio
|
| 2 |
-
from typing import AsyncGenerator, Dict, Any
|
| 3 |
-
from fastapi import FastAPI, Request, Response
|
| 4 |
-
from fastapi.responses import JSONResponse, StreamingResponse
|
| 5 |
|
| 6 |
-
# Directories
|
| 7 |
-
BASE_DIR
|
| 8 |
FILES_DIR = os.path.join(BASE_DIR, "files")
|
| 9 |
-
LOGS_DIR
|
| 10 |
EVENTS_FILE = os.path.join(LOGS_DIR, "events.jsonl")
|
| 11 |
|
| 12 |
for p in (BASE_DIR, FILES_DIR, LOGS_DIR):
|
| 13 |
os.makedirs(p, exist_ok=True)
|
| 14 |
|
| 15 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
|
| 17 |
-
#
|
|
|
|
|
|
|
|
|
|
| 18 |
log_queue: "asyncio.Queue[Dict[str, Any]]" = asyncio.Queue()
|
| 19 |
|
| 20 |
def write_event(event: Dict[str, Any]) -> None:
|
|
@@ -23,12 +32,21 @@ def write_event(event: Dict[str, Any]) -> None:
|
|
| 23 |
os.makedirs(LOGS_DIR, exist_ok=True)
|
| 24 |
with open(EVENTS_FILE, "a", encoding="utf-8") as f:
|
| 25 |
f.write(json.dumps(event, ensure_ascii=False) + "\n")
|
| 26 |
-
# Put to queue without awaiting (called from sync context)
|
| 27 |
try:
|
| 28 |
log_queue.put_nowait(event)
|
| 29 |
except asyncio.QueueFull:
|
| 30 |
pass
|
| 31 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
@app.get("/health")
|
| 33 |
def health():
|
| 34 |
return {
|
|
@@ -37,6 +55,7 @@ def health():
|
|
| 37 |
"time": time.time(),
|
| 38 |
"files_dir": FILES_DIR,
|
| 39 |
"logs_dir": LOGS_DIR,
|
|
|
|
| 40 |
}
|
| 41 |
|
| 42 |
@app.post("/process")
|
|
@@ -51,11 +70,12 @@ async def process(req: Request):
|
|
| 51 |
write_event(event)
|
| 52 |
return {"ok": True, "received": payload}
|
| 53 |
|
|
|
|
| 54 |
@app.get("/stream/logs")
|
| 55 |
async def stream_logs() -> StreamingResponse:
|
| 56 |
"""Server-Sent Events stream of log events (one per line)."""
|
| 57 |
async def gen() -> AsyncGenerator[bytes, None]:
|
| 58 |
-
#
|
| 59 |
try:
|
| 60 |
if os.path.exists(EVENTS_FILE):
|
| 61 |
with open(EVENTS_FILE, "r", encoding="utf-8") as f:
|
|
@@ -64,7 +84,7 @@ async def stream_logs() -> StreamingResponse:
|
|
| 64 |
except Exception:
|
| 65 |
pass
|
| 66 |
|
| 67 |
-
#
|
| 68 |
while True:
|
| 69 |
event = await log_queue.get()
|
| 70 |
line = json.dumps(event, ensure_ascii=False)
|
|
@@ -82,4 +102,156 @@ async def log_error(req: Request):
|
|
| 82 |
return JSONResponse({"ok": False, "error": "Invalid JSON body"}, status_code=400)
|
| 83 |
event = {"type": "error", "data": payload}
|
| 84 |
write_event(event)
|
| 85 |
-
return {"ok": True}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os, json, time, asyncio, base64
|
| 2 |
+
from typing import AsyncGenerator, Dict, Any, Optional
|
| 3 |
+
from fastapi import FastAPI, Request, Response, Query, BackgroundTasks
|
| 4 |
+
from fastapi.responses import JSONResponse, StreamingResponse, FileResponse
|
| 5 |
|
| 6 |
+
# ========== Directories ==========
|
| 7 |
+
BASE_DIR = os.environ.get("BASE_DIR", "/tmp/brain_app")
|
| 8 |
FILES_DIR = os.path.join(BASE_DIR, "files")
|
| 9 |
+
LOGS_DIR = os.path.join(FILES_DIR, "logs")
|
| 10 |
EVENTS_FILE = os.path.join(LOGS_DIR, "events.jsonl")
|
| 11 |
|
| 12 |
for p in (BASE_DIR, FILES_DIR, LOGS_DIR):
|
| 13 |
os.makedirs(p, exist_ok=True)
|
| 14 |
|
| 15 |
+
# ========== TTS Config ==========
|
| 16 |
+
TTS_BASE = os.environ.get("TTS_BASE", "https://Percy3822-ActualTTS.hf.space") # your Space
|
| 17 |
+
# prosody baseline: length_scale = BASE_WPM / rate_wpm (clamped)
|
| 18 |
+
BASE_WPM = int(os.environ.get("BASE_WPM", "180"))
|
| 19 |
+
NOISE_SCALE = float(os.environ.get("NOISE_SCALE", "0.33"))
|
| 20 |
+
NOISE_W = float(os.environ.get("NOISE_W", "0.92"))
|
| 21 |
+
DEFAULT_VOICE = os.environ.get("DEFAULT_VOICE", "en_US-amy-medium")
|
| 22 |
|
| 23 |
+
# ========== App ==========
|
| 24 |
+
app = FastAPI(title="Brain Skeleton", version="1.1.0 (with TTS)")
|
| 25 |
+
|
| 26 |
+
# In-memory queue to fan-out logs to /stream/logs clients
|
| 27 |
log_queue: "asyncio.Queue[Dict[str, Any]]" = asyncio.Queue()
|
| 28 |
|
| 29 |
def write_event(event: Dict[str, Any]) -> None:
|
|
|
|
| 32 |
os.makedirs(LOGS_DIR, exist_ok=True)
|
| 33 |
with open(EVENTS_FILE, "a", encoding="utf-8") as f:
|
| 34 |
f.write(json.dumps(event, ensure_ascii=False) + "\n")
|
|
|
|
| 35 |
try:
|
| 36 |
log_queue.put_nowait(event)
|
| 37 |
except asyncio.QueueFull:
|
| 38 |
pass
|
| 39 |
|
| 40 |
+
def clamp_rate(rate_wpm: Optional[int]) -> int:
|
| 41 |
+
if not isinstance(rate_wpm, int):
|
| 42 |
+
return BASE_WPM
|
| 43 |
+
return max(80, min(320, rate_wpm))
|
| 44 |
+
|
| 45 |
+
def rate_to_length_scale(rate_wpm: Optional[int]) -> float:
|
| 46 |
+
r = clamp_rate(rate_wpm)
|
| 47 |
+
return round(BASE_WPM / float(r), 3)
|
| 48 |
+
|
| 49 |
+
# ========== Health & Basics ==========
|
| 50 |
@app.get("/health")
|
| 51 |
def health():
|
| 52 |
return {
|
|
|
|
| 55 |
"time": time.time(),
|
| 56 |
"files_dir": FILES_DIR,
|
| 57 |
"logs_dir": LOGS_DIR,
|
| 58 |
+
"tts_base": TTS_BASE,
|
| 59 |
}
|
| 60 |
|
| 61 |
@app.post("/process")
|
|
|
|
| 70 |
write_event(event)
|
| 71 |
return {"ok": True, "received": payload}
|
| 72 |
|
| 73 |
+
# ========== SSE Logs ==========
|
| 74 |
@app.get("/stream/logs")
|
| 75 |
async def stream_logs() -> StreamingResponse:
|
| 76 |
"""Server-Sent Events stream of log events (one per line)."""
|
| 77 |
async def gen() -> AsyncGenerator[bytes, None]:
|
| 78 |
+
# Send recent lines on connect (optional)
|
| 79 |
try:
|
| 80 |
if os.path.exists(EVENTS_FILE):
|
| 81 |
with open(EVENTS_FILE, "r", encoding="utf-8") as f:
|
|
|
|
| 84 |
except Exception:
|
| 85 |
pass
|
| 86 |
|
| 87 |
+
# Live stream
|
| 88 |
while True:
|
| 89 |
event = await log_queue.get()
|
| 90 |
line = json.dumps(event, ensure_ascii=False)
|
|
|
|
| 102 |
return JSONResponse({"ok": False, "error": "Invalid JSON body"}, status_code=400)
|
| 103 |
event = {"type": "error", "data": payload}
|
| 104 |
write_event(event)
|
| 105 |
+
return {"ok": True}
|
| 106 |
+
|
| 107 |
+
# ========== TTS: JSON (file URL) ==========
|
| 108 |
+
@app.post("/tts/say")
|
| 109 |
+
async def tts_say_json(req: Request):
|
| 110 |
+
"""
|
| 111 |
+
POST JSON → call TTS /speak (JSON) and return audio_url.
|
| 112 |
+
Body:
|
| 113 |
+
{
|
| 114 |
+
"text": "Hello world",
|
| 115 |
+
"voice": "en_US-amy-medium", # optional
|
| 116 |
+
"rate_wpm": 165, # optional (maps to length_scale)
|
| 117 |
+
"length_scale": 1.05, # optional (overrides rate_wpm)
|
| 118 |
+
"noise_scale": 0.33, # optional
|
| 119 |
+
"noise_w": 0.92 # optional
|
| 120 |
+
}
|
| 121 |
+
"""
|
| 122 |
+
try:
|
| 123 |
+
body = await req.json()
|
| 124 |
+
except Exception:
|
| 125 |
+
return JSONResponse({"ok": False, "error": "Invalid JSON body"}, status_code=400)
|
| 126 |
+
|
| 127 |
+
text = (body.get("text") or "").strip()
|
| 128 |
+
if not text:
|
| 129 |
+
return JSONResponse({"ok": False, "error": "Missing text"}, status_code=400)
|
| 130 |
+
|
| 131 |
+
voice = (body.get("voice") or DEFAULT_VOICE).strip()
|
| 132 |
+
length_scale = float(body["length_scale"]) if "length_scale" in body else rate_to_length_scale(int(body.get("rate_wpm", BASE_WPM)))
|
| 133 |
+
noise_scale = float(body.get("noise_scale", NOISE_SCALE))
|
| 134 |
+
noise_w = float(body.get("noise_w", NOISE_W))
|
| 135 |
+
|
| 136 |
+
# Call TTS Space /speak (JSON)
|
| 137 |
+
import httpx
|
| 138 |
+
payload = {
|
| 139 |
+
"text": text,
|
| 140 |
+
"voice": voice,
|
| 141 |
+
"length_scale": length_scale,
|
| 142 |
+
"noise_scale": noise_scale,
|
| 143 |
+
"noise_w": noise_w,
|
| 144 |
+
}
|
| 145 |
+
|
| 146 |
+
async with httpx.AsyncClient(timeout=180) as client:
|
| 147 |
+
resp = await client.post(f"{TTS_BASE}/speak", json=payload)
|
| 148 |
+
ok = resp.status_code == 200
|
| 149 |
+
data = {}
|
| 150 |
+
try:
|
| 151 |
+
data = resp.json()
|
| 152 |
+
except Exception:
|
| 153 |
+
pass
|
| 154 |
+
|
| 155 |
+
event = {"type": "tts_say_json", "data": {"text_len": len(text), "voice": voice, "ok": ok, "tts_resp": data}}
|
| 156 |
+
write_event(event)
|
| 157 |
+
|
| 158 |
+
if not ok or not data.get("ok"):
|
| 159 |
+
return JSONResponse({"ok": False, "error": data.get("error") if data else f"TTS error {resp.status_code}"}, status_code=500)
|
| 160 |
+
|
| 161 |
+
# Return TTS audio_url directly
|
| 162 |
+
return {"ok": True, "audio_url": data["audio_url"], "voice": voice, "length_scale": length_scale}
|
| 163 |
+
|
| 164 |
+
# ========== TTS: Direct WAV Proxy ==========
|
| 165 |
+
async def _proxy_tts_wav_stream(
|
| 166 |
+
text: str,
|
| 167 |
+
voice: str,
|
| 168 |
+
length_scale: float,
|
| 169 |
+
noise_scale: float,
|
| 170 |
+
noise_w: float,
|
| 171 |
+
save_local: bool = False
|
| 172 |
+
) -> StreamingResponse:
|
| 173 |
+
"""
|
| 174 |
+
GET TTS /speak.wav and stream the WAV to the caller.
|
| 175 |
+
If save_local is True, also tee to a local file under FILES_DIR.
|
| 176 |
+
"""
|
| 177 |
+
import httpx
|
| 178 |
+
params = {
|
| 179 |
+
"text": text,
|
| 180 |
+
"voice": voice,
|
| 181 |
+
"length_scale": f"{length_scale:.3f}",
|
| 182 |
+
"noise_scale": f"{noise_scale:.3f}",
|
| 183 |
+
"noise_w": f"{noise_w:.3f}",
|
| 184 |
+
}
|
| 185 |
+
ts = int(time.time() * 1000)
|
| 186 |
+
local_path = os.path.join(FILES_DIR, f"say-{ts}.wav") if save_local else None
|
| 187 |
+
|
| 188 |
+
async def gen():
|
| 189 |
+
async with httpx.AsyncClient(timeout=None) as client:
|
| 190 |
+
async with client.stream("GET", f"{TTS_BASE}/speak.wav", params=params) as resp:
|
| 191 |
+
if resp.status_code != 200:
|
| 192 |
+
# bubble JSON error from TTS if any
|
| 193 |
+
err_body = await resp.aread()
|
| 194 |
+
yield err_body # still return something; caller will see non-wav
|
| 195 |
+
return
|
| 196 |
+
# stream body and optionally save
|
| 197 |
+
f = None
|
| 198 |
+
try:
|
| 199 |
+
if local_path:
|
| 200 |
+
f = open(local_path, "wb")
|
| 201 |
+
async for chunk in resp.aiter_bytes():
|
| 202 |
+
if chunk:
|
| 203 |
+
if f: f.write(chunk)
|
| 204 |
+
yield chunk
|
| 205 |
+
finally:
|
| 206 |
+
if f: f.close()
|
| 207 |
+
|
| 208 |
+
headers = {"Cache-Control": "no-cache"}
|
| 209 |
+
if local_path:
|
| 210 |
+
headers["X-Local-Path"] = local_path
|
| 211 |
+
return StreamingResponse(gen(), media_type="audio/wav", headers=headers)
|
| 212 |
+
|
| 213 |
+
@app.get("/tts/say.wav")
|
| 214 |
+
async def tts_say_wav_get(
|
| 215 |
+
text: str = Query(..., description="Text to synthesize"),
|
| 216 |
+
voice: str = Query(DEFAULT_VOICE, description="Voice id from the TTS Space"),
|
| 217 |
+
rate_wpm: Optional[int] = Query(None, description="Words-per-minute; maps to length_scale"),
|
| 218 |
+
length_scale: Optional[float] = Query(None, description="Override prosody (else derived from rate_wpm)"),
|
| 219 |
+
noise_scale: float = Query(NOISE_SCALE),
|
| 220 |
+
noise_w: float = Query(NOISE_W),
|
| 221 |
+
save: bool = Query(False, description="Also save under /files")
|
| 222 |
+
):
|
| 223 |
+
ls = float(length_scale) if length_scale is not None else rate_to_length_scale(rate_wpm if rate_wpm is not None else BASE_WPM)
|
| 224 |
+
|
| 225 |
+
write_event({"type": "tts_say_wav_get", "data": {"len": len(text), "voice": voice, "ls": ls, "save": save}})
|
| 226 |
+
return await _proxy_tts_wav_stream(text, voice, ls, noise_scale, noise_w, save_local=save)
|
| 227 |
+
|
| 228 |
+
@app.post("/tts/say.wav")
|
| 229 |
+
async def tts_say_wav_post(req: Request, save: bool = Query(False, description="Also save under /files")):
|
| 230 |
+
"""
|
| 231 |
+
POST JSON → stream back audio/wav
|
| 232 |
+
{ "text": "...", "voice": "en_US-amy-medium", "rate_wpm": 165 }
|
| 233 |
+
"""
|
| 234 |
+
try:
|
| 235 |
+
body = await req.json()
|
| 236 |
+
except Exception:
|
| 237 |
+
return JSONResponse({"ok": False, "error": "Invalid JSON body"}, status_code=400)
|
| 238 |
+
|
| 239 |
+
text = (body.get("text") or "").strip()
|
| 240 |
+
if not text:
|
| 241 |
+
return JSONResponse({"ok": False, "error": "Missing text"}, status_code=400)
|
| 242 |
+
|
| 243 |
+
voice = (body.get("voice") or DEFAULT_VOICE).strip()
|
| 244 |
+
ls = float(body["length_scale"]) if "length_scale" in body else rate_to_length_scale(int(body.get("rate_wpm", BASE_WPM)))
|
| 245 |
+
ns = float(body.get("noise_scale", NOISE_SCALE))
|
| 246 |
+
nw = float(body.get("noise_w", NOISE_W))
|
| 247 |
+
|
| 248 |
+
write_event({"type": "tts_say_wav_post", "data": {"len": len(text), "voice": voice, "ls": ls, "save": save}})
|
| 249 |
+
return await _proxy_tts_wav_stream(text, voice, ls, ns, nw, save_local=save)
|
| 250 |
+
|
| 251 |
+
# ========== Serve saved files (if you used save=true) ==========
|
| 252 |
+
@app.get("/files/{name}")
|
| 253 |
+
def get_saved_file(name: str):
|
| 254 |
+
path = os.path.join(FILES_DIR, name)
|
| 255 |
+
if not os.path.exists(path):
|
| 256 |
+
return JSONResponse({"ok": False, "error": "not found"}, status_code=404)
|
| 257 |
+
return FileResponse(path, media_type="audio/wav", filename=name)
|