Spaces:
Running
on
Zero
Running
on
Zero
| import torch | |
| import torch.nn as nn | |
| from einops import rearrange, repeat | |
| from transformers import get_scheduler | |
| from .modules.clip import clip_xlm_roberta_vit_h_14 | |
| from .wan_t2v import WanTextToVideo | |
| class WanImageToVideo(WanTextToVideo): | |
| """ | |
| Main class for WanImageToVideo, inheriting from WanTextToVideo | |
| """ | |
| def __init__(self, cfg): | |
| super().__init__(cfg) | |
| self.cfg.model.in_dim = self.cfg.vae.z_dim * 2 + 4 | |
| def configure_model(self): | |
| # Call parent's configure_model first | |
| super().configure_model() | |
| if self.cfg.model.tuned_ckpt_path is None: | |
| self.model.hack_embedding_ckpt() | |
| # Additionally initialize CLIP for image encoding | |
| clip, clip_transform = clip_xlm_roberta_vit_h_14( | |
| pretrained=False, | |
| return_transforms=True, | |
| return_tokenizer=False, | |
| dtype=torch.float16 if self.is_inference else self.dtype, | |
| device="cpu", | |
| ) | |
| if self.cfg.clip.ckpt_path is not None: | |
| clip.load_state_dict( | |
| torch.load( | |
| self.cfg.clip.ckpt_path, map_location="cpu", weights_only=True | |
| ) | |
| ) | |
| if self.cfg.clip.compile: | |
| clip = torch.compile(clip) | |
| self.clip = clip | |
| self.clip_normalize = clip_transform.transforms[-1] | |
| def configure_optimizers(self): | |
| optimizer = torch.optim.AdamW( | |
| [ | |
| {"params": self.model.parameters(), "lr": self.cfg.lr}, | |
| {"params": self.vae.parameters(), "lr": 0}, | |
| {"params": self.clip.parameters(), "lr": 0}, | |
| ], | |
| weight_decay=self.cfg.weight_decay, | |
| betas=self.cfg.betas, | |
| ) | |
| # optimizer = torch.optim.AdamW( | |
| # self.model.parameters(), | |
| # lr=self.cfg.lr, | |
| # weight_decay=self.cfg.weight_decay, | |
| # betas=self.cfg.betas, | |
| # ) | |
| lr_scheduler_config = { | |
| "scheduler": get_scheduler( | |
| optimizer=optimizer, | |
| **self.cfg.lr_scheduler, | |
| ), | |
| "interval": "step", | |
| "frequency": 1, | |
| } | |
| return { | |
| "optimizer": optimizer, | |
| "lr_scheduler": lr_scheduler_config, | |
| } | |
| def clip_features(self, videos): | |
| size = (self.clip.image_size,) * 2 | |
| videos = rearrange(videos, "b t c h w -> (b t) c h w") | |
| videos = nn.functional.interpolate( | |
| videos, size=size, mode="bicubic", align_corners=False | |
| ) | |
| videos = self.clip_normalize(videos.mul_(0.5).add_(0.5)) | |
| return self.clip.visual(videos, use_31_block=True) | |
| def prepare_embeds(self, batch): | |
| batch = super().prepare_embeds(batch) | |
| videos = batch["videos"] | |
| images = videos[:, :1] | |
| has_bbox = batch["has_bbox"] # [B, 2] | |
| bbox_render = batch["bbox_render"] # [B, 2, H, W] | |
| batch_size, t, _, h, w = videos.shape | |
| lat_c, lat_t, lat_h, lat_w = self.lat_c, self.lat_t, self.lat_h, self.lat_w | |
| clip_embeds = self.clip_features(images) | |
| batch["clip_embeds"] = clip_embeds | |
| mask = torch.zeros( | |
| batch_size, | |
| self.vae_stride[0], | |
| lat_t, | |
| lat_h, | |
| lat_w, | |
| device=self.device, | |
| dtype=self.dtype, | |
| ) | |
| # after the ckpt hack, we repurpose the 4 mask channels for bounding box conditioning | |
| # second last channel is indicator of bounding box | |
| mask[:, 2, 0] = has_bbox[..., 0, None, None] | |
| mask[:, 2, -1] = has_bbox[..., -1, None, None] | |
| # Interpolate bbox_render to match latent dimensions | |
| bbox_render_resized = nn.functional.interpolate( | |
| bbox_render, | |
| size=(lat_h, lat_w), | |
| mode="bicubic", | |
| align_corners=False, | |
| ) | |
| # last channel is renderred bbox | |
| mask[:, 3, 0] = bbox_render_resized[:, 0] | |
| mask[:, 3, -1] = bbox_render_resized[:, -1] | |
| if self.diffusion_forcing.enabled: | |
| image_embeds = torch.zeros( | |
| batch_size, | |
| 4 + lat_c, | |
| lat_t, | |
| lat_h, | |
| lat_w, | |
| device=self.device, | |
| dtype=self.dtype, | |
| ) | |
| else: | |
| padded_images = torch.zeros(batch_size, 3, t - 1, h, w, device=self.device) | |
| padded_images = torch.cat( | |
| [rearrange(images, "b 1 c h w -> b c 1 h w"), padded_images], dim=2 | |
| ) | |
| image_embeds = self.encode_video( | |
| padded_images | |
| ) # b, lat_c, lat_t, lat_h, lat_w | |
| image_embeds = torch.cat([mask, image_embeds], 1) | |
| mask[:, :2, 0] = 1 | |
| batch["image_embeds"] = image_embeds | |
| return batch | |
| def visualize(self, video_pred, batch): | |
| bbox_render = batch["bbox_render"] # b, 2, h, w for first and last frame | |
| has_bbox = batch["has_bbox"] # b, 2 for first and last frame | |
| video_gt = batch["videos"] # b, t, 3, h, w | |
| alpha = 0.4 | |
| l = video_gt.shape[1] // 4 | |
| # Apply green bbox overlay with transparency to first frame if has_bbox for first frame | |
| mask = has_bbox[:, 0].bool() | |
| green = torch.zeros_like(video_gt[mask, :1]) | |
| green[:, :, 1] = 1.0 | |
| if mask.any(): | |
| bbox = bbox_render[:, None, 0:1][mask] * alpha # b', 1, 1, h, w | |
| video_gt[mask, :l] = (1 - bbox) * video_gt[mask, :l] + bbox * green | |
| # Apply green bbox overlay with transparency to last frame if has_bbox for last frame | |
| mask = has_bbox[:, 1].bool() | |
| green = torch.zeros_like(video_gt[mask, :1]) | |
| green[:, :, 1] = 1.0 | |
| if mask.any(): | |
| bbox = bbox_render[:, None, 1:2][mask] * alpha # b', 1, 1, h, w | |
| video_gt[mask, -l:] = (1 - bbox) * video_gt[mask, -l:] + bbox * green | |
| batch["videos"] = video_gt | |
| return super().visualize(video_pred, batch) | |