ShadowHunter222 commited on
Commit
5f92103
Β·
verified Β·
1 Parent(s): 06dac54

Upload 6 files

Browse files
Files changed (4) hide show
  1. app.py +4 -42
  2. chatterbox_wrapper.py +198 -6
  3. config.py +1 -1
  4. text_processor.py +7 -41
app.py CHANGED
@@ -1,20 +1,4 @@
1
- """
2
- Chatterbox Turbo TTS -- FastAPI Server
3
- ======================================
4
- Production-ready API with true real-time MP3 streaming,
5
- in-memory voice cloning, and fully non-blocking inference.
6
-
7
- Endpoints:
8
- GET /health -> health check + optional warmup
9
- GET /info -> model info, supported tags, parameters
10
- POST /tts -> full audio response (WAV/MP3/FLAC)
11
- POST /tts/stream -> chunked MP3 streaming (MediaSource-ready)
12
- POST /tts/true-stream -> alias for /tts/stream (Kokoro compat)
13
- POST /tts/stop/{stream_id}-> cancel a specific active stream
14
- POST /tts/stop -> cancel ALL active streams
15
- POST /v1/audio/speech -> OpenAI-compatible streaming
16
- """
17
- import asyncio
18
  import io
19
  import json
20
  import logging
@@ -259,31 +243,6 @@ async def health(warm_up: bool = False):
259
  return status
260
 
261
 
262
- @app.get("/info")
263
- async def info():
264
- return {
265
- "model": Config.MODEL_ID,
266
- "dtype": Config.MODEL_DTYPE,
267
- "sample_rate": Config.SAMPLE_RATE,
268
- "paralinguistic_tags": list(Config.PARALINGUISTIC_TAGS),
269
- "tag_usage": "Insert tags directly in text, e.g. 'That is so funny! [laugh] Anyway…'",
270
- "parameters": {
271
- "max_new_tokens": {"default": Config.MAX_NEW_TOKENS, "range": "64–2048"},
272
- "repetition_penalty": {"default": Config.REPETITION_PENALTY, "range": "1.0–2.0"},
273
- },
274
- "voice_cloning": {
275
- "description": "Upload 3–30s reference WAV/MP3 as 'voice_ref' field",
276
- "max_upload_mb": Config.MAX_VOICE_UPLOAD_BYTES // (1024 * 1024),
277
- },
278
- "parallel_mode": {
279
- "enabled": Config.ENABLE_PARALLEL_MODE,
280
- "helper_configured": bool(Config.HELPER_BASE_URL),
281
- "helper_base_url": Config.HELPER_BASE_URL or None,
282
- "supports_voice_ref": True,
283
- },
284
- }
285
-
286
-
287
  # ── POST /tts ─────────────────────────────────────────────────────
288
 
289
  @app.post("/tts", response_class=Response)
@@ -805,6 +764,9 @@ async def internal_chunk_synthesize(
805
  voice_profile = wrapper.default_voice
806
  if request.voice_key:
807
  cached_voice = wrapper._voice_cache.get(request.voice_key)
 
 
 
808
  if cached_voice is None:
809
  raise HTTPException(409, "Voice key expired or not found")
810
  voice_profile = cached_voice
 
1
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  import io
3
  import json
4
  import logging
 
243
  return status
244
 
245
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
246
  # ── POST /tts ─────────────────────────────────────────────────────
247
 
248
  @app.post("/tts", response_class=Response)
 
764
  voice_profile = wrapper.default_voice
765
  if request.voice_key:
766
  cached_voice = wrapper._voice_cache.get(request.voice_key)
767
+ if cached_voice is None:
768
+ # Built-in voices are permanent in wrapper registry even if TTL cache entry expired.
769
+ cached_voice = wrapper.get_builtin_voice_by_hash(request.voice_key)
770
  if cached_voice is None:
771
  raise HTTPException(409, "Voice key expired or not found")
772
  voice_profile = cached_voice
chatterbox_wrapper.py CHANGED
@@ -27,6 +27,7 @@ import tempfile
27
  import time
28
  from collections import OrderedDict
29
  from dataclasses import dataclass
 
30
  from typing import Callable, Generator, Optional
31
 
32
  import librosa
@@ -48,6 +49,21 @@ _SUPPORTED_AUDIO_EXTENSIONS = {
48
  }
49
 
50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  # ═══════════════════════════════════════════════════════════════════
52
  # Data Structures
53
  # ═══════════════════════════════════════════════════════════════════
@@ -203,8 +219,15 @@ class ChatterboxWrapper:
203
  ttl_seconds=self.cfg.VOICE_CACHE_TTL_SEC,
204
  )
205
 
206
- logger.info("Encoding default reference voice …")
207
- self.default_voice = self._load_default_voice()
 
 
 
 
 
 
 
208
 
209
  logger.info("βœ… ChatterboxWrapper ready")
210
 
@@ -260,16 +283,185 @@ class ChatterboxWrapper:
260
  opts.enable_mem_reuse = True
261
  return opts
262
 
263
- # ─── Default voice ────────────────────────────────────────────
264
 
265
- def _load_default_voice(self) -> VoiceProfile:
266
  path = hf_hub_download(
267
  self.cfg.DEFAULT_VOICE_REPO,
268
  filename=self.cfg.DEFAULT_VOICE_FILE,
269
  cache_dir=self.cfg.MODELS_DIR,
270
  )
271
- audio, _ = librosa.load(path, sr=self.cfg.SAMPLE_RATE)
272
- return self._encode_audio_array(audio, audio_hash="__default__")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
273
 
274
  # ─── Voice encoding ──────────────────────────────────────────
275
 
 
27
  import time
28
  from collections import OrderedDict
29
  from dataclasses import dataclass
30
+ from pathlib import Path
31
  from typing import Callable, Generator, Optional
32
 
33
  import librosa
 
49
  }
50
 
51
 
52
+ def _slugify(text: str) -> str:
53
+ buf = []
54
+ prev_underscore = False
55
+ for ch in text.strip().lower():
56
+ if ch.isalnum():
57
+ buf.append(ch)
58
+ prev_underscore = False
59
+ else:
60
+ if not prev_underscore:
61
+ buf.append("_")
62
+ prev_underscore = True
63
+ slug = "".join(buf).strip("_")
64
+ return slug or "voice"
65
+
66
+
67
  # ═══════════════════════════════════════════════════════════════════
68
  # Data Structures
69
  # ═══════════════════════════════════════════════════════════════════
 
219
  ttl_seconds=self.cfg.VOICE_CACHE_TTL_SEC,
220
  )
221
 
222
+ self._builtin_voice_profiles: dict[str, VoiceProfile] = {}
223
+ self._builtin_voice_bytes: dict[str, bytes] = {}
224
+ self._builtin_voice_by_hash: dict[str, VoiceProfile] = {}
225
+ self._voice_alias_to_id: dict[str, str] = {}
226
+ self._builtin_voice_catalog: list[dict] = []
227
+ self._default_voice_id: str = "default"
228
+
229
+ logger.info("Loading built-in voices (HF default + local samples) …")
230
+ self.default_voice = self._load_builtin_voices()
231
 
232
  logger.info("βœ… ChatterboxWrapper ready")
233
 
 
283
  opts.enable_mem_reuse = True
284
  return opts
285
 
286
+ # ─── Built-in voices (HF default + local samples) ────────────
287
 
288
+ def _download_hf_default_voice_bytes(self) -> bytes:
289
  path = hf_hub_download(
290
  self.cfg.DEFAULT_VOICE_REPO,
291
  filename=self.cfg.DEFAULT_VOICE_FILE,
292
  cache_dir=self.cfg.MODELS_DIR,
293
  )
294
+ return Path(path).read_bytes()
295
+
296
+ def _list_local_voice_paths(self) -> list[Path]:
297
+ wrapper_dir = Path(__file__).resolve().parent
298
+
299
+ # Support both module-level and repo-root deployment layouts.
300
+ candidates = []
301
+ for d in (wrapper_dir, Path.cwd().resolve(), wrapper_dir.parent):
302
+ try:
303
+ resolved = d.resolve()
304
+ except Exception:
305
+ continue
306
+ if resolved.is_dir() and resolved not in candidates:
307
+ candidates.append(resolved)
308
+
309
+ voices: list[Path] = []
310
+ seen_real_paths: set[str] = set()
311
+ for root in candidates:
312
+ try:
313
+ entries = sorted(root.iterdir(), key=lambda x: x.name.lower())
314
+ except Exception:
315
+ continue
316
+
317
+ for p in entries:
318
+ if not p.is_file():
319
+ continue
320
+ if p.suffix.lower() not in _SUPPORTED_AUDIO_EXTENSIONS:
321
+ continue
322
+ real_path = str(p.resolve())
323
+ if real_path in seen_real_paths:
324
+ continue
325
+ seen_real_paths.add(real_path)
326
+ voices.append(p)
327
+
328
+ return voices
329
+
330
+ def _make_unique_voice_id(self, preferred: str) -> str:
331
+ base = _slugify(preferred)
332
+ candidate = base
333
+ idx = 2
334
+ while candidate in self._builtin_voice_profiles:
335
+ candidate = f"{base}_{idx}"
336
+ idx += 1
337
+ return candidate
338
+
339
+ def _register_builtin_voice(
340
+ self,
341
+ *,
342
+ preferred_id: str,
343
+ display_name: str,
344
+ source: str,
345
+ source_ref: str,
346
+ audio_bytes: bytes,
347
+ is_default: bool = False,
348
+ ) -> str:
349
+ if not audio_bytes:
350
+ raise ValueError("Voice file is empty")
351
+
352
+ voice_id = self._make_unique_voice_id(preferred_id)
353
+ audio_hash = hashlib.md5(audio_bytes).hexdigest()
354
+
355
+ profile = self._voice_cache.get(audio_hash)
356
+ if profile is None:
357
+ audio = _load_audio_bytes(audio_bytes, sr=self.cfg.SAMPLE_RATE)
358
+ profile = self._encode_audio_array(audio, audio_hash=audio_hash)
359
+ self._voice_cache.put(audio_hash, profile)
360
+ else:
361
+ # Keep hash attached to cached profile for metadata/voice-key usage.
362
+ profile.audio_hash = audio_hash
363
+
364
+ self._builtin_voice_profiles[voice_id] = profile
365
+ self._builtin_voice_bytes[voice_id] = audio_bytes
366
+ if audio_hash:
367
+ self._builtin_voice_by_hash[audio_hash] = profile
368
+
369
+ aliases: list[str] = []
370
+ for alias in (voice_id, _slugify(Path(display_name).stem)):
371
+ if alias not in self._voice_alias_to_id:
372
+ self._voice_alias_to_id[alias] = voice_id
373
+ aliases.append(alias)
374
+
375
+ if is_default:
376
+ self._default_voice_id = voice_id
377
+ self._voice_alias_to_id["default"] = voice_id
378
+ if "default" not in aliases:
379
+ aliases.append("default")
380
+
381
+ self._builtin_voice_catalog.append(
382
+ {
383
+ "id": voice_id,
384
+ "display_name": display_name,
385
+ "source": source,
386
+ "source_ref": source_ref,
387
+ "aliases": aliases,
388
+ "voice_key": audio_hash,
389
+ }
390
+ )
391
+ return voice_id
392
+
393
+ def _load_builtin_voices(self) -> VoiceProfile:
394
+ # 1) HF default voice (kept as true default fallback)
395
+ hf_bytes = self._download_hf_default_voice_bytes()
396
+ self._register_builtin_voice(
397
+ preferred_id="default_hf_voice",
398
+ display_name=self.cfg.DEFAULT_VOICE_FILE,
399
+ source="huggingface",
400
+ source_ref=f"{self.cfg.DEFAULT_VOICE_REPO}:{self.cfg.DEFAULT_VOICE_FILE}",
401
+ audio_bytes=hf_bytes,
402
+ is_default=True,
403
+ )
404
+
405
+ # 2) Local voice samples placed next to app files
406
+ for path in self._list_local_voice_paths():
407
+ # Avoid duplicate entry if someone also copied default_voice.wav locally.
408
+ if path.name == self.cfg.DEFAULT_VOICE_FILE:
409
+ continue
410
+ try:
411
+ self._register_builtin_voice(
412
+ preferred_id=path.stem,
413
+ display_name=path.name,
414
+ source="local",
415
+ source_ref=str(path.name),
416
+ audio_bytes=path.read_bytes(),
417
+ is_default=False,
418
+ )
419
+ except Exception as e:
420
+ logger.warning(f"Skipping local voice {path.name}: {e}")
421
+
422
+ default_profile = self._builtin_voice_profiles.get(self._default_voice_id)
423
+ if default_profile is None:
424
+ raise RuntimeError("Default built-in voice could not be initialized")
425
+
426
+ logger.info(
427
+ f"Built-in voices loaded: {len(self._builtin_voice_catalog)} "
428
+ f"(default={self._default_voice_id})"
429
+ )
430
+ return default_profile
431
+
432
+ def list_builtin_voices(self) -> list[dict]:
433
+ """Return metadata for startup-preloaded voices."""
434
+ return [dict(v) for v in self._builtin_voice_catalog]
435
+
436
+ @property
437
+ def default_voice_name(self) -> str:
438
+ return self._default_voice_id
439
+
440
+ def resolve_voice_id(self, voice_name: Optional[str]) -> str:
441
+ if voice_name is None:
442
+ return self._default_voice_id
443
+ key = _slugify(str(voice_name))
444
+ if not key:
445
+ return self._default_voice_id
446
+ voice_id = self._voice_alias_to_id.get(key)
447
+ if voice_id is None:
448
+ available = ", ".join(sorted(self._voice_alias_to_id.keys()))
449
+ raise ValueError(f"Unknown voice '{voice_name}'. Available: {available}")
450
+ return voice_id
451
+
452
+ def get_builtin_voice(self, voice_name: Optional[str]) -> VoiceProfile:
453
+ voice_id = self.resolve_voice_id(voice_name)
454
+ profile = self._builtin_voice_profiles[voice_id]
455
+ if profile.audio_hash:
456
+ self._voice_cache.put(profile.audio_hash, profile)
457
+ return profile
458
+
459
+ def get_builtin_voice_bytes(self, voice_name: Optional[str]) -> Optional[bytes]:
460
+ voice_id = self.resolve_voice_id(voice_name)
461
+ return self._builtin_voice_bytes.get(voice_id)
462
+
463
+ def get_builtin_voice_by_hash(self, audio_hash: str) -> Optional[VoiceProfile]:
464
+ return self._builtin_voice_by_hash.get((audio_hash or "").strip())
465
 
466
  # ─── Voice encoding ──────────────────────────────────────────
467
 
config.py CHANGED
@@ -88,7 +88,7 @@ class Config:
88
  # ── Streaming ────────────────────────────────────────────────
89
  # Smaller chunks = faster TTFB (first audio arrives sooner)
90
  # ~200 chars β‰ˆ 1–2 sentences β‰ˆ fastest first-chunk on 2 vCPU
91
- MAX_CHUNK_CHARS: int = int(os.getenv("CB_MAX_CHUNK_CHARS", "150"))
92
  # Additive parallel mode (3-way split: primary + helper1 + helper2).
93
  ENABLE_PARALLEL_MODE: bool = _get_bool("CB_ENABLE_PARALLEL_MODE", True)
94
  HELPER_BASE_URL: str = os.getenv("CB_HELPER_BASE_URL", "https://shadowhunter222-chab2.hf.space").strip()
 
88
  # ── Streaming ────────────────────────────────────────────────
89
  # Smaller chunks = faster TTFB (first audio arrives sooner)
90
  # ~200 chars β‰ˆ 1–2 sentences β‰ˆ fastest first-chunk on 2 vCPU
91
+ MAX_CHUNK_CHARS: int = int(os.getenv("CB_MAX_CHUNK_CHARS", "100"))
92
  # Additive parallel mode (3-way split: primary + helper1 + helper2).
93
  ENABLE_PARALLEL_MODE: bool = _get_bool("CB_ENABLE_PARALLEL_MODE", True)
94
  HELPER_BASE_URL: str = os.getenv("CB_HELPER_BASE_URL", "https://shadowhunter222-chab2.hf.space").strip()
text_processor.py CHANGED
@@ -231,23 +231,6 @@ def sanitize(text: str) -> str:
231
  for idx, original in tags_found:
232
  text = text.replace(f"Β§TAG{idx}Β§", original)
233
 
234
- # 11. Ensure paralinguistic tags have spaces around them.
235
- # The model needs whitespace boundaries to properly render tags like
236
- # [clear throat]. Without spaces (e.g. "Jerry.[clear throat]I'm"),
237
- # the tag gets swallowed or produces silence instead of the sound.
238
- text = re.sub(
239
- r"(\w)(\[(?:" + _TAG_NAMES + r")\])",
240
- r"\1 \2",
241
- text,
242
- flags=re.IGNORECASE,
243
- )
244
- text = re.sub(
245
- r"(\[(?:" + _TAG_NAMES + r")\])(\w)",
246
- r"\1 \2",
247
- text,
248
- flags=re.IGNORECASE,
249
- )
250
-
251
  return text
252
 
253
 
@@ -303,42 +286,25 @@ def split_for_streaming(text: str, max_chars: int = Config.MAX_CHUNK_CHARS) -> L
303
  # ═══════════════════════════════════════════════════════════════════
304
 
305
  def _break_long_chunk(text: str, max_chars: int) -> List[str]:
306
- """Break a chunk longer than max_chars on natural pause boundaries.
307
-
308
- Priority order for break points:
309
- 1. Ellipsis '...' β€” strongest natural pause within a long sentence
310
- 2. Punctuation (comma, semicolon, colon, dash, !, ?)
311
- 3. Nearest space before limit
312
- 4. Look ahead slightly to avoid mid-word cuts
313
- """
314
  parts: List[str] = []
315
  remaining = text
316
  while len(remaining) > max_chars:
317
  break_pos = -1
318
  include_break_char = False
319
 
320
- # First try: break at ellipsis '...' β€” the strongest internal pause.
321
- ellipsis_pos = remaining.rfind("...", 0, max_chars)
322
- if ellipsis_pos > 0:
323
- # Include all three dots in the current segment
324
- break_pos = ellipsis_pos + 3
325
- include_break_char = False # already moved past the dots
326
-
327
- # Then try punctuation markers (only upgrade if at a later position).
328
  for marker in (",", ";", ":", "β€”", "-", "!", "?"):
329
  pos = remaining.rfind(marker, 0, max_chars)
330
  if pos > break_pos:
331
  break_pos = pos
332
  include_break_char = True
333
 
334
- # Space is a FALLBACK only β€” never override a punctuation/ellipsis break.
335
- # Cutting at punctuation gives the model proper prosody cues;
336
- # cutting at a random space creates mid-phrase fragments ("handle the").
337
- if break_pos <= 0:
338
- space_pos = remaining.rfind(" ", 0, max_chars)
339
- if space_pos > 0:
340
- break_pos = space_pos
341
- include_break_char = False
342
 
343
  # If nothing before limit, look slightly ahead to avoid mid-word cuts.
344
  if break_pos == -1:
 
231
  for idx, original in tags_found:
232
  text = text.replace(f"Β§TAG{idx}Β§", original)
233
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
234
  return text
235
 
236
 
 
286
  # ═══════════════════════════════════════════════════════════════════
287
 
288
  def _break_long_chunk(text: str, max_chars: int) -> List[str]:
289
+ """Break a chunk longer than max_chars on commas or word boundaries."""
 
 
 
 
 
 
 
290
  parts: List[str] = []
291
  remaining = text
292
  while len(remaining) > max_chars:
293
  break_pos = -1
294
  include_break_char = False
295
 
296
+ # Prefer punctuation/pauses first to keep prosody natural.
 
 
 
 
 
 
 
297
  for marker in (",", ";", ":", "β€”", "-", "!", "?"):
298
  pos = remaining.rfind(marker, 0, max_chars)
299
  if pos > break_pos:
300
  break_pos = pos
301
  include_break_char = True
302
 
303
+ # Then prefer nearest space before limit.
304
+ space_pos = remaining.rfind(" ", 0, max_chars)
305
+ if space_pos > break_pos:
306
+ break_pos = space_pos
307
+ include_break_char = False
 
 
 
308
 
309
  # If nothing before limit, look slightly ahead to avoid mid-word cuts.
310
  if break_pos == -1: