eeuuia commited on
Commit
0d5ddf4
·
verified ·
1 Parent(s): 2869224

Update api/ltx/ltx_utils.py

Browse files
Files changed (1) hide show
  1. api/ltx/ltx_utils.py +34 -27
api/ltx/ltx_utils.py CHANGED
@@ -1,7 +1,6 @@
1
  # FILE: api/ltx/ltx_utils.py
2
  # DESCRIPTION: A pure utility library for the LTX ecosystem.
3
- # REFACTORED to contain only the official, low-level builder function for core components
4
- # and other stateless helper functions.
5
 
6
  import os
7
  import random
@@ -13,7 +12,7 @@ from typing import Dict, Tuple
13
 
14
  import torch
15
  from safetensors import safe_open
16
- from transformers import T5EncoderModel, T5Tokenizer, AutoModelForCausalLM, AutoProcessor, AutoTokenizer
17
 
18
  # ==============================================================================
19
  # --- CONFIGURAÇÃO DE PATH E IMPORTS DA BIBLIOTECA LTX ---
@@ -33,64 +32,73 @@ add_deps_to_path()
33
  try:
34
  from ltx_video.pipelines.pipeline_ltx_video import LTXVideoPipeline
35
  from ltx_video.models.autoencoders.causal_video_autoencoder import CausalVideoAutoencoder
36
- from ltx_video.models.transformers.transformer3d import create_transformer
37
  from ltx_video.models.transformers.symmetric_patchifier import SymmetricPatchifier
38
  from ltx_video.schedulers.rf import RectifiedFlowScheduler
39
  except ImportError as e:
 
40
  raise ImportError(f"Could not import from LTX-Video library. Check repo integrity at '{LTX_VIDEO_REPO_DIR}'. Error: {e}")
41
 
42
  # ==============================================================================
43
- # --- BUILDER DE BAIXO NÍVEL OFICIAL ---
44
- # (Esta é a única função de construção, usada pelo LTXAducManager)
45
  # ==============================================================================
46
 
47
- def build_components_on_cpu(checkpoint_path: str, config: Dict) -> Tuple[LTXVideoPipeline, CausalVideoAutoencoder]:
 
 
 
48
  """
49
- Constrói o pipeline LTX principal (sem VAE) e o modelo VAE separadamente,
50
- mantendo ambos os componentes na CPU. Esta é a função de construção fundamental
51
- usada pelo Manager antes de distribuir os modelos para as GPUs.
 
 
 
 
 
 
 
 
 
52
 
53
- Args:
54
- checkpoint_path (str): Caminho absoluto para o arquivo de checkpoint principal.
55
- config (Dict): O dicionário de configuração carregado do arquivo YAML.
56
 
57
- Returns:
58
- Tuple[LTXVideoPipeline, CausalVideoAutoencoder]: Uma tupla contendo o pipeline principal
59
- e o modelo VAE, ambos na CPU.
 
60
  """
61
  logging.info(f"Building LTX components from checkpoint: {Path(checkpoint_path).name}")
62
 
63
  with safe_open(checkpoint_path, framework="pt") as f:
64
  metadata = f.metadata() or {}
65
  config_str = metadata.get("config", "{}")
66
- configs = json.loads(config_str)
67
- allowed_inference_steps = configs.get("allowed_inference_steps")
68
 
69
- # --- Construir componentes na CPU ---
70
  precision = config.get("precision", "bfloat16")
 
 
71
  transformer = create_transformer(checkpoint_path, precision).to("cpu")
 
72
  scheduler = RectifiedFlowScheduler.from_pretrained(checkpoint_path)
73
  text_encoder = T5EncoderModel.from_pretrained(config["text_encoder_model_name_or_path"], subfolder="text_encoder").to("cpu")
74
  tokenizer = T5Tokenizer.from_pretrained(config["text_encoder_model_name_or_path"], subfolder="tokenizer")
75
  patchifier = SymmetricPatchifier(patch_size=1)
76
-
77
- # Construir o VAE separadamente, também na CPU
78
  vae = CausalVideoAutoencoder.from_pretrained(checkpoint_path).to("cpu")
79
 
80
- # Aplicar precisão bfloat16 se configurado
81
  if precision == "bfloat16":
82
- transformer.to(torch.bfloat16)
83
  text_encoder.to(torch.bfloat16)
84
  vae.to(torch.bfloat16)
85
 
86
- # Montar o pipeline principal, passando 'vae=None' para garantir o desacoplamento
87
  pipeline = LTXVideoPipeline(
88
  transformer=transformer,
89
  patchifier=patchifier,
90
  text_encoder=text_encoder,
91
  tokenizer=tokenizer,
92
  scheduler=scheduler,
93
- vae=None, # VAE é explicitamente desacoplado do pipeline principal
94
  allowed_inference_steps=allowed_inference_steps,
95
  prompt_enhancer_image_caption_model=None,
96
  prompt_enhancer_image_caption_processor=None,
@@ -106,8 +114,7 @@ def build_components_on_cpu(checkpoint_path: str, config: Dict) -> Tuple[LTXVide
106
 
107
  def seed_everything(seed: int):
108
  """
109
- Define a semente para PyTorch, NumPy e Python para garantir reprodutibilidade
110
- em experimentos ou gerações.
111
  """
112
  random.seed(seed)
113
  os.environ['PYTHONHASHSEED'] = str(seed)
 
1
  # FILE: api/ltx/ltx_utils.py
2
  # DESCRIPTION: A pure utility library for the LTX ecosystem.
3
+ # Contains the official low-level builder function for core components and other stateless helpers.
 
4
 
5
  import os
6
  import random
 
12
 
13
  import torch
14
  from safetensors import safe_open
15
+ from transformers import T5EncoderModel, T5Tokenizer
16
 
17
  # ==============================================================================
18
  # --- CONFIGURAÇÃO DE PATH E IMPORTS DA BIBLIOTECA LTX ---
 
32
  try:
33
  from ltx_video.pipelines.pipeline_ltx_video import LTXVideoPipeline
34
  from ltx_video.models.autoencoders.causal_video_autoencoder import CausalVideoAutoencoder
35
+ from ltx_video.models.transformers.transformer3d import Transformer3DModel
36
  from ltx_video.models.transformers.symmetric_patchifier import SymmetricPatchifier
37
  from ltx_video.schedulers.rf import RectifiedFlowScheduler
38
  except ImportError as e:
39
+ logging.critical("Failed to import a core LTX-Video library component.", exc_info=True)
40
  raise ImportError(f"Could not import from LTX-Video library. Check repo integrity at '{LTX_VIDEO_REPO_DIR}'. Error: {e}")
41
 
42
  # ==============================================================================
43
+ # --- FUNÇÃO HELPER 'create_transformer' (Essencial) ---
 
44
  # ==============================================================================
45
 
46
+ def create_transformer(ckpt_path: str, precision: str) -> Transformer3DModel:
47
+ """
48
+ Cria e carrega o modelo Transformer3D com a lógica de precisão correta,
49
+ incluindo suporte para a otimização float8_e4m3fn.
50
  """
51
+ if precision == "float8_e4m3fn":
52
+ try:
53
+ from q8_kernels.integration.patch_transformer import patch_diffusers_transformer as patch_transformer_for_q8_kernels
54
+ transformer = Transformer3DModel.from_pretrained(ckpt_path, dtype=torch.float8_e4m3fn)
55
+ patch_transformer_for_q8_kernels(transformer)
56
+ return transformer
57
+ except ImportError:
58
+ raise ValueError("Q8-Kernels not found. To use FP8 checkpoint, please install Q8 kernels from the project's wheels.")
59
+ elif precision == "bfloat16":
60
+ return Transformer3DModel.from_pretrained(ckpt_path).to(torch.bfloat16)
61
+ else:
62
+ return Transformer3DModel.from_pretrained(ckpt_path)
63
 
64
+ # ==============================================================================
65
+ # --- BUILDER DE BAIXO NÍVEL OFICIAL ---
66
+ # ==============================================================================
67
 
68
+ def build_components_on_cpu(checkpoint_path: str, config: Dict) -> Tuple[LTXVideoPipeline, CausalVideoAutoencoder]:
69
+ """
70
+ Constrói o pipeline LTX principal (sem VAE) e o modelo VAE separadamente, na CPU.
71
+ Esta é a função de construção fundamental usada pelo LTXAducManager.
72
  """
73
  logging.info(f"Building LTX components from checkpoint: {Path(checkpoint_path).name}")
74
 
75
  with safe_open(checkpoint_path, framework="pt") as f:
76
  metadata = f.metadata() or {}
77
  config_str = metadata.get("config", "{}")
78
+ allowed_inference_steps = json.loads(config_str).get("allowed_inference_steps")
 
79
 
 
80
  precision = config.get("precision", "bfloat16")
81
+
82
+ # Usa a função helper correta para criar o transformer
83
  transformer = create_transformer(checkpoint_path, precision).to("cpu")
84
+
85
  scheduler = RectifiedFlowScheduler.from_pretrained(checkpoint_path)
86
  text_encoder = T5EncoderModel.from_pretrained(config["text_encoder_model_name_or_path"], subfolder="text_encoder").to("cpu")
87
  tokenizer = T5Tokenizer.from_pretrained(config["text_encoder_model_name_or_path"], subfolder="tokenizer")
88
  patchifier = SymmetricPatchifier(patch_size=1)
 
 
89
  vae = CausalVideoAutoencoder.from_pretrained(checkpoint_path).to("cpu")
90
 
 
91
  if precision == "bfloat16":
 
92
  text_encoder.to(torch.bfloat16)
93
  vae.to(torch.bfloat16)
94
 
 
95
  pipeline = LTXVideoPipeline(
96
  transformer=transformer,
97
  patchifier=patchifier,
98
  text_encoder=text_encoder,
99
  tokenizer=tokenizer,
100
  scheduler=scheduler,
101
+ vae=None, # VAE é desacoplado para ser gerenciado por um worker separado
102
  allowed_inference_steps=allowed_inference_steps,
103
  prompt_enhancer_image_caption_model=None,
104
  prompt_enhancer_image_caption_processor=None,
 
114
 
115
  def seed_everything(seed: int):
116
  """
117
+ Define a semente para PyTorch, NumPy e Python para garantir reprodutibilidade.
 
118
  """
119
  random.seed(seed)
120
  os.environ['PYTHONHASHSEED'] = str(seed)