File size: 4,470 Bytes
cf3d756 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 | import math
import torch
from torchvision.transforms.functional import resize, InterpolationMode
from einops import rearrange
from typing import Tuple, Union
from PIL import Image
class DynamicResize(torch.nn.Module):
"""
Resize so that:
* the longer side ≤ `max_side_len` **and** is divisible by `patch_size`
* the shorter side keeps aspect ratio and is also divisible by `patch_size`
Optionally forbids up-scaling.
Works on PIL Images, (C, H, W) tensors, or (B, C, H, W) tensors.
Returns the same type it receives.
"""
def __init__(
self,
patch_size: int,
max_side_len: int,
resize_to_max_side_len: bool = False,
interpolation: InterpolationMode = InterpolationMode.BICUBIC,
) -> None:
super().__init__()
self.p = int(patch_size)
self.m = int(max_side_len)
self.interpolation = interpolation
print(f"Resize to max side len: {resize_to_max_side_len}")
self.resize_to_max_side_len = resize_to_max_side_len
# ------------------------------------------------------------
def _get_new_hw(self, h: int, w: int) -> Tuple[int, int]:
"""Compute target (h, w) divisible by patch_size."""
long, short = (w, h) if w >= h else (h, w)
# 1) upscale long side
target_long = self.m if self.resize_to_max_side_len else min(self.m, math.ceil(long / self.p) * self.p)
# 2) scale factor
scale = target_long / long
# 3) compute short side with ceil → never undershoot
target_short = math.ceil(short * scale / self.p) * self.p
target_short = max(target_short, self.p) # just in case
return (target_short, target_long) if w >= h else (target_long, target_short)
# ------------------------------------------------------------
def forward(self, img: Union[Image.Image, torch.Tensor]):
if isinstance(img, Image.Image):
w, h = img.size
new_h, new_w = self._get_new_hw(h, w)
return resize(img, [new_h, new_w], interpolation=self.interpolation)
if not torch.is_tensor(img):
raise TypeError(
"DynamicResize expects a PIL Image or a torch.Tensor; "
f"got {type(img)}"
)
# tensor path ---------------------------------------------------------
batched = img.ndim == 4
if img.ndim not in (3, 4):
raise ValueError(
"Tensor input must have shape (C,H,W) or (B,C,H,W); "
f"got {img.shape}"
)
# operate batch-wise
imgs = img if batched else img.unsqueeze(0)
_, _, h, w = imgs.shape
new_h, new_w = self._get_new_hw(h, w)
out = resize(imgs, [new_h, new_w], interpolation=self.interpolation)
return out if batched else out.squeeze(0)
class SplitImage(torch.nn.Module):
"""Split (B, C, H, W) image tensor into square patches.
Returns:
patches: (B·n_h·n_w, C, patch_size, patch_size)
grid: (n_h, n_w) - number of patches along H and W
"""
def __init__(self, patch_size: int) -> None:
super().__init__()
self.p = patch_size
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, Tuple[int, int]]:
if x.ndim == 3: # add batch dim if missing
x = x.unsqueeze(0)
b, c, h, w = x.shape
if h % self.p or w % self.p:
raise ValueError(f'Image size {(h,w)} not divisible by patch_size {self.p}')
n_h, n_w = h // self.p, w // self.p
patches = rearrange(x, 'b c (nh ph) (nw pw) -> (b nh nw) c ph pw',
ph=self.p, pw=self.p)
return patches, (n_h, n_w)
class GlobalAndSplitImages(torch.nn.Module):
def __init__(self, patch_size: int):
super().__init__()
self.p = patch_size
self.splitter = SplitImage(patch_size)
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, Tuple[int, int]]:
if x.ndim == 3:
x = x.unsqueeze(0)
patches, grid = self.splitter(x)
if grid == (1, 1):
return patches, grid # Dont add global patch if there is only one patch
global_patch = resize(x, [self.p, self.p])
return torch.cat([global_patch, patches], dim=0), grid |