vorkna commited on
Commit
3831d63
·
verified ·
1 Parent(s): 4d05703

Upload 7 files

Browse files
Files changed (6) hide show
  1. coco_eval.py +192 -0
  2. coco_utils.py +234 -0
  3. engine.py +115 -0
  4. torchvision.ipynb +0 -0
  5. transforms.py +601 -0
  6. utils.py +282 -0
coco_eval.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import io
3
+ from contextlib import redirect_stdout
4
+
5
+ import numpy as np
6
+ import pycocotools.mask as mask_util
7
+ import torch
8
+ import utils
9
+ from pycocotools.coco import COCO
10
+ from pycocotools.cocoeval import COCOeval
11
+
12
+
13
+ class CocoEvaluator:
14
+ def __init__(self, coco_gt, iou_types):
15
+ if not isinstance(iou_types, (list, tuple)):
16
+ raise TypeError(f"This constructor expects iou_types of type list or tuple, instead got {type(iou_types)}")
17
+ coco_gt = copy.deepcopy(coco_gt)
18
+ self.coco_gt = coco_gt
19
+
20
+ self.iou_types = iou_types
21
+ self.coco_eval = {}
22
+ for iou_type in iou_types:
23
+ self.coco_eval[iou_type] = COCOeval(coco_gt, iouType=iou_type)
24
+
25
+ self.img_ids = []
26
+ self.eval_imgs = {k: [] for k in iou_types}
27
+
28
+ def update(self, predictions):
29
+ img_ids = list(np.unique(list(predictions.keys())))
30
+ self.img_ids.extend(img_ids)
31
+
32
+ for iou_type in self.iou_types:
33
+ results = self.prepare(predictions, iou_type)
34
+ with redirect_stdout(io.StringIO()):
35
+ coco_dt = COCO.loadRes(self.coco_gt, results) if results else COCO()
36
+ coco_eval = self.coco_eval[iou_type]
37
+
38
+ coco_eval.cocoDt = coco_dt
39
+ coco_eval.params.imgIds = list(img_ids)
40
+ img_ids, eval_imgs = evaluate(coco_eval)
41
+
42
+ self.eval_imgs[iou_type].append(eval_imgs)
43
+
44
+ def synchronize_between_processes(self):
45
+ for iou_type in self.iou_types:
46
+ self.eval_imgs[iou_type] = np.concatenate(self.eval_imgs[iou_type], 2)
47
+ create_common_coco_eval(self.coco_eval[iou_type], self.img_ids, self.eval_imgs[iou_type])
48
+
49
+ def accumulate(self):
50
+ for coco_eval in self.coco_eval.values():
51
+ coco_eval.accumulate()
52
+
53
+ def summarize(self):
54
+ for iou_type, coco_eval in self.coco_eval.items():
55
+ print(f"IoU metric: {iou_type}")
56
+ coco_eval.summarize()
57
+
58
+ def prepare(self, predictions, iou_type):
59
+ if iou_type == "bbox":
60
+ return self.prepare_for_coco_detection(predictions)
61
+ if iou_type == "segm":
62
+ return self.prepare_for_coco_segmentation(predictions)
63
+ if iou_type == "keypoints":
64
+ return self.prepare_for_coco_keypoint(predictions)
65
+ raise ValueError(f"Unknown iou type {iou_type}")
66
+
67
+ def prepare_for_coco_detection(self, predictions):
68
+ coco_results = []
69
+ for original_id, prediction in predictions.items():
70
+ if len(prediction) == 0:
71
+ continue
72
+
73
+ boxes = prediction["boxes"]
74
+ boxes = convert_to_xywh(boxes).tolist()
75
+ scores = prediction["scores"].tolist()
76
+ labels = prediction["labels"].tolist()
77
+
78
+ coco_results.extend(
79
+ [
80
+ {
81
+ "image_id": original_id,
82
+ "category_id": labels[k],
83
+ "bbox": box,
84
+ "score": scores[k],
85
+ }
86
+ for k, box in enumerate(boxes)
87
+ ]
88
+ )
89
+ return coco_results
90
+
91
+ def prepare_for_coco_segmentation(self, predictions):
92
+ coco_results = []
93
+ for original_id, prediction in predictions.items():
94
+ if len(prediction) == 0:
95
+ continue
96
+
97
+ scores = prediction["scores"]
98
+ labels = prediction["labels"]
99
+ masks = prediction["masks"]
100
+
101
+ masks = masks > 0.5
102
+
103
+ scores = prediction["scores"].tolist()
104
+ labels = prediction["labels"].tolist()
105
+
106
+ rles = [
107
+ mask_util.encode(np.array(mask[0, :, :, np.newaxis], dtype=np.uint8, order="F"))[0] for mask in masks
108
+ ]
109
+ for rle in rles:
110
+ rle["counts"] = rle["counts"].decode("utf-8")
111
+
112
+ coco_results.extend(
113
+ [
114
+ {
115
+ "image_id": original_id,
116
+ "category_id": labels[k],
117
+ "segmentation": rle,
118
+ "score": scores[k],
119
+ }
120
+ for k, rle in enumerate(rles)
121
+ ]
122
+ )
123
+ return coco_results
124
+
125
+ def prepare_for_coco_keypoint(self, predictions):
126
+ coco_results = []
127
+ for original_id, prediction in predictions.items():
128
+ if len(prediction) == 0:
129
+ continue
130
+
131
+ boxes = prediction["boxes"]
132
+ boxes = convert_to_xywh(boxes).tolist()
133
+ scores = prediction["scores"].tolist()
134
+ labels = prediction["labels"].tolist()
135
+ keypoints = prediction["keypoints"]
136
+ keypoints = keypoints.flatten(start_dim=1).tolist()
137
+
138
+ coco_results.extend(
139
+ [
140
+ {
141
+ "image_id": original_id,
142
+ "category_id": labels[k],
143
+ "keypoints": keypoint,
144
+ "score": scores[k],
145
+ }
146
+ for k, keypoint in enumerate(keypoints)
147
+ ]
148
+ )
149
+ return coco_results
150
+
151
+
152
+ def convert_to_xywh(boxes):
153
+ xmin, ymin, xmax, ymax = boxes.unbind(1)
154
+ return torch.stack((xmin, ymin, xmax - xmin, ymax - ymin), dim=1)
155
+
156
+
157
+ def merge(img_ids, eval_imgs):
158
+ all_img_ids = utils.all_gather(img_ids)
159
+ all_eval_imgs = utils.all_gather(eval_imgs)
160
+
161
+ merged_img_ids = []
162
+ for p in all_img_ids:
163
+ merged_img_ids.extend(p)
164
+
165
+ merged_eval_imgs = []
166
+ for p in all_eval_imgs:
167
+ merged_eval_imgs.append(p)
168
+
169
+ merged_img_ids = np.array(merged_img_ids)
170
+ merged_eval_imgs = np.concatenate(merged_eval_imgs, 2)
171
+
172
+ # keep only unique (and in sorted order) images
173
+ merged_img_ids, idx = np.unique(merged_img_ids, return_index=True)
174
+ merged_eval_imgs = merged_eval_imgs[..., idx]
175
+
176
+ return merged_img_ids, merged_eval_imgs
177
+
178
+
179
+ def create_common_coco_eval(coco_eval, img_ids, eval_imgs):
180
+ img_ids, eval_imgs = merge(img_ids, eval_imgs)
181
+ img_ids = list(img_ids)
182
+ eval_imgs = list(eval_imgs.flatten())
183
+
184
+ coco_eval.evalImgs = eval_imgs
185
+ coco_eval.params.imgIds = img_ids
186
+ coco_eval._paramsEval = copy.deepcopy(coco_eval.params)
187
+
188
+
189
+ def evaluate(imgs):
190
+ with redirect_stdout(io.StringIO()):
191
+ imgs.evaluate()
192
+ return imgs.params.imgIds, np.asarray(imgs.evalImgs).reshape(-1, len(imgs.params.areaRng), len(imgs.params.imgIds))
coco_utils.py ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import torch
4
+ import torch.utils.data
5
+ import torchvision
6
+ import transforms as T
7
+ from pycocotools import mask as coco_mask
8
+ from pycocotools.coco import COCO
9
+
10
+
11
+ def convert_coco_poly_to_mask(segmentations, height, width):
12
+ masks = []
13
+ for polygons in segmentations:
14
+ rles = coco_mask.frPyObjects(polygons, height, width)
15
+ mask = coco_mask.decode(rles)
16
+ if len(mask.shape) < 3:
17
+ mask = mask[..., None]
18
+ mask = torch.as_tensor(mask, dtype=torch.uint8)
19
+ mask = mask.any(dim=2)
20
+ masks.append(mask)
21
+ if masks:
22
+ masks = torch.stack(masks, dim=0)
23
+ else:
24
+ masks = torch.zeros((0, height, width), dtype=torch.uint8)
25
+ return masks
26
+
27
+
28
+ class ConvertCocoPolysToMask:
29
+ def __call__(self, image, target):
30
+ w, h = image.size
31
+
32
+ image_id = target["image_id"]
33
+
34
+ anno = target["annotations"]
35
+
36
+ anno = [obj for obj in anno if obj["iscrowd"] == 0]
37
+
38
+ boxes = [obj["bbox"] for obj in anno]
39
+ # guard against no boxes via resizing
40
+ boxes = torch.as_tensor(boxes, dtype=torch.float32).reshape(-1, 4)
41
+ boxes[:, 2:] += boxes[:, :2]
42
+ boxes[:, 0::2].clamp_(min=0, max=w)
43
+ boxes[:, 1::2].clamp_(min=0, max=h)
44
+
45
+ classes = [obj["category_id"] for obj in anno]
46
+ classes = torch.tensor(classes, dtype=torch.int64)
47
+
48
+ segmentations = [obj["segmentation"] for obj in anno]
49
+ masks = convert_coco_poly_to_mask(segmentations, h, w)
50
+
51
+ keypoints = None
52
+ if anno and "keypoints" in anno[0]:
53
+ keypoints = [obj["keypoints"] for obj in anno]
54
+ keypoints = torch.as_tensor(keypoints, dtype=torch.float32)
55
+ num_keypoints = keypoints.shape[0]
56
+ if num_keypoints:
57
+ keypoints = keypoints.view(num_keypoints, -1, 3)
58
+
59
+ keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0])
60
+ boxes = boxes[keep]
61
+ classes = classes[keep]
62
+ masks = masks[keep]
63
+ if keypoints is not None:
64
+ keypoints = keypoints[keep]
65
+
66
+ target = {}
67
+ target["boxes"] = boxes
68
+ target["labels"] = classes
69
+ target["masks"] = masks
70
+ target["image_id"] = image_id
71
+ if keypoints is not None:
72
+ target["keypoints"] = keypoints
73
+
74
+ # for conversion to coco api
75
+ area = torch.tensor([obj["area"] for obj in anno])
76
+ iscrowd = torch.tensor([obj["iscrowd"] for obj in anno])
77
+ target["area"] = area
78
+ target["iscrowd"] = iscrowd
79
+
80
+ return image, target
81
+
82
+
83
+ def _coco_remove_images_without_annotations(dataset, cat_list=None):
84
+ def _has_only_empty_bbox(anno):
85
+ return all(any(o <= 1 for o in obj["bbox"][2:]) for obj in anno)
86
+
87
+ def _count_visible_keypoints(anno):
88
+ return sum(sum(1 for v in ann["keypoints"][2::3] if v > 0) for ann in anno)
89
+
90
+ min_keypoints_per_image = 10
91
+
92
+ def _has_valid_annotation(anno):
93
+ # if it's empty, there is no annotation
94
+ if len(anno) == 0:
95
+ return False
96
+ # if all boxes have close to zero area, there is no annotation
97
+ if _has_only_empty_bbox(anno):
98
+ return False
99
+ # keypoints task have a slight different criteria for considering
100
+ # if an annotation is valid
101
+ if "keypoints" not in anno[0]:
102
+ return True
103
+ # for keypoint detection tasks, only consider valid images those
104
+ # containing at least min_keypoints_per_image
105
+ if _count_visible_keypoints(anno) >= min_keypoints_per_image:
106
+ return True
107
+ return False
108
+
109
+ ids = []
110
+ for ds_idx, img_id in enumerate(dataset.ids):
111
+ ann_ids = dataset.coco.getAnnIds(imgIds=img_id, iscrowd=None)
112
+ anno = dataset.coco.loadAnns(ann_ids)
113
+ if cat_list:
114
+ anno = [obj for obj in anno if obj["category_id"] in cat_list]
115
+ if _has_valid_annotation(anno):
116
+ ids.append(ds_idx)
117
+
118
+ dataset = torch.utils.data.Subset(dataset, ids)
119
+ return dataset
120
+
121
+
122
+ def convert_to_coco_api(ds):
123
+ coco_ds = COCO()
124
+ # annotation IDs need to start at 1, not 0, see torchvision issue #1530
125
+ ann_id = 1
126
+ dataset = {"images": [], "categories": [], "annotations": [], "info": {}}
127
+ categories = set()
128
+ for img_idx in range(len(ds)):
129
+ # find better way to get target
130
+ # targets = ds.get_annotations(img_idx)
131
+ img, targets = ds[img_idx]
132
+ image_id = targets["image_id"]
133
+ img_dict = {}
134
+ img_dict["id"] = image_id
135
+ img_dict["height"] = img.shape[-2]
136
+ img_dict["width"] = img.shape[-1]
137
+ dataset["images"].append(img_dict)
138
+ bboxes = targets["boxes"].clone()
139
+ bboxes[:, 2:] -= bboxes[:, :2]
140
+ bboxes = bboxes.tolist()
141
+ labels = targets["labels"].tolist()
142
+ areas = targets["area"].tolist()
143
+ iscrowd = targets["iscrowd"].tolist()
144
+ if "masks" in targets:
145
+ masks = targets["masks"]
146
+ # make masks Fortran contiguous for coco_mask
147
+ masks = masks.permute(0, 2, 1).contiguous().permute(0, 2, 1)
148
+ if "keypoints" in targets:
149
+ keypoints = targets["keypoints"]
150
+ keypoints = keypoints.reshape(keypoints.shape[0], -1).tolist()
151
+ num_objs = len(bboxes)
152
+ for i in range(num_objs):
153
+ ann = {}
154
+ ann["image_id"] = image_id
155
+ ann["bbox"] = bboxes[i]
156
+ ann["category_id"] = labels[i]
157
+ categories.add(labels[i])
158
+ ann["area"] = areas[i]
159
+ ann["iscrowd"] = iscrowd[i]
160
+ ann["id"] = ann_id
161
+ if "masks" in targets:
162
+ ann["segmentation"] = coco_mask.encode(masks[i].numpy())
163
+ if "keypoints" in targets:
164
+ ann["keypoints"] = keypoints[i]
165
+ ann["num_keypoints"] = sum(k != 0 for k in keypoints[i][2::3])
166
+ dataset["annotations"].append(ann)
167
+ ann_id += 1
168
+ dataset["categories"] = [{"id": i} for i in sorted(categories)]
169
+ coco_ds.dataset = dataset
170
+ coco_ds.createIndex()
171
+ return coco_ds
172
+
173
+
174
+ def get_coco_api_from_dataset(dataset):
175
+ # FIXME: This is... awful?
176
+ for _ in range(10):
177
+ if isinstance(dataset, torchvision.datasets.CocoDetection):
178
+ break
179
+ if isinstance(dataset, torch.utils.data.Subset):
180
+ dataset = dataset.dataset
181
+ if isinstance(dataset, torchvision.datasets.CocoDetection):
182
+ return dataset.coco
183
+ return convert_to_coco_api(dataset)
184
+
185
+
186
+ class CocoDetection(torchvision.datasets.CocoDetection):
187
+ def __init__(self, img_folder, ann_file, transforms):
188
+ super().__init__(img_folder, ann_file)
189
+ self._transforms = transforms
190
+
191
+ def __getitem__(self, idx):
192
+ img, target = super().__getitem__(idx)
193
+ image_id = self.ids[idx]
194
+ target = dict(image_id=image_id, annotations=target)
195
+ if self._transforms is not None:
196
+ img, target = self._transforms(img, target)
197
+ return img, target
198
+
199
+
200
+ def get_coco(root, image_set, transforms, mode="instances", use_v2=False, with_masks=False):
201
+ anno_file_template = "{}_{}2017.json"
202
+ PATHS = {
203
+ "train": ("train2017", os.path.join("annotations", anno_file_template.format(mode, "train"))),
204
+ "val": ("val2017", os.path.join("annotations", anno_file_template.format(mode, "val"))),
205
+ # "train": ("val2017", os.path.join("annotations", anno_file_template.format(mode, "val")))
206
+ }
207
+
208
+ img_folder, ann_file = PATHS[image_set]
209
+ img_folder = os.path.join(root, img_folder)
210
+ ann_file = os.path.join(root, ann_file)
211
+
212
+ if use_v2:
213
+ from torchvision.datasets import wrap_dataset_for_transforms_v2
214
+
215
+ dataset = torchvision.datasets.CocoDetection(img_folder, ann_file, transforms=transforms)
216
+ target_keys = ["boxes", "labels", "image_id"]
217
+ if with_masks:
218
+ target_keys += ["masks"]
219
+ dataset = wrap_dataset_for_transforms_v2(dataset, target_keys=target_keys)
220
+ else:
221
+ # TODO: handle with_masks for V1?
222
+ t = [ConvertCocoPolysToMask()]
223
+ if transforms is not None:
224
+ t.append(transforms)
225
+ transforms = T.Compose(t)
226
+
227
+ dataset = CocoDetection(img_folder, ann_file, transforms=transforms)
228
+
229
+ if image_set == "train":
230
+ dataset = _coco_remove_images_without_annotations(dataset)
231
+
232
+ # dataset = torch.utils.data.Subset(dataset, [i for i in range(500)])
233
+
234
+ return dataset
engine.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import sys
3
+ import time
4
+
5
+ import torch
6
+ import torchvision.models.detection.mask_rcnn
7
+ import utils
8
+ from coco_eval import CocoEvaluator
9
+ from coco_utils import get_coco_api_from_dataset
10
+
11
+
12
+ def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq, scaler=None):
13
+ model.train()
14
+ metric_logger = utils.MetricLogger(delimiter=" ")
15
+ metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value:.6f}"))
16
+ header = f"Epoch: [{epoch}]"
17
+
18
+ lr_scheduler = None
19
+ if epoch == 0:
20
+ warmup_factor = 1.0 / 1000
21
+ warmup_iters = min(1000, len(data_loader) - 1)
22
+
23
+ lr_scheduler = torch.optim.lr_scheduler.LinearLR(
24
+ optimizer, start_factor=warmup_factor, total_iters=warmup_iters
25
+ )
26
+
27
+ for images, targets in metric_logger.log_every(data_loader, print_freq, header):
28
+ images = list(image.to(device) for image in images)
29
+ targets = [{k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in t.items()} for t in targets]
30
+ with torch.cuda.amp.autocast(enabled=scaler is not None):
31
+ loss_dict = model(images, targets)
32
+ losses = sum(loss for loss in loss_dict.values())
33
+
34
+ # reduce losses over all GPUs for logging purposes
35
+ loss_dict_reduced = utils.reduce_dict(loss_dict)
36
+ losses_reduced = sum(loss for loss in loss_dict_reduced.values())
37
+
38
+ loss_value = losses_reduced.item()
39
+
40
+ if not math.isfinite(loss_value):
41
+ print(f"Loss is {loss_value}, stopping training")
42
+ print(loss_dict_reduced)
43
+ sys.exit(1)
44
+
45
+ optimizer.zero_grad()
46
+ if scaler is not None:
47
+ scaler.scale(losses).backward()
48
+ scaler.step(optimizer)
49
+ scaler.update()
50
+ else:
51
+ losses.backward()
52
+ optimizer.step()
53
+
54
+ if lr_scheduler is not None:
55
+ lr_scheduler.step()
56
+
57
+ metric_logger.update(loss=losses_reduced, **loss_dict_reduced)
58
+ metric_logger.update(lr=optimizer.param_groups[0]["lr"])
59
+
60
+ return metric_logger
61
+
62
+
63
+ def _get_iou_types(model):
64
+ model_without_ddp = model
65
+ if isinstance(model, torch.nn.parallel.DistributedDataParallel):
66
+ model_without_ddp = model.module
67
+ iou_types = ["bbox"]
68
+ if isinstance(model_without_ddp, torchvision.models.detection.MaskRCNN):
69
+ iou_types.append("segm")
70
+ if isinstance(model_without_ddp, torchvision.models.detection.KeypointRCNN):
71
+ iou_types.append("keypoints")
72
+ return iou_types
73
+
74
+
75
+ @torch.inference_mode()
76
+ def evaluate(model, data_loader, device):
77
+ n_threads = torch.get_num_threads()
78
+ # FIXME remove this and make paste_masks_in_image run on the GPU
79
+ torch.set_num_threads(1)
80
+ cpu_device = torch.device("cpu")
81
+ model.eval()
82
+ metric_logger = utils.MetricLogger(delimiter=" ")
83
+ header = "Test:"
84
+
85
+ coco = get_coco_api_from_dataset(data_loader.dataset)
86
+ iou_types = _get_iou_types(model)
87
+ coco_evaluator = CocoEvaluator(coco, iou_types)
88
+
89
+ for images, targets in metric_logger.log_every(data_loader, 100, header):
90
+ images = list(img.to(device) for img in images)
91
+
92
+ if torch.cuda.is_available():
93
+ torch.cuda.synchronize()
94
+ model_time = time.time()
95
+ outputs = model(images)
96
+
97
+ outputs = [{k: v.to(cpu_device) for k, v in t.items()} for t in outputs]
98
+ model_time = time.time() - model_time
99
+
100
+ res = {target["image_id"]: output for target, output in zip(targets, outputs)}
101
+ evaluator_time = time.time()
102
+ coco_evaluator.update(res)
103
+ evaluator_time = time.time() - evaluator_time
104
+ metric_logger.update(model_time=model_time, evaluator_time=evaluator_time)
105
+
106
+ # gather the stats from all processes
107
+ metric_logger.synchronize_between_processes()
108
+ print("Averaged stats:", metric_logger)
109
+ coco_evaluator.synchronize_between_processes()
110
+
111
+ # accumulate predictions from all images
112
+ coco_evaluator.accumulate()
113
+ coco_evaluator.summarize()
114
+ torch.set_num_threads(n_threads)
115
+ return coco_evaluator
torchvision.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
transforms.py ADDED
@@ -0,0 +1,601 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Optional, Tuple, Union
2
+
3
+ import torch
4
+ import torchvision
5
+ from torch import nn, Tensor
6
+ from torchvision import ops
7
+ from torchvision.transforms import functional as F, InterpolationMode, transforms as T
8
+
9
+
10
+ def _flip_coco_person_keypoints(kps, width):
11
+ flip_inds = [0, 2, 1, 4, 3, 6, 5, 8, 7, 10, 9, 12, 11, 14, 13, 16, 15]
12
+ flipped_data = kps[:, flip_inds]
13
+ flipped_data[..., 0] = width - flipped_data[..., 0]
14
+ # Maintain COCO convention that if visibility == 0, then x, y = 0
15
+ inds = flipped_data[..., 2] == 0
16
+ flipped_data[inds] = 0
17
+ return flipped_data
18
+
19
+
20
+ class Compose:
21
+ def __init__(self, transforms):
22
+ self.transforms = transforms
23
+
24
+ def __call__(self, image, target):
25
+ for t in self.transforms:
26
+ image, target = t(image, target)
27
+ return image, target
28
+
29
+
30
+ class RandomHorizontalFlip(T.RandomHorizontalFlip):
31
+ def forward(
32
+ self, image: Tensor, target: Optional[Dict[str, Tensor]] = None
33
+ ) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
34
+ if torch.rand(1) < self.p:
35
+ image = F.hflip(image)
36
+ if target is not None:
37
+ _, _, width = F.get_dimensions(image)
38
+ target["boxes"][:, [0, 2]] = width - target["boxes"][:, [2, 0]]
39
+ if "masks" in target:
40
+ target["masks"] = target["masks"].flip(-1)
41
+ if "keypoints" in target:
42
+ keypoints = target["keypoints"]
43
+ keypoints = _flip_coco_person_keypoints(keypoints, width)
44
+ target["keypoints"] = keypoints
45
+ return image, target
46
+
47
+
48
+ class PILToTensor(nn.Module):
49
+ def forward(
50
+ self, image: Tensor, target: Optional[Dict[str, Tensor]] = None
51
+ ) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
52
+ image = F.pil_to_tensor(image)
53
+ return image, target
54
+
55
+
56
+ class ToDtype(nn.Module):
57
+ def __init__(self, dtype: torch.dtype, scale: bool = False) -> None:
58
+ super().__init__()
59
+ self.dtype = dtype
60
+ self.scale = scale
61
+
62
+ def forward(
63
+ self, image: Tensor, target: Optional[Dict[str, Tensor]] = None
64
+ ) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
65
+ if not self.scale:
66
+ return image.to(dtype=self.dtype), target
67
+ image = F.convert_image_dtype(image, self.dtype)
68
+ return image, target
69
+
70
+
71
+ class RandomIoUCrop(nn.Module):
72
+ def __init__(
73
+ self,
74
+ min_scale: float = 0.3,
75
+ max_scale: float = 1.0,
76
+ min_aspect_ratio: float = 0.5,
77
+ max_aspect_ratio: float = 2.0,
78
+ sampler_options: Optional[List[float]] = None,
79
+ trials: int = 40,
80
+ ):
81
+ super().__init__()
82
+ # Configuration similar to https://github.com/weiliu89/caffe/blob/ssd/examples/ssd/ssd_coco.py#L89-L174
83
+ self.min_scale = min_scale
84
+ self.max_scale = max_scale
85
+ self.min_aspect_ratio = min_aspect_ratio
86
+ self.max_aspect_ratio = max_aspect_ratio
87
+ if sampler_options is None:
88
+ sampler_options = [0.0, 0.1, 0.3, 0.5, 0.7, 0.9, 1.0]
89
+ self.options = sampler_options
90
+ self.trials = trials
91
+
92
+ def forward(
93
+ self, image: Tensor, target: Optional[Dict[str, Tensor]] = None
94
+ ) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
95
+ if target is None:
96
+ raise ValueError("The targets can't be None for this transform.")
97
+
98
+ if isinstance(image, torch.Tensor):
99
+ if image.ndimension() not in {2, 3}:
100
+ raise ValueError(f"image should be 2/3 dimensional. Got {image.ndimension()} dimensions.")
101
+ elif image.ndimension() == 2:
102
+ image = image.unsqueeze(0)
103
+
104
+ _, orig_h, orig_w = F.get_dimensions(image)
105
+
106
+ while True:
107
+ # sample an option
108
+ idx = int(torch.randint(low=0, high=len(self.options), size=(1,)))
109
+ min_jaccard_overlap = self.options[idx]
110
+ if min_jaccard_overlap >= 1.0: # a value larger than 1 encodes the leave as-is option
111
+ return image, target
112
+
113
+ for _ in range(self.trials):
114
+ # check the aspect ratio limitations
115
+ r = self.min_scale + (self.max_scale - self.min_scale) * torch.rand(2)
116
+ new_w = int(orig_w * r[0])
117
+ new_h = int(orig_h * r[1])
118
+ aspect_ratio = new_w / new_h
119
+ if not (self.min_aspect_ratio <= aspect_ratio <= self.max_aspect_ratio):
120
+ continue
121
+
122
+ # check for 0 area crops
123
+ r = torch.rand(2)
124
+ left = int((orig_w - new_w) * r[0])
125
+ top = int((orig_h - new_h) * r[1])
126
+ right = left + new_w
127
+ bottom = top + new_h
128
+ if left == right or top == bottom:
129
+ continue
130
+
131
+ # check for any valid boxes with centers within the crop area
132
+ cx = 0.5 * (target["boxes"][:, 0] + target["boxes"][:, 2])
133
+ cy = 0.5 * (target["boxes"][:, 1] + target["boxes"][:, 3])
134
+ is_within_crop_area = (left < cx) & (cx < right) & (top < cy) & (cy < bottom)
135
+ if not is_within_crop_area.any():
136
+ continue
137
+
138
+ # check at least 1 box with jaccard limitations
139
+ boxes = target["boxes"][is_within_crop_area]
140
+ ious = torchvision.ops.boxes.box_iou(
141
+ boxes, torch.tensor([[left, top, right, bottom]], dtype=boxes.dtype, device=boxes.device)
142
+ )
143
+ if ious.max() < min_jaccard_overlap:
144
+ continue
145
+
146
+ # keep only valid boxes and perform cropping
147
+ target["boxes"] = boxes
148
+ target["labels"] = target["labels"][is_within_crop_area]
149
+ target["boxes"][:, 0::2] -= left
150
+ target["boxes"][:, 1::2] -= top
151
+ target["boxes"][:, 0::2].clamp_(min=0, max=new_w)
152
+ target["boxes"][:, 1::2].clamp_(min=0, max=new_h)
153
+ image = F.crop(image, top, left, new_h, new_w)
154
+
155
+ return image, target
156
+
157
+
158
+ class RandomZoomOut(nn.Module):
159
+ def __init__(
160
+ self, fill: Optional[List[float]] = None, side_range: Tuple[float, float] = (1.0, 4.0), p: float = 0.5
161
+ ):
162
+ super().__init__()
163
+ if fill is None:
164
+ fill = [0.0, 0.0, 0.0]
165
+ self.fill = fill
166
+ self.side_range = side_range
167
+ if side_range[0] < 1.0 or side_range[0] > side_range[1]:
168
+ raise ValueError(f"Invalid canvas side range provided {side_range}.")
169
+ self.p = p
170
+
171
+ @torch.jit.unused
172
+ def _get_fill_value(self, is_pil):
173
+ # type: (bool) -> int
174
+ # We fake the type to make it work on JIT
175
+ return tuple(int(x) for x in self.fill) if is_pil else 0
176
+
177
+ def forward(
178
+ self, image: Tensor, target: Optional[Dict[str, Tensor]] = None
179
+ ) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
180
+ if isinstance(image, torch.Tensor):
181
+ if image.ndimension() not in {2, 3}:
182
+ raise ValueError(f"image should be 2/3 dimensional. Got {image.ndimension()} dimensions.")
183
+ elif image.ndimension() == 2:
184
+ image = image.unsqueeze(0)
185
+
186
+ if torch.rand(1) >= self.p:
187
+ return image, target
188
+
189
+ _, orig_h, orig_w = F.get_dimensions(image)
190
+
191
+ r = self.side_range[0] + torch.rand(1) * (self.side_range[1] - self.side_range[0])
192
+ canvas_width = int(orig_w * r)
193
+ canvas_height = int(orig_h * r)
194
+
195
+ r = torch.rand(2)
196
+ left = int((canvas_width - orig_w) * r[0])
197
+ top = int((canvas_height - orig_h) * r[1])
198
+ right = canvas_width - (left + orig_w)
199
+ bottom = canvas_height - (top + orig_h)
200
+
201
+ if torch.jit.is_scripting():
202
+ fill = 0
203
+ else:
204
+ fill = self._get_fill_value(F._is_pil_image(image))
205
+
206
+ image = F.pad(image, [left, top, right, bottom], fill=fill)
207
+ if isinstance(image, torch.Tensor):
208
+ # PyTorch's pad supports only integers on fill. So we need to overwrite the colour
209
+ v = torch.tensor(self.fill, device=image.device, dtype=image.dtype).view(-1, 1, 1)
210
+ image[..., :top, :] = image[..., :, :left] = image[..., (top + orig_h) :, :] = image[
211
+ ..., :, (left + orig_w) :
212
+ ] = v
213
+
214
+ if target is not None:
215
+ target["boxes"][:, 0::2] += left
216
+ target["boxes"][:, 1::2] += top
217
+
218
+ return image, target
219
+
220
+
221
+ class RandomPhotometricDistort(nn.Module):
222
+ def __init__(
223
+ self,
224
+ contrast: Tuple[float, float] = (0.5, 1.5),
225
+ saturation: Tuple[float, float] = (0.5, 1.5),
226
+ hue: Tuple[float, float] = (-0.05, 0.05),
227
+ brightness: Tuple[float, float] = (0.875, 1.125),
228
+ p: float = 0.5,
229
+ ):
230
+ super().__init__()
231
+ self._brightness = T.ColorJitter(brightness=brightness)
232
+ self._contrast = T.ColorJitter(contrast=contrast)
233
+ self._hue = T.ColorJitter(hue=hue)
234
+ self._saturation = T.ColorJitter(saturation=saturation)
235
+ self.p = p
236
+
237
+ def forward(
238
+ self, image: Tensor, target: Optional[Dict[str, Tensor]] = None
239
+ ) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
240
+ if isinstance(image, torch.Tensor):
241
+ if image.ndimension() not in {2, 3}:
242
+ raise ValueError(f"image should be 2/3 dimensional. Got {image.ndimension()} dimensions.")
243
+ elif image.ndimension() == 2:
244
+ image = image.unsqueeze(0)
245
+
246
+ r = torch.rand(7)
247
+
248
+ if r[0] < self.p:
249
+ image = self._brightness(image)
250
+
251
+ contrast_before = r[1] < 0.5
252
+ if contrast_before:
253
+ if r[2] < self.p:
254
+ image = self._contrast(image)
255
+
256
+ if r[3] < self.p:
257
+ image = self._saturation(image)
258
+
259
+ if r[4] < self.p:
260
+ image = self._hue(image)
261
+
262
+ if not contrast_before:
263
+ if r[5] < self.p:
264
+ image = self._contrast(image)
265
+
266
+ if r[6] < self.p:
267
+ channels, _, _ = F.get_dimensions(image)
268
+ permutation = torch.randperm(channels)
269
+
270
+ is_pil = F._is_pil_image(image)
271
+ if is_pil:
272
+ image = F.pil_to_tensor(image)
273
+ image = F.convert_image_dtype(image)
274
+ image = image[..., permutation, :, :]
275
+ if is_pil:
276
+ image = F.to_pil_image(image)
277
+
278
+ return image, target
279
+
280
+
281
+ class ScaleJitter(nn.Module):
282
+ """Randomly resizes the image and its bounding boxes within the specified scale range.
283
+ The class implements the Scale Jitter augmentation as described in the paper
284
+ `"Simple Copy-Paste is a Strong Data Augmentation Method for Instance Segmentation" <https://arxiv.org/abs/2012.07177>`_.
285
+
286
+ Args:
287
+ target_size (tuple of ints): The target size for the transform provided in (height, weight) format.
288
+ scale_range (tuple of ints): scaling factor interval, e.g (a, b), then scale is randomly sampled from the
289
+ range a <= scale <= b.
290
+ interpolation (InterpolationMode): Desired interpolation enum defined by
291
+ :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``.
292
+ """
293
+
294
+ def __init__(
295
+ self,
296
+ target_size: Tuple[int, int],
297
+ scale_range: Tuple[float, float] = (0.1, 2.0),
298
+ interpolation: InterpolationMode = InterpolationMode.BILINEAR,
299
+ antialias=True,
300
+ ):
301
+ super().__init__()
302
+ self.target_size = target_size
303
+ self.scale_range = scale_range
304
+ self.interpolation = interpolation
305
+ self.antialias = antialias
306
+
307
+ def forward(
308
+ self, image: Tensor, target: Optional[Dict[str, Tensor]] = None
309
+ ) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
310
+ if isinstance(image, torch.Tensor):
311
+ if image.ndimension() not in {2, 3}:
312
+ raise ValueError(f"image should be 2/3 dimensional. Got {image.ndimension()} dimensions.")
313
+ elif image.ndimension() == 2:
314
+ image = image.unsqueeze(0)
315
+
316
+ _, orig_height, orig_width = F.get_dimensions(image)
317
+
318
+ scale = self.scale_range[0] + torch.rand(1) * (self.scale_range[1] - self.scale_range[0])
319
+ r = min(self.target_size[1] / orig_height, self.target_size[0] / orig_width) * scale
320
+ new_width = int(orig_width * r)
321
+ new_height = int(orig_height * r)
322
+
323
+ image = F.resize(image, [new_height, new_width], interpolation=self.interpolation, antialias=self.antialias)
324
+
325
+ if target is not None:
326
+ target["boxes"][:, 0::2] *= new_width / orig_width
327
+ target["boxes"][:, 1::2] *= new_height / orig_height
328
+ if "masks" in target:
329
+ target["masks"] = F.resize(
330
+ target["masks"],
331
+ [new_height, new_width],
332
+ interpolation=InterpolationMode.NEAREST,
333
+ antialias=self.antialias,
334
+ )
335
+
336
+ return image, target
337
+
338
+
339
+ class FixedSizeCrop(nn.Module):
340
+ def __init__(self, size, fill=0, padding_mode="constant"):
341
+ super().__init__()
342
+ size = tuple(T._setup_size(size, error_msg="Please provide only two dimensions (h, w) for size."))
343
+ self.crop_height = size[0]
344
+ self.crop_width = size[1]
345
+ self.fill = fill # TODO: Fill is currently respected only on PIL. Apply tensor patch.
346
+ self.padding_mode = padding_mode
347
+
348
+ def _pad(self, img, target, padding):
349
+ # Taken from the functional_tensor.py pad
350
+ if isinstance(padding, int):
351
+ pad_left = pad_right = pad_top = pad_bottom = padding
352
+ elif len(padding) == 1:
353
+ pad_left = pad_right = pad_top = pad_bottom = padding[0]
354
+ elif len(padding) == 2:
355
+ pad_left = pad_right = padding[0]
356
+ pad_top = pad_bottom = padding[1]
357
+ else:
358
+ pad_left = padding[0]
359
+ pad_top = padding[1]
360
+ pad_right = padding[2]
361
+ pad_bottom = padding[3]
362
+
363
+ padding = [pad_left, pad_top, pad_right, pad_bottom]
364
+ img = F.pad(img, padding, self.fill, self.padding_mode)
365
+ if target is not None:
366
+ target["boxes"][:, 0::2] += pad_left
367
+ target["boxes"][:, 1::2] += pad_top
368
+ if "masks" in target:
369
+ target["masks"] = F.pad(target["masks"], padding, 0, "constant")
370
+
371
+ return img, target
372
+
373
+ def _crop(self, img, target, top, left, height, width):
374
+ img = F.crop(img, top, left, height, width)
375
+ if target is not None:
376
+ boxes = target["boxes"]
377
+ boxes[:, 0::2] -= left
378
+ boxes[:, 1::2] -= top
379
+ boxes[:, 0::2].clamp_(min=0, max=width)
380
+ boxes[:, 1::2].clamp_(min=0, max=height)
381
+
382
+ is_valid = (boxes[:, 0] < boxes[:, 2]) & (boxes[:, 1] < boxes[:, 3])
383
+
384
+ target["boxes"] = boxes[is_valid]
385
+ target["labels"] = target["labels"][is_valid]
386
+ if "masks" in target:
387
+ target["masks"] = F.crop(target["masks"][is_valid], top, left, height, width)
388
+
389
+ return img, target
390
+
391
+ def forward(self, img, target=None):
392
+ _, height, width = F.get_dimensions(img)
393
+ new_height = min(height, self.crop_height)
394
+ new_width = min(width, self.crop_width)
395
+
396
+ if new_height != height or new_width != width:
397
+ offset_height = max(height - self.crop_height, 0)
398
+ offset_width = max(width - self.crop_width, 0)
399
+
400
+ r = torch.rand(1)
401
+ top = int(offset_height * r)
402
+ left = int(offset_width * r)
403
+
404
+ img, target = self._crop(img, target, top, left, new_height, new_width)
405
+
406
+ pad_bottom = max(self.crop_height - new_height, 0)
407
+ pad_right = max(self.crop_width - new_width, 0)
408
+ if pad_bottom != 0 or pad_right != 0:
409
+ img, target = self._pad(img, target, [0, 0, pad_right, pad_bottom])
410
+
411
+ return img, target
412
+
413
+
414
+ class RandomShortestSize(nn.Module):
415
+ def __init__(
416
+ self,
417
+ min_size: Union[List[int], Tuple[int], int],
418
+ max_size: int,
419
+ interpolation: InterpolationMode = InterpolationMode.BILINEAR,
420
+ ):
421
+ super().__init__()
422
+ self.min_size = [min_size] if isinstance(min_size, int) else list(min_size)
423
+ self.max_size = max_size
424
+ self.interpolation = interpolation
425
+
426
+ def forward(
427
+ self, image: Tensor, target: Optional[Dict[str, Tensor]] = None
428
+ ) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
429
+ _, orig_height, orig_width = F.get_dimensions(image)
430
+
431
+ min_size = self.min_size[torch.randint(len(self.min_size), (1,)).item()]
432
+ r = min(min_size / min(orig_height, orig_width), self.max_size / max(orig_height, orig_width))
433
+
434
+ new_width = int(orig_width * r)
435
+ new_height = int(orig_height * r)
436
+
437
+ image = F.resize(image, [new_height, new_width], interpolation=self.interpolation)
438
+
439
+ if target is not None:
440
+ target["boxes"][:, 0::2] *= new_width / orig_width
441
+ target["boxes"][:, 1::2] *= new_height / orig_height
442
+ if "masks" in target:
443
+ target["masks"] = F.resize(
444
+ target["masks"], [new_height, new_width], interpolation=InterpolationMode.NEAREST
445
+ )
446
+
447
+ return image, target
448
+
449
+
450
+ def _copy_paste(
451
+ image: torch.Tensor,
452
+ target: Dict[str, Tensor],
453
+ paste_image: torch.Tensor,
454
+ paste_target: Dict[str, Tensor],
455
+ blending: bool = True,
456
+ resize_interpolation: F.InterpolationMode = F.InterpolationMode.BILINEAR,
457
+ ) -> Tuple[torch.Tensor, Dict[str, Tensor]]:
458
+
459
+ # Random paste targets selection:
460
+ num_masks = len(paste_target["masks"])
461
+
462
+ if num_masks < 1:
463
+ # Such degerante case with num_masks=0 can happen with LSJ
464
+ # Let's just return (image, target)
465
+ return image, target
466
+
467
+ # We have to please torch script by explicitly specifying dtype as torch.long
468
+ random_selection = torch.randint(0, num_masks, (num_masks,), device=paste_image.device)
469
+ random_selection = torch.unique(random_selection).to(torch.long)
470
+
471
+ paste_masks = paste_target["masks"][random_selection]
472
+ paste_boxes = paste_target["boxes"][random_selection]
473
+ paste_labels = paste_target["labels"][random_selection]
474
+
475
+ masks = target["masks"]
476
+
477
+ # We resize source and paste data if they have different sizes
478
+ # This is something we introduced here as originally the algorithm works
479
+ # on equal-sized data (for example, coming from LSJ data augmentations)
480
+ size1 = image.shape[-2:]
481
+ size2 = paste_image.shape[-2:]
482
+ if size1 != size2:
483
+ paste_image = F.resize(paste_image, size1, interpolation=resize_interpolation)
484
+ paste_masks = F.resize(paste_masks, size1, interpolation=F.InterpolationMode.NEAREST)
485
+ # resize bboxes:
486
+ ratios = torch.tensor((size1[1] / size2[1], size1[0] / size2[0]), device=paste_boxes.device)
487
+ paste_boxes = paste_boxes.view(-1, 2, 2).mul(ratios).view(paste_boxes.shape)
488
+
489
+ paste_alpha_mask = paste_masks.sum(dim=0) > 0
490
+
491
+ if blending:
492
+ paste_alpha_mask = F.gaussian_blur(
493
+ paste_alpha_mask.unsqueeze(0),
494
+ kernel_size=(5, 5),
495
+ sigma=[
496
+ 2.0,
497
+ ],
498
+ )
499
+
500
+ # Copy-paste images:
501
+ image = (image * (~paste_alpha_mask)) + (paste_image * paste_alpha_mask)
502
+
503
+ # Copy-paste masks:
504
+ masks = masks * (~paste_alpha_mask)
505
+ non_all_zero_masks = masks.sum((-1, -2)) > 0
506
+ masks = masks[non_all_zero_masks]
507
+
508
+ # Do a shallow copy of the target dict
509
+ out_target = {k: v for k, v in target.items()}
510
+
511
+ out_target["masks"] = torch.cat([masks, paste_masks])
512
+
513
+ # Copy-paste boxes and labels
514
+ boxes = ops.masks_to_boxes(masks)
515
+ out_target["boxes"] = torch.cat([boxes, paste_boxes])
516
+
517
+ labels = target["labels"][non_all_zero_masks]
518
+ out_target["labels"] = torch.cat([labels, paste_labels])
519
+
520
+ # Update additional optional keys: area and iscrowd if exist
521
+ if "area" in target:
522
+ out_target["area"] = out_target["masks"].sum((-1, -2)).to(torch.float32)
523
+
524
+ if "iscrowd" in target and "iscrowd" in paste_target:
525
+ # target['iscrowd'] size can be differ from mask size (non_all_zero_masks)
526
+ # For example, if previous transforms geometrically modifies masks/boxes/labels but
527
+ # does not update "iscrowd"
528
+ if len(target["iscrowd"]) == len(non_all_zero_masks):
529
+ iscrowd = target["iscrowd"][non_all_zero_masks]
530
+ paste_iscrowd = paste_target["iscrowd"][random_selection]
531
+ out_target["iscrowd"] = torch.cat([iscrowd, paste_iscrowd])
532
+
533
+ # Check for degenerated boxes and remove them
534
+ boxes = out_target["boxes"]
535
+ degenerate_boxes = boxes[:, 2:] <= boxes[:, :2]
536
+ if degenerate_boxes.any():
537
+ valid_targets = ~degenerate_boxes.any(dim=1)
538
+
539
+ out_target["boxes"] = boxes[valid_targets]
540
+ out_target["masks"] = out_target["masks"][valid_targets]
541
+ out_target["labels"] = out_target["labels"][valid_targets]
542
+
543
+ if "area" in out_target:
544
+ out_target["area"] = out_target["area"][valid_targets]
545
+ if "iscrowd" in out_target and len(out_target["iscrowd"]) == len(valid_targets):
546
+ out_target["iscrowd"] = out_target["iscrowd"][valid_targets]
547
+
548
+ return image, out_target
549
+
550
+
551
+ class SimpleCopyPaste(torch.nn.Module):
552
+ def __init__(self, blending=True, resize_interpolation=F.InterpolationMode.BILINEAR):
553
+ super().__init__()
554
+ self.resize_interpolation = resize_interpolation
555
+ self.blending = blending
556
+
557
+ def forward(
558
+ self, images: List[torch.Tensor], targets: List[Dict[str, Tensor]]
559
+ ) -> Tuple[List[torch.Tensor], List[Dict[str, Tensor]]]:
560
+ torch._assert(
561
+ isinstance(images, (list, tuple)) and all([isinstance(v, torch.Tensor) for v in images]),
562
+ "images should be a list of tensors",
563
+ )
564
+ torch._assert(
565
+ isinstance(targets, (list, tuple)) and len(images) == len(targets),
566
+ "targets should be a list of the same size as images",
567
+ )
568
+ for target in targets:
569
+ # Can not check for instance type dict with inside torch.jit.script
570
+ # torch._assert(isinstance(target, dict), "targets item should be a dict")
571
+ for k in ["masks", "boxes", "labels"]:
572
+ torch._assert(k in target, f"Key {k} should be present in targets")
573
+ torch._assert(isinstance(target[k], torch.Tensor), f"Value for the key {k} should be a tensor")
574
+
575
+ # images = [t1, t2, ..., tN]
576
+ # Let's define paste_images as shifted list of input images
577
+ # paste_images = [t2, t3, ..., tN, t1]
578
+ # FYI: in TF they mix data on the dataset level
579
+ images_rolled = images[-1:] + images[:-1]
580
+ targets_rolled = targets[-1:] + targets[:-1]
581
+
582
+ output_images: List[torch.Tensor] = []
583
+ output_targets: List[Dict[str, Tensor]] = []
584
+
585
+ for image, target, paste_image, paste_target in zip(images, targets, images_rolled, targets_rolled):
586
+ output_image, output_data = _copy_paste(
587
+ image,
588
+ target,
589
+ paste_image,
590
+ paste_target,
591
+ blending=self.blending,
592
+ resize_interpolation=self.resize_interpolation,
593
+ )
594
+ output_images.append(output_image)
595
+ output_targets.append(output_data)
596
+
597
+ return output_images, output_targets
598
+
599
+ def __repr__(self) -> str:
600
+ s = f"{self.__class__.__name__}(blending={self.blending}, resize_interpolation={self.resize_interpolation})"
601
+ return s
utils.py ADDED
@@ -0,0 +1,282 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime
2
+ import errno
3
+ import os
4
+ import time
5
+ from collections import defaultdict, deque
6
+
7
+ import torch
8
+ import torch.distributed as dist
9
+
10
+
11
+ class SmoothedValue:
12
+ """Track a series of values and provide access to smoothed values over a
13
+ window or the global series average.
14
+ """
15
+
16
+ def __init__(self, window_size=20, fmt=None):
17
+ if fmt is None:
18
+ fmt = "{median:.4f} ({global_avg:.4f})"
19
+ self.deque = deque(maxlen=window_size)
20
+ self.total = 0.0
21
+ self.count = 0
22
+ self.fmt = fmt
23
+
24
+ def update(self, value, n=1):
25
+ self.deque.append(value)
26
+ self.count += n
27
+ self.total += value * n
28
+
29
+ def synchronize_between_processes(self):
30
+ """
31
+ Warning: does not synchronize the deque!
32
+ """
33
+ if not is_dist_avail_and_initialized():
34
+ return
35
+ t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda")
36
+ dist.barrier()
37
+ dist.all_reduce(t)
38
+ t = t.tolist()
39
+ self.count = int(t[0])
40
+ self.total = t[1]
41
+
42
+ @property
43
+ def median(self):
44
+ d = torch.tensor(list(self.deque))
45
+ return d.median().item()
46
+
47
+ @property
48
+ def avg(self):
49
+ d = torch.tensor(list(self.deque), dtype=torch.float32)
50
+ return d.mean().item()
51
+
52
+ @property
53
+ def global_avg(self):
54
+ return self.total / self.count
55
+
56
+ @property
57
+ def max(self):
58
+ return max(self.deque)
59
+
60
+ @property
61
+ def value(self):
62
+ return self.deque[-1]
63
+
64
+ def __str__(self):
65
+ return self.fmt.format(
66
+ median=self.median, avg=self.avg, global_avg=self.global_avg, max=self.max, value=self.value
67
+ )
68
+
69
+
70
+ def all_gather(data):
71
+ """
72
+ Run all_gather on arbitrary picklable data (not necessarily tensors)
73
+ Args:
74
+ data: any picklable object
75
+ Returns:
76
+ list[data]: list of data gathered from each rank
77
+ """
78
+ world_size = get_world_size()
79
+ if world_size == 1:
80
+ return [data]
81
+ data_list = [None] * world_size
82
+ dist.all_gather_object(data_list, data)
83
+ return data_list
84
+
85
+
86
+ def reduce_dict(input_dict, average=True):
87
+ """
88
+ Args:
89
+ input_dict (dict): all the values will be reduced
90
+ average (bool): whether to do average or sum
91
+ Reduce the values in the dictionary from all processes so that all processes
92
+ have the averaged results. Returns a dict with the same fields as
93
+ input_dict, after reduction.
94
+ """
95
+ world_size = get_world_size()
96
+ if world_size < 2:
97
+ return input_dict
98
+ with torch.inference_mode():
99
+ names = []
100
+ values = []
101
+ # sort the keys so that they are consistent across processes
102
+ for k in sorted(input_dict.keys()):
103
+ names.append(k)
104
+ values.append(input_dict[k])
105
+ values = torch.stack(values, dim=0)
106
+ dist.all_reduce(values)
107
+ if average:
108
+ values /= world_size
109
+ reduced_dict = {k: v for k, v in zip(names, values)}
110
+ return reduced_dict
111
+
112
+
113
+ class MetricLogger:
114
+ def __init__(self, delimiter="\t"):
115
+ self.meters = defaultdict(SmoothedValue)
116
+ self.delimiter = delimiter
117
+
118
+ def update(self, **kwargs):
119
+ for k, v in kwargs.items():
120
+ if isinstance(v, torch.Tensor):
121
+ v = v.item()
122
+ assert isinstance(v, (float, int))
123
+ self.meters[k].update(v)
124
+
125
+ def __getattr__(self, attr):
126
+ if attr in self.meters:
127
+ return self.meters[attr]
128
+ if attr in self.__dict__:
129
+ return self.__dict__[attr]
130
+ raise AttributeError(f"'{type(self).__name__}' object has no attribute '{attr}'")
131
+
132
+ def __str__(self):
133
+ loss_str = []
134
+ for name, meter in self.meters.items():
135
+ loss_str.append(f"{name}: {str(meter)}")
136
+ return self.delimiter.join(loss_str)
137
+
138
+ def synchronize_between_processes(self):
139
+ for meter in self.meters.values():
140
+ meter.synchronize_between_processes()
141
+
142
+ def add_meter(self, name, meter):
143
+ self.meters[name] = meter
144
+
145
+ def log_every(self, iterable, print_freq, header=None):
146
+ i = 0
147
+ if not header:
148
+ header = ""
149
+ start_time = time.time()
150
+ end = time.time()
151
+ iter_time = SmoothedValue(fmt="{avg:.4f}")
152
+ data_time = SmoothedValue(fmt="{avg:.4f}")
153
+ space_fmt = ":" + str(len(str(len(iterable)))) + "d"
154
+ if torch.cuda.is_available():
155
+ log_msg = self.delimiter.join(
156
+ [
157
+ header,
158
+ "[{0" + space_fmt + "}/{1}]",
159
+ "eta: {eta}",
160
+ "{meters}",
161
+ "time: {time}",
162
+ "data: {data}",
163
+ "max mem: {memory:.0f}",
164
+ ]
165
+ )
166
+ else:
167
+ log_msg = self.delimiter.join(
168
+ [header, "[{0" + space_fmt + "}/{1}]", "eta: {eta}", "{meters}", "time: {time}", "data: {data}"]
169
+ )
170
+ MB = 1024.0 * 1024.0
171
+ for obj in iterable:
172
+ data_time.update(time.time() - end)
173
+ yield obj
174
+ iter_time.update(time.time() - end)
175
+ if i % print_freq == 0 or i == len(iterable) - 1:
176
+ eta_seconds = iter_time.global_avg * (len(iterable) - i)
177
+ eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
178
+ if torch.cuda.is_available():
179
+ print(
180
+ log_msg.format(
181
+ i,
182
+ len(iterable),
183
+ eta=eta_string,
184
+ meters=str(self),
185
+ time=str(iter_time),
186
+ data=str(data_time),
187
+ memory=torch.cuda.max_memory_allocated() / MB,
188
+ )
189
+ )
190
+ else:
191
+ print(
192
+ log_msg.format(
193
+ i, len(iterable), eta=eta_string, meters=str(self), time=str(iter_time), data=str(data_time)
194
+ )
195
+ )
196
+ i += 1
197
+ end = time.time()
198
+ total_time = time.time() - start_time
199
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
200
+ print(f"{header} Total time: {total_time_str} ({total_time / len(iterable):.4f} s / it)")
201
+
202
+
203
+ def collate_fn(batch):
204
+ return tuple(zip(*batch))
205
+
206
+
207
+ def mkdir(path):
208
+ try:
209
+ os.makedirs(path)
210
+ except OSError as e:
211
+ if e.errno != errno.EEXIST:
212
+ raise
213
+
214
+
215
+ def setup_for_distributed(is_master):
216
+ """
217
+ This function disables printing when not in master process
218
+ """
219
+ import builtins as __builtin__
220
+
221
+ builtin_print = __builtin__.print
222
+
223
+ def print(*args, **kwargs):
224
+ force = kwargs.pop("force", False)
225
+ if is_master or force:
226
+ builtin_print(*args, **kwargs)
227
+
228
+ __builtin__.print = print
229
+
230
+
231
+ def is_dist_avail_and_initialized():
232
+ if not dist.is_available():
233
+ return False
234
+ if not dist.is_initialized():
235
+ return False
236
+ return True
237
+
238
+
239
+ def get_world_size():
240
+ if not is_dist_avail_and_initialized():
241
+ return 1
242
+ return dist.get_world_size()
243
+
244
+
245
+ def get_rank():
246
+ if not is_dist_avail_and_initialized():
247
+ return 0
248
+ return dist.get_rank()
249
+
250
+
251
+ def is_main_process():
252
+ return get_rank() == 0
253
+
254
+
255
+ def save_on_master(*args, **kwargs):
256
+ if is_main_process():
257
+ torch.save(*args, **kwargs)
258
+
259
+
260
+ def init_distributed_mode(args):
261
+ if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
262
+ args.rank = int(os.environ["RANK"])
263
+ args.world_size = int(os.environ["WORLD_SIZE"])
264
+ args.gpu = int(os.environ["LOCAL_RANK"])
265
+ elif "SLURM_PROCID" in os.environ:
266
+ args.rank = int(os.environ["SLURM_PROCID"])
267
+ args.gpu = args.rank % torch.cuda.device_count()
268
+ else:
269
+ print("Not using distributed mode")
270
+ args.distributed = False
271
+ return
272
+
273
+ args.distributed = True
274
+
275
+ torch.cuda.set_device(args.gpu)
276
+ args.dist_backend = "nccl"
277
+ print(f"| distributed init (rank {args.rank}): {args.dist_url}", flush=True)
278
+ torch.distributed.init_process_group(
279
+ backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank
280
+ )
281
+ torch.distributed.barrier()
282
+ setup_for_distributed(args.rank == 0)