Spaces:
Runtime error
Runtime error
Delete clipseg/datasets/pfe_dataset.py
Browse files- clipseg/datasets/pfe_dataset.py +0 -129
clipseg/datasets/pfe_dataset.py
DELETED
|
@@ -1,129 +0,0 @@
|
|
| 1 |
-
from os.path import expanduser
|
| 2 |
-
import torch
|
| 3 |
-
import json
|
| 4 |
-
from general_utils import get_from_repository
|
| 5 |
-
from datasets.lvis_oneshot3 import blend_image_segmentation
|
| 6 |
-
from general_utils import log
|
| 7 |
-
|
| 8 |
-
PASCAL_CLASSES = {a['id']: a['synonyms'] for a in json.load(open('datasets/pascal_classes.json'))}
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
class PFEPascalWrapper(object):
|
| 12 |
-
|
| 13 |
-
def __init__(self, mode, split, mask='separate', image_size=473, label_support=None, size=None, p_negative=0, aug=None):
|
| 14 |
-
import sys
|
| 15 |
-
# sys.path.append(expanduser('~/projects/new_one_shot'))
|
| 16 |
-
from third_party.PFENet.util.dataset import SemData
|
| 17 |
-
|
| 18 |
-
get_from_repository('PascalVOC2012', ['Pascal5i.tar'])
|
| 19 |
-
|
| 20 |
-
self.p_negative = p_negative
|
| 21 |
-
self.size = size
|
| 22 |
-
self.mode = mode
|
| 23 |
-
self.image_size = image_size
|
| 24 |
-
|
| 25 |
-
if label_support in {True, False}:
|
| 26 |
-
log.warning('label_support argument is deprecated. Use mask instead.')
|
| 27 |
-
#raise ValueError()
|
| 28 |
-
|
| 29 |
-
self.mask = mask
|
| 30 |
-
|
| 31 |
-
value_scale = 255
|
| 32 |
-
mean = [0.485, 0.456, 0.406]
|
| 33 |
-
mean = [item * value_scale for item in mean]
|
| 34 |
-
std = [0.229, 0.224, 0.225]
|
| 35 |
-
std = [item * value_scale for item in std]
|
| 36 |
-
|
| 37 |
-
import third_party.PFENet.util.transform as transform
|
| 38 |
-
|
| 39 |
-
if mode == 'val':
|
| 40 |
-
data_list = expanduser('~/projects/old_one_shot/PFENet/lists/pascal/val.txt')
|
| 41 |
-
|
| 42 |
-
data_transform = [transform.test_Resize(size=image_size)] if image_size != 'original' else []
|
| 43 |
-
data_transform += [
|
| 44 |
-
transform.ToTensor(),
|
| 45 |
-
transform.Normalize(mean=mean, std=std)
|
| 46 |
-
]
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
elif mode == 'train':
|
| 50 |
-
data_list = expanduser('~/projects/old_one_shot/PFENet/lists/pascal/voc_sbd_merge_noduplicate.txt')
|
| 51 |
-
|
| 52 |
-
assert image_size != 'original'
|
| 53 |
-
|
| 54 |
-
data_transform = [
|
| 55 |
-
transform.RandScale([0.9, 1.1]),
|
| 56 |
-
transform.RandRotate([-10, 10], padding=mean, ignore_label=255),
|
| 57 |
-
transform.RandomGaussianBlur(),
|
| 58 |
-
transform.RandomHorizontalFlip(),
|
| 59 |
-
transform.Crop((image_size, image_size), crop_type='rand', padding=mean, ignore_label=255),
|
| 60 |
-
transform.ToTensor(),
|
| 61 |
-
transform.Normalize(mean=mean, std=std)
|
| 62 |
-
]
|
| 63 |
-
|
| 64 |
-
data_transform = transform.Compose(data_transform)
|
| 65 |
-
|
| 66 |
-
self.dataset = SemData(split=split, mode=mode, data_root=expanduser('~/datasets/PascalVOC2012/VOC2012'),
|
| 67 |
-
data_list=data_list, shot=1, transform=data_transform, use_coco=False, use_split_coco=False)
|
| 68 |
-
|
| 69 |
-
self.class_list = self.dataset.sub_val_list if mode == 'val' else self.dataset.sub_list
|
| 70 |
-
|
| 71 |
-
# verify that subcls_list always has length 1
|
| 72 |
-
# assert len(set([len(d[4]) for d in self.dataset])) == 1
|
| 73 |
-
|
| 74 |
-
print('actual length', len(self.dataset.data_list))
|
| 75 |
-
|
| 76 |
-
def __len__(self):
|
| 77 |
-
if self.mode == 'val':
|
| 78 |
-
return len(self.dataset.data_list)
|
| 79 |
-
else:
|
| 80 |
-
return len(self.dataset.data_list)
|
| 81 |
-
|
| 82 |
-
def __getitem__(self, index):
|
| 83 |
-
if self.dataset.mode == 'train':
|
| 84 |
-
image, label, s_x, s_y, subcls_list = self.dataset[index % len(self.dataset.data_list)]
|
| 85 |
-
elif self.dataset.mode == 'val':
|
| 86 |
-
image, label, s_x, s_y, subcls_list, ori_label = self.dataset[index % len(self.dataset.data_list)]
|
| 87 |
-
ori_label = torch.from_numpy(ori_label).unsqueeze(0)
|
| 88 |
-
|
| 89 |
-
if self.image_size != 'original':
|
| 90 |
-
longerside = max(ori_label.size(1), ori_label.size(2))
|
| 91 |
-
backmask = torch.ones(ori_label.size(0), longerside, longerside).cuda()*255
|
| 92 |
-
backmask[0, :ori_label.size(1), :ori_label.size(2)] = ori_label
|
| 93 |
-
label = backmask.clone().long()
|
| 94 |
-
else:
|
| 95 |
-
label = label.unsqueeze(0)
|
| 96 |
-
|
| 97 |
-
# assert label.shape == (473, 473)
|
| 98 |
-
|
| 99 |
-
if self.p_negative > 0:
|
| 100 |
-
if torch.rand(1).item() < self.p_negative:
|
| 101 |
-
while True:
|
| 102 |
-
idx = torch.randint(0, len(self.dataset.data_list), (1,)).item()
|
| 103 |
-
_, _, s_x, s_y, subcls_list_tmp, _ = self.dataset[idx]
|
| 104 |
-
if subcls_list[0] != subcls_list_tmp[0]:
|
| 105 |
-
break
|
| 106 |
-
|
| 107 |
-
s_x = s_x[0]
|
| 108 |
-
s_y = (s_y == 1)[0]
|
| 109 |
-
label_fg = (label == 1).float()
|
| 110 |
-
val_mask = (label != 255).float()
|
| 111 |
-
|
| 112 |
-
class_id = self.class_list[subcls_list[0]]
|
| 113 |
-
|
| 114 |
-
label_name = PASCAL_CLASSES[class_id][0]
|
| 115 |
-
label_add = ()
|
| 116 |
-
mask = self.mask
|
| 117 |
-
|
| 118 |
-
if mask == 'text':
|
| 119 |
-
support = ('a photo of a ' + label_name + '.',)
|
| 120 |
-
elif mask == 'separate':
|
| 121 |
-
support = (s_x, s_y)
|
| 122 |
-
else:
|
| 123 |
-
if mask.startswith('text_and_'):
|
| 124 |
-
label_add = (label_name,)
|
| 125 |
-
mask = mask[9:]
|
| 126 |
-
|
| 127 |
-
support = (blend_image_segmentation(s_x, s_y.float(), mask)[0],)
|
| 128 |
-
|
| 129 |
-
return (image,) + label_add + support, (label_fg.unsqueeze(0), val_mask.unsqueeze(0), subcls_list[0])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|