| import os |
| import sys |
| import torch.utils.data as data |
| import torch |
| from torchvision import transforms |
| from torch.autograd import Variable |
| import numpy as np |
| from PIL import Image |
| import torchvision.transforms.functional as TF |
| import random |
|
|
| from bert.tokenization_bert import BertTokenizer |
|
|
| import h5py |
| from refer.refer import REFER |
|
|
| from args import get_parser |
|
|
| |
| parser = get_parser() |
| args = parser.parse_args() |
|
|
| from hfai.datasets import CocoDetection |
|
|
| from PIL import Image |
| import numpy as np |
| |
| import ffrecord |
| import pickle |
|
|
| _EXIF_ORIENT = 274 |
| def _apply_exif_orientation(image): |
| """ |
| Applies the exif orientation correctly. |
| |
| This code exists per the bug: |
| https://github.com/python-pillow/Pillow/issues/3973 |
| with the function `ImageOps.exif_transpose`. The Pillow source raises errors with |
| various methods, especially `tobytes` |
| |
| Function based on: |
| https://github.com/wkentaro/labelme/blob/v4.5.4/labelme/utils/image.py#L59 |
| https://github.com/python-pillow/Pillow/blob/7.1.2/src/PIL/ImageOps.py#L527 |
| |
| Args: |
| image (PIL.Image): a PIL image |
| |
| Returns: |
| (PIL.Image): the PIL image with exif orientation applied, if applicable |
| """ |
| if not hasattr(image, "getexif"): |
| return image |
|
|
| try: |
| exif = image.getexif() |
| except Exception: |
| exif = None |
|
|
| if exif is None: |
| return image |
|
|
| orientation = exif.get(_EXIF_ORIENT) |
|
|
| method = { |
| 2: Image.FLIP_LEFT_RIGHT, |
| 3: Image.ROTATE_180, |
| 4: Image.FLIP_TOP_BOTTOM, |
| 5: Image.TRANSPOSE, |
| 6: Image.ROTATE_270, |
| 7: Image.TRANSVERSE, |
| 8: Image.ROTATE_90, |
| }.get(orientation) |
|
|
| if method is not None: |
| return image.transpose(method) |
| return image |
|
|
| def convert_PIL_to_numpy(image, format): |
| """ |
| Convert PIL image to numpy array of target format. |
| |
| Args: |
| image (PIL.Image): a PIL image |
| format (str): the format of output image |
| |
| Returns: |
| (np.ndarray): also see `read_image` |
| """ |
| if format is not None: |
| |
| conversion_format = format |
| if format in ["BGR", "YUV-BT.601"]: |
| conversion_format = "RGB" |
| image = image.convert(conversion_format) |
| image = np.asarray(image) |
| |
| if format == "L": |
| image = np.expand_dims(image, -1) |
|
|
| |
| elif format == "BGR": |
| |
| image = image[:, :, ::-1] |
| elif format == "YUV-BT.601": |
| image = image / 255.0 |
| image = np.dot(image, np.array(_M_RGB2YUV).T) |
|
|
| return image |
|
|
| class ReferDataset(data.Dataset): |
| |
|
|
| def __init__(self, |
| args, |
| image_transforms=None, |
| target_transforms=None, |
| split='train', |
| eval_mode=False): |
|
|
| self.classes = [] |
| self.image_transforms = image_transforms |
| self.target_transform = target_transforms |
| self.split = split |
| self.refer = REFER(args.refer_data_root, args.dataset, args.splitBy) |
|
|
| self.max_tokens = 20 |
|
|
| ref_ids = self.refer.getRefIds(split=self.split) |
| img_ids = self.refer.getImgIds(ref_ids) |
|
|
| all_imgs = self.refer.Imgs |
| self.imgs = list(all_imgs[i] for i in img_ids) |
| self.ref_ids = ref_ids |
|
|
| self.input_ids = [] |
| self.attention_masks = [] |
| self.tokenizer = BertTokenizer.from_pretrained(args.bert_tokenizer) |
|
|
| self.eval_mode = eval_mode |
| |
| |
| for r in ref_ids: |
| ref = self.refer.Refs[r] |
|
|
| sentences_for_ref = [] |
| attentions_for_ref = [] |
|
|
| for i, (el, sent_id) in enumerate(zip(ref['sentences'], ref['sent_ids'])): |
| sentence_raw = el['raw'] |
| attention_mask = [0] * self.max_tokens |
| padded_input_ids = [0] * self.max_tokens |
|
|
| input_ids = self.tokenizer.encode(text=sentence_raw, add_special_tokens=True) |
|
|
| |
| input_ids = input_ids[:self.max_tokens] |
|
|
| padded_input_ids[:len(input_ids)] = input_ids |
| attention_mask[:len(input_ids)] = [1]*len(input_ids) |
|
|
| sentences_for_ref.append(torch.tensor(padded_input_ids).unsqueeze(0)) |
| attentions_for_ref.append(torch.tensor(attention_mask).unsqueeze(0)) |
|
|
| self.input_ids.append(sentences_for_ref) |
| self.attention_masks.append(attentions_for_ref) |
|
|
| split = 'train' |
| print(split) |
| self.hfai_dataset = CocoDetection(split, transform=None) |
| self.keys = {} |
| for i in range(len(self.hfai_dataset.reader.ids)): |
| self.keys[self.hfai_dataset.reader.ids[i]] = i |
|
|
| with open('/ceph-jd/pub/jupyter/zhuangrongxian/notebooks/LAVT-RIS-bidirectional-refactor-mask2former/LAVT-RIS-fuckddp/refcoco.pkl', 'rb') as handle: |
| self.mixed_masks = pickle.load(handle) |
|
|
| def get_classes(self): |
| return self.classes |
|
|
| def __len__(self): |
| return len(self.ref_ids) |
|
|
| def __getitem__(self, index): |
| |
| |
| this_ref_id = self.ref_ids[index] |
| this_img_id = self.refer.getImgIds(this_ref_id) |
| this_img = self.refer.Imgs[this_img_id[0]] |
|
|
| |
| |
| |
| |
| img = self.hfai_dataset.reader.read_imgs([self.keys[this_img_id[0]]])[0] |
| img = _apply_exif_orientation(img) |
| img = convert_PIL_to_numpy(img, 'RGB') |
| |
| img = Image.fromarray(img) |
|
|
| ref = self.refer.loadRefs(this_ref_id) |
|
|
| ref_mask = np.array(self.refer.getMask(ref[0])['mask']) |
| annot = np.zeros(ref_mask.shape) |
| annot[ref_mask == 1] = 1 |
|
|
| annot = Image.fromarray(annot.astype(np.uint8), mode="P") |
|
|
| if self.image_transforms is not None: |
| |
| img, target = self.image_transforms(img, annot) |
|
|
| if self.eval_mode: |
| embedding = [] |
| att = [] |
| for s in range(len(self.input_ids[index])): |
| e = self.input_ids[index][s] |
| a = self.attention_masks[index][s] |
| embedding.append(e.unsqueeze(-1)) |
| att.append(a.unsqueeze(-1)) |
|
|
| tensor_embeddings = torch.cat(embedding, dim=-1) |
| attention_mask = torch.cat(att, dim=-1) |
| return img, target, tensor_embeddings, attention_mask |
| else: |
| choice_sent = np.random.choice(len(self.input_ids[index])) |
| tensor_embeddings = self.input_ids[index][choice_sent] |
| attention_mask = self.attention_masks[index][choice_sent] |
|
|
| |
| if self.split == 'val': |
| return img, target, tensor_embeddings, attention_mask |
| else: |
| return img, target, tensor_embeddings, attention_mask, torch.tensor(self.mixed_masks[this_img_id[0]]['masks']) |
|
|