vlm-demo / src /data /custom_transforms.py
tiltaf's picture
Upload 29 files
cf3d756 verified
import math
import torch
from torchvision.transforms.functional import resize, InterpolationMode
from einops import rearrange
from typing import Tuple, Union
from PIL import Image
class DynamicResize(torch.nn.Module):
"""
Resize so that:
* the longer side ≤ `max_side_len` **and** is divisible by `patch_size`
* the shorter side keeps aspect ratio and is also divisible by `patch_size`
Optionally forbids up-scaling.
Works on PIL Images, (C, H, W) tensors, or (B, C, H, W) tensors.
Returns the same type it receives.
"""
def __init__(
self,
patch_size: int,
max_side_len: int,
resize_to_max_side_len: bool = False,
interpolation: InterpolationMode = InterpolationMode.BICUBIC,
) -> None:
super().__init__()
self.p = int(patch_size)
self.m = int(max_side_len)
self.interpolation = interpolation
print(f"Resize to max side len: {resize_to_max_side_len}")
self.resize_to_max_side_len = resize_to_max_side_len
# ------------------------------------------------------------
def _get_new_hw(self, h: int, w: int) -> Tuple[int, int]:
"""Compute target (h, w) divisible by patch_size."""
long, short = (w, h) if w >= h else (h, w)
# 1) upscale long side
target_long = self.m if self.resize_to_max_side_len else min(self.m, math.ceil(long / self.p) * self.p)
# 2) scale factor
scale = target_long / long
# 3) compute short side with ceil → never undershoot
target_short = math.ceil(short * scale / self.p) * self.p
target_short = max(target_short, self.p) # just in case
return (target_short, target_long) if w >= h else (target_long, target_short)
# ------------------------------------------------------------
def forward(self, img: Union[Image.Image, torch.Tensor]):
if isinstance(img, Image.Image):
w, h = img.size
new_h, new_w = self._get_new_hw(h, w)
return resize(img, [new_h, new_w], interpolation=self.interpolation)
if not torch.is_tensor(img):
raise TypeError(
"DynamicResize expects a PIL Image or a torch.Tensor; "
f"got {type(img)}"
)
# tensor path ---------------------------------------------------------
batched = img.ndim == 4
if img.ndim not in (3, 4):
raise ValueError(
"Tensor input must have shape (C,H,W) or (B,C,H,W); "
f"got {img.shape}"
)
# operate batch-wise
imgs = img if batched else img.unsqueeze(0)
_, _, h, w = imgs.shape
new_h, new_w = self._get_new_hw(h, w)
out = resize(imgs, [new_h, new_w], interpolation=self.interpolation)
return out if batched else out.squeeze(0)
class SplitImage(torch.nn.Module):
"""Split (B, C, H, W) image tensor into square patches.
Returns:
patches: (B·n_h·n_w, C, patch_size, patch_size)
grid: (n_h, n_w) - number of patches along H and W
"""
def __init__(self, patch_size: int) -> None:
super().__init__()
self.p = patch_size
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, Tuple[int, int]]:
if x.ndim == 3: # add batch dim if missing
x = x.unsqueeze(0)
b, c, h, w = x.shape
if h % self.p or w % self.p:
raise ValueError(f'Image size {(h,w)} not divisible by patch_size {self.p}')
n_h, n_w = h // self.p, w // self.p
patches = rearrange(x, 'b c (nh ph) (nw pw) -> (b nh nw) c ph pw',
ph=self.p, pw=self.p)
return patches, (n_h, n_w)
class GlobalAndSplitImages(torch.nn.Module):
def __init__(self, patch_size: int):
super().__init__()
self.p = patch_size
self.splitter = SplitImage(patch_size)
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, Tuple[int, int]]:
if x.ndim == 3:
x = x.unsqueeze(0)
patches, grid = self.splitter(x)
if grid == (1, 1):
return patches, grid # Dont add global patch if there is only one patch
global_patch = resize(x, [self.p, self.p])
return torch.cat([global_patch, patches], dim=0), grid