| | import argparse |
| | import logging |
| | import math |
| | import os |
| | from typing import List, Optional |
| |
|
| | import numpy as np |
| | import torch |
| | import torch.nn.functional as F |
| | from tqdm import tqdm |
| | from transformers import SiglipImageProcessor, SiglipVisionModel |
| | from PIL import Image |
| |
|
| | from dataset import config_utils |
| | from dataset.config_utils import BlueprintGenerator, ConfigSanitizer |
| | from dataset.image_video_dataset import BaseDataset, ItemInfo, save_latent_cache_framepack, ARCHITECTURE_FRAMEPACK |
| | from frame_pack import hunyuan |
| | from frame_pack.framepack_utils import load_image_encoders, load_vae |
| | from hunyuan_model.autoencoder_kl_causal_3d import AutoencoderKLCausal3D |
| | from frame_pack.clip_vision import hf_clip_vision_encode |
| | import cache_latents |
| |
|
| | logger = logging.getLogger(__name__) |
| | logging.basicConfig(level=logging.INFO) |
| |
|
| |
|
| | def encode_and_save_batch( |
| | vae: AutoencoderKLCausal3D, |
| | feature_extractor: SiglipImageProcessor, |
| | image_encoder: SiglipVisionModel, |
| | batch: List[ItemInfo], |
| | vanilla_sampling: bool = False, |
| | one_frame: bool = False, |
| | one_frame_no_2x: bool = False, |
| | one_frame_no_4x: bool = False, |
| | ): |
| | """Encode a batch of original RGB videos and save FramePack section caches.""" |
| | if one_frame: |
| | encode_and_save_batch_one_frame( |
| | vae, feature_extractor, image_encoder, batch, vanilla_sampling, one_frame_no_2x, one_frame_no_4x |
| | ) |
| | return |
| |
|
| | latent_window_size = batch[0].fp_latent_window_size |
| |
|
| | |
| | contents = torch.stack([torch.from_numpy(item.content) for item in batch]) |
| | if len(contents.shape) == 4: |
| | contents = contents.unsqueeze(1) |
| |
|
| | contents = contents.permute(0, 4, 1, 2, 3).contiguous() |
| | contents = contents.to(vae.device, dtype=vae.dtype) |
| | contents = contents / 127.5 - 1.0 |
| |
|
| | height, width = contents.shape[3], contents.shape[4] |
| | if height < 8 or width < 8: |
| | item = batch[0] |
| | raise ValueError(f"Image or video size too small: {item.item_key} and {len(batch) - 1} more, size: {item.original_size}") |
| |
|
| | |
| | latent_f = (batch[0].frame_count - 1) // 4 + 1 |
| |
|
| | |
| | total_latent_sections = math.floor((latent_f - 1) / latent_window_size) |
| | if total_latent_sections < 1: |
| | min_frames_needed = latent_window_size * 4 + 1 |
| | raise ValueError( |
| | f"Not enough frames for FramePack: {batch[0].frame_count} frames ({latent_f} latent frames), minimum required: {min_frames_needed} frames ({latent_window_size+1} latent frames)" |
| | ) |
| |
|
| | |
| | latent_f_aligned = total_latent_sections * latent_window_size + 1 if not one_frame else 1 |
| |
|
| | |
| | frame_count_aligned = (latent_f_aligned - 1) * 4 + 1 |
| | if frame_count_aligned != batch[0].frame_count: |
| | logger.info( |
| | f"Frame count mismatch: required={frame_count_aligned} != actual={batch[0].frame_count}, trimming to {frame_count_aligned}" |
| | ) |
| | contents = contents[:, :, :frame_count_aligned, :, :] |
| |
|
| | latent_f = latent_f_aligned |
| |
|
| | |
| | latents = hunyuan.vae_encode(contents, vae) |
| | latents = latents.to("cpu") |
| |
|
| | |
| | images = np.stack([item.content[0] for item in batch], axis=0) |
| |
|
| | |
| | image_embeddings = [] |
| | with torch.no_grad(): |
| | for image in images: |
| | image_encoder_output = hf_clip_vision_encode(image, feature_extractor, image_encoder) |
| | image_embeddings.append(image_encoder_output.last_hidden_state) |
| | image_embeddings = torch.cat(image_embeddings, dim=0) |
| | image_embeddings = image_embeddings.to("cpu") |
| |
|
| | if not vanilla_sampling: |
| | |
| | latent_paddings = list(reversed(range(total_latent_sections))) |
| | |
| | if total_latent_sections > 4: |
| | latent_paddings = [3] + [2] * (total_latent_sections - 3) + [1, 0] |
| |
|
| | for b, item in enumerate(batch): |
| | original_latent_cache_path = item.latent_cache_path |
| | video_lat = latents[b : b + 1] |
| |
|
| | |
| | |
| | |
| | |
| | history_latents = torch.zeros( |
| | (1, video_lat.shape[1], 1 + 2 + 16, video_lat.shape[3], video_lat.shape[4]), dtype=video_lat.dtype |
| | ) |
| |
|
| | latent_f_index = latent_f - latent_window_size |
| | section_index = total_latent_sections - 1 |
| |
|
| | for latent_padding in latent_paddings: |
| | is_last_section = section_index == 0 |
| | latent_padding_size = latent_padding * latent_window_size |
| | if is_last_section: |
| | assert latent_f_index == 1, "Last section should be starting from frame 1" |
| |
|
| | |
| | indices = torch.arange(0, sum([1, latent_padding_size, latent_window_size, 1, 2, 16])).unsqueeze(0) |
| | ( |
| | clean_latent_indices_pre, |
| | blank_indices, |
| | latent_indices, |
| | clean_latent_indices_post, |
| | clean_latent_2x_indices, |
| | clean_latent_4x_indices, |
| | ) = indices.split([1, latent_padding_size, latent_window_size, 1, 2, 16], dim=1) |
| |
|
| | |
| | clean_latent_indices = torch.cat([clean_latent_indices_pre, clean_latent_indices_post], dim=1) |
| |
|
| | |
| | clean_latents_pre = video_lat[:, :, 0:1, :, :] |
| | clean_latents_post, clean_latents_2x, clean_latents_4x = history_latents[:, :, : 1 + 2 + 16, :, :].split( |
| | [1, 2, 16], dim=2 |
| | ) |
| | clean_latents = torch.cat([clean_latents_pre, clean_latents_post], dim=2) |
| |
|
| | |
| | target_latents = video_lat[:, :, latent_f_index : latent_f_index + latent_window_size, :, :] |
| |
|
| | |
| | item.latent_cache_path = append_section_idx_to_latent_cache_path(original_latent_cache_path, section_index) |
| | save_latent_cache_framepack( |
| | item_info=item, |
| | latent=target_latents.squeeze(0), |
| | latent_indices=latent_indices.squeeze(0), |
| | clean_latents=clean_latents.squeeze(0), |
| | clean_latent_indices=clean_latent_indices.squeeze(0), |
| | clean_latents_2x=clean_latents_2x.squeeze(0), |
| | clean_latent_2x_indices=clean_latent_2x_indices.squeeze(0), |
| | clean_latents_4x=clean_latents_4x.squeeze(0), |
| | clean_latent_4x_indices=clean_latent_4x_indices.squeeze(0), |
| | image_embeddings=image_embeddings[b], |
| | ) |
| |
|
| | if is_last_section: |
| | |
| | generated_latents_for_history = video_lat[:, :, : latent_window_size + 1, :, :] |
| | else: |
| | |
| | generated_latents_for_history = target_latents |
| |
|
| | history_latents = torch.cat([generated_latents_for_history, history_latents], dim=2) |
| |
|
| | section_index -= 1 |
| | latent_f_index -= latent_window_size |
| |
|
| | else: |
| | |
| | for b, item in enumerate(batch): |
| | original_latent_cache_path = item.latent_cache_path |
| | video_lat = latents[b : b + 1] |
| | img_emb = image_embeddings[b] |
| |
|
| | for section_index in range(total_latent_sections): |
| | target_start_f = section_index * latent_window_size + 1 |
| | target_end_f = target_start_f + latent_window_size |
| | target_latents = video_lat[:, :, target_start_f:target_end_f, :, :] |
| | start_latent = video_lat[:, :, 0:1, :, :] |
| |
|
| | |
| | clean_latents_total_count = 1 + 2 + 16 |
| | history_latents = torch.zeros( |
| | size=(1, 16, clean_latents_total_count, video_lat.shape[-2], video_lat.shape[-1]), |
| | device=video_lat.device, |
| | dtype=video_lat.dtype, |
| | ) |
| |
|
| | history_start_f = 0 |
| | video_start_f = target_start_f - clean_latents_total_count |
| | copy_count = clean_latents_total_count |
| | if video_start_f < 0: |
| | history_start_f = -video_start_f |
| | copy_count = clean_latents_total_count - history_start_f |
| | video_start_f = 0 |
| | if copy_count > 0: |
| | history_latents[:, :, history_start_f:] = video_lat[:, :, video_start_f : video_start_f + copy_count, :, :] |
| |
|
| | |
| | indices = torch.arange(0, sum([1, 16, 2, 1, latent_window_size])).unsqueeze(0) |
| | ( |
| | clean_latent_indices_start, |
| | clean_latent_4x_indices, |
| | clean_latent_2x_indices, |
| | clean_latent_1x_indices, |
| | latent_indices, |
| | ) = indices.split([1, 16, 2, 1, latent_window_size], dim=1) |
| | clean_latent_indices = torch.cat([clean_latent_indices_start, clean_latent_1x_indices], dim=1) |
| |
|
| | clean_latents_4x, clean_latents_2x, clean_latents_1x = history_latents.split([16, 2, 1], dim=2) |
| | clean_latents = torch.cat([start_latent, clean_latents_1x], dim=2) |
| |
|
| | |
| | item.latent_cache_path = append_section_idx_to_latent_cache_path(original_latent_cache_path, section_index) |
| | save_latent_cache_framepack( |
| | item_info=item, |
| | latent=target_latents.squeeze(0), |
| | latent_indices=latent_indices.squeeze(0), |
| | clean_latents=clean_latents.squeeze(0), |
| | clean_latent_indices=clean_latent_indices.squeeze(0), |
| | clean_latents_2x=clean_latents_2x.squeeze(0), |
| | clean_latent_2x_indices=clean_latent_2x_indices.squeeze(0), |
| | clean_latents_4x=clean_latents_4x.squeeze(0), |
| | clean_latent_4x_indices=clean_latent_4x_indices.squeeze(0), |
| | image_embeddings=img_emb, |
| | |
| | |
| | ) |
| |
|
| |
|
| | def encode_and_save_batch_one_frame( |
| | vae: AutoencoderKLCausal3D, |
| | feature_extractor: SiglipImageProcessor, |
| | image_encoder: SiglipVisionModel, |
| | batch: List[ItemInfo], |
| | vanilla_sampling: bool = False, |
| | one_frame_no_2x: bool = False, |
| | one_frame_no_4x: bool = False, |
| | ): |
| | |
| | |
| |
|
| | |
| | contents = [] |
| | content_masks: list[list[Optional[torch.Tensor]]] = [] |
| | for item in batch: |
| | item_contents = item.control_content + [item.content] |
| |
|
| | item_masks = [] |
| | for i, c in enumerate(item_contents): |
| | if c.shape[-1] == 4: |
| | item_contents[i] = c[..., :3] |
| |
|
| | alpha = c[..., 3] |
| | mask_image = Image.fromarray(alpha, mode="L") |
| | width, height = mask_image.size |
| | mask_image = mask_image.resize((width // 8, height // 8), Image.LANCZOS) |
| | mask_image = np.array(mask_image) |
| | mask_image = torch.from_numpy(mask_image).float() / 255.0 |
| | mask_image = mask_image.squeeze(-1) |
| | mask_image = mask_image.unsqueeze(0).unsqueeze(0).unsqueeze(0) |
| | mask_image = mask_image.to(torch.float32) |
| | content_mask = mask_image |
| | else: |
| | content_mask = None |
| |
|
| | item_masks.append(content_mask) |
| |
|
| | item_contents = [torch.from_numpy(c) for c in item_contents] |
| | contents.append(torch.stack(item_contents, dim=0)) |
| | content_masks.append(item_masks) |
| |
|
| | contents = torch.stack(contents, dim=0) |
| |
|
| | contents = contents.permute(0, 4, 1, 2, 3).contiguous() |
| | contents = contents.to(vae.device, dtype=vae.dtype) |
| | contents = contents / 127.5 - 1.0 |
| |
|
| | height, width = contents.shape[-2], contents.shape[-1] |
| | if height < 8 or width < 8: |
| | item = batch[0] |
| | raise ValueError(f"Image or video size too small: {item.item_key} and {len(batch) - 1} more, size: {item.original_size}") |
| |
|
| | |
| | latents = [hunyuan.vae_encode(contents[:, :, idx : idx + 1], vae).to("cpu") for idx in range(contents.shape[2])] |
| | latents = torch.cat(latents, dim=2) |
| |
|
| | |
| | for b, item in enumerate(batch): |
| | for i, content_mask in enumerate(content_masks[b]): |
| | if content_mask is not None: |
| | |
| | |
| | latents[b : b + 1, :, i : i + 1] *= content_mask |
| |
|
| | |
| | images = [item.control_content[0] for item in batch] |
| |
|
| | |
| | image_embeddings = [] |
| | with torch.no_grad(): |
| | for image in images: |
| | image_encoder_output = hf_clip_vision_encode(image, feature_extractor, image_encoder) |
| | image_embeddings.append(image_encoder_output.last_hidden_state) |
| | image_embeddings = torch.cat(image_embeddings, dim=0) |
| | image_embeddings = image_embeddings.to("cpu") |
| |
|
| | |
| | for b, item in enumerate(batch): |
| | |
| | clean_latent_indices = item.fp_1f_clean_indices |
| | if clean_latent_indices is None or len(clean_latent_indices) == 0: |
| | logger.warning( |
| | f"Item {item.item_key} has no clean_latent_indices defined, using default indices for one frame training." |
| | ) |
| | clean_latent_indices = [0] |
| |
|
| | if not item.fp_1f_no_post: |
| | clean_latent_indices = clean_latent_indices + [1 + item.fp_latent_window_size] |
| | clean_latent_indices = torch.Tensor(clean_latent_indices).long() |
| |
|
| | latent_index = torch.Tensor([item.fp_1f_target_index]).long() |
| |
|
| | |
| | clean_latents_2x = None |
| | clean_latents_4x = None |
| |
|
| | if one_frame_no_2x: |
| | clean_latent_2x_indices = None |
| | else: |
| | index = 1 + item.fp_latent_window_size + 1 |
| | clean_latent_2x_indices = torch.arange(index, index + 2) |
| |
|
| | if one_frame_no_4x: |
| | clean_latent_4x_indices = None |
| | else: |
| | index = 1 + item.fp_latent_window_size + 1 + 2 |
| | clean_latent_4x_indices = torch.arange(index, index + 16) |
| |
|
| | |
| | clean_latents = latents[b, :, :-1] |
| | if not item.fp_1f_no_post: |
| | |
| | clean_latents = F.pad(clean_latents, (0, 0, 0, 0, 0, 1), value=0.0) |
| |
|
| | |
| | target_latents = latents[b, :, -1:] |
| |
|
| | print(f"Saving cache for item {item.item_key} at {item.latent_cache_path}. no_post: {item.fp_1f_no_post}") |
| | print(f" Clean latent indices: {clean_latent_indices}, latent index: {latent_index}") |
| | print(f" Clean latents: {clean_latents.shape}, target latents: {target_latents.shape}") |
| | print(f" Clean latents 2x indices: {clean_latent_2x_indices}, clean latents 4x indices: {clean_latent_4x_indices}") |
| | print( |
| | f" Clean latents 2x: {clean_latents_2x.shape if clean_latents_2x is not None else 'None'}, " |
| | f"Clean latents 4x: {clean_latents_4x.shape if clean_latents_4x is not None else 'None'}" |
| | ) |
| | print(f" Image embeddings: {image_embeddings[b].shape}") |
| |
|
| | |
| | save_latent_cache_framepack( |
| | item_info=item, |
| | latent=target_latents, |
| | latent_indices=latent_index, |
| | clean_latents=clean_latents, |
| | clean_latent_indices=clean_latent_indices, |
| | clean_latents_2x=clean_latents_2x, |
| | clean_latent_2x_indices=clean_latent_2x_indices, |
| | clean_latents_4x=clean_latents_4x, |
| | clean_latent_4x_indices=clean_latent_4x_indices, |
| | image_embeddings=image_embeddings[b], |
| | ) |
| |
|
| |
|
| | def framepack_setup_parser(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: |
| | parser.add_argument("--image_encoder", type=str, required=True, help="Image encoder (CLIP) checkpoint path or directory") |
| | parser.add_argument( |
| | "--f1", |
| | action="store_true", |
| | help="Generate cache for F1 model (vanilla (autoregressive) sampling) instead of Inverted anti-drifting (plain FramePack)", |
| | ) |
| | parser.add_argument( |
| | "--one_frame", |
| | action="store_true", |
| | help="Generate cache for one frame training (single frame, single section). latent_window_size is used as the index of the target frame.", |
| | ) |
| | parser.add_argument( |
| | "--one_frame_no_2x", |
| | action="store_true", |
| | help="Do not use clean_latents_2x and clean_latent_2x_indices for one frame training.", |
| | ) |
| | parser.add_argument( |
| | "--one_frame_no_4x", |
| | action="store_true", |
| | help="Do not use clean_latents_4x and clean_latent_4x_indices for one frame training.", |
| | ) |
| | return parser |
| |
|
| |
|
| | def main(args: argparse.Namespace): |
| | device = args.device if hasattr(args, "device") and args.device else ("cuda" if torch.cuda.is_available() else "cpu") |
| | device = torch.device(device) |
| |
|
| | |
| | blueprint_generator = BlueprintGenerator(ConfigSanitizer()) |
| | logger.info(f"Load dataset config from {args.dataset_config}") |
| | user_config = config_utils.load_user_config(args.dataset_config) |
| | blueprint = blueprint_generator.generate(user_config, args, architecture=ARCHITECTURE_FRAMEPACK) |
| | train_dataset_group = config_utils.generate_dataset_group_by_blueprint(blueprint.dataset_group) |
| |
|
| | datasets = train_dataset_group.datasets |
| |
|
| | if args.debug_mode is not None: |
| | cache_latents.show_datasets( |
| | datasets, args.debug_mode, args.console_width, args.console_back, args.console_num_images, fps=16 |
| | ) |
| | return |
| |
|
| | assert args.vae is not None, "vae checkpoint is required" |
| |
|
| | logger.info(f"Loading VAE model from {args.vae}") |
| | vae = load_vae(args.vae, args.vae_chunk_size, args.vae_spatial_tile_sample_min_size, device=device) |
| | vae.to(device) |
| |
|
| | logger.info(f"Loading image encoder from {args.image_encoder}") |
| | feature_extractor, image_encoder = load_image_encoders(args) |
| | image_encoder.eval() |
| | image_encoder.to(device) |
| |
|
| | logger.info(f"Cache generation mode: {'Vanilla Sampling' if args.f1 else 'Inference Emulation'}") |
| |
|
| | |
| | def encode(batch: List[ItemInfo]): |
| | encode_and_save_batch( |
| | vae, feature_extractor, image_encoder, batch, args.f1, args.one_frame, args.one_frame_no_2x, args.one_frame_no_4x |
| | ) |
| |
|
| | |
| | encode_datasets_framepack(datasets, encode, args) |
| |
|
| |
|
| | def append_section_idx_to_latent_cache_path(latent_cache_path: str, section_idx: int) -> str: |
| | tokens = latent_cache_path.split("_") |
| | tokens[-3] = f"{tokens[-3]}-{section_idx:04d}" |
| | return "_".join(tokens) |
| |
|
| |
|
| | def encode_datasets_framepack(datasets: list[BaseDataset], encode: callable, args: argparse.Namespace): |
| | num_workers = args.num_workers if args.num_workers is not None else max(1, os.cpu_count() - 1) |
| | for i, dataset in enumerate(datasets): |
| | logger.info(f"Encoding dataset [{i}]") |
| | all_latent_cache_paths = [] |
| | for _, batch in tqdm(dataset.retrieve_latent_cache_batches(num_workers)): |
| | batch: list[ItemInfo] = batch |
| |
|
| | |
| | |
| | filtered_batch = [] |
| | for item in batch: |
| | if item.frame_count is None: |
| | |
| | all_latent_cache_paths.append(item.latent_cache_path) |
| | all_existing = os.path.exists(item.latent_cache_path) |
| | else: |
| | latent_f = (item.frame_count - 1) // 4 + 1 |
| | num_sections = max(1, math.floor((latent_f - 1) / item.fp_latent_window_size)) |
| | all_existing = True |
| | for sec in range(num_sections): |
| | p = append_section_idx_to_latent_cache_path(item.latent_cache_path, sec) |
| | all_latent_cache_paths.append(p) |
| | all_existing = all_existing and os.path.exists(p) |
| |
|
| | if not all_existing: |
| | filtered_batch.append(item) |
| |
|
| | if args.skip_existing: |
| | if len(filtered_batch) == 0: |
| | logger.info(f"All sections exist for {batch[0].item_key}, skipping") |
| | continue |
| | batch = filtered_batch |
| |
|
| | bs = args.batch_size if args.batch_size is not None else len(batch) |
| | for i in range(0, len(batch), bs): |
| | encode(batch[i : i + bs]) |
| |
|
| | |
| | all_latent_cache_paths = [os.path.normpath(p) for p in all_latent_cache_paths] |
| | all_latent_cache_paths = set(all_latent_cache_paths) |
| |
|
| | |
| | all_cache_files = dataset.get_all_latent_cache_files() |
| | for cache_file in all_cache_files: |
| | if os.path.normpath(cache_file) not in all_latent_cache_paths: |
| | if args.keep_cache: |
| | logger.info(f"Keep cache file not in the dataset: {cache_file}") |
| | else: |
| | os.remove(cache_file) |
| | logger.info(f"Removed old cache file: {cache_file}") |
| |
|
| |
|
| | if __name__ == "__main__": |
| | parser = cache_latents.setup_parser_common() |
| | parser = cache_latents.hv_setup_parser(parser) |
| | parser = framepack_setup_parser(parser) |
| |
|
| | args = parser.parse_args() |
| |
|
| | if args.vae_dtype is not None: |
| | raise ValueError("VAE dtype is not supported in FramePack") |
| | |
| | |
| | |
| |
|
| | main(args) |
| |
|