ColabWan / models /ltx2 /editanything.py
1ripon1's picture
Upload folder using huggingface_hub
7344bef verified
Raw
History Blame Contribute Delete
9.83 kB
from __future__ import annotations
import os
import torch
import torch.nn.functional as F
from mmgp import offload as mmgp_offload
from shared.attention import pay_attention
EDITANYTHING_REF_START_BLOCK = 12
EDITANYTHING_REF_END_BLOCK = 35
EDITANYTHING_REF_CONTEXT_SCALE = 0.01
EDITANYTHING_REF_TOKEN_SCALE = 0.25
EDITANYTHING_ADALN_SCALE = 2.0
def _module_state(module_paths) -> dict[str, torch.Tensor]:
paths = module_paths if isinstance(module_paths, (list, tuple)) else [module_paths]
state = {}
for path in paths:
if not path or "edit_anything" not in os.path.basename(str(path)).lower():
continue
sd, _, _ = mmgp_offload.load_sd(path, writable_tensors=False)
state.update(sd)
return state
def _strip_prefix(state: dict[str, torch.Tensor], prefix: str) -> dict[str, torch.Tensor]:
return {key[len(prefix) :]: value for key, value in state.items() if key.startswith(prefix)}
class _LoRALinear(torch.nn.Module):
def __init__(self, base_linear: torch.nn.Linear, lora_a: torch.Tensor, lora_b: torch.Tensor) -> None:
super().__init__()
object.__setattr__(self, "base_linear", base_linear)
self.lora_A = torch.nn.Parameter(lora_a, requires_grad=False)
self.lora_B = torch.nn.Parameter(lora_b, requires_grad=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
out = self.base_linear(x)
lora_dtype = self.lora_A.dtype
lora_out = F.linear(F.linear(x.to(dtype=lora_dtype), self.lora_A), self.lora_B)
return out.add(lora_out.to(device=out.device, dtype=out.dtype))
class EditAnythingRefAttention(torch.nn.Module):
def __init__(self, base_attn: torch.nn.Module, state: dict[str, torch.Tensor], prefix: str) -> None:
super().__init__()
object.__setattr__(self, "base_attn", base_attn)
self.heads = int(base_attn.heads)
self.dim_head = int(base_attn.dim_head)
self.to_q = _LoRALinear(base_attn.to_q, state[f"{prefix}to_q.lora_A.weight"], state[f"{prefix}to_q.lora_B.weight"])
self.to_k = _LoRALinear(base_attn.to_k, state[f"{prefix}to_k.lora_A.weight"], state[f"{prefix}to_k.lora_B.weight"])
self.to_v = _LoRALinear(base_attn.to_v, state[f"{prefix}to_v.lora_A.weight"], state[f"{prefix}to_v.lora_B.weight"])
self.to_out = _LoRALinear(base_attn.to_out[0], state[f"{prefix}to_out.0.lora_A.weight"], state[f"{prefix}to_out.0.lora_B.weight"])
def forward(self, x_list: list[torch.Tensor], context_list: list[torch.Tensor] | None = None) -> torch.Tensor:
x = x_list[0]
x_list.clear()
context = context_list[0] if context_list is not None else x
if context_list is not None:
context_list.clear()
q = self.to_q(x)
k = self.to_k(context)
v = self.to_v(context)
self.base_attn.q_norm(q)
self.base_attn.k_norm(k)
q = q.view(q.shape[0], -1, self.heads, self.dim_head)
k = k.view(k.shape[0], -1, self.heads, self.dim_head)
v = v.view(v.shape[0], -1, self.heads, self.dim_head)
force_attention, attention_version = self.base_attn._resolve_attention_override()
out = pay_attention([q, k, v], force_attention=force_attention, version=attention_version, recycle_q=True)
out = out.flatten(2, 3)
return self.to_out(out)
class EditAnythingRefVisualProj(torch.nn.Module):
def __init__(self, state: dict[str, torch.Tensor]) -> None:
super().__init__()
fc1_weight = state["fc1.weight"]
proj_weight = state["proj.weight"]
self.fc1 = torch.nn.Linear(fc1_weight.shape[1], fc1_weight.shape[0], bias="fc1.bias" in state)
self.proj = torch.nn.Linear(proj_weight.shape[1], proj_weight.shape[0], bias="proj.bias" in state)
self.norm = torch.nn.LayerNorm(proj_weight.shape[0])
self.pos_embed = torch.nn.Parameter(state["pos_embed"], requires_grad=False)
self.load_state_dict(state, strict=True)
self.requires_grad_(False)
def forward(self, ref_latent: torch.Tensor, token_scale: float = EDITANYTHING_REF_TOKEN_SCALE) -> torch.Tensor:
ref_frame = ref_latent.mean(dim=2)
local = F.adaptive_avg_pool2d(ref_frame, (4, 8)).permute(0, 2, 3, 1).reshape(ref_frame.shape[0], 32, -1)
global_mean = ref_frame.mean(dim=(-2, -1))
global_std = ref_frame.std(dim=(-2, -1), unbiased=False)
stats = torch.cat([global_mean, global_std], dim=-1).unsqueeze(1).expand(-1, local.shape[1], -1)
tokens = torch.cat([local, stats], dim=-1)
tokens = self.proj(F.silu(self.fc1(tokens)))
tokens = self.norm(tokens)
tokens = tokens + self.pos_embed[:, : tokens.shape[1]].to(device=tokens.device, dtype=tokens.dtype)
return tokens * float(token_scale)
class EditAnythingRefAdaLNProj(torch.nn.Module):
def __init__(self, state: dict[str, torch.Tensor]) -> None:
super().__init__()
fc1_weight = state["fc1.weight"]
proj_weight = state["proj.weight"]
self.fc1 = torch.nn.Linear(fc1_weight.shape[1], fc1_weight.shape[0], bias="fc1.bias" in state)
self.proj = torch.nn.Linear(proj_weight.shape[1], proj_weight.shape[0], bias="proj.bias" in state)
self.load_state_dict(state, strict=True)
self.requires_grad_(False)
def forward(self, ref_latent: torch.Tensor, adaln_scale: float = EDITANYTHING_ADALN_SCALE) -> torch.Tensor:
ref_frame = ref_latent.mean(dim=2)
avg_1x1 = F.adaptive_avg_pool2d(ref_frame, (1, 1)).flatten(1)
avg_2x2 = F.adaptive_avg_pool2d(ref_frame, (2, 2)).flatten(1)
max_1x1 = F.adaptive_max_pool2d(ref_frame, (1, 1)).flatten(1)
pooled = torch.cat([avg_1x1, avg_2x2, max_1x1], dim=-1)
return self.proj(F.silu(self.fc1(pooled))) * float(adaln_scale)
def install_editanything_modules(velocity_model: torch.nn.Module, module_paths, model_def: dict | None = None) -> None:
state = _module_state(module_paths)
if not state:
return
model_def = model_def or {}
velocity_model.editanything_ref_start_block = int(model_def.get("ltx2_edit_anything_ref_start_block", EDITANYTHING_REF_START_BLOCK))
velocity_model.editanything_ref_end_block = int(model_def.get("ltx2_edit_anything_ref_end_block", EDITANYTHING_REF_END_BLOCK))
velocity_model.editanything_ref_context_scale = float(model_def.get("ltx2_edit_anything_ref_context_scale", EDITANYTHING_REF_CONTEXT_SCALE))
velocity_model.editanything_ref_token_scale = float(model_def.get("ltx2_edit_anything_ref_token_scale", EDITANYTHING_REF_TOKEN_SCALE))
velocity_model.editanything_adaln_scale = float(model_def.get("ltx2_edit_anything_adaln_scale", EDITANYTHING_ADALN_SCALE))
visual_state = _strip_prefix(state, "ref_visual_proj.")
if visual_state:
velocity_model.editanything_ref_visual_proj = EditAnythingRefVisualProj(visual_state)
adaln_state = _strip_prefix(state, "ref_adaln_proj.")
if adaln_state:
velocity_model.editanything_ref_adaln_proj = EditAnythingRefAdaLNProj(adaln_state)
role_weight = state.get("role_embedding.embedding.weight")
if role_weight is not None:
role_embedding = torch.nn.Embedding(role_weight.shape[0], role_weight.shape[1])
role_embedding.weight = torch.nn.Parameter(role_weight, requires_grad=False)
velocity_model.editanything_role_embedding = role_embedding
for block in getattr(velocity_model, "transformer_blocks", []):
prefix = f"diffusion_model.transformer_blocks.{block.idx}.ref_attn."
if f"{prefix}to_q.lora_A.weight" not in state:
continue
block.ref_attn = EditAnythingRefAttention(block.attn2, state, prefix)
block.editanything_ref_start_block = velocity_model.editanything_ref_start_block
block.editanything_ref_end_block = velocity_model.editanything_ref_end_block
block.editanything_ref_context_scale = velocity_model.editanything_ref_context_scale
velocity_model.editanything_module_loaded = True
print("[WAN2GP][LTX2] EditAnything reference module installed.")
def build_editanything_reference_conditioning(
transformer: torch.nn.Module,
ref_images,
height: int,
width: int,
video_encoder: torch.nn.Module,
dtype: torch.dtype,
device: torch.device,
tiling_config=None,
):
from .ltx_core.conditioning import VideoConditionByReferenceLatent
from .ltx_core.model.video_vae import encode_video as vae_encode_video
from .ltx_pipelines.utils.media_io import load_image_conditioning
velocity_model = getattr(transformer, "velocity_model", transformer)
if not getattr(velocity_model, "editanything_module_loaded", False) or not ref_images:
return [], None, None
ref_image = ref_images[0] if isinstance(ref_images, (list, tuple)) else ref_images
image = load_image_conditioning(ref_image, height=height, width=width, dtype=dtype, device=device, resample="lanczos")
ref_latent = vae_encode_video(image, video_encoder, tiling_config).to(dtype=dtype)
conditionings = [VideoConditionByReferenceLatent(ref_latent, strength=1.0)]
ref_context = ref_adaln = None
visual_proj = getattr(velocity_model, "editanything_ref_visual_proj", None)
if visual_proj is not None:
visual_param = next(visual_proj.parameters())
ref_context = visual_proj(ref_latent.to(device=device, dtype=visual_param.dtype), velocity_model.editanything_ref_token_scale).detach()
adaln_proj = getattr(velocity_model, "editanything_ref_adaln_proj", None)
if adaln_proj is not None:
adaln_param = next(adaln_proj.parameters())
ref_adaln = adaln_proj(ref_latent.to(device=device, dtype=adaln_param.dtype), velocity_model.editanything_adaln_scale).detach()
return conditionings, ref_context, ref_adaln