eeuuia commited on
Commit
7011aa8
verified
1 Parent(s): c2d3ac4

Upload 2 files

Browse files
Files changed (2) hide show
  1. api/debug_utils.py +70 -0
  2. api/ltx_utils.py +207 -0
api/debug_utils.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # FILE: api/utils/debug_utils.py
2
+ # DESCRIPTION: A utility for detailed function logging and debugging.
3
+
4
+ import os
5
+ import functools
6
+ import logging
7
+ import torch
8
+
9
+ # Define o n铆vel de log. Mude para "INFO" para desativar os logs detalhados.
10
+ # Voc锚 pode controlar isso com uma vari谩vel de ambiente.
11
+ LOG_LEVEL = os.environ.get("ADUC_LOG_LEVEL", "DEBUG").upper()
12
+ logging.basicConfig(level=LOG_LEVEL, format='[%(levelname)s] [%(name)s] %(message)s')
13
+ logger = logging.getLogger("AducDebug")
14
+
15
+
16
+ def _format_value(value):
17
+ """Formata os valores dos argumentos para uma exibi莽茫o concisa e informativa."""
18
+ if isinstance(value, torch.Tensor):
19
+ return f"Tensor(shape={list(value.shape)}, device='{value.device}', dtype={value.dtype})"
20
+ if isinstance(value, str) and len(value) > 70:
21
+ return f"'{value[:70]}...'"
22
+ if isinstance(value, list) and len(value) > 5:
23
+ return f"List(len={len(value)})"
24
+ if isinstance(value, dict) and len(value.keys()) > 5:
25
+ return f"Dict(keys={list(value.keys())[:5]}...)"
26
+ return repr(value)
27
+
28
+ def log_function_io(func):
29
+ """
30
+ Um decorador que registra as entradas, sa铆das e exce莽玫es de uma fun莽茫o.
31
+ Ele 茅 ativado apenas se o n铆vel de log estiver definido como DEBUG.
32
+ """
33
+ @functools.wraps(func)
34
+ def wrapper(*args, **kwargs):
35
+ # S贸 executa a l贸gica de log se o n铆vel for DEBUG
36
+ if logger.isEnabledFor(logging.DEBUG):
37
+ # Obt茅m o nome do m贸dulo e da fun莽茫o
38
+ func_name = f"{func.__module__}.{func.__name__}"
39
+
40
+ # Formata os argumentos de entrada
41
+ args_repr = [_format_value(a) for a in args]
42
+ kwargs_repr = {k: _format_value(v) for k, v in kwargs.items()}
43
+ signature = ", ".join(args_repr + [f"{k}={v}" for k, v in kwargs_repr.items()])
44
+
45
+ # Log de Entrada
46
+ logger.debug(f"================ IN脥CIO: {func_name} ================")
47
+ logger.debug(f" -> ENTRADA: ({signature})")
48
+
49
+ try:
50
+ # Executa a fun莽茫o original
51
+ result = func(*args, **kwargs)
52
+
53
+ # Formata e registra o resultado
54
+ result_repr = _format_value(result)
55
+ logger.debug(f" <- SA脥DA: {result_repr}")
56
+
57
+ except Exception as e:
58
+ # Registra qualquer exce莽茫o que ocorra
59
+ logger.error(f" <-- ERRO em {func_name}: {e}", exc_info=True)
60
+ raise # Re-lan莽a a exce莽茫o para n茫o alterar o comportamento do programa
61
+ finally:
62
+ # Log de Fim
63
+ logger.debug(f"================ FIM: {func_name} ================\n")
64
+
65
+ return result
66
+ else:
67
+ # Se o log n茫o estiver em modo DEBUG, executa a fun莽茫o sem nenhum overhead.
68
+ return func(*args, **kwargs)
69
+
70
+ return wrapper
api/ltx_utils.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # FILE: api/ltx/ltx_utils.py
2
+ # DESCRIPTION: Comprehensive, self-contained utility module for the LTX pipeline.
3
+ # Handles dependency path injection, model loading, data structures, and helper functions.
4
+
5
+ import os
6
+ import random
7
+ import json
8
+ import logging
9
+ import time
10
+ import sys
11
+ from pathlib import Path
12
+ from typing import Dict, Optional, Tuple, Union
13
+ from dataclasses import dataclass
14
+ from enum import Enum, auto
15
+
16
+ import numpy as np
17
+ import torch
18
+ import torchvision.transforms.functional as TVF
19
+ from PIL import Image
20
+ from safetensors import safe_open
21
+ from transformers import T5EncoderModel, T5Tokenizer
22
+
23
+ # ==============================================================================
24
+ # --- CRITICAL: DEPENDENCY PATH INJECTION ---
25
+ # ==============================================================================
26
+
27
+ # Define o caminho para o reposit贸rio clonado
28
+ LTX_VIDEO_REPO_DIR = Path("/data/LTX-Video")
29
+
30
+ def add_deps_to_path():
31
+ """
32
+ Adiciona o diret贸rio do reposit贸rio LTX ao sys.path para garantir que suas
33
+ bibliotecas possam ser importadas.
34
+ """
35
+ repo_path = str(LTX_VIDEO_REPO_DIR.resolve())
36
+ if repo_path not in sys.path:
37
+ sys.path.insert(0, repo_path)
38
+ logging.info(f"[ltx_utils] LTX-Video repository added to sys.path: {repo_path}")
39
+
40
+ # Executa a fun莽茫o imediatamente para configurar o ambiente antes de qualquer importa莽茫o.
41
+ add_deps_to_path()
42
+
43
+
44
+ # ==============================================================================
45
+ # --- IMPORTA脟脮ES DA BIBLIOTECA LTX-VIDEO (Ap贸s configura莽茫o do path) ---
46
+ # ==============================================================================
47
+ try:
48
+ from ltx_video.pipelines.pipeline_ltx_video import LTXVideoPipeline
49
+ from ltx_video.models.autoencoders.latent_upsampler import LatentUpsampler
50
+ from ltx_video.models.autoencoders.causal_video_autoencoder import CausalVideoAutoencoder
51
+ from ltx_video.models.transformers.transformer3d import Transformer3DModel
52
+ from ltx_video.models.transformers.symmetric_patchifier import SymmetricPatchifier
53
+ from ltx_video.schedulers.rf import RectifiedFlowScheduler
54
+ from ltx_video.models.autoencoders.vae_encode import un_normalize_latents, normalize_latents
55
+ import ltx_video.pipelines.crf_compressor as crf_compressor
56
+ except ImportError as e:
57
+ raise ImportError(f"Could not import from LTX-Video library even after setting sys.path. Check repo integrity at '{LTX_VIDEO_REPO_DIR}'. Error: {e}")
58
+
59
+
60
+ # ==============================================================================
61
+ # --- ESTRUTURAS DE DADOS E ENUMS (Centralizadas aqui) ---
62
+ # ==============================================================================
63
+
64
+ @dataclass
65
+ class ConditioningItem:
66
+ """Define a single frame-conditioning item, used to guide the generation pipeline."""
67
+ media_item: torch.Tensor
68
+ media_frame_number: int
69
+ conditioning_strength: float
70
+ media_x: Optional[int] = None
71
+ media_y: Optional[int] = None
72
+
73
+
74
+ class SkipLayerStrategy(Enum):
75
+ """Defines the strategy for how spatio-temporal guidance is applied across transformer blocks."""
76
+ AttentionSkip = auto()
77
+ AttentionValues = auto()
78
+ Residual = auto()
79
+ TransformerBlock = auto()
80
+
81
+
82
+ # ==============================================================================
83
+ # --- FUN脟脮ES DE CONSTRU脟脙O DE MODELO E PIPELINE ---
84
+ # ==============================================================================
85
+
86
+ def create_latent_upsampler(latent_upsampler_model_path: str, device: str) -> LatentUpsampler:
87
+ """Loads the Latent Upsampler model from a checkpoint path."""
88
+ logging.info(f"Loading Latent Upsampler from: {latent_upsampler_model_path} to device: {device}")
89
+ latent_upsampler = LatentUpsampler.from_pretrained(latent_upsampler_model_path)
90
+ latent_upsampler.to(device)
91
+ latent_upsampler.eval()
92
+ return latent_upsampler
93
+
94
+ def build_ltx_pipeline_on_cpu(config: Dict) -> Tuple[LTXVideoPipeline, Optional[torch.nn.Module]]:
95
+ """Builds the complete LTX pipeline and upsampler on the CPU."""
96
+ t0 = time.perf_counter()
97
+ logging.info("Building LTX pipeline on CPU...")
98
+
99
+ ckpt_path = Path(config["checkpoint_path"])
100
+ if not ckpt_path.is_file():
101
+ raise FileNotFoundError(f"Main checkpoint file not found: {ckpt_path}")
102
+
103
+ with safe_open(ckpt_path, framework="pt") as f:
104
+ metadata = f.metadata() or {}
105
+ config_str = metadata.get("config", "{}")
106
+ configs = json.loads(config_str)
107
+ allowed_inference_steps = configs.get("allowed_inference_steps")
108
+
109
+ vae = CausalVideoAutoencoder.from_pretrained(ckpt_path).to("cpu")
110
+ transformer = Transformer3DModel.from_pretrained(ckpt_path).to("cpu")
111
+ scheduler = RectifiedFlowScheduler.from_pretrained(ckpt_path)
112
+
113
+ text_encoder_path = config["text_encoder_model_name_or_path"]
114
+ text_encoder = T5EncoderModel.from_pretrained(text_encoder_path, subfolder="text_encoder").to("cpu")
115
+ tokenizer = T5Tokenizer.from_pretrained(text_encoder_path, subfolder="tokenizer")
116
+ patchifier = SymmetricPatchifier(patch_size=1)
117
+
118
+ precision = config.get("precision", "bfloat16")
119
+ if precision == "bfloat16":
120
+ vae.to(torch.bfloat16)
121
+ transformer.to(torch.bfloat16)
122
+ text_encoder.to(torch.bfloat16)
123
+
124
+ pipeline = LTXVideoPipeline(
125
+ transformer=transformer, patchifier=patchifier, text_encoder=text_encoder,
126
+ tokenizer=tokenizer, scheduler=scheduler, vae=vae,
127
+ allowed_inference_steps=allowed_inference_steps,
128
+ prompt_enhancer_image_caption_model=None, prompt_enhancer_image_caption_processor=None,
129
+ prompt_enhancer_llm_model=None, prompt_enhancer_llm_tokenizer=None,
130
+ )
131
+
132
+ latent_upsampler = None
133
+ if config.get("spatial_upscaler_model_path"):
134
+ spatial_path = config["spatial_upscaler_model_path"]
135
+ latent_upsampler = create_latent_upsampler(spatial_path, device="cpu")
136
+ if precision == "bfloat16":
137
+ latent_upsampler.to(torch.bfloat16)
138
+
139
+ logging.info(f"LTX pipeline built on CPU in {time.perf_counter() - t0:.2f}s")
140
+ return pipeline, latent_upsampler
141
+
142
+
143
+ # ==============================================================================
144
+ # --- FUN脟脮ES AUXILIARES (Latent Processing, Seed, Image Prep) ---
145
+ # ==============================================================================
146
+
147
+ def adain_filter_latent(
148
+ latents: torch.Tensor, reference_latents: torch.Tensor, factor=1.0
149
+ ) -> torch.Tensor:
150
+ """Applies AdaIN to transfer the style from a reference latent to another."""
151
+ result = latents.clone()
152
+ for i in range(latents.size(0)):
153
+ for c in range(latents.size(1)):
154
+ r_sd, r_mean = torch.std_mean(reference_latents[i, c], dim=None)
155
+ i_sd, i_mean = torch.std_mean(result[i, c], dim=None)
156
+ if i_sd > 1e-6:
157
+ result[i, c] = ((result[i, c] - i_mean) / i_sd) * r_sd + r_mean
158
+ return torch.lerp(latents, result, factor)
159
+
160
+ def seed_everything(seed: int):
161
+ """Sets the seed for reproducibility."""
162
+ random.seed(seed)
163
+ os.environ['PYTHONHASHSEED'] = str(seed)
164
+ np.random.seed(seed)
165
+ torch.manual_seed(seed)
166
+ torch.cuda.manual_seed_all(seed)
167
+ torch.backends.cudnn.deterministic = True
168
+ torch.backends.cudnn.benchmark = False
169
+
170
+ def load_image_to_tensor_with_resize_and_crop(
171
+ image_input: Union[str, Image.Image],
172
+ target_height: int,
173
+ target_width: int,
174
+ ) -> torch.Tensor:
175
+ """Loads and processes an image into a 5D tensor compatible with the LTX pipeline."""
176
+ if isinstance(image_input, str):
177
+ image = Image.open(image_input).convert("RGB")
178
+ elif isinstance(image_input, Image.Image):
179
+ image = image_input
180
+ else:
181
+ raise ValueError("image_input must be a file path or a PIL Image object")
182
+
183
+ input_width, input_height = image.size
184
+ aspect_ratio_target = target_width / target_height
185
+ aspect_ratio_frame = input_width / input_height
186
+
187
+ if aspect_ratio_frame > aspect_ratio_target:
188
+ new_width, new_height = int(input_height * aspect_ratio_target), input_height
189
+ x_start, y_start = (input_width - new_width) // 2, 0
190
+ else:
191
+ new_width, new_height = input_width, int(input_width / aspect_ratio_target)
192
+ x_start, y_start = 0, (input_height - new_height) // 2
193
+
194
+ image = image.crop((x_start, y_start, x_start + new_width, y_start + new_height))
195
+ image = image.resize((target_width, target_height), Image.Resampling.LANCZOS)
196
+
197
+ frame_tensor = TVF.to_tensor(image)
198
+ frame_tensor = TVF.gaussian_blur(frame_tensor, kernel_size=(3, 3))
199
+
200
+ frame_tensor_hwc = frame_tensor.permute(1, 2, 0)
201
+ frame_tensor_hwc = crf_compressor.compress(frame_tensor_hwc)
202
+ frame_tensor = frame_tensor_hwc.permute(2, 0, 1)
203
+ # Normalize to [-1, 1] range
204
+ frame_tensor = (frame_tensor * 2.0) - 1.0
205
+
206
+ # Create 5D tensor: (batch_size=1, channels=3, num_frames=1, height, width)
207
+ return frame_tensor.unsqueeze(0).unsqueeze(2)