Percy3822 commited on
Commit
3ac22d2
·
verified ·
1 Parent(s): 00a7c4e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +185 -13
app.py CHANGED
@@ -1,20 +1,29 @@
1
- import os, json, time, asyncio
2
- from typing import AsyncGenerator, Dict, Any
3
- from fastapi import FastAPI, Request, Response
4
- from fastapi.responses import JSONResponse, StreamingResponse
5
 
6
- # Directories (HF Spaces writable path)
7
- BASE_DIR = os.environ.get("BASE_DIR", "/tmp/brain_app")
8
  FILES_DIR = os.path.join(BASE_DIR, "files")
9
- LOGS_DIR = os.path.join(FILES_DIR, "logs")
10
  EVENTS_FILE = os.path.join(LOGS_DIR, "events.jsonl")
11
 
12
  for p in (BASE_DIR, FILES_DIR, LOGS_DIR):
13
  os.makedirs(p, exist_ok=True)
14
 
15
- app = FastAPI(title="Brain Skeleton", version="1.0.0")
 
 
 
 
 
 
16
 
17
- # Simple in-memory queue to fan-out logs to /stream/logs clients
 
 
 
18
  log_queue: "asyncio.Queue[Dict[str, Any]]" = asyncio.Queue()
19
 
20
  def write_event(event: Dict[str, Any]) -> None:
@@ -23,12 +32,21 @@ def write_event(event: Dict[str, Any]) -> None:
23
  os.makedirs(LOGS_DIR, exist_ok=True)
24
  with open(EVENTS_FILE, "a", encoding="utf-8") as f:
25
  f.write(json.dumps(event, ensure_ascii=False) + "\n")
26
- # Put to queue without awaiting (called from sync context)
27
  try:
28
  log_queue.put_nowait(event)
29
  except asyncio.QueueFull:
30
  pass
31
 
 
 
 
 
 
 
 
 
 
 
32
  @app.get("/health")
33
  def health():
34
  return {
@@ -37,6 +55,7 @@ def health():
37
  "time": time.time(),
38
  "files_dir": FILES_DIR,
39
  "logs_dir": LOGS_DIR,
 
40
  }
41
 
42
  @app.post("/process")
@@ -51,11 +70,12 @@ async def process(req: Request):
51
  write_event(event)
52
  return {"ok": True, "received": payload}
53
 
 
54
  @app.get("/stream/logs")
55
  async def stream_logs() -> StreamingResponse:
56
  """Server-Sent Events stream of log events (one per line)."""
57
  async def gen() -> AsyncGenerator[bytes, None]:
58
- # On connect, tail recent file lines so client sees immediate data (optional)
59
  try:
60
  if os.path.exists(EVENTS_FILE):
61
  with open(EVENTS_FILE, "r", encoding="utf-8") as f:
@@ -64,7 +84,7 @@ async def stream_logs() -> StreamingResponse:
64
  except Exception:
65
  pass
66
 
67
- # Now live stream
68
  while True:
69
  event = await log_queue.get()
70
  line = json.dumps(event, ensure_ascii=False)
@@ -82,4 +102,156 @@ async def log_error(req: Request):
82
  return JSONResponse({"ok": False, "error": "Invalid JSON body"}, status_code=400)
83
  event = {"type": "error", "data": payload}
84
  write_event(event)
85
- return {"ok": True}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, json, time, asyncio, base64
2
+ from typing import AsyncGenerator, Dict, Any, Optional
3
+ from fastapi import FastAPI, Request, Response, Query, BackgroundTasks
4
+ from fastapi.responses import JSONResponse, StreamingResponse, FileResponse
5
 
6
+ # ========== Directories ==========
7
+ BASE_DIR = os.environ.get("BASE_DIR", "/tmp/brain_app")
8
  FILES_DIR = os.path.join(BASE_DIR, "files")
9
+ LOGS_DIR = os.path.join(FILES_DIR, "logs")
10
  EVENTS_FILE = os.path.join(LOGS_DIR, "events.jsonl")
11
 
12
  for p in (BASE_DIR, FILES_DIR, LOGS_DIR):
13
  os.makedirs(p, exist_ok=True)
14
 
15
+ # ========== TTS Config ==========
16
+ TTS_BASE = os.environ.get("TTS_BASE", "https://Percy3822-ActualTTS.hf.space") # your Space
17
+ # prosody baseline: length_scale = BASE_WPM / rate_wpm (clamped)
18
+ BASE_WPM = int(os.environ.get("BASE_WPM", "180"))
19
+ NOISE_SCALE = float(os.environ.get("NOISE_SCALE", "0.33"))
20
+ NOISE_W = float(os.environ.get("NOISE_W", "0.92"))
21
+ DEFAULT_VOICE = os.environ.get("DEFAULT_VOICE", "en_US-amy-medium")
22
 
23
+ # ========== App ==========
24
+ app = FastAPI(title="Brain Skeleton", version="1.1.0 (with TTS)")
25
+
26
+ # In-memory queue to fan-out logs to /stream/logs clients
27
  log_queue: "asyncio.Queue[Dict[str, Any]]" = asyncio.Queue()
28
 
29
  def write_event(event: Dict[str, Any]) -> None:
 
32
  os.makedirs(LOGS_DIR, exist_ok=True)
33
  with open(EVENTS_FILE, "a", encoding="utf-8") as f:
34
  f.write(json.dumps(event, ensure_ascii=False) + "\n")
 
35
  try:
36
  log_queue.put_nowait(event)
37
  except asyncio.QueueFull:
38
  pass
39
 
40
+ def clamp_rate(rate_wpm: Optional[int]) -> int:
41
+ if not isinstance(rate_wpm, int):
42
+ return BASE_WPM
43
+ return max(80, min(320, rate_wpm))
44
+
45
+ def rate_to_length_scale(rate_wpm: Optional[int]) -> float:
46
+ r = clamp_rate(rate_wpm)
47
+ return round(BASE_WPM / float(r), 3)
48
+
49
+ # ========== Health & Basics ==========
50
  @app.get("/health")
51
  def health():
52
  return {
 
55
  "time": time.time(),
56
  "files_dir": FILES_DIR,
57
  "logs_dir": LOGS_DIR,
58
+ "tts_base": TTS_BASE,
59
  }
60
 
61
  @app.post("/process")
 
70
  write_event(event)
71
  return {"ok": True, "received": payload}
72
 
73
+ # ========== SSE Logs ==========
74
  @app.get("/stream/logs")
75
  async def stream_logs() -> StreamingResponse:
76
  """Server-Sent Events stream of log events (one per line)."""
77
  async def gen() -> AsyncGenerator[bytes, None]:
78
+ # Send recent lines on connect (optional)
79
  try:
80
  if os.path.exists(EVENTS_FILE):
81
  with open(EVENTS_FILE, "r", encoding="utf-8") as f:
 
84
  except Exception:
85
  pass
86
 
87
+ # Live stream
88
  while True:
89
  event = await log_queue.get()
90
  line = json.dumps(event, ensure_ascii=False)
 
102
  return JSONResponse({"ok": False, "error": "Invalid JSON body"}, status_code=400)
103
  event = {"type": "error", "data": payload}
104
  write_event(event)
105
+ return {"ok": True}
106
+
107
+ # ========== TTS: JSON (file URL) ==========
108
+ @app.post("/tts/say")
109
+ async def tts_say_json(req: Request):
110
+ """
111
+ POST JSON → call TTS /speak (JSON) and return audio_url.
112
+ Body:
113
+ {
114
+ "text": "Hello world",
115
+ "voice": "en_US-amy-medium", # optional
116
+ "rate_wpm": 165, # optional (maps to length_scale)
117
+ "length_scale": 1.05, # optional (overrides rate_wpm)
118
+ "noise_scale": 0.33, # optional
119
+ "noise_w": 0.92 # optional
120
+ }
121
+ """
122
+ try:
123
+ body = await req.json()
124
+ except Exception:
125
+ return JSONResponse({"ok": False, "error": "Invalid JSON body"}, status_code=400)
126
+
127
+ text = (body.get("text") or "").strip()
128
+ if not text:
129
+ return JSONResponse({"ok": False, "error": "Missing text"}, status_code=400)
130
+
131
+ voice = (body.get("voice") or DEFAULT_VOICE).strip()
132
+ length_scale = float(body["length_scale"]) if "length_scale" in body else rate_to_length_scale(int(body.get("rate_wpm", BASE_WPM)))
133
+ noise_scale = float(body.get("noise_scale", NOISE_SCALE))
134
+ noise_w = float(body.get("noise_w", NOISE_W))
135
+
136
+ # Call TTS Space /speak (JSON)
137
+ import httpx
138
+ payload = {
139
+ "text": text,
140
+ "voice": voice,
141
+ "length_scale": length_scale,
142
+ "noise_scale": noise_scale,
143
+ "noise_w": noise_w,
144
+ }
145
+
146
+ async with httpx.AsyncClient(timeout=180) as client:
147
+ resp = await client.post(f"{TTS_BASE}/speak", json=payload)
148
+ ok = resp.status_code == 200
149
+ data = {}
150
+ try:
151
+ data = resp.json()
152
+ except Exception:
153
+ pass
154
+
155
+ event = {"type": "tts_say_json", "data": {"text_len": len(text), "voice": voice, "ok": ok, "tts_resp": data}}
156
+ write_event(event)
157
+
158
+ if not ok or not data.get("ok"):
159
+ return JSONResponse({"ok": False, "error": data.get("error") if data else f"TTS error {resp.status_code}"}, status_code=500)
160
+
161
+ # Return TTS audio_url directly
162
+ return {"ok": True, "audio_url": data["audio_url"], "voice": voice, "length_scale": length_scale}
163
+
164
+ # ========== TTS: Direct WAV Proxy ==========
165
+ async def _proxy_tts_wav_stream(
166
+ text: str,
167
+ voice: str,
168
+ length_scale: float,
169
+ noise_scale: float,
170
+ noise_w: float,
171
+ save_local: bool = False
172
+ ) -> StreamingResponse:
173
+ """
174
+ GET TTS /speak.wav and stream the WAV to the caller.
175
+ If save_local is True, also tee to a local file under FILES_DIR.
176
+ """
177
+ import httpx
178
+ params = {
179
+ "text": text,
180
+ "voice": voice,
181
+ "length_scale": f"{length_scale:.3f}",
182
+ "noise_scale": f"{noise_scale:.3f}",
183
+ "noise_w": f"{noise_w:.3f}",
184
+ }
185
+ ts = int(time.time() * 1000)
186
+ local_path = os.path.join(FILES_DIR, f"say-{ts}.wav") if save_local else None
187
+
188
+ async def gen():
189
+ async with httpx.AsyncClient(timeout=None) as client:
190
+ async with client.stream("GET", f"{TTS_BASE}/speak.wav", params=params) as resp:
191
+ if resp.status_code != 200:
192
+ # bubble JSON error from TTS if any
193
+ err_body = await resp.aread()
194
+ yield err_body # still return something; caller will see non-wav
195
+ return
196
+ # stream body and optionally save
197
+ f = None
198
+ try:
199
+ if local_path:
200
+ f = open(local_path, "wb")
201
+ async for chunk in resp.aiter_bytes():
202
+ if chunk:
203
+ if f: f.write(chunk)
204
+ yield chunk
205
+ finally:
206
+ if f: f.close()
207
+
208
+ headers = {"Cache-Control": "no-cache"}
209
+ if local_path:
210
+ headers["X-Local-Path"] = local_path
211
+ return StreamingResponse(gen(), media_type="audio/wav", headers=headers)
212
+
213
+ @app.get("/tts/say.wav")
214
+ async def tts_say_wav_get(
215
+ text: str = Query(..., description="Text to synthesize"),
216
+ voice: str = Query(DEFAULT_VOICE, description="Voice id from the TTS Space"),
217
+ rate_wpm: Optional[int] = Query(None, description="Words-per-minute; maps to length_scale"),
218
+ length_scale: Optional[float] = Query(None, description="Override prosody (else derived from rate_wpm)"),
219
+ noise_scale: float = Query(NOISE_SCALE),
220
+ noise_w: float = Query(NOISE_W),
221
+ save: bool = Query(False, description="Also save under /files")
222
+ ):
223
+ ls = float(length_scale) if length_scale is not None else rate_to_length_scale(rate_wpm if rate_wpm is not None else BASE_WPM)
224
+
225
+ write_event({"type": "tts_say_wav_get", "data": {"len": len(text), "voice": voice, "ls": ls, "save": save}})
226
+ return await _proxy_tts_wav_stream(text, voice, ls, noise_scale, noise_w, save_local=save)
227
+
228
+ @app.post("/tts/say.wav")
229
+ async def tts_say_wav_post(req: Request, save: bool = Query(False, description="Also save under /files")):
230
+ """
231
+ POST JSON → stream back audio/wav
232
+ { "text": "...", "voice": "en_US-amy-medium", "rate_wpm": 165 }
233
+ """
234
+ try:
235
+ body = await req.json()
236
+ except Exception:
237
+ return JSONResponse({"ok": False, "error": "Invalid JSON body"}, status_code=400)
238
+
239
+ text = (body.get("text") or "").strip()
240
+ if not text:
241
+ return JSONResponse({"ok": False, "error": "Missing text"}, status_code=400)
242
+
243
+ voice = (body.get("voice") or DEFAULT_VOICE).strip()
244
+ ls = float(body["length_scale"]) if "length_scale" in body else rate_to_length_scale(int(body.get("rate_wpm", BASE_WPM)))
245
+ ns = float(body.get("noise_scale", NOISE_SCALE))
246
+ nw = float(body.get("noise_w", NOISE_W))
247
+
248
+ write_event({"type": "tts_say_wav_post", "data": {"len": len(text), "voice": voice, "ls": ls, "save": save}})
249
+ return await _proxy_tts_wav_stream(text, voice, ls, ns, nw, save_local=save)
250
+
251
+ # ========== Serve saved files (if you used save=true) ==========
252
+ @app.get("/files/{name}")
253
+ def get_saved_file(name: str):
254
+ path = os.path.join(FILES_DIR, name)
255
+ if not os.path.exists(path):
256
+ return JSONResponse({"ok": False, "error": "not found"}, status_code=404)
257
+ return FileResponse(path, media_type="audio/wav", filename=name)