jkorstad commited on
Commit
5df26db
·
1 Parent(s): 95170f7

Robust GPU detection: TTSEngine checks torch.cuda.is_available() at model load time, with CPU-safe float32 fallback.

Browse files
Files changed (1) hide show
  1. backend.py +12 -9
backend.py CHANGED
@@ -420,11 +420,12 @@ class TTSEngine:
420
  try:
421
  from qwen_tts import Qwen3TTSModel
422
  import torch
423
- print("[TTS] Loading CustomVoice model...")
 
424
  self._custom_voice_model = Qwen3TTSModel.from_pretrained(
425
  self._model_ids["custom"],
426
- device_map=self.device,
427
- dtype=torch.bfloat16,
428
  )
429
  print("[TTS] CustomVoice ready.")
430
  except Exception as e:
@@ -438,11 +439,12 @@ class TTSEngine:
438
  try:
439
  from qwen_tts import Qwen3TTSModel
440
  import torch
441
- print("[TTS] Loading Base (clone) model...")
 
442
  self._base_model = Qwen3TTSModel.from_pretrained(
443
  self._model_ids["base"],
444
- device_map=self.device,
445
- dtype=torch.bfloat16,
446
  )
447
  print("[TTS] Base ready.")
448
  except Exception as e:
@@ -456,11 +458,12 @@ class TTSEngine:
456
  try:
457
  from qwen_tts import Qwen3TTSModel
458
  import torch
459
- print("[TTS] Loading VoiceDesign model...")
 
460
  self._design_model = Qwen3TTSModel.from_pretrained(
461
  self._model_ids["design"],
462
- device_map=self.device,
463
- dtype=torch.bfloat16,
464
  )
465
  print("[TTS] VoiceDesign ready.")
466
  except Exception as e:
 
420
  try:
421
  from qwen_tts import Qwen3TTSModel
422
  import torch
423
+ device = "cuda" if torch.cuda.is_available() else "cpu"
424
+ print(f"[TTS] Loading CustomVoice model on {device}...")
425
  self._custom_voice_model = Qwen3TTSModel.from_pretrained(
426
  self._model_ids["custom"],
427
+ device_map=device,
428
+ dtype=torch.bfloat16 if device == "cuda" else torch.float32,
429
  )
430
  print("[TTS] CustomVoice ready.")
431
  except Exception as e:
 
439
  try:
440
  from qwen_tts import Qwen3TTSModel
441
  import torch
442
+ device = "cuda" if torch.cuda.is_available() else "cpu"
443
+ print(f"[TTS] Loading Base (clone) model on {device}...")
444
  self._base_model = Qwen3TTSModel.from_pretrained(
445
  self._model_ids["base"],
446
+ device_map=device,
447
+ dtype=torch.bfloat16 if device == "cuda" else torch.float32,
448
  )
449
  print("[TTS] Base ready.")
450
  except Exception as e:
 
458
  try:
459
  from qwen_tts import Qwen3TTSModel
460
  import torch
461
+ device = "cuda" if torch.cuda.is_available() else "cpu"
462
+ print(f"[TTS] Loading VoiceDesign model on {device}...")
463
  self._design_model = Qwen3TTSModel.from_pretrained(
464
  self._model_ids["design"],
465
+ device_map=device,
466
+ dtype=torch.bfloat16 if device == "cuda" else torch.float32,
467
  )
468
  print("[TTS] VoiceDesign ready.")
469
  except Exception as e: