Spaces:
Runtime error
Runtime error
| import torch | |
| import numpy as np | |
| import os | |
| from os.path import join, isdir, isfile, expanduser | |
| from PIL import Image | |
| from torchvision import transforms | |
| from torchvision.transforms.transforms import Resize | |
| from torch.nn import functional as nnf | |
| from general_utils import get_from_repository | |
| from skimage.draw import polygon2mask | |
| def random_crop_slices(origin_size, target_size): | |
| """Gets slices of a random crop. """ | |
| assert origin_size[0] >= target_size[0] and origin_size[1] >= target_size[1], f'actual size: {origin_size}, target size: {target_size}' | |
| offset_y = torch.randint(0, origin_size[0] - target_size[0] + 1, (1,)).item() # range: 0 <= value < high | |
| offset_x = torch.randint(0, origin_size[1] - target_size[1] + 1, (1,)).item() | |
| return slice(offset_y, offset_y + target_size[0]), slice(offset_x, offset_x + target_size[1]) | |
| def find_crop(seg, image_size, iterations=1000, min_frac=None, best_of=None): | |
| best_crops = [] | |
| best_crop_not_ok = float('-inf'), None, None | |
| min_sum = 0 | |
| seg = seg.astype('bool') | |
| if min_frac is not None: | |
| #min_sum = seg.sum() * min_frac | |
| min_sum = seg.shape[0] * seg.shape[1] * min_frac | |
| for iteration in range(iterations): | |
| sl_y, sl_x = random_crop_slices(seg.shape, image_size) | |
| seg_ = seg[sl_y, sl_x] | |
| sum_seg_ = seg_.sum() | |
| if sum_seg_ > min_sum: | |
| if best_of is None: | |
| return sl_y, sl_x, False | |
| else: | |
| best_crops += [(sum_seg_, sl_y, sl_x)] | |
| if len(best_crops) >= best_of: | |
| best_crops.sort(key=lambda x:x[0], reverse=True) | |
| sl_y, sl_x = best_crops[0][1:] | |
| return sl_y, sl_x, False | |
| else: | |
| if sum_seg_ > best_crop_not_ok[0]: | |
| best_crop_not_ok = sum_seg_, sl_y, sl_x | |
| else: | |
| # return best segmentation found | |
| return best_crop_not_ok[1:] + (best_crop_not_ok[0] <= min_sum,) | |
| class PhraseCut(object): | |
| def __init__(self, split, image_size=400, negative_prob=0, aug=None, aug_color=False, aug_crop=True, | |
| min_size=0, remove_classes=None, with_visual=False, only_visual=False, mask=None): | |
| super().__init__() | |
| self.negative_prob = negative_prob | |
| self.image_size = image_size | |
| self.with_visual = with_visual | |
| self.only_visual = only_visual | |
| self.phrase_form = '{}' | |
| self.mask = mask | |
| self.aug_crop = aug_crop | |
| if aug_color: | |
| self.aug_color = transforms.Compose([ | |
| transforms.ColorJitter(0.5, 0.5, 0.2, 0.05), | |
| ]) | |
| else: | |
| self.aug_color = None | |
| get_from_repository('PhraseCut', ['PhraseCut.tar'], integrity_check=lambda local_dir: all([ | |
| isdir(join(local_dir, 'VGPhraseCut_v0')), | |
| isdir(join(local_dir, 'VGPhraseCut_v0', 'images')), | |
| isfile(join(local_dir, 'VGPhraseCut_v0', 'refer_train.json')), | |
| len(os.listdir(join(local_dir, 'VGPhraseCut_v0', 'images'))) in {108250, 108249} | |
| ])) | |
| from third_party.PhraseCutDataset.utils.refvg_loader import RefVGLoader | |
| self.refvg_loader = RefVGLoader(split=split) | |
| # img_ids where the size in the annotations does not match actual size | |
| invalid_img_ids = set([150417, 285665, 498246, 61564, 285743, 498269, 498010, 150516, 150344, 286093, 61530, | |
| 150333, 286065, 285814, 498187, 285761, 498042]) | |
| mean = [0.485, 0.456, 0.406] | |
| std = [0.229, 0.224, 0.225] | |
| self.normalize = transforms.Normalize(mean, std) | |
| self.sample_ids = [(i, j) | |
| for i in self.refvg_loader.img_ids | |
| for j in range(len(self.refvg_loader.get_img_ref_data(i)['phrases'])) | |
| if i not in invalid_img_ids] | |
| # self.all_phrases = list(set([p for i in self.refvg_loader.img_ids for p in self.refvg_loader.get_img_ref_data(i)['phrases']])) | |
| from nltk.stem import WordNetLemmatizer | |
| wnl = WordNetLemmatizer() | |
| # Filter by class (if remove_classes is set) | |
| if remove_classes is None: | |
| pass | |
| else: | |
| from datasets.generate_lvis_oneshot import PASCAL_SYNSETS, traverse_lemmas, traverse_lemmas_hypo | |
| from nltk.corpus import wordnet | |
| print('remove pascal classes...') | |
| get_data = self.refvg_loader.get_img_ref_data # shortcut | |
| keep_sids = None | |
| if remove_classes[0] == 'pas5i': | |
| subset_id = remove_classes[1] | |
| from datasets.generate_lvis_oneshot import PASCAL_5I_SYNSETS_ORDERED, PASCAL_5I_CLASS_IDS | |
| avoid = [PASCAL_5I_SYNSETS_ORDERED[i] for i in range(20) if i+1 not in PASCAL_5I_CLASS_IDS[subset_id]] | |
| elif remove_classes[0] == 'zs': | |
| stop = remove_classes[1] | |
| from datasets.pascal_zeroshot import PASCAL_VOC_CLASSES_ZS | |
| avoid = [c for class_set in PASCAL_VOC_CLASSES_ZS[:stop] for c in class_set] | |
| print(avoid) | |
| elif remove_classes[0] == 'aff': | |
| # avoid = ['drink.v.01', 'sit.v.01', 'ride.v.02'] | |
| # all_lemmas = set(['drink', 'sit', 'ride']) | |
| avoid = ['drink', 'drinks', 'drinking', 'sit', 'sits', 'sitting', | |
| 'ride', 'rides', 'riding', | |
| 'fly', 'flies', 'flying', 'drive', 'drives', 'driving', 'driven', | |
| 'swim', 'swims', 'swimming', | |
| 'wheels', 'wheel', 'legs', 'leg', 'ear', 'ears'] | |
| keep_sids = [(i, j) for i, j in self.sample_ids if | |
| all(x not in avoid for x in get_data(i)['phrases'][j].split(' '))] | |
| print('avoid classes:', avoid) | |
| if keep_sids is None: | |
| all_lemmas = [s for ps in avoid for s in traverse_lemmas_hypo(wordnet.synset(ps), max_depth=None)] | |
| all_lemmas = list(set(all_lemmas)) | |
| all_lemmas = [h.replace('_', ' ').lower() for h in all_lemmas] | |
| all_lemmas = set(all_lemmas) | |
| # divide into multi word and single word | |
| all_lemmas_s = set(l for l in all_lemmas if ' ' not in l) | |
| all_lemmas_m = set(l for l in all_lemmas if l not in all_lemmas_s) | |
| # new3 | |
| phrases = [get_data(i)['phrases'][j] for i, j in self.sample_ids] | |
| remove_sids = set((i,j) for (i,j), phrase in zip(self.sample_ids, phrases) | |
| if any(l in phrase for l in all_lemmas_m) or | |
| len(set(wnl.lemmatize(w) for w in phrase.split(' ')).intersection(all_lemmas_s)) > 0 | |
| ) | |
| keep_sids = [(i, j) for i, j in self.sample_ids if (i,j) not in remove_sids] | |
| print(f'Reduced to {len(keep_sids) / len(self.sample_ids):.3f}') | |
| removed_ids = set(self.sample_ids) - set(keep_sids) | |
| print('Examples of removed', len(removed_ids)) | |
| for i, j in list(removed_ids)[:20]: | |
| print(i, get_data(i)['phrases'][j]) | |
| self.sample_ids = keep_sids | |
| from itertools import groupby | |
| samples_by_phrase = [(self.refvg_loader.get_img_ref_data(i)['phrases'][j], (i, j)) | |
| for i, j in self.sample_ids] | |
| samples_by_phrase = sorted(samples_by_phrase) | |
| samples_by_phrase = groupby(samples_by_phrase, key=lambda x: x[0]) | |
| self.samples_by_phrase = {prompt: [s[1] for s in prompt_sample_ids] for prompt, prompt_sample_ids in samples_by_phrase} | |
| self.all_phrases = list(set(self.samples_by_phrase.keys())) | |
| if self.only_visual: | |
| assert self.with_visual | |
| self.sample_ids = [(i, j) for i, j in self.sample_ids | |
| if len(self.samples_by_phrase[self.refvg_loader.get_img_ref_data(i)['phrases'][j]]) > 1] | |
| # Filter by size (if min_size is set) | |
| sizes = [self.refvg_loader.get_img_ref_data(i)['gt_boxes'][j] for i, j in self.sample_ids] | |
| image_sizes = [self.refvg_loader.get_img_ref_data(i)['width'] * self.refvg_loader.get_img_ref_data(i)['height'] for i, j in self.sample_ids] | |
| #self.sizes = [sum([(s[2] - s[0]) * (s[3] - s[1]) for s in size]) for size in sizes] | |
| self.sizes = [sum([s[2] * s[3] for s in size]) / img_size for size, img_size in zip(sizes, image_sizes)] | |
| if min_size: | |
| print('filter by size') | |
| self.sample_ids = [self.sample_ids[i] for i in range(len(self.sample_ids)) if self.sizes[i] > min_size] | |
| self.base_path = join(expanduser('~/datasets/PhraseCut/VGPhraseCut_v0/images/')) | |
| def __len__(self): | |
| return len(self.sample_ids) | |
| def load_sample(self, sample_i, j): | |
| img_ref_data = self.refvg_loader.get_img_ref_data(sample_i) | |
| polys_phrase0 = img_ref_data['gt_Polygons'][j] | |
| phrase = img_ref_data['phrases'][j] | |
| phrase = self.phrase_form.format(phrase) | |
| masks = [] | |
| for polys in polys_phrase0: | |
| for poly in polys: | |
| poly = [p[::-1] for p in poly] # swap x,y | |
| masks += [polygon2mask((img_ref_data['height'], img_ref_data['width']), poly)] | |
| seg = np.stack(masks).max(0) | |
| img = np.array(Image.open(join(self.base_path, str(img_ref_data['image_id']) + '.jpg'))) | |
| min_shape = min(img.shape[:2]) | |
| if self.aug_crop: | |
| sly, slx, exceed = find_crop(seg, (min_shape, min_shape), iterations=50, min_frac=0.05) | |
| else: | |
| sly, slx = slice(0, None), slice(0, None) | |
| seg = seg[sly, slx] | |
| img = img[sly, slx] | |
| seg = seg.astype('uint8') | |
| seg = torch.from_numpy(seg).view(1, 1, *seg.shape) | |
| if img.ndim == 2: | |
| img = np.dstack([img] * 3) | |
| img = torch.from_numpy(img).permute(2,0,1).unsqueeze(0).float() | |
| seg = nnf.interpolate(seg, (self.image_size, self.image_size), mode='nearest')[0,0] | |
| img = nnf.interpolate(img, (self.image_size, self.image_size), mode='bilinear', align_corners=True)[0] | |
| # img = img.permute([2,0, 1]) | |
| img = img / 255.0 | |
| if self.aug_color is not None: | |
| img = self.aug_color(img) | |
| img = self.normalize(img) | |
| return img, seg, phrase | |
| def __getitem__(self, i): | |
| sample_i, j = self.sample_ids[i] | |
| img, seg, phrase = self.load_sample(sample_i, j) | |
| if self.negative_prob > 0: | |
| if torch.rand((1,)).item() < self.negative_prob: | |
| new_phrase = None | |
| while new_phrase is None or new_phrase == phrase: | |
| idx = torch.randint(0, len(self.all_phrases), (1,)).item() | |
| new_phrase = self.all_phrases[idx] | |
| phrase = new_phrase | |
| seg = torch.zeros_like(seg) | |
| if self.with_visual: | |
| # find a corresponding visual image | |
| if phrase in self.samples_by_phrase and len(self.samples_by_phrase[phrase]) > 1: | |
| idx = torch.randint(0, len(self.samples_by_phrase[phrase]), (1,)).item() | |
| other_sample = self.samples_by_phrase[phrase][idx] | |
| #print(other_sample) | |
| img_s, seg_s, _ = self.load_sample(*other_sample) | |
| from datasets.utils import blend_image_segmentation | |
| if self.mask in {'separate', 'text_and_separate'}: | |
| # assert img.shape[1:] == img_s.shape[1:] == seg_s.shape == seg.shape[1:] | |
| add_phrase = [phrase] if self.mask == 'text_and_separate' else [] | |
| vis_s = add_phrase + [img_s, seg_s, True] | |
| else: | |
| if self.mask.startswith('text_and_'): | |
| mask_mode = self.mask[9:] | |
| label_add = [phrase] | |
| else: | |
| mask_mode = self.mask | |
| label_add = [] | |
| masked_img_s = torch.from_numpy(blend_image_segmentation(img_s, seg_s, mode=mask_mode, image_size=self.image_size)[0]) | |
| vis_s = label_add + [masked_img_s, True] | |
| else: | |
| # phrase is unique | |
| vis_s = torch.zeros_like(img) | |
| if self.mask in {'separate', 'text_and_separate'}: | |
| add_phrase = [phrase] if self.mask == 'text_and_separate' else [] | |
| vis_s = add_phrase + [vis_s, torch.zeros(*vis_s.shape[1:], dtype=torch.uint8), False] | |
| elif self.mask.startswith('text_and_'): | |
| vis_s = [phrase, vis_s, False] | |
| else: | |
| vis_s = [vis_s, False] | |
| else: | |
| assert self.mask == 'text' | |
| vis_s = [phrase] | |
| seg = seg.unsqueeze(0).float() | |
| data_x = (img,) + tuple(vis_s) | |
| return data_x, (seg, torch.zeros(0), i) | |
| class PhraseCutPlus(PhraseCut): | |
| def __init__(self, split, image_size=400, aug=None, aug_color=False, aug_crop=True, min_size=0, remove_classes=None, only_visual=False, mask=None): | |
| super().__init__(split, image_size=image_size, negative_prob=0.2, aug=aug, aug_color=aug_color, aug_crop=aug_crop, min_size=min_size, | |
| remove_classes=remove_classes, with_visual=True, only_visual=only_visual, mask=mask) |