Spaces:
Configuration error
Configuration error
| import os | |
| import numpy as np | |
| from PIL import Image | |
| from einops import rearrange | |
| from pathlib import Path | |
| import torch | |
| from torch.utils.data import Dataset | |
| from .transform import short_size_scale, random_crop, center_crop, offset_crop | |
| from ..common.image_util import IMAGE_EXTENSION | |
| import cv2 | |
| class ImageSequenceDataset(Dataset): | |
| def __init__( | |
| self, | |
| path: str, | |
| layout_mask_dir: str, | |
| layout_mask_order: list, | |
| prompt_ids: torch.Tensor, | |
| prompt: str, | |
| start_sample_frame: int=0, | |
| n_sample_frame: int = 8, | |
| sampling_rate: int = 1, | |
| stride: int = -1, # only used during tuning to sample a long video | |
| image_mode: str = "RGB", | |
| image_size: int = 512, | |
| crop: str = "center", | |
| class_data_root: str = None, | |
| class_prompt_ids: torch.Tensor = None, | |
| offset: dict = { | |
| "left": 0, | |
| "right": 0, | |
| "top": 0, | |
| "bottom": 0 | |
| }, | |
| **args | |
| ): | |
| self.path = path | |
| self.images = self.get_image_list(path) | |
| # | |
| self.layout_mask_dir = layout_mask_dir | |
| self.layout_mask_order = list(layout_mask_order) | |
| layout_mask_dir0 = os.path.join(self.layout_mask_dir,self.layout_mask_order[0]) | |
| self.masks_index = self.get_image_list(layout_mask_dir0) | |
| # | |
| self.n_images = len(self.images) | |
| self.offset = offset | |
| self.start_sample_frame = start_sample_frame | |
| if n_sample_frame < 0: | |
| n_sample_frame = len(self.images) | |
| self.n_sample_frame = n_sample_frame | |
| # local sampling rate from the video | |
| self.sampling_rate = sampling_rate | |
| self.sequence_length = (n_sample_frame - 1) * sampling_rate + 1 | |
| if self.n_images < self.sequence_length: | |
| raise ValueError(f"self.n_images {self.n_images } < self.sequence_length {self.sequence_length}: Required number of frames {self.sequence_length} larger than total frames in the dataset {self.n_images }") | |
| # During tuning if video is too long, we sample the long video every self.stride globally | |
| self.stride = stride if stride > 0 else (self.n_images+1) | |
| self.video_len = (self.n_images - self.sequence_length) // self.stride + 1 | |
| self.image_mode = image_mode | |
| self.image_size = image_size | |
| crop_methods = { | |
| "center": center_crop, | |
| "random": random_crop, | |
| } | |
| if crop not in crop_methods: | |
| raise ValueError | |
| self.crop = crop_methods[crop] | |
| self.prompt = prompt | |
| self.prompt_ids = prompt_ids | |
| # Negative prompt for regularization to avoid overfitting during one-shot tuning | |
| if class_data_root is not None: | |
| self.class_data_root = Path(class_data_root) | |
| self.class_images_path = sorted(list(self.class_data_root.iterdir())) | |
| self.num_class_images = len(self.class_images_path) | |
| self.class_prompt_ids = class_prompt_ids | |
| def __len__(self): | |
| max_len = (self.n_images - self.sequence_length) // self.stride + 1 | |
| if hasattr(self, 'num_class_images'): | |
| max_len = max(max_len, self.num_class_images) | |
| return max_len | |
| def __getitem__(self, index): | |
| return_batch = {} | |
| frame_indices = self.get_frame_indices(index%self.video_len) | |
| frames = [self.load_frame(i) for i in frame_indices] | |
| frames = self.transform(frames) | |
| layout_ = [] | |
| for layout_name in self.layout_mask_order: | |
| frame_indices = self.get_frame_indices(index%self.video_len) | |
| layout_mask_dir = os.path.join(self.layout_mask_dir,layout_name) | |
| mask = [self._read_mask(layout_mask_dir,i) for i in frame_indices] | |
| masks = np.stack(mask) | |
| layout_.append(masks) | |
| layout_ = np.stack(layout_) | |
| merged_masks = [] | |
| for i in range(int(self.n_sample_frame)): | |
| merged_mask_frame = np.sum(layout_[:,i,:,:,:], axis=0) | |
| merged_mask_frame = (merged_mask_frame > 0).astype(np.uint8) | |
| merged_masks.append(merged_mask_frame) | |
| masks = rearrange(np.stack(merged_masks), "f c h w -> c f h w") | |
| masks = torch.from_numpy(masks).half() | |
| layouts = rearrange(layout_,"s f c h w -> f s c h w" ) | |
| layouts = torch.from_numpy(layouts).half() | |
| return_batch.update( | |
| { | |
| "images": frames, | |
| "masks":masks, | |
| "layouts":layouts, | |
| "prompt_ids": self.prompt_ids, | |
| } | |
| ) | |
| if hasattr(self, 'class_data_root'): | |
| class_index = index % (self.num_class_images - self.n_sample_frame) | |
| class_indices = self.get_class_indices(class_index) | |
| frames = [self.load_class_frame(i) for i in class_indices] | |
| return_batch["class_images"] = self.tensorize_frames(frames) | |
| return_batch["class_prompt_ids"] = self.class_prompt_ids | |
| return return_batch | |
| def transform(self, frames): | |
| frames = self.tensorize_frames(frames) | |
| frames = offset_crop(frames, **self.offset) | |
| frames = short_size_scale(frames, size=self.image_size) | |
| frames = self.crop(frames, height=self.image_size, width=self.image_size) | |
| return frames | |
| def tensorize_frames(frames): | |
| frames = rearrange(np.stack(frames), "f h w c -> c f h w") | |
| return torch.from_numpy(frames).div(255) * 2 - 1 | |
| def _read_mask(self, mask_path,index: int): | |
| ### read mask by pil | |
| mask_path = os.path.join(mask_path,f"{index:05d}.png") | |
| ### read mask by cv2 | |
| mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE) | |
| mask = (mask > 0).astype(np.uint8) | |
| # Determine dynamic destination size | |
| height, width = mask.shape | |
| dest_size = (width // 8, height // 8) | |
| # Resize using nearest neighbor interpolation | |
| mask = cv2.resize(mask, dest_size, interpolation=cv2.INTER_NEAREST) #cv2.INTER_CUBIC | |
| mask = mask[np.newaxis, ...] | |
| return mask | |
| def load_frame(self, index): | |
| image_path = os.path.join(self.path, self.images[index]) | |
| return Image.open(image_path).convert(self.image_mode) | |
| def load_class_frame(self, index): | |
| image_path = self.class_images_path[index] | |
| return Image.open(image_path).convert(self.image_mode) | |
| def get_frame_indices(self, index): | |
| if self.start_sample_frame is not None: | |
| frame_start = self.start_sample_frame + self.stride * index | |
| else: | |
| frame_start = self.stride * index | |
| return (frame_start + i * self.sampling_rate for i in range(self.n_sample_frame)) | |
| def get_class_indices(self, index): | |
| frame_start = index | |
| return (frame_start + i for i in range(self.n_sample_frame)) | |
| def get_image_list(path): | |
| images = [] | |
| for file in sorted(os.listdir(path)): | |
| if file.endswith(IMAGE_EXTENSION): | |
| images.append(file) | |
| return images |