| | import argparse |
| | from datetime import datetime |
| | import gc |
| | import json |
| | import random |
| | import os |
| | import re |
| | import time |
| | import math |
| | import copy |
| | from typing import Tuple, Optional, List, Union, Any, Dict |
| |
|
| | import torch |
| | from safetensors.torch import load_file, save_file |
| | from safetensors import safe_open |
| | from PIL import Image |
| | import cv2 |
| | import numpy as np |
| | import torchvision.transforms.functional as TF |
| | from transformers import LlamaModel |
| | from tqdm import tqdm |
| |
|
| | from networks import lora_framepack |
| | from hunyuan_model.autoencoder_kl_causal_3d import AutoencoderKLCausal3D |
| | from frame_pack import hunyuan |
| | from frame_pack.hunyuan_video_packed import HunyuanVideoTransformer3DModelPacked, load_packed_model |
| | from frame_pack.utils import crop_or_pad_yield_mask, resize_and_center_crop, soft_append_bcthw |
| | from frame_pack.bucket_tools import find_nearest_bucket |
| | from frame_pack.clip_vision import hf_clip_vision_encode |
| | from frame_pack.k_diffusion_hunyuan import sample_hunyuan |
| | from dataset import image_video_dataset |
| |
|
| | try: |
| | from lycoris.kohya import create_network_from_weights |
| | except: |
| | pass |
| |
|
| | from utils.device_utils import clean_memory_on_device |
| | from hv_generate_video import save_images_grid, save_videos_grid, synchronize_device |
| | from wan_generate_video import merge_lora_weights |
| | from frame_pack.framepack_utils import load_vae, load_text_encoder1, load_text_encoder2, load_image_encoders |
| | from dataset.image_video_dataset import load_video |
| |
|
| | import logging |
| |
|
| | logger = logging.getLogger(__name__) |
| | logging.basicConfig(level=logging.INFO) |
| |
|
| |
|
| | class GenerationSettings: |
| | def __init__(self, device: torch.device, dit_weight_dtype: Optional[torch.dtype] = None): |
| | self.device = device |
| | self.dit_weight_dtype = dit_weight_dtype |
| |
|
| |
|
| | def parse_args() -> argparse.Namespace: |
| | """parse command line arguments""" |
| | parser = argparse.ArgumentParser(description="Wan 2.1 inference script") |
| |
|
| | |
| | |
| | parser.add_argument( |
| | "--sample_solver", type=str, default="unipc", choices=["unipc", "dpm++", "vanilla"], help="The solver used to sample." |
| | ) |
| |
|
| | parser.add_argument("--dit", type=str, default=None, help="DiT directory or path") |
| | parser.add_argument("--vae", type=str, default=None, help="VAE directory or path") |
| | parser.add_argument("--text_encoder1", type=str, required=True, help="Text Encoder 1 directory or path") |
| | parser.add_argument("--text_encoder2", type=str, required=True, help="Text Encoder 2 directory or path") |
| | parser.add_argument("--image_encoder", type=str, required=True, help="Image Encoder directory or path") |
| | |
| | parser.add_argument("--lora_weight", type=str, nargs="*", required=False, default=None, help="LoRA weight path") |
| | parser.add_argument("--lora_multiplier", type=float, nargs="*", default=1.0, help="LoRA multiplier") |
| | parser.add_argument("--include_patterns", type=str, nargs="*", default=None, help="LoRA module include patterns") |
| | parser.add_argument("--exclude_patterns", type=str, nargs="*", default=None, help="LoRA module exclude patterns") |
| | parser.add_argument( |
| | "--save_merged_model", |
| | type=str, |
| | default=None, |
| | help="Save merged model to path. If specified, no inference will be performed.", |
| | ) |
| |
|
| | |
| | parser.add_argument( |
| | "--prompt", |
| | type=str, |
| | default=None, |
| | help="prompt for generation. If `;;;` is used, it will be split into sections. Example: `section_index:prompt` or " |
| | "`section_index:prompt;;;section_index:prompt;;;...`, section_index can be `0` or `-1` or `0-2`, `-1` means last section, `0-2` means from 0 to 2 (inclusive).", |
| | ) |
| | parser.add_argument( |
| | "--negative_prompt", |
| | type=str, |
| | default=None, |
| | help="negative prompt for generation, default is empty string. should not change.", |
| | ) |
| | parser.add_argument("--video_size", type=int, nargs=2, default=[256, 256], help="video size, height and width") |
| | parser.add_argument("--video_seconds", type=float, default=5.0, help="video length, Default is 5.0 seconds") |
| | parser.add_argument("--fps", type=int, default=30, help="video fps, Default is 30") |
| | parser.add_argument("--infer_steps", type=int, default=25, help="number of inference steps, Default is 25") |
| | parser.add_argument("--save_path", type=str, required=True, help="path to save generated video") |
| | parser.add_argument("--seed", type=int, default=None, help="Seed for evaluation.") |
| | |
| | |
| | |
| | parser.add_argument("--latent_window_size", type=int, default=9, help="latent window size, default is 9. should not change.") |
| | parser.add_argument( |
| | "--embedded_cfg_scale", type=float, default=10.0, help="Embeded CFG scale (distilled CFG Scale), default is 10.0" |
| | ) |
| | parser.add_argument( |
| | "--guidance_scale", |
| | type=float, |
| | default=1.0, |
| | help="Guidance scale for classifier free guidance. Default is 1.0, should not change.", |
| | ) |
| | parser.add_argument("--guidance_rescale", type=float, default=0.0, help="CFG Re-scale, default is 0.0. Should not change.") |
| | |
| | parser.add_argument("--image_path", type=str, default=None, help="path to image for image2video inference") |
| | parser.add_argument("--end_image_path", type=str, default=None, help="path to end image for image2video inference") |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | parser.add_argument("--fp8", action="store_true", help="use fp8 for DiT model") |
| | parser.add_argument("--fp8_scaled", action="store_true", help="use scaled fp8 for DiT, only for fp8") |
| | |
| | parser.add_argument("--fp8_llm", action="store_true", help="use fp8 for Text Encoder 1 (LLM)") |
| | parser.add_argument( |
| | "--device", type=str, default=None, help="device to use for inference. If None, use CUDA if available, otherwise use CPU" |
| | ) |
| | parser.add_argument( |
| | "--attn_mode", |
| | type=str, |
| | default="torch", |
| | choices=["flash", "torch", "sageattn", "xformers", "sdpa"], |
| | help="attention mode", |
| | ) |
| | parser.add_argument("--vae_chunk_size", type=int, default=None, help="chunk size for CausalConv3d in VAE") |
| | parser.add_argument( |
| | "--vae_spatial_tile_sample_min_size", type=int, default=None, help="spatial tile sample min size for VAE, default 256" |
| | ) |
| | parser.add_argument("--bulk_decode", action="store_true", help="decode all frames at once") |
| | parser.add_argument("--blocks_to_swap", type=int, default=0, help="number of blocks to swap in the model") |
| | parser.add_argument( |
| | "--output_type", type=str, default="video", choices=["video", "images", "latent", "both"], help="output type" |
| | ) |
| | parser.add_argument("--no_metadata", action="store_true", help="do not save metadata") |
| | parser.add_argument("--latent_path", type=str, nargs="*", default=None, help="path to latent for decode. no inference") |
| | parser.add_argument("--lycoris", action="store_true", help="use lycoris for inference") |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | parser.add_argument("--from_file", type=str, default=None, help="Read prompts from a file") |
| | parser.add_argument("--interactive", action="store_true", help="Interactive mode: read prompts from console") |
| |
|
| | args = parser.parse_args() |
| |
|
| | |
| | if args.from_file and args.interactive: |
| | raise ValueError("Cannot use both --from_file and --interactive at the same time") |
| |
|
| | if args.prompt is None and not args.from_file and not args.interactive: |
| | raise ValueError("Either --prompt, --from_file or --interactive must be specified") |
| |
|
| | return args |
| |
|
| |
|
| | def parse_prompt_line(line: str) -> Dict[str, Any]: |
| | """Parse a prompt line into a dictionary of argument overrides |
| | |
| | Args: |
| | line: Prompt line with options |
| | |
| | Returns: |
| | Dict[str, Any]: Dictionary of argument overrides |
| | """ |
| | |
| | parts = line.split(" --") |
| | prompt = parts[0].strip() |
| |
|
| | |
| | overrides = {"prompt": prompt} |
| |
|
| | for part in parts[1:]: |
| | if not part.strip(): |
| | continue |
| | option_parts = part.split(" ", 1) |
| | option = option_parts[0].strip() |
| | value = option_parts[1].strip() if len(option_parts) > 1 else "" |
| |
|
| | |
| | if option == "w": |
| | overrides["video_size_width"] = int(value) |
| | elif option == "h": |
| | overrides["video_size_height"] = int(value) |
| | elif option == "f": |
| | overrides["video_seconds"] = float(value) |
| | elif option == "d": |
| | overrides["seed"] = int(value) |
| | elif option == "s": |
| | overrides["infer_steps"] = int(value) |
| | elif option == "g" or option == "l": |
| | overrides["guidance_scale"] = float(value) |
| | |
| | |
| | elif option == "i": |
| | overrides["image_path"] = value |
| | elif option == "cn": |
| | overrides["control_path"] = value |
| | elif option == "n": |
| | overrides["negative_prompt"] = value |
| |
|
| | return overrides |
| |
|
| |
|
| | def apply_overrides(args: argparse.Namespace, overrides: Dict[str, Any]) -> argparse.Namespace: |
| | """Apply overrides to args |
| | |
| | Args: |
| | args: Original arguments |
| | overrides: Dictionary of overrides |
| | |
| | Returns: |
| | argparse.Namespace: New arguments with overrides applied |
| | """ |
| | args_copy = copy.deepcopy(args) |
| |
|
| | for key, value in overrides.items(): |
| | if key == "video_size_width": |
| | args_copy.video_size[1] = value |
| | elif key == "video_size_height": |
| | args_copy.video_size[0] = value |
| | else: |
| | setattr(args_copy, key, value) |
| |
|
| | return args_copy |
| |
|
| |
|
| | def check_inputs(args: argparse.Namespace) -> Tuple[int, int, int]: |
| | """Validate video size and length |
| | |
| | Args: |
| | args: command line arguments |
| | |
| | Returns: |
| | Tuple[int, int, float]: (height, width, video_seconds) |
| | """ |
| | height = args.video_size[0] |
| | width = args.video_size[1] |
| |
|
| | video_seconds = args.video_seconds |
| |
|
| | if height % 8 != 0 or width % 8 != 0: |
| | raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") |
| |
|
| | return height, width, video_seconds |
| |
|
| |
|
| | |
| |
|
| |
|
| | def load_dit_model(args: argparse.Namespace, device: torch.device) -> HunyuanVideoTransformer3DModelPacked: |
| | """load DiT model |
| | |
| | Args: |
| | args: command line arguments |
| | device: device to use |
| | dit_dtype: data type for the model |
| | dit_weight_dtype: data type for the model weights. None for as-is |
| | |
| | Returns: |
| | HunyuanVideoTransformer3DModelPacked: DiT model |
| | """ |
| | loading_device = "cpu" |
| | if args.blocks_to_swap == 0 and not args.fp8_scaled and args.lora_weight is None: |
| | loading_device = device |
| |
|
| | |
| | model = load_packed_model(device, args.dit, args.attn_mode, loading_device) |
| | return model |
| |
|
| |
|
| | def optimize_model(model: HunyuanVideoTransformer3DModelPacked, args: argparse.Namespace, device: torch.device) -> None: |
| | """optimize the model (FP8 conversion, device move etc.) |
| | |
| | Args: |
| | model: dit model |
| | args: command line arguments |
| | device: device to use |
| | """ |
| | if args.fp8_scaled: |
| | |
| | state_dict = model.state_dict() |
| |
|
| | |
| | move_to_device = args.blocks_to_swap == 0 |
| | state_dict = model.fp8_optimization(state_dict, device, move_to_device, use_scaled_mm=False) |
| |
|
| | info = model.load_state_dict(state_dict, strict=True, assign=True) |
| | logger.info(f"Loaded FP8 optimized weights: {info}") |
| |
|
| | if args.blocks_to_swap == 0: |
| | model.to(device) |
| | else: |
| | |
| | target_dtype = None |
| | target_device = None |
| |
|
| | if args.fp8: |
| | target_dtype = torch.float8e4m3fn |
| |
|
| | if args.blocks_to_swap == 0: |
| | logger.info(f"Move model to device: {device}") |
| | target_device = device |
| |
|
| | if target_device is not None and target_dtype is not None: |
| | model.to(target_device, target_dtype) |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | if args.blocks_to_swap > 0: |
| | logger.info(f"Enable swap {args.blocks_to_swap} blocks to CPU from device: {device}") |
| | model.enable_block_swap(args.blocks_to_swap, device, supports_backward=False) |
| | model.move_to_device_except_swap_blocks(device) |
| | model.prepare_block_swap_before_forward() |
| | else: |
| | |
| | model.to(device) |
| |
|
| | model.eval().requires_grad_(False) |
| | clean_memory_on_device(device) |
| |
|
| |
|
| | |
| |
|
| |
|
| | def decode_latent( |
| | latent_window_size: int, |
| | total_latent_sections: int, |
| | bulk_decode: bool, |
| | vae: AutoencoderKLCausal3D, |
| | latent: torch.Tensor, |
| | device: torch.device, |
| | ) -> torch.Tensor: |
| | logger.info(f"Decoding video...") |
| | if latent.ndim == 4: |
| | latent = latent.unsqueeze(0) |
| |
|
| | vae.to(device) |
| | if not bulk_decode: |
| | latent_window_size = latent_window_size |
| | |
| | |
| | num_frames = latent_window_size * 4 - 3 |
| |
|
| | latents_to_decode = [] |
| | latent_frame_index = 0 |
| | for i in range(total_latent_sections - 1, -1, -1): |
| | is_last_section = i == total_latent_sections - 1 |
| | generated_latent_frames = (num_frames + 3) // 4 + (1 if is_last_section else 0) |
| | section_latent_frames = (latent_window_size * 2 + 1) if is_last_section else (latent_window_size * 2) |
| |
|
| | section_latent = latent[:, :, latent_frame_index : latent_frame_index + section_latent_frames, :, :] |
| | latents_to_decode.append(section_latent) |
| |
|
| | latent_frame_index += generated_latent_frames |
| |
|
| | latents_to_decode = latents_to_decode[::-1] |
| |
|
| | history_pixels = None |
| | for latent in tqdm(latents_to_decode): |
| | if history_pixels is None: |
| | history_pixels = hunyuan.vae_decode(latent, vae).cpu() |
| | else: |
| | overlapped_frames = latent_window_size * 4 - 3 |
| | current_pixels = hunyuan.vae_decode(latent, vae).cpu() |
| | history_pixels = soft_append_bcthw(current_pixels, history_pixels, overlapped_frames) |
| | clean_memory_on_device(device) |
| | else: |
| | |
| | logger.info(f"Bulk decoding") |
| | history_pixels = hunyuan.vae_decode(latent, vae).cpu() |
| | vae.to("cpu") |
| |
|
| | print(f"Decoded. Pixel shape {history_pixels.shape}") |
| | return history_pixels[0] |
| |
|
| |
|
| | def prepare_i2v_inputs( |
| | args: argparse.Namespace, |
| | device: torch.device, |
| | vae: AutoencoderKLCausal3D, |
| | encoded_context: Optional[Dict] = None, |
| | encoded_context_n: Optional[Dict] = None, |
| | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Tuple[dict, dict]]: |
| | """Prepare inputs for I2V |
| | |
| | Args: |
| | args: command line arguments |
| | config: model configuration |
| | device: device to use |
| | vae: VAE model, used for image encoding |
| | encoded_context: Pre-encoded text context |
| | |
| | Returns: |
| | Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Tuple[dict, dict]]: |
| | (noise, context, context_null, y, (arg_c, arg_null)) |
| | """ |
| |
|
| | height, width, video_seconds = check_inputs(args) |
| |
|
| | |
| | def preprocess_image(image_path: str): |
| | image = Image.open(image_path).convert("RGB") |
| |
|
| | image_np = np.array(image) |
| |
|
| | image_np = image_video_dataset.resize_image_to_bucket(image_np, (width, height)) |
| | image_tensor = torch.from_numpy(image_np).float() / 127.5 - 1.0 |
| | image_tensor = image_tensor.permute(2, 0, 1)[None, :, None] |
| | return image_tensor, image_np |
| |
|
| | img_tensor, img_np = preprocess_image(args.image_path) |
| | if args.end_image_path is not None: |
| | end_img_tensor, end_img_np = preprocess_image(args.end_image_path) |
| | else: |
| | end_img_tensor, end_img_np = None, None |
| |
|
| | |
| | n_prompt = args.negative_prompt if args.negative_prompt else "" |
| |
|
| | if encoded_context is None: |
| | |
| | tokenizer1, text_encoder1 = load_text_encoder1(args, args.fp8_llm, device) |
| | tokenizer2, text_encoder2 = load_text_encoder2(args) |
| | text_encoder2.to(device) |
| |
|
| | |
| | section_prompts = {} |
| | if ";;;" in args.prompt: |
| | section_prompt_strs = args.prompt.split(";;;") |
| | for section_prompt_str in section_prompt_strs: |
| | if ":" not in section_prompt_str: |
| | start = end = 0 |
| | prompt_str = section_prompt_str.strip() |
| | else: |
| | index_str, prompt_str = section_prompt_str.split(":", 1) |
| | index_str = index_str.strip() |
| | prompt_str = prompt_str.strip() |
| |
|
| | m = re.match(r"^(-?\d+)(-\d+)?$", index_str) |
| | if m: |
| | start = int(m.group(1)) |
| | end = int(m.group(2)[1:]) if m.group(2) is not None else start |
| | else: |
| | start = end = 0 |
| | prompt_str = section_prompt_str.strip() |
| | for i in range(start, end + 1): |
| | section_prompts[i] = prompt_str |
| | else: |
| | section_prompts[0] = args.prompt |
| |
|
| | |
| | if 0 not in section_prompts: |
| | |
| | |
| | indices = list(section_prompts.keys()) |
| | if all(i < 0 for i in indices): |
| | section_index = min(indices) |
| | else: |
| | section_index = min(i for i in indices if i >= 0) |
| | section_prompts[0] = section_prompts[section_index] |
| | print(section_prompts) |
| |
|
| | logger.info(f"Encoding prompt") |
| | llama_vecs = {} |
| | llama_attention_masks = {} |
| | clip_l_poolers = {} |
| | with torch.autocast(device_type=device.type, dtype=text_encoder1.dtype), torch.no_grad(): |
| | for index, prompt in section_prompts.items(): |
| | llama_vec, clip_l_pooler = hunyuan.encode_prompt_conds(prompt, text_encoder1, text_encoder2, tokenizer1, tokenizer2) |
| | llama_vec = llama_vec.cpu() |
| | clip_l_pooler = clip_l_pooler.cpu() |
| |
|
| | llama_vec, llama_attention_mask = crop_or_pad_yield_mask(llama_vec, length=512) |
| |
|
| | llama_vecs[index] = llama_vec |
| | llama_attention_masks[index] = llama_attention_mask |
| | clip_l_poolers[index] = clip_l_pooler |
| |
|
| | if args.guidance_scale == 1.0: |
| | llama_vec_n, clip_l_pooler_n = torch.zeros_like(llama_vecs[0]), torch.zeros_like(clip_l_poolers[0]) |
| | else: |
| | with torch.autocast(device_type=device.type, dtype=text_encoder1.dtype), torch.no_grad(): |
| | llama_vec_n, clip_l_pooler_n = hunyuan.encode_prompt_conds( |
| | n_prompt, text_encoder1, text_encoder2, tokenizer1, tokenizer2 |
| | ) |
| | llama_vec_n = llama_vec_n.cpu() |
| | clip_l_pooler_n = clip_l_pooler_n.cpu() |
| |
|
| | llama_vec_n, llama_attention_mask_n = crop_or_pad_yield_mask(llama_vec_n, length=512) |
| |
|
| | |
| | del text_encoder1, text_encoder2, tokenizer1, tokenizer2 |
| | clean_memory_on_device(device) |
| |
|
| | |
| | feature_extractor, image_encoder = load_image_encoders(args) |
| | image_encoder.to(device) |
| |
|
| | |
| | with torch.no_grad(): |
| | image_encoder_output = hf_clip_vision_encode(img_np, feature_extractor, image_encoder) |
| | image_encoder_last_hidden_state = image_encoder_output.last_hidden_state.cpu() |
| |
|
| | if end_img_np is not None: |
| | with torch.no_grad(): |
| | end_image_encoder_output = hf_clip_vision_encode(end_img_np, feature_extractor, image_encoder) |
| | end_image_encoder_last_hidden_state = end_image_encoder_output.last_hidden_state.cpu() |
| | else: |
| | end_image_encoder_last_hidden_state = None |
| |
|
| | |
| | del image_encoder, feature_extractor |
| | clean_memory_on_device(device) |
| | else: |
| | |
| | llama_vecs = encoded_context["llama_vecs"] |
| | llama_attention_masks = encoded_context["llama_attention_masks"] |
| | clip_l_poolers = encoded_context["clip_l_poolers"] |
| | llama_vec_n = encoded_context_n["llama_vec"] |
| | llama_attention_mask_n = encoded_context_n["llama_attention_mask"] |
| | clip_l_pooler_n = encoded_context_n["clip_l_pooler"] |
| | image_encoder_last_hidden_state = encoded_context["image_encoder_last_hidden_state"] |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | logger.info(f"Encoding image to latent space") |
| | vae.to(device) |
| | start_latent = hunyuan.vae_encode(img_tensor, vae).cpu() |
| | if end_img_tensor is not None: |
| | end_latent = hunyuan.vae_encode(end_img_tensor, vae).cpu() |
| | else: |
| | end_latent = None |
| | vae.to("cpu") |
| | clean_memory_on_device(device) |
| |
|
| | |
| | arg_c = {} |
| | for index in llama_vecs.keys(): |
| | llama_vec = llama_vecs[index] |
| | llama_attention_mask = llama_attention_masks[index] |
| | clip_l_pooler = clip_l_poolers[index] |
| | arg_c_i = { |
| | "llama_vec": llama_vec, |
| | "llama_attention_mask": llama_attention_mask, |
| | "clip_l_pooler": clip_l_pooler, |
| | "image_encoder_last_hidden_state": image_encoder_last_hidden_state, |
| | "end_image_encoder_last_hidden_state": end_image_encoder_last_hidden_state, |
| | "prompt": section_prompts[index], |
| | } |
| | arg_c[index] = arg_c_i |
| |
|
| | arg_null = { |
| | "llama_vec": llama_vec_n, |
| | "llama_attention_mask": llama_attention_mask_n, |
| | "clip_l_pooler": clip_l_pooler_n, |
| | "image_encoder_last_hidden_state": image_encoder_last_hidden_state, |
| | "end_image_encoder_last_hidden_state": end_image_encoder_last_hidden_state, |
| | } |
| |
|
| | return height, width, video_seconds, start_latent, end_latent, arg_c, arg_null |
| |
|
| |
|
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| |
|
| | |
| |
|
| |
|
| | def generate(args: argparse.Namespace, gen_settings: GenerationSettings, shared_models: Optional[Dict] = None) -> torch.Tensor: |
| | """main function for generation |
| | |
| | Args: |
| | args: command line arguments |
| | shared_models: dictionary containing pre-loaded models and encoded data |
| | |
| | Returns: |
| | torch.Tensor: generated latent |
| | """ |
| | device, dit_weight_dtype = (gen_settings.device, gen_settings.dit_weight_dtype) |
| |
|
| | |
| | seed = args.seed if args.seed is not None else random.randint(0, 2**32 - 1) |
| | args.seed = seed |
| |
|
| | |
| | if shared_models is not None: |
| | |
| | vae = shared_models.get("vae") |
| | model = shared_models.get("model") |
| | encoded_context = shared_models.get("encoded_contexts", {}).get(args.prompt) |
| | n_prompt = args.negative_prompt if args.negative_prompt else "" |
| | encoded_context_n = shared_models.get("encoded_contexts", {}).get(n_prompt) |
| |
|
| | height, width, video_seconds, start_latent, end_latent, context, context_null = prepare_i2v_inputs( |
| | args, device, vae, encoded_context, encoded_context_n |
| | ) |
| | else: |
| | |
| | vae = load_vae(args.vae, args.vae_chunk_size, args.vae_spatial_tile_sample_min_size, device) |
| | height, width, video_seconds, start_latent, end_latent, context, context_null = prepare_i2v_inputs(args, device, vae) |
| |
|
| | |
| | model = load_dit_model(args, device) |
| |
|
| | |
| | if args.lora_weight is not None and len(args.lora_weight) > 0: |
| | merge_lora_weights(lora_framepack, model, args, device) |
| | |
| | if args.save_merged_model: |
| | return None |
| |
|
| | |
| | optimize_model(model, args, device) |
| |
|
| | |
| | latent_window_size = args.latent_window_size |
| | |
| | total_latent_sections = (video_seconds * 30) / (latent_window_size * 4) |
| | total_latent_sections = int(max(round(total_latent_sections), 1)) |
| |
|
| | |
| | seed_g = torch.Generator(device="cpu") |
| | seed_g.manual_seed(seed) |
| | num_frames = latent_window_size * 4 - 3 |
| |
|
| | logger.info( |
| | f"Video size: {height}x{width}@{video_seconds} (HxW@seconds), fps: {args.fps}, " |
| | f"infer_steps: {args.infer_steps}, frames per generation: {num_frames}" |
| | ) |
| |
|
| | history_latents = torch.zeros((1, 16, 1 + 2 + 16, height // 8, width // 8), dtype=torch.float32) |
| | |
| | total_generated_latent_frames = 0 |
| |
|
| | latent_paddings = reversed(range(total_latent_sections)) |
| |
|
| | if total_latent_sections > 4: |
| | |
| | |
| | |
| | |
| | |
| | latent_paddings = [3] + [2] * (total_latent_sections - 3) + [1, 0] |
| |
|
| | for section_index_reverse, latent_padding in enumerate(latent_paddings): |
| | section_index = total_latent_sections - 1 - section_index_reverse |
| |
|
| | is_last_section = latent_padding == 0 |
| | is_first_section = section_index_reverse == 0 |
| | latent_padding_size = latent_padding * latent_window_size |
| |
|
| | logger.info(f"latent_padding_size = {latent_padding_size}, is_last_section = {is_last_section}") |
| |
|
| | reference_start_latent = start_latent |
| | apply_end_image = args.end_image_path is not None and is_first_section |
| | if apply_end_image: |
| | latent_padding_size = 0 |
| | reference_start_latent = end_latent |
| | logger.info(f"Apply experimental end image, latent_padding_size = {latent_padding_size}") |
| |
|
| | |
| | indices = torch.arange(0, sum([1, latent_padding_size, latent_window_size, 1, 2, 16])).unsqueeze(0) |
| | ( |
| | clean_latent_indices_pre, |
| | blank_indices, |
| | latent_indices, |
| | clean_latent_indices_post, |
| | clean_latent_2x_indices, |
| | clean_latent_4x_indices, |
| | ) = indices.split([1, latent_padding_size, latent_window_size, 1, 2, 16], dim=1) |
| | clean_latent_indices = torch.cat([clean_latent_indices_pre, clean_latent_indices_post], dim=1) |
| |
|
| | clean_latents_pre = reference_start_latent.to(history_latents) |
| | clean_latents_post, clean_latents_2x, clean_latents_4x = history_latents[:, :, : 1 + 2 + 16, :, :].split([1, 2, 16], dim=2) |
| | clean_latents = torch.cat([clean_latents_pre, clean_latents_post], dim=2) |
| |
|
| | |
| | |
| | |
| | |
| |
|
| | section_index_from_last = -(section_index_reverse + 1) |
| | if section_index_from_last in context: |
| | prompt_index = section_index_from_last |
| | elif section_index in context: |
| | prompt_index = section_index |
| | else: |
| | prompt_index = 0 |
| | context_for_index = context[prompt_index] |
| | |
| | logger.info(f"Section {section_index}: {context_for_index['prompt']}") |
| |
|
| | llama_vec = context_for_index["llama_vec"].to(device, dtype=torch.bfloat16) |
| | llama_attention_mask = context_for_index["llama_attention_mask"].to(device) |
| | clip_l_pooler = context_for_index["clip_l_pooler"].to(device, dtype=torch.bfloat16) |
| |
|
| | if not apply_end_image: |
| | image_encoder_last_hidden_state = context_for_index["image_encoder_last_hidden_state"].to(device, dtype=torch.bfloat16) |
| | else: |
| | image_encoder_last_hidden_state = context_for_index["end_image_encoder_last_hidden_state"].to( |
| | device, dtype=torch.bfloat16 |
| | ) |
| |
|
| | llama_vec_n = context_null["llama_vec"].to(device, dtype=torch.bfloat16) |
| | llama_attention_mask_n = context_null["llama_attention_mask"].to(device) |
| | clip_l_pooler_n = context_null["clip_l_pooler"].to(device, dtype=torch.bfloat16) |
| |
|
| | generated_latents = sample_hunyuan( |
| | transformer=model, |
| | sampler=args.sample_solver, |
| | width=width, |
| | height=height, |
| | frames=num_frames, |
| | real_guidance_scale=args.guidance_scale, |
| | distilled_guidance_scale=args.embedded_cfg_scale, |
| | guidance_rescale=args.guidance_rescale, |
| | |
| | num_inference_steps=args.infer_steps, |
| | generator=seed_g, |
| | prompt_embeds=llama_vec, |
| | prompt_embeds_mask=llama_attention_mask, |
| | prompt_poolers=clip_l_pooler, |
| | negative_prompt_embeds=llama_vec_n, |
| | negative_prompt_embeds_mask=llama_attention_mask_n, |
| | negative_prompt_poolers=clip_l_pooler_n, |
| | device=device, |
| | dtype=torch.bfloat16, |
| | image_embeddings=image_encoder_last_hidden_state, |
| | latent_indices=latent_indices, |
| | clean_latents=clean_latents, |
| | clean_latent_indices=clean_latent_indices, |
| | clean_latents_2x=clean_latents_2x, |
| | clean_latent_2x_indices=clean_latent_2x_indices, |
| | clean_latents_4x=clean_latents_4x, |
| | clean_latent_4x_indices=clean_latent_4x_indices, |
| | ) |
| |
|
| | if is_last_section: |
| | generated_latents = torch.cat([start_latent.to(generated_latents), generated_latents], dim=2) |
| |
|
| | total_generated_latent_frames += int(generated_latents.shape[2]) |
| | history_latents = torch.cat([generated_latents.to(history_latents), history_latents], dim=2) |
| |
|
| | real_history_latents = history_latents[:, :, :total_generated_latent_frames, :, :] |
| |
|
| | logger.info(f"Generated. Latent shape {real_history_latents.shape}") |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | if shared_models is None: |
| | |
| | del model |
| | |
| | synchronize_device(device) |
| |
|
| | |
| | logger.info("Waiting for 5 seconds to finish block swap") |
| | time.sleep(5) |
| |
|
| | gc.collect() |
| | clean_memory_on_device(device) |
| |
|
| | return vae, real_history_latents |
| |
|
| |
|
| | def save_latent(latent: torch.Tensor, args: argparse.Namespace, height: int, width: int) -> str: |
| | """Save latent to file |
| | |
| | Args: |
| | latent: Latent tensor |
| | args: command line arguments |
| | height: height of frame |
| | width: width of frame |
| | |
| | Returns: |
| | str: Path to saved latent file |
| | """ |
| | save_path = args.save_path |
| | os.makedirs(save_path, exist_ok=True) |
| | time_flag = datetime.fromtimestamp(time.time()).strftime("%Y%m%d-%H%M%S") |
| |
|
| | seed = args.seed |
| | video_seconds = args.video_seconds |
| | latent_path = f"{save_path}/{time_flag}_{seed}_latent.safetensors" |
| |
|
| | if args.no_metadata: |
| | metadata = None |
| | else: |
| | metadata = { |
| | "seeds": f"{seed}", |
| | "prompt": f"{args.prompt}", |
| | "height": f"{height}", |
| | "width": f"{width}", |
| | "video_seconds": f"{video_seconds}", |
| | "infer_steps": f"{args.infer_steps}", |
| | "guidance_scale": f"{args.guidance_scale}", |
| | "latent_window_size": f"{args.latent_window_size}", |
| | "embedded_cfg_scale": f"{args.embedded_cfg_scale}", |
| | "guidance_rescale": f"{args.guidance_rescale}", |
| | "sample_solver": f"{args.sample_solver}", |
| | "latent_window_size": f"{args.latent_window_size}", |
| | "fps": f"{args.fps}", |
| | } |
| | if args.negative_prompt is not None: |
| | metadata["negative_prompt"] = f"{args.negative_prompt}" |
| |
|
| | sd = {"latent": latent.contiguous()} |
| | save_file(sd, latent_path, metadata=metadata) |
| | logger.info(f"Latent saved to: {latent_path}") |
| |
|
| | return latent_path |
| |
|
| |
|
| | def save_video( |
| | video: torch.Tensor, args: argparse.Namespace, original_base_name: Optional[str] = None, latent_frames: Optional[int] = None |
| | ) -> str: |
| | """Save video to file |
| | |
| | Args: |
| | video: Video tensor |
| | args: command line arguments |
| | original_base_name: Original base name (if latents are loaded from files) |
| | |
| | Returns: |
| | str: Path to saved video file |
| | """ |
| | save_path = args.save_path |
| | os.makedirs(save_path, exist_ok=True) |
| | time_flag = datetime.fromtimestamp(time.time()).strftime("%Y%m%d-%H%M%S") |
| |
|
| | seed = args.seed |
| | original_name = "" if original_base_name is None else f"_{original_base_name}" |
| | latent_frames = "" if latent_frames is None else f"_{latent_frames}" |
| | video_path = f"{save_path}/{time_flag}_{seed}{original_name}{latent_frames}.mp4" |
| |
|
| | video = video.unsqueeze(0) |
| | save_videos_grid(video, video_path, fps=args.fps, rescale=True) |
| | logger.info(f"Video saved to: {video_path}") |
| |
|
| | return video_path |
| |
|
| |
|
| | def save_images(sample: torch.Tensor, args: argparse.Namespace, original_base_name: Optional[str] = None) -> str: |
| | """Save images to directory |
| | |
| | Args: |
| | sample: Video tensor |
| | args: command line arguments |
| | original_base_name: Original base name (if latents are loaded from files) |
| | |
| | Returns: |
| | str: Path to saved images directory |
| | """ |
| | save_path = args.save_path |
| | os.makedirs(save_path, exist_ok=True) |
| | time_flag = datetime.fromtimestamp(time.time()).strftime("%Y%m%d-%H%M%S") |
| |
|
| | seed = args.seed |
| | original_name = "" if original_base_name is None else f"_{original_base_name}" |
| | image_name = f"{time_flag}_{seed}{original_name}" |
| | sample = sample.unsqueeze(0) |
| | save_images_grid(sample, save_path, image_name, rescale=True) |
| | logger.info(f"Sample images saved to: {save_path}/{image_name}") |
| |
|
| | return f"{save_path}/{image_name}" |
| |
|
| |
|
| | def save_output( |
| | args: argparse.Namespace, |
| | vae: AutoencoderKLCausal3D, |
| | latent: torch.Tensor, |
| | device: torch.device, |
| | original_base_names: Optional[List[str]] = None, |
| | ) -> None: |
| | """save output |
| | |
| | Args: |
| | args: command line arguments |
| | vae: VAE model |
| | latent: latent tensor |
| | device: device to use |
| | original_base_names: original base names (if latents are loaded from files) |
| | """ |
| | height, width = latent.shape[-2], latent.shape[-1] |
| | height *= 8 |
| | width *= 8 |
| | |
| | if args.output_type == "latent" or args.output_type == "both": |
| | |
| | save_latent(latent, args, height, width) |
| | if args.output_type == "latent": |
| | return |
| |
|
| | total_latent_sections = (args.video_seconds * 30) / (args.latent_window_size * 4) |
| | total_latent_sections = int(max(round(total_latent_sections), 1)) |
| | video = decode_latent(args.latent_window_size, total_latent_sections, args.bulk_decode, vae, latent, device) |
| |
|
| | if args.output_type == "video" or args.output_type == "both": |
| | |
| | original_name = "" if original_base_names is None else f"_{original_base_names[0]}" |
| | save_video(video, args, original_name) |
| |
|
| | elif args.output_type == "images": |
| | |
| | original_name = "" if original_base_names is None else f"_{original_base_names[0]}" |
| | save_images(video, args, original_name) |
| |
|
| |
|
| | def preprocess_prompts_for_batch(prompt_lines: List[str], base_args: argparse.Namespace) -> List[Dict]: |
| | """Process multiple prompts for batch mode |
| | |
| | Args: |
| | prompt_lines: List of prompt lines |
| | base_args: Base command line arguments |
| | |
| | Returns: |
| | List[Dict]: List of prompt data dictionaries |
| | """ |
| | prompts_data = [] |
| |
|
| | for line in prompt_lines: |
| | line = line.strip() |
| | if not line or line.startswith("#"): |
| | continue |
| |
|
| | |
| | prompt_data = parse_prompt_line(line) |
| | logger.info(f"Parsed prompt data: {prompt_data}") |
| | prompts_data.append(prompt_data) |
| |
|
| | return prompts_data |
| |
|
| |
|
| | def get_generation_settings(args: argparse.Namespace) -> GenerationSettings: |
| | device = torch.device(args.device) |
| |
|
| | dit_weight_dtype = None |
| | if args.fp8_scaled: |
| | dit_weight_dtype = None |
| | elif args.fp8: |
| | dit_weight_dtype = torch.float8_e4m3fn |
| |
|
| | logger.info(f"Using device: {device}, DiT weight weight precision: {dit_weight_dtype}") |
| |
|
| | gen_settings = GenerationSettings(device=device, dit_weight_dtype=dit_weight_dtype) |
| | return gen_settings |
| |
|
| |
|
| | def main(): |
| | |
| | args = parse_args() |
| |
|
| | |
| | latents_mode = args.latent_path is not None and len(args.latent_path) > 0 |
| |
|
| | |
| | device = args.device if args.device is not None else "cuda" if torch.cuda.is_available() else "cpu" |
| | device = torch.device(device) |
| | logger.info(f"Using device: {device}") |
| | args.device = device |
| |
|
| | if latents_mode: |
| | |
| | original_base_names = [] |
| | latents_list = [] |
| | seeds = [] |
| |
|
| | assert len(args.latent_path) == 1, "Only one latent path is supported for now" |
| |
|
| | for latent_path in args.latent_path: |
| | original_base_names.append(os.path.splitext(os.path.basename(latent_path))[0]) |
| | seed = 0 |
| |
|
| | if os.path.splitext(latent_path)[1] != ".safetensors": |
| | latents = torch.load(latent_path, map_location="cpu") |
| | else: |
| | latents = load_file(latent_path)["latent"] |
| | with safe_open(latent_path, framework="pt") as f: |
| | metadata = f.metadata() |
| | if metadata is None: |
| | metadata = {} |
| | logger.info(f"Loaded metadata: {metadata}") |
| |
|
| | if "seeds" in metadata: |
| | seed = int(metadata["seeds"]) |
| | if "height" in metadata and "width" in metadata: |
| | height = int(metadata["height"]) |
| | width = int(metadata["width"]) |
| | args.video_size = [height, width] |
| | if "video_seconds" in metadata: |
| | args.video_seconds = float(metadata["video_seconds"]) |
| |
|
| | seeds.append(seed) |
| | logger.info(f"Loaded latent from {latent_path}. Shape: {latents.shape}") |
| |
|
| | if latents.ndim == 5: |
| | latents = latents.squeeze(0) |
| |
|
| | latents_list.append(latents) |
| |
|
| | latent = torch.stack(latents_list, dim=0) |
| |
|
| | args.seed = seeds[0] |
| |
|
| | vae = load_vae(args.vae, args.vae_chunk_size, args.vae_spatial_tile_sample_min_size, device) |
| | save_output(args, vae, latent, device, original_base_names) |
| |
|
| | elif args.from_file: |
| | |
| |
|
| | |
| | with open(args.from_file, "r", encoding="utf-8") as f: |
| | prompt_lines = f.readlines() |
| |
|
| | |
| | prompts_data = preprocess_prompts_for_batch(prompt_lines, args) |
| | |
| | raise NotImplementedError("Batch mode is not implemented yet.") |
| |
|
| | elif args.interactive: |
| | |
| | |
| | raise NotImplementedError("Interactive mode is not implemented yet.") |
| |
|
| | else: |
| | |
| |
|
| | |
| | gen_settings = get_generation_settings(args) |
| | vae, latent = generate(args, gen_settings) |
| | |
| |
|
| | |
| | |
| | |
| |
|
| | save_output(args, vae, latent[0], device) |
| |
|
| | logger.info("Done!") |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|