Spaces:
Paused
Paused
| import argparse | |
| from datetime import datetime | |
| import gc | |
| import json | |
| import random | |
| import os | |
| import re | |
| import time | |
| import math | |
| import copy | |
| from typing import Tuple, Optional, List, Union, Any, Dict | |
| from rich.traceback import install as install_rich_tracebacks | |
| import torch | |
| from safetensors.torch import load_file, save_file | |
| from safetensors import safe_open | |
| from PIL import Image | |
| import cv2 | |
| import numpy as np | |
| import torchvision.transforms.functional as TF | |
| from transformers import LlamaModel | |
| from tqdm import tqdm | |
| from rich_argparse import RichHelpFormatter | |
| from networks import lora_framepack | |
| from hunyuan_model.autoencoder_kl_causal_3d import AutoencoderKLCausal3D | |
| from frame_pack import hunyuan | |
| from frame_pack.hunyuan_video_packed import HunyuanVideoTransformer3DModelPacked, load_packed_model | |
| from frame_pack.utils import crop_or_pad_yield_mask, resize_and_center_crop, soft_append_bcthw | |
| from frame_pack.bucket_tools import find_nearest_bucket | |
| from frame_pack.clip_vision import hf_clip_vision_encode | |
| from frame_pack.k_diffusion_hunyuan import sample_hunyuan | |
| from dataset import image_video_dataset | |
| try: | |
| from lycoris.kohya import create_network_from_weights | |
| except: | |
| pass | |
| from utils.device_utils import clean_memory_on_device | |
| from base_hv_generate_video import save_images_grid, save_videos_grid, synchronize_device | |
| from base_wan_generate_video import merge_lora_weights | |
| from frame_pack.framepack_utils import load_vae, load_text_encoder1, load_text_encoder2, load_image_encoders | |
| from dataset.image_video_dataset import load_video | |
| from blissful_tuner.blissful_args import add_blissful_args, parse_blissful_args | |
| from blissful_tuner.video_processing_common import save_videos_grid_advanced | |
| from blissful_tuner.latent_preview import LatentPreviewer | |
| import logging | |
| from diffusers_helper.utils import save_bcthw_as_mp4 | |
| logger = logging.getLogger(__name__) | |
| logging.basicConfig(level=logging.INFO) | |
| class GenerationSettings: | |
| def __init__(self, device: torch.device, dit_weight_dtype: Optional[torch.dtype] = None): | |
| self.device = device | |
| self.dit_weight_dtype = dit_weight_dtype | |
| def parse_args() -> argparse.Namespace: | |
| """parse command line arguments""" | |
| install_rich_tracebacks() | |
| parser = argparse.ArgumentParser(description="Framepack inference script", formatter_class=RichHelpFormatter) | |
| # WAN arguments | |
| # parser.add_argument("--ckpt_dir", type=str, default=None, help="The path to the checkpoint directory (Wan 2.1 official).") | |
| parser.add_argument("--is_f1", action="store_true", help="Use the FramePack F1 model specific logic.") | |
| parser.add_argument( | |
| "--sample_solver", type=str, default="unipc", choices=["unipc", "dpm++", "vanilla"], help="The solver used to sample." | |
| ) | |
| parser.add_argument("--dit", type=str, default=None, help="DiT directory or path. Overrides --model_version if specified.") | |
| parser.add_argument( | |
| "--model_version", type=str, default="original", choices=["original", "f1"], help="Select the FramePack model version to use ('original' or 'f1'). Ignored if --dit is specified." | |
| ) | |
| parser.add_argument("--vae", type=str, default=None, help="VAE directory or path") | |
| parser.add_argument("--text_encoder1", type=str, required=True, help="Text Encoder 1 directory or path") | |
| parser.add_argument("--text_encoder2", type=str, required=True, help="Text Encoder 2 directory or path") | |
| parser.add_argument("--image_encoder", type=str, required=True, help="Image Encoder directory or path") | |
| # LoRA | |
| parser.add_argument("--lora_weight", type=str, nargs="*", required=False, default=None, help="LoRA weight path") | |
| parser.add_argument("--lora_multiplier", type=float, nargs="*", default=1.0, help="LoRA multiplier") | |
| parser.add_argument("--include_patterns", type=str, nargs="*", default=None, help="LoRA module include patterns") | |
| parser.add_argument("--exclude_patterns", type=str, nargs="*", default=None, help="LoRA module exclude patterns") | |
| parser.add_argument( | |
| "--save_merged_model", | |
| type=str, | |
| default=None, | |
| help="Save merged model to path. If specified, no inference will be performed.", | |
| ) | |
| # inference | |
| parser.add_argument( | |
| "--prompt", | |
| type=str, | |
| default=None, | |
| help="prompt for generation. If `;;;` is used, it will be split into sections. Example: `section_index:prompt` or " | |
| "`section_index:prompt;;;section_index:prompt;;;...`, section_index can be `0` or `-1` or `0-2`, `-1` means last section, `0-2` means from 0 to 2 (inclusive).", | |
| ) | |
| parser.add_argument( | |
| "--negative_prompt", | |
| type=str, | |
| default=None, | |
| help="negative prompt for generation, default is empty string. should not change.", | |
| ) | |
| parser.add_argument("--video_size", type=int, nargs=2, default=[256, 256], help="video size, height and width") | |
| parser.add_argument("--video_seconds", type=float, default=5.0, help="video length, Default is 5.0 seconds") | |
| parser.add_argument("--fps", type=int, default=30, help="video fps, Default is 30") | |
| parser.add_argument("--infer_steps", type=int, default=25, help="number of inference steps, Default is 25") | |
| parser.add_argument("--save_path", type=str, required=True, help="path to save generated video") | |
| parser.add_argument("--seed", type=str, default=None, help="Seed for evaluation.") | |
| # parser.add_argument( | |
| # "--cpu_noise", action="store_true", help="Use CPU to generate noise (compatible with ComfyUI). Default is False." | |
| # ) | |
| parser.add_argument("--latent_window_size", type=int, default=9, help="latent window size, default is 9. should not change.") | |
| parser.add_argument( | |
| "--embedded_cfg_scale", type=float, default=10.0, help="Embeded CFG scale (distilled CFG Scale), default is 10.0" | |
| ) | |
| parser.add_argument( | |
| "--guidance_scale", | |
| type=float, | |
| default=1.0, | |
| help="Guidance scale for classifier free guidance. Default is 1.0, should not change.", | |
| ) | |
| parser.add_argument("--guidance_rescale", type=float, default=0.0, help="CFG Re-scale, default is 0.0. Should not change.") | |
| # parser.add_argument("--video_path", type=str, default=None, help="path to video for video2video inference") | |
| parser.add_argument( | |
| "--image_path", | |
| type=str, | |
| default=None, | |
| help="path to image for image2video inference. If `;;;` is used, it will be used as section images. The notation is same as `--prompt`.", | |
| ) | |
| parser.add_argument("--end_image_path", type=str, default=None, help="path to end image for image2video inference") | |
| # parser.add_argument( | |
| # "--control_path", | |
| # type=str, | |
| # default=None, | |
| # help="path to control video for inference with controlnet. video file or directory with images", | |
| # ) | |
| # parser.add_argument("--trim_tail_frames", type=int, default=0, help="trim tail N frames from the video before saving") | |
| # # Flow Matching | |
| # parser.add_argument( | |
| # "--flow_shift", | |
| # type=float, | |
| # default=None, | |
| # help="Shift factor for flow matching schedulers. Default depends on task.", | |
| # ) | |
| parser.add_argument("--fp8", action="store_true", help="use fp8 for DiT model") | |
| parser.add_argument("--fp8_scaled", action="store_true", help="use scaled fp8 for DiT, only for fp8") | |
| parser.add_argument("--fp8_fast", action="store_true", help="Enable fast FP8 arithmetic (RTX 4XXX+), only for fp8_scaled mode and can degrade quality slightly but offers noticeable speedup") | |
| parser.add_argument("--fp8_llm", action="store_true", help="use fp8 for Text Encoder 1 (LLM)") | |
| parser.add_argument( | |
| "--device", type=str, default=None, help="device to use for inference. If None, use CUDA if available, otherwise use CPU" | |
| ) | |
| parser.add_argument( | |
| "--attn_mode", | |
| type=str, | |
| default="torch", | |
| choices=["flash", "torch", "sageattn", "xformers", "sdpa"], # "flash2", "flash3", | |
| help="attention mode", | |
| ) | |
| parser.add_argument("--vae_chunk_size", type=int, default=None, help="chunk size for CausalConv3d in VAE") | |
| parser.add_argument( | |
| "--vae_spatial_tile_sample_min_size", type=int, default=None, help="spatial tile sample min size for VAE, default 256" | |
| ) | |
| parser.add_argument("--bulk_decode", action="store_true", help="decode all frames at once") | |
| parser.add_argument("--blocks_to_swap", type=int, default=0, help="number of blocks to swap in the model") | |
| parser.add_argument( | |
| "--output_type", type=str, default="video", choices=["video", "images", "latent", "both"], help="output type" | |
| ) | |
| parser.add_argument("--no_metadata", action="store_true", help="do not save metadata") | |
| parser.add_argument("--latent_path", type=str, nargs="*", default=None, help="path to latent for decode. no inference") | |
| parser.add_argument("--lycoris", action="store_true", help="use lycoris for inference") | |
| parser.add_argument("--compile", action="store_true", help="Enable torch.compile") | |
| parser.add_argument( | |
| "--compile_args", | |
| nargs=4, | |
| metavar=("BACKEND", "MODE", "DYNAMIC", "FULLGRAPH"), | |
| default=["inductor", "max-autotune-no-cudagraphs", "False", "False"], | |
| help="Torch.compile settings", | |
| ) | |
| # New arguments for batch and interactive modes | |
| parser.add_argument("--from_file", type=str, default=None, help="Read prompts from a file") | |
| parser.add_argument("--interactive", action="store_true", help="Interactive mode: read prompts from console") | |
| #parser.add_argument("--preview_latent_every", type=int, default=None, help="Preview latent every N sections") | |
| parser.add_argument("--preview_suffix", type=str, default=None, help="Unique suffix for preview files to avoid conflicts in concurrent runs.") | |
| parser.add_argument("--full_preview", action="store_true", help="Save full intermediate video previews instead of latent previews.") | |
| # TeaCache arguments | |
| parser.add_argument("--use_teacache", action="store_true", help="Enable TeaCache for faster generation.") | |
| parser.add_argument("--teacache_steps", type=int, default=25, help="Number of steps for TeaCache initialization (should match --infer_steps).") | |
| parser.add_argument("--teacache_thresh", type=float, default=0.15, help="Relative L1 distance threshold for TeaCache skipping.") | |
| parser.add_argument( | |
| "--video_sections", | |
| type=int, | |
| default=None, | |
| help="number of video sections, Default is None (auto calculate from video seconds). Overrides --video_seconds if set.", | |
| ) | |
| parser = add_blissful_args(parser) | |
| args = parser.parse_args() | |
| args = parse_blissful_args(args) | |
| # Validate arguments | |
| if args.from_file and args.interactive: | |
| raise ValueError("Cannot use both --from_file and --interactive at the same time") | |
| if args.prompt is None and not args.from_file and not args.interactive: | |
| raise ValueError("Either --prompt, --from_file or --interactive must be specified") | |
| return args | |
| def parse_prompt_line(line: str) -> Dict[str, Any]: | |
| """Parse a prompt line into a dictionary of argument overrides | |
| Args: | |
| line: Prompt line with options | |
| Returns: | |
| Dict[str, Any]: Dictionary of argument overrides | |
| """ | |
| # TODO common function with hv_train_network.line_to_prompt_dict | |
| parts = line.split(" --") | |
| prompt = parts[0].strip() | |
| # Create dictionary of overrides | |
| overrides = {"prompt": prompt} | |
| for part in parts[1:]: | |
| if not part.strip(): | |
| continue | |
| option_parts = part.split(" ", 1) | |
| option = option_parts[0].strip() | |
| value = option_parts[1].strip() if len(option_parts) > 1 else "" | |
| # Map options to argument names | |
| if option == "w": | |
| overrides["video_size_width"] = int(value) | |
| elif option == "h": | |
| overrides["video_size_height"] = int(value) | |
| elif option == "f": | |
| overrides["video_seconds"] = float(value) | |
| elif option == "d": | |
| overrides["seed"] = int(value) | |
| elif option == "s": | |
| overrides["infer_steps"] = int(value) | |
| elif option == "g" or option == "l": | |
| overrides["guidance_scale"] = float(value) | |
| # elif option == "fs": | |
| # overrides["flow_shift"] = float(value) | |
| elif option == "i": | |
| overrides["image_path"] = value | |
| elif option == "cn": | |
| overrides["control_path"] = value | |
| elif option == "n": | |
| overrides["negative_prompt"] = value | |
| return overrides | |
| def apply_overrides(args: argparse.Namespace, overrides: Dict[str, Any]) -> argparse.Namespace: | |
| """Apply overrides to args | |
| Args: | |
| args: Original arguments | |
| overrides: Dictionary of overrides | |
| Returns: | |
| argparse.Namespace: New arguments with overrides applied | |
| """ | |
| args_copy = copy.deepcopy(args) | |
| for key, value in overrides.items(): | |
| if key == "video_size_width": | |
| args_copy.video_size[1] = value | |
| elif key == "video_size_height": | |
| args_copy.video_size[0] = value | |
| else: | |
| setattr(args_copy, key, value) | |
| return args_copy | |
| def check_inputs(args: argparse.Namespace) -> Tuple[int, int, float]: | |
| """Validate video size and length | |
| Args: | |
| args: command line arguments | |
| Returns: | |
| Tuple[int, int, float]: (height, width, video_seconds) | |
| """ | |
| height = args.video_size[0] | |
| width = args.video_size[1] | |
| if args.video_sections is not None: | |
| video_seconds = (args.video_sections * (args.latent_window_size * 4) + 1) / args.fps | |
| logger.info(f"--video_sections is set to {args.video_sections}. Calculated video_seconds: {video_seconds:.2f}s") | |
| args.video_seconds = video_seconds | |
| else: | |
| video_seconds = args.video_seconds | |
| if height % 8 != 0 or width % 8 != 0: | |
| raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") | |
| return height, width, video_seconds | |
| # region DiT model | |
| def get_dit_dtype(args: argparse.Namespace) -> torch.dtype: | |
| dit_dtype = torch.bfloat16 | |
| if args.precision == "fp16": | |
| dit_dtype = torch.float16 | |
| elif args.precision == "fp32": | |
| dit_dtype = torch.float32 | |
| return dit_dtype | |
| def load_dit_model(args: argparse.Namespace, device: torch.device) -> HunyuanVideoTransformer3DModelPacked: | |
| """load DiT model | |
| Args: | |
| args: command line arguments | |
| device: device to use | |
| Returns: | |
| HunyuanVideoTransformer3DModelPacked: DiT model | |
| """ | |
| loading_device = "cpu" | |
| # Adjust loading device logic based on F1 requirements if necessary | |
| if args.blocks_to_swap == 0 and not args.fp8_scaled and args.lora_weight is None: | |
| loading_device = device | |
| # F1 model expects bfloat16 according to demo | |
| # However, load_packed_model might handle dtype internally based on checkpoint. | |
| # Let's keep the call as is for now. | |
| logger.info(f"Loading DiT model (Class: HunyuanVideoTransformer3DModelPacked) for {'F1' if args.is_f1 else 'Standard'} mode.") | |
| model = load_packed_model( | |
| device=device, | |
| dit_path=args.dit, | |
| attn_mode=args.attn_mode, | |
| loading_device=loading_device, | |
| # Pass fp8_scaled and split_attn if load_packed_model supports them directly | |
| # fp8_scaled=args.fp8_scaled, # Assuming load_packed_model handles this | |
| # split_attn=False, # F1 demo doesn't use split_attn | |
| ) | |
| return model | |
| def optimize_model(model: HunyuanVideoTransformer3DModelPacked, args: argparse.Namespace, device: torch.device) -> None: | |
| """optimize the model (FP8 conversion, device move etc.) | |
| Args: | |
| model: dit model | |
| args: command line arguments | |
| device: device to use | |
| """ | |
| if args.fp8_scaled: | |
| # load state dict as-is and optimize to fp8 | |
| state_dict = model.state_dict() | |
| # if no blocks to swap, we can move the weights to GPU after optimization on GPU (omit redundant CPU->GPU copy) | |
| move_to_device = args.blocks_to_swap == 0 # if blocks_to_swap > 0, we will keep the model on CPU | |
| state_dict = model.fp8_optimization(state_dict, device, move_to_device, use_scaled_mm=args.fp8_fast) # args.fp8_fast) | |
| info = model.load_state_dict(state_dict, strict=True, assign=True) | |
| logger.info(f"Loaded FP8 optimized weights: {info}") | |
| if args.blocks_to_swap == 0: | |
| model.to(device) # make sure all parameters are on the right device (e.g. RoPE etc.) | |
| else: | |
| # simple cast to dit_dtype | |
| target_dtype = None # load as-is (dit_weight_dtype == dtype of the weights in state_dict) | |
| target_device = None | |
| if args.fp8: | |
| target_dtype = torch.float8e4m3fn | |
| if args.blocks_to_swap == 0: | |
| logger.info(f"Move model to device: {device}") | |
| target_device = device | |
| if target_device is not None and target_dtype is not None: | |
| model.to(target_device, target_dtype) # move and cast at the same time. this reduces redundant copy operations | |
| if args.compile: | |
| compile_backend, compile_mode, compile_dynamic, compile_fullgraph = args.compile_args | |
| logger.info( | |
| f"Torch Compiling[Backend: {compile_backend}; Mode: {compile_mode}; Dynamic: {compile_dynamic}; Fullgraph: {compile_fullgraph}]" | |
| ) | |
| torch._dynamo.config.cache_size_limit = 32 | |
| for i in range(len(model.transformer_blocks)): | |
| model.transformer_blocks[i] = torch.compile( | |
| model.transformer_blocks[i], | |
| backend=compile_backend, | |
| mode=compile_mode, | |
| dynamic=compile_dynamic.lower() in "true", | |
| fullgraph=compile_fullgraph.lower() in "true", | |
| ) | |
| if args.blocks_to_swap > 0: | |
| logger.info(f"Enable swap {args.blocks_to_swap} blocks to CPU from device: {device}") | |
| model.enable_block_swap(args.blocks_to_swap, device, supports_backward=False) | |
| model.move_to_device_except_swap_blocks(device) | |
| model.prepare_block_swap_before_forward() | |
| else: | |
| # make sure the model is on the right device | |
| model.to(device) | |
| model.eval().requires_grad_(False) | |
| clean_memory_on_device(device) | |
| # endregion | |
| def decode_latent( | |
| latent_window_size: int, | |
| total_latent_sections: int, | |
| bulk_decode: bool, | |
| vae: AutoencoderKLCausal3D, | |
| latent: torch.Tensor, | |
| device: torch.device, | |
| ) -> torch.Tensor: | |
| logger.info(f"Decoding video...") | |
| if latent.ndim == 4: | |
| latent = latent.unsqueeze(0) # add batch dimension | |
| vae.to(device) | |
| if not bulk_decode: | |
| latent_window_size = latent_window_size # default is 9 | |
| # total_latent_sections = (args.video_seconds * 30) / (latent_window_size * 4) | |
| # total_latent_sections = int(max(round(total_latent_sections), 1)) | |
| num_frames = latent_window_size * 4 - 3 | |
| latents_to_decode = [] | |
| latent_frame_index = 0 | |
| for i in range(total_latent_sections - 1, -1, -1): | |
| is_last_section = i == total_latent_sections - 1 | |
| generated_latent_frames = (num_frames + 3) // 4 + (1 if is_last_section else 0) | |
| section_latent_frames = (latent_window_size * 2 + 1) if is_last_section else (latent_window_size * 2) | |
| section_latent = latent[:, :, latent_frame_index : latent_frame_index + section_latent_frames, :, :] | |
| latents_to_decode.append(section_latent) | |
| latent_frame_index += generated_latent_frames | |
| latents_to_decode = latents_to_decode[::-1] # reverse the order of latents to decode | |
| history_pixels = None | |
| for latent in tqdm(latents_to_decode): | |
| if history_pixels is None: | |
| history_pixels = hunyuan.vae_decode(latent, vae).cpu() | |
| else: | |
| overlapped_frames = latent_window_size * 4 - 3 | |
| current_pixels = hunyuan.vae_decode(latent, vae).cpu() | |
| history_pixels = soft_append_bcthw(current_pixels, history_pixels, overlapped_frames) | |
| clean_memory_on_device(device) | |
| else: | |
| # bulk decode | |
| logger.info(f"Bulk decoding") | |
| history_pixels = hunyuan.vae_decode(latent, vae).cpu() | |
| vae.to("cpu") | |
| logger.info(f"Decoded. Pixel shape {history_pixels.shape}") | |
| return history_pixels[0] # remove batch dimension | |
| def prepare_i2v_inputs( | |
| args: argparse.Namespace, | |
| device: torch.device, | |
| vae: AutoencoderKLCausal3D, | |
| encoded_context: Optional[Dict] = None, | |
| encoded_context_n: Optional[Dict] = None, | |
| ) -> Tuple[int, int, float, dict, dict, dict, torch.Tensor]: # Adjusted return type annotation | |
| """Prepare inputs for I2V | |
| Args: | |
| args: command line arguments | |
| device: device to use | |
| vae: VAE model, used for image encoding | |
| encoded_context: Pre-encoded text context | |
| encoded_context_n: Pre-encoded negative text context | |
| Returns: | |
| Tuple[int, int, float, dict, dict, dict, torch.Tensor]: | |
| (height, width, video_seconds, context, context_null, context_img, end_latent) | |
| """ | |
| def parse_section_strings(input_string: str) -> dict[int, str]: | |
| section_strings = {} | |
| if not input_string: # Handle empty input string | |
| return {0: ""} | |
| if ";;;" in input_string: | |
| split_section_strings = input_string.split(";;;") | |
| for section_str in split_section_strings: | |
| if ":" not in section_str: | |
| start = end = 0 | |
| section_str_val = section_str.strip() | |
| else: | |
| index_str, section_str_val = section_str.split(":", 1) | |
| index_str = index_str.strip() | |
| section_str_val = section_str_val.strip() | |
| m = re.match(r"^(-?\d+)(-\d+)?$", index_str) | |
| if m: | |
| start = int(m.group(1)) | |
| end = int(m.group(2)[1:]) if m.group(2) is not None else start | |
| else: | |
| start = end = 0 # Default to 0 if index format is invalid | |
| # Handle negative indices relative to a hypothetical 'last section' (-1) | |
| # This part is tricky without knowing the total sections beforehand. | |
| # For now, treat negative indices directly. A better approach might involve | |
| # resolving them later in the generation loop. | |
| for i in range(start, end + 1): | |
| section_strings[i] = section_str_val | |
| else: | |
| # If no section specifiers, assume section 0 | |
| section_strings[0] = input_string.strip() | |
| # Ensure section 0 exists if any sections are defined | |
| if section_strings and 0 not in section_strings: | |
| indices = list(section_strings.keys()) | |
| # Prefer smallest non-negative index, otherwise smallest negative index | |
| try: | |
| first_positive_index = min(i for i in indices if i >= 0) | |
| section_index = first_positive_index | |
| except ValueError: # No non-negative indices | |
| section_index = min(indices) if indices else 0 # Fallback to 0 if empty | |
| if section_index in section_strings: | |
| section_strings[0] = section_strings[section_index] | |
| elif section_strings: # If section_index wasn't valid somehow, pick first available | |
| section_strings[0] = next(iter(section_strings.values())) | |
| else: # If section_strings was empty initially | |
| section_strings[0] = "" # Default empty prompt | |
| # If still no section 0 (e.g., empty input string initially) | |
| if 0 not in section_strings: | |
| section_strings[0] = "" | |
| return section_strings | |
| # prepare image preprocessing function | |
| def preprocess_image(image_path: str, target_height: int, target_width: int, is_f1: bool): # is_f1 is kept for signature, but not used differently here | |
| image = Image.open(image_path).convert("RGB") | |
| image_np = np.array(image) # PIL to numpy, HWC | |
| # Consistent image preprocessing for both F1 and standard mode, | |
| # using target_height/target_width which come from args.video_size | |
| image_np = image_video_dataset.resize_image_to_bucket(image_np, (target_width, target_height)) | |
| processed_height, processed_width = image_np.shape[0], image_np.shape[1] # Get actual size after resize | |
| image_tensor = torch.from_numpy(image_np).float() / 127.5 - 1.0 # -1 to 1.0, HWC | |
| image_tensor = image_tensor.permute(2, 0, 1)[None, :, None] # HWC -> CHW -> NCFHW, N=1, C=3, F=1 | |
| return image_tensor, image_np, processed_height, processed_width | |
| # Initial height/width check. These dimensions will be used for image processing and generation. | |
| height, width, video_seconds = check_inputs(args) | |
| logger.info(f"Video dimensions for processing and generation set to: {height}x{width} (from --video_size or default).") | |
| section_image_paths = parse_section_strings(args.image_path) | |
| section_images = {} | |
| first_image_processed = False | |
| for index, image_path in section_image_paths.items(): | |
| img_tensor, img_np, proc_h, proc_w = preprocess_image(image_path, height, width, args.is_f1) | |
| section_images[index] = (img_tensor, img_np) | |
| if not first_image_processed and image_path: | |
| default_video_size_used = (args.video_size[0] == 256 and args.video_size[1] == 256) # Check if default was used | |
| if default_video_size_used and (proc_h != height or proc_w != width): | |
| logger.info(f"Video dimensions updated to {proc_h}x{proc_w} based on first image processing (as default --video_size was used).") | |
| height, width = proc_h, proc_w | |
| args.video_size = [height, width] # Update args for consistency for downstream logging/metadata. | |
| elif not default_video_size_used and (proc_h != height or proc_w != width): | |
| logger.warning(f"User specified --video_size {height}x{width}, but first image processed to {proc_h}x{proc_w}. " | |
| f"Generation will use {height}x{width}. Conditioning image aspect might differ.") | |
| first_image_processed = True | |
| # Process end image if provided | |
| if args.end_image_path is not None: | |
| end_img_tensor, end_img_np, _, _ = preprocess_image(args.end_image_path, height, width, args.is_f1) | |
| else: | |
| end_img_tensor, end_img_np = None, None | |
| # configure negative prompt | |
| n_prompt = args.negative_prompt if args.negative_prompt else "" | |
| if encoded_context is None or encoded_context_n is None: # Regenerate if either is missing | |
| # parse section prompts | |
| section_prompts = parse_section_strings(args.prompt) | |
| # load text encoder | |
| # Assuming load_text_encoder1/2 are compatible | |
| tokenizer1, text_encoder1 = load_text_encoder1(args, args.fp8_llm, device) | |
| tokenizer2, text_encoder2 = load_text_encoder2(args) | |
| text_encoder2.to(device) | |
| logger.info(f"Encoding prompts...") | |
| llama_vecs = {} | |
| llama_attention_masks = {} | |
| clip_l_poolers = {} | |
| # Use a common dtype for text encoders if possible, respecting fp8 flag | |
| text_encoder_dtype = torch.float8_e4m3fn if args.fp8_llm else torch.float16 # text_encoder1.dtype | |
| # Pre-allocate negative prompt tensors only if needed | |
| llama_vec_n, clip_l_pooler_n = None, None | |
| llama_attention_mask_n = None | |
| # Encode positive prompts first | |
| with torch.autocast(device_type=device.type, dtype=text_encoder_dtype), torch.no_grad(): | |
| for index, prompt in section_prompts.items(): | |
| # Ensure prompt is not empty before encoding | |
| current_prompt = prompt if prompt else "" # Use empty string if prompt is None or empty | |
| llama_vec, clip_l_pooler = hunyuan.encode_prompt_conds(current_prompt, text_encoder1, text_encoder2, tokenizer1, tokenizer2) | |
| # Pad/crop and store | |
| llama_vec_padded, llama_attention_mask = crop_or_pad_yield_mask(llama_vec.cpu(), length=512) # Move to CPU before padding | |
| llama_vecs[index] = llama_vec_padded | |
| llama_attention_masks[index] = llama_attention_mask | |
| clip_l_poolers[index] = clip_l_pooler.cpu() # Move to CPU | |
| # Use the encoding of section 0 as fallback for negative if needed | |
| if index == 0 and args.guidance_scale == 1.0: | |
| llama_vec_n = torch.zeros_like(llama_vec_padded) | |
| llama_attention_mask_n = torch.zeros_like(llama_attention_mask) | |
| clip_l_pooler_n = torch.zeros_like(clip_l_poolers[0]) | |
| # Encode negative prompt if needed | |
| if args.guidance_scale != 1.0: | |
| with torch.autocast(device_type=device.type, dtype=text_encoder_dtype), torch.no_grad(): | |
| current_n_prompt = n_prompt if n_prompt else "" | |
| llama_vec_n_raw, clip_l_pooler_n_raw = hunyuan.encode_prompt_conds( | |
| current_n_prompt, text_encoder1, text_encoder2, tokenizer1, tokenizer2 | |
| ) | |
| llama_vec_n, llama_attention_mask_n = crop_or_pad_yield_mask(llama_vec_n_raw.cpu(), length=512) # Move to CPU | |
| clip_l_pooler_n = clip_l_pooler_n_raw.cpu() # Move to CPU | |
| # Check if negative prompt was generated (handles guidance_scale=1.0 case) | |
| if llama_vec_n is None: | |
| logger.warning("Negative prompt tensors not generated (likely guidance_scale=1.0). Using zeros.") | |
| # Assuming section 0 exists and was processed | |
| llama_vec_n = torch.zeros_like(llama_vecs[0]) | |
| llama_attention_mask_n = torch.zeros_like(llama_attention_masks[0]) | |
| clip_l_pooler_n = torch.zeros_like(clip_l_poolers[0]) | |
| # free text encoder and clean memory | |
| del text_encoder1, text_encoder2, tokenizer1, tokenizer2 | |
| clean_memory_on_device(device) | |
| # load image encoder (Handles SigLIP via framepack_utils) | |
| feature_extractor, image_encoder = load_image_encoders(args) | |
| image_encoder.to(device) | |
| # encode image with image encoder | |
| logger.info(f"Encoding images with {'SigLIP' if args.is_f1 else 'Image Encoder'}...") | |
| section_image_encoder_last_hidden_states = {} | |
| img_encoder_dtype = image_encoder.dtype # Get dtype from loaded model | |
| end_image_embedding_for_f1 = None # Initialize for F1 end image | |
| with torch.autocast(device_type=device.type, dtype=img_encoder_dtype), torch.no_grad(): | |
| for index, (img_tensor, img_np) in section_images.items(): | |
| # Use hf_clip_vision_encode (works for SigLIP too) | |
| image_encoder_output = hf_clip_vision_encode(img_np, feature_extractor, image_encoder) | |
| image_encoder_last_hidden_state = image_encoder_output.last_hidden_state.cpu() # Move to CPU | |
| section_image_encoder_last_hidden_states[index] = image_encoder_last_hidden_state | |
| if args.is_f1 and end_img_np is not None: # end_img_np is from args.end_image_path | |
| logger.info("F1 Mode: Encoding end image for potential conditioning.") | |
| end_image_encoder_output_f1 = hf_clip_vision_encode(end_img_np, feature_extractor, image_encoder) | |
| end_image_embedding_for_f1 = end_image_encoder_output_f1.last_hidden_state.cpu() | |
| # free image encoder and clean memory | |
| del image_encoder, feature_extractor | |
| clean_memory_on_device(device) | |
| # --- Store encoded contexts for potential reuse --- | |
| # Positive context (bundle per unique prompt string if needed, or just section 0) | |
| # For simplicity, let's assume we only cache based on args.prompt for now | |
| encoded_context = { | |
| "llama_vecs": llama_vecs, | |
| "llama_attention_masks": llama_attention_masks, | |
| "clip_l_poolers": clip_l_poolers, | |
| "image_encoder_last_hidden_states": section_image_encoder_last_hidden_states # Store all section states | |
| } | |
| # Negative context | |
| encoded_context_n = { | |
| "llama_vec": llama_vec_n, | |
| "llama_attention_mask": llama_attention_mask_n, | |
| "clip_l_pooler": clip_l_pooler_n, | |
| } | |
| # --- End context caching --- | |
| else: | |
| # Use pre-encoded context | |
| logger.info("Using pre-encoded context.") | |
| llama_vecs = encoded_context["llama_vecs"] | |
| llama_attention_masks = encoded_context["llama_attention_masks"] | |
| clip_l_poolers = encoded_context["clip_l_poolers"] | |
| section_image_encoder_last_hidden_states = encoded_context["image_encoder_last_hidden_states"] # Retrieve all sections | |
| llama_vec_n = encoded_context_n["llama_vec"] | |
| llama_attention_mask_n = encoded_context_n["llama_attention_mask"] | |
| clip_l_pooler_n = encoded_context_n["clip_l_pooler"] | |
| # Need to re-parse section prompts if using cached context | |
| section_prompts = parse_section_strings(args.prompt) | |
| # VAE encoding | |
| logger.info(f"Encoding image(s) to latent space...") | |
| vae.to(device) | |
| vae_dtype = vae.dtype # Get VAE dtype | |
| section_start_latents = {} | |
| with torch.autocast(device_type=device.type, dtype=vae_dtype), torch.no_grad(): | |
| for index, (img_tensor, img_np) in section_images.items(): | |
| start_latent = hunyuan.vae_encode(img_tensor, vae).cpu() # Move to CPU | |
| section_start_latents[index] = start_latent | |
| end_latent = hunyuan.vae_encode(end_img_tensor, vae).cpu() if end_img_tensor is not None else None # Move to CPU | |
| vae.to("cpu") # move VAE to CPU to save memory | |
| clean_memory_on_device(device) | |
| # prepare model input arguments | |
| arg_c = {} # Positive text conditioning per section | |
| arg_c_img = {} # Positive image conditioning per section | |
| # Ensure section_prompts is available (parsed earlier) | |
| if 'section_prompts' not in locals(): | |
| section_prompts = parse_section_strings(args.prompt) | |
| # Populate positive text args | |
| for index in llama_vecs.keys(): | |
| # Get corresponding prompt, defaulting to empty string if index missing | |
| prompt_text = section_prompts.get(index, "") | |
| arg_c_i = { | |
| "llama_vec": llama_vecs[index], | |
| "llama_attention_mask": llama_attention_masks[index], | |
| "clip_l_pooler": clip_l_poolers[index], | |
| "prompt": prompt_text, # Include the actual prompt text | |
| } | |
| arg_c[index] = arg_c_i | |
| # Populate negative text args (only one needed) | |
| arg_null = { | |
| "llama_vec": llama_vec_n, | |
| "llama_attention_mask": llama_attention_mask_n, | |
| "clip_l_pooler": clip_l_pooler_n, | |
| "prompt": n_prompt, # Include negative prompt text | |
| } | |
| # Populate positive image args | |
| for index in section_start_latents.keys(): # Use latents keys as reference | |
| # Check if corresponding hidden state exists, fallback to section 0 if needed | |
| image_encoder_last_hidden_state = section_image_encoder_last_hidden_states.get(index, section_image_encoder_last_hidden_states.get(0)) | |
| if image_encoder_last_hidden_state is None and section_image_encoder_last_hidden_states: | |
| # Absolute fallback if index and 0 are missing but others exist | |
| image_encoder_last_hidden_state = next(iter(section_image_encoder_last_hidden_states.values())) | |
| elif image_encoder_last_hidden_state is None: | |
| raise ValueError(f"Cannot find image encoder state for section {index} or fallback section 0.") | |
| arg_c_img_i = { | |
| "image_encoder_last_hidden_state": image_encoder_last_hidden_state, | |
| "start_latent": section_start_latents[index] | |
| } | |
| arg_c_img[index] = arg_c_img_i | |
| # Ensure fallback section 0 exists in arg_c and arg_c_img if needed later | |
| if 0 not in arg_c and arg_c: | |
| arg_c[0] = next(iter(arg_c.values())) | |
| if 0 not in arg_c_img and arg_c_img: | |
| arg_c_img[0] = next(iter(arg_c_img.values())) | |
| # Final check for minimal context existence | |
| if not arg_c or not arg_c_img: | |
| raise ValueError("Failed to prepare conditioning arguments. Check prompts and image paths.") | |
| return height, width, video_seconds, arg_c, arg_null, arg_c_img, end_latent, end_image_embedding_for_f1 | |
| # def setup_scheduler(args: argparse.Namespace, config, device: torch.device) -> Tuple[Any, torch.Tensor]: | |
| # """setup scheduler for sampling | |
| # Args: | |
| # args: command line arguments | |
| # config: model configuration | |
| # device: device to use | |
| # Returns: | |
| # Tuple[Any, torch.Tensor]: (scheduler, timesteps) | |
| # """ | |
| # if args.sample_solver == "unipc": | |
| # scheduler = FlowUniPCMultistepScheduler(num_train_timesteps=config.num_train_timesteps, shift=1, use_dynamic_shifting=False) | |
| # scheduler.set_timesteps(args.infer_steps, device=device, shift=args.flow_shift) | |
| # timesteps = scheduler.timesteps | |
| # elif args.sample_solver == "dpm++": | |
| # scheduler = FlowDPMSolverMultistepScheduler( | |
| # num_train_timesteps=config.num_train_timesteps, shift=1, use_dynamic_shifting=False | |
| # ) | |
| # sampling_sigmas = get_sampling_sigmas(args.infer_steps, args.flow_shift) | |
| # timesteps, _ = retrieve_timesteps(scheduler, device=device, sigmas=sampling_sigmas) | |
| # elif args.sample_solver == "vanilla": | |
| # scheduler = FlowMatchDiscreteScheduler(num_train_timesteps=config.num_train_timesteps, shift=args.flow_shift) | |
| # scheduler.set_timesteps(args.infer_steps, device=device) | |
| # timesteps = scheduler.timesteps | |
| # # FlowMatchDiscreteScheduler does not support generator argument in step method | |
| # org_step = scheduler.step | |
| # def step_wrapper( | |
| # model_output: torch.Tensor, | |
| # timestep: Union[int, torch.Tensor], | |
| # sample: torch.Tensor, | |
| # return_dict: bool = True, | |
| # generator=None, | |
| # ): | |
| # return org_step(model_output, timestep, sample, return_dict=return_dict) | |
| # scheduler.step = step_wrapper | |
| # else: | |
| # raise NotImplementedError("Unsupported solver.") | |
| # return scheduler, timesteps | |
| # In fpack_generate_video.py | |
| def generate(args: argparse.Namespace, gen_settings: GenerationSettings, shared_models: Optional[Dict] = None) -> Tuple[AutoencoderKLCausal3D, torch.Tensor]: # Return VAE too | |
| """main function for generation | |
| Args: | |
| args: command line arguments | |
| gen_settings: Generation settings object | |
| shared_models: dictionary containing pre-loaded models and encoded data | |
| Returns: | |
| Tuple[AutoencoderKLCausal3D, torch.Tensor]: vae, generated latent | |
| """ | |
| device, dit_weight_dtype = (gen_settings.device, gen_settings.dit_weight_dtype) | |
| # prepare seed | |
| seed = args.seed if args.seed is not None else random.randint(0, 2**32 - 1) | |
| # Ensure seed is integer | |
| if isinstance(seed, str): | |
| try: | |
| seed = int(seed) | |
| except ValueError: | |
| logger.warning(f"Invalid seed string: {seed}. Generating random seed.") | |
| seed = random.randint(0, 2**32 - 1) | |
| elif not isinstance(seed, int): | |
| logger.warning(f"Invalid seed type: {type(seed)}. Generating random seed.") | |
| seed = random.randint(0, 2**32 - 1) | |
| args.seed = seed # set seed to args for saving | |
| vae = None # Initialize VAE | |
| # Check if we have shared models | |
| if shared_models is not None: | |
| # Use shared models and encoded data | |
| vae = shared_models.get("vae") | |
| model = shared_models.get("model") | |
| # --- Retrieve cached context --- | |
| # Try to get context based on the full prompt string first | |
| prompt_key = args.prompt if args.prompt else "" | |
| n_prompt_key = args.negative_prompt if args.negative_prompt else "" | |
| encoded_context = shared_models.get("encoded_contexts", {}).get(prompt_key) | |
| encoded_context_n = shared_models.get("encoded_contexts", {}).get(n_prompt_key) | |
| # If not found, maybe the cache uses a simpler key (like just section 0?) - needs alignment with prepare_i2v_inputs caching logic | |
| # For now, assume prepare_i2v_inputs handles regeneration if cache miss | |
| if encoded_context is None or encoded_context_n is None: | |
| logger.info("Cached context not found or incomplete, preparing inputs.") | |
| # Need VAE for preparation if regenerating context | |
| if vae is None: | |
| vae = load_vae(args.vae, args.vae_chunk_size, args.vae_spatial_tile_sample_min_size, device) | |
| height, width, video_seconds, context, context_null, context_img, end_latent = prepare_i2v_inputs( | |
| args, device, vae # Pass VAE here | |
| ) | |
| # Store newly generated context back? (Requires shared_models to be mutable and handled carefully) | |
| # shared_models["encoded_contexts"][prompt_key] = context # Simplified example | |
| # shared_models["encoded_contexts"][n_prompt_key] = context_null # Simplified example | |
| else: | |
| logger.info("Using cached context from shared models.") | |
| # Need VAE if decoding later, load if not present | |
| if vae is None: | |
| vae = load_vae(args.vae, args.vae_chunk_size, args.vae_spatial_tile_sample_min_size, device) | |
| height, width, video_seconds, context, context_null, context_img, end_latent = prepare_i2v_inputs( | |
| args, device, vae, encoded_context, encoded_context_n | |
| ) | |
| # --- End context retrieval --- | |
| else: | |
| # prepare inputs without shared models | |
| vae = load_vae(args.vae, args.vae_chunk_size, args.vae_spatial_tile_sample_min_size, device) | |
| height, width, video_seconds, context, context_null, context_img, end_latent, end_image_embedding_for_f1 = prepare_i2v_inputs(args, device, vae) | |
| # load DiT model | |
| model = load_dit_model(args, device) # Handles F1 class loading implicitly | |
| # merge LoRA weights | |
| if args.lora_weight is not None and len(args.lora_weight) > 0: | |
| # Ensure merge_lora_weights can handle HunyuanVideoTransformer3DModelPacked | |
| # It might need adjustments depending on its implementation. | |
| logger.info("Merging LoRA weights...") | |
| # Assuming lora_framepack is the correct network type definition | |
| # Make sure merge_lora_weights exists and is imported | |
| try: | |
| from base_wan_generate_video import merge_lora_weights # Example import path | |
| merge_lora_weights(lora_framepack, model, args, device) | |
| except ImportError: | |
| logger.error("merge_lora_weights function not found. Skipping LoRA merge.") | |
| except Exception as e: | |
| logger.error(f"Error merging LoRA weights: {e}") | |
| # if we only want to save the model, we can skip the rest | |
| if args.save_merged_model: | |
| # Implement saving logic here if merge_lora_weights doesn't handle it | |
| logger.info(f"Saving merged model to {args.save_merged_model} and exiting.") | |
| # Example: save_model(model, args.save_merged_model) | |
| return None, None # Indicate no generation occurred | |
| # optimize model: fp8 conversion, block swap etc. | |
| optimize_model(model, args, device) | |
| if args.use_teacache: | |
| logger.info(f"Initializing TeaCache: steps={args.teacache_steps}, threshold={args.teacache_thresh}") | |
| # The model's initialize_teacache expects num_steps and rel_l1_thresh | |
| model.initialize_teacache( | |
| enable_teacache=True, | |
| num_steps=args.teacache_steps, | |
| rel_l1_thresh=args.teacache_thresh | |
| ) | |
| else: | |
| logger.info("TeaCache is disabled.") | |
| # Ensure it's explicitly disabled in the model too, just in case | |
| model.initialize_teacache(enable_teacache=False) | |
| # --- Sampling --- | |
| latent_window_size = args.latent_window_size # default is 9 (consistent with F1 demo) | |
| if args.video_sections is not None: | |
| total_latent_sections = args.video_sections | |
| logger.info(f"Using --video_sections: {total_latent_sections} sections.") | |
| else: | |
| total_latent_sections = (video_seconds * args.fps) / (latent_window_size * 4) | |
| total_latent_sections = int(max(round(total_latent_sections), 1)) | |
| logger.info(f"Calculated total_latent_sections from video_seconds: {total_latent_sections} sections.") | |
| # set random generator | |
| seed_g = torch.Generator(device="cpu") # Keep noise on CPU initially | |
| seed_g.manual_seed(seed) | |
| # F1 expects frames = latent_window_size * 4 - 3 | |
| # Our script's default decode uses latent_window_size * 4 - 3 overlap | |
| # Let's calculate F1 frames per section explicitly | |
| f1_frames_per_section = latent_window_size * 4 - 3 | |
| logger.info( | |
| f"Mode: {'F1' if args.is_f1 else 'Standard'}, " | |
| f"Video size: {height}x{width}@{video_seconds:.2f}s, fps: {args.fps}, num sections: {total_latent_sections}, " | |
| f"infer_steps: {args.infer_steps}, frames per generation step: {f1_frames_per_section}" | |
| ) | |
| # Determine compute dtype based on model/args | |
| compute_dtype = model.dtype if hasattr(model, 'dtype') else torch.bfloat16 # Default for F1 | |
| if args.fp8 or args.fp8_scaled: | |
| # FP8 might still use bfloat16/float16 for some operations | |
| logger.info("FP8 enabled, using bfloat16 for intermediate computations.") | |
| compute_dtype = torch.bfloat16 # Or potentially float16 depending on model/ops | |
| logger.info(f"Using compute dtype: {compute_dtype}") | |
| # --- F1 Model Specific Sampling Logic --- | |
| if args.is_f1: # Renamed from args.f1 in simpler script to args.is_f1 | |
| logger.info("Starting F1 model sampling process.") | |
| logger.info(f"F1 Mode: Using video dimensions {height}x{width} for latent operations and generation.") | |
| history_latents = torch.zeros((1, 16, 19, height // 8, width // 8), dtype=torch.float32, device='cpu') | |
| start_latent_0 = context_img.get(0, {}).get("start_latent") | |
| if start_latent_0 is None: | |
| raise ValueError("Cannot find start_latent for section 0 in context_img.") | |
| if start_latent_0.shape[3] != (height // 8) or start_latent_0.shape[4] != (width // 8): | |
| logger.error(f"Mismatch between start_latent_0 dimensions ({start_latent_0.shape[3]}x{start_latent_0.shape[4]}) " | |
| f"and history_latents dimensions ({height//8}x{width//8}). This should not happen with current logic.") | |
| history_latents = torch.cat([history_latents, start_latent_0.cpu().float()], dim=2) | |
| history_pixels_for_preview_f1_cpu = None | |
| if args.full_preview and args.preview_latent_every is not None: | |
| if vae is None: | |
| logger.error("VAE not available for initial F1 preview setup.") | |
| else: | |
| logger.info("F1 Full Preview: Decoding initial start_latent for preview history.") | |
| vae.to(device) | |
| initial_latent_for_preview = start_latent_0.to(device, dtype=vae.dtype if hasattr(vae, 'dtype') else torch.float16) | |
| # Assuming vae_decode returns BCTHW or CTHW. Ensure BCTHW for history_pixels. | |
| decoded_initial = hunyuan.vae_decode(initial_latent_for_preview, vae).cpu() | |
| if decoded_initial.ndim == 4: # CTHW | |
| history_pixels_for_preview_f1_cpu = decoded_initial.unsqueeze(0) | |
| elif decoded_initial.ndim == 5: # BCTHW | |
| history_pixels_for_preview_f1_cpu = decoded_initial | |
| else: | |
| logger.error(f"Unexpected dimensions from initial VAE decode: {decoded_initial.shape}") | |
| vae.to("cpu") | |
| clean_memory_on_device(device) | |
| total_generated_latent_frames = 1 # Account for the initial start_latent_0 in history_latents | |
| if args.preview_latent_every and not args.full_preview: | |
| previewer = LatentPreviewer(args, vae, None, gen_settings.device, compute_dtype, model_type="framepack") | |
| else: | |
| previewer = None | |
| for section_index in range(total_latent_sections): | |
| logger.info(f"--- F1 Section {section_index + 1} / {total_latent_sections} ---") | |
| f1_split_sizes = [1, 16, 2, 1, args.latent_window_size] | |
| f1_indices = torch.arange(0, sum(f1_split_sizes)).unsqueeze(0).to(device) | |
| ( | |
| f1_clean_latent_indices_start, | |
| f1_clean_latent_4x_indices, | |
| f1_clean_latent_2x_indices, | |
| f1_clean_latent_1x_indices, | |
| f1_latent_indices, | |
| ) = f1_indices.split(f1_split_sizes, dim=1) | |
| f1_clean_latent_indices = torch.cat([f1_clean_latent_indices_start, f1_clean_latent_1x_indices], dim=1) | |
| current_image_context_section_idx = section_index if section_index in context_img else 0 | |
| current_start_latent = context_img[current_image_context_section_idx]["start_latent"].to(device, dtype=torch.float32) | |
| current_history_for_f1_clean = history_latents[:, :, -sum([16, 2, 1]):, :, :].to(device, dtype=torch.float32) | |
| f1_clean_latents_4x, f1_clean_latents_2x, f1_clean_latents_1x = current_history_for_f1_clean.split([16, 2, 1], dim=2) | |
| f1_clean_latents_combined = torch.cat([current_start_latent, f1_clean_latents_1x], dim=2) | |
| context_section_idx = section_index if section_index in context else 0 | |
| llama_vec = context[context_section_idx]["llama_vec"].to(device, dtype=compute_dtype) | |
| llama_attention_mask = context[context_section_idx]["llama_attention_mask"].to(device) | |
| clip_l_pooler = context[context_section_idx]["clip_l_pooler"].to(device, dtype=compute_dtype) | |
| image_encoder_last_hidden_state = context_img[current_image_context_section_idx]["image_encoder_last_hidden_state"].to(device, dtype=compute_dtype) | |
| llama_vec_n = context_null["llama_vec"].to(device, dtype=compute_dtype) | |
| llama_attention_mask_n = context_null["llama_attention_mask"].to(device) | |
| clip_l_pooler_n = context_null["clip_l_pooler"].to(device, dtype=compute_dtype) | |
| # generated_latents_step is on GPU after sample_hunyuan | |
| generated_latents_step = sample_hunyuan( | |
| transformer=model, sampler=args.sample_solver, width=width, height=height, | |
| frames=f1_frames_per_section, real_guidance_scale=args.guidance_scale, | |
| distilled_guidance_scale=args.embedded_cfg_scale, guidance_rescale=args.guidance_rescale, | |
| num_inference_steps=args.infer_steps, generator=seed_g, | |
| prompt_embeds=llama_vec, prompt_embeds_mask=llama_attention_mask, prompt_poolers=clip_l_pooler, | |
| negative_prompt_embeds=llama_vec_n, negative_prompt_embeds_mask=llama_attention_mask_n, negative_prompt_poolers=clip_l_pooler_n, | |
| device=device, dtype=compute_dtype, image_embeddings=image_encoder_last_hidden_state, | |
| latent_indices=f1_latent_indices, clean_latents=f1_clean_latents_combined, clean_latent_indices=f1_clean_latent_indices, | |
| clean_latents_2x=f1_clean_latents_2x, clean_latent_2x_indices=f1_clean_latent_2x_indices, | |
| clean_latents_4x=f1_clean_latents_4x, clean_latent_4x_indices=f1_clean_latent_4x_indices, | |
| ) | |
| newly_generated_latent_frames_count_this_step = int(generated_latents_step.shape[2]) | |
| history_latents = torch.cat([history_latents, generated_latents_step.cpu().float()], dim=2) | |
| total_generated_latent_frames += newly_generated_latent_frames_count_this_step | |
| if args.preview_latent_every is not None and (section_index + 1) % args.preview_latent_every == 0: | |
| if args.full_preview: | |
| logger.info(f"Saving full F1 preview at section {section_index + 1}") | |
| if vae is None: | |
| logger.error("VAE not available for full F1 preview.") | |
| else: | |
| preview_filename_full = os.path.join(args.save_path, f"latent_preview_{args.preview_suffix if args.preview_suffix else section_index + 1}.mp4") | |
| latents_this_step_for_decode = generated_latents_step.to(device, dtype=vae.dtype if hasattr(vae, 'dtype') else torch.float16) | |
| vae.to(device) | |
| pixels_this_step_decoded_cpu = hunyuan.vae_decode(latents_this_step_for_decode, vae).cpu() | |
| vae.to("cpu") | |
| if pixels_this_step_decoded_cpu.ndim == 4: | |
| pixels_this_step_decoded_cpu = pixels_this_step_decoded_cpu.unsqueeze(0) | |
| if history_pixels_for_preview_f1_cpu is None: | |
| history_pixels_for_preview_f1_cpu = pixels_this_step_decoded_cpu | |
| else: | |
| overlap_pixels = args.latent_window_size * 4 - 3 | |
| history_pixels_for_preview_f1_cpu = soft_append_bcthw( | |
| history_pixels_for_preview_f1_cpu, | |
| pixels_this_step_decoded_cpu, | |
| overlap=overlap_pixels | |
| ) | |
| save_bcthw_as_mp4(history_pixels_for_preview_f1_cpu, preview_filename_full, fps=args.fps, crf=getattr(args, 'mp4_crf', 16)) | |
| logger.info(f"Full F1 preview saved to {preview_filename_full}") | |
| del latents_this_step_for_decode, pixels_this_step_decoded_cpu | |
| clean_memory_on_device(device) | |
| elif previewer is not None: | |
| logger.info(f"Previewing latents at F1 section {section_index + 1}") | |
| preview_latents_f1_for_pv = history_latents[:, :, -total_generated_latent_frames:, :, :].to(gen_settings.device) | |
| previewer.preview(preview_latents_f1_for_pv, section_index, preview_suffix=args.preview_suffix) | |
| del preview_latents_f1_for_pv | |
| clean_memory_on_device(gen_settings.device) | |
| del generated_latents_step, current_history_for_f1_clean, f1_clean_latents_combined | |
| del f1_clean_latents_1x, f1_clean_latents_2x, f1_clean_latents_4x, current_start_latent | |
| del llama_vec, llama_attention_mask, clip_l_pooler, image_encoder_last_hidden_state | |
| del llama_vec_n, llama_attention_mask_n, clip_l_pooler_n | |
| clean_memory_on_device(device) | |
| real_history_latents = history_latents[:, :, -total_generated_latent_frames:, :, :] | |
| # No resizing needed as generation happened at target dimensions. | |
| # --- Standard Model Sampling Logic --- | |
| else: # Standard mode | |
| logger.info("Starting standard model sampling process.") | |
| history_latents = torch.zeros((1, 16, 1 + 2 + 16, height // 8, width // 8), dtype=torch.float32, device='cpu') | |
| if end_latent is not None: | |
| logger.info(f"Using end image: {args.end_image_path}") | |
| history_latents[:, :, 0:1] = end_latent.cpu().float() | |
| total_generated_latent_frames = 0 | |
| history_pixels_for_preview_std_cpu = None # Initialize pixel history | |
| # For standard mode (backward generation), the first chunk generated is the "end" of the video. | |
| # If end_latent is provided and previews are on, we should decode it to start the preview history. | |
| if args.full_preview and args.preview_latent_every is not None and end_latent is not None: | |
| if vae is None: | |
| logger.error("VAE not available for initial Standard mode preview setup with end_latent.") | |
| else: | |
| logger.info("Standard Full Preview: Decoding initial end_latent for preview history.") | |
| vae.to(device) | |
| initial_latent_for_preview = end_latent.to(device, dtype=vae.dtype if hasattr(vae, 'dtype') else torch.float16) | |
| decoded_initial = hunyuan.vae_decode(initial_latent_for_preview, vae).cpu() | |
| if decoded_initial.ndim == 4: # CTHW | |
| history_pixels_for_preview_std_cpu = decoded_initial.unsqueeze(0) | |
| elif decoded_initial.ndim == 5: # BCTHW | |
| history_pixels_for_preview_std_cpu = decoded_initial | |
| else: | |
| logger.error(f"Unexpected dimensions from initial VAE decode for end_latent: {decoded_initial.shape}") | |
| vae.to("cpu") | |
| clean_memory_on_device(device) | |
| latent_paddings = list(reversed(range(total_latent_sections))) | |
| if total_latent_sections > 4: | |
| logger.info("Using F1-style latent padding heuristic for > 4 sections.") | |
| latent_paddings = [3] + [2] * (total_latent_sections - 3) + [1, 0] | |
| if args.preview_latent_every and not args.full_preview: | |
| previewer = LatentPreviewer(args, vae, None, gen_settings.device, compute_dtype, model_type="framepack") | |
| else: | |
| previewer = None | |
| for section_index_reverse, latent_padding in enumerate(latent_paddings): | |
| section_index = total_latent_sections - 1 - section_index_reverse | |
| section_index_from_last = -(section_index_reverse + 1) | |
| logger.info(f"--- Standard Section {section_index + 1} / {total_latent_sections} (Reverse Index {section_index_reverse}, Padding {latent_padding}) ---") | |
| is_last_section = latent_padding == 0 | |
| latent_padding_size = latent_padding * latent_window_size | |
| apply_section_image = False | |
| if section_index_from_last in context_img: | |
| image_index = section_index_from_last | |
| if not is_last_section: apply_section_image = True | |
| elif section_index in context_img: | |
| image_index = section_index | |
| if not is_last_section: apply_section_image = True | |
| else: | |
| image_index = 0 | |
| start_latent_section = context_img[image_index]["start_latent"].to(device, dtype=torch.float32) | |
| if apply_section_image: | |
| latent_padding_size = 0 | |
| logger.info(f"Applying experimental section image, forcing latent_padding_size = 0") | |
| split_sizes_std = [1, latent_padding_size, latent_window_size, 1, 2, 16] | |
| indices_std = torch.arange(0, sum(split_sizes_std)).unsqueeze(0).to(device) | |
| ( | |
| clean_latent_indices_pre, blank_indices, latent_indices, | |
| clean_latent_indices_post, clean_latent_2x_indices, clean_latent_4x_indices, | |
| ) = indices_std.split(split_sizes_std, dim=1) | |
| clean_latent_indices = torch.cat([clean_latent_indices_pre, clean_latent_indices_post], dim=1) | |
| current_history_std = history_latents[:, :, :19].to(device, dtype=torch.float32) | |
| clean_latents_post, clean_latents_2x, clean_latents_4x = current_history_std.split([1, 2, 16], dim=2) | |
| clean_latents = torch.cat([start_latent_section, clean_latents_post], dim=2) | |
| if section_index_from_last in context: prompt_index = section_index_from_last | |
| elif section_index in context: prompt_index = section_index | |
| else: prompt_index = 0 | |
| context_for_index = context[prompt_index] | |
| logger.info(f"Using prompt from section {prompt_index}: '{context_for_index['prompt'][:100]}...'") | |
| llama_vec = context_for_index["llama_vec"].to(device, dtype=compute_dtype) | |
| llama_attention_mask = context_for_index["llama_attention_mask"].to(device) | |
| clip_l_pooler = context_for_index["clip_l_pooler"].to(device, dtype=compute_dtype) | |
| image_encoder_last_hidden_state = context_img[image_index]["image_encoder_last_hidden_state"].to(device, dtype=compute_dtype) | |
| llama_vec_n = context_null["llama_vec"].to(device, dtype=compute_dtype) | |
| llama_attention_mask_n = context_null["llama_attention_mask"].to(device) | |
| clip_l_pooler_n = context_null["clip_l_pooler"].to(device, dtype=compute_dtype) | |
| sampler_to_use = args.sample_solver | |
| guidance_scale_to_use = args.guidance_scale | |
| embedded_cfg_scale_to_use = args.embedded_cfg_scale | |
| guidance_rescale_to_use = args.guidance_rescale | |
| # generated_latents_step is on GPU after sample_hunyuan | |
| generated_latents_step_gpu = sample_hunyuan( | |
| transformer=model, sampler=sampler_to_use, width=width, height=height, | |
| frames=f1_frames_per_section, real_guidance_scale=guidance_scale_to_use, | |
| distilled_guidance_scale=embedded_cfg_scale_to_use, guidance_rescale=guidance_rescale_to_use, | |
| num_inference_steps=args.infer_steps, generator=seed_g, | |
| prompt_embeds=llama_vec, prompt_embeds_mask=llama_attention_mask, prompt_poolers=clip_l_pooler, | |
| negative_prompt_embeds=llama_vec_n, negative_prompt_embeds_mask=llama_attention_mask_n, negative_prompt_poolers=clip_l_pooler_n, | |
| device=device, dtype=compute_dtype, image_embeddings=image_encoder_last_hidden_state, | |
| latent_indices=latent_indices, clean_latents=clean_latents, clean_latent_indices=clean_latent_indices, | |
| clean_latents_2x=clean_latents_2x, clean_latent_2x_indices=clean_latent_2x_indices, | |
| clean_latents_4x=clean_latents_4x, clean_latent_4x_indices=clean_latent_4x_indices, | |
| ) | |
| # Move to CPU for history accumulation and potential preview decode | |
| generated_latents_step = generated_latents_step_gpu.cpu().float() | |
| if is_last_section: # This is the first iteration in reverse, corresponds to earliest part of generated video | |
| logger.info("Standard Mode: Last section (first in reverse loop), prepending start_latent_section for this chunk.") | |
| generated_latents_step = torch.cat([start_latent_section.cpu().float(), generated_latents_step], dim=2) | |
| current_step_latents_cpu = generated_latents_step.clone() # This is what was generated/prepended in this step | |
| total_generated_latent_frames += int(generated_latents_step.shape[2]) | |
| history_latents = torch.cat([generated_latents_step, history_latents], dim=2) # Prepend to full latent history | |
| real_history_latents = history_latents[:, :, :total_generated_latent_frames, :, :] | |
| if args.preview_latent_every is not None and (section_index_reverse + 1) % args.preview_latent_every == 0: | |
| if args.full_preview: | |
| logger.info(f"Saving full preview at standard section {section_index + 1} (Reverse Index {section_index_reverse})") | |
| if vae is None: | |
| logger.error("VAE not available for full standard preview.") | |
| else: | |
| preview_filename_full_std = os.path.join(args.save_path, f"latent_preview_{args.preview_suffix if args.preview_suffix else section_index_reverse + 1}.mp4") | |
| latents_this_step_for_decode = current_step_latents_cpu.to(device, dtype=vae.dtype if hasattr(vae, 'dtype') else torch.float16) | |
| vae.to(device) | |
| pixels_this_step_decoded_cpu = hunyuan.vae_decode(latents_this_step_for_decode, vae).cpu() | |
| vae.to("cpu") | |
| if pixels_this_step_decoded_cpu.ndim == 4: | |
| pixels_this_step_decoded_cpu = pixels_this_step_decoded_cpu.unsqueeze(0) | |
| if history_pixels_for_preview_std_cpu is None: | |
| history_pixels_for_preview_std_cpu = pixels_this_step_decoded_cpu | |
| else: | |
| overlap_pixels = args.latent_window_size * 4 - 3 | |
| # Standard mode prepends, so new pixels are first arg for soft_append | |
| history_pixels_for_preview_std_cpu = soft_append_bcthw( | |
| pixels_this_step_decoded_cpu, | |
| history_pixels_for_preview_std_cpu, | |
| overlap=overlap_pixels | |
| ) | |
| save_bcthw_as_mp4(history_pixels_for_preview_std_cpu, preview_filename_full_std, fps=args.fps, crf=getattr(args, 'mp4_crf', 16)) | |
| logger.info(f"Full standard preview saved to {preview_filename_full_std}") | |
| del latents_this_step_for_decode, pixels_this_step_decoded_cpu | |
| clean_memory_on_device(device) | |
| elif previewer is not None: | |
| logger.info(f"Previewing latents at standard section {section_index + 1} (Reverse Index {section_index_reverse})") | |
| preview_latents_std_for_pv = real_history_latents.to(gen_settings.device) | |
| previewer.preview(preview_latents_std_for_pv, section_index, preview_suffix=args.preview_suffix) | |
| del preview_latents_std_for_pv | |
| clean_memory_on_device(gen_settings.device) | |
| logger.info(f"Section {section_index + 1} finished. Total latent frames: {total_generated_latent_frames}. History shape: {history_latents.shape}") | |
| del generated_latents_step, current_history_std, clean_latents, clean_latents_post, clean_latents_2x, clean_latents_4x | |
| del llama_vec, llama_attention_mask, clip_l_pooler, image_encoder_last_hidden_state, start_latent_section | |
| del llama_vec_n, llama_attention_mask_n, clip_l_pooler_n | |
| # Explicitly delete the GPU tensor if it was created | |
| if 'generated_latents_step_gpu' in locals(): del generated_latents_step_gpu | |
| clean_memory_on_device(device) | |
| gc.collect() | |
| clean_memory_on_device(device) | |
| # Return the final generated latents (CPU tensor) and the VAE | |
| # The shape should be (B, C, T_total, H, W) | |
| logger.info(f"Generation complete. Final latent shape: {real_history_latents.shape}") | |
| return vae, real_history_latents # Return VAE along with latents | |
| def save_latent(latent: torch.Tensor, args: argparse.Namespace, height: int, width: int, original_base_name: Optional[str] = None) -> str: # Add original_base_name | |
| """Save latent to file | |
| Args: | |
| latent: Latent tensor (CTHW expected) | |
| args: command line arguments | |
| height: height of frame | |
| width: width of frame | |
| original_base_name: Optional base name from loaded file | |
| Returns: | |
| str: Path to saved latent file | |
| """ | |
| save_path = args.save_path | |
| os.makedirs(save_path, exist_ok=True) | |
| time_flag = datetime.fromtimestamp(time.time()).strftime("%Y%m%d-%H%M%S") | |
| seed = args.seed | |
| original_name = "" if original_base_name is None else f"_{original_base_name}" # Use provided base name | |
| video_seconds = args.video_seconds | |
| latent_path = f"{save_path}/{time_flag}_{seed}{original_name}_latent.safetensors" # Add original name to file | |
| # Ensure latent is on CPU before saving | |
| latent = latent.detach().cpu() | |
| if args.no_metadata: | |
| metadata = None | |
| else: | |
| # (Metadata creation remains the same) | |
| metadata = { | |
| "seeds": f"{seed}", | |
| "prompt": f"{args.prompt}", | |
| "height": f"{height}", | |
| "width": f"{width}", | |
| "video_seconds": f"{video_seconds}", | |
| "infer_steps": f"{args.infer_steps}", | |
| "guidance_scale": f"{args.guidance_scale}", | |
| "latent_window_size": f"{args.latent_window_size}", | |
| "embedded_cfg_scale": f"{args.embedded_cfg_scale}", | |
| "guidance_rescale": f"{args.guidance_rescale}", | |
| "sample_solver": f"{args.sample_solver}", | |
| # "latent_window_size": f"{args.latent_window_size}", # Duplicate key | |
| "fps": f"{args.fps}", | |
| "is_f1": f"{args.is_f1}", # Add F1 flag to metadata | |
| } | |
| if args.negative_prompt is not None: | |
| metadata["negative_prompt"] = f"{args.negative_prompt}" | |
| # Add other relevant args like LoRA, compile settings, etc. if desired | |
| sd = {"latent": latent.contiguous()} | |
| save_file(sd, latent_path, metadata=metadata) | |
| logger.info(f"Latent saved to: {latent_path}") | |
| return latent_path | |
| def save_video( | |
| video: torch.Tensor, args: argparse.Namespace, original_base_name: Optional[str] = None, latent_frames: Optional[int] = None | |
| ) -> str: | |
| """Save video to file | |
| Args: | |
| video: Video tensor | |
| args: command line arguments | |
| original_base_name: Original base name (if latents are loaded from files) | |
| Returns: | |
| str: Path to saved video file | |
| """ | |
| save_path = args.save_path | |
| os.makedirs(save_path, exist_ok=True) | |
| time_flag = datetime.fromtimestamp(time.time()).strftime("%Y%m%d-%H%M%S") | |
| seed = args.seed | |
| original_name = "" if original_base_name is None else f"_{original_base_name}" | |
| latent_frames = "" if latent_frames is None else f"_{latent_frames}" | |
| video_path = f"{save_path}/{time_flag}_{seed}{original_name}{latent_frames}.mp4" | |
| video = video.unsqueeze(0) | |
| if args.codec is not None: | |
| save_videos_grid_advanced(video, video_path, args.codec, args.container, rescale=True, fps=args.fps, keep_frames=args.keep_pngs) | |
| else: | |
| save_videos_grid(video, video_path, fps=args.fps, rescale=True) | |
| logger.info(f"Video saved to: {video_path}") | |
| return video_path | |
| def save_images(sample: torch.Tensor, args: argparse.Namespace, original_base_name: Optional[str] = None) -> str: | |
| """Save images to directory | |
| Args: | |
| sample: Video tensor | |
| args: command line arguments | |
| original_base_name: Original base name (if latents are loaded from files) | |
| Returns: | |
| str: Path to saved images directory | |
| """ | |
| save_path = args.save_path | |
| os.makedirs(save_path, exist_ok=True) | |
| time_flag = datetime.fromtimestamp(time.time()).strftime("%Y%m%d-%H%M%S") | |
| seed = args.seed | |
| original_name = "" if original_base_name is None else f"_{original_base_name}" | |
| image_name = f"{time_flag}_{seed}{original_name}" | |
| sample = sample.unsqueeze(0) | |
| save_images_grid(sample, save_path, image_name, rescale=True) | |
| logger.info(f"Sample images saved to: {save_path}/{image_name}") | |
| return f"{save_path}/{image_name}" | |
| # In fpack_generate_video.py | |
| def save_output( | |
| args: argparse.Namespace, | |
| vae: AutoencoderKLCausal3D, | |
| latent: torch.Tensor, | |
| device: torch.device, | |
| original_base_names: Optional[List[str]] = None, | |
| ) -> None: | |
| """save output | |
| Args: | |
| args: command line arguments | |
| vae: VAE model | |
| latent: latent tensor (should be BCTHW or CTHW) | |
| device: device to use | |
| original_base_names: original base names (if latents are loaded from files) | |
| """ | |
| if latent.ndim == 4: # Add batch dim if missing (CTHW -> BCTHW) | |
| latent = latent.unsqueeze(0) | |
| elif latent.ndim != 5: | |
| raise ValueError(f"Unexpected latent dimensions: {latent.ndim}. Expected 4 or 5.") | |
| # Latent shape is BCTHW | |
| batch_size, channels, latent_frames, latent_height, latent_width = latent.shape | |
| height = latent_height * 8 | |
| width = latent_width * 8 | |
| logger.info(f"Saving output. Latent shape: {latent.shape}; Target pixel shape: {height}x{width}") | |
| if args.output_type == "latent" or args.output_type == "both": | |
| # save latent (use first name if multiple originals) | |
| base_name = original_base_names[0] if original_base_names else None | |
| save_latent(latent[0], args, height, width, original_base_name=base_name) # Save first batch item if B > 1 | |
| if args.output_type == "latent": | |
| return | |
| if args.video_sections is not None: | |
| total_latent_sections = args.video_sections | |
| else: | |
| total_latent_sections = (args.video_seconds * args.fps) / (args.latent_window_size * 4) | |
| total_latent_sections = int(max(round(total_latent_sections), 1)) | |
| logger.info(f"Decoding using total_latent_sections = {total_latent_sections} (derived from {'--video_sections' if args.video_sections is not None else '--video_seconds'}).") | |
| # Decode (handle potential batch > 1?) | |
| # decode_latent expects BCTHW or CTHW, and returns CTHW | |
| # Currently process only the first item in the batch for saving video/images | |
| video = decode_latent(args.latent_window_size, total_latent_sections, args.bulk_decode, vae, latent[0], device) | |
| if args.output_type == "video" or args.output_type == "both": | |
| # save video | |
| original_name = original_base_names[0] if original_base_names else None | |
| save_video(video, args, original_name, latent_frames=latent_frames) # Pass latent frames count | |
| elif args.output_type == "images": | |
| # save images | |
| original_name = original_base_names[0] if original_base_names else None | |
| save_images(video, args, original_name) | |
| def preprocess_prompts_for_batch(prompt_lines: List[str], base_args: argparse.Namespace) -> List[Dict]: | |
| """Process multiple prompts for batch mode | |
| Args: | |
| prompt_lines: List of prompt lines | |
| base_args: Base command line arguments | |
| Returns: | |
| List[Dict]: List of prompt data dictionaries | |
| """ | |
| prompts_data = [] | |
| for line in prompt_lines: | |
| line = line.strip() | |
| if not line or line.startswith("#"): # Skip empty lines and comments | |
| continue | |
| # Parse prompt line and create override dictionary | |
| prompt_data = parse_prompt_line(line) | |
| logger.info(f"Parsed prompt data: {prompt_data}") | |
| prompts_data.append(prompt_data) | |
| return prompts_data | |
| def get_generation_settings(args: argparse.Namespace) -> GenerationSettings: | |
| device = torch.device(args.device) | |
| dit_weight_dtype = None # default | |
| if args.fp8_scaled: | |
| dit_weight_dtype = None # various precision weights, so don't cast to specific dtype | |
| elif args.fp8: | |
| dit_weight_dtype = torch.float8_e4m3fn | |
| logger.info(f"Using device: {device}, DiT weight weight precision: {dit_weight_dtype}") | |
| gen_settings = GenerationSettings(device=device, dit_weight_dtype=dit_weight_dtype) | |
| return gen_settings | |
| # In fpack_generate_video.py | |
| def main(): | |
| # Parse arguments | |
| args = parse_args() | |
| # Check if latents are provided | |
| latents_mode = args.latent_path is not None and len(args.latent_path) > 0 | |
| # Set device | |
| device = args.device if args.device is not None else "cuda" if torch.cuda.is_available() else "cpu" | |
| device = torch.device(device) | |
| logger.info(f"Using device: {device}") | |
| args.device = device # Ensure args has the final device | |
| if latents_mode: | |
| # --- Latent Decode Mode --- | |
| # (Keep existing logic, but maybe add F1 flag reading from metadata?) | |
| original_base_names = [] | |
| latents_list = [] | |
| seeds = [] | |
| is_f1_from_metadata = False # Default | |
| # Allow only one latent file for simplicity now | |
| if len(args.latent_path) > 1: | |
| logger.warning("Loading multiple latents is not fully supported for metadata consistency. Using first latent's metadata.") | |
| for i, latent_path in enumerate(args.latent_path): | |
| logger.info(f"Loading latent from: {latent_path}") | |
| base_name = os.path.splitext(os.path.basename(latent_path))[0] | |
| original_base_names.append(base_name) | |
| seed = 0 # Default seed | |
| if not latent_path.lower().endswith(".safetensors"): | |
| logger.warning(f"Loading from non-safetensors file {latent_path}. Metadata might be missing.") | |
| latents = torch.load(latent_path, map_location="cpu") | |
| if isinstance(latents, dict) and "latent" in latents: # Handle potential dict structure | |
| latents = latents["latent"] | |
| else: | |
| try: | |
| # Load latent tensor | |
| loaded_data = load_file(latent_path, device="cpu") # Load to CPU | |
| latents = loaded_data["latent"] | |
| # Load metadata | |
| metadata = {} | |
| with safe_open(latent_path, framework="pt", device="cpu") as f: | |
| metadata = f.metadata() | |
| if metadata is None: | |
| metadata = {} | |
| logger.info(f"Loaded metadata: {metadata}") | |
| # Apply metadata only from the first file for consistency | |
| if i == 0: | |
| if "seeds" in metadata: | |
| try: | |
| seed = int(metadata["seeds"]) | |
| except ValueError: | |
| logger.warning(f"Could not parse seed from metadata: {metadata['seeds']}") | |
| if "height" in metadata and "width" in metadata: | |
| try: | |
| height = int(metadata["height"]) | |
| width = int(metadata["width"]) | |
| args.video_size = [height, width] | |
| logger.info(f"Set video size from metadata: {height}x{width}") | |
| except ValueError: | |
| logger.warning(f"Could not parse height/width from metadata.") | |
| if "video_seconds" in metadata: | |
| try: | |
| args.video_seconds = float(metadata["video_seconds"]) | |
| logger.info(f"Set video seconds from metadata: {args.video_seconds}") | |
| except ValueError: | |
| logger.warning(f"Could not parse video_seconds from metadata.") | |
| if "fps" in metadata: | |
| try: | |
| args.fps = int(metadata["fps"]) | |
| logger.info(f"Set fps from metadata: {args.fps}") | |
| except ValueError: | |
| logger.warning(f"Could not parse fps from metadata.") | |
| if "is_f1" in metadata: | |
| is_f1_from_metadata = metadata["is_f1"].lower() == 'true' | |
| if args.is_f1 != is_f1_from_metadata: | |
| logger.warning(f"Metadata indicates is_f1={is_f1_from_metadata}, overriding command line argument --is_f1={args.is_f1}") | |
| args.is_f1 = is_f1_from_metadata | |
| except Exception as e: | |
| logger.error(f"Error loading safetensors file {latent_path}: {e}") | |
| continue # Skip this file | |
| # Use seed from first file for all if multiple latents are somehow processed | |
| if i == 0: | |
| args.seed = seed | |
| seeds.append(seed) # Store all seeds read | |
| logger.info(f"Loaded latent shape: {latents.shape}") | |
| if latents.ndim == 5: # [BCTHW] | |
| if latents.shape[0] > 1: | |
| logger.warning("Latent file contains batch size > 1. Using only the first item.") | |
| latents = latents[0] # Use first item -> [CTHW] | |
| elif latents.ndim != 4: | |
| logger.error(f"Unexpected latent dimension {latents.ndim} in {latent_path}. Skipping.") | |
| continue | |
| latents_list.append(latents) | |
| if not latents_list: | |
| logger.error("No valid latents loaded. Exiting.") | |
| return | |
| # Stack latents into a batch if multiple were loaded (BCTHW) | |
| # Note: Saving output currently only processes the first batch item. | |
| latent_batch = torch.stack(latents_list, dim=0) | |
| # Load VAE needed for decoding | |
| vae = load_vae(args.vae, args.vae_chunk_size, args.vae_spatial_tile_sample_min_size, device) | |
| # Call save_output with the batch | |
| save_output(args, vae, latent_batch, device, original_base_names) | |
| elif args.from_file: | |
| # Batch mode from file (Not Implemented) | |
| logger.error("Batch mode (--from_file) is not implemented yet.") | |
| # with open(args.from_file, "r", encoding="utf-8") as f: | |
| # prompt_lines = f.readlines() | |
| # prompts_data = preprocess_prompts_for_batch(prompt_lines, args) | |
| # process_batch_prompts(prompts_data, args) # Needs implementation | |
| raise NotImplementedError("Batch mode is not implemented yet.") | |
| elif args.interactive: | |
| # Interactive mode (Not Implemented) | |
| logger.error("Interactive mode (--interactive) is not implemented yet.") | |
| # process_interactive(args) # Needs implementation | |
| raise NotImplementedError("Interactive mode is not implemented yet.") | |
| else: | |
| # --- Single prompt mode (original behavior + F1 support) --- | |
| gen_settings = get_generation_settings(args) | |
| # Generate returns (vae, latent) | |
| vae, latent = generate(args, gen_settings) # VAE might be loaded inside generate | |
| if latent is None: # Handle cases like --save_merged_model | |
| logger.info("Generation did not produce latents (e.g., --save_merged_model used). Exiting.") | |
| return | |
| # Ensure VAE is available (it should be returned by generate) | |
| if vae is None: | |
| logger.error("VAE not available after generation. Cannot save output.") | |
| return | |
| # Save output expects BCTHW or CTHW, generate returns BCTHW | |
| # save_output handles the batch dimension internally now. | |
| save_output(args, vae, latent, device) | |
| # Clean up VAE if it was loaded here | |
| del vae | |
| gc.collect() | |
| clean_memory_on_device(device) | |
| logger.info("Done!") | |
| if __name__ == "__main__": | |
| main() | |