| |
| |
|
|
| |
| |
| |
|
|
| |
|
|
| |
| |
| |
| |
| |
|
|
| import json |
| import os |
|
|
| import numpy as np |
| import torch |
| import torchvision.transforms.functional as TVF |
| from torch.utils.data import DataLoader, Dataset |
| from torchvision.transforms import Compose, Normalize, ToTensor |
|
|
| def bucket_images(images: list[torch.Tensor], resolution: int = 512): |
| bucket_override=[ |
| |
| (256, 768), |
| (320, 768), |
| (320, 704), |
| (384, 640), |
| (448, 576), |
| (512, 512), |
| (576, 448), |
| (640, 384), |
| (704, 320), |
| (768, 320), |
| (768, 256) |
| ] |
| bucket_override = [(int(h / 512 * resolution), int(w / 512 * resolution)) for h, w in bucket_override] |
| bucket_override = [(h // 16 * 16, w // 16 * 16) for h, w in bucket_override] |
|
|
| aspect_ratios = [image.shape[-2] / image.shape[-1] for image in images] |
| mean_aspect_ratio = np.mean(aspect_ratios) |
| |
| new_h, new_w = bucket_override[0] |
| min_aspect_diff = np.abs(new_h / new_w - mean_aspect_ratio) |
| for h, w in bucket_override: |
| aspect_diff = np.abs(h / w - mean_aspect_ratio) |
| if aspect_diff < min_aspect_diff: |
| min_aspect_diff = aspect_diff |
| new_h, new_w = h, w |
| |
| images = [TVF.resize(image, (new_h, new_w)) for image in images] |
| images = torch.stack(images, dim=0) |
| return images |
|
|
| class FluxPairedDatasetV2(Dataset): |
| def __init__(self, json_file: str, resolution: int, resolution_ref: int | None = None): |
| super().__init__() |
| self.json_file = json_file |
| self.resolution = resolution |
| self.resolution_ref = resolution_ref if resolution_ref is not None else resolution |
| self.image_root = os.path.dirname(json_file) |
|
|
| with open(self.json_file, "rt") as f: |
| self.data_dicts = json.load(f) |
| |
| self.transform = Compose([ |
| ToTensor(), |
| Normalize([0.5], [0.5]), |
| ]) |
| |
| def __getitem__(self, idx): |
| data_dict = self.data_dicts[idx] |
| image_paths = [data_dict["image_path"]] if "image_path" in data_dict else data_dict["image_paths"] |
| txt = data_dict["prompt"] |
| image_tgt_path = data_dict.get("image_tgt_path", None) |
| ref_imgs = [ |
| Image.open(os.path.join(self.image_root, path)).convert("RGB") |
| for path in image_paths |
| ] |
| ref_imgs = [self.transform(img) for img in ref_imgs] |
| img = None |
| if image_tgt_path is not None: |
| img = Image.open(os.path.join(self.image_root, image_tgt_path)).convert("RGB") |
| img = self.transform(img) |
|
|
| return { |
| "img": img, |
| "txt": txt, |
| "ref_imgs": ref_imgs, |
| } |
|
|
| def __len__(self): |
| return len(self.data_dicts) |
| |
| def collate_fn(self, batch): |
| img = [data["img"] for data in batch] |
| txt = [data["txt"] for data in batch] |
| ref_imgs = [data["ref_imgs"] for data in batch] |
| assert all([len(ref_imgs[0]) == len(ref_imgs[i]) for i in range(len(ref_imgs))]) |
|
|
| n_ref = len(ref_imgs[0]) |
|
|
| img = bucket_images(img, self.resolution) |
| ref_imgs_new = [] |
| for i in range(n_ref): |
| ref_imgs_i = [refs[i] for refs in ref_imgs] |
| ref_imgs_i = bucket_images(ref_imgs_i, self.resolution_ref) |
| ref_imgs_new.append(ref_imgs_i) |
|
|
| return { |
| "txt": txt, |
| "img": img, |
| "ref_imgs": ref_imgs_new, |
| } |
|
|
| if __name__ == '__main__': |
| import argparse |
| from pprint import pprint |
| parser = argparse.ArgumentParser() |
| |
| parser.add_argument("--json_file", type=str, default="datasets/fake_train_data.json") |
| args = parser.parse_args() |
| dataset = FluxPairedDatasetV2(args.json_file, 512) |
| dataloder = DataLoader(dataset, batch_size=4, collate_fn=dataset.collate_fn) |
|
|
| for i, data_dict in enumerate(dataloder): |
| pprint(i) |
| pprint(data_dict) |
| breakpoint() |
|
|