Percy3822 commited on
Commit
5cd7d81
·
verified ·
1 Parent(s): 5ea3089

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +81 -14
app.py CHANGED
@@ -1,4 +1,7 @@
 
 
1
  # app.py
 
2
  import asyncio
3
  import json
4
  import os
@@ -9,7 +12,7 @@ from pathlib import Path
9
  from typing import Dict, Optional, Tuple
10
 
11
  import uvicorn
12
- from fastapi import FastAPI, WebSocket, WebSocketDisconnect, Request, BackgroundTasks, Query
13
  from fastapi.responses import JSONResponse, FileResponse, PlainTextResponse
14
 
15
  # -------------------------
@@ -28,7 +31,7 @@ def pick_writable_dir(candidates):
28
  probe.unlink(missing_ok=True)
29
  return p
30
  except Exception as e:
31
- errs.append(f"{p}: {type(e).__name__}({e})")
32
  raise RuntimeError("No writable dir. Tried:\n " + "\n ".join(errs))
33
 
34
  ENV_DIR = os.getenv("TTS_DATA_DIR")
@@ -80,6 +83,15 @@ STREAM_BATCH_MS = int(os.getenv("STREAM_BATCH_MS", "100")) # ~100 ms
80
 
81
  DEFAULT_CH = 1 # mono
82
 
 
 
 
 
 
 
 
 
 
83
  # -------------------------
84
  # Voice download & checks
85
  # -------------------------
@@ -88,7 +100,7 @@ HF_REPO_BASE = "https://huggingface.co/rhasspy/piper-voices/resolve"
88
  HF_REV = os.getenv("PIPER_VOICES_REV", "main") # optionally pin a commit hash
89
 
90
  # sanity thresholds (bytes)
91
- MIN_ONNX_BYTES = int(os.getenv("MIN_ONNX_BYTES", "5000000")) # >= ~5MB (real models are much larger)
92
  MIN_JSON_BYTES = int(os.getenv("MIN_JSON_BYTES", "1000")) # >= 1KB
93
 
94
  # (lang, country, family, quality, basename)
@@ -325,8 +337,8 @@ def health():
325
  # optional environment versions
326
  try:
327
  import numpy, onnxruntime as ort
328
- numpy_version = numpy.__version__
329
- onnxruntime_version = ort.__version__
330
  except Exception:
331
  numpy_version = onnxruntime_version = None
332
 
@@ -345,6 +357,25 @@ def health():
345
  def root():
346
  return PlainTextResponse("ActualTTS (CPU) — use POST /speak, GET/POST /speak.wav, or WS /ws/tts")
347
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
348
  @app.get("/file/{name}")
349
  def get_file(name: str):
350
  path = FILES_DIR / name
@@ -352,22 +383,32 @@ def get_file(name: str):
352
  return JSONResponse({"ok": False, "error": "not found"}, status_code=404)
353
  return FileResponse(path)
354
 
 
 
 
 
 
 
 
355
  @app.post("/speak")
356
- async def speak(request: Request):
357
  """
358
  POST JSON:
359
  { "text": "Hello", "voice": "en_US-libritts-high",
360
  "length_scale": 1.08, "noise_scale": 0.35, "noise_w": 0.90 }
361
  Returns: { "ok": true, "audio_url": "/file/tts-XXXX.wav" }
362
  """
 
 
363
  try:
364
  body = await request.json()
365
  except Exception:
366
  return JSONResponse({"detail": "Invalid JSON"}, status_code=400)
367
 
368
  text = (body.get("text") or "").strip()
369
- if not text:
370
- return JSONResponse({"detail": "Missing text"}, status_code=400)
 
371
 
372
  voice = (body.get("voice") or DEFAULT_VOICE).strip()
373
  length_scale = float(body.get("length_scale", 1.08))
@@ -386,16 +427,19 @@ async def speak(request: Request):
386
  return {"ok": True, "audio_url": f"/file/{out_path.name}"}
387
 
388
  @app.post("/speak.wav")
389
- async def speak_wav_post(request: Request, background_tasks: BackgroundTasks):
390
  """POST JSON -> returns audio/wav directly"""
 
 
391
  try:
392
  body = await request.json()
393
  except Exception:
394
  return JSONResponse({"detail": "Invalid JSON"}, status_code=400)
395
 
396
  text = (body.get("text") or "").strip()
397
- if not text:
398
- return JSONResponse({"detail": "Missing text"}, status_code=400)
 
399
 
400
  voice = (body.get("voice") or DEFAULT_VOICE).strip()
401
  length_scale = float(body.get("length_scale", 1.08))
@@ -422,11 +466,16 @@ async def speak_wav_get(
422
  noise_scale: float = 0.35,
423
  noise_w: float = 0.90,
424
  background_tasks: BackgroundTasks = None,
 
425
  ):
426
  """GET query -> returns audio/wav directly"""
 
 
 
427
  text = (text or "").strip()
428
- if not text:
429
- return JSONResponse({"detail": "Missing text"}, status_code=400)
 
430
 
431
  ts = int(time.time() * 1000)
432
  out_path = FILES_DIR / f"tts-{ts}.wav"
@@ -490,10 +539,25 @@ async def ws_tts(ws: WebSocket):
490
  continue
491
  ev = data.get("event")
492
  if ev == "init":
 
 
 
 
 
 
493
  voice = (data.get("voice") or voice).strip()
 
494
  if "length_scale" in data: length_scale = float(data["length_scale"])
495
  if "noise_scale" in data: noise_scale = float(data["noise_scale"])
496
  if "noise_w" in data: noise_w = float(data["noise_w"])
 
 
 
 
 
 
 
 
497
  try:
498
  info = ensure_voice(voice)
499
  voice_sr = int(info.get("sr", 22050))
@@ -508,6 +572,9 @@ async def ws_tts(ws: WebSocket):
508
  if not text:
509
  await ws.send_text(json.dumps({"event": "error", "detail": "empty text"}))
510
  continue
 
 
 
511
  await piper_stream_raw(text, voice, ws, voice_sr, DEFAULT_CH, length_scale, noise_scale, noise_w)
512
  # ignore others
513
  except WebSocketDisconnect:
@@ -523,4 +590,4 @@ async def ws_tts(ws: WebSocket):
523
  pass
524
 
525
  if __name__ == "__main__":
526
- uvicorn.run(app, host="0.0.0.0", port=int(os.getenv("PORT", "7860")), reload=False)
 
1
+ ---
2
+
3
  # app.py
4
+ ```python
5
  import asyncio
6
  import json
7
  import os
 
12
  from typing import Dict, Optional, Tuple
13
 
14
  import uvicorn
15
+ from fastapi import FastAPI, WebSocket, WebSocketDisconnect, Request, BackgroundTasks, Query, Header
16
  from fastapi.responses import JSONResponse, FileResponse, PlainTextResponse
17
 
18
  # -------------------------
 
31
  probe.unlink(missing_ok=True)
32
  return p
33
  except Exception as e:
34
+ errs.append(f"{p}: {type(e)._name_}({e})")
35
  raise RuntimeError("No writable dir. Tried:\n " + "\n ".join(errs))
36
 
37
  ENV_DIR = os.getenv("TTS_DATA_DIR")
 
83
 
84
  DEFAULT_CH = 1 # mono
85
 
86
+ # Input clamp (basic DoS protection)
87
+ MAX_TEXT_CHARS = int(os.getenv("MAX_TEXT_CHARS", "800"))
88
+
89
+ # Optional shared secret (x-auth header) for internal/protected calls
90
+ AUTH_SHARED_SECRET = (os.getenv("AUTH_SHARED_SECRET") or "").strip()
91
+
92
+ def _auth_ok(x_auth: Optional[str]) -> bool:
93
+ return (not AUTH_SHARED_SECRET) or (x_auth == AUTH_SHARED_SECRET)
94
+
95
  # -------------------------
96
  # Voice download & checks
97
  # -------------------------
 
100
  HF_REV = os.getenv("PIPER_VOICES_REV", "main") # optionally pin a commit hash
101
 
102
  # sanity thresholds (bytes)
103
+ MIN_ONNX_BYTES = int(os.getenv("MIN_ONNX_BYTES", "5000000")) # >= ~5MB
104
  MIN_JSON_BYTES = int(os.getenv("MIN_JSON_BYTES", "1000")) # >= 1KB
105
 
106
  # (lang, country, family, quality, basename)
 
337
  # optional environment versions
338
  try:
339
  import numpy, onnxruntime as ort
340
+ numpy_version = numpy._version_
341
+ onnxruntime_version = ort._version_
342
  except Exception:
343
  numpy_version = onnxruntime_version = None
344
 
 
357
  def root():
358
  return PlainTextResponse("ActualTTS (CPU) — use POST /speak, GET/POST /speak.wav, or WS /ws/tts")
359
 
360
+ @app.post("/provision")
361
+ async def provision(request: Request, x_auth: Optional[str] = Header(None)):
362
+ """
363
+ POST JSON: { "voice": "en_US-amy-medium" }
364
+ Downloads voice assets if missing. Returns {ok, voice, sr}.
365
+ """
366
+ if not _auth_ok(x_auth):
367
+ return JSONResponse({"ok": False, "error": "unauthorized"}, status_code=401)
368
+ try:
369
+ body = await request.json()
370
+ except Exception:
371
+ return JSONResponse({"ok": False, "error": "invalid json"}, status_code=400)
372
+ voice = (body.get("voice") or DEFAULT_VOICE).strip()
373
+ try:
374
+ info = ensure_voice(voice)
375
+ return {"ok": True, "voice": voice, "sr": int(info.get("sr", 22050))}
376
+ except Exception as e:
377
+ return JSONResponse({"ok": False, "error": str(e)}, status_code=500)
378
+
379
  @app.get("/file/{name}")
380
  def get_file(name: str):
381
  path = FILES_DIR / name
 
383
  return JSONResponse({"ok": False, "error": "not found"}, status_code=404)
384
  return FileResponse(path)
385
 
386
+ def _validate_text(text: str) -> Optional[str]:
387
+ if not text:
388
+ return "Missing text"
389
+ if len(text) > MAX_TEXT_CHARS:
390
+ return f"text too long (>{MAX_TEXT_CHARS} chars)"
391
+ return None
392
+
393
  @app.post("/speak")
394
+ async def speak(request: Request, x_auth: Optional[str] = Header(None)):
395
  """
396
  POST JSON:
397
  { "text": "Hello", "voice": "en_US-libritts-high",
398
  "length_scale": 1.08, "noise_scale": 0.35, "noise_w": 0.90 }
399
  Returns: { "ok": true, "audio_url": "/file/tts-XXXX.wav" }
400
  """
401
+ if not _auth_ok(x_auth):
402
+ return JSONResponse({"ok": False, "error": "unauthorized"}, status_code=401)
403
  try:
404
  body = await request.json()
405
  except Exception:
406
  return JSONResponse({"detail": "Invalid JSON"}, status_code=400)
407
 
408
  text = (body.get("text") or "").strip()
409
+ err = _validate_text(text)
410
+ if err:
411
+ return JSONResponse({"detail": err}, status_code=400)
412
 
413
  voice = (body.get("voice") or DEFAULT_VOICE).strip()
414
  length_scale = float(body.get("length_scale", 1.08))
 
427
  return {"ok": True, "audio_url": f"/file/{out_path.name}"}
428
 
429
  @app.post("/speak.wav")
430
+ async def speak_wav_post(request: Request, background_tasks: BackgroundTasks, x_auth: Optional[str] = Header(None)):
431
  """POST JSON -> returns audio/wav directly"""
432
+ if not _auth_ok(x_auth):
433
+ return JSONResponse({"ok": False, "error": "unauthorized"}, status_code=401)
434
  try:
435
  body = await request.json()
436
  except Exception:
437
  return JSONResponse({"detail": "Invalid JSON"}, status_code=400)
438
 
439
  text = (body.get("text") or "").strip()
440
+ err = _validate_text(text)
441
+ if err:
442
+ return JSONResponse({"detail": err}, status_code=400)
443
 
444
  voice = (body.get("voice") or DEFAULT_VOICE).strip()
445
  length_scale = float(body.get("length_scale", 1.08))
 
466
  noise_scale: float = 0.35,
467
  noise_w: float = 0.90,
468
  background_tasks: BackgroundTasks = None,
469
+ x_auth: Optional[str] = Header(None),
470
  ):
471
  """GET query -> returns audio/wav directly"""
472
+ if not _auth_ok(x_auth):
473
+ return JSONResponse({"ok": False, "error": "unauthorized"}, status_code=401)
474
+
475
  text = (text or "").strip()
476
+ err = _validate_text(text)
477
+ if err:
478
+ return JSONResponse({"detail": err}, status_code=400)
479
 
480
  ts = int(time.time() * 1000)
481
  out_path = FILES_DIR / f"tts-{ts}.wav"
 
539
  continue
540
  ev = data.get("event")
541
  if ev == "init":
542
+ # optional shared-secret over WS: accept via querystring token or in 'token' field
543
+ token = (data.get("token") or "")
544
+ if AUTH_SHARED_SECRET and token != AUTH_SHARED_SECRET:
545
+ await ws.send_text(json.dumps({"event": "error", "detail": "unauthorized"}))
546
+ await ws.close(); return
547
+
548
  voice = (data.get("voice") or voice).strip()
549
+ # Accept explicit params first
550
  if "length_scale" in data: length_scale = float(data["length_scale"])
551
  if "noise_scale" in data: noise_scale = float(data["noise_scale"])
552
  if "noise_w" in data: noise_w = float(data["noise_w"])
553
+ # Optional: map rate_wpm → length_scale if user didn't set a custom length_scale
554
+ if "length_scale" not in data and "rate_wpm" in data:
555
+ try:
556
+ rate_wpm = int(data.get("rate_wpm", 165))
557
+ # crude monotonic mapping: faster WPM → smaller length_scale
558
+ length_scale = max(0.70, min(1.40, 165.0 / max(100, rate_wpm)))
559
+ except Exception:
560
+ pass
561
  try:
562
  info = ensure_voice(voice)
563
  voice_sr = int(info.get("sr", 22050))
 
572
  if not text:
573
  await ws.send_text(json.dumps({"event": "error", "detail": "empty text"}))
574
  continue
575
+ if len(text) > MAX_TEXT_CHARS:
576
+ await ws.send_text(json.dumps({"event":"error","detail": f"text too long (>{MAX_TEXT_CHARS})"}))
577
+ continue
578
  await piper_stream_raw(text, voice, ws, voice_sr, DEFAULT_CH, length_scale, noise_scale, noise_w)
579
  # ignore others
580
  except WebSocketDisconnect:
 
590
  pass
591
 
592
  if __name__ == "__main__":
593
+ uvicorn.run(app, host="0.0.0.0", port=int(os.getenv("PORT", "7860")), reload=False)