Add files using upload-large-folder tool
Browse files- CatVTON/detectron2/data/__init__.py +19 -0
- CatVTON/detectron2/data/benchmark.py +225 -0
- CatVTON/detectron2/data/build.py +694 -0
- CatVTON/detectron2/data/catalog.py +236 -0
- CatVTON/detectron2/data/common.py +339 -0
- CatVTON/detectron2/data/dataset_mapper.py +191 -0
- CatVTON/detectron2/data/datasets/README.md +9 -0
- CatVTON/detectron2/data/datasets/__init__.py +9 -0
- CatVTON/detectron2/data/datasets/builtin.py +259 -0
- CatVTON/detectron2/data/datasets/builtin_meta.py +350 -0
- CatVTON/detectron2/data/datasets/cityscapes.py +334 -0
- CatVTON/detectron2/data/datasets/cityscapes_panoptic.py +187 -0
- CatVTON/detectron2/data/datasets/coco.py +555 -0
- CatVTON/detectron2/data/datasets/coco_panoptic.py +228 -0
- CatVTON/detectron2/data/datasets/lvis.py +250 -0
- CatVTON/detectron2/data/datasets/lvis_v0_5_categories.py +0 -0
- CatVTON/detectron2/data/datasets/lvis_v1_categories.py +0 -0
- CatVTON/detectron2/data/datasets/lvis_v1_category_image_count.py +20 -0
- CatVTON/detectron2/data/datasets/pascal_voc.py +82 -0
- CatVTON/detectron2/data/datasets/register_coco.py +3 -0
- CatVTON/detectron2/data/detection_utils.py +661 -0
- CatVTON/detectron2/data/samplers/__init__.py +17 -0
- CatVTON/detectron2/data/samplers/distributed_sampler.py +287 -0
- CatVTON/detectron2/data/samplers/grouped_batch_sampler.py +47 -0
- CatVTON/detectron2/data/transforms/__init__.py +14 -0
- CatVTON/detectron2/data/transforms/augmentation.py +380 -0
- CatVTON/detectron2/data/transforms/augmentation_impl.py +736 -0
- CatVTON/detectron2/data/transforms/transform.py +351 -0
- CatVTON/detectron2/export/README.md +15 -0
- CatVTON/detectron2/export/__init__.py +30 -0
- CatVTON/detectron2/export/api.py +230 -0
- CatVTON/detectron2/export/c10.py +571 -0
- CatVTON/detectron2/export/caffe2_export.py +203 -0
- CatVTON/detectron2/export/caffe2_inference.py +161 -0
- CatVTON/detectron2/export/caffe2_modeling.py +420 -0
- CatVTON/detectron2/export/caffe2_patch.py +189 -0
- CatVTON/detectron2/export/flatten.py +330 -0
- CatVTON/detectron2/export/shared.py +1040 -0
- CatVTON/detectron2/export/torchscript.py +132 -0
- CatVTON/detectron2/export/torchscript_patch.py +406 -0
CatVTON/detectron2/data/__init__.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
from . import transforms # isort:skip
|
| 3 |
+
|
| 4 |
+
from .build import (
|
| 5 |
+
build_batch_data_loader,
|
| 6 |
+
build_detection_test_loader,
|
| 7 |
+
build_detection_train_loader,
|
| 8 |
+
get_detection_dataset_dicts,
|
| 9 |
+
load_proposals_into_dataset,
|
| 10 |
+
print_instances_class_histogram,
|
| 11 |
+
)
|
| 12 |
+
from .catalog import DatasetCatalog, MetadataCatalog, Metadata
|
| 13 |
+
from .common import DatasetFromList, MapDataset, ToIterableDataset
|
| 14 |
+
from .dataset_mapper import DatasetMapper
|
| 15 |
+
|
| 16 |
+
# ensure the builtin datasets are registered
|
| 17 |
+
from . import datasets, samplers # isort:skip
|
| 18 |
+
|
| 19 |
+
__all__ = [k for k in globals().keys() if not k.startswith("_")]
|
CatVTON/detectron2/data/benchmark.py
ADDED
|
@@ -0,0 +1,225 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
import logging
|
| 3 |
+
import numpy as np
|
| 4 |
+
from itertools import count
|
| 5 |
+
from typing import List, Tuple
|
| 6 |
+
import torch
|
| 7 |
+
import tqdm
|
| 8 |
+
from fvcore.common.timer import Timer
|
| 9 |
+
|
| 10 |
+
from detectron2.utils import comm
|
| 11 |
+
|
| 12 |
+
from .build import build_batch_data_loader
|
| 13 |
+
from .common import DatasetFromList, MapDataset
|
| 14 |
+
from .samplers import TrainingSampler
|
| 15 |
+
|
| 16 |
+
logger = logging.getLogger(__name__)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class _EmptyMapDataset(torch.utils.data.Dataset):
|
| 20 |
+
"""
|
| 21 |
+
Map anything to emptiness.
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
def __init__(self, dataset):
|
| 25 |
+
self.ds = dataset
|
| 26 |
+
|
| 27 |
+
def __len__(self):
|
| 28 |
+
return len(self.ds)
|
| 29 |
+
|
| 30 |
+
def __getitem__(self, idx):
|
| 31 |
+
_ = self.ds[idx]
|
| 32 |
+
return [0]
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def iter_benchmark(
|
| 36 |
+
iterator, num_iter: int, warmup: int = 5, max_time_seconds: float = 60
|
| 37 |
+
) -> Tuple[float, List[float]]:
|
| 38 |
+
"""
|
| 39 |
+
Benchmark an iterator/iterable for `num_iter` iterations with an extra
|
| 40 |
+
`warmup` iterations of warmup.
|
| 41 |
+
End early if `max_time_seconds` time is spent on iterations.
|
| 42 |
+
|
| 43 |
+
Returns:
|
| 44 |
+
float: average time (seconds) per iteration
|
| 45 |
+
list[float]: time spent on each iteration. Sometimes useful for further analysis.
|
| 46 |
+
"""
|
| 47 |
+
num_iter, warmup = int(num_iter), int(warmup)
|
| 48 |
+
|
| 49 |
+
iterator = iter(iterator)
|
| 50 |
+
for _ in range(warmup):
|
| 51 |
+
next(iterator)
|
| 52 |
+
timer = Timer()
|
| 53 |
+
all_times = []
|
| 54 |
+
for curr_iter in tqdm.trange(num_iter):
|
| 55 |
+
start = timer.seconds()
|
| 56 |
+
if start > max_time_seconds:
|
| 57 |
+
num_iter = curr_iter
|
| 58 |
+
break
|
| 59 |
+
next(iterator)
|
| 60 |
+
all_times.append(timer.seconds() - start)
|
| 61 |
+
avg = timer.seconds() / num_iter
|
| 62 |
+
return avg, all_times
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
class DataLoaderBenchmark:
|
| 66 |
+
"""
|
| 67 |
+
Some common benchmarks that help understand perf bottleneck of a standard dataloader
|
| 68 |
+
made of dataset, mapper and sampler.
|
| 69 |
+
"""
|
| 70 |
+
|
| 71 |
+
def __init__(
|
| 72 |
+
self,
|
| 73 |
+
dataset,
|
| 74 |
+
*,
|
| 75 |
+
mapper,
|
| 76 |
+
sampler=None,
|
| 77 |
+
total_batch_size,
|
| 78 |
+
num_workers=0,
|
| 79 |
+
max_time_seconds: int = 90,
|
| 80 |
+
):
|
| 81 |
+
"""
|
| 82 |
+
Args:
|
| 83 |
+
max_time_seconds (int): maximum time to spent for each benchmark
|
| 84 |
+
other args: same as in `build.py:build_detection_train_loader`
|
| 85 |
+
"""
|
| 86 |
+
if isinstance(dataset, list):
|
| 87 |
+
dataset = DatasetFromList(dataset, copy=False, serialize=True)
|
| 88 |
+
if sampler is None:
|
| 89 |
+
sampler = TrainingSampler(len(dataset))
|
| 90 |
+
|
| 91 |
+
self.dataset = dataset
|
| 92 |
+
self.mapper = mapper
|
| 93 |
+
self.sampler = sampler
|
| 94 |
+
self.total_batch_size = total_batch_size
|
| 95 |
+
self.num_workers = num_workers
|
| 96 |
+
self.per_gpu_batch_size = self.total_batch_size // comm.get_world_size()
|
| 97 |
+
|
| 98 |
+
self.max_time_seconds = max_time_seconds
|
| 99 |
+
|
| 100 |
+
def _benchmark(self, iterator, num_iter, warmup, msg=None):
|
| 101 |
+
avg, all_times = iter_benchmark(iterator, num_iter, warmup, self.max_time_seconds)
|
| 102 |
+
if msg is not None:
|
| 103 |
+
self._log_time(msg, avg, all_times)
|
| 104 |
+
return avg, all_times
|
| 105 |
+
|
| 106 |
+
def _log_time(self, msg, avg, all_times, distributed=False):
|
| 107 |
+
percentiles = [np.percentile(all_times, k, interpolation="nearest") for k in [1, 5, 95, 99]]
|
| 108 |
+
if not distributed:
|
| 109 |
+
logger.info(
|
| 110 |
+
f"{msg}: avg={1.0/avg:.1f} it/s, "
|
| 111 |
+
f"p1={percentiles[0]:.2g}s, p5={percentiles[1]:.2g}s, "
|
| 112 |
+
f"p95={percentiles[2]:.2g}s, p99={percentiles[3]:.2g}s."
|
| 113 |
+
)
|
| 114 |
+
return
|
| 115 |
+
avg_per_gpu = comm.all_gather(avg)
|
| 116 |
+
percentiles_per_gpu = comm.all_gather(percentiles)
|
| 117 |
+
if comm.get_rank() > 0:
|
| 118 |
+
return
|
| 119 |
+
for idx, avg, percentiles in zip(count(), avg_per_gpu, percentiles_per_gpu):
|
| 120 |
+
logger.info(
|
| 121 |
+
f"GPU{idx} {msg}: avg={1.0/avg:.1f} it/s, "
|
| 122 |
+
f"p1={percentiles[0]:.2g}s, p5={percentiles[1]:.2g}s, "
|
| 123 |
+
f"p95={percentiles[2]:.2g}s, p99={percentiles[3]:.2g}s."
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
def benchmark_dataset(self, num_iter, warmup=5):
|
| 127 |
+
"""
|
| 128 |
+
Benchmark the speed of taking raw samples from the dataset.
|
| 129 |
+
"""
|
| 130 |
+
|
| 131 |
+
def loader():
|
| 132 |
+
while True:
|
| 133 |
+
for k in self.sampler:
|
| 134 |
+
yield self.dataset[k]
|
| 135 |
+
|
| 136 |
+
self._benchmark(loader(), num_iter, warmup, "Dataset Alone")
|
| 137 |
+
|
| 138 |
+
def benchmark_mapper(self, num_iter, warmup=5):
|
| 139 |
+
"""
|
| 140 |
+
Benchmark the speed of taking raw samples from the dataset and map
|
| 141 |
+
them in a single process.
|
| 142 |
+
"""
|
| 143 |
+
|
| 144 |
+
def loader():
|
| 145 |
+
while True:
|
| 146 |
+
for k in self.sampler:
|
| 147 |
+
yield self.mapper(self.dataset[k])
|
| 148 |
+
|
| 149 |
+
self._benchmark(loader(), num_iter, warmup, "Single Process Mapper (sec/sample)")
|
| 150 |
+
|
| 151 |
+
def benchmark_workers(self, num_iter, warmup=10):
|
| 152 |
+
"""
|
| 153 |
+
Benchmark the dataloader by tuning num_workers to [0, 1, self.num_workers].
|
| 154 |
+
"""
|
| 155 |
+
candidates = [0, 1]
|
| 156 |
+
if self.num_workers not in candidates:
|
| 157 |
+
candidates.append(self.num_workers)
|
| 158 |
+
|
| 159 |
+
dataset = MapDataset(self.dataset, self.mapper)
|
| 160 |
+
for n in candidates:
|
| 161 |
+
loader = build_batch_data_loader(
|
| 162 |
+
dataset,
|
| 163 |
+
self.sampler,
|
| 164 |
+
self.total_batch_size,
|
| 165 |
+
num_workers=n,
|
| 166 |
+
)
|
| 167 |
+
self._benchmark(
|
| 168 |
+
iter(loader),
|
| 169 |
+
num_iter * max(n, 1),
|
| 170 |
+
warmup * max(n, 1),
|
| 171 |
+
f"DataLoader ({n} workers, bs={self.per_gpu_batch_size})",
|
| 172 |
+
)
|
| 173 |
+
del loader
|
| 174 |
+
|
| 175 |
+
def benchmark_IPC(self, num_iter, warmup=10):
|
| 176 |
+
"""
|
| 177 |
+
Benchmark the dataloader where each worker outputs nothing. This
|
| 178 |
+
eliminates the IPC overhead compared to the regular dataloader.
|
| 179 |
+
|
| 180 |
+
PyTorch multiprocessing's IPC only optimizes for torch tensors.
|
| 181 |
+
Large numpy arrays or other data structure may incur large IPC overhead.
|
| 182 |
+
"""
|
| 183 |
+
n = self.num_workers
|
| 184 |
+
dataset = _EmptyMapDataset(MapDataset(self.dataset, self.mapper))
|
| 185 |
+
loader = build_batch_data_loader(
|
| 186 |
+
dataset, self.sampler, self.total_batch_size, num_workers=n
|
| 187 |
+
)
|
| 188 |
+
self._benchmark(
|
| 189 |
+
iter(loader),
|
| 190 |
+
num_iter * max(n, 1),
|
| 191 |
+
warmup * max(n, 1),
|
| 192 |
+
f"DataLoader ({n} workers, bs={self.per_gpu_batch_size}) w/o comm",
|
| 193 |
+
)
|
| 194 |
+
|
| 195 |
+
def benchmark_distributed(self, num_iter, warmup=10):
|
| 196 |
+
"""
|
| 197 |
+
Benchmark the dataloader in each distributed worker, and log results of
|
| 198 |
+
all workers. This helps understand the final performance as well as
|
| 199 |
+
the variances among workers.
|
| 200 |
+
|
| 201 |
+
It also prints startup time (first iter) of the dataloader.
|
| 202 |
+
"""
|
| 203 |
+
gpu = comm.get_world_size()
|
| 204 |
+
dataset = MapDataset(self.dataset, self.mapper)
|
| 205 |
+
n = self.num_workers
|
| 206 |
+
loader = build_batch_data_loader(
|
| 207 |
+
dataset, self.sampler, self.total_batch_size, num_workers=n
|
| 208 |
+
)
|
| 209 |
+
|
| 210 |
+
timer = Timer()
|
| 211 |
+
loader = iter(loader)
|
| 212 |
+
next(loader)
|
| 213 |
+
startup_time = timer.seconds()
|
| 214 |
+
logger.info("Dataloader startup time: {:.2f} seconds".format(startup_time))
|
| 215 |
+
|
| 216 |
+
comm.synchronize()
|
| 217 |
+
|
| 218 |
+
avg, all_times = self._benchmark(loader, num_iter * max(n, 1), warmup * max(n, 1))
|
| 219 |
+
del loader
|
| 220 |
+
self._log_time(
|
| 221 |
+
f"DataLoader ({gpu} GPUs x {n} workers, total bs={self.total_batch_size})",
|
| 222 |
+
avg,
|
| 223 |
+
all_times,
|
| 224 |
+
True,
|
| 225 |
+
)
|
CatVTON/detectron2/data/build.py
ADDED
|
@@ -0,0 +1,694 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
import itertools
|
| 3 |
+
import logging
|
| 4 |
+
import numpy as np
|
| 5 |
+
import operator
|
| 6 |
+
import pickle
|
| 7 |
+
from collections import OrderedDict, defaultdict
|
| 8 |
+
from typing import Any, Callable, Dict, List, Optional, Union
|
| 9 |
+
import torch
|
| 10 |
+
import torch.utils.data as torchdata
|
| 11 |
+
from tabulate import tabulate
|
| 12 |
+
from termcolor import colored
|
| 13 |
+
|
| 14 |
+
from detectron2.config import configurable
|
| 15 |
+
from detectron2.structures import BoxMode
|
| 16 |
+
from detectron2.utils.comm import get_world_size
|
| 17 |
+
from detectron2.utils.env import seed_all_rng
|
| 18 |
+
from detectron2.utils.file_io import PathManager
|
| 19 |
+
from detectron2.utils.logger import _log_api_usage, log_first_n
|
| 20 |
+
|
| 21 |
+
from .catalog import DatasetCatalog, MetadataCatalog
|
| 22 |
+
from .common import AspectRatioGroupedDataset, DatasetFromList, MapDataset, ToIterableDataset
|
| 23 |
+
from .dataset_mapper import DatasetMapper
|
| 24 |
+
from .detection_utils import check_metadata_consistency
|
| 25 |
+
from .samplers import (
|
| 26 |
+
InferenceSampler,
|
| 27 |
+
RandomSubsetTrainingSampler,
|
| 28 |
+
RepeatFactorTrainingSampler,
|
| 29 |
+
TrainingSampler,
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
"""
|
| 33 |
+
This file contains the default logic to build a dataloader for training or testing.
|
| 34 |
+
"""
|
| 35 |
+
|
| 36 |
+
__all__ = [
|
| 37 |
+
"build_batch_data_loader",
|
| 38 |
+
"build_detection_train_loader",
|
| 39 |
+
"build_detection_test_loader",
|
| 40 |
+
"get_detection_dataset_dicts",
|
| 41 |
+
"load_proposals_into_dataset",
|
| 42 |
+
"print_instances_class_histogram",
|
| 43 |
+
]
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def filter_images_with_only_crowd_annotations(dataset_dicts):
|
| 47 |
+
"""
|
| 48 |
+
Filter out images with none annotations or only crowd annotations
|
| 49 |
+
(i.e., images without non-crowd annotations).
|
| 50 |
+
A common training-time preprocessing on COCO dataset.
|
| 51 |
+
|
| 52 |
+
Args:
|
| 53 |
+
dataset_dicts (list[dict]): annotations in Detectron2 Dataset format.
|
| 54 |
+
|
| 55 |
+
Returns:
|
| 56 |
+
list[dict]: the same format, but filtered.
|
| 57 |
+
"""
|
| 58 |
+
num_before = len(dataset_dicts)
|
| 59 |
+
|
| 60 |
+
def valid(anns):
|
| 61 |
+
for ann in anns:
|
| 62 |
+
if ann.get("iscrowd", 0) == 0:
|
| 63 |
+
return True
|
| 64 |
+
return False
|
| 65 |
+
|
| 66 |
+
dataset_dicts = [x for x in dataset_dicts if valid(x["annotations"])]
|
| 67 |
+
num_after = len(dataset_dicts)
|
| 68 |
+
logger = logging.getLogger(__name__)
|
| 69 |
+
logger.info(
|
| 70 |
+
"Removed {} images with no usable annotations. {} images left.".format(
|
| 71 |
+
num_before - num_after, num_after
|
| 72 |
+
)
|
| 73 |
+
)
|
| 74 |
+
return dataset_dicts
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def filter_images_with_few_keypoints(dataset_dicts, min_keypoints_per_image):
|
| 78 |
+
"""
|
| 79 |
+
Filter out images with too few number of keypoints.
|
| 80 |
+
|
| 81 |
+
Args:
|
| 82 |
+
dataset_dicts (list[dict]): annotations in Detectron2 Dataset format.
|
| 83 |
+
|
| 84 |
+
Returns:
|
| 85 |
+
list[dict]: the same format as dataset_dicts, but filtered.
|
| 86 |
+
"""
|
| 87 |
+
num_before = len(dataset_dicts)
|
| 88 |
+
|
| 89 |
+
def visible_keypoints_in_image(dic):
|
| 90 |
+
# Each keypoints field has the format [x1, y1, v1, ...], where v is visibility
|
| 91 |
+
annotations = dic["annotations"]
|
| 92 |
+
return sum(
|
| 93 |
+
(np.array(ann["keypoints"][2::3]) > 0).sum()
|
| 94 |
+
for ann in annotations
|
| 95 |
+
if "keypoints" in ann
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
dataset_dicts = [
|
| 99 |
+
x for x in dataset_dicts if visible_keypoints_in_image(x) >= min_keypoints_per_image
|
| 100 |
+
]
|
| 101 |
+
num_after = len(dataset_dicts)
|
| 102 |
+
logger = logging.getLogger(__name__)
|
| 103 |
+
logger.info(
|
| 104 |
+
"Removed {} images with fewer than {} keypoints.".format(
|
| 105 |
+
num_before - num_after, min_keypoints_per_image
|
| 106 |
+
)
|
| 107 |
+
)
|
| 108 |
+
return dataset_dicts
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def load_proposals_into_dataset(dataset_dicts, proposal_file):
|
| 112 |
+
"""
|
| 113 |
+
Load precomputed object proposals into the dataset.
|
| 114 |
+
|
| 115 |
+
The proposal file should be a pickled dict with the following keys:
|
| 116 |
+
|
| 117 |
+
- "ids": list[int] or list[str], the image ids
|
| 118 |
+
- "boxes": list[np.ndarray], each is an Nx4 array of boxes corresponding to the image id
|
| 119 |
+
- "objectness_logits": list[np.ndarray], each is an N sized array of objectness scores
|
| 120 |
+
corresponding to the boxes.
|
| 121 |
+
- "bbox_mode": the BoxMode of the boxes array. Defaults to ``BoxMode.XYXY_ABS``.
|
| 122 |
+
|
| 123 |
+
Args:
|
| 124 |
+
dataset_dicts (list[dict]): annotations in Detectron2 Dataset format.
|
| 125 |
+
proposal_file (str): file path of pre-computed proposals, in pkl format.
|
| 126 |
+
|
| 127 |
+
Returns:
|
| 128 |
+
list[dict]: the same format as dataset_dicts, but added proposal field.
|
| 129 |
+
"""
|
| 130 |
+
logger = logging.getLogger(__name__)
|
| 131 |
+
logger.info("Loading proposals from: {}".format(proposal_file))
|
| 132 |
+
|
| 133 |
+
with PathManager.open(proposal_file, "rb") as f:
|
| 134 |
+
proposals = pickle.load(f, encoding="latin1")
|
| 135 |
+
|
| 136 |
+
# Rename the key names in D1 proposal files
|
| 137 |
+
rename_keys = {"indexes": "ids", "scores": "objectness_logits"}
|
| 138 |
+
for key in rename_keys:
|
| 139 |
+
if key in proposals:
|
| 140 |
+
proposals[rename_keys[key]] = proposals.pop(key)
|
| 141 |
+
|
| 142 |
+
# Fetch the indexes of all proposals that are in the dataset
|
| 143 |
+
# Convert image_id to str since they could be int.
|
| 144 |
+
img_ids = set({str(record["image_id"]) for record in dataset_dicts})
|
| 145 |
+
id_to_index = {str(id): i for i, id in enumerate(proposals["ids"]) if str(id) in img_ids}
|
| 146 |
+
|
| 147 |
+
# Assuming default bbox_mode of precomputed proposals are 'XYXY_ABS'
|
| 148 |
+
bbox_mode = BoxMode(proposals["bbox_mode"]) if "bbox_mode" in proposals else BoxMode.XYXY_ABS
|
| 149 |
+
|
| 150 |
+
for record in dataset_dicts:
|
| 151 |
+
# Get the index of the proposal
|
| 152 |
+
i = id_to_index[str(record["image_id"])]
|
| 153 |
+
|
| 154 |
+
boxes = proposals["boxes"][i]
|
| 155 |
+
objectness_logits = proposals["objectness_logits"][i]
|
| 156 |
+
# Sort the proposals in descending order of the scores
|
| 157 |
+
inds = objectness_logits.argsort()[::-1]
|
| 158 |
+
record["proposal_boxes"] = boxes[inds]
|
| 159 |
+
record["proposal_objectness_logits"] = objectness_logits[inds]
|
| 160 |
+
record["proposal_bbox_mode"] = bbox_mode
|
| 161 |
+
|
| 162 |
+
return dataset_dicts
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
def print_instances_class_histogram(dataset_dicts, class_names):
|
| 166 |
+
"""
|
| 167 |
+
Args:
|
| 168 |
+
dataset_dicts (list[dict]): list of dataset dicts.
|
| 169 |
+
class_names (list[str]): list of class names (zero-indexed).
|
| 170 |
+
"""
|
| 171 |
+
num_classes = len(class_names)
|
| 172 |
+
hist_bins = np.arange(num_classes + 1)
|
| 173 |
+
histogram = np.zeros((num_classes,), dtype=int)
|
| 174 |
+
for entry in dataset_dicts:
|
| 175 |
+
annos = entry["annotations"]
|
| 176 |
+
classes = np.asarray(
|
| 177 |
+
[x["category_id"] for x in annos if not x.get("iscrowd", 0)], dtype=int
|
| 178 |
+
)
|
| 179 |
+
if len(classes):
|
| 180 |
+
assert classes.min() >= 0, f"Got an invalid category_id={classes.min()}"
|
| 181 |
+
assert (
|
| 182 |
+
classes.max() < num_classes
|
| 183 |
+
), f"Got an invalid category_id={classes.max()} for a dataset of {num_classes} classes"
|
| 184 |
+
histogram += np.histogram(classes, bins=hist_bins)[0]
|
| 185 |
+
|
| 186 |
+
N_COLS = min(6, len(class_names) * 2)
|
| 187 |
+
|
| 188 |
+
def short_name(x):
|
| 189 |
+
# make long class names shorter. useful for lvis
|
| 190 |
+
if len(x) > 13:
|
| 191 |
+
return x[:11] + ".."
|
| 192 |
+
return x
|
| 193 |
+
|
| 194 |
+
data = list(
|
| 195 |
+
itertools.chain(*[[short_name(class_names[i]), int(v)] for i, v in enumerate(histogram)])
|
| 196 |
+
)
|
| 197 |
+
total_num_instances = sum(data[1::2])
|
| 198 |
+
data.extend([None] * (N_COLS - (len(data) % N_COLS)))
|
| 199 |
+
if num_classes > 1:
|
| 200 |
+
data.extend(["total", total_num_instances])
|
| 201 |
+
data = itertools.zip_longest(*[data[i::N_COLS] for i in range(N_COLS)])
|
| 202 |
+
table = tabulate(
|
| 203 |
+
data,
|
| 204 |
+
headers=["category", "#instances"] * (N_COLS // 2),
|
| 205 |
+
tablefmt="pipe",
|
| 206 |
+
numalign="left",
|
| 207 |
+
stralign="center",
|
| 208 |
+
)
|
| 209 |
+
log_first_n(
|
| 210 |
+
logging.INFO,
|
| 211 |
+
"Distribution of instances among all {} categories:\n".format(num_classes)
|
| 212 |
+
+ colored(table, "cyan"),
|
| 213 |
+
key="message",
|
| 214 |
+
)
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
def get_detection_dataset_dicts(
|
| 218 |
+
names,
|
| 219 |
+
filter_empty=True,
|
| 220 |
+
min_keypoints=0,
|
| 221 |
+
proposal_files=None,
|
| 222 |
+
check_consistency=True,
|
| 223 |
+
):
|
| 224 |
+
"""
|
| 225 |
+
Load and prepare dataset dicts for instance detection/segmentation and semantic segmentation.
|
| 226 |
+
|
| 227 |
+
Args:
|
| 228 |
+
names (str or list[str]): a dataset name or a list of dataset names
|
| 229 |
+
filter_empty (bool): whether to filter out images without instance annotations
|
| 230 |
+
min_keypoints (int): filter out images with fewer keypoints than
|
| 231 |
+
`min_keypoints`. Set to 0 to do nothing.
|
| 232 |
+
proposal_files (list[str]): if given, a list of object proposal files
|
| 233 |
+
that match each dataset in `names`.
|
| 234 |
+
check_consistency (bool): whether to check if datasets have consistent metadata.
|
| 235 |
+
|
| 236 |
+
Returns:
|
| 237 |
+
list[dict]: a list of dicts following the standard dataset dict format.
|
| 238 |
+
"""
|
| 239 |
+
if isinstance(names, str):
|
| 240 |
+
names = [names]
|
| 241 |
+
assert len(names), names
|
| 242 |
+
|
| 243 |
+
available_datasets = DatasetCatalog.keys()
|
| 244 |
+
names_set = set(names)
|
| 245 |
+
if not names_set.issubset(available_datasets):
|
| 246 |
+
logger = logging.getLogger(__name__)
|
| 247 |
+
logger.warning(
|
| 248 |
+
"The following dataset names are not registered in the DatasetCatalog: "
|
| 249 |
+
f"{names_set - available_datasets}. "
|
| 250 |
+
f"Available datasets are {available_datasets}"
|
| 251 |
+
)
|
| 252 |
+
|
| 253 |
+
dataset_dicts = [DatasetCatalog.get(dataset_name) for dataset_name in names]
|
| 254 |
+
|
| 255 |
+
if isinstance(dataset_dicts[0], torchdata.Dataset):
|
| 256 |
+
if len(dataset_dicts) > 1:
|
| 257 |
+
# ConcatDataset does not work for iterable style dataset.
|
| 258 |
+
# We could support concat for iterable as well, but it's often
|
| 259 |
+
# not a good idea to concat iterables anyway.
|
| 260 |
+
return torchdata.ConcatDataset(dataset_dicts)
|
| 261 |
+
return dataset_dicts[0]
|
| 262 |
+
|
| 263 |
+
for dataset_name, dicts in zip(names, dataset_dicts):
|
| 264 |
+
assert len(dicts), "Dataset '{}' is empty!".format(dataset_name)
|
| 265 |
+
|
| 266 |
+
if proposal_files is not None:
|
| 267 |
+
assert len(names) == len(proposal_files)
|
| 268 |
+
# load precomputed proposals from proposal files
|
| 269 |
+
dataset_dicts = [
|
| 270 |
+
load_proposals_into_dataset(dataset_i_dicts, proposal_file)
|
| 271 |
+
for dataset_i_dicts, proposal_file in zip(dataset_dicts, proposal_files)
|
| 272 |
+
]
|
| 273 |
+
|
| 274 |
+
dataset_dicts = list(itertools.chain.from_iterable(dataset_dicts))
|
| 275 |
+
|
| 276 |
+
has_instances = "annotations" in dataset_dicts[0]
|
| 277 |
+
if filter_empty and has_instances:
|
| 278 |
+
dataset_dicts = filter_images_with_only_crowd_annotations(dataset_dicts)
|
| 279 |
+
if min_keypoints > 0 and has_instances:
|
| 280 |
+
dataset_dicts = filter_images_with_few_keypoints(dataset_dicts, min_keypoints)
|
| 281 |
+
|
| 282 |
+
if check_consistency and has_instances:
|
| 283 |
+
try:
|
| 284 |
+
class_names = MetadataCatalog.get(names[0]).thing_classes
|
| 285 |
+
check_metadata_consistency("thing_classes", names)
|
| 286 |
+
print_instances_class_histogram(dataset_dicts, class_names)
|
| 287 |
+
except AttributeError: # class names are not available for this dataset
|
| 288 |
+
pass
|
| 289 |
+
|
| 290 |
+
assert len(dataset_dicts), "No valid data found in {}.".format(",".join(names))
|
| 291 |
+
return dataset_dicts
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
def build_batch_data_loader(
|
| 295 |
+
dataset,
|
| 296 |
+
sampler,
|
| 297 |
+
total_batch_size,
|
| 298 |
+
*,
|
| 299 |
+
aspect_ratio_grouping=False,
|
| 300 |
+
num_workers=0,
|
| 301 |
+
collate_fn=None,
|
| 302 |
+
drop_last: bool = True,
|
| 303 |
+
single_gpu_batch_size=None,
|
| 304 |
+
prefetch_factor=2,
|
| 305 |
+
persistent_workers=False,
|
| 306 |
+
pin_memory=False,
|
| 307 |
+
seed=None,
|
| 308 |
+
**kwargs,
|
| 309 |
+
):
|
| 310 |
+
"""
|
| 311 |
+
Build a batched dataloader. The main differences from `torch.utils.data.DataLoader` are:
|
| 312 |
+
1. support aspect ratio grouping options
|
| 313 |
+
2. use no "batch collation", because this is common for detection training
|
| 314 |
+
|
| 315 |
+
Args:
|
| 316 |
+
dataset (torch.utils.data.Dataset): a pytorch map-style or iterable dataset.
|
| 317 |
+
sampler (torch.utils.data.sampler.Sampler or None): a sampler that produces indices.
|
| 318 |
+
Must be provided iff. ``dataset`` is a map-style dataset.
|
| 319 |
+
total_batch_size, aspect_ratio_grouping, num_workers, collate_fn: see
|
| 320 |
+
:func:`build_detection_train_loader`.
|
| 321 |
+
single_gpu_batch_size: You can specify either `single_gpu_batch_size` or `total_batch_size`.
|
| 322 |
+
`single_gpu_batch_size` specifies the batch size that will be used for each gpu/process.
|
| 323 |
+
`total_batch_size` allows you to specify the total aggregate batch size across gpus.
|
| 324 |
+
It is an error to supply a value for both.
|
| 325 |
+
drop_last (bool): if ``True``, the dataloader will drop incomplete batches.
|
| 326 |
+
|
| 327 |
+
Returns:
|
| 328 |
+
iterable[list]. Length of each list is the batch size of the current
|
| 329 |
+
GPU. Each element in the list comes from the dataset.
|
| 330 |
+
"""
|
| 331 |
+
if single_gpu_batch_size:
|
| 332 |
+
if total_batch_size:
|
| 333 |
+
raise ValueError(
|
| 334 |
+
"""total_batch_size and single_gpu_batch_size are mutually incompatible.
|
| 335 |
+
Please specify only one. """
|
| 336 |
+
)
|
| 337 |
+
batch_size = single_gpu_batch_size
|
| 338 |
+
else:
|
| 339 |
+
world_size = get_world_size()
|
| 340 |
+
assert (
|
| 341 |
+
total_batch_size > 0 and total_batch_size % world_size == 0
|
| 342 |
+
), "Total batch size ({}) must be divisible by the number of gpus ({}).".format(
|
| 343 |
+
total_batch_size, world_size
|
| 344 |
+
)
|
| 345 |
+
batch_size = total_batch_size // world_size
|
| 346 |
+
logger = logging.getLogger(__name__)
|
| 347 |
+
logger.info("Making batched data loader with batch_size=%d", batch_size)
|
| 348 |
+
|
| 349 |
+
if isinstance(dataset, torchdata.IterableDataset):
|
| 350 |
+
assert sampler is None, "sampler must be None if dataset is IterableDataset"
|
| 351 |
+
else:
|
| 352 |
+
dataset = ToIterableDataset(dataset, sampler, shard_chunk_size=batch_size)
|
| 353 |
+
|
| 354 |
+
generator = None
|
| 355 |
+
if seed is not None:
|
| 356 |
+
generator = torch.Generator()
|
| 357 |
+
generator.manual_seed(seed)
|
| 358 |
+
|
| 359 |
+
if aspect_ratio_grouping:
|
| 360 |
+
assert drop_last, "Aspect ratio grouping will drop incomplete batches."
|
| 361 |
+
data_loader = torchdata.DataLoader(
|
| 362 |
+
dataset,
|
| 363 |
+
num_workers=num_workers,
|
| 364 |
+
collate_fn=operator.itemgetter(0), # don't batch, but yield individual elements
|
| 365 |
+
worker_init_fn=worker_init_reset_seed,
|
| 366 |
+
prefetch_factor=prefetch_factor if num_workers > 0 else None,
|
| 367 |
+
persistent_workers=persistent_workers,
|
| 368 |
+
pin_memory=pin_memory,
|
| 369 |
+
generator=generator,
|
| 370 |
+
**kwargs,
|
| 371 |
+
) # yield individual mapped dict
|
| 372 |
+
data_loader = AspectRatioGroupedDataset(data_loader, batch_size)
|
| 373 |
+
if collate_fn is None:
|
| 374 |
+
return data_loader
|
| 375 |
+
return MapDataset(data_loader, collate_fn)
|
| 376 |
+
else:
|
| 377 |
+
return torchdata.DataLoader(
|
| 378 |
+
dataset,
|
| 379 |
+
batch_size=batch_size,
|
| 380 |
+
drop_last=drop_last,
|
| 381 |
+
num_workers=num_workers,
|
| 382 |
+
collate_fn=trivial_batch_collator if collate_fn is None else collate_fn,
|
| 383 |
+
worker_init_fn=worker_init_reset_seed,
|
| 384 |
+
prefetch_factor=prefetch_factor if num_workers > 0 else None,
|
| 385 |
+
persistent_workers=persistent_workers,
|
| 386 |
+
pin_memory=pin_memory,
|
| 387 |
+
generator=generator,
|
| 388 |
+
**kwargs,
|
| 389 |
+
)
|
| 390 |
+
|
| 391 |
+
|
| 392 |
+
def _get_train_datasets_repeat_factors(cfg) -> Dict[str, float]:
|
| 393 |
+
repeat_factors = cfg.DATASETS.TRAIN_REPEAT_FACTOR
|
| 394 |
+
assert all(len(tup) == 2 for tup in repeat_factors)
|
| 395 |
+
name_to_weight = defaultdict(lambda: 1, dict(repeat_factors))
|
| 396 |
+
# The sampling weights map should only contain datasets in train config
|
| 397 |
+
unrecognized = set(name_to_weight.keys()) - set(cfg.DATASETS.TRAIN)
|
| 398 |
+
assert not unrecognized, f"unrecognized datasets: {unrecognized}"
|
| 399 |
+
logger = logging.getLogger(__name__)
|
| 400 |
+
logger.info(f"Found repeat factors: {list(name_to_weight.items())}")
|
| 401 |
+
|
| 402 |
+
# pyre-fixme[7]: Expected `Dict[str, float]` but got `DefaultDict[typing.Any, int]`.
|
| 403 |
+
return name_to_weight
|
| 404 |
+
|
| 405 |
+
|
| 406 |
+
def _build_weighted_sampler(cfg, enable_category_balance=False):
|
| 407 |
+
dataset_repeat_factors = _get_train_datasets_repeat_factors(cfg)
|
| 408 |
+
# OrderedDict to guarantee order of values() consistent with repeat factors
|
| 409 |
+
dataset_name_to_dicts = OrderedDict(
|
| 410 |
+
{
|
| 411 |
+
name: get_detection_dataset_dicts(
|
| 412 |
+
[name],
|
| 413 |
+
filter_empty=cfg.DATALOADER.FILTER_EMPTY_ANNOTATIONS,
|
| 414 |
+
min_keypoints=(
|
| 415 |
+
cfg.MODEL.ROI_KEYPOINT_HEAD.MIN_KEYPOINTS_PER_IMAGE
|
| 416 |
+
if cfg.MODEL.KEYPOINT_ON
|
| 417 |
+
else 0
|
| 418 |
+
),
|
| 419 |
+
proposal_files=(
|
| 420 |
+
cfg.DATASETS.PROPOSAL_FILES_TRAIN if cfg.MODEL.LOAD_PROPOSALS else None
|
| 421 |
+
),
|
| 422 |
+
)
|
| 423 |
+
for name in cfg.DATASETS.TRAIN
|
| 424 |
+
}
|
| 425 |
+
)
|
| 426 |
+
# Repeat factor for every sample in the dataset
|
| 427 |
+
repeat_factors = [
|
| 428 |
+
[dataset_repeat_factors[dsname]] * len(dataset_name_to_dicts[dsname])
|
| 429 |
+
for dsname in cfg.DATASETS.TRAIN
|
| 430 |
+
]
|
| 431 |
+
|
| 432 |
+
repeat_factors = list(itertools.chain.from_iterable(repeat_factors))
|
| 433 |
+
|
| 434 |
+
repeat_factors = torch.tensor(repeat_factors)
|
| 435 |
+
logger = logging.getLogger(__name__)
|
| 436 |
+
if enable_category_balance:
|
| 437 |
+
"""
|
| 438 |
+
1. Calculate repeat factors using category frequency for each dataset and then merge them.
|
| 439 |
+
2. Element wise dot producting the dataset frequency repeat factors with
|
| 440 |
+
the category frequency repeat factors gives the final repeat factors.
|
| 441 |
+
"""
|
| 442 |
+
category_repeat_factors = [
|
| 443 |
+
RepeatFactorTrainingSampler.repeat_factors_from_category_frequency(
|
| 444 |
+
dataset_dict, cfg.DATALOADER.REPEAT_THRESHOLD, sqrt=cfg.DATALOADER.REPEAT_SQRT
|
| 445 |
+
)
|
| 446 |
+
for dataset_dict in dataset_name_to_dicts.values()
|
| 447 |
+
]
|
| 448 |
+
# flatten the category repeat factors from all datasets
|
| 449 |
+
category_repeat_factors = list(itertools.chain.from_iterable(category_repeat_factors))
|
| 450 |
+
category_repeat_factors = torch.tensor(category_repeat_factors)
|
| 451 |
+
repeat_factors = torch.mul(category_repeat_factors, repeat_factors)
|
| 452 |
+
repeat_factors = repeat_factors / torch.min(repeat_factors)
|
| 453 |
+
logger.info(
|
| 454 |
+
"Using WeightedCategoryTrainingSampler with repeat_factors={}".format(
|
| 455 |
+
cfg.DATASETS.TRAIN_REPEAT_FACTOR
|
| 456 |
+
)
|
| 457 |
+
)
|
| 458 |
+
else:
|
| 459 |
+
logger.info(
|
| 460 |
+
"Using WeightedTrainingSampler with repeat_factors={}".format(
|
| 461 |
+
cfg.DATASETS.TRAIN_REPEAT_FACTOR
|
| 462 |
+
)
|
| 463 |
+
)
|
| 464 |
+
|
| 465 |
+
sampler = RepeatFactorTrainingSampler(repeat_factors)
|
| 466 |
+
return sampler
|
| 467 |
+
|
| 468 |
+
|
| 469 |
+
def _train_loader_from_config(cfg, mapper=None, *, dataset=None, sampler=None):
|
| 470 |
+
if dataset is None:
|
| 471 |
+
dataset = get_detection_dataset_dicts(
|
| 472 |
+
cfg.DATASETS.TRAIN,
|
| 473 |
+
filter_empty=cfg.DATALOADER.FILTER_EMPTY_ANNOTATIONS,
|
| 474 |
+
min_keypoints=(
|
| 475 |
+
cfg.MODEL.ROI_KEYPOINT_HEAD.MIN_KEYPOINTS_PER_IMAGE if cfg.MODEL.KEYPOINT_ON else 0
|
| 476 |
+
),
|
| 477 |
+
proposal_files=cfg.DATASETS.PROPOSAL_FILES_TRAIN if cfg.MODEL.LOAD_PROPOSALS else None,
|
| 478 |
+
)
|
| 479 |
+
_log_api_usage("dataset." + cfg.DATASETS.TRAIN[0])
|
| 480 |
+
|
| 481 |
+
if mapper is None:
|
| 482 |
+
mapper = DatasetMapper(cfg, True)
|
| 483 |
+
|
| 484 |
+
if sampler is None:
|
| 485 |
+
sampler_name = cfg.DATALOADER.SAMPLER_TRAIN
|
| 486 |
+
logger = logging.getLogger(__name__)
|
| 487 |
+
if isinstance(dataset, torchdata.IterableDataset):
|
| 488 |
+
logger.info("Not using any sampler since the dataset is IterableDataset.")
|
| 489 |
+
sampler = None
|
| 490 |
+
else:
|
| 491 |
+
logger.info("Using training sampler {}".format(sampler_name))
|
| 492 |
+
if sampler_name == "TrainingSampler":
|
| 493 |
+
sampler = TrainingSampler(len(dataset))
|
| 494 |
+
elif sampler_name == "RepeatFactorTrainingSampler":
|
| 495 |
+
repeat_factors = RepeatFactorTrainingSampler.repeat_factors_from_category_frequency(
|
| 496 |
+
dataset, cfg.DATALOADER.REPEAT_THRESHOLD, sqrt=cfg.DATALOADER.REPEAT_SQRT
|
| 497 |
+
)
|
| 498 |
+
sampler = RepeatFactorTrainingSampler(repeat_factors, seed=cfg.SEED)
|
| 499 |
+
elif sampler_name == "RandomSubsetTrainingSampler":
|
| 500 |
+
sampler = RandomSubsetTrainingSampler(
|
| 501 |
+
len(dataset), cfg.DATALOADER.RANDOM_SUBSET_RATIO
|
| 502 |
+
)
|
| 503 |
+
elif sampler_name == "WeightedTrainingSampler":
|
| 504 |
+
sampler = _build_weighted_sampler(cfg)
|
| 505 |
+
elif sampler_name == "WeightedCategoryTrainingSampler":
|
| 506 |
+
sampler = _build_weighted_sampler(cfg, enable_category_balance=True)
|
| 507 |
+
else:
|
| 508 |
+
raise ValueError("Unknown training sampler: {}".format(sampler_name))
|
| 509 |
+
|
| 510 |
+
return {
|
| 511 |
+
"dataset": dataset,
|
| 512 |
+
"sampler": sampler,
|
| 513 |
+
"mapper": mapper,
|
| 514 |
+
"total_batch_size": cfg.SOLVER.IMS_PER_BATCH,
|
| 515 |
+
"aspect_ratio_grouping": cfg.DATALOADER.ASPECT_RATIO_GROUPING,
|
| 516 |
+
"num_workers": cfg.DATALOADER.NUM_WORKERS,
|
| 517 |
+
}
|
| 518 |
+
|
| 519 |
+
|
| 520 |
+
@configurable(from_config=_train_loader_from_config)
|
| 521 |
+
def build_detection_train_loader(
|
| 522 |
+
dataset,
|
| 523 |
+
*,
|
| 524 |
+
mapper,
|
| 525 |
+
sampler=None,
|
| 526 |
+
total_batch_size,
|
| 527 |
+
aspect_ratio_grouping=True,
|
| 528 |
+
num_workers=0,
|
| 529 |
+
collate_fn=None,
|
| 530 |
+
**kwargs,
|
| 531 |
+
):
|
| 532 |
+
"""
|
| 533 |
+
Build a dataloader for object detection with some default features.
|
| 534 |
+
|
| 535 |
+
Args:
|
| 536 |
+
dataset (list or torch.utils.data.Dataset): a list of dataset dicts,
|
| 537 |
+
or a pytorch dataset (either map-style or iterable). It can be obtained
|
| 538 |
+
by using :func:`DatasetCatalog.get` or :func:`get_detection_dataset_dicts`.
|
| 539 |
+
mapper (callable): a callable which takes a sample (dict) from dataset and
|
| 540 |
+
returns the format to be consumed by the model.
|
| 541 |
+
When using cfg, the default choice is ``DatasetMapper(cfg, is_train=True)``.
|
| 542 |
+
sampler (torch.utils.data.sampler.Sampler or None): a sampler that produces
|
| 543 |
+
indices to be applied on ``dataset``.
|
| 544 |
+
If ``dataset`` is map-style, the default sampler is a :class:`TrainingSampler`,
|
| 545 |
+
which coordinates an infinite random shuffle sequence across all workers.
|
| 546 |
+
Sampler must be None if ``dataset`` is iterable.
|
| 547 |
+
total_batch_size (int): total batch size across all workers.
|
| 548 |
+
aspect_ratio_grouping (bool): whether to group images with similar
|
| 549 |
+
aspect ratio for efficiency. When enabled, it requires each
|
| 550 |
+
element in dataset be a dict with keys "width" and "height".
|
| 551 |
+
num_workers (int): number of parallel data loading workers
|
| 552 |
+
collate_fn: a function that determines how to do batching, same as the argument of
|
| 553 |
+
`torch.utils.data.DataLoader`. Defaults to do no collation and return a list of
|
| 554 |
+
data. No collation is OK for small batch size and simple data structures.
|
| 555 |
+
If your batch size is large and each sample contains too many small tensors,
|
| 556 |
+
it's more efficient to collate them in data loader.
|
| 557 |
+
|
| 558 |
+
Returns:
|
| 559 |
+
torch.utils.data.DataLoader:
|
| 560 |
+
a dataloader. Each output from it is a ``list[mapped_element]`` of length
|
| 561 |
+
``total_batch_size / num_workers``, where ``mapped_element`` is produced
|
| 562 |
+
by the ``mapper``.
|
| 563 |
+
"""
|
| 564 |
+
if isinstance(dataset, list):
|
| 565 |
+
dataset = DatasetFromList(dataset, copy=False)
|
| 566 |
+
if mapper is not None:
|
| 567 |
+
dataset = MapDataset(dataset, mapper)
|
| 568 |
+
|
| 569 |
+
if isinstance(dataset, torchdata.IterableDataset):
|
| 570 |
+
assert sampler is None, "sampler must be None if dataset is IterableDataset"
|
| 571 |
+
else:
|
| 572 |
+
if sampler is None:
|
| 573 |
+
sampler = TrainingSampler(len(dataset))
|
| 574 |
+
assert isinstance(sampler, torchdata.Sampler), f"Expect a Sampler but got {type(sampler)}"
|
| 575 |
+
return build_batch_data_loader(
|
| 576 |
+
dataset,
|
| 577 |
+
sampler,
|
| 578 |
+
total_batch_size,
|
| 579 |
+
aspect_ratio_grouping=aspect_ratio_grouping,
|
| 580 |
+
num_workers=num_workers,
|
| 581 |
+
collate_fn=collate_fn,
|
| 582 |
+
**kwargs,
|
| 583 |
+
)
|
| 584 |
+
|
| 585 |
+
|
| 586 |
+
def _test_loader_from_config(cfg, dataset_name, mapper=None):
|
| 587 |
+
"""
|
| 588 |
+
Uses the given `dataset_name` argument (instead of the names in cfg), because the
|
| 589 |
+
standard practice is to evaluate each test set individually (not combining them).
|
| 590 |
+
"""
|
| 591 |
+
if isinstance(dataset_name, str):
|
| 592 |
+
dataset_name = [dataset_name]
|
| 593 |
+
|
| 594 |
+
dataset = get_detection_dataset_dicts(
|
| 595 |
+
dataset_name,
|
| 596 |
+
filter_empty=False,
|
| 597 |
+
proposal_files=(
|
| 598 |
+
[
|
| 599 |
+
cfg.DATASETS.PROPOSAL_FILES_TEST[list(cfg.DATASETS.TEST).index(x)]
|
| 600 |
+
for x in dataset_name
|
| 601 |
+
]
|
| 602 |
+
if cfg.MODEL.LOAD_PROPOSALS
|
| 603 |
+
else None
|
| 604 |
+
),
|
| 605 |
+
)
|
| 606 |
+
if mapper is None:
|
| 607 |
+
mapper = DatasetMapper(cfg, False)
|
| 608 |
+
return {
|
| 609 |
+
"dataset": dataset,
|
| 610 |
+
"mapper": mapper,
|
| 611 |
+
"num_workers": cfg.DATALOADER.NUM_WORKERS,
|
| 612 |
+
"sampler": (
|
| 613 |
+
InferenceSampler(len(dataset))
|
| 614 |
+
if not isinstance(dataset, torchdata.IterableDataset)
|
| 615 |
+
else None
|
| 616 |
+
),
|
| 617 |
+
}
|
| 618 |
+
|
| 619 |
+
|
| 620 |
+
@configurable(from_config=_test_loader_from_config)
|
| 621 |
+
def build_detection_test_loader(
|
| 622 |
+
dataset: Union[List[Any], torchdata.Dataset],
|
| 623 |
+
*,
|
| 624 |
+
mapper: Callable[[Dict[str, Any]], Any],
|
| 625 |
+
sampler: Optional[torchdata.Sampler] = None,
|
| 626 |
+
batch_size: int = 1,
|
| 627 |
+
num_workers: int = 0,
|
| 628 |
+
collate_fn: Optional[Callable[[List[Any]], Any]] = None,
|
| 629 |
+
) -> torchdata.DataLoader:
|
| 630 |
+
"""
|
| 631 |
+
Similar to `build_detection_train_loader`, with default batch size = 1,
|
| 632 |
+
and sampler = :class:`InferenceSampler`. This sampler coordinates all workers
|
| 633 |
+
to produce the exact set of all samples.
|
| 634 |
+
|
| 635 |
+
Args:
|
| 636 |
+
dataset: a list of dataset dicts,
|
| 637 |
+
or a pytorch dataset (either map-style or iterable). They can be obtained
|
| 638 |
+
by using :func:`DatasetCatalog.get` or :func:`get_detection_dataset_dicts`.
|
| 639 |
+
mapper: a callable which takes a sample (dict) from dataset
|
| 640 |
+
and returns the format to be consumed by the model.
|
| 641 |
+
When using cfg, the default choice is ``DatasetMapper(cfg, is_train=False)``.
|
| 642 |
+
sampler: a sampler that produces
|
| 643 |
+
indices to be applied on ``dataset``. Default to :class:`InferenceSampler`,
|
| 644 |
+
which splits the dataset across all workers. Sampler must be None
|
| 645 |
+
if `dataset` is iterable.
|
| 646 |
+
batch_size: the batch size of the data loader to be created.
|
| 647 |
+
Default to 1 image per worker since this is the standard when reporting
|
| 648 |
+
inference time in papers.
|
| 649 |
+
num_workers: number of parallel data loading workers
|
| 650 |
+
collate_fn: same as the argument of `torch.utils.data.DataLoader`.
|
| 651 |
+
Defaults to do no collation and return a list of data.
|
| 652 |
+
|
| 653 |
+
Returns:
|
| 654 |
+
DataLoader: a torch DataLoader, that loads the given detection
|
| 655 |
+
dataset, with test-time transformation and batching.
|
| 656 |
+
|
| 657 |
+
Examples:
|
| 658 |
+
::
|
| 659 |
+
data_loader = build_detection_test_loader(
|
| 660 |
+
DatasetRegistry.get("my_test"),
|
| 661 |
+
mapper=DatasetMapper(...))
|
| 662 |
+
|
| 663 |
+
# or, instantiate with a CfgNode:
|
| 664 |
+
data_loader = build_detection_test_loader(cfg, "my_test")
|
| 665 |
+
"""
|
| 666 |
+
if isinstance(dataset, list):
|
| 667 |
+
dataset = DatasetFromList(dataset, copy=False)
|
| 668 |
+
if mapper is not None:
|
| 669 |
+
dataset = MapDataset(dataset, mapper)
|
| 670 |
+
if isinstance(dataset, torchdata.IterableDataset):
|
| 671 |
+
assert sampler is None, "sampler must be None if dataset is IterableDataset"
|
| 672 |
+
else:
|
| 673 |
+
if sampler is None:
|
| 674 |
+
sampler = InferenceSampler(len(dataset))
|
| 675 |
+
return torchdata.DataLoader(
|
| 676 |
+
dataset,
|
| 677 |
+
batch_size=batch_size,
|
| 678 |
+
sampler=sampler,
|
| 679 |
+
drop_last=False,
|
| 680 |
+
num_workers=num_workers,
|
| 681 |
+
collate_fn=trivial_batch_collator if collate_fn is None else collate_fn,
|
| 682 |
+
)
|
| 683 |
+
|
| 684 |
+
|
| 685 |
+
def trivial_batch_collator(batch):
|
| 686 |
+
"""
|
| 687 |
+
A batch collator that does nothing.
|
| 688 |
+
"""
|
| 689 |
+
return batch
|
| 690 |
+
|
| 691 |
+
|
| 692 |
+
def worker_init_reset_seed(worker_id):
|
| 693 |
+
initial_seed = torch.initial_seed() % 2**31
|
| 694 |
+
seed_all_rng(initial_seed + worker_id)
|
CatVTON/detectron2/data/catalog.py
ADDED
|
@@ -0,0 +1,236 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
import copy
|
| 3 |
+
import logging
|
| 4 |
+
import types
|
| 5 |
+
from collections import UserDict
|
| 6 |
+
from typing import List
|
| 7 |
+
|
| 8 |
+
from detectron2.utils.logger import log_first_n
|
| 9 |
+
|
| 10 |
+
__all__ = ["DatasetCatalog", "MetadataCatalog", "Metadata"]
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class _DatasetCatalog(UserDict):
|
| 14 |
+
"""
|
| 15 |
+
A global dictionary that stores information about the datasets and how to obtain them.
|
| 16 |
+
|
| 17 |
+
It contains a mapping from strings
|
| 18 |
+
(which are names that identify a dataset, e.g. "coco_2014_train")
|
| 19 |
+
to a function which parses the dataset and returns the samples in the
|
| 20 |
+
format of `list[dict]`.
|
| 21 |
+
|
| 22 |
+
The returned dicts should be in Detectron2 Dataset format (See DATASETS.md for details)
|
| 23 |
+
if used with the data loader functionalities in `data/build.py,data/detection_transform.py`.
|
| 24 |
+
|
| 25 |
+
The purpose of having this catalog is to make it easy to choose
|
| 26 |
+
different datasets, by just using the strings in the config.
|
| 27 |
+
"""
|
| 28 |
+
|
| 29 |
+
def register(self, name, func):
|
| 30 |
+
"""
|
| 31 |
+
Args:
|
| 32 |
+
name (str): the name that identifies a dataset, e.g. "coco_2014_train".
|
| 33 |
+
func (callable): a callable which takes no arguments and returns a list of dicts.
|
| 34 |
+
It must return the same results if called multiple times.
|
| 35 |
+
"""
|
| 36 |
+
assert callable(func), "You must register a function with `DatasetCatalog.register`!"
|
| 37 |
+
assert name not in self, "Dataset '{}' is already registered!".format(name)
|
| 38 |
+
self[name] = func
|
| 39 |
+
|
| 40 |
+
def get(self, name):
|
| 41 |
+
"""
|
| 42 |
+
Call the registered function and return its results.
|
| 43 |
+
|
| 44 |
+
Args:
|
| 45 |
+
name (str): the name that identifies a dataset, e.g. "coco_2014_train".
|
| 46 |
+
|
| 47 |
+
Returns:
|
| 48 |
+
list[dict]: dataset annotations.
|
| 49 |
+
"""
|
| 50 |
+
try:
|
| 51 |
+
f = self[name]
|
| 52 |
+
except KeyError as e:
|
| 53 |
+
raise KeyError(
|
| 54 |
+
"Dataset '{}' is not registered! Available datasets are: {}".format(
|
| 55 |
+
name, ", ".join(list(self.keys()))
|
| 56 |
+
)
|
| 57 |
+
) from e
|
| 58 |
+
return f()
|
| 59 |
+
|
| 60 |
+
def list(self) -> List[str]:
|
| 61 |
+
"""
|
| 62 |
+
List all registered datasets.
|
| 63 |
+
|
| 64 |
+
Returns:
|
| 65 |
+
list[str]
|
| 66 |
+
"""
|
| 67 |
+
return list(self.keys())
|
| 68 |
+
|
| 69 |
+
def remove(self, name):
|
| 70 |
+
"""
|
| 71 |
+
Alias of ``pop``.
|
| 72 |
+
"""
|
| 73 |
+
self.pop(name)
|
| 74 |
+
|
| 75 |
+
def __str__(self):
|
| 76 |
+
return "DatasetCatalog(registered datasets: {})".format(", ".join(self.keys()))
|
| 77 |
+
|
| 78 |
+
__repr__ = __str__
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
DatasetCatalog = _DatasetCatalog()
|
| 82 |
+
DatasetCatalog.__doc__ = (
|
| 83 |
+
_DatasetCatalog.__doc__
|
| 84 |
+
+ """
|
| 85 |
+
.. automethod:: detectron2.data.catalog.DatasetCatalog.register
|
| 86 |
+
.. automethod:: detectron2.data.catalog.DatasetCatalog.get
|
| 87 |
+
"""
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
class Metadata(types.SimpleNamespace):
|
| 92 |
+
"""
|
| 93 |
+
A class that supports simple attribute setter/getter.
|
| 94 |
+
It is intended for storing metadata of a dataset and make it accessible globally.
|
| 95 |
+
|
| 96 |
+
Examples:
|
| 97 |
+
::
|
| 98 |
+
# somewhere when you load the data:
|
| 99 |
+
MetadataCatalog.get("mydataset").thing_classes = ["person", "dog"]
|
| 100 |
+
|
| 101 |
+
# somewhere when you print statistics or visualize:
|
| 102 |
+
classes = MetadataCatalog.get("mydataset").thing_classes
|
| 103 |
+
"""
|
| 104 |
+
|
| 105 |
+
# the name of the dataset
|
| 106 |
+
# set default to N/A so that `self.name` in the errors will not trigger getattr again
|
| 107 |
+
name: str = "N/A"
|
| 108 |
+
|
| 109 |
+
_RENAMED = {
|
| 110 |
+
"class_names": "thing_classes",
|
| 111 |
+
"dataset_id_to_contiguous_id": "thing_dataset_id_to_contiguous_id",
|
| 112 |
+
"stuff_class_names": "stuff_classes",
|
| 113 |
+
}
|
| 114 |
+
|
| 115 |
+
def __getattr__(self, key):
|
| 116 |
+
if key in self._RENAMED:
|
| 117 |
+
log_first_n(
|
| 118 |
+
logging.WARNING,
|
| 119 |
+
"Metadata '{}' was renamed to '{}'!".format(key, self._RENAMED[key]),
|
| 120 |
+
n=10,
|
| 121 |
+
)
|
| 122 |
+
return getattr(self, self._RENAMED[key])
|
| 123 |
+
|
| 124 |
+
# "name" exists in every metadata
|
| 125 |
+
if len(self.__dict__) > 1:
|
| 126 |
+
raise AttributeError(
|
| 127 |
+
"Attribute '{}' does not exist in the metadata of dataset '{}'. Available "
|
| 128 |
+
"keys are {}.".format(key, self.name, str(self.__dict__.keys()))
|
| 129 |
+
)
|
| 130 |
+
else:
|
| 131 |
+
raise AttributeError(
|
| 132 |
+
f"Attribute '{key}' does not exist in the metadata of dataset '{self.name}': "
|
| 133 |
+
"metadata is empty."
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
def __setattr__(self, key, val):
|
| 137 |
+
if key in self._RENAMED:
|
| 138 |
+
log_first_n(
|
| 139 |
+
logging.WARNING,
|
| 140 |
+
"Metadata '{}' was renamed to '{}'!".format(key, self._RENAMED[key]),
|
| 141 |
+
n=10,
|
| 142 |
+
)
|
| 143 |
+
setattr(self, self._RENAMED[key], val)
|
| 144 |
+
|
| 145 |
+
# Ensure that metadata of the same name stays consistent
|
| 146 |
+
try:
|
| 147 |
+
oldval = getattr(self, key)
|
| 148 |
+
assert oldval == val, (
|
| 149 |
+
"Attribute '{}' in the metadata of '{}' cannot be set "
|
| 150 |
+
"to a different value!\n{} != {}".format(key, self.name, oldval, val)
|
| 151 |
+
)
|
| 152 |
+
except AttributeError:
|
| 153 |
+
super().__setattr__(key, val)
|
| 154 |
+
|
| 155 |
+
def as_dict(self):
|
| 156 |
+
"""
|
| 157 |
+
Returns all the metadata as a dict.
|
| 158 |
+
Note that modifications to the returned dict will not reflect on the Metadata object.
|
| 159 |
+
"""
|
| 160 |
+
return copy.copy(self.__dict__)
|
| 161 |
+
|
| 162 |
+
def set(self, **kwargs):
|
| 163 |
+
"""
|
| 164 |
+
Set multiple metadata with kwargs.
|
| 165 |
+
"""
|
| 166 |
+
for k, v in kwargs.items():
|
| 167 |
+
setattr(self, k, v)
|
| 168 |
+
return self
|
| 169 |
+
|
| 170 |
+
def get(self, key, default=None):
|
| 171 |
+
"""
|
| 172 |
+
Access an attribute and return its value if exists.
|
| 173 |
+
Otherwise return default.
|
| 174 |
+
"""
|
| 175 |
+
try:
|
| 176 |
+
return getattr(self, key)
|
| 177 |
+
except AttributeError:
|
| 178 |
+
return default
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
class _MetadataCatalog(UserDict):
|
| 182 |
+
"""
|
| 183 |
+
MetadataCatalog is a global dictionary that provides access to
|
| 184 |
+
:class:`Metadata` of a given dataset.
|
| 185 |
+
|
| 186 |
+
The metadata associated with a certain name is a singleton: once created, the
|
| 187 |
+
metadata will stay alive and will be returned by future calls to ``get(name)``.
|
| 188 |
+
|
| 189 |
+
It's like global variables, so don't abuse it.
|
| 190 |
+
It's meant for storing knowledge that's constant and shared across the execution
|
| 191 |
+
of the program, e.g.: the class names in COCO.
|
| 192 |
+
"""
|
| 193 |
+
|
| 194 |
+
def get(self, name):
|
| 195 |
+
"""
|
| 196 |
+
Args:
|
| 197 |
+
name (str): name of a dataset (e.g. coco_2014_train).
|
| 198 |
+
|
| 199 |
+
Returns:
|
| 200 |
+
Metadata: The :class:`Metadata` instance associated with this name,
|
| 201 |
+
or create an empty one if none is available.
|
| 202 |
+
"""
|
| 203 |
+
assert len(name)
|
| 204 |
+
r = super().get(name, None)
|
| 205 |
+
if r is None:
|
| 206 |
+
r = self[name] = Metadata(name=name)
|
| 207 |
+
return r
|
| 208 |
+
|
| 209 |
+
def list(self):
|
| 210 |
+
"""
|
| 211 |
+
List all registered metadata.
|
| 212 |
+
|
| 213 |
+
Returns:
|
| 214 |
+
list[str]: keys (names of datasets) of all registered metadata
|
| 215 |
+
"""
|
| 216 |
+
return list(self.keys())
|
| 217 |
+
|
| 218 |
+
def remove(self, name):
|
| 219 |
+
"""
|
| 220 |
+
Alias of ``pop``.
|
| 221 |
+
"""
|
| 222 |
+
self.pop(name)
|
| 223 |
+
|
| 224 |
+
def __str__(self):
|
| 225 |
+
return "MetadataCatalog(registered metadata: {})".format(", ".join(self.keys()))
|
| 226 |
+
|
| 227 |
+
__repr__ = __str__
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
MetadataCatalog = _MetadataCatalog()
|
| 231 |
+
MetadataCatalog.__doc__ = (
|
| 232 |
+
_MetadataCatalog.__doc__
|
| 233 |
+
+ """
|
| 234 |
+
.. automethod:: detectron2.data.catalog.MetadataCatalog.get
|
| 235 |
+
"""
|
| 236 |
+
)
|
CatVTON/detectron2/data/common.py
ADDED
|
@@ -0,0 +1,339 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
import contextlib
|
| 3 |
+
import copy
|
| 4 |
+
import itertools
|
| 5 |
+
import logging
|
| 6 |
+
import numpy as np
|
| 7 |
+
import pickle
|
| 8 |
+
import random
|
| 9 |
+
from typing import Callable, Union
|
| 10 |
+
import torch
|
| 11 |
+
import torch.utils.data as data
|
| 12 |
+
from torch.utils.data.sampler import Sampler
|
| 13 |
+
|
| 14 |
+
from detectron2.utils.serialize import PicklableWrapper
|
| 15 |
+
|
| 16 |
+
__all__ = ["MapDataset", "DatasetFromList", "AspectRatioGroupedDataset", "ToIterableDataset"]
|
| 17 |
+
|
| 18 |
+
logger = logging.getLogger(__name__)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
# copied from: https://docs.python.org/3/library/itertools.html#recipes
|
| 22 |
+
def _roundrobin(*iterables):
|
| 23 |
+
"roundrobin('ABC', 'D', 'EF') --> A D E B F C"
|
| 24 |
+
# Recipe credited to George Sakkis
|
| 25 |
+
num_active = len(iterables)
|
| 26 |
+
nexts = itertools.cycle(iter(it).__next__ for it in iterables)
|
| 27 |
+
while num_active:
|
| 28 |
+
try:
|
| 29 |
+
for next in nexts:
|
| 30 |
+
yield next()
|
| 31 |
+
except StopIteration:
|
| 32 |
+
# Remove the iterator we just exhausted from the cycle.
|
| 33 |
+
num_active -= 1
|
| 34 |
+
nexts = itertools.cycle(itertools.islice(nexts, num_active))
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def _shard_iterator_dataloader_worker(iterable, chunk_size=1):
|
| 38 |
+
# Shard the iterable if we're currently inside pytorch dataloader worker.
|
| 39 |
+
worker_info = data.get_worker_info()
|
| 40 |
+
if worker_info is None or worker_info.num_workers == 1:
|
| 41 |
+
# do nothing
|
| 42 |
+
yield from iterable
|
| 43 |
+
else:
|
| 44 |
+
# worker0: 0, 1, ..., chunk_size-1, num_workers*chunk_size, num_workers*chunk_size+1, ...
|
| 45 |
+
# worker1: chunk_size, chunk_size+1, ...
|
| 46 |
+
# worker2: 2*chunk_size, 2*chunk_size+1, ...
|
| 47 |
+
# ...
|
| 48 |
+
yield from _roundrobin(
|
| 49 |
+
*[
|
| 50 |
+
itertools.islice(
|
| 51 |
+
iterable,
|
| 52 |
+
worker_info.id * chunk_size + chunk_i,
|
| 53 |
+
None,
|
| 54 |
+
worker_info.num_workers * chunk_size,
|
| 55 |
+
)
|
| 56 |
+
for chunk_i in range(chunk_size)
|
| 57 |
+
]
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
class _MapIterableDataset(data.IterableDataset):
|
| 62 |
+
"""
|
| 63 |
+
Map a function over elements in an IterableDataset.
|
| 64 |
+
|
| 65 |
+
Similar to pytorch's MapIterDataPipe, but support filtering when map_func
|
| 66 |
+
returns None.
|
| 67 |
+
|
| 68 |
+
This class is not public-facing. Will be called by `MapDataset`.
|
| 69 |
+
"""
|
| 70 |
+
|
| 71 |
+
def __init__(self, dataset, map_func):
|
| 72 |
+
self._dataset = dataset
|
| 73 |
+
self._map_func = PicklableWrapper(map_func) # wrap so that a lambda will work
|
| 74 |
+
|
| 75 |
+
def __len__(self):
|
| 76 |
+
return len(self._dataset)
|
| 77 |
+
|
| 78 |
+
def __iter__(self):
|
| 79 |
+
for x in map(self._map_func, self._dataset):
|
| 80 |
+
if x is not None:
|
| 81 |
+
yield x
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
class MapDataset(data.Dataset):
|
| 85 |
+
"""
|
| 86 |
+
Map a function over the elements in a dataset.
|
| 87 |
+
"""
|
| 88 |
+
|
| 89 |
+
def __init__(self, dataset, map_func):
|
| 90 |
+
"""
|
| 91 |
+
Args:
|
| 92 |
+
dataset: a dataset where map function is applied. Can be either
|
| 93 |
+
map-style or iterable dataset. When given an iterable dataset,
|
| 94 |
+
the returned object will also be an iterable dataset.
|
| 95 |
+
map_func: a callable which maps the element in dataset. map_func can
|
| 96 |
+
return None to skip the data (e.g. in case of errors).
|
| 97 |
+
How None is handled depends on the style of `dataset`.
|
| 98 |
+
If `dataset` is map-style, it randomly tries other elements.
|
| 99 |
+
If `dataset` is iterable, it skips the data and tries the next.
|
| 100 |
+
"""
|
| 101 |
+
self._dataset = dataset
|
| 102 |
+
self._map_func = PicklableWrapper(map_func) # wrap so that a lambda will work
|
| 103 |
+
|
| 104 |
+
self._rng = random.Random(42)
|
| 105 |
+
self._fallback_candidates = set(range(len(dataset)))
|
| 106 |
+
|
| 107 |
+
def __new__(cls, dataset, map_func):
|
| 108 |
+
is_iterable = isinstance(dataset, data.IterableDataset)
|
| 109 |
+
if is_iterable:
|
| 110 |
+
return _MapIterableDataset(dataset, map_func)
|
| 111 |
+
else:
|
| 112 |
+
return super().__new__(cls)
|
| 113 |
+
|
| 114 |
+
def __getnewargs__(self):
|
| 115 |
+
return self._dataset, self._map_func
|
| 116 |
+
|
| 117 |
+
def __len__(self):
|
| 118 |
+
return len(self._dataset)
|
| 119 |
+
|
| 120 |
+
def __getitem__(self, idx):
|
| 121 |
+
retry_count = 0
|
| 122 |
+
cur_idx = int(idx)
|
| 123 |
+
|
| 124 |
+
while True:
|
| 125 |
+
data = self._map_func(self._dataset[cur_idx])
|
| 126 |
+
if data is not None:
|
| 127 |
+
self._fallback_candidates.add(cur_idx)
|
| 128 |
+
return data
|
| 129 |
+
|
| 130 |
+
# _map_func fails for this idx, use a random new index from the pool
|
| 131 |
+
retry_count += 1
|
| 132 |
+
self._fallback_candidates.discard(cur_idx)
|
| 133 |
+
cur_idx = self._rng.sample(self._fallback_candidates, k=1)[0]
|
| 134 |
+
|
| 135 |
+
if retry_count >= 3:
|
| 136 |
+
logger = logging.getLogger(__name__)
|
| 137 |
+
logger.warning(
|
| 138 |
+
"Failed to apply `_map_func` for idx: {}, retry count: {}".format(
|
| 139 |
+
idx, retry_count
|
| 140 |
+
)
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
class _TorchSerializedList:
|
| 145 |
+
"""
|
| 146 |
+
A list-like object whose items are serialized and stored in a torch tensor. When
|
| 147 |
+
launching a process that uses TorchSerializedList with "fork" start method,
|
| 148 |
+
the subprocess can read the same buffer without triggering copy-on-access. When
|
| 149 |
+
launching a process that uses TorchSerializedList with "spawn/forkserver" start
|
| 150 |
+
method, the list will be pickled by a special ForkingPickler registered by PyTorch
|
| 151 |
+
that moves data to shared memory. In both cases, this allows parent and child
|
| 152 |
+
processes to share RAM for the list data, hence avoids the issue in
|
| 153 |
+
https://github.com/pytorch/pytorch/issues/13246.
|
| 154 |
+
|
| 155 |
+
See also https://ppwwyyxx.com/blog/2022/Demystify-RAM-Usage-in-Multiprocess-DataLoader/
|
| 156 |
+
on how it works.
|
| 157 |
+
"""
|
| 158 |
+
|
| 159 |
+
def __init__(self, lst: list):
|
| 160 |
+
self._lst = lst
|
| 161 |
+
|
| 162 |
+
def _serialize(data):
|
| 163 |
+
buffer = pickle.dumps(data, protocol=-1)
|
| 164 |
+
return np.frombuffer(buffer, dtype=np.uint8)
|
| 165 |
+
|
| 166 |
+
logger.info(
|
| 167 |
+
"Serializing {} elements to byte tensors and concatenating them all ...".format(
|
| 168 |
+
len(self._lst)
|
| 169 |
+
)
|
| 170 |
+
)
|
| 171 |
+
self._lst = [_serialize(x) for x in self._lst]
|
| 172 |
+
self._addr = np.asarray([len(x) for x in self._lst], dtype=np.int64)
|
| 173 |
+
self._addr = torch.from_numpy(np.cumsum(self._addr))
|
| 174 |
+
self._lst = torch.from_numpy(np.concatenate(self._lst))
|
| 175 |
+
logger.info("Serialized dataset takes {:.2f} MiB".format(len(self._lst) / 1024**2))
|
| 176 |
+
|
| 177 |
+
def __len__(self):
|
| 178 |
+
return len(self._addr)
|
| 179 |
+
|
| 180 |
+
def __getitem__(self, idx):
|
| 181 |
+
start_addr = 0 if idx == 0 else self._addr[idx - 1].item()
|
| 182 |
+
end_addr = self._addr[idx].item()
|
| 183 |
+
bytes = memoryview(self._lst[start_addr:end_addr].numpy())
|
| 184 |
+
|
| 185 |
+
# @lint-ignore PYTHONPICKLEISBAD
|
| 186 |
+
return pickle.loads(bytes)
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
_DEFAULT_DATASET_FROM_LIST_SERIALIZE_METHOD = _TorchSerializedList
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
@contextlib.contextmanager
|
| 193 |
+
def set_default_dataset_from_list_serialize_method(new):
|
| 194 |
+
"""
|
| 195 |
+
Context manager for using custom serialize function when creating DatasetFromList
|
| 196 |
+
"""
|
| 197 |
+
|
| 198 |
+
global _DEFAULT_DATASET_FROM_LIST_SERIALIZE_METHOD
|
| 199 |
+
orig = _DEFAULT_DATASET_FROM_LIST_SERIALIZE_METHOD
|
| 200 |
+
_DEFAULT_DATASET_FROM_LIST_SERIALIZE_METHOD = new
|
| 201 |
+
yield
|
| 202 |
+
_DEFAULT_DATASET_FROM_LIST_SERIALIZE_METHOD = orig
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
class DatasetFromList(data.Dataset):
|
| 206 |
+
"""
|
| 207 |
+
Wrap a list to a torch Dataset. It produces elements of the list as data.
|
| 208 |
+
"""
|
| 209 |
+
|
| 210 |
+
def __init__(
|
| 211 |
+
self,
|
| 212 |
+
lst: list,
|
| 213 |
+
copy: bool = True,
|
| 214 |
+
serialize: Union[bool, Callable] = True,
|
| 215 |
+
):
|
| 216 |
+
"""
|
| 217 |
+
Args:
|
| 218 |
+
lst (list): a list which contains elements to produce.
|
| 219 |
+
copy (bool): whether to deepcopy the element when producing it,
|
| 220 |
+
so that the result can be modified in place without affecting the
|
| 221 |
+
source in the list.
|
| 222 |
+
serialize (bool or callable): whether to serialize the stroage to other
|
| 223 |
+
backend. If `True`, the default serialize method will be used, if given
|
| 224 |
+
a callable, the callable will be used as serialize method.
|
| 225 |
+
"""
|
| 226 |
+
self._lst = lst
|
| 227 |
+
self._copy = copy
|
| 228 |
+
if not isinstance(serialize, (bool, Callable)):
|
| 229 |
+
raise TypeError(f"Unsupported type for argument `serailzie`: {serialize}")
|
| 230 |
+
self._serialize = serialize is not False
|
| 231 |
+
|
| 232 |
+
if self._serialize:
|
| 233 |
+
serialize_method = (
|
| 234 |
+
serialize
|
| 235 |
+
if isinstance(serialize, Callable)
|
| 236 |
+
else _DEFAULT_DATASET_FROM_LIST_SERIALIZE_METHOD
|
| 237 |
+
)
|
| 238 |
+
logger.info(f"Serializing the dataset using: {serialize_method}")
|
| 239 |
+
self._lst = serialize_method(self._lst)
|
| 240 |
+
|
| 241 |
+
def __len__(self):
|
| 242 |
+
return len(self._lst)
|
| 243 |
+
|
| 244 |
+
def __getitem__(self, idx):
|
| 245 |
+
if self._copy and not self._serialize:
|
| 246 |
+
return copy.deepcopy(self._lst[idx])
|
| 247 |
+
else:
|
| 248 |
+
return self._lst[idx]
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
class ToIterableDataset(data.IterableDataset):
|
| 252 |
+
"""
|
| 253 |
+
Convert an old indices-based (also called map-style) dataset
|
| 254 |
+
to an iterable-style dataset.
|
| 255 |
+
"""
|
| 256 |
+
|
| 257 |
+
def __init__(
|
| 258 |
+
self,
|
| 259 |
+
dataset: data.Dataset,
|
| 260 |
+
sampler: Sampler,
|
| 261 |
+
shard_sampler: bool = True,
|
| 262 |
+
shard_chunk_size: int = 1,
|
| 263 |
+
):
|
| 264 |
+
"""
|
| 265 |
+
Args:
|
| 266 |
+
dataset: an old-style dataset with ``__getitem__``
|
| 267 |
+
sampler: a cheap iterable that produces indices to be applied on ``dataset``.
|
| 268 |
+
shard_sampler: whether to shard the sampler based on the current pytorch data loader
|
| 269 |
+
worker id. When an IterableDataset is forked by pytorch's DataLoader into multiple
|
| 270 |
+
workers, it is responsible for sharding its data based on worker id so that workers
|
| 271 |
+
don't produce identical data.
|
| 272 |
+
|
| 273 |
+
Most samplers (like our TrainingSampler) do not shard based on dataloader worker id
|
| 274 |
+
and this argument should be set to True. But certain samplers may be already
|
| 275 |
+
sharded, in that case this argument should be set to False.
|
| 276 |
+
shard_chunk_size: when sharding the sampler, each worker will
|
| 277 |
+
"""
|
| 278 |
+
assert not isinstance(dataset, data.IterableDataset), dataset
|
| 279 |
+
assert isinstance(sampler, Sampler), sampler
|
| 280 |
+
self.dataset = dataset
|
| 281 |
+
self.sampler = sampler
|
| 282 |
+
self.shard_sampler = shard_sampler
|
| 283 |
+
self.shard_chunk_size = shard_chunk_size
|
| 284 |
+
|
| 285 |
+
def __iter__(self):
|
| 286 |
+
if not self.shard_sampler:
|
| 287 |
+
sampler = self.sampler
|
| 288 |
+
else:
|
| 289 |
+
# With map-style dataset, `DataLoader(dataset, sampler)` runs the
|
| 290 |
+
# sampler in main process only. But `DataLoader(ToIterableDataset(dataset, sampler))`
|
| 291 |
+
# will run sampler in every of the N worker. So we should only keep 1/N of the ids on
|
| 292 |
+
# each worker. The assumption is that sampler is cheap to iterate so it's fine to
|
| 293 |
+
# discard ids in workers.
|
| 294 |
+
sampler = _shard_iterator_dataloader_worker(self.sampler, self.shard_chunk_size)
|
| 295 |
+
for idx in sampler:
|
| 296 |
+
yield self.dataset[idx]
|
| 297 |
+
|
| 298 |
+
def __len__(self):
|
| 299 |
+
return len(self.sampler)
|
| 300 |
+
|
| 301 |
+
|
| 302 |
+
class AspectRatioGroupedDataset(data.IterableDataset):
|
| 303 |
+
"""
|
| 304 |
+
Batch data that have similar aspect ratio together.
|
| 305 |
+
In this implementation, images whose aspect ratio < (or >) 1 will
|
| 306 |
+
be batched together.
|
| 307 |
+
This improves training speed because the images then need less padding
|
| 308 |
+
to form a batch.
|
| 309 |
+
|
| 310 |
+
It assumes the underlying dataset produces dicts with "width" and "height" keys.
|
| 311 |
+
It will then produce a list of original dicts with length = batch_size,
|
| 312 |
+
all with similar aspect ratios.
|
| 313 |
+
"""
|
| 314 |
+
|
| 315 |
+
def __init__(self, dataset, batch_size):
|
| 316 |
+
"""
|
| 317 |
+
Args:
|
| 318 |
+
dataset: an iterable. Each element must be a dict with keys
|
| 319 |
+
"width" and "height", which will be used to batch data.
|
| 320 |
+
batch_size (int):
|
| 321 |
+
"""
|
| 322 |
+
self.dataset = dataset
|
| 323 |
+
self.batch_size = batch_size
|
| 324 |
+
self._buckets = [[] for _ in range(2)]
|
| 325 |
+
# Hard-coded two aspect ratio groups: w > h and w < h.
|
| 326 |
+
# Can add support for more aspect ratio groups, but doesn't seem useful
|
| 327 |
+
|
| 328 |
+
def __iter__(self):
|
| 329 |
+
for d in self.dataset:
|
| 330 |
+
w, h = d["width"], d["height"]
|
| 331 |
+
bucket_id = 0 if w > h else 1
|
| 332 |
+
bucket = self._buckets[bucket_id]
|
| 333 |
+
bucket.append(d)
|
| 334 |
+
if len(bucket) == self.batch_size:
|
| 335 |
+
data = bucket[:]
|
| 336 |
+
# Clear bucket first, because code after yield is not
|
| 337 |
+
# guaranteed to execute
|
| 338 |
+
del bucket[:]
|
| 339 |
+
yield data
|
CatVTON/detectron2/data/dataset_mapper.py
ADDED
|
@@ -0,0 +1,191 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
import copy
|
| 3 |
+
import logging
|
| 4 |
+
import numpy as np
|
| 5 |
+
from typing import List, Optional, Union
|
| 6 |
+
import torch
|
| 7 |
+
|
| 8 |
+
from detectron2.config import configurable
|
| 9 |
+
|
| 10 |
+
from . import detection_utils as utils
|
| 11 |
+
from . import transforms as T
|
| 12 |
+
|
| 13 |
+
"""
|
| 14 |
+
This file contains the default mapping that's applied to "dataset dicts".
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
__all__ = ["DatasetMapper"]
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class DatasetMapper:
|
| 21 |
+
"""
|
| 22 |
+
A callable which takes a dataset dict in Detectron2 Dataset format,
|
| 23 |
+
and map it into a format used by the model.
|
| 24 |
+
|
| 25 |
+
This is the default callable to be used to map your dataset dict into training data.
|
| 26 |
+
You may need to follow it to implement your own one for customized logic,
|
| 27 |
+
such as a different way to read or transform images.
|
| 28 |
+
See :doc:`/tutorials/data_loading` for details.
|
| 29 |
+
|
| 30 |
+
The callable currently does the following:
|
| 31 |
+
|
| 32 |
+
1. Read the image from "file_name"
|
| 33 |
+
2. Applies cropping/geometric transforms to the image and annotations
|
| 34 |
+
3. Prepare data and annotations to Tensor and :class:`Instances`
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
@configurable
|
| 38 |
+
def __init__(
|
| 39 |
+
self,
|
| 40 |
+
is_train: bool,
|
| 41 |
+
*,
|
| 42 |
+
augmentations: List[Union[T.Augmentation, T.Transform]],
|
| 43 |
+
image_format: str,
|
| 44 |
+
use_instance_mask: bool = False,
|
| 45 |
+
use_keypoint: bool = False,
|
| 46 |
+
instance_mask_format: str = "polygon",
|
| 47 |
+
keypoint_hflip_indices: Optional[np.ndarray] = None,
|
| 48 |
+
precomputed_proposal_topk: Optional[int] = None,
|
| 49 |
+
recompute_boxes: bool = False,
|
| 50 |
+
):
|
| 51 |
+
"""
|
| 52 |
+
NOTE: this interface is experimental.
|
| 53 |
+
|
| 54 |
+
Args:
|
| 55 |
+
is_train: whether it's used in training or inference
|
| 56 |
+
augmentations: a list of augmentations or deterministic transforms to apply
|
| 57 |
+
image_format: an image format supported by :func:`detection_utils.read_image`.
|
| 58 |
+
use_instance_mask: whether to process instance segmentation annotations, if available
|
| 59 |
+
use_keypoint: whether to process keypoint annotations if available
|
| 60 |
+
instance_mask_format: one of "polygon" or "bitmask". Process instance segmentation
|
| 61 |
+
masks into this format.
|
| 62 |
+
keypoint_hflip_indices: see :func:`detection_utils.create_keypoint_hflip_indices`
|
| 63 |
+
precomputed_proposal_topk: if given, will load pre-computed
|
| 64 |
+
proposals from dataset_dict and keep the top k proposals for each image.
|
| 65 |
+
recompute_boxes: whether to overwrite bounding box annotations
|
| 66 |
+
by computing tight bounding boxes from instance mask annotations.
|
| 67 |
+
"""
|
| 68 |
+
if recompute_boxes:
|
| 69 |
+
assert use_instance_mask, "recompute_boxes requires instance masks"
|
| 70 |
+
# fmt: off
|
| 71 |
+
self.is_train = is_train
|
| 72 |
+
self.augmentations = T.AugmentationList(augmentations)
|
| 73 |
+
self.image_format = image_format
|
| 74 |
+
self.use_instance_mask = use_instance_mask
|
| 75 |
+
self.instance_mask_format = instance_mask_format
|
| 76 |
+
self.use_keypoint = use_keypoint
|
| 77 |
+
self.keypoint_hflip_indices = keypoint_hflip_indices
|
| 78 |
+
self.proposal_topk = precomputed_proposal_topk
|
| 79 |
+
self.recompute_boxes = recompute_boxes
|
| 80 |
+
# fmt: on
|
| 81 |
+
logger = logging.getLogger(__name__)
|
| 82 |
+
mode = "training" if is_train else "inference"
|
| 83 |
+
logger.info(f"[DatasetMapper] Augmentations used in {mode}: {augmentations}")
|
| 84 |
+
|
| 85 |
+
@classmethod
|
| 86 |
+
def from_config(cls, cfg, is_train: bool = True):
|
| 87 |
+
augs = utils.build_augmentation(cfg, is_train)
|
| 88 |
+
if cfg.INPUT.CROP.ENABLED and is_train:
|
| 89 |
+
augs.insert(0, T.RandomCrop(cfg.INPUT.CROP.TYPE, cfg.INPUT.CROP.SIZE))
|
| 90 |
+
recompute_boxes = cfg.MODEL.MASK_ON
|
| 91 |
+
else:
|
| 92 |
+
recompute_boxes = False
|
| 93 |
+
|
| 94 |
+
ret = {
|
| 95 |
+
"is_train": is_train,
|
| 96 |
+
"augmentations": augs,
|
| 97 |
+
"image_format": cfg.INPUT.FORMAT,
|
| 98 |
+
"use_instance_mask": cfg.MODEL.MASK_ON,
|
| 99 |
+
"instance_mask_format": cfg.INPUT.MASK_FORMAT,
|
| 100 |
+
"use_keypoint": cfg.MODEL.KEYPOINT_ON,
|
| 101 |
+
"recompute_boxes": recompute_boxes,
|
| 102 |
+
}
|
| 103 |
+
|
| 104 |
+
if cfg.MODEL.KEYPOINT_ON:
|
| 105 |
+
ret["keypoint_hflip_indices"] = utils.create_keypoint_hflip_indices(cfg.DATASETS.TRAIN)
|
| 106 |
+
|
| 107 |
+
if cfg.MODEL.LOAD_PROPOSALS:
|
| 108 |
+
ret["precomputed_proposal_topk"] = (
|
| 109 |
+
cfg.DATASETS.PRECOMPUTED_PROPOSAL_TOPK_TRAIN
|
| 110 |
+
if is_train
|
| 111 |
+
else cfg.DATASETS.PRECOMPUTED_PROPOSAL_TOPK_TEST
|
| 112 |
+
)
|
| 113 |
+
return ret
|
| 114 |
+
|
| 115 |
+
def _transform_annotations(self, dataset_dict, transforms, image_shape):
|
| 116 |
+
# USER: Modify this if you want to keep them for some reason.
|
| 117 |
+
for anno in dataset_dict["annotations"]:
|
| 118 |
+
if not self.use_instance_mask:
|
| 119 |
+
anno.pop("segmentation", None)
|
| 120 |
+
if not self.use_keypoint:
|
| 121 |
+
anno.pop("keypoints", None)
|
| 122 |
+
|
| 123 |
+
# USER: Implement additional transformations if you have other types of data
|
| 124 |
+
annos = [
|
| 125 |
+
utils.transform_instance_annotations(
|
| 126 |
+
obj, transforms, image_shape, keypoint_hflip_indices=self.keypoint_hflip_indices
|
| 127 |
+
)
|
| 128 |
+
for obj in dataset_dict.pop("annotations")
|
| 129 |
+
if obj.get("iscrowd", 0) == 0
|
| 130 |
+
]
|
| 131 |
+
instances = utils.annotations_to_instances(
|
| 132 |
+
annos, image_shape, mask_format=self.instance_mask_format
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
# After transforms such as cropping are applied, the bounding box may no longer
|
| 136 |
+
# tightly bound the object. As an example, imagine a triangle object
|
| 137 |
+
# [(0,0), (2,0), (0,2)] cropped by a box [(1,0),(2,2)] (XYXY format). The tight
|
| 138 |
+
# bounding box of the cropped triangle should be [(1,0),(2,1)], which is not equal to
|
| 139 |
+
# the intersection of original bounding box and the cropping box.
|
| 140 |
+
if self.recompute_boxes:
|
| 141 |
+
instances.gt_boxes = instances.gt_masks.get_bounding_boxes()
|
| 142 |
+
dataset_dict["instances"] = utils.filter_empty_instances(instances)
|
| 143 |
+
|
| 144 |
+
def __call__(self, dataset_dict):
|
| 145 |
+
"""
|
| 146 |
+
Args:
|
| 147 |
+
dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format.
|
| 148 |
+
|
| 149 |
+
Returns:
|
| 150 |
+
dict: a format that builtin models in detectron2 accept
|
| 151 |
+
"""
|
| 152 |
+
dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below
|
| 153 |
+
# USER: Write your own image loading if it's not from a file
|
| 154 |
+
image = utils.read_image(dataset_dict["file_name"], format=self.image_format)
|
| 155 |
+
utils.check_image_size(dataset_dict, image)
|
| 156 |
+
|
| 157 |
+
# USER: Remove if you don't do semantic/panoptic segmentation.
|
| 158 |
+
if "sem_seg_file_name" in dataset_dict:
|
| 159 |
+
sem_seg_gt = utils.read_image(dataset_dict.pop("sem_seg_file_name"), "L").squeeze(2)
|
| 160 |
+
else:
|
| 161 |
+
sem_seg_gt = None
|
| 162 |
+
|
| 163 |
+
aug_input = T.AugInput(image, sem_seg=sem_seg_gt)
|
| 164 |
+
transforms = self.augmentations(aug_input)
|
| 165 |
+
image, sem_seg_gt = aug_input.image, aug_input.sem_seg
|
| 166 |
+
|
| 167 |
+
image_shape = image.shape[:2] # h, w
|
| 168 |
+
# Pytorch's dataloader is efficient on torch.Tensor due to shared-memory,
|
| 169 |
+
# but not efficient on large generic data structures due to the use of pickle & mp.Queue.
|
| 170 |
+
# Therefore it's important to use torch.Tensor.
|
| 171 |
+
dataset_dict["image"] = torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1)))
|
| 172 |
+
if sem_seg_gt is not None:
|
| 173 |
+
dataset_dict["sem_seg"] = torch.as_tensor(sem_seg_gt.astype("long"))
|
| 174 |
+
|
| 175 |
+
# USER: Remove if you don't use pre-computed proposals.
|
| 176 |
+
# Most users would not need this feature.
|
| 177 |
+
if self.proposal_topk is not None:
|
| 178 |
+
utils.transform_proposals(
|
| 179 |
+
dataset_dict, image_shape, transforms, proposal_topk=self.proposal_topk
|
| 180 |
+
)
|
| 181 |
+
|
| 182 |
+
if not self.is_train:
|
| 183 |
+
# USER: Modify this if you want to keep them for some reason.
|
| 184 |
+
dataset_dict.pop("annotations", None)
|
| 185 |
+
dataset_dict.pop("sem_seg_file_name", None)
|
| 186 |
+
return dataset_dict
|
| 187 |
+
|
| 188 |
+
if "annotations" in dataset_dict:
|
| 189 |
+
self._transform_annotations(dataset_dict, transforms, image_shape)
|
| 190 |
+
|
| 191 |
+
return dataset_dict
|
CatVTON/detectron2/data/datasets/README.md
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
|
| 3 |
+
### Common Datasets
|
| 4 |
+
|
| 5 |
+
The dataset implemented here do not need to load the data into the final format.
|
| 6 |
+
It should provide the minimal data structure needed to use the dataset, so it can be very efficient.
|
| 7 |
+
|
| 8 |
+
For example, for an image dataset, just provide the file names and labels, but don't read the images.
|
| 9 |
+
Let the downstream decide how to read.
|
CatVTON/detectron2/data/datasets/__init__.py
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
from .coco import load_coco_json, load_sem_seg, register_coco_instances, convert_to_coco_json
|
| 3 |
+
from .coco_panoptic import register_coco_panoptic, register_coco_panoptic_separated
|
| 4 |
+
from .lvis import load_lvis_json, register_lvis_instances, get_lvis_instances_meta
|
| 5 |
+
from .pascal_voc import load_voc_instances, register_pascal_voc
|
| 6 |
+
from . import builtin as _builtin # ensure the builtin datasets are registered
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
__all__ = [k for k in globals().keys() if not k.startswith("_")]
|
CatVTON/detectron2/data/datasets/builtin.py
ADDED
|
@@ -0,0 +1,259 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
"""
|
| 6 |
+
This file registers pre-defined datasets at hard-coded paths, and their metadata.
|
| 7 |
+
|
| 8 |
+
We hard-code metadata for common datasets. This will enable:
|
| 9 |
+
1. Consistency check when loading the datasets
|
| 10 |
+
2. Use models on these standard datasets directly and run demos,
|
| 11 |
+
without having to download the dataset annotations
|
| 12 |
+
|
| 13 |
+
We hard-code some paths to the dataset that's assumed to
|
| 14 |
+
exist in "./datasets/".
|
| 15 |
+
|
| 16 |
+
Users SHOULD NOT use this file to create new dataset / metadata for new dataset.
|
| 17 |
+
To add new dataset, refer to the tutorial "docs/DATASETS.md".
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
import os
|
| 21 |
+
|
| 22 |
+
from detectron2.data import DatasetCatalog, MetadataCatalog
|
| 23 |
+
|
| 24 |
+
from .builtin_meta import ADE20K_SEM_SEG_CATEGORIES, _get_builtin_metadata
|
| 25 |
+
from .cityscapes import load_cityscapes_instances, load_cityscapes_semantic
|
| 26 |
+
from .cityscapes_panoptic import register_all_cityscapes_panoptic
|
| 27 |
+
from .coco import load_sem_seg, register_coco_instances
|
| 28 |
+
from .coco_panoptic import register_coco_panoptic, register_coco_panoptic_separated
|
| 29 |
+
from .lvis import get_lvis_instances_meta, register_lvis_instances
|
| 30 |
+
from .pascal_voc import register_pascal_voc
|
| 31 |
+
|
| 32 |
+
# ==== Predefined datasets and splits for COCO ==========
|
| 33 |
+
|
| 34 |
+
_PREDEFINED_SPLITS_COCO = {}
|
| 35 |
+
_PREDEFINED_SPLITS_COCO["coco"] = {
|
| 36 |
+
"coco_2014_train": ("coco/train2014", "coco/annotations/instances_train2014.json"),
|
| 37 |
+
"coco_2014_val": ("coco/val2014", "coco/annotations/instances_val2014.json"),
|
| 38 |
+
"coco_2014_minival": ("coco/val2014", "coco/annotations/instances_minival2014.json"),
|
| 39 |
+
"coco_2014_valminusminival": (
|
| 40 |
+
"coco/val2014",
|
| 41 |
+
"coco/annotations/instances_valminusminival2014.json",
|
| 42 |
+
),
|
| 43 |
+
"coco_2017_train": ("coco/train2017", "coco/annotations/instances_train2017.json"),
|
| 44 |
+
"coco_2017_val": ("coco/val2017", "coco/annotations/instances_val2017.json"),
|
| 45 |
+
"coco_2017_test": ("coco/test2017", "coco/annotations/image_info_test2017.json"),
|
| 46 |
+
"coco_2017_test-dev": ("coco/test2017", "coco/annotations/image_info_test-dev2017.json"),
|
| 47 |
+
"coco_2017_val_100": ("coco/val2017", "coco/annotations/instances_val2017_100.json"),
|
| 48 |
+
}
|
| 49 |
+
|
| 50 |
+
_PREDEFINED_SPLITS_COCO["coco_person"] = {
|
| 51 |
+
"keypoints_coco_2014_train": (
|
| 52 |
+
"coco/train2014",
|
| 53 |
+
"coco/annotations/person_keypoints_train2014.json",
|
| 54 |
+
),
|
| 55 |
+
"keypoints_coco_2014_val": ("coco/val2014", "coco/annotations/person_keypoints_val2014.json"),
|
| 56 |
+
"keypoints_coco_2014_minival": (
|
| 57 |
+
"coco/val2014",
|
| 58 |
+
"coco/annotations/person_keypoints_minival2014.json",
|
| 59 |
+
),
|
| 60 |
+
"keypoints_coco_2014_valminusminival": (
|
| 61 |
+
"coco/val2014",
|
| 62 |
+
"coco/annotations/person_keypoints_valminusminival2014.json",
|
| 63 |
+
),
|
| 64 |
+
"keypoints_coco_2017_train": (
|
| 65 |
+
"coco/train2017",
|
| 66 |
+
"coco/annotations/person_keypoints_train2017.json",
|
| 67 |
+
),
|
| 68 |
+
"keypoints_coco_2017_val": ("coco/val2017", "coco/annotations/person_keypoints_val2017.json"),
|
| 69 |
+
"keypoints_coco_2017_val_100": (
|
| 70 |
+
"coco/val2017",
|
| 71 |
+
"coco/annotations/person_keypoints_val2017_100.json",
|
| 72 |
+
),
|
| 73 |
+
}
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
_PREDEFINED_SPLITS_COCO_PANOPTIC = {
|
| 77 |
+
"coco_2017_train_panoptic": (
|
| 78 |
+
# This is the original panoptic annotation directory
|
| 79 |
+
"coco/panoptic_train2017",
|
| 80 |
+
"coco/annotations/panoptic_train2017.json",
|
| 81 |
+
# This directory contains semantic annotations that are
|
| 82 |
+
# converted from panoptic annotations.
|
| 83 |
+
# It is used by PanopticFPN.
|
| 84 |
+
# You can use the script at detectron2/datasets/prepare_panoptic_fpn.py
|
| 85 |
+
# to create these directories.
|
| 86 |
+
"coco/panoptic_stuff_train2017",
|
| 87 |
+
),
|
| 88 |
+
"coco_2017_val_panoptic": (
|
| 89 |
+
"coco/panoptic_val2017",
|
| 90 |
+
"coco/annotations/panoptic_val2017.json",
|
| 91 |
+
"coco/panoptic_stuff_val2017",
|
| 92 |
+
),
|
| 93 |
+
"coco_2017_val_100_panoptic": (
|
| 94 |
+
"coco/panoptic_val2017_100",
|
| 95 |
+
"coco/annotations/panoptic_val2017_100.json",
|
| 96 |
+
"coco/panoptic_stuff_val2017_100",
|
| 97 |
+
),
|
| 98 |
+
}
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def register_all_coco(root):
|
| 102 |
+
for dataset_name, splits_per_dataset in _PREDEFINED_SPLITS_COCO.items():
|
| 103 |
+
for key, (image_root, json_file) in splits_per_dataset.items():
|
| 104 |
+
# Assume pre-defined datasets live in `./datasets`.
|
| 105 |
+
register_coco_instances(
|
| 106 |
+
key,
|
| 107 |
+
_get_builtin_metadata(dataset_name),
|
| 108 |
+
os.path.join(root, json_file) if "://" not in json_file else json_file,
|
| 109 |
+
os.path.join(root, image_root),
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
for (
|
| 113 |
+
prefix,
|
| 114 |
+
(panoptic_root, panoptic_json, semantic_root),
|
| 115 |
+
) in _PREDEFINED_SPLITS_COCO_PANOPTIC.items():
|
| 116 |
+
prefix_instances = prefix[: -len("_panoptic")]
|
| 117 |
+
instances_meta = MetadataCatalog.get(prefix_instances)
|
| 118 |
+
image_root, instances_json = instances_meta.image_root, instances_meta.json_file
|
| 119 |
+
# The "separated" version of COCO panoptic segmentation dataset,
|
| 120 |
+
# e.g. used by Panoptic FPN
|
| 121 |
+
register_coco_panoptic_separated(
|
| 122 |
+
prefix,
|
| 123 |
+
_get_builtin_metadata("coco_panoptic_separated"),
|
| 124 |
+
image_root,
|
| 125 |
+
os.path.join(root, panoptic_root),
|
| 126 |
+
os.path.join(root, panoptic_json),
|
| 127 |
+
os.path.join(root, semantic_root),
|
| 128 |
+
instances_json,
|
| 129 |
+
)
|
| 130 |
+
# The "standard" version of COCO panoptic segmentation dataset,
|
| 131 |
+
# e.g. used by Panoptic-DeepLab
|
| 132 |
+
register_coco_panoptic(
|
| 133 |
+
prefix,
|
| 134 |
+
_get_builtin_metadata("coco_panoptic_standard"),
|
| 135 |
+
image_root,
|
| 136 |
+
os.path.join(root, panoptic_root),
|
| 137 |
+
os.path.join(root, panoptic_json),
|
| 138 |
+
instances_json,
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
# ==== Predefined datasets and splits for LVIS ==========
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
_PREDEFINED_SPLITS_LVIS = {
|
| 146 |
+
"lvis_v1": {
|
| 147 |
+
"lvis_v1_train": ("coco/", "lvis/lvis_v1_train.json"),
|
| 148 |
+
"lvis_v1_val": ("coco/", "lvis/lvis_v1_val.json"),
|
| 149 |
+
"lvis_v1_test_dev": ("coco/", "lvis/lvis_v1_image_info_test_dev.json"),
|
| 150 |
+
"lvis_v1_test_challenge": ("coco/", "lvis/lvis_v1_image_info_test_challenge.json"),
|
| 151 |
+
},
|
| 152 |
+
"lvis_v0.5": {
|
| 153 |
+
"lvis_v0.5_train": ("coco/", "lvis/lvis_v0.5_train.json"),
|
| 154 |
+
"lvis_v0.5_val": ("coco/", "lvis/lvis_v0.5_val.json"),
|
| 155 |
+
"lvis_v0.5_val_rand_100": ("coco/", "lvis/lvis_v0.5_val_rand_100.json"),
|
| 156 |
+
"lvis_v0.5_test": ("coco/", "lvis/lvis_v0.5_image_info_test.json"),
|
| 157 |
+
},
|
| 158 |
+
"lvis_v0.5_cocofied": {
|
| 159 |
+
"lvis_v0.5_train_cocofied": ("coco/", "lvis/lvis_v0.5_train_cocofied.json"),
|
| 160 |
+
"lvis_v0.5_val_cocofied": ("coco/", "lvis/lvis_v0.5_val_cocofied.json"),
|
| 161 |
+
},
|
| 162 |
+
}
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
def register_all_lvis(root):
|
| 166 |
+
for dataset_name, splits_per_dataset in _PREDEFINED_SPLITS_LVIS.items():
|
| 167 |
+
for key, (image_root, json_file) in splits_per_dataset.items():
|
| 168 |
+
register_lvis_instances(
|
| 169 |
+
key,
|
| 170 |
+
get_lvis_instances_meta(dataset_name),
|
| 171 |
+
os.path.join(root, json_file) if "://" not in json_file else json_file,
|
| 172 |
+
os.path.join(root, image_root),
|
| 173 |
+
)
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
# ==== Predefined splits for raw cityscapes images ===========
|
| 177 |
+
_RAW_CITYSCAPES_SPLITS = {
|
| 178 |
+
"cityscapes_fine_{task}_train": ("cityscapes/leftImg8bit/train/", "cityscapes/gtFine/train/"),
|
| 179 |
+
"cityscapes_fine_{task}_val": ("cityscapes/leftImg8bit/val/", "cityscapes/gtFine/val/"),
|
| 180 |
+
"cityscapes_fine_{task}_test": ("cityscapes/leftImg8bit/test/", "cityscapes/gtFine/test/"),
|
| 181 |
+
}
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
def register_all_cityscapes(root):
|
| 185 |
+
for key, (image_dir, gt_dir) in _RAW_CITYSCAPES_SPLITS.items():
|
| 186 |
+
meta = _get_builtin_metadata("cityscapes")
|
| 187 |
+
image_dir = os.path.join(root, image_dir)
|
| 188 |
+
gt_dir = os.path.join(root, gt_dir)
|
| 189 |
+
|
| 190 |
+
inst_key = key.format(task="instance_seg")
|
| 191 |
+
DatasetCatalog.register(
|
| 192 |
+
inst_key,
|
| 193 |
+
lambda x=image_dir, y=gt_dir: load_cityscapes_instances(
|
| 194 |
+
x, y, from_json=True, to_polygons=True
|
| 195 |
+
),
|
| 196 |
+
)
|
| 197 |
+
MetadataCatalog.get(inst_key).set(
|
| 198 |
+
image_dir=image_dir, gt_dir=gt_dir, evaluator_type="cityscapes_instance", **meta
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
+
sem_key = key.format(task="sem_seg")
|
| 202 |
+
DatasetCatalog.register(
|
| 203 |
+
sem_key, lambda x=image_dir, y=gt_dir: load_cityscapes_semantic(x, y)
|
| 204 |
+
)
|
| 205 |
+
MetadataCatalog.get(sem_key).set(
|
| 206 |
+
image_dir=image_dir,
|
| 207 |
+
gt_dir=gt_dir,
|
| 208 |
+
evaluator_type="cityscapes_sem_seg",
|
| 209 |
+
ignore_label=255,
|
| 210 |
+
**meta,
|
| 211 |
+
)
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
# ==== Predefined splits for PASCAL VOC ===========
|
| 215 |
+
def register_all_pascal_voc(root):
|
| 216 |
+
SPLITS = [
|
| 217 |
+
("voc_2007_trainval", "VOC2007", "trainval"),
|
| 218 |
+
("voc_2007_train", "VOC2007", "train"),
|
| 219 |
+
("voc_2007_val", "VOC2007", "val"),
|
| 220 |
+
("voc_2007_test", "VOC2007", "test"),
|
| 221 |
+
("voc_2012_trainval", "VOC2012", "trainval"),
|
| 222 |
+
("voc_2012_train", "VOC2012", "train"),
|
| 223 |
+
("voc_2012_val", "VOC2012", "val"),
|
| 224 |
+
]
|
| 225 |
+
for name, dirname, split in SPLITS:
|
| 226 |
+
year = 2007 if "2007" in name else 2012
|
| 227 |
+
register_pascal_voc(name, os.path.join(root, dirname), split, year)
|
| 228 |
+
MetadataCatalog.get(name).evaluator_type = "pascal_voc"
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
def register_all_ade20k(root):
|
| 232 |
+
root = os.path.join(root, "ADEChallengeData2016")
|
| 233 |
+
for name, dirname in [("train", "training"), ("val", "validation")]:
|
| 234 |
+
image_dir = os.path.join(root, "images", dirname)
|
| 235 |
+
gt_dir = os.path.join(root, "annotations_detectron2", dirname)
|
| 236 |
+
name = f"ade20k_sem_seg_{name}"
|
| 237 |
+
DatasetCatalog.register(
|
| 238 |
+
name, lambda x=image_dir, y=gt_dir: load_sem_seg(y, x, gt_ext="png", image_ext="jpg")
|
| 239 |
+
)
|
| 240 |
+
MetadataCatalog.get(name).set(
|
| 241 |
+
stuff_classes=ADE20K_SEM_SEG_CATEGORIES[:],
|
| 242 |
+
image_root=image_dir,
|
| 243 |
+
sem_seg_root=gt_dir,
|
| 244 |
+
evaluator_type="sem_seg",
|
| 245 |
+
ignore_label=255,
|
| 246 |
+
)
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
# True for open source;
|
| 250 |
+
# Internally at fb, we register them elsewhere
|
| 251 |
+
if __name__.endswith(".builtin"):
|
| 252 |
+
# Assume pre-defined datasets live in `./datasets`.
|
| 253 |
+
_root = os.path.expanduser(os.getenv("DETECTRON2_DATASETS", "datasets"))
|
| 254 |
+
register_all_coco(_root)
|
| 255 |
+
register_all_lvis(_root)
|
| 256 |
+
register_all_cityscapes(_root)
|
| 257 |
+
register_all_cityscapes_panoptic(_root)
|
| 258 |
+
register_all_pascal_voc(_root)
|
| 259 |
+
register_all_ade20k(_root)
|
CatVTON/detectron2/data/datasets/builtin_meta.py
ADDED
|
@@ -0,0 +1,350 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 3 |
+
|
| 4 |
+
"""
|
| 5 |
+
Note:
|
| 6 |
+
For your custom dataset, there is no need to hard-code metadata anywhere in the code.
|
| 7 |
+
For example, for COCO-format dataset, metadata will be obtained automatically
|
| 8 |
+
when calling `load_coco_json`. For other dataset, metadata may also be obtained in other ways
|
| 9 |
+
during loading.
|
| 10 |
+
|
| 11 |
+
However, we hard-coded metadata for a few common dataset here.
|
| 12 |
+
The only goal is to allow users who don't have these dataset to use pre-trained models.
|
| 13 |
+
Users don't have to download a COCO json (which contains metadata), in order to visualize a
|
| 14 |
+
COCO model (with correct class names and colors).
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
# All coco categories, together with their nice-looking visualization colors
|
| 19 |
+
# It's from https://github.com/cocodataset/panopticapi/blob/master/panoptic_coco_categories.json
|
| 20 |
+
COCO_CATEGORIES = [
|
| 21 |
+
{"color": [220, 20, 60], "isthing": 1, "id": 1, "name": "person"},
|
| 22 |
+
{"color": [119, 11, 32], "isthing": 1, "id": 2, "name": "bicycle"},
|
| 23 |
+
{"color": [0, 0, 142], "isthing": 1, "id": 3, "name": "car"},
|
| 24 |
+
{"color": [0, 0, 230], "isthing": 1, "id": 4, "name": "motorcycle"},
|
| 25 |
+
{"color": [106, 0, 228], "isthing": 1, "id": 5, "name": "airplane"},
|
| 26 |
+
{"color": [0, 60, 100], "isthing": 1, "id": 6, "name": "bus"},
|
| 27 |
+
{"color": [0, 80, 100], "isthing": 1, "id": 7, "name": "train"},
|
| 28 |
+
{"color": [0, 0, 70], "isthing": 1, "id": 8, "name": "truck"},
|
| 29 |
+
{"color": [0, 0, 192], "isthing": 1, "id": 9, "name": "boat"},
|
| 30 |
+
{"color": [250, 170, 30], "isthing": 1, "id": 10, "name": "traffic light"},
|
| 31 |
+
{"color": [100, 170, 30], "isthing": 1, "id": 11, "name": "fire hydrant"},
|
| 32 |
+
{"color": [220, 220, 0], "isthing": 1, "id": 13, "name": "stop sign"},
|
| 33 |
+
{"color": [175, 116, 175], "isthing": 1, "id": 14, "name": "parking meter"},
|
| 34 |
+
{"color": [250, 0, 30], "isthing": 1, "id": 15, "name": "bench"},
|
| 35 |
+
{"color": [165, 42, 42], "isthing": 1, "id": 16, "name": "bird"},
|
| 36 |
+
{"color": [255, 77, 255], "isthing": 1, "id": 17, "name": "cat"},
|
| 37 |
+
{"color": [0, 226, 252], "isthing": 1, "id": 18, "name": "dog"},
|
| 38 |
+
{"color": [182, 182, 255], "isthing": 1, "id": 19, "name": "horse"},
|
| 39 |
+
{"color": [0, 82, 0], "isthing": 1, "id": 20, "name": "sheep"},
|
| 40 |
+
{"color": [120, 166, 157], "isthing": 1, "id": 21, "name": "cow"},
|
| 41 |
+
{"color": [110, 76, 0], "isthing": 1, "id": 22, "name": "elephant"},
|
| 42 |
+
{"color": [174, 57, 255], "isthing": 1, "id": 23, "name": "bear"},
|
| 43 |
+
{"color": [199, 100, 0], "isthing": 1, "id": 24, "name": "zebra"},
|
| 44 |
+
{"color": [72, 0, 118], "isthing": 1, "id": 25, "name": "giraffe"},
|
| 45 |
+
{"color": [255, 179, 240], "isthing": 1, "id": 27, "name": "backpack"},
|
| 46 |
+
{"color": [0, 125, 92], "isthing": 1, "id": 28, "name": "umbrella"},
|
| 47 |
+
{"color": [209, 0, 151], "isthing": 1, "id": 31, "name": "handbag"},
|
| 48 |
+
{"color": [188, 208, 182], "isthing": 1, "id": 32, "name": "tie"},
|
| 49 |
+
{"color": [0, 220, 176], "isthing": 1, "id": 33, "name": "suitcase"},
|
| 50 |
+
{"color": [255, 99, 164], "isthing": 1, "id": 34, "name": "frisbee"},
|
| 51 |
+
{"color": [92, 0, 73], "isthing": 1, "id": 35, "name": "skis"},
|
| 52 |
+
{"color": [133, 129, 255], "isthing": 1, "id": 36, "name": "snowboard"},
|
| 53 |
+
{"color": [78, 180, 255], "isthing": 1, "id": 37, "name": "sports ball"},
|
| 54 |
+
{"color": [0, 228, 0], "isthing": 1, "id": 38, "name": "kite"},
|
| 55 |
+
{"color": [174, 255, 243], "isthing": 1, "id": 39, "name": "baseball bat"},
|
| 56 |
+
{"color": [45, 89, 255], "isthing": 1, "id": 40, "name": "baseball glove"},
|
| 57 |
+
{"color": [134, 134, 103], "isthing": 1, "id": 41, "name": "skateboard"},
|
| 58 |
+
{"color": [145, 148, 174], "isthing": 1, "id": 42, "name": "surfboard"},
|
| 59 |
+
{"color": [255, 208, 186], "isthing": 1, "id": 43, "name": "tennis racket"},
|
| 60 |
+
{"color": [197, 226, 255], "isthing": 1, "id": 44, "name": "bottle"},
|
| 61 |
+
{"color": [171, 134, 1], "isthing": 1, "id": 46, "name": "wine glass"},
|
| 62 |
+
{"color": [109, 63, 54], "isthing": 1, "id": 47, "name": "cup"},
|
| 63 |
+
{"color": [207, 138, 255], "isthing": 1, "id": 48, "name": "fork"},
|
| 64 |
+
{"color": [151, 0, 95], "isthing": 1, "id": 49, "name": "knife"},
|
| 65 |
+
{"color": [9, 80, 61], "isthing": 1, "id": 50, "name": "spoon"},
|
| 66 |
+
{"color": [84, 105, 51], "isthing": 1, "id": 51, "name": "bowl"},
|
| 67 |
+
{"color": [74, 65, 105], "isthing": 1, "id": 52, "name": "banana"},
|
| 68 |
+
{"color": [166, 196, 102], "isthing": 1, "id": 53, "name": "apple"},
|
| 69 |
+
{"color": [208, 195, 210], "isthing": 1, "id": 54, "name": "sandwich"},
|
| 70 |
+
{"color": [255, 109, 65], "isthing": 1, "id": 55, "name": "orange"},
|
| 71 |
+
{"color": [0, 143, 149], "isthing": 1, "id": 56, "name": "broccoli"},
|
| 72 |
+
{"color": [179, 0, 194], "isthing": 1, "id": 57, "name": "carrot"},
|
| 73 |
+
{"color": [209, 99, 106], "isthing": 1, "id": 58, "name": "hot dog"},
|
| 74 |
+
{"color": [5, 121, 0], "isthing": 1, "id": 59, "name": "pizza"},
|
| 75 |
+
{"color": [227, 255, 205], "isthing": 1, "id": 60, "name": "donut"},
|
| 76 |
+
{"color": [147, 186, 208], "isthing": 1, "id": 61, "name": "cake"},
|
| 77 |
+
{"color": [153, 69, 1], "isthing": 1, "id": 62, "name": "chair"},
|
| 78 |
+
{"color": [3, 95, 161], "isthing": 1, "id": 63, "name": "couch"},
|
| 79 |
+
{"color": [163, 255, 0], "isthing": 1, "id": 64, "name": "potted plant"},
|
| 80 |
+
{"color": [119, 0, 170], "isthing": 1, "id": 65, "name": "bed"},
|
| 81 |
+
{"color": [0, 182, 199], "isthing": 1, "id": 67, "name": "dining table"},
|
| 82 |
+
{"color": [0, 165, 120], "isthing": 1, "id": 70, "name": "toilet"},
|
| 83 |
+
{"color": [183, 130, 88], "isthing": 1, "id": 72, "name": "tv"},
|
| 84 |
+
{"color": [95, 32, 0], "isthing": 1, "id": 73, "name": "laptop"},
|
| 85 |
+
{"color": [130, 114, 135], "isthing": 1, "id": 74, "name": "mouse"},
|
| 86 |
+
{"color": [110, 129, 133], "isthing": 1, "id": 75, "name": "remote"},
|
| 87 |
+
{"color": [166, 74, 118], "isthing": 1, "id": 76, "name": "keyboard"},
|
| 88 |
+
{"color": [219, 142, 185], "isthing": 1, "id": 77, "name": "cell phone"},
|
| 89 |
+
{"color": [79, 210, 114], "isthing": 1, "id": 78, "name": "microwave"},
|
| 90 |
+
{"color": [178, 90, 62], "isthing": 1, "id": 79, "name": "oven"},
|
| 91 |
+
{"color": [65, 70, 15], "isthing": 1, "id": 80, "name": "toaster"},
|
| 92 |
+
{"color": [127, 167, 115], "isthing": 1, "id": 81, "name": "sink"},
|
| 93 |
+
{"color": [59, 105, 106], "isthing": 1, "id": 82, "name": "refrigerator"},
|
| 94 |
+
{"color": [142, 108, 45], "isthing": 1, "id": 84, "name": "book"},
|
| 95 |
+
{"color": [196, 172, 0], "isthing": 1, "id": 85, "name": "clock"},
|
| 96 |
+
{"color": [95, 54, 80], "isthing": 1, "id": 86, "name": "vase"},
|
| 97 |
+
{"color": [128, 76, 255], "isthing": 1, "id": 87, "name": "scissors"},
|
| 98 |
+
{"color": [201, 57, 1], "isthing": 1, "id": 88, "name": "teddy bear"},
|
| 99 |
+
{"color": [246, 0, 122], "isthing": 1, "id": 89, "name": "hair drier"},
|
| 100 |
+
{"color": [191, 162, 208], "isthing": 1, "id": 90, "name": "toothbrush"},
|
| 101 |
+
{"color": [255, 255, 128], "isthing": 0, "id": 92, "name": "banner"},
|
| 102 |
+
{"color": [147, 211, 203], "isthing": 0, "id": 93, "name": "blanket"},
|
| 103 |
+
{"color": [150, 100, 100], "isthing": 0, "id": 95, "name": "bridge"},
|
| 104 |
+
{"color": [168, 171, 172], "isthing": 0, "id": 100, "name": "cardboard"},
|
| 105 |
+
{"color": [146, 112, 198], "isthing": 0, "id": 107, "name": "counter"},
|
| 106 |
+
{"color": [210, 170, 100], "isthing": 0, "id": 109, "name": "curtain"},
|
| 107 |
+
{"color": [92, 136, 89], "isthing": 0, "id": 112, "name": "door-stuff"},
|
| 108 |
+
{"color": [218, 88, 184], "isthing": 0, "id": 118, "name": "floor-wood"},
|
| 109 |
+
{"color": [241, 129, 0], "isthing": 0, "id": 119, "name": "flower"},
|
| 110 |
+
{"color": [217, 17, 255], "isthing": 0, "id": 122, "name": "fruit"},
|
| 111 |
+
{"color": [124, 74, 181], "isthing": 0, "id": 125, "name": "gravel"},
|
| 112 |
+
{"color": [70, 70, 70], "isthing": 0, "id": 128, "name": "house"},
|
| 113 |
+
{"color": [255, 228, 255], "isthing": 0, "id": 130, "name": "light"},
|
| 114 |
+
{"color": [154, 208, 0], "isthing": 0, "id": 133, "name": "mirror-stuff"},
|
| 115 |
+
{"color": [193, 0, 92], "isthing": 0, "id": 138, "name": "net"},
|
| 116 |
+
{"color": [76, 91, 113], "isthing": 0, "id": 141, "name": "pillow"},
|
| 117 |
+
{"color": [255, 180, 195], "isthing": 0, "id": 144, "name": "platform"},
|
| 118 |
+
{"color": [106, 154, 176], "isthing": 0, "id": 145, "name": "playingfield"},
|
| 119 |
+
{"color": [230, 150, 140], "isthing": 0, "id": 147, "name": "railroad"},
|
| 120 |
+
{"color": [60, 143, 255], "isthing": 0, "id": 148, "name": "river"},
|
| 121 |
+
{"color": [128, 64, 128], "isthing": 0, "id": 149, "name": "road"},
|
| 122 |
+
{"color": [92, 82, 55], "isthing": 0, "id": 151, "name": "roof"},
|
| 123 |
+
{"color": [254, 212, 124], "isthing": 0, "id": 154, "name": "sand"},
|
| 124 |
+
{"color": [73, 77, 174], "isthing": 0, "id": 155, "name": "sea"},
|
| 125 |
+
{"color": [255, 160, 98], "isthing": 0, "id": 156, "name": "shelf"},
|
| 126 |
+
{"color": [255, 255, 255], "isthing": 0, "id": 159, "name": "snow"},
|
| 127 |
+
{"color": [104, 84, 109], "isthing": 0, "id": 161, "name": "stairs"},
|
| 128 |
+
{"color": [169, 164, 131], "isthing": 0, "id": 166, "name": "tent"},
|
| 129 |
+
{"color": [225, 199, 255], "isthing": 0, "id": 168, "name": "towel"},
|
| 130 |
+
{"color": [137, 54, 74], "isthing": 0, "id": 171, "name": "wall-brick"},
|
| 131 |
+
{"color": [135, 158, 223], "isthing": 0, "id": 175, "name": "wall-stone"},
|
| 132 |
+
{"color": [7, 246, 231], "isthing": 0, "id": 176, "name": "wall-tile"},
|
| 133 |
+
{"color": [107, 255, 200], "isthing": 0, "id": 177, "name": "wall-wood"},
|
| 134 |
+
{"color": [58, 41, 149], "isthing": 0, "id": 178, "name": "water-other"},
|
| 135 |
+
{"color": [183, 121, 142], "isthing": 0, "id": 180, "name": "window-blind"},
|
| 136 |
+
{"color": [255, 73, 97], "isthing": 0, "id": 181, "name": "window-other"},
|
| 137 |
+
{"color": [107, 142, 35], "isthing": 0, "id": 184, "name": "tree-merged"},
|
| 138 |
+
{"color": [190, 153, 153], "isthing": 0, "id": 185, "name": "fence-merged"},
|
| 139 |
+
{"color": [146, 139, 141], "isthing": 0, "id": 186, "name": "ceiling-merged"},
|
| 140 |
+
{"color": [70, 130, 180], "isthing": 0, "id": 187, "name": "sky-other-merged"},
|
| 141 |
+
{"color": [134, 199, 156], "isthing": 0, "id": 188, "name": "cabinet-merged"},
|
| 142 |
+
{"color": [209, 226, 140], "isthing": 0, "id": 189, "name": "table-merged"},
|
| 143 |
+
{"color": [96, 36, 108], "isthing": 0, "id": 190, "name": "floor-other-merged"},
|
| 144 |
+
{"color": [96, 96, 96], "isthing": 0, "id": 191, "name": "pavement-merged"},
|
| 145 |
+
{"color": [64, 170, 64], "isthing": 0, "id": 192, "name": "mountain-merged"},
|
| 146 |
+
{"color": [152, 251, 152], "isthing": 0, "id": 193, "name": "grass-merged"},
|
| 147 |
+
{"color": [208, 229, 228], "isthing": 0, "id": 194, "name": "dirt-merged"},
|
| 148 |
+
{"color": [206, 186, 171], "isthing": 0, "id": 195, "name": "paper-merged"},
|
| 149 |
+
{"color": [152, 161, 64], "isthing": 0, "id": 196, "name": "food-other-merged"},
|
| 150 |
+
{"color": [116, 112, 0], "isthing": 0, "id": 197, "name": "building-other-merged"},
|
| 151 |
+
{"color": [0, 114, 143], "isthing": 0, "id": 198, "name": "rock-merged"},
|
| 152 |
+
{"color": [102, 102, 156], "isthing": 0, "id": 199, "name": "wall-other-merged"},
|
| 153 |
+
{"color": [250, 141, 255], "isthing": 0, "id": 200, "name": "rug-merged"},
|
| 154 |
+
]
|
| 155 |
+
|
| 156 |
+
# fmt: off
|
| 157 |
+
COCO_PERSON_KEYPOINT_NAMES = (
|
| 158 |
+
"nose",
|
| 159 |
+
"left_eye", "right_eye",
|
| 160 |
+
"left_ear", "right_ear",
|
| 161 |
+
"left_shoulder", "right_shoulder",
|
| 162 |
+
"left_elbow", "right_elbow",
|
| 163 |
+
"left_wrist", "right_wrist",
|
| 164 |
+
"left_hip", "right_hip",
|
| 165 |
+
"left_knee", "right_knee",
|
| 166 |
+
"left_ankle", "right_ankle",
|
| 167 |
+
)
|
| 168 |
+
# fmt: on
|
| 169 |
+
|
| 170 |
+
# Pairs of keypoints that should be exchanged under horizontal flipping
|
| 171 |
+
COCO_PERSON_KEYPOINT_FLIP_MAP = (
|
| 172 |
+
("left_eye", "right_eye"),
|
| 173 |
+
("left_ear", "right_ear"),
|
| 174 |
+
("left_shoulder", "right_shoulder"),
|
| 175 |
+
("left_elbow", "right_elbow"),
|
| 176 |
+
("left_wrist", "right_wrist"),
|
| 177 |
+
("left_hip", "right_hip"),
|
| 178 |
+
("left_knee", "right_knee"),
|
| 179 |
+
("left_ankle", "right_ankle"),
|
| 180 |
+
)
|
| 181 |
+
|
| 182 |
+
# rules for pairs of keypoints to draw a line between, and the line color to use.
|
| 183 |
+
KEYPOINT_CONNECTION_RULES = [
|
| 184 |
+
# face
|
| 185 |
+
("left_ear", "left_eye", (102, 204, 255)),
|
| 186 |
+
("right_ear", "right_eye", (51, 153, 255)),
|
| 187 |
+
("left_eye", "nose", (102, 0, 204)),
|
| 188 |
+
("nose", "right_eye", (51, 102, 255)),
|
| 189 |
+
# upper-body
|
| 190 |
+
("left_shoulder", "right_shoulder", (255, 128, 0)),
|
| 191 |
+
("left_shoulder", "left_elbow", (153, 255, 204)),
|
| 192 |
+
("right_shoulder", "right_elbow", (128, 229, 255)),
|
| 193 |
+
("left_elbow", "left_wrist", (153, 255, 153)),
|
| 194 |
+
("right_elbow", "right_wrist", (102, 255, 224)),
|
| 195 |
+
# lower-body
|
| 196 |
+
("left_hip", "right_hip", (255, 102, 0)),
|
| 197 |
+
("left_hip", "left_knee", (255, 255, 77)),
|
| 198 |
+
("right_hip", "right_knee", (153, 255, 204)),
|
| 199 |
+
("left_knee", "left_ankle", (191, 255, 128)),
|
| 200 |
+
("right_knee", "right_ankle", (255, 195, 77)),
|
| 201 |
+
]
|
| 202 |
+
|
| 203 |
+
# All Cityscapes categories, together with their nice-looking visualization colors
|
| 204 |
+
# It's from https://github.com/mcordts/cityscapesScripts/blob/master/cityscapesscripts/helpers/labels.py # noqa
|
| 205 |
+
CITYSCAPES_CATEGORIES = [
|
| 206 |
+
{"color": (128, 64, 128), "isthing": 0, "id": 7, "trainId": 0, "name": "road"},
|
| 207 |
+
{"color": (244, 35, 232), "isthing": 0, "id": 8, "trainId": 1, "name": "sidewalk"},
|
| 208 |
+
{"color": (70, 70, 70), "isthing": 0, "id": 11, "trainId": 2, "name": "building"},
|
| 209 |
+
{"color": (102, 102, 156), "isthing": 0, "id": 12, "trainId": 3, "name": "wall"},
|
| 210 |
+
{"color": (190, 153, 153), "isthing": 0, "id": 13, "trainId": 4, "name": "fence"},
|
| 211 |
+
{"color": (153, 153, 153), "isthing": 0, "id": 17, "trainId": 5, "name": "pole"},
|
| 212 |
+
{"color": (250, 170, 30), "isthing": 0, "id": 19, "trainId": 6, "name": "traffic light"},
|
| 213 |
+
{"color": (220, 220, 0), "isthing": 0, "id": 20, "trainId": 7, "name": "traffic sign"},
|
| 214 |
+
{"color": (107, 142, 35), "isthing": 0, "id": 21, "trainId": 8, "name": "vegetation"},
|
| 215 |
+
{"color": (152, 251, 152), "isthing": 0, "id": 22, "trainId": 9, "name": "terrain"},
|
| 216 |
+
{"color": (70, 130, 180), "isthing": 0, "id": 23, "trainId": 10, "name": "sky"},
|
| 217 |
+
{"color": (220, 20, 60), "isthing": 1, "id": 24, "trainId": 11, "name": "person"},
|
| 218 |
+
{"color": (255, 0, 0), "isthing": 1, "id": 25, "trainId": 12, "name": "rider"},
|
| 219 |
+
{"color": (0, 0, 142), "isthing": 1, "id": 26, "trainId": 13, "name": "car"},
|
| 220 |
+
{"color": (0, 0, 70), "isthing": 1, "id": 27, "trainId": 14, "name": "truck"},
|
| 221 |
+
{"color": (0, 60, 100), "isthing": 1, "id": 28, "trainId": 15, "name": "bus"},
|
| 222 |
+
{"color": (0, 80, 100), "isthing": 1, "id": 31, "trainId": 16, "name": "train"},
|
| 223 |
+
{"color": (0, 0, 230), "isthing": 1, "id": 32, "trainId": 17, "name": "motorcycle"},
|
| 224 |
+
{"color": (119, 11, 32), "isthing": 1, "id": 33, "trainId": 18, "name": "bicycle"},
|
| 225 |
+
]
|
| 226 |
+
|
| 227 |
+
# fmt: off
|
| 228 |
+
ADE20K_SEM_SEG_CATEGORIES = [
|
| 229 |
+
"wall", "building", "sky", "floor", "tree", "ceiling", "road, route", "bed", "window ", "grass", "cabinet", "sidewalk, pavement", "person", "earth, ground", "door", "table", "mountain, mount", "plant", "curtain", "chair", "car", "water", "painting, picture", "sofa", "shelf", "house", "sea", "mirror", "rug", "field", "armchair", "seat", "fence", "desk", "rock, stone", "wardrobe, closet, press", "lamp", "tub", "rail", "cushion", "base, pedestal, stand", "box", "column, pillar", "signboard, sign", "chest of drawers, chest, bureau, dresser", "counter", "sand", "sink", "skyscraper", "fireplace", "refrigerator, icebox", "grandstand, covered stand", "path", "stairs", "runway", "case, display case, showcase, vitrine", "pool table, billiard table, snooker table", "pillow", "screen door, screen", "stairway, staircase", "river", "bridge, span", "bookcase", "blind, screen", "coffee table", "toilet, can, commode, crapper, pot, potty, stool, throne", "flower", "book", "hill", "bench", "countertop", "stove", "palm, palm tree", "kitchen island", "computer", "swivel chair", "boat", "bar", "arcade machine", "hovel, hut, hutch, shack, shanty", "bus", "towel", "light", "truck", "tower", "chandelier", "awning, sunshade, sunblind", "street lamp", "booth", "tv", "plane", "dirt track", "clothes", "pole", "land, ground, soil", "bannister, banister, balustrade, balusters, handrail", "escalator, moving staircase, moving stairway", "ottoman, pouf, pouffe, puff, hassock", "bottle", "buffet, counter, sideboard", "poster, posting, placard, notice, bill, card", "stage", "van", "ship", "fountain", "conveyer belt, conveyor belt, conveyer, conveyor, transporter", "canopy", "washer, automatic washer, washing machine", "plaything, toy", "pool", "stool", "barrel, cask", "basket, handbasket", "falls", "tent", "bag", "minibike, motorbike", "cradle", "oven", "ball", "food, solid food", "step, stair", "tank, storage tank", "trade name", "microwave", "pot", "animal", "bicycle", "lake", "dishwasher", "screen", "blanket, cover", "sculpture", "hood, exhaust hood", "sconce", "vase", "traffic light", "tray", "trash can", "fan", "pier", "crt screen", "plate", "monitor", "bulletin board", "shower", "radiator", "glass, drinking glass", "clock", "flag", # noqa
|
| 230 |
+
]
|
| 231 |
+
# After processed by `prepare_ade20k_sem_seg.py`, id 255 means ignore
|
| 232 |
+
# fmt: on
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
def _get_coco_instances_meta():
|
| 236 |
+
thing_ids = [k["id"] for k in COCO_CATEGORIES if k["isthing"] == 1]
|
| 237 |
+
thing_colors = [k["color"] for k in COCO_CATEGORIES if k["isthing"] == 1]
|
| 238 |
+
assert len(thing_ids) == 80, len(thing_ids)
|
| 239 |
+
# Mapping from the incontiguous COCO category id to an id in [0, 79]
|
| 240 |
+
thing_dataset_id_to_contiguous_id = {k: i for i, k in enumerate(thing_ids)}
|
| 241 |
+
thing_classes = [k["name"] for k in COCO_CATEGORIES if k["isthing"] == 1]
|
| 242 |
+
ret = {
|
| 243 |
+
"thing_dataset_id_to_contiguous_id": thing_dataset_id_to_contiguous_id,
|
| 244 |
+
"thing_classes": thing_classes,
|
| 245 |
+
"thing_colors": thing_colors,
|
| 246 |
+
}
|
| 247 |
+
return ret
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
def _get_coco_panoptic_separated_meta():
|
| 251 |
+
"""
|
| 252 |
+
Returns metadata for "separated" version of the panoptic segmentation dataset.
|
| 253 |
+
"""
|
| 254 |
+
stuff_ids = [k["id"] for k in COCO_CATEGORIES if k["isthing"] == 0]
|
| 255 |
+
assert len(stuff_ids) == 53, len(stuff_ids)
|
| 256 |
+
|
| 257 |
+
# For semantic segmentation, this mapping maps from contiguous stuff id
|
| 258 |
+
# (in [0, 53], used in models) to ids in the dataset (used for processing results)
|
| 259 |
+
# The id 0 is mapped to an extra category "thing".
|
| 260 |
+
stuff_dataset_id_to_contiguous_id = {k: i + 1 for i, k in enumerate(stuff_ids)}
|
| 261 |
+
# When converting COCO panoptic annotations to semantic annotations
|
| 262 |
+
# We label the "thing" category to 0
|
| 263 |
+
stuff_dataset_id_to_contiguous_id[0] = 0
|
| 264 |
+
|
| 265 |
+
# 54 names for COCO stuff categories (including "things")
|
| 266 |
+
stuff_classes = ["things"] + [
|
| 267 |
+
k["name"].replace("-other", "").replace("-merged", "")
|
| 268 |
+
for k in COCO_CATEGORIES
|
| 269 |
+
if k["isthing"] == 0
|
| 270 |
+
]
|
| 271 |
+
|
| 272 |
+
# NOTE: I randomly picked a color for things
|
| 273 |
+
stuff_colors = [[82, 18, 128]] + [k["color"] for k in COCO_CATEGORIES if k["isthing"] == 0]
|
| 274 |
+
ret = {
|
| 275 |
+
"stuff_dataset_id_to_contiguous_id": stuff_dataset_id_to_contiguous_id,
|
| 276 |
+
"stuff_classes": stuff_classes,
|
| 277 |
+
"stuff_colors": stuff_colors,
|
| 278 |
+
}
|
| 279 |
+
ret.update(_get_coco_instances_meta())
|
| 280 |
+
return ret
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
def _get_builtin_metadata(dataset_name):
|
| 284 |
+
if dataset_name == "coco":
|
| 285 |
+
return _get_coco_instances_meta()
|
| 286 |
+
if dataset_name == "coco_panoptic_separated":
|
| 287 |
+
return _get_coco_panoptic_separated_meta()
|
| 288 |
+
elif dataset_name == "coco_panoptic_standard":
|
| 289 |
+
meta = {}
|
| 290 |
+
# The following metadata maps contiguous id from [0, #thing categories +
|
| 291 |
+
# #stuff categories) to their names and colors. We have to replica of the
|
| 292 |
+
# same name and color under "thing_*" and "stuff_*" because the current
|
| 293 |
+
# visualization function in D2 handles thing and class classes differently
|
| 294 |
+
# due to some heuristic used in Panoptic FPN. We keep the same naming to
|
| 295 |
+
# enable reusing existing visualization functions.
|
| 296 |
+
thing_classes = [k["name"] for k in COCO_CATEGORIES]
|
| 297 |
+
thing_colors = [k["color"] for k in COCO_CATEGORIES]
|
| 298 |
+
stuff_classes = [k["name"] for k in COCO_CATEGORIES]
|
| 299 |
+
stuff_colors = [k["color"] for k in COCO_CATEGORIES]
|
| 300 |
+
|
| 301 |
+
meta["thing_classes"] = thing_classes
|
| 302 |
+
meta["thing_colors"] = thing_colors
|
| 303 |
+
meta["stuff_classes"] = stuff_classes
|
| 304 |
+
meta["stuff_colors"] = stuff_colors
|
| 305 |
+
|
| 306 |
+
# Convert category id for training:
|
| 307 |
+
# category id: like semantic segmentation, it is the class id for each
|
| 308 |
+
# pixel. Since there are some classes not used in evaluation, the category
|
| 309 |
+
# id is not always contiguous and thus we have two set of category ids:
|
| 310 |
+
# - original category id: category id in the original dataset, mainly
|
| 311 |
+
# used for evaluation.
|
| 312 |
+
# - contiguous category id: [0, #classes), in order to train the linear
|
| 313 |
+
# softmax classifier.
|
| 314 |
+
thing_dataset_id_to_contiguous_id = {}
|
| 315 |
+
stuff_dataset_id_to_contiguous_id = {}
|
| 316 |
+
|
| 317 |
+
for i, cat in enumerate(COCO_CATEGORIES):
|
| 318 |
+
if cat["isthing"]:
|
| 319 |
+
thing_dataset_id_to_contiguous_id[cat["id"]] = i
|
| 320 |
+
else:
|
| 321 |
+
stuff_dataset_id_to_contiguous_id[cat["id"]] = i
|
| 322 |
+
|
| 323 |
+
meta["thing_dataset_id_to_contiguous_id"] = thing_dataset_id_to_contiguous_id
|
| 324 |
+
meta["stuff_dataset_id_to_contiguous_id"] = stuff_dataset_id_to_contiguous_id
|
| 325 |
+
|
| 326 |
+
return meta
|
| 327 |
+
elif dataset_name == "coco_person":
|
| 328 |
+
return {
|
| 329 |
+
"thing_classes": ["person"],
|
| 330 |
+
"keypoint_names": COCO_PERSON_KEYPOINT_NAMES,
|
| 331 |
+
"keypoint_flip_map": COCO_PERSON_KEYPOINT_FLIP_MAP,
|
| 332 |
+
"keypoint_connection_rules": KEYPOINT_CONNECTION_RULES,
|
| 333 |
+
}
|
| 334 |
+
elif dataset_name == "cityscapes":
|
| 335 |
+
# fmt: off
|
| 336 |
+
CITYSCAPES_THING_CLASSES = [
|
| 337 |
+
"person", "rider", "car", "truck",
|
| 338 |
+
"bus", "train", "motorcycle", "bicycle",
|
| 339 |
+
]
|
| 340 |
+
CITYSCAPES_STUFF_CLASSES = [
|
| 341 |
+
"road", "sidewalk", "building", "wall", "fence", "pole", "traffic light",
|
| 342 |
+
"traffic sign", "vegetation", "terrain", "sky", "person", "rider", "car",
|
| 343 |
+
"truck", "bus", "train", "motorcycle", "bicycle",
|
| 344 |
+
]
|
| 345 |
+
# fmt: on
|
| 346 |
+
return {
|
| 347 |
+
"thing_classes": CITYSCAPES_THING_CLASSES,
|
| 348 |
+
"stuff_classes": CITYSCAPES_STUFF_CLASSES,
|
| 349 |
+
}
|
| 350 |
+
raise KeyError("No built-in metadata for dataset {}".format(dataset_name))
|
CatVTON/detectron2/data/datasets/cityscapes.py
ADDED
|
@@ -0,0 +1,334 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
import functools
|
| 3 |
+
import json
|
| 4 |
+
import logging
|
| 5 |
+
import multiprocessing as mp
|
| 6 |
+
import numpy as np
|
| 7 |
+
import os
|
| 8 |
+
from itertools import chain
|
| 9 |
+
import pycocotools.mask as mask_util
|
| 10 |
+
from PIL import Image
|
| 11 |
+
|
| 12 |
+
from detectron2.structures import BoxMode
|
| 13 |
+
from detectron2.utils.comm import get_world_size
|
| 14 |
+
from detectron2.utils.file_io import PathManager
|
| 15 |
+
from detectron2.utils.logger import setup_logger
|
| 16 |
+
|
| 17 |
+
try:
|
| 18 |
+
import cv2 # noqa
|
| 19 |
+
except ImportError:
|
| 20 |
+
# OpenCV is an optional dependency at the moment
|
| 21 |
+
pass
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
logger = logging.getLogger(__name__)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def _get_cityscapes_files(image_dir, gt_dir):
|
| 28 |
+
files = []
|
| 29 |
+
# scan through the directory
|
| 30 |
+
cities = PathManager.ls(image_dir)
|
| 31 |
+
logger.info(f"{len(cities)} cities found in '{image_dir}'.")
|
| 32 |
+
for city in cities:
|
| 33 |
+
city_img_dir = os.path.join(image_dir, city)
|
| 34 |
+
city_gt_dir = os.path.join(gt_dir, city)
|
| 35 |
+
for basename in PathManager.ls(city_img_dir):
|
| 36 |
+
image_file = os.path.join(city_img_dir, basename)
|
| 37 |
+
|
| 38 |
+
suffix = "leftImg8bit.png"
|
| 39 |
+
assert basename.endswith(suffix), basename
|
| 40 |
+
basename = basename[: -len(suffix)]
|
| 41 |
+
|
| 42 |
+
instance_file = os.path.join(city_gt_dir, basename + "gtFine_instanceIds.png")
|
| 43 |
+
label_file = os.path.join(city_gt_dir, basename + "gtFine_labelIds.png")
|
| 44 |
+
json_file = os.path.join(city_gt_dir, basename + "gtFine_polygons.json")
|
| 45 |
+
|
| 46 |
+
files.append((image_file, instance_file, label_file, json_file))
|
| 47 |
+
assert len(files), "No images found in {}".format(image_dir)
|
| 48 |
+
for f in files[0]:
|
| 49 |
+
assert PathManager.isfile(f), f
|
| 50 |
+
return files
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def load_cityscapes_instances(image_dir, gt_dir, from_json=True, to_polygons=True):
|
| 54 |
+
"""
|
| 55 |
+
Args:
|
| 56 |
+
image_dir (str): path to the raw dataset. e.g., "~/cityscapes/leftImg8bit/train".
|
| 57 |
+
gt_dir (str): path to the raw annotations. e.g., "~/cityscapes/gtFine/train".
|
| 58 |
+
from_json (bool): whether to read annotations from the raw json file or the png files.
|
| 59 |
+
to_polygons (bool): whether to represent the segmentation as polygons
|
| 60 |
+
(COCO's format) instead of masks (cityscapes's format).
|
| 61 |
+
|
| 62 |
+
Returns:
|
| 63 |
+
list[dict]: a list of dicts in Detectron2 standard format. (See
|
| 64 |
+
`Using Custom Datasets </tutorials/datasets.html>`_ )
|
| 65 |
+
"""
|
| 66 |
+
if from_json:
|
| 67 |
+
assert to_polygons, (
|
| 68 |
+
"Cityscapes's json annotations are in polygon format. "
|
| 69 |
+
"Converting to mask format is not supported now."
|
| 70 |
+
)
|
| 71 |
+
files = _get_cityscapes_files(image_dir, gt_dir)
|
| 72 |
+
|
| 73 |
+
logger.info("Preprocessing cityscapes annotations ...")
|
| 74 |
+
# This is still not fast: all workers will execute duplicate works and will
|
| 75 |
+
# take up to 10m on a 8GPU server.
|
| 76 |
+
pool = mp.Pool(processes=max(mp.cpu_count() // get_world_size() // 2, 4))
|
| 77 |
+
|
| 78 |
+
ret = pool.map(
|
| 79 |
+
functools.partial(_cityscapes_files_to_dict, from_json=from_json, to_polygons=to_polygons),
|
| 80 |
+
files,
|
| 81 |
+
)
|
| 82 |
+
logger.info("Loaded {} images from {}".format(len(ret), image_dir))
|
| 83 |
+
|
| 84 |
+
# Map cityscape ids to contiguous ids
|
| 85 |
+
from cityscapesscripts.helpers.labels import labels
|
| 86 |
+
|
| 87 |
+
labels = [l for l in labels if l.hasInstances and not l.ignoreInEval]
|
| 88 |
+
dataset_id_to_contiguous_id = {l.id: idx for idx, l in enumerate(labels)}
|
| 89 |
+
for dict_per_image in ret:
|
| 90 |
+
for anno in dict_per_image["annotations"]:
|
| 91 |
+
anno["category_id"] = dataset_id_to_contiguous_id[anno["category_id"]]
|
| 92 |
+
return ret
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def load_cityscapes_semantic(image_dir, gt_dir):
|
| 96 |
+
"""
|
| 97 |
+
Args:
|
| 98 |
+
image_dir (str): path to the raw dataset. e.g., "~/cityscapes/leftImg8bit/train".
|
| 99 |
+
gt_dir (str): path to the raw annotations. e.g., "~/cityscapes/gtFine/train".
|
| 100 |
+
|
| 101 |
+
Returns:
|
| 102 |
+
list[dict]: a list of dict, each has "file_name" and
|
| 103 |
+
"sem_seg_file_name".
|
| 104 |
+
"""
|
| 105 |
+
ret = []
|
| 106 |
+
# gt_dir is small and contain many small files. make sense to fetch to local first
|
| 107 |
+
gt_dir = PathManager.get_local_path(gt_dir)
|
| 108 |
+
for image_file, _, label_file, json_file in _get_cityscapes_files(image_dir, gt_dir):
|
| 109 |
+
label_file = label_file.replace("labelIds", "labelTrainIds")
|
| 110 |
+
|
| 111 |
+
with PathManager.open(json_file, "r") as f:
|
| 112 |
+
jsonobj = json.load(f)
|
| 113 |
+
ret.append(
|
| 114 |
+
{
|
| 115 |
+
"file_name": image_file,
|
| 116 |
+
"sem_seg_file_name": label_file,
|
| 117 |
+
"height": jsonobj["imgHeight"],
|
| 118 |
+
"width": jsonobj["imgWidth"],
|
| 119 |
+
}
|
| 120 |
+
)
|
| 121 |
+
assert len(ret), f"No images found in {image_dir}!"
|
| 122 |
+
assert PathManager.isfile(
|
| 123 |
+
ret[0]["sem_seg_file_name"]
|
| 124 |
+
), "Please generate labelTrainIds.png with cityscapesscripts/preparation/createTrainIdLabelImgs.py" # noqa
|
| 125 |
+
return ret
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def _cityscapes_files_to_dict(files, from_json, to_polygons):
|
| 129 |
+
"""
|
| 130 |
+
Parse cityscapes annotation files to a instance segmentation dataset dict.
|
| 131 |
+
|
| 132 |
+
Args:
|
| 133 |
+
files (tuple): consists of (image_file, instance_id_file, label_id_file, json_file)
|
| 134 |
+
from_json (bool): whether to read annotations from the raw json file or the png files.
|
| 135 |
+
to_polygons (bool): whether to represent the segmentation as polygons
|
| 136 |
+
(COCO's format) instead of masks (cityscapes's format).
|
| 137 |
+
|
| 138 |
+
Returns:
|
| 139 |
+
A dict in Detectron2 Dataset format.
|
| 140 |
+
"""
|
| 141 |
+
from cityscapesscripts.helpers.labels import id2label, name2label
|
| 142 |
+
|
| 143 |
+
image_file, instance_id_file, _, json_file = files
|
| 144 |
+
|
| 145 |
+
annos = []
|
| 146 |
+
|
| 147 |
+
if from_json:
|
| 148 |
+
from shapely.geometry import MultiPolygon, Polygon
|
| 149 |
+
|
| 150 |
+
with PathManager.open(json_file, "r") as f:
|
| 151 |
+
jsonobj = json.load(f)
|
| 152 |
+
ret = {
|
| 153 |
+
"file_name": image_file,
|
| 154 |
+
"image_id": os.path.basename(image_file),
|
| 155 |
+
"height": jsonobj["imgHeight"],
|
| 156 |
+
"width": jsonobj["imgWidth"],
|
| 157 |
+
}
|
| 158 |
+
|
| 159 |
+
# `polygons_union` contains the union of all valid polygons.
|
| 160 |
+
polygons_union = Polygon()
|
| 161 |
+
|
| 162 |
+
# CityscapesScripts draw the polygons in sequential order
|
| 163 |
+
# and each polygon *overwrites* existing ones. See
|
| 164 |
+
# (https://github.com/mcordts/cityscapesScripts/blob/master/cityscapesscripts/preparation/json2instanceImg.py) # noqa
|
| 165 |
+
# We use reverse order, and each polygon *avoids* early ones.
|
| 166 |
+
# This will resolve the ploygon overlaps in the same way as CityscapesScripts.
|
| 167 |
+
for obj in jsonobj["objects"][::-1]:
|
| 168 |
+
if "deleted" in obj: # cityscapes data format specific
|
| 169 |
+
continue
|
| 170 |
+
label_name = obj["label"]
|
| 171 |
+
|
| 172 |
+
try:
|
| 173 |
+
label = name2label[label_name]
|
| 174 |
+
except KeyError:
|
| 175 |
+
if label_name.endswith("group"): # crowd area
|
| 176 |
+
label = name2label[label_name[: -len("group")]]
|
| 177 |
+
else:
|
| 178 |
+
raise
|
| 179 |
+
if label.id < 0: # cityscapes data format
|
| 180 |
+
continue
|
| 181 |
+
|
| 182 |
+
# Cityscapes's raw annotations uses integer coordinates
|
| 183 |
+
# Therefore +0.5 here
|
| 184 |
+
poly_coord = np.asarray(obj["polygon"], dtype="f4") + 0.5
|
| 185 |
+
# CityscapesScript uses PIL.ImageDraw.polygon to rasterize
|
| 186 |
+
# polygons for evaluation. This function operates in integer space
|
| 187 |
+
# and draws each pixel whose center falls into the polygon.
|
| 188 |
+
# Therefore it draws a polygon which is 0.5 "fatter" in expectation.
|
| 189 |
+
# We therefore dilate the input polygon by 0.5 as our input.
|
| 190 |
+
poly = Polygon(poly_coord).buffer(0.5, resolution=4)
|
| 191 |
+
|
| 192 |
+
if not label.hasInstances or label.ignoreInEval:
|
| 193 |
+
# even if we won't store the polygon it still contributes to overlaps resolution
|
| 194 |
+
polygons_union = polygons_union.union(poly)
|
| 195 |
+
continue
|
| 196 |
+
|
| 197 |
+
# Take non-overlapping part of the polygon
|
| 198 |
+
poly_wo_overlaps = poly.difference(polygons_union)
|
| 199 |
+
if poly_wo_overlaps.is_empty:
|
| 200 |
+
continue
|
| 201 |
+
polygons_union = polygons_union.union(poly)
|
| 202 |
+
|
| 203 |
+
anno = {}
|
| 204 |
+
anno["iscrowd"] = label_name.endswith("group")
|
| 205 |
+
anno["category_id"] = label.id
|
| 206 |
+
|
| 207 |
+
if isinstance(poly_wo_overlaps, Polygon):
|
| 208 |
+
poly_list = [poly_wo_overlaps]
|
| 209 |
+
elif isinstance(poly_wo_overlaps, MultiPolygon):
|
| 210 |
+
poly_list = poly_wo_overlaps.geoms
|
| 211 |
+
else:
|
| 212 |
+
raise NotImplementedError("Unknown geometric structure {}".format(poly_wo_overlaps))
|
| 213 |
+
|
| 214 |
+
poly_coord = []
|
| 215 |
+
for poly_el in poly_list:
|
| 216 |
+
# COCO API can work only with exterior boundaries now, hence we store only them.
|
| 217 |
+
# TODO: store both exterior and interior boundaries once other parts of the
|
| 218 |
+
# codebase support holes in polygons.
|
| 219 |
+
poly_coord.append(list(chain(*poly_el.exterior.coords)))
|
| 220 |
+
anno["segmentation"] = poly_coord
|
| 221 |
+
(xmin, ymin, xmax, ymax) = poly_wo_overlaps.bounds
|
| 222 |
+
|
| 223 |
+
anno["bbox"] = (xmin, ymin, xmax, ymax)
|
| 224 |
+
anno["bbox_mode"] = BoxMode.XYXY_ABS
|
| 225 |
+
|
| 226 |
+
annos.append(anno)
|
| 227 |
+
else:
|
| 228 |
+
# See also the official annotation parsing scripts at
|
| 229 |
+
# https://github.com/mcordts/cityscapesScripts/blob/master/cityscapesscripts/evaluation/instances2dict.py # noqa
|
| 230 |
+
with PathManager.open(instance_id_file, "rb") as f:
|
| 231 |
+
inst_image = np.asarray(Image.open(f), order="F")
|
| 232 |
+
# ids < 24 are stuff labels (filtering them first is about 5% faster)
|
| 233 |
+
flattened_ids = np.unique(inst_image[inst_image >= 24])
|
| 234 |
+
|
| 235 |
+
ret = {
|
| 236 |
+
"file_name": image_file,
|
| 237 |
+
"image_id": os.path.basename(image_file),
|
| 238 |
+
"height": inst_image.shape[0],
|
| 239 |
+
"width": inst_image.shape[1],
|
| 240 |
+
}
|
| 241 |
+
|
| 242 |
+
for instance_id in flattened_ids:
|
| 243 |
+
# For non-crowd annotations, instance_id // 1000 is the label_id
|
| 244 |
+
# Crowd annotations have <1000 instance ids
|
| 245 |
+
label_id = instance_id // 1000 if instance_id >= 1000 else instance_id
|
| 246 |
+
label = id2label[label_id]
|
| 247 |
+
if not label.hasInstances or label.ignoreInEval:
|
| 248 |
+
continue
|
| 249 |
+
|
| 250 |
+
anno = {}
|
| 251 |
+
anno["iscrowd"] = instance_id < 1000
|
| 252 |
+
anno["category_id"] = label.id
|
| 253 |
+
|
| 254 |
+
mask = np.asarray(inst_image == instance_id, dtype=np.uint8, order="F")
|
| 255 |
+
|
| 256 |
+
inds = np.nonzero(mask)
|
| 257 |
+
ymin, ymax = inds[0].min(), inds[0].max()
|
| 258 |
+
xmin, xmax = inds[1].min(), inds[1].max()
|
| 259 |
+
anno["bbox"] = (xmin, ymin, xmax, ymax)
|
| 260 |
+
if xmax <= xmin or ymax <= ymin:
|
| 261 |
+
continue
|
| 262 |
+
anno["bbox_mode"] = BoxMode.XYXY_ABS
|
| 263 |
+
if to_polygons:
|
| 264 |
+
# This conversion comes from D4809743 and D5171122,
|
| 265 |
+
# when Mask-RCNN was first developed.
|
| 266 |
+
contours = cv2.findContours(mask.copy(), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)[
|
| 267 |
+
-2
|
| 268 |
+
]
|
| 269 |
+
polygons = [c.reshape(-1).tolist() for c in contours if len(c) >= 3]
|
| 270 |
+
# opencv's can produce invalid polygons
|
| 271 |
+
if len(polygons) == 0:
|
| 272 |
+
continue
|
| 273 |
+
anno["segmentation"] = polygons
|
| 274 |
+
else:
|
| 275 |
+
anno["segmentation"] = mask_util.encode(mask[:, :, None])[0]
|
| 276 |
+
annos.append(anno)
|
| 277 |
+
ret["annotations"] = annos
|
| 278 |
+
return ret
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
def main() -> None:
|
| 282 |
+
global logger, labels
|
| 283 |
+
"""
|
| 284 |
+
Test the cityscapes dataset loader.
|
| 285 |
+
|
| 286 |
+
Usage:
|
| 287 |
+
python -m detectron2.data.datasets.cityscapes \
|
| 288 |
+
cityscapes/leftImg8bit/train cityscapes/gtFine/train
|
| 289 |
+
"""
|
| 290 |
+
import argparse
|
| 291 |
+
|
| 292 |
+
parser = argparse.ArgumentParser()
|
| 293 |
+
parser.add_argument("image_dir")
|
| 294 |
+
parser.add_argument("gt_dir")
|
| 295 |
+
parser.add_argument("--type", choices=["instance", "semantic"], default="instance")
|
| 296 |
+
args = parser.parse_args()
|
| 297 |
+
from cityscapesscripts.helpers.labels import labels
|
| 298 |
+
from detectron2.data.catalog import Metadata
|
| 299 |
+
from detectron2.utils.visualizer import Visualizer
|
| 300 |
+
|
| 301 |
+
logger = setup_logger(name=__name__)
|
| 302 |
+
|
| 303 |
+
dirname = "cityscapes-data-vis"
|
| 304 |
+
os.makedirs(dirname, exist_ok=True)
|
| 305 |
+
|
| 306 |
+
if args.type == "instance":
|
| 307 |
+
dicts = load_cityscapes_instances(
|
| 308 |
+
args.image_dir, args.gt_dir, from_json=True, to_polygons=True
|
| 309 |
+
)
|
| 310 |
+
logger.info("Done loading {} samples.".format(len(dicts)))
|
| 311 |
+
|
| 312 |
+
thing_classes = [k.name for k in labels if k.hasInstances and not k.ignoreInEval]
|
| 313 |
+
meta = Metadata().set(thing_classes=thing_classes)
|
| 314 |
+
|
| 315 |
+
else:
|
| 316 |
+
dicts = load_cityscapes_semantic(args.image_dir, args.gt_dir)
|
| 317 |
+
logger.info("Done loading {} samples.".format(len(dicts)))
|
| 318 |
+
|
| 319 |
+
stuff_classes = [k.name for k in labels if k.trainId != 255]
|
| 320 |
+
stuff_colors = [k.color for k in labels if k.trainId != 255]
|
| 321 |
+
meta = Metadata().set(stuff_classes=stuff_classes, stuff_colors=stuff_colors)
|
| 322 |
+
|
| 323 |
+
for d in dicts:
|
| 324 |
+
img = np.array(Image.open(PathManager.open(d["file_name"], "rb")))
|
| 325 |
+
visualizer = Visualizer(img, metadata=meta)
|
| 326 |
+
vis = visualizer.draw_dataset_dict(d)
|
| 327 |
+
# cv2.imshow("a", vis.get_image()[:, :, ::-1])
|
| 328 |
+
# cv2.waitKey()
|
| 329 |
+
fpath = os.path.join(dirname, os.path.basename(d["file_name"]))
|
| 330 |
+
vis.save(fpath)
|
| 331 |
+
|
| 332 |
+
|
| 333 |
+
if __name__ == "__main__":
|
| 334 |
+
main() # pragma: no cover
|
CatVTON/detectron2/data/datasets/cityscapes_panoptic.py
ADDED
|
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
import json
|
| 3 |
+
import logging
|
| 4 |
+
import os
|
| 5 |
+
|
| 6 |
+
from detectron2.data import DatasetCatalog, MetadataCatalog
|
| 7 |
+
from detectron2.data.datasets.builtin_meta import CITYSCAPES_CATEGORIES
|
| 8 |
+
from detectron2.utils.file_io import PathManager
|
| 9 |
+
|
| 10 |
+
"""
|
| 11 |
+
This file contains functions to register the Cityscapes panoptic dataset to the DatasetCatalog.
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
logger = logging.getLogger(__name__)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def get_cityscapes_panoptic_files(image_dir, gt_dir, json_info):
|
| 19 |
+
files = []
|
| 20 |
+
# scan through the directory
|
| 21 |
+
cities = PathManager.ls(image_dir)
|
| 22 |
+
logger.info(f"{len(cities)} cities found in '{image_dir}'.")
|
| 23 |
+
image_dict = {}
|
| 24 |
+
for city in cities:
|
| 25 |
+
city_img_dir = os.path.join(image_dir, city)
|
| 26 |
+
for basename in PathManager.ls(city_img_dir):
|
| 27 |
+
image_file = os.path.join(city_img_dir, basename)
|
| 28 |
+
|
| 29 |
+
suffix = "_leftImg8bit.png"
|
| 30 |
+
assert basename.endswith(suffix), basename
|
| 31 |
+
basename = os.path.basename(basename)[: -len(suffix)]
|
| 32 |
+
|
| 33 |
+
image_dict[basename] = image_file
|
| 34 |
+
|
| 35 |
+
for ann in json_info["annotations"]:
|
| 36 |
+
image_file = image_dict.get(ann["image_id"], None)
|
| 37 |
+
assert image_file is not None, "No image {} found for annotation {}".format(
|
| 38 |
+
ann["image_id"], ann["file_name"]
|
| 39 |
+
)
|
| 40 |
+
label_file = os.path.join(gt_dir, ann["file_name"])
|
| 41 |
+
segments_info = ann["segments_info"]
|
| 42 |
+
|
| 43 |
+
files.append((image_file, label_file, segments_info))
|
| 44 |
+
|
| 45 |
+
assert len(files), "No images found in {}".format(image_dir)
|
| 46 |
+
assert PathManager.isfile(files[0][0]), files[0][0]
|
| 47 |
+
assert PathManager.isfile(files[0][1]), files[0][1]
|
| 48 |
+
return files
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def load_cityscapes_panoptic(image_dir, gt_dir, gt_json, meta):
|
| 52 |
+
"""
|
| 53 |
+
Args:
|
| 54 |
+
image_dir (str): path to the raw dataset. e.g., "~/cityscapes/leftImg8bit/train".
|
| 55 |
+
gt_dir (str): path to the raw annotations. e.g.,
|
| 56 |
+
"~/cityscapes/gtFine/cityscapes_panoptic_train".
|
| 57 |
+
gt_json (str): path to the json file. e.g.,
|
| 58 |
+
"~/cityscapes/gtFine/cityscapes_panoptic_train.json".
|
| 59 |
+
meta (dict): dictionary containing "thing_dataset_id_to_contiguous_id"
|
| 60 |
+
and "stuff_dataset_id_to_contiguous_id" to map category ids to
|
| 61 |
+
contiguous ids for training.
|
| 62 |
+
|
| 63 |
+
Returns:
|
| 64 |
+
list[dict]: a list of dicts in Detectron2 standard format. (See
|
| 65 |
+
`Using Custom Datasets </tutorials/datasets.html>`_ )
|
| 66 |
+
"""
|
| 67 |
+
|
| 68 |
+
def _convert_category_id(segment_info, meta):
|
| 69 |
+
if segment_info["category_id"] in meta["thing_dataset_id_to_contiguous_id"]:
|
| 70 |
+
segment_info["category_id"] = meta["thing_dataset_id_to_contiguous_id"][
|
| 71 |
+
segment_info["category_id"]
|
| 72 |
+
]
|
| 73 |
+
else:
|
| 74 |
+
segment_info["category_id"] = meta["stuff_dataset_id_to_contiguous_id"][
|
| 75 |
+
segment_info["category_id"]
|
| 76 |
+
]
|
| 77 |
+
return segment_info
|
| 78 |
+
|
| 79 |
+
assert os.path.exists(
|
| 80 |
+
gt_json
|
| 81 |
+
), "Please run `python cityscapesscripts/preparation/createPanopticImgs.py` to generate label files." # noqa
|
| 82 |
+
with open(gt_json) as f:
|
| 83 |
+
json_info = json.load(f)
|
| 84 |
+
files = get_cityscapes_panoptic_files(image_dir, gt_dir, json_info)
|
| 85 |
+
ret = []
|
| 86 |
+
for image_file, label_file, segments_info in files:
|
| 87 |
+
sem_label_file = (
|
| 88 |
+
image_file.replace("leftImg8bit", "gtFine").split(".")[0] + "_labelTrainIds.png"
|
| 89 |
+
)
|
| 90 |
+
segments_info = [_convert_category_id(x, meta) for x in segments_info]
|
| 91 |
+
ret.append(
|
| 92 |
+
{
|
| 93 |
+
"file_name": image_file,
|
| 94 |
+
"image_id": "_".join(
|
| 95 |
+
os.path.splitext(os.path.basename(image_file))[0].split("_")[:3]
|
| 96 |
+
),
|
| 97 |
+
"sem_seg_file_name": sem_label_file,
|
| 98 |
+
"pan_seg_file_name": label_file,
|
| 99 |
+
"segments_info": segments_info,
|
| 100 |
+
}
|
| 101 |
+
)
|
| 102 |
+
assert len(ret), f"No images found in {image_dir}!"
|
| 103 |
+
assert PathManager.isfile(
|
| 104 |
+
ret[0]["sem_seg_file_name"]
|
| 105 |
+
), "Please generate labelTrainIds.png with cityscapesscripts/preparation/createTrainIdLabelImgs.py" # noqa
|
| 106 |
+
assert PathManager.isfile(
|
| 107 |
+
ret[0]["pan_seg_file_name"]
|
| 108 |
+
), "Please generate panoptic annotation with python cityscapesscripts/preparation/createPanopticImgs.py" # noqa
|
| 109 |
+
return ret
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
_RAW_CITYSCAPES_PANOPTIC_SPLITS = {
|
| 113 |
+
"cityscapes_fine_panoptic_train": (
|
| 114 |
+
"cityscapes/leftImg8bit/train",
|
| 115 |
+
"cityscapes/gtFine/cityscapes_panoptic_train",
|
| 116 |
+
"cityscapes/gtFine/cityscapes_panoptic_train.json",
|
| 117 |
+
),
|
| 118 |
+
"cityscapes_fine_panoptic_val": (
|
| 119 |
+
"cityscapes/leftImg8bit/val",
|
| 120 |
+
"cityscapes/gtFine/cityscapes_panoptic_val",
|
| 121 |
+
"cityscapes/gtFine/cityscapes_panoptic_val.json",
|
| 122 |
+
),
|
| 123 |
+
# "cityscapes_fine_panoptic_test": not supported yet
|
| 124 |
+
}
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def register_all_cityscapes_panoptic(root):
|
| 128 |
+
meta = {}
|
| 129 |
+
# The following metadata maps contiguous id from [0, #thing categories +
|
| 130 |
+
# #stuff categories) to their names and colors. We have to replica of the
|
| 131 |
+
# same name and color under "thing_*" and "stuff_*" because the current
|
| 132 |
+
# visualization function in D2 handles thing and class classes differently
|
| 133 |
+
# due to some heuristic used in Panoptic FPN. We keep the same naming to
|
| 134 |
+
# enable reusing existing visualization functions.
|
| 135 |
+
thing_classes = [k["name"] for k in CITYSCAPES_CATEGORIES]
|
| 136 |
+
thing_colors = [k["color"] for k in CITYSCAPES_CATEGORIES]
|
| 137 |
+
stuff_classes = [k["name"] for k in CITYSCAPES_CATEGORIES]
|
| 138 |
+
stuff_colors = [k["color"] for k in CITYSCAPES_CATEGORIES]
|
| 139 |
+
|
| 140 |
+
meta["thing_classes"] = thing_classes
|
| 141 |
+
meta["thing_colors"] = thing_colors
|
| 142 |
+
meta["stuff_classes"] = stuff_classes
|
| 143 |
+
meta["stuff_colors"] = stuff_colors
|
| 144 |
+
|
| 145 |
+
# There are three types of ids in cityscapes panoptic segmentation:
|
| 146 |
+
# (1) category id: like semantic segmentation, it is the class id for each
|
| 147 |
+
# pixel. Since there are some classes not used in evaluation, the category
|
| 148 |
+
# id is not always contiguous and thus we have two set of category ids:
|
| 149 |
+
# - original category id: category id in the original dataset, mainly
|
| 150 |
+
# used for evaluation.
|
| 151 |
+
# - contiguous category id: [0, #classes), in order to train the classifier
|
| 152 |
+
# (2) instance id: this id is used to differentiate different instances from
|
| 153 |
+
# the same category. For "stuff" classes, the instance id is always 0; for
|
| 154 |
+
# "thing" classes, the instance id starts from 1 and 0 is reserved for
|
| 155 |
+
# ignored instances (e.g. crowd annotation).
|
| 156 |
+
# (3) panoptic id: this is the compact id that encode both category and
|
| 157 |
+
# instance id by: category_id * 1000 + instance_id.
|
| 158 |
+
thing_dataset_id_to_contiguous_id = {}
|
| 159 |
+
stuff_dataset_id_to_contiguous_id = {}
|
| 160 |
+
|
| 161 |
+
for k in CITYSCAPES_CATEGORIES:
|
| 162 |
+
if k["isthing"] == 1:
|
| 163 |
+
thing_dataset_id_to_contiguous_id[k["id"]] = k["trainId"]
|
| 164 |
+
else:
|
| 165 |
+
stuff_dataset_id_to_contiguous_id[k["id"]] = k["trainId"]
|
| 166 |
+
|
| 167 |
+
meta["thing_dataset_id_to_contiguous_id"] = thing_dataset_id_to_contiguous_id
|
| 168 |
+
meta["stuff_dataset_id_to_contiguous_id"] = stuff_dataset_id_to_contiguous_id
|
| 169 |
+
|
| 170 |
+
for key, (image_dir, gt_dir, gt_json) in _RAW_CITYSCAPES_PANOPTIC_SPLITS.items():
|
| 171 |
+
image_dir = os.path.join(root, image_dir)
|
| 172 |
+
gt_dir = os.path.join(root, gt_dir)
|
| 173 |
+
gt_json = os.path.join(root, gt_json)
|
| 174 |
+
|
| 175 |
+
DatasetCatalog.register(
|
| 176 |
+
key, lambda x=image_dir, y=gt_dir, z=gt_json: load_cityscapes_panoptic(x, y, z, meta)
|
| 177 |
+
)
|
| 178 |
+
MetadataCatalog.get(key).set(
|
| 179 |
+
panoptic_root=gt_dir,
|
| 180 |
+
image_root=image_dir,
|
| 181 |
+
panoptic_json=gt_json,
|
| 182 |
+
gt_dir=gt_dir.replace("cityscapes_panoptic_", ""),
|
| 183 |
+
evaluator_type="cityscapes_panoptic_seg",
|
| 184 |
+
ignore_label=255,
|
| 185 |
+
label_divisor=1000,
|
| 186 |
+
**meta,
|
| 187 |
+
)
|
CatVTON/detectron2/data/datasets/coco.py
ADDED
|
@@ -0,0 +1,555 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
import contextlib
|
| 3 |
+
import datetime
|
| 4 |
+
import io
|
| 5 |
+
import json
|
| 6 |
+
import logging
|
| 7 |
+
import numpy as np
|
| 8 |
+
import os
|
| 9 |
+
import shutil
|
| 10 |
+
import pycocotools.mask as mask_util
|
| 11 |
+
from fvcore.common.timer import Timer
|
| 12 |
+
from iopath.common.file_io import file_lock
|
| 13 |
+
from PIL import Image
|
| 14 |
+
|
| 15 |
+
from detectron2.structures import Boxes, BoxMode, PolygonMasks, RotatedBoxes
|
| 16 |
+
from detectron2.utils.file_io import PathManager
|
| 17 |
+
|
| 18 |
+
from .. import DatasetCatalog, MetadataCatalog
|
| 19 |
+
|
| 20 |
+
"""
|
| 21 |
+
This file contains functions to parse COCO-format annotations into dicts in "Detectron2 format".
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
logger = logging.getLogger(__name__)
|
| 26 |
+
|
| 27 |
+
__all__ = [
|
| 28 |
+
"load_coco_json",
|
| 29 |
+
"load_sem_seg",
|
| 30 |
+
"convert_to_coco_json",
|
| 31 |
+
"register_coco_instances",
|
| 32 |
+
]
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def load_coco_json(json_file, image_root, dataset_name=None, extra_annotation_keys=None):
|
| 36 |
+
"""
|
| 37 |
+
Load a json file with COCO's instances annotation format.
|
| 38 |
+
Currently supports instance detection, instance segmentation,
|
| 39 |
+
and person keypoints annotations.
|
| 40 |
+
|
| 41 |
+
Args:
|
| 42 |
+
json_file (str): full path to the json file in COCO instances annotation format.
|
| 43 |
+
image_root (str or path-like): the directory where the images in this json file exists.
|
| 44 |
+
dataset_name (str or None): the name of the dataset (e.g., coco_2017_train).
|
| 45 |
+
When provided, this function will also do the following:
|
| 46 |
+
|
| 47 |
+
* Put "thing_classes" into the metadata associated with this dataset.
|
| 48 |
+
* Map the category ids into a contiguous range (needed by standard dataset format),
|
| 49 |
+
and add "thing_dataset_id_to_contiguous_id" to the metadata associated
|
| 50 |
+
with this dataset.
|
| 51 |
+
|
| 52 |
+
This option should usually be provided, unless users need to load
|
| 53 |
+
the original json content and apply more processing manually.
|
| 54 |
+
extra_annotation_keys (list[str]): list of per-annotation keys that should also be
|
| 55 |
+
loaded into the dataset dict (besides "iscrowd", "bbox", "keypoints",
|
| 56 |
+
"category_id", "segmentation"). The values for these keys will be returned as-is.
|
| 57 |
+
For example, the densepose annotations are loaded in this way.
|
| 58 |
+
|
| 59 |
+
Returns:
|
| 60 |
+
list[dict]: a list of dicts in Detectron2 standard dataset dicts format (See
|
| 61 |
+
`Using Custom Datasets </tutorials/datasets.html>`_ ) when `dataset_name` is not None.
|
| 62 |
+
If `dataset_name` is None, the returned `category_ids` may be
|
| 63 |
+
incontiguous and may not conform to the Detectron2 standard format.
|
| 64 |
+
|
| 65 |
+
Notes:
|
| 66 |
+
1. This function does not read the image files.
|
| 67 |
+
The results do not have the "image" field.
|
| 68 |
+
"""
|
| 69 |
+
from pycocotools.coco import COCO
|
| 70 |
+
|
| 71 |
+
timer = Timer()
|
| 72 |
+
json_file = PathManager.get_local_path(json_file)
|
| 73 |
+
with contextlib.redirect_stdout(io.StringIO()):
|
| 74 |
+
coco_api = COCO(json_file)
|
| 75 |
+
if timer.seconds() > 1:
|
| 76 |
+
logger.info("Loading {} takes {:.2f} seconds.".format(json_file, timer.seconds()))
|
| 77 |
+
|
| 78 |
+
id_map = None
|
| 79 |
+
if dataset_name is not None:
|
| 80 |
+
meta = MetadataCatalog.get(dataset_name)
|
| 81 |
+
cat_ids = sorted(coco_api.getCatIds())
|
| 82 |
+
cats = coco_api.loadCats(cat_ids)
|
| 83 |
+
# The categories in a custom json file may not be sorted.
|
| 84 |
+
thing_classes = [c["name"] for c in sorted(cats, key=lambda x: x["id"])]
|
| 85 |
+
meta.thing_classes = thing_classes
|
| 86 |
+
|
| 87 |
+
# In COCO, certain category ids are artificially removed,
|
| 88 |
+
# and by convention they are always ignored.
|
| 89 |
+
# We deal with COCO's id issue and translate
|
| 90 |
+
# the category ids to contiguous ids in [0, 80).
|
| 91 |
+
|
| 92 |
+
# It works by looking at the "categories" field in the json, therefore
|
| 93 |
+
# if users' own json also have incontiguous ids, we'll
|
| 94 |
+
# apply this mapping as well but print a warning.
|
| 95 |
+
if not (min(cat_ids) == 1 and max(cat_ids) == len(cat_ids)):
|
| 96 |
+
if "coco" not in dataset_name:
|
| 97 |
+
logger.warning(
|
| 98 |
+
"""
|
| 99 |
+
Category ids in annotations are not in [1, #categories]! We'll apply a mapping for you.
|
| 100 |
+
"""
|
| 101 |
+
)
|
| 102 |
+
id_map = {v: i for i, v in enumerate(cat_ids)}
|
| 103 |
+
meta.thing_dataset_id_to_contiguous_id = id_map
|
| 104 |
+
|
| 105 |
+
# sort indices for reproducible results
|
| 106 |
+
img_ids = sorted(coco_api.imgs.keys())
|
| 107 |
+
# imgs is a list of dicts, each looks something like:
|
| 108 |
+
# {'license': 4,
|
| 109 |
+
# 'url': 'http://farm6.staticflickr.com/5454/9413846304_881d5e5c3b_z.jpg',
|
| 110 |
+
# 'file_name': 'COCO_val2014_000000001268.jpg',
|
| 111 |
+
# 'height': 427,
|
| 112 |
+
# 'width': 640,
|
| 113 |
+
# 'date_captured': '2013-11-17 05:57:24',
|
| 114 |
+
# 'id': 1268}
|
| 115 |
+
imgs = coco_api.loadImgs(img_ids)
|
| 116 |
+
# anns is a list[list[dict]], where each dict is an annotation
|
| 117 |
+
# record for an object. The inner list enumerates the objects in an image
|
| 118 |
+
# and the outer list enumerates over images. Example of anns[0]:
|
| 119 |
+
# [{'segmentation': [[192.81,
|
| 120 |
+
# 247.09,
|
| 121 |
+
# ...
|
| 122 |
+
# 219.03,
|
| 123 |
+
# 249.06]],
|
| 124 |
+
# 'area': 1035.749,
|
| 125 |
+
# 'iscrowd': 0,
|
| 126 |
+
# 'image_id': 1268,
|
| 127 |
+
# 'bbox': [192.81, 224.8, 74.73, 33.43],
|
| 128 |
+
# 'category_id': 16,
|
| 129 |
+
# 'id': 42986},
|
| 130 |
+
# ...]
|
| 131 |
+
anns = [coco_api.imgToAnns[img_id] for img_id in img_ids]
|
| 132 |
+
total_num_valid_anns = sum([len(x) for x in anns])
|
| 133 |
+
total_num_anns = len(coco_api.anns)
|
| 134 |
+
if total_num_valid_anns < total_num_anns:
|
| 135 |
+
logger.warning(
|
| 136 |
+
f"{json_file} contains {total_num_anns} annotations, but only "
|
| 137 |
+
f"{total_num_valid_anns} of them match to images in the file."
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
if "minival" not in json_file:
|
| 141 |
+
# The popular valminusminival & minival annotations for COCO2014 contain this bug.
|
| 142 |
+
# However the ratio of buggy annotations there is tiny and does not affect accuracy.
|
| 143 |
+
# Therefore we explicitly white-list them.
|
| 144 |
+
ann_ids = [ann["id"] for anns_per_image in anns for ann in anns_per_image]
|
| 145 |
+
assert len(set(ann_ids)) == len(ann_ids), "Annotation ids in '{}' are not unique!".format(
|
| 146 |
+
json_file
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
imgs_anns = list(zip(imgs, anns))
|
| 150 |
+
logger.info("Loaded {} images in COCO format from {}".format(len(imgs_anns), json_file))
|
| 151 |
+
|
| 152 |
+
dataset_dicts = []
|
| 153 |
+
|
| 154 |
+
ann_keys = ["iscrowd", "bbox", "keypoints", "category_id"] + (extra_annotation_keys or [])
|
| 155 |
+
|
| 156 |
+
num_instances_without_valid_segmentation = 0
|
| 157 |
+
|
| 158 |
+
for img_dict, anno_dict_list in imgs_anns:
|
| 159 |
+
record = {}
|
| 160 |
+
record["file_name"] = os.path.join(image_root, img_dict["file_name"])
|
| 161 |
+
record["height"] = img_dict["height"]
|
| 162 |
+
record["width"] = img_dict["width"]
|
| 163 |
+
image_id = record["image_id"] = img_dict["id"]
|
| 164 |
+
|
| 165 |
+
objs = []
|
| 166 |
+
for anno in anno_dict_list:
|
| 167 |
+
# Check that the image_id in this annotation is the same as
|
| 168 |
+
# the image_id we're looking at.
|
| 169 |
+
# This fails only when the data parsing logic or the annotation file is buggy.
|
| 170 |
+
|
| 171 |
+
# The original COCO valminusminival2014 & minival2014 annotation files
|
| 172 |
+
# actually contains bugs that, together with certain ways of using COCO API,
|
| 173 |
+
# can trigger this assertion.
|
| 174 |
+
assert anno["image_id"] == image_id
|
| 175 |
+
|
| 176 |
+
assert anno.get("ignore", 0) == 0, '"ignore" in COCO json file is not supported.'
|
| 177 |
+
|
| 178 |
+
obj = {key: anno[key] for key in ann_keys if key in anno}
|
| 179 |
+
if "bbox" in obj and len(obj["bbox"]) == 0:
|
| 180 |
+
raise ValueError(
|
| 181 |
+
f"One annotation of image {image_id} contains empty 'bbox' value! "
|
| 182 |
+
"This json does not have valid COCO format."
|
| 183 |
+
)
|
| 184 |
+
|
| 185 |
+
segm = anno.get("segmentation", None)
|
| 186 |
+
if segm: # either list[list[float]] or dict(RLE)
|
| 187 |
+
if isinstance(segm, dict):
|
| 188 |
+
if isinstance(segm["counts"], list):
|
| 189 |
+
# convert to compressed RLE
|
| 190 |
+
segm = mask_util.frPyObjects(segm, *segm["size"])
|
| 191 |
+
else:
|
| 192 |
+
# filter out invalid polygons (< 3 points)
|
| 193 |
+
segm = [poly for poly in segm if len(poly) % 2 == 0 and len(poly) >= 6]
|
| 194 |
+
if len(segm) == 0:
|
| 195 |
+
num_instances_without_valid_segmentation += 1
|
| 196 |
+
continue # ignore this instance
|
| 197 |
+
obj["segmentation"] = segm
|
| 198 |
+
|
| 199 |
+
keypts = anno.get("keypoints", None)
|
| 200 |
+
if keypts: # list[int]
|
| 201 |
+
for idx, v in enumerate(keypts):
|
| 202 |
+
if idx % 3 != 2:
|
| 203 |
+
# COCO's segmentation coordinates are floating points in [0, H or W],
|
| 204 |
+
# but keypoint coordinates are integers in [0, H-1 or W-1]
|
| 205 |
+
# Therefore we assume the coordinates are "pixel indices" and
|
| 206 |
+
# add 0.5 to convert to floating point coordinates.
|
| 207 |
+
keypts[idx] = v + 0.5
|
| 208 |
+
obj["keypoints"] = keypts
|
| 209 |
+
|
| 210 |
+
obj["bbox_mode"] = BoxMode.XYWH_ABS
|
| 211 |
+
if id_map:
|
| 212 |
+
annotation_category_id = obj["category_id"]
|
| 213 |
+
try:
|
| 214 |
+
obj["category_id"] = id_map[annotation_category_id]
|
| 215 |
+
except KeyError as e:
|
| 216 |
+
raise KeyError(
|
| 217 |
+
f"Encountered category_id={annotation_category_id} "
|
| 218 |
+
"but this id does not exist in 'categories' of the json file."
|
| 219 |
+
) from e
|
| 220 |
+
objs.append(obj)
|
| 221 |
+
record["annotations"] = objs
|
| 222 |
+
dataset_dicts.append(record)
|
| 223 |
+
|
| 224 |
+
if num_instances_without_valid_segmentation > 0:
|
| 225 |
+
logger.warning(
|
| 226 |
+
"Filtered out {} instances without valid segmentation. ".format(
|
| 227 |
+
num_instances_without_valid_segmentation
|
| 228 |
+
)
|
| 229 |
+
+ "There might be issues in your dataset generation process. Please "
|
| 230 |
+
"check https://detectron2.readthedocs.io/en/latest/tutorials/datasets.html carefully"
|
| 231 |
+
)
|
| 232 |
+
return dataset_dicts
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
def load_sem_seg(gt_root, image_root, gt_ext="png", image_ext="jpg"):
|
| 236 |
+
"""
|
| 237 |
+
Load semantic segmentation datasets. All files under "gt_root" with "gt_ext" extension are
|
| 238 |
+
treated as ground truth annotations and all files under "image_root" with "image_ext" extension
|
| 239 |
+
as input images. Ground truth and input images are matched using file paths relative to
|
| 240 |
+
"gt_root" and "image_root" respectively without taking into account file extensions.
|
| 241 |
+
This works for COCO as well as some other datasets.
|
| 242 |
+
|
| 243 |
+
Args:
|
| 244 |
+
gt_root (str): full path to ground truth semantic segmentation files. Semantic segmentation
|
| 245 |
+
annotations are stored as images with integer values in pixels that represent
|
| 246 |
+
corresponding semantic labels.
|
| 247 |
+
image_root (str): the directory where the input images are.
|
| 248 |
+
gt_ext (str): file extension for ground truth annotations.
|
| 249 |
+
image_ext (str): file extension for input images.
|
| 250 |
+
|
| 251 |
+
Returns:
|
| 252 |
+
list[dict]:
|
| 253 |
+
a list of dicts in detectron2 standard format without instance-level
|
| 254 |
+
annotation.
|
| 255 |
+
|
| 256 |
+
Notes:
|
| 257 |
+
1. This function does not read the image and ground truth files.
|
| 258 |
+
The results do not have the "image" and "sem_seg" fields.
|
| 259 |
+
"""
|
| 260 |
+
|
| 261 |
+
# We match input images with ground truth based on their relative filepaths (without file
|
| 262 |
+
# extensions) starting from 'image_root' and 'gt_root' respectively.
|
| 263 |
+
def file2id(folder_path, file_path):
|
| 264 |
+
# extract relative path starting from `folder_path`
|
| 265 |
+
image_id = os.path.normpath(os.path.relpath(file_path, start=folder_path))
|
| 266 |
+
# remove file extension
|
| 267 |
+
image_id = os.path.splitext(image_id)[0]
|
| 268 |
+
return image_id
|
| 269 |
+
|
| 270 |
+
input_files = sorted(
|
| 271 |
+
(os.path.join(image_root, f) for f in PathManager.ls(image_root) if f.endswith(image_ext)),
|
| 272 |
+
key=lambda file_path: file2id(image_root, file_path),
|
| 273 |
+
)
|
| 274 |
+
gt_files = sorted(
|
| 275 |
+
(os.path.join(gt_root, f) for f in PathManager.ls(gt_root) if f.endswith(gt_ext)),
|
| 276 |
+
key=lambda file_path: file2id(gt_root, file_path),
|
| 277 |
+
)
|
| 278 |
+
|
| 279 |
+
assert len(gt_files) > 0, "No annotations found in {}.".format(gt_root)
|
| 280 |
+
|
| 281 |
+
# Use the intersection, so that val2017_100 annotations can run smoothly with val2017 images
|
| 282 |
+
if len(input_files) != len(gt_files):
|
| 283 |
+
logger.warn(
|
| 284 |
+
"Directory {} and {} has {} and {} files, respectively.".format(
|
| 285 |
+
image_root, gt_root, len(input_files), len(gt_files)
|
| 286 |
+
)
|
| 287 |
+
)
|
| 288 |
+
input_basenames = [os.path.basename(f)[: -len(image_ext)] for f in input_files]
|
| 289 |
+
gt_basenames = [os.path.basename(f)[: -len(gt_ext)] for f in gt_files]
|
| 290 |
+
intersect = list(set(input_basenames) & set(gt_basenames))
|
| 291 |
+
# sort, otherwise each worker may obtain a list[dict] in different order
|
| 292 |
+
intersect = sorted(intersect)
|
| 293 |
+
logger.warn("Will use their intersection of {} files.".format(len(intersect)))
|
| 294 |
+
input_files = [os.path.join(image_root, f + image_ext) for f in intersect]
|
| 295 |
+
gt_files = [os.path.join(gt_root, f + gt_ext) for f in intersect]
|
| 296 |
+
|
| 297 |
+
logger.info(
|
| 298 |
+
"Loaded {} images with semantic segmentation from {}".format(len(input_files), image_root)
|
| 299 |
+
)
|
| 300 |
+
|
| 301 |
+
dataset_dicts = []
|
| 302 |
+
for img_path, gt_path in zip(input_files, gt_files):
|
| 303 |
+
record = {}
|
| 304 |
+
record["file_name"] = img_path
|
| 305 |
+
record["sem_seg_file_name"] = gt_path
|
| 306 |
+
dataset_dicts.append(record)
|
| 307 |
+
|
| 308 |
+
return dataset_dicts
|
| 309 |
+
|
| 310 |
+
|
| 311 |
+
def convert_to_coco_dict(dataset_name):
|
| 312 |
+
"""
|
| 313 |
+
Convert an instance detection/segmentation or keypoint detection dataset
|
| 314 |
+
in detectron2's standard format into COCO json format.
|
| 315 |
+
|
| 316 |
+
Generic dataset description can be found here:
|
| 317 |
+
https://detectron2.readthedocs.io/tutorials/datasets.html#register-a-dataset
|
| 318 |
+
|
| 319 |
+
COCO data format description can be found here:
|
| 320 |
+
http://cocodataset.org/#format-data
|
| 321 |
+
|
| 322 |
+
Args:
|
| 323 |
+
dataset_name (str):
|
| 324 |
+
name of the source dataset
|
| 325 |
+
Must be registered in DatastCatalog and in detectron2's standard format.
|
| 326 |
+
Must have corresponding metadata "thing_classes"
|
| 327 |
+
Returns:
|
| 328 |
+
coco_dict: serializable dict in COCO json format
|
| 329 |
+
"""
|
| 330 |
+
|
| 331 |
+
dataset_dicts = DatasetCatalog.get(dataset_name)
|
| 332 |
+
metadata = MetadataCatalog.get(dataset_name)
|
| 333 |
+
|
| 334 |
+
# unmap the category mapping ids for COCO
|
| 335 |
+
if hasattr(metadata, "thing_dataset_id_to_contiguous_id"):
|
| 336 |
+
reverse_id_mapping = {v: k for k, v in metadata.thing_dataset_id_to_contiguous_id.items()}
|
| 337 |
+
reverse_id_mapper = lambda contiguous_id: reverse_id_mapping[contiguous_id] # noqa
|
| 338 |
+
else:
|
| 339 |
+
reverse_id_mapper = lambda contiguous_id: contiguous_id # noqa
|
| 340 |
+
|
| 341 |
+
categories = [
|
| 342 |
+
{"id": reverse_id_mapper(id), "name": name}
|
| 343 |
+
for id, name in enumerate(metadata.thing_classes)
|
| 344 |
+
]
|
| 345 |
+
|
| 346 |
+
logger.info("Converting dataset dicts into COCO format")
|
| 347 |
+
coco_images = []
|
| 348 |
+
coco_annotations = []
|
| 349 |
+
|
| 350 |
+
for image_id, image_dict in enumerate(dataset_dicts):
|
| 351 |
+
coco_image = {
|
| 352 |
+
"id": image_dict.get("image_id", image_id),
|
| 353 |
+
"width": int(image_dict["width"]),
|
| 354 |
+
"height": int(image_dict["height"]),
|
| 355 |
+
"file_name": str(image_dict["file_name"]),
|
| 356 |
+
}
|
| 357 |
+
coco_images.append(coco_image)
|
| 358 |
+
|
| 359 |
+
anns_per_image = image_dict.get("annotations", [])
|
| 360 |
+
for annotation in anns_per_image:
|
| 361 |
+
# create a new dict with only COCO fields
|
| 362 |
+
coco_annotation = {}
|
| 363 |
+
|
| 364 |
+
# COCO requirement: XYWH box format for axis-align and XYWHA for rotated
|
| 365 |
+
bbox = annotation["bbox"]
|
| 366 |
+
if isinstance(bbox, np.ndarray):
|
| 367 |
+
if bbox.ndim != 1:
|
| 368 |
+
raise ValueError(f"bbox has to be 1-dimensional. Got shape={bbox.shape}.")
|
| 369 |
+
bbox = bbox.tolist()
|
| 370 |
+
if len(bbox) not in [4, 5]:
|
| 371 |
+
raise ValueError(f"bbox has to has length 4 or 5. Got {bbox}.")
|
| 372 |
+
from_bbox_mode = annotation["bbox_mode"]
|
| 373 |
+
to_bbox_mode = BoxMode.XYWH_ABS if len(bbox) == 4 else BoxMode.XYWHA_ABS
|
| 374 |
+
bbox = BoxMode.convert(bbox, from_bbox_mode, to_bbox_mode)
|
| 375 |
+
|
| 376 |
+
# COCO requirement: instance area
|
| 377 |
+
if "segmentation" in annotation:
|
| 378 |
+
# Computing areas for instances by counting the pixels
|
| 379 |
+
segmentation = annotation["segmentation"]
|
| 380 |
+
# TODO: check segmentation type: RLE, BinaryMask or Polygon
|
| 381 |
+
if isinstance(segmentation, list):
|
| 382 |
+
polygons = PolygonMasks([segmentation])
|
| 383 |
+
area = polygons.area()[0].item()
|
| 384 |
+
elif isinstance(segmentation, dict): # RLE
|
| 385 |
+
area = mask_util.area(segmentation).item()
|
| 386 |
+
else:
|
| 387 |
+
raise TypeError(f"Unknown segmentation type {type(segmentation)}!")
|
| 388 |
+
else:
|
| 389 |
+
# Computing areas using bounding boxes
|
| 390 |
+
if to_bbox_mode == BoxMode.XYWH_ABS:
|
| 391 |
+
bbox_xy = BoxMode.convert(bbox, to_bbox_mode, BoxMode.XYXY_ABS)
|
| 392 |
+
area = Boxes([bbox_xy]).area()[0].item()
|
| 393 |
+
else:
|
| 394 |
+
area = RotatedBoxes([bbox]).area()[0].item()
|
| 395 |
+
|
| 396 |
+
if "keypoints" in annotation:
|
| 397 |
+
keypoints = annotation["keypoints"] # list[int]
|
| 398 |
+
for idx, v in enumerate(keypoints):
|
| 399 |
+
if idx % 3 != 2:
|
| 400 |
+
# COCO's segmentation coordinates are floating points in [0, H or W],
|
| 401 |
+
# but keypoint coordinates are integers in [0, H-1 or W-1]
|
| 402 |
+
# For COCO format consistency we substract 0.5
|
| 403 |
+
# https://github.com/facebookresearch/detectron2/pull/175#issuecomment-551202163
|
| 404 |
+
keypoints[idx] = v - 0.5
|
| 405 |
+
if "num_keypoints" in annotation:
|
| 406 |
+
num_keypoints = annotation["num_keypoints"]
|
| 407 |
+
else:
|
| 408 |
+
num_keypoints = sum(kp > 0 for kp in keypoints[2::3])
|
| 409 |
+
|
| 410 |
+
# COCO requirement:
|
| 411 |
+
# linking annotations to images
|
| 412 |
+
# "id" field must start with 1
|
| 413 |
+
coco_annotation["id"] = len(coco_annotations) + 1
|
| 414 |
+
coco_annotation["image_id"] = coco_image["id"]
|
| 415 |
+
coco_annotation["bbox"] = [round(float(x), 3) for x in bbox]
|
| 416 |
+
coco_annotation["area"] = float(area)
|
| 417 |
+
coco_annotation["iscrowd"] = int(annotation.get("iscrowd", 0))
|
| 418 |
+
coco_annotation["category_id"] = int(reverse_id_mapper(annotation["category_id"]))
|
| 419 |
+
|
| 420 |
+
# Add optional fields
|
| 421 |
+
if "keypoints" in annotation:
|
| 422 |
+
coco_annotation["keypoints"] = keypoints
|
| 423 |
+
coco_annotation["num_keypoints"] = num_keypoints
|
| 424 |
+
|
| 425 |
+
if "segmentation" in annotation:
|
| 426 |
+
seg = coco_annotation["segmentation"] = annotation["segmentation"]
|
| 427 |
+
if isinstance(seg, dict): # RLE
|
| 428 |
+
counts = seg["counts"]
|
| 429 |
+
if not isinstance(counts, str):
|
| 430 |
+
# make it json-serializable
|
| 431 |
+
seg["counts"] = counts.decode("ascii")
|
| 432 |
+
|
| 433 |
+
coco_annotations.append(coco_annotation)
|
| 434 |
+
|
| 435 |
+
logger.info(
|
| 436 |
+
"Conversion finished, "
|
| 437 |
+
f"#images: {len(coco_images)}, #annotations: {len(coco_annotations)}"
|
| 438 |
+
)
|
| 439 |
+
|
| 440 |
+
info = {
|
| 441 |
+
"date_created": str(datetime.datetime.now()),
|
| 442 |
+
"description": "Automatically generated COCO json file for Detectron2.",
|
| 443 |
+
}
|
| 444 |
+
coco_dict = {
|
| 445 |
+
"info": info,
|
| 446 |
+
"images": coco_images,
|
| 447 |
+
"categories": categories,
|
| 448 |
+
"licenses": None,
|
| 449 |
+
}
|
| 450 |
+
if len(coco_annotations) > 0:
|
| 451 |
+
coco_dict["annotations"] = coco_annotations
|
| 452 |
+
return coco_dict
|
| 453 |
+
|
| 454 |
+
|
| 455 |
+
def convert_to_coco_json(dataset_name, output_file, allow_cached=True):
|
| 456 |
+
"""
|
| 457 |
+
Converts dataset into COCO format and saves it to a json file.
|
| 458 |
+
dataset_name must be registered in DatasetCatalog and in detectron2's standard format.
|
| 459 |
+
|
| 460 |
+
Args:
|
| 461 |
+
dataset_name:
|
| 462 |
+
reference from the config file to the catalogs
|
| 463 |
+
must be registered in DatasetCatalog and in detectron2's standard format
|
| 464 |
+
output_file: path of json file that will be saved to
|
| 465 |
+
allow_cached: if json file is already present then skip conversion
|
| 466 |
+
"""
|
| 467 |
+
|
| 468 |
+
# TODO: The dataset or the conversion script *may* change,
|
| 469 |
+
# a checksum would be useful for validating the cached data
|
| 470 |
+
|
| 471 |
+
PathManager.mkdirs(os.path.dirname(output_file))
|
| 472 |
+
with file_lock(output_file):
|
| 473 |
+
if PathManager.exists(output_file) and allow_cached:
|
| 474 |
+
logger.warning(
|
| 475 |
+
f"Using previously cached COCO format annotations at '{output_file}'. "
|
| 476 |
+
"You need to clear the cache file if your dataset has been modified."
|
| 477 |
+
)
|
| 478 |
+
else:
|
| 479 |
+
logger.info(f"Converting annotations of dataset '{dataset_name}' to COCO format ...)")
|
| 480 |
+
coco_dict = convert_to_coco_dict(dataset_name)
|
| 481 |
+
|
| 482 |
+
logger.info(f"Caching COCO format annotations at '{output_file}' ...")
|
| 483 |
+
tmp_file = output_file + ".tmp"
|
| 484 |
+
with PathManager.open(tmp_file, "w") as f:
|
| 485 |
+
json.dump(coco_dict, f)
|
| 486 |
+
shutil.move(tmp_file, output_file)
|
| 487 |
+
|
| 488 |
+
|
| 489 |
+
def register_coco_instances(name, metadata, json_file, image_root):
|
| 490 |
+
"""
|
| 491 |
+
Register a dataset in COCO's json annotation format for
|
| 492 |
+
instance detection, instance segmentation and keypoint detection.
|
| 493 |
+
(i.e., Type 1 and 2 in http://cocodataset.org/#format-data.
|
| 494 |
+
`instances*.json` and `person_keypoints*.json` in the dataset).
|
| 495 |
+
|
| 496 |
+
This is an example of how to register a new dataset.
|
| 497 |
+
You can do something similar to this function, to register new datasets.
|
| 498 |
+
|
| 499 |
+
Args:
|
| 500 |
+
name (str): the name that identifies a dataset, e.g. "coco_2014_train".
|
| 501 |
+
metadata (dict): extra metadata associated with this dataset. You can
|
| 502 |
+
leave it as an empty dict.
|
| 503 |
+
json_file (str): path to the json instance annotation file.
|
| 504 |
+
image_root (str or path-like): directory which contains all the images.
|
| 505 |
+
"""
|
| 506 |
+
assert isinstance(name, str), name
|
| 507 |
+
assert isinstance(json_file, (str, os.PathLike)), json_file
|
| 508 |
+
assert isinstance(image_root, (str, os.PathLike)), image_root
|
| 509 |
+
# 1. register a function which returns dicts
|
| 510 |
+
DatasetCatalog.register(name, lambda: load_coco_json(json_file, image_root, name))
|
| 511 |
+
|
| 512 |
+
# 2. Optionally, add metadata about this dataset,
|
| 513 |
+
# since they might be useful in evaluation, visualization or logging
|
| 514 |
+
MetadataCatalog.get(name).set(
|
| 515 |
+
json_file=json_file, image_root=image_root, evaluator_type="coco", **metadata
|
| 516 |
+
)
|
| 517 |
+
|
| 518 |
+
|
| 519 |
+
def main() -> None:
|
| 520 |
+
global logger
|
| 521 |
+
"""
|
| 522 |
+
Test the COCO json dataset loader.
|
| 523 |
+
|
| 524 |
+
Usage:
|
| 525 |
+
python -m detectron2.data.datasets.coco \
|
| 526 |
+
path/to/json path/to/image_root dataset_name
|
| 527 |
+
|
| 528 |
+
"dataset_name" can be "coco_2014_minival_100", or other
|
| 529 |
+
pre-registered ones
|
| 530 |
+
"""
|
| 531 |
+
import sys
|
| 532 |
+
|
| 533 |
+
import detectron2.data.datasets # noqa # add pre-defined metadata
|
| 534 |
+
from detectron2.utils.logger import setup_logger
|
| 535 |
+
from detectron2.utils.visualizer import Visualizer
|
| 536 |
+
|
| 537 |
+
logger = setup_logger(name=__name__)
|
| 538 |
+
assert sys.argv[3] in DatasetCatalog.list()
|
| 539 |
+
meta = MetadataCatalog.get(sys.argv[3])
|
| 540 |
+
|
| 541 |
+
dicts = load_coco_json(sys.argv[1], sys.argv[2], sys.argv[3])
|
| 542 |
+
logger.info("Done loading {} samples.".format(len(dicts)))
|
| 543 |
+
|
| 544 |
+
dirname = "coco-data-vis"
|
| 545 |
+
os.makedirs(dirname, exist_ok=True)
|
| 546 |
+
for d in dicts:
|
| 547 |
+
img = np.array(Image.open(d["file_name"]))
|
| 548 |
+
visualizer = Visualizer(img, metadata=meta)
|
| 549 |
+
vis = visualizer.draw_dataset_dict(d)
|
| 550 |
+
fpath = os.path.join(dirname, os.path.basename(d["file_name"]))
|
| 551 |
+
vis.save(fpath)
|
| 552 |
+
|
| 553 |
+
|
| 554 |
+
if __name__ == "__main__":
|
| 555 |
+
main() # pragma: no cover
|
CatVTON/detectron2/data/datasets/coco_panoptic.py
ADDED
|
@@ -0,0 +1,228 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
import copy
|
| 3 |
+
import json
|
| 4 |
+
import os
|
| 5 |
+
|
| 6 |
+
from detectron2.data import DatasetCatalog, MetadataCatalog
|
| 7 |
+
from detectron2.utils.file_io import PathManager
|
| 8 |
+
|
| 9 |
+
from .coco import load_coco_json, load_sem_seg
|
| 10 |
+
|
| 11 |
+
__all__ = ["register_coco_panoptic", "register_coco_panoptic_separated"]
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def load_coco_panoptic_json(json_file, image_dir, gt_dir, meta):
|
| 15 |
+
"""
|
| 16 |
+
Args:
|
| 17 |
+
image_dir (str): path to the raw dataset. e.g., "~/coco/train2017".
|
| 18 |
+
gt_dir (str): path to the raw annotations. e.g., "~/coco/panoptic_train2017".
|
| 19 |
+
json_file (str): path to the json file. e.g., "~/coco/annotations/panoptic_train2017.json".
|
| 20 |
+
|
| 21 |
+
Returns:
|
| 22 |
+
list[dict]: a list of dicts in Detectron2 standard format. (See
|
| 23 |
+
`Using Custom Datasets </tutorials/datasets.html>`_ )
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
def _convert_category_id(segment_info, meta):
|
| 27 |
+
if segment_info["category_id"] in meta["thing_dataset_id_to_contiguous_id"]:
|
| 28 |
+
segment_info["category_id"] = meta["thing_dataset_id_to_contiguous_id"][
|
| 29 |
+
segment_info["category_id"]
|
| 30 |
+
]
|
| 31 |
+
segment_info["isthing"] = True
|
| 32 |
+
else:
|
| 33 |
+
segment_info["category_id"] = meta["stuff_dataset_id_to_contiguous_id"][
|
| 34 |
+
segment_info["category_id"]
|
| 35 |
+
]
|
| 36 |
+
segment_info["isthing"] = False
|
| 37 |
+
return segment_info
|
| 38 |
+
|
| 39 |
+
with PathManager.open(json_file) as f:
|
| 40 |
+
json_info = json.load(f)
|
| 41 |
+
|
| 42 |
+
ret = []
|
| 43 |
+
for ann in json_info["annotations"]:
|
| 44 |
+
image_id = int(ann["image_id"])
|
| 45 |
+
# TODO: currently we assume image and label has the same filename but
|
| 46 |
+
# different extension, and images have extension ".jpg" for COCO. Need
|
| 47 |
+
# to make image extension a user-provided argument if we extend this
|
| 48 |
+
# function to support other COCO-like datasets.
|
| 49 |
+
image_file = os.path.join(image_dir, os.path.splitext(ann["file_name"])[0] + ".jpg")
|
| 50 |
+
label_file = os.path.join(gt_dir, ann["file_name"])
|
| 51 |
+
segments_info = [_convert_category_id(x, meta) for x in ann["segments_info"]]
|
| 52 |
+
ret.append(
|
| 53 |
+
{
|
| 54 |
+
"file_name": image_file,
|
| 55 |
+
"image_id": image_id,
|
| 56 |
+
"pan_seg_file_name": label_file,
|
| 57 |
+
"segments_info": segments_info,
|
| 58 |
+
}
|
| 59 |
+
)
|
| 60 |
+
assert len(ret), f"No images found in {image_dir}!"
|
| 61 |
+
assert PathManager.isfile(ret[0]["file_name"]), ret[0]["file_name"]
|
| 62 |
+
assert PathManager.isfile(ret[0]["pan_seg_file_name"]), ret[0]["pan_seg_file_name"]
|
| 63 |
+
return ret
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def register_coco_panoptic(
|
| 67 |
+
name, metadata, image_root, panoptic_root, panoptic_json, instances_json=None
|
| 68 |
+
):
|
| 69 |
+
"""
|
| 70 |
+
Register a "standard" version of COCO panoptic segmentation dataset named `name`.
|
| 71 |
+
The dictionaries in this registered dataset follows detectron2's standard format.
|
| 72 |
+
Hence it's called "standard".
|
| 73 |
+
|
| 74 |
+
Args:
|
| 75 |
+
name (str): the name that identifies a dataset,
|
| 76 |
+
e.g. "coco_2017_train_panoptic"
|
| 77 |
+
metadata (dict): extra metadata associated with this dataset.
|
| 78 |
+
image_root (str): directory which contains all the images
|
| 79 |
+
panoptic_root (str): directory which contains panoptic annotation images in COCO format
|
| 80 |
+
panoptic_json (str): path to the json panoptic annotation file in COCO format
|
| 81 |
+
sem_seg_root (none): not used, to be consistent with
|
| 82 |
+
`register_coco_panoptic_separated`.
|
| 83 |
+
instances_json (str): path to the json instance annotation file
|
| 84 |
+
"""
|
| 85 |
+
panoptic_name = name
|
| 86 |
+
DatasetCatalog.register(
|
| 87 |
+
panoptic_name,
|
| 88 |
+
lambda: load_coco_panoptic_json(panoptic_json, image_root, panoptic_root, metadata),
|
| 89 |
+
)
|
| 90 |
+
MetadataCatalog.get(panoptic_name).set(
|
| 91 |
+
panoptic_root=panoptic_root,
|
| 92 |
+
image_root=image_root,
|
| 93 |
+
panoptic_json=panoptic_json,
|
| 94 |
+
json_file=instances_json,
|
| 95 |
+
evaluator_type="coco_panoptic_seg",
|
| 96 |
+
ignore_label=255,
|
| 97 |
+
label_divisor=1000,
|
| 98 |
+
**metadata,
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def register_coco_panoptic_separated(
|
| 103 |
+
name, metadata, image_root, panoptic_root, panoptic_json, sem_seg_root, instances_json
|
| 104 |
+
):
|
| 105 |
+
"""
|
| 106 |
+
Register a "separated" version of COCO panoptic segmentation dataset named `name`.
|
| 107 |
+
The annotations in this registered dataset will contain both instance annotations and
|
| 108 |
+
semantic annotations, each with its own contiguous ids. Hence it's called "separated".
|
| 109 |
+
|
| 110 |
+
It follows the setting used by the PanopticFPN paper:
|
| 111 |
+
|
| 112 |
+
1. The instance annotations directly come from polygons in the COCO
|
| 113 |
+
instances annotation task, rather than from the masks in the COCO panoptic annotations.
|
| 114 |
+
|
| 115 |
+
The two format have small differences:
|
| 116 |
+
Polygons in the instance annotations may have overlaps.
|
| 117 |
+
The mask annotations are produced by labeling the overlapped polygons
|
| 118 |
+
with depth ordering.
|
| 119 |
+
|
| 120 |
+
2. The semantic annotations are converted from panoptic annotations, where
|
| 121 |
+
all "things" are assigned a semantic id of 0.
|
| 122 |
+
All semantic categories will therefore have ids in contiguous
|
| 123 |
+
range [1, #stuff_categories].
|
| 124 |
+
|
| 125 |
+
This function will also register a pure semantic segmentation dataset
|
| 126 |
+
named ``name + '_stuffonly'``.
|
| 127 |
+
|
| 128 |
+
Args:
|
| 129 |
+
name (str): the name that identifies a dataset,
|
| 130 |
+
e.g. "coco_2017_train_panoptic"
|
| 131 |
+
metadata (dict): extra metadata associated with this dataset.
|
| 132 |
+
image_root (str): directory which contains all the images
|
| 133 |
+
panoptic_root (str): directory which contains panoptic annotation images
|
| 134 |
+
panoptic_json (str): path to the json panoptic annotation file
|
| 135 |
+
sem_seg_root (str): directory which contains all the ground truth segmentation annotations.
|
| 136 |
+
instances_json (str): path to the json instance annotation file
|
| 137 |
+
"""
|
| 138 |
+
panoptic_name = name + "_separated"
|
| 139 |
+
DatasetCatalog.register(
|
| 140 |
+
panoptic_name,
|
| 141 |
+
lambda: merge_to_panoptic(
|
| 142 |
+
load_coco_json(instances_json, image_root, panoptic_name),
|
| 143 |
+
load_sem_seg(sem_seg_root, image_root),
|
| 144 |
+
),
|
| 145 |
+
)
|
| 146 |
+
MetadataCatalog.get(panoptic_name).set(
|
| 147 |
+
panoptic_root=panoptic_root,
|
| 148 |
+
image_root=image_root,
|
| 149 |
+
panoptic_json=panoptic_json,
|
| 150 |
+
sem_seg_root=sem_seg_root,
|
| 151 |
+
json_file=instances_json, # TODO rename
|
| 152 |
+
evaluator_type="coco_panoptic_seg",
|
| 153 |
+
ignore_label=255,
|
| 154 |
+
**metadata,
|
| 155 |
+
)
|
| 156 |
+
|
| 157 |
+
semantic_name = name + "_stuffonly"
|
| 158 |
+
DatasetCatalog.register(semantic_name, lambda: load_sem_seg(sem_seg_root, image_root))
|
| 159 |
+
MetadataCatalog.get(semantic_name).set(
|
| 160 |
+
sem_seg_root=sem_seg_root,
|
| 161 |
+
image_root=image_root,
|
| 162 |
+
evaluator_type="sem_seg",
|
| 163 |
+
ignore_label=255,
|
| 164 |
+
**metadata,
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
def merge_to_panoptic(detection_dicts, sem_seg_dicts):
|
| 169 |
+
"""
|
| 170 |
+
Create dataset dicts for panoptic segmentation, by
|
| 171 |
+
merging two dicts using "file_name" field to match their entries.
|
| 172 |
+
|
| 173 |
+
Args:
|
| 174 |
+
detection_dicts (list[dict]): lists of dicts for object detection or instance segmentation.
|
| 175 |
+
sem_seg_dicts (list[dict]): lists of dicts for semantic segmentation.
|
| 176 |
+
|
| 177 |
+
Returns:
|
| 178 |
+
list[dict] (one per input image): Each dict contains all (key, value) pairs from dicts in
|
| 179 |
+
both detection_dicts and sem_seg_dicts that correspond to the same image.
|
| 180 |
+
The function assumes that the same key in different dicts has the same value.
|
| 181 |
+
"""
|
| 182 |
+
results = []
|
| 183 |
+
sem_seg_file_to_entry = {x["file_name"]: x for x in sem_seg_dicts}
|
| 184 |
+
assert len(sem_seg_file_to_entry) > 0
|
| 185 |
+
|
| 186 |
+
for det_dict in detection_dicts:
|
| 187 |
+
dic = copy.copy(det_dict)
|
| 188 |
+
dic.update(sem_seg_file_to_entry[dic["file_name"]])
|
| 189 |
+
results.append(dic)
|
| 190 |
+
return results
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
if __name__ == "__main__":
|
| 194 |
+
"""
|
| 195 |
+
Test the COCO panoptic dataset loader.
|
| 196 |
+
|
| 197 |
+
Usage:
|
| 198 |
+
python -m detectron2.data.datasets.coco_panoptic \
|
| 199 |
+
path/to/image_root path/to/panoptic_root path/to/panoptic_json dataset_name 10
|
| 200 |
+
|
| 201 |
+
"dataset_name" can be "coco_2017_train_panoptic", or other
|
| 202 |
+
pre-registered ones
|
| 203 |
+
"""
|
| 204 |
+
from detectron2.utils.logger import setup_logger
|
| 205 |
+
from detectron2.utils.visualizer import Visualizer
|
| 206 |
+
import detectron2.data.datasets # noqa # add pre-defined metadata
|
| 207 |
+
import sys
|
| 208 |
+
from PIL import Image
|
| 209 |
+
import numpy as np
|
| 210 |
+
|
| 211 |
+
logger = setup_logger(name=__name__)
|
| 212 |
+
assert sys.argv[4] in DatasetCatalog.list()
|
| 213 |
+
meta = MetadataCatalog.get(sys.argv[4])
|
| 214 |
+
|
| 215 |
+
dicts = load_coco_panoptic_json(sys.argv[3], sys.argv[1], sys.argv[2], meta.as_dict())
|
| 216 |
+
logger.info("Done loading {} samples.".format(len(dicts)))
|
| 217 |
+
|
| 218 |
+
dirname = "coco-data-vis"
|
| 219 |
+
os.makedirs(dirname, exist_ok=True)
|
| 220 |
+
num_imgs_to_vis = int(sys.argv[5])
|
| 221 |
+
for i, d in enumerate(dicts):
|
| 222 |
+
img = np.array(Image.open(d["file_name"]))
|
| 223 |
+
visualizer = Visualizer(img, metadata=meta)
|
| 224 |
+
vis = visualizer.draw_dataset_dict(d)
|
| 225 |
+
fpath = os.path.join(dirname, os.path.basename(d["file_name"]))
|
| 226 |
+
vis.save(fpath)
|
| 227 |
+
if i + 1 >= num_imgs_to_vis:
|
| 228 |
+
break
|
CatVTON/detectron2/data/datasets/lvis.py
ADDED
|
@@ -0,0 +1,250 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
import logging
|
| 3 |
+
import os
|
| 4 |
+
from fvcore.common.timer import Timer
|
| 5 |
+
|
| 6 |
+
from detectron2.data import DatasetCatalog, MetadataCatalog
|
| 7 |
+
from detectron2.structures import BoxMode
|
| 8 |
+
from detectron2.utils.file_io import PathManager
|
| 9 |
+
|
| 10 |
+
from .builtin_meta import _get_coco_instances_meta
|
| 11 |
+
from .lvis_v0_5_categories import LVIS_CATEGORIES as LVIS_V0_5_CATEGORIES
|
| 12 |
+
from .lvis_v1_categories import LVIS_CATEGORIES as LVIS_V1_CATEGORIES
|
| 13 |
+
from .lvis_v1_category_image_count import LVIS_CATEGORY_IMAGE_COUNT as LVIS_V1_CATEGORY_IMAGE_COUNT
|
| 14 |
+
|
| 15 |
+
"""
|
| 16 |
+
This file contains functions to parse LVIS-format annotations into dicts in the
|
| 17 |
+
"Detectron2 format".
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
logger = logging.getLogger(__name__)
|
| 21 |
+
|
| 22 |
+
__all__ = ["load_lvis_json", "register_lvis_instances", "get_lvis_instances_meta"]
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def register_lvis_instances(name, metadata, json_file, image_root):
|
| 26 |
+
"""
|
| 27 |
+
Register a dataset in LVIS's json annotation format for instance detection and segmentation.
|
| 28 |
+
|
| 29 |
+
Args:
|
| 30 |
+
name (str): a name that identifies the dataset, e.g. "lvis_v0.5_train".
|
| 31 |
+
metadata (dict): extra metadata associated with this dataset. It can be an empty dict.
|
| 32 |
+
json_file (str): path to the json instance annotation file.
|
| 33 |
+
image_root (str or path-like): directory which contains all the images.
|
| 34 |
+
"""
|
| 35 |
+
DatasetCatalog.register(name, lambda: load_lvis_json(json_file, image_root, name))
|
| 36 |
+
MetadataCatalog.get(name).set(
|
| 37 |
+
json_file=json_file, image_root=image_root, evaluator_type="lvis", **metadata
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def load_lvis_json(json_file, image_root, dataset_name=None, extra_annotation_keys=None):
|
| 42 |
+
"""
|
| 43 |
+
Load a json file in LVIS's annotation format.
|
| 44 |
+
|
| 45 |
+
Args:
|
| 46 |
+
json_file (str): full path to the LVIS json annotation file.
|
| 47 |
+
image_root (str): the directory where the images in this json file exists.
|
| 48 |
+
dataset_name (str): the name of the dataset (e.g., "lvis_v0.5_train").
|
| 49 |
+
If provided, this function will put "thing_classes" into the metadata
|
| 50 |
+
associated with this dataset.
|
| 51 |
+
extra_annotation_keys (list[str]): list of per-annotation keys that should also be
|
| 52 |
+
loaded into the dataset dict (besides "bbox", "bbox_mode", "category_id",
|
| 53 |
+
"segmentation"). The values for these keys will be returned as-is.
|
| 54 |
+
|
| 55 |
+
Returns:
|
| 56 |
+
list[dict]: a list of dicts in Detectron2 standard format. (See
|
| 57 |
+
`Using Custom Datasets </tutorials/datasets.html>`_ )
|
| 58 |
+
|
| 59 |
+
Notes:
|
| 60 |
+
1. This function does not read the image files.
|
| 61 |
+
The results do not have the "image" field.
|
| 62 |
+
"""
|
| 63 |
+
from lvis import LVIS
|
| 64 |
+
|
| 65 |
+
json_file = PathManager.get_local_path(json_file)
|
| 66 |
+
|
| 67 |
+
timer = Timer()
|
| 68 |
+
lvis_api = LVIS(json_file)
|
| 69 |
+
if timer.seconds() > 1:
|
| 70 |
+
logger.info("Loading {} takes {:.2f} seconds.".format(json_file, timer.seconds()))
|
| 71 |
+
|
| 72 |
+
if dataset_name is not None:
|
| 73 |
+
meta = get_lvis_instances_meta(dataset_name)
|
| 74 |
+
MetadataCatalog.get(dataset_name).set(**meta)
|
| 75 |
+
|
| 76 |
+
# sort indices for reproducible results
|
| 77 |
+
img_ids = sorted(lvis_api.imgs.keys())
|
| 78 |
+
# imgs is a list of dicts, each looks something like:
|
| 79 |
+
# {'license': 4,
|
| 80 |
+
# 'url': 'http://farm6.staticflickr.com/5454/9413846304_881d5e5c3b_z.jpg',
|
| 81 |
+
# 'file_name': 'COCO_val2014_000000001268.jpg',
|
| 82 |
+
# 'height': 427,
|
| 83 |
+
# 'width': 640,
|
| 84 |
+
# 'date_captured': '2013-11-17 05:57:24',
|
| 85 |
+
# 'id': 1268}
|
| 86 |
+
imgs = lvis_api.load_imgs(img_ids)
|
| 87 |
+
# anns is a list[list[dict]], where each dict is an annotation
|
| 88 |
+
# record for an object. The inner list enumerates the objects in an image
|
| 89 |
+
# and the outer list enumerates over images. Example of anns[0]:
|
| 90 |
+
# [{'segmentation': [[192.81,
|
| 91 |
+
# 247.09,
|
| 92 |
+
# ...
|
| 93 |
+
# 219.03,
|
| 94 |
+
# 249.06]],
|
| 95 |
+
# 'area': 1035.749,
|
| 96 |
+
# 'image_id': 1268,
|
| 97 |
+
# 'bbox': [192.81, 224.8, 74.73, 33.43],
|
| 98 |
+
# 'category_id': 16,
|
| 99 |
+
# 'id': 42986},
|
| 100 |
+
# ...]
|
| 101 |
+
anns = [lvis_api.img_ann_map[img_id] for img_id in img_ids]
|
| 102 |
+
|
| 103 |
+
# Sanity check that each annotation has a unique id
|
| 104 |
+
ann_ids = [ann["id"] for anns_per_image in anns for ann in anns_per_image]
|
| 105 |
+
assert len(set(ann_ids)) == len(ann_ids), "Annotation ids in '{}' are not unique".format(
|
| 106 |
+
json_file
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
imgs_anns = list(zip(imgs, anns))
|
| 110 |
+
|
| 111 |
+
logger.info("Loaded {} images in the LVIS format from {}".format(len(imgs_anns), json_file))
|
| 112 |
+
|
| 113 |
+
if extra_annotation_keys:
|
| 114 |
+
logger.info(
|
| 115 |
+
"The following extra annotation keys will be loaded: {} ".format(extra_annotation_keys)
|
| 116 |
+
)
|
| 117 |
+
else:
|
| 118 |
+
extra_annotation_keys = []
|
| 119 |
+
|
| 120 |
+
def get_file_name(img_root, img_dict):
|
| 121 |
+
# Determine the path including the split folder ("train2017", "val2017", "test2017") from
|
| 122 |
+
# the coco_url field. Example:
|
| 123 |
+
# 'coco_url': 'http://images.cocodataset.org/train2017/000000155379.jpg'
|
| 124 |
+
split_folder, file_name = img_dict["coco_url"].split("/")[-2:]
|
| 125 |
+
return os.path.join(img_root + split_folder, file_name)
|
| 126 |
+
|
| 127 |
+
dataset_dicts = []
|
| 128 |
+
|
| 129 |
+
for img_dict, anno_dict_list in imgs_anns:
|
| 130 |
+
record = {}
|
| 131 |
+
record["file_name"] = get_file_name(image_root, img_dict)
|
| 132 |
+
record["height"] = img_dict["height"]
|
| 133 |
+
record["width"] = img_dict["width"]
|
| 134 |
+
record["not_exhaustive_category_ids"] = img_dict.get("not_exhaustive_category_ids", [])
|
| 135 |
+
record["neg_category_ids"] = img_dict.get("neg_category_ids", [])
|
| 136 |
+
image_id = record["image_id"] = img_dict["id"]
|
| 137 |
+
|
| 138 |
+
objs = []
|
| 139 |
+
for anno in anno_dict_list:
|
| 140 |
+
# Check that the image_id in this annotation is the same as
|
| 141 |
+
# the image_id we're looking at.
|
| 142 |
+
# This fails only when the data parsing logic or the annotation file is buggy.
|
| 143 |
+
assert anno["image_id"] == image_id
|
| 144 |
+
obj = {"bbox": anno["bbox"], "bbox_mode": BoxMode.XYWH_ABS}
|
| 145 |
+
# LVIS data loader can be used to load COCO dataset categories. In this case `meta`
|
| 146 |
+
# variable will have a field with COCO-specific category mapping.
|
| 147 |
+
if dataset_name is not None and "thing_dataset_id_to_contiguous_id" in meta:
|
| 148 |
+
obj["category_id"] = meta["thing_dataset_id_to_contiguous_id"][anno["category_id"]]
|
| 149 |
+
else:
|
| 150 |
+
obj["category_id"] = anno["category_id"] - 1 # Convert 1-indexed to 0-indexed
|
| 151 |
+
segm = anno["segmentation"] # list[list[float]]
|
| 152 |
+
# filter out invalid polygons (< 3 points)
|
| 153 |
+
valid_segm = [poly for poly in segm if len(poly) % 2 == 0 and len(poly) >= 6]
|
| 154 |
+
assert len(segm) == len(
|
| 155 |
+
valid_segm
|
| 156 |
+
), "Annotation contains an invalid polygon with < 3 points"
|
| 157 |
+
assert len(segm) > 0
|
| 158 |
+
obj["segmentation"] = segm
|
| 159 |
+
for extra_ann_key in extra_annotation_keys:
|
| 160 |
+
obj[extra_ann_key] = anno[extra_ann_key]
|
| 161 |
+
objs.append(obj)
|
| 162 |
+
record["annotations"] = objs
|
| 163 |
+
dataset_dicts.append(record)
|
| 164 |
+
|
| 165 |
+
return dataset_dicts
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
def get_lvis_instances_meta(dataset_name):
|
| 169 |
+
"""
|
| 170 |
+
Load LVIS metadata.
|
| 171 |
+
|
| 172 |
+
Args:
|
| 173 |
+
dataset_name (str): LVIS dataset name without the split name (e.g., "lvis_v0.5").
|
| 174 |
+
|
| 175 |
+
Returns:
|
| 176 |
+
dict: LVIS metadata with keys: thing_classes
|
| 177 |
+
"""
|
| 178 |
+
if "cocofied" in dataset_name:
|
| 179 |
+
return _get_coco_instances_meta()
|
| 180 |
+
if "v0.5" in dataset_name:
|
| 181 |
+
return _get_lvis_instances_meta_v0_5()
|
| 182 |
+
elif "v1" in dataset_name:
|
| 183 |
+
return _get_lvis_instances_meta_v1()
|
| 184 |
+
raise ValueError("No built-in metadata for dataset {}".format(dataset_name))
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
def _get_lvis_instances_meta_v0_5():
|
| 188 |
+
assert len(LVIS_V0_5_CATEGORIES) == 1230
|
| 189 |
+
cat_ids = [k["id"] for k in LVIS_V0_5_CATEGORIES]
|
| 190 |
+
assert min(cat_ids) == 1 and max(cat_ids) == len(
|
| 191 |
+
cat_ids
|
| 192 |
+
), "Category ids are not in [1, #categories], as expected"
|
| 193 |
+
# Ensure that the category list is sorted by id
|
| 194 |
+
lvis_categories = sorted(LVIS_V0_5_CATEGORIES, key=lambda x: x["id"])
|
| 195 |
+
thing_classes = [k["synonyms"][0] for k in lvis_categories]
|
| 196 |
+
meta = {"thing_classes": thing_classes}
|
| 197 |
+
return meta
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
def _get_lvis_instances_meta_v1():
|
| 201 |
+
assert len(LVIS_V1_CATEGORIES) == 1203
|
| 202 |
+
cat_ids = [k["id"] for k in LVIS_V1_CATEGORIES]
|
| 203 |
+
assert min(cat_ids) == 1 and max(cat_ids) == len(
|
| 204 |
+
cat_ids
|
| 205 |
+
), "Category ids are not in [1, #categories], as expected"
|
| 206 |
+
# Ensure that the category list is sorted by id
|
| 207 |
+
lvis_categories = sorted(LVIS_V1_CATEGORIES, key=lambda x: x["id"])
|
| 208 |
+
thing_classes = [k["synonyms"][0] for k in lvis_categories]
|
| 209 |
+
meta = {
|
| 210 |
+
"thing_classes": thing_classes,
|
| 211 |
+
"class_image_count": LVIS_V1_CATEGORY_IMAGE_COUNT,
|
| 212 |
+
}
|
| 213 |
+
return meta
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
def main() -> None:
|
| 217 |
+
global logger
|
| 218 |
+
"""
|
| 219 |
+
Test the LVIS json dataset loader.
|
| 220 |
+
|
| 221 |
+
Usage:
|
| 222 |
+
python -m detectron2.data.datasets.lvis \
|
| 223 |
+
path/to/json path/to/image_root dataset_name vis_limit
|
| 224 |
+
"""
|
| 225 |
+
import sys
|
| 226 |
+
|
| 227 |
+
import detectron2.data.datasets # noqa # add pre-defined metadata
|
| 228 |
+
import numpy as np
|
| 229 |
+
from detectron2.utils.logger import setup_logger
|
| 230 |
+
from detectron2.utils.visualizer import Visualizer
|
| 231 |
+
from PIL import Image
|
| 232 |
+
|
| 233 |
+
logger = setup_logger(name=__name__)
|
| 234 |
+
meta = MetadataCatalog.get(sys.argv[3])
|
| 235 |
+
|
| 236 |
+
dicts = load_lvis_json(sys.argv[1], sys.argv[2], sys.argv[3])
|
| 237 |
+
logger.info("Done loading {} samples.".format(len(dicts)))
|
| 238 |
+
|
| 239 |
+
dirname = "lvis-data-vis"
|
| 240 |
+
os.makedirs(dirname, exist_ok=True)
|
| 241 |
+
for d in dicts[: int(sys.argv[4])]:
|
| 242 |
+
img = np.array(Image.open(d["file_name"]))
|
| 243 |
+
visualizer = Visualizer(img, metadata=meta)
|
| 244 |
+
vis = visualizer.draw_dataset_dict(d)
|
| 245 |
+
fpath = os.path.join(dirname, os.path.basename(d["file_name"]))
|
| 246 |
+
vis.save(fpath)
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
if __name__ == "__main__":
|
| 250 |
+
main() # pragma: no cover
|
CatVTON/detectron2/data/datasets/lvis_v0_5_categories.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
CatVTON/detectron2/data/datasets/lvis_v1_categories.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
CatVTON/detectron2/data/datasets/lvis_v1_category_image_count.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
# Autogen with
|
| 3 |
+
# with open("lvis_v1_train.json", "r") as f:
|
| 4 |
+
# a = json.load(f)
|
| 5 |
+
# c = a["categories"]
|
| 6 |
+
# for x in c:
|
| 7 |
+
# del x["name"]
|
| 8 |
+
# del x["instance_count"]
|
| 9 |
+
# del x["def"]
|
| 10 |
+
# del x["synonyms"]
|
| 11 |
+
# del x["frequency"]
|
| 12 |
+
# del x["synset"]
|
| 13 |
+
# LVIS_CATEGORY_IMAGE_COUNT = repr(c) + " # noqa"
|
| 14 |
+
# with open("/tmp/lvis_category_image_count.py", "wt") as f:
|
| 15 |
+
# f.write(f"LVIS_CATEGORY_IMAGE_COUNT = {LVIS_CATEGORY_IMAGE_COUNT}")
|
| 16 |
+
# Then paste the contents of that file below
|
| 17 |
+
|
| 18 |
+
# fmt: off
|
| 19 |
+
LVIS_CATEGORY_IMAGE_COUNT = [{'id': 1, 'image_count': 64}, {'id': 2, 'image_count': 364}, {'id': 3, 'image_count': 1911}, {'id': 4, 'image_count': 149}, {'id': 5, 'image_count': 29}, {'id': 6, 'image_count': 26}, {'id': 7, 'image_count': 59}, {'id': 8, 'image_count': 22}, {'id': 9, 'image_count': 12}, {'id': 10, 'image_count': 28}, {'id': 11, 'image_count': 505}, {'id': 12, 'image_count': 1207}, {'id': 13, 'image_count': 4}, {'id': 14, 'image_count': 10}, {'id': 15, 'image_count': 500}, {'id': 16, 'image_count': 33}, {'id': 17, 'image_count': 3}, {'id': 18, 'image_count': 44}, {'id': 19, 'image_count': 561}, {'id': 20, 'image_count': 8}, {'id': 21, 'image_count': 9}, {'id': 22, 'image_count': 33}, {'id': 23, 'image_count': 1883}, {'id': 24, 'image_count': 98}, {'id': 25, 'image_count': 70}, {'id': 26, 'image_count': 46}, {'id': 27, 'image_count': 117}, {'id': 28, 'image_count': 41}, {'id': 29, 'image_count': 1395}, {'id': 30, 'image_count': 7}, {'id': 31, 'image_count': 1}, {'id': 32, 'image_count': 314}, {'id': 33, 'image_count': 31}, {'id': 34, 'image_count': 1905}, {'id': 35, 'image_count': 1859}, {'id': 36, 'image_count': 1623}, {'id': 37, 'image_count': 47}, {'id': 38, 'image_count': 3}, {'id': 39, 'image_count': 3}, {'id': 40, 'image_count': 1}, {'id': 41, 'image_count': 305}, {'id': 42, 'image_count': 6}, {'id': 43, 'image_count': 210}, {'id': 44, 'image_count': 36}, {'id': 45, 'image_count': 1787}, {'id': 46, 'image_count': 17}, {'id': 47, 'image_count': 51}, {'id': 48, 'image_count': 138}, {'id': 49, 'image_count': 3}, {'id': 50, 'image_count': 1470}, {'id': 51, 'image_count': 3}, {'id': 52, 'image_count': 2}, {'id': 53, 'image_count': 186}, {'id': 54, 'image_count': 76}, {'id': 55, 'image_count': 26}, {'id': 56, 'image_count': 303}, {'id': 57, 'image_count': 738}, {'id': 58, 'image_count': 1799}, {'id': 59, 'image_count': 1934}, {'id': 60, 'image_count': 1609}, {'id': 61, 'image_count': 1622}, {'id': 62, 'image_count': 41}, {'id': 63, 'image_count': 4}, {'id': 64, 'image_count': 11}, {'id': 65, 'image_count': 270}, {'id': 66, 'image_count': 349}, {'id': 67, 'image_count': 42}, {'id': 68, 'image_count': 823}, {'id': 69, 'image_count': 6}, {'id': 70, 'image_count': 48}, {'id': 71, 'image_count': 3}, {'id': 72, 'image_count': 42}, {'id': 73, 'image_count': 24}, {'id': 74, 'image_count': 16}, {'id': 75, 'image_count': 605}, {'id': 76, 'image_count': 646}, {'id': 77, 'image_count': 1765}, {'id': 78, 'image_count': 2}, {'id': 79, 'image_count': 125}, {'id': 80, 'image_count': 1420}, {'id': 81, 'image_count': 140}, {'id': 82, 'image_count': 4}, {'id': 83, 'image_count': 322}, {'id': 84, 'image_count': 60}, {'id': 85, 'image_count': 2}, {'id': 86, 'image_count': 231}, {'id': 87, 'image_count': 333}, {'id': 88, 'image_count': 1941}, {'id': 89, 'image_count': 367}, {'id': 90, 'image_count': 1922}, {'id': 91, 'image_count': 18}, {'id': 92, 'image_count': 81}, {'id': 93, 'image_count': 1}, {'id': 94, 'image_count': 1852}, {'id': 95, 'image_count': 430}, {'id': 96, 'image_count': 247}, {'id': 97, 'image_count': 94}, {'id': 98, 'image_count': 21}, {'id': 99, 'image_count': 1821}, {'id': 100, 'image_count': 16}, {'id': 101, 'image_count': 12}, {'id': 102, 'image_count': 25}, {'id': 103, 'image_count': 41}, {'id': 104, 'image_count': 244}, {'id': 105, 'image_count': 7}, {'id': 106, 'image_count': 1}, {'id': 107, 'image_count': 40}, {'id': 108, 'image_count': 40}, {'id': 109, 'image_count': 104}, {'id': 110, 'image_count': 1671}, {'id': 111, 'image_count': 49}, {'id': 112, 'image_count': 243}, {'id': 113, 'image_count': 2}, {'id': 114, 'image_count': 242}, {'id': 115, 'image_count': 271}, {'id': 116, 'image_count': 104}, {'id': 117, 'image_count': 8}, {'id': 118, 'image_count': 1758}, {'id': 119, 'image_count': 1}, {'id': 120, 'image_count': 48}, {'id': 121, 'image_count': 14}, {'id': 122, 'image_count': 40}, {'id': 123, 'image_count': 1}, {'id': 124, 'image_count': 37}, {'id': 125, 'image_count': 1510}, {'id': 126, 'image_count': 6}, {'id': 127, 'image_count': 1903}, {'id': 128, 'image_count': 70}, {'id': 129, 'image_count': 86}, {'id': 130, 'image_count': 7}, {'id': 131, 'image_count': 5}, {'id': 132, 'image_count': 1406}, {'id': 133, 'image_count': 1901}, {'id': 134, 'image_count': 15}, {'id': 135, 'image_count': 28}, {'id': 136, 'image_count': 6}, {'id': 137, 'image_count': 494}, {'id': 138, 'image_count': 234}, {'id': 139, 'image_count': 1922}, {'id': 140, 'image_count': 1}, {'id': 141, 'image_count': 35}, {'id': 142, 'image_count': 5}, {'id': 143, 'image_count': 1828}, {'id': 144, 'image_count': 8}, {'id': 145, 'image_count': 63}, {'id': 146, 'image_count': 1668}, {'id': 147, 'image_count': 4}, {'id': 148, 'image_count': 95}, {'id': 149, 'image_count': 17}, {'id': 150, 'image_count': 1567}, {'id': 151, 'image_count': 2}, {'id': 152, 'image_count': 103}, {'id': 153, 'image_count': 50}, {'id': 154, 'image_count': 1309}, {'id': 155, 'image_count': 6}, {'id': 156, 'image_count': 92}, {'id': 157, 'image_count': 19}, {'id': 158, 'image_count': 37}, {'id': 159, 'image_count': 4}, {'id': 160, 'image_count': 709}, {'id': 161, 'image_count': 9}, {'id': 162, 'image_count': 82}, {'id': 163, 'image_count': 15}, {'id': 164, 'image_count': 3}, {'id': 165, 'image_count': 61}, {'id': 166, 'image_count': 51}, {'id': 167, 'image_count': 5}, {'id': 168, 'image_count': 13}, {'id': 169, 'image_count': 642}, {'id': 170, 'image_count': 24}, {'id': 171, 'image_count': 255}, {'id': 172, 'image_count': 9}, {'id': 173, 'image_count': 1808}, {'id': 174, 'image_count': 31}, {'id': 175, 'image_count': 158}, {'id': 176, 'image_count': 80}, {'id': 177, 'image_count': 1884}, {'id': 178, 'image_count': 158}, {'id': 179, 'image_count': 2}, {'id': 180, 'image_count': 12}, {'id': 181, 'image_count': 1659}, {'id': 182, 'image_count': 7}, {'id': 183, 'image_count': 834}, {'id': 184, 'image_count': 57}, {'id': 185, 'image_count': 174}, {'id': 186, 'image_count': 95}, {'id': 187, 'image_count': 27}, {'id': 188, 'image_count': 22}, {'id': 189, 'image_count': 1391}, {'id': 190, 'image_count': 90}, {'id': 191, 'image_count': 40}, {'id': 192, 'image_count': 445}, {'id': 193, 'image_count': 21}, {'id': 194, 'image_count': 1132}, {'id': 195, 'image_count': 177}, {'id': 196, 'image_count': 4}, {'id': 197, 'image_count': 17}, {'id': 198, 'image_count': 84}, {'id': 199, 'image_count': 55}, {'id': 200, 'image_count': 30}, {'id': 201, 'image_count': 25}, {'id': 202, 'image_count': 2}, {'id': 203, 'image_count': 125}, {'id': 204, 'image_count': 1135}, {'id': 205, 'image_count': 19}, {'id': 206, 'image_count': 72}, {'id': 207, 'image_count': 1926}, {'id': 208, 'image_count': 159}, {'id': 209, 'image_count': 7}, {'id': 210, 'image_count': 1}, {'id': 211, 'image_count': 13}, {'id': 212, 'image_count': 35}, {'id': 213, 'image_count': 18}, {'id': 214, 'image_count': 8}, {'id': 215, 'image_count': 6}, {'id': 216, 'image_count': 35}, {'id': 217, 'image_count': 1222}, {'id': 218, 'image_count': 103}, {'id': 219, 'image_count': 28}, {'id': 220, 'image_count': 63}, {'id': 221, 'image_count': 28}, {'id': 222, 'image_count': 5}, {'id': 223, 'image_count': 7}, {'id': 224, 'image_count': 14}, {'id': 225, 'image_count': 1918}, {'id': 226, 'image_count': 133}, {'id': 227, 'image_count': 16}, {'id': 228, 'image_count': 27}, {'id': 229, 'image_count': 110}, {'id': 230, 'image_count': 1895}, {'id': 231, 'image_count': 4}, {'id': 232, 'image_count': 1927}, {'id': 233, 'image_count': 8}, {'id': 234, 'image_count': 1}, {'id': 235, 'image_count': 263}, {'id': 236, 'image_count': 10}, {'id': 237, 'image_count': 2}, {'id': 238, 'image_count': 3}, {'id': 239, 'image_count': 87}, {'id': 240, 'image_count': 9}, {'id': 241, 'image_count': 71}, {'id': 242, 'image_count': 13}, {'id': 243, 'image_count': 18}, {'id': 244, 'image_count': 2}, {'id': 245, 'image_count': 5}, {'id': 246, 'image_count': 45}, {'id': 247, 'image_count': 1}, {'id': 248, 'image_count': 23}, {'id': 249, 'image_count': 32}, {'id': 250, 'image_count': 4}, {'id': 251, 'image_count': 1}, {'id': 252, 'image_count': 858}, {'id': 253, 'image_count': 661}, {'id': 254, 'image_count': 168}, {'id': 255, 'image_count': 210}, {'id': 256, 'image_count': 65}, {'id': 257, 'image_count': 4}, {'id': 258, 'image_count': 2}, {'id': 259, 'image_count': 159}, {'id': 260, 'image_count': 31}, {'id': 261, 'image_count': 811}, {'id': 262, 'image_count': 1}, {'id': 263, 'image_count': 42}, {'id': 264, 'image_count': 27}, {'id': 265, 'image_count': 2}, {'id': 266, 'image_count': 5}, {'id': 267, 'image_count': 95}, {'id': 268, 'image_count': 32}, {'id': 269, 'image_count': 1}, {'id': 270, 'image_count': 1}, {'id': 271, 'image_count': 1844}, {'id': 272, 'image_count': 897}, {'id': 273, 'image_count': 31}, {'id': 274, 'image_count': 23}, {'id': 275, 'image_count': 1}, {'id': 276, 'image_count': 202}, {'id': 277, 'image_count': 746}, {'id': 278, 'image_count': 44}, {'id': 279, 'image_count': 14}, {'id': 280, 'image_count': 26}, {'id': 281, 'image_count': 1}, {'id': 282, 'image_count': 2}, {'id': 283, 'image_count': 25}, {'id': 284, 'image_count': 238}, {'id': 285, 'image_count': 592}, {'id': 286, 'image_count': 26}, {'id': 287, 'image_count': 5}, {'id': 288, 'image_count': 42}, {'id': 289, 'image_count': 13}, {'id': 290, 'image_count': 46}, {'id': 291, 'image_count': 1}, {'id': 292, 'image_count': 8}, {'id': 293, 'image_count': 34}, {'id': 294, 'image_count': 5}, {'id': 295, 'image_count': 1}, {'id': 296, 'image_count': 1871}, {'id': 297, 'image_count': 717}, {'id': 298, 'image_count': 1010}, {'id': 299, 'image_count': 679}, {'id': 300, 'image_count': 3}, {'id': 301, 'image_count': 4}, {'id': 302, 'image_count': 1}, {'id': 303, 'image_count': 166}, {'id': 304, 'image_count': 2}, {'id': 305, 'image_count': 266}, {'id': 306, 'image_count': 101}, {'id': 307, 'image_count': 6}, {'id': 308, 'image_count': 14}, {'id': 309, 'image_count': 133}, {'id': 310, 'image_count': 2}, {'id': 311, 'image_count': 38}, {'id': 312, 'image_count': 95}, {'id': 313, 'image_count': 1}, {'id': 314, 'image_count': 12}, {'id': 315, 'image_count': 49}, {'id': 316, 'image_count': 5}, {'id': 317, 'image_count': 5}, {'id': 318, 'image_count': 16}, {'id': 319, 'image_count': 216}, {'id': 320, 'image_count': 12}, {'id': 321, 'image_count': 1}, {'id': 322, 'image_count': 54}, {'id': 323, 'image_count': 5}, {'id': 324, 'image_count': 245}, {'id': 325, 'image_count': 12}, {'id': 326, 'image_count': 7}, {'id': 327, 'image_count': 35}, {'id': 328, 'image_count': 36}, {'id': 329, 'image_count': 32}, {'id': 330, 'image_count': 1027}, {'id': 331, 'image_count': 10}, {'id': 332, 'image_count': 12}, {'id': 333, 'image_count': 1}, {'id': 334, 'image_count': 67}, {'id': 335, 'image_count': 71}, {'id': 336, 'image_count': 30}, {'id': 337, 'image_count': 48}, {'id': 338, 'image_count': 249}, {'id': 339, 'image_count': 13}, {'id': 340, 'image_count': 29}, {'id': 341, 'image_count': 14}, {'id': 342, 'image_count': 236}, {'id': 343, 'image_count': 15}, {'id': 344, 'image_count': 1521}, {'id': 345, 'image_count': 25}, {'id': 346, 'image_count': 249}, {'id': 347, 'image_count': 139}, {'id': 348, 'image_count': 2}, {'id': 349, 'image_count': 2}, {'id': 350, 'image_count': 1890}, {'id': 351, 'image_count': 1240}, {'id': 352, 'image_count': 1}, {'id': 353, 'image_count': 9}, {'id': 354, 'image_count': 1}, {'id': 355, 'image_count': 3}, {'id': 356, 'image_count': 11}, {'id': 357, 'image_count': 4}, {'id': 358, 'image_count': 236}, {'id': 359, 'image_count': 44}, {'id': 360, 'image_count': 19}, {'id': 361, 'image_count': 1100}, {'id': 362, 'image_count': 7}, {'id': 363, 'image_count': 69}, {'id': 364, 'image_count': 2}, {'id': 365, 'image_count': 8}, {'id': 366, 'image_count': 5}, {'id': 367, 'image_count': 227}, {'id': 368, 'image_count': 6}, {'id': 369, 'image_count': 106}, {'id': 370, 'image_count': 81}, {'id': 371, 'image_count': 17}, {'id': 372, 'image_count': 134}, {'id': 373, 'image_count': 312}, {'id': 374, 'image_count': 8}, {'id': 375, 'image_count': 271}, {'id': 376, 'image_count': 2}, {'id': 377, 'image_count': 103}, {'id': 378, 'image_count': 1938}, {'id': 379, 'image_count': 574}, {'id': 380, 'image_count': 120}, {'id': 381, 'image_count': 2}, {'id': 382, 'image_count': 2}, {'id': 383, 'image_count': 13}, {'id': 384, 'image_count': 29}, {'id': 385, 'image_count': 1710}, {'id': 386, 'image_count': 66}, {'id': 387, 'image_count': 1008}, {'id': 388, 'image_count': 1}, {'id': 389, 'image_count': 3}, {'id': 390, 'image_count': 1942}, {'id': 391, 'image_count': 19}, {'id': 392, 'image_count': 1488}, {'id': 393, 'image_count': 46}, {'id': 394, 'image_count': 106}, {'id': 395, 'image_count': 115}, {'id': 396, 'image_count': 19}, {'id': 397, 'image_count': 2}, {'id': 398, 'image_count': 1}, {'id': 399, 'image_count': 28}, {'id': 400, 'image_count': 9}, {'id': 401, 'image_count': 192}, {'id': 402, 'image_count': 12}, {'id': 403, 'image_count': 21}, {'id': 404, 'image_count': 247}, {'id': 405, 'image_count': 6}, {'id': 406, 'image_count': 64}, {'id': 407, 'image_count': 7}, {'id': 408, 'image_count': 40}, {'id': 409, 'image_count': 542}, {'id': 410, 'image_count': 2}, {'id': 411, 'image_count': 1898}, {'id': 412, 'image_count': 36}, {'id': 413, 'image_count': 4}, {'id': 414, 'image_count': 1}, {'id': 415, 'image_count': 191}, {'id': 416, 'image_count': 6}, {'id': 417, 'image_count': 41}, {'id': 418, 'image_count': 39}, {'id': 419, 'image_count': 46}, {'id': 420, 'image_count': 1}, {'id': 421, 'image_count': 1451}, {'id': 422, 'image_count': 1878}, {'id': 423, 'image_count': 11}, {'id': 424, 'image_count': 82}, {'id': 425, 'image_count': 18}, {'id': 426, 'image_count': 1}, {'id': 427, 'image_count': 7}, {'id': 428, 'image_count': 3}, {'id': 429, 'image_count': 575}, {'id': 430, 'image_count': 1907}, {'id': 431, 'image_count': 8}, {'id': 432, 'image_count': 4}, {'id': 433, 'image_count': 32}, {'id': 434, 'image_count': 11}, {'id': 435, 'image_count': 4}, {'id': 436, 'image_count': 54}, {'id': 437, 'image_count': 202}, {'id': 438, 'image_count': 32}, {'id': 439, 'image_count': 3}, {'id': 440, 'image_count': 130}, {'id': 441, 'image_count': 119}, {'id': 442, 'image_count': 141}, {'id': 443, 'image_count': 29}, {'id': 444, 'image_count': 525}, {'id': 445, 'image_count': 1323}, {'id': 446, 'image_count': 2}, {'id': 447, 'image_count': 113}, {'id': 448, 'image_count': 16}, {'id': 449, 'image_count': 7}, {'id': 450, 'image_count': 35}, {'id': 451, 'image_count': 1908}, {'id': 452, 'image_count': 353}, {'id': 453, 'image_count': 18}, {'id': 454, 'image_count': 14}, {'id': 455, 'image_count': 77}, {'id': 456, 'image_count': 8}, {'id': 457, 'image_count': 37}, {'id': 458, 'image_count': 1}, {'id': 459, 'image_count': 346}, {'id': 460, 'image_count': 19}, {'id': 461, 'image_count': 1779}, {'id': 462, 'image_count': 23}, {'id': 463, 'image_count': 25}, {'id': 464, 'image_count': 67}, {'id': 465, 'image_count': 19}, {'id': 466, 'image_count': 28}, {'id': 467, 'image_count': 4}, {'id': 468, 'image_count': 27}, {'id': 469, 'image_count': 1861}, {'id': 470, 'image_count': 11}, {'id': 471, 'image_count': 13}, {'id': 472, 'image_count': 13}, {'id': 473, 'image_count': 32}, {'id': 474, 'image_count': 1767}, {'id': 475, 'image_count': 42}, {'id': 476, 'image_count': 17}, {'id': 477, 'image_count': 128}, {'id': 478, 'image_count': 1}, {'id': 479, 'image_count': 9}, {'id': 480, 'image_count': 10}, {'id': 481, 'image_count': 4}, {'id': 482, 'image_count': 9}, {'id': 483, 'image_count': 18}, {'id': 484, 'image_count': 41}, {'id': 485, 'image_count': 28}, {'id': 486, 'image_count': 3}, {'id': 487, 'image_count': 65}, {'id': 488, 'image_count': 9}, {'id': 489, 'image_count': 23}, {'id': 490, 'image_count': 24}, {'id': 491, 'image_count': 1}, {'id': 492, 'image_count': 2}, {'id': 493, 'image_count': 59}, {'id': 494, 'image_count': 48}, {'id': 495, 'image_count': 17}, {'id': 496, 'image_count': 1877}, {'id': 497, 'image_count': 18}, {'id': 498, 'image_count': 1920}, {'id': 499, 'image_count': 50}, {'id': 500, 'image_count': 1890}, {'id': 501, 'image_count': 99}, {'id': 502, 'image_count': 1530}, {'id': 503, 'image_count': 3}, {'id': 504, 'image_count': 11}, {'id': 505, 'image_count': 19}, {'id': 506, 'image_count': 3}, {'id': 507, 'image_count': 63}, {'id': 508, 'image_count': 5}, {'id': 509, 'image_count': 6}, {'id': 510, 'image_count': 233}, {'id': 511, 'image_count': 54}, {'id': 512, 'image_count': 36}, {'id': 513, 'image_count': 10}, {'id': 514, 'image_count': 124}, {'id': 515, 'image_count': 101}, {'id': 516, 'image_count': 3}, {'id': 517, 'image_count': 363}, {'id': 518, 'image_count': 3}, {'id': 519, 'image_count': 30}, {'id': 520, 'image_count': 18}, {'id': 521, 'image_count': 199}, {'id': 522, 'image_count': 97}, {'id': 523, 'image_count': 32}, {'id': 524, 'image_count': 121}, {'id': 525, 'image_count': 16}, {'id': 526, 'image_count': 12}, {'id': 527, 'image_count': 2}, {'id': 528, 'image_count': 214}, {'id': 529, 'image_count': 48}, {'id': 530, 'image_count': 26}, {'id': 531, 'image_count': 13}, {'id': 532, 'image_count': 4}, {'id': 533, 'image_count': 11}, {'id': 534, 'image_count': 123}, {'id': 535, 'image_count': 7}, {'id': 536, 'image_count': 200}, {'id': 537, 'image_count': 91}, {'id': 538, 'image_count': 9}, {'id': 539, 'image_count': 72}, {'id': 540, 'image_count': 1886}, {'id': 541, 'image_count': 4}, {'id': 542, 'image_count': 1}, {'id': 543, 'image_count': 1}, {'id': 544, 'image_count': 1932}, {'id': 545, 'image_count': 4}, {'id': 546, 'image_count': 56}, {'id': 547, 'image_count': 854}, {'id': 548, 'image_count': 755}, {'id': 549, 'image_count': 1843}, {'id': 550, 'image_count': 96}, {'id': 551, 'image_count': 7}, {'id': 552, 'image_count': 74}, {'id': 553, 'image_count': 66}, {'id': 554, 'image_count': 57}, {'id': 555, 'image_count': 44}, {'id': 556, 'image_count': 1905}, {'id': 557, 'image_count': 4}, {'id': 558, 'image_count': 90}, {'id': 559, 'image_count': 1635}, {'id': 560, 'image_count': 8}, {'id': 561, 'image_count': 5}, {'id': 562, 'image_count': 50}, {'id': 563, 'image_count': 545}, {'id': 564, 'image_count': 20}, {'id': 565, 'image_count': 193}, {'id': 566, 'image_count': 285}, {'id': 567, 'image_count': 3}, {'id': 568, 'image_count': 1}, {'id': 569, 'image_count': 1904}, {'id': 570, 'image_count': 294}, {'id': 571, 'image_count': 3}, {'id': 572, 'image_count': 5}, {'id': 573, 'image_count': 24}, {'id': 574, 'image_count': 2}, {'id': 575, 'image_count': 2}, {'id': 576, 'image_count': 16}, {'id': 577, 'image_count': 8}, {'id': 578, 'image_count': 154}, {'id': 579, 'image_count': 66}, {'id': 580, 'image_count': 1}, {'id': 581, 'image_count': 24}, {'id': 582, 'image_count': 1}, {'id': 583, 'image_count': 4}, {'id': 584, 'image_count': 75}, {'id': 585, 'image_count': 6}, {'id': 586, 'image_count': 126}, {'id': 587, 'image_count': 24}, {'id': 588, 'image_count': 22}, {'id': 589, 'image_count': 1872}, {'id': 590, 'image_count': 16}, {'id': 591, 'image_count': 423}, {'id': 592, 'image_count': 1927}, {'id': 593, 'image_count': 38}, {'id': 594, 'image_count': 3}, {'id': 595, 'image_count': 1945}, {'id': 596, 'image_count': 35}, {'id': 597, 'image_count': 1}, {'id': 598, 'image_count': 13}, {'id': 599, 'image_count': 9}, {'id': 600, 'image_count': 14}, {'id': 601, 'image_count': 37}, {'id': 602, 'image_count': 3}, {'id': 603, 'image_count': 4}, {'id': 604, 'image_count': 100}, {'id': 605, 'image_count': 195}, {'id': 606, 'image_count': 1}, {'id': 607, 'image_count': 12}, {'id': 608, 'image_count': 24}, {'id': 609, 'image_count': 489}, {'id': 610, 'image_count': 10}, {'id': 611, 'image_count': 1689}, {'id': 612, 'image_count': 42}, {'id': 613, 'image_count': 81}, {'id': 614, 'image_count': 894}, {'id': 615, 'image_count': 1868}, {'id': 616, 'image_count': 7}, {'id': 617, 'image_count': 1567}, {'id': 618, 'image_count': 10}, {'id': 619, 'image_count': 8}, {'id': 620, 'image_count': 7}, {'id': 621, 'image_count': 629}, {'id': 622, 'image_count': 89}, {'id': 623, 'image_count': 15}, {'id': 624, 'image_count': 134}, {'id': 625, 'image_count': 4}, {'id': 626, 'image_count': 1802}, {'id': 627, 'image_count': 595}, {'id': 628, 'image_count': 1210}, {'id': 629, 'image_count': 48}, {'id': 630, 'image_count': 418}, {'id': 631, 'image_count': 1846}, {'id': 632, 'image_count': 5}, {'id': 633, 'image_count': 221}, {'id': 634, 'image_count': 10}, {'id': 635, 'image_count': 7}, {'id': 636, 'image_count': 76}, {'id': 637, 'image_count': 22}, {'id': 638, 'image_count': 10}, {'id': 639, 'image_count': 341}, {'id': 640, 'image_count': 1}, {'id': 641, 'image_count': 705}, {'id': 642, 'image_count': 1900}, {'id': 643, 'image_count': 188}, {'id': 644, 'image_count': 227}, {'id': 645, 'image_count': 861}, {'id': 646, 'image_count': 6}, {'id': 647, 'image_count': 115}, {'id': 648, 'image_count': 5}, {'id': 649, 'image_count': 43}, {'id': 650, 'image_count': 14}, {'id': 651, 'image_count': 6}, {'id': 652, 'image_count': 15}, {'id': 653, 'image_count': 1167}, {'id': 654, 'image_count': 15}, {'id': 655, 'image_count': 994}, {'id': 656, 'image_count': 28}, {'id': 657, 'image_count': 2}, {'id': 658, 'image_count': 338}, {'id': 659, 'image_count': 334}, {'id': 660, 'image_count': 15}, {'id': 661, 'image_count': 102}, {'id': 662, 'image_count': 1}, {'id': 663, 'image_count': 8}, {'id': 664, 'image_count': 1}, {'id': 665, 'image_count': 1}, {'id': 666, 'image_count': 28}, {'id': 667, 'image_count': 91}, {'id': 668, 'image_count': 260}, {'id': 669, 'image_count': 131}, {'id': 670, 'image_count': 128}, {'id': 671, 'image_count': 3}, {'id': 672, 'image_count': 10}, {'id': 673, 'image_count': 39}, {'id': 674, 'image_count': 2}, {'id': 675, 'image_count': 925}, {'id': 676, 'image_count': 354}, {'id': 677, 'image_count': 31}, {'id': 678, 'image_count': 10}, {'id': 679, 'image_count': 215}, {'id': 680, 'image_count': 71}, {'id': 681, 'image_count': 43}, {'id': 682, 'image_count': 28}, {'id': 683, 'image_count': 34}, {'id': 684, 'image_count': 16}, {'id': 685, 'image_count': 273}, {'id': 686, 'image_count': 2}, {'id': 687, 'image_count': 999}, {'id': 688, 'image_count': 4}, {'id': 689, 'image_count': 107}, {'id': 690, 'image_count': 2}, {'id': 691, 'image_count': 1}, {'id': 692, 'image_count': 454}, {'id': 693, 'image_count': 9}, {'id': 694, 'image_count': 1901}, {'id': 695, 'image_count': 61}, {'id': 696, 'image_count': 91}, {'id': 697, 'image_count': 46}, {'id': 698, 'image_count': 1402}, {'id': 699, 'image_count': 74}, {'id': 700, 'image_count': 421}, {'id': 701, 'image_count': 226}, {'id': 702, 'image_count': 10}, {'id': 703, 'image_count': 1720}, {'id': 704, 'image_count': 261}, {'id': 705, 'image_count': 1337}, {'id': 706, 'image_count': 293}, {'id': 707, 'image_count': 62}, {'id': 708, 'image_count': 814}, {'id': 709, 'image_count': 407}, {'id': 710, 'image_count': 6}, {'id': 711, 'image_count': 16}, {'id': 712, 'image_count': 7}, {'id': 713, 'image_count': 1791}, {'id': 714, 'image_count': 2}, {'id': 715, 'image_count': 1915}, {'id': 716, 'image_count': 1940}, {'id': 717, 'image_count': 13}, {'id': 718, 'image_count': 16}, {'id': 719, 'image_count': 448}, {'id': 720, 'image_count': 12}, {'id': 721, 'image_count': 18}, {'id': 722, 'image_count': 4}, {'id': 723, 'image_count': 71}, {'id': 724, 'image_count': 189}, {'id': 725, 'image_count': 74}, {'id': 726, 'image_count': 103}, {'id': 727, 'image_count': 3}, {'id': 728, 'image_count': 110}, {'id': 729, 'image_count': 5}, {'id': 730, 'image_count': 9}, {'id': 731, 'image_count': 15}, {'id': 732, 'image_count': 25}, {'id': 733, 'image_count': 7}, {'id': 734, 'image_count': 647}, {'id': 735, 'image_count': 824}, {'id': 736, 'image_count': 100}, {'id': 737, 'image_count': 47}, {'id': 738, 'image_count': 121}, {'id': 739, 'image_count': 731}, {'id': 740, 'image_count': 73}, {'id': 741, 'image_count': 49}, {'id': 742, 'image_count': 23}, {'id': 743, 'image_count': 4}, {'id': 744, 'image_count': 62}, {'id': 745, 'image_count': 118}, {'id': 746, 'image_count': 99}, {'id': 747, 'image_count': 40}, {'id': 748, 'image_count': 1036}, {'id': 749, 'image_count': 105}, {'id': 750, 'image_count': 21}, {'id': 751, 'image_count': 229}, {'id': 752, 'image_count': 7}, {'id': 753, 'image_count': 72}, {'id': 754, 'image_count': 9}, {'id': 755, 'image_count': 10}, {'id': 756, 'image_count': 328}, {'id': 757, 'image_count': 468}, {'id': 758, 'image_count': 1}, {'id': 759, 'image_count': 2}, {'id': 760, 'image_count': 24}, {'id': 761, 'image_count': 11}, {'id': 762, 'image_count': 72}, {'id': 763, 'image_count': 17}, {'id': 764, 'image_count': 10}, {'id': 765, 'image_count': 17}, {'id': 766, 'image_count': 489}, {'id': 767, 'image_count': 47}, {'id': 768, 'image_count': 93}, {'id': 769, 'image_count': 1}, {'id': 770, 'image_count': 12}, {'id': 771, 'image_count': 228}, {'id': 772, 'image_count': 5}, {'id': 773, 'image_count': 76}, {'id': 774, 'image_count': 71}, {'id': 775, 'image_count': 30}, {'id': 776, 'image_count': 109}, {'id': 777, 'image_count': 14}, {'id': 778, 'image_count': 1}, {'id': 779, 'image_count': 8}, {'id': 780, 'image_count': 26}, {'id': 781, 'image_count': 339}, {'id': 782, 'image_count': 153}, {'id': 783, 'image_count': 2}, {'id': 784, 'image_count': 3}, {'id': 785, 'image_count': 8}, {'id': 786, 'image_count': 47}, {'id': 787, 'image_count': 8}, {'id': 788, 'image_count': 6}, {'id': 789, 'image_count': 116}, {'id': 790, 'image_count': 69}, {'id': 791, 'image_count': 13}, {'id': 792, 'image_count': 6}, {'id': 793, 'image_count': 1928}, {'id': 794, 'image_count': 79}, {'id': 795, 'image_count': 14}, {'id': 796, 'image_count': 7}, {'id': 797, 'image_count': 20}, {'id': 798, 'image_count': 114}, {'id': 799, 'image_count': 221}, {'id': 800, 'image_count': 502}, {'id': 801, 'image_count': 62}, {'id': 802, 'image_count': 87}, {'id': 803, 'image_count': 4}, {'id': 804, 'image_count': 1912}, {'id': 805, 'image_count': 7}, {'id': 806, 'image_count': 186}, {'id': 807, 'image_count': 18}, {'id': 808, 'image_count': 4}, {'id': 809, 'image_count': 3}, {'id': 810, 'image_count': 7}, {'id': 811, 'image_count': 1413}, {'id': 812, 'image_count': 7}, {'id': 813, 'image_count': 12}, {'id': 814, 'image_count': 248}, {'id': 815, 'image_count': 4}, {'id': 816, 'image_count': 1881}, {'id': 817, 'image_count': 529}, {'id': 818, 'image_count': 1932}, {'id': 819, 'image_count': 50}, {'id': 820, 'image_count': 3}, {'id': 821, 'image_count': 28}, {'id': 822, 'image_count': 10}, {'id': 823, 'image_count': 5}, {'id': 824, 'image_count': 5}, {'id': 825, 'image_count': 18}, {'id': 826, 'image_count': 14}, {'id': 827, 'image_count': 1890}, {'id': 828, 'image_count': 660}, {'id': 829, 'image_count': 8}, {'id': 830, 'image_count': 25}, {'id': 831, 'image_count': 10}, {'id': 832, 'image_count': 218}, {'id': 833, 'image_count': 36}, {'id': 834, 'image_count': 16}, {'id': 835, 'image_count': 808}, {'id': 836, 'image_count': 479}, {'id': 837, 'image_count': 1404}, {'id': 838, 'image_count': 307}, {'id': 839, 'image_count': 57}, {'id': 840, 'image_count': 28}, {'id': 841, 'image_count': 80}, {'id': 842, 'image_count': 11}, {'id': 843, 'image_count': 92}, {'id': 844, 'image_count': 20}, {'id': 845, 'image_count': 194}, {'id': 846, 'image_count': 23}, {'id': 847, 'image_count': 52}, {'id': 848, 'image_count': 673}, {'id': 849, 'image_count': 2}, {'id': 850, 'image_count': 2}, {'id': 851, 'image_count': 1}, {'id': 852, 'image_count': 2}, {'id': 853, 'image_count': 8}, {'id': 854, 'image_count': 80}, {'id': 855, 'image_count': 3}, {'id': 856, 'image_count': 3}, {'id': 857, 'image_count': 15}, {'id': 858, 'image_count': 2}, {'id': 859, 'image_count': 10}, {'id': 860, 'image_count': 386}, {'id': 861, 'image_count': 65}, {'id': 862, 'image_count': 3}, {'id': 863, 'image_count': 35}, {'id': 864, 'image_count': 5}, {'id': 865, 'image_count': 180}, {'id': 866, 'image_count': 99}, {'id': 867, 'image_count': 49}, {'id': 868, 'image_count': 28}, {'id': 869, 'image_count': 1}, {'id': 870, 'image_count': 52}, {'id': 871, 'image_count': 36}, {'id': 872, 'image_count': 70}, {'id': 873, 'image_count': 6}, {'id': 874, 'image_count': 29}, {'id': 875, 'image_count': 24}, {'id': 876, 'image_count': 1115}, {'id': 877, 'image_count': 61}, {'id': 878, 'image_count': 18}, {'id': 879, 'image_count': 18}, {'id': 880, 'image_count': 665}, {'id': 881, 'image_count': 1096}, {'id': 882, 'image_count': 29}, {'id': 883, 'image_count': 8}, {'id': 884, 'image_count': 14}, {'id': 885, 'image_count': 1622}, {'id': 886, 'image_count': 2}, {'id': 887, 'image_count': 3}, {'id': 888, 'image_count': 32}, {'id': 889, 'image_count': 55}, {'id': 890, 'image_count': 1}, {'id': 891, 'image_count': 10}, {'id': 892, 'image_count': 10}, {'id': 893, 'image_count': 47}, {'id': 894, 'image_count': 3}, {'id': 895, 'image_count': 29}, {'id': 896, 'image_count': 342}, {'id': 897, 'image_count': 25}, {'id': 898, 'image_count': 1469}, {'id': 899, 'image_count': 521}, {'id': 900, 'image_count': 347}, {'id': 901, 'image_count': 35}, {'id': 902, 'image_count': 7}, {'id': 903, 'image_count': 207}, {'id': 904, 'image_count': 108}, {'id': 905, 'image_count': 2}, {'id': 906, 'image_count': 34}, {'id': 907, 'image_count': 12}, {'id': 908, 'image_count': 10}, {'id': 909, 'image_count': 13}, {'id': 910, 'image_count': 361}, {'id': 911, 'image_count': 1023}, {'id': 912, 'image_count': 782}, {'id': 913, 'image_count': 2}, {'id': 914, 'image_count': 5}, {'id': 915, 'image_count': 247}, {'id': 916, 'image_count': 221}, {'id': 917, 'image_count': 4}, {'id': 918, 'image_count': 8}, {'id': 919, 'image_count': 158}, {'id': 920, 'image_count': 3}, {'id': 921, 'image_count': 752}, {'id': 922, 'image_count': 64}, {'id': 923, 'image_count': 707}, {'id': 924, 'image_count': 143}, {'id': 925, 'image_count': 1}, {'id': 926, 'image_count': 49}, {'id': 927, 'image_count': 126}, {'id': 928, 'image_count': 76}, {'id': 929, 'image_count': 11}, {'id': 930, 'image_count': 11}, {'id': 931, 'image_count': 4}, {'id': 932, 'image_count': 39}, {'id': 933, 'image_count': 11}, {'id': 934, 'image_count': 13}, {'id': 935, 'image_count': 91}, {'id': 936, 'image_count': 14}, {'id': 937, 'image_count': 5}, {'id': 938, 'image_count': 3}, {'id': 939, 'image_count': 10}, {'id': 940, 'image_count': 18}, {'id': 941, 'image_count': 9}, {'id': 942, 'image_count': 6}, {'id': 943, 'image_count': 951}, {'id': 944, 'image_count': 2}, {'id': 945, 'image_count': 1}, {'id': 946, 'image_count': 19}, {'id': 947, 'image_count': 1942}, {'id': 948, 'image_count': 1916}, {'id': 949, 'image_count': 139}, {'id': 950, 'image_count': 43}, {'id': 951, 'image_count': 1969}, {'id': 952, 'image_count': 5}, {'id': 953, 'image_count': 134}, {'id': 954, 'image_count': 74}, {'id': 955, 'image_count': 381}, {'id': 956, 'image_count': 1}, {'id': 957, 'image_count': 381}, {'id': 958, 'image_count': 6}, {'id': 959, 'image_count': 1826}, {'id': 960, 'image_count': 28}, {'id': 961, 'image_count': 1635}, {'id': 962, 'image_count': 1967}, {'id': 963, 'image_count': 16}, {'id': 964, 'image_count': 1926}, {'id': 965, 'image_count': 1789}, {'id': 966, 'image_count': 401}, {'id': 967, 'image_count': 1968}, {'id': 968, 'image_count': 1167}, {'id': 969, 'image_count': 1}, {'id': 970, 'image_count': 56}, {'id': 971, 'image_count': 17}, {'id': 972, 'image_count': 1}, {'id': 973, 'image_count': 58}, {'id': 974, 'image_count': 9}, {'id': 975, 'image_count': 8}, {'id': 976, 'image_count': 1124}, {'id': 977, 'image_count': 31}, {'id': 978, 'image_count': 16}, {'id': 979, 'image_count': 491}, {'id': 980, 'image_count': 432}, {'id': 981, 'image_count': 1945}, {'id': 982, 'image_count': 1899}, {'id': 983, 'image_count': 5}, {'id': 984, 'image_count': 28}, {'id': 985, 'image_count': 7}, {'id': 986, 'image_count': 146}, {'id': 987, 'image_count': 1}, {'id': 988, 'image_count': 25}, {'id': 989, 'image_count': 22}, {'id': 990, 'image_count': 1}, {'id': 991, 'image_count': 10}, {'id': 992, 'image_count': 9}, {'id': 993, 'image_count': 308}, {'id': 994, 'image_count': 4}, {'id': 995, 'image_count': 1969}, {'id': 996, 'image_count': 45}, {'id': 997, 'image_count': 12}, {'id': 998, 'image_count': 1}, {'id': 999, 'image_count': 85}, {'id': 1000, 'image_count': 1127}, {'id': 1001, 'image_count': 11}, {'id': 1002, 'image_count': 60}, {'id': 1003, 'image_count': 1}, {'id': 1004, 'image_count': 16}, {'id': 1005, 'image_count': 1}, {'id': 1006, 'image_count': 65}, {'id': 1007, 'image_count': 13}, {'id': 1008, 'image_count': 655}, {'id': 1009, 'image_count': 51}, {'id': 1010, 'image_count': 1}, {'id': 1011, 'image_count': 673}, {'id': 1012, 'image_count': 5}, {'id': 1013, 'image_count': 36}, {'id': 1014, 'image_count': 54}, {'id': 1015, 'image_count': 5}, {'id': 1016, 'image_count': 8}, {'id': 1017, 'image_count': 305}, {'id': 1018, 'image_count': 297}, {'id': 1019, 'image_count': 1053}, {'id': 1020, 'image_count': 223}, {'id': 1021, 'image_count': 1037}, {'id': 1022, 'image_count': 63}, {'id': 1023, 'image_count': 1881}, {'id': 1024, 'image_count': 507}, {'id': 1025, 'image_count': 333}, {'id': 1026, 'image_count': 1911}, {'id': 1027, 'image_count': 1765}, {'id': 1028, 'image_count': 1}, {'id': 1029, 'image_count': 5}, {'id': 1030, 'image_count': 1}, {'id': 1031, 'image_count': 9}, {'id': 1032, 'image_count': 2}, {'id': 1033, 'image_count': 151}, {'id': 1034, 'image_count': 82}, {'id': 1035, 'image_count': 1931}, {'id': 1036, 'image_count': 41}, {'id': 1037, 'image_count': 1895}, {'id': 1038, 'image_count': 24}, {'id': 1039, 'image_count': 22}, {'id': 1040, 'image_count': 35}, {'id': 1041, 'image_count': 69}, {'id': 1042, 'image_count': 962}, {'id': 1043, 'image_count': 588}, {'id': 1044, 'image_count': 21}, {'id': 1045, 'image_count': 825}, {'id': 1046, 'image_count': 52}, {'id': 1047, 'image_count': 5}, {'id': 1048, 'image_count': 5}, {'id': 1049, 'image_count': 5}, {'id': 1050, 'image_count': 1860}, {'id': 1051, 'image_count': 56}, {'id': 1052, 'image_count': 1582}, {'id': 1053, 'image_count': 7}, {'id': 1054, 'image_count': 2}, {'id': 1055, 'image_count': 1562}, {'id': 1056, 'image_count': 1885}, {'id': 1057, 'image_count': 1}, {'id': 1058, 'image_count': 5}, {'id': 1059, 'image_count': 137}, {'id': 1060, 'image_count': 1094}, {'id': 1061, 'image_count': 134}, {'id': 1062, 'image_count': 29}, {'id': 1063, 'image_count': 22}, {'id': 1064, 'image_count': 522}, {'id': 1065, 'image_count': 50}, {'id': 1066, 'image_count': 68}, {'id': 1067, 'image_count': 16}, {'id': 1068, 'image_count': 40}, {'id': 1069, 'image_count': 35}, {'id': 1070, 'image_count': 135}, {'id': 1071, 'image_count': 1413}, {'id': 1072, 'image_count': 772}, {'id': 1073, 'image_count': 50}, {'id': 1074, 'image_count': 1015}, {'id': 1075, 'image_count': 1}, {'id': 1076, 'image_count': 65}, {'id': 1077, 'image_count': 1900}, {'id': 1078, 'image_count': 1302}, {'id': 1079, 'image_count': 1977}, {'id': 1080, 'image_count': 2}, {'id': 1081, 'image_count': 29}, {'id': 1082, 'image_count': 36}, {'id': 1083, 'image_count': 138}, {'id': 1084, 'image_count': 4}, {'id': 1085, 'image_count': 67}, {'id': 1086, 'image_count': 26}, {'id': 1087, 'image_count': 25}, {'id': 1088, 'image_count': 33}, {'id': 1089, 'image_count': 37}, {'id': 1090, 'image_count': 50}, {'id': 1091, 'image_count': 270}, {'id': 1092, 'image_count': 12}, {'id': 1093, 'image_count': 316}, {'id': 1094, 'image_count': 41}, {'id': 1095, 'image_count': 224}, {'id': 1096, 'image_count': 105}, {'id': 1097, 'image_count': 1925}, {'id': 1098, 'image_count': 1021}, {'id': 1099, 'image_count': 1213}, {'id': 1100, 'image_count': 172}, {'id': 1101, 'image_count': 28}, {'id': 1102, 'image_count': 745}, {'id': 1103, 'image_count': 187}, {'id': 1104, 'image_count': 147}, {'id': 1105, 'image_count': 136}, {'id': 1106, 'image_count': 34}, {'id': 1107, 'image_count': 41}, {'id': 1108, 'image_count': 636}, {'id': 1109, 'image_count': 570}, {'id': 1110, 'image_count': 1149}, {'id': 1111, 'image_count': 61}, {'id': 1112, 'image_count': 1890}, {'id': 1113, 'image_count': 18}, {'id': 1114, 'image_count': 143}, {'id': 1115, 'image_count': 1517}, {'id': 1116, 'image_count': 7}, {'id': 1117, 'image_count': 943}, {'id': 1118, 'image_count': 6}, {'id': 1119, 'image_count': 1}, {'id': 1120, 'image_count': 11}, {'id': 1121, 'image_count': 101}, {'id': 1122, 'image_count': 1909}, {'id': 1123, 'image_count': 800}, {'id': 1124, 'image_count': 1}, {'id': 1125, 'image_count': 44}, {'id': 1126, 'image_count': 3}, {'id': 1127, 'image_count': 44}, {'id': 1128, 'image_count': 31}, {'id': 1129, 'image_count': 7}, {'id': 1130, 'image_count': 20}, {'id': 1131, 'image_count': 11}, {'id': 1132, 'image_count': 13}, {'id': 1133, 'image_count': 1924}, {'id': 1134, 'image_count': 113}, {'id': 1135, 'image_count': 2}, {'id': 1136, 'image_count': 139}, {'id': 1137, 'image_count': 12}, {'id': 1138, 'image_count': 37}, {'id': 1139, 'image_count': 1866}, {'id': 1140, 'image_count': 47}, {'id': 1141, 'image_count': 1468}, {'id': 1142, 'image_count': 729}, {'id': 1143, 'image_count': 24}, {'id': 1144, 'image_count': 1}, {'id': 1145, 'image_count': 10}, {'id': 1146, 'image_count': 3}, {'id': 1147, 'image_count': 14}, {'id': 1148, 'image_count': 4}, {'id': 1149, 'image_count': 29}, {'id': 1150, 'image_count': 4}, {'id': 1151, 'image_count': 70}, {'id': 1152, 'image_count': 46}, {'id': 1153, 'image_count': 14}, {'id': 1154, 'image_count': 48}, {'id': 1155, 'image_count': 1855}, {'id': 1156, 'image_count': 113}, {'id': 1157, 'image_count': 1}, {'id': 1158, 'image_count': 1}, {'id': 1159, 'image_count': 10}, {'id': 1160, 'image_count': 54}, {'id': 1161, 'image_count': 1923}, {'id': 1162, 'image_count': 630}, {'id': 1163, 'image_count': 31}, {'id': 1164, 'image_count': 69}, {'id': 1165, 'image_count': 7}, {'id': 1166, 'image_count': 11}, {'id': 1167, 'image_count': 1}, {'id': 1168, 'image_count': 30}, {'id': 1169, 'image_count': 50}, {'id': 1170, 'image_count': 45}, {'id': 1171, 'image_count': 28}, {'id': 1172, 'image_count': 114}, {'id': 1173, 'image_count': 193}, {'id': 1174, 'image_count': 21}, {'id': 1175, 'image_count': 91}, {'id': 1176, 'image_count': 31}, {'id': 1177, 'image_count': 1469}, {'id': 1178, 'image_count': 1924}, {'id': 1179, 'image_count': 87}, {'id': 1180, 'image_count': 77}, {'id': 1181, 'image_count': 11}, {'id': 1182, 'image_count': 47}, {'id': 1183, 'image_count': 21}, {'id': 1184, 'image_count': 47}, {'id': 1185, 'image_count': 70}, {'id': 1186, 'image_count': 1838}, {'id': 1187, 'image_count': 19}, {'id': 1188, 'image_count': 531}, {'id': 1189, 'image_count': 11}, {'id': 1190, 'image_count': 941}, {'id': 1191, 'image_count': 113}, {'id': 1192, 'image_count': 26}, {'id': 1193, 'image_count': 5}, {'id': 1194, 'image_count': 56}, {'id': 1195, 'image_count': 73}, {'id': 1196, 'image_count': 32}, {'id': 1197, 'image_count': 128}, {'id': 1198, 'image_count': 623}, {'id': 1199, 'image_count': 12}, {'id': 1200, 'image_count': 52}, {'id': 1201, 'image_count': 11}, {'id': 1202, 'image_count': 1674}, {'id': 1203, 'image_count': 81}] # noqa
|
| 20 |
+
# fmt: on
|
CatVTON/detectron2/data/datasets/pascal_voc.py
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
import os
|
| 6 |
+
import xml.etree.ElementTree as ET
|
| 7 |
+
from typing import List, Tuple, Union
|
| 8 |
+
|
| 9 |
+
from detectron2.data import DatasetCatalog, MetadataCatalog
|
| 10 |
+
from detectron2.structures import BoxMode
|
| 11 |
+
from detectron2.utils.file_io import PathManager
|
| 12 |
+
|
| 13 |
+
__all__ = ["load_voc_instances", "register_pascal_voc"]
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
# fmt: off
|
| 17 |
+
CLASS_NAMES = (
|
| 18 |
+
"aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat",
|
| 19 |
+
"chair", "cow", "diningtable", "dog", "horse", "motorbike", "person",
|
| 20 |
+
"pottedplant", "sheep", "sofa", "train", "tvmonitor"
|
| 21 |
+
)
|
| 22 |
+
# fmt: on
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def load_voc_instances(dirname: str, split: str, class_names: Union[List[str], Tuple[str, ...]]):
|
| 26 |
+
"""
|
| 27 |
+
Load Pascal VOC detection annotations to Detectron2 format.
|
| 28 |
+
|
| 29 |
+
Args:
|
| 30 |
+
dirname: Contain "Annotations", "ImageSets", "JPEGImages"
|
| 31 |
+
split (str): one of "train", "test", "val", "trainval"
|
| 32 |
+
class_names: list or tuple of class names
|
| 33 |
+
"""
|
| 34 |
+
with PathManager.open(os.path.join(dirname, "ImageSets", "Main", split + ".txt")) as f:
|
| 35 |
+
fileids = np.loadtxt(f, dtype=str)
|
| 36 |
+
|
| 37 |
+
# Needs to read many small annotation files. Makes sense at local
|
| 38 |
+
annotation_dirname = PathManager.get_local_path(os.path.join(dirname, "Annotations/"))
|
| 39 |
+
dicts = []
|
| 40 |
+
for fileid in fileids:
|
| 41 |
+
anno_file = os.path.join(annotation_dirname, fileid + ".xml")
|
| 42 |
+
jpeg_file = os.path.join(dirname, "JPEGImages", fileid + ".jpg")
|
| 43 |
+
|
| 44 |
+
with PathManager.open(anno_file) as f:
|
| 45 |
+
tree = ET.parse(f)
|
| 46 |
+
|
| 47 |
+
r = {
|
| 48 |
+
"file_name": jpeg_file,
|
| 49 |
+
"image_id": fileid,
|
| 50 |
+
"height": int(tree.findall("./size/height")[0].text),
|
| 51 |
+
"width": int(tree.findall("./size/width")[0].text),
|
| 52 |
+
}
|
| 53 |
+
instances = []
|
| 54 |
+
|
| 55 |
+
for obj in tree.findall("object"):
|
| 56 |
+
cls = obj.find("name").text
|
| 57 |
+
# We include "difficult" samples in training.
|
| 58 |
+
# Based on limited experiments, they don't hurt accuracy.
|
| 59 |
+
# difficult = int(obj.find("difficult").text)
|
| 60 |
+
# if difficult == 1:
|
| 61 |
+
# continue
|
| 62 |
+
bbox = obj.find("bndbox")
|
| 63 |
+
bbox = [float(bbox.find(x).text) for x in ["xmin", "ymin", "xmax", "ymax"]]
|
| 64 |
+
# Original annotations are integers in the range [1, W or H]
|
| 65 |
+
# Assuming they mean 1-based pixel indices (inclusive),
|
| 66 |
+
# a box with annotation (xmin=1, xmax=W) covers the whole image.
|
| 67 |
+
# In coordinate space this is represented by (xmin=0, xmax=W)
|
| 68 |
+
bbox[0] -= 1.0
|
| 69 |
+
bbox[1] -= 1.0
|
| 70 |
+
instances.append(
|
| 71 |
+
{"category_id": class_names.index(cls), "bbox": bbox, "bbox_mode": BoxMode.XYXY_ABS}
|
| 72 |
+
)
|
| 73 |
+
r["annotations"] = instances
|
| 74 |
+
dicts.append(r)
|
| 75 |
+
return dicts
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def register_pascal_voc(name, dirname, split, year, class_names=CLASS_NAMES):
|
| 79 |
+
DatasetCatalog.register(name, lambda: load_voc_instances(dirname, split, class_names))
|
| 80 |
+
MetadataCatalog.get(name).set(
|
| 81 |
+
thing_classes=list(class_names), dirname=dirname, year=year, split=split
|
| 82 |
+
)
|
CatVTON/detectron2/data/datasets/register_coco.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
from .coco import register_coco_instances # noqa
|
| 3 |
+
from .coco_panoptic import register_coco_panoptic_separated # noqa
|
CatVTON/detectron2/data/detection_utils.py
ADDED
|
@@ -0,0 +1,661 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 3 |
+
|
| 4 |
+
"""
|
| 5 |
+
Common data processing utilities that are used in a
|
| 6 |
+
typical object detection data pipeline.
|
| 7 |
+
"""
|
| 8 |
+
import logging
|
| 9 |
+
import numpy as np
|
| 10 |
+
from typing import List, Union
|
| 11 |
+
import pycocotools.mask as mask_util
|
| 12 |
+
import torch
|
| 13 |
+
from PIL import Image
|
| 14 |
+
|
| 15 |
+
from detectron2.structures import (
|
| 16 |
+
BitMasks,
|
| 17 |
+
Boxes,
|
| 18 |
+
BoxMode,
|
| 19 |
+
Instances,
|
| 20 |
+
Keypoints,
|
| 21 |
+
PolygonMasks,
|
| 22 |
+
RotatedBoxes,
|
| 23 |
+
polygons_to_bitmask,
|
| 24 |
+
)
|
| 25 |
+
from detectron2.utils.file_io import PathManager
|
| 26 |
+
|
| 27 |
+
from . import transforms as T
|
| 28 |
+
from .catalog import MetadataCatalog
|
| 29 |
+
|
| 30 |
+
__all__ = [
|
| 31 |
+
"SizeMismatchError",
|
| 32 |
+
"convert_image_to_rgb",
|
| 33 |
+
"check_image_size",
|
| 34 |
+
"transform_proposals",
|
| 35 |
+
"transform_instance_annotations",
|
| 36 |
+
"annotations_to_instances",
|
| 37 |
+
"annotations_to_instances_rotated",
|
| 38 |
+
"build_augmentation",
|
| 39 |
+
"build_transform_gen",
|
| 40 |
+
"create_keypoint_hflip_indices",
|
| 41 |
+
"filter_empty_instances",
|
| 42 |
+
"read_image",
|
| 43 |
+
]
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class SizeMismatchError(ValueError):
|
| 47 |
+
"""
|
| 48 |
+
When loaded image has difference width/height compared with annotation.
|
| 49 |
+
"""
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
# https://en.wikipedia.org/wiki/YUV#SDTV_with_BT.601
|
| 53 |
+
_M_RGB2YUV = [[0.299, 0.587, 0.114], [-0.14713, -0.28886, 0.436], [0.615, -0.51499, -0.10001]]
|
| 54 |
+
_M_YUV2RGB = [[1.0, 0.0, 1.13983], [1.0, -0.39465, -0.58060], [1.0, 2.03211, 0.0]]
|
| 55 |
+
|
| 56 |
+
# https://www.exiv2.org/tags.html
|
| 57 |
+
_EXIF_ORIENT = 274 # exif 'Orientation' tag
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def convert_PIL_to_numpy(image, format):
|
| 61 |
+
"""
|
| 62 |
+
Convert PIL image to numpy array of target format.
|
| 63 |
+
|
| 64 |
+
Args:
|
| 65 |
+
image (PIL.Image): a PIL image
|
| 66 |
+
format (str): the format of output image
|
| 67 |
+
|
| 68 |
+
Returns:
|
| 69 |
+
(np.ndarray): also see `read_image`
|
| 70 |
+
"""
|
| 71 |
+
if format is not None:
|
| 72 |
+
# PIL only supports RGB, so convert to RGB and flip channels over below
|
| 73 |
+
conversion_format = format
|
| 74 |
+
if format in ["BGR", "YUV-BT.601"]:
|
| 75 |
+
conversion_format = "RGB"
|
| 76 |
+
image = image.convert(conversion_format)
|
| 77 |
+
image = np.asarray(image)
|
| 78 |
+
# PIL squeezes out the channel dimension for "L", so make it HWC
|
| 79 |
+
if format == "L":
|
| 80 |
+
image = np.expand_dims(image, -1)
|
| 81 |
+
|
| 82 |
+
# handle formats not supported by PIL
|
| 83 |
+
elif format == "BGR":
|
| 84 |
+
# flip channels if needed
|
| 85 |
+
image = image[:, :, ::-1]
|
| 86 |
+
elif format == "YUV-BT.601":
|
| 87 |
+
image = image / 255.0
|
| 88 |
+
image = np.dot(image, np.array(_M_RGB2YUV).T)
|
| 89 |
+
|
| 90 |
+
return image
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def convert_image_to_rgb(image, format):
|
| 94 |
+
"""
|
| 95 |
+
Convert an image from given format to RGB.
|
| 96 |
+
|
| 97 |
+
Args:
|
| 98 |
+
image (np.ndarray or Tensor): an HWC image
|
| 99 |
+
format (str): the format of input image, also see `read_image`
|
| 100 |
+
|
| 101 |
+
Returns:
|
| 102 |
+
(np.ndarray): (H,W,3) RGB image in 0-255 range, can be either float or uint8
|
| 103 |
+
"""
|
| 104 |
+
if isinstance(image, torch.Tensor):
|
| 105 |
+
image = image.cpu().numpy()
|
| 106 |
+
if format == "BGR":
|
| 107 |
+
image = image[:, :, [2, 1, 0]]
|
| 108 |
+
elif format == "YUV-BT.601":
|
| 109 |
+
image = np.dot(image, np.array(_M_YUV2RGB).T)
|
| 110 |
+
image = image * 255.0
|
| 111 |
+
else:
|
| 112 |
+
if format == "L":
|
| 113 |
+
image = image[:, :, 0]
|
| 114 |
+
image = image.astype(np.uint8)
|
| 115 |
+
image = np.asarray(Image.fromarray(image, mode=format).convert("RGB"))
|
| 116 |
+
return image
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
def _apply_exif_orientation(image):
|
| 120 |
+
"""
|
| 121 |
+
Applies the exif orientation correctly.
|
| 122 |
+
|
| 123 |
+
This code exists per the bug:
|
| 124 |
+
https://github.com/python-pillow/Pillow/issues/3973
|
| 125 |
+
with the function `ImageOps.exif_transpose`. The Pillow source raises errors with
|
| 126 |
+
various methods, especially `tobytes`
|
| 127 |
+
|
| 128 |
+
Function based on:
|
| 129 |
+
https://github.com/wkentaro/labelme/blob/v4.5.4/labelme/utils/image.py#L59
|
| 130 |
+
https://github.com/python-pillow/Pillow/blob/7.1.2/src/PIL/ImageOps.py#L527
|
| 131 |
+
|
| 132 |
+
Args:
|
| 133 |
+
image (PIL.Image): a PIL image
|
| 134 |
+
|
| 135 |
+
Returns:
|
| 136 |
+
(PIL.Image): the PIL image with exif orientation applied, if applicable
|
| 137 |
+
"""
|
| 138 |
+
if not hasattr(image, "getexif"):
|
| 139 |
+
return image
|
| 140 |
+
|
| 141 |
+
try:
|
| 142 |
+
exif = image.getexif()
|
| 143 |
+
except Exception: # https://github.com/facebookresearch/detectron2/issues/1885
|
| 144 |
+
exif = None
|
| 145 |
+
|
| 146 |
+
if exif is None:
|
| 147 |
+
return image
|
| 148 |
+
|
| 149 |
+
orientation = exif.get(_EXIF_ORIENT)
|
| 150 |
+
|
| 151 |
+
method = {
|
| 152 |
+
2: Image.FLIP_LEFT_RIGHT,
|
| 153 |
+
3: Image.ROTATE_180,
|
| 154 |
+
4: Image.FLIP_TOP_BOTTOM,
|
| 155 |
+
5: Image.TRANSPOSE,
|
| 156 |
+
6: Image.ROTATE_270,
|
| 157 |
+
7: Image.TRANSVERSE,
|
| 158 |
+
8: Image.ROTATE_90,
|
| 159 |
+
}.get(orientation)
|
| 160 |
+
|
| 161 |
+
if method is not None:
|
| 162 |
+
return image.transpose(method)
|
| 163 |
+
return image
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
def read_image(file_name, format=None):
|
| 167 |
+
"""
|
| 168 |
+
Read an image into the given format.
|
| 169 |
+
Will apply rotation and flipping if the image has such exif information.
|
| 170 |
+
|
| 171 |
+
Args:
|
| 172 |
+
file_name (str): image file path
|
| 173 |
+
format (str): one of the supported image modes in PIL, or "BGR" or "YUV-BT.601".
|
| 174 |
+
|
| 175 |
+
Returns:
|
| 176 |
+
image (np.ndarray):
|
| 177 |
+
an HWC image in the given format, which is 0-255, uint8 for
|
| 178 |
+
supported image modes in PIL or "BGR"; float (0-1 for Y) for YUV-BT.601.
|
| 179 |
+
"""
|
| 180 |
+
with PathManager.open(file_name, "rb") as f:
|
| 181 |
+
image = Image.open(f)
|
| 182 |
+
|
| 183 |
+
# work around this bug: https://github.com/python-pillow/Pillow/issues/3973
|
| 184 |
+
image = _apply_exif_orientation(image)
|
| 185 |
+
return convert_PIL_to_numpy(image, format)
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
def check_image_size(dataset_dict, image):
|
| 189 |
+
"""
|
| 190 |
+
Raise an error if the image does not match the size specified in the dict.
|
| 191 |
+
"""
|
| 192 |
+
if "width" in dataset_dict or "height" in dataset_dict:
|
| 193 |
+
image_wh = (image.shape[1], image.shape[0])
|
| 194 |
+
expected_wh = (dataset_dict["width"], dataset_dict["height"])
|
| 195 |
+
if not image_wh == expected_wh:
|
| 196 |
+
raise SizeMismatchError(
|
| 197 |
+
"Mismatched image shape{}, got {}, expect {}.".format(
|
| 198 |
+
(
|
| 199 |
+
" for image " + dataset_dict["file_name"]
|
| 200 |
+
if "file_name" in dataset_dict
|
| 201 |
+
else ""
|
| 202 |
+
),
|
| 203 |
+
image_wh,
|
| 204 |
+
expected_wh,
|
| 205 |
+
)
|
| 206 |
+
+ " Please check the width/height in your annotation."
|
| 207 |
+
)
|
| 208 |
+
|
| 209 |
+
# To ensure bbox always remap to original image size
|
| 210 |
+
if "width" not in dataset_dict:
|
| 211 |
+
dataset_dict["width"] = image.shape[1]
|
| 212 |
+
if "height" not in dataset_dict:
|
| 213 |
+
dataset_dict["height"] = image.shape[0]
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
def transform_proposals(dataset_dict, image_shape, transforms, *, proposal_topk, min_box_size=0):
|
| 217 |
+
"""
|
| 218 |
+
Apply transformations to the proposals in dataset_dict, if any.
|
| 219 |
+
|
| 220 |
+
Args:
|
| 221 |
+
dataset_dict (dict): a dict read from the dataset, possibly
|
| 222 |
+
contains fields "proposal_boxes", "proposal_objectness_logits", "proposal_bbox_mode"
|
| 223 |
+
image_shape (tuple): height, width
|
| 224 |
+
transforms (TransformList):
|
| 225 |
+
proposal_topk (int): only keep top-K scoring proposals
|
| 226 |
+
min_box_size (int): proposals with either side smaller than this
|
| 227 |
+
threshold are removed
|
| 228 |
+
|
| 229 |
+
The input dict is modified in-place, with abovementioned keys removed. A new
|
| 230 |
+
key "proposals" will be added. Its value is an `Instances`
|
| 231 |
+
object which contains the transformed proposals in its field
|
| 232 |
+
"proposal_boxes" and "objectness_logits".
|
| 233 |
+
"""
|
| 234 |
+
if "proposal_boxes" in dataset_dict:
|
| 235 |
+
# Transform proposal boxes
|
| 236 |
+
boxes = transforms.apply_box(
|
| 237 |
+
BoxMode.convert(
|
| 238 |
+
dataset_dict.pop("proposal_boxes"),
|
| 239 |
+
dataset_dict.pop("proposal_bbox_mode"),
|
| 240 |
+
BoxMode.XYXY_ABS,
|
| 241 |
+
)
|
| 242 |
+
)
|
| 243 |
+
boxes = Boxes(boxes)
|
| 244 |
+
objectness_logits = torch.as_tensor(
|
| 245 |
+
dataset_dict.pop("proposal_objectness_logits").astype("float32")
|
| 246 |
+
)
|
| 247 |
+
|
| 248 |
+
boxes.clip(image_shape)
|
| 249 |
+
keep = boxes.nonempty(threshold=min_box_size)
|
| 250 |
+
boxes = boxes[keep]
|
| 251 |
+
objectness_logits = objectness_logits[keep]
|
| 252 |
+
|
| 253 |
+
proposals = Instances(image_shape)
|
| 254 |
+
proposals.proposal_boxes = boxes[:proposal_topk]
|
| 255 |
+
proposals.objectness_logits = objectness_logits[:proposal_topk]
|
| 256 |
+
dataset_dict["proposals"] = proposals
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
def get_bbox(annotation):
|
| 260 |
+
"""
|
| 261 |
+
Get bbox from data
|
| 262 |
+
Args:
|
| 263 |
+
annotation (dict): dict of instance annotations for a single instance.
|
| 264 |
+
Returns:
|
| 265 |
+
bbox (ndarray): x1, y1, x2, y2 coordinates
|
| 266 |
+
"""
|
| 267 |
+
# bbox is 1d (per-instance bounding box)
|
| 268 |
+
bbox = BoxMode.convert(annotation["bbox"], annotation["bbox_mode"], BoxMode.XYXY_ABS)
|
| 269 |
+
return bbox
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
def transform_instance_annotations(
|
| 273 |
+
annotation, transforms, image_size, *, keypoint_hflip_indices=None
|
| 274 |
+
):
|
| 275 |
+
"""
|
| 276 |
+
Apply transforms to box, segmentation and keypoints annotations of a single instance.
|
| 277 |
+
|
| 278 |
+
It will use `transforms.apply_box` for the box, and
|
| 279 |
+
`transforms.apply_coords` for segmentation polygons & keypoints.
|
| 280 |
+
If you need anything more specially designed for each data structure,
|
| 281 |
+
you'll need to implement your own version of this function or the transforms.
|
| 282 |
+
|
| 283 |
+
Args:
|
| 284 |
+
annotation (dict): dict of instance annotations for a single instance.
|
| 285 |
+
It will be modified in-place.
|
| 286 |
+
transforms (TransformList or list[Transform]):
|
| 287 |
+
image_size (tuple): the height, width of the transformed image
|
| 288 |
+
keypoint_hflip_indices (ndarray[int]): see `create_keypoint_hflip_indices`.
|
| 289 |
+
|
| 290 |
+
Returns:
|
| 291 |
+
dict:
|
| 292 |
+
the same input dict with fields "bbox", "segmentation", "keypoints"
|
| 293 |
+
transformed according to `transforms`.
|
| 294 |
+
The "bbox_mode" field will be set to XYXY_ABS.
|
| 295 |
+
"""
|
| 296 |
+
if isinstance(transforms, (tuple, list)):
|
| 297 |
+
transforms = T.TransformList(transforms)
|
| 298 |
+
# bbox is 1d (per-instance bounding box)
|
| 299 |
+
bbox = BoxMode.convert(annotation["bbox"], annotation["bbox_mode"], BoxMode.XYXY_ABS)
|
| 300 |
+
# clip transformed bbox to image size
|
| 301 |
+
bbox = transforms.apply_box(np.array([bbox]))[0].clip(min=0)
|
| 302 |
+
annotation["bbox"] = np.minimum(bbox, list(image_size + image_size)[::-1])
|
| 303 |
+
annotation["bbox_mode"] = BoxMode.XYXY_ABS
|
| 304 |
+
|
| 305 |
+
if "segmentation" in annotation:
|
| 306 |
+
# each instance contains 1 or more polygons
|
| 307 |
+
segm = annotation["segmentation"]
|
| 308 |
+
if isinstance(segm, list):
|
| 309 |
+
# polygons
|
| 310 |
+
polygons = [np.asarray(p).reshape(-1, 2) for p in segm]
|
| 311 |
+
annotation["segmentation"] = [
|
| 312 |
+
p.reshape(-1) for p in transforms.apply_polygons(polygons)
|
| 313 |
+
]
|
| 314 |
+
elif isinstance(segm, dict):
|
| 315 |
+
# RLE
|
| 316 |
+
mask = mask_util.decode(segm)
|
| 317 |
+
mask = transforms.apply_segmentation(mask)
|
| 318 |
+
assert tuple(mask.shape[:2]) == image_size
|
| 319 |
+
annotation["segmentation"] = mask
|
| 320 |
+
else:
|
| 321 |
+
raise ValueError(
|
| 322 |
+
"Cannot transform segmentation of type '{}'!"
|
| 323 |
+
"Supported types are: polygons as list[list[float] or ndarray],"
|
| 324 |
+
" COCO-style RLE as a dict.".format(type(segm))
|
| 325 |
+
)
|
| 326 |
+
|
| 327 |
+
if "keypoints" in annotation:
|
| 328 |
+
keypoints = transform_keypoint_annotations(
|
| 329 |
+
annotation["keypoints"], transforms, image_size, keypoint_hflip_indices
|
| 330 |
+
)
|
| 331 |
+
annotation["keypoints"] = keypoints
|
| 332 |
+
|
| 333 |
+
return annotation
|
| 334 |
+
|
| 335 |
+
|
| 336 |
+
def transform_keypoint_annotations(keypoints, transforms, image_size, keypoint_hflip_indices=None):
|
| 337 |
+
"""
|
| 338 |
+
Transform keypoint annotations of an image.
|
| 339 |
+
If a keypoint is transformed out of image boundary, it will be marked "unlabeled" (visibility=0)
|
| 340 |
+
|
| 341 |
+
Args:
|
| 342 |
+
keypoints (list[float]): Nx3 float in Detectron2's Dataset format.
|
| 343 |
+
Each point is represented by (x, y, visibility).
|
| 344 |
+
transforms (TransformList):
|
| 345 |
+
image_size (tuple): the height, width of the transformed image
|
| 346 |
+
keypoint_hflip_indices (ndarray[int]): see `create_keypoint_hflip_indices`.
|
| 347 |
+
When `transforms` includes horizontal flip, will use the index
|
| 348 |
+
mapping to flip keypoints.
|
| 349 |
+
"""
|
| 350 |
+
# (N*3,) -> (N, 3)
|
| 351 |
+
keypoints = np.asarray(keypoints, dtype="float64").reshape(-1, 3)
|
| 352 |
+
keypoints_xy = transforms.apply_coords(keypoints[:, :2])
|
| 353 |
+
|
| 354 |
+
# Set all out-of-boundary points to "unlabeled"
|
| 355 |
+
inside = (keypoints_xy >= np.array([0, 0])) & (keypoints_xy <= np.array(image_size[::-1]))
|
| 356 |
+
inside = inside.all(axis=1)
|
| 357 |
+
keypoints[:, :2] = keypoints_xy
|
| 358 |
+
keypoints[:, 2][~inside] = 0
|
| 359 |
+
|
| 360 |
+
# This assumes that HorizFlipTransform is the only one that does flip
|
| 361 |
+
do_hflip = sum(isinstance(t, T.HFlipTransform) for t in transforms.transforms) % 2 == 1
|
| 362 |
+
|
| 363 |
+
# Alternative way: check if probe points was horizontally flipped.
|
| 364 |
+
# probe = np.asarray([[0.0, 0.0], [image_width, 0.0]])
|
| 365 |
+
# probe_aug = transforms.apply_coords(probe.copy())
|
| 366 |
+
# do_hflip = np.sign(probe[1][0] - probe[0][0]) != np.sign(probe_aug[1][0] - probe_aug[0][0]) # noqa
|
| 367 |
+
|
| 368 |
+
# If flipped, swap each keypoint with its opposite-handed equivalent
|
| 369 |
+
if do_hflip:
|
| 370 |
+
if keypoint_hflip_indices is None:
|
| 371 |
+
raise ValueError("Cannot flip keypoints without providing flip indices!")
|
| 372 |
+
if len(keypoints) != len(keypoint_hflip_indices):
|
| 373 |
+
raise ValueError(
|
| 374 |
+
"Keypoint data has {} points, but metadata "
|
| 375 |
+
"contains {} points!".format(len(keypoints), len(keypoint_hflip_indices))
|
| 376 |
+
)
|
| 377 |
+
keypoints = keypoints[np.asarray(keypoint_hflip_indices, dtype=np.int32), :]
|
| 378 |
+
|
| 379 |
+
# Maintain COCO convention that if visibility == 0 (unlabeled), then x, y = 0
|
| 380 |
+
keypoints[keypoints[:, 2] == 0] = 0
|
| 381 |
+
return keypoints
|
| 382 |
+
|
| 383 |
+
|
| 384 |
+
def annotations_to_instances(annos, image_size, mask_format="polygon"):
|
| 385 |
+
"""
|
| 386 |
+
Create an :class:`Instances` object used by the models,
|
| 387 |
+
from instance annotations in the dataset dict.
|
| 388 |
+
|
| 389 |
+
Args:
|
| 390 |
+
annos (list[dict]): a list of instance annotations in one image, each
|
| 391 |
+
element for one instance.
|
| 392 |
+
image_size (tuple): height, width
|
| 393 |
+
|
| 394 |
+
Returns:
|
| 395 |
+
Instances:
|
| 396 |
+
It will contain fields "gt_boxes", "gt_classes",
|
| 397 |
+
"gt_masks", "gt_keypoints", if they can be obtained from `annos`.
|
| 398 |
+
This is the format that builtin models expect.
|
| 399 |
+
"""
|
| 400 |
+
boxes = (
|
| 401 |
+
np.stack(
|
| 402 |
+
[BoxMode.convert(obj["bbox"], obj["bbox_mode"], BoxMode.XYXY_ABS) for obj in annos]
|
| 403 |
+
)
|
| 404 |
+
if len(annos)
|
| 405 |
+
else np.zeros((0, 4))
|
| 406 |
+
)
|
| 407 |
+
target = Instances(image_size)
|
| 408 |
+
target.gt_boxes = Boxes(boxes)
|
| 409 |
+
|
| 410 |
+
classes = [int(obj["category_id"]) for obj in annos]
|
| 411 |
+
classes = torch.tensor(classes, dtype=torch.int64)
|
| 412 |
+
target.gt_classes = classes
|
| 413 |
+
|
| 414 |
+
if len(annos) and "segmentation" in annos[0]:
|
| 415 |
+
segms = [obj["segmentation"] for obj in annos]
|
| 416 |
+
if mask_format == "polygon":
|
| 417 |
+
try:
|
| 418 |
+
masks = PolygonMasks(segms)
|
| 419 |
+
except ValueError as e:
|
| 420 |
+
raise ValueError(
|
| 421 |
+
"Failed to use mask_format=='polygon' from the given annotations!"
|
| 422 |
+
) from e
|
| 423 |
+
else:
|
| 424 |
+
assert mask_format == "bitmask", mask_format
|
| 425 |
+
masks = []
|
| 426 |
+
for segm in segms:
|
| 427 |
+
if isinstance(segm, list):
|
| 428 |
+
# polygon
|
| 429 |
+
masks.append(polygons_to_bitmask(segm, *image_size))
|
| 430 |
+
elif isinstance(segm, dict):
|
| 431 |
+
# COCO RLE
|
| 432 |
+
masks.append(mask_util.decode(segm))
|
| 433 |
+
elif isinstance(segm, np.ndarray):
|
| 434 |
+
assert segm.ndim == 2, "Expect segmentation of 2 dimensions, got {}.".format(
|
| 435 |
+
segm.ndim
|
| 436 |
+
)
|
| 437 |
+
# mask array
|
| 438 |
+
masks.append(segm)
|
| 439 |
+
else:
|
| 440 |
+
raise ValueError(
|
| 441 |
+
"Cannot convert segmentation of type '{}' to BitMasks!"
|
| 442 |
+
"Supported types are: polygons as list[list[float] or ndarray],"
|
| 443 |
+
" COCO-style RLE as a dict, or a binary segmentation mask "
|
| 444 |
+
" in a 2D numpy array of shape HxW.".format(type(segm))
|
| 445 |
+
)
|
| 446 |
+
# torch.from_numpy does not support array with negative stride.
|
| 447 |
+
masks = BitMasks(
|
| 448 |
+
torch.stack([torch.from_numpy(np.ascontiguousarray(x)) for x in masks])
|
| 449 |
+
)
|
| 450 |
+
target.gt_masks = masks
|
| 451 |
+
|
| 452 |
+
if len(annos) and "keypoints" in annos[0]:
|
| 453 |
+
kpts = [obj.get("keypoints", []) for obj in annos]
|
| 454 |
+
target.gt_keypoints = Keypoints(kpts)
|
| 455 |
+
|
| 456 |
+
return target
|
| 457 |
+
|
| 458 |
+
|
| 459 |
+
def annotations_to_instances_rotated(annos, image_size):
|
| 460 |
+
"""
|
| 461 |
+
Create an :class:`Instances` object used by the models,
|
| 462 |
+
from instance annotations in the dataset dict.
|
| 463 |
+
Compared to `annotations_to_instances`, this function is for rotated boxes only
|
| 464 |
+
|
| 465 |
+
Args:
|
| 466 |
+
annos (list[dict]): a list of instance annotations in one image, each
|
| 467 |
+
element for one instance.
|
| 468 |
+
image_size (tuple): height, width
|
| 469 |
+
|
| 470 |
+
Returns:
|
| 471 |
+
Instances:
|
| 472 |
+
Containing fields "gt_boxes", "gt_classes",
|
| 473 |
+
if they can be obtained from `annos`.
|
| 474 |
+
This is the format that builtin models expect.
|
| 475 |
+
"""
|
| 476 |
+
boxes = [obj["bbox"] for obj in annos]
|
| 477 |
+
target = Instances(image_size)
|
| 478 |
+
boxes = target.gt_boxes = RotatedBoxes(boxes)
|
| 479 |
+
boxes.clip(image_size)
|
| 480 |
+
|
| 481 |
+
classes = [obj["category_id"] for obj in annos]
|
| 482 |
+
classes = torch.tensor(classes, dtype=torch.int64)
|
| 483 |
+
target.gt_classes = classes
|
| 484 |
+
|
| 485 |
+
return target
|
| 486 |
+
|
| 487 |
+
|
| 488 |
+
def filter_empty_instances(
|
| 489 |
+
instances, by_box=True, by_mask=True, box_threshold=1e-5, return_mask=False
|
| 490 |
+
):
|
| 491 |
+
"""
|
| 492 |
+
Filter out empty instances in an `Instances` object.
|
| 493 |
+
|
| 494 |
+
Args:
|
| 495 |
+
instances (Instances):
|
| 496 |
+
by_box (bool): whether to filter out instances with empty boxes
|
| 497 |
+
by_mask (bool): whether to filter out instances with empty masks
|
| 498 |
+
box_threshold (float): minimum width and height to be considered non-empty
|
| 499 |
+
return_mask (bool): whether to return boolean mask of filtered instances
|
| 500 |
+
|
| 501 |
+
Returns:
|
| 502 |
+
Instances: the filtered instances.
|
| 503 |
+
tensor[bool], optional: boolean mask of filtered instances
|
| 504 |
+
"""
|
| 505 |
+
assert by_box or by_mask
|
| 506 |
+
r = []
|
| 507 |
+
if by_box:
|
| 508 |
+
r.append(instances.gt_boxes.nonempty(threshold=box_threshold))
|
| 509 |
+
if instances.has("gt_masks") and by_mask:
|
| 510 |
+
r.append(instances.gt_masks.nonempty())
|
| 511 |
+
|
| 512 |
+
# TODO: can also filter visible keypoints
|
| 513 |
+
|
| 514 |
+
if not r:
|
| 515 |
+
return instances
|
| 516 |
+
m = r[0]
|
| 517 |
+
for x in r[1:]:
|
| 518 |
+
m = m & x
|
| 519 |
+
if return_mask:
|
| 520 |
+
return instances[m], m
|
| 521 |
+
return instances[m]
|
| 522 |
+
|
| 523 |
+
|
| 524 |
+
def create_keypoint_hflip_indices(dataset_names: Union[str, List[str]]) -> List[int]:
|
| 525 |
+
"""
|
| 526 |
+
Args:
|
| 527 |
+
dataset_names: list of dataset names
|
| 528 |
+
|
| 529 |
+
Returns:
|
| 530 |
+
list[int]: a list of size=#keypoints, storing the
|
| 531 |
+
horizontally-flipped keypoint indices.
|
| 532 |
+
"""
|
| 533 |
+
if isinstance(dataset_names, str):
|
| 534 |
+
dataset_names = [dataset_names]
|
| 535 |
+
|
| 536 |
+
check_metadata_consistency("keypoint_names", dataset_names)
|
| 537 |
+
check_metadata_consistency("keypoint_flip_map", dataset_names)
|
| 538 |
+
|
| 539 |
+
meta = MetadataCatalog.get(dataset_names[0])
|
| 540 |
+
names = meta.keypoint_names
|
| 541 |
+
# TODO flip -> hflip
|
| 542 |
+
flip_map = dict(meta.keypoint_flip_map)
|
| 543 |
+
flip_map.update({v: k for k, v in flip_map.items()})
|
| 544 |
+
flipped_names = [i if i not in flip_map else flip_map[i] for i in names]
|
| 545 |
+
flip_indices = [names.index(i) for i in flipped_names]
|
| 546 |
+
return flip_indices
|
| 547 |
+
|
| 548 |
+
|
| 549 |
+
def get_fed_loss_cls_weights(dataset_names: Union[str, List[str]], freq_weight_power=1.0):
|
| 550 |
+
"""
|
| 551 |
+
Get frequency weight for each class sorted by class id.
|
| 552 |
+
We now calcualte freqency weight using image_count to the power freq_weight_power.
|
| 553 |
+
|
| 554 |
+
Args:
|
| 555 |
+
dataset_names: list of dataset names
|
| 556 |
+
freq_weight_power: power value
|
| 557 |
+
"""
|
| 558 |
+
if isinstance(dataset_names, str):
|
| 559 |
+
dataset_names = [dataset_names]
|
| 560 |
+
|
| 561 |
+
check_metadata_consistency("class_image_count", dataset_names)
|
| 562 |
+
|
| 563 |
+
meta = MetadataCatalog.get(dataset_names[0])
|
| 564 |
+
class_freq_meta = meta.class_image_count
|
| 565 |
+
class_freq = torch.tensor(
|
| 566 |
+
[c["image_count"] for c in sorted(class_freq_meta, key=lambda x: x["id"])]
|
| 567 |
+
)
|
| 568 |
+
class_freq_weight = class_freq.float() ** freq_weight_power
|
| 569 |
+
return class_freq_weight
|
| 570 |
+
|
| 571 |
+
|
| 572 |
+
def gen_crop_transform_with_instance(crop_size, image_size, instance):
|
| 573 |
+
"""
|
| 574 |
+
Generate a CropTransform so that the cropping region contains
|
| 575 |
+
the center of the given instance.
|
| 576 |
+
|
| 577 |
+
Args:
|
| 578 |
+
crop_size (tuple): h, w in pixels
|
| 579 |
+
image_size (tuple): h, w
|
| 580 |
+
instance (dict): an annotation dict of one instance, in Detectron2's
|
| 581 |
+
dataset format.
|
| 582 |
+
"""
|
| 583 |
+
crop_size = np.asarray(crop_size, dtype=np.int32)
|
| 584 |
+
bbox = BoxMode.convert(instance["bbox"], instance["bbox_mode"], BoxMode.XYXY_ABS)
|
| 585 |
+
center_yx = (bbox[1] + bbox[3]) * 0.5, (bbox[0] + bbox[2]) * 0.5
|
| 586 |
+
assert (
|
| 587 |
+
image_size[0] >= center_yx[0] and image_size[1] >= center_yx[1]
|
| 588 |
+
), "The annotation bounding box is outside of the image!"
|
| 589 |
+
assert (
|
| 590 |
+
image_size[0] >= crop_size[0] and image_size[1] >= crop_size[1]
|
| 591 |
+
), "Crop size is larger than image size!"
|
| 592 |
+
|
| 593 |
+
min_yx = np.maximum(np.floor(center_yx).astype(np.int32) - crop_size, 0)
|
| 594 |
+
max_yx = np.maximum(np.asarray(image_size, dtype=np.int32) - crop_size, 0)
|
| 595 |
+
max_yx = np.minimum(max_yx, np.ceil(center_yx).astype(np.int32))
|
| 596 |
+
|
| 597 |
+
y0 = np.random.randint(min_yx[0], max_yx[0] + 1)
|
| 598 |
+
x0 = np.random.randint(min_yx[1], max_yx[1] + 1)
|
| 599 |
+
return T.CropTransform(x0, y0, crop_size[1], crop_size[0])
|
| 600 |
+
|
| 601 |
+
|
| 602 |
+
def check_metadata_consistency(key, dataset_names):
|
| 603 |
+
"""
|
| 604 |
+
Check that the datasets have consistent metadata.
|
| 605 |
+
|
| 606 |
+
Args:
|
| 607 |
+
key (str): a metadata key
|
| 608 |
+
dataset_names (list[str]): a list of dataset names
|
| 609 |
+
|
| 610 |
+
Raises:
|
| 611 |
+
AttributeError: if the key does not exist in the metadata
|
| 612 |
+
ValueError: if the given datasets do not have the same metadata values defined by key
|
| 613 |
+
"""
|
| 614 |
+
if len(dataset_names) == 0:
|
| 615 |
+
return
|
| 616 |
+
logger = logging.getLogger(__name__)
|
| 617 |
+
entries_per_dataset = [getattr(MetadataCatalog.get(d), key) for d in dataset_names]
|
| 618 |
+
for idx, entry in enumerate(entries_per_dataset):
|
| 619 |
+
if entry != entries_per_dataset[0]:
|
| 620 |
+
logger.error(
|
| 621 |
+
"Metadata '{}' for dataset '{}' is '{}'".format(key, dataset_names[idx], str(entry))
|
| 622 |
+
)
|
| 623 |
+
logger.error(
|
| 624 |
+
"Metadata '{}' for dataset '{}' is '{}'".format(
|
| 625 |
+
key, dataset_names[0], str(entries_per_dataset[0])
|
| 626 |
+
)
|
| 627 |
+
)
|
| 628 |
+
raise ValueError("Datasets have different metadata '{}'!".format(key))
|
| 629 |
+
|
| 630 |
+
|
| 631 |
+
def build_augmentation(cfg, is_train):
|
| 632 |
+
"""
|
| 633 |
+
Create a list of default :class:`Augmentation` from config.
|
| 634 |
+
Now it includes resizing and flipping.
|
| 635 |
+
|
| 636 |
+
Returns:
|
| 637 |
+
list[Augmentation]
|
| 638 |
+
"""
|
| 639 |
+
if is_train:
|
| 640 |
+
min_size = cfg.INPUT.MIN_SIZE_TRAIN
|
| 641 |
+
max_size = cfg.INPUT.MAX_SIZE_TRAIN
|
| 642 |
+
sample_style = cfg.INPUT.MIN_SIZE_TRAIN_SAMPLING
|
| 643 |
+
else:
|
| 644 |
+
min_size = cfg.INPUT.MIN_SIZE_TEST
|
| 645 |
+
max_size = cfg.INPUT.MAX_SIZE_TEST
|
| 646 |
+
sample_style = "choice"
|
| 647 |
+
augmentation = [T.ResizeShortestEdge(min_size, max_size, sample_style)]
|
| 648 |
+
if is_train and cfg.INPUT.RANDOM_FLIP != "none":
|
| 649 |
+
augmentation.append(
|
| 650 |
+
T.RandomFlip(
|
| 651 |
+
horizontal=cfg.INPUT.RANDOM_FLIP == "horizontal",
|
| 652 |
+
vertical=cfg.INPUT.RANDOM_FLIP == "vertical",
|
| 653 |
+
)
|
| 654 |
+
)
|
| 655 |
+
return augmentation
|
| 656 |
+
|
| 657 |
+
|
| 658 |
+
build_transform_gen = build_augmentation
|
| 659 |
+
"""
|
| 660 |
+
Alias for backward-compatibility.
|
| 661 |
+
"""
|
CatVTON/detectron2/data/samplers/__init__.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
from .distributed_sampler import (
|
| 3 |
+
InferenceSampler,
|
| 4 |
+
RandomSubsetTrainingSampler,
|
| 5 |
+
RepeatFactorTrainingSampler,
|
| 6 |
+
TrainingSampler,
|
| 7 |
+
)
|
| 8 |
+
|
| 9 |
+
from .grouped_batch_sampler import GroupedBatchSampler
|
| 10 |
+
|
| 11 |
+
__all__ = [
|
| 12 |
+
"GroupedBatchSampler",
|
| 13 |
+
"TrainingSampler",
|
| 14 |
+
"RandomSubsetTrainingSampler",
|
| 15 |
+
"InferenceSampler",
|
| 16 |
+
"RepeatFactorTrainingSampler",
|
| 17 |
+
]
|
CatVTON/detectron2/data/samplers/distributed_sampler.py
ADDED
|
@@ -0,0 +1,287 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
import itertools
|
| 3 |
+
import logging
|
| 4 |
+
import math
|
| 5 |
+
from collections import defaultdict
|
| 6 |
+
from typing import Optional
|
| 7 |
+
import torch
|
| 8 |
+
from torch.utils.data.sampler import Sampler
|
| 9 |
+
|
| 10 |
+
from detectron2.utils import comm
|
| 11 |
+
|
| 12 |
+
logger = logging.getLogger(__name__)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class TrainingSampler(Sampler):
|
| 16 |
+
"""
|
| 17 |
+
In training, we only care about the "infinite stream" of training data.
|
| 18 |
+
So this sampler produces an infinite stream of indices and
|
| 19 |
+
all workers cooperate to correctly shuffle the indices and sample different indices.
|
| 20 |
+
|
| 21 |
+
The samplers in each worker effectively produces `indices[worker_id::num_workers]`
|
| 22 |
+
where `indices` is an infinite stream of indices consisting of
|
| 23 |
+
`shuffle(range(size)) + shuffle(range(size)) + ...` (if shuffle is True)
|
| 24 |
+
or `range(size) + range(size) + ...` (if shuffle is False)
|
| 25 |
+
|
| 26 |
+
Note that this sampler does not shard based on pytorch DataLoader worker id.
|
| 27 |
+
A sampler passed to pytorch DataLoader is used only with map-style dataset
|
| 28 |
+
and will not be executed inside workers.
|
| 29 |
+
But if this sampler is used in a way that it gets execute inside a dataloader
|
| 30 |
+
worker, then extra work needs to be done to shard its outputs based on worker id.
|
| 31 |
+
This is required so that workers don't produce identical data.
|
| 32 |
+
:class:`ToIterableDataset` implements this logic.
|
| 33 |
+
This note is true for all samplers in detectron2.
|
| 34 |
+
"""
|
| 35 |
+
|
| 36 |
+
def __init__(self, size: int, shuffle: bool = True, seed: Optional[int] = None):
|
| 37 |
+
"""
|
| 38 |
+
Args:
|
| 39 |
+
size (int): the total number of data of the underlying dataset to sample from
|
| 40 |
+
shuffle (bool): whether to shuffle the indices or not
|
| 41 |
+
seed (int): the initial seed of the shuffle. Must be the same
|
| 42 |
+
across all workers. If None, will use a random seed shared
|
| 43 |
+
among workers (require synchronization among all workers).
|
| 44 |
+
"""
|
| 45 |
+
if not isinstance(size, int):
|
| 46 |
+
raise TypeError(f"TrainingSampler(size=) expects an int. Got type {type(size)}.")
|
| 47 |
+
if size <= 0:
|
| 48 |
+
raise ValueError(f"TrainingSampler(size=) expects a positive int. Got {size}.")
|
| 49 |
+
self._size = size
|
| 50 |
+
self._shuffle = shuffle
|
| 51 |
+
if seed is None:
|
| 52 |
+
seed = comm.shared_random_seed()
|
| 53 |
+
self._seed = int(seed)
|
| 54 |
+
|
| 55 |
+
self._rank = comm.get_rank()
|
| 56 |
+
self._world_size = comm.get_world_size()
|
| 57 |
+
|
| 58 |
+
def __iter__(self):
|
| 59 |
+
start = self._rank
|
| 60 |
+
yield from itertools.islice(self._infinite_indices(), start, None, self._world_size)
|
| 61 |
+
|
| 62 |
+
def _infinite_indices(self):
|
| 63 |
+
g = torch.Generator()
|
| 64 |
+
if self._seed is not None:
|
| 65 |
+
g.manual_seed(self._seed)
|
| 66 |
+
while True:
|
| 67 |
+
if self._shuffle:
|
| 68 |
+
yield from torch.randperm(self._size, generator=g).tolist()
|
| 69 |
+
else:
|
| 70 |
+
yield from torch.arange(self._size).tolist()
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
class RandomSubsetTrainingSampler(TrainingSampler):
|
| 74 |
+
"""
|
| 75 |
+
Similar to TrainingSampler, but only sample a random subset of indices.
|
| 76 |
+
This is useful when you want to estimate the accuracy vs data-number curves by
|
| 77 |
+
training the model with different subset_ratio.
|
| 78 |
+
"""
|
| 79 |
+
|
| 80 |
+
def __init__(
|
| 81 |
+
self,
|
| 82 |
+
size: int,
|
| 83 |
+
subset_ratio: float,
|
| 84 |
+
shuffle: bool = True,
|
| 85 |
+
seed_shuffle: Optional[int] = None,
|
| 86 |
+
seed_subset: Optional[int] = None,
|
| 87 |
+
):
|
| 88 |
+
"""
|
| 89 |
+
Args:
|
| 90 |
+
size (int): the total number of data of the underlying dataset to sample from
|
| 91 |
+
subset_ratio (float): the ratio of subset data to sample from the underlying dataset
|
| 92 |
+
shuffle (bool): whether to shuffle the indices or not
|
| 93 |
+
seed_shuffle (int): the initial seed of the shuffle. Must be the same
|
| 94 |
+
across all workers. If None, will use a random seed shared
|
| 95 |
+
among workers (require synchronization among all workers).
|
| 96 |
+
seed_subset (int): the seed to randomize the subset to be sampled.
|
| 97 |
+
Must be the same across all workers. If None, will use a random seed shared
|
| 98 |
+
among workers (require synchronization among all workers).
|
| 99 |
+
"""
|
| 100 |
+
super().__init__(size=size, shuffle=shuffle, seed=seed_shuffle)
|
| 101 |
+
|
| 102 |
+
assert 0.0 < subset_ratio <= 1.0
|
| 103 |
+
self._size_subset = int(size * subset_ratio)
|
| 104 |
+
assert self._size_subset > 0
|
| 105 |
+
if seed_subset is None:
|
| 106 |
+
seed_subset = comm.shared_random_seed()
|
| 107 |
+
self._seed_subset = int(seed_subset)
|
| 108 |
+
|
| 109 |
+
# randomly generate the subset indexes to be sampled from
|
| 110 |
+
g = torch.Generator()
|
| 111 |
+
g.manual_seed(self._seed_subset)
|
| 112 |
+
indexes_randperm = torch.randperm(self._size, generator=g)
|
| 113 |
+
self._indexes_subset = indexes_randperm[: self._size_subset]
|
| 114 |
+
|
| 115 |
+
logger.info("Using RandomSubsetTrainingSampler......")
|
| 116 |
+
logger.info(f"Randomly sample {self._size_subset} data from the original {self._size} data")
|
| 117 |
+
|
| 118 |
+
def _infinite_indices(self):
|
| 119 |
+
g = torch.Generator()
|
| 120 |
+
g.manual_seed(self._seed) # self._seed equals seed_shuffle from __init__()
|
| 121 |
+
while True:
|
| 122 |
+
if self._shuffle:
|
| 123 |
+
# generate a random permutation to shuffle self._indexes_subset
|
| 124 |
+
randperm = torch.randperm(self._size_subset, generator=g)
|
| 125 |
+
yield from self._indexes_subset[randperm].tolist()
|
| 126 |
+
else:
|
| 127 |
+
yield from self._indexes_subset.tolist()
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
class RepeatFactorTrainingSampler(Sampler):
|
| 131 |
+
"""
|
| 132 |
+
Similar to TrainingSampler, but a sample may appear more times than others based
|
| 133 |
+
on its "repeat factor". This is suitable for training on class imbalanced datasets like LVIS.
|
| 134 |
+
"""
|
| 135 |
+
|
| 136 |
+
def __init__(self, repeat_factors, *, shuffle=True, seed=None):
|
| 137 |
+
"""
|
| 138 |
+
Args:
|
| 139 |
+
repeat_factors (Tensor): a float vector, the repeat factor for each indice. When it's
|
| 140 |
+
full of ones, it is equivalent to ``TrainingSampler(len(repeat_factors), ...)``.
|
| 141 |
+
shuffle (bool): whether to shuffle the indices or not
|
| 142 |
+
seed (int): the initial seed of the shuffle. Must be the same
|
| 143 |
+
across all workers. If None, will use a random seed shared
|
| 144 |
+
among workers (require synchronization among all workers).
|
| 145 |
+
"""
|
| 146 |
+
self._shuffle = shuffle
|
| 147 |
+
if seed is None:
|
| 148 |
+
seed = comm.shared_random_seed()
|
| 149 |
+
self._seed = int(seed)
|
| 150 |
+
|
| 151 |
+
self._rank = comm.get_rank()
|
| 152 |
+
self._world_size = comm.get_world_size()
|
| 153 |
+
|
| 154 |
+
# Split into whole number (_int_part) and fractional (_frac_part) parts.
|
| 155 |
+
self._int_part = torch.trunc(repeat_factors)
|
| 156 |
+
self._frac_part = repeat_factors - self._int_part
|
| 157 |
+
|
| 158 |
+
@staticmethod
|
| 159 |
+
def repeat_factors_from_category_frequency(dataset_dicts, repeat_thresh, sqrt=True):
|
| 160 |
+
"""
|
| 161 |
+
Compute (fractional) per-image repeat factors based on category frequency.
|
| 162 |
+
The repeat factor for an image is a function of the frequency of the rarest
|
| 163 |
+
category labeled in that image. The "frequency of category c" in [0, 1] is defined
|
| 164 |
+
as the fraction of images in the training set (without repeats) in which category c
|
| 165 |
+
appears.
|
| 166 |
+
See :paper:`lvis` (>= v2) Appendix B.2.
|
| 167 |
+
|
| 168 |
+
Args:
|
| 169 |
+
dataset_dicts (list[dict]): annotations in Detectron2 dataset format.
|
| 170 |
+
repeat_thresh (float): frequency threshold below which data is repeated.
|
| 171 |
+
If the frequency is half of `repeat_thresh`, the image will be
|
| 172 |
+
repeated twice.
|
| 173 |
+
sqrt (bool): if True, apply :func:`math.sqrt` to the repeat factor.
|
| 174 |
+
|
| 175 |
+
Returns:
|
| 176 |
+
torch.Tensor:
|
| 177 |
+
the i-th element is the repeat factor for the dataset image at index i.
|
| 178 |
+
"""
|
| 179 |
+
# 1. For each category c, compute the fraction of images that contain it: f(c)
|
| 180 |
+
category_freq = defaultdict(int)
|
| 181 |
+
for dataset_dict in dataset_dicts: # For each image (without repeats)
|
| 182 |
+
cat_ids = {ann["category_id"] for ann in dataset_dict["annotations"]}
|
| 183 |
+
for cat_id in cat_ids:
|
| 184 |
+
category_freq[cat_id] += 1
|
| 185 |
+
num_images = len(dataset_dicts)
|
| 186 |
+
for k, v in category_freq.items():
|
| 187 |
+
category_freq[k] = v / num_images
|
| 188 |
+
|
| 189 |
+
# 2. For each category c, compute the category-level repeat factor:
|
| 190 |
+
# r(c) = max(1, sqrt(t / f(c)))
|
| 191 |
+
category_rep = {
|
| 192 |
+
cat_id: max(
|
| 193 |
+
1.0,
|
| 194 |
+
(math.sqrt(repeat_thresh / cat_freq) if sqrt else (repeat_thresh / cat_freq)),
|
| 195 |
+
)
|
| 196 |
+
for cat_id, cat_freq in category_freq.items()
|
| 197 |
+
}
|
| 198 |
+
for cat_id in sorted(category_rep.keys()):
|
| 199 |
+
logger.info(
|
| 200 |
+
f"Cat ID {cat_id}: freq={category_freq[cat_id]:.2f}, rep={category_rep[cat_id]:.2f}"
|
| 201 |
+
)
|
| 202 |
+
|
| 203 |
+
# 3. For each image I, compute the image-level repeat factor:
|
| 204 |
+
# r(I) = max_{c in I} r(c)
|
| 205 |
+
rep_factors = []
|
| 206 |
+
for dataset_dict in dataset_dicts:
|
| 207 |
+
cat_ids = {ann["category_id"] for ann in dataset_dict["annotations"]}
|
| 208 |
+
rep_factor = max({category_rep[cat_id] for cat_id in cat_ids}, default=1.0)
|
| 209 |
+
rep_factors.append(rep_factor)
|
| 210 |
+
|
| 211 |
+
return torch.tensor(rep_factors, dtype=torch.float32)
|
| 212 |
+
|
| 213 |
+
def _get_epoch_indices(self, generator):
|
| 214 |
+
"""
|
| 215 |
+
Create a list of dataset indices (with repeats) to use for one epoch.
|
| 216 |
+
|
| 217 |
+
Args:
|
| 218 |
+
generator (torch.Generator): pseudo random number generator used for
|
| 219 |
+
stochastic rounding.
|
| 220 |
+
|
| 221 |
+
Returns:
|
| 222 |
+
torch.Tensor: list of dataset indices to use in one epoch. Each index
|
| 223 |
+
is repeated based on its calculated repeat factor.
|
| 224 |
+
"""
|
| 225 |
+
# Since repeat factors are fractional, we use stochastic rounding so
|
| 226 |
+
# that the target repeat factor is achieved in expectation over the
|
| 227 |
+
# course of training
|
| 228 |
+
rands = torch.rand(len(self._frac_part), generator=generator)
|
| 229 |
+
rep_factors = self._int_part + (rands < self._frac_part).float()
|
| 230 |
+
# Construct a list of indices in which we repeat images as specified
|
| 231 |
+
indices = []
|
| 232 |
+
for dataset_index, rep_factor in enumerate(rep_factors):
|
| 233 |
+
indices.extend([dataset_index] * int(rep_factor.item()))
|
| 234 |
+
return torch.tensor(indices, dtype=torch.int64)
|
| 235 |
+
|
| 236 |
+
def __iter__(self):
|
| 237 |
+
start = self._rank
|
| 238 |
+
yield from itertools.islice(self._infinite_indices(), start, None, self._world_size)
|
| 239 |
+
|
| 240 |
+
def _infinite_indices(self):
|
| 241 |
+
g = torch.Generator()
|
| 242 |
+
g.manual_seed(self._seed)
|
| 243 |
+
while True:
|
| 244 |
+
# Sample indices with repeats determined by stochastic rounding; each
|
| 245 |
+
# "epoch" may have a slightly different size due to the rounding.
|
| 246 |
+
indices = self._get_epoch_indices(g)
|
| 247 |
+
if self._shuffle:
|
| 248 |
+
randperm = torch.randperm(len(indices), generator=g)
|
| 249 |
+
yield from indices[randperm].tolist()
|
| 250 |
+
else:
|
| 251 |
+
yield from indices.tolist()
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
class InferenceSampler(Sampler):
|
| 255 |
+
"""
|
| 256 |
+
Produce indices for inference across all workers.
|
| 257 |
+
Inference needs to run on the __exact__ set of samples,
|
| 258 |
+
therefore when the total number of samples is not divisible by the number of workers,
|
| 259 |
+
this sampler produces different number of samples on different workers.
|
| 260 |
+
"""
|
| 261 |
+
|
| 262 |
+
def __init__(self, size: int):
|
| 263 |
+
"""
|
| 264 |
+
Args:
|
| 265 |
+
size (int): the total number of data of the underlying dataset to sample from
|
| 266 |
+
"""
|
| 267 |
+
self._size = size
|
| 268 |
+
assert size > 0
|
| 269 |
+
self._rank = comm.get_rank()
|
| 270 |
+
self._world_size = comm.get_world_size()
|
| 271 |
+
self._local_indices = self._get_local_indices(size, self._world_size, self._rank)
|
| 272 |
+
|
| 273 |
+
@staticmethod
|
| 274 |
+
def _get_local_indices(total_size, world_size, rank):
|
| 275 |
+
shard_size = total_size // world_size
|
| 276 |
+
left = total_size % world_size
|
| 277 |
+
shard_sizes = [shard_size + int(r < left) for r in range(world_size)]
|
| 278 |
+
|
| 279 |
+
begin = sum(shard_sizes[:rank])
|
| 280 |
+
end = min(sum(shard_sizes[: rank + 1]), total_size)
|
| 281 |
+
return range(begin, end)
|
| 282 |
+
|
| 283 |
+
def __iter__(self):
|
| 284 |
+
yield from self._local_indices
|
| 285 |
+
|
| 286 |
+
def __len__(self):
|
| 287 |
+
return len(self._local_indices)
|
CatVTON/detectron2/data/samplers/grouped_batch_sampler.py
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
import numpy as np
|
| 3 |
+
from torch.utils.data.sampler import BatchSampler, Sampler
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class GroupedBatchSampler(BatchSampler):
|
| 7 |
+
"""
|
| 8 |
+
Wraps another sampler to yield a mini-batch of indices.
|
| 9 |
+
It enforces that the batch only contain elements from the same group.
|
| 10 |
+
It also tries to provide mini-batches which follows an ordering which is
|
| 11 |
+
as close as possible to the ordering from the original sampler.
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
def __init__(self, sampler, group_ids, batch_size):
|
| 15 |
+
"""
|
| 16 |
+
Args:
|
| 17 |
+
sampler (Sampler): Base sampler.
|
| 18 |
+
group_ids (list[int]): If the sampler produces indices in range [0, N),
|
| 19 |
+
`group_ids` must be a list of `N` ints which contains the group id of each sample.
|
| 20 |
+
The group ids must be a set of integers in the range [0, num_groups).
|
| 21 |
+
batch_size (int): Size of mini-batch.
|
| 22 |
+
"""
|
| 23 |
+
if not isinstance(sampler, Sampler):
|
| 24 |
+
raise ValueError(
|
| 25 |
+
"sampler should be an instance of "
|
| 26 |
+
"torch.utils.data.Sampler, but got sampler={}".format(sampler)
|
| 27 |
+
)
|
| 28 |
+
self.sampler = sampler
|
| 29 |
+
self.group_ids = np.asarray(group_ids)
|
| 30 |
+
assert self.group_ids.ndim == 1
|
| 31 |
+
self.batch_size = batch_size
|
| 32 |
+
groups = np.unique(self.group_ids).tolist()
|
| 33 |
+
|
| 34 |
+
# buffer the indices of each group until batch size is reached
|
| 35 |
+
self.buffer_per_group = {k: [] for k in groups}
|
| 36 |
+
|
| 37 |
+
def __iter__(self):
|
| 38 |
+
for idx in self.sampler:
|
| 39 |
+
group_id = self.group_ids[idx]
|
| 40 |
+
group_buffer = self.buffer_per_group[group_id]
|
| 41 |
+
group_buffer.append(idx)
|
| 42 |
+
if len(group_buffer) == self.batch_size:
|
| 43 |
+
yield group_buffer[:] # yield a copy of the list
|
| 44 |
+
del group_buffer[:]
|
| 45 |
+
|
| 46 |
+
def __len__(self):
|
| 47 |
+
raise NotImplementedError("len() of GroupedBatchSampler is not well-defined.")
|
CatVTON/detectron2/data/transforms/__init__.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
from fvcore.transforms.transform import Transform, TransformList # order them first
|
| 3 |
+
from fvcore.transforms.transform import *
|
| 4 |
+
from .transform import *
|
| 5 |
+
from .augmentation import *
|
| 6 |
+
from .augmentation_impl import *
|
| 7 |
+
|
| 8 |
+
__all__ = [k for k in globals().keys() if not k.startswith("_")]
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
from detectron2.utils.env import fixup_module_metadata
|
| 12 |
+
|
| 13 |
+
fixup_module_metadata(__name__, globals(), __all__)
|
| 14 |
+
del fixup_module_metadata
|
CatVTON/detectron2/data/transforms/augmentation.py
ADDED
|
@@ -0,0 +1,380 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 3 |
+
|
| 4 |
+
import inspect
|
| 5 |
+
import numpy as np
|
| 6 |
+
import pprint
|
| 7 |
+
from typing import Any, List, Optional, Tuple, Union
|
| 8 |
+
from fvcore.transforms.transform import Transform, TransformList
|
| 9 |
+
|
| 10 |
+
"""
|
| 11 |
+
See "Data Augmentation" tutorial for an overview of the system:
|
| 12 |
+
https://detectron2.readthedocs.io/tutorials/augmentation.html
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
__all__ = [
|
| 17 |
+
"Augmentation",
|
| 18 |
+
"AugmentationList",
|
| 19 |
+
"AugInput",
|
| 20 |
+
"TransformGen",
|
| 21 |
+
"apply_transform_gens",
|
| 22 |
+
"StandardAugInput",
|
| 23 |
+
"apply_augmentations",
|
| 24 |
+
]
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def _check_img_dtype(img):
|
| 28 |
+
assert isinstance(img, np.ndarray), "[Augmentation] Needs an numpy array, but got a {}!".format(
|
| 29 |
+
type(img)
|
| 30 |
+
)
|
| 31 |
+
assert not isinstance(img.dtype, np.integer) or (
|
| 32 |
+
img.dtype == np.uint8
|
| 33 |
+
), "[Augmentation] Got image of type {}, use uint8 or floating points instead!".format(
|
| 34 |
+
img.dtype
|
| 35 |
+
)
|
| 36 |
+
assert img.ndim in [2, 3], img.ndim
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def _get_aug_input_args(aug, aug_input) -> List[Any]:
|
| 40 |
+
"""
|
| 41 |
+
Get the arguments to be passed to ``aug.get_transform`` from the input ``aug_input``.
|
| 42 |
+
"""
|
| 43 |
+
if aug.input_args is None:
|
| 44 |
+
# Decide what attributes are needed automatically
|
| 45 |
+
prms = list(inspect.signature(aug.get_transform).parameters.items())
|
| 46 |
+
# The default behavior is: if there is one parameter, then its "image"
|
| 47 |
+
# (work automatically for majority of use cases, and also avoid BC breaking),
|
| 48 |
+
# Otherwise, use the argument names.
|
| 49 |
+
if len(prms) == 1:
|
| 50 |
+
names = ("image",)
|
| 51 |
+
else:
|
| 52 |
+
names = []
|
| 53 |
+
for name, prm in prms:
|
| 54 |
+
if prm.kind in (
|
| 55 |
+
inspect.Parameter.VAR_POSITIONAL,
|
| 56 |
+
inspect.Parameter.VAR_KEYWORD,
|
| 57 |
+
):
|
| 58 |
+
raise TypeError(
|
| 59 |
+
f""" \
|
| 60 |
+
The default implementation of `{type(aug)}.__call__` does not allow \
|
| 61 |
+
`{type(aug)}.get_transform` to use variable-length arguments (*args, **kwargs)! \
|
| 62 |
+
If arguments are unknown, reimplement `__call__` instead. \
|
| 63 |
+
"""
|
| 64 |
+
)
|
| 65 |
+
names.append(name)
|
| 66 |
+
aug.input_args = tuple(names)
|
| 67 |
+
|
| 68 |
+
args = []
|
| 69 |
+
for f in aug.input_args:
|
| 70 |
+
try:
|
| 71 |
+
args.append(getattr(aug_input, f))
|
| 72 |
+
except AttributeError as e:
|
| 73 |
+
raise AttributeError(
|
| 74 |
+
f"{type(aug)}.get_transform needs input attribute '{f}', "
|
| 75 |
+
f"but it is not an attribute of {type(aug_input)}!"
|
| 76 |
+
) from e
|
| 77 |
+
return args
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
class Augmentation:
|
| 81 |
+
"""
|
| 82 |
+
Augmentation defines (often random) policies/strategies to generate :class:`Transform`
|
| 83 |
+
from data. It is often used for pre-processing of input data.
|
| 84 |
+
|
| 85 |
+
A "policy" that generates a :class:`Transform` may, in the most general case,
|
| 86 |
+
need arbitrary information from input data in order to determine what transforms
|
| 87 |
+
to apply. Therefore, each :class:`Augmentation` instance defines the arguments
|
| 88 |
+
needed by its :meth:`get_transform` method. When called with the positional arguments,
|
| 89 |
+
the :meth:`get_transform` method executes the policy.
|
| 90 |
+
|
| 91 |
+
Note that :class:`Augmentation` defines the policies to create a :class:`Transform`,
|
| 92 |
+
but not how to execute the actual transform operations to those data.
|
| 93 |
+
Its :meth:`__call__` method will use :meth:`AugInput.transform` to execute the transform.
|
| 94 |
+
|
| 95 |
+
The returned `Transform` object is meant to describe deterministic transformation, which means
|
| 96 |
+
it can be re-applied on associated data, e.g. the geometry of an image and its segmentation
|
| 97 |
+
masks need to be transformed together.
|
| 98 |
+
(If such re-application is not needed, then determinism is not a crucial requirement.)
|
| 99 |
+
"""
|
| 100 |
+
|
| 101 |
+
input_args: Optional[Tuple[str]] = None
|
| 102 |
+
"""
|
| 103 |
+
Stores the attribute names needed by :meth:`get_transform`, e.g. ``("image", "sem_seg")``.
|
| 104 |
+
By default, it is just a tuple of argument names in :meth:`self.get_transform`, which often only
|
| 105 |
+
contain "image". As long as the argument name convention is followed, there is no need for
|
| 106 |
+
users to touch this attribute.
|
| 107 |
+
"""
|
| 108 |
+
|
| 109 |
+
def _init(self, params=None):
|
| 110 |
+
if params:
|
| 111 |
+
for k, v in params.items():
|
| 112 |
+
if k != "self" and not k.startswith("_"):
|
| 113 |
+
setattr(self, k, v)
|
| 114 |
+
|
| 115 |
+
def get_transform(self, *args) -> Transform:
|
| 116 |
+
"""
|
| 117 |
+
Execute the policy based on input data, and decide what transform to apply to inputs.
|
| 118 |
+
|
| 119 |
+
Args:
|
| 120 |
+
args: Any fixed-length positional arguments. By default, the name of the arguments
|
| 121 |
+
should exist in the :class:`AugInput` to be used.
|
| 122 |
+
|
| 123 |
+
Returns:
|
| 124 |
+
Transform: Returns the deterministic transform to apply to the input.
|
| 125 |
+
|
| 126 |
+
Examples:
|
| 127 |
+
::
|
| 128 |
+
class MyAug:
|
| 129 |
+
# if a policy needs to know both image and semantic segmentation
|
| 130 |
+
def get_transform(image, sem_seg) -> T.Transform:
|
| 131 |
+
pass
|
| 132 |
+
tfm: Transform = MyAug().get_transform(image, sem_seg)
|
| 133 |
+
new_image = tfm.apply_image(image)
|
| 134 |
+
|
| 135 |
+
Notes:
|
| 136 |
+
Users can freely use arbitrary new argument names in custom
|
| 137 |
+
:meth:`get_transform` method, as long as they are available in the
|
| 138 |
+
input data. In detectron2 we use the following convention:
|
| 139 |
+
|
| 140 |
+
* image: (H,W) or (H,W,C) ndarray of type uint8 in range [0, 255], or
|
| 141 |
+
floating point in range [0, 1] or [0, 255].
|
| 142 |
+
* boxes: (N,4) ndarray of float32. It represents the instance bounding boxes
|
| 143 |
+
of N instances. Each is in XYXY format in unit of absolute coordinates.
|
| 144 |
+
* sem_seg: (H,W) ndarray of type uint8. Each element is an integer label of pixel.
|
| 145 |
+
|
| 146 |
+
We do not specify convention for other types and do not include builtin
|
| 147 |
+
:class:`Augmentation` that uses other types in detectron2.
|
| 148 |
+
"""
|
| 149 |
+
raise NotImplementedError
|
| 150 |
+
|
| 151 |
+
def __call__(self, aug_input) -> Transform:
|
| 152 |
+
"""
|
| 153 |
+
Augment the given `aug_input` **in-place**, and return the transform that's used.
|
| 154 |
+
|
| 155 |
+
This method will be called to apply the augmentation. In most augmentation, it
|
| 156 |
+
is enough to use the default implementation, which calls :meth:`get_transform`
|
| 157 |
+
using the inputs. But a subclass can overwrite it to have more complicated logic.
|
| 158 |
+
|
| 159 |
+
Args:
|
| 160 |
+
aug_input (AugInput): an object that has attributes needed by this augmentation
|
| 161 |
+
(defined by ``self.get_transform``). Its ``transform`` method will be called
|
| 162 |
+
to in-place transform it.
|
| 163 |
+
|
| 164 |
+
Returns:
|
| 165 |
+
Transform: the transform that is applied on the input.
|
| 166 |
+
"""
|
| 167 |
+
args = _get_aug_input_args(self, aug_input)
|
| 168 |
+
tfm = self.get_transform(*args)
|
| 169 |
+
assert isinstance(tfm, (Transform, TransformList)), (
|
| 170 |
+
f"{type(self)}.get_transform must return an instance of Transform! "
|
| 171 |
+
f"Got {type(tfm)} instead."
|
| 172 |
+
)
|
| 173 |
+
aug_input.transform(tfm)
|
| 174 |
+
return tfm
|
| 175 |
+
|
| 176 |
+
def _rand_range(self, low=1.0, high=None, size=None):
|
| 177 |
+
"""
|
| 178 |
+
Uniform float random number between low and high.
|
| 179 |
+
"""
|
| 180 |
+
if high is None:
|
| 181 |
+
low, high = 0, low
|
| 182 |
+
if size is None:
|
| 183 |
+
size = []
|
| 184 |
+
return np.random.uniform(low, high, size)
|
| 185 |
+
|
| 186 |
+
def __repr__(self):
|
| 187 |
+
"""
|
| 188 |
+
Produce something like:
|
| 189 |
+
"MyAugmentation(field1={self.field1}, field2={self.field2})"
|
| 190 |
+
"""
|
| 191 |
+
try:
|
| 192 |
+
sig = inspect.signature(self.__init__)
|
| 193 |
+
classname = type(self).__name__
|
| 194 |
+
argstr = []
|
| 195 |
+
for name, param in sig.parameters.items():
|
| 196 |
+
assert (
|
| 197 |
+
param.kind != param.VAR_POSITIONAL and param.kind != param.VAR_KEYWORD
|
| 198 |
+
), "The default __repr__ doesn't support *args or **kwargs"
|
| 199 |
+
assert hasattr(self, name), (
|
| 200 |
+
"Attribute {} not found! "
|
| 201 |
+
"Default __repr__ only works if attributes match the constructor.".format(name)
|
| 202 |
+
)
|
| 203 |
+
attr = getattr(self, name)
|
| 204 |
+
default = param.default
|
| 205 |
+
if default is attr:
|
| 206 |
+
continue
|
| 207 |
+
attr_str = pprint.pformat(attr)
|
| 208 |
+
if "\n" in attr_str:
|
| 209 |
+
# don't show it if pformat decides to use >1 lines
|
| 210 |
+
attr_str = "..."
|
| 211 |
+
argstr.append("{}={}".format(name, attr_str))
|
| 212 |
+
return "{}({})".format(classname, ", ".join(argstr))
|
| 213 |
+
except AssertionError:
|
| 214 |
+
return super().__repr__()
|
| 215 |
+
|
| 216 |
+
__str__ = __repr__
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
class _TransformToAug(Augmentation):
|
| 220 |
+
def __init__(self, tfm: Transform):
|
| 221 |
+
self.tfm = tfm
|
| 222 |
+
|
| 223 |
+
def get_transform(self, *args):
|
| 224 |
+
return self.tfm
|
| 225 |
+
|
| 226 |
+
def __repr__(self):
|
| 227 |
+
return repr(self.tfm)
|
| 228 |
+
|
| 229 |
+
__str__ = __repr__
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
def _transform_to_aug(tfm_or_aug):
|
| 233 |
+
"""
|
| 234 |
+
Wrap Transform into Augmentation.
|
| 235 |
+
Private, used internally to implement augmentations.
|
| 236 |
+
"""
|
| 237 |
+
assert isinstance(tfm_or_aug, (Transform, Augmentation)), tfm_or_aug
|
| 238 |
+
if isinstance(tfm_or_aug, Augmentation):
|
| 239 |
+
return tfm_or_aug
|
| 240 |
+
else:
|
| 241 |
+
return _TransformToAug(tfm_or_aug)
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
class AugmentationList(Augmentation):
|
| 245 |
+
"""
|
| 246 |
+
Apply a sequence of augmentations.
|
| 247 |
+
|
| 248 |
+
It has ``__call__`` method to apply the augmentations.
|
| 249 |
+
|
| 250 |
+
Note that :meth:`get_transform` method is impossible (will throw error if called)
|
| 251 |
+
for :class:`AugmentationList`, because in order to apply a sequence of augmentations,
|
| 252 |
+
the kth augmentation must be applied first, to provide inputs needed by the (k+1)th
|
| 253 |
+
augmentation.
|
| 254 |
+
"""
|
| 255 |
+
|
| 256 |
+
def __init__(self, augs):
|
| 257 |
+
"""
|
| 258 |
+
Args:
|
| 259 |
+
augs (list[Augmentation or Transform]):
|
| 260 |
+
"""
|
| 261 |
+
super().__init__()
|
| 262 |
+
self.augs = [_transform_to_aug(x) for x in augs]
|
| 263 |
+
|
| 264 |
+
def __call__(self, aug_input) -> TransformList:
|
| 265 |
+
tfms = []
|
| 266 |
+
for x in self.augs:
|
| 267 |
+
tfm = x(aug_input)
|
| 268 |
+
tfms.append(tfm)
|
| 269 |
+
return TransformList(tfms)
|
| 270 |
+
|
| 271 |
+
def __repr__(self):
|
| 272 |
+
msgs = [str(x) for x in self.augs]
|
| 273 |
+
return "AugmentationList[{}]".format(", ".join(msgs))
|
| 274 |
+
|
| 275 |
+
__str__ = __repr__
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
class AugInput:
|
| 279 |
+
"""
|
| 280 |
+
Input that can be used with :meth:`Augmentation.__call__`.
|
| 281 |
+
This is a standard implementation for the majority of use cases.
|
| 282 |
+
This class provides the standard attributes **"image", "boxes", "sem_seg"**
|
| 283 |
+
defined in :meth:`__init__` and they may be needed by different augmentations.
|
| 284 |
+
Most augmentation policies do not need attributes beyond these three.
|
| 285 |
+
|
| 286 |
+
After applying augmentations to these attributes (using :meth:`AugInput.transform`),
|
| 287 |
+
the returned transforms can then be used to transform other data structures that users have.
|
| 288 |
+
|
| 289 |
+
Examples:
|
| 290 |
+
::
|
| 291 |
+
input = AugInput(image, boxes=boxes)
|
| 292 |
+
tfms = augmentation(input)
|
| 293 |
+
transformed_image = input.image
|
| 294 |
+
transformed_boxes = input.boxes
|
| 295 |
+
transformed_other_data = tfms.apply_other(other_data)
|
| 296 |
+
|
| 297 |
+
An extended project that works with new data types may implement augmentation policies
|
| 298 |
+
that need other inputs. An algorithm may need to transform inputs in a way different
|
| 299 |
+
from the standard approach defined in this class. In those rare situations, users can
|
| 300 |
+
implement a class similar to this class, that satify the following condition:
|
| 301 |
+
|
| 302 |
+
* The input must provide access to these data in the form of attribute access
|
| 303 |
+
(``getattr``). For example, if an :class:`Augmentation` to be applied needs "image"
|
| 304 |
+
and "sem_seg" arguments, its input must have the attribute "image" and "sem_seg".
|
| 305 |
+
* The input must have a ``transform(tfm: Transform) -> None`` method which
|
| 306 |
+
in-place transforms all its attributes.
|
| 307 |
+
"""
|
| 308 |
+
|
| 309 |
+
# TODO maybe should support more builtin data types here
|
| 310 |
+
def __init__(
|
| 311 |
+
self,
|
| 312 |
+
image: np.ndarray,
|
| 313 |
+
*,
|
| 314 |
+
boxes: Optional[np.ndarray] = None,
|
| 315 |
+
sem_seg: Optional[np.ndarray] = None,
|
| 316 |
+
):
|
| 317 |
+
"""
|
| 318 |
+
Args:
|
| 319 |
+
image (ndarray): (H,W) or (H,W,C) ndarray of type uint8 in range [0, 255], or
|
| 320 |
+
floating point in range [0, 1] or [0, 255]. The meaning of C is up
|
| 321 |
+
to users.
|
| 322 |
+
boxes (ndarray or None): Nx4 float32 boxes in XYXY_ABS mode
|
| 323 |
+
sem_seg (ndarray or None): HxW uint8 semantic segmentation mask. Each element
|
| 324 |
+
is an integer label of pixel.
|
| 325 |
+
"""
|
| 326 |
+
_check_img_dtype(image)
|
| 327 |
+
self.image = image
|
| 328 |
+
self.boxes = boxes
|
| 329 |
+
self.sem_seg = sem_seg
|
| 330 |
+
|
| 331 |
+
def transform(self, tfm: Transform) -> None:
|
| 332 |
+
"""
|
| 333 |
+
In-place transform all attributes of this class.
|
| 334 |
+
|
| 335 |
+
By "in-place", it means after calling this method, accessing an attribute such
|
| 336 |
+
as ``self.image`` will return transformed data.
|
| 337 |
+
"""
|
| 338 |
+
self.image = tfm.apply_image(self.image)
|
| 339 |
+
if self.boxes is not None:
|
| 340 |
+
self.boxes = tfm.apply_box(self.boxes)
|
| 341 |
+
if self.sem_seg is not None:
|
| 342 |
+
self.sem_seg = tfm.apply_segmentation(self.sem_seg)
|
| 343 |
+
|
| 344 |
+
def apply_augmentations(
|
| 345 |
+
self, augmentations: List[Union[Augmentation, Transform]]
|
| 346 |
+
) -> TransformList:
|
| 347 |
+
"""
|
| 348 |
+
Equivalent of ``AugmentationList(augmentations)(self)``
|
| 349 |
+
"""
|
| 350 |
+
return AugmentationList(augmentations)(self)
|
| 351 |
+
|
| 352 |
+
|
| 353 |
+
def apply_augmentations(augmentations: List[Union[Transform, Augmentation]], inputs):
|
| 354 |
+
"""
|
| 355 |
+
Use ``T.AugmentationList(augmentations)(inputs)`` instead.
|
| 356 |
+
"""
|
| 357 |
+
if isinstance(inputs, np.ndarray):
|
| 358 |
+
# handle the common case of image-only Augmentation, also for backward compatibility
|
| 359 |
+
image_only = True
|
| 360 |
+
inputs = AugInput(inputs)
|
| 361 |
+
else:
|
| 362 |
+
image_only = False
|
| 363 |
+
tfms = inputs.apply_augmentations(augmentations)
|
| 364 |
+
return inputs.image if image_only else inputs, tfms
|
| 365 |
+
|
| 366 |
+
|
| 367 |
+
apply_transform_gens = apply_augmentations
|
| 368 |
+
"""
|
| 369 |
+
Alias for backward-compatibility.
|
| 370 |
+
"""
|
| 371 |
+
|
| 372 |
+
TransformGen = Augmentation
|
| 373 |
+
"""
|
| 374 |
+
Alias for Augmentation, since it is something that generates :class:`Transform`s
|
| 375 |
+
"""
|
| 376 |
+
|
| 377 |
+
StandardAugInput = AugInput
|
| 378 |
+
"""
|
| 379 |
+
Alias for compatibility. It's not worth the complexity to have two classes.
|
| 380 |
+
"""
|
CatVTON/detectron2/data/transforms/augmentation_impl.py
ADDED
|
@@ -0,0 +1,736 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 3 |
+
"""
|
| 4 |
+
Implement many useful :class:`Augmentation`.
|
| 5 |
+
"""
|
| 6 |
+
import numpy as np
|
| 7 |
+
import sys
|
| 8 |
+
from numpy import random
|
| 9 |
+
from typing import Tuple
|
| 10 |
+
import torch
|
| 11 |
+
from fvcore.transforms.transform import (
|
| 12 |
+
BlendTransform,
|
| 13 |
+
CropTransform,
|
| 14 |
+
HFlipTransform,
|
| 15 |
+
NoOpTransform,
|
| 16 |
+
PadTransform,
|
| 17 |
+
Transform,
|
| 18 |
+
TransformList,
|
| 19 |
+
VFlipTransform,
|
| 20 |
+
)
|
| 21 |
+
from PIL import Image
|
| 22 |
+
|
| 23 |
+
from detectron2.structures import Boxes, pairwise_iou
|
| 24 |
+
|
| 25 |
+
from .augmentation import Augmentation, _transform_to_aug
|
| 26 |
+
from .transform import ExtentTransform, ResizeTransform, RotationTransform
|
| 27 |
+
|
| 28 |
+
__all__ = [
|
| 29 |
+
"FixedSizeCrop",
|
| 30 |
+
"RandomApply",
|
| 31 |
+
"RandomBrightness",
|
| 32 |
+
"RandomContrast",
|
| 33 |
+
"RandomCrop",
|
| 34 |
+
"RandomExtent",
|
| 35 |
+
"RandomFlip",
|
| 36 |
+
"RandomSaturation",
|
| 37 |
+
"RandomLighting",
|
| 38 |
+
"RandomRotation",
|
| 39 |
+
"Resize",
|
| 40 |
+
"ResizeScale",
|
| 41 |
+
"ResizeShortestEdge",
|
| 42 |
+
"RandomCrop_CategoryAreaConstraint",
|
| 43 |
+
"RandomResize",
|
| 44 |
+
"MinIoURandomCrop",
|
| 45 |
+
]
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
class RandomApply(Augmentation):
|
| 49 |
+
"""
|
| 50 |
+
Randomly apply an augmentation with a given probability.
|
| 51 |
+
"""
|
| 52 |
+
|
| 53 |
+
def __init__(self, tfm_or_aug, prob=0.5):
|
| 54 |
+
"""
|
| 55 |
+
Args:
|
| 56 |
+
tfm_or_aug (Transform, Augmentation): the transform or augmentation
|
| 57 |
+
to be applied. It can either be a `Transform` or `Augmentation`
|
| 58 |
+
instance.
|
| 59 |
+
prob (float): probability between 0.0 and 1.0 that
|
| 60 |
+
the wrapper transformation is applied
|
| 61 |
+
"""
|
| 62 |
+
super().__init__()
|
| 63 |
+
self.aug = _transform_to_aug(tfm_or_aug)
|
| 64 |
+
assert 0.0 <= prob <= 1.0, f"Probablity must be between 0.0 and 1.0 (given: {prob})"
|
| 65 |
+
self.prob = prob
|
| 66 |
+
|
| 67 |
+
def get_transform(self, *args):
|
| 68 |
+
do = self._rand_range() < self.prob
|
| 69 |
+
if do:
|
| 70 |
+
return self.aug.get_transform(*args)
|
| 71 |
+
else:
|
| 72 |
+
return NoOpTransform()
|
| 73 |
+
|
| 74 |
+
def __call__(self, aug_input):
|
| 75 |
+
do = self._rand_range() < self.prob
|
| 76 |
+
if do:
|
| 77 |
+
return self.aug(aug_input)
|
| 78 |
+
else:
|
| 79 |
+
return NoOpTransform()
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
class RandomFlip(Augmentation):
|
| 83 |
+
"""
|
| 84 |
+
Flip the image horizontally or vertically with the given probability.
|
| 85 |
+
"""
|
| 86 |
+
|
| 87 |
+
def __init__(self, prob=0.5, *, horizontal=True, vertical=False):
|
| 88 |
+
"""
|
| 89 |
+
Args:
|
| 90 |
+
prob (float): probability of flip.
|
| 91 |
+
horizontal (boolean): whether to apply horizontal flipping
|
| 92 |
+
vertical (boolean): whether to apply vertical flipping
|
| 93 |
+
"""
|
| 94 |
+
super().__init__()
|
| 95 |
+
|
| 96 |
+
if horizontal and vertical:
|
| 97 |
+
raise ValueError("Cannot do both horiz and vert. Please use two Flip instead.")
|
| 98 |
+
if not horizontal and not vertical:
|
| 99 |
+
raise ValueError("At least one of horiz or vert has to be True!")
|
| 100 |
+
self._init(locals())
|
| 101 |
+
|
| 102 |
+
def get_transform(self, image):
|
| 103 |
+
h, w = image.shape[:2]
|
| 104 |
+
do = self._rand_range() < self.prob
|
| 105 |
+
if do:
|
| 106 |
+
if self.horizontal:
|
| 107 |
+
return HFlipTransform(w)
|
| 108 |
+
elif self.vertical:
|
| 109 |
+
return VFlipTransform(h)
|
| 110 |
+
else:
|
| 111 |
+
return NoOpTransform()
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
class Resize(Augmentation):
|
| 115 |
+
"""Resize image to a fixed target size"""
|
| 116 |
+
|
| 117 |
+
def __init__(self, shape, interp=Image.BILINEAR):
|
| 118 |
+
"""
|
| 119 |
+
Args:
|
| 120 |
+
shape: (h, w) tuple or a int
|
| 121 |
+
interp: PIL interpolation method
|
| 122 |
+
"""
|
| 123 |
+
if isinstance(shape, int):
|
| 124 |
+
shape = (shape, shape)
|
| 125 |
+
shape = tuple(shape)
|
| 126 |
+
self._init(locals())
|
| 127 |
+
|
| 128 |
+
def get_transform(self, image):
|
| 129 |
+
return ResizeTransform(
|
| 130 |
+
image.shape[0], image.shape[1], self.shape[0], self.shape[1], self.interp
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
class ResizeShortestEdge(Augmentation):
|
| 135 |
+
"""
|
| 136 |
+
Resize the image while keeping the aspect ratio unchanged.
|
| 137 |
+
It attempts to scale the shorter edge to the given `short_edge_length`,
|
| 138 |
+
as long as the longer edge does not exceed `max_size`.
|
| 139 |
+
If `max_size` is reached, then downscale so that the longer edge does not exceed max_size.
|
| 140 |
+
"""
|
| 141 |
+
|
| 142 |
+
@torch.jit.unused
|
| 143 |
+
def __init__(
|
| 144 |
+
self, short_edge_length, max_size=sys.maxsize, sample_style="range", interp=Image.BILINEAR
|
| 145 |
+
):
|
| 146 |
+
"""
|
| 147 |
+
Args:
|
| 148 |
+
short_edge_length (list[int]): If ``sample_style=="range"``,
|
| 149 |
+
a [min, max] interval from which to sample the shortest edge length.
|
| 150 |
+
If ``sample_style=="choice"``, a list of shortest edge lengths to sample from.
|
| 151 |
+
max_size (int): maximum allowed longest edge length.
|
| 152 |
+
sample_style (str): either "range" or "choice".
|
| 153 |
+
"""
|
| 154 |
+
super().__init__()
|
| 155 |
+
assert sample_style in ["range", "choice"], sample_style
|
| 156 |
+
|
| 157 |
+
self.is_range = sample_style == "range"
|
| 158 |
+
if isinstance(short_edge_length, int):
|
| 159 |
+
short_edge_length = (short_edge_length, short_edge_length)
|
| 160 |
+
if self.is_range:
|
| 161 |
+
assert len(short_edge_length) == 2, (
|
| 162 |
+
"short_edge_length must be two values using 'range' sample style."
|
| 163 |
+
f" Got {short_edge_length}!"
|
| 164 |
+
)
|
| 165 |
+
self._init(locals())
|
| 166 |
+
|
| 167 |
+
@torch.jit.unused
|
| 168 |
+
def get_transform(self, image):
|
| 169 |
+
h, w = image.shape[:2]
|
| 170 |
+
if self.is_range:
|
| 171 |
+
size = np.random.randint(self.short_edge_length[0], self.short_edge_length[1] + 1)
|
| 172 |
+
else:
|
| 173 |
+
size = np.random.choice(self.short_edge_length)
|
| 174 |
+
if size == 0:
|
| 175 |
+
return NoOpTransform()
|
| 176 |
+
|
| 177 |
+
newh, neww = ResizeShortestEdge.get_output_shape(h, w, size, self.max_size)
|
| 178 |
+
return ResizeTransform(h, w, newh, neww, self.interp)
|
| 179 |
+
|
| 180 |
+
@staticmethod
|
| 181 |
+
def get_output_shape(
|
| 182 |
+
oldh: int, oldw: int, short_edge_length: int, max_size: int
|
| 183 |
+
) -> Tuple[int, int]:
|
| 184 |
+
"""
|
| 185 |
+
Compute the output size given input size and target short edge length.
|
| 186 |
+
"""
|
| 187 |
+
h, w = oldh, oldw
|
| 188 |
+
size = short_edge_length * 1.0
|
| 189 |
+
scale = size / min(h, w)
|
| 190 |
+
if h < w:
|
| 191 |
+
newh, neww = size, scale * w
|
| 192 |
+
else:
|
| 193 |
+
newh, neww = scale * h, size
|
| 194 |
+
if max(newh, neww) > max_size:
|
| 195 |
+
scale = max_size * 1.0 / max(newh, neww)
|
| 196 |
+
newh = newh * scale
|
| 197 |
+
neww = neww * scale
|
| 198 |
+
neww = int(neww + 0.5)
|
| 199 |
+
newh = int(newh + 0.5)
|
| 200 |
+
return (newh, neww)
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
class ResizeScale(Augmentation):
|
| 204 |
+
"""
|
| 205 |
+
Takes target size as input and randomly scales the given target size between `min_scale`
|
| 206 |
+
and `max_scale`. It then scales the input image such that it fits inside the scaled target
|
| 207 |
+
box, keeping the aspect ratio constant.
|
| 208 |
+
This implements the resize part of the Google's 'resize_and_crop' data augmentation:
|
| 209 |
+
https://github.com/tensorflow/tpu/blob/master/models/official/detection/utils/input_utils.py#L127
|
| 210 |
+
"""
|
| 211 |
+
|
| 212 |
+
def __init__(
|
| 213 |
+
self,
|
| 214 |
+
min_scale: float,
|
| 215 |
+
max_scale: float,
|
| 216 |
+
target_height: int,
|
| 217 |
+
target_width: int,
|
| 218 |
+
interp: int = Image.BILINEAR,
|
| 219 |
+
):
|
| 220 |
+
"""
|
| 221 |
+
Args:
|
| 222 |
+
min_scale: minimum image scale range.
|
| 223 |
+
max_scale: maximum image scale range.
|
| 224 |
+
target_height: target image height.
|
| 225 |
+
target_width: target image width.
|
| 226 |
+
interp: image interpolation method.
|
| 227 |
+
"""
|
| 228 |
+
super().__init__()
|
| 229 |
+
self._init(locals())
|
| 230 |
+
|
| 231 |
+
def _get_resize(self, image: np.ndarray, scale: float) -> Transform:
|
| 232 |
+
input_size = image.shape[:2]
|
| 233 |
+
|
| 234 |
+
# Compute new target size given a scale.
|
| 235 |
+
target_size = (self.target_height, self.target_width)
|
| 236 |
+
target_scale_size = np.multiply(target_size, scale)
|
| 237 |
+
|
| 238 |
+
# Compute actual rescaling applied to input image and output size.
|
| 239 |
+
output_scale = np.minimum(
|
| 240 |
+
target_scale_size[0] / input_size[0], target_scale_size[1] / input_size[1]
|
| 241 |
+
)
|
| 242 |
+
output_size = np.round(np.multiply(input_size, output_scale)).astype(int)
|
| 243 |
+
|
| 244 |
+
return ResizeTransform(
|
| 245 |
+
input_size[0], input_size[1], int(output_size[0]), int(output_size[1]), self.interp
|
| 246 |
+
)
|
| 247 |
+
|
| 248 |
+
def get_transform(self, image: np.ndarray) -> Transform:
|
| 249 |
+
random_scale = np.random.uniform(self.min_scale, self.max_scale)
|
| 250 |
+
return self._get_resize(image, random_scale)
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
class RandomRotation(Augmentation):
|
| 254 |
+
"""
|
| 255 |
+
This method returns a copy of this image, rotated the given
|
| 256 |
+
number of degrees counter clockwise around the given center.
|
| 257 |
+
"""
|
| 258 |
+
|
| 259 |
+
def __init__(self, angle, expand=True, center=None, sample_style="range", interp=None):
|
| 260 |
+
"""
|
| 261 |
+
Args:
|
| 262 |
+
angle (list[float]): If ``sample_style=="range"``,
|
| 263 |
+
a [min, max] interval from which to sample the angle (in degrees).
|
| 264 |
+
If ``sample_style=="choice"``, a list of angles to sample from
|
| 265 |
+
expand (bool): choose if the image should be resized to fit the whole
|
| 266 |
+
rotated image (default), or simply cropped
|
| 267 |
+
center (list[[float, float]]): If ``sample_style=="range"``,
|
| 268 |
+
a [[minx, miny], [maxx, maxy]] relative interval from which to sample the center,
|
| 269 |
+
[0, 0] being the top left of the image and [1, 1] the bottom right.
|
| 270 |
+
If ``sample_style=="choice"``, a list of centers to sample from
|
| 271 |
+
Default: None, which means that the center of rotation is the center of the image
|
| 272 |
+
center has no effect if expand=True because it only affects shifting
|
| 273 |
+
"""
|
| 274 |
+
super().__init__()
|
| 275 |
+
assert sample_style in ["range", "choice"], sample_style
|
| 276 |
+
self.is_range = sample_style == "range"
|
| 277 |
+
if isinstance(angle, (float, int)):
|
| 278 |
+
angle = (angle, angle)
|
| 279 |
+
if center is not None and isinstance(center[0], (float, int)):
|
| 280 |
+
center = (center, center)
|
| 281 |
+
self._init(locals())
|
| 282 |
+
|
| 283 |
+
def get_transform(self, image):
|
| 284 |
+
h, w = image.shape[:2]
|
| 285 |
+
center = None
|
| 286 |
+
if self.is_range:
|
| 287 |
+
angle = np.random.uniform(self.angle[0], self.angle[1])
|
| 288 |
+
if self.center is not None:
|
| 289 |
+
center = (
|
| 290 |
+
np.random.uniform(self.center[0][0], self.center[1][0]),
|
| 291 |
+
np.random.uniform(self.center[0][1], self.center[1][1]),
|
| 292 |
+
)
|
| 293 |
+
else:
|
| 294 |
+
angle = np.random.choice(self.angle)
|
| 295 |
+
if self.center is not None:
|
| 296 |
+
center = np.random.choice(self.center)
|
| 297 |
+
|
| 298 |
+
if center is not None:
|
| 299 |
+
center = (w * center[0], h * center[1]) # Convert to absolute coordinates
|
| 300 |
+
|
| 301 |
+
if angle % 360 == 0:
|
| 302 |
+
return NoOpTransform()
|
| 303 |
+
|
| 304 |
+
return RotationTransform(h, w, angle, expand=self.expand, center=center, interp=self.interp)
|
| 305 |
+
|
| 306 |
+
|
| 307 |
+
class FixedSizeCrop(Augmentation):
|
| 308 |
+
"""
|
| 309 |
+
If `crop_size` is smaller than the input image size, then it uses a random crop of
|
| 310 |
+
the crop size. If `crop_size` is larger than the input image size, then it pads
|
| 311 |
+
the right and the bottom of the image to the crop size if `pad` is True, otherwise
|
| 312 |
+
it returns the smaller image.
|
| 313 |
+
"""
|
| 314 |
+
|
| 315 |
+
def __init__(
|
| 316 |
+
self,
|
| 317 |
+
crop_size: Tuple[int],
|
| 318 |
+
pad: bool = True,
|
| 319 |
+
pad_value: float = 128.0,
|
| 320 |
+
seg_pad_value: int = 255,
|
| 321 |
+
):
|
| 322 |
+
"""
|
| 323 |
+
Args:
|
| 324 |
+
crop_size: target image (height, width).
|
| 325 |
+
pad: if True, will pad images smaller than `crop_size` up to `crop_size`
|
| 326 |
+
pad_value: the padding value to the image.
|
| 327 |
+
seg_pad_value: the padding value to the segmentation mask.
|
| 328 |
+
"""
|
| 329 |
+
super().__init__()
|
| 330 |
+
self._init(locals())
|
| 331 |
+
|
| 332 |
+
def _get_crop(self, image: np.ndarray) -> Transform:
|
| 333 |
+
# Compute the image scale and scaled size.
|
| 334 |
+
input_size = image.shape[:2]
|
| 335 |
+
output_size = self.crop_size
|
| 336 |
+
|
| 337 |
+
# Add random crop if the image is scaled up.
|
| 338 |
+
max_offset = np.subtract(input_size, output_size)
|
| 339 |
+
max_offset = np.maximum(max_offset, 0)
|
| 340 |
+
offset = np.multiply(max_offset, np.random.uniform(0.0, 1.0))
|
| 341 |
+
offset = np.round(offset).astype(int)
|
| 342 |
+
return CropTransform(
|
| 343 |
+
offset[1], offset[0], output_size[1], output_size[0], input_size[1], input_size[0]
|
| 344 |
+
)
|
| 345 |
+
|
| 346 |
+
def _get_pad(self, image: np.ndarray) -> Transform:
|
| 347 |
+
# Compute the image scale and scaled size.
|
| 348 |
+
input_size = image.shape[:2]
|
| 349 |
+
output_size = self.crop_size
|
| 350 |
+
|
| 351 |
+
# Add padding if the image is scaled down.
|
| 352 |
+
pad_size = np.subtract(output_size, input_size)
|
| 353 |
+
pad_size = np.maximum(pad_size, 0)
|
| 354 |
+
original_size = np.minimum(input_size, output_size)
|
| 355 |
+
return PadTransform(
|
| 356 |
+
0,
|
| 357 |
+
0,
|
| 358 |
+
pad_size[1],
|
| 359 |
+
pad_size[0],
|
| 360 |
+
original_size[1],
|
| 361 |
+
original_size[0],
|
| 362 |
+
self.pad_value,
|
| 363 |
+
self.seg_pad_value,
|
| 364 |
+
)
|
| 365 |
+
|
| 366 |
+
def get_transform(self, image: np.ndarray) -> TransformList:
|
| 367 |
+
transforms = [self._get_crop(image)]
|
| 368 |
+
if self.pad:
|
| 369 |
+
transforms.append(self._get_pad(image))
|
| 370 |
+
return TransformList(transforms)
|
| 371 |
+
|
| 372 |
+
|
| 373 |
+
class RandomCrop(Augmentation):
|
| 374 |
+
"""
|
| 375 |
+
Randomly crop a rectangle region out of an image.
|
| 376 |
+
"""
|
| 377 |
+
|
| 378 |
+
def __init__(self, crop_type: str, crop_size):
|
| 379 |
+
"""
|
| 380 |
+
Args:
|
| 381 |
+
crop_type (str): one of "relative_range", "relative", "absolute", "absolute_range".
|
| 382 |
+
crop_size (tuple[float, float]): two floats, explained below.
|
| 383 |
+
|
| 384 |
+
- "relative": crop a (H * crop_size[0], W * crop_size[1]) region from an input image of
|
| 385 |
+
size (H, W). crop size should be in (0, 1]
|
| 386 |
+
- "relative_range": uniformly sample two values from [crop_size[0], 1]
|
| 387 |
+
and [crop_size[1]], 1], and use them as in "relative" crop type.
|
| 388 |
+
- "absolute" crop a (crop_size[0], crop_size[1]) region from input image.
|
| 389 |
+
crop_size must be smaller than the input image size.
|
| 390 |
+
- "absolute_range", for an input of size (H, W), uniformly sample H_crop in
|
| 391 |
+
[crop_size[0], min(H, crop_size[1])] and W_crop in [crop_size[0], min(W, crop_size[1])].
|
| 392 |
+
Then crop a region (H_crop, W_crop).
|
| 393 |
+
"""
|
| 394 |
+
# TODO style of relative_range and absolute_range are not consistent:
|
| 395 |
+
# one takes (h, w) but another takes (min, max)
|
| 396 |
+
super().__init__()
|
| 397 |
+
assert crop_type in ["relative_range", "relative", "absolute", "absolute_range"]
|
| 398 |
+
self._init(locals())
|
| 399 |
+
|
| 400 |
+
def get_transform(self, image):
|
| 401 |
+
h, w = image.shape[:2]
|
| 402 |
+
croph, cropw = self.get_crop_size((h, w))
|
| 403 |
+
assert h >= croph and w >= cropw, "Shape computation in {} has bugs.".format(self)
|
| 404 |
+
h0 = np.random.randint(h - croph + 1)
|
| 405 |
+
w0 = np.random.randint(w - cropw + 1)
|
| 406 |
+
return CropTransform(w0, h0, cropw, croph)
|
| 407 |
+
|
| 408 |
+
def get_crop_size(self, image_size):
|
| 409 |
+
"""
|
| 410 |
+
Args:
|
| 411 |
+
image_size (tuple): height, width
|
| 412 |
+
|
| 413 |
+
Returns:
|
| 414 |
+
crop_size (tuple): height, width in absolute pixels
|
| 415 |
+
"""
|
| 416 |
+
h, w = image_size
|
| 417 |
+
if self.crop_type == "relative":
|
| 418 |
+
ch, cw = self.crop_size
|
| 419 |
+
return int(h * ch + 0.5), int(w * cw + 0.5)
|
| 420 |
+
elif self.crop_type == "relative_range":
|
| 421 |
+
crop_size = np.asarray(self.crop_size, dtype=np.float32)
|
| 422 |
+
ch, cw = crop_size + np.random.rand(2) * (1 - crop_size)
|
| 423 |
+
return int(h * ch + 0.5), int(w * cw + 0.5)
|
| 424 |
+
elif self.crop_type == "absolute":
|
| 425 |
+
return (min(self.crop_size[0], h), min(self.crop_size[1], w))
|
| 426 |
+
elif self.crop_type == "absolute_range":
|
| 427 |
+
assert self.crop_size[0] <= self.crop_size[1]
|
| 428 |
+
ch = np.random.randint(min(h, self.crop_size[0]), min(h, self.crop_size[1]) + 1)
|
| 429 |
+
cw = np.random.randint(min(w, self.crop_size[0]), min(w, self.crop_size[1]) + 1)
|
| 430 |
+
return ch, cw
|
| 431 |
+
else:
|
| 432 |
+
raise NotImplementedError("Unknown crop type {}".format(self.crop_type))
|
| 433 |
+
|
| 434 |
+
|
| 435 |
+
class RandomCrop_CategoryAreaConstraint(Augmentation):
|
| 436 |
+
"""
|
| 437 |
+
Similar to :class:`RandomCrop`, but find a cropping window such that no single category
|
| 438 |
+
occupies a ratio of more than `single_category_max_area` in semantic segmentation ground
|
| 439 |
+
truth, which can cause unstability in training. The function attempts to find such a valid
|
| 440 |
+
cropping window for at most 10 times.
|
| 441 |
+
"""
|
| 442 |
+
|
| 443 |
+
def __init__(
|
| 444 |
+
self,
|
| 445 |
+
crop_type: str,
|
| 446 |
+
crop_size,
|
| 447 |
+
single_category_max_area: float = 1.0,
|
| 448 |
+
ignored_category: int = None,
|
| 449 |
+
):
|
| 450 |
+
"""
|
| 451 |
+
Args:
|
| 452 |
+
crop_type, crop_size: same as in :class:`RandomCrop`
|
| 453 |
+
single_category_max_area: the maximum allowed area ratio of a
|
| 454 |
+
category. Set to 1.0 to disable
|
| 455 |
+
ignored_category: allow this category in the semantic segmentation
|
| 456 |
+
ground truth to exceed the area ratio. Usually set to the category
|
| 457 |
+
that's ignored in training.
|
| 458 |
+
"""
|
| 459 |
+
self.crop_aug = RandomCrop(crop_type, crop_size)
|
| 460 |
+
self._init(locals())
|
| 461 |
+
|
| 462 |
+
def get_transform(self, image, sem_seg):
|
| 463 |
+
if self.single_category_max_area >= 1.0:
|
| 464 |
+
return self.crop_aug.get_transform(image)
|
| 465 |
+
else:
|
| 466 |
+
h, w = sem_seg.shape
|
| 467 |
+
for _ in range(10):
|
| 468 |
+
crop_size = self.crop_aug.get_crop_size((h, w))
|
| 469 |
+
y0 = np.random.randint(h - crop_size[0] + 1)
|
| 470 |
+
x0 = np.random.randint(w - crop_size[1] + 1)
|
| 471 |
+
sem_seg_temp = sem_seg[y0 : y0 + crop_size[0], x0 : x0 + crop_size[1]]
|
| 472 |
+
labels, cnt = np.unique(sem_seg_temp, return_counts=True)
|
| 473 |
+
if self.ignored_category is not None:
|
| 474 |
+
cnt = cnt[labels != self.ignored_category]
|
| 475 |
+
if len(cnt) > 1 and np.max(cnt) < np.sum(cnt) * self.single_category_max_area:
|
| 476 |
+
break
|
| 477 |
+
crop_tfm = CropTransform(x0, y0, crop_size[1], crop_size[0])
|
| 478 |
+
return crop_tfm
|
| 479 |
+
|
| 480 |
+
|
| 481 |
+
class RandomExtent(Augmentation):
|
| 482 |
+
"""
|
| 483 |
+
Outputs an image by cropping a random "subrect" of the source image.
|
| 484 |
+
|
| 485 |
+
The subrect can be parameterized to include pixels outside the source image,
|
| 486 |
+
in which case they will be set to zeros (i.e. black). The size of the output
|
| 487 |
+
image will vary with the size of the random subrect.
|
| 488 |
+
"""
|
| 489 |
+
|
| 490 |
+
def __init__(self, scale_range, shift_range):
|
| 491 |
+
"""
|
| 492 |
+
Args:
|
| 493 |
+
output_size (h, w): Dimensions of output image
|
| 494 |
+
scale_range (l, h): Range of input-to-output size scaling factor
|
| 495 |
+
shift_range (x, y): Range of shifts of the cropped subrect. The rect
|
| 496 |
+
is shifted by [w / 2 * Uniform(-x, x), h / 2 * Uniform(-y, y)],
|
| 497 |
+
where (w, h) is the (width, height) of the input image. Set each
|
| 498 |
+
component to zero to crop at the image's center.
|
| 499 |
+
"""
|
| 500 |
+
super().__init__()
|
| 501 |
+
self._init(locals())
|
| 502 |
+
|
| 503 |
+
def get_transform(self, image):
|
| 504 |
+
img_h, img_w = image.shape[:2]
|
| 505 |
+
|
| 506 |
+
# Initialize src_rect to fit the input image.
|
| 507 |
+
src_rect = np.array([-0.5 * img_w, -0.5 * img_h, 0.5 * img_w, 0.5 * img_h])
|
| 508 |
+
|
| 509 |
+
# Apply a random scaling to the src_rect.
|
| 510 |
+
src_rect *= np.random.uniform(self.scale_range[0], self.scale_range[1])
|
| 511 |
+
|
| 512 |
+
# Apply a random shift to the coordinates origin.
|
| 513 |
+
src_rect[0::2] += self.shift_range[0] * img_w * (np.random.rand() - 0.5)
|
| 514 |
+
src_rect[1::2] += self.shift_range[1] * img_h * (np.random.rand() - 0.5)
|
| 515 |
+
|
| 516 |
+
# Map src_rect coordinates into image coordinates (center at corner).
|
| 517 |
+
src_rect[0::2] += 0.5 * img_w
|
| 518 |
+
src_rect[1::2] += 0.5 * img_h
|
| 519 |
+
|
| 520 |
+
return ExtentTransform(
|
| 521 |
+
src_rect=(src_rect[0], src_rect[1], src_rect[2], src_rect[3]),
|
| 522 |
+
output_size=(int(src_rect[3] - src_rect[1]), int(src_rect[2] - src_rect[0])),
|
| 523 |
+
)
|
| 524 |
+
|
| 525 |
+
|
| 526 |
+
class RandomContrast(Augmentation):
|
| 527 |
+
"""
|
| 528 |
+
Randomly transforms image contrast.
|
| 529 |
+
|
| 530 |
+
Contrast intensity is uniformly sampled in (intensity_min, intensity_max).
|
| 531 |
+
- intensity < 1 will reduce contrast
|
| 532 |
+
- intensity = 1 will preserve the input image
|
| 533 |
+
- intensity > 1 will increase contrast
|
| 534 |
+
|
| 535 |
+
See: https://pillow.readthedocs.io/en/3.0.x/reference/ImageEnhance.html
|
| 536 |
+
"""
|
| 537 |
+
|
| 538 |
+
def __init__(self, intensity_min, intensity_max):
|
| 539 |
+
"""
|
| 540 |
+
Args:
|
| 541 |
+
intensity_min (float): Minimum augmentation
|
| 542 |
+
intensity_max (float): Maximum augmentation
|
| 543 |
+
"""
|
| 544 |
+
super().__init__()
|
| 545 |
+
self._init(locals())
|
| 546 |
+
|
| 547 |
+
def get_transform(self, image):
|
| 548 |
+
w = np.random.uniform(self.intensity_min, self.intensity_max)
|
| 549 |
+
return BlendTransform(src_image=image.mean(), src_weight=1 - w, dst_weight=w)
|
| 550 |
+
|
| 551 |
+
|
| 552 |
+
class RandomBrightness(Augmentation):
|
| 553 |
+
"""
|
| 554 |
+
Randomly transforms image brightness.
|
| 555 |
+
|
| 556 |
+
Brightness intensity is uniformly sampled in (intensity_min, intensity_max).
|
| 557 |
+
- intensity < 1 will reduce brightness
|
| 558 |
+
- intensity = 1 will preserve the input image
|
| 559 |
+
- intensity > 1 will increase brightness
|
| 560 |
+
|
| 561 |
+
See: https://pillow.readthedocs.io/en/3.0.x/reference/ImageEnhance.html
|
| 562 |
+
"""
|
| 563 |
+
|
| 564 |
+
def __init__(self, intensity_min, intensity_max):
|
| 565 |
+
"""
|
| 566 |
+
Args:
|
| 567 |
+
intensity_min (float): Minimum augmentation
|
| 568 |
+
intensity_max (float): Maximum augmentation
|
| 569 |
+
"""
|
| 570 |
+
super().__init__()
|
| 571 |
+
self._init(locals())
|
| 572 |
+
|
| 573 |
+
def get_transform(self, image):
|
| 574 |
+
w = np.random.uniform(self.intensity_min, self.intensity_max)
|
| 575 |
+
return BlendTransform(src_image=0, src_weight=1 - w, dst_weight=w)
|
| 576 |
+
|
| 577 |
+
|
| 578 |
+
class RandomSaturation(Augmentation):
|
| 579 |
+
"""
|
| 580 |
+
Randomly transforms saturation of an RGB image.
|
| 581 |
+
Input images are assumed to have 'RGB' channel order.
|
| 582 |
+
|
| 583 |
+
Saturation intensity is uniformly sampled in (intensity_min, intensity_max).
|
| 584 |
+
- intensity < 1 will reduce saturation (make the image more grayscale)
|
| 585 |
+
- intensity = 1 will preserve the input image
|
| 586 |
+
- intensity > 1 will increase saturation
|
| 587 |
+
|
| 588 |
+
See: https://pillow.readthedocs.io/en/3.0.x/reference/ImageEnhance.html
|
| 589 |
+
"""
|
| 590 |
+
|
| 591 |
+
def __init__(self, intensity_min, intensity_max):
|
| 592 |
+
"""
|
| 593 |
+
Args:
|
| 594 |
+
intensity_min (float): Minimum augmentation (1 preserves input).
|
| 595 |
+
intensity_max (float): Maximum augmentation (1 preserves input).
|
| 596 |
+
"""
|
| 597 |
+
super().__init__()
|
| 598 |
+
self._init(locals())
|
| 599 |
+
|
| 600 |
+
def get_transform(self, image):
|
| 601 |
+
assert image.shape[-1] == 3, "RandomSaturation only works on RGB images"
|
| 602 |
+
w = np.random.uniform(self.intensity_min, self.intensity_max)
|
| 603 |
+
grayscale = image.dot([0.299, 0.587, 0.114])[:, :, np.newaxis]
|
| 604 |
+
return BlendTransform(src_image=grayscale, src_weight=1 - w, dst_weight=w)
|
| 605 |
+
|
| 606 |
+
|
| 607 |
+
class RandomLighting(Augmentation):
|
| 608 |
+
"""
|
| 609 |
+
The "lighting" augmentation described in AlexNet, using fixed PCA over ImageNet.
|
| 610 |
+
Input images are assumed to have 'RGB' channel order.
|
| 611 |
+
|
| 612 |
+
The degree of color jittering is randomly sampled via a normal distribution,
|
| 613 |
+
with standard deviation given by the scale parameter.
|
| 614 |
+
"""
|
| 615 |
+
|
| 616 |
+
def __init__(self, scale):
|
| 617 |
+
"""
|
| 618 |
+
Args:
|
| 619 |
+
scale (float): Standard deviation of principal component weighting.
|
| 620 |
+
"""
|
| 621 |
+
super().__init__()
|
| 622 |
+
self._init(locals())
|
| 623 |
+
self.eigen_vecs = np.array(
|
| 624 |
+
[[-0.5675, 0.7192, 0.4009], [-0.5808, -0.0045, -0.8140], [-0.5836, -0.6948, 0.4203]]
|
| 625 |
+
)
|
| 626 |
+
self.eigen_vals = np.array([0.2175, 0.0188, 0.0045])
|
| 627 |
+
|
| 628 |
+
def get_transform(self, image):
|
| 629 |
+
assert image.shape[-1] == 3, "RandomLighting only works on RGB images"
|
| 630 |
+
weights = np.random.normal(scale=self.scale, size=3)
|
| 631 |
+
return BlendTransform(
|
| 632 |
+
src_image=self.eigen_vecs.dot(weights * self.eigen_vals), src_weight=1.0, dst_weight=1.0
|
| 633 |
+
)
|
| 634 |
+
|
| 635 |
+
|
| 636 |
+
class RandomResize(Augmentation):
|
| 637 |
+
"""Randomly resize image to a target size in shape_list"""
|
| 638 |
+
|
| 639 |
+
def __init__(self, shape_list, interp=Image.BILINEAR):
|
| 640 |
+
"""
|
| 641 |
+
Args:
|
| 642 |
+
shape_list: a list of shapes in (h, w)
|
| 643 |
+
interp: PIL interpolation method
|
| 644 |
+
"""
|
| 645 |
+
self.shape_list = shape_list
|
| 646 |
+
self._init(locals())
|
| 647 |
+
|
| 648 |
+
def get_transform(self, image):
|
| 649 |
+
shape_idx = np.random.randint(low=0, high=len(self.shape_list))
|
| 650 |
+
h, w = self.shape_list[shape_idx]
|
| 651 |
+
return ResizeTransform(image.shape[0], image.shape[1], h, w, self.interp)
|
| 652 |
+
|
| 653 |
+
|
| 654 |
+
class MinIoURandomCrop(Augmentation):
|
| 655 |
+
"""Random crop the image & bboxes, the cropped patches have minimum IoU
|
| 656 |
+
requirement with original image & bboxes, the IoU threshold is randomly
|
| 657 |
+
selected from min_ious.
|
| 658 |
+
|
| 659 |
+
Args:
|
| 660 |
+
min_ious (tuple): minimum IoU threshold for all intersections with
|
| 661 |
+
bounding boxes
|
| 662 |
+
min_crop_size (float): minimum crop's size (i.e. h,w := a*h, a*w,
|
| 663 |
+
where a >= min_crop_size)
|
| 664 |
+
mode_trials: number of trials for sampling min_ious threshold
|
| 665 |
+
crop_trials: number of trials for sampling crop_size after cropping
|
| 666 |
+
"""
|
| 667 |
+
|
| 668 |
+
def __init__(
|
| 669 |
+
self,
|
| 670 |
+
min_ious=(0.1, 0.3, 0.5, 0.7, 0.9),
|
| 671 |
+
min_crop_size=0.3,
|
| 672 |
+
mode_trials=1000,
|
| 673 |
+
crop_trials=50,
|
| 674 |
+
):
|
| 675 |
+
self.min_ious = min_ious
|
| 676 |
+
self.sample_mode = (1, *min_ious, 0)
|
| 677 |
+
self.min_crop_size = min_crop_size
|
| 678 |
+
self.mode_trials = mode_trials
|
| 679 |
+
self.crop_trials = crop_trials
|
| 680 |
+
|
| 681 |
+
def get_transform(self, image, boxes):
|
| 682 |
+
"""Call function to crop images and bounding boxes with minimum IoU
|
| 683 |
+
constraint.
|
| 684 |
+
|
| 685 |
+
Args:
|
| 686 |
+
boxes: ground truth boxes in (x1, y1, x2, y2) format
|
| 687 |
+
"""
|
| 688 |
+
if boxes is None:
|
| 689 |
+
return NoOpTransform()
|
| 690 |
+
h, w, c = image.shape
|
| 691 |
+
for _ in range(self.mode_trials):
|
| 692 |
+
mode = random.choice(self.sample_mode)
|
| 693 |
+
self.mode = mode
|
| 694 |
+
if mode == 1:
|
| 695 |
+
return NoOpTransform()
|
| 696 |
+
|
| 697 |
+
min_iou = mode
|
| 698 |
+
for _ in range(self.crop_trials):
|
| 699 |
+
new_w = random.uniform(self.min_crop_size * w, w)
|
| 700 |
+
new_h = random.uniform(self.min_crop_size * h, h)
|
| 701 |
+
|
| 702 |
+
# h / w in [0.5, 2]
|
| 703 |
+
if new_h / new_w < 0.5 or new_h / new_w > 2:
|
| 704 |
+
continue
|
| 705 |
+
|
| 706 |
+
left = random.uniform(w - new_w)
|
| 707 |
+
top = random.uniform(h - new_h)
|
| 708 |
+
|
| 709 |
+
patch = np.array((int(left), int(top), int(left + new_w), int(top + new_h)))
|
| 710 |
+
# Line or point crop is not allowed
|
| 711 |
+
if patch[2] == patch[0] or patch[3] == patch[1]:
|
| 712 |
+
continue
|
| 713 |
+
overlaps = pairwise_iou(
|
| 714 |
+
Boxes(patch.reshape(-1, 4)), Boxes(boxes.reshape(-1, 4))
|
| 715 |
+
).reshape(-1)
|
| 716 |
+
if len(overlaps) > 0 and overlaps.min() < min_iou:
|
| 717 |
+
continue
|
| 718 |
+
|
| 719 |
+
# center of boxes should inside the crop img
|
| 720 |
+
# only adjust boxes and instance masks when the gt is not empty
|
| 721 |
+
if len(overlaps) > 0:
|
| 722 |
+
# adjust boxes
|
| 723 |
+
def is_center_of_bboxes_in_patch(boxes, patch):
|
| 724 |
+
center = (boxes[:, :2] + boxes[:, 2:]) / 2
|
| 725 |
+
mask = (
|
| 726 |
+
(center[:, 0] > patch[0])
|
| 727 |
+
* (center[:, 1] > patch[1])
|
| 728 |
+
* (center[:, 0] < patch[2])
|
| 729 |
+
* (center[:, 1] < patch[3])
|
| 730 |
+
)
|
| 731 |
+
return mask
|
| 732 |
+
|
| 733 |
+
mask = is_center_of_bboxes_in_patch(boxes, patch)
|
| 734 |
+
if not mask.any():
|
| 735 |
+
continue
|
| 736 |
+
return CropTransform(int(left), int(top), int(new_w), int(new_h))
|
CatVTON/detectron2/data/transforms/transform.py
ADDED
|
@@ -0,0 +1,351 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 3 |
+
|
| 4 |
+
"""
|
| 5 |
+
See "Data Augmentation" tutorial for an overview of the system:
|
| 6 |
+
https://detectron2.readthedocs.io/tutorials/augmentation.html
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import numpy as np
|
| 10 |
+
import torch
|
| 11 |
+
import torch.nn.functional as F
|
| 12 |
+
from fvcore.transforms.transform import (
|
| 13 |
+
CropTransform,
|
| 14 |
+
HFlipTransform,
|
| 15 |
+
NoOpTransform,
|
| 16 |
+
Transform,
|
| 17 |
+
TransformList,
|
| 18 |
+
)
|
| 19 |
+
from PIL import Image
|
| 20 |
+
|
| 21 |
+
try:
|
| 22 |
+
import cv2 # noqa
|
| 23 |
+
except ImportError:
|
| 24 |
+
# OpenCV is an optional dependency at the moment
|
| 25 |
+
pass
|
| 26 |
+
|
| 27 |
+
__all__ = [
|
| 28 |
+
"ExtentTransform",
|
| 29 |
+
"ResizeTransform",
|
| 30 |
+
"RotationTransform",
|
| 31 |
+
"ColorTransform",
|
| 32 |
+
"PILColorTransform",
|
| 33 |
+
]
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class ExtentTransform(Transform):
|
| 37 |
+
"""
|
| 38 |
+
Extracts a subregion from the source image and scales it to the output size.
|
| 39 |
+
|
| 40 |
+
The fill color is used to map pixels from the source rect that fall outside
|
| 41 |
+
the source image.
|
| 42 |
+
|
| 43 |
+
See: https://pillow.readthedocs.io/en/latest/PIL.html#PIL.ImageTransform.ExtentTransform
|
| 44 |
+
"""
|
| 45 |
+
|
| 46 |
+
def __init__(self, src_rect, output_size, interp=Image.BILINEAR, fill=0):
|
| 47 |
+
"""
|
| 48 |
+
Args:
|
| 49 |
+
src_rect (x0, y0, x1, y1): src coordinates
|
| 50 |
+
output_size (h, w): dst image size
|
| 51 |
+
interp: PIL interpolation methods
|
| 52 |
+
fill: Fill color used when src_rect extends outside image
|
| 53 |
+
"""
|
| 54 |
+
super().__init__()
|
| 55 |
+
self._set_attributes(locals())
|
| 56 |
+
|
| 57 |
+
def apply_image(self, img, interp=None):
|
| 58 |
+
h, w = self.output_size
|
| 59 |
+
if len(img.shape) > 2 and img.shape[2] == 1:
|
| 60 |
+
pil_image = Image.fromarray(img[:, :, 0], mode="L")
|
| 61 |
+
else:
|
| 62 |
+
pil_image = Image.fromarray(img)
|
| 63 |
+
pil_image = pil_image.transform(
|
| 64 |
+
size=(w, h),
|
| 65 |
+
method=Image.EXTENT,
|
| 66 |
+
data=self.src_rect,
|
| 67 |
+
resample=interp if interp else self.interp,
|
| 68 |
+
fill=self.fill,
|
| 69 |
+
)
|
| 70 |
+
ret = np.asarray(pil_image)
|
| 71 |
+
if len(img.shape) > 2 and img.shape[2] == 1:
|
| 72 |
+
ret = np.expand_dims(ret, -1)
|
| 73 |
+
return ret
|
| 74 |
+
|
| 75 |
+
def apply_coords(self, coords):
|
| 76 |
+
# Transform image center from source coordinates into output coordinates
|
| 77 |
+
# and then map the new origin to the corner of the output image.
|
| 78 |
+
h, w = self.output_size
|
| 79 |
+
x0, y0, x1, y1 = self.src_rect
|
| 80 |
+
new_coords = coords.astype(np.float32)
|
| 81 |
+
new_coords[:, 0] -= 0.5 * (x0 + x1)
|
| 82 |
+
new_coords[:, 1] -= 0.5 * (y0 + y1)
|
| 83 |
+
new_coords[:, 0] *= w / (x1 - x0)
|
| 84 |
+
new_coords[:, 1] *= h / (y1 - y0)
|
| 85 |
+
new_coords[:, 0] += 0.5 * w
|
| 86 |
+
new_coords[:, 1] += 0.5 * h
|
| 87 |
+
return new_coords
|
| 88 |
+
|
| 89 |
+
def apply_segmentation(self, segmentation):
|
| 90 |
+
segmentation = self.apply_image(segmentation, interp=Image.NEAREST)
|
| 91 |
+
return segmentation
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
class ResizeTransform(Transform):
|
| 95 |
+
"""
|
| 96 |
+
Resize the image to a target size.
|
| 97 |
+
"""
|
| 98 |
+
|
| 99 |
+
def __init__(self, h, w, new_h, new_w, interp=None):
|
| 100 |
+
"""
|
| 101 |
+
Args:
|
| 102 |
+
h, w (int): original image size
|
| 103 |
+
new_h, new_w (int): new image size
|
| 104 |
+
interp: PIL interpolation methods, defaults to bilinear.
|
| 105 |
+
"""
|
| 106 |
+
# TODO decide on PIL vs opencv
|
| 107 |
+
super().__init__()
|
| 108 |
+
if interp is None:
|
| 109 |
+
interp = Image.BILINEAR
|
| 110 |
+
self._set_attributes(locals())
|
| 111 |
+
|
| 112 |
+
def apply_image(self, img, interp=None):
|
| 113 |
+
assert img.shape[:2] == (self.h, self.w)
|
| 114 |
+
assert len(img.shape) <= 4
|
| 115 |
+
interp_method = interp if interp is not None else self.interp
|
| 116 |
+
|
| 117 |
+
if img.dtype == np.uint8:
|
| 118 |
+
if len(img.shape) > 2 and img.shape[2] == 1:
|
| 119 |
+
pil_image = Image.fromarray(img[:, :, 0], mode="L")
|
| 120 |
+
else:
|
| 121 |
+
pil_image = Image.fromarray(img)
|
| 122 |
+
pil_image = pil_image.resize((self.new_w, self.new_h), interp_method)
|
| 123 |
+
ret = np.asarray(pil_image)
|
| 124 |
+
if len(img.shape) > 2 and img.shape[2] == 1:
|
| 125 |
+
ret = np.expand_dims(ret, -1)
|
| 126 |
+
else:
|
| 127 |
+
# PIL only supports uint8
|
| 128 |
+
if any(x < 0 for x in img.strides):
|
| 129 |
+
img = np.ascontiguousarray(img)
|
| 130 |
+
img = torch.from_numpy(img)
|
| 131 |
+
shape = list(img.shape)
|
| 132 |
+
shape_4d = shape[:2] + [1] * (4 - len(shape)) + shape[2:]
|
| 133 |
+
img = img.view(shape_4d).permute(2, 3, 0, 1) # hw(c) -> nchw
|
| 134 |
+
_PIL_RESIZE_TO_INTERPOLATE_MODE = {
|
| 135 |
+
Image.NEAREST: "nearest",
|
| 136 |
+
Image.BILINEAR: "bilinear",
|
| 137 |
+
Image.BICUBIC: "bicubic",
|
| 138 |
+
}
|
| 139 |
+
mode = _PIL_RESIZE_TO_INTERPOLATE_MODE[interp_method]
|
| 140 |
+
align_corners = None if mode == "nearest" else False
|
| 141 |
+
img = F.interpolate(
|
| 142 |
+
img, (self.new_h, self.new_w), mode=mode, align_corners=align_corners
|
| 143 |
+
)
|
| 144 |
+
shape[:2] = (self.new_h, self.new_w)
|
| 145 |
+
ret = img.permute(2, 3, 0, 1).view(shape).numpy() # nchw -> hw(c)
|
| 146 |
+
|
| 147 |
+
return ret
|
| 148 |
+
|
| 149 |
+
def apply_coords(self, coords):
|
| 150 |
+
coords[:, 0] = coords[:, 0] * (self.new_w * 1.0 / self.w)
|
| 151 |
+
coords[:, 1] = coords[:, 1] * (self.new_h * 1.0 / self.h)
|
| 152 |
+
return coords
|
| 153 |
+
|
| 154 |
+
def apply_segmentation(self, segmentation):
|
| 155 |
+
segmentation = self.apply_image(segmentation, interp=Image.NEAREST)
|
| 156 |
+
return segmentation
|
| 157 |
+
|
| 158 |
+
def inverse(self):
|
| 159 |
+
return ResizeTransform(self.new_h, self.new_w, self.h, self.w, self.interp)
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
class RotationTransform(Transform):
|
| 163 |
+
"""
|
| 164 |
+
This method returns a copy of this image, rotated the given
|
| 165 |
+
number of degrees counter clockwise around its center.
|
| 166 |
+
"""
|
| 167 |
+
|
| 168 |
+
def __init__(self, h, w, angle, expand=True, center=None, interp=None):
|
| 169 |
+
"""
|
| 170 |
+
Args:
|
| 171 |
+
h, w (int): original image size
|
| 172 |
+
angle (float): degrees for rotation
|
| 173 |
+
expand (bool): choose if the image should be resized to fit the whole
|
| 174 |
+
rotated image (default), or simply cropped
|
| 175 |
+
center (tuple (width, height)): coordinates of the rotation center
|
| 176 |
+
if left to None, the center will be fit to the center of each image
|
| 177 |
+
center has no effect if expand=True because it only affects shifting
|
| 178 |
+
interp: cv2 interpolation method, default cv2.INTER_LINEAR
|
| 179 |
+
"""
|
| 180 |
+
super().__init__()
|
| 181 |
+
image_center = np.array((w / 2, h / 2))
|
| 182 |
+
if center is None:
|
| 183 |
+
center = image_center
|
| 184 |
+
if interp is None:
|
| 185 |
+
interp = cv2.INTER_LINEAR
|
| 186 |
+
abs_cos, abs_sin = (abs(np.cos(np.deg2rad(angle))), abs(np.sin(np.deg2rad(angle))))
|
| 187 |
+
if expand:
|
| 188 |
+
# find the new width and height bounds
|
| 189 |
+
bound_w, bound_h = np.rint(
|
| 190 |
+
[h * abs_sin + w * abs_cos, h * abs_cos + w * abs_sin]
|
| 191 |
+
).astype(int)
|
| 192 |
+
else:
|
| 193 |
+
bound_w, bound_h = w, h
|
| 194 |
+
|
| 195 |
+
self._set_attributes(locals())
|
| 196 |
+
self.rm_coords = self.create_rotation_matrix()
|
| 197 |
+
# Needed because of this problem https://github.com/opencv/opencv/issues/11784
|
| 198 |
+
self.rm_image = self.create_rotation_matrix(offset=-0.5)
|
| 199 |
+
|
| 200 |
+
def apply_image(self, img, interp=None):
|
| 201 |
+
"""
|
| 202 |
+
img should be a numpy array, formatted as Height * Width * Nchannels
|
| 203 |
+
"""
|
| 204 |
+
if len(img) == 0 or self.angle % 360 == 0:
|
| 205 |
+
return img
|
| 206 |
+
assert img.shape[:2] == (self.h, self.w)
|
| 207 |
+
interp = interp if interp is not None else self.interp
|
| 208 |
+
return cv2.warpAffine(img, self.rm_image, (self.bound_w, self.bound_h), flags=interp)
|
| 209 |
+
|
| 210 |
+
def apply_coords(self, coords):
|
| 211 |
+
"""
|
| 212 |
+
coords should be a N * 2 array-like, containing N couples of (x, y) points
|
| 213 |
+
"""
|
| 214 |
+
coords = np.asarray(coords, dtype=float)
|
| 215 |
+
if len(coords) == 0 or self.angle % 360 == 0:
|
| 216 |
+
return coords
|
| 217 |
+
return cv2.transform(coords[:, np.newaxis, :], self.rm_coords)[:, 0, :]
|
| 218 |
+
|
| 219 |
+
def apply_segmentation(self, segmentation):
|
| 220 |
+
segmentation = self.apply_image(segmentation, interp=cv2.INTER_NEAREST)
|
| 221 |
+
return segmentation
|
| 222 |
+
|
| 223 |
+
def create_rotation_matrix(self, offset=0):
|
| 224 |
+
center = (self.center[0] + offset, self.center[1] + offset)
|
| 225 |
+
rm = cv2.getRotationMatrix2D(tuple(center), self.angle, 1)
|
| 226 |
+
if self.expand:
|
| 227 |
+
# Find the coordinates of the center of rotation in the new image
|
| 228 |
+
# The only point for which we know the future coordinates is the center of the image
|
| 229 |
+
rot_im_center = cv2.transform(self.image_center[None, None, :] + offset, rm)[0, 0, :]
|
| 230 |
+
new_center = np.array([self.bound_w / 2, self.bound_h / 2]) + offset - rot_im_center
|
| 231 |
+
# shift the rotation center to the new coordinates
|
| 232 |
+
rm[:, 2] += new_center
|
| 233 |
+
return rm
|
| 234 |
+
|
| 235 |
+
def inverse(self):
|
| 236 |
+
"""
|
| 237 |
+
The inverse is to rotate it back with expand, and crop to get the original shape.
|
| 238 |
+
"""
|
| 239 |
+
if not self.expand: # Not possible to inverse if a part of the image is lost
|
| 240 |
+
raise NotImplementedError()
|
| 241 |
+
rotation = RotationTransform(
|
| 242 |
+
self.bound_h, self.bound_w, -self.angle, True, None, self.interp
|
| 243 |
+
)
|
| 244 |
+
crop = CropTransform(
|
| 245 |
+
(rotation.bound_w - self.w) // 2, (rotation.bound_h - self.h) // 2, self.w, self.h
|
| 246 |
+
)
|
| 247 |
+
return TransformList([rotation, crop])
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
class ColorTransform(Transform):
|
| 251 |
+
"""
|
| 252 |
+
Generic wrapper for any photometric transforms.
|
| 253 |
+
These transformations should only affect the color space and
|
| 254 |
+
not the coordinate space of the image (e.g. annotation
|
| 255 |
+
coordinates such as bounding boxes should not be changed)
|
| 256 |
+
"""
|
| 257 |
+
|
| 258 |
+
def __init__(self, op):
|
| 259 |
+
"""
|
| 260 |
+
Args:
|
| 261 |
+
op (Callable): operation to be applied to the image,
|
| 262 |
+
which takes in an ndarray and returns an ndarray.
|
| 263 |
+
"""
|
| 264 |
+
if not callable(op):
|
| 265 |
+
raise ValueError("op parameter should be callable")
|
| 266 |
+
super().__init__()
|
| 267 |
+
self._set_attributes(locals())
|
| 268 |
+
|
| 269 |
+
def apply_image(self, img):
|
| 270 |
+
return self.op(img)
|
| 271 |
+
|
| 272 |
+
def apply_coords(self, coords):
|
| 273 |
+
return coords
|
| 274 |
+
|
| 275 |
+
def inverse(self):
|
| 276 |
+
return NoOpTransform()
|
| 277 |
+
|
| 278 |
+
def apply_segmentation(self, segmentation):
|
| 279 |
+
return segmentation
|
| 280 |
+
|
| 281 |
+
|
| 282 |
+
class PILColorTransform(ColorTransform):
|
| 283 |
+
"""
|
| 284 |
+
Generic wrapper for PIL Photometric image transforms,
|
| 285 |
+
which affect the color space and not the coordinate
|
| 286 |
+
space of the image
|
| 287 |
+
"""
|
| 288 |
+
|
| 289 |
+
def __init__(self, op):
|
| 290 |
+
"""
|
| 291 |
+
Args:
|
| 292 |
+
op (Callable): operation to be applied to the image,
|
| 293 |
+
which takes in a PIL Image and returns a transformed
|
| 294 |
+
PIL Image.
|
| 295 |
+
For reference on possible operations see:
|
| 296 |
+
- https://pillow.readthedocs.io/en/stable/
|
| 297 |
+
"""
|
| 298 |
+
if not callable(op):
|
| 299 |
+
raise ValueError("op parameter should be callable")
|
| 300 |
+
super().__init__(op)
|
| 301 |
+
|
| 302 |
+
def apply_image(self, img):
|
| 303 |
+
img = Image.fromarray(img)
|
| 304 |
+
return np.asarray(super().apply_image(img))
|
| 305 |
+
|
| 306 |
+
|
| 307 |
+
def HFlip_rotated_box(transform, rotated_boxes):
|
| 308 |
+
"""
|
| 309 |
+
Apply the horizontal flip transform on rotated boxes.
|
| 310 |
+
|
| 311 |
+
Args:
|
| 312 |
+
rotated_boxes (ndarray): Nx5 floating point array of
|
| 313 |
+
(x_center, y_center, width, height, angle_degrees) format
|
| 314 |
+
in absolute coordinates.
|
| 315 |
+
"""
|
| 316 |
+
# Transform x_center
|
| 317 |
+
rotated_boxes[:, 0] = transform.width - rotated_boxes[:, 0]
|
| 318 |
+
# Transform angle
|
| 319 |
+
rotated_boxes[:, 4] = -rotated_boxes[:, 4]
|
| 320 |
+
return rotated_boxes
|
| 321 |
+
|
| 322 |
+
|
| 323 |
+
def Resize_rotated_box(transform, rotated_boxes):
|
| 324 |
+
"""
|
| 325 |
+
Apply the resizing transform on rotated boxes. For details of how these (approximation)
|
| 326 |
+
formulas are derived, please refer to :meth:`RotatedBoxes.scale`.
|
| 327 |
+
|
| 328 |
+
Args:
|
| 329 |
+
rotated_boxes (ndarray): Nx5 floating point array of
|
| 330 |
+
(x_center, y_center, width, height, angle_degrees) format
|
| 331 |
+
in absolute coordinates.
|
| 332 |
+
"""
|
| 333 |
+
scale_factor_x = transform.new_w * 1.0 / transform.w
|
| 334 |
+
scale_factor_y = transform.new_h * 1.0 / transform.h
|
| 335 |
+
rotated_boxes[:, 0] *= scale_factor_x
|
| 336 |
+
rotated_boxes[:, 1] *= scale_factor_y
|
| 337 |
+
theta = rotated_boxes[:, 4] * np.pi / 180.0
|
| 338 |
+
c = np.cos(theta)
|
| 339 |
+
s = np.sin(theta)
|
| 340 |
+
rotated_boxes[:, 2] *= np.sqrt(np.square(scale_factor_x * c) + np.square(scale_factor_y * s))
|
| 341 |
+
rotated_boxes[:, 3] *= np.sqrt(np.square(scale_factor_x * s) + np.square(scale_factor_y * c))
|
| 342 |
+
rotated_boxes[:, 4] = np.arctan2(scale_factor_x * s, scale_factor_y * c) * 180 / np.pi
|
| 343 |
+
|
| 344 |
+
return rotated_boxes
|
| 345 |
+
|
| 346 |
+
|
| 347 |
+
HFlipTransform.register_type("rotated_box", HFlip_rotated_box)
|
| 348 |
+
ResizeTransform.register_type("rotated_box", Resize_rotated_box)
|
| 349 |
+
|
| 350 |
+
# not necessary any more with latest fvcore
|
| 351 |
+
NoOpTransform.register_type("rotated_box", lambda t, x: x)
|
CatVTON/detectron2/export/README.md
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
This directory contains code to prepare a detectron2 model for deployment.
|
| 3 |
+
Currently it supports exporting a detectron2 model to TorchScript, ONNX, or (deprecated) Caffe2 format.
|
| 4 |
+
|
| 5 |
+
Please see [documentation](https://detectron2.readthedocs.io/tutorials/deployment.html) for its usage.
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
### Acknowledgements
|
| 9 |
+
|
| 10 |
+
Thanks to Mobile Vision team at Facebook for developing the Caffe2 conversion tools.
|
| 11 |
+
|
| 12 |
+
Thanks to Computing Platform Department - PAI team at Alibaba Group (@bddpqq, @chenbohua3) who
|
| 13 |
+
help export Detectron2 models to TorchScript.
|
| 14 |
+
|
| 15 |
+
Thanks to ONNX Converter team at Microsoft who help export Detectron2 models to ONNX.
|
CatVTON/detectron2/export/__init__.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
|
| 3 |
+
import warnings
|
| 4 |
+
|
| 5 |
+
from .flatten import TracingAdapter
|
| 6 |
+
from .torchscript import dump_torchscript_IR, scripting_with_instances
|
| 7 |
+
|
| 8 |
+
try:
|
| 9 |
+
from caffe2.proto import caffe2_pb2 as _tmp
|
| 10 |
+
from caffe2.python import core
|
| 11 |
+
|
| 12 |
+
# caffe2 is optional
|
| 13 |
+
except ImportError:
|
| 14 |
+
pass
|
| 15 |
+
else:
|
| 16 |
+
from .api import *
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
# TODO: Update ONNX Opset version and run tests when a newer PyTorch is supported
|
| 20 |
+
STABLE_ONNX_OPSET_VERSION = 11
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def add_export_config(cfg):
|
| 24 |
+
warnings.warn(
|
| 25 |
+
"add_export_config has been deprecated and behaves as no-op function.", DeprecationWarning
|
| 26 |
+
)
|
| 27 |
+
return cfg
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
__all__ = [k for k in globals().keys() if not k.startswith("_")]
|
CatVTON/detectron2/export/api.py
ADDED
|
@@ -0,0 +1,230 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
import copy
|
| 3 |
+
import logging
|
| 4 |
+
import os
|
| 5 |
+
import torch
|
| 6 |
+
from caffe2.proto import caffe2_pb2
|
| 7 |
+
from torch import nn
|
| 8 |
+
|
| 9 |
+
from detectron2.config import CfgNode
|
| 10 |
+
from detectron2.utils.file_io import PathManager
|
| 11 |
+
|
| 12 |
+
from .caffe2_inference import ProtobufDetectionModel
|
| 13 |
+
from .caffe2_modeling import META_ARCH_CAFFE2_EXPORT_TYPE_MAP, convert_batched_inputs_to_c2_format
|
| 14 |
+
from .shared import get_pb_arg_vali, get_pb_arg_vals, save_graph
|
| 15 |
+
|
| 16 |
+
__all__ = [
|
| 17 |
+
"Caffe2Model",
|
| 18 |
+
"Caffe2Tracer",
|
| 19 |
+
]
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class Caffe2Tracer:
|
| 23 |
+
"""
|
| 24 |
+
Make a detectron2 model traceable with Caffe2 operators.
|
| 25 |
+
This class creates a traceable version of a detectron2 model which:
|
| 26 |
+
|
| 27 |
+
1. Rewrite parts of the model using ops in Caffe2. Note that some ops do
|
| 28 |
+
not have GPU implementation in Caffe2.
|
| 29 |
+
2. Remove post-processing and only produce raw layer outputs
|
| 30 |
+
|
| 31 |
+
After making a traceable model, the class provide methods to export such a
|
| 32 |
+
model to different deployment formats.
|
| 33 |
+
Exported graph produced by this class take two input tensors:
|
| 34 |
+
|
| 35 |
+
1. (1, C, H, W) float "data" which is an image (usually in [0, 255]).
|
| 36 |
+
(H, W) often has to be padded to multiple of 32 (depend on the model
|
| 37 |
+
architecture).
|
| 38 |
+
2. 1x3 float "im_info", each row of which is (height, width, 1.0).
|
| 39 |
+
Height and width are true image shapes before padding.
|
| 40 |
+
|
| 41 |
+
The class currently only supports models using builtin meta architectures.
|
| 42 |
+
Batch inference is not supported, and contributions are welcome.
|
| 43 |
+
"""
|
| 44 |
+
|
| 45 |
+
def __init__(self, cfg: CfgNode, model: nn.Module, inputs):
|
| 46 |
+
"""
|
| 47 |
+
Args:
|
| 48 |
+
cfg (CfgNode): a detectron2 config used to construct caffe2-compatible model.
|
| 49 |
+
model (nn.Module): An original pytorch model. Must be among a few official models
|
| 50 |
+
in detectron2 that can be converted to become caffe2-compatible automatically.
|
| 51 |
+
Weights have to be already loaded to this model.
|
| 52 |
+
inputs: sample inputs that the given model takes for inference.
|
| 53 |
+
Will be used to trace the model. For most models, random inputs with
|
| 54 |
+
no detected objects will not work as they lead to wrong traces.
|
| 55 |
+
"""
|
| 56 |
+
assert isinstance(cfg, CfgNode), cfg
|
| 57 |
+
assert isinstance(model, torch.nn.Module), type(model)
|
| 58 |
+
|
| 59 |
+
# TODO make it support custom models, by passing in c2 model directly
|
| 60 |
+
C2MetaArch = META_ARCH_CAFFE2_EXPORT_TYPE_MAP[cfg.MODEL.META_ARCHITECTURE]
|
| 61 |
+
self.traceable_model = C2MetaArch(cfg, copy.deepcopy(model))
|
| 62 |
+
self.inputs = inputs
|
| 63 |
+
self.traceable_inputs = self.traceable_model.get_caffe2_inputs(inputs)
|
| 64 |
+
|
| 65 |
+
def export_caffe2(self):
|
| 66 |
+
"""
|
| 67 |
+
Export the model to Caffe2's protobuf format.
|
| 68 |
+
The returned object can be saved with its :meth:`.save_protobuf()` method.
|
| 69 |
+
The result can be loaded and executed using Caffe2 runtime.
|
| 70 |
+
|
| 71 |
+
Returns:
|
| 72 |
+
:class:`Caffe2Model`
|
| 73 |
+
"""
|
| 74 |
+
from .caffe2_export import export_caffe2_detection_model
|
| 75 |
+
|
| 76 |
+
predict_net, init_net = export_caffe2_detection_model(
|
| 77 |
+
self.traceable_model, self.traceable_inputs
|
| 78 |
+
)
|
| 79 |
+
return Caffe2Model(predict_net, init_net)
|
| 80 |
+
|
| 81 |
+
def export_onnx(self):
|
| 82 |
+
"""
|
| 83 |
+
Export the model to ONNX format.
|
| 84 |
+
Note that the exported model contains custom ops only available in caffe2, therefore it
|
| 85 |
+
cannot be directly executed by other runtime (such as onnxruntime or TensorRT).
|
| 86 |
+
Post-processing or transformation passes may be applied on the model to accommodate
|
| 87 |
+
different runtimes, but we currently do not provide support for them.
|
| 88 |
+
|
| 89 |
+
Returns:
|
| 90 |
+
onnx.ModelProto: an onnx model.
|
| 91 |
+
"""
|
| 92 |
+
from .caffe2_export import export_onnx_model as export_onnx_model_impl
|
| 93 |
+
|
| 94 |
+
return export_onnx_model_impl(self.traceable_model, (self.traceable_inputs,))
|
| 95 |
+
|
| 96 |
+
def export_torchscript(self):
|
| 97 |
+
"""
|
| 98 |
+
Export the model to a ``torch.jit.TracedModule`` by tracing.
|
| 99 |
+
The returned object can be saved to a file by ``.save()``.
|
| 100 |
+
|
| 101 |
+
Returns:
|
| 102 |
+
torch.jit.TracedModule: a torch TracedModule
|
| 103 |
+
"""
|
| 104 |
+
logger = logging.getLogger(__name__)
|
| 105 |
+
logger.info("Tracing the model with torch.jit.trace ...")
|
| 106 |
+
with torch.no_grad():
|
| 107 |
+
return torch.jit.trace(self.traceable_model, (self.traceable_inputs,))
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
class Caffe2Model(nn.Module):
|
| 111 |
+
"""
|
| 112 |
+
A wrapper around the traced model in Caffe2's protobuf format.
|
| 113 |
+
The exported graph has different inputs/outputs from the original Pytorch
|
| 114 |
+
model, as explained in :class:`Caffe2Tracer`. This class wraps around the
|
| 115 |
+
exported graph to simulate the same interface as the original Pytorch model.
|
| 116 |
+
It also provides functions to save/load models in Caffe2's format.'
|
| 117 |
+
|
| 118 |
+
Examples:
|
| 119 |
+
::
|
| 120 |
+
c2_model = Caffe2Tracer(cfg, torch_model, inputs).export_caffe2()
|
| 121 |
+
inputs = [{"image": img_tensor_CHW}]
|
| 122 |
+
outputs = c2_model(inputs)
|
| 123 |
+
orig_outputs = torch_model(inputs)
|
| 124 |
+
"""
|
| 125 |
+
|
| 126 |
+
def __init__(self, predict_net, init_net):
|
| 127 |
+
super().__init__()
|
| 128 |
+
self.eval() # always in eval mode
|
| 129 |
+
self._predict_net = predict_net
|
| 130 |
+
self._init_net = init_net
|
| 131 |
+
self._predictor = None
|
| 132 |
+
|
| 133 |
+
__init__.__HIDE_SPHINX_DOC__ = True
|
| 134 |
+
|
| 135 |
+
@property
|
| 136 |
+
def predict_net(self):
|
| 137 |
+
"""
|
| 138 |
+
caffe2.core.Net: the underlying caffe2 predict net
|
| 139 |
+
"""
|
| 140 |
+
return self._predict_net
|
| 141 |
+
|
| 142 |
+
@property
|
| 143 |
+
def init_net(self):
|
| 144 |
+
"""
|
| 145 |
+
caffe2.core.Net: the underlying caffe2 init net
|
| 146 |
+
"""
|
| 147 |
+
return self._init_net
|
| 148 |
+
|
| 149 |
+
def save_protobuf(self, output_dir):
|
| 150 |
+
"""
|
| 151 |
+
Save the model as caffe2's protobuf format.
|
| 152 |
+
It saves the following files:
|
| 153 |
+
|
| 154 |
+
* "model.pb": definition of the graph. Can be visualized with
|
| 155 |
+
tools like `netron <https://github.com/lutzroeder/netron>`_.
|
| 156 |
+
* "model_init.pb": model parameters
|
| 157 |
+
* "model.pbtxt": human-readable definition of the graph. Not
|
| 158 |
+
needed for deployment.
|
| 159 |
+
|
| 160 |
+
Args:
|
| 161 |
+
output_dir (str): the output directory to save protobuf files.
|
| 162 |
+
"""
|
| 163 |
+
logger = logging.getLogger(__name__)
|
| 164 |
+
logger.info("Saving model to {} ...".format(output_dir))
|
| 165 |
+
if not PathManager.exists(output_dir):
|
| 166 |
+
PathManager.mkdirs(output_dir)
|
| 167 |
+
|
| 168 |
+
with PathManager.open(os.path.join(output_dir, "model.pb"), "wb") as f:
|
| 169 |
+
f.write(self._predict_net.SerializeToString())
|
| 170 |
+
with PathManager.open(os.path.join(output_dir, "model.pbtxt"), "w") as f:
|
| 171 |
+
f.write(str(self._predict_net))
|
| 172 |
+
with PathManager.open(os.path.join(output_dir, "model_init.pb"), "wb") as f:
|
| 173 |
+
f.write(self._init_net.SerializeToString())
|
| 174 |
+
|
| 175 |
+
def save_graph(self, output_file, inputs=None):
|
| 176 |
+
"""
|
| 177 |
+
Save the graph as SVG format.
|
| 178 |
+
|
| 179 |
+
Args:
|
| 180 |
+
output_file (str): a SVG file
|
| 181 |
+
inputs: optional inputs given to the model.
|
| 182 |
+
If given, the inputs will be used to run the graph to record
|
| 183 |
+
shape of every tensor. The shape information will be
|
| 184 |
+
saved together with the graph.
|
| 185 |
+
"""
|
| 186 |
+
from .caffe2_export import run_and_save_graph
|
| 187 |
+
|
| 188 |
+
if inputs is None:
|
| 189 |
+
save_graph(self._predict_net, output_file, op_only=False)
|
| 190 |
+
else:
|
| 191 |
+
size_divisibility = get_pb_arg_vali(self._predict_net, "size_divisibility", 0)
|
| 192 |
+
device = get_pb_arg_vals(self._predict_net, "device", b"cpu").decode("ascii")
|
| 193 |
+
inputs = convert_batched_inputs_to_c2_format(inputs, size_divisibility, device)
|
| 194 |
+
inputs = [x.cpu().numpy() for x in inputs]
|
| 195 |
+
run_and_save_graph(self._predict_net, self._init_net, inputs, output_file)
|
| 196 |
+
|
| 197 |
+
@staticmethod
|
| 198 |
+
def load_protobuf(dir):
|
| 199 |
+
"""
|
| 200 |
+
Args:
|
| 201 |
+
dir (str): a directory used to save Caffe2Model with
|
| 202 |
+
:meth:`save_protobuf`.
|
| 203 |
+
The files "model.pb" and "model_init.pb" are needed.
|
| 204 |
+
|
| 205 |
+
Returns:
|
| 206 |
+
Caffe2Model: the caffe2 model loaded from this directory.
|
| 207 |
+
"""
|
| 208 |
+
predict_net = caffe2_pb2.NetDef()
|
| 209 |
+
with PathManager.open(os.path.join(dir, "model.pb"), "rb") as f:
|
| 210 |
+
predict_net.ParseFromString(f.read())
|
| 211 |
+
|
| 212 |
+
init_net = caffe2_pb2.NetDef()
|
| 213 |
+
with PathManager.open(os.path.join(dir, "model_init.pb"), "rb") as f:
|
| 214 |
+
init_net.ParseFromString(f.read())
|
| 215 |
+
|
| 216 |
+
return Caffe2Model(predict_net, init_net)
|
| 217 |
+
|
| 218 |
+
def __call__(self, inputs):
|
| 219 |
+
"""
|
| 220 |
+
An interface that wraps around a Caffe2 model and mimics detectron2's models'
|
| 221 |
+
input/output format. See details about the format at :doc:`/tutorials/models`.
|
| 222 |
+
This is used to compare the outputs of caffe2 model with its original torch model.
|
| 223 |
+
|
| 224 |
+
Due to the extra conversion between Pytorch/Caffe2, this method is not meant for
|
| 225 |
+
benchmark. Because of the conversion, this method also has dependency
|
| 226 |
+
on detectron2 in order to convert to detectron2's output format.
|
| 227 |
+
"""
|
| 228 |
+
if self._predictor is None:
|
| 229 |
+
self._predictor = ProtobufDetectionModel(self._predict_net, self._init_net)
|
| 230 |
+
return self._predictor(inputs)
|
CatVTON/detectron2/export/c10.py
ADDED
|
@@ -0,0 +1,571 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
|
| 3 |
+
import math
|
| 4 |
+
from typing import Dict
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
|
| 8 |
+
from detectron2.layers import ShapeSpec, cat
|
| 9 |
+
from detectron2.layers.roi_align_rotated import ROIAlignRotated
|
| 10 |
+
from detectron2.modeling import poolers
|
| 11 |
+
from detectron2.modeling.proposal_generator import rpn
|
| 12 |
+
from detectron2.modeling.roi_heads.mask_head import mask_rcnn_inference
|
| 13 |
+
from detectron2.structures import Boxes, ImageList, Instances, Keypoints, RotatedBoxes
|
| 14 |
+
|
| 15 |
+
from .shared import alias, to_device
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
"""
|
| 19 |
+
This file contains caffe2-compatible implementation of several detectron2 components.
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class Caffe2Boxes(Boxes):
|
| 24 |
+
"""
|
| 25 |
+
Representing a list of detectron2.structures.Boxes from minibatch, each box
|
| 26 |
+
is represented by a 5d vector (batch index + 4 coordinates), or a 6d vector
|
| 27 |
+
(batch index + 5 coordinates) for RotatedBoxes.
|
| 28 |
+
"""
|
| 29 |
+
|
| 30 |
+
def __init__(self, tensor):
|
| 31 |
+
assert isinstance(tensor, torch.Tensor)
|
| 32 |
+
assert tensor.dim() == 2 and tensor.size(-1) in [4, 5, 6], tensor.size()
|
| 33 |
+
# TODO: make tensor immutable when dim is Nx5 for Boxes,
|
| 34 |
+
# and Nx6 for RotatedBoxes?
|
| 35 |
+
self.tensor = tensor
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
# TODO clean up this class, maybe just extend Instances
|
| 39 |
+
class InstancesList:
|
| 40 |
+
"""
|
| 41 |
+
Tensor representation of a list of Instances object for a batch of images.
|
| 42 |
+
|
| 43 |
+
When dealing with a batch of images with Caffe2 ops, a list of bboxes
|
| 44 |
+
(instances) are usually represented by single Tensor with size
|
| 45 |
+
(sigma(Ni), 5) or (sigma(Ni), 4) plus a batch split Tensor. This class is
|
| 46 |
+
for providing common functions to convert between these two representations.
|
| 47 |
+
"""
|
| 48 |
+
|
| 49 |
+
def __init__(self, im_info, indices, extra_fields=None):
|
| 50 |
+
# [N, 3] -> (H, W, Scale)
|
| 51 |
+
self.im_info = im_info
|
| 52 |
+
# [N,] -> indice of batch to which the instance belongs
|
| 53 |
+
self.indices = indices
|
| 54 |
+
# [N, ...]
|
| 55 |
+
self.batch_extra_fields = extra_fields or {}
|
| 56 |
+
|
| 57 |
+
self.image_size = self.im_info
|
| 58 |
+
|
| 59 |
+
def get_fields(self):
|
| 60 |
+
"""like `get_fields` in the Instances object,
|
| 61 |
+
but return each field in tensor representations"""
|
| 62 |
+
ret = {}
|
| 63 |
+
for k, v in self.batch_extra_fields.items():
|
| 64 |
+
# if isinstance(v, torch.Tensor):
|
| 65 |
+
# tensor_rep = v
|
| 66 |
+
# elif isinstance(v, (Boxes, Keypoints)):
|
| 67 |
+
# tensor_rep = v.tensor
|
| 68 |
+
# else:
|
| 69 |
+
# raise ValueError("Can't find tensor representation for: {}".format())
|
| 70 |
+
ret[k] = v
|
| 71 |
+
return ret
|
| 72 |
+
|
| 73 |
+
def has(self, name):
|
| 74 |
+
return name in self.batch_extra_fields
|
| 75 |
+
|
| 76 |
+
def set(self, name, value):
|
| 77 |
+
# len(tensor) is a bad practice that generates ONNX constants during tracing.
|
| 78 |
+
# Although not a problem for the `assert` statement below, torch ONNX exporter
|
| 79 |
+
# still raises a misleading warning as it does not this call comes from `assert`
|
| 80 |
+
if isinstance(value, Boxes):
|
| 81 |
+
data_len = value.tensor.shape[0]
|
| 82 |
+
elif isinstance(value, torch.Tensor):
|
| 83 |
+
data_len = value.shape[0]
|
| 84 |
+
else:
|
| 85 |
+
data_len = len(value)
|
| 86 |
+
if len(self.batch_extra_fields):
|
| 87 |
+
assert (
|
| 88 |
+
len(self) == data_len
|
| 89 |
+
), "Adding a field of length {} to a Instances of length {}".format(data_len, len(self))
|
| 90 |
+
self.batch_extra_fields[name] = value
|
| 91 |
+
|
| 92 |
+
def __getattr__(self, name):
|
| 93 |
+
if name not in self.batch_extra_fields:
|
| 94 |
+
raise AttributeError("Cannot find field '{}' in the given Instances!".format(name))
|
| 95 |
+
return self.batch_extra_fields[name]
|
| 96 |
+
|
| 97 |
+
def __len__(self):
|
| 98 |
+
return len(self.indices)
|
| 99 |
+
|
| 100 |
+
def flatten(self):
|
| 101 |
+
ret = []
|
| 102 |
+
for _, v in self.batch_extra_fields.items():
|
| 103 |
+
if isinstance(v, (Boxes, Keypoints)):
|
| 104 |
+
ret.append(v.tensor)
|
| 105 |
+
else:
|
| 106 |
+
ret.append(v)
|
| 107 |
+
return ret
|
| 108 |
+
|
| 109 |
+
@staticmethod
|
| 110 |
+
def to_d2_instances_list(instances_list):
|
| 111 |
+
"""
|
| 112 |
+
Convert InstancesList to List[Instances]. The input `instances_list` can
|
| 113 |
+
also be a List[Instances], in this case this method is a non-op.
|
| 114 |
+
"""
|
| 115 |
+
if not isinstance(instances_list, InstancesList):
|
| 116 |
+
assert all(isinstance(x, Instances) for x in instances_list)
|
| 117 |
+
return instances_list
|
| 118 |
+
|
| 119 |
+
ret = []
|
| 120 |
+
for i, info in enumerate(instances_list.im_info):
|
| 121 |
+
instances = Instances(torch.Size([int(info[0].item()), int(info[1].item())]))
|
| 122 |
+
|
| 123 |
+
ids = instances_list.indices == i
|
| 124 |
+
for k, v in instances_list.batch_extra_fields.items():
|
| 125 |
+
if isinstance(v, torch.Tensor):
|
| 126 |
+
instances.set(k, v[ids])
|
| 127 |
+
continue
|
| 128 |
+
elif isinstance(v, Boxes):
|
| 129 |
+
instances.set(k, v[ids, -4:])
|
| 130 |
+
continue
|
| 131 |
+
|
| 132 |
+
target_type, tensor_source = v
|
| 133 |
+
assert isinstance(tensor_source, torch.Tensor)
|
| 134 |
+
assert tensor_source.shape[0] == instances_list.indices.shape[0]
|
| 135 |
+
tensor_source = tensor_source[ids]
|
| 136 |
+
|
| 137 |
+
if issubclass(target_type, Boxes):
|
| 138 |
+
instances.set(k, Boxes(tensor_source[:, -4:]))
|
| 139 |
+
elif issubclass(target_type, Keypoints):
|
| 140 |
+
instances.set(k, Keypoints(tensor_source))
|
| 141 |
+
elif issubclass(target_type, torch.Tensor):
|
| 142 |
+
instances.set(k, tensor_source)
|
| 143 |
+
else:
|
| 144 |
+
raise ValueError("Can't handle targe type: {}".format(target_type))
|
| 145 |
+
|
| 146 |
+
ret.append(instances)
|
| 147 |
+
return ret
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
class Caffe2Compatible:
|
| 151 |
+
"""
|
| 152 |
+
A model can inherit this class to indicate that it can be traced and deployed with caffe2.
|
| 153 |
+
"""
|
| 154 |
+
|
| 155 |
+
def _get_tensor_mode(self):
|
| 156 |
+
return self._tensor_mode
|
| 157 |
+
|
| 158 |
+
def _set_tensor_mode(self, v):
|
| 159 |
+
self._tensor_mode = v
|
| 160 |
+
|
| 161 |
+
tensor_mode = property(_get_tensor_mode, _set_tensor_mode)
|
| 162 |
+
"""
|
| 163 |
+
If true, the model expects C2-style tensor only inputs/outputs format.
|
| 164 |
+
"""
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
class Caffe2RPN(Caffe2Compatible, rpn.RPN):
|
| 168 |
+
@classmethod
|
| 169 |
+
def from_config(cls, cfg, input_shape: Dict[str, ShapeSpec]):
|
| 170 |
+
ret = super(Caffe2Compatible, cls).from_config(cfg, input_shape)
|
| 171 |
+
assert tuple(cfg.MODEL.RPN.BBOX_REG_WEIGHTS) == (1.0, 1.0, 1.0, 1.0) or tuple(
|
| 172 |
+
cfg.MODEL.RPN.BBOX_REG_WEIGHTS
|
| 173 |
+
) == (1.0, 1.0, 1.0, 1.0, 1.0)
|
| 174 |
+
return ret
|
| 175 |
+
|
| 176 |
+
def _generate_proposals(
|
| 177 |
+
self, images, objectness_logits_pred, anchor_deltas_pred, gt_instances=None
|
| 178 |
+
):
|
| 179 |
+
assert isinstance(images, ImageList)
|
| 180 |
+
if self.tensor_mode:
|
| 181 |
+
im_info = images.image_sizes
|
| 182 |
+
else:
|
| 183 |
+
im_info = torch.tensor([[im_sz[0], im_sz[1], 1.0] for im_sz in images.image_sizes]).to(
|
| 184 |
+
images.tensor.device
|
| 185 |
+
)
|
| 186 |
+
assert isinstance(im_info, torch.Tensor)
|
| 187 |
+
|
| 188 |
+
rpn_rois_list = []
|
| 189 |
+
rpn_roi_probs_list = []
|
| 190 |
+
for scores, bbox_deltas, cell_anchors_tensor, feat_stride in zip(
|
| 191 |
+
objectness_logits_pred,
|
| 192 |
+
anchor_deltas_pred,
|
| 193 |
+
[b for (n, b) in self.anchor_generator.cell_anchors.named_buffers()],
|
| 194 |
+
self.anchor_generator.strides,
|
| 195 |
+
):
|
| 196 |
+
scores = scores.detach()
|
| 197 |
+
bbox_deltas = bbox_deltas.detach()
|
| 198 |
+
|
| 199 |
+
rpn_rois, rpn_roi_probs = torch.ops._caffe2.GenerateProposals(
|
| 200 |
+
scores,
|
| 201 |
+
bbox_deltas,
|
| 202 |
+
im_info,
|
| 203 |
+
cell_anchors_tensor,
|
| 204 |
+
spatial_scale=1.0 / feat_stride,
|
| 205 |
+
pre_nms_topN=self.pre_nms_topk[self.training],
|
| 206 |
+
post_nms_topN=self.post_nms_topk[self.training],
|
| 207 |
+
nms_thresh=self.nms_thresh,
|
| 208 |
+
min_size=self.min_box_size,
|
| 209 |
+
# correct_transform_coords=True, # deprecated argument
|
| 210 |
+
angle_bound_on=True, # Default
|
| 211 |
+
angle_bound_lo=-180,
|
| 212 |
+
angle_bound_hi=180,
|
| 213 |
+
clip_angle_thresh=1.0, # Default
|
| 214 |
+
legacy_plus_one=False,
|
| 215 |
+
)
|
| 216 |
+
rpn_rois_list.append(rpn_rois)
|
| 217 |
+
rpn_roi_probs_list.append(rpn_roi_probs)
|
| 218 |
+
|
| 219 |
+
# For FPN in D2, in RPN all proposals from different levels are concated
|
| 220 |
+
# together, ranked and picked by top post_nms_topk. Then in ROIPooler
|
| 221 |
+
# it calculates level_assignments and calls the RoIAlign from
|
| 222 |
+
# the corresponding level.
|
| 223 |
+
|
| 224 |
+
if len(objectness_logits_pred) == 1:
|
| 225 |
+
rpn_rois = rpn_rois_list[0]
|
| 226 |
+
rpn_roi_probs = rpn_roi_probs_list[0]
|
| 227 |
+
else:
|
| 228 |
+
assert len(rpn_rois_list) == len(rpn_roi_probs_list)
|
| 229 |
+
rpn_post_nms_topN = self.post_nms_topk[self.training]
|
| 230 |
+
|
| 231 |
+
device = rpn_rois_list[0].device
|
| 232 |
+
input_list = [to_device(x, "cpu") for x in (rpn_rois_list + rpn_roi_probs_list)]
|
| 233 |
+
|
| 234 |
+
# TODO remove this after confirming rpn_max_level/rpn_min_level
|
| 235 |
+
# is not needed in CollectRpnProposals.
|
| 236 |
+
feature_strides = list(self.anchor_generator.strides)
|
| 237 |
+
rpn_min_level = int(math.log2(feature_strides[0]))
|
| 238 |
+
rpn_max_level = int(math.log2(feature_strides[-1]))
|
| 239 |
+
assert (rpn_max_level - rpn_min_level + 1) == len(
|
| 240 |
+
rpn_rois_list
|
| 241 |
+
), "CollectRpnProposals requires continuous levels"
|
| 242 |
+
|
| 243 |
+
rpn_rois = torch.ops._caffe2.CollectRpnProposals(
|
| 244 |
+
input_list,
|
| 245 |
+
# NOTE: in current implementation, rpn_max_level and rpn_min_level
|
| 246 |
+
# are not needed, only the subtraction of two matters and it
|
| 247 |
+
# can be infer from the number of inputs. Keep them now for
|
| 248 |
+
# consistency.
|
| 249 |
+
rpn_max_level=2 + len(rpn_rois_list) - 1,
|
| 250 |
+
rpn_min_level=2,
|
| 251 |
+
rpn_post_nms_topN=rpn_post_nms_topN,
|
| 252 |
+
)
|
| 253 |
+
rpn_rois = to_device(rpn_rois, device)
|
| 254 |
+
rpn_roi_probs = []
|
| 255 |
+
|
| 256 |
+
proposals = self.c2_postprocess(im_info, rpn_rois, rpn_roi_probs, self.tensor_mode)
|
| 257 |
+
return proposals, {}
|
| 258 |
+
|
| 259 |
+
def forward(self, images, features, gt_instances=None):
|
| 260 |
+
assert not self.training
|
| 261 |
+
features = [features[f] for f in self.in_features]
|
| 262 |
+
objectness_logits_pred, anchor_deltas_pred = self.rpn_head(features)
|
| 263 |
+
return self._generate_proposals(
|
| 264 |
+
images,
|
| 265 |
+
objectness_logits_pred,
|
| 266 |
+
anchor_deltas_pred,
|
| 267 |
+
gt_instances,
|
| 268 |
+
)
|
| 269 |
+
|
| 270 |
+
@staticmethod
|
| 271 |
+
def c2_postprocess(im_info, rpn_rois, rpn_roi_probs, tensor_mode):
|
| 272 |
+
proposals = InstancesList(
|
| 273 |
+
im_info=im_info,
|
| 274 |
+
indices=rpn_rois[:, 0],
|
| 275 |
+
extra_fields={
|
| 276 |
+
"proposal_boxes": Caffe2Boxes(rpn_rois),
|
| 277 |
+
"objectness_logits": (torch.Tensor, rpn_roi_probs),
|
| 278 |
+
},
|
| 279 |
+
)
|
| 280 |
+
if not tensor_mode:
|
| 281 |
+
proposals = InstancesList.to_d2_instances_list(proposals)
|
| 282 |
+
else:
|
| 283 |
+
proposals = [proposals]
|
| 284 |
+
return proposals
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
class Caffe2ROIPooler(Caffe2Compatible, poolers.ROIPooler):
|
| 288 |
+
@staticmethod
|
| 289 |
+
def c2_preprocess(box_lists):
|
| 290 |
+
assert all(isinstance(x, Boxes) for x in box_lists)
|
| 291 |
+
if all(isinstance(x, Caffe2Boxes) for x in box_lists):
|
| 292 |
+
# input is pure-tensor based
|
| 293 |
+
assert len(box_lists) == 1
|
| 294 |
+
pooler_fmt_boxes = box_lists[0].tensor
|
| 295 |
+
else:
|
| 296 |
+
pooler_fmt_boxes = poolers.convert_boxes_to_pooler_format(box_lists)
|
| 297 |
+
return pooler_fmt_boxes
|
| 298 |
+
|
| 299 |
+
def forward(self, x, box_lists):
|
| 300 |
+
assert not self.training
|
| 301 |
+
|
| 302 |
+
pooler_fmt_boxes = self.c2_preprocess(box_lists)
|
| 303 |
+
num_level_assignments = len(self.level_poolers)
|
| 304 |
+
|
| 305 |
+
if num_level_assignments == 1:
|
| 306 |
+
if isinstance(self.level_poolers[0], ROIAlignRotated):
|
| 307 |
+
c2_roi_align = torch.ops._caffe2.RoIAlignRotated
|
| 308 |
+
aligned = True
|
| 309 |
+
else:
|
| 310 |
+
c2_roi_align = torch.ops._caffe2.RoIAlign
|
| 311 |
+
aligned = self.level_poolers[0].aligned
|
| 312 |
+
|
| 313 |
+
x0 = x[0]
|
| 314 |
+
if x0.is_quantized:
|
| 315 |
+
x0 = x0.dequantize()
|
| 316 |
+
|
| 317 |
+
out = c2_roi_align(
|
| 318 |
+
x0,
|
| 319 |
+
pooler_fmt_boxes,
|
| 320 |
+
order="NCHW",
|
| 321 |
+
spatial_scale=float(self.level_poolers[0].spatial_scale),
|
| 322 |
+
pooled_h=int(self.output_size[0]),
|
| 323 |
+
pooled_w=int(self.output_size[1]),
|
| 324 |
+
sampling_ratio=int(self.level_poolers[0].sampling_ratio),
|
| 325 |
+
aligned=aligned,
|
| 326 |
+
)
|
| 327 |
+
return out
|
| 328 |
+
|
| 329 |
+
device = pooler_fmt_boxes.device
|
| 330 |
+
assert (
|
| 331 |
+
self.max_level - self.min_level + 1 == 4
|
| 332 |
+
), "Currently DistributeFpnProposals only support 4 levels"
|
| 333 |
+
fpn_outputs = torch.ops._caffe2.DistributeFpnProposals(
|
| 334 |
+
to_device(pooler_fmt_boxes, "cpu"),
|
| 335 |
+
roi_canonical_scale=self.canonical_box_size,
|
| 336 |
+
roi_canonical_level=self.canonical_level,
|
| 337 |
+
roi_max_level=self.max_level,
|
| 338 |
+
roi_min_level=self.min_level,
|
| 339 |
+
legacy_plus_one=False,
|
| 340 |
+
)
|
| 341 |
+
fpn_outputs = [to_device(x, device) for x in fpn_outputs]
|
| 342 |
+
|
| 343 |
+
rois_fpn_list = fpn_outputs[:-1]
|
| 344 |
+
rois_idx_restore_int32 = fpn_outputs[-1]
|
| 345 |
+
|
| 346 |
+
roi_feat_fpn_list = []
|
| 347 |
+
for roi_fpn, x_level, pooler in zip(rois_fpn_list, x, self.level_poolers):
|
| 348 |
+
if isinstance(pooler, ROIAlignRotated):
|
| 349 |
+
c2_roi_align = torch.ops._caffe2.RoIAlignRotated
|
| 350 |
+
aligned = True
|
| 351 |
+
else:
|
| 352 |
+
c2_roi_align = torch.ops._caffe2.RoIAlign
|
| 353 |
+
aligned = bool(pooler.aligned)
|
| 354 |
+
|
| 355 |
+
if x_level.is_quantized:
|
| 356 |
+
x_level = x_level.dequantize()
|
| 357 |
+
|
| 358 |
+
roi_feat_fpn = c2_roi_align(
|
| 359 |
+
x_level,
|
| 360 |
+
roi_fpn,
|
| 361 |
+
order="NCHW",
|
| 362 |
+
spatial_scale=float(pooler.spatial_scale),
|
| 363 |
+
pooled_h=int(self.output_size[0]),
|
| 364 |
+
pooled_w=int(self.output_size[1]),
|
| 365 |
+
sampling_ratio=int(pooler.sampling_ratio),
|
| 366 |
+
aligned=aligned,
|
| 367 |
+
)
|
| 368 |
+
roi_feat_fpn_list.append(roi_feat_fpn)
|
| 369 |
+
|
| 370 |
+
roi_feat_shuffled = cat(roi_feat_fpn_list, dim=0)
|
| 371 |
+
assert roi_feat_shuffled.numel() > 0 and rois_idx_restore_int32.numel() > 0, (
|
| 372 |
+
"Caffe2 export requires tracing with a model checkpoint + input that can produce valid"
|
| 373 |
+
" detections. But no detections were obtained with the given checkpoint and input!"
|
| 374 |
+
)
|
| 375 |
+
roi_feat = torch.ops._caffe2.BatchPermutation(roi_feat_shuffled, rois_idx_restore_int32)
|
| 376 |
+
return roi_feat
|
| 377 |
+
|
| 378 |
+
|
| 379 |
+
def caffe2_fast_rcnn_outputs_inference(tensor_mode, box_predictor, predictions, proposals):
|
| 380 |
+
"""equivalent to FastRCNNOutputLayers.inference"""
|
| 381 |
+
num_classes = box_predictor.num_classes
|
| 382 |
+
score_thresh = box_predictor.test_score_thresh
|
| 383 |
+
nms_thresh = box_predictor.test_nms_thresh
|
| 384 |
+
topk_per_image = box_predictor.test_topk_per_image
|
| 385 |
+
is_rotated = len(box_predictor.box2box_transform.weights) == 5
|
| 386 |
+
|
| 387 |
+
if is_rotated:
|
| 388 |
+
box_dim = 5
|
| 389 |
+
assert box_predictor.box2box_transform.weights[4] == 1, (
|
| 390 |
+
"The weights for Rotated BBoxTransform in C2 have only 4 dimensions,"
|
| 391 |
+
+ " thus enforcing the angle weight to be 1 for now"
|
| 392 |
+
)
|
| 393 |
+
box2box_transform_weights = box_predictor.box2box_transform.weights[:4]
|
| 394 |
+
else:
|
| 395 |
+
box_dim = 4
|
| 396 |
+
box2box_transform_weights = box_predictor.box2box_transform.weights
|
| 397 |
+
|
| 398 |
+
class_logits, box_regression = predictions
|
| 399 |
+
if num_classes + 1 == class_logits.shape[1]:
|
| 400 |
+
class_prob = F.softmax(class_logits, -1)
|
| 401 |
+
else:
|
| 402 |
+
assert num_classes == class_logits.shape[1]
|
| 403 |
+
class_prob = F.sigmoid(class_logits)
|
| 404 |
+
# BoxWithNMSLimit will infer num_classes from the shape of the class_prob
|
| 405 |
+
# So append a zero column as placeholder for the background class
|
| 406 |
+
class_prob = torch.cat((class_prob, torch.zeros(class_prob.shape[0], 1)), dim=1)
|
| 407 |
+
|
| 408 |
+
assert box_regression.shape[1] % box_dim == 0
|
| 409 |
+
cls_agnostic_bbox_reg = box_regression.shape[1] // box_dim == 1
|
| 410 |
+
|
| 411 |
+
input_tensor_mode = proposals[0].proposal_boxes.tensor.shape[1] == box_dim + 1
|
| 412 |
+
|
| 413 |
+
proposal_boxes = proposals[0].proposal_boxes
|
| 414 |
+
if isinstance(proposal_boxes, Caffe2Boxes):
|
| 415 |
+
rois = Caffe2Boxes.cat([p.proposal_boxes for p in proposals])
|
| 416 |
+
elif isinstance(proposal_boxes, RotatedBoxes):
|
| 417 |
+
rois = RotatedBoxes.cat([p.proposal_boxes for p in proposals])
|
| 418 |
+
elif isinstance(proposal_boxes, Boxes):
|
| 419 |
+
rois = Boxes.cat([p.proposal_boxes for p in proposals])
|
| 420 |
+
else:
|
| 421 |
+
raise NotImplementedError(
|
| 422 |
+
'Expected proposals[0].proposal_boxes to be type "Boxes", '
|
| 423 |
+
f"instead got {type(proposal_boxes)}"
|
| 424 |
+
)
|
| 425 |
+
|
| 426 |
+
device, dtype = rois.tensor.device, rois.tensor.dtype
|
| 427 |
+
if input_tensor_mode:
|
| 428 |
+
im_info = proposals[0].image_size
|
| 429 |
+
rois = rois.tensor
|
| 430 |
+
else:
|
| 431 |
+
im_info = torch.tensor([[sz[0], sz[1], 1.0] for sz in [x.image_size for x in proposals]])
|
| 432 |
+
batch_ids = cat(
|
| 433 |
+
[
|
| 434 |
+
torch.full((b, 1), i, dtype=dtype, device=device)
|
| 435 |
+
for i, b in enumerate(len(p) for p in proposals)
|
| 436 |
+
],
|
| 437 |
+
dim=0,
|
| 438 |
+
)
|
| 439 |
+
rois = torch.cat([batch_ids, rois.tensor], dim=1)
|
| 440 |
+
|
| 441 |
+
roi_pred_bbox, roi_batch_splits = torch.ops._caffe2.BBoxTransform(
|
| 442 |
+
to_device(rois, "cpu"),
|
| 443 |
+
to_device(box_regression, "cpu"),
|
| 444 |
+
to_device(im_info, "cpu"),
|
| 445 |
+
weights=box2box_transform_weights,
|
| 446 |
+
apply_scale=True,
|
| 447 |
+
rotated=is_rotated,
|
| 448 |
+
angle_bound_on=True,
|
| 449 |
+
angle_bound_lo=-180,
|
| 450 |
+
angle_bound_hi=180,
|
| 451 |
+
clip_angle_thresh=1.0,
|
| 452 |
+
legacy_plus_one=False,
|
| 453 |
+
)
|
| 454 |
+
roi_pred_bbox = to_device(roi_pred_bbox, device)
|
| 455 |
+
roi_batch_splits = to_device(roi_batch_splits, device)
|
| 456 |
+
|
| 457 |
+
nms_outputs = torch.ops._caffe2.BoxWithNMSLimit(
|
| 458 |
+
to_device(class_prob, "cpu"),
|
| 459 |
+
to_device(roi_pred_bbox, "cpu"),
|
| 460 |
+
to_device(roi_batch_splits, "cpu"),
|
| 461 |
+
score_thresh=float(score_thresh),
|
| 462 |
+
nms=float(nms_thresh),
|
| 463 |
+
detections_per_im=int(topk_per_image),
|
| 464 |
+
soft_nms_enabled=False,
|
| 465 |
+
soft_nms_method="linear",
|
| 466 |
+
soft_nms_sigma=0.5,
|
| 467 |
+
soft_nms_min_score_thres=0.001,
|
| 468 |
+
rotated=is_rotated,
|
| 469 |
+
cls_agnostic_bbox_reg=cls_agnostic_bbox_reg,
|
| 470 |
+
input_boxes_include_bg_cls=False,
|
| 471 |
+
output_classes_include_bg_cls=False,
|
| 472 |
+
legacy_plus_one=False,
|
| 473 |
+
)
|
| 474 |
+
roi_score_nms = to_device(nms_outputs[0], device)
|
| 475 |
+
roi_bbox_nms = to_device(nms_outputs[1], device)
|
| 476 |
+
roi_class_nms = to_device(nms_outputs[2], device)
|
| 477 |
+
roi_batch_splits_nms = to_device(nms_outputs[3], device)
|
| 478 |
+
roi_keeps_nms = to_device(nms_outputs[4], device)
|
| 479 |
+
roi_keeps_size_nms = to_device(nms_outputs[5], device)
|
| 480 |
+
if not tensor_mode:
|
| 481 |
+
roi_class_nms = roi_class_nms.to(torch.int64)
|
| 482 |
+
|
| 483 |
+
roi_batch_ids = cat(
|
| 484 |
+
[
|
| 485 |
+
torch.full((b, 1), i, dtype=dtype, device=device)
|
| 486 |
+
for i, b in enumerate(int(x.item()) for x in roi_batch_splits_nms)
|
| 487 |
+
],
|
| 488 |
+
dim=0,
|
| 489 |
+
)
|
| 490 |
+
|
| 491 |
+
roi_class_nms = alias(roi_class_nms, "class_nms")
|
| 492 |
+
roi_score_nms = alias(roi_score_nms, "score_nms")
|
| 493 |
+
roi_bbox_nms = alias(roi_bbox_nms, "bbox_nms")
|
| 494 |
+
roi_batch_splits_nms = alias(roi_batch_splits_nms, "batch_splits_nms")
|
| 495 |
+
roi_keeps_nms = alias(roi_keeps_nms, "keeps_nms")
|
| 496 |
+
roi_keeps_size_nms = alias(roi_keeps_size_nms, "keeps_size_nms")
|
| 497 |
+
|
| 498 |
+
results = InstancesList(
|
| 499 |
+
im_info=im_info,
|
| 500 |
+
indices=roi_batch_ids[:, 0],
|
| 501 |
+
extra_fields={
|
| 502 |
+
"pred_boxes": Caffe2Boxes(roi_bbox_nms),
|
| 503 |
+
"scores": roi_score_nms,
|
| 504 |
+
"pred_classes": roi_class_nms,
|
| 505 |
+
},
|
| 506 |
+
)
|
| 507 |
+
|
| 508 |
+
if not tensor_mode:
|
| 509 |
+
results = InstancesList.to_d2_instances_list(results)
|
| 510 |
+
batch_splits = roi_batch_splits_nms.int().tolist()
|
| 511 |
+
kept_indices = list(roi_keeps_nms.to(torch.int64).split(batch_splits))
|
| 512 |
+
else:
|
| 513 |
+
results = [results]
|
| 514 |
+
kept_indices = [roi_keeps_nms]
|
| 515 |
+
|
| 516 |
+
return results, kept_indices
|
| 517 |
+
|
| 518 |
+
|
| 519 |
+
class Caffe2FastRCNNOutputsInference:
|
| 520 |
+
def __init__(self, tensor_mode):
|
| 521 |
+
self.tensor_mode = tensor_mode # whether the output is caffe2 tensor mode
|
| 522 |
+
|
| 523 |
+
def __call__(self, box_predictor, predictions, proposals):
|
| 524 |
+
return caffe2_fast_rcnn_outputs_inference(
|
| 525 |
+
self.tensor_mode, box_predictor, predictions, proposals
|
| 526 |
+
)
|
| 527 |
+
|
| 528 |
+
|
| 529 |
+
def caffe2_mask_rcnn_inference(pred_mask_logits, pred_instances):
|
| 530 |
+
"""equivalent to mask_head.mask_rcnn_inference"""
|
| 531 |
+
if all(isinstance(x, InstancesList) for x in pred_instances):
|
| 532 |
+
assert len(pred_instances) == 1
|
| 533 |
+
mask_probs_pred = pred_mask_logits.sigmoid()
|
| 534 |
+
mask_probs_pred = alias(mask_probs_pred, "mask_fcn_probs")
|
| 535 |
+
pred_instances[0].set("pred_masks", mask_probs_pred)
|
| 536 |
+
else:
|
| 537 |
+
mask_rcnn_inference(pred_mask_logits, pred_instances)
|
| 538 |
+
|
| 539 |
+
|
| 540 |
+
class Caffe2MaskRCNNInference:
|
| 541 |
+
def __call__(self, pred_mask_logits, pred_instances):
|
| 542 |
+
return caffe2_mask_rcnn_inference(pred_mask_logits, pred_instances)
|
| 543 |
+
|
| 544 |
+
|
| 545 |
+
def caffe2_keypoint_rcnn_inference(use_heatmap_max_keypoint, pred_keypoint_logits, pred_instances):
|
| 546 |
+
# just return the keypoint heatmap for now,
|
| 547 |
+
# there will be option to call HeatmapMaxKeypointOp
|
| 548 |
+
output = alias(pred_keypoint_logits, "kps_score")
|
| 549 |
+
if all(isinstance(x, InstancesList) for x in pred_instances):
|
| 550 |
+
assert len(pred_instances) == 1
|
| 551 |
+
if use_heatmap_max_keypoint:
|
| 552 |
+
device = output.device
|
| 553 |
+
output = torch.ops._caffe2.HeatmapMaxKeypoint(
|
| 554 |
+
to_device(output, "cpu"),
|
| 555 |
+
pred_instances[0].pred_boxes.tensor,
|
| 556 |
+
should_output_softmax=True, # worth make it configerable?
|
| 557 |
+
)
|
| 558 |
+
output = to_device(output, device)
|
| 559 |
+
output = alias(output, "keypoints_out")
|
| 560 |
+
pred_instances[0].set("pred_keypoints", output)
|
| 561 |
+
return pred_keypoint_logits
|
| 562 |
+
|
| 563 |
+
|
| 564 |
+
class Caffe2KeypointRCNNInference:
|
| 565 |
+
def __init__(self, use_heatmap_max_keypoint):
|
| 566 |
+
self.use_heatmap_max_keypoint = use_heatmap_max_keypoint
|
| 567 |
+
|
| 568 |
+
def __call__(self, pred_keypoint_logits, pred_instances):
|
| 569 |
+
return caffe2_keypoint_rcnn_inference(
|
| 570 |
+
self.use_heatmap_max_keypoint, pred_keypoint_logits, pred_instances
|
| 571 |
+
)
|
CatVTON/detectron2/export/caffe2_export.py
ADDED
|
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
|
| 3 |
+
import copy
|
| 4 |
+
import io
|
| 5 |
+
import logging
|
| 6 |
+
import numpy as np
|
| 7 |
+
from typing import List
|
| 8 |
+
import onnx
|
| 9 |
+
import onnx.optimizer
|
| 10 |
+
import torch
|
| 11 |
+
from caffe2.proto import caffe2_pb2
|
| 12 |
+
from caffe2.python import core
|
| 13 |
+
from caffe2.python.onnx.backend import Caffe2Backend
|
| 14 |
+
from tabulate import tabulate
|
| 15 |
+
from termcolor import colored
|
| 16 |
+
from torch.onnx import OperatorExportTypes
|
| 17 |
+
|
| 18 |
+
from .shared import (
|
| 19 |
+
ScopedWS,
|
| 20 |
+
construct_init_net_from_params,
|
| 21 |
+
fuse_alias_placeholder,
|
| 22 |
+
fuse_copy_between_cpu_and_gpu,
|
| 23 |
+
get_params_from_init_net,
|
| 24 |
+
group_norm_replace_aten_with_caffe2,
|
| 25 |
+
infer_device_type,
|
| 26 |
+
remove_dead_end_ops,
|
| 27 |
+
remove_reshape_for_fc,
|
| 28 |
+
save_graph,
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
logger = logging.getLogger(__name__)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def export_onnx_model(model, inputs):
|
| 35 |
+
"""
|
| 36 |
+
Trace and export a model to onnx format.
|
| 37 |
+
|
| 38 |
+
Args:
|
| 39 |
+
model (nn.Module):
|
| 40 |
+
inputs (tuple[args]): the model will be called by `model(*inputs)`
|
| 41 |
+
|
| 42 |
+
Returns:
|
| 43 |
+
an onnx model
|
| 44 |
+
"""
|
| 45 |
+
assert isinstance(model, torch.nn.Module)
|
| 46 |
+
|
| 47 |
+
# make sure all modules are in eval mode, onnx may change the training state
|
| 48 |
+
# of the module if the states are not consistent
|
| 49 |
+
def _check_eval(module):
|
| 50 |
+
assert not module.training
|
| 51 |
+
|
| 52 |
+
model.apply(_check_eval)
|
| 53 |
+
|
| 54 |
+
# Export the model to ONNX
|
| 55 |
+
with torch.no_grad():
|
| 56 |
+
with io.BytesIO() as f:
|
| 57 |
+
torch.onnx.export(
|
| 58 |
+
model,
|
| 59 |
+
inputs,
|
| 60 |
+
f,
|
| 61 |
+
operator_export_type=OperatorExportTypes.ONNX_ATEN_FALLBACK,
|
| 62 |
+
# verbose=True, # NOTE: uncomment this for debugging
|
| 63 |
+
# export_params=True,
|
| 64 |
+
)
|
| 65 |
+
onnx_model = onnx.load_from_string(f.getvalue())
|
| 66 |
+
|
| 67 |
+
return onnx_model
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def _op_stats(net_def):
|
| 71 |
+
type_count = {}
|
| 72 |
+
for t in [op.type for op in net_def.op]:
|
| 73 |
+
type_count[t] = type_count.get(t, 0) + 1
|
| 74 |
+
type_count_list = sorted(type_count.items(), key=lambda kv: kv[0]) # alphabet
|
| 75 |
+
type_count_list = sorted(type_count_list, key=lambda kv: -kv[1]) # count
|
| 76 |
+
return "\n".join("{:>4}x {}".format(count, name) for name, count in type_count_list)
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def _assign_device_option(
|
| 80 |
+
predict_net: caffe2_pb2.NetDef, init_net: caffe2_pb2.NetDef, tensor_inputs: List[torch.Tensor]
|
| 81 |
+
):
|
| 82 |
+
"""
|
| 83 |
+
ONNX exported network doesn't have concept of device, assign necessary
|
| 84 |
+
device option for each op in order to make it runable on GPU runtime.
|
| 85 |
+
"""
|
| 86 |
+
|
| 87 |
+
def _get_device_type(torch_tensor):
|
| 88 |
+
assert torch_tensor.device.type in ["cpu", "cuda"]
|
| 89 |
+
assert torch_tensor.device.index == 0
|
| 90 |
+
return torch_tensor.device.type
|
| 91 |
+
|
| 92 |
+
def _assign_op_device_option(net_proto, net_ssa, blob_device_types):
|
| 93 |
+
for op, ssa_i in zip(net_proto.op, net_ssa):
|
| 94 |
+
if op.type in ["CopyCPUToGPU", "CopyGPUToCPU"]:
|
| 95 |
+
op.device_option.CopyFrom(core.DeviceOption(caffe2_pb2.CUDA, 0))
|
| 96 |
+
else:
|
| 97 |
+
devices = [blob_device_types[b] for b in ssa_i[0] + ssa_i[1]]
|
| 98 |
+
assert all(d == devices[0] for d in devices)
|
| 99 |
+
if devices[0] == "cuda":
|
| 100 |
+
op.device_option.CopyFrom(core.DeviceOption(caffe2_pb2.CUDA, 0))
|
| 101 |
+
|
| 102 |
+
# update ops in predict_net
|
| 103 |
+
predict_net_input_device_types = {
|
| 104 |
+
(name, 0): _get_device_type(tensor)
|
| 105 |
+
for name, tensor in zip(predict_net.external_input, tensor_inputs)
|
| 106 |
+
}
|
| 107 |
+
predict_net_device_types = infer_device_type(
|
| 108 |
+
predict_net, known_status=predict_net_input_device_types, device_name_style="pytorch"
|
| 109 |
+
)
|
| 110 |
+
predict_net_ssa, _ = core.get_ssa(predict_net)
|
| 111 |
+
_assign_op_device_option(predict_net, predict_net_ssa, predict_net_device_types)
|
| 112 |
+
|
| 113 |
+
# update ops in init_net
|
| 114 |
+
init_net_ssa, versions = core.get_ssa(init_net)
|
| 115 |
+
init_net_output_device_types = {
|
| 116 |
+
(name, versions[name]): predict_net_device_types[(name, 0)]
|
| 117 |
+
for name in init_net.external_output
|
| 118 |
+
}
|
| 119 |
+
init_net_device_types = infer_device_type(
|
| 120 |
+
init_net, known_status=init_net_output_device_types, device_name_style="pytorch"
|
| 121 |
+
)
|
| 122 |
+
_assign_op_device_option(init_net, init_net_ssa, init_net_device_types)
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def export_caffe2_detection_model(model: torch.nn.Module, tensor_inputs: List[torch.Tensor]):
|
| 126 |
+
"""
|
| 127 |
+
Export a caffe2-compatible Detectron2 model to caffe2 format via ONNX.
|
| 128 |
+
|
| 129 |
+
Arg:
|
| 130 |
+
model: a caffe2-compatible version of detectron2 model, defined in caffe2_modeling.py
|
| 131 |
+
tensor_inputs: a list of tensors that caffe2 model takes as input.
|
| 132 |
+
"""
|
| 133 |
+
model = copy.deepcopy(model)
|
| 134 |
+
assert isinstance(model, torch.nn.Module)
|
| 135 |
+
assert hasattr(model, "encode_additional_info")
|
| 136 |
+
|
| 137 |
+
# Export via ONNX
|
| 138 |
+
logger.info(
|
| 139 |
+
"Exporting a {} model via ONNX ...".format(type(model).__name__)
|
| 140 |
+
+ " Some warnings from ONNX are expected and are usually not to worry about."
|
| 141 |
+
)
|
| 142 |
+
onnx_model = export_onnx_model(model, (tensor_inputs,))
|
| 143 |
+
# Convert ONNX model to Caffe2 protobuf
|
| 144 |
+
init_net, predict_net = Caffe2Backend.onnx_graph_to_caffe2_net(onnx_model)
|
| 145 |
+
ops_table = [[op.type, op.input, op.output] for op in predict_net.op]
|
| 146 |
+
table = tabulate(ops_table, headers=["type", "input", "output"], tablefmt="pipe")
|
| 147 |
+
logger.info(
|
| 148 |
+
"ONNX export Done. Exported predict_net (before optimizations):\n" + colored(table, "cyan")
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
# Apply protobuf optimization
|
| 152 |
+
fuse_alias_placeholder(predict_net, init_net)
|
| 153 |
+
if any(t.device.type != "cpu" for t in tensor_inputs):
|
| 154 |
+
fuse_copy_between_cpu_and_gpu(predict_net)
|
| 155 |
+
remove_dead_end_ops(init_net)
|
| 156 |
+
_assign_device_option(predict_net, init_net, tensor_inputs)
|
| 157 |
+
params, device_options = get_params_from_init_net(init_net)
|
| 158 |
+
predict_net, params = remove_reshape_for_fc(predict_net, params)
|
| 159 |
+
init_net = construct_init_net_from_params(params, device_options)
|
| 160 |
+
group_norm_replace_aten_with_caffe2(predict_net)
|
| 161 |
+
|
| 162 |
+
# Record necessary information for running the pb model in Detectron2 system.
|
| 163 |
+
model.encode_additional_info(predict_net, init_net)
|
| 164 |
+
|
| 165 |
+
logger.info("Operators used in predict_net: \n{}".format(_op_stats(predict_net)))
|
| 166 |
+
logger.info("Operators used in init_net: \n{}".format(_op_stats(init_net)))
|
| 167 |
+
|
| 168 |
+
return predict_net, init_net
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
def run_and_save_graph(predict_net, init_net, tensor_inputs, graph_save_path):
|
| 172 |
+
"""
|
| 173 |
+
Run the caffe2 model on given inputs, recording the shape and draw the graph.
|
| 174 |
+
|
| 175 |
+
predict_net/init_net: caffe2 model.
|
| 176 |
+
tensor_inputs: a list of tensors that caffe2 model takes as input.
|
| 177 |
+
graph_save_path: path for saving graph of exported model.
|
| 178 |
+
"""
|
| 179 |
+
|
| 180 |
+
logger.info("Saving graph of ONNX exported model to {} ...".format(graph_save_path))
|
| 181 |
+
save_graph(predict_net, graph_save_path, op_only=False)
|
| 182 |
+
|
| 183 |
+
# Run the exported Caffe2 net
|
| 184 |
+
logger.info("Running ONNX exported model ...")
|
| 185 |
+
with ScopedWS("__ws_tmp__", True) as ws:
|
| 186 |
+
ws.RunNetOnce(init_net)
|
| 187 |
+
initialized_blobs = set(ws.Blobs())
|
| 188 |
+
uninitialized = [inp for inp in predict_net.external_input if inp not in initialized_blobs]
|
| 189 |
+
for name, blob in zip(uninitialized, tensor_inputs):
|
| 190 |
+
ws.FeedBlob(name, blob)
|
| 191 |
+
|
| 192 |
+
try:
|
| 193 |
+
ws.RunNetOnce(predict_net)
|
| 194 |
+
except RuntimeError as e:
|
| 195 |
+
logger.warning("Encountered RuntimeError: \n{}".format(str(e)))
|
| 196 |
+
|
| 197 |
+
ws_blobs = {b: ws.FetchBlob(b) for b in ws.Blobs()}
|
| 198 |
+
blob_sizes = {b: ws_blobs[b].shape for b in ws_blobs if isinstance(ws_blobs[b], np.ndarray)}
|
| 199 |
+
|
| 200 |
+
logger.info("Saving graph with blob shapes to {} ...".format(graph_save_path))
|
| 201 |
+
save_graph(predict_net, graph_save_path, op_only=False, blob_sizes=blob_sizes)
|
| 202 |
+
|
| 203 |
+
return ws_blobs
|
CatVTON/detectron2/export/caffe2_inference.py
ADDED
|
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
|
| 3 |
+
import logging
|
| 4 |
+
import numpy as np
|
| 5 |
+
from itertools import count
|
| 6 |
+
import torch
|
| 7 |
+
from caffe2.proto import caffe2_pb2
|
| 8 |
+
from caffe2.python import core
|
| 9 |
+
|
| 10 |
+
from .caffe2_modeling import META_ARCH_CAFFE2_EXPORT_TYPE_MAP, convert_batched_inputs_to_c2_format
|
| 11 |
+
from .shared import ScopedWS, get_pb_arg_vali, get_pb_arg_vals, infer_device_type
|
| 12 |
+
|
| 13 |
+
logger = logging.getLogger(__name__)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
# ===== ref: mobile-vision predictor's 'Caffe2Wrapper' class ======
|
| 17 |
+
class ProtobufModel(torch.nn.Module):
|
| 18 |
+
"""
|
| 19 |
+
Wrapper of a caffe2's protobuf model.
|
| 20 |
+
It works just like nn.Module, but running caffe2 under the hood.
|
| 21 |
+
Input/Output are tuple[tensor] that match the caffe2 net's external_input/output.
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
_ids = count(0)
|
| 25 |
+
|
| 26 |
+
def __init__(self, predict_net, init_net):
|
| 27 |
+
logger.info(f"Initializing ProtobufModel for: {predict_net.name} ...")
|
| 28 |
+
super().__init__()
|
| 29 |
+
assert isinstance(predict_net, caffe2_pb2.NetDef)
|
| 30 |
+
assert isinstance(init_net, caffe2_pb2.NetDef)
|
| 31 |
+
# create unique temporary workspace for each instance
|
| 32 |
+
self.ws_name = "__tmp_ProtobufModel_{}__".format(next(self._ids))
|
| 33 |
+
self.net = core.Net(predict_net)
|
| 34 |
+
|
| 35 |
+
logger.info("Running init_net once to fill the parameters ...")
|
| 36 |
+
with ScopedWS(self.ws_name, is_reset=True, is_cleanup=False) as ws:
|
| 37 |
+
ws.RunNetOnce(init_net)
|
| 38 |
+
uninitialized_external_input = []
|
| 39 |
+
for blob in self.net.Proto().external_input:
|
| 40 |
+
if blob not in ws.Blobs():
|
| 41 |
+
uninitialized_external_input.append(blob)
|
| 42 |
+
ws.CreateBlob(blob)
|
| 43 |
+
ws.CreateNet(self.net)
|
| 44 |
+
|
| 45 |
+
self._error_msgs = set()
|
| 46 |
+
self._input_blobs = uninitialized_external_input
|
| 47 |
+
|
| 48 |
+
def _infer_output_devices(self, inputs):
|
| 49 |
+
"""
|
| 50 |
+
Returns:
|
| 51 |
+
list[str]: list of device for each external output
|
| 52 |
+
"""
|
| 53 |
+
|
| 54 |
+
def _get_device_type(torch_tensor):
|
| 55 |
+
assert torch_tensor.device.type in ["cpu", "cuda"]
|
| 56 |
+
assert torch_tensor.device.index == 0
|
| 57 |
+
return torch_tensor.device.type
|
| 58 |
+
|
| 59 |
+
predict_net = self.net.Proto()
|
| 60 |
+
input_device_types = {
|
| 61 |
+
(name, 0): _get_device_type(tensor) for name, tensor in zip(self._input_blobs, inputs)
|
| 62 |
+
}
|
| 63 |
+
device_type_map = infer_device_type(
|
| 64 |
+
predict_net, known_status=input_device_types, device_name_style="pytorch"
|
| 65 |
+
)
|
| 66 |
+
ssa, versions = core.get_ssa(predict_net)
|
| 67 |
+
versioned_outputs = [(name, versions[name]) for name in predict_net.external_output]
|
| 68 |
+
output_devices = [device_type_map[outp] for outp in versioned_outputs]
|
| 69 |
+
return output_devices
|
| 70 |
+
|
| 71 |
+
def forward(self, inputs):
|
| 72 |
+
"""
|
| 73 |
+
Args:
|
| 74 |
+
inputs (tuple[torch.Tensor])
|
| 75 |
+
|
| 76 |
+
Returns:
|
| 77 |
+
tuple[torch.Tensor]
|
| 78 |
+
"""
|
| 79 |
+
assert len(inputs) == len(self._input_blobs), (
|
| 80 |
+
f"Length of inputs ({len(inputs)}) "
|
| 81 |
+
f"doesn't match the required input blobs: {self._input_blobs}"
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
with ScopedWS(self.ws_name, is_reset=False, is_cleanup=False) as ws:
|
| 85 |
+
for b, tensor in zip(self._input_blobs, inputs):
|
| 86 |
+
ws.FeedBlob(b, tensor)
|
| 87 |
+
|
| 88 |
+
try:
|
| 89 |
+
ws.RunNet(self.net.Proto().name)
|
| 90 |
+
except RuntimeError as e:
|
| 91 |
+
if not str(e) in self._error_msgs:
|
| 92 |
+
self._error_msgs.add(str(e))
|
| 93 |
+
logger.warning("Encountered new RuntimeError: \n{}".format(str(e)))
|
| 94 |
+
logger.warning("Catch the error and use partial results.")
|
| 95 |
+
|
| 96 |
+
c2_outputs = [ws.FetchBlob(b) for b in self.net.Proto().external_output]
|
| 97 |
+
# Remove outputs of current run, this is necessary in order to
|
| 98 |
+
# prevent fetching the result from previous run if the model fails
|
| 99 |
+
# in the middle.
|
| 100 |
+
for b in self.net.Proto().external_output:
|
| 101 |
+
# Needs to create uninitialized blob to make the net runable.
|
| 102 |
+
# This is "equivalent" to: ws.RemoveBlob(b) then ws.CreateBlob(b),
|
| 103 |
+
# but there'no such API.
|
| 104 |
+
ws.FeedBlob(b, f"{b}, a C++ native class of type nullptr (uninitialized).")
|
| 105 |
+
|
| 106 |
+
# Cast output to torch.Tensor on the desired device
|
| 107 |
+
output_devices = (
|
| 108 |
+
self._infer_output_devices(inputs)
|
| 109 |
+
if any(t.device.type != "cpu" for t in inputs)
|
| 110 |
+
else ["cpu" for _ in self.net.Proto().external_output]
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
outputs = []
|
| 114 |
+
for name, c2_output, device in zip(
|
| 115 |
+
self.net.Proto().external_output, c2_outputs, output_devices
|
| 116 |
+
):
|
| 117 |
+
if not isinstance(c2_output, np.ndarray):
|
| 118 |
+
raise RuntimeError(
|
| 119 |
+
"Invalid output for blob {}, received: {}".format(name, c2_output)
|
| 120 |
+
)
|
| 121 |
+
outputs.append(torch.tensor(c2_output).to(device=device))
|
| 122 |
+
return tuple(outputs)
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
class ProtobufDetectionModel(torch.nn.Module):
|
| 126 |
+
"""
|
| 127 |
+
A class works just like a pytorch meta arch in terms of inference, but running
|
| 128 |
+
caffe2 model under the hood.
|
| 129 |
+
"""
|
| 130 |
+
|
| 131 |
+
def __init__(self, predict_net, init_net, *, convert_outputs=None):
|
| 132 |
+
"""
|
| 133 |
+
Args:
|
| 134 |
+
predict_net, init_net (core.Net): caffe2 nets
|
| 135 |
+
convert_outptus (callable): a function that converts caffe2
|
| 136 |
+
outputs to the same format of the original pytorch model.
|
| 137 |
+
By default, use the one defined in the caffe2 meta_arch.
|
| 138 |
+
"""
|
| 139 |
+
super().__init__()
|
| 140 |
+
self.protobuf_model = ProtobufModel(predict_net, init_net)
|
| 141 |
+
self.size_divisibility = get_pb_arg_vali(predict_net, "size_divisibility", 0)
|
| 142 |
+
self.device = get_pb_arg_vals(predict_net, "device", b"cpu").decode("ascii")
|
| 143 |
+
|
| 144 |
+
if convert_outputs is None:
|
| 145 |
+
meta_arch = get_pb_arg_vals(predict_net, "meta_architecture", b"GeneralizedRCNN")
|
| 146 |
+
meta_arch = META_ARCH_CAFFE2_EXPORT_TYPE_MAP[meta_arch.decode("ascii")]
|
| 147 |
+
self._convert_outputs = meta_arch.get_outputs_converter(predict_net, init_net)
|
| 148 |
+
else:
|
| 149 |
+
self._convert_outputs = convert_outputs
|
| 150 |
+
|
| 151 |
+
def _convert_inputs(self, batched_inputs):
|
| 152 |
+
# currently all models convert inputs in the same way
|
| 153 |
+
return convert_batched_inputs_to_c2_format(
|
| 154 |
+
batched_inputs, self.size_divisibility, self.device
|
| 155 |
+
)
|
| 156 |
+
|
| 157 |
+
def forward(self, batched_inputs):
|
| 158 |
+
c2_inputs = self._convert_inputs(batched_inputs)
|
| 159 |
+
c2_results = self.protobuf_model(c2_inputs)
|
| 160 |
+
c2_results = dict(zip(self.protobuf_model.net.Proto().external_output, c2_results))
|
| 161 |
+
return self._convert_outputs(batched_inputs, c2_inputs, c2_results)
|
CatVTON/detectron2/export/caffe2_modeling.py
ADDED
|
@@ -0,0 +1,420 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
|
| 3 |
+
import functools
|
| 4 |
+
import io
|
| 5 |
+
import struct
|
| 6 |
+
import types
|
| 7 |
+
import torch
|
| 8 |
+
|
| 9 |
+
from detectron2.modeling import meta_arch
|
| 10 |
+
from detectron2.modeling.box_regression import Box2BoxTransform
|
| 11 |
+
from detectron2.modeling.roi_heads import keypoint_head
|
| 12 |
+
from detectron2.structures import Boxes, ImageList, Instances, RotatedBoxes
|
| 13 |
+
|
| 14 |
+
from .c10 import Caffe2Compatible
|
| 15 |
+
from .caffe2_patch import ROIHeadsPatcher, patch_generalized_rcnn
|
| 16 |
+
from .shared import (
|
| 17 |
+
alias,
|
| 18 |
+
check_set_pb_arg,
|
| 19 |
+
get_pb_arg_floats,
|
| 20 |
+
get_pb_arg_valf,
|
| 21 |
+
get_pb_arg_vali,
|
| 22 |
+
get_pb_arg_vals,
|
| 23 |
+
mock_torch_nn_functional_interpolate,
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def assemble_rcnn_outputs_by_name(image_sizes, tensor_outputs, force_mask_on=False):
|
| 28 |
+
"""
|
| 29 |
+
A function to assemble caffe2 model's outputs (i.e. Dict[str, Tensor])
|
| 30 |
+
to detectron2's format (i.e. list of Instances instance).
|
| 31 |
+
This only works when the model follows the Caffe2 detectron's naming convention.
|
| 32 |
+
|
| 33 |
+
Args:
|
| 34 |
+
image_sizes (List[List[int, int]]): [H, W] of every image.
|
| 35 |
+
tensor_outputs (Dict[str, Tensor]): external_output to its tensor.
|
| 36 |
+
|
| 37 |
+
force_mask_on (Bool): if true, the it make sure there'll be pred_masks even
|
| 38 |
+
if the mask is not found from tensor_outputs (usually due to model crash)
|
| 39 |
+
"""
|
| 40 |
+
|
| 41 |
+
results = [Instances(image_size) for image_size in image_sizes]
|
| 42 |
+
|
| 43 |
+
batch_splits = tensor_outputs.get("batch_splits", None)
|
| 44 |
+
if batch_splits:
|
| 45 |
+
raise NotImplementedError()
|
| 46 |
+
assert len(image_sizes) == 1
|
| 47 |
+
result = results[0]
|
| 48 |
+
|
| 49 |
+
bbox_nms = tensor_outputs["bbox_nms"]
|
| 50 |
+
score_nms = tensor_outputs["score_nms"]
|
| 51 |
+
class_nms = tensor_outputs["class_nms"]
|
| 52 |
+
# Detection will always success because Conv support 0-batch
|
| 53 |
+
assert bbox_nms is not None
|
| 54 |
+
assert score_nms is not None
|
| 55 |
+
assert class_nms is not None
|
| 56 |
+
if bbox_nms.shape[1] == 5:
|
| 57 |
+
result.pred_boxes = RotatedBoxes(bbox_nms)
|
| 58 |
+
else:
|
| 59 |
+
result.pred_boxes = Boxes(bbox_nms)
|
| 60 |
+
result.scores = score_nms
|
| 61 |
+
result.pred_classes = class_nms.to(torch.int64)
|
| 62 |
+
|
| 63 |
+
mask_fcn_probs = tensor_outputs.get("mask_fcn_probs", None)
|
| 64 |
+
if mask_fcn_probs is not None:
|
| 65 |
+
# finish the mask pred
|
| 66 |
+
mask_probs_pred = mask_fcn_probs
|
| 67 |
+
num_masks = mask_probs_pred.shape[0]
|
| 68 |
+
class_pred = result.pred_classes
|
| 69 |
+
indices = torch.arange(num_masks, device=class_pred.device)
|
| 70 |
+
mask_probs_pred = mask_probs_pred[indices, class_pred][:, None]
|
| 71 |
+
result.pred_masks = mask_probs_pred
|
| 72 |
+
elif force_mask_on:
|
| 73 |
+
# NOTE: there's no way to know the height/width of mask here, it won't be
|
| 74 |
+
# used anyway when batch size is 0, so just set them to 0.
|
| 75 |
+
result.pred_masks = torch.zeros([0, 1, 0, 0], dtype=torch.uint8)
|
| 76 |
+
|
| 77 |
+
keypoints_out = tensor_outputs.get("keypoints_out", None)
|
| 78 |
+
kps_score = tensor_outputs.get("kps_score", None)
|
| 79 |
+
if keypoints_out is not None:
|
| 80 |
+
# keypoints_out: [N, 4, #kypoints], where 4 is in order of (x, y, score, prob)
|
| 81 |
+
keypoints_tensor = keypoints_out
|
| 82 |
+
# NOTE: it's possible that prob is not calculated if "should_output_softmax"
|
| 83 |
+
# is set to False in HeatmapMaxKeypoint, so just using raw score, seems
|
| 84 |
+
# it doesn't affect mAP. TODO: check more carefully.
|
| 85 |
+
keypoint_xyp = keypoints_tensor.transpose(1, 2)[:, :, [0, 1, 2]]
|
| 86 |
+
result.pred_keypoints = keypoint_xyp
|
| 87 |
+
elif kps_score is not None:
|
| 88 |
+
# keypoint heatmap to sparse data structure
|
| 89 |
+
pred_keypoint_logits = kps_score
|
| 90 |
+
keypoint_head.keypoint_rcnn_inference(pred_keypoint_logits, [result])
|
| 91 |
+
|
| 92 |
+
return results
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def _cast_to_f32(f64):
|
| 96 |
+
return struct.unpack("f", struct.pack("f", f64))[0]
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def set_caffe2_compatible_tensor_mode(model, enable=True):
|
| 100 |
+
def _fn(m):
|
| 101 |
+
if isinstance(m, Caffe2Compatible):
|
| 102 |
+
m.tensor_mode = enable
|
| 103 |
+
|
| 104 |
+
model.apply(_fn)
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def convert_batched_inputs_to_c2_format(batched_inputs, size_divisibility, device):
|
| 108 |
+
"""
|
| 109 |
+
See get_caffe2_inputs() below.
|
| 110 |
+
"""
|
| 111 |
+
assert all(isinstance(x, dict) for x in batched_inputs)
|
| 112 |
+
assert all(x["image"].dim() == 3 for x in batched_inputs)
|
| 113 |
+
|
| 114 |
+
images = [x["image"] for x in batched_inputs]
|
| 115 |
+
images = ImageList.from_tensors(images, size_divisibility)
|
| 116 |
+
|
| 117 |
+
im_info = []
|
| 118 |
+
for input_per_image, image_size in zip(batched_inputs, images.image_sizes):
|
| 119 |
+
target_height = input_per_image.get("height", image_size[0])
|
| 120 |
+
target_width = input_per_image.get("width", image_size[1]) # noqa
|
| 121 |
+
# NOTE: The scale inside im_info is kept as convention and for providing
|
| 122 |
+
# post-processing information if further processing is needed. For
|
| 123 |
+
# current Caffe2 model definitions that don't include post-processing inside
|
| 124 |
+
# the model, this number is not used.
|
| 125 |
+
# NOTE: There can be a slight difference between width and height
|
| 126 |
+
# scales, using a single number can results in numerical difference
|
| 127 |
+
# compared with D2's post-processing.
|
| 128 |
+
scale = target_height / image_size[0]
|
| 129 |
+
im_info.append([image_size[0], image_size[1], scale])
|
| 130 |
+
im_info = torch.Tensor(im_info)
|
| 131 |
+
|
| 132 |
+
return images.tensor.to(device), im_info.to(device)
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
class Caffe2MetaArch(Caffe2Compatible, torch.nn.Module):
|
| 136 |
+
"""
|
| 137 |
+
Base class for caffe2-compatible implementation of a meta architecture.
|
| 138 |
+
The forward is traceable and its traced graph can be converted to caffe2
|
| 139 |
+
graph through ONNX.
|
| 140 |
+
"""
|
| 141 |
+
|
| 142 |
+
def __init__(self, cfg, torch_model, enable_tensor_mode=True):
|
| 143 |
+
"""
|
| 144 |
+
Args:
|
| 145 |
+
cfg (CfgNode):
|
| 146 |
+
torch_model (nn.Module): the detectron2 model (meta_arch) to be
|
| 147 |
+
converted.
|
| 148 |
+
"""
|
| 149 |
+
super().__init__()
|
| 150 |
+
self._wrapped_model = torch_model
|
| 151 |
+
self.eval()
|
| 152 |
+
set_caffe2_compatible_tensor_mode(self, enable_tensor_mode)
|
| 153 |
+
|
| 154 |
+
def get_caffe2_inputs(self, batched_inputs):
|
| 155 |
+
"""
|
| 156 |
+
Convert pytorch-style structured inputs to caffe2-style inputs that
|
| 157 |
+
are tuples of tensors.
|
| 158 |
+
|
| 159 |
+
Args:
|
| 160 |
+
batched_inputs (list[dict]): inputs to a detectron2 model
|
| 161 |
+
in its standard format. Each dict has "image" (CHW tensor), and optionally
|
| 162 |
+
"height" and "width".
|
| 163 |
+
|
| 164 |
+
Returns:
|
| 165 |
+
tuple[Tensor]:
|
| 166 |
+
tuple of tensors that will be the inputs to the
|
| 167 |
+
:meth:`forward` method. For existing models, the first
|
| 168 |
+
is an NCHW tensor (padded and batched); the second is
|
| 169 |
+
a im_info Nx3 tensor, where the rows are
|
| 170 |
+
(height, width, unused legacy parameter)
|
| 171 |
+
"""
|
| 172 |
+
return convert_batched_inputs_to_c2_format(
|
| 173 |
+
batched_inputs,
|
| 174 |
+
self._wrapped_model.backbone.size_divisibility,
|
| 175 |
+
self._wrapped_model.device,
|
| 176 |
+
)
|
| 177 |
+
|
| 178 |
+
def encode_additional_info(self, predict_net, init_net):
|
| 179 |
+
"""
|
| 180 |
+
Save extra metadata that will be used by inference in the output protobuf.
|
| 181 |
+
"""
|
| 182 |
+
pass
|
| 183 |
+
|
| 184 |
+
def forward(self, inputs):
|
| 185 |
+
"""
|
| 186 |
+
Run the forward in caffe2-style. It has to use caffe2-compatible ops
|
| 187 |
+
and the method will be used for tracing.
|
| 188 |
+
|
| 189 |
+
Args:
|
| 190 |
+
inputs (tuple[Tensor]): inputs defined by :meth:`get_caffe2_input`.
|
| 191 |
+
They will be the inputs of the converted caffe2 graph.
|
| 192 |
+
|
| 193 |
+
Returns:
|
| 194 |
+
tuple[Tensor]: output tensors. They will be the outputs of the
|
| 195 |
+
converted caffe2 graph.
|
| 196 |
+
"""
|
| 197 |
+
raise NotImplementedError
|
| 198 |
+
|
| 199 |
+
def _caffe2_preprocess_image(self, inputs):
|
| 200 |
+
"""
|
| 201 |
+
Caffe2 implementation of preprocess_image, which is called inside each MetaArch's forward.
|
| 202 |
+
It normalizes the input images, and the final caffe2 graph assumes the
|
| 203 |
+
inputs have been batched already.
|
| 204 |
+
"""
|
| 205 |
+
data, im_info = inputs
|
| 206 |
+
data = alias(data, "data")
|
| 207 |
+
im_info = alias(im_info, "im_info")
|
| 208 |
+
mean, std = self._wrapped_model.pixel_mean, self._wrapped_model.pixel_std
|
| 209 |
+
normalized_data = (data - mean) / std
|
| 210 |
+
normalized_data = alias(normalized_data, "normalized_data")
|
| 211 |
+
|
| 212 |
+
# Pack (data, im_info) into ImageList which is recognized by self.inference.
|
| 213 |
+
images = ImageList(tensor=normalized_data, image_sizes=im_info)
|
| 214 |
+
return images
|
| 215 |
+
|
| 216 |
+
@staticmethod
|
| 217 |
+
def get_outputs_converter(predict_net, init_net):
|
| 218 |
+
"""
|
| 219 |
+
Creates a function that converts outputs of the caffe2 model to
|
| 220 |
+
detectron2's standard format.
|
| 221 |
+
The function uses information in `predict_net` and `init_net` that are
|
| 222 |
+
available at inferene time. Therefore the function logic can be used in inference.
|
| 223 |
+
|
| 224 |
+
The returned function has the following signature:
|
| 225 |
+
|
| 226 |
+
def convert(batched_inputs, c2_inputs, c2_results) -> detectron2_outputs
|
| 227 |
+
|
| 228 |
+
Where
|
| 229 |
+
|
| 230 |
+
* batched_inputs (list[dict]): the original input format of the meta arch
|
| 231 |
+
* c2_inputs (tuple[Tensor]): the caffe2 inputs.
|
| 232 |
+
* c2_results (dict[str, Tensor]): the caffe2 output format,
|
| 233 |
+
corresponding to the outputs of the :meth:`forward` function.
|
| 234 |
+
* detectron2_outputs: the original output format of the meta arch.
|
| 235 |
+
|
| 236 |
+
This function can be used to compare the outputs of the original meta arch and
|
| 237 |
+
the converted caffe2 graph.
|
| 238 |
+
|
| 239 |
+
Returns:
|
| 240 |
+
callable: a callable of the above signature.
|
| 241 |
+
"""
|
| 242 |
+
raise NotImplementedError
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
class Caffe2GeneralizedRCNN(Caffe2MetaArch):
|
| 246 |
+
def __init__(self, cfg, torch_model, enable_tensor_mode=True):
|
| 247 |
+
assert isinstance(torch_model, meta_arch.GeneralizedRCNN)
|
| 248 |
+
torch_model = patch_generalized_rcnn(torch_model)
|
| 249 |
+
super().__init__(cfg, torch_model, enable_tensor_mode)
|
| 250 |
+
|
| 251 |
+
try:
|
| 252 |
+
use_heatmap_max_keypoint = cfg.EXPORT_CAFFE2.USE_HEATMAP_MAX_KEYPOINT
|
| 253 |
+
except AttributeError:
|
| 254 |
+
use_heatmap_max_keypoint = False
|
| 255 |
+
self.roi_heads_patcher = ROIHeadsPatcher(
|
| 256 |
+
self._wrapped_model.roi_heads, use_heatmap_max_keypoint
|
| 257 |
+
)
|
| 258 |
+
if self.tensor_mode:
|
| 259 |
+
self.roi_heads_patcher.patch_roi_heads()
|
| 260 |
+
|
| 261 |
+
def encode_additional_info(self, predict_net, init_net):
|
| 262 |
+
size_divisibility = self._wrapped_model.backbone.size_divisibility
|
| 263 |
+
check_set_pb_arg(predict_net, "size_divisibility", "i", size_divisibility)
|
| 264 |
+
check_set_pb_arg(
|
| 265 |
+
predict_net, "device", "s", str.encode(str(self._wrapped_model.device), "ascii")
|
| 266 |
+
)
|
| 267 |
+
check_set_pb_arg(predict_net, "meta_architecture", "s", b"GeneralizedRCNN")
|
| 268 |
+
|
| 269 |
+
@mock_torch_nn_functional_interpolate()
|
| 270 |
+
def forward(self, inputs):
|
| 271 |
+
if not self.tensor_mode:
|
| 272 |
+
return self._wrapped_model.inference(inputs)
|
| 273 |
+
images = self._caffe2_preprocess_image(inputs)
|
| 274 |
+
features = self._wrapped_model.backbone(images.tensor)
|
| 275 |
+
proposals, _ = self._wrapped_model.proposal_generator(images, features)
|
| 276 |
+
detector_results, _ = self._wrapped_model.roi_heads(images, features, proposals)
|
| 277 |
+
return tuple(detector_results[0].flatten())
|
| 278 |
+
|
| 279 |
+
@staticmethod
|
| 280 |
+
def get_outputs_converter(predict_net, init_net):
|
| 281 |
+
def f(batched_inputs, c2_inputs, c2_results):
|
| 282 |
+
_, im_info = c2_inputs
|
| 283 |
+
image_sizes = [[int(im[0]), int(im[1])] for im in im_info]
|
| 284 |
+
results = assemble_rcnn_outputs_by_name(image_sizes, c2_results)
|
| 285 |
+
return meta_arch.GeneralizedRCNN._postprocess(results, batched_inputs, image_sizes)
|
| 286 |
+
|
| 287 |
+
return f
|
| 288 |
+
|
| 289 |
+
|
| 290 |
+
class Caffe2RetinaNet(Caffe2MetaArch):
|
| 291 |
+
def __init__(self, cfg, torch_model):
|
| 292 |
+
assert isinstance(torch_model, meta_arch.RetinaNet)
|
| 293 |
+
super().__init__(cfg, torch_model)
|
| 294 |
+
|
| 295 |
+
@mock_torch_nn_functional_interpolate()
|
| 296 |
+
def forward(self, inputs):
|
| 297 |
+
assert self.tensor_mode
|
| 298 |
+
images = self._caffe2_preprocess_image(inputs)
|
| 299 |
+
|
| 300 |
+
# explicitly return the images sizes to avoid removing "im_info" by ONNX
|
| 301 |
+
# since it's not used in the forward path
|
| 302 |
+
return_tensors = [images.image_sizes]
|
| 303 |
+
|
| 304 |
+
features = self._wrapped_model.backbone(images.tensor)
|
| 305 |
+
features = [features[f] for f in self._wrapped_model.head_in_features]
|
| 306 |
+
for i, feature_i in enumerate(features):
|
| 307 |
+
features[i] = alias(feature_i, "feature_{}".format(i), is_backward=True)
|
| 308 |
+
return_tensors.append(features[i])
|
| 309 |
+
|
| 310 |
+
pred_logits, pred_anchor_deltas = self._wrapped_model.head(features)
|
| 311 |
+
for i, (box_cls_i, box_delta_i) in enumerate(zip(pred_logits, pred_anchor_deltas)):
|
| 312 |
+
return_tensors.append(alias(box_cls_i, "box_cls_{}".format(i)))
|
| 313 |
+
return_tensors.append(alias(box_delta_i, "box_delta_{}".format(i)))
|
| 314 |
+
|
| 315 |
+
return tuple(return_tensors)
|
| 316 |
+
|
| 317 |
+
def encode_additional_info(self, predict_net, init_net):
|
| 318 |
+
size_divisibility = self._wrapped_model.backbone.size_divisibility
|
| 319 |
+
check_set_pb_arg(predict_net, "size_divisibility", "i", size_divisibility)
|
| 320 |
+
check_set_pb_arg(
|
| 321 |
+
predict_net, "device", "s", str.encode(str(self._wrapped_model.device), "ascii")
|
| 322 |
+
)
|
| 323 |
+
check_set_pb_arg(predict_net, "meta_architecture", "s", b"RetinaNet")
|
| 324 |
+
|
| 325 |
+
# Inference parameters:
|
| 326 |
+
check_set_pb_arg(
|
| 327 |
+
predict_net, "score_threshold", "f", _cast_to_f32(self._wrapped_model.test_score_thresh)
|
| 328 |
+
)
|
| 329 |
+
check_set_pb_arg(
|
| 330 |
+
predict_net, "topk_candidates", "i", self._wrapped_model.test_topk_candidates
|
| 331 |
+
)
|
| 332 |
+
check_set_pb_arg(
|
| 333 |
+
predict_net, "nms_threshold", "f", _cast_to_f32(self._wrapped_model.test_nms_thresh)
|
| 334 |
+
)
|
| 335 |
+
check_set_pb_arg(
|
| 336 |
+
predict_net,
|
| 337 |
+
"max_detections_per_image",
|
| 338 |
+
"i",
|
| 339 |
+
self._wrapped_model.max_detections_per_image,
|
| 340 |
+
)
|
| 341 |
+
|
| 342 |
+
check_set_pb_arg(
|
| 343 |
+
predict_net,
|
| 344 |
+
"bbox_reg_weights",
|
| 345 |
+
"floats",
|
| 346 |
+
[_cast_to_f32(w) for w in self._wrapped_model.box2box_transform.weights],
|
| 347 |
+
)
|
| 348 |
+
self._encode_anchor_generator_cfg(predict_net)
|
| 349 |
+
|
| 350 |
+
def _encode_anchor_generator_cfg(self, predict_net):
|
| 351 |
+
# serialize anchor_generator for future use
|
| 352 |
+
serialized_anchor_generator = io.BytesIO()
|
| 353 |
+
torch.save(self._wrapped_model.anchor_generator, serialized_anchor_generator)
|
| 354 |
+
# Ideally we can put anchor generating inside the model, then we don't
|
| 355 |
+
# need to store this information.
|
| 356 |
+
bytes = serialized_anchor_generator.getvalue()
|
| 357 |
+
check_set_pb_arg(predict_net, "serialized_anchor_generator", "s", bytes)
|
| 358 |
+
|
| 359 |
+
@staticmethod
|
| 360 |
+
def get_outputs_converter(predict_net, init_net):
|
| 361 |
+
self = types.SimpleNamespace()
|
| 362 |
+
serialized_anchor_generator = io.BytesIO(
|
| 363 |
+
get_pb_arg_vals(predict_net, "serialized_anchor_generator", None)
|
| 364 |
+
)
|
| 365 |
+
self.anchor_generator = torch.load(serialized_anchor_generator)
|
| 366 |
+
bbox_reg_weights = get_pb_arg_floats(predict_net, "bbox_reg_weights", None)
|
| 367 |
+
self.box2box_transform = Box2BoxTransform(weights=tuple(bbox_reg_weights))
|
| 368 |
+
self.test_score_thresh = get_pb_arg_valf(predict_net, "score_threshold", None)
|
| 369 |
+
self.test_topk_candidates = get_pb_arg_vali(predict_net, "topk_candidates", None)
|
| 370 |
+
self.test_nms_thresh = get_pb_arg_valf(predict_net, "nms_threshold", None)
|
| 371 |
+
self.max_detections_per_image = get_pb_arg_vali(
|
| 372 |
+
predict_net, "max_detections_per_image", None
|
| 373 |
+
)
|
| 374 |
+
|
| 375 |
+
# hack to reuse inference code from RetinaNet
|
| 376 |
+
for meth in [
|
| 377 |
+
"forward_inference",
|
| 378 |
+
"inference_single_image",
|
| 379 |
+
"_transpose_dense_predictions",
|
| 380 |
+
"_decode_multi_level_predictions",
|
| 381 |
+
"_decode_per_level_predictions",
|
| 382 |
+
]:
|
| 383 |
+
setattr(self, meth, functools.partial(getattr(meta_arch.RetinaNet, meth), self))
|
| 384 |
+
|
| 385 |
+
def f(batched_inputs, c2_inputs, c2_results):
|
| 386 |
+
_, im_info = c2_inputs
|
| 387 |
+
image_sizes = [[int(im[0]), int(im[1])] for im in im_info]
|
| 388 |
+
dummy_images = ImageList(
|
| 389 |
+
torch.randn(
|
| 390 |
+
(
|
| 391 |
+
len(im_info),
|
| 392 |
+
3,
|
| 393 |
+
)
|
| 394 |
+
+ tuple(image_sizes[0])
|
| 395 |
+
),
|
| 396 |
+
image_sizes,
|
| 397 |
+
)
|
| 398 |
+
|
| 399 |
+
num_features = len([x for x in c2_results.keys() if x.startswith("box_cls_")])
|
| 400 |
+
pred_logits = [c2_results["box_cls_{}".format(i)] for i in range(num_features)]
|
| 401 |
+
pred_anchor_deltas = [c2_results["box_delta_{}".format(i)] for i in range(num_features)]
|
| 402 |
+
|
| 403 |
+
# For each feature level, feature should have the same batch size and
|
| 404 |
+
# spatial dimension as the box_cls and box_delta.
|
| 405 |
+
dummy_features = [x.clone()[:, 0:0, :, :] for x in pred_logits]
|
| 406 |
+
# self.num_classess can be inferred
|
| 407 |
+
self.num_classes = pred_logits[0].shape[1] // (pred_anchor_deltas[0].shape[1] // 4)
|
| 408 |
+
|
| 409 |
+
results = self.forward_inference(
|
| 410 |
+
dummy_images, dummy_features, [pred_logits, pred_anchor_deltas]
|
| 411 |
+
)
|
| 412 |
+
return meta_arch.GeneralizedRCNN._postprocess(results, batched_inputs, image_sizes)
|
| 413 |
+
|
| 414 |
+
return f
|
| 415 |
+
|
| 416 |
+
|
| 417 |
+
META_ARCH_CAFFE2_EXPORT_TYPE_MAP = {
|
| 418 |
+
"GeneralizedRCNN": Caffe2GeneralizedRCNN,
|
| 419 |
+
"RetinaNet": Caffe2RetinaNet,
|
| 420 |
+
}
|
CatVTON/detectron2/export/caffe2_patch.py
ADDED
|
@@ -0,0 +1,189 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
|
| 3 |
+
import contextlib
|
| 4 |
+
from unittest import mock
|
| 5 |
+
import torch
|
| 6 |
+
|
| 7 |
+
from detectron2.modeling import poolers
|
| 8 |
+
from detectron2.modeling.proposal_generator import rpn
|
| 9 |
+
from detectron2.modeling.roi_heads import keypoint_head, mask_head
|
| 10 |
+
from detectron2.modeling.roi_heads.fast_rcnn import FastRCNNOutputLayers
|
| 11 |
+
|
| 12 |
+
from .c10 import (
|
| 13 |
+
Caffe2Compatible,
|
| 14 |
+
Caffe2FastRCNNOutputsInference,
|
| 15 |
+
Caffe2KeypointRCNNInference,
|
| 16 |
+
Caffe2MaskRCNNInference,
|
| 17 |
+
Caffe2ROIPooler,
|
| 18 |
+
Caffe2RPN,
|
| 19 |
+
caffe2_fast_rcnn_outputs_inference,
|
| 20 |
+
caffe2_keypoint_rcnn_inference,
|
| 21 |
+
caffe2_mask_rcnn_inference,
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class GenericMixin:
|
| 26 |
+
pass
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class Caffe2CompatibleConverter:
|
| 30 |
+
"""
|
| 31 |
+
A GenericUpdater which implements the `create_from` interface, by modifying
|
| 32 |
+
module object and assign it with another class replaceCls.
|
| 33 |
+
"""
|
| 34 |
+
|
| 35 |
+
def __init__(self, replaceCls):
|
| 36 |
+
self.replaceCls = replaceCls
|
| 37 |
+
|
| 38 |
+
def create_from(self, module):
|
| 39 |
+
# update module's class to the new class
|
| 40 |
+
assert isinstance(module, torch.nn.Module)
|
| 41 |
+
if issubclass(self.replaceCls, GenericMixin):
|
| 42 |
+
# replaceCls should act as mixin, create a new class on-the-fly
|
| 43 |
+
new_class = type(
|
| 44 |
+
"{}MixedWith{}".format(self.replaceCls.__name__, module.__class__.__name__),
|
| 45 |
+
(self.replaceCls, module.__class__),
|
| 46 |
+
{}, # {"new_method": lambda self: ...},
|
| 47 |
+
)
|
| 48 |
+
module.__class__ = new_class
|
| 49 |
+
else:
|
| 50 |
+
# replaceCls is complete class, this allow arbitrary class swap
|
| 51 |
+
module.__class__ = self.replaceCls
|
| 52 |
+
|
| 53 |
+
# initialize Caffe2Compatible
|
| 54 |
+
if isinstance(module, Caffe2Compatible):
|
| 55 |
+
module.tensor_mode = False
|
| 56 |
+
|
| 57 |
+
return module
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def patch(model, target, updater, *args, **kwargs):
|
| 61 |
+
"""
|
| 62 |
+
recursively (post-order) update all modules with the target type and its
|
| 63 |
+
subclasses, make a initialization/composition/inheritance/... via the
|
| 64 |
+
updater.create_from.
|
| 65 |
+
"""
|
| 66 |
+
for name, module in model.named_children():
|
| 67 |
+
model._modules[name] = patch(module, target, updater, *args, **kwargs)
|
| 68 |
+
if isinstance(model, target):
|
| 69 |
+
return updater.create_from(model, *args, **kwargs)
|
| 70 |
+
return model
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def patch_generalized_rcnn(model):
|
| 74 |
+
ccc = Caffe2CompatibleConverter
|
| 75 |
+
model = patch(model, rpn.RPN, ccc(Caffe2RPN))
|
| 76 |
+
model = patch(model, poolers.ROIPooler, ccc(Caffe2ROIPooler))
|
| 77 |
+
|
| 78 |
+
return model
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
@contextlib.contextmanager
|
| 82 |
+
def mock_fastrcnn_outputs_inference(
|
| 83 |
+
tensor_mode, check=True, box_predictor_type=FastRCNNOutputLayers
|
| 84 |
+
):
|
| 85 |
+
with mock.patch.object(
|
| 86 |
+
box_predictor_type,
|
| 87 |
+
"inference",
|
| 88 |
+
autospec=True,
|
| 89 |
+
side_effect=Caffe2FastRCNNOutputsInference(tensor_mode),
|
| 90 |
+
) as mocked_func:
|
| 91 |
+
yield
|
| 92 |
+
if check:
|
| 93 |
+
assert mocked_func.call_count > 0
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
@contextlib.contextmanager
|
| 97 |
+
def mock_mask_rcnn_inference(tensor_mode, patched_module, check=True):
|
| 98 |
+
with mock.patch(
|
| 99 |
+
"{}.mask_rcnn_inference".format(patched_module), side_effect=Caffe2MaskRCNNInference()
|
| 100 |
+
) as mocked_func:
|
| 101 |
+
yield
|
| 102 |
+
if check:
|
| 103 |
+
assert mocked_func.call_count > 0
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
@contextlib.contextmanager
|
| 107 |
+
def mock_keypoint_rcnn_inference(tensor_mode, patched_module, use_heatmap_max_keypoint, check=True):
|
| 108 |
+
with mock.patch(
|
| 109 |
+
"{}.keypoint_rcnn_inference".format(patched_module),
|
| 110 |
+
side_effect=Caffe2KeypointRCNNInference(use_heatmap_max_keypoint),
|
| 111 |
+
) as mocked_func:
|
| 112 |
+
yield
|
| 113 |
+
if check:
|
| 114 |
+
assert mocked_func.call_count > 0
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
class ROIHeadsPatcher:
|
| 118 |
+
def __init__(self, heads, use_heatmap_max_keypoint):
|
| 119 |
+
self.heads = heads
|
| 120 |
+
self.use_heatmap_max_keypoint = use_heatmap_max_keypoint
|
| 121 |
+
self.previous_patched = {}
|
| 122 |
+
|
| 123 |
+
@contextlib.contextmanager
|
| 124 |
+
def mock_roi_heads(self, tensor_mode=True):
|
| 125 |
+
"""
|
| 126 |
+
Patching several inference functions inside ROIHeads and its subclasses
|
| 127 |
+
|
| 128 |
+
Args:
|
| 129 |
+
tensor_mode (bool): whether the inputs/outputs are caffe2's tensor
|
| 130 |
+
format or not. Default to True.
|
| 131 |
+
"""
|
| 132 |
+
# NOTE: this requries the `keypoint_rcnn_inference` and `mask_rcnn_inference`
|
| 133 |
+
# are called inside the same file as BaseXxxHead due to using mock.patch.
|
| 134 |
+
kpt_heads_mod = keypoint_head.BaseKeypointRCNNHead.__module__
|
| 135 |
+
mask_head_mod = mask_head.BaseMaskRCNNHead.__module__
|
| 136 |
+
|
| 137 |
+
mock_ctx_managers = [
|
| 138 |
+
mock_fastrcnn_outputs_inference(
|
| 139 |
+
tensor_mode=tensor_mode,
|
| 140 |
+
check=True,
|
| 141 |
+
box_predictor_type=type(self.heads.box_predictor),
|
| 142 |
+
)
|
| 143 |
+
]
|
| 144 |
+
if getattr(self.heads, "keypoint_on", False):
|
| 145 |
+
mock_ctx_managers += [
|
| 146 |
+
mock_keypoint_rcnn_inference(
|
| 147 |
+
tensor_mode, kpt_heads_mod, self.use_heatmap_max_keypoint
|
| 148 |
+
)
|
| 149 |
+
]
|
| 150 |
+
if getattr(self.heads, "mask_on", False):
|
| 151 |
+
mock_ctx_managers += [mock_mask_rcnn_inference(tensor_mode, mask_head_mod)]
|
| 152 |
+
|
| 153 |
+
with contextlib.ExitStack() as stack: # python 3.3+
|
| 154 |
+
for mgr in mock_ctx_managers:
|
| 155 |
+
stack.enter_context(mgr)
|
| 156 |
+
yield
|
| 157 |
+
|
| 158 |
+
def patch_roi_heads(self, tensor_mode=True):
|
| 159 |
+
self.previous_patched["box_predictor"] = self.heads.box_predictor.inference
|
| 160 |
+
self.previous_patched["keypoint_rcnn"] = keypoint_head.keypoint_rcnn_inference
|
| 161 |
+
self.previous_patched["mask_rcnn"] = mask_head.mask_rcnn_inference
|
| 162 |
+
|
| 163 |
+
def patched_fastrcnn_outputs_inference(predictions, proposal):
|
| 164 |
+
return caffe2_fast_rcnn_outputs_inference(
|
| 165 |
+
True, self.heads.box_predictor, predictions, proposal
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
self.heads.box_predictor.inference = patched_fastrcnn_outputs_inference
|
| 169 |
+
|
| 170 |
+
if getattr(self.heads, "keypoint_on", False):
|
| 171 |
+
|
| 172 |
+
def patched_keypoint_rcnn_inference(pred_keypoint_logits, pred_instances):
|
| 173 |
+
return caffe2_keypoint_rcnn_inference(
|
| 174 |
+
self.use_heatmap_max_keypoint, pred_keypoint_logits, pred_instances
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
+
keypoint_head.keypoint_rcnn_inference = patched_keypoint_rcnn_inference
|
| 178 |
+
|
| 179 |
+
if getattr(self.heads, "mask_on", False):
|
| 180 |
+
|
| 181 |
+
def patched_mask_rcnn_inference(pred_mask_logits, pred_instances):
|
| 182 |
+
return caffe2_mask_rcnn_inference(pred_mask_logits, pred_instances)
|
| 183 |
+
|
| 184 |
+
mask_head.mask_rcnn_inference = patched_mask_rcnn_inference
|
| 185 |
+
|
| 186 |
+
def unpatch_roi_heads(self):
|
| 187 |
+
self.heads.box_predictor.inference = self.previous_patched["box_predictor"]
|
| 188 |
+
keypoint_head.keypoint_rcnn_inference = self.previous_patched["keypoint_rcnn"]
|
| 189 |
+
mask_head.mask_rcnn_inference = self.previous_patched["mask_rcnn"]
|
CatVTON/detectron2/export/flatten.py
ADDED
|
@@ -0,0 +1,330 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
import collections
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
from typing import Callable, List, Optional, Tuple
|
| 5 |
+
import torch
|
| 6 |
+
from torch import nn
|
| 7 |
+
|
| 8 |
+
from detectron2.structures import Boxes, Instances, ROIMasks
|
| 9 |
+
from detectron2.utils.registry import _convert_target_to_string, locate
|
| 10 |
+
|
| 11 |
+
from .torchscript_patch import patch_builtin_len
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
@dataclass
|
| 15 |
+
class Schema:
|
| 16 |
+
"""
|
| 17 |
+
A Schema defines how to flatten a possibly hierarchical object into tuple of
|
| 18 |
+
primitive objects, so it can be used as inputs/outputs of PyTorch's tracing.
|
| 19 |
+
|
| 20 |
+
PyTorch does not support tracing a function that produces rich output
|
| 21 |
+
structures (e.g. dict, Instances, Boxes). To trace such a function, we
|
| 22 |
+
flatten the rich object into tuple of tensors, and return this tuple of tensors
|
| 23 |
+
instead. Meanwhile, we also need to know how to "rebuild" the original object
|
| 24 |
+
from the flattened results, so we can evaluate the flattened results.
|
| 25 |
+
A Schema defines how to flatten an object, and while flattening it, it records
|
| 26 |
+
necessary schemas so that the object can be rebuilt using the flattened outputs.
|
| 27 |
+
|
| 28 |
+
The flattened object and the schema object is returned by ``.flatten`` classmethod.
|
| 29 |
+
Then the original object can be rebuilt with the ``__call__`` method of schema.
|
| 30 |
+
|
| 31 |
+
A Schema is a dataclass that can be serialized easily.
|
| 32 |
+
"""
|
| 33 |
+
|
| 34 |
+
# inspired by FetchMapper in tensorflow/python/client/session.py
|
| 35 |
+
|
| 36 |
+
@classmethod
|
| 37 |
+
def flatten(cls, obj):
|
| 38 |
+
raise NotImplementedError
|
| 39 |
+
|
| 40 |
+
def __call__(self, values):
|
| 41 |
+
raise NotImplementedError
|
| 42 |
+
|
| 43 |
+
@staticmethod
|
| 44 |
+
def _concat(values):
|
| 45 |
+
ret = ()
|
| 46 |
+
sizes = []
|
| 47 |
+
for v in values:
|
| 48 |
+
assert isinstance(v, tuple), "Flattened results must be a tuple"
|
| 49 |
+
ret = ret + v
|
| 50 |
+
sizes.append(len(v))
|
| 51 |
+
return ret, sizes
|
| 52 |
+
|
| 53 |
+
@staticmethod
|
| 54 |
+
def _split(values, sizes):
|
| 55 |
+
if len(sizes):
|
| 56 |
+
expected_len = sum(sizes)
|
| 57 |
+
assert (
|
| 58 |
+
len(values) == expected_len
|
| 59 |
+
), f"Values has length {len(values)} but expect length {expected_len}."
|
| 60 |
+
ret = []
|
| 61 |
+
for k in range(len(sizes)):
|
| 62 |
+
begin, end = sum(sizes[:k]), sum(sizes[: k + 1])
|
| 63 |
+
ret.append(values[begin:end])
|
| 64 |
+
return ret
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
@dataclass
|
| 68 |
+
class ListSchema(Schema):
|
| 69 |
+
schemas: List[Schema] # the schemas that define how to flatten each element in the list
|
| 70 |
+
sizes: List[int] # the flattened length of each element
|
| 71 |
+
|
| 72 |
+
def __call__(self, values):
|
| 73 |
+
values = self._split(values, self.sizes)
|
| 74 |
+
if len(values) != len(self.schemas):
|
| 75 |
+
raise ValueError(
|
| 76 |
+
f"Values has length {len(values)} but schemas " f"has length {len(self.schemas)}!"
|
| 77 |
+
)
|
| 78 |
+
values = [m(v) for m, v in zip(self.schemas, values)]
|
| 79 |
+
return list(values)
|
| 80 |
+
|
| 81 |
+
@classmethod
|
| 82 |
+
def flatten(cls, obj):
|
| 83 |
+
res = [flatten_to_tuple(k) for k in obj]
|
| 84 |
+
values, sizes = cls._concat([k[0] for k in res])
|
| 85 |
+
return values, cls([k[1] for k in res], sizes)
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
@dataclass
|
| 89 |
+
class TupleSchema(ListSchema):
|
| 90 |
+
def __call__(self, values):
|
| 91 |
+
return tuple(super().__call__(values))
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
@dataclass
|
| 95 |
+
class IdentitySchema(Schema):
|
| 96 |
+
def __call__(self, values):
|
| 97 |
+
return values[0]
|
| 98 |
+
|
| 99 |
+
@classmethod
|
| 100 |
+
def flatten(cls, obj):
|
| 101 |
+
return (obj,), cls()
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
@dataclass
|
| 105 |
+
class DictSchema(ListSchema):
|
| 106 |
+
keys: List[str]
|
| 107 |
+
|
| 108 |
+
def __call__(self, values):
|
| 109 |
+
values = super().__call__(values)
|
| 110 |
+
return dict(zip(self.keys, values))
|
| 111 |
+
|
| 112 |
+
@classmethod
|
| 113 |
+
def flatten(cls, obj):
|
| 114 |
+
for k in obj.keys():
|
| 115 |
+
if not isinstance(k, str):
|
| 116 |
+
raise KeyError("Only support flattening dictionaries if keys are str.")
|
| 117 |
+
keys = sorted(obj.keys())
|
| 118 |
+
values = [obj[k] for k in keys]
|
| 119 |
+
ret, schema = ListSchema.flatten(values)
|
| 120 |
+
return ret, cls(schema.schemas, schema.sizes, keys)
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
@dataclass
|
| 124 |
+
class InstancesSchema(DictSchema):
|
| 125 |
+
def __call__(self, values):
|
| 126 |
+
image_size, fields = values[-1], values[:-1]
|
| 127 |
+
fields = super().__call__(fields)
|
| 128 |
+
return Instances(image_size, **fields)
|
| 129 |
+
|
| 130 |
+
@classmethod
|
| 131 |
+
def flatten(cls, obj):
|
| 132 |
+
ret, schema = super().flatten(obj.get_fields())
|
| 133 |
+
size = obj.image_size
|
| 134 |
+
if not isinstance(size, torch.Tensor):
|
| 135 |
+
size = torch.tensor(size)
|
| 136 |
+
return ret + (size,), schema
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
@dataclass
|
| 140 |
+
class TensorWrapSchema(Schema):
|
| 141 |
+
"""
|
| 142 |
+
For classes that are simple wrapper of tensors, e.g.
|
| 143 |
+
Boxes, RotatedBoxes, BitMasks
|
| 144 |
+
"""
|
| 145 |
+
|
| 146 |
+
class_name: str
|
| 147 |
+
|
| 148 |
+
def __call__(self, values):
|
| 149 |
+
return locate(self.class_name)(values[0])
|
| 150 |
+
|
| 151 |
+
@classmethod
|
| 152 |
+
def flatten(cls, obj):
|
| 153 |
+
return (obj.tensor,), cls(_convert_target_to_string(type(obj)))
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
# if more custom structures needed in the future, can allow
|
| 157 |
+
# passing in extra schemas for custom types
|
| 158 |
+
def flatten_to_tuple(obj):
|
| 159 |
+
"""
|
| 160 |
+
Flatten an object so it can be used for PyTorch tracing.
|
| 161 |
+
Also returns how to rebuild the original object from the flattened outputs.
|
| 162 |
+
|
| 163 |
+
Returns:
|
| 164 |
+
res (tuple): the flattened results that can be used as tracing outputs
|
| 165 |
+
schema: an object with a ``__call__`` method such that ``schema(res) == obj``.
|
| 166 |
+
It is a pure dataclass that can be serialized.
|
| 167 |
+
"""
|
| 168 |
+
schemas = [
|
| 169 |
+
((str, bytes), IdentitySchema),
|
| 170 |
+
(list, ListSchema),
|
| 171 |
+
(tuple, TupleSchema),
|
| 172 |
+
(collections.abc.Mapping, DictSchema),
|
| 173 |
+
(Instances, InstancesSchema),
|
| 174 |
+
((Boxes, ROIMasks), TensorWrapSchema),
|
| 175 |
+
]
|
| 176 |
+
for klass, schema in schemas:
|
| 177 |
+
if isinstance(obj, klass):
|
| 178 |
+
F = schema
|
| 179 |
+
break
|
| 180 |
+
else:
|
| 181 |
+
F = IdentitySchema
|
| 182 |
+
|
| 183 |
+
return F.flatten(obj)
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
class TracingAdapter(nn.Module):
|
| 187 |
+
"""
|
| 188 |
+
A model may take rich input/output format (e.g. dict or custom classes),
|
| 189 |
+
but `torch.jit.trace` requires tuple of tensors as input/output.
|
| 190 |
+
This adapter flattens input/output format of a model so it becomes traceable.
|
| 191 |
+
|
| 192 |
+
It also records the necessary schema to rebuild model's inputs/outputs from flattened
|
| 193 |
+
inputs/outputs.
|
| 194 |
+
|
| 195 |
+
Example:
|
| 196 |
+
::
|
| 197 |
+
outputs = model(inputs) # inputs/outputs may be rich structure
|
| 198 |
+
adapter = TracingAdapter(model, inputs)
|
| 199 |
+
|
| 200 |
+
# can now trace the model, with adapter.flattened_inputs, or another
|
| 201 |
+
# tuple of tensors with the same length and meaning
|
| 202 |
+
traced = torch.jit.trace(adapter, adapter.flattened_inputs)
|
| 203 |
+
|
| 204 |
+
# traced model can only produce flattened outputs (tuple of tensors)
|
| 205 |
+
flattened_outputs = traced(*adapter.flattened_inputs)
|
| 206 |
+
# adapter knows the schema to convert it back (new_outputs == outputs)
|
| 207 |
+
new_outputs = adapter.outputs_schema(flattened_outputs)
|
| 208 |
+
"""
|
| 209 |
+
|
| 210 |
+
flattened_inputs: Tuple[torch.Tensor] = None
|
| 211 |
+
"""
|
| 212 |
+
Flattened version of inputs given to this class's constructor.
|
| 213 |
+
"""
|
| 214 |
+
|
| 215 |
+
inputs_schema: Schema = None
|
| 216 |
+
"""
|
| 217 |
+
Schema of the inputs given to this class's constructor.
|
| 218 |
+
"""
|
| 219 |
+
|
| 220 |
+
outputs_schema: Schema = None
|
| 221 |
+
"""
|
| 222 |
+
Schema of the output produced by calling the given model with inputs.
|
| 223 |
+
"""
|
| 224 |
+
|
| 225 |
+
def __init__(
|
| 226 |
+
self,
|
| 227 |
+
model: nn.Module,
|
| 228 |
+
inputs,
|
| 229 |
+
inference_func: Optional[Callable] = None,
|
| 230 |
+
allow_non_tensor: bool = False,
|
| 231 |
+
):
|
| 232 |
+
"""
|
| 233 |
+
Args:
|
| 234 |
+
model: an nn.Module
|
| 235 |
+
inputs: An input argument or a tuple of input arguments used to call model.
|
| 236 |
+
After flattening, it has to only consist of tensors.
|
| 237 |
+
inference_func: a callable that takes (model, *inputs), calls the
|
| 238 |
+
model with inputs, and return outputs. By default it
|
| 239 |
+
is ``lambda model, *inputs: model(*inputs)``. Can be override
|
| 240 |
+
if you need to call the model differently.
|
| 241 |
+
allow_non_tensor: allow inputs/outputs to contain non-tensor objects.
|
| 242 |
+
This option will filter out non-tensor objects to make the
|
| 243 |
+
model traceable, but ``inputs_schema``/``outputs_schema`` cannot be
|
| 244 |
+
used anymore because inputs/outputs cannot be rebuilt from pure tensors.
|
| 245 |
+
This is useful when you're only interested in the single trace of
|
| 246 |
+
execution (e.g. for flop count), but not interested in
|
| 247 |
+
generalizing the traced graph to new inputs.
|
| 248 |
+
"""
|
| 249 |
+
super().__init__()
|
| 250 |
+
if isinstance(model, (nn.parallel.distributed.DistributedDataParallel, nn.DataParallel)):
|
| 251 |
+
model = model.module
|
| 252 |
+
self.model = model
|
| 253 |
+
if not isinstance(inputs, tuple):
|
| 254 |
+
inputs = (inputs,)
|
| 255 |
+
self.inputs = inputs
|
| 256 |
+
self.allow_non_tensor = allow_non_tensor
|
| 257 |
+
|
| 258 |
+
if inference_func is None:
|
| 259 |
+
inference_func = lambda model, *inputs: model(*inputs) # noqa
|
| 260 |
+
self.inference_func = inference_func
|
| 261 |
+
|
| 262 |
+
self.flattened_inputs, self.inputs_schema = flatten_to_tuple(inputs)
|
| 263 |
+
|
| 264 |
+
if all(isinstance(x, torch.Tensor) for x in self.flattened_inputs):
|
| 265 |
+
return
|
| 266 |
+
if self.allow_non_tensor:
|
| 267 |
+
self.flattened_inputs = tuple(
|
| 268 |
+
[x for x in self.flattened_inputs if isinstance(x, torch.Tensor)]
|
| 269 |
+
)
|
| 270 |
+
self.inputs_schema = None
|
| 271 |
+
else:
|
| 272 |
+
for input in self.flattened_inputs:
|
| 273 |
+
if not isinstance(input, torch.Tensor):
|
| 274 |
+
raise ValueError(
|
| 275 |
+
"Inputs for tracing must only contain tensors. "
|
| 276 |
+
f"Got a {type(input)} instead."
|
| 277 |
+
)
|
| 278 |
+
|
| 279 |
+
def forward(self, *args: torch.Tensor):
|
| 280 |
+
with torch.no_grad(), patch_builtin_len():
|
| 281 |
+
if self.inputs_schema is not None:
|
| 282 |
+
inputs_orig_format = self.inputs_schema(args)
|
| 283 |
+
else:
|
| 284 |
+
if len(args) != len(self.flattened_inputs) or any(
|
| 285 |
+
x is not y for x, y in zip(args, self.flattened_inputs)
|
| 286 |
+
):
|
| 287 |
+
raise ValueError(
|
| 288 |
+
"TracingAdapter does not contain valid inputs_schema."
|
| 289 |
+
" So it cannot generalize to other inputs and must be"
|
| 290 |
+
" traced with `.flattened_inputs`."
|
| 291 |
+
)
|
| 292 |
+
inputs_orig_format = self.inputs
|
| 293 |
+
|
| 294 |
+
outputs = self.inference_func(self.model, *inputs_orig_format)
|
| 295 |
+
flattened_outputs, schema = flatten_to_tuple(outputs)
|
| 296 |
+
|
| 297 |
+
flattened_output_tensors = tuple(
|
| 298 |
+
[x for x in flattened_outputs if isinstance(x, torch.Tensor)]
|
| 299 |
+
)
|
| 300 |
+
if len(flattened_output_tensors) < len(flattened_outputs):
|
| 301 |
+
if self.allow_non_tensor:
|
| 302 |
+
flattened_outputs = flattened_output_tensors
|
| 303 |
+
self.outputs_schema = None
|
| 304 |
+
else:
|
| 305 |
+
raise ValueError(
|
| 306 |
+
"Model cannot be traced because some model outputs "
|
| 307 |
+
"cannot flatten to tensors."
|
| 308 |
+
)
|
| 309 |
+
else: # schema is valid
|
| 310 |
+
if self.outputs_schema is None:
|
| 311 |
+
self.outputs_schema = schema
|
| 312 |
+
else:
|
| 313 |
+
assert self.outputs_schema == schema, (
|
| 314 |
+
"Model should always return outputs with the same "
|
| 315 |
+
"structure so it can be traced!"
|
| 316 |
+
)
|
| 317 |
+
return flattened_outputs
|
| 318 |
+
|
| 319 |
+
def _create_wrapper(self, traced_model):
|
| 320 |
+
"""
|
| 321 |
+
Return a function that has an input/output interface the same as the
|
| 322 |
+
original model, but it calls the given traced model under the hood.
|
| 323 |
+
"""
|
| 324 |
+
|
| 325 |
+
def forward(*args):
|
| 326 |
+
flattened_inputs, _ = flatten_to_tuple(args)
|
| 327 |
+
flattened_outputs = traced_model(*flattened_inputs)
|
| 328 |
+
return self.outputs_schema(flattened_outputs)
|
| 329 |
+
|
| 330 |
+
return forward
|
CatVTON/detectron2/export/shared.py
ADDED
|
@@ -0,0 +1,1040 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
|
| 3 |
+
import collections
|
| 4 |
+
import copy
|
| 5 |
+
import functools
|
| 6 |
+
import logging
|
| 7 |
+
import numpy as np
|
| 8 |
+
import os
|
| 9 |
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
| 10 |
+
from unittest import mock
|
| 11 |
+
import caffe2.python.utils as putils
|
| 12 |
+
import torch
|
| 13 |
+
import torch.nn.functional as F
|
| 14 |
+
from caffe2.proto import caffe2_pb2
|
| 15 |
+
from caffe2.python import core, net_drawer, workspace
|
| 16 |
+
from torch.nn.functional import interpolate as interp
|
| 17 |
+
|
| 18 |
+
logger = logging.getLogger(__name__)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
# ==== torch/utils_toffee/cast.py =======================================
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def to_device(t, device_str):
|
| 25 |
+
"""
|
| 26 |
+
This function is a replacement of .to(another_device) such that it allows the
|
| 27 |
+
casting to be traced properly by explicitly calling the underlying copy ops.
|
| 28 |
+
It also avoids introducing unncessary op when casting to the same device.
|
| 29 |
+
"""
|
| 30 |
+
src = t.device
|
| 31 |
+
dst = torch.device(device_str)
|
| 32 |
+
|
| 33 |
+
if src == dst:
|
| 34 |
+
return t
|
| 35 |
+
elif src.type == "cuda" and dst.type == "cpu":
|
| 36 |
+
return torch.ops._caffe2.CopyGPUToCPU(t)
|
| 37 |
+
elif src.type == "cpu" and dst.type == "cuda":
|
| 38 |
+
return torch.ops._caffe2.CopyCPUToGPU(t)
|
| 39 |
+
else:
|
| 40 |
+
raise RuntimeError("Can't cast tensor from device {} to device {}".format(src, dst))
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
# ==== torch/utils_toffee/interpolate.py =======================================
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
# Note: borrowed from vision/detection/fair/detectron/detectron/modeling/detector.py
|
| 47 |
+
def BilinearInterpolation(tensor_in, up_scale):
|
| 48 |
+
assert up_scale % 2 == 0, "Scale should be even"
|
| 49 |
+
|
| 50 |
+
def upsample_filt(size):
|
| 51 |
+
factor = (size + 1) // 2
|
| 52 |
+
if size % 2 == 1:
|
| 53 |
+
center = factor - 1
|
| 54 |
+
else:
|
| 55 |
+
center = factor - 0.5
|
| 56 |
+
|
| 57 |
+
og = np.ogrid[:size, :size]
|
| 58 |
+
return (1 - abs(og[0] - center) / factor) * (1 - abs(og[1] - center) / factor)
|
| 59 |
+
|
| 60 |
+
kernel_size = int(up_scale) * 2
|
| 61 |
+
bil_filt = upsample_filt(kernel_size)
|
| 62 |
+
|
| 63 |
+
dim = int(tensor_in.shape[1])
|
| 64 |
+
kernel = np.zeros((dim, dim, kernel_size, kernel_size), dtype=np.float32)
|
| 65 |
+
kernel[range(dim), range(dim), :, :] = bil_filt
|
| 66 |
+
|
| 67 |
+
tensor_out = F.conv_transpose2d(
|
| 68 |
+
tensor_in,
|
| 69 |
+
weight=to_device(torch.Tensor(kernel), tensor_in.device),
|
| 70 |
+
bias=None,
|
| 71 |
+
stride=int(up_scale),
|
| 72 |
+
padding=int(up_scale / 2),
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
return tensor_out
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
# NOTE: ONNX is incompatible with traced torch.nn.functional.interpolate if
|
| 79 |
+
# using dynamic `scale_factor` rather than static `size`. (T43166860)
|
| 80 |
+
# NOTE: Caffe2 Int8 conversion might not be able to quantize `size` properly.
|
| 81 |
+
def onnx_compatibale_interpolate(
|
| 82 |
+
input, size=None, scale_factor=None, mode="nearest", align_corners=None
|
| 83 |
+
):
|
| 84 |
+
# NOTE: The input dimensions are interpreted in the form:
|
| 85 |
+
# `mini-batch x channels x [optional depth] x [optional height] x width`.
|
| 86 |
+
if size is None and scale_factor is not None:
|
| 87 |
+
if input.dim() == 4:
|
| 88 |
+
if isinstance(scale_factor, (int, float)):
|
| 89 |
+
height_scale, width_scale = (scale_factor, scale_factor)
|
| 90 |
+
else:
|
| 91 |
+
assert isinstance(scale_factor, (tuple, list))
|
| 92 |
+
assert len(scale_factor) == 2
|
| 93 |
+
height_scale, width_scale = scale_factor
|
| 94 |
+
|
| 95 |
+
assert not align_corners, "No matching C2 op for align_corners == True"
|
| 96 |
+
if mode == "nearest":
|
| 97 |
+
return torch.ops._caffe2.ResizeNearest(
|
| 98 |
+
input, order="NCHW", width_scale=width_scale, height_scale=height_scale
|
| 99 |
+
)
|
| 100 |
+
elif mode == "bilinear":
|
| 101 |
+
logger.warning(
|
| 102 |
+
"Use F.conv_transpose2d for bilinear interpolate"
|
| 103 |
+
" because there's no such C2 op, this may cause significant"
|
| 104 |
+
" slowdown and the boundary pixels won't be as same as"
|
| 105 |
+
" using F.interpolate due to padding."
|
| 106 |
+
)
|
| 107 |
+
assert height_scale == width_scale
|
| 108 |
+
return BilinearInterpolation(input, up_scale=height_scale)
|
| 109 |
+
logger.warning("Output size is not static, it might cause ONNX conversion issue")
|
| 110 |
+
|
| 111 |
+
return interp(input, size, scale_factor, mode, align_corners)
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def mock_torch_nn_functional_interpolate():
|
| 115 |
+
def decorator(func):
|
| 116 |
+
@functools.wraps(func)
|
| 117 |
+
def _mock_torch_nn_functional_interpolate(*args, **kwargs):
|
| 118 |
+
if torch.onnx.is_in_onnx_export():
|
| 119 |
+
with mock.patch(
|
| 120 |
+
"torch.nn.functional.interpolate", side_effect=onnx_compatibale_interpolate
|
| 121 |
+
):
|
| 122 |
+
return func(*args, **kwargs)
|
| 123 |
+
else:
|
| 124 |
+
return func(*args, **kwargs)
|
| 125 |
+
|
| 126 |
+
return _mock_torch_nn_functional_interpolate
|
| 127 |
+
|
| 128 |
+
return decorator
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
# ==== torch/utils_caffe2/ws_utils.py ==========================================
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
class ScopedWS:
|
| 135 |
+
def __init__(self, ws_name, is_reset, is_cleanup=False):
|
| 136 |
+
self.ws_name = ws_name
|
| 137 |
+
self.is_reset = is_reset
|
| 138 |
+
self.is_cleanup = is_cleanup
|
| 139 |
+
self.org_ws = ""
|
| 140 |
+
|
| 141 |
+
def __enter__(self):
|
| 142 |
+
self.org_ws = workspace.CurrentWorkspace()
|
| 143 |
+
if self.ws_name is not None:
|
| 144 |
+
workspace.SwitchWorkspace(self.ws_name, True)
|
| 145 |
+
if self.is_reset:
|
| 146 |
+
workspace.ResetWorkspace()
|
| 147 |
+
|
| 148 |
+
return workspace
|
| 149 |
+
|
| 150 |
+
def __exit__(self, *args):
|
| 151 |
+
if self.is_cleanup:
|
| 152 |
+
workspace.ResetWorkspace()
|
| 153 |
+
if self.ws_name is not None:
|
| 154 |
+
workspace.SwitchWorkspace(self.org_ws)
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
def fetch_any_blob(name):
|
| 158 |
+
bb = None
|
| 159 |
+
try:
|
| 160 |
+
bb = workspace.FetchBlob(name)
|
| 161 |
+
except TypeError:
|
| 162 |
+
bb = workspace.FetchInt8Blob(name)
|
| 163 |
+
except Exception as e:
|
| 164 |
+
logger.error("Get blob {} error: {}".format(name, e))
|
| 165 |
+
|
| 166 |
+
return bb
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
# ==== torch/utils_caffe2/protobuf.py ==========================================
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
def get_pb_arg(pb, arg_name):
|
| 173 |
+
for x in pb.arg:
|
| 174 |
+
if x.name == arg_name:
|
| 175 |
+
return x
|
| 176 |
+
return None
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
def get_pb_arg_valf(pb, arg_name, default_val):
|
| 180 |
+
arg = get_pb_arg(pb, arg_name)
|
| 181 |
+
return arg.f if arg is not None else default_val
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
def get_pb_arg_floats(pb, arg_name, default_val):
|
| 185 |
+
arg = get_pb_arg(pb, arg_name)
|
| 186 |
+
return list(map(float, arg.floats)) if arg is not None else default_val
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
def get_pb_arg_ints(pb, arg_name, default_val):
|
| 190 |
+
arg = get_pb_arg(pb, arg_name)
|
| 191 |
+
return list(map(int, arg.ints)) if arg is not None else default_val
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
def get_pb_arg_vali(pb, arg_name, default_val):
|
| 195 |
+
arg = get_pb_arg(pb, arg_name)
|
| 196 |
+
return arg.i if arg is not None else default_val
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
def get_pb_arg_vals(pb, arg_name, default_val):
|
| 200 |
+
arg = get_pb_arg(pb, arg_name)
|
| 201 |
+
return arg.s if arg is not None else default_val
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
def get_pb_arg_valstrings(pb, arg_name, default_val):
|
| 205 |
+
arg = get_pb_arg(pb, arg_name)
|
| 206 |
+
return list(arg.strings) if arg is not None else default_val
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
def check_set_pb_arg(pb, arg_name, arg_attr, arg_value, allow_override=False):
|
| 210 |
+
arg = get_pb_arg(pb, arg_name)
|
| 211 |
+
if arg is None:
|
| 212 |
+
arg = putils.MakeArgument(arg_name, arg_value)
|
| 213 |
+
assert hasattr(arg, arg_attr)
|
| 214 |
+
pb.arg.extend([arg])
|
| 215 |
+
if allow_override and getattr(arg, arg_attr) != arg_value:
|
| 216 |
+
logger.warning(
|
| 217 |
+
"Override argument {}: {} -> {}".format(arg_name, getattr(arg, arg_attr), arg_value)
|
| 218 |
+
)
|
| 219 |
+
setattr(arg, arg_attr, arg_value)
|
| 220 |
+
else:
|
| 221 |
+
assert arg is not None
|
| 222 |
+
assert getattr(arg, arg_attr) == arg_value, "Existing value {}, new value {}".format(
|
| 223 |
+
getattr(arg, arg_attr), arg_value
|
| 224 |
+
)
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
def _create_const_fill_op_from_numpy(name, tensor, device_option=None):
|
| 228 |
+
assert type(tensor) is np.ndarray
|
| 229 |
+
kTypeNameMapper = {
|
| 230 |
+
np.dtype("float32"): "GivenTensorFill",
|
| 231 |
+
np.dtype("int32"): "GivenTensorIntFill",
|
| 232 |
+
np.dtype("int64"): "GivenTensorInt64Fill",
|
| 233 |
+
np.dtype("uint8"): "GivenTensorStringFill",
|
| 234 |
+
}
|
| 235 |
+
|
| 236 |
+
args_dict = {}
|
| 237 |
+
if tensor.dtype == np.dtype("uint8"):
|
| 238 |
+
args_dict.update({"values": [str(tensor.data)], "shape": [1]})
|
| 239 |
+
else:
|
| 240 |
+
args_dict.update({"values": tensor, "shape": tensor.shape})
|
| 241 |
+
|
| 242 |
+
if device_option is not None:
|
| 243 |
+
args_dict["device_option"] = device_option
|
| 244 |
+
|
| 245 |
+
return core.CreateOperator(kTypeNameMapper[tensor.dtype], [], [name], **args_dict)
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
def _create_const_fill_op_from_c2_int8_tensor(name, int8_tensor):
|
| 249 |
+
assert type(int8_tensor) is workspace.Int8Tensor
|
| 250 |
+
kTypeNameMapper = {
|
| 251 |
+
np.dtype("int32"): "Int8GivenIntTensorFill",
|
| 252 |
+
np.dtype("uint8"): "Int8GivenTensorFill",
|
| 253 |
+
}
|
| 254 |
+
|
| 255 |
+
tensor = int8_tensor.data
|
| 256 |
+
assert tensor.dtype in [np.dtype("uint8"), np.dtype("int32")]
|
| 257 |
+
values = tensor.tobytes() if tensor.dtype == np.dtype("uint8") else tensor
|
| 258 |
+
|
| 259 |
+
return core.CreateOperator(
|
| 260 |
+
kTypeNameMapper[tensor.dtype],
|
| 261 |
+
[],
|
| 262 |
+
[name],
|
| 263 |
+
values=values,
|
| 264 |
+
shape=tensor.shape,
|
| 265 |
+
Y_scale=int8_tensor.scale,
|
| 266 |
+
Y_zero_point=int8_tensor.zero_point,
|
| 267 |
+
)
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
def create_const_fill_op(
|
| 271 |
+
name: str,
|
| 272 |
+
blob: Union[np.ndarray, workspace.Int8Tensor],
|
| 273 |
+
device_option: Optional[caffe2_pb2.DeviceOption] = None,
|
| 274 |
+
) -> caffe2_pb2.OperatorDef:
|
| 275 |
+
"""
|
| 276 |
+
Given a blob object, return the Caffe2 operator that creates this blob
|
| 277 |
+
as constant. Currently support NumPy tensor and Caffe2 Int8Tensor.
|
| 278 |
+
"""
|
| 279 |
+
|
| 280 |
+
tensor_type = type(blob)
|
| 281 |
+
assert tensor_type in [
|
| 282 |
+
np.ndarray,
|
| 283 |
+
workspace.Int8Tensor,
|
| 284 |
+
], 'Error when creating const fill op for "{}", unsupported blob type: {}'.format(
|
| 285 |
+
name, type(blob)
|
| 286 |
+
)
|
| 287 |
+
|
| 288 |
+
if tensor_type == np.ndarray:
|
| 289 |
+
return _create_const_fill_op_from_numpy(name, blob, device_option)
|
| 290 |
+
elif tensor_type == workspace.Int8Tensor:
|
| 291 |
+
assert device_option is None
|
| 292 |
+
return _create_const_fill_op_from_c2_int8_tensor(name, blob)
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
def construct_init_net_from_params(
|
| 296 |
+
params: Dict[str, Any], device_options: Optional[Dict[str, caffe2_pb2.DeviceOption]] = None
|
| 297 |
+
) -> caffe2_pb2.NetDef:
|
| 298 |
+
"""
|
| 299 |
+
Construct the init_net from params dictionary
|
| 300 |
+
"""
|
| 301 |
+
init_net = caffe2_pb2.NetDef()
|
| 302 |
+
device_options = device_options or {}
|
| 303 |
+
for name, blob in params.items():
|
| 304 |
+
if isinstance(blob, str):
|
| 305 |
+
logger.warning(
|
| 306 |
+
(
|
| 307 |
+
"Blob {} with type {} is not supported in generating init net,"
|
| 308 |
+
" skipped.".format(name, type(blob))
|
| 309 |
+
)
|
| 310 |
+
)
|
| 311 |
+
continue
|
| 312 |
+
init_net.op.extend(
|
| 313 |
+
[create_const_fill_op(name, blob, device_option=device_options.get(name, None))]
|
| 314 |
+
)
|
| 315 |
+
init_net.external_output.append(name)
|
| 316 |
+
return init_net
|
| 317 |
+
|
| 318 |
+
|
| 319 |
+
def get_producer_map(ssa):
|
| 320 |
+
"""
|
| 321 |
+
Return dict from versioned blob to (i, j),
|
| 322 |
+
where i is index of producer op, j is the index of output of that op.
|
| 323 |
+
"""
|
| 324 |
+
producer_map = {}
|
| 325 |
+
for i in range(len(ssa)):
|
| 326 |
+
outputs = ssa[i][1]
|
| 327 |
+
for j, outp in enumerate(outputs):
|
| 328 |
+
producer_map[outp] = (i, j)
|
| 329 |
+
return producer_map
|
| 330 |
+
|
| 331 |
+
|
| 332 |
+
def get_consumer_map(ssa):
|
| 333 |
+
"""
|
| 334 |
+
Return dict from versioned blob to list of (i, j),
|
| 335 |
+
where i is index of consumer op, j is the index of input of that op.
|
| 336 |
+
"""
|
| 337 |
+
consumer_map = collections.defaultdict(list)
|
| 338 |
+
for i in range(len(ssa)):
|
| 339 |
+
inputs = ssa[i][0]
|
| 340 |
+
for j, inp in enumerate(inputs):
|
| 341 |
+
consumer_map[inp].append((i, j))
|
| 342 |
+
return consumer_map
|
| 343 |
+
|
| 344 |
+
|
| 345 |
+
def get_params_from_init_net(
|
| 346 |
+
init_net: caffe2_pb2.NetDef,
|
| 347 |
+
) -> [Dict[str, Any], Dict[str, caffe2_pb2.DeviceOption]]:
|
| 348 |
+
"""
|
| 349 |
+
Take the output blobs from init_net by running it.
|
| 350 |
+
Outputs:
|
| 351 |
+
params: dict from blob name to numpy array
|
| 352 |
+
device_options: dict from blob name to the device option of its creating op
|
| 353 |
+
"""
|
| 354 |
+
|
| 355 |
+
# NOTE: this assumes that the params is determined by producer op with the
|
| 356 |
+
# only exception be CopyGPUToCPU which is CUDA op but returns CPU tensor.
|
| 357 |
+
def _get_device_option(producer_op):
|
| 358 |
+
if producer_op.type == "CopyGPUToCPU":
|
| 359 |
+
return caffe2_pb2.DeviceOption()
|
| 360 |
+
else:
|
| 361 |
+
return producer_op.device_option
|
| 362 |
+
|
| 363 |
+
with ScopedWS("__get_params_from_init_net__", is_reset=True, is_cleanup=True) as ws:
|
| 364 |
+
ws.RunNetOnce(init_net)
|
| 365 |
+
params = {b: fetch_any_blob(b) for b in init_net.external_output}
|
| 366 |
+
ssa, versions = core.get_ssa(init_net)
|
| 367 |
+
producer_map = get_producer_map(ssa)
|
| 368 |
+
device_options = {
|
| 369 |
+
b: _get_device_option(init_net.op[producer_map[(b, versions[b])][0]])
|
| 370 |
+
for b in init_net.external_output
|
| 371 |
+
}
|
| 372 |
+
return params, device_options
|
| 373 |
+
|
| 374 |
+
|
| 375 |
+
def _updater_raise(op, input_types, output_types):
|
| 376 |
+
raise RuntimeError(
|
| 377 |
+
"Failed to apply updater for op {} given input_types {} and"
|
| 378 |
+
" output_types {}".format(op, input_types, output_types)
|
| 379 |
+
)
|
| 380 |
+
|
| 381 |
+
|
| 382 |
+
def _generic_status_identifier(
|
| 383 |
+
predict_net: caffe2_pb2.NetDef,
|
| 384 |
+
status_updater: Callable,
|
| 385 |
+
known_status: Dict[Tuple[str, int], Any],
|
| 386 |
+
) -> Dict[Tuple[str, int], Any]:
|
| 387 |
+
"""
|
| 388 |
+
Statically infer the status of each blob, the status can be such as device type
|
| 389 |
+
(CPU/GPU), layout (NCHW/NHWC), data type (float32/int8), etc. "Blob" here
|
| 390 |
+
is versioned blob (Tuple[str, int]) in the format compatible with ssa.
|
| 391 |
+
Inputs:
|
| 392 |
+
predict_net: the caffe2 network
|
| 393 |
+
status_updater: a callable, given an op and the status of its input/output,
|
| 394 |
+
it returns the updated status of input/output. `None` is used for
|
| 395 |
+
representing unknown status.
|
| 396 |
+
known_status: a dict containing known status, used as initialization.
|
| 397 |
+
Outputs:
|
| 398 |
+
A dict mapping from versioned blob to its status
|
| 399 |
+
"""
|
| 400 |
+
ssa, versions = core.get_ssa(predict_net)
|
| 401 |
+
versioned_ext_input = [(b, 0) for b in predict_net.external_input]
|
| 402 |
+
versioned_ext_output = [(b, versions[b]) for b in predict_net.external_output]
|
| 403 |
+
all_versioned_blobs = set().union(*[set(x[0] + x[1]) for x in ssa])
|
| 404 |
+
|
| 405 |
+
allowed_vbs = all_versioned_blobs.union(versioned_ext_input).union(versioned_ext_output)
|
| 406 |
+
assert all(k in allowed_vbs for k in known_status)
|
| 407 |
+
assert all(v is not None for v in known_status.values())
|
| 408 |
+
_known_status = copy.deepcopy(known_status)
|
| 409 |
+
|
| 410 |
+
def _check_and_update(key, value):
|
| 411 |
+
assert value is not None
|
| 412 |
+
if key in _known_status:
|
| 413 |
+
if not _known_status[key] == value:
|
| 414 |
+
raise RuntimeError(
|
| 415 |
+
"Confilict status for {}, existing status {}, new status {}".format(
|
| 416 |
+
key, _known_status[key], value
|
| 417 |
+
)
|
| 418 |
+
)
|
| 419 |
+
_known_status[key] = value
|
| 420 |
+
|
| 421 |
+
def _update_i(op, ssa_i):
|
| 422 |
+
versioned_inputs = ssa_i[0]
|
| 423 |
+
versioned_outputs = ssa_i[1]
|
| 424 |
+
|
| 425 |
+
inputs_status = [_known_status.get(b, None) for b in versioned_inputs]
|
| 426 |
+
outputs_status = [_known_status.get(b, None) for b in versioned_outputs]
|
| 427 |
+
|
| 428 |
+
new_inputs_status, new_outputs_status = status_updater(op, inputs_status, outputs_status)
|
| 429 |
+
|
| 430 |
+
for versioned_blob, status in zip(
|
| 431 |
+
versioned_inputs + versioned_outputs, new_inputs_status + new_outputs_status
|
| 432 |
+
):
|
| 433 |
+
if status is not None:
|
| 434 |
+
_check_and_update(versioned_blob, status)
|
| 435 |
+
|
| 436 |
+
for op, ssa_i in zip(predict_net.op, ssa):
|
| 437 |
+
_update_i(op, ssa_i)
|
| 438 |
+
for op, ssa_i in zip(reversed(predict_net.op), reversed(ssa)):
|
| 439 |
+
_update_i(op, ssa_i)
|
| 440 |
+
|
| 441 |
+
# NOTE: This strictly checks all the blob from predict_net must be assgined
|
| 442 |
+
# a known status. However sometimes it's impossible (eg. having deadend op),
|
| 443 |
+
# we may relax this constraint if
|
| 444 |
+
for k in all_versioned_blobs:
|
| 445 |
+
if k not in _known_status:
|
| 446 |
+
raise NotImplementedError(
|
| 447 |
+
"Can not infer the status for {}. Currently only support the case where"
|
| 448 |
+
" a single forward and backward pass can identify status for all blobs.".format(k)
|
| 449 |
+
)
|
| 450 |
+
|
| 451 |
+
return _known_status
|
| 452 |
+
|
| 453 |
+
|
| 454 |
+
def infer_device_type(
|
| 455 |
+
predict_net: caffe2_pb2.NetDef,
|
| 456 |
+
known_status: Dict[Tuple[str, int], Any],
|
| 457 |
+
device_name_style: str = "caffe2",
|
| 458 |
+
) -> Dict[Tuple[str, int], str]:
|
| 459 |
+
"""Return the device type ("cpu" or "gpu"/"cuda") of each (versioned) blob"""
|
| 460 |
+
|
| 461 |
+
assert device_name_style in ["caffe2", "pytorch"]
|
| 462 |
+
_CPU_STR = "cpu"
|
| 463 |
+
_GPU_STR = "gpu" if device_name_style == "caffe2" else "cuda"
|
| 464 |
+
|
| 465 |
+
def _copy_cpu_to_gpu_updater(op, input_types, output_types):
|
| 466 |
+
if input_types[0] == _GPU_STR or output_types[0] == _CPU_STR:
|
| 467 |
+
_updater_raise(op, input_types, output_types)
|
| 468 |
+
return ([_CPU_STR], [_GPU_STR])
|
| 469 |
+
|
| 470 |
+
def _copy_gpu_to_cpu_updater(op, input_types, output_types):
|
| 471 |
+
if input_types[0] == _CPU_STR or output_types[0] == _GPU_STR:
|
| 472 |
+
_updater_raise(op, input_types, output_types)
|
| 473 |
+
return ([_GPU_STR], [_CPU_STR])
|
| 474 |
+
|
| 475 |
+
def _other_ops_updater(op, input_types, output_types):
|
| 476 |
+
non_none_types = [x for x in input_types + output_types if x is not None]
|
| 477 |
+
if len(non_none_types) > 0:
|
| 478 |
+
the_type = non_none_types[0]
|
| 479 |
+
if not all(x == the_type for x in non_none_types):
|
| 480 |
+
_updater_raise(op, input_types, output_types)
|
| 481 |
+
else:
|
| 482 |
+
the_type = None
|
| 483 |
+
return ([the_type for _ in op.input], [the_type for _ in op.output])
|
| 484 |
+
|
| 485 |
+
def _device_updater(op, *args, **kwargs):
|
| 486 |
+
return {
|
| 487 |
+
"CopyCPUToGPU": _copy_cpu_to_gpu_updater,
|
| 488 |
+
"CopyGPUToCPU": _copy_gpu_to_cpu_updater,
|
| 489 |
+
}.get(op.type, _other_ops_updater)(op, *args, **kwargs)
|
| 490 |
+
|
| 491 |
+
return _generic_status_identifier(predict_net, _device_updater, known_status)
|
| 492 |
+
|
| 493 |
+
|
| 494 |
+
# ==== torch/utils_caffe2/vis.py ===============================================
|
| 495 |
+
|
| 496 |
+
|
| 497 |
+
def _modify_blob_names(ops, blob_rename_f):
|
| 498 |
+
ret = []
|
| 499 |
+
|
| 500 |
+
def _replace_list(blob_list, replaced_list):
|
| 501 |
+
del blob_list[:]
|
| 502 |
+
blob_list.extend(replaced_list)
|
| 503 |
+
|
| 504 |
+
for x in ops:
|
| 505 |
+
cur = copy.deepcopy(x)
|
| 506 |
+
_replace_list(cur.input, list(map(blob_rename_f, cur.input)))
|
| 507 |
+
_replace_list(cur.output, list(map(blob_rename_f, cur.output)))
|
| 508 |
+
ret.append(cur)
|
| 509 |
+
|
| 510 |
+
return ret
|
| 511 |
+
|
| 512 |
+
|
| 513 |
+
def _rename_blob(name, blob_sizes, blob_ranges):
|
| 514 |
+
def _list_to_str(bsize):
|
| 515 |
+
ret = ", ".join([str(x) for x in bsize])
|
| 516 |
+
ret = "[" + ret + "]"
|
| 517 |
+
return ret
|
| 518 |
+
|
| 519 |
+
ret = name
|
| 520 |
+
if blob_sizes is not None and name in blob_sizes:
|
| 521 |
+
ret += "\n" + _list_to_str(blob_sizes[name])
|
| 522 |
+
if blob_ranges is not None and name in blob_ranges:
|
| 523 |
+
ret += "\n" + _list_to_str(blob_ranges[name])
|
| 524 |
+
|
| 525 |
+
return ret
|
| 526 |
+
|
| 527 |
+
|
| 528 |
+
# graph_name could not contain word 'graph'
|
| 529 |
+
def save_graph(net, file_name, graph_name="net", op_only=True, blob_sizes=None, blob_ranges=None):
|
| 530 |
+
blob_rename_f = functools.partial(_rename_blob, blob_sizes=blob_sizes, blob_ranges=blob_ranges)
|
| 531 |
+
return save_graph_base(net, file_name, graph_name, op_only, blob_rename_f)
|
| 532 |
+
|
| 533 |
+
|
| 534 |
+
def save_graph_base(net, file_name, graph_name="net", op_only=True, blob_rename_func=None):
|
| 535 |
+
graph = None
|
| 536 |
+
ops = net.op
|
| 537 |
+
if blob_rename_func is not None:
|
| 538 |
+
ops = _modify_blob_names(ops, blob_rename_func)
|
| 539 |
+
if not op_only:
|
| 540 |
+
graph = net_drawer.GetPydotGraph(ops, graph_name, rankdir="TB")
|
| 541 |
+
else:
|
| 542 |
+
graph = net_drawer.GetPydotGraphMinimal(
|
| 543 |
+
ops, graph_name, rankdir="TB", minimal_dependency=True
|
| 544 |
+
)
|
| 545 |
+
|
| 546 |
+
try:
|
| 547 |
+
par_dir = os.path.dirname(file_name)
|
| 548 |
+
if not os.path.exists(par_dir):
|
| 549 |
+
os.makedirs(par_dir)
|
| 550 |
+
|
| 551 |
+
format = os.path.splitext(os.path.basename(file_name))[-1]
|
| 552 |
+
if format == ".png":
|
| 553 |
+
graph.write_png(file_name)
|
| 554 |
+
elif format == ".pdf":
|
| 555 |
+
graph.write_pdf(file_name)
|
| 556 |
+
elif format == ".svg":
|
| 557 |
+
graph.write_svg(file_name)
|
| 558 |
+
else:
|
| 559 |
+
print("Incorrect format {}".format(format))
|
| 560 |
+
except Exception as e:
|
| 561 |
+
print("Error when writing graph to image {}".format(e))
|
| 562 |
+
|
| 563 |
+
return graph
|
| 564 |
+
|
| 565 |
+
|
| 566 |
+
# ==== torch/utils_toffee/aten_to_caffe2.py ====================================
|
| 567 |
+
|
| 568 |
+
|
| 569 |
+
def group_norm_replace_aten_with_caffe2(predict_net: caffe2_pb2.NetDef):
|
| 570 |
+
"""
|
| 571 |
+
For ONNX exported model, GroupNorm will be represented as ATen op,
|
| 572 |
+
this can be a drop in replacement from ATen to GroupNorm
|
| 573 |
+
"""
|
| 574 |
+
count = 0
|
| 575 |
+
for op in predict_net.op:
|
| 576 |
+
if op.type == "ATen":
|
| 577 |
+
op_name = get_pb_arg_vals(op, "operator", None) # return byte in py3
|
| 578 |
+
if op_name and op_name.decode() == "group_norm":
|
| 579 |
+
op.arg.remove(get_pb_arg(op, "operator"))
|
| 580 |
+
|
| 581 |
+
if get_pb_arg_vali(op, "cudnn_enabled", None):
|
| 582 |
+
op.arg.remove(get_pb_arg(op, "cudnn_enabled"))
|
| 583 |
+
|
| 584 |
+
num_groups = get_pb_arg_vali(op, "num_groups", None)
|
| 585 |
+
if num_groups is not None:
|
| 586 |
+
op.arg.remove(get_pb_arg(op, "num_groups"))
|
| 587 |
+
check_set_pb_arg(op, "group", "i", num_groups)
|
| 588 |
+
|
| 589 |
+
op.type = "GroupNorm"
|
| 590 |
+
count += 1
|
| 591 |
+
if count > 1:
|
| 592 |
+
logger.info("Replaced {} ATen operator to GroupNormOp".format(count))
|
| 593 |
+
|
| 594 |
+
|
| 595 |
+
# ==== torch/utils_toffee/alias.py =============================================
|
| 596 |
+
|
| 597 |
+
|
| 598 |
+
def alias(x, name, is_backward=False):
|
| 599 |
+
if not torch.onnx.is_in_onnx_export():
|
| 600 |
+
return x
|
| 601 |
+
assert isinstance(x, torch.Tensor)
|
| 602 |
+
return torch.ops._caffe2.AliasWithName(x, name, is_backward=is_backward)
|
| 603 |
+
|
| 604 |
+
|
| 605 |
+
def fuse_alias_placeholder(predict_net, init_net):
|
| 606 |
+
"""Remove AliasWithName placeholder and rename the input/output of it"""
|
| 607 |
+
# First we finish all the re-naming
|
| 608 |
+
for i, op in enumerate(predict_net.op):
|
| 609 |
+
if op.type == "AliasWithName":
|
| 610 |
+
assert len(op.input) == 1
|
| 611 |
+
assert len(op.output) == 1
|
| 612 |
+
name = get_pb_arg_vals(op, "name", None).decode()
|
| 613 |
+
is_backward = bool(get_pb_arg_vali(op, "is_backward", 0))
|
| 614 |
+
rename_op_input(predict_net, init_net, i, 0, name, from_producer=is_backward)
|
| 615 |
+
rename_op_output(predict_net, i, 0, name)
|
| 616 |
+
|
| 617 |
+
# Remove AliasWithName, should be very safe since it's a non-op
|
| 618 |
+
new_ops = []
|
| 619 |
+
for op in predict_net.op:
|
| 620 |
+
if op.type != "AliasWithName":
|
| 621 |
+
new_ops.append(op)
|
| 622 |
+
else:
|
| 623 |
+
# safety check
|
| 624 |
+
assert op.input == op.output
|
| 625 |
+
assert op.input[0] == op.arg[0].s.decode()
|
| 626 |
+
del predict_net.op[:]
|
| 627 |
+
predict_net.op.extend(new_ops)
|
| 628 |
+
|
| 629 |
+
|
| 630 |
+
# ==== torch/utils_caffe2/graph_transform.py ===================================
|
| 631 |
+
|
| 632 |
+
|
| 633 |
+
class IllegalGraphTransformError(ValueError):
|
| 634 |
+
"""When a graph transform function call can't be executed."""
|
| 635 |
+
|
| 636 |
+
|
| 637 |
+
def _rename_versioned_blob_in_proto(
|
| 638 |
+
proto: caffe2_pb2.NetDef,
|
| 639 |
+
old_name: str,
|
| 640 |
+
new_name: str,
|
| 641 |
+
version: int,
|
| 642 |
+
ssa: List[Tuple[List[Tuple[str, int]], List[Tuple[str, int]]]],
|
| 643 |
+
start_versions: Dict[str, int],
|
| 644 |
+
end_versions: Dict[str, int],
|
| 645 |
+
):
|
| 646 |
+
"""In given proto, rename all blobs with matched version"""
|
| 647 |
+
# Operater list
|
| 648 |
+
for op, i_th_ssa in zip(proto.op, ssa):
|
| 649 |
+
versioned_inputs, versioned_outputs = i_th_ssa
|
| 650 |
+
for i in range(len(op.input)):
|
| 651 |
+
if versioned_inputs[i] == (old_name, version):
|
| 652 |
+
op.input[i] = new_name
|
| 653 |
+
for i in range(len(op.output)):
|
| 654 |
+
if versioned_outputs[i] == (old_name, version):
|
| 655 |
+
op.output[i] = new_name
|
| 656 |
+
# external_input
|
| 657 |
+
if start_versions.get(old_name, 0) == version:
|
| 658 |
+
for i in range(len(proto.external_input)):
|
| 659 |
+
if proto.external_input[i] == old_name:
|
| 660 |
+
proto.external_input[i] = new_name
|
| 661 |
+
# external_output
|
| 662 |
+
if end_versions.get(old_name, 0) == version:
|
| 663 |
+
for i in range(len(proto.external_output)):
|
| 664 |
+
if proto.external_output[i] == old_name:
|
| 665 |
+
proto.external_output[i] = new_name
|
| 666 |
+
|
| 667 |
+
|
| 668 |
+
def rename_op_input(
|
| 669 |
+
predict_net: caffe2_pb2.NetDef,
|
| 670 |
+
init_net: caffe2_pb2.NetDef,
|
| 671 |
+
op_id: int,
|
| 672 |
+
input_id: int,
|
| 673 |
+
new_name: str,
|
| 674 |
+
from_producer: bool = False,
|
| 675 |
+
):
|
| 676 |
+
"""
|
| 677 |
+
Rename the op_id-th operator in predict_net, change it's input_id-th input's
|
| 678 |
+
name to the new_name. It also does automatic re-route and change
|
| 679 |
+
external_input and init_net if necessary.
|
| 680 |
+
- It requires the input is only consumed by this op.
|
| 681 |
+
- This function modifies predict_net and init_net in-place.
|
| 682 |
+
- When from_producer is enable, this also updates other operators that consumes
|
| 683 |
+
the same input. Be cautious because may trigger unintended behavior.
|
| 684 |
+
"""
|
| 685 |
+
assert isinstance(predict_net, caffe2_pb2.NetDef)
|
| 686 |
+
assert isinstance(init_net, caffe2_pb2.NetDef)
|
| 687 |
+
|
| 688 |
+
init_net_ssa, init_net_versions = core.get_ssa(init_net)
|
| 689 |
+
predict_net_ssa, predict_net_versions = core.get_ssa(
|
| 690 |
+
predict_net, copy.deepcopy(init_net_versions)
|
| 691 |
+
)
|
| 692 |
+
|
| 693 |
+
versioned_inputs, versioned_outputs = predict_net_ssa[op_id]
|
| 694 |
+
old_name, version = versioned_inputs[input_id]
|
| 695 |
+
|
| 696 |
+
if from_producer:
|
| 697 |
+
producer_map = get_producer_map(predict_net_ssa)
|
| 698 |
+
if not (old_name, version) in producer_map:
|
| 699 |
+
raise NotImplementedError(
|
| 700 |
+
"Can't find producer, the input {} is probably from"
|
| 701 |
+
" init_net, this is not supported yet.".format(old_name)
|
| 702 |
+
)
|
| 703 |
+
producer = producer_map[(old_name, version)]
|
| 704 |
+
rename_op_output(predict_net, producer[0], producer[1], new_name)
|
| 705 |
+
return
|
| 706 |
+
|
| 707 |
+
def contain_targets(op_ssa):
|
| 708 |
+
return (old_name, version) in op_ssa[0]
|
| 709 |
+
|
| 710 |
+
is_consumer = [contain_targets(op_ssa) for op_ssa in predict_net_ssa]
|
| 711 |
+
if sum(is_consumer) > 1:
|
| 712 |
+
raise IllegalGraphTransformError(
|
| 713 |
+
(
|
| 714 |
+
"Input '{}' of operator(#{}) are consumed by other ops, please use"
|
| 715 |
+
+ " rename_op_output on the producer instead. Offending op: \n{}"
|
| 716 |
+
).format(old_name, op_id, predict_net.op[op_id])
|
| 717 |
+
)
|
| 718 |
+
|
| 719 |
+
# update init_net
|
| 720 |
+
_rename_versioned_blob_in_proto(
|
| 721 |
+
init_net, old_name, new_name, version, init_net_ssa, {}, init_net_versions
|
| 722 |
+
)
|
| 723 |
+
# update predict_net
|
| 724 |
+
_rename_versioned_blob_in_proto(
|
| 725 |
+
predict_net,
|
| 726 |
+
old_name,
|
| 727 |
+
new_name,
|
| 728 |
+
version,
|
| 729 |
+
predict_net_ssa,
|
| 730 |
+
init_net_versions,
|
| 731 |
+
predict_net_versions,
|
| 732 |
+
)
|
| 733 |
+
|
| 734 |
+
|
| 735 |
+
def rename_op_output(predict_net: caffe2_pb2.NetDef, op_id: int, output_id: int, new_name: str):
|
| 736 |
+
"""
|
| 737 |
+
Rename the op_id-th operator in predict_net, change it's output_id-th input's
|
| 738 |
+
name to the new_name. It also does automatic re-route and change
|
| 739 |
+
external_output and if necessary.
|
| 740 |
+
- It allows multiple consumers of its output.
|
| 741 |
+
- This function modifies predict_net in-place, doesn't need init_net.
|
| 742 |
+
"""
|
| 743 |
+
assert isinstance(predict_net, caffe2_pb2.NetDef)
|
| 744 |
+
|
| 745 |
+
ssa, blob_versions = core.get_ssa(predict_net)
|
| 746 |
+
|
| 747 |
+
versioned_inputs, versioned_outputs = ssa[op_id]
|
| 748 |
+
old_name, version = versioned_outputs[output_id]
|
| 749 |
+
|
| 750 |
+
# update predict_net
|
| 751 |
+
_rename_versioned_blob_in_proto(
|
| 752 |
+
predict_net, old_name, new_name, version, ssa, {}, blob_versions
|
| 753 |
+
)
|
| 754 |
+
|
| 755 |
+
|
| 756 |
+
def get_sub_graph_external_input_output(
|
| 757 |
+
predict_net: caffe2_pb2.NetDef, sub_graph_op_indices: List[int]
|
| 758 |
+
) -> Tuple[List[Tuple[str, int]], List[Tuple[str, int]]]:
|
| 759 |
+
"""
|
| 760 |
+
Return the list of external input/output of sub-graph,
|
| 761 |
+
each element is tuple of the name and corresponding version in predict_net.
|
| 762 |
+
|
| 763 |
+
external input/output is defined the same way as caffe2 NetDef.
|
| 764 |
+
"""
|
| 765 |
+
ssa, versions = core.get_ssa(predict_net)
|
| 766 |
+
|
| 767 |
+
all_inputs = []
|
| 768 |
+
all_outputs = []
|
| 769 |
+
for op_id in sub_graph_op_indices:
|
| 770 |
+
all_inputs += [inp for inp in ssa[op_id][0] if inp not in all_inputs]
|
| 771 |
+
all_outputs += list(ssa[op_id][1]) # ssa output won't repeat
|
| 772 |
+
|
| 773 |
+
# for versioned blobs, external inputs are just those blob in all_inputs
|
| 774 |
+
# but not in all_outputs
|
| 775 |
+
ext_inputs = [inp for inp in all_inputs if inp not in all_outputs]
|
| 776 |
+
|
| 777 |
+
# external outputs are essentially outputs of this subgraph that are used
|
| 778 |
+
# outside of this sub-graph (including predict_net.external_output)
|
| 779 |
+
all_other_inputs = sum(
|
| 780 |
+
(ssa[i][0] for i in range(len(ssa)) if i not in sub_graph_op_indices),
|
| 781 |
+
[(outp, versions[outp]) for outp in predict_net.external_output],
|
| 782 |
+
)
|
| 783 |
+
ext_outputs = [outp for outp in all_outputs if outp in set(all_other_inputs)]
|
| 784 |
+
|
| 785 |
+
return ext_inputs, ext_outputs
|
| 786 |
+
|
| 787 |
+
|
| 788 |
+
class DiGraph:
|
| 789 |
+
"""A DAG representation of caffe2 graph, each vertice is a versioned blob."""
|
| 790 |
+
|
| 791 |
+
def __init__(self):
|
| 792 |
+
self.vertices = set()
|
| 793 |
+
self.graph = collections.defaultdict(list)
|
| 794 |
+
|
| 795 |
+
def add_edge(self, u, v):
|
| 796 |
+
self.graph[u].append(v)
|
| 797 |
+
self.vertices.add(u)
|
| 798 |
+
self.vertices.add(v)
|
| 799 |
+
|
| 800 |
+
# grab from https://www.geeksforgeeks.org/find-paths-given-source-destination/
|
| 801 |
+
def get_all_paths(self, s, d):
|
| 802 |
+
visited = {k: False for k in self.vertices}
|
| 803 |
+
path = []
|
| 804 |
+
all_paths = []
|
| 805 |
+
|
| 806 |
+
def _get_all_paths_util(graph, u, d, visited, path):
|
| 807 |
+
visited[u] = True
|
| 808 |
+
path.append(u)
|
| 809 |
+
if u == d:
|
| 810 |
+
all_paths.append(copy.deepcopy(path))
|
| 811 |
+
else:
|
| 812 |
+
for i in graph[u]:
|
| 813 |
+
if not visited[i]:
|
| 814 |
+
_get_all_paths_util(graph, i, d, visited, path)
|
| 815 |
+
path.pop()
|
| 816 |
+
visited[u] = False
|
| 817 |
+
|
| 818 |
+
_get_all_paths_util(self.graph, s, d, visited, path)
|
| 819 |
+
return all_paths
|
| 820 |
+
|
| 821 |
+
@staticmethod
|
| 822 |
+
def from_ssa(ssa):
|
| 823 |
+
graph = DiGraph()
|
| 824 |
+
for op_id in range(len(ssa)):
|
| 825 |
+
for inp in ssa[op_id][0]:
|
| 826 |
+
for outp in ssa[op_id][1]:
|
| 827 |
+
graph.add_edge(inp, outp)
|
| 828 |
+
return graph
|
| 829 |
+
|
| 830 |
+
|
| 831 |
+
def _get_dependency_chain(ssa, versioned_target, versioned_source):
|
| 832 |
+
"""
|
| 833 |
+
Return the index list of relevant operator to produce target blob from source blob,
|
| 834 |
+
if there's no dependency, return empty list.
|
| 835 |
+
"""
|
| 836 |
+
|
| 837 |
+
# finding all paths between nodes can be O(N!), thus we can only search
|
| 838 |
+
# in the subgraph using the op starting from the first consumer of source blob
|
| 839 |
+
# to the producer of the target blob.
|
| 840 |
+
consumer_map = get_consumer_map(ssa)
|
| 841 |
+
producer_map = get_producer_map(ssa)
|
| 842 |
+
start_op = min(x[0] for x in consumer_map[versioned_source]) - 15
|
| 843 |
+
end_op = (
|
| 844 |
+
producer_map[versioned_target][0] + 15 if versioned_target in producer_map else start_op
|
| 845 |
+
)
|
| 846 |
+
sub_graph_ssa = ssa[start_op : end_op + 1]
|
| 847 |
+
if len(sub_graph_ssa) > 30:
|
| 848 |
+
logger.warning(
|
| 849 |
+
"Subgraph bebetween {} and {} is large (from op#{} to op#{}), it"
|
| 850 |
+
" might take non-trival time to find all paths between them.".format(
|
| 851 |
+
versioned_source, versioned_target, start_op, end_op
|
| 852 |
+
)
|
| 853 |
+
)
|
| 854 |
+
|
| 855 |
+
dag = DiGraph.from_ssa(sub_graph_ssa)
|
| 856 |
+
paths = dag.get_all_paths(versioned_source, versioned_target) # include two ends
|
| 857 |
+
ops_in_paths = [[producer_map[blob][0] for blob in path[1:]] for path in paths]
|
| 858 |
+
return sorted(set().union(*[set(ops) for ops in ops_in_paths]))
|
| 859 |
+
|
| 860 |
+
|
| 861 |
+
def identify_reshape_sub_graph(predict_net: caffe2_pb2.NetDef) -> List[List[int]]:
|
| 862 |
+
"""
|
| 863 |
+
Idenfity the reshape sub-graph in a protobuf.
|
| 864 |
+
The reshape sub-graph is defined as matching the following pattern:
|
| 865 |
+
|
| 866 |
+
(input_blob) -> Op_1 -> ... -> Op_N -> (new_shape) -─┐
|
| 867 |
+
└-------------------------------------------> Reshape -> (output_blob)
|
| 868 |
+
|
| 869 |
+
Return:
|
| 870 |
+
List of sub-graphs, each sub-graph is represented as a list of indices
|
| 871 |
+
of the relavent ops, [Op_1, Op_2, ..., Op_N, Reshape]
|
| 872 |
+
"""
|
| 873 |
+
|
| 874 |
+
ssa, _ = core.get_ssa(predict_net)
|
| 875 |
+
|
| 876 |
+
ret = []
|
| 877 |
+
for i, op in enumerate(predict_net.op):
|
| 878 |
+
if op.type == "Reshape":
|
| 879 |
+
assert len(op.input) == 2
|
| 880 |
+
input_ssa = ssa[i][0]
|
| 881 |
+
data_source = input_ssa[0]
|
| 882 |
+
shape_source = input_ssa[1]
|
| 883 |
+
op_indices = _get_dependency_chain(ssa, shape_source, data_source)
|
| 884 |
+
ret.append(op_indices + [i])
|
| 885 |
+
return ret
|
| 886 |
+
|
| 887 |
+
|
| 888 |
+
def remove_reshape_for_fc(predict_net, params):
|
| 889 |
+
"""
|
| 890 |
+
In PyTorch nn.Linear has to take 2D tensor, this often leads to reshape
|
| 891 |
+
a 4D tensor to 2D by calling .view(). However this (dynamic) reshaping
|
| 892 |
+
doesn't work well with ONNX and Int8 tools, and cause using extra
|
| 893 |
+
ops (eg. ExpandDims) that might not be available on mobile.
|
| 894 |
+
Luckily Caffe2 supports 4D tensor for FC, so we can remove those reshape
|
| 895 |
+
after exporting ONNX model.
|
| 896 |
+
"""
|
| 897 |
+
from caffe2.python import core
|
| 898 |
+
|
| 899 |
+
# find all reshape sub-graph that can be removed, which is now all Reshape
|
| 900 |
+
# sub-graph whose output is only consumed by FC.
|
| 901 |
+
# TODO: to make it safer, we may need the actually value to better determine
|
| 902 |
+
# if a Reshape before FC is removable.
|
| 903 |
+
reshape_sub_graphs = identify_reshape_sub_graph(predict_net)
|
| 904 |
+
sub_graphs_to_remove = []
|
| 905 |
+
for reshape_sub_graph in reshape_sub_graphs:
|
| 906 |
+
reshape_op_id = reshape_sub_graph[-1]
|
| 907 |
+
assert predict_net.op[reshape_op_id].type == "Reshape"
|
| 908 |
+
ssa, _ = core.get_ssa(predict_net)
|
| 909 |
+
reshape_output = ssa[reshape_op_id][1][0]
|
| 910 |
+
consumers = [i for i in range(len(ssa)) if reshape_output in ssa[i][0]]
|
| 911 |
+
if all(predict_net.op[consumer].type == "FC" for consumer in consumers):
|
| 912 |
+
# safety check if the sub-graph is isolated, for this reshape sub-graph,
|
| 913 |
+
# it means it has one non-param external input and one external output.
|
| 914 |
+
ext_inputs, ext_outputs = get_sub_graph_external_input_output(
|
| 915 |
+
predict_net, reshape_sub_graph
|
| 916 |
+
)
|
| 917 |
+
non_params_ext_inputs = [inp for inp in ext_inputs if inp[1] != 0]
|
| 918 |
+
if len(non_params_ext_inputs) == 1 and len(ext_outputs) == 1:
|
| 919 |
+
sub_graphs_to_remove.append(reshape_sub_graph)
|
| 920 |
+
|
| 921 |
+
# perform removing subgraph by:
|
| 922 |
+
# 1: rename the Reshape's output to its input, then the graph can be
|
| 923 |
+
# seen as in-place itentify, meaning whose external input/output are the same.
|
| 924 |
+
# 2: simply remove those ops.
|
| 925 |
+
remove_op_ids = []
|
| 926 |
+
params_to_remove = []
|
| 927 |
+
for sub_graph in sub_graphs_to_remove:
|
| 928 |
+
logger.info(
|
| 929 |
+
"Remove Reshape sub-graph:\n{}".format(
|
| 930 |
+
"".join(["(#{:>4})\n{}".format(i, predict_net.op[i]) for i in sub_graph])
|
| 931 |
+
)
|
| 932 |
+
)
|
| 933 |
+
reshape_op_id = sub_graph[-1]
|
| 934 |
+
new_reshap_output = predict_net.op[reshape_op_id].input[0]
|
| 935 |
+
rename_op_output(predict_net, reshape_op_id, 0, new_reshap_output)
|
| 936 |
+
ext_inputs, ext_outputs = get_sub_graph_external_input_output(predict_net, sub_graph)
|
| 937 |
+
non_params_ext_inputs = [inp for inp in ext_inputs if inp[1] != 0]
|
| 938 |
+
params_ext_inputs = [inp for inp in ext_inputs if inp[1] == 0]
|
| 939 |
+
assert len(non_params_ext_inputs) == 1 and len(ext_outputs) == 1
|
| 940 |
+
assert ext_outputs[0][0] == non_params_ext_inputs[0][0]
|
| 941 |
+
assert ext_outputs[0][1] == non_params_ext_inputs[0][1] + 1
|
| 942 |
+
remove_op_ids.extend(sub_graph)
|
| 943 |
+
params_to_remove.extend(params_ext_inputs)
|
| 944 |
+
|
| 945 |
+
predict_net = copy.deepcopy(predict_net)
|
| 946 |
+
new_ops = [op for i, op in enumerate(predict_net.op) if i not in remove_op_ids]
|
| 947 |
+
del predict_net.op[:]
|
| 948 |
+
predict_net.op.extend(new_ops)
|
| 949 |
+
for versioned_params in params_to_remove:
|
| 950 |
+
name = versioned_params[0]
|
| 951 |
+
logger.info("Remove params: {} from init_net and predict_net.external_input".format(name))
|
| 952 |
+
del params[name]
|
| 953 |
+
predict_net.external_input.remove(name)
|
| 954 |
+
|
| 955 |
+
return predict_net, params
|
| 956 |
+
|
| 957 |
+
|
| 958 |
+
def fuse_copy_between_cpu_and_gpu(predict_net: caffe2_pb2.NetDef):
|
| 959 |
+
"""
|
| 960 |
+
In-place fuse extra copy ops between cpu/gpu for the following case:
|
| 961 |
+
a -CopyAToB-> b -CopyBToA> c1 -NextOp1-> d1
|
| 962 |
+
-CopyBToA> c2 -NextOp2-> d2
|
| 963 |
+
The fused network will look like:
|
| 964 |
+
a -NextOp1-> d1
|
| 965 |
+
-NextOp2-> d2
|
| 966 |
+
"""
|
| 967 |
+
|
| 968 |
+
_COPY_OPS = ["CopyCPUToGPU", "CopyGPUToCPU"]
|
| 969 |
+
|
| 970 |
+
def _fuse_once(predict_net):
|
| 971 |
+
ssa, blob_versions = core.get_ssa(predict_net)
|
| 972 |
+
consumer_map = get_consumer_map(ssa)
|
| 973 |
+
versioned_external_output = [
|
| 974 |
+
(name, blob_versions[name]) for name in predict_net.external_output
|
| 975 |
+
]
|
| 976 |
+
|
| 977 |
+
for op_id, op in enumerate(predict_net.op):
|
| 978 |
+
if op.type in _COPY_OPS:
|
| 979 |
+
fw_copy_versioned_output = ssa[op_id][1][0]
|
| 980 |
+
consumer_ids = [x[0] for x in consumer_map[fw_copy_versioned_output]]
|
| 981 |
+
reverse_op_type = _COPY_OPS[1 - _COPY_OPS.index(op.type)]
|
| 982 |
+
|
| 983 |
+
is_fusable = (
|
| 984 |
+
len(consumer_ids) > 0
|
| 985 |
+
and fw_copy_versioned_output not in versioned_external_output
|
| 986 |
+
and all(
|
| 987 |
+
predict_net.op[_op_id].type == reverse_op_type
|
| 988 |
+
and ssa[_op_id][1][0] not in versioned_external_output
|
| 989 |
+
for _op_id in consumer_ids
|
| 990 |
+
)
|
| 991 |
+
)
|
| 992 |
+
|
| 993 |
+
if is_fusable:
|
| 994 |
+
for rv_copy_op_id in consumer_ids:
|
| 995 |
+
# making each NextOp uses "a" directly and removing Copy ops
|
| 996 |
+
rs_copy_versioned_output = ssa[rv_copy_op_id][1][0]
|
| 997 |
+
next_op_id, inp_id = consumer_map[rs_copy_versioned_output][0]
|
| 998 |
+
predict_net.op[next_op_id].input[inp_id] = op.input[0]
|
| 999 |
+
# remove CopyOps
|
| 1000 |
+
new_ops = [
|
| 1001 |
+
op
|
| 1002 |
+
for i, op in enumerate(predict_net.op)
|
| 1003 |
+
if i != op_id and i not in consumer_ids
|
| 1004 |
+
]
|
| 1005 |
+
del predict_net.op[:]
|
| 1006 |
+
predict_net.op.extend(new_ops)
|
| 1007 |
+
return True
|
| 1008 |
+
|
| 1009 |
+
return False
|
| 1010 |
+
|
| 1011 |
+
# _fuse_once returns False is nothing can be fused
|
| 1012 |
+
while _fuse_once(predict_net):
|
| 1013 |
+
pass
|
| 1014 |
+
|
| 1015 |
+
|
| 1016 |
+
def remove_dead_end_ops(net_def: caffe2_pb2.NetDef):
|
| 1017 |
+
"""remove ops if its output is not used or not in external_output"""
|
| 1018 |
+
ssa, versions = core.get_ssa(net_def)
|
| 1019 |
+
versioned_external_output = [(name, versions[name]) for name in net_def.external_output]
|
| 1020 |
+
consumer_map = get_consumer_map(ssa)
|
| 1021 |
+
removed_op_ids = set()
|
| 1022 |
+
|
| 1023 |
+
def _is_dead_end(versioned_blob):
|
| 1024 |
+
return not (
|
| 1025 |
+
versioned_blob in versioned_external_output
|
| 1026 |
+
or (
|
| 1027 |
+
len(consumer_map[versioned_blob]) > 0
|
| 1028 |
+
and all(x[0] not in removed_op_ids for x in consumer_map[versioned_blob])
|
| 1029 |
+
)
|
| 1030 |
+
)
|
| 1031 |
+
|
| 1032 |
+
for i, ssa_i in reversed(list(enumerate(ssa))):
|
| 1033 |
+
versioned_outputs = ssa_i[1]
|
| 1034 |
+
if all(_is_dead_end(outp) for outp in versioned_outputs):
|
| 1035 |
+
removed_op_ids.add(i)
|
| 1036 |
+
|
| 1037 |
+
# simply removing those deadend ops should have no effect to external_output
|
| 1038 |
+
new_ops = [op for i, op in enumerate(net_def.op) if i not in removed_op_ids]
|
| 1039 |
+
del net_def.op[:]
|
| 1040 |
+
net_def.op.extend(new_ops)
|
CatVTON/detectron2/export/torchscript.py
ADDED
|
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
from detectron2.utils.file_io import PathManager
|
| 7 |
+
|
| 8 |
+
from .torchscript_patch import freeze_training_mode, patch_instances
|
| 9 |
+
|
| 10 |
+
__all__ = ["scripting_with_instances", "dump_torchscript_IR"]
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def scripting_with_instances(model, fields):
|
| 14 |
+
"""
|
| 15 |
+
Run :func:`torch.jit.script` on a model that uses the :class:`Instances` class. Since
|
| 16 |
+
attributes of :class:`Instances` are "dynamically" added in eager mode,it is difficult
|
| 17 |
+
for scripting to support it out of the box. This function is made to support scripting
|
| 18 |
+
a model that uses :class:`Instances`. It does the following:
|
| 19 |
+
|
| 20 |
+
1. Create a scriptable ``new_Instances`` class which behaves similarly to ``Instances``,
|
| 21 |
+
but with all attributes been "static".
|
| 22 |
+
The attributes need to be statically declared in the ``fields`` argument.
|
| 23 |
+
2. Register ``new_Instances``, and force scripting compiler to
|
| 24 |
+
use it when trying to compile ``Instances``.
|
| 25 |
+
|
| 26 |
+
After this function, the process will be reverted. User should be able to script another model
|
| 27 |
+
using different fields.
|
| 28 |
+
|
| 29 |
+
Example:
|
| 30 |
+
Assume that ``Instances`` in the model consist of two attributes named
|
| 31 |
+
``proposal_boxes`` and ``objectness_logits`` with type :class:`Boxes` and
|
| 32 |
+
:class:`Tensor` respectively during inference. You can call this function like:
|
| 33 |
+
::
|
| 34 |
+
fields = {"proposal_boxes": Boxes, "objectness_logits": torch.Tensor}
|
| 35 |
+
torchscipt_model = scripting_with_instances(model, fields)
|
| 36 |
+
|
| 37 |
+
Note:
|
| 38 |
+
It only support models in evaluation mode.
|
| 39 |
+
|
| 40 |
+
Args:
|
| 41 |
+
model (nn.Module): The input model to be exported by scripting.
|
| 42 |
+
fields (Dict[str, type]): Attribute names and corresponding type that
|
| 43 |
+
``Instances`` will use in the model. Note that all attributes used in ``Instances``
|
| 44 |
+
need to be added, regardless of whether they are inputs/outputs of the model.
|
| 45 |
+
Data type not defined in detectron2 is not supported for now.
|
| 46 |
+
|
| 47 |
+
Returns:
|
| 48 |
+
torch.jit.ScriptModule: the model in torchscript format
|
| 49 |
+
"""
|
| 50 |
+
assert (
|
| 51 |
+
not model.training
|
| 52 |
+
), "Currently we only support exporting models in evaluation mode to torchscript"
|
| 53 |
+
|
| 54 |
+
with freeze_training_mode(model), patch_instances(fields):
|
| 55 |
+
scripted_model = torch.jit.script(model)
|
| 56 |
+
return scripted_model
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
# alias for old name
|
| 60 |
+
export_torchscript_with_instances = scripting_with_instances
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def dump_torchscript_IR(model, dir):
|
| 64 |
+
"""
|
| 65 |
+
Dump IR of a TracedModule/ScriptModule/Function in various format (code, graph,
|
| 66 |
+
inlined graph). Useful for debugging.
|
| 67 |
+
|
| 68 |
+
Args:
|
| 69 |
+
model (TracedModule/ScriptModule/ScriptFUnction): traced or scripted module
|
| 70 |
+
dir (str): output directory to dump files.
|
| 71 |
+
"""
|
| 72 |
+
dir = os.path.expanduser(dir)
|
| 73 |
+
PathManager.mkdirs(dir)
|
| 74 |
+
|
| 75 |
+
def _get_script_mod(mod):
|
| 76 |
+
if isinstance(mod, torch.jit.TracedModule):
|
| 77 |
+
return mod._actual_script_module
|
| 78 |
+
return mod
|
| 79 |
+
|
| 80 |
+
# Dump pretty-printed code: https://pytorch.org/docs/stable/jit.html#inspecting-code
|
| 81 |
+
with PathManager.open(os.path.join(dir, "model_ts_code.txt"), "w") as f:
|
| 82 |
+
|
| 83 |
+
def get_code(mod):
|
| 84 |
+
# Try a few ways to get code using private attributes.
|
| 85 |
+
try:
|
| 86 |
+
# This contains more information than just `mod.code`
|
| 87 |
+
return _get_script_mod(mod)._c.code
|
| 88 |
+
except AttributeError:
|
| 89 |
+
pass
|
| 90 |
+
try:
|
| 91 |
+
return mod.code
|
| 92 |
+
except AttributeError:
|
| 93 |
+
return None
|
| 94 |
+
|
| 95 |
+
def dump_code(prefix, mod):
|
| 96 |
+
code = get_code(mod)
|
| 97 |
+
name = prefix or "root model"
|
| 98 |
+
if code is None:
|
| 99 |
+
f.write(f"Could not found code for {name} (type={mod.original_name})\n")
|
| 100 |
+
f.write("\n")
|
| 101 |
+
else:
|
| 102 |
+
f.write(f"\nCode for {name}, type={mod.original_name}:\n")
|
| 103 |
+
f.write(code)
|
| 104 |
+
f.write("\n")
|
| 105 |
+
f.write("-" * 80)
|
| 106 |
+
|
| 107 |
+
for name, m in mod.named_children():
|
| 108 |
+
dump_code(prefix + "." + name, m)
|
| 109 |
+
|
| 110 |
+
if isinstance(model, torch.jit.ScriptFunction):
|
| 111 |
+
f.write(get_code(model))
|
| 112 |
+
else:
|
| 113 |
+
dump_code("", model)
|
| 114 |
+
|
| 115 |
+
def _get_graph(model):
|
| 116 |
+
try:
|
| 117 |
+
# Recursively dump IR of all modules
|
| 118 |
+
return _get_script_mod(model)._c.dump_to_str(True, False, False)
|
| 119 |
+
except AttributeError:
|
| 120 |
+
return model.graph.str()
|
| 121 |
+
|
| 122 |
+
with PathManager.open(os.path.join(dir, "model_ts_IR.txt"), "w") as f:
|
| 123 |
+
f.write(_get_graph(model))
|
| 124 |
+
|
| 125 |
+
# Dump IR of the entire graph (all submodules inlined)
|
| 126 |
+
with PathManager.open(os.path.join(dir, "model_ts_IR_inlined.txt"), "w") as f:
|
| 127 |
+
f.write(str(model.inlined_graph))
|
| 128 |
+
|
| 129 |
+
if not isinstance(model, torch.jit.ScriptFunction):
|
| 130 |
+
# Dump the model structure in pytorch style
|
| 131 |
+
with PathManager.open(os.path.join(dir, "model.txt"), "w") as f:
|
| 132 |
+
f.write(str(model))
|
CatVTON/detectron2/export/torchscript_patch.py
ADDED
|
@@ -0,0 +1,406 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
import sys
|
| 5 |
+
import tempfile
|
| 6 |
+
from contextlib import ExitStack, contextmanager
|
| 7 |
+
from copy import deepcopy
|
| 8 |
+
from unittest import mock
|
| 9 |
+
import torch
|
| 10 |
+
from torch import nn
|
| 11 |
+
|
| 12 |
+
# need some explicit imports due to https://github.com/pytorch/pytorch/issues/38964
|
| 13 |
+
import detectron2 # noqa F401
|
| 14 |
+
from detectron2.structures import Boxes, Instances
|
| 15 |
+
from detectron2.utils.env import _import_file
|
| 16 |
+
|
| 17 |
+
_counter = 0
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def _clear_jit_cache():
|
| 21 |
+
from torch.jit._recursive import concrete_type_store
|
| 22 |
+
from torch.jit._state import _jit_caching_layer
|
| 23 |
+
|
| 24 |
+
concrete_type_store.type_store.clear() # for modules
|
| 25 |
+
_jit_caching_layer.clear() # for free functions
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def _add_instances_conversion_methods(newInstances):
|
| 29 |
+
"""
|
| 30 |
+
Add from_instances methods to the scripted Instances class.
|
| 31 |
+
"""
|
| 32 |
+
cls_name = newInstances.__name__
|
| 33 |
+
|
| 34 |
+
@torch.jit.unused
|
| 35 |
+
def from_instances(instances: Instances):
|
| 36 |
+
"""
|
| 37 |
+
Create scripted Instances from original Instances
|
| 38 |
+
"""
|
| 39 |
+
fields = instances.get_fields()
|
| 40 |
+
image_size = instances.image_size
|
| 41 |
+
ret = newInstances(image_size)
|
| 42 |
+
for name, val in fields.items():
|
| 43 |
+
assert hasattr(ret, f"_{name}"), f"No attribute named {name} in {cls_name}"
|
| 44 |
+
setattr(ret, name, deepcopy(val))
|
| 45 |
+
return ret
|
| 46 |
+
|
| 47 |
+
newInstances.from_instances = from_instances
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
@contextmanager
|
| 51 |
+
def patch_instances(fields):
|
| 52 |
+
"""
|
| 53 |
+
A contextmanager, under which the Instances class in detectron2 is replaced
|
| 54 |
+
by a statically-typed scriptable class, defined by `fields`.
|
| 55 |
+
See more in `scripting_with_instances`.
|
| 56 |
+
"""
|
| 57 |
+
|
| 58 |
+
with tempfile.TemporaryDirectory(prefix="detectron2") as dir, tempfile.NamedTemporaryFile(
|
| 59 |
+
mode="w", encoding="utf-8", suffix=".py", dir=dir, delete=False
|
| 60 |
+
) as f:
|
| 61 |
+
try:
|
| 62 |
+
# Objects that use Instances should not reuse previously-compiled
|
| 63 |
+
# results in cache, because `Instances` could be a new class each time.
|
| 64 |
+
_clear_jit_cache()
|
| 65 |
+
|
| 66 |
+
cls_name, s = _gen_instance_module(fields)
|
| 67 |
+
f.write(s)
|
| 68 |
+
f.flush()
|
| 69 |
+
f.close()
|
| 70 |
+
|
| 71 |
+
module = _import(f.name)
|
| 72 |
+
new_instances = getattr(module, cls_name)
|
| 73 |
+
_ = torch.jit.script(new_instances)
|
| 74 |
+
# let torchscript think Instances was scripted already
|
| 75 |
+
Instances.__torch_script_class__ = True
|
| 76 |
+
# let torchscript find new_instances when looking for the jit type of Instances
|
| 77 |
+
Instances._jit_override_qualname = torch._jit_internal._qualified_name(new_instances)
|
| 78 |
+
|
| 79 |
+
_add_instances_conversion_methods(new_instances)
|
| 80 |
+
yield new_instances
|
| 81 |
+
finally:
|
| 82 |
+
try:
|
| 83 |
+
del Instances.__torch_script_class__
|
| 84 |
+
del Instances._jit_override_qualname
|
| 85 |
+
except AttributeError:
|
| 86 |
+
pass
|
| 87 |
+
sys.modules.pop(module.__name__)
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def _gen_instance_class(fields):
|
| 91 |
+
"""
|
| 92 |
+
Args:
|
| 93 |
+
fields (dict[name: type])
|
| 94 |
+
"""
|
| 95 |
+
|
| 96 |
+
class _FieldType:
|
| 97 |
+
def __init__(self, name, type_):
|
| 98 |
+
assert isinstance(name, str), f"Field name must be str, got {name}"
|
| 99 |
+
self.name = name
|
| 100 |
+
self.type_ = type_
|
| 101 |
+
self.annotation = f"{type_.__module__}.{type_.__name__}"
|
| 102 |
+
|
| 103 |
+
fields = [_FieldType(k, v) for k, v in fields.items()]
|
| 104 |
+
|
| 105 |
+
def indent(level, s):
|
| 106 |
+
return " " * 4 * level + s
|
| 107 |
+
|
| 108 |
+
lines = []
|
| 109 |
+
|
| 110 |
+
global _counter
|
| 111 |
+
_counter += 1
|
| 112 |
+
|
| 113 |
+
cls_name = "ScriptedInstances{}".format(_counter)
|
| 114 |
+
|
| 115 |
+
field_names = tuple(x.name for x in fields)
|
| 116 |
+
extra_args = ", ".join([f"{f.name}: Optional[{f.annotation}] = None" for f in fields])
|
| 117 |
+
lines.append(
|
| 118 |
+
f"""
|
| 119 |
+
class {cls_name}:
|
| 120 |
+
def __init__(self, image_size: Tuple[int, int], {extra_args}):
|
| 121 |
+
self.image_size = image_size
|
| 122 |
+
self._field_names = {field_names}
|
| 123 |
+
"""
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
for f in fields:
|
| 127 |
+
lines.append(
|
| 128 |
+
indent(2, f"self._{f.name} = torch.jit.annotate(Optional[{f.annotation}], {f.name})")
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
for f in fields:
|
| 132 |
+
lines.append(
|
| 133 |
+
f"""
|
| 134 |
+
@property
|
| 135 |
+
def {f.name}(self) -> {f.annotation}:
|
| 136 |
+
# has to use a local for type refinement
|
| 137 |
+
# https://pytorch.org/docs/stable/jit_language_reference.html#optional-type-refinement
|
| 138 |
+
t = self._{f.name}
|
| 139 |
+
assert t is not None, "{f.name} is None and cannot be accessed!"
|
| 140 |
+
return t
|
| 141 |
+
|
| 142 |
+
@{f.name}.setter
|
| 143 |
+
def {f.name}(self, value: {f.annotation}) -> None:
|
| 144 |
+
self._{f.name} = value
|
| 145 |
+
"""
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
# support method `__len__`
|
| 149 |
+
lines.append(
|
| 150 |
+
"""
|
| 151 |
+
def __len__(self) -> int:
|
| 152 |
+
"""
|
| 153 |
+
)
|
| 154 |
+
for f in fields:
|
| 155 |
+
lines.append(
|
| 156 |
+
f"""
|
| 157 |
+
t = self._{f.name}
|
| 158 |
+
if t is not None:
|
| 159 |
+
return len(t)
|
| 160 |
+
"""
|
| 161 |
+
)
|
| 162 |
+
lines.append(
|
| 163 |
+
"""
|
| 164 |
+
raise NotImplementedError("Empty Instances does not support __len__!")
|
| 165 |
+
"""
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
# support method `has`
|
| 169 |
+
lines.append(
|
| 170 |
+
"""
|
| 171 |
+
def has(self, name: str) -> bool:
|
| 172 |
+
"""
|
| 173 |
+
)
|
| 174 |
+
for f in fields:
|
| 175 |
+
lines.append(
|
| 176 |
+
f"""
|
| 177 |
+
if name == "{f.name}":
|
| 178 |
+
return self._{f.name} is not None
|
| 179 |
+
"""
|
| 180 |
+
)
|
| 181 |
+
lines.append(
|
| 182 |
+
"""
|
| 183 |
+
return False
|
| 184 |
+
"""
|
| 185 |
+
)
|
| 186 |
+
|
| 187 |
+
# support method `to`
|
| 188 |
+
none_args = ", None" * len(fields)
|
| 189 |
+
lines.append(
|
| 190 |
+
f"""
|
| 191 |
+
def to(self, device: torch.device) -> "{cls_name}":
|
| 192 |
+
ret = {cls_name}(self.image_size{none_args})
|
| 193 |
+
"""
|
| 194 |
+
)
|
| 195 |
+
for f in fields:
|
| 196 |
+
if hasattr(f.type_, "to"):
|
| 197 |
+
lines.append(
|
| 198 |
+
f"""
|
| 199 |
+
t = self._{f.name}
|
| 200 |
+
if t is not None:
|
| 201 |
+
ret._{f.name} = t.to(device)
|
| 202 |
+
"""
|
| 203 |
+
)
|
| 204 |
+
else:
|
| 205 |
+
# For now, ignore fields that cannot be moved to devices.
|
| 206 |
+
# Maybe can support other tensor-like classes (e.g. __torch_function__)
|
| 207 |
+
pass
|
| 208 |
+
lines.append(
|
| 209 |
+
"""
|
| 210 |
+
return ret
|
| 211 |
+
"""
|
| 212 |
+
)
|
| 213 |
+
|
| 214 |
+
# support method `getitem`
|
| 215 |
+
none_args = ", None" * len(fields)
|
| 216 |
+
lines.append(
|
| 217 |
+
f"""
|
| 218 |
+
def __getitem__(self, item) -> "{cls_name}":
|
| 219 |
+
ret = {cls_name}(self.image_size{none_args})
|
| 220 |
+
"""
|
| 221 |
+
)
|
| 222 |
+
for f in fields:
|
| 223 |
+
lines.append(
|
| 224 |
+
f"""
|
| 225 |
+
t = self._{f.name}
|
| 226 |
+
if t is not None:
|
| 227 |
+
ret._{f.name} = t[item]
|
| 228 |
+
"""
|
| 229 |
+
)
|
| 230 |
+
lines.append(
|
| 231 |
+
"""
|
| 232 |
+
return ret
|
| 233 |
+
"""
|
| 234 |
+
)
|
| 235 |
+
|
| 236 |
+
# support method `cat`
|
| 237 |
+
# this version does not contain checks that all instances have same size and fields
|
| 238 |
+
none_args = ", None" * len(fields)
|
| 239 |
+
lines.append(
|
| 240 |
+
f"""
|
| 241 |
+
def cat(self, instances: List["{cls_name}"]) -> "{cls_name}":
|
| 242 |
+
ret = {cls_name}(self.image_size{none_args})
|
| 243 |
+
"""
|
| 244 |
+
)
|
| 245 |
+
for f in fields:
|
| 246 |
+
lines.append(
|
| 247 |
+
f"""
|
| 248 |
+
t = self._{f.name}
|
| 249 |
+
if t is not None:
|
| 250 |
+
values: List[{f.annotation}] = [x.{f.name} for x in instances]
|
| 251 |
+
if torch.jit.isinstance(t, torch.Tensor):
|
| 252 |
+
ret._{f.name} = torch.cat(values, dim=0)
|
| 253 |
+
else:
|
| 254 |
+
ret._{f.name} = t.cat(values)
|
| 255 |
+
"""
|
| 256 |
+
)
|
| 257 |
+
lines.append(
|
| 258 |
+
"""
|
| 259 |
+
return ret"""
|
| 260 |
+
)
|
| 261 |
+
|
| 262 |
+
# support method `get_fields()`
|
| 263 |
+
lines.append(
|
| 264 |
+
"""
|
| 265 |
+
def get_fields(self) -> Dict[str, Tensor]:
|
| 266 |
+
ret = {}
|
| 267 |
+
"""
|
| 268 |
+
)
|
| 269 |
+
for f in fields:
|
| 270 |
+
if f.type_ == Boxes:
|
| 271 |
+
stmt = "t.tensor"
|
| 272 |
+
elif f.type_ == torch.Tensor:
|
| 273 |
+
stmt = "t"
|
| 274 |
+
else:
|
| 275 |
+
stmt = f'assert False, "unsupported type {str(f.type_)}"'
|
| 276 |
+
lines.append(
|
| 277 |
+
f"""
|
| 278 |
+
t = self._{f.name}
|
| 279 |
+
if t is not None:
|
| 280 |
+
ret["{f.name}"] = {stmt}
|
| 281 |
+
"""
|
| 282 |
+
)
|
| 283 |
+
lines.append(
|
| 284 |
+
"""
|
| 285 |
+
return ret"""
|
| 286 |
+
)
|
| 287 |
+
return cls_name, os.linesep.join(lines)
|
| 288 |
+
|
| 289 |
+
|
| 290 |
+
def _gen_instance_module(fields):
|
| 291 |
+
# TODO: find a more automatic way to enable import of other classes
|
| 292 |
+
s = """
|
| 293 |
+
from copy import deepcopy
|
| 294 |
+
import torch
|
| 295 |
+
from torch import Tensor
|
| 296 |
+
import typing
|
| 297 |
+
from typing import *
|
| 298 |
+
|
| 299 |
+
import detectron2
|
| 300 |
+
from detectron2.structures import Boxes, Instances
|
| 301 |
+
|
| 302 |
+
"""
|
| 303 |
+
|
| 304 |
+
cls_name, cls_def = _gen_instance_class(fields)
|
| 305 |
+
s += cls_def
|
| 306 |
+
return cls_name, s
|
| 307 |
+
|
| 308 |
+
|
| 309 |
+
def _import(path):
|
| 310 |
+
return _import_file(
|
| 311 |
+
"{}{}".format(sys.modules[__name__].__name__, _counter), path, make_importable=True
|
| 312 |
+
)
|
| 313 |
+
|
| 314 |
+
|
| 315 |
+
@contextmanager
|
| 316 |
+
def patch_builtin_len(modules=()):
|
| 317 |
+
"""
|
| 318 |
+
Patch the builtin len() function of a few detectron2 modules
|
| 319 |
+
to use __len__ instead, because __len__ does not convert values to
|
| 320 |
+
integers and therefore is friendly to tracing.
|
| 321 |
+
|
| 322 |
+
Args:
|
| 323 |
+
modules (list[stsr]): names of extra modules to patch len(), in
|
| 324 |
+
addition to those in detectron2.
|
| 325 |
+
"""
|
| 326 |
+
|
| 327 |
+
def _new_len(obj):
|
| 328 |
+
return obj.__len__()
|
| 329 |
+
|
| 330 |
+
with ExitStack() as stack:
|
| 331 |
+
MODULES = [
|
| 332 |
+
"detectron2.modeling.roi_heads.fast_rcnn",
|
| 333 |
+
"detectron2.modeling.roi_heads.mask_head",
|
| 334 |
+
"detectron2.modeling.roi_heads.keypoint_head",
|
| 335 |
+
] + list(modules)
|
| 336 |
+
ctxs = [stack.enter_context(mock.patch(mod + ".len")) for mod in MODULES]
|
| 337 |
+
for m in ctxs:
|
| 338 |
+
m.side_effect = _new_len
|
| 339 |
+
yield
|
| 340 |
+
|
| 341 |
+
|
| 342 |
+
def patch_nonscriptable_classes():
|
| 343 |
+
"""
|
| 344 |
+
Apply patches on a few nonscriptable detectron2 classes.
|
| 345 |
+
Should not have side-effects on eager usage.
|
| 346 |
+
"""
|
| 347 |
+
# __prepare_scriptable__ can also be added to models for easier maintenance.
|
| 348 |
+
# But it complicates the clean model code.
|
| 349 |
+
|
| 350 |
+
from detectron2.modeling.backbone import ResNet, FPN
|
| 351 |
+
|
| 352 |
+
# Due to https://github.com/pytorch/pytorch/issues/36061,
|
| 353 |
+
# we change backbone to use ModuleList for scripting.
|
| 354 |
+
# (note: this changes param names in state_dict)
|
| 355 |
+
|
| 356 |
+
def prepare_resnet(self):
|
| 357 |
+
ret = deepcopy(self)
|
| 358 |
+
ret.stages = nn.ModuleList(ret.stages)
|
| 359 |
+
for k in self.stage_names:
|
| 360 |
+
delattr(ret, k)
|
| 361 |
+
return ret
|
| 362 |
+
|
| 363 |
+
ResNet.__prepare_scriptable__ = prepare_resnet
|
| 364 |
+
|
| 365 |
+
def prepare_fpn(self):
|
| 366 |
+
ret = deepcopy(self)
|
| 367 |
+
ret.lateral_convs = nn.ModuleList(ret.lateral_convs)
|
| 368 |
+
ret.output_convs = nn.ModuleList(ret.output_convs)
|
| 369 |
+
for name, _ in self.named_children():
|
| 370 |
+
if name.startswith("fpn_"):
|
| 371 |
+
delattr(ret, name)
|
| 372 |
+
return ret
|
| 373 |
+
|
| 374 |
+
FPN.__prepare_scriptable__ = prepare_fpn
|
| 375 |
+
|
| 376 |
+
# Annotate some attributes to be constants for the purpose of scripting,
|
| 377 |
+
# even though they are not constants in eager mode.
|
| 378 |
+
from detectron2.modeling.roi_heads import StandardROIHeads
|
| 379 |
+
|
| 380 |
+
if hasattr(StandardROIHeads, "__annotations__"):
|
| 381 |
+
# copy first to avoid editing annotations of base class
|
| 382 |
+
StandardROIHeads.__annotations__ = deepcopy(StandardROIHeads.__annotations__)
|
| 383 |
+
StandardROIHeads.__annotations__["mask_on"] = torch.jit.Final[bool]
|
| 384 |
+
StandardROIHeads.__annotations__["keypoint_on"] = torch.jit.Final[bool]
|
| 385 |
+
|
| 386 |
+
|
| 387 |
+
# These patches are not supposed to have side-effects.
|
| 388 |
+
patch_nonscriptable_classes()
|
| 389 |
+
|
| 390 |
+
|
| 391 |
+
@contextmanager
|
| 392 |
+
def freeze_training_mode(model):
|
| 393 |
+
"""
|
| 394 |
+
A context manager that annotates the "training" attribute of every submodule
|
| 395 |
+
to constant, so that the training codepath in these modules can be
|
| 396 |
+
meta-compiled away. Upon exiting, the annotations are reverted.
|
| 397 |
+
"""
|
| 398 |
+
classes = {type(x) for x in model.modules()}
|
| 399 |
+
# __constants__ is the old way to annotate constants and not compatible
|
| 400 |
+
# with __annotations__ .
|
| 401 |
+
classes = {x for x in classes if not hasattr(x, "__constants__")}
|
| 402 |
+
for cls in classes:
|
| 403 |
+
cls.__annotations__["training"] = torch.jit.Final[bool]
|
| 404 |
+
yield
|
| 405 |
+
for cls in classes:
|
| 406 |
+
cls.__annotations__["training"] = bool
|