File size: 11,300 Bytes
9b58924
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
# -*- coding: utf-8 -*-
"""

Image processing utilities

"""
import torch
import PIL
import random
from PIL import Image, ImageDraw
from diffusers import VQModel
from diffusers.image_processor import VaeImageProcessor
import torch.nn.functional as F

def decode_vq_to_image(

    vq_codes: torch.LongTensor, 

    save_path: str = None, 

    vae_ckpt: str = None, 

    image_height: int = 512, 

    image_width: int = 512,

    vqvae: VQModel = None

) -> Image.Image:
    """

    Decode VQ codes to image

    

    Args:

        vq_codes: VQ codes in range [0, codebook_size), shape [batch_size, seq_len]

        save_path: Save path (optional, if None will not save to file)

        vae_ckpt: VAE checkpoint path (optional if vqvae is provided)

        image_height: Image height

        image_width: Image width

        vqvae: VQ-VAE model, if None will load from vae_ckpt

    

    Returns:

        PIL image

    """
    device = vq_codes.device
    if vqvae is None:
        vqvae = VQModel.from_pretrained(vae_ckpt, subfolder="vqvae").to(device)
    
    scale = 2 ** (len(vqvae.config.block_out_channels) - 1)
    img_proc = VaeImageProcessor(vae_scale_factor=scale, do_normalize=False)

    # Calculate latent space grid size
    latent_height = image_height // scale
    latent_width = image_width // scale

    # Ensure VQ codes length matches
    expected_len = latent_height * latent_width
    if vq_codes.shape[1] != expected_len:
        raise ValueError(
            f"VQ codes length mismatch: {vq_codes.shape[1]} != {expected_len} "
            f"for image size ({image_height},{image_width}) with scale {scale}"
        )

    # Reshape to 2D grid: [batch_size, seq_len] -> [batch_size, latent_height, latent_width]
    # vq_codes should already be in range [0, codebook_size), no offset needed

    
    latents = vq_codes.view(vq_codes.shape[0], latent_height, latent_width).long()
    # latents = (vq_codes.view(1, latent_height, latent_width) - 126356).long()

    # Decode
    recon = vqvae.decode(
        latents,
        force_not_quantize=True,
        shape=(vq_codes.shape[0], latent_height, latent_width, vqvae.config.latent_channels),
    ).sample.clip(0, 1)

    # Post-process
    img = img_proc.postprocess(recon.detach(), output_type="pil")[0]
    
    # Save image (only if save_path is provided)
    if save_path is not None:
        img.save(save_path)
    
    return img


def preprocess_image(image_path: str, target_size: tuple = (512, 512)):
    """

    Preprocess image: load, crop, resize

    

    Args:

        image_path: Image path

        target_size: Target size (width, height)

    

    Returns:

        Processed PIL image

    """
    img = Image.open(image_path).convert("RGB")
    crop_size_list = generate_crop_size_list((target_size[0] // 32) ** 2, 32)
    processed_img = var_center_crop(img, crop_size_list=crop_size_list)
    return processed_img


def calculate_vq_params(image_height: int, image_width: int, vae_scale: int = 16):
    """

    Calculate VQ related parameters

    

    Args:

        image_height: Image height

        image_width: Image width

        vae_scale: VAE scale factor

    

    Returns:

        seq_len, newline_every, token_grid_height, token_grid_width

    """
    token_grid_height = image_height // vae_scale
    token_grid_width = image_width // vae_scale
    seq_len = token_grid_height * token_grid_width
    newline_every = token_grid_width
    return seq_len, newline_every, token_grid_height, token_grid_width

def center_crop(pil_image, crop_size):
    while pil_image.size[0] >= 2 * crop_size[0] and pil_image.size[1] >= 2 * crop_size[1]:
        pil_image = pil_image.resize(tuple(x // 2 for x in pil_image.size), resample=Image.BOX)

    scale = max(crop_size[0] / pil_image.size[0], crop_size[1] / pil_image.size[1])
    pil_image = pil_image.resize(tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC)

    crop_left = random.randint(0, pil_image.size[0] - crop_size[0])
    crop_upper = random.randint(0, pil_image.size[1] - crop_size[1])
    crop_right = crop_left + crop_size[0]
    crop_lower = crop_upper + crop_size[1]
    return pil_image.crop(box=(crop_left, crop_upper, crop_right, crop_lower))


def var_center_crop(pil_image, crop_size_list, random_top_k=1):
    w, h = pil_image.size
    rem_percent = [min(cw / w, ch / h) / max(cw / w, ch / h) for cw, ch in crop_size_list]
    crop_size = random.choice(
        sorted(((x, y) for x, y in zip(rem_percent, crop_size_list)), reverse=True)[:random_top_k]
    )[1]
    return center_crop(pil_image, crop_size)


def generate_crop_size_list(num_patches, patch_size, max_ratio=4.0):
    assert max_ratio >= 1.0
    crop_size_list = []
    wp, hp = num_patches, 1
    while wp > 0:
        if max(wp, hp) / min(wp, hp) <= max_ratio:
            crop_size_list.append((wp * patch_size, hp * patch_size))
        if (hp + 1) * wp <= num_patches:
            hp += 1
        else:
            wp -= 1
    return crop_size_list

def add_break_line(sequence: list, H: int, W: int, new_number: int = 0) -> list:
    """Add newline characters to sequence"""
    result = []
    for i in range(H):
        start = i * W
        end = start + W
        row = sequence[start:end]
        result.extend(row + [new_number])
    return result

def encode_img_with_breaks(img, vqvae, vae_scale_factor: int = 16):
    """Encode image and add newline characters"""
    from diffusers.image_processor import VaeImageProcessor
    
    orig = img.convert("RGB")
    orig_resized = orig
    image_processor = VaeImageProcessor(vae_scale_factor=vae_scale_factor, do_normalize=False)
    x = image_processor.preprocess(orig_resized).to(vqvae.device)
    latents = vqvae.encode(x).latents
    latents_bsz, channels, lat_h, lat_w = latents.shape
    quantized = vqvae.quantize(latents)[2][2] + 126356
    quantized = quantized.reshape(latents_bsz, lat_h, lat_w).flatten().tolist()
    img_token = add_break_line(quantized, lat_h, lat_w, new_number=126084)
    img_token = [126349] + img_token + [126350]
    return img_token

@torch.no_grad()
def encode_img_with_paint(

    img: Image.Image,

    vqvae: VQModel,

    *,

    mask_h_ratio: float = 1,   # Height ratio

    mask_w_ratio: float = 0.2,    # Width ratio

    gray_value: int = 127,        # Visualization gray value

    downsample_mode: str = "area",# Pixel mask alignment to latent grid

    dilate_latent_k: int = 0,     # Optional dilation on latent grid (grid count)

    mask_mode: str = "inpainting",   # "inpainting" | "outpainting"

):
    """

    Encode image with mask for inpainting/outpainting tasks

    

    Args:

        img: Input PIL image

        vqvae: VQ-VAE model for encoding

        mask_h_ratio: Height ratio for mask region (default: 1.0)

        mask_w_ratio: Width ratio for mask region (default: 0.2)

        gray_value: Gray value for mask visualization (default: 127)

        downsample_mode: Downsampling mode for mask alignment ("area", "nearest", "bilinear")

        dilate_latent_k: Dilation kernel size for latent grid (default: 0)

        mask_mode: Mask mode - "inpainting" (mask inside) or "outpainting" (mask outside)

    

    Returns:

        img_token: List[int] - Token sequence with newlines (126084) inserted at row ends;

                              masked positions = 126336, others = index + 126356

        vis_img: PIL.Image - Gray mask visualization image (consistent with mask_mode)

    

    Note:

        * Encoding uses original image strictly; mask only maps to latent grid to determine

          which tokens are set to MASK_TOKEN_ID.

        * mask_mode="inpainting": mask inside rectangle; "outpainting": mask outside rectangle (inverse).

    """
    MASK_TOKEN_ID = 126336      # mask token
    NEWLINE_TOKEN_ID = 126084   # newline token
    VQ_OFFSET = 126356          # quantization index offset

    assert mask_mode in ("inpainting", "outpainting"), "mask_mode must be 'inpainting' or 'outpainting'"

    # --- 1) Calculate center rectangle and generate visualization ---
    img = img.convert("RGB")
    W, H = img.size
    mh = int(round(H * mask_h_ratio))
    mw = int(round(W * mask_w_ratio))
    top = (H - mh) // 2
    left = (W - mw) // 2
    bottom = top + mh
    right = left + mw

    if mask_mode == "inpainting":
        vis_img = img.copy()
        draw = ImageDraw.Draw(vis_img)
        draw.rectangle([left, top, right, bottom], fill=(gray_value, gray_value, gray_value))
    elif mask_mode == "outpainting":  # outpainting
        bg = Image.new("RGB", (W, H), (gray_value, gray_value, gray_value))
        crop = img.crop((left, top, right, bottom))
        bg.paste(crop, (left, top))
        vis_img = bg

    # --- 2) VQ encoding using original image ---
    vae_scale_factor = 2 ** (len(vqvae.config.block_out_channels) - 1)
    image_processor = VaeImageProcessor(vae_scale_factor=vae_scale_factor, do_normalize=False)
    x = image_processor.preprocess(img).to(vqvae.device)  # 1 x 3 x H' x W'
    latents = vqvae.encode(x).latents                     # 1 x C x h x w
    _, _, lat_h, lat_w = latents.shape

    # Quantization indices
    quant_pack = vqvae.quantize(latents)
    indices = quant_pack[2][2].view(1, lat_h, lat_w)      # 1 x h x w, long

    # --- 3) Pixel mask -> latent grid mask (aligned with encoding input size) ---
    Hp, Wp = x.shape[-2:]
    mask_px = torch.zeros((1, 1, Hp, Wp), dtype=torch.float32, device=vqvae.device)
    # First generate mask where "rectangle inside=1, outside=0"
    top_p  = int(round(top  * Hp / H))
    left_p = int(round(left * Wp / W))
    bh_p   = int(round(mh   * Hp / H))
    bw_p   = int(round(mw   * Wp / W))
    mask_px[:, :, top_p:top_p+bh_p, left_p:left_p+bw_p] = 1.0

    # If outpainting, need to invert (outside=1, inside=0 is the masked region)
    if mask_mode == "outpainting":
        mask_px = 1.0 - mask_px

    if downsample_mode not in ("nearest", "area", "bilinear"):
        downsample_mode = "area"
    mask_lat = F.interpolate(mask_px, size=(lat_h, lat_w), mode=downsample_mode)
    mask_lat = (mask_lat > 0.5) if downsample_mode == "area" else (mask_lat >= 0.5)
    mask_lat = mask_lat[0, 0]        # h x w (bool)

    # Optional: latent grid dilation (after inversion is applied)
    if dilate_latent_k > 0:
        m = mask_lat.float().unsqueeze(0).unsqueeze(0)
        ker = 2 * dilate_latent_k + 1
        m = F.max_pool2d(m, kernel_size=ker, stride=1, padding=dilate_latent_k)
        mask_lat = (m[0, 0] > 0.5)

    # --- 4) Generate tokens: masked positions=MASK_TOKEN_ID, others=indices+VQ_OFFSET ---
    idx_flat = indices.view(-1)
    mask_flat = mask_lat.view(-1)
    tokens = torch.empty_like(idx_flat)
    tokens[mask_flat] = MASK_TOKEN_ID
    tokens[~mask_flat] = idx_flat[~mask_flat] + VQ_OFFSET
    tokens_list = tokens.tolist()

    # --- 5) Insert newlines (no longer wrapped in <boi>/<eoi>, consistent with current return) ---

    img_token = add_break_line(tokens_list, lat_h, lat_w, NEWLINE_TOKEN_ID)
    return img_token, vis_img