| | import argparse |
| | import os |
| | import glob |
| | from typing import Optional, Union |
| |
|
| | import numpy as np |
| | import torch |
| | from tqdm import tqdm |
| |
|
| | from dataset import config_utils |
| | from dataset.config_utils import BlueprintGenerator, ConfigSanitizer |
| | from PIL import Image |
| |
|
| | import logging |
| |
|
| | from dataset.image_video_dataset import BaseDataset, ItemInfo, save_latent_cache, ARCHITECTURE_HUNYUAN_VIDEO |
| | from hunyuan_model.vae import load_vae |
| | from hunyuan_model.autoencoder_kl_causal_3d import AutoencoderKLCausal3D |
| | from utils.model_utils import str_to_dtype |
| |
|
| | logger = logging.getLogger(__name__) |
| | logging.basicConfig(level=logging.INFO) |
| |
|
| |
|
| | def show_image(image: Union[list[Union[Image.Image, np.ndarray], Union[Image.Image, np.ndarray]]]) -> int: |
| | import cv2 |
| |
|
| | imgs = ( |
| | [image] |
| | if (isinstance(image, np.ndarray) and len(image.shape) == 3) or isinstance(image, Image.Image) |
| | else [image[0], image[-1]] |
| | ) |
| | if len(imgs) > 1: |
| | print(f"Number of images: {len(image)}") |
| | for i, img in enumerate(imgs): |
| | if len(imgs) > 1: |
| | print(f"{'First' if i == 0 else 'Last'} image: {img.shape}") |
| | else: |
| | print(f"Image: {img.shape}") |
| | cv2_img = np.array(img) if isinstance(img, Image.Image) else img |
| | cv2_img = cv2.cvtColor(cv2_img, cv2.COLOR_RGB2BGR) |
| | cv2.imshow("image", cv2_img) |
| | k = cv2.waitKey(0) |
| | cv2.destroyAllWindows() |
| | if k == ord("q") or k == ord("d"): |
| | return k |
| | return k |
| |
|
| |
|
| | def show_console( |
| | image: Union[list[Union[Image.Image, np.ndarray], Union[Image.Image, np.ndarray]]], |
| | width: int, |
| | back: str, |
| | interactive: bool = False, |
| | ) -> int: |
| | from ascii_magic import from_pillow_image, Back |
| |
|
| | back = None |
| | if back is not None: |
| | back = getattr(Back, back.upper()) |
| |
|
| | k = None |
| | imgs = ( |
| | [image] |
| | if (isinstance(image, np.ndarray) and len(image.shape) == 3) or isinstance(image, Image.Image) |
| | else [image[0], image[-1]] |
| | ) |
| | if len(imgs) > 1: |
| | print(f"Number of images: {len(image)}") |
| | for i, img in enumerate(imgs): |
| | if len(imgs) > 1: |
| | print(f"{'First' if i == 0 else 'Last'} image: {img.shape}") |
| | else: |
| | print(f"Image: {img.shape}") |
| | pil_img = img if isinstance(img, Image.Image) else Image.fromarray(img) |
| | ascii_img = from_pillow_image(pil_img) |
| | ascii_img.to_terminal(columns=width, back=back) |
| |
|
| | if interactive: |
| | k = input("Press q to quit, d to next dataset, other key to next: ") |
| | if k == "q" or k == "d": |
| | return ord(k) |
| |
|
| | if not interactive: |
| | return ord(" ") |
| | return ord(k) if k else ord(" ") |
| |
|
| |
|
| | def save_video(image: Union[list[Union[Image.Image, np.ndarray], Union[Image.Image, np.ndarray]]], cache_path: str, fps: int = 24): |
| | import av |
| |
|
| | directory = os.path.dirname(cache_path) |
| | if not os.path.exists(directory): |
| | os.makedirs(directory) |
| |
|
| | if (isinstance(image, np.ndarray) and len(image.shape) == 3) or isinstance(image, Image.Image): |
| | |
| | image_path = cache_path.replace(".safetensors", ".jpg") |
| | img = image if isinstance(image, Image.Image) else Image.fromarray(image) |
| | img.save(image_path) |
| | print(f"Saved image: {image_path}") |
| | else: |
| | imgs = image |
| | print(f"Number of images: {len(imgs)}") |
| | |
| | video_path = cache_path.replace(".safetensors", ".mp4") |
| | height, width = imgs[0].shape[0:2] |
| |
|
| | |
| | container = av.open(video_path, mode="w") |
| |
|
| | |
| | codec = "libx264" |
| | pixel_format = "yuv420p" |
| | stream = container.add_stream(codec, rate=fps) |
| | stream.width = width |
| | stream.height = height |
| | stream.pix_fmt = pixel_format |
| | stream.bit_rate = 1000000 |
| |
|
| | for frame_img in imgs: |
| | if isinstance(frame_img, Image.Image): |
| | frame = av.VideoFrame.from_image(frame_img) |
| | else: |
| | frame = av.VideoFrame.from_ndarray(frame_img, format="rgb24") |
| | packets = stream.encode(frame) |
| | for packet in packets: |
| | container.mux(packet) |
| |
|
| | for packet in stream.encode(): |
| | container.mux(packet) |
| |
|
| | container.close() |
| |
|
| | print(f"Saved video: {video_path}") |
| |
|
| |
|
| | def show_datasets( |
| | datasets: list[BaseDataset], |
| | debug_mode: str, |
| | console_width: int, |
| | console_back: str, |
| | console_num_images: Optional[int], |
| | fps: int = 24, |
| | ): |
| | if debug_mode != "video": |
| | print(f"d: next dataset, q: quit") |
| |
|
| | num_workers = max(1, os.cpu_count() - 1) |
| | for i, dataset in enumerate(datasets): |
| | print(f"Dataset [{i}]") |
| | batch_index = 0 |
| | num_images_to_show = console_num_images |
| | k = None |
| | for key, batch in dataset.retrieve_latent_cache_batches(num_workers): |
| | print(f"bucket resolution: {key}, count: {len(batch)}") |
| | for j, item_info in enumerate(batch): |
| | item_info: ItemInfo |
| | print(f"{batch_index}-{j}: {item_info}") |
| | if debug_mode == "image": |
| | k = show_image(item_info.content) |
| | elif debug_mode == "console": |
| | k = show_console(item_info.content, console_width, console_back, console_num_images is None) |
| | if num_images_to_show is not None: |
| | num_images_to_show -= 1 |
| | if num_images_to_show == 0: |
| | k = ord("d") |
| | elif debug_mode == "video": |
| | save_video(item_info.content, item_info.latent_cache_path, fps) |
| | k = None |
| |
|
| | if k == ord("q"): |
| | return |
| | elif k == ord("d"): |
| | break |
| | if k == ord("d"): |
| | break |
| | batch_index += 1 |
| |
|
| |
|
| | def encode_and_save_batch(vae: AutoencoderKLCausal3D, batch: list[ItemInfo]): |
| | contents = torch.stack([torch.from_numpy(item.content) for item in batch]) |
| | if len(contents.shape) == 4: |
| | contents = contents.unsqueeze(1) |
| |
|
| | contents = contents.permute(0, 4, 1, 2, 3).contiguous() |
| | contents = contents.to(vae.device, dtype=vae.dtype) |
| | contents = contents / 127.5 - 1.0 |
| |
|
| | h, w = contents.shape[3], contents.shape[4] |
| | if h < 8 or w < 8: |
| | item = batch[0] |
| | raise ValueError(f"Image or video size too small: {item.item_key} and {len(batch) - 1} more, size: {item.original_size}") |
| |
|
| | |
| | with torch.no_grad(): |
| | latent = vae.encode(contents).latent_dist.sample() |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | for item, l in zip(batch, latent): |
| | |
| | save_latent_cache(item, l) |
| |
|
| |
|
| | def encode_datasets(datasets: list[BaseDataset], encode: callable, args: argparse.Namespace): |
| | num_workers = args.num_workers if args.num_workers is not None else max(1, os.cpu_count() - 1) |
| | for i, dataset in enumerate(datasets): |
| | logger.info(f"Encoding dataset [{i}]") |
| | all_latent_cache_paths = [] |
| | for _, batch in tqdm(dataset.retrieve_latent_cache_batches(num_workers)): |
| | all_latent_cache_paths.extend([item.latent_cache_path for item in batch]) |
| |
|
| | if args.skip_existing: |
| | filtered_batch = [item for item in batch if not os.path.exists(item.latent_cache_path)] |
| | if len(filtered_batch) == 0: |
| | continue |
| | batch = filtered_batch |
| |
|
| | bs = args.batch_size if args.batch_size is not None else len(batch) |
| | for i in range(0, len(batch), bs): |
| | encode(batch[i : i + bs]) |
| |
|
| | |
| | all_latent_cache_paths = [os.path.normpath(p) for p in all_latent_cache_paths] |
| | all_latent_cache_paths = set(all_latent_cache_paths) |
| |
|
| | |
| | all_cache_files = dataset.get_all_latent_cache_files() |
| | for cache_file in all_cache_files: |
| | if os.path.normpath(cache_file) not in all_latent_cache_paths: |
| | if args.keep_cache: |
| | logger.info(f"Keep cache file not in the dataset: {cache_file}") |
| | else: |
| | os.remove(cache_file) |
| | logger.info(f"Removed old cache file: {cache_file}") |
| |
|
| |
|
| | def main(args): |
| | device = args.device if args.device is not None else "cuda" if torch.cuda.is_available() else "cpu" |
| | device = torch.device(device) |
| |
|
| | |
| | blueprint_generator = BlueprintGenerator(ConfigSanitizer()) |
| | logger.info(f"Load dataset config from {args.dataset_config}") |
| | user_config = config_utils.load_user_config(args.dataset_config) |
| | blueprint = blueprint_generator.generate(user_config, args, architecture=ARCHITECTURE_HUNYUAN_VIDEO) |
| | train_dataset_group = config_utils.generate_dataset_group_by_blueprint(blueprint.dataset_group) |
| |
|
| | datasets = train_dataset_group.datasets |
| |
|
| | if args.debug_mode is not None: |
| | show_datasets(datasets, args.debug_mode, args.console_width, args.console_back, args.console_num_images) |
| | return |
| |
|
| | assert args.vae is not None, "vae checkpoint is required" |
| |
|
| | |
| | vae_dtype = torch.float16 if args.vae_dtype is None else str_to_dtype(args.vae_dtype) |
| | vae, _, s_ratio, t_ratio = load_vae(vae_dtype=vae_dtype, device=device, vae_path=args.vae) |
| | vae.eval() |
| | logger.info(f"Loaded VAE: {vae.config}, dtype: {vae.dtype}") |
| |
|
| | if args.vae_chunk_size is not None: |
| | vae.set_chunk_size_for_causal_conv_3d(args.vae_chunk_size) |
| | logger.info(f"Set chunk_size to {args.vae_chunk_size} for CausalConv3d in VAE") |
| | if args.vae_spatial_tile_sample_min_size is not None: |
| | vae.enable_spatial_tiling(True) |
| | vae.tile_sample_min_size = args.vae_spatial_tile_sample_min_size |
| | vae.tile_latent_min_size = args.vae_spatial_tile_sample_min_size // 8 |
| | elif args.vae_tiling: |
| | vae.enable_spatial_tiling(True) |
| |
|
| | |
| | def encode(one_batch: list[ItemInfo]): |
| | encode_and_save_batch(vae, one_batch) |
| |
|
| | encode_datasets(datasets, encode, args) |
| |
|
| |
|
| | def setup_parser_common() -> argparse.ArgumentParser: |
| | parser = argparse.ArgumentParser() |
| |
|
| | parser.add_argument("--dataset_config", type=str, required=True, help="path to dataset config .toml file") |
| | parser.add_argument("--vae", type=str, required=False, default=None, help="path to vae checkpoint") |
| | parser.add_argument("--vae_dtype", type=str, default=None, help="data type for VAE, default is float16") |
| | parser.add_argument("--device", type=str, default=None, help="device to use, default is cuda if available") |
| | parser.add_argument( |
| | "--batch_size", type=int, default=None, help="batch size, override dataset config if dataset batch size > this" |
| | ) |
| | parser.add_argument("--num_workers", type=int, default=None, help="number of workers for dataset. default is cpu count-1") |
| | parser.add_argument("--skip_existing", action="store_true", help="skip existing cache files") |
| | parser.add_argument("--keep_cache", action="store_true", help="keep cache files not in dataset") |
| | parser.add_argument("--debug_mode", type=str, default=None, choices=["image", "console", "video"], help="debug mode") |
| | parser.add_argument("--console_width", type=int, default=80, help="debug mode: console width") |
| | parser.add_argument( |
| | "--console_back", type=str, default=None, help="debug mode: console background color, one of ascii_magic.Back" |
| | ) |
| | parser.add_argument( |
| | "--console_num_images", |
| | type=int, |
| | default=None, |
| | help="debug mode: not interactive, number of images to show for each dataset", |
| | ) |
| | return parser |
| |
|
| |
|
| | def hv_setup_parser(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: |
| | parser.add_argument( |
| | "--vae_tiling", |
| | action="store_true", |
| | help="enable spatial tiling for VAE, default is False. If vae_spatial_tile_sample_min_size is set, this is automatically enabled", |
| | ) |
| | parser.add_argument("--vae_chunk_size", type=int, default=None, help="chunk size for CausalConv3d in VAE") |
| | parser.add_argument( |
| | "--vae_spatial_tile_sample_min_size", type=int, default=None, help="spatial tile sample min size for VAE, default 256" |
| | ) |
| | return parser |
| |
|
| |
|
| | if __name__ == "__main__": |
| | parser = setup_parser_common() |
| | parser = hv_setup_parser(parser) |
| |
|
| | args = parser.parse_args() |
| | main(args) |
| |
|