LucaCappelletti94 commited on
Commit
2886be7
ยท
1 Parent(s): 19933fe

Add auto device detection with CPU fallback

Browse files
src/talking_snake/__main__.py CHANGED
@@ -45,9 +45,9 @@ def main() -> int:
45
  parser.add_argument(
46
  "--device",
47
  type=str,
48
- default="cuda",
49
- choices=["cuda", "cpu"],
50
- help="Device to run the TTS model on (default: cuda)",
51
  )
52
  parser.add_argument(
53
  "--reload",
@@ -57,10 +57,26 @@ def main() -> int:
57
 
58
  args = parser.parse_args()
59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  print("๐Ÿš€ Starting Reader server...")
61
  print(f" Language: {args.language}")
62
  print(f" Voice: {args.voice or 'auto'}")
63
- print(f" Device: {args.device}")
64
  print(f" URL: http://{args.host}:{args.port}")
65
  print()
66
 
@@ -76,7 +92,7 @@ def main() -> int:
76
  tts_engine = QwenTTSEngine(
77
  voice=args.voice,
78
  language=args.language,
79
- device=args.device,
80
  )
81
  except Exception as e:
82
  print(f"โŒ Failed to load TTS model: {e}", file=sys.stderr)
 
45
  parser.add_argument(
46
  "--device",
47
  type=str,
48
+ default="auto",
49
+ choices=["auto", "cuda", "cpu"],
50
+ help="Device to run the TTS model on (default: auto, detects GPU)",
51
  )
52
  parser.add_argument(
53
  "--reload",
 
57
 
58
  args = parser.parse_args()
59
 
60
+ # Auto-detect device if set to 'auto'
61
+ device = args.device
62
+ if device == "auto":
63
+ try:
64
+ import torch
65
+
66
+ if torch.cuda.is_available():
67
+ device = "cuda"
68
+ print("๐ŸŽฎ GPU detected, using CUDA")
69
+ else:
70
+ device = "cpu"
71
+ print("๐Ÿ’ป No GPU detected, using CPU (slower but works!)")
72
+ except ImportError:
73
+ device = "cpu"
74
+ print("๐Ÿ’ป PyTorch not available for detection, using CPU")
75
+
76
  print("๐Ÿš€ Starting Reader server...")
77
  print(f" Language: {args.language}")
78
  print(f" Voice: {args.voice or 'auto'}")
79
+ print(f" Device: {device}")
80
  print(f" URL: http://{args.host}:{args.port}")
81
  print()
82
 
 
92
  tts_engine = QwenTTSEngine(
93
  voice=args.voice,
94
  language=args.language,
95
+ device=device,
96
  )
97
  except Exception as e:
98
  print(f"โŒ Failed to load TTS model: {e}", file=sys.stderr)
src/talking_snake/tts.py CHANGED
@@ -3,6 +3,8 @@
3
  from __future__ import annotations
4
 
5
  import io
 
 
6
  import wave
7
  from abc import ABC, abstractmethod
8
  from collections.abc import Iterator
@@ -66,9 +68,13 @@ LANGUAGE_VOICES: dict[str, str] = {
66
  # 1200 chars provides good balance for natural speech flow
67
  DEFAULT_CHUNK_SIZE = 1200
68
 
 
 
 
 
69
 
70
  class QwenTTSEngine(TTSEngineProtocol):
71
- """TTS engine using Qwen3-TTS model."""
72
 
73
  # Available voices for CustomVoice model:
74
  # Chinese: Vivian, Serena, Uncle_Fu, Dylan (Beijing), Eric (Sichuan)
@@ -94,6 +100,7 @@ class QwenTTSEngine(TTSEngineProtocol):
94
  device: str = "cuda",
95
  chunk_size: int = DEFAULT_CHUNK_SIZE,
96
  model_name: str = "Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice",
 
97
  ) -> None:
98
  """Initialize the TTS engine.
99
 
@@ -114,7 +121,6 @@ class QwenTTSEngine(TTSEngineProtocol):
114
  import warnings
115
 
116
  import torch
117
- from qwen_tts import Qwen3TTSModel
118
 
119
  # Suppress the pad_token_id warning from transformers
120
  logging.getLogger("transformers.generation.utils").setLevel(logging.ERROR)
@@ -126,33 +132,122 @@ class QwenTTSEngine(TTSEngineProtocol):
126
  self.chunk_size = chunk_size
127
  self._sample_rate = 24000
128
  self._batch_size = 1 # Will be calculated after model loads
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129
 
130
- # Determine dtype based on device
131
- dtype = torch.bfloat16 if device == "cuda" else torch.float32
 
132
 
133
- # Try to use flash attention on CUDA
134
- attn_impl = "flash_attention_2" if device == "cuda" else "eager"
 
 
 
 
 
135
 
136
  try:
137
  self.model = Qwen3TTSModel.from_pretrained(
138
- model_name,
139
- device_map=device,
140
- dtype=dtype,
141
- attn_implementation=attn_impl,
142
  )
143
  except Exception:
144
  # Fallback without flash attention
145
  self.model = Qwen3TTSModel.from_pretrained(
146
- model_name,
147
- device_map=device,
148
- dtype=dtype,
149
  )
150
 
 
 
151
  # Calculate optimal batch size based on available VRAM
152
- if device == "cuda":
153
  self._batch_size = self._calculate_batch_size()
154
  print(f" Batch size: {self._batch_size} (based on available VRAM)")
155
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
  def _calculate_batch_size(self) -> int:
157
  """Calculate optimal batch size based on available GPU memory.
158
 
@@ -206,46 +301,56 @@ class QwenTTSEngine(TTSEngineProtocol):
206
  if not text.strip():
207
  return
208
 
209
- # Split text into chunks for streaming
210
- chunks = self._split_text(text)
211
-
212
- # First chunk includes WAV header
213
- first_chunk = True
214
-
215
- # Process chunks in batches for GPU efficiency
216
- batch_size = self._batch_size
217
 
218
- for i in range(0, len(chunks), batch_size):
219
- batch = chunks[i : i + batch_size]
220
 
221
- # Filter empty chunks
222
- batch = [c for c in batch if c.strip()]
223
- if not batch:
224
- continue
225
-
226
- # Always use batched call for consistent GPU memory allocation
227
- # Use professional narration style for clear, authoritative delivery
228
- batch_instruct = (
229
- [PROFESSIONAL_STYLE] * len(batch) if len(batch) > 1 else PROFESSIONAL_STYLE
230
- )
231
- audios, sr = self.model.generate_custom_voice(
232
- text=batch if len(batch) > 1 else batch[0],
233
- speaker=[self.voice] * len(batch) if len(batch) > 1 else self.voice,
234
- instruct=batch_instruct,
235
- # Use lower temperature for more stable, consistent voice
236
- temperature=0.7,
237
- repetition_penalty=1.1,
238
- )
239
-
240
- # Ensure audios is a list for consistent iteration
241
- if len(batch) == 1:
242
- audios = [audios]
243
-
244
- # Yield each audio chunk in order
245
- for audio in audios:
246
- wav_bytes = self._audio_to_wav(audio, sr, include_header=first_chunk)
247
- first_chunk = False
248
- yield wav_bytes
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
249
 
250
  def _split_text(self, text: str, max_chars: int | None = None) -> list[str]:
251
  """Split text into chunks suitable for TTS.
 
3
  from __future__ import annotations
4
 
5
  import io
6
+ import threading
7
+ import time
8
  import wave
9
  from abc import ABC, abstractmethod
10
  from collections.abc import Iterator
 
68
  # 1200 chars provides good balance for natural speech flow
69
  DEFAULT_CHUNK_SIZE = 1200
70
 
71
+ # Idle timeout before unloading model from GPU (seconds)
72
+ # Set to 0 to disable auto-unloading
73
+ IDLE_TIMEOUT = 300 # 5 minutes
74
+
75
 
76
  class QwenTTSEngine(TTSEngineProtocol):
77
+ """TTS engine using Qwen3-TTS model with automatic GPU memory management."""
78
 
79
  # Available voices for CustomVoice model:
80
  # Chinese: Vivian, Serena, Uncle_Fu, Dylan (Beijing), Eric (Sichuan)
 
100
  device: str = "cuda",
101
  chunk_size: int = DEFAULT_CHUNK_SIZE,
102
  model_name: str = "Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice",
103
+ idle_timeout: int = IDLE_TIMEOUT,
104
  ) -> None:
105
  """Initialize the TTS engine.
106
 
 
121
  import warnings
122
 
123
  import torch
 
124
 
125
  # Suppress the pad_token_id warning from transformers
126
  logging.getLogger("transformers.generation.utils").setLevel(logging.ERROR)
 
132
  self.chunk_size = chunk_size
133
  self._sample_rate = 24000
134
  self._batch_size = 1 # Will be calculated after model loads
135
+ self._model_name = model_name
136
+ self._dtype = torch.bfloat16 if device == "cuda" else torch.float32
137
+ self._attn_impl = "flash_attention_2" if device == "cuda" else "eager"
138
+
139
+ # Idle timeout management
140
+ self._idle_timeout = idle_timeout
141
+ self._last_activity = time.time()
142
+ self._model_loaded = False
143
+ self._lock = threading.Lock()
144
+ self._unload_timer: threading.Timer | None = None
145
+
146
+ # Model will be loaded on first request (lazy loading)
147
+ self.model = None
148
+
149
+ # Load model immediately if no idle timeout (always keep loaded)
150
+ if idle_timeout == 0:
151
+ self._load_model()
152
+
153
+ def _load_model(self) -> None:
154
+ """Load the model onto GPU or CPU."""
155
+ if self._model_loaded:
156
+ return
157
+
158
+ import torch
159
+ from qwen_tts import Qwen3TTSModel
160
 
161
+ device_name = "GPU" if self.device == "cuda" else "CPU"
162
+ print(f"๐Ÿ”„ Loading TTS model onto {device_name}...")
163
+ start = time.time()
164
 
165
+ # Check if CUDA is actually available when requested
166
+ if self.device == "cuda" and not torch.cuda.is_available():
167
+ print("โš ๏ธ CUDA requested but not available, falling back to CPU")
168
+ self.device = "cpu"
169
+ self._dtype = torch.float32
170
+ self._attn_impl = "eager"
171
+ device_name = "CPU"
172
 
173
  try:
174
  self.model = Qwen3TTSModel.from_pretrained(
175
+ self._model_name,
176
+ device_map=self.device,
177
+ dtype=self._dtype,
178
+ attn_implementation=self._attn_impl,
179
  )
180
  except Exception:
181
  # Fallback without flash attention
182
  self.model = Qwen3TTSModel.from_pretrained(
183
+ self._model_name,
184
+ device_map=self.device,
185
+ dtype=self._dtype,
186
  )
187
 
188
+ self._model_loaded = True
189
+
190
  # Calculate optimal batch size based on available VRAM
191
+ if self.device == "cuda":
192
  self._batch_size = self._calculate_batch_size()
193
  print(f" Batch size: {self._batch_size} (based on available VRAM)")
194
 
195
+ elapsed = time.time() - start
196
+ print(f"โœ… Model loaded in {elapsed:.1f}s")
197
+
198
+ def _unload_model(self) -> None:
199
+ """Unload the model from GPU to free memory."""
200
+ with self._lock:
201
+ if not self._model_loaded or self.model is None:
202
+ return
203
+
204
+ import gc
205
+
206
+ import torch
207
+
208
+ print("๐Ÿ’ค Unloading TTS model from GPU (idle timeout)...")
209
+
210
+ # Delete model and clear references
211
+ del self.model
212
+ self.model = None
213
+ self._model_loaded = False
214
+
215
+ # Force garbage collection and clear CUDA cache
216
+ gc.collect()
217
+ if torch.cuda.is_available():
218
+ torch.cuda.empty_cache()
219
+ torch.cuda.synchronize()
220
+
221
+ print("โœ… GPU memory freed")
222
+
223
+ def _schedule_unload(self) -> None:
224
+ """Schedule model unload after idle timeout."""
225
+ if self._idle_timeout <= 0:
226
+ return
227
+
228
+ # Cancel existing timer
229
+ if self._unload_timer is not None:
230
+ self._unload_timer.cancel()
231
+
232
+ # Schedule new unload
233
+ self._unload_timer = threading.Timer(self._idle_timeout, self._unload_model)
234
+ self._unload_timer.daemon = True
235
+ self._unload_timer.start()
236
+
237
+ def _ensure_model_loaded(self) -> None:
238
+ """Ensure model is loaded before use."""
239
+ with self._lock:
240
+ self._last_activity = time.time()
241
+
242
+ # Cancel any pending unload
243
+ if self._unload_timer is not None:
244
+ self._unload_timer.cancel()
245
+ self._unload_timer = None
246
+
247
+ # Load model if not loaded
248
+ if not self._model_loaded:
249
+ self._load_model()
250
+
251
  def _calculate_batch_size(self) -> int:
252
  """Calculate optimal batch size based on available GPU memory.
253
 
 
301
  if not text.strip():
302
  return
303
 
304
+ # Ensure model is loaded (lazy loading with idle timeout)
305
+ self._ensure_model_loaded()
 
 
 
 
 
 
306
 
307
+ # Type guard - model is guaranteed to be loaded after _ensure_model_loaded
308
+ assert self.model is not None, "Model failed to load"
309
 
310
+ try:
311
+ # Split text into chunks for streaming
312
+ chunks = self._split_text(text)
313
+
314
+ # First chunk includes WAV header
315
+ first_chunk = True
316
+
317
+ # Process chunks in batches for GPU efficiency
318
+ batch_size = self._batch_size
319
+
320
+ for i in range(0, len(chunks), batch_size):
321
+ batch = chunks[i : i + batch_size]
322
+
323
+ # Filter empty chunks
324
+ batch = [c for c in batch if c.strip()]
325
+ if not batch:
326
+ continue
327
+
328
+ # Always use batched call for consistent GPU memory allocation
329
+ # Use professional narration style for clear, authoritative delivery
330
+ batch_instruct = (
331
+ [PROFESSIONAL_STYLE] * len(batch) if len(batch) > 1 else PROFESSIONAL_STYLE
332
+ )
333
+ audios, sr = self.model.generate_custom_voice(
334
+ text=batch if len(batch) > 1 else batch[0],
335
+ speaker=[self.voice] * len(batch) if len(batch) > 1 else self.voice,
336
+ instruct=batch_instruct,
337
+ # Use lower temperature for more stable, consistent voice
338
+ temperature=0.7,
339
+ repetition_penalty=1.1,
340
+ )
341
+
342
+ # Ensure audios is a list for consistent iteration
343
+ if len(batch) == 1:
344
+ audios = [audios]
345
+
346
+ # Yield each audio chunk in order
347
+ for audio in audios:
348
+ wav_bytes = self._audio_to_wav(audio, sr, include_header=first_chunk)
349
+ first_chunk = False
350
+ yield wav_bytes
351
+ finally:
352
+ # Schedule model unload after idle timeout
353
+ self._schedule_unload()
354
 
355
  def _split_text(self, text: str, max_chars: int | None = None) -> list[str]:
356
  """Split text into chunks suitable for TTS.