Percy3822 commited on
Commit
d6a8b1c
·
verified ·
1 Parent(s): f137cdc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +109 -28
app.py CHANGED
@@ -7,7 +7,7 @@ from pathlib import Path
7
  from typing import Optional, Dict
8
 
9
  import uvicorn
10
- from fastapi import FastAPI, WebSocket, WebSocketDisconnect, Request
11
  from fastapi.responses import JSONResponse, FileResponse, PlainTextResponse
12
 
13
  # -------------------------
@@ -45,6 +45,12 @@ FILE_CANDIDATES = [
45
  VOICES_DIR = pick_writable_dir([p for p in VOICE_CANDIDATES if p])
46
  FILES_DIR = pick_writable_dir([p for p in FILE_CANDIDATES if p])
47
 
 
 
 
 
 
 
48
  # -------------------------
49
  # Piper CLI integration
50
  # -------------------------
@@ -127,10 +133,12 @@ def build_piper_cmd(text: str, voice_id: str, to_stdout: bool, out_path: Optiona
127
  "--noise_w", str(noise_w),
128
  ]
129
  if to_stdout:
130
- cmd += ["-f", "-"] # write WAV to stdout
 
131
  else:
132
  if out_path is None:
133
  raise ValueError("out_path required when to_stdout=False")
 
134
  cmd += ["-f", str(out_path)]
135
  return cmd
136
 
@@ -140,8 +148,8 @@ async def piper_to_file(text, voice, out_path, length_scale, noise_scale, noise_
140
  proc = await asyncio.create_subprocess_exec(
141
  *cmd, stdin=asyncio.subprocess.PIPE, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE
142
  )
143
- # IMPORTANT: .write() is not awaitable; keep drain() awaited.
144
- proc.stdin.write(text.encode("utf-8"))
145
  await proc.stdin.drain()
146
  proc.stdin.close()
147
  await proc.wait()
@@ -150,7 +158,7 @@ async def piper_to_file(text, voice, out_path, length_scale, noise_scale, noise_
150
  raise RuntimeError(f"Piper failed (code {proc.returncode}).\n{stderr}")
151
 
152
  async def piper_stream_stdout(text, voice, ws: WebSocket, length_scale, noise_scale, noise_w):
153
- """Stream WAV from Piper stdout over WS, stripping the WAV header once even if split."""
154
  await ws.send_text(json.dumps({"event": "ready", "sr": DEFAULT_SR, "channels": DEFAULT_CH}))
155
 
156
  cmd = build_piper_cmd(text, voice, to_stdout=True,
@@ -158,46 +166,59 @@ async def piper_stream_stdout(text, voice, ws: WebSocket, length_scale, noise_sc
158
  proc = await asyncio.create_subprocess_exec(
159
  *cmd, stdin=asyncio.subprocess.PIPE, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE
160
  )
161
- # IMPORTANT: .write() is not awaitable; keep drain() awaited.
162
- proc.stdin.write(text.encode("utf-8"))
 
163
  await proc.stdin.drain()
164
  proc.stdin.close()
165
 
166
- header_needed = True
167
- header_buf = bytearray()
 
 
 
 
 
 
 
 
 
 
 
 
 
168
 
 
169
  try:
 
170
  while True:
171
  chunk = await proc.stdout.read(4096)
172
  if not chunk:
173
  break
174
-
175
- if header_needed:
176
- header_buf.extend(chunk)
177
- if len(header_buf) < 44:
178
- continue
179
- if header_buf[:4] == b"RIFF" and header_buf[8:12] == b"WAVE":
180
- payload = header_buf[44:]
181
- else:
182
- payload = bytes(header_buf)
183
- header_buf.clear()
184
- header_needed = False
185
- if payload:
186
- await ws.send_bytes(payload)
187
- else:
188
- await ws.send_bytes(chunk)
189
 
190
  await proc.wait()
 
 
191
  if proc.returncode != 0:
192
- stderr = (await proc.stderr.read()).decode("utf-8", "ignore")
193
- await ws.send_text(json.dumps({"event": "error", "detail": stderr}))
 
194
  else:
195
- await ws.send_text(json.dumps({"event": "done"}))
 
 
 
196
  except WebSocketDisconnect:
197
  try:
198
  proc.kill()
199
  except Exception:
200
  pass
 
 
 
 
201
 
202
  # ---------------
203
  # FastAPI wiring
@@ -275,6 +296,65 @@ async def speak(request: Request):
275
 
276
  return {"ok": True, "audio_url": f"/file/{out_path.name}"}
277
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
278
  @app.websocket("/ws/tts")
279
  async def ws_tts(ws: WebSocket):
280
  await ws.accept()
@@ -321,4 +401,5 @@ async def ws_tts(ws: WebSocket):
321
  pass
322
 
323
  if __name__ == "__main__":
324
- uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=False)
 
 
7
  from typing import Optional, Dict
8
 
9
  import uvicorn
10
+ from fastapi import FastAPI, WebSocket, WebSocketDisconnect, Request, BackgroundTasks
11
  from fastapi.responses import JSONResponse, FileResponse, PlainTextResponse
12
 
13
  # -------------------------
 
45
  VOICES_DIR = pick_writable_dir([p for p in VOICE_CANDIDATES if p])
46
  FILES_DIR = pick_writable_dir([p for p in FILE_CANDIDATES if p])
47
 
48
+ def _safe_unlink(path: Path):
49
+ try:
50
+ path.unlink(missing_ok=True)
51
+ except Exception:
52
+ pass
53
+
54
  # -------------------------
55
  # Piper CLI integration
56
  # -------------------------
 
133
  "--noise_w", str(noise_w),
134
  ]
135
  if to_stdout:
136
+ # Stream RAW PCM (no WAV header) → simpler, no header parsing bugs.
137
+ cmd += ["--raw", "-f", "-"]
138
  else:
139
  if out_path is None:
140
  raise ValueError("out_path required when to_stdout=False")
141
+ # When writing to file, Piper writes WAV by default.
142
  cmd += ["-f", str(out_path)]
143
  return cmd
144
 
 
148
  proc = await asyncio.create_subprocess_exec(
149
  *cmd, stdin=asyncio.subprocess.PIPE, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE
150
  )
151
+ # Send a final newline so Piper treats it as a complete utterance.
152
+ proc.stdin.write((text + "\n").encode("utf-8"))
153
  await proc.stdin.drain()
154
  proc.stdin.close()
155
  await proc.wait()
 
158
  raise RuntimeError(f"Piper failed (code {proc.returncode}).\n{stderr}")
159
 
160
  async def piper_stream_stdout(text, voice, ws: WebSocket, length_scale, noise_scale, noise_w):
161
+ """Stream RAW PCM from Piper stdout over WS (no WAV header), with stderr logs."""
162
  await ws.send_text(json.dumps({"event": "ready", "sr": DEFAULT_SR, "channels": DEFAULT_CH}))
163
 
164
  cmd = build_piper_cmd(text, voice, to_stdout=True,
 
166
  proc = await asyncio.create_subprocess_exec(
167
  *cmd, stdin=asyncio.subprocess.PIPE, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE
168
  )
169
+
170
+ # Feed text + newline then close stdin
171
+ proc.stdin.write((text + "\n").encode("utf-8"))
172
  await proc.stdin.drain()
173
  proc.stdin.close()
174
 
175
+ # Forward stderr lines to client (debug visibility)
176
+ async def pump_stderr():
177
+ try:
178
+ while True:
179
+ line = await proc.stderr.readline()
180
+ if not line:
181
+ break
182
+ try:
183
+ await ws.send_text(json.dumps({"event": "log", "stderr": line.decode("utf-8", "ignore").rstrip()}))
184
+ except Exception:
185
+ break
186
+ except Exception:
187
+ pass
188
+
189
+ stderr_task = asyncio.create_task(pump_stderr())
190
 
191
+ total_bytes = 0
192
  try:
193
+ # RAW PCM passthrough
194
  while True:
195
  chunk = await proc.stdout.read(4096)
196
  if not chunk:
197
  break
198
+ total_bytes += len(chunk)
199
+ await ws.send_bytes(chunk)
 
 
 
 
 
 
 
 
 
 
 
 
 
200
 
201
  await proc.wait()
202
+ await stderr_task
203
+
204
  if proc.returncode != 0:
205
+ rem = await proc.stderr.read()
206
+ detail = rem.decode("utf-8", "ignore").strip()
207
+ await ws.send_text(json.dumps({"event": "error", "detail": detail or f'piper exited {proc.returncode}'}))
208
  else:
209
+ if total_bytes == 0:
210
+ await ws.send_text(json.dumps({"event": "error", "detail": "No audio produced"}))
211
+ else:
212
+ await ws.send_text(json.dumps({"event": "done"}))
213
  except WebSocketDisconnect:
214
  try:
215
  proc.kill()
216
  except Exception:
217
  pass
218
+ try:
219
+ await stderr_task
220
+ except Exception:
221
+ pass
222
 
223
  # ---------------
224
  # FastAPI wiring
 
296
 
297
  return {"ok": True, "audio_url": f"/file/{out_path.name}"}
298
 
299
+ # --- Direct-file endpoints (audio/wav response) ---
300
+ @app.post("/speak.wav")
301
+ async def speak_wav_post(request: Request, background_tasks: BackgroundTasks):
302
+ """
303
+ POST JSON -> returns audio/wav directly
304
+ """
305
+ try:
306
+ body = await request.json()
307
+ except Exception:
308
+ return JSONResponse({"detail": "Invalid JSON"}, status_code=400)
309
+
310
+ text = (body.get("text") or "").strip()
311
+ if not text:
312
+ return JSONResponse({"detail": "Missing text"}, status_code=400)
313
+
314
+ voice = (body.get("voice") or DEFAULT_VOICE).strip()
315
+ length_scale = float(body.get("length_scale", 1.08))
316
+ noise_scale = float(body.get("noise_scale", 0.35))
317
+ noise_w = float(body.get("noise_w", 0.90))
318
+
319
+ ts = int(time.time() * 1000)
320
+ out_path = FILES_DIR / f"tts-{ts}.wav"
321
+
322
+ try:
323
+ ensure_voice(voice)
324
+ await piper_to_file(text, voice, out_path, length_scale, noise_scale, noise_w)
325
+ except Exception as e:
326
+ return JSONResponse({"ok": False, "error": str(e)}, status_code=500)
327
+
328
+ background_tasks.add_task(_safe_unlink, out_path)
329
+ return FileResponse(out_path, media_type="audio/wav", filename=out_path.name, background=background_tasks)
330
+
331
+ @app.get("/speak.wav")
332
+ async def speak_wav_get(
333
+ text: str,
334
+ voice: str = DEFAULT_VOICE,
335
+ length_scale: float = 1.08,
336
+ noise_scale: float = 0.35,
337
+ noise_w: float = 0.90,
338
+ background_tasks: BackgroundTasks | None = None,
339
+ ):
340
+ """GET query -> returns audio/wav directly"""
341
+ text = (text or "").strip()
342
+ if not text:
343
+ return JSONResponse({"detail": "Missing text"}, status_code=400)
344
+
345
+ ts = int(time.time() * 1000)
346
+ out_path = FILES_DIR / f"tts-{ts}.wav"
347
+
348
+ try:
349
+ ensure_voice(voice.strip())
350
+ await piper_to_file(text, voice.strip(), out_path, float(length_scale), float(noise_scale), float(noise_w))
351
+ except Exception as e:
352
+ return JSONResponse({"ok": False, "error": str(e)}, status_code=500)
353
+
354
+ background_tasks = background_tasks or BackgroundTasks()
355
+ background_tasks.add_task(_safe_unlink, out_path)
356
+ return FileResponse(out_path, media_type="audio/wav", filename=out_path.name, background=background_tasks)
357
+
358
  @app.websocket("/ws/tts")
359
  async def ws_tts(ws: WebSocket):
360
  await ws.accept()
 
401
  pass
402
 
403
  if __name__ == "__main__":
404
+ import os
405
+ uvicorn.run(app, host="0.0.0.0", port=int(os.getenv("PORT", "7860")), reload=False)