LVP / algorithms /wan /wan_i2v.py
kiwhansong's picture
add demo
142a1ac
import torch
import torch.nn as nn
from einops import rearrange, repeat
from transformers import get_scheduler
from .modules.clip import clip_xlm_roberta_vit_h_14
from .wan_t2v import WanTextToVideo
class WanImageToVideo(WanTextToVideo):
"""
Main class for WanImageToVideo, inheriting from WanTextToVideo
"""
def __init__(self, cfg):
super().__init__(cfg)
self.cfg.model.in_dim = self.cfg.vae.z_dim * 2 + 4
def configure_model(self):
# Call parent's configure_model first
super().configure_model()
if self.cfg.model.tuned_ckpt_path is None:
self.model.hack_embedding_ckpt()
# Additionally initialize CLIP for image encoding
clip, clip_transform = clip_xlm_roberta_vit_h_14(
pretrained=False,
return_transforms=True,
return_tokenizer=False,
dtype=torch.float16 if self.is_inference else self.dtype,
device="cpu",
)
if self.cfg.clip.ckpt_path is not None:
clip.load_state_dict(
torch.load(
self.cfg.clip.ckpt_path, map_location="cpu", weights_only=True
)
)
if self.cfg.clip.compile:
clip = torch.compile(clip)
self.clip = clip
self.clip_normalize = clip_transform.transforms[-1]
def configure_optimizers(self):
optimizer = torch.optim.AdamW(
[
{"params": self.model.parameters(), "lr": self.cfg.lr},
{"params": self.vae.parameters(), "lr": 0},
{"params": self.clip.parameters(), "lr": 0},
],
weight_decay=self.cfg.weight_decay,
betas=self.cfg.betas,
)
# optimizer = torch.optim.AdamW(
# self.model.parameters(),
# lr=self.cfg.lr,
# weight_decay=self.cfg.weight_decay,
# betas=self.cfg.betas,
# )
lr_scheduler_config = {
"scheduler": get_scheduler(
optimizer=optimizer,
**self.cfg.lr_scheduler,
),
"interval": "step",
"frequency": 1,
}
return {
"optimizer": optimizer,
"lr_scheduler": lr_scheduler_config,
}
def clip_features(self, videos):
size = (self.clip.image_size,) * 2
videos = rearrange(videos, "b t c h w -> (b t) c h w")
videos = nn.functional.interpolate(
videos, size=size, mode="bicubic", align_corners=False
)
videos = self.clip_normalize(videos.mul_(0.5).add_(0.5))
return self.clip.visual(videos, use_31_block=True)
@torch.no_grad()
def prepare_embeds(self, batch):
batch = super().prepare_embeds(batch)
videos = batch["videos"]
images = videos[:, :1]
has_bbox = batch["has_bbox"] # [B, 2]
bbox_render = batch["bbox_render"] # [B, 2, H, W]
batch_size, t, _, h, w = videos.shape
lat_c, lat_t, lat_h, lat_w = self.lat_c, self.lat_t, self.lat_h, self.lat_w
clip_embeds = self.clip_features(images)
batch["clip_embeds"] = clip_embeds
mask = torch.zeros(
batch_size,
self.vae_stride[0],
lat_t,
lat_h,
lat_w,
device=self.device,
dtype=self.dtype,
)
# after the ckpt hack, we repurpose the 4 mask channels for bounding box conditioning
# second last channel is indicator of bounding box
mask[:, 2, 0] = has_bbox[..., 0, None, None]
mask[:, 2, -1] = has_bbox[..., -1, None, None]
# Interpolate bbox_render to match latent dimensions
bbox_render_resized = nn.functional.interpolate(
bbox_render,
size=(lat_h, lat_w),
mode="bicubic",
align_corners=False,
)
# last channel is renderred bbox
mask[:, 3, 0] = bbox_render_resized[:, 0]
mask[:, 3, -1] = bbox_render_resized[:, -1]
if self.diffusion_forcing.enabled:
image_embeds = torch.zeros(
batch_size,
4 + lat_c,
lat_t,
lat_h,
lat_w,
device=self.device,
dtype=self.dtype,
)
else:
padded_images = torch.zeros(batch_size, 3, t - 1, h, w, device=self.device)
padded_images = torch.cat(
[rearrange(images, "b 1 c h w -> b c 1 h w"), padded_images], dim=2
)
image_embeds = self.encode_video(
padded_images
) # b, lat_c, lat_t, lat_h, lat_w
image_embeds = torch.cat([mask, image_embeds], 1)
mask[:, :2, 0] = 1
batch["image_embeds"] = image_embeds
return batch
def visualize(self, video_pred, batch):
bbox_render = batch["bbox_render"] # b, 2, h, w for first and last frame
has_bbox = batch["has_bbox"] # b, 2 for first and last frame
video_gt = batch["videos"] # b, t, 3, h, w
alpha = 0.4
l = video_gt.shape[1] // 4
# Apply green bbox overlay with transparency to first frame if has_bbox for first frame
mask = has_bbox[:, 0].bool()
green = torch.zeros_like(video_gt[mask, :1])
green[:, :, 1] = 1.0
if mask.any():
bbox = bbox_render[:, None, 0:1][mask] * alpha # b', 1, 1, h, w
video_gt[mask, :l] = (1 - bbox) * video_gt[mask, :l] + bbox * green
# Apply green bbox overlay with transparency to last frame if has_bbox for last frame
mask = has_bbox[:, 1].bool()
green = torch.zeros_like(video_gt[mask, :1])
green[:, :, 1] = 1.0
if mask.any():
bbox = bbox_render[:, None, 1:2][mask] * alpha # b', 1, 1, h, w
video_gt[mask, -l:] = (1 - bbox) * video_gt[mask, -l:] + bbox * green
batch["videos"] = video_gt
return super().visualize(video_pred, batch)