Mr7Explorer commited on
Commit
e251440
·
verified ·
1 Parent(s): 5f73188

Create dataset.py

Browse files
Files changed (1) hide show
  1. dataset.py +173 -0
dataset.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ import numpy as np
4
+ import cv2
5
+ from tqdm import tqdm
6
+ from PIL import Image
7
+ from torch.utils import data
8
+ from torchvision import transforms
9
+
10
+ from image_proc import preproc
11
+ from config import Config
12
+ from utils import path_to_image
13
+
14
+
15
+ Image.MAX_IMAGE_PIXELS = None # remove DecompressionBombWarning
16
+ config = Config()
17
+ _class_labels_TR_sorted = (
18
+ 'Airplane, Ant, Antenna, Archery, Axe, BabyCarriage, Bag, BalanceBeam, Balcony, Balloon, Basket, BasketballHoop, Beatle, Bed, Bee, Bench, Bicycle, '
19
+ 'BicycleFrame, BicycleStand, Boat, Bonsai, BoomLift, Bridge, BunkBed, Butterfly, Button, Cable, CableLift, Cage, Camcorder, Cannon, Canoe, Car, '
20
+ 'CarParkDropArm, Carriage, Cart, Caterpillar, CeilingLamp, Centipede, Chair, Clip, Clock, Clothes, CoatHanger, Comb, ConcretePumpTruck, Crack, Crane, '
21
+ 'Cup, DentalChair, Desk, DeskChair, Diagram, DishRack, DoorHandle, Dragonfish, Dragonfly, Drum, Earphone, Easel, ElectricIron, Excavator, Eyeglasses, '
22
+ 'Fan, Fence, Fencing, FerrisWheel, FireExtinguisher, Fishing, Flag, FloorLamp, Forklift, GasStation, Gate, Gear, Goal, Golf, GymEquipment, Hammock, '
23
+ 'Handcart, Handcraft, Handrail, HangGlider, Harp, Harvester, Headset, Helicopter, Helmet, Hook, HorizontalBar, Hydrovalve, IroningTable, Jewelry, Key, '
24
+ 'KidsPlayground, Kitchenware, Kite, Knife, Ladder, LaundryRack, Lightning, Lobster, Locust, Machine, MachineGun, MagazineRack, Mantis, Medal, MemorialArchway, '
25
+ 'Microphone, Missile, MobileHolder, Monitor, Mosquito, Motorcycle, MovingTrolley, Mower, MusicPlayer, MusicStand, ObservationTower, Octopus, OilWell, '
26
+ 'OlympicLogo, OperatingTable, OutdoorFitnessEquipment, Parachute, Pavilion, Piano, Pipe, PlowHarrow, PoleVault, Punchbag, Rack, Racket, Rifle, Ring, Robot, '
27
+ 'RockClimbing, Rope, Sailboat, Satellite, Scaffold, Scale, Scissor, Scooter, Sculpture, Seadragon, Seahorse, Seal, SewingMachine, Ship, Shoe, ShoppingCart, '
28
+ 'ShoppingTrolley, Shower, Shrimp, Signboard, Skateboarding, Skeleton, Skiing, Spade, SpeedBoat, Spider, Spoon, Stair, Stand, Stationary, SteeringWheel, '
29
+ 'Stethoscope, Stool, Stove, StreetLamp, SweetStand, Swing, Sword, TV, Table, TableChair, TableLamp, TableTennis, Tank, Tapeline, Teapot, Telescope, Tent, '
30
+ 'TobaccoPipe, Toy, Tractor, TrafficLight, TrafficSign, Trampoline, TransmissionTower, Tree, Tricycle, TrimmerCover, Tripod, Trombone, Truck, Trumpet, Tuba, '
31
+ 'UAV, Umbrella, UnevenBars, UtilityPole, VacuumCleaner, Violin, Wakesurfing, Watch, WaterTower, WateringPot, Well, WellLid, Wheel, Wheelchair, WindTurbine, Windmill, WineGlass, WireWhisk, Yacht'
32
+ )
33
+ class_labels_TR_sorted = _class_labels_TR_sorted.split(', ')
34
+
35
+
36
+ class MyData(data.Dataset):
37
+ def __init__(self, datasets, data_size, is_train=True):
38
+ # data_size is None when using dynamic_size or data_size is manually set to None (for inference in the original size).
39
+ self.is_train = is_train
40
+ self.data_size = data_size
41
+ self.load_all = config.load_all
42
+ self.device = config.device
43
+ valid_extensions = ['.png', '.jpg', '.PNG', '.JPG', '.JPEG']
44
+
45
+ if self.is_train and config.auxiliary_classification:
46
+ self.cls_name2id = {_name: _id for _id, _name in enumerate(class_labels_TR_sorted)}
47
+ self.transform_image = transforms.Compose([
48
+ transforms.ToTensor(),
49
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
50
+ ])
51
+ self.transform_label = transforms.Compose([
52
+ transforms.ToTensor(),
53
+ ])
54
+ dataset_root = os.path.join(config.data_root_dir, config.task)
55
+ # datasets can be a list of different datasets for training on combined sets.
56
+ self.image_paths = []
57
+ for dataset in datasets.split('+'):
58
+ image_root = os.path.join(dataset_root, dataset, 'im')
59
+ self.image_paths += [os.path.join(image_root, p) for p in os.listdir(image_root) if any(p.endswith(ext) for ext in valid_extensions)]
60
+ self.label_paths = []
61
+ for p in self.image_paths:
62
+ for ext in valid_extensions:
63
+ ## 'im' and 'gt' may need modifying
64
+ p_gt = p.replace('/im/', '/gt/')[:-(len(p.split('.')[-1])+1)] + ext
65
+ file_exists = False
66
+ if os.path.exists(p_gt):
67
+ self.label_paths.append(p_gt)
68
+ file_exists = True
69
+ break
70
+ if not file_exists:
71
+ print('Not exists:', p_gt)
72
+
73
+ if len(self.label_paths) != len(self.image_paths):
74
+ set_image_paths = set([os.path.splitext(p.split(os.sep)[-1])[0] for p in self.image_paths])
75
+ set_label_paths = set([os.path.splitext(p.split(os.sep)[-1])[0] for p in self.label_paths])
76
+ print('Path diff:', set_image_paths - set_label_paths)
77
+ raise ValueError(f"There are different numbers of images ({len(self.label_paths)}) and labels ({len(self.image_paths)})")
78
+
79
+ if self.load_all:
80
+ self.images_loaded, self.labels_loaded = [], []
81
+ self.class_labels_loaded = []
82
+ # for image_path, label_path in zip(self.image_paths, self.label_paths):
83
+ for image_path, label_path in tqdm(zip(self.image_paths, self.label_paths), total=len(self.image_paths)):
84
+ _image = path_to_image(image_path, size=self.data_size, color_type='rgb')
85
+ _label = path_to_image(label_path, size=self.data_size, color_type='gray')
86
+ self.images_loaded.append(_image)
87
+ self.labels_loaded.append(_label)
88
+ self.class_labels_loaded.append(
89
+ self.cls_name2id[label_path.split('/')[-1].split('#')[3]] if self.is_train and config.auxiliary_classification else -1
90
+ )
91
+
92
+ def __getitem__(self, index):
93
+ if self.load_all:
94
+ image = self.images_loaded[index]
95
+ label = self.labels_loaded[index]
96
+ class_label = self.class_labels_loaded[index] if self.is_train and config.auxiliary_classification else -1
97
+ else:
98
+ image = path_to_image(self.image_paths[index], size=self.data_size, color_type='rgb')
99
+ label = path_to_image(self.label_paths[index], size=self.data_size, color_type='gray')
100
+ class_label = self.cls_name2id[self.label_paths[index].split('/')[-1].split('#')[3]] if self.is_train and config.auxiliary_classification else -1
101
+
102
+ # loading image and label
103
+ if self.is_train:
104
+ if config.background_color_synthesis:
105
+ image.putalpha(label)
106
+ array_image = np.array(image)
107
+ array_foreground = array_image[:, :, :3].astype(np.float32)
108
+ array_mask = (array_image[:, :, 3:] / 255).astype(np.float32)
109
+ array_background = np.zeros_like(array_foreground)
110
+ choice = random.random()
111
+ if choice < 0.4:
112
+ # Black/Gray/White backgrounds
113
+ array_background[:, :, :] = random.randint(0, 255)
114
+ elif choice < 0.8:
115
+ # Background color that similar to the foreground object. Hard negative samples.
116
+ foreground_pixel_number = np.sum(array_mask > 0)
117
+ color_foreground_mean = np.mean(array_foreground * array_mask, axis=(0, 1)) * (np.prod(array_foreground.shape[:2]) / foreground_pixel_number)
118
+ color_up_or_down = random.choice((-1, 1))
119
+ # Up or down for 20% range from 255 or 0, respectively.
120
+ color_foreground_mean += (255 - color_foreground_mean if color_up_or_down == 1 else color_foreground_mean) * (random.random() * 0.2) * color_up_or_down
121
+ array_background[:, :, :] = color_foreground_mean
122
+ else:
123
+ # Any color
124
+ for idx_channel in range(3):
125
+ array_background[:, :, idx_channel] = random.randint(0, 255)
126
+ array_foreground_background = array_foreground * array_mask + array_background * (1 - array_mask)
127
+ image = Image.fromarray(array_foreground_background.astype(np.uint8))
128
+ image, label = preproc(image, label, preproc_methods=config.preproc_methods)
129
+ # else:
130
+ # if _label.shape[0] > 2048 or _label.shape[1] > 2048:
131
+ # _image = cv2.resize(_image, (2048, 2048), interpolation=cv2.INTER_LINEAR)
132
+ # _label = cv2.resize(_label, (2048, 2048), interpolation=cv2.INTER_LINEAR)
133
+
134
+ # At present, we use fixed sizes in inference, instead of consistent dynamic size with training.
135
+ if self.is_train:
136
+ if config.dynamic_size is None:
137
+ image, label = self.transform_image(image), self.transform_label(label)
138
+ else:
139
+ size_div_32 = (int(image.size[0] // 32 * 32), int(image.size[1] // 32 * 32))
140
+ if image.size != size_div_32:
141
+ image = image.resize(size_div_32)
142
+ label = label.resize(size_div_32)
143
+ image, label = self.transform_image(image), self.transform_label(label)
144
+
145
+ if self.is_train:
146
+ return image, label, class_label
147
+ else:
148
+ return image, label, self.label_paths[index]
149
+
150
+ def __len__(self):
151
+ return len(self.image_paths)
152
+
153
+
154
+ def custom_collate_fn(batch):
155
+ if config.dynamic_size:
156
+ dynamic_size = tuple(sorted(config.dynamic_size))
157
+ dynamic_size_batch = (random.randint(dynamic_size[0][0], dynamic_size[0][1]) // 32 * 32, random.randint(dynamic_size[1][0], dynamic_size[1][1]) // 32 * 32) # select a value randomly in the range of [dynamic_size[0/1][0], dynamic_size[0/1][1]].
158
+ data_size = dynamic_size_batch
159
+ else:
160
+ data_size = config.size
161
+ new_batch = []
162
+ transform_image = transforms.Compose([
163
+ transforms.Resize(data_size[::-1]),
164
+ transforms.ToTensor(),
165
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
166
+ ])
167
+ transform_label = transforms.Compose([
168
+ transforms.Resize(data_size[::-1]),
169
+ transforms.ToTensor(),
170
+ ])
171
+ for image, label, class_label in batch:
172
+ new_batch.append((transform_image(image), transform_label(label), class_label))
173
+ return data._utils.collate.default_collate(new_batch)