caarleexx commited on
Commit
8d53dc7
·
verified ·
1 Parent(s): 5a15d3e

Delete api/aduc_ltx_latent_patch.py

Browse files
Files changed (1) hide show
  1. api/aduc_ltx_latent_patch.py +0 -268
api/aduc_ltx_latent_patch.py DELETED
@@ -1,268 +0,0 @@
1
- # aduc_ltx_latent_patch.py
2
- #
3
- # Este módulo fornece um monkey patch para a classe LTXVideoPipeline da biblioteca ltx_video.
4
- # A principal funcionalidade deste patch é otimizar o processo de condicionamento, permitindo
5
- # que a pipeline aceite tensores de latentes pré-calculados diretamente através de um
6
- # `ConditioningItem` modificado. Isso evita a re-codificação desnecessária de mídias (imagens/vídeos)
7
- # pela VAE, resultando em um ganho de performance significativo quando os latentes já estão disponíveis.
8
-
9
- import torch
10
- from torch import Tensor
11
- from typing import Optional, List, Tuple
12
- from pathlib import Path
13
- import os
14
- import sys
15
-
16
- DEPS_DIR = Path("/data")
17
- LTX_VIDEO_REPO_DIR = DEPS_DIR / "LTX-Video"
18
- def add_deps_to_path(repo_path: Path):
19
- """Adiciona o diretório do repositório ao sys.path para importações locais."""
20
- resolved_path = str(repo_path.resolve())
21
- if resolved_path not in sys.path:
22
- sys.path.insert(0, resolved_path)
23
- if LTXV_DEBUG:
24
- print(f"[DEBUG] Adicionado ao sys.path: {resolved_path}")
25
-
26
- # --- Execução da configuração inicial ---
27
- if not LTX_VIDEO_REPO_DIR.exists():
28
- _run_setup_script()
29
- add_deps_to_path(LTX_VIDEO_REPO_DIR)
30
-
31
-
32
- # Tenta importar as dependências necessárias do módulo original que será modificado.
33
- # Isso requer que o ambiente Python tenha o pacote `ltx_video` acessível em seu sys.path.
34
- try:
35
- from ltx_video.pipelines.pipeline_ltx_video import (
36
- LTXVideoPipeline,
37
- ConditioningItem as OriginalConditioningItem
38
- )
39
- from ltx_video.models.autoencoders.causal_video_autoencoder import CausalVideoAutoencoder
40
- from ltx_video.models.autoencoders.vae_encode import vae_encode, latent_to_pixel_coords
41
- from diffusers.utils.torch_utils import randn_tensor
42
- except ImportError as e:
43
- print(f"FATAL ERROR: Could not import dependencies from 'ltx_video'. "
44
- f"Please ensure the environment is correctly set up. Error: {e}")
45
- # Interrompe a execução se as dependências essenciais não puderem ser encontradas.
46
- raise
47
-
48
- print("[INFO] Patch module 'aduc_ltx_latent_patch' loaded successfully.")
49
-
50
- # ==============================================================================
51
- # 1. NOVA DEFINIÇÃO DA DATACLASS `ConditioningItem`
52
- # ==============================================================================
53
-
54
- from dataclasses import dataclass
55
-
56
- @dataclass
57
- class PatchedConditioningItem:
58
- """
59
- Versão modificada do `ConditioningItem` que aceita tensores de pixel (`media_item`)
60
- ou tensores de latentes pré-codificados (`latents`).
61
-
62
- Attributes:
63
- media_frame_number (int): Quadro inicial do item de condicionamento no vídeo.
64
- conditioning_strength (float): Força do condicionamento (0.0 a 1.0).
65
- media_item (Optional[Tensor]): Tensor de mídia (pixels). Usado se `latents` for None.
66
- media_x (Optional[int]): Coordenada X (esquerda) para posicionamento espacial.
67
- media_y (Optional[int]): Coordenada Y (topo) para posicionamento espacial.
68
- latents (Optional[Tensor]): Tensor de latentes pré-codificado. Terá precedência sobre `media_item`.
69
- """
70
- media_frame_number: int
71
- conditioning_strength: float
72
- media_item: Optional[Tensor] = None
73
- media_x: Optional[int] = None
74
- media_y: Optional[int] = None
75
- latents: Optional[Tensor] = None
76
-
77
- def __post_init__(self):
78
- """Valida o estado do objeto após a inicialização."""
79
- if self.media_item is None and self.latents is None:
80
- raise ValueError("A `PatchedConditioningItem` must have either 'media_item' or 'latents' defined.")
81
- if self.media_item is not None and self.latents is not None:
82
- print("[WARNING] `PatchedConditioningItem` received both 'media_item' and 'latents'. "
83
- "The 'latents' tensor will take precedence.")
84
-
85
- # ==============================================================================
86
- # 2. NOVA IMPLEMENTAÇÃO DA FUNÇÃO `prepare_conditioning`
87
- # ==============================================================================
88
-
89
- def prepare_conditioning_with_latents(
90
- self: LTXVideoPipeline,
91
- conditioning_items: Optional[List[PatchedConditioningItem]],
92
- init_latents: Tensor,
93
- num_frames: int,
94
- height: int,
95
- width: int,
96
- vae_per_channel_normalize: bool = False,
97
- generator: Optional[torch.Generator] = None,
98
- ) -> Tuple[Tensor, Tensor, Optional[Tensor], int]:
99
- """
100
- Versão modificada de `prepare_conditioning` que prioriza o uso de latentes pré-calculados
101
- dos `conditioning_items`, evitando a re-codificação desnecessária pela VAE.
102
- """
103
- assert isinstance(self, LTXVideoPipeline), "This function must be called as a method of LTXVideoPipeline."
104
- assert isinstance(self.vae, CausalVideoAutoencoder), "VAE must be of type CausalVideoAutoencoder."
105
-
106
- # Se não há itens de condicionamento, apenas patchifica os latentes e retorna.
107
- if not conditioning_items:
108
- init_latents, init_latent_coords = self.patchifier.patchify(latents=init_latents)
109
- init_pixel_coords = latent_to_pixel_coords(
110
- init_latent_coords, self.vae,
111
- causal_fix=self.transformer.config.causal_temporal_positioning
112
- )
113
- return init_latents, init_pixel_coords, None, 0
114
-
115
- # Inicializa tensores para acumular resultados
116
- init_conditioning_mask = torch.zeros(
117
- init_latents[:, 0, :, :, :].shape, dtype=torch.float32, device=init_latents.device
118
- )
119
- extra_conditioning_latents = []
120
- extra_conditioning_pixel_coords = []
121
- extra_conditioning_mask = []
122
- extra_conditioning_num_latents = 0
123
-
124
- for item in conditioning_items:
125
- item_latents: Tensor
126
-
127
- # --- LÓGICA CENTRAL DO PATCH ---
128
- if item.latents is not None:
129
- # 1. Se latentes pré-calculados existem, use-os diretamente.
130
- item_latents = item.latents.to(dtype=init_latents.dtype, device=init_latents.device)
131
- if item_latents.ndim != 5:
132
- raise ValueError(f"Latents must have 5 dimensions (b, c, f, h, w), but got {item_latents.ndim}")
133
- elif item.media_item is not None:
134
- # 2. Caso contrário, volte para o fluxo original de codificação da VAE.
135
- resized_item = self._resize_conditioning_item(item, height, width)
136
- media_item = resized_item.media_item
137
- assert media_item.ndim == 5, f"media_item must have 5 dims, but got {media_item.ndim}"
138
-
139
- item_latents = vae_encode(
140
- media_item.to(dtype=self.vae.dtype, device=self.vae.device),
141
- self.vae,
142
- vae_per_channel_normalize=vae_per_channel_normalize,
143
- ).to(dtype=init_latents.dtype)
144
- else:
145
- # Este caso é prevenido pelo __post_init__ do dataclass, mas é bom ter uma checagem.
146
- raise ValueError("ConditioningItem is invalid: it has neither 'latents' nor 'media_item'.")
147
- # --- FIM DA LÓGICA DO PATCH ---
148
-
149
- media_frame_number = item.media_frame_number
150
- strength = item.conditioning_strength
151
-
152
- # O resto da lógica da função original é aplicado sobre `item_latents`.
153
- if media_frame_number == 0:
154
- item_latents, l_x, l_y = self._get_latent_spatial_position(
155
- item_latents, item, height, width, strip_latent_border=True
156
- )
157
- _, _, f_l, h_l, w_l = item_latents.shape
158
- init_latents[:, :, :f_l, l_y : l_y + h_l, l_x : l_x + w_l] = torch.lerp(
159
- init_latents[:, :, :f_l, l_y : l_y + h_l, l_x : l_x + w_l], item_latents, strength
160
- )
161
- init_conditioning_mask[:, :f_l, l_y : l_y + h_l, l_x : l_x + w_l] = strength
162
- else:
163
- if item_latents.shape[2] > 1:
164
- (init_latents, init_conditioning_mask, item_latents) = self._handle_non_first_conditioning_sequence(
165
- init_latents, init_conditioning_mask, item_latents, media_frame_number, strength
166
- )
167
-
168
- if item_latents is not None:
169
- noise = randn_tensor(
170
- item_latents.shape, generator=generator,
171
- device=item_latents.device, dtype=item_latents.dtype
172
- )
173
- item_latents = torch.lerp(noise, item_latents, strength)
174
- item_latents, latent_coords = self.patchifier.patchify(latents=item_latents)
175
- pixel_coords = latent_to_pixel_coords(
176
- latent_coords, self.vae,
177
- causal_fix=self.transformer.config.causal_temporal_positioning
178
- )
179
- pixel_coords[:, 0] += media_frame_number
180
- extra_conditioning_num_latents += item_latents.shape[1]
181
- conditioning_mask = torch.full(
182
- item_latents.shape[:2], strength,
183
- dtype=torch.float32, device=init_latents.device
184
- )
185
- extra_conditioning_latents.append(item_latents)
186
- extra_conditioning_pixel_coords.append(pixel_coords)
187
- extra_conditioning_mask.append(conditioning_mask)
188
-
189
- # Patchifica os latentes principais e a máscara de condicionamento
190
- init_latents, init_latent_coords = self.patchifier.patchify(latents=init_latents)
191
- init_pixel_coords = latent_to_pixel_coords(
192
- init_latent_coords, self.vae,
193
- causal_fix=self.transformer.config.causal_temporal_positioning
194
- )
195
- init_conditioning_mask, _ = self.patchifier.patchify(latents=init_conditioning_mask.unsqueeze(1))
196
- init_conditioning_mask = init_conditioning_mask.squeeze(-1)
197
-
198
- # Concatena os latentes extras (se houver)
199
- if extra_conditioning_latents:
200
- init_latents = torch.cat([*extra_conditioning_latents, init_latents], dim=1)
201
- init_pixel_coords = torch.cat([*extra_conditioning_pixel_coords, init_pixel_coords], dim=2)
202
- init_conditioning_mask = torch.cat([*extra_conditioning_mask, init_conditioning_mask], dim=1)
203
-
204
- if self.transformer.use_tpu_flash_attention:
205
- init_latents = init_latents[:, :-extra_conditioning_num_latents]
206
- init_pixel_coords = init_pixel_coords[:, :, :-extra_conditioning_num_latents]
207
- init_conditioning_mask = init_conditioning_mask[:, :-extra_conditioning_num_latents]
208
-
209
- return init_latents, init_pixel_coords, init_conditioning_mask, extra_conditioning_num_latents
210
-
211
-
212
- # ==============================================================================
213
- # 3. CLASSE DO MONKEY PATCHER
214
- # ==============================================================================
215
-
216
- class LTXLatentConditioningPatch:
217
- """
218
- Classe estática para aplicar e reverter o monkey patch na pipeline LTX-Video.
219
-
220
- Esta classe substitui o método `prepare_conditioning` da `LTXVideoPipeline`
221
- pela versão otimizada que suporta latentes pré-calculados, e implicitamente
222
- requer o uso da `PatchedConditioningItem`.
223
- """
224
- _original_prepare_conditioning = None
225
- _is_patched = False
226
-
227
- @staticmethod
228
- def apply():
229
- """
230
- Aplica o monkey patch à classe `LTXVideoPipeline`.
231
-
232
- Guarda o método original e o substitui pela nova implementação.
233
- É idempotente; aplicar múltiplas vezes não causa efeito adicional.
234
- """
235
- if LTXLatentConditioningPatch._is_patched:
236
- print("[WARNING] LTXLatentConditioningPatch has already been applied. Ignoring.")
237
- return
238
-
239
- print("[INFO] Applying monkey patch for latent-based conditioning...")
240
-
241
- # Guarda a implementação original para permitir a reversão.
242
- LTXLatentConditioningPatch._original_prepare_conditioning = LTXVideoPipeline.prepare_conditioning
243
-
244
- # Substitui o método na classe LTXVideoPipeline.
245
- # Todas as instâncias futuras e existentes da classe usarão este novo método.
246
- LTXVideoPipeline.prepare_conditioning = prepare_conditioning_with_latents
247
-
248
- LTXLatentConditioningPatch._is_patched = True
249
- print("[SUCCESS] Monkey patch applied successfully.")
250
- print(" - `LTXVideoPipeline.prepare_conditioning` has been updated.")
251
- print(" - NOTE: Remember to use `aduc_ltx_latent_patch.PatchedConditioningItem` when creating conditioning items.")
252
-
253
- @staticmethod
254
- def revert():
255
- """
256
- Reverte o monkey patch, restaurando a implementação original.
257
- """
258
- if not LTXLatentConditioningPatch._is_patched:
259
- print("[WARNING] Patch is not currently applied. No action taken.")
260
- return
261
-
262
- if LTXLatentConditioningPatch._original_prepare_conditioning:
263
- print("[INFO] Reverting LTXLatentConditioningPatch...")
264
- LTXVideoPipeline.prepare_conditioning = LTXLatentConditioningPatch._original_prepare_conditioning
265
- LTXLatentConditioningPatch._is_patched = False
266
- print("[SUCCESS] Patch reverted successfully. Original functionality restored.")
267
- else:
268
- print("[ERROR] Cannot revert: original implementation was not saved.")