File size: 2,181 Bytes
97bc03d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import numpy as np
from PIL import Image
import torch
import torchvision
from torchvision import transforms

BACKGROUND_COLOR=(127, 127, 127)

from torchvision.transforms import InterpolationMode

def preprocess_image_with_min_size(image, min_factor=28):
    width, height = image.size 
    if height < min_factor or width < min_factor:
        scale_factor = max(min_factor / height, min_factor / width)
        new_width = int(width * scale_factor)
        new_height = int(height * scale_factor)

        image = image.resize((new_width, new_height), Image.Resampling.LANCZOS)    
    return image

def preprocess_image_gen(images, processor, vq_transform):

    image_list = []
    grid_thw_list = []
    vq_image_list = []
    for image in images:
        image = preprocess_image_with_min_size(image)

        visual_processed = processor.preprocess(image, return_tensors="pt")
        image_tensor = visual_processed["pixel_values"]
        if isinstance(image_tensor, list):
            image_tensor = image_tensor[0]
        image_list.append(image_tensor)
            
        grid_thw = visual_processed["image_grid_thw"][0]
        grid_thw_list.append(grid_thw)

        vq_image = vq_transform(image)
        vq_image_list.append(vq_image)
    
    image_tensor = torch.stack(image_list, dim=0)
    grid_thw = torch.stack(grid_thw_list, dim=0)
    vq_image = torch.stack(vq_image_list, dim=0)
        
    return {
        "pixel_values": image_tensor,
        "image_grid_thw": grid_thw,
        "vq_pixel_values": vq_image
    }



def get_vq_transform(args):
    return transforms.Compose([
        transforms.Resize((args.vq_image_size, args.vq_image_size), interpolation=InterpolationMode.BILINEAR),
        transforms.ToTensor(),    # [0, 255] -> [0, 1]
        transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),    # [0, 1] -> [-1, 1]
    ])

def get_full_transform(args):
    return transforms.Compose([
        transforms.Resize((1024, 1024), interpolation=InterpolationMode.BILINEAR),
        transforms.ToTensor(),    # [0, 255] -> [0, 1]
        transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),    # [0, 1] -> [-1, 1]
    ])