|
|
import torch |
|
|
import requests |
|
|
import torchvision.transforms as transforms |
|
|
from math import ceil |
|
|
from PIL import Image |
|
|
import matplotlib.pyplot as plt |
|
|
|
|
|
MAX_RESOLUTION = 1024 |
|
|
|
|
|
def get_resize_output_image_size( |
|
|
image_size, |
|
|
fix_resolution=False, |
|
|
max_resolution: int = MAX_RESOLUTION, |
|
|
patch_size=32 |
|
|
) -> tuple: |
|
|
if fix_resolution==True: |
|
|
return 224,224 |
|
|
l1, l2 = image_size |
|
|
short, long = (l2, l1) if l2 <= l1 else (l1, l2) |
|
|
|
|
|
|
|
|
requested_new_long = min( |
|
|
[ |
|
|
ceil(long / patch_size) * patch_size, |
|
|
max_resolution, |
|
|
] |
|
|
) |
|
|
|
|
|
new_long, new_short = requested_new_long, int(requested_new_long * short / long) |
|
|
|
|
|
new_short = ceil(new_short / patch_size) * patch_size |
|
|
return (new_long, new_short) if l2 <= l1 else (new_short, new_long) |
|
|
|
|
|
|
|
|
def preprocess_image( |
|
|
image_tensor: torch.Tensor, |
|
|
patch_size=32 |
|
|
) -> torch.Tensor: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
patches = image_tensor.unfold(1, patch_size, patch_size)\ |
|
|
.unfold(2, patch_size, patch_size) |
|
|
patches = patches.permute(1, 2, 0, 3, 4).contiguous() |
|
|
return patches |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_transform(height, width): |
|
|
preprocess_transform = transforms.Compose([ |
|
|
transforms.Resize((height, width)), |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], |
|
|
std=[0.229, 0.224, 0.225]) |
|
|
]) |
|
|
return preprocess_transform |
|
|
|
|
|
def convert_image_to_patches(image, patch_size=32) -> torch.Tensor: |
|
|
|
|
|
width, height = image.size |
|
|
new_width, new_height = get_resize_output_image_size((width, height), patch_size=patch_size, fix_resolution=False) |
|
|
img_tensor = get_transform(new_height, new_width)(image) |
|
|
|
|
|
img_patches = preprocess_image(img_tensor, patch_size=patch_size) |
|
|
return img_patches |
|
|
|