File size: 4,607 Bytes
6029b11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import json
import os
import random
from tqdm import tqdm
from torch.utils.data import Dataset
from mask_image import ImageNet_Masked
from pycocotools.coco import COCO
from pycocotools import mask as maskUtils
from PIL import Image
import cv2
import random
from torchvision import transforms
from tqdm import tqdm
PIXEL_MEAN = (0.48145466, 0.4578275, 0.40821073)
MASK_FILL = [int(255 * c) for c in PIXEL_MEAN]
import pickle
import torch
import numpy as np
import copy
import sys
import shutil
from PIL import Image

def get_file(url):
    return #TODO: get file path from local directory

clip_standard_transform = transforms.Compose([
    transforms.ToTensor(), 
    transforms.Resize((224, 224), interpolation=Image.BICUBIC),
    transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
])

hi_clip_standard_transform = transforms.Compose([
    transforms.ToTensor(), 
    transforms.Resize((336, 336), interpolation=Image.BICUBIC),
    transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
])

res_clip_standard_transform = transforms.Compose([
    transforms.ToTensor(), 
    transforms.Resize((336, 336), interpolation=Image.BICUBIC),
    transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
])

mask_transform = transforms.Compose([
    transforms.ToTensor(), 
    transforms.Resize((224, 224)),
    transforms.Normalize(0.5, 0.26)
])

hi_mask_transform = transforms.Compose([
    transforms.ToTensor(), 
    transforms.Resize((336, 336)),
    transforms.Normalize(0.5, 0.26)
])

res_mask_transform = transforms.Compose([
    transforms.ToTensor(), 
    transforms.Resize((336, 336)),
    transforms.Normalize(0.5, 0.26)
])

def crop_center(img, croph, cropw):
    h, w = img.shape[:2]
    starth = h//2 - (croph//2)
    startw = w//2 - (cropw//2)    
    return img[starth:starth+croph, startw:startw+cropw, :]

class Alpha_GRIT(Dataset):
    def __init__(self, ids_file='grit_1m_ids.pkl', root_pth='grit-1m/', common_pair=0.0, hi_res=False, subnum=None):
        if subnum is not None:
            self.ids = pickle.load(open(ids_file, 'rb'))[:subnum]
        else:
            self.ids = pickle.load(open(ids_file, 'rb'))
        self.root_pth = root_pth
        self.with_common_pair_prop = common_pair
        if hi_res:
            self.mask_transform = res_mask_transform
            self.clip_standard_transform = res_clip_standard_transform
        else:
            self.mask_transform = mask_transform
            self.clip_standard_transform = clip_standard_transform
        
    def __len__(self):
        return len(self.ids)

    def __getitem__(self, index):
        id = self.ids[index]
        ann = json.loads(get_file(self.root_pth + str(id) + '.json'))
        image_data = get_file(self.root_pth + str(id) + '.jpg')
        img = np.frombuffer(image_data, dtype=np.uint8)
        img = cv2.imdecode(img, cv2.IMREAD_COLOR)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        ref_exps = ann['ref_exps']
        # random choose single ref with its corresponding masks
        choice = random.randint(0, len(ref_exps)-1)
        ref_exp = ref_exps[choice]
        text = ann['caption'][int(ref_exp[0]): int(ref_exp[1])]
        mask = maskUtils.decode(ann['seudo_masks'][choice])
        if mask.shape != img.shape[:2]:
            img = np.rot90(img)
        rgba = np.concatenate((img, np.expand_dims(mask, axis=-1)), axis=-1)
        h, w = rgba.shape[:2]
        choice = random.randint(0, 1)
        choice = 0
        if choice == 0:
            if max(h, w) == w:
                pad = (w - h) // 2
                l, r = pad, w - h - pad
                rgba = np.pad(rgba, ((l, r), (0, 0), (0, 0)), 'constant', constant_values=0)
            else:
                pad = (h - w) // 2
                l, r = pad, h - w - pad
                rgba = np.pad(rgba, ((0, 0), (l, r), (0, 0)), 'constant', constant_values=0)
        else:
            if min(h, w) == h:
                rgba = crop_center(rgba, h, h)
            else:
                rgba = crop_center(rgba, w, w)
        rgb = rgba[:, :, :-1]
        mask = rgba[:, :, -1]
        image_torch = self.clip_standard_transform(rgb)

        choice = random.random()
        if choice >= self.with_common_pair_prop:
            mask_torch = self.mask_transform(mask * 255)
            return image_torch, mask_torch, text 
        else: # half ori image
            mask_torch = self.mask_transform(np.ones_like(mask) * 255)
            return image_torch, mask_torch, ann['caption']