Percy3822 commited on
Commit
b9cabc3
·
verified ·
1 Parent(s): ab5c984

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +99 -298
app.py CHANGED
@@ -1,39 +1,34 @@
 
1
  import os, json, time, asyncio, tempfile
2
- from typing import AsyncGenerator, Dict, Any, Optional, List
3
- from fastapi import FastAPI, Request, Query, UploadFile, File
4
  from fastapi.responses import JSONResponse, StreamingResponse, FileResponse
5
 
6
- # ========== Directories ==========
7
  BASE_DIR = os.environ.get("BASE_DIR", "/tmp/brain_app")
8
  FILES_DIR = os.path.join(BASE_DIR, "files")
9
  LOGS_DIR = os.path.join(FILES_DIR, "logs")
10
  EVENTS_FILE = os.path.join(LOGS_DIR, "events.jsonl")
11
-
12
  for p in (BASE_DIR, FILES_DIR, LOGS_DIR):
13
  os.makedirs(p, exist_ok=True)
14
 
15
- # ========== TTS Config ==========
16
- TTS_BASE = os.environ.get("TTS_BASE", "https://Percy3822-ActualTTS.hf.space")
17
- BASE_WPM = int(os.environ.get("BASE_WPM", "180"))
18
- NOISE_SCALE = float(os.environ.get("NOISE_SCALE", "0.33"))
19
- NOISE_W = float(os.environ.get("NOISE_W", "0.92"))
20
- DEFAULT_VOICE = os.environ.get("DEFAULT_VOICE", "en_US-amy-medium")
21
-
22
- # ========== STT Config ==========
23
- STT_MODEL = os.environ.get("STT_MODEL", "base.en") # faster-whisper model id
24
- STT_DEVICE = os.environ.get("STT_DEVICE", "cpu") # "cpu" | "cuda"
25
- STT_COMPUTE = os.environ.get("STT_COMPUTE", "int8") # "int8"|"int8_float16"|"float32"
26
- STT_MAXLEN_S = float(os.environ.get("STT_MAXLEN_S", "600")) # refuse extremely long uploads
27
 
28
- # ========== App ==========
29
- app = FastAPI(title="Brain Space (TTS+STT)", version="2.1.0")
 
 
 
30
 
31
- # In-memory queue to fan-out logs to /stream/logs clients
 
32
  log_queue: "asyncio.Queue[Dict[str, Any]]" = asyncio.Queue()
33
 
34
  def write_event(event: Dict[str, Any]) -> None:
35
  event.setdefault("ts", time.time())
36
- os.makedirs(LOGS_DIR, exist_ok=True)
37
  with open(EVENTS_FILE, "a", encoding="utf-8") as f:
38
  f.write(json.dumps(event, ensure_ascii=False) + "\n")
39
  try:
@@ -41,16 +36,13 @@ def write_event(event: Dict[str, Any]) -> None:
41
  except asyncio.QueueFull:
42
  pass
43
 
44
- def clamp_rate(rate_wpm: Optional[int]) -> int:
45
- if not isinstance(rate_wpm, int):
46
- return BASE_WPM
47
- return max(80, min(320, rate_wpm))
48
-
49
  def rate_to_length_scale(rate_wpm: Optional[int]) -> float:
50
- r = clamp_rate(rate_wpm)
51
- return round(BASE_WPM / float(r), 3)
 
 
 
52
 
53
- # ========== Health & Basics ==========
54
  @app.get("/health")
55
  def health():
56
  return {
@@ -58,24 +50,12 @@ def health():
58
  "service": "brain-space",
59
  "time": time.time(),
60
  "files_dir": FILES_DIR,
61
- "logs_dir": LOGS_DIR,
62
  "tts_base": TTS_BASE,
63
- "stt_model": STT_MODEL,
64
- "stt_device": STT_DEVICE,
65
- "stt_compute": STT_COMPUTE,
66
  }
67
 
68
- @app.post("/process")
69
- async def process(req: Request):
70
- try:
71
- payload = await req.json()
72
- except Exception:
73
- return JSONResponse({"ok": False, "error": "Invalid JSON body"}, status_code=400)
74
- event = {"type": "process", "data": payload}
75
- write_event(event)
76
- return {"ok": True, "received": payload}
77
-
78
- # ========== SSE Logs ==========
79
  @app.get("/stream/logs")
80
  async def stream_logs() -> StreamingResponse:
81
  async def gen() -> AsyncGenerator[bytes, None]:
@@ -88,72 +68,40 @@ async def stream_logs() -> StreamingResponse:
88
  pass
89
  while True:
90
  event = await log_queue.get()
91
- line = json.dumps(event, ensure_ascii=False)
92
- yield b"data: " + line.encode("utf-8") + b"\n\n"
93
- headers = {"Cache-Control": "no-cache", "Connection": "keep-alive"}
94
- return StreamingResponse(gen(), media_type="text/event-stream", headers=headers)
95
-
96
- @app.post("/log_error")
97
- async def log_error(req: Request):
98
- try:
99
- payload = await req.json()
100
- except Exception:
101
- return JSONResponse({"ok": False, "error": "Invalid JSON body"}, status_code=400)
102
- event = {"type": "error", "data": payload}
103
- write_event(event)
104
- return {"ok": True}
105
-
106
- # ========== TTS: JSON (file URL) ==========
107
- @app.post("/tts/say")
108
- async def tts_say_json(req: Request):
109
- """
110
- POST JSON -> call TTS /speak (JSON) and return audio_url and audio_url_full.
111
- Body:
112
- {
113
- "text": "Hello",
114
- "voice": "en_US-amy-medium",
115
- "rate_wpm": 165, # optional (maps to length_scale)
116
- "length_scale": 1.05, # optional (overrides rate_wpm)
117
- "noise_scale": 0.33, # optional
118
- "noise_w": 0.92 # optional
119
- }
120
- """
121
- try:
122
- body = await req.json()
123
- except Exception:
124
- return JSONResponse({"ok": False, "error": "Invalid JSON body"}, status_code=400)
125
 
126
- text = (body.get("text") or "").strip()
127
- if not text:
128
- return JSONResponse({"ok": False, "error": "Missing text"}, status_code=400)
129
-
130
- voice = (body.get("voice") or DEFAULT_VOICE).strip()
131
- length_scale = float(body["length_scale"]) if "length_scale" in body else rate_to_length_scale(int(body.get("rate_wpm", BASE_WPM)))
132
- noise_scale = float(body.get("noise_scale", NOISE_SCALE))
133
- noise_w = float(body.get("noise_w", NOISE_W))
 
 
 
 
 
134
 
 
 
135
  import httpx
136
- payload = {"text": text, "voice": voice, "length_scale": length_scale, "noise_scale": noise_scale, "noise_w": noise_w}
137
- async with httpx.AsyncClient(timeout=180) as client:
138
- resp = await client.post(f"{TTS_BASE}/speak", json=payload)
139
- ok = resp.status_code == 200
140
- try:
141
- data = resp.json()
142
- except Exception:
143
- data = None
144
-
145
- write_event({"type": "tts_say_json", "data": {"text_len": len(text), "voice": voice, "ok": ok, "resp": data}})
146
-
147
- if not ok or not data or not data.get("ok"):
148
- return JSONResponse({"ok": False, "error": (data or {}).get("error", f"TTS error {resp.status_code}")}, status_code=500)
149
-
150
- audio_url = data["audio_url"]
151
- audio_url_full = audio_url if audio_url.startswith("http") else f"{TTS_BASE}{audio_url}"
152
- return {"ok": True, "audio_url": audio_url, "audio_url_full": audio_url_full, "voice": voice, "length_scale": length_scale}
153
-
154
- # ========== TTS: Direct WAV Proxy ==========
155
- async def _proxy_tts_wav_stream(text: str, voice: str, length_scale: float, noise_scale: float, noise_w: float, save_local: bool = False) -> StreamingResponse:
156
  import httpx
 
157
  params = {
158
  "text": text,
159
  "voice": voice,
@@ -161,233 +109,79 @@ async def _proxy_tts_wav_stream(text: str, voice: str, length_scale: float, nois
161
  "noise_scale": f"{noise_scale:.3f}",
162
  "noise_w": f"{noise_w:.3f}",
163
  }
164
- ts = int(time.time() * 1000)
165
- local_path = os.path.join(FILES_DIR, f"say-{ts}.wav") if save_local else None
166
-
167
  async def gen():
168
  async with httpx.AsyncClient(timeout=None) as client:
169
  async with client.stream("GET", f"{TTS_BASE}/speak.wav", params=params) as resp:
170
  if resp.status_code != 200:
171
- err_body = await resp.aread()
172
- yield err_body
173
  return
174
- f = None
175
- try:
176
- if local_path:
177
- f = open(local_path, "wb")
178
- async for chunk in resp.aiter_bytes():
179
- if chunk:
180
- if f: f.write(chunk)
181
- yield chunk
182
- finally:
183
- if f: f.close()
184
-
185
- headers = {"Cache-Control": "no-cache"}
186
- if local_path:
187
- headers["X-Local-Path"] = local_path
188
- return StreamingResponse(gen(), media_type="audio/wav", headers=headers)
189
-
190
- @app.get("/tts/say.wav")
191
- async def tts_say_wav_get(
192
- text: str = Query(..., description="Text to synthesize"),
193
- voice: str = Query(DEFAULT_VOICE, description="Voice id"),
194
- rate_wpm: Optional[int] = Query(None, description="Words-per-minute"),
195
- length_scale: Optional[float] = Query(None, description="Override length_scale"),
196
- noise_scale: float = Query(NOISE_SCALE),
197
- noise_w: float = Query(NOISE_W),
198
- save: bool = Query(False, description="Also save under /files"),
199
- ):
200
- ls = float(length_scale) if length_scale is not None else rate_to_length_scale(rate_wpm if rate_wpm is not None else BASE_WPM)
201
- write_event({"type": "tts_say_wav_get", "data": {"len": len(text), "voice": voice, "ls": ls, "save": save}})
202
- return await _proxy_tts_wav_stream(text, voice, ls, noise_scale, noise_w, save_local=save)
203
-
204
- @app.post("/tts/say.wav")
205
- async def tts_say_wav_post(req: Request, save: bool = Query(False, description="Also save under /files")):
206
- try:
207
- body = await req.json()
208
- except Exception:
209
- return JSONResponse({"ok": False, "error": "Invalid JSON body"}, status_code=400)
210
- text = (body.get("text") or "").strip()
211
- if not text:
212
- return JSONResponse({"ok": False, "error": "Missing text"}, status_code=400)
213
- voice = (body.get("voice") or DEFAULT_VOICE).strip()
214
- ls = float(body["length_scale"]) if "length_scale" in body else rate_to_length_scale(int(body.get("rate_wpm", BASE_WPM)))
215
- ns = float(body.get("noise_scale", NOISE_SCALE))
216
- nw = float(body.get("noise_w", NOISE_W))
217
- write_event({"type": "tts_say_wav_post", "data": {"len": len(text), "voice": voice, "ls": ls, "save": save}})
218
- return await _proxy_tts_wav_stream(text, voice, ls, ns, nw, save_local=save)
219
-
220
- # ========== Serve saved files ==========
221
- @app.get("/files/{name}")
222
- def get_saved_file(name: str):
223
- path = os.path.join(FILES_DIR, name)
224
- if not os.path.exists(path):
225
- return JSONResponse({"ok": False, "error": "not found"}, status_code=404)
226
- return FileResponse(path, media_type="audio/wav", filename=name)
227
-
228
- # ========== STT (faster-whisper) ==========
229
- _model = None
230
- def _stt_model():
231
- global _model
232
- if _model is None:
233
- from faster_whisper import WhisperModel
234
- _model = WhisperModel(STT_MODEL, device=STT_DEVICE, compute_type=STT_COMPUTE)
235
- return _model
236
-
237
- async def _download_to_temp(url: str) -> str:
238
- import httpx
239
- _, ext = os.path.splitext(url.split("?")[0])
240
- if not ext: ext = ".wav"
241
- fd, tmp_path = tempfile.mkstemp(prefix="stt_", suffix=ext)
242
- os.close(fd)
243
- async with httpx.AsyncClient(timeout=300) as client:
244
- r = await client.get(url)
245
- r.raise_for_status()
246
- with open(tmp_path, "wb") as f:
247
- f.write(r.content)
248
- return tmp_path
249
-
250
- def _transcribe_path(path: str, language: Optional[str] = None) -> Dict[str, Any]:
251
- model = _stt_model()
252
- segments, info = model.transcribe(
253
- path,
254
- language=language, # "en" or None for auto
255
- beam_size=5,
256
- vad_filter=False,
257
- word_timestamps=False
258
- )
259
- out_segments: List[Dict[str, Any]] = []
260
- txt_parts: List[str] = []
261
- dur = getattr(info, "duration", None)
262
- for seg in segments:
263
- out_segments.append({"start": seg.start, "end": seg.end, "text": seg.text})
264
- txt_parts.append(seg.text)
265
- # guard against absurdly long files
266
- if STT_MAXLEN_S and dur and seg.end and float(seg.end) > STT_MAXLEN_S:
267
- break
268
- text = "".join(txt_parts).strip()
269
- return {"text": text, "language": getattr(info, "language", language or "unknown"), "duration": dur, "segments": out_segments}
270
-
271
- @app.post("/stt/transcribe")
272
- async def stt_transcribe(
273
- req: Request,
274
- language: Optional[str] = Query(None, description="ISO code like 'en' (None = auto)"),
275
- file_url: Optional[str] = Query(None, description="If provided via query")
276
- ):
277
- """
278
- POST either:
279
- - multipart/form-data with 'audio' file
280
- - or JSON: { "file_url": "https://..." }
281
- - or query param ?file_url=...
282
- Returns: { ok, text, language, duration, segments:[...] }
283
- """
284
- tmp_path = None
285
- try:
286
- content_type = req.headers.get("content-type","").lower()
287
- if "multipart/form-data" in content_type:
288
- form = await req.form()
289
- up: UploadFile = form.get("audio") # key: audio
290
- if not up:
291
- return JSONResponse({"ok": False, "error": "Missing 'audio' file"}, status_code=400)
292
- suffix = os.path.splitext(up.filename or "")[1] or ".wav"
293
- fd, tmp_path = tempfile.mkstemp(prefix="stt_", suffix=suffix)
294
- os.close(fd)
295
- with open(tmp_path, "wb") as f:
296
- f.write(await up.read())
297
- else:
298
- # JSON or query
299
- try:
300
- body = await req.json()
301
- except Exception:
302
- body = {}
303
- url = file_url or (body.get("file_url") if isinstance(body, dict) else None)
304
- if not url:
305
- return JSONResponse({"ok": False, "error": "Provide file_url (JSON/query) or multipart 'audio' file"}, status_code=400)
306
- tmp_path = await _download_to_temp(url)
307
-
308
- res = _transcribe_path(tmp_path, language=language)
309
- write_event({"type": "stt_transcribe", "data": {"ok": True, "language": res.get("language"), "dur": res.get("duration"), "text_len": len(res.get("text",""))}})
310
- return {"ok": True, **res}
311
- except Exception as e:
312
- write_event({"type": "stt_transcribe", "data": {"ok": False, "error": str(e)}})
313
- return JSONResponse({"ok": False, "error": str(e)}, status_code=500)
314
- finally:
315
- try:
316
- if tmp_path and os.path.exists(tmp_path):
317
- os.unlink(tmp_path)
318
- except Exception:
319
- pass
320
-
321
- # --- End-to-end: STT -> Brain -> TTS (streamed WAV) ---
322
- @app.post("/demo/echo.wav")
323
- async def demo_echo_wav(
324
  req: Request,
325
- voice: str = Query(DEFAULT_VOICE, description="Voice id (TTS)"),
326
- rate_wpm: Optional[int] = Query(None, description="Words-per-minute -> length_scale"),
327
- length_scale: Optional[float] = Query(None, description="Override prosody"),
328
  noise_scale: float = Query(NOISE_SCALE),
329
  noise_w: float = Query(NOISE_W),
330
- save: bool = Query(False, description="Also save output WAV under /files"),
331
  ):
332
  """
333
- POST either:
334
- - multipart/form-data with 'audio' file
335
- - or JSON: { "file_url": "https://..." }
336
- Returns: streaming audio/wav that says what it heard.
337
  """
338
  tmp_path = None
339
  try:
340
- # --- Ingest audio (multipart or JSON URL) ---
341
- content_type = req.headers.get("content-type", "").lower()
342
- if "multipart/form-data" in content_type:
343
  form = await req.form()
344
- up = form.get("audio") # UploadFile
345
  if not up:
346
  return JSONResponse({"ok": False, "error": "Missing 'audio' file"}, status_code=400)
347
- suffix = os.path.splitext(getattr(up, "filename", "") or "")[1] or ".wav"
348
- fd, tmp_path = tempfile.mkstemp(prefix="demo_echo_", suffix=suffix)
 
349
  os.close(fd)
350
  with open(tmp_path, "wb") as f:
351
  f.write(await up.read())
352
  else:
353
- # JSON with file_url
354
  try:
355
  body = await req.json()
356
  except Exception:
357
  body = {}
358
  url = (body or {}).get("file_url")
359
  if not url:
360
- return JSONResponse({"ok": False, "error": "Provide multipart 'audio' or JSON {file_url}"}, status_code=400)
 
361
  tmp_path = await _download_to_temp(url)
362
 
363
- # --- STT ---
364
- stt_res = _transcribe_path(tmp_path, language=None)
365
- text = (stt_res.get("text") or "").strip()
 
 
 
 
366
  if not text:
367
- write_event({"type": "demo_echo", "data": {"ok": False, "error": "No speech detected"}})
368
  return JSONResponse({"ok": False, "error": "No speech detected"}, status_code=422)
369
 
370
- # --- Brain reply (simple confirmation) ---
371
- reply_text = f"I heard: {text}"
372
- reply_text = reply_text[:800] # safety bound
373
 
374
- # Prosody parameters
375
- ls = float(length_scale) if length_scale is not None else rate_to_length_scale(rate_wpm if rate_wpm is not None else BASE_WPM)
376
 
377
- write_event({"type": "demo_echo", "data": {"ok": True, "heard_len": len(text), "voice": voice, "ls": ls, "save": save}})
378
-
379
- # --- TTS (stream WAV back to the caller) ---
380
- return await _proxy_tts_wav_stream(
381
- text=reply_text,
382
- voice=voice,
383
- length_scale=ls,
384
- noise_scale=noise_scale,
385
- noise_w=noise_w,
386
- save_local=save
387
- )
388
 
389
  except Exception as e:
390
- write_event({"type": "demo_echo", "data": {"ok": False, "error": str(e)}})
391
  return JSONResponse({"ok": False, "error": str(e)}, status_code=500)
392
  finally:
393
  try:
@@ -396,7 +190,14 @@ async def demo_echo_wav(
396
  except Exception:
397
  pass
398
 
399
- # ========== Optional direct runner ==========
 
 
 
 
 
 
 
400
  if __name__ == "__main__":
401
  import uvicorn
402
  uvicorn.run("brain_app:app", host="0.0.0.0", port=7861, reload=False)
 
1
+ # brain_app.py — Brain Space: STT → TTS proxy streamer
2
  import os, json, time, asyncio, tempfile
3
+ from typing import AsyncGenerator, Dict, Any, Optional
4
+ from fastapi import FastAPI, Request, Query, UploadFile
5
  from fastapi.responses import JSONResponse, StreamingResponse, FileResponse
6
 
7
+ # === Directories ===
8
  BASE_DIR = os.environ.get("BASE_DIR", "/tmp/brain_app")
9
  FILES_DIR = os.path.join(BASE_DIR, "files")
10
  LOGS_DIR = os.path.join(FILES_DIR, "logs")
11
  EVENTS_FILE = os.path.join(LOGS_DIR, "events.jsonl")
 
12
  for p in (BASE_DIR, FILES_DIR, LOGS_DIR):
13
  os.makedirs(p, exist_ok=True)
14
 
15
+ # === External Spaces ===
16
+ TTS_BASE = os.environ.get("TTS_BASE", "https://Percy3822-ActualTTS.hf.space")
17
+ # TODO: set your STT Space base here (example):
18
+ STT_BASE = os.environ.get("STT_BASE", "https://YOUR-STT-SPACE.hf.space")
 
 
 
 
 
 
 
 
19
 
20
+ # === TTS defaults ===
21
+ DEFAULT_VOICE = os.environ.get("DEFAULT_VOICE", "en_US-amy-medium")
22
+ BASE_WPM = int(os.environ.get("BASE_WPM", "165"))
23
+ NOISE_SCALE = float(os.environ.get("NOISE_SCALE", "0.33"))
24
+ NOISE_W = float(os.environ.get("NOISE_W", "0.92"))
25
 
26
+ # === App ===
27
+ app = FastAPI(title="Brain Space (STT→TTS coordinator)", version="3.0.0")
28
  log_queue: "asyncio.Queue[Dict[str, Any]]" = asyncio.Queue()
29
 
30
  def write_event(event: Dict[str, Any]) -> None:
31
  event.setdefault("ts", time.time())
 
32
  with open(EVENTS_FILE, "a", encoding="utf-8") as f:
33
  f.write(json.dumps(event, ensure_ascii=False) + "\n")
34
  try:
 
36
  except asyncio.QueueFull:
37
  pass
38
 
 
 
 
 
 
39
  def rate_to_length_scale(rate_wpm: Optional[int]) -> float:
40
+ base = BASE_WPM
41
+ if not isinstance(rate_wpm, int):
42
+ return 1.0
43
+ r = max(80, min(320, rate_wpm))
44
+ return round(base / float(r), 3)
45
 
 
46
  @app.get("/health")
47
  def health():
48
  return {
 
50
  "service": "brain-space",
51
  "time": time.time(),
52
  "files_dir": FILES_DIR,
 
53
  "tts_base": TTS_BASE,
54
+ "stt_base": STT_BASE,
55
+ "defaults": {"voice": DEFAULT_VOICE, "rate_wpm": BASE_WPM}
 
56
  }
57
 
58
+ # ========== SSE logs (optional) ==========
 
 
 
 
 
 
 
 
 
 
59
  @app.get("/stream/logs")
60
  async def stream_logs() -> StreamingResponse:
61
  async def gen() -> AsyncGenerator[bytes, None]:
 
68
  pass
69
  while True:
70
  event = await log_queue.get()
71
+ yield b"data: " + json.dumps(event, ensure_ascii=False).encode("utf-8") + b"\n\n"
72
+ return StreamingResponse(gen(), media_type="text/event-stream",
73
+ headers={"Cache-Control":"no-cache","Connection":"keep-alive"})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
 
75
+ # ---------- Helpers ----------
76
+ async def _download_to_temp(url: str) -> str:
77
+ import httpx, os
78
+ _, ext = os.path.splitext(url.split("?")[0])
79
+ if not ext: ext = ".wav"
80
+ fd, tmp_path = tempfile.mkstemp(prefix="mic_", suffix=ext)
81
+ os.close(fd)
82
+ async with httpx.AsyncClient(timeout=300) as client:
83
+ r = await client.get(url)
84
+ r.raise_for_status()
85
+ with open(tmp_path, "wb") as f:
86
+ f.write(r.content)
87
+ return tmp_path
88
 
89
+ async def _call_stt_transcribe_file(path: str) -> Dict[str, Any]:
90
+ """POST multipart 'audio' to STT /stt/transcribe and return its JSON."""
91
  import httpx
92
+ stt_url = f"{STT_BASE}/stt/transcribe"
93
+ files = {"audio": (os.path.basename(path), open(path, "rb"), "audio/wav")}
94
+ async with httpx.AsyncClient(timeout=300) as client:
95
+ r = await client.post(stt_url, files=files)
96
+ ok = r.status_code == 200
97
+ data = r.json() if ok else {"ok": False, "error": f"STT {r.status_code}"}
98
+ return data
99
+
100
+ async def _proxy_tts_wav_stream(text: str, voice: str, rate_wpm: Optional[int],
101
+ noise_scale: float, noise_w: float) -> StreamingResponse:
102
+ """Proxy stream from TTS /speak.wav based on text."""
 
 
 
 
 
 
 
 
 
103
  import httpx
104
+ length_scale = rate_to_length_scale(rate_wpm) if rate_wpm is not None else rate_to_length_scale(BASE_WPM)
105
  params = {
106
  "text": text,
107
  "voice": voice,
 
109
  "noise_scale": f"{noise_scale:.3f}",
110
  "noise_w": f"{noise_w:.3f}",
111
  }
 
 
 
112
  async def gen():
113
  async with httpx.AsyncClient(timeout=None) as client:
114
  async with client.stream("GET", f"{TTS_BASE}/speak.wav", params=params) as resp:
115
  if resp.status_code != 200:
116
+ yield (await resp.aread())
 
117
  return
118
+ async for chunk in resp.aiter_bytes():
119
+ if chunk:
120
+ yield chunk
121
+ return StreamingResponse(gen(), media_type="audio/wav", headers={"Cache-Control":"no-cache"})
122
+
123
+ # ========== The simple end-to-end endpoint ==========
124
+ @app.post("/demo/relay.wav")
125
+ async def demo_relay_wav(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
  req: Request,
127
+ voice: str = Query(DEFAULT_VOICE),
128
+ rate_wpm: Optional[int] = Query(BASE_WPM),
 
129
  noise_scale: float = Query(NOISE_SCALE),
130
  noise_w: float = Query(NOISE_W),
 
131
  ):
132
  """
133
+ Accept 5s mic recording from client (multipart 'audio' or JSON {file_url}),
134
+ send to STT Space for transcription, then IMMEDIATELY proxy stream TTS WAV
135
+ that speaks back what was heard.
 
136
  """
137
  tmp_path = None
138
  try:
139
+ # Ingest audio
140
+ ctype = (req.headers.get("content-type") or "").lower()
141
+ if "multipart/form-data" in ctype:
142
  form = await req.form()
143
+ up: UploadFile = form.get("audio")
144
  if not up:
145
  return JSONResponse({"ok": False, "error": "Missing 'audio' file"}, status_code=400)
146
+ import os, tempfile
147
+ suffix = os.path.splitext(up.filename or "")[1] or ".wav"
148
+ fd, tmp_path = tempfile.mkstemp(prefix="mic_", suffix=suffix)
149
  os.close(fd)
150
  with open(tmp_path, "wb") as f:
151
  f.write(await up.read())
152
  else:
153
+ # JSON with {file_url}
154
  try:
155
  body = await req.json()
156
  except Exception:
157
  body = {}
158
  url = (body or {}).get("file_url")
159
  if not url:
160
+ return JSONResponse({"ok": False, "error": "Provide multipart 'audio' or JSON {file_url}"},
161
+ status_code=400)
162
  tmp_path = await _download_to_temp(url)
163
 
164
+ # STT
165
+ stt = await _call_stt_transcribe_file(tmp_path)
166
+ if not stt.get("ok"):
167
+ write_event({"type":"relay","ok":False,"stage":"stt","err":stt.get("error")})
168
+ return JSONResponse({"ok": False, "error": f"STT failed: {stt.get('error')}"}, status_code=502)
169
+
170
+ text = (stt.get("text") or "").strip()
171
  if not text:
172
+ write_event({"type":"relay","ok":False,"stage":"stt","err":"empty transcript"})
173
  return JSONResponse({"ok": False, "error": "No speech detected"}, status_code=422)
174
 
175
+ # Brain reply (for demo we just echo; you can replace with actual brain logic later)
176
+ reply_text = f"I heard: {text}"[:800]
 
177
 
178
+ write_event({"type":"relay","ok":True,"heard_len":len(text),"voice":voice,"rate_wpm":rate_wpm})
 
179
 
180
+ # TTS proxy stream (immediate)
181
+ return await _proxy_tts_wav_stream(reply_text, voice, rate_wpm, noise_scale, noise_w)
 
 
 
 
 
 
 
 
 
182
 
183
  except Exception as e:
184
+ write_event({"type":"relay","ok":False,"err":str(e)})
185
  return JSONResponse({"ok": False, "error": str(e)}, status_code=500)
186
  finally:
187
  try:
 
190
  except Exception:
191
  pass
192
 
193
+ # Optional: serve saved files if you decide to persist later
194
+ @app.get("/files/{name}")
195
+ def get_file(name: str):
196
+ path = os.path.join(FILES_DIR, name)
197
+ if not os.path.exists(path):
198
+ return JSONResponse({"ok": False, "error": "not found"}, status_code=404)
199
+ return FileResponse(path, media_type="application/octet-stream", filename=name)
200
+
201
  if __name__ == "__main__":
202
  import uvicorn
203
  uvicorn.run("brain_app:app", host="0.0.0.0", port=7861, reload=False)