| from concurrent.futures import ThreadPoolExecutor |
| import glob |
| import json |
| import math |
| import os |
| import random |
| import time |
| from typing import Optional, Sequence, Tuple, Union |
|
|
| import numpy as np |
| import torch |
| from safetensors.torch import save_file, load_file |
| from safetensors import safe_open |
| from PIL import Image |
| import cv2 |
| import av |
|
|
| from utils import safetensors_utils |
| from utils.model_utils import dtype_to_str |
|
|
| import logging |
|
|
| logger = logging.getLogger(__name__) |
| logging.basicConfig(level=logging.INFO) |
|
|
|
|
| IMAGE_EXTENSIONS = [".png", ".jpg", ".jpeg", ".webp", ".bmp", ".PNG", ".JPG", ".JPEG", ".WEBP", ".BMP"] |
|
|
| try: |
| import pillow_avif |
|
|
| IMAGE_EXTENSIONS.extend([".avif", ".AVIF"]) |
| except: |
| pass |
|
|
| |
| try: |
| from jxlpy import JXLImagePlugin |
|
|
| IMAGE_EXTENSIONS.extend([".jxl", ".JXL"]) |
| except: |
| pass |
|
|
| |
| try: |
| import pillow_jxl |
|
|
| IMAGE_EXTENSIONS.extend([".jxl", ".JXL"]) |
| except: |
| pass |
|
|
| VIDEO_EXTENSIONS = [ |
| ".mp4", |
| ".webm", |
| ".avi", |
| ".mkv", |
| ".mov", |
| ".flv", |
| ".wmv", |
| ".m4v", |
| ".mpg", |
| ".mpeg", |
| ".MP4", |
| ".WEBM", |
| ".AVI", |
| ".MKV", |
| ".MOV", |
| ".FLV", |
| ".WMV", |
| ".M4V", |
| ".MPG", |
| ".MPEG", |
| ] |
|
|
| ARCHITECTURE_HUNYUAN_VIDEO = "hv" |
| ARCHITECTURE_HUNYUAN_VIDEO_FULL = "hunyuan_video" |
| ARCHITECTURE_WAN = "wan" |
| ARCHITECTURE_WAN_FULL = "wan" |
| ARCHITECTURE_FRAMEPACK = "fp" |
| ARCHITECTURE_FRAMEPACK_FULL = "framepack" |
|
|
|
|
| def glob_images(directory, base="*"): |
| img_paths = [] |
| for ext in IMAGE_EXTENSIONS: |
| if base == "*": |
| img_paths.extend(glob.glob(os.path.join(glob.escape(directory), base + ext))) |
| else: |
| img_paths.extend(glob.glob(glob.escape(os.path.join(directory, base + ext)))) |
| img_paths = list(set(img_paths)) |
| img_paths.sort() |
| return img_paths |
|
|
|
|
| def glob_videos(directory, base="*"): |
| video_paths = [] |
| for ext in VIDEO_EXTENSIONS: |
| if base == "*": |
| video_paths.extend(glob.glob(os.path.join(glob.escape(directory), base + ext))) |
| else: |
| video_paths.extend(glob.glob(glob.escape(os.path.join(directory, base + ext)))) |
| video_paths = list(set(video_paths)) |
| video_paths.sort() |
| return video_paths |
|
|
|
|
| def divisible_by(num: int, divisor: int) -> int: |
| return num - num % divisor |
|
|
|
|
| def resize_image_to_bucket(image: Union[Image.Image, np.ndarray], bucket_reso: tuple[int, int]) -> np.ndarray: |
| """ |
| Resize the image to the bucket resolution. |
| |
| bucket_reso: **(width, height)** |
| """ |
| is_pil_image = isinstance(image, Image.Image) |
| if is_pil_image: |
| image_width, image_height = image.size |
| else: |
| image_height, image_width = image.shape[:2] |
|
|
| if bucket_reso == (image_width, image_height): |
| return np.array(image) if is_pil_image else image |
|
|
| bucket_width, bucket_height = bucket_reso |
| if bucket_width == image_width or bucket_height == image_height: |
| image = np.array(image) if is_pil_image else image |
| else: |
| |
| scale_width = bucket_width / image_width |
| scale_height = bucket_height / image_height |
| scale = max(scale_width, scale_height) |
| image_width = int(image_width * scale + 0.5) |
| image_height = int(image_height * scale + 0.5) |
|
|
| if scale > 1: |
| image = Image.fromarray(image) if not is_pil_image else image |
| image = image.resize((image_width, image_height), Image.LANCZOS) |
| image = np.array(image) |
| else: |
| image = np.array(image) if is_pil_image else image |
| image = cv2.resize(image, (image_width, image_height), interpolation=cv2.INTER_AREA) |
|
|
| |
| crop_left = (image_width - bucket_width) // 2 |
| crop_top = (image_height - bucket_height) // 2 |
| image = image[crop_top : crop_top + bucket_height, crop_left : crop_left + bucket_width] |
| return image |
|
|
|
|
| class ItemInfo: |
| def __init__( |
| self, |
| item_key: str, |
| caption: str, |
| original_size: tuple[int, int], |
| bucket_size: Optional[Union[tuple[int, int], tuple[int, int, int]]] = None, |
| frame_count: Optional[int] = None, |
| content: Optional[np.ndarray] = None, |
| latent_cache_path: Optional[str] = None, |
| ) -> None: |
| self.item_key = item_key |
| self.caption = caption |
| self.original_size = original_size |
| self.bucket_size = bucket_size |
| self.frame_count = frame_count |
| self.content = content |
| self.latent_cache_path = latent_cache_path |
| self.text_encoder_output_cache_path: Optional[str] = None |
| self.control_content: Optional[np.ndarray] = None |
|
|
| def __str__(self) -> str: |
| return ( |
| f"ItemInfo(item_key={self.item_key}, caption={self.caption}, " |
| + f"original_size={self.original_size}, bucket_size={self.bucket_size}, " |
| + f"frame_count={self.frame_count}, latent_cache_path={self.latent_cache_path}, content={self.content.shape if self.content is not None else None})" |
| ) |
|
|
|
|
| |
| |
|
|
| |
| |
|
|
|
|
| def save_latent_cache(item_info: ItemInfo, latent: torch.Tensor): |
| """HunyuanVideo architecture only. HunyuanVideo doesn't support I2V and control latents""" |
| assert latent.dim() == 4, "latent should be 4D tensor (frame, channel, height, width)" |
|
|
| _, F, H, W = latent.shape |
| dtype_str = dtype_to_str(latent.dtype) |
| sd = {f"latents_{F}x{H}x{W}_{dtype_str}": latent.detach().cpu()} |
|
|
| save_latent_cache_common(item_info, sd, ARCHITECTURE_HUNYUAN_VIDEO_FULL) |
|
|
|
|
| def save_latent_cache_wan( |
| item_info: ItemInfo, |
| latent: torch.Tensor, |
| clip_embed: Optional[torch.Tensor], |
| image_latent: Optional[torch.Tensor], |
| control_latent: Optional[torch.Tensor], |
| ): |
| """Wan architecture only""" |
| assert latent.dim() == 4, "latent should be 4D tensor (frame, channel, height, width)" |
|
|
| _, F, H, W = latent.shape |
| dtype_str = dtype_to_str(latent.dtype) |
| sd = {f"latents_{F}x{H}x{W}_{dtype_str}": latent.detach().cpu()} |
|
|
| if clip_embed is not None: |
| sd[f"clip_{dtype_str}"] = clip_embed.detach().cpu() |
|
|
| if image_latent is not None: |
| sd[f"latents_image_{F}x{H}x{W}_{dtype_str}"] = image_latent.detach().cpu() |
|
|
| if control_latent is not None: |
| sd[f"latents_control_{F}x{H}x{W}_{dtype_str}"] = control_latent.detach().cpu() |
|
|
| save_latent_cache_common(item_info, sd, ARCHITECTURE_WAN_FULL) |
|
|
|
|
| def save_latent_cache_framepack( |
| item_info: ItemInfo, |
| latent: torch.Tensor, |
| latent_indices: torch.Tensor, |
| clean_latents: torch.Tensor, |
| clean_latent_indices: torch.Tensor, |
| clean_latents_2x: torch.Tensor, |
| clean_latent_2x_indices: torch.Tensor, |
| clean_latents_4x: torch.Tensor, |
| clean_latent_4x_indices: torch.Tensor, |
| image_embeddings: torch.Tensor, |
| ): |
| """FramePack architecture only""" |
| assert latent.dim() == 4, "latent should be 4D tensor (frame, channel, height, width)" |
|
|
| _, F, H, W = latent.shape |
| dtype_str = dtype_to_str(latent.dtype) |
| sd = {f"latents_{F}x{H}x{W}_{dtype_str}": latent.detach().cpu().contiguous()} |
|
|
| |
| indices_dtype_str = dtype_to_str(latent_indices.dtype) |
| sd[f"image_embeddings_{dtype_str}"] = image_embeddings.detach().cpu() |
| sd[f"latent_indices_{indices_dtype_str}"] = latent_indices.detach().cpu() |
| sd[f"clean_latent_indices_{indices_dtype_str}"] = clean_latent_indices.detach().cpu() |
| sd[f"clean_latent_2x_indices_{indices_dtype_str}"] = clean_latent_2x_indices.detach().cpu() |
| sd[f"clean_latent_4x_indices_{indices_dtype_str}"] = clean_latent_4x_indices.detach().cpu() |
| sd[f"latents_clean_{F}x{H}x{W}_{dtype_str}"] = clean_latents.detach().cpu().contiguous() |
| sd[f"latents_clean_2x_{F}x{H}x{W}_{dtype_str}"] = clean_latents_2x.detach().cpu().contiguous() |
| sd[f"latents_clean_4x_{F}x{H}x{W}_{dtype_str}"] = clean_latents_4x.detach().cpu().contiguous() |
|
|
| |
| |
| save_latent_cache_common(item_info, sd, ARCHITECTURE_FRAMEPACK_FULL) |
|
|
|
|
| def save_latent_cache_common(item_info: ItemInfo, sd: dict[str, torch.Tensor], arch_fullname: str): |
| metadata = { |
| "architecture": arch_fullname, |
| "width": f"{item_info.original_size[0]}", |
| "height": f"{item_info.original_size[1]}", |
| "format_version": "1.0.1", |
| } |
| if item_info.frame_count is not None: |
| metadata["frame_count"] = f"{item_info.frame_count}" |
|
|
| for key, value in sd.items(): |
| |
| if torch.isnan(value).any(): |
| logger.warning(f"{key} tensor has NaN: {item_info.item_key}, replace NaN with 0") |
| value[torch.isnan(value)] = 0 |
|
|
| latent_dir = os.path.dirname(item_info.latent_cache_path) |
| os.makedirs(latent_dir, exist_ok=True) |
|
|
| save_file(sd, item_info.latent_cache_path, metadata=metadata) |
|
|
|
|
| def save_text_encoder_output_cache(item_info: ItemInfo, embed: torch.Tensor, mask: Optional[torch.Tensor], is_llm: bool): |
| """HunyuanVideo architecture only""" |
| assert ( |
| embed.dim() == 1 or embed.dim() == 2 |
| ), f"embed should be 2D tensor (feature, hidden_size) or (hidden_size,), got {embed.shape}" |
| assert mask is None or mask.dim() == 1, f"mask should be 1D tensor (feature), got {mask.shape}" |
|
|
| sd = {} |
| dtype_str = dtype_to_str(embed.dtype) |
| text_encoder_type = "llm" if is_llm else "clipL" |
| sd[f"{text_encoder_type}_{dtype_str}"] = embed.detach().cpu() |
| if mask is not None: |
| sd[f"{text_encoder_type}_mask"] = mask.detach().cpu() |
|
|
| save_text_encoder_output_cache_common(item_info, sd, ARCHITECTURE_HUNYUAN_VIDEO_FULL) |
|
|
|
|
| def save_text_encoder_output_cache_wan(item_info: ItemInfo, embed: torch.Tensor): |
| """Wan architecture only. Wan2.1 only has a single text encoder""" |
|
|
| sd = {} |
| dtype_str = dtype_to_str(embed.dtype) |
| text_encoder_type = "t5" |
| sd[f"varlen_{text_encoder_type}_{dtype_str}"] = embed.detach().cpu() |
|
|
| save_text_encoder_output_cache_common(item_info, sd, ARCHITECTURE_WAN_FULL) |
|
|
|
|
| def save_text_encoder_output_cache_framepack( |
| item_info: ItemInfo, llama_vec: torch.Tensor, llama_attention_mask: torch.Tensor, clip_l_pooler: torch.Tensor |
| ): |
| """FramePack architecture only.""" |
| sd = {} |
| dtype_str = dtype_to_str(llama_vec.dtype) |
| sd[f"llama_vec_{dtype_str}"] = llama_vec.detach().cpu() |
| sd[f"llama_attention_mask"] = llama_attention_mask.detach().cpu() |
| dtype_str = dtype_to_str(clip_l_pooler.dtype) |
| sd[f"clip_l_pooler_{dtype_str}"] = clip_l_pooler.detach().cpu() |
|
|
| save_text_encoder_output_cache_common(item_info, sd, ARCHITECTURE_FRAMEPACK_FULL) |
|
|
|
|
| def save_text_encoder_output_cache_common(item_info: ItemInfo, sd: dict[str, torch.Tensor], arch_fullname: str): |
| for key, value in sd.items(): |
| |
| if torch.isnan(value).any(): |
| logger.warning(f"{key} tensor has NaN: {item_info.item_key}, replace NaN with 0") |
| value[torch.isnan(value)] = 0 |
|
|
| metadata = { |
| "architecture": arch_fullname, |
| "caption1": item_info.caption, |
| "format_version": "1.0.1", |
| } |
|
|
| if os.path.exists(item_info.text_encoder_output_cache_path): |
| |
| with safetensors_utils.MemoryEfficientSafeOpen(item_info.text_encoder_output_cache_path) as f: |
| existing_metadata = f.metadata() |
| for key in f.keys(): |
| if key not in sd: |
| sd[key] = f.get_tensor(key) |
|
|
| assert existing_metadata["architecture"] == metadata["architecture"], "architecture mismatch" |
| if existing_metadata["caption1"] != metadata["caption1"]: |
| logger.warning(f"caption mismatch: existing={existing_metadata['caption1']}, new={metadata['caption1']}, overwrite") |
| |
|
|
| existing_metadata.pop("caption1", None) |
| existing_metadata.pop("format_version", None) |
| metadata.update(existing_metadata) |
| else: |
| text_encoder_output_dir = os.path.dirname(item_info.text_encoder_output_cache_path) |
| os.makedirs(text_encoder_output_dir, exist_ok=True) |
|
|
| safetensors_utils.mem_eff_save_file(sd, item_info.text_encoder_output_cache_path, metadata=metadata) |
|
|
|
|
| class BucketSelector: |
| RESOLUTION_STEPS_HUNYUAN = 16 |
| RESOLUTION_STEPS_WAN = 16 |
| RESOLUTION_STEPS_FRAMEPACK = 16 |
|
|
| def __init__( |
| self, resolution: Tuple[int, int], enable_bucket: bool = True, no_upscale: bool = False, architecture: str = "no_default" |
| ): |
| self.resolution = resolution |
| self.bucket_area = resolution[0] * resolution[1] |
| self.architecture = architecture |
|
|
| if self.architecture == ARCHITECTURE_HUNYUAN_VIDEO: |
| self.reso_steps = BucketSelector.RESOLUTION_STEPS_HUNYUAN |
| elif self.architecture == ARCHITECTURE_WAN: |
| self.reso_steps = BucketSelector.RESOLUTION_STEPS_WAN |
| elif self.architecture == ARCHITECTURE_FRAMEPACK: |
| self.reso_steps = BucketSelector.RESOLUTION_STEPS_FRAMEPACK |
| else: |
| raise ValueError(f"Invalid architecture: {self.architecture}") |
|
|
| if not enable_bucket: |
| |
| self.bucket_resolutions = [resolution] |
| self.no_upscale = False |
| else: |
| |
| self.no_upscale = no_upscale |
| sqrt_size = int(math.sqrt(self.bucket_area)) |
| min_size = divisible_by(sqrt_size // 2, self.reso_steps) |
| self.bucket_resolutions = [] |
| for w in range(min_size, sqrt_size + self.reso_steps, self.reso_steps): |
| h = divisible_by(self.bucket_area // w, self.reso_steps) |
| self.bucket_resolutions.append((w, h)) |
| self.bucket_resolutions.append((h, w)) |
|
|
| self.bucket_resolutions = list(set(self.bucket_resolutions)) |
| self.bucket_resolutions.sort() |
|
|
| |
| self.aspect_ratios = np.array([w / h for w, h in self.bucket_resolutions]) |
|
|
| def get_bucket_resolution(self, image_size: tuple[int, int]) -> tuple[int, int]: |
| """ |
| return the bucket resolution for the given image size, (width, height) |
| """ |
| area = image_size[0] * image_size[1] |
| if self.no_upscale and area <= self.bucket_area: |
| w, h = image_size |
| w = divisible_by(w, self.reso_steps) |
| h = divisible_by(h, self.reso_steps) |
| return w, h |
|
|
| aspect_ratio = image_size[0] / image_size[1] |
| ar_errors = self.aspect_ratios - aspect_ratio |
| bucket_id = np.abs(ar_errors).argmin() |
| return self.bucket_resolutions[bucket_id] |
|
|
|
|
| def load_video( |
| video_path: str, |
| start_frame: Optional[int] = None, |
| end_frame: Optional[int] = None, |
| bucket_selector: Optional[BucketSelector] = None, |
| bucket_reso: Optional[tuple[int, int]] = None, |
| source_fps: Optional[float] = None, |
| target_fps: Optional[float] = None, |
| ) -> list[np.ndarray]: |
| """ |
| bucket_reso: if given, resize the video to the bucket resolution, (width, height) |
| """ |
| if source_fps is None or target_fps is None: |
| if os.path.isfile(video_path): |
| container = av.open(video_path) |
| video = [] |
| for i, frame in enumerate(container.decode(video=0)): |
| if start_frame is not None and i < start_frame: |
| continue |
| if end_frame is not None and i >= end_frame: |
| break |
| frame = frame.to_image() |
|
|
| if bucket_selector is not None and bucket_reso is None: |
| bucket_reso = bucket_selector.get_bucket_resolution(frame.size) |
|
|
| if bucket_reso is not None: |
| frame = resize_image_to_bucket(frame, bucket_reso) |
| else: |
| frame = np.array(frame) |
|
|
| video.append(frame) |
| container.close() |
| else: |
| |
| image_files = glob_images(video_path) |
| image_files.sort() |
| video = [] |
| for i in range(len(image_files)): |
| if start_frame is not None and i < start_frame: |
| continue |
| if end_frame is not None and i >= end_frame: |
| break |
|
|
| image_file = image_files[i] |
| image = Image.open(image_file).convert("RGB") |
|
|
| if bucket_selector is not None and bucket_reso is None: |
| bucket_reso = bucket_selector.get_bucket_resolution(image.size) |
| image = np.array(image) |
| if bucket_reso is not None: |
| image = resize_image_to_bucket(image, bucket_reso) |
|
|
| video.append(image) |
| else: |
| |
| frame_index_delta = target_fps / source_fps |
| if os.path.isfile(video_path): |
| container = av.open(video_path) |
| video = [] |
| frame_index_with_fraction = 0.0 |
| previous_frame_index = -1 |
| for i, frame in enumerate(container.decode(video=0)): |
| target_frame_index = int(frame_index_with_fraction) |
| frame_index_with_fraction += frame_index_delta |
|
|
| if target_frame_index == previous_frame_index: |
| continue |
|
|
| |
| previous_frame_index = target_frame_index |
|
|
| if start_frame is not None and target_frame_index < start_frame: |
| continue |
| if end_frame is not None and target_frame_index >= end_frame: |
| break |
| frame = frame.to_image() |
|
|
| if bucket_selector is not None and bucket_reso is None: |
| bucket_reso = bucket_selector.get_bucket_resolution(frame.size) |
|
|
| if bucket_reso is not None: |
| frame = resize_image_to_bucket(frame, bucket_reso) |
| else: |
| frame = np.array(frame) |
|
|
| video.append(frame) |
| container.close() |
| else: |
| |
| image_files = glob_images(video_path) |
| image_files.sort() |
| video = [] |
| frame_index_with_fraction = 0.0 |
| previous_frame_index = -1 |
| for i in range(len(image_files)): |
| target_frame_index = int(frame_index_with_fraction) |
| frame_index_with_fraction += frame_index_delta |
|
|
| if target_frame_index == previous_frame_index: |
| continue |
|
|
| |
| previous_frame_index = target_frame_index |
|
|
| if start_frame is not None and target_frame_index < start_frame: |
| continue |
| if end_frame is not None and target_frame_index >= end_frame: |
| break |
|
|
| image_file = image_files[i] |
| image = Image.open(image_file).convert("RGB") |
|
|
| if bucket_selector is not None and bucket_reso is None: |
| bucket_reso = bucket_selector.get_bucket_resolution(image.size) |
| image = np.array(image) |
| if bucket_reso is not None: |
| image = resize_image_to_bucket(image, bucket_reso) |
|
|
| video.append(image) |
|
|
| return video |
|
|
|
|
| class BucketBatchManager: |
|
|
| def __init__(self, bucketed_item_info: dict[Union[tuple[int, int], tuple[int, int, int]], list[ItemInfo]], batch_size: int): |
| self.batch_size = batch_size |
| self.buckets = bucketed_item_info |
| self.bucket_resos = list(self.buckets.keys()) |
| self.bucket_resos.sort() |
|
|
| |
| self.bucket_batch_indices: list[tuple[Union[tuple[int, int], tuple[int, int, int], int]]] = [] |
| for bucket_reso in self.bucket_resos: |
| bucket = self.buckets[bucket_reso] |
| num_batches = math.ceil(len(bucket) / self.batch_size) |
| for i in range(num_batches): |
| self.bucket_batch_indices.append((bucket_reso, i)) |
| |
| |
| |
|
|
| def show_bucket_info(self): |
| for bucket_reso in self.bucket_resos: |
| bucket = self.buckets[bucket_reso] |
| logger.info(f"bucket: {bucket_reso}, count: {len(bucket)}") |
|
|
| logger.info(f"total batches: {len(self)}") |
|
|
| def shuffle(self): |
| |
| for bucket in self.buckets.values(): |
| random.shuffle(bucket) |
|
|
| |
| random.shuffle(self.bucket_batch_indices) |
|
|
| def __len__(self): |
| return len(self.bucket_batch_indices) |
|
|
| def __getitem__(self, idx): |
| bucket_reso, batch_idx = self.bucket_batch_indices[idx] |
| bucket = self.buckets[bucket_reso] |
| start = batch_idx * self.batch_size |
| end = min(start + self.batch_size, len(bucket)) |
|
|
| batch_tensor_data = {} |
| varlen_keys = set() |
| for item_info in bucket[start:end]: |
| sd_latent = load_file(item_info.latent_cache_path) |
| sd_te = load_file(item_info.text_encoder_output_cache_path) |
| sd = {**sd_latent, **sd_te} |
|
|
| |
| for key in sd.keys(): |
| is_varlen_key = key.startswith("varlen_") |
| content_key = key |
|
|
| if is_varlen_key: |
| content_key = content_key.replace("varlen_", "") |
|
|
| if content_key.endswith("_mask"): |
| pass |
| else: |
| content_key = content_key.rsplit("_", 1)[0] |
| if content_key.startswith("latents_"): |
| content_key = content_key.rsplit("_", 1)[0] |
|
|
| if content_key not in batch_tensor_data: |
| batch_tensor_data[content_key] = [] |
| batch_tensor_data[content_key].append(sd[key]) |
|
|
| if is_varlen_key: |
| varlen_keys.add(content_key) |
|
|
| for key in batch_tensor_data.keys(): |
| if key not in varlen_keys: |
| batch_tensor_data[key] = torch.stack(batch_tensor_data[key]) |
|
|
| return batch_tensor_data |
|
|
|
|
| class ContentDatasource: |
| def __init__(self): |
| self.caption_only = False |
| self.has_control = False |
|
|
| def set_caption_only(self, caption_only: bool): |
| self.caption_only = caption_only |
|
|
| def is_indexable(self): |
| return False |
|
|
| def get_caption(self, idx: int) -> tuple[str, str]: |
| """ |
| Returns caption. May not be called if is_indexable() returns False. |
| """ |
| raise NotImplementedError |
|
|
| def __len__(self): |
| raise NotImplementedError |
|
|
| def __iter__(self): |
| raise NotImplementedError |
|
|
| def __next__(self): |
| raise NotImplementedError |
|
|
|
|
| class ImageDatasource(ContentDatasource): |
| def __init__(self): |
| super().__init__() |
|
|
| def get_image_data(self, idx: int) -> tuple[str, Image.Image, str]: |
| """ |
| Returns image data as a tuple of image path, image, and caption for the given index. |
| Key must be unique and valid as a file name. |
| May not be called if is_indexable() returns False. |
| """ |
| raise NotImplementedError |
|
|
|
|
| class ImageDirectoryDatasource(ImageDatasource): |
| def __init__(self, image_directory: str, caption_extension: Optional[str] = None, control_directory: Optional[str] = None): |
| super().__init__() |
| self.image_directory = image_directory |
| self.caption_extension = caption_extension |
| self.control_directory = control_directory |
| self.current_idx = 0 |
|
|
| |
| logger.info(f"glob images in {self.image_directory}") |
| self.image_paths = glob_images(self.image_directory) |
| logger.info(f"found {len(self.image_paths)} images") |
|
|
| |
| if self.control_directory is not None: |
| logger.info(f"glob control images in {self.control_directory}") |
| self.has_control = True |
| self.control_paths = {} |
| for image_path in self.image_paths: |
| image_basename = os.path.basename(image_path) |
| control_path = os.path.join(self.control_directory, image_basename) |
| if os.path.exists(control_path): |
| self.control_paths[image_path] = control_path |
| else: |
| |
| |
| image_basename_no_ext = os.path.splitext(image_basename)[0] |
| for ext in IMAGE_EXTENSIONS: |
| potential_path = os.path.join(self.control_directory, image_basename_no_ext + ext) |
| if os.path.exists(potential_path): |
| self.control_paths[image_path] = potential_path |
| break |
|
|
| logger.info(f"found {len(self.control_paths)} matching control images") |
| missing_controls = len(self.image_paths) - len(self.control_paths) |
| if missing_controls > 0: |
| missing_control_paths = set(self.image_paths) - set(self.control_paths.keys()) |
| logger.error(f"Could not find matching control images for {missing_controls} images: {missing_control_paths}") |
| raise ValueError(f"Could not find matching control images for {missing_controls} images") |
|
|
| def is_indexable(self): |
| return True |
|
|
| def __len__(self): |
| return len(self.image_paths) |
|
|
| def get_image_data(self, idx: int) -> tuple[str, Image.Image, str, Optional[Image.Image]]: |
| image_path = self.image_paths[idx] |
| image = Image.open(image_path).convert("RGB") |
|
|
| _, caption = self.get_caption(idx) |
|
|
| control = None |
| if self.has_control: |
| control_path = self.control_paths[image_path] |
| control = Image.open(control_path).convert("RGB") |
|
|
| return image_path, image, caption, control |
|
|
| def get_caption(self, idx: int) -> tuple[str, str]: |
| image_path = self.image_paths[idx] |
| caption_path = os.path.splitext(image_path)[0] + self.caption_extension if self.caption_extension else "" |
| with open(caption_path, "r", encoding="utf-8") as f: |
| caption = f.read().strip() |
| return image_path, caption |
|
|
| def __iter__(self): |
| self.current_idx = 0 |
| return self |
|
|
| def __next__(self) -> callable: |
| """ |
| Returns a fetcher function that returns image data. |
| """ |
| if self.current_idx >= len(self.image_paths): |
| raise StopIteration |
|
|
| if self.caption_only: |
|
|
| def create_caption_fetcher(index): |
| return lambda: self.get_caption(index) |
|
|
| fetcher = create_caption_fetcher(self.current_idx) |
| else: |
|
|
| def create_image_fetcher(index): |
| return lambda: self.get_image_data(index) |
|
|
| fetcher = create_image_fetcher(self.current_idx) |
|
|
| self.current_idx += 1 |
| return fetcher |
|
|
|
|
| class ImageJsonlDatasource(ImageDatasource): |
| def __init__(self, image_jsonl_file: str): |
| super().__init__() |
| self.image_jsonl_file = image_jsonl_file |
| self.current_idx = 0 |
|
|
| |
| logger.info(f"load image jsonl from {self.image_jsonl_file}") |
| self.data = [] |
| with open(self.image_jsonl_file, "r", encoding="utf-8") as f: |
| for line in f: |
| try: |
| data = json.loads(line) |
| except json.JSONDecodeError: |
| logger.error(f"failed to load json: {line} @ {self.image_jsonl_file}") |
| raise |
| self.data.append(data) |
| logger.info(f"loaded {len(self.data)} images") |
|
|
| |
| self.has_control = any("control_path" in item for item in self.data) |
| if self.has_control: |
| control_count = sum(1 for item in self.data if "control_path" in item) |
| if control_count < len(self.data): |
| missing_control_images = [item["image_path"] for item in self.data if "control_path" not in item] |
| logger.error(f"Some images do not have control paths in JSONL data: {missing_control_images}") |
| raise ValueError(f"Some images do not have control paths in JSONL data: {missing_control_images}") |
| logger.info(f"found {control_count} control images in JSONL data") |
|
|
| def is_indexable(self): |
| return True |
|
|
| def __len__(self): |
| return len(self.data) |
|
|
| def get_image_data(self, idx: int) -> tuple[str, Image.Image, str, Optional[Image.Image]]: |
| data = self.data[idx] |
| image_path = data["image_path"] |
| image = Image.open(image_path).convert("RGB") |
|
|
| caption = data["caption"] |
|
|
| control = None |
| if self.has_control: |
| control_path = data["control_path"] |
| control = Image.open(control_path).convert("RGB") |
|
|
| return image_path, image, caption, control |
|
|
| def get_caption(self, idx: int) -> tuple[str, str]: |
| data = self.data[idx] |
| image_path = data["image_path"] |
| caption = data["caption"] |
| return image_path, caption |
|
|
| def __iter__(self): |
| self.current_idx = 0 |
| return self |
|
|
| def __next__(self) -> callable: |
| if self.current_idx >= len(self.data): |
| raise StopIteration |
|
|
| if self.caption_only: |
|
|
| def create_caption_fetcher(index): |
| return lambda: self.get_caption(index) |
|
|
| fetcher = create_caption_fetcher(self.current_idx) |
|
|
| else: |
|
|
| def create_fetcher(index): |
| return lambda: self.get_image_data(index) |
|
|
| fetcher = create_fetcher(self.current_idx) |
|
|
| self.current_idx += 1 |
| return fetcher |
|
|
|
|
| class VideoDatasource(ContentDatasource): |
| def __init__(self): |
| super().__init__() |
|
|
| |
| self.start_frame = None |
| self.end_frame = None |
|
|
| self.bucket_selector = None |
|
|
| self.source_fps = None |
| self.target_fps = None |
|
|
| def __len__(self): |
| raise NotImplementedError |
|
|
| def get_video_data_from_path( |
| self, |
| video_path: str, |
| start_frame: Optional[int] = None, |
| end_frame: Optional[int] = None, |
| bucket_selector: Optional[BucketSelector] = None, |
| ) -> tuple[str, list[Image.Image], str]: |
| |
|
|
| start_frame = start_frame if start_frame is not None else self.start_frame |
| end_frame = end_frame if end_frame is not None else self.end_frame |
| bucket_selector = bucket_selector if bucket_selector is not None else self.bucket_selector |
|
|
| video = load_video( |
| video_path, start_frame, end_frame, bucket_selector, source_fps=self.source_fps, target_fps=self.target_fps |
| ) |
| return video |
|
|
| def get_control_data_from_path( |
| self, |
| control_path: str, |
| start_frame: Optional[int] = None, |
| end_frame: Optional[int] = None, |
| bucket_selector: Optional[BucketSelector] = None, |
| ) -> list[Image.Image]: |
| start_frame = start_frame if start_frame is not None else self.start_frame |
| end_frame = end_frame if end_frame is not None else self.end_frame |
| bucket_selector = bucket_selector if bucket_selector is not None else self.bucket_selector |
|
|
| control = load_video( |
| control_path, start_frame, end_frame, bucket_selector, source_fps=self.source_fps, target_fps=self.target_fps |
| ) |
| return control |
|
|
| def set_start_and_end_frame(self, start_frame: Optional[int], end_frame: Optional[int]): |
| self.start_frame = start_frame |
| self.end_frame = end_frame |
|
|
| def set_bucket_selector(self, bucket_selector: BucketSelector): |
| self.bucket_selector = bucket_selector |
|
|
| def set_source_and_target_fps(self, source_fps: Optional[float], target_fps: Optional[float]): |
| self.source_fps = source_fps |
| self.target_fps = target_fps |
|
|
| def __iter__(self): |
| raise NotImplementedError |
|
|
| def __next__(self): |
| raise NotImplementedError |
|
|
|
|
| class VideoDirectoryDatasource(VideoDatasource): |
| def __init__(self, video_directory: str, caption_extension: Optional[str] = None, control_directory: Optional[str] = None): |
| super().__init__() |
| self.video_directory = video_directory |
| self.caption_extension = caption_extension |
| self.control_directory = control_directory |
| self.current_idx = 0 |
|
|
| |
| logger.info(f"glob videos in {self.video_directory}") |
| self.video_paths = glob_videos(self.video_directory) |
| logger.info(f"found {len(self.video_paths)} videos") |
|
|
| |
| if self.control_directory is not None: |
| logger.info(f"glob control videos in {self.control_directory}") |
| self.has_control = True |
| self.control_paths = {} |
| for video_path in self.video_paths: |
| video_basename = os.path.basename(video_path) |
| |
| |
| control_path = os.path.join(self.control_directory, video_basename) |
| if os.path.exists(control_path): |
| self.control_paths[video_path] = control_path |
| else: |
| |
| base_name = os.path.splitext(video_basename)[0] |
|
|
| |
| potential_path = os.path.join(self.control_directory, base_name) |
| if os.path.isdir(potential_path): |
| self.control_paths[video_path] = potential_path |
| else: |
| |
| |
| for ext in VIDEO_EXTENSIONS: |
| potential_path = os.path.join(self.control_directory, base_name + ext) |
| if os.path.exists(potential_path): |
| self.control_paths[video_path] = potential_path |
| break |
|
|
| logger.info(f"found {len(self.control_paths)} matching control videos/images") |
| |
| missing_controls = len(self.video_paths) - len(self.control_paths) |
| if missing_controls > 0: |
| |
| missing_controls_videos = [video_path for video_path in self.video_paths if video_path not in self.control_paths] |
| logger.error( |
| f"Could not find matching control videos/images for {missing_controls} videos: {missing_controls_videos}" |
| ) |
| raise ValueError(f"Could not find matching control videos/images for {missing_controls} videos") |
|
|
| def is_indexable(self): |
| return True |
|
|
| def __len__(self): |
| return len(self.video_paths) |
|
|
| def get_video_data( |
| self, |
| idx: int, |
| start_frame: Optional[int] = None, |
| end_frame: Optional[int] = None, |
| bucket_selector: Optional[BucketSelector] = None, |
| ) -> tuple[str, list[Image.Image], str, Optional[list[Image.Image]]]: |
| video_path = self.video_paths[idx] |
| video = self.get_video_data_from_path(video_path, start_frame, end_frame, bucket_selector) |
|
|
| _, caption = self.get_caption(idx) |
|
|
| control = None |
| if self.control_directory is not None and video_path in self.control_paths: |
| control_path = self.control_paths[video_path] |
| control = self.get_control_data_from_path(control_path, start_frame, end_frame, bucket_selector) |
|
|
| return video_path, video, caption, control |
|
|
| def get_caption(self, idx: int) -> tuple[str, str]: |
| video_path = self.video_paths[idx] |
| caption_path = os.path.splitext(video_path)[0] + self.caption_extension if self.caption_extension else "" |
| with open(caption_path, "r", encoding="utf-8") as f: |
| caption = f.read().strip() |
| return video_path, caption |
|
|
| def __iter__(self): |
| self.current_idx = 0 |
| return self |
|
|
| def __next__(self): |
| if self.current_idx >= len(self.video_paths): |
| raise StopIteration |
|
|
| if self.caption_only: |
|
|
| def create_caption_fetcher(index): |
| return lambda: self.get_caption(index) |
|
|
| fetcher = create_caption_fetcher(self.current_idx) |
|
|
| else: |
|
|
| def create_fetcher(index): |
| return lambda: self.get_video_data(index) |
|
|
| fetcher = create_fetcher(self.current_idx) |
|
|
| self.current_idx += 1 |
| return fetcher |
|
|
|
|
| class VideoJsonlDatasource(VideoDatasource): |
| def __init__(self, video_jsonl_file: str): |
| super().__init__() |
| self.video_jsonl_file = video_jsonl_file |
| self.current_idx = 0 |
|
|
| |
| logger.info(f"load video jsonl from {self.video_jsonl_file}") |
| self.data = [] |
| with open(self.video_jsonl_file, "r", encoding="utf-8") as f: |
| for line in f: |
| data = json.loads(line) |
| self.data.append(data) |
| logger.info(f"loaded {len(self.data)} videos") |
|
|
| |
| self.has_control = any("control_path" in item for item in self.data) |
| if self.has_control: |
| control_count = sum(1 for item in self.data if "control_path" in item) |
| if control_count < len(self.data): |
| missing_control_videos = [item["video_path"] for item in self.data if "control_path" not in item] |
| logger.error(f"Some videos do not have control paths in JSONL data: {missing_control_videos}") |
| raise ValueError(f"Some videos do not have control paths in JSONL data: {missing_control_videos}") |
| logger.info(f"found {control_count} control videos/images in JSONL data") |
|
|
| def is_indexable(self): |
| return True |
|
|
| def __len__(self): |
| return len(self.data) |
|
|
| def get_video_data( |
| self, |
| idx: int, |
| start_frame: Optional[int] = None, |
| end_frame: Optional[int] = None, |
| bucket_selector: Optional[BucketSelector] = None, |
| ) -> tuple[str, list[Image.Image], str, Optional[list[Image.Image]]]: |
| data = self.data[idx] |
| video_path = data["video_path"] |
| video = self.get_video_data_from_path(video_path, start_frame, end_frame, bucket_selector) |
|
|
| caption = data["caption"] |
|
|
| control = None |
| if "control_path" in data and data["control_path"]: |
| control_path = data["control_path"] |
| control = self.get_control_data_from_path(control_path, start_frame, end_frame, bucket_selector) |
|
|
| return video_path, video, caption, control |
|
|
| def get_caption(self, idx: int) -> tuple[str, str]: |
| data = self.data[idx] |
| video_path = data["video_path"] |
| caption = data["caption"] |
| return video_path, caption |
|
|
| def __iter__(self): |
| self.current_idx = 0 |
| return self |
|
|
| def __next__(self): |
| if self.current_idx >= len(self.data): |
| raise StopIteration |
|
|
| if self.caption_only: |
|
|
| def create_caption_fetcher(index): |
| return lambda: self.get_caption(index) |
|
|
| fetcher = create_caption_fetcher(self.current_idx) |
|
|
| else: |
|
|
| def create_fetcher(index): |
| return lambda: self.get_video_data(index) |
|
|
| fetcher = create_fetcher(self.current_idx) |
|
|
| self.current_idx += 1 |
| return fetcher |
|
|
|
|
| class BaseDataset(torch.utils.data.Dataset): |
| def __init__( |
| self, |
| resolution: Tuple[int, int] = (960, 544), |
| caption_extension: Optional[str] = None, |
| batch_size: int = 1, |
| num_repeats: int = 1, |
| enable_bucket: bool = False, |
| bucket_no_upscale: bool = False, |
| cache_directory: Optional[str] = None, |
| debug_dataset: bool = False, |
| architecture: str = "no_default", |
| ): |
| self.resolution = resolution |
| self.caption_extension = caption_extension |
| self.batch_size = batch_size |
| self.num_repeats = num_repeats |
| self.enable_bucket = enable_bucket |
| self.bucket_no_upscale = bucket_no_upscale |
| self.cache_directory = cache_directory |
| self.debug_dataset = debug_dataset |
| self.architecture = architecture |
| self.seed = None |
| self.current_epoch = 0 |
|
|
| if not self.enable_bucket: |
| self.bucket_no_upscale = False |
|
|
| def get_metadata(self) -> dict: |
| metadata = { |
| "resolution": self.resolution, |
| "caption_extension": self.caption_extension, |
| "batch_size_per_device": self.batch_size, |
| "num_repeats": self.num_repeats, |
| "enable_bucket": bool(self.enable_bucket), |
| "bucket_no_upscale": bool(self.bucket_no_upscale), |
| } |
| return metadata |
|
|
| def get_all_latent_cache_files(self): |
| return glob.glob(os.path.join(self.cache_directory, f"*_{self.architecture}.safetensors")) |
|
|
| def get_all_text_encoder_output_cache_files(self): |
| return glob.glob(os.path.join(self.cache_directory, f"*_{self.architecture}_te.safetensors")) |
|
|
| def get_latent_cache_path(self, item_info: ItemInfo) -> str: |
| """ |
| Returns the cache path for the latent tensor. |
| |
| item_info: ItemInfo object |
| |
| Returns: |
| str: cache path |
| |
| cache_path is based on the item_key and the resolution. |
| """ |
| w, h = item_info.original_size |
| basename = os.path.splitext(os.path.basename(item_info.item_key))[0] |
| assert self.cache_directory is not None, "cache_directory is required / cache_directoryは必須です" |
| return os.path.join(self.cache_directory, f"{basename}_{w:04d}x{h:04d}_{self.architecture}.safetensors") |
|
|
| def get_text_encoder_output_cache_path(self, item_info: ItemInfo) -> str: |
| basename = os.path.splitext(os.path.basename(item_info.item_key))[0] |
| assert self.cache_directory is not None, "cache_directory is required / cache_directoryは必須です" |
| return os.path.join(self.cache_directory, f"{basename}_{self.architecture}_te.safetensors") |
|
|
| def retrieve_latent_cache_batches(self, num_workers: int): |
| raise NotImplementedError |
|
|
| def retrieve_text_encoder_output_cache_batches(self, num_workers: int): |
| raise NotImplementedError |
|
|
| def prepare_for_training(self): |
| pass |
|
|
| def set_seed(self, seed: int): |
| self.seed = seed |
|
|
| def set_current_epoch(self, epoch): |
| if not self.current_epoch == epoch: |
| if epoch > self.current_epoch: |
| logger.info("epoch is incremented. current_epoch: {}, epoch: {}".format(self.current_epoch, epoch)) |
| num_epochs = epoch - self.current_epoch |
| for _ in range(num_epochs): |
| self.current_epoch += 1 |
| self.shuffle_buckets() |
| |
| else: |
| logger.warning("epoch is not incremented. current_epoch: {}, epoch: {}".format(self.current_epoch, epoch)) |
| self.current_epoch = epoch |
|
|
| def set_current_step(self, step): |
| self.current_step = step |
|
|
| def set_max_train_steps(self, max_train_steps): |
| self.max_train_steps = max_train_steps |
|
|
| def shuffle_buckets(self): |
| raise NotImplementedError |
|
|
| def __len__(self): |
| return NotImplementedError |
|
|
| def __getitem__(self, idx): |
| raise NotImplementedError |
|
|
| def _default_retrieve_text_encoder_output_cache_batches(self, datasource: ContentDatasource, batch_size: int, num_workers: int): |
| datasource.set_caption_only(True) |
| executor = ThreadPoolExecutor(max_workers=num_workers) |
|
|
| data: list[ItemInfo] = [] |
| futures = [] |
|
|
| def aggregate_future(consume_all: bool = False): |
| while len(futures) >= num_workers or (consume_all and len(futures) > 0): |
| completed_futures = [future for future in futures if future.done()] |
| if len(completed_futures) == 0: |
| if len(futures) >= num_workers or consume_all: |
| time.sleep(0.1) |
| continue |
| else: |
| break |
|
|
| for future in completed_futures: |
| item_key, caption = future.result() |
| item_info = ItemInfo(item_key, caption, (0, 0), (0, 0)) |
| item_info.text_encoder_output_cache_path = self.get_text_encoder_output_cache_path(item_info) |
| data.append(item_info) |
|
|
| futures.remove(future) |
|
|
| def submit_batch(flush: bool = False): |
| nonlocal data |
| if len(data) >= batch_size or (len(data) > 0 and flush): |
| batch = data[0:batch_size] |
| if len(data) > batch_size: |
| data = data[batch_size:] |
| else: |
| data = [] |
| return batch |
| return None |
|
|
| for fetch_op in datasource: |
| future = executor.submit(fetch_op) |
| futures.append(future) |
| aggregate_future() |
| while True: |
| batch = submit_batch() |
| if batch is None: |
| break |
| yield batch |
|
|
| aggregate_future(consume_all=True) |
| while True: |
| batch = submit_batch(flush=True) |
| if batch is None: |
| break |
| yield batch |
|
|
| executor.shutdown() |
|
|
|
|
| class ImageDataset(BaseDataset): |
| def __init__( |
| self, |
| resolution: Tuple[int, int], |
| caption_extension: Optional[str], |
| batch_size: int, |
| num_repeats: int, |
| enable_bucket: bool, |
| bucket_no_upscale: bool, |
| image_directory: Optional[str] = None, |
| image_jsonl_file: Optional[str] = None, |
| control_directory: Optional[str] = None, |
| cache_directory: Optional[str] = None, |
| debug_dataset: bool = False, |
| architecture: str = "no_default", |
| ): |
| super(ImageDataset, self).__init__( |
| resolution, |
| caption_extension, |
| batch_size, |
| num_repeats, |
| enable_bucket, |
| bucket_no_upscale, |
| cache_directory, |
| debug_dataset, |
| architecture, |
| ) |
| self.image_directory = image_directory |
| self.image_jsonl_file = image_jsonl_file |
| self.control_directory = control_directory |
| if image_directory is not None: |
| self.datasource = ImageDirectoryDatasource(image_directory, caption_extension, control_directory) |
| elif image_jsonl_file is not None: |
| self.datasource = ImageJsonlDatasource(image_jsonl_file) |
| else: |
| raise ValueError("image_directory or image_jsonl_file must be specified") |
|
|
| if self.cache_directory is None: |
| self.cache_directory = self.image_directory |
|
|
| self.batch_manager = None |
| self.num_train_items = 0 |
| self.has_control = self.datasource.has_control |
|
|
| def get_metadata(self): |
| metadata = super().get_metadata() |
| if self.image_directory is not None: |
| metadata["image_directory"] = os.path.basename(self.image_directory) |
| if self.image_jsonl_file is not None: |
| metadata["image_jsonl_file"] = os.path.basename(self.image_jsonl_file) |
| if self.control_directory is not None: |
| metadata["control_directory"] = os.path.basename(self.control_directory) |
| metadata["has_control"] = self.has_control |
| return metadata |
|
|
| def get_total_image_count(self): |
| return len(self.datasource) if self.datasource.is_indexable() else None |
|
|
| def retrieve_latent_cache_batches(self, num_workers: int): |
| buckset_selector = BucketSelector(self.resolution, self.enable_bucket, self.bucket_no_upscale, self.architecture) |
| executor = ThreadPoolExecutor(max_workers=num_workers) |
|
|
| batches: dict[tuple[int, int], list[ItemInfo]] = {} |
| futures = [] |
|
|
| |
| def aggregate_future(consume_all: bool = False): |
| while len(futures) >= num_workers or (consume_all and len(futures) > 0): |
| completed_futures = [future for future in futures if future.done()] |
| if len(completed_futures) == 0: |
| if len(futures) >= num_workers or consume_all: |
| time.sleep(0.1) |
| continue |
| else: |
| break |
|
|
| for future in completed_futures: |
| original_size, item_key, image, caption, control = future.result() |
| bucket_height, bucket_width = image.shape[:2] |
| bucket_reso = (bucket_width, bucket_height) |
|
|
| item_info = ItemInfo(item_key, caption, original_size, bucket_reso, content=image) |
| item_info.latent_cache_path = self.get_latent_cache_path(item_info) |
|
|
| if control is not None: |
| item_info.control_content = control |
|
|
| if bucket_reso not in batches: |
| batches[bucket_reso] = [] |
| batches[bucket_reso].append(item_info) |
|
|
| futures.remove(future) |
|
|
| |
| def submit_batch(flush: bool = False): |
| for key in batches: |
| if len(batches[key]) >= self.batch_size or flush: |
| batch = batches[key][0 : self.batch_size] |
| if len(batches[key]) > self.batch_size: |
| batches[key] = batches[key][self.batch_size :] |
| else: |
| del batches[key] |
| return key, batch |
| return None, None |
|
|
| for fetch_op in self.datasource: |
|
|
| |
| def fetch_and_resize(op: callable) -> tuple[tuple[int, int], str, Image.Image, str, Optional[Image.Image]]: |
| image_key, image, caption, control = op() |
| image: Image.Image |
| image_size = image.size |
|
|
| bucket_reso = buckset_selector.get_bucket_resolution(image_size) |
| image = resize_image_to_bucket(image, bucket_reso) |
| if control is not None: |
| control = resize_image_to_bucket(control, bucket_reso) |
| return image_size, image_key, image, caption, control |
|
|
| future = executor.submit(fetch_and_resize, fetch_op) |
| futures.append(future) |
| aggregate_future() |
| while True: |
| key, batch = submit_batch() |
| if key is None: |
| break |
| yield key, batch |
|
|
| aggregate_future(consume_all=True) |
| while True: |
| key, batch = submit_batch(flush=True) |
| if key is None: |
| break |
| yield key, batch |
|
|
| executor.shutdown() |
|
|
| def retrieve_text_encoder_output_cache_batches(self, num_workers: int): |
| return self._default_retrieve_text_encoder_output_cache_batches(self.datasource, self.batch_size, num_workers) |
|
|
| def prepare_for_training(self): |
| bucket_selector = BucketSelector(self.resolution, self.enable_bucket, self.bucket_no_upscale, self.architecture) |
|
|
| |
| latent_cache_files = glob.glob(os.path.join(self.cache_directory, f"*_{self.architecture}.safetensors")) |
|
|
| |
| bucketed_item_info: dict[tuple[int, int], list[ItemInfo]] = {} |
| for cache_file in latent_cache_files: |
| tokens = os.path.basename(cache_file).split("_") |
|
|
| image_size = tokens[-2] |
| image_width, image_height = map(int, image_size.split("x")) |
| image_size = (image_width, image_height) |
|
|
| item_key = "_".join(tokens[:-2]) |
| text_encoder_output_cache_file = os.path.join(self.cache_directory, f"{item_key}_{self.architecture}_te.safetensors") |
| if not os.path.exists(text_encoder_output_cache_file): |
| logger.warning(f"Text encoder output cache file not found: {text_encoder_output_cache_file}") |
| continue |
|
|
| bucket_reso = bucket_selector.get_bucket_resolution(image_size) |
| item_info = ItemInfo(item_key, "", image_size, bucket_reso, latent_cache_path=cache_file) |
| item_info.text_encoder_output_cache_path = text_encoder_output_cache_file |
|
|
| bucket = bucketed_item_info.get(bucket_reso, []) |
| for _ in range(self.num_repeats): |
| bucket.append(item_info) |
| bucketed_item_info[bucket_reso] = bucket |
|
|
| |
| self.batch_manager = BucketBatchManager(bucketed_item_info, self.batch_size) |
| self.batch_manager.show_bucket_info() |
|
|
| self.num_train_items = sum([len(bucket) for bucket in bucketed_item_info.values()]) |
|
|
| def shuffle_buckets(self): |
| |
| random.seed(self.seed + self.current_epoch) |
| self.batch_manager.shuffle() |
|
|
| def __len__(self): |
| if self.batch_manager is None: |
| return 100 |
| return len(self.batch_manager) |
|
|
| def __getitem__(self, idx): |
| return self.batch_manager[idx] |
|
|
|
|
| class VideoDataset(BaseDataset): |
| TARGET_FPS_HUNYUAN = 24.0 |
| TARGET_FPS_WAN = 16.0 |
| TARGET_FPS_FRAMEPACK = 30.0 |
|
|
| def __init__( |
| self, |
| resolution: Tuple[int, int], |
| caption_extension: Optional[str], |
| batch_size: int, |
| num_repeats: int, |
| enable_bucket: bool, |
| bucket_no_upscale: bool, |
| frame_extraction: Optional[str] = "head", |
| frame_stride: Optional[int] = 1, |
| frame_sample: Optional[int] = 1, |
| target_frames: Optional[list[int]] = None, |
| max_frames: Optional[int] = None, |
| source_fps: Optional[float] = None, |
| video_directory: Optional[str] = None, |
| video_jsonl_file: Optional[str] = None, |
| control_directory: Optional[str] = None, |
| cache_directory: Optional[str] = None, |
| debug_dataset: bool = False, |
| architecture: str = "no_default", |
| ): |
| super(VideoDataset, self).__init__( |
| resolution, |
| caption_extension, |
| batch_size, |
| num_repeats, |
| enable_bucket, |
| bucket_no_upscale, |
| cache_directory, |
| debug_dataset, |
| architecture, |
| ) |
| self.video_directory = video_directory |
| self.video_jsonl_file = video_jsonl_file |
| self.control_directory = control_directory |
| self.frame_extraction = frame_extraction |
| self.frame_stride = frame_stride |
| self.frame_sample = frame_sample |
| self.max_frames = max_frames |
| self.source_fps = source_fps |
|
|
| if self.architecture == ARCHITECTURE_HUNYUAN_VIDEO: |
| self.target_fps = VideoDataset.TARGET_FPS_HUNYUAN |
| elif self.architecture == ARCHITECTURE_WAN: |
| self.target_fps = VideoDataset.TARGET_FPS_WAN |
| elif self.architecture == ARCHITECTURE_FRAMEPACK: |
| self.target_fps = VideoDataset.TARGET_FPS_FRAMEPACK |
| else: |
| raise ValueError(f"Unsupported architecture: {self.architecture}") |
|
|
| if target_frames is not None: |
| target_frames = list(set(target_frames)) |
| target_frames.sort() |
|
|
| |
| rounded_target_frames = [(f - 1) // 4 * 4 + 1 for f in target_frames] |
| rouneded_target_frames = list(set(rounded_target_frames)) |
| rouneded_target_frames.sort() |
|
|
| |
| if target_frames != rounded_target_frames: |
| logger.warning(f"target_frames are rounded to {rounded_target_frames}") |
|
|
| target_frames = tuple(rounded_target_frames) |
|
|
| self.target_frames = target_frames |
|
|
| if video_directory is not None: |
| self.datasource = VideoDirectoryDatasource(video_directory, caption_extension, control_directory) |
| elif video_jsonl_file is not None: |
| self.datasource = VideoJsonlDatasource(video_jsonl_file) |
|
|
| if self.frame_extraction == "uniform" and self.frame_sample == 1: |
| self.frame_extraction = "head" |
| logger.warning("frame_sample is set to 1 for frame_extraction=uniform. frame_extraction is changed to head.") |
| if self.frame_extraction == "head": |
| |
| self.datasource.set_start_and_end_frame(0, max(self.target_frames)) |
|
|
| if self.cache_directory is None: |
| self.cache_directory = self.video_directory |
|
|
| self.batch_manager = None |
| self.num_train_items = 0 |
| self.has_control = self.datasource.has_control |
|
|
| def get_metadata(self): |
| metadata = super().get_metadata() |
| if self.video_directory is not None: |
| metadata["video_directory"] = os.path.basename(self.video_directory) |
| if self.video_jsonl_file is not None: |
| metadata["video_jsonl_file"] = os.path.basename(self.video_jsonl_file) |
| if self.control_directory is not None: |
| metadata["control_directory"] = os.path.basename(self.control_directory) |
| metadata["frame_extraction"] = self.frame_extraction |
| metadata["frame_stride"] = self.frame_stride |
| metadata["frame_sample"] = self.frame_sample |
| metadata["target_frames"] = self.target_frames |
| metadata["max_frames"] = self.max_frames |
| metadata["source_fps"] = self.source_fps |
| metadata["has_control"] = self.has_control |
| return metadata |
|
|
| def retrieve_latent_cache_batches(self, num_workers: int): |
| buckset_selector = BucketSelector(self.resolution, architecture=self.architecture) |
| self.datasource.set_bucket_selector(buckset_selector) |
| if self.source_fps is not None: |
| self.datasource.set_source_and_target_fps(self.source_fps, self.target_fps) |
| else: |
| self.datasource.set_source_and_target_fps(None, None) |
|
|
| executor = ThreadPoolExecutor(max_workers=num_workers) |
|
|
| |
| batches: dict[tuple[int, int, int], list[ItemInfo]] = {} |
| futures = [] |
|
|
| def aggregate_future(consume_all: bool = False): |
| while len(futures) >= num_workers or (consume_all and len(futures) > 0): |
| completed_futures = [future for future in futures if future.done()] |
| if len(completed_futures) == 0: |
| if len(futures) >= num_workers or consume_all: |
| time.sleep(0.1) |
| continue |
| else: |
| break |
|
|
| for future in completed_futures: |
| original_frame_size, video_key, video, caption, control = future.result() |
|
|
| frame_count = len(video) |
| video = np.stack(video, axis=0) |
| height, width = video.shape[1:3] |
| bucket_reso = (width, height) |
|
|
| |
| control_video = None |
| if control is not None: |
| |
| if len(control) > frame_count: |
| control = control[:frame_count] |
| elif len(control) < frame_count: |
| |
| last_frame = control[-1] |
| control.extend([last_frame] * (frame_count - len(control))) |
| control_video = np.stack(control, axis=0) |
|
|
| crop_pos_and_frames = [] |
| if self.frame_extraction == "head": |
| for target_frame in self.target_frames: |
| if frame_count >= target_frame: |
| crop_pos_and_frames.append((0, target_frame)) |
| elif self.frame_extraction == "chunk": |
| |
| for target_frame in self.target_frames: |
| for i in range(0, frame_count, target_frame): |
| if i + target_frame <= frame_count: |
| crop_pos_and_frames.append((i, target_frame)) |
| elif self.frame_extraction == "slide": |
| |
| for target_frame in self.target_frames: |
| if frame_count >= target_frame: |
| for i in range(0, frame_count - target_frame + 1, self.frame_stride): |
| crop_pos_and_frames.append((i, target_frame)) |
| elif self.frame_extraction == "uniform": |
| |
| for target_frame in self.target_frames: |
| if frame_count >= target_frame: |
| frame_indices = np.linspace(0, frame_count - target_frame, self.frame_sample, dtype=int) |
| for i in frame_indices: |
| crop_pos_and_frames.append((i, target_frame)) |
| elif self.frame_extraction == "full": |
| |
| target_frame = min(frame_count, self.max_frames) |
| target_frame = (target_frame - 1) // 4 * 4 + 1 |
| crop_pos_and_frames.append((0, target_frame)) |
| else: |
| raise ValueError(f"frame_extraction {self.frame_extraction} is not supported") |
|
|
| for crop_pos, target_frame in crop_pos_and_frames: |
| cropped_video = video[crop_pos : crop_pos + target_frame] |
| body, ext = os.path.splitext(video_key) |
| item_key = f"{body}_{crop_pos:05d}-{target_frame:03d}{ext}" |
| batch_key = (*bucket_reso, target_frame) |
|
|
| |
| cropped_control = None |
| if control_video is not None: |
| cropped_control = control_video[crop_pos : crop_pos + target_frame] |
|
|
| item_info = ItemInfo( |
| item_key, caption, original_frame_size, batch_key, frame_count=target_frame, content=cropped_video |
| ) |
| item_info.latent_cache_path = self.get_latent_cache_path(item_info) |
| item_info.control_content = cropped_control |
|
|
| batch = batches.get(batch_key, []) |
| batch.append(item_info) |
| batches[batch_key] = batch |
|
|
| futures.remove(future) |
|
|
| def submit_batch(flush: bool = False): |
| for key in batches: |
| if len(batches[key]) >= self.batch_size or flush: |
| batch = batches[key][0 : self.batch_size] |
| if len(batches[key]) > self.batch_size: |
| batches[key] = batches[key][self.batch_size :] |
| else: |
| del batches[key] |
| return key, batch |
| return None, None |
|
|
| for operator in self.datasource: |
|
|
| def fetch_and_resize(op: callable) -> tuple[tuple[int, int], str, list[np.ndarray], str, Optional[list[np.ndarray]]]: |
| result = op() |
|
|
| if len(result) == 3: |
| video_key, video, caption = result |
| control = None |
| else: |
| video_key, video, caption, control = result |
|
|
| video: list[np.ndarray] |
| frame_size = (video[0].shape[1], video[0].shape[0]) |
|
|
| |
| bucket_reso = buckset_selector.get_bucket_resolution(frame_size) |
| video = [resize_image_to_bucket(frame, bucket_reso) for frame in video] |
|
|
| |
| if control is not None: |
| control = [resize_image_to_bucket(frame, bucket_reso) for frame in control] |
|
|
| return frame_size, video_key, video, caption, control |
|
|
| future = executor.submit(fetch_and_resize, operator) |
| futures.append(future) |
| aggregate_future() |
| while True: |
| key, batch = submit_batch() |
| if key is None: |
| break |
| yield key, batch |
|
|
| aggregate_future(consume_all=True) |
| while True: |
| key, batch = submit_batch(flush=True) |
| if key is None: |
| break |
| yield key, batch |
|
|
| executor.shutdown() |
|
|
| def retrieve_text_encoder_output_cache_batches(self, num_workers: int): |
| return self._default_retrieve_text_encoder_output_cache_batches(self.datasource, self.batch_size, num_workers) |
|
|
| def prepare_for_training(self): |
| bucket_selector = BucketSelector(self.resolution, self.enable_bucket, self.bucket_no_upscale, self.architecture) |
|
|
| |
| latent_cache_files = glob.glob(os.path.join(self.cache_directory, f"*_{self.architecture}.safetensors")) |
|
|
| |
| bucketed_item_info: dict[tuple[int, int, int], list[ItemInfo]] = {} |
| for cache_file in latent_cache_files: |
| tokens = os.path.basename(cache_file).split("_") |
|
|
| image_size = tokens[-2] |
| image_width, image_height = map(int, image_size.split("x")) |
| image_size = (image_width, image_height) |
|
|
| frame_pos, frame_count = tokens[-3].split("-")[:2] |
| frame_pos, frame_count = int(frame_pos), int(frame_count) |
|
|
| item_key = "_".join(tokens[:-3]) |
| text_encoder_output_cache_file = os.path.join(self.cache_directory, f"{item_key}_{self.architecture}_te.safetensors") |
| if not os.path.exists(text_encoder_output_cache_file): |
| logger.warning(f"Text encoder output cache file not found: {text_encoder_output_cache_file}") |
| continue |
|
|
| bucket_reso = bucket_selector.get_bucket_resolution(image_size) |
| bucket_reso = (*bucket_reso, frame_count) |
| item_info = ItemInfo(item_key, "", image_size, bucket_reso, frame_count=frame_count, latent_cache_path=cache_file) |
| item_info.text_encoder_output_cache_path = text_encoder_output_cache_file |
|
|
| bucket = bucketed_item_info.get(bucket_reso, []) |
| for _ in range(self.num_repeats): |
| bucket.append(item_info) |
| bucketed_item_info[bucket_reso] = bucket |
|
|
| |
| self.batch_manager = BucketBatchManager(bucketed_item_info, self.batch_size) |
| self.batch_manager.show_bucket_info() |
|
|
| self.num_train_items = sum([len(bucket) for bucket in bucketed_item_info.values()]) |
|
|
| def shuffle_buckets(self): |
| |
| random.seed(self.seed + self.current_epoch) |
| self.batch_manager.shuffle() |
|
|
| def __len__(self): |
| if self.batch_manager is None: |
| return 100 |
| return len(self.batch_manager) |
|
|
| def __getitem__(self, idx): |
| return self.batch_manager[idx] |
|
|
|
|
| class DatasetGroup(torch.utils.data.ConcatDataset): |
| def __init__(self, datasets: Sequence[Union[ImageDataset, VideoDataset]]): |
| super().__init__(datasets) |
| self.datasets: list[Union[ImageDataset, VideoDataset]] = datasets |
| self.num_train_items = 0 |
| for dataset in self.datasets: |
| self.num_train_items += dataset.num_train_items |
|
|
| def set_current_epoch(self, epoch): |
| for dataset in self.datasets: |
| dataset.set_current_epoch(epoch) |
|
|
| def set_current_step(self, step): |
| for dataset in self.datasets: |
| dataset.set_current_step(step) |
|
|
| def set_max_train_steps(self, max_train_steps): |
| for dataset in self.datasets: |
| dataset.set_max_train_steps(max_train_steps) |
|
|