| from concurrent.futures import ThreadPoolExecutor |
| import glob |
| import json |
| import math |
| import os |
| import random |
| import time |
| from typing import Optional, Sequence, Tuple, Union |
|
|
| import numpy as np |
| import torch |
| from safetensors.torch import save_file, load_file |
| from safetensors import safe_open |
| from PIL import Image |
| import cv2 |
| import av |
|
|
| from utils import safetensors_utils |
| from utils.model_utils import dtype_to_str |
|
|
| import logging |
|
|
| logger = logging.getLogger(__name__) |
| logging.basicConfig(level=logging.INFO) |
|
|
|
|
| IMAGE_EXTENSIONS = [".png", ".jpg", ".jpeg", ".webp", ".bmp", ".PNG", ".JPG", ".JPEG", ".WEBP", ".BMP"] |
|
|
| try: |
| import pillow_avif |
|
|
| IMAGE_EXTENSIONS.extend([".avif", ".AVIF"]) |
| except: |
| pass |
|
|
| |
| try: |
| from jxlpy import JXLImagePlugin |
|
|
| IMAGE_EXTENSIONS.extend([".jxl", ".JXL"]) |
| except: |
| pass |
|
|
| |
| try: |
| import pillow_jxl |
|
|
| IMAGE_EXTENSIONS.extend([".jxl", ".JXL"]) |
| except: |
| pass |
|
|
| VIDEO_EXTENSIONS = [".mp4", ".avi", ".mov", ".webm", ".MP4", ".AVI", ".MOV", ".WEBM"] |
|
|
| ARCHITECTURE_HUNYUAN_VIDEO = "hv" |
|
|
|
|
| def glob_images(directory, base="*"): |
| img_paths = [] |
| for ext in IMAGE_EXTENSIONS: |
| if base == "*": |
| img_paths.extend(glob.glob(os.path.join(glob.escape(directory), base + ext))) |
| else: |
| img_paths.extend(glob.glob(glob.escape(os.path.join(directory, base + ext)))) |
| img_paths = list(set(img_paths)) |
| img_paths.sort() |
| return img_paths |
|
|
|
|
| def glob_videos(directory, base="*"): |
| video_paths = [] |
| for ext in VIDEO_EXTENSIONS: |
| if base == "*": |
| video_paths.extend(glob.glob(os.path.join(glob.escape(directory), base + ext))) |
| else: |
| video_paths.extend(glob.glob(glob.escape(os.path.join(directory, base + ext)))) |
| video_paths = list(set(video_paths)) |
| video_paths.sort() |
| return video_paths |
|
|
|
|
| def divisible_by(num: int, divisor: int) -> int: |
| return num - num % divisor |
|
|
|
|
| def resize_image_to_bucket(image: Union[Image.Image, np.ndarray], bucket_reso: tuple[int, int]) -> np.ndarray: |
| """ |
| Resize the image to the bucket resolution. |
| """ |
| is_pil_image = isinstance(image, Image.Image) |
| if is_pil_image: |
| image_width, image_height = image.size |
| else: |
| image_height, image_width = image.shape[:2] |
|
|
| if bucket_reso == (image_width, image_height): |
| return np.array(image) if is_pil_image else image |
|
|
| bucket_width, bucket_height = bucket_reso |
| if bucket_width == image_width or bucket_height == image_height: |
| image = np.array(image) if is_pil_image else image |
| else: |
| |
| scale_width = bucket_width / image_width |
| scale_height = bucket_height / image_height |
| scale = max(scale_width, scale_height) |
| image_width = int(image_width * scale + 0.5) |
| image_height = int(image_height * scale + 0.5) |
|
|
| if scale > 1: |
| image = Image.fromarray(image) if not is_pil_image else image |
| image = image.resize((image_width, image_height), Image.LANCZOS) |
| image = np.array(image) |
| else: |
| image = np.array(image) if is_pil_image else image |
| image = cv2.resize(image, (image_width, image_height), interpolation=cv2.INTER_AREA) |
|
|
| |
| crop_left = (image_width - bucket_width) // 2 |
| crop_top = (image_height - bucket_height) // 2 |
| image = image[crop_top : crop_top + bucket_height, crop_left : crop_left + bucket_width] |
| return image |
|
|
|
|
| class ItemInfo: |
| def __init__( |
| self, |
| item_key: str, |
| caption: str, |
| original_size: tuple[int, int], |
| bucket_size: Optional[Union[tuple[int, int], tuple[int, int, int]]] = None, |
| frame_count: Optional[int] = None, |
| content: Optional[np.ndarray] = None, |
| latent_cache_path: Optional[str] = None, |
| ) -> None: |
| self.item_key = item_key |
| self.caption = caption |
| self.original_size = original_size |
| self.bucket_size = bucket_size |
| self.frame_count = frame_count |
| self.content = content |
| self.latent_cache_path = latent_cache_path |
| self.text_encoder_output_cache_path: Optional[str] = None |
|
|
| def __str__(self) -> str: |
| return ( |
| f"ItemInfo(item_key={self.item_key}, caption={self.caption}, " |
| + f"original_size={self.original_size}, bucket_size={self.bucket_size}, " |
| + f"frame_count={self.frame_count}, latent_cache_path={self.latent_cache_path})" |
| ) |
|
|
|
|
| def save_latent_cache(item_info: ItemInfo, latent: torch.Tensor): |
| assert latent.dim() == 4, "latent should be 4D tensor (frame, channel, height, width)" |
| metadata = { |
| "architecture": "hunyuan_video", |
| "width": f"{item_info.original_size[0]}", |
| "height": f"{item_info.original_size[1]}", |
| "format_version": "1.0.0", |
| } |
| if item_info.frame_count is not None: |
| metadata["frame_count"] = f"{item_info.frame_count}" |
|
|
| _, F, H, W = latent.shape |
| dtype_str = dtype_to_str(latent.dtype) |
| sd = {f"latents_{F}x{H}x{W}_{dtype_str}": latent.detach().cpu()} |
|
|
| latent_dir = os.path.dirname(item_info.latent_cache_path) |
| os.makedirs(latent_dir, exist_ok=True) |
|
|
| save_file(sd, item_info.latent_cache_path, metadata=metadata) |
|
|
|
|
| def save_text_encoder_output_cache(item_info: ItemInfo, embed: torch.Tensor, mask: Optional[torch.Tensor], is_llm: bool): |
| assert ( |
| embed.dim() == 1 or embed.dim() == 2 |
| ), f"embed should be 2D tensor (feature, hidden_size) or (hidden_size,), got {embed.shape}" |
| assert mask is None or mask.dim() == 1, f"mask should be 1D tensor (feature), got {mask.shape}" |
| metadata = { |
| "architecture": "hunyuan_video", |
| "caption1": item_info.caption, |
| "format_version": "1.0.0", |
| } |
|
|
| sd = {} |
| if os.path.exists(item_info.text_encoder_output_cache_path): |
| |
| with safetensors_utils.MemoryEfficientSafeOpen(item_info.text_encoder_output_cache_path) as f: |
| existing_metadata = f.metadata() |
| for key in f.keys(): |
| sd[key] = f.get_tensor(key) |
|
|
| assert existing_metadata["architecture"] == metadata["architecture"], "architecture mismatch" |
| if existing_metadata["caption1"] != metadata["caption1"]: |
| logger.warning(f"caption mismatch: existing={existing_metadata['caption1']}, new={metadata['caption1']}, overwrite") |
| |
|
|
| existing_metadata.pop("caption1", None) |
| existing_metadata.pop("format_version", None) |
| metadata.update(existing_metadata) |
| else: |
| text_encoder_output_dir = os.path.dirname(item_info.text_encoder_output_cache_path) |
| os.makedirs(text_encoder_output_dir, exist_ok=True) |
|
|
| dtype_str = dtype_to_str(embed.dtype) |
| text_encoder_type = "llm" if is_llm else "clipL" |
| sd[f"{text_encoder_type}_{dtype_str}"] = embed.detach().cpu() |
| if mask is not None: |
| sd[f"{text_encoder_type}_mask"] = mask.detach().cpu() |
|
|
| safetensors_utils.mem_eff_save_file(sd, item_info.text_encoder_output_cache_path, metadata=metadata) |
|
|
|
|
| class BucketSelector: |
| RESOLUTION_STEPS_HUNYUAN = 16 |
|
|
| def __init__(self, resolution: Tuple[int, int], enable_bucket: bool = True, no_upscale: bool = False): |
| self.resolution = resolution |
| self.bucket_area = resolution[0] * resolution[1] |
| self.reso_steps = BucketSelector.RESOLUTION_STEPS_HUNYUAN |
|
|
| if not enable_bucket: |
| |
| self.bucket_resolutions = [resolution] |
| self.no_upscale = False |
| else: |
| |
| self.no_upscale = no_upscale |
| sqrt_size = int(math.sqrt(self.bucket_area)) |
| min_size = divisible_by(sqrt_size // 2, self.reso_steps) |
| self.bucket_resolutions = [] |
| for w in range(min_size, sqrt_size + self.reso_steps, self.reso_steps): |
| h = divisible_by(self.bucket_area // w, self.reso_steps) |
| self.bucket_resolutions.append((w, h)) |
| self.bucket_resolutions.append((h, w)) |
|
|
| self.bucket_resolutions = list(set(self.bucket_resolutions)) |
| self.bucket_resolutions.sort() |
|
|
| |
| self.aspect_ratios = np.array([w / h for w, h in self.bucket_resolutions]) |
|
|
| def get_bucket_resolution(self, image_size: tuple[int, int]) -> tuple[int, int]: |
| """ |
| return the bucket resolution for the given image size, (width, height) |
| """ |
| area = image_size[0] * image_size[1] |
| if self.no_upscale and area <= self.bucket_area: |
| w, h = image_size |
| w = divisible_by(w, self.reso_steps) |
| h = divisible_by(h, self.reso_steps) |
| return w, h |
|
|
| aspect_ratio = image_size[0] / image_size[1] |
| ar_errors = self.aspect_ratios - aspect_ratio |
| bucket_id = np.abs(ar_errors).argmin() |
| return self.bucket_resolutions[bucket_id] |
|
|
|
|
| def load_video( |
| video_path: str, |
| start_frame: Optional[int] = None, |
| end_frame: Optional[int] = None, |
| bucket_selector: Optional[BucketSelector] = None, |
| ) -> list[np.ndarray]: |
| container = av.open(video_path) |
| video = [] |
| bucket_reso = None |
| for i, frame in enumerate(container.decode(video=0)): |
| if start_frame is not None and i < start_frame: |
| continue |
| if end_frame is not None and i >= end_frame: |
| break |
| frame = frame.to_image() |
|
|
| if bucket_selector is not None and bucket_reso is None: |
| bucket_reso = bucket_selector.get_bucket_resolution(frame.size) |
|
|
| if bucket_reso is not None: |
| frame = resize_image_to_bucket(frame, bucket_reso) |
| else: |
| frame = np.array(frame) |
|
|
| video.append(frame) |
| container.close() |
| return video |
|
|
|
|
| class BucketBatchManager: |
|
|
| def __init__(self, bucketed_item_info: dict[tuple[int, int], list[ItemInfo]], batch_size: int): |
| self.batch_size = batch_size |
| self.buckets = bucketed_item_info |
| self.bucket_resos = list(self.buckets.keys()) |
| self.bucket_resos.sort() |
|
|
| self.bucket_batch_indices = [] |
| for bucket_reso in self.bucket_resos: |
| bucket = self.buckets[bucket_reso] |
| num_batches = math.ceil(len(bucket) / self.batch_size) |
| for i in range(num_batches): |
| self.bucket_batch_indices.append((bucket_reso, i)) |
|
|
| self.shuffle() |
|
|
| def show_bucket_info(self): |
| for bucket_reso in self.bucket_resos: |
| bucket = self.buckets[bucket_reso] |
| logger.info(f"bucket: {bucket_reso}, count: {len(bucket)}") |
|
|
| logger.info(f"total batches: {len(self)}") |
|
|
| def shuffle(self): |
| for bucket in self.buckets.values(): |
| random.shuffle(bucket) |
| random.shuffle(self.bucket_batch_indices) |
|
|
| def __len__(self): |
| return len(self.bucket_batch_indices) |
|
|
| def __getitem__(self, idx): |
| bucket_reso, batch_idx = self.bucket_batch_indices[idx] |
| bucket = self.buckets[bucket_reso] |
| start = batch_idx * self.batch_size |
| end = min(start + self.batch_size, len(bucket)) |
|
|
| latents = [] |
| llm_embeds = [] |
| llm_masks = [] |
| clip_l_embeds = [] |
| for item_info in bucket[start:end]: |
| sd = load_file(item_info.latent_cache_path) |
| latent = None |
| for key in sd.keys(): |
| if key.startswith("latents_"): |
| latent = sd[key] |
| break |
| latents.append(latent) |
|
|
| sd = load_file(item_info.text_encoder_output_cache_path) |
| llm_embed = llm_mask = clip_l_embed = None |
| for key in sd.keys(): |
| if key.startswith("llm_mask"): |
| llm_mask = sd[key] |
| elif key.startswith("llm_"): |
| llm_embed = sd[key] |
| elif key.startswith("clipL_mask"): |
| pass |
| elif key.startswith("clipL_"): |
| clip_l_embed = sd[key] |
| llm_embeds.append(llm_embed) |
| llm_masks.append(llm_mask) |
| clip_l_embeds.append(clip_l_embed) |
|
|
| latents = torch.stack(latents) |
| llm_embeds = torch.stack(llm_embeds) |
| llm_masks = torch.stack(llm_masks) |
| clip_l_embeds = torch.stack(clip_l_embeds) |
|
|
| return latents, llm_embeds, llm_masks, clip_l_embeds |
|
|
|
|
| class ContentDatasource: |
| def __init__(self): |
| self.caption_only = False |
|
|
| def set_caption_only(self, caption_only: bool): |
| self.caption_only = caption_only |
|
|
| def is_indexable(self): |
| return False |
|
|
| def get_caption(self, idx: int) -> tuple[str, str]: |
| """ |
| Returns caption. May not be called if is_indexable() returns False. |
| """ |
| raise NotImplementedError |
|
|
| def __len__(self): |
| raise NotImplementedError |
|
|
| def __iter__(self): |
| raise NotImplementedError |
|
|
| def __next__(self): |
| raise NotImplementedError |
|
|
|
|
| class ImageDatasource(ContentDatasource): |
| def __init__(self): |
| super().__init__() |
|
|
| def get_image_data(self, idx: int) -> tuple[str, Image.Image, str]: |
| """ |
| Returns image data as a tuple of image path, image, and caption for the given index. |
| Key must be unique and valid as a file name. |
| May not be called if is_indexable() returns False. |
| """ |
| raise NotImplementedError |
|
|
|
|
| class ImageDirectoryDatasource(ImageDatasource): |
| def __init__(self, image_directory: str, caption_extension: Optional[str] = None): |
| super().__init__() |
| self.image_directory = image_directory |
| self.caption_extension = caption_extension |
| self.current_idx = 0 |
|
|
| |
| logger.info(f"glob images in {self.image_directory}") |
| self.image_paths = glob_images(self.image_directory) |
| logger.info(f"found {len(self.image_paths)} images") |
|
|
| def is_indexable(self): |
| return True |
|
|
| def __len__(self): |
| return len(self.image_paths) |
|
|
| def get_image_data(self, idx: int) -> tuple[str, Image.Image, str]: |
| image_path = self.image_paths[idx] |
| image = Image.open(image_path).convert("RGB") |
|
|
| _, caption = self.get_caption(idx) |
|
|
| return image_path, image, caption |
|
|
| def get_caption(self, idx: int) -> tuple[str, str]: |
| image_path = self.image_paths[idx] |
| caption_path = os.path.splitext(image_path)[0] + self.caption_extension if self.caption_extension else "" |
| with open(caption_path, "r", encoding="utf-8") as f: |
| caption = f.read().strip() |
| return image_path, caption |
|
|
| def __iter__(self): |
| self.current_idx = 0 |
| return self |
|
|
| def __next__(self) -> callable: |
| """ |
| Returns a fetcher function that returns image data. |
| """ |
| if self.current_idx >= len(self.image_paths): |
| raise StopIteration |
|
|
| if self.caption_only: |
|
|
| def create_caption_fetcher(index): |
| return lambda: self.get_caption(index) |
|
|
| fetcher = create_caption_fetcher(self.current_idx) |
| else: |
|
|
| def create_image_fetcher(index): |
| return lambda: self.get_image_data(index) |
|
|
| fetcher = create_image_fetcher(self.current_idx) |
|
|
| self.current_idx += 1 |
| return fetcher |
|
|
|
|
| class ImageJsonlDatasource(ImageDatasource): |
| def __init__(self, image_jsonl_file: str): |
| super().__init__() |
| self.image_jsonl_file = image_jsonl_file |
| self.current_idx = 0 |
|
|
| |
| logger.info(f"load image jsonl from {self.image_jsonl_file}") |
| self.data = [] |
| with open(self.image_jsonl_file, "r", encoding="utf-8") as f: |
| for line in f: |
| data = json.loads(line) |
| self.data.append(data) |
| logger.info(f"loaded {len(self.data)} images") |
|
|
| def is_indexable(self): |
| return True |
|
|
| def __len__(self): |
| return len(self.data) |
|
|
| def get_image_data(self, idx: int) -> tuple[str, Image.Image, str]: |
| data = self.data[idx] |
| image_path = data["image_path"] |
| image = Image.open(image_path).convert("RGB") |
|
|
| caption = data["caption"] |
|
|
| return image_path, image, caption |
|
|
| def get_caption(self, idx: int) -> tuple[str, str]: |
| data = self.data[idx] |
| image_path = data["image_path"] |
| caption = data["caption"] |
| return image_path, caption |
|
|
| def __iter__(self): |
| self.current_idx = 0 |
| return self |
|
|
| def __next__(self) -> callable: |
| if self.current_idx >= len(self.data): |
| raise StopIteration |
|
|
| if self.caption_only: |
|
|
| def create_caption_fetcher(index): |
| return lambda: self.get_caption(index) |
|
|
| fetcher = create_caption_fetcher(self.current_idx) |
|
|
| else: |
|
|
| def create_fetcher(index): |
| return lambda: self.get_image_data(index) |
|
|
| fetcher = create_fetcher(self.current_idx) |
|
|
| self.current_idx += 1 |
| return fetcher |
|
|
|
|
| class VideoDatasource(ContentDatasource): |
| def __init__(self): |
| super().__init__() |
|
|
| |
| self.start_frame = None |
| self.end_frame = None |
|
|
| self.bucket_selector = None |
|
|
| def __len__(self): |
| raise NotImplementedError |
|
|
| def get_video_data_from_path( |
| self, |
| video_path: str, |
| start_frame: Optional[int] = None, |
| end_frame: Optional[int] = None, |
| bucket_selector: Optional[BucketSelector] = None, |
| ) -> tuple[str, list[Image.Image], str]: |
| |
|
|
| start_frame = start_frame if start_frame is not None else self.start_frame |
| end_frame = end_frame if end_frame is not None else self.end_frame |
| bucket_selector = bucket_selector if bucket_selector is not None else self.bucket_selector |
|
|
| video = load_video(video_path, start_frame, end_frame, bucket_selector) |
| return video |
|
|
| def set_start_and_end_frame(self, start_frame: Optional[int], end_frame: Optional[int]): |
| self.start_frame = start_frame |
| self.end_frame = end_frame |
|
|
| def set_bucket_selector(self, bucket_selector: BucketSelector): |
| self.bucket_selector = bucket_selector |
|
|
| def __iter__(self): |
| raise NotImplementedError |
|
|
| def __next__(self): |
| raise NotImplementedError |
|
|
|
|
| class VideoDirectoryDatasource(VideoDatasource): |
| def __init__(self, video_directory: str, caption_extension: Optional[str] = None): |
| super().__init__() |
| self.video_directory = video_directory |
| self.caption_extension = caption_extension |
| self.current_idx = 0 |
|
|
| |
| logger.info(f"glob images in {self.video_directory}") |
| self.video_paths = glob_videos(self.video_directory) |
| logger.info(f"found {len(self.video_paths)} videos") |
|
|
| def is_indexable(self): |
| return True |
|
|
| def __len__(self): |
| return len(self.video_paths) |
|
|
| def get_video_data( |
| self, |
| idx: int, |
| start_frame: Optional[int] = None, |
| end_frame: Optional[int] = None, |
| bucket_selector: Optional[BucketSelector] = None, |
| ) -> tuple[str, list[Image.Image], str]: |
| video_path = self.video_paths[idx] |
| video = self.get_video_data_from_path(video_path, start_frame, end_frame, bucket_selector) |
|
|
| _, caption = self.get_caption(idx) |
|
|
| return video_path, video, caption |
|
|
| def get_caption(self, idx: int) -> tuple[str, str]: |
| video_path = self.video_paths[idx] |
| caption_path = os.path.splitext(video_path)[0] + self.caption_extension if self.caption_extension else "" |
| with open(caption_path, "r", encoding="utf-8") as f: |
| caption = f.read().strip() |
| return video_path, caption |
|
|
| def __iter__(self): |
| self.current_idx = 0 |
| return self |
|
|
| def __next__(self): |
| if self.current_idx >= len(self.video_paths): |
| raise StopIteration |
|
|
| if self.caption_only: |
|
|
| def create_caption_fetcher(index): |
| return lambda: self.get_caption(index) |
|
|
| fetcher = create_caption_fetcher(self.current_idx) |
|
|
| else: |
|
|
| def create_fetcher(index): |
| return lambda: self.get_video_data(index) |
|
|
| fetcher = create_fetcher(self.current_idx) |
|
|
| self.current_idx += 1 |
| return fetcher |
|
|
|
|
| class VideoJsonlDatasource(VideoDatasource): |
| def __init__(self, video_jsonl_file: str): |
| super().__init__() |
| self.video_jsonl_file = video_jsonl_file |
| self.current_idx = 0 |
|
|
| |
| logger.info(f"load video jsonl from {self.video_jsonl_file}") |
| self.data = [] |
| with open(self.video_jsonl_file, "r", encoding="utf-8") as f: |
| for line in f: |
| data = json.loads(line) |
| self.data.append(data) |
| logger.info(f"loaded {len(self.data)} videos") |
|
|
| def is_indexable(self): |
| return True |
|
|
| def __len__(self): |
| return len(self.data) |
|
|
| def get_video_data( |
| self, |
| idx: int, |
| start_frame: Optional[int] = None, |
| end_frame: Optional[int] = None, |
| bucket_selector: Optional[BucketSelector] = None, |
| ) -> tuple[str, list[Image.Image], str]: |
| data = self.data[idx] |
| video_path = data["video_path"] |
| video = self.get_video_data_from_path(video_path, start_frame, end_frame, bucket_selector) |
|
|
| caption = data["caption"] |
|
|
| return video_path, video, caption |
|
|
| def get_caption(self, idx: int) -> tuple[str, str]: |
| data = self.data[idx] |
| video_path = data["video_path"] |
| caption = data["caption"] |
| return video_path, caption |
|
|
| def __iter__(self): |
| self.current_idx = 0 |
| return self |
|
|
| def __next__(self): |
| if self.current_idx >= len(self.data): |
| raise StopIteration |
|
|
| if self.caption_only: |
|
|
| def create_caption_fetcher(index): |
| return lambda: self.get_caption(index) |
|
|
| fetcher = create_caption_fetcher(self.current_idx) |
|
|
| else: |
|
|
| def create_fetcher(index): |
| return lambda: self.get_video_data(index) |
|
|
| fetcher = create_fetcher(self.current_idx) |
|
|
| self.current_idx += 1 |
| return fetcher |
|
|
|
|
| class BaseDataset(torch.utils.data.Dataset): |
| def __init__( |
| self, |
| resolution: Tuple[int, int] = (960, 544), |
| caption_extension: Optional[str] = None, |
| batch_size: int = 1, |
| enable_bucket: bool = False, |
| bucket_no_upscale: bool = False, |
| cache_directory: Optional[str] = None, |
| debug_dataset: bool = False, |
| ): |
| self.resolution = resolution |
| self.caption_extension = caption_extension |
| self.batch_size = batch_size |
| self.enable_bucket = enable_bucket |
| self.bucket_no_upscale = bucket_no_upscale |
| self.cache_directory = cache_directory |
| self.debug_dataset = debug_dataset |
| self.seed = None |
| self.current_epoch = 0 |
|
|
| if not self.enable_bucket: |
| self.bucket_no_upscale = False |
|
|
| def get_metadata(self) -> dict: |
| metadata = { |
| "resolution": self.resolution, |
| "caption_extension": self.caption_extension, |
| "batch_size_per_device": self.batch_size, |
| "enable_bucket": bool(self.enable_bucket), |
| "bucket_no_upscale": bool(self.bucket_no_upscale), |
| } |
| return metadata |
|
|
| def get_latent_cache_path(self, item_info: ItemInfo) -> str: |
| w, h = item_info.original_size |
| basename = os.path.splitext(os.path.basename(item_info.item_key))[0] |
| assert self.cache_directory is not None, "cache_directory is required / cache_directoryは必須です" |
| return os.path.join(self.cache_directory, f"{basename}_{w:04d}x{h:04d}_{ARCHITECTURE_HUNYUAN_VIDEO}.safetensors") |
|
|
| def get_text_encoder_output_cache_path(self, item_info: ItemInfo) -> str: |
| basename = os.path.splitext(os.path.basename(item_info.item_key))[0] |
| assert self.cache_directory is not None, "cache_directory is required / cache_directoryは必須です" |
| return os.path.join(self.cache_directory, f"{basename}_{ARCHITECTURE_HUNYUAN_VIDEO}_te.safetensors") |
|
|
| def retrieve_latent_cache_batches(self, num_workers: int): |
| raise NotImplementedError |
|
|
| def retrieve_text_encoder_output_cache_batches(self, num_workers: int): |
| raise NotImplementedError |
|
|
| def prepare_for_training(self): |
| pass |
|
|
| def set_seed(self, seed: int): |
| self.seed = seed |
|
|
| def set_current_epoch(self, epoch): |
| if not self.current_epoch == epoch: |
| if epoch > self.current_epoch: |
| logger.info("epoch is incremented. current_epoch: {}, epoch: {}".format(self.current_epoch, epoch)) |
| num_epochs = epoch - self.current_epoch |
| for _ in range(num_epochs): |
| self.current_epoch += 1 |
| self.shuffle_buckets() |
| |
| else: |
| logger.warning("epoch is not incremented. current_epoch: {}, epoch: {}".format(self.current_epoch, epoch)) |
| self.current_epoch = epoch |
|
|
| def set_current_step(self, step): |
| self.current_step = step |
|
|
| def set_max_train_steps(self, max_train_steps): |
| self.max_train_steps = max_train_steps |
|
|
| def shuffle_buckets(self): |
| raise NotImplementedError |
|
|
| def __len__(self): |
| return NotImplementedError |
|
|
| def __getitem__(self, idx): |
| raise NotImplementedError |
|
|
| def _default_retrieve_text_encoder_output_cache_batches(self, datasource: ContentDatasource, batch_size: int, num_workers: int): |
| datasource.set_caption_only(True) |
| executor = ThreadPoolExecutor(max_workers=num_workers) |
|
|
| data: list[ItemInfo] = [] |
| futures = [] |
|
|
| def aggregate_future(consume_all: bool = False): |
| while len(futures) >= num_workers or (consume_all and len(futures) > 0): |
| completed_futures = [future for future in futures if future.done()] |
| if len(completed_futures) == 0: |
| if len(futures) >= num_workers or consume_all: |
| time.sleep(0.1) |
| continue |
| else: |
| break |
|
|
| for future in completed_futures: |
| item_key, caption = future.result() |
| item_info = ItemInfo(item_key, caption, (0, 0), (0, 0)) |
| item_info.text_encoder_output_cache_path = self.get_text_encoder_output_cache_path(item_info) |
| data.append(item_info) |
|
|
| futures.remove(future) |
|
|
| def submit_batch(flush: bool = False): |
| nonlocal data |
| if len(data) >= batch_size or (len(data) > 0 and flush): |
| batch = data[0:batch_size] |
| if len(data) > batch_size: |
| data = data[batch_size:] |
| else: |
| data = [] |
| return batch |
| return None |
|
|
| for fetch_op in datasource: |
| future = executor.submit(fetch_op) |
| futures.append(future) |
| aggregate_future() |
| while True: |
| batch = submit_batch() |
| if batch is None: |
| break |
| yield batch |
|
|
| aggregate_future(consume_all=True) |
| while True: |
| batch = submit_batch(flush=True) |
| if batch is None: |
| break |
| yield batch |
|
|
| executor.shutdown() |
|
|
|
|
| class ImageDataset(BaseDataset): |
| def __init__( |
| self, |
| resolution: Tuple[int, int], |
| caption_extension: Optional[str], |
| batch_size: int, |
| enable_bucket: bool, |
| bucket_no_upscale: bool, |
| image_directory: Optional[str] = None, |
| image_jsonl_file: Optional[str] = None, |
| cache_directory: Optional[str] = None, |
| debug_dataset: bool = False, |
| ): |
| super(ImageDataset, self).__init__( |
| resolution, caption_extension, batch_size, enable_bucket, bucket_no_upscale, cache_directory, debug_dataset |
| ) |
| self.image_directory = image_directory |
| self.image_jsonl_file = image_jsonl_file |
| if image_directory is not None: |
| self.datasource = ImageDirectoryDatasource(image_directory, caption_extension) |
| elif image_jsonl_file is not None: |
| self.datasource = ImageJsonlDatasource(image_jsonl_file) |
| else: |
| raise ValueError("image_directory or image_jsonl_file must be specified") |
|
|
| if self.cache_directory is None: |
| self.cache_directory = self.image_directory |
|
|
| self.batch_manager = None |
| self.num_train_items = 0 |
|
|
| def get_metadata(self): |
| metadata = super().get_metadata() |
| if self.image_directory is not None: |
| metadata["image_directory"] = os.path.basename(self.image_directory) |
| if self.image_jsonl_file is not None: |
| metadata["image_jsonl_file"] = os.path.basename(self.image_jsonl_file) |
| return metadata |
|
|
| def get_total_image_count(self): |
| return len(self.datasource) if self.datasource.is_indexable() else None |
|
|
| def retrieve_latent_cache_batches(self, num_workers: int): |
| buckset_selector = BucketSelector(self.resolution, self.enable_bucket, self.bucket_no_upscale) |
| executor = ThreadPoolExecutor(max_workers=num_workers) |
|
|
| batches: dict[tuple[int, int], list[ItemInfo]] = {} |
| futures = [] |
|
|
| def aggregate_future(consume_all: bool = False): |
| while len(futures) >= num_workers or (consume_all and len(futures) > 0): |
| completed_futures = [future for future in futures if future.done()] |
| if len(completed_futures) == 0: |
| if len(futures) >= num_workers or consume_all: |
| time.sleep(0.1) |
| continue |
| else: |
| break |
|
|
| for future in completed_futures: |
| original_size, item_key, image, caption = future.result() |
| bucket_height, bucket_width = image.shape[:2] |
| bucket_reso = (bucket_width, bucket_height) |
|
|
| item_info = ItemInfo(item_key, caption, original_size, bucket_reso, content=image) |
| item_info.latent_cache_path = self.get_latent_cache_path(item_info) |
|
|
| if bucket_reso not in batches: |
| batches[bucket_reso] = [] |
| batches[bucket_reso].append(item_info) |
|
|
| futures.remove(future) |
|
|
| def submit_batch(flush: bool = False): |
| for key in batches: |
| if len(batches[key]) >= self.batch_size or flush: |
| batch = batches[key][0 : self.batch_size] |
| if len(batches[key]) > self.batch_size: |
| batches[key] = batches[key][self.batch_size :] |
| else: |
| del batches[key] |
| return key, batch |
| return None, None |
|
|
| for fetch_op in self.datasource: |
|
|
| def fetch_and_resize(op: callable) -> tuple[tuple[int, int], str, Image.Image, str]: |
| image_key, image, caption = op() |
| image: Image.Image |
| image_size = image.size |
|
|
| bucket_reso = buckset_selector.get_bucket_resolution(image_size) |
| image = resize_image_to_bucket(image, bucket_reso) |
| return image_size, image_key, image, caption |
|
|
| future = executor.submit(fetch_and_resize, fetch_op) |
| futures.append(future) |
| aggregate_future() |
| while True: |
| key, batch = submit_batch() |
| if key is None: |
| break |
| yield key, batch |
|
|
| aggregate_future(consume_all=True) |
| while True: |
| key, batch = submit_batch(flush=True) |
| if key is None: |
| break |
| yield key, batch |
|
|
| executor.shutdown() |
|
|
| def retrieve_text_encoder_output_cache_batches(self, num_workers: int): |
| return self._default_retrieve_text_encoder_output_cache_batches(self.datasource, self.batch_size, num_workers) |
|
|
| def prepare_for_training(self): |
| bucket_selector = BucketSelector(self.resolution, self.enable_bucket, self.bucket_no_upscale) |
|
|
| |
| latent_cache_files = glob.glob(os.path.join(self.cache_directory, f"*_{ARCHITECTURE_HUNYUAN_VIDEO}.safetensors")) |
|
|
| |
| bucketed_item_info: dict[tuple[int, int], list[ItemInfo]] = {} |
| for cache_file in latent_cache_files: |
| tokens = os.path.basename(cache_file).split("_") |
|
|
| image_size = tokens[-2] |
| image_width, image_height = map(int, image_size.split("x")) |
| image_size = (image_width, image_height) |
|
|
| item_key = "_".join(tokens[:-2]) |
| text_encoder_output_cache_file = os.path.join( |
| self.cache_directory, f"{item_key}_{ARCHITECTURE_HUNYUAN_VIDEO}_te.safetensors" |
| ) |
| if not os.path.exists(text_encoder_output_cache_file): |
| logger.warning(f"Text encoder output cache file not found: {text_encoder_output_cache_file}") |
| continue |
|
|
| bucket_reso = bucket_selector.get_bucket_resolution(image_size) |
| item_info = ItemInfo(item_key, "", image_size, bucket_reso, latent_cache_path=cache_file) |
| item_info.text_encoder_output_cache_path = text_encoder_output_cache_file |
|
|
| bucket = bucketed_item_info.get(bucket_reso, []) |
| bucket.append(item_info) |
| bucketed_item_info[bucket_reso] = bucket |
|
|
| |
| self.batch_manager = BucketBatchManager(bucketed_item_info, self.batch_size) |
| self.batch_manager.show_bucket_info() |
|
|
| self.num_train_items = sum([len(bucket) for bucket in bucketed_item_info.values()]) |
|
|
| def shuffle_buckets(self): |
| |
| random.seed(self.seed + self.current_epoch) |
| self.batch_manager.shuffle() |
|
|
| def __len__(self): |
| if self.batch_manager is None: |
| return 100 |
| return len(self.batch_manager) |
|
|
| def __getitem__(self, idx): |
| return self.batch_manager[idx] |
|
|
|
|
| class VideoDataset(BaseDataset): |
| def __init__( |
| self, |
| resolution: Tuple[int, int], |
| caption_extension: Optional[str], |
| batch_size: int, |
| enable_bucket: bool, |
| bucket_no_upscale: bool, |
| frame_extraction: Optional[str] = "head", |
| frame_stride: Optional[int] = 1, |
| frame_sample: Optional[int] = 1, |
| target_frames: Optional[list[int]] = None, |
| video_directory: Optional[str] = None, |
| video_jsonl_file: Optional[str] = None, |
| cache_directory: Optional[str] = None, |
| debug_dataset: bool = False, |
| ): |
| super(VideoDataset, self).__init__( |
| resolution, caption_extension, batch_size, enable_bucket, bucket_no_upscale, cache_directory, debug_dataset |
| ) |
| self.video_directory = video_directory |
| self.video_jsonl_file = video_jsonl_file |
| self.target_frames = target_frames |
| self.frame_extraction = frame_extraction |
| self.frame_stride = frame_stride |
| self.frame_sample = frame_sample |
|
|
| if video_directory is not None: |
| self.datasource = VideoDirectoryDatasource(video_directory, caption_extension) |
| elif video_jsonl_file is not None: |
| self.datasource = VideoJsonlDatasource(video_jsonl_file) |
|
|
| if self.frame_extraction == "uniform" and self.frame_sample == 1: |
| self.frame_extraction = "head" |
| logger.warning("frame_sample is set to 1 for frame_extraction=uniform. frame_extraction is changed to head.") |
| if self.frame_extraction == "head": |
| |
| self.datasource.set_start_and_end_frame(0, max(self.target_frames)) |
|
|
| if self.cache_directory is None: |
| self.cache_directory = self.video_directory |
|
|
| self.batch_manager = None |
| self.num_train_items = 0 |
|
|
| def get_metadata(self): |
| metadata = super().get_metadata() |
| if self.video_directory is not None: |
| metadata["video_directory"] = os.path.basename(self.video_directory) |
| if self.video_jsonl_file is not None: |
| metadata["video_jsonl_file"] = os.path.basename(self.video_jsonl_file) |
| metadata["frame_extraction"] = self.frame_extraction |
| metadata["frame_stride"] = self.frame_stride |
| metadata["frame_sample"] = self.frame_sample |
| metadata["target_frames"] = self.target_frames |
| return metadata |
|
|
| def retrieve_latent_cache_batches(self, num_workers: int): |
| buckset_selector = BucketSelector(self.resolution) |
| self.datasource.set_bucket_selector(buckset_selector) |
|
|
| executor = ThreadPoolExecutor(max_workers=num_workers) |
|
|
| |
| batches: dict[tuple[int, int, int], list[ItemInfo]] = {} |
| futures = [] |
|
|
| def aggregate_future(consume_all: bool = False): |
| while len(futures) >= num_workers or (consume_all and len(futures) > 0): |
| completed_futures = [future for future in futures if future.done()] |
| if len(completed_futures) == 0: |
| if len(futures) >= num_workers or consume_all: |
| time.sleep(0.1) |
| continue |
| else: |
| break |
|
|
| for future in completed_futures: |
| original_frame_size, video_key, video, caption = future.result() |
|
|
| frame_count = len(video) |
| video = np.stack(video, axis=0) |
| height, width = video.shape[1:3] |
| bucket_reso = (width, height) |
|
|
| crop_pos_and_frames = [] |
| if self.frame_extraction == "head": |
| for target_frame in self.target_frames: |
| if frame_count >= target_frame: |
| crop_pos_and_frames.append((0, target_frame)) |
| elif self.frame_extraction == "chunk": |
| |
| for target_frame in self.target_frames: |
| for i in range(0, frame_count, target_frame): |
| if i + target_frame <= frame_count: |
| crop_pos_and_frames.append((i, target_frame)) |
| elif self.frame_extraction == "slide": |
| |
| for target_frame in self.target_frames: |
| if frame_count >= target_frame: |
| for i in range(0, frame_count - target_frame + 1, self.frame_stride): |
| crop_pos_and_frames.append((i, target_frame)) |
| elif self.frame_extraction == "uniform": |
| |
| for target_frame in self.target_frames: |
| if frame_count >= target_frame: |
| frame_indices = np.linspace(0, frame_count - target_frame, self.frame_sample, dtype=int) |
| for i in frame_indices: |
| crop_pos_and_frames.append((i, target_frame)) |
| else: |
| raise ValueError(f"frame_extraction {self.frame_extraction} is not supported") |
|
|
| for crop_pos, target_frame in crop_pos_and_frames: |
| cropped_video = video[crop_pos : crop_pos + target_frame] |
| body, ext = os.path.splitext(video_key) |
| item_key = f"{body}_{crop_pos:05d}-{target_frame:03d}{ext}" |
| batch_key = (*bucket_reso, target_frame) |
|
|
| item_info = ItemInfo( |
| item_key, caption, original_frame_size, batch_key, frame_count=target_frame, content=cropped_video |
| ) |
| item_info.latent_cache_path = self.get_latent_cache_path(item_info) |
|
|
| batch = batches.get(batch_key, []) |
| batch.append(item_info) |
| batches[batch_key] = batch |
|
|
| futures.remove(future) |
|
|
| def submit_batch(flush: bool = False): |
| for key in batches: |
| if len(batches[key]) >= self.batch_size or flush: |
| batch = batches[key][0 : self.batch_size] |
| if len(batches[key]) > self.batch_size: |
| batches[key] = batches[key][self.batch_size :] |
| else: |
| del batches[key] |
| return key, batch |
| return None, None |
|
|
| for operator in self.datasource: |
|
|
| def fetch_and_resize(op: callable) -> tuple[tuple[int, int], str, list[np.ndarray], str]: |
| video_key, video, caption = op() |
| video: list[np.ndarray] |
| frame_size = (video[0].shape[1], video[0].shape[0]) |
|
|
| |
| bucket_reso = buckset_selector.get_bucket_resolution(frame_size) |
| video = [resize_image_to_bucket(frame, bucket_reso) for frame in video] |
|
|
| return frame_size, video_key, video, caption |
|
|
| future = executor.submit(fetch_and_resize, operator) |
| futures.append(future) |
| aggregate_future() |
| while True: |
| key, batch = submit_batch() |
| if key is None: |
| break |
| yield key, batch |
|
|
| aggregate_future(consume_all=True) |
| while True: |
| key, batch = submit_batch(flush=True) |
| if key is None: |
| break |
| yield key, batch |
|
|
| executor.shutdown() |
|
|
| def retrieve_text_encoder_output_cache_batches(self, num_workers: int): |
| return self._default_retrieve_text_encoder_output_cache_batches(self.datasource, self.batch_size, num_workers) |
|
|
| def prepare_for_training(self): |
| bucket_selector = BucketSelector(self.resolution, self.enable_bucket, self.bucket_no_upscale) |
|
|
| |
| latent_cache_files = glob.glob(os.path.join(self.cache_directory, f"*_{ARCHITECTURE_HUNYUAN_VIDEO}.safetensors")) |
|
|
| |
| bucketed_item_info: dict[tuple[int, int, int], list[ItemInfo]] = {} |
| for cache_file in latent_cache_files: |
| tokens = os.path.basename(cache_file).split("_") |
|
|
| image_size = tokens[-2] |
| image_width, image_height = map(int, image_size.split("x")) |
| image_size = (image_width, image_height) |
|
|
| frame_pos, frame_count = tokens[-3].split("-") |
| frame_pos, frame_count = int(frame_pos), int(frame_count) |
|
|
| item_key = "_".join(tokens[:-3]) |
| text_encoder_output_cache_file = os.path.join( |
| self.cache_directory, f"{item_key}_{ARCHITECTURE_HUNYUAN_VIDEO}_te.safetensors" |
| ) |
| if not os.path.exists(text_encoder_output_cache_file): |
| logger.warning(f"Text encoder output cache file not found: {text_encoder_output_cache_file}") |
| continue |
|
|
| bucket_reso = bucket_selector.get_bucket_resolution(image_size) |
| bucket_reso = (*bucket_reso, frame_count) |
| item_info = ItemInfo(item_key, "", image_size, bucket_reso, frame_count=frame_count, latent_cache_path=cache_file) |
| item_info.text_encoder_output_cache_path = text_encoder_output_cache_file |
|
|
| bucket = bucketed_item_info.get(bucket_reso, []) |
| bucket.append(item_info) |
| bucketed_item_info[bucket_reso] = bucket |
|
|
| |
| self.batch_manager = BucketBatchManager(bucketed_item_info, self.batch_size) |
| self.batch_manager.show_bucket_info() |
|
|
| self.num_train_items = sum([len(bucket) for bucket in bucketed_item_info.values()]) |
|
|
| def shuffle_buckets(self): |
| |
| random.seed(self.seed + self.current_epoch) |
| self.batch_manager.shuffle() |
|
|
| def __len__(self): |
| if self.batch_manager is None: |
| return 100 |
| return len(self.batch_manager) |
|
|
| def __getitem__(self, idx): |
| return self.batch_manager[idx] |
|
|
|
|
| class DatasetGroup(torch.utils.data.ConcatDataset): |
| def __init__(self, datasets: Sequence[Union[ImageDataset, VideoDataset]]): |
| super().__init__(datasets) |
| self.datasets: list[Union[ImageDataset, VideoDataset]] = datasets |
| self.num_train_items = 0 |
| for dataset in self.datasets: |
| self.num_train_items += dataset.num_train_items |
|
|
| def set_current_epoch(self, epoch): |
| for dataset in self.datasets: |
| dataset.set_current_epoch(epoch) |
|
|
| def set_current_step(self, step): |
| for dataset in self.datasets: |
| dataset.set_current_step(step) |
|
|
| def set_max_train_steps(self, max_train_steps): |
| for dataset in self.datasets: |
| dataset.set_max_train_steps(max_train_steps) |
|
|