| from __future__ import annotations |
|
|
| import os |
|
|
| import torch |
| import torch.nn.functional as F |
|
|
| from mmgp import offload as mmgp_offload |
| from shared.attention import pay_attention |
|
|
| EDITANYTHING_REF_START_BLOCK = 12 |
| EDITANYTHING_REF_END_BLOCK = 35 |
| EDITANYTHING_REF_CONTEXT_SCALE = 0.01 |
| EDITANYTHING_REF_TOKEN_SCALE = 0.25 |
| EDITANYTHING_ADALN_SCALE = 2.0 |
|
|
|
|
| def _module_state(module_paths) -> dict[str, torch.Tensor]: |
| paths = module_paths if isinstance(module_paths, (list, tuple)) else [module_paths] |
| state = {} |
| for path in paths: |
| if not path or "edit_anything" not in os.path.basename(str(path)).lower(): |
| continue |
| sd, _, _ = mmgp_offload.load_sd(path, writable_tensors=False) |
| state.update(sd) |
| return state |
|
|
|
|
| def _strip_prefix(state: dict[str, torch.Tensor], prefix: str) -> dict[str, torch.Tensor]: |
| return {key[len(prefix) :]: value for key, value in state.items() if key.startswith(prefix)} |
|
|
|
|
| class _LoRALinear(torch.nn.Module): |
| def __init__(self, base_linear: torch.nn.Linear, lora_a: torch.Tensor, lora_b: torch.Tensor) -> None: |
| super().__init__() |
| object.__setattr__(self, "base_linear", base_linear) |
| self.lora_A = torch.nn.Parameter(lora_a, requires_grad=False) |
| self.lora_B = torch.nn.Parameter(lora_b, requires_grad=False) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| out = self.base_linear(x) |
| lora_dtype = self.lora_A.dtype |
| lora_out = F.linear(F.linear(x.to(dtype=lora_dtype), self.lora_A), self.lora_B) |
| return out.add(lora_out.to(device=out.device, dtype=out.dtype)) |
|
|
|
|
| class EditAnythingRefAttention(torch.nn.Module): |
| def __init__(self, base_attn: torch.nn.Module, state: dict[str, torch.Tensor], prefix: str) -> None: |
| super().__init__() |
| object.__setattr__(self, "base_attn", base_attn) |
| self.heads = int(base_attn.heads) |
| self.dim_head = int(base_attn.dim_head) |
| self.to_q = _LoRALinear(base_attn.to_q, state[f"{prefix}to_q.lora_A.weight"], state[f"{prefix}to_q.lora_B.weight"]) |
| self.to_k = _LoRALinear(base_attn.to_k, state[f"{prefix}to_k.lora_A.weight"], state[f"{prefix}to_k.lora_B.weight"]) |
| self.to_v = _LoRALinear(base_attn.to_v, state[f"{prefix}to_v.lora_A.weight"], state[f"{prefix}to_v.lora_B.weight"]) |
| self.to_out = _LoRALinear(base_attn.to_out[0], state[f"{prefix}to_out.0.lora_A.weight"], state[f"{prefix}to_out.0.lora_B.weight"]) |
|
|
| def forward(self, x_list: list[torch.Tensor], context_list: list[torch.Tensor] | None = None) -> torch.Tensor: |
| x = x_list[0] |
| x_list.clear() |
| context = context_list[0] if context_list is not None else x |
| if context_list is not None: |
| context_list.clear() |
| q = self.to_q(x) |
| k = self.to_k(context) |
| v = self.to_v(context) |
| self.base_attn.q_norm(q) |
| self.base_attn.k_norm(k) |
| q = q.view(q.shape[0], -1, self.heads, self.dim_head) |
| k = k.view(k.shape[0], -1, self.heads, self.dim_head) |
| v = v.view(v.shape[0], -1, self.heads, self.dim_head) |
| force_attention, attention_version = self.base_attn._resolve_attention_override() |
| out = pay_attention([q, k, v], force_attention=force_attention, version=attention_version, recycle_q=True) |
| out = out.flatten(2, 3) |
| return self.to_out(out) |
|
|
|
|
| class EditAnythingRefVisualProj(torch.nn.Module): |
| def __init__(self, state: dict[str, torch.Tensor]) -> None: |
| super().__init__() |
| fc1_weight = state["fc1.weight"] |
| proj_weight = state["proj.weight"] |
| self.fc1 = torch.nn.Linear(fc1_weight.shape[1], fc1_weight.shape[0], bias="fc1.bias" in state) |
| self.proj = torch.nn.Linear(proj_weight.shape[1], proj_weight.shape[0], bias="proj.bias" in state) |
| self.norm = torch.nn.LayerNorm(proj_weight.shape[0]) |
| self.pos_embed = torch.nn.Parameter(state["pos_embed"], requires_grad=False) |
| self.load_state_dict(state, strict=True) |
| self.requires_grad_(False) |
|
|
| def forward(self, ref_latent: torch.Tensor, token_scale: float = EDITANYTHING_REF_TOKEN_SCALE) -> torch.Tensor: |
| ref_frame = ref_latent.mean(dim=2) |
| local = F.adaptive_avg_pool2d(ref_frame, (4, 8)).permute(0, 2, 3, 1).reshape(ref_frame.shape[0], 32, -1) |
| global_mean = ref_frame.mean(dim=(-2, -1)) |
| global_std = ref_frame.std(dim=(-2, -1), unbiased=False) |
| stats = torch.cat([global_mean, global_std], dim=-1).unsqueeze(1).expand(-1, local.shape[1], -1) |
| tokens = torch.cat([local, stats], dim=-1) |
| tokens = self.proj(F.silu(self.fc1(tokens))) |
| tokens = self.norm(tokens) |
| tokens = tokens + self.pos_embed[:, : tokens.shape[1]].to(device=tokens.device, dtype=tokens.dtype) |
| return tokens * float(token_scale) |
|
|
|
|
| class EditAnythingRefAdaLNProj(torch.nn.Module): |
| def __init__(self, state: dict[str, torch.Tensor]) -> None: |
| super().__init__() |
| fc1_weight = state["fc1.weight"] |
| proj_weight = state["proj.weight"] |
| self.fc1 = torch.nn.Linear(fc1_weight.shape[1], fc1_weight.shape[0], bias="fc1.bias" in state) |
| self.proj = torch.nn.Linear(proj_weight.shape[1], proj_weight.shape[0], bias="proj.bias" in state) |
| self.load_state_dict(state, strict=True) |
| self.requires_grad_(False) |
|
|
| def forward(self, ref_latent: torch.Tensor, adaln_scale: float = EDITANYTHING_ADALN_SCALE) -> torch.Tensor: |
| ref_frame = ref_latent.mean(dim=2) |
| avg_1x1 = F.adaptive_avg_pool2d(ref_frame, (1, 1)).flatten(1) |
| avg_2x2 = F.adaptive_avg_pool2d(ref_frame, (2, 2)).flatten(1) |
| max_1x1 = F.adaptive_max_pool2d(ref_frame, (1, 1)).flatten(1) |
| pooled = torch.cat([avg_1x1, avg_2x2, max_1x1], dim=-1) |
| return self.proj(F.silu(self.fc1(pooled))) * float(adaln_scale) |
|
|
|
|
| def install_editanything_modules(velocity_model: torch.nn.Module, module_paths, model_def: dict | None = None) -> None: |
| state = _module_state(module_paths) |
| if not state: |
| return |
| model_def = model_def or {} |
| velocity_model.editanything_ref_start_block = int(model_def.get("ltx2_edit_anything_ref_start_block", EDITANYTHING_REF_START_BLOCK)) |
| velocity_model.editanything_ref_end_block = int(model_def.get("ltx2_edit_anything_ref_end_block", EDITANYTHING_REF_END_BLOCK)) |
| velocity_model.editanything_ref_context_scale = float(model_def.get("ltx2_edit_anything_ref_context_scale", EDITANYTHING_REF_CONTEXT_SCALE)) |
| velocity_model.editanything_ref_token_scale = float(model_def.get("ltx2_edit_anything_ref_token_scale", EDITANYTHING_REF_TOKEN_SCALE)) |
| velocity_model.editanything_adaln_scale = float(model_def.get("ltx2_edit_anything_adaln_scale", EDITANYTHING_ADALN_SCALE)) |
| visual_state = _strip_prefix(state, "ref_visual_proj.") |
| if visual_state: |
| velocity_model.editanything_ref_visual_proj = EditAnythingRefVisualProj(visual_state) |
| adaln_state = _strip_prefix(state, "ref_adaln_proj.") |
| if adaln_state: |
| velocity_model.editanything_ref_adaln_proj = EditAnythingRefAdaLNProj(adaln_state) |
| role_weight = state.get("role_embedding.embedding.weight") |
| if role_weight is not None: |
| role_embedding = torch.nn.Embedding(role_weight.shape[0], role_weight.shape[1]) |
| role_embedding.weight = torch.nn.Parameter(role_weight, requires_grad=False) |
| velocity_model.editanything_role_embedding = role_embedding |
| for block in getattr(velocity_model, "transformer_blocks", []): |
| prefix = f"diffusion_model.transformer_blocks.{block.idx}.ref_attn." |
| if f"{prefix}to_q.lora_A.weight" not in state: |
| continue |
| block.ref_attn = EditAnythingRefAttention(block.attn2, state, prefix) |
| block.editanything_ref_start_block = velocity_model.editanything_ref_start_block |
| block.editanything_ref_end_block = velocity_model.editanything_ref_end_block |
| block.editanything_ref_context_scale = velocity_model.editanything_ref_context_scale |
| velocity_model.editanything_module_loaded = True |
| print("[WAN2GP][LTX2] EditAnything reference module installed.") |
|
|
|
|
| def build_editanything_reference_conditioning( |
| transformer: torch.nn.Module, |
| ref_images, |
| height: int, |
| width: int, |
| video_encoder: torch.nn.Module, |
| dtype: torch.dtype, |
| device: torch.device, |
| tiling_config=None, |
| ): |
| from .ltx_core.conditioning import VideoConditionByReferenceLatent |
| from .ltx_core.model.video_vae import encode_video as vae_encode_video |
| from .ltx_pipelines.utils.media_io import load_image_conditioning |
|
|
| velocity_model = getattr(transformer, "velocity_model", transformer) |
| if not getattr(velocity_model, "editanything_module_loaded", False) or not ref_images: |
| return [], None, None |
| ref_image = ref_images[0] if isinstance(ref_images, (list, tuple)) else ref_images |
| image = load_image_conditioning(ref_image, height=height, width=width, dtype=dtype, device=device, resample="lanczos") |
| ref_latent = vae_encode_video(image, video_encoder, tiling_config).to(dtype=dtype) |
| conditionings = [VideoConditionByReferenceLatent(ref_latent, strength=1.0)] |
| ref_context = ref_adaln = None |
| visual_proj = getattr(velocity_model, "editanything_ref_visual_proj", None) |
| if visual_proj is not None: |
| visual_param = next(visual_proj.parameters()) |
| ref_context = visual_proj(ref_latent.to(device=device, dtype=visual_param.dtype), velocity_model.editanything_ref_token_scale).detach() |
| adaln_proj = getattr(velocity_model, "editanything_ref_adaln_proj", None) |
| if adaln_proj is not None: |
| adaln_param = next(adaln_proj.parameters()) |
| ref_adaln = adaln_proj(ref_latent.to(device=device, dtype=adaln_param.dtype), velocity_model.editanything_adaln_scale).detach() |
| return conditionings, ref_context, ref_adaln |
|
|