Robust GPU detection: TTSEngine checks torch.cuda.is_available() at model load time, with CPU-safe float32 fallback.
Browse files- backend.py +12 -9
backend.py
CHANGED
|
@@ -420,11 +420,12 @@ class TTSEngine:
|
|
| 420 |
try:
|
| 421 |
from qwen_tts import Qwen3TTSModel
|
| 422 |
import torch
|
| 423 |
-
|
|
|
|
| 424 |
self._custom_voice_model = Qwen3TTSModel.from_pretrained(
|
| 425 |
self._model_ids["custom"],
|
| 426 |
-
device_map=
|
| 427 |
-
dtype=torch.bfloat16,
|
| 428 |
)
|
| 429 |
print("[TTS] CustomVoice ready.")
|
| 430 |
except Exception as e:
|
|
@@ -438,11 +439,12 @@ class TTSEngine:
|
|
| 438 |
try:
|
| 439 |
from qwen_tts import Qwen3TTSModel
|
| 440 |
import torch
|
| 441 |
-
|
|
|
|
| 442 |
self._base_model = Qwen3TTSModel.from_pretrained(
|
| 443 |
self._model_ids["base"],
|
| 444 |
-
device_map=
|
| 445 |
-
dtype=torch.bfloat16,
|
| 446 |
)
|
| 447 |
print("[TTS] Base ready.")
|
| 448 |
except Exception as e:
|
|
@@ -456,11 +458,12 @@ class TTSEngine:
|
|
| 456 |
try:
|
| 457 |
from qwen_tts import Qwen3TTSModel
|
| 458 |
import torch
|
| 459 |
-
|
|
|
|
| 460 |
self._design_model = Qwen3TTSModel.from_pretrained(
|
| 461 |
self._model_ids["design"],
|
| 462 |
-
device_map=
|
| 463 |
-
dtype=torch.bfloat16,
|
| 464 |
)
|
| 465 |
print("[TTS] VoiceDesign ready.")
|
| 466 |
except Exception as e:
|
|
|
|
| 420 |
try:
|
| 421 |
from qwen_tts import Qwen3TTSModel
|
| 422 |
import torch
|
| 423 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 424 |
+
print(f"[TTS] Loading CustomVoice model on {device}...")
|
| 425 |
self._custom_voice_model = Qwen3TTSModel.from_pretrained(
|
| 426 |
self._model_ids["custom"],
|
| 427 |
+
device_map=device,
|
| 428 |
+
dtype=torch.bfloat16 if device == "cuda" else torch.float32,
|
| 429 |
)
|
| 430 |
print("[TTS] CustomVoice ready.")
|
| 431 |
except Exception as e:
|
|
|
|
| 439 |
try:
|
| 440 |
from qwen_tts import Qwen3TTSModel
|
| 441 |
import torch
|
| 442 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 443 |
+
print(f"[TTS] Loading Base (clone) model on {device}...")
|
| 444 |
self._base_model = Qwen3TTSModel.from_pretrained(
|
| 445 |
self._model_ids["base"],
|
| 446 |
+
device_map=device,
|
| 447 |
+
dtype=torch.bfloat16 if device == "cuda" else torch.float32,
|
| 448 |
)
|
| 449 |
print("[TTS] Base ready.")
|
| 450 |
except Exception as e:
|
|
|
|
| 458 |
try:
|
| 459 |
from qwen_tts import Qwen3TTSModel
|
| 460 |
import torch
|
| 461 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 462 |
+
print(f"[TTS] Loading VoiceDesign model on {device}...")
|
| 463 |
self._design_model = Qwen3TTSModel.from_pretrained(
|
| 464 |
self._model_ids["design"],
|
| 465 |
+
device_map=device,
|
| 466 |
+
dtype=torch.bfloat16 if device == "cuda" else torch.float32,
|
| 467 |
)
|
| 468 |
print("[TTS] VoiceDesign ready.")
|
| 469 |
except Exception as e:
|