| import os |
|
|
| import json |
|
|
| import random |
|
|
| import torch |
|
|
| import ijson |
|
|
| import numpy as np |
|
|
| from PIL import Image |
|
|
| from torchvision.transforms import ToTensor |
|
|
| from torchvision.ops import box_convert, clip_boxes_to_image |
|
|
| from re_classifier import REClassifier |
|
|
| from utils import progressbar |
|
|
|
|
| def collate_fn(batch): |
| image = torch.stack([s['image'] for s in batch], dim=0) |
|
|
| image_size = torch.FloatTensor([s['image_size'] for s in batch]) |
|
|
| |
| bbox = torch.cat([s['bbox'] for s in batch], dim=0) |
|
|
| |
| bbox_raw = torch.cat([s['bbox_raw'] for s in batch], dim=0) |
|
|
| expr = [s['expr'] for s in batch] |
|
|
| tok = None |
| if batch[0]['tok'] is not None: |
| tok = { |
| 'input_ids': torch.cat([s['tok']['input_ids'] for s in batch], dim=0), |
| 'attention_mask': torch.cat([s['tok']['attention_mask'] for s in batch], dim=0) |
| } |
|
|
| |
| max_length = max([s['tok']['length'] for s in batch]) |
| tok = { |
| 'input_ids': tok['input_ids'][:, :max_length], |
| 'attention_mask': tok['attention_mask'][:, :max_length], |
| } |
|
|
| mask = None |
| if batch[0]['mask'] is not None: |
| mask = torch.stack([s['mask'] for s in batch], dim=0) |
|
|
| mask_bbox = None |
| if batch[0]['mask_bbox'] is not None: |
| mask_bbox = torch.stack([s['mask_bbox'] for s in batch], dim=0) |
|
|
| tr_param = [s['tr_param'] for s in batch] |
|
|
| return { |
| 'image': image, |
| 'image_size': image_size, |
| 'bbox': bbox, |
| 'bbox_raw': bbox_raw, |
| 'expr': expr, |
| 'tok': tok, |
| 'tr_param': tr_param, |
| 'mask': mask, |
| 'mask_bbox': mask_bbox, |
| } |
|
|
|
|
| class RECDataset(torch.utils.data.Dataset): |
| def __init__(self, transform=None, tokenizer=None, max_length=32, with_mask_bbox=False): |
| super().__init__() |
| self.samples = [] |
| self.transform = transform |
| self.tokenizer = tokenizer |
| self.max_length = int(max_length) |
| self.with_mask_bbox = bool(with_mask_bbox) |
|
|
| def tokenize(self, inp, max_length): |
| return self.tokenizer( |
| inp, |
| return_tensors='pt', |
| padding='max_length', |
| return_token_type_ids=False, |
| return_attention_mask=True, |
| add_special_tokens=True, |
| truncation=True, |
| max_length=max_length |
| ) |
|
|
| def print_stats(self): |
| print(f'{len(self.samples)} samples') |
| lens = [len(expr.split()) for _, expr, _ in self.samples] |
| print('expression lengths stats: ' |
| f'min={np.min(lens):.1f}, ' |
| f'mean={np.mean(lens):.1f}, ' |
| f'median={np.median(lens):.1f}, ' |
| f'max={np.max(lens):.1f}, ' |
| f'99.9P={np.percentile(lens, 99.9):.1f}' |
| ) |
|
|
| def __len__(self): |
| return len(self.samples) |
|
|
| def __getitem__(self, idx): |
| file_name, expr, bbox = self.samples[idx] |
|
|
| if not os.path.exists(file_name): |
| raise IOError(f'{file_name} not found') |
| img = Image.open(file_name).convert('RGB') |
|
|
| |
| |
|
|
| |
| W0, H0 = img.size |
|
|
| |
| |
| |
|
|
| sample = { |
| 'image': img, |
| 'image_size': (H0, W0), |
| 'bbox': bbox.clone(), |
| 'bbox_raw': bbox.clone(), |
| 'expr': expr, |
| 'tok': None, |
| 'mask': torch.ones((1, H0, W0), dtype=torch.float32), |
| 'mask_bbox': None, |
| } |
|
|
| |
| if self.transform is None: |
| sample['image'] = ToTensor()(sample['image']) |
| else: |
| sample = self.transform(sample) |
|
|
| |
| if self.tokenizer is not None: |
| sample['tok'] = self.tokenize(sample['expr'], self.max_length) |
| sample['tok']['length'] = sample['tok']['attention_mask'].sum(1).item() |
|
|
| |
| if self.with_mask_bbox: |
| |
| _, H, W = sample['image'].size() |
|
|
| |
| bbox = sample['bbox'].clone() |
| bbox[:, (0, 2)] *= W |
| bbox[:, (1, 3)] *= H |
| bbox = clip_boxes_to_image((bbox + 0.5).long(), (H, W)) |
|
|
| |
| sample['mask_bbox'] = torch.zeros((1, H, W), dtype=torch.float32) |
| for x1, y1, x2, y2 in bbox.tolist(): |
| sample['mask_bbox'][:, y1:y2+1, x1:x2+1] = 1.0 |
|
|
| return sample |
|
|
|
|
| class RegionDescriptionsVisualGnome(RECDataset): |
| def __init__(self, data_root, transform=None, tokenizer=None, |
| max_length=32, with_mask_bbox=False): |
| super().__init__(transform=transform, tokenizer=tokenizer, |
| max_length=max_length, with_mask_bbox=with_mask_bbox) |
|
|
|
|
| |
| |
| try: |
| with open('./refcoco_valtest_ids.txt', 'r') as fh: |
| refcoco_ids = [int(lin.strip()) for lin in fh.readlines()] |
| except: |
| refcoco_ids = [] |
|
|
| def path_from_url(fname): |
| return os.path.join(data_root, fname[fname.index('VG_100K'):]) |
|
|
| with open(os.path.join(data_root, 'image_data.json'), 'r') as f: |
| image_data = { |
| data['image_id']: path_from_url(data['url']) |
| for data in json.load(f) |
| if data['coco_id'] is None or data['coco_id'] not in refcoco_ids |
| } |
| print(f'{len(image_data)} images') |
|
|
| self.samples = [] |
|
|
| with open(os.path.join(data_root, 'region_descriptions.json'), 'r') as f: |
| for record in progressbar(ijson.items(f, 'item.regions.item'), desc='loading data'): |
| if record['image_id'] not in image_data: |
| continue |
| file_name = image_data[record['image_id']] |
|
|
| expr = record['phrase'] |
|
|
| bbox = [record['x'], record['y'], record['width'], record['height']] |
| bbox = torch.atleast_2d(torch.FloatTensor(bbox)) |
| bbox = box_convert(bbox, 'xywh', 'xyxy') |
|
|
| self.samples.append((file_name, expr, bbox)) |
|
|
| self.print_stats() |
|
|
|
|
| class ReferDataset(RECDataset): |
| def __init__(self, data_root, dataset, split_by, split, transform=None, |
| tokenizer=None, max_length=32, with_mask_bbox=False): |
| super().__init__(transform=transform, tokenizer=tokenizer, |
| max_length=max_length, with_mask_bbox=with_mask_bbox) |
|
|
| |
| try: |
| import sys |
| sys.path.append('refer') |
| from refer import REFER |
| except: |
| raise RuntimeError('create a symlink to valid refer compilation ' |
| '(see https://github.com/lichengunc/refer)') |
|
|
| refer = REFER(data_root, dataset, split_by) |
| ref_ids = sorted(refer.getRefIds(split=split)) |
|
|
| self.samples = [] |
|
|
| for rid in progressbar(ref_ids, desc='loading data'): |
| ref = refer.Refs[rid] |
| ann = refer.refToAnn[rid] |
|
|
| file_name = refer.Imgs[ref['image_id']]['file_name'] |
| if dataset == 'refclef': |
| file_name = os.path.join( |
| 'refer', 'data', 'images', 'saiapr_tc-12', file_name |
| ) |
| else: |
| coco_set = file_name.split('_')[1] |
| file_name = os.path.join( |
| 'refer', 'data', 'images', 'mscoco', coco_set, file_name |
| ) |
|
|
| bbox = ann['bbox'] |
| bbox = torch.atleast_2d(torch.FloatTensor(bbox)) |
| bbox = box_convert(bbox, 'xywh', 'xyxy') |
|
|
| sentences = [s['sent'] for s in ref['sentences']] |
| if 'train' in split: |
| sentences = list(set(sentences)) |
| sentences = sorted(sentences) |
|
|
| self.samples += [(file_name, expr, bbox) for expr in sentences] |
|
|
| self.print_stats() |
|
|
|
|
| class RefCLEF(ReferDataset): |
| def __init__(self, *args, **kwargs): |
| assert args[0] in ('train', 'val', 'test') |
| super().__init__('refer/data', 'refclef', 'berkeley', *args, **kwargs) |
|
|
|
|
| class RefCOCO(ReferDataset): |
| def __init__(self, *args, **kwargs): |
| assert args[0] in ('train', 'val', 'trainval', 'testA', 'testB') |
| super().__init__('refer/data', 'refcoco', 'unc', *args, **kwargs) |
|
|
|
|
| class RefCOCOp(ReferDataset): |
| def __init__(self, *args, **kwargs): |
| assert args[0] in ('train', 'val', 'trainval', 'testA', 'testB') |
| super().__init__('refer/data', 'refcoco+', 'unc', *args, **kwargs) |
|
|
|
|
| class RefCOCOg(ReferDataset): |
| def __init__(self, *args, **kwargs): |
| assert args[0] in ('train', 'val', 'test') |
| super().__init__('refer/data', 'refcocog', 'umd', *args, **kwargs) |
|
|