| from PIL import Image, ImageFilter, ImageDraw |
| import cv2 |
| import numpy as np |
| from torch.utils.data import Dataset |
| import torchvision.transforms as T |
| import random |
|
|
|
|
| class Subject200KDataset(Dataset): |
| def __init__( |
| self, |
| base_dataset, |
| condition_size: int = 512, |
| target_size: int = 512, |
| image_size: int = 512, |
| padding: int = 0, |
| condition_type: str = "subject", |
| drop_text_prob: float = 0.1, |
| drop_image_prob: float = 0.1, |
| return_pil_image: bool = False, |
| ): |
| self.base_dataset = base_dataset |
| self.condition_size = condition_size |
| self.target_size = target_size |
| self.image_size = image_size |
| self.padding = padding |
| self.condition_type = condition_type |
| self.drop_text_prob = drop_text_prob |
| self.drop_image_prob = drop_image_prob |
| self.return_pil_image = return_pil_image |
|
|
| self.to_tensor = T.ToTensor() |
|
|
| def __len__(self): |
| return len(self.base_dataset) * 2 |
|
|
| def __getitem__(self, idx): |
| |
| target = idx % 2 |
| item = self.base_dataset[idx // 2] |
|
|
| |
| image = item["image"] |
| left_img = image.crop( |
| ( |
| self.padding, |
| self.padding, |
| self.image_size + self.padding, |
| self.image_size + self.padding, |
| ) |
| ) |
| right_img = image.crop( |
| ( |
| self.image_size + self.padding * 2, |
| self.padding, |
| self.image_size * 2 + self.padding * 2, |
| self.image_size + self.padding, |
| ) |
| ) |
|
|
| |
| target_image, condition_img = ( |
| (left_img, right_img) if target == 0 else (right_img, left_img) |
| ) |
|
|
| |
| condition_img = condition_img.resize( |
| (self.condition_size, self.condition_size) |
| ).convert("RGB") |
| target_image = target_image.resize( |
| (self.target_size, self.target_size) |
| ).convert("RGB") |
|
|
| |
| description = item["description"][ |
| "description_0" if target == 0 else "description_1" |
| ] |
|
|
| |
| drop_text = random.random() < self.drop_text_prob |
| drop_image = random.random() < self.drop_image_prob |
| if drop_text: |
| description = "" |
| if drop_image: |
| condition_img = Image.new( |
| "RGB", (self.condition_size, self.condition_size), (0, 0, 0) |
| ) |
|
|
| return { |
| "image": self.to_tensor(target_image), |
| "condition": self.to_tensor(condition_img), |
| "condition_type": self.condition_type, |
| "description": description, |
| |
| "position_delta": np.array([0, -self.condition_size // 16]), |
| **({"pil_image": image} if self.return_pil_image else {}), |
| } |
|
|
|
|
| class ImageConditionDataset(Dataset): |
| def __init__( |
| self, |
| base_dataset, |
| condition_size: int = 512, |
| target_size: int = 512, |
| condition_type: str = "canny", |
| drop_text_prob: float = 0.1, |
| drop_image_prob: float = 0.1, |
| return_pil_image: bool = False, |
| position_scale=1.0, |
| ): |
| self.base_dataset = base_dataset |
| self.condition_size = condition_size |
| self.target_size = target_size |
| self.condition_type = condition_type |
| self.drop_text_prob = drop_text_prob |
| self.drop_image_prob = drop_image_prob |
| self.return_pil_image = return_pil_image |
| self.position_scale = position_scale |
|
|
| self.to_tensor = T.ToTensor() |
|
|
| def __len__(self): |
| return len(self.base_dataset) |
|
|
| @property |
| def depth_pipe(self): |
| if not hasattr(self, "_depth_pipe"): |
| from transformers import pipeline |
|
|
| self._depth_pipe = pipeline( |
| task="depth-estimation", |
| model="LiheYoung/depth-anything-small-hf", |
| device="cpu", |
| ) |
| return self._depth_pipe |
|
|
| def _get_canny_edge(self, img): |
| resize_ratio = self.condition_size / max(img.size) |
| img = img.resize( |
| (int(img.size[0] * resize_ratio), int(img.size[1] * resize_ratio)) |
| ) |
| img_np = np.array(img) |
| img_gray = cv2.cvtColor(img_np, cv2.COLOR_RGB2GRAY) |
| edges = cv2.Canny(img_gray, 100, 200) |
| return Image.fromarray(edges).convert("RGB") |
|
|
| def __getitem__(self, idx): |
| image = self.base_dataset[idx]["jpg"] |
| image = image.resize((self.target_size, self.target_size)).convert("RGB") |
| description = self.base_dataset[idx]["json"]["prompt"] |
|
|
| enable_scale = random.random() < 1 |
| if not enable_scale: |
| condition_size = int(self.condition_size * self.position_scale) |
| position_scale = 1.0 |
| else: |
| condition_size = self.condition_size |
| position_scale = self.position_scale |
|
|
| |
| position_delta = np.array([0, 0]) |
| if self.condition_type == "canny": |
| condition_img = self._get_canny_edge(image) |
| elif self.condition_type == "coloring": |
| condition_img = ( |
| image.resize((condition_size, condition_size)) |
| .convert("L") |
| .convert("RGB") |
| ) |
| elif self.condition_type == "deblurring": |
| blur_radius = random.randint(1, 10) |
| condition_img = ( |
| image.convert("RGB") |
| .filter(ImageFilter.GaussianBlur(blur_radius)) |
| .resize((condition_size, condition_size)) |
| .convert("RGB") |
| ) |
| elif self.condition_type == "depth": |
| condition_img = self.depth_pipe(image)["depth"].convert("RGB") |
| condition_img = condition_img.resize((condition_size, condition_size)) |
| elif self.condition_type == "depth_pred": |
| condition_img = image |
| image = self.depth_pipe(condition_img)["depth"].convert("RGB") |
| description = f"[depth] {description}" |
| elif self.condition_type == "fill": |
| condition_img = image.resize((condition_size, condition_size)).convert( |
| "RGB" |
| ) |
| w, h = image.size |
| x1, x2 = sorted([random.randint(0, w), random.randint(0, w)]) |
| y1, y2 = sorted([random.randint(0, h), random.randint(0, h)]) |
| mask = Image.new("L", image.size, 0) |
| draw = ImageDraw.Draw(mask) |
| draw.rectangle([x1, y1, x2, y2], fill=255) |
| if random.random() > 0.5: |
| mask = Image.eval(mask, lambda a: 255 - a) |
| condition_img = Image.composite( |
| image, Image.new("RGB", image.size, (0, 0, 0)), mask |
| ) |
| elif self.condition_type == "sr": |
| condition_img = image.resize((condition_size, condition_size)).convert( |
| "RGB" |
| ) |
| position_delta = np.array([0, -condition_size // 16]) |
|
|
| else: |
| raise ValueError(f"Condition type {self.condition_type} not implemented") |
|
|
| |
| drop_text = random.random() < self.drop_text_prob |
| drop_image = random.random() < self.drop_image_prob |
| if drop_text: |
| description = "" |
| if drop_image: |
| condition_img = Image.new( |
| "RGB", (condition_size, condition_size), (0, 0, 0) |
| ) |
|
|
| return { |
| "image": self.to_tensor(image), |
| "condition": self.to_tensor(condition_img), |
| "condition_type": self.condition_type, |
| "description": description, |
| "position_delta": position_delta, |
| **({"pil_image": [image, condition_img]} if self.return_pil_image else {}), |
| **({"position_scale": position_scale} if position_scale != 1.0 else {}), |
| } |
|
|
| import os |
| from PIL import Image |
| from torch.utils.data import Dataset |
|
|
| class SRBaseDataset(Dataset): |
| def __init__(self, root_dir, lr_dir='LR', gt_dir='GT', lr_suffix='', gt_suffix=''): |
| self.lr_root = os.path.join(root_dir, lr_dir) |
| self.gt_root = os.path.join(root_dir, gt_dir) |
|
|
| self.filenames = sorted([ |
| f for f in os.listdir(self.lr_root) |
| if os.path.isfile(os.path.join(self.lr_root, f)) |
| ]) |
|
|
| self.lr_suffix = lr_suffix |
| self.gt_suffix = gt_suffix |
|
|
| def __len__(self): |
| return len(self.filenames) |
|
|
| def __getitem__(self, idx): |
| filename = self.filenames[idx] |
| base_name, ext = os.path.splitext(filename) |
|
|
| lr_path = os.path.join(self.lr_root, base_name + self.lr_suffix + ext) |
| gt_path = os.path.join(self.gt_root, base_name + self.gt_suffix + ext) |
|
|
| lr = Image.open(lr_path) |
| gt = Image.open(gt_path) |
|
|
| return {'lr': lr, 'gt': gt} |
|
|
|
|
| class SRDataset(Dataset): |
| def __init__( |
| self, |
| base_dataset, |
| condition_size: int = 512, |
| target_size: int = 512, |
| condition_type: str = "sr", |
| drop_text_prob: float = 0.1, |
| drop_image_prob: float = 0.1, |
| return_pil_image: bool = False, |
| position_scale=1.0, |
| ): |
| self.base_dataset = base_dataset |
| self.condition_size = condition_size |
| self.target_size = target_size |
| self.condition_type = condition_type |
| self.drop_text_prob = drop_text_prob |
| self.drop_image_prob = drop_image_prob |
| self.return_pil_image = return_pil_image |
| self.position_scale = position_scale |
|
|
| self.to_tensor = T.ToTensor() |
|
|
| def __len__(self): |
| return len(self.base_dataset) |
|
|
| def __getitem__(self, idx): |
| image = self.base_dataset[idx]["gt"] |
| image = image.resize((self.target_size, self.target_size)).convert("RGB") |
| description = "" |
|
|
| enable_scale = random.random() < 1 |
| if not enable_scale: |
| condition_size = int(self.condition_size * self.position_scale) |
| position_scale = 1.0 |
| else: |
| condition_size = self.condition_size |
| position_scale = self.position_scale |
|
|
| |
| position_delta = np.array([0, 0]) |
| condition_img = self.base_dataset[idx]["lr"] |
| condition_img = condition_img.resize((condition_size, condition_size),resample=Image.BICUBIC).convert( |
| "RGB" |
| ) |
| |
|
|
| |
| drop_text = random.random() < self.drop_text_prob |
| drop_image = random.random() < self.drop_image_prob |
| if drop_text: |
| description = "" |
| if drop_image: |
| condition_img = Image.new( |
| "RGB", (condition_size, condition_size), (0, 0, 0) |
| ) |
|
|
| return { |
| "image": self.to_tensor(image), |
| "condition": self.to_tensor(condition_img), |
| "condition_type": self.condition_type, |
| "description": description, |
| "position_delta": position_delta, |
| **({"pil_image": [image, condition_img]} if self.return_pil_image else {}), |
| **({"position_scale": position_scale} if position_scale != 1.0 else {}), |
| } |
|
|
| class CartoonDataset(Dataset): |
| def __init__( |
| self, |
| base_dataset, |
| condition_size: int = 1024, |
| target_size: int = 1024, |
| image_size: int = 1024, |
| padding: int = 0, |
| condition_type: str = "cartoon", |
| drop_text_prob: float = 0.1, |
| drop_image_prob: float = 0.1, |
| return_pil_image: bool = False, |
| ): |
| self.base_dataset = base_dataset |
| self.condition_size = condition_size |
| self.target_size = target_size |
| self.image_size = image_size |
| self.padding = padding |
| self.condition_type = condition_type |
| self.drop_text_prob = drop_text_prob |
| self.drop_image_prob = drop_image_prob |
| self.return_pil_image = return_pil_image |
|
|
| self.to_tensor = T.ToTensor() |
|
|
| def __len__(self): |
| return len(self.base_dataset) |
|
|
| def __getitem__(self, idx): |
| data = self.base_dataset[idx] |
| condition_img = data["condition"] |
| target_image = data["target"] |
|
|
| |
| tag = data["tags"][0] |
|
|
| target_description = data["target_description"] |
|
|
| description = { |
| "lion": "lion like animal", |
| "bear": "bear like animal", |
| "gorilla": "gorilla like animal", |
| "dog": "dog like animal", |
| "elephant": "elephant like animal", |
| "eagle": "eagle like bird", |
| "tiger": "tiger like animal", |
| "owl": "owl like bird", |
| "woman": "woman", |
| "parrot": "parrot like bird", |
| "mouse": "mouse like animal", |
| "man": "man", |
| "pigeon": "pigeon like bird", |
| "girl": "girl", |
| "panda": "panda like animal", |
| "crocodile": "crocodile like animal", |
| "rabbit": "rabbit like animal", |
| "boy": "boy", |
| "monkey": "monkey like animal", |
| "cat": "cat like animal", |
| } |
|
|
| |
| condition_img = condition_img.resize( |
| (self.condition_size, self.condition_size) |
| ).convert("RGB") |
| target_image = target_image.resize( |
| (self.target_size, self.target_size) |
| ).convert("RGB") |
|
|
| |
| description = data.get( |
| "description", |
| f"Photo of a {description[tag]} cartoon character in a white background. Character is facing {target_description['facing_direction']}. Character pose is {target_description['pose']}.", |
| ) |
|
|
| |
| drop_text = random.random() < self.drop_text_prob |
| drop_image = random.random() < self.drop_image_prob |
| if drop_text: |
| description = "" |
| if drop_image: |
| condition_img = Image.new( |
| "RGB", (self.condition_size, self.condition_size), (0, 0, 0) |
| ) |
|
|
| return { |
| "image": self.to_tensor(target_image), |
| "condition": self.to_tensor(condition_img), |
| "condition_type": self.condition_type, |
| "description": description, |
| |
| "position_delta": np.array([0, -16]), |
| } |
|
|