Spaces:
Runtime error
Runtime error
| import os | |
| import random | |
| import shutil | |
| from io import BytesIO | |
| from pathlib import Path | |
| import numpy as np | |
| import openai | |
| import regex as re | |
| import requests | |
| import torch | |
| from PIL import Image | |
| from torch.utils.data import Dataset | |
| from torchvision import transforms | |
| from tqdm.auto import tqdm | |
| from diffusers import DPMSolverMultistepScheduler | |
| normalize = transforms.Normalize( | |
| mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], | |
| ) | |
| small_288 = transforms.Compose([ | |
| transforms.Resize(288), | |
| transforms.ToTensor(), | |
| normalize, | |
| ]) | |
| def collate_fn(examples, with_prior_preservation): | |
| input_ids = [example["instance_prompt_ids"] for example in examples] | |
| input_anchor_ids = [example["instance_anchor_prompt_ids"] | |
| for example in examples] | |
| pixel_values = [example["instance_images"] for example in examples] | |
| mask = [example["mask"] for example in examples] | |
| # Concat class and instance examples for prior preservation. | |
| # We do this to avoid doing two forward passes. | |
| if with_prior_preservation: | |
| input_ids += [example["class_prompt_ids"] for example in examples] | |
| pixel_values += [example["class_images"] for example in examples] | |
| mask += [example["class_mask"] for example in examples] | |
| input_ids = torch.cat(input_ids, dim=0) | |
| input_anchor_ids = torch.cat(input_anchor_ids, dim=0) | |
| pixel_values = torch.stack(pixel_values) | |
| mask = torch.stack(mask) | |
| pixel_values = pixel_values.to( | |
| memory_format=torch.contiguous_format).float() | |
| mask = mask.to(memory_format=torch.contiguous_format).float() | |
| batch = { | |
| "input_ids": input_ids, | |
| "input_anchor_ids": input_anchor_ids, | |
| "pixel_values": pixel_values, | |
| "mask": mask.unsqueeze(1) | |
| } | |
| return batch | |
| class PromptDataset(Dataset): | |
| "A simple dataset to prepare the prompts to generate class images on multiple GPUs." | |
| def __init__(self, prompt, num_samples): | |
| self.prompt = prompt | |
| self.num_samples = num_samples | |
| def __len__(self): | |
| return self.num_samples | |
| def __getitem__(self, index): | |
| example = {} | |
| example["prompt"] = self.prompt[index % len(self.prompt)] | |
| example["index"] = index | |
| return example | |
| class CustomDiffusionDataset(Dataset): | |
| """ | |
| A dataset to prepare the instance and class images with the prompts for fine-tuning the model. | |
| It pre-processes the images and the tokenizes prompts. | |
| """ | |
| def __init__( | |
| self, | |
| concepts_list, | |
| concept_type, | |
| tokenizer, | |
| size=512, | |
| center_crop=False, | |
| with_prior_preservation=False, | |
| num_class_images=200, | |
| hflip=False, | |
| aug=True, | |
| ): | |
| self.size = size | |
| self.center_crop = center_crop | |
| self.tokenizer = tokenizer | |
| self.interpolation = Image.BILINEAR | |
| self.aug = aug | |
| self.concept_type = concept_type | |
| self.instance_images_path = [] | |
| self.class_images_path = [] | |
| self.with_prior_preservation = with_prior_preservation | |
| for concept in concepts_list: | |
| with open(concept["instance_data_dir"], "r") as f: | |
| inst_images_path = f.read().splitlines() | |
| with open(concept["instance_prompt"], "r") as f: | |
| inst_prompt = f.read().splitlines() | |
| inst_img_path = [(x, y, concept['caption_target']) | |
| for (x, y) in zip(inst_images_path, inst_prompt)] | |
| self.instance_images_path.extend(inst_img_path) | |
| if with_prior_preservation: | |
| class_data_root = Path(concept["class_data_dir"]) | |
| if os.path.isdir(class_data_root): | |
| class_images_path = list(class_data_root.iterdir()) | |
| class_prompt = [concept["class_prompt"] | |
| for _ in range(len(class_images_path))] | |
| else: | |
| with open(class_data_root, "r") as f: | |
| class_images_path = f.read().splitlines() | |
| with open(concept["class_prompt"], "r") as f: | |
| class_prompt = f.read().splitlines() | |
| class_img_path = [(x, y) for (x, y) in zip( | |
| class_images_path, class_prompt)] | |
| self.class_images_path.extend( | |
| class_img_path[:num_class_images]) | |
| random.shuffle(self.instance_images_path) | |
| self.num_instance_images = len(self.instance_images_path) | |
| self.num_class_images = len(self.class_images_path) | |
| self._length = max(self.num_class_images, self.num_instance_images) | |
| self.flip = transforms.RandomHorizontalFlip(0.5 * hflip) | |
| self.image_transforms = transforms.Compose( | |
| [ | |
| self.flip, | |
| transforms.Resize( | |
| size, interpolation=transforms.InterpolationMode.BILINEAR), | |
| transforms.CenterCrop( | |
| size) if center_crop else transforms.RandomCrop(size), | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.5], [0.5]), | |
| ] | |
| ) | |
| def __len__(self): | |
| return self._length | |
| def preprocess(self, image, scale, resample): | |
| outer, inner = self.size, scale | |
| if scale > self.size: | |
| outer, inner = scale, self.size | |
| top, left = np.random.randint( | |
| 0, outer - inner + 1), np.random.randint(0, outer - inner + 1) | |
| image = image.resize((scale, scale), resample=resample) | |
| image = np.array(image).astype(np.uint8) | |
| image = (image / 127.5 - 1.0).astype(np.float32) | |
| instance_image = np.zeros((self.size, self.size, 3), dtype=np.float32) | |
| mask = np.zeros((self.size // 8, self.size // 8)) | |
| if scale > self.size: | |
| instance_image = image[top: top + inner, left: left + inner, :] | |
| mask = np.ones((self.size // 8, self.size // 8)) | |
| else: | |
| instance_image[top: top + inner, left: left + inner, :] = image | |
| mask[top // 8 + 1: (top + scale) // 8 - 1, left // | |
| 8 + 1: (left + scale) // 8 - 1] = 1. | |
| return instance_image, mask | |
| def __getprompt__(self, instance_prompt, instance_target): | |
| if self.concept_type == 'style': | |
| r = np.random.choice([0, 1, 2]) | |
| instance_prompt = f'{instance_prompt}, in the style of {instance_target}' if r == 0 else f'in {instance_target}\'s style, {instance_prompt}' if r == 1 else f'in {instance_target}\'s style, {instance_prompt}' | |
| elif self.concept_type == 'object': | |
| anchor, target = instance_target.split('+') | |
| instance_prompt = instance_prompt.replace(anchor, target) | |
| elif self.concept_type == 'memorization': | |
| instance_prompt = instance_target.split('+')[1] | |
| return instance_prompt | |
| def __getitem__(self, index): | |
| example = {} | |
| instance_image, instance_prompt, instance_target = self.instance_images_path[ | |
| index % self.num_instance_images] | |
| instance_image = Image.open(instance_image) | |
| if not instance_image.mode == "RGB": | |
| instance_image = instance_image.convert("RGB") | |
| instance_image = self.flip(instance_image) | |
| # modify instance prompt according to the concept_type to include target concept | |
| # multiple style/object fine-tuning | |
| if ';' in instance_target: | |
| instance_target = instance_target.split(';') | |
| instance_target = instance_target[index % len(instance_target)] | |
| instance_anchor_prompt = instance_prompt | |
| instance_prompt = self.__getprompt__(instance_prompt, instance_target) | |
| # apply resize augmentation and create a valid image region mask | |
| random_scale = self.size | |
| if self.aug: | |
| random_scale = np.random.randint(self.size // 3, self.size + 1) if np.random.uniform( | |
| ) < 0.66 else np.random.randint(int(1.2 * self.size), int(1.4 * self.size)) | |
| instance_image, mask = self.preprocess( | |
| instance_image, random_scale, self.interpolation) | |
| if random_scale < 0.6 * self.size: | |
| instance_prompt = np.random.choice( | |
| ["a far away ", "very small "]) + instance_prompt | |
| elif random_scale > self.size: | |
| instance_prompt = np.random.choice( | |
| ["zoomed in ", "close up "]) + instance_prompt | |
| example["instance_images"] = torch.from_numpy( | |
| instance_image).permute(2, 0, 1) | |
| example["mask"] = torch.from_numpy(mask) | |
| example["instance_prompt_ids"] = self.tokenizer( | |
| instance_prompt, | |
| truncation=True, | |
| padding="max_length", | |
| max_length=self.tokenizer.model_max_length, | |
| return_tensors="pt", | |
| ).input_ids | |
| example["instance_anchor_prompt_ids"] = self.tokenizer( | |
| instance_anchor_prompt, | |
| truncation=True, | |
| padding="max_length", | |
| max_length=self.tokenizer.model_max_length, | |
| return_tensors="pt", | |
| ).input_ids | |
| if self.with_prior_preservation: | |
| class_image, class_prompt = self.class_images_path[index % | |
| self.num_class_images] | |
| class_image = Image.open(class_image) | |
| if not class_image.mode == "RGB": | |
| class_image = class_image.convert("RGB") | |
| example["class_images"] = self.image_transforms(class_image) | |
| example["class_mask"] = torch.ones_like(example["mask"]) | |
| example["class_prompt_ids"] = self.tokenizer( | |
| class_prompt, | |
| truncation=True, | |
| padding="max_length", | |
| max_length=self.tokenizer.model_max_length, | |
| return_tensors="pt", | |
| ).input_ids | |
| return example | |
| def isimage(path): | |
| if 'png' in path.lower() or 'jpg' in path.lower() or 'jpeg' in path.lower(): | |
| return True | |
| def filter(folder, impath, outpath=None, unfiltered_path=None, threshold=0.15, | |
| image_threshold=0.5, anchor_size=10, target_size=3, return_score=False): | |
| model = torch.jit.load( | |
| "./assets/sscd_imagenet_mixup.torchscript.pt") | |
| if isinstance(folder, list): | |
| image_paths = folder | |
| image_captions = ["None" for _ in range(len(image_paths))] | |
| elif Path(folder / 'images.txt').exists(): | |
| with open(f'{folder}/images.txt', "r") as f: | |
| image_paths = f.read().splitlines() | |
| with open(f'{folder}/caption.txt', "r") as f: | |
| image_captions = f.read().splitlines() | |
| else: | |
| image_paths = [os.path.join(str(folder), file_path) | |
| for file_path in os.listdir(folder) if isimage(file_path)] | |
| image_captions = ["None" for _ in range(len(image_paths))] | |
| batch = small_288(Image.open(impath).convert('RGB')).unsqueeze(0) | |
| embedding_target = model(batch)[0, :] | |
| filtered_paths = [] | |
| filtered_captions = [] | |
| unfiltered_paths = [] | |
| unfiltered_captions = [] | |
| count_dict = {} | |
| for im, c in zip(image_paths, image_captions): | |
| if c not in count_dict: | |
| count_dict[c] = 0 | |
| if isinstance(folder, list): | |
| batch = small_288(im).unsqueeze(0) | |
| else: | |
| batch = small_288(Image.open(im).convert('RGB')).unsqueeze(0) | |
| embedding = model(batch)[0, :] | |
| diff_sscd = (embedding * embedding_target).sum() | |
| if diff_sscd <= image_threshold: | |
| filtered_paths.append(im) | |
| filtered_captions.append(c) | |
| count_dict[c] += 1 | |
| else: | |
| unfiltered_paths.append(im) | |
| unfiltered_captions.append(c) | |
| # only return score | |
| if return_score: | |
| score = len(unfiltered_paths) / \ | |
| (len(unfiltered_paths)+len(filtered_paths)) | |
| return score | |
| os.makedirs(outpath, exist_ok=True) | |
| os.makedirs(f'{outpath}/samples', exist_ok=True) | |
| with open(f'{outpath}/caption.txt', 'w') as f: | |
| for each in filtered_captions: | |
| f.write(each.strip() + '\n') | |
| with open(f'{outpath}/images.txt', 'w') as f: | |
| for each in filtered_paths: | |
| f.write(each.strip() + '\n') | |
| imbase = Path(each).name | |
| shutil.copy(each, f'{outpath}/samples/{imbase}') | |
| print('++++++++++++++++++++++++++++++++++++++++++++++++') | |
| print('+ Filter Summary +') | |
| print(f'+ Remained images: {len(filtered_paths)}') | |
| print(f'+ Filtered images: {len(unfiltered_paths)}') | |
| print('++++++++++++++++++++++++++++++++++++++++++++++++') | |
| sorted_list = sorted(list(count_dict.items()), | |
| key=lambda x: x[1], reverse=True) | |
| anchor_prompts = [c[0] for c in sorted_list[:anchor_size]] | |
| target_prompts = [c[0] for c in sorted_list[-target_size:]] | |
| return anchor_prompts, target_prompts, len(filtered_paths) | |
| def getanchorprompts(pipeline, accelerator, class_prompt, concept_type, class_images_dir, api_key, num_class_images=200, mem_impath=None): | |
| openai.api_key = api_key | |
| class_prompt_collection = [] | |
| caption_target = [] | |
| if concept_type == 'object': | |
| messages = [{"role": "system", "content": "You can describe any image via text and provide captions for wide variety of images that is possible to generate."}] | |
| messages = [{"role": "user", "content": f"Generate {num_class_images} captions for images containing a {class_prompt}. The caption should also contain the word \"{class_prompt}\" "}] | |
| while True: | |
| completion = openai.ChatCompletion.create( | |
| model="gpt-3.5-turbo", | |
| messages=messages | |
| ) | |
| class_prompt_collection += [x for x in completion.choices[0].message.content.lower( | |
| ).split('\n') if class_prompt in x] | |
| messages.append( | |
| {"role": "assistant", "content": completion.choices[0].message.content}) | |
| messages.append( | |
| {"role": "user", "content": f"Generate {num_class_images-len(class_prompt_collection)} more captions"}) | |
| if len(class_prompt_collection) >= num_class_images: | |
| break | |
| class_prompt_collection = clean_prompt(class_prompt_collection)[ | |
| :num_class_images] | |
| elif concept_type == 'memorization': | |
| pipeline.scheduler = DPMSolverMultistepScheduler.from_config( | |
| pipeline.scheduler.config) | |
| num_prompts_firstpass = 5 | |
| num_prompts_secondpass = 2 | |
| threshold = 0.3 | |
| # Generate num_prompts_firstpass paraphrases which generate different content at least 1-threshold % of the times. | |
| os.makedirs(class_images_dir / 'temp/', exist_ok=True) | |
| class_prompt_collection_counter = [] | |
| caption_target = [] | |
| prev_captions = [] | |
| messages = [{"role": "user", "content": f"Generate {4*num_prompts_firstpass} different paraphrase of the caption: {class_prompt}. Preserve the meaning when paraphrasing."}] | |
| while True: | |
| completion = openai.ChatCompletion.create( | |
| model="gpt-3.5-turbo", | |
| messages=messages | |
| ) | |
| # print(completion.choices[0].message.content.lower().split('\n')) | |
| class_prompt_collection_ = [x.strip( | |
| ) for x in completion.choices[0].message.content.lower().split('\n') if x.strip() != ''] | |
| class_prompt_collection_ = clean_prompt(class_prompt_collection_) | |
| # print(class_prompt_collection_) | |
| for prompt in tqdm( | |
| class_prompt_collection_, desc="Generating anchor and target prompts ", disable=not accelerator.is_local_main_process | |
| ): | |
| print(f'Prompt: {prompt}') | |
| images = pipeline([prompt]*10, num_inference_steps=25,).images | |
| score = filter(images, mem_impath, return_score=True) | |
| print(f'Memorization rate: {score}') | |
| if score <= threshold and prompt not in class_prompt_collection and len(class_prompt_collection) < num_prompts_firstpass: | |
| class_prompt_collection += [prompt] | |
| class_prompt_collection_counter += [score] | |
| elif score >= 0.6 and prompt not in caption_target and len(caption_target) < 2: | |
| caption_target += [prompt] | |
| if len(class_prompt_collection) >= num_prompts_firstpass and len(caption_target) >= 2: | |
| break | |
| if len(class_prompt_collection) >= num_prompts_firstpass: | |
| break | |
| # print("prompts till now", class_prompt_collection, caption_target) | |
| # print("prompts till now", len( | |
| # class_prompt_collection), len(caption_target)) | |
| prev_captions += class_prompt_collection_ | |
| prev_captions_ = ','.join(prev_captions[-40:]) | |
| messages = [ | |
| {"role": "user", "content": f"Generate {4*(num_prompts_firstpass- len(class_prompt_collection))} different paraphrase of the caption: {class_prompt}. Preserve the meaning the most when paraphrasing. Also make sure that the new captions are different from the following captions: {prev_captions_[:4000]}"}] | |
| # Generate more paraphrases using the captions we retrieved above. | |
| for prompt in class_prompt_collection[:num_prompts_firstpass]: | |
| completion = openai.ChatCompletion.create( | |
| model="gpt-3.5-turbo", | |
| messages=[ | |
| {"role": "user", "content": f"Generate {num_prompts_secondpass} different paraphrases of: {prompt}. "}] | |
| ) | |
| class_prompt_collection += clean_prompt( | |
| [x.strip() for x in completion.choices[0].message.content.lower().split('\n') if x.strip() != '']) | |
| for prompt in tqdm(class_prompt_collection[num_prompts_firstpass:], desc="Memorization rate for final prompts"): | |
| images = pipeline([prompt]*10, num_inference_steps=25,).images | |
| class_prompt_collection_counter += [ | |
| filter(images, mem_impath, return_score=True)] | |
| # select least ten and most memorized text prompts to be selected as anchor and target prompts. | |
| class_prompt_collection = sorted( | |
| zip(class_prompt_collection, class_prompt_collection_counter), key=lambda x: x[1]) | |
| caption_target += [x for (x, y) in class_prompt_collection if y >= 0.6] | |
| class_prompt_collection = [ | |
| x for (x, y) in class_prompt_collection if y <= threshold][:10] | |
| print("Anchor prompts:", class_prompt_collection) | |
| print("Target prompts:", caption_target) | |
| return class_prompt_collection, ';*+'.join(caption_target) | |
| def clean_prompt(class_prompt_collection): | |
| class_prompt_collection = [re.sub( | |
| r"[0-9]+", lambda num: '' * len(num.group(0)), prompt) for prompt in class_prompt_collection] | |
| class_prompt_collection = [re.sub( | |
| r"^\.+", lambda dots: '' * len(dots.group(0)), prompt) for prompt in class_prompt_collection] | |
| class_prompt_collection = [x.strip() for x in class_prompt_collection] | |
| class_prompt_collection = [x.replace('"', '') for x in class_prompt_collection] | |
| return class_prompt_collection | |
| def safe_dir(dir): | |
| if not dir.exists(): | |
| dir.mkdir() | |
| return dir | |