| | import os |
| | import numpy as np |
| | import PIL |
| | import torch |
| | from PIL import Image |
| | from torch.utils.data import Dataset, DataLoader, Sampler |
| | from torchvision import transforms |
| | from collections import defaultdict |
| | from random import shuffle, choices |
| |
|
| | import random |
| | import tqdm |
| | from modules import devices, shared |
| | import re |
| |
|
| | from ldm.modules.distributions.distributions import DiagonalGaussianDistribution |
| |
|
| | re_numbers_at_start = re.compile(r"^[-\d]+\s*") |
| |
|
| |
|
| | class DatasetEntry: |
| | def __init__(self, filename=None, filename_text=None, latent_dist=None, latent_sample=None, cond=None, cond_text=None, pixel_values=None, weight=None): |
| | self.filename = filename |
| | self.filename_text = filename_text |
| | self.weight = weight |
| | self.latent_dist = latent_dist |
| | self.latent_sample = latent_sample |
| | self.cond = cond |
| | self.cond_text = cond_text |
| | self.pixel_values = pixel_values |
| |
|
| |
|
| | class PersonalizedBase(Dataset): |
| | def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, cond_model=None, device=None, template_file=None, include_cond=False, batch_size=1, gradient_step=1, shuffle_tags=False, tag_drop_out=0, latent_sampling_method='once', varsize=False, use_weight=False): |
| | re_word = re.compile(shared.opts.dataset_filename_word_regex) if len(shared.opts.dataset_filename_word_regex) > 0 else None |
| |
|
| | self.placeholder_token = placeholder_token |
| |
|
| | self.flip = transforms.RandomHorizontalFlip(p=flip_p) |
| |
|
| | self.dataset = [] |
| |
|
| | with open(template_file, "r") as file: |
| | lines = [x.strip() for x in file.readlines()] |
| |
|
| | self.lines = lines |
| |
|
| | assert data_root, 'dataset directory not specified' |
| | assert os.path.isdir(data_root), "Dataset directory doesn't exist" |
| | assert os.listdir(data_root), "Dataset directory is empty" |
| |
|
| | self.image_paths = [os.path.join(data_root, file_path) for file_path in os.listdir(data_root)] |
| |
|
| | self.shuffle_tags = shuffle_tags |
| | self.tag_drop_out = tag_drop_out |
| | groups = defaultdict(list) |
| |
|
| | print("Preparing dataset...") |
| | for path in tqdm.tqdm(self.image_paths): |
| | alpha_channel = None |
| | if shared.state.interrupted: |
| | raise Exception("interrupted") |
| | try: |
| | image = Image.open(path) |
| | |
| | |
| | if use_weight and 'A' in image.getbands(): |
| | alpha_channel = image.getchannel('A') |
| | image = image.convert('RGB') |
| | if not varsize: |
| | image = image.resize((width, height), PIL.Image.BICUBIC) |
| | except Exception: |
| | continue |
| |
|
| | text_filename = os.path.splitext(path)[0] + ".txt" |
| | filename = os.path.basename(path) |
| |
|
| | if os.path.exists(text_filename): |
| | with open(text_filename, "r", encoding="utf8") as file: |
| | filename_text = file.read() |
| | else: |
| | filename_text = os.path.splitext(filename)[0] |
| | filename_text = re.sub(re_numbers_at_start, '', filename_text) |
| | if re_word: |
| | tokens = re_word.findall(filename_text) |
| | filename_text = (shared.opts.dataset_filename_join_string or "").join(tokens) |
| |
|
| | npimage = np.array(image).astype(np.uint8) |
| | npimage = (npimage / 127.5 - 1.0).astype(np.float32) |
| |
|
| | torchdata = torch.from_numpy(npimage).permute(2, 0, 1).to(device=device, dtype=torch.float32) |
| | latent_sample = None |
| |
|
| | with devices.autocast(): |
| | latent_dist = model.encode_first_stage(torchdata.unsqueeze(dim=0)) |
| |
|
| | |
| | |
| | if latent_sampling_method == "deterministic": |
| | if isinstance(latent_dist, DiagonalGaussianDistribution): |
| | |
| | latent_dist.std = 0 |
| | else: |
| | latent_sampling_method = "once" |
| | latent_sample = model.get_first_stage_encoding(latent_dist).squeeze().to(devices.cpu) |
| |
|
| | if use_weight and alpha_channel is not None: |
| | channels, *latent_size = latent_sample.shape |
| | weight_img = alpha_channel.resize(latent_size) |
| | npweight = np.array(weight_img).astype(np.float32) |
| | |
| | weight = torch.tensor([npweight] * channels).reshape([channels] + latent_size) |
| | |
| | weight -= weight.min() |
| | weight /= weight.mean() |
| | elif use_weight: |
| | |
| | weight = torch.ones(latent_sample.shape) |
| | else: |
| | weight = None |
| | |
| | if latent_sampling_method == "random": |
| | entry = DatasetEntry(filename=path, filename_text=filename_text, latent_dist=latent_dist, weight=weight) |
| | else: |
| | entry = DatasetEntry(filename=path, filename_text=filename_text, latent_sample=latent_sample, weight=weight) |
| |
|
| | if not (self.tag_drop_out != 0 or self.shuffle_tags): |
| | entry.cond_text = self.create_text(filename_text) |
| |
|
| | if include_cond and not (self.tag_drop_out != 0 or self.shuffle_tags): |
| | with devices.autocast(): |
| | entry.cond = cond_model([entry.cond_text]).to(devices.cpu).squeeze(0) |
| | groups[image.size].append(len(self.dataset)) |
| | self.dataset.append(entry) |
| | del torchdata |
| | del latent_dist |
| | del latent_sample |
| | del weight |
| |
|
| | self.length = len(self.dataset) |
| | self.groups = list(groups.values()) |
| | assert self.length > 0, "No images have been found in the dataset." |
| | self.batch_size = min(batch_size, self.length) |
| | self.gradient_step = min(gradient_step, self.length // self.batch_size) |
| | self.latent_sampling_method = latent_sampling_method |
| |
|
| | if len(groups) > 1: |
| | print("Buckets:") |
| | for (w, h), ids in sorted(groups.items(), key=lambda x: x[0]): |
| | print(f" {w}x{h}: {len(ids)}") |
| | print() |
| |
|
| | def create_text(self, filename_text): |
| | text = random.choice(self.lines) |
| | tags = filename_text.split(',') |
| | if self.tag_drop_out != 0: |
| | tags = [t for t in tags if random.random() > self.tag_drop_out] |
| | if self.shuffle_tags: |
| | random.shuffle(tags) |
| | text = text.replace("[filewords]", ','.join(tags)) |
| | text = text.replace("[name]", self.placeholder_token) |
| | return text |
| |
|
| | def __len__(self): |
| | return self.length |
| |
|
| | def __getitem__(self, i): |
| | entry = self.dataset[i] |
| | if self.tag_drop_out != 0 or self.shuffle_tags: |
| | entry.cond_text = self.create_text(entry.filename_text) |
| | if self.latent_sampling_method == "random": |
| | entry.latent_sample = shared.sd_model.get_first_stage_encoding(entry.latent_dist).to(devices.cpu) |
| | return entry |
| |
|
| |
|
| | class GroupedBatchSampler(Sampler): |
| | def __init__(self, data_source: PersonalizedBase, batch_size: int): |
| | super().__init__(data_source) |
| |
|
| | n = len(data_source) |
| | self.groups = data_source.groups |
| | self.len = n_batch = n // batch_size |
| | expected = [len(g) / n * n_batch * batch_size for g in data_source.groups] |
| | self.base = [int(e) // batch_size for e in expected] |
| | self.n_rand_batches = nrb = n_batch - sum(self.base) |
| | self.probs = [e%batch_size/nrb/batch_size if nrb>0 else 0 for e in expected] |
| | self.batch_size = batch_size |
| |
|
| | def __len__(self): |
| | return self.len |
| |
|
| | def __iter__(self): |
| | b = self.batch_size |
| |
|
| | for g in self.groups: |
| | shuffle(g) |
| |
|
| | batches = [] |
| | for g in self.groups: |
| | batches.extend(g[i*b:(i+1)*b] for i in range(len(g) // b)) |
| | for _ in range(self.n_rand_batches): |
| | rand_group = choices(self.groups, self.probs)[0] |
| | batches.append(choices(rand_group, k=b)) |
| |
|
| | shuffle(batches) |
| |
|
| | yield from batches |
| |
|
| |
|
| | class PersonalizedDataLoader(DataLoader): |
| | def __init__(self, dataset, latent_sampling_method="once", batch_size=1, pin_memory=False): |
| | super(PersonalizedDataLoader, self).__init__(dataset, batch_sampler=GroupedBatchSampler(dataset, batch_size), pin_memory=pin_memory) |
| | if latent_sampling_method == "random": |
| | self.collate_fn = collate_wrapper_random |
| | else: |
| | self.collate_fn = collate_wrapper |
| |
|
| |
|
| | class BatchLoader: |
| | def __init__(self, data): |
| | self.cond_text = [entry.cond_text for entry in data] |
| | self.cond = [entry.cond for entry in data] |
| | self.latent_sample = torch.stack([entry.latent_sample for entry in data]).squeeze(1) |
| | if all(entry.weight is not None for entry in data): |
| | self.weight = torch.stack([entry.weight for entry in data]).squeeze(1) |
| | else: |
| | self.weight = None |
| | |
| | |
| |
|
| | def pin_memory(self): |
| | self.latent_sample = self.latent_sample.pin_memory() |
| | return self |
| |
|
| | def collate_wrapper(batch): |
| | return BatchLoader(batch) |
| |
|
| | class BatchLoaderRandom(BatchLoader): |
| | def __init__(self, data): |
| | super().__init__(data) |
| |
|
| | def pin_memory(self): |
| | return self |
| |
|
| | def collate_wrapper_random(batch): |
| | return BatchLoaderRandom(batch) |