Percy3822 commited on
Commit
ab5c984
·
verified ·
1 Parent(s): 2dda111

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +81 -4
app.py CHANGED
@@ -26,7 +26,7 @@ STT_COMPUTE = os.environ.get("STT_COMPUTE", "int8") # "int8"|"int8_float16"|"flo
26
  STT_MAXLEN_S = float(os.environ.get("STT_MAXLEN_S", "600")) # refuse extremely long uploads
27
 
28
  # ========== App ==========
29
- app = FastAPI(title="Brain Space (TTS+STT)", version="2.0.0")
30
 
31
  # In-memory queue to fan-out logs to /stream/logs clients
32
  log_queue: "asyncio.Queue[Dict[str, Any]]" = asyncio.Queue()
@@ -249,7 +249,6 @@ async def _download_to_temp(url: str) -> str:
249
 
250
  def _transcribe_path(path: str, language: Optional[str] = None) -> Dict[str, Any]:
251
  model = _stt_model()
252
- # NOTE: sticking to CPU-friendly settings; adjust if you move to GPU
253
  segments, info = model.transcribe(
254
  path,
255
  language=language, # "en" or None for auto
@@ -263,8 +262,8 @@ def _transcribe_path(path: str, language: Optional[str] = None) -> Dict[str, Any
263
  for seg in segments:
264
  out_segments.append({"start": seg.start, "end": seg.end, "text": seg.text})
265
  txt_parts.append(seg.text)
266
- # guard against absurdly long files if decoder doesn't report duration
267
- if STT_MAXLEN_S and len(out_segments) > 0 and dur and seg.end and float(seg.end) > STT_MAXLEN_S:
268
  break
269
  text = "".join(txt_parts).strip()
270
  return {"text": text, "language": getattr(info, "language", language or "unknown"), "duration": dur, "segments": out_segments}
@@ -319,6 +318,84 @@ async def stt_transcribe(
319
  except Exception:
320
  pass
321
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
322
  # ========== Optional direct runner ==========
323
  if __name__ == "__main__":
324
  import uvicorn
 
26
  STT_MAXLEN_S = float(os.environ.get("STT_MAXLEN_S", "600")) # refuse extremely long uploads
27
 
28
  # ========== App ==========
29
+ app = FastAPI(title="Brain Space (TTS+STT)", version="2.1.0")
30
 
31
  # In-memory queue to fan-out logs to /stream/logs clients
32
  log_queue: "asyncio.Queue[Dict[str, Any]]" = asyncio.Queue()
 
249
 
250
  def _transcribe_path(path: str, language: Optional[str] = None) -> Dict[str, Any]:
251
  model = _stt_model()
 
252
  segments, info = model.transcribe(
253
  path,
254
  language=language, # "en" or None for auto
 
262
  for seg in segments:
263
  out_segments.append({"start": seg.start, "end": seg.end, "text": seg.text})
264
  txt_parts.append(seg.text)
265
+ # guard against absurdly long files
266
+ if STT_MAXLEN_S and dur and seg.end and float(seg.end) > STT_MAXLEN_S:
267
  break
268
  text = "".join(txt_parts).strip()
269
  return {"text": text, "language": getattr(info, "language", language or "unknown"), "duration": dur, "segments": out_segments}
 
318
  except Exception:
319
  pass
320
 
321
+ # --- End-to-end: STT -> Brain -> TTS (streamed WAV) ---
322
+ @app.post("/demo/echo.wav")
323
+ async def demo_echo_wav(
324
+ req: Request,
325
+ voice: str = Query(DEFAULT_VOICE, description="Voice id (TTS)"),
326
+ rate_wpm: Optional[int] = Query(None, description="Words-per-minute -> length_scale"),
327
+ length_scale: Optional[float] = Query(None, description="Override prosody"),
328
+ noise_scale: float = Query(NOISE_SCALE),
329
+ noise_w: float = Query(NOISE_W),
330
+ save: bool = Query(False, description="Also save output WAV under /files"),
331
+ ):
332
+ """
333
+ POST either:
334
+ - multipart/form-data with 'audio' file
335
+ - or JSON: { "file_url": "https://..." }
336
+ Returns: streaming audio/wav that says what it heard.
337
+ """
338
+ tmp_path = None
339
+ try:
340
+ # --- Ingest audio (multipart or JSON URL) ---
341
+ content_type = req.headers.get("content-type", "").lower()
342
+ if "multipart/form-data" in content_type:
343
+ form = await req.form()
344
+ up = form.get("audio") # UploadFile
345
+ if not up:
346
+ return JSONResponse({"ok": False, "error": "Missing 'audio' file"}, status_code=400)
347
+ suffix = os.path.splitext(getattr(up, "filename", "") or "")[1] or ".wav"
348
+ fd, tmp_path = tempfile.mkstemp(prefix="demo_echo_", suffix=suffix)
349
+ os.close(fd)
350
+ with open(tmp_path, "wb") as f:
351
+ f.write(await up.read())
352
+ else:
353
+ # JSON with file_url
354
+ try:
355
+ body = await req.json()
356
+ except Exception:
357
+ body = {}
358
+ url = (body or {}).get("file_url")
359
+ if not url:
360
+ return JSONResponse({"ok": False, "error": "Provide multipart 'audio' or JSON {file_url}"}, status_code=400)
361
+ tmp_path = await _download_to_temp(url)
362
+
363
+ # --- STT ---
364
+ stt_res = _transcribe_path(tmp_path, language=None)
365
+ text = (stt_res.get("text") or "").strip()
366
+ if not text:
367
+ write_event({"type": "demo_echo", "data": {"ok": False, "error": "No speech detected"}})
368
+ return JSONResponse({"ok": False, "error": "No speech detected"}, status_code=422)
369
+
370
+ # --- Brain reply (simple confirmation) ---
371
+ reply_text = f"I heard: {text}"
372
+ reply_text = reply_text[:800] # safety bound
373
+
374
+ # Prosody parameters
375
+ ls = float(length_scale) if length_scale is not None else rate_to_length_scale(rate_wpm if rate_wpm is not None else BASE_WPM)
376
+
377
+ write_event({"type": "demo_echo", "data": {"ok": True, "heard_len": len(text), "voice": voice, "ls": ls, "save": save}})
378
+
379
+ # --- TTS (stream WAV back to the caller) ---
380
+ return await _proxy_tts_wav_stream(
381
+ text=reply_text,
382
+ voice=voice,
383
+ length_scale=ls,
384
+ noise_scale=noise_scale,
385
+ noise_w=noise_w,
386
+ save_local=save
387
+ )
388
+
389
+ except Exception as e:
390
+ write_event({"type": "demo_echo", "data": {"ok": False, "error": str(e)}})
391
+ return JSONResponse({"ok": False, "error": str(e)}, status_code=500)
392
+ finally:
393
+ try:
394
+ if tmp_path and os.path.exists(tmp_path):
395
+ os.unlink(tmp_path)
396
+ except Exception:
397
+ pass
398
+
399
  # ========== Optional direct runner ==========
400
  if __name__ == "__main__":
401
  import uvicorn