CDOM201 commited on
Commit
e00ded0
·
verified ·
1 Parent(s): f1f26fc

Mat1 and Mat2 fixes

Browse files
Files changed (1) hide show
  1. main.py +10 -5
main.py CHANGED
@@ -28,11 +28,9 @@ print("Loading TTS model...")
28
  tts_model = ChatterboxMultilingualTTS.from_pretrained(device=device_map)
29
 
30
  # Optimize for T4 GPU using half-precision (FP16)
31
- # FP16 provides a significant speed boost with negligible quality loss
32
  if device_map == "cuda":
33
- tts_model.t3.to(torch.float16)
34
- tts_model.s3gen.to(torch.float16)
35
- tts_model.ve.to(torch.float16)
36
 
37
  print("Model loaded.")
38
 
@@ -58,7 +56,14 @@ def generate_audio(req: TTSRequest) -> str:
58
  filename = os.path.join("outputs", f"{req.channelID}-{req.username}-{req.messageid}.wav")
59
 
60
  try:
61
- audio_tensor = tts_model.generate(req.message, language_id=req.language)
 
 
 
 
 
 
 
62
  ta.save(filename, audio_tensor, tts_model.sr)
63
  return filename
64
  except Exception as e:
 
28
  tts_model = ChatterboxMultilingualTTS.from_pretrained(device=device_map)
29
 
30
  # Optimize for T4 GPU using half-precision (FP16)
31
+ # We use autocast during inference for the best balance of speed and stability
32
  if device_map == "cuda":
33
+ print("GPU optimization: FP16 Autocast enabled.")
 
 
34
 
35
  print("Model loaded.")
36
 
 
56
  filename = os.path.join("outputs", f"{req.channelID}-{req.username}-{req.messageid}.wav")
57
 
58
  try:
59
+ # Use autocast to automatically handle float16/float32 mixing
60
+ # This prevents the "mat1 and mat2 must have the same dtype" error
61
+ if device_map == "cuda":
62
+ with torch.amp.autocast(device_type='cuda', dtype=torch.float16):
63
+ audio_tensor = tts_model.generate(req.message, language_id=req.language)
64
+ else:
65
+ audio_tensor = tts_model.generate(req.message, language_id=req.language)
66
+
67
  ta.save(filename, audio_tensor, tts_model.sr)
68
  return filename
69
  except Exception as e: