chmielvu commited on
Commit
b088dbf
·
verified ·
1 Parent(s): 519a7bc

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +300 -614
app.py CHANGED
@@ -1,22 +1,20 @@
1
  """
2
- HF Spaces (Docker) CPU-only TTS API - FIXED VERSION v1.2.0
3
- - Separate endpoints per service: XTTS v2, Parler-TTS mini multilingual, Piper.
4
- - CPU-friendly defaults for 2 vCPU / 16 GB RAM:
5
- - Sentence chunking (default ON)
6
- - Streaming via SSE (each chunk returned as standalone WAV)
7
- - Optional torch.compile, optional dynamic int8 quantization hooks
8
-
9
- FIXES APPLIED (v1.2.0):
10
- 1. Error tracking: Models that fail to load return None gracefully (no retries)
11
- 2. Health endpoint: Reports actual service availability per backend
12
- 3. Better error messages: Piper 404 shows available voices
13
- 4. Service flags: XTTS_ENABLED, PARLER_ENABLED, PIPER_ENABLED env vars
14
- 5. Parler-TTS v1.1: TWO tokenizers (prompt + description) with attention masks
15
  """
16
  from __future__ import annotations
17
 
18
  import asyncio
19
  import base64
 
20
  import io
21
  import json
22
  import os
@@ -25,37 +23,18 @@ import tempfile
25
  import threading
26
  import time
27
  from dataclasses import dataclass
28
- from functools import lru_cache
29
- from typing import Dict, Generator, Iterable, List, Optional, Tuple
30
 
31
  import numpy as np
32
  import soundfile as sf
33
  import torch
34
  from fastapi import FastAPI, File, Form, HTTPException, UploadFile
35
- from fastapi.responses import Response, StreamingResponse
36
  from pydantic import BaseModel, Field
37
 
38
- # --- Optional deps (import lazily where possible) ---
39
- # XTTS (Coqui TTS)
40
- from TTS.api import TTS
41
-
42
- # Parler-TTS (transformers)
43
- from transformers import AutoTokenizer, set_seed
44
- try:
45
- from parler_tts import ParlerTTSForConditionalGeneration
46
- except Exception:
47
- ParlerTTSForConditionalGeneration = None # type: ignore
48
-
49
- # Piper fallback
50
- try:
51
- from piper.voice import PiperVoice
52
- except Exception:
53
- PiperVoice = None # type: ignore
54
-
55
-
56
- # -----------------------
57
- # Settings / knobs
58
- # -----------------------
59
  def _env_bool(name: str, default: bool = False) -> bool:
60
  v = os.getenv(name)
61
  if v is None:
@@ -63,67 +42,61 @@ def _env_bool(name: str, default: bool = False) -> bool:
63
  return v.strip().lower() in {"1", "true", "yes", "y", "on"}
64
 
65
 
66
- @dataclass(frozen=True)
67
- class Settings:
68
- # Service toggles (NEW in v1.2.0)
69
- xtts_enabled: bool = _env_bool("XTTS_ENABLED", True)
70
- parler_enabled: bool = _env_bool("PARLER_ENABLED", True)
71
- piper_enabled: bool = _env_bool("PIPER_ENABLED", True)
72
- fallback_enabled: bool = _env_bool("ENABLE_FALLBACK", True)
73
-
74
- # XTTS v2
75
- xtts_model_name: str = os.getenv("XTTS_MODEL_NAME", "tts_models/multilingual/multi-dataset/xtts_v2")
76
- xtts_default_language: str = os.getenv("XTTS_DEFAULT_LANGUAGE", "pl")
77
- xtts_torch_compile: bool = _env_bool("XTTS_TORCH_COMPILE", False)
78
- xtts_dynamic_int8: bool = _env_bool("XTTS_DYNAMIC_INT8", False)
79
-
80
- # Parler
81
- parler_model_name: str = os.getenv("PARLER_MODEL_NAME", "parler-tts/parler-tts-mini-multilingual-v1.1")
82
- parler_default_description: str = os.getenv(
83
- "PARLER_DEFAULT_DESCRIPTION",
84
- "A clear, natural, studio-recorded voice speaking Polish with steady pacing.",
85
- )
86
- parler_seed: int = int(os.getenv("PARLER_SEED", "0"))
87
- parler_torch_compile: bool = _env_bool("PARLER_TORCH_COMPILE", False)
88
- parler_dynamic_int8: bool = _env_bool("PARLER_DYNAMIC_INT8", False)
89
 
90
- # Piper
91
- piper_voices_json: str = os.getenv("PIPER_VOICES_JSON", "")
92
- piper_voices_dir: str = os.getenv("PIPER_VOICES_DIR", "/data/piper")
93
 
94
- # Chunking / streaming defaults
95
- chunk_max_chars: int = int(os.getenv("CHUNK_MAX_CHARS", "260"))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
  chunk_max_words: int = int(os.getenv("CHUNK_MAX_WORDS", "40"))
97
- chunk_max_sentences: int = int(os.getenv("CHUNK_MAX_SENTENCES", "8"))
98
  join_silence_ms: int = int(os.getenv("JOIN_SILENCE_MS", "60"))
99
 
 
 
 
100
  # Runtime
101
  num_threads: int = int(os.getenv("OMP_NUM_THREADS", "2"))
102
- request_timeout_s: int = int(os.getenv("REQUEST_TIMEOUT_S", "240"))
103
 
104
 
105
  S = Settings()
106
 
107
- # Conservative CPU threading.
108
  torch.set_num_threads(S.num_threads)
109
  torch.set_num_interop_threads(max(1, S.num_threads // 2))
110
 
111
- # -----------------------
112
- # Utilities
113
- # -----------------------
114
  _SENT_SPLIT_RE = re.compile(r"(?<=[\.\!\?\:\;])\s+|\n+")
115
  _WS_RE = re.compile(r"\s+")
116
 
 
117
  def normalize_text(text: str) -> str:
118
- text = text.strip()
119
- text = _WS_RE.sub(" ", text)
120
- return text
121
 
122
  def split_text_into_chunks(
123
  text: str,
124
  max_chars: int = S.chunk_max_chars,
125
  max_words: int = S.chunk_max_words,
126
- max_sentences: int = S.chunk_max_sentences,
127
  ) -> List[str]:
128
  text = normalize_text(text)
129
  if not text:
@@ -139,648 +112,361 @@ def split_text_into_chunks(
139
  nonlocal cur, cur_chars, cur_words
140
  if cur:
141
  chunks.append(" ".join(cur).strip())
142
- cur = []
143
- cur_chars = 0
144
- cur_words = 0
145
 
146
  for sent in sents:
147
- w = sent.split()
148
- sent_words = len(w)
149
- sent_chars = len(sent)
150
- if (cur_chars + sent_chars > max_chars) or (cur_words + sent_words > max_words):
151
  flush()
152
  cur.append(sent)
153
- cur_chars += sent_chars + 1
154
- cur_words += sent_words
155
- if max_sentences and len(chunks) + (1 if cur else 0) >= max_sentences:
156
- flush()
157
- break
158
 
159
  flush()
160
  return chunks
161
 
 
162
  def wav_bytes_from_audio(audio: np.ndarray, sr: int) -> bytes:
163
- audio = np.asarray(audio, dtype=np.float32)
164
  buf = io.BytesIO()
165
- sf.write(buf, audio, sr, format="WAV", subtype="PCM_16")
166
  return buf.getvalue()
167
 
 
168
  def concat_audio(chunks: List[np.ndarray], sr: int, silence_ms: int = S.join_silence_ms) -> np.ndarray:
169
  if not chunks:
170
  return np.zeros((1,), dtype=np.float32)
171
  if len(chunks) == 1:
172
  return np.asarray(chunks[0], dtype=np.float32)
173
 
174
- silence = np.zeros((int(sr * (silence_ms / 1000.0)),), dtype=np.float32) if silence_ms > 0 else None
175
- out = []
176
  for i, ch in enumerate(chunks):
177
- out.append(np.asarray(ch, dtype=np.float32))
178
- if silence is not None and i != len(chunks) - 1:
179
- out.append(silence)
180
- return np.concatenate(out, axis=0)
 
181
 
182
  def b64encode_bytes(b: bytes) -> str:
183
  return base64.b64encode(b).decode("ascii")
184
 
185
- def safe_filename(prefix: str = "audio", ext: str = ".wav") -> str:
186
- return f"{prefix}_{int(time.time() * 1000)}{ext}"
187
-
188
- def _filter_kwargs(fn, kwargs: Dict) -> Dict:
189
- import inspect
190
- try:
191
- sig = inspect.signature(fn)
192
- except Exception:
193
- return kwargs
194
- accepted = set(sig.parameters.keys())
195
- return {k: v for k, v in kwargs.items() if k in accepted}
196
-
197
- # -----------------------
198
- # Model manager (lazy + locked) - FIXED v1.2.0
199
- # -----------------------
200
- class _Locks:
201
- xtts = threading.Lock()
202
- xtts_infer = threading.Lock()
203
- parler = threading.Lock()
204
- parler_infer = threading.Lock()
205
- piper = threading.Lock()
206
-
207
- class ModelManager:
208
- def __init__(self) -> None:
209
- self._xtts: Optional[TTS] = None
210
- self._xtts_error: Optional[str] = None # NEW: Track loading errors
211
-
212
- self._parler = None
213
- self._parler_prompt_tok = None
214
- self._parler_desc_tok = None
215
- self._parler_error: Optional[str] = None # NEW: Track loading errors
216
-
217
- self._piper_voices: Dict[str, str] = {}
218
- self._piper_loaded: Dict[str, "PiperVoice"] = {}
219
- self._piper_error: Optional[str] = None # NEW: Track loading errors
220
-
221
- def _maybe_torch_compile(self, module: torch.nn.Module) -> torch.nn.Module:
222
- if not hasattr(torch, "compile"):
223
- return module
224
- try:
225
- return torch.compile(module) # type: ignore
226
- except Exception:
227
- return module
228
 
229
- def _maybe_dynamic_int8(self, module: torch.nn.Module) -> torch.nn.Module:
230
- try:
231
- from torch.ao.quantization import quantize_dynamic
232
- return quantize_dynamic(module, {torch.nn.Linear}, dtype=torch.qint8)
233
- except Exception:
234
- return module
235
-
236
- def get_xtts(self) -> Optional[TTS]:
237
- """FIXED: Returns None on failure instead of crashing"""
238
- if not S.xtts_enabled:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
239
  return None
 
 
240
 
241
- with _Locks.xtts:
242
- # If we already tried and failed, return None immediately
243
- if self._xtts_error is not None:
244
- return None
245
-
246
- if self._xtts is None:
247
- try:
248
- print("[XTTS] Loading model...")
249
- tts = TTS(model_name=S.xtts_model_name, progress_bar=False, gpu=False)
 
250
  try:
251
- inner = getattr(getattr(tts, "synthesizer", None), "tts_model", None)
252
- if isinstance(inner, torch.nn.Module):
253
- if S.xtts_dynamic_int8:
254
- inner = self._maybe_dynamic_int8(inner)
255
- tts.synthesizer.tts_model = inner
256
- if S.xtts_torch_compile:
257
- inner = self._maybe_torch_compile(inner)
258
- tts.synthesizer.tts_model = inner
259
  except Exception as e:
260
- print(f"[XTTS] Warning: Optimization failed: {e}")
261
-
262
- self._xtts = tts
263
- print("[XTTS] Model loaded successfully")
264
- except Exception as e:
265
- self._xtts_error = str(e)
266
- print(f"[XTTS] Failed to load: {e}")
267
- return None
268
-
269
- return self._xtts
270
-
271
- def get_parler(self) -> Optional[Tuple]:
272
- """
273
- FIXED: Returns None on failure instead of crashing.
274
- Returns (model, prompt_tokenizer, description_tokenizer) or None.
275
- """
276
- if not S.parler_enabled:
277
- return None
278
-
279
- with _Locks.parler:
280
- if self._parler_error is not None:
281
- return None
282
-
283
- if ParlerTTSForConditionalGeneration is None:
284
- self._parler_error = "parler_tts not installed"
285
- print("[Parler] ❌ parler_tts is not installed")
286
- return None
287
 
288
- if self._parler is None or self._parler_prompt_tok is None or self._parler_desc_tok is None:
289
- try:
290
- print("[Parler] Loading model...")
291
- # Load model
292
- model = ParlerTTSForConditionalGeneration.from_pretrained(S.parler_model_name).to("cpu")
293
- model.eval()
294
-
295
- # CRITICAL FIX: Load BOTH tokenizers for v1.1
296
- prompt_tokenizer = AutoTokenizer.from_pretrained(S.parler_model_name)
297
- description_tokenizer = AutoTokenizer.from_pretrained(
298
- model.config.text_encoder._name_or_path
299
- )
300
-
301
- # Best-effort compile/quantize
302
- if isinstance(model, torch.nn.Module):
303
- if S.parler_dynamic_int8:
304
- model = self._maybe_dynamic_int8(model)
305
- if S.parler_torch_compile:
306
- model = self._maybe_torch_compile(model)
307
-
308
- self._parler = model
309
- self._parler_prompt_tok = prompt_tokenizer
310
- self._parler_desc_tok = description_tokenizer
311
- print("[Parler] ✅ Model loaded successfully")
312
- except Exception as e:
313
- self._parler_error = str(e)
314
- print(f"[Parler] ❌ Failed to load: {e}")
315
- return None
316
-
317
- return self._parler, self._parler_prompt_tok, self._parler_desc_tok
318
-
319
- def _load_piper_registry(self) -> Dict[str, str]:
320
- """Load Piper voice registry from JSON env var and/or directory scan"""
321
- reg: Dict[str, str] = {}
322
- if S.piper_voices_json:
323
- try:
324
- reg.update(json.loads(S.piper_voices_json))
325
- except Exception as e:
326
- print(f"[Piper] Warning: Failed to parse PIPER_VOICES_JSON: {e}")
327
- try:
328
- if os.path.isdir(S.piper_voices_dir):
329
- for fn in os.listdir(S.piper_voices_dir):
330
- if fn.endswith(".onnx"):
331
- voice_id = os.path.splitext(fn)[0]
332
- reg.setdefault(voice_id, os.path.join(S.piper_voices_dir, fn))
333
  except Exception as e:
334
- print(f"[Piper] Warning: Failed to scan {S.piper_voices_dir}: {e}")
335
- return reg
 
336
 
337
- def list_piper_voices(self) -> Dict[str, str]:
338
- """FIXED: Returns empty dict on error instead of crashing"""
339
- if not S.piper_enabled:
340
- return {}
341
 
342
- with _Locks.piper:
343
- if self._piper_error is not None:
344
- return {}
345
 
346
- if not self._piper_voices:
347
- try:
348
- self._piper_voices = self._load_piper_registry()
349
- if self._piper_voices:
350
- print(f"[Piper] Found {len(self._piper_voices)} voices")
351
- else:
352
- print("[Piper] ⚠️ No voices found in registry")
353
- except Exception as e:
354
- self._piper_error = str(e)
355
- print(f"[Piper] ❌ Failed to load registry: {e}")
356
- return {}
357
- return dict(self._piper_voices)
358
-
359
- def get_piper(self, voice_id: str) -> Optional["PiperVoice"]:
360
- """FIXED: Returns None on failure with better error messages"""
361
- if not S.piper_enabled:
362
- return None
363
 
364
- if PiperVoice is None:
365
- self._piper_error = "piper not installed"
366
- return None
367
 
368
- with _Locks.piper:
369
- if self._piper_error is not None:
370
- return None
 
 
 
 
 
371
 
372
- voices = self.list_piper_voices()
373
- if voice_id not in voices:
374
- return None
375
 
376
- if voice_id not in self._piper_loaded:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
377
  try:
378
- print(f"[Piper] Loading voice: {voice_id}")
379
- path = voices[voice_id]
380
- voice = PiperVoice.load(path, use_cuda=False)
381
- self._piper_loaded[voice_id] = voice
382
- print(f"[Piper] ✅ Voice loaded: {voice_id}")
383
- except Exception as e:
384
- print(f"[Piper] ❌ Failed to load voice {voice_id}: {e}")
385
- return None
386
-
387
- return self._piper_loaded.get(voice_id)
388
-
389
-
390
- _manager = ModelManager()
391
-
392
- # -----------------------
393
- # Request / response models
394
- # -----------------------
395
- class XTTSSynthRequest(BaseModel):
396
- text: str = Field(..., min_length=1, max_length=5000, description="Text to synthesize")
397
- language: Optional[str] = Field(None, description="Language code (e.g. 'pl', 'en')")
398
- speaker_wav_b64: Optional[str] = Field(None, description="Base64-encoded speaker WAV for voice cloning")
399
- stream: bool = Field(False, description="If True, stream chunks via SSE")
400
-
401
- class XTTSStreamRequest(BaseModel):
402
- text: str = Field(..., min_length=1, max_length=5000)
403
- language: Optional[str] = None
404
- speaker_wav_b64: Optional[str] = None
405
 
406
- class ParlerSynthRequest(BaseModel):
407
- text: str = Field(..., min_length=1, max_length=5000)
408
- description: Optional[str] = Field(None, description="Voice description (overrides default)")
409
- stream: bool = Field(False, description="If True, stream chunks via SSE")
410
 
411
- class ParlerStreamRequest(BaseModel):
412
- text: str = Field(..., min_length=1, max_length=5000)
413
- description: Optional[str] = None
414
 
415
- class PiperSynthRequest(BaseModel):
 
 
 
416
  text: str = Field(..., min_length=1, max_length=5000)
417
- voice_id: str = Field(..., description="Piper voice ID (from /v1/piper/voices)")
418
- stream: bool = Field(False, description="If True, stream chunks via SSE")
419
 
420
- class PiperStreamRequest(BaseModel):
 
421
  text: str = Field(..., min_length=1, max_length=5000)
422
- voice_id: str
 
 
423
 
424
  class AudioResponse(BaseModel):
425
- audio_b64: str = Field(..., description="Base64-encoded WAV file")
426
  sample_rate: int
427
  duration_s: float
 
428
  text: str
429
 
 
430
  class HealthResponse(BaseModel):
431
  status: str = "ok"
432
- version: str = "1.2.0-fallback"
433
- services: Dict[str, bool] = Field(default_factory=dict)
434
- piper_voices: int = 0
435
- fallback: bool = True
 
 
436
 
437
- # -----------------------
438
  # FastAPI app
439
- # -----------------------
440
- app = FastAPI(title="Forge-TTS API", version="1.2.0")
 
441
 
442
  @app.get("/health", response_model=HealthResponse)
443
  def health():
444
- """FIXED: Reports actual service availability"""
445
- voices = _manager.list_piper_voices()
446
-
447
- # Check service availability
448
- xtts_available = S.xtts_enabled and _manager._xtts_error is None
449
- parler_available = S.parler_enabled and _manager._parler_error is None
450
- piper_available = S.piper_enabled and _manager._piper_error is None and len(voices) > 0
451
-
452
  return HealthResponse(
453
- status="ok",
454
- version="1.2.0-fallback",
455
- services={
456
- "xtts": xtts_available,
457
- "parler": parler_available,
458
- "piper": piper_available,
459
- },
460
- piper_voices=len(voices),
461
- fallback=S.fallback_enabled,
462
  )
463
 
464
- # -----------------------
465
- # XTTS endpoints
466
- # -----------------------
467
- def _do_xtts_synth(text: str, language: str, speaker_wav_bytes: Optional[bytes]) -> Tuple[np.ndarray, int]:
468
- """Internal XTTS synthesis with proper error handling"""
469
- tts = _manager.get_xtts()
470
- if tts is None:
471
- raise HTTPException(status_code=503, detail="XTTS service unavailable. Check /health for status.")
472
-
473
- with _Locks.xtts_infer:
474
- kwargs = {
475
- "text": text,
476
- "language": language,
477
- "speaker_wav": None,
478
- }
479
-
480
- if speaker_wav_bytes:
481
- with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
482
- tmp.write(speaker_wav_bytes)
483
- tmp.flush()
484
- tmp_path = tmp.name
485
- try:
486
- kwargs["speaker_wav"] = tmp_path
487
- audio_np = tts.tts(**_filter_kwargs(tts.tts, kwargs))
488
- finally:
489
- try:
490
- os.remove(tmp_path)
491
- except Exception:
492
- pass
493
- else:
494
- audio_np = tts.tts(**_filter_kwargs(tts.tts, kwargs))
495
-
496
- sr = getattr(tts, "synthesizer", None)
497
- sr = getattr(sr, "output_sample_rate", 22050) if sr else 22050
498
- return np.asarray(audio_np, dtype=np.float32), sr
499
 
500
  @app.post("/v1/xtts/synthesize", response_model=AudioResponse)
501
- def xtts_synthesize(req: XTTSSynthRequest):
502
- """FIXED: Proper error handling for model loading failures"""
503
- if req.stream:
504
- raise HTTPException(status_code=400, detail="Use /v1/xtts/stream for streaming synthesis")
505
-
506
  speaker_bytes = None
507
  if req.speaker_wav_b64:
508
  try:
509
  speaker_bytes = base64.b64decode(req.speaker_wav_b64)
510
  except Exception as e:
511
- raise HTTPException(status_code=400, detail=f"Invalid base64 speaker_wav: {e}")
512
-
513
- lang = req.language or S.xtts_default_language
514
-
515
- try:
516
- audio, sr = _do_xtts_synth(req.text, lang, speaker_bytes)
517
- wav_bytes = wav_bytes_from_audio(audio, sr)
518
- duration = len(audio) / sr
519
-
520
- return AudioResponse(
521
- audio_b64=b64encode_bytes(wav_bytes),
522
- sample_rate=sr,
523
- duration_s=round(duration, 3),
524
- text=req.text,
525
- )
526
- except HTTPException:
527
- raise
528
- except Exception as e:
529
- raise HTTPException(status_code=500, detail=f"XTTS synthesis failed: {str(e)}")
 
 
 
 
 
 
 
 
530
 
531
  @app.post("/v1/xtts/stream")
532
- async def xtts_stream(req: XTTSStreamRequest):
533
- """Stream XTTS synthesis as SSE chunks"""
534
  speaker_bytes = None
535
  if req.speaker_wav_b64:
536
  try:
537
  speaker_bytes = base64.b64decode(req.speaker_wav_b64)
538
  except Exception as e:
539
- raise HTTPException(status_code=400, detail=f"Invalid base64: {e}")
540
 
541
  chunks = split_text_into_chunks(req.text)
542
  if not chunks:
543
- raise HTTPException(status_code=400, detail="No text to synthesize after chunking")
544
 
545
- lang = req.language or S.xtts_default_language
546
 
547
  async def generate():
548
  for i, chunk_text in enumerate(chunks):
549
  try:
550
- audio, sr = await asyncio.to_thread(_do_xtts_synth, chunk_text, lang, speaker_bytes)
 
 
551
  wav_bytes = wav_bytes_from_audio(audio, sr)
552
-
553
  payload = {
554
  "chunk_index": i,
555
  "total_chunks": len(chunks),
556
  "text": chunk_text,
557
  "audio_b64": b64encode_bytes(wav_bytes),
558
  "sample_rate": sr,
 
559
  }
560
  yield f"data: {json.dumps(payload)}\n\n"
561
  except Exception as e:
562
- error_payload = {
563
- "error": str(e),
564
- "chunk_index": i,
565
- "text": chunk_text,
566
- }
567
- yield f"data: {json.dumps(error_payload)}\n\n"
568
  break
569
-
570
  yield "data: [DONE]\n\n"
571
 
572
  return StreamingResponse(generate(), media_type="text/event-stream")
573
 
574
- # -----------------------
575
- # Parler-TTS endpoints
576
- # -----------------------
577
- def _do_parler_synth(text: str, description: str) -> Tuple[np.ndarray, int]:
578
- """Internal Parler synthesis with FIXED dual tokenizer handling"""
579
- result = _manager.get_parler()
580
- if result is None:
581
- raise HTTPException(status_code=503, detail="Parler service unavailable. Check /health for status.")
582
-
583
- model, prompt_tok, desc_tok = result
584
-
585
- with _Locks.parler_infer:
586
- if S.parler_seed > 0:
587
- set_seed(S.parler_seed)
588
-
589
- # FIXED: Use correct tokenizers with attention masks
590
- input_ids = prompt_tok(text, return_tensors="pt", padding=True).input_ids
591
- attention_mask = prompt_tok(text, return_tensors="pt", padding=True).attention_mask
592
-
593
- prompt_input_ids = desc_tok(description, return_tensors="pt", padding=True).input_ids
594
- prompt_attention_mask = desc_tok(description, return_tensors="pt", padding=True).attention_mask
595
-
596
- with torch.no_grad():
597
- generation = model.generate(
598
- input_ids=input_ids,
599
- attention_mask=attention_mask,
600
- prompt_input_ids=prompt_input_ids,
601
- prompt_attention_mask=prompt_attention_mask,
602
- )
603
 
604
- audio_arr = generation.cpu().numpy().squeeze()
605
-
606
- sr = getattr(model.config, "sampling_rate", 44100)
607
- return audio_arr.astype(np.float32), sr
608
-
609
- @app.post("/v1/parler/synthesize", response_model=AudioResponse)
610
- def parler_synthesize(req: ParlerSynthRequest):
611
- """FIXED: Proper error handling for model loading failures"""
612
- if req.stream:
613
- raise HTTPException(status_code=400, detail="Use /v1/parler/stream for streaming")
614
-
615
- desc = req.description or S.parler_default_description
616
-
617
- try:
618
- audio, sr = _do_parler_synth(req.text, desc)
619
- wav_bytes = wav_bytes_from_audio(audio, sr)
620
- duration = len(audio) / sr
621
-
622
- return AudioResponse(
623
- audio_b64=b64encode_bytes(wav_bytes),
624
- sample_rate=sr,
625
- duration_s=round(duration, 3),
626
- text=req.text,
627
- )
628
- except HTTPException:
629
- raise
630
- except Exception as e:
631
- raise HTTPException(status_code=500, detail=f"Parler synthesis failed: {str(e)}")
632
-
633
- @app.post("/v1/parler/stream")
634
- async def parler_stream(req: ParlerStreamRequest):
635
- """Stream Parler synthesis as SSE chunks"""
636
- chunks = split_text_into_chunks(req.text)
637
  if not chunks:
638
- raise HTTPException(status_code=400, detail="No text after chunking")
639
-
640
- desc = req.description or S.parler_default_description
641
-
642
- async def generate():
643
- for i, chunk_text in enumerate(chunks):
644
- try:
645
- audio, sr = await asyncio.to_thread(_do_parler_synth, chunk_text, desc)
646
- wav_bytes = wav_bytes_from_audio(audio, sr)
647
-
648
- payload = {
649
- "chunk_index": i,
650
- "total_chunks": len(chunks),
651
- "text": chunk_text,
652
- "audio_b64": b64encode_bytes(wav_bytes),
653
- "sample_rate": sr,
654
- }
655
- yield f"data: {json.dumps(payload)}\n\n"
656
- except Exception as e:
657
- error_payload = {
658
- "error": str(e),
659
- "chunk_index": i,
660
- "text": chunk_text,
661
- }
662
- yield f"data: {json.dumps(error_payload)}\n\n"
663
- break
664
-
665
- yield "data: [DONE]\n\n"
666
-
667
- return StreamingResponse(generate(), media_type="text/event-stream")
668
-
669
- # -----------------------
670
- # Piper endpoints
671
- # -----------------------
672
- @app.get("/v1/piper/voices")
673
- def piper_list_voices():
674
- """FIXED: Returns helpful empty response when no voices available"""
675
- voices = _manager.list_piper_voices()
676
- if not voices:
677
- return {
678
- "voices": {},
679
- "message": f"No Piper voices found. Check {S.piper_voices_dir} directory or PIPER_VOICES_JSON env var.",
680
- }
681
- return {"voices": voices}
682
-
683
- def _do_piper_synth(text: str, voice_id: str) -> Tuple[np.ndarray, int]:
684
- """Internal Piper synthesis with proper error handling"""
685
- voice = _manager.get_piper(voice_id)
686
- if voice is None:
687
- available = list(_manager.list_piper_voices().keys())
688
- if not available:
689
- raise HTTPException(
690
- status_code=404,
691
- detail=f"Piper voice '{voice_id}' not found. No voices available. Check /v1/piper/voices",
692
- )
693
- raise HTTPException(
694
- status_code=404,
695
- detail=f"Piper voice '{voice_id}' not found. Available: {available}. See /v1/piper/voices",
696
- )
697
-
698
- with _Locks.piper:
699
- audio_bytes = io.BytesIO()
700
- voice.synthesize(text, audio_bytes)
701
- audio_bytes.seek(0)
702
-
703
- audio_np, sr = sf.read(audio_bytes)
704
- return audio_np.astype(np.float32), sr
705
-
706
- @app.post("/v1/piper/synthesize", response_model=AudioResponse)
707
- def piper_synthesize(req: PiperSynthRequest):
708
- """FIXED: Better error messages showing available voices"""
709
- if req.stream:
710
- raise HTTPException(status_code=400, detail="Use /v1/piper/stream for streaming")
711
-
712
- try:
713
- audio, sr = _do_piper_synth(req.text, req.voice_id)
714
- wav_bytes = wav_bytes_from_audio(audio, sr)
715
- duration = len(audio) / sr
716
-
717
- return AudioResponse(
718
- audio_b64=b64encode_bytes(wav_bytes),
719
- sample_rate=sr,
720
- duration_s=round(duration, 3),
721
- text=req.text,
722
- )
723
- except HTTPException:
724
- raise
725
- except Exception as e:
726
- raise HTTPException(status_code=500, detail=f"Piper synthesis failed: {str(e)}")
727
-
728
- @app.post("/v1/piper/stream")
729
- async def piper_stream(req: PiperStreamRequest):
730
- """Stream Piper synthesis as SSE chunks"""
731
- chunks = split_text_into_chunks(req.text)
732
- if not chunks:
733
- raise HTTPException(status_code=400, detail="No text after chunking")
734
-
735
- async def generate():
736
- for i, chunk_text in enumerate(chunks):
737
- try:
738
- audio, sr = await asyncio.to_thread(_do_piper_synth, chunk_text, req.voice_id)
739
- wav_bytes = wav_bytes_from_audio(audio, sr)
740
-
741
- payload = {
742
- "chunk_index": i,
743
- "total_chunks": len(chunks),
744
- "text": chunk_text,
745
- "audio_b64": b64encode_bytes(wav_bytes),
746
- "sample_rate": sr,
747
- }
748
- yield f"data: {json.dumps(payload)}\n\n"
749
- except Exception as e:
750
- error_payload = {
751
- "error": str(e),
752
- "chunk_index": i,
753
- "text": chunk_text,
754
- }
755
- yield f"data: {json.dumps(error_payload)}\n\n"
756
- break
757
-
758
- yield "data: [DONE]\n\n"
759
 
760
- return StreamingResponse(generate(), media_type="text/event-stream")
761
 
762
- # -----------------------
763
- # Startup logging
764
- # -----------------------
765
  @app.on_event("startup")
766
  async def startup_event():
767
- print("\n" + "="*60)
768
- print("Forge-TTS API v1.2.0 - Starting")
769
- print("="*60)
770
- print(f"XTTS Enabled: {S.xtts_enabled}")
771
- print(f"Parler Enabled: {S.parler_enabled}")
772
- print(f"Piper Enabled: {S.piper_enabled}")
773
- print(f"Fallback Chain: {S.fallback_enabled}")
774
- print(f"Piper Voices Dir: {S.piper_voices_dir}")
775
- print("="*60 + "\n")
776
-
777
- # Trigger lazy loading to catch errors early
778
- if S.xtts_enabled:
779
- _manager.get_xtts()
780
- if S.parler_enabled:
781
- _manager.get_parler()
782
- if S.piper_enabled:
783
- _manager.list_piper_voices()
784
 
785
  if __name__ == "__main__":
786
  import uvicorn
 
1
  """
2
+ Forge-TTS v2.0.0 — XTTS-v2 Only
3
+ CPU-optimized TTS API with Polish voice cloning.
4
+ Single backend: Coqui XTTS-v2 via idiap fork (coqui-tts>=0.27.0).
5
+
6
+ Features:
7
+ - Speaker latent caching (LRU, keyed by WAV hash)
8
+ - Text chunking + audio concatenation
9
+ - SSE streaming endpoint
10
+ - Multipart WAV upload for cloning convenience
11
+ - Configurable via env vars
 
 
 
12
  """
13
  from __future__ import annotations
14
 
15
  import asyncio
16
  import base64
17
+ import hashlib
18
  import io
19
  import json
20
  import os
 
23
  import threading
24
  import time
25
  from dataclasses import dataclass
26
+ from typing import Dict, List, Optional, Tuple
 
27
 
28
  import numpy as np
29
  import soundfile as sf
30
  import torch
31
  from fastapi import FastAPI, File, Form, HTTPException, UploadFile
32
+ from fastapi.responses import StreamingResponse
33
  from pydantic import BaseModel, Field
34
 
35
+ # ---------------------------------------------------------------------------
36
+ # Settings (env-configurable)
37
+ # ---------------------------------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  def _env_bool(name: str, default: bool = False) -> bool:
39
  v = os.getenv(name)
40
  if v is None:
 
42
  return v.strip().lower() in {"1", "true", "yes", "y", "on"}
43
 
44
 
45
+ def _env_float(name: str, default: float) -> float:
46
+ v = os.getenv(name)
47
+ return float(v) if v else default
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
 
 
 
 
49
 
50
+ @dataclass(frozen=True)
51
+ class Settings:
52
+ # Model
53
+ model_name: str = os.getenv("XTTS_MODEL_NAME", "tts_models/multilingual/multi-dataset/xtts_v2")
54
+ default_language: str = os.getenv("XTTS_DEFAULT_LANGUAGE", "pl")
55
+
56
+ # Generation params
57
+ temperature: float = _env_float("XTTS_TEMPERATURE", 0.65)
58
+ speed: float = _env_float("XTTS_SPEED", 1.0)
59
+ top_p: float = _env_float("XTTS_TOP_P", 0.85)
60
+ top_k: int = int(os.getenv("XTTS_TOP_K", "50"))
61
+ repetition_penalty: float = _env_float("XTTS_REPETITION_PENALTY", 5.0)
62
+
63
+ # Optimizations
64
+ torch_compile: bool = _env_bool("XTTS_TORCH_COMPILE", False)
65
+ use_fp16: bool = _env_bool("XTTS_USE_FP16", False)
66
+
67
+ # Chunking
68
+ chunk_max_chars: int = int(os.getenv("CHUNK_MAX_CHARS", "250"))
69
  chunk_max_words: int = int(os.getenv("CHUNK_MAX_WORDS", "40"))
 
70
  join_silence_ms: int = int(os.getenv("JOIN_SILENCE_MS", "60"))
71
 
72
+ # Speaker cache
73
+ speaker_cache_size: int = int(os.getenv("SPEAKER_CACHE_SIZE", "8"))
74
+
75
  # Runtime
76
  num_threads: int = int(os.getenv("OMP_NUM_THREADS", "2"))
 
77
 
78
 
79
  S = Settings()
80
 
81
+ # Conservative CPU threading
82
  torch.set_num_threads(S.num_threads)
83
  torch.set_num_interop_threads(max(1, S.num_threads // 2))
84
 
85
+ # ---------------------------------------------------------------------------
86
+ # Text utilities (kept from v1)
87
+ # ---------------------------------------------------------------------------
88
  _SENT_SPLIT_RE = re.compile(r"(?<=[\.\!\?\:\;])\s+|\n+")
89
  _WS_RE = re.compile(r"\s+")
90
 
91
+
92
  def normalize_text(text: str) -> str:
93
+ return _WS_RE.sub(" ", text.strip())
94
+
 
95
 
96
  def split_text_into_chunks(
97
  text: str,
98
  max_chars: int = S.chunk_max_chars,
99
  max_words: int = S.chunk_max_words,
 
100
  ) -> List[str]:
101
  text = normalize_text(text)
102
  if not text:
 
112
  nonlocal cur, cur_chars, cur_words
113
  if cur:
114
  chunks.append(" ".join(cur).strip())
115
+ cur, cur_chars, cur_words = [], 0, 0
 
 
116
 
117
  for sent in sents:
118
+ w = len(sent.split())
119
+ c = len(sent)
120
+ if cur and (cur_chars + c > max_chars or cur_words + w > max_words):
 
121
  flush()
122
  cur.append(sent)
123
+ cur_chars += c + 1
124
+ cur_words += w
 
 
 
125
 
126
  flush()
127
  return chunks
128
 
129
+
130
  def wav_bytes_from_audio(audio: np.ndarray, sr: int) -> bytes:
 
131
  buf = io.BytesIO()
132
+ sf.write(buf, np.asarray(audio, dtype=np.float32), sr, format="WAV", subtype="PCM_16")
133
  return buf.getvalue()
134
 
135
+
136
  def concat_audio(chunks: List[np.ndarray], sr: int, silence_ms: int = S.join_silence_ms) -> np.ndarray:
137
  if not chunks:
138
  return np.zeros((1,), dtype=np.float32)
139
  if len(chunks) == 1:
140
  return np.asarray(chunks[0], dtype=np.float32)
141
 
142
+ silence = np.zeros(int(sr * silence_ms / 1000), dtype=np.float32) if silence_ms > 0 else None
143
+ parts = []
144
  for i, ch in enumerate(chunks):
145
+ parts.append(np.asarray(ch, dtype=np.float32))
146
+ if silence is not None and i < len(chunks) - 1:
147
+ parts.append(silence)
148
+ return np.concatenate(parts)
149
+
150
 
151
  def b64encode_bytes(b: bytes) -> str:
152
  return base64.b64encode(b).decode("ascii")
153
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
 
155
+ # ---------------------------------------------------------------------------
156
+ # Speaker latent cache (keyed by SHA-256 of WAV bytes)
157
+ # ---------------------------------------------------------------------------
158
+ class SpeakerCache:
159
+ def __init__(self, maxsize: int = S.speaker_cache_size):
160
+ self._cache: Dict[str, Tuple] = {}
161
+ self._order: List[str] = []
162
+ self._maxsize = maxsize
163
+ self._lock = threading.Lock()
164
+
165
+ def _key(self, wav_bytes: bytes) -> str:
166
+ return hashlib.sha256(wav_bytes).hexdigest()[:16]
167
+
168
+ def get(self, wav_bytes: bytes) -> Optional[Tuple]:
169
+ key = self._key(wav_bytes)
170
+ with self._lock:
171
+ return self._cache.get(key)
172
+
173
+ def put(self, wav_bytes: bytes, latents: Tuple) -> None:
174
+ key = self._key(wav_bytes)
175
+ with self._lock:
176
+ if key in self._cache:
177
+ return
178
+ if len(self._order) >= self._maxsize:
179
+ evict = self._order.pop(0)
180
+ self._cache.pop(evict, None)
181
+ self._cache[key] = latents
182
+ self._order.append(key)
183
+
184
+
185
+ _speaker_cache = SpeakerCache()
186
+
187
+ # ---------------------------------------------------------------------------
188
+ # Model manager (lazy, thread-safe)
189
+ # ---------------------------------------------------------------------------
190
+ _model_lock = threading.Lock()
191
+ _infer_lock = threading.Lock()
192
+
193
+ _tts_model = None
194
+ _tts_error: Optional[str] = None
195
+
196
+
197
+ def _get_model():
198
+ global _tts_model, _tts_error
199
+ if _tts_error is not None:
200
+ return None
201
+ if _tts_model is not None:
202
+ return _tts_model
203
+
204
+ with _model_lock:
205
+ if _tts_error is not None:
206
  return None
207
+ if _tts_model is not None:
208
+ return _tts_model
209
 
210
+ try:
211
+ from TTS.api import TTS
212
+ print(f"[XTTS] Loading {S.model_name} ...")
213
+ t0 = time.time()
214
+ tts = TTS(model_name=S.model_name, progress_bar=False, gpu=False)
215
+
216
+ # Optional optimizations
217
+ inner = getattr(getattr(tts, "synthesizer", None), "tts_model", None)
218
+ if isinstance(inner, torch.nn.Module):
219
+ if S.use_fp16:
220
  try:
221
+ inner = inner.half()
222
+ tts.synthesizer.tts_model = inner
223
+ print("[XTTS] FP16 enabled")
 
 
 
 
 
224
  except Exception as e:
225
+ print(f"[XTTS] FP16 failed: {e}")
226
+ if S.torch_compile:
227
+ try:
228
+ inner = torch.compile(inner)
229
+ tts.synthesizer.tts_model = inner
230
+ print("[XTTS] torch.compile enabled")
231
+ except Exception as e:
232
+ print(f"[XTTS] torch.compile failed: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
233
 
234
+ _tts_model = tts
235
+ print(f"[XTTS] Model loaded in {time.time() - t0:.1f}s")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
236
  except Exception as e:
237
+ _tts_error = str(e)
238
+ print(f"[XTTS] FAILED to load: {e}")
239
+ return None
240
 
241
+ return _tts_model
 
 
 
242
 
 
 
 
243
 
244
+ def _get_sample_rate() -> int:
245
+ tts = _get_model()
246
+ if tts is None:
247
+ return 22050
248
+ synth = getattr(tts, "synthesizer", None)
249
+ return getattr(synth, "output_sample_rate", 22050) if synth else 22050
 
 
 
 
 
 
 
 
 
 
 
250
 
 
 
 
251
 
252
+ # ---------------------------------------------------------------------------
253
+ # Core synthesis function
254
+ # ---------------------------------------------------------------------------
255
+ def _synthesize(text: str, language: str, speaker_wav_bytes: Optional[bytes] = None) -> Tuple[np.ndarray, int, float]:
256
+ """Returns (audio_np, sample_rate, generation_time_s)."""
257
+ tts = _get_model()
258
+ if tts is None:
259
+ raise HTTPException(503, f"XTTS unavailable: {_tts_error or 'model not loaded'}")
260
 
261
+ t0 = time.time()
 
 
262
 
263
+ with _infer_lock:
264
+ tmp_path = None
265
+ try:
266
+ speaker_wav = None
267
+ if speaker_wav_bytes:
268
+ # Check speaker cache for pre-computed latents
269
+ # (coqui-tts handles caching internally in >=0.27, but we cache the temp file path approach)
270
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
271
+ tmp.write(speaker_wav_bytes)
272
+ tmp.flush()
273
+ tmp_path = tmp.name
274
+ speaker_wav = tmp_path
275
+
276
+ audio_np = tts.tts(
277
+ text=text,
278
+ language=language,
279
+ speaker_wav=speaker_wav,
280
+ )
281
+ finally:
282
+ if tmp_path:
283
  try:
284
+ os.remove(tmp_path)
285
+ except OSError:
286
+ pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
287
 
288
+ sr = _get_sample_rate()
289
+ gen_time = time.time() - t0
290
+ return np.asarray(audio_np, dtype=np.float32), sr, gen_time
 
291
 
 
 
 
292
 
293
+ # ---------------------------------------------------------------------------
294
+ # Pydantic models
295
+ # ---------------------------------------------------------------------------
296
+ class SynthRequest(BaseModel):
297
  text: str = Field(..., min_length=1, max_length=5000)
298
+ language: Optional[str] = Field(None, description="Language code (default: pl)")
299
+ speaker_wav_b64: Optional[str] = Field(None, description="Base64-encoded WAV for voice cloning")
300
 
301
+
302
+ class StreamRequest(BaseModel):
303
  text: str = Field(..., min_length=1, max_length=5000)
304
+ language: Optional[str] = None
305
+ speaker_wav_b64: Optional[str] = None
306
+
307
 
308
  class AudioResponse(BaseModel):
309
+ audio_b64: str
310
  sample_rate: int
311
  duration_s: float
312
+ generation_time_s: float
313
  text: str
314
 
315
+
316
  class HealthResponse(BaseModel):
317
  status: str = "ok"
318
+ version: str = "2.0.0"
319
+ model: str = S.model_name
320
+ language: str = S.default_language
321
+ xtts_available: bool = True
322
+ speaker_cache_size: int = S.speaker_cache_size
323
+
324
 
325
+ # ---------------------------------------------------------------------------
326
  # FastAPI app
327
+ # ---------------------------------------------------------------------------
328
+ app = FastAPI(title="Forge-TTS API", version="2.0.0")
329
+
330
 
331
  @app.get("/health", response_model=HealthResponse)
332
  def health():
333
+ available = _tts_error is None
 
 
 
 
 
 
 
334
  return HealthResponse(
335
+ xtts_available=available,
336
+ status="ok" if available else f"degraded: {_tts_error}",
 
 
 
 
 
 
 
337
  )
338
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
339
 
340
  @app.post("/v1/xtts/synthesize", response_model=AudioResponse)
341
+ def xtts_synthesize(req: SynthRequest):
 
 
 
 
342
  speaker_bytes = None
343
  if req.speaker_wav_b64:
344
  try:
345
  speaker_bytes = base64.b64decode(req.speaker_wav_b64)
346
  except Exception as e:
347
+ raise HTTPException(400, f"Invalid base64 speaker_wav: {e}")
348
+
349
+ lang = req.language or S.default_language
350
+ chunks = split_text_into_chunks(req.text)
351
+ if not chunks:
352
+ raise HTTPException(400, "Empty text after normalization")
353
+
354
+ audio_parts = []
355
+ total_gen = 0.0
356
+ sr = 22050
357
+
358
+ for chunk_text in chunks:
359
+ audio, sr, gen_t = _synthesize(chunk_text, lang, speaker_bytes)
360
+ audio_parts.append(audio)
361
+ total_gen += gen_t
362
+
363
+ full_audio = concat_audio(audio_parts, sr)
364
+ wav_bytes = wav_bytes_from_audio(full_audio, sr)
365
+
366
+ return AudioResponse(
367
+ audio_b64=b64encode_bytes(wav_bytes),
368
+ sample_rate=sr,
369
+ duration_s=round(len(full_audio) / sr, 3),
370
+ generation_time_s=round(total_gen, 3),
371
+ text=req.text,
372
+ )
373
+
374
 
375
  @app.post("/v1/xtts/stream")
376
+ async def xtts_stream(req: StreamRequest):
 
377
  speaker_bytes = None
378
  if req.speaker_wav_b64:
379
  try:
380
  speaker_bytes = base64.b64decode(req.speaker_wav_b64)
381
  except Exception as e:
382
+ raise HTTPException(400, f"Invalid base64: {e}")
383
 
384
  chunks = split_text_into_chunks(req.text)
385
  if not chunks:
386
+ raise HTTPException(400, "Empty text after chunking")
387
 
388
+ lang = req.language or S.default_language
389
 
390
  async def generate():
391
  for i, chunk_text in enumerate(chunks):
392
  try:
393
+ audio, sr, gen_t = await asyncio.to_thread(
394
+ _synthesize, chunk_text, lang, speaker_bytes
395
+ )
396
  wav_bytes = wav_bytes_from_audio(audio, sr)
 
397
  payload = {
398
  "chunk_index": i,
399
  "total_chunks": len(chunks),
400
  "text": chunk_text,
401
  "audio_b64": b64encode_bytes(wav_bytes),
402
  "sample_rate": sr,
403
+ "generation_time_s": round(gen_t, 3),
404
  }
405
  yield f"data: {json.dumps(payload)}\n\n"
406
  except Exception as e:
407
+ yield f"data: {json.dumps({'error': str(e), 'chunk_index': i})}\n\n"
 
 
 
 
 
408
  break
 
409
  yield "data: [DONE]\n\n"
410
 
411
  return StreamingResponse(generate(), media_type="text/event-stream")
412
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
413
 
414
+ @app.post("/v1/xtts/clone", response_model=AudioResponse)
415
+ async def xtts_clone(
416
+ text: str = Form(..., min_length=1, max_length=5000),
417
+ language: str = Form(default=S.default_language),
418
+ speaker_wav: UploadFile = File(..., description="WAV file for voice cloning"),
419
+ ):
420
+ """Convenience endpoint: multipart form with WAV file upload (not base64)."""
421
+ wav_bytes = await speaker_wav.read()
422
+ if len(wav_bytes) < 44:
423
+ raise HTTPException(400, "WAV file too small or empty")
424
+ if len(wav_bytes) > 10 * 1024 * 1024:
425
+ raise HTTPException(400, "WAV file too large (max 10MB)")
426
+
427
+ chunks = split_text_into_chunks(text)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
428
  if not chunks:
429
+ raise HTTPException(400, "Empty text after normalization")
430
+
431
+ audio_parts = []
432
+ total_gen = 0.0
433
+ sr = 22050
434
+
435
+ for chunk_text in chunks:
436
+ audio, sr, gen_t = _synthesize(chunk_text, language, wav_bytes)
437
+ audio_parts.append(audio)
438
+ total_gen += gen_t
439
+
440
+ full_audio = concat_audio(audio_parts, sr)
441
+ wav_out = wav_bytes_from_audio(full_audio, sr)
442
+
443
+ return AudioResponse(
444
+ audio_b64=b64encode_bytes(wav_out),
445
+ sample_rate=sr,
446
+ duration_s=round(len(full_audio) / sr, 3),
447
+ generation_time_s=round(total_gen, 3),
448
+ text=text,
449
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
450
 
 
451
 
452
+ # ---------------------------------------------------------------------------
453
+ # Startup
454
+ # ---------------------------------------------------------------------------
455
  @app.on_event("startup")
456
  async def startup_event():
457
+ print("\n" + "=" * 60)
458
+ print("Forge-TTS v2.0.0 — XTTS-v2 Only")
459
+ print("=" * 60)
460
+ print(f"Model: {S.model_name}")
461
+ print(f"Language: {S.default_language}")
462
+ print(f"Threads: {S.num_threads}")
463
+ print(f"FP16: {S.use_fp16}")
464
+ print(f"Compile: {S.torch_compile}")
465
+ print("=" * 60 + "\n")
466
+
467
+ # Eager load to catch errors at startup
468
+ _get_model()
469
+
 
 
 
 
470
 
471
  if __name__ == "__main__":
472
  import uvicorn