b3h-young123 commited on
Commit
cc6f633
·
verified ·
1 Parent(s): eabb8da

Add files using upload-large-folder tool

Browse files
Files changed (40) hide show
  1. CatVTON/detectron2/data/__init__.py +19 -0
  2. CatVTON/detectron2/data/benchmark.py +225 -0
  3. CatVTON/detectron2/data/build.py +694 -0
  4. CatVTON/detectron2/data/catalog.py +236 -0
  5. CatVTON/detectron2/data/common.py +339 -0
  6. CatVTON/detectron2/data/dataset_mapper.py +191 -0
  7. CatVTON/detectron2/data/datasets/README.md +9 -0
  8. CatVTON/detectron2/data/datasets/__init__.py +9 -0
  9. CatVTON/detectron2/data/datasets/builtin.py +259 -0
  10. CatVTON/detectron2/data/datasets/builtin_meta.py +350 -0
  11. CatVTON/detectron2/data/datasets/cityscapes.py +334 -0
  12. CatVTON/detectron2/data/datasets/cityscapes_panoptic.py +187 -0
  13. CatVTON/detectron2/data/datasets/coco.py +555 -0
  14. CatVTON/detectron2/data/datasets/coco_panoptic.py +228 -0
  15. CatVTON/detectron2/data/datasets/lvis.py +250 -0
  16. CatVTON/detectron2/data/datasets/lvis_v0_5_categories.py +0 -0
  17. CatVTON/detectron2/data/datasets/lvis_v1_categories.py +0 -0
  18. CatVTON/detectron2/data/datasets/lvis_v1_category_image_count.py +20 -0
  19. CatVTON/detectron2/data/datasets/pascal_voc.py +82 -0
  20. CatVTON/detectron2/data/datasets/register_coco.py +3 -0
  21. CatVTON/detectron2/data/detection_utils.py +661 -0
  22. CatVTON/detectron2/data/samplers/__init__.py +17 -0
  23. CatVTON/detectron2/data/samplers/distributed_sampler.py +287 -0
  24. CatVTON/detectron2/data/samplers/grouped_batch_sampler.py +47 -0
  25. CatVTON/detectron2/data/transforms/__init__.py +14 -0
  26. CatVTON/detectron2/data/transforms/augmentation.py +380 -0
  27. CatVTON/detectron2/data/transforms/augmentation_impl.py +736 -0
  28. CatVTON/detectron2/data/transforms/transform.py +351 -0
  29. CatVTON/detectron2/export/README.md +15 -0
  30. CatVTON/detectron2/export/__init__.py +30 -0
  31. CatVTON/detectron2/export/api.py +230 -0
  32. CatVTON/detectron2/export/c10.py +571 -0
  33. CatVTON/detectron2/export/caffe2_export.py +203 -0
  34. CatVTON/detectron2/export/caffe2_inference.py +161 -0
  35. CatVTON/detectron2/export/caffe2_modeling.py +420 -0
  36. CatVTON/detectron2/export/caffe2_patch.py +189 -0
  37. CatVTON/detectron2/export/flatten.py +330 -0
  38. CatVTON/detectron2/export/shared.py +1040 -0
  39. CatVTON/detectron2/export/torchscript.py +132 -0
  40. CatVTON/detectron2/export/torchscript_patch.py +406 -0
CatVTON/detectron2/data/__init__.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ from . import transforms # isort:skip
3
+
4
+ from .build import (
5
+ build_batch_data_loader,
6
+ build_detection_test_loader,
7
+ build_detection_train_loader,
8
+ get_detection_dataset_dicts,
9
+ load_proposals_into_dataset,
10
+ print_instances_class_histogram,
11
+ )
12
+ from .catalog import DatasetCatalog, MetadataCatalog, Metadata
13
+ from .common import DatasetFromList, MapDataset, ToIterableDataset
14
+ from .dataset_mapper import DatasetMapper
15
+
16
+ # ensure the builtin datasets are registered
17
+ from . import datasets, samplers # isort:skip
18
+
19
+ __all__ = [k for k in globals().keys() if not k.startswith("_")]
CatVTON/detectron2/data/benchmark.py ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ import logging
3
+ import numpy as np
4
+ from itertools import count
5
+ from typing import List, Tuple
6
+ import torch
7
+ import tqdm
8
+ from fvcore.common.timer import Timer
9
+
10
+ from detectron2.utils import comm
11
+
12
+ from .build import build_batch_data_loader
13
+ from .common import DatasetFromList, MapDataset
14
+ from .samplers import TrainingSampler
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+
19
+ class _EmptyMapDataset(torch.utils.data.Dataset):
20
+ """
21
+ Map anything to emptiness.
22
+ """
23
+
24
+ def __init__(self, dataset):
25
+ self.ds = dataset
26
+
27
+ def __len__(self):
28
+ return len(self.ds)
29
+
30
+ def __getitem__(self, idx):
31
+ _ = self.ds[idx]
32
+ return [0]
33
+
34
+
35
+ def iter_benchmark(
36
+ iterator, num_iter: int, warmup: int = 5, max_time_seconds: float = 60
37
+ ) -> Tuple[float, List[float]]:
38
+ """
39
+ Benchmark an iterator/iterable for `num_iter` iterations with an extra
40
+ `warmup` iterations of warmup.
41
+ End early if `max_time_seconds` time is spent on iterations.
42
+
43
+ Returns:
44
+ float: average time (seconds) per iteration
45
+ list[float]: time spent on each iteration. Sometimes useful for further analysis.
46
+ """
47
+ num_iter, warmup = int(num_iter), int(warmup)
48
+
49
+ iterator = iter(iterator)
50
+ for _ in range(warmup):
51
+ next(iterator)
52
+ timer = Timer()
53
+ all_times = []
54
+ for curr_iter in tqdm.trange(num_iter):
55
+ start = timer.seconds()
56
+ if start > max_time_seconds:
57
+ num_iter = curr_iter
58
+ break
59
+ next(iterator)
60
+ all_times.append(timer.seconds() - start)
61
+ avg = timer.seconds() / num_iter
62
+ return avg, all_times
63
+
64
+
65
+ class DataLoaderBenchmark:
66
+ """
67
+ Some common benchmarks that help understand perf bottleneck of a standard dataloader
68
+ made of dataset, mapper and sampler.
69
+ """
70
+
71
+ def __init__(
72
+ self,
73
+ dataset,
74
+ *,
75
+ mapper,
76
+ sampler=None,
77
+ total_batch_size,
78
+ num_workers=0,
79
+ max_time_seconds: int = 90,
80
+ ):
81
+ """
82
+ Args:
83
+ max_time_seconds (int): maximum time to spent for each benchmark
84
+ other args: same as in `build.py:build_detection_train_loader`
85
+ """
86
+ if isinstance(dataset, list):
87
+ dataset = DatasetFromList(dataset, copy=False, serialize=True)
88
+ if sampler is None:
89
+ sampler = TrainingSampler(len(dataset))
90
+
91
+ self.dataset = dataset
92
+ self.mapper = mapper
93
+ self.sampler = sampler
94
+ self.total_batch_size = total_batch_size
95
+ self.num_workers = num_workers
96
+ self.per_gpu_batch_size = self.total_batch_size // comm.get_world_size()
97
+
98
+ self.max_time_seconds = max_time_seconds
99
+
100
+ def _benchmark(self, iterator, num_iter, warmup, msg=None):
101
+ avg, all_times = iter_benchmark(iterator, num_iter, warmup, self.max_time_seconds)
102
+ if msg is not None:
103
+ self._log_time(msg, avg, all_times)
104
+ return avg, all_times
105
+
106
+ def _log_time(self, msg, avg, all_times, distributed=False):
107
+ percentiles = [np.percentile(all_times, k, interpolation="nearest") for k in [1, 5, 95, 99]]
108
+ if not distributed:
109
+ logger.info(
110
+ f"{msg}: avg={1.0/avg:.1f} it/s, "
111
+ f"p1={percentiles[0]:.2g}s, p5={percentiles[1]:.2g}s, "
112
+ f"p95={percentiles[2]:.2g}s, p99={percentiles[3]:.2g}s."
113
+ )
114
+ return
115
+ avg_per_gpu = comm.all_gather(avg)
116
+ percentiles_per_gpu = comm.all_gather(percentiles)
117
+ if comm.get_rank() > 0:
118
+ return
119
+ for idx, avg, percentiles in zip(count(), avg_per_gpu, percentiles_per_gpu):
120
+ logger.info(
121
+ f"GPU{idx} {msg}: avg={1.0/avg:.1f} it/s, "
122
+ f"p1={percentiles[0]:.2g}s, p5={percentiles[1]:.2g}s, "
123
+ f"p95={percentiles[2]:.2g}s, p99={percentiles[3]:.2g}s."
124
+ )
125
+
126
+ def benchmark_dataset(self, num_iter, warmup=5):
127
+ """
128
+ Benchmark the speed of taking raw samples from the dataset.
129
+ """
130
+
131
+ def loader():
132
+ while True:
133
+ for k in self.sampler:
134
+ yield self.dataset[k]
135
+
136
+ self._benchmark(loader(), num_iter, warmup, "Dataset Alone")
137
+
138
+ def benchmark_mapper(self, num_iter, warmup=5):
139
+ """
140
+ Benchmark the speed of taking raw samples from the dataset and map
141
+ them in a single process.
142
+ """
143
+
144
+ def loader():
145
+ while True:
146
+ for k in self.sampler:
147
+ yield self.mapper(self.dataset[k])
148
+
149
+ self._benchmark(loader(), num_iter, warmup, "Single Process Mapper (sec/sample)")
150
+
151
+ def benchmark_workers(self, num_iter, warmup=10):
152
+ """
153
+ Benchmark the dataloader by tuning num_workers to [0, 1, self.num_workers].
154
+ """
155
+ candidates = [0, 1]
156
+ if self.num_workers not in candidates:
157
+ candidates.append(self.num_workers)
158
+
159
+ dataset = MapDataset(self.dataset, self.mapper)
160
+ for n in candidates:
161
+ loader = build_batch_data_loader(
162
+ dataset,
163
+ self.sampler,
164
+ self.total_batch_size,
165
+ num_workers=n,
166
+ )
167
+ self._benchmark(
168
+ iter(loader),
169
+ num_iter * max(n, 1),
170
+ warmup * max(n, 1),
171
+ f"DataLoader ({n} workers, bs={self.per_gpu_batch_size})",
172
+ )
173
+ del loader
174
+
175
+ def benchmark_IPC(self, num_iter, warmup=10):
176
+ """
177
+ Benchmark the dataloader where each worker outputs nothing. This
178
+ eliminates the IPC overhead compared to the regular dataloader.
179
+
180
+ PyTorch multiprocessing's IPC only optimizes for torch tensors.
181
+ Large numpy arrays or other data structure may incur large IPC overhead.
182
+ """
183
+ n = self.num_workers
184
+ dataset = _EmptyMapDataset(MapDataset(self.dataset, self.mapper))
185
+ loader = build_batch_data_loader(
186
+ dataset, self.sampler, self.total_batch_size, num_workers=n
187
+ )
188
+ self._benchmark(
189
+ iter(loader),
190
+ num_iter * max(n, 1),
191
+ warmup * max(n, 1),
192
+ f"DataLoader ({n} workers, bs={self.per_gpu_batch_size}) w/o comm",
193
+ )
194
+
195
+ def benchmark_distributed(self, num_iter, warmup=10):
196
+ """
197
+ Benchmark the dataloader in each distributed worker, and log results of
198
+ all workers. This helps understand the final performance as well as
199
+ the variances among workers.
200
+
201
+ It also prints startup time (first iter) of the dataloader.
202
+ """
203
+ gpu = comm.get_world_size()
204
+ dataset = MapDataset(self.dataset, self.mapper)
205
+ n = self.num_workers
206
+ loader = build_batch_data_loader(
207
+ dataset, self.sampler, self.total_batch_size, num_workers=n
208
+ )
209
+
210
+ timer = Timer()
211
+ loader = iter(loader)
212
+ next(loader)
213
+ startup_time = timer.seconds()
214
+ logger.info("Dataloader startup time: {:.2f} seconds".format(startup_time))
215
+
216
+ comm.synchronize()
217
+
218
+ avg, all_times = self._benchmark(loader, num_iter * max(n, 1), warmup * max(n, 1))
219
+ del loader
220
+ self._log_time(
221
+ f"DataLoader ({gpu} GPUs x {n} workers, total bs={self.total_batch_size})",
222
+ avg,
223
+ all_times,
224
+ True,
225
+ )
CatVTON/detectron2/data/build.py ADDED
@@ -0,0 +1,694 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ import itertools
3
+ import logging
4
+ import numpy as np
5
+ import operator
6
+ import pickle
7
+ from collections import OrderedDict, defaultdict
8
+ from typing import Any, Callable, Dict, List, Optional, Union
9
+ import torch
10
+ import torch.utils.data as torchdata
11
+ from tabulate import tabulate
12
+ from termcolor import colored
13
+
14
+ from detectron2.config import configurable
15
+ from detectron2.structures import BoxMode
16
+ from detectron2.utils.comm import get_world_size
17
+ from detectron2.utils.env import seed_all_rng
18
+ from detectron2.utils.file_io import PathManager
19
+ from detectron2.utils.logger import _log_api_usage, log_first_n
20
+
21
+ from .catalog import DatasetCatalog, MetadataCatalog
22
+ from .common import AspectRatioGroupedDataset, DatasetFromList, MapDataset, ToIterableDataset
23
+ from .dataset_mapper import DatasetMapper
24
+ from .detection_utils import check_metadata_consistency
25
+ from .samplers import (
26
+ InferenceSampler,
27
+ RandomSubsetTrainingSampler,
28
+ RepeatFactorTrainingSampler,
29
+ TrainingSampler,
30
+ )
31
+
32
+ """
33
+ This file contains the default logic to build a dataloader for training or testing.
34
+ """
35
+
36
+ __all__ = [
37
+ "build_batch_data_loader",
38
+ "build_detection_train_loader",
39
+ "build_detection_test_loader",
40
+ "get_detection_dataset_dicts",
41
+ "load_proposals_into_dataset",
42
+ "print_instances_class_histogram",
43
+ ]
44
+
45
+
46
+ def filter_images_with_only_crowd_annotations(dataset_dicts):
47
+ """
48
+ Filter out images with none annotations or only crowd annotations
49
+ (i.e., images without non-crowd annotations).
50
+ A common training-time preprocessing on COCO dataset.
51
+
52
+ Args:
53
+ dataset_dicts (list[dict]): annotations in Detectron2 Dataset format.
54
+
55
+ Returns:
56
+ list[dict]: the same format, but filtered.
57
+ """
58
+ num_before = len(dataset_dicts)
59
+
60
+ def valid(anns):
61
+ for ann in anns:
62
+ if ann.get("iscrowd", 0) == 0:
63
+ return True
64
+ return False
65
+
66
+ dataset_dicts = [x for x in dataset_dicts if valid(x["annotations"])]
67
+ num_after = len(dataset_dicts)
68
+ logger = logging.getLogger(__name__)
69
+ logger.info(
70
+ "Removed {} images with no usable annotations. {} images left.".format(
71
+ num_before - num_after, num_after
72
+ )
73
+ )
74
+ return dataset_dicts
75
+
76
+
77
+ def filter_images_with_few_keypoints(dataset_dicts, min_keypoints_per_image):
78
+ """
79
+ Filter out images with too few number of keypoints.
80
+
81
+ Args:
82
+ dataset_dicts (list[dict]): annotations in Detectron2 Dataset format.
83
+
84
+ Returns:
85
+ list[dict]: the same format as dataset_dicts, but filtered.
86
+ """
87
+ num_before = len(dataset_dicts)
88
+
89
+ def visible_keypoints_in_image(dic):
90
+ # Each keypoints field has the format [x1, y1, v1, ...], where v is visibility
91
+ annotations = dic["annotations"]
92
+ return sum(
93
+ (np.array(ann["keypoints"][2::3]) > 0).sum()
94
+ for ann in annotations
95
+ if "keypoints" in ann
96
+ )
97
+
98
+ dataset_dicts = [
99
+ x for x in dataset_dicts if visible_keypoints_in_image(x) >= min_keypoints_per_image
100
+ ]
101
+ num_after = len(dataset_dicts)
102
+ logger = logging.getLogger(__name__)
103
+ logger.info(
104
+ "Removed {} images with fewer than {} keypoints.".format(
105
+ num_before - num_after, min_keypoints_per_image
106
+ )
107
+ )
108
+ return dataset_dicts
109
+
110
+
111
+ def load_proposals_into_dataset(dataset_dicts, proposal_file):
112
+ """
113
+ Load precomputed object proposals into the dataset.
114
+
115
+ The proposal file should be a pickled dict with the following keys:
116
+
117
+ - "ids": list[int] or list[str], the image ids
118
+ - "boxes": list[np.ndarray], each is an Nx4 array of boxes corresponding to the image id
119
+ - "objectness_logits": list[np.ndarray], each is an N sized array of objectness scores
120
+ corresponding to the boxes.
121
+ - "bbox_mode": the BoxMode of the boxes array. Defaults to ``BoxMode.XYXY_ABS``.
122
+
123
+ Args:
124
+ dataset_dicts (list[dict]): annotations in Detectron2 Dataset format.
125
+ proposal_file (str): file path of pre-computed proposals, in pkl format.
126
+
127
+ Returns:
128
+ list[dict]: the same format as dataset_dicts, but added proposal field.
129
+ """
130
+ logger = logging.getLogger(__name__)
131
+ logger.info("Loading proposals from: {}".format(proposal_file))
132
+
133
+ with PathManager.open(proposal_file, "rb") as f:
134
+ proposals = pickle.load(f, encoding="latin1")
135
+
136
+ # Rename the key names in D1 proposal files
137
+ rename_keys = {"indexes": "ids", "scores": "objectness_logits"}
138
+ for key in rename_keys:
139
+ if key in proposals:
140
+ proposals[rename_keys[key]] = proposals.pop(key)
141
+
142
+ # Fetch the indexes of all proposals that are in the dataset
143
+ # Convert image_id to str since they could be int.
144
+ img_ids = set({str(record["image_id"]) for record in dataset_dicts})
145
+ id_to_index = {str(id): i for i, id in enumerate(proposals["ids"]) if str(id) in img_ids}
146
+
147
+ # Assuming default bbox_mode of precomputed proposals are 'XYXY_ABS'
148
+ bbox_mode = BoxMode(proposals["bbox_mode"]) if "bbox_mode" in proposals else BoxMode.XYXY_ABS
149
+
150
+ for record in dataset_dicts:
151
+ # Get the index of the proposal
152
+ i = id_to_index[str(record["image_id"])]
153
+
154
+ boxes = proposals["boxes"][i]
155
+ objectness_logits = proposals["objectness_logits"][i]
156
+ # Sort the proposals in descending order of the scores
157
+ inds = objectness_logits.argsort()[::-1]
158
+ record["proposal_boxes"] = boxes[inds]
159
+ record["proposal_objectness_logits"] = objectness_logits[inds]
160
+ record["proposal_bbox_mode"] = bbox_mode
161
+
162
+ return dataset_dicts
163
+
164
+
165
+ def print_instances_class_histogram(dataset_dicts, class_names):
166
+ """
167
+ Args:
168
+ dataset_dicts (list[dict]): list of dataset dicts.
169
+ class_names (list[str]): list of class names (zero-indexed).
170
+ """
171
+ num_classes = len(class_names)
172
+ hist_bins = np.arange(num_classes + 1)
173
+ histogram = np.zeros((num_classes,), dtype=int)
174
+ for entry in dataset_dicts:
175
+ annos = entry["annotations"]
176
+ classes = np.asarray(
177
+ [x["category_id"] for x in annos if not x.get("iscrowd", 0)], dtype=int
178
+ )
179
+ if len(classes):
180
+ assert classes.min() >= 0, f"Got an invalid category_id={classes.min()}"
181
+ assert (
182
+ classes.max() < num_classes
183
+ ), f"Got an invalid category_id={classes.max()} for a dataset of {num_classes} classes"
184
+ histogram += np.histogram(classes, bins=hist_bins)[0]
185
+
186
+ N_COLS = min(6, len(class_names) * 2)
187
+
188
+ def short_name(x):
189
+ # make long class names shorter. useful for lvis
190
+ if len(x) > 13:
191
+ return x[:11] + ".."
192
+ return x
193
+
194
+ data = list(
195
+ itertools.chain(*[[short_name(class_names[i]), int(v)] for i, v in enumerate(histogram)])
196
+ )
197
+ total_num_instances = sum(data[1::2])
198
+ data.extend([None] * (N_COLS - (len(data) % N_COLS)))
199
+ if num_classes > 1:
200
+ data.extend(["total", total_num_instances])
201
+ data = itertools.zip_longest(*[data[i::N_COLS] for i in range(N_COLS)])
202
+ table = tabulate(
203
+ data,
204
+ headers=["category", "#instances"] * (N_COLS // 2),
205
+ tablefmt="pipe",
206
+ numalign="left",
207
+ stralign="center",
208
+ )
209
+ log_first_n(
210
+ logging.INFO,
211
+ "Distribution of instances among all {} categories:\n".format(num_classes)
212
+ + colored(table, "cyan"),
213
+ key="message",
214
+ )
215
+
216
+
217
+ def get_detection_dataset_dicts(
218
+ names,
219
+ filter_empty=True,
220
+ min_keypoints=0,
221
+ proposal_files=None,
222
+ check_consistency=True,
223
+ ):
224
+ """
225
+ Load and prepare dataset dicts for instance detection/segmentation and semantic segmentation.
226
+
227
+ Args:
228
+ names (str or list[str]): a dataset name or a list of dataset names
229
+ filter_empty (bool): whether to filter out images without instance annotations
230
+ min_keypoints (int): filter out images with fewer keypoints than
231
+ `min_keypoints`. Set to 0 to do nothing.
232
+ proposal_files (list[str]): if given, a list of object proposal files
233
+ that match each dataset in `names`.
234
+ check_consistency (bool): whether to check if datasets have consistent metadata.
235
+
236
+ Returns:
237
+ list[dict]: a list of dicts following the standard dataset dict format.
238
+ """
239
+ if isinstance(names, str):
240
+ names = [names]
241
+ assert len(names), names
242
+
243
+ available_datasets = DatasetCatalog.keys()
244
+ names_set = set(names)
245
+ if not names_set.issubset(available_datasets):
246
+ logger = logging.getLogger(__name__)
247
+ logger.warning(
248
+ "The following dataset names are not registered in the DatasetCatalog: "
249
+ f"{names_set - available_datasets}. "
250
+ f"Available datasets are {available_datasets}"
251
+ )
252
+
253
+ dataset_dicts = [DatasetCatalog.get(dataset_name) for dataset_name in names]
254
+
255
+ if isinstance(dataset_dicts[0], torchdata.Dataset):
256
+ if len(dataset_dicts) > 1:
257
+ # ConcatDataset does not work for iterable style dataset.
258
+ # We could support concat for iterable as well, but it's often
259
+ # not a good idea to concat iterables anyway.
260
+ return torchdata.ConcatDataset(dataset_dicts)
261
+ return dataset_dicts[0]
262
+
263
+ for dataset_name, dicts in zip(names, dataset_dicts):
264
+ assert len(dicts), "Dataset '{}' is empty!".format(dataset_name)
265
+
266
+ if proposal_files is not None:
267
+ assert len(names) == len(proposal_files)
268
+ # load precomputed proposals from proposal files
269
+ dataset_dicts = [
270
+ load_proposals_into_dataset(dataset_i_dicts, proposal_file)
271
+ for dataset_i_dicts, proposal_file in zip(dataset_dicts, proposal_files)
272
+ ]
273
+
274
+ dataset_dicts = list(itertools.chain.from_iterable(dataset_dicts))
275
+
276
+ has_instances = "annotations" in dataset_dicts[0]
277
+ if filter_empty and has_instances:
278
+ dataset_dicts = filter_images_with_only_crowd_annotations(dataset_dicts)
279
+ if min_keypoints > 0 and has_instances:
280
+ dataset_dicts = filter_images_with_few_keypoints(dataset_dicts, min_keypoints)
281
+
282
+ if check_consistency and has_instances:
283
+ try:
284
+ class_names = MetadataCatalog.get(names[0]).thing_classes
285
+ check_metadata_consistency("thing_classes", names)
286
+ print_instances_class_histogram(dataset_dicts, class_names)
287
+ except AttributeError: # class names are not available for this dataset
288
+ pass
289
+
290
+ assert len(dataset_dicts), "No valid data found in {}.".format(",".join(names))
291
+ return dataset_dicts
292
+
293
+
294
+ def build_batch_data_loader(
295
+ dataset,
296
+ sampler,
297
+ total_batch_size,
298
+ *,
299
+ aspect_ratio_grouping=False,
300
+ num_workers=0,
301
+ collate_fn=None,
302
+ drop_last: bool = True,
303
+ single_gpu_batch_size=None,
304
+ prefetch_factor=2,
305
+ persistent_workers=False,
306
+ pin_memory=False,
307
+ seed=None,
308
+ **kwargs,
309
+ ):
310
+ """
311
+ Build a batched dataloader. The main differences from `torch.utils.data.DataLoader` are:
312
+ 1. support aspect ratio grouping options
313
+ 2. use no "batch collation", because this is common for detection training
314
+
315
+ Args:
316
+ dataset (torch.utils.data.Dataset): a pytorch map-style or iterable dataset.
317
+ sampler (torch.utils.data.sampler.Sampler or None): a sampler that produces indices.
318
+ Must be provided iff. ``dataset`` is a map-style dataset.
319
+ total_batch_size, aspect_ratio_grouping, num_workers, collate_fn: see
320
+ :func:`build_detection_train_loader`.
321
+ single_gpu_batch_size: You can specify either `single_gpu_batch_size` or `total_batch_size`.
322
+ `single_gpu_batch_size` specifies the batch size that will be used for each gpu/process.
323
+ `total_batch_size` allows you to specify the total aggregate batch size across gpus.
324
+ It is an error to supply a value for both.
325
+ drop_last (bool): if ``True``, the dataloader will drop incomplete batches.
326
+
327
+ Returns:
328
+ iterable[list]. Length of each list is the batch size of the current
329
+ GPU. Each element in the list comes from the dataset.
330
+ """
331
+ if single_gpu_batch_size:
332
+ if total_batch_size:
333
+ raise ValueError(
334
+ """total_batch_size and single_gpu_batch_size are mutually incompatible.
335
+ Please specify only one. """
336
+ )
337
+ batch_size = single_gpu_batch_size
338
+ else:
339
+ world_size = get_world_size()
340
+ assert (
341
+ total_batch_size > 0 and total_batch_size % world_size == 0
342
+ ), "Total batch size ({}) must be divisible by the number of gpus ({}).".format(
343
+ total_batch_size, world_size
344
+ )
345
+ batch_size = total_batch_size // world_size
346
+ logger = logging.getLogger(__name__)
347
+ logger.info("Making batched data loader with batch_size=%d", batch_size)
348
+
349
+ if isinstance(dataset, torchdata.IterableDataset):
350
+ assert sampler is None, "sampler must be None if dataset is IterableDataset"
351
+ else:
352
+ dataset = ToIterableDataset(dataset, sampler, shard_chunk_size=batch_size)
353
+
354
+ generator = None
355
+ if seed is not None:
356
+ generator = torch.Generator()
357
+ generator.manual_seed(seed)
358
+
359
+ if aspect_ratio_grouping:
360
+ assert drop_last, "Aspect ratio grouping will drop incomplete batches."
361
+ data_loader = torchdata.DataLoader(
362
+ dataset,
363
+ num_workers=num_workers,
364
+ collate_fn=operator.itemgetter(0), # don't batch, but yield individual elements
365
+ worker_init_fn=worker_init_reset_seed,
366
+ prefetch_factor=prefetch_factor if num_workers > 0 else None,
367
+ persistent_workers=persistent_workers,
368
+ pin_memory=pin_memory,
369
+ generator=generator,
370
+ **kwargs,
371
+ ) # yield individual mapped dict
372
+ data_loader = AspectRatioGroupedDataset(data_loader, batch_size)
373
+ if collate_fn is None:
374
+ return data_loader
375
+ return MapDataset(data_loader, collate_fn)
376
+ else:
377
+ return torchdata.DataLoader(
378
+ dataset,
379
+ batch_size=batch_size,
380
+ drop_last=drop_last,
381
+ num_workers=num_workers,
382
+ collate_fn=trivial_batch_collator if collate_fn is None else collate_fn,
383
+ worker_init_fn=worker_init_reset_seed,
384
+ prefetch_factor=prefetch_factor if num_workers > 0 else None,
385
+ persistent_workers=persistent_workers,
386
+ pin_memory=pin_memory,
387
+ generator=generator,
388
+ **kwargs,
389
+ )
390
+
391
+
392
+ def _get_train_datasets_repeat_factors(cfg) -> Dict[str, float]:
393
+ repeat_factors = cfg.DATASETS.TRAIN_REPEAT_FACTOR
394
+ assert all(len(tup) == 2 for tup in repeat_factors)
395
+ name_to_weight = defaultdict(lambda: 1, dict(repeat_factors))
396
+ # The sampling weights map should only contain datasets in train config
397
+ unrecognized = set(name_to_weight.keys()) - set(cfg.DATASETS.TRAIN)
398
+ assert not unrecognized, f"unrecognized datasets: {unrecognized}"
399
+ logger = logging.getLogger(__name__)
400
+ logger.info(f"Found repeat factors: {list(name_to_weight.items())}")
401
+
402
+ # pyre-fixme[7]: Expected `Dict[str, float]` but got `DefaultDict[typing.Any, int]`.
403
+ return name_to_weight
404
+
405
+
406
+ def _build_weighted_sampler(cfg, enable_category_balance=False):
407
+ dataset_repeat_factors = _get_train_datasets_repeat_factors(cfg)
408
+ # OrderedDict to guarantee order of values() consistent with repeat factors
409
+ dataset_name_to_dicts = OrderedDict(
410
+ {
411
+ name: get_detection_dataset_dicts(
412
+ [name],
413
+ filter_empty=cfg.DATALOADER.FILTER_EMPTY_ANNOTATIONS,
414
+ min_keypoints=(
415
+ cfg.MODEL.ROI_KEYPOINT_HEAD.MIN_KEYPOINTS_PER_IMAGE
416
+ if cfg.MODEL.KEYPOINT_ON
417
+ else 0
418
+ ),
419
+ proposal_files=(
420
+ cfg.DATASETS.PROPOSAL_FILES_TRAIN if cfg.MODEL.LOAD_PROPOSALS else None
421
+ ),
422
+ )
423
+ for name in cfg.DATASETS.TRAIN
424
+ }
425
+ )
426
+ # Repeat factor for every sample in the dataset
427
+ repeat_factors = [
428
+ [dataset_repeat_factors[dsname]] * len(dataset_name_to_dicts[dsname])
429
+ for dsname in cfg.DATASETS.TRAIN
430
+ ]
431
+
432
+ repeat_factors = list(itertools.chain.from_iterable(repeat_factors))
433
+
434
+ repeat_factors = torch.tensor(repeat_factors)
435
+ logger = logging.getLogger(__name__)
436
+ if enable_category_balance:
437
+ """
438
+ 1. Calculate repeat factors using category frequency for each dataset and then merge them.
439
+ 2. Element wise dot producting the dataset frequency repeat factors with
440
+ the category frequency repeat factors gives the final repeat factors.
441
+ """
442
+ category_repeat_factors = [
443
+ RepeatFactorTrainingSampler.repeat_factors_from_category_frequency(
444
+ dataset_dict, cfg.DATALOADER.REPEAT_THRESHOLD, sqrt=cfg.DATALOADER.REPEAT_SQRT
445
+ )
446
+ for dataset_dict in dataset_name_to_dicts.values()
447
+ ]
448
+ # flatten the category repeat factors from all datasets
449
+ category_repeat_factors = list(itertools.chain.from_iterable(category_repeat_factors))
450
+ category_repeat_factors = torch.tensor(category_repeat_factors)
451
+ repeat_factors = torch.mul(category_repeat_factors, repeat_factors)
452
+ repeat_factors = repeat_factors / torch.min(repeat_factors)
453
+ logger.info(
454
+ "Using WeightedCategoryTrainingSampler with repeat_factors={}".format(
455
+ cfg.DATASETS.TRAIN_REPEAT_FACTOR
456
+ )
457
+ )
458
+ else:
459
+ logger.info(
460
+ "Using WeightedTrainingSampler with repeat_factors={}".format(
461
+ cfg.DATASETS.TRAIN_REPEAT_FACTOR
462
+ )
463
+ )
464
+
465
+ sampler = RepeatFactorTrainingSampler(repeat_factors)
466
+ return sampler
467
+
468
+
469
+ def _train_loader_from_config(cfg, mapper=None, *, dataset=None, sampler=None):
470
+ if dataset is None:
471
+ dataset = get_detection_dataset_dicts(
472
+ cfg.DATASETS.TRAIN,
473
+ filter_empty=cfg.DATALOADER.FILTER_EMPTY_ANNOTATIONS,
474
+ min_keypoints=(
475
+ cfg.MODEL.ROI_KEYPOINT_HEAD.MIN_KEYPOINTS_PER_IMAGE if cfg.MODEL.KEYPOINT_ON else 0
476
+ ),
477
+ proposal_files=cfg.DATASETS.PROPOSAL_FILES_TRAIN if cfg.MODEL.LOAD_PROPOSALS else None,
478
+ )
479
+ _log_api_usage("dataset." + cfg.DATASETS.TRAIN[0])
480
+
481
+ if mapper is None:
482
+ mapper = DatasetMapper(cfg, True)
483
+
484
+ if sampler is None:
485
+ sampler_name = cfg.DATALOADER.SAMPLER_TRAIN
486
+ logger = logging.getLogger(__name__)
487
+ if isinstance(dataset, torchdata.IterableDataset):
488
+ logger.info("Not using any sampler since the dataset is IterableDataset.")
489
+ sampler = None
490
+ else:
491
+ logger.info("Using training sampler {}".format(sampler_name))
492
+ if sampler_name == "TrainingSampler":
493
+ sampler = TrainingSampler(len(dataset))
494
+ elif sampler_name == "RepeatFactorTrainingSampler":
495
+ repeat_factors = RepeatFactorTrainingSampler.repeat_factors_from_category_frequency(
496
+ dataset, cfg.DATALOADER.REPEAT_THRESHOLD, sqrt=cfg.DATALOADER.REPEAT_SQRT
497
+ )
498
+ sampler = RepeatFactorTrainingSampler(repeat_factors, seed=cfg.SEED)
499
+ elif sampler_name == "RandomSubsetTrainingSampler":
500
+ sampler = RandomSubsetTrainingSampler(
501
+ len(dataset), cfg.DATALOADER.RANDOM_SUBSET_RATIO
502
+ )
503
+ elif sampler_name == "WeightedTrainingSampler":
504
+ sampler = _build_weighted_sampler(cfg)
505
+ elif sampler_name == "WeightedCategoryTrainingSampler":
506
+ sampler = _build_weighted_sampler(cfg, enable_category_balance=True)
507
+ else:
508
+ raise ValueError("Unknown training sampler: {}".format(sampler_name))
509
+
510
+ return {
511
+ "dataset": dataset,
512
+ "sampler": sampler,
513
+ "mapper": mapper,
514
+ "total_batch_size": cfg.SOLVER.IMS_PER_BATCH,
515
+ "aspect_ratio_grouping": cfg.DATALOADER.ASPECT_RATIO_GROUPING,
516
+ "num_workers": cfg.DATALOADER.NUM_WORKERS,
517
+ }
518
+
519
+
520
+ @configurable(from_config=_train_loader_from_config)
521
+ def build_detection_train_loader(
522
+ dataset,
523
+ *,
524
+ mapper,
525
+ sampler=None,
526
+ total_batch_size,
527
+ aspect_ratio_grouping=True,
528
+ num_workers=0,
529
+ collate_fn=None,
530
+ **kwargs,
531
+ ):
532
+ """
533
+ Build a dataloader for object detection with some default features.
534
+
535
+ Args:
536
+ dataset (list or torch.utils.data.Dataset): a list of dataset dicts,
537
+ or a pytorch dataset (either map-style or iterable). It can be obtained
538
+ by using :func:`DatasetCatalog.get` or :func:`get_detection_dataset_dicts`.
539
+ mapper (callable): a callable which takes a sample (dict) from dataset and
540
+ returns the format to be consumed by the model.
541
+ When using cfg, the default choice is ``DatasetMapper(cfg, is_train=True)``.
542
+ sampler (torch.utils.data.sampler.Sampler or None): a sampler that produces
543
+ indices to be applied on ``dataset``.
544
+ If ``dataset`` is map-style, the default sampler is a :class:`TrainingSampler`,
545
+ which coordinates an infinite random shuffle sequence across all workers.
546
+ Sampler must be None if ``dataset`` is iterable.
547
+ total_batch_size (int): total batch size across all workers.
548
+ aspect_ratio_grouping (bool): whether to group images with similar
549
+ aspect ratio for efficiency. When enabled, it requires each
550
+ element in dataset be a dict with keys "width" and "height".
551
+ num_workers (int): number of parallel data loading workers
552
+ collate_fn: a function that determines how to do batching, same as the argument of
553
+ `torch.utils.data.DataLoader`. Defaults to do no collation and return a list of
554
+ data. No collation is OK for small batch size and simple data structures.
555
+ If your batch size is large and each sample contains too many small tensors,
556
+ it's more efficient to collate them in data loader.
557
+
558
+ Returns:
559
+ torch.utils.data.DataLoader:
560
+ a dataloader. Each output from it is a ``list[mapped_element]`` of length
561
+ ``total_batch_size / num_workers``, where ``mapped_element`` is produced
562
+ by the ``mapper``.
563
+ """
564
+ if isinstance(dataset, list):
565
+ dataset = DatasetFromList(dataset, copy=False)
566
+ if mapper is not None:
567
+ dataset = MapDataset(dataset, mapper)
568
+
569
+ if isinstance(dataset, torchdata.IterableDataset):
570
+ assert sampler is None, "sampler must be None if dataset is IterableDataset"
571
+ else:
572
+ if sampler is None:
573
+ sampler = TrainingSampler(len(dataset))
574
+ assert isinstance(sampler, torchdata.Sampler), f"Expect a Sampler but got {type(sampler)}"
575
+ return build_batch_data_loader(
576
+ dataset,
577
+ sampler,
578
+ total_batch_size,
579
+ aspect_ratio_grouping=aspect_ratio_grouping,
580
+ num_workers=num_workers,
581
+ collate_fn=collate_fn,
582
+ **kwargs,
583
+ )
584
+
585
+
586
+ def _test_loader_from_config(cfg, dataset_name, mapper=None):
587
+ """
588
+ Uses the given `dataset_name` argument (instead of the names in cfg), because the
589
+ standard practice is to evaluate each test set individually (not combining them).
590
+ """
591
+ if isinstance(dataset_name, str):
592
+ dataset_name = [dataset_name]
593
+
594
+ dataset = get_detection_dataset_dicts(
595
+ dataset_name,
596
+ filter_empty=False,
597
+ proposal_files=(
598
+ [
599
+ cfg.DATASETS.PROPOSAL_FILES_TEST[list(cfg.DATASETS.TEST).index(x)]
600
+ for x in dataset_name
601
+ ]
602
+ if cfg.MODEL.LOAD_PROPOSALS
603
+ else None
604
+ ),
605
+ )
606
+ if mapper is None:
607
+ mapper = DatasetMapper(cfg, False)
608
+ return {
609
+ "dataset": dataset,
610
+ "mapper": mapper,
611
+ "num_workers": cfg.DATALOADER.NUM_WORKERS,
612
+ "sampler": (
613
+ InferenceSampler(len(dataset))
614
+ if not isinstance(dataset, torchdata.IterableDataset)
615
+ else None
616
+ ),
617
+ }
618
+
619
+
620
+ @configurable(from_config=_test_loader_from_config)
621
+ def build_detection_test_loader(
622
+ dataset: Union[List[Any], torchdata.Dataset],
623
+ *,
624
+ mapper: Callable[[Dict[str, Any]], Any],
625
+ sampler: Optional[torchdata.Sampler] = None,
626
+ batch_size: int = 1,
627
+ num_workers: int = 0,
628
+ collate_fn: Optional[Callable[[List[Any]], Any]] = None,
629
+ ) -> torchdata.DataLoader:
630
+ """
631
+ Similar to `build_detection_train_loader`, with default batch size = 1,
632
+ and sampler = :class:`InferenceSampler`. This sampler coordinates all workers
633
+ to produce the exact set of all samples.
634
+
635
+ Args:
636
+ dataset: a list of dataset dicts,
637
+ or a pytorch dataset (either map-style or iterable). They can be obtained
638
+ by using :func:`DatasetCatalog.get` or :func:`get_detection_dataset_dicts`.
639
+ mapper: a callable which takes a sample (dict) from dataset
640
+ and returns the format to be consumed by the model.
641
+ When using cfg, the default choice is ``DatasetMapper(cfg, is_train=False)``.
642
+ sampler: a sampler that produces
643
+ indices to be applied on ``dataset``. Default to :class:`InferenceSampler`,
644
+ which splits the dataset across all workers. Sampler must be None
645
+ if `dataset` is iterable.
646
+ batch_size: the batch size of the data loader to be created.
647
+ Default to 1 image per worker since this is the standard when reporting
648
+ inference time in papers.
649
+ num_workers: number of parallel data loading workers
650
+ collate_fn: same as the argument of `torch.utils.data.DataLoader`.
651
+ Defaults to do no collation and return a list of data.
652
+
653
+ Returns:
654
+ DataLoader: a torch DataLoader, that loads the given detection
655
+ dataset, with test-time transformation and batching.
656
+
657
+ Examples:
658
+ ::
659
+ data_loader = build_detection_test_loader(
660
+ DatasetRegistry.get("my_test"),
661
+ mapper=DatasetMapper(...))
662
+
663
+ # or, instantiate with a CfgNode:
664
+ data_loader = build_detection_test_loader(cfg, "my_test")
665
+ """
666
+ if isinstance(dataset, list):
667
+ dataset = DatasetFromList(dataset, copy=False)
668
+ if mapper is not None:
669
+ dataset = MapDataset(dataset, mapper)
670
+ if isinstance(dataset, torchdata.IterableDataset):
671
+ assert sampler is None, "sampler must be None if dataset is IterableDataset"
672
+ else:
673
+ if sampler is None:
674
+ sampler = InferenceSampler(len(dataset))
675
+ return torchdata.DataLoader(
676
+ dataset,
677
+ batch_size=batch_size,
678
+ sampler=sampler,
679
+ drop_last=False,
680
+ num_workers=num_workers,
681
+ collate_fn=trivial_batch_collator if collate_fn is None else collate_fn,
682
+ )
683
+
684
+
685
+ def trivial_batch_collator(batch):
686
+ """
687
+ A batch collator that does nothing.
688
+ """
689
+ return batch
690
+
691
+
692
+ def worker_init_reset_seed(worker_id):
693
+ initial_seed = torch.initial_seed() % 2**31
694
+ seed_all_rng(initial_seed + worker_id)
CatVTON/detectron2/data/catalog.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ import copy
3
+ import logging
4
+ import types
5
+ from collections import UserDict
6
+ from typing import List
7
+
8
+ from detectron2.utils.logger import log_first_n
9
+
10
+ __all__ = ["DatasetCatalog", "MetadataCatalog", "Metadata"]
11
+
12
+
13
+ class _DatasetCatalog(UserDict):
14
+ """
15
+ A global dictionary that stores information about the datasets and how to obtain them.
16
+
17
+ It contains a mapping from strings
18
+ (which are names that identify a dataset, e.g. "coco_2014_train")
19
+ to a function which parses the dataset and returns the samples in the
20
+ format of `list[dict]`.
21
+
22
+ The returned dicts should be in Detectron2 Dataset format (See DATASETS.md for details)
23
+ if used with the data loader functionalities in `data/build.py,data/detection_transform.py`.
24
+
25
+ The purpose of having this catalog is to make it easy to choose
26
+ different datasets, by just using the strings in the config.
27
+ """
28
+
29
+ def register(self, name, func):
30
+ """
31
+ Args:
32
+ name (str): the name that identifies a dataset, e.g. "coco_2014_train".
33
+ func (callable): a callable which takes no arguments and returns a list of dicts.
34
+ It must return the same results if called multiple times.
35
+ """
36
+ assert callable(func), "You must register a function with `DatasetCatalog.register`!"
37
+ assert name not in self, "Dataset '{}' is already registered!".format(name)
38
+ self[name] = func
39
+
40
+ def get(self, name):
41
+ """
42
+ Call the registered function and return its results.
43
+
44
+ Args:
45
+ name (str): the name that identifies a dataset, e.g. "coco_2014_train".
46
+
47
+ Returns:
48
+ list[dict]: dataset annotations.
49
+ """
50
+ try:
51
+ f = self[name]
52
+ except KeyError as e:
53
+ raise KeyError(
54
+ "Dataset '{}' is not registered! Available datasets are: {}".format(
55
+ name, ", ".join(list(self.keys()))
56
+ )
57
+ ) from e
58
+ return f()
59
+
60
+ def list(self) -> List[str]:
61
+ """
62
+ List all registered datasets.
63
+
64
+ Returns:
65
+ list[str]
66
+ """
67
+ return list(self.keys())
68
+
69
+ def remove(self, name):
70
+ """
71
+ Alias of ``pop``.
72
+ """
73
+ self.pop(name)
74
+
75
+ def __str__(self):
76
+ return "DatasetCatalog(registered datasets: {})".format(", ".join(self.keys()))
77
+
78
+ __repr__ = __str__
79
+
80
+
81
+ DatasetCatalog = _DatasetCatalog()
82
+ DatasetCatalog.__doc__ = (
83
+ _DatasetCatalog.__doc__
84
+ + """
85
+ .. automethod:: detectron2.data.catalog.DatasetCatalog.register
86
+ .. automethod:: detectron2.data.catalog.DatasetCatalog.get
87
+ """
88
+ )
89
+
90
+
91
+ class Metadata(types.SimpleNamespace):
92
+ """
93
+ A class that supports simple attribute setter/getter.
94
+ It is intended for storing metadata of a dataset and make it accessible globally.
95
+
96
+ Examples:
97
+ ::
98
+ # somewhere when you load the data:
99
+ MetadataCatalog.get("mydataset").thing_classes = ["person", "dog"]
100
+
101
+ # somewhere when you print statistics or visualize:
102
+ classes = MetadataCatalog.get("mydataset").thing_classes
103
+ """
104
+
105
+ # the name of the dataset
106
+ # set default to N/A so that `self.name` in the errors will not trigger getattr again
107
+ name: str = "N/A"
108
+
109
+ _RENAMED = {
110
+ "class_names": "thing_classes",
111
+ "dataset_id_to_contiguous_id": "thing_dataset_id_to_contiguous_id",
112
+ "stuff_class_names": "stuff_classes",
113
+ }
114
+
115
+ def __getattr__(self, key):
116
+ if key in self._RENAMED:
117
+ log_first_n(
118
+ logging.WARNING,
119
+ "Metadata '{}' was renamed to '{}'!".format(key, self._RENAMED[key]),
120
+ n=10,
121
+ )
122
+ return getattr(self, self._RENAMED[key])
123
+
124
+ # "name" exists in every metadata
125
+ if len(self.__dict__) > 1:
126
+ raise AttributeError(
127
+ "Attribute '{}' does not exist in the metadata of dataset '{}'. Available "
128
+ "keys are {}.".format(key, self.name, str(self.__dict__.keys()))
129
+ )
130
+ else:
131
+ raise AttributeError(
132
+ f"Attribute '{key}' does not exist in the metadata of dataset '{self.name}': "
133
+ "metadata is empty."
134
+ )
135
+
136
+ def __setattr__(self, key, val):
137
+ if key in self._RENAMED:
138
+ log_first_n(
139
+ logging.WARNING,
140
+ "Metadata '{}' was renamed to '{}'!".format(key, self._RENAMED[key]),
141
+ n=10,
142
+ )
143
+ setattr(self, self._RENAMED[key], val)
144
+
145
+ # Ensure that metadata of the same name stays consistent
146
+ try:
147
+ oldval = getattr(self, key)
148
+ assert oldval == val, (
149
+ "Attribute '{}' in the metadata of '{}' cannot be set "
150
+ "to a different value!\n{} != {}".format(key, self.name, oldval, val)
151
+ )
152
+ except AttributeError:
153
+ super().__setattr__(key, val)
154
+
155
+ def as_dict(self):
156
+ """
157
+ Returns all the metadata as a dict.
158
+ Note that modifications to the returned dict will not reflect on the Metadata object.
159
+ """
160
+ return copy.copy(self.__dict__)
161
+
162
+ def set(self, **kwargs):
163
+ """
164
+ Set multiple metadata with kwargs.
165
+ """
166
+ for k, v in kwargs.items():
167
+ setattr(self, k, v)
168
+ return self
169
+
170
+ def get(self, key, default=None):
171
+ """
172
+ Access an attribute and return its value if exists.
173
+ Otherwise return default.
174
+ """
175
+ try:
176
+ return getattr(self, key)
177
+ except AttributeError:
178
+ return default
179
+
180
+
181
+ class _MetadataCatalog(UserDict):
182
+ """
183
+ MetadataCatalog is a global dictionary that provides access to
184
+ :class:`Metadata` of a given dataset.
185
+
186
+ The metadata associated with a certain name is a singleton: once created, the
187
+ metadata will stay alive and will be returned by future calls to ``get(name)``.
188
+
189
+ It's like global variables, so don't abuse it.
190
+ It's meant for storing knowledge that's constant and shared across the execution
191
+ of the program, e.g.: the class names in COCO.
192
+ """
193
+
194
+ def get(self, name):
195
+ """
196
+ Args:
197
+ name (str): name of a dataset (e.g. coco_2014_train).
198
+
199
+ Returns:
200
+ Metadata: The :class:`Metadata` instance associated with this name,
201
+ or create an empty one if none is available.
202
+ """
203
+ assert len(name)
204
+ r = super().get(name, None)
205
+ if r is None:
206
+ r = self[name] = Metadata(name=name)
207
+ return r
208
+
209
+ def list(self):
210
+ """
211
+ List all registered metadata.
212
+
213
+ Returns:
214
+ list[str]: keys (names of datasets) of all registered metadata
215
+ """
216
+ return list(self.keys())
217
+
218
+ def remove(self, name):
219
+ """
220
+ Alias of ``pop``.
221
+ """
222
+ self.pop(name)
223
+
224
+ def __str__(self):
225
+ return "MetadataCatalog(registered metadata: {})".format(", ".join(self.keys()))
226
+
227
+ __repr__ = __str__
228
+
229
+
230
+ MetadataCatalog = _MetadataCatalog()
231
+ MetadataCatalog.__doc__ = (
232
+ _MetadataCatalog.__doc__
233
+ + """
234
+ .. automethod:: detectron2.data.catalog.MetadataCatalog.get
235
+ """
236
+ )
CatVTON/detectron2/data/common.py ADDED
@@ -0,0 +1,339 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ import contextlib
3
+ import copy
4
+ import itertools
5
+ import logging
6
+ import numpy as np
7
+ import pickle
8
+ import random
9
+ from typing import Callable, Union
10
+ import torch
11
+ import torch.utils.data as data
12
+ from torch.utils.data.sampler import Sampler
13
+
14
+ from detectron2.utils.serialize import PicklableWrapper
15
+
16
+ __all__ = ["MapDataset", "DatasetFromList", "AspectRatioGroupedDataset", "ToIterableDataset"]
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+
21
+ # copied from: https://docs.python.org/3/library/itertools.html#recipes
22
+ def _roundrobin(*iterables):
23
+ "roundrobin('ABC', 'D', 'EF') --> A D E B F C"
24
+ # Recipe credited to George Sakkis
25
+ num_active = len(iterables)
26
+ nexts = itertools.cycle(iter(it).__next__ for it in iterables)
27
+ while num_active:
28
+ try:
29
+ for next in nexts:
30
+ yield next()
31
+ except StopIteration:
32
+ # Remove the iterator we just exhausted from the cycle.
33
+ num_active -= 1
34
+ nexts = itertools.cycle(itertools.islice(nexts, num_active))
35
+
36
+
37
+ def _shard_iterator_dataloader_worker(iterable, chunk_size=1):
38
+ # Shard the iterable if we're currently inside pytorch dataloader worker.
39
+ worker_info = data.get_worker_info()
40
+ if worker_info is None or worker_info.num_workers == 1:
41
+ # do nothing
42
+ yield from iterable
43
+ else:
44
+ # worker0: 0, 1, ..., chunk_size-1, num_workers*chunk_size, num_workers*chunk_size+1, ...
45
+ # worker1: chunk_size, chunk_size+1, ...
46
+ # worker2: 2*chunk_size, 2*chunk_size+1, ...
47
+ # ...
48
+ yield from _roundrobin(
49
+ *[
50
+ itertools.islice(
51
+ iterable,
52
+ worker_info.id * chunk_size + chunk_i,
53
+ None,
54
+ worker_info.num_workers * chunk_size,
55
+ )
56
+ for chunk_i in range(chunk_size)
57
+ ]
58
+ )
59
+
60
+
61
+ class _MapIterableDataset(data.IterableDataset):
62
+ """
63
+ Map a function over elements in an IterableDataset.
64
+
65
+ Similar to pytorch's MapIterDataPipe, but support filtering when map_func
66
+ returns None.
67
+
68
+ This class is not public-facing. Will be called by `MapDataset`.
69
+ """
70
+
71
+ def __init__(self, dataset, map_func):
72
+ self._dataset = dataset
73
+ self._map_func = PicklableWrapper(map_func) # wrap so that a lambda will work
74
+
75
+ def __len__(self):
76
+ return len(self._dataset)
77
+
78
+ def __iter__(self):
79
+ for x in map(self._map_func, self._dataset):
80
+ if x is not None:
81
+ yield x
82
+
83
+
84
+ class MapDataset(data.Dataset):
85
+ """
86
+ Map a function over the elements in a dataset.
87
+ """
88
+
89
+ def __init__(self, dataset, map_func):
90
+ """
91
+ Args:
92
+ dataset: a dataset where map function is applied. Can be either
93
+ map-style or iterable dataset. When given an iterable dataset,
94
+ the returned object will also be an iterable dataset.
95
+ map_func: a callable which maps the element in dataset. map_func can
96
+ return None to skip the data (e.g. in case of errors).
97
+ How None is handled depends on the style of `dataset`.
98
+ If `dataset` is map-style, it randomly tries other elements.
99
+ If `dataset` is iterable, it skips the data and tries the next.
100
+ """
101
+ self._dataset = dataset
102
+ self._map_func = PicklableWrapper(map_func) # wrap so that a lambda will work
103
+
104
+ self._rng = random.Random(42)
105
+ self._fallback_candidates = set(range(len(dataset)))
106
+
107
+ def __new__(cls, dataset, map_func):
108
+ is_iterable = isinstance(dataset, data.IterableDataset)
109
+ if is_iterable:
110
+ return _MapIterableDataset(dataset, map_func)
111
+ else:
112
+ return super().__new__(cls)
113
+
114
+ def __getnewargs__(self):
115
+ return self._dataset, self._map_func
116
+
117
+ def __len__(self):
118
+ return len(self._dataset)
119
+
120
+ def __getitem__(self, idx):
121
+ retry_count = 0
122
+ cur_idx = int(idx)
123
+
124
+ while True:
125
+ data = self._map_func(self._dataset[cur_idx])
126
+ if data is not None:
127
+ self._fallback_candidates.add(cur_idx)
128
+ return data
129
+
130
+ # _map_func fails for this idx, use a random new index from the pool
131
+ retry_count += 1
132
+ self._fallback_candidates.discard(cur_idx)
133
+ cur_idx = self._rng.sample(self._fallback_candidates, k=1)[0]
134
+
135
+ if retry_count >= 3:
136
+ logger = logging.getLogger(__name__)
137
+ logger.warning(
138
+ "Failed to apply `_map_func` for idx: {}, retry count: {}".format(
139
+ idx, retry_count
140
+ )
141
+ )
142
+
143
+
144
+ class _TorchSerializedList:
145
+ """
146
+ A list-like object whose items are serialized and stored in a torch tensor. When
147
+ launching a process that uses TorchSerializedList with "fork" start method,
148
+ the subprocess can read the same buffer without triggering copy-on-access. When
149
+ launching a process that uses TorchSerializedList with "spawn/forkserver" start
150
+ method, the list will be pickled by a special ForkingPickler registered by PyTorch
151
+ that moves data to shared memory. In both cases, this allows parent and child
152
+ processes to share RAM for the list data, hence avoids the issue in
153
+ https://github.com/pytorch/pytorch/issues/13246.
154
+
155
+ See also https://ppwwyyxx.com/blog/2022/Demystify-RAM-Usage-in-Multiprocess-DataLoader/
156
+ on how it works.
157
+ """
158
+
159
+ def __init__(self, lst: list):
160
+ self._lst = lst
161
+
162
+ def _serialize(data):
163
+ buffer = pickle.dumps(data, protocol=-1)
164
+ return np.frombuffer(buffer, dtype=np.uint8)
165
+
166
+ logger.info(
167
+ "Serializing {} elements to byte tensors and concatenating them all ...".format(
168
+ len(self._lst)
169
+ )
170
+ )
171
+ self._lst = [_serialize(x) for x in self._lst]
172
+ self._addr = np.asarray([len(x) for x in self._lst], dtype=np.int64)
173
+ self._addr = torch.from_numpy(np.cumsum(self._addr))
174
+ self._lst = torch.from_numpy(np.concatenate(self._lst))
175
+ logger.info("Serialized dataset takes {:.2f} MiB".format(len(self._lst) / 1024**2))
176
+
177
+ def __len__(self):
178
+ return len(self._addr)
179
+
180
+ def __getitem__(self, idx):
181
+ start_addr = 0 if idx == 0 else self._addr[idx - 1].item()
182
+ end_addr = self._addr[idx].item()
183
+ bytes = memoryview(self._lst[start_addr:end_addr].numpy())
184
+
185
+ # @lint-ignore PYTHONPICKLEISBAD
186
+ return pickle.loads(bytes)
187
+
188
+
189
+ _DEFAULT_DATASET_FROM_LIST_SERIALIZE_METHOD = _TorchSerializedList
190
+
191
+
192
+ @contextlib.contextmanager
193
+ def set_default_dataset_from_list_serialize_method(new):
194
+ """
195
+ Context manager for using custom serialize function when creating DatasetFromList
196
+ """
197
+
198
+ global _DEFAULT_DATASET_FROM_LIST_SERIALIZE_METHOD
199
+ orig = _DEFAULT_DATASET_FROM_LIST_SERIALIZE_METHOD
200
+ _DEFAULT_DATASET_FROM_LIST_SERIALIZE_METHOD = new
201
+ yield
202
+ _DEFAULT_DATASET_FROM_LIST_SERIALIZE_METHOD = orig
203
+
204
+
205
+ class DatasetFromList(data.Dataset):
206
+ """
207
+ Wrap a list to a torch Dataset. It produces elements of the list as data.
208
+ """
209
+
210
+ def __init__(
211
+ self,
212
+ lst: list,
213
+ copy: bool = True,
214
+ serialize: Union[bool, Callable] = True,
215
+ ):
216
+ """
217
+ Args:
218
+ lst (list): a list which contains elements to produce.
219
+ copy (bool): whether to deepcopy the element when producing it,
220
+ so that the result can be modified in place without affecting the
221
+ source in the list.
222
+ serialize (bool or callable): whether to serialize the stroage to other
223
+ backend. If `True`, the default serialize method will be used, if given
224
+ a callable, the callable will be used as serialize method.
225
+ """
226
+ self._lst = lst
227
+ self._copy = copy
228
+ if not isinstance(serialize, (bool, Callable)):
229
+ raise TypeError(f"Unsupported type for argument `serailzie`: {serialize}")
230
+ self._serialize = serialize is not False
231
+
232
+ if self._serialize:
233
+ serialize_method = (
234
+ serialize
235
+ if isinstance(serialize, Callable)
236
+ else _DEFAULT_DATASET_FROM_LIST_SERIALIZE_METHOD
237
+ )
238
+ logger.info(f"Serializing the dataset using: {serialize_method}")
239
+ self._lst = serialize_method(self._lst)
240
+
241
+ def __len__(self):
242
+ return len(self._lst)
243
+
244
+ def __getitem__(self, idx):
245
+ if self._copy and not self._serialize:
246
+ return copy.deepcopy(self._lst[idx])
247
+ else:
248
+ return self._lst[idx]
249
+
250
+
251
+ class ToIterableDataset(data.IterableDataset):
252
+ """
253
+ Convert an old indices-based (also called map-style) dataset
254
+ to an iterable-style dataset.
255
+ """
256
+
257
+ def __init__(
258
+ self,
259
+ dataset: data.Dataset,
260
+ sampler: Sampler,
261
+ shard_sampler: bool = True,
262
+ shard_chunk_size: int = 1,
263
+ ):
264
+ """
265
+ Args:
266
+ dataset: an old-style dataset with ``__getitem__``
267
+ sampler: a cheap iterable that produces indices to be applied on ``dataset``.
268
+ shard_sampler: whether to shard the sampler based on the current pytorch data loader
269
+ worker id. When an IterableDataset is forked by pytorch's DataLoader into multiple
270
+ workers, it is responsible for sharding its data based on worker id so that workers
271
+ don't produce identical data.
272
+
273
+ Most samplers (like our TrainingSampler) do not shard based on dataloader worker id
274
+ and this argument should be set to True. But certain samplers may be already
275
+ sharded, in that case this argument should be set to False.
276
+ shard_chunk_size: when sharding the sampler, each worker will
277
+ """
278
+ assert not isinstance(dataset, data.IterableDataset), dataset
279
+ assert isinstance(sampler, Sampler), sampler
280
+ self.dataset = dataset
281
+ self.sampler = sampler
282
+ self.shard_sampler = shard_sampler
283
+ self.shard_chunk_size = shard_chunk_size
284
+
285
+ def __iter__(self):
286
+ if not self.shard_sampler:
287
+ sampler = self.sampler
288
+ else:
289
+ # With map-style dataset, `DataLoader(dataset, sampler)` runs the
290
+ # sampler in main process only. But `DataLoader(ToIterableDataset(dataset, sampler))`
291
+ # will run sampler in every of the N worker. So we should only keep 1/N of the ids on
292
+ # each worker. The assumption is that sampler is cheap to iterate so it's fine to
293
+ # discard ids in workers.
294
+ sampler = _shard_iterator_dataloader_worker(self.sampler, self.shard_chunk_size)
295
+ for idx in sampler:
296
+ yield self.dataset[idx]
297
+
298
+ def __len__(self):
299
+ return len(self.sampler)
300
+
301
+
302
+ class AspectRatioGroupedDataset(data.IterableDataset):
303
+ """
304
+ Batch data that have similar aspect ratio together.
305
+ In this implementation, images whose aspect ratio < (or >) 1 will
306
+ be batched together.
307
+ This improves training speed because the images then need less padding
308
+ to form a batch.
309
+
310
+ It assumes the underlying dataset produces dicts with "width" and "height" keys.
311
+ It will then produce a list of original dicts with length = batch_size,
312
+ all with similar aspect ratios.
313
+ """
314
+
315
+ def __init__(self, dataset, batch_size):
316
+ """
317
+ Args:
318
+ dataset: an iterable. Each element must be a dict with keys
319
+ "width" and "height", which will be used to batch data.
320
+ batch_size (int):
321
+ """
322
+ self.dataset = dataset
323
+ self.batch_size = batch_size
324
+ self._buckets = [[] for _ in range(2)]
325
+ # Hard-coded two aspect ratio groups: w > h and w < h.
326
+ # Can add support for more aspect ratio groups, but doesn't seem useful
327
+
328
+ def __iter__(self):
329
+ for d in self.dataset:
330
+ w, h = d["width"], d["height"]
331
+ bucket_id = 0 if w > h else 1
332
+ bucket = self._buckets[bucket_id]
333
+ bucket.append(d)
334
+ if len(bucket) == self.batch_size:
335
+ data = bucket[:]
336
+ # Clear bucket first, because code after yield is not
337
+ # guaranteed to execute
338
+ del bucket[:]
339
+ yield data
CatVTON/detectron2/data/dataset_mapper.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ import copy
3
+ import logging
4
+ import numpy as np
5
+ from typing import List, Optional, Union
6
+ import torch
7
+
8
+ from detectron2.config import configurable
9
+
10
+ from . import detection_utils as utils
11
+ from . import transforms as T
12
+
13
+ """
14
+ This file contains the default mapping that's applied to "dataset dicts".
15
+ """
16
+
17
+ __all__ = ["DatasetMapper"]
18
+
19
+
20
+ class DatasetMapper:
21
+ """
22
+ A callable which takes a dataset dict in Detectron2 Dataset format,
23
+ and map it into a format used by the model.
24
+
25
+ This is the default callable to be used to map your dataset dict into training data.
26
+ You may need to follow it to implement your own one for customized logic,
27
+ such as a different way to read or transform images.
28
+ See :doc:`/tutorials/data_loading` for details.
29
+
30
+ The callable currently does the following:
31
+
32
+ 1. Read the image from "file_name"
33
+ 2. Applies cropping/geometric transforms to the image and annotations
34
+ 3. Prepare data and annotations to Tensor and :class:`Instances`
35
+ """
36
+
37
+ @configurable
38
+ def __init__(
39
+ self,
40
+ is_train: bool,
41
+ *,
42
+ augmentations: List[Union[T.Augmentation, T.Transform]],
43
+ image_format: str,
44
+ use_instance_mask: bool = False,
45
+ use_keypoint: bool = False,
46
+ instance_mask_format: str = "polygon",
47
+ keypoint_hflip_indices: Optional[np.ndarray] = None,
48
+ precomputed_proposal_topk: Optional[int] = None,
49
+ recompute_boxes: bool = False,
50
+ ):
51
+ """
52
+ NOTE: this interface is experimental.
53
+
54
+ Args:
55
+ is_train: whether it's used in training or inference
56
+ augmentations: a list of augmentations or deterministic transforms to apply
57
+ image_format: an image format supported by :func:`detection_utils.read_image`.
58
+ use_instance_mask: whether to process instance segmentation annotations, if available
59
+ use_keypoint: whether to process keypoint annotations if available
60
+ instance_mask_format: one of "polygon" or "bitmask". Process instance segmentation
61
+ masks into this format.
62
+ keypoint_hflip_indices: see :func:`detection_utils.create_keypoint_hflip_indices`
63
+ precomputed_proposal_topk: if given, will load pre-computed
64
+ proposals from dataset_dict and keep the top k proposals for each image.
65
+ recompute_boxes: whether to overwrite bounding box annotations
66
+ by computing tight bounding boxes from instance mask annotations.
67
+ """
68
+ if recompute_boxes:
69
+ assert use_instance_mask, "recompute_boxes requires instance masks"
70
+ # fmt: off
71
+ self.is_train = is_train
72
+ self.augmentations = T.AugmentationList(augmentations)
73
+ self.image_format = image_format
74
+ self.use_instance_mask = use_instance_mask
75
+ self.instance_mask_format = instance_mask_format
76
+ self.use_keypoint = use_keypoint
77
+ self.keypoint_hflip_indices = keypoint_hflip_indices
78
+ self.proposal_topk = precomputed_proposal_topk
79
+ self.recompute_boxes = recompute_boxes
80
+ # fmt: on
81
+ logger = logging.getLogger(__name__)
82
+ mode = "training" if is_train else "inference"
83
+ logger.info(f"[DatasetMapper] Augmentations used in {mode}: {augmentations}")
84
+
85
+ @classmethod
86
+ def from_config(cls, cfg, is_train: bool = True):
87
+ augs = utils.build_augmentation(cfg, is_train)
88
+ if cfg.INPUT.CROP.ENABLED and is_train:
89
+ augs.insert(0, T.RandomCrop(cfg.INPUT.CROP.TYPE, cfg.INPUT.CROP.SIZE))
90
+ recompute_boxes = cfg.MODEL.MASK_ON
91
+ else:
92
+ recompute_boxes = False
93
+
94
+ ret = {
95
+ "is_train": is_train,
96
+ "augmentations": augs,
97
+ "image_format": cfg.INPUT.FORMAT,
98
+ "use_instance_mask": cfg.MODEL.MASK_ON,
99
+ "instance_mask_format": cfg.INPUT.MASK_FORMAT,
100
+ "use_keypoint": cfg.MODEL.KEYPOINT_ON,
101
+ "recompute_boxes": recompute_boxes,
102
+ }
103
+
104
+ if cfg.MODEL.KEYPOINT_ON:
105
+ ret["keypoint_hflip_indices"] = utils.create_keypoint_hflip_indices(cfg.DATASETS.TRAIN)
106
+
107
+ if cfg.MODEL.LOAD_PROPOSALS:
108
+ ret["precomputed_proposal_topk"] = (
109
+ cfg.DATASETS.PRECOMPUTED_PROPOSAL_TOPK_TRAIN
110
+ if is_train
111
+ else cfg.DATASETS.PRECOMPUTED_PROPOSAL_TOPK_TEST
112
+ )
113
+ return ret
114
+
115
+ def _transform_annotations(self, dataset_dict, transforms, image_shape):
116
+ # USER: Modify this if you want to keep them for some reason.
117
+ for anno in dataset_dict["annotations"]:
118
+ if not self.use_instance_mask:
119
+ anno.pop("segmentation", None)
120
+ if not self.use_keypoint:
121
+ anno.pop("keypoints", None)
122
+
123
+ # USER: Implement additional transformations if you have other types of data
124
+ annos = [
125
+ utils.transform_instance_annotations(
126
+ obj, transforms, image_shape, keypoint_hflip_indices=self.keypoint_hflip_indices
127
+ )
128
+ for obj in dataset_dict.pop("annotations")
129
+ if obj.get("iscrowd", 0) == 0
130
+ ]
131
+ instances = utils.annotations_to_instances(
132
+ annos, image_shape, mask_format=self.instance_mask_format
133
+ )
134
+
135
+ # After transforms such as cropping are applied, the bounding box may no longer
136
+ # tightly bound the object. As an example, imagine a triangle object
137
+ # [(0,0), (2,0), (0,2)] cropped by a box [(1,0),(2,2)] (XYXY format). The tight
138
+ # bounding box of the cropped triangle should be [(1,0),(2,1)], which is not equal to
139
+ # the intersection of original bounding box and the cropping box.
140
+ if self.recompute_boxes:
141
+ instances.gt_boxes = instances.gt_masks.get_bounding_boxes()
142
+ dataset_dict["instances"] = utils.filter_empty_instances(instances)
143
+
144
+ def __call__(self, dataset_dict):
145
+ """
146
+ Args:
147
+ dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format.
148
+
149
+ Returns:
150
+ dict: a format that builtin models in detectron2 accept
151
+ """
152
+ dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below
153
+ # USER: Write your own image loading if it's not from a file
154
+ image = utils.read_image(dataset_dict["file_name"], format=self.image_format)
155
+ utils.check_image_size(dataset_dict, image)
156
+
157
+ # USER: Remove if you don't do semantic/panoptic segmentation.
158
+ if "sem_seg_file_name" in dataset_dict:
159
+ sem_seg_gt = utils.read_image(dataset_dict.pop("sem_seg_file_name"), "L").squeeze(2)
160
+ else:
161
+ sem_seg_gt = None
162
+
163
+ aug_input = T.AugInput(image, sem_seg=sem_seg_gt)
164
+ transforms = self.augmentations(aug_input)
165
+ image, sem_seg_gt = aug_input.image, aug_input.sem_seg
166
+
167
+ image_shape = image.shape[:2] # h, w
168
+ # Pytorch's dataloader is efficient on torch.Tensor due to shared-memory,
169
+ # but not efficient on large generic data structures due to the use of pickle & mp.Queue.
170
+ # Therefore it's important to use torch.Tensor.
171
+ dataset_dict["image"] = torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1)))
172
+ if sem_seg_gt is not None:
173
+ dataset_dict["sem_seg"] = torch.as_tensor(sem_seg_gt.astype("long"))
174
+
175
+ # USER: Remove if you don't use pre-computed proposals.
176
+ # Most users would not need this feature.
177
+ if self.proposal_topk is not None:
178
+ utils.transform_proposals(
179
+ dataset_dict, image_shape, transforms, proposal_topk=self.proposal_topk
180
+ )
181
+
182
+ if not self.is_train:
183
+ # USER: Modify this if you want to keep them for some reason.
184
+ dataset_dict.pop("annotations", None)
185
+ dataset_dict.pop("sem_seg_file_name", None)
186
+ return dataset_dict
187
+
188
+ if "annotations" in dataset_dict:
189
+ self._transform_annotations(dataset_dict, transforms, image_shape)
190
+
191
+ return dataset_dict
CatVTON/detectron2/data/datasets/README.md ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ ### Common Datasets
4
+
5
+ The dataset implemented here do not need to load the data into the final format.
6
+ It should provide the minimal data structure needed to use the dataset, so it can be very efficient.
7
+
8
+ For example, for an image dataset, just provide the file names and labels, but don't read the images.
9
+ Let the downstream decide how to read.
CatVTON/detectron2/data/datasets/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ from .coco import load_coco_json, load_sem_seg, register_coco_instances, convert_to_coco_json
3
+ from .coco_panoptic import register_coco_panoptic, register_coco_panoptic_separated
4
+ from .lvis import load_lvis_json, register_lvis_instances, get_lvis_instances_meta
5
+ from .pascal_voc import load_voc_instances, register_pascal_voc
6
+ from . import builtin as _builtin # ensure the builtin datasets are registered
7
+
8
+
9
+ __all__ = [k for k in globals().keys() if not k.startswith("_")]
CatVTON/detectron2/data/datasets/builtin.py ADDED
@@ -0,0 +1,259 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) Facebook, Inc. and its affiliates.
3
+
4
+
5
+ """
6
+ This file registers pre-defined datasets at hard-coded paths, and their metadata.
7
+
8
+ We hard-code metadata for common datasets. This will enable:
9
+ 1. Consistency check when loading the datasets
10
+ 2. Use models on these standard datasets directly and run demos,
11
+ without having to download the dataset annotations
12
+
13
+ We hard-code some paths to the dataset that's assumed to
14
+ exist in "./datasets/".
15
+
16
+ Users SHOULD NOT use this file to create new dataset / metadata for new dataset.
17
+ To add new dataset, refer to the tutorial "docs/DATASETS.md".
18
+ """
19
+
20
+ import os
21
+
22
+ from detectron2.data import DatasetCatalog, MetadataCatalog
23
+
24
+ from .builtin_meta import ADE20K_SEM_SEG_CATEGORIES, _get_builtin_metadata
25
+ from .cityscapes import load_cityscapes_instances, load_cityscapes_semantic
26
+ from .cityscapes_panoptic import register_all_cityscapes_panoptic
27
+ from .coco import load_sem_seg, register_coco_instances
28
+ from .coco_panoptic import register_coco_panoptic, register_coco_panoptic_separated
29
+ from .lvis import get_lvis_instances_meta, register_lvis_instances
30
+ from .pascal_voc import register_pascal_voc
31
+
32
+ # ==== Predefined datasets and splits for COCO ==========
33
+
34
+ _PREDEFINED_SPLITS_COCO = {}
35
+ _PREDEFINED_SPLITS_COCO["coco"] = {
36
+ "coco_2014_train": ("coco/train2014", "coco/annotations/instances_train2014.json"),
37
+ "coco_2014_val": ("coco/val2014", "coco/annotations/instances_val2014.json"),
38
+ "coco_2014_minival": ("coco/val2014", "coco/annotations/instances_minival2014.json"),
39
+ "coco_2014_valminusminival": (
40
+ "coco/val2014",
41
+ "coco/annotations/instances_valminusminival2014.json",
42
+ ),
43
+ "coco_2017_train": ("coco/train2017", "coco/annotations/instances_train2017.json"),
44
+ "coco_2017_val": ("coco/val2017", "coco/annotations/instances_val2017.json"),
45
+ "coco_2017_test": ("coco/test2017", "coco/annotations/image_info_test2017.json"),
46
+ "coco_2017_test-dev": ("coco/test2017", "coco/annotations/image_info_test-dev2017.json"),
47
+ "coco_2017_val_100": ("coco/val2017", "coco/annotations/instances_val2017_100.json"),
48
+ }
49
+
50
+ _PREDEFINED_SPLITS_COCO["coco_person"] = {
51
+ "keypoints_coco_2014_train": (
52
+ "coco/train2014",
53
+ "coco/annotations/person_keypoints_train2014.json",
54
+ ),
55
+ "keypoints_coco_2014_val": ("coco/val2014", "coco/annotations/person_keypoints_val2014.json"),
56
+ "keypoints_coco_2014_minival": (
57
+ "coco/val2014",
58
+ "coco/annotations/person_keypoints_minival2014.json",
59
+ ),
60
+ "keypoints_coco_2014_valminusminival": (
61
+ "coco/val2014",
62
+ "coco/annotations/person_keypoints_valminusminival2014.json",
63
+ ),
64
+ "keypoints_coco_2017_train": (
65
+ "coco/train2017",
66
+ "coco/annotations/person_keypoints_train2017.json",
67
+ ),
68
+ "keypoints_coco_2017_val": ("coco/val2017", "coco/annotations/person_keypoints_val2017.json"),
69
+ "keypoints_coco_2017_val_100": (
70
+ "coco/val2017",
71
+ "coco/annotations/person_keypoints_val2017_100.json",
72
+ ),
73
+ }
74
+
75
+
76
+ _PREDEFINED_SPLITS_COCO_PANOPTIC = {
77
+ "coco_2017_train_panoptic": (
78
+ # This is the original panoptic annotation directory
79
+ "coco/panoptic_train2017",
80
+ "coco/annotations/panoptic_train2017.json",
81
+ # This directory contains semantic annotations that are
82
+ # converted from panoptic annotations.
83
+ # It is used by PanopticFPN.
84
+ # You can use the script at detectron2/datasets/prepare_panoptic_fpn.py
85
+ # to create these directories.
86
+ "coco/panoptic_stuff_train2017",
87
+ ),
88
+ "coco_2017_val_panoptic": (
89
+ "coco/panoptic_val2017",
90
+ "coco/annotations/panoptic_val2017.json",
91
+ "coco/panoptic_stuff_val2017",
92
+ ),
93
+ "coco_2017_val_100_panoptic": (
94
+ "coco/panoptic_val2017_100",
95
+ "coco/annotations/panoptic_val2017_100.json",
96
+ "coco/panoptic_stuff_val2017_100",
97
+ ),
98
+ }
99
+
100
+
101
+ def register_all_coco(root):
102
+ for dataset_name, splits_per_dataset in _PREDEFINED_SPLITS_COCO.items():
103
+ for key, (image_root, json_file) in splits_per_dataset.items():
104
+ # Assume pre-defined datasets live in `./datasets`.
105
+ register_coco_instances(
106
+ key,
107
+ _get_builtin_metadata(dataset_name),
108
+ os.path.join(root, json_file) if "://" not in json_file else json_file,
109
+ os.path.join(root, image_root),
110
+ )
111
+
112
+ for (
113
+ prefix,
114
+ (panoptic_root, panoptic_json, semantic_root),
115
+ ) in _PREDEFINED_SPLITS_COCO_PANOPTIC.items():
116
+ prefix_instances = prefix[: -len("_panoptic")]
117
+ instances_meta = MetadataCatalog.get(prefix_instances)
118
+ image_root, instances_json = instances_meta.image_root, instances_meta.json_file
119
+ # The "separated" version of COCO panoptic segmentation dataset,
120
+ # e.g. used by Panoptic FPN
121
+ register_coco_panoptic_separated(
122
+ prefix,
123
+ _get_builtin_metadata("coco_panoptic_separated"),
124
+ image_root,
125
+ os.path.join(root, panoptic_root),
126
+ os.path.join(root, panoptic_json),
127
+ os.path.join(root, semantic_root),
128
+ instances_json,
129
+ )
130
+ # The "standard" version of COCO panoptic segmentation dataset,
131
+ # e.g. used by Panoptic-DeepLab
132
+ register_coco_panoptic(
133
+ prefix,
134
+ _get_builtin_metadata("coco_panoptic_standard"),
135
+ image_root,
136
+ os.path.join(root, panoptic_root),
137
+ os.path.join(root, panoptic_json),
138
+ instances_json,
139
+ )
140
+
141
+
142
+ # ==== Predefined datasets and splits for LVIS ==========
143
+
144
+
145
+ _PREDEFINED_SPLITS_LVIS = {
146
+ "lvis_v1": {
147
+ "lvis_v1_train": ("coco/", "lvis/lvis_v1_train.json"),
148
+ "lvis_v1_val": ("coco/", "lvis/lvis_v1_val.json"),
149
+ "lvis_v1_test_dev": ("coco/", "lvis/lvis_v1_image_info_test_dev.json"),
150
+ "lvis_v1_test_challenge": ("coco/", "lvis/lvis_v1_image_info_test_challenge.json"),
151
+ },
152
+ "lvis_v0.5": {
153
+ "lvis_v0.5_train": ("coco/", "lvis/lvis_v0.5_train.json"),
154
+ "lvis_v0.5_val": ("coco/", "lvis/lvis_v0.5_val.json"),
155
+ "lvis_v0.5_val_rand_100": ("coco/", "lvis/lvis_v0.5_val_rand_100.json"),
156
+ "lvis_v0.5_test": ("coco/", "lvis/lvis_v0.5_image_info_test.json"),
157
+ },
158
+ "lvis_v0.5_cocofied": {
159
+ "lvis_v0.5_train_cocofied": ("coco/", "lvis/lvis_v0.5_train_cocofied.json"),
160
+ "lvis_v0.5_val_cocofied": ("coco/", "lvis/lvis_v0.5_val_cocofied.json"),
161
+ },
162
+ }
163
+
164
+
165
+ def register_all_lvis(root):
166
+ for dataset_name, splits_per_dataset in _PREDEFINED_SPLITS_LVIS.items():
167
+ for key, (image_root, json_file) in splits_per_dataset.items():
168
+ register_lvis_instances(
169
+ key,
170
+ get_lvis_instances_meta(dataset_name),
171
+ os.path.join(root, json_file) if "://" not in json_file else json_file,
172
+ os.path.join(root, image_root),
173
+ )
174
+
175
+
176
+ # ==== Predefined splits for raw cityscapes images ===========
177
+ _RAW_CITYSCAPES_SPLITS = {
178
+ "cityscapes_fine_{task}_train": ("cityscapes/leftImg8bit/train/", "cityscapes/gtFine/train/"),
179
+ "cityscapes_fine_{task}_val": ("cityscapes/leftImg8bit/val/", "cityscapes/gtFine/val/"),
180
+ "cityscapes_fine_{task}_test": ("cityscapes/leftImg8bit/test/", "cityscapes/gtFine/test/"),
181
+ }
182
+
183
+
184
+ def register_all_cityscapes(root):
185
+ for key, (image_dir, gt_dir) in _RAW_CITYSCAPES_SPLITS.items():
186
+ meta = _get_builtin_metadata("cityscapes")
187
+ image_dir = os.path.join(root, image_dir)
188
+ gt_dir = os.path.join(root, gt_dir)
189
+
190
+ inst_key = key.format(task="instance_seg")
191
+ DatasetCatalog.register(
192
+ inst_key,
193
+ lambda x=image_dir, y=gt_dir: load_cityscapes_instances(
194
+ x, y, from_json=True, to_polygons=True
195
+ ),
196
+ )
197
+ MetadataCatalog.get(inst_key).set(
198
+ image_dir=image_dir, gt_dir=gt_dir, evaluator_type="cityscapes_instance", **meta
199
+ )
200
+
201
+ sem_key = key.format(task="sem_seg")
202
+ DatasetCatalog.register(
203
+ sem_key, lambda x=image_dir, y=gt_dir: load_cityscapes_semantic(x, y)
204
+ )
205
+ MetadataCatalog.get(sem_key).set(
206
+ image_dir=image_dir,
207
+ gt_dir=gt_dir,
208
+ evaluator_type="cityscapes_sem_seg",
209
+ ignore_label=255,
210
+ **meta,
211
+ )
212
+
213
+
214
+ # ==== Predefined splits for PASCAL VOC ===========
215
+ def register_all_pascal_voc(root):
216
+ SPLITS = [
217
+ ("voc_2007_trainval", "VOC2007", "trainval"),
218
+ ("voc_2007_train", "VOC2007", "train"),
219
+ ("voc_2007_val", "VOC2007", "val"),
220
+ ("voc_2007_test", "VOC2007", "test"),
221
+ ("voc_2012_trainval", "VOC2012", "trainval"),
222
+ ("voc_2012_train", "VOC2012", "train"),
223
+ ("voc_2012_val", "VOC2012", "val"),
224
+ ]
225
+ for name, dirname, split in SPLITS:
226
+ year = 2007 if "2007" in name else 2012
227
+ register_pascal_voc(name, os.path.join(root, dirname), split, year)
228
+ MetadataCatalog.get(name).evaluator_type = "pascal_voc"
229
+
230
+
231
+ def register_all_ade20k(root):
232
+ root = os.path.join(root, "ADEChallengeData2016")
233
+ for name, dirname in [("train", "training"), ("val", "validation")]:
234
+ image_dir = os.path.join(root, "images", dirname)
235
+ gt_dir = os.path.join(root, "annotations_detectron2", dirname)
236
+ name = f"ade20k_sem_seg_{name}"
237
+ DatasetCatalog.register(
238
+ name, lambda x=image_dir, y=gt_dir: load_sem_seg(y, x, gt_ext="png", image_ext="jpg")
239
+ )
240
+ MetadataCatalog.get(name).set(
241
+ stuff_classes=ADE20K_SEM_SEG_CATEGORIES[:],
242
+ image_root=image_dir,
243
+ sem_seg_root=gt_dir,
244
+ evaluator_type="sem_seg",
245
+ ignore_label=255,
246
+ )
247
+
248
+
249
+ # True for open source;
250
+ # Internally at fb, we register them elsewhere
251
+ if __name__.endswith(".builtin"):
252
+ # Assume pre-defined datasets live in `./datasets`.
253
+ _root = os.path.expanduser(os.getenv("DETECTRON2_DATASETS", "datasets"))
254
+ register_all_coco(_root)
255
+ register_all_lvis(_root)
256
+ register_all_cityscapes(_root)
257
+ register_all_cityscapes_panoptic(_root)
258
+ register_all_pascal_voc(_root)
259
+ register_all_ade20k(_root)
CatVTON/detectron2/data/datasets/builtin_meta.py ADDED
@@ -0,0 +1,350 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) Facebook, Inc. and its affiliates.
3
+
4
+ """
5
+ Note:
6
+ For your custom dataset, there is no need to hard-code metadata anywhere in the code.
7
+ For example, for COCO-format dataset, metadata will be obtained automatically
8
+ when calling `load_coco_json`. For other dataset, metadata may also be obtained in other ways
9
+ during loading.
10
+
11
+ However, we hard-coded metadata for a few common dataset here.
12
+ The only goal is to allow users who don't have these dataset to use pre-trained models.
13
+ Users don't have to download a COCO json (which contains metadata), in order to visualize a
14
+ COCO model (with correct class names and colors).
15
+ """
16
+
17
+
18
+ # All coco categories, together with their nice-looking visualization colors
19
+ # It's from https://github.com/cocodataset/panopticapi/blob/master/panoptic_coco_categories.json
20
+ COCO_CATEGORIES = [
21
+ {"color": [220, 20, 60], "isthing": 1, "id": 1, "name": "person"},
22
+ {"color": [119, 11, 32], "isthing": 1, "id": 2, "name": "bicycle"},
23
+ {"color": [0, 0, 142], "isthing": 1, "id": 3, "name": "car"},
24
+ {"color": [0, 0, 230], "isthing": 1, "id": 4, "name": "motorcycle"},
25
+ {"color": [106, 0, 228], "isthing": 1, "id": 5, "name": "airplane"},
26
+ {"color": [0, 60, 100], "isthing": 1, "id": 6, "name": "bus"},
27
+ {"color": [0, 80, 100], "isthing": 1, "id": 7, "name": "train"},
28
+ {"color": [0, 0, 70], "isthing": 1, "id": 8, "name": "truck"},
29
+ {"color": [0, 0, 192], "isthing": 1, "id": 9, "name": "boat"},
30
+ {"color": [250, 170, 30], "isthing": 1, "id": 10, "name": "traffic light"},
31
+ {"color": [100, 170, 30], "isthing": 1, "id": 11, "name": "fire hydrant"},
32
+ {"color": [220, 220, 0], "isthing": 1, "id": 13, "name": "stop sign"},
33
+ {"color": [175, 116, 175], "isthing": 1, "id": 14, "name": "parking meter"},
34
+ {"color": [250, 0, 30], "isthing": 1, "id": 15, "name": "bench"},
35
+ {"color": [165, 42, 42], "isthing": 1, "id": 16, "name": "bird"},
36
+ {"color": [255, 77, 255], "isthing": 1, "id": 17, "name": "cat"},
37
+ {"color": [0, 226, 252], "isthing": 1, "id": 18, "name": "dog"},
38
+ {"color": [182, 182, 255], "isthing": 1, "id": 19, "name": "horse"},
39
+ {"color": [0, 82, 0], "isthing": 1, "id": 20, "name": "sheep"},
40
+ {"color": [120, 166, 157], "isthing": 1, "id": 21, "name": "cow"},
41
+ {"color": [110, 76, 0], "isthing": 1, "id": 22, "name": "elephant"},
42
+ {"color": [174, 57, 255], "isthing": 1, "id": 23, "name": "bear"},
43
+ {"color": [199, 100, 0], "isthing": 1, "id": 24, "name": "zebra"},
44
+ {"color": [72, 0, 118], "isthing": 1, "id": 25, "name": "giraffe"},
45
+ {"color": [255, 179, 240], "isthing": 1, "id": 27, "name": "backpack"},
46
+ {"color": [0, 125, 92], "isthing": 1, "id": 28, "name": "umbrella"},
47
+ {"color": [209, 0, 151], "isthing": 1, "id": 31, "name": "handbag"},
48
+ {"color": [188, 208, 182], "isthing": 1, "id": 32, "name": "tie"},
49
+ {"color": [0, 220, 176], "isthing": 1, "id": 33, "name": "suitcase"},
50
+ {"color": [255, 99, 164], "isthing": 1, "id": 34, "name": "frisbee"},
51
+ {"color": [92, 0, 73], "isthing": 1, "id": 35, "name": "skis"},
52
+ {"color": [133, 129, 255], "isthing": 1, "id": 36, "name": "snowboard"},
53
+ {"color": [78, 180, 255], "isthing": 1, "id": 37, "name": "sports ball"},
54
+ {"color": [0, 228, 0], "isthing": 1, "id": 38, "name": "kite"},
55
+ {"color": [174, 255, 243], "isthing": 1, "id": 39, "name": "baseball bat"},
56
+ {"color": [45, 89, 255], "isthing": 1, "id": 40, "name": "baseball glove"},
57
+ {"color": [134, 134, 103], "isthing": 1, "id": 41, "name": "skateboard"},
58
+ {"color": [145, 148, 174], "isthing": 1, "id": 42, "name": "surfboard"},
59
+ {"color": [255, 208, 186], "isthing": 1, "id": 43, "name": "tennis racket"},
60
+ {"color": [197, 226, 255], "isthing": 1, "id": 44, "name": "bottle"},
61
+ {"color": [171, 134, 1], "isthing": 1, "id": 46, "name": "wine glass"},
62
+ {"color": [109, 63, 54], "isthing": 1, "id": 47, "name": "cup"},
63
+ {"color": [207, 138, 255], "isthing": 1, "id": 48, "name": "fork"},
64
+ {"color": [151, 0, 95], "isthing": 1, "id": 49, "name": "knife"},
65
+ {"color": [9, 80, 61], "isthing": 1, "id": 50, "name": "spoon"},
66
+ {"color": [84, 105, 51], "isthing": 1, "id": 51, "name": "bowl"},
67
+ {"color": [74, 65, 105], "isthing": 1, "id": 52, "name": "banana"},
68
+ {"color": [166, 196, 102], "isthing": 1, "id": 53, "name": "apple"},
69
+ {"color": [208, 195, 210], "isthing": 1, "id": 54, "name": "sandwich"},
70
+ {"color": [255, 109, 65], "isthing": 1, "id": 55, "name": "orange"},
71
+ {"color": [0, 143, 149], "isthing": 1, "id": 56, "name": "broccoli"},
72
+ {"color": [179, 0, 194], "isthing": 1, "id": 57, "name": "carrot"},
73
+ {"color": [209, 99, 106], "isthing": 1, "id": 58, "name": "hot dog"},
74
+ {"color": [5, 121, 0], "isthing": 1, "id": 59, "name": "pizza"},
75
+ {"color": [227, 255, 205], "isthing": 1, "id": 60, "name": "donut"},
76
+ {"color": [147, 186, 208], "isthing": 1, "id": 61, "name": "cake"},
77
+ {"color": [153, 69, 1], "isthing": 1, "id": 62, "name": "chair"},
78
+ {"color": [3, 95, 161], "isthing": 1, "id": 63, "name": "couch"},
79
+ {"color": [163, 255, 0], "isthing": 1, "id": 64, "name": "potted plant"},
80
+ {"color": [119, 0, 170], "isthing": 1, "id": 65, "name": "bed"},
81
+ {"color": [0, 182, 199], "isthing": 1, "id": 67, "name": "dining table"},
82
+ {"color": [0, 165, 120], "isthing": 1, "id": 70, "name": "toilet"},
83
+ {"color": [183, 130, 88], "isthing": 1, "id": 72, "name": "tv"},
84
+ {"color": [95, 32, 0], "isthing": 1, "id": 73, "name": "laptop"},
85
+ {"color": [130, 114, 135], "isthing": 1, "id": 74, "name": "mouse"},
86
+ {"color": [110, 129, 133], "isthing": 1, "id": 75, "name": "remote"},
87
+ {"color": [166, 74, 118], "isthing": 1, "id": 76, "name": "keyboard"},
88
+ {"color": [219, 142, 185], "isthing": 1, "id": 77, "name": "cell phone"},
89
+ {"color": [79, 210, 114], "isthing": 1, "id": 78, "name": "microwave"},
90
+ {"color": [178, 90, 62], "isthing": 1, "id": 79, "name": "oven"},
91
+ {"color": [65, 70, 15], "isthing": 1, "id": 80, "name": "toaster"},
92
+ {"color": [127, 167, 115], "isthing": 1, "id": 81, "name": "sink"},
93
+ {"color": [59, 105, 106], "isthing": 1, "id": 82, "name": "refrigerator"},
94
+ {"color": [142, 108, 45], "isthing": 1, "id": 84, "name": "book"},
95
+ {"color": [196, 172, 0], "isthing": 1, "id": 85, "name": "clock"},
96
+ {"color": [95, 54, 80], "isthing": 1, "id": 86, "name": "vase"},
97
+ {"color": [128, 76, 255], "isthing": 1, "id": 87, "name": "scissors"},
98
+ {"color": [201, 57, 1], "isthing": 1, "id": 88, "name": "teddy bear"},
99
+ {"color": [246, 0, 122], "isthing": 1, "id": 89, "name": "hair drier"},
100
+ {"color": [191, 162, 208], "isthing": 1, "id": 90, "name": "toothbrush"},
101
+ {"color": [255, 255, 128], "isthing": 0, "id": 92, "name": "banner"},
102
+ {"color": [147, 211, 203], "isthing": 0, "id": 93, "name": "blanket"},
103
+ {"color": [150, 100, 100], "isthing": 0, "id": 95, "name": "bridge"},
104
+ {"color": [168, 171, 172], "isthing": 0, "id": 100, "name": "cardboard"},
105
+ {"color": [146, 112, 198], "isthing": 0, "id": 107, "name": "counter"},
106
+ {"color": [210, 170, 100], "isthing": 0, "id": 109, "name": "curtain"},
107
+ {"color": [92, 136, 89], "isthing": 0, "id": 112, "name": "door-stuff"},
108
+ {"color": [218, 88, 184], "isthing": 0, "id": 118, "name": "floor-wood"},
109
+ {"color": [241, 129, 0], "isthing": 0, "id": 119, "name": "flower"},
110
+ {"color": [217, 17, 255], "isthing": 0, "id": 122, "name": "fruit"},
111
+ {"color": [124, 74, 181], "isthing": 0, "id": 125, "name": "gravel"},
112
+ {"color": [70, 70, 70], "isthing": 0, "id": 128, "name": "house"},
113
+ {"color": [255, 228, 255], "isthing": 0, "id": 130, "name": "light"},
114
+ {"color": [154, 208, 0], "isthing": 0, "id": 133, "name": "mirror-stuff"},
115
+ {"color": [193, 0, 92], "isthing": 0, "id": 138, "name": "net"},
116
+ {"color": [76, 91, 113], "isthing": 0, "id": 141, "name": "pillow"},
117
+ {"color": [255, 180, 195], "isthing": 0, "id": 144, "name": "platform"},
118
+ {"color": [106, 154, 176], "isthing": 0, "id": 145, "name": "playingfield"},
119
+ {"color": [230, 150, 140], "isthing": 0, "id": 147, "name": "railroad"},
120
+ {"color": [60, 143, 255], "isthing": 0, "id": 148, "name": "river"},
121
+ {"color": [128, 64, 128], "isthing": 0, "id": 149, "name": "road"},
122
+ {"color": [92, 82, 55], "isthing": 0, "id": 151, "name": "roof"},
123
+ {"color": [254, 212, 124], "isthing": 0, "id": 154, "name": "sand"},
124
+ {"color": [73, 77, 174], "isthing": 0, "id": 155, "name": "sea"},
125
+ {"color": [255, 160, 98], "isthing": 0, "id": 156, "name": "shelf"},
126
+ {"color": [255, 255, 255], "isthing": 0, "id": 159, "name": "snow"},
127
+ {"color": [104, 84, 109], "isthing": 0, "id": 161, "name": "stairs"},
128
+ {"color": [169, 164, 131], "isthing": 0, "id": 166, "name": "tent"},
129
+ {"color": [225, 199, 255], "isthing": 0, "id": 168, "name": "towel"},
130
+ {"color": [137, 54, 74], "isthing": 0, "id": 171, "name": "wall-brick"},
131
+ {"color": [135, 158, 223], "isthing": 0, "id": 175, "name": "wall-stone"},
132
+ {"color": [7, 246, 231], "isthing": 0, "id": 176, "name": "wall-tile"},
133
+ {"color": [107, 255, 200], "isthing": 0, "id": 177, "name": "wall-wood"},
134
+ {"color": [58, 41, 149], "isthing": 0, "id": 178, "name": "water-other"},
135
+ {"color": [183, 121, 142], "isthing": 0, "id": 180, "name": "window-blind"},
136
+ {"color": [255, 73, 97], "isthing": 0, "id": 181, "name": "window-other"},
137
+ {"color": [107, 142, 35], "isthing": 0, "id": 184, "name": "tree-merged"},
138
+ {"color": [190, 153, 153], "isthing": 0, "id": 185, "name": "fence-merged"},
139
+ {"color": [146, 139, 141], "isthing": 0, "id": 186, "name": "ceiling-merged"},
140
+ {"color": [70, 130, 180], "isthing": 0, "id": 187, "name": "sky-other-merged"},
141
+ {"color": [134, 199, 156], "isthing": 0, "id": 188, "name": "cabinet-merged"},
142
+ {"color": [209, 226, 140], "isthing": 0, "id": 189, "name": "table-merged"},
143
+ {"color": [96, 36, 108], "isthing": 0, "id": 190, "name": "floor-other-merged"},
144
+ {"color": [96, 96, 96], "isthing": 0, "id": 191, "name": "pavement-merged"},
145
+ {"color": [64, 170, 64], "isthing": 0, "id": 192, "name": "mountain-merged"},
146
+ {"color": [152, 251, 152], "isthing": 0, "id": 193, "name": "grass-merged"},
147
+ {"color": [208, 229, 228], "isthing": 0, "id": 194, "name": "dirt-merged"},
148
+ {"color": [206, 186, 171], "isthing": 0, "id": 195, "name": "paper-merged"},
149
+ {"color": [152, 161, 64], "isthing": 0, "id": 196, "name": "food-other-merged"},
150
+ {"color": [116, 112, 0], "isthing": 0, "id": 197, "name": "building-other-merged"},
151
+ {"color": [0, 114, 143], "isthing": 0, "id": 198, "name": "rock-merged"},
152
+ {"color": [102, 102, 156], "isthing": 0, "id": 199, "name": "wall-other-merged"},
153
+ {"color": [250, 141, 255], "isthing": 0, "id": 200, "name": "rug-merged"},
154
+ ]
155
+
156
+ # fmt: off
157
+ COCO_PERSON_KEYPOINT_NAMES = (
158
+ "nose",
159
+ "left_eye", "right_eye",
160
+ "left_ear", "right_ear",
161
+ "left_shoulder", "right_shoulder",
162
+ "left_elbow", "right_elbow",
163
+ "left_wrist", "right_wrist",
164
+ "left_hip", "right_hip",
165
+ "left_knee", "right_knee",
166
+ "left_ankle", "right_ankle",
167
+ )
168
+ # fmt: on
169
+
170
+ # Pairs of keypoints that should be exchanged under horizontal flipping
171
+ COCO_PERSON_KEYPOINT_FLIP_MAP = (
172
+ ("left_eye", "right_eye"),
173
+ ("left_ear", "right_ear"),
174
+ ("left_shoulder", "right_shoulder"),
175
+ ("left_elbow", "right_elbow"),
176
+ ("left_wrist", "right_wrist"),
177
+ ("left_hip", "right_hip"),
178
+ ("left_knee", "right_knee"),
179
+ ("left_ankle", "right_ankle"),
180
+ )
181
+
182
+ # rules for pairs of keypoints to draw a line between, and the line color to use.
183
+ KEYPOINT_CONNECTION_RULES = [
184
+ # face
185
+ ("left_ear", "left_eye", (102, 204, 255)),
186
+ ("right_ear", "right_eye", (51, 153, 255)),
187
+ ("left_eye", "nose", (102, 0, 204)),
188
+ ("nose", "right_eye", (51, 102, 255)),
189
+ # upper-body
190
+ ("left_shoulder", "right_shoulder", (255, 128, 0)),
191
+ ("left_shoulder", "left_elbow", (153, 255, 204)),
192
+ ("right_shoulder", "right_elbow", (128, 229, 255)),
193
+ ("left_elbow", "left_wrist", (153, 255, 153)),
194
+ ("right_elbow", "right_wrist", (102, 255, 224)),
195
+ # lower-body
196
+ ("left_hip", "right_hip", (255, 102, 0)),
197
+ ("left_hip", "left_knee", (255, 255, 77)),
198
+ ("right_hip", "right_knee", (153, 255, 204)),
199
+ ("left_knee", "left_ankle", (191, 255, 128)),
200
+ ("right_knee", "right_ankle", (255, 195, 77)),
201
+ ]
202
+
203
+ # All Cityscapes categories, together with their nice-looking visualization colors
204
+ # It's from https://github.com/mcordts/cityscapesScripts/blob/master/cityscapesscripts/helpers/labels.py # noqa
205
+ CITYSCAPES_CATEGORIES = [
206
+ {"color": (128, 64, 128), "isthing": 0, "id": 7, "trainId": 0, "name": "road"},
207
+ {"color": (244, 35, 232), "isthing": 0, "id": 8, "trainId": 1, "name": "sidewalk"},
208
+ {"color": (70, 70, 70), "isthing": 0, "id": 11, "trainId": 2, "name": "building"},
209
+ {"color": (102, 102, 156), "isthing": 0, "id": 12, "trainId": 3, "name": "wall"},
210
+ {"color": (190, 153, 153), "isthing": 0, "id": 13, "trainId": 4, "name": "fence"},
211
+ {"color": (153, 153, 153), "isthing": 0, "id": 17, "trainId": 5, "name": "pole"},
212
+ {"color": (250, 170, 30), "isthing": 0, "id": 19, "trainId": 6, "name": "traffic light"},
213
+ {"color": (220, 220, 0), "isthing": 0, "id": 20, "trainId": 7, "name": "traffic sign"},
214
+ {"color": (107, 142, 35), "isthing": 0, "id": 21, "trainId": 8, "name": "vegetation"},
215
+ {"color": (152, 251, 152), "isthing": 0, "id": 22, "trainId": 9, "name": "terrain"},
216
+ {"color": (70, 130, 180), "isthing": 0, "id": 23, "trainId": 10, "name": "sky"},
217
+ {"color": (220, 20, 60), "isthing": 1, "id": 24, "trainId": 11, "name": "person"},
218
+ {"color": (255, 0, 0), "isthing": 1, "id": 25, "trainId": 12, "name": "rider"},
219
+ {"color": (0, 0, 142), "isthing": 1, "id": 26, "trainId": 13, "name": "car"},
220
+ {"color": (0, 0, 70), "isthing": 1, "id": 27, "trainId": 14, "name": "truck"},
221
+ {"color": (0, 60, 100), "isthing": 1, "id": 28, "trainId": 15, "name": "bus"},
222
+ {"color": (0, 80, 100), "isthing": 1, "id": 31, "trainId": 16, "name": "train"},
223
+ {"color": (0, 0, 230), "isthing": 1, "id": 32, "trainId": 17, "name": "motorcycle"},
224
+ {"color": (119, 11, 32), "isthing": 1, "id": 33, "trainId": 18, "name": "bicycle"},
225
+ ]
226
+
227
+ # fmt: off
228
+ ADE20K_SEM_SEG_CATEGORIES = [
229
+ "wall", "building", "sky", "floor", "tree", "ceiling", "road, route", "bed", "window ", "grass", "cabinet", "sidewalk, pavement", "person", "earth, ground", "door", "table", "mountain, mount", "plant", "curtain", "chair", "car", "water", "painting, picture", "sofa", "shelf", "house", "sea", "mirror", "rug", "field", "armchair", "seat", "fence", "desk", "rock, stone", "wardrobe, closet, press", "lamp", "tub", "rail", "cushion", "base, pedestal, stand", "box", "column, pillar", "signboard, sign", "chest of drawers, chest, bureau, dresser", "counter", "sand", "sink", "skyscraper", "fireplace", "refrigerator, icebox", "grandstand, covered stand", "path", "stairs", "runway", "case, display case, showcase, vitrine", "pool table, billiard table, snooker table", "pillow", "screen door, screen", "stairway, staircase", "river", "bridge, span", "bookcase", "blind, screen", "coffee table", "toilet, can, commode, crapper, pot, potty, stool, throne", "flower", "book", "hill", "bench", "countertop", "stove", "palm, palm tree", "kitchen island", "computer", "swivel chair", "boat", "bar", "arcade machine", "hovel, hut, hutch, shack, shanty", "bus", "towel", "light", "truck", "tower", "chandelier", "awning, sunshade, sunblind", "street lamp", "booth", "tv", "plane", "dirt track", "clothes", "pole", "land, ground, soil", "bannister, banister, balustrade, balusters, handrail", "escalator, moving staircase, moving stairway", "ottoman, pouf, pouffe, puff, hassock", "bottle", "buffet, counter, sideboard", "poster, posting, placard, notice, bill, card", "stage", "van", "ship", "fountain", "conveyer belt, conveyor belt, conveyer, conveyor, transporter", "canopy", "washer, automatic washer, washing machine", "plaything, toy", "pool", "stool", "barrel, cask", "basket, handbasket", "falls", "tent", "bag", "minibike, motorbike", "cradle", "oven", "ball", "food, solid food", "step, stair", "tank, storage tank", "trade name", "microwave", "pot", "animal", "bicycle", "lake", "dishwasher", "screen", "blanket, cover", "sculpture", "hood, exhaust hood", "sconce", "vase", "traffic light", "tray", "trash can", "fan", "pier", "crt screen", "plate", "monitor", "bulletin board", "shower", "radiator", "glass, drinking glass", "clock", "flag", # noqa
230
+ ]
231
+ # After processed by `prepare_ade20k_sem_seg.py`, id 255 means ignore
232
+ # fmt: on
233
+
234
+
235
+ def _get_coco_instances_meta():
236
+ thing_ids = [k["id"] for k in COCO_CATEGORIES if k["isthing"] == 1]
237
+ thing_colors = [k["color"] for k in COCO_CATEGORIES if k["isthing"] == 1]
238
+ assert len(thing_ids) == 80, len(thing_ids)
239
+ # Mapping from the incontiguous COCO category id to an id in [0, 79]
240
+ thing_dataset_id_to_contiguous_id = {k: i for i, k in enumerate(thing_ids)}
241
+ thing_classes = [k["name"] for k in COCO_CATEGORIES if k["isthing"] == 1]
242
+ ret = {
243
+ "thing_dataset_id_to_contiguous_id": thing_dataset_id_to_contiguous_id,
244
+ "thing_classes": thing_classes,
245
+ "thing_colors": thing_colors,
246
+ }
247
+ return ret
248
+
249
+
250
+ def _get_coco_panoptic_separated_meta():
251
+ """
252
+ Returns metadata for "separated" version of the panoptic segmentation dataset.
253
+ """
254
+ stuff_ids = [k["id"] for k in COCO_CATEGORIES if k["isthing"] == 0]
255
+ assert len(stuff_ids) == 53, len(stuff_ids)
256
+
257
+ # For semantic segmentation, this mapping maps from contiguous stuff id
258
+ # (in [0, 53], used in models) to ids in the dataset (used for processing results)
259
+ # The id 0 is mapped to an extra category "thing".
260
+ stuff_dataset_id_to_contiguous_id = {k: i + 1 for i, k in enumerate(stuff_ids)}
261
+ # When converting COCO panoptic annotations to semantic annotations
262
+ # We label the "thing" category to 0
263
+ stuff_dataset_id_to_contiguous_id[0] = 0
264
+
265
+ # 54 names for COCO stuff categories (including "things")
266
+ stuff_classes = ["things"] + [
267
+ k["name"].replace("-other", "").replace("-merged", "")
268
+ for k in COCO_CATEGORIES
269
+ if k["isthing"] == 0
270
+ ]
271
+
272
+ # NOTE: I randomly picked a color for things
273
+ stuff_colors = [[82, 18, 128]] + [k["color"] for k in COCO_CATEGORIES if k["isthing"] == 0]
274
+ ret = {
275
+ "stuff_dataset_id_to_contiguous_id": stuff_dataset_id_to_contiguous_id,
276
+ "stuff_classes": stuff_classes,
277
+ "stuff_colors": stuff_colors,
278
+ }
279
+ ret.update(_get_coco_instances_meta())
280
+ return ret
281
+
282
+
283
+ def _get_builtin_metadata(dataset_name):
284
+ if dataset_name == "coco":
285
+ return _get_coco_instances_meta()
286
+ if dataset_name == "coco_panoptic_separated":
287
+ return _get_coco_panoptic_separated_meta()
288
+ elif dataset_name == "coco_panoptic_standard":
289
+ meta = {}
290
+ # The following metadata maps contiguous id from [0, #thing categories +
291
+ # #stuff categories) to their names and colors. We have to replica of the
292
+ # same name and color under "thing_*" and "stuff_*" because the current
293
+ # visualization function in D2 handles thing and class classes differently
294
+ # due to some heuristic used in Panoptic FPN. We keep the same naming to
295
+ # enable reusing existing visualization functions.
296
+ thing_classes = [k["name"] for k in COCO_CATEGORIES]
297
+ thing_colors = [k["color"] for k in COCO_CATEGORIES]
298
+ stuff_classes = [k["name"] for k in COCO_CATEGORIES]
299
+ stuff_colors = [k["color"] for k in COCO_CATEGORIES]
300
+
301
+ meta["thing_classes"] = thing_classes
302
+ meta["thing_colors"] = thing_colors
303
+ meta["stuff_classes"] = stuff_classes
304
+ meta["stuff_colors"] = stuff_colors
305
+
306
+ # Convert category id for training:
307
+ # category id: like semantic segmentation, it is the class id for each
308
+ # pixel. Since there are some classes not used in evaluation, the category
309
+ # id is not always contiguous and thus we have two set of category ids:
310
+ # - original category id: category id in the original dataset, mainly
311
+ # used for evaluation.
312
+ # - contiguous category id: [0, #classes), in order to train the linear
313
+ # softmax classifier.
314
+ thing_dataset_id_to_contiguous_id = {}
315
+ stuff_dataset_id_to_contiguous_id = {}
316
+
317
+ for i, cat in enumerate(COCO_CATEGORIES):
318
+ if cat["isthing"]:
319
+ thing_dataset_id_to_contiguous_id[cat["id"]] = i
320
+ else:
321
+ stuff_dataset_id_to_contiguous_id[cat["id"]] = i
322
+
323
+ meta["thing_dataset_id_to_contiguous_id"] = thing_dataset_id_to_contiguous_id
324
+ meta["stuff_dataset_id_to_contiguous_id"] = stuff_dataset_id_to_contiguous_id
325
+
326
+ return meta
327
+ elif dataset_name == "coco_person":
328
+ return {
329
+ "thing_classes": ["person"],
330
+ "keypoint_names": COCO_PERSON_KEYPOINT_NAMES,
331
+ "keypoint_flip_map": COCO_PERSON_KEYPOINT_FLIP_MAP,
332
+ "keypoint_connection_rules": KEYPOINT_CONNECTION_RULES,
333
+ }
334
+ elif dataset_name == "cityscapes":
335
+ # fmt: off
336
+ CITYSCAPES_THING_CLASSES = [
337
+ "person", "rider", "car", "truck",
338
+ "bus", "train", "motorcycle", "bicycle",
339
+ ]
340
+ CITYSCAPES_STUFF_CLASSES = [
341
+ "road", "sidewalk", "building", "wall", "fence", "pole", "traffic light",
342
+ "traffic sign", "vegetation", "terrain", "sky", "person", "rider", "car",
343
+ "truck", "bus", "train", "motorcycle", "bicycle",
344
+ ]
345
+ # fmt: on
346
+ return {
347
+ "thing_classes": CITYSCAPES_THING_CLASSES,
348
+ "stuff_classes": CITYSCAPES_STUFF_CLASSES,
349
+ }
350
+ raise KeyError("No built-in metadata for dataset {}".format(dataset_name))
CatVTON/detectron2/data/datasets/cityscapes.py ADDED
@@ -0,0 +1,334 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ import functools
3
+ import json
4
+ import logging
5
+ import multiprocessing as mp
6
+ import numpy as np
7
+ import os
8
+ from itertools import chain
9
+ import pycocotools.mask as mask_util
10
+ from PIL import Image
11
+
12
+ from detectron2.structures import BoxMode
13
+ from detectron2.utils.comm import get_world_size
14
+ from detectron2.utils.file_io import PathManager
15
+ from detectron2.utils.logger import setup_logger
16
+
17
+ try:
18
+ import cv2 # noqa
19
+ except ImportError:
20
+ # OpenCV is an optional dependency at the moment
21
+ pass
22
+
23
+
24
+ logger = logging.getLogger(__name__)
25
+
26
+
27
+ def _get_cityscapes_files(image_dir, gt_dir):
28
+ files = []
29
+ # scan through the directory
30
+ cities = PathManager.ls(image_dir)
31
+ logger.info(f"{len(cities)} cities found in '{image_dir}'.")
32
+ for city in cities:
33
+ city_img_dir = os.path.join(image_dir, city)
34
+ city_gt_dir = os.path.join(gt_dir, city)
35
+ for basename in PathManager.ls(city_img_dir):
36
+ image_file = os.path.join(city_img_dir, basename)
37
+
38
+ suffix = "leftImg8bit.png"
39
+ assert basename.endswith(suffix), basename
40
+ basename = basename[: -len(suffix)]
41
+
42
+ instance_file = os.path.join(city_gt_dir, basename + "gtFine_instanceIds.png")
43
+ label_file = os.path.join(city_gt_dir, basename + "gtFine_labelIds.png")
44
+ json_file = os.path.join(city_gt_dir, basename + "gtFine_polygons.json")
45
+
46
+ files.append((image_file, instance_file, label_file, json_file))
47
+ assert len(files), "No images found in {}".format(image_dir)
48
+ for f in files[0]:
49
+ assert PathManager.isfile(f), f
50
+ return files
51
+
52
+
53
+ def load_cityscapes_instances(image_dir, gt_dir, from_json=True, to_polygons=True):
54
+ """
55
+ Args:
56
+ image_dir (str): path to the raw dataset. e.g., "~/cityscapes/leftImg8bit/train".
57
+ gt_dir (str): path to the raw annotations. e.g., "~/cityscapes/gtFine/train".
58
+ from_json (bool): whether to read annotations from the raw json file or the png files.
59
+ to_polygons (bool): whether to represent the segmentation as polygons
60
+ (COCO's format) instead of masks (cityscapes's format).
61
+
62
+ Returns:
63
+ list[dict]: a list of dicts in Detectron2 standard format. (See
64
+ `Using Custom Datasets </tutorials/datasets.html>`_ )
65
+ """
66
+ if from_json:
67
+ assert to_polygons, (
68
+ "Cityscapes's json annotations are in polygon format. "
69
+ "Converting to mask format is not supported now."
70
+ )
71
+ files = _get_cityscapes_files(image_dir, gt_dir)
72
+
73
+ logger.info("Preprocessing cityscapes annotations ...")
74
+ # This is still not fast: all workers will execute duplicate works and will
75
+ # take up to 10m on a 8GPU server.
76
+ pool = mp.Pool(processes=max(mp.cpu_count() // get_world_size() // 2, 4))
77
+
78
+ ret = pool.map(
79
+ functools.partial(_cityscapes_files_to_dict, from_json=from_json, to_polygons=to_polygons),
80
+ files,
81
+ )
82
+ logger.info("Loaded {} images from {}".format(len(ret), image_dir))
83
+
84
+ # Map cityscape ids to contiguous ids
85
+ from cityscapesscripts.helpers.labels import labels
86
+
87
+ labels = [l for l in labels if l.hasInstances and not l.ignoreInEval]
88
+ dataset_id_to_contiguous_id = {l.id: idx for idx, l in enumerate(labels)}
89
+ for dict_per_image in ret:
90
+ for anno in dict_per_image["annotations"]:
91
+ anno["category_id"] = dataset_id_to_contiguous_id[anno["category_id"]]
92
+ return ret
93
+
94
+
95
+ def load_cityscapes_semantic(image_dir, gt_dir):
96
+ """
97
+ Args:
98
+ image_dir (str): path to the raw dataset. e.g., "~/cityscapes/leftImg8bit/train".
99
+ gt_dir (str): path to the raw annotations. e.g., "~/cityscapes/gtFine/train".
100
+
101
+ Returns:
102
+ list[dict]: a list of dict, each has "file_name" and
103
+ "sem_seg_file_name".
104
+ """
105
+ ret = []
106
+ # gt_dir is small and contain many small files. make sense to fetch to local first
107
+ gt_dir = PathManager.get_local_path(gt_dir)
108
+ for image_file, _, label_file, json_file in _get_cityscapes_files(image_dir, gt_dir):
109
+ label_file = label_file.replace("labelIds", "labelTrainIds")
110
+
111
+ with PathManager.open(json_file, "r") as f:
112
+ jsonobj = json.load(f)
113
+ ret.append(
114
+ {
115
+ "file_name": image_file,
116
+ "sem_seg_file_name": label_file,
117
+ "height": jsonobj["imgHeight"],
118
+ "width": jsonobj["imgWidth"],
119
+ }
120
+ )
121
+ assert len(ret), f"No images found in {image_dir}!"
122
+ assert PathManager.isfile(
123
+ ret[0]["sem_seg_file_name"]
124
+ ), "Please generate labelTrainIds.png with cityscapesscripts/preparation/createTrainIdLabelImgs.py" # noqa
125
+ return ret
126
+
127
+
128
+ def _cityscapes_files_to_dict(files, from_json, to_polygons):
129
+ """
130
+ Parse cityscapes annotation files to a instance segmentation dataset dict.
131
+
132
+ Args:
133
+ files (tuple): consists of (image_file, instance_id_file, label_id_file, json_file)
134
+ from_json (bool): whether to read annotations from the raw json file or the png files.
135
+ to_polygons (bool): whether to represent the segmentation as polygons
136
+ (COCO's format) instead of masks (cityscapes's format).
137
+
138
+ Returns:
139
+ A dict in Detectron2 Dataset format.
140
+ """
141
+ from cityscapesscripts.helpers.labels import id2label, name2label
142
+
143
+ image_file, instance_id_file, _, json_file = files
144
+
145
+ annos = []
146
+
147
+ if from_json:
148
+ from shapely.geometry import MultiPolygon, Polygon
149
+
150
+ with PathManager.open(json_file, "r") as f:
151
+ jsonobj = json.load(f)
152
+ ret = {
153
+ "file_name": image_file,
154
+ "image_id": os.path.basename(image_file),
155
+ "height": jsonobj["imgHeight"],
156
+ "width": jsonobj["imgWidth"],
157
+ }
158
+
159
+ # `polygons_union` contains the union of all valid polygons.
160
+ polygons_union = Polygon()
161
+
162
+ # CityscapesScripts draw the polygons in sequential order
163
+ # and each polygon *overwrites* existing ones. See
164
+ # (https://github.com/mcordts/cityscapesScripts/blob/master/cityscapesscripts/preparation/json2instanceImg.py) # noqa
165
+ # We use reverse order, and each polygon *avoids* early ones.
166
+ # This will resolve the ploygon overlaps in the same way as CityscapesScripts.
167
+ for obj in jsonobj["objects"][::-1]:
168
+ if "deleted" in obj: # cityscapes data format specific
169
+ continue
170
+ label_name = obj["label"]
171
+
172
+ try:
173
+ label = name2label[label_name]
174
+ except KeyError:
175
+ if label_name.endswith("group"): # crowd area
176
+ label = name2label[label_name[: -len("group")]]
177
+ else:
178
+ raise
179
+ if label.id < 0: # cityscapes data format
180
+ continue
181
+
182
+ # Cityscapes's raw annotations uses integer coordinates
183
+ # Therefore +0.5 here
184
+ poly_coord = np.asarray(obj["polygon"], dtype="f4") + 0.5
185
+ # CityscapesScript uses PIL.ImageDraw.polygon to rasterize
186
+ # polygons for evaluation. This function operates in integer space
187
+ # and draws each pixel whose center falls into the polygon.
188
+ # Therefore it draws a polygon which is 0.5 "fatter" in expectation.
189
+ # We therefore dilate the input polygon by 0.5 as our input.
190
+ poly = Polygon(poly_coord).buffer(0.5, resolution=4)
191
+
192
+ if not label.hasInstances or label.ignoreInEval:
193
+ # even if we won't store the polygon it still contributes to overlaps resolution
194
+ polygons_union = polygons_union.union(poly)
195
+ continue
196
+
197
+ # Take non-overlapping part of the polygon
198
+ poly_wo_overlaps = poly.difference(polygons_union)
199
+ if poly_wo_overlaps.is_empty:
200
+ continue
201
+ polygons_union = polygons_union.union(poly)
202
+
203
+ anno = {}
204
+ anno["iscrowd"] = label_name.endswith("group")
205
+ anno["category_id"] = label.id
206
+
207
+ if isinstance(poly_wo_overlaps, Polygon):
208
+ poly_list = [poly_wo_overlaps]
209
+ elif isinstance(poly_wo_overlaps, MultiPolygon):
210
+ poly_list = poly_wo_overlaps.geoms
211
+ else:
212
+ raise NotImplementedError("Unknown geometric structure {}".format(poly_wo_overlaps))
213
+
214
+ poly_coord = []
215
+ for poly_el in poly_list:
216
+ # COCO API can work only with exterior boundaries now, hence we store only them.
217
+ # TODO: store both exterior and interior boundaries once other parts of the
218
+ # codebase support holes in polygons.
219
+ poly_coord.append(list(chain(*poly_el.exterior.coords)))
220
+ anno["segmentation"] = poly_coord
221
+ (xmin, ymin, xmax, ymax) = poly_wo_overlaps.bounds
222
+
223
+ anno["bbox"] = (xmin, ymin, xmax, ymax)
224
+ anno["bbox_mode"] = BoxMode.XYXY_ABS
225
+
226
+ annos.append(anno)
227
+ else:
228
+ # See also the official annotation parsing scripts at
229
+ # https://github.com/mcordts/cityscapesScripts/blob/master/cityscapesscripts/evaluation/instances2dict.py # noqa
230
+ with PathManager.open(instance_id_file, "rb") as f:
231
+ inst_image = np.asarray(Image.open(f), order="F")
232
+ # ids < 24 are stuff labels (filtering them first is about 5% faster)
233
+ flattened_ids = np.unique(inst_image[inst_image >= 24])
234
+
235
+ ret = {
236
+ "file_name": image_file,
237
+ "image_id": os.path.basename(image_file),
238
+ "height": inst_image.shape[0],
239
+ "width": inst_image.shape[1],
240
+ }
241
+
242
+ for instance_id in flattened_ids:
243
+ # For non-crowd annotations, instance_id // 1000 is the label_id
244
+ # Crowd annotations have <1000 instance ids
245
+ label_id = instance_id // 1000 if instance_id >= 1000 else instance_id
246
+ label = id2label[label_id]
247
+ if not label.hasInstances or label.ignoreInEval:
248
+ continue
249
+
250
+ anno = {}
251
+ anno["iscrowd"] = instance_id < 1000
252
+ anno["category_id"] = label.id
253
+
254
+ mask = np.asarray(inst_image == instance_id, dtype=np.uint8, order="F")
255
+
256
+ inds = np.nonzero(mask)
257
+ ymin, ymax = inds[0].min(), inds[0].max()
258
+ xmin, xmax = inds[1].min(), inds[1].max()
259
+ anno["bbox"] = (xmin, ymin, xmax, ymax)
260
+ if xmax <= xmin or ymax <= ymin:
261
+ continue
262
+ anno["bbox_mode"] = BoxMode.XYXY_ABS
263
+ if to_polygons:
264
+ # This conversion comes from D4809743 and D5171122,
265
+ # when Mask-RCNN was first developed.
266
+ contours = cv2.findContours(mask.copy(), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)[
267
+ -2
268
+ ]
269
+ polygons = [c.reshape(-1).tolist() for c in contours if len(c) >= 3]
270
+ # opencv's can produce invalid polygons
271
+ if len(polygons) == 0:
272
+ continue
273
+ anno["segmentation"] = polygons
274
+ else:
275
+ anno["segmentation"] = mask_util.encode(mask[:, :, None])[0]
276
+ annos.append(anno)
277
+ ret["annotations"] = annos
278
+ return ret
279
+
280
+
281
+ def main() -> None:
282
+ global logger, labels
283
+ """
284
+ Test the cityscapes dataset loader.
285
+
286
+ Usage:
287
+ python -m detectron2.data.datasets.cityscapes \
288
+ cityscapes/leftImg8bit/train cityscapes/gtFine/train
289
+ """
290
+ import argparse
291
+
292
+ parser = argparse.ArgumentParser()
293
+ parser.add_argument("image_dir")
294
+ parser.add_argument("gt_dir")
295
+ parser.add_argument("--type", choices=["instance", "semantic"], default="instance")
296
+ args = parser.parse_args()
297
+ from cityscapesscripts.helpers.labels import labels
298
+ from detectron2.data.catalog import Metadata
299
+ from detectron2.utils.visualizer import Visualizer
300
+
301
+ logger = setup_logger(name=__name__)
302
+
303
+ dirname = "cityscapes-data-vis"
304
+ os.makedirs(dirname, exist_ok=True)
305
+
306
+ if args.type == "instance":
307
+ dicts = load_cityscapes_instances(
308
+ args.image_dir, args.gt_dir, from_json=True, to_polygons=True
309
+ )
310
+ logger.info("Done loading {} samples.".format(len(dicts)))
311
+
312
+ thing_classes = [k.name for k in labels if k.hasInstances and not k.ignoreInEval]
313
+ meta = Metadata().set(thing_classes=thing_classes)
314
+
315
+ else:
316
+ dicts = load_cityscapes_semantic(args.image_dir, args.gt_dir)
317
+ logger.info("Done loading {} samples.".format(len(dicts)))
318
+
319
+ stuff_classes = [k.name for k in labels if k.trainId != 255]
320
+ stuff_colors = [k.color for k in labels if k.trainId != 255]
321
+ meta = Metadata().set(stuff_classes=stuff_classes, stuff_colors=stuff_colors)
322
+
323
+ for d in dicts:
324
+ img = np.array(Image.open(PathManager.open(d["file_name"], "rb")))
325
+ visualizer = Visualizer(img, metadata=meta)
326
+ vis = visualizer.draw_dataset_dict(d)
327
+ # cv2.imshow("a", vis.get_image()[:, :, ::-1])
328
+ # cv2.waitKey()
329
+ fpath = os.path.join(dirname, os.path.basename(d["file_name"]))
330
+ vis.save(fpath)
331
+
332
+
333
+ if __name__ == "__main__":
334
+ main() # pragma: no cover
CatVTON/detectron2/data/datasets/cityscapes_panoptic.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ import json
3
+ import logging
4
+ import os
5
+
6
+ from detectron2.data import DatasetCatalog, MetadataCatalog
7
+ from detectron2.data.datasets.builtin_meta import CITYSCAPES_CATEGORIES
8
+ from detectron2.utils.file_io import PathManager
9
+
10
+ """
11
+ This file contains functions to register the Cityscapes panoptic dataset to the DatasetCatalog.
12
+ """
13
+
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+ def get_cityscapes_panoptic_files(image_dir, gt_dir, json_info):
19
+ files = []
20
+ # scan through the directory
21
+ cities = PathManager.ls(image_dir)
22
+ logger.info(f"{len(cities)} cities found in '{image_dir}'.")
23
+ image_dict = {}
24
+ for city in cities:
25
+ city_img_dir = os.path.join(image_dir, city)
26
+ for basename in PathManager.ls(city_img_dir):
27
+ image_file = os.path.join(city_img_dir, basename)
28
+
29
+ suffix = "_leftImg8bit.png"
30
+ assert basename.endswith(suffix), basename
31
+ basename = os.path.basename(basename)[: -len(suffix)]
32
+
33
+ image_dict[basename] = image_file
34
+
35
+ for ann in json_info["annotations"]:
36
+ image_file = image_dict.get(ann["image_id"], None)
37
+ assert image_file is not None, "No image {} found for annotation {}".format(
38
+ ann["image_id"], ann["file_name"]
39
+ )
40
+ label_file = os.path.join(gt_dir, ann["file_name"])
41
+ segments_info = ann["segments_info"]
42
+
43
+ files.append((image_file, label_file, segments_info))
44
+
45
+ assert len(files), "No images found in {}".format(image_dir)
46
+ assert PathManager.isfile(files[0][0]), files[0][0]
47
+ assert PathManager.isfile(files[0][1]), files[0][1]
48
+ return files
49
+
50
+
51
+ def load_cityscapes_panoptic(image_dir, gt_dir, gt_json, meta):
52
+ """
53
+ Args:
54
+ image_dir (str): path to the raw dataset. e.g., "~/cityscapes/leftImg8bit/train".
55
+ gt_dir (str): path to the raw annotations. e.g.,
56
+ "~/cityscapes/gtFine/cityscapes_panoptic_train".
57
+ gt_json (str): path to the json file. e.g.,
58
+ "~/cityscapes/gtFine/cityscapes_panoptic_train.json".
59
+ meta (dict): dictionary containing "thing_dataset_id_to_contiguous_id"
60
+ and "stuff_dataset_id_to_contiguous_id" to map category ids to
61
+ contiguous ids for training.
62
+
63
+ Returns:
64
+ list[dict]: a list of dicts in Detectron2 standard format. (See
65
+ `Using Custom Datasets </tutorials/datasets.html>`_ )
66
+ """
67
+
68
+ def _convert_category_id(segment_info, meta):
69
+ if segment_info["category_id"] in meta["thing_dataset_id_to_contiguous_id"]:
70
+ segment_info["category_id"] = meta["thing_dataset_id_to_contiguous_id"][
71
+ segment_info["category_id"]
72
+ ]
73
+ else:
74
+ segment_info["category_id"] = meta["stuff_dataset_id_to_contiguous_id"][
75
+ segment_info["category_id"]
76
+ ]
77
+ return segment_info
78
+
79
+ assert os.path.exists(
80
+ gt_json
81
+ ), "Please run `python cityscapesscripts/preparation/createPanopticImgs.py` to generate label files." # noqa
82
+ with open(gt_json) as f:
83
+ json_info = json.load(f)
84
+ files = get_cityscapes_panoptic_files(image_dir, gt_dir, json_info)
85
+ ret = []
86
+ for image_file, label_file, segments_info in files:
87
+ sem_label_file = (
88
+ image_file.replace("leftImg8bit", "gtFine").split(".")[0] + "_labelTrainIds.png"
89
+ )
90
+ segments_info = [_convert_category_id(x, meta) for x in segments_info]
91
+ ret.append(
92
+ {
93
+ "file_name": image_file,
94
+ "image_id": "_".join(
95
+ os.path.splitext(os.path.basename(image_file))[0].split("_")[:3]
96
+ ),
97
+ "sem_seg_file_name": sem_label_file,
98
+ "pan_seg_file_name": label_file,
99
+ "segments_info": segments_info,
100
+ }
101
+ )
102
+ assert len(ret), f"No images found in {image_dir}!"
103
+ assert PathManager.isfile(
104
+ ret[0]["sem_seg_file_name"]
105
+ ), "Please generate labelTrainIds.png with cityscapesscripts/preparation/createTrainIdLabelImgs.py" # noqa
106
+ assert PathManager.isfile(
107
+ ret[0]["pan_seg_file_name"]
108
+ ), "Please generate panoptic annotation with python cityscapesscripts/preparation/createPanopticImgs.py" # noqa
109
+ return ret
110
+
111
+
112
+ _RAW_CITYSCAPES_PANOPTIC_SPLITS = {
113
+ "cityscapes_fine_panoptic_train": (
114
+ "cityscapes/leftImg8bit/train",
115
+ "cityscapes/gtFine/cityscapes_panoptic_train",
116
+ "cityscapes/gtFine/cityscapes_panoptic_train.json",
117
+ ),
118
+ "cityscapes_fine_panoptic_val": (
119
+ "cityscapes/leftImg8bit/val",
120
+ "cityscapes/gtFine/cityscapes_panoptic_val",
121
+ "cityscapes/gtFine/cityscapes_panoptic_val.json",
122
+ ),
123
+ # "cityscapes_fine_panoptic_test": not supported yet
124
+ }
125
+
126
+
127
+ def register_all_cityscapes_panoptic(root):
128
+ meta = {}
129
+ # The following metadata maps contiguous id from [0, #thing categories +
130
+ # #stuff categories) to their names and colors. We have to replica of the
131
+ # same name and color under "thing_*" and "stuff_*" because the current
132
+ # visualization function in D2 handles thing and class classes differently
133
+ # due to some heuristic used in Panoptic FPN. We keep the same naming to
134
+ # enable reusing existing visualization functions.
135
+ thing_classes = [k["name"] for k in CITYSCAPES_CATEGORIES]
136
+ thing_colors = [k["color"] for k in CITYSCAPES_CATEGORIES]
137
+ stuff_classes = [k["name"] for k in CITYSCAPES_CATEGORIES]
138
+ stuff_colors = [k["color"] for k in CITYSCAPES_CATEGORIES]
139
+
140
+ meta["thing_classes"] = thing_classes
141
+ meta["thing_colors"] = thing_colors
142
+ meta["stuff_classes"] = stuff_classes
143
+ meta["stuff_colors"] = stuff_colors
144
+
145
+ # There are three types of ids in cityscapes panoptic segmentation:
146
+ # (1) category id: like semantic segmentation, it is the class id for each
147
+ # pixel. Since there are some classes not used in evaluation, the category
148
+ # id is not always contiguous and thus we have two set of category ids:
149
+ # - original category id: category id in the original dataset, mainly
150
+ # used for evaluation.
151
+ # - contiguous category id: [0, #classes), in order to train the classifier
152
+ # (2) instance id: this id is used to differentiate different instances from
153
+ # the same category. For "stuff" classes, the instance id is always 0; for
154
+ # "thing" classes, the instance id starts from 1 and 0 is reserved for
155
+ # ignored instances (e.g. crowd annotation).
156
+ # (3) panoptic id: this is the compact id that encode both category and
157
+ # instance id by: category_id * 1000 + instance_id.
158
+ thing_dataset_id_to_contiguous_id = {}
159
+ stuff_dataset_id_to_contiguous_id = {}
160
+
161
+ for k in CITYSCAPES_CATEGORIES:
162
+ if k["isthing"] == 1:
163
+ thing_dataset_id_to_contiguous_id[k["id"]] = k["trainId"]
164
+ else:
165
+ stuff_dataset_id_to_contiguous_id[k["id"]] = k["trainId"]
166
+
167
+ meta["thing_dataset_id_to_contiguous_id"] = thing_dataset_id_to_contiguous_id
168
+ meta["stuff_dataset_id_to_contiguous_id"] = stuff_dataset_id_to_contiguous_id
169
+
170
+ for key, (image_dir, gt_dir, gt_json) in _RAW_CITYSCAPES_PANOPTIC_SPLITS.items():
171
+ image_dir = os.path.join(root, image_dir)
172
+ gt_dir = os.path.join(root, gt_dir)
173
+ gt_json = os.path.join(root, gt_json)
174
+
175
+ DatasetCatalog.register(
176
+ key, lambda x=image_dir, y=gt_dir, z=gt_json: load_cityscapes_panoptic(x, y, z, meta)
177
+ )
178
+ MetadataCatalog.get(key).set(
179
+ panoptic_root=gt_dir,
180
+ image_root=image_dir,
181
+ panoptic_json=gt_json,
182
+ gt_dir=gt_dir.replace("cityscapes_panoptic_", ""),
183
+ evaluator_type="cityscapes_panoptic_seg",
184
+ ignore_label=255,
185
+ label_divisor=1000,
186
+ **meta,
187
+ )
CatVTON/detectron2/data/datasets/coco.py ADDED
@@ -0,0 +1,555 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ import contextlib
3
+ import datetime
4
+ import io
5
+ import json
6
+ import logging
7
+ import numpy as np
8
+ import os
9
+ import shutil
10
+ import pycocotools.mask as mask_util
11
+ from fvcore.common.timer import Timer
12
+ from iopath.common.file_io import file_lock
13
+ from PIL import Image
14
+
15
+ from detectron2.structures import Boxes, BoxMode, PolygonMasks, RotatedBoxes
16
+ from detectron2.utils.file_io import PathManager
17
+
18
+ from .. import DatasetCatalog, MetadataCatalog
19
+
20
+ """
21
+ This file contains functions to parse COCO-format annotations into dicts in "Detectron2 format".
22
+ """
23
+
24
+
25
+ logger = logging.getLogger(__name__)
26
+
27
+ __all__ = [
28
+ "load_coco_json",
29
+ "load_sem_seg",
30
+ "convert_to_coco_json",
31
+ "register_coco_instances",
32
+ ]
33
+
34
+
35
+ def load_coco_json(json_file, image_root, dataset_name=None, extra_annotation_keys=None):
36
+ """
37
+ Load a json file with COCO's instances annotation format.
38
+ Currently supports instance detection, instance segmentation,
39
+ and person keypoints annotations.
40
+
41
+ Args:
42
+ json_file (str): full path to the json file in COCO instances annotation format.
43
+ image_root (str or path-like): the directory where the images in this json file exists.
44
+ dataset_name (str or None): the name of the dataset (e.g., coco_2017_train).
45
+ When provided, this function will also do the following:
46
+
47
+ * Put "thing_classes" into the metadata associated with this dataset.
48
+ * Map the category ids into a contiguous range (needed by standard dataset format),
49
+ and add "thing_dataset_id_to_contiguous_id" to the metadata associated
50
+ with this dataset.
51
+
52
+ This option should usually be provided, unless users need to load
53
+ the original json content and apply more processing manually.
54
+ extra_annotation_keys (list[str]): list of per-annotation keys that should also be
55
+ loaded into the dataset dict (besides "iscrowd", "bbox", "keypoints",
56
+ "category_id", "segmentation"). The values for these keys will be returned as-is.
57
+ For example, the densepose annotations are loaded in this way.
58
+
59
+ Returns:
60
+ list[dict]: a list of dicts in Detectron2 standard dataset dicts format (See
61
+ `Using Custom Datasets </tutorials/datasets.html>`_ ) when `dataset_name` is not None.
62
+ If `dataset_name` is None, the returned `category_ids` may be
63
+ incontiguous and may not conform to the Detectron2 standard format.
64
+
65
+ Notes:
66
+ 1. This function does not read the image files.
67
+ The results do not have the "image" field.
68
+ """
69
+ from pycocotools.coco import COCO
70
+
71
+ timer = Timer()
72
+ json_file = PathManager.get_local_path(json_file)
73
+ with contextlib.redirect_stdout(io.StringIO()):
74
+ coco_api = COCO(json_file)
75
+ if timer.seconds() > 1:
76
+ logger.info("Loading {} takes {:.2f} seconds.".format(json_file, timer.seconds()))
77
+
78
+ id_map = None
79
+ if dataset_name is not None:
80
+ meta = MetadataCatalog.get(dataset_name)
81
+ cat_ids = sorted(coco_api.getCatIds())
82
+ cats = coco_api.loadCats(cat_ids)
83
+ # The categories in a custom json file may not be sorted.
84
+ thing_classes = [c["name"] for c in sorted(cats, key=lambda x: x["id"])]
85
+ meta.thing_classes = thing_classes
86
+
87
+ # In COCO, certain category ids are artificially removed,
88
+ # and by convention they are always ignored.
89
+ # We deal with COCO's id issue and translate
90
+ # the category ids to contiguous ids in [0, 80).
91
+
92
+ # It works by looking at the "categories" field in the json, therefore
93
+ # if users' own json also have incontiguous ids, we'll
94
+ # apply this mapping as well but print a warning.
95
+ if not (min(cat_ids) == 1 and max(cat_ids) == len(cat_ids)):
96
+ if "coco" not in dataset_name:
97
+ logger.warning(
98
+ """
99
+ Category ids in annotations are not in [1, #categories]! We'll apply a mapping for you.
100
+ """
101
+ )
102
+ id_map = {v: i for i, v in enumerate(cat_ids)}
103
+ meta.thing_dataset_id_to_contiguous_id = id_map
104
+
105
+ # sort indices for reproducible results
106
+ img_ids = sorted(coco_api.imgs.keys())
107
+ # imgs is a list of dicts, each looks something like:
108
+ # {'license': 4,
109
+ # 'url': 'http://farm6.staticflickr.com/5454/9413846304_881d5e5c3b_z.jpg',
110
+ # 'file_name': 'COCO_val2014_000000001268.jpg',
111
+ # 'height': 427,
112
+ # 'width': 640,
113
+ # 'date_captured': '2013-11-17 05:57:24',
114
+ # 'id': 1268}
115
+ imgs = coco_api.loadImgs(img_ids)
116
+ # anns is a list[list[dict]], where each dict is an annotation
117
+ # record for an object. The inner list enumerates the objects in an image
118
+ # and the outer list enumerates over images. Example of anns[0]:
119
+ # [{'segmentation': [[192.81,
120
+ # 247.09,
121
+ # ...
122
+ # 219.03,
123
+ # 249.06]],
124
+ # 'area': 1035.749,
125
+ # 'iscrowd': 0,
126
+ # 'image_id': 1268,
127
+ # 'bbox': [192.81, 224.8, 74.73, 33.43],
128
+ # 'category_id': 16,
129
+ # 'id': 42986},
130
+ # ...]
131
+ anns = [coco_api.imgToAnns[img_id] for img_id in img_ids]
132
+ total_num_valid_anns = sum([len(x) for x in anns])
133
+ total_num_anns = len(coco_api.anns)
134
+ if total_num_valid_anns < total_num_anns:
135
+ logger.warning(
136
+ f"{json_file} contains {total_num_anns} annotations, but only "
137
+ f"{total_num_valid_anns} of them match to images in the file."
138
+ )
139
+
140
+ if "minival" not in json_file:
141
+ # The popular valminusminival & minival annotations for COCO2014 contain this bug.
142
+ # However the ratio of buggy annotations there is tiny and does not affect accuracy.
143
+ # Therefore we explicitly white-list them.
144
+ ann_ids = [ann["id"] for anns_per_image in anns for ann in anns_per_image]
145
+ assert len(set(ann_ids)) == len(ann_ids), "Annotation ids in '{}' are not unique!".format(
146
+ json_file
147
+ )
148
+
149
+ imgs_anns = list(zip(imgs, anns))
150
+ logger.info("Loaded {} images in COCO format from {}".format(len(imgs_anns), json_file))
151
+
152
+ dataset_dicts = []
153
+
154
+ ann_keys = ["iscrowd", "bbox", "keypoints", "category_id"] + (extra_annotation_keys or [])
155
+
156
+ num_instances_without_valid_segmentation = 0
157
+
158
+ for img_dict, anno_dict_list in imgs_anns:
159
+ record = {}
160
+ record["file_name"] = os.path.join(image_root, img_dict["file_name"])
161
+ record["height"] = img_dict["height"]
162
+ record["width"] = img_dict["width"]
163
+ image_id = record["image_id"] = img_dict["id"]
164
+
165
+ objs = []
166
+ for anno in anno_dict_list:
167
+ # Check that the image_id in this annotation is the same as
168
+ # the image_id we're looking at.
169
+ # This fails only when the data parsing logic or the annotation file is buggy.
170
+
171
+ # The original COCO valminusminival2014 & minival2014 annotation files
172
+ # actually contains bugs that, together with certain ways of using COCO API,
173
+ # can trigger this assertion.
174
+ assert anno["image_id"] == image_id
175
+
176
+ assert anno.get("ignore", 0) == 0, '"ignore" in COCO json file is not supported.'
177
+
178
+ obj = {key: anno[key] for key in ann_keys if key in anno}
179
+ if "bbox" in obj and len(obj["bbox"]) == 0:
180
+ raise ValueError(
181
+ f"One annotation of image {image_id} contains empty 'bbox' value! "
182
+ "This json does not have valid COCO format."
183
+ )
184
+
185
+ segm = anno.get("segmentation", None)
186
+ if segm: # either list[list[float]] or dict(RLE)
187
+ if isinstance(segm, dict):
188
+ if isinstance(segm["counts"], list):
189
+ # convert to compressed RLE
190
+ segm = mask_util.frPyObjects(segm, *segm["size"])
191
+ else:
192
+ # filter out invalid polygons (< 3 points)
193
+ segm = [poly for poly in segm if len(poly) % 2 == 0 and len(poly) >= 6]
194
+ if len(segm) == 0:
195
+ num_instances_without_valid_segmentation += 1
196
+ continue # ignore this instance
197
+ obj["segmentation"] = segm
198
+
199
+ keypts = anno.get("keypoints", None)
200
+ if keypts: # list[int]
201
+ for idx, v in enumerate(keypts):
202
+ if idx % 3 != 2:
203
+ # COCO's segmentation coordinates are floating points in [0, H or W],
204
+ # but keypoint coordinates are integers in [0, H-1 or W-1]
205
+ # Therefore we assume the coordinates are "pixel indices" and
206
+ # add 0.5 to convert to floating point coordinates.
207
+ keypts[idx] = v + 0.5
208
+ obj["keypoints"] = keypts
209
+
210
+ obj["bbox_mode"] = BoxMode.XYWH_ABS
211
+ if id_map:
212
+ annotation_category_id = obj["category_id"]
213
+ try:
214
+ obj["category_id"] = id_map[annotation_category_id]
215
+ except KeyError as e:
216
+ raise KeyError(
217
+ f"Encountered category_id={annotation_category_id} "
218
+ "but this id does not exist in 'categories' of the json file."
219
+ ) from e
220
+ objs.append(obj)
221
+ record["annotations"] = objs
222
+ dataset_dicts.append(record)
223
+
224
+ if num_instances_without_valid_segmentation > 0:
225
+ logger.warning(
226
+ "Filtered out {} instances without valid segmentation. ".format(
227
+ num_instances_without_valid_segmentation
228
+ )
229
+ + "There might be issues in your dataset generation process. Please "
230
+ "check https://detectron2.readthedocs.io/en/latest/tutorials/datasets.html carefully"
231
+ )
232
+ return dataset_dicts
233
+
234
+
235
+ def load_sem_seg(gt_root, image_root, gt_ext="png", image_ext="jpg"):
236
+ """
237
+ Load semantic segmentation datasets. All files under "gt_root" with "gt_ext" extension are
238
+ treated as ground truth annotations and all files under "image_root" with "image_ext" extension
239
+ as input images. Ground truth and input images are matched using file paths relative to
240
+ "gt_root" and "image_root" respectively without taking into account file extensions.
241
+ This works for COCO as well as some other datasets.
242
+
243
+ Args:
244
+ gt_root (str): full path to ground truth semantic segmentation files. Semantic segmentation
245
+ annotations are stored as images with integer values in pixels that represent
246
+ corresponding semantic labels.
247
+ image_root (str): the directory where the input images are.
248
+ gt_ext (str): file extension for ground truth annotations.
249
+ image_ext (str): file extension for input images.
250
+
251
+ Returns:
252
+ list[dict]:
253
+ a list of dicts in detectron2 standard format without instance-level
254
+ annotation.
255
+
256
+ Notes:
257
+ 1. This function does not read the image and ground truth files.
258
+ The results do not have the "image" and "sem_seg" fields.
259
+ """
260
+
261
+ # We match input images with ground truth based on their relative filepaths (without file
262
+ # extensions) starting from 'image_root' and 'gt_root' respectively.
263
+ def file2id(folder_path, file_path):
264
+ # extract relative path starting from `folder_path`
265
+ image_id = os.path.normpath(os.path.relpath(file_path, start=folder_path))
266
+ # remove file extension
267
+ image_id = os.path.splitext(image_id)[0]
268
+ return image_id
269
+
270
+ input_files = sorted(
271
+ (os.path.join(image_root, f) for f in PathManager.ls(image_root) if f.endswith(image_ext)),
272
+ key=lambda file_path: file2id(image_root, file_path),
273
+ )
274
+ gt_files = sorted(
275
+ (os.path.join(gt_root, f) for f in PathManager.ls(gt_root) if f.endswith(gt_ext)),
276
+ key=lambda file_path: file2id(gt_root, file_path),
277
+ )
278
+
279
+ assert len(gt_files) > 0, "No annotations found in {}.".format(gt_root)
280
+
281
+ # Use the intersection, so that val2017_100 annotations can run smoothly with val2017 images
282
+ if len(input_files) != len(gt_files):
283
+ logger.warn(
284
+ "Directory {} and {} has {} and {} files, respectively.".format(
285
+ image_root, gt_root, len(input_files), len(gt_files)
286
+ )
287
+ )
288
+ input_basenames = [os.path.basename(f)[: -len(image_ext)] for f in input_files]
289
+ gt_basenames = [os.path.basename(f)[: -len(gt_ext)] for f in gt_files]
290
+ intersect = list(set(input_basenames) & set(gt_basenames))
291
+ # sort, otherwise each worker may obtain a list[dict] in different order
292
+ intersect = sorted(intersect)
293
+ logger.warn("Will use their intersection of {} files.".format(len(intersect)))
294
+ input_files = [os.path.join(image_root, f + image_ext) for f in intersect]
295
+ gt_files = [os.path.join(gt_root, f + gt_ext) for f in intersect]
296
+
297
+ logger.info(
298
+ "Loaded {} images with semantic segmentation from {}".format(len(input_files), image_root)
299
+ )
300
+
301
+ dataset_dicts = []
302
+ for img_path, gt_path in zip(input_files, gt_files):
303
+ record = {}
304
+ record["file_name"] = img_path
305
+ record["sem_seg_file_name"] = gt_path
306
+ dataset_dicts.append(record)
307
+
308
+ return dataset_dicts
309
+
310
+
311
+ def convert_to_coco_dict(dataset_name):
312
+ """
313
+ Convert an instance detection/segmentation or keypoint detection dataset
314
+ in detectron2's standard format into COCO json format.
315
+
316
+ Generic dataset description can be found here:
317
+ https://detectron2.readthedocs.io/tutorials/datasets.html#register-a-dataset
318
+
319
+ COCO data format description can be found here:
320
+ http://cocodataset.org/#format-data
321
+
322
+ Args:
323
+ dataset_name (str):
324
+ name of the source dataset
325
+ Must be registered in DatastCatalog and in detectron2's standard format.
326
+ Must have corresponding metadata "thing_classes"
327
+ Returns:
328
+ coco_dict: serializable dict in COCO json format
329
+ """
330
+
331
+ dataset_dicts = DatasetCatalog.get(dataset_name)
332
+ metadata = MetadataCatalog.get(dataset_name)
333
+
334
+ # unmap the category mapping ids for COCO
335
+ if hasattr(metadata, "thing_dataset_id_to_contiguous_id"):
336
+ reverse_id_mapping = {v: k for k, v in metadata.thing_dataset_id_to_contiguous_id.items()}
337
+ reverse_id_mapper = lambda contiguous_id: reverse_id_mapping[contiguous_id] # noqa
338
+ else:
339
+ reverse_id_mapper = lambda contiguous_id: contiguous_id # noqa
340
+
341
+ categories = [
342
+ {"id": reverse_id_mapper(id), "name": name}
343
+ for id, name in enumerate(metadata.thing_classes)
344
+ ]
345
+
346
+ logger.info("Converting dataset dicts into COCO format")
347
+ coco_images = []
348
+ coco_annotations = []
349
+
350
+ for image_id, image_dict in enumerate(dataset_dicts):
351
+ coco_image = {
352
+ "id": image_dict.get("image_id", image_id),
353
+ "width": int(image_dict["width"]),
354
+ "height": int(image_dict["height"]),
355
+ "file_name": str(image_dict["file_name"]),
356
+ }
357
+ coco_images.append(coco_image)
358
+
359
+ anns_per_image = image_dict.get("annotations", [])
360
+ for annotation in anns_per_image:
361
+ # create a new dict with only COCO fields
362
+ coco_annotation = {}
363
+
364
+ # COCO requirement: XYWH box format for axis-align and XYWHA for rotated
365
+ bbox = annotation["bbox"]
366
+ if isinstance(bbox, np.ndarray):
367
+ if bbox.ndim != 1:
368
+ raise ValueError(f"bbox has to be 1-dimensional. Got shape={bbox.shape}.")
369
+ bbox = bbox.tolist()
370
+ if len(bbox) not in [4, 5]:
371
+ raise ValueError(f"bbox has to has length 4 or 5. Got {bbox}.")
372
+ from_bbox_mode = annotation["bbox_mode"]
373
+ to_bbox_mode = BoxMode.XYWH_ABS if len(bbox) == 4 else BoxMode.XYWHA_ABS
374
+ bbox = BoxMode.convert(bbox, from_bbox_mode, to_bbox_mode)
375
+
376
+ # COCO requirement: instance area
377
+ if "segmentation" in annotation:
378
+ # Computing areas for instances by counting the pixels
379
+ segmentation = annotation["segmentation"]
380
+ # TODO: check segmentation type: RLE, BinaryMask or Polygon
381
+ if isinstance(segmentation, list):
382
+ polygons = PolygonMasks([segmentation])
383
+ area = polygons.area()[0].item()
384
+ elif isinstance(segmentation, dict): # RLE
385
+ area = mask_util.area(segmentation).item()
386
+ else:
387
+ raise TypeError(f"Unknown segmentation type {type(segmentation)}!")
388
+ else:
389
+ # Computing areas using bounding boxes
390
+ if to_bbox_mode == BoxMode.XYWH_ABS:
391
+ bbox_xy = BoxMode.convert(bbox, to_bbox_mode, BoxMode.XYXY_ABS)
392
+ area = Boxes([bbox_xy]).area()[0].item()
393
+ else:
394
+ area = RotatedBoxes([bbox]).area()[0].item()
395
+
396
+ if "keypoints" in annotation:
397
+ keypoints = annotation["keypoints"] # list[int]
398
+ for idx, v in enumerate(keypoints):
399
+ if idx % 3 != 2:
400
+ # COCO's segmentation coordinates are floating points in [0, H or W],
401
+ # but keypoint coordinates are integers in [0, H-1 or W-1]
402
+ # For COCO format consistency we substract 0.5
403
+ # https://github.com/facebookresearch/detectron2/pull/175#issuecomment-551202163
404
+ keypoints[idx] = v - 0.5
405
+ if "num_keypoints" in annotation:
406
+ num_keypoints = annotation["num_keypoints"]
407
+ else:
408
+ num_keypoints = sum(kp > 0 for kp in keypoints[2::3])
409
+
410
+ # COCO requirement:
411
+ # linking annotations to images
412
+ # "id" field must start with 1
413
+ coco_annotation["id"] = len(coco_annotations) + 1
414
+ coco_annotation["image_id"] = coco_image["id"]
415
+ coco_annotation["bbox"] = [round(float(x), 3) for x in bbox]
416
+ coco_annotation["area"] = float(area)
417
+ coco_annotation["iscrowd"] = int(annotation.get("iscrowd", 0))
418
+ coco_annotation["category_id"] = int(reverse_id_mapper(annotation["category_id"]))
419
+
420
+ # Add optional fields
421
+ if "keypoints" in annotation:
422
+ coco_annotation["keypoints"] = keypoints
423
+ coco_annotation["num_keypoints"] = num_keypoints
424
+
425
+ if "segmentation" in annotation:
426
+ seg = coco_annotation["segmentation"] = annotation["segmentation"]
427
+ if isinstance(seg, dict): # RLE
428
+ counts = seg["counts"]
429
+ if not isinstance(counts, str):
430
+ # make it json-serializable
431
+ seg["counts"] = counts.decode("ascii")
432
+
433
+ coco_annotations.append(coco_annotation)
434
+
435
+ logger.info(
436
+ "Conversion finished, "
437
+ f"#images: {len(coco_images)}, #annotations: {len(coco_annotations)}"
438
+ )
439
+
440
+ info = {
441
+ "date_created": str(datetime.datetime.now()),
442
+ "description": "Automatically generated COCO json file for Detectron2.",
443
+ }
444
+ coco_dict = {
445
+ "info": info,
446
+ "images": coco_images,
447
+ "categories": categories,
448
+ "licenses": None,
449
+ }
450
+ if len(coco_annotations) > 0:
451
+ coco_dict["annotations"] = coco_annotations
452
+ return coco_dict
453
+
454
+
455
+ def convert_to_coco_json(dataset_name, output_file, allow_cached=True):
456
+ """
457
+ Converts dataset into COCO format and saves it to a json file.
458
+ dataset_name must be registered in DatasetCatalog and in detectron2's standard format.
459
+
460
+ Args:
461
+ dataset_name:
462
+ reference from the config file to the catalogs
463
+ must be registered in DatasetCatalog and in detectron2's standard format
464
+ output_file: path of json file that will be saved to
465
+ allow_cached: if json file is already present then skip conversion
466
+ """
467
+
468
+ # TODO: The dataset or the conversion script *may* change,
469
+ # a checksum would be useful for validating the cached data
470
+
471
+ PathManager.mkdirs(os.path.dirname(output_file))
472
+ with file_lock(output_file):
473
+ if PathManager.exists(output_file) and allow_cached:
474
+ logger.warning(
475
+ f"Using previously cached COCO format annotations at '{output_file}'. "
476
+ "You need to clear the cache file if your dataset has been modified."
477
+ )
478
+ else:
479
+ logger.info(f"Converting annotations of dataset '{dataset_name}' to COCO format ...)")
480
+ coco_dict = convert_to_coco_dict(dataset_name)
481
+
482
+ logger.info(f"Caching COCO format annotations at '{output_file}' ...")
483
+ tmp_file = output_file + ".tmp"
484
+ with PathManager.open(tmp_file, "w") as f:
485
+ json.dump(coco_dict, f)
486
+ shutil.move(tmp_file, output_file)
487
+
488
+
489
+ def register_coco_instances(name, metadata, json_file, image_root):
490
+ """
491
+ Register a dataset in COCO's json annotation format for
492
+ instance detection, instance segmentation and keypoint detection.
493
+ (i.e., Type 1 and 2 in http://cocodataset.org/#format-data.
494
+ `instances*.json` and `person_keypoints*.json` in the dataset).
495
+
496
+ This is an example of how to register a new dataset.
497
+ You can do something similar to this function, to register new datasets.
498
+
499
+ Args:
500
+ name (str): the name that identifies a dataset, e.g. "coco_2014_train".
501
+ metadata (dict): extra metadata associated with this dataset. You can
502
+ leave it as an empty dict.
503
+ json_file (str): path to the json instance annotation file.
504
+ image_root (str or path-like): directory which contains all the images.
505
+ """
506
+ assert isinstance(name, str), name
507
+ assert isinstance(json_file, (str, os.PathLike)), json_file
508
+ assert isinstance(image_root, (str, os.PathLike)), image_root
509
+ # 1. register a function which returns dicts
510
+ DatasetCatalog.register(name, lambda: load_coco_json(json_file, image_root, name))
511
+
512
+ # 2. Optionally, add metadata about this dataset,
513
+ # since they might be useful in evaluation, visualization or logging
514
+ MetadataCatalog.get(name).set(
515
+ json_file=json_file, image_root=image_root, evaluator_type="coco", **metadata
516
+ )
517
+
518
+
519
+ def main() -> None:
520
+ global logger
521
+ """
522
+ Test the COCO json dataset loader.
523
+
524
+ Usage:
525
+ python -m detectron2.data.datasets.coco \
526
+ path/to/json path/to/image_root dataset_name
527
+
528
+ "dataset_name" can be "coco_2014_minival_100", or other
529
+ pre-registered ones
530
+ """
531
+ import sys
532
+
533
+ import detectron2.data.datasets # noqa # add pre-defined metadata
534
+ from detectron2.utils.logger import setup_logger
535
+ from detectron2.utils.visualizer import Visualizer
536
+
537
+ logger = setup_logger(name=__name__)
538
+ assert sys.argv[3] in DatasetCatalog.list()
539
+ meta = MetadataCatalog.get(sys.argv[3])
540
+
541
+ dicts = load_coco_json(sys.argv[1], sys.argv[2], sys.argv[3])
542
+ logger.info("Done loading {} samples.".format(len(dicts)))
543
+
544
+ dirname = "coco-data-vis"
545
+ os.makedirs(dirname, exist_ok=True)
546
+ for d in dicts:
547
+ img = np.array(Image.open(d["file_name"]))
548
+ visualizer = Visualizer(img, metadata=meta)
549
+ vis = visualizer.draw_dataset_dict(d)
550
+ fpath = os.path.join(dirname, os.path.basename(d["file_name"]))
551
+ vis.save(fpath)
552
+
553
+
554
+ if __name__ == "__main__":
555
+ main() # pragma: no cover
CatVTON/detectron2/data/datasets/coco_panoptic.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ import copy
3
+ import json
4
+ import os
5
+
6
+ from detectron2.data import DatasetCatalog, MetadataCatalog
7
+ from detectron2.utils.file_io import PathManager
8
+
9
+ from .coco import load_coco_json, load_sem_seg
10
+
11
+ __all__ = ["register_coco_panoptic", "register_coco_panoptic_separated"]
12
+
13
+
14
+ def load_coco_panoptic_json(json_file, image_dir, gt_dir, meta):
15
+ """
16
+ Args:
17
+ image_dir (str): path to the raw dataset. e.g., "~/coco/train2017".
18
+ gt_dir (str): path to the raw annotations. e.g., "~/coco/panoptic_train2017".
19
+ json_file (str): path to the json file. e.g., "~/coco/annotations/panoptic_train2017.json".
20
+
21
+ Returns:
22
+ list[dict]: a list of dicts in Detectron2 standard format. (See
23
+ `Using Custom Datasets </tutorials/datasets.html>`_ )
24
+ """
25
+
26
+ def _convert_category_id(segment_info, meta):
27
+ if segment_info["category_id"] in meta["thing_dataset_id_to_contiguous_id"]:
28
+ segment_info["category_id"] = meta["thing_dataset_id_to_contiguous_id"][
29
+ segment_info["category_id"]
30
+ ]
31
+ segment_info["isthing"] = True
32
+ else:
33
+ segment_info["category_id"] = meta["stuff_dataset_id_to_contiguous_id"][
34
+ segment_info["category_id"]
35
+ ]
36
+ segment_info["isthing"] = False
37
+ return segment_info
38
+
39
+ with PathManager.open(json_file) as f:
40
+ json_info = json.load(f)
41
+
42
+ ret = []
43
+ for ann in json_info["annotations"]:
44
+ image_id = int(ann["image_id"])
45
+ # TODO: currently we assume image and label has the same filename but
46
+ # different extension, and images have extension ".jpg" for COCO. Need
47
+ # to make image extension a user-provided argument if we extend this
48
+ # function to support other COCO-like datasets.
49
+ image_file = os.path.join(image_dir, os.path.splitext(ann["file_name"])[0] + ".jpg")
50
+ label_file = os.path.join(gt_dir, ann["file_name"])
51
+ segments_info = [_convert_category_id(x, meta) for x in ann["segments_info"]]
52
+ ret.append(
53
+ {
54
+ "file_name": image_file,
55
+ "image_id": image_id,
56
+ "pan_seg_file_name": label_file,
57
+ "segments_info": segments_info,
58
+ }
59
+ )
60
+ assert len(ret), f"No images found in {image_dir}!"
61
+ assert PathManager.isfile(ret[0]["file_name"]), ret[0]["file_name"]
62
+ assert PathManager.isfile(ret[0]["pan_seg_file_name"]), ret[0]["pan_seg_file_name"]
63
+ return ret
64
+
65
+
66
+ def register_coco_panoptic(
67
+ name, metadata, image_root, panoptic_root, panoptic_json, instances_json=None
68
+ ):
69
+ """
70
+ Register a "standard" version of COCO panoptic segmentation dataset named `name`.
71
+ The dictionaries in this registered dataset follows detectron2's standard format.
72
+ Hence it's called "standard".
73
+
74
+ Args:
75
+ name (str): the name that identifies a dataset,
76
+ e.g. "coco_2017_train_panoptic"
77
+ metadata (dict): extra metadata associated with this dataset.
78
+ image_root (str): directory which contains all the images
79
+ panoptic_root (str): directory which contains panoptic annotation images in COCO format
80
+ panoptic_json (str): path to the json panoptic annotation file in COCO format
81
+ sem_seg_root (none): not used, to be consistent with
82
+ `register_coco_panoptic_separated`.
83
+ instances_json (str): path to the json instance annotation file
84
+ """
85
+ panoptic_name = name
86
+ DatasetCatalog.register(
87
+ panoptic_name,
88
+ lambda: load_coco_panoptic_json(panoptic_json, image_root, panoptic_root, metadata),
89
+ )
90
+ MetadataCatalog.get(panoptic_name).set(
91
+ panoptic_root=panoptic_root,
92
+ image_root=image_root,
93
+ panoptic_json=panoptic_json,
94
+ json_file=instances_json,
95
+ evaluator_type="coco_panoptic_seg",
96
+ ignore_label=255,
97
+ label_divisor=1000,
98
+ **metadata,
99
+ )
100
+
101
+
102
+ def register_coco_panoptic_separated(
103
+ name, metadata, image_root, panoptic_root, panoptic_json, sem_seg_root, instances_json
104
+ ):
105
+ """
106
+ Register a "separated" version of COCO panoptic segmentation dataset named `name`.
107
+ The annotations in this registered dataset will contain both instance annotations and
108
+ semantic annotations, each with its own contiguous ids. Hence it's called "separated".
109
+
110
+ It follows the setting used by the PanopticFPN paper:
111
+
112
+ 1. The instance annotations directly come from polygons in the COCO
113
+ instances annotation task, rather than from the masks in the COCO panoptic annotations.
114
+
115
+ The two format have small differences:
116
+ Polygons in the instance annotations may have overlaps.
117
+ The mask annotations are produced by labeling the overlapped polygons
118
+ with depth ordering.
119
+
120
+ 2. The semantic annotations are converted from panoptic annotations, where
121
+ all "things" are assigned a semantic id of 0.
122
+ All semantic categories will therefore have ids in contiguous
123
+ range [1, #stuff_categories].
124
+
125
+ This function will also register a pure semantic segmentation dataset
126
+ named ``name + '_stuffonly'``.
127
+
128
+ Args:
129
+ name (str): the name that identifies a dataset,
130
+ e.g. "coco_2017_train_panoptic"
131
+ metadata (dict): extra metadata associated with this dataset.
132
+ image_root (str): directory which contains all the images
133
+ panoptic_root (str): directory which contains panoptic annotation images
134
+ panoptic_json (str): path to the json panoptic annotation file
135
+ sem_seg_root (str): directory which contains all the ground truth segmentation annotations.
136
+ instances_json (str): path to the json instance annotation file
137
+ """
138
+ panoptic_name = name + "_separated"
139
+ DatasetCatalog.register(
140
+ panoptic_name,
141
+ lambda: merge_to_panoptic(
142
+ load_coco_json(instances_json, image_root, panoptic_name),
143
+ load_sem_seg(sem_seg_root, image_root),
144
+ ),
145
+ )
146
+ MetadataCatalog.get(panoptic_name).set(
147
+ panoptic_root=panoptic_root,
148
+ image_root=image_root,
149
+ panoptic_json=panoptic_json,
150
+ sem_seg_root=sem_seg_root,
151
+ json_file=instances_json, # TODO rename
152
+ evaluator_type="coco_panoptic_seg",
153
+ ignore_label=255,
154
+ **metadata,
155
+ )
156
+
157
+ semantic_name = name + "_stuffonly"
158
+ DatasetCatalog.register(semantic_name, lambda: load_sem_seg(sem_seg_root, image_root))
159
+ MetadataCatalog.get(semantic_name).set(
160
+ sem_seg_root=sem_seg_root,
161
+ image_root=image_root,
162
+ evaluator_type="sem_seg",
163
+ ignore_label=255,
164
+ **metadata,
165
+ )
166
+
167
+
168
+ def merge_to_panoptic(detection_dicts, sem_seg_dicts):
169
+ """
170
+ Create dataset dicts for panoptic segmentation, by
171
+ merging two dicts using "file_name" field to match their entries.
172
+
173
+ Args:
174
+ detection_dicts (list[dict]): lists of dicts for object detection or instance segmentation.
175
+ sem_seg_dicts (list[dict]): lists of dicts for semantic segmentation.
176
+
177
+ Returns:
178
+ list[dict] (one per input image): Each dict contains all (key, value) pairs from dicts in
179
+ both detection_dicts and sem_seg_dicts that correspond to the same image.
180
+ The function assumes that the same key in different dicts has the same value.
181
+ """
182
+ results = []
183
+ sem_seg_file_to_entry = {x["file_name"]: x for x in sem_seg_dicts}
184
+ assert len(sem_seg_file_to_entry) > 0
185
+
186
+ for det_dict in detection_dicts:
187
+ dic = copy.copy(det_dict)
188
+ dic.update(sem_seg_file_to_entry[dic["file_name"]])
189
+ results.append(dic)
190
+ return results
191
+
192
+
193
+ if __name__ == "__main__":
194
+ """
195
+ Test the COCO panoptic dataset loader.
196
+
197
+ Usage:
198
+ python -m detectron2.data.datasets.coco_panoptic \
199
+ path/to/image_root path/to/panoptic_root path/to/panoptic_json dataset_name 10
200
+
201
+ "dataset_name" can be "coco_2017_train_panoptic", or other
202
+ pre-registered ones
203
+ """
204
+ from detectron2.utils.logger import setup_logger
205
+ from detectron2.utils.visualizer import Visualizer
206
+ import detectron2.data.datasets # noqa # add pre-defined metadata
207
+ import sys
208
+ from PIL import Image
209
+ import numpy as np
210
+
211
+ logger = setup_logger(name=__name__)
212
+ assert sys.argv[4] in DatasetCatalog.list()
213
+ meta = MetadataCatalog.get(sys.argv[4])
214
+
215
+ dicts = load_coco_panoptic_json(sys.argv[3], sys.argv[1], sys.argv[2], meta.as_dict())
216
+ logger.info("Done loading {} samples.".format(len(dicts)))
217
+
218
+ dirname = "coco-data-vis"
219
+ os.makedirs(dirname, exist_ok=True)
220
+ num_imgs_to_vis = int(sys.argv[5])
221
+ for i, d in enumerate(dicts):
222
+ img = np.array(Image.open(d["file_name"]))
223
+ visualizer = Visualizer(img, metadata=meta)
224
+ vis = visualizer.draw_dataset_dict(d)
225
+ fpath = os.path.join(dirname, os.path.basename(d["file_name"]))
226
+ vis.save(fpath)
227
+ if i + 1 >= num_imgs_to_vis:
228
+ break
CatVTON/detectron2/data/datasets/lvis.py ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ import logging
3
+ import os
4
+ from fvcore.common.timer import Timer
5
+
6
+ from detectron2.data import DatasetCatalog, MetadataCatalog
7
+ from detectron2.structures import BoxMode
8
+ from detectron2.utils.file_io import PathManager
9
+
10
+ from .builtin_meta import _get_coco_instances_meta
11
+ from .lvis_v0_5_categories import LVIS_CATEGORIES as LVIS_V0_5_CATEGORIES
12
+ from .lvis_v1_categories import LVIS_CATEGORIES as LVIS_V1_CATEGORIES
13
+ from .lvis_v1_category_image_count import LVIS_CATEGORY_IMAGE_COUNT as LVIS_V1_CATEGORY_IMAGE_COUNT
14
+
15
+ """
16
+ This file contains functions to parse LVIS-format annotations into dicts in the
17
+ "Detectron2 format".
18
+ """
19
+
20
+ logger = logging.getLogger(__name__)
21
+
22
+ __all__ = ["load_lvis_json", "register_lvis_instances", "get_lvis_instances_meta"]
23
+
24
+
25
+ def register_lvis_instances(name, metadata, json_file, image_root):
26
+ """
27
+ Register a dataset in LVIS's json annotation format for instance detection and segmentation.
28
+
29
+ Args:
30
+ name (str): a name that identifies the dataset, e.g. "lvis_v0.5_train".
31
+ metadata (dict): extra metadata associated with this dataset. It can be an empty dict.
32
+ json_file (str): path to the json instance annotation file.
33
+ image_root (str or path-like): directory which contains all the images.
34
+ """
35
+ DatasetCatalog.register(name, lambda: load_lvis_json(json_file, image_root, name))
36
+ MetadataCatalog.get(name).set(
37
+ json_file=json_file, image_root=image_root, evaluator_type="lvis", **metadata
38
+ )
39
+
40
+
41
+ def load_lvis_json(json_file, image_root, dataset_name=None, extra_annotation_keys=None):
42
+ """
43
+ Load a json file in LVIS's annotation format.
44
+
45
+ Args:
46
+ json_file (str): full path to the LVIS json annotation file.
47
+ image_root (str): the directory where the images in this json file exists.
48
+ dataset_name (str): the name of the dataset (e.g., "lvis_v0.5_train").
49
+ If provided, this function will put "thing_classes" into the metadata
50
+ associated with this dataset.
51
+ extra_annotation_keys (list[str]): list of per-annotation keys that should also be
52
+ loaded into the dataset dict (besides "bbox", "bbox_mode", "category_id",
53
+ "segmentation"). The values for these keys will be returned as-is.
54
+
55
+ Returns:
56
+ list[dict]: a list of dicts in Detectron2 standard format. (See
57
+ `Using Custom Datasets </tutorials/datasets.html>`_ )
58
+
59
+ Notes:
60
+ 1. This function does not read the image files.
61
+ The results do not have the "image" field.
62
+ """
63
+ from lvis import LVIS
64
+
65
+ json_file = PathManager.get_local_path(json_file)
66
+
67
+ timer = Timer()
68
+ lvis_api = LVIS(json_file)
69
+ if timer.seconds() > 1:
70
+ logger.info("Loading {} takes {:.2f} seconds.".format(json_file, timer.seconds()))
71
+
72
+ if dataset_name is not None:
73
+ meta = get_lvis_instances_meta(dataset_name)
74
+ MetadataCatalog.get(dataset_name).set(**meta)
75
+
76
+ # sort indices for reproducible results
77
+ img_ids = sorted(lvis_api.imgs.keys())
78
+ # imgs is a list of dicts, each looks something like:
79
+ # {'license': 4,
80
+ # 'url': 'http://farm6.staticflickr.com/5454/9413846304_881d5e5c3b_z.jpg',
81
+ # 'file_name': 'COCO_val2014_000000001268.jpg',
82
+ # 'height': 427,
83
+ # 'width': 640,
84
+ # 'date_captured': '2013-11-17 05:57:24',
85
+ # 'id': 1268}
86
+ imgs = lvis_api.load_imgs(img_ids)
87
+ # anns is a list[list[dict]], where each dict is an annotation
88
+ # record for an object. The inner list enumerates the objects in an image
89
+ # and the outer list enumerates over images. Example of anns[0]:
90
+ # [{'segmentation': [[192.81,
91
+ # 247.09,
92
+ # ...
93
+ # 219.03,
94
+ # 249.06]],
95
+ # 'area': 1035.749,
96
+ # 'image_id': 1268,
97
+ # 'bbox': [192.81, 224.8, 74.73, 33.43],
98
+ # 'category_id': 16,
99
+ # 'id': 42986},
100
+ # ...]
101
+ anns = [lvis_api.img_ann_map[img_id] for img_id in img_ids]
102
+
103
+ # Sanity check that each annotation has a unique id
104
+ ann_ids = [ann["id"] for anns_per_image in anns for ann in anns_per_image]
105
+ assert len(set(ann_ids)) == len(ann_ids), "Annotation ids in '{}' are not unique".format(
106
+ json_file
107
+ )
108
+
109
+ imgs_anns = list(zip(imgs, anns))
110
+
111
+ logger.info("Loaded {} images in the LVIS format from {}".format(len(imgs_anns), json_file))
112
+
113
+ if extra_annotation_keys:
114
+ logger.info(
115
+ "The following extra annotation keys will be loaded: {} ".format(extra_annotation_keys)
116
+ )
117
+ else:
118
+ extra_annotation_keys = []
119
+
120
+ def get_file_name(img_root, img_dict):
121
+ # Determine the path including the split folder ("train2017", "val2017", "test2017") from
122
+ # the coco_url field. Example:
123
+ # 'coco_url': 'http://images.cocodataset.org/train2017/000000155379.jpg'
124
+ split_folder, file_name = img_dict["coco_url"].split("/")[-2:]
125
+ return os.path.join(img_root + split_folder, file_name)
126
+
127
+ dataset_dicts = []
128
+
129
+ for img_dict, anno_dict_list in imgs_anns:
130
+ record = {}
131
+ record["file_name"] = get_file_name(image_root, img_dict)
132
+ record["height"] = img_dict["height"]
133
+ record["width"] = img_dict["width"]
134
+ record["not_exhaustive_category_ids"] = img_dict.get("not_exhaustive_category_ids", [])
135
+ record["neg_category_ids"] = img_dict.get("neg_category_ids", [])
136
+ image_id = record["image_id"] = img_dict["id"]
137
+
138
+ objs = []
139
+ for anno in anno_dict_list:
140
+ # Check that the image_id in this annotation is the same as
141
+ # the image_id we're looking at.
142
+ # This fails only when the data parsing logic or the annotation file is buggy.
143
+ assert anno["image_id"] == image_id
144
+ obj = {"bbox": anno["bbox"], "bbox_mode": BoxMode.XYWH_ABS}
145
+ # LVIS data loader can be used to load COCO dataset categories. In this case `meta`
146
+ # variable will have a field with COCO-specific category mapping.
147
+ if dataset_name is not None and "thing_dataset_id_to_contiguous_id" in meta:
148
+ obj["category_id"] = meta["thing_dataset_id_to_contiguous_id"][anno["category_id"]]
149
+ else:
150
+ obj["category_id"] = anno["category_id"] - 1 # Convert 1-indexed to 0-indexed
151
+ segm = anno["segmentation"] # list[list[float]]
152
+ # filter out invalid polygons (< 3 points)
153
+ valid_segm = [poly for poly in segm if len(poly) % 2 == 0 and len(poly) >= 6]
154
+ assert len(segm) == len(
155
+ valid_segm
156
+ ), "Annotation contains an invalid polygon with < 3 points"
157
+ assert len(segm) > 0
158
+ obj["segmentation"] = segm
159
+ for extra_ann_key in extra_annotation_keys:
160
+ obj[extra_ann_key] = anno[extra_ann_key]
161
+ objs.append(obj)
162
+ record["annotations"] = objs
163
+ dataset_dicts.append(record)
164
+
165
+ return dataset_dicts
166
+
167
+
168
+ def get_lvis_instances_meta(dataset_name):
169
+ """
170
+ Load LVIS metadata.
171
+
172
+ Args:
173
+ dataset_name (str): LVIS dataset name without the split name (e.g., "lvis_v0.5").
174
+
175
+ Returns:
176
+ dict: LVIS metadata with keys: thing_classes
177
+ """
178
+ if "cocofied" in dataset_name:
179
+ return _get_coco_instances_meta()
180
+ if "v0.5" in dataset_name:
181
+ return _get_lvis_instances_meta_v0_5()
182
+ elif "v1" in dataset_name:
183
+ return _get_lvis_instances_meta_v1()
184
+ raise ValueError("No built-in metadata for dataset {}".format(dataset_name))
185
+
186
+
187
+ def _get_lvis_instances_meta_v0_5():
188
+ assert len(LVIS_V0_5_CATEGORIES) == 1230
189
+ cat_ids = [k["id"] for k in LVIS_V0_5_CATEGORIES]
190
+ assert min(cat_ids) == 1 and max(cat_ids) == len(
191
+ cat_ids
192
+ ), "Category ids are not in [1, #categories], as expected"
193
+ # Ensure that the category list is sorted by id
194
+ lvis_categories = sorted(LVIS_V0_5_CATEGORIES, key=lambda x: x["id"])
195
+ thing_classes = [k["synonyms"][0] for k in lvis_categories]
196
+ meta = {"thing_classes": thing_classes}
197
+ return meta
198
+
199
+
200
+ def _get_lvis_instances_meta_v1():
201
+ assert len(LVIS_V1_CATEGORIES) == 1203
202
+ cat_ids = [k["id"] for k in LVIS_V1_CATEGORIES]
203
+ assert min(cat_ids) == 1 and max(cat_ids) == len(
204
+ cat_ids
205
+ ), "Category ids are not in [1, #categories], as expected"
206
+ # Ensure that the category list is sorted by id
207
+ lvis_categories = sorted(LVIS_V1_CATEGORIES, key=lambda x: x["id"])
208
+ thing_classes = [k["synonyms"][0] for k in lvis_categories]
209
+ meta = {
210
+ "thing_classes": thing_classes,
211
+ "class_image_count": LVIS_V1_CATEGORY_IMAGE_COUNT,
212
+ }
213
+ return meta
214
+
215
+
216
+ def main() -> None:
217
+ global logger
218
+ """
219
+ Test the LVIS json dataset loader.
220
+
221
+ Usage:
222
+ python -m detectron2.data.datasets.lvis \
223
+ path/to/json path/to/image_root dataset_name vis_limit
224
+ """
225
+ import sys
226
+
227
+ import detectron2.data.datasets # noqa # add pre-defined metadata
228
+ import numpy as np
229
+ from detectron2.utils.logger import setup_logger
230
+ from detectron2.utils.visualizer import Visualizer
231
+ from PIL import Image
232
+
233
+ logger = setup_logger(name=__name__)
234
+ meta = MetadataCatalog.get(sys.argv[3])
235
+
236
+ dicts = load_lvis_json(sys.argv[1], sys.argv[2], sys.argv[3])
237
+ logger.info("Done loading {} samples.".format(len(dicts)))
238
+
239
+ dirname = "lvis-data-vis"
240
+ os.makedirs(dirname, exist_ok=True)
241
+ for d in dicts[: int(sys.argv[4])]:
242
+ img = np.array(Image.open(d["file_name"]))
243
+ visualizer = Visualizer(img, metadata=meta)
244
+ vis = visualizer.draw_dataset_dict(d)
245
+ fpath = os.path.join(dirname, os.path.basename(d["file_name"]))
246
+ vis.save(fpath)
247
+
248
+
249
+ if __name__ == "__main__":
250
+ main() # pragma: no cover
CatVTON/detectron2/data/datasets/lvis_v0_5_categories.py ADDED
The diff for this file is too large to render. See raw diff
 
CatVTON/detectron2/data/datasets/lvis_v1_categories.py ADDED
The diff for this file is too large to render. See raw diff
 
CatVTON/detectron2/data/datasets/lvis_v1_category_image_count.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # Autogen with
3
+ # with open("lvis_v1_train.json", "r") as f:
4
+ # a = json.load(f)
5
+ # c = a["categories"]
6
+ # for x in c:
7
+ # del x["name"]
8
+ # del x["instance_count"]
9
+ # del x["def"]
10
+ # del x["synonyms"]
11
+ # del x["frequency"]
12
+ # del x["synset"]
13
+ # LVIS_CATEGORY_IMAGE_COUNT = repr(c) + " # noqa"
14
+ # with open("/tmp/lvis_category_image_count.py", "wt") as f:
15
+ # f.write(f"LVIS_CATEGORY_IMAGE_COUNT = {LVIS_CATEGORY_IMAGE_COUNT}")
16
+ # Then paste the contents of that file below
17
+
18
+ # fmt: off
19
+ LVIS_CATEGORY_IMAGE_COUNT = [{'id': 1, 'image_count': 64}, {'id': 2, 'image_count': 364}, {'id': 3, 'image_count': 1911}, {'id': 4, 'image_count': 149}, {'id': 5, 'image_count': 29}, {'id': 6, 'image_count': 26}, {'id': 7, 'image_count': 59}, {'id': 8, 'image_count': 22}, {'id': 9, 'image_count': 12}, {'id': 10, 'image_count': 28}, {'id': 11, 'image_count': 505}, {'id': 12, 'image_count': 1207}, {'id': 13, 'image_count': 4}, {'id': 14, 'image_count': 10}, {'id': 15, 'image_count': 500}, {'id': 16, 'image_count': 33}, {'id': 17, 'image_count': 3}, {'id': 18, 'image_count': 44}, {'id': 19, 'image_count': 561}, {'id': 20, 'image_count': 8}, {'id': 21, 'image_count': 9}, {'id': 22, 'image_count': 33}, {'id': 23, 'image_count': 1883}, {'id': 24, 'image_count': 98}, {'id': 25, 'image_count': 70}, {'id': 26, 'image_count': 46}, {'id': 27, 'image_count': 117}, {'id': 28, 'image_count': 41}, {'id': 29, 'image_count': 1395}, {'id': 30, 'image_count': 7}, {'id': 31, 'image_count': 1}, {'id': 32, 'image_count': 314}, {'id': 33, 'image_count': 31}, {'id': 34, 'image_count': 1905}, {'id': 35, 'image_count': 1859}, {'id': 36, 'image_count': 1623}, {'id': 37, 'image_count': 47}, {'id': 38, 'image_count': 3}, {'id': 39, 'image_count': 3}, {'id': 40, 'image_count': 1}, {'id': 41, 'image_count': 305}, {'id': 42, 'image_count': 6}, {'id': 43, 'image_count': 210}, {'id': 44, 'image_count': 36}, {'id': 45, 'image_count': 1787}, {'id': 46, 'image_count': 17}, {'id': 47, 'image_count': 51}, {'id': 48, 'image_count': 138}, {'id': 49, 'image_count': 3}, {'id': 50, 'image_count': 1470}, {'id': 51, 'image_count': 3}, {'id': 52, 'image_count': 2}, {'id': 53, 'image_count': 186}, {'id': 54, 'image_count': 76}, {'id': 55, 'image_count': 26}, {'id': 56, 'image_count': 303}, {'id': 57, 'image_count': 738}, {'id': 58, 'image_count': 1799}, {'id': 59, 'image_count': 1934}, {'id': 60, 'image_count': 1609}, {'id': 61, 'image_count': 1622}, {'id': 62, 'image_count': 41}, {'id': 63, 'image_count': 4}, {'id': 64, 'image_count': 11}, {'id': 65, 'image_count': 270}, {'id': 66, 'image_count': 349}, {'id': 67, 'image_count': 42}, {'id': 68, 'image_count': 823}, {'id': 69, 'image_count': 6}, {'id': 70, 'image_count': 48}, {'id': 71, 'image_count': 3}, {'id': 72, 'image_count': 42}, {'id': 73, 'image_count': 24}, {'id': 74, 'image_count': 16}, {'id': 75, 'image_count': 605}, {'id': 76, 'image_count': 646}, {'id': 77, 'image_count': 1765}, {'id': 78, 'image_count': 2}, {'id': 79, 'image_count': 125}, {'id': 80, 'image_count': 1420}, {'id': 81, 'image_count': 140}, {'id': 82, 'image_count': 4}, {'id': 83, 'image_count': 322}, {'id': 84, 'image_count': 60}, {'id': 85, 'image_count': 2}, {'id': 86, 'image_count': 231}, {'id': 87, 'image_count': 333}, {'id': 88, 'image_count': 1941}, {'id': 89, 'image_count': 367}, {'id': 90, 'image_count': 1922}, {'id': 91, 'image_count': 18}, {'id': 92, 'image_count': 81}, {'id': 93, 'image_count': 1}, {'id': 94, 'image_count': 1852}, {'id': 95, 'image_count': 430}, {'id': 96, 'image_count': 247}, {'id': 97, 'image_count': 94}, {'id': 98, 'image_count': 21}, {'id': 99, 'image_count': 1821}, {'id': 100, 'image_count': 16}, {'id': 101, 'image_count': 12}, {'id': 102, 'image_count': 25}, {'id': 103, 'image_count': 41}, {'id': 104, 'image_count': 244}, {'id': 105, 'image_count': 7}, {'id': 106, 'image_count': 1}, {'id': 107, 'image_count': 40}, {'id': 108, 'image_count': 40}, {'id': 109, 'image_count': 104}, {'id': 110, 'image_count': 1671}, {'id': 111, 'image_count': 49}, {'id': 112, 'image_count': 243}, {'id': 113, 'image_count': 2}, {'id': 114, 'image_count': 242}, {'id': 115, 'image_count': 271}, {'id': 116, 'image_count': 104}, {'id': 117, 'image_count': 8}, {'id': 118, 'image_count': 1758}, {'id': 119, 'image_count': 1}, {'id': 120, 'image_count': 48}, {'id': 121, 'image_count': 14}, {'id': 122, 'image_count': 40}, {'id': 123, 'image_count': 1}, {'id': 124, 'image_count': 37}, {'id': 125, 'image_count': 1510}, {'id': 126, 'image_count': 6}, {'id': 127, 'image_count': 1903}, {'id': 128, 'image_count': 70}, {'id': 129, 'image_count': 86}, {'id': 130, 'image_count': 7}, {'id': 131, 'image_count': 5}, {'id': 132, 'image_count': 1406}, {'id': 133, 'image_count': 1901}, {'id': 134, 'image_count': 15}, {'id': 135, 'image_count': 28}, {'id': 136, 'image_count': 6}, {'id': 137, 'image_count': 494}, {'id': 138, 'image_count': 234}, {'id': 139, 'image_count': 1922}, {'id': 140, 'image_count': 1}, {'id': 141, 'image_count': 35}, {'id': 142, 'image_count': 5}, {'id': 143, 'image_count': 1828}, {'id': 144, 'image_count': 8}, {'id': 145, 'image_count': 63}, {'id': 146, 'image_count': 1668}, {'id': 147, 'image_count': 4}, {'id': 148, 'image_count': 95}, {'id': 149, 'image_count': 17}, {'id': 150, 'image_count': 1567}, {'id': 151, 'image_count': 2}, {'id': 152, 'image_count': 103}, {'id': 153, 'image_count': 50}, {'id': 154, 'image_count': 1309}, {'id': 155, 'image_count': 6}, {'id': 156, 'image_count': 92}, {'id': 157, 'image_count': 19}, {'id': 158, 'image_count': 37}, {'id': 159, 'image_count': 4}, {'id': 160, 'image_count': 709}, {'id': 161, 'image_count': 9}, {'id': 162, 'image_count': 82}, {'id': 163, 'image_count': 15}, {'id': 164, 'image_count': 3}, {'id': 165, 'image_count': 61}, {'id': 166, 'image_count': 51}, {'id': 167, 'image_count': 5}, {'id': 168, 'image_count': 13}, {'id': 169, 'image_count': 642}, {'id': 170, 'image_count': 24}, {'id': 171, 'image_count': 255}, {'id': 172, 'image_count': 9}, {'id': 173, 'image_count': 1808}, {'id': 174, 'image_count': 31}, {'id': 175, 'image_count': 158}, {'id': 176, 'image_count': 80}, {'id': 177, 'image_count': 1884}, {'id': 178, 'image_count': 158}, {'id': 179, 'image_count': 2}, {'id': 180, 'image_count': 12}, {'id': 181, 'image_count': 1659}, {'id': 182, 'image_count': 7}, {'id': 183, 'image_count': 834}, {'id': 184, 'image_count': 57}, {'id': 185, 'image_count': 174}, {'id': 186, 'image_count': 95}, {'id': 187, 'image_count': 27}, {'id': 188, 'image_count': 22}, {'id': 189, 'image_count': 1391}, {'id': 190, 'image_count': 90}, {'id': 191, 'image_count': 40}, {'id': 192, 'image_count': 445}, {'id': 193, 'image_count': 21}, {'id': 194, 'image_count': 1132}, {'id': 195, 'image_count': 177}, {'id': 196, 'image_count': 4}, {'id': 197, 'image_count': 17}, {'id': 198, 'image_count': 84}, {'id': 199, 'image_count': 55}, {'id': 200, 'image_count': 30}, {'id': 201, 'image_count': 25}, {'id': 202, 'image_count': 2}, {'id': 203, 'image_count': 125}, {'id': 204, 'image_count': 1135}, {'id': 205, 'image_count': 19}, {'id': 206, 'image_count': 72}, {'id': 207, 'image_count': 1926}, {'id': 208, 'image_count': 159}, {'id': 209, 'image_count': 7}, {'id': 210, 'image_count': 1}, {'id': 211, 'image_count': 13}, {'id': 212, 'image_count': 35}, {'id': 213, 'image_count': 18}, {'id': 214, 'image_count': 8}, {'id': 215, 'image_count': 6}, {'id': 216, 'image_count': 35}, {'id': 217, 'image_count': 1222}, {'id': 218, 'image_count': 103}, {'id': 219, 'image_count': 28}, {'id': 220, 'image_count': 63}, {'id': 221, 'image_count': 28}, {'id': 222, 'image_count': 5}, {'id': 223, 'image_count': 7}, {'id': 224, 'image_count': 14}, {'id': 225, 'image_count': 1918}, {'id': 226, 'image_count': 133}, {'id': 227, 'image_count': 16}, {'id': 228, 'image_count': 27}, {'id': 229, 'image_count': 110}, {'id': 230, 'image_count': 1895}, {'id': 231, 'image_count': 4}, {'id': 232, 'image_count': 1927}, {'id': 233, 'image_count': 8}, {'id': 234, 'image_count': 1}, {'id': 235, 'image_count': 263}, {'id': 236, 'image_count': 10}, {'id': 237, 'image_count': 2}, {'id': 238, 'image_count': 3}, {'id': 239, 'image_count': 87}, {'id': 240, 'image_count': 9}, {'id': 241, 'image_count': 71}, {'id': 242, 'image_count': 13}, {'id': 243, 'image_count': 18}, {'id': 244, 'image_count': 2}, {'id': 245, 'image_count': 5}, {'id': 246, 'image_count': 45}, {'id': 247, 'image_count': 1}, {'id': 248, 'image_count': 23}, {'id': 249, 'image_count': 32}, {'id': 250, 'image_count': 4}, {'id': 251, 'image_count': 1}, {'id': 252, 'image_count': 858}, {'id': 253, 'image_count': 661}, {'id': 254, 'image_count': 168}, {'id': 255, 'image_count': 210}, {'id': 256, 'image_count': 65}, {'id': 257, 'image_count': 4}, {'id': 258, 'image_count': 2}, {'id': 259, 'image_count': 159}, {'id': 260, 'image_count': 31}, {'id': 261, 'image_count': 811}, {'id': 262, 'image_count': 1}, {'id': 263, 'image_count': 42}, {'id': 264, 'image_count': 27}, {'id': 265, 'image_count': 2}, {'id': 266, 'image_count': 5}, {'id': 267, 'image_count': 95}, {'id': 268, 'image_count': 32}, {'id': 269, 'image_count': 1}, {'id': 270, 'image_count': 1}, {'id': 271, 'image_count': 1844}, {'id': 272, 'image_count': 897}, {'id': 273, 'image_count': 31}, {'id': 274, 'image_count': 23}, {'id': 275, 'image_count': 1}, {'id': 276, 'image_count': 202}, {'id': 277, 'image_count': 746}, {'id': 278, 'image_count': 44}, {'id': 279, 'image_count': 14}, {'id': 280, 'image_count': 26}, {'id': 281, 'image_count': 1}, {'id': 282, 'image_count': 2}, {'id': 283, 'image_count': 25}, {'id': 284, 'image_count': 238}, {'id': 285, 'image_count': 592}, {'id': 286, 'image_count': 26}, {'id': 287, 'image_count': 5}, {'id': 288, 'image_count': 42}, {'id': 289, 'image_count': 13}, {'id': 290, 'image_count': 46}, {'id': 291, 'image_count': 1}, {'id': 292, 'image_count': 8}, {'id': 293, 'image_count': 34}, {'id': 294, 'image_count': 5}, {'id': 295, 'image_count': 1}, {'id': 296, 'image_count': 1871}, {'id': 297, 'image_count': 717}, {'id': 298, 'image_count': 1010}, {'id': 299, 'image_count': 679}, {'id': 300, 'image_count': 3}, {'id': 301, 'image_count': 4}, {'id': 302, 'image_count': 1}, {'id': 303, 'image_count': 166}, {'id': 304, 'image_count': 2}, {'id': 305, 'image_count': 266}, {'id': 306, 'image_count': 101}, {'id': 307, 'image_count': 6}, {'id': 308, 'image_count': 14}, {'id': 309, 'image_count': 133}, {'id': 310, 'image_count': 2}, {'id': 311, 'image_count': 38}, {'id': 312, 'image_count': 95}, {'id': 313, 'image_count': 1}, {'id': 314, 'image_count': 12}, {'id': 315, 'image_count': 49}, {'id': 316, 'image_count': 5}, {'id': 317, 'image_count': 5}, {'id': 318, 'image_count': 16}, {'id': 319, 'image_count': 216}, {'id': 320, 'image_count': 12}, {'id': 321, 'image_count': 1}, {'id': 322, 'image_count': 54}, {'id': 323, 'image_count': 5}, {'id': 324, 'image_count': 245}, {'id': 325, 'image_count': 12}, {'id': 326, 'image_count': 7}, {'id': 327, 'image_count': 35}, {'id': 328, 'image_count': 36}, {'id': 329, 'image_count': 32}, {'id': 330, 'image_count': 1027}, {'id': 331, 'image_count': 10}, {'id': 332, 'image_count': 12}, {'id': 333, 'image_count': 1}, {'id': 334, 'image_count': 67}, {'id': 335, 'image_count': 71}, {'id': 336, 'image_count': 30}, {'id': 337, 'image_count': 48}, {'id': 338, 'image_count': 249}, {'id': 339, 'image_count': 13}, {'id': 340, 'image_count': 29}, {'id': 341, 'image_count': 14}, {'id': 342, 'image_count': 236}, {'id': 343, 'image_count': 15}, {'id': 344, 'image_count': 1521}, {'id': 345, 'image_count': 25}, {'id': 346, 'image_count': 249}, {'id': 347, 'image_count': 139}, {'id': 348, 'image_count': 2}, {'id': 349, 'image_count': 2}, {'id': 350, 'image_count': 1890}, {'id': 351, 'image_count': 1240}, {'id': 352, 'image_count': 1}, {'id': 353, 'image_count': 9}, {'id': 354, 'image_count': 1}, {'id': 355, 'image_count': 3}, {'id': 356, 'image_count': 11}, {'id': 357, 'image_count': 4}, {'id': 358, 'image_count': 236}, {'id': 359, 'image_count': 44}, {'id': 360, 'image_count': 19}, {'id': 361, 'image_count': 1100}, {'id': 362, 'image_count': 7}, {'id': 363, 'image_count': 69}, {'id': 364, 'image_count': 2}, {'id': 365, 'image_count': 8}, {'id': 366, 'image_count': 5}, {'id': 367, 'image_count': 227}, {'id': 368, 'image_count': 6}, {'id': 369, 'image_count': 106}, {'id': 370, 'image_count': 81}, {'id': 371, 'image_count': 17}, {'id': 372, 'image_count': 134}, {'id': 373, 'image_count': 312}, {'id': 374, 'image_count': 8}, {'id': 375, 'image_count': 271}, {'id': 376, 'image_count': 2}, {'id': 377, 'image_count': 103}, {'id': 378, 'image_count': 1938}, {'id': 379, 'image_count': 574}, {'id': 380, 'image_count': 120}, {'id': 381, 'image_count': 2}, {'id': 382, 'image_count': 2}, {'id': 383, 'image_count': 13}, {'id': 384, 'image_count': 29}, {'id': 385, 'image_count': 1710}, {'id': 386, 'image_count': 66}, {'id': 387, 'image_count': 1008}, {'id': 388, 'image_count': 1}, {'id': 389, 'image_count': 3}, {'id': 390, 'image_count': 1942}, {'id': 391, 'image_count': 19}, {'id': 392, 'image_count': 1488}, {'id': 393, 'image_count': 46}, {'id': 394, 'image_count': 106}, {'id': 395, 'image_count': 115}, {'id': 396, 'image_count': 19}, {'id': 397, 'image_count': 2}, {'id': 398, 'image_count': 1}, {'id': 399, 'image_count': 28}, {'id': 400, 'image_count': 9}, {'id': 401, 'image_count': 192}, {'id': 402, 'image_count': 12}, {'id': 403, 'image_count': 21}, {'id': 404, 'image_count': 247}, {'id': 405, 'image_count': 6}, {'id': 406, 'image_count': 64}, {'id': 407, 'image_count': 7}, {'id': 408, 'image_count': 40}, {'id': 409, 'image_count': 542}, {'id': 410, 'image_count': 2}, {'id': 411, 'image_count': 1898}, {'id': 412, 'image_count': 36}, {'id': 413, 'image_count': 4}, {'id': 414, 'image_count': 1}, {'id': 415, 'image_count': 191}, {'id': 416, 'image_count': 6}, {'id': 417, 'image_count': 41}, {'id': 418, 'image_count': 39}, {'id': 419, 'image_count': 46}, {'id': 420, 'image_count': 1}, {'id': 421, 'image_count': 1451}, {'id': 422, 'image_count': 1878}, {'id': 423, 'image_count': 11}, {'id': 424, 'image_count': 82}, {'id': 425, 'image_count': 18}, {'id': 426, 'image_count': 1}, {'id': 427, 'image_count': 7}, {'id': 428, 'image_count': 3}, {'id': 429, 'image_count': 575}, {'id': 430, 'image_count': 1907}, {'id': 431, 'image_count': 8}, {'id': 432, 'image_count': 4}, {'id': 433, 'image_count': 32}, {'id': 434, 'image_count': 11}, {'id': 435, 'image_count': 4}, {'id': 436, 'image_count': 54}, {'id': 437, 'image_count': 202}, {'id': 438, 'image_count': 32}, {'id': 439, 'image_count': 3}, {'id': 440, 'image_count': 130}, {'id': 441, 'image_count': 119}, {'id': 442, 'image_count': 141}, {'id': 443, 'image_count': 29}, {'id': 444, 'image_count': 525}, {'id': 445, 'image_count': 1323}, {'id': 446, 'image_count': 2}, {'id': 447, 'image_count': 113}, {'id': 448, 'image_count': 16}, {'id': 449, 'image_count': 7}, {'id': 450, 'image_count': 35}, {'id': 451, 'image_count': 1908}, {'id': 452, 'image_count': 353}, {'id': 453, 'image_count': 18}, {'id': 454, 'image_count': 14}, {'id': 455, 'image_count': 77}, {'id': 456, 'image_count': 8}, {'id': 457, 'image_count': 37}, {'id': 458, 'image_count': 1}, {'id': 459, 'image_count': 346}, {'id': 460, 'image_count': 19}, {'id': 461, 'image_count': 1779}, {'id': 462, 'image_count': 23}, {'id': 463, 'image_count': 25}, {'id': 464, 'image_count': 67}, {'id': 465, 'image_count': 19}, {'id': 466, 'image_count': 28}, {'id': 467, 'image_count': 4}, {'id': 468, 'image_count': 27}, {'id': 469, 'image_count': 1861}, {'id': 470, 'image_count': 11}, {'id': 471, 'image_count': 13}, {'id': 472, 'image_count': 13}, {'id': 473, 'image_count': 32}, {'id': 474, 'image_count': 1767}, {'id': 475, 'image_count': 42}, {'id': 476, 'image_count': 17}, {'id': 477, 'image_count': 128}, {'id': 478, 'image_count': 1}, {'id': 479, 'image_count': 9}, {'id': 480, 'image_count': 10}, {'id': 481, 'image_count': 4}, {'id': 482, 'image_count': 9}, {'id': 483, 'image_count': 18}, {'id': 484, 'image_count': 41}, {'id': 485, 'image_count': 28}, {'id': 486, 'image_count': 3}, {'id': 487, 'image_count': 65}, {'id': 488, 'image_count': 9}, {'id': 489, 'image_count': 23}, {'id': 490, 'image_count': 24}, {'id': 491, 'image_count': 1}, {'id': 492, 'image_count': 2}, {'id': 493, 'image_count': 59}, {'id': 494, 'image_count': 48}, {'id': 495, 'image_count': 17}, {'id': 496, 'image_count': 1877}, {'id': 497, 'image_count': 18}, {'id': 498, 'image_count': 1920}, {'id': 499, 'image_count': 50}, {'id': 500, 'image_count': 1890}, {'id': 501, 'image_count': 99}, {'id': 502, 'image_count': 1530}, {'id': 503, 'image_count': 3}, {'id': 504, 'image_count': 11}, {'id': 505, 'image_count': 19}, {'id': 506, 'image_count': 3}, {'id': 507, 'image_count': 63}, {'id': 508, 'image_count': 5}, {'id': 509, 'image_count': 6}, {'id': 510, 'image_count': 233}, {'id': 511, 'image_count': 54}, {'id': 512, 'image_count': 36}, {'id': 513, 'image_count': 10}, {'id': 514, 'image_count': 124}, {'id': 515, 'image_count': 101}, {'id': 516, 'image_count': 3}, {'id': 517, 'image_count': 363}, {'id': 518, 'image_count': 3}, {'id': 519, 'image_count': 30}, {'id': 520, 'image_count': 18}, {'id': 521, 'image_count': 199}, {'id': 522, 'image_count': 97}, {'id': 523, 'image_count': 32}, {'id': 524, 'image_count': 121}, {'id': 525, 'image_count': 16}, {'id': 526, 'image_count': 12}, {'id': 527, 'image_count': 2}, {'id': 528, 'image_count': 214}, {'id': 529, 'image_count': 48}, {'id': 530, 'image_count': 26}, {'id': 531, 'image_count': 13}, {'id': 532, 'image_count': 4}, {'id': 533, 'image_count': 11}, {'id': 534, 'image_count': 123}, {'id': 535, 'image_count': 7}, {'id': 536, 'image_count': 200}, {'id': 537, 'image_count': 91}, {'id': 538, 'image_count': 9}, {'id': 539, 'image_count': 72}, {'id': 540, 'image_count': 1886}, {'id': 541, 'image_count': 4}, {'id': 542, 'image_count': 1}, {'id': 543, 'image_count': 1}, {'id': 544, 'image_count': 1932}, {'id': 545, 'image_count': 4}, {'id': 546, 'image_count': 56}, {'id': 547, 'image_count': 854}, {'id': 548, 'image_count': 755}, {'id': 549, 'image_count': 1843}, {'id': 550, 'image_count': 96}, {'id': 551, 'image_count': 7}, {'id': 552, 'image_count': 74}, {'id': 553, 'image_count': 66}, {'id': 554, 'image_count': 57}, {'id': 555, 'image_count': 44}, {'id': 556, 'image_count': 1905}, {'id': 557, 'image_count': 4}, {'id': 558, 'image_count': 90}, {'id': 559, 'image_count': 1635}, {'id': 560, 'image_count': 8}, {'id': 561, 'image_count': 5}, {'id': 562, 'image_count': 50}, {'id': 563, 'image_count': 545}, {'id': 564, 'image_count': 20}, {'id': 565, 'image_count': 193}, {'id': 566, 'image_count': 285}, {'id': 567, 'image_count': 3}, {'id': 568, 'image_count': 1}, {'id': 569, 'image_count': 1904}, {'id': 570, 'image_count': 294}, {'id': 571, 'image_count': 3}, {'id': 572, 'image_count': 5}, {'id': 573, 'image_count': 24}, {'id': 574, 'image_count': 2}, {'id': 575, 'image_count': 2}, {'id': 576, 'image_count': 16}, {'id': 577, 'image_count': 8}, {'id': 578, 'image_count': 154}, {'id': 579, 'image_count': 66}, {'id': 580, 'image_count': 1}, {'id': 581, 'image_count': 24}, {'id': 582, 'image_count': 1}, {'id': 583, 'image_count': 4}, {'id': 584, 'image_count': 75}, {'id': 585, 'image_count': 6}, {'id': 586, 'image_count': 126}, {'id': 587, 'image_count': 24}, {'id': 588, 'image_count': 22}, {'id': 589, 'image_count': 1872}, {'id': 590, 'image_count': 16}, {'id': 591, 'image_count': 423}, {'id': 592, 'image_count': 1927}, {'id': 593, 'image_count': 38}, {'id': 594, 'image_count': 3}, {'id': 595, 'image_count': 1945}, {'id': 596, 'image_count': 35}, {'id': 597, 'image_count': 1}, {'id': 598, 'image_count': 13}, {'id': 599, 'image_count': 9}, {'id': 600, 'image_count': 14}, {'id': 601, 'image_count': 37}, {'id': 602, 'image_count': 3}, {'id': 603, 'image_count': 4}, {'id': 604, 'image_count': 100}, {'id': 605, 'image_count': 195}, {'id': 606, 'image_count': 1}, {'id': 607, 'image_count': 12}, {'id': 608, 'image_count': 24}, {'id': 609, 'image_count': 489}, {'id': 610, 'image_count': 10}, {'id': 611, 'image_count': 1689}, {'id': 612, 'image_count': 42}, {'id': 613, 'image_count': 81}, {'id': 614, 'image_count': 894}, {'id': 615, 'image_count': 1868}, {'id': 616, 'image_count': 7}, {'id': 617, 'image_count': 1567}, {'id': 618, 'image_count': 10}, {'id': 619, 'image_count': 8}, {'id': 620, 'image_count': 7}, {'id': 621, 'image_count': 629}, {'id': 622, 'image_count': 89}, {'id': 623, 'image_count': 15}, {'id': 624, 'image_count': 134}, {'id': 625, 'image_count': 4}, {'id': 626, 'image_count': 1802}, {'id': 627, 'image_count': 595}, {'id': 628, 'image_count': 1210}, {'id': 629, 'image_count': 48}, {'id': 630, 'image_count': 418}, {'id': 631, 'image_count': 1846}, {'id': 632, 'image_count': 5}, {'id': 633, 'image_count': 221}, {'id': 634, 'image_count': 10}, {'id': 635, 'image_count': 7}, {'id': 636, 'image_count': 76}, {'id': 637, 'image_count': 22}, {'id': 638, 'image_count': 10}, {'id': 639, 'image_count': 341}, {'id': 640, 'image_count': 1}, {'id': 641, 'image_count': 705}, {'id': 642, 'image_count': 1900}, {'id': 643, 'image_count': 188}, {'id': 644, 'image_count': 227}, {'id': 645, 'image_count': 861}, {'id': 646, 'image_count': 6}, {'id': 647, 'image_count': 115}, {'id': 648, 'image_count': 5}, {'id': 649, 'image_count': 43}, {'id': 650, 'image_count': 14}, {'id': 651, 'image_count': 6}, {'id': 652, 'image_count': 15}, {'id': 653, 'image_count': 1167}, {'id': 654, 'image_count': 15}, {'id': 655, 'image_count': 994}, {'id': 656, 'image_count': 28}, {'id': 657, 'image_count': 2}, {'id': 658, 'image_count': 338}, {'id': 659, 'image_count': 334}, {'id': 660, 'image_count': 15}, {'id': 661, 'image_count': 102}, {'id': 662, 'image_count': 1}, {'id': 663, 'image_count': 8}, {'id': 664, 'image_count': 1}, {'id': 665, 'image_count': 1}, {'id': 666, 'image_count': 28}, {'id': 667, 'image_count': 91}, {'id': 668, 'image_count': 260}, {'id': 669, 'image_count': 131}, {'id': 670, 'image_count': 128}, {'id': 671, 'image_count': 3}, {'id': 672, 'image_count': 10}, {'id': 673, 'image_count': 39}, {'id': 674, 'image_count': 2}, {'id': 675, 'image_count': 925}, {'id': 676, 'image_count': 354}, {'id': 677, 'image_count': 31}, {'id': 678, 'image_count': 10}, {'id': 679, 'image_count': 215}, {'id': 680, 'image_count': 71}, {'id': 681, 'image_count': 43}, {'id': 682, 'image_count': 28}, {'id': 683, 'image_count': 34}, {'id': 684, 'image_count': 16}, {'id': 685, 'image_count': 273}, {'id': 686, 'image_count': 2}, {'id': 687, 'image_count': 999}, {'id': 688, 'image_count': 4}, {'id': 689, 'image_count': 107}, {'id': 690, 'image_count': 2}, {'id': 691, 'image_count': 1}, {'id': 692, 'image_count': 454}, {'id': 693, 'image_count': 9}, {'id': 694, 'image_count': 1901}, {'id': 695, 'image_count': 61}, {'id': 696, 'image_count': 91}, {'id': 697, 'image_count': 46}, {'id': 698, 'image_count': 1402}, {'id': 699, 'image_count': 74}, {'id': 700, 'image_count': 421}, {'id': 701, 'image_count': 226}, {'id': 702, 'image_count': 10}, {'id': 703, 'image_count': 1720}, {'id': 704, 'image_count': 261}, {'id': 705, 'image_count': 1337}, {'id': 706, 'image_count': 293}, {'id': 707, 'image_count': 62}, {'id': 708, 'image_count': 814}, {'id': 709, 'image_count': 407}, {'id': 710, 'image_count': 6}, {'id': 711, 'image_count': 16}, {'id': 712, 'image_count': 7}, {'id': 713, 'image_count': 1791}, {'id': 714, 'image_count': 2}, {'id': 715, 'image_count': 1915}, {'id': 716, 'image_count': 1940}, {'id': 717, 'image_count': 13}, {'id': 718, 'image_count': 16}, {'id': 719, 'image_count': 448}, {'id': 720, 'image_count': 12}, {'id': 721, 'image_count': 18}, {'id': 722, 'image_count': 4}, {'id': 723, 'image_count': 71}, {'id': 724, 'image_count': 189}, {'id': 725, 'image_count': 74}, {'id': 726, 'image_count': 103}, {'id': 727, 'image_count': 3}, {'id': 728, 'image_count': 110}, {'id': 729, 'image_count': 5}, {'id': 730, 'image_count': 9}, {'id': 731, 'image_count': 15}, {'id': 732, 'image_count': 25}, {'id': 733, 'image_count': 7}, {'id': 734, 'image_count': 647}, {'id': 735, 'image_count': 824}, {'id': 736, 'image_count': 100}, {'id': 737, 'image_count': 47}, {'id': 738, 'image_count': 121}, {'id': 739, 'image_count': 731}, {'id': 740, 'image_count': 73}, {'id': 741, 'image_count': 49}, {'id': 742, 'image_count': 23}, {'id': 743, 'image_count': 4}, {'id': 744, 'image_count': 62}, {'id': 745, 'image_count': 118}, {'id': 746, 'image_count': 99}, {'id': 747, 'image_count': 40}, {'id': 748, 'image_count': 1036}, {'id': 749, 'image_count': 105}, {'id': 750, 'image_count': 21}, {'id': 751, 'image_count': 229}, {'id': 752, 'image_count': 7}, {'id': 753, 'image_count': 72}, {'id': 754, 'image_count': 9}, {'id': 755, 'image_count': 10}, {'id': 756, 'image_count': 328}, {'id': 757, 'image_count': 468}, {'id': 758, 'image_count': 1}, {'id': 759, 'image_count': 2}, {'id': 760, 'image_count': 24}, {'id': 761, 'image_count': 11}, {'id': 762, 'image_count': 72}, {'id': 763, 'image_count': 17}, {'id': 764, 'image_count': 10}, {'id': 765, 'image_count': 17}, {'id': 766, 'image_count': 489}, {'id': 767, 'image_count': 47}, {'id': 768, 'image_count': 93}, {'id': 769, 'image_count': 1}, {'id': 770, 'image_count': 12}, {'id': 771, 'image_count': 228}, {'id': 772, 'image_count': 5}, {'id': 773, 'image_count': 76}, {'id': 774, 'image_count': 71}, {'id': 775, 'image_count': 30}, {'id': 776, 'image_count': 109}, {'id': 777, 'image_count': 14}, {'id': 778, 'image_count': 1}, {'id': 779, 'image_count': 8}, {'id': 780, 'image_count': 26}, {'id': 781, 'image_count': 339}, {'id': 782, 'image_count': 153}, {'id': 783, 'image_count': 2}, {'id': 784, 'image_count': 3}, {'id': 785, 'image_count': 8}, {'id': 786, 'image_count': 47}, {'id': 787, 'image_count': 8}, {'id': 788, 'image_count': 6}, {'id': 789, 'image_count': 116}, {'id': 790, 'image_count': 69}, {'id': 791, 'image_count': 13}, {'id': 792, 'image_count': 6}, {'id': 793, 'image_count': 1928}, {'id': 794, 'image_count': 79}, {'id': 795, 'image_count': 14}, {'id': 796, 'image_count': 7}, {'id': 797, 'image_count': 20}, {'id': 798, 'image_count': 114}, {'id': 799, 'image_count': 221}, {'id': 800, 'image_count': 502}, {'id': 801, 'image_count': 62}, {'id': 802, 'image_count': 87}, {'id': 803, 'image_count': 4}, {'id': 804, 'image_count': 1912}, {'id': 805, 'image_count': 7}, {'id': 806, 'image_count': 186}, {'id': 807, 'image_count': 18}, {'id': 808, 'image_count': 4}, {'id': 809, 'image_count': 3}, {'id': 810, 'image_count': 7}, {'id': 811, 'image_count': 1413}, {'id': 812, 'image_count': 7}, {'id': 813, 'image_count': 12}, {'id': 814, 'image_count': 248}, {'id': 815, 'image_count': 4}, {'id': 816, 'image_count': 1881}, {'id': 817, 'image_count': 529}, {'id': 818, 'image_count': 1932}, {'id': 819, 'image_count': 50}, {'id': 820, 'image_count': 3}, {'id': 821, 'image_count': 28}, {'id': 822, 'image_count': 10}, {'id': 823, 'image_count': 5}, {'id': 824, 'image_count': 5}, {'id': 825, 'image_count': 18}, {'id': 826, 'image_count': 14}, {'id': 827, 'image_count': 1890}, {'id': 828, 'image_count': 660}, {'id': 829, 'image_count': 8}, {'id': 830, 'image_count': 25}, {'id': 831, 'image_count': 10}, {'id': 832, 'image_count': 218}, {'id': 833, 'image_count': 36}, {'id': 834, 'image_count': 16}, {'id': 835, 'image_count': 808}, {'id': 836, 'image_count': 479}, {'id': 837, 'image_count': 1404}, {'id': 838, 'image_count': 307}, {'id': 839, 'image_count': 57}, {'id': 840, 'image_count': 28}, {'id': 841, 'image_count': 80}, {'id': 842, 'image_count': 11}, {'id': 843, 'image_count': 92}, {'id': 844, 'image_count': 20}, {'id': 845, 'image_count': 194}, {'id': 846, 'image_count': 23}, {'id': 847, 'image_count': 52}, {'id': 848, 'image_count': 673}, {'id': 849, 'image_count': 2}, {'id': 850, 'image_count': 2}, {'id': 851, 'image_count': 1}, {'id': 852, 'image_count': 2}, {'id': 853, 'image_count': 8}, {'id': 854, 'image_count': 80}, {'id': 855, 'image_count': 3}, {'id': 856, 'image_count': 3}, {'id': 857, 'image_count': 15}, {'id': 858, 'image_count': 2}, {'id': 859, 'image_count': 10}, {'id': 860, 'image_count': 386}, {'id': 861, 'image_count': 65}, {'id': 862, 'image_count': 3}, {'id': 863, 'image_count': 35}, {'id': 864, 'image_count': 5}, {'id': 865, 'image_count': 180}, {'id': 866, 'image_count': 99}, {'id': 867, 'image_count': 49}, {'id': 868, 'image_count': 28}, {'id': 869, 'image_count': 1}, {'id': 870, 'image_count': 52}, {'id': 871, 'image_count': 36}, {'id': 872, 'image_count': 70}, {'id': 873, 'image_count': 6}, {'id': 874, 'image_count': 29}, {'id': 875, 'image_count': 24}, {'id': 876, 'image_count': 1115}, {'id': 877, 'image_count': 61}, {'id': 878, 'image_count': 18}, {'id': 879, 'image_count': 18}, {'id': 880, 'image_count': 665}, {'id': 881, 'image_count': 1096}, {'id': 882, 'image_count': 29}, {'id': 883, 'image_count': 8}, {'id': 884, 'image_count': 14}, {'id': 885, 'image_count': 1622}, {'id': 886, 'image_count': 2}, {'id': 887, 'image_count': 3}, {'id': 888, 'image_count': 32}, {'id': 889, 'image_count': 55}, {'id': 890, 'image_count': 1}, {'id': 891, 'image_count': 10}, {'id': 892, 'image_count': 10}, {'id': 893, 'image_count': 47}, {'id': 894, 'image_count': 3}, {'id': 895, 'image_count': 29}, {'id': 896, 'image_count': 342}, {'id': 897, 'image_count': 25}, {'id': 898, 'image_count': 1469}, {'id': 899, 'image_count': 521}, {'id': 900, 'image_count': 347}, {'id': 901, 'image_count': 35}, {'id': 902, 'image_count': 7}, {'id': 903, 'image_count': 207}, {'id': 904, 'image_count': 108}, {'id': 905, 'image_count': 2}, {'id': 906, 'image_count': 34}, {'id': 907, 'image_count': 12}, {'id': 908, 'image_count': 10}, {'id': 909, 'image_count': 13}, {'id': 910, 'image_count': 361}, {'id': 911, 'image_count': 1023}, {'id': 912, 'image_count': 782}, {'id': 913, 'image_count': 2}, {'id': 914, 'image_count': 5}, {'id': 915, 'image_count': 247}, {'id': 916, 'image_count': 221}, {'id': 917, 'image_count': 4}, {'id': 918, 'image_count': 8}, {'id': 919, 'image_count': 158}, {'id': 920, 'image_count': 3}, {'id': 921, 'image_count': 752}, {'id': 922, 'image_count': 64}, {'id': 923, 'image_count': 707}, {'id': 924, 'image_count': 143}, {'id': 925, 'image_count': 1}, {'id': 926, 'image_count': 49}, {'id': 927, 'image_count': 126}, {'id': 928, 'image_count': 76}, {'id': 929, 'image_count': 11}, {'id': 930, 'image_count': 11}, {'id': 931, 'image_count': 4}, {'id': 932, 'image_count': 39}, {'id': 933, 'image_count': 11}, {'id': 934, 'image_count': 13}, {'id': 935, 'image_count': 91}, {'id': 936, 'image_count': 14}, {'id': 937, 'image_count': 5}, {'id': 938, 'image_count': 3}, {'id': 939, 'image_count': 10}, {'id': 940, 'image_count': 18}, {'id': 941, 'image_count': 9}, {'id': 942, 'image_count': 6}, {'id': 943, 'image_count': 951}, {'id': 944, 'image_count': 2}, {'id': 945, 'image_count': 1}, {'id': 946, 'image_count': 19}, {'id': 947, 'image_count': 1942}, {'id': 948, 'image_count': 1916}, {'id': 949, 'image_count': 139}, {'id': 950, 'image_count': 43}, {'id': 951, 'image_count': 1969}, {'id': 952, 'image_count': 5}, {'id': 953, 'image_count': 134}, {'id': 954, 'image_count': 74}, {'id': 955, 'image_count': 381}, {'id': 956, 'image_count': 1}, {'id': 957, 'image_count': 381}, {'id': 958, 'image_count': 6}, {'id': 959, 'image_count': 1826}, {'id': 960, 'image_count': 28}, {'id': 961, 'image_count': 1635}, {'id': 962, 'image_count': 1967}, {'id': 963, 'image_count': 16}, {'id': 964, 'image_count': 1926}, {'id': 965, 'image_count': 1789}, {'id': 966, 'image_count': 401}, {'id': 967, 'image_count': 1968}, {'id': 968, 'image_count': 1167}, {'id': 969, 'image_count': 1}, {'id': 970, 'image_count': 56}, {'id': 971, 'image_count': 17}, {'id': 972, 'image_count': 1}, {'id': 973, 'image_count': 58}, {'id': 974, 'image_count': 9}, {'id': 975, 'image_count': 8}, {'id': 976, 'image_count': 1124}, {'id': 977, 'image_count': 31}, {'id': 978, 'image_count': 16}, {'id': 979, 'image_count': 491}, {'id': 980, 'image_count': 432}, {'id': 981, 'image_count': 1945}, {'id': 982, 'image_count': 1899}, {'id': 983, 'image_count': 5}, {'id': 984, 'image_count': 28}, {'id': 985, 'image_count': 7}, {'id': 986, 'image_count': 146}, {'id': 987, 'image_count': 1}, {'id': 988, 'image_count': 25}, {'id': 989, 'image_count': 22}, {'id': 990, 'image_count': 1}, {'id': 991, 'image_count': 10}, {'id': 992, 'image_count': 9}, {'id': 993, 'image_count': 308}, {'id': 994, 'image_count': 4}, {'id': 995, 'image_count': 1969}, {'id': 996, 'image_count': 45}, {'id': 997, 'image_count': 12}, {'id': 998, 'image_count': 1}, {'id': 999, 'image_count': 85}, {'id': 1000, 'image_count': 1127}, {'id': 1001, 'image_count': 11}, {'id': 1002, 'image_count': 60}, {'id': 1003, 'image_count': 1}, {'id': 1004, 'image_count': 16}, {'id': 1005, 'image_count': 1}, {'id': 1006, 'image_count': 65}, {'id': 1007, 'image_count': 13}, {'id': 1008, 'image_count': 655}, {'id': 1009, 'image_count': 51}, {'id': 1010, 'image_count': 1}, {'id': 1011, 'image_count': 673}, {'id': 1012, 'image_count': 5}, {'id': 1013, 'image_count': 36}, {'id': 1014, 'image_count': 54}, {'id': 1015, 'image_count': 5}, {'id': 1016, 'image_count': 8}, {'id': 1017, 'image_count': 305}, {'id': 1018, 'image_count': 297}, {'id': 1019, 'image_count': 1053}, {'id': 1020, 'image_count': 223}, {'id': 1021, 'image_count': 1037}, {'id': 1022, 'image_count': 63}, {'id': 1023, 'image_count': 1881}, {'id': 1024, 'image_count': 507}, {'id': 1025, 'image_count': 333}, {'id': 1026, 'image_count': 1911}, {'id': 1027, 'image_count': 1765}, {'id': 1028, 'image_count': 1}, {'id': 1029, 'image_count': 5}, {'id': 1030, 'image_count': 1}, {'id': 1031, 'image_count': 9}, {'id': 1032, 'image_count': 2}, {'id': 1033, 'image_count': 151}, {'id': 1034, 'image_count': 82}, {'id': 1035, 'image_count': 1931}, {'id': 1036, 'image_count': 41}, {'id': 1037, 'image_count': 1895}, {'id': 1038, 'image_count': 24}, {'id': 1039, 'image_count': 22}, {'id': 1040, 'image_count': 35}, {'id': 1041, 'image_count': 69}, {'id': 1042, 'image_count': 962}, {'id': 1043, 'image_count': 588}, {'id': 1044, 'image_count': 21}, {'id': 1045, 'image_count': 825}, {'id': 1046, 'image_count': 52}, {'id': 1047, 'image_count': 5}, {'id': 1048, 'image_count': 5}, {'id': 1049, 'image_count': 5}, {'id': 1050, 'image_count': 1860}, {'id': 1051, 'image_count': 56}, {'id': 1052, 'image_count': 1582}, {'id': 1053, 'image_count': 7}, {'id': 1054, 'image_count': 2}, {'id': 1055, 'image_count': 1562}, {'id': 1056, 'image_count': 1885}, {'id': 1057, 'image_count': 1}, {'id': 1058, 'image_count': 5}, {'id': 1059, 'image_count': 137}, {'id': 1060, 'image_count': 1094}, {'id': 1061, 'image_count': 134}, {'id': 1062, 'image_count': 29}, {'id': 1063, 'image_count': 22}, {'id': 1064, 'image_count': 522}, {'id': 1065, 'image_count': 50}, {'id': 1066, 'image_count': 68}, {'id': 1067, 'image_count': 16}, {'id': 1068, 'image_count': 40}, {'id': 1069, 'image_count': 35}, {'id': 1070, 'image_count': 135}, {'id': 1071, 'image_count': 1413}, {'id': 1072, 'image_count': 772}, {'id': 1073, 'image_count': 50}, {'id': 1074, 'image_count': 1015}, {'id': 1075, 'image_count': 1}, {'id': 1076, 'image_count': 65}, {'id': 1077, 'image_count': 1900}, {'id': 1078, 'image_count': 1302}, {'id': 1079, 'image_count': 1977}, {'id': 1080, 'image_count': 2}, {'id': 1081, 'image_count': 29}, {'id': 1082, 'image_count': 36}, {'id': 1083, 'image_count': 138}, {'id': 1084, 'image_count': 4}, {'id': 1085, 'image_count': 67}, {'id': 1086, 'image_count': 26}, {'id': 1087, 'image_count': 25}, {'id': 1088, 'image_count': 33}, {'id': 1089, 'image_count': 37}, {'id': 1090, 'image_count': 50}, {'id': 1091, 'image_count': 270}, {'id': 1092, 'image_count': 12}, {'id': 1093, 'image_count': 316}, {'id': 1094, 'image_count': 41}, {'id': 1095, 'image_count': 224}, {'id': 1096, 'image_count': 105}, {'id': 1097, 'image_count': 1925}, {'id': 1098, 'image_count': 1021}, {'id': 1099, 'image_count': 1213}, {'id': 1100, 'image_count': 172}, {'id': 1101, 'image_count': 28}, {'id': 1102, 'image_count': 745}, {'id': 1103, 'image_count': 187}, {'id': 1104, 'image_count': 147}, {'id': 1105, 'image_count': 136}, {'id': 1106, 'image_count': 34}, {'id': 1107, 'image_count': 41}, {'id': 1108, 'image_count': 636}, {'id': 1109, 'image_count': 570}, {'id': 1110, 'image_count': 1149}, {'id': 1111, 'image_count': 61}, {'id': 1112, 'image_count': 1890}, {'id': 1113, 'image_count': 18}, {'id': 1114, 'image_count': 143}, {'id': 1115, 'image_count': 1517}, {'id': 1116, 'image_count': 7}, {'id': 1117, 'image_count': 943}, {'id': 1118, 'image_count': 6}, {'id': 1119, 'image_count': 1}, {'id': 1120, 'image_count': 11}, {'id': 1121, 'image_count': 101}, {'id': 1122, 'image_count': 1909}, {'id': 1123, 'image_count': 800}, {'id': 1124, 'image_count': 1}, {'id': 1125, 'image_count': 44}, {'id': 1126, 'image_count': 3}, {'id': 1127, 'image_count': 44}, {'id': 1128, 'image_count': 31}, {'id': 1129, 'image_count': 7}, {'id': 1130, 'image_count': 20}, {'id': 1131, 'image_count': 11}, {'id': 1132, 'image_count': 13}, {'id': 1133, 'image_count': 1924}, {'id': 1134, 'image_count': 113}, {'id': 1135, 'image_count': 2}, {'id': 1136, 'image_count': 139}, {'id': 1137, 'image_count': 12}, {'id': 1138, 'image_count': 37}, {'id': 1139, 'image_count': 1866}, {'id': 1140, 'image_count': 47}, {'id': 1141, 'image_count': 1468}, {'id': 1142, 'image_count': 729}, {'id': 1143, 'image_count': 24}, {'id': 1144, 'image_count': 1}, {'id': 1145, 'image_count': 10}, {'id': 1146, 'image_count': 3}, {'id': 1147, 'image_count': 14}, {'id': 1148, 'image_count': 4}, {'id': 1149, 'image_count': 29}, {'id': 1150, 'image_count': 4}, {'id': 1151, 'image_count': 70}, {'id': 1152, 'image_count': 46}, {'id': 1153, 'image_count': 14}, {'id': 1154, 'image_count': 48}, {'id': 1155, 'image_count': 1855}, {'id': 1156, 'image_count': 113}, {'id': 1157, 'image_count': 1}, {'id': 1158, 'image_count': 1}, {'id': 1159, 'image_count': 10}, {'id': 1160, 'image_count': 54}, {'id': 1161, 'image_count': 1923}, {'id': 1162, 'image_count': 630}, {'id': 1163, 'image_count': 31}, {'id': 1164, 'image_count': 69}, {'id': 1165, 'image_count': 7}, {'id': 1166, 'image_count': 11}, {'id': 1167, 'image_count': 1}, {'id': 1168, 'image_count': 30}, {'id': 1169, 'image_count': 50}, {'id': 1170, 'image_count': 45}, {'id': 1171, 'image_count': 28}, {'id': 1172, 'image_count': 114}, {'id': 1173, 'image_count': 193}, {'id': 1174, 'image_count': 21}, {'id': 1175, 'image_count': 91}, {'id': 1176, 'image_count': 31}, {'id': 1177, 'image_count': 1469}, {'id': 1178, 'image_count': 1924}, {'id': 1179, 'image_count': 87}, {'id': 1180, 'image_count': 77}, {'id': 1181, 'image_count': 11}, {'id': 1182, 'image_count': 47}, {'id': 1183, 'image_count': 21}, {'id': 1184, 'image_count': 47}, {'id': 1185, 'image_count': 70}, {'id': 1186, 'image_count': 1838}, {'id': 1187, 'image_count': 19}, {'id': 1188, 'image_count': 531}, {'id': 1189, 'image_count': 11}, {'id': 1190, 'image_count': 941}, {'id': 1191, 'image_count': 113}, {'id': 1192, 'image_count': 26}, {'id': 1193, 'image_count': 5}, {'id': 1194, 'image_count': 56}, {'id': 1195, 'image_count': 73}, {'id': 1196, 'image_count': 32}, {'id': 1197, 'image_count': 128}, {'id': 1198, 'image_count': 623}, {'id': 1199, 'image_count': 12}, {'id': 1200, 'image_count': 52}, {'id': 1201, 'image_count': 11}, {'id': 1202, 'image_count': 1674}, {'id': 1203, 'image_count': 81}] # noqa
20
+ # fmt: on
CatVTON/detectron2/data/datasets/pascal_voc.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) Facebook, Inc. and its affiliates.
3
+
4
+ import numpy as np
5
+ import os
6
+ import xml.etree.ElementTree as ET
7
+ from typing import List, Tuple, Union
8
+
9
+ from detectron2.data import DatasetCatalog, MetadataCatalog
10
+ from detectron2.structures import BoxMode
11
+ from detectron2.utils.file_io import PathManager
12
+
13
+ __all__ = ["load_voc_instances", "register_pascal_voc"]
14
+
15
+
16
+ # fmt: off
17
+ CLASS_NAMES = (
18
+ "aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat",
19
+ "chair", "cow", "diningtable", "dog", "horse", "motorbike", "person",
20
+ "pottedplant", "sheep", "sofa", "train", "tvmonitor"
21
+ )
22
+ # fmt: on
23
+
24
+
25
+ def load_voc_instances(dirname: str, split: str, class_names: Union[List[str], Tuple[str, ...]]):
26
+ """
27
+ Load Pascal VOC detection annotations to Detectron2 format.
28
+
29
+ Args:
30
+ dirname: Contain "Annotations", "ImageSets", "JPEGImages"
31
+ split (str): one of "train", "test", "val", "trainval"
32
+ class_names: list or tuple of class names
33
+ """
34
+ with PathManager.open(os.path.join(dirname, "ImageSets", "Main", split + ".txt")) as f:
35
+ fileids = np.loadtxt(f, dtype=str)
36
+
37
+ # Needs to read many small annotation files. Makes sense at local
38
+ annotation_dirname = PathManager.get_local_path(os.path.join(dirname, "Annotations/"))
39
+ dicts = []
40
+ for fileid in fileids:
41
+ anno_file = os.path.join(annotation_dirname, fileid + ".xml")
42
+ jpeg_file = os.path.join(dirname, "JPEGImages", fileid + ".jpg")
43
+
44
+ with PathManager.open(anno_file) as f:
45
+ tree = ET.parse(f)
46
+
47
+ r = {
48
+ "file_name": jpeg_file,
49
+ "image_id": fileid,
50
+ "height": int(tree.findall("./size/height")[0].text),
51
+ "width": int(tree.findall("./size/width")[0].text),
52
+ }
53
+ instances = []
54
+
55
+ for obj in tree.findall("object"):
56
+ cls = obj.find("name").text
57
+ # We include "difficult" samples in training.
58
+ # Based on limited experiments, they don't hurt accuracy.
59
+ # difficult = int(obj.find("difficult").text)
60
+ # if difficult == 1:
61
+ # continue
62
+ bbox = obj.find("bndbox")
63
+ bbox = [float(bbox.find(x).text) for x in ["xmin", "ymin", "xmax", "ymax"]]
64
+ # Original annotations are integers in the range [1, W or H]
65
+ # Assuming they mean 1-based pixel indices (inclusive),
66
+ # a box with annotation (xmin=1, xmax=W) covers the whole image.
67
+ # In coordinate space this is represented by (xmin=0, xmax=W)
68
+ bbox[0] -= 1.0
69
+ bbox[1] -= 1.0
70
+ instances.append(
71
+ {"category_id": class_names.index(cls), "bbox": bbox, "bbox_mode": BoxMode.XYXY_ABS}
72
+ )
73
+ r["annotations"] = instances
74
+ dicts.append(r)
75
+ return dicts
76
+
77
+
78
+ def register_pascal_voc(name, dirname, split, year, class_names=CLASS_NAMES):
79
+ DatasetCatalog.register(name, lambda: load_voc_instances(dirname, split, class_names))
80
+ MetadataCatalog.get(name).set(
81
+ thing_classes=list(class_names), dirname=dirname, year=year, split=split
82
+ )
CatVTON/detectron2/data/datasets/register_coco.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ from .coco import register_coco_instances # noqa
3
+ from .coco_panoptic import register_coco_panoptic_separated # noqa
CatVTON/detectron2/data/detection_utils.py ADDED
@@ -0,0 +1,661 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) Facebook, Inc. and its affiliates.
3
+
4
+ """
5
+ Common data processing utilities that are used in a
6
+ typical object detection data pipeline.
7
+ """
8
+ import logging
9
+ import numpy as np
10
+ from typing import List, Union
11
+ import pycocotools.mask as mask_util
12
+ import torch
13
+ from PIL import Image
14
+
15
+ from detectron2.structures import (
16
+ BitMasks,
17
+ Boxes,
18
+ BoxMode,
19
+ Instances,
20
+ Keypoints,
21
+ PolygonMasks,
22
+ RotatedBoxes,
23
+ polygons_to_bitmask,
24
+ )
25
+ from detectron2.utils.file_io import PathManager
26
+
27
+ from . import transforms as T
28
+ from .catalog import MetadataCatalog
29
+
30
+ __all__ = [
31
+ "SizeMismatchError",
32
+ "convert_image_to_rgb",
33
+ "check_image_size",
34
+ "transform_proposals",
35
+ "transform_instance_annotations",
36
+ "annotations_to_instances",
37
+ "annotations_to_instances_rotated",
38
+ "build_augmentation",
39
+ "build_transform_gen",
40
+ "create_keypoint_hflip_indices",
41
+ "filter_empty_instances",
42
+ "read_image",
43
+ ]
44
+
45
+
46
+ class SizeMismatchError(ValueError):
47
+ """
48
+ When loaded image has difference width/height compared with annotation.
49
+ """
50
+
51
+
52
+ # https://en.wikipedia.org/wiki/YUV#SDTV_with_BT.601
53
+ _M_RGB2YUV = [[0.299, 0.587, 0.114], [-0.14713, -0.28886, 0.436], [0.615, -0.51499, -0.10001]]
54
+ _M_YUV2RGB = [[1.0, 0.0, 1.13983], [1.0, -0.39465, -0.58060], [1.0, 2.03211, 0.0]]
55
+
56
+ # https://www.exiv2.org/tags.html
57
+ _EXIF_ORIENT = 274 # exif 'Orientation' tag
58
+
59
+
60
+ def convert_PIL_to_numpy(image, format):
61
+ """
62
+ Convert PIL image to numpy array of target format.
63
+
64
+ Args:
65
+ image (PIL.Image): a PIL image
66
+ format (str): the format of output image
67
+
68
+ Returns:
69
+ (np.ndarray): also see `read_image`
70
+ """
71
+ if format is not None:
72
+ # PIL only supports RGB, so convert to RGB and flip channels over below
73
+ conversion_format = format
74
+ if format in ["BGR", "YUV-BT.601"]:
75
+ conversion_format = "RGB"
76
+ image = image.convert(conversion_format)
77
+ image = np.asarray(image)
78
+ # PIL squeezes out the channel dimension for "L", so make it HWC
79
+ if format == "L":
80
+ image = np.expand_dims(image, -1)
81
+
82
+ # handle formats not supported by PIL
83
+ elif format == "BGR":
84
+ # flip channels if needed
85
+ image = image[:, :, ::-1]
86
+ elif format == "YUV-BT.601":
87
+ image = image / 255.0
88
+ image = np.dot(image, np.array(_M_RGB2YUV).T)
89
+
90
+ return image
91
+
92
+
93
+ def convert_image_to_rgb(image, format):
94
+ """
95
+ Convert an image from given format to RGB.
96
+
97
+ Args:
98
+ image (np.ndarray or Tensor): an HWC image
99
+ format (str): the format of input image, also see `read_image`
100
+
101
+ Returns:
102
+ (np.ndarray): (H,W,3) RGB image in 0-255 range, can be either float or uint8
103
+ """
104
+ if isinstance(image, torch.Tensor):
105
+ image = image.cpu().numpy()
106
+ if format == "BGR":
107
+ image = image[:, :, [2, 1, 0]]
108
+ elif format == "YUV-BT.601":
109
+ image = np.dot(image, np.array(_M_YUV2RGB).T)
110
+ image = image * 255.0
111
+ else:
112
+ if format == "L":
113
+ image = image[:, :, 0]
114
+ image = image.astype(np.uint8)
115
+ image = np.asarray(Image.fromarray(image, mode=format).convert("RGB"))
116
+ return image
117
+
118
+
119
+ def _apply_exif_orientation(image):
120
+ """
121
+ Applies the exif orientation correctly.
122
+
123
+ This code exists per the bug:
124
+ https://github.com/python-pillow/Pillow/issues/3973
125
+ with the function `ImageOps.exif_transpose`. The Pillow source raises errors with
126
+ various methods, especially `tobytes`
127
+
128
+ Function based on:
129
+ https://github.com/wkentaro/labelme/blob/v4.5.4/labelme/utils/image.py#L59
130
+ https://github.com/python-pillow/Pillow/blob/7.1.2/src/PIL/ImageOps.py#L527
131
+
132
+ Args:
133
+ image (PIL.Image): a PIL image
134
+
135
+ Returns:
136
+ (PIL.Image): the PIL image with exif orientation applied, if applicable
137
+ """
138
+ if not hasattr(image, "getexif"):
139
+ return image
140
+
141
+ try:
142
+ exif = image.getexif()
143
+ except Exception: # https://github.com/facebookresearch/detectron2/issues/1885
144
+ exif = None
145
+
146
+ if exif is None:
147
+ return image
148
+
149
+ orientation = exif.get(_EXIF_ORIENT)
150
+
151
+ method = {
152
+ 2: Image.FLIP_LEFT_RIGHT,
153
+ 3: Image.ROTATE_180,
154
+ 4: Image.FLIP_TOP_BOTTOM,
155
+ 5: Image.TRANSPOSE,
156
+ 6: Image.ROTATE_270,
157
+ 7: Image.TRANSVERSE,
158
+ 8: Image.ROTATE_90,
159
+ }.get(orientation)
160
+
161
+ if method is not None:
162
+ return image.transpose(method)
163
+ return image
164
+
165
+
166
+ def read_image(file_name, format=None):
167
+ """
168
+ Read an image into the given format.
169
+ Will apply rotation and flipping if the image has such exif information.
170
+
171
+ Args:
172
+ file_name (str): image file path
173
+ format (str): one of the supported image modes in PIL, or "BGR" or "YUV-BT.601".
174
+
175
+ Returns:
176
+ image (np.ndarray):
177
+ an HWC image in the given format, which is 0-255, uint8 for
178
+ supported image modes in PIL or "BGR"; float (0-1 for Y) for YUV-BT.601.
179
+ """
180
+ with PathManager.open(file_name, "rb") as f:
181
+ image = Image.open(f)
182
+
183
+ # work around this bug: https://github.com/python-pillow/Pillow/issues/3973
184
+ image = _apply_exif_orientation(image)
185
+ return convert_PIL_to_numpy(image, format)
186
+
187
+
188
+ def check_image_size(dataset_dict, image):
189
+ """
190
+ Raise an error if the image does not match the size specified in the dict.
191
+ """
192
+ if "width" in dataset_dict or "height" in dataset_dict:
193
+ image_wh = (image.shape[1], image.shape[0])
194
+ expected_wh = (dataset_dict["width"], dataset_dict["height"])
195
+ if not image_wh == expected_wh:
196
+ raise SizeMismatchError(
197
+ "Mismatched image shape{}, got {}, expect {}.".format(
198
+ (
199
+ " for image " + dataset_dict["file_name"]
200
+ if "file_name" in dataset_dict
201
+ else ""
202
+ ),
203
+ image_wh,
204
+ expected_wh,
205
+ )
206
+ + " Please check the width/height in your annotation."
207
+ )
208
+
209
+ # To ensure bbox always remap to original image size
210
+ if "width" not in dataset_dict:
211
+ dataset_dict["width"] = image.shape[1]
212
+ if "height" not in dataset_dict:
213
+ dataset_dict["height"] = image.shape[0]
214
+
215
+
216
+ def transform_proposals(dataset_dict, image_shape, transforms, *, proposal_topk, min_box_size=0):
217
+ """
218
+ Apply transformations to the proposals in dataset_dict, if any.
219
+
220
+ Args:
221
+ dataset_dict (dict): a dict read from the dataset, possibly
222
+ contains fields "proposal_boxes", "proposal_objectness_logits", "proposal_bbox_mode"
223
+ image_shape (tuple): height, width
224
+ transforms (TransformList):
225
+ proposal_topk (int): only keep top-K scoring proposals
226
+ min_box_size (int): proposals with either side smaller than this
227
+ threshold are removed
228
+
229
+ The input dict is modified in-place, with abovementioned keys removed. A new
230
+ key "proposals" will be added. Its value is an `Instances`
231
+ object which contains the transformed proposals in its field
232
+ "proposal_boxes" and "objectness_logits".
233
+ """
234
+ if "proposal_boxes" in dataset_dict:
235
+ # Transform proposal boxes
236
+ boxes = transforms.apply_box(
237
+ BoxMode.convert(
238
+ dataset_dict.pop("proposal_boxes"),
239
+ dataset_dict.pop("proposal_bbox_mode"),
240
+ BoxMode.XYXY_ABS,
241
+ )
242
+ )
243
+ boxes = Boxes(boxes)
244
+ objectness_logits = torch.as_tensor(
245
+ dataset_dict.pop("proposal_objectness_logits").astype("float32")
246
+ )
247
+
248
+ boxes.clip(image_shape)
249
+ keep = boxes.nonempty(threshold=min_box_size)
250
+ boxes = boxes[keep]
251
+ objectness_logits = objectness_logits[keep]
252
+
253
+ proposals = Instances(image_shape)
254
+ proposals.proposal_boxes = boxes[:proposal_topk]
255
+ proposals.objectness_logits = objectness_logits[:proposal_topk]
256
+ dataset_dict["proposals"] = proposals
257
+
258
+
259
+ def get_bbox(annotation):
260
+ """
261
+ Get bbox from data
262
+ Args:
263
+ annotation (dict): dict of instance annotations for a single instance.
264
+ Returns:
265
+ bbox (ndarray): x1, y1, x2, y2 coordinates
266
+ """
267
+ # bbox is 1d (per-instance bounding box)
268
+ bbox = BoxMode.convert(annotation["bbox"], annotation["bbox_mode"], BoxMode.XYXY_ABS)
269
+ return bbox
270
+
271
+
272
+ def transform_instance_annotations(
273
+ annotation, transforms, image_size, *, keypoint_hflip_indices=None
274
+ ):
275
+ """
276
+ Apply transforms to box, segmentation and keypoints annotations of a single instance.
277
+
278
+ It will use `transforms.apply_box` for the box, and
279
+ `transforms.apply_coords` for segmentation polygons & keypoints.
280
+ If you need anything more specially designed for each data structure,
281
+ you'll need to implement your own version of this function or the transforms.
282
+
283
+ Args:
284
+ annotation (dict): dict of instance annotations for a single instance.
285
+ It will be modified in-place.
286
+ transforms (TransformList or list[Transform]):
287
+ image_size (tuple): the height, width of the transformed image
288
+ keypoint_hflip_indices (ndarray[int]): see `create_keypoint_hflip_indices`.
289
+
290
+ Returns:
291
+ dict:
292
+ the same input dict with fields "bbox", "segmentation", "keypoints"
293
+ transformed according to `transforms`.
294
+ The "bbox_mode" field will be set to XYXY_ABS.
295
+ """
296
+ if isinstance(transforms, (tuple, list)):
297
+ transforms = T.TransformList(transforms)
298
+ # bbox is 1d (per-instance bounding box)
299
+ bbox = BoxMode.convert(annotation["bbox"], annotation["bbox_mode"], BoxMode.XYXY_ABS)
300
+ # clip transformed bbox to image size
301
+ bbox = transforms.apply_box(np.array([bbox]))[0].clip(min=0)
302
+ annotation["bbox"] = np.minimum(bbox, list(image_size + image_size)[::-1])
303
+ annotation["bbox_mode"] = BoxMode.XYXY_ABS
304
+
305
+ if "segmentation" in annotation:
306
+ # each instance contains 1 or more polygons
307
+ segm = annotation["segmentation"]
308
+ if isinstance(segm, list):
309
+ # polygons
310
+ polygons = [np.asarray(p).reshape(-1, 2) for p in segm]
311
+ annotation["segmentation"] = [
312
+ p.reshape(-1) for p in transforms.apply_polygons(polygons)
313
+ ]
314
+ elif isinstance(segm, dict):
315
+ # RLE
316
+ mask = mask_util.decode(segm)
317
+ mask = transforms.apply_segmentation(mask)
318
+ assert tuple(mask.shape[:2]) == image_size
319
+ annotation["segmentation"] = mask
320
+ else:
321
+ raise ValueError(
322
+ "Cannot transform segmentation of type '{}'!"
323
+ "Supported types are: polygons as list[list[float] or ndarray],"
324
+ " COCO-style RLE as a dict.".format(type(segm))
325
+ )
326
+
327
+ if "keypoints" in annotation:
328
+ keypoints = transform_keypoint_annotations(
329
+ annotation["keypoints"], transforms, image_size, keypoint_hflip_indices
330
+ )
331
+ annotation["keypoints"] = keypoints
332
+
333
+ return annotation
334
+
335
+
336
+ def transform_keypoint_annotations(keypoints, transforms, image_size, keypoint_hflip_indices=None):
337
+ """
338
+ Transform keypoint annotations of an image.
339
+ If a keypoint is transformed out of image boundary, it will be marked "unlabeled" (visibility=0)
340
+
341
+ Args:
342
+ keypoints (list[float]): Nx3 float in Detectron2's Dataset format.
343
+ Each point is represented by (x, y, visibility).
344
+ transforms (TransformList):
345
+ image_size (tuple): the height, width of the transformed image
346
+ keypoint_hflip_indices (ndarray[int]): see `create_keypoint_hflip_indices`.
347
+ When `transforms` includes horizontal flip, will use the index
348
+ mapping to flip keypoints.
349
+ """
350
+ # (N*3,) -> (N, 3)
351
+ keypoints = np.asarray(keypoints, dtype="float64").reshape(-1, 3)
352
+ keypoints_xy = transforms.apply_coords(keypoints[:, :2])
353
+
354
+ # Set all out-of-boundary points to "unlabeled"
355
+ inside = (keypoints_xy >= np.array([0, 0])) & (keypoints_xy <= np.array(image_size[::-1]))
356
+ inside = inside.all(axis=1)
357
+ keypoints[:, :2] = keypoints_xy
358
+ keypoints[:, 2][~inside] = 0
359
+
360
+ # This assumes that HorizFlipTransform is the only one that does flip
361
+ do_hflip = sum(isinstance(t, T.HFlipTransform) for t in transforms.transforms) % 2 == 1
362
+
363
+ # Alternative way: check if probe points was horizontally flipped.
364
+ # probe = np.asarray([[0.0, 0.0], [image_width, 0.0]])
365
+ # probe_aug = transforms.apply_coords(probe.copy())
366
+ # do_hflip = np.sign(probe[1][0] - probe[0][0]) != np.sign(probe_aug[1][0] - probe_aug[0][0]) # noqa
367
+
368
+ # If flipped, swap each keypoint with its opposite-handed equivalent
369
+ if do_hflip:
370
+ if keypoint_hflip_indices is None:
371
+ raise ValueError("Cannot flip keypoints without providing flip indices!")
372
+ if len(keypoints) != len(keypoint_hflip_indices):
373
+ raise ValueError(
374
+ "Keypoint data has {} points, but metadata "
375
+ "contains {} points!".format(len(keypoints), len(keypoint_hflip_indices))
376
+ )
377
+ keypoints = keypoints[np.asarray(keypoint_hflip_indices, dtype=np.int32), :]
378
+
379
+ # Maintain COCO convention that if visibility == 0 (unlabeled), then x, y = 0
380
+ keypoints[keypoints[:, 2] == 0] = 0
381
+ return keypoints
382
+
383
+
384
+ def annotations_to_instances(annos, image_size, mask_format="polygon"):
385
+ """
386
+ Create an :class:`Instances` object used by the models,
387
+ from instance annotations in the dataset dict.
388
+
389
+ Args:
390
+ annos (list[dict]): a list of instance annotations in one image, each
391
+ element for one instance.
392
+ image_size (tuple): height, width
393
+
394
+ Returns:
395
+ Instances:
396
+ It will contain fields "gt_boxes", "gt_classes",
397
+ "gt_masks", "gt_keypoints", if they can be obtained from `annos`.
398
+ This is the format that builtin models expect.
399
+ """
400
+ boxes = (
401
+ np.stack(
402
+ [BoxMode.convert(obj["bbox"], obj["bbox_mode"], BoxMode.XYXY_ABS) for obj in annos]
403
+ )
404
+ if len(annos)
405
+ else np.zeros((0, 4))
406
+ )
407
+ target = Instances(image_size)
408
+ target.gt_boxes = Boxes(boxes)
409
+
410
+ classes = [int(obj["category_id"]) for obj in annos]
411
+ classes = torch.tensor(classes, dtype=torch.int64)
412
+ target.gt_classes = classes
413
+
414
+ if len(annos) and "segmentation" in annos[0]:
415
+ segms = [obj["segmentation"] for obj in annos]
416
+ if mask_format == "polygon":
417
+ try:
418
+ masks = PolygonMasks(segms)
419
+ except ValueError as e:
420
+ raise ValueError(
421
+ "Failed to use mask_format=='polygon' from the given annotations!"
422
+ ) from e
423
+ else:
424
+ assert mask_format == "bitmask", mask_format
425
+ masks = []
426
+ for segm in segms:
427
+ if isinstance(segm, list):
428
+ # polygon
429
+ masks.append(polygons_to_bitmask(segm, *image_size))
430
+ elif isinstance(segm, dict):
431
+ # COCO RLE
432
+ masks.append(mask_util.decode(segm))
433
+ elif isinstance(segm, np.ndarray):
434
+ assert segm.ndim == 2, "Expect segmentation of 2 dimensions, got {}.".format(
435
+ segm.ndim
436
+ )
437
+ # mask array
438
+ masks.append(segm)
439
+ else:
440
+ raise ValueError(
441
+ "Cannot convert segmentation of type '{}' to BitMasks!"
442
+ "Supported types are: polygons as list[list[float] or ndarray],"
443
+ " COCO-style RLE as a dict, or a binary segmentation mask "
444
+ " in a 2D numpy array of shape HxW.".format(type(segm))
445
+ )
446
+ # torch.from_numpy does not support array with negative stride.
447
+ masks = BitMasks(
448
+ torch.stack([torch.from_numpy(np.ascontiguousarray(x)) for x in masks])
449
+ )
450
+ target.gt_masks = masks
451
+
452
+ if len(annos) and "keypoints" in annos[0]:
453
+ kpts = [obj.get("keypoints", []) for obj in annos]
454
+ target.gt_keypoints = Keypoints(kpts)
455
+
456
+ return target
457
+
458
+
459
+ def annotations_to_instances_rotated(annos, image_size):
460
+ """
461
+ Create an :class:`Instances` object used by the models,
462
+ from instance annotations in the dataset dict.
463
+ Compared to `annotations_to_instances`, this function is for rotated boxes only
464
+
465
+ Args:
466
+ annos (list[dict]): a list of instance annotations in one image, each
467
+ element for one instance.
468
+ image_size (tuple): height, width
469
+
470
+ Returns:
471
+ Instances:
472
+ Containing fields "gt_boxes", "gt_classes",
473
+ if they can be obtained from `annos`.
474
+ This is the format that builtin models expect.
475
+ """
476
+ boxes = [obj["bbox"] for obj in annos]
477
+ target = Instances(image_size)
478
+ boxes = target.gt_boxes = RotatedBoxes(boxes)
479
+ boxes.clip(image_size)
480
+
481
+ classes = [obj["category_id"] for obj in annos]
482
+ classes = torch.tensor(classes, dtype=torch.int64)
483
+ target.gt_classes = classes
484
+
485
+ return target
486
+
487
+
488
+ def filter_empty_instances(
489
+ instances, by_box=True, by_mask=True, box_threshold=1e-5, return_mask=False
490
+ ):
491
+ """
492
+ Filter out empty instances in an `Instances` object.
493
+
494
+ Args:
495
+ instances (Instances):
496
+ by_box (bool): whether to filter out instances with empty boxes
497
+ by_mask (bool): whether to filter out instances with empty masks
498
+ box_threshold (float): minimum width and height to be considered non-empty
499
+ return_mask (bool): whether to return boolean mask of filtered instances
500
+
501
+ Returns:
502
+ Instances: the filtered instances.
503
+ tensor[bool], optional: boolean mask of filtered instances
504
+ """
505
+ assert by_box or by_mask
506
+ r = []
507
+ if by_box:
508
+ r.append(instances.gt_boxes.nonempty(threshold=box_threshold))
509
+ if instances.has("gt_masks") and by_mask:
510
+ r.append(instances.gt_masks.nonempty())
511
+
512
+ # TODO: can also filter visible keypoints
513
+
514
+ if not r:
515
+ return instances
516
+ m = r[0]
517
+ for x in r[1:]:
518
+ m = m & x
519
+ if return_mask:
520
+ return instances[m], m
521
+ return instances[m]
522
+
523
+
524
+ def create_keypoint_hflip_indices(dataset_names: Union[str, List[str]]) -> List[int]:
525
+ """
526
+ Args:
527
+ dataset_names: list of dataset names
528
+
529
+ Returns:
530
+ list[int]: a list of size=#keypoints, storing the
531
+ horizontally-flipped keypoint indices.
532
+ """
533
+ if isinstance(dataset_names, str):
534
+ dataset_names = [dataset_names]
535
+
536
+ check_metadata_consistency("keypoint_names", dataset_names)
537
+ check_metadata_consistency("keypoint_flip_map", dataset_names)
538
+
539
+ meta = MetadataCatalog.get(dataset_names[0])
540
+ names = meta.keypoint_names
541
+ # TODO flip -> hflip
542
+ flip_map = dict(meta.keypoint_flip_map)
543
+ flip_map.update({v: k for k, v in flip_map.items()})
544
+ flipped_names = [i if i not in flip_map else flip_map[i] for i in names]
545
+ flip_indices = [names.index(i) for i in flipped_names]
546
+ return flip_indices
547
+
548
+
549
+ def get_fed_loss_cls_weights(dataset_names: Union[str, List[str]], freq_weight_power=1.0):
550
+ """
551
+ Get frequency weight for each class sorted by class id.
552
+ We now calcualte freqency weight using image_count to the power freq_weight_power.
553
+
554
+ Args:
555
+ dataset_names: list of dataset names
556
+ freq_weight_power: power value
557
+ """
558
+ if isinstance(dataset_names, str):
559
+ dataset_names = [dataset_names]
560
+
561
+ check_metadata_consistency("class_image_count", dataset_names)
562
+
563
+ meta = MetadataCatalog.get(dataset_names[0])
564
+ class_freq_meta = meta.class_image_count
565
+ class_freq = torch.tensor(
566
+ [c["image_count"] for c in sorted(class_freq_meta, key=lambda x: x["id"])]
567
+ )
568
+ class_freq_weight = class_freq.float() ** freq_weight_power
569
+ return class_freq_weight
570
+
571
+
572
+ def gen_crop_transform_with_instance(crop_size, image_size, instance):
573
+ """
574
+ Generate a CropTransform so that the cropping region contains
575
+ the center of the given instance.
576
+
577
+ Args:
578
+ crop_size (tuple): h, w in pixels
579
+ image_size (tuple): h, w
580
+ instance (dict): an annotation dict of one instance, in Detectron2's
581
+ dataset format.
582
+ """
583
+ crop_size = np.asarray(crop_size, dtype=np.int32)
584
+ bbox = BoxMode.convert(instance["bbox"], instance["bbox_mode"], BoxMode.XYXY_ABS)
585
+ center_yx = (bbox[1] + bbox[3]) * 0.5, (bbox[0] + bbox[2]) * 0.5
586
+ assert (
587
+ image_size[0] >= center_yx[0] and image_size[1] >= center_yx[1]
588
+ ), "The annotation bounding box is outside of the image!"
589
+ assert (
590
+ image_size[0] >= crop_size[0] and image_size[1] >= crop_size[1]
591
+ ), "Crop size is larger than image size!"
592
+
593
+ min_yx = np.maximum(np.floor(center_yx).astype(np.int32) - crop_size, 0)
594
+ max_yx = np.maximum(np.asarray(image_size, dtype=np.int32) - crop_size, 0)
595
+ max_yx = np.minimum(max_yx, np.ceil(center_yx).astype(np.int32))
596
+
597
+ y0 = np.random.randint(min_yx[0], max_yx[0] + 1)
598
+ x0 = np.random.randint(min_yx[1], max_yx[1] + 1)
599
+ return T.CropTransform(x0, y0, crop_size[1], crop_size[0])
600
+
601
+
602
+ def check_metadata_consistency(key, dataset_names):
603
+ """
604
+ Check that the datasets have consistent metadata.
605
+
606
+ Args:
607
+ key (str): a metadata key
608
+ dataset_names (list[str]): a list of dataset names
609
+
610
+ Raises:
611
+ AttributeError: if the key does not exist in the metadata
612
+ ValueError: if the given datasets do not have the same metadata values defined by key
613
+ """
614
+ if len(dataset_names) == 0:
615
+ return
616
+ logger = logging.getLogger(__name__)
617
+ entries_per_dataset = [getattr(MetadataCatalog.get(d), key) for d in dataset_names]
618
+ for idx, entry in enumerate(entries_per_dataset):
619
+ if entry != entries_per_dataset[0]:
620
+ logger.error(
621
+ "Metadata '{}' for dataset '{}' is '{}'".format(key, dataset_names[idx], str(entry))
622
+ )
623
+ logger.error(
624
+ "Metadata '{}' for dataset '{}' is '{}'".format(
625
+ key, dataset_names[0], str(entries_per_dataset[0])
626
+ )
627
+ )
628
+ raise ValueError("Datasets have different metadata '{}'!".format(key))
629
+
630
+
631
+ def build_augmentation(cfg, is_train):
632
+ """
633
+ Create a list of default :class:`Augmentation` from config.
634
+ Now it includes resizing and flipping.
635
+
636
+ Returns:
637
+ list[Augmentation]
638
+ """
639
+ if is_train:
640
+ min_size = cfg.INPUT.MIN_SIZE_TRAIN
641
+ max_size = cfg.INPUT.MAX_SIZE_TRAIN
642
+ sample_style = cfg.INPUT.MIN_SIZE_TRAIN_SAMPLING
643
+ else:
644
+ min_size = cfg.INPUT.MIN_SIZE_TEST
645
+ max_size = cfg.INPUT.MAX_SIZE_TEST
646
+ sample_style = "choice"
647
+ augmentation = [T.ResizeShortestEdge(min_size, max_size, sample_style)]
648
+ if is_train and cfg.INPUT.RANDOM_FLIP != "none":
649
+ augmentation.append(
650
+ T.RandomFlip(
651
+ horizontal=cfg.INPUT.RANDOM_FLIP == "horizontal",
652
+ vertical=cfg.INPUT.RANDOM_FLIP == "vertical",
653
+ )
654
+ )
655
+ return augmentation
656
+
657
+
658
+ build_transform_gen = build_augmentation
659
+ """
660
+ Alias for backward-compatibility.
661
+ """
CatVTON/detectron2/data/samplers/__init__.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ from .distributed_sampler import (
3
+ InferenceSampler,
4
+ RandomSubsetTrainingSampler,
5
+ RepeatFactorTrainingSampler,
6
+ TrainingSampler,
7
+ )
8
+
9
+ from .grouped_batch_sampler import GroupedBatchSampler
10
+
11
+ __all__ = [
12
+ "GroupedBatchSampler",
13
+ "TrainingSampler",
14
+ "RandomSubsetTrainingSampler",
15
+ "InferenceSampler",
16
+ "RepeatFactorTrainingSampler",
17
+ ]
CatVTON/detectron2/data/samplers/distributed_sampler.py ADDED
@@ -0,0 +1,287 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ import itertools
3
+ import logging
4
+ import math
5
+ from collections import defaultdict
6
+ from typing import Optional
7
+ import torch
8
+ from torch.utils.data.sampler import Sampler
9
+
10
+ from detectron2.utils import comm
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ class TrainingSampler(Sampler):
16
+ """
17
+ In training, we only care about the "infinite stream" of training data.
18
+ So this sampler produces an infinite stream of indices and
19
+ all workers cooperate to correctly shuffle the indices and sample different indices.
20
+
21
+ The samplers in each worker effectively produces `indices[worker_id::num_workers]`
22
+ where `indices` is an infinite stream of indices consisting of
23
+ `shuffle(range(size)) + shuffle(range(size)) + ...` (if shuffle is True)
24
+ or `range(size) + range(size) + ...` (if shuffle is False)
25
+
26
+ Note that this sampler does not shard based on pytorch DataLoader worker id.
27
+ A sampler passed to pytorch DataLoader is used only with map-style dataset
28
+ and will not be executed inside workers.
29
+ But if this sampler is used in a way that it gets execute inside a dataloader
30
+ worker, then extra work needs to be done to shard its outputs based on worker id.
31
+ This is required so that workers don't produce identical data.
32
+ :class:`ToIterableDataset` implements this logic.
33
+ This note is true for all samplers in detectron2.
34
+ """
35
+
36
+ def __init__(self, size: int, shuffle: bool = True, seed: Optional[int] = None):
37
+ """
38
+ Args:
39
+ size (int): the total number of data of the underlying dataset to sample from
40
+ shuffle (bool): whether to shuffle the indices or not
41
+ seed (int): the initial seed of the shuffle. Must be the same
42
+ across all workers. If None, will use a random seed shared
43
+ among workers (require synchronization among all workers).
44
+ """
45
+ if not isinstance(size, int):
46
+ raise TypeError(f"TrainingSampler(size=) expects an int. Got type {type(size)}.")
47
+ if size <= 0:
48
+ raise ValueError(f"TrainingSampler(size=) expects a positive int. Got {size}.")
49
+ self._size = size
50
+ self._shuffle = shuffle
51
+ if seed is None:
52
+ seed = comm.shared_random_seed()
53
+ self._seed = int(seed)
54
+
55
+ self._rank = comm.get_rank()
56
+ self._world_size = comm.get_world_size()
57
+
58
+ def __iter__(self):
59
+ start = self._rank
60
+ yield from itertools.islice(self._infinite_indices(), start, None, self._world_size)
61
+
62
+ def _infinite_indices(self):
63
+ g = torch.Generator()
64
+ if self._seed is not None:
65
+ g.manual_seed(self._seed)
66
+ while True:
67
+ if self._shuffle:
68
+ yield from torch.randperm(self._size, generator=g).tolist()
69
+ else:
70
+ yield from torch.arange(self._size).tolist()
71
+
72
+
73
+ class RandomSubsetTrainingSampler(TrainingSampler):
74
+ """
75
+ Similar to TrainingSampler, but only sample a random subset of indices.
76
+ This is useful when you want to estimate the accuracy vs data-number curves by
77
+ training the model with different subset_ratio.
78
+ """
79
+
80
+ def __init__(
81
+ self,
82
+ size: int,
83
+ subset_ratio: float,
84
+ shuffle: bool = True,
85
+ seed_shuffle: Optional[int] = None,
86
+ seed_subset: Optional[int] = None,
87
+ ):
88
+ """
89
+ Args:
90
+ size (int): the total number of data of the underlying dataset to sample from
91
+ subset_ratio (float): the ratio of subset data to sample from the underlying dataset
92
+ shuffle (bool): whether to shuffle the indices or not
93
+ seed_shuffle (int): the initial seed of the shuffle. Must be the same
94
+ across all workers. If None, will use a random seed shared
95
+ among workers (require synchronization among all workers).
96
+ seed_subset (int): the seed to randomize the subset to be sampled.
97
+ Must be the same across all workers. If None, will use a random seed shared
98
+ among workers (require synchronization among all workers).
99
+ """
100
+ super().__init__(size=size, shuffle=shuffle, seed=seed_shuffle)
101
+
102
+ assert 0.0 < subset_ratio <= 1.0
103
+ self._size_subset = int(size * subset_ratio)
104
+ assert self._size_subset > 0
105
+ if seed_subset is None:
106
+ seed_subset = comm.shared_random_seed()
107
+ self._seed_subset = int(seed_subset)
108
+
109
+ # randomly generate the subset indexes to be sampled from
110
+ g = torch.Generator()
111
+ g.manual_seed(self._seed_subset)
112
+ indexes_randperm = torch.randperm(self._size, generator=g)
113
+ self._indexes_subset = indexes_randperm[: self._size_subset]
114
+
115
+ logger.info("Using RandomSubsetTrainingSampler......")
116
+ logger.info(f"Randomly sample {self._size_subset} data from the original {self._size} data")
117
+
118
+ def _infinite_indices(self):
119
+ g = torch.Generator()
120
+ g.manual_seed(self._seed) # self._seed equals seed_shuffle from __init__()
121
+ while True:
122
+ if self._shuffle:
123
+ # generate a random permutation to shuffle self._indexes_subset
124
+ randperm = torch.randperm(self._size_subset, generator=g)
125
+ yield from self._indexes_subset[randperm].tolist()
126
+ else:
127
+ yield from self._indexes_subset.tolist()
128
+
129
+
130
+ class RepeatFactorTrainingSampler(Sampler):
131
+ """
132
+ Similar to TrainingSampler, but a sample may appear more times than others based
133
+ on its "repeat factor". This is suitable for training on class imbalanced datasets like LVIS.
134
+ """
135
+
136
+ def __init__(self, repeat_factors, *, shuffle=True, seed=None):
137
+ """
138
+ Args:
139
+ repeat_factors (Tensor): a float vector, the repeat factor for each indice. When it's
140
+ full of ones, it is equivalent to ``TrainingSampler(len(repeat_factors), ...)``.
141
+ shuffle (bool): whether to shuffle the indices or not
142
+ seed (int): the initial seed of the shuffle. Must be the same
143
+ across all workers. If None, will use a random seed shared
144
+ among workers (require synchronization among all workers).
145
+ """
146
+ self._shuffle = shuffle
147
+ if seed is None:
148
+ seed = comm.shared_random_seed()
149
+ self._seed = int(seed)
150
+
151
+ self._rank = comm.get_rank()
152
+ self._world_size = comm.get_world_size()
153
+
154
+ # Split into whole number (_int_part) and fractional (_frac_part) parts.
155
+ self._int_part = torch.trunc(repeat_factors)
156
+ self._frac_part = repeat_factors - self._int_part
157
+
158
+ @staticmethod
159
+ def repeat_factors_from_category_frequency(dataset_dicts, repeat_thresh, sqrt=True):
160
+ """
161
+ Compute (fractional) per-image repeat factors based on category frequency.
162
+ The repeat factor for an image is a function of the frequency of the rarest
163
+ category labeled in that image. The "frequency of category c" in [0, 1] is defined
164
+ as the fraction of images in the training set (without repeats) in which category c
165
+ appears.
166
+ See :paper:`lvis` (>= v2) Appendix B.2.
167
+
168
+ Args:
169
+ dataset_dicts (list[dict]): annotations in Detectron2 dataset format.
170
+ repeat_thresh (float): frequency threshold below which data is repeated.
171
+ If the frequency is half of `repeat_thresh`, the image will be
172
+ repeated twice.
173
+ sqrt (bool): if True, apply :func:`math.sqrt` to the repeat factor.
174
+
175
+ Returns:
176
+ torch.Tensor:
177
+ the i-th element is the repeat factor for the dataset image at index i.
178
+ """
179
+ # 1. For each category c, compute the fraction of images that contain it: f(c)
180
+ category_freq = defaultdict(int)
181
+ for dataset_dict in dataset_dicts: # For each image (without repeats)
182
+ cat_ids = {ann["category_id"] for ann in dataset_dict["annotations"]}
183
+ for cat_id in cat_ids:
184
+ category_freq[cat_id] += 1
185
+ num_images = len(dataset_dicts)
186
+ for k, v in category_freq.items():
187
+ category_freq[k] = v / num_images
188
+
189
+ # 2. For each category c, compute the category-level repeat factor:
190
+ # r(c) = max(1, sqrt(t / f(c)))
191
+ category_rep = {
192
+ cat_id: max(
193
+ 1.0,
194
+ (math.sqrt(repeat_thresh / cat_freq) if sqrt else (repeat_thresh / cat_freq)),
195
+ )
196
+ for cat_id, cat_freq in category_freq.items()
197
+ }
198
+ for cat_id in sorted(category_rep.keys()):
199
+ logger.info(
200
+ f"Cat ID {cat_id}: freq={category_freq[cat_id]:.2f}, rep={category_rep[cat_id]:.2f}"
201
+ )
202
+
203
+ # 3. For each image I, compute the image-level repeat factor:
204
+ # r(I) = max_{c in I} r(c)
205
+ rep_factors = []
206
+ for dataset_dict in dataset_dicts:
207
+ cat_ids = {ann["category_id"] for ann in dataset_dict["annotations"]}
208
+ rep_factor = max({category_rep[cat_id] for cat_id in cat_ids}, default=1.0)
209
+ rep_factors.append(rep_factor)
210
+
211
+ return torch.tensor(rep_factors, dtype=torch.float32)
212
+
213
+ def _get_epoch_indices(self, generator):
214
+ """
215
+ Create a list of dataset indices (with repeats) to use for one epoch.
216
+
217
+ Args:
218
+ generator (torch.Generator): pseudo random number generator used for
219
+ stochastic rounding.
220
+
221
+ Returns:
222
+ torch.Tensor: list of dataset indices to use in one epoch. Each index
223
+ is repeated based on its calculated repeat factor.
224
+ """
225
+ # Since repeat factors are fractional, we use stochastic rounding so
226
+ # that the target repeat factor is achieved in expectation over the
227
+ # course of training
228
+ rands = torch.rand(len(self._frac_part), generator=generator)
229
+ rep_factors = self._int_part + (rands < self._frac_part).float()
230
+ # Construct a list of indices in which we repeat images as specified
231
+ indices = []
232
+ for dataset_index, rep_factor in enumerate(rep_factors):
233
+ indices.extend([dataset_index] * int(rep_factor.item()))
234
+ return torch.tensor(indices, dtype=torch.int64)
235
+
236
+ def __iter__(self):
237
+ start = self._rank
238
+ yield from itertools.islice(self._infinite_indices(), start, None, self._world_size)
239
+
240
+ def _infinite_indices(self):
241
+ g = torch.Generator()
242
+ g.manual_seed(self._seed)
243
+ while True:
244
+ # Sample indices with repeats determined by stochastic rounding; each
245
+ # "epoch" may have a slightly different size due to the rounding.
246
+ indices = self._get_epoch_indices(g)
247
+ if self._shuffle:
248
+ randperm = torch.randperm(len(indices), generator=g)
249
+ yield from indices[randperm].tolist()
250
+ else:
251
+ yield from indices.tolist()
252
+
253
+
254
+ class InferenceSampler(Sampler):
255
+ """
256
+ Produce indices for inference across all workers.
257
+ Inference needs to run on the __exact__ set of samples,
258
+ therefore when the total number of samples is not divisible by the number of workers,
259
+ this sampler produces different number of samples on different workers.
260
+ """
261
+
262
+ def __init__(self, size: int):
263
+ """
264
+ Args:
265
+ size (int): the total number of data of the underlying dataset to sample from
266
+ """
267
+ self._size = size
268
+ assert size > 0
269
+ self._rank = comm.get_rank()
270
+ self._world_size = comm.get_world_size()
271
+ self._local_indices = self._get_local_indices(size, self._world_size, self._rank)
272
+
273
+ @staticmethod
274
+ def _get_local_indices(total_size, world_size, rank):
275
+ shard_size = total_size // world_size
276
+ left = total_size % world_size
277
+ shard_sizes = [shard_size + int(r < left) for r in range(world_size)]
278
+
279
+ begin = sum(shard_sizes[:rank])
280
+ end = min(sum(shard_sizes[: rank + 1]), total_size)
281
+ return range(begin, end)
282
+
283
+ def __iter__(self):
284
+ yield from self._local_indices
285
+
286
+ def __len__(self):
287
+ return len(self._local_indices)
CatVTON/detectron2/data/samplers/grouped_batch_sampler.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ import numpy as np
3
+ from torch.utils.data.sampler import BatchSampler, Sampler
4
+
5
+
6
+ class GroupedBatchSampler(BatchSampler):
7
+ """
8
+ Wraps another sampler to yield a mini-batch of indices.
9
+ It enforces that the batch only contain elements from the same group.
10
+ It also tries to provide mini-batches which follows an ordering which is
11
+ as close as possible to the ordering from the original sampler.
12
+ """
13
+
14
+ def __init__(self, sampler, group_ids, batch_size):
15
+ """
16
+ Args:
17
+ sampler (Sampler): Base sampler.
18
+ group_ids (list[int]): If the sampler produces indices in range [0, N),
19
+ `group_ids` must be a list of `N` ints which contains the group id of each sample.
20
+ The group ids must be a set of integers in the range [0, num_groups).
21
+ batch_size (int): Size of mini-batch.
22
+ """
23
+ if not isinstance(sampler, Sampler):
24
+ raise ValueError(
25
+ "sampler should be an instance of "
26
+ "torch.utils.data.Sampler, but got sampler={}".format(sampler)
27
+ )
28
+ self.sampler = sampler
29
+ self.group_ids = np.asarray(group_ids)
30
+ assert self.group_ids.ndim == 1
31
+ self.batch_size = batch_size
32
+ groups = np.unique(self.group_ids).tolist()
33
+
34
+ # buffer the indices of each group until batch size is reached
35
+ self.buffer_per_group = {k: [] for k in groups}
36
+
37
+ def __iter__(self):
38
+ for idx in self.sampler:
39
+ group_id = self.group_ids[idx]
40
+ group_buffer = self.buffer_per_group[group_id]
41
+ group_buffer.append(idx)
42
+ if len(group_buffer) == self.batch_size:
43
+ yield group_buffer[:] # yield a copy of the list
44
+ del group_buffer[:]
45
+
46
+ def __len__(self):
47
+ raise NotImplementedError("len() of GroupedBatchSampler is not well-defined.")
CatVTON/detectron2/data/transforms/__init__.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ from fvcore.transforms.transform import Transform, TransformList # order them first
3
+ from fvcore.transforms.transform import *
4
+ from .transform import *
5
+ from .augmentation import *
6
+ from .augmentation_impl import *
7
+
8
+ __all__ = [k for k in globals().keys() if not k.startswith("_")]
9
+
10
+
11
+ from detectron2.utils.env import fixup_module_metadata
12
+
13
+ fixup_module_metadata(__name__, globals(), __all__)
14
+ del fixup_module_metadata
CatVTON/detectron2/data/transforms/augmentation.py ADDED
@@ -0,0 +1,380 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) Facebook, Inc. and its affiliates.
3
+
4
+ import inspect
5
+ import numpy as np
6
+ import pprint
7
+ from typing import Any, List, Optional, Tuple, Union
8
+ from fvcore.transforms.transform import Transform, TransformList
9
+
10
+ """
11
+ See "Data Augmentation" tutorial for an overview of the system:
12
+ https://detectron2.readthedocs.io/tutorials/augmentation.html
13
+ """
14
+
15
+
16
+ __all__ = [
17
+ "Augmentation",
18
+ "AugmentationList",
19
+ "AugInput",
20
+ "TransformGen",
21
+ "apply_transform_gens",
22
+ "StandardAugInput",
23
+ "apply_augmentations",
24
+ ]
25
+
26
+
27
+ def _check_img_dtype(img):
28
+ assert isinstance(img, np.ndarray), "[Augmentation] Needs an numpy array, but got a {}!".format(
29
+ type(img)
30
+ )
31
+ assert not isinstance(img.dtype, np.integer) or (
32
+ img.dtype == np.uint8
33
+ ), "[Augmentation] Got image of type {}, use uint8 or floating points instead!".format(
34
+ img.dtype
35
+ )
36
+ assert img.ndim in [2, 3], img.ndim
37
+
38
+
39
+ def _get_aug_input_args(aug, aug_input) -> List[Any]:
40
+ """
41
+ Get the arguments to be passed to ``aug.get_transform`` from the input ``aug_input``.
42
+ """
43
+ if aug.input_args is None:
44
+ # Decide what attributes are needed automatically
45
+ prms = list(inspect.signature(aug.get_transform).parameters.items())
46
+ # The default behavior is: if there is one parameter, then its "image"
47
+ # (work automatically for majority of use cases, and also avoid BC breaking),
48
+ # Otherwise, use the argument names.
49
+ if len(prms) == 1:
50
+ names = ("image",)
51
+ else:
52
+ names = []
53
+ for name, prm in prms:
54
+ if prm.kind in (
55
+ inspect.Parameter.VAR_POSITIONAL,
56
+ inspect.Parameter.VAR_KEYWORD,
57
+ ):
58
+ raise TypeError(
59
+ f""" \
60
+ The default implementation of `{type(aug)}.__call__` does not allow \
61
+ `{type(aug)}.get_transform` to use variable-length arguments (*args, **kwargs)! \
62
+ If arguments are unknown, reimplement `__call__` instead. \
63
+ """
64
+ )
65
+ names.append(name)
66
+ aug.input_args = tuple(names)
67
+
68
+ args = []
69
+ for f in aug.input_args:
70
+ try:
71
+ args.append(getattr(aug_input, f))
72
+ except AttributeError as e:
73
+ raise AttributeError(
74
+ f"{type(aug)}.get_transform needs input attribute '{f}', "
75
+ f"but it is not an attribute of {type(aug_input)}!"
76
+ ) from e
77
+ return args
78
+
79
+
80
+ class Augmentation:
81
+ """
82
+ Augmentation defines (often random) policies/strategies to generate :class:`Transform`
83
+ from data. It is often used for pre-processing of input data.
84
+
85
+ A "policy" that generates a :class:`Transform` may, in the most general case,
86
+ need arbitrary information from input data in order to determine what transforms
87
+ to apply. Therefore, each :class:`Augmentation` instance defines the arguments
88
+ needed by its :meth:`get_transform` method. When called with the positional arguments,
89
+ the :meth:`get_transform` method executes the policy.
90
+
91
+ Note that :class:`Augmentation` defines the policies to create a :class:`Transform`,
92
+ but not how to execute the actual transform operations to those data.
93
+ Its :meth:`__call__` method will use :meth:`AugInput.transform` to execute the transform.
94
+
95
+ The returned `Transform` object is meant to describe deterministic transformation, which means
96
+ it can be re-applied on associated data, e.g. the geometry of an image and its segmentation
97
+ masks need to be transformed together.
98
+ (If such re-application is not needed, then determinism is not a crucial requirement.)
99
+ """
100
+
101
+ input_args: Optional[Tuple[str]] = None
102
+ """
103
+ Stores the attribute names needed by :meth:`get_transform`, e.g. ``("image", "sem_seg")``.
104
+ By default, it is just a tuple of argument names in :meth:`self.get_transform`, which often only
105
+ contain "image". As long as the argument name convention is followed, there is no need for
106
+ users to touch this attribute.
107
+ """
108
+
109
+ def _init(self, params=None):
110
+ if params:
111
+ for k, v in params.items():
112
+ if k != "self" and not k.startswith("_"):
113
+ setattr(self, k, v)
114
+
115
+ def get_transform(self, *args) -> Transform:
116
+ """
117
+ Execute the policy based on input data, and decide what transform to apply to inputs.
118
+
119
+ Args:
120
+ args: Any fixed-length positional arguments. By default, the name of the arguments
121
+ should exist in the :class:`AugInput` to be used.
122
+
123
+ Returns:
124
+ Transform: Returns the deterministic transform to apply to the input.
125
+
126
+ Examples:
127
+ ::
128
+ class MyAug:
129
+ # if a policy needs to know both image and semantic segmentation
130
+ def get_transform(image, sem_seg) -> T.Transform:
131
+ pass
132
+ tfm: Transform = MyAug().get_transform(image, sem_seg)
133
+ new_image = tfm.apply_image(image)
134
+
135
+ Notes:
136
+ Users can freely use arbitrary new argument names in custom
137
+ :meth:`get_transform` method, as long as they are available in the
138
+ input data. In detectron2 we use the following convention:
139
+
140
+ * image: (H,W) or (H,W,C) ndarray of type uint8 in range [0, 255], or
141
+ floating point in range [0, 1] or [0, 255].
142
+ * boxes: (N,4) ndarray of float32. It represents the instance bounding boxes
143
+ of N instances. Each is in XYXY format in unit of absolute coordinates.
144
+ * sem_seg: (H,W) ndarray of type uint8. Each element is an integer label of pixel.
145
+
146
+ We do not specify convention for other types and do not include builtin
147
+ :class:`Augmentation` that uses other types in detectron2.
148
+ """
149
+ raise NotImplementedError
150
+
151
+ def __call__(self, aug_input) -> Transform:
152
+ """
153
+ Augment the given `aug_input` **in-place**, and return the transform that's used.
154
+
155
+ This method will be called to apply the augmentation. In most augmentation, it
156
+ is enough to use the default implementation, which calls :meth:`get_transform`
157
+ using the inputs. But a subclass can overwrite it to have more complicated logic.
158
+
159
+ Args:
160
+ aug_input (AugInput): an object that has attributes needed by this augmentation
161
+ (defined by ``self.get_transform``). Its ``transform`` method will be called
162
+ to in-place transform it.
163
+
164
+ Returns:
165
+ Transform: the transform that is applied on the input.
166
+ """
167
+ args = _get_aug_input_args(self, aug_input)
168
+ tfm = self.get_transform(*args)
169
+ assert isinstance(tfm, (Transform, TransformList)), (
170
+ f"{type(self)}.get_transform must return an instance of Transform! "
171
+ f"Got {type(tfm)} instead."
172
+ )
173
+ aug_input.transform(tfm)
174
+ return tfm
175
+
176
+ def _rand_range(self, low=1.0, high=None, size=None):
177
+ """
178
+ Uniform float random number between low and high.
179
+ """
180
+ if high is None:
181
+ low, high = 0, low
182
+ if size is None:
183
+ size = []
184
+ return np.random.uniform(low, high, size)
185
+
186
+ def __repr__(self):
187
+ """
188
+ Produce something like:
189
+ "MyAugmentation(field1={self.field1}, field2={self.field2})"
190
+ """
191
+ try:
192
+ sig = inspect.signature(self.__init__)
193
+ classname = type(self).__name__
194
+ argstr = []
195
+ for name, param in sig.parameters.items():
196
+ assert (
197
+ param.kind != param.VAR_POSITIONAL and param.kind != param.VAR_KEYWORD
198
+ ), "The default __repr__ doesn't support *args or **kwargs"
199
+ assert hasattr(self, name), (
200
+ "Attribute {} not found! "
201
+ "Default __repr__ only works if attributes match the constructor.".format(name)
202
+ )
203
+ attr = getattr(self, name)
204
+ default = param.default
205
+ if default is attr:
206
+ continue
207
+ attr_str = pprint.pformat(attr)
208
+ if "\n" in attr_str:
209
+ # don't show it if pformat decides to use >1 lines
210
+ attr_str = "..."
211
+ argstr.append("{}={}".format(name, attr_str))
212
+ return "{}({})".format(classname, ", ".join(argstr))
213
+ except AssertionError:
214
+ return super().__repr__()
215
+
216
+ __str__ = __repr__
217
+
218
+
219
+ class _TransformToAug(Augmentation):
220
+ def __init__(self, tfm: Transform):
221
+ self.tfm = tfm
222
+
223
+ def get_transform(self, *args):
224
+ return self.tfm
225
+
226
+ def __repr__(self):
227
+ return repr(self.tfm)
228
+
229
+ __str__ = __repr__
230
+
231
+
232
+ def _transform_to_aug(tfm_or_aug):
233
+ """
234
+ Wrap Transform into Augmentation.
235
+ Private, used internally to implement augmentations.
236
+ """
237
+ assert isinstance(tfm_or_aug, (Transform, Augmentation)), tfm_or_aug
238
+ if isinstance(tfm_or_aug, Augmentation):
239
+ return tfm_or_aug
240
+ else:
241
+ return _TransformToAug(tfm_or_aug)
242
+
243
+
244
+ class AugmentationList(Augmentation):
245
+ """
246
+ Apply a sequence of augmentations.
247
+
248
+ It has ``__call__`` method to apply the augmentations.
249
+
250
+ Note that :meth:`get_transform` method is impossible (will throw error if called)
251
+ for :class:`AugmentationList`, because in order to apply a sequence of augmentations,
252
+ the kth augmentation must be applied first, to provide inputs needed by the (k+1)th
253
+ augmentation.
254
+ """
255
+
256
+ def __init__(self, augs):
257
+ """
258
+ Args:
259
+ augs (list[Augmentation or Transform]):
260
+ """
261
+ super().__init__()
262
+ self.augs = [_transform_to_aug(x) for x in augs]
263
+
264
+ def __call__(self, aug_input) -> TransformList:
265
+ tfms = []
266
+ for x in self.augs:
267
+ tfm = x(aug_input)
268
+ tfms.append(tfm)
269
+ return TransformList(tfms)
270
+
271
+ def __repr__(self):
272
+ msgs = [str(x) for x in self.augs]
273
+ return "AugmentationList[{}]".format(", ".join(msgs))
274
+
275
+ __str__ = __repr__
276
+
277
+
278
+ class AugInput:
279
+ """
280
+ Input that can be used with :meth:`Augmentation.__call__`.
281
+ This is a standard implementation for the majority of use cases.
282
+ This class provides the standard attributes **"image", "boxes", "sem_seg"**
283
+ defined in :meth:`__init__` and they may be needed by different augmentations.
284
+ Most augmentation policies do not need attributes beyond these three.
285
+
286
+ After applying augmentations to these attributes (using :meth:`AugInput.transform`),
287
+ the returned transforms can then be used to transform other data structures that users have.
288
+
289
+ Examples:
290
+ ::
291
+ input = AugInput(image, boxes=boxes)
292
+ tfms = augmentation(input)
293
+ transformed_image = input.image
294
+ transformed_boxes = input.boxes
295
+ transformed_other_data = tfms.apply_other(other_data)
296
+
297
+ An extended project that works with new data types may implement augmentation policies
298
+ that need other inputs. An algorithm may need to transform inputs in a way different
299
+ from the standard approach defined in this class. In those rare situations, users can
300
+ implement a class similar to this class, that satify the following condition:
301
+
302
+ * The input must provide access to these data in the form of attribute access
303
+ (``getattr``). For example, if an :class:`Augmentation` to be applied needs "image"
304
+ and "sem_seg" arguments, its input must have the attribute "image" and "sem_seg".
305
+ * The input must have a ``transform(tfm: Transform) -> None`` method which
306
+ in-place transforms all its attributes.
307
+ """
308
+
309
+ # TODO maybe should support more builtin data types here
310
+ def __init__(
311
+ self,
312
+ image: np.ndarray,
313
+ *,
314
+ boxes: Optional[np.ndarray] = None,
315
+ sem_seg: Optional[np.ndarray] = None,
316
+ ):
317
+ """
318
+ Args:
319
+ image (ndarray): (H,W) or (H,W,C) ndarray of type uint8 in range [0, 255], or
320
+ floating point in range [0, 1] or [0, 255]. The meaning of C is up
321
+ to users.
322
+ boxes (ndarray or None): Nx4 float32 boxes in XYXY_ABS mode
323
+ sem_seg (ndarray or None): HxW uint8 semantic segmentation mask. Each element
324
+ is an integer label of pixel.
325
+ """
326
+ _check_img_dtype(image)
327
+ self.image = image
328
+ self.boxes = boxes
329
+ self.sem_seg = sem_seg
330
+
331
+ def transform(self, tfm: Transform) -> None:
332
+ """
333
+ In-place transform all attributes of this class.
334
+
335
+ By "in-place", it means after calling this method, accessing an attribute such
336
+ as ``self.image`` will return transformed data.
337
+ """
338
+ self.image = tfm.apply_image(self.image)
339
+ if self.boxes is not None:
340
+ self.boxes = tfm.apply_box(self.boxes)
341
+ if self.sem_seg is not None:
342
+ self.sem_seg = tfm.apply_segmentation(self.sem_seg)
343
+
344
+ def apply_augmentations(
345
+ self, augmentations: List[Union[Augmentation, Transform]]
346
+ ) -> TransformList:
347
+ """
348
+ Equivalent of ``AugmentationList(augmentations)(self)``
349
+ """
350
+ return AugmentationList(augmentations)(self)
351
+
352
+
353
+ def apply_augmentations(augmentations: List[Union[Transform, Augmentation]], inputs):
354
+ """
355
+ Use ``T.AugmentationList(augmentations)(inputs)`` instead.
356
+ """
357
+ if isinstance(inputs, np.ndarray):
358
+ # handle the common case of image-only Augmentation, also for backward compatibility
359
+ image_only = True
360
+ inputs = AugInput(inputs)
361
+ else:
362
+ image_only = False
363
+ tfms = inputs.apply_augmentations(augmentations)
364
+ return inputs.image if image_only else inputs, tfms
365
+
366
+
367
+ apply_transform_gens = apply_augmentations
368
+ """
369
+ Alias for backward-compatibility.
370
+ """
371
+
372
+ TransformGen = Augmentation
373
+ """
374
+ Alias for Augmentation, since it is something that generates :class:`Transform`s
375
+ """
376
+
377
+ StandardAugInput = AugInput
378
+ """
379
+ Alias for compatibility. It's not worth the complexity to have two classes.
380
+ """
CatVTON/detectron2/data/transforms/augmentation_impl.py ADDED
@@ -0,0 +1,736 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) Facebook, Inc. and its affiliates.
3
+ """
4
+ Implement many useful :class:`Augmentation`.
5
+ """
6
+ import numpy as np
7
+ import sys
8
+ from numpy import random
9
+ from typing import Tuple
10
+ import torch
11
+ from fvcore.transforms.transform import (
12
+ BlendTransform,
13
+ CropTransform,
14
+ HFlipTransform,
15
+ NoOpTransform,
16
+ PadTransform,
17
+ Transform,
18
+ TransformList,
19
+ VFlipTransform,
20
+ )
21
+ from PIL import Image
22
+
23
+ from detectron2.structures import Boxes, pairwise_iou
24
+
25
+ from .augmentation import Augmentation, _transform_to_aug
26
+ from .transform import ExtentTransform, ResizeTransform, RotationTransform
27
+
28
+ __all__ = [
29
+ "FixedSizeCrop",
30
+ "RandomApply",
31
+ "RandomBrightness",
32
+ "RandomContrast",
33
+ "RandomCrop",
34
+ "RandomExtent",
35
+ "RandomFlip",
36
+ "RandomSaturation",
37
+ "RandomLighting",
38
+ "RandomRotation",
39
+ "Resize",
40
+ "ResizeScale",
41
+ "ResizeShortestEdge",
42
+ "RandomCrop_CategoryAreaConstraint",
43
+ "RandomResize",
44
+ "MinIoURandomCrop",
45
+ ]
46
+
47
+
48
+ class RandomApply(Augmentation):
49
+ """
50
+ Randomly apply an augmentation with a given probability.
51
+ """
52
+
53
+ def __init__(self, tfm_or_aug, prob=0.5):
54
+ """
55
+ Args:
56
+ tfm_or_aug (Transform, Augmentation): the transform or augmentation
57
+ to be applied. It can either be a `Transform` or `Augmentation`
58
+ instance.
59
+ prob (float): probability between 0.0 and 1.0 that
60
+ the wrapper transformation is applied
61
+ """
62
+ super().__init__()
63
+ self.aug = _transform_to_aug(tfm_or_aug)
64
+ assert 0.0 <= prob <= 1.0, f"Probablity must be between 0.0 and 1.0 (given: {prob})"
65
+ self.prob = prob
66
+
67
+ def get_transform(self, *args):
68
+ do = self._rand_range() < self.prob
69
+ if do:
70
+ return self.aug.get_transform(*args)
71
+ else:
72
+ return NoOpTransform()
73
+
74
+ def __call__(self, aug_input):
75
+ do = self._rand_range() < self.prob
76
+ if do:
77
+ return self.aug(aug_input)
78
+ else:
79
+ return NoOpTransform()
80
+
81
+
82
+ class RandomFlip(Augmentation):
83
+ """
84
+ Flip the image horizontally or vertically with the given probability.
85
+ """
86
+
87
+ def __init__(self, prob=0.5, *, horizontal=True, vertical=False):
88
+ """
89
+ Args:
90
+ prob (float): probability of flip.
91
+ horizontal (boolean): whether to apply horizontal flipping
92
+ vertical (boolean): whether to apply vertical flipping
93
+ """
94
+ super().__init__()
95
+
96
+ if horizontal and vertical:
97
+ raise ValueError("Cannot do both horiz and vert. Please use two Flip instead.")
98
+ if not horizontal and not vertical:
99
+ raise ValueError("At least one of horiz or vert has to be True!")
100
+ self._init(locals())
101
+
102
+ def get_transform(self, image):
103
+ h, w = image.shape[:2]
104
+ do = self._rand_range() < self.prob
105
+ if do:
106
+ if self.horizontal:
107
+ return HFlipTransform(w)
108
+ elif self.vertical:
109
+ return VFlipTransform(h)
110
+ else:
111
+ return NoOpTransform()
112
+
113
+
114
+ class Resize(Augmentation):
115
+ """Resize image to a fixed target size"""
116
+
117
+ def __init__(self, shape, interp=Image.BILINEAR):
118
+ """
119
+ Args:
120
+ shape: (h, w) tuple or a int
121
+ interp: PIL interpolation method
122
+ """
123
+ if isinstance(shape, int):
124
+ shape = (shape, shape)
125
+ shape = tuple(shape)
126
+ self._init(locals())
127
+
128
+ def get_transform(self, image):
129
+ return ResizeTransform(
130
+ image.shape[0], image.shape[1], self.shape[0], self.shape[1], self.interp
131
+ )
132
+
133
+
134
+ class ResizeShortestEdge(Augmentation):
135
+ """
136
+ Resize the image while keeping the aspect ratio unchanged.
137
+ It attempts to scale the shorter edge to the given `short_edge_length`,
138
+ as long as the longer edge does not exceed `max_size`.
139
+ If `max_size` is reached, then downscale so that the longer edge does not exceed max_size.
140
+ """
141
+
142
+ @torch.jit.unused
143
+ def __init__(
144
+ self, short_edge_length, max_size=sys.maxsize, sample_style="range", interp=Image.BILINEAR
145
+ ):
146
+ """
147
+ Args:
148
+ short_edge_length (list[int]): If ``sample_style=="range"``,
149
+ a [min, max] interval from which to sample the shortest edge length.
150
+ If ``sample_style=="choice"``, a list of shortest edge lengths to sample from.
151
+ max_size (int): maximum allowed longest edge length.
152
+ sample_style (str): either "range" or "choice".
153
+ """
154
+ super().__init__()
155
+ assert sample_style in ["range", "choice"], sample_style
156
+
157
+ self.is_range = sample_style == "range"
158
+ if isinstance(short_edge_length, int):
159
+ short_edge_length = (short_edge_length, short_edge_length)
160
+ if self.is_range:
161
+ assert len(short_edge_length) == 2, (
162
+ "short_edge_length must be two values using 'range' sample style."
163
+ f" Got {short_edge_length}!"
164
+ )
165
+ self._init(locals())
166
+
167
+ @torch.jit.unused
168
+ def get_transform(self, image):
169
+ h, w = image.shape[:2]
170
+ if self.is_range:
171
+ size = np.random.randint(self.short_edge_length[0], self.short_edge_length[1] + 1)
172
+ else:
173
+ size = np.random.choice(self.short_edge_length)
174
+ if size == 0:
175
+ return NoOpTransform()
176
+
177
+ newh, neww = ResizeShortestEdge.get_output_shape(h, w, size, self.max_size)
178
+ return ResizeTransform(h, w, newh, neww, self.interp)
179
+
180
+ @staticmethod
181
+ def get_output_shape(
182
+ oldh: int, oldw: int, short_edge_length: int, max_size: int
183
+ ) -> Tuple[int, int]:
184
+ """
185
+ Compute the output size given input size and target short edge length.
186
+ """
187
+ h, w = oldh, oldw
188
+ size = short_edge_length * 1.0
189
+ scale = size / min(h, w)
190
+ if h < w:
191
+ newh, neww = size, scale * w
192
+ else:
193
+ newh, neww = scale * h, size
194
+ if max(newh, neww) > max_size:
195
+ scale = max_size * 1.0 / max(newh, neww)
196
+ newh = newh * scale
197
+ neww = neww * scale
198
+ neww = int(neww + 0.5)
199
+ newh = int(newh + 0.5)
200
+ return (newh, neww)
201
+
202
+
203
+ class ResizeScale(Augmentation):
204
+ """
205
+ Takes target size as input and randomly scales the given target size between `min_scale`
206
+ and `max_scale`. It then scales the input image such that it fits inside the scaled target
207
+ box, keeping the aspect ratio constant.
208
+ This implements the resize part of the Google's 'resize_and_crop' data augmentation:
209
+ https://github.com/tensorflow/tpu/blob/master/models/official/detection/utils/input_utils.py#L127
210
+ """
211
+
212
+ def __init__(
213
+ self,
214
+ min_scale: float,
215
+ max_scale: float,
216
+ target_height: int,
217
+ target_width: int,
218
+ interp: int = Image.BILINEAR,
219
+ ):
220
+ """
221
+ Args:
222
+ min_scale: minimum image scale range.
223
+ max_scale: maximum image scale range.
224
+ target_height: target image height.
225
+ target_width: target image width.
226
+ interp: image interpolation method.
227
+ """
228
+ super().__init__()
229
+ self._init(locals())
230
+
231
+ def _get_resize(self, image: np.ndarray, scale: float) -> Transform:
232
+ input_size = image.shape[:2]
233
+
234
+ # Compute new target size given a scale.
235
+ target_size = (self.target_height, self.target_width)
236
+ target_scale_size = np.multiply(target_size, scale)
237
+
238
+ # Compute actual rescaling applied to input image and output size.
239
+ output_scale = np.minimum(
240
+ target_scale_size[0] / input_size[0], target_scale_size[1] / input_size[1]
241
+ )
242
+ output_size = np.round(np.multiply(input_size, output_scale)).astype(int)
243
+
244
+ return ResizeTransform(
245
+ input_size[0], input_size[1], int(output_size[0]), int(output_size[1]), self.interp
246
+ )
247
+
248
+ def get_transform(self, image: np.ndarray) -> Transform:
249
+ random_scale = np.random.uniform(self.min_scale, self.max_scale)
250
+ return self._get_resize(image, random_scale)
251
+
252
+
253
+ class RandomRotation(Augmentation):
254
+ """
255
+ This method returns a copy of this image, rotated the given
256
+ number of degrees counter clockwise around the given center.
257
+ """
258
+
259
+ def __init__(self, angle, expand=True, center=None, sample_style="range", interp=None):
260
+ """
261
+ Args:
262
+ angle (list[float]): If ``sample_style=="range"``,
263
+ a [min, max] interval from which to sample the angle (in degrees).
264
+ If ``sample_style=="choice"``, a list of angles to sample from
265
+ expand (bool): choose if the image should be resized to fit the whole
266
+ rotated image (default), or simply cropped
267
+ center (list[[float, float]]): If ``sample_style=="range"``,
268
+ a [[minx, miny], [maxx, maxy]] relative interval from which to sample the center,
269
+ [0, 0] being the top left of the image and [1, 1] the bottom right.
270
+ If ``sample_style=="choice"``, a list of centers to sample from
271
+ Default: None, which means that the center of rotation is the center of the image
272
+ center has no effect if expand=True because it only affects shifting
273
+ """
274
+ super().__init__()
275
+ assert sample_style in ["range", "choice"], sample_style
276
+ self.is_range = sample_style == "range"
277
+ if isinstance(angle, (float, int)):
278
+ angle = (angle, angle)
279
+ if center is not None and isinstance(center[0], (float, int)):
280
+ center = (center, center)
281
+ self._init(locals())
282
+
283
+ def get_transform(self, image):
284
+ h, w = image.shape[:2]
285
+ center = None
286
+ if self.is_range:
287
+ angle = np.random.uniform(self.angle[0], self.angle[1])
288
+ if self.center is not None:
289
+ center = (
290
+ np.random.uniform(self.center[0][0], self.center[1][0]),
291
+ np.random.uniform(self.center[0][1], self.center[1][1]),
292
+ )
293
+ else:
294
+ angle = np.random.choice(self.angle)
295
+ if self.center is not None:
296
+ center = np.random.choice(self.center)
297
+
298
+ if center is not None:
299
+ center = (w * center[0], h * center[1]) # Convert to absolute coordinates
300
+
301
+ if angle % 360 == 0:
302
+ return NoOpTransform()
303
+
304
+ return RotationTransform(h, w, angle, expand=self.expand, center=center, interp=self.interp)
305
+
306
+
307
+ class FixedSizeCrop(Augmentation):
308
+ """
309
+ If `crop_size` is smaller than the input image size, then it uses a random crop of
310
+ the crop size. If `crop_size` is larger than the input image size, then it pads
311
+ the right and the bottom of the image to the crop size if `pad` is True, otherwise
312
+ it returns the smaller image.
313
+ """
314
+
315
+ def __init__(
316
+ self,
317
+ crop_size: Tuple[int],
318
+ pad: bool = True,
319
+ pad_value: float = 128.0,
320
+ seg_pad_value: int = 255,
321
+ ):
322
+ """
323
+ Args:
324
+ crop_size: target image (height, width).
325
+ pad: if True, will pad images smaller than `crop_size` up to `crop_size`
326
+ pad_value: the padding value to the image.
327
+ seg_pad_value: the padding value to the segmentation mask.
328
+ """
329
+ super().__init__()
330
+ self._init(locals())
331
+
332
+ def _get_crop(self, image: np.ndarray) -> Transform:
333
+ # Compute the image scale and scaled size.
334
+ input_size = image.shape[:2]
335
+ output_size = self.crop_size
336
+
337
+ # Add random crop if the image is scaled up.
338
+ max_offset = np.subtract(input_size, output_size)
339
+ max_offset = np.maximum(max_offset, 0)
340
+ offset = np.multiply(max_offset, np.random.uniform(0.0, 1.0))
341
+ offset = np.round(offset).astype(int)
342
+ return CropTransform(
343
+ offset[1], offset[0], output_size[1], output_size[0], input_size[1], input_size[0]
344
+ )
345
+
346
+ def _get_pad(self, image: np.ndarray) -> Transform:
347
+ # Compute the image scale and scaled size.
348
+ input_size = image.shape[:2]
349
+ output_size = self.crop_size
350
+
351
+ # Add padding if the image is scaled down.
352
+ pad_size = np.subtract(output_size, input_size)
353
+ pad_size = np.maximum(pad_size, 0)
354
+ original_size = np.minimum(input_size, output_size)
355
+ return PadTransform(
356
+ 0,
357
+ 0,
358
+ pad_size[1],
359
+ pad_size[0],
360
+ original_size[1],
361
+ original_size[0],
362
+ self.pad_value,
363
+ self.seg_pad_value,
364
+ )
365
+
366
+ def get_transform(self, image: np.ndarray) -> TransformList:
367
+ transforms = [self._get_crop(image)]
368
+ if self.pad:
369
+ transforms.append(self._get_pad(image))
370
+ return TransformList(transforms)
371
+
372
+
373
+ class RandomCrop(Augmentation):
374
+ """
375
+ Randomly crop a rectangle region out of an image.
376
+ """
377
+
378
+ def __init__(self, crop_type: str, crop_size):
379
+ """
380
+ Args:
381
+ crop_type (str): one of "relative_range", "relative", "absolute", "absolute_range".
382
+ crop_size (tuple[float, float]): two floats, explained below.
383
+
384
+ - "relative": crop a (H * crop_size[0], W * crop_size[1]) region from an input image of
385
+ size (H, W). crop size should be in (0, 1]
386
+ - "relative_range": uniformly sample two values from [crop_size[0], 1]
387
+ and [crop_size[1]], 1], and use them as in "relative" crop type.
388
+ - "absolute" crop a (crop_size[0], crop_size[1]) region from input image.
389
+ crop_size must be smaller than the input image size.
390
+ - "absolute_range", for an input of size (H, W), uniformly sample H_crop in
391
+ [crop_size[0], min(H, crop_size[1])] and W_crop in [crop_size[0], min(W, crop_size[1])].
392
+ Then crop a region (H_crop, W_crop).
393
+ """
394
+ # TODO style of relative_range and absolute_range are not consistent:
395
+ # one takes (h, w) but another takes (min, max)
396
+ super().__init__()
397
+ assert crop_type in ["relative_range", "relative", "absolute", "absolute_range"]
398
+ self._init(locals())
399
+
400
+ def get_transform(self, image):
401
+ h, w = image.shape[:2]
402
+ croph, cropw = self.get_crop_size((h, w))
403
+ assert h >= croph and w >= cropw, "Shape computation in {} has bugs.".format(self)
404
+ h0 = np.random.randint(h - croph + 1)
405
+ w0 = np.random.randint(w - cropw + 1)
406
+ return CropTransform(w0, h0, cropw, croph)
407
+
408
+ def get_crop_size(self, image_size):
409
+ """
410
+ Args:
411
+ image_size (tuple): height, width
412
+
413
+ Returns:
414
+ crop_size (tuple): height, width in absolute pixels
415
+ """
416
+ h, w = image_size
417
+ if self.crop_type == "relative":
418
+ ch, cw = self.crop_size
419
+ return int(h * ch + 0.5), int(w * cw + 0.5)
420
+ elif self.crop_type == "relative_range":
421
+ crop_size = np.asarray(self.crop_size, dtype=np.float32)
422
+ ch, cw = crop_size + np.random.rand(2) * (1 - crop_size)
423
+ return int(h * ch + 0.5), int(w * cw + 0.5)
424
+ elif self.crop_type == "absolute":
425
+ return (min(self.crop_size[0], h), min(self.crop_size[1], w))
426
+ elif self.crop_type == "absolute_range":
427
+ assert self.crop_size[0] <= self.crop_size[1]
428
+ ch = np.random.randint(min(h, self.crop_size[0]), min(h, self.crop_size[1]) + 1)
429
+ cw = np.random.randint(min(w, self.crop_size[0]), min(w, self.crop_size[1]) + 1)
430
+ return ch, cw
431
+ else:
432
+ raise NotImplementedError("Unknown crop type {}".format(self.crop_type))
433
+
434
+
435
+ class RandomCrop_CategoryAreaConstraint(Augmentation):
436
+ """
437
+ Similar to :class:`RandomCrop`, but find a cropping window such that no single category
438
+ occupies a ratio of more than `single_category_max_area` in semantic segmentation ground
439
+ truth, which can cause unstability in training. The function attempts to find such a valid
440
+ cropping window for at most 10 times.
441
+ """
442
+
443
+ def __init__(
444
+ self,
445
+ crop_type: str,
446
+ crop_size,
447
+ single_category_max_area: float = 1.0,
448
+ ignored_category: int = None,
449
+ ):
450
+ """
451
+ Args:
452
+ crop_type, crop_size: same as in :class:`RandomCrop`
453
+ single_category_max_area: the maximum allowed area ratio of a
454
+ category. Set to 1.0 to disable
455
+ ignored_category: allow this category in the semantic segmentation
456
+ ground truth to exceed the area ratio. Usually set to the category
457
+ that's ignored in training.
458
+ """
459
+ self.crop_aug = RandomCrop(crop_type, crop_size)
460
+ self._init(locals())
461
+
462
+ def get_transform(self, image, sem_seg):
463
+ if self.single_category_max_area >= 1.0:
464
+ return self.crop_aug.get_transform(image)
465
+ else:
466
+ h, w = sem_seg.shape
467
+ for _ in range(10):
468
+ crop_size = self.crop_aug.get_crop_size((h, w))
469
+ y0 = np.random.randint(h - crop_size[0] + 1)
470
+ x0 = np.random.randint(w - crop_size[1] + 1)
471
+ sem_seg_temp = sem_seg[y0 : y0 + crop_size[0], x0 : x0 + crop_size[1]]
472
+ labels, cnt = np.unique(sem_seg_temp, return_counts=True)
473
+ if self.ignored_category is not None:
474
+ cnt = cnt[labels != self.ignored_category]
475
+ if len(cnt) > 1 and np.max(cnt) < np.sum(cnt) * self.single_category_max_area:
476
+ break
477
+ crop_tfm = CropTransform(x0, y0, crop_size[1], crop_size[0])
478
+ return crop_tfm
479
+
480
+
481
+ class RandomExtent(Augmentation):
482
+ """
483
+ Outputs an image by cropping a random "subrect" of the source image.
484
+
485
+ The subrect can be parameterized to include pixels outside the source image,
486
+ in which case they will be set to zeros (i.e. black). The size of the output
487
+ image will vary with the size of the random subrect.
488
+ """
489
+
490
+ def __init__(self, scale_range, shift_range):
491
+ """
492
+ Args:
493
+ output_size (h, w): Dimensions of output image
494
+ scale_range (l, h): Range of input-to-output size scaling factor
495
+ shift_range (x, y): Range of shifts of the cropped subrect. The rect
496
+ is shifted by [w / 2 * Uniform(-x, x), h / 2 * Uniform(-y, y)],
497
+ where (w, h) is the (width, height) of the input image. Set each
498
+ component to zero to crop at the image's center.
499
+ """
500
+ super().__init__()
501
+ self._init(locals())
502
+
503
+ def get_transform(self, image):
504
+ img_h, img_w = image.shape[:2]
505
+
506
+ # Initialize src_rect to fit the input image.
507
+ src_rect = np.array([-0.5 * img_w, -0.5 * img_h, 0.5 * img_w, 0.5 * img_h])
508
+
509
+ # Apply a random scaling to the src_rect.
510
+ src_rect *= np.random.uniform(self.scale_range[0], self.scale_range[1])
511
+
512
+ # Apply a random shift to the coordinates origin.
513
+ src_rect[0::2] += self.shift_range[0] * img_w * (np.random.rand() - 0.5)
514
+ src_rect[1::2] += self.shift_range[1] * img_h * (np.random.rand() - 0.5)
515
+
516
+ # Map src_rect coordinates into image coordinates (center at corner).
517
+ src_rect[0::2] += 0.5 * img_w
518
+ src_rect[1::2] += 0.5 * img_h
519
+
520
+ return ExtentTransform(
521
+ src_rect=(src_rect[0], src_rect[1], src_rect[2], src_rect[3]),
522
+ output_size=(int(src_rect[3] - src_rect[1]), int(src_rect[2] - src_rect[0])),
523
+ )
524
+
525
+
526
+ class RandomContrast(Augmentation):
527
+ """
528
+ Randomly transforms image contrast.
529
+
530
+ Contrast intensity is uniformly sampled in (intensity_min, intensity_max).
531
+ - intensity < 1 will reduce contrast
532
+ - intensity = 1 will preserve the input image
533
+ - intensity > 1 will increase contrast
534
+
535
+ See: https://pillow.readthedocs.io/en/3.0.x/reference/ImageEnhance.html
536
+ """
537
+
538
+ def __init__(self, intensity_min, intensity_max):
539
+ """
540
+ Args:
541
+ intensity_min (float): Minimum augmentation
542
+ intensity_max (float): Maximum augmentation
543
+ """
544
+ super().__init__()
545
+ self._init(locals())
546
+
547
+ def get_transform(self, image):
548
+ w = np.random.uniform(self.intensity_min, self.intensity_max)
549
+ return BlendTransform(src_image=image.mean(), src_weight=1 - w, dst_weight=w)
550
+
551
+
552
+ class RandomBrightness(Augmentation):
553
+ """
554
+ Randomly transforms image brightness.
555
+
556
+ Brightness intensity is uniformly sampled in (intensity_min, intensity_max).
557
+ - intensity < 1 will reduce brightness
558
+ - intensity = 1 will preserve the input image
559
+ - intensity > 1 will increase brightness
560
+
561
+ See: https://pillow.readthedocs.io/en/3.0.x/reference/ImageEnhance.html
562
+ """
563
+
564
+ def __init__(self, intensity_min, intensity_max):
565
+ """
566
+ Args:
567
+ intensity_min (float): Minimum augmentation
568
+ intensity_max (float): Maximum augmentation
569
+ """
570
+ super().__init__()
571
+ self._init(locals())
572
+
573
+ def get_transform(self, image):
574
+ w = np.random.uniform(self.intensity_min, self.intensity_max)
575
+ return BlendTransform(src_image=0, src_weight=1 - w, dst_weight=w)
576
+
577
+
578
+ class RandomSaturation(Augmentation):
579
+ """
580
+ Randomly transforms saturation of an RGB image.
581
+ Input images are assumed to have 'RGB' channel order.
582
+
583
+ Saturation intensity is uniformly sampled in (intensity_min, intensity_max).
584
+ - intensity < 1 will reduce saturation (make the image more grayscale)
585
+ - intensity = 1 will preserve the input image
586
+ - intensity > 1 will increase saturation
587
+
588
+ See: https://pillow.readthedocs.io/en/3.0.x/reference/ImageEnhance.html
589
+ """
590
+
591
+ def __init__(self, intensity_min, intensity_max):
592
+ """
593
+ Args:
594
+ intensity_min (float): Minimum augmentation (1 preserves input).
595
+ intensity_max (float): Maximum augmentation (1 preserves input).
596
+ """
597
+ super().__init__()
598
+ self._init(locals())
599
+
600
+ def get_transform(self, image):
601
+ assert image.shape[-1] == 3, "RandomSaturation only works on RGB images"
602
+ w = np.random.uniform(self.intensity_min, self.intensity_max)
603
+ grayscale = image.dot([0.299, 0.587, 0.114])[:, :, np.newaxis]
604
+ return BlendTransform(src_image=grayscale, src_weight=1 - w, dst_weight=w)
605
+
606
+
607
+ class RandomLighting(Augmentation):
608
+ """
609
+ The "lighting" augmentation described in AlexNet, using fixed PCA over ImageNet.
610
+ Input images are assumed to have 'RGB' channel order.
611
+
612
+ The degree of color jittering is randomly sampled via a normal distribution,
613
+ with standard deviation given by the scale parameter.
614
+ """
615
+
616
+ def __init__(self, scale):
617
+ """
618
+ Args:
619
+ scale (float): Standard deviation of principal component weighting.
620
+ """
621
+ super().__init__()
622
+ self._init(locals())
623
+ self.eigen_vecs = np.array(
624
+ [[-0.5675, 0.7192, 0.4009], [-0.5808, -0.0045, -0.8140], [-0.5836, -0.6948, 0.4203]]
625
+ )
626
+ self.eigen_vals = np.array([0.2175, 0.0188, 0.0045])
627
+
628
+ def get_transform(self, image):
629
+ assert image.shape[-1] == 3, "RandomLighting only works on RGB images"
630
+ weights = np.random.normal(scale=self.scale, size=3)
631
+ return BlendTransform(
632
+ src_image=self.eigen_vecs.dot(weights * self.eigen_vals), src_weight=1.0, dst_weight=1.0
633
+ )
634
+
635
+
636
+ class RandomResize(Augmentation):
637
+ """Randomly resize image to a target size in shape_list"""
638
+
639
+ def __init__(self, shape_list, interp=Image.BILINEAR):
640
+ """
641
+ Args:
642
+ shape_list: a list of shapes in (h, w)
643
+ interp: PIL interpolation method
644
+ """
645
+ self.shape_list = shape_list
646
+ self._init(locals())
647
+
648
+ def get_transform(self, image):
649
+ shape_idx = np.random.randint(low=0, high=len(self.shape_list))
650
+ h, w = self.shape_list[shape_idx]
651
+ return ResizeTransform(image.shape[0], image.shape[1], h, w, self.interp)
652
+
653
+
654
+ class MinIoURandomCrop(Augmentation):
655
+ """Random crop the image & bboxes, the cropped patches have minimum IoU
656
+ requirement with original image & bboxes, the IoU threshold is randomly
657
+ selected from min_ious.
658
+
659
+ Args:
660
+ min_ious (tuple): minimum IoU threshold for all intersections with
661
+ bounding boxes
662
+ min_crop_size (float): minimum crop's size (i.e. h,w := a*h, a*w,
663
+ where a >= min_crop_size)
664
+ mode_trials: number of trials for sampling min_ious threshold
665
+ crop_trials: number of trials for sampling crop_size after cropping
666
+ """
667
+
668
+ def __init__(
669
+ self,
670
+ min_ious=(0.1, 0.3, 0.5, 0.7, 0.9),
671
+ min_crop_size=0.3,
672
+ mode_trials=1000,
673
+ crop_trials=50,
674
+ ):
675
+ self.min_ious = min_ious
676
+ self.sample_mode = (1, *min_ious, 0)
677
+ self.min_crop_size = min_crop_size
678
+ self.mode_trials = mode_trials
679
+ self.crop_trials = crop_trials
680
+
681
+ def get_transform(self, image, boxes):
682
+ """Call function to crop images and bounding boxes with minimum IoU
683
+ constraint.
684
+
685
+ Args:
686
+ boxes: ground truth boxes in (x1, y1, x2, y2) format
687
+ """
688
+ if boxes is None:
689
+ return NoOpTransform()
690
+ h, w, c = image.shape
691
+ for _ in range(self.mode_trials):
692
+ mode = random.choice(self.sample_mode)
693
+ self.mode = mode
694
+ if mode == 1:
695
+ return NoOpTransform()
696
+
697
+ min_iou = mode
698
+ for _ in range(self.crop_trials):
699
+ new_w = random.uniform(self.min_crop_size * w, w)
700
+ new_h = random.uniform(self.min_crop_size * h, h)
701
+
702
+ # h / w in [0.5, 2]
703
+ if new_h / new_w < 0.5 or new_h / new_w > 2:
704
+ continue
705
+
706
+ left = random.uniform(w - new_w)
707
+ top = random.uniform(h - new_h)
708
+
709
+ patch = np.array((int(left), int(top), int(left + new_w), int(top + new_h)))
710
+ # Line or point crop is not allowed
711
+ if patch[2] == patch[0] or patch[3] == patch[1]:
712
+ continue
713
+ overlaps = pairwise_iou(
714
+ Boxes(patch.reshape(-1, 4)), Boxes(boxes.reshape(-1, 4))
715
+ ).reshape(-1)
716
+ if len(overlaps) > 0 and overlaps.min() < min_iou:
717
+ continue
718
+
719
+ # center of boxes should inside the crop img
720
+ # only adjust boxes and instance masks when the gt is not empty
721
+ if len(overlaps) > 0:
722
+ # adjust boxes
723
+ def is_center_of_bboxes_in_patch(boxes, patch):
724
+ center = (boxes[:, :2] + boxes[:, 2:]) / 2
725
+ mask = (
726
+ (center[:, 0] > patch[0])
727
+ * (center[:, 1] > patch[1])
728
+ * (center[:, 0] < patch[2])
729
+ * (center[:, 1] < patch[3])
730
+ )
731
+ return mask
732
+
733
+ mask = is_center_of_bboxes_in_patch(boxes, patch)
734
+ if not mask.any():
735
+ continue
736
+ return CropTransform(int(left), int(top), int(new_w), int(new_h))
CatVTON/detectron2/data/transforms/transform.py ADDED
@@ -0,0 +1,351 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) Facebook, Inc. and its affiliates.
3
+
4
+ """
5
+ See "Data Augmentation" tutorial for an overview of the system:
6
+ https://detectron2.readthedocs.io/tutorials/augmentation.html
7
+ """
8
+
9
+ import numpy as np
10
+ import torch
11
+ import torch.nn.functional as F
12
+ from fvcore.transforms.transform import (
13
+ CropTransform,
14
+ HFlipTransform,
15
+ NoOpTransform,
16
+ Transform,
17
+ TransformList,
18
+ )
19
+ from PIL import Image
20
+
21
+ try:
22
+ import cv2 # noqa
23
+ except ImportError:
24
+ # OpenCV is an optional dependency at the moment
25
+ pass
26
+
27
+ __all__ = [
28
+ "ExtentTransform",
29
+ "ResizeTransform",
30
+ "RotationTransform",
31
+ "ColorTransform",
32
+ "PILColorTransform",
33
+ ]
34
+
35
+
36
+ class ExtentTransform(Transform):
37
+ """
38
+ Extracts a subregion from the source image and scales it to the output size.
39
+
40
+ The fill color is used to map pixels from the source rect that fall outside
41
+ the source image.
42
+
43
+ See: https://pillow.readthedocs.io/en/latest/PIL.html#PIL.ImageTransform.ExtentTransform
44
+ """
45
+
46
+ def __init__(self, src_rect, output_size, interp=Image.BILINEAR, fill=0):
47
+ """
48
+ Args:
49
+ src_rect (x0, y0, x1, y1): src coordinates
50
+ output_size (h, w): dst image size
51
+ interp: PIL interpolation methods
52
+ fill: Fill color used when src_rect extends outside image
53
+ """
54
+ super().__init__()
55
+ self._set_attributes(locals())
56
+
57
+ def apply_image(self, img, interp=None):
58
+ h, w = self.output_size
59
+ if len(img.shape) > 2 and img.shape[2] == 1:
60
+ pil_image = Image.fromarray(img[:, :, 0], mode="L")
61
+ else:
62
+ pil_image = Image.fromarray(img)
63
+ pil_image = pil_image.transform(
64
+ size=(w, h),
65
+ method=Image.EXTENT,
66
+ data=self.src_rect,
67
+ resample=interp if interp else self.interp,
68
+ fill=self.fill,
69
+ )
70
+ ret = np.asarray(pil_image)
71
+ if len(img.shape) > 2 and img.shape[2] == 1:
72
+ ret = np.expand_dims(ret, -1)
73
+ return ret
74
+
75
+ def apply_coords(self, coords):
76
+ # Transform image center from source coordinates into output coordinates
77
+ # and then map the new origin to the corner of the output image.
78
+ h, w = self.output_size
79
+ x0, y0, x1, y1 = self.src_rect
80
+ new_coords = coords.astype(np.float32)
81
+ new_coords[:, 0] -= 0.5 * (x0 + x1)
82
+ new_coords[:, 1] -= 0.5 * (y0 + y1)
83
+ new_coords[:, 0] *= w / (x1 - x0)
84
+ new_coords[:, 1] *= h / (y1 - y0)
85
+ new_coords[:, 0] += 0.5 * w
86
+ new_coords[:, 1] += 0.5 * h
87
+ return new_coords
88
+
89
+ def apply_segmentation(self, segmentation):
90
+ segmentation = self.apply_image(segmentation, interp=Image.NEAREST)
91
+ return segmentation
92
+
93
+
94
+ class ResizeTransform(Transform):
95
+ """
96
+ Resize the image to a target size.
97
+ """
98
+
99
+ def __init__(self, h, w, new_h, new_w, interp=None):
100
+ """
101
+ Args:
102
+ h, w (int): original image size
103
+ new_h, new_w (int): new image size
104
+ interp: PIL interpolation methods, defaults to bilinear.
105
+ """
106
+ # TODO decide on PIL vs opencv
107
+ super().__init__()
108
+ if interp is None:
109
+ interp = Image.BILINEAR
110
+ self._set_attributes(locals())
111
+
112
+ def apply_image(self, img, interp=None):
113
+ assert img.shape[:2] == (self.h, self.w)
114
+ assert len(img.shape) <= 4
115
+ interp_method = interp if interp is not None else self.interp
116
+
117
+ if img.dtype == np.uint8:
118
+ if len(img.shape) > 2 and img.shape[2] == 1:
119
+ pil_image = Image.fromarray(img[:, :, 0], mode="L")
120
+ else:
121
+ pil_image = Image.fromarray(img)
122
+ pil_image = pil_image.resize((self.new_w, self.new_h), interp_method)
123
+ ret = np.asarray(pil_image)
124
+ if len(img.shape) > 2 and img.shape[2] == 1:
125
+ ret = np.expand_dims(ret, -1)
126
+ else:
127
+ # PIL only supports uint8
128
+ if any(x < 0 for x in img.strides):
129
+ img = np.ascontiguousarray(img)
130
+ img = torch.from_numpy(img)
131
+ shape = list(img.shape)
132
+ shape_4d = shape[:2] + [1] * (4 - len(shape)) + shape[2:]
133
+ img = img.view(shape_4d).permute(2, 3, 0, 1) # hw(c) -> nchw
134
+ _PIL_RESIZE_TO_INTERPOLATE_MODE = {
135
+ Image.NEAREST: "nearest",
136
+ Image.BILINEAR: "bilinear",
137
+ Image.BICUBIC: "bicubic",
138
+ }
139
+ mode = _PIL_RESIZE_TO_INTERPOLATE_MODE[interp_method]
140
+ align_corners = None if mode == "nearest" else False
141
+ img = F.interpolate(
142
+ img, (self.new_h, self.new_w), mode=mode, align_corners=align_corners
143
+ )
144
+ shape[:2] = (self.new_h, self.new_w)
145
+ ret = img.permute(2, 3, 0, 1).view(shape).numpy() # nchw -> hw(c)
146
+
147
+ return ret
148
+
149
+ def apply_coords(self, coords):
150
+ coords[:, 0] = coords[:, 0] * (self.new_w * 1.0 / self.w)
151
+ coords[:, 1] = coords[:, 1] * (self.new_h * 1.0 / self.h)
152
+ return coords
153
+
154
+ def apply_segmentation(self, segmentation):
155
+ segmentation = self.apply_image(segmentation, interp=Image.NEAREST)
156
+ return segmentation
157
+
158
+ def inverse(self):
159
+ return ResizeTransform(self.new_h, self.new_w, self.h, self.w, self.interp)
160
+
161
+
162
+ class RotationTransform(Transform):
163
+ """
164
+ This method returns a copy of this image, rotated the given
165
+ number of degrees counter clockwise around its center.
166
+ """
167
+
168
+ def __init__(self, h, w, angle, expand=True, center=None, interp=None):
169
+ """
170
+ Args:
171
+ h, w (int): original image size
172
+ angle (float): degrees for rotation
173
+ expand (bool): choose if the image should be resized to fit the whole
174
+ rotated image (default), or simply cropped
175
+ center (tuple (width, height)): coordinates of the rotation center
176
+ if left to None, the center will be fit to the center of each image
177
+ center has no effect if expand=True because it only affects shifting
178
+ interp: cv2 interpolation method, default cv2.INTER_LINEAR
179
+ """
180
+ super().__init__()
181
+ image_center = np.array((w / 2, h / 2))
182
+ if center is None:
183
+ center = image_center
184
+ if interp is None:
185
+ interp = cv2.INTER_LINEAR
186
+ abs_cos, abs_sin = (abs(np.cos(np.deg2rad(angle))), abs(np.sin(np.deg2rad(angle))))
187
+ if expand:
188
+ # find the new width and height bounds
189
+ bound_w, bound_h = np.rint(
190
+ [h * abs_sin + w * abs_cos, h * abs_cos + w * abs_sin]
191
+ ).astype(int)
192
+ else:
193
+ bound_w, bound_h = w, h
194
+
195
+ self._set_attributes(locals())
196
+ self.rm_coords = self.create_rotation_matrix()
197
+ # Needed because of this problem https://github.com/opencv/opencv/issues/11784
198
+ self.rm_image = self.create_rotation_matrix(offset=-0.5)
199
+
200
+ def apply_image(self, img, interp=None):
201
+ """
202
+ img should be a numpy array, formatted as Height * Width * Nchannels
203
+ """
204
+ if len(img) == 0 or self.angle % 360 == 0:
205
+ return img
206
+ assert img.shape[:2] == (self.h, self.w)
207
+ interp = interp if interp is not None else self.interp
208
+ return cv2.warpAffine(img, self.rm_image, (self.bound_w, self.bound_h), flags=interp)
209
+
210
+ def apply_coords(self, coords):
211
+ """
212
+ coords should be a N * 2 array-like, containing N couples of (x, y) points
213
+ """
214
+ coords = np.asarray(coords, dtype=float)
215
+ if len(coords) == 0 or self.angle % 360 == 0:
216
+ return coords
217
+ return cv2.transform(coords[:, np.newaxis, :], self.rm_coords)[:, 0, :]
218
+
219
+ def apply_segmentation(self, segmentation):
220
+ segmentation = self.apply_image(segmentation, interp=cv2.INTER_NEAREST)
221
+ return segmentation
222
+
223
+ def create_rotation_matrix(self, offset=0):
224
+ center = (self.center[0] + offset, self.center[1] + offset)
225
+ rm = cv2.getRotationMatrix2D(tuple(center), self.angle, 1)
226
+ if self.expand:
227
+ # Find the coordinates of the center of rotation in the new image
228
+ # The only point for which we know the future coordinates is the center of the image
229
+ rot_im_center = cv2.transform(self.image_center[None, None, :] + offset, rm)[0, 0, :]
230
+ new_center = np.array([self.bound_w / 2, self.bound_h / 2]) + offset - rot_im_center
231
+ # shift the rotation center to the new coordinates
232
+ rm[:, 2] += new_center
233
+ return rm
234
+
235
+ def inverse(self):
236
+ """
237
+ The inverse is to rotate it back with expand, and crop to get the original shape.
238
+ """
239
+ if not self.expand: # Not possible to inverse if a part of the image is lost
240
+ raise NotImplementedError()
241
+ rotation = RotationTransform(
242
+ self.bound_h, self.bound_w, -self.angle, True, None, self.interp
243
+ )
244
+ crop = CropTransform(
245
+ (rotation.bound_w - self.w) // 2, (rotation.bound_h - self.h) // 2, self.w, self.h
246
+ )
247
+ return TransformList([rotation, crop])
248
+
249
+
250
+ class ColorTransform(Transform):
251
+ """
252
+ Generic wrapper for any photometric transforms.
253
+ These transformations should only affect the color space and
254
+ not the coordinate space of the image (e.g. annotation
255
+ coordinates such as bounding boxes should not be changed)
256
+ """
257
+
258
+ def __init__(self, op):
259
+ """
260
+ Args:
261
+ op (Callable): operation to be applied to the image,
262
+ which takes in an ndarray and returns an ndarray.
263
+ """
264
+ if not callable(op):
265
+ raise ValueError("op parameter should be callable")
266
+ super().__init__()
267
+ self._set_attributes(locals())
268
+
269
+ def apply_image(self, img):
270
+ return self.op(img)
271
+
272
+ def apply_coords(self, coords):
273
+ return coords
274
+
275
+ def inverse(self):
276
+ return NoOpTransform()
277
+
278
+ def apply_segmentation(self, segmentation):
279
+ return segmentation
280
+
281
+
282
+ class PILColorTransform(ColorTransform):
283
+ """
284
+ Generic wrapper for PIL Photometric image transforms,
285
+ which affect the color space and not the coordinate
286
+ space of the image
287
+ """
288
+
289
+ def __init__(self, op):
290
+ """
291
+ Args:
292
+ op (Callable): operation to be applied to the image,
293
+ which takes in a PIL Image and returns a transformed
294
+ PIL Image.
295
+ For reference on possible operations see:
296
+ - https://pillow.readthedocs.io/en/stable/
297
+ """
298
+ if not callable(op):
299
+ raise ValueError("op parameter should be callable")
300
+ super().__init__(op)
301
+
302
+ def apply_image(self, img):
303
+ img = Image.fromarray(img)
304
+ return np.asarray(super().apply_image(img))
305
+
306
+
307
+ def HFlip_rotated_box(transform, rotated_boxes):
308
+ """
309
+ Apply the horizontal flip transform on rotated boxes.
310
+
311
+ Args:
312
+ rotated_boxes (ndarray): Nx5 floating point array of
313
+ (x_center, y_center, width, height, angle_degrees) format
314
+ in absolute coordinates.
315
+ """
316
+ # Transform x_center
317
+ rotated_boxes[:, 0] = transform.width - rotated_boxes[:, 0]
318
+ # Transform angle
319
+ rotated_boxes[:, 4] = -rotated_boxes[:, 4]
320
+ return rotated_boxes
321
+
322
+
323
+ def Resize_rotated_box(transform, rotated_boxes):
324
+ """
325
+ Apply the resizing transform on rotated boxes. For details of how these (approximation)
326
+ formulas are derived, please refer to :meth:`RotatedBoxes.scale`.
327
+
328
+ Args:
329
+ rotated_boxes (ndarray): Nx5 floating point array of
330
+ (x_center, y_center, width, height, angle_degrees) format
331
+ in absolute coordinates.
332
+ """
333
+ scale_factor_x = transform.new_w * 1.0 / transform.w
334
+ scale_factor_y = transform.new_h * 1.0 / transform.h
335
+ rotated_boxes[:, 0] *= scale_factor_x
336
+ rotated_boxes[:, 1] *= scale_factor_y
337
+ theta = rotated_boxes[:, 4] * np.pi / 180.0
338
+ c = np.cos(theta)
339
+ s = np.sin(theta)
340
+ rotated_boxes[:, 2] *= np.sqrt(np.square(scale_factor_x * c) + np.square(scale_factor_y * s))
341
+ rotated_boxes[:, 3] *= np.sqrt(np.square(scale_factor_x * s) + np.square(scale_factor_y * c))
342
+ rotated_boxes[:, 4] = np.arctan2(scale_factor_x * s, scale_factor_y * c) * 180 / np.pi
343
+
344
+ return rotated_boxes
345
+
346
+
347
+ HFlipTransform.register_type("rotated_box", HFlip_rotated_box)
348
+ ResizeTransform.register_type("rotated_box", Resize_rotated_box)
349
+
350
+ # not necessary any more with latest fvcore
351
+ NoOpTransform.register_type("rotated_box", lambda t, x: x)
CatVTON/detectron2/export/README.md ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ This directory contains code to prepare a detectron2 model for deployment.
3
+ Currently it supports exporting a detectron2 model to TorchScript, ONNX, or (deprecated) Caffe2 format.
4
+
5
+ Please see [documentation](https://detectron2.readthedocs.io/tutorials/deployment.html) for its usage.
6
+
7
+
8
+ ### Acknowledgements
9
+
10
+ Thanks to Mobile Vision team at Facebook for developing the Caffe2 conversion tools.
11
+
12
+ Thanks to Computing Platform Department - PAI team at Alibaba Group (@bddpqq, @chenbohua3) who
13
+ help export Detectron2 models to TorchScript.
14
+
15
+ Thanks to ONNX Converter team at Microsoft who help export Detectron2 models to ONNX.
CatVTON/detectron2/export/__init__.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ import warnings
4
+
5
+ from .flatten import TracingAdapter
6
+ from .torchscript import dump_torchscript_IR, scripting_with_instances
7
+
8
+ try:
9
+ from caffe2.proto import caffe2_pb2 as _tmp
10
+ from caffe2.python import core
11
+
12
+ # caffe2 is optional
13
+ except ImportError:
14
+ pass
15
+ else:
16
+ from .api import *
17
+
18
+
19
+ # TODO: Update ONNX Opset version and run tests when a newer PyTorch is supported
20
+ STABLE_ONNX_OPSET_VERSION = 11
21
+
22
+
23
+ def add_export_config(cfg):
24
+ warnings.warn(
25
+ "add_export_config has been deprecated and behaves as no-op function.", DeprecationWarning
26
+ )
27
+ return cfg
28
+
29
+
30
+ __all__ = [k for k in globals().keys() if not k.startswith("_")]
CatVTON/detectron2/export/api.py ADDED
@@ -0,0 +1,230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ import copy
3
+ import logging
4
+ import os
5
+ import torch
6
+ from caffe2.proto import caffe2_pb2
7
+ from torch import nn
8
+
9
+ from detectron2.config import CfgNode
10
+ from detectron2.utils.file_io import PathManager
11
+
12
+ from .caffe2_inference import ProtobufDetectionModel
13
+ from .caffe2_modeling import META_ARCH_CAFFE2_EXPORT_TYPE_MAP, convert_batched_inputs_to_c2_format
14
+ from .shared import get_pb_arg_vali, get_pb_arg_vals, save_graph
15
+
16
+ __all__ = [
17
+ "Caffe2Model",
18
+ "Caffe2Tracer",
19
+ ]
20
+
21
+
22
+ class Caffe2Tracer:
23
+ """
24
+ Make a detectron2 model traceable with Caffe2 operators.
25
+ This class creates a traceable version of a detectron2 model which:
26
+
27
+ 1. Rewrite parts of the model using ops in Caffe2. Note that some ops do
28
+ not have GPU implementation in Caffe2.
29
+ 2. Remove post-processing and only produce raw layer outputs
30
+
31
+ After making a traceable model, the class provide methods to export such a
32
+ model to different deployment formats.
33
+ Exported graph produced by this class take two input tensors:
34
+
35
+ 1. (1, C, H, W) float "data" which is an image (usually in [0, 255]).
36
+ (H, W) often has to be padded to multiple of 32 (depend on the model
37
+ architecture).
38
+ 2. 1x3 float "im_info", each row of which is (height, width, 1.0).
39
+ Height and width are true image shapes before padding.
40
+
41
+ The class currently only supports models using builtin meta architectures.
42
+ Batch inference is not supported, and contributions are welcome.
43
+ """
44
+
45
+ def __init__(self, cfg: CfgNode, model: nn.Module, inputs):
46
+ """
47
+ Args:
48
+ cfg (CfgNode): a detectron2 config used to construct caffe2-compatible model.
49
+ model (nn.Module): An original pytorch model. Must be among a few official models
50
+ in detectron2 that can be converted to become caffe2-compatible automatically.
51
+ Weights have to be already loaded to this model.
52
+ inputs: sample inputs that the given model takes for inference.
53
+ Will be used to trace the model. For most models, random inputs with
54
+ no detected objects will not work as they lead to wrong traces.
55
+ """
56
+ assert isinstance(cfg, CfgNode), cfg
57
+ assert isinstance(model, torch.nn.Module), type(model)
58
+
59
+ # TODO make it support custom models, by passing in c2 model directly
60
+ C2MetaArch = META_ARCH_CAFFE2_EXPORT_TYPE_MAP[cfg.MODEL.META_ARCHITECTURE]
61
+ self.traceable_model = C2MetaArch(cfg, copy.deepcopy(model))
62
+ self.inputs = inputs
63
+ self.traceable_inputs = self.traceable_model.get_caffe2_inputs(inputs)
64
+
65
+ def export_caffe2(self):
66
+ """
67
+ Export the model to Caffe2's protobuf format.
68
+ The returned object can be saved with its :meth:`.save_protobuf()` method.
69
+ The result can be loaded and executed using Caffe2 runtime.
70
+
71
+ Returns:
72
+ :class:`Caffe2Model`
73
+ """
74
+ from .caffe2_export import export_caffe2_detection_model
75
+
76
+ predict_net, init_net = export_caffe2_detection_model(
77
+ self.traceable_model, self.traceable_inputs
78
+ )
79
+ return Caffe2Model(predict_net, init_net)
80
+
81
+ def export_onnx(self):
82
+ """
83
+ Export the model to ONNX format.
84
+ Note that the exported model contains custom ops only available in caffe2, therefore it
85
+ cannot be directly executed by other runtime (such as onnxruntime or TensorRT).
86
+ Post-processing or transformation passes may be applied on the model to accommodate
87
+ different runtimes, but we currently do not provide support for them.
88
+
89
+ Returns:
90
+ onnx.ModelProto: an onnx model.
91
+ """
92
+ from .caffe2_export import export_onnx_model as export_onnx_model_impl
93
+
94
+ return export_onnx_model_impl(self.traceable_model, (self.traceable_inputs,))
95
+
96
+ def export_torchscript(self):
97
+ """
98
+ Export the model to a ``torch.jit.TracedModule`` by tracing.
99
+ The returned object can be saved to a file by ``.save()``.
100
+
101
+ Returns:
102
+ torch.jit.TracedModule: a torch TracedModule
103
+ """
104
+ logger = logging.getLogger(__name__)
105
+ logger.info("Tracing the model with torch.jit.trace ...")
106
+ with torch.no_grad():
107
+ return torch.jit.trace(self.traceable_model, (self.traceable_inputs,))
108
+
109
+
110
+ class Caffe2Model(nn.Module):
111
+ """
112
+ A wrapper around the traced model in Caffe2's protobuf format.
113
+ The exported graph has different inputs/outputs from the original Pytorch
114
+ model, as explained in :class:`Caffe2Tracer`. This class wraps around the
115
+ exported graph to simulate the same interface as the original Pytorch model.
116
+ It also provides functions to save/load models in Caffe2's format.'
117
+
118
+ Examples:
119
+ ::
120
+ c2_model = Caffe2Tracer(cfg, torch_model, inputs).export_caffe2()
121
+ inputs = [{"image": img_tensor_CHW}]
122
+ outputs = c2_model(inputs)
123
+ orig_outputs = torch_model(inputs)
124
+ """
125
+
126
+ def __init__(self, predict_net, init_net):
127
+ super().__init__()
128
+ self.eval() # always in eval mode
129
+ self._predict_net = predict_net
130
+ self._init_net = init_net
131
+ self._predictor = None
132
+
133
+ __init__.__HIDE_SPHINX_DOC__ = True
134
+
135
+ @property
136
+ def predict_net(self):
137
+ """
138
+ caffe2.core.Net: the underlying caffe2 predict net
139
+ """
140
+ return self._predict_net
141
+
142
+ @property
143
+ def init_net(self):
144
+ """
145
+ caffe2.core.Net: the underlying caffe2 init net
146
+ """
147
+ return self._init_net
148
+
149
+ def save_protobuf(self, output_dir):
150
+ """
151
+ Save the model as caffe2's protobuf format.
152
+ It saves the following files:
153
+
154
+ * "model.pb": definition of the graph. Can be visualized with
155
+ tools like `netron <https://github.com/lutzroeder/netron>`_.
156
+ * "model_init.pb": model parameters
157
+ * "model.pbtxt": human-readable definition of the graph. Not
158
+ needed for deployment.
159
+
160
+ Args:
161
+ output_dir (str): the output directory to save protobuf files.
162
+ """
163
+ logger = logging.getLogger(__name__)
164
+ logger.info("Saving model to {} ...".format(output_dir))
165
+ if not PathManager.exists(output_dir):
166
+ PathManager.mkdirs(output_dir)
167
+
168
+ with PathManager.open(os.path.join(output_dir, "model.pb"), "wb") as f:
169
+ f.write(self._predict_net.SerializeToString())
170
+ with PathManager.open(os.path.join(output_dir, "model.pbtxt"), "w") as f:
171
+ f.write(str(self._predict_net))
172
+ with PathManager.open(os.path.join(output_dir, "model_init.pb"), "wb") as f:
173
+ f.write(self._init_net.SerializeToString())
174
+
175
+ def save_graph(self, output_file, inputs=None):
176
+ """
177
+ Save the graph as SVG format.
178
+
179
+ Args:
180
+ output_file (str): a SVG file
181
+ inputs: optional inputs given to the model.
182
+ If given, the inputs will be used to run the graph to record
183
+ shape of every tensor. The shape information will be
184
+ saved together with the graph.
185
+ """
186
+ from .caffe2_export import run_and_save_graph
187
+
188
+ if inputs is None:
189
+ save_graph(self._predict_net, output_file, op_only=False)
190
+ else:
191
+ size_divisibility = get_pb_arg_vali(self._predict_net, "size_divisibility", 0)
192
+ device = get_pb_arg_vals(self._predict_net, "device", b"cpu").decode("ascii")
193
+ inputs = convert_batched_inputs_to_c2_format(inputs, size_divisibility, device)
194
+ inputs = [x.cpu().numpy() for x in inputs]
195
+ run_and_save_graph(self._predict_net, self._init_net, inputs, output_file)
196
+
197
+ @staticmethod
198
+ def load_protobuf(dir):
199
+ """
200
+ Args:
201
+ dir (str): a directory used to save Caffe2Model with
202
+ :meth:`save_protobuf`.
203
+ The files "model.pb" and "model_init.pb" are needed.
204
+
205
+ Returns:
206
+ Caffe2Model: the caffe2 model loaded from this directory.
207
+ """
208
+ predict_net = caffe2_pb2.NetDef()
209
+ with PathManager.open(os.path.join(dir, "model.pb"), "rb") as f:
210
+ predict_net.ParseFromString(f.read())
211
+
212
+ init_net = caffe2_pb2.NetDef()
213
+ with PathManager.open(os.path.join(dir, "model_init.pb"), "rb") as f:
214
+ init_net.ParseFromString(f.read())
215
+
216
+ return Caffe2Model(predict_net, init_net)
217
+
218
+ def __call__(self, inputs):
219
+ """
220
+ An interface that wraps around a Caffe2 model and mimics detectron2's models'
221
+ input/output format. See details about the format at :doc:`/tutorials/models`.
222
+ This is used to compare the outputs of caffe2 model with its original torch model.
223
+
224
+ Due to the extra conversion between Pytorch/Caffe2, this method is not meant for
225
+ benchmark. Because of the conversion, this method also has dependency
226
+ on detectron2 in order to convert to detectron2's output format.
227
+ """
228
+ if self._predictor is None:
229
+ self._predictor = ProtobufDetectionModel(self._predict_net, self._init_net)
230
+ return self._predictor(inputs)
CatVTON/detectron2/export/c10.py ADDED
@@ -0,0 +1,571 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ import math
4
+ from typing import Dict
5
+ import torch
6
+ import torch.nn.functional as F
7
+
8
+ from detectron2.layers import ShapeSpec, cat
9
+ from detectron2.layers.roi_align_rotated import ROIAlignRotated
10
+ from detectron2.modeling import poolers
11
+ from detectron2.modeling.proposal_generator import rpn
12
+ from detectron2.modeling.roi_heads.mask_head import mask_rcnn_inference
13
+ from detectron2.structures import Boxes, ImageList, Instances, Keypoints, RotatedBoxes
14
+
15
+ from .shared import alias, to_device
16
+
17
+
18
+ """
19
+ This file contains caffe2-compatible implementation of several detectron2 components.
20
+ """
21
+
22
+
23
+ class Caffe2Boxes(Boxes):
24
+ """
25
+ Representing a list of detectron2.structures.Boxes from minibatch, each box
26
+ is represented by a 5d vector (batch index + 4 coordinates), or a 6d vector
27
+ (batch index + 5 coordinates) for RotatedBoxes.
28
+ """
29
+
30
+ def __init__(self, tensor):
31
+ assert isinstance(tensor, torch.Tensor)
32
+ assert tensor.dim() == 2 and tensor.size(-1) in [4, 5, 6], tensor.size()
33
+ # TODO: make tensor immutable when dim is Nx5 for Boxes,
34
+ # and Nx6 for RotatedBoxes?
35
+ self.tensor = tensor
36
+
37
+
38
+ # TODO clean up this class, maybe just extend Instances
39
+ class InstancesList:
40
+ """
41
+ Tensor representation of a list of Instances object for a batch of images.
42
+
43
+ When dealing with a batch of images with Caffe2 ops, a list of bboxes
44
+ (instances) are usually represented by single Tensor with size
45
+ (sigma(Ni), 5) or (sigma(Ni), 4) plus a batch split Tensor. This class is
46
+ for providing common functions to convert between these two representations.
47
+ """
48
+
49
+ def __init__(self, im_info, indices, extra_fields=None):
50
+ # [N, 3] -> (H, W, Scale)
51
+ self.im_info = im_info
52
+ # [N,] -> indice of batch to which the instance belongs
53
+ self.indices = indices
54
+ # [N, ...]
55
+ self.batch_extra_fields = extra_fields or {}
56
+
57
+ self.image_size = self.im_info
58
+
59
+ def get_fields(self):
60
+ """like `get_fields` in the Instances object,
61
+ but return each field in tensor representations"""
62
+ ret = {}
63
+ for k, v in self.batch_extra_fields.items():
64
+ # if isinstance(v, torch.Tensor):
65
+ # tensor_rep = v
66
+ # elif isinstance(v, (Boxes, Keypoints)):
67
+ # tensor_rep = v.tensor
68
+ # else:
69
+ # raise ValueError("Can't find tensor representation for: {}".format())
70
+ ret[k] = v
71
+ return ret
72
+
73
+ def has(self, name):
74
+ return name in self.batch_extra_fields
75
+
76
+ def set(self, name, value):
77
+ # len(tensor) is a bad practice that generates ONNX constants during tracing.
78
+ # Although not a problem for the `assert` statement below, torch ONNX exporter
79
+ # still raises a misleading warning as it does not this call comes from `assert`
80
+ if isinstance(value, Boxes):
81
+ data_len = value.tensor.shape[0]
82
+ elif isinstance(value, torch.Tensor):
83
+ data_len = value.shape[0]
84
+ else:
85
+ data_len = len(value)
86
+ if len(self.batch_extra_fields):
87
+ assert (
88
+ len(self) == data_len
89
+ ), "Adding a field of length {} to a Instances of length {}".format(data_len, len(self))
90
+ self.batch_extra_fields[name] = value
91
+
92
+ def __getattr__(self, name):
93
+ if name not in self.batch_extra_fields:
94
+ raise AttributeError("Cannot find field '{}' in the given Instances!".format(name))
95
+ return self.batch_extra_fields[name]
96
+
97
+ def __len__(self):
98
+ return len(self.indices)
99
+
100
+ def flatten(self):
101
+ ret = []
102
+ for _, v in self.batch_extra_fields.items():
103
+ if isinstance(v, (Boxes, Keypoints)):
104
+ ret.append(v.tensor)
105
+ else:
106
+ ret.append(v)
107
+ return ret
108
+
109
+ @staticmethod
110
+ def to_d2_instances_list(instances_list):
111
+ """
112
+ Convert InstancesList to List[Instances]. The input `instances_list` can
113
+ also be a List[Instances], in this case this method is a non-op.
114
+ """
115
+ if not isinstance(instances_list, InstancesList):
116
+ assert all(isinstance(x, Instances) for x in instances_list)
117
+ return instances_list
118
+
119
+ ret = []
120
+ for i, info in enumerate(instances_list.im_info):
121
+ instances = Instances(torch.Size([int(info[0].item()), int(info[1].item())]))
122
+
123
+ ids = instances_list.indices == i
124
+ for k, v in instances_list.batch_extra_fields.items():
125
+ if isinstance(v, torch.Tensor):
126
+ instances.set(k, v[ids])
127
+ continue
128
+ elif isinstance(v, Boxes):
129
+ instances.set(k, v[ids, -4:])
130
+ continue
131
+
132
+ target_type, tensor_source = v
133
+ assert isinstance(tensor_source, torch.Tensor)
134
+ assert tensor_source.shape[0] == instances_list.indices.shape[0]
135
+ tensor_source = tensor_source[ids]
136
+
137
+ if issubclass(target_type, Boxes):
138
+ instances.set(k, Boxes(tensor_source[:, -4:]))
139
+ elif issubclass(target_type, Keypoints):
140
+ instances.set(k, Keypoints(tensor_source))
141
+ elif issubclass(target_type, torch.Tensor):
142
+ instances.set(k, tensor_source)
143
+ else:
144
+ raise ValueError("Can't handle targe type: {}".format(target_type))
145
+
146
+ ret.append(instances)
147
+ return ret
148
+
149
+
150
+ class Caffe2Compatible:
151
+ """
152
+ A model can inherit this class to indicate that it can be traced and deployed with caffe2.
153
+ """
154
+
155
+ def _get_tensor_mode(self):
156
+ return self._tensor_mode
157
+
158
+ def _set_tensor_mode(self, v):
159
+ self._tensor_mode = v
160
+
161
+ tensor_mode = property(_get_tensor_mode, _set_tensor_mode)
162
+ """
163
+ If true, the model expects C2-style tensor only inputs/outputs format.
164
+ """
165
+
166
+
167
+ class Caffe2RPN(Caffe2Compatible, rpn.RPN):
168
+ @classmethod
169
+ def from_config(cls, cfg, input_shape: Dict[str, ShapeSpec]):
170
+ ret = super(Caffe2Compatible, cls).from_config(cfg, input_shape)
171
+ assert tuple(cfg.MODEL.RPN.BBOX_REG_WEIGHTS) == (1.0, 1.0, 1.0, 1.0) or tuple(
172
+ cfg.MODEL.RPN.BBOX_REG_WEIGHTS
173
+ ) == (1.0, 1.0, 1.0, 1.0, 1.0)
174
+ return ret
175
+
176
+ def _generate_proposals(
177
+ self, images, objectness_logits_pred, anchor_deltas_pred, gt_instances=None
178
+ ):
179
+ assert isinstance(images, ImageList)
180
+ if self.tensor_mode:
181
+ im_info = images.image_sizes
182
+ else:
183
+ im_info = torch.tensor([[im_sz[0], im_sz[1], 1.0] for im_sz in images.image_sizes]).to(
184
+ images.tensor.device
185
+ )
186
+ assert isinstance(im_info, torch.Tensor)
187
+
188
+ rpn_rois_list = []
189
+ rpn_roi_probs_list = []
190
+ for scores, bbox_deltas, cell_anchors_tensor, feat_stride in zip(
191
+ objectness_logits_pred,
192
+ anchor_deltas_pred,
193
+ [b for (n, b) in self.anchor_generator.cell_anchors.named_buffers()],
194
+ self.anchor_generator.strides,
195
+ ):
196
+ scores = scores.detach()
197
+ bbox_deltas = bbox_deltas.detach()
198
+
199
+ rpn_rois, rpn_roi_probs = torch.ops._caffe2.GenerateProposals(
200
+ scores,
201
+ bbox_deltas,
202
+ im_info,
203
+ cell_anchors_tensor,
204
+ spatial_scale=1.0 / feat_stride,
205
+ pre_nms_topN=self.pre_nms_topk[self.training],
206
+ post_nms_topN=self.post_nms_topk[self.training],
207
+ nms_thresh=self.nms_thresh,
208
+ min_size=self.min_box_size,
209
+ # correct_transform_coords=True, # deprecated argument
210
+ angle_bound_on=True, # Default
211
+ angle_bound_lo=-180,
212
+ angle_bound_hi=180,
213
+ clip_angle_thresh=1.0, # Default
214
+ legacy_plus_one=False,
215
+ )
216
+ rpn_rois_list.append(rpn_rois)
217
+ rpn_roi_probs_list.append(rpn_roi_probs)
218
+
219
+ # For FPN in D2, in RPN all proposals from different levels are concated
220
+ # together, ranked and picked by top post_nms_topk. Then in ROIPooler
221
+ # it calculates level_assignments and calls the RoIAlign from
222
+ # the corresponding level.
223
+
224
+ if len(objectness_logits_pred) == 1:
225
+ rpn_rois = rpn_rois_list[0]
226
+ rpn_roi_probs = rpn_roi_probs_list[0]
227
+ else:
228
+ assert len(rpn_rois_list) == len(rpn_roi_probs_list)
229
+ rpn_post_nms_topN = self.post_nms_topk[self.training]
230
+
231
+ device = rpn_rois_list[0].device
232
+ input_list = [to_device(x, "cpu") for x in (rpn_rois_list + rpn_roi_probs_list)]
233
+
234
+ # TODO remove this after confirming rpn_max_level/rpn_min_level
235
+ # is not needed in CollectRpnProposals.
236
+ feature_strides = list(self.anchor_generator.strides)
237
+ rpn_min_level = int(math.log2(feature_strides[0]))
238
+ rpn_max_level = int(math.log2(feature_strides[-1]))
239
+ assert (rpn_max_level - rpn_min_level + 1) == len(
240
+ rpn_rois_list
241
+ ), "CollectRpnProposals requires continuous levels"
242
+
243
+ rpn_rois = torch.ops._caffe2.CollectRpnProposals(
244
+ input_list,
245
+ # NOTE: in current implementation, rpn_max_level and rpn_min_level
246
+ # are not needed, only the subtraction of two matters and it
247
+ # can be infer from the number of inputs. Keep them now for
248
+ # consistency.
249
+ rpn_max_level=2 + len(rpn_rois_list) - 1,
250
+ rpn_min_level=2,
251
+ rpn_post_nms_topN=rpn_post_nms_topN,
252
+ )
253
+ rpn_rois = to_device(rpn_rois, device)
254
+ rpn_roi_probs = []
255
+
256
+ proposals = self.c2_postprocess(im_info, rpn_rois, rpn_roi_probs, self.tensor_mode)
257
+ return proposals, {}
258
+
259
+ def forward(self, images, features, gt_instances=None):
260
+ assert not self.training
261
+ features = [features[f] for f in self.in_features]
262
+ objectness_logits_pred, anchor_deltas_pred = self.rpn_head(features)
263
+ return self._generate_proposals(
264
+ images,
265
+ objectness_logits_pred,
266
+ anchor_deltas_pred,
267
+ gt_instances,
268
+ )
269
+
270
+ @staticmethod
271
+ def c2_postprocess(im_info, rpn_rois, rpn_roi_probs, tensor_mode):
272
+ proposals = InstancesList(
273
+ im_info=im_info,
274
+ indices=rpn_rois[:, 0],
275
+ extra_fields={
276
+ "proposal_boxes": Caffe2Boxes(rpn_rois),
277
+ "objectness_logits": (torch.Tensor, rpn_roi_probs),
278
+ },
279
+ )
280
+ if not tensor_mode:
281
+ proposals = InstancesList.to_d2_instances_list(proposals)
282
+ else:
283
+ proposals = [proposals]
284
+ return proposals
285
+
286
+
287
+ class Caffe2ROIPooler(Caffe2Compatible, poolers.ROIPooler):
288
+ @staticmethod
289
+ def c2_preprocess(box_lists):
290
+ assert all(isinstance(x, Boxes) for x in box_lists)
291
+ if all(isinstance(x, Caffe2Boxes) for x in box_lists):
292
+ # input is pure-tensor based
293
+ assert len(box_lists) == 1
294
+ pooler_fmt_boxes = box_lists[0].tensor
295
+ else:
296
+ pooler_fmt_boxes = poolers.convert_boxes_to_pooler_format(box_lists)
297
+ return pooler_fmt_boxes
298
+
299
+ def forward(self, x, box_lists):
300
+ assert not self.training
301
+
302
+ pooler_fmt_boxes = self.c2_preprocess(box_lists)
303
+ num_level_assignments = len(self.level_poolers)
304
+
305
+ if num_level_assignments == 1:
306
+ if isinstance(self.level_poolers[0], ROIAlignRotated):
307
+ c2_roi_align = torch.ops._caffe2.RoIAlignRotated
308
+ aligned = True
309
+ else:
310
+ c2_roi_align = torch.ops._caffe2.RoIAlign
311
+ aligned = self.level_poolers[0].aligned
312
+
313
+ x0 = x[0]
314
+ if x0.is_quantized:
315
+ x0 = x0.dequantize()
316
+
317
+ out = c2_roi_align(
318
+ x0,
319
+ pooler_fmt_boxes,
320
+ order="NCHW",
321
+ spatial_scale=float(self.level_poolers[0].spatial_scale),
322
+ pooled_h=int(self.output_size[0]),
323
+ pooled_w=int(self.output_size[1]),
324
+ sampling_ratio=int(self.level_poolers[0].sampling_ratio),
325
+ aligned=aligned,
326
+ )
327
+ return out
328
+
329
+ device = pooler_fmt_boxes.device
330
+ assert (
331
+ self.max_level - self.min_level + 1 == 4
332
+ ), "Currently DistributeFpnProposals only support 4 levels"
333
+ fpn_outputs = torch.ops._caffe2.DistributeFpnProposals(
334
+ to_device(pooler_fmt_boxes, "cpu"),
335
+ roi_canonical_scale=self.canonical_box_size,
336
+ roi_canonical_level=self.canonical_level,
337
+ roi_max_level=self.max_level,
338
+ roi_min_level=self.min_level,
339
+ legacy_plus_one=False,
340
+ )
341
+ fpn_outputs = [to_device(x, device) for x in fpn_outputs]
342
+
343
+ rois_fpn_list = fpn_outputs[:-1]
344
+ rois_idx_restore_int32 = fpn_outputs[-1]
345
+
346
+ roi_feat_fpn_list = []
347
+ for roi_fpn, x_level, pooler in zip(rois_fpn_list, x, self.level_poolers):
348
+ if isinstance(pooler, ROIAlignRotated):
349
+ c2_roi_align = torch.ops._caffe2.RoIAlignRotated
350
+ aligned = True
351
+ else:
352
+ c2_roi_align = torch.ops._caffe2.RoIAlign
353
+ aligned = bool(pooler.aligned)
354
+
355
+ if x_level.is_quantized:
356
+ x_level = x_level.dequantize()
357
+
358
+ roi_feat_fpn = c2_roi_align(
359
+ x_level,
360
+ roi_fpn,
361
+ order="NCHW",
362
+ spatial_scale=float(pooler.spatial_scale),
363
+ pooled_h=int(self.output_size[0]),
364
+ pooled_w=int(self.output_size[1]),
365
+ sampling_ratio=int(pooler.sampling_ratio),
366
+ aligned=aligned,
367
+ )
368
+ roi_feat_fpn_list.append(roi_feat_fpn)
369
+
370
+ roi_feat_shuffled = cat(roi_feat_fpn_list, dim=0)
371
+ assert roi_feat_shuffled.numel() > 0 and rois_idx_restore_int32.numel() > 0, (
372
+ "Caffe2 export requires tracing with a model checkpoint + input that can produce valid"
373
+ " detections. But no detections were obtained with the given checkpoint and input!"
374
+ )
375
+ roi_feat = torch.ops._caffe2.BatchPermutation(roi_feat_shuffled, rois_idx_restore_int32)
376
+ return roi_feat
377
+
378
+
379
+ def caffe2_fast_rcnn_outputs_inference(tensor_mode, box_predictor, predictions, proposals):
380
+ """equivalent to FastRCNNOutputLayers.inference"""
381
+ num_classes = box_predictor.num_classes
382
+ score_thresh = box_predictor.test_score_thresh
383
+ nms_thresh = box_predictor.test_nms_thresh
384
+ topk_per_image = box_predictor.test_topk_per_image
385
+ is_rotated = len(box_predictor.box2box_transform.weights) == 5
386
+
387
+ if is_rotated:
388
+ box_dim = 5
389
+ assert box_predictor.box2box_transform.weights[4] == 1, (
390
+ "The weights for Rotated BBoxTransform in C2 have only 4 dimensions,"
391
+ + " thus enforcing the angle weight to be 1 for now"
392
+ )
393
+ box2box_transform_weights = box_predictor.box2box_transform.weights[:4]
394
+ else:
395
+ box_dim = 4
396
+ box2box_transform_weights = box_predictor.box2box_transform.weights
397
+
398
+ class_logits, box_regression = predictions
399
+ if num_classes + 1 == class_logits.shape[1]:
400
+ class_prob = F.softmax(class_logits, -1)
401
+ else:
402
+ assert num_classes == class_logits.shape[1]
403
+ class_prob = F.sigmoid(class_logits)
404
+ # BoxWithNMSLimit will infer num_classes from the shape of the class_prob
405
+ # So append a zero column as placeholder for the background class
406
+ class_prob = torch.cat((class_prob, torch.zeros(class_prob.shape[0], 1)), dim=1)
407
+
408
+ assert box_regression.shape[1] % box_dim == 0
409
+ cls_agnostic_bbox_reg = box_regression.shape[1] // box_dim == 1
410
+
411
+ input_tensor_mode = proposals[0].proposal_boxes.tensor.shape[1] == box_dim + 1
412
+
413
+ proposal_boxes = proposals[0].proposal_boxes
414
+ if isinstance(proposal_boxes, Caffe2Boxes):
415
+ rois = Caffe2Boxes.cat([p.proposal_boxes for p in proposals])
416
+ elif isinstance(proposal_boxes, RotatedBoxes):
417
+ rois = RotatedBoxes.cat([p.proposal_boxes for p in proposals])
418
+ elif isinstance(proposal_boxes, Boxes):
419
+ rois = Boxes.cat([p.proposal_boxes for p in proposals])
420
+ else:
421
+ raise NotImplementedError(
422
+ 'Expected proposals[0].proposal_boxes to be type "Boxes", '
423
+ f"instead got {type(proposal_boxes)}"
424
+ )
425
+
426
+ device, dtype = rois.tensor.device, rois.tensor.dtype
427
+ if input_tensor_mode:
428
+ im_info = proposals[0].image_size
429
+ rois = rois.tensor
430
+ else:
431
+ im_info = torch.tensor([[sz[0], sz[1], 1.0] for sz in [x.image_size for x in proposals]])
432
+ batch_ids = cat(
433
+ [
434
+ torch.full((b, 1), i, dtype=dtype, device=device)
435
+ for i, b in enumerate(len(p) for p in proposals)
436
+ ],
437
+ dim=0,
438
+ )
439
+ rois = torch.cat([batch_ids, rois.tensor], dim=1)
440
+
441
+ roi_pred_bbox, roi_batch_splits = torch.ops._caffe2.BBoxTransform(
442
+ to_device(rois, "cpu"),
443
+ to_device(box_regression, "cpu"),
444
+ to_device(im_info, "cpu"),
445
+ weights=box2box_transform_weights,
446
+ apply_scale=True,
447
+ rotated=is_rotated,
448
+ angle_bound_on=True,
449
+ angle_bound_lo=-180,
450
+ angle_bound_hi=180,
451
+ clip_angle_thresh=1.0,
452
+ legacy_plus_one=False,
453
+ )
454
+ roi_pred_bbox = to_device(roi_pred_bbox, device)
455
+ roi_batch_splits = to_device(roi_batch_splits, device)
456
+
457
+ nms_outputs = torch.ops._caffe2.BoxWithNMSLimit(
458
+ to_device(class_prob, "cpu"),
459
+ to_device(roi_pred_bbox, "cpu"),
460
+ to_device(roi_batch_splits, "cpu"),
461
+ score_thresh=float(score_thresh),
462
+ nms=float(nms_thresh),
463
+ detections_per_im=int(topk_per_image),
464
+ soft_nms_enabled=False,
465
+ soft_nms_method="linear",
466
+ soft_nms_sigma=0.5,
467
+ soft_nms_min_score_thres=0.001,
468
+ rotated=is_rotated,
469
+ cls_agnostic_bbox_reg=cls_agnostic_bbox_reg,
470
+ input_boxes_include_bg_cls=False,
471
+ output_classes_include_bg_cls=False,
472
+ legacy_plus_one=False,
473
+ )
474
+ roi_score_nms = to_device(nms_outputs[0], device)
475
+ roi_bbox_nms = to_device(nms_outputs[1], device)
476
+ roi_class_nms = to_device(nms_outputs[2], device)
477
+ roi_batch_splits_nms = to_device(nms_outputs[3], device)
478
+ roi_keeps_nms = to_device(nms_outputs[4], device)
479
+ roi_keeps_size_nms = to_device(nms_outputs[5], device)
480
+ if not tensor_mode:
481
+ roi_class_nms = roi_class_nms.to(torch.int64)
482
+
483
+ roi_batch_ids = cat(
484
+ [
485
+ torch.full((b, 1), i, dtype=dtype, device=device)
486
+ for i, b in enumerate(int(x.item()) for x in roi_batch_splits_nms)
487
+ ],
488
+ dim=0,
489
+ )
490
+
491
+ roi_class_nms = alias(roi_class_nms, "class_nms")
492
+ roi_score_nms = alias(roi_score_nms, "score_nms")
493
+ roi_bbox_nms = alias(roi_bbox_nms, "bbox_nms")
494
+ roi_batch_splits_nms = alias(roi_batch_splits_nms, "batch_splits_nms")
495
+ roi_keeps_nms = alias(roi_keeps_nms, "keeps_nms")
496
+ roi_keeps_size_nms = alias(roi_keeps_size_nms, "keeps_size_nms")
497
+
498
+ results = InstancesList(
499
+ im_info=im_info,
500
+ indices=roi_batch_ids[:, 0],
501
+ extra_fields={
502
+ "pred_boxes": Caffe2Boxes(roi_bbox_nms),
503
+ "scores": roi_score_nms,
504
+ "pred_classes": roi_class_nms,
505
+ },
506
+ )
507
+
508
+ if not tensor_mode:
509
+ results = InstancesList.to_d2_instances_list(results)
510
+ batch_splits = roi_batch_splits_nms.int().tolist()
511
+ kept_indices = list(roi_keeps_nms.to(torch.int64).split(batch_splits))
512
+ else:
513
+ results = [results]
514
+ kept_indices = [roi_keeps_nms]
515
+
516
+ return results, kept_indices
517
+
518
+
519
+ class Caffe2FastRCNNOutputsInference:
520
+ def __init__(self, tensor_mode):
521
+ self.tensor_mode = tensor_mode # whether the output is caffe2 tensor mode
522
+
523
+ def __call__(self, box_predictor, predictions, proposals):
524
+ return caffe2_fast_rcnn_outputs_inference(
525
+ self.tensor_mode, box_predictor, predictions, proposals
526
+ )
527
+
528
+
529
+ def caffe2_mask_rcnn_inference(pred_mask_logits, pred_instances):
530
+ """equivalent to mask_head.mask_rcnn_inference"""
531
+ if all(isinstance(x, InstancesList) for x in pred_instances):
532
+ assert len(pred_instances) == 1
533
+ mask_probs_pred = pred_mask_logits.sigmoid()
534
+ mask_probs_pred = alias(mask_probs_pred, "mask_fcn_probs")
535
+ pred_instances[0].set("pred_masks", mask_probs_pred)
536
+ else:
537
+ mask_rcnn_inference(pred_mask_logits, pred_instances)
538
+
539
+
540
+ class Caffe2MaskRCNNInference:
541
+ def __call__(self, pred_mask_logits, pred_instances):
542
+ return caffe2_mask_rcnn_inference(pred_mask_logits, pred_instances)
543
+
544
+
545
+ def caffe2_keypoint_rcnn_inference(use_heatmap_max_keypoint, pred_keypoint_logits, pred_instances):
546
+ # just return the keypoint heatmap for now,
547
+ # there will be option to call HeatmapMaxKeypointOp
548
+ output = alias(pred_keypoint_logits, "kps_score")
549
+ if all(isinstance(x, InstancesList) for x in pred_instances):
550
+ assert len(pred_instances) == 1
551
+ if use_heatmap_max_keypoint:
552
+ device = output.device
553
+ output = torch.ops._caffe2.HeatmapMaxKeypoint(
554
+ to_device(output, "cpu"),
555
+ pred_instances[0].pred_boxes.tensor,
556
+ should_output_softmax=True, # worth make it configerable?
557
+ )
558
+ output = to_device(output, device)
559
+ output = alias(output, "keypoints_out")
560
+ pred_instances[0].set("pred_keypoints", output)
561
+ return pred_keypoint_logits
562
+
563
+
564
+ class Caffe2KeypointRCNNInference:
565
+ def __init__(self, use_heatmap_max_keypoint):
566
+ self.use_heatmap_max_keypoint = use_heatmap_max_keypoint
567
+
568
+ def __call__(self, pred_keypoint_logits, pred_instances):
569
+ return caffe2_keypoint_rcnn_inference(
570
+ self.use_heatmap_max_keypoint, pred_keypoint_logits, pred_instances
571
+ )
CatVTON/detectron2/export/caffe2_export.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ import copy
4
+ import io
5
+ import logging
6
+ import numpy as np
7
+ from typing import List
8
+ import onnx
9
+ import onnx.optimizer
10
+ import torch
11
+ from caffe2.proto import caffe2_pb2
12
+ from caffe2.python import core
13
+ from caffe2.python.onnx.backend import Caffe2Backend
14
+ from tabulate import tabulate
15
+ from termcolor import colored
16
+ from torch.onnx import OperatorExportTypes
17
+
18
+ from .shared import (
19
+ ScopedWS,
20
+ construct_init_net_from_params,
21
+ fuse_alias_placeholder,
22
+ fuse_copy_between_cpu_and_gpu,
23
+ get_params_from_init_net,
24
+ group_norm_replace_aten_with_caffe2,
25
+ infer_device_type,
26
+ remove_dead_end_ops,
27
+ remove_reshape_for_fc,
28
+ save_graph,
29
+ )
30
+
31
+ logger = logging.getLogger(__name__)
32
+
33
+
34
+ def export_onnx_model(model, inputs):
35
+ """
36
+ Trace and export a model to onnx format.
37
+
38
+ Args:
39
+ model (nn.Module):
40
+ inputs (tuple[args]): the model will be called by `model(*inputs)`
41
+
42
+ Returns:
43
+ an onnx model
44
+ """
45
+ assert isinstance(model, torch.nn.Module)
46
+
47
+ # make sure all modules are in eval mode, onnx may change the training state
48
+ # of the module if the states are not consistent
49
+ def _check_eval(module):
50
+ assert not module.training
51
+
52
+ model.apply(_check_eval)
53
+
54
+ # Export the model to ONNX
55
+ with torch.no_grad():
56
+ with io.BytesIO() as f:
57
+ torch.onnx.export(
58
+ model,
59
+ inputs,
60
+ f,
61
+ operator_export_type=OperatorExportTypes.ONNX_ATEN_FALLBACK,
62
+ # verbose=True, # NOTE: uncomment this for debugging
63
+ # export_params=True,
64
+ )
65
+ onnx_model = onnx.load_from_string(f.getvalue())
66
+
67
+ return onnx_model
68
+
69
+
70
+ def _op_stats(net_def):
71
+ type_count = {}
72
+ for t in [op.type for op in net_def.op]:
73
+ type_count[t] = type_count.get(t, 0) + 1
74
+ type_count_list = sorted(type_count.items(), key=lambda kv: kv[0]) # alphabet
75
+ type_count_list = sorted(type_count_list, key=lambda kv: -kv[1]) # count
76
+ return "\n".join("{:>4}x {}".format(count, name) for name, count in type_count_list)
77
+
78
+
79
+ def _assign_device_option(
80
+ predict_net: caffe2_pb2.NetDef, init_net: caffe2_pb2.NetDef, tensor_inputs: List[torch.Tensor]
81
+ ):
82
+ """
83
+ ONNX exported network doesn't have concept of device, assign necessary
84
+ device option for each op in order to make it runable on GPU runtime.
85
+ """
86
+
87
+ def _get_device_type(torch_tensor):
88
+ assert torch_tensor.device.type in ["cpu", "cuda"]
89
+ assert torch_tensor.device.index == 0
90
+ return torch_tensor.device.type
91
+
92
+ def _assign_op_device_option(net_proto, net_ssa, blob_device_types):
93
+ for op, ssa_i in zip(net_proto.op, net_ssa):
94
+ if op.type in ["CopyCPUToGPU", "CopyGPUToCPU"]:
95
+ op.device_option.CopyFrom(core.DeviceOption(caffe2_pb2.CUDA, 0))
96
+ else:
97
+ devices = [blob_device_types[b] for b in ssa_i[0] + ssa_i[1]]
98
+ assert all(d == devices[0] for d in devices)
99
+ if devices[0] == "cuda":
100
+ op.device_option.CopyFrom(core.DeviceOption(caffe2_pb2.CUDA, 0))
101
+
102
+ # update ops in predict_net
103
+ predict_net_input_device_types = {
104
+ (name, 0): _get_device_type(tensor)
105
+ for name, tensor in zip(predict_net.external_input, tensor_inputs)
106
+ }
107
+ predict_net_device_types = infer_device_type(
108
+ predict_net, known_status=predict_net_input_device_types, device_name_style="pytorch"
109
+ )
110
+ predict_net_ssa, _ = core.get_ssa(predict_net)
111
+ _assign_op_device_option(predict_net, predict_net_ssa, predict_net_device_types)
112
+
113
+ # update ops in init_net
114
+ init_net_ssa, versions = core.get_ssa(init_net)
115
+ init_net_output_device_types = {
116
+ (name, versions[name]): predict_net_device_types[(name, 0)]
117
+ for name in init_net.external_output
118
+ }
119
+ init_net_device_types = infer_device_type(
120
+ init_net, known_status=init_net_output_device_types, device_name_style="pytorch"
121
+ )
122
+ _assign_op_device_option(init_net, init_net_ssa, init_net_device_types)
123
+
124
+
125
+ def export_caffe2_detection_model(model: torch.nn.Module, tensor_inputs: List[torch.Tensor]):
126
+ """
127
+ Export a caffe2-compatible Detectron2 model to caffe2 format via ONNX.
128
+
129
+ Arg:
130
+ model: a caffe2-compatible version of detectron2 model, defined in caffe2_modeling.py
131
+ tensor_inputs: a list of tensors that caffe2 model takes as input.
132
+ """
133
+ model = copy.deepcopy(model)
134
+ assert isinstance(model, torch.nn.Module)
135
+ assert hasattr(model, "encode_additional_info")
136
+
137
+ # Export via ONNX
138
+ logger.info(
139
+ "Exporting a {} model via ONNX ...".format(type(model).__name__)
140
+ + " Some warnings from ONNX are expected and are usually not to worry about."
141
+ )
142
+ onnx_model = export_onnx_model(model, (tensor_inputs,))
143
+ # Convert ONNX model to Caffe2 protobuf
144
+ init_net, predict_net = Caffe2Backend.onnx_graph_to_caffe2_net(onnx_model)
145
+ ops_table = [[op.type, op.input, op.output] for op in predict_net.op]
146
+ table = tabulate(ops_table, headers=["type", "input", "output"], tablefmt="pipe")
147
+ logger.info(
148
+ "ONNX export Done. Exported predict_net (before optimizations):\n" + colored(table, "cyan")
149
+ )
150
+
151
+ # Apply protobuf optimization
152
+ fuse_alias_placeholder(predict_net, init_net)
153
+ if any(t.device.type != "cpu" for t in tensor_inputs):
154
+ fuse_copy_between_cpu_and_gpu(predict_net)
155
+ remove_dead_end_ops(init_net)
156
+ _assign_device_option(predict_net, init_net, tensor_inputs)
157
+ params, device_options = get_params_from_init_net(init_net)
158
+ predict_net, params = remove_reshape_for_fc(predict_net, params)
159
+ init_net = construct_init_net_from_params(params, device_options)
160
+ group_norm_replace_aten_with_caffe2(predict_net)
161
+
162
+ # Record necessary information for running the pb model in Detectron2 system.
163
+ model.encode_additional_info(predict_net, init_net)
164
+
165
+ logger.info("Operators used in predict_net: \n{}".format(_op_stats(predict_net)))
166
+ logger.info("Operators used in init_net: \n{}".format(_op_stats(init_net)))
167
+
168
+ return predict_net, init_net
169
+
170
+
171
+ def run_and_save_graph(predict_net, init_net, tensor_inputs, graph_save_path):
172
+ """
173
+ Run the caffe2 model on given inputs, recording the shape and draw the graph.
174
+
175
+ predict_net/init_net: caffe2 model.
176
+ tensor_inputs: a list of tensors that caffe2 model takes as input.
177
+ graph_save_path: path for saving graph of exported model.
178
+ """
179
+
180
+ logger.info("Saving graph of ONNX exported model to {} ...".format(graph_save_path))
181
+ save_graph(predict_net, graph_save_path, op_only=False)
182
+
183
+ # Run the exported Caffe2 net
184
+ logger.info("Running ONNX exported model ...")
185
+ with ScopedWS("__ws_tmp__", True) as ws:
186
+ ws.RunNetOnce(init_net)
187
+ initialized_blobs = set(ws.Blobs())
188
+ uninitialized = [inp for inp in predict_net.external_input if inp not in initialized_blobs]
189
+ for name, blob in zip(uninitialized, tensor_inputs):
190
+ ws.FeedBlob(name, blob)
191
+
192
+ try:
193
+ ws.RunNetOnce(predict_net)
194
+ except RuntimeError as e:
195
+ logger.warning("Encountered RuntimeError: \n{}".format(str(e)))
196
+
197
+ ws_blobs = {b: ws.FetchBlob(b) for b in ws.Blobs()}
198
+ blob_sizes = {b: ws_blobs[b].shape for b in ws_blobs if isinstance(ws_blobs[b], np.ndarray)}
199
+
200
+ logger.info("Saving graph with blob shapes to {} ...".format(graph_save_path))
201
+ save_graph(predict_net, graph_save_path, op_only=False, blob_sizes=blob_sizes)
202
+
203
+ return ws_blobs
CatVTON/detectron2/export/caffe2_inference.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ import logging
4
+ import numpy as np
5
+ from itertools import count
6
+ import torch
7
+ from caffe2.proto import caffe2_pb2
8
+ from caffe2.python import core
9
+
10
+ from .caffe2_modeling import META_ARCH_CAFFE2_EXPORT_TYPE_MAP, convert_batched_inputs_to_c2_format
11
+ from .shared import ScopedWS, get_pb_arg_vali, get_pb_arg_vals, infer_device_type
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ # ===== ref: mobile-vision predictor's 'Caffe2Wrapper' class ======
17
+ class ProtobufModel(torch.nn.Module):
18
+ """
19
+ Wrapper of a caffe2's protobuf model.
20
+ It works just like nn.Module, but running caffe2 under the hood.
21
+ Input/Output are tuple[tensor] that match the caffe2 net's external_input/output.
22
+ """
23
+
24
+ _ids = count(0)
25
+
26
+ def __init__(self, predict_net, init_net):
27
+ logger.info(f"Initializing ProtobufModel for: {predict_net.name} ...")
28
+ super().__init__()
29
+ assert isinstance(predict_net, caffe2_pb2.NetDef)
30
+ assert isinstance(init_net, caffe2_pb2.NetDef)
31
+ # create unique temporary workspace for each instance
32
+ self.ws_name = "__tmp_ProtobufModel_{}__".format(next(self._ids))
33
+ self.net = core.Net(predict_net)
34
+
35
+ logger.info("Running init_net once to fill the parameters ...")
36
+ with ScopedWS(self.ws_name, is_reset=True, is_cleanup=False) as ws:
37
+ ws.RunNetOnce(init_net)
38
+ uninitialized_external_input = []
39
+ for blob in self.net.Proto().external_input:
40
+ if blob not in ws.Blobs():
41
+ uninitialized_external_input.append(blob)
42
+ ws.CreateBlob(blob)
43
+ ws.CreateNet(self.net)
44
+
45
+ self._error_msgs = set()
46
+ self._input_blobs = uninitialized_external_input
47
+
48
+ def _infer_output_devices(self, inputs):
49
+ """
50
+ Returns:
51
+ list[str]: list of device for each external output
52
+ """
53
+
54
+ def _get_device_type(torch_tensor):
55
+ assert torch_tensor.device.type in ["cpu", "cuda"]
56
+ assert torch_tensor.device.index == 0
57
+ return torch_tensor.device.type
58
+
59
+ predict_net = self.net.Proto()
60
+ input_device_types = {
61
+ (name, 0): _get_device_type(tensor) for name, tensor in zip(self._input_blobs, inputs)
62
+ }
63
+ device_type_map = infer_device_type(
64
+ predict_net, known_status=input_device_types, device_name_style="pytorch"
65
+ )
66
+ ssa, versions = core.get_ssa(predict_net)
67
+ versioned_outputs = [(name, versions[name]) for name in predict_net.external_output]
68
+ output_devices = [device_type_map[outp] for outp in versioned_outputs]
69
+ return output_devices
70
+
71
+ def forward(self, inputs):
72
+ """
73
+ Args:
74
+ inputs (tuple[torch.Tensor])
75
+
76
+ Returns:
77
+ tuple[torch.Tensor]
78
+ """
79
+ assert len(inputs) == len(self._input_blobs), (
80
+ f"Length of inputs ({len(inputs)}) "
81
+ f"doesn't match the required input blobs: {self._input_blobs}"
82
+ )
83
+
84
+ with ScopedWS(self.ws_name, is_reset=False, is_cleanup=False) as ws:
85
+ for b, tensor in zip(self._input_blobs, inputs):
86
+ ws.FeedBlob(b, tensor)
87
+
88
+ try:
89
+ ws.RunNet(self.net.Proto().name)
90
+ except RuntimeError as e:
91
+ if not str(e) in self._error_msgs:
92
+ self._error_msgs.add(str(e))
93
+ logger.warning("Encountered new RuntimeError: \n{}".format(str(e)))
94
+ logger.warning("Catch the error and use partial results.")
95
+
96
+ c2_outputs = [ws.FetchBlob(b) for b in self.net.Proto().external_output]
97
+ # Remove outputs of current run, this is necessary in order to
98
+ # prevent fetching the result from previous run if the model fails
99
+ # in the middle.
100
+ for b in self.net.Proto().external_output:
101
+ # Needs to create uninitialized blob to make the net runable.
102
+ # This is "equivalent" to: ws.RemoveBlob(b) then ws.CreateBlob(b),
103
+ # but there'no such API.
104
+ ws.FeedBlob(b, f"{b}, a C++ native class of type nullptr (uninitialized).")
105
+
106
+ # Cast output to torch.Tensor on the desired device
107
+ output_devices = (
108
+ self._infer_output_devices(inputs)
109
+ if any(t.device.type != "cpu" for t in inputs)
110
+ else ["cpu" for _ in self.net.Proto().external_output]
111
+ )
112
+
113
+ outputs = []
114
+ for name, c2_output, device in zip(
115
+ self.net.Proto().external_output, c2_outputs, output_devices
116
+ ):
117
+ if not isinstance(c2_output, np.ndarray):
118
+ raise RuntimeError(
119
+ "Invalid output for blob {}, received: {}".format(name, c2_output)
120
+ )
121
+ outputs.append(torch.tensor(c2_output).to(device=device))
122
+ return tuple(outputs)
123
+
124
+
125
+ class ProtobufDetectionModel(torch.nn.Module):
126
+ """
127
+ A class works just like a pytorch meta arch in terms of inference, but running
128
+ caffe2 model under the hood.
129
+ """
130
+
131
+ def __init__(self, predict_net, init_net, *, convert_outputs=None):
132
+ """
133
+ Args:
134
+ predict_net, init_net (core.Net): caffe2 nets
135
+ convert_outptus (callable): a function that converts caffe2
136
+ outputs to the same format of the original pytorch model.
137
+ By default, use the one defined in the caffe2 meta_arch.
138
+ """
139
+ super().__init__()
140
+ self.protobuf_model = ProtobufModel(predict_net, init_net)
141
+ self.size_divisibility = get_pb_arg_vali(predict_net, "size_divisibility", 0)
142
+ self.device = get_pb_arg_vals(predict_net, "device", b"cpu").decode("ascii")
143
+
144
+ if convert_outputs is None:
145
+ meta_arch = get_pb_arg_vals(predict_net, "meta_architecture", b"GeneralizedRCNN")
146
+ meta_arch = META_ARCH_CAFFE2_EXPORT_TYPE_MAP[meta_arch.decode("ascii")]
147
+ self._convert_outputs = meta_arch.get_outputs_converter(predict_net, init_net)
148
+ else:
149
+ self._convert_outputs = convert_outputs
150
+
151
+ def _convert_inputs(self, batched_inputs):
152
+ # currently all models convert inputs in the same way
153
+ return convert_batched_inputs_to_c2_format(
154
+ batched_inputs, self.size_divisibility, self.device
155
+ )
156
+
157
+ def forward(self, batched_inputs):
158
+ c2_inputs = self._convert_inputs(batched_inputs)
159
+ c2_results = self.protobuf_model(c2_inputs)
160
+ c2_results = dict(zip(self.protobuf_model.net.Proto().external_output, c2_results))
161
+ return self._convert_outputs(batched_inputs, c2_inputs, c2_results)
CatVTON/detectron2/export/caffe2_modeling.py ADDED
@@ -0,0 +1,420 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ import functools
4
+ import io
5
+ import struct
6
+ import types
7
+ import torch
8
+
9
+ from detectron2.modeling import meta_arch
10
+ from detectron2.modeling.box_regression import Box2BoxTransform
11
+ from detectron2.modeling.roi_heads import keypoint_head
12
+ from detectron2.structures import Boxes, ImageList, Instances, RotatedBoxes
13
+
14
+ from .c10 import Caffe2Compatible
15
+ from .caffe2_patch import ROIHeadsPatcher, patch_generalized_rcnn
16
+ from .shared import (
17
+ alias,
18
+ check_set_pb_arg,
19
+ get_pb_arg_floats,
20
+ get_pb_arg_valf,
21
+ get_pb_arg_vali,
22
+ get_pb_arg_vals,
23
+ mock_torch_nn_functional_interpolate,
24
+ )
25
+
26
+
27
+ def assemble_rcnn_outputs_by_name(image_sizes, tensor_outputs, force_mask_on=False):
28
+ """
29
+ A function to assemble caffe2 model's outputs (i.e. Dict[str, Tensor])
30
+ to detectron2's format (i.e. list of Instances instance).
31
+ This only works when the model follows the Caffe2 detectron's naming convention.
32
+
33
+ Args:
34
+ image_sizes (List[List[int, int]]): [H, W] of every image.
35
+ tensor_outputs (Dict[str, Tensor]): external_output to its tensor.
36
+
37
+ force_mask_on (Bool): if true, the it make sure there'll be pred_masks even
38
+ if the mask is not found from tensor_outputs (usually due to model crash)
39
+ """
40
+
41
+ results = [Instances(image_size) for image_size in image_sizes]
42
+
43
+ batch_splits = tensor_outputs.get("batch_splits", None)
44
+ if batch_splits:
45
+ raise NotImplementedError()
46
+ assert len(image_sizes) == 1
47
+ result = results[0]
48
+
49
+ bbox_nms = tensor_outputs["bbox_nms"]
50
+ score_nms = tensor_outputs["score_nms"]
51
+ class_nms = tensor_outputs["class_nms"]
52
+ # Detection will always success because Conv support 0-batch
53
+ assert bbox_nms is not None
54
+ assert score_nms is not None
55
+ assert class_nms is not None
56
+ if bbox_nms.shape[1] == 5:
57
+ result.pred_boxes = RotatedBoxes(bbox_nms)
58
+ else:
59
+ result.pred_boxes = Boxes(bbox_nms)
60
+ result.scores = score_nms
61
+ result.pred_classes = class_nms.to(torch.int64)
62
+
63
+ mask_fcn_probs = tensor_outputs.get("mask_fcn_probs", None)
64
+ if mask_fcn_probs is not None:
65
+ # finish the mask pred
66
+ mask_probs_pred = mask_fcn_probs
67
+ num_masks = mask_probs_pred.shape[0]
68
+ class_pred = result.pred_classes
69
+ indices = torch.arange(num_masks, device=class_pred.device)
70
+ mask_probs_pred = mask_probs_pred[indices, class_pred][:, None]
71
+ result.pred_masks = mask_probs_pred
72
+ elif force_mask_on:
73
+ # NOTE: there's no way to know the height/width of mask here, it won't be
74
+ # used anyway when batch size is 0, so just set them to 0.
75
+ result.pred_masks = torch.zeros([0, 1, 0, 0], dtype=torch.uint8)
76
+
77
+ keypoints_out = tensor_outputs.get("keypoints_out", None)
78
+ kps_score = tensor_outputs.get("kps_score", None)
79
+ if keypoints_out is not None:
80
+ # keypoints_out: [N, 4, #kypoints], where 4 is in order of (x, y, score, prob)
81
+ keypoints_tensor = keypoints_out
82
+ # NOTE: it's possible that prob is not calculated if "should_output_softmax"
83
+ # is set to False in HeatmapMaxKeypoint, so just using raw score, seems
84
+ # it doesn't affect mAP. TODO: check more carefully.
85
+ keypoint_xyp = keypoints_tensor.transpose(1, 2)[:, :, [0, 1, 2]]
86
+ result.pred_keypoints = keypoint_xyp
87
+ elif kps_score is not None:
88
+ # keypoint heatmap to sparse data structure
89
+ pred_keypoint_logits = kps_score
90
+ keypoint_head.keypoint_rcnn_inference(pred_keypoint_logits, [result])
91
+
92
+ return results
93
+
94
+
95
+ def _cast_to_f32(f64):
96
+ return struct.unpack("f", struct.pack("f", f64))[0]
97
+
98
+
99
+ def set_caffe2_compatible_tensor_mode(model, enable=True):
100
+ def _fn(m):
101
+ if isinstance(m, Caffe2Compatible):
102
+ m.tensor_mode = enable
103
+
104
+ model.apply(_fn)
105
+
106
+
107
+ def convert_batched_inputs_to_c2_format(batched_inputs, size_divisibility, device):
108
+ """
109
+ See get_caffe2_inputs() below.
110
+ """
111
+ assert all(isinstance(x, dict) for x in batched_inputs)
112
+ assert all(x["image"].dim() == 3 for x in batched_inputs)
113
+
114
+ images = [x["image"] for x in batched_inputs]
115
+ images = ImageList.from_tensors(images, size_divisibility)
116
+
117
+ im_info = []
118
+ for input_per_image, image_size in zip(batched_inputs, images.image_sizes):
119
+ target_height = input_per_image.get("height", image_size[0])
120
+ target_width = input_per_image.get("width", image_size[1]) # noqa
121
+ # NOTE: The scale inside im_info is kept as convention and for providing
122
+ # post-processing information if further processing is needed. For
123
+ # current Caffe2 model definitions that don't include post-processing inside
124
+ # the model, this number is not used.
125
+ # NOTE: There can be a slight difference between width and height
126
+ # scales, using a single number can results in numerical difference
127
+ # compared with D2's post-processing.
128
+ scale = target_height / image_size[0]
129
+ im_info.append([image_size[0], image_size[1], scale])
130
+ im_info = torch.Tensor(im_info)
131
+
132
+ return images.tensor.to(device), im_info.to(device)
133
+
134
+
135
+ class Caffe2MetaArch(Caffe2Compatible, torch.nn.Module):
136
+ """
137
+ Base class for caffe2-compatible implementation of a meta architecture.
138
+ The forward is traceable and its traced graph can be converted to caffe2
139
+ graph through ONNX.
140
+ """
141
+
142
+ def __init__(self, cfg, torch_model, enable_tensor_mode=True):
143
+ """
144
+ Args:
145
+ cfg (CfgNode):
146
+ torch_model (nn.Module): the detectron2 model (meta_arch) to be
147
+ converted.
148
+ """
149
+ super().__init__()
150
+ self._wrapped_model = torch_model
151
+ self.eval()
152
+ set_caffe2_compatible_tensor_mode(self, enable_tensor_mode)
153
+
154
+ def get_caffe2_inputs(self, batched_inputs):
155
+ """
156
+ Convert pytorch-style structured inputs to caffe2-style inputs that
157
+ are tuples of tensors.
158
+
159
+ Args:
160
+ batched_inputs (list[dict]): inputs to a detectron2 model
161
+ in its standard format. Each dict has "image" (CHW tensor), and optionally
162
+ "height" and "width".
163
+
164
+ Returns:
165
+ tuple[Tensor]:
166
+ tuple of tensors that will be the inputs to the
167
+ :meth:`forward` method. For existing models, the first
168
+ is an NCHW tensor (padded and batched); the second is
169
+ a im_info Nx3 tensor, where the rows are
170
+ (height, width, unused legacy parameter)
171
+ """
172
+ return convert_batched_inputs_to_c2_format(
173
+ batched_inputs,
174
+ self._wrapped_model.backbone.size_divisibility,
175
+ self._wrapped_model.device,
176
+ )
177
+
178
+ def encode_additional_info(self, predict_net, init_net):
179
+ """
180
+ Save extra metadata that will be used by inference in the output protobuf.
181
+ """
182
+ pass
183
+
184
+ def forward(self, inputs):
185
+ """
186
+ Run the forward in caffe2-style. It has to use caffe2-compatible ops
187
+ and the method will be used for tracing.
188
+
189
+ Args:
190
+ inputs (tuple[Tensor]): inputs defined by :meth:`get_caffe2_input`.
191
+ They will be the inputs of the converted caffe2 graph.
192
+
193
+ Returns:
194
+ tuple[Tensor]: output tensors. They will be the outputs of the
195
+ converted caffe2 graph.
196
+ """
197
+ raise NotImplementedError
198
+
199
+ def _caffe2_preprocess_image(self, inputs):
200
+ """
201
+ Caffe2 implementation of preprocess_image, which is called inside each MetaArch's forward.
202
+ It normalizes the input images, and the final caffe2 graph assumes the
203
+ inputs have been batched already.
204
+ """
205
+ data, im_info = inputs
206
+ data = alias(data, "data")
207
+ im_info = alias(im_info, "im_info")
208
+ mean, std = self._wrapped_model.pixel_mean, self._wrapped_model.pixel_std
209
+ normalized_data = (data - mean) / std
210
+ normalized_data = alias(normalized_data, "normalized_data")
211
+
212
+ # Pack (data, im_info) into ImageList which is recognized by self.inference.
213
+ images = ImageList(tensor=normalized_data, image_sizes=im_info)
214
+ return images
215
+
216
+ @staticmethod
217
+ def get_outputs_converter(predict_net, init_net):
218
+ """
219
+ Creates a function that converts outputs of the caffe2 model to
220
+ detectron2's standard format.
221
+ The function uses information in `predict_net` and `init_net` that are
222
+ available at inferene time. Therefore the function logic can be used in inference.
223
+
224
+ The returned function has the following signature:
225
+
226
+ def convert(batched_inputs, c2_inputs, c2_results) -> detectron2_outputs
227
+
228
+ Where
229
+
230
+ * batched_inputs (list[dict]): the original input format of the meta arch
231
+ * c2_inputs (tuple[Tensor]): the caffe2 inputs.
232
+ * c2_results (dict[str, Tensor]): the caffe2 output format,
233
+ corresponding to the outputs of the :meth:`forward` function.
234
+ * detectron2_outputs: the original output format of the meta arch.
235
+
236
+ This function can be used to compare the outputs of the original meta arch and
237
+ the converted caffe2 graph.
238
+
239
+ Returns:
240
+ callable: a callable of the above signature.
241
+ """
242
+ raise NotImplementedError
243
+
244
+
245
+ class Caffe2GeneralizedRCNN(Caffe2MetaArch):
246
+ def __init__(self, cfg, torch_model, enable_tensor_mode=True):
247
+ assert isinstance(torch_model, meta_arch.GeneralizedRCNN)
248
+ torch_model = patch_generalized_rcnn(torch_model)
249
+ super().__init__(cfg, torch_model, enable_tensor_mode)
250
+
251
+ try:
252
+ use_heatmap_max_keypoint = cfg.EXPORT_CAFFE2.USE_HEATMAP_MAX_KEYPOINT
253
+ except AttributeError:
254
+ use_heatmap_max_keypoint = False
255
+ self.roi_heads_patcher = ROIHeadsPatcher(
256
+ self._wrapped_model.roi_heads, use_heatmap_max_keypoint
257
+ )
258
+ if self.tensor_mode:
259
+ self.roi_heads_patcher.patch_roi_heads()
260
+
261
+ def encode_additional_info(self, predict_net, init_net):
262
+ size_divisibility = self._wrapped_model.backbone.size_divisibility
263
+ check_set_pb_arg(predict_net, "size_divisibility", "i", size_divisibility)
264
+ check_set_pb_arg(
265
+ predict_net, "device", "s", str.encode(str(self._wrapped_model.device), "ascii")
266
+ )
267
+ check_set_pb_arg(predict_net, "meta_architecture", "s", b"GeneralizedRCNN")
268
+
269
+ @mock_torch_nn_functional_interpolate()
270
+ def forward(self, inputs):
271
+ if not self.tensor_mode:
272
+ return self._wrapped_model.inference(inputs)
273
+ images = self._caffe2_preprocess_image(inputs)
274
+ features = self._wrapped_model.backbone(images.tensor)
275
+ proposals, _ = self._wrapped_model.proposal_generator(images, features)
276
+ detector_results, _ = self._wrapped_model.roi_heads(images, features, proposals)
277
+ return tuple(detector_results[0].flatten())
278
+
279
+ @staticmethod
280
+ def get_outputs_converter(predict_net, init_net):
281
+ def f(batched_inputs, c2_inputs, c2_results):
282
+ _, im_info = c2_inputs
283
+ image_sizes = [[int(im[0]), int(im[1])] for im in im_info]
284
+ results = assemble_rcnn_outputs_by_name(image_sizes, c2_results)
285
+ return meta_arch.GeneralizedRCNN._postprocess(results, batched_inputs, image_sizes)
286
+
287
+ return f
288
+
289
+
290
+ class Caffe2RetinaNet(Caffe2MetaArch):
291
+ def __init__(self, cfg, torch_model):
292
+ assert isinstance(torch_model, meta_arch.RetinaNet)
293
+ super().__init__(cfg, torch_model)
294
+
295
+ @mock_torch_nn_functional_interpolate()
296
+ def forward(self, inputs):
297
+ assert self.tensor_mode
298
+ images = self._caffe2_preprocess_image(inputs)
299
+
300
+ # explicitly return the images sizes to avoid removing "im_info" by ONNX
301
+ # since it's not used in the forward path
302
+ return_tensors = [images.image_sizes]
303
+
304
+ features = self._wrapped_model.backbone(images.tensor)
305
+ features = [features[f] for f in self._wrapped_model.head_in_features]
306
+ for i, feature_i in enumerate(features):
307
+ features[i] = alias(feature_i, "feature_{}".format(i), is_backward=True)
308
+ return_tensors.append(features[i])
309
+
310
+ pred_logits, pred_anchor_deltas = self._wrapped_model.head(features)
311
+ for i, (box_cls_i, box_delta_i) in enumerate(zip(pred_logits, pred_anchor_deltas)):
312
+ return_tensors.append(alias(box_cls_i, "box_cls_{}".format(i)))
313
+ return_tensors.append(alias(box_delta_i, "box_delta_{}".format(i)))
314
+
315
+ return tuple(return_tensors)
316
+
317
+ def encode_additional_info(self, predict_net, init_net):
318
+ size_divisibility = self._wrapped_model.backbone.size_divisibility
319
+ check_set_pb_arg(predict_net, "size_divisibility", "i", size_divisibility)
320
+ check_set_pb_arg(
321
+ predict_net, "device", "s", str.encode(str(self._wrapped_model.device), "ascii")
322
+ )
323
+ check_set_pb_arg(predict_net, "meta_architecture", "s", b"RetinaNet")
324
+
325
+ # Inference parameters:
326
+ check_set_pb_arg(
327
+ predict_net, "score_threshold", "f", _cast_to_f32(self._wrapped_model.test_score_thresh)
328
+ )
329
+ check_set_pb_arg(
330
+ predict_net, "topk_candidates", "i", self._wrapped_model.test_topk_candidates
331
+ )
332
+ check_set_pb_arg(
333
+ predict_net, "nms_threshold", "f", _cast_to_f32(self._wrapped_model.test_nms_thresh)
334
+ )
335
+ check_set_pb_arg(
336
+ predict_net,
337
+ "max_detections_per_image",
338
+ "i",
339
+ self._wrapped_model.max_detections_per_image,
340
+ )
341
+
342
+ check_set_pb_arg(
343
+ predict_net,
344
+ "bbox_reg_weights",
345
+ "floats",
346
+ [_cast_to_f32(w) for w in self._wrapped_model.box2box_transform.weights],
347
+ )
348
+ self._encode_anchor_generator_cfg(predict_net)
349
+
350
+ def _encode_anchor_generator_cfg(self, predict_net):
351
+ # serialize anchor_generator for future use
352
+ serialized_anchor_generator = io.BytesIO()
353
+ torch.save(self._wrapped_model.anchor_generator, serialized_anchor_generator)
354
+ # Ideally we can put anchor generating inside the model, then we don't
355
+ # need to store this information.
356
+ bytes = serialized_anchor_generator.getvalue()
357
+ check_set_pb_arg(predict_net, "serialized_anchor_generator", "s", bytes)
358
+
359
+ @staticmethod
360
+ def get_outputs_converter(predict_net, init_net):
361
+ self = types.SimpleNamespace()
362
+ serialized_anchor_generator = io.BytesIO(
363
+ get_pb_arg_vals(predict_net, "serialized_anchor_generator", None)
364
+ )
365
+ self.anchor_generator = torch.load(serialized_anchor_generator)
366
+ bbox_reg_weights = get_pb_arg_floats(predict_net, "bbox_reg_weights", None)
367
+ self.box2box_transform = Box2BoxTransform(weights=tuple(bbox_reg_weights))
368
+ self.test_score_thresh = get_pb_arg_valf(predict_net, "score_threshold", None)
369
+ self.test_topk_candidates = get_pb_arg_vali(predict_net, "topk_candidates", None)
370
+ self.test_nms_thresh = get_pb_arg_valf(predict_net, "nms_threshold", None)
371
+ self.max_detections_per_image = get_pb_arg_vali(
372
+ predict_net, "max_detections_per_image", None
373
+ )
374
+
375
+ # hack to reuse inference code from RetinaNet
376
+ for meth in [
377
+ "forward_inference",
378
+ "inference_single_image",
379
+ "_transpose_dense_predictions",
380
+ "_decode_multi_level_predictions",
381
+ "_decode_per_level_predictions",
382
+ ]:
383
+ setattr(self, meth, functools.partial(getattr(meta_arch.RetinaNet, meth), self))
384
+
385
+ def f(batched_inputs, c2_inputs, c2_results):
386
+ _, im_info = c2_inputs
387
+ image_sizes = [[int(im[0]), int(im[1])] for im in im_info]
388
+ dummy_images = ImageList(
389
+ torch.randn(
390
+ (
391
+ len(im_info),
392
+ 3,
393
+ )
394
+ + tuple(image_sizes[0])
395
+ ),
396
+ image_sizes,
397
+ )
398
+
399
+ num_features = len([x for x in c2_results.keys() if x.startswith("box_cls_")])
400
+ pred_logits = [c2_results["box_cls_{}".format(i)] for i in range(num_features)]
401
+ pred_anchor_deltas = [c2_results["box_delta_{}".format(i)] for i in range(num_features)]
402
+
403
+ # For each feature level, feature should have the same batch size and
404
+ # spatial dimension as the box_cls and box_delta.
405
+ dummy_features = [x.clone()[:, 0:0, :, :] for x in pred_logits]
406
+ # self.num_classess can be inferred
407
+ self.num_classes = pred_logits[0].shape[1] // (pred_anchor_deltas[0].shape[1] // 4)
408
+
409
+ results = self.forward_inference(
410
+ dummy_images, dummy_features, [pred_logits, pred_anchor_deltas]
411
+ )
412
+ return meta_arch.GeneralizedRCNN._postprocess(results, batched_inputs, image_sizes)
413
+
414
+ return f
415
+
416
+
417
+ META_ARCH_CAFFE2_EXPORT_TYPE_MAP = {
418
+ "GeneralizedRCNN": Caffe2GeneralizedRCNN,
419
+ "RetinaNet": Caffe2RetinaNet,
420
+ }
CatVTON/detectron2/export/caffe2_patch.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ import contextlib
4
+ from unittest import mock
5
+ import torch
6
+
7
+ from detectron2.modeling import poolers
8
+ from detectron2.modeling.proposal_generator import rpn
9
+ from detectron2.modeling.roi_heads import keypoint_head, mask_head
10
+ from detectron2.modeling.roi_heads.fast_rcnn import FastRCNNOutputLayers
11
+
12
+ from .c10 import (
13
+ Caffe2Compatible,
14
+ Caffe2FastRCNNOutputsInference,
15
+ Caffe2KeypointRCNNInference,
16
+ Caffe2MaskRCNNInference,
17
+ Caffe2ROIPooler,
18
+ Caffe2RPN,
19
+ caffe2_fast_rcnn_outputs_inference,
20
+ caffe2_keypoint_rcnn_inference,
21
+ caffe2_mask_rcnn_inference,
22
+ )
23
+
24
+
25
+ class GenericMixin:
26
+ pass
27
+
28
+
29
+ class Caffe2CompatibleConverter:
30
+ """
31
+ A GenericUpdater which implements the `create_from` interface, by modifying
32
+ module object and assign it with another class replaceCls.
33
+ """
34
+
35
+ def __init__(self, replaceCls):
36
+ self.replaceCls = replaceCls
37
+
38
+ def create_from(self, module):
39
+ # update module's class to the new class
40
+ assert isinstance(module, torch.nn.Module)
41
+ if issubclass(self.replaceCls, GenericMixin):
42
+ # replaceCls should act as mixin, create a new class on-the-fly
43
+ new_class = type(
44
+ "{}MixedWith{}".format(self.replaceCls.__name__, module.__class__.__name__),
45
+ (self.replaceCls, module.__class__),
46
+ {}, # {"new_method": lambda self: ...},
47
+ )
48
+ module.__class__ = new_class
49
+ else:
50
+ # replaceCls is complete class, this allow arbitrary class swap
51
+ module.__class__ = self.replaceCls
52
+
53
+ # initialize Caffe2Compatible
54
+ if isinstance(module, Caffe2Compatible):
55
+ module.tensor_mode = False
56
+
57
+ return module
58
+
59
+
60
+ def patch(model, target, updater, *args, **kwargs):
61
+ """
62
+ recursively (post-order) update all modules with the target type and its
63
+ subclasses, make a initialization/composition/inheritance/... via the
64
+ updater.create_from.
65
+ """
66
+ for name, module in model.named_children():
67
+ model._modules[name] = patch(module, target, updater, *args, **kwargs)
68
+ if isinstance(model, target):
69
+ return updater.create_from(model, *args, **kwargs)
70
+ return model
71
+
72
+
73
+ def patch_generalized_rcnn(model):
74
+ ccc = Caffe2CompatibleConverter
75
+ model = patch(model, rpn.RPN, ccc(Caffe2RPN))
76
+ model = patch(model, poolers.ROIPooler, ccc(Caffe2ROIPooler))
77
+
78
+ return model
79
+
80
+
81
+ @contextlib.contextmanager
82
+ def mock_fastrcnn_outputs_inference(
83
+ tensor_mode, check=True, box_predictor_type=FastRCNNOutputLayers
84
+ ):
85
+ with mock.patch.object(
86
+ box_predictor_type,
87
+ "inference",
88
+ autospec=True,
89
+ side_effect=Caffe2FastRCNNOutputsInference(tensor_mode),
90
+ ) as mocked_func:
91
+ yield
92
+ if check:
93
+ assert mocked_func.call_count > 0
94
+
95
+
96
+ @contextlib.contextmanager
97
+ def mock_mask_rcnn_inference(tensor_mode, patched_module, check=True):
98
+ with mock.patch(
99
+ "{}.mask_rcnn_inference".format(patched_module), side_effect=Caffe2MaskRCNNInference()
100
+ ) as mocked_func:
101
+ yield
102
+ if check:
103
+ assert mocked_func.call_count > 0
104
+
105
+
106
+ @contextlib.contextmanager
107
+ def mock_keypoint_rcnn_inference(tensor_mode, patched_module, use_heatmap_max_keypoint, check=True):
108
+ with mock.patch(
109
+ "{}.keypoint_rcnn_inference".format(patched_module),
110
+ side_effect=Caffe2KeypointRCNNInference(use_heatmap_max_keypoint),
111
+ ) as mocked_func:
112
+ yield
113
+ if check:
114
+ assert mocked_func.call_count > 0
115
+
116
+
117
+ class ROIHeadsPatcher:
118
+ def __init__(self, heads, use_heatmap_max_keypoint):
119
+ self.heads = heads
120
+ self.use_heatmap_max_keypoint = use_heatmap_max_keypoint
121
+ self.previous_patched = {}
122
+
123
+ @contextlib.contextmanager
124
+ def mock_roi_heads(self, tensor_mode=True):
125
+ """
126
+ Patching several inference functions inside ROIHeads and its subclasses
127
+
128
+ Args:
129
+ tensor_mode (bool): whether the inputs/outputs are caffe2's tensor
130
+ format or not. Default to True.
131
+ """
132
+ # NOTE: this requries the `keypoint_rcnn_inference` and `mask_rcnn_inference`
133
+ # are called inside the same file as BaseXxxHead due to using mock.patch.
134
+ kpt_heads_mod = keypoint_head.BaseKeypointRCNNHead.__module__
135
+ mask_head_mod = mask_head.BaseMaskRCNNHead.__module__
136
+
137
+ mock_ctx_managers = [
138
+ mock_fastrcnn_outputs_inference(
139
+ tensor_mode=tensor_mode,
140
+ check=True,
141
+ box_predictor_type=type(self.heads.box_predictor),
142
+ )
143
+ ]
144
+ if getattr(self.heads, "keypoint_on", False):
145
+ mock_ctx_managers += [
146
+ mock_keypoint_rcnn_inference(
147
+ tensor_mode, kpt_heads_mod, self.use_heatmap_max_keypoint
148
+ )
149
+ ]
150
+ if getattr(self.heads, "mask_on", False):
151
+ mock_ctx_managers += [mock_mask_rcnn_inference(tensor_mode, mask_head_mod)]
152
+
153
+ with contextlib.ExitStack() as stack: # python 3.3+
154
+ for mgr in mock_ctx_managers:
155
+ stack.enter_context(mgr)
156
+ yield
157
+
158
+ def patch_roi_heads(self, tensor_mode=True):
159
+ self.previous_patched["box_predictor"] = self.heads.box_predictor.inference
160
+ self.previous_patched["keypoint_rcnn"] = keypoint_head.keypoint_rcnn_inference
161
+ self.previous_patched["mask_rcnn"] = mask_head.mask_rcnn_inference
162
+
163
+ def patched_fastrcnn_outputs_inference(predictions, proposal):
164
+ return caffe2_fast_rcnn_outputs_inference(
165
+ True, self.heads.box_predictor, predictions, proposal
166
+ )
167
+
168
+ self.heads.box_predictor.inference = patched_fastrcnn_outputs_inference
169
+
170
+ if getattr(self.heads, "keypoint_on", False):
171
+
172
+ def patched_keypoint_rcnn_inference(pred_keypoint_logits, pred_instances):
173
+ return caffe2_keypoint_rcnn_inference(
174
+ self.use_heatmap_max_keypoint, pred_keypoint_logits, pred_instances
175
+ )
176
+
177
+ keypoint_head.keypoint_rcnn_inference = patched_keypoint_rcnn_inference
178
+
179
+ if getattr(self.heads, "mask_on", False):
180
+
181
+ def patched_mask_rcnn_inference(pred_mask_logits, pred_instances):
182
+ return caffe2_mask_rcnn_inference(pred_mask_logits, pred_instances)
183
+
184
+ mask_head.mask_rcnn_inference = patched_mask_rcnn_inference
185
+
186
+ def unpatch_roi_heads(self):
187
+ self.heads.box_predictor.inference = self.previous_patched["box_predictor"]
188
+ keypoint_head.keypoint_rcnn_inference = self.previous_patched["keypoint_rcnn"]
189
+ mask_head.mask_rcnn_inference = self.previous_patched["mask_rcnn"]
CatVTON/detectron2/export/flatten.py ADDED
@@ -0,0 +1,330 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ import collections
3
+ from dataclasses import dataclass
4
+ from typing import Callable, List, Optional, Tuple
5
+ import torch
6
+ from torch import nn
7
+
8
+ from detectron2.structures import Boxes, Instances, ROIMasks
9
+ from detectron2.utils.registry import _convert_target_to_string, locate
10
+
11
+ from .torchscript_patch import patch_builtin_len
12
+
13
+
14
+ @dataclass
15
+ class Schema:
16
+ """
17
+ A Schema defines how to flatten a possibly hierarchical object into tuple of
18
+ primitive objects, so it can be used as inputs/outputs of PyTorch's tracing.
19
+
20
+ PyTorch does not support tracing a function that produces rich output
21
+ structures (e.g. dict, Instances, Boxes). To trace such a function, we
22
+ flatten the rich object into tuple of tensors, and return this tuple of tensors
23
+ instead. Meanwhile, we also need to know how to "rebuild" the original object
24
+ from the flattened results, so we can evaluate the flattened results.
25
+ A Schema defines how to flatten an object, and while flattening it, it records
26
+ necessary schemas so that the object can be rebuilt using the flattened outputs.
27
+
28
+ The flattened object and the schema object is returned by ``.flatten`` classmethod.
29
+ Then the original object can be rebuilt with the ``__call__`` method of schema.
30
+
31
+ A Schema is a dataclass that can be serialized easily.
32
+ """
33
+
34
+ # inspired by FetchMapper in tensorflow/python/client/session.py
35
+
36
+ @classmethod
37
+ def flatten(cls, obj):
38
+ raise NotImplementedError
39
+
40
+ def __call__(self, values):
41
+ raise NotImplementedError
42
+
43
+ @staticmethod
44
+ def _concat(values):
45
+ ret = ()
46
+ sizes = []
47
+ for v in values:
48
+ assert isinstance(v, tuple), "Flattened results must be a tuple"
49
+ ret = ret + v
50
+ sizes.append(len(v))
51
+ return ret, sizes
52
+
53
+ @staticmethod
54
+ def _split(values, sizes):
55
+ if len(sizes):
56
+ expected_len = sum(sizes)
57
+ assert (
58
+ len(values) == expected_len
59
+ ), f"Values has length {len(values)} but expect length {expected_len}."
60
+ ret = []
61
+ for k in range(len(sizes)):
62
+ begin, end = sum(sizes[:k]), sum(sizes[: k + 1])
63
+ ret.append(values[begin:end])
64
+ return ret
65
+
66
+
67
+ @dataclass
68
+ class ListSchema(Schema):
69
+ schemas: List[Schema] # the schemas that define how to flatten each element in the list
70
+ sizes: List[int] # the flattened length of each element
71
+
72
+ def __call__(self, values):
73
+ values = self._split(values, self.sizes)
74
+ if len(values) != len(self.schemas):
75
+ raise ValueError(
76
+ f"Values has length {len(values)} but schemas " f"has length {len(self.schemas)}!"
77
+ )
78
+ values = [m(v) for m, v in zip(self.schemas, values)]
79
+ return list(values)
80
+
81
+ @classmethod
82
+ def flatten(cls, obj):
83
+ res = [flatten_to_tuple(k) for k in obj]
84
+ values, sizes = cls._concat([k[0] for k in res])
85
+ return values, cls([k[1] for k in res], sizes)
86
+
87
+
88
+ @dataclass
89
+ class TupleSchema(ListSchema):
90
+ def __call__(self, values):
91
+ return tuple(super().__call__(values))
92
+
93
+
94
+ @dataclass
95
+ class IdentitySchema(Schema):
96
+ def __call__(self, values):
97
+ return values[0]
98
+
99
+ @classmethod
100
+ def flatten(cls, obj):
101
+ return (obj,), cls()
102
+
103
+
104
+ @dataclass
105
+ class DictSchema(ListSchema):
106
+ keys: List[str]
107
+
108
+ def __call__(self, values):
109
+ values = super().__call__(values)
110
+ return dict(zip(self.keys, values))
111
+
112
+ @classmethod
113
+ def flatten(cls, obj):
114
+ for k in obj.keys():
115
+ if not isinstance(k, str):
116
+ raise KeyError("Only support flattening dictionaries if keys are str.")
117
+ keys = sorted(obj.keys())
118
+ values = [obj[k] for k in keys]
119
+ ret, schema = ListSchema.flatten(values)
120
+ return ret, cls(schema.schemas, schema.sizes, keys)
121
+
122
+
123
+ @dataclass
124
+ class InstancesSchema(DictSchema):
125
+ def __call__(self, values):
126
+ image_size, fields = values[-1], values[:-1]
127
+ fields = super().__call__(fields)
128
+ return Instances(image_size, **fields)
129
+
130
+ @classmethod
131
+ def flatten(cls, obj):
132
+ ret, schema = super().flatten(obj.get_fields())
133
+ size = obj.image_size
134
+ if not isinstance(size, torch.Tensor):
135
+ size = torch.tensor(size)
136
+ return ret + (size,), schema
137
+
138
+
139
+ @dataclass
140
+ class TensorWrapSchema(Schema):
141
+ """
142
+ For classes that are simple wrapper of tensors, e.g.
143
+ Boxes, RotatedBoxes, BitMasks
144
+ """
145
+
146
+ class_name: str
147
+
148
+ def __call__(self, values):
149
+ return locate(self.class_name)(values[0])
150
+
151
+ @classmethod
152
+ def flatten(cls, obj):
153
+ return (obj.tensor,), cls(_convert_target_to_string(type(obj)))
154
+
155
+
156
+ # if more custom structures needed in the future, can allow
157
+ # passing in extra schemas for custom types
158
+ def flatten_to_tuple(obj):
159
+ """
160
+ Flatten an object so it can be used for PyTorch tracing.
161
+ Also returns how to rebuild the original object from the flattened outputs.
162
+
163
+ Returns:
164
+ res (tuple): the flattened results that can be used as tracing outputs
165
+ schema: an object with a ``__call__`` method such that ``schema(res) == obj``.
166
+ It is a pure dataclass that can be serialized.
167
+ """
168
+ schemas = [
169
+ ((str, bytes), IdentitySchema),
170
+ (list, ListSchema),
171
+ (tuple, TupleSchema),
172
+ (collections.abc.Mapping, DictSchema),
173
+ (Instances, InstancesSchema),
174
+ ((Boxes, ROIMasks), TensorWrapSchema),
175
+ ]
176
+ for klass, schema in schemas:
177
+ if isinstance(obj, klass):
178
+ F = schema
179
+ break
180
+ else:
181
+ F = IdentitySchema
182
+
183
+ return F.flatten(obj)
184
+
185
+
186
+ class TracingAdapter(nn.Module):
187
+ """
188
+ A model may take rich input/output format (e.g. dict or custom classes),
189
+ but `torch.jit.trace` requires tuple of tensors as input/output.
190
+ This adapter flattens input/output format of a model so it becomes traceable.
191
+
192
+ It also records the necessary schema to rebuild model's inputs/outputs from flattened
193
+ inputs/outputs.
194
+
195
+ Example:
196
+ ::
197
+ outputs = model(inputs) # inputs/outputs may be rich structure
198
+ adapter = TracingAdapter(model, inputs)
199
+
200
+ # can now trace the model, with adapter.flattened_inputs, or another
201
+ # tuple of tensors with the same length and meaning
202
+ traced = torch.jit.trace(adapter, adapter.flattened_inputs)
203
+
204
+ # traced model can only produce flattened outputs (tuple of tensors)
205
+ flattened_outputs = traced(*adapter.flattened_inputs)
206
+ # adapter knows the schema to convert it back (new_outputs == outputs)
207
+ new_outputs = adapter.outputs_schema(flattened_outputs)
208
+ """
209
+
210
+ flattened_inputs: Tuple[torch.Tensor] = None
211
+ """
212
+ Flattened version of inputs given to this class's constructor.
213
+ """
214
+
215
+ inputs_schema: Schema = None
216
+ """
217
+ Schema of the inputs given to this class's constructor.
218
+ """
219
+
220
+ outputs_schema: Schema = None
221
+ """
222
+ Schema of the output produced by calling the given model with inputs.
223
+ """
224
+
225
+ def __init__(
226
+ self,
227
+ model: nn.Module,
228
+ inputs,
229
+ inference_func: Optional[Callable] = None,
230
+ allow_non_tensor: bool = False,
231
+ ):
232
+ """
233
+ Args:
234
+ model: an nn.Module
235
+ inputs: An input argument or a tuple of input arguments used to call model.
236
+ After flattening, it has to only consist of tensors.
237
+ inference_func: a callable that takes (model, *inputs), calls the
238
+ model with inputs, and return outputs. By default it
239
+ is ``lambda model, *inputs: model(*inputs)``. Can be override
240
+ if you need to call the model differently.
241
+ allow_non_tensor: allow inputs/outputs to contain non-tensor objects.
242
+ This option will filter out non-tensor objects to make the
243
+ model traceable, but ``inputs_schema``/``outputs_schema`` cannot be
244
+ used anymore because inputs/outputs cannot be rebuilt from pure tensors.
245
+ This is useful when you're only interested in the single trace of
246
+ execution (e.g. for flop count), but not interested in
247
+ generalizing the traced graph to new inputs.
248
+ """
249
+ super().__init__()
250
+ if isinstance(model, (nn.parallel.distributed.DistributedDataParallel, nn.DataParallel)):
251
+ model = model.module
252
+ self.model = model
253
+ if not isinstance(inputs, tuple):
254
+ inputs = (inputs,)
255
+ self.inputs = inputs
256
+ self.allow_non_tensor = allow_non_tensor
257
+
258
+ if inference_func is None:
259
+ inference_func = lambda model, *inputs: model(*inputs) # noqa
260
+ self.inference_func = inference_func
261
+
262
+ self.flattened_inputs, self.inputs_schema = flatten_to_tuple(inputs)
263
+
264
+ if all(isinstance(x, torch.Tensor) for x in self.flattened_inputs):
265
+ return
266
+ if self.allow_non_tensor:
267
+ self.flattened_inputs = tuple(
268
+ [x for x in self.flattened_inputs if isinstance(x, torch.Tensor)]
269
+ )
270
+ self.inputs_schema = None
271
+ else:
272
+ for input in self.flattened_inputs:
273
+ if not isinstance(input, torch.Tensor):
274
+ raise ValueError(
275
+ "Inputs for tracing must only contain tensors. "
276
+ f"Got a {type(input)} instead."
277
+ )
278
+
279
+ def forward(self, *args: torch.Tensor):
280
+ with torch.no_grad(), patch_builtin_len():
281
+ if self.inputs_schema is not None:
282
+ inputs_orig_format = self.inputs_schema(args)
283
+ else:
284
+ if len(args) != len(self.flattened_inputs) or any(
285
+ x is not y for x, y in zip(args, self.flattened_inputs)
286
+ ):
287
+ raise ValueError(
288
+ "TracingAdapter does not contain valid inputs_schema."
289
+ " So it cannot generalize to other inputs and must be"
290
+ " traced with `.flattened_inputs`."
291
+ )
292
+ inputs_orig_format = self.inputs
293
+
294
+ outputs = self.inference_func(self.model, *inputs_orig_format)
295
+ flattened_outputs, schema = flatten_to_tuple(outputs)
296
+
297
+ flattened_output_tensors = tuple(
298
+ [x for x in flattened_outputs if isinstance(x, torch.Tensor)]
299
+ )
300
+ if len(flattened_output_tensors) < len(flattened_outputs):
301
+ if self.allow_non_tensor:
302
+ flattened_outputs = flattened_output_tensors
303
+ self.outputs_schema = None
304
+ else:
305
+ raise ValueError(
306
+ "Model cannot be traced because some model outputs "
307
+ "cannot flatten to tensors."
308
+ )
309
+ else: # schema is valid
310
+ if self.outputs_schema is None:
311
+ self.outputs_schema = schema
312
+ else:
313
+ assert self.outputs_schema == schema, (
314
+ "Model should always return outputs with the same "
315
+ "structure so it can be traced!"
316
+ )
317
+ return flattened_outputs
318
+
319
+ def _create_wrapper(self, traced_model):
320
+ """
321
+ Return a function that has an input/output interface the same as the
322
+ original model, but it calls the given traced model under the hood.
323
+ """
324
+
325
+ def forward(*args):
326
+ flattened_inputs, _ = flatten_to_tuple(args)
327
+ flattened_outputs = traced_model(*flattened_inputs)
328
+ return self.outputs_schema(flattened_outputs)
329
+
330
+ return forward
CatVTON/detectron2/export/shared.py ADDED
@@ -0,0 +1,1040 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ import collections
4
+ import copy
5
+ import functools
6
+ import logging
7
+ import numpy as np
8
+ import os
9
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
10
+ from unittest import mock
11
+ import caffe2.python.utils as putils
12
+ import torch
13
+ import torch.nn.functional as F
14
+ from caffe2.proto import caffe2_pb2
15
+ from caffe2.python import core, net_drawer, workspace
16
+ from torch.nn.functional import interpolate as interp
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+
21
+ # ==== torch/utils_toffee/cast.py =======================================
22
+
23
+
24
+ def to_device(t, device_str):
25
+ """
26
+ This function is a replacement of .to(another_device) such that it allows the
27
+ casting to be traced properly by explicitly calling the underlying copy ops.
28
+ It also avoids introducing unncessary op when casting to the same device.
29
+ """
30
+ src = t.device
31
+ dst = torch.device(device_str)
32
+
33
+ if src == dst:
34
+ return t
35
+ elif src.type == "cuda" and dst.type == "cpu":
36
+ return torch.ops._caffe2.CopyGPUToCPU(t)
37
+ elif src.type == "cpu" and dst.type == "cuda":
38
+ return torch.ops._caffe2.CopyCPUToGPU(t)
39
+ else:
40
+ raise RuntimeError("Can't cast tensor from device {} to device {}".format(src, dst))
41
+
42
+
43
+ # ==== torch/utils_toffee/interpolate.py =======================================
44
+
45
+
46
+ # Note: borrowed from vision/detection/fair/detectron/detectron/modeling/detector.py
47
+ def BilinearInterpolation(tensor_in, up_scale):
48
+ assert up_scale % 2 == 0, "Scale should be even"
49
+
50
+ def upsample_filt(size):
51
+ factor = (size + 1) // 2
52
+ if size % 2 == 1:
53
+ center = factor - 1
54
+ else:
55
+ center = factor - 0.5
56
+
57
+ og = np.ogrid[:size, :size]
58
+ return (1 - abs(og[0] - center) / factor) * (1 - abs(og[1] - center) / factor)
59
+
60
+ kernel_size = int(up_scale) * 2
61
+ bil_filt = upsample_filt(kernel_size)
62
+
63
+ dim = int(tensor_in.shape[1])
64
+ kernel = np.zeros((dim, dim, kernel_size, kernel_size), dtype=np.float32)
65
+ kernel[range(dim), range(dim), :, :] = bil_filt
66
+
67
+ tensor_out = F.conv_transpose2d(
68
+ tensor_in,
69
+ weight=to_device(torch.Tensor(kernel), tensor_in.device),
70
+ bias=None,
71
+ stride=int(up_scale),
72
+ padding=int(up_scale / 2),
73
+ )
74
+
75
+ return tensor_out
76
+
77
+
78
+ # NOTE: ONNX is incompatible with traced torch.nn.functional.interpolate if
79
+ # using dynamic `scale_factor` rather than static `size`. (T43166860)
80
+ # NOTE: Caffe2 Int8 conversion might not be able to quantize `size` properly.
81
+ def onnx_compatibale_interpolate(
82
+ input, size=None, scale_factor=None, mode="nearest", align_corners=None
83
+ ):
84
+ # NOTE: The input dimensions are interpreted in the form:
85
+ # `mini-batch x channels x [optional depth] x [optional height] x width`.
86
+ if size is None and scale_factor is not None:
87
+ if input.dim() == 4:
88
+ if isinstance(scale_factor, (int, float)):
89
+ height_scale, width_scale = (scale_factor, scale_factor)
90
+ else:
91
+ assert isinstance(scale_factor, (tuple, list))
92
+ assert len(scale_factor) == 2
93
+ height_scale, width_scale = scale_factor
94
+
95
+ assert not align_corners, "No matching C2 op for align_corners == True"
96
+ if mode == "nearest":
97
+ return torch.ops._caffe2.ResizeNearest(
98
+ input, order="NCHW", width_scale=width_scale, height_scale=height_scale
99
+ )
100
+ elif mode == "bilinear":
101
+ logger.warning(
102
+ "Use F.conv_transpose2d for bilinear interpolate"
103
+ " because there's no such C2 op, this may cause significant"
104
+ " slowdown and the boundary pixels won't be as same as"
105
+ " using F.interpolate due to padding."
106
+ )
107
+ assert height_scale == width_scale
108
+ return BilinearInterpolation(input, up_scale=height_scale)
109
+ logger.warning("Output size is not static, it might cause ONNX conversion issue")
110
+
111
+ return interp(input, size, scale_factor, mode, align_corners)
112
+
113
+
114
+ def mock_torch_nn_functional_interpolate():
115
+ def decorator(func):
116
+ @functools.wraps(func)
117
+ def _mock_torch_nn_functional_interpolate(*args, **kwargs):
118
+ if torch.onnx.is_in_onnx_export():
119
+ with mock.patch(
120
+ "torch.nn.functional.interpolate", side_effect=onnx_compatibale_interpolate
121
+ ):
122
+ return func(*args, **kwargs)
123
+ else:
124
+ return func(*args, **kwargs)
125
+
126
+ return _mock_torch_nn_functional_interpolate
127
+
128
+ return decorator
129
+
130
+
131
+ # ==== torch/utils_caffe2/ws_utils.py ==========================================
132
+
133
+
134
+ class ScopedWS:
135
+ def __init__(self, ws_name, is_reset, is_cleanup=False):
136
+ self.ws_name = ws_name
137
+ self.is_reset = is_reset
138
+ self.is_cleanup = is_cleanup
139
+ self.org_ws = ""
140
+
141
+ def __enter__(self):
142
+ self.org_ws = workspace.CurrentWorkspace()
143
+ if self.ws_name is not None:
144
+ workspace.SwitchWorkspace(self.ws_name, True)
145
+ if self.is_reset:
146
+ workspace.ResetWorkspace()
147
+
148
+ return workspace
149
+
150
+ def __exit__(self, *args):
151
+ if self.is_cleanup:
152
+ workspace.ResetWorkspace()
153
+ if self.ws_name is not None:
154
+ workspace.SwitchWorkspace(self.org_ws)
155
+
156
+
157
+ def fetch_any_blob(name):
158
+ bb = None
159
+ try:
160
+ bb = workspace.FetchBlob(name)
161
+ except TypeError:
162
+ bb = workspace.FetchInt8Blob(name)
163
+ except Exception as e:
164
+ logger.error("Get blob {} error: {}".format(name, e))
165
+
166
+ return bb
167
+
168
+
169
+ # ==== torch/utils_caffe2/protobuf.py ==========================================
170
+
171
+
172
+ def get_pb_arg(pb, arg_name):
173
+ for x in pb.arg:
174
+ if x.name == arg_name:
175
+ return x
176
+ return None
177
+
178
+
179
+ def get_pb_arg_valf(pb, arg_name, default_val):
180
+ arg = get_pb_arg(pb, arg_name)
181
+ return arg.f if arg is not None else default_val
182
+
183
+
184
+ def get_pb_arg_floats(pb, arg_name, default_val):
185
+ arg = get_pb_arg(pb, arg_name)
186
+ return list(map(float, arg.floats)) if arg is not None else default_val
187
+
188
+
189
+ def get_pb_arg_ints(pb, arg_name, default_val):
190
+ arg = get_pb_arg(pb, arg_name)
191
+ return list(map(int, arg.ints)) if arg is not None else default_val
192
+
193
+
194
+ def get_pb_arg_vali(pb, arg_name, default_val):
195
+ arg = get_pb_arg(pb, arg_name)
196
+ return arg.i if arg is not None else default_val
197
+
198
+
199
+ def get_pb_arg_vals(pb, arg_name, default_val):
200
+ arg = get_pb_arg(pb, arg_name)
201
+ return arg.s if arg is not None else default_val
202
+
203
+
204
+ def get_pb_arg_valstrings(pb, arg_name, default_val):
205
+ arg = get_pb_arg(pb, arg_name)
206
+ return list(arg.strings) if arg is not None else default_val
207
+
208
+
209
+ def check_set_pb_arg(pb, arg_name, arg_attr, arg_value, allow_override=False):
210
+ arg = get_pb_arg(pb, arg_name)
211
+ if arg is None:
212
+ arg = putils.MakeArgument(arg_name, arg_value)
213
+ assert hasattr(arg, arg_attr)
214
+ pb.arg.extend([arg])
215
+ if allow_override and getattr(arg, arg_attr) != arg_value:
216
+ logger.warning(
217
+ "Override argument {}: {} -> {}".format(arg_name, getattr(arg, arg_attr), arg_value)
218
+ )
219
+ setattr(arg, arg_attr, arg_value)
220
+ else:
221
+ assert arg is not None
222
+ assert getattr(arg, arg_attr) == arg_value, "Existing value {}, new value {}".format(
223
+ getattr(arg, arg_attr), arg_value
224
+ )
225
+
226
+
227
+ def _create_const_fill_op_from_numpy(name, tensor, device_option=None):
228
+ assert type(tensor) is np.ndarray
229
+ kTypeNameMapper = {
230
+ np.dtype("float32"): "GivenTensorFill",
231
+ np.dtype("int32"): "GivenTensorIntFill",
232
+ np.dtype("int64"): "GivenTensorInt64Fill",
233
+ np.dtype("uint8"): "GivenTensorStringFill",
234
+ }
235
+
236
+ args_dict = {}
237
+ if tensor.dtype == np.dtype("uint8"):
238
+ args_dict.update({"values": [str(tensor.data)], "shape": [1]})
239
+ else:
240
+ args_dict.update({"values": tensor, "shape": tensor.shape})
241
+
242
+ if device_option is not None:
243
+ args_dict["device_option"] = device_option
244
+
245
+ return core.CreateOperator(kTypeNameMapper[tensor.dtype], [], [name], **args_dict)
246
+
247
+
248
+ def _create_const_fill_op_from_c2_int8_tensor(name, int8_tensor):
249
+ assert type(int8_tensor) is workspace.Int8Tensor
250
+ kTypeNameMapper = {
251
+ np.dtype("int32"): "Int8GivenIntTensorFill",
252
+ np.dtype("uint8"): "Int8GivenTensorFill",
253
+ }
254
+
255
+ tensor = int8_tensor.data
256
+ assert tensor.dtype in [np.dtype("uint8"), np.dtype("int32")]
257
+ values = tensor.tobytes() if tensor.dtype == np.dtype("uint8") else tensor
258
+
259
+ return core.CreateOperator(
260
+ kTypeNameMapper[tensor.dtype],
261
+ [],
262
+ [name],
263
+ values=values,
264
+ shape=tensor.shape,
265
+ Y_scale=int8_tensor.scale,
266
+ Y_zero_point=int8_tensor.zero_point,
267
+ )
268
+
269
+
270
+ def create_const_fill_op(
271
+ name: str,
272
+ blob: Union[np.ndarray, workspace.Int8Tensor],
273
+ device_option: Optional[caffe2_pb2.DeviceOption] = None,
274
+ ) -> caffe2_pb2.OperatorDef:
275
+ """
276
+ Given a blob object, return the Caffe2 operator that creates this blob
277
+ as constant. Currently support NumPy tensor and Caffe2 Int8Tensor.
278
+ """
279
+
280
+ tensor_type = type(blob)
281
+ assert tensor_type in [
282
+ np.ndarray,
283
+ workspace.Int8Tensor,
284
+ ], 'Error when creating const fill op for "{}", unsupported blob type: {}'.format(
285
+ name, type(blob)
286
+ )
287
+
288
+ if tensor_type == np.ndarray:
289
+ return _create_const_fill_op_from_numpy(name, blob, device_option)
290
+ elif tensor_type == workspace.Int8Tensor:
291
+ assert device_option is None
292
+ return _create_const_fill_op_from_c2_int8_tensor(name, blob)
293
+
294
+
295
+ def construct_init_net_from_params(
296
+ params: Dict[str, Any], device_options: Optional[Dict[str, caffe2_pb2.DeviceOption]] = None
297
+ ) -> caffe2_pb2.NetDef:
298
+ """
299
+ Construct the init_net from params dictionary
300
+ """
301
+ init_net = caffe2_pb2.NetDef()
302
+ device_options = device_options or {}
303
+ for name, blob in params.items():
304
+ if isinstance(blob, str):
305
+ logger.warning(
306
+ (
307
+ "Blob {} with type {} is not supported in generating init net,"
308
+ " skipped.".format(name, type(blob))
309
+ )
310
+ )
311
+ continue
312
+ init_net.op.extend(
313
+ [create_const_fill_op(name, blob, device_option=device_options.get(name, None))]
314
+ )
315
+ init_net.external_output.append(name)
316
+ return init_net
317
+
318
+
319
+ def get_producer_map(ssa):
320
+ """
321
+ Return dict from versioned blob to (i, j),
322
+ where i is index of producer op, j is the index of output of that op.
323
+ """
324
+ producer_map = {}
325
+ for i in range(len(ssa)):
326
+ outputs = ssa[i][1]
327
+ for j, outp in enumerate(outputs):
328
+ producer_map[outp] = (i, j)
329
+ return producer_map
330
+
331
+
332
+ def get_consumer_map(ssa):
333
+ """
334
+ Return dict from versioned blob to list of (i, j),
335
+ where i is index of consumer op, j is the index of input of that op.
336
+ """
337
+ consumer_map = collections.defaultdict(list)
338
+ for i in range(len(ssa)):
339
+ inputs = ssa[i][0]
340
+ for j, inp in enumerate(inputs):
341
+ consumer_map[inp].append((i, j))
342
+ return consumer_map
343
+
344
+
345
+ def get_params_from_init_net(
346
+ init_net: caffe2_pb2.NetDef,
347
+ ) -> [Dict[str, Any], Dict[str, caffe2_pb2.DeviceOption]]:
348
+ """
349
+ Take the output blobs from init_net by running it.
350
+ Outputs:
351
+ params: dict from blob name to numpy array
352
+ device_options: dict from blob name to the device option of its creating op
353
+ """
354
+
355
+ # NOTE: this assumes that the params is determined by producer op with the
356
+ # only exception be CopyGPUToCPU which is CUDA op but returns CPU tensor.
357
+ def _get_device_option(producer_op):
358
+ if producer_op.type == "CopyGPUToCPU":
359
+ return caffe2_pb2.DeviceOption()
360
+ else:
361
+ return producer_op.device_option
362
+
363
+ with ScopedWS("__get_params_from_init_net__", is_reset=True, is_cleanup=True) as ws:
364
+ ws.RunNetOnce(init_net)
365
+ params = {b: fetch_any_blob(b) for b in init_net.external_output}
366
+ ssa, versions = core.get_ssa(init_net)
367
+ producer_map = get_producer_map(ssa)
368
+ device_options = {
369
+ b: _get_device_option(init_net.op[producer_map[(b, versions[b])][0]])
370
+ for b in init_net.external_output
371
+ }
372
+ return params, device_options
373
+
374
+
375
+ def _updater_raise(op, input_types, output_types):
376
+ raise RuntimeError(
377
+ "Failed to apply updater for op {} given input_types {} and"
378
+ " output_types {}".format(op, input_types, output_types)
379
+ )
380
+
381
+
382
+ def _generic_status_identifier(
383
+ predict_net: caffe2_pb2.NetDef,
384
+ status_updater: Callable,
385
+ known_status: Dict[Tuple[str, int], Any],
386
+ ) -> Dict[Tuple[str, int], Any]:
387
+ """
388
+ Statically infer the status of each blob, the status can be such as device type
389
+ (CPU/GPU), layout (NCHW/NHWC), data type (float32/int8), etc. "Blob" here
390
+ is versioned blob (Tuple[str, int]) in the format compatible with ssa.
391
+ Inputs:
392
+ predict_net: the caffe2 network
393
+ status_updater: a callable, given an op and the status of its input/output,
394
+ it returns the updated status of input/output. `None` is used for
395
+ representing unknown status.
396
+ known_status: a dict containing known status, used as initialization.
397
+ Outputs:
398
+ A dict mapping from versioned blob to its status
399
+ """
400
+ ssa, versions = core.get_ssa(predict_net)
401
+ versioned_ext_input = [(b, 0) for b in predict_net.external_input]
402
+ versioned_ext_output = [(b, versions[b]) for b in predict_net.external_output]
403
+ all_versioned_blobs = set().union(*[set(x[0] + x[1]) for x in ssa])
404
+
405
+ allowed_vbs = all_versioned_blobs.union(versioned_ext_input).union(versioned_ext_output)
406
+ assert all(k in allowed_vbs for k in known_status)
407
+ assert all(v is not None for v in known_status.values())
408
+ _known_status = copy.deepcopy(known_status)
409
+
410
+ def _check_and_update(key, value):
411
+ assert value is not None
412
+ if key in _known_status:
413
+ if not _known_status[key] == value:
414
+ raise RuntimeError(
415
+ "Confilict status for {}, existing status {}, new status {}".format(
416
+ key, _known_status[key], value
417
+ )
418
+ )
419
+ _known_status[key] = value
420
+
421
+ def _update_i(op, ssa_i):
422
+ versioned_inputs = ssa_i[0]
423
+ versioned_outputs = ssa_i[1]
424
+
425
+ inputs_status = [_known_status.get(b, None) for b in versioned_inputs]
426
+ outputs_status = [_known_status.get(b, None) for b in versioned_outputs]
427
+
428
+ new_inputs_status, new_outputs_status = status_updater(op, inputs_status, outputs_status)
429
+
430
+ for versioned_blob, status in zip(
431
+ versioned_inputs + versioned_outputs, new_inputs_status + new_outputs_status
432
+ ):
433
+ if status is not None:
434
+ _check_and_update(versioned_blob, status)
435
+
436
+ for op, ssa_i in zip(predict_net.op, ssa):
437
+ _update_i(op, ssa_i)
438
+ for op, ssa_i in zip(reversed(predict_net.op), reversed(ssa)):
439
+ _update_i(op, ssa_i)
440
+
441
+ # NOTE: This strictly checks all the blob from predict_net must be assgined
442
+ # a known status. However sometimes it's impossible (eg. having deadend op),
443
+ # we may relax this constraint if
444
+ for k in all_versioned_blobs:
445
+ if k not in _known_status:
446
+ raise NotImplementedError(
447
+ "Can not infer the status for {}. Currently only support the case where"
448
+ " a single forward and backward pass can identify status for all blobs.".format(k)
449
+ )
450
+
451
+ return _known_status
452
+
453
+
454
+ def infer_device_type(
455
+ predict_net: caffe2_pb2.NetDef,
456
+ known_status: Dict[Tuple[str, int], Any],
457
+ device_name_style: str = "caffe2",
458
+ ) -> Dict[Tuple[str, int], str]:
459
+ """Return the device type ("cpu" or "gpu"/"cuda") of each (versioned) blob"""
460
+
461
+ assert device_name_style in ["caffe2", "pytorch"]
462
+ _CPU_STR = "cpu"
463
+ _GPU_STR = "gpu" if device_name_style == "caffe2" else "cuda"
464
+
465
+ def _copy_cpu_to_gpu_updater(op, input_types, output_types):
466
+ if input_types[0] == _GPU_STR or output_types[0] == _CPU_STR:
467
+ _updater_raise(op, input_types, output_types)
468
+ return ([_CPU_STR], [_GPU_STR])
469
+
470
+ def _copy_gpu_to_cpu_updater(op, input_types, output_types):
471
+ if input_types[0] == _CPU_STR or output_types[0] == _GPU_STR:
472
+ _updater_raise(op, input_types, output_types)
473
+ return ([_GPU_STR], [_CPU_STR])
474
+
475
+ def _other_ops_updater(op, input_types, output_types):
476
+ non_none_types = [x for x in input_types + output_types if x is not None]
477
+ if len(non_none_types) > 0:
478
+ the_type = non_none_types[0]
479
+ if not all(x == the_type for x in non_none_types):
480
+ _updater_raise(op, input_types, output_types)
481
+ else:
482
+ the_type = None
483
+ return ([the_type for _ in op.input], [the_type for _ in op.output])
484
+
485
+ def _device_updater(op, *args, **kwargs):
486
+ return {
487
+ "CopyCPUToGPU": _copy_cpu_to_gpu_updater,
488
+ "CopyGPUToCPU": _copy_gpu_to_cpu_updater,
489
+ }.get(op.type, _other_ops_updater)(op, *args, **kwargs)
490
+
491
+ return _generic_status_identifier(predict_net, _device_updater, known_status)
492
+
493
+
494
+ # ==== torch/utils_caffe2/vis.py ===============================================
495
+
496
+
497
+ def _modify_blob_names(ops, blob_rename_f):
498
+ ret = []
499
+
500
+ def _replace_list(blob_list, replaced_list):
501
+ del blob_list[:]
502
+ blob_list.extend(replaced_list)
503
+
504
+ for x in ops:
505
+ cur = copy.deepcopy(x)
506
+ _replace_list(cur.input, list(map(blob_rename_f, cur.input)))
507
+ _replace_list(cur.output, list(map(blob_rename_f, cur.output)))
508
+ ret.append(cur)
509
+
510
+ return ret
511
+
512
+
513
+ def _rename_blob(name, blob_sizes, blob_ranges):
514
+ def _list_to_str(bsize):
515
+ ret = ", ".join([str(x) for x in bsize])
516
+ ret = "[" + ret + "]"
517
+ return ret
518
+
519
+ ret = name
520
+ if blob_sizes is not None and name in blob_sizes:
521
+ ret += "\n" + _list_to_str(blob_sizes[name])
522
+ if blob_ranges is not None and name in blob_ranges:
523
+ ret += "\n" + _list_to_str(blob_ranges[name])
524
+
525
+ return ret
526
+
527
+
528
+ # graph_name could not contain word 'graph'
529
+ def save_graph(net, file_name, graph_name="net", op_only=True, blob_sizes=None, blob_ranges=None):
530
+ blob_rename_f = functools.partial(_rename_blob, blob_sizes=blob_sizes, blob_ranges=blob_ranges)
531
+ return save_graph_base(net, file_name, graph_name, op_only, blob_rename_f)
532
+
533
+
534
+ def save_graph_base(net, file_name, graph_name="net", op_only=True, blob_rename_func=None):
535
+ graph = None
536
+ ops = net.op
537
+ if blob_rename_func is not None:
538
+ ops = _modify_blob_names(ops, blob_rename_func)
539
+ if not op_only:
540
+ graph = net_drawer.GetPydotGraph(ops, graph_name, rankdir="TB")
541
+ else:
542
+ graph = net_drawer.GetPydotGraphMinimal(
543
+ ops, graph_name, rankdir="TB", minimal_dependency=True
544
+ )
545
+
546
+ try:
547
+ par_dir = os.path.dirname(file_name)
548
+ if not os.path.exists(par_dir):
549
+ os.makedirs(par_dir)
550
+
551
+ format = os.path.splitext(os.path.basename(file_name))[-1]
552
+ if format == ".png":
553
+ graph.write_png(file_name)
554
+ elif format == ".pdf":
555
+ graph.write_pdf(file_name)
556
+ elif format == ".svg":
557
+ graph.write_svg(file_name)
558
+ else:
559
+ print("Incorrect format {}".format(format))
560
+ except Exception as e:
561
+ print("Error when writing graph to image {}".format(e))
562
+
563
+ return graph
564
+
565
+
566
+ # ==== torch/utils_toffee/aten_to_caffe2.py ====================================
567
+
568
+
569
+ def group_norm_replace_aten_with_caffe2(predict_net: caffe2_pb2.NetDef):
570
+ """
571
+ For ONNX exported model, GroupNorm will be represented as ATen op,
572
+ this can be a drop in replacement from ATen to GroupNorm
573
+ """
574
+ count = 0
575
+ for op in predict_net.op:
576
+ if op.type == "ATen":
577
+ op_name = get_pb_arg_vals(op, "operator", None) # return byte in py3
578
+ if op_name and op_name.decode() == "group_norm":
579
+ op.arg.remove(get_pb_arg(op, "operator"))
580
+
581
+ if get_pb_arg_vali(op, "cudnn_enabled", None):
582
+ op.arg.remove(get_pb_arg(op, "cudnn_enabled"))
583
+
584
+ num_groups = get_pb_arg_vali(op, "num_groups", None)
585
+ if num_groups is not None:
586
+ op.arg.remove(get_pb_arg(op, "num_groups"))
587
+ check_set_pb_arg(op, "group", "i", num_groups)
588
+
589
+ op.type = "GroupNorm"
590
+ count += 1
591
+ if count > 1:
592
+ logger.info("Replaced {} ATen operator to GroupNormOp".format(count))
593
+
594
+
595
+ # ==== torch/utils_toffee/alias.py =============================================
596
+
597
+
598
+ def alias(x, name, is_backward=False):
599
+ if not torch.onnx.is_in_onnx_export():
600
+ return x
601
+ assert isinstance(x, torch.Tensor)
602
+ return torch.ops._caffe2.AliasWithName(x, name, is_backward=is_backward)
603
+
604
+
605
+ def fuse_alias_placeholder(predict_net, init_net):
606
+ """Remove AliasWithName placeholder and rename the input/output of it"""
607
+ # First we finish all the re-naming
608
+ for i, op in enumerate(predict_net.op):
609
+ if op.type == "AliasWithName":
610
+ assert len(op.input) == 1
611
+ assert len(op.output) == 1
612
+ name = get_pb_arg_vals(op, "name", None).decode()
613
+ is_backward = bool(get_pb_arg_vali(op, "is_backward", 0))
614
+ rename_op_input(predict_net, init_net, i, 0, name, from_producer=is_backward)
615
+ rename_op_output(predict_net, i, 0, name)
616
+
617
+ # Remove AliasWithName, should be very safe since it's a non-op
618
+ new_ops = []
619
+ for op in predict_net.op:
620
+ if op.type != "AliasWithName":
621
+ new_ops.append(op)
622
+ else:
623
+ # safety check
624
+ assert op.input == op.output
625
+ assert op.input[0] == op.arg[0].s.decode()
626
+ del predict_net.op[:]
627
+ predict_net.op.extend(new_ops)
628
+
629
+
630
+ # ==== torch/utils_caffe2/graph_transform.py ===================================
631
+
632
+
633
+ class IllegalGraphTransformError(ValueError):
634
+ """When a graph transform function call can't be executed."""
635
+
636
+
637
+ def _rename_versioned_blob_in_proto(
638
+ proto: caffe2_pb2.NetDef,
639
+ old_name: str,
640
+ new_name: str,
641
+ version: int,
642
+ ssa: List[Tuple[List[Tuple[str, int]], List[Tuple[str, int]]]],
643
+ start_versions: Dict[str, int],
644
+ end_versions: Dict[str, int],
645
+ ):
646
+ """In given proto, rename all blobs with matched version"""
647
+ # Operater list
648
+ for op, i_th_ssa in zip(proto.op, ssa):
649
+ versioned_inputs, versioned_outputs = i_th_ssa
650
+ for i in range(len(op.input)):
651
+ if versioned_inputs[i] == (old_name, version):
652
+ op.input[i] = new_name
653
+ for i in range(len(op.output)):
654
+ if versioned_outputs[i] == (old_name, version):
655
+ op.output[i] = new_name
656
+ # external_input
657
+ if start_versions.get(old_name, 0) == version:
658
+ for i in range(len(proto.external_input)):
659
+ if proto.external_input[i] == old_name:
660
+ proto.external_input[i] = new_name
661
+ # external_output
662
+ if end_versions.get(old_name, 0) == version:
663
+ for i in range(len(proto.external_output)):
664
+ if proto.external_output[i] == old_name:
665
+ proto.external_output[i] = new_name
666
+
667
+
668
+ def rename_op_input(
669
+ predict_net: caffe2_pb2.NetDef,
670
+ init_net: caffe2_pb2.NetDef,
671
+ op_id: int,
672
+ input_id: int,
673
+ new_name: str,
674
+ from_producer: bool = False,
675
+ ):
676
+ """
677
+ Rename the op_id-th operator in predict_net, change it's input_id-th input's
678
+ name to the new_name. It also does automatic re-route and change
679
+ external_input and init_net if necessary.
680
+ - It requires the input is only consumed by this op.
681
+ - This function modifies predict_net and init_net in-place.
682
+ - When from_producer is enable, this also updates other operators that consumes
683
+ the same input. Be cautious because may trigger unintended behavior.
684
+ """
685
+ assert isinstance(predict_net, caffe2_pb2.NetDef)
686
+ assert isinstance(init_net, caffe2_pb2.NetDef)
687
+
688
+ init_net_ssa, init_net_versions = core.get_ssa(init_net)
689
+ predict_net_ssa, predict_net_versions = core.get_ssa(
690
+ predict_net, copy.deepcopy(init_net_versions)
691
+ )
692
+
693
+ versioned_inputs, versioned_outputs = predict_net_ssa[op_id]
694
+ old_name, version = versioned_inputs[input_id]
695
+
696
+ if from_producer:
697
+ producer_map = get_producer_map(predict_net_ssa)
698
+ if not (old_name, version) in producer_map:
699
+ raise NotImplementedError(
700
+ "Can't find producer, the input {} is probably from"
701
+ " init_net, this is not supported yet.".format(old_name)
702
+ )
703
+ producer = producer_map[(old_name, version)]
704
+ rename_op_output(predict_net, producer[0], producer[1], new_name)
705
+ return
706
+
707
+ def contain_targets(op_ssa):
708
+ return (old_name, version) in op_ssa[0]
709
+
710
+ is_consumer = [contain_targets(op_ssa) for op_ssa in predict_net_ssa]
711
+ if sum(is_consumer) > 1:
712
+ raise IllegalGraphTransformError(
713
+ (
714
+ "Input '{}' of operator(#{}) are consumed by other ops, please use"
715
+ + " rename_op_output on the producer instead. Offending op: \n{}"
716
+ ).format(old_name, op_id, predict_net.op[op_id])
717
+ )
718
+
719
+ # update init_net
720
+ _rename_versioned_blob_in_proto(
721
+ init_net, old_name, new_name, version, init_net_ssa, {}, init_net_versions
722
+ )
723
+ # update predict_net
724
+ _rename_versioned_blob_in_proto(
725
+ predict_net,
726
+ old_name,
727
+ new_name,
728
+ version,
729
+ predict_net_ssa,
730
+ init_net_versions,
731
+ predict_net_versions,
732
+ )
733
+
734
+
735
+ def rename_op_output(predict_net: caffe2_pb2.NetDef, op_id: int, output_id: int, new_name: str):
736
+ """
737
+ Rename the op_id-th operator in predict_net, change it's output_id-th input's
738
+ name to the new_name. It also does automatic re-route and change
739
+ external_output and if necessary.
740
+ - It allows multiple consumers of its output.
741
+ - This function modifies predict_net in-place, doesn't need init_net.
742
+ """
743
+ assert isinstance(predict_net, caffe2_pb2.NetDef)
744
+
745
+ ssa, blob_versions = core.get_ssa(predict_net)
746
+
747
+ versioned_inputs, versioned_outputs = ssa[op_id]
748
+ old_name, version = versioned_outputs[output_id]
749
+
750
+ # update predict_net
751
+ _rename_versioned_blob_in_proto(
752
+ predict_net, old_name, new_name, version, ssa, {}, blob_versions
753
+ )
754
+
755
+
756
+ def get_sub_graph_external_input_output(
757
+ predict_net: caffe2_pb2.NetDef, sub_graph_op_indices: List[int]
758
+ ) -> Tuple[List[Tuple[str, int]], List[Tuple[str, int]]]:
759
+ """
760
+ Return the list of external input/output of sub-graph,
761
+ each element is tuple of the name and corresponding version in predict_net.
762
+
763
+ external input/output is defined the same way as caffe2 NetDef.
764
+ """
765
+ ssa, versions = core.get_ssa(predict_net)
766
+
767
+ all_inputs = []
768
+ all_outputs = []
769
+ for op_id in sub_graph_op_indices:
770
+ all_inputs += [inp for inp in ssa[op_id][0] if inp not in all_inputs]
771
+ all_outputs += list(ssa[op_id][1]) # ssa output won't repeat
772
+
773
+ # for versioned blobs, external inputs are just those blob in all_inputs
774
+ # but not in all_outputs
775
+ ext_inputs = [inp for inp in all_inputs if inp not in all_outputs]
776
+
777
+ # external outputs are essentially outputs of this subgraph that are used
778
+ # outside of this sub-graph (including predict_net.external_output)
779
+ all_other_inputs = sum(
780
+ (ssa[i][0] for i in range(len(ssa)) if i not in sub_graph_op_indices),
781
+ [(outp, versions[outp]) for outp in predict_net.external_output],
782
+ )
783
+ ext_outputs = [outp for outp in all_outputs if outp in set(all_other_inputs)]
784
+
785
+ return ext_inputs, ext_outputs
786
+
787
+
788
+ class DiGraph:
789
+ """A DAG representation of caffe2 graph, each vertice is a versioned blob."""
790
+
791
+ def __init__(self):
792
+ self.vertices = set()
793
+ self.graph = collections.defaultdict(list)
794
+
795
+ def add_edge(self, u, v):
796
+ self.graph[u].append(v)
797
+ self.vertices.add(u)
798
+ self.vertices.add(v)
799
+
800
+ # grab from https://www.geeksforgeeks.org/find-paths-given-source-destination/
801
+ def get_all_paths(self, s, d):
802
+ visited = {k: False for k in self.vertices}
803
+ path = []
804
+ all_paths = []
805
+
806
+ def _get_all_paths_util(graph, u, d, visited, path):
807
+ visited[u] = True
808
+ path.append(u)
809
+ if u == d:
810
+ all_paths.append(copy.deepcopy(path))
811
+ else:
812
+ for i in graph[u]:
813
+ if not visited[i]:
814
+ _get_all_paths_util(graph, i, d, visited, path)
815
+ path.pop()
816
+ visited[u] = False
817
+
818
+ _get_all_paths_util(self.graph, s, d, visited, path)
819
+ return all_paths
820
+
821
+ @staticmethod
822
+ def from_ssa(ssa):
823
+ graph = DiGraph()
824
+ for op_id in range(len(ssa)):
825
+ for inp in ssa[op_id][0]:
826
+ for outp in ssa[op_id][1]:
827
+ graph.add_edge(inp, outp)
828
+ return graph
829
+
830
+
831
+ def _get_dependency_chain(ssa, versioned_target, versioned_source):
832
+ """
833
+ Return the index list of relevant operator to produce target blob from source blob,
834
+ if there's no dependency, return empty list.
835
+ """
836
+
837
+ # finding all paths between nodes can be O(N!), thus we can only search
838
+ # in the subgraph using the op starting from the first consumer of source blob
839
+ # to the producer of the target blob.
840
+ consumer_map = get_consumer_map(ssa)
841
+ producer_map = get_producer_map(ssa)
842
+ start_op = min(x[0] for x in consumer_map[versioned_source]) - 15
843
+ end_op = (
844
+ producer_map[versioned_target][0] + 15 if versioned_target in producer_map else start_op
845
+ )
846
+ sub_graph_ssa = ssa[start_op : end_op + 1]
847
+ if len(sub_graph_ssa) > 30:
848
+ logger.warning(
849
+ "Subgraph bebetween {} and {} is large (from op#{} to op#{}), it"
850
+ " might take non-trival time to find all paths between them.".format(
851
+ versioned_source, versioned_target, start_op, end_op
852
+ )
853
+ )
854
+
855
+ dag = DiGraph.from_ssa(sub_graph_ssa)
856
+ paths = dag.get_all_paths(versioned_source, versioned_target) # include two ends
857
+ ops_in_paths = [[producer_map[blob][0] for blob in path[1:]] for path in paths]
858
+ return sorted(set().union(*[set(ops) for ops in ops_in_paths]))
859
+
860
+
861
+ def identify_reshape_sub_graph(predict_net: caffe2_pb2.NetDef) -> List[List[int]]:
862
+ """
863
+ Idenfity the reshape sub-graph in a protobuf.
864
+ The reshape sub-graph is defined as matching the following pattern:
865
+
866
+ (input_blob) -> Op_1 -> ... -> Op_N -> (new_shape) -─┐
867
+ └-------------------------------------------> Reshape -> (output_blob)
868
+
869
+ Return:
870
+ List of sub-graphs, each sub-graph is represented as a list of indices
871
+ of the relavent ops, [Op_1, Op_2, ..., Op_N, Reshape]
872
+ """
873
+
874
+ ssa, _ = core.get_ssa(predict_net)
875
+
876
+ ret = []
877
+ for i, op in enumerate(predict_net.op):
878
+ if op.type == "Reshape":
879
+ assert len(op.input) == 2
880
+ input_ssa = ssa[i][0]
881
+ data_source = input_ssa[0]
882
+ shape_source = input_ssa[1]
883
+ op_indices = _get_dependency_chain(ssa, shape_source, data_source)
884
+ ret.append(op_indices + [i])
885
+ return ret
886
+
887
+
888
+ def remove_reshape_for_fc(predict_net, params):
889
+ """
890
+ In PyTorch nn.Linear has to take 2D tensor, this often leads to reshape
891
+ a 4D tensor to 2D by calling .view(). However this (dynamic) reshaping
892
+ doesn't work well with ONNX and Int8 tools, and cause using extra
893
+ ops (eg. ExpandDims) that might not be available on mobile.
894
+ Luckily Caffe2 supports 4D tensor for FC, so we can remove those reshape
895
+ after exporting ONNX model.
896
+ """
897
+ from caffe2.python import core
898
+
899
+ # find all reshape sub-graph that can be removed, which is now all Reshape
900
+ # sub-graph whose output is only consumed by FC.
901
+ # TODO: to make it safer, we may need the actually value to better determine
902
+ # if a Reshape before FC is removable.
903
+ reshape_sub_graphs = identify_reshape_sub_graph(predict_net)
904
+ sub_graphs_to_remove = []
905
+ for reshape_sub_graph in reshape_sub_graphs:
906
+ reshape_op_id = reshape_sub_graph[-1]
907
+ assert predict_net.op[reshape_op_id].type == "Reshape"
908
+ ssa, _ = core.get_ssa(predict_net)
909
+ reshape_output = ssa[reshape_op_id][1][0]
910
+ consumers = [i for i in range(len(ssa)) if reshape_output in ssa[i][0]]
911
+ if all(predict_net.op[consumer].type == "FC" for consumer in consumers):
912
+ # safety check if the sub-graph is isolated, for this reshape sub-graph,
913
+ # it means it has one non-param external input and one external output.
914
+ ext_inputs, ext_outputs = get_sub_graph_external_input_output(
915
+ predict_net, reshape_sub_graph
916
+ )
917
+ non_params_ext_inputs = [inp for inp in ext_inputs if inp[1] != 0]
918
+ if len(non_params_ext_inputs) == 1 and len(ext_outputs) == 1:
919
+ sub_graphs_to_remove.append(reshape_sub_graph)
920
+
921
+ # perform removing subgraph by:
922
+ # 1: rename the Reshape's output to its input, then the graph can be
923
+ # seen as in-place itentify, meaning whose external input/output are the same.
924
+ # 2: simply remove those ops.
925
+ remove_op_ids = []
926
+ params_to_remove = []
927
+ for sub_graph in sub_graphs_to_remove:
928
+ logger.info(
929
+ "Remove Reshape sub-graph:\n{}".format(
930
+ "".join(["(#{:>4})\n{}".format(i, predict_net.op[i]) for i in sub_graph])
931
+ )
932
+ )
933
+ reshape_op_id = sub_graph[-1]
934
+ new_reshap_output = predict_net.op[reshape_op_id].input[0]
935
+ rename_op_output(predict_net, reshape_op_id, 0, new_reshap_output)
936
+ ext_inputs, ext_outputs = get_sub_graph_external_input_output(predict_net, sub_graph)
937
+ non_params_ext_inputs = [inp for inp in ext_inputs if inp[1] != 0]
938
+ params_ext_inputs = [inp for inp in ext_inputs if inp[1] == 0]
939
+ assert len(non_params_ext_inputs) == 1 and len(ext_outputs) == 1
940
+ assert ext_outputs[0][0] == non_params_ext_inputs[0][0]
941
+ assert ext_outputs[0][1] == non_params_ext_inputs[0][1] + 1
942
+ remove_op_ids.extend(sub_graph)
943
+ params_to_remove.extend(params_ext_inputs)
944
+
945
+ predict_net = copy.deepcopy(predict_net)
946
+ new_ops = [op for i, op in enumerate(predict_net.op) if i not in remove_op_ids]
947
+ del predict_net.op[:]
948
+ predict_net.op.extend(new_ops)
949
+ for versioned_params in params_to_remove:
950
+ name = versioned_params[0]
951
+ logger.info("Remove params: {} from init_net and predict_net.external_input".format(name))
952
+ del params[name]
953
+ predict_net.external_input.remove(name)
954
+
955
+ return predict_net, params
956
+
957
+
958
+ def fuse_copy_between_cpu_and_gpu(predict_net: caffe2_pb2.NetDef):
959
+ """
960
+ In-place fuse extra copy ops between cpu/gpu for the following case:
961
+ a -CopyAToB-> b -CopyBToA> c1 -NextOp1-> d1
962
+ -CopyBToA> c2 -NextOp2-> d2
963
+ The fused network will look like:
964
+ a -NextOp1-> d1
965
+ -NextOp2-> d2
966
+ """
967
+
968
+ _COPY_OPS = ["CopyCPUToGPU", "CopyGPUToCPU"]
969
+
970
+ def _fuse_once(predict_net):
971
+ ssa, blob_versions = core.get_ssa(predict_net)
972
+ consumer_map = get_consumer_map(ssa)
973
+ versioned_external_output = [
974
+ (name, blob_versions[name]) for name in predict_net.external_output
975
+ ]
976
+
977
+ for op_id, op in enumerate(predict_net.op):
978
+ if op.type in _COPY_OPS:
979
+ fw_copy_versioned_output = ssa[op_id][1][0]
980
+ consumer_ids = [x[0] for x in consumer_map[fw_copy_versioned_output]]
981
+ reverse_op_type = _COPY_OPS[1 - _COPY_OPS.index(op.type)]
982
+
983
+ is_fusable = (
984
+ len(consumer_ids) > 0
985
+ and fw_copy_versioned_output not in versioned_external_output
986
+ and all(
987
+ predict_net.op[_op_id].type == reverse_op_type
988
+ and ssa[_op_id][1][0] not in versioned_external_output
989
+ for _op_id in consumer_ids
990
+ )
991
+ )
992
+
993
+ if is_fusable:
994
+ for rv_copy_op_id in consumer_ids:
995
+ # making each NextOp uses "a" directly and removing Copy ops
996
+ rs_copy_versioned_output = ssa[rv_copy_op_id][1][0]
997
+ next_op_id, inp_id = consumer_map[rs_copy_versioned_output][0]
998
+ predict_net.op[next_op_id].input[inp_id] = op.input[0]
999
+ # remove CopyOps
1000
+ new_ops = [
1001
+ op
1002
+ for i, op in enumerate(predict_net.op)
1003
+ if i != op_id and i not in consumer_ids
1004
+ ]
1005
+ del predict_net.op[:]
1006
+ predict_net.op.extend(new_ops)
1007
+ return True
1008
+
1009
+ return False
1010
+
1011
+ # _fuse_once returns False is nothing can be fused
1012
+ while _fuse_once(predict_net):
1013
+ pass
1014
+
1015
+
1016
+ def remove_dead_end_ops(net_def: caffe2_pb2.NetDef):
1017
+ """remove ops if its output is not used or not in external_output"""
1018
+ ssa, versions = core.get_ssa(net_def)
1019
+ versioned_external_output = [(name, versions[name]) for name in net_def.external_output]
1020
+ consumer_map = get_consumer_map(ssa)
1021
+ removed_op_ids = set()
1022
+
1023
+ def _is_dead_end(versioned_blob):
1024
+ return not (
1025
+ versioned_blob in versioned_external_output
1026
+ or (
1027
+ len(consumer_map[versioned_blob]) > 0
1028
+ and all(x[0] not in removed_op_ids for x in consumer_map[versioned_blob])
1029
+ )
1030
+ )
1031
+
1032
+ for i, ssa_i in reversed(list(enumerate(ssa))):
1033
+ versioned_outputs = ssa_i[1]
1034
+ if all(_is_dead_end(outp) for outp in versioned_outputs):
1035
+ removed_op_ids.add(i)
1036
+
1037
+ # simply removing those deadend ops should have no effect to external_output
1038
+ new_ops = [op for i, op in enumerate(net_def.op) if i not in removed_op_ids]
1039
+ del net_def.op[:]
1040
+ net_def.op.extend(new_ops)
CatVTON/detectron2/export/torchscript.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ import os
4
+ import torch
5
+
6
+ from detectron2.utils.file_io import PathManager
7
+
8
+ from .torchscript_patch import freeze_training_mode, patch_instances
9
+
10
+ __all__ = ["scripting_with_instances", "dump_torchscript_IR"]
11
+
12
+
13
+ def scripting_with_instances(model, fields):
14
+ """
15
+ Run :func:`torch.jit.script` on a model that uses the :class:`Instances` class. Since
16
+ attributes of :class:`Instances` are "dynamically" added in eager mode,it is difficult
17
+ for scripting to support it out of the box. This function is made to support scripting
18
+ a model that uses :class:`Instances`. It does the following:
19
+
20
+ 1. Create a scriptable ``new_Instances`` class which behaves similarly to ``Instances``,
21
+ but with all attributes been "static".
22
+ The attributes need to be statically declared in the ``fields`` argument.
23
+ 2. Register ``new_Instances``, and force scripting compiler to
24
+ use it when trying to compile ``Instances``.
25
+
26
+ After this function, the process will be reverted. User should be able to script another model
27
+ using different fields.
28
+
29
+ Example:
30
+ Assume that ``Instances`` in the model consist of two attributes named
31
+ ``proposal_boxes`` and ``objectness_logits`` with type :class:`Boxes` and
32
+ :class:`Tensor` respectively during inference. You can call this function like:
33
+ ::
34
+ fields = {"proposal_boxes": Boxes, "objectness_logits": torch.Tensor}
35
+ torchscipt_model = scripting_with_instances(model, fields)
36
+
37
+ Note:
38
+ It only support models in evaluation mode.
39
+
40
+ Args:
41
+ model (nn.Module): The input model to be exported by scripting.
42
+ fields (Dict[str, type]): Attribute names and corresponding type that
43
+ ``Instances`` will use in the model. Note that all attributes used in ``Instances``
44
+ need to be added, regardless of whether they are inputs/outputs of the model.
45
+ Data type not defined in detectron2 is not supported for now.
46
+
47
+ Returns:
48
+ torch.jit.ScriptModule: the model in torchscript format
49
+ """
50
+ assert (
51
+ not model.training
52
+ ), "Currently we only support exporting models in evaluation mode to torchscript"
53
+
54
+ with freeze_training_mode(model), patch_instances(fields):
55
+ scripted_model = torch.jit.script(model)
56
+ return scripted_model
57
+
58
+
59
+ # alias for old name
60
+ export_torchscript_with_instances = scripting_with_instances
61
+
62
+
63
+ def dump_torchscript_IR(model, dir):
64
+ """
65
+ Dump IR of a TracedModule/ScriptModule/Function in various format (code, graph,
66
+ inlined graph). Useful for debugging.
67
+
68
+ Args:
69
+ model (TracedModule/ScriptModule/ScriptFUnction): traced or scripted module
70
+ dir (str): output directory to dump files.
71
+ """
72
+ dir = os.path.expanduser(dir)
73
+ PathManager.mkdirs(dir)
74
+
75
+ def _get_script_mod(mod):
76
+ if isinstance(mod, torch.jit.TracedModule):
77
+ return mod._actual_script_module
78
+ return mod
79
+
80
+ # Dump pretty-printed code: https://pytorch.org/docs/stable/jit.html#inspecting-code
81
+ with PathManager.open(os.path.join(dir, "model_ts_code.txt"), "w") as f:
82
+
83
+ def get_code(mod):
84
+ # Try a few ways to get code using private attributes.
85
+ try:
86
+ # This contains more information than just `mod.code`
87
+ return _get_script_mod(mod)._c.code
88
+ except AttributeError:
89
+ pass
90
+ try:
91
+ return mod.code
92
+ except AttributeError:
93
+ return None
94
+
95
+ def dump_code(prefix, mod):
96
+ code = get_code(mod)
97
+ name = prefix or "root model"
98
+ if code is None:
99
+ f.write(f"Could not found code for {name} (type={mod.original_name})\n")
100
+ f.write("\n")
101
+ else:
102
+ f.write(f"\nCode for {name}, type={mod.original_name}:\n")
103
+ f.write(code)
104
+ f.write("\n")
105
+ f.write("-" * 80)
106
+
107
+ for name, m in mod.named_children():
108
+ dump_code(prefix + "." + name, m)
109
+
110
+ if isinstance(model, torch.jit.ScriptFunction):
111
+ f.write(get_code(model))
112
+ else:
113
+ dump_code("", model)
114
+
115
+ def _get_graph(model):
116
+ try:
117
+ # Recursively dump IR of all modules
118
+ return _get_script_mod(model)._c.dump_to_str(True, False, False)
119
+ except AttributeError:
120
+ return model.graph.str()
121
+
122
+ with PathManager.open(os.path.join(dir, "model_ts_IR.txt"), "w") as f:
123
+ f.write(_get_graph(model))
124
+
125
+ # Dump IR of the entire graph (all submodules inlined)
126
+ with PathManager.open(os.path.join(dir, "model_ts_IR_inlined.txt"), "w") as f:
127
+ f.write(str(model.inlined_graph))
128
+
129
+ if not isinstance(model, torch.jit.ScriptFunction):
130
+ # Dump the model structure in pytorch style
131
+ with PathManager.open(os.path.join(dir, "model.txt"), "w") as f:
132
+ f.write(str(model))
CatVTON/detectron2/export/torchscript_patch.py ADDED
@@ -0,0 +1,406 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ import os
4
+ import sys
5
+ import tempfile
6
+ from contextlib import ExitStack, contextmanager
7
+ from copy import deepcopy
8
+ from unittest import mock
9
+ import torch
10
+ from torch import nn
11
+
12
+ # need some explicit imports due to https://github.com/pytorch/pytorch/issues/38964
13
+ import detectron2 # noqa F401
14
+ from detectron2.structures import Boxes, Instances
15
+ from detectron2.utils.env import _import_file
16
+
17
+ _counter = 0
18
+
19
+
20
+ def _clear_jit_cache():
21
+ from torch.jit._recursive import concrete_type_store
22
+ from torch.jit._state import _jit_caching_layer
23
+
24
+ concrete_type_store.type_store.clear() # for modules
25
+ _jit_caching_layer.clear() # for free functions
26
+
27
+
28
+ def _add_instances_conversion_methods(newInstances):
29
+ """
30
+ Add from_instances methods to the scripted Instances class.
31
+ """
32
+ cls_name = newInstances.__name__
33
+
34
+ @torch.jit.unused
35
+ def from_instances(instances: Instances):
36
+ """
37
+ Create scripted Instances from original Instances
38
+ """
39
+ fields = instances.get_fields()
40
+ image_size = instances.image_size
41
+ ret = newInstances(image_size)
42
+ for name, val in fields.items():
43
+ assert hasattr(ret, f"_{name}"), f"No attribute named {name} in {cls_name}"
44
+ setattr(ret, name, deepcopy(val))
45
+ return ret
46
+
47
+ newInstances.from_instances = from_instances
48
+
49
+
50
+ @contextmanager
51
+ def patch_instances(fields):
52
+ """
53
+ A contextmanager, under which the Instances class in detectron2 is replaced
54
+ by a statically-typed scriptable class, defined by `fields`.
55
+ See more in `scripting_with_instances`.
56
+ """
57
+
58
+ with tempfile.TemporaryDirectory(prefix="detectron2") as dir, tempfile.NamedTemporaryFile(
59
+ mode="w", encoding="utf-8", suffix=".py", dir=dir, delete=False
60
+ ) as f:
61
+ try:
62
+ # Objects that use Instances should not reuse previously-compiled
63
+ # results in cache, because `Instances` could be a new class each time.
64
+ _clear_jit_cache()
65
+
66
+ cls_name, s = _gen_instance_module(fields)
67
+ f.write(s)
68
+ f.flush()
69
+ f.close()
70
+
71
+ module = _import(f.name)
72
+ new_instances = getattr(module, cls_name)
73
+ _ = torch.jit.script(new_instances)
74
+ # let torchscript think Instances was scripted already
75
+ Instances.__torch_script_class__ = True
76
+ # let torchscript find new_instances when looking for the jit type of Instances
77
+ Instances._jit_override_qualname = torch._jit_internal._qualified_name(new_instances)
78
+
79
+ _add_instances_conversion_methods(new_instances)
80
+ yield new_instances
81
+ finally:
82
+ try:
83
+ del Instances.__torch_script_class__
84
+ del Instances._jit_override_qualname
85
+ except AttributeError:
86
+ pass
87
+ sys.modules.pop(module.__name__)
88
+
89
+
90
+ def _gen_instance_class(fields):
91
+ """
92
+ Args:
93
+ fields (dict[name: type])
94
+ """
95
+
96
+ class _FieldType:
97
+ def __init__(self, name, type_):
98
+ assert isinstance(name, str), f"Field name must be str, got {name}"
99
+ self.name = name
100
+ self.type_ = type_
101
+ self.annotation = f"{type_.__module__}.{type_.__name__}"
102
+
103
+ fields = [_FieldType(k, v) for k, v in fields.items()]
104
+
105
+ def indent(level, s):
106
+ return " " * 4 * level + s
107
+
108
+ lines = []
109
+
110
+ global _counter
111
+ _counter += 1
112
+
113
+ cls_name = "ScriptedInstances{}".format(_counter)
114
+
115
+ field_names = tuple(x.name for x in fields)
116
+ extra_args = ", ".join([f"{f.name}: Optional[{f.annotation}] = None" for f in fields])
117
+ lines.append(
118
+ f"""
119
+ class {cls_name}:
120
+ def __init__(self, image_size: Tuple[int, int], {extra_args}):
121
+ self.image_size = image_size
122
+ self._field_names = {field_names}
123
+ """
124
+ )
125
+
126
+ for f in fields:
127
+ lines.append(
128
+ indent(2, f"self._{f.name} = torch.jit.annotate(Optional[{f.annotation}], {f.name})")
129
+ )
130
+
131
+ for f in fields:
132
+ lines.append(
133
+ f"""
134
+ @property
135
+ def {f.name}(self) -> {f.annotation}:
136
+ # has to use a local for type refinement
137
+ # https://pytorch.org/docs/stable/jit_language_reference.html#optional-type-refinement
138
+ t = self._{f.name}
139
+ assert t is not None, "{f.name} is None and cannot be accessed!"
140
+ return t
141
+
142
+ @{f.name}.setter
143
+ def {f.name}(self, value: {f.annotation}) -> None:
144
+ self._{f.name} = value
145
+ """
146
+ )
147
+
148
+ # support method `__len__`
149
+ lines.append(
150
+ """
151
+ def __len__(self) -> int:
152
+ """
153
+ )
154
+ for f in fields:
155
+ lines.append(
156
+ f"""
157
+ t = self._{f.name}
158
+ if t is not None:
159
+ return len(t)
160
+ """
161
+ )
162
+ lines.append(
163
+ """
164
+ raise NotImplementedError("Empty Instances does not support __len__!")
165
+ """
166
+ )
167
+
168
+ # support method `has`
169
+ lines.append(
170
+ """
171
+ def has(self, name: str) -> bool:
172
+ """
173
+ )
174
+ for f in fields:
175
+ lines.append(
176
+ f"""
177
+ if name == "{f.name}":
178
+ return self._{f.name} is not None
179
+ """
180
+ )
181
+ lines.append(
182
+ """
183
+ return False
184
+ """
185
+ )
186
+
187
+ # support method `to`
188
+ none_args = ", None" * len(fields)
189
+ lines.append(
190
+ f"""
191
+ def to(self, device: torch.device) -> "{cls_name}":
192
+ ret = {cls_name}(self.image_size{none_args})
193
+ """
194
+ )
195
+ for f in fields:
196
+ if hasattr(f.type_, "to"):
197
+ lines.append(
198
+ f"""
199
+ t = self._{f.name}
200
+ if t is not None:
201
+ ret._{f.name} = t.to(device)
202
+ """
203
+ )
204
+ else:
205
+ # For now, ignore fields that cannot be moved to devices.
206
+ # Maybe can support other tensor-like classes (e.g. __torch_function__)
207
+ pass
208
+ lines.append(
209
+ """
210
+ return ret
211
+ """
212
+ )
213
+
214
+ # support method `getitem`
215
+ none_args = ", None" * len(fields)
216
+ lines.append(
217
+ f"""
218
+ def __getitem__(self, item) -> "{cls_name}":
219
+ ret = {cls_name}(self.image_size{none_args})
220
+ """
221
+ )
222
+ for f in fields:
223
+ lines.append(
224
+ f"""
225
+ t = self._{f.name}
226
+ if t is not None:
227
+ ret._{f.name} = t[item]
228
+ """
229
+ )
230
+ lines.append(
231
+ """
232
+ return ret
233
+ """
234
+ )
235
+
236
+ # support method `cat`
237
+ # this version does not contain checks that all instances have same size and fields
238
+ none_args = ", None" * len(fields)
239
+ lines.append(
240
+ f"""
241
+ def cat(self, instances: List["{cls_name}"]) -> "{cls_name}":
242
+ ret = {cls_name}(self.image_size{none_args})
243
+ """
244
+ )
245
+ for f in fields:
246
+ lines.append(
247
+ f"""
248
+ t = self._{f.name}
249
+ if t is not None:
250
+ values: List[{f.annotation}] = [x.{f.name} for x in instances]
251
+ if torch.jit.isinstance(t, torch.Tensor):
252
+ ret._{f.name} = torch.cat(values, dim=0)
253
+ else:
254
+ ret._{f.name} = t.cat(values)
255
+ """
256
+ )
257
+ lines.append(
258
+ """
259
+ return ret"""
260
+ )
261
+
262
+ # support method `get_fields()`
263
+ lines.append(
264
+ """
265
+ def get_fields(self) -> Dict[str, Tensor]:
266
+ ret = {}
267
+ """
268
+ )
269
+ for f in fields:
270
+ if f.type_ == Boxes:
271
+ stmt = "t.tensor"
272
+ elif f.type_ == torch.Tensor:
273
+ stmt = "t"
274
+ else:
275
+ stmt = f'assert False, "unsupported type {str(f.type_)}"'
276
+ lines.append(
277
+ f"""
278
+ t = self._{f.name}
279
+ if t is not None:
280
+ ret["{f.name}"] = {stmt}
281
+ """
282
+ )
283
+ lines.append(
284
+ """
285
+ return ret"""
286
+ )
287
+ return cls_name, os.linesep.join(lines)
288
+
289
+
290
+ def _gen_instance_module(fields):
291
+ # TODO: find a more automatic way to enable import of other classes
292
+ s = """
293
+ from copy import deepcopy
294
+ import torch
295
+ from torch import Tensor
296
+ import typing
297
+ from typing import *
298
+
299
+ import detectron2
300
+ from detectron2.structures import Boxes, Instances
301
+
302
+ """
303
+
304
+ cls_name, cls_def = _gen_instance_class(fields)
305
+ s += cls_def
306
+ return cls_name, s
307
+
308
+
309
+ def _import(path):
310
+ return _import_file(
311
+ "{}{}".format(sys.modules[__name__].__name__, _counter), path, make_importable=True
312
+ )
313
+
314
+
315
+ @contextmanager
316
+ def patch_builtin_len(modules=()):
317
+ """
318
+ Patch the builtin len() function of a few detectron2 modules
319
+ to use __len__ instead, because __len__ does not convert values to
320
+ integers and therefore is friendly to tracing.
321
+
322
+ Args:
323
+ modules (list[stsr]): names of extra modules to patch len(), in
324
+ addition to those in detectron2.
325
+ """
326
+
327
+ def _new_len(obj):
328
+ return obj.__len__()
329
+
330
+ with ExitStack() as stack:
331
+ MODULES = [
332
+ "detectron2.modeling.roi_heads.fast_rcnn",
333
+ "detectron2.modeling.roi_heads.mask_head",
334
+ "detectron2.modeling.roi_heads.keypoint_head",
335
+ ] + list(modules)
336
+ ctxs = [stack.enter_context(mock.patch(mod + ".len")) for mod in MODULES]
337
+ for m in ctxs:
338
+ m.side_effect = _new_len
339
+ yield
340
+
341
+
342
+ def patch_nonscriptable_classes():
343
+ """
344
+ Apply patches on a few nonscriptable detectron2 classes.
345
+ Should not have side-effects on eager usage.
346
+ """
347
+ # __prepare_scriptable__ can also be added to models for easier maintenance.
348
+ # But it complicates the clean model code.
349
+
350
+ from detectron2.modeling.backbone import ResNet, FPN
351
+
352
+ # Due to https://github.com/pytorch/pytorch/issues/36061,
353
+ # we change backbone to use ModuleList for scripting.
354
+ # (note: this changes param names in state_dict)
355
+
356
+ def prepare_resnet(self):
357
+ ret = deepcopy(self)
358
+ ret.stages = nn.ModuleList(ret.stages)
359
+ for k in self.stage_names:
360
+ delattr(ret, k)
361
+ return ret
362
+
363
+ ResNet.__prepare_scriptable__ = prepare_resnet
364
+
365
+ def prepare_fpn(self):
366
+ ret = deepcopy(self)
367
+ ret.lateral_convs = nn.ModuleList(ret.lateral_convs)
368
+ ret.output_convs = nn.ModuleList(ret.output_convs)
369
+ for name, _ in self.named_children():
370
+ if name.startswith("fpn_"):
371
+ delattr(ret, name)
372
+ return ret
373
+
374
+ FPN.__prepare_scriptable__ = prepare_fpn
375
+
376
+ # Annotate some attributes to be constants for the purpose of scripting,
377
+ # even though they are not constants in eager mode.
378
+ from detectron2.modeling.roi_heads import StandardROIHeads
379
+
380
+ if hasattr(StandardROIHeads, "__annotations__"):
381
+ # copy first to avoid editing annotations of base class
382
+ StandardROIHeads.__annotations__ = deepcopy(StandardROIHeads.__annotations__)
383
+ StandardROIHeads.__annotations__["mask_on"] = torch.jit.Final[bool]
384
+ StandardROIHeads.__annotations__["keypoint_on"] = torch.jit.Final[bool]
385
+
386
+
387
+ # These patches are not supposed to have side-effects.
388
+ patch_nonscriptable_classes()
389
+
390
+
391
+ @contextmanager
392
+ def freeze_training_mode(model):
393
+ """
394
+ A context manager that annotates the "training" attribute of every submodule
395
+ to constant, so that the training codepath in these modules can be
396
+ meta-compiled away. Upon exiting, the annotations are reverted.
397
+ """
398
+ classes = {type(x) for x in model.modules()}
399
+ # __constants__ is the old way to annotate constants and not compatible
400
+ # with __annotations__ .
401
+ classes = {x for x in classes if not hasattr(x, "__constants__")}
402
+ for cls in classes:
403
+ cls.__annotations__["training"] = torch.jit.Final[bool]
404
+ yield
405
+ for cls in classes:
406
+ cls.__annotations__["training"] = bool