HY-Video-PRFL / diffusers_lite /utils /diffusion_utils.py
Camellia997's picture
Upload folder using huggingface_hub
e14f899 verified
import os
import torch
import torch.amp as amp
import torch.nn.functional as F
from einops import rearrange
from safetensors.torch import load_file
# tensor
def expand_tensor_dims(tensor, ndim):
while len(tensor.shape) < ndim:
tensor = tensor.unsqueeze(-1)
return tensor
# vae
def vae_encode(vae, images, dtype=torch.bfloat16, vae_type="wanx"):
if vae_type in ["wanx"]:
images = batch2list(images)
latents = vae.encode(images)
latents = list2batch(latents)
elif vae_type in ["ltx"]:
with amp.autocast("cuda", dtype=dtype):
latents = vae.encode(images).latent_dist.sample()
latents_mean = vae.latents_mean
latents_std = vae.latents_std
scaling_factor = 1.0
latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
latents = (latents - latents_mean) * scaling_factor / latents_std
return latents
def vae_decode(vae, latents, dtype=torch.bfloat16, vae_type="wanx"):
if vae_type in ["wanx"]:
latents = batch2list(latents)
images = vae.decode(latents)
images = list2batch(images)
elif vae_type in ["ltx"]:
latents_mean = vae.latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
latents_std = vae.latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
scaling_factor = 1.0
latents = latents * latents_std / scaling_factor + latents_mean
with amp.autocast("cuda", dtype=dtype):
images = vae.decode(latents, return_dict=False)[0]
return images
def image_encode(
image_encoder,
image,
last_image=None,
image_encoder_type="wanx"
):
if image_encoder_type in ["wanx"]:
if image.ndim == 5:
image = image[:,:,0]
image = rearrange(image, "b c h w -> c b h w")
if last_image is not None:
if last_image.ndim == 5:
last_image = last_image[:,:,0]
last_image = rearrange(last_image, "b c h w -> c b h w")
image_embeds = image_encoder.visual([image, last_image])
else:
image_embeds = image_encoder.visual([image])
return image_embeds
def pack_latents(latents, patch_size=1, patch_size_t=1):
# Unpacked latents of shape are [B, C, F, H, W] are patched into tokens of shape [B, C, F // p_t, p_t, H // p, p, W // p, p].
# The patch dimensions are then permuted and collapsed into the channel dimension of shape:
# [B, F // p_t * H // p * W // p, C * p_t * p * p] (an ndim=3 tensor).
# dim=0 is the batch size, dim=1 is the effective video sequence length, dim=2 is the effective number of input features
batch_size, num_channels, num_frames, height, width = latents.shape
post_patch_num_frames = num_frames // patch_size_t
post_patch_height = height // patch_size
post_patch_width = width // patch_size
latents = latents.reshape(
batch_size,
-1,
post_patch_num_frames,
patch_size_t,
post_patch_height,
patch_size,
post_patch_width,
patch_size,
)
latents = latents.permute(0, 2, 4, 6, 1, 3, 5, 7).flatten(4, 7).flatten(1, 3)
latents = latents.contiguous()
return latents
def unpack_latents(latents, num_frames, height, width, patch_size=1, patch_size_t=1):
# Packed latents of shape [B, S, D] (S is the effective video sequence length, D is the effective feature dimensions)
# are unpacked and reshaped into a video tensor of shape [B, C, F, H, W]. This is the inverse operation of
# what happens in the `_pack_latents` method.
batch_size = latents.size(0)
latents = latents.reshape(
batch_size, num_frames, height, width, -1, patch_size_t, patch_size, patch_size
)
latents = (
latents.permute(0, 4, 1, 5, 2, 6, 3, 7)
.flatten(6, 7)
.flatten(4, 5)
.flatten(2, 3)
)
latents = latents.contiguous()
return latents
# text encoder
def prompt2states(
prompt,
text_encoder,
device="cuda:0",
tokenizer=None,
max_length=128,
text_encoder_type="wanx",
):
if isinstance(prompt, str):
prompt = [prompt]
if text_encoder_type in ["wanx"]:
text_states = text_encoder(prompt, device)[0]
text_states = text_states.unsqueeze(0)
return text_states
elif text_encoder_type in ["ltx"]:
text_inputs = tokenizer(
prompt,
padding="max_length",
max_length=max_length,
truncation=True,
add_special_tokens=True,
return_tensors="pt",
)
text_ids = text_inputs.input_ids.to(device)
text_mask = text_inputs.attention_mask
text_mask = text_mask.bool().to(device)
text_states = text_encoder(text_ids)[0]
return text_states, text_mask
def load_lora_for_pipeline(
pipeline,
lora_path,
LORA_PREFIX_TRANSFORMER="",
LORA_PREFIX_TEXT_ENCODER="",
alpha=1.0,
rank=0,
):
# load LoRA weight from .safetensors
state_dict = load_file(lora_path, device=rank)
visited = []
# directly update weight in diffusers model
for key in state_dict:
# it is suggested to print out the key, it usually will be something like below
# "lora_te_text_model_encoder_layers_0_self_attn_k_proj.lora_down.weight"
# as we have set the alpha beforehand, so just skip
if "alpha" in key or key in visited:
continue
if "text" in key:
layer_infos = (
key.split(".")[0].split(LORA_PREFIX_TEXT_ENCODER + "_")[-1].split("_")
)
curr_layer = pipeline.text_encoder
else:
layer_infos = (
key.split(".")[0].split(LORA_PREFIX_TRANSFORMER + "_")[-1].split("_")
)
curr_layer = pipeline.transformer
# find the target layer
temp_name = layer_infos.pop(0)
while len(layer_infos) > -1:
try:
curr_layer = curr_layer.__getattr__(temp_name)
if len(layer_infos) > 0:
temp_name = layer_infos.pop(0)
elif len(layer_infos) == 0:
break
except Exception:
if len(temp_name) > 0:
temp_name += "_" + layer_infos.pop(0)
else:
temp_name = layer_infos.pop(0)
pair_keys = []
if "lora_down" in key:
pair_keys.append(key.replace("lora_down", "lora_up"))
pair_keys.append(key)
else:
pair_keys.append(key)
pair_keys.append(key.replace("lora_up", "lora_down"))
# update weight
if len(state_dict[pair_keys[0]].shape) == 4:
weight_up = state_dict[pair_keys[0]].squeeze(3).squeeze(2).to(torch.float32)
weight_down = (
state_dict[pair_keys[1]].squeeze(3).squeeze(2).to(torch.float32)
)
curr_layer.weight.data += alpha * torch.mm(
weight_up, weight_down
).unsqueeze(2).unsqueeze(3)
else:
weight_up = state_dict[pair_keys[0]].to(torch.float32)
weight_down = state_dict[pair_keys[1]].to(torch.float32)
curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down)
# update visited list
for item in pair_keys:
visited.append(item)
del state_dict
return pipeline
def load_lora_for_model(
model,
lora_path,
LORA_PREFIX_TRANSFORMER="",
LORA_PREFIX_TEXT_ENCODER="",
alpha=1.0,
rank=0,
):
# load LoRA weight from .safetensors
state_dict = load_file(lora_path, device="cpu")
visited = []
# directly update weight in diffusers model
for key in state_dict:
# it is suggested to print out the key, it usually will be something like below
# "lora_te_text_model_encoder_layers_0_self_attn_k_proj.lora_down.weight"
# as we have set the alpha beforehand, so just skip
if "alpha" in key or key in visited:
continue
layer_infos = (
key.split(".")[0].split(LORA_PREFIX_TRANSFORMER + "_")[-1].split("_")
)
curr_layer = model
# find the target layer
temp_name = layer_infos.pop(0)
while len(layer_infos) > -1:
try:
curr_layer = curr_layer.__getattr__(temp_name)
if len(layer_infos) > 0:
temp_name = layer_infos.pop(0)
elif len(layer_infos) == 0:
break
except Exception:
if len(temp_name) > 0:
temp_name += "_" + layer_infos.pop(0)
else:
temp_name = layer_infos.pop(0)
pair_keys = []
if "lora_down" in key:
pair_keys.append(key.replace("lora_down", "lora_up"))
pair_keys.append(key)
else:
pair_keys.append(key)
pair_keys.append(key.replace("lora_up", "lora_down"))
# update weight
if len(state_dict[pair_keys[0]].shape) == 4:
weight_up = state_dict[pair_keys[0]].squeeze(3).squeeze(2).to(torch.float32)
weight_down = (
state_dict[pair_keys[1]].squeeze(3).squeeze(2).to(torch.float32)
)
curr_layer.weight.data += alpha * torch.mm(
weight_up, weight_down
).unsqueeze(2).unsqueeze(3)
else:
weight_up = state_dict[pair_keys[0]].to(torch.float32)
weight_down = state_dict[pair_keys[1]].to(torch.float32)
curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down)
# update visited list
for item in pair_keys:
visited.append(item)
del state_dict
return model
def load_lora_state_dict(lora_dir):
lora_path = os.path.join(lora_dir, 'pytorch_lora_transformers_weights.safetensors')
lora_weights = load_file(lora_path)
load_lora_weights = {}
for key in lora_weights:
load_lora_weights[key.replace('.weight','.default.weight')] = lora_weights[key]
return load_lora_weights
def transformer_zero_init(transformer):
for p in transformer.parameters():
if p.dim() > 1:
torch.nn.init.zeros_(p.data)
else:
torch.nn.init.normal_(p.data)
return transformer
def prepare_video_condition_wanx(
vae,
video,
mask_strategy=[0.4, 0.25, 0.3, 0.05],
):
# Get mask strategy
mask_id = torch.multinomial(torch.tensor(mask_strategy), num_samples=1).item()
bsz, _, num_frames, height, width = video.shape
latents_height, latents_width = height // 8, width // 8
# Get video mask
if mask_id == 0:
mask = torch.cat([
torch.ones(bsz, 1, 1, height, width),
torch.zeros(bsz, 1, num_frames-1, height, width)
], dim=2)
elif mask_id == 1:
mid_frame = (num_frames - 1) // 2 + 1
mask = torch.cat([
torch.ones(bsz, 1, mid_frame, height, width),
torch.zeros(bsz, 1, num_frames-mid_frame, height, width)
], dim=2)
elif mask_id == 2:
mask = torch.cat([
torch.ones(bsz, 1, 1, height, width),
torch.zeros(bsz, 1, num_frames-2, height, width),
torch.ones(bsz, 1, 1, height, width)
], dim=2)
elif mask_id == 3:
num_masked = torch.randint(1, num_frames, (bsz,)).item()
indices = torch.randperm(num_frames)[:num_masked].sort().values
mask = torch.zeros(bsz, 1, num_frames, height, width)
mask[:,:, indices] = 1
# Encode video mask
mask = mask.to(video.device, dtype=video.dtype)
mask_lat_size = torch.cat([
torch.repeat_interleave(mask[:,:,:1,:,:], dim=2, repeats=4),
mask[:,:,1:,:,:],
], dim=2)
mask_lat_size = mask_lat_size[:,:,:,::8,::8]
mask_lat_size = mask_lat_size.view(bsz, -1, 4, latents_height, latents_width).transpose(1,2)
# Encode video condition
video_condition = video * mask
latents_condition = torch.cat([
mask_lat_size,
vae_encode(vae, video_condition, "wanx")
], dim=1)
return latents_condition
def batch2list(batch):
return [item for item in batch]
def list2batch(list):
return torch.stack(list)
def stable_mse_loss(model_pred, target, weighting=None, threshold=50):
if weighting is None:
weighting = torch.ones_like(target)
diff = model_pred - target
mask = (diff.abs() <= threshold).float()
loss = F.mse_loss(model_pred, target, reduction="none")
masked_loss = weighting * mask * loss
masked_loss = masked_loss.mean()
return masked_loss