CherithCutestory commited on
Commit
447c73c
·
1 Parent(s): 7fd8b08

Added caching for voice clone conitionals

Browse files
Files changed (3) hide show
  1. app.py +85 -59
  2. index.html +10 -1
  3. requirements.txt +1 -0
app.py CHANGED
@@ -1,6 +1,8 @@
1
  import os
 
2
  os.environ.setdefault("OMP_NUM_THREADS", "4")
3
 
 
4
  import io
5
  import base64
6
  import tempfile
@@ -9,6 +11,7 @@ import wave
9
  import numpy as np
10
  import torch
11
  import pyrubberband as pyrb
 
12
  from contextlib import asynccontextmanager
13
  from pathlib import Path
14
  from fastapi import FastAPI, Request, HTTPException
@@ -19,7 +22,8 @@ from typing import Optional
19
  logging.basicConfig(level=logging.INFO)
20
  logger = logging.getLogger("chatterbox-engine")
21
 
22
- BEARER_TOKEN = os.environ.get("API_KEY", "124CC717-7517-47A2-BBD6-54FCAE310297")
 
23
  SAMPLE_RATE = 24000
24
  BIT_DEPTH = 16
25
  CHANNELS = 1
@@ -112,12 +116,24 @@ EMOTION_PITCH_MAP = {
112
  }
113
 
114
  CANONICAL_EMOTIONS = [
115
- "neutral", "happy", "sad", "angry", "fear",
116
- "surprise", "disgust", "excited", "calm", "confused",
117
- "anxious", "hopeful", "melancholy", "fearful",
 
 
 
 
 
 
 
 
 
 
 
118
  ]
119
 
120
  tts_model = None
 
121
 
122
 
123
  def load_model():
@@ -172,8 +188,10 @@ def estimate_speech_duration(text: str) -> float:
172
  return max(1.0, base_seconds)
173
 
174
 
175
- def find_speech_end(audio_np: np.ndarray, sample_rate: int, threshold_db: float = SILENCE_THRESHOLD_DB) -> int:
176
- threshold_linear = 10.0 ** (threshold_db / 20.0)
 
 
177
 
178
  window_size = int(sample_rate * 0.02)
179
  abs_audio = np.abs(audio_np)
@@ -181,7 +199,7 @@ def find_speech_end(audio_np: np.ndarray, sample_rate: int, threshold_db: float
181
  i = len(abs_audio) - 1
182
  while i >= window_size:
183
  window = abs_audio[max(0, i - window_size):i]
184
- rms = np.sqrt(np.mean(window ** 2))
185
  if rms > threshold_linear:
186
  return i
187
  i -= window_size // 2
@@ -189,11 +207,13 @@ def find_speech_end(audio_np: np.ndarray, sample_rate: int, threshold_db: float
189
  return len(audio_np)
190
 
191
 
192
- def find_last_silence_gap(audio_np: np.ndarray, sample_rate: int,
193
- min_expected_samples: int,
194
- threshold_db: float = SILENCE_THRESHOLD_DB,
195
- min_gap_sec: float = MIN_SILENCE_DURATION_SEC) -> int:
196
- threshold_linear = 10.0 ** (threshold_db / 20.0)
 
 
197
  min_gap_samples = int(sample_rate * min_gap_sec)
198
  window_size = int(sample_rate * 0.02)
199
  abs_audio = np.abs(audio_np)
@@ -206,7 +226,7 @@ def find_last_silence_gap(audio_np: np.ndarray, sample_rate: int,
206
 
207
  while i >= search_start:
208
  window = abs_audio[max(0, i - window_size):i]
209
- rms = np.sqrt(np.mean(window ** 2))
210
  if rms <= threshold_linear:
211
  silent_run += window_size // 2
212
  if silent_run >= min_gap_samples:
@@ -221,22 +241,24 @@ def find_last_silence_gap(audio_np: np.ndarray, sample_rate: int,
221
  return best_gap_end
222
 
223
 
224
- def smart_trim_audio(audio_np: np.ndarray, sample_rate: int, text: str) -> np.ndarray:
 
225
  expected_sec = estimate_speech_duration(text)
226
  actual_sec = len(audio_np) / sample_rate
227
 
228
  logger.info(
229
  f"Audio trim: expected={expected_sec:.1f}s, actual={actual_sec:.1f}s, "
230
- f"samples={len(audio_np)}"
231
- )
232
 
233
  speech_end = find_speech_end(audio_np, sample_rate)
234
  speech_end_sec = speech_end / sample_rate
235
- logger.info(f"Speech end detected at {speech_end_sec:.2f}s (sample {speech_end})")
 
236
 
237
  if actual_sec > expected_sec * 1.5:
238
  min_expected_samples = int(expected_sec * 0.7 * sample_rate)
239
- gap_end = find_last_silence_gap(audio_np, sample_rate, min_expected_samples)
 
240
  gap_end_sec = gap_end / sample_rate
241
  logger.info(f"Last silence gap boundary at {gap_end_sec:.2f}s")
242
 
@@ -250,8 +272,7 @@ def smart_trim_audio(audio_np: np.ndarray, sample_rate: int, text: str) -> np.nd
250
  if trim_point < len(audio_np) * 0.3:
251
  logger.warning(
252
  f"Trim point ({trim_point / sample_rate:.2f}s) is less than 30% of audio, "
253
- f"keeping full audio to avoid cutting real speech"
254
- )
255
  trim_point = len(audio_np)
256
 
257
  if trim_point < len(audio_np):
@@ -263,10 +284,8 @@ def smart_trim_audio(audio_np: np.ndarray, sample_rate: int, text: str) -> np.nd
263
  tail_pad = np.zeros(int(sample_rate * TAIL_PAD_SEC), dtype=np.float32)
264
  result = np.concatenate([result, tail_pad])
265
 
266
- logger.info(
267
- f"Final audio: {len(result) / sample_rate:.2f}s "
268
- f"(trimmed from {actual_sec:.2f}s)"
269
- )
270
 
271
  return result
272
 
@@ -312,26 +331,27 @@ async def convert_text_to_speech(request: Request):
312
  body = await request.json()
313
  req = ConvertRequest(**body)
314
  except Exception as e:
315
- return JSONResponse(
316
- status_code=400,
317
- content={"error": str(e), "error_code": "INVALID_REQUEST"}
318
- )
 
319
 
320
  if not req.input_text.strip():
321
- return JSONResponse(
322
- status_code=400,
323
- content={"error": "Input text is empty", "error_code": "INVALID_REQUEST"}
324
- )
 
325
 
326
  if not req.voice_to_clone_sample:
327
  return JSONResponse(
328
  status_code=400,
329
  content={
330
  "error": "Chatterbox requires a voice sample for cloning. "
331
- "Please provide a voice_to_clone_sample.",
332
  "error_code": "CLONING_NOT_SUPPORTED"
333
- }
334
- )
335
 
336
  if req.random_seed is not None and req.random_seed > 0:
337
  torch.manual_seed(req.random_seed)
@@ -342,30 +362,40 @@ async def convert_text_to_speech(request: Request):
342
 
343
  try:
344
  try:
345
- wav_bytes = base64.b64decode(req.voice_to_clone_sample, validate=True)
 
346
  except Exception:
347
  return JSONResponse(
348
  status_code=400,
349
  content={
350
  "error": "Invalid voice_to_clone_sample: not valid base64",
351
  "error_code": "INVALID_REQUEST"
352
- }
353
- )
354
 
355
  if len(wav_bytes) < 44:
356
  return JSONResponse(
357
  status_code=400,
358
  content={
359
- "error": "Invalid voice_to_clone_sample: file too small to be valid audio",
 
360
  "error_code": "INVALID_REQUEST"
361
- }
362
- )
363
 
364
- tmp = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
365
- tmp.write(wav_bytes)
366
- tmp.close()
367
- speaker_wav_path = tmp.name
368
- temp_files.append(tmp.name)
 
 
 
 
 
 
 
 
 
 
369
 
370
  text = req.input_text.strip()
371
  if len(text) > MAX_CHARS:
@@ -379,7 +409,8 @@ async def convert_text_to_speech(request: Request):
379
  if text and text[-1] not in '.!?;:':
380
  text += '.'
381
 
382
- dominant_emotion = req.emotion_set[0].lower() if req.emotion_set else "neutral"
 
383
  base_exaggeration = EMOTION_EXAGGERATION_MAP.get(dominant_emotion, 0.5)
384
  intensity_factor = req.intensity / 50.0
385
  exaggeration = min(1.0, max(0.0, base_exaggeration * intensity_factor))
@@ -398,12 +429,10 @@ async def convert_text_to_speech(request: Request):
398
  f"Generating with Chatterbox: emotion={dominant_emotion}, "
399
  f"exaggeration={exaggeration:.2f}, cfg={cfg_weight:.2f}, "
400
  f"temperature={temperature:.2f}, emotion_speed={emotion_speed:.3f}, "
401
- f"emotion_pitch={emotion_pitch:.2f}, text_len={len(text)}"
402
- )
403
 
404
  wav = tts_model.generate(
405
  text,
406
- audio_prompt_path=speaker_wav_path,
407
  exaggeration=exaggeration,
408
  temperature=temperature,
409
  cfg_weight=cfg_weight,
@@ -436,14 +465,12 @@ async def convert_text_to_speech(request: Request):
436
 
437
  except Exception as e:
438
  logger.exception("TTS generation failed")
439
- return JSONResponse(
440
- status_code=500,
441
- content={
442
- "error": "Audio generation failed",
443
- "error_code": "GENERATION_FAILED",
444
- "details": str(e)
445
- }
446
- )
447
  finally:
448
  for f in temp_files:
449
  try:
@@ -488,4 +515,3 @@ async def health():
488
  if __name__ == "__main__":
489
  import uvicorn
490
  uvicorn.run(app, host="0.0.0.0", port=7860)
491
-
 
1
  import os
2
+
3
  os.environ.setdefault("OMP_NUM_THREADS", "4")
4
 
5
+ import hashlib
6
  import io
7
  import base64
8
  import tempfile
 
11
  import numpy as np
12
  import torch
13
  import pyrubberband as pyrb
14
+ from cachetools import LRUCache
15
  from contextlib import asynccontextmanager
16
  from pathlib import Path
17
  from fastapi import FastAPI, Request, HTTPException
 
22
  logging.basicConfig(level=logging.INFO)
23
  logger = logging.getLogger("chatterbox-engine")
24
 
25
+ BEARER_TOKEN = os.environ.get("API_KEY", "")
26
+ VOICE_COND_CACHE_MAXSIZE = 20
27
  SAMPLE_RATE = 24000
28
  BIT_DEPTH = 16
29
  CHANNELS = 1
 
116
  }
117
 
118
  CANONICAL_EMOTIONS = [
119
+ "neutral",
120
+ "happy",
121
+ "sad",
122
+ "angry",
123
+ "fear",
124
+ "surprise",
125
+ "disgust",
126
+ "excited",
127
+ "calm",
128
+ "confused",
129
+ "anxious",
130
+ "hopeful",
131
+ "melancholy",
132
+ "fearful",
133
  ]
134
 
135
  tts_model = None
136
+ _voice_cond_cache: LRUCache = LRUCache(maxsize=VOICE_COND_CACHE_MAXSIZE)
137
 
138
 
139
  def load_model():
 
188
  return max(1.0, base_seconds)
189
 
190
 
191
+ def find_speech_end(audio_np: np.ndarray,
192
+ sample_rate: int,
193
+ threshold_db: float = SILENCE_THRESHOLD_DB) -> int:
194
+ threshold_linear = 10.0**(threshold_db / 20.0)
195
 
196
  window_size = int(sample_rate * 0.02)
197
  abs_audio = np.abs(audio_np)
 
199
  i = len(abs_audio) - 1
200
  while i >= window_size:
201
  window = abs_audio[max(0, i - window_size):i]
202
+ rms = np.sqrt(np.mean(window**2))
203
  if rms > threshold_linear:
204
  return i
205
  i -= window_size // 2
 
207
  return len(audio_np)
208
 
209
 
210
+ def find_last_silence_gap(
211
+ audio_np: np.ndarray,
212
+ sample_rate: int,
213
+ min_expected_samples: int,
214
+ threshold_db: float = SILENCE_THRESHOLD_DB,
215
+ min_gap_sec: float = MIN_SILENCE_DURATION_SEC) -> int:
216
+ threshold_linear = 10.0**(threshold_db / 20.0)
217
  min_gap_samples = int(sample_rate * min_gap_sec)
218
  window_size = int(sample_rate * 0.02)
219
  abs_audio = np.abs(audio_np)
 
226
 
227
  while i >= search_start:
228
  window = abs_audio[max(0, i - window_size):i]
229
+ rms = np.sqrt(np.mean(window**2))
230
  if rms <= threshold_linear:
231
  silent_run += window_size // 2
232
  if silent_run >= min_gap_samples:
 
241
  return best_gap_end
242
 
243
 
244
+ def smart_trim_audio(audio_np: np.ndarray, sample_rate: int,
245
+ text: str) -> np.ndarray:
246
  expected_sec = estimate_speech_duration(text)
247
  actual_sec = len(audio_np) / sample_rate
248
 
249
  logger.info(
250
  f"Audio trim: expected={expected_sec:.1f}s, actual={actual_sec:.1f}s, "
251
+ f"samples={len(audio_np)}")
 
252
 
253
  speech_end = find_speech_end(audio_np, sample_rate)
254
  speech_end_sec = speech_end / sample_rate
255
+ logger.info(
256
+ f"Speech end detected at {speech_end_sec:.2f}s (sample {speech_end})")
257
 
258
  if actual_sec > expected_sec * 1.5:
259
  min_expected_samples = int(expected_sec * 0.7 * sample_rate)
260
+ gap_end = find_last_silence_gap(audio_np, sample_rate,
261
+ min_expected_samples)
262
  gap_end_sec = gap_end / sample_rate
263
  logger.info(f"Last silence gap boundary at {gap_end_sec:.2f}s")
264
 
 
272
  if trim_point < len(audio_np) * 0.3:
273
  logger.warning(
274
  f"Trim point ({trim_point / sample_rate:.2f}s) is less than 30% of audio, "
275
+ f"keeping full audio to avoid cutting real speech")
 
276
  trim_point = len(audio_np)
277
 
278
  if trim_point < len(audio_np):
 
284
  tail_pad = np.zeros(int(sample_rate * TAIL_PAD_SEC), dtype=np.float32)
285
  result = np.concatenate([result, tail_pad])
286
 
287
+ logger.info(f"Final audio: {len(result) / sample_rate:.2f}s "
288
+ f"(trimmed from {actual_sec:.2f}s)")
 
 
289
 
290
  return result
291
 
 
331
  body = await request.json()
332
  req = ConvertRequest(**body)
333
  except Exception as e:
334
+ return JSONResponse(status_code=400,
335
+ content={
336
+ "error": str(e),
337
+ "error_code": "INVALID_REQUEST"
338
+ })
339
 
340
  if not req.input_text.strip():
341
+ return JSONResponse(status_code=400,
342
+ content={
343
+ "error": "Input text is empty",
344
+ "error_code": "INVALID_REQUEST"
345
+ })
346
 
347
  if not req.voice_to_clone_sample:
348
  return JSONResponse(
349
  status_code=400,
350
  content={
351
  "error": "Chatterbox requires a voice sample for cloning. "
352
+ "Please provide a voice_to_clone_sample.",
353
  "error_code": "CLONING_NOT_SUPPORTED"
354
+ })
 
355
 
356
  if req.random_seed is not None and req.random_seed > 0:
357
  torch.manual_seed(req.random_seed)
 
362
 
363
  try:
364
  try:
365
+ wav_bytes = base64.b64decode(req.voice_to_clone_sample,
366
+ validate=True)
367
  except Exception:
368
  return JSONResponse(
369
  status_code=400,
370
  content={
371
  "error": "Invalid voice_to_clone_sample: not valid base64",
372
  "error_code": "INVALID_REQUEST"
373
+ })
 
374
 
375
  if len(wav_bytes) < 44:
376
  return JSONResponse(
377
  status_code=400,
378
  content={
379
+ "error":
380
+ "Invalid voice_to_clone_sample: file too small to be valid audio",
381
  "error_code": "INVALID_REQUEST"
382
+ })
 
383
 
384
+ cache_key = hashlib.sha256(wav_bytes).hexdigest()
385
+ cached_conds = _voice_cond_cache.get(cache_key)
386
+
387
+ if cached_conds is not None:
388
+ logger.info(f"Voice conditioning cache hit ({cache_key[:8]}...), skipping prepare_conditionals")
389
+ tts_model.conds = cached_conds
390
+ else:
391
+ tmp = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
392
+ tmp.write(wav_bytes)
393
+ tmp.close()
394
+ temp_files.append(tmp.name)
395
+ logger.info(f"Voice conditioning cache miss ({cache_key[:8]}...), running prepare_conditionals")
396
+ tts_model.prepare_conditionals(tmp.name)
397
+ _voice_cond_cache[cache_key] = tts_model.conds
398
+ logger.info(f"Voice conditionals cached (cache size: {len(_voice_cond_cache)}/{VOICE_COND_CACHE_MAXSIZE})")
399
 
400
  text = req.input_text.strip()
401
  if len(text) > MAX_CHARS:
 
409
  if text and text[-1] not in '.!?;:':
410
  text += '.'
411
 
412
+ dominant_emotion = req.emotion_set[0].lower(
413
+ ) if req.emotion_set else "neutral"
414
  base_exaggeration = EMOTION_EXAGGERATION_MAP.get(dominant_emotion, 0.5)
415
  intensity_factor = req.intensity / 50.0
416
  exaggeration = min(1.0, max(0.0, base_exaggeration * intensity_factor))
 
429
  f"Generating with Chatterbox: emotion={dominant_emotion}, "
430
  f"exaggeration={exaggeration:.2f}, cfg={cfg_weight:.2f}, "
431
  f"temperature={temperature:.2f}, emotion_speed={emotion_speed:.3f}, "
432
+ f"emotion_pitch={emotion_pitch:.2f}, text_len={len(text)}")
 
433
 
434
  wav = tts_model.generate(
435
  text,
 
436
  exaggeration=exaggeration,
437
  temperature=temperature,
438
  cfg_weight=cfg_weight,
 
465
 
466
  except Exception as e:
467
  logger.exception("TTS generation failed")
468
+ return JSONResponse(status_code=500,
469
+ content={
470
+ "error": "Audio generation failed",
471
+ "error_code": "GENERATION_FAILED",
472
+ "details": str(e)
473
+ })
 
 
474
  finally:
475
  for f in temp_files:
476
  try:
 
515
  if __name__ == "__main__":
516
  import uvicorn
517
  uvicorn.run(app, host="0.0.0.0", port=7860)
 
index.html CHANGED
@@ -281,6 +281,12 @@
281
  </div>
282
  </div>
283
 
 
 
 
 
 
 
284
  <button class="generate" id="generateBtn" onclick="generate()">Generate Speech</button>
285
 
286
  <div class="result-area hidden" id="resultArea">
@@ -399,9 +405,12 @@
399
  };
400
 
401
  try {
 
 
 
402
  const resp = await fetch("/ConvertTextToSpeech", {
403
  method: "POST",
404
- headers: { "Content-Type": "application/json" },
405
  body: JSON.stringify(payload),
406
  });
407
 
 
281
  </div>
282
  </div>
283
 
284
+ <div class="card">
285
+ <div class="card-title">Authentication</div>
286
+ <label for="apiKey">API Key (if set on server)</label>
287
+ <input type="text" id="apiKey" placeholder="Leave empty if no auth required">
288
+ </div>
289
+
290
  <button class="generate" id="generateBtn" onclick="generate()">Generate Speech</button>
291
 
292
  <div class="result-area hidden" id="resultArea">
 
405
  };
406
 
407
  try {
408
+ const hdrs = { "Content-Type": "application/json" };
409
+ const apiKey = document.getElementById("apiKey").value.trim();
410
+ if (apiKey) hdrs["Authorization"] = "Bearer " + apiKey;
411
  const resp = await fetch("/ConvertTextToSpeech", {
412
  method: "POST",
413
+ headers: hdrs,
414
  body: JSON.stringify(payload),
415
  });
416
 
requirements.txt CHANGED
@@ -7,3 +7,4 @@ numpy
7
  pydantic>=2.0.0
8
  pyrubberband>=0.3.0
9
  soundfile>=0.12.0
 
 
7
  pydantic>=2.0.0
8
  pyrubberband>=0.3.0
9
  soundfile>=0.12.0
10
+ cachetools>=5.0.0