caarleexx commited on
Commit
cd55bb4
·
verified ·
1 Parent(s): 854c213

Update api/ltx_server_refactored.py

Browse files
Files changed (1) hide show
  1. api/ltx_server_refactored.py +268 -41
api/ltx_server_refactored.py CHANGED
@@ -214,6 +214,273 @@ GPU_CONFIG = {
214
 
215
 
216
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
217
  class VideoService:
218
  def __init__(self):
219
  """Inicializa o serviço com 4 workers especializados."""
@@ -265,46 +532,6 @@ class VideoService:
265
  RESULTS_DIR.mkdir(exist_ok=True)
266
  print(f"[INFO] VideoService 4-Workers pronto. Tempo: {time.perf_counter()-t0:.2f}s")
267
 
268
- def _setup_4gpu_workers(self):
269
- """Configura 4 workers especializados sem deepcopy."""
270
- if self.multi_gpu_enabled:
271
- print("[INFO] Distribuindo modelos em 4 workers...")
272
-
273
- # Workers 0 e 1: Transformer + Text Encenger completos
274
- # Mover os modelos completos para cada GPU transformer
275
- for i, device in enumerate(self.transformer_devices):
276
- print(f"[INFO] Worker {i} (Transformer): {device}")
277
- # Para evitar deepcopy, movemos o modelo principal para a primeira GPU
278
- # e para as outras usamos o modelo já carregado mas movemos para a GPU
279
- if i == 0:
280
- self.pipeline.transformer.to(device)
281
- self.pipeline.text_encoder.to(device)
282
- #self.pipeline.patchifier.to(device)
283
- else:
284
- # Para GPUs adicionais, usamos o mesmo modelo mas movemos entre GPUs quando necessário
285
- # Na prática, vamos usar apenas uma GPU transformer por vez
286
- pass
287
-
288
- # Workers 2 e 3: Apenas VAE
289
- # Não usar deepcopy - vamos compartilhar o mesmo VAE entre GPUs
290
- for i, device in enumerate(self.vae_devices):
291
- print(f"[INFO] Worker {i+2} (VAE): {device}")
292
- # Movemos o VAE para a GPU quando for usar
293
- # Inicialmente fica na primeira GPU VAE
294
- if i == 0:
295
- self.pipeline.vae.to(device)
296
-
297
- # Upscaler - manter na primeira GPU VAE
298
- if self.latent_upsampler:
299
- self.latent_upsampler.to(self.vae_devices[0])
300
-
301
- print("[INFO] Distribuição 4-Workers concluída.")
302
- else:
303
- # Fallback para single GPU
304
- self.pipeline.to(self.device_ltx)
305
- if self.latent_upsampler:
306
- self.latent_upsampler.to(self.device_ltx)
307
-
308
  def _set_generation_environment(self):
309
  """Prepara o ambiente para geração (LTX pipeline)."""
310
  if not ENABLE_MEMORY_OPTIMIZATION:
@@ -594,7 +821,7 @@ class VideoService:
594
  print(f"[DEBUG] Cond shape={tuple(out.shape)} dtype={out.dtype} device={out.device}")
595
  return out
596
 
597
- def generate_low_resolution(
598
  self,
599
  prompt: str,
600
  negative_prompt: str,
 
214
 
215
 
216
 
217
+
218
+
219
+ # Adicione/modifique estas configurações no início do arquivo
220
+ PRECISION_CONFIG = {
221
+ "enable_fp8": False, # Desabilitar FP8 devido a problemas de compatibilidade
222
+ "default_dtype": torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float16,
223
+ "fallback_dtype": torch.float16
224
+ }
225
+
226
+ # Modifique a classe VideoService
227
+ class VideoService:
228
+
229
+
230
+ def _get_safe_precision_dtype(self):
231
+ """Configuração de precisão mais segura para evitar conflitos de dtype."""
232
+ if not torch.cuda.is_available():
233
+ return torch.float32
234
+
235
+ # Verificar suporte a bfloat16
236
+ if torch.cuda.is_bf16_supported():
237
+ print("[INFO] Usando bfloat16 (suportado pela GPU)")
238
+ return torch.bfloat16
239
+ else:
240
+ print("[INFO] Usando float16 (bfloat16 não suportado)")
241
+ return torch.float16
242
+
243
+ def _load_models_with_safe_dtype(self):
244
+ """Carrega modelos com dtype seguro e verifica compatibilidade."""
245
+ print("[INFO] Carregando modelos com dtype seguro...")
246
+
247
+ # CORREÇÃO: Usar dtype seguro explicitamente
248
+ torch_dtype = self.runtime_autocast_dtype
249
+
250
+ try:
251
+ pipeline = LTXVideoPipeline.from_pretrained(
252
+ "Lightricks/LTX-Video",
253
+ torch_dtype=torch_dtype,
254
+ variant="fp8" if PRECISION_CONFIG["enable_fp8"] else None,
255
+ cache_dir=MODEL_CACHE_DIR
256
+ )
257
+ except Exception as e:
258
+ print(f"[WARNING] Erro ao carregar com {torch_dtype}: {e}")
259
+ print("[INFO] Tentando carregar com float16...")
260
+ torch_dtype = torch.float16
261
+ pipeline = LTXVideoPipeline.from_pretrained(
262
+ "Lightricks/LTX-Video",
263
+ torch_dtype=torch_dtype,
264
+ cache_dir=MODEL_CACHE_DIR
265
+ )
266
+
267
+ # CORREÇÃO: Verificar e ajustar dtypes dos componentes do modelo
268
+ self._ensure_consistent_dtypes(pipeline, torch_dtype)
269
+
270
+ # Carregar upscaler com o mesmo dtype
271
+ latent_upsampler = self._load_latent_upsampler(torch_dtype)
272
+
273
+ return pipeline, latent_upsampler
274
+
275
+ def _ensure_consistent_dtypes(self, pipeline, expected_dtype):
276
+ """Garante que todos os componentes do pipeline tenham dtypes consistentes."""
277
+ print("[INFO] Verificando consistência de dtypes...")
278
+
279
+ components = [
280
+ (pipeline.transformer, "transformer"),
281
+ (pipeline.text_encoder, "text_encoder"),
282
+ (pipeline.vae, "vae"),
283
+ (pipeline.patchifier, "patchifier")
284
+ ]
285
+
286
+ for component, name in components:
287
+ if hasattr(component, 'parameters') and next(component.parameters(), None) is not None:
288
+ actual_dtype = next(component.parameters()).dtype
289
+ if actual_dtype != expected_dtype:
290
+ print(f"[INFO] Convertendo {name} de {actual_dtype} para {expected_dtype}")
291
+ component.to(dtype=expected_dtype)
292
+
293
+ print("[INFO] Verificação de dtypes concluída.")
294
+
295
+ def _load_latent_upsampler(self, torch_dtype):
296
+ """Carrega o latent upscaler com dtype seguro."""
297
+ try:
298
+ from ltx_video.models import LatentUpscaler
299
+ upscaler = LatentUpscaler.from_pretrained(
300
+ "Lightricks/LTX-Video",
301
+ subfolder="ltxv-spatial-upscaler-0.9.8",
302
+ torch_dtype=torch_dtype,
303
+ cache_dir=MODEL_CACHE_DIR
304
+ )
305
+ return upscaler
306
+ except Exception as e:
307
+ print(f"[WARNING] Não foi possível carregar o latent upscaler: {e}")
308
+ return None
309
+
310
+ def _setup_4gpu_workers(self):
311
+ """Configura 4 workers com verificação de dtype."""
312
+ if self.multi_gpu_enabled:
313
+ print("[INFO] Distribuindo modelos em 4 workers...")
314
+
315
+ # Workers 0 e 1: Transformer + Text Encoder
316
+ for i, device in enumerate(self.transformer_devices):
317
+ print(f"[INFO] Worker {i} (Transformer): {device}")
318
+ if i == 0:
319
+ self.pipeline.transformer.to(device)
320
+ self.pipeline.text_encoder.to(device)
321
+ self.pipeline.patchifier.to(device)
322
+ # Nota: Para multi-worker transformer, precisaríamos de cópias do modelo
323
+
324
+ # Workers 2 e 3: VAE
325
+ for i, device in enumerate(self.vae_devices):
326
+ print(f"[INFO] Worker {i+2} (VAE): {device}")
327
+ if i == 0:
328
+ self.pipeline.vae.to(device)
329
+
330
+ # Upscaler
331
+ if self.latent_upsampler:
332
+ self.latent_upsampler.to(self.vae_devices[0])
333
+
334
+ print("[INFO] Distribuição 4-Workers concluída.")
335
+
336
+ # CORREÇÃO: Verificar dtypes após mover para GPU
337
+ self._verify_gpu_dtypes()
338
+ else:
339
+ self.pipeline.to(self.device_ltx)
340
+ if self.latent_upsampler:
341
+ self.latent_upsampler.to(self.device_ltx)
342
+
343
+ def _verify_gpu_dtypes(self):
344
+ """Verifica se os dtypes estão consistentes após mover para GPU."""
345
+ print("[INFO] Verificando dtypes nas GPUs...")
346
+
347
+ components = [
348
+ (self.pipeline.transformer, "transformer"),
349
+ (self.pipeline.vae, "vae")
350
+ ]
351
+
352
+ for component, name in components:
353
+ if hasattr(component, 'parameters') and next(component.parameters(), None) is not None:
354
+ param = next(component.parameters())
355
+ print(f" {name}: dtype={param.dtype}, device={param.device}")
356
+
357
+ print("[INFO] Verificação de GPU dtypes concluída.")
358
+
359
+ def generate_low_resolution(self, prompt: str, negative_prompt: str,
360
+ height: int, width: int, num_frames: int,
361
+ guidance_scale: float, seed: Optional[int] = None,
362
+ conditioning_items: Optional[List[ConditioningItem]] = None) -> Tuple[str, str, int]:
363
+ """Geração de baixa resolução com dtype seguro."""
364
+ print("\n[INFO] Iniciando ETAPA 1: Geração de Baixa Resolução...")
365
+ self._set_generation_environment()
366
+
367
+ temp_dir = tempfile.mkdtemp(prefix="ltxv_low_")
368
+ self._register_tmp_dir(temp_dir)
369
+ used_seed = random.randint(0, 2**32 - 1) if seed is None else int(seed)
370
+
371
+ # Determinar dispositivo
372
+ if self.multi_gpu_enabled:
373
+ device = self.transformer_devices[0]
374
+ else:
375
+ device = self.device_ltx
376
+
377
+ print(f" - Usando Seed: {used_seed}")
378
+ print(f" - Frames: {num_frames}, Duração: {num_frames/DEFAULT_FPS:.1f}s")
379
+ print(f" - Dimensões de Saída: {height}x{width}")
380
+ print(f" - Dispositivo: {device}, Dtype: {self.runtime_autocast_dtype}")
381
+
382
+ # CORREÇÃO: Configuração de autocast mais robusta
383
+ device_type = device.split(':')[0] if ':' in device else device
384
+ enabled_autocast = 'cuda' in device and self.runtime_autocast_dtype in [torch.float16, torch.bfloat16]
385
+
386
+ print(f" - Autocast habilitado: {enabled_autocast}")
387
+
388
+ try:
389
+ with torch.autocast(device_type=device_type,
390
+ dtype=self.runtime_autocast_dtype,
391
+ enabled=enabled_autocast):
392
+
393
+ first_pass_kwargs = {
394
+ "prompt": prompt, "negative_prompt": negative_prompt,
395
+ "height": height, "width": width,
396
+ "frame_rate": int(DEFAULT_FPS), "num_frames": num_frames,
397
+ "guidance_scale": float(guidance_scale),
398
+ "output_type": "latent",
399
+ "generator": torch.Generator(device=device).manual_seed(used_seed),
400
+ "conditioning_items": conditioning_items,
401
+ **(self.config.get("first_pass", {}))
402
+ }
403
+
404
+ print(" - Enviando para a pipeline LTX...")
405
+ latents = self.pipeline(**first_pass_kwargs).images
406
+ print(f" [LOG] Latentes gerados. Shape: {latents.shape}, Dtype: {latents.dtype}")
407
+
408
+ except RuntimeError as e:
409
+ print(f"[ERROR] Erro durante a geração: {e}")
410
+ print("[INFO] Tentando fallback para float32...")
411
+ # Fallback para float32
412
+ with torch.autocast(device_type=device_type, dtype=torch.float32, enabled=False):
413
+ first_pass_kwargs = {
414
+ "prompt": prompt, "negative_prompt": negative_prompt,
415
+ "height": height, "width": width,
416
+ "frame_rate": int(DEFAULT_FPS), "num_frames": num_frames,
417
+ "guidance_scale": float(guidance_scale),
418
+ "output_type": "latent",
419
+ "generator": torch.Generator(device=device).manual_seed(used_seed),
420
+ "conditioning_items": conditioning_items,
421
+ **(self.config.get("first_pass", {}))
422
+ }
423
+ latents = self.pipeline(**first_pass_kwargs).images
424
+
425
+ # Resto do método permanece igual...
426
+ latents_cpu = latents.cpu()
427
+ del latents
428
+ torch.cuda.empty_cache()
429
+
430
+ # ... (decodificação e salvamento)
431
+ latents_path = self._save_latents_to_disk(latents_cpu, "latents_low", used_seed)
432
+
433
+ print("\n[INFO] Decodificando vídeo de baixa resolução...")
434
+ self._set_decode_environment()
435
+
436
+ # Decodificação (similar ao código anterior)
437
+ total_latents = latents_cpu.shape[2]
438
+ pontos_de_corte, segment_sizes = self._calculate_dynamic_cuts(total_latents)
439
+
440
+ if len(pontos_de_corte) == 1:
441
+ vae_device = self.vae_devices[0] if self.multi_gpu_enabled else self.device_vae
442
+ latents_for_decode = latents_cpu.to(vae_device)
443
+ vae_manager = self._get_vae_manager(vae_device)
444
+
445
+ pixel_tensor = vae_manager.decode(
446
+ latents_for_decode,
447
+ decode_timestep=float(self.config.get("decode_timestep", 0.05))
448
+ ).cpu()
449
+ else:
450
+ print(f" [LOG] Decodificação em {len(pontos_de_corte)} chunks...")
451
+ pixel_chunks_list = []
452
+ for i, (start, end) in enumerate(pontos_de_corte):
453
+ start, end = max(0, start), min(total_latents, end)
454
+ if start >= end:
455
+ continue
456
+
457
+ latent_chunk = latents_cpu[:, :, start:end, :, :]
458
+ vae_device = self.vae_devices[0] if self.multi_gpu_enabled else self.device_vae
459
+ latent_chunk = latent_chunk.to(vae_device)
460
+ vae_manager = self._get_vae_manager(vae_device)
461
+
462
+ print(f" -> Decodificando Grupo {i+1} (latentes {start} a {end-1})")
463
+
464
+ pixel_chunk = vae_manager.decode(
465
+ latent_chunk,
466
+ decode_timestep=float(self.config.get("decode_timestep", 0.05))
467
+ )
468
+ pixel_chunks_list.append(pixel_chunk.cpu())
469
+ torch.cuda.empty_cache()
470
+
471
+ pixel_tensor = self._stitch_dynamic_chunks(pixel_chunks_list, segment_sizes)
472
+
473
+ video_path = self._save_video_from_tensor(pixel_tensor, "video_low", used_seed, temp_dir)
474
+ self._set_generation_environment()
475
+
476
+ del latents_cpu
477
+ self._finalize()
478
+
479
+ print("\n[SUCCESS] Geração de Baixa Resolução Concluída")
480
+ return video_path, latents_path, used_seed
481
+
482
+
483
+
484
  class VideoService:
485
  def __init__(self):
486
  """Inicializa o serviço com 4 workers especializados."""
 
532
  RESULTS_DIR.mkdir(exist_ok=True)
533
  print(f"[INFO] VideoService 4-Workers pronto. Tempo: {time.perf_counter()-t0:.2f}s")
534
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
535
  def _set_generation_environment(self):
536
  """Prepara o ambiente para geração (LTX pipeline)."""
537
  if not ENABLE_MEMORY_OPTIMIZATION:
 
821
  print(f"[DEBUG] Cond shape={tuple(out.shape)} dtype={out.dtype} device={out.device}")
822
  return out
823
 
824
+ def generate_low_resolution1(
825
  self,
826
  prompt: str,
827
  negative_prompt: str,