|
|
import numpy as np |
|
|
from PIL import Image |
|
|
import torch |
|
|
import torchvision |
|
|
from torchvision import transforms |
|
|
|
|
|
BACKGROUND_COLOR=(127, 127, 127) |
|
|
|
|
|
from torchvision.transforms import InterpolationMode |
|
|
|
|
|
def preprocess_image_with_min_size(image, min_factor=28): |
|
|
width, height = image.size |
|
|
if height < min_factor or width < min_factor: |
|
|
scale_factor = max(min_factor / height, min_factor / width) |
|
|
new_width = int(width * scale_factor) |
|
|
new_height = int(height * scale_factor) |
|
|
|
|
|
image = image.resize((new_width, new_height), Image.Resampling.LANCZOS) |
|
|
return image |
|
|
|
|
|
def preprocess_image_gen(images, processor, vq_transform): |
|
|
|
|
|
image_list = [] |
|
|
grid_thw_list = [] |
|
|
vq_image_list = [] |
|
|
for image in images: |
|
|
image = preprocess_image_with_min_size(image) |
|
|
|
|
|
visual_processed = processor.preprocess(image, return_tensors="pt") |
|
|
image_tensor = visual_processed["pixel_values"] |
|
|
if isinstance(image_tensor, list): |
|
|
image_tensor = image_tensor[0] |
|
|
image_list.append(image_tensor) |
|
|
|
|
|
grid_thw = visual_processed["image_grid_thw"][0] |
|
|
grid_thw_list.append(grid_thw) |
|
|
|
|
|
vq_image = vq_transform(image) |
|
|
vq_image_list.append(vq_image) |
|
|
|
|
|
image_tensor = torch.stack(image_list, dim=0) |
|
|
grid_thw = torch.stack(grid_thw_list, dim=0) |
|
|
vq_image = torch.stack(vq_image_list, dim=0) |
|
|
|
|
|
return { |
|
|
"pixel_values": image_tensor, |
|
|
"image_grid_thw": grid_thw, |
|
|
"vq_pixel_values": vq_image |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
def get_vq_transform(args): |
|
|
return transforms.Compose([ |
|
|
transforms.Resize((args.vq_image_size, args.vq_image_size), interpolation=InterpolationMode.BILINEAR), |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), |
|
|
]) |
|
|
|
|
|
def get_full_transform(args): |
|
|
return transforms.Compose([ |
|
|
transforms.Resize((1024, 1024), interpolation=InterpolationMode.BILINEAR), |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), |
|
|
]) |
|
|
|