| | import argparse |
| | import os |
| | import glob |
| | from typing import Optional, Union |
| |
|
| | import numpy as np |
| | import torch |
| | from tqdm import tqdm |
| |
|
| | from dataset import config_utils |
| | from dataset.config_utils import BlueprintGenerator, ConfigSanitizer |
| | from PIL import Image |
| |
|
| | import logging |
| |
|
| | from dataset.image_video_dataset import ItemInfo, save_latent_cache_wan, ARCHITECTURE_WAN |
| | from utils.model_utils import str_to_dtype |
| | from wan.configs import wan_i2v_14B |
| | from wan.modules.vae import WanVAE |
| | from wan.modules.clip import CLIPModel |
| | import cache_latents |
| |
|
| | logger = logging.getLogger(__name__) |
| | logging.basicConfig(level=logging.INFO) |
| |
|
| |
|
| | def encode_and_save_batch(vae: WanVAE, clip: Optional[CLIPModel], batch: list[ItemInfo]): |
| | contents = torch.stack([torch.from_numpy(item.content) for item in batch]) |
| | if len(contents.shape) == 4: |
| | contents = contents.unsqueeze(1) |
| |
|
| | contents = contents.permute(0, 4, 1, 2, 3).contiguous() |
| | contents = contents.to(vae.device, dtype=vae.dtype) |
| | contents = contents / 127.5 - 1.0 |
| |
|
| | h, w = contents.shape[3], contents.shape[4] |
| | if h < 8 or w < 8: |
| | item = batch[0] |
| | raise ValueError(f"Image or video size too small: {item.item_key} and {len(batch) - 1} more, size: {item.original_size}") |
| |
|
| | |
| | with torch.amp.autocast(device_type=vae.device.type, dtype=vae.dtype), torch.no_grad(): |
| | latent = vae.encode(contents) |
| | latent = torch.stack(latent, dim=0) |
| | latent = latent.to(vae.dtype) |
| |
|
| | if clip is not None: |
| | |
| | images = contents[:, :, 0:1, :, :] |
| |
|
| | with torch.amp.autocast(device_type=clip.device.type, dtype=torch.float16), torch.no_grad(): |
| | clip_context = clip.visual(images) |
| | clip_context = clip_context.to(torch.float16) |
| |
|
| | |
| | B, _, _, lat_h, lat_w = latent.shape |
| | F = contents.shape[2] |
| |
|
| | |
| | msk = torch.ones(1, F, lat_h, lat_w, dtype=vae.dtype, device=vae.device) |
| | msk[:, 1:] = 0 |
| | msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1) |
| | msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w) |
| | msk = msk.transpose(1, 2) |
| | msk = msk.repeat(B, 1, 1, 1, 1) |
| |
|
| | |
| | padding_frames = F - 1 |
| | images_resized = torch.concat([images, torch.zeros(B, 3, padding_frames, h, w, device=vae.device)], dim=2) |
| | with torch.amp.autocast(device_type=vae.device.type, dtype=vae.dtype), torch.no_grad(): |
| | y = vae.encode(images_resized) |
| | y = torch.stack(y, dim=0) |
| |
|
| | y = y[:, :, :F] |
| | y = y.to(vae.dtype) |
| | y = torch.concat([msk, y], dim=1) |
| |
|
| | else: |
| | clip_context = None |
| | y = None |
| |
|
| | |
| | if batch[0].control_content is not None: |
| | control_contents = torch.stack([torch.from_numpy(item.control_content) for item in batch]) |
| | if len(control_contents.shape) == 4: |
| | control_contents = control_contents.unsqueeze(1) |
| | control_contents = control_contents.permute(0, 4, 1, 2, 3).contiguous() |
| | control_contents = control_contents.to(vae.device, dtype=vae.dtype) |
| | control_contents = control_contents / 127.5 - 1.0 |
| | with torch.amp.autocast(device_type=vae.device.type, dtype=vae.dtype), torch.no_grad(): |
| | control_latent = vae.encode(control_contents) |
| | control_latent = torch.stack(control_latent, dim=0) |
| | control_latent = control_latent.to(vae.dtype) |
| | else: |
| | control_latent = None |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | for i, item in enumerate(batch): |
| | l = latent[i] |
| | cctx = clip_context[i] if clip is not None else None |
| | y_i = y[i] if clip is not None else None |
| | control_latent_i = control_latent[i] if control_latent is not None else None |
| | |
| | save_latent_cache_wan(item, l, cctx, y_i, control_latent_i) |
| |
|
| |
|
| | def main(args): |
| | device = args.device if args.device is not None else "cuda" if torch.cuda.is_available() else "cpu" |
| | device = torch.device(device) |
| |
|
| | |
| | blueprint_generator = BlueprintGenerator(ConfigSanitizer()) |
| | logger.info(f"Load dataset config from {args.dataset_config}") |
| | user_config = config_utils.load_user_config(args.dataset_config) |
| | blueprint = blueprint_generator.generate(user_config, args, architecture=ARCHITECTURE_WAN) |
| | train_dataset_group = config_utils.generate_dataset_group_by_blueprint(blueprint.dataset_group) |
| |
|
| | datasets = train_dataset_group.datasets |
| |
|
| | if args.debug_mode is not None: |
| | cache_latents.show_datasets( |
| | datasets, args.debug_mode, args.console_width, args.console_back, args.console_num_images, fps=16 |
| | ) |
| | return |
| |
|
| | assert args.vae is not None, "vae checkpoint is required" |
| |
|
| | vae_path = args.vae |
| |
|
| | logger.info(f"Loading VAE model from {vae_path}") |
| | vae_dtype = torch.bfloat16 if args.vae_dtype is None else str_to_dtype(args.vae_dtype) |
| | cache_device = torch.device("cpu") if args.vae_cache_cpu else None |
| | vae = WanVAE(vae_path=vae_path, device=device, dtype=vae_dtype, cache_device=cache_device) |
| |
|
| | if args.clip is not None: |
| | clip_dtype = wan_i2v_14B.i2v_14B["clip_dtype"] |
| | clip = CLIPModel(dtype=clip_dtype, device=device, weight_path=args.clip) |
| | else: |
| | clip = None |
| |
|
| | |
| | def encode(one_batch: list[ItemInfo]): |
| | encode_and_save_batch(vae, clip, one_batch) |
| |
|
| | cache_latents.encode_datasets(datasets, encode, args) |
| |
|
| |
|
| | def wan_setup_parser(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: |
| | parser.add_argument("--vae_cache_cpu", action="store_true", help="cache features in VAE on CPU") |
| | parser.add_argument( |
| | "--clip", |
| | type=str, |
| | default=None, |
| | help="text encoder (CLIP) checkpoint path, optional. If training I2V model, this is required", |
| | ) |
| | return parser |
| |
|
| |
|
| | if __name__ == "__main__": |
| | parser = cache_latents.setup_parser_common() |
| | parser = wan_setup_parser(parser) |
| |
|
| | args = parser.parse_args() |
| | main(args) |
| |
|