CDOM201 commited on
Commit
15a3086
·
verified ·
1 Parent(s): e8ceee6

Upload 2 files

Browse files
Files changed (2) hide show
  1. download_model.py +1 -1
  2. main.py +18 -3
download_model.py CHANGED
@@ -1,7 +1,7 @@
1
  from huggingface_hub import snapshot_download
2
 
3
  print("Downloading model weights (cache only)...")
4
- # The correct repository for chatterbox-tts
5
  snapshot_download(
6
  repo_id="ResembleAI/chatterbox",
7
  allow_patterns=["ve.pt", "t3_mtl23ls_v2.safetensors", "s3gen.pt", "grapheme_mtl_merged_expanded_v1.json", "conds.pt", "Cangjie5_TC.json"]
 
1
  from huggingface_hub import snapshot_download
2
 
3
  print("Downloading model weights (cache only)...")
4
+ # The correct repository for chatterbox-tts (Multilingual)
5
  snapshot_download(
6
  repo_id="ResembleAI/chatterbox",
7
  allow_patterns=["ve.pt", "t3_mtl23ls_v2.safetensors", "s3gen.pt", "grapheme_mtl_merged_expanded_v1.json", "conds.pt", "Cangjie5_TC.json"]
main.py CHANGED
@@ -7,6 +7,7 @@ from pydantic import BaseModel
7
  from chatterbox.mtl_tts import ChatterboxMultilingualTTS
8
  import functools
9
  import uvicorn
 
10
 
11
  # Patch torch.load for CPU if necessary (as in app.py)
12
  # torch.load = functools.partial(torch.load, map_location='cpu')
@@ -16,11 +17,21 @@ app = FastAPI()
16
  # 1. Determine device dynamically
17
  device_map = "cuda" if torch.cuda.is_available() else "cpu"
18
 
 
 
 
19
  print(f"CUDA Available: {torch.cuda.is_available()}")
20
  print(f"Using device: {device_map} with name: {torch.cuda.get_device_name(torch.cuda.current_device())}")
21
 
22
  print("Loading TTS model...")
 
23
  tts_model = ChatterboxMultilingualTTS.from_pretrained(device=device_map)
 
 
 
 
 
 
24
  print("Model loaded.")
25
 
26
  class TTSRequest(BaseModel):
@@ -43,6 +54,7 @@ def generate_audio(req: TTSRequest) -> str:
43
  """Generates audio and returns the filename."""
44
  os.makedirs("outputs", exist_ok=True)
45
  filename = os.path.join("outputs", f"{req.channelID}-{req.username}-{req.messageid}.wav")
 
46
  try:
47
  audio_tensor = tts_model.generate(req.message, language_id=req.language)
48
  ta.save(filename, audio_tensor, tts_model.sr)
@@ -52,20 +64,23 @@ def generate_audio(req: TTSRequest) -> str:
52
 
53
  @app.post("/tts")
54
  async def tts_endpoint(req: TTSRequest, background_tasks: BackgroundTasks):
55
- filename = generate_audio(req)
 
56
  background_tasks.add_task(cleanup_file, filename)
57
  return FileResponse(path=filename, filename=filename, media_type='audio/wav')
58
 
59
  @app.post("/stream")
60
  async def stream_endpoint(req: TTSRequest, background_tasks: BackgroundTasks):
61
- filename = generate_audio(req)
 
62
  background_tasks.add_task(cleanup_file, filename)
63
  # FileResponse handles streaming efficiently for large files
64
  return FileResponse(path=filename, media_type='audio/wav')
65
 
66
  @app.post("/test")
67
  async def test_endpoint(req: TTSRequest):
68
- filename = generate_audio(req)
 
69
  # For /test, we don't delete the file and just return "ok"
70
  return {"status": "ok", "filename": filename}
71
 
 
7
  from chatterbox.mtl_tts import ChatterboxMultilingualTTS
8
  import functools
9
  import uvicorn
10
+ import asyncio
11
 
12
  # Patch torch.load for CPU if necessary (as in app.py)
13
  # torch.load = functools.partial(torch.load, map_location='cpu')
 
17
  # 1. Determine device dynamically
18
  device_map = "cuda" if torch.cuda.is_available() else "cpu"
19
 
20
+ # Create a lock to ensure only one generation happens at a time (important for GPU)
21
+ model_lock = asyncio.Lock()
22
+
23
  print(f"CUDA Available: {torch.cuda.is_available()}")
24
  print(f"Using device: {device_map} with name: {torch.cuda.get_device_name(torch.cuda.current_device())}")
25
 
26
  print("Loading TTS model...")
27
+ # Using Multilingual model as requested
28
  tts_model = ChatterboxMultilingualTTS.from_pretrained(device=device_map)
29
+
30
+ # Optimize for T4 GPU using half-precision (FP16)
31
+ # FP16 provides a significant speed boost with negligible quality loss
32
+ if device_map == "cuda":
33
+ tts_model.to(torch.float16)
34
+
35
  print("Model loaded.")
36
 
37
  class TTSRequest(BaseModel):
 
54
  """Generates audio and returns the filename."""
55
  os.makedirs("outputs", exist_ok=True)
56
  filename = os.path.join("outputs", f"{req.channelID}-{req.username}-{req.messageid}.wav")
57
+
58
  try:
59
  audio_tensor = tts_model.generate(req.message, language_id=req.language)
60
  ta.save(filename, audio_tensor, tts_model.sr)
 
64
 
65
  @app.post("/tts")
66
  async def tts_endpoint(req: TTSRequest, background_tasks: BackgroundTasks):
67
+ async with model_lock:
68
+ filename = await asyncio.to_thread(generate_audio, req)
69
  background_tasks.add_task(cleanup_file, filename)
70
  return FileResponse(path=filename, filename=filename, media_type='audio/wav')
71
 
72
  @app.post("/stream")
73
  async def stream_endpoint(req: TTSRequest, background_tasks: BackgroundTasks):
74
+ async with model_lock:
75
+ filename = await asyncio.to_thread(generate_audio, req)
76
  background_tasks.add_task(cleanup_file, filename)
77
  # FileResponse handles streaming efficiently for large files
78
  return FileResponse(path=filename, media_type='audio/wav')
79
 
80
  @app.post("/test")
81
  async def test_endpoint(req: TTSRequest):
82
+ async with model_lock:
83
+ filename = await asyncio.to_thread(generate_audio, req)
84
  # For /test, we don't delete the file and just return "ok"
85
  return {"status": "ok", "filename": filename}
86