Spaces:
Running
on
Zero
Running
on
Zero
| # Authors: Hui Ren (rhfeiyang.github.io) | |
| import random | |
| import torch.utils.data as data | |
| from PIL import Image | |
| import os | |
| import torch | |
| # from tqdm import tqdm | |
| class ImageSet(data.Dataset): | |
| def __init__(self, folder , transform=None, keep_in_mem=True, caption=None): | |
| self.path = folder | |
| self.transform = transform | |
| self.caption_path = None | |
| self.images = [] | |
| self.captions = [] | |
| self.keep_in_mem = keep_in_mem | |
| if not isinstance(folder, list): | |
| self.image_files = [file for file in os.listdir(folder) if file.endswith((".png",".jpg"))] | |
| self.image_files.sort() | |
| else: | |
| self.images = folder | |
| if not isinstance(caption, list): | |
| if caption not in [None, "", "None"]: | |
| self.caption_path = caption | |
| self.caption_files = [os.path.join(caption, file.replace(".png", ".txt").replace(".jpg", ".txt")) for file in self.image_files] | |
| self.caption_files.sort() | |
| else: | |
| self.caption_path = True | |
| self.captions = caption | |
| # get all the image files png/jpg | |
| if keep_in_mem: | |
| if len(self.images) == 0: | |
| for file in self.image_files: | |
| img = self.load_image(os.path.join(self.path, file)) | |
| self.images.append(img) | |
| if len(self.captions) == 0: | |
| if self.caption_path is not None: | |
| self.captions = [] | |
| for file in self.caption_files: | |
| caption = self.load_caption(file) | |
| self.captions.append(caption) | |
| else: | |
| self.images = None | |
| def limit_num(self, n): | |
| raise NotImplementedError | |
| assert n <= len(self), f"n should be less than the length of the dataset {len(self)}" | |
| self.image_files = self.image_files[:n] | |
| self.caption_files = self.caption_files[:n] | |
| if self.keep_in_mem: | |
| self.images = self.images[:n] | |
| self.captions = self.captions[:n] | |
| print(f"Dataset limited to {n}") | |
| def __len__(self): | |
| if len(self.images) != 0: | |
| return len(self.images) | |
| else: | |
| return len(self.image_files) | |
| def load_image(self, path): | |
| with open(path, 'rb') as f: | |
| img = Image.open(f).convert('RGB') | |
| return img | |
| def load_caption(self, path): | |
| with open(path, 'r') as f: | |
| caption = f.readlines() | |
| caption = [line.strip() for line in caption if len(line.strip()) > 0] | |
| return caption | |
| def __getitem__(self, index): | |
| if len(self.images) != 0: | |
| img = self.images[index] | |
| else: | |
| img = self.load_image(os.path.join(self.path, self.image_files[index])) | |
| # if self.transform is not None: | |
| # img = self.transform(img) | |
| if self.caption_path is not None or len(self.captions) != 0: | |
| if len(self.captions) != 0: | |
| caption = self.captions[index] | |
| else: | |
| caption = self.load_caption(self.caption_files[index]) | |
| ret= {"image": img, "caption": caption, "id": index} | |
| else: | |
| ret= {"image": img, "id": index} | |
| if self.transform is not None: | |
| ret = self.transform(ret) | |
| return ret | |
| def subsample(self, n: int = 10): | |
| if n is None or n == -1: | |
| return self | |
| ori_len = len(self) | |
| assert n <= ori_len | |
| # equal interval subsample | |
| ids = self.image_files[::ori_len // n][:n] | |
| self.image_files = ids | |
| if self.keep_in_mem: | |
| self.images = self.images[::ori_len // n][:n] | |
| print(f"Dataset subsampled from {ori_len} to {len(self)}") | |
| return self | |
| def with_transform(self, transform): | |
| self.transform = transform | |
| return self | |
| def collate_fn(examples): | |
| images = [example["image"] for example in examples] | |
| ids = [example["id"] for example in examples] | |
| if "caption" in examples[0]: | |
| captions = [random.choice(example["caption"]) for example in examples] | |
| return {"images": images, "captions": captions, "id": ids} | |
| else: | |
| return {"images": images, "id": ids} | |
| class ImagePair(ImageSet): | |
| def __init__(self, folder1, folder2, transform=None, keep_in_mem=True): | |
| self.path1 = folder1 | |
| self.path2 = folder2 | |
| self.transform = transform | |
| # get all the image files png/jpg | |
| self.image_files = [file for file in os.listdir(folder1) if file.endswith(".png") or file.endswith(".jpg")] | |
| self.image_files.sort() | |
| self.keep_in_mem = keep_in_mem | |
| if keep_in_mem: | |
| self.images = [] | |
| for file in self.image_files: | |
| img1 = self.load_image(os.path.join(self.path1, file)) | |
| img2 = self.load_image(os.path.join(self.path2, file)) | |
| self.images.append((img1, img2)) | |
| else: | |
| self.images = None | |
| def __getitem__(self, index): | |
| if self.keep_in_mem: | |
| img1, img2 = self.images[index] | |
| else: | |
| img1 = self.load_image(os.path.join(self.path1, self.image_files[index])) | |
| img2 = self.load_image(os.path.join(self.path2, self.image_files[index])) | |
| if self.transform is not None: | |
| img1 = self.transform(img1) | |
| img2 = self.transform(img2) | |
| return {"image1": img1, "image2": img2, "id": index} | |
| def collate_fn(examples): | |
| images1 = [example["image1"] for example in examples] | |
| images2 = [example["image2"] for example in examples] | |
| # images1 = torch.stack(images1) | |
| # images2 = torch.stack(images2) | |
| ids = [example["id"] for example in examples] | |
| return {"image1": images1, "image2": images2, "id": ids} | |
| def push_to_huggingface(self, hug_folder): | |
| from datasets import Dataset | |
| from datasets import Image as HugImage | |
| photo_path = [os.path.join(self.path1, file) for file in self.image_files] | |
| sketch_path = [os.path.join(self.path2, file) for file in self.image_files] | |
| dataset = Dataset.from_dict({"photo": photo_path, "sketch": sketch_path, "file_name": self.image_files}) | |
| dataset = dataset.cast_column("photo", HugImage()) | |
| dataset = dataset.cast_column("sketch", HugImage()) | |
| dataset.push_to_hub(hug_folder, private=True) | |
| class ImageClass(ImageSet): | |
| def __init__(self, folders: list, transform=None, keep_in_mem=True): | |
| self.paths = folders | |
| self.transform = transform | |
| # get all the image files png/jpg | |
| self.image_files = [] | |
| self.keep_in_mem = keep_in_mem | |
| for i, folder in enumerate(folders): | |
| self.image_files+=[(os.path.join(folder, file), i) for file in os.listdir(folder) if file.endswith(".png") or file.endswith(".jpg")] | |
| if keep_in_mem: | |
| self.images = [] | |
| print("Loading images to memory") | |
| for file in self.image_files: | |
| img = self.load_image(file[0]) | |
| self.images.append((img, file[1])) | |
| print("Loading images to memory done") | |
| else: | |
| self.images = None | |
| def __getitem__(self, index): | |
| if self.keep_in_mem: | |
| img, label = self.images[index] | |
| else: | |
| img_path, label = self.image_files[index] | |
| img = self.load_image(img_path) | |
| if self.transform is not None: | |
| img = self.transform(img) | |
| return {"image": img, "label": label, "id": index} | |
| def collate_fn(examples): | |
| images = [example["image"] for example in examples] | |
| labels = [example["label"] for example in examples] | |
| ids = [example["id"] for example in examples] | |
| return {"images": images, "labels":labels, "id": ids} | |
| if __name__ == "__main__": | |
| # dataset = ImagePair("/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/imgFolder/clip_filtered_remain_50", | |
| # "/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/imgFolder/sketch_50",keep_in_mem=False) | |
| # dataset.push_to_huggingface("rhfeiyang/photo-sketch-pair-50") | |
| dataset = ImagePair("/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/imgFolder/clip_filtered_remain_500", | |
| "/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/imgFolder/sketch_500", | |
| keep_in_mem=True) | |
| # dataset.push_to_huggingface("rhfeiyang/photo-sketch-pair-500") | |
| # ret = dataset[0] | |
| # print(len(dataset)) | |
| import torch | |
| from torchvision import transforms | |
| train_transforms = transforms.Compose( | |
| [ | |
| transforms.Resize(256, interpolation=transforms.InterpolationMode.BILINEAR), | |
| transforms.CenterCrop(256), | |
| transforms.RandomHorizontalFlip(), | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.5], [0.5]), | |
| ] | |
| ) | |
| dataset = dataset.with_transform(train_transforms) | |
| dataloader = torch.utils.data.DataLoader(dataset, batch_size=4, shuffle=True, num_workers=4, collate_fn=ImagePair.collate_fn) | |
| ret = dataloader.__iter__().__next__() | |
| pass |