Spaces:
Sleeping
Sleeping
Commit ยท
2886be7
1
Parent(s): 19933fe
Add auto device detection with CPU fallback
Browse files- src/talking_snake/__main__.py +21 -5
- src/talking_snake/tts.py +157 -52
src/talking_snake/__main__.py
CHANGED
|
@@ -45,9 +45,9 @@ def main() -> int:
|
|
| 45 |
parser.add_argument(
|
| 46 |
"--device",
|
| 47 |
type=str,
|
| 48 |
-
default="
|
| 49 |
-
choices=["cuda", "cpu"],
|
| 50 |
-
help="Device to run the TTS model on (default:
|
| 51 |
)
|
| 52 |
parser.add_argument(
|
| 53 |
"--reload",
|
|
@@ -57,10 +57,26 @@ def main() -> int:
|
|
| 57 |
|
| 58 |
args = parser.parse_args()
|
| 59 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
print("๐ Starting Reader server...")
|
| 61 |
print(f" Language: {args.language}")
|
| 62 |
print(f" Voice: {args.voice or 'auto'}")
|
| 63 |
-
print(f" Device: {
|
| 64 |
print(f" URL: http://{args.host}:{args.port}")
|
| 65 |
print()
|
| 66 |
|
|
@@ -76,7 +92,7 @@ def main() -> int:
|
|
| 76 |
tts_engine = QwenTTSEngine(
|
| 77 |
voice=args.voice,
|
| 78 |
language=args.language,
|
| 79 |
-
device=
|
| 80 |
)
|
| 81 |
except Exception as e:
|
| 82 |
print(f"โ Failed to load TTS model: {e}", file=sys.stderr)
|
|
|
|
| 45 |
parser.add_argument(
|
| 46 |
"--device",
|
| 47 |
type=str,
|
| 48 |
+
default="auto",
|
| 49 |
+
choices=["auto", "cuda", "cpu"],
|
| 50 |
+
help="Device to run the TTS model on (default: auto, detects GPU)",
|
| 51 |
)
|
| 52 |
parser.add_argument(
|
| 53 |
"--reload",
|
|
|
|
| 57 |
|
| 58 |
args = parser.parse_args()
|
| 59 |
|
| 60 |
+
# Auto-detect device if set to 'auto'
|
| 61 |
+
device = args.device
|
| 62 |
+
if device == "auto":
|
| 63 |
+
try:
|
| 64 |
+
import torch
|
| 65 |
+
|
| 66 |
+
if torch.cuda.is_available():
|
| 67 |
+
device = "cuda"
|
| 68 |
+
print("๐ฎ GPU detected, using CUDA")
|
| 69 |
+
else:
|
| 70 |
+
device = "cpu"
|
| 71 |
+
print("๐ป No GPU detected, using CPU (slower but works!)")
|
| 72 |
+
except ImportError:
|
| 73 |
+
device = "cpu"
|
| 74 |
+
print("๐ป PyTorch not available for detection, using CPU")
|
| 75 |
+
|
| 76 |
print("๐ Starting Reader server...")
|
| 77 |
print(f" Language: {args.language}")
|
| 78 |
print(f" Voice: {args.voice or 'auto'}")
|
| 79 |
+
print(f" Device: {device}")
|
| 80 |
print(f" URL: http://{args.host}:{args.port}")
|
| 81 |
print()
|
| 82 |
|
|
|
|
| 92 |
tts_engine = QwenTTSEngine(
|
| 93 |
voice=args.voice,
|
| 94 |
language=args.language,
|
| 95 |
+
device=device,
|
| 96 |
)
|
| 97 |
except Exception as e:
|
| 98 |
print(f"โ Failed to load TTS model: {e}", file=sys.stderr)
|
src/talking_snake/tts.py
CHANGED
|
@@ -3,6 +3,8 @@
|
|
| 3 |
from __future__ import annotations
|
| 4 |
|
| 5 |
import io
|
|
|
|
|
|
|
| 6 |
import wave
|
| 7 |
from abc import ABC, abstractmethod
|
| 8 |
from collections.abc import Iterator
|
|
@@ -66,9 +68,13 @@ LANGUAGE_VOICES: dict[str, str] = {
|
|
| 66 |
# 1200 chars provides good balance for natural speech flow
|
| 67 |
DEFAULT_CHUNK_SIZE = 1200
|
| 68 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
|
| 70 |
class QwenTTSEngine(TTSEngineProtocol):
|
| 71 |
-
"""TTS engine using Qwen3-TTS model."""
|
| 72 |
|
| 73 |
# Available voices for CustomVoice model:
|
| 74 |
# Chinese: Vivian, Serena, Uncle_Fu, Dylan (Beijing), Eric (Sichuan)
|
|
@@ -94,6 +100,7 @@ class QwenTTSEngine(TTSEngineProtocol):
|
|
| 94 |
device: str = "cuda",
|
| 95 |
chunk_size: int = DEFAULT_CHUNK_SIZE,
|
| 96 |
model_name: str = "Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice",
|
|
|
|
| 97 |
) -> None:
|
| 98 |
"""Initialize the TTS engine.
|
| 99 |
|
|
@@ -114,7 +121,6 @@ class QwenTTSEngine(TTSEngineProtocol):
|
|
| 114 |
import warnings
|
| 115 |
|
| 116 |
import torch
|
| 117 |
-
from qwen_tts import Qwen3TTSModel
|
| 118 |
|
| 119 |
# Suppress the pad_token_id warning from transformers
|
| 120 |
logging.getLogger("transformers.generation.utils").setLevel(logging.ERROR)
|
|
@@ -126,33 +132,122 @@ class QwenTTSEngine(TTSEngineProtocol):
|
|
| 126 |
self.chunk_size = chunk_size
|
| 127 |
self._sample_rate = 24000
|
| 128 |
self._batch_size = 1 # Will be calculated after model loads
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 129 |
|
| 130 |
-
|
| 131 |
-
|
|
|
|
| 132 |
|
| 133 |
-
#
|
| 134 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 135 |
|
| 136 |
try:
|
| 137 |
self.model = Qwen3TTSModel.from_pretrained(
|
| 138 |
-
|
| 139 |
-
device_map=device,
|
| 140 |
-
dtype=
|
| 141 |
-
attn_implementation=
|
| 142 |
)
|
| 143 |
except Exception:
|
| 144 |
# Fallback without flash attention
|
| 145 |
self.model = Qwen3TTSModel.from_pretrained(
|
| 146 |
-
|
| 147 |
-
device_map=device,
|
| 148 |
-
dtype=
|
| 149 |
)
|
| 150 |
|
|
|
|
|
|
|
| 151 |
# Calculate optimal batch size based on available VRAM
|
| 152 |
-
if device == "cuda":
|
| 153 |
self._batch_size = self._calculate_batch_size()
|
| 154 |
print(f" Batch size: {self._batch_size} (based on available VRAM)")
|
| 155 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 156 |
def _calculate_batch_size(self) -> int:
|
| 157 |
"""Calculate optimal batch size based on available GPU memory.
|
| 158 |
|
|
@@ -206,46 +301,56 @@ class QwenTTSEngine(TTSEngineProtocol):
|
|
| 206 |
if not text.strip():
|
| 207 |
return
|
| 208 |
|
| 209 |
-
#
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
# First chunk includes WAV header
|
| 213 |
-
first_chunk = True
|
| 214 |
-
|
| 215 |
-
# Process chunks in batches for GPU efficiency
|
| 216 |
-
batch_size = self._batch_size
|
| 217 |
|
| 218 |
-
|
| 219 |
-
|
| 220 |
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 249 |
|
| 250 |
def _split_text(self, text: str, max_chars: int | None = None) -> list[str]:
|
| 251 |
"""Split text into chunks suitable for TTS.
|
|
|
|
| 3 |
from __future__ import annotations
|
| 4 |
|
| 5 |
import io
|
| 6 |
+
import threading
|
| 7 |
+
import time
|
| 8 |
import wave
|
| 9 |
from abc import ABC, abstractmethod
|
| 10 |
from collections.abc import Iterator
|
|
|
|
| 68 |
# 1200 chars provides good balance for natural speech flow
|
| 69 |
DEFAULT_CHUNK_SIZE = 1200
|
| 70 |
|
| 71 |
+
# Idle timeout before unloading model from GPU (seconds)
|
| 72 |
+
# Set to 0 to disable auto-unloading
|
| 73 |
+
IDLE_TIMEOUT = 300 # 5 minutes
|
| 74 |
+
|
| 75 |
|
| 76 |
class QwenTTSEngine(TTSEngineProtocol):
|
| 77 |
+
"""TTS engine using Qwen3-TTS model with automatic GPU memory management."""
|
| 78 |
|
| 79 |
# Available voices for CustomVoice model:
|
| 80 |
# Chinese: Vivian, Serena, Uncle_Fu, Dylan (Beijing), Eric (Sichuan)
|
|
|
|
| 100 |
device: str = "cuda",
|
| 101 |
chunk_size: int = DEFAULT_CHUNK_SIZE,
|
| 102 |
model_name: str = "Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice",
|
| 103 |
+
idle_timeout: int = IDLE_TIMEOUT,
|
| 104 |
) -> None:
|
| 105 |
"""Initialize the TTS engine.
|
| 106 |
|
|
|
|
| 121 |
import warnings
|
| 122 |
|
| 123 |
import torch
|
|
|
|
| 124 |
|
| 125 |
# Suppress the pad_token_id warning from transformers
|
| 126 |
logging.getLogger("transformers.generation.utils").setLevel(logging.ERROR)
|
|
|
|
| 132 |
self.chunk_size = chunk_size
|
| 133 |
self._sample_rate = 24000
|
| 134 |
self._batch_size = 1 # Will be calculated after model loads
|
| 135 |
+
self._model_name = model_name
|
| 136 |
+
self._dtype = torch.bfloat16 if device == "cuda" else torch.float32
|
| 137 |
+
self._attn_impl = "flash_attention_2" if device == "cuda" else "eager"
|
| 138 |
+
|
| 139 |
+
# Idle timeout management
|
| 140 |
+
self._idle_timeout = idle_timeout
|
| 141 |
+
self._last_activity = time.time()
|
| 142 |
+
self._model_loaded = False
|
| 143 |
+
self._lock = threading.Lock()
|
| 144 |
+
self._unload_timer: threading.Timer | None = None
|
| 145 |
+
|
| 146 |
+
# Model will be loaded on first request (lazy loading)
|
| 147 |
+
self.model = None
|
| 148 |
+
|
| 149 |
+
# Load model immediately if no idle timeout (always keep loaded)
|
| 150 |
+
if idle_timeout == 0:
|
| 151 |
+
self._load_model()
|
| 152 |
+
|
| 153 |
+
def _load_model(self) -> None:
|
| 154 |
+
"""Load the model onto GPU or CPU."""
|
| 155 |
+
if self._model_loaded:
|
| 156 |
+
return
|
| 157 |
+
|
| 158 |
+
import torch
|
| 159 |
+
from qwen_tts import Qwen3TTSModel
|
| 160 |
|
| 161 |
+
device_name = "GPU" if self.device == "cuda" else "CPU"
|
| 162 |
+
print(f"๐ Loading TTS model onto {device_name}...")
|
| 163 |
+
start = time.time()
|
| 164 |
|
| 165 |
+
# Check if CUDA is actually available when requested
|
| 166 |
+
if self.device == "cuda" and not torch.cuda.is_available():
|
| 167 |
+
print("โ ๏ธ CUDA requested but not available, falling back to CPU")
|
| 168 |
+
self.device = "cpu"
|
| 169 |
+
self._dtype = torch.float32
|
| 170 |
+
self._attn_impl = "eager"
|
| 171 |
+
device_name = "CPU"
|
| 172 |
|
| 173 |
try:
|
| 174 |
self.model = Qwen3TTSModel.from_pretrained(
|
| 175 |
+
self._model_name,
|
| 176 |
+
device_map=self.device,
|
| 177 |
+
dtype=self._dtype,
|
| 178 |
+
attn_implementation=self._attn_impl,
|
| 179 |
)
|
| 180 |
except Exception:
|
| 181 |
# Fallback without flash attention
|
| 182 |
self.model = Qwen3TTSModel.from_pretrained(
|
| 183 |
+
self._model_name,
|
| 184 |
+
device_map=self.device,
|
| 185 |
+
dtype=self._dtype,
|
| 186 |
)
|
| 187 |
|
| 188 |
+
self._model_loaded = True
|
| 189 |
+
|
| 190 |
# Calculate optimal batch size based on available VRAM
|
| 191 |
+
if self.device == "cuda":
|
| 192 |
self._batch_size = self._calculate_batch_size()
|
| 193 |
print(f" Batch size: {self._batch_size} (based on available VRAM)")
|
| 194 |
|
| 195 |
+
elapsed = time.time() - start
|
| 196 |
+
print(f"โ
Model loaded in {elapsed:.1f}s")
|
| 197 |
+
|
| 198 |
+
def _unload_model(self) -> None:
|
| 199 |
+
"""Unload the model from GPU to free memory."""
|
| 200 |
+
with self._lock:
|
| 201 |
+
if not self._model_loaded or self.model is None:
|
| 202 |
+
return
|
| 203 |
+
|
| 204 |
+
import gc
|
| 205 |
+
|
| 206 |
+
import torch
|
| 207 |
+
|
| 208 |
+
print("๐ค Unloading TTS model from GPU (idle timeout)...")
|
| 209 |
+
|
| 210 |
+
# Delete model and clear references
|
| 211 |
+
del self.model
|
| 212 |
+
self.model = None
|
| 213 |
+
self._model_loaded = False
|
| 214 |
+
|
| 215 |
+
# Force garbage collection and clear CUDA cache
|
| 216 |
+
gc.collect()
|
| 217 |
+
if torch.cuda.is_available():
|
| 218 |
+
torch.cuda.empty_cache()
|
| 219 |
+
torch.cuda.synchronize()
|
| 220 |
+
|
| 221 |
+
print("โ
GPU memory freed")
|
| 222 |
+
|
| 223 |
+
def _schedule_unload(self) -> None:
|
| 224 |
+
"""Schedule model unload after idle timeout."""
|
| 225 |
+
if self._idle_timeout <= 0:
|
| 226 |
+
return
|
| 227 |
+
|
| 228 |
+
# Cancel existing timer
|
| 229 |
+
if self._unload_timer is not None:
|
| 230 |
+
self._unload_timer.cancel()
|
| 231 |
+
|
| 232 |
+
# Schedule new unload
|
| 233 |
+
self._unload_timer = threading.Timer(self._idle_timeout, self._unload_model)
|
| 234 |
+
self._unload_timer.daemon = True
|
| 235 |
+
self._unload_timer.start()
|
| 236 |
+
|
| 237 |
+
def _ensure_model_loaded(self) -> None:
|
| 238 |
+
"""Ensure model is loaded before use."""
|
| 239 |
+
with self._lock:
|
| 240 |
+
self._last_activity = time.time()
|
| 241 |
+
|
| 242 |
+
# Cancel any pending unload
|
| 243 |
+
if self._unload_timer is not None:
|
| 244 |
+
self._unload_timer.cancel()
|
| 245 |
+
self._unload_timer = None
|
| 246 |
+
|
| 247 |
+
# Load model if not loaded
|
| 248 |
+
if not self._model_loaded:
|
| 249 |
+
self._load_model()
|
| 250 |
+
|
| 251 |
def _calculate_batch_size(self) -> int:
|
| 252 |
"""Calculate optimal batch size based on available GPU memory.
|
| 253 |
|
|
|
|
| 301 |
if not text.strip():
|
| 302 |
return
|
| 303 |
|
| 304 |
+
# Ensure model is loaded (lazy loading with idle timeout)
|
| 305 |
+
self._ensure_model_loaded()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 306 |
|
| 307 |
+
# Type guard - model is guaranteed to be loaded after _ensure_model_loaded
|
| 308 |
+
assert self.model is not None, "Model failed to load"
|
| 309 |
|
| 310 |
+
try:
|
| 311 |
+
# Split text into chunks for streaming
|
| 312 |
+
chunks = self._split_text(text)
|
| 313 |
+
|
| 314 |
+
# First chunk includes WAV header
|
| 315 |
+
first_chunk = True
|
| 316 |
+
|
| 317 |
+
# Process chunks in batches for GPU efficiency
|
| 318 |
+
batch_size = self._batch_size
|
| 319 |
+
|
| 320 |
+
for i in range(0, len(chunks), batch_size):
|
| 321 |
+
batch = chunks[i : i + batch_size]
|
| 322 |
+
|
| 323 |
+
# Filter empty chunks
|
| 324 |
+
batch = [c for c in batch if c.strip()]
|
| 325 |
+
if not batch:
|
| 326 |
+
continue
|
| 327 |
+
|
| 328 |
+
# Always use batched call for consistent GPU memory allocation
|
| 329 |
+
# Use professional narration style for clear, authoritative delivery
|
| 330 |
+
batch_instruct = (
|
| 331 |
+
[PROFESSIONAL_STYLE] * len(batch) if len(batch) > 1 else PROFESSIONAL_STYLE
|
| 332 |
+
)
|
| 333 |
+
audios, sr = self.model.generate_custom_voice(
|
| 334 |
+
text=batch if len(batch) > 1 else batch[0],
|
| 335 |
+
speaker=[self.voice] * len(batch) if len(batch) > 1 else self.voice,
|
| 336 |
+
instruct=batch_instruct,
|
| 337 |
+
# Use lower temperature for more stable, consistent voice
|
| 338 |
+
temperature=0.7,
|
| 339 |
+
repetition_penalty=1.1,
|
| 340 |
+
)
|
| 341 |
+
|
| 342 |
+
# Ensure audios is a list for consistent iteration
|
| 343 |
+
if len(batch) == 1:
|
| 344 |
+
audios = [audios]
|
| 345 |
+
|
| 346 |
+
# Yield each audio chunk in order
|
| 347 |
+
for audio in audios:
|
| 348 |
+
wav_bytes = self._audio_to_wav(audio, sr, include_header=first_chunk)
|
| 349 |
+
first_chunk = False
|
| 350 |
+
yield wav_bytes
|
| 351 |
+
finally:
|
| 352 |
+
# Schedule model unload after idle timeout
|
| 353 |
+
self._schedule_unload()
|
| 354 |
|
| 355 |
def _split_text(self, text: str, max_chars: int | None = None) -> list[str]:
|
| 356 |
"""Split text into chunks suitable for TTS.
|