ShadowHunter222 commited on
Commit
9ea9ec8
Β·
verified Β·
1 Parent(s): 1ff75e7

Upload 10 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ 3cpo_prompt.wav filter=lfs diff=lfs merge=lfs -text
37
+ aave_female_01_prompt.wav filter=lfs diff=lfs merge=lfs -text
38
+ her_prompt.wav filter=lfs diff=lfs merge=lfs -text
39
+ ivr_female_02_prompt.wav filter=lfs diff=lfs merge=lfs -text
3cpo_prompt.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a830bbf5494096e593dcfb6e099cfa334cb8b0b34d1403c69d36c02649c5ab15
3
+ size 513452
Dockerfile ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.11-slim
2
+
3
+ # Audio codec libraries for soundfile/librosa
4
+ RUN apt-get update && \
5
+ apt-get install -y --no-install-recommends libsndfile1 ffmpeg && \
6
+ rm -rf /var/lib/apt/lists/*
7
+
8
+ WORKDIR /app
9
+
10
+ # Install PyTorch CPU first (from dedicated index for smaller size)
11
+ RUN pip install --no-cache-dir torch --index-url https://download.pytorch.org/whl/cpu
12
+
13
+ # Install remaining dependencies
14
+ COPY requirements.txt .
15
+ RUN pip install --no-cache-dir -r requirements.txt
16
+
17
+ # Copy application code
18
+ COPY config.py text_processor.py chatterbox_wrapper.py app.py ./
19
+ COPY 3cpo_prompt.wav aave_female_01_prompt.wav her_prompt.wav ivr_female_02_prompt.wav ./
20
+ # Pre-download ONNX models + tokenizer at build time
21
+ RUN python -c "\
22
+ from chatterbox_wrapper import ChatterboxWrapper; \
23
+ ChatterboxWrapper(download_only=True); \
24
+ print('Models pre-downloaded successfully')"
25
+
26
+ # Prevent thread oversubscription in production
27
+ ENV OMP_NUM_THREADS=1
28
+ ENV MKL_NUM_THREADS=1
29
+ ENV OPENBLAS_NUM_THREADS=1
30
+
31
+ EXPOSE 7860
32
+
33
+ CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860", "--workers", "1"]
aave_female_01_prompt.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:971a3568a5a1521612bff565ed416aea62e30da3e00a53d771ff2c26da78276d
3
+ size 1217636
app.py ADDED
@@ -0,0 +1,962 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Chatterbox Turbo TTS -- FastAPI Server
3
+ ======================================
4
+ Production-ready API with true real-time MP3 streaming,
5
+ in-memory voice cloning, and fully non-blocking inference.
6
+
7
+ Endpoints:
8
+ GET /health -> health check + optional warmup
9
+ GET /info -> model info, supported tags, parameters
10
+ POST /tts -> full audio response (WAV/MP3/FLAC)
11
+ POST /tts/stream -> chunked MP3 streaming (MediaSource-ready)
12
+ POST /tts/true-stream -> alias for /tts/stream (Kokoro compat)
13
+ POST /tts/stop/{stream_id}-> cancel a specific active stream
14
+ POST /tts/stop -> cancel ALL active streams
15
+ POST /v1/audio/speech -> OpenAI-compatible streaming
16
+ """
17
+ import asyncio
18
+ import io
19
+ import json
20
+ import logging
21
+ import queue as stdlib_queue
22
+ import threading
23
+ import time
24
+ import urllib.error
25
+ import urllib.parse
26
+ import urllib.request
27
+ import uuid
28
+ from concurrent.futures import ThreadPoolExecutor
29
+ from typing import Generator, Optional
30
+
31
+ import numpy as np
32
+ import soundfile as sf
33
+ from fastapi import FastAPI, File, Form, HTTPException, Request, UploadFile
34
+ from fastapi.responses import Response, StreamingResponse
35
+ from contextlib import asynccontextmanager
36
+
37
+ from config import Config
38
+ from chatterbox_wrapper import ChatterboxWrapper, GenerationCancelled, VoiceProfile
39
+ import text_processor
40
+
41
+ # ── Logging ───────────────────────────────────────────────────────
42
+ logging.basicConfig(
43
+ level=logging.INFO,
44
+ format="%(asctime)s β”‚ %(levelname)-7s β”‚ %(name)s β”‚ %(message)s",
45
+ datefmt="%H:%M:%S",
46
+ )
47
+ logger = logging.getLogger(__name__)
48
+
49
+ # ── Thread pool for CPU-bound inference ───────────────────────────
50
+ tts_executor = ThreadPoolExecutor(max_workers=Config.MAX_WORKERS)
51
+
52
+
53
+ # ── Lifespan ──────────────────────────────────────────────────────
54
+
55
+ @asynccontextmanager
56
+ async def lifespan(app: FastAPI):
57
+ try:
58
+ wrapper = ChatterboxWrapper()
59
+ app.state.wrapper = wrapper
60
+ logger.info("βœ… Model loaded, server ready")
61
+ except Exception as e:
62
+ logger.error(f"❌ Model loading failed: {e}")
63
+ raise
64
+ yield
65
+ tts_executor.shutdown(wait=False)
66
+
67
+
68
+ app = FastAPI(
69
+ title="Chatterbox Turbo TTS API",
70
+ version="1.0.0",
71
+ docs_url="/docs",
72
+ lifespan=lifespan,
73
+ )
74
+
75
+
76
+ # ── CORS Middleware ───────────────────────────────────────────────
77
+
78
+ @app.middleware("http")
79
+ async def cors_middleware(request: Request, call_next):
80
+ origin = request.headers.get("origin")
81
+
82
+ # Preflight
83
+ if request.method == "OPTIONS" and origin in Config.ALLOWED_ORIGINS:
84
+ return Response(
85
+ status_code=200,
86
+ headers={
87
+ "Access-Control-Allow-Origin": origin,
88
+ "Access-Control-Allow-Methods": "*",
89
+ "Access-Control-Allow-Headers": "*",
90
+ "Access-Control-Allow-Credentials": "true",
91
+ },
92
+ )
93
+
94
+ if not origin or origin in Config.ALLOWED_ORIGINS:
95
+ response = await call_next(request)
96
+ if origin:
97
+ response.headers["Access-Control-Allow-Origin"] = origin
98
+ response.headers["Access-Control-Allow-Credentials"] = "true"
99
+ response.headers["Access-Control-Allow-Methods"] = "*"
100
+ response.headers["Access-Control-Allow-Headers"] = "*"
101
+ response.headers["Access-Control-Expose-Headers"] = "X-Stream-Id"
102
+ return response
103
+
104
+ logger.warning(f"🚫 Blocked origin: {origin}")
105
+ return Response(status_code=403, content="Forbidden: Origin not allowed")
106
+
107
+
108
+ # ═══════════════════════════════════════════════════════════════════
109
+ # Helper: resolve voice from optional upload
110
+ # ═══════════════════════════════════════════════════════════════════
111
+
112
+ async def _resolve_voice(
113
+ voice_ref: Optional[UploadFile],
114
+ voice_name: str,
115
+ wrapper: ChatterboxWrapper,
116
+ ) -> VoiceProfile:
117
+ """Return a VoiceProfile from uploaded audio, built-in voice name, or default."""
118
+ # 1) If a file was uploaded, encode it (highest priority)
119
+ if voice_ref is not None and voice_ref.filename:
120
+ audio_bytes = await voice_ref.read()
121
+ if len(audio_bytes) > Config.MAX_VOICE_UPLOAD_BYTES:
122
+ raise HTTPException(status_code=413, detail="Voice file too large (max 10 MB)")
123
+ if len(audio_bytes) == 0:
124
+ raise HTTPException(status_code=400, detail="Empty voice file")
125
+
126
+ loop = asyncio.get_running_loop()
127
+ try:
128
+ return await loop.run_in_executor(
129
+ tts_executor, wrapper.encode_voice_from_bytes, audio_bytes
130
+ )
131
+ except ValueError as e:
132
+ raise HTTPException(status_code=400, detail=str(e))
133
+ except Exception as e:
134
+ logger.error(f"Voice encoding failed: {e}")
135
+ raise HTTPException(
136
+ status_code=400,
137
+ detail=f"Could not process voice file: {str(e)}. "
138
+ f"Supported formats: WAV, MP3, MPEG, M4A, OGG, FLAC, WebM."
139
+ )
140
+
141
+ # 2) Resolve by built-in voice name (returns cached profile β€” no encoding)
142
+ try:
143
+ return wrapper.get_builtin_voice(voice_name)
144
+ except (ValueError, KeyError) as e:
145
+ raise HTTPException(status_code=400, detail=str(e))
146
+
147
+
148
+ # ═══════════════════════════════════════════════════════════════════
149
+ # Helper: encode numpy audio to bytes in given format
150
+ # ═══════════════════════════════════════════════════════════════════
151
+
152
+ def _encode_audio(audio: np.ndarray, fmt: str = "wav") -> tuple[bytes, str]:
153
+ buf = io.BytesIO()
154
+ fmt_lower = fmt.lower()
155
+ if fmt_lower == "mp3":
156
+ sf.write(buf, audio, Config.SAMPLE_RATE, format="mp3")
157
+ media = "audio/mpeg"
158
+ elif fmt_lower == "flac":
159
+ sf.write(buf, audio, Config.SAMPLE_RATE, format="flac")
160
+ media = "audio/flac"
161
+ else:
162
+ sf.write(buf, audio, Config.SAMPLE_RATE, format="wav")
163
+ media = "audio/wav"
164
+ return buf.getvalue(), media
165
+
166
+
167
+ def _encode_mp3_chunk(audio: np.ndarray) -> bytes:
168
+ """Encode one numpy chunk to MP3 bytes (same encoder path as current server)."""
169
+ data, _ = _encode_audio(audio, fmt="mp3")
170
+ return data
171
+
172
+
173
+ def _build_helper_endpoint(base_url: str, path: str) -> str:
174
+ return f"{base_url.rstrip('/')}{path}"
175
+
176
+
177
+ def _internal_headers() -> dict[str, str]:
178
+ headers = {"Content-Type": "application/json", "Accept": "audio/mpeg"}
179
+ if Config.INTERNAL_SHARED_SECRET:
180
+ headers["X-Internal-Secret"] = Config.INTERNAL_SHARED_SECRET
181
+ return headers
182
+
183
+
184
+ def _helper_request_chunk(
185
+ helper_base_url: str,
186
+ payload: dict,
187
+ timeout_sec: float,
188
+ ) -> bytes:
189
+ url = _build_helper_endpoint(helper_base_url, "/internal/chunk/synthesize")
190
+ body = json.dumps(payload).encode("utf-8")
191
+ req = urllib.request.Request(
192
+ url=url,
193
+ data=body,
194
+ headers=_internal_headers(),
195
+ method="POST",
196
+ )
197
+ with urllib.request.urlopen(req, timeout=timeout_sec) as resp:
198
+ return resp.read()
199
+
200
+
201
+ def _helper_register_voice(
202
+ helper_base_url: str,
203
+ stream_id: str,
204
+ audio_bytes: bytes,
205
+ timeout_sec: float,
206
+ ) -> str:
207
+ """Register reference voice on helper once, return voice_key for chunk calls."""
208
+ query = urllib.parse.urlencode({"stream_id": stream_id})
209
+ url = _build_helper_endpoint(helper_base_url, f"/internal/voice/register?{query}")
210
+ headers = {"Content-Type": "application/octet-stream", "Accept": "application/json"}
211
+ if Config.INTERNAL_SHARED_SECRET:
212
+ headers["X-Internal-Secret"] = Config.INTERNAL_SHARED_SECRET
213
+
214
+ req = urllib.request.Request(
215
+ url=url,
216
+ data=audio_bytes,
217
+ headers=headers,
218
+ method="POST",
219
+ )
220
+ with urllib.request.urlopen(req, timeout=timeout_sec) as resp:
221
+ data = json.loads(resp.read().decode("utf-8"))
222
+ voice_key = (data.get("voice_key") or "").strip()
223
+ if not voice_key:
224
+ raise RuntimeError("helper voice registration returned no voice_key")
225
+ return voice_key
226
+
227
+
228
+ def _helper_cancel_stream(helper_base_url: str, stream_id: str):
229
+ """Best-effort cancellation signal to helper."""
230
+ try:
231
+ url = _build_helper_endpoint(helper_base_url, f"/internal/chunk/cancel/{stream_id}")
232
+ req = urllib.request.Request(
233
+ url=url,
234
+ data=b"",
235
+ headers=_internal_headers(),
236
+ method="POST",
237
+ )
238
+ with urllib.request.urlopen(req, timeout=3.0):
239
+ pass
240
+ except Exception:
241
+ pass
242
+
243
+
244
+ # ═══════════════════════════════════════════════════════════════════
245
+ # Endpoints
246
+ # ═══════════════════════════════════════════════════════════════════
247
+
248
+ @app.get("/health")
249
+ async def health(warm_up: bool = False):
250
+ wrapper: ChatterboxWrapper = getattr(app.state, "wrapper", None)
251
+ status = {
252
+ "status": "healthy" if wrapper else "loading",
253
+ "model_loaded": wrapper is not None,
254
+ "model_dtype": Config.MODEL_DTYPE,
255
+ "streaming_supported": True,
256
+ "voice_cache_entries": wrapper._voice_cache.size if wrapper else 0,
257
+ }
258
+ if warm_up and wrapper:
259
+ try:
260
+ loop = asyncio.get_running_loop()
261
+ await loop.run_in_executor(tts_executor, wrapper.warmup)
262
+ status["warm_up"] = "success"
263
+ except Exception as e:
264
+ status["warm_up"] = f"failed: {e}"
265
+ return status
266
+
267
+
268
+ @app.get("/info")
269
+ async def info():
270
+ return {
271
+ "model": Config.MODEL_ID,
272
+ "dtype": Config.MODEL_DTYPE,
273
+ "sample_rate": Config.SAMPLE_RATE,
274
+ "paralinguistic_tags": list(Config.PARALINGUISTIC_TAGS),
275
+ "tag_usage": "Insert tags directly in text, e.g. 'That is so funny! [laugh] Anyway…'",
276
+ "parameters": {
277
+ "max_new_tokens": {"default": Config.MAX_NEW_TOKENS, "range": "64–2048"},
278
+ "repetition_penalty": {"default": Config.REPETITION_PENALTY, "range": "1.0–2.0"},
279
+ },
280
+ "voice_cloning": {
281
+ "description": "Upload 3–30s reference WAV/MP3 as 'voice_ref' field",
282
+ "max_upload_mb": Config.MAX_VOICE_UPLOAD_BYTES // (1024 * 1024),
283
+ },
284
+ "parallel_mode": {
285
+ "enabled": Config.ENABLE_PARALLEL_MODE,
286
+ "helper_configured": bool(Config.HELPER_BASE_URL),
287
+ "helper_base_url": Config.HELPER_BASE_URL or None,
288
+ "supports_voice_ref": True,
289
+ },
290
+ }
291
+
292
+
293
+ @app.get("/voices")
294
+ async def list_voices():
295
+ """Return all built-in voices available for selection."""
296
+ wrapper: ChatterboxWrapper = getattr(app.state, "wrapper", None)
297
+ if not wrapper:
298
+ raise HTTPException(503, "Model not loaded")
299
+ return {
300
+ "default": wrapper.default_voice_name,
301
+ "voices": wrapper.list_builtin_voices(),
302
+ }
303
+
304
+
305
+ # ── POST /tts ─────────────────────────────────────────────────────
306
+
307
+ @app.post("/tts", response_class=Response)
308
+ async def text_to_speech(
309
+ text: str = Form(...),
310
+ voice_ref: Optional[UploadFile] = File(None),
311
+ voice_name: str = Form("default"),
312
+ output_format: str = Form("wav"),
313
+ max_new_tokens: int = Form(Config.MAX_NEW_TOKENS),
314
+ repetition_penalty: float = Form(Config.REPETITION_PENALTY),
315
+ ):
316
+ """Generate complete audio for the given text."""
317
+ wrapper: ChatterboxWrapper = getattr(app.state, "wrapper", None)
318
+ if not wrapper:
319
+ raise HTTPException(503, "Model not loaded")
320
+
321
+ if not text or not text.strip():
322
+ raise HTTPException(400, "Text is required")
323
+
324
+ voice = await _resolve_voice(voice_ref, voice_name, wrapper)
325
+
326
+ loop = asyncio.get_running_loop()
327
+ try:
328
+ audio = await loop.run_in_executor(
329
+ tts_executor,
330
+ wrapper.generate_speech,
331
+ text, voice, max_new_tokens, repetition_penalty,
332
+ )
333
+ except ValueError as e:
334
+ raise HTTPException(400, str(e))
335
+ except Exception as e:
336
+ logger.error(f"TTS error: {e}")
337
+ raise HTTPException(500, "Internal server error")
338
+
339
+ data, media_type = _encode_audio(audio, output_format)
340
+ return Response(
341
+ content=data,
342
+ media_type=media_type,
343
+ headers={"Content-Disposition": f"attachment; filename=tts_output.{output_format}"},
344
+ )
345
+
346
+ # ═══════════════════════════════════════════════════════════════════
347
+ # Active Stream Registry (for cancellation)
348
+ # ═══════════════════════════════════════════════════════════════════
349
+
350
+ _active_streams: dict[str, threading.Event] = {}
351
+ _internal_cancelled_streams: set[str] = set()
352
+ _internal_cancel_lock = threading.Lock()
353
+ _internal_stream_voice_keys: dict[str, set[str]] = {}
354
+
355
+
356
+ # ═══════════════════════════════════════════════════════════════════
357
+ # Pipeline Streaming Generator
358
+ # ═══════════════════════════════════════════════════════════════════
359
+
360
+ def _pipeline_stream_generator(
361
+ wrapper: ChatterboxWrapper,
362
+ text: str,
363
+ voice: VoiceProfile,
364
+ max_new_tokens: int,
365
+ repetition_penalty: float,
366
+ stream_id: str,
367
+ ) -> Generator[bytes, None, None]:
368
+ """Two-stage producer-consumer pipeline for minimal inter-chunk gaps.
369
+
370
+ Architecture:
371
+ Producer thread (heavyweight, ~80% CPU):
372
+ ONNX token generation β†’ audio decoding β†’ raw numpy arrays β†’ queue
373
+
374
+ Consumer (this generator, lightweight, ~20% CPU):
375
+ queue β†’ MP3 encode β†’ yield to HTTP response
376
+
377
+ Why this helps:
378
+ - ONNX model runs CONTINUOUSLY without waiting for MP3 encode or HTTP
379
+ - MP3 encoding (libsndfile, C code) releases GIL β†’ true parallelism
380
+ - ONNX inference (C++ code) also releases GIL β†’ both run simultaneously
381
+ - Queue(maxsize=2) lets producer stay 1-2 chunks ahead
382
+
383
+ Cancellation:
384
+ - cancel_event checked between chunks + every 25 autoregressive steps
385
+ - Client disconnect triggers GeneratorExit β†’ finally sets cancel
386
+ - /tts/stop endpoint sets cancel externally
387
+ """
388
+ cancel_event = threading.Event()
389
+ _active_streams[stream_id] = cancel_event
390
+
391
+ # Raw audio buffer: producer puts numpy arrays, consumer takes them
392
+ audio_buffer: stdlib_queue.Queue = stdlib_queue.Queue(maxsize=2)
393
+
394
+ def _producer():
395
+ """Heavyweight worker: runs ONNX model continuously."""
396
+ try:
397
+ for audio_chunk in wrapper.stream_speech(
398
+ text, voice,
399
+ max_new_tokens=max_new_tokens,
400
+ repetition_penalty=repetition_penalty,
401
+ is_cancelled=cancel_event.is_set,
402
+ ):
403
+ if cancel_event.is_set():
404
+ break
405
+ while not cancel_event.is_set():
406
+ try:
407
+ audio_buffer.put(audio_chunk, timeout=0.1)
408
+ break
409
+ except stdlib_queue.Full:
410
+ continue
411
+ except GenerationCancelled:
412
+ logger.info(f"[{stream_id}] Generation cancelled")
413
+ except Exception as e:
414
+ while not cancel_event.is_set():
415
+ try:
416
+ audio_buffer.put(e, timeout=0.1)
417
+ break
418
+ except stdlib_queue.Full:
419
+ continue
420
+ finally:
421
+ while not cancel_event.is_set():
422
+ try:
423
+ audio_buffer.put(None, timeout=0.1)
424
+ break
425
+ except stdlib_queue.Full:
426
+ continue
427
+
428
+ producer = threading.Thread(target=_producer, daemon=True)
429
+ producer.start()
430
+
431
+ try:
432
+ # Consumer: lightweight MP3 encoding + yield
433
+ while True:
434
+ item = audio_buffer.get()
435
+ if item is None:
436
+ break
437
+ if isinstance(item, Exception):
438
+ logger.error(f"[{stream_id}] Stream error: {item}")
439
+ break
440
+ if cancel_event.is_set():
441
+ break
442
+
443
+ # MP3 encode (C code, releases GIL, runs parallel with next ONNX step)
444
+ buf = io.BytesIO()
445
+ sf.write(buf, item, Config.SAMPLE_RATE, format="mp3")
446
+ yield buf.getvalue()
447
+ finally:
448
+ # Cleanup: signal producer to stop + deregister
449
+ cancel_event.set()
450
+ _active_streams.pop(stream_id, None)
451
+
452
+
453
+ def _parallel_odd_even_stream_generator(
454
+ wrapper: ChatterboxWrapper,
455
+ text: str,
456
+ local_voice: VoiceProfile,
457
+ helper_voice_bytes: Optional[bytes],
458
+ max_new_tokens: int,
459
+ repetition_penalty: float,
460
+ stream_id: str,
461
+ helper_base_url: str,
462
+ ) -> Generator[bytes, None, None]:
463
+ """Additive odd/even split streamer (primary handles odd, helper handles even)."""
464
+ cancel_event = threading.Event()
465
+ _active_streams[stream_id] = cancel_event
466
+
467
+ clean_text = text_processor.sanitize(text.strip()[: Config.MAX_TEXT_LENGTH])
468
+ chunks = text_processor.split_for_streaming(clean_text)
469
+ total_chunks = len(chunks)
470
+ if total_chunks == 0:
471
+ _active_streams.pop(stream_id, None)
472
+ return
473
+
474
+ lock = threading.Lock()
475
+ cond = threading.Condition(lock)
476
+ ready: dict[int, bytes] = {}
477
+ first_error: Optional[Exception] = None
478
+ workers_done = 0
479
+
480
+ def _publish(idx: int, data: bytes):
481
+ with cond:
482
+ ready[idx] = data
483
+ cond.notify_all()
484
+
485
+ def _set_error(err: Exception):
486
+ nonlocal first_error
487
+ with cond:
488
+ if first_error is None:
489
+ first_error = err
490
+ cond.notify_all()
491
+
492
+ def _worker_done():
493
+ nonlocal workers_done
494
+ with cond:
495
+ workers_done += 1
496
+ cond.notify_all()
497
+
498
+ def _synth_local(chunk_text: str) -> bytes:
499
+ audio = wrapper.generate_speech(
500
+ chunk_text,
501
+ local_voice,
502
+ max_new_tokens=max_new_tokens,
503
+ repetition_penalty=repetition_penalty,
504
+ )
505
+ return _encode_mp3_chunk(audio)
506
+
507
+ def _odd_worker():
508
+ try:
509
+ for idx in range(0, total_chunks, 2):
510
+ if cancel_event.is_set():
511
+ break
512
+ data = _synth_local(chunks[idx])
513
+ _publish(idx, data)
514
+ except Exception as e:
515
+ _set_error(e)
516
+ finally:
517
+ _worker_done()
518
+
519
+ def _even_worker():
520
+ helper_available = True
521
+ helper_voice_key: Optional[str] = None
522
+ try:
523
+ if helper_voice_bytes:
524
+ attempts = 2 if Config.HELPER_RETRY_ONCE else 1
525
+ last_err: Optional[Exception] = None
526
+ for _ in range(attempts):
527
+ try:
528
+ helper_voice_key = _helper_register_voice(
529
+ helper_base_url=helper_base_url,
530
+ stream_id=stream_id,
531
+ audio_bytes=helper_voice_bytes,
532
+ timeout_sec=max(1.0, Config.HELPER_TIMEOUT_SEC),
533
+ )
534
+ last_err = None
535
+ break
536
+ except Exception as reg_err:
537
+ last_err = reg_err
538
+ continue
539
+ if last_err is not None:
540
+ helper_available = False
541
+ logger.warning(
542
+ f"[{stream_id}] Helper voice registration failed; "
543
+ "falling back to local synthesis for even chunks"
544
+ )
545
+
546
+ for idx in range(1, total_chunks, 2):
547
+ if cancel_event.is_set():
548
+ break
549
+
550
+ if helper_available:
551
+ payload = {
552
+ "stream_id": stream_id,
553
+ "chunk_index": idx,
554
+ "text": chunks[idx],
555
+ "max_new_tokens": max_new_tokens,
556
+ "repetition_penalty": repetition_penalty,
557
+ "output_format": "mp3",
558
+ }
559
+ if helper_voice_key:
560
+ payload["voice_key"] = helper_voice_key
561
+
562
+ attempts = 2 if Config.HELPER_RETRY_ONCE else 1
563
+ last_err: Optional[Exception] = None
564
+ for _ in range(attempts):
565
+ try:
566
+ helper_data = _helper_request_chunk(
567
+ helper_base_url=helper_base_url,
568
+ payload=payload,
569
+ timeout_sec=max(1.0, Config.HELPER_TIMEOUT_SEC),
570
+ )
571
+ _publish(idx, helper_data)
572
+ last_err = None
573
+ break
574
+ except Exception as helper_err:
575
+ last_err = helper_err
576
+ continue
577
+
578
+ if last_err is None:
579
+ continue
580
+
581
+ helper_available = False
582
+ logger.warning(
583
+ f"[{stream_id}] Helper failed at chunk {idx}; "
584
+ "falling back to local synthesis for remaining even chunks"
585
+ )
586
+
587
+ # Local fallback for even chunks
588
+ data = _synth_local(chunks[idx])
589
+ _publish(idx, data)
590
+ except Exception as e:
591
+ _set_error(e)
592
+ finally:
593
+ _worker_done()
594
+
595
+ odd_thread = threading.Thread(target=_odd_worker, daemon=True)
596
+ even_thread = threading.Thread(target=_even_worker, daemon=True)
597
+ odd_thread.start()
598
+ even_thread.start()
599
+
600
+ next_idx = 0
601
+ try:
602
+ while next_idx < total_chunks:
603
+ with cond:
604
+ while (
605
+ next_idx not in ready
606
+ and first_error is None
607
+ and not cancel_event.is_set()
608
+ and workers_done < 2
609
+ ):
610
+ cond.wait(timeout=0.1)
611
+
612
+ if cancel_event.is_set():
613
+ break
614
+
615
+ if next_idx in ready:
616
+ data = ready.pop(next_idx)
617
+ elif first_error is not None:
618
+ logger.error(f"[{stream_id}] Parallel stream error: {first_error}")
619
+ break
620
+ elif workers_done >= 2:
621
+ logger.error(
622
+ f"[{stream_id}] Parallel stream ended with missing chunk index {next_idx}"
623
+ )
624
+ break
625
+ else:
626
+ continue
627
+
628
+ yield data
629
+ next_idx += 1
630
+ finally:
631
+ cancel_event.set()
632
+ _helper_cancel_stream(helper_base_url, stream_id)
633
+ odd_thread.join(timeout=1.0)
634
+ even_thread.join(timeout=1.0)
635
+ _active_streams.pop(stream_id, None)
636
+
637
+
638
+ # ── POST /tts/stream & /tts/true-stream ──────────────────────────
639
+
640
+ @app.post("/tts/stream")
641
+ @app.post("/tts/true-stream")
642
+ async def stream_text_to_speech(
643
+ text: str = Form(...),
644
+ voice_ref: Optional[UploadFile] = File(None),
645
+ voice_name: str = Form("default"),
646
+ max_new_tokens: int = Form(Config.MAX_NEW_TOKENS),
647
+ repetition_penalty: float = Form(Config.REPETITION_PENALTY),
648
+ ):
649
+ """True real-time streaming: yields MP3 chunks as each sentence finishes.
650
+
651
+ Response includes X-Stream-Id header for cancellation via /tts/stop.
652
+ Compatible with frontend's MediaSource + ReadableStream pattern.
653
+ """
654
+ wrapper: ChatterboxWrapper = getattr(app.state, "wrapper", None)
655
+ if not wrapper:
656
+ raise HTTPException(503, "Model not loaded")
657
+
658
+ if not text or not text.strip():
659
+ raise HTTPException(400, "Text is required")
660
+
661
+ voice = await _resolve_voice(voice_ref, voice_name, wrapper)
662
+ stream_id = uuid.uuid4().hex[:12]
663
+
664
+ return StreamingResponse(
665
+ _pipeline_stream_generator(
666
+ wrapper, text, voice, max_new_tokens, repetition_penalty, stream_id,
667
+ ),
668
+ media_type="audio/mpeg",
669
+ headers={
670
+ "Content-Disposition": "attachment; filename=tts_stream.mp3",
671
+ "Transfer-Encoding": "chunked",
672
+ "X-Stream-Id": stream_id,
673
+ "X-Streaming-Type": "true-realtime",
674
+ "Cache-Control": "no-cache",
675
+ },
676
+ )
677
+
678
+
679
+ @app.post("/tts/parallel-stream")
680
+ async def parallel_stream_text_to_speech(
681
+ text: str = Form(...),
682
+ voice_ref: Optional[UploadFile] = File(None),
683
+ voice_name: str = Form("default"),
684
+ max_new_tokens: int = Form(Config.MAX_NEW_TOKENS),
685
+ repetition_penalty: float = Form(Config.REPETITION_PENALTY),
686
+ helper_url: Optional[str] = Form(None),
687
+ ):
688
+ """Additive odd/even split stream mode (primary + helper)."""
689
+ wrapper: ChatterboxWrapper = getattr(app.state, "wrapper", None)
690
+ if not wrapper:
691
+ raise HTTPException(503, "Model not loaded")
692
+ if not Config.ENABLE_PARALLEL_MODE:
693
+ raise HTTPException(503, "Parallel mode is disabled")
694
+ if not text or not text.strip():
695
+ raise HTTPException(400, "Text is required")
696
+
697
+ local_voice: VoiceProfile = wrapper.default_voice
698
+ helper_voice_bytes: Optional[bytes] = None
699
+ if voice_ref is not None and voice_ref.filename:
700
+ helper_voice_bytes = await voice_ref.read()
701
+ if len(helper_voice_bytes) > Config.MAX_VOICE_UPLOAD_BYTES:
702
+ raise HTTPException(status_code=413, detail="Voice file too large (max 10 MB)")
703
+ if len(helper_voice_bytes) == 0:
704
+ raise HTTPException(status_code=400, detail="Empty voice file")
705
+ loop = asyncio.get_running_loop()
706
+ try:
707
+ local_voice = await loop.run_in_executor(
708
+ tts_executor, wrapper.encode_voice_from_bytes, helper_voice_bytes
709
+ )
710
+ except Exception as e:
711
+ logger.error(f"Parallel voice encoding failed: {e}")
712
+ raise HTTPException(400, "Could not process voice file for parallel mode")
713
+ else:
714
+ # Built-in voice selected by name β€” resolve locally and prepare
715
+ # bytes for helper registration so helpers cache the same hash.
716
+ try:
717
+ selected_voice_id = wrapper.resolve_voice_id(voice_name)
718
+ local_voice = wrapper.get_builtin_voice(selected_voice_id)
719
+ except ValueError as e:
720
+ raise HTTPException(status_code=400, detail=str(e))
721
+
722
+ # Only send bytes to helper if a non-default voice was selected,
723
+ # because the helper's own default is already loaded.
724
+ if selected_voice_id != wrapper.default_voice_name:
725
+ helper_voice_bytes = wrapper.get_builtin_voice_bytes(selected_voice_id)
726
+ if not helper_voice_bytes:
727
+ raise HTTPException(
728
+ status_code=400,
729
+ detail=f"Selected voice '{voice_name}' is unavailable for helper registration",
730
+ )
731
+
732
+ resolved_helper = (helper_url or Config.HELPER_BASE_URL).strip()
733
+ if not resolved_helper:
734
+ raise HTTPException(
735
+ 400,
736
+ "Helper URL not configured. Set CB_HELPER_BASE_URL or pass helper_url.",
737
+ )
738
+
739
+ stream_id = uuid.uuid4().hex[:12]
740
+ return StreamingResponse(
741
+ _parallel_odd_even_stream_generator(
742
+ wrapper=wrapper,
743
+ text=text,
744
+ local_voice=local_voice,
745
+ helper_voice_bytes=helper_voice_bytes,
746
+ max_new_tokens=max_new_tokens,
747
+ repetition_penalty=repetition_penalty,
748
+ stream_id=stream_id,
749
+ helper_base_url=resolved_helper,
750
+ ),
751
+ media_type="audio/mpeg",
752
+ headers={
753
+ "Content-Disposition": "attachment; filename=tts_parallel_stream.mp3",
754
+ "Transfer-Encoding": "chunked",
755
+ "X-Stream-Id": stream_id,
756
+ "X-Streaming-Type": "parallel-odd-even",
757
+ "Cache-Control": "no-cache",
758
+ },
759
+ )
760
+
761
+
762
+ # ── JSON body variant (Kokoro/OpenAI compatibility) ───────────────
763
+
764
+ from pydantic import BaseModel, Field
765
+
766
+
767
+ class InternalChunkRequest(BaseModel):
768
+ stream_id: str = Field(..., min_length=1, max_length=64)
769
+ chunk_index: int = Field(..., ge=0)
770
+ text: str = Field(..., min_length=1, max_length=10000)
771
+ max_new_tokens: int = Field(default=Config.MAX_NEW_TOKENS, ge=64, le=2048)
772
+ repetition_penalty: float = Field(default=Config.REPETITION_PENALTY, ge=1.0, le=2.0)
773
+ output_format: str = Field(default="mp3")
774
+ voice_key: Optional[str] = Field(default=None, min_length=1, max_length=64)
775
+
776
+
777
+ class TTSJsonRequest(BaseModel):
778
+ text: str = Field(..., min_length=1, max_length=50000)
779
+ voice: str = Field(default="default")
780
+ speed: float = Field(default=1.0, ge=0.5, le=2.0) # reserved for future use
781
+ max_new_tokens: int = Field(default=Config.MAX_NEW_TOKENS, ge=64, le=2048)
782
+ repetition_penalty: float = Field(default=Config.REPETITION_PENALTY, ge=1.0, le=2.0)
783
+
784
+
785
+ @app.post("/internal/voice/register")
786
+ async def internal_voice_register(http_request: Request):
787
+ """Register voice once for a stream; returns reusable voice_key."""
788
+ if Config.INTERNAL_SHARED_SECRET:
789
+ provided = http_request.headers.get("X-Internal-Secret", "")
790
+ if provided != Config.INTERNAL_SHARED_SECRET:
791
+ raise HTTPException(403, "Forbidden")
792
+
793
+ wrapper: ChatterboxWrapper = getattr(app.state, "wrapper", None)
794
+ if not wrapper:
795
+ raise HTTPException(503, "Model not loaded")
796
+
797
+ audio_bytes = await http_request.body()
798
+ if len(audio_bytes) > Config.MAX_VOICE_UPLOAD_BYTES:
799
+ raise HTTPException(status_code=413, detail="Voice file too large (max 10 MB)")
800
+ if len(audio_bytes) == 0:
801
+ raise HTTPException(status_code=400, detail="Empty voice file")
802
+
803
+ loop = asyncio.get_running_loop()
804
+ try:
805
+ voice = await loop.run_in_executor(
806
+ tts_executor, wrapper.encode_voice_from_bytes, audio_bytes
807
+ )
808
+ except Exception as e:
809
+ logger.error(f"[internal] voice register failed: {e}")
810
+ raise HTTPException(400, "Voice registration failed")
811
+
812
+ voice_key = (voice.audio_hash or "").strip()
813
+ if not voice_key:
814
+ raise HTTPException(500, "Voice key unavailable")
815
+
816
+ stream_id = (http_request.query_params.get("stream_id") or "").strip()
817
+ if stream_id:
818
+ with _internal_cancel_lock:
819
+ keys = _internal_stream_voice_keys.setdefault(stream_id, set())
820
+ keys.add(voice_key)
821
+
822
+ return {"status": "registered", "voice_key": voice_key}
823
+
824
+
825
+ @app.post("/internal/chunk/synthesize")
826
+ async def internal_chunk_synthesize(
827
+ request: InternalChunkRequest,
828
+ http_request: Request,
829
+ ):
830
+ """Internal endpoint used by primary/helper parallel routing."""
831
+ if Config.INTERNAL_SHARED_SECRET:
832
+ provided = http_request.headers.get("X-Internal-Secret", "")
833
+ if provided != Config.INTERNAL_SHARED_SECRET:
834
+ raise HTTPException(403, "Forbidden")
835
+
836
+ with _internal_cancel_lock:
837
+ if request.stream_id in _internal_cancelled_streams:
838
+ raise HTTPException(409, "Stream already cancelled")
839
+
840
+ wrapper: ChatterboxWrapper = getattr(app.state, "wrapper", None)
841
+ if not wrapper:
842
+ raise HTTPException(503, "Model not loaded")
843
+
844
+ voice_profile = wrapper.default_voice
845
+ if request.voice_key:
846
+ cached_voice = wrapper._voice_cache.get(request.voice_key)
847
+ if cached_voice is None:
848
+ # Built-in voices are permanent in wrapper registry even if TTL cache entry expired.
849
+ cached_voice = wrapper.get_builtin_voice_by_hash(request.voice_key)
850
+ if cached_voice is None:
851
+ raise HTTPException(409, "Voice key expired or not found")
852
+ voice_profile = cached_voice
853
+
854
+ loop = asyncio.get_running_loop()
855
+ try:
856
+ audio = await loop.run_in_executor(
857
+ tts_executor,
858
+ wrapper.generate_speech,
859
+ request.text,
860
+ voice_profile,
861
+ request.max_new_tokens,
862
+ request.repetition_penalty,
863
+ )
864
+ except Exception as e:
865
+ logger.error(f"[internal] chunk {request.chunk_index} failed: {e}")
866
+ raise HTTPException(500, "Chunk synthesis failed")
867
+
868
+ fmt = (request.output_format or "mp3").lower()
869
+ if fmt not in {"mp3", "wav", "flac"}:
870
+ fmt = "mp3"
871
+ data, media_type = _encode_audio(audio, fmt=fmt)
872
+ return Response(
873
+ content=data,
874
+ media_type=media_type,
875
+ headers={
876
+ "X-Stream-Id": request.stream_id,
877
+ "X-Chunk-Index": str(request.chunk_index),
878
+ },
879
+ )
880
+
881
+
882
+ @app.post("/internal/chunk/cancel/{stream_id}")
883
+ async def internal_chunk_cancel(stream_id: str, http_request: Request):
884
+ if Config.INTERNAL_SHARED_SECRET:
885
+ provided = http_request.headers.get("X-Internal-Secret", "")
886
+ if provided != Config.INTERNAL_SHARED_SECRET:
887
+ raise HTTPException(403, "Forbidden")
888
+
889
+ with _internal_cancel_lock:
890
+ _internal_cancelled_streams.add(stream_id)
891
+ _internal_stream_voice_keys.pop(stream_id, None)
892
+ return {"status": "cancelled", "stream_id": stream_id}
893
+
894
+
895
+ @app.post("/v1/audio/speech")
896
+ async def openai_compatible_tts(request: TTSJsonRequest):
897
+ """OpenAI-compatible streaming endpoint (JSON body, no file upload).
898
+
899
+ Uses built-in voice selection via `voice`. For voice cloning, use /tts/stream with FormData.
900
+ """
901
+ wrapper: ChatterboxWrapper = getattr(app.state, "wrapper", None)
902
+ if not wrapper:
903
+ raise HTTPException(503, "Model not loaded")
904
+
905
+ try:
906
+ selected_voice = wrapper.get_builtin_voice(request.voice)
907
+ except ValueError as e:
908
+ raise HTTPException(400, str(e))
909
+
910
+ stream_id = uuid.uuid4().hex[:12]
911
+
912
+ return StreamingResponse(
913
+ _pipeline_stream_generator(
914
+ wrapper, request.text, selected_voice,
915
+ request.max_new_tokens, request.repetition_penalty, stream_id,
916
+ ),
917
+ media_type="audio/mpeg",
918
+ headers={
919
+ "Transfer-Encoding": "chunked",
920
+ "X-Stream-Id": stream_id,
921
+ "Cache-Control": "no-cache",
922
+ },
923
+ )
924
+
925
+
926
+ # ═══════════════════════════════════════════════════════════════════
927
+ # Stop / Cancel Endpoint
928
+ # ═══════════════════════════════════════════════════════════════════
929
+
930
+ @app.post("/tts/stop/{stream_id}")
931
+ async def stop_stream(stream_id: str):
932
+ """Stop an active TTS stream by its ID (from X-Stream-Id header).
933
+
934
+ Cancels the ONNX generation loop mid-token, freeing CPU immediately.
935
+ """
936
+ event = _active_streams.get(stream_id)
937
+ if event:
938
+ event.set()
939
+ logger.info(f"Stream {stream_id} cancelled by client")
940
+ return {"status": "stopped", "stream_id": stream_id}
941
+ return {"status": "not_found", "stream_id": stream_id}
942
+
943
+
944
+ @app.post("/tts/stop")
945
+ async def stop_all_streams():
946
+ """Emergency stop: cancel ALL active TTS streams."""
947
+ count = len(_active_streams)
948
+ for sid, event in list(_active_streams.items()):
949
+ event.set()
950
+ _active_streams.clear()
951
+ logger.info(f"Stopped all streams ({count} active)")
952
+ return {"status": "stopped_all", "count": count}
953
+
954
+
955
+ # ═══════════════════════════════════════════════════════════════════
956
+ # Entrypoint
957
+ # ═══════════════════════════════════════════════════════════════════
958
+
959
+ if __name__ == "__main__":
960
+ import uvicorn
961
+
962
+ uvicorn.run(app, host=Config.HOST, port=Config.PORT)
chatterbox_wrapper.py ADDED
@@ -0,0 +1,733 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Chatterbox Turbo TTS β€” ONNX Inference Wrapper
3
+ ═══════════════════════════════════════════════
4
+ Orchestrates the 4-component ONNX pipeline:
5
+ embed_tokens β†’ speech_encoder β†’ language_model β†’ conditional_decoder
6
+
7
+ Optimised for lowest-latency CPU inference on 2 vCPU:
8
+ β€’ Sequential execution, thread count = physical cores, no spinning
9
+ β€’ Token list pre-allocation (avoids O(nΒ²) np.concatenate in loop)
10
+ β€’ In-memory voice caching (no disk writes for uploads)
11
+ β€’ Robust audio loading: WAV, MP3, MPEG, M4A, OGG, FLAC, WebM
12
+ β€’ Sentence-level streaming for real-time playback
13
+ """
14
+
15
+ # ── Suppress harmless transformers warnings BEFORE import ─────────
16
+ import os
17
+ import warnings
18
+
19
+ os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "1"
20
+ warnings.filterwarnings("ignore", message=".*model of type.*chatterbox.*")
21
+
22
+ import hashlib
23
+ import io
24
+ import logging
25
+ import subprocess
26
+ import tempfile
27
+ import time
28
+ from collections import OrderedDict
29
+ from dataclasses import dataclass
30
+ from pathlib import Path
31
+ from typing import Callable, Generator, Optional
32
+
33
+ import librosa
34
+ import numpy as np
35
+ import onnxruntime as ort
36
+ import soundfile as soundfile_lib
37
+ from huggingface_hub import hf_hub_download
38
+ from transformers import AutoTokenizer
39
+
40
+ from config import Config
41
+ import text_processor
42
+
43
+ logger = logging.getLogger(__name__)
44
+
45
+ # ── Supported audio MIME types for voice upload ───────────────────
46
+ _SUPPORTED_AUDIO_EXTENSIONS = {
47
+ ".wav", ".mp3", ".mpeg", ".mpga", ".m4a", ".mp4",
48
+ ".ogg", ".oga", ".opus", ".flac", ".webm", ".aac", ".wma",
49
+ }
50
+
51
+
52
+ def _slugify(text: str) -> str:
53
+ """Convert a display name to a safe, lowercase identifier."""
54
+ buf = []
55
+ prev_underscore = False
56
+ for ch in text.strip().lower():
57
+ if ch.isalnum():
58
+ buf.append(ch)
59
+ prev_underscore = False
60
+ else:
61
+ if not prev_underscore:
62
+ buf.append("_")
63
+ prev_underscore = True
64
+ slug = "".join(buf).strip("_")
65
+ return slug or "voice"
66
+
67
+
68
+
69
+ # ═══════════════════════════════════════════════════════════════════
70
+ # Data Structures
71
+ # ═══════════════════════════════════════════════════════════════════
72
+
73
+ @dataclass
74
+ class VoiceProfile:
75
+ """Cached speaker embedding extracted from reference audio."""
76
+ cond_emb: np.ndarray
77
+ prompt_token: np.ndarray
78
+ speaker_embeddings: np.ndarray
79
+ speaker_features: np.ndarray
80
+ audio_hash: str = ""
81
+
82
+
83
+ class GenerationCancelled(Exception):
84
+ """Raised when inference is cancelled by the client."""
85
+ pass
86
+
87
+
88
+ # ═══════════════════════════════════════════════════════════════════
89
+ # LRU Voice Cache
90
+ # ═══════════════════════════════════════════════════════════════════
91
+
92
+ class _VoiceCache:
93
+ """LRU cache for VoiceProfile objects with TTL-based expiration.
94
+
95
+ Entries auto-expire after `ttl_seconds` (default: 1 hour).
96
+ Re-uploading the same voice file within the TTL window returns
97
+ the cached profile instantly β€” no re-encoding needed.
98
+ """
99
+
100
+ def __init__(self, maxsize: int, ttl_seconds: int = 3600):
101
+ self._cache: OrderedDict[str, tuple[VoiceProfile, float]] = OrderedDict()
102
+ self._maxsize = maxsize
103
+ self._ttl = ttl_seconds
104
+
105
+ def _evict_expired(self):
106
+ """Remove all entries older than TTL."""
107
+ now = time.time()
108
+ expired = [k for k, (_, ts) in self._cache.items() if now - ts > self._ttl]
109
+ for k in expired:
110
+ del self._cache[k]
111
+ logger.debug(f"Voice cache expired: {k[:8]}…")
112
+
113
+ def get(self, key: str) -> Optional[VoiceProfile]:
114
+ self._evict_expired()
115
+ if key in self._cache:
116
+ profile, ts = self._cache[key]
117
+ remaining = self._ttl - (time.time() - ts)
118
+ self._cache.move_to_end(key)
119
+ logger.info(f"Voice cache HIT: {key[:8]}… (expires in {remaining:.0f}s)")
120
+ return profile
121
+ return None
122
+
123
+ def put(self, key: str, profile: VoiceProfile):
124
+ self._evict_expired()
125
+ if key in self._cache:
126
+ self._cache.move_to_end(key)
127
+ else:
128
+ if len(self._cache) >= self._maxsize:
129
+ evicted_key, _ = self._cache.popitem(last=False)
130
+ logger.debug(f"Voice cache evicted (LRU): {evicted_key[:8]}…")
131
+ self._cache[key] = (profile, time.time())
132
+ logger.info(f"Voice cache STORED: {key[:8]}… (TTL: {self._ttl}s, size: {len(self._cache)})")
133
+
134
+ @property
135
+ def size(self) -> int:
136
+ return len(self._cache)
137
+
138
+
139
+ # ═══════════════════════════════════════════════════════════════════
140
+ # Audio Loading (robust multi-format support)
141
+ # ═══════════════════════════════════════════════════════════════════
142
+
143
+ def _load_audio_bytes(audio_bytes: bytes, sr: int = 24000) -> np.ndarray:
144
+ """Load audio from raw bytes, supporting WAV/MP3/MPEG/M4A/OGG/FLAC/WebM.
145
+
146
+ Strategy: try soundfile (fast, native) β†’ librosa (ffmpeg backend) β†’ ffmpeg CLI.
147
+ """
148
+ buf = io.BytesIO(audio_bytes)
149
+
150
+ # 1) Try soundfile (handles WAV, FLAC, OGG natively β€” fastest)
151
+ try:
152
+ audio, file_sr = soundfile_lib.read(buf)
153
+ if audio.ndim > 1:
154
+ audio = audio.mean(axis=1) # stereo β†’ mono
155
+ if file_sr != sr:
156
+ audio = librosa.resample(audio.astype(np.float32), orig_sr=file_sr, target_sr=sr)
157
+ return audio.astype(np.float32)
158
+ except Exception:
159
+ buf.seek(0)
160
+
161
+ # 2) Try librosa (handles MP3 via audioread + ffmpeg backend)
162
+ try:
163
+ audio, _ = librosa.load(buf, sr=sr, mono=True)
164
+ return audio.astype(np.float32)
165
+ except Exception:
166
+ buf.seek(0)
167
+
168
+ # 3) Fallback: use ffmpeg CLI to convert to WAV in memory
169
+ try:
170
+ proc = subprocess.run(
171
+ ["ffmpeg", "-i", "pipe:0", "-f", "wav", "-ac", "1", "-ar", str(sr), "pipe:1"],
172
+ input=audio_bytes, capture_output=True, timeout=30,
173
+ )
174
+ if proc.returncode == 0 and len(proc.stdout) > 44:
175
+ wav_buf = io.BytesIO(proc.stdout)
176
+ audio, _ = soundfile_lib.read(wav_buf)
177
+ return audio.astype(np.float32)
178
+ except Exception:
179
+ pass
180
+
181
+ raise ValueError(
182
+ "Could not decode audio file. Supported formats: "
183
+ "WAV, MP3, MPEG, M4A, OGG, FLAC, WebM, AAC. "
184
+ "Please upload a valid audio file."
185
+ )
186
+
187
+
188
+ # ═══════════════════════════════════════════════════════════════════
189
+ # Main Wrapper
190
+ # ═══════════════════════════════════════════════════════════════════
191
+
192
+ class ChatterboxWrapper:
193
+
194
+ def __init__(self, download_only: bool = False):
195
+ self.cfg = Config
196
+ os.makedirs(self.cfg.MODELS_DIR, exist_ok=True)
197
+
198
+ logger.info(f"Downloading ONNX models (dtype={self.cfg.MODEL_DTYPE}) …")
199
+ self._model_paths = self._download_models()
200
+
201
+ if download_only:
202
+ return
203
+
204
+ logger.info(
205
+ f"Creating ONNX Runtime sessions "
206
+ f"(intra_op_threads={self.cfg.CPU_THREADS}, workers={self.cfg.MAX_WORKERS}) …"
207
+ )
208
+ opts = self._make_session_options()
209
+ providers = ["CPUExecutionProvider"]
210
+
211
+ self.embed_session = ort.InferenceSession(self._model_paths["embed_tokens"], sess_options=opts, providers=providers)
212
+ self.encoder_session = ort.InferenceSession(self._model_paths["speech_encoder"], sess_options=opts, providers=providers)
213
+ self.lm_session = ort.InferenceSession(self._model_paths["language_model"], sess_options=opts, providers=providers)
214
+ self.decoder_session = ort.InferenceSession(self._model_paths["conditional_decoder"], sess_options=opts, providers=providers)
215
+
216
+ logger.info("Loading tokenizer …")
217
+ self.tokenizer = AutoTokenizer.from_pretrained(self.cfg.MODEL_ID)
218
+
219
+ self._voice_cache = _VoiceCache(
220
+ maxsize=self.cfg.VOICE_CACHE_SIZE,
221
+ ttl_seconds=self.cfg.VOICE_CACHE_TTL_SEC,
222
+ )
223
+
224
+ self._builtin_voice_profiles: dict[str, VoiceProfile] = {}
225
+ self._builtin_voice_bytes: dict[str, bytes] = {}
226
+ self._builtin_voice_by_hash: dict[str, VoiceProfile] = {}
227
+ self._voice_alias_to_id: dict[str, str] = {}
228
+ self._builtin_voice_catalog: list[dict] = []
229
+ self._default_voice_id: str = "default"
230
+
231
+ logger.info("Loading built-in voices (HF default + local samples) …")
232
+ self.default_voice = self._load_builtin_voices()
233
+
234
+ logger.info("βœ… ChatterboxWrapper ready")
235
+
236
+ # ─── Model download ──────────────────────────────────────────
237
+
238
+ def _download_models(self) -> dict:
239
+ """Download all 4 ONNX components + weight files from HuggingFace."""
240
+ components = ("conditional_decoder", "speech_encoder", "embed_tokens", "language_model")
241
+ paths = {}
242
+ for name in components:
243
+ paths[name] = self._download_component(name, self.cfg.MODEL_DTYPE)
244
+ return paths
245
+
246
+ def _download_component(self, name: str, dtype: str) -> str:
247
+ if dtype == "fp32":
248
+ filename = f"{name}.onnx"
249
+ elif dtype == "q8":
250
+ filename = f"{name}_quantized.onnx"
251
+ else:
252
+ filename = f"{name}_{dtype}.onnx"
253
+
254
+ graph = hf_hub_download(
255
+ self.cfg.MODEL_ID, subfolder="onnx", filename=filename,
256
+ cache_dir=self.cfg.MODELS_DIR,
257
+ )
258
+ # Download companion weight file
259
+ try:
260
+ hf_hub_download(
261
+ self.cfg.MODEL_ID, subfolder="onnx", filename=f"{filename}_data",
262
+ cache_dir=self.cfg.MODELS_DIR,
263
+ )
264
+ except Exception:
265
+ pass # Some quantized variants embed weights in-graph
266
+ return graph
267
+
268
+ # ─── Session configuration (optimised for 2 vCPU) ─────────────
269
+
270
+ def _make_session_options(self) -> ort.SessionOptions:
271
+ opts = ort.SessionOptions()
272
+ # Sequential execution: no parallel graph scheduling overhead
273
+ opts.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL
274
+ # Match physical cores exactly (2 for HF Space free tier)
275
+ opts.intra_op_num_threads = self.cfg.CPU_THREADS
276
+ opts.inter_op_num_threads = 1
277
+ # Full graph optimisations (constant folding, fusion, etc.)
278
+ opts.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
279
+ # Disable thread spinning β€” wastes CPU cycles on busy-wait
280
+ opts.add_session_config_entry("session.intra_op.allow_spinning", "0")
281
+ opts.add_session_config_entry("session.inter_op.allow_spinning", "0")
282
+ # Enable memory optimisations
283
+ opts.enable_cpu_mem_arena = True
284
+ opts.enable_mem_pattern = True
285
+ opts.enable_mem_reuse = True
286
+ return opts
287
+
288
+ # ─── Built-in voices (HF default + local samples) ────────────
289
+
290
+ def _download_hf_default_voice_bytes(self) -> bytes:
291
+ path = hf_hub_download(
292
+ self.cfg.DEFAULT_VOICE_REPO,
293
+ filename=self.cfg.DEFAULT_VOICE_FILE,
294
+ cache_dir=self.cfg.MODELS_DIR,
295
+ )
296
+ return Path(path).read_bytes()
297
+
298
+ def _list_local_voice_paths(self) -> list[Path]:
299
+ wrapper_dir = Path(__file__).resolve().parent
300
+
301
+ # Support both module-level and repo-root deployment layouts.
302
+ candidates = []
303
+ for d in (wrapper_dir, Path.cwd().resolve(), wrapper_dir.parent):
304
+ try:
305
+ resolved = d.resolve()
306
+ except Exception:
307
+ continue
308
+ if resolved.is_dir() and resolved not in candidates:
309
+ candidates.append(resolved)
310
+
311
+ voices: list[Path] = []
312
+ seen_real_paths: set[str] = set()
313
+ for root in candidates:
314
+ try:
315
+ entries = sorted(root.iterdir(), key=lambda x: x.name.lower())
316
+ except Exception:
317
+ continue
318
+
319
+ for p in entries:
320
+ if not p.is_file():
321
+ continue
322
+ if p.suffix.lower() not in _SUPPORTED_AUDIO_EXTENSIONS:
323
+ continue
324
+ real_path = str(p.resolve())
325
+ if real_path in seen_real_paths:
326
+ continue
327
+ seen_real_paths.add(real_path)
328
+ voices.append(p)
329
+
330
+ logger.info(
331
+ "Local voice scan complete: %s files across %s",
332
+ len(voices),
333
+ [str(x) for x in candidates],
334
+ )
335
+ return voices
336
+
337
+ def _make_unique_voice_id(self, preferred: str) -> str:
338
+ base = _slugify(preferred)
339
+ candidate = base
340
+ idx = 2
341
+ while candidate in self._builtin_voice_profiles:
342
+ candidate = f"{base}_{idx}"
343
+ idx += 1
344
+ return candidate
345
+
346
+ def _register_builtin_voice(
347
+ self,
348
+ *,
349
+ preferred_id: str,
350
+ display_name: str,
351
+ source: str,
352
+ source_ref: str,
353
+ audio_bytes: bytes,
354
+ is_default: bool = False,
355
+ ) -> str:
356
+ if not audio_bytes:
357
+ raise ValueError("Voice file is empty")
358
+
359
+ voice_id = self._make_unique_voice_id(preferred_id)
360
+ audio_hash = hashlib.md5(audio_bytes).hexdigest()
361
+
362
+ profile = self._voice_cache.get(audio_hash)
363
+ if profile is None:
364
+ audio = _load_audio_bytes(audio_bytes, sr=self.cfg.SAMPLE_RATE)
365
+ profile = self._encode_audio_array(audio, audio_hash=audio_hash)
366
+ self._voice_cache.put(audio_hash, profile)
367
+ else:
368
+ # Keep hash attached to cached profile for metadata/voice-key usage.
369
+ profile.audio_hash = audio_hash
370
+
371
+ self._builtin_voice_profiles[voice_id] = profile
372
+ self._builtin_voice_bytes[voice_id] = audio_bytes
373
+ if audio_hash:
374
+ self._builtin_voice_by_hash[audio_hash] = profile
375
+
376
+ aliases: list[str] = []
377
+ for alias in (voice_id, _slugify(Path(display_name).stem)):
378
+ if alias not in self._voice_alias_to_id:
379
+ self._voice_alias_to_id[alias] = voice_id
380
+ aliases.append(alias)
381
+
382
+ if is_default:
383
+ self._default_voice_id = voice_id
384
+ self._voice_alias_to_id["default"] = voice_id
385
+ if "default" not in aliases:
386
+ aliases.append("default")
387
+
388
+ self._builtin_voice_catalog.append(
389
+ {
390
+ "id": voice_id,
391
+ "display_name": display_name,
392
+ "source": source,
393
+ "source_ref": source_ref,
394
+ "aliases": aliases,
395
+ "voice_key": audio_hash,
396
+ }
397
+ )
398
+ return voice_id
399
+
400
+ def _load_builtin_voices(self) -> VoiceProfile:
401
+ # 1) HF default voice (kept as true default fallback)
402
+ hf_bytes = self._download_hf_default_voice_bytes()
403
+ self._register_builtin_voice(
404
+ preferred_id="default_hf_voice",
405
+ display_name=self.cfg.DEFAULT_VOICE_FILE,
406
+ source="huggingface",
407
+ source_ref=f"{self.cfg.DEFAULT_VOICE_REPO}:{self.cfg.DEFAULT_VOICE_FILE}",
408
+ audio_bytes=hf_bytes,
409
+ is_default=True,
410
+ )
411
+
412
+ # 2) Local voice samples placed next to app files
413
+ for path in self._list_local_voice_paths():
414
+ # Avoid duplicate entry if someone also copied default_voice.wav locally.
415
+ if path.name == self.cfg.DEFAULT_VOICE_FILE:
416
+ continue
417
+ try:
418
+ self._register_builtin_voice(
419
+ preferred_id=path.stem,
420
+ display_name=path.name,
421
+ source="local",
422
+ source_ref=str(path.name),
423
+ audio_bytes=path.read_bytes(),
424
+ is_default=False,
425
+ )
426
+ except Exception as e:
427
+ logger.warning(f"Skipping local voice {path.name}: {e}")
428
+
429
+ default_profile = self._builtin_voice_profiles.get(self._default_voice_id)
430
+ if default_profile is None:
431
+ raise RuntimeError("Default built-in voice could not be initialized")
432
+
433
+ logger.info(
434
+ f"Built-in voices loaded: {len(self._builtin_voice_catalog)} "
435
+ f"(default={self._default_voice_id})"
436
+ )
437
+ return default_profile
438
+
439
+ def list_builtin_voices(self) -> list[dict]:
440
+ """Return metadata for startup-preloaded voices."""
441
+ return [dict(v) for v in self._builtin_voice_catalog]
442
+
443
+ @property
444
+ def default_voice_name(self) -> str:
445
+ return self._default_voice_id
446
+
447
+ def resolve_voice_id(self, voice_name: Optional[str]) -> str:
448
+ if voice_name is None:
449
+ return self._default_voice_id
450
+ key = _slugify(str(voice_name))
451
+ if not key:
452
+ return self._default_voice_id
453
+ voice_id = self._voice_alias_to_id.get(key)
454
+ if voice_id is None:
455
+ available = ", ".join(sorted(self._voice_alias_to_id.keys()))
456
+ raise ValueError(f"Unknown voice '{voice_name}'. Available: {available}")
457
+ return voice_id
458
+
459
+ def get_builtin_voice(self, voice_name: Optional[str]) -> VoiceProfile:
460
+ voice_id = self.resolve_voice_id(voice_name)
461
+ profile = self._builtin_voice_profiles[voice_id]
462
+ if profile.audio_hash:
463
+ self._voice_cache.put(profile.audio_hash, profile)
464
+ return profile
465
+
466
+ def get_builtin_voice_bytes(self, voice_name: Optional[str]) -> Optional[bytes]:
467
+ voice_id = self.resolve_voice_id(voice_name)
468
+ return self._builtin_voice_bytes.get(voice_id)
469
+
470
+ def get_builtin_voice_by_hash(self, audio_hash: str) -> Optional[VoiceProfile]:
471
+ return self._builtin_voice_by_hash.get((audio_hash or "").strip())
472
+
473
+ # ─── Voice encoding ──────────────────────────────────────────
474
+
475
+ def encode_voice_from_bytes(self, audio_bytes: bytes) -> VoiceProfile:
476
+ """Encode reference audio from raw bytes (in-memory, no disk write).
477
+
478
+ Accepts: WAV, MP3, MPEG, M4A, OGG, FLAC, WebM, AAC, WMA, Opus.
479
+ """
480
+ audio_hash = hashlib.md5(audio_bytes).hexdigest()
481
+ cached = self._voice_cache.get(audio_hash)
482
+ if cached is not None:
483
+ logger.info(f"Voice cache hit: {audio_hash[:8]}…")
484
+ return cached
485
+
486
+ # Robust multi-format audio loading
487
+ audio = _load_audio_bytes(audio_bytes, sr=self.cfg.SAMPLE_RATE)
488
+
489
+ # Validate duration
490
+ duration = len(audio) / self.cfg.SAMPLE_RATE
491
+ if duration < self.cfg.MIN_REF_DURATION_SEC:
492
+ raise ValueError(
493
+ f"Reference audio too short ({duration:.1f}s). "
494
+ f"Minimum: {self.cfg.MIN_REF_DURATION_SEC}s"
495
+ )
496
+ if duration > self.cfg.MAX_REF_DURATION_SEC:
497
+ audio = audio[: int(self.cfg.MAX_REF_DURATION_SEC * self.cfg.SAMPLE_RATE)]
498
+
499
+ profile = self._encode_audio_array(audio, audio_hash=audio_hash)
500
+ self._voice_cache.put(audio_hash, profile)
501
+ return profile
502
+
503
+ def _encode_audio_array(self, audio: np.ndarray, audio_hash: str = "") -> VoiceProfile:
504
+ """Run speech_encoder on a float32 mono audio array."""
505
+ audio_input = audio[np.newaxis, :].astype(np.float32)
506
+ cond_emb, prompt_token, speaker_emb, speaker_feat = self.encoder_session.run(
507
+ None, {"audio_values": audio_input}
508
+ )
509
+ return VoiceProfile(
510
+ cond_emb=cond_emb,
511
+ prompt_token=prompt_token,
512
+ speaker_embeddings=speaker_emb,
513
+ speaker_features=speaker_feat,
514
+ audio_hash=audio_hash,
515
+ )
516
+
517
+ # ─── Full generation (non-streaming) ──────────────────────────
518
+
519
+ def generate_speech(
520
+ self,
521
+ text: str,
522
+ voice: Optional[VoiceProfile] = None,
523
+ max_new_tokens: Optional[int] = None,
524
+ repetition_penalty: Optional[float] = None,
525
+ ) -> np.ndarray:
526
+ """Generate complete audio for the given text."""
527
+ voice = voice or self.default_voice
528
+ text = text_processor.sanitize(text.strip()[: self.cfg.MAX_TEXT_LENGTH])
529
+ if not text:
530
+ raise ValueError("Text is empty after sanitization")
531
+
532
+ tokens = self._generate_tokens(
533
+ text, voice,
534
+ max_new_tokens or self.cfg.MAX_NEW_TOKENS,
535
+ repetition_penalty or self.cfg.REPETITION_PENALTY,
536
+ )
537
+ return self._decode_tokens(tokens, voice)
538
+
539
+ # ─── Streaming generation ─────────────────────────────────────
540
+
541
+ def stream_speech(
542
+ self,
543
+ text: str,
544
+ voice: Optional[VoiceProfile] = None,
545
+ max_new_tokens: Optional[int] = None,
546
+ repetition_penalty: Optional[float] = None,
547
+ is_cancelled: Optional[Callable[[], bool]] = None,
548
+ ) -> Generator[np.ndarray, None, None]:
549
+ """Yield audio chunks sentence-by-sentence for real-time streaming.
550
+
551
+ Each sentence is independently processed through the full pipeline
552
+ so the first chunk arrives as fast as possible (low TTFB).
553
+
554
+ Args:
555
+ is_cancelled: Optional callable that returns True to abort generation.
556
+ Checked between chunks and every 25 autoregressive steps.
557
+ """
558
+ voice = voice or self.default_voice
559
+ text = text_processor.sanitize(text.strip()[: self.cfg.MAX_TEXT_LENGTH])
560
+ if not text:
561
+ return
562
+
563
+ sentences = text_processor.split_for_streaming(text)
564
+ _max = max_new_tokens or self.cfg.MAX_NEW_TOKENS
565
+ _rep = repetition_penalty or self.cfg.REPETITION_PENALTY
566
+ _check = is_cancelled or (lambda: False)
567
+
568
+ for i, sentence in enumerate(sentences):
569
+ # Check cancellation between chunks
570
+ if _check():
571
+ logger.info("Generation cancelled by client (between chunks)")
572
+ return
573
+ if not sentence.strip():
574
+ continue
575
+ t0 = time.perf_counter()
576
+ try:
577
+ tokens = self._generate_tokens(sentence, voice, _max, _rep, _check)
578
+ if _check():
579
+ return
580
+ audio = self._decode_tokens(tokens, voice)
581
+ elapsed = time.perf_counter() - t0
582
+ audio_duration = len(audio) / self.cfg.SAMPLE_RATE
583
+ rtf = elapsed / audio_duration if audio_duration > 0 else 0
584
+ logger.info(
585
+ f"Chunk {i + 1}/{len(sentences)}: "
586
+ f"{len(sentence)} chars β†’ {audio_duration:.1f}s audio "
587
+ f"in {elapsed:.2f}s (RTF: {rtf:.2f}x)"
588
+ )
589
+ yield audio
590
+ except GenerationCancelled:
591
+ logger.info(f"Generation cancelled mid-token at chunk {i + 1}")
592
+ return
593
+ except Exception as e:
594
+ logger.error(f"Error on chunk {i + 1}: {e}")
595
+ raise
596
+
597
+ # ─── Autoregressive token generation (OPTIMISED) ──────────────
598
+
599
+ def _generate_tokens(
600
+ self,
601
+ text: str,
602
+ voice: VoiceProfile,
603
+ max_new_tokens: int,
604
+ repetition_penalty: float,
605
+ is_cancelled: Callable[[], bool] = lambda: False,
606
+ ) -> np.ndarray:
607
+ """Run embed β†’ LM autoregressive loop. Returns raw token array.
608
+
609
+ Optimisations:
610
+ β€’ Token list instead of repeated np.concatenate (O(n) β†’ O(1) append)
611
+ β€’ Unique tokens set for inline repetition penalty (avoids exponential penalty bug)
612
+ β€’ Pre-allocated attention mask for zero-copy slicing
613
+ β€’ Correct dimensional routing for step 0 prompt processing
614
+ """
615
+ input_ids = self.tokenizer(text, return_tensors="np")["input_ids"].astype(np.int64)
616
+
617
+ # Pre-allocate collections
618
+ token_list: list[int] = [self.cfg.START_SPEECH_TOKEN]
619
+ unique_tokens: set[int] = {self.cfg.START_SPEECH_TOKEN}
620
+ penalty = repetition_penalty
621
+
622
+ past_key_values = None
623
+ attention_mask_full = None
624
+ seq_len = 0
625
+
626
+ for step in range(max_new_tokens):
627
+ if step > 0 and step % 25 == 0 and is_cancelled():
628
+ raise GenerationCancelled()
629
+
630
+ embeds = self.embed_session.run(None, {"input_ids": input_ids})[0]
631
+
632
+ if step == 0:
633
+ # Prepend speaker conditioning
634
+ embeds = np.concatenate((voice.cond_emb, embeds), axis=1)
635
+ batch, seq_len, _ = embeds.shape
636
+
637
+ past_key_values = {
638
+ inp.name: np.zeros(
639
+ [batch, self.cfg.NUM_KV_HEADS, 0, self.cfg.HEAD_DIM],
640
+ dtype=np.float16 if inp.type == "tensor(float16)" else np.float32,
641
+ )
642
+ for inp in self.lm_session.get_inputs()
643
+ if "past_key_values" in inp.name
644
+ }
645
+
646
+ # Pre-allocate full attention mask
647
+ attention_mask_full = np.ones((batch, seq_len + max_new_tokens), dtype=np.int64)
648
+ attention_mask = attention_mask_full[:, :seq_len]
649
+
650
+ # Step 0 requires position_ids matching prompt sequence length
651
+ position_ids = np.arange(seq_len, dtype=np.int64).reshape(batch, -1)
652
+ else:
653
+ # O(1) zero-copy slice for subsequent steps
654
+ attention_mask = attention_mask_full[:, : seq_len + step]
655
+ # Single position ID for the single new token
656
+ position_ids = np.array([[seq_len + step - 1]], dtype=np.int64)
657
+
658
+ # Language model forward pass
659
+ logits, *present_kv = self.lm_session.run(
660
+ None,
661
+ dict(
662
+ inputs_embeds=embeds,
663
+ attention_mask=attention_mask,
664
+ position_ids=position_ids,
665
+ **past_key_values,
666
+ ),
667
+ )
668
+
669
+ # ── Inline repetition penalty + token selection ───────
670
+ last_logits = logits[0, -1, :].copy() # shape: (vocab_size,)
671
+
672
+ # Apply repetition penalty strictly to unique tokens to prevent over-penalization
673
+ for tok_id in unique_tokens:
674
+ if last_logits[tok_id] < 0:
675
+ last_logits[tok_id] *= penalty
676
+ else:
677
+ last_logits[tok_id] /= penalty
678
+
679
+ next_token = int(np.argmax(last_logits))
680
+ token_list.append(next_token)
681
+ unique_tokens.add(next_token)
682
+
683
+ if next_token == self.cfg.STOP_SPEECH_TOKEN:
684
+ break
685
+
686
+ # Update state for next step
687
+ input_ids = np.array([[next_token]], dtype=np.int64)
688
+ for j, key in enumerate(past_key_values):
689
+ past_key_values[key] = present_kv[j]
690
+
691
+ return np.array([token_list], dtype=np.int64)
692
+
693
+ # ─── Token β†’ audio decoding ───────────────────────────────────
694
+
695
+ def _decode_tokens(self, generated: np.ndarray, voice: VoiceProfile) -> np.ndarray:
696
+ """Decode speech tokens to a float32 waveform at 24 kHz."""
697
+ # Strip START token; strip STOP token if present
698
+ tokens = generated[:, 1:]
699
+ if tokens.shape[1] > 0 and tokens[0, -1] == self.cfg.STOP_SPEECH_TOKEN:
700
+ tokens = tokens[:, :-1]
701
+
702
+ if tokens.shape[1] == 0:
703
+ return np.zeros(0, dtype=np.float32)
704
+
705
+ # Prepend prompt token + append silence
706
+ silence = np.full(
707
+ (tokens.shape[0], 3), self.cfg.SILENCE_TOKEN, dtype=np.int64
708
+ )
709
+ full_tokens = np.concatenate(
710
+ [voice.prompt_token, tokens, silence], axis=1
711
+ )
712
+
713
+ wav = self.decoder_session.run(
714
+ None,
715
+ {
716
+ "speech_tokens": full_tokens,
717
+ "speaker_embeddings": voice.speaker_embeddings,
718
+ "speaker_features": voice.speaker_features,
719
+ },
720
+ )[0].squeeze(axis=0)
721
+
722
+ return wav
723
+
724
+ # ─── Warmup ───────────────────────────────────────────────────
725
+
726
+ def warmup(self):
727
+ """Run a short inference to warm up ONNX sessions and JIT paths."""
728
+ try:
729
+ t0 = time.perf_counter()
730
+ _ = self.generate_speech("Hello.", self.default_voice, max_new_tokens=32)
731
+ logger.info(f"Warmup done in {time.perf_counter() - t0:.2f}s")
732
+ except Exception as e:
733
+ logger.warning(f"Warmup failed (non-critical): {e}")
config.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Chatterbox Turbo TTS β€” Centralized Configuration
3
+ ═══════════════════════════════════════════════════
4
+ Optimised for HF Space free tier (2 vCPU).
5
+ Adjust MODEL_DTYPE to switch quantization (q8/q4/fp16/fp32).
6
+ All settings overridable via environment variables prefixed CB_.
7
+ """
8
+ import os
9
+
10
+ _HERE = os.path.dirname(os.path.abspath(__file__))
11
+
12
+
13
+ def _get_bool(name: str, default: bool) -> bool:
14
+ raw = os.getenv(name)
15
+ if raw is None:
16
+ return default
17
+ return raw.strip().lower() in {"1", "true", "yes", "on"}
18
+
19
+
20
+ class Config:
21
+ # ── Model ────────────────────────────────────────────────────
22
+ MODEL_ID: str = os.getenv("CB_MODEL_ID", "ResembleAI/chatterbox-turbo-ONNX")
23
+
24
+ # fp32 β†’ highest quality, ~1.4 GB, slowest
25
+ # fp16 β†’ good quality, ~0.7 GB
26
+ # q8 β†’ β˜… recommended, ~0.35 GB, best balance
27
+ # q4 β†’ smallest, ~0.17 GB, fastest, slight loss
28
+ # q4f16 β†’ q4 weights + fp16 activations
29
+ MODEL_DTYPE: str = os.getenv("CB_MODEL_DTYPE", "q4")
30
+
31
+ MODELS_DIR: str = os.getenv("CB_MODELS_DIR", os.path.join(_HERE, "models"))
32
+
33
+ # ── ONNX Runtime CPU tuning (optimised for 2 vCPU) ───────────
34
+ #
35
+ # KEY RULE: intra_op threads MUST match physical cores.
36
+ # β†’ 4 threads on 2 cores = oversubscription = SLOWER.
37
+ # β†’ 2 threads on 2 cores = each op uses both cores perfectly.
38
+ #
39
+ # MAX_WORKERS = 1 ensures ONE inference gets both cores.
40
+ # β†’ 2 workers would split 2 cores = both requests slow.
41
+ #
42
+ CPU_THREADS: int = int(os.getenv("CB_CPU_THREADS", "2"))
43
+ MAX_WORKERS: int = int(os.getenv("CB_MAX_WORKERS", "1"))
44
+
45
+ # ── Generation defaults ──────────────────────────────────────
46
+ SAMPLE_RATE: int = 24000
47
+ MAX_NEW_TOKENS: int = int(os.getenv("CB_MAX_NEW_TOKENS", "768"))
48
+ REPETITION_PENALTY: float = float(os.getenv("CB_REPETITION_PENALTY", "1.2"))
49
+ MAX_TEXT_LENGTH: int = int(os.getenv("CB_MAX_TEXT_LENGTH", "50000"))
50
+
51
+ # ── Model constants (official card β€” do not change) ──────────
52
+ START_SPEECH_TOKEN: int = 6561
53
+ STOP_SPEECH_TOKEN: int = 6562
54
+ SILENCE_TOKEN: int = 4299
55
+ NUM_KV_HEADS: int = 16
56
+ HEAD_DIM: int = 64
57
+
58
+ # ── Paralinguistic tags (Turbo native) ───────────────────────
59
+ PARALINGUISTIC_TAGS: tuple = (
60
+ "laugh", "chuckle", "cough", "sigh", "gasp",
61
+ "shush", "groan", "sniff", "clear throat",
62
+ )
63
+
64
+ # ── Voice / reference audio ──────────────────────────────────
65
+ # NOTE: Official ResembleAI/chatterbox-turbo-ONNX has no bundled voice.
66
+ # The default_voice.wav is a plain audio sample from community repo
67
+ # (not a model β€” just a reference WAV, safe to use from any source).
68
+ DEFAULT_VOICE_REPO: str = "onnx-community/chatterbox-ONNX"
69
+ DEFAULT_VOICE_FILE: str = "default_voice.wav"
70
+ MAX_VOICE_UPLOAD_BYTES: int = 10 * 1024 * 1024 # 10 MB
71
+ MIN_REF_DURATION_SEC: float = 1.5
72
+ MAX_REF_DURATION_SEC: float = 30.0
73
+ VOICE_CACHE_SIZE: int = int(os.getenv("CB_VOICE_CACHE_SIZE", "20"))
74
+ VOICE_CACHE_TTL_SEC: int = int(os.getenv("CB_VOICE_CACHE_TTL", "3600")) # 1 hour
75
+
76
+ # ── Streaming ────────────────────────────────────────────────
77
+ # Smaller chunks = faster TTFB (first audio arrives sooner)
78
+ # ~200 chars β‰ˆ 1–2 sentences β‰ˆ fastest first-chunk on 2 vCPU
79
+ MAX_CHUNK_CHARS: int = int(os.getenv("CB_MAX_CHUNK_CHARS", "100"))
80
+ # Additive parallel mode (odd/even split across primary/helper).
81
+ ENABLE_PARALLEL_MODE: bool = _get_bool("CB_ENABLE_PARALLEL_MODE", True)
82
+ HELPER_BASE_URL: str = os.getenv("CB_HELPER_BASE_URL", "https://shadowhunter222-chab2.hf.space").strip()
83
+ HELPER_TIMEOUT_SEC: float = float(os.getenv("CB_HELPER_TIMEOUT_SEC", "45"))
84
+ HELPER_RETRY_ONCE: bool = _get_bool("CB_HELPER_RETRY_ONCE", True)
85
+ # Optional shared secret for internal chunk endpoints.
86
+ INTERNAL_SHARED_SECRET: str = os.getenv("CB_INTERNAL_SHARED_SECRET", "").strip()
87
+
88
+ # ── Server ───────────────────────────────────────────────────
89
+ HOST: str = os.getenv("CB_HOST", "0.0.0.0")
90
+ PORT: int = int(os.getenv("CB_PORT", "7860"))
91
+
92
+ ALLOWED_ORIGINS: list = [
93
+ "https://toolboxesai.com",
94
+ "www.toolboxesai.com",
95
+ "https://www.toolboxesai.com",
96
+ "http://localhost:8788", "http://127.0.0.1:8788",
97
+ "http://localhost:5502", "http://127.0.0.1:5502",
98
+ "http://localhost:5501", "http://127.0.0.1:5501",
99
+ "http://localhost:5500", "http://127.0.0.1:5500",
100
+ "http://localhost:5173", "http://127.0.0.1:5173",
101
+ "http://localhost:7860", "http://127.0.0.1:7860",
102
+ ]
her_prompt.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8eaabbeafe26ad6f78b56dcc32608763eeb69485db074c7136c6818f04a93ced
3
+ size 725328
ivr_female_02_prompt.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:64953bf94657c4334532319fd4f20e9859c31af4445940916b04f129ef1f89e6
3
+ size 2779278
requirements.txt ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # =========================================================
2
+ # Chatterbox Turbo TTS - Dependencies (CPU-only)
3
+ # =========================================================
4
+
5
+ # PyTorch CPU (required by transformers tokenizer internals)
6
+ torch --index-url https://download.pytorch.org/whl/cpu
7
+
8
+ # Core API
9
+ fastapi>=0.104.1
10
+ uvicorn[standard]>=0.24.0
11
+ pydantic>=2.5.0
12
+ python-multipart>=0.0.6
13
+
14
+ # ONNX Runtime (CPU inference)
15
+ onnxruntime>=1.17.0
16
+
17
+ # Audio processing
18
+ numpy>=1.24.0
19
+ librosa>=0.10.0
20
+ soundfile>=0.12.0
21
+
22
+ # Tokenizer + model download
23
+ transformers>=4.46.0
24
+ huggingface-hub>=0.19.0
text_processor.py ADDED
@@ -0,0 +1,331 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Chatterbox Turbo TTS β€” Text Processor
3
+ ═══════════════════════════════════════
4
+ Sanitizes raw input text and splits it into sentence-level chunks
5
+ for streaming TTS. Paralinguistic tags ([laugh], [cough], …) are
6
+ explicitly preserved so the model can render them.
7
+
8
+ Punctuation Philosophy (based on Resemble AI recommendations):
9
+ βœ… PRESERVE (benefits prosody):
10
+ β€’ Ellipsis ... β†’ dramatic pause, trailing thought, hesitation
11
+ β€’ Em dash β€” β†’ abrupt transition, dramatic break
12
+ β€’ Comma , β†’ short natural pause / breathing point
13
+ β€’ Period . β†’ full stop, pitch drop, sentence boundary
14
+ β€’ ! and ? β†’ exclamatory / interrogative inflection
15
+ β€’ Semicolon ; β†’ medium pause, clause bridge (NOT a split point)
16
+ β€’ Colon : β†’ medium pause, introduces explanation (NOT a split point)
17
+ β€’ Parentheses () β†’ quieter/explanatory tone shift
18
+ β€’ Quotes "" β†’ dialogue cue
19
+ β€’ Apostrophe ' β†’ contractions (don't, it's)
20
+ β€’ CAPS words β†’ emphasis / volume increase
21
+
22
+ ❌ FILTER (harms output):
23
+ β€’ Excessive repeated punctuation (!!!! β†’ !, ???? β†’ ?, ,,, β†’ ,)
24
+ β€’ 4+ dots (.... β†’ ...)
25
+ β€’ Emojis, URLs, markdown, HTML tags
26
+ β€’ Non-standard Unicode punctuation (guillemets, etc.)
27
+ """
28
+ import re
29
+ from typing import List
30
+
31
+ from config import Config
32
+
33
+ # ═══════════════════════════════════════════════════════════════════
34
+ # Pre-compiled regex patterns (compiled once at import β†’ zero cost)
35
+ # ═══════════════════════════════════════════════════════════════════
36
+
37
+ # β€” Paralinguistic tag protector (matches [laugh], [clear throat], etc.)
38
+ _TAG_NAMES = "|".join(re.escape(t) for t in Config.PARALINGUISTIC_TAGS)
39
+ _RE_PARA_TAG = re.compile(rf"\[(?:{_TAG_NAMES})\]", re.IGNORECASE)
40
+
41
+ # β€” Markdown / structural noise
42
+ _RE_CODE_BLOCK = re.compile(r"```[\s\S]*?```")
43
+ _RE_INLINE_CODE = re.compile(r"`([^`]+)`")
44
+ _RE_IMAGE = re.compile(r"!\[([^\]]*)\]\([^)]+\)")
45
+ _RE_LINK = re.compile(r"\[([^\]]+)\]\([^)]+\)")
46
+ _RE_BOLD_AST = re.compile(r"\*\*(.+?)\*\*")
47
+ _RE_BOLD_UND = re.compile(r"__(.+?)__")
48
+ _RE_STRIKE = re.compile(r"~~(.+?)~~")
49
+ _RE_ITALIC_AST = re.compile(r"\*(.+?)\*")
50
+ _RE_ITALIC_UND = re.compile(r"(?<!\w)_(.+?)_(?!\w)")
51
+ _RE_HEADER = re.compile(r"^#{1,6}\s+", re.MULTILINE)
52
+ _RE_BLOCKQUOTE = re.compile(r"^>+\s?", re.MULTILINE)
53
+ _RE_HR = re.compile(r"^[-*_]{3,}$", re.MULTILINE)
54
+ _RE_BULLET = re.compile(r"^\s*[-*+]\s+", re.MULTILINE)
55
+ _RE_ORDERED = re.compile(r"^\s*\d+\.\s+", re.MULTILINE)
56
+
57
+ # β€” URLs, emojis, HTML entities
58
+ _RE_URL = re.compile(r"https?://\S+")
59
+ _RE_EMOJI = re.compile(
60
+ r"["
61
+ r"\U0001F600-\U0001F64F\U0001F300-\U0001F5FF"
62
+ r"\U0001F680-\U0001F6FF\U0001F1E0-\U0001F1FF"
63
+ r"\U00002702-\U000027B0\U0001F900-\U0001F9FF"
64
+ r"\U0001FA00-\U0001FA6F\U0001FA70-\U0001FAFF"
65
+ r"\U00002600-\U000026FF\U0000FE00-\U0000FE0F"
66
+ r"\U0000200D"
67
+ r"]+", re.UNICODE,
68
+ )
69
+ _RE_HTML_ENTITY = re.compile(r"&(?:#x?[\da-fA-F]+|\w+);")
70
+
71
+ # HTML entities β†’ speakable replacements
72
+ # NOTE: &hellip; β†’ "..." (preserves dramatic pause), &mdash;/&ndash; β†’ "β€”" (preserves dramatic break)
73
+ _HTML_ENTITIES = {
74
+ "&amp;": " and ", "&lt;": " less than ", "&gt;": " greater than ",
75
+ "&nbsp;": " ", "&quot;": '"', "&apos;": "'",
76
+ "&mdash;": "β€”", "&ndash;": "β€”", "&hellip;": "...",
77
+ }
78
+
79
+ # β€” Smart/curly quote normalization β†’ ASCII equivalents
80
+ # These Unicode variants may confuse the tokenizer; normalizing ensures clean input.
81
+ _SMART_QUOTE_MAP = str.maketrans({
82
+ "\u201c": '"', # " left double quotation mark
83
+ "\u201d": '"', # " right double quotation mark
84
+ "\u2018": "'", # ' left single quotation mark
85
+ "\u2019": "'", # ' right single quotation mark
86
+ "\u00ab": '"', # Β« left guillemet
87
+ "\u00bb": '"', # Β» right guillemet
88
+ "\u201e": '"', # β€ž double low quotation mark
89
+ "\u201f": '"', # β€Ÿ double high reversed quotation mark
90
+ "\u2032": "'", # β€² prime
91
+ "\u2033": '"', # β€³ double prime
92
+ "\u2013": "β€”", # – en dash β†’ em dash (dramatic pause)
93
+ "\u2014": "β€”", # β€” em dash (keep as-is after mapping)
94
+ "\u2026": "...", # … horizontal ellipsis β†’ three dots
95
+ })
96
+
97
+ # β€” ALL CAPS normalization
98
+ # Words entirely in caps (length >= 4) often get spelled out by the TTS engine (e.g. NOTHING).
99
+ # By converting them to Title Case, they'll be processed naturally as words.
100
+ _RE_ALL_CAPS = re.compile(r"\b[A-Z]{4,}\b")
101
+
102
+ # β€” Punctuation normalization
103
+ # Ellipsis (... / ..) is PRESERVED β€” it creates dramatic pauses in Chatterbox.
104
+ # Only 4+ dots are excessive and get capped to standard ellipsis.
105
+ _RE_EXCESSIVE_DOTS = re.compile(r"\.{4,}") # ....+ β†’ ... (cap excessive)
106
+ _RE_NORMALIZE_DOTS = re.compile(r"\.{2,3}") # .. or ... β†’ ... (standardize)
107
+ _RE_REPEATED_EXCLAM = re.compile(r"!{2,}") # !! β†’ !
108
+ _RE_REPEATED_QUEST = re.compile(r"\?{2,}") # ?? β†’ ?
109
+ _RE_REPEATED_SEMI = re.compile(r";{2,}") # ;; β†’ ;
110
+ _RE_REPEATED_COLON = re.compile(r":{2,}") # :: β†’ :
111
+ _RE_REPEATED_COMMA = re.compile(r",{2,}") # ,, β†’ ,
112
+ _RE_REPEATED_DASH = re.compile(r"-{3,}") # --- β†’ β€” (em dash)
113
+
114
+ # β€” Abbreviation protection
115
+ # Common abbreviations ending in "." that should NOT trigger sentence splitting.
116
+ # These get a placeholder before splitting, then get restored.
117
+ _ABBREVIATIONS = (
118
+ "Mr", "Mrs", "Ms", "Dr", "Prof", "Sr", "Jr", "St", "Ave", "Blvd",
119
+ "vs", "etc", "approx", "dept", "est", "govt", "inc", "corp", "ltd",
120
+ "Jan", "Feb", "Mar", "Apr", "Jun", "Jul", "Aug", "Sep", "Oct", "Nov", "Dec",
121
+ "Gen", "Col", "Sgt", "Capt", "Lt", "Cmdr", "Adm",
122
+ "Fig", "Vol", "No", "Ref", "Rev", "Ph",
123
+ )
124
+ _RE_ABBREV = re.compile(
125
+ r"\b(" + "|".join(re.escape(a) for a in _ABBREVIATIONS) + r")\.",
126
+ re.IGNORECASE,
127
+ )
128
+
129
+ # β€” Whitespace
130
+ _RE_MULTI_SPACE = re.compile(r"[ \t]+")
131
+ _RE_MULTI_NEWLINE = re.compile(r"\n{3,}")
132
+ _RE_SPACE_BEFORE_PUN = re.compile(r"\s+([.!?,;:])")
133
+
134
+ # β€” Sentence boundary (split point)
135
+ # Split ONLY on true sentence-ending punctuation: . ! ?
136
+ # Semicolons and colons are clause connectors β€” they bridge related thoughts
137
+ # and should NOT be used as split points (creates choppy, unnatural fragments).
138
+ # Ellipsis (...) is also intentionally excluded from splitting: letting it split the stream
139
+ # creates a compound lag between chunks, making the pause artificially excessive.
140
+ _RE_SENTENCE_SPLIT = re.compile(
141
+ r"""(?:(?<=[.!?])(?<!\.\.\.)|(?<=[.!?][)\]"'])(?<!\.\.\.\.))\s+"""
142
+ )
143
+
144
+ _MIN_MERGE_WORDS = 5
145
+
146
+
147
+ # ═══════════════════════════════════════════════════════════════════
148
+ # Public API
149
+ # ═══════════════════════════════════════════════════════════════════
150
+
151
+ def sanitize(text: str) -> str:
152
+ """Clean raw input for TTS while preserving prosody-beneficial punctuation.
153
+
154
+ Preserves: ellipsis (...), em dashes (β€”), commas, periods, !, ?, ;, :, quotes.
155
+ Removes: emojis, URLs, markdown, HTML, excessive repeated punctuation.
156
+ """
157
+ if not text:
158
+ return text
159
+
160
+ # 0. Normalize smart/curly quotes and Unicode punctuation FIRST
161
+ # This ensures downstream regex works on clean ASCII-like punctuation.
162
+ text = text.translate(_SMART_QUOTE_MAP)
163
+
164
+ # 1. Normalize ALL CAPS words to Title Case to prevent spelling out
165
+ text = _RE_ALL_CAPS.sub(lambda m: m.group(0).capitalize(), text)
166
+
167
+ # 2. Protect paralinguistic tags by replacing with placeholders
168
+ tags_found: list[tuple[int, str]] = []
169
+ def _protect_tag(m):
170
+ idx = len(tags_found)
171
+ tags_found.append((idx, m.group(0)))
172
+ return f"Β§TAG{idx}Β§"
173
+ text = _RE_PARA_TAG.sub(_protect_tag, text)
174
+
175
+ # 3. Protect abbreviations from sentence-boundary splitting
176
+ # "Dr. Smith" β†’ "DrΒ§ Smith" (restored later)
177
+ abbrevs_found: list[tuple[int, str]] = []
178
+ def _protect_abbrev(m):
179
+ idx = len(abbrevs_found)
180
+ abbrevs_found.append((idx, m.group(0)))
181
+ return f"{m.group(1)}Β§ABR{idx}Β§"
182
+ text = _RE_ABBREV.sub(_protect_abbrev, text)
183
+
184
+ # 4. Strip non-speakable structures
185
+ text = _RE_URL.sub("", text)
186
+ text = _RE_CODE_BLOCK.sub("", text)
187
+ text = _RE_IMAGE.sub(lambda m: m.group(1) if m.group(1) else "", text)
188
+ text = _RE_LINK.sub(r"\1", text)
189
+ text = _RE_BOLD_AST.sub(r"\1", text)
190
+ text = _RE_BOLD_UND.sub(r"\1", text)
191
+ text = _RE_STRIKE.sub(r"\1", text)
192
+ text = _RE_ITALIC_AST.sub(r"\1", text)
193
+ text = _RE_ITALIC_UND.sub(r"\1", text)
194
+ text = _RE_INLINE_CODE.sub(r"\1", text)
195
+ text = _RE_HEADER.sub("", text)
196
+ text = _RE_BLOCKQUOTE.sub("", text)
197
+ text = _RE_HR.sub("", text)
198
+ text = _RE_BULLET.sub("", text)
199
+ text = _RE_ORDERED.sub("", text)
200
+
201
+ # 5. Emojis, hashtags
202
+ text = _RE_EMOJI.sub("", text)
203
+ text = re.sub(r"#(\w+)", r"\1", text)
204
+
205
+ # 6. HTML entities β†’ speakable text
206
+ text = _RE_HTML_ENTITY.sub(lambda m: _HTML_ENTITIES.get(m.group(0), ""), text)
207
+
208
+ # 7. Normalize punctuation (PRESERVE prosody-beneficial marks)
209
+ # Order matters: handle excessive dots first, then standardize ellipsis.
210
+ text = _RE_EXCESSIVE_DOTS.sub("...", text) # ....+ β†’ ... (cap)
211
+ text = _RE_NORMALIZE_DOTS.sub("...", text) # .. or ... β†’ ... (standardize)
212
+ text = _RE_REPEATED_EXCLAM.sub("!", text) # !! β†’ !
213
+ text = _RE_REPEATED_QUEST.sub("?", text) # ?? β†’ ?
214
+ text = _RE_REPEATED_SEMI.sub(";", text) # ;; β†’ ;
215
+ text = _RE_REPEATED_COLON.sub(":", text) # :: β†’ :
216
+ text = _RE_REPEATED_COMMA.sub(",", text) # ,, β†’ ,
217
+ text = _RE_REPEATED_DASH.sub("β€”", text) # --- β†’ em dash
218
+
219
+ # 8. Whitespace cleanup
220
+ text = _RE_SPACE_BEFORE_PUN.sub(r"\1", text)
221
+ text = _RE_MULTI_SPACE.sub(" ", text)
222
+ text = _RE_MULTI_NEWLINE.sub("\n\n", text)
223
+ text = text.strip()
224
+
225
+ # 9. Strip abbreviation dots (Mr. β†’ Mr, Dr. β†’ Dr, etc.)
226
+ # The dot is not needed for correct TTS pronunciation and removing it
227
+ # prevents false sentence-boundary splits in split_for_streaming().
228
+ for idx, original in abbrevs_found:
229
+ text = text.replace(f"Β§ABR{idx}Β§", "")
230
+
231
+ # 10. Restore paralinguistic tags
232
+ for idx, original in tags_found:
233
+ text = text.replace(f"Β§TAG{idx}Β§", original)
234
+
235
+ return text
236
+
237
+
238
+ def split_for_streaming(text: str, max_chars: int = Config.MAX_CHUNK_CHARS) -> List[str]:
239
+ """Split sanitized text into sentence-level chunks for streaming.
240
+
241
+ Strategy:
242
+ 1. Split on sentence-ending punctuation boundaries (. ! ?)
243
+ β€” NOT on semicolons, colons, or ellipsis (those are non-breaking boundaries)
244
+ 2. Enforce max_chars per chunk (split long sentences on commas / spaces)
245
+ 3. Merge short chunks (≀5 words) with the next to avoid tiny segments
246
+ """
247
+ if not text:
248
+ return []
249
+
250
+ # Step 1: sentence split
251
+ raw_chunks = _RE_SENTENCE_SPLIT.split(text)
252
+ raw_chunks = [c.strip() for c in raw_chunks if c.strip()]
253
+
254
+ # Step 2: enforce max length per chunk
255
+ sized: List[str] = []
256
+ for chunk in raw_chunks:
257
+ if len(chunk) <= max_chars:
258
+ sized.append(chunk)
259
+ else:
260
+ sized.extend(_break_long_chunk(chunk, max_chars))
261
+
262
+ # Step 3: merge short chunks
263
+ if len(sized) <= 1:
264
+ return sized
265
+
266
+ merged: List[str] = []
267
+ carry = ""
268
+ for i, chunk in enumerate(sized):
269
+ if carry:
270
+ chunk = carry + " " + chunk
271
+ carry = ""
272
+ if len(chunk.split()) <= _MIN_MERGE_WORDS and i < len(sized) - 1:
273
+ carry = chunk
274
+ else:
275
+ merged.append(chunk)
276
+ if carry:
277
+ if merged:
278
+ merged[-1] += " " + carry
279
+ else:
280
+ merged.append(carry)
281
+
282
+ return merged
283
+
284
+
285
+ # ═══════════════════════════════════════════════════════════════════
286
+ # Internal helpers
287
+ # ═══════════════════════════════════════════════════════════════════
288
+
289
+ def _break_long_chunk(text: str, max_chars: int) -> List[str]:
290
+ """Break a chunk longer than max_chars on commas or word boundaries."""
291
+ parts: List[str] = []
292
+ remaining = text
293
+ while len(remaining) > max_chars:
294
+ break_pos = -1
295
+ include_break_char = False
296
+
297
+ # Prefer punctuation/pauses first to keep prosody natural.
298
+ for marker in (",", ";", ":", "β€”", "-", "!", "?"):
299
+ pos = remaining.rfind(marker, 0, max_chars)
300
+ if pos > break_pos:
301
+ break_pos = pos
302
+ include_break_char = True
303
+
304
+ # Then prefer nearest space before limit.
305
+ space_pos = remaining.rfind(" ", 0, max_chars)
306
+ if space_pos > break_pos:
307
+ break_pos = space_pos
308
+ include_break_char = False
309
+
310
+ # If nothing before limit, look slightly ahead to avoid mid-word cuts.
311
+ if break_pos == -1:
312
+ forward_limit = min(len(remaining), max_chars + 24)
313
+ m = re.search(r"[\s,;:!?]", remaining[max_chars:forward_limit])
314
+ if m:
315
+ break_pos = max_chars + m.start()
316
+ include_break_char = remaining[break_pos] in ",;:!?"
317
+ else:
318
+ break_pos = max_chars
319
+ include_break_char = False
320
+
321
+ cut_at = break_pos + (1 if include_break_char else 0)
322
+ if cut_at <= 0:
323
+ cut_at = min(max_chars, len(remaining))
324
+
325
+ segment = remaining[:cut_at].strip()
326
+ if segment:
327
+ parts.append(segment)
328
+ remaining = remaining[cut_at:].lstrip()
329
+ if remaining.strip():
330
+ parts.append(remaining.strip())
331
+ return parts