| import os |
| import json |
| import albumentations |
| import numpy as np |
| from PIL import Image |
| from tqdm import tqdm |
| from torch.utils.data import Dataset |
|
|
| from taming.data.sflckr import SegmentationBase |
|
|
|
|
| class Examples(SegmentationBase): |
| def __init__(self, size=256, random_crop=False, interpolation="bicubic"): |
| super().__init__(data_csv="data/coco_examples.txt", |
| data_root="data/coco_images", |
| segmentation_root="data/coco_segmentations", |
| size=size, random_crop=random_crop, |
| interpolation=interpolation, |
| n_labels=183, shift_segmentation=True) |
|
|
|
|
| class CocoBase(Dataset): |
| """needed for (image, caption, segmentation) pairs""" |
| def __init__(self, size=None, dataroot="", datajson="", onehot_segmentation=False, use_stuffthing=False, |
| crop_size=None, force_no_crop=False, given_files=None): |
| self.split = self.get_split() |
| self.size = size |
| if crop_size is None: |
| self.crop_size = size |
| else: |
| self.crop_size = crop_size |
|
|
| self.onehot = onehot_segmentation |
| self.stuffthing = use_stuffthing |
| if self.onehot and not self.stuffthing: |
| raise NotImplemented("One hot mode is only supported for the " |
| "stuffthings version because labels are stored " |
| "a bit different.") |
|
|
| data_json = datajson |
| with open(data_json) as json_file: |
| self.json_data = json.load(json_file) |
| self.img_id_to_captions = dict() |
| self.img_id_to_filepath = dict() |
| self.img_id_to_segmentation_filepath = dict() |
|
|
| assert data_json.split("/")[-1] in ["captions_train2017.json", |
| "captions_val2017.json"] |
| if self.stuffthing: |
| self.segmentation_prefix = ( |
| "data/cocostuffthings/val2017" if |
| data_json.endswith("captions_val2017.json") else |
| "data/cocostuffthings/train2017") |
| else: |
| self.segmentation_prefix = ( |
| "data/coco/annotations/stuff_val2017_pixelmaps" if |
| data_json.endswith("captions_val2017.json") else |
| "data/coco/annotations/stuff_train2017_pixelmaps") |
|
|
| imagedirs = self.json_data["images"] |
| self.labels = {"image_ids": list()} |
| for imgdir in tqdm(imagedirs, desc="ImgToPath"): |
| self.img_id_to_filepath[imgdir["id"]] = os.path.join(dataroot, imgdir["file_name"]) |
| self.img_id_to_captions[imgdir["id"]] = list() |
| pngfilename = imgdir["file_name"].replace("jpg", "png") |
| self.img_id_to_segmentation_filepath[imgdir["id"]] = os.path.join( |
| self.segmentation_prefix, pngfilename) |
| if given_files is not None: |
| if pngfilename in given_files: |
| self.labels["image_ids"].append(imgdir["id"]) |
| else: |
| self.labels["image_ids"].append(imgdir["id"]) |
|
|
| capdirs = self.json_data["annotations"] |
| for capdir in tqdm(capdirs, desc="ImgToCaptions"): |
| |
| self.img_id_to_captions[capdir["image_id"]].append(np.array([capdir["caption"]])) |
|
|
| self.rescaler = albumentations.SmallestMaxSize(max_size=self.size) |
| if self.split=="validation": |
| self.cropper = albumentations.CenterCrop(height=self.crop_size, width=self.crop_size) |
| else: |
| self.cropper = albumentations.RandomCrop(height=self.crop_size, width=self.crop_size) |
| self.preprocessor = albumentations.Compose( |
| [self.rescaler, self.cropper], |
| additional_targets={"segmentation": "image"}) |
| if force_no_crop: |
| self.rescaler = albumentations.Resize(height=self.size, width=self.size) |
| self.preprocessor = albumentations.Compose( |
| [self.rescaler], |
| additional_targets={"segmentation": "image"}) |
|
|
| def __len__(self): |
| return len(self.labels["image_ids"]) |
|
|
| def preprocess_image(self, image_path, segmentation_path): |
| image = Image.open(image_path) |
| if not image.mode == "RGB": |
| image = image.convert("RGB") |
| image = np.array(image).astype(np.uint8) |
|
|
| segmentation = Image.open(segmentation_path) |
| if not self.onehot and not segmentation.mode == "RGB": |
| segmentation = segmentation.convert("RGB") |
| segmentation = np.array(segmentation).astype(np.uint8) |
| if self.onehot: |
| assert self.stuffthing |
| |
| |
| |
| |
| |
| |
| assert segmentation.dtype == np.uint8 |
| segmentation = segmentation + 1 |
|
|
| processed = self.preprocessor(image=image, segmentation=segmentation) |
| image, segmentation = processed["image"], processed["segmentation"] |
| image = (image / 127.5 - 1.0).astype(np.float32) |
|
|
| if self.onehot: |
| assert segmentation.dtype == np.uint8 |
| |
| n_labels = 183 |
| flatseg = np.ravel(segmentation) |
| onehot = np.zeros((flatseg.size, n_labels), dtype=np.bool) |
| onehot[np.arange(flatseg.size), flatseg] = True |
| onehot = onehot.reshape(segmentation.shape + (n_labels,)).astype(int) |
| segmentation = onehot |
| else: |
| segmentation = (segmentation / 127.5 - 1.0).astype(np.float32) |
| return image, segmentation |
|
|
| def __getitem__(self, i): |
| img_path = self.img_id_to_filepath[self.labels["image_ids"][i]] |
| seg_path = self.img_id_to_segmentation_filepath[self.labels["image_ids"][i]] |
| image, segmentation = self.preprocess_image(img_path, seg_path) |
| captions = self.img_id_to_captions[self.labels["image_ids"][i]] |
| |
| caption = captions[np.random.randint(0, len(captions))] |
| example = {"image": image, |
| "caption": [str(caption[0])], |
| "segmentation": segmentation, |
| "img_path": img_path, |
| "seg_path": seg_path, |
| "filename_": img_path.split(os.sep)[-1] |
| } |
| return example |
|
|
|
|
| class CocoImagesAndCaptionsTrain(CocoBase): |
| """returns a pair of (image, caption)""" |
| def __init__(self, size, onehot_segmentation=False, use_stuffthing=False, crop_size=None, force_no_crop=False): |
| super().__init__(size=size, |
| dataroot="data/coco/train2017", |
| datajson="data/coco/annotations/captions_train2017.json", |
| onehot_segmentation=onehot_segmentation, |
| use_stuffthing=use_stuffthing, crop_size=crop_size, force_no_crop=force_no_crop) |
|
|
| def get_split(self): |
| return "train" |
|
|
|
|
| class CocoImagesAndCaptionsValidation(CocoBase): |
| """returns a pair of (image, caption)""" |
| def __init__(self, size, onehot_segmentation=False, use_stuffthing=False, crop_size=None, force_no_crop=False, |
| given_files=None): |
| super().__init__(size=size, |
| dataroot="data/coco/val2017", |
| datajson="data/coco/annotations/captions_val2017.json", |
| onehot_segmentation=onehot_segmentation, |
| use_stuffthing=use_stuffthing, crop_size=crop_size, force_no_crop=force_no_crop, |
| given_files=given_files) |
|
|
| def get_split(self): |
| return "validation" |
|
|