| | import argparse |
| | from typing import Optional |
| | from PIL import Image |
| |
|
| |
|
| | import numpy as np |
| | import torch |
| | import torchvision.transforms.functional as TF |
| | from tqdm import tqdm |
| | from accelerate import Accelerator, init_empty_weights |
| |
|
| | from dataset.image_video_dataset import ARCHITECTURE_WAN, ARCHITECTURE_WAN_FULL, load_video |
| | from hv_generate_video import resize_image_to_bucket |
| | from hv_train_network import NetworkTrainer, load_prompts, clean_memory_on_device, setup_parser_common, read_config_from_file |
| |
|
| | import logging |
| |
|
| | logger = logging.getLogger(__name__) |
| | logging.basicConfig(level=logging.INFO) |
| |
|
| | from utils import model_utils |
| | from utils.safetensors_utils import load_safetensors, MemoryEfficientSafeOpen |
| | from wan.configs import WAN_CONFIGS |
| | from wan.modules.clip import CLIPModel |
| | from wan.modules.model import WanModel, detect_wan_sd_dtype, load_wan_model |
| | from wan.modules.t5 import T5EncoderModel |
| | from wan.modules.vae import WanVAE |
| | from wan.utils.fm_solvers_unipc import FlowUniPCMultistepScheduler |
| |
|
| |
|
| | class WanNetworkTrainer(NetworkTrainer): |
| | def __init__(self): |
| | super().__init__() |
| |
|
| | |
| |
|
| | @property |
| | def architecture(self) -> str: |
| | return ARCHITECTURE_WAN |
| |
|
| | @property |
| | def architecture_full_name(self) -> str: |
| | return ARCHITECTURE_WAN_FULL |
| |
|
| | def handle_model_specific_args(self, args): |
| | self.config = WAN_CONFIGS[args.task] |
| | self._i2v_training = "i2v" in args.task |
| | self._control_training = self.config.is_fun_control |
| |
|
| | self.dit_dtype = detect_wan_sd_dtype(args.dit) |
| |
|
| | if self.dit_dtype == torch.float16: |
| | assert args.mixed_precision in ["fp16", "no"], "DiT weights are in fp16, mixed precision must be fp16 or no" |
| | elif self.dit_dtype == torch.bfloat16: |
| | assert args.mixed_precision in ["bf16", "no"], "DiT weights are in bf16, mixed precision must be bf16 or no" |
| |
|
| | if args.fp8_scaled and self.dit_dtype.itemsize == 1: |
| | raise ValueError( |
| | "DiT weights is already in fp8 format, cannot scale to fp8. Please use fp16/bf16 weights / DiTの重みはすでにfp8形式です。fp8にスケーリングできません。fp16/bf16の重みを使用してください" |
| | ) |
| |
|
| | |
| | if self.dit_dtype.itemsize == 1: |
| | self.dit_dtype = torch.float16 if args.mixed_precision == "fp16" else torch.bfloat16 |
| |
|
| | args.dit_dtype = model_utils.dtype_to_str(self.dit_dtype) |
| |
|
| | self.default_guidance_scale = 1.0 |
| |
|
| | def process_sample_prompts( |
| | self, |
| | args: argparse.Namespace, |
| | accelerator: Accelerator, |
| | sample_prompts: str, |
| | ): |
| | config = self.config |
| | device = accelerator.device |
| | t5_path, clip_path, fp8_t5 = args.t5, args.clip, args.fp8_t5 |
| |
|
| | logger.info(f"cache Text Encoder outputs for sample prompt: {sample_prompts}") |
| | prompts = load_prompts(sample_prompts) |
| |
|
| | def encode_for_text_encoder(text_encoder): |
| | sample_prompts_te_outputs = {} |
| | |
| | t5_dtype = config.t5_dtype |
| | with torch.amp.autocast(device_type=device.type, dtype=t5_dtype), torch.no_grad(): |
| | for prompt_dict in prompts: |
| | if "negative_prompt" not in prompt_dict: |
| | prompt_dict["negative_prompt"] = self.config["sample_neg_prompt"] |
| | for p in [prompt_dict.get("prompt", ""), prompt_dict.get("negative_prompt", None)]: |
| | if p is None: |
| | continue |
| | if p not in sample_prompts_te_outputs: |
| | logger.info(f"cache Text Encoder outputs for prompt: {p}") |
| |
|
| | prompt_outputs = text_encoder([p], device) |
| | sample_prompts_te_outputs[p] = prompt_outputs |
| |
|
| | return sample_prompts_te_outputs |
| |
|
| | |
| | logger.info(f"loading T5: {t5_path}") |
| | t5 = T5EncoderModel(text_len=config.text_len, dtype=config.t5_dtype, device=device, weight_path=t5_path, fp8=fp8_t5) |
| |
|
| | logger.info("encoding with Text Encoder 1") |
| | te_outputs_1 = encode_for_text_encoder(t5) |
| | del t5 |
| |
|
| | |
| | |
| | sample_prompts_image_embs = {} |
| | for prompt_dict in prompts: |
| | if prompt_dict.get("image_path", None) is not None and self.i2v_training: |
| | sample_prompts_image_embs[prompt_dict["image_path"]] = None |
| |
|
| | if len(sample_prompts_image_embs) > 0: |
| | logger.info(f"loading CLIP: {clip_path}") |
| | assert clip_path is not None, "CLIP path is required for I2V training / I2V学習にはCLIPのパスが必要です" |
| | clip = CLIPModel(dtype=config.clip_dtype, device=device, weight_path=clip_path) |
| | clip.model.to(device) |
| |
|
| | logger.info(f"Encoding image to CLIP context") |
| | with torch.amp.autocast(device_type=device.type, dtype=torch.float16), torch.no_grad(): |
| | for image_path in sample_prompts_image_embs: |
| | logger.info(f"Encoding image: {image_path}") |
| | img = Image.open(image_path).convert("RGB") |
| | img = TF.to_tensor(img).sub_(0.5).div_(0.5).to(device) |
| | clip_context = clip.visual([img[:, None, :, :]]) |
| | sample_prompts_image_embs[image_path] = clip_context |
| |
|
| | del clip |
| | clean_memory_on_device(device) |
| |
|
| | |
| | sample_parameters = [] |
| | for prompt_dict in prompts: |
| | prompt_dict_copy = prompt_dict.copy() |
| |
|
| | p = prompt_dict.get("prompt", "") |
| | prompt_dict_copy["t5_embeds"] = te_outputs_1[p][0] |
| |
|
| | p = prompt_dict.get("negative_prompt", None) |
| | if p is not None: |
| | prompt_dict_copy["negative_t5_embeds"] = te_outputs_1[p][0] |
| |
|
| | p = prompt_dict.get("image_path", None) |
| | if p is not None and self.i2v_training: |
| | prompt_dict_copy["clip_embeds"] = sample_prompts_image_embs[p] |
| |
|
| | sample_parameters.append(prompt_dict_copy) |
| |
|
| | clean_memory_on_device(accelerator.device) |
| |
|
| | return sample_parameters |
| |
|
| | def do_inference( |
| | self, |
| | accelerator, |
| | args, |
| | sample_parameter, |
| | vae, |
| | dit_dtype, |
| | transformer, |
| | discrete_flow_shift, |
| | sample_steps, |
| | width, |
| | height, |
| | frame_count, |
| | generator, |
| | do_classifier_free_guidance, |
| | guidance_scale, |
| | cfg_scale, |
| | image_path=None, |
| | control_video_path=None, |
| | ): |
| | """architecture dependent inference""" |
| | model: WanModel = transformer |
| | device = accelerator.device |
| | if cfg_scale is None: |
| | cfg_scale = 5.0 |
| | do_classifier_free_guidance = do_classifier_free_guidance and cfg_scale != 1.0 |
| |
|
| | |
| | latent_video_length = (frame_count - 1) // self.config["vae_stride"][0] + 1 |
| |
|
| | |
| | context = sample_parameter["t5_embeds"].to(device=device) |
| | if do_classifier_free_guidance: |
| | context_null = sample_parameter["negative_t5_embeds"].to(device=device) |
| | else: |
| | context_null = None |
| |
|
| | num_channels_latents = 16 |
| | vae_scale_factor = self.config["vae_stride"][1] |
| |
|
| | |
| | lat_h = height // vae_scale_factor |
| | lat_w = width // vae_scale_factor |
| | shape_or_frame = (1, num_channels_latents, 1, lat_h, lat_w) |
| | latents = [] |
| | for _ in range(latent_video_length): |
| | latents.append(torch.randn(shape_or_frame, generator=generator, device=device, dtype=torch.float32)) |
| | latents = torch.cat(latents, dim=2) |
| |
|
| | image_latents = None |
| | if self.i2v_training or self.control_training: |
| | |
| | vae.to(device) |
| | vae.eval() |
| |
|
| | if self.i2v_training: |
| | image = Image.open(image_path) |
| | image = resize_image_to_bucket(image, (width, height)) |
| | image = torch.from_numpy(image).permute(2, 0, 1).unsqueeze(1).float() |
| | image = image / 127.5 - 1 |
| |
|
| | |
| | msk = torch.ones(1, frame_count, lat_h, lat_w, device=device) |
| | msk[:, 1:] = 0 |
| | msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1) |
| | msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w) |
| | msk = msk.transpose(1, 2) |
| |
|
| | with torch.amp.autocast(device_type=device.type, dtype=vae.dtype), torch.no_grad(): |
| | |
| | padding_frames = frame_count - 1 |
| | image = torch.concat([image, torch.zeros(3, padding_frames, height, width)], dim=1).to(device=device) |
| | y = vae.encode([image])[0] |
| |
|
| | y = y[:, :latent_video_length] |
| | y = y.unsqueeze(0) |
| | image_latents = torch.concat([msk, y], dim=1) |
| |
|
| | if self.control_training: |
| | |
| | video = load_video(control_video_path, 0, frame_count, bucket_reso=(width, height)) |
| | video = np.stack(video, axis=0) |
| | video = torch.from_numpy(video).permute(3, 0, 1, 2).float() |
| | video = video / 127.5 - 1 |
| | video = video.to(device=device) |
| |
|
| | with torch.amp.autocast(device_type=device.type, dtype=vae.dtype), torch.no_grad(): |
| | control_latents = vae.encode([video])[0] |
| | control_latents = control_latents[:, :latent_video_length] |
| | control_latents = control_latents.unsqueeze(0) |
| |
|
| | |
| | if image_latents is not None: |
| | image_latents = image_latents[:, 4:] |
| | image_latents[:, :, 1:] = 0 |
| | else: |
| | image_latents = torch.zeros_like(control_latents) |
| |
|
| | image_latents = torch.concat([control_latents, image_latents], dim=1) |
| |
|
| | vae.to("cpu") |
| | clean_memory_on_device(device) |
| |
|
| | |
| | scheduler = FlowUniPCMultistepScheduler(shift=1, use_dynamic_shifting=False) |
| | scheduler.set_timesteps(sample_steps, device=device, shift=discrete_flow_shift) |
| | timesteps = scheduler.timesteps |
| |
|
| | |
| | noise = torch.randn(16, latent_video_length, lat_h, lat_w, dtype=torch.float32, generator=generator, device=device).to( |
| | "cpu" |
| | ) |
| |
|
| | |
| | max_seq_len = latent_video_length * lat_h * lat_w // (self.config.patch_size[1] * self.config.patch_size[2]) |
| | arg_c = {"context": [context], "seq_len": max_seq_len} |
| | arg_null = {"context": [context_null], "seq_len": max_seq_len} |
| |
|
| | if self.i2v_training: |
| | arg_c["clip_fea"] = sample_parameter["clip_embeds"].to(device=device, dtype=dit_dtype) |
| | arg_null["clip_fea"] = arg_c["clip_fea"] |
| | if self.i2v_training or self.control_training: |
| | arg_c["y"] = image_latents |
| | arg_null["y"] = image_latents |
| |
|
| | |
| | prompt_idx = sample_parameter.get("enum", 0) |
| | latent = noise |
| | with torch.no_grad(): |
| | for i, t in enumerate(tqdm(timesteps, desc=f"Sampling timesteps for prompt {prompt_idx+1}")): |
| | latent_model_input = [latent.to(device=device)] |
| | timestep = t.unsqueeze(0) |
| |
|
| | with accelerator.autocast(): |
| | noise_pred_cond = model(latent_model_input, t=timestep, **arg_c)[0].to("cpu") |
| | if do_classifier_free_guidance: |
| | noise_pred_uncond = model(latent_model_input, t=timestep, **arg_null)[0].to("cpu") |
| | else: |
| | noise_pred_uncond = None |
| |
|
| | if do_classifier_free_guidance: |
| | noise_pred = noise_pred_uncond + cfg_scale * (noise_pred_cond - noise_pred_uncond) |
| | else: |
| | noise_pred = noise_pred_cond |
| |
|
| | temp_x0 = scheduler.step(noise_pred.unsqueeze(0), t, latent.unsqueeze(0), return_dict=False, generator=generator)[0] |
| | latent = temp_x0.squeeze(0) |
| |
|
| | |
| | vae.to(device) |
| | vae.eval() |
| |
|
| | |
| | logger.info(f"Decoding video from latents: {latent.shape}") |
| | latent = latent.unsqueeze(0) |
| | latent = latent.to(device=device) |
| |
|
| | with torch.amp.autocast(device_type=device.type, dtype=vae.dtype), torch.no_grad(): |
| | video = vae.decode(latent)[0] |
| | video = video.unsqueeze(0) |
| | del latent |
| |
|
| | logger.info(f"Decoding complete") |
| | video = video.to(torch.float32).cpu() |
| | video = (video / 2 + 0.5).clamp(0, 1) |
| |
|
| | vae.to("cpu") |
| | clean_memory_on_device(device) |
| |
|
| | return video |
| |
|
| | def load_vae(self, args: argparse.Namespace, vae_dtype: torch.dtype, vae_path: str): |
| | vae_path = args.vae |
| |
|
| | logger.info(f"Loading VAE model from {vae_path}") |
| | cache_device = torch.device("cpu") if args.vae_cache_cpu else None |
| | vae = WanVAE(vae_path=vae_path, device="cpu", dtype=vae_dtype, cache_device=cache_device) |
| | return vae |
| |
|
| | def load_transformer( |
| | self, |
| | accelerator: Accelerator, |
| | args: argparse.Namespace, |
| | dit_path: str, |
| | attn_mode: str, |
| | split_attn: bool, |
| | loading_device: str, |
| | dit_weight_dtype: Optional[torch.dtype], |
| | ): |
| | model = load_wan_model( |
| | self.config, accelerator.device, dit_path, attn_mode, split_attn, loading_device, dit_weight_dtype, args.fp8_scaled |
| | ) |
| | return model |
| |
|
| | def scale_shift_latents(self, latents): |
| | return latents |
| |
|
| | def call_dit( |
| | self, |
| | args: argparse.Namespace, |
| | accelerator: Accelerator, |
| | transformer, |
| | latents: torch.Tensor, |
| | batch: dict[str, torch.Tensor], |
| | noise: torch.Tensor, |
| | noisy_model_input: torch.Tensor, |
| | timesteps: torch.Tensor, |
| | network_dtype: torch.dtype, |
| | ): |
| | model: WanModel = transformer |
| |
|
| | |
| | image_latents = None |
| | clip_fea = None |
| | if self.i2v_training: |
| | image_latents = batch["latents_image"] |
| | image_latents = image_latents.to(device=accelerator.device, dtype=network_dtype) |
| | clip_fea = batch["clip"] |
| | clip_fea = clip_fea.to(device=accelerator.device, dtype=network_dtype) |
| | if self.control_training: |
| | control_latents = batch["latents_control"] |
| | control_latents = control_latents.to(device=accelerator.device, dtype=network_dtype) |
| | if image_latents is not None: |
| | image_latents = image_latents[:, 4:] |
| | image_latents[:, :, 1:] = 0 |
| | else: |
| | image_latents = torch.zeros_like(control_latents) |
| | image_latents = torch.concat([control_latents, image_latents], dim=1) |
| | control_latents = None |
| |
|
| | context = [t.to(device=accelerator.device, dtype=network_dtype) for t in batch["t5"]] |
| |
|
| | |
| | if args.gradient_checkpointing: |
| | noisy_model_input.requires_grad_(True) |
| | for t in context: |
| | t.requires_grad_(True) |
| | if image_latents is not None: |
| | image_latents.requires_grad_(True) |
| | if clip_fea is not None: |
| | clip_fea.requires_grad_(True) |
| |
|
| | |
| | lat_f, lat_h, lat_w = latents.shape[2:5] |
| | seq_len = lat_f * lat_h * lat_w // (self.config.patch_size[0] * self.config.patch_size[1] * self.config.patch_size[2]) |
| | latents = latents.to(device=accelerator.device, dtype=network_dtype) |
| | noisy_model_input = noisy_model_input.to(device=accelerator.device, dtype=network_dtype) |
| | with accelerator.autocast(): |
| | model_pred = model(noisy_model_input, t=timesteps, context=context, clip_fea=clip_fea, seq_len=seq_len, y=image_latents) |
| | model_pred = torch.stack(model_pred, dim=0) |
| |
|
| | |
| | target = noise - latents |
| |
|
| | return model_pred, target |
| |
|
| | |
| |
|
| |
|
| | def wan_setup_parser(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: |
| | """Wan2.1 specific parser setup""" |
| | parser.add_argument("--task", type=str, default="t2v-14B", choices=list(WAN_CONFIGS.keys()), help="The task to run.") |
| | parser.add_argument("--fp8_scaled", action="store_true", help="use scaled fp8 for DiT / DiTにスケーリングされたfp8を使う") |
| | parser.add_argument("--t5", type=str, default=None, help="text encoder (T5) checkpoint path") |
| | parser.add_argument("--fp8_t5", action="store_true", help="use fp8 for Text Encoder model") |
| | parser.add_argument( |
| | "--clip", |
| | type=str, |
| | default=None, |
| | help="text encoder (CLIP) checkpoint path, optional. If training I2V model, this is required", |
| | ) |
| | parser.add_argument("--vae_cache_cpu", action="store_true", help="cache features in VAE on CPU") |
| | return parser |
| |
|
| |
|
| | if __name__ == "__main__": |
| | parser = setup_parser_common() |
| | parser = wan_setup_parser(parser) |
| |
|
| | args = parser.parse_args() |
| | args = read_config_from_file(args, parser) |
| |
|
| | args.dit_dtype = None |
| | if args.vae_dtype is None: |
| | args.vae_dtype = "bfloat16" |
| |
|
| | trainer = WanNetworkTrainer() |
| | trainer.train(args) |
| |
|