Spaces:
Paused
Paused
Upload 30 files
Browse files- detectron2/utils/README.md +5 -0
- detectron2/utils/__init__.py +1 -0
- detectron2/utils/__pycache__/__init__.cpython-311.pyc +0 -0
- detectron2/utils/__pycache__/collect_env.cpython-311.pyc +0 -0
- detectron2/utils/__pycache__/comm.cpython-311.pyc +0 -0
- detectron2/utils/__pycache__/develop.cpython-311.pyc +0 -0
- detectron2/utils/__pycache__/env.cpython-311.pyc +0 -0
- detectron2/utils/__pycache__/events.cpython-311.pyc +0 -0
- detectron2/utils/__pycache__/file_io.cpython-311.pyc +0 -0
- detectron2/utils/__pycache__/logger.cpython-311.pyc +0 -0
- detectron2/utils/__pycache__/memory.cpython-311.pyc +0 -0
- detectron2/utils/__pycache__/registry.cpython-311.pyc +0 -0
- detectron2/utils/__pycache__/serialize.cpython-311.pyc +0 -0
- detectron2/utils/__pycache__/tracing.cpython-311.pyc +0 -0
- detectron2/utils/analysis.py +188 -0
- detectron2/utils/collect_env.py +263 -0
- detectron2/utils/colormap.py +158 -0
- detectron2/utils/comm.py +238 -0
- detectron2/utils/develop.py +59 -0
- detectron2/utils/env.py +171 -0
- detectron2/utils/events.py +557 -0
- detectron2/utils/file_io.py +39 -0
- detectron2/utils/logger.py +261 -0
- detectron2/utils/memory.py +84 -0
- detectron2/utils/registry.py +60 -0
- detectron2/utils/serialize.py +32 -0
- detectron2/utils/testing.py +478 -0
- detectron2/utils/tracing.py +73 -0
- detectron2/utils/video_visualizer.py +287 -0
- detectron2/utils/visualizer.py +1281 -0
detectron2/utils/README.md
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Utility functions
|
| 2 |
+
|
| 3 |
+
This folder contain utility functions that are not used in the
|
| 4 |
+
core library, but are useful for building models or training
|
| 5 |
+
code using the config system.
|
detectron2/utils/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
detectron2/utils/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (171 Bytes). View file
|
|
|
detectron2/utils/__pycache__/collect_env.cpython-311.pyc
ADDED
|
Binary file (14.7 kB). View file
|
|
|
detectron2/utils/__pycache__/comm.cpython-311.pyc
ADDED
|
Binary file (10.5 kB). View file
|
|
|
detectron2/utils/__pycache__/develop.cpython-311.pyc
ADDED
|
Binary file (3.01 kB). View file
|
|
|
detectron2/utils/__pycache__/env.cpython-311.pyc
ADDED
|
Binary file (8.3 kB). View file
|
|
|
detectron2/utils/__pycache__/events.cpython-311.pyc
ADDED
|
Binary file (26.9 kB). View file
|
|
|
detectron2/utils/__pycache__/file_io.cpython-311.pyc
ADDED
|
Binary file (2.11 kB). View file
|
|
|
detectron2/utils/__pycache__/logger.cpython-311.pyc
ADDED
|
Binary file (12 kB). View file
|
|
|
detectron2/utils/__pycache__/memory.cpython-311.pyc
ADDED
|
Binary file (4.38 kB). View file
|
|
|
detectron2/utils/__pycache__/registry.cpython-311.pyc
ADDED
|
Binary file (2.06 kB). View file
|
|
|
detectron2/utils/__pycache__/serialize.cpython-311.pyc
ADDED
|
Binary file (1.92 kB). View file
|
|
|
detectron2/utils/__pycache__/tracing.cpython-311.pyc
ADDED
|
Binary file (3.72 kB). View file
|
|
|
detectron2/utils/analysis.py
ADDED
|
@@ -0,0 +1,188 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
|
| 4 |
+
import typing
|
| 5 |
+
from typing import Any, List
|
| 6 |
+
import fvcore
|
| 7 |
+
from fvcore.nn import activation_count, flop_count, parameter_count, parameter_count_table
|
| 8 |
+
from torch import nn
|
| 9 |
+
|
| 10 |
+
from detectron2.export import TracingAdapter
|
| 11 |
+
|
| 12 |
+
__all__ = [
|
| 13 |
+
"activation_count_operators",
|
| 14 |
+
"flop_count_operators",
|
| 15 |
+
"parameter_count_table",
|
| 16 |
+
"parameter_count",
|
| 17 |
+
"FlopCountAnalysis",
|
| 18 |
+
]
|
| 19 |
+
|
| 20 |
+
FLOPS_MODE = "flops"
|
| 21 |
+
ACTIVATIONS_MODE = "activations"
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
# Some extra ops to ignore from counting, including elementwise and reduction ops
|
| 25 |
+
_IGNORED_OPS = {
|
| 26 |
+
"aten::add",
|
| 27 |
+
"aten::add_",
|
| 28 |
+
"aten::argmax",
|
| 29 |
+
"aten::argsort",
|
| 30 |
+
"aten::batch_norm",
|
| 31 |
+
"aten::constant_pad_nd",
|
| 32 |
+
"aten::div",
|
| 33 |
+
"aten::div_",
|
| 34 |
+
"aten::exp",
|
| 35 |
+
"aten::log2",
|
| 36 |
+
"aten::max_pool2d",
|
| 37 |
+
"aten::meshgrid",
|
| 38 |
+
"aten::mul",
|
| 39 |
+
"aten::mul_",
|
| 40 |
+
"aten::neg",
|
| 41 |
+
"aten::nonzero_numpy",
|
| 42 |
+
"aten::reciprocal",
|
| 43 |
+
"aten::repeat_interleave",
|
| 44 |
+
"aten::rsub",
|
| 45 |
+
"aten::sigmoid",
|
| 46 |
+
"aten::sigmoid_",
|
| 47 |
+
"aten::softmax",
|
| 48 |
+
"aten::sort",
|
| 49 |
+
"aten::sqrt",
|
| 50 |
+
"aten::sub",
|
| 51 |
+
"torchvision::nms", # TODO estimate flop for nms
|
| 52 |
+
}
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
class FlopCountAnalysis(fvcore.nn.FlopCountAnalysis):
|
| 56 |
+
"""
|
| 57 |
+
Same as :class:`fvcore.nn.FlopCountAnalysis`, but supports detectron2 models.
|
| 58 |
+
"""
|
| 59 |
+
|
| 60 |
+
def __init__(self, model, inputs):
|
| 61 |
+
"""
|
| 62 |
+
Args:
|
| 63 |
+
model (nn.Module):
|
| 64 |
+
inputs (Any): inputs of the given model. Does not have to be tuple of tensors.
|
| 65 |
+
"""
|
| 66 |
+
wrapper = TracingAdapter(model, inputs, allow_non_tensor=True)
|
| 67 |
+
super().__init__(wrapper, wrapper.flattened_inputs)
|
| 68 |
+
self.set_op_handle(**{k: None for k in _IGNORED_OPS})
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def flop_count_operators(model: nn.Module, inputs: list) -> typing.DefaultDict[str, float]:
|
| 72 |
+
"""
|
| 73 |
+
Implement operator-level flops counting using jit.
|
| 74 |
+
This is a wrapper of :func:`fvcore.nn.flop_count` and adds supports for standard
|
| 75 |
+
detection models in detectron2.
|
| 76 |
+
Please use :class:`FlopCountAnalysis` for more advanced functionalities.
|
| 77 |
+
|
| 78 |
+
Note:
|
| 79 |
+
The function runs the input through the model to compute flops.
|
| 80 |
+
The flops of a detection model is often input-dependent, for example,
|
| 81 |
+
the flops of box & mask head depends on the number of proposals &
|
| 82 |
+
the number of detected objects.
|
| 83 |
+
Therefore, the flops counting using a single input may not accurately
|
| 84 |
+
reflect the computation cost of a model. It's recommended to average
|
| 85 |
+
across a number of inputs.
|
| 86 |
+
|
| 87 |
+
Args:
|
| 88 |
+
model: a detectron2 model that takes `list[dict]` as input.
|
| 89 |
+
inputs (list[dict]): inputs to model, in detectron2's standard format.
|
| 90 |
+
Only "image" key will be used.
|
| 91 |
+
supported_ops (dict[str, Handle]): see documentation of :func:`fvcore.nn.flop_count`
|
| 92 |
+
|
| 93 |
+
Returns:
|
| 94 |
+
Counter: Gflop count per operator
|
| 95 |
+
"""
|
| 96 |
+
old_train = model.training
|
| 97 |
+
model.eval()
|
| 98 |
+
ret = FlopCountAnalysis(model, inputs).by_operator()
|
| 99 |
+
model.train(old_train)
|
| 100 |
+
return {k: v / 1e9 for k, v in ret.items()}
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def activation_count_operators(
|
| 104 |
+
model: nn.Module, inputs: list, **kwargs
|
| 105 |
+
) -> typing.DefaultDict[str, float]:
|
| 106 |
+
"""
|
| 107 |
+
Implement operator-level activations counting using jit.
|
| 108 |
+
This is a wrapper of fvcore.nn.activation_count, that supports standard detection models
|
| 109 |
+
in detectron2.
|
| 110 |
+
|
| 111 |
+
Note:
|
| 112 |
+
The function runs the input through the model to compute activations.
|
| 113 |
+
The activations of a detection model is often input-dependent, for example,
|
| 114 |
+
the activations of box & mask head depends on the number of proposals &
|
| 115 |
+
the number of detected objects.
|
| 116 |
+
|
| 117 |
+
Args:
|
| 118 |
+
model: a detectron2 model that takes `list[dict]` as input.
|
| 119 |
+
inputs (list[dict]): inputs to model, in detectron2's standard format.
|
| 120 |
+
Only "image" key will be used.
|
| 121 |
+
|
| 122 |
+
Returns:
|
| 123 |
+
Counter: activation count per operator
|
| 124 |
+
"""
|
| 125 |
+
return _wrapper_count_operators(model=model, inputs=inputs, mode=ACTIVATIONS_MODE, **kwargs)
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def _wrapper_count_operators(
|
| 129 |
+
model: nn.Module, inputs: list, mode: str, **kwargs
|
| 130 |
+
) -> typing.DefaultDict[str, float]:
|
| 131 |
+
# ignore some ops
|
| 132 |
+
supported_ops = {k: lambda *args, **kwargs: {} for k in _IGNORED_OPS}
|
| 133 |
+
supported_ops.update(kwargs.pop("supported_ops", {}))
|
| 134 |
+
kwargs["supported_ops"] = supported_ops
|
| 135 |
+
|
| 136 |
+
assert len(inputs) == 1, "Please use batch size=1"
|
| 137 |
+
tensor_input = inputs[0]["image"]
|
| 138 |
+
inputs = [{"image": tensor_input}] # remove other keys, in case there are any
|
| 139 |
+
|
| 140 |
+
old_train = model.training
|
| 141 |
+
if isinstance(model, (nn.parallel.distributed.DistributedDataParallel, nn.DataParallel)):
|
| 142 |
+
model = model.module
|
| 143 |
+
wrapper = TracingAdapter(model, inputs)
|
| 144 |
+
wrapper.eval()
|
| 145 |
+
if mode == FLOPS_MODE:
|
| 146 |
+
ret = flop_count(wrapper, (tensor_input,), **kwargs)
|
| 147 |
+
elif mode == ACTIVATIONS_MODE:
|
| 148 |
+
ret = activation_count(wrapper, (tensor_input,), **kwargs)
|
| 149 |
+
else:
|
| 150 |
+
raise NotImplementedError("Count for mode {} is not supported yet.".format(mode))
|
| 151 |
+
# compatible with change in fvcore
|
| 152 |
+
if isinstance(ret, tuple):
|
| 153 |
+
ret = ret[0]
|
| 154 |
+
model.train(old_train)
|
| 155 |
+
return ret
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
def find_unused_parameters(model: nn.Module, inputs: Any) -> List[str]:
|
| 159 |
+
"""
|
| 160 |
+
Given a model, find parameters that do not contribute
|
| 161 |
+
to the loss.
|
| 162 |
+
|
| 163 |
+
Args:
|
| 164 |
+
model: a model in training mode that returns losses
|
| 165 |
+
inputs: argument or a tuple of arguments. Inputs of the model
|
| 166 |
+
|
| 167 |
+
Returns:
|
| 168 |
+
list[str]: the name of unused parameters
|
| 169 |
+
"""
|
| 170 |
+
assert model.training
|
| 171 |
+
for _, prm in model.named_parameters():
|
| 172 |
+
prm.grad = None
|
| 173 |
+
|
| 174 |
+
if isinstance(inputs, tuple):
|
| 175 |
+
losses = model(*inputs)
|
| 176 |
+
else:
|
| 177 |
+
losses = model(inputs)
|
| 178 |
+
|
| 179 |
+
if isinstance(losses, dict):
|
| 180 |
+
losses = sum(losses.values())
|
| 181 |
+
losses.backward()
|
| 182 |
+
|
| 183 |
+
unused: List[str] = []
|
| 184 |
+
for name, prm in model.named_parameters():
|
| 185 |
+
if prm.grad is None:
|
| 186 |
+
unused.append(name)
|
| 187 |
+
prm.grad = None
|
| 188 |
+
return unused
|
detectron2/utils/collect_env.py
ADDED
|
@@ -0,0 +1,263 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
import importlib
|
| 3 |
+
import numpy as np
|
| 4 |
+
import os
|
| 5 |
+
import re
|
| 6 |
+
import subprocess
|
| 7 |
+
import sys
|
| 8 |
+
from collections import defaultdict
|
| 9 |
+
import PIL
|
| 10 |
+
import torch
|
| 11 |
+
import torchvision
|
| 12 |
+
from tabulate import tabulate
|
| 13 |
+
|
| 14 |
+
__all__ = ["collect_env_info"]
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def collect_torch_env():
|
| 18 |
+
try:
|
| 19 |
+
import torch.__config__
|
| 20 |
+
|
| 21 |
+
return torch.__config__.show()
|
| 22 |
+
except ImportError:
|
| 23 |
+
# compatible with older versions of pytorch
|
| 24 |
+
from torch.utils.collect_env import get_pretty_env_info
|
| 25 |
+
|
| 26 |
+
return get_pretty_env_info()
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def get_env_module():
|
| 30 |
+
var_name = "DETECTRON2_ENV_MODULE"
|
| 31 |
+
return var_name, os.environ.get(var_name, "<not set>")
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def detect_compute_compatibility(CUDA_HOME, so_file):
|
| 35 |
+
try:
|
| 36 |
+
cuobjdump = os.path.join(CUDA_HOME, "bin", "cuobjdump")
|
| 37 |
+
if os.path.isfile(cuobjdump):
|
| 38 |
+
output = subprocess.check_output(
|
| 39 |
+
"'{}' --list-elf '{}'".format(cuobjdump, so_file), shell=True
|
| 40 |
+
)
|
| 41 |
+
output = output.decode("utf-8").strip().split("\n")
|
| 42 |
+
arch = []
|
| 43 |
+
for line in output:
|
| 44 |
+
line = re.findall(r"\.sm_([0-9]*)\.", line)[0]
|
| 45 |
+
arch.append(".".join(line))
|
| 46 |
+
arch = sorted(set(arch))
|
| 47 |
+
return ", ".join(arch)
|
| 48 |
+
else:
|
| 49 |
+
return so_file + "; cannot find cuobjdump"
|
| 50 |
+
except Exception:
|
| 51 |
+
# unhandled failure
|
| 52 |
+
return so_file
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def collect_env_info():
|
| 56 |
+
has_gpu = torch.cuda.is_available() # true for both CUDA & ROCM
|
| 57 |
+
torch_version = torch.__version__
|
| 58 |
+
|
| 59 |
+
# NOTE that CUDA_HOME/ROCM_HOME could be None even when CUDA runtime libs are functional
|
| 60 |
+
from torch.utils.cpp_extension import CUDA_HOME, ROCM_HOME
|
| 61 |
+
|
| 62 |
+
has_rocm = False
|
| 63 |
+
if (getattr(torch.version, "hip", None) is not None) and (ROCM_HOME is not None):
|
| 64 |
+
has_rocm = True
|
| 65 |
+
has_cuda = has_gpu and (not has_rocm)
|
| 66 |
+
|
| 67 |
+
data = []
|
| 68 |
+
data.append(("sys.platform", sys.platform)) # check-template.yml depends on it
|
| 69 |
+
data.append(("Python", sys.version.replace("\n", "")))
|
| 70 |
+
data.append(("numpy", np.__version__))
|
| 71 |
+
|
| 72 |
+
try:
|
| 73 |
+
import detectron2 # noqa
|
| 74 |
+
|
| 75 |
+
data.append(
|
| 76 |
+
(
|
| 77 |
+
"detectron2",
|
| 78 |
+
detectron2.__version__ + " @" + os.path.dirname(detectron2.__file__),
|
| 79 |
+
)
|
| 80 |
+
)
|
| 81 |
+
except ImportError:
|
| 82 |
+
data.append(("detectron2", "failed to import"))
|
| 83 |
+
except AttributeError:
|
| 84 |
+
data.append(("detectron2", "imported a wrong installation"))
|
| 85 |
+
|
| 86 |
+
try:
|
| 87 |
+
import detectron2._C as _C
|
| 88 |
+
except ImportError as e:
|
| 89 |
+
data.append(("detectron2._C", f"not built correctly: {e}"))
|
| 90 |
+
|
| 91 |
+
# print system compilers when extension fails to build
|
| 92 |
+
if sys.platform != "win32": # don't know what to do for windows
|
| 93 |
+
try:
|
| 94 |
+
# this is how torch/utils/cpp_extensions.py choose compiler
|
| 95 |
+
cxx = os.environ.get("CXX", "c++")
|
| 96 |
+
cxx = subprocess.check_output("'{}' --version".format(cxx), shell=True)
|
| 97 |
+
cxx = cxx.decode("utf-8").strip().split("\n")[0]
|
| 98 |
+
except subprocess.SubprocessError:
|
| 99 |
+
cxx = "Not found"
|
| 100 |
+
data.append(("Compiler ($CXX)", cxx))
|
| 101 |
+
|
| 102 |
+
if has_cuda and CUDA_HOME is not None:
|
| 103 |
+
try:
|
| 104 |
+
nvcc = os.path.join(CUDA_HOME, "bin", "nvcc")
|
| 105 |
+
nvcc = subprocess.check_output("'{}' -V".format(nvcc), shell=True)
|
| 106 |
+
nvcc = nvcc.decode("utf-8").strip().split("\n")[-1]
|
| 107 |
+
except subprocess.SubprocessError:
|
| 108 |
+
nvcc = "Not found"
|
| 109 |
+
data.append(("CUDA compiler", nvcc))
|
| 110 |
+
if has_cuda and sys.platform != "win32":
|
| 111 |
+
try:
|
| 112 |
+
so_file = importlib.util.find_spec("detectron2._C").origin
|
| 113 |
+
except (ImportError, AttributeError):
|
| 114 |
+
pass
|
| 115 |
+
else:
|
| 116 |
+
data.append(
|
| 117 |
+
(
|
| 118 |
+
"detectron2 arch flags",
|
| 119 |
+
detect_compute_compatibility(CUDA_HOME, so_file),
|
| 120 |
+
)
|
| 121 |
+
)
|
| 122 |
+
else:
|
| 123 |
+
# print compilers that are used to build extension
|
| 124 |
+
data.append(("Compiler", _C.get_compiler_version()))
|
| 125 |
+
data.append(("CUDA compiler", _C.get_cuda_version())) # cuda or hip
|
| 126 |
+
if has_cuda and getattr(_C, "has_cuda", lambda: True)():
|
| 127 |
+
data.append(
|
| 128 |
+
(
|
| 129 |
+
"detectron2 arch flags",
|
| 130 |
+
detect_compute_compatibility(CUDA_HOME, _C.__file__),
|
| 131 |
+
)
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
data.append(get_env_module())
|
| 135 |
+
data.append(("PyTorch", torch_version + " @" + os.path.dirname(torch.__file__)))
|
| 136 |
+
data.append(("PyTorch debug build", torch.version.debug))
|
| 137 |
+
try:
|
| 138 |
+
data.append(("torch._C._GLIBCXX_USE_CXX11_ABI", torch._C._GLIBCXX_USE_CXX11_ABI))
|
| 139 |
+
except Exception:
|
| 140 |
+
pass
|
| 141 |
+
|
| 142 |
+
if not has_gpu:
|
| 143 |
+
has_gpu_text = "No: torch.cuda.is_available() == False"
|
| 144 |
+
else:
|
| 145 |
+
has_gpu_text = "Yes"
|
| 146 |
+
data.append(("GPU available", has_gpu_text))
|
| 147 |
+
if has_gpu:
|
| 148 |
+
devices = defaultdict(list)
|
| 149 |
+
for k in range(torch.cuda.device_count()):
|
| 150 |
+
cap = ".".join((str(x) for x in torch.cuda.get_device_capability(k)))
|
| 151 |
+
name = torch.cuda.get_device_name(k) + f" (arch={cap})"
|
| 152 |
+
devices[name].append(str(k))
|
| 153 |
+
for name, devids in devices.items():
|
| 154 |
+
data.append(("GPU " + ",".join(devids), name))
|
| 155 |
+
|
| 156 |
+
if has_rocm:
|
| 157 |
+
msg = " - invalid!" if not (ROCM_HOME and os.path.isdir(ROCM_HOME)) else ""
|
| 158 |
+
data.append(("ROCM_HOME", str(ROCM_HOME) + msg))
|
| 159 |
+
else:
|
| 160 |
+
try:
|
| 161 |
+
from torch.utils.collect_env import (
|
| 162 |
+
get_nvidia_driver_version,
|
| 163 |
+
run as _run,
|
| 164 |
+
)
|
| 165 |
+
|
| 166 |
+
data.append(("Driver version", get_nvidia_driver_version(_run)))
|
| 167 |
+
except Exception:
|
| 168 |
+
pass
|
| 169 |
+
msg = " - invalid!" if not (CUDA_HOME and os.path.isdir(CUDA_HOME)) else ""
|
| 170 |
+
data.append(("CUDA_HOME", str(CUDA_HOME) + msg))
|
| 171 |
+
|
| 172 |
+
cuda_arch_list = os.environ.get("TORCH_CUDA_ARCH_LIST", None)
|
| 173 |
+
if cuda_arch_list:
|
| 174 |
+
data.append(("TORCH_CUDA_ARCH_LIST", cuda_arch_list))
|
| 175 |
+
data.append(("Pillow", PIL.__version__))
|
| 176 |
+
|
| 177 |
+
try:
|
| 178 |
+
data.append(
|
| 179 |
+
(
|
| 180 |
+
"torchvision",
|
| 181 |
+
str(torchvision.__version__) + " @" + os.path.dirname(torchvision.__file__),
|
| 182 |
+
)
|
| 183 |
+
)
|
| 184 |
+
if has_cuda:
|
| 185 |
+
try:
|
| 186 |
+
torchvision_C = importlib.util.find_spec("torchvision._C").origin
|
| 187 |
+
msg = detect_compute_compatibility(CUDA_HOME, torchvision_C)
|
| 188 |
+
data.append(("torchvision arch flags", msg))
|
| 189 |
+
except (ImportError, AttributeError):
|
| 190 |
+
data.append(("torchvision._C", "Not found"))
|
| 191 |
+
except AttributeError:
|
| 192 |
+
data.append(("torchvision", "unknown"))
|
| 193 |
+
|
| 194 |
+
try:
|
| 195 |
+
import fvcore
|
| 196 |
+
|
| 197 |
+
data.append(("fvcore", fvcore.__version__))
|
| 198 |
+
except (ImportError, AttributeError):
|
| 199 |
+
pass
|
| 200 |
+
|
| 201 |
+
try:
|
| 202 |
+
import iopath
|
| 203 |
+
|
| 204 |
+
data.append(("iopath", iopath.__version__))
|
| 205 |
+
except (ImportError, AttributeError):
|
| 206 |
+
pass
|
| 207 |
+
|
| 208 |
+
try:
|
| 209 |
+
import cv2
|
| 210 |
+
|
| 211 |
+
data.append(("cv2", cv2.__version__))
|
| 212 |
+
except (ImportError, AttributeError):
|
| 213 |
+
data.append(("cv2", "Not found"))
|
| 214 |
+
env_str = tabulate(data) + "\n"
|
| 215 |
+
env_str += collect_torch_env()
|
| 216 |
+
return env_str
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
def test_nccl_ops():
|
| 220 |
+
num_gpu = torch.cuda.device_count()
|
| 221 |
+
if os.access("/tmp", os.W_OK):
|
| 222 |
+
import torch.multiprocessing as mp
|
| 223 |
+
|
| 224 |
+
dist_url = "file:///tmp/nccl_tmp_file"
|
| 225 |
+
print("Testing NCCL connectivity ... this should not hang.")
|
| 226 |
+
mp.spawn(_test_nccl_worker, nprocs=num_gpu, args=(num_gpu, dist_url), daemon=False)
|
| 227 |
+
print("NCCL succeeded.")
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
def _test_nccl_worker(rank, num_gpu, dist_url):
|
| 231 |
+
import torch.distributed as dist
|
| 232 |
+
|
| 233 |
+
dist.init_process_group(backend="NCCL", init_method=dist_url, rank=rank, world_size=num_gpu)
|
| 234 |
+
dist.barrier(device_ids=[rank])
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
def main() -> None:
|
| 238 |
+
global x
|
| 239 |
+
try:
|
| 240 |
+
from detectron2.utils.collect_env import collect_env_info as f
|
| 241 |
+
|
| 242 |
+
print(f())
|
| 243 |
+
except ImportError:
|
| 244 |
+
print(collect_env_info())
|
| 245 |
+
|
| 246 |
+
if torch.cuda.is_available():
|
| 247 |
+
num_gpu = torch.cuda.device_count()
|
| 248 |
+
for k in range(num_gpu):
|
| 249 |
+
device = f"cuda:{k}"
|
| 250 |
+
try:
|
| 251 |
+
x = torch.tensor([1, 2.0], dtype=torch.float32)
|
| 252 |
+
x = x.to(device)
|
| 253 |
+
except Exception as e:
|
| 254 |
+
print(
|
| 255 |
+
f"Unable to copy tensor to device={device}: {e}. "
|
| 256 |
+
"Your CUDA environment is broken."
|
| 257 |
+
)
|
| 258 |
+
if num_gpu > 1:
|
| 259 |
+
test_nccl_ops()
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
if __name__ == "__main__":
|
| 263 |
+
main() # pragma: no cover
|
detectron2/utils/colormap.py
ADDED
|
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
|
| 3 |
+
"""
|
| 4 |
+
An awesome colormap for really neat visualizations.
|
| 5 |
+
Copied from Detectron, and removed gray colors.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
import random
|
| 10 |
+
|
| 11 |
+
__all__ = ["colormap", "random_color", "random_colors"]
|
| 12 |
+
|
| 13 |
+
# fmt: off
|
| 14 |
+
# RGB:
|
| 15 |
+
_COLORS = np.array(
|
| 16 |
+
[
|
| 17 |
+
0.000, 0.447, 0.741,
|
| 18 |
+
0.850, 0.325, 0.098,
|
| 19 |
+
0.929, 0.694, 0.125,
|
| 20 |
+
0.494, 0.184, 0.556,
|
| 21 |
+
0.466, 0.674, 0.188,
|
| 22 |
+
0.301, 0.745, 0.933,
|
| 23 |
+
0.635, 0.078, 0.184,
|
| 24 |
+
0.300, 0.300, 0.300,
|
| 25 |
+
0.600, 0.600, 0.600,
|
| 26 |
+
1.000, 0.000, 0.000,
|
| 27 |
+
1.000, 0.500, 0.000,
|
| 28 |
+
0.749, 0.749, 0.000,
|
| 29 |
+
0.000, 1.000, 0.000,
|
| 30 |
+
0.000, 0.000, 1.000,
|
| 31 |
+
0.667, 0.000, 1.000,
|
| 32 |
+
0.333, 0.333, 0.000,
|
| 33 |
+
0.333, 0.667, 0.000,
|
| 34 |
+
0.333, 1.000, 0.000,
|
| 35 |
+
0.667, 0.333, 0.000,
|
| 36 |
+
0.667, 0.667, 0.000,
|
| 37 |
+
0.667, 1.000, 0.000,
|
| 38 |
+
1.000, 0.333, 0.000,
|
| 39 |
+
1.000, 0.667, 0.000,
|
| 40 |
+
1.000, 1.000, 0.000,
|
| 41 |
+
0.000, 0.333, 0.500,
|
| 42 |
+
0.000, 0.667, 0.500,
|
| 43 |
+
0.000, 1.000, 0.500,
|
| 44 |
+
0.333, 0.000, 0.500,
|
| 45 |
+
0.333, 0.333, 0.500,
|
| 46 |
+
0.333, 0.667, 0.500,
|
| 47 |
+
0.333, 1.000, 0.500,
|
| 48 |
+
0.667, 0.000, 0.500,
|
| 49 |
+
0.667, 0.333, 0.500,
|
| 50 |
+
0.667, 0.667, 0.500,
|
| 51 |
+
0.667, 1.000, 0.500,
|
| 52 |
+
1.000, 0.000, 0.500,
|
| 53 |
+
1.000, 0.333, 0.500,
|
| 54 |
+
1.000, 0.667, 0.500,
|
| 55 |
+
1.000, 1.000, 0.500,
|
| 56 |
+
0.000, 0.333, 1.000,
|
| 57 |
+
0.000, 0.667, 1.000,
|
| 58 |
+
0.000, 1.000, 1.000,
|
| 59 |
+
0.333, 0.000, 1.000,
|
| 60 |
+
0.333, 0.333, 1.000,
|
| 61 |
+
0.333, 0.667, 1.000,
|
| 62 |
+
0.333, 1.000, 1.000,
|
| 63 |
+
0.667, 0.000, 1.000,
|
| 64 |
+
0.667, 0.333, 1.000,
|
| 65 |
+
0.667, 0.667, 1.000,
|
| 66 |
+
0.667, 1.000, 1.000,
|
| 67 |
+
1.000, 0.000, 1.000,
|
| 68 |
+
1.000, 0.333, 1.000,
|
| 69 |
+
1.000, 0.667, 1.000,
|
| 70 |
+
0.333, 0.000, 0.000,
|
| 71 |
+
0.500, 0.000, 0.000,
|
| 72 |
+
0.667, 0.000, 0.000,
|
| 73 |
+
0.833, 0.000, 0.000,
|
| 74 |
+
1.000, 0.000, 0.000,
|
| 75 |
+
0.000, 0.167, 0.000,
|
| 76 |
+
0.000, 0.333, 0.000,
|
| 77 |
+
0.000, 0.500, 0.000,
|
| 78 |
+
0.000, 0.667, 0.000,
|
| 79 |
+
0.000, 0.833, 0.000,
|
| 80 |
+
0.000, 1.000, 0.000,
|
| 81 |
+
0.000, 0.000, 0.167,
|
| 82 |
+
0.000, 0.000, 0.333,
|
| 83 |
+
0.000, 0.000, 0.500,
|
| 84 |
+
0.000, 0.000, 0.667,
|
| 85 |
+
0.000, 0.000, 0.833,
|
| 86 |
+
0.000, 0.000, 1.000,
|
| 87 |
+
0.000, 0.000, 0.000,
|
| 88 |
+
0.143, 0.143, 0.143,
|
| 89 |
+
0.857, 0.857, 0.857,
|
| 90 |
+
1.000, 1.000, 1.000
|
| 91 |
+
]
|
| 92 |
+
).astype(np.float32).reshape(-1, 3)
|
| 93 |
+
# fmt: on
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def colormap(rgb=False, maximum=255):
|
| 97 |
+
"""
|
| 98 |
+
Args:
|
| 99 |
+
rgb (bool): whether to return RGB colors or BGR colors.
|
| 100 |
+
maximum (int): either 255 or 1
|
| 101 |
+
|
| 102 |
+
Returns:
|
| 103 |
+
ndarray: a float32 array of Nx3 colors, in range [0, 255] or [0, 1]
|
| 104 |
+
"""
|
| 105 |
+
assert maximum in [255, 1], maximum
|
| 106 |
+
c = _COLORS * maximum
|
| 107 |
+
if not rgb:
|
| 108 |
+
c = c[:, ::-1]
|
| 109 |
+
return c
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def random_color(rgb=False, maximum=255):
|
| 113 |
+
"""
|
| 114 |
+
Args:
|
| 115 |
+
rgb (bool): whether to return RGB colors or BGR colors.
|
| 116 |
+
maximum (int): either 255 or 1
|
| 117 |
+
|
| 118 |
+
Returns:
|
| 119 |
+
ndarray: a vector of 3 numbers
|
| 120 |
+
"""
|
| 121 |
+
idx = np.random.randint(0, len(_COLORS))
|
| 122 |
+
ret = _COLORS[idx] * maximum
|
| 123 |
+
if not rgb:
|
| 124 |
+
ret = ret[::-1]
|
| 125 |
+
return ret
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def random_colors(N, rgb=False, maximum=255):
|
| 129 |
+
"""
|
| 130 |
+
Args:
|
| 131 |
+
N (int): number of unique colors needed
|
| 132 |
+
rgb (bool): whether to return RGB colors or BGR colors.
|
| 133 |
+
maximum (int): either 255 or 1
|
| 134 |
+
|
| 135 |
+
Returns:
|
| 136 |
+
ndarray: a list of random_color
|
| 137 |
+
"""
|
| 138 |
+
indices = random.sample(range(len(_COLORS)), N)
|
| 139 |
+
ret = [_COLORS[i] * maximum for i in indices]
|
| 140 |
+
if not rgb:
|
| 141 |
+
ret = [x[::-1] for x in ret]
|
| 142 |
+
return ret
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
if __name__ == "__main__":
|
| 146 |
+
import cv2
|
| 147 |
+
|
| 148 |
+
size = 100
|
| 149 |
+
H, W = 10, 10
|
| 150 |
+
canvas = np.random.rand(H * size, W * size, 3).astype("float32")
|
| 151 |
+
for h in range(H):
|
| 152 |
+
for w in range(W):
|
| 153 |
+
idx = h * W + w
|
| 154 |
+
if idx >= len(_COLORS):
|
| 155 |
+
break
|
| 156 |
+
canvas[h * size : (h + 1) * size, w * size : (w + 1) * size] = _COLORS[idx]
|
| 157 |
+
cv2.imshow("a", canvas)
|
| 158 |
+
cv2.waitKey(0)
|
detectron2/utils/comm.py
ADDED
|
@@ -0,0 +1,238 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
"""
|
| 3 |
+
This file contains primitives for multi-gpu communication.
|
| 4 |
+
This is useful when doing distributed training.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import functools
|
| 8 |
+
import numpy as np
|
| 9 |
+
import torch
|
| 10 |
+
import torch.distributed as dist
|
| 11 |
+
|
| 12 |
+
_LOCAL_PROCESS_GROUP = None
|
| 13 |
+
_MISSING_LOCAL_PG_ERROR = (
|
| 14 |
+
"Local process group is not yet created! Please use detectron2's `launch()` "
|
| 15 |
+
"to start processes and initialize pytorch process group. If you need to start "
|
| 16 |
+
"processes in other ways, please call comm.create_local_process_group("
|
| 17 |
+
"num_workers_per_machine) after calling torch.distributed.init_process_group()."
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def get_world_size() -> int:
|
| 22 |
+
if not dist.is_available():
|
| 23 |
+
return 1
|
| 24 |
+
if not dist.is_initialized():
|
| 25 |
+
return 1
|
| 26 |
+
return dist.get_world_size()
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def get_rank() -> int:
|
| 30 |
+
if not dist.is_available():
|
| 31 |
+
return 0
|
| 32 |
+
if not dist.is_initialized():
|
| 33 |
+
return 0
|
| 34 |
+
return dist.get_rank()
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
@functools.lru_cache()
|
| 38 |
+
def create_local_process_group(num_workers_per_machine: int) -> None:
|
| 39 |
+
"""
|
| 40 |
+
Create a process group that contains ranks within the same machine.
|
| 41 |
+
|
| 42 |
+
Detectron2's launch() in engine/launch.py will call this function. If you start
|
| 43 |
+
workers without launch(), you'll have to also call this. Otherwise utilities
|
| 44 |
+
like `get_local_rank()` will not work.
|
| 45 |
+
|
| 46 |
+
This function contains a barrier. All processes must call it together.
|
| 47 |
+
|
| 48 |
+
Args:
|
| 49 |
+
num_workers_per_machine: the number of worker processes per machine. Typically
|
| 50 |
+
the number of GPUs.
|
| 51 |
+
"""
|
| 52 |
+
global _LOCAL_PROCESS_GROUP
|
| 53 |
+
assert _LOCAL_PROCESS_GROUP is None
|
| 54 |
+
assert get_world_size() % num_workers_per_machine == 0
|
| 55 |
+
num_machines = get_world_size() // num_workers_per_machine
|
| 56 |
+
machine_rank = get_rank() // num_workers_per_machine
|
| 57 |
+
for i in range(num_machines):
|
| 58 |
+
ranks_on_i = list(range(i * num_workers_per_machine, (i + 1) * num_workers_per_machine))
|
| 59 |
+
pg = dist.new_group(ranks_on_i)
|
| 60 |
+
if i == machine_rank:
|
| 61 |
+
_LOCAL_PROCESS_GROUP = pg
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def get_local_process_group():
|
| 65 |
+
"""
|
| 66 |
+
Returns:
|
| 67 |
+
A torch process group which only includes processes that are on the same
|
| 68 |
+
machine as the current process. This group can be useful for communication
|
| 69 |
+
within a machine, e.g. a per-machine SyncBN.
|
| 70 |
+
"""
|
| 71 |
+
assert _LOCAL_PROCESS_GROUP is not None, _MISSING_LOCAL_PG_ERROR
|
| 72 |
+
return _LOCAL_PROCESS_GROUP
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def get_local_rank() -> int:
|
| 76 |
+
"""
|
| 77 |
+
Returns:
|
| 78 |
+
The rank of the current process within the local (per-machine) process group.
|
| 79 |
+
"""
|
| 80 |
+
if not dist.is_available():
|
| 81 |
+
return 0
|
| 82 |
+
if not dist.is_initialized():
|
| 83 |
+
return 0
|
| 84 |
+
assert _LOCAL_PROCESS_GROUP is not None, _MISSING_LOCAL_PG_ERROR
|
| 85 |
+
return dist.get_rank(group=_LOCAL_PROCESS_GROUP)
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def get_local_size() -> int:
|
| 89 |
+
"""
|
| 90 |
+
Returns:
|
| 91 |
+
The size of the per-machine process group,
|
| 92 |
+
i.e. the number of processes per machine.
|
| 93 |
+
"""
|
| 94 |
+
if not dist.is_available():
|
| 95 |
+
return 1
|
| 96 |
+
if not dist.is_initialized():
|
| 97 |
+
return 1
|
| 98 |
+
assert _LOCAL_PROCESS_GROUP is not None, _MISSING_LOCAL_PG_ERROR
|
| 99 |
+
return dist.get_world_size(group=_LOCAL_PROCESS_GROUP)
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def is_main_process() -> bool:
|
| 103 |
+
return get_rank() == 0
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def synchronize():
|
| 107 |
+
"""
|
| 108 |
+
Helper function to synchronize (barrier) among all processes when
|
| 109 |
+
using distributed training
|
| 110 |
+
"""
|
| 111 |
+
if not dist.is_available():
|
| 112 |
+
return
|
| 113 |
+
if not dist.is_initialized():
|
| 114 |
+
return
|
| 115 |
+
world_size = dist.get_world_size()
|
| 116 |
+
if world_size == 1:
|
| 117 |
+
return
|
| 118 |
+
if dist.get_backend() == dist.Backend.NCCL:
|
| 119 |
+
# This argument is needed to avoid warnings.
|
| 120 |
+
# It's valid only for NCCL backend.
|
| 121 |
+
dist.barrier(device_ids=[torch.cuda.current_device()])
|
| 122 |
+
else:
|
| 123 |
+
dist.barrier()
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
@functools.lru_cache()
|
| 127 |
+
def _get_global_gloo_group():
|
| 128 |
+
"""
|
| 129 |
+
Return a process group based on gloo backend, containing all the ranks
|
| 130 |
+
The result is cached.
|
| 131 |
+
"""
|
| 132 |
+
if dist.get_backend() == "nccl":
|
| 133 |
+
return dist.new_group(backend="gloo")
|
| 134 |
+
else:
|
| 135 |
+
return dist.group.WORLD
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def all_gather(data, group=None):
|
| 139 |
+
"""
|
| 140 |
+
Run all_gather on arbitrary picklable data (not necessarily tensors).
|
| 141 |
+
|
| 142 |
+
Args:
|
| 143 |
+
data: any picklable object
|
| 144 |
+
group: a torch process group. By default, will use a group which
|
| 145 |
+
contains all ranks on gloo backend.
|
| 146 |
+
|
| 147 |
+
Returns:
|
| 148 |
+
list[data]: list of data gathered from each rank
|
| 149 |
+
"""
|
| 150 |
+
if get_world_size() == 1:
|
| 151 |
+
return [data]
|
| 152 |
+
if group is None:
|
| 153 |
+
group = _get_global_gloo_group() # use CPU group by default, to reduce GPU RAM usage.
|
| 154 |
+
world_size = dist.get_world_size(group)
|
| 155 |
+
if world_size == 1:
|
| 156 |
+
return [data]
|
| 157 |
+
|
| 158 |
+
output = [None for _ in range(world_size)]
|
| 159 |
+
dist.all_gather_object(output, data, group=group)
|
| 160 |
+
return output
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
def gather(data, dst=0, group=None):
|
| 164 |
+
"""
|
| 165 |
+
Run gather on arbitrary picklable data (not necessarily tensors).
|
| 166 |
+
|
| 167 |
+
Args:
|
| 168 |
+
data: any picklable object
|
| 169 |
+
dst (int): destination rank
|
| 170 |
+
group: a torch process group. By default, will use a group which
|
| 171 |
+
contains all ranks on gloo backend.
|
| 172 |
+
|
| 173 |
+
Returns:
|
| 174 |
+
list[data]: on dst, a list of data gathered from each rank. Otherwise,
|
| 175 |
+
an empty list.
|
| 176 |
+
"""
|
| 177 |
+
if get_world_size() == 1:
|
| 178 |
+
return [data]
|
| 179 |
+
if group is None:
|
| 180 |
+
group = _get_global_gloo_group()
|
| 181 |
+
world_size = dist.get_world_size(group=group)
|
| 182 |
+
if world_size == 1:
|
| 183 |
+
return [data]
|
| 184 |
+
rank = dist.get_rank(group=group)
|
| 185 |
+
|
| 186 |
+
if rank == dst:
|
| 187 |
+
output = [None for _ in range(world_size)]
|
| 188 |
+
dist.gather_object(data, output, dst=dst, group=group)
|
| 189 |
+
return output
|
| 190 |
+
else:
|
| 191 |
+
dist.gather_object(data, None, dst=dst, group=group)
|
| 192 |
+
return []
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
def shared_random_seed():
|
| 196 |
+
"""
|
| 197 |
+
Returns:
|
| 198 |
+
int: a random number that is the same across all workers.
|
| 199 |
+
If workers need a shared RNG, they can use this shared seed to
|
| 200 |
+
create one.
|
| 201 |
+
|
| 202 |
+
All workers must call this function, otherwise it will deadlock.
|
| 203 |
+
"""
|
| 204 |
+
ints = np.random.randint(2**31)
|
| 205 |
+
all_ints = all_gather(ints)
|
| 206 |
+
return all_ints[0]
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
def reduce_dict(input_dict, average=True):
|
| 210 |
+
"""
|
| 211 |
+
Reduce the values in the dictionary from all processes so that process with rank
|
| 212 |
+
0 has the reduced results.
|
| 213 |
+
|
| 214 |
+
Args:
|
| 215 |
+
input_dict (dict): inputs to be reduced. All the values must be scalar CUDA Tensor.
|
| 216 |
+
average (bool): whether to do average or sum
|
| 217 |
+
|
| 218 |
+
Returns:
|
| 219 |
+
a dict with the same keys as input_dict, after reduction.
|
| 220 |
+
"""
|
| 221 |
+
world_size = get_world_size()
|
| 222 |
+
if world_size < 2:
|
| 223 |
+
return input_dict
|
| 224 |
+
with torch.no_grad():
|
| 225 |
+
names = []
|
| 226 |
+
values = []
|
| 227 |
+
# sort the keys so that they are consistent across processes
|
| 228 |
+
for k in sorted(input_dict.keys()):
|
| 229 |
+
names.append(k)
|
| 230 |
+
values.append(input_dict[k])
|
| 231 |
+
values = torch.stack(values, dim=0)
|
| 232 |
+
dist.reduce(values, dst=0)
|
| 233 |
+
if dist.get_rank() == 0 and average:
|
| 234 |
+
# only main process gets accumulated, so only divide by
|
| 235 |
+
# world_size in this case
|
| 236 |
+
values /= world_size
|
| 237 |
+
reduced_dict = {k: v for k, v in zip(names, values)}
|
| 238 |
+
return reduced_dict
|
detectron2/utils/develop.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
""" Utilities for developers only.
|
| 3 |
+
These are not visible to users (not automatically imported). And should not
|
| 4 |
+
appeared in docs."""
|
| 5 |
+
# adapted from https://github.com/tensorpack/tensorpack/blob/master/tensorpack/utils/develop.py
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def create_dummy_class(klass, dependency, message=""):
|
| 9 |
+
"""
|
| 10 |
+
When a dependency of a class is not available, create a dummy class which throws ImportError
|
| 11 |
+
when used.
|
| 12 |
+
|
| 13 |
+
Args:
|
| 14 |
+
klass (str): name of the class.
|
| 15 |
+
dependency (str): name of the dependency.
|
| 16 |
+
message: extra message to print
|
| 17 |
+
Returns:
|
| 18 |
+
class: a class object
|
| 19 |
+
"""
|
| 20 |
+
err = "Cannot import '{}', therefore '{}' is not available.".format(dependency, klass)
|
| 21 |
+
if message:
|
| 22 |
+
err = err + " " + message
|
| 23 |
+
|
| 24 |
+
class _DummyMetaClass(type):
|
| 25 |
+
# throw error on class attribute access
|
| 26 |
+
def __getattr__(_, __): # noqa: B902
|
| 27 |
+
raise ImportError(err)
|
| 28 |
+
|
| 29 |
+
class _Dummy(object, metaclass=_DummyMetaClass):
|
| 30 |
+
# throw error on constructor
|
| 31 |
+
def __init__(self, *args, **kwargs):
|
| 32 |
+
raise ImportError(err)
|
| 33 |
+
|
| 34 |
+
return _Dummy
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def create_dummy_func(func, dependency, message=""):
|
| 38 |
+
"""
|
| 39 |
+
When a dependency of a function is not available, create a dummy function which throws
|
| 40 |
+
ImportError when used.
|
| 41 |
+
|
| 42 |
+
Args:
|
| 43 |
+
func (str): name of the function.
|
| 44 |
+
dependency (str or list[str]): name(s) of the dependency.
|
| 45 |
+
message: extra message to print
|
| 46 |
+
Returns:
|
| 47 |
+
function: a function object
|
| 48 |
+
"""
|
| 49 |
+
err = "Cannot import '{}', therefore '{}' is not available.".format(dependency, func)
|
| 50 |
+
if message:
|
| 51 |
+
err = err + " " + message
|
| 52 |
+
|
| 53 |
+
if isinstance(dependency, (list, tuple)):
|
| 54 |
+
dependency = ",".join(dependency)
|
| 55 |
+
|
| 56 |
+
def _dummy(*args, **kwargs):
|
| 57 |
+
raise ImportError(err)
|
| 58 |
+
|
| 59 |
+
return _dummy
|
detectron2/utils/env.py
ADDED
|
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
import importlib
|
| 3 |
+
import importlib.util
|
| 4 |
+
import logging
|
| 5 |
+
import numpy as np
|
| 6 |
+
import os
|
| 7 |
+
import random
|
| 8 |
+
import sys
|
| 9 |
+
from datetime import datetime
|
| 10 |
+
import torch
|
| 11 |
+
|
| 12 |
+
__all__ = ["seed_all_rng"]
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
TORCH_VERSION = tuple(int(x) for x in torch.__version__.split(".")[:2])
|
| 16 |
+
"""
|
| 17 |
+
PyTorch version as a tuple of 2 ints. Useful for comparison.
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
DOC_BUILDING = os.getenv("_DOC_BUILDING", False) # set in docs/conf.py
|
| 22 |
+
"""
|
| 23 |
+
Whether we're building documentation.
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def seed_all_rng(seed=None):
|
| 28 |
+
"""
|
| 29 |
+
Set the random seed for the RNG in torch, numpy and python.
|
| 30 |
+
|
| 31 |
+
Args:
|
| 32 |
+
seed (int): if None, will use a strong random seed.
|
| 33 |
+
"""
|
| 34 |
+
if seed is None:
|
| 35 |
+
seed = (
|
| 36 |
+
os.getpid()
|
| 37 |
+
+ int(datetime.now().strftime("%S%f"))
|
| 38 |
+
+ int.from_bytes(os.urandom(2), "big")
|
| 39 |
+
)
|
| 40 |
+
logger = logging.getLogger(__name__)
|
| 41 |
+
logger.info("Using a generated random seed {}".format(seed))
|
| 42 |
+
np.random.seed(seed)
|
| 43 |
+
torch.manual_seed(seed)
|
| 44 |
+
random.seed(seed)
|
| 45 |
+
torch.cuda.manual_seed_all(str(seed))
|
| 46 |
+
os.environ["PYTHONHASHSEED"] = str(seed)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
# from https://stackoverflow.com/questions/67631/how-to-import-a-module-given-the-full-path
|
| 50 |
+
def _import_file(module_name, file_path, make_importable=False):
|
| 51 |
+
spec = importlib.util.spec_from_file_location(module_name, file_path)
|
| 52 |
+
module = importlib.util.module_from_spec(spec)
|
| 53 |
+
spec.loader.exec_module(module)
|
| 54 |
+
if make_importable:
|
| 55 |
+
sys.modules[module_name] = module
|
| 56 |
+
return module
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def _configure_libraries():
|
| 60 |
+
"""
|
| 61 |
+
Configurations for some libraries.
|
| 62 |
+
"""
|
| 63 |
+
# An environment option to disable `import cv2` globally,
|
| 64 |
+
# in case it leads to negative performance impact
|
| 65 |
+
disable_cv2 = int(os.environ.get("DETECTRON2_DISABLE_CV2", False))
|
| 66 |
+
if disable_cv2:
|
| 67 |
+
sys.modules["cv2"] = None
|
| 68 |
+
else:
|
| 69 |
+
# Disable opencl in opencv since its interaction with cuda often has negative effects
|
| 70 |
+
# This envvar is supported after OpenCV 3.4.0
|
| 71 |
+
os.environ["OPENCV_OPENCL_RUNTIME"] = "disabled"
|
| 72 |
+
try:
|
| 73 |
+
import cv2
|
| 74 |
+
|
| 75 |
+
if int(cv2.__version__.split(".")[0]) >= 3:
|
| 76 |
+
cv2.ocl.setUseOpenCL(False)
|
| 77 |
+
except ModuleNotFoundError:
|
| 78 |
+
# Other types of ImportError, if happened, should not be ignored.
|
| 79 |
+
# Because a failed opencv import could mess up address space
|
| 80 |
+
# https://github.com/skvark/opencv-python/issues/381
|
| 81 |
+
pass
|
| 82 |
+
|
| 83 |
+
def get_version(module, digit=2):
|
| 84 |
+
return tuple(map(int, module.__version__.split(".")[:digit]))
|
| 85 |
+
|
| 86 |
+
# fmt: off
|
| 87 |
+
assert get_version(torch) >= (1, 4), "Requires torch>=1.4"
|
| 88 |
+
import fvcore
|
| 89 |
+
assert get_version(fvcore, 3) >= (0, 1, 2), "Requires fvcore>=0.1.2"
|
| 90 |
+
import yaml
|
| 91 |
+
assert get_version(yaml) >= (5, 1), "Requires pyyaml>=5.1"
|
| 92 |
+
# fmt: on
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
_ENV_SETUP_DONE = False
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def setup_environment():
|
| 99 |
+
"""Perform environment setup work. The default setup is a no-op, but this
|
| 100 |
+
function allows the user to specify a Python source file or a module in
|
| 101 |
+
the $DETECTRON2_ENV_MODULE environment variable, that performs
|
| 102 |
+
custom setup work that may be necessary to their computing environment.
|
| 103 |
+
"""
|
| 104 |
+
global _ENV_SETUP_DONE
|
| 105 |
+
if _ENV_SETUP_DONE:
|
| 106 |
+
return
|
| 107 |
+
_ENV_SETUP_DONE = True
|
| 108 |
+
|
| 109 |
+
_configure_libraries()
|
| 110 |
+
|
| 111 |
+
custom_module_path = os.environ.get("DETECTRON2_ENV_MODULE")
|
| 112 |
+
|
| 113 |
+
if custom_module_path:
|
| 114 |
+
setup_custom_environment(custom_module_path)
|
| 115 |
+
else:
|
| 116 |
+
# The default setup is a no-op
|
| 117 |
+
pass
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def setup_custom_environment(custom_module):
|
| 121 |
+
"""
|
| 122 |
+
Load custom environment setup by importing a Python source file or a
|
| 123 |
+
module, and run the setup function.
|
| 124 |
+
"""
|
| 125 |
+
if custom_module.endswith(".py"):
|
| 126 |
+
module = _import_file("detectron2.utils.env.custom_module", custom_module)
|
| 127 |
+
else:
|
| 128 |
+
module = importlib.import_module(custom_module)
|
| 129 |
+
assert hasattr(module, "setup_environment") and callable(module.setup_environment), (
|
| 130 |
+
"Custom environment module defined in {} does not have the "
|
| 131 |
+
"required callable attribute 'setup_environment'."
|
| 132 |
+
).format(custom_module)
|
| 133 |
+
module.setup_environment()
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def fixup_module_metadata(module_name, namespace, keys=None):
|
| 137 |
+
"""
|
| 138 |
+
Fix the __qualname__ of module members to be their exported api name, so
|
| 139 |
+
when they are referenced in docs, sphinx can find them. Reference:
|
| 140 |
+
https://github.com/python-trio/trio/blob/6754c74eacfad9cc5c92d5c24727a2f3b620624e/trio/_util.py#L216-L241
|
| 141 |
+
"""
|
| 142 |
+
if not DOC_BUILDING:
|
| 143 |
+
return
|
| 144 |
+
seen_ids = set()
|
| 145 |
+
|
| 146 |
+
def fix_one(qualname, name, obj):
|
| 147 |
+
# avoid infinite recursion (relevant when using
|
| 148 |
+
# typing.Generic, for example)
|
| 149 |
+
if id(obj) in seen_ids:
|
| 150 |
+
return
|
| 151 |
+
seen_ids.add(id(obj))
|
| 152 |
+
|
| 153 |
+
mod = getattr(obj, "__module__", None)
|
| 154 |
+
if mod is not None and (mod.startswith(module_name) or mod.startswith("fvcore.")):
|
| 155 |
+
obj.__module__ = module_name
|
| 156 |
+
# Modules, unlike everything else in Python, put fully-qualitied
|
| 157 |
+
# names into their __name__ attribute. We check for "." to avoid
|
| 158 |
+
# rewriting these.
|
| 159 |
+
if hasattr(obj, "__name__") and "." not in obj.__name__:
|
| 160 |
+
obj.__name__ = name
|
| 161 |
+
obj.__qualname__ = qualname
|
| 162 |
+
if isinstance(obj, type):
|
| 163 |
+
for attr_name, attr_value in obj.__dict__.items():
|
| 164 |
+
fix_one(objname + "." + attr_name, attr_name, attr_value)
|
| 165 |
+
|
| 166 |
+
if keys is None:
|
| 167 |
+
keys = namespace.keys()
|
| 168 |
+
for objname in keys:
|
| 169 |
+
if not objname.startswith("_"):
|
| 170 |
+
obj = namespace[objname]
|
| 171 |
+
fix_one(objname, objname, obj)
|
detectron2/utils/events.py
ADDED
|
@@ -0,0 +1,557 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
import datetime
|
| 3 |
+
import json
|
| 4 |
+
import logging
|
| 5 |
+
import os
|
| 6 |
+
import time
|
| 7 |
+
from collections import defaultdict
|
| 8 |
+
from contextlib import contextmanager
|
| 9 |
+
from functools import cached_property
|
| 10 |
+
from typing import Optional
|
| 11 |
+
import torch
|
| 12 |
+
from fvcore.common.history_buffer import HistoryBuffer
|
| 13 |
+
|
| 14 |
+
from detectron2.utils.file_io import PathManager
|
| 15 |
+
|
| 16 |
+
__all__ = [
|
| 17 |
+
"get_event_storage",
|
| 18 |
+
"has_event_storage",
|
| 19 |
+
"JSONWriter",
|
| 20 |
+
"TensorboardXWriter",
|
| 21 |
+
"CommonMetricPrinter",
|
| 22 |
+
"EventStorage",
|
| 23 |
+
]
|
| 24 |
+
|
| 25 |
+
_CURRENT_STORAGE_STACK = []
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def get_event_storage():
|
| 29 |
+
"""
|
| 30 |
+
Returns:
|
| 31 |
+
The :class:`EventStorage` object that's currently being used.
|
| 32 |
+
Throws an error if no :class:`EventStorage` is currently enabled.
|
| 33 |
+
"""
|
| 34 |
+
assert len(
|
| 35 |
+
_CURRENT_STORAGE_STACK
|
| 36 |
+
), "get_event_storage() has to be called inside a 'with EventStorage(...)' context!"
|
| 37 |
+
return _CURRENT_STORAGE_STACK[-1]
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def has_event_storage():
|
| 41 |
+
"""
|
| 42 |
+
Returns:
|
| 43 |
+
Check if there are EventStorage() context existed.
|
| 44 |
+
"""
|
| 45 |
+
return len(_CURRENT_STORAGE_STACK) > 0
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
class EventWriter:
|
| 49 |
+
"""
|
| 50 |
+
Base class for writers that obtain events from :class:`EventStorage` and process them.
|
| 51 |
+
"""
|
| 52 |
+
|
| 53 |
+
def write(self):
|
| 54 |
+
raise NotImplementedError
|
| 55 |
+
|
| 56 |
+
def close(self):
|
| 57 |
+
pass
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
class JSONWriter(EventWriter):
|
| 61 |
+
"""
|
| 62 |
+
Write scalars to a json file.
|
| 63 |
+
|
| 64 |
+
It saves scalars as one json per line (instead of a big json) for easy parsing.
|
| 65 |
+
|
| 66 |
+
Examples parsing such a json file:
|
| 67 |
+
::
|
| 68 |
+
$ cat metrics.json | jq -s '.[0:2]'
|
| 69 |
+
[
|
| 70 |
+
{
|
| 71 |
+
"data_time": 0.008433341979980469,
|
| 72 |
+
"iteration": 19,
|
| 73 |
+
"loss": 1.9228371381759644,
|
| 74 |
+
"loss_box_reg": 0.050025828182697296,
|
| 75 |
+
"loss_classifier": 0.5316952466964722,
|
| 76 |
+
"loss_mask": 0.7236229181289673,
|
| 77 |
+
"loss_rpn_box": 0.0856662318110466,
|
| 78 |
+
"loss_rpn_cls": 0.48198649287223816,
|
| 79 |
+
"lr": 0.007173333333333333,
|
| 80 |
+
"time": 0.25401854515075684
|
| 81 |
+
},
|
| 82 |
+
{
|
| 83 |
+
"data_time": 0.007216215133666992,
|
| 84 |
+
"iteration": 39,
|
| 85 |
+
"loss": 1.282649278640747,
|
| 86 |
+
"loss_box_reg": 0.06222952902317047,
|
| 87 |
+
"loss_classifier": 0.30682939291000366,
|
| 88 |
+
"loss_mask": 0.6970193982124329,
|
| 89 |
+
"loss_rpn_box": 0.038663312792778015,
|
| 90 |
+
"loss_rpn_cls": 0.1471673548221588,
|
| 91 |
+
"lr": 0.007706666666666667,
|
| 92 |
+
"time": 0.2490077018737793
|
| 93 |
+
}
|
| 94 |
+
]
|
| 95 |
+
|
| 96 |
+
$ cat metrics.json | jq '.loss_mask'
|
| 97 |
+
0.7126231789588928
|
| 98 |
+
0.689423680305481
|
| 99 |
+
0.6776131987571716
|
| 100 |
+
...
|
| 101 |
+
|
| 102 |
+
"""
|
| 103 |
+
|
| 104 |
+
def __init__(self, json_file, window_size=20):
|
| 105 |
+
"""
|
| 106 |
+
Args:
|
| 107 |
+
json_file (str): path to the json file. New data will be appended if the file exists.
|
| 108 |
+
window_size (int): the window size of median smoothing for the scalars whose
|
| 109 |
+
`smoothing_hint` are True.
|
| 110 |
+
"""
|
| 111 |
+
self._file_handle = PathManager.open(json_file, "a")
|
| 112 |
+
self._window_size = window_size
|
| 113 |
+
self._last_write = -1
|
| 114 |
+
|
| 115 |
+
def write(self):
|
| 116 |
+
storage = get_event_storage()
|
| 117 |
+
to_save = defaultdict(dict)
|
| 118 |
+
|
| 119 |
+
for k, (v, iter) in storage.latest_with_smoothing_hint(self._window_size).items():
|
| 120 |
+
# keep scalars that have not been written
|
| 121 |
+
if iter <= self._last_write:
|
| 122 |
+
continue
|
| 123 |
+
to_save[iter][k] = v
|
| 124 |
+
if len(to_save):
|
| 125 |
+
all_iters = sorted(to_save.keys())
|
| 126 |
+
self._last_write = max(all_iters)
|
| 127 |
+
|
| 128 |
+
for itr, scalars_per_iter in to_save.items():
|
| 129 |
+
scalars_per_iter["iteration"] = itr
|
| 130 |
+
self._file_handle.write(json.dumps(scalars_per_iter, sort_keys=True) + "\n")
|
| 131 |
+
self._file_handle.flush()
|
| 132 |
+
try:
|
| 133 |
+
os.fsync(self._file_handle.fileno())
|
| 134 |
+
except AttributeError:
|
| 135 |
+
pass
|
| 136 |
+
|
| 137 |
+
def close(self):
|
| 138 |
+
self._file_handle.close()
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
class TensorboardXWriter(EventWriter):
|
| 142 |
+
"""
|
| 143 |
+
Write all scalars to a tensorboard file.
|
| 144 |
+
"""
|
| 145 |
+
|
| 146 |
+
def __init__(self, log_dir: str, window_size: int = 20, **kwargs):
|
| 147 |
+
"""
|
| 148 |
+
Args:
|
| 149 |
+
log_dir (str): the directory to save the output events
|
| 150 |
+
window_size (int): the scalars will be median-smoothed by this window size
|
| 151 |
+
|
| 152 |
+
kwargs: other arguments passed to `torch.utils.tensorboard.SummaryWriter(...)`
|
| 153 |
+
"""
|
| 154 |
+
self._window_size = window_size
|
| 155 |
+
self._writer_args = {"log_dir": log_dir, **kwargs}
|
| 156 |
+
self._last_write = -1
|
| 157 |
+
|
| 158 |
+
@cached_property
|
| 159 |
+
def _writer(self):
|
| 160 |
+
from torch.utils.tensorboard import SummaryWriter
|
| 161 |
+
|
| 162 |
+
return SummaryWriter(**self._writer_args)
|
| 163 |
+
|
| 164 |
+
def write(self):
|
| 165 |
+
storage = get_event_storage()
|
| 166 |
+
new_last_write = self._last_write
|
| 167 |
+
for k, (v, iter) in storage.latest_with_smoothing_hint(self._window_size).items():
|
| 168 |
+
if iter > self._last_write:
|
| 169 |
+
self._writer.add_scalar(k, v, iter)
|
| 170 |
+
new_last_write = max(new_last_write, iter)
|
| 171 |
+
self._last_write = new_last_write
|
| 172 |
+
|
| 173 |
+
# storage.put_{image,histogram} is only meant to be used by
|
| 174 |
+
# tensorboard writer. So we access its internal fields directly from here.
|
| 175 |
+
if len(storage._vis_data) >= 1:
|
| 176 |
+
for img_name, img, step_num in storage._vis_data:
|
| 177 |
+
self._writer.add_image(img_name, img, step_num)
|
| 178 |
+
# Storage stores all image data and rely on this writer to clear them.
|
| 179 |
+
# As a result it assumes only one writer will use its image data.
|
| 180 |
+
# An alternative design is to let storage store limited recent
|
| 181 |
+
# data (e.g. only the most recent image) that all writers can access.
|
| 182 |
+
# In that case a writer may not see all image data if its period is long.
|
| 183 |
+
storage.clear_images()
|
| 184 |
+
|
| 185 |
+
if len(storage._histograms) >= 1:
|
| 186 |
+
for params in storage._histograms:
|
| 187 |
+
self._writer.add_histogram_raw(**params)
|
| 188 |
+
storage.clear_histograms()
|
| 189 |
+
|
| 190 |
+
def close(self):
|
| 191 |
+
if "_writer" in self.__dict__:
|
| 192 |
+
self._writer.close()
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
class CommonMetricPrinter(EventWriter):
|
| 196 |
+
"""
|
| 197 |
+
Print **common** metrics to the terminal, including
|
| 198 |
+
iteration time, ETA, memory, all losses, and the learning rate.
|
| 199 |
+
It also applies smoothing using a window of 20 elements.
|
| 200 |
+
|
| 201 |
+
It's meant to print common metrics in common ways.
|
| 202 |
+
To print something in more customized ways, please implement a similar printer by yourself.
|
| 203 |
+
"""
|
| 204 |
+
|
| 205 |
+
def __init__(self, max_iter: Optional[int] = None, window_size: int = 20):
|
| 206 |
+
"""
|
| 207 |
+
Args:
|
| 208 |
+
max_iter: the maximum number of iterations to train.
|
| 209 |
+
Used to compute ETA. If not given, ETA will not be printed.
|
| 210 |
+
window_size (int): the losses will be median-smoothed by this window size
|
| 211 |
+
"""
|
| 212 |
+
self.logger = logging.getLogger("detectron2.utils.events")
|
| 213 |
+
self._max_iter = max_iter
|
| 214 |
+
self._window_size = window_size
|
| 215 |
+
self._last_write = None # (step, time) of last call to write(). Used to compute ETA
|
| 216 |
+
|
| 217 |
+
def _get_eta(self, storage) -> Optional[str]:
|
| 218 |
+
if self._max_iter is None:
|
| 219 |
+
return ""
|
| 220 |
+
iteration = storage.iter
|
| 221 |
+
try:
|
| 222 |
+
eta_seconds = storage.history("time").median(1000) * (self._max_iter - iteration - 1)
|
| 223 |
+
storage.put_scalar("eta_seconds", eta_seconds, smoothing_hint=False)
|
| 224 |
+
return str(datetime.timedelta(seconds=int(eta_seconds)))
|
| 225 |
+
except KeyError:
|
| 226 |
+
# estimate eta on our own - more noisy
|
| 227 |
+
eta_string = None
|
| 228 |
+
if self._last_write is not None:
|
| 229 |
+
estimate_iter_time = (time.perf_counter() - self._last_write[1]) / (
|
| 230 |
+
iteration - self._last_write[0]
|
| 231 |
+
)
|
| 232 |
+
eta_seconds = estimate_iter_time * (self._max_iter - iteration - 1)
|
| 233 |
+
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
|
| 234 |
+
self._last_write = (iteration, time.perf_counter())
|
| 235 |
+
return eta_string
|
| 236 |
+
|
| 237 |
+
def write(self):
|
| 238 |
+
storage = get_event_storage()
|
| 239 |
+
iteration = storage.iter
|
| 240 |
+
if iteration == self._max_iter:
|
| 241 |
+
# This hook only reports training progress (loss, ETA, etc) but not other data,
|
| 242 |
+
# therefore do not write anything after training succeeds, even if this method
|
| 243 |
+
# is called.
|
| 244 |
+
return
|
| 245 |
+
|
| 246 |
+
try:
|
| 247 |
+
avg_data_time = storage.history("data_time").avg(
|
| 248 |
+
storage.count_samples("data_time", self._window_size)
|
| 249 |
+
)
|
| 250 |
+
last_data_time = storage.history("data_time").latest()
|
| 251 |
+
except KeyError:
|
| 252 |
+
# they may not exist in the first few iterations (due to warmup)
|
| 253 |
+
# or when SimpleTrainer is not used
|
| 254 |
+
avg_data_time = None
|
| 255 |
+
last_data_time = None
|
| 256 |
+
try:
|
| 257 |
+
avg_iter_time = storage.history("time").global_avg()
|
| 258 |
+
last_iter_time = storage.history("time").latest()
|
| 259 |
+
except KeyError:
|
| 260 |
+
avg_iter_time = None
|
| 261 |
+
last_iter_time = None
|
| 262 |
+
try:
|
| 263 |
+
lr = "{:.5g}".format(storage.history("lr").latest())
|
| 264 |
+
except KeyError:
|
| 265 |
+
lr = "N/A"
|
| 266 |
+
|
| 267 |
+
eta_string = self._get_eta(storage)
|
| 268 |
+
|
| 269 |
+
if torch.cuda.is_available():
|
| 270 |
+
max_mem_mb = torch.cuda.max_memory_allocated() / 1024.0 / 1024.0
|
| 271 |
+
else:
|
| 272 |
+
max_mem_mb = None
|
| 273 |
+
|
| 274 |
+
# NOTE: max_mem is parsed by grep in "dev/parse_results.sh"
|
| 275 |
+
self.logger.info(
|
| 276 |
+
str.format(
|
| 277 |
+
" {eta}iter: {iter} {losses} {non_losses} {avg_time}{last_time}"
|
| 278 |
+
+ "{avg_data_time}{last_data_time} lr: {lr} {memory}",
|
| 279 |
+
eta=f"eta: {eta_string} " if eta_string else "",
|
| 280 |
+
iter=iteration,
|
| 281 |
+
losses=" ".join(
|
| 282 |
+
[
|
| 283 |
+
"{}: {:.4g}".format(
|
| 284 |
+
k, v.median(storage.count_samples(k, self._window_size))
|
| 285 |
+
)
|
| 286 |
+
for k, v in storage.histories().items()
|
| 287 |
+
if "loss" in k
|
| 288 |
+
]
|
| 289 |
+
),
|
| 290 |
+
non_losses=" ".join(
|
| 291 |
+
[
|
| 292 |
+
"{}: {:.4g}".format(
|
| 293 |
+
k, v.median(storage.count_samples(k, self._window_size))
|
| 294 |
+
)
|
| 295 |
+
for k, v in storage.histories().items()
|
| 296 |
+
if "[metric]" in k
|
| 297 |
+
]
|
| 298 |
+
),
|
| 299 |
+
avg_time=(
|
| 300 |
+
"time: {:.4f} ".format(avg_iter_time) if avg_iter_time is not None else ""
|
| 301 |
+
),
|
| 302 |
+
last_time=(
|
| 303 |
+
"last_time: {:.4f} ".format(last_iter_time)
|
| 304 |
+
if last_iter_time is not None
|
| 305 |
+
else ""
|
| 306 |
+
),
|
| 307 |
+
avg_data_time=(
|
| 308 |
+
"data_time: {:.4f} ".format(avg_data_time) if avg_data_time is not None else ""
|
| 309 |
+
),
|
| 310 |
+
last_data_time=(
|
| 311 |
+
"last_data_time: {:.4f} ".format(last_data_time)
|
| 312 |
+
if last_data_time is not None
|
| 313 |
+
else ""
|
| 314 |
+
),
|
| 315 |
+
lr=lr,
|
| 316 |
+
memory="max_mem: {:.0f}M".format(max_mem_mb) if max_mem_mb is not None else "",
|
| 317 |
+
)
|
| 318 |
+
)
|
| 319 |
+
|
| 320 |
+
|
| 321 |
+
class EventStorage:
|
| 322 |
+
"""
|
| 323 |
+
The user-facing class that provides metric storage functionalities.
|
| 324 |
+
|
| 325 |
+
In the future we may add support for storing / logging other types of data if needed.
|
| 326 |
+
"""
|
| 327 |
+
|
| 328 |
+
def __init__(self, start_iter=0):
|
| 329 |
+
"""
|
| 330 |
+
Args:
|
| 331 |
+
start_iter (int): the iteration number to start with
|
| 332 |
+
"""
|
| 333 |
+
self._history = defaultdict(HistoryBuffer)
|
| 334 |
+
self._smoothing_hints = {}
|
| 335 |
+
self._latest_scalars = {}
|
| 336 |
+
self._iter = start_iter
|
| 337 |
+
self._current_prefix = ""
|
| 338 |
+
self._vis_data = []
|
| 339 |
+
self._histograms = []
|
| 340 |
+
|
| 341 |
+
def put_image(self, img_name, img_tensor):
|
| 342 |
+
"""
|
| 343 |
+
Add an `img_tensor` associated with `img_name`, to be shown on
|
| 344 |
+
tensorboard.
|
| 345 |
+
|
| 346 |
+
Args:
|
| 347 |
+
img_name (str): The name of the image to put into tensorboard.
|
| 348 |
+
img_tensor (torch.Tensor or numpy.array): An `uint8` or `float`
|
| 349 |
+
Tensor of shape `[channel, height, width]` where `channel` is
|
| 350 |
+
3. The image format should be RGB. The elements in img_tensor
|
| 351 |
+
can either have values in [0, 1] (float32) or [0, 255] (uint8).
|
| 352 |
+
The `img_tensor` will be visualized in tensorboard.
|
| 353 |
+
"""
|
| 354 |
+
self._vis_data.append((img_name, img_tensor, self._iter))
|
| 355 |
+
|
| 356 |
+
def put_scalar(self, name, value, smoothing_hint=True, cur_iter=None):
|
| 357 |
+
"""
|
| 358 |
+
Add a scalar `value` to the `HistoryBuffer` associated with `name`.
|
| 359 |
+
|
| 360 |
+
Args:
|
| 361 |
+
smoothing_hint (bool): a 'hint' on whether this scalar is noisy and should be
|
| 362 |
+
smoothed when logged. The hint will be accessible through
|
| 363 |
+
:meth:`EventStorage.smoothing_hints`. A writer may ignore the hint
|
| 364 |
+
and apply custom smoothing rule.
|
| 365 |
+
|
| 366 |
+
It defaults to True because most scalars we save need to be smoothed to
|
| 367 |
+
provide any useful signal.
|
| 368 |
+
cur_iter (int): an iteration number to set explicitly instead of current iteration
|
| 369 |
+
"""
|
| 370 |
+
name = self._current_prefix + name
|
| 371 |
+
cur_iter = self._iter if cur_iter is None else cur_iter
|
| 372 |
+
history = self._history[name]
|
| 373 |
+
value = float(value)
|
| 374 |
+
history.update(value, cur_iter)
|
| 375 |
+
self._latest_scalars[name] = (value, cur_iter)
|
| 376 |
+
|
| 377 |
+
existing_hint = self._smoothing_hints.get(name)
|
| 378 |
+
|
| 379 |
+
if existing_hint is not None:
|
| 380 |
+
assert (
|
| 381 |
+
existing_hint == smoothing_hint
|
| 382 |
+
), "Scalar {} was put with a different smoothing_hint!".format(name)
|
| 383 |
+
else:
|
| 384 |
+
self._smoothing_hints[name] = smoothing_hint
|
| 385 |
+
|
| 386 |
+
def put_scalars(self, *, smoothing_hint=True, cur_iter=None, **kwargs):
|
| 387 |
+
"""
|
| 388 |
+
Put multiple scalars from keyword arguments.
|
| 389 |
+
|
| 390 |
+
Examples:
|
| 391 |
+
|
| 392 |
+
storage.put_scalars(loss=my_loss, accuracy=my_accuracy, smoothing_hint=True)
|
| 393 |
+
"""
|
| 394 |
+
for k, v in kwargs.items():
|
| 395 |
+
self.put_scalar(k, v, smoothing_hint=smoothing_hint, cur_iter=cur_iter)
|
| 396 |
+
|
| 397 |
+
def put_histogram(self, hist_name, hist_tensor, bins=1000):
|
| 398 |
+
"""
|
| 399 |
+
Create a histogram from a tensor.
|
| 400 |
+
|
| 401 |
+
Args:
|
| 402 |
+
hist_name (str): The name of the histogram to put into tensorboard.
|
| 403 |
+
hist_tensor (torch.Tensor): A Tensor of arbitrary shape to be converted
|
| 404 |
+
into a histogram.
|
| 405 |
+
bins (int): Number of histogram bins.
|
| 406 |
+
"""
|
| 407 |
+
ht_min, ht_max = hist_tensor.min().item(), hist_tensor.max().item()
|
| 408 |
+
|
| 409 |
+
# Create a histogram with PyTorch
|
| 410 |
+
hist_counts = torch.histc(hist_tensor, bins=bins)
|
| 411 |
+
hist_edges = torch.linspace(start=ht_min, end=ht_max, steps=bins + 1, dtype=torch.float32)
|
| 412 |
+
|
| 413 |
+
# Parameter for the add_histogram_raw function of SummaryWriter
|
| 414 |
+
hist_params = dict(
|
| 415 |
+
tag=hist_name,
|
| 416 |
+
min=ht_min,
|
| 417 |
+
max=ht_max,
|
| 418 |
+
num=len(hist_tensor),
|
| 419 |
+
sum=float(hist_tensor.sum()),
|
| 420 |
+
sum_squares=float(torch.sum(hist_tensor**2)),
|
| 421 |
+
bucket_limits=hist_edges[1:].tolist(),
|
| 422 |
+
bucket_counts=hist_counts.tolist(),
|
| 423 |
+
global_step=self._iter,
|
| 424 |
+
)
|
| 425 |
+
self._histograms.append(hist_params)
|
| 426 |
+
|
| 427 |
+
def history(self, name):
|
| 428 |
+
"""
|
| 429 |
+
Returns:
|
| 430 |
+
HistoryBuffer: the scalar history for name
|
| 431 |
+
"""
|
| 432 |
+
ret = self._history.get(name, None)
|
| 433 |
+
if ret is None:
|
| 434 |
+
raise KeyError("No history metric available for {}!".format(name))
|
| 435 |
+
return ret
|
| 436 |
+
|
| 437 |
+
def histories(self):
|
| 438 |
+
"""
|
| 439 |
+
Returns:
|
| 440 |
+
dict[name -> HistoryBuffer]: the HistoryBuffer for all scalars
|
| 441 |
+
"""
|
| 442 |
+
return self._history
|
| 443 |
+
|
| 444 |
+
def latest(self):
|
| 445 |
+
"""
|
| 446 |
+
Returns:
|
| 447 |
+
dict[str -> (float, int)]: mapping from the name of each scalar to the most
|
| 448 |
+
recent value and the iteration number its added.
|
| 449 |
+
"""
|
| 450 |
+
return self._latest_scalars
|
| 451 |
+
|
| 452 |
+
def latest_with_smoothing_hint(self, window_size=20):
|
| 453 |
+
"""
|
| 454 |
+
Similar to :meth:`latest`, but the returned values
|
| 455 |
+
are either the un-smoothed original latest value,
|
| 456 |
+
or a median of the given window_size,
|
| 457 |
+
depend on whether the smoothing_hint is True.
|
| 458 |
+
|
| 459 |
+
This provides a default behavior that other writers can use.
|
| 460 |
+
|
| 461 |
+
Note: All scalars saved in the past `window_size` iterations are used for smoothing.
|
| 462 |
+
This is different from the `window_size` definition in HistoryBuffer.
|
| 463 |
+
Use :meth:`get_history_window_size` to get the `window_size` used in HistoryBuffer.
|
| 464 |
+
"""
|
| 465 |
+
result = {}
|
| 466 |
+
for k, (v, itr) in self._latest_scalars.items():
|
| 467 |
+
result[k] = (
|
| 468 |
+
(
|
| 469 |
+
self._history[k].median(self.count_samples(k, window_size))
|
| 470 |
+
if self._smoothing_hints[k]
|
| 471 |
+
else v
|
| 472 |
+
),
|
| 473 |
+
itr,
|
| 474 |
+
)
|
| 475 |
+
return result
|
| 476 |
+
|
| 477 |
+
def count_samples(self, name, window_size=20):
|
| 478 |
+
"""
|
| 479 |
+
Return the number of samples logged in the past `window_size` iterations.
|
| 480 |
+
"""
|
| 481 |
+
samples = 0
|
| 482 |
+
data = self._history[name].values()
|
| 483 |
+
for _, iter_ in reversed(data):
|
| 484 |
+
if iter_ > data[-1][1] - window_size:
|
| 485 |
+
samples += 1
|
| 486 |
+
else:
|
| 487 |
+
break
|
| 488 |
+
return samples
|
| 489 |
+
|
| 490 |
+
def smoothing_hints(self):
|
| 491 |
+
"""
|
| 492 |
+
Returns:
|
| 493 |
+
dict[name -> bool]: the user-provided hint on whether the scalar
|
| 494 |
+
is noisy and needs smoothing.
|
| 495 |
+
"""
|
| 496 |
+
return self._smoothing_hints
|
| 497 |
+
|
| 498 |
+
def step(self):
|
| 499 |
+
"""
|
| 500 |
+
User should either: (1) Call this function to increment storage.iter when needed. Or
|
| 501 |
+
(2) Set `storage.iter` to the correct iteration number before each iteration.
|
| 502 |
+
|
| 503 |
+
The storage will then be able to associate the new data with an iteration number.
|
| 504 |
+
"""
|
| 505 |
+
self._iter += 1
|
| 506 |
+
|
| 507 |
+
@property
|
| 508 |
+
def iter(self):
|
| 509 |
+
"""
|
| 510 |
+
Returns:
|
| 511 |
+
int: The current iteration number. When used together with a trainer,
|
| 512 |
+
this is ensured to be the same as trainer.iter.
|
| 513 |
+
"""
|
| 514 |
+
return self._iter
|
| 515 |
+
|
| 516 |
+
@iter.setter
|
| 517 |
+
def iter(self, val):
|
| 518 |
+
self._iter = int(val)
|
| 519 |
+
|
| 520 |
+
@property
|
| 521 |
+
def iteration(self):
|
| 522 |
+
# for backward compatibility
|
| 523 |
+
return self._iter
|
| 524 |
+
|
| 525 |
+
def __enter__(self):
|
| 526 |
+
_CURRENT_STORAGE_STACK.append(self)
|
| 527 |
+
return self
|
| 528 |
+
|
| 529 |
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
| 530 |
+
assert _CURRENT_STORAGE_STACK[-1] == self
|
| 531 |
+
_CURRENT_STORAGE_STACK.pop()
|
| 532 |
+
|
| 533 |
+
@contextmanager
|
| 534 |
+
def name_scope(self, name):
|
| 535 |
+
"""
|
| 536 |
+
Yields:
|
| 537 |
+
A context within which all the events added to this storage
|
| 538 |
+
will be prefixed by the name scope.
|
| 539 |
+
"""
|
| 540 |
+
old_prefix = self._current_prefix
|
| 541 |
+
self._current_prefix = name.rstrip("/") + "/"
|
| 542 |
+
yield
|
| 543 |
+
self._current_prefix = old_prefix
|
| 544 |
+
|
| 545 |
+
def clear_images(self):
|
| 546 |
+
"""
|
| 547 |
+
Delete all the stored images for visualization. This should be called
|
| 548 |
+
after images are written to tensorboard.
|
| 549 |
+
"""
|
| 550 |
+
self._vis_data = []
|
| 551 |
+
|
| 552 |
+
def clear_histograms(self):
|
| 553 |
+
"""
|
| 554 |
+
Delete all the stored histograms for visualization.
|
| 555 |
+
This should be called after histograms are written to tensorboard.
|
| 556 |
+
"""
|
| 557 |
+
self._histograms = []
|
detectron2/utils/file_io.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
from iopath.common.file_io import HTTPURLHandler, OneDrivePathHandler, PathHandler
|
| 3 |
+
from iopath.common.file_io import PathManager as PathManagerBase
|
| 4 |
+
|
| 5 |
+
__all__ = ["PathManager", "PathHandler"]
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
PathManager = PathManagerBase()
|
| 9 |
+
"""
|
| 10 |
+
This is a detectron2 project-specific PathManager.
|
| 11 |
+
We try to stay away from global PathManager in fvcore as it
|
| 12 |
+
introduces potential conflicts among other libraries.
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class Detectron2Handler(PathHandler):
|
| 17 |
+
"""
|
| 18 |
+
Resolve anything that's hosted under detectron2's namespace.
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
PREFIX = "detectron2://"
|
| 22 |
+
S3_DETECTRON2_PREFIX = "https://dl.fbaipublicfiles.com/detectron2/"
|
| 23 |
+
|
| 24 |
+
def _get_supported_prefixes(self):
|
| 25 |
+
return [self.PREFIX]
|
| 26 |
+
|
| 27 |
+
def _get_local_path(self, path, **kwargs):
|
| 28 |
+
name = path[len(self.PREFIX) :]
|
| 29 |
+
return PathManager.get_local_path(self.S3_DETECTRON2_PREFIX + name, **kwargs)
|
| 30 |
+
|
| 31 |
+
def _open(self, path, mode="r", **kwargs):
|
| 32 |
+
return PathManager.open(
|
| 33 |
+
self.S3_DETECTRON2_PREFIX + path[len(self.PREFIX) :], mode, **kwargs
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
PathManager.register_handler(HTTPURLHandler())
|
| 38 |
+
PathManager.register_handler(OneDrivePathHandler())
|
| 39 |
+
PathManager.register_handler(Detectron2Handler())
|
detectron2/utils/logger.py
ADDED
|
@@ -0,0 +1,261 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
import atexit
|
| 3 |
+
import functools
|
| 4 |
+
import logging
|
| 5 |
+
import os
|
| 6 |
+
import sys
|
| 7 |
+
import time
|
| 8 |
+
from collections import Counter
|
| 9 |
+
import torch
|
| 10 |
+
from tabulate import tabulate
|
| 11 |
+
from termcolor import colored
|
| 12 |
+
|
| 13 |
+
from detectron2.utils.file_io import PathManager
|
| 14 |
+
|
| 15 |
+
__all__ = ["setup_logger", "log_first_n", "log_every_n", "log_every_n_seconds"]
|
| 16 |
+
|
| 17 |
+
D2_LOG_BUFFER_SIZE_KEY: str = "D2_LOG_BUFFER_SIZE"
|
| 18 |
+
|
| 19 |
+
DEFAULT_LOG_BUFFER_SIZE: int = 1024 * 1024 # 1MB
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class _ColorfulFormatter(logging.Formatter):
|
| 23 |
+
def __init__(self, *args, **kwargs):
|
| 24 |
+
self._root_name = kwargs.pop("root_name") + "."
|
| 25 |
+
self._abbrev_name = kwargs.pop("abbrev_name", "")
|
| 26 |
+
if len(self._abbrev_name):
|
| 27 |
+
self._abbrev_name = self._abbrev_name + "."
|
| 28 |
+
super(_ColorfulFormatter, self).__init__(*args, **kwargs)
|
| 29 |
+
|
| 30 |
+
def formatMessage(self, record):
|
| 31 |
+
record.name = record.name.replace(self._root_name, self._abbrev_name)
|
| 32 |
+
log = super(_ColorfulFormatter, self).formatMessage(record)
|
| 33 |
+
if record.levelno == logging.WARNING:
|
| 34 |
+
prefix = colored("WARNING", "red", attrs=["blink"])
|
| 35 |
+
elif record.levelno == logging.ERROR or record.levelno == logging.CRITICAL:
|
| 36 |
+
prefix = colored("ERROR", "red", attrs=["blink", "underline"])
|
| 37 |
+
else:
|
| 38 |
+
return log
|
| 39 |
+
return prefix + " " + log
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
@functools.lru_cache() # so that calling setup_logger multiple times won't add many handlers
|
| 43 |
+
def setup_logger(
|
| 44 |
+
output=None,
|
| 45 |
+
distributed_rank=0,
|
| 46 |
+
*,
|
| 47 |
+
color=True,
|
| 48 |
+
name="detectron2",
|
| 49 |
+
abbrev_name=None,
|
| 50 |
+
enable_propagation: bool = False,
|
| 51 |
+
configure_stdout: bool = True
|
| 52 |
+
):
|
| 53 |
+
"""
|
| 54 |
+
Initialize the detectron2 logger and set its verbosity level to "DEBUG".
|
| 55 |
+
|
| 56 |
+
Args:
|
| 57 |
+
output (str): a file name or a directory to save log. If None, will not save log file.
|
| 58 |
+
If ends with ".txt" or ".log", assumed to be a file name.
|
| 59 |
+
Otherwise, logs will be saved to `output/log.txt`.
|
| 60 |
+
name (str): the root module name of this logger
|
| 61 |
+
abbrev_name (str): an abbreviation of the module, to avoid long names in logs.
|
| 62 |
+
Set to "" to not log the root module in logs.
|
| 63 |
+
By default, will abbreviate "detectron2" to "d2" and leave other
|
| 64 |
+
modules unchanged.
|
| 65 |
+
enable_propagation (bool): whether to propagate logs to the parent logger.
|
| 66 |
+
configure_stdout (bool): whether to configure logging to stdout.
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
Returns:
|
| 70 |
+
logging.Logger: a logger
|
| 71 |
+
"""
|
| 72 |
+
logger = logging.getLogger(name)
|
| 73 |
+
logger.setLevel(logging.DEBUG)
|
| 74 |
+
logger.propagate = enable_propagation
|
| 75 |
+
|
| 76 |
+
if abbrev_name is None:
|
| 77 |
+
abbrev_name = "d2" if name == "detectron2" else name
|
| 78 |
+
|
| 79 |
+
plain_formatter = logging.Formatter(
|
| 80 |
+
"[%(asctime)s] %(name)s %(levelname)s: %(message)s", datefmt="%m/%d %H:%M:%S"
|
| 81 |
+
)
|
| 82 |
+
# stdout logging: master only
|
| 83 |
+
if configure_stdout and distributed_rank == 0:
|
| 84 |
+
ch = logging.StreamHandler(stream=sys.stdout)
|
| 85 |
+
ch.setLevel(logging.DEBUG)
|
| 86 |
+
if color:
|
| 87 |
+
formatter = _ColorfulFormatter(
|
| 88 |
+
colored("[%(asctime)s %(name)s]: ", "green") + "%(message)s",
|
| 89 |
+
datefmt="%m/%d %H:%M:%S",
|
| 90 |
+
root_name=name,
|
| 91 |
+
abbrev_name=str(abbrev_name),
|
| 92 |
+
)
|
| 93 |
+
else:
|
| 94 |
+
formatter = plain_formatter
|
| 95 |
+
ch.setFormatter(formatter)
|
| 96 |
+
logger.addHandler(ch)
|
| 97 |
+
|
| 98 |
+
# file logging: all workers
|
| 99 |
+
if output is not None:
|
| 100 |
+
if output.endswith(".txt") or output.endswith(".log"):
|
| 101 |
+
filename = output
|
| 102 |
+
else:
|
| 103 |
+
filename = os.path.join(output, "log.txt")
|
| 104 |
+
if distributed_rank > 0:
|
| 105 |
+
filename = filename + ".rank{}".format(distributed_rank)
|
| 106 |
+
PathManager.mkdirs(os.path.dirname(filename))
|
| 107 |
+
|
| 108 |
+
fh = logging.StreamHandler(_cached_log_stream(filename))
|
| 109 |
+
fh.setLevel(logging.DEBUG)
|
| 110 |
+
fh.setFormatter(plain_formatter)
|
| 111 |
+
logger.addHandler(fh)
|
| 112 |
+
|
| 113 |
+
return logger
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
# cache the opened file object, so that different calls to `setup_logger`
|
| 117 |
+
# with the same file name can safely write to the same file.
|
| 118 |
+
@functools.lru_cache(maxsize=None)
|
| 119 |
+
def _cached_log_stream(filename):
|
| 120 |
+
# use 1K buffer if writing to cloud storage
|
| 121 |
+
io = PathManager.open(filename, "a", buffering=_get_log_stream_buffer_size(filename))
|
| 122 |
+
atexit.register(io.close)
|
| 123 |
+
return io
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def _get_log_stream_buffer_size(filename: str) -> int:
|
| 127 |
+
if "://" not in filename:
|
| 128 |
+
# Local file, no extra caching is necessary
|
| 129 |
+
return -1
|
| 130 |
+
# Remote file requires a larger cache to avoid many small writes.
|
| 131 |
+
if D2_LOG_BUFFER_SIZE_KEY in os.environ:
|
| 132 |
+
return int(os.environ[D2_LOG_BUFFER_SIZE_KEY])
|
| 133 |
+
return DEFAULT_LOG_BUFFER_SIZE
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
"""
|
| 137 |
+
Below are some other convenient logging methods.
|
| 138 |
+
They are mainly adopted from
|
| 139 |
+
https://github.com/abseil/abseil-py/blob/master/absl/logging/__init__.py
|
| 140 |
+
"""
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
def _find_caller():
|
| 144 |
+
"""
|
| 145 |
+
Returns:
|
| 146 |
+
str: module name of the caller
|
| 147 |
+
tuple: a hashable key to be used to identify different callers
|
| 148 |
+
"""
|
| 149 |
+
frame = sys._getframe(2)
|
| 150 |
+
while frame:
|
| 151 |
+
code = frame.f_code
|
| 152 |
+
if os.path.join("utils", "logger.") not in code.co_filename:
|
| 153 |
+
mod_name = frame.f_globals["__name__"]
|
| 154 |
+
if mod_name == "__main__":
|
| 155 |
+
mod_name = "detectron2"
|
| 156 |
+
return mod_name, (code.co_filename, frame.f_lineno, code.co_name)
|
| 157 |
+
frame = frame.f_back
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
_LOG_COUNTER = Counter()
|
| 161 |
+
_LOG_TIMER = {}
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
def log_first_n(lvl, msg, n=1, *, name=None, key="caller"):
|
| 165 |
+
"""
|
| 166 |
+
Log only for the first n times.
|
| 167 |
+
|
| 168 |
+
Args:
|
| 169 |
+
lvl (int): the logging level
|
| 170 |
+
msg (str):
|
| 171 |
+
n (int):
|
| 172 |
+
name (str): name of the logger to use. Will use the caller's module by default.
|
| 173 |
+
key (str or tuple[str]): the string(s) can be one of "caller" or
|
| 174 |
+
"message", which defines how to identify duplicated logs.
|
| 175 |
+
For example, if called with `n=1, key="caller"`, this function
|
| 176 |
+
will only log the first call from the same caller, regardless of
|
| 177 |
+
the message content.
|
| 178 |
+
If called with `n=1, key="message"`, this function will log the
|
| 179 |
+
same content only once, even if they are called from different places.
|
| 180 |
+
If called with `n=1, key=("caller", "message")`, this function
|
| 181 |
+
will not log only if the same caller has logged the same message before.
|
| 182 |
+
"""
|
| 183 |
+
if isinstance(key, str):
|
| 184 |
+
key = (key,)
|
| 185 |
+
assert len(key) > 0
|
| 186 |
+
|
| 187 |
+
caller_module, caller_key = _find_caller()
|
| 188 |
+
hash_key = ()
|
| 189 |
+
if "caller" in key:
|
| 190 |
+
hash_key = hash_key + caller_key
|
| 191 |
+
if "message" in key:
|
| 192 |
+
hash_key = hash_key + (msg,)
|
| 193 |
+
|
| 194 |
+
_LOG_COUNTER[hash_key] += 1
|
| 195 |
+
if _LOG_COUNTER[hash_key] <= n:
|
| 196 |
+
logging.getLogger(name or caller_module).log(lvl, msg)
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
def log_every_n(lvl, msg, n=1, *, name=None):
|
| 200 |
+
"""
|
| 201 |
+
Log once per n times.
|
| 202 |
+
|
| 203 |
+
Args:
|
| 204 |
+
lvl (int): the logging level
|
| 205 |
+
msg (str):
|
| 206 |
+
n (int):
|
| 207 |
+
name (str): name of the logger to use. Will use the caller's module by default.
|
| 208 |
+
"""
|
| 209 |
+
caller_module, key = _find_caller()
|
| 210 |
+
_LOG_COUNTER[key] += 1
|
| 211 |
+
if n == 1 or _LOG_COUNTER[key] % n == 1:
|
| 212 |
+
logging.getLogger(name or caller_module).log(lvl, msg)
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
def log_every_n_seconds(lvl, msg, n=1, *, name=None):
|
| 216 |
+
"""
|
| 217 |
+
Log no more than once per n seconds.
|
| 218 |
+
|
| 219 |
+
Args:
|
| 220 |
+
lvl (int): the logging level
|
| 221 |
+
msg (str):
|
| 222 |
+
n (int):
|
| 223 |
+
name (str): name of the logger to use. Will use the caller's module by default.
|
| 224 |
+
"""
|
| 225 |
+
caller_module, key = _find_caller()
|
| 226 |
+
last_logged = _LOG_TIMER.get(key, None)
|
| 227 |
+
current_time = time.time()
|
| 228 |
+
if last_logged is None or current_time - last_logged >= n:
|
| 229 |
+
logging.getLogger(name or caller_module).log(lvl, msg)
|
| 230 |
+
_LOG_TIMER[key] = current_time
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
def create_small_table(small_dict):
|
| 234 |
+
"""
|
| 235 |
+
Create a small table using the keys of small_dict as headers. This is only
|
| 236 |
+
suitable for small dictionaries.
|
| 237 |
+
|
| 238 |
+
Args:
|
| 239 |
+
small_dict (dict): a result dictionary of only a few items.
|
| 240 |
+
|
| 241 |
+
Returns:
|
| 242 |
+
str: the table as a string.
|
| 243 |
+
"""
|
| 244 |
+
keys, values = tuple(zip(*small_dict.items()))
|
| 245 |
+
table = tabulate(
|
| 246 |
+
[values],
|
| 247 |
+
headers=keys,
|
| 248 |
+
tablefmt="pipe",
|
| 249 |
+
floatfmt=".3f",
|
| 250 |
+
stralign="center",
|
| 251 |
+
numalign="center",
|
| 252 |
+
)
|
| 253 |
+
return table
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
def _log_api_usage(identifier: str):
|
| 257 |
+
"""
|
| 258 |
+
Internal function used to log the usage of different detectron2 components
|
| 259 |
+
inside facebook's infra.
|
| 260 |
+
"""
|
| 261 |
+
torch._C._log_api_usage_once("detectron2." + identifier)
|
detectron2/utils/memory.py
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
|
| 3 |
+
import logging
|
| 4 |
+
from contextlib import contextmanager
|
| 5 |
+
from functools import wraps
|
| 6 |
+
import torch
|
| 7 |
+
|
| 8 |
+
__all__ = ["retry_if_cuda_oom"]
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
@contextmanager
|
| 12 |
+
def _ignore_torch_cuda_oom():
|
| 13 |
+
"""
|
| 14 |
+
A context which ignores CUDA OOM exception from pytorch.
|
| 15 |
+
"""
|
| 16 |
+
try:
|
| 17 |
+
yield
|
| 18 |
+
except RuntimeError as e:
|
| 19 |
+
# NOTE: the string may change?
|
| 20 |
+
if "CUDA out of memory. " in str(e):
|
| 21 |
+
pass
|
| 22 |
+
else:
|
| 23 |
+
raise
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def retry_if_cuda_oom(func):
|
| 27 |
+
"""
|
| 28 |
+
Makes a function retry itself after encountering
|
| 29 |
+
pytorch's CUDA OOM error.
|
| 30 |
+
It will first retry after calling `torch.cuda.empty_cache()`.
|
| 31 |
+
|
| 32 |
+
If that still fails, it will then retry by trying to convert inputs to CPUs.
|
| 33 |
+
In this case, it expects the function to dispatch to CPU implementation.
|
| 34 |
+
The return values may become CPU tensors as well and it's user's
|
| 35 |
+
responsibility to convert it back to CUDA tensor if needed.
|
| 36 |
+
|
| 37 |
+
Args:
|
| 38 |
+
func: a stateless callable that takes tensor-like objects as arguments
|
| 39 |
+
|
| 40 |
+
Returns:
|
| 41 |
+
a callable which retries `func` if OOM is encountered.
|
| 42 |
+
|
| 43 |
+
Examples:
|
| 44 |
+
::
|
| 45 |
+
output = retry_if_cuda_oom(some_torch_function)(input1, input2)
|
| 46 |
+
# output may be on CPU even if inputs are on GPU
|
| 47 |
+
|
| 48 |
+
Note:
|
| 49 |
+
1. When converting inputs to CPU, it will only look at each argument and check
|
| 50 |
+
if it has `.device` and `.to` for conversion. Nested structures of tensors
|
| 51 |
+
are not supported.
|
| 52 |
+
|
| 53 |
+
2. Since the function might be called more than once, it has to be
|
| 54 |
+
stateless.
|
| 55 |
+
"""
|
| 56 |
+
|
| 57 |
+
def maybe_to_cpu(x):
|
| 58 |
+
try:
|
| 59 |
+
like_gpu_tensor = x.device.type == "cuda" and hasattr(x, "to")
|
| 60 |
+
except AttributeError:
|
| 61 |
+
like_gpu_tensor = False
|
| 62 |
+
if like_gpu_tensor:
|
| 63 |
+
return x.to(device="cpu")
|
| 64 |
+
else:
|
| 65 |
+
return x
|
| 66 |
+
|
| 67 |
+
@wraps(func)
|
| 68 |
+
def wrapped(*args, **kwargs):
|
| 69 |
+
with _ignore_torch_cuda_oom():
|
| 70 |
+
return func(*args, **kwargs)
|
| 71 |
+
|
| 72 |
+
# Clear cache and retry
|
| 73 |
+
torch.cuda.empty_cache()
|
| 74 |
+
with _ignore_torch_cuda_oom():
|
| 75 |
+
return func(*args, **kwargs)
|
| 76 |
+
|
| 77 |
+
# Try on CPU. This slows down the code significantly, therefore print a notice.
|
| 78 |
+
logger = logging.getLogger(__name__)
|
| 79 |
+
logger.info("Attempting to copy inputs of {} to CPU due to CUDA OOM".format(str(func)))
|
| 80 |
+
new_args = (maybe_to_cpu(x) for x in args)
|
| 81 |
+
new_kwargs = {k: maybe_to_cpu(v) for k, v in kwargs.items()}
|
| 82 |
+
return func(*new_args, **new_kwargs)
|
| 83 |
+
|
| 84 |
+
return wrapped
|
detectron2/utils/registry.py
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
|
| 3 |
+
from typing import Any
|
| 4 |
+
import pydoc
|
| 5 |
+
from fvcore.common.registry import Registry # for backward compatibility.
|
| 6 |
+
|
| 7 |
+
"""
|
| 8 |
+
``Registry`` and `locate` provide ways to map a string (typically found
|
| 9 |
+
in config files) to callable objects.
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
__all__ = ["Registry", "locate"]
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def _convert_target_to_string(t: Any) -> str:
|
| 16 |
+
"""
|
| 17 |
+
Inverse of ``locate()``.
|
| 18 |
+
|
| 19 |
+
Args:
|
| 20 |
+
t: any object with ``__module__`` and ``__qualname__``
|
| 21 |
+
"""
|
| 22 |
+
module, qualname = t.__module__, t.__qualname__
|
| 23 |
+
|
| 24 |
+
# Compress the path to this object, e.g. ``module.submodule._impl.class``
|
| 25 |
+
# may become ``module.submodule.class``, if the later also resolves to the same
|
| 26 |
+
# object. This simplifies the string, and also is less affected by moving the
|
| 27 |
+
# class implementation.
|
| 28 |
+
module_parts = module.split(".")
|
| 29 |
+
for k in range(1, len(module_parts)):
|
| 30 |
+
prefix = ".".join(module_parts[:k])
|
| 31 |
+
candidate = f"{prefix}.{qualname}"
|
| 32 |
+
try:
|
| 33 |
+
if locate(candidate) is t:
|
| 34 |
+
return candidate
|
| 35 |
+
except ImportError:
|
| 36 |
+
pass
|
| 37 |
+
return f"{module}.{qualname}"
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def locate(name: str) -> Any:
|
| 41 |
+
"""
|
| 42 |
+
Locate and return an object ``x`` using an input string ``{x.__module__}.{x.__qualname__}``,
|
| 43 |
+
such as "module.submodule.class_name".
|
| 44 |
+
|
| 45 |
+
Raise Exception if it cannot be found.
|
| 46 |
+
"""
|
| 47 |
+
obj = pydoc.locate(name)
|
| 48 |
+
|
| 49 |
+
# Some cases (e.g. torch.optim.sgd.SGD) not handled correctly
|
| 50 |
+
# by pydoc.locate. Try a private function from hydra.
|
| 51 |
+
if obj is None:
|
| 52 |
+
try:
|
| 53 |
+
# from hydra.utils import get_method - will print many errors
|
| 54 |
+
from hydra.utils import _locate
|
| 55 |
+
except ImportError as e:
|
| 56 |
+
raise ImportError(f"Cannot dynamically locate object {name}!") from e
|
| 57 |
+
else:
|
| 58 |
+
obj = _locate(name) # it raises if fails
|
| 59 |
+
|
| 60 |
+
return obj
|
detectron2/utils/serialize.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
import cloudpickle
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class PicklableWrapper:
|
| 6 |
+
"""
|
| 7 |
+
Wrap an object to make it more picklable, note that it uses
|
| 8 |
+
heavy weight serialization libraries that are slower than pickle.
|
| 9 |
+
It's best to use it only on closures (which are usually not picklable).
|
| 10 |
+
|
| 11 |
+
This is a simplified version of
|
| 12 |
+
https://github.com/joblib/joblib/blob/master/joblib/externals/loky/cloudpickle_wrapper.py
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
def __init__(self, obj):
|
| 16 |
+
while isinstance(obj, PicklableWrapper):
|
| 17 |
+
# Wrapping an object twice is no-op
|
| 18 |
+
obj = obj._obj
|
| 19 |
+
self._obj = obj
|
| 20 |
+
|
| 21 |
+
def __reduce__(self):
|
| 22 |
+
s = cloudpickle.dumps(self._obj)
|
| 23 |
+
return cloudpickle.loads, (s,)
|
| 24 |
+
|
| 25 |
+
def __call__(self, *args, **kwargs):
|
| 26 |
+
return self._obj(*args, **kwargs)
|
| 27 |
+
|
| 28 |
+
def __getattr__(self, attr):
|
| 29 |
+
# Ensure that the wrapped object can be used seamlessly as the previous object.
|
| 30 |
+
if attr not in ["_obj"]:
|
| 31 |
+
return getattr(self._obj, attr)
|
| 32 |
+
return getattr(self, attr)
|
detectron2/utils/testing.py
ADDED
|
@@ -0,0 +1,478 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
import io
|
| 3 |
+
import numpy as np
|
| 4 |
+
import os
|
| 5 |
+
import re
|
| 6 |
+
import tempfile
|
| 7 |
+
import unittest
|
| 8 |
+
from typing import Callable
|
| 9 |
+
import torch
|
| 10 |
+
import torch.onnx.symbolic_helper as sym_help
|
| 11 |
+
from packaging import version
|
| 12 |
+
from torch._C import ListType
|
| 13 |
+
from torch.onnx import register_custom_op_symbolic
|
| 14 |
+
|
| 15 |
+
from detectron2 import model_zoo
|
| 16 |
+
from detectron2.config import CfgNode, LazyConfig, instantiate
|
| 17 |
+
from detectron2.data import DatasetCatalog
|
| 18 |
+
from detectron2.data.detection_utils import read_image
|
| 19 |
+
from detectron2.modeling import build_model
|
| 20 |
+
from detectron2.structures import Boxes, Instances, ROIMasks
|
| 21 |
+
from detectron2.utils.file_io import PathManager
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
"""
|
| 25 |
+
Internal utilities for tests. Don't use except for writing tests.
|
| 26 |
+
"""
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def get_model_no_weights(config_path):
|
| 30 |
+
"""
|
| 31 |
+
Like model_zoo.get, but do not load any weights (even pretrained)
|
| 32 |
+
"""
|
| 33 |
+
cfg = model_zoo.get_config(config_path)
|
| 34 |
+
if isinstance(cfg, CfgNode):
|
| 35 |
+
if not torch.cuda.is_available():
|
| 36 |
+
cfg.MODEL.DEVICE = "cpu"
|
| 37 |
+
return build_model(cfg)
|
| 38 |
+
else:
|
| 39 |
+
return instantiate(cfg.model)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def random_boxes(num_boxes, max_coord=100, device="cpu"):
|
| 43 |
+
"""
|
| 44 |
+
Create a random Nx4 boxes tensor, with coordinates < max_coord.
|
| 45 |
+
"""
|
| 46 |
+
boxes = torch.rand(num_boxes, 4, device=device) * (max_coord * 0.5)
|
| 47 |
+
boxes.clamp_(min=1.0) # tiny boxes cause numerical instability in box regression
|
| 48 |
+
# Note: the implementation of this function in torchvision is:
|
| 49 |
+
# boxes[:, 2:] += torch.rand(N, 2) * 100
|
| 50 |
+
# but it does not guarantee non-negative widths/heights constraints:
|
| 51 |
+
# boxes[:, 2] >= boxes[:, 0] and boxes[:, 3] >= boxes[:, 1]:
|
| 52 |
+
boxes[:, 2:] += boxes[:, :2]
|
| 53 |
+
return boxes
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def get_sample_coco_image(tensor=True):
|
| 57 |
+
"""
|
| 58 |
+
Args:
|
| 59 |
+
tensor (bool): if True, returns 3xHxW tensor.
|
| 60 |
+
else, returns a HxWx3 numpy array.
|
| 61 |
+
|
| 62 |
+
Returns:
|
| 63 |
+
an image, in BGR color.
|
| 64 |
+
"""
|
| 65 |
+
try:
|
| 66 |
+
file_name = DatasetCatalog.get("coco_2017_val_100")[0]["file_name"]
|
| 67 |
+
if not PathManager.exists(file_name):
|
| 68 |
+
raise FileNotFoundError()
|
| 69 |
+
except IOError:
|
| 70 |
+
# for public CI to run
|
| 71 |
+
file_name = PathManager.get_local_path(
|
| 72 |
+
"http://images.cocodataset.org/train2017/000000000009.jpg"
|
| 73 |
+
)
|
| 74 |
+
ret = read_image(file_name, format="BGR")
|
| 75 |
+
if tensor:
|
| 76 |
+
ret = torch.from_numpy(np.ascontiguousarray(ret.transpose(2, 0, 1)))
|
| 77 |
+
return ret
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def convert_scripted_instances(instances):
|
| 81 |
+
"""
|
| 82 |
+
Convert a scripted Instances object to a regular :class:`Instances` object
|
| 83 |
+
"""
|
| 84 |
+
assert hasattr(
|
| 85 |
+
instances, "image_size"
|
| 86 |
+
), f"Expect an Instances object, but got {type(instances)}!"
|
| 87 |
+
ret = Instances(instances.image_size)
|
| 88 |
+
for name in instances._field_names:
|
| 89 |
+
val = getattr(instances, "_" + name, None)
|
| 90 |
+
if val is not None:
|
| 91 |
+
ret.set(name, val)
|
| 92 |
+
return ret
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def assert_instances_allclose(input, other, *, rtol=1e-5, msg="", size_as_tensor=False):
|
| 96 |
+
"""
|
| 97 |
+
Args:
|
| 98 |
+
input, other (Instances):
|
| 99 |
+
size_as_tensor: compare image_size of the Instances as tensors (instead of tuples).
|
| 100 |
+
Useful for comparing outputs of tracing.
|
| 101 |
+
"""
|
| 102 |
+
if not isinstance(input, Instances):
|
| 103 |
+
input = convert_scripted_instances(input)
|
| 104 |
+
if not isinstance(other, Instances):
|
| 105 |
+
other = convert_scripted_instances(other)
|
| 106 |
+
|
| 107 |
+
if not msg:
|
| 108 |
+
msg = "Two Instances are different! "
|
| 109 |
+
else:
|
| 110 |
+
msg = msg.rstrip() + " "
|
| 111 |
+
|
| 112 |
+
size_error_msg = msg + f"image_size is {input.image_size} vs. {other.image_size}!"
|
| 113 |
+
if size_as_tensor:
|
| 114 |
+
assert torch.equal(
|
| 115 |
+
torch.tensor(input.image_size), torch.tensor(other.image_size)
|
| 116 |
+
), size_error_msg
|
| 117 |
+
else:
|
| 118 |
+
assert input.image_size == other.image_size, size_error_msg
|
| 119 |
+
fields = sorted(input.get_fields().keys())
|
| 120 |
+
fields_other = sorted(other.get_fields().keys())
|
| 121 |
+
assert fields == fields_other, msg + f"Fields are {fields} vs {fields_other}!"
|
| 122 |
+
|
| 123 |
+
for f in fields:
|
| 124 |
+
val1, val2 = input.get(f), other.get(f)
|
| 125 |
+
if isinstance(val1, (Boxes, ROIMasks)):
|
| 126 |
+
# boxes in the range of O(100) and can have a larger tolerance
|
| 127 |
+
assert torch.allclose(val1.tensor, val2.tensor, atol=100 * rtol), (
|
| 128 |
+
msg + f"Field {f} differs too much!"
|
| 129 |
+
)
|
| 130 |
+
elif isinstance(val1, torch.Tensor):
|
| 131 |
+
if val1.dtype.is_floating_point:
|
| 132 |
+
mag = torch.abs(val1).max().cpu().item()
|
| 133 |
+
assert torch.allclose(val1, val2, atol=mag * rtol), (
|
| 134 |
+
msg + f"Field {f} differs too much!"
|
| 135 |
+
)
|
| 136 |
+
else:
|
| 137 |
+
assert torch.equal(val1, val2), msg + f"Field {f} is different!"
|
| 138 |
+
else:
|
| 139 |
+
raise ValueError(f"Don't know how to compare type {type(val1)}")
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
def reload_script_model(module):
|
| 143 |
+
"""
|
| 144 |
+
Save a jit module and load it back.
|
| 145 |
+
Similar to the `getExportImportCopy` function in torch/testing/
|
| 146 |
+
"""
|
| 147 |
+
buffer = io.BytesIO()
|
| 148 |
+
torch.jit.save(module, buffer)
|
| 149 |
+
buffer.seek(0)
|
| 150 |
+
return torch.jit.load(buffer)
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
def reload_lazy_config(cfg):
|
| 154 |
+
"""
|
| 155 |
+
Save an object by LazyConfig.save and load it back.
|
| 156 |
+
This is used to test that a config still works the same after
|
| 157 |
+
serialization/deserialization.
|
| 158 |
+
"""
|
| 159 |
+
with tempfile.TemporaryDirectory(prefix="detectron2") as d:
|
| 160 |
+
fname = os.path.join(d, "d2_cfg_test.yaml")
|
| 161 |
+
LazyConfig.save(cfg, fname)
|
| 162 |
+
return LazyConfig.load(fname)
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
def min_torch_version(min_version: str) -> bool:
|
| 166 |
+
"""
|
| 167 |
+
Returns True when torch's version is at least `min_version`.
|
| 168 |
+
"""
|
| 169 |
+
try:
|
| 170 |
+
import torch
|
| 171 |
+
except ImportError:
|
| 172 |
+
return False
|
| 173 |
+
|
| 174 |
+
installed_version = version.parse(torch.__version__.split("+")[0])
|
| 175 |
+
min_version = version.parse(min_version)
|
| 176 |
+
return installed_version >= min_version
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
def has_dynamic_axes(onnx_model):
|
| 180 |
+
"""
|
| 181 |
+
Return True when all ONNX input/output have only dynamic axes for all ranks
|
| 182 |
+
"""
|
| 183 |
+
return all(
|
| 184 |
+
not dim.dim_param.isnumeric()
|
| 185 |
+
for inp in onnx_model.graph.input
|
| 186 |
+
for dim in inp.type.tensor_type.shape.dim
|
| 187 |
+
) and all(
|
| 188 |
+
not dim.dim_param.isnumeric()
|
| 189 |
+
for out in onnx_model.graph.output
|
| 190 |
+
for dim in out.type.tensor_type.shape.dim
|
| 191 |
+
)
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
def register_custom_op_onnx_export(
|
| 195 |
+
opname: str, symbolic_fn: Callable, opset_version: int, min_version: str
|
| 196 |
+
) -> None:
|
| 197 |
+
"""
|
| 198 |
+
Register `symbolic_fn` as PyTorch's symbolic `opname`-`opset_version` for ONNX export.
|
| 199 |
+
The registration is performed only when current PyTorch's version is < `min_version.`
|
| 200 |
+
IMPORTANT: symbolic must be manually unregistered after the caller function returns
|
| 201 |
+
"""
|
| 202 |
+
if min_torch_version(min_version):
|
| 203 |
+
return
|
| 204 |
+
register_custom_op_symbolic(opname, symbolic_fn, opset_version)
|
| 205 |
+
print(f"_register_custom_op_onnx_export({opname}, {opset_version}) succeeded.")
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
def unregister_custom_op_onnx_export(opname: str, opset_version: int, min_version: str) -> None:
|
| 209 |
+
"""
|
| 210 |
+
Unregister PyTorch's symbolic `opname`-`opset_version` for ONNX export.
|
| 211 |
+
The un-registration is performed only when PyTorch's version is < `min_version`
|
| 212 |
+
IMPORTANT: The symbolic must have been manually registered by the caller, otherwise
|
| 213 |
+
the incorrect symbolic may be unregistered instead.
|
| 214 |
+
"""
|
| 215 |
+
|
| 216 |
+
# TODO: _unregister_custom_op_symbolic is introduced PyTorch>=1.10
|
| 217 |
+
# Remove after PyTorch 1.10+ is used by ALL detectron2's CI
|
| 218 |
+
try:
|
| 219 |
+
from torch.onnx import unregister_custom_op_symbolic as _unregister_custom_op_symbolic
|
| 220 |
+
except ImportError:
|
| 221 |
+
|
| 222 |
+
def _unregister_custom_op_symbolic(symbolic_name, opset_version):
|
| 223 |
+
import torch.onnx.symbolic_registry as sym_registry
|
| 224 |
+
from torch.onnx.symbolic_helper import _onnx_main_opset, _onnx_stable_opsets
|
| 225 |
+
|
| 226 |
+
def _get_ns_op_name_from_custom_op(symbolic_name):
|
| 227 |
+
try:
|
| 228 |
+
from torch.onnx.utils import get_ns_op_name_from_custom_op
|
| 229 |
+
|
| 230 |
+
ns, op_name = get_ns_op_name_from_custom_op(symbolic_name)
|
| 231 |
+
except ImportError as import_error:
|
| 232 |
+
if not bool(
|
| 233 |
+
re.match(r"^[a-zA-Z0-9-_]*::[a-zA-Z-_]+[a-zA-Z0-9-_]*$", symbolic_name)
|
| 234 |
+
):
|
| 235 |
+
raise ValueError(
|
| 236 |
+
f"Invalid symbolic name {symbolic_name}. Must be `domain::name`"
|
| 237 |
+
) from import_error
|
| 238 |
+
|
| 239 |
+
ns, op_name = symbolic_name.split("::")
|
| 240 |
+
if ns == "onnx":
|
| 241 |
+
raise ValueError(f"{ns} domain cannot be modified.") from import_error
|
| 242 |
+
|
| 243 |
+
if ns == "aten":
|
| 244 |
+
ns = ""
|
| 245 |
+
|
| 246 |
+
return ns, op_name
|
| 247 |
+
|
| 248 |
+
def _unregister_op(opname: str, domain: str, version: int):
|
| 249 |
+
try:
|
| 250 |
+
sym_registry.unregister_op(op_name, ns, ver)
|
| 251 |
+
except AttributeError as attribute_error:
|
| 252 |
+
if sym_registry.is_registered_op(opname, domain, version):
|
| 253 |
+
del sym_registry._registry[(domain, version)][opname]
|
| 254 |
+
if not sym_registry._registry[(domain, version)]:
|
| 255 |
+
del sym_registry._registry[(domain, version)]
|
| 256 |
+
else:
|
| 257 |
+
raise RuntimeError(
|
| 258 |
+
f"The opname {opname} is not registered."
|
| 259 |
+
) from attribute_error
|
| 260 |
+
|
| 261 |
+
ns, op_name = _get_ns_op_name_from_custom_op(symbolic_name)
|
| 262 |
+
for ver in _onnx_stable_opsets + [_onnx_main_opset]:
|
| 263 |
+
if ver >= opset_version:
|
| 264 |
+
_unregister_op(op_name, ns, ver)
|
| 265 |
+
|
| 266 |
+
if min_torch_version(min_version):
|
| 267 |
+
return
|
| 268 |
+
_unregister_custom_op_symbolic(opname, opset_version)
|
| 269 |
+
print(f"_unregister_custom_op_onnx_export({opname}, {opset_version}) succeeded.")
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
skipIfOnCPUCI = unittest.skipIf(
|
| 273 |
+
os.environ.get("CI") and not torch.cuda.is_available(),
|
| 274 |
+
"The test is too slow on CPUs and will be executed on CircleCI's GPU jobs.",
|
| 275 |
+
)
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
def skipIfUnsupportedMinOpsetVersion(min_opset_version, current_opset_version=None):
|
| 279 |
+
"""
|
| 280 |
+
Skips tests for ONNX Opset versions older than min_opset_version.
|
| 281 |
+
"""
|
| 282 |
+
|
| 283 |
+
def skip_dec(func):
|
| 284 |
+
def wrapper(self):
|
| 285 |
+
try:
|
| 286 |
+
opset_version = self.opset_version
|
| 287 |
+
except AttributeError:
|
| 288 |
+
opset_version = current_opset_version
|
| 289 |
+
if opset_version < min_opset_version:
|
| 290 |
+
raise unittest.SkipTest(
|
| 291 |
+
f"Unsupported opset_version {opset_version}"
|
| 292 |
+
f", required is {min_opset_version}"
|
| 293 |
+
)
|
| 294 |
+
return func(self)
|
| 295 |
+
|
| 296 |
+
return wrapper
|
| 297 |
+
|
| 298 |
+
return skip_dec
|
| 299 |
+
|
| 300 |
+
|
| 301 |
+
def skipIfUnsupportedMinTorchVersion(min_version):
|
| 302 |
+
"""
|
| 303 |
+
Skips tests for PyTorch versions older than min_version.
|
| 304 |
+
"""
|
| 305 |
+
reason = f"module 'torch' has __version__ {torch.__version__}" f", required is: {min_version}"
|
| 306 |
+
return unittest.skipIf(not min_torch_version(min_version), reason)
|
| 307 |
+
|
| 308 |
+
|
| 309 |
+
# TODO: Remove after PyTorch 1.11.1+ is used by detectron2's CI
|
| 310 |
+
def _pytorch1111_symbolic_opset9_to(g, self, *args):
|
| 311 |
+
"""aten::to() symbolic that must be used for testing with PyTorch < 1.11.1."""
|
| 312 |
+
|
| 313 |
+
def is_aten_to_device_only(args):
|
| 314 |
+
if len(args) == 4:
|
| 315 |
+
# aten::to(Tensor, Device, bool, bool, memory_format)
|
| 316 |
+
return (
|
| 317 |
+
args[0].node().kind() == "prim::device"
|
| 318 |
+
or args[0].type().isSubtypeOf(ListType.ofInts())
|
| 319 |
+
or (
|
| 320 |
+
sym_help._is_value(args[0])
|
| 321 |
+
and args[0].node().kind() == "onnx::Constant"
|
| 322 |
+
and isinstance(args[0].node()["value"], str)
|
| 323 |
+
)
|
| 324 |
+
)
|
| 325 |
+
elif len(args) == 5:
|
| 326 |
+
# aten::to(Tensor, Device, ScalarType, bool, bool, memory_format)
|
| 327 |
+
# When dtype is None, this is a aten::to(device) call
|
| 328 |
+
dtype = sym_help._get_const(args[1], "i", "dtype")
|
| 329 |
+
return dtype is None
|
| 330 |
+
elif len(args) in (6, 7):
|
| 331 |
+
# aten::to(Tensor, ScalarType, Layout, Device, bool, bool, memory_format)
|
| 332 |
+
# aten::to(Tensor, ScalarType, Layout, Device, bool, bool, bool, memory_format)
|
| 333 |
+
# When dtype is None, this is a aten::to(device) call
|
| 334 |
+
dtype = sym_help._get_const(args[0], "i", "dtype")
|
| 335 |
+
return dtype is None
|
| 336 |
+
return False
|
| 337 |
+
|
| 338 |
+
# ONNX doesn't have a concept of a device, so we ignore device-only casts
|
| 339 |
+
if is_aten_to_device_only(args):
|
| 340 |
+
return self
|
| 341 |
+
|
| 342 |
+
if len(args) == 4:
|
| 343 |
+
# TestONNXRuntime::test_ones_bool shows args[0] of aten::to can be onnx::Constant[Tensor]
|
| 344 |
+
# In this case, the constant value is a tensor not int,
|
| 345 |
+
# so sym_help._maybe_get_const(args[0], 'i') would not work.
|
| 346 |
+
dtype = args[0]
|
| 347 |
+
if sym_help._is_value(args[0]) and args[0].node().kind() == "onnx::Constant":
|
| 348 |
+
tval = args[0].node()["value"]
|
| 349 |
+
if isinstance(tval, torch.Tensor):
|
| 350 |
+
if len(tval.shape) == 0:
|
| 351 |
+
tval = tval.item()
|
| 352 |
+
dtype = int(tval)
|
| 353 |
+
else:
|
| 354 |
+
dtype = tval
|
| 355 |
+
|
| 356 |
+
if sym_help._is_value(dtype) or isinstance(dtype, torch.Tensor):
|
| 357 |
+
# aten::to(Tensor, Tensor, bool, bool, memory_format)
|
| 358 |
+
dtype = args[0].type().scalarType()
|
| 359 |
+
return g.op("Cast", self, to_i=sym_help.cast_pytorch_to_onnx[dtype])
|
| 360 |
+
else:
|
| 361 |
+
# aten::to(Tensor, ScalarType, bool, bool, memory_format)
|
| 362 |
+
# memory_format is ignored
|
| 363 |
+
return g.op("Cast", self, to_i=sym_help.scalar_type_to_onnx[dtype])
|
| 364 |
+
elif len(args) == 5:
|
| 365 |
+
# aten::to(Tensor, Device, ScalarType, bool, bool, memory_format)
|
| 366 |
+
dtype = sym_help._get_const(args[1], "i", "dtype")
|
| 367 |
+
# memory_format is ignored
|
| 368 |
+
return g.op("Cast", self, to_i=sym_help.scalar_type_to_onnx[dtype])
|
| 369 |
+
elif len(args) == 6:
|
| 370 |
+
# aten::to(Tensor, ScalarType, Layout, Device, bool, bool, memory_format)
|
| 371 |
+
dtype = sym_help._get_const(args[0], "i", "dtype")
|
| 372 |
+
# Layout, device and memory_format are ignored
|
| 373 |
+
return g.op("Cast", self, to_i=sym_help.scalar_type_to_onnx[dtype])
|
| 374 |
+
elif len(args) == 7:
|
| 375 |
+
# aten::to(Tensor, ScalarType, Layout, Device, bool, bool, bool, memory_format)
|
| 376 |
+
dtype = sym_help._get_const(args[0], "i", "dtype")
|
| 377 |
+
# Layout, device and memory_format are ignored
|
| 378 |
+
return g.op("Cast", self, to_i=sym_help.scalar_type_to_onnx[dtype])
|
| 379 |
+
else:
|
| 380 |
+
return sym_help._onnx_unsupported("Unknown aten::to signature")
|
| 381 |
+
|
| 382 |
+
|
| 383 |
+
# TODO: Remove after PyTorch 1.11.1+ is used by detectron2's CI
|
| 384 |
+
def _pytorch1111_symbolic_opset9_repeat_interleave(g, self, repeats, dim=None, output_size=None):
|
| 385 |
+
|
| 386 |
+
# from torch.onnx.symbolic_helper import ScalarType
|
| 387 |
+
from torch.onnx.symbolic_opset9 import expand, unsqueeze
|
| 388 |
+
|
| 389 |
+
input = self
|
| 390 |
+
# if dim is None flatten
|
| 391 |
+
# By default, use the flattened input array, and return a flat output array
|
| 392 |
+
if sym_help._is_none(dim):
|
| 393 |
+
input = sym_help._reshape_helper(g, self, g.op("Constant", value_t=torch.tensor([-1])))
|
| 394 |
+
dim = 0
|
| 395 |
+
else:
|
| 396 |
+
dim = sym_help._maybe_get_scalar(dim)
|
| 397 |
+
|
| 398 |
+
repeats_dim = sym_help._get_tensor_rank(repeats)
|
| 399 |
+
repeats_sizes = sym_help._get_tensor_sizes(repeats)
|
| 400 |
+
input_sizes = sym_help._get_tensor_sizes(input)
|
| 401 |
+
if repeats_dim is None:
|
| 402 |
+
raise RuntimeError(
|
| 403 |
+
"Unsupported: ONNX export of repeat_interleave for unknown " "repeats rank."
|
| 404 |
+
)
|
| 405 |
+
if repeats_sizes is None:
|
| 406 |
+
raise RuntimeError(
|
| 407 |
+
"Unsupported: ONNX export of repeat_interleave for unknown " "repeats size."
|
| 408 |
+
)
|
| 409 |
+
if input_sizes is None:
|
| 410 |
+
raise RuntimeError(
|
| 411 |
+
"Unsupported: ONNX export of repeat_interleave for unknown " "input size."
|
| 412 |
+
)
|
| 413 |
+
|
| 414 |
+
input_sizes_temp = input_sizes.copy()
|
| 415 |
+
for idx, input_size in enumerate(input_sizes):
|
| 416 |
+
if input_size is None:
|
| 417 |
+
input_sizes[idx], input_sizes_temp[idx] = 0, -1
|
| 418 |
+
|
| 419 |
+
# Cases where repeats is an int or single value tensor
|
| 420 |
+
if repeats_dim == 0 or (repeats_dim == 1 and repeats_sizes[0] == 1):
|
| 421 |
+
if not sym_help._is_tensor(repeats):
|
| 422 |
+
repeats = g.op("Constant", value_t=torch.LongTensor(repeats))
|
| 423 |
+
if input_sizes[dim] == 0:
|
| 424 |
+
return sym_help._onnx_opset_unsupported_detailed(
|
| 425 |
+
"repeat_interleave",
|
| 426 |
+
9,
|
| 427 |
+
13,
|
| 428 |
+
"Unsupported along dimension with unknown input size",
|
| 429 |
+
)
|
| 430 |
+
else:
|
| 431 |
+
reps = input_sizes[dim]
|
| 432 |
+
repeats = expand(g, repeats, g.op("Constant", value_t=torch.tensor([reps])), None)
|
| 433 |
+
|
| 434 |
+
# Cases where repeats is a 1 dim Tensor
|
| 435 |
+
elif repeats_dim == 1:
|
| 436 |
+
if input_sizes[dim] == 0:
|
| 437 |
+
return sym_help._onnx_opset_unsupported_detailed(
|
| 438 |
+
"repeat_interleave",
|
| 439 |
+
9,
|
| 440 |
+
13,
|
| 441 |
+
"Unsupported along dimension with unknown input size",
|
| 442 |
+
)
|
| 443 |
+
if repeats_sizes[0] is None:
|
| 444 |
+
return sym_help._onnx_opset_unsupported_detailed(
|
| 445 |
+
"repeat_interleave", 9, 13, "Unsupported for cases with dynamic repeats"
|
| 446 |
+
)
|
| 447 |
+
assert (
|
| 448 |
+
repeats_sizes[0] == input_sizes[dim]
|
| 449 |
+
), "repeats must have the same size as input along dim"
|
| 450 |
+
reps = repeats_sizes[0]
|
| 451 |
+
else:
|
| 452 |
+
raise RuntimeError("repeats must be 0-dim or 1-dim tensor")
|
| 453 |
+
|
| 454 |
+
final_splits = list()
|
| 455 |
+
r_splits = sym_help._repeat_interleave_split_helper(g, repeats, reps, 0)
|
| 456 |
+
if isinstance(r_splits, torch._C.Value):
|
| 457 |
+
r_splits = [r_splits]
|
| 458 |
+
i_splits = sym_help._repeat_interleave_split_helper(g, input, reps, dim)
|
| 459 |
+
if isinstance(i_splits, torch._C.Value):
|
| 460 |
+
i_splits = [i_splits]
|
| 461 |
+
input_sizes[dim], input_sizes_temp[dim] = -1, 1
|
| 462 |
+
for idx, r_split in enumerate(r_splits):
|
| 463 |
+
i_split = unsqueeze(g, i_splits[idx], dim + 1)
|
| 464 |
+
r_concat = [
|
| 465 |
+
g.op("Constant", value_t=torch.LongTensor(input_sizes_temp[: dim + 1])),
|
| 466 |
+
r_split,
|
| 467 |
+
g.op("Constant", value_t=torch.LongTensor(input_sizes_temp[dim + 1 :])),
|
| 468 |
+
]
|
| 469 |
+
r_concat = g.op("Concat", *r_concat, axis_i=0)
|
| 470 |
+
i_split = expand(g, i_split, r_concat, None)
|
| 471 |
+
i_split = sym_help._reshape_helper(
|
| 472 |
+
g,
|
| 473 |
+
i_split,
|
| 474 |
+
g.op("Constant", value_t=torch.LongTensor(input_sizes)),
|
| 475 |
+
allowzero=0,
|
| 476 |
+
)
|
| 477 |
+
final_splits.append(i_split)
|
| 478 |
+
return g.op("Concat", *final_splits, axis_i=dim)
|
detectron2/utils/tracing.py
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import inspect
|
| 2 |
+
import torch
|
| 3 |
+
|
| 4 |
+
from detectron2.utils.env import TORCH_VERSION
|
| 5 |
+
|
| 6 |
+
try:
|
| 7 |
+
from torch.fx._symbolic_trace import is_fx_tracing as is_fx_tracing_current
|
| 8 |
+
|
| 9 |
+
tracing_current_exists = True
|
| 10 |
+
except ImportError:
|
| 11 |
+
tracing_current_exists = False
|
| 12 |
+
|
| 13 |
+
try:
|
| 14 |
+
from torch.fx._symbolic_trace import _orig_module_call
|
| 15 |
+
|
| 16 |
+
tracing_legacy_exists = True
|
| 17 |
+
except ImportError:
|
| 18 |
+
tracing_legacy_exists = False
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
@torch.jit.ignore
|
| 22 |
+
def is_fx_tracing_legacy() -> bool:
|
| 23 |
+
"""
|
| 24 |
+
Returns a bool indicating whether torch.fx is currently symbolically tracing a module.
|
| 25 |
+
Can be useful for gating module logic that is incompatible with symbolic tracing.
|
| 26 |
+
"""
|
| 27 |
+
return torch.nn.Module.__call__ is not _orig_module_call
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def is_fx_tracing() -> bool:
|
| 31 |
+
"""Returns whether execution is currently in
|
| 32 |
+
Torch FX tracing mode"""
|
| 33 |
+
if torch.jit.is_scripting():
|
| 34 |
+
return False
|
| 35 |
+
if TORCH_VERSION >= (1, 10) and tracing_current_exists:
|
| 36 |
+
return is_fx_tracing_current()
|
| 37 |
+
elif tracing_legacy_exists:
|
| 38 |
+
return is_fx_tracing_legacy()
|
| 39 |
+
else:
|
| 40 |
+
# Can't find either current or legacy tracing indication code.
|
| 41 |
+
# Enabling this assert_fx_safe() call regardless of tracing status.
|
| 42 |
+
return False
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def assert_fx_safe(condition: bool, message: str) -> torch.Tensor:
|
| 46 |
+
"""An FX-tracing safe version of assert.
|
| 47 |
+
Avoids erroneous type assertion triggering when types are masked inside
|
| 48 |
+
an fx.proxy.Proxy object during tracing.
|
| 49 |
+
Args: condition - either a boolean expression or a string representing
|
| 50 |
+
the condition to test. If this assert triggers an exception when tracing
|
| 51 |
+
due to dynamic control flow, try encasing the expression in quotation
|
| 52 |
+
marks and supplying it as a string."""
|
| 53 |
+
# Must return a concrete tensor for compatibility with PyTorch <=1.8.
|
| 54 |
+
# If <=1.8 compatibility is not needed, return type can be converted to None
|
| 55 |
+
if torch.jit.is_scripting() or is_fx_tracing():
|
| 56 |
+
return torch.zeros(1)
|
| 57 |
+
return _do_assert_fx_safe(condition, message)
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def _do_assert_fx_safe(condition: bool, message: str) -> torch.Tensor:
|
| 61 |
+
try:
|
| 62 |
+
if isinstance(condition, str):
|
| 63 |
+
caller_frame = inspect.currentframe().f_back
|
| 64 |
+
torch._assert(eval(condition, caller_frame.f_globals, caller_frame.f_locals), message)
|
| 65 |
+
return torch.ones(1)
|
| 66 |
+
else:
|
| 67 |
+
torch._assert(condition, message)
|
| 68 |
+
return torch.ones(1)
|
| 69 |
+
except torch.fx.proxy.TraceError as e:
|
| 70 |
+
print(
|
| 71 |
+
"Found a non-FX compatible assertion. Skipping the check. Failure is shown below"
|
| 72 |
+
+ str(e)
|
| 73 |
+
)
|
detectron2/utils/video_visualizer.py
ADDED
|
@@ -0,0 +1,287 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
import numpy as np
|
| 3 |
+
from typing import List
|
| 4 |
+
import pycocotools.mask as mask_util
|
| 5 |
+
|
| 6 |
+
from detectron2.structures import Instances
|
| 7 |
+
from detectron2.utils.visualizer import (
|
| 8 |
+
ColorMode,
|
| 9 |
+
Visualizer,
|
| 10 |
+
_create_text_labels,
|
| 11 |
+
_PanopticPrediction,
|
| 12 |
+
)
|
| 13 |
+
|
| 14 |
+
from .colormap import random_color, random_colors
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class _DetectedInstance:
|
| 18 |
+
"""
|
| 19 |
+
Used to store data about detected objects in video frame,
|
| 20 |
+
in order to transfer color to objects in the future frames.
|
| 21 |
+
|
| 22 |
+
Attributes:
|
| 23 |
+
label (int):
|
| 24 |
+
bbox (tuple[float]):
|
| 25 |
+
mask_rle (dict):
|
| 26 |
+
color (tuple[float]): RGB colors in range (0, 1)
|
| 27 |
+
ttl (int): time-to-live for the instance. For example, if ttl=2,
|
| 28 |
+
the instance color can be transferred to objects in the next two frames.
|
| 29 |
+
"""
|
| 30 |
+
|
| 31 |
+
__slots__ = ["label", "bbox", "mask_rle", "color", "ttl"]
|
| 32 |
+
|
| 33 |
+
def __init__(self, label, bbox, mask_rle, color, ttl):
|
| 34 |
+
self.label = label
|
| 35 |
+
self.bbox = bbox
|
| 36 |
+
self.mask_rle = mask_rle
|
| 37 |
+
self.color = color
|
| 38 |
+
self.ttl = ttl
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class VideoVisualizer:
|
| 42 |
+
def __init__(self, metadata, instance_mode=ColorMode.IMAGE):
|
| 43 |
+
"""
|
| 44 |
+
Args:
|
| 45 |
+
metadata (MetadataCatalog): image metadata.
|
| 46 |
+
"""
|
| 47 |
+
self.metadata = metadata
|
| 48 |
+
self._old_instances = []
|
| 49 |
+
assert instance_mode in [
|
| 50 |
+
ColorMode.IMAGE,
|
| 51 |
+
ColorMode.IMAGE_BW,
|
| 52 |
+
], "Other mode not supported yet."
|
| 53 |
+
self._instance_mode = instance_mode
|
| 54 |
+
self._max_num_instances = self.metadata.get("max_num_instances", 74)
|
| 55 |
+
self._assigned_colors = {}
|
| 56 |
+
self._color_pool = random_colors(self._max_num_instances, rgb=True, maximum=1)
|
| 57 |
+
self._color_idx_set = set(range(len(self._color_pool)))
|
| 58 |
+
|
| 59 |
+
def draw_instance_predictions(self, frame, predictions):
|
| 60 |
+
"""
|
| 61 |
+
Draw instance-level prediction results on an image.
|
| 62 |
+
|
| 63 |
+
Args:
|
| 64 |
+
frame (ndarray): an RGB image of shape (H, W, C), in the range [0, 255].
|
| 65 |
+
predictions (Instances): the output of an instance detection/segmentation
|
| 66 |
+
model. Following fields will be used to draw:
|
| 67 |
+
"pred_boxes", "pred_classes", "scores", "pred_masks" (or "pred_masks_rle").
|
| 68 |
+
|
| 69 |
+
Returns:
|
| 70 |
+
output (VisImage): image object with visualizations.
|
| 71 |
+
"""
|
| 72 |
+
frame_visualizer = Visualizer(frame, self.metadata)
|
| 73 |
+
num_instances = len(predictions)
|
| 74 |
+
if num_instances == 0:
|
| 75 |
+
return frame_visualizer.output
|
| 76 |
+
|
| 77 |
+
boxes = predictions.pred_boxes.tensor.numpy() if predictions.has("pred_boxes") else None
|
| 78 |
+
scores = predictions.scores if predictions.has("scores") else None
|
| 79 |
+
classes = predictions.pred_classes.numpy() if predictions.has("pred_classes") else None
|
| 80 |
+
keypoints = predictions.pred_keypoints if predictions.has("pred_keypoints") else None
|
| 81 |
+
colors = predictions.COLOR if predictions.has("COLOR") else [None] * len(predictions)
|
| 82 |
+
periods = predictions.ID_period if predictions.has("ID_period") else None
|
| 83 |
+
period_threshold = self.metadata.get("period_threshold", 0)
|
| 84 |
+
visibilities = (
|
| 85 |
+
[True] * len(predictions)
|
| 86 |
+
if periods is None
|
| 87 |
+
else [x > period_threshold for x in periods]
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
if predictions.has("pred_masks"):
|
| 91 |
+
masks = predictions.pred_masks
|
| 92 |
+
# mask IOU is not yet enabled
|
| 93 |
+
# masks_rles = mask_util.encode(np.asarray(masks.permute(1, 2, 0), order="F"))
|
| 94 |
+
# assert len(masks_rles) == num_instances
|
| 95 |
+
else:
|
| 96 |
+
masks = None
|
| 97 |
+
|
| 98 |
+
if not predictions.has("COLOR"):
|
| 99 |
+
if predictions.has("ID"):
|
| 100 |
+
colors = self._assign_colors_by_id(predictions)
|
| 101 |
+
else:
|
| 102 |
+
# ToDo: clean old assign color method and use a default tracker to assign id
|
| 103 |
+
detected = [
|
| 104 |
+
_DetectedInstance(classes[i], boxes[i], mask_rle=None, color=colors[i], ttl=8)
|
| 105 |
+
for i in range(num_instances)
|
| 106 |
+
]
|
| 107 |
+
colors = self._assign_colors(detected)
|
| 108 |
+
|
| 109 |
+
labels = _create_text_labels(classes, scores, self.metadata.get("thing_classes", None))
|
| 110 |
+
|
| 111 |
+
if self._instance_mode == ColorMode.IMAGE_BW:
|
| 112 |
+
# any() returns uint8 tensor
|
| 113 |
+
frame_visualizer.output.reset_image(
|
| 114 |
+
frame_visualizer._create_grayscale_image(
|
| 115 |
+
(masks.any(dim=0) > 0).numpy() if masks is not None else None
|
| 116 |
+
)
|
| 117 |
+
)
|
| 118 |
+
alpha = 0.3
|
| 119 |
+
else:
|
| 120 |
+
alpha = 0.5
|
| 121 |
+
|
| 122 |
+
labels = (
|
| 123 |
+
None
|
| 124 |
+
if labels is None
|
| 125 |
+
else [y[0] for y in filter(lambda x: x[1], zip(labels, visibilities))]
|
| 126 |
+
) # noqa
|
| 127 |
+
assigned_colors = (
|
| 128 |
+
None
|
| 129 |
+
if colors is None
|
| 130 |
+
else [y[0] for y in filter(lambda x: x[1], zip(colors, visibilities))]
|
| 131 |
+
) # noqa
|
| 132 |
+
frame_visualizer.overlay_instances(
|
| 133 |
+
boxes=None if masks is not None else boxes[visibilities], # boxes are a bit distracting
|
| 134 |
+
masks=None if masks is None else masks[visibilities],
|
| 135 |
+
labels=labels,
|
| 136 |
+
keypoints=None if keypoints is None else keypoints[visibilities],
|
| 137 |
+
assigned_colors=assigned_colors,
|
| 138 |
+
alpha=alpha,
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
return frame_visualizer.output
|
| 142 |
+
|
| 143 |
+
def draw_sem_seg(self, frame, sem_seg, area_threshold=None):
|
| 144 |
+
"""
|
| 145 |
+
Args:
|
| 146 |
+
sem_seg (ndarray or Tensor): semantic segmentation of shape (H, W),
|
| 147 |
+
each value is the integer label.
|
| 148 |
+
area_threshold (Optional[int]): only draw segmentations larger than the threshold
|
| 149 |
+
"""
|
| 150 |
+
# don't need to do anything special
|
| 151 |
+
frame_visualizer = Visualizer(frame, self.metadata)
|
| 152 |
+
frame_visualizer.draw_sem_seg(sem_seg, area_threshold=None)
|
| 153 |
+
return frame_visualizer.output
|
| 154 |
+
|
| 155 |
+
def draw_panoptic_seg_predictions(
|
| 156 |
+
self, frame, panoptic_seg, segments_info, area_threshold=None, alpha=0.5
|
| 157 |
+
):
|
| 158 |
+
frame_visualizer = Visualizer(frame, self.metadata)
|
| 159 |
+
pred = _PanopticPrediction(panoptic_seg, segments_info, self.metadata)
|
| 160 |
+
|
| 161 |
+
if self._instance_mode == ColorMode.IMAGE_BW:
|
| 162 |
+
frame_visualizer.output.reset_image(
|
| 163 |
+
frame_visualizer._create_grayscale_image(pred.non_empty_mask())
|
| 164 |
+
)
|
| 165 |
+
|
| 166 |
+
# draw mask for all semantic segments first i.e. "stuff"
|
| 167 |
+
for mask, sinfo in pred.semantic_masks():
|
| 168 |
+
category_idx = sinfo["category_id"]
|
| 169 |
+
try:
|
| 170 |
+
mask_color = [x / 255 for x in self.metadata.stuff_colors[category_idx]]
|
| 171 |
+
except AttributeError:
|
| 172 |
+
mask_color = None
|
| 173 |
+
|
| 174 |
+
frame_visualizer.draw_binary_mask(
|
| 175 |
+
mask,
|
| 176 |
+
color=mask_color,
|
| 177 |
+
text=self.metadata.stuff_classes[category_idx],
|
| 178 |
+
alpha=alpha,
|
| 179 |
+
area_threshold=area_threshold,
|
| 180 |
+
)
|
| 181 |
+
|
| 182 |
+
all_instances = list(pred.instance_masks())
|
| 183 |
+
if len(all_instances) == 0:
|
| 184 |
+
return frame_visualizer.output
|
| 185 |
+
# draw mask for all instances second
|
| 186 |
+
masks, sinfo = list(zip(*all_instances))
|
| 187 |
+
num_instances = len(masks)
|
| 188 |
+
masks_rles = mask_util.encode(
|
| 189 |
+
np.asarray(np.asarray(masks).transpose(1, 2, 0), dtype=np.uint8, order="F")
|
| 190 |
+
)
|
| 191 |
+
assert len(masks_rles) == num_instances
|
| 192 |
+
|
| 193 |
+
category_ids = [x["category_id"] for x in sinfo]
|
| 194 |
+
detected = [
|
| 195 |
+
_DetectedInstance(category_ids[i], bbox=None, mask_rle=masks_rles[i], color=None, ttl=8)
|
| 196 |
+
for i in range(num_instances)
|
| 197 |
+
]
|
| 198 |
+
colors = self._assign_colors(detected)
|
| 199 |
+
labels = [self.metadata.thing_classes[k] for k in category_ids]
|
| 200 |
+
|
| 201 |
+
frame_visualizer.overlay_instances(
|
| 202 |
+
boxes=None,
|
| 203 |
+
masks=masks,
|
| 204 |
+
labels=labels,
|
| 205 |
+
keypoints=None,
|
| 206 |
+
assigned_colors=colors,
|
| 207 |
+
alpha=alpha,
|
| 208 |
+
)
|
| 209 |
+
return frame_visualizer.output
|
| 210 |
+
|
| 211 |
+
def _assign_colors(self, instances):
|
| 212 |
+
"""
|
| 213 |
+
Naive tracking heuristics to assign same color to the same instance,
|
| 214 |
+
will update the internal state of tracked instances.
|
| 215 |
+
|
| 216 |
+
Returns:
|
| 217 |
+
list[tuple[float]]: list of colors.
|
| 218 |
+
"""
|
| 219 |
+
|
| 220 |
+
# Compute iou with either boxes or masks:
|
| 221 |
+
is_crowd = np.zeros((len(instances),), dtype=bool)
|
| 222 |
+
if instances[0].bbox is None:
|
| 223 |
+
assert instances[0].mask_rle is not None
|
| 224 |
+
# use mask iou only when box iou is None
|
| 225 |
+
# because box seems good enough
|
| 226 |
+
rles_old = [x.mask_rle for x in self._old_instances]
|
| 227 |
+
rles_new = [x.mask_rle for x in instances]
|
| 228 |
+
ious = mask_util.iou(rles_old, rles_new, is_crowd)
|
| 229 |
+
threshold = 0.5
|
| 230 |
+
else:
|
| 231 |
+
boxes_old = [x.bbox for x in self._old_instances]
|
| 232 |
+
boxes_new = [x.bbox for x in instances]
|
| 233 |
+
ious = mask_util.iou(boxes_old, boxes_new, is_crowd)
|
| 234 |
+
threshold = 0.6
|
| 235 |
+
if len(ious) == 0:
|
| 236 |
+
ious = np.zeros((len(self._old_instances), len(instances)), dtype="float32")
|
| 237 |
+
|
| 238 |
+
# Only allow matching instances of the same label:
|
| 239 |
+
for old_idx, old in enumerate(self._old_instances):
|
| 240 |
+
for new_idx, new in enumerate(instances):
|
| 241 |
+
if old.label != new.label:
|
| 242 |
+
ious[old_idx, new_idx] = 0
|
| 243 |
+
|
| 244 |
+
matched_new_per_old = np.asarray(ious).argmax(axis=1)
|
| 245 |
+
max_iou_per_old = np.asarray(ious).max(axis=1)
|
| 246 |
+
|
| 247 |
+
# Try to find match for each old instance:
|
| 248 |
+
extra_instances = []
|
| 249 |
+
for idx, inst in enumerate(self._old_instances):
|
| 250 |
+
if max_iou_per_old[idx] > threshold:
|
| 251 |
+
newidx = matched_new_per_old[idx]
|
| 252 |
+
if instances[newidx].color is None:
|
| 253 |
+
instances[newidx].color = inst.color
|
| 254 |
+
continue
|
| 255 |
+
# If an old instance does not match any new instances,
|
| 256 |
+
# keep it for the next frame in case it is just missed by the detector
|
| 257 |
+
inst.ttl -= 1
|
| 258 |
+
if inst.ttl > 0:
|
| 259 |
+
extra_instances.append(inst)
|
| 260 |
+
|
| 261 |
+
# Assign random color to newly-detected instances:
|
| 262 |
+
for inst in instances:
|
| 263 |
+
if inst.color is None:
|
| 264 |
+
inst.color = random_color(rgb=True, maximum=1)
|
| 265 |
+
self._old_instances = instances[:] + extra_instances
|
| 266 |
+
return [d.color for d in instances]
|
| 267 |
+
|
| 268 |
+
def _assign_colors_by_id(self, instances: Instances) -> List:
|
| 269 |
+
colors = []
|
| 270 |
+
untracked_ids = set(self._assigned_colors.keys())
|
| 271 |
+
for id in instances.ID:
|
| 272 |
+
if id in self._assigned_colors:
|
| 273 |
+
colors.append(self._color_pool[self._assigned_colors[id]])
|
| 274 |
+
untracked_ids.remove(id)
|
| 275 |
+
else:
|
| 276 |
+
assert (
|
| 277 |
+
len(self._color_idx_set) >= 1
|
| 278 |
+
), f"Number of id exceeded maximum, \
|
| 279 |
+
max = {self._max_num_instances}"
|
| 280 |
+
idx = self._color_idx_set.pop()
|
| 281 |
+
color = self._color_pool[idx]
|
| 282 |
+
self._assigned_colors[id] = idx
|
| 283 |
+
colors.append(color)
|
| 284 |
+
for id in untracked_ids:
|
| 285 |
+
self._color_idx_set.add(self._assigned_colors[id])
|
| 286 |
+
del self._assigned_colors[id]
|
| 287 |
+
return colors
|
detectron2/utils/visualizer.py
ADDED
|
@@ -0,0 +1,1281 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
import colorsys
|
| 3 |
+
import logging
|
| 4 |
+
import math
|
| 5 |
+
import numpy as np
|
| 6 |
+
from enum import Enum, unique
|
| 7 |
+
import cv2
|
| 8 |
+
import matplotlib as mpl
|
| 9 |
+
import matplotlib.colors as mplc
|
| 10 |
+
import matplotlib.figure as mplfigure
|
| 11 |
+
import pycocotools.mask as mask_util
|
| 12 |
+
import torch
|
| 13 |
+
from matplotlib.backends.backend_agg import FigureCanvasAgg
|
| 14 |
+
from PIL import Image
|
| 15 |
+
|
| 16 |
+
from detectron2.data import MetadataCatalog
|
| 17 |
+
from detectron2.structures import BitMasks, Boxes, BoxMode, Keypoints, PolygonMasks, RotatedBoxes
|
| 18 |
+
from detectron2.utils.file_io import PathManager
|
| 19 |
+
|
| 20 |
+
from .colormap import random_color
|
| 21 |
+
|
| 22 |
+
logger = logging.getLogger(__name__)
|
| 23 |
+
|
| 24 |
+
__all__ = ["ColorMode", "VisImage", "Visualizer"]
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
_SMALL_OBJECT_AREA_THRESH = 1000
|
| 28 |
+
_LARGE_MASK_AREA_THRESH = 120000
|
| 29 |
+
_OFF_WHITE = (1.0, 1.0, 240.0 / 255)
|
| 30 |
+
_BLACK = (0, 0, 0)
|
| 31 |
+
_RED = (1.0, 0, 0)
|
| 32 |
+
|
| 33 |
+
_KEYPOINT_THRESHOLD = 0.05
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
@unique
|
| 37 |
+
class ColorMode(Enum):
|
| 38 |
+
"""
|
| 39 |
+
Enum of different color modes to use for instance visualizations.
|
| 40 |
+
"""
|
| 41 |
+
|
| 42 |
+
IMAGE = 0
|
| 43 |
+
"""
|
| 44 |
+
Picks a random color for every instance and overlay segmentations with low opacity.
|
| 45 |
+
"""
|
| 46 |
+
SEGMENTATION = 1
|
| 47 |
+
"""
|
| 48 |
+
Let instances of the same category have similar colors
|
| 49 |
+
(from metadata.thing_colors), and overlay them with
|
| 50 |
+
high opacity. This provides more attention on the quality of segmentation.
|
| 51 |
+
"""
|
| 52 |
+
IMAGE_BW = 2
|
| 53 |
+
"""
|
| 54 |
+
Same as IMAGE, but convert all areas without masks to gray-scale.
|
| 55 |
+
Only available for drawing per-instance mask predictions.
|
| 56 |
+
"""
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
class GenericMask:
|
| 60 |
+
"""
|
| 61 |
+
Attribute:
|
| 62 |
+
polygons (list[ndarray]): list[ndarray]: polygons for this mask.
|
| 63 |
+
Each ndarray has format [x, y, x, y, ...]
|
| 64 |
+
mask (ndarray): a binary mask
|
| 65 |
+
"""
|
| 66 |
+
|
| 67 |
+
def __init__(self, mask_or_polygons, height, width):
|
| 68 |
+
self._mask = self._polygons = self._has_holes = None
|
| 69 |
+
self.height = height
|
| 70 |
+
self.width = width
|
| 71 |
+
|
| 72 |
+
m = mask_or_polygons
|
| 73 |
+
if isinstance(m, dict):
|
| 74 |
+
# RLEs
|
| 75 |
+
assert "counts" in m and "size" in m
|
| 76 |
+
if isinstance(m["counts"], list): # uncompressed RLEs
|
| 77 |
+
h, w = m["size"]
|
| 78 |
+
assert h == height and w == width
|
| 79 |
+
m = mask_util.frPyObjects(m, h, w)
|
| 80 |
+
self._mask = mask_util.decode(m)[:, :]
|
| 81 |
+
return
|
| 82 |
+
|
| 83 |
+
if isinstance(m, list): # list[ndarray]
|
| 84 |
+
self._polygons = [np.asarray(x).reshape(-1) for x in m]
|
| 85 |
+
return
|
| 86 |
+
|
| 87 |
+
if isinstance(m, np.ndarray): # assumed to be a binary mask
|
| 88 |
+
assert m.shape[1] != 2, m.shape
|
| 89 |
+
assert m.shape == (
|
| 90 |
+
height,
|
| 91 |
+
width,
|
| 92 |
+
), f"mask shape: {m.shape}, target dims: {height}, {width}"
|
| 93 |
+
self._mask = m.astype("uint8")
|
| 94 |
+
return
|
| 95 |
+
|
| 96 |
+
raise ValueError("GenericMask cannot handle object {} of type '{}'".format(m, type(m)))
|
| 97 |
+
|
| 98 |
+
@property
|
| 99 |
+
def mask(self):
|
| 100 |
+
if self._mask is None:
|
| 101 |
+
self._mask = self.polygons_to_mask(self._polygons)
|
| 102 |
+
return self._mask
|
| 103 |
+
|
| 104 |
+
@property
|
| 105 |
+
def polygons(self):
|
| 106 |
+
if self._polygons is None:
|
| 107 |
+
self._polygons, self._has_holes = self.mask_to_polygons(self._mask)
|
| 108 |
+
return self._polygons
|
| 109 |
+
|
| 110 |
+
@property
|
| 111 |
+
def has_holes(self):
|
| 112 |
+
if self._has_holes is None:
|
| 113 |
+
if self._mask is not None:
|
| 114 |
+
self._polygons, self._has_holes = self.mask_to_polygons(self._mask)
|
| 115 |
+
else:
|
| 116 |
+
self._has_holes = False # if original format is polygon, does not have holes
|
| 117 |
+
return self._has_holes
|
| 118 |
+
|
| 119 |
+
def mask_to_polygons(self, mask):
|
| 120 |
+
# cv2.RETR_CCOMP flag retrieves all the contours and arranges them to a 2-level
|
| 121 |
+
# hierarchy. External contours (boundary) of the object are placed in hierarchy-1.
|
| 122 |
+
# Internal contours (holes) are placed in hierarchy-2.
|
| 123 |
+
# cv2.CHAIN_APPROX_NONE flag gets vertices of polygons from contours.
|
| 124 |
+
mask = np.ascontiguousarray(mask) # some versions of cv2 does not support incontiguous arr
|
| 125 |
+
res = cv2.findContours(mask.astype("uint8"), cv2.RETR_CCOMP, cv2.CHAIN_APPROX_NONE)
|
| 126 |
+
hierarchy = res[-1]
|
| 127 |
+
if hierarchy is None: # empty mask
|
| 128 |
+
return [], False
|
| 129 |
+
has_holes = (hierarchy.reshape(-1, 4)[:, 3] >= 0).sum() > 0
|
| 130 |
+
res = res[-2]
|
| 131 |
+
res = [x.flatten() for x in res]
|
| 132 |
+
# These coordinates from OpenCV are integers in range [0, W-1 or H-1].
|
| 133 |
+
# We add 0.5 to turn them into real-value coordinate space. A better solution
|
| 134 |
+
# would be to first +0.5 and then dilate the returned polygon by 0.5.
|
| 135 |
+
res = [x + 0.5 for x in res if len(x) >= 6]
|
| 136 |
+
return res, has_holes
|
| 137 |
+
|
| 138 |
+
def polygons_to_mask(self, polygons):
|
| 139 |
+
rle = mask_util.frPyObjects(polygons, self.height, self.width)
|
| 140 |
+
rle = mask_util.merge(rle)
|
| 141 |
+
return mask_util.decode(rle)[:, :]
|
| 142 |
+
|
| 143 |
+
def area(self):
|
| 144 |
+
return self.mask.sum()
|
| 145 |
+
|
| 146 |
+
def bbox(self):
|
| 147 |
+
p = mask_util.frPyObjects(self.polygons, self.height, self.width)
|
| 148 |
+
p = mask_util.merge(p)
|
| 149 |
+
bbox = mask_util.toBbox(p)
|
| 150 |
+
bbox[2] += bbox[0]
|
| 151 |
+
bbox[3] += bbox[1]
|
| 152 |
+
return bbox
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
class _PanopticPrediction:
|
| 156 |
+
"""
|
| 157 |
+
Unify different panoptic annotation/prediction formats
|
| 158 |
+
"""
|
| 159 |
+
|
| 160 |
+
def __init__(self, panoptic_seg, segments_info, metadata=None):
|
| 161 |
+
if segments_info is None:
|
| 162 |
+
assert metadata is not None
|
| 163 |
+
# If "segments_info" is None, we assume "panoptic_img" is a
|
| 164 |
+
# H*W int32 image storing the panoptic_id in the format of
|
| 165 |
+
# category_id * label_divisor + instance_id. We reserve -1 for
|
| 166 |
+
# VOID label.
|
| 167 |
+
label_divisor = metadata.label_divisor
|
| 168 |
+
segments_info = []
|
| 169 |
+
for panoptic_label in np.unique(panoptic_seg.numpy()):
|
| 170 |
+
if panoptic_label == -1:
|
| 171 |
+
# VOID region.
|
| 172 |
+
continue
|
| 173 |
+
pred_class = panoptic_label // label_divisor
|
| 174 |
+
isthing = pred_class in metadata.thing_dataset_id_to_contiguous_id.values()
|
| 175 |
+
segments_info.append(
|
| 176 |
+
{
|
| 177 |
+
"id": int(panoptic_label),
|
| 178 |
+
"category_id": int(pred_class),
|
| 179 |
+
"isthing": bool(isthing),
|
| 180 |
+
}
|
| 181 |
+
)
|
| 182 |
+
del metadata
|
| 183 |
+
|
| 184 |
+
self._seg = panoptic_seg
|
| 185 |
+
|
| 186 |
+
self._sinfo = {s["id"]: s for s in segments_info} # seg id -> seg info
|
| 187 |
+
segment_ids, areas = torch.unique(panoptic_seg, sorted=True, return_counts=True)
|
| 188 |
+
areas = areas.numpy()
|
| 189 |
+
sorted_idxs = np.argsort(-areas)
|
| 190 |
+
self._seg_ids, self._seg_areas = segment_ids[sorted_idxs], areas[sorted_idxs]
|
| 191 |
+
self._seg_ids = self._seg_ids.tolist()
|
| 192 |
+
for sid, area in zip(self._seg_ids, self._seg_areas):
|
| 193 |
+
if sid in self._sinfo:
|
| 194 |
+
self._sinfo[sid]["area"] = float(area)
|
| 195 |
+
|
| 196 |
+
def non_empty_mask(self):
|
| 197 |
+
"""
|
| 198 |
+
Returns:
|
| 199 |
+
(H, W) array, a mask for all pixels that have a prediction
|
| 200 |
+
"""
|
| 201 |
+
empty_ids = []
|
| 202 |
+
for id in self._seg_ids:
|
| 203 |
+
if id not in self._sinfo:
|
| 204 |
+
empty_ids.append(id)
|
| 205 |
+
if len(empty_ids) == 0:
|
| 206 |
+
return np.zeros(self._seg.shape, dtype=np.uint8)
|
| 207 |
+
assert (
|
| 208 |
+
len(empty_ids) == 1
|
| 209 |
+
), ">1 ids corresponds to no labels. This is currently not supported"
|
| 210 |
+
return (self._seg != empty_ids[0]).numpy().astype(bool)
|
| 211 |
+
|
| 212 |
+
def semantic_masks(self):
|
| 213 |
+
for sid in self._seg_ids:
|
| 214 |
+
sinfo = self._sinfo.get(sid)
|
| 215 |
+
if sinfo is None or sinfo["isthing"]:
|
| 216 |
+
# Some pixels (e.g. id 0 in PanopticFPN) have no instance or semantic predictions.
|
| 217 |
+
continue
|
| 218 |
+
yield (self._seg == sid).numpy().astype(bool), sinfo
|
| 219 |
+
|
| 220 |
+
def instance_masks(self):
|
| 221 |
+
for sid in self._seg_ids:
|
| 222 |
+
sinfo = self._sinfo.get(sid)
|
| 223 |
+
if sinfo is None or not sinfo["isthing"]:
|
| 224 |
+
continue
|
| 225 |
+
mask = (self._seg == sid).numpy().astype(bool)
|
| 226 |
+
if mask.sum() > 0:
|
| 227 |
+
yield mask, sinfo
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
def _create_text_labels(classes, scores, class_names, is_crowd=None):
|
| 231 |
+
"""
|
| 232 |
+
Args:
|
| 233 |
+
classes (list[int] or None):
|
| 234 |
+
scores (list[float] or None):
|
| 235 |
+
class_names (list[str] or None):
|
| 236 |
+
is_crowd (list[bool] or None):
|
| 237 |
+
|
| 238 |
+
Returns:
|
| 239 |
+
list[str] or None
|
| 240 |
+
"""
|
| 241 |
+
labels = None
|
| 242 |
+
if classes is not None:
|
| 243 |
+
if class_names is not None and len(class_names) > 0:
|
| 244 |
+
labels = [class_names[i] for i in classes]
|
| 245 |
+
else:
|
| 246 |
+
labels = [str(i) for i in classes]
|
| 247 |
+
if scores is not None:
|
| 248 |
+
if labels is None:
|
| 249 |
+
labels = ["{:.0f}%".format(s * 100) for s in scores]
|
| 250 |
+
else:
|
| 251 |
+
labels = ["{} {:.0f}%".format(l, s * 100) for l, s in zip(labels, scores)]
|
| 252 |
+
if labels is not None and is_crowd is not None:
|
| 253 |
+
labels = [l + ("|crowd" if crowd else "") for l, crowd in zip(labels, is_crowd)]
|
| 254 |
+
return labels
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
class VisImage:
|
| 258 |
+
def __init__(self, img, scale=1.0):
|
| 259 |
+
"""
|
| 260 |
+
Args:
|
| 261 |
+
img (ndarray): an RGB image of shape (H, W, 3) in range [0, 255].
|
| 262 |
+
scale (float): scale the input image
|
| 263 |
+
"""
|
| 264 |
+
self.img = img
|
| 265 |
+
self.scale = scale
|
| 266 |
+
self.width, self.height = img.shape[1], img.shape[0]
|
| 267 |
+
self._setup_figure(img)
|
| 268 |
+
|
| 269 |
+
def _setup_figure(self, img):
|
| 270 |
+
"""
|
| 271 |
+
Args:
|
| 272 |
+
Same as in :meth:`__init__()`.
|
| 273 |
+
|
| 274 |
+
Returns:
|
| 275 |
+
fig (matplotlib.pyplot.figure): top level container for all the image plot elements.
|
| 276 |
+
ax (matplotlib.pyplot.Axes): contains figure elements and sets the coordinate system.
|
| 277 |
+
"""
|
| 278 |
+
fig = mplfigure.Figure(frameon=False)
|
| 279 |
+
self.dpi = fig.get_dpi()
|
| 280 |
+
# add a small 1e-2 to avoid precision lost due to matplotlib's truncation
|
| 281 |
+
# (https://github.com/matplotlib/matplotlib/issues/15363)
|
| 282 |
+
fig.set_size_inches(
|
| 283 |
+
(self.width * self.scale + 1e-2) / self.dpi,
|
| 284 |
+
(self.height * self.scale + 1e-2) / self.dpi,
|
| 285 |
+
)
|
| 286 |
+
self.canvas = FigureCanvasAgg(fig)
|
| 287 |
+
# self.canvas = mpl.backends.backend_cairo.FigureCanvasCairo(fig)
|
| 288 |
+
ax = fig.add_axes([0.0, 0.0, 1.0, 1.0])
|
| 289 |
+
ax.axis("off")
|
| 290 |
+
self.fig = fig
|
| 291 |
+
self.ax = ax
|
| 292 |
+
self.reset_image(img)
|
| 293 |
+
|
| 294 |
+
def reset_image(self, img):
|
| 295 |
+
"""
|
| 296 |
+
Args:
|
| 297 |
+
img: same as in __init__
|
| 298 |
+
"""
|
| 299 |
+
img = img.astype("uint8")
|
| 300 |
+
self.ax.imshow(img, extent=(0, self.width, self.height, 0), interpolation="nearest")
|
| 301 |
+
|
| 302 |
+
def save(self, filepath):
|
| 303 |
+
"""
|
| 304 |
+
Args:
|
| 305 |
+
filepath (str): a string that contains the absolute path, including the file name, where
|
| 306 |
+
the visualized image will be saved.
|
| 307 |
+
"""
|
| 308 |
+
self.fig.savefig(filepath)
|
| 309 |
+
|
| 310 |
+
def get_image(self):
|
| 311 |
+
"""
|
| 312 |
+
Returns:
|
| 313 |
+
ndarray:
|
| 314 |
+
the visualized image of shape (H, W, 3) (RGB) in uint8 type.
|
| 315 |
+
The shape is scaled w.r.t the input image using the given `scale` argument.
|
| 316 |
+
"""
|
| 317 |
+
canvas = self.canvas
|
| 318 |
+
s, (width, height) = canvas.print_to_buffer()
|
| 319 |
+
# buf = io.BytesIO() # works for cairo backend
|
| 320 |
+
# canvas.print_rgba(buf)
|
| 321 |
+
# width, height = self.width, self.height
|
| 322 |
+
# s = buf.getvalue()
|
| 323 |
+
|
| 324 |
+
buffer = np.frombuffer(s, dtype="uint8")
|
| 325 |
+
|
| 326 |
+
img_rgba = buffer.reshape(height, width, 4)
|
| 327 |
+
rgb, alpha = np.split(img_rgba, [3], axis=2)
|
| 328 |
+
return rgb.astype("uint8")
|
| 329 |
+
|
| 330 |
+
|
| 331 |
+
class Visualizer:
|
| 332 |
+
"""
|
| 333 |
+
Visualizer that draws data about detection/segmentation on images.
|
| 334 |
+
|
| 335 |
+
It contains methods like `draw_{text,box,circle,line,binary_mask,polygon}`
|
| 336 |
+
that draw primitive objects to images, as well as high-level wrappers like
|
| 337 |
+
`draw_{instance_predictions,sem_seg,panoptic_seg_predictions,dataset_dict}`
|
| 338 |
+
that draw composite data in some pre-defined style.
|
| 339 |
+
|
| 340 |
+
Note that the exact visualization style for the high-level wrappers are subject to change.
|
| 341 |
+
Style such as color, opacity, label contents, visibility of labels, or even the visibility
|
| 342 |
+
of objects themselves (e.g. when the object is too small) may change according
|
| 343 |
+
to different heuristics, as long as the results still look visually reasonable.
|
| 344 |
+
|
| 345 |
+
To obtain a consistent style, you can implement custom drawing functions with the
|
| 346 |
+
abovementioned primitive methods instead. If you need more customized visualization
|
| 347 |
+
styles, you can process the data yourself following their format documented in
|
| 348 |
+
tutorials (:doc:`/tutorials/models`, :doc:`/tutorials/datasets`). This class does not
|
| 349 |
+
intend to satisfy everyone's preference on drawing styles.
|
| 350 |
+
|
| 351 |
+
This visualizer focuses on high rendering quality rather than performance. It is not
|
| 352 |
+
designed to be used for real-time applications.
|
| 353 |
+
"""
|
| 354 |
+
|
| 355 |
+
# TODO implement a fast, rasterized version using OpenCV
|
| 356 |
+
|
| 357 |
+
def __init__(
|
| 358 |
+
self, img_rgb, metadata=None, scale=1.0, instance_mode=ColorMode.IMAGE, font_size_scale=1.0
|
| 359 |
+
):
|
| 360 |
+
"""
|
| 361 |
+
Args:
|
| 362 |
+
img_rgb: a numpy array of shape (H, W, C), where H and W correspond to
|
| 363 |
+
the height and width of the image respectively. C is the number of
|
| 364 |
+
color channels. The image is required to be in RGB format since that
|
| 365 |
+
is a requirement of the Matplotlib library. The image is also expected
|
| 366 |
+
to be in the range [0, 255].
|
| 367 |
+
metadata (Metadata): dataset metadata (e.g. class names and colors)
|
| 368 |
+
instance_mode (ColorMode): defines one of the pre-defined style for drawing
|
| 369 |
+
instances on an image.
|
| 370 |
+
font_size_scale: extra scaling of font size on top of default font size
|
| 371 |
+
"""
|
| 372 |
+
self.img = np.asarray(img_rgb).clip(0, 255).astype(np.uint8)
|
| 373 |
+
if metadata is None:
|
| 374 |
+
metadata = MetadataCatalog.get("__nonexist__")
|
| 375 |
+
self.metadata = metadata
|
| 376 |
+
self.output = VisImage(self.img, scale=scale)
|
| 377 |
+
self.cpu_device = torch.device("cpu")
|
| 378 |
+
|
| 379 |
+
# too small texts are useless, therefore clamp to 9
|
| 380 |
+
self._default_font_size = (
|
| 381 |
+
max(np.sqrt(self.output.height * self.output.width) // 90, 10 // scale)
|
| 382 |
+
* font_size_scale
|
| 383 |
+
)
|
| 384 |
+
self._instance_mode = instance_mode
|
| 385 |
+
self.keypoint_threshold = _KEYPOINT_THRESHOLD
|
| 386 |
+
|
| 387 |
+
def draw_instance_predictions(self, predictions, jittering: bool = True):
|
| 388 |
+
"""
|
| 389 |
+
Draw instance-level prediction results on an image.
|
| 390 |
+
|
| 391 |
+
Args:
|
| 392 |
+
predictions (Instances): the output of an instance detection/segmentation
|
| 393 |
+
model. Following fields will be used to draw:
|
| 394 |
+
"pred_boxes", "pred_classes", "scores", "pred_masks" (or "pred_masks_rle").
|
| 395 |
+
jittering: if True, in color mode SEGMENTATION, randomly jitter the colors per class
|
| 396 |
+
to distinguish instances from the same class
|
| 397 |
+
|
| 398 |
+
Returns:
|
| 399 |
+
output (VisImage): image object with visualizations.
|
| 400 |
+
"""
|
| 401 |
+
boxes = predictions.pred_boxes if predictions.has("pred_boxes") else None
|
| 402 |
+
scores = predictions.scores if predictions.has("scores") else None
|
| 403 |
+
classes = predictions.pred_classes.tolist() if predictions.has("pred_classes") else None
|
| 404 |
+
labels = _create_text_labels(classes, scores, self.metadata.get("thing_classes", None))
|
| 405 |
+
keypoints = predictions.pred_keypoints if predictions.has("pred_keypoints") else None
|
| 406 |
+
|
| 407 |
+
if predictions.has("pred_masks"):
|
| 408 |
+
masks = np.asarray(predictions.pred_masks)
|
| 409 |
+
masks = [GenericMask(x, self.output.height, self.output.width) for x in masks]
|
| 410 |
+
else:
|
| 411 |
+
masks = None
|
| 412 |
+
|
| 413 |
+
if self._instance_mode == ColorMode.SEGMENTATION and self.metadata.get("thing_colors"):
|
| 414 |
+
colors = (
|
| 415 |
+
[self._jitter([x / 255 for x in self.metadata.thing_colors[c]]) for c in classes]
|
| 416 |
+
if jittering
|
| 417 |
+
else [
|
| 418 |
+
tuple(mplc.to_rgb([x / 255 for x in self.metadata.thing_colors[c]]))
|
| 419 |
+
for c in classes
|
| 420 |
+
]
|
| 421 |
+
)
|
| 422 |
+
|
| 423 |
+
alpha = 0.8
|
| 424 |
+
else:
|
| 425 |
+
colors = None
|
| 426 |
+
alpha = 0.5
|
| 427 |
+
|
| 428 |
+
if self._instance_mode == ColorMode.IMAGE_BW:
|
| 429 |
+
self.output.reset_image(
|
| 430 |
+
self._create_grayscale_image(
|
| 431 |
+
(predictions.pred_masks.any(dim=0) > 0).numpy()
|
| 432 |
+
if predictions.has("pred_masks")
|
| 433 |
+
else None
|
| 434 |
+
)
|
| 435 |
+
)
|
| 436 |
+
alpha = 0.3
|
| 437 |
+
|
| 438 |
+
self.overlay_instances(
|
| 439 |
+
masks=masks,
|
| 440 |
+
boxes=boxes,
|
| 441 |
+
labels=labels,
|
| 442 |
+
keypoints=keypoints,
|
| 443 |
+
assigned_colors=colors,
|
| 444 |
+
alpha=alpha,
|
| 445 |
+
)
|
| 446 |
+
return self.output
|
| 447 |
+
|
| 448 |
+
def draw_sem_seg(self, sem_seg, area_threshold=None, alpha=0.8):
|
| 449 |
+
"""
|
| 450 |
+
Draw semantic segmentation predictions/labels.
|
| 451 |
+
|
| 452 |
+
Args:
|
| 453 |
+
sem_seg (Tensor or ndarray): the segmentation of shape (H, W).
|
| 454 |
+
Each value is the integer label of the pixel.
|
| 455 |
+
area_threshold (int): segments with less than `area_threshold` are not drawn.
|
| 456 |
+
alpha (float): the larger it is, the more opaque the segmentations are.
|
| 457 |
+
|
| 458 |
+
Returns:
|
| 459 |
+
output (VisImage): image object with visualizations.
|
| 460 |
+
"""
|
| 461 |
+
if isinstance(sem_seg, torch.Tensor):
|
| 462 |
+
sem_seg = sem_seg.numpy()
|
| 463 |
+
labels, areas = np.unique(sem_seg, return_counts=True)
|
| 464 |
+
sorted_idxs = np.argsort(-areas).tolist()
|
| 465 |
+
labels = labels[sorted_idxs]
|
| 466 |
+
for label in filter(lambda l: l < len(self.metadata.stuff_classes), labels):
|
| 467 |
+
try:
|
| 468 |
+
mask_color = [x / 255 for x in self.metadata.stuff_colors[label]]
|
| 469 |
+
except (AttributeError, IndexError):
|
| 470 |
+
mask_color = None
|
| 471 |
+
|
| 472 |
+
binary_mask = (sem_seg == label).astype(np.uint8)
|
| 473 |
+
text = self.metadata.stuff_classes[label]
|
| 474 |
+
self.draw_binary_mask(
|
| 475 |
+
binary_mask,
|
| 476 |
+
color=mask_color,
|
| 477 |
+
edge_color=_OFF_WHITE,
|
| 478 |
+
text=text,
|
| 479 |
+
alpha=alpha,
|
| 480 |
+
area_threshold=area_threshold,
|
| 481 |
+
)
|
| 482 |
+
return self.output
|
| 483 |
+
|
| 484 |
+
def draw_panoptic_seg(self, panoptic_seg, segments_info, area_threshold=None, alpha=0.7):
|
| 485 |
+
"""
|
| 486 |
+
Draw panoptic prediction annotations or results.
|
| 487 |
+
|
| 488 |
+
Args:
|
| 489 |
+
panoptic_seg (Tensor): of shape (height, width) where the values are ids for each
|
| 490 |
+
segment.
|
| 491 |
+
segments_info (list[dict] or None): Describe each segment in `panoptic_seg`.
|
| 492 |
+
If it is a ``list[dict]``, each dict contains keys "id", "category_id".
|
| 493 |
+
If None, category id of each pixel is computed by
|
| 494 |
+
``pixel // metadata.label_divisor``.
|
| 495 |
+
area_threshold (int): stuff segments with less than `area_threshold` are not drawn.
|
| 496 |
+
|
| 497 |
+
Returns:
|
| 498 |
+
output (VisImage): image object with visualizations.
|
| 499 |
+
"""
|
| 500 |
+
pred = _PanopticPrediction(panoptic_seg, segments_info, self.metadata)
|
| 501 |
+
|
| 502 |
+
if self._instance_mode == ColorMode.IMAGE_BW:
|
| 503 |
+
self.output.reset_image(self._create_grayscale_image(pred.non_empty_mask()))
|
| 504 |
+
|
| 505 |
+
# draw mask for all semantic segments first i.e. "stuff"
|
| 506 |
+
for mask, sinfo in pred.semantic_masks():
|
| 507 |
+
category_idx = sinfo["category_id"]
|
| 508 |
+
try:
|
| 509 |
+
mask_color = [x / 255 for x in self.metadata.stuff_colors[category_idx]]
|
| 510 |
+
except AttributeError:
|
| 511 |
+
mask_color = None
|
| 512 |
+
|
| 513 |
+
text = self.metadata.stuff_classes[category_idx]
|
| 514 |
+
self.draw_binary_mask(
|
| 515 |
+
mask,
|
| 516 |
+
color=mask_color,
|
| 517 |
+
edge_color=_OFF_WHITE,
|
| 518 |
+
text=text,
|
| 519 |
+
alpha=alpha,
|
| 520 |
+
area_threshold=area_threshold,
|
| 521 |
+
)
|
| 522 |
+
|
| 523 |
+
# draw mask for all instances second
|
| 524 |
+
all_instances = list(pred.instance_masks())
|
| 525 |
+
if len(all_instances) == 0:
|
| 526 |
+
return self.output
|
| 527 |
+
masks, sinfo = list(zip(*all_instances))
|
| 528 |
+
category_ids = [x["category_id"] for x in sinfo]
|
| 529 |
+
|
| 530 |
+
try:
|
| 531 |
+
scores = [x["score"] for x in sinfo]
|
| 532 |
+
except KeyError:
|
| 533 |
+
scores = None
|
| 534 |
+
labels = _create_text_labels(
|
| 535 |
+
category_ids, scores, self.metadata.thing_classes, [x.get("iscrowd", 0) for x in sinfo]
|
| 536 |
+
)
|
| 537 |
+
|
| 538 |
+
try:
|
| 539 |
+
colors = [
|
| 540 |
+
self._jitter([x / 255 for x in self.metadata.thing_colors[c]]) for c in category_ids
|
| 541 |
+
]
|
| 542 |
+
except AttributeError:
|
| 543 |
+
colors = None
|
| 544 |
+
self.overlay_instances(masks=masks, labels=labels, assigned_colors=colors, alpha=alpha)
|
| 545 |
+
|
| 546 |
+
return self.output
|
| 547 |
+
|
| 548 |
+
draw_panoptic_seg_predictions = draw_panoptic_seg # backward compatibility
|
| 549 |
+
|
| 550 |
+
def draw_dataset_dict(self, dic):
|
| 551 |
+
"""
|
| 552 |
+
Draw annotations/segmentations in Detectron2 Dataset format.
|
| 553 |
+
|
| 554 |
+
Args:
|
| 555 |
+
dic (dict): annotation/segmentation data of one image, in Detectron2 Dataset format.
|
| 556 |
+
|
| 557 |
+
Returns:
|
| 558 |
+
output (VisImage): image object with visualizations.
|
| 559 |
+
"""
|
| 560 |
+
annos = dic.get("annotations", None)
|
| 561 |
+
if annos:
|
| 562 |
+
if "segmentation" in annos[0]:
|
| 563 |
+
masks = [x["segmentation"] for x in annos]
|
| 564 |
+
else:
|
| 565 |
+
masks = None
|
| 566 |
+
if "keypoints" in annos[0]:
|
| 567 |
+
keypts = [x["keypoints"] for x in annos]
|
| 568 |
+
keypts = np.array(keypts).reshape(len(annos), -1, 3)
|
| 569 |
+
else:
|
| 570 |
+
keypts = None
|
| 571 |
+
|
| 572 |
+
boxes = [
|
| 573 |
+
(
|
| 574 |
+
BoxMode.convert(x["bbox"], x["bbox_mode"], BoxMode.XYXY_ABS)
|
| 575 |
+
if len(x["bbox"]) == 4
|
| 576 |
+
else x["bbox"]
|
| 577 |
+
)
|
| 578 |
+
for x in annos
|
| 579 |
+
]
|
| 580 |
+
|
| 581 |
+
colors = None
|
| 582 |
+
category_ids = [x["category_id"] for x in annos]
|
| 583 |
+
if self._instance_mode == ColorMode.SEGMENTATION and self.metadata.get("thing_colors"):
|
| 584 |
+
colors = [
|
| 585 |
+
self._jitter([x / 255 for x in self.metadata.thing_colors[c]])
|
| 586 |
+
for c in category_ids
|
| 587 |
+
]
|
| 588 |
+
names = self.metadata.get("thing_classes", None)
|
| 589 |
+
labels = _create_text_labels(
|
| 590 |
+
category_ids,
|
| 591 |
+
scores=None,
|
| 592 |
+
class_names=names,
|
| 593 |
+
is_crowd=[x.get("iscrowd", 0) for x in annos],
|
| 594 |
+
)
|
| 595 |
+
self.overlay_instances(
|
| 596 |
+
labels=labels, boxes=boxes, masks=masks, keypoints=keypts, assigned_colors=colors
|
| 597 |
+
)
|
| 598 |
+
|
| 599 |
+
sem_seg = dic.get("sem_seg", None)
|
| 600 |
+
if sem_seg is None and "sem_seg_file_name" in dic:
|
| 601 |
+
with PathManager.open(dic["sem_seg_file_name"], "rb") as f:
|
| 602 |
+
sem_seg = Image.open(f)
|
| 603 |
+
sem_seg = np.asarray(sem_seg, dtype="uint8")
|
| 604 |
+
if sem_seg is not None:
|
| 605 |
+
self.draw_sem_seg(sem_seg, area_threshold=0, alpha=0.5)
|
| 606 |
+
|
| 607 |
+
pan_seg = dic.get("pan_seg", None)
|
| 608 |
+
if pan_seg is None and "pan_seg_file_name" in dic:
|
| 609 |
+
with PathManager.open(dic["pan_seg_file_name"], "rb") as f:
|
| 610 |
+
pan_seg = Image.open(f)
|
| 611 |
+
pan_seg = np.asarray(pan_seg)
|
| 612 |
+
from panopticapi.utils import rgb2id
|
| 613 |
+
|
| 614 |
+
pan_seg = rgb2id(pan_seg)
|
| 615 |
+
if pan_seg is not None:
|
| 616 |
+
segments_info = dic["segments_info"]
|
| 617 |
+
pan_seg = torch.tensor(pan_seg)
|
| 618 |
+
self.draw_panoptic_seg(pan_seg, segments_info, area_threshold=0, alpha=0.5)
|
| 619 |
+
return self.output
|
| 620 |
+
|
| 621 |
+
def overlay_instances(
|
| 622 |
+
self,
|
| 623 |
+
*,
|
| 624 |
+
boxes=None,
|
| 625 |
+
labels=None,
|
| 626 |
+
masks=None,
|
| 627 |
+
keypoints=None,
|
| 628 |
+
assigned_colors=None,
|
| 629 |
+
alpha=0.5,
|
| 630 |
+
):
|
| 631 |
+
"""
|
| 632 |
+
Args:
|
| 633 |
+
boxes (Boxes, RotatedBoxes or ndarray): either a :class:`Boxes`,
|
| 634 |
+
or an Nx4 numpy array of XYXY_ABS format for the N objects in a single image,
|
| 635 |
+
or a :class:`RotatedBoxes`,
|
| 636 |
+
or an Nx5 numpy array of (x_center, y_center, width, height, angle_degrees) format
|
| 637 |
+
for the N objects in a single image,
|
| 638 |
+
labels (list[str]): the text to be displayed for each instance.
|
| 639 |
+
masks (masks-like object): Supported types are:
|
| 640 |
+
|
| 641 |
+
* :class:`detectron2.structures.PolygonMasks`,
|
| 642 |
+
:class:`detectron2.structures.BitMasks`.
|
| 643 |
+
* list[list[ndarray]]: contains the segmentation masks for all objects in one image.
|
| 644 |
+
The first level of the list corresponds to individual instances. The second
|
| 645 |
+
level to all the polygon that compose the instance, and the third level
|
| 646 |
+
to the polygon coordinates. The third level should have the format of
|
| 647 |
+
[x0, y0, x1, y1, ..., xn, yn] (n >= 3).
|
| 648 |
+
* list[ndarray]: each ndarray is a binary mask of shape (H, W).
|
| 649 |
+
* list[dict]: each dict is a COCO-style RLE.
|
| 650 |
+
keypoints (Keypoint or array like): an array-like object of shape (N, K, 3),
|
| 651 |
+
where the N is the number of instances and K is the number of keypoints.
|
| 652 |
+
The last dimension corresponds to (x, y, visibility or score).
|
| 653 |
+
assigned_colors (list[matplotlib.colors]): a list of colors, where each color
|
| 654 |
+
corresponds to each mask or box in the image. Refer to 'matplotlib.colors'
|
| 655 |
+
for full list of formats that the colors are accepted in.
|
| 656 |
+
Returns:
|
| 657 |
+
output (VisImage): image object with visualizations.
|
| 658 |
+
"""
|
| 659 |
+
num_instances = 0
|
| 660 |
+
if boxes is not None:
|
| 661 |
+
boxes = self._convert_boxes(boxes)
|
| 662 |
+
num_instances = len(boxes)
|
| 663 |
+
if masks is not None:
|
| 664 |
+
masks = self._convert_masks(masks)
|
| 665 |
+
if num_instances:
|
| 666 |
+
assert len(masks) == num_instances
|
| 667 |
+
else:
|
| 668 |
+
num_instances = len(masks)
|
| 669 |
+
if keypoints is not None:
|
| 670 |
+
if num_instances:
|
| 671 |
+
assert len(keypoints) == num_instances
|
| 672 |
+
else:
|
| 673 |
+
num_instances = len(keypoints)
|
| 674 |
+
keypoints = self._convert_keypoints(keypoints)
|
| 675 |
+
if labels is not None:
|
| 676 |
+
assert len(labels) == num_instances
|
| 677 |
+
if assigned_colors is None:
|
| 678 |
+
assigned_colors = [random_color(rgb=True, maximum=1) for _ in range(num_instances)]
|
| 679 |
+
if num_instances == 0:
|
| 680 |
+
return self.output
|
| 681 |
+
if boxes is not None and boxes.shape[1] == 5:
|
| 682 |
+
return self.overlay_rotated_instances(
|
| 683 |
+
boxes=boxes, labels=labels, assigned_colors=assigned_colors
|
| 684 |
+
)
|
| 685 |
+
|
| 686 |
+
# Display in largest to smallest order to reduce occlusion.
|
| 687 |
+
areas = None
|
| 688 |
+
if boxes is not None:
|
| 689 |
+
areas = np.prod(boxes[:, 2:] - boxes[:, :2], axis=1)
|
| 690 |
+
elif masks is not None:
|
| 691 |
+
areas = np.asarray([x.area() for x in masks])
|
| 692 |
+
|
| 693 |
+
if areas is not None:
|
| 694 |
+
sorted_idxs = np.argsort(-areas).tolist()
|
| 695 |
+
# Re-order overlapped instances in descending order.
|
| 696 |
+
boxes = boxes[sorted_idxs] if boxes is not None else None
|
| 697 |
+
labels = [labels[k] for k in sorted_idxs] if labels is not None else None
|
| 698 |
+
masks = [masks[idx] for idx in sorted_idxs] if masks is not None else None
|
| 699 |
+
assigned_colors = [assigned_colors[idx] for idx in sorted_idxs]
|
| 700 |
+
keypoints = keypoints[sorted_idxs] if keypoints is not None else None
|
| 701 |
+
|
| 702 |
+
for i in range(num_instances):
|
| 703 |
+
color = assigned_colors[i]
|
| 704 |
+
if boxes is not None:
|
| 705 |
+
self.draw_box(boxes[i], edge_color=color)
|
| 706 |
+
|
| 707 |
+
if masks is not None:
|
| 708 |
+
for segment in masks[i].polygons:
|
| 709 |
+
self.draw_polygon(segment.reshape(-1, 2), color, alpha=alpha)
|
| 710 |
+
|
| 711 |
+
if labels is not None:
|
| 712 |
+
# first get a box
|
| 713 |
+
if boxes is not None:
|
| 714 |
+
x0, y0, x1, y1 = boxes[i]
|
| 715 |
+
text_pos = (x0, y0) # if drawing boxes, put text on the box corner.
|
| 716 |
+
horiz_align = "left"
|
| 717 |
+
elif masks is not None:
|
| 718 |
+
# skip small mask without polygon
|
| 719 |
+
if len(masks[i].polygons) == 0:
|
| 720 |
+
continue
|
| 721 |
+
|
| 722 |
+
x0, y0, x1, y1 = masks[i].bbox()
|
| 723 |
+
|
| 724 |
+
# draw text in the center (defined by median) when box is not drawn
|
| 725 |
+
# median is less sensitive to outliers.
|
| 726 |
+
text_pos = np.median(masks[i].mask.nonzero(), axis=1)[::-1]
|
| 727 |
+
horiz_align = "center"
|
| 728 |
+
else:
|
| 729 |
+
continue # drawing the box confidence for keypoints isn't very useful.
|
| 730 |
+
# for small objects, draw text at the side to avoid occlusion
|
| 731 |
+
instance_area = (y1 - y0) * (x1 - x0)
|
| 732 |
+
if (
|
| 733 |
+
instance_area < _SMALL_OBJECT_AREA_THRESH * self.output.scale
|
| 734 |
+
or y1 - y0 < 40 * self.output.scale
|
| 735 |
+
):
|
| 736 |
+
if y1 >= self.output.height - 5:
|
| 737 |
+
text_pos = (x1, y0)
|
| 738 |
+
else:
|
| 739 |
+
text_pos = (x0, y1)
|
| 740 |
+
|
| 741 |
+
height_ratio = (y1 - y0) / np.sqrt(self.output.height * self.output.width)
|
| 742 |
+
lighter_color = self._change_color_brightness(color, brightness_factor=0.7)
|
| 743 |
+
font_size = (
|
| 744 |
+
np.clip((height_ratio - 0.02) / 0.08 + 1, 1.2, 2)
|
| 745 |
+
* 0.5
|
| 746 |
+
* self._default_font_size
|
| 747 |
+
)
|
| 748 |
+
self.draw_text(
|
| 749 |
+
labels[i],
|
| 750 |
+
text_pos,
|
| 751 |
+
color=lighter_color,
|
| 752 |
+
horizontal_alignment=horiz_align,
|
| 753 |
+
font_size=font_size,
|
| 754 |
+
)
|
| 755 |
+
|
| 756 |
+
# draw keypoints
|
| 757 |
+
if keypoints is not None:
|
| 758 |
+
for keypoints_per_instance in keypoints:
|
| 759 |
+
self.draw_and_connect_keypoints(keypoints_per_instance)
|
| 760 |
+
|
| 761 |
+
return self.output
|
| 762 |
+
|
| 763 |
+
def overlay_rotated_instances(self, boxes=None, labels=None, assigned_colors=None):
|
| 764 |
+
"""
|
| 765 |
+
Args:
|
| 766 |
+
boxes (ndarray): an Nx5 numpy array of
|
| 767 |
+
(x_center, y_center, width, height, angle_degrees) format
|
| 768 |
+
for the N objects in a single image.
|
| 769 |
+
labels (list[str]): the text to be displayed for each instance.
|
| 770 |
+
assigned_colors (list[matplotlib.colors]): a list of colors, where each color
|
| 771 |
+
corresponds to each mask or box in the image. Refer to 'matplotlib.colors'
|
| 772 |
+
for full list of formats that the colors are accepted in.
|
| 773 |
+
|
| 774 |
+
Returns:
|
| 775 |
+
output (VisImage): image object with visualizations.
|
| 776 |
+
"""
|
| 777 |
+
num_instances = len(boxes)
|
| 778 |
+
|
| 779 |
+
if assigned_colors is None:
|
| 780 |
+
assigned_colors = [random_color(rgb=True, maximum=1) for _ in range(num_instances)]
|
| 781 |
+
if num_instances == 0:
|
| 782 |
+
return self.output
|
| 783 |
+
|
| 784 |
+
# Display in largest to smallest order to reduce occlusion.
|
| 785 |
+
if boxes is not None:
|
| 786 |
+
areas = boxes[:, 2] * boxes[:, 3]
|
| 787 |
+
|
| 788 |
+
sorted_idxs = np.argsort(-areas).tolist()
|
| 789 |
+
# Re-order overlapped instances in descending order.
|
| 790 |
+
boxes = boxes[sorted_idxs]
|
| 791 |
+
labels = [labels[k] for k in sorted_idxs] if labels is not None else None
|
| 792 |
+
colors = [assigned_colors[idx] for idx in sorted_idxs]
|
| 793 |
+
|
| 794 |
+
for i in range(num_instances):
|
| 795 |
+
self.draw_rotated_box_with_label(
|
| 796 |
+
boxes[i], edge_color=colors[i], label=labels[i] if labels is not None else None
|
| 797 |
+
)
|
| 798 |
+
|
| 799 |
+
return self.output
|
| 800 |
+
|
| 801 |
+
def draw_and_connect_keypoints(self, keypoints):
|
| 802 |
+
"""
|
| 803 |
+
Draws keypoints of an instance and follows the rules for keypoint connections
|
| 804 |
+
to draw lines between appropriate keypoints. This follows color heuristics for
|
| 805 |
+
line color.
|
| 806 |
+
|
| 807 |
+
Args:
|
| 808 |
+
keypoints (Tensor): a tensor of shape (K, 3), where K is the number of keypoints
|
| 809 |
+
and the last dimension corresponds to (x, y, probability).
|
| 810 |
+
|
| 811 |
+
Returns:
|
| 812 |
+
output (VisImage): image object with visualizations.
|
| 813 |
+
"""
|
| 814 |
+
visible = {}
|
| 815 |
+
keypoint_names = self.metadata.get("keypoint_names")
|
| 816 |
+
for idx, keypoint in enumerate(keypoints):
|
| 817 |
+
|
| 818 |
+
# draw keypoint
|
| 819 |
+
x, y, prob = keypoint
|
| 820 |
+
if prob > self.keypoint_threshold:
|
| 821 |
+
self.draw_circle((x, y), color=_RED)
|
| 822 |
+
if keypoint_names:
|
| 823 |
+
keypoint_name = keypoint_names[idx]
|
| 824 |
+
visible[keypoint_name] = (x, y)
|
| 825 |
+
|
| 826 |
+
if self.metadata.get("keypoint_connection_rules"):
|
| 827 |
+
for kp0, kp1, color in self.metadata.keypoint_connection_rules:
|
| 828 |
+
if kp0 in visible and kp1 in visible:
|
| 829 |
+
x0, y0 = visible[kp0]
|
| 830 |
+
x1, y1 = visible[kp1]
|
| 831 |
+
color = tuple(x / 255.0 for x in color)
|
| 832 |
+
self.draw_line([x0, x1], [y0, y1], color=color)
|
| 833 |
+
|
| 834 |
+
# draw lines from nose to mid-shoulder and mid-shoulder to mid-hip
|
| 835 |
+
# Note that this strategy is specific to person keypoints.
|
| 836 |
+
# For other keypoints, it should just do nothing
|
| 837 |
+
try:
|
| 838 |
+
ls_x, ls_y = visible["left_shoulder"]
|
| 839 |
+
rs_x, rs_y = visible["right_shoulder"]
|
| 840 |
+
mid_shoulder_x, mid_shoulder_y = (ls_x + rs_x) / 2, (ls_y + rs_y) / 2
|
| 841 |
+
except KeyError:
|
| 842 |
+
pass
|
| 843 |
+
else:
|
| 844 |
+
# draw line from nose to mid-shoulder
|
| 845 |
+
nose_x, nose_y = visible.get("nose", (None, None))
|
| 846 |
+
if nose_x is not None:
|
| 847 |
+
self.draw_line([nose_x, mid_shoulder_x], [nose_y, mid_shoulder_y], color=_RED)
|
| 848 |
+
|
| 849 |
+
try:
|
| 850 |
+
# draw line from mid-shoulder to mid-hip
|
| 851 |
+
lh_x, lh_y = visible["left_hip"]
|
| 852 |
+
rh_x, rh_y = visible["right_hip"]
|
| 853 |
+
except KeyError:
|
| 854 |
+
pass
|
| 855 |
+
else:
|
| 856 |
+
mid_hip_x, mid_hip_y = (lh_x + rh_x) / 2, (lh_y + rh_y) / 2
|
| 857 |
+
self.draw_line([mid_hip_x, mid_shoulder_x], [mid_hip_y, mid_shoulder_y], color=_RED)
|
| 858 |
+
return self.output
|
| 859 |
+
|
| 860 |
+
"""
|
| 861 |
+
Primitive drawing functions:
|
| 862 |
+
"""
|
| 863 |
+
|
| 864 |
+
def draw_text(
|
| 865 |
+
self,
|
| 866 |
+
text,
|
| 867 |
+
position,
|
| 868 |
+
*,
|
| 869 |
+
font_size=None,
|
| 870 |
+
color="g",
|
| 871 |
+
horizontal_alignment="center",
|
| 872 |
+
rotation=0,
|
| 873 |
+
):
|
| 874 |
+
"""
|
| 875 |
+
Args:
|
| 876 |
+
text (str): class label
|
| 877 |
+
position (tuple): a tuple of the x and y coordinates to place text on image.
|
| 878 |
+
font_size (int, optional): font of the text. If not provided, a font size
|
| 879 |
+
proportional to the image width is calculated and used.
|
| 880 |
+
color: color of the text. Refer to `matplotlib.colors` for full list
|
| 881 |
+
of formats that are accepted.
|
| 882 |
+
horizontal_alignment (str): see `matplotlib.text.Text`
|
| 883 |
+
rotation: rotation angle in degrees CCW
|
| 884 |
+
|
| 885 |
+
Returns:
|
| 886 |
+
output (VisImage): image object with text drawn.
|
| 887 |
+
"""
|
| 888 |
+
if not font_size:
|
| 889 |
+
font_size = self._default_font_size
|
| 890 |
+
|
| 891 |
+
# since the text background is dark, we don't want the text to be dark
|
| 892 |
+
color = np.maximum(list(mplc.to_rgb(color)), 0.2)
|
| 893 |
+
color[np.argmax(color)] = max(0.8, np.max(color))
|
| 894 |
+
|
| 895 |
+
x, y = position
|
| 896 |
+
self.output.ax.text(
|
| 897 |
+
x,
|
| 898 |
+
y,
|
| 899 |
+
text,
|
| 900 |
+
size=font_size * self.output.scale,
|
| 901 |
+
family="sans-serif",
|
| 902 |
+
bbox={"facecolor": "black", "alpha": 0.8, "pad": 0.7, "edgecolor": "none"},
|
| 903 |
+
verticalalignment="top",
|
| 904 |
+
horizontalalignment=horizontal_alignment,
|
| 905 |
+
color=color,
|
| 906 |
+
zorder=10,
|
| 907 |
+
rotation=rotation,
|
| 908 |
+
)
|
| 909 |
+
return self.output
|
| 910 |
+
|
| 911 |
+
def draw_box(self, box_coord, alpha=0.5, edge_color="g", line_style="-"):
|
| 912 |
+
"""
|
| 913 |
+
Args:
|
| 914 |
+
box_coord (tuple): a tuple containing x0, y0, x1, y1 coordinates, where x0 and y0
|
| 915 |
+
are the coordinates of the image's top left corner. x1 and y1 are the
|
| 916 |
+
coordinates of the image's bottom right corner.
|
| 917 |
+
alpha (float): blending efficient. Smaller values lead to more transparent masks.
|
| 918 |
+
edge_color: color of the outline of the box. Refer to `matplotlib.colors`
|
| 919 |
+
for full list of formats that are accepted.
|
| 920 |
+
line_style (string): the string to use to create the outline of the boxes.
|
| 921 |
+
|
| 922 |
+
Returns:
|
| 923 |
+
output (VisImage): image object with box drawn.
|
| 924 |
+
"""
|
| 925 |
+
x0, y0, x1, y1 = box_coord
|
| 926 |
+
width = x1 - x0
|
| 927 |
+
height = y1 - y0
|
| 928 |
+
|
| 929 |
+
linewidth = max(self._default_font_size / 4, 1)
|
| 930 |
+
|
| 931 |
+
self.output.ax.add_patch(
|
| 932 |
+
mpl.patches.Rectangle(
|
| 933 |
+
(x0, y0),
|
| 934 |
+
width,
|
| 935 |
+
height,
|
| 936 |
+
fill=False,
|
| 937 |
+
edgecolor=edge_color,
|
| 938 |
+
linewidth=linewidth * self.output.scale,
|
| 939 |
+
alpha=alpha,
|
| 940 |
+
linestyle=line_style,
|
| 941 |
+
)
|
| 942 |
+
)
|
| 943 |
+
return self.output
|
| 944 |
+
|
| 945 |
+
def draw_rotated_box_with_label(
|
| 946 |
+
self, rotated_box, alpha=0.5, edge_color="g", line_style="-", label=None
|
| 947 |
+
):
|
| 948 |
+
"""
|
| 949 |
+
Draw a rotated box with label on its top-left corner.
|
| 950 |
+
|
| 951 |
+
Args:
|
| 952 |
+
rotated_box (tuple): a tuple containing (cnt_x, cnt_y, w, h, angle),
|
| 953 |
+
where cnt_x and cnt_y are the center coordinates of the box.
|
| 954 |
+
w and h are the width and height of the box. angle represents how
|
| 955 |
+
many degrees the box is rotated CCW with regard to the 0-degree box.
|
| 956 |
+
alpha (float): blending efficient. Smaller values lead to more transparent masks.
|
| 957 |
+
edge_color: color of the outline of the box. Refer to `matplotlib.colors`
|
| 958 |
+
for full list of formats that are accepted.
|
| 959 |
+
line_style (string): the string to use to create the outline of the boxes.
|
| 960 |
+
label (string): label for rotated box. It will not be rendered when set to None.
|
| 961 |
+
|
| 962 |
+
Returns:
|
| 963 |
+
output (VisImage): image object with box drawn.
|
| 964 |
+
"""
|
| 965 |
+
cnt_x, cnt_y, w, h, angle = rotated_box
|
| 966 |
+
area = w * h
|
| 967 |
+
# use thinner lines when the box is small
|
| 968 |
+
linewidth = self._default_font_size / (
|
| 969 |
+
6 if area < _SMALL_OBJECT_AREA_THRESH * self.output.scale else 3
|
| 970 |
+
)
|
| 971 |
+
|
| 972 |
+
theta = angle * math.pi / 180.0
|
| 973 |
+
c = math.cos(theta)
|
| 974 |
+
s = math.sin(theta)
|
| 975 |
+
rect = [(-w / 2, h / 2), (-w / 2, -h / 2), (w / 2, -h / 2), (w / 2, h / 2)]
|
| 976 |
+
# x: left->right ; y: top->down
|
| 977 |
+
rotated_rect = [(s * yy + c * xx + cnt_x, c * yy - s * xx + cnt_y) for (xx, yy) in rect]
|
| 978 |
+
for k in range(4):
|
| 979 |
+
j = (k + 1) % 4
|
| 980 |
+
self.draw_line(
|
| 981 |
+
[rotated_rect[k][0], rotated_rect[j][0]],
|
| 982 |
+
[rotated_rect[k][1], rotated_rect[j][1]],
|
| 983 |
+
color=edge_color,
|
| 984 |
+
linestyle="--" if k == 1 else line_style,
|
| 985 |
+
linewidth=linewidth,
|
| 986 |
+
)
|
| 987 |
+
|
| 988 |
+
if label is not None:
|
| 989 |
+
text_pos = rotated_rect[1] # topleft corner
|
| 990 |
+
|
| 991 |
+
height_ratio = h / np.sqrt(self.output.height * self.output.width)
|
| 992 |
+
label_color = self._change_color_brightness(edge_color, brightness_factor=0.7)
|
| 993 |
+
font_size = (
|
| 994 |
+
np.clip((height_ratio - 0.02) / 0.08 + 1, 1.2, 2) * 0.5 * self._default_font_size
|
| 995 |
+
)
|
| 996 |
+
self.draw_text(label, text_pos, color=label_color, font_size=font_size, rotation=angle)
|
| 997 |
+
|
| 998 |
+
return self.output
|
| 999 |
+
|
| 1000 |
+
def draw_circle(self, circle_coord, color, radius=3):
|
| 1001 |
+
"""
|
| 1002 |
+
Args:
|
| 1003 |
+
circle_coord (list(int) or tuple(int)): contains the x and y coordinates
|
| 1004 |
+
of the center of the circle.
|
| 1005 |
+
color: color of the polygon. Refer to `matplotlib.colors` for a full list of
|
| 1006 |
+
formats that are accepted.
|
| 1007 |
+
radius (int): radius of the circle.
|
| 1008 |
+
|
| 1009 |
+
Returns:
|
| 1010 |
+
output (VisImage): image object with box drawn.
|
| 1011 |
+
"""
|
| 1012 |
+
x, y = circle_coord
|
| 1013 |
+
self.output.ax.add_patch(
|
| 1014 |
+
mpl.patches.Circle(circle_coord, radius=radius, fill=True, color=color)
|
| 1015 |
+
)
|
| 1016 |
+
return self.output
|
| 1017 |
+
|
| 1018 |
+
def draw_line(self, x_data, y_data, color, linestyle="-", linewidth=None):
|
| 1019 |
+
"""
|
| 1020 |
+
Args:
|
| 1021 |
+
x_data (list[int]): a list containing x values of all the points being drawn.
|
| 1022 |
+
Length of list should match the length of y_data.
|
| 1023 |
+
y_data (list[int]): a list containing y values of all the points being drawn.
|
| 1024 |
+
Length of list should match the length of x_data.
|
| 1025 |
+
color: color of the line. Refer to `matplotlib.colors` for a full list of
|
| 1026 |
+
formats that are accepted.
|
| 1027 |
+
linestyle: style of the line. Refer to `matplotlib.lines.Line2D`
|
| 1028 |
+
for a full list of formats that are accepted.
|
| 1029 |
+
linewidth (float or None): width of the line. When it's None,
|
| 1030 |
+
a default value will be computed and used.
|
| 1031 |
+
|
| 1032 |
+
Returns:
|
| 1033 |
+
output (VisImage): image object with line drawn.
|
| 1034 |
+
"""
|
| 1035 |
+
if linewidth is None:
|
| 1036 |
+
linewidth = self._default_font_size / 3
|
| 1037 |
+
linewidth = max(linewidth, 1)
|
| 1038 |
+
self.output.ax.add_line(
|
| 1039 |
+
mpl.lines.Line2D(
|
| 1040 |
+
x_data,
|
| 1041 |
+
y_data,
|
| 1042 |
+
linewidth=linewidth * self.output.scale,
|
| 1043 |
+
color=color,
|
| 1044 |
+
linestyle=linestyle,
|
| 1045 |
+
)
|
| 1046 |
+
)
|
| 1047 |
+
return self.output
|
| 1048 |
+
|
| 1049 |
+
def draw_binary_mask(
|
| 1050 |
+
self, binary_mask, color=None, *, edge_color=None, text=None, alpha=0.5, area_threshold=10
|
| 1051 |
+
):
|
| 1052 |
+
"""
|
| 1053 |
+
Args:
|
| 1054 |
+
binary_mask (ndarray): numpy array of shape (H, W), where H is the image height and
|
| 1055 |
+
W is the image width. Each value in the array is either a 0 or 1 value of uint8
|
| 1056 |
+
type.
|
| 1057 |
+
color: color of the mask. Refer to `matplotlib.colors` for a full list of
|
| 1058 |
+
formats that are accepted. If None, will pick a random color.
|
| 1059 |
+
edge_color: color of the polygon edges. Refer to `matplotlib.colors` for a
|
| 1060 |
+
full list of formats that are accepted.
|
| 1061 |
+
text (str): if None, will be drawn on the object
|
| 1062 |
+
alpha (float): blending efficient. Smaller values lead to more transparent masks.
|
| 1063 |
+
area_threshold (float): a connected component smaller than this area will not be shown.
|
| 1064 |
+
|
| 1065 |
+
Returns:
|
| 1066 |
+
output (VisImage): image object with mask drawn.
|
| 1067 |
+
"""
|
| 1068 |
+
if color is None:
|
| 1069 |
+
color = random_color(rgb=True, maximum=1)
|
| 1070 |
+
color = mplc.to_rgb(color)
|
| 1071 |
+
|
| 1072 |
+
has_valid_segment = False
|
| 1073 |
+
binary_mask = binary_mask.astype("uint8") # opencv needs uint8
|
| 1074 |
+
mask = GenericMask(binary_mask, self.output.height, self.output.width)
|
| 1075 |
+
shape2d = (binary_mask.shape[0], binary_mask.shape[1])
|
| 1076 |
+
|
| 1077 |
+
if not mask.has_holes:
|
| 1078 |
+
# draw polygons for regular masks
|
| 1079 |
+
for segment in mask.polygons:
|
| 1080 |
+
area = mask_util.area(mask_util.frPyObjects([segment], shape2d[0], shape2d[1]))
|
| 1081 |
+
if area < (area_threshold or 0):
|
| 1082 |
+
continue
|
| 1083 |
+
has_valid_segment = True
|
| 1084 |
+
segment = segment.reshape(-1, 2)
|
| 1085 |
+
self.draw_polygon(segment, color=color, edge_color=edge_color, alpha=alpha)
|
| 1086 |
+
else:
|
| 1087 |
+
# TODO: Use Path/PathPatch to draw vector graphics:
|
| 1088 |
+
# https://stackoverflow.com/questions/8919719/how-to-plot-a-complex-polygon
|
| 1089 |
+
rgba = np.zeros(shape2d + (4,), dtype="float32")
|
| 1090 |
+
rgba[:, :, :3] = color
|
| 1091 |
+
rgba[:, :, 3] = (mask.mask == 1).astype("float32") * alpha
|
| 1092 |
+
has_valid_segment = True
|
| 1093 |
+
self.output.ax.imshow(rgba, extent=(0, self.output.width, self.output.height, 0))
|
| 1094 |
+
|
| 1095 |
+
if text is not None and has_valid_segment:
|
| 1096 |
+
lighter_color = self._change_color_brightness(color, brightness_factor=0.7)
|
| 1097 |
+
self._draw_text_in_mask(binary_mask, text, lighter_color)
|
| 1098 |
+
return self.output
|
| 1099 |
+
|
| 1100 |
+
def draw_soft_mask(self, soft_mask, color=None, *, text=None, alpha=0.5):
|
| 1101 |
+
"""
|
| 1102 |
+
Args:
|
| 1103 |
+
soft_mask (ndarray): float array of shape (H, W), each value in [0, 1].
|
| 1104 |
+
color: color of the mask. Refer to `matplotlib.colors` for a full list of
|
| 1105 |
+
formats that are accepted. If None, will pick a random color.
|
| 1106 |
+
text (str): if None, will be drawn on the object
|
| 1107 |
+
alpha (float): blending efficient. Smaller values lead to more transparent masks.
|
| 1108 |
+
|
| 1109 |
+
Returns:
|
| 1110 |
+
output (VisImage): image object with mask drawn.
|
| 1111 |
+
"""
|
| 1112 |
+
if color is None:
|
| 1113 |
+
color = random_color(rgb=True, maximum=1)
|
| 1114 |
+
color = mplc.to_rgb(color)
|
| 1115 |
+
|
| 1116 |
+
shape2d = (soft_mask.shape[0], soft_mask.shape[1])
|
| 1117 |
+
rgba = np.zeros(shape2d + (4,), dtype="float32")
|
| 1118 |
+
rgba[:, :, :3] = color
|
| 1119 |
+
rgba[:, :, 3] = soft_mask * alpha
|
| 1120 |
+
self.output.ax.imshow(rgba, extent=(0, self.output.width, self.output.height, 0))
|
| 1121 |
+
|
| 1122 |
+
if text is not None:
|
| 1123 |
+
lighter_color = self._change_color_brightness(color, brightness_factor=0.7)
|
| 1124 |
+
binary_mask = (soft_mask > 0.5).astype("uint8")
|
| 1125 |
+
self._draw_text_in_mask(binary_mask, text, lighter_color)
|
| 1126 |
+
return self.output
|
| 1127 |
+
|
| 1128 |
+
def draw_polygon(self, segment, color, edge_color=None, alpha=0.5):
|
| 1129 |
+
"""
|
| 1130 |
+
Args:
|
| 1131 |
+
segment: numpy array of shape Nx2, containing all the points in the polygon.
|
| 1132 |
+
color: color of the polygon. Refer to `matplotlib.colors` for a full list of
|
| 1133 |
+
formats that are accepted.
|
| 1134 |
+
edge_color: color of the polygon edges. Refer to `matplotlib.colors` for a
|
| 1135 |
+
full list of formats that are accepted. If not provided, a darker shade
|
| 1136 |
+
of the polygon color will be used instead.
|
| 1137 |
+
alpha (float): blending efficient. Smaller values lead to more transparent masks.
|
| 1138 |
+
|
| 1139 |
+
Returns:
|
| 1140 |
+
output (VisImage): image object with polygon drawn.
|
| 1141 |
+
"""
|
| 1142 |
+
if edge_color is None:
|
| 1143 |
+
# make edge color darker than the polygon color
|
| 1144 |
+
if alpha > 0.8:
|
| 1145 |
+
edge_color = self._change_color_brightness(color, brightness_factor=-0.7)
|
| 1146 |
+
else:
|
| 1147 |
+
edge_color = color
|
| 1148 |
+
edge_color = mplc.to_rgb(edge_color) + (1,)
|
| 1149 |
+
|
| 1150 |
+
polygon = mpl.patches.Polygon(
|
| 1151 |
+
segment,
|
| 1152 |
+
fill=True,
|
| 1153 |
+
facecolor=mplc.to_rgb(color) + (alpha,),
|
| 1154 |
+
edgecolor=edge_color,
|
| 1155 |
+
linewidth=max(self._default_font_size // 15 * self.output.scale, 1),
|
| 1156 |
+
)
|
| 1157 |
+
self.output.ax.add_patch(polygon)
|
| 1158 |
+
return self.output
|
| 1159 |
+
|
| 1160 |
+
"""
|
| 1161 |
+
Internal methods:
|
| 1162 |
+
"""
|
| 1163 |
+
|
| 1164 |
+
def _jitter(self, color):
|
| 1165 |
+
"""
|
| 1166 |
+
Randomly modifies given color to produce a slightly different color than the color given.
|
| 1167 |
+
|
| 1168 |
+
Args:
|
| 1169 |
+
color (tuple[double]): a tuple of 3 elements, containing the RGB values of the color
|
| 1170 |
+
picked. The values in the list are in the [0.0, 1.0] range.
|
| 1171 |
+
|
| 1172 |
+
Returns:
|
| 1173 |
+
jittered_color (tuple[double]): a tuple of 3 elements, containing the RGB values of the
|
| 1174 |
+
color after being jittered. The values in the list are in the [0.0, 1.0] range.
|
| 1175 |
+
"""
|
| 1176 |
+
color = mplc.to_rgb(color)
|
| 1177 |
+
vec = np.random.rand(3)
|
| 1178 |
+
# better to do it in another color space
|
| 1179 |
+
vec = vec / np.linalg.norm(vec) * 0.5
|
| 1180 |
+
res = np.clip(vec + color, 0, 1)
|
| 1181 |
+
return tuple(res)
|
| 1182 |
+
|
| 1183 |
+
def _create_grayscale_image(self, mask=None):
|
| 1184 |
+
"""
|
| 1185 |
+
Create a grayscale version of the original image.
|
| 1186 |
+
The colors in masked area, if given, will be kept.
|
| 1187 |
+
"""
|
| 1188 |
+
img_bw = self.img.astype("f4").mean(axis=2)
|
| 1189 |
+
img_bw = np.stack([img_bw] * 3, axis=2)
|
| 1190 |
+
if mask is not None:
|
| 1191 |
+
img_bw[mask] = self.img[mask]
|
| 1192 |
+
return img_bw
|
| 1193 |
+
|
| 1194 |
+
def _change_color_brightness(self, color, brightness_factor):
|
| 1195 |
+
"""
|
| 1196 |
+
Depending on the brightness_factor, gives a lighter or darker color i.e. a color with
|
| 1197 |
+
less or more saturation than the original color.
|
| 1198 |
+
|
| 1199 |
+
Args:
|
| 1200 |
+
color: color of the polygon. Refer to `matplotlib.colors` for a full list of
|
| 1201 |
+
formats that are accepted.
|
| 1202 |
+
brightness_factor (float): a value in [-1.0, 1.0] range. A lightness factor of
|
| 1203 |
+
0 will correspond to no change, a factor in [-1.0, 0) range will result in
|
| 1204 |
+
a darker color and a factor in (0, 1.0] range will result in a lighter color.
|
| 1205 |
+
|
| 1206 |
+
Returns:
|
| 1207 |
+
modified_color (tuple[double]): a tuple containing the RGB values of the
|
| 1208 |
+
modified color. Each value in the tuple is in the [0.0, 1.0] range.
|
| 1209 |
+
"""
|
| 1210 |
+
assert brightness_factor >= -1.0 and brightness_factor <= 1.0
|
| 1211 |
+
color = mplc.to_rgb(color)
|
| 1212 |
+
polygon_color = colorsys.rgb_to_hls(*mplc.to_rgb(color))
|
| 1213 |
+
modified_lightness = polygon_color[1] + (brightness_factor * polygon_color[1])
|
| 1214 |
+
modified_lightness = 0.0 if modified_lightness < 0.0 else modified_lightness
|
| 1215 |
+
modified_lightness = 1.0 if modified_lightness > 1.0 else modified_lightness
|
| 1216 |
+
modified_color = colorsys.hls_to_rgb(polygon_color[0], modified_lightness, polygon_color[2])
|
| 1217 |
+
return tuple(np.clip(modified_color, 0.0, 1.0))
|
| 1218 |
+
|
| 1219 |
+
def _convert_boxes(self, boxes):
|
| 1220 |
+
"""
|
| 1221 |
+
Convert different format of boxes to an NxB array, where B = 4 or 5 is the box dimension.
|
| 1222 |
+
"""
|
| 1223 |
+
if isinstance(boxes, Boxes) or isinstance(boxes, RotatedBoxes):
|
| 1224 |
+
return boxes.tensor.detach().numpy()
|
| 1225 |
+
else:
|
| 1226 |
+
return np.asarray(boxes)
|
| 1227 |
+
|
| 1228 |
+
def _convert_masks(self, masks_or_polygons):
|
| 1229 |
+
"""
|
| 1230 |
+
Convert different format of masks or polygons to a tuple of masks and polygons.
|
| 1231 |
+
|
| 1232 |
+
Returns:
|
| 1233 |
+
list[GenericMask]:
|
| 1234 |
+
"""
|
| 1235 |
+
|
| 1236 |
+
m = masks_or_polygons
|
| 1237 |
+
if isinstance(m, PolygonMasks):
|
| 1238 |
+
m = m.polygons
|
| 1239 |
+
if isinstance(m, BitMasks):
|
| 1240 |
+
m = m.tensor.numpy()
|
| 1241 |
+
if isinstance(m, torch.Tensor):
|
| 1242 |
+
m = m.numpy()
|
| 1243 |
+
ret = []
|
| 1244 |
+
for x in m:
|
| 1245 |
+
if isinstance(x, GenericMask):
|
| 1246 |
+
ret.append(x)
|
| 1247 |
+
else:
|
| 1248 |
+
ret.append(GenericMask(x, self.output.height, self.output.width))
|
| 1249 |
+
return ret
|
| 1250 |
+
|
| 1251 |
+
def _draw_text_in_mask(self, binary_mask, text, color):
|
| 1252 |
+
"""
|
| 1253 |
+
Find proper places to draw text given a binary mask.
|
| 1254 |
+
"""
|
| 1255 |
+
# TODO sometimes drawn on wrong objects. the heuristics here can improve.
|
| 1256 |
+
_num_cc, cc_labels, stats, centroids = cv2.connectedComponentsWithStats(binary_mask, 8)
|
| 1257 |
+
if stats[1:, -1].size == 0:
|
| 1258 |
+
return
|
| 1259 |
+
largest_component_id = np.argmax(stats[1:, -1]) + 1
|
| 1260 |
+
|
| 1261 |
+
# draw text on the largest component, as well as other very large components.
|
| 1262 |
+
for cid in range(1, _num_cc):
|
| 1263 |
+
if cid == largest_component_id or stats[cid, -1] > _LARGE_MASK_AREA_THRESH:
|
| 1264 |
+
# median is more stable than centroid
|
| 1265 |
+
# center = centroids[largest_component_id]
|
| 1266 |
+
center = np.median((cc_labels == cid).nonzero(), axis=1)[::-1]
|
| 1267 |
+
self.draw_text(text, center, color=color)
|
| 1268 |
+
|
| 1269 |
+
def _convert_keypoints(self, keypoints):
|
| 1270 |
+
if isinstance(keypoints, Keypoints):
|
| 1271 |
+
keypoints = keypoints.tensor
|
| 1272 |
+
keypoints = np.asarray(keypoints)
|
| 1273 |
+
return keypoints
|
| 1274 |
+
|
| 1275 |
+
def get_output(self):
|
| 1276 |
+
"""
|
| 1277 |
+
Returns:
|
| 1278 |
+
output (VisImage): the image output containing the visualizations added
|
| 1279 |
+
to the image.
|
| 1280 |
+
"""
|
| 1281 |
+
return self.output
|