| import json | |
| import os | |
| import random | |
| from tqdm import tqdm | |
| from torch.utils.data import Dataset | |
| from pycocotools.coco import COCO | |
| from pycocotools import mask as maskUtils | |
| from PIL import Image | |
| import cv2 | |
| import random | |
| from torchvision import transforms | |
| from tqdm import tqdm | |
| import pickle | |
| import torch | |
| import numpy as np | |
| import copy | |
| import sys | |
| import shutil | |
| from PIL import Image | |
| from nltk.corpus import wordnet | |
| PIXEL_MEAN = (0.48145466, 0.4578275, 0.40821073) | |
| MASK_FILL = [int(255 * c) for c in PIXEL_MEAN] | |
| clip_standard_transform = transforms.Compose([ | |
| transforms.ToTensor(), | |
| transforms.Resize((224, 224), interpolation=Image.BICUBIC), | |
| transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), | |
| ]) | |
| hi_clip_standard_transform = transforms.Compose([ | |
| transforms.ToTensor(), | |
| transforms.Resize((336, 336), interpolation=Image.BICUBIC), | |
| transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), | |
| ]) | |
| res_clip_standard_transform = transforms.Compose([ | |
| transforms.ToTensor(), | |
| transforms.Resize((336, 336), interpolation=Image.BICUBIC), | |
| transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), | |
| ]) | |
| mask_transform = transforms.Compose([ | |
| transforms.ToTensor(), | |
| transforms.Resize((224, 224)), | |
| transforms.Normalize(0.5, 0.26) | |
| ]) | |
| hi_mask_transform = transforms.Compose([ | |
| transforms.ToTensor(), | |
| transforms.Resize((336, 336)), | |
| transforms.Normalize(0.5, 0.26) | |
| ]) | |
| res_mask_transform = transforms.Compose([ | |
| transforms.ToTensor(), | |
| transforms.Resize((336, 336)), | |
| transforms.Normalize(0.5, 0.26) | |
| ]) | |
| def crop_center(img, croph, cropw): | |
| h, w = img.shape[:2] | |
| starth = h//2 - (croph//2) | |
| startw = w//2 - (cropw//2) | |
| return img[starth:starth+croph, startw:startw+cropw, :] | |
| class Imagenet_S(Dataset): | |
| def __init__(self, ann_file='data/imagenet_s/imagenet_919.json', hi_res=False, all_one=False): | |
| self.anns = json.load(open(ann_file, 'r')) | |
| self.root_pth = 'data/imagenet_s/' | |
| cats = [] | |
| for ann in self.anns: | |
| if ann['category_word'] not in cats: | |
| cats.append(ann['category_word']) | |
| ann['cat_index'] = len(cats) - 1 | |
| self.classes = [] | |
| for cat_word in cats: | |
| synset = wordnet.synset_from_pos_and_offset('n', int(cat_word[1:])) | |
| synonyms = [x.name() for x in synset.lemmas()] | |
| self.classes.append(synonyms[0]) | |
| self.choice = "center_crop" | |
| if hi_res: | |
| self.mask_transform = res_mask_transform | |
| self.clip_standard_transform = res_clip_standard_transform | |
| else: | |
| self.mask_transform = mask_transform | |
| self.clip_standard_transform = clip_standard_transform | |
| self.all_one = all_one | |
| def __len__(self): | |
| return len(self.anns) | |
| def __getitem__(self, index): | |
| ann = self.anns[index] | |
| image = cv2.imread(self.root_pth + ann['image_pth']) | |
| image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
| mask = maskUtils.decode(ann['mask']) | |
| rgba = np.concatenate((image, np.expand_dims(mask, axis=-1)), axis=-1) | |
| h, w = rgba.shape[:2] | |
| if self.choice == "padding": | |
| if max(h, w) == w: | |
| pad = (w - h) // 2 | |
| l, r = pad, w - h - pad | |
| rgba = np.pad(rgba, ((l, r), (0, 0), (0, 0)), 'constant', constant_values=0) | |
| else: | |
| pad = (h - w) // 2 | |
| l, r = pad, h - w - pad | |
| rgba = np.pad(rgba, ((0, 0), (l, r), (0, 0)), 'constant', constant_values=0) | |
| else: | |
| if min(h, w) == h: | |
| rgba = crop_center(rgba, h, h) | |
| else: | |
| rgba = crop_center(rgba, w, w) | |
| rgb = rgba[:, :, :-1] | |
| mask = rgba[:, :, -1] | |
| image_torch = self.clip_standard_transform(rgb) | |
| bi_mask = mask == 1 | |
| h, w = bi_mask.shape[-2:] | |
| in_height = np.max(bi_mask, axis=-1) | |
| in_height_coords = np.max(bi_mask, axis=-1) * np.arange(h) | |
| b_e = in_height_coords.max() | |
| in_height_coords = in_height_coords + h * (~in_height) | |
| t_e = in_height_coords.min() | |
| in_width = np.max(bi_mask, axis=-2) | |
| in_width_coords = np.max(bi_mask, axis=-2) * np.arange(w) | |
| r_e = in_width_coords.max() | |
| in_width_coords = in_width_coords + w * (~in_width) | |
| l_e = in_width_coords.min() | |
| if self.all_one: | |
| mask_torch = self.mask_transform(np.ones_like(mask) * 255) | |
| else: | |
| mask_torch = self.mask_transform(mask * 255) | |
| return image_torch, mask_torch, ann['cat_index'] | |
| if __name__ == "__main__": | |
| data = Imagenet_S() | |
| for i in tqdm(range(data.__len__())): | |
| data.__getitem__(i) |