|
|
import os |
|
|
import torch |
|
|
import torch.amp as amp |
|
|
import torch.nn.functional as F |
|
|
from einops import rearrange |
|
|
from safetensors.torch import load_file |
|
|
|
|
|
|
|
|
|
|
|
def expand_tensor_dims(tensor, ndim): |
|
|
while len(tensor.shape) < ndim: |
|
|
tensor = tensor.unsqueeze(-1) |
|
|
return tensor |
|
|
|
|
|
|
|
|
|
|
|
def vae_encode(vae, images, dtype=torch.bfloat16, vae_type="wanx"): |
|
|
if vae_type in ["wanx"]: |
|
|
images = batch2list(images) |
|
|
latents = vae.encode(images) |
|
|
latents = list2batch(latents) |
|
|
elif vae_type in ["ltx"]: |
|
|
with amp.autocast("cuda", dtype=dtype): |
|
|
latents = vae.encode(images).latent_dist.sample() |
|
|
|
|
|
latents_mean = vae.latents_mean |
|
|
latents_std = vae.latents_std |
|
|
scaling_factor = 1.0 |
|
|
|
|
|
latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) |
|
|
latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) |
|
|
latents = (latents - latents_mean) * scaling_factor / latents_std |
|
|
|
|
|
return latents |
|
|
|
|
|
def vae_decode(vae, latents, dtype=torch.bfloat16, vae_type="wanx"): |
|
|
if vae_type in ["wanx"]: |
|
|
latents = batch2list(latents) |
|
|
images = vae.decode(latents) |
|
|
images = list2batch(images) |
|
|
|
|
|
elif vae_type in ["ltx"]: |
|
|
latents_mean = vae.latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) |
|
|
latents_std = vae.latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) |
|
|
scaling_factor = 1.0 |
|
|
latents = latents * latents_std / scaling_factor + latents_mean |
|
|
|
|
|
with amp.autocast("cuda", dtype=dtype): |
|
|
images = vae.decode(latents, return_dict=False)[0] |
|
|
|
|
|
return images |
|
|
|
|
|
|
|
|
def image_encode( |
|
|
image_encoder, |
|
|
image, |
|
|
last_image=None, |
|
|
image_encoder_type="wanx" |
|
|
): |
|
|
if image_encoder_type in ["wanx"]: |
|
|
if image.ndim == 5: |
|
|
image = image[:,:,0] |
|
|
image = rearrange(image, "b c h w -> c b h w") |
|
|
|
|
|
if last_image is not None: |
|
|
if last_image.ndim == 5: |
|
|
last_image = last_image[:,:,0] |
|
|
|
|
|
last_image = rearrange(last_image, "b c h w -> c b h w") |
|
|
image_embeds = image_encoder.visual([image, last_image]) |
|
|
else: |
|
|
image_embeds = image_encoder.visual([image]) |
|
|
|
|
|
return image_embeds |
|
|
|
|
|
|
|
|
def pack_latents(latents, patch_size=1, patch_size_t=1): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
batch_size, num_channels, num_frames, height, width = latents.shape |
|
|
post_patch_num_frames = num_frames // patch_size_t |
|
|
post_patch_height = height // patch_size |
|
|
post_patch_width = width // patch_size |
|
|
latents = latents.reshape( |
|
|
batch_size, |
|
|
-1, |
|
|
post_patch_num_frames, |
|
|
patch_size_t, |
|
|
post_patch_height, |
|
|
patch_size, |
|
|
post_patch_width, |
|
|
patch_size, |
|
|
) |
|
|
latents = latents.permute(0, 2, 4, 6, 1, 3, 5, 7).flatten(4, 7).flatten(1, 3) |
|
|
latents = latents.contiguous() |
|
|
return latents |
|
|
|
|
|
|
|
|
def unpack_latents(latents, num_frames, height, width, patch_size=1, patch_size_t=1): |
|
|
|
|
|
|
|
|
|
|
|
batch_size = latents.size(0) |
|
|
latents = latents.reshape( |
|
|
batch_size, num_frames, height, width, -1, patch_size_t, patch_size, patch_size |
|
|
) |
|
|
latents = ( |
|
|
latents.permute(0, 4, 1, 5, 2, 6, 3, 7) |
|
|
.flatten(6, 7) |
|
|
.flatten(4, 5) |
|
|
.flatten(2, 3) |
|
|
) |
|
|
latents = latents.contiguous() |
|
|
return latents |
|
|
|
|
|
|
|
|
|
|
|
def prompt2states( |
|
|
prompt, |
|
|
text_encoder, |
|
|
device="cuda:0", |
|
|
tokenizer=None, |
|
|
max_length=128, |
|
|
text_encoder_type="wanx", |
|
|
): |
|
|
if isinstance(prompt, str): |
|
|
prompt = [prompt] |
|
|
|
|
|
if text_encoder_type in ["wanx"]: |
|
|
text_states = text_encoder(prompt, device)[0] |
|
|
text_states = text_states.unsqueeze(0) |
|
|
return text_states |
|
|
elif text_encoder_type in ["ltx"]: |
|
|
text_inputs = tokenizer( |
|
|
prompt, |
|
|
padding="max_length", |
|
|
max_length=max_length, |
|
|
truncation=True, |
|
|
add_special_tokens=True, |
|
|
return_tensors="pt", |
|
|
) |
|
|
text_ids = text_inputs.input_ids.to(device) |
|
|
text_mask = text_inputs.attention_mask |
|
|
text_mask = text_mask.bool().to(device) |
|
|
text_states = text_encoder(text_ids)[0] |
|
|
|
|
|
return text_states, text_mask |
|
|
|
|
|
|
|
|
def load_lora_for_pipeline( |
|
|
pipeline, |
|
|
lora_path, |
|
|
LORA_PREFIX_TRANSFORMER="", |
|
|
LORA_PREFIX_TEXT_ENCODER="", |
|
|
alpha=1.0, |
|
|
rank=0, |
|
|
): |
|
|
|
|
|
state_dict = load_file(lora_path, device=rank) |
|
|
|
|
|
visited = [] |
|
|
|
|
|
|
|
|
for key in state_dict: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if "alpha" in key or key in visited: |
|
|
continue |
|
|
|
|
|
if "text" in key: |
|
|
layer_infos = ( |
|
|
key.split(".")[0].split(LORA_PREFIX_TEXT_ENCODER + "_")[-1].split("_") |
|
|
) |
|
|
curr_layer = pipeline.text_encoder |
|
|
else: |
|
|
layer_infos = ( |
|
|
key.split(".")[0].split(LORA_PREFIX_TRANSFORMER + "_")[-1].split("_") |
|
|
) |
|
|
curr_layer = pipeline.transformer |
|
|
|
|
|
|
|
|
temp_name = layer_infos.pop(0) |
|
|
while len(layer_infos) > -1: |
|
|
try: |
|
|
curr_layer = curr_layer.__getattr__(temp_name) |
|
|
if len(layer_infos) > 0: |
|
|
temp_name = layer_infos.pop(0) |
|
|
elif len(layer_infos) == 0: |
|
|
break |
|
|
except Exception: |
|
|
if len(temp_name) > 0: |
|
|
temp_name += "_" + layer_infos.pop(0) |
|
|
else: |
|
|
temp_name = layer_infos.pop(0) |
|
|
|
|
|
pair_keys = [] |
|
|
if "lora_down" in key: |
|
|
pair_keys.append(key.replace("lora_down", "lora_up")) |
|
|
pair_keys.append(key) |
|
|
else: |
|
|
pair_keys.append(key) |
|
|
pair_keys.append(key.replace("lora_up", "lora_down")) |
|
|
|
|
|
|
|
|
if len(state_dict[pair_keys[0]].shape) == 4: |
|
|
weight_up = state_dict[pair_keys[0]].squeeze(3).squeeze(2).to(torch.float32) |
|
|
weight_down = ( |
|
|
state_dict[pair_keys[1]].squeeze(3).squeeze(2).to(torch.float32) |
|
|
) |
|
|
curr_layer.weight.data += alpha * torch.mm( |
|
|
weight_up, weight_down |
|
|
).unsqueeze(2).unsqueeze(3) |
|
|
else: |
|
|
weight_up = state_dict[pair_keys[0]].to(torch.float32) |
|
|
weight_down = state_dict[pair_keys[1]].to(torch.float32) |
|
|
curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down) |
|
|
|
|
|
|
|
|
for item in pair_keys: |
|
|
visited.append(item) |
|
|
del state_dict |
|
|
|
|
|
return pipeline |
|
|
|
|
|
|
|
|
def load_lora_for_model( |
|
|
model, |
|
|
lora_path, |
|
|
LORA_PREFIX_TRANSFORMER="", |
|
|
LORA_PREFIX_TEXT_ENCODER="", |
|
|
alpha=1.0, |
|
|
rank=0, |
|
|
): |
|
|
|
|
|
state_dict = load_file(lora_path, device="cpu") |
|
|
|
|
|
visited = [] |
|
|
|
|
|
|
|
|
for key in state_dict: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if "alpha" in key or key in visited: |
|
|
continue |
|
|
|
|
|
layer_infos = ( |
|
|
key.split(".")[0].split(LORA_PREFIX_TRANSFORMER + "_")[-1].split("_") |
|
|
) |
|
|
curr_layer = model |
|
|
|
|
|
|
|
|
temp_name = layer_infos.pop(0) |
|
|
while len(layer_infos) > -1: |
|
|
try: |
|
|
curr_layer = curr_layer.__getattr__(temp_name) |
|
|
if len(layer_infos) > 0: |
|
|
temp_name = layer_infos.pop(0) |
|
|
elif len(layer_infos) == 0: |
|
|
break |
|
|
except Exception: |
|
|
if len(temp_name) > 0: |
|
|
temp_name += "_" + layer_infos.pop(0) |
|
|
else: |
|
|
temp_name = layer_infos.pop(0) |
|
|
|
|
|
pair_keys = [] |
|
|
if "lora_down" in key: |
|
|
pair_keys.append(key.replace("lora_down", "lora_up")) |
|
|
pair_keys.append(key) |
|
|
else: |
|
|
pair_keys.append(key) |
|
|
pair_keys.append(key.replace("lora_up", "lora_down")) |
|
|
|
|
|
|
|
|
if len(state_dict[pair_keys[0]].shape) == 4: |
|
|
weight_up = state_dict[pair_keys[0]].squeeze(3).squeeze(2).to(torch.float32) |
|
|
weight_down = ( |
|
|
state_dict[pair_keys[1]].squeeze(3).squeeze(2).to(torch.float32) |
|
|
) |
|
|
curr_layer.weight.data += alpha * torch.mm( |
|
|
weight_up, weight_down |
|
|
).unsqueeze(2).unsqueeze(3) |
|
|
else: |
|
|
weight_up = state_dict[pair_keys[0]].to(torch.float32) |
|
|
weight_down = state_dict[pair_keys[1]].to(torch.float32) |
|
|
curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down) |
|
|
|
|
|
|
|
|
for item in pair_keys: |
|
|
visited.append(item) |
|
|
del state_dict |
|
|
|
|
|
return model |
|
|
|
|
|
|
|
|
def load_lora_state_dict(lora_dir): |
|
|
lora_path = os.path.join(lora_dir, 'pytorch_lora_transformers_weights.safetensors') |
|
|
lora_weights = load_file(lora_path) |
|
|
load_lora_weights = {} |
|
|
for key in lora_weights: |
|
|
load_lora_weights[key.replace('.weight','.default.weight')] = lora_weights[key] |
|
|
|
|
|
return load_lora_weights |
|
|
|
|
|
|
|
|
def transformer_zero_init(transformer): |
|
|
for p in transformer.parameters(): |
|
|
if p.dim() > 1: |
|
|
torch.nn.init.zeros_(p.data) |
|
|
else: |
|
|
torch.nn.init.normal_(p.data) |
|
|
|
|
|
return transformer |
|
|
|
|
|
|
|
|
def prepare_video_condition_wanx( |
|
|
vae, |
|
|
video, |
|
|
mask_strategy=[0.4, 0.25, 0.3, 0.05], |
|
|
): |
|
|
|
|
|
mask_id = torch.multinomial(torch.tensor(mask_strategy), num_samples=1).item() |
|
|
bsz, _, num_frames, height, width = video.shape |
|
|
latents_height, latents_width = height // 8, width // 8 |
|
|
|
|
|
|
|
|
if mask_id == 0: |
|
|
mask = torch.cat([ |
|
|
torch.ones(bsz, 1, 1, height, width), |
|
|
torch.zeros(bsz, 1, num_frames-1, height, width) |
|
|
], dim=2) |
|
|
|
|
|
elif mask_id == 1: |
|
|
mid_frame = (num_frames - 1) // 2 + 1 |
|
|
mask = torch.cat([ |
|
|
torch.ones(bsz, 1, mid_frame, height, width), |
|
|
torch.zeros(bsz, 1, num_frames-mid_frame, height, width) |
|
|
], dim=2) |
|
|
|
|
|
elif mask_id == 2: |
|
|
mask = torch.cat([ |
|
|
torch.ones(bsz, 1, 1, height, width), |
|
|
torch.zeros(bsz, 1, num_frames-2, height, width), |
|
|
torch.ones(bsz, 1, 1, height, width) |
|
|
], dim=2) |
|
|
|
|
|
elif mask_id == 3: |
|
|
num_masked = torch.randint(1, num_frames, (bsz,)).item() |
|
|
indices = torch.randperm(num_frames)[:num_masked].sort().values |
|
|
mask = torch.zeros(bsz, 1, num_frames, height, width) |
|
|
mask[:,:, indices] = 1 |
|
|
|
|
|
|
|
|
mask = mask.to(video.device, dtype=video.dtype) |
|
|
mask_lat_size = torch.cat([ |
|
|
torch.repeat_interleave(mask[:,:,:1,:,:], dim=2, repeats=4), |
|
|
mask[:,:,1:,:,:], |
|
|
], dim=2) |
|
|
mask_lat_size = mask_lat_size[:,:,:,::8,::8] |
|
|
mask_lat_size = mask_lat_size.view(bsz, -1, 4, latents_height, latents_width).transpose(1,2) |
|
|
|
|
|
|
|
|
video_condition = video * mask |
|
|
latents_condition = torch.cat([ |
|
|
mask_lat_size, |
|
|
vae_encode(vae, video_condition, "wanx") |
|
|
], dim=1) |
|
|
|
|
|
return latents_condition |
|
|
|
|
|
|
|
|
def batch2list(batch): |
|
|
return [item for item in batch] |
|
|
|
|
|
def list2batch(list): |
|
|
return torch.stack(list) |
|
|
|
|
|
|
|
|
def stable_mse_loss(model_pred, target, weighting=None, threshold=50): |
|
|
if weighting is None: |
|
|
weighting = torch.ones_like(target) |
|
|
|
|
|
diff = model_pred - target |
|
|
mask = (diff.abs() <= threshold).float() |
|
|
loss = F.mse_loss(model_pred, target, reduction="none") |
|
|
masked_loss = weighting * mask * loss |
|
|
masked_loss = masked_loss.mean() |
|
|
|
|
|
return masked_loss |