Update app.py
Browse files
app.py
CHANGED
|
@@ -26,7 +26,7 @@ STT_COMPUTE = os.environ.get("STT_COMPUTE", "int8") # "int8"|"int8_float16"|"flo
|
|
| 26 |
STT_MAXLEN_S = float(os.environ.get("STT_MAXLEN_S", "600")) # refuse extremely long uploads
|
| 27 |
|
| 28 |
# ========== App ==========
|
| 29 |
-
app = FastAPI(title="Brain Space (TTS+STT)", version="2.
|
| 30 |
|
| 31 |
# In-memory queue to fan-out logs to /stream/logs clients
|
| 32 |
log_queue: "asyncio.Queue[Dict[str, Any]]" = asyncio.Queue()
|
|
@@ -249,7 +249,6 @@ async def _download_to_temp(url: str) -> str:
|
|
| 249 |
|
| 250 |
def _transcribe_path(path: str, language: Optional[str] = None) -> Dict[str, Any]:
|
| 251 |
model = _stt_model()
|
| 252 |
-
# NOTE: sticking to CPU-friendly settings; adjust if you move to GPU
|
| 253 |
segments, info = model.transcribe(
|
| 254 |
path,
|
| 255 |
language=language, # "en" or None for auto
|
|
@@ -263,8 +262,8 @@ def _transcribe_path(path: str, language: Optional[str] = None) -> Dict[str, Any
|
|
| 263 |
for seg in segments:
|
| 264 |
out_segments.append({"start": seg.start, "end": seg.end, "text": seg.text})
|
| 265 |
txt_parts.append(seg.text)
|
| 266 |
-
# guard against absurdly long files
|
| 267 |
-
if STT_MAXLEN_S and
|
| 268 |
break
|
| 269 |
text = "".join(txt_parts).strip()
|
| 270 |
return {"text": text, "language": getattr(info, "language", language or "unknown"), "duration": dur, "segments": out_segments}
|
|
@@ -319,6 +318,84 @@ async def stt_transcribe(
|
|
| 319 |
except Exception:
|
| 320 |
pass
|
| 321 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 322 |
# ========== Optional direct runner ==========
|
| 323 |
if __name__ == "__main__":
|
| 324 |
import uvicorn
|
|
|
|
| 26 |
STT_MAXLEN_S = float(os.environ.get("STT_MAXLEN_S", "600")) # refuse extremely long uploads
|
| 27 |
|
| 28 |
# ========== App ==========
|
| 29 |
+
app = FastAPI(title="Brain Space (TTS+STT)", version="2.1.0")
|
| 30 |
|
| 31 |
# In-memory queue to fan-out logs to /stream/logs clients
|
| 32 |
log_queue: "asyncio.Queue[Dict[str, Any]]" = asyncio.Queue()
|
|
|
|
| 249 |
|
| 250 |
def _transcribe_path(path: str, language: Optional[str] = None) -> Dict[str, Any]:
|
| 251 |
model = _stt_model()
|
|
|
|
| 252 |
segments, info = model.transcribe(
|
| 253 |
path,
|
| 254 |
language=language, # "en" or None for auto
|
|
|
|
| 262 |
for seg in segments:
|
| 263 |
out_segments.append({"start": seg.start, "end": seg.end, "text": seg.text})
|
| 264 |
txt_parts.append(seg.text)
|
| 265 |
+
# guard against absurdly long files
|
| 266 |
+
if STT_MAXLEN_S and dur and seg.end and float(seg.end) > STT_MAXLEN_S:
|
| 267 |
break
|
| 268 |
text = "".join(txt_parts).strip()
|
| 269 |
return {"text": text, "language": getattr(info, "language", language or "unknown"), "duration": dur, "segments": out_segments}
|
|
|
|
| 318 |
except Exception:
|
| 319 |
pass
|
| 320 |
|
| 321 |
+
# --- End-to-end: STT -> Brain -> TTS (streamed WAV) ---
|
| 322 |
+
@app.post("/demo/echo.wav")
|
| 323 |
+
async def demo_echo_wav(
|
| 324 |
+
req: Request,
|
| 325 |
+
voice: str = Query(DEFAULT_VOICE, description="Voice id (TTS)"),
|
| 326 |
+
rate_wpm: Optional[int] = Query(None, description="Words-per-minute -> length_scale"),
|
| 327 |
+
length_scale: Optional[float] = Query(None, description="Override prosody"),
|
| 328 |
+
noise_scale: float = Query(NOISE_SCALE),
|
| 329 |
+
noise_w: float = Query(NOISE_W),
|
| 330 |
+
save: bool = Query(False, description="Also save output WAV under /files"),
|
| 331 |
+
):
|
| 332 |
+
"""
|
| 333 |
+
POST either:
|
| 334 |
+
- multipart/form-data with 'audio' file
|
| 335 |
+
- or JSON: { "file_url": "https://..." }
|
| 336 |
+
Returns: streaming audio/wav that says what it heard.
|
| 337 |
+
"""
|
| 338 |
+
tmp_path = None
|
| 339 |
+
try:
|
| 340 |
+
# --- Ingest audio (multipart or JSON URL) ---
|
| 341 |
+
content_type = req.headers.get("content-type", "").lower()
|
| 342 |
+
if "multipart/form-data" in content_type:
|
| 343 |
+
form = await req.form()
|
| 344 |
+
up = form.get("audio") # UploadFile
|
| 345 |
+
if not up:
|
| 346 |
+
return JSONResponse({"ok": False, "error": "Missing 'audio' file"}, status_code=400)
|
| 347 |
+
suffix = os.path.splitext(getattr(up, "filename", "") or "")[1] or ".wav"
|
| 348 |
+
fd, tmp_path = tempfile.mkstemp(prefix="demo_echo_", suffix=suffix)
|
| 349 |
+
os.close(fd)
|
| 350 |
+
with open(tmp_path, "wb") as f:
|
| 351 |
+
f.write(await up.read())
|
| 352 |
+
else:
|
| 353 |
+
# JSON with file_url
|
| 354 |
+
try:
|
| 355 |
+
body = await req.json()
|
| 356 |
+
except Exception:
|
| 357 |
+
body = {}
|
| 358 |
+
url = (body or {}).get("file_url")
|
| 359 |
+
if not url:
|
| 360 |
+
return JSONResponse({"ok": False, "error": "Provide multipart 'audio' or JSON {file_url}"}, status_code=400)
|
| 361 |
+
tmp_path = await _download_to_temp(url)
|
| 362 |
+
|
| 363 |
+
# --- STT ---
|
| 364 |
+
stt_res = _transcribe_path(tmp_path, language=None)
|
| 365 |
+
text = (stt_res.get("text") or "").strip()
|
| 366 |
+
if not text:
|
| 367 |
+
write_event({"type": "demo_echo", "data": {"ok": False, "error": "No speech detected"}})
|
| 368 |
+
return JSONResponse({"ok": False, "error": "No speech detected"}, status_code=422)
|
| 369 |
+
|
| 370 |
+
# --- Brain reply (simple confirmation) ---
|
| 371 |
+
reply_text = f"I heard: {text}"
|
| 372 |
+
reply_text = reply_text[:800] # safety bound
|
| 373 |
+
|
| 374 |
+
# Prosody parameters
|
| 375 |
+
ls = float(length_scale) if length_scale is not None else rate_to_length_scale(rate_wpm if rate_wpm is not None else BASE_WPM)
|
| 376 |
+
|
| 377 |
+
write_event({"type": "demo_echo", "data": {"ok": True, "heard_len": len(text), "voice": voice, "ls": ls, "save": save}})
|
| 378 |
+
|
| 379 |
+
# --- TTS (stream WAV back to the caller) ---
|
| 380 |
+
return await _proxy_tts_wav_stream(
|
| 381 |
+
text=reply_text,
|
| 382 |
+
voice=voice,
|
| 383 |
+
length_scale=ls,
|
| 384 |
+
noise_scale=noise_scale,
|
| 385 |
+
noise_w=noise_w,
|
| 386 |
+
save_local=save
|
| 387 |
+
)
|
| 388 |
+
|
| 389 |
+
except Exception as e:
|
| 390 |
+
write_event({"type": "demo_echo", "data": {"ok": False, "error": str(e)}})
|
| 391 |
+
return JSONResponse({"ok": False, "error": str(e)}, status_code=500)
|
| 392 |
+
finally:
|
| 393 |
+
try:
|
| 394 |
+
if tmp_path and os.path.exists(tmp_path):
|
| 395 |
+
os.unlink(tmp_path)
|
| 396 |
+
except Exception:
|
| 397 |
+
pass
|
| 398 |
+
|
| 399 |
# ========== Optional direct runner ==========
|
| 400 |
if __name__ == "__main__":
|
| 401 |
import uvicorn
|