AnishaNaik03 commited on
Commit
c667f42
·
verified ·
1 Parent(s): 697091d

Upload 30 files

Browse files
detectron2/utils/README.md ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Utility functions
2
+
3
+ This folder contain utility functions that are not used in the
4
+ core library, but are useful for building models or training
5
+ code using the config system.
detectron2/utils/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
detectron2/utils/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (171 Bytes). View file
 
detectron2/utils/__pycache__/collect_env.cpython-311.pyc ADDED
Binary file (14.7 kB). View file
 
detectron2/utils/__pycache__/comm.cpython-311.pyc ADDED
Binary file (10.5 kB). View file
 
detectron2/utils/__pycache__/develop.cpython-311.pyc ADDED
Binary file (3.01 kB). View file
 
detectron2/utils/__pycache__/env.cpython-311.pyc ADDED
Binary file (8.3 kB). View file
 
detectron2/utils/__pycache__/events.cpython-311.pyc ADDED
Binary file (26.9 kB). View file
 
detectron2/utils/__pycache__/file_io.cpython-311.pyc ADDED
Binary file (2.11 kB). View file
 
detectron2/utils/__pycache__/logger.cpython-311.pyc ADDED
Binary file (12 kB). View file
 
detectron2/utils/__pycache__/memory.cpython-311.pyc ADDED
Binary file (4.38 kB). View file
 
detectron2/utils/__pycache__/registry.cpython-311.pyc ADDED
Binary file (2.06 kB). View file
 
detectron2/utils/__pycache__/serialize.cpython-311.pyc ADDED
Binary file (1.92 kB). View file
 
detectron2/utils/__pycache__/tracing.cpython-311.pyc ADDED
Binary file (3.72 kB). View file
 
detectron2/utils/analysis.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # -*- coding: utf-8 -*-
3
+
4
+ import typing
5
+ from typing import Any, List
6
+ import fvcore
7
+ from fvcore.nn import activation_count, flop_count, parameter_count, parameter_count_table
8
+ from torch import nn
9
+
10
+ from detectron2.export import TracingAdapter
11
+
12
+ __all__ = [
13
+ "activation_count_operators",
14
+ "flop_count_operators",
15
+ "parameter_count_table",
16
+ "parameter_count",
17
+ "FlopCountAnalysis",
18
+ ]
19
+
20
+ FLOPS_MODE = "flops"
21
+ ACTIVATIONS_MODE = "activations"
22
+
23
+
24
+ # Some extra ops to ignore from counting, including elementwise and reduction ops
25
+ _IGNORED_OPS = {
26
+ "aten::add",
27
+ "aten::add_",
28
+ "aten::argmax",
29
+ "aten::argsort",
30
+ "aten::batch_norm",
31
+ "aten::constant_pad_nd",
32
+ "aten::div",
33
+ "aten::div_",
34
+ "aten::exp",
35
+ "aten::log2",
36
+ "aten::max_pool2d",
37
+ "aten::meshgrid",
38
+ "aten::mul",
39
+ "aten::mul_",
40
+ "aten::neg",
41
+ "aten::nonzero_numpy",
42
+ "aten::reciprocal",
43
+ "aten::repeat_interleave",
44
+ "aten::rsub",
45
+ "aten::sigmoid",
46
+ "aten::sigmoid_",
47
+ "aten::softmax",
48
+ "aten::sort",
49
+ "aten::sqrt",
50
+ "aten::sub",
51
+ "torchvision::nms", # TODO estimate flop for nms
52
+ }
53
+
54
+
55
+ class FlopCountAnalysis(fvcore.nn.FlopCountAnalysis):
56
+ """
57
+ Same as :class:`fvcore.nn.FlopCountAnalysis`, but supports detectron2 models.
58
+ """
59
+
60
+ def __init__(self, model, inputs):
61
+ """
62
+ Args:
63
+ model (nn.Module):
64
+ inputs (Any): inputs of the given model. Does not have to be tuple of tensors.
65
+ """
66
+ wrapper = TracingAdapter(model, inputs, allow_non_tensor=True)
67
+ super().__init__(wrapper, wrapper.flattened_inputs)
68
+ self.set_op_handle(**{k: None for k in _IGNORED_OPS})
69
+
70
+
71
+ def flop_count_operators(model: nn.Module, inputs: list) -> typing.DefaultDict[str, float]:
72
+ """
73
+ Implement operator-level flops counting using jit.
74
+ This is a wrapper of :func:`fvcore.nn.flop_count` and adds supports for standard
75
+ detection models in detectron2.
76
+ Please use :class:`FlopCountAnalysis` for more advanced functionalities.
77
+
78
+ Note:
79
+ The function runs the input through the model to compute flops.
80
+ The flops of a detection model is often input-dependent, for example,
81
+ the flops of box & mask head depends on the number of proposals &
82
+ the number of detected objects.
83
+ Therefore, the flops counting using a single input may not accurately
84
+ reflect the computation cost of a model. It's recommended to average
85
+ across a number of inputs.
86
+
87
+ Args:
88
+ model: a detectron2 model that takes `list[dict]` as input.
89
+ inputs (list[dict]): inputs to model, in detectron2's standard format.
90
+ Only "image" key will be used.
91
+ supported_ops (dict[str, Handle]): see documentation of :func:`fvcore.nn.flop_count`
92
+
93
+ Returns:
94
+ Counter: Gflop count per operator
95
+ """
96
+ old_train = model.training
97
+ model.eval()
98
+ ret = FlopCountAnalysis(model, inputs).by_operator()
99
+ model.train(old_train)
100
+ return {k: v / 1e9 for k, v in ret.items()}
101
+
102
+
103
+ def activation_count_operators(
104
+ model: nn.Module, inputs: list, **kwargs
105
+ ) -> typing.DefaultDict[str, float]:
106
+ """
107
+ Implement operator-level activations counting using jit.
108
+ This is a wrapper of fvcore.nn.activation_count, that supports standard detection models
109
+ in detectron2.
110
+
111
+ Note:
112
+ The function runs the input through the model to compute activations.
113
+ The activations of a detection model is often input-dependent, for example,
114
+ the activations of box & mask head depends on the number of proposals &
115
+ the number of detected objects.
116
+
117
+ Args:
118
+ model: a detectron2 model that takes `list[dict]` as input.
119
+ inputs (list[dict]): inputs to model, in detectron2's standard format.
120
+ Only "image" key will be used.
121
+
122
+ Returns:
123
+ Counter: activation count per operator
124
+ """
125
+ return _wrapper_count_operators(model=model, inputs=inputs, mode=ACTIVATIONS_MODE, **kwargs)
126
+
127
+
128
+ def _wrapper_count_operators(
129
+ model: nn.Module, inputs: list, mode: str, **kwargs
130
+ ) -> typing.DefaultDict[str, float]:
131
+ # ignore some ops
132
+ supported_ops = {k: lambda *args, **kwargs: {} for k in _IGNORED_OPS}
133
+ supported_ops.update(kwargs.pop("supported_ops", {}))
134
+ kwargs["supported_ops"] = supported_ops
135
+
136
+ assert len(inputs) == 1, "Please use batch size=1"
137
+ tensor_input = inputs[0]["image"]
138
+ inputs = [{"image": tensor_input}] # remove other keys, in case there are any
139
+
140
+ old_train = model.training
141
+ if isinstance(model, (nn.parallel.distributed.DistributedDataParallel, nn.DataParallel)):
142
+ model = model.module
143
+ wrapper = TracingAdapter(model, inputs)
144
+ wrapper.eval()
145
+ if mode == FLOPS_MODE:
146
+ ret = flop_count(wrapper, (tensor_input,), **kwargs)
147
+ elif mode == ACTIVATIONS_MODE:
148
+ ret = activation_count(wrapper, (tensor_input,), **kwargs)
149
+ else:
150
+ raise NotImplementedError("Count for mode {} is not supported yet.".format(mode))
151
+ # compatible with change in fvcore
152
+ if isinstance(ret, tuple):
153
+ ret = ret[0]
154
+ model.train(old_train)
155
+ return ret
156
+
157
+
158
+ def find_unused_parameters(model: nn.Module, inputs: Any) -> List[str]:
159
+ """
160
+ Given a model, find parameters that do not contribute
161
+ to the loss.
162
+
163
+ Args:
164
+ model: a model in training mode that returns losses
165
+ inputs: argument or a tuple of arguments. Inputs of the model
166
+
167
+ Returns:
168
+ list[str]: the name of unused parameters
169
+ """
170
+ assert model.training
171
+ for _, prm in model.named_parameters():
172
+ prm.grad = None
173
+
174
+ if isinstance(inputs, tuple):
175
+ losses = model(*inputs)
176
+ else:
177
+ losses = model(inputs)
178
+
179
+ if isinstance(losses, dict):
180
+ losses = sum(losses.values())
181
+ losses.backward()
182
+
183
+ unused: List[str] = []
184
+ for name, prm in model.named_parameters():
185
+ if prm.grad is None:
186
+ unused.append(name)
187
+ prm.grad = None
188
+ return unused
detectron2/utils/collect_env.py ADDED
@@ -0,0 +1,263 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ import importlib
3
+ import numpy as np
4
+ import os
5
+ import re
6
+ import subprocess
7
+ import sys
8
+ from collections import defaultdict
9
+ import PIL
10
+ import torch
11
+ import torchvision
12
+ from tabulate import tabulate
13
+
14
+ __all__ = ["collect_env_info"]
15
+
16
+
17
+ def collect_torch_env():
18
+ try:
19
+ import torch.__config__
20
+
21
+ return torch.__config__.show()
22
+ except ImportError:
23
+ # compatible with older versions of pytorch
24
+ from torch.utils.collect_env import get_pretty_env_info
25
+
26
+ return get_pretty_env_info()
27
+
28
+
29
+ def get_env_module():
30
+ var_name = "DETECTRON2_ENV_MODULE"
31
+ return var_name, os.environ.get(var_name, "<not set>")
32
+
33
+
34
+ def detect_compute_compatibility(CUDA_HOME, so_file):
35
+ try:
36
+ cuobjdump = os.path.join(CUDA_HOME, "bin", "cuobjdump")
37
+ if os.path.isfile(cuobjdump):
38
+ output = subprocess.check_output(
39
+ "'{}' --list-elf '{}'".format(cuobjdump, so_file), shell=True
40
+ )
41
+ output = output.decode("utf-8").strip().split("\n")
42
+ arch = []
43
+ for line in output:
44
+ line = re.findall(r"\.sm_([0-9]*)\.", line)[0]
45
+ arch.append(".".join(line))
46
+ arch = sorted(set(arch))
47
+ return ", ".join(arch)
48
+ else:
49
+ return so_file + "; cannot find cuobjdump"
50
+ except Exception:
51
+ # unhandled failure
52
+ return so_file
53
+
54
+
55
+ def collect_env_info():
56
+ has_gpu = torch.cuda.is_available() # true for both CUDA & ROCM
57
+ torch_version = torch.__version__
58
+
59
+ # NOTE that CUDA_HOME/ROCM_HOME could be None even when CUDA runtime libs are functional
60
+ from torch.utils.cpp_extension import CUDA_HOME, ROCM_HOME
61
+
62
+ has_rocm = False
63
+ if (getattr(torch.version, "hip", None) is not None) and (ROCM_HOME is not None):
64
+ has_rocm = True
65
+ has_cuda = has_gpu and (not has_rocm)
66
+
67
+ data = []
68
+ data.append(("sys.platform", sys.platform)) # check-template.yml depends on it
69
+ data.append(("Python", sys.version.replace("\n", "")))
70
+ data.append(("numpy", np.__version__))
71
+
72
+ try:
73
+ import detectron2 # noqa
74
+
75
+ data.append(
76
+ (
77
+ "detectron2",
78
+ detectron2.__version__ + " @" + os.path.dirname(detectron2.__file__),
79
+ )
80
+ )
81
+ except ImportError:
82
+ data.append(("detectron2", "failed to import"))
83
+ except AttributeError:
84
+ data.append(("detectron2", "imported a wrong installation"))
85
+
86
+ try:
87
+ import detectron2._C as _C
88
+ except ImportError as e:
89
+ data.append(("detectron2._C", f"not built correctly: {e}"))
90
+
91
+ # print system compilers when extension fails to build
92
+ if sys.platform != "win32": # don't know what to do for windows
93
+ try:
94
+ # this is how torch/utils/cpp_extensions.py choose compiler
95
+ cxx = os.environ.get("CXX", "c++")
96
+ cxx = subprocess.check_output("'{}' --version".format(cxx), shell=True)
97
+ cxx = cxx.decode("utf-8").strip().split("\n")[0]
98
+ except subprocess.SubprocessError:
99
+ cxx = "Not found"
100
+ data.append(("Compiler ($CXX)", cxx))
101
+
102
+ if has_cuda and CUDA_HOME is not None:
103
+ try:
104
+ nvcc = os.path.join(CUDA_HOME, "bin", "nvcc")
105
+ nvcc = subprocess.check_output("'{}' -V".format(nvcc), shell=True)
106
+ nvcc = nvcc.decode("utf-8").strip().split("\n")[-1]
107
+ except subprocess.SubprocessError:
108
+ nvcc = "Not found"
109
+ data.append(("CUDA compiler", nvcc))
110
+ if has_cuda and sys.platform != "win32":
111
+ try:
112
+ so_file = importlib.util.find_spec("detectron2._C").origin
113
+ except (ImportError, AttributeError):
114
+ pass
115
+ else:
116
+ data.append(
117
+ (
118
+ "detectron2 arch flags",
119
+ detect_compute_compatibility(CUDA_HOME, so_file),
120
+ )
121
+ )
122
+ else:
123
+ # print compilers that are used to build extension
124
+ data.append(("Compiler", _C.get_compiler_version()))
125
+ data.append(("CUDA compiler", _C.get_cuda_version())) # cuda or hip
126
+ if has_cuda and getattr(_C, "has_cuda", lambda: True)():
127
+ data.append(
128
+ (
129
+ "detectron2 arch flags",
130
+ detect_compute_compatibility(CUDA_HOME, _C.__file__),
131
+ )
132
+ )
133
+
134
+ data.append(get_env_module())
135
+ data.append(("PyTorch", torch_version + " @" + os.path.dirname(torch.__file__)))
136
+ data.append(("PyTorch debug build", torch.version.debug))
137
+ try:
138
+ data.append(("torch._C._GLIBCXX_USE_CXX11_ABI", torch._C._GLIBCXX_USE_CXX11_ABI))
139
+ except Exception:
140
+ pass
141
+
142
+ if not has_gpu:
143
+ has_gpu_text = "No: torch.cuda.is_available() == False"
144
+ else:
145
+ has_gpu_text = "Yes"
146
+ data.append(("GPU available", has_gpu_text))
147
+ if has_gpu:
148
+ devices = defaultdict(list)
149
+ for k in range(torch.cuda.device_count()):
150
+ cap = ".".join((str(x) for x in torch.cuda.get_device_capability(k)))
151
+ name = torch.cuda.get_device_name(k) + f" (arch={cap})"
152
+ devices[name].append(str(k))
153
+ for name, devids in devices.items():
154
+ data.append(("GPU " + ",".join(devids), name))
155
+
156
+ if has_rocm:
157
+ msg = " - invalid!" if not (ROCM_HOME and os.path.isdir(ROCM_HOME)) else ""
158
+ data.append(("ROCM_HOME", str(ROCM_HOME) + msg))
159
+ else:
160
+ try:
161
+ from torch.utils.collect_env import (
162
+ get_nvidia_driver_version,
163
+ run as _run,
164
+ )
165
+
166
+ data.append(("Driver version", get_nvidia_driver_version(_run)))
167
+ except Exception:
168
+ pass
169
+ msg = " - invalid!" if not (CUDA_HOME and os.path.isdir(CUDA_HOME)) else ""
170
+ data.append(("CUDA_HOME", str(CUDA_HOME) + msg))
171
+
172
+ cuda_arch_list = os.environ.get("TORCH_CUDA_ARCH_LIST", None)
173
+ if cuda_arch_list:
174
+ data.append(("TORCH_CUDA_ARCH_LIST", cuda_arch_list))
175
+ data.append(("Pillow", PIL.__version__))
176
+
177
+ try:
178
+ data.append(
179
+ (
180
+ "torchvision",
181
+ str(torchvision.__version__) + " @" + os.path.dirname(torchvision.__file__),
182
+ )
183
+ )
184
+ if has_cuda:
185
+ try:
186
+ torchvision_C = importlib.util.find_spec("torchvision._C").origin
187
+ msg = detect_compute_compatibility(CUDA_HOME, torchvision_C)
188
+ data.append(("torchvision arch flags", msg))
189
+ except (ImportError, AttributeError):
190
+ data.append(("torchvision._C", "Not found"))
191
+ except AttributeError:
192
+ data.append(("torchvision", "unknown"))
193
+
194
+ try:
195
+ import fvcore
196
+
197
+ data.append(("fvcore", fvcore.__version__))
198
+ except (ImportError, AttributeError):
199
+ pass
200
+
201
+ try:
202
+ import iopath
203
+
204
+ data.append(("iopath", iopath.__version__))
205
+ except (ImportError, AttributeError):
206
+ pass
207
+
208
+ try:
209
+ import cv2
210
+
211
+ data.append(("cv2", cv2.__version__))
212
+ except (ImportError, AttributeError):
213
+ data.append(("cv2", "Not found"))
214
+ env_str = tabulate(data) + "\n"
215
+ env_str += collect_torch_env()
216
+ return env_str
217
+
218
+
219
+ def test_nccl_ops():
220
+ num_gpu = torch.cuda.device_count()
221
+ if os.access("/tmp", os.W_OK):
222
+ import torch.multiprocessing as mp
223
+
224
+ dist_url = "file:///tmp/nccl_tmp_file"
225
+ print("Testing NCCL connectivity ... this should not hang.")
226
+ mp.spawn(_test_nccl_worker, nprocs=num_gpu, args=(num_gpu, dist_url), daemon=False)
227
+ print("NCCL succeeded.")
228
+
229
+
230
+ def _test_nccl_worker(rank, num_gpu, dist_url):
231
+ import torch.distributed as dist
232
+
233
+ dist.init_process_group(backend="NCCL", init_method=dist_url, rank=rank, world_size=num_gpu)
234
+ dist.barrier(device_ids=[rank])
235
+
236
+
237
+ def main() -> None:
238
+ global x
239
+ try:
240
+ from detectron2.utils.collect_env import collect_env_info as f
241
+
242
+ print(f())
243
+ except ImportError:
244
+ print(collect_env_info())
245
+
246
+ if torch.cuda.is_available():
247
+ num_gpu = torch.cuda.device_count()
248
+ for k in range(num_gpu):
249
+ device = f"cuda:{k}"
250
+ try:
251
+ x = torch.tensor([1, 2.0], dtype=torch.float32)
252
+ x = x.to(device)
253
+ except Exception as e:
254
+ print(
255
+ f"Unable to copy tensor to device={device}: {e}. "
256
+ "Your CUDA environment is broken."
257
+ )
258
+ if num_gpu > 1:
259
+ test_nccl_ops()
260
+
261
+
262
+ if __name__ == "__main__":
263
+ main() # pragma: no cover
detectron2/utils/colormap.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ """
4
+ An awesome colormap for really neat visualizations.
5
+ Copied from Detectron, and removed gray colors.
6
+ """
7
+
8
+ import numpy as np
9
+ import random
10
+
11
+ __all__ = ["colormap", "random_color", "random_colors"]
12
+
13
+ # fmt: off
14
+ # RGB:
15
+ _COLORS = np.array(
16
+ [
17
+ 0.000, 0.447, 0.741,
18
+ 0.850, 0.325, 0.098,
19
+ 0.929, 0.694, 0.125,
20
+ 0.494, 0.184, 0.556,
21
+ 0.466, 0.674, 0.188,
22
+ 0.301, 0.745, 0.933,
23
+ 0.635, 0.078, 0.184,
24
+ 0.300, 0.300, 0.300,
25
+ 0.600, 0.600, 0.600,
26
+ 1.000, 0.000, 0.000,
27
+ 1.000, 0.500, 0.000,
28
+ 0.749, 0.749, 0.000,
29
+ 0.000, 1.000, 0.000,
30
+ 0.000, 0.000, 1.000,
31
+ 0.667, 0.000, 1.000,
32
+ 0.333, 0.333, 0.000,
33
+ 0.333, 0.667, 0.000,
34
+ 0.333, 1.000, 0.000,
35
+ 0.667, 0.333, 0.000,
36
+ 0.667, 0.667, 0.000,
37
+ 0.667, 1.000, 0.000,
38
+ 1.000, 0.333, 0.000,
39
+ 1.000, 0.667, 0.000,
40
+ 1.000, 1.000, 0.000,
41
+ 0.000, 0.333, 0.500,
42
+ 0.000, 0.667, 0.500,
43
+ 0.000, 1.000, 0.500,
44
+ 0.333, 0.000, 0.500,
45
+ 0.333, 0.333, 0.500,
46
+ 0.333, 0.667, 0.500,
47
+ 0.333, 1.000, 0.500,
48
+ 0.667, 0.000, 0.500,
49
+ 0.667, 0.333, 0.500,
50
+ 0.667, 0.667, 0.500,
51
+ 0.667, 1.000, 0.500,
52
+ 1.000, 0.000, 0.500,
53
+ 1.000, 0.333, 0.500,
54
+ 1.000, 0.667, 0.500,
55
+ 1.000, 1.000, 0.500,
56
+ 0.000, 0.333, 1.000,
57
+ 0.000, 0.667, 1.000,
58
+ 0.000, 1.000, 1.000,
59
+ 0.333, 0.000, 1.000,
60
+ 0.333, 0.333, 1.000,
61
+ 0.333, 0.667, 1.000,
62
+ 0.333, 1.000, 1.000,
63
+ 0.667, 0.000, 1.000,
64
+ 0.667, 0.333, 1.000,
65
+ 0.667, 0.667, 1.000,
66
+ 0.667, 1.000, 1.000,
67
+ 1.000, 0.000, 1.000,
68
+ 1.000, 0.333, 1.000,
69
+ 1.000, 0.667, 1.000,
70
+ 0.333, 0.000, 0.000,
71
+ 0.500, 0.000, 0.000,
72
+ 0.667, 0.000, 0.000,
73
+ 0.833, 0.000, 0.000,
74
+ 1.000, 0.000, 0.000,
75
+ 0.000, 0.167, 0.000,
76
+ 0.000, 0.333, 0.000,
77
+ 0.000, 0.500, 0.000,
78
+ 0.000, 0.667, 0.000,
79
+ 0.000, 0.833, 0.000,
80
+ 0.000, 1.000, 0.000,
81
+ 0.000, 0.000, 0.167,
82
+ 0.000, 0.000, 0.333,
83
+ 0.000, 0.000, 0.500,
84
+ 0.000, 0.000, 0.667,
85
+ 0.000, 0.000, 0.833,
86
+ 0.000, 0.000, 1.000,
87
+ 0.000, 0.000, 0.000,
88
+ 0.143, 0.143, 0.143,
89
+ 0.857, 0.857, 0.857,
90
+ 1.000, 1.000, 1.000
91
+ ]
92
+ ).astype(np.float32).reshape(-1, 3)
93
+ # fmt: on
94
+
95
+
96
+ def colormap(rgb=False, maximum=255):
97
+ """
98
+ Args:
99
+ rgb (bool): whether to return RGB colors or BGR colors.
100
+ maximum (int): either 255 or 1
101
+
102
+ Returns:
103
+ ndarray: a float32 array of Nx3 colors, in range [0, 255] or [0, 1]
104
+ """
105
+ assert maximum in [255, 1], maximum
106
+ c = _COLORS * maximum
107
+ if not rgb:
108
+ c = c[:, ::-1]
109
+ return c
110
+
111
+
112
+ def random_color(rgb=False, maximum=255):
113
+ """
114
+ Args:
115
+ rgb (bool): whether to return RGB colors or BGR colors.
116
+ maximum (int): either 255 or 1
117
+
118
+ Returns:
119
+ ndarray: a vector of 3 numbers
120
+ """
121
+ idx = np.random.randint(0, len(_COLORS))
122
+ ret = _COLORS[idx] * maximum
123
+ if not rgb:
124
+ ret = ret[::-1]
125
+ return ret
126
+
127
+
128
+ def random_colors(N, rgb=False, maximum=255):
129
+ """
130
+ Args:
131
+ N (int): number of unique colors needed
132
+ rgb (bool): whether to return RGB colors or BGR colors.
133
+ maximum (int): either 255 or 1
134
+
135
+ Returns:
136
+ ndarray: a list of random_color
137
+ """
138
+ indices = random.sample(range(len(_COLORS)), N)
139
+ ret = [_COLORS[i] * maximum for i in indices]
140
+ if not rgb:
141
+ ret = [x[::-1] for x in ret]
142
+ return ret
143
+
144
+
145
+ if __name__ == "__main__":
146
+ import cv2
147
+
148
+ size = 100
149
+ H, W = 10, 10
150
+ canvas = np.random.rand(H * size, W * size, 3).astype("float32")
151
+ for h in range(H):
152
+ for w in range(W):
153
+ idx = h * W + w
154
+ if idx >= len(_COLORS):
155
+ break
156
+ canvas[h * size : (h + 1) * size, w * size : (w + 1) * size] = _COLORS[idx]
157
+ cv2.imshow("a", canvas)
158
+ cv2.waitKey(0)
detectron2/utils/comm.py ADDED
@@ -0,0 +1,238 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ """
3
+ This file contains primitives for multi-gpu communication.
4
+ This is useful when doing distributed training.
5
+ """
6
+
7
+ import functools
8
+ import numpy as np
9
+ import torch
10
+ import torch.distributed as dist
11
+
12
+ _LOCAL_PROCESS_GROUP = None
13
+ _MISSING_LOCAL_PG_ERROR = (
14
+ "Local process group is not yet created! Please use detectron2's `launch()` "
15
+ "to start processes and initialize pytorch process group. If you need to start "
16
+ "processes in other ways, please call comm.create_local_process_group("
17
+ "num_workers_per_machine) after calling torch.distributed.init_process_group()."
18
+ )
19
+
20
+
21
+ def get_world_size() -> int:
22
+ if not dist.is_available():
23
+ return 1
24
+ if not dist.is_initialized():
25
+ return 1
26
+ return dist.get_world_size()
27
+
28
+
29
+ def get_rank() -> int:
30
+ if not dist.is_available():
31
+ return 0
32
+ if not dist.is_initialized():
33
+ return 0
34
+ return dist.get_rank()
35
+
36
+
37
+ @functools.lru_cache()
38
+ def create_local_process_group(num_workers_per_machine: int) -> None:
39
+ """
40
+ Create a process group that contains ranks within the same machine.
41
+
42
+ Detectron2's launch() in engine/launch.py will call this function. If you start
43
+ workers without launch(), you'll have to also call this. Otherwise utilities
44
+ like `get_local_rank()` will not work.
45
+
46
+ This function contains a barrier. All processes must call it together.
47
+
48
+ Args:
49
+ num_workers_per_machine: the number of worker processes per machine. Typically
50
+ the number of GPUs.
51
+ """
52
+ global _LOCAL_PROCESS_GROUP
53
+ assert _LOCAL_PROCESS_GROUP is None
54
+ assert get_world_size() % num_workers_per_machine == 0
55
+ num_machines = get_world_size() // num_workers_per_machine
56
+ machine_rank = get_rank() // num_workers_per_machine
57
+ for i in range(num_machines):
58
+ ranks_on_i = list(range(i * num_workers_per_machine, (i + 1) * num_workers_per_machine))
59
+ pg = dist.new_group(ranks_on_i)
60
+ if i == machine_rank:
61
+ _LOCAL_PROCESS_GROUP = pg
62
+
63
+
64
+ def get_local_process_group():
65
+ """
66
+ Returns:
67
+ A torch process group which only includes processes that are on the same
68
+ machine as the current process. This group can be useful for communication
69
+ within a machine, e.g. a per-machine SyncBN.
70
+ """
71
+ assert _LOCAL_PROCESS_GROUP is not None, _MISSING_LOCAL_PG_ERROR
72
+ return _LOCAL_PROCESS_GROUP
73
+
74
+
75
+ def get_local_rank() -> int:
76
+ """
77
+ Returns:
78
+ The rank of the current process within the local (per-machine) process group.
79
+ """
80
+ if not dist.is_available():
81
+ return 0
82
+ if not dist.is_initialized():
83
+ return 0
84
+ assert _LOCAL_PROCESS_GROUP is not None, _MISSING_LOCAL_PG_ERROR
85
+ return dist.get_rank(group=_LOCAL_PROCESS_GROUP)
86
+
87
+
88
+ def get_local_size() -> int:
89
+ """
90
+ Returns:
91
+ The size of the per-machine process group,
92
+ i.e. the number of processes per machine.
93
+ """
94
+ if not dist.is_available():
95
+ return 1
96
+ if not dist.is_initialized():
97
+ return 1
98
+ assert _LOCAL_PROCESS_GROUP is not None, _MISSING_LOCAL_PG_ERROR
99
+ return dist.get_world_size(group=_LOCAL_PROCESS_GROUP)
100
+
101
+
102
+ def is_main_process() -> bool:
103
+ return get_rank() == 0
104
+
105
+
106
+ def synchronize():
107
+ """
108
+ Helper function to synchronize (barrier) among all processes when
109
+ using distributed training
110
+ """
111
+ if not dist.is_available():
112
+ return
113
+ if not dist.is_initialized():
114
+ return
115
+ world_size = dist.get_world_size()
116
+ if world_size == 1:
117
+ return
118
+ if dist.get_backend() == dist.Backend.NCCL:
119
+ # This argument is needed to avoid warnings.
120
+ # It's valid only for NCCL backend.
121
+ dist.barrier(device_ids=[torch.cuda.current_device()])
122
+ else:
123
+ dist.barrier()
124
+
125
+
126
+ @functools.lru_cache()
127
+ def _get_global_gloo_group():
128
+ """
129
+ Return a process group based on gloo backend, containing all the ranks
130
+ The result is cached.
131
+ """
132
+ if dist.get_backend() == "nccl":
133
+ return dist.new_group(backend="gloo")
134
+ else:
135
+ return dist.group.WORLD
136
+
137
+
138
+ def all_gather(data, group=None):
139
+ """
140
+ Run all_gather on arbitrary picklable data (not necessarily tensors).
141
+
142
+ Args:
143
+ data: any picklable object
144
+ group: a torch process group. By default, will use a group which
145
+ contains all ranks on gloo backend.
146
+
147
+ Returns:
148
+ list[data]: list of data gathered from each rank
149
+ """
150
+ if get_world_size() == 1:
151
+ return [data]
152
+ if group is None:
153
+ group = _get_global_gloo_group() # use CPU group by default, to reduce GPU RAM usage.
154
+ world_size = dist.get_world_size(group)
155
+ if world_size == 1:
156
+ return [data]
157
+
158
+ output = [None for _ in range(world_size)]
159
+ dist.all_gather_object(output, data, group=group)
160
+ return output
161
+
162
+
163
+ def gather(data, dst=0, group=None):
164
+ """
165
+ Run gather on arbitrary picklable data (not necessarily tensors).
166
+
167
+ Args:
168
+ data: any picklable object
169
+ dst (int): destination rank
170
+ group: a torch process group. By default, will use a group which
171
+ contains all ranks on gloo backend.
172
+
173
+ Returns:
174
+ list[data]: on dst, a list of data gathered from each rank. Otherwise,
175
+ an empty list.
176
+ """
177
+ if get_world_size() == 1:
178
+ return [data]
179
+ if group is None:
180
+ group = _get_global_gloo_group()
181
+ world_size = dist.get_world_size(group=group)
182
+ if world_size == 1:
183
+ return [data]
184
+ rank = dist.get_rank(group=group)
185
+
186
+ if rank == dst:
187
+ output = [None for _ in range(world_size)]
188
+ dist.gather_object(data, output, dst=dst, group=group)
189
+ return output
190
+ else:
191
+ dist.gather_object(data, None, dst=dst, group=group)
192
+ return []
193
+
194
+
195
+ def shared_random_seed():
196
+ """
197
+ Returns:
198
+ int: a random number that is the same across all workers.
199
+ If workers need a shared RNG, they can use this shared seed to
200
+ create one.
201
+
202
+ All workers must call this function, otherwise it will deadlock.
203
+ """
204
+ ints = np.random.randint(2**31)
205
+ all_ints = all_gather(ints)
206
+ return all_ints[0]
207
+
208
+
209
+ def reduce_dict(input_dict, average=True):
210
+ """
211
+ Reduce the values in the dictionary from all processes so that process with rank
212
+ 0 has the reduced results.
213
+
214
+ Args:
215
+ input_dict (dict): inputs to be reduced. All the values must be scalar CUDA Tensor.
216
+ average (bool): whether to do average or sum
217
+
218
+ Returns:
219
+ a dict with the same keys as input_dict, after reduction.
220
+ """
221
+ world_size = get_world_size()
222
+ if world_size < 2:
223
+ return input_dict
224
+ with torch.no_grad():
225
+ names = []
226
+ values = []
227
+ # sort the keys so that they are consistent across processes
228
+ for k in sorted(input_dict.keys()):
229
+ names.append(k)
230
+ values.append(input_dict[k])
231
+ values = torch.stack(values, dim=0)
232
+ dist.reduce(values, dst=0)
233
+ if dist.get_rank() == 0 and average:
234
+ # only main process gets accumulated, so only divide by
235
+ # world_size in this case
236
+ values /= world_size
237
+ reduced_dict = {k: v for k, v in zip(names, values)}
238
+ return reduced_dict
detectron2/utils/develop.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ """ Utilities for developers only.
3
+ These are not visible to users (not automatically imported). And should not
4
+ appeared in docs."""
5
+ # adapted from https://github.com/tensorpack/tensorpack/blob/master/tensorpack/utils/develop.py
6
+
7
+
8
+ def create_dummy_class(klass, dependency, message=""):
9
+ """
10
+ When a dependency of a class is not available, create a dummy class which throws ImportError
11
+ when used.
12
+
13
+ Args:
14
+ klass (str): name of the class.
15
+ dependency (str): name of the dependency.
16
+ message: extra message to print
17
+ Returns:
18
+ class: a class object
19
+ """
20
+ err = "Cannot import '{}', therefore '{}' is not available.".format(dependency, klass)
21
+ if message:
22
+ err = err + " " + message
23
+
24
+ class _DummyMetaClass(type):
25
+ # throw error on class attribute access
26
+ def __getattr__(_, __): # noqa: B902
27
+ raise ImportError(err)
28
+
29
+ class _Dummy(object, metaclass=_DummyMetaClass):
30
+ # throw error on constructor
31
+ def __init__(self, *args, **kwargs):
32
+ raise ImportError(err)
33
+
34
+ return _Dummy
35
+
36
+
37
+ def create_dummy_func(func, dependency, message=""):
38
+ """
39
+ When a dependency of a function is not available, create a dummy function which throws
40
+ ImportError when used.
41
+
42
+ Args:
43
+ func (str): name of the function.
44
+ dependency (str or list[str]): name(s) of the dependency.
45
+ message: extra message to print
46
+ Returns:
47
+ function: a function object
48
+ """
49
+ err = "Cannot import '{}', therefore '{}' is not available.".format(dependency, func)
50
+ if message:
51
+ err = err + " " + message
52
+
53
+ if isinstance(dependency, (list, tuple)):
54
+ dependency = ",".join(dependency)
55
+
56
+ def _dummy(*args, **kwargs):
57
+ raise ImportError(err)
58
+
59
+ return _dummy
detectron2/utils/env.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ import importlib
3
+ import importlib.util
4
+ import logging
5
+ import numpy as np
6
+ import os
7
+ import random
8
+ import sys
9
+ from datetime import datetime
10
+ import torch
11
+
12
+ __all__ = ["seed_all_rng"]
13
+
14
+
15
+ TORCH_VERSION = tuple(int(x) for x in torch.__version__.split(".")[:2])
16
+ """
17
+ PyTorch version as a tuple of 2 ints. Useful for comparison.
18
+ """
19
+
20
+
21
+ DOC_BUILDING = os.getenv("_DOC_BUILDING", False) # set in docs/conf.py
22
+ """
23
+ Whether we're building documentation.
24
+ """
25
+
26
+
27
+ def seed_all_rng(seed=None):
28
+ """
29
+ Set the random seed for the RNG in torch, numpy and python.
30
+
31
+ Args:
32
+ seed (int): if None, will use a strong random seed.
33
+ """
34
+ if seed is None:
35
+ seed = (
36
+ os.getpid()
37
+ + int(datetime.now().strftime("%S%f"))
38
+ + int.from_bytes(os.urandom(2), "big")
39
+ )
40
+ logger = logging.getLogger(__name__)
41
+ logger.info("Using a generated random seed {}".format(seed))
42
+ np.random.seed(seed)
43
+ torch.manual_seed(seed)
44
+ random.seed(seed)
45
+ torch.cuda.manual_seed_all(str(seed))
46
+ os.environ["PYTHONHASHSEED"] = str(seed)
47
+
48
+
49
+ # from https://stackoverflow.com/questions/67631/how-to-import-a-module-given-the-full-path
50
+ def _import_file(module_name, file_path, make_importable=False):
51
+ spec = importlib.util.spec_from_file_location(module_name, file_path)
52
+ module = importlib.util.module_from_spec(spec)
53
+ spec.loader.exec_module(module)
54
+ if make_importable:
55
+ sys.modules[module_name] = module
56
+ return module
57
+
58
+
59
+ def _configure_libraries():
60
+ """
61
+ Configurations for some libraries.
62
+ """
63
+ # An environment option to disable `import cv2` globally,
64
+ # in case it leads to negative performance impact
65
+ disable_cv2 = int(os.environ.get("DETECTRON2_DISABLE_CV2", False))
66
+ if disable_cv2:
67
+ sys.modules["cv2"] = None
68
+ else:
69
+ # Disable opencl in opencv since its interaction with cuda often has negative effects
70
+ # This envvar is supported after OpenCV 3.4.0
71
+ os.environ["OPENCV_OPENCL_RUNTIME"] = "disabled"
72
+ try:
73
+ import cv2
74
+
75
+ if int(cv2.__version__.split(".")[0]) >= 3:
76
+ cv2.ocl.setUseOpenCL(False)
77
+ except ModuleNotFoundError:
78
+ # Other types of ImportError, if happened, should not be ignored.
79
+ # Because a failed opencv import could mess up address space
80
+ # https://github.com/skvark/opencv-python/issues/381
81
+ pass
82
+
83
+ def get_version(module, digit=2):
84
+ return tuple(map(int, module.__version__.split(".")[:digit]))
85
+
86
+ # fmt: off
87
+ assert get_version(torch) >= (1, 4), "Requires torch>=1.4"
88
+ import fvcore
89
+ assert get_version(fvcore, 3) >= (0, 1, 2), "Requires fvcore>=0.1.2"
90
+ import yaml
91
+ assert get_version(yaml) >= (5, 1), "Requires pyyaml>=5.1"
92
+ # fmt: on
93
+
94
+
95
+ _ENV_SETUP_DONE = False
96
+
97
+
98
+ def setup_environment():
99
+ """Perform environment setup work. The default setup is a no-op, but this
100
+ function allows the user to specify a Python source file or a module in
101
+ the $DETECTRON2_ENV_MODULE environment variable, that performs
102
+ custom setup work that may be necessary to their computing environment.
103
+ """
104
+ global _ENV_SETUP_DONE
105
+ if _ENV_SETUP_DONE:
106
+ return
107
+ _ENV_SETUP_DONE = True
108
+
109
+ _configure_libraries()
110
+
111
+ custom_module_path = os.environ.get("DETECTRON2_ENV_MODULE")
112
+
113
+ if custom_module_path:
114
+ setup_custom_environment(custom_module_path)
115
+ else:
116
+ # The default setup is a no-op
117
+ pass
118
+
119
+
120
+ def setup_custom_environment(custom_module):
121
+ """
122
+ Load custom environment setup by importing a Python source file or a
123
+ module, and run the setup function.
124
+ """
125
+ if custom_module.endswith(".py"):
126
+ module = _import_file("detectron2.utils.env.custom_module", custom_module)
127
+ else:
128
+ module = importlib.import_module(custom_module)
129
+ assert hasattr(module, "setup_environment") and callable(module.setup_environment), (
130
+ "Custom environment module defined in {} does not have the "
131
+ "required callable attribute 'setup_environment'."
132
+ ).format(custom_module)
133
+ module.setup_environment()
134
+
135
+
136
+ def fixup_module_metadata(module_name, namespace, keys=None):
137
+ """
138
+ Fix the __qualname__ of module members to be their exported api name, so
139
+ when they are referenced in docs, sphinx can find them. Reference:
140
+ https://github.com/python-trio/trio/blob/6754c74eacfad9cc5c92d5c24727a2f3b620624e/trio/_util.py#L216-L241
141
+ """
142
+ if not DOC_BUILDING:
143
+ return
144
+ seen_ids = set()
145
+
146
+ def fix_one(qualname, name, obj):
147
+ # avoid infinite recursion (relevant when using
148
+ # typing.Generic, for example)
149
+ if id(obj) in seen_ids:
150
+ return
151
+ seen_ids.add(id(obj))
152
+
153
+ mod = getattr(obj, "__module__", None)
154
+ if mod is not None and (mod.startswith(module_name) or mod.startswith("fvcore.")):
155
+ obj.__module__ = module_name
156
+ # Modules, unlike everything else in Python, put fully-qualitied
157
+ # names into their __name__ attribute. We check for "." to avoid
158
+ # rewriting these.
159
+ if hasattr(obj, "__name__") and "." not in obj.__name__:
160
+ obj.__name__ = name
161
+ obj.__qualname__ = qualname
162
+ if isinstance(obj, type):
163
+ for attr_name, attr_value in obj.__dict__.items():
164
+ fix_one(objname + "." + attr_name, attr_name, attr_value)
165
+
166
+ if keys is None:
167
+ keys = namespace.keys()
168
+ for objname in keys:
169
+ if not objname.startswith("_"):
170
+ obj = namespace[objname]
171
+ fix_one(objname, objname, obj)
detectron2/utils/events.py ADDED
@@ -0,0 +1,557 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ import datetime
3
+ import json
4
+ import logging
5
+ import os
6
+ import time
7
+ from collections import defaultdict
8
+ from contextlib import contextmanager
9
+ from functools import cached_property
10
+ from typing import Optional
11
+ import torch
12
+ from fvcore.common.history_buffer import HistoryBuffer
13
+
14
+ from detectron2.utils.file_io import PathManager
15
+
16
+ __all__ = [
17
+ "get_event_storage",
18
+ "has_event_storage",
19
+ "JSONWriter",
20
+ "TensorboardXWriter",
21
+ "CommonMetricPrinter",
22
+ "EventStorage",
23
+ ]
24
+
25
+ _CURRENT_STORAGE_STACK = []
26
+
27
+
28
+ def get_event_storage():
29
+ """
30
+ Returns:
31
+ The :class:`EventStorage` object that's currently being used.
32
+ Throws an error if no :class:`EventStorage` is currently enabled.
33
+ """
34
+ assert len(
35
+ _CURRENT_STORAGE_STACK
36
+ ), "get_event_storage() has to be called inside a 'with EventStorage(...)' context!"
37
+ return _CURRENT_STORAGE_STACK[-1]
38
+
39
+
40
+ def has_event_storage():
41
+ """
42
+ Returns:
43
+ Check if there are EventStorage() context existed.
44
+ """
45
+ return len(_CURRENT_STORAGE_STACK) > 0
46
+
47
+
48
+ class EventWriter:
49
+ """
50
+ Base class for writers that obtain events from :class:`EventStorage` and process them.
51
+ """
52
+
53
+ def write(self):
54
+ raise NotImplementedError
55
+
56
+ def close(self):
57
+ pass
58
+
59
+
60
+ class JSONWriter(EventWriter):
61
+ """
62
+ Write scalars to a json file.
63
+
64
+ It saves scalars as one json per line (instead of a big json) for easy parsing.
65
+
66
+ Examples parsing such a json file:
67
+ ::
68
+ $ cat metrics.json | jq -s '.[0:2]'
69
+ [
70
+ {
71
+ "data_time": 0.008433341979980469,
72
+ "iteration": 19,
73
+ "loss": 1.9228371381759644,
74
+ "loss_box_reg": 0.050025828182697296,
75
+ "loss_classifier": 0.5316952466964722,
76
+ "loss_mask": 0.7236229181289673,
77
+ "loss_rpn_box": 0.0856662318110466,
78
+ "loss_rpn_cls": 0.48198649287223816,
79
+ "lr": 0.007173333333333333,
80
+ "time": 0.25401854515075684
81
+ },
82
+ {
83
+ "data_time": 0.007216215133666992,
84
+ "iteration": 39,
85
+ "loss": 1.282649278640747,
86
+ "loss_box_reg": 0.06222952902317047,
87
+ "loss_classifier": 0.30682939291000366,
88
+ "loss_mask": 0.6970193982124329,
89
+ "loss_rpn_box": 0.038663312792778015,
90
+ "loss_rpn_cls": 0.1471673548221588,
91
+ "lr": 0.007706666666666667,
92
+ "time": 0.2490077018737793
93
+ }
94
+ ]
95
+
96
+ $ cat metrics.json | jq '.loss_mask'
97
+ 0.7126231789588928
98
+ 0.689423680305481
99
+ 0.6776131987571716
100
+ ...
101
+
102
+ """
103
+
104
+ def __init__(self, json_file, window_size=20):
105
+ """
106
+ Args:
107
+ json_file (str): path to the json file. New data will be appended if the file exists.
108
+ window_size (int): the window size of median smoothing for the scalars whose
109
+ `smoothing_hint` are True.
110
+ """
111
+ self._file_handle = PathManager.open(json_file, "a")
112
+ self._window_size = window_size
113
+ self._last_write = -1
114
+
115
+ def write(self):
116
+ storage = get_event_storage()
117
+ to_save = defaultdict(dict)
118
+
119
+ for k, (v, iter) in storage.latest_with_smoothing_hint(self._window_size).items():
120
+ # keep scalars that have not been written
121
+ if iter <= self._last_write:
122
+ continue
123
+ to_save[iter][k] = v
124
+ if len(to_save):
125
+ all_iters = sorted(to_save.keys())
126
+ self._last_write = max(all_iters)
127
+
128
+ for itr, scalars_per_iter in to_save.items():
129
+ scalars_per_iter["iteration"] = itr
130
+ self._file_handle.write(json.dumps(scalars_per_iter, sort_keys=True) + "\n")
131
+ self._file_handle.flush()
132
+ try:
133
+ os.fsync(self._file_handle.fileno())
134
+ except AttributeError:
135
+ pass
136
+
137
+ def close(self):
138
+ self._file_handle.close()
139
+
140
+
141
+ class TensorboardXWriter(EventWriter):
142
+ """
143
+ Write all scalars to a tensorboard file.
144
+ """
145
+
146
+ def __init__(self, log_dir: str, window_size: int = 20, **kwargs):
147
+ """
148
+ Args:
149
+ log_dir (str): the directory to save the output events
150
+ window_size (int): the scalars will be median-smoothed by this window size
151
+
152
+ kwargs: other arguments passed to `torch.utils.tensorboard.SummaryWriter(...)`
153
+ """
154
+ self._window_size = window_size
155
+ self._writer_args = {"log_dir": log_dir, **kwargs}
156
+ self._last_write = -1
157
+
158
+ @cached_property
159
+ def _writer(self):
160
+ from torch.utils.tensorboard import SummaryWriter
161
+
162
+ return SummaryWriter(**self._writer_args)
163
+
164
+ def write(self):
165
+ storage = get_event_storage()
166
+ new_last_write = self._last_write
167
+ for k, (v, iter) in storage.latest_with_smoothing_hint(self._window_size).items():
168
+ if iter > self._last_write:
169
+ self._writer.add_scalar(k, v, iter)
170
+ new_last_write = max(new_last_write, iter)
171
+ self._last_write = new_last_write
172
+
173
+ # storage.put_{image,histogram} is only meant to be used by
174
+ # tensorboard writer. So we access its internal fields directly from here.
175
+ if len(storage._vis_data) >= 1:
176
+ for img_name, img, step_num in storage._vis_data:
177
+ self._writer.add_image(img_name, img, step_num)
178
+ # Storage stores all image data and rely on this writer to clear them.
179
+ # As a result it assumes only one writer will use its image data.
180
+ # An alternative design is to let storage store limited recent
181
+ # data (e.g. only the most recent image) that all writers can access.
182
+ # In that case a writer may not see all image data if its period is long.
183
+ storage.clear_images()
184
+
185
+ if len(storage._histograms) >= 1:
186
+ for params in storage._histograms:
187
+ self._writer.add_histogram_raw(**params)
188
+ storage.clear_histograms()
189
+
190
+ def close(self):
191
+ if "_writer" in self.__dict__:
192
+ self._writer.close()
193
+
194
+
195
+ class CommonMetricPrinter(EventWriter):
196
+ """
197
+ Print **common** metrics to the terminal, including
198
+ iteration time, ETA, memory, all losses, and the learning rate.
199
+ It also applies smoothing using a window of 20 elements.
200
+
201
+ It's meant to print common metrics in common ways.
202
+ To print something in more customized ways, please implement a similar printer by yourself.
203
+ """
204
+
205
+ def __init__(self, max_iter: Optional[int] = None, window_size: int = 20):
206
+ """
207
+ Args:
208
+ max_iter: the maximum number of iterations to train.
209
+ Used to compute ETA. If not given, ETA will not be printed.
210
+ window_size (int): the losses will be median-smoothed by this window size
211
+ """
212
+ self.logger = logging.getLogger("detectron2.utils.events")
213
+ self._max_iter = max_iter
214
+ self._window_size = window_size
215
+ self._last_write = None # (step, time) of last call to write(). Used to compute ETA
216
+
217
+ def _get_eta(self, storage) -> Optional[str]:
218
+ if self._max_iter is None:
219
+ return ""
220
+ iteration = storage.iter
221
+ try:
222
+ eta_seconds = storage.history("time").median(1000) * (self._max_iter - iteration - 1)
223
+ storage.put_scalar("eta_seconds", eta_seconds, smoothing_hint=False)
224
+ return str(datetime.timedelta(seconds=int(eta_seconds)))
225
+ except KeyError:
226
+ # estimate eta on our own - more noisy
227
+ eta_string = None
228
+ if self._last_write is not None:
229
+ estimate_iter_time = (time.perf_counter() - self._last_write[1]) / (
230
+ iteration - self._last_write[0]
231
+ )
232
+ eta_seconds = estimate_iter_time * (self._max_iter - iteration - 1)
233
+ eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
234
+ self._last_write = (iteration, time.perf_counter())
235
+ return eta_string
236
+
237
+ def write(self):
238
+ storage = get_event_storage()
239
+ iteration = storage.iter
240
+ if iteration == self._max_iter:
241
+ # This hook only reports training progress (loss, ETA, etc) but not other data,
242
+ # therefore do not write anything after training succeeds, even if this method
243
+ # is called.
244
+ return
245
+
246
+ try:
247
+ avg_data_time = storage.history("data_time").avg(
248
+ storage.count_samples("data_time", self._window_size)
249
+ )
250
+ last_data_time = storage.history("data_time").latest()
251
+ except KeyError:
252
+ # they may not exist in the first few iterations (due to warmup)
253
+ # or when SimpleTrainer is not used
254
+ avg_data_time = None
255
+ last_data_time = None
256
+ try:
257
+ avg_iter_time = storage.history("time").global_avg()
258
+ last_iter_time = storage.history("time").latest()
259
+ except KeyError:
260
+ avg_iter_time = None
261
+ last_iter_time = None
262
+ try:
263
+ lr = "{:.5g}".format(storage.history("lr").latest())
264
+ except KeyError:
265
+ lr = "N/A"
266
+
267
+ eta_string = self._get_eta(storage)
268
+
269
+ if torch.cuda.is_available():
270
+ max_mem_mb = torch.cuda.max_memory_allocated() / 1024.0 / 1024.0
271
+ else:
272
+ max_mem_mb = None
273
+
274
+ # NOTE: max_mem is parsed by grep in "dev/parse_results.sh"
275
+ self.logger.info(
276
+ str.format(
277
+ " {eta}iter: {iter} {losses} {non_losses} {avg_time}{last_time}"
278
+ + "{avg_data_time}{last_data_time} lr: {lr} {memory}",
279
+ eta=f"eta: {eta_string} " if eta_string else "",
280
+ iter=iteration,
281
+ losses=" ".join(
282
+ [
283
+ "{}: {:.4g}".format(
284
+ k, v.median(storage.count_samples(k, self._window_size))
285
+ )
286
+ for k, v in storage.histories().items()
287
+ if "loss" in k
288
+ ]
289
+ ),
290
+ non_losses=" ".join(
291
+ [
292
+ "{}: {:.4g}".format(
293
+ k, v.median(storage.count_samples(k, self._window_size))
294
+ )
295
+ for k, v in storage.histories().items()
296
+ if "[metric]" in k
297
+ ]
298
+ ),
299
+ avg_time=(
300
+ "time: {:.4f} ".format(avg_iter_time) if avg_iter_time is not None else ""
301
+ ),
302
+ last_time=(
303
+ "last_time: {:.4f} ".format(last_iter_time)
304
+ if last_iter_time is not None
305
+ else ""
306
+ ),
307
+ avg_data_time=(
308
+ "data_time: {:.4f} ".format(avg_data_time) if avg_data_time is not None else ""
309
+ ),
310
+ last_data_time=(
311
+ "last_data_time: {:.4f} ".format(last_data_time)
312
+ if last_data_time is not None
313
+ else ""
314
+ ),
315
+ lr=lr,
316
+ memory="max_mem: {:.0f}M".format(max_mem_mb) if max_mem_mb is not None else "",
317
+ )
318
+ )
319
+
320
+
321
+ class EventStorage:
322
+ """
323
+ The user-facing class that provides metric storage functionalities.
324
+
325
+ In the future we may add support for storing / logging other types of data if needed.
326
+ """
327
+
328
+ def __init__(self, start_iter=0):
329
+ """
330
+ Args:
331
+ start_iter (int): the iteration number to start with
332
+ """
333
+ self._history = defaultdict(HistoryBuffer)
334
+ self._smoothing_hints = {}
335
+ self._latest_scalars = {}
336
+ self._iter = start_iter
337
+ self._current_prefix = ""
338
+ self._vis_data = []
339
+ self._histograms = []
340
+
341
+ def put_image(self, img_name, img_tensor):
342
+ """
343
+ Add an `img_tensor` associated with `img_name`, to be shown on
344
+ tensorboard.
345
+
346
+ Args:
347
+ img_name (str): The name of the image to put into tensorboard.
348
+ img_tensor (torch.Tensor or numpy.array): An `uint8` or `float`
349
+ Tensor of shape `[channel, height, width]` where `channel` is
350
+ 3. The image format should be RGB. The elements in img_tensor
351
+ can either have values in [0, 1] (float32) or [0, 255] (uint8).
352
+ The `img_tensor` will be visualized in tensorboard.
353
+ """
354
+ self._vis_data.append((img_name, img_tensor, self._iter))
355
+
356
+ def put_scalar(self, name, value, smoothing_hint=True, cur_iter=None):
357
+ """
358
+ Add a scalar `value` to the `HistoryBuffer` associated with `name`.
359
+
360
+ Args:
361
+ smoothing_hint (bool): a 'hint' on whether this scalar is noisy and should be
362
+ smoothed when logged. The hint will be accessible through
363
+ :meth:`EventStorage.smoothing_hints`. A writer may ignore the hint
364
+ and apply custom smoothing rule.
365
+
366
+ It defaults to True because most scalars we save need to be smoothed to
367
+ provide any useful signal.
368
+ cur_iter (int): an iteration number to set explicitly instead of current iteration
369
+ """
370
+ name = self._current_prefix + name
371
+ cur_iter = self._iter if cur_iter is None else cur_iter
372
+ history = self._history[name]
373
+ value = float(value)
374
+ history.update(value, cur_iter)
375
+ self._latest_scalars[name] = (value, cur_iter)
376
+
377
+ existing_hint = self._smoothing_hints.get(name)
378
+
379
+ if existing_hint is not None:
380
+ assert (
381
+ existing_hint == smoothing_hint
382
+ ), "Scalar {} was put with a different smoothing_hint!".format(name)
383
+ else:
384
+ self._smoothing_hints[name] = smoothing_hint
385
+
386
+ def put_scalars(self, *, smoothing_hint=True, cur_iter=None, **kwargs):
387
+ """
388
+ Put multiple scalars from keyword arguments.
389
+
390
+ Examples:
391
+
392
+ storage.put_scalars(loss=my_loss, accuracy=my_accuracy, smoothing_hint=True)
393
+ """
394
+ for k, v in kwargs.items():
395
+ self.put_scalar(k, v, smoothing_hint=smoothing_hint, cur_iter=cur_iter)
396
+
397
+ def put_histogram(self, hist_name, hist_tensor, bins=1000):
398
+ """
399
+ Create a histogram from a tensor.
400
+
401
+ Args:
402
+ hist_name (str): The name of the histogram to put into tensorboard.
403
+ hist_tensor (torch.Tensor): A Tensor of arbitrary shape to be converted
404
+ into a histogram.
405
+ bins (int): Number of histogram bins.
406
+ """
407
+ ht_min, ht_max = hist_tensor.min().item(), hist_tensor.max().item()
408
+
409
+ # Create a histogram with PyTorch
410
+ hist_counts = torch.histc(hist_tensor, bins=bins)
411
+ hist_edges = torch.linspace(start=ht_min, end=ht_max, steps=bins + 1, dtype=torch.float32)
412
+
413
+ # Parameter for the add_histogram_raw function of SummaryWriter
414
+ hist_params = dict(
415
+ tag=hist_name,
416
+ min=ht_min,
417
+ max=ht_max,
418
+ num=len(hist_tensor),
419
+ sum=float(hist_tensor.sum()),
420
+ sum_squares=float(torch.sum(hist_tensor**2)),
421
+ bucket_limits=hist_edges[1:].tolist(),
422
+ bucket_counts=hist_counts.tolist(),
423
+ global_step=self._iter,
424
+ )
425
+ self._histograms.append(hist_params)
426
+
427
+ def history(self, name):
428
+ """
429
+ Returns:
430
+ HistoryBuffer: the scalar history for name
431
+ """
432
+ ret = self._history.get(name, None)
433
+ if ret is None:
434
+ raise KeyError("No history metric available for {}!".format(name))
435
+ return ret
436
+
437
+ def histories(self):
438
+ """
439
+ Returns:
440
+ dict[name -> HistoryBuffer]: the HistoryBuffer for all scalars
441
+ """
442
+ return self._history
443
+
444
+ def latest(self):
445
+ """
446
+ Returns:
447
+ dict[str -> (float, int)]: mapping from the name of each scalar to the most
448
+ recent value and the iteration number its added.
449
+ """
450
+ return self._latest_scalars
451
+
452
+ def latest_with_smoothing_hint(self, window_size=20):
453
+ """
454
+ Similar to :meth:`latest`, but the returned values
455
+ are either the un-smoothed original latest value,
456
+ or a median of the given window_size,
457
+ depend on whether the smoothing_hint is True.
458
+
459
+ This provides a default behavior that other writers can use.
460
+
461
+ Note: All scalars saved in the past `window_size` iterations are used for smoothing.
462
+ This is different from the `window_size` definition in HistoryBuffer.
463
+ Use :meth:`get_history_window_size` to get the `window_size` used in HistoryBuffer.
464
+ """
465
+ result = {}
466
+ for k, (v, itr) in self._latest_scalars.items():
467
+ result[k] = (
468
+ (
469
+ self._history[k].median(self.count_samples(k, window_size))
470
+ if self._smoothing_hints[k]
471
+ else v
472
+ ),
473
+ itr,
474
+ )
475
+ return result
476
+
477
+ def count_samples(self, name, window_size=20):
478
+ """
479
+ Return the number of samples logged in the past `window_size` iterations.
480
+ """
481
+ samples = 0
482
+ data = self._history[name].values()
483
+ for _, iter_ in reversed(data):
484
+ if iter_ > data[-1][1] - window_size:
485
+ samples += 1
486
+ else:
487
+ break
488
+ return samples
489
+
490
+ def smoothing_hints(self):
491
+ """
492
+ Returns:
493
+ dict[name -> bool]: the user-provided hint on whether the scalar
494
+ is noisy and needs smoothing.
495
+ """
496
+ return self._smoothing_hints
497
+
498
+ def step(self):
499
+ """
500
+ User should either: (1) Call this function to increment storage.iter when needed. Or
501
+ (2) Set `storage.iter` to the correct iteration number before each iteration.
502
+
503
+ The storage will then be able to associate the new data with an iteration number.
504
+ """
505
+ self._iter += 1
506
+
507
+ @property
508
+ def iter(self):
509
+ """
510
+ Returns:
511
+ int: The current iteration number. When used together with a trainer,
512
+ this is ensured to be the same as trainer.iter.
513
+ """
514
+ return self._iter
515
+
516
+ @iter.setter
517
+ def iter(self, val):
518
+ self._iter = int(val)
519
+
520
+ @property
521
+ def iteration(self):
522
+ # for backward compatibility
523
+ return self._iter
524
+
525
+ def __enter__(self):
526
+ _CURRENT_STORAGE_STACK.append(self)
527
+ return self
528
+
529
+ def __exit__(self, exc_type, exc_val, exc_tb):
530
+ assert _CURRENT_STORAGE_STACK[-1] == self
531
+ _CURRENT_STORAGE_STACK.pop()
532
+
533
+ @contextmanager
534
+ def name_scope(self, name):
535
+ """
536
+ Yields:
537
+ A context within which all the events added to this storage
538
+ will be prefixed by the name scope.
539
+ """
540
+ old_prefix = self._current_prefix
541
+ self._current_prefix = name.rstrip("/") + "/"
542
+ yield
543
+ self._current_prefix = old_prefix
544
+
545
+ def clear_images(self):
546
+ """
547
+ Delete all the stored images for visualization. This should be called
548
+ after images are written to tensorboard.
549
+ """
550
+ self._vis_data = []
551
+
552
+ def clear_histograms(self):
553
+ """
554
+ Delete all the stored histograms for visualization.
555
+ This should be called after histograms are written to tensorboard.
556
+ """
557
+ self._histograms = []
detectron2/utils/file_io.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ from iopath.common.file_io import HTTPURLHandler, OneDrivePathHandler, PathHandler
3
+ from iopath.common.file_io import PathManager as PathManagerBase
4
+
5
+ __all__ = ["PathManager", "PathHandler"]
6
+
7
+
8
+ PathManager = PathManagerBase()
9
+ """
10
+ This is a detectron2 project-specific PathManager.
11
+ We try to stay away from global PathManager in fvcore as it
12
+ introduces potential conflicts among other libraries.
13
+ """
14
+
15
+
16
+ class Detectron2Handler(PathHandler):
17
+ """
18
+ Resolve anything that's hosted under detectron2's namespace.
19
+ """
20
+
21
+ PREFIX = "detectron2://"
22
+ S3_DETECTRON2_PREFIX = "https://dl.fbaipublicfiles.com/detectron2/"
23
+
24
+ def _get_supported_prefixes(self):
25
+ return [self.PREFIX]
26
+
27
+ def _get_local_path(self, path, **kwargs):
28
+ name = path[len(self.PREFIX) :]
29
+ return PathManager.get_local_path(self.S3_DETECTRON2_PREFIX + name, **kwargs)
30
+
31
+ def _open(self, path, mode="r", **kwargs):
32
+ return PathManager.open(
33
+ self.S3_DETECTRON2_PREFIX + path[len(self.PREFIX) :], mode, **kwargs
34
+ )
35
+
36
+
37
+ PathManager.register_handler(HTTPURLHandler())
38
+ PathManager.register_handler(OneDrivePathHandler())
39
+ PathManager.register_handler(Detectron2Handler())
detectron2/utils/logger.py ADDED
@@ -0,0 +1,261 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ import atexit
3
+ import functools
4
+ import logging
5
+ import os
6
+ import sys
7
+ import time
8
+ from collections import Counter
9
+ import torch
10
+ from tabulate import tabulate
11
+ from termcolor import colored
12
+
13
+ from detectron2.utils.file_io import PathManager
14
+
15
+ __all__ = ["setup_logger", "log_first_n", "log_every_n", "log_every_n_seconds"]
16
+
17
+ D2_LOG_BUFFER_SIZE_KEY: str = "D2_LOG_BUFFER_SIZE"
18
+
19
+ DEFAULT_LOG_BUFFER_SIZE: int = 1024 * 1024 # 1MB
20
+
21
+
22
+ class _ColorfulFormatter(logging.Formatter):
23
+ def __init__(self, *args, **kwargs):
24
+ self._root_name = kwargs.pop("root_name") + "."
25
+ self._abbrev_name = kwargs.pop("abbrev_name", "")
26
+ if len(self._abbrev_name):
27
+ self._abbrev_name = self._abbrev_name + "."
28
+ super(_ColorfulFormatter, self).__init__(*args, **kwargs)
29
+
30
+ def formatMessage(self, record):
31
+ record.name = record.name.replace(self._root_name, self._abbrev_name)
32
+ log = super(_ColorfulFormatter, self).formatMessage(record)
33
+ if record.levelno == logging.WARNING:
34
+ prefix = colored("WARNING", "red", attrs=["blink"])
35
+ elif record.levelno == logging.ERROR or record.levelno == logging.CRITICAL:
36
+ prefix = colored("ERROR", "red", attrs=["blink", "underline"])
37
+ else:
38
+ return log
39
+ return prefix + " " + log
40
+
41
+
42
+ @functools.lru_cache() # so that calling setup_logger multiple times won't add many handlers
43
+ def setup_logger(
44
+ output=None,
45
+ distributed_rank=0,
46
+ *,
47
+ color=True,
48
+ name="detectron2",
49
+ abbrev_name=None,
50
+ enable_propagation: bool = False,
51
+ configure_stdout: bool = True
52
+ ):
53
+ """
54
+ Initialize the detectron2 logger and set its verbosity level to "DEBUG".
55
+
56
+ Args:
57
+ output (str): a file name or a directory to save log. If None, will not save log file.
58
+ If ends with ".txt" or ".log", assumed to be a file name.
59
+ Otherwise, logs will be saved to `output/log.txt`.
60
+ name (str): the root module name of this logger
61
+ abbrev_name (str): an abbreviation of the module, to avoid long names in logs.
62
+ Set to "" to not log the root module in logs.
63
+ By default, will abbreviate "detectron2" to "d2" and leave other
64
+ modules unchanged.
65
+ enable_propagation (bool): whether to propagate logs to the parent logger.
66
+ configure_stdout (bool): whether to configure logging to stdout.
67
+
68
+
69
+ Returns:
70
+ logging.Logger: a logger
71
+ """
72
+ logger = logging.getLogger(name)
73
+ logger.setLevel(logging.DEBUG)
74
+ logger.propagate = enable_propagation
75
+
76
+ if abbrev_name is None:
77
+ abbrev_name = "d2" if name == "detectron2" else name
78
+
79
+ plain_formatter = logging.Formatter(
80
+ "[%(asctime)s] %(name)s %(levelname)s: %(message)s", datefmt="%m/%d %H:%M:%S"
81
+ )
82
+ # stdout logging: master only
83
+ if configure_stdout and distributed_rank == 0:
84
+ ch = logging.StreamHandler(stream=sys.stdout)
85
+ ch.setLevel(logging.DEBUG)
86
+ if color:
87
+ formatter = _ColorfulFormatter(
88
+ colored("[%(asctime)s %(name)s]: ", "green") + "%(message)s",
89
+ datefmt="%m/%d %H:%M:%S",
90
+ root_name=name,
91
+ abbrev_name=str(abbrev_name),
92
+ )
93
+ else:
94
+ formatter = plain_formatter
95
+ ch.setFormatter(formatter)
96
+ logger.addHandler(ch)
97
+
98
+ # file logging: all workers
99
+ if output is not None:
100
+ if output.endswith(".txt") or output.endswith(".log"):
101
+ filename = output
102
+ else:
103
+ filename = os.path.join(output, "log.txt")
104
+ if distributed_rank > 0:
105
+ filename = filename + ".rank{}".format(distributed_rank)
106
+ PathManager.mkdirs(os.path.dirname(filename))
107
+
108
+ fh = logging.StreamHandler(_cached_log_stream(filename))
109
+ fh.setLevel(logging.DEBUG)
110
+ fh.setFormatter(plain_formatter)
111
+ logger.addHandler(fh)
112
+
113
+ return logger
114
+
115
+
116
+ # cache the opened file object, so that different calls to `setup_logger`
117
+ # with the same file name can safely write to the same file.
118
+ @functools.lru_cache(maxsize=None)
119
+ def _cached_log_stream(filename):
120
+ # use 1K buffer if writing to cloud storage
121
+ io = PathManager.open(filename, "a", buffering=_get_log_stream_buffer_size(filename))
122
+ atexit.register(io.close)
123
+ return io
124
+
125
+
126
+ def _get_log_stream_buffer_size(filename: str) -> int:
127
+ if "://" not in filename:
128
+ # Local file, no extra caching is necessary
129
+ return -1
130
+ # Remote file requires a larger cache to avoid many small writes.
131
+ if D2_LOG_BUFFER_SIZE_KEY in os.environ:
132
+ return int(os.environ[D2_LOG_BUFFER_SIZE_KEY])
133
+ return DEFAULT_LOG_BUFFER_SIZE
134
+
135
+
136
+ """
137
+ Below are some other convenient logging methods.
138
+ They are mainly adopted from
139
+ https://github.com/abseil/abseil-py/blob/master/absl/logging/__init__.py
140
+ """
141
+
142
+
143
+ def _find_caller():
144
+ """
145
+ Returns:
146
+ str: module name of the caller
147
+ tuple: a hashable key to be used to identify different callers
148
+ """
149
+ frame = sys._getframe(2)
150
+ while frame:
151
+ code = frame.f_code
152
+ if os.path.join("utils", "logger.") not in code.co_filename:
153
+ mod_name = frame.f_globals["__name__"]
154
+ if mod_name == "__main__":
155
+ mod_name = "detectron2"
156
+ return mod_name, (code.co_filename, frame.f_lineno, code.co_name)
157
+ frame = frame.f_back
158
+
159
+
160
+ _LOG_COUNTER = Counter()
161
+ _LOG_TIMER = {}
162
+
163
+
164
+ def log_first_n(lvl, msg, n=1, *, name=None, key="caller"):
165
+ """
166
+ Log only for the first n times.
167
+
168
+ Args:
169
+ lvl (int): the logging level
170
+ msg (str):
171
+ n (int):
172
+ name (str): name of the logger to use. Will use the caller's module by default.
173
+ key (str or tuple[str]): the string(s) can be one of "caller" or
174
+ "message", which defines how to identify duplicated logs.
175
+ For example, if called with `n=1, key="caller"`, this function
176
+ will only log the first call from the same caller, regardless of
177
+ the message content.
178
+ If called with `n=1, key="message"`, this function will log the
179
+ same content only once, even if they are called from different places.
180
+ If called with `n=1, key=("caller", "message")`, this function
181
+ will not log only if the same caller has logged the same message before.
182
+ """
183
+ if isinstance(key, str):
184
+ key = (key,)
185
+ assert len(key) > 0
186
+
187
+ caller_module, caller_key = _find_caller()
188
+ hash_key = ()
189
+ if "caller" in key:
190
+ hash_key = hash_key + caller_key
191
+ if "message" in key:
192
+ hash_key = hash_key + (msg,)
193
+
194
+ _LOG_COUNTER[hash_key] += 1
195
+ if _LOG_COUNTER[hash_key] <= n:
196
+ logging.getLogger(name or caller_module).log(lvl, msg)
197
+
198
+
199
+ def log_every_n(lvl, msg, n=1, *, name=None):
200
+ """
201
+ Log once per n times.
202
+
203
+ Args:
204
+ lvl (int): the logging level
205
+ msg (str):
206
+ n (int):
207
+ name (str): name of the logger to use. Will use the caller's module by default.
208
+ """
209
+ caller_module, key = _find_caller()
210
+ _LOG_COUNTER[key] += 1
211
+ if n == 1 or _LOG_COUNTER[key] % n == 1:
212
+ logging.getLogger(name or caller_module).log(lvl, msg)
213
+
214
+
215
+ def log_every_n_seconds(lvl, msg, n=1, *, name=None):
216
+ """
217
+ Log no more than once per n seconds.
218
+
219
+ Args:
220
+ lvl (int): the logging level
221
+ msg (str):
222
+ n (int):
223
+ name (str): name of the logger to use. Will use the caller's module by default.
224
+ """
225
+ caller_module, key = _find_caller()
226
+ last_logged = _LOG_TIMER.get(key, None)
227
+ current_time = time.time()
228
+ if last_logged is None or current_time - last_logged >= n:
229
+ logging.getLogger(name or caller_module).log(lvl, msg)
230
+ _LOG_TIMER[key] = current_time
231
+
232
+
233
+ def create_small_table(small_dict):
234
+ """
235
+ Create a small table using the keys of small_dict as headers. This is only
236
+ suitable for small dictionaries.
237
+
238
+ Args:
239
+ small_dict (dict): a result dictionary of only a few items.
240
+
241
+ Returns:
242
+ str: the table as a string.
243
+ """
244
+ keys, values = tuple(zip(*small_dict.items()))
245
+ table = tabulate(
246
+ [values],
247
+ headers=keys,
248
+ tablefmt="pipe",
249
+ floatfmt=".3f",
250
+ stralign="center",
251
+ numalign="center",
252
+ )
253
+ return table
254
+
255
+
256
+ def _log_api_usage(identifier: str):
257
+ """
258
+ Internal function used to log the usage of different detectron2 components
259
+ inside facebook's infra.
260
+ """
261
+ torch._C._log_api_usage_once("detectron2." + identifier)
detectron2/utils/memory.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ import logging
4
+ from contextlib import contextmanager
5
+ from functools import wraps
6
+ import torch
7
+
8
+ __all__ = ["retry_if_cuda_oom"]
9
+
10
+
11
+ @contextmanager
12
+ def _ignore_torch_cuda_oom():
13
+ """
14
+ A context which ignores CUDA OOM exception from pytorch.
15
+ """
16
+ try:
17
+ yield
18
+ except RuntimeError as e:
19
+ # NOTE: the string may change?
20
+ if "CUDA out of memory. " in str(e):
21
+ pass
22
+ else:
23
+ raise
24
+
25
+
26
+ def retry_if_cuda_oom(func):
27
+ """
28
+ Makes a function retry itself after encountering
29
+ pytorch's CUDA OOM error.
30
+ It will first retry after calling `torch.cuda.empty_cache()`.
31
+
32
+ If that still fails, it will then retry by trying to convert inputs to CPUs.
33
+ In this case, it expects the function to dispatch to CPU implementation.
34
+ The return values may become CPU tensors as well and it's user's
35
+ responsibility to convert it back to CUDA tensor if needed.
36
+
37
+ Args:
38
+ func: a stateless callable that takes tensor-like objects as arguments
39
+
40
+ Returns:
41
+ a callable which retries `func` if OOM is encountered.
42
+
43
+ Examples:
44
+ ::
45
+ output = retry_if_cuda_oom(some_torch_function)(input1, input2)
46
+ # output may be on CPU even if inputs are on GPU
47
+
48
+ Note:
49
+ 1. When converting inputs to CPU, it will only look at each argument and check
50
+ if it has `.device` and `.to` for conversion. Nested structures of tensors
51
+ are not supported.
52
+
53
+ 2. Since the function might be called more than once, it has to be
54
+ stateless.
55
+ """
56
+
57
+ def maybe_to_cpu(x):
58
+ try:
59
+ like_gpu_tensor = x.device.type == "cuda" and hasattr(x, "to")
60
+ except AttributeError:
61
+ like_gpu_tensor = False
62
+ if like_gpu_tensor:
63
+ return x.to(device="cpu")
64
+ else:
65
+ return x
66
+
67
+ @wraps(func)
68
+ def wrapped(*args, **kwargs):
69
+ with _ignore_torch_cuda_oom():
70
+ return func(*args, **kwargs)
71
+
72
+ # Clear cache and retry
73
+ torch.cuda.empty_cache()
74
+ with _ignore_torch_cuda_oom():
75
+ return func(*args, **kwargs)
76
+
77
+ # Try on CPU. This slows down the code significantly, therefore print a notice.
78
+ logger = logging.getLogger(__name__)
79
+ logger.info("Attempting to copy inputs of {} to CPU due to CUDA OOM".format(str(func)))
80
+ new_args = (maybe_to_cpu(x) for x in args)
81
+ new_kwargs = {k: maybe_to_cpu(v) for k, v in kwargs.items()}
82
+ return func(*new_args, **new_kwargs)
83
+
84
+ return wrapped
detectron2/utils/registry.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ from typing import Any
4
+ import pydoc
5
+ from fvcore.common.registry import Registry # for backward compatibility.
6
+
7
+ """
8
+ ``Registry`` and `locate` provide ways to map a string (typically found
9
+ in config files) to callable objects.
10
+ """
11
+
12
+ __all__ = ["Registry", "locate"]
13
+
14
+
15
+ def _convert_target_to_string(t: Any) -> str:
16
+ """
17
+ Inverse of ``locate()``.
18
+
19
+ Args:
20
+ t: any object with ``__module__`` and ``__qualname__``
21
+ """
22
+ module, qualname = t.__module__, t.__qualname__
23
+
24
+ # Compress the path to this object, e.g. ``module.submodule._impl.class``
25
+ # may become ``module.submodule.class``, if the later also resolves to the same
26
+ # object. This simplifies the string, and also is less affected by moving the
27
+ # class implementation.
28
+ module_parts = module.split(".")
29
+ for k in range(1, len(module_parts)):
30
+ prefix = ".".join(module_parts[:k])
31
+ candidate = f"{prefix}.{qualname}"
32
+ try:
33
+ if locate(candidate) is t:
34
+ return candidate
35
+ except ImportError:
36
+ pass
37
+ return f"{module}.{qualname}"
38
+
39
+
40
+ def locate(name: str) -> Any:
41
+ """
42
+ Locate and return an object ``x`` using an input string ``{x.__module__}.{x.__qualname__}``,
43
+ such as "module.submodule.class_name".
44
+
45
+ Raise Exception if it cannot be found.
46
+ """
47
+ obj = pydoc.locate(name)
48
+
49
+ # Some cases (e.g. torch.optim.sgd.SGD) not handled correctly
50
+ # by pydoc.locate. Try a private function from hydra.
51
+ if obj is None:
52
+ try:
53
+ # from hydra.utils import get_method - will print many errors
54
+ from hydra.utils import _locate
55
+ except ImportError as e:
56
+ raise ImportError(f"Cannot dynamically locate object {name}!") from e
57
+ else:
58
+ obj = _locate(name) # it raises if fails
59
+
60
+ return obj
detectron2/utils/serialize.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ import cloudpickle
3
+
4
+
5
+ class PicklableWrapper:
6
+ """
7
+ Wrap an object to make it more picklable, note that it uses
8
+ heavy weight serialization libraries that are slower than pickle.
9
+ It's best to use it only on closures (which are usually not picklable).
10
+
11
+ This is a simplified version of
12
+ https://github.com/joblib/joblib/blob/master/joblib/externals/loky/cloudpickle_wrapper.py
13
+ """
14
+
15
+ def __init__(self, obj):
16
+ while isinstance(obj, PicklableWrapper):
17
+ # Wrapping an object twice is no-op
18
+ obj = obj._obj
19
+ self._obj = obj
20
+
21
+ def __reduce__(self):
22
+ s = cloudpickle.dumps(self._obj)
23
+ return cloudpickle.loads, (s,)
24
+
25
+ def __call__(self, *args, **kwargs):
26
+ return self._obj(*args, **kwargs)
27
+
28
+ def __getattr__(self, attr):
29
+ # Ensure that the wrapped object can be used seamlessly as the previous object.
30
+ if attr not in ["_obj"]:
31
+ return getattr(self._obj, attr)
32
+ return getattr(self, attr)
detectron2/utils/testing.py ADDED
@@ -0,0 +1,478 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ import io
3
+ import numpy as np
4
+ import os
5
+ import re
6
+ import tempfile
7
+ import unittest
8
+ from typing import Callable
9
+ import torch
10
+ import torch.onnx.symbolic_helper as sym_help
11
+ from packaging import version
12
+ from torch._C import ListType
13
+ from torch.onnx import register_custom_op_symbolic
14
+
15
+ from detectron2 import model_zoo
16
+ from detectron2.config import CfgNode, LazyConfig, instantiate
17
+ from detectron2.data import DatasetCatalog
18
+ from detectron2.data.detection_utils import read_image
19
+ from detectron2.modeling import build_model
20
+ from detectron2.structures import Boxes, Instances, ROIMasks
21
+ from detectron2.utils.file_io import PathManager
22
+
23
+
24
+ """
25
+ Internal utilities for tests. Don't use except for writing tests.
26
+ """
27
+
28
+
29
+ def get_model_no_weights(config_path):
30
+ """
31
+ Like model_zoo.get, but do not load any weights (even pretrained)
32
+ """
33
+ cfg = model_zoo.get_config(config_path)
34
+ if isinstance(cfg, CfgNode):
35
+ if not torch.cuda.is_available():
36
+ cfg.MODEL.DEVICE = "cpu"
37
+ return build_model(cfg)
38
+ else:
39
+ return instantiate(cfg.model)
40
+
41
+
42
+ def random_boxes(num_boxes, max_coord=100, device="cpu"):
43
+ """
44
+ Create a random Nx4 boxes tensor, with coordinates < max_coord.
45
+ """
46
+ boxes = torch.rand(num_boxes, 4, device=device) * (max_coord * 0.5)
47
+ boxes.clamp_(min=1.0) # tiny boxes cause numerical instability in box regression
48
+ # Note: the implementation of this function in torchvision is:
49
+ # boxes[:, 2:] += torch.rand(N, 2) * 100
50
+ # but it does not guarantee non-negative widths/heights constraints:
51
+ # boxes[:, 2] >= boxes[:, 0] and boxes[:, 3] >= boxes[:, 1]:
52
+ boxes[:, 2:] += boxes[:, :2]
53
+ return boxes
54
+
55
+
56
+ def get_sample_coco_image(tensor=True):
57
+ """
58
+ Args:
59
+ tensor (bool): if True, returns 3xHxW tensor.
60
+ else, returns a HxWx3 numpy array.
61
+
62
+ Returns:
63
+ an image, in BGR color.
64
+ """
65
+ try:
66
+ file_name = DatasetCatalog.get("coco_2017_val_100")[0]["file_name"]
67
+ if not PathManager.exists(file_name):
68
+ raise FileNotFoundError()
69
+ except IOError:
70
+ # for public CI to run
71
+ file_name = PathManager.get_local_path(
72
+ "http://images.cocodataset.org/train2017/000000000009.jpg"
73
+ )
74
+ ret = read_image(file_name, format="BGR")
75
+ if tensor:
76
+ ret = torch.from_numpy(np.ascontiguousarray(ret.transpose(2, 0, 1)))
77
+ return ret
78
+
79
+
80
+ def convert_scripted_instances(instances):
81
+ """
82
+ Convert a scripted Instances object to a regular :class:`Instances` object
83
+ """
84
+ assert hasattr(
85
+ instances, "image_size"
86
+ ), f"Expect an Instances object, but got {type(instances)}!"
87
+ ret = Instances(instances.image_size)
88
+ for name in instances._field_names:
89
+ val = getattr(instances, "_" + name, None)
90
+ if val is not None:
91
+ ret.set(name, val)
92
+ return ret
93
+
94
+
95
+ def assert_instances_allclose(input, other, *, rtol=1e-5, msg="", size_as_tensor=False):
96
+ """
97
+ Args:
98
+ input, other (Instances):
99
+ size_as_tensor: compare image_size of the Instances as tensors (instead of tuples).
100
+ Useful for comparing outputs of tracing.
101
+ """
102
+ if not isinstance(input, Instances):
103
+ input = convert_scripted_instances(input)
104
+ if not isinstance(other, Instances):
105
+ other = convert_scripted_instances(other)
106
+
107
+ if not msg:
108
+ msg = "Two Instances are different! "
109
+ else:
110
+ msg = msg.rstrip() + " "
111
+
112
+ size_error_msg = msg + f"image_size is {input.image_size} vs. {other.image_size}!"
113
+ if size_as_tensor:
114
+ assert torch.equal(
115
+ torch.tensor(input.image_size), torch.tensor(other.image_size)
116
+ ), size_error_msg
117
+ else:
118
+ assert input.image_size == other.image_size, size_error_msg
119
+ fields = sorted(input.get_fields().keys())
120
+ fields_other = sorted(other.get_fields().keys())
121
+ assert fields == fields_other, msg + f"Fields are {fields} vs {fields_other}!"
122
+
123
+ for f in fields:
124
+ val1, val2 = input.get(f), other.get(f)
125
+ if isinstance(val1, (Boxes, ROIMasks)):
126
+ # boxes in the range of O(100) and can have a larger tolerance
127
+ assert torch.allclose(val1.tensor, val2.tensor, atol=100 * rtol), (
128
+ msg + f"Field {f} differs too much!"
129
+ )
130
+ elif isinstance(val1, torch.Tensor):
131
+ if val1.dtype.is_floating_point:
132
+ mag = torch.abs(val1).max().cpu().item()
133
+ assert torch.allclose(val1, val2, atol=mag * rtol), (
134
+ msg + f"Field {f} differs too much!"
135
+ )
136
+ else:
137
+ assert torch.equal(val1, val2), msg + f"Field {f} is different!"
138
+ else:
139
+ raise ValueError(f"Don't know how to compare type {type(val1)}")
140
+
141
+
142
+ def reload_script_model(module):
143
+ """
144
+ Save a jit module and load it back.
145
+ Similar to the `getExportImportCopy` function in torch/testing/
146
+ """
147
+ buffer = io.BytesIO()
148
+ torch.jit.save(module, buffer)
149
+ buffer.seek(0)
150
+ return torch.jit.load(buffer)
151
+
152
+
153
+ def reload_lazy_config(cfg):
154
+ """
155
+ Save an object by LazyConfig.save and load it back.
156
+ This is used to test that a config still works the same after
157
+ serialization/deserialization.
158
+ """
159
+ with tempfile.TemporaryDirectory(prefix="detectron2") as d:
160
+ fname = os.path.join(d, "d2_cfg_test.yaml")
161
+ LazyConfig.save(cfg, fname)
162
+ return LazyConfig.load(fname)
163
+
164
+
165
+ def min_torch_version(min_version: str) -> bool:
166
+ """
167
+ Returns True when torch's version is at least `min_version`.
168
+ """
169
+ try:
170
+ import torch
171
+ except ImportError:
172
+ return False
173
+
174
+ installed_version = version.parse(torch.__version__.split("+")[0])
175
+ min_version = version.parse(min_version)
176
+ return installed_version >= min_version
177
+
178
+
179
+ def has_dynamic_axes(onnx_model):
180
+ """
181
+ Return True when all ONNX input/output have only dynamic axes for all ranks
182
+ """
183
+ return all(
184
+ not dim.dim_param.isnumeric()
185
+ for inp in onnx_model.graph.input
186
+ for dim in inp.type.tensor_type.shape.dim
187
+ ) and all(
188
+ not dim.dim_param.isnumeric()
189
+ for out in onnx_model.graph.output
190
+ for dim in out.type.tensor_type.shape.dim
191
+ )
192
+
193
+
194
+ def register_custom_op_onnx_export(
195
+ opname: str, symbolic_fn: Callable, opset_version: int, min_version: str
196
+ ) -> None:
197
+ """
198
+ Register `symbolic_fn` as PyTorch's symbolic `opname`-`opset_version` for ONNX export.
199
+ The registration is performed only when current PyTorch's version is < `min_version.`
200
+ IMPORTANT: symbolic must be manually unregistered after the caller function returns
201
+ """
202
+ if min_torch_version(min_version):
203
+ return
204
+ register_custom_op_symbolic(opname, symbolic_fn, opset_version)
205
+ print(f"_register_custom_op_onnx_export({opname}, {opset_version}) succeeded.")
206
+
207
+
208
+ def unregister_custom_op_onnx_export(opname: str, opset_version: int, min_version: str) -> None:
209
+ """
210
+ Unregister PyTorch's symbolic `opname`-`opset_version` for ONNX export.
211
+ The un-registration is performed only when PyTorch's version is < `min_version`
212
+ IMPORTANT: The symbolic must have been manually registered by the caller, otherwise
213
+ the incorrect symbolic may be unregistered instead.
214
+ """
215
+
216
+ # TODO: _unregister_custom_op_symbolic is introduced PyTorch>=1.10
217
+ # Remove after PyTorch 1.10+ is used by ALL detectron2's CI
218
+ try:
219
+ from torch.onnx import unregister_custom_op_symbolic as _unregister_custom_op_symbolic
220
+ except ImportError:
221
+
222
+ def _unregister_custom_op_symbolic(symbolic_name, opset_version):
223
+ import torch.onnx.symbolic_registry as sym_registry
224
+ from torch.onnx.symbolic_helper import _onnx_main_opset, _onnx_stable_opsets
225
+
226
+ def _get_ns_op_name_from_custom_op(symbolic_name):
227
+ try:
228
+ from torch.onnx.utils import get_ns_op_name_from_custom_op
229
+
230
+ ns, op_name = get_ns_op_name_from_custom_op(symbolic_name)
231
+ except ImportError as import_error:
232
+ if not bool(
233
+ re.match(r"^[a-zA-Z0-9-_]*::[a-zA-Z-_]+[a-zA-Z0-9-_]*$", symbolic_name)
234
+ ):
235
+ raise ValueError(
236
+ f"Invalid symbolic name {symbolic_name}. Must be `domain::name`"
237
+ ) from import_error
238
+
239
+ ns, op_name = symbolic_name.split("::")
240
+ if ns == "onnx":
241
+ raise ValueError(f"{ns} domain cannot be modified.") from import_error
242
+
243
+ if ns == "aten":
244
+ ns = ""
245
+
246
+ return ns, op_name
247
+
248
+ def _unregister_op(opname: str, domain: str, version: int):
249
+ try:
250
+ sym_registry.unregister_op(op_name, ns, ver)
251
+ except AttributeError as attribute_error:
252
+ if sym_registry.is_registered_op(opname, domain, version):
253
+ del sym_registry._registry[(domain, version)][opname]
254
+ if not sym_registry._registry[(domain, version)]:
255
+ del sym_registry._registry[(domain, version)]
256
+ else:
257
+ raise RuntimeError(
258
+ f"The opname {opname} is not registered."
259
+ ) from attribute_error
260
+
261
+ ns, op_name = _get_ns_op_name_from_custom_op(symbolic_name)
262
+ for ver in _onnx_stable_opsets + [_onnx_main_opset]:
263
+ if ver >= opset_version:
264
+ _unregister_op(op_name, ns, ver)
265
+
266
+ if min_torch_version(min_version):
267
+ return
268
+ _unregister_custom_op_symbolic(opname, opset_version)
269
+ print(f"_unregister_custom_op_onnx_export({opname}, {opset_version}) succeeded.")
270
+
271
+
272
+ skipIfOnCPUCI = unittest.skipIf(
273
+ os.environ.get("CI") and not torch.cuda.is_available(),
274
+ "The test is too slow on CPUs and will be executed on CircleCI's GPU jobs.",
275
+ )
276
+
277
+
278
+ def skipIfUnsupportedMinOpsetVersion(min_opset_version, current_opset_version=None):
279
+ """
280
+ Skips tests for ONNX Opset versions older than min_opset_version.
281
+ """
282
+
283
+ def skip_dec(func):
284
+ def wrapper(self):
285
+ try:
286
+ opset_version = self.opset_version
287
+ except AttributeError:
288
+ opset_version = current_opset_version
289
+ if opset_version < min_opset_version:
290
+ raise unittest.SkipTest(
291
+ f"Unsupported opset_version {opset_version}"
292
+ f", required is {min_opset_version}"
293
+ )
294
+ return func(self)
295
+
296
+ return wrapper
297
+
298
+ return skip_dec
299
+
300
+
301
+ def skipIfUnsupportedMinTorchVersion(min_version):
302
+ """
303
+ Skips tests for PyTorch versions older than min_version.
304
+ """
305
+ reason = f"module 'torch' has __version__ {torch.__version__}" f", required is: {min_version}"
306
+ return unittest.skipIf(not min_torch_version(min_version), reason)
307
+
308
+
309
+ # TODO: Remove after PyTorch 1.11.1+ is used by detectron2's CI
310
+ def _pytorch1111_symbolic_opset9_to(g, self, *args):
311
+ """aten::to() symbolic that must be used for testing with PyTorch < 1.11.1."""
312
+
313
+ def is_aten_to_device_only(args):
314
+ if len(args) == 4:
315
+ # aten::to(Tensor, Device, bool, bool, memory_format)
316
+ return (
317
+ args[0].node().kind() == "prim::device"
318
+ or args[0].type().isSubtypeOf(ListType.ofInts())
319
+ or (
320
+ sym_help._is_value(args[0])
321
+ and args[0].node().kind() == "onnx::Constant"
322
+ and isinstance(args[0].node()["value"], str)
323
+ )
324
+ )
325
+ elif len(args) == 5:
326
+ # aten::to(Tensor, Device, ScalarType, bool, bool, memory_format)
327
+ # When dtype is None, this is a aten::to(device) call
328
+ dtype = sym_help._get_const(args[1], "i", "dtype")
329
+ return dtype is None
330
+ elif len(args) in (6, 7):
331
+ # aten::to(Tensor, ScalarType, Layout, Device, bool, bool, memory_format)
332
+ # aten::to(Tensor, ScalarType, Layout, Device, bool, bool, bool, memory_format)
333
+ # When dtype is None, this is a aten::to(device) call
334
+ dtype = sym_help._get_const(args[0], "i", "dtype")
335
+ return dtype is None
336
+ return False
337
+
338
+ # ONNX doesn't have a concept of a device, so we ignore device-only casts
339
+ if is_aten_to_device_only(args):
340
+ return self
341
+
342
+ if len(args) == 4:
343
+ # TestONNXRuntime::test_ones_bool shows args[0] of aten::to can be onnx::Constant[Tensor]
344
+ # In this case, the constant value is a tensor not int,
345
+ # so sym_help._maybe_get_const(args[0], 'i') would not work.
346
+ dtype = args[0]
347
+ if sym_help._is_value(args[0]) and args[0].node().kind() == "onnx::Constant":
348
+ tval = args[0].node()["value"]
349
+ if isinstance(tval, torch.Tensor):
350
+ if len(tval.shape) == 0:
351
+ tval = tval.item()
352
+ dtype = int(tval)
353
+ else:
354
+ dtype = tval
355
+
356
+ if sym_help._is_value(dtype) or isinstance(dtype, torch.Tensor):
357
+ # aten::to(Tensor, Tensor, bool, bool, memory_format)
358
+ dtype = args[0].type().scalarType()
359
+ return g.op("Cast", self, to_i=sym_help.cast_pytorch_to_onnx[dtype])
360
+ else:
361
+ # aten::to(Tensor, ScalarType, bool, bool, memory_format)
362
+ # memory_format is ignored
363
+ return g.op("Cast", self, to_i=sym_help.scalar_type_to_onnx[dtype])
364
+ elif len(args) == 5:
365
+ # aten::to(Tensor, Device, ScalarType, bool, bool, memory_format)
366
+ dtype = sym_help._get_const(args[1], "i", "dtype")
367
+ # memory_format is ignored
368
+ return g.op("Cast", self, to_i=sym_help.scalar_type_to_onnx[dtype])
369
+ elif len(args) == 6:
370
+ # aten::to(Tensor, ScalarType, Layout, Device, bool, bool, memory_format)
371
+ dtype = sym_help._get_const(args[0], "i", "dtype")
372
+ # Layout, device and memory_format are ignored
373
+ return g.op("Cast", self, to_i=sym_help.scalar_type_to_onnx[dtype])
374
+ elif len(args) == 7:
375
+ # aten::to(Tensor, ScalarType, Layout, Device, bool, bool, bool, memory_format)
376
+ dtype = sym_help._get_const(args[0], "i", "dtype")
377
+ # Layout, device and memory_format are ignored
378
+ return g.op("Cast", self, to_i=sym_help.scalar_type_to_onnx[dtype])
379
+ else:
380
+ return sym_help._onnx_unsupported("Unknown aten::to signature")
381
+
382
+
383
+ # TODO: Remove after PyTorch 1.11.1+ is used by detectron2's CI
384
+ def _pytorch1111_symbolic_opset9_repeat_interleave(g, self, repeats, dim=None, output_size=None):
385
+
386
+ # from torch.onnx.symbolic_helper import ScalarType
387
+ from torch.onnx.symbolic_opset9 import expand, unsqueeze
388
+
389
+ input = self
390
+ # if dim is None flatten
391
+ # By default, use the flattened input array, and return a flat output array
392
+ if sym_help._is_none(dim):
393
+ input = sym_help._reshape_helper(g, self, g.op("Constant", value_t=torch.tensor([-1])))
394
+ dim = 0
395
+ else:
396
+ dim = sym_help._maybe_get_scalar(dim)
397
+
398
+ repeats_dim = sym_help._get_tensor_rank(repeats)
399
+ repeats_sizes = sym_help._get_tensor_sizes(repeats)
400
+ input_sizes = sym_help._get_tensor_sizes(input)
401
+ if repeats_dim is None:
402
+ raise RuntimeError(
403
+ "Unsupported: ONNX export of repeat_interleave for unknown " "repeats rank."
404
+ )
405
+ if repeats_sizes is None:
406
+ raise RuntimeError(
407
+ "Unsupported: ONNX export of repeat_interleave for unknown " "repeats size."
408
+ )
409
+ if input_sizes is None:
410
+ raise RuntimeError(
411
+ "Unsupported: ONNX export of repeat_interleave for unknown " "input size."
412
+ )
413
+
414
+ input_sizes_temp = input_sizes.copy()
415
+ for idx, input_size in enumerate(input_sizes):
416
+ if input_size is None:
417
+ input_sizes[idx], input_sizes_temp[idx] = 0, -1
418
+
419
+ # Cases where repeats is an int or single value tensor
420
+ if repeats_dim == 0 or (repeats_dim == 1 and repeats_sizes[0] == 1):
421
+ if not sym_help._is_tensor(repeats):
422
+ repeats = g.op("Constant", value_t=torch.LongTensor(repeats))
423
+ if input_sizes[dim] == 0:
424
+ return sym_help._onnx_opset_unsupported_detailed(
425
+ "repeat_interleave",
426
+ 9,
427
+ 13,
428
+ "Unsupported along dimension with unknown input size",
429
+ )
430
+ else:
431
+ reps = input_sizes[dim]
432
+ repeats = expand(g, repeats, g.op("Constant", value_t=torch.tensor([reps])), None)
433
+
434
+ # Cases where repeats is a 1 dim Tensor
435
+ elif repeats_dim == 1:
436
+ if input_sizes[dim] == 0:
437
+ return sym_help._onnx_opset_unsupported_detailed(
438
+ "repeat_interleave",
439
+ 9,
440
+ 13,
441
+ "Unsupported along dimension with unknown input size",
442
+ )
443
+ if repeats_sizes[0] is None:
444
+ return sym_help._onnx_opset_unsupported_detailed(
445
+ "repeat_interleave", 9, 13, "Unsupported for cases with dynamic repeats"
446
+ )
447
+ assert (
448
+ repeats_sizes[0] == input_sizes[dim]
449
+ ), "repeats must have the same size as input along dim"
450
+ reps = repeats_sizes[0]
451
+ else:
452
+ raise RuntimeError("repeats must be 0-dim or 1-dim tensor")
453
+
454
+ final_splits = list()
455
+ r_splits = sym_help._repeat_interleave_split_helper(g, repeats, reps, 0)
456
+ if isinstance(r_splits, torch._C.Value):
457
+ r_splits = [r_splits]
458
+ i_splits = sym_help._repeat_interleave_split_helper(g, input, reps, dim)
459
+ if isinstance(i_splits, torch._C.Value):
460
+ i_splits = [i_splits]
461
+ input_sizes[dim], input_sizes_temp[dim] = -1, 1
462
+ for idx, r_split in enumerate(r_splits):
463
+ i_split = unsqueeze(g, i_splits[idx], dim + 1)
464
+ r_concat = [
465
+ g.op("Constant", value_t=torch.LongTensor(input_sizes_temp[: dim + 1])),
466
+ r_split,
467
+ g.op("Constant", value_t=torch.LongTensor(input_sizes_temp[dim + 1 :])),
468
+ ]
469
+ r_concat = g.op("Concat", *r_concat, axis_i=0)
470
+ i_split = expand(g, i_split, r_concat, None)
471
+ i_split = sym_help._reshape_helper(
472
+ g,
473
+ i_split,
474
+ g.op("Constant", value_t=torch.LongTensor(input_sizes)),
475
+ allowzero=0,
476
+ )
477
+ final_splits.append(i_split)
478
+ return g.op("Concat", *final_splits, axis_i=dim)
detectron2/utils/tracing.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ import torch
3
+
4
+ from detectron2.utils.env import TORCH_VERSION
5
+
6
+ try:
7
+ from torch.fx._symbolic_trace import is_fx_tracing as is_fx_tracing_current
8
+
9
+ tracing_current_exists = True
10
+ except ImportError:
11
+ tracing_current_exists = False
12
+
13
+ try:
14
+ from torch.fx._symbolic_trace import _orig_module_call
15
+
16
+ tracing_legacy_exists = True
17
+ except ImportError:
18
+ tracing_legacy_exists = False
19
+
20
+
21
+ @torch.jit.ignore
22
+ def is_fx_tracing_legacy() -> bool:
23
+ """
24
+ Returns a bool indicating whether torch.fx is currently symbolically tracing a module.
25
+ Can be useful for gating module logic that is incompatible with symbolic tracing.
26
+ """
27
+ return torch.nn.Module.__call__ is not _orig_module_call
28
+
29
+
30
+ def is_fx_tracing() -> bool:
31
+ """Returns whether execution is currently in
32
+ Torch FX tracing mode"""
33
+ if torch.jit.is_scripting():
34
+ return False
35
+ if TORCH_VERSION >= (1, 10) and tracing_current_exists:
36
+ return is_fx_tracing_current()
37
+ elif tracing_legacy_exists:
38
+ return is_fx_tracing_legacy()
39
+ else:
40
+ # Can't find either current or legacy tracing indication code.
41
+ # Enabling this assert_fx_safe() call regardless of tracing status.
42
+ return False
43
+
44
+
45
+ def assert_fx_safe(condition: bool, message: str) -> torch.Tensor:
46
+ """An FX-tracing safe version of assert.
47
+ Avoids erroneous type assertion triggering when types are masked inside
48
+ an fx.proxy.Proxy object during tracing.
49
+ Args: condition - either a boolean expression or a string representing
50
+ the condition to test. If this assert triggers an exception when tracing
51
+ due to dynamic control flow, try encasing the expression in quotation
52
+ marks and supplying it as a string."""
53
+ # Must return a concrete tensor for compatibility with PyTorch <=1.8.
54
+ # If <=1.8 compatibility is not needed, return type can be converted to None
55
+ if torch.jit.is_scripting() or is_fx_tracing():
56
+ return torch.zeros(1)
57
+ return _do_assert_fx_safe(condition, message)
58
+
59
+
60
+ def _do_assert_fx_safe(condition: bool, message: str) -> torch.Tensor:
61
+ try:
62
+ if isinstance(condition, str):
63
+ caller_frame = inspect.currentframe().f_back
64
+ torch._assert(eval(condition, caller_frame.f_globals, caller_frame.f_locals), message)
65
+ return torch.ones(1)
66
+ else:
67
+ torch._assert(condition, message)
68
+ return torch.ones(1)
69
+ except torch.fx.proxy.TraceError as e:
70
+ print(
71
+ "Found a non-FX compatible assertion. Skipping the check. Failure is shown below"
72
+ + str(e)
73
+ )
detectron2/utils/video_visualizer.py ADDED
@@ -0,0 +1,287 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ import numpy as np
3
+ from typing import List
4
+ import pycocotools.mask as mask_util
5
+
6
+ from detectron2.structures import Instances
7
+ from detectron2.utils.visualizer import (
8
+ ColorMode,
9
+ Visualizer,
10
+ _create_text_labels,
11
+ _PanopticPrediction,
12
+ )
13
+
14
+ from .colormap import random_color, random_colors
15
+
16
+
17
+ class _DetectedInstance:
18
+ """
19
+ Used to store data about detected objects in video frame,
20
+ in order to transfer color to objects in the future frames.
21
+
22
+ Attributes:
23
+ label (int):
24
+ bbox (tuple[float]):
25
+ mask_rle (dict):
26
+ color (tuple[float]): RGB colors in range (0, 1)
27
+ ttl (int): time-to-live for the instance. For example, if ttl=2,
28
+ the instance color can be transferred to objects in the next two frames.
29
+ """
30
+
31
+ __slots__ = ["label", "bbox", "mask_rle", "color", "ttl"]
32
+
33
+ def __init__(self, label, bbox, mask_rle, color, ttl):
34
+ self.label = label
35
+ self.bbox = bbox
36
+ self.mask_rle = mask_rle
37
+ self.color = color
38
+ self.ttl = ttl
39
+
40
+
41
+ class VideoVisualizer:
42
+ def __init__(self, metadata, instance_mode=ColorMode.IMAGE):
43
+ """
44
+ Args:
45
+ metadata (MetadataCatalog): image metadata.
46
+ """
47
+ self.metadata = metadata
48
+ self._old_instances = []
49
+ assert instance_mode in [
50
+ ColorMode.IMAGE,
51
+ ColorMode.IMAGE_BW,
52
+ ], "Other mode not supported yet."
53
+ self._instance_mode = instance_mode
54
+ self._max_num_instances = self.metadata.get("max_num_instances", 74)
55
+ self._assigned_colors = {}
56
+ self._color_pool = random_colors(self._max_num_instances, rgb=True, maximum=1)
57
+ self._color_idx_set = set(range(len(self._color_pool)))
58
+
59
+ def draw_instance_predictions(self, frame, predictions):
60
+ """
61
+ Draw instance-level prediction results on an image.
62
+
63
+ Args:
64
+ frame (ndarray): an RGB image of shape (H, W, C), in the range [0, 255].
65
+ predictions (Instances): the output of an instance detection/segmentation
66
+ model. Following fields will be used to draw:
67
+ "pred_boxes", "pred_classes", "scores", "pred_masks" (or "pred_masks_rle").
68
+
69
+ Returns:
70
+ output (VisImage): image object with visualizations.
71
+ """
72
+ frame_visualizer = Visualizer(frame, self.metadata)
73
+ num_instances = len(predictions)
74
+ if num_instances == 0:
75
+ return frame_visualizer.output
76
+
77
+ boxes = predictions.pred_boxes.tensor.numpy() if predictions.has("pred_boxes") else None
78
+ scores = predictions.scores if predictions.has("scores") else None
79
+ classes = predictions.pred_classes.numpy() if predictions.has("pred_classes") else None
80
+ keypoints = predictions.pred_keypoints if predictions.has("pred_keypoints") else None
81
+ colors = predictions.COLOR if predictions.has("COLOR") else [None] * len(predictions)
82
+ periods = predictions.ID_period if predictions.has("ID_period") else None
83
+ period_threshold = self.metadata.get("period_threshold", 0)
84
+ visibilities = (
85
+ [True] * len(predictions)
86
+ if periods is None
87
+ else [x > period_threshold for x in periods]
88
+ )
89
+
90
+ if predictions.has("pred_masks"):
91
+ masks = predictions.pred_masks
92
+ # mask IOU is not yet enabled
93
+ # masks_rles = mask_util.encode(np.asarray(masks.permute(1, 2, 0), order="F"))
94
+ # assert len(masks_rles) == num_instances
95
+ else:
96
+ masks = None
97
+
98
+ if not predictions.has("COLOR"):
99
+ if predictions.has("ID"):
100
+ colors = self._assign_colors_by_id(predictions)
101
+ else:
102
+ # ToDo: clean old assign color method and use a default tracker to assign id
103
+ detected = [
104
+ _DetectedInstance(classes[i], boxes[i], mask_rle=None, color=colors[i], ttl=8)
105
+ for i in range(num_instances)
106
+ ]
107
+ colors = self._assign_colors(detected)
108
+
109
+ labels = _create_text_labels(classes, scores, self.metadata.get("thing_classes", None))
110
+
111
+ if self._instance_mode == ColorMode.IMAGE_BW:
112
+ # any() returns uint8 tensor
113
+ frame_visualizer.output.reset_image(
114
+ frame_visualizer._create_grayscale_image(
115
+ (masks.any(dim=0) > 0).numpy() if masks is not None else None
116
+ )
117
+ )
118
+ alpha = 0.3
119
+ else:
120
+ alpha = 0.5
121
+
122
+ labels = (
123
+ None
124
+ if labels is None
125
+ else [y[0] for y in filter(lambda x: x[1], zip(labels, visibilities))]
126
+ ) # noqa
127
+ assigned_colors = (
128
+ None
129
+ if colors is None
130
+ else [y[0] for y in filter(lambda x: x[1], zip(colors, visibilities))]
131
+ ) # noqa
132
+ frame_visualizer.overlay_instances(
133
+ boxes=None if masks is not None else boxes[visibilities], # boxes are a bit distracting
134
+ masks=None if masks is None else masks[visibilities],
135
+ labels=labels,
136
+ keypoints=None if keypoints is None else keypoints[visibilities],
137
+ assigned_colors=assigned_colors,
138
+ alpha=alpha,
139
+ )
140
+
141
+ return frame_visualizer.output
142
+
143
+ def draw_sem_seg(self, frame, sem_seg, area_threshold=None):
144
+ """
145
+ Args:
146
+ sem_seg (ndarray or Tensor): semantic segmentation of shape (H, W),
147
+ each value is the integer label.
148
+ area_threshold (Optional[int]): only draw segmentations larger than the threshold
149
+ """
150
+ # don't need to do anything special
151
+ frame_visualizer = Visualizer(frame, self.metadata)
152
+ frame_visualizer.draw_sem_seg(sem_seg, area_threshold=None)
153
+ return frame_visualizer.output
154
+
155
+ def draw_panoptic_seg_predictions(
156
+ self, frame, panoptic_seg, segments_info, area_threshold=None, alpha=0.5
157
+ ):
158
+ frame_visualizer = Visualizer(frame, self.metadata)
159
+ pred = _PanopticPrediction(panoptic_seg, segments_info, self.metadata)
160
+
161
+ if self._instance_mode == ColorMode.IMAGE_BW:
162
+ frame_visualizer.output.reset_image(
163
+ frame_visualizer._create_grayscale_image(pred.non_empty_mask())
164
+ )
165
+
166
+ # draw mask for all semantic segments first i.e. "stuff"
167
+ for mask, sinfo in pred.semantic_masks():
168
+ category_idx = sinfo["category_id"]
169
+ try:
170
+ mask_color = [x / 255 for x in self.metadata.stuff_colors[category_idx]]
171
+ except AttributeError:
172
+ mask_color = None
173
+
174
+ frame_visualizer.draw_binary_mask(
175
+ mask,
176
+ color=mask_color,
177
+ text=self.metadata.stuff_classes[category_idx],
178
+ alpha=alpha,
179
+ area_threshold=area_threshold,
180
+ )
181
+
182
+ all_instances = list(pred.instance_masks())
183
+ if len(all_instances) == 0:
184
+ return frame_visualizer.output
185
+ # draw mask for all instances second
186
+ masks, sinfo = list(zip(*all_instances))
187
+ num_instances = len(masks)
188
+ masks_rles = mask_util.encode(
189
+ np.asarray(np.asarray(masks).transpose(1, 2, 0), dtype=np.uint8, order="F")
190
+ )
191
+ assert len(masks_rles) == num_instances
192
+
193
+ category_ids = [x["category_id"] for x in sinfo]
194
+ detected = [
195
+ _DetectedInstance(category_ids[i], bbox=None, mask_rle=masks_rles[i], color=None, ttl=8)
196
+ for i in range(num_instances)
197
+ ]
198
+ colors = self._assign_colors(detected)
199
+ labels = [self.metadata.thing_classes[k] for k in category_ids]
200
+
201
+ frame_visualizer.overlay_instances(
202
+ boxes=None,
203
+ masks=masks,
204
+ labels=labels,
205
+ keypoints=None,
206
+ assigned_colors=colors,
207
+ alpha=alpha,
208
+ )
209
+ return frame_visualizer.output
210
+
211
+ def _assign_colors(self, instances):
212
+ """
213
+ Naive tracking heuristics to assign same color to the same instance,
214
+ will update the internal state of tracked instances.
215
+
216
+ Returns:
217
+ list[tuple[float]]: list of colors.
218
+ """
219
+
220
+ # Compute iou with either boxes or masks:
221
+ is_crowd = np.zeros((len(instances),), dtype=bool)
222
+ if instances[0].bbox is None:
223
+ assert instances[0].mask_rle is not None
224
+ # use mask iou only when box iou is None
225
+ # because box seems good enough
226
+ rles_old = [x.mask_rle for x in self._old_instances]
227
+ rles_new = [x.mask_rle for x in instances]
228
+ ious = mask_util.iou(rles_old, rles_new, is_crowd)
229
+ threshold = 0.5
230
+ else:
231
+ boxes_old = [x.bbox for x in self._old_instances]
232
+ boxes_new = [x.bbox for x in instances]
233
+ ious = mask_util.iou(boxes_old, boxes_new, is_crowd)
234
+ threshold = 0.6
235
+ if len(ious) == 0:
236
+ ious = np.zeros((len(self._old_instances), len(instances)), dtype="float32")
237
+
238
+ # Only allow matching instances of the same label:
239
+ for old_idx, old in enumerate(self._old_instances):
240
+ for new_idx, new in enumerate(instances):
241
+ if old.label != new.label:
242
+ ious[old_idx, new_idx] = 0
243
+
244
+ matched_new_per_old = np.asarray(ious).argmax(axis=1)
245
+ max_iou_per_old = np.asarray(ious).max(axis=1)
246
+
247
+ # Try to find match for each old instance:
248
+ extra_instances = []
249
+ for idx, inst in enumerate(self._old_instances):
250
+ if max_iou_per_old[idx] > threshold:
251
+ newidx = matched_new_per_old[idx]
252
+ if instances[newidx].color is None:
253
+ instances[newidx].color = inst.color
254
+ continue
255
+ # If an old instance does not match any new instances,
256
+ # keep it for the next frame in case it is just missed by the detector
257
+ inst.ttl -= 1
258
+ if inst.ttl > 0:
259
+ extra_instances.append(inst)
260
+
261
+ # Assign random color to newly-detected instances:
262
+ for inst in instances:
263
+ if inst.color is None:
264
+ inst.color = random_color(rgb=True, maximum=1)
265
+ self._old_instances = instances[:] + extra_instances
266
+ return [d.color for d in instances]
267
+
268
+ def _assign_colors_by_id(self, instances: Instances) -> List:
269
+ colors = []
270
+ untracked_ids = set(self._assigned_colors.keys())
271
+ for id in instances.ID:
272
+ if id in self._assigned_colors:
273
+ colors.append(self._color_pool[self._assigned_colors[id]])
274
+ untracked_ids.remove(id)
275
+ else:
276
+ assert (
277
+ len(self._color_idx_set) >= 1
278
+ ), f"Number of id exceeded maximum, \
279
+ max = {self._max_num_instances}"
280
+ idx = self._color_idx_set.pop()
281
+ color = self._color_pool[idx]
282
+ self._assigned_colors[id] = idx
283
+ colors.append(color)
284
+ for id in untracked_ids:
285
+ self._color_idx_set.add(self._assigned_colors[id])
286
+ del self._assigned_colors[id]
287
+ return colors
detectron2/utils/visualizer.py ADDED
@@ -0,0 +1,1281 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ import colorsys
3
+ import logging
4
+ import math
5
+ import numpy as np
6
+ from enum import Enum, unique
7
+ import cv2
8
+ import matplotlib as mpl
9
+ import matplotlib.colors as mplc
10
+ import matplotlib.figure as mplfigure
11
+ import pycocotools.mask as mask_util
12
+ import torch
13
+ from matplotlib.backends.backend_agg import FigureCanvasAgg
14
+ from PIL import Image
15
+
16
+ from detectron2.data import MetadataCatalog
17
+ from detectron2.structures import BitMasks, Boxes, BoxMode, Keypoints, PolygonMasks, RotatedBoxes
18
+ from detectron2.utils.file_io import PathManager
19
+
20
+ from .colormap import random_color
21
+
22
+ logger = logging.getLogger(__name__)
23
+
24
+ __all__ = ["ColorMode", "VisImage", "Visualizer"]
25
+
26
+
27
+ _SMALL_OBJECT_AREA_THRESH = 1000
28
+ _LARGE_MASK_AREA_THRESH = 120000
29
+ _OFF_WHITE = (1.0, 1.0, 240.0 / 255)
30
+ _BLACK = (0, 0, 0)
31
+ _RED = (1.0, 0, 0)
32
+
33
+ _KEYPOINT_THRESHOLD = 0.05
34
+
35
+
36
+ @unique
37
+ class ColorMode(Enum):
38
+ """
39
+ Enum of different color modes to use for instance visualizations.
40
+ """
41
+
42
+ IMAGE = 0
43
+ """
44
+ Picks a random color for every instance and overlay segmentations with low opacity.
45
+ """
46
+ SEGMENTATION = 1
47
+ """
48
+ Let instances of the same category have similar colors
49
+ (from metadata.thing_colors), and overlay them with
50
+ high opacity. This provides more attention on the quality of segmentation.
51
+ """
52
+ IMAGE_BW = 2
53
+ """
54
+ Same as IMAGE, but convert all areas without masks to gray-scale.
55
+ Only available for drawing per-instance mask predictions.
56
+ """
57
+
58
+
59
+ class GenericMask:
60
+ """
61
+ Attribute:
62
+ polygons (list[ndarray]): list[ndarray]: polygons for this mask.
63
+ Each ndarray has format [x, y, x, y, ...]
64
+ mask (ndarray): a binary mask
65
+ """
66
+
67
+ def __init__(self, mask_or_polygons, height, width):
68
+ self._mask = self._polygons = self._has_holes = None
69
+ self.height = height
70
+ self.width = width
71
+
72
+ m = mask_or_polygons
73
+ if isinstance(m, dict):
74
+ # RLEs
75
+ assert "counts" in m and "size" in m
76
+ if isinstance(m["counts"], list): # uncompressed RLEs
77
+ h, w = m["size"]
78
+ assert h == height and w == width
79
+ m = mask_util.frPyObjects(m, h, w)
80
+ self._mask = mask_util.decode(m)[:, :]
81
+ return
82
+
83
+ if isinstance(m, list): # list[ndarray]
84
+ self._polygons = [np.asarray(x).reshape(-1) for x in m]
85
+ return
86
+
87
+ if isinstance(m, np.ndarray): # assumed to be a binary mask
88
+ assert m.shape[1] != 2, m.shape
89
+ assert m.shape == (
90
+ height,
91
+ width,
92
+ ), f"mask shape: {m.shape}, target dims: {height}, {width}"
93
+ self._mask = m.astype("uint8")
94
+ return
95
+
96
+ raise ValueError("GenericMask cannot handle object {} of type '{}'".format(m, type(m)))
97
+
98
+ @property
99
+ def mask(self):
100
+ if self._mask is None:
101
+ self._mask = self.polygons_to_mask(self._polygons)
102
+ return self._mask
103
+
104
+ @property
105
+ def polygons(self):
106
+ if self._polygons is None:
107
+ self._polygons, self._has_holes = self.mask_to_polygons(self._mask)
108
+ return self._polygons
109
+
110
+ @property
111
+ def has_holes(self):
112
+ if self._has_holes is None:
113
+ if self._mask is not None:
114
+ self._polygons, self._has_holes = self.mask_to_polygons(self._mask)
115
+ else:
116
+ self._has_holes = False # if original format is polygon, does not have holes
117
+ return self._has_holes
118
+
119
+ def mask_to_polygons(self, mask):
120
+ # cv2.RETR_CCOMP flag retrieves all the contours and arranges them to a 2-level
121
+ # hierarchy. External contours (boundary) of the object are placed in hierarchy-1.
122
+ # Internal contours (holes) are placed in hierarchy-2.
123
+ # cv2.CHAIN_APPROX_NONE flag gets vertices of polygons from contours.
124
+ mask = np.ascontiguousarray(mask) # some versions of cv2 does not support incontiguous arr
125
+ res = cv2.findContours(mask.astype("uint8"), cv2.RETR_CCOMP, cv2.CHAIN_APPROX_NONE)
126
+ hierarchy = res[-1]
127
+ if hierarchy is None: # empty mask
128
+ return [], False
129
+ has_holes = (hierarchy.reshape(-1, 4)[:, 3] >= 0).sum() > 0
130
+ res = res[-2]
131
+ res = [x.flatten() for x in res]
132
+ # These coordinates from OpenCV are integers in range [0, W-1 or H-1].
133
+ # We add 0.5 to turn them into real-value coordinate space. A better solution
134
+ # would be to first +0.5 and then dilate the returned polygon by 0.5.
135
+ res = [x + 0.5 for x in res if len(x) >= 6]
136
+ return res, has_holes
137
+
138
+ def polygons_to_mask(self, polygons):
139
+ rle = mask_util.frPyObjects(polygons, self.height, self.width)
140
+ rle = mask_util.merge(rle)
141
+ return mask_util.decode(rle)[:, :]
142
+
143
+ def area(self):
144
+ return self.mask.sum()
145
+
146
+ def bbox(self):
147
+ p = mask_util.frPyObjects(self.polygons, self.height, self.width)
148
+ p = mask_util.merge(p)
149
+ bbox = mask_util.toBbox(p)
150
+ bbox[2] += bbox[0]
151
+ bbox[3] += bbox[1]
152
+ return bbox
153
+
154
+
155
+ class _PanopticPrediction:
156
+ """
157
+ Unify different panoptic annotation/prediction formats
158
+ """
159
+
160
+ def __init__(self, panoptic_seg, segments_info, metadata=None):
161
+ if segments_info is None:
162
+ assert metadata is not None
163
+ # If "segments_info" is None, we assume "panoptic_img" is a
164
+ # H*W int32 image storing the panoptic_id in the format of
165
+ # category_id * label_divisor + instance_id. We reserve -1 for
166
+ # VOID label.
167
+ label_divisor = metadata.label_divisor
168
+ segments_info = []
169
+ for panoptic_label in np.unique(panoptic_seg.numpy()):
170
+ if panoptic_label == -1:
171
+ # VOID region.
172
+ continue
173
+ pred_class = panoptic_label // label_divisor
174
+ isthing = pred_class in metadata.thing_dataset_id_to_contiguous_id.values()
175
+ segments_info.append(
176
+ {
177
+ "id": int(panoptic_label),
178
+ "category_id": int(pred_class),
179
+ "isthing": bool(isthing),
180
+ }
181
+ )
182
+ del metadata
183
+
184
+ self._seg = panoptic_seg
185
+
186
+ self._sinfo = {s["id"]: s for s in segments_info} # seg id -> seg info
187
+ segment_ids, areas = torch.unique(panoptic_seg, sorted=True, return_counts=True)
188
+ areas = areas.numpy()
189
+ sorted_idxs = np.argsort(-areas)
190
+ self._seg_ids, self._seg_areas = segment_ids[sorted_idxs], areas[sorted_idxs]
191
+ self._seg_ids = self._seg_ids.tolist()
192
+ for sid, area in zip(self._seg_ids, self._seg_areas):
193
+ if sid in self._sinfo:
194
+ self._sinfo[sid]["area"] = float(area)
195
+
196
+ def non_empty_mask(self):
197
+ """
198
+ Returns:
199
+ (H, W) array, a mask for all pixels that have a prediction
200
+ """
201
+ empty_ids = []
202
+ for id in self._seg_ids:
203
+ if id not in self._sinfo:
204
+ empty_ids.append(id)
205
+ if len(empty_ids) == 0:
206
+ return np.zeros(self._seg.shape, dtype=np.uint8)
207
+ assert (
208
+ len(empty_ids) == 1
209
+ ), ">1 ids corresponds to no labels. This is currently not supported"
210
+ return (self._seg != empty_ids[0]).numpy().astype(bool)
211
+
212
+ def semantic_masks(self):
213
+ for sid in self._seg_ids:
214
+ sinfo = self._sinfo.get(sid)
215
+ if sinfo is None or sinfo["isthing"]:
216
+ # Some pixels (e.g. id 0 in PanopticFPN) have no instance or semantic predictions.
217
+ continue
218
+ yield (self._seg == sid).numpy().astype(bool), sinfo
219
+
220
+ def instance_masks(self):
221
+ for sid in self._seg_ids:
222
+ sinfo = self._sinfo.get(sid)
223
+ if sinfo is None or not sinfo["isthing"]:
224
+ continue
225
+ mask = (self._seg == sid).numpy().astype(bool)
226
+ if mask.sum() > 0:
227
+ yield mask, sinfo
228
+
229
+
230
+ def _create_text_labels(classes, scores, class_names, is_crowd=None):
231
+ """
232
+ Args:
233
+ classes (list[int] or None):
234
+ scores (list[float] or None):
235
+ class_names (list[str] or None):
236
+ is_crowd (list[bool] or None):
237
+
238
+ Returns:
239
+ list[str] or None
240
+ """
241
+ labels = None
242
+ if classes is not None:
243
+ if class_names is not None and len(class_names) > 0:
244
+ labels = [class_names[i] for i in classes]
245
+ else:
246
+ labels = [str(i) for i in classes]
247
+ if scores is not None:
248
+ if labels is None:
249
+ labels = ["{:.0f}%".format(s * 100) for s in scores]
250
+ else:
251
+ labels = ["{} {:.0f}%".format(l, s * 100) for l, s in zip(labels, scores)]
252
+ if labels is not None and is_crowd is not None:
253
+ labels = [l + ("|crowd" if crowd else "") for l, crowd in zip(labels, is_crowd)]
254
+ return labels
255
+
256
+
257
+ class VisImage:
258
+ def __init__(self, img, scale=1.0):
259
+ """
260
+ Args:
261
+ img (ndarray): an RGB image of shape (H, W, 3) in range [0, 255].
262
+ scale (float): scale the input image
263
+ """
264
+ self.img = img
265
+ self.scale = scale
266
+ self.width, self.height = img.shape[1], img.shape[0]
267
+ self._setup_figure(img)
268
+
269
+ def _setup_figure(self, img):
270
+ """
271
+ Args:
272
+ Same as in :meth:`__init__()`.
273
+
274
+ Returns:
275
+ fig (matplotlib.pyplot.figure): top level container for all the image plot elements.
276
+ ax (matplotlib.pyplot.Axes): contains figure elements and sets the coordinate system.
277
+ """
278
+ fig = mplfigure.Figure(frameon=False)
279
+ self.dpi = fig.get_dpi()
280
+ # add a small 1e-2 to avoid precision lost due to matplotlib's truncation
281
+ # (https://github.com/matplotlib/matplotlib/issues/15363)
282
+ fig.set_size_inches(
283
+ (self.width * self.scale + 1e-2) / self.dpi,
284
+ (self.height * self.scale + 1e-2) / self.dpi,
285
+ )
286
+ self.canvas = FigureCanvasAgg(fig)
287
+ # self.canvas = mpl.backends.backend_cairo.FigureCanvasCairo(fig)
288
+ ax = fig.add_axes([0.0, 0.0, 1.0, 1.0])
289
+ ax.axis("off")
290
+ self.fig = fig
291
+ self.ax = ax
292
+ self.reset_image(img)
293
+
294
+ def reset_image(self, img):
295
+ """
296
+ Args:
297
+ img: same as in __init__
298
+ """
299
+ img = img.astype("uint8")
300
+ self.ax.imshow(img, extent=(0, self.width, self.height, 0), interpolation="nearest")
301
+
302
+ def save(self, filepath):
303
+ """
304
+ Args:
305
+ filepath (str): a string that contains the absolute path, including the file name, where
306
+ the visualized image will be saved.
307
+ """
308
+ self.fig.savefig(filepath)
309
+
310
+ def get_image(self):
311
+ """
312
+ Returns:
313
+ ndarray:
314
+ the visualized image of shape (H, W, 3) (RGB) in uint8 type.
315
+ The shape is scaled w.r.t the input image using the given `scale` argument.
316
+ """
317
+ canvas = self.canvas
318
+ s, (width, height) = canvas.print_to_buffer()
319
+ # buf = io.BytesIO() # works for cairo backend
320
+ # canvas.print_rgba(buf)
321
+ # width, height = self.width, self.height
322
+ # s = buf.getvalue()
323
+
324
+ buffer = np.frombuffer(s, dtype="uint8")
325
+
326
+ img_rgba = buffer.reshape(height, width, 4)
327
+ rgb, alpha = np.split(img_rgba, [3], axis=2)
328
+ return rgb.astype("uint8")
329
+
330
+
331
+ class Visualizer:
332
+ """
333
+ Visualizer that draws data about detection/segmentation on images.
334
+
335
+ It contains methods like `draw_{text,box,circle,line,binary_mask,polygon}`
336
+ that draw primitive objects to images, as well as high-level wrappers like
337
+ `draw_{instance_predictions,sem_seg,panoptic_seg_predictions,dataset_dict}`
338
+ that draw composite data in some pre-defined style.
339
+
340
+ Note that the exact visualization style for the high-level wrappers are subject to change.
341
+ Style such as color, opacity, label contents, visibility of labels, or even the visibility
342
+ of objects themselves (e.g. when the object is too small) may change according
343
+ to different heuristics, as long as the results still look visually reasonable.
344
+
345
+ To obtain a consistent style, you can implement custom drawing functions with the
346
+ abovementioned primitive methods instead. If you need more customized visualization
347
+ styles, you can process the data yourself following their format documented in
348
+ tutorials (:doc:`/tutorials/models`, :doc:`/tutorials/datasets`). This class does not
349
+ intend to satisfy everyone's preference on drawing styles.
350
+
351
+ This visualizer focuses on high rendering quality rather than performance. It is not
352
+ designed to be used for real-time applications.
353
+ """
354
+
355
+ # TODO implement a fast, rasterized version using OpenCV
356
+
357
+ def __init__(
358
+ self, img_rgb, metadata=None, scale=1.0, instance_mode=ColorMode.IMAGE, font_size_scale=1.0
359
+ ):
360
+ """
361
+ Args:
362
+ img_rgb: a numpy array of shape (H, W, C), where H and W correspond to
363
+ the height and width of the image respectively. C is the number of
364
+ color channels. The image is required to be in RGB format since that
365
+ is a requirement of the Matplotlib library. The image is also expected
366
+ to be in the range [0, 255].
367
+ metadata (Metadata): dataset metadata (e.g. class names and colors)
368
+ instance_mode (ColorMode): defines one of the pre-defined style for drawing
369
+ instances on an image.
370
+ font_size_scale: extra scaling of font size on top of default font size
371
+ """
372
+ self.img = np.asarray(img_rgb).clip(0, 255).astype(np.uint8)
373
+ if metadata is None:
374
+ metadata = MetadataCatalog.get("__nonexist__")
375
+ self.metadata = metadata
376
+ self.output = VisImage(self.img, scale=scale)
377
+ self.cpu_device = torch.device("cpu")
378
+
379
+ # too small texts are useless, therefore clamp to 9
380
+ self._default_font_size = (
381
+ max(np.sqrt(self.output.height * self.output.width) // 90, 10 // scale)
382
+ * font_size_scale
383
+ )
384
+ self._instance_mode = instance_mode
385
+ self.keypoint_threshold = _KEYPOINT_THRESHOLD
386
+
387
+ def draw_instance_predictions(self, predictions, jittering: bool = True):
388
+ """
389
+ Draw instance-level prediction results on an image.
390
+
391
+ Args:
392
+ predictions (Instances): the output of an instance detection/segmentation
393
+ model. Following fields will be used to draw:
394
+ "pred_boxes", "pred_classes", "scores", "pred_masks" (or "pred_masks_rle").
395
+ jittering: if True, in color mode SEGMENTATION, randomly jitter the colors per class
396
+ to distinguish instances from the same class
397
+
398
+ Returns:
399
+ output (VisImage): image object with visualizations.
400
+ """
401
+ boxes = predictions.pred_boxes if predictions.has("pred_boxes") else None
402
+ scores = predictions.scores if predictions.has("scores") else None
403
+ classes = predictions.pred_classes.tolist() if predictions.has("pred_classes") else None
404
+ labels = _create_text_labels(classes, scores, self.metadata.get("thing_classes", None))
405
+ keypoints = predictions.pred_keypoints if predictions.has("pred_keypoints") else None
406
+
407
+ if predictions.has("pred_masks"):
408
+ masks = np.asarray(predictions.pred_masks)
409
+ masks = [GenericMask(x, self.output.height, self.output.width) for x in masks]
410
+ else:
411
+ masks = None
412
+
413
+ if self._instance_mode == ColorMode.SEGMENTATION and self.metadata.get("thing_colors"):
414
+ colors = (
415
+ [self._jitter([x / 255 for x in self.metadata.thing_colors[c]]) for c in classes]
416
+ if jittering
417
+ else [
418
+ tuple(mplc.to_rgb([x / 255 for x in self.metadata.thing_colors[c]]))
419
+ for c in classes
420
+ ]
421
+ )
422
+
423
+ alpha = 0.8
424
+ else:
425
+ colors = None
426
+ alpha = 0.5
427
+
428
+ if self._instance_mode == ColorMode.IMAGE_BW:
429
+ self.output.reset_image(
430
+ self._create_grayscale_image(
431
+ (predictions.pred_masks.any(dim=0) > 0).numpy()
432
+ if predictions.has("pred_masks")
433
+ else None
434
+ )
435
+ )
436
+ alpha = 0.3
437
+
438
+ self.overlay_instances(
439
+ masks=masks,
440
+ boxes=boxes,
441
+ labels=labels,
442
+ keypoints=keypoints,
443
+ assigned_colors=colors,
444
+ alpha=alpha,
445
+ )
446
+ return self.output
447
+
448
+ def draw_sem_seg(self, sem_seg, area_threshold=None, alpha=0.8):
449
+ """
450
+ Draw semantic segmentation predictions/labels.
451
+
452
+ Args:
453
+ sem_seg (Tensor or ndarray): the segmentation of shape (H, W).
454
+ Each value is the integer label of the pixel.
455
+ area_threshold (int): segments with less than `area_threshold` are not drawn.
456
+ alpha (float): the larger it is, the more opaque the segmentations are.
457
+
458
+ Returns:
459
+ output (VisImage): image object with visualizations.
460
+ """
461
+ if isinstance(sem_seg, torch.Tensor):
462
+ sem_seg = sem_seg.numpy()
463
+ labels, areas = np.unique(sem_seg, return_counts=True)
464
+ sorted_idxs = np.argsort(-areas).tolist()
465
+ labels = labels[sorted_idxs]
466
+ for label in filter(lambda l: l < len(self.metadata.stuff_classes), labels):
467
+ try:
468
+ mask_color = [x / 255 for x in self.metadata.stuff_colors[label]]
469
+ except (AttributeError, IndexError):
470
+ mask_color = None
471
+
472
+ binary_mask = (sem_seg == label).astype(np.uint8)
473
+ text = self.metadata.stuff_classes[label]
474
+ self.draw_binary_mask(
475
+ binary_mask,
476
+ color=mask_color,
477
+ edge_color=_OFF_WHITE,
478
+ text=text,
479
+ alpha=alpha,
480
+ area_threshold=area_threshold,
481
+ )
482
+ return self.output
483
+
484
+ def draw_panoptic_seg(self, panoptic_seg, segments_info, area_threshold=None, alpha=0.7):
485
+ """
486
+ Draw panoptic prediction annotations or results.
487
+
488
+ Args:
489
+ panoptic_seg (Tensor): of shape (height, width) where the values are ids for each
490
+ segment.
491
+ segments_info (list[dict] or None): Describe each segment in `panoptic_seg`.
492
+ If it is a ``list[dict]``, each dict contains keys "id", "category_id".
493
+ If None, category id of each pixel is computed by
494
+ ``pixel // metadata.label_divisor``.
495
+ area_threshold (int): stuff segments with less than `area_threshold` are not drawn.
496
+
497
+ Returns:
498
+ output (VisImage): image object with visualizations.
499
+ """
500
+ pred = _PanopticPrediction(panoptic_seg, segments_info, self.metadata)
501
+
502
+ if self._instance_mode == ColorMode.IMAGE_BW:
503
+ self.output.reset_image(self._create_grayscale_image(pred.non_empty_mask()))
504
+
505
+ # draw mask for all semantic segments first i.e. "stuff"
506
+ for mask, sinfo in pred.semantic_masks():
507
+ category_idx = sinfo["category_id"]
508
+ try:
509
+ mask_color = [x / 255 for x in self.metadata.stuff_colors[category_idx]]
510
+ except AttributeError:
511
+ mask_color = None
512
+
513
+ text = self.metadata.stuff_classes[category_idx]
514
+ self.draw_binary_mask(
515
+ mask,
516
+ color=mask_color,
517
+ edge_color=_OFF_WHITE,
518
+ text=text,
519
+ alpha=alpha,
520
+ area_threshold=area_threshold,
521
+ )
522
+
523
+ # draw mask for all instances second
524
+ all_instances = list(pred.instance_masks())
525
+ if len(all_instances) == 0:
526
+ return self.output
527
+ masks, sinfo = list(zip(*all_instances))
528
+ category_ids = [x["category_id"] for x in sinfo]
529
+
530
+ try:
531
+ scores = [x["score"] for x in sinfo]
532
+ except KeyError:
533
+ scores = None
534
+ labels = _create_text_labels(
535
+ category_ids, scores, self.metadata.thing_classes, [x.get("iscrowd", 0) for x in sinfo]
536
+ )
537
+
538
+ try:
539
+ colors = [
540
+ self._jitter([x / 255 for x in self.metadata.thing_colors[c]]) for c in category_ids
541
+ ]
542
+ except AttributeError:
543
+ colors = None
544
+ self.overlay_instances(masks=masks, labels=labels, assigned_colors=colors, alpha=alpha)
545
+
546
+ return self.output
547
+
548
+ draw_panoptic_seg_predictions = draw_panoptic_seg # backward compatibility
549
+
550
+ def draw_dataset_dict(self, dic):
551
+ """
552
+ Draw annotations/segmentations in Detectron2 Dataset format.
553
+
554
+ Args:
555
+ dic (dict): annotation/segmentation data of one image, in Detectron2 Dataset format.
556
+
557
+ Returns:
558
+ output (VisImage): image object with visualizations.
559
+ """
560
+ annos = dic.get("annotations", None)
561
+ if annos:
562
+ if "segmentation" in annos[0]:
563
+ masks = [x["segmentation"] for x in annos]
564
+ else:
565
+ masks = None
566
+ if "keypoints" in annos[0]:
567
+ keypts = [x["keypoints"] for x in annos]
568
+ keypts = np.array(keypts).reshape(len(annos), -1, 3)
569
+ else:
570
+ keypts = None
571
+
572
+ boxes = [
573
+ (
574
+ BoxMode.convert(x["bbox"], x["bbox_mode"], BoxMode.XYXY_ABS)
575
+ if len(x["bbox"]) == 4
576
+ else x["bbox"]
577
+ )
578
+ for x in annos
579
+ ]
580
+
581
+ colors = None
582
+ category_ids = [x["category_id"] for x in annos]
583
+ if self._instance_mode == ColorMode.SEGMENTATION and self.metadata.get("thing_colors"):
584
+ colors = [
585
+ self._jitter([x / 255 for x in self.metadata.thing_colors[c]])
586
+ for c in category_ids
587
+ ]
588
+ names = self.metadata.get("thing_classes", None)
589
+ labels = _create_text_labels(
590
+ category_ids,
591
+ scores=None,
592
+ class_names=names,
593
+ is_crowd=[x.get("iscrowd", 0) for x in annos],
594
+ )
595
+ self.overlay_instances(
596
+ labels=labels, boxes=boxes, masks=masks, keypoints=keypts, assigned_colors=colors
597
+ )
598
+
599
+ sem_seg = dic.get("sem_seg", None)
600
+ if sem_seg is None and "sem_seg_file_name" in dic:
601
+ with PathManager.open(dic["sem_seg_file_name"], "rb") as f:
602
+ sem_seg = Image.open(f)
603
+ sem_seg = np.asarray(sem_seg, dtype="uint8")
604
+ if sem_seg is not None:
605
+ self.draw_sem_seg(sem_seg, area_threshold=0, alpha=0.5)
606
+
607
+ pan_seg = dic.get("pan_seg", None)
608
+ if pan_seg is None and "pan_seg_file_name" in dic:
609
+ with PathManager.open(dic["pan_seg_file_name"], "rb") as f:
610
+ pan_seg = Image.open(f)
611
+ pan_seg = np.asarray(pan_seg)
612
+ from panopticapi.utils import rgb2id
613
+
614
+ pan_seg = rgb2id(pan_seg)
615
+ if pan_seg is not None:
616
+ segments_info = dic["segments_info"]
617
+ pan_seg = torch.tensor(pan_seg)
618
+ self.draw_panoptic_seg(pan_seg, segments_info, area_threshold=0, alpha=0.5)
619
+ return self.output
620
+
621
+ def overlay_instances(
622
+ self,
623
+ *,
624
+ boxes=None,
625
+ labels=None,
626
+ masks=None,
627
+ keypoints=None,
628
+ assigned_colors=None,
629
+ alpha=0.5,
630
+ ):
631
+ """
632
+ Args:
633
+ boxes (Boxes, RotatedBoxes or ndarray): either a :class:`Boxes`,
634
+ or an Nx4 numpy array of XYXY_ABS format for the N objects in a single image,
635
+ or a :class:`RotatedBoxes`,
636
+ or an Nx5 numpy array of (x_center, y_center, width, height, angle_degrees) format
637
+ for the N objects in a single image,
638
+ labels (list[str]): the text to be displayed for each instance.
639
+ masks (masks-like object): Supported types are:
640
+
641
+ * :class:`detectron2.structures.PolygonMasks`,
642
+ :class:`detectron2.structures.BitMasks`.
643
+ * list[list[ndarray]]: contains the segmentation masks for all objects in one image.
644
+ The first level of the list corresponds to individual instances. The second
645
+ level to all the polygon that compose the instance, and the third level
646
+ to the polygon coordinates. The third level should have the format of
647
+ [x0, y0, x1, y1, ..., xn, yn] (n >= 3).
648
+ * list[ndarray]: each ndarray is a binary mask of shape (H, W).
649
+ * list[dict]: each dict is a COCO-style RLE.
650
+ keypoints (Keypoint or array like): an array-like object of shape (N, K, 3),
651
+ where the N is the number of instances and K is the number of keypoints.
652
+ The last dimension corresponds to (x, y, visibility or score).
653
+ assigned_colors (list[matplotlib.colors]): a list of colors, where each color
654
+ corresponds to each mask or box in the image. Refer to 'matplotlib.colors'
655
+ for full list of formats that the colors are accepted in.
656
+ Returns:
657
+ output (VisImage): image object with visualizations.
658
+ """
659
+ num_instances = 0
660
+ if boxes is not None:
661
+ boxes = self._convert_boxes(boxes)
662
+ num_instances = len(boxes)
663
+ if masks is not None:
664
+ masks = self._convert_masks(masks)
665
+ if num_instances:
666
+ assert len(masks) == num_instances
667
+ else:
668
+ num_instances = len(masks)
669
+ if keypoints is not None:
670
+ if num_instances:
671
+ assert len(keypoints) == num_instances
672
+ else:
673
+ num_instances = len(keypoints)
674
+ keypoints = self._convert_keypoints(keypoints)
675
+ if labels is not None:
676
+ assert len(labels) == num_instances
677
+ if assigned_colors is None:
678
+ assigned_colors = [random_color(rgb=True, maximum=1) for _ in range(num_instances)]
679
+ if num_instances == 0:
680
+ return self.output
681
+ if boxes is not None and boxes.shape[1] == 5:
682
+ return self.overlay_rotated_instances(
683
+ boxes=boxes, labels=labels, assigned_colors=assigned_colors
684
+ )
685
+
686
+ # Display in largest to smallest order to reduce occlusion.
687
+ areas = None
688
+ if boxes is not None:
689
+ areas = np.prod(boxes[:, 2:] - boxes[:, :2], axis=1)
690
+ elif masks is not None:
691
+ areas = np.asarray([x.area() for x in masks])
692
+
693
+ if areas is not None:
694
+ sorted_idxs = np.argsort(-areas).tolist()
695
+ # Re-order overlapped instances in descending order.
696
+ boxes = boxes[sorted_idxs] if boxes is not None else None
697
+ labels = [labels[k] for k in sorted_idxs] if labels is not None else None
698
+ masks = [masks[idx] for idx in sorted_idxs] if masks is not None else None
699
+ assigned_colors = [assigned_colors[idx] for idx in sorted_idxs]
700
+ keypoints = keypoints[sorted_idxs] if keypoints is not None else None
701
+
702
+ for i in range(num_instances):
703
+ color = assigned_colors[i]
704
+ if boxes is not None:
705
+ self.draw_box(boxes[i], edge_color=color)
706
+
707
+ if masks is not None:
708
+ for segment in masks[i].polygons:
709
+ self.draw_polygon(segment.reshape(-1, 2), color, alpha=alpha)
710
+
711
+ if labels is not None:
712
+ # first get a box
713
+ if boxes is not None:
714
+ x0, y0, x1, y1 = boxes[i]
715
+ text_pos = (x0, y0) # if drawing boxes, put text on the box corner.
716
+ horiz_align = "left"
717
+ elif masks is not None:
718
+ # skip small mask without polygon
719
+ if len(masks[i].polygons) == 0:
720
+ continue
721
+
722
+ x0, y0, x1, y1 = masks[i].bbox()
723
+
724
+ # draw text in the center (defined by median) when box is not drawn
725
+ # median is less sensitive to outliers.
726
+ text_pos = np.median(masks[i].mask.nonzero(), axis=1)[::-1]
727
+ horiz_align = "center"
728
+ else:
729
+ continue # drawing the box confidence for keypoints isn't very useful.
730
+ # for small objects, draw text at the side to avoid occlusion
731
+ instance_area = (y1 - y0) * (x1 - x0)
732
+ if (
733
+ instance_area < _SMALL_OBJECT_AREA_THRESH * self.output.scale
734
+ or y1 - y0 < 40 * self.output.scale
735
+ ):
736
+ if y1 >= self.output.height - 5:
737
+ text_pos = (x1, y0)
738
+ else:
739
+ text_pos = (x0, y1)
740
+
741
+ height_ratio = (y1 - y0) / np.sqrt(self.output.height * self.output.width)
742
+ lighter_color = self._change_color_brightness(color, brightness_factor=0.7)
743
+ font_size = (
744
+ np.clip((height_ratio - 0.02) / 0.08 + 1, 1.2, 2)
745
+ * 0.5
746
+ * self._default_font_size
747
+ )
748
+ self.draw_text(
749
+ labels[i],
750
+ text_pos,
751
+ color=lighter_color,
752
+ horizontal_alignment=horiz_align,
753
+ font_size=font_size,
754
+ )
755
+
756
+ # draw keypoints
757
+ if keypoints is not None:
758
+ for keypoints_per_instance in keypoints:
759
+ self.draw_and_connect_keypoints(keypoints_per_instance)
760
+
761
+ return self.output
762
+
763
+ def overlay_rotated_instances(self, boxes=None, labels=None, assigned_colors=None):
764
+ """
765
+ Args:
766
+ boxes (ndarray): an Nx5 numpy array of
767
+ (x_center, y_center, width, height, angle_degrees) format
768
+ for the N objects in a single image.
769
+ labels (list[str]): the text to be displayed for each instance.
770
+ assigned_colors (list[matplotlib.colors]): a list of colors, where each color
771
+ corresponds to each mask or box in the image. Refer to 'matplotlib.colors'
772
+ for full list of formats that the colors are accepted in.
773
+
774
+ Returns:
775
+ output (VisImage): image object with visualizations.
776
+ """
777
+ num_instances = len(boxes)
778
+
779
+ if assigned_colors is None:
780
+ assigned_colors = [random_color(rgb=True, maximum=1) for _ in range(num_instances)]
781
+ if num_instances == 0:
782
+ return self.output
783
+
784
+ # Display in largest to smallest order to reduce occlusion.
785
+ if boxes is not None:
786
+ areas = boxes[:, 2] * boxes[:, 3]
787
+
788
+ sorted_idxs = np.argsort(-areas).tolist()
789
+ # Re-order overlapped instances in descending order.
790
+ boxes = boxes[sorted_idxs]
791
+ labels = [labels[k] for k in sorted_idxs] if labels is not None else None
792
+ colors = [assigned_colors[idx] for idx in sorted_idxs]
793
+
794
+ for i in range(num_instances):
795
+ self.draw_rotated_box_with_label(
796
+ boxes[i], edge_color=colors[i], label=labels[i] if labels is not None else None
797
+ )
798
+
799
+ return self.output
800
+
801
+ def draw_and_connect_keypoints(self, keypoints):
802
+ """
803
+ Draws keypoints of an instance and follows the rules for keypoint connections
804
+ to draw lines between appropriate keypoints. This follows color heuristics for
805
+ line color.
806
+
807
+ Args:
808
+ keypoints (Tensor): a tensor of shape (K, 3), where K is the number of keypoints
809
+ and the last dimension corresponds to (x, y, probability).
810
+
811
+ Returns:
812
+ output (VisImage): image object with visualizations.
813
+ """
814
+ visible = {}
815
+ keypoint_names = self.metadata.get("keypoint_names")
816
+ for idx, keypoint in enumerate(keypoints):
817
+
818
+ # draw keypoint
819
+ x, y, prob = keypoint
820
+ if prob > self.keypoint_threshold:
821
+ self.draw_circle((x, y), color=_RED)
822
+ if keypoint_names:
823
+ keypoint_name = keypoint_names[idx]
824
+ visible[keypoint_name] = (x, y)
825
+
826
+ if self.metadata.get("keypoint_connection_rules"):
827
+ for kp0, kp1, color in self.metadata.keypoint_connection_rules:
828
+ if kp0 in visible and kp1 in visible:
829
+ x0, y0 = visible[kp0]
830
+ x1, y1 = visible[kp1]
831
+ color = tuple(x / 255.0 for x in color)
832
+ self.draw_line([x0, x1], [y0, y1], color=color)
833
+
834
+ # draw lines from nose to mid-shoulder and mid-shoulder to mid-hip
835
+ # Note that this strategy is specific to person keypoints.
836
+ # For other keypoints, it should just do nothing
837
+ try:
838
+ ls_x, ls_y = visible["left_shoulder"]
839
+ rs_x, rs_y = visible["right_shoulder"]
840
+ mid_shoulder_x, mid_shoulder_y = (ls_x + rs_x) / 2, (ls_y + rs_y) / 2
841
+ except KeyError:
842
+ pass
843
+ else:
844
+ # draw line from nose to mid-shoulder
845
+ nose_x, nose_y = visible.get("nose", (None, None))
846
+ if nose_x is not None:
847
+ self.draw_line([nose_x, mid_shoulder_x], [nose_y, mid_shoulder_y], color=_RED)
848
+
849
+ try:
850
+ # draw line from mid-shoulder to mid-hip
851
+ lh_x, lh_y = visible["left_hip"]
852
+ rh_x, rh_y = visible["right_hip"]
853
+ except KeyError:
854
+ pass
855
+ else:
856
+ mid_hip_x, mid_hip_y = (lh_x + rh_x) / 2, (lh_y + rh_y) / 2
857
+ self.draw_line([mid_hip_x, mid_shoulder_x], [mid_hip_y, mid_shoulder_y], color=_RED)
858
+ return self.output
859
+
860
+ """
861
+ Primitive drawing functions:
862
+ """
863
+
864
+ def draw_text(
865
+ self,
866
+ text,
867
+ position,
868
+ *,
869
+ font_size=None,
870
+ color="g",
871
+ horizontal_alignment="center",
872
+ rotation=0,
873
+ ):
874
+ """
875
+ Args:
876
+ text (str): class label
877
+ position (tuple): a tuple of the x and y coordinates to place text on image.
878
+ font_size (int, optional): font of the text. If not provided, a font size
879
+ proportional to the image width is calculated and used.
880
+ color: color of the text. Refer to `matplotlib.colors` for full list
881
+ of formats that are accepted.
882
+ horizontal_alignment (str): see `matplotlib.text.Text`
883
+ rotation: rotation angle in degrees CCW
884
+
885
+ Returns:
886
+ output (VisImage): image object with text drawn.
887
+ """
888
+ if not font_size:
889
+ font_size = self._default_font_size
890
+
891
+ # since the text background is dark, we don't want the text to be dark
892
+ color = np.maximum(list(mplc.to_rgb(color)), 0.2)
893
+ color[np.argmax(color)] = max(0.8, np.max(color))
894
+
895
+ x, y = position
896
+ self.output.ax.text(
897
+ x,
898
+ y,
899
+ text,
900
+ size=font_size * self.output.scale,
901
+ family="sans-serif",
902
+ bbox={"facecolor": "black", "alpha": 0.8, "pad": 0.7, "edgecolor": "none"},
903
+ verticalalignment="top",
904
+ horizontalalignment=horizontal_alignment,
905
+ color=color,
906
+ zorder=10,
907
+ rotation=rotation,
908
+ )
909
+ return self.output
910
+
911
+ def draw_box(self, box_coord, alpha=0.5, edge_color="g", line_style="-"):
912
+ """
913
+ Args:
914
+ box_coord (tuple): a tuple containing x0, y0, x1, y1 coordinates, where x0 and y0
915
+ are the coordinates of the image's top left corner. x1 and y1 are the
916
+ coordinates of the image's bottom right corner.
917
+ alpha (float): blending efficient. Smaller values lead to more transparent masks.
918
+ edge_color: color of the outline of the box. Refer to `matplotlib.colors`
919
+ for full list of formats that are accepted.
920
+ line_style (string): the string to use to create the outline of the boxes.
921
+
922
+ Returns:
923
+ output (VisImage): image object with box drawn.
924
+ """
925
+ x0, y0, x1, y1 = box_coord
926
+ width = x1 - x0
927
+ height = y1 - y0
928
+
929
+ linewidth = max(self._default_font_size / 4, 1)
930
+
931
+ self.output.ax.add_patch(
932
+ mpl.patches.Rectangle(
933
+ (x0, y0),
934
+ width,
935
+ height,
936
+ fill=False,
937
+ edgecolor=edge_color,
938
+ linewidth=linewidth * self.output.scale,
939
+ alpha=alpha,
940
+ linestyle=line_style,
941
+ )
942
+ )
943
+ return self.output
944
+
945
+ def draw_rotated_box_with_label(
946
+ self, rotated_box, alpha=0.5, edge_color="g", line_style="-", label=None
947
+ ):
948
+ """
949
+ Draw a rotated box with label on its top-left corner.
950
+
951
+ Args:
952
+ rotated_box (tuple): a tuple containing (cnt_x, cnt_y, w, h, angle),
953
+ where cnt_x and cnt_y are the center coordinates of the box.
954
+ w and h are the width and height of the box. angle represents how
955
+ many degrees the box is rotated CCW with regard to the 0-degree box.
956
+ alpha (float): blending efficient. Smaller values lead to more transparent masks.
957
+ edge_color: color of the outline of the box. Refer to `matplotlib.colors`
958
+ for full list of formats that are accepted.
959
+ line_style (string): the string to use to create the outline of the boxes.
960
+ label (string): label for rotated box. It will not be rendered when set to None.
961
+
962
+ Returns:
963
+ output (VisImage): image object with box drawn.
964
+ """
965
+ cnt_x, cnt_y, w, h, angle = rotated_box
966
+ area = w * h
967
+ # use thinner lines when the box is small
968
+ linewidth = self._default_font_size / (
969
+ 6 if area < _SMALL_OBJECT_AREA_THRESH * self.output.scale else 3
970
+ )
971
+
972
+ theta = angle * math.pi / 180.0
973
+ c = math.cos(theta)
974
+ s = math.sin(theta)
975
+ rect = [(-w / 2, h / 2), (-w / 2, -h / 2), (w / 2, -h / 2), (w / 2, h / 2)]
976
+ # x: left->right ; y: top->down
977
+ rotated_rect = [(s * yy + c * xx + cnt_x, c * yy - s * xx + cnt_y) for (xx, yy) in rect]
978
+ for k in range(4):
979
+ j = (k + 1) % 4
980
+ self.draw_line(
981
+ [rotated_rect[k][0], rotated_rect[j][0]],
982
+ [rotated_rect[k][1], rotated_rect[j][1]],
983
+ color=edge_color,
984
+ linestyle="--" if k == 1 else line_style,
985
+ linewidth=linewidth,
986
+ )
987
+
988
+ if label is not None:
989
+ text_pos = rotated_rect[1] # topleft corner
990
+
991
+ height_ratio = h / np.sqrt(self.output.height * self.output.width)
992
+ label_color = self._change_color_brightness(edge_color, brightness_factor=0.7)
993
+ font_size = (
994
+ np.clip((height_ratio - 0.02) / 0.08 + 1, 1.2, 2) * 0.5 * self._default_font_size
995
+ )
996
+ self.draw_text(label, text_pos, color=label_color, font_size=font_size, rotation=angle)
997
+
998
+ return self.output
999
+
1000
+ def draw_circle(self, circle_coord, color, radius=3):
1001
+ """
1002
+ Args:
1003
+ circle_coord (list(int) or tuple(int)): contains the x and y coordinates
1004
+ of the center of the circle.
1005
+ color: color of the polygon. Refer to `matplotlib.colors` for a full list of
1006
+ formats that are accepted.
1007
+ radius (int): radius of the circle.
1008
+
1009
+ Returns:
1010
+ output (VisImage): image object with box drawn.
1011
+ """
1012
+ x, y = circle_coord
1013
+ self.output.ax.add_patch(
1014
+ mpl.patches.Circle(circle_coord, radius=radius, fill=True, color=color)
1015
+ )
1016
+ return self.output
1017
+
1018
+ def draw_line(self, x_data, y_data, color, linestyle="-", linewidth=None):
1019
+ """
1020
+ Args:
1021
+ x_data (list[int]): a list containing x values of all the points being drawn.
1022
+ Length of list should match the length of y_data.
1023
+ y_data (list[int]): a list containing y values of all the points being drawn.
1024
+ Length of list should match the length of x_data.
1025
+ color: color of the line. Refer to `matplotlib.colors` for a full list of
1026
+ formats that are accepted.
1027
+ linestyle: style of the line. Refer to `matplotlib.lines.Line2D`
1028
+ for a full list of formats that are accepted.
1029
+ linewidth (float or None): width of the line. When it's None,
1030
+ a default value will be computed and used.
1031
+
1032
+ Returns:
1033
+ output (VisImage): image object with line drawn.
1034
+ """
1035
+ if linewidth is None:
1036
+ linewidth = self._default_font_size / 3
1037
+ linewidth = max(linewidth, 1)
1038
+ self.output.ax.add_line(
1039
+ mpl.lines.Line2D(
1040
+ x_data,
1041
+ y_data,
1042
+ linewidth=linewidth * self.output.scale,
1043
+ color=color,
1044
+ linestyle=linestyle,
1045
+ )
1046
+ )
1047
+ return self.output
1048
+
1049
+ def draw_binary_mask(
1050
+ self, binary_mask, color=None, *, edge_color=None, text=None, alpha=0.5, area_threshold=10
1051
+ ):
1052
+ """
1053
+ Args:
1054
+ binary_mask (ndarray): numpy array of shape (H, W), where H is the image height and
1055
+ W is the image width. Each value in the array is either a 0 or 1 value of uint8
1056
+ type.
1057
+ color: color of the mask. Refer to `matplotlib.colors` for a full list of
1058
+ formats that are accepted. If None, will pick a random color.
1059
+ edge_color: color of the polygon edges. Refer to `matplotlib.colors` for a
1060
+ full list of formats that are accepted.
1061
+ text (str): if None, will be drawn on the object
1062
+ alpha (float): blending efficient. Smaller values lead to more transparent masks.
1063
+ area_threshold (float): a connected component smaller than this area will not be shown.
1064
+
1065
+ Returns:
1066
+ output (VisImage): image object with mask drawn.
1067
+ """
1068
+ if color is None:
1069
+ color = random_color(rgb=True, maximum=1)
1070
+ color = mplc.to_rgb(color)
1071
+
1072
+ has_valid_segment = False
1073
+ binary_mask = binary_mask.astype("uint8") # opencv needs uint8
1074
+ mask = GenericMask(binary_mask, self.output.height, self.output.width)
1075
+ shape2d = (binary_mask.shape[0], binary_mask.shape[1])
1076
+
1077
+ if not mask.has_holes:
1078
+ # draw polygons for regular masks
1079
+ for segment in mask.polygons:
1080
+ area = mask_util.area(mask_util.frPyObjects([segment], shape2d[0], shape2d[1]))
1081
+ if area < (area_threshold or 0):
1082
+ continue
1083
+ has_valid_segment = True
1084
+ segment = segment.reshape(-1, 2)
1085
+ self.draw_polygon(segment, color=color, edge_color=edge_color, alpha=alpha)
1086
+ else:
1087
+ # TODO: Use Path/PathPatch to draw vector graphics:
1088
+ # https://stackoverflow.com/questions/8919719/how-to-plot-a-complex-polygon
1089
+ rgba = np.zeros(shape2d + (4,), dtype="float32")
1090
+ rgba[:, :, :3] = color
1091
+ rgba[:, :, 3] = (mask.mask == 1).astype("float32") * alpha
1092
+ has_valid_segment = True
1093
+ self.output.ax.imshow(rgba, extent=(0, self.output.width, self.output.height, 0))
1094
+
1095
+ if text is not None and has_valid_segment:
1096
+ lighter_color = self._change_color_brightness(color, brightness_factor=0.7)
1097
+ self._draw_text_in_mask(binary_mask, text, lighter_color)
1098
+ return self.output
1099
+
1100
+ def draw_soft_mask(self, soft_mask, color=None, *, text=None, alpha=0.5):
1101
+ """
1102
+ Args:
1103
+ soft_mask (ndarray): float array of shape (H, W), each value in [0, 1].
1104
+ color: color of the mask. Refer to `matplotlib.colors` for a full list of
1105
+ formats that are accepted. If None, will pick a random color.
1106
+ text (str): if None, will be drawn on the object
1107
+ alpha (float): blending efficient. Smaller values lead to more transparent masks.
1108
+
1109
+ Returns:
1110
+ output (VisImage): image object with mask drawn.
1111
+ """
1112
+ if color is None:
1113
+ color = random_color(rgb=True, maximum=1)
1114
+ color = mplc.to_rgb(color)
1115
+
1116
+ shape2d = (soft_mask.shape[0], soft_mask.shape[1])
1117
+ rgba = np.zeros(shape2d + (4,), dtype="float32")
1118
+ rgba[:, :, :3] = color
1119
+ rgba[:, :, 3] = soft_mask * alpha
1120
+ self.output.ax.imshow(rgba, extent=(0, self.output.width, self.output.height, 0))
1121
+
1122
+ if text is not None:
1123
+ lighter_color = self._change_color_brightness(color, brightness_factor=0.7)
1124
+ binary_mask = (soft_mask > 0.5).astype("uint8")
1125
+ self._draw_text_in_mask(binary_mask, text, lighter_color)
1126
+ return self.output
1127
+
1128
+ def draw_polygon(self, segment, color, edge_color=None, alpha=0.5):
1129
+ """
1130
+ Args:
1131
+ segment: numpy array of shape Nx2, containing all the points in the polygon.
1132
+ color: color of the polygon. Refer to `matplotlib.colors` for a full list of
1133
+ formats that are accepted.
1134
+ edge_color: color of the polygon edges. Refer to `matplotlib.colors` for a
1135
+ full list of formats that are accepted. If not provided, a darker shade
1136
+ of the polygon color will be used instead.
1137
+ alpha (float): blending efficient. Smaller values lead to more transparent masks.
1138
+
1139
+ Returns:
1140
+ output (VisImage): image object with polygon drawn.
1141
+ """
1142
+ if edge_color is None:
1143
+ # make edge color darker than the polygon color
1144
+ if alpha > 0.8:
1145
+ edge_color = self._change_color_brightness(color, brightness_factor=-0.7)
1146
+ else:
1147
+ edge_color = color
1148
+ edge_color = mplc.to_rgb(edge_color) + (1,)
1149
+
1150
+ polygon = mpl.patches.Polygon(
1151
+ segment,
1152
+ fill=True,
1153
+ facecolor=mplc.to_rgb(color) + (alpha,),
1154
+ edgecolor=edge_color,
1155
+ linewidth=max(self._default_font_size // 15 * self.output.scale, 1),
1156
+ )
1157
+ self.output.ax.add_patch(polygon)
1158
+ return self.output
1159
+
1160
+ """
1161
+ Internal methods:
1162
+ """
1163
+
1164
+ def _jitter(self, color):
1165
+ """
1166
+ Randomly modifies given color to produce a slightly different color than the color given.
1167
+
1168
+ Args:
1169
+ color (tuple[double]): a tuple of 3 elements, containing the RGB values of the color
1170
+ picked. The values in the list are in the [0.0, 1.0] range.
1171
+
1172
+ Returns:
1173
+ jittered_color (tuple[double]): a tuple of 3 elements, containing the RGB values of the
1174
+ color after being jittered. The values in the list are in the [0.0, 1.0] range.
1175
+ """
1176
+ color = mplc.to_rgb(color)
1177
+ vec = np.random.rand(3)
1178
+ # better to do it in another color space
1179
+ vec = vec / np.linalg.norm(vec) * 0.5
1180
+ res = np.clip(vec + color, 0, 1)
1181
+ return tuple(res)
1182
+
1183
+ def _create_grayscale_image(self, mask=None):
1184
+ """
1185
+ Create a grayscale version of the original image.
1186
+ The colors in masked area, if given, will be kept.
1187
+ """
1188
+ img_bw = self.img.astype("f4").mean(axis=2)
1189
+ img_bw = np.stack([img_bw] * 3, axis=2)
1190
+ if mask is not None:
1191
+ img_bw[mask] = self.img[mask]
1192
+ return img_bw
1193
+
1194
+ def _change_color_brightness(self, color, brightness_factor):
1195
+ """
1196
+ Depending on the brightness_factor, gives a lighter or darker color i.e. a color with
1197
+ less or more saturation than the original color.
1198
+
1199
+ Args:
1200
+ color: color of the polygon. Refer to `matplotlib.colors` for a full list of
1201
+ formats that are accepted.
1202
+ brightness_factor (float): a value in [-1.0, 1.0] range. A lightness factor of
1203
+ 0 will correspond to no change, a factor in [-1.0, 0) range will result in
1204
+ a darker color and a factor in (0, 1.0] range will result in a lighter color.
1205
+
1206
+ Returns:
1207
+ modified_color (tuple[double]): a tuple containing the RGB values of the
1208
+ modified color. Each value in the tuple is in the [0.0, 1.0] range.
1209
+ """
1210
+ assert brightness_factor >= -1.0 and brightness_factor <= 1.0
1211
+ color = mplc.to_rgb(color)
1212
+ polygon_color = colorsys.rgb_to_hls(*mplc.to_rgb(color))
1213
+ modified_lightness = polygon_color[1] + (brightness_factor * polygon_color[1])
1214
+ modified_lightness = 0.0 if modified_lightness < 0.0 else modified_lightness
1215
+ modified_lightness = 1.0 if modified_lightness > 1.0 else modified_lightness
1216
+ modified_color = colorsys.hls_to_rgb(polygon_color[0], modified_lightness, polygon_color[2])
1217
+ return tuple(np.clip(modified_color, 0.0, 1.0))
1218
+
1219
+ def _convert_boxes(self, boxes):
1220
+ """
1221
+ Convert different format of boxes to an NxB array, where B = 4 or 5 is the box dimension.
1222
+ """
1223
+ if isinstance(boxes, Boxes) or isinstance(boxes, RotatedBoxes):
1224
+ return boxes.tensor.detach().numpy()
1225
+ else:
1226
+ return np.asarray(boxes)
1227
+
1228
+ def _convert_masks(self, masks_or_polygons):
1229
+ """
1230
+ Convert different format of masks or polygons to a tuple of masks and polygons.
1231
+
1232
+ Returns:
1233
+ list[GenericMask]:
1234
+ """
1235
+
1236
+ m = masks_or_polygons
1237
+ if isinstance(m, PolygonMasks):
1238
+ m = m.polygons
1239
+ if isinstance(m, BitMasks):
1240
+ m = m.tensor.numpy()
1241
+ if isinstance(m, torch.Tensor):
1242
+ m = m.numpy()
1243
+ ret = []
1244
+ for x in m:
1245
+ if isinstance(x, GenericMask):
1246
+ ret.append(x)
1247
+ else:
1248
+ ret.append(GenericMask(x, self.output.height, self.output.width))
1249
+ return ret
1250
+
1251
+ def _draw_text_in_mask(self, binary_mask, text, color):
1252
+ """
1253
+ Find proper places to draw text given a binary mask.
1254
+ """
1255
+ # TODO sometimes drawn on wrong objects. the heuristics here can improve.
1256
+ _num_cc, cc_labels, stats, centroids = cv2.connectedComponentsWithStats(binary_mask, 8)
1257
+ if stats[1:, -1].size == 0:
1258
+ return
1259
+ largest_component_id = np.argmax(stats[1:, -1]) + 1
1260
+
1261
+ # draw text on the largest component, as well as other very large components.
1262
+ for cid in range(1, _num_cc):
1263
+ if cid == largest_component_id or stats[cid, -1] > _LARGE_MASK_AREA_THRESH:
1264
+ # median is more stable than centroid
1265
+ # center = centroids[largest_component_id]
1266
+ center = np.median((cc_labels == cid).nonzero(), axis=1)[::-1]
1267
+ self.draw_text(text, center, color=color)
1268
+
1269
+ def _convert_keypoints(self, keypoints):
1270
+ if isinstance(keypoints, Keypoints):
1271
+ keypoints = keypoints.tensor
1272
+ keypoints = np.asarray(keypoints)
1273
+ return keypoints
1274
+
1275
+ def get_output(self):
1276
+ """
1277
+ Returns:
1278
+ output (VisImage): the image output containing the visualizations added
1279
+ to the image.
1280
+ """
1281
+ return self.output