|
|
|
|
|
|
|
|
from pathlib import Path |
|
|
|
|
|
import torch |
|
|
from PIL import Image |
|
|
from torch.utils.data import Dataset |
|
|
from torchvision import transforms |
|
|
|
|
|
|
|
|
class DreamBoothDataset(Dataset): |
|
|
""" |
|
|
A dataset to prepare the instance and class images with the prompts for fine-tuning the model. |
|
|
It pre-processes the images and the tokenizes prompts. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
instance_data_root, |
|
|
instance_prompt, |
|
|
tokenizer, |
|
|
class_data_root=None, |
|
|
class_prompt=None, |
|
|
size=512, |
|
|
center_crop=False, |
|
|
): |
|
|
self.size = size |
|
|
self.center_crop = center_crop |
|
|
self.tokenizer = tokenizer |
|
|
|
|
|
self.instance_data_root = Path(instance_data_root) |
|
|
if not self.instance_data_root.exists(): |
|
|
raise ValueError("Instance images root doesn't exists.") |
|
|
|
|
|
self.instance_images_path = list(Path(instance_data_root).iterdir()) |
|
|
self.num_instance_images = len(self.instance_images_path) |
|
|
self.instance_prompt = instance_prompt |
|
|
self._length = self.num_instance_images |
|
|
|
|
|
if class_data_root is not None: |
|
|
self.class_data_root = Path(class_data_root) |
|
|
self.class_data_root.mkdir(parents=True, exist_ok=True) |
|
|
self.class_images_path = list(self.class_data_root.iterdir()) |
|
|
self.num_class_images = len(self.class_images_path) |
|
|
self._length = max(self.num_class_images, self.num_instance_images) |
|
|
self.class_prompt = class_prompt |
|
|
else: |
|
|
self.class_data_root = None |
|
|
|
|
|
self.image_transforms = transforms.Compose( |
|
|
[ |
|
|
transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), |
|
|
transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize([0.5], [0.5]), |
|
|
] |
|
|
) |
|
|
|
|
|
def __len__(self): |
|
|
return self._length |
|
|
|
|
|
def __getitem__(self, index): |
|
|
example = {} |
|
|
instance_image = Image.open(self.instance_images_path[index % self.num_instance_images]) |
|
|
if not instance_image.mode == "RGB": |
|
|
instance_image = instance_image.convert("RGB") |
|
|
example["instance_images"] = self.image_transforms(instance_image) |
|
|
example["instance_prompt_ids"] = self.tokenizer( |
|
|
self.instance_prompt, |
|
|
truncation=True, |
|
|
padding="max_length", |
|
|
max_length=self.tokenizer.model_max_length, |
|
|
return_tensors="pt", |
|
|
).input_ids |
|
|
|
|
|
if self.class_data_root: |
|
|
class_image = Image.open(self.class_images_path[index % self.num_class_images]) |
|
|
if not class_image.mode == "RGB": |
|
|
class_image = class_image.convert("RGB") |
|
|
example["class_images"] = self.image_transforms(class_image) |
|
|
example["class_prompt_ids"] = self.tokenizer( |
|
|
self.class_prompt, |
|
|
truncation=True, |
|
|
padding="max_length", |
|
|
max_length=self.tokenizer.model_max_length, |
|
|
return_tensors="pt", |
|
|
).input_ids |
|
|
|
|
|
return example |
|
|
|
|
|
|
|
|
def collate_fn(examples, with_prior_preservation=False): |
|
|
input_ids = [example["instance_prompt_ids"] for example in examples] |
|
|
pixel_values = [example["instance_images"] for example in examples] |
|
|
|
|
|
|
|
|
|
|
|
if with_prior_preservation: |
|
|
input_ids += [example["class_prompt_ids"] for example in examples] |
|
|
pixel_values += [example["class_images"] for example in examples] |
|
|
|
|
|
pixel_values = torch.stack(pixel_values) |
|
|
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() |
|
|
|
|
|
input_ids = torch.cat(input_ids, dim=0) |
|
|
|
|
|
batch = { |
|
|
"input_ids": input_ids, |
|
|
"pixel_values": pixel_values, |
|
|
} |
|
|
return batch |
|
|
|
|
|
|
|
|
class PromptDataset(Dataset): |
|
|
"A simple dataset to prepare the prompts to generate class images on multiple GPUs." |
|
|
|
|
|
def __init__(self, prompt, num_samples): |
|
|
self.prompt = prompt |
|
|
self.num_samples = num_samples |
|
|
|
|
|
def __len__(self): |
|
|
return self.num_samples |
|
|
|
|
|
def __getitem__(self, index): |
|
|
example = {} |
|
|
example["prompt"] = self.prompt |
|
|
example["index"] = index |
|
|
return example |
|
|
|