Percy3822 commited on
Commit
0bb83c8
·
verified ·
1 Parent(s): b9cabc3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -98
app.py CHANGED
@@ -1,4 +1,3 @@
1
- # brain_app.py — Brain Space: STT → TTS proxy streamer
2
  import os, json, time, asyncio, tempfile
3
  from typing import AsyncGenerator, Dict, Any, Optional
4
  from fastapi import FastAPI, Request, Query, UploadFile
@@ -14,8 +13,7 @@ for p in (BASE_DIR, FILES_DIR, LOGS_DIR):
14
 
15
  # === External Spaces ===
16
  TTS_BASE = os.environ.get("TTS_BASE", "https://Percy3822-ActualTTS.hf.space")
17
- # TODO: set your STT Space base here (example):
18
- STT_BASE = os.environ.get("STT_BASE", "https://YOUR-STT-SPACE.hf.space")
19
 
20
  # === TTS defaults ===
21
  DEFAULT_VOICE = os.environ.get("DEFAULT_VOICE", "en_US-amy-medium")
@@ -23,8 +21,7 @@ BASE_WPM = int(os.environ.get("BASE_WPM", "165"))
23
  NOISE_SCALE = float(os.environ.get("NOISE_SCALE", "0.33"))
24
  NOISE_W = float(os.environ.get("NOISE_W", "0.92"))
25
 
26
- # === App ===
27
- app = FastAPI(title="Brain Space (STT→TTS coordinator)", version="3.0.0")
28
  log_queue: "asyncio.Queue[Dict[str, Any]]" = asyncio.Queue()
29
 
30
  def write_event(event: Dict[str, Any]) -> None:
@@ -43,6 +40,7 @@ def rate_to_length_scale(rate_wpm: Optional[int]) -> float:
43
  r = max(80, min(320, rate_wpm))
44
  return round(base / float(r), 3)
45
 
 
46
  @app.get("/health")
47
  def health():
48
  return {
@@ -55,7 +53,7 @@ def health():
55
  "defaults": {"voice": DEFAULT_VOICE, "rate_wpm": BASE_WPM}
56
  }
57
 
58
- # ========== SSE logs (optional) ==========
59
  @app.get("/stream/logs")
60
  async def stream_logs() -> StreamingResponse:
61
  async def gen() -> AsyncGenerator[bytes, None]:
@@ -72,34 +70,11 @@ async def stream_logs() -> StreamingResponse:
72
  return StreamingResponse(gen(), media_type="text/event-stream",
73
  headers={"Cache-Control":"no-cache","Connection":"keep-alive"})
74
 
75
- # ---------- Helpers ----------
76
- async def _download_to_temp(url: str) -> str:
77
- import httpx, os
78
- _, ext = os.path.splitext(url.split("?")[0])
79
- if not ext: ext = ".wav"
80
- fd, tmp_path = tempfile.mkstemp(prefix="mic_", suffix=ext)
81
- os.close(fd)
82
- async with httpx.AsyncClient(timeout=300) as client:
83
- r = await client.get(url)
84
- r.raise_for_status()
85
- with open(tmp_path, "wb") as f:
86
- f.write(r.content)
87
- return tmp_path
88
-
89
- async def _call_stt_transcribe_file(path: str) -> Dict[str, Any]:
90
- """POST multipart 'audio' to STT /stt/transcribe and return its JSON."""
91
- import httpx
92
- stt_url = f"{STT_BASE}/stt/transcribe"
93
- files = {"audio": (os.path.basename(path), open(path, "rb"), "audio/wav")}
94
- async with httpx.AsyncClient(timeout=300) as client:
95
- r = await client.post(stt_url, files=files)
96
- ok = r.status_code == 200
97
- data = r.json() if ok else {"ok": False, "error": f"STT {r.status_code}"}
98
- return data
99
-
100
  async def _proxy_tts_wav_stream(text: str, voice: str, rate_wpm: Optional[int],
101
  noise_scale: float, noise_w: float) -> StreamingResponse:
102
- """Proxy stream from TTS /speak.wav based on text."""
103
  import httpx
104
  length_scale = rate_to_length_scale(rate_wpm) if rate_wpm is not None else rate_to_length_scale(BASE_WPM)
105
  params = {
@@ -113,6 +88,7 @@ async def _proxy_tts_wav_stream(text: str, voice: str, rate_wpm: Optional[int],
113
  async with httpx.AsyncClient(timeout=None) as client:
114
  async with client.stream("GET", f"{TTS_BASE}/speak.wav", params=params) as resp:
115
  if resp.status_code != 200:
 
116
  yield (await resp.aread())
117
  return
118
  async for chunk in resp.aiter_bytes():
@@ -120,77 +96,37 @@ async def _proxy_tts_wav_stream(text: str, voice: str, rate_wpm: Optional[int],
120
  yield chunk
121
  return StreamingResponse(gen(), media_type="audio/wav", headers={"Cache-Control":"no-cache"})
122
 
123
- # ========== The simple end-to-end endpoint ==========
124
- @app.post("/demo/relay.wav")
125
- async def demo_relay_wav(
126
- req: Request,
127
  voice: str = Query(DEFAULT_VOICE),
128
  rate_wpm: Optional[int] = Query(BASE_WPM),
129
  noise_scale: float = Query(NOISE_SCALE),
130
  noise_w: float = Query(NOISE_W),
131
  ):
132
- """
133
- Accept 5s mic recording from client (multipart 'audio' or JSON {file_url}),
134
- send to STT Space for transcription, then IMMEDIATELY proxy stream TTS WAV
135
- that speaks back what was heard.
136
- """
137
- tmp_path = None
138
- try:
139
- # Ingest audio
140
- ctype = (req.headers.get("content-type") or "").lower()
141
- if "multipart/form-data" in ctype:
142
- form = await req.form()
143
- up: UploadFile = form.get("audio")
144
- if not up:
145
- return JSONResponse({"ok": False, "error": "Missing 'audio' file"}, status_code=400)
146
- import os, tempfile
147
- suffix = os.path.splitext(up.filename or "")[1] or ".wav"
148
- fd, tmp_path = tempfile.mkstemp(prefix="mic_", suffix=suffix)
149
- os.close(fd)
150
- with open(tmp_path, "wb") as f:
151
- f.write(await up.read())
152
- else:
153
- # JSON with {file_url}
154
- try:
155
- body = await req.json()
156
- except Exception:
157
- body = {}
158
- url = (body or {}).get("file_url")
159
- if not url:
160
- return JSONResponse({"ok": False, "error": "Provide multipart 'audio' or JSON {file_url}"},
161
- status_code=400)
162
- tmp_path = await _download_to_temp(url)
163
-
164
- # STT
165
- stt = await _call_stt_transcribe_file(tmp_path)
166
- if not stt.get("ok"):
167
- write_event({"type":"relay","ok":False,"stage":"stt","err":stt.get("error")})
168
- return JSONResponse({"ok": False, "error": f"STT failed: {stt.get('error')}"}, status_code=502)
169
-
170
- text = (stt.get("text") or "").strip()
171
- if not text:
172
- write_event({"type":"relay","ok":False,"stage":"stt","err":"empty transcript"})
173
- return JSONResponse({"ok": False, "error": "No speech detected"}, status_code=422)
174
-
175
- # Brain reply (for demo we just echo; you can replace with actual brain logic later)
176
- reply_text = f"I heard: {text}"[:800]
177
 
178
- write_event({"type":"relay","ok":True,"heard_len":len(text),"voice":voice,"rate_wpm":rate_wpm})
179
-
180
- # TTS proxy stream (immediate)
181
- return await _proxy_tts_wav_stream(reply_text, voice, rate_wpm, noise_scale, noise_w)
182
-
183
- except Exception as e:
184
- write_event({"type":"relay","ok":False,"err":str(e)})
185
- return JSONResponse({"ok": False, "error": str(e)}, status_code=500)
186
- finally:
187
- try:
188
- if tmp_path and os.path.exists(tmp_path):
189
- os.unlink(tmp_path)
190
- except Exception:
191
- pass
192
-
193
- # Optional: serve saved files if you decide to persist later
 
 
 
 
194
  @app.get("/files/{name}")
195
  def get_file(name: str):
196
  path = os.path.join(FILES_DIR, name)
@@ -200,4 +136,4 @@ def get_file(name: str):
200
 
201
  if __name__ == "__main__":
202
  import uvicorn
203
- uvicorn.run("brain_app:app", host="0.0.0.0", port=7861, reload=False)
 
 
1
  import os, json, time, asyncio, tempfile
2
  from typing import AsyncGenerator, Dict, Any, Optional
3
  from fastapi import FastAPI, Request, Query, UploadFile
 
13
 
14
  # === External Spaces ===
15
  TTS_BASE = os.environ.get("TTS_BASE", "https://Percy3822-ActualTTS.hf.space")
16
+ STT_BASE = os.environ.get("STT_BASE", "https://Percy3822-ActualSTT.hf.space") # set to your STT Space
 
17
 
18
  # === TTS defaults ===
19
  DEFAULT_VOICE = os.environ.get("DEFAULT_VOICE", "en_US-amy-medium")
 
21
  NOISE_SCALE = float(os.environ.get("NOISE_SCALE", "0.33"))
22
  NOISE_W = float(os.environ.get("NOISE_W", "0.92"))
23
 
24
+ app = FastAPI(title="Brain Space (STT→TTS coordinator)", version="3.1.0")
 
25
  log_queue: "asyncio.Queue[Dict[str, Any]]" = asyncio.Queue()
26
 
27
  def write_event(event: Dict[str, Any]) -> None:
 
40
  r = max(80, min(320, rate_wpm))
41
  return round(base / float(r), 3)
42
 
43
+ # ---------- Health ----------
44
  @app.get("/health")
45
  def health():
46
  return {
 
53
  "defaults": {"voice": DEFAULT_VOICE, "rate_wpm": BASE_WPM}
54
  }
55
 
56
+ # ---------- SSE logs (optional) ----------
57
  @app.get("/stream/logs")
58
  async def stream_logs() -> StreamingResponse:
59
  async def gen() -> AsyncGenerator[bytes, None]:
 
70
  return StreamingResponse(gen(), media_type="text/event-stream",
71
  headers={"Cache-Control":"no-cache","Connection":"keep-alive"})
72
 
73
+ # ---------- TTS proxy streaming (/tts/say.wav) ----------
74
+ # GET: /tts/say.wav?text=...&voice=...&rate_wpm=165
75
+ # POST: JSON {"text": "...", "voice": "...", "rate_wpm": 165}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
  async def _proxy_tts_wav_stream(text: str, voice: str, rate_wpm: Optional[int],
77
  noise_scale: float, noise_w: float) -> StreamingResponse:
 
78
  import httpx
79
  length_scale = rate_to_length_scale(rate_wpm) if rate_wpm is not None else rate_to_length_scale(BASE_WPM)
80
  params = {
 
88
  async with httpx.AsyncClient(timeout=None) as client:
89
  async with client.stream("GET", f"{TTS_BASE}/speak.wav", params=params) as resp:
90
  if resp.status_code != 200:
91
+ # bubble up exact error body from TTS
92
  yield (await resp.aread())
93
  return
94
  async for chunk in resp.aiter_bytes():
 
96
  yield chunk
97
  return StreamingResponse(gen(), media_type="audio/wav", headers={"Cache-Control":"no-cache"})
98
 
99
+ @app.get("/tts/say.wav")
100
+ async def tts_say_wav_get(
101
+ text: str = Query(..., description="Text to synthesize"),
 
102
  voice: str = Query(DEFAULT_VOICE),
103
  rate_wpm: Optional[int] = Query(BASE_WPM),
104
  noise_scale: float = Query(NOISE_SCALE),
105
  noise_w: float = Query(NOISE_W),
106
  ):
107
+ write_event({"type":"tts_get","len":len(text),"voice":voice,"rate_wpm":rate_wpm})
108
+ return await _proxy_tts_wav_stream(text, voice, rate_wpm, noise_scale, noise_w)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
 
110
+ @app.post("/tts/say.wav")
111
+ async def tts_say_wav_post(req: Request):
112
+ try:
113
+ body = await req.json()
114
+ except Exception:
115
+ return JSONResponse({"ok": False, "error": "Invalid JSON body"}, status_code=400)
116
+ text = (body.get("text") or "").strip()
117
+ if not text:
118
+ return JSONResponse({"ok": False, "error": "Missing text"}, status_code=400)
119
+ voice = (body.get("voice") or DEFAULT_VOICE).strip()
120
+ rate_wpm = int(body.get("rate_wpm", BASE_WPM)) if body.get("rate_wpm") is not None else BASE_WPM
121
+ noise_s = float(body.get("noise_scale", NOISE_SCALE))
122
+ noise_wgt = float(body.get("noise_w", NOISE_W))
123
+ write_event({"type":"tts_post","len":len(text),"voice":voice,"rate_wpm":rate_wpm})
124
+ return await _proxy_tts_wav_stream(text, voice, rate_wpm, noise_s, noise_wgt)
125
+
126
+ # ---------- (Optional) simple relay demo kept for later ----------
127
+ # You can keep your /demo/relay.wav here if you still want the file-upload STT→TTS demo.
128
+
129
+ # ---------- Optional: serve saved files later ----------
130
  @app.get("/files/{name}")
131
  def get_file(name: str):
132
  path = os.path.join(FILES_DIR, name)
 
136
 
137
  if __name__ == "__main__":
138
  import uvicorn
139
+ uvicorn.run("brain_app:app", host="0.0.0.0", port=7861, reload=False)