File size: 12,317 Bytes
36c95ba | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 | import math
from typing import List, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from kornia.filters import filter2d, gaussian_blur2d
__all__ = [
"PyrDown",
"PyrUp",
"ScalePyramid",
"pyrdown",
"pyrup",
"build_pyramid"
]
def _get_pyramid_gaussian_kernel() -> torch.Tensor:
"""Utility function that return a pre-computed gaussian kernel."""
return (
torch.tensor(
[
[
[1.0, 4.0, 6.0, 4.0, 1.0],
[4.0, 16.0, 24.0, 16.0, 4.0],
[6.0, 24.0, 36.0, 24.0, 6.0],
[4.0, 16.0, 24.0, 16.0, 4.0],
[1.0, 4.0, 6.0, 4.0, 1.0],
]
]
)
/ 256.0
)
class PyrDown(nn.Module):
r"""Blur a tensor and downsamples it.
Args:
border_type: the padding mode to be applied before convolving.
The expected modes are: ``'constant'``, ``'reflect'``,
``'replicate'`` or ``'circular'``.
align_corners: interpolation flag.
Return:
the downsampled tensor.
Shape:
- Input: :math:`(B, C, H, W)`
- Output: :math:`(B, C, H / 2, W / 2)`
Examples:
>>> input = torch.rand(1, 2, 4, 4)
>>> output = PyrDown()(input) # 1x2x2x2
"""
def __init__(self, border_type: str = 'reflect', align_corners: bool = False) -> None:
super().__init__()
self.border_type: str = border_type
self.align_corners: bool = align_corners
def forward(self, input: torch.Tensor) -> torch.Tensor:
return pyrdown(input, self.border_type, self.align_corners)
class PyrUp(nn.Module):
r"""Upsample a tensor and then blurs it.
Args:
borde_type: the padding mode to be applied before convolving.
The expected modes are: ``'constant'``, ``'reflect'``,
``'replicate'`` or ``'circular'``.
align_corners: interpolation flag.
Return:
the upsampled tensor.
Shape:
- Input: :math:`(B, C, H, W)`
- Output: :math:`(B, C, H * 2, W * 2)`
Examples:
>>> input = torch.rand(1, 2, 4, 4)
>>> output = PyrUp()(input) # 1x2x8x8
"""
def __init__(self, border_type: str = 'reflect', align_corners: bool = False):
super().__init__()
self.border_type: str = border_type
self.align_corners: bool = align_corners
def forward(self, input: torch.Tensor) -> torch.Tensor:
return pyrup(input, self.border_type, self.align_corners)
class ScalePyramid(nn.Module):
r"""Create an scale pyramid of image, usually used for local feature detection.
Images are consequently smoothed with Gaussian blur and downscaled.
Args:
n_levels: number of the levels in octave.
init_sigma: initial blur level.
min_size: the minimum size of the octave in pixels.
double_image: add 2x upscaled image as 1st level of pyramid. OpenCV SIFT does this.
Returns:
1st output: images
2nd output: sigmas (coefficients for scale conversion)
3rd output: pixelDists (coefficients for coordinate conversion)
Shape:
- Input: :math:`(B, C, H, W)`
- Output 1st: :math:`[(B, C, NL, H, W), (B, C, NL, H/2, W/2), ...]`
- Output 2nd: :math:`[(B, NL), (B, NL), (B, NL), ...]`
- Output 3rd: :math:`[(B, NL), (B, NL), (B, NL), ...]`
Examples:
>>> input = torch.rand(2, 4, 100, 100)
>>> sp, sigmas, pds = ScalePyramid(3, 15)(input)
"""
def __init__(self, n_levels: int = 3, init_sigma: float = 1.6, min_size: int = 15, double_image: bool = False):
super().__init__()
# 3 extra levels are needed for DoG nms.
self.n_levels = n_levels
self.extra_levels: int = 3
self.init_sigma = init_sigma
self.min_size = min_size
self.border = min_size // 2 - 1
self.sigma_step = 2 ** (1.0 / float(self.n_levels))
self.double_image = double_image
def __repr__(self) -> str:
return (
self.__class__.__name__
+ '(n_levels='
+ str(self.n_levels)
+ ', '
+ 'init_sigma='
+ str(self.init_sigma)
+ ', '
+ 'min_size='
+ str(self.min_size)
+ ', '
+ 'extra_levels='
+ str(self.extra_levels)
+ ', '
+ 'border='
+ str(self.border)
+ ', '
+ 'sigma_step='
+ str(self.sigma_step)
+ ', '
+ 'double_image='
+ str(self.double_image)
+ ')'
)
def get_kernel_size(self, sigma: float):
ksize = int(2.0 * 4.0 * sigma + 1.0)
# matches OpenCV, but may cause padding problem for small images
# PyTorch does not allow to pad more than original size.
# Therefore there is a hack in forward function
if ksize % 2 == 0:
ksize += 1
return ksize
def get_first_level(self, input):
pixel_distance = 1.0
cur_sigma = 0.5
# Same as in OpenCV up to interpolation difference
if self.double_image:
x = F.interpolate(input, scale_factor=2.0, mode='bilinear', align_corners=False)
pixel_distance = 0.5
cur_sigma *= 2.0
else:
x = input
if self.init_sigma > cur_sigma:
sigma = max(math.sqrt(self.init_sigma ** 2 - cur_sigma ** 2), 0.01)
ksize = self.get_kernel_size(sigma)
cur_level = gaussian_blur2d(x, (ksize, ksize), (sigma, sigma))
cur_sigma = self.init_sigma
else:
cur_level = x
return cur_level, cur_sigma, pixel_distance
def forward(self, x: torch.Tensor) -> Tuple[List, List, List]: # type: ignore
bs, _, _, _ = x.size()
cur_level, cur_sigma, pixel_distance = self.get_first_level(x)
sigmas = [cur_sigma * torch.ones(bs, self.n_levels + self.extra_levels).to(x.device).to(x.dtype)]
pixel_dists = [pixel_distance * torch.ones(bs, self.n_levels + self.extra_levels).to(x.device).to(x.dtype)]
pyr = [[cur_level]]
oct_idx = 0
while True:
cur_level = pyr[-1][0]
for level_idx in range(1, self.n_levels + self.extra_levels):
sigma = cur_sigma * math.sqrt(self.sigma_step ** 2 - 1.0)
ksize = self.get_kernel_size(sigma)
# Hack, because PyTorch does not allow to pad more than original size.
# But for the huge sigmas, one needs huge kernel and padding...
ksize = min(ksize, min(cur_level.size(2), cur_level.size(3)))
if ksize % 2 == 0:
ksize += 1
cur_level = gaussian_blur2d(cur_level, (ksize, ksize), (sigma, sigma))
cur_sigma *= self.sigma_step
pyr[-1].append(cur_level)
sigmas[-1][:, level_idx] = cur_sigma
pixel_dists[-1][:, level_idx] = pixel_distance
_pyr = pyr[-1][-self.extra_levels]
nextOctaveFirstLevel = F.interpolate(
_pyr, size=(_pyr.size(-2) // 2, _pyr.size(-1) // 2), mode='nearest'
) # Nearest matches OpenCV SIFT
pixel_distance *= 2.0
cur_sigma = self.init_sigma
if min(nextOctaveFirstLevel.size(2), nextOctaveFirstLevel.size(3)) <= self.min_size:
break
pyr.append([nextOctaveFirstLevel])
sigmas.append(cur_sigma * torch.ones(bs, self.n_levels + self.extra_levels).to(x.device))
pixel_dists.append(pixel_distance * torch.ones(bs, self.n_levels + self.extra_levels).to(x.device))
oct_idx += 1
for i in range(len(pyr)):
pyr[i] = torch.stack(pyr[i], dim=2) # type: ignore
return pyr, sigmas, pixel_dists
def pyrdown(input: torch.Tensor, border_type: str = 'reflect', align_corners: bool = False) -> torch.Tensor:
r"""Blur a tensor and downsamples it.
.. image:: _static/img/pyrdown.png
Args:
input: the tensor to be downsampled.
border_type: the padding mode to be applied before convolving.
The expected modes are: ``'constant'``, ``'reflect'``,
``'replicate'`` or ``'circular'``.
align_corners: interpolation flag.
Return:
the downsampled tensor.
Examples:
>>> input = torch.arange(16, dtype=torch.float32).reshape(1, 1, 4, 4)
>>> pyrdown(input, align_corners=True)
tensor([[[[ 3.7500, 5.2500],
[ 9.7500, 11.2500]]]])
"""
if not len(input.shape) == 4:
raise ValueError(f"Invalid input shape, we expect BxCxHxW. Got: {input.shape}")
kernel: torch.Tensor = _get_pyramid_gaussian_kernel()
_, _, height, width = input.shape
# blur image
x_blur: torch.Tensor = filter2d(input, kernel, border_type)
# TODO: use kornia.geometry.resize/rescale
# downsample.
out: torch.Tensor = F.interpolate(
x_blur, size=(height // 2, width // 2), mode='bilinear', align_corners=align_corners
)
return out
def pyrup(input: torch.Tensor, border_type: str = 'reflect', align_corners: bool = False) -> torch.Tensor:
r"""Upsample a tensor and then blurs it.
.. image:: _static/img/pyrup.png
Args:
input: the tensor to be downsampled.
border_type: the padding mode to be applied before convolving.
The expected modes are: ``'constant'``, ``'reflect'``, ``'replicate'`` or ``'circular'``.
align_corners: interpolation flag.
Return:
the downsampled tensor.
Examples:
>>> input = torch.arange(4, dtype=torch.float32).reshape(1, 1, 2, 2)
>>> pyrup(input, align_corners=True)
tensor([[[[0.7500, 0.8750, 1.1250, 1.2500],
[1.0000, 1.1250, 1.3750, 1.5000],
[1.5000, 1.6250, 1.8750, 2.0000],
[1.7500, 1.8750, 2.1250, 2.2500]]]])
"""
if not len(input.shape) == 4:
raise ValueError(f"Invalid input shape, we expect BxCxHxW. Got: {input.shape}")
kernel: torch.Tensor = _get_pyramid_gaussian_kernel()
# upsample tensor
_, _, height, width = input.shape
# TODO: use kornia.geometry.resize/rescale
x_up: torch.Tensor = F.interpolate(
input, size=(height * 2, width * 2), mode='bilinear', align_corners=align_corners
)
# blurs upsampled tensor
x_blur: torch.Tensor = filter2d(x_up, kernel, border_type)
return x_blur
def build_pyramid(
input: torch.Tensor, max_level: int, border_type: str = 'reflect', align_corners: bool = False
) -> List[torch.Tensor]:
r"""Construct the Gaussian pyramid for an image.
.. image:: _static/img/build_pyramid.png
The function constructs a vector of images and builds the Gaussian pyramid
by recursively applying pyrDown to the previously built pyramid layers.
Args:
input : the tensor to be used to construct the pyramid.
max_level: 0-based index of the last (the smallest) pyramid layer.
It must be non-negative.
border_type: the padding mode to be applied before convolving.
The expected modes are: ``'constant'``, ``'reflect'``,
``'replicate'`` or ``'circular'``.
align_corners: interpolation flag.
Shape:
- Input: :math:`(B, C, H, W)`
- Output :math:`[(B, C, H, W), (B, C, H/2, W/2), ...]`
"""
if not isinstance(input, torch.Tensor):
raise TypeError(f"Input type is not a torch.Tensor. Got {type(input)}")
if not len(input.shape) == 4:
raise ValueError(f"Invalid input shape, we expect BxCxHxW. Got: {input.shape}")
if not isinstance(max_level, int) or max_level < 0:
raise ValueError(f"Invalid max_level, it must be a positive integer. Got: {max_level}")
# create empty list and append the original image
pyramid: List[torch.Tensor] = []
pyramid.append(input)
# iterate and downsample
for _ in range(max_level - 1):
img_curr: torch.Tensor = pyramid[-1]
img_down: torch.Tensor = pyrdown(img_curr, border_type, align_corners)
pyramid.append(img_down)
return pyramid
|