Spaces:
Sleeping
Sleeping
| # -*- coding: utf-8 -*- | |
| """ | |
| Image processing utilities | |
| """ | |
| import torch | |
| import PIL | |
| import random | |
| from PIL import Image, ImageDraw | |
| from diffusers import VQModel | |
| from diffusers.image_processor import VaeImageProcessor | |
| import torch.nn.functional as F | |
| def decode_vq_to_image( | |
| vq_codes: torch.LongTensor, | |
| save_path: str = None, | |
| vae_ckpt: str = None, | |
| image_height: int = 512, | |
| image_width: int = 512, | |
| vqvae: VQModel = None | |
| ) -> Image.Image: | |
| """ | |
| Decode VQ codes to image | |
| Args: | |
| vq_codes: VQ codes in range [0, codebook_size), shape [batch_size, seq_len] | |
| save_path: Save path (optional, if None will not save to file) | |
| vae_ckpt: VAE checkpoint path (optional if vqvae is provided) | |
| image_height: Image height | |
| image_width: Image width | |
| vqvae: VQ-VAE model, if None will load from vae_ckpt | |
| Returns: | |
| PIL image | |
| """ | |
| device = vq_codes.device | |
| if vqvae is None: | |
| vqvae = VQModel.from_pretrained(vae_ckpt, subfolder="vqvae").to(device) | |
| scale = 2 ** (len(vqvae.config.block_out_channels) - 1) | |
| img_proc = VaeImageProcessor(vae_scale_factor=scale, do_normalize=False) | |
| # Calculate latent space grid size | |
| latent_height = image_height // scale | |
| latent_width = image_width // scale | |
| # Ensure VQ codes length matches | |
| expected_len = latent_height * latent_width | |
| if vq_codes.shape[1] != expected_len: | |
| raise ValueError( | |
| f"VQ codes length mismatch: {vq_codes.shape[1]} != {expected_len} " | |
| f"for image size ({image_height},{image_width}) with scale {scale}" | |
| ) | |
| # Reshape to 2D grid: [batch_size, seq_len] -> [batch_size, latent_height, latent_width] | |
| # vq_codes should already be in range [0, codebook_size), no offset needed | |
| latents = vq_codes.view(vq_codes.shape[0], latent_height, latent_width).long() | |
| # latents = (vq_codes.view(1, latent_height, latent_width) - 126356).long() | |
| # Decode | |
| recon = vqvae.decode( | |
| latents, | |
| force_not_quantize=True, | |
| shape=(vq_codes.shape[0], latent_height, latent_width, vqvae.config.latent_channels), | |
| ).sample.clip(0, 1) | |
| # Post-process | |
| img = img_proc.postprocess(recon.detach(), output_type="pil")[0] | |
| # Save image (only if save_path is provided) | |
| if save_path is not None: | |
| img.save(save_path) | |
| return img | |
| def preprocess_image(image_path: str, target_size: tuple = (512, 512)): | |
| """ | |
| Preprocess image: load, crop, resize | |
| Args: | |
| image_path: Image path | |
| target_size: Target size (width, height) | |
| Returns: | |
| Processed PIL image | |
| """ | |
| img = Image.open(image_path).convert("RGB") | |
| crop_size_list = generate_crop_size_list((target_size[0] // 32) ** 2, 32) | |
| processed_img = var_center_crop(img, crop_size_list=crop_size_list) | |
| return processed_img | |
| def calculate_vq_params(image_height: int, image_width: int, vae_scale: int = 16): | |
| """ | |
| Calculate VQ related parameters | |
| Args: | |
| image_height: Image height | |
| image_width: Image width | |
| vae_scale: VAE scale factor | |
| Returns: | |
| seq_len, newline_every, token_grid_height, token_grid_width | |
| """ | |
| token_grid_height = image_height // vae_scale | |
| token_grid_width = image_width // vae_scale | |
| seq_len = token_grid_height * token_grid_width | |
| newline_every = token_grid_width | |
| return seq_len, newline_every, token_grid_height, token_grid_width | |
| def center_crop(pil_image, crop_size): | |
| while pil_image.size[0] >= 2 * crop_size[0] and pil_image.size[1] >= 2 * crop_size[1]: | |
| pil_image = pil_image.resize(tuple(x // 2 for x in pil_image.size), resample=Image.BOX) | |
| scale = max(crop_size[0] / pil_image.size[0], crop_size[1] / pil_image.size[1]) | |
| pil_image = pil_image.resize(tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC) | |
| crop_left = random.randint(0, pil_image.size[0] - crop_size[0]) | |
| crop_upper = random.randint(0, pil_image.size[1] - crop_size[1]) | |
| crop_right = crop_left + crop_size[0] | |
| crop_lower = crop_upper + crop_size[1] | |
| return pil_image.crop(box=(crop_left, crop_upper, crop_right, crop_lower)) | |
| def var_center_crop(pil_image, crop_size_list, random_top_k=1): | |
| w, h = pil_image.size | |
| rem_percent = [min(cw / w, ch / h) / max(cw / w, ch / h) for cw, ch in crop_size_list] | |
| crop_size = random.choice( | |
| sorted(((x, y) for x, y in zip(rem_percent, crop_size_list)), reverse=True)[:random_top_k] | |
| )[1] | |
| return center_crop(pil_image, crop_size) | |
| def generate_crop_size_list(num_patches, patch_size, max_ratio=4.0): | |
| assert max_ratio >= 1.0 | |
| crop_size_list = [] | |
| wp, hp = num_patches, 1 | |
| while wp > 0: | |
| if max(wp, hp) / min(wp, hp) <= max_ratio: | |
| crop_size_list.append((wp * patch_size, hp * patch_size)) | |
| if (hp + 1) * wp <= num_patches: | |
| hp += 1 | |
| else: | |
| wp -= 1 | |
| return crop_size_list | |
| def add_break_line(sequence: list, H: int, W: int, new_number: int = 0) -> list: | |
| """Add newline characters to sequence""" | |
| result = [] | |
| for i in range(H): | |
| start = i * W | |
| end = start + W | |
| row = sequence[start:end] | |
| result.extend(row + [new_number]) | |
| return result | |
| def encode_img_with_breaks(img, vqvae, vae_scale_factor: int = 16): | |
| """Encode image and add newline characters""" | |
| from diffusers.image_processor import VaeImageProcessor | |
| orig = img.convert("RGB") | |
| orig_resized = orig | |
| image_processor = VaeImageProcessor(vae_scale_factor=vae_scale_factor, do_normalize=False) | |
| x = image_processor.preprocess(orig_resized).to(vqvae.device) | |
| latents = vqvae.encode(x).latents | |
| latents_bsz, channels, lat_h, lat_w = latents.shape | |
| quantized = vqvae.quantize(latents)[2][2] + 126356 | |
| quantized = quantized.reshape(latents_bsz, lat_h, lat_w).flatten().tolist() | |
| img_token = add_break_line(quantized, lat_h, lat_w, new_number=126084) | |
| img_token = [126349] + img_token + [126350] | |
| return img_token | |
| def encode_img_with_paint( | |
| img: Image.Image, | |
| vqvae: VQModel, | |
| *, | |
| mask_h_ratio: float = 1, # Height ratio | |
| mask_w_ratio: float = 0.2, # Width ratio | |
| gray_value: int = 127, # Visualization gray value | |
| downsample_mode: str = "area",# Pixel mask alignment to latent grid | |
| dilate_latent_k: int = 0, # Optional dilation on latent grid (grid count) | |
| mask_mode: str = "inpainting", # "inpainting" | "outpainting" | |
| ): | |
| """ | |
| Encode image with mask for inpainting/outpainting tasks | |
| Args: | |
| img: Input PIL image | |
| vqvae: VQ-VAE model for encoding | |
| mask_h_ratio: Height ratio for mask region (default: 1.0) | |
| mask_w_ratio: Width ratio for mask region (default: 0.2) | |
| gray_value: Gray value for mask visualization (default: 127) | |
| downsample_mode: Downsampling mode for mask alignment ("area", "nearest", "bilinear") | |
| dilate_latent_k: Dilation kernel size for latent grid (default: 0) | |
| mask_mode: Mask mode - "inpainting" (mask inside) or "outpainting" (mask outside) | |
| Returns: | |
| img_token: List[int] - Token sequence with newlines (126084) inserted at row ends; | |
| masked positions = 126336, others = index + 126356 | |
| vis_img: PIL.Image - Gray mask visualization image (consistent with mask_mode) | |
| Note: | |
| * Encoding uses original image strictly; mask only maps to latent grid to determine | |
| which tokens are set to MASK_TOKEN_ID. | |
| * mask_mode="inpainting": mask inside rectangle; "outpainting": mask outside rectangle (inverse). | |
| """ | |
| MASK_TOKEN_ID = 126336 # mask token | |
| NEWLINE_TOKEN_ID = 126084 # newline token | |
| VQ_OFFSET = 126356 # quantization index offset | |
| assert mask_mode in ("inpainting", "outpainting"), "mask_mode must be 'inpainting' or 'outpainting'" | |
| # --- 1) Calculate center rectangle and generate visualization --- | |
| img = img.convert("RGB") | |
| W, H = img.size | |
| mh = int(round(H * mask_h_ratio)) | |
| mw = int(round(W * mask_w_ratio)) | |
| top = (H - mh) // 2 | |
| left = (W - mw) // 2 | |
| bottom = top + mh | |
| right = left + mw | |
| if mask_mode == "inpainting": | |
| vis_img = img.copy() | |
| draw = ImageDraw.Draw(vis_img) | |
| draw.rectangle([left, top, right, bottom], fill=(gray_value, gray_value, gray_value)) | |
| elif mask_mode == "outpainting": # outpainting | |
| bg = Image.new("RGB", (W, H), (gray_value, gray_value, gray_value)) | |
| crop = img.crop((left, top, right, bottom)) | |
| bg.paste(crop, (left, top)) | |
| vis_img = bg | |
| # --- 2) VQ encoding using original image --- | |
| vae_scale_factor = 2 ** (len(vqvae.config.block_out_channels) - 1) | |
| image_processor = VaeImageProcessor(vae_scale_factor=vae_scale_factor, do_normalize=False) | |
| x = image_processor.preprocess(img).to(vqvae.device) # 1 x 3 x H' x W' | |
| latents = vqvae.encode(x).latents # 1 x C x h x w | |
| _, _, lat_h, lat_w = latents.shape | |
| # Quantization indices | |
| quant_pack = vqvae.quantize(latents) | |
| indices = quant_pack[2][2].view(1, lat_h, lat_w) # 1 x h x w, long | |
| # --- 3) Pixel mask -> latent grid mask (aligned with encoding input size) --- | |
| Hp, Wp = x.shape[-2:] | |
| mask_px = torch.zeros((1, 1, Hp, Wp), dtype=torch.float32, device=vqvae.device) | |
| # First generate mask where "rectangle inside=1, outside=0" | |
| top_p = int(round(top * Hp / H)) | |
| left_p = int(round(left * Wp / W)) | |
| bh_p = int(round(mh * Hp / H)) | |
| bw_p = int(round(mw * Wp / W)) | |
| mask_px[:, :, top_p:top_p+bh_p, left_p:left_p+bw_p] = 1.0 | |
| # If outpainting, need to invert (outside=1, inside=0 is the masked region) | |
| if mask_mode == "outpainting": | |
| mask_px = 1.0 - mask_px | |
| if downsample_mode not in ("nearest", "area", "bilinear"): | |
| downsample_mode = "area" | |
| mask_lat = F.interpolate(mask_px, size=(lat_h, lat_w), mode=downsample_mode) | |
| mask_lat = (mask_lat > 0.5) if downsample_mode == "area" else (mask_lat >= 0.5) | |
| mask_lat = mask_lat[0, 0] # h x w (bool) | |
| # Optional: latent grid dilation (after inversion is applied) | |
| if dilate_latent_k > 0: | |
| m = mask_lat.float().unsqueeze(0).unsqueeze(0) | |
| ker = 2 * dilate_latent_k + 1 | |
| m = F.max_pool2d(m, kernel_size=ker, stride=1, padding=dilate_latent_k) | |
| mask_lat = (m[0, 0] > 0.5) | |
| # --- 4) Generate tokens: masked positions=MASK_TOKEN_ID, others=indices+VQ_OFFSET --- | |
| idx_flat = indices.view(-1) | |
| mask_flat = mask_lat.view(-1) | |
| tokens = torch.empty_like(idx_flat) | |
| tokens[mask_flat] = MASK_TOKEN_ID | |
| tokens[~mask_flat] = idx_flat[~mask_flat] + VQ_OFFSET | |
| tokens_list = tokens.tolist() | |
| # --- 5) Insert newlines (no longer wrapped in <boi>/<eoi>, consistent with current return) --- | |
| img_token = add_break_line(tokens_list, lat_h, lat_w, NEWLINE_TOKEN_ID) | |
| return img_token, vis_img |