MMaDA-Parallel-A / utils /image_utils.py
akhaliq's picture
akhaliq HF Staff
Upload 22 files
9b58924 verified
# -*- coding: utf-8 -*-
"""
Image processing utilities
"""
import torch
import PIL
import random
from PIL import Image, ImageDraw
from diffusers import VQModel
from diffusers.image_processor import VaeImageProcessor
import torch.nn.functional as F
def decode_vq_to_image(
vq_codes: torch.LongTensor,
save_path: str = None,
vae_ckpt: str = None,
image_height: int = 512,
image_width: int = 512,
vqvae: VQModel = None
) -> Image.Image:
"""
Decode VQ codes to image
Args:
vq_codes: VQ codes in range [0, codebook_size), shape [batch_size, seq_len]
save_path: Save path (optional, if None will not save to file)
vae_ckpt: VAE checkpoint path (optional if vqvae is provided)
image_height: Image height
image_width: Image width
vqvae: VQ-VAE model, if None will load from vae_ckpt
Returns:
PIL image
"""
device = vq_codes.device
if vqvae is None:
vqvae = VQModel.from_pretrained(vae_ckpt, subfolder="vqvae").to(device)
scale = 2 ** (len(vqvae.config.block_out_channels) - 1)
img_proc = VaeImageProcessor(vae_scale_factor=scale, do_normalize=False)
# Calculate latent space grid size
latent_height = image_height // scale
latent_width = image_width // scale
# Ensure VQ codes length matches
expected_len = latent_height * latent_width
if vq_codes.shape[1] != expected_len:
raise ValueError(
f"VQ codes length mismatch: {vq_codes.shape[1]} != {expected_len} "
f"for image size ({image_height},{image_width}) with scale {scale}"
)
# Reshape to 2D grid: [batch_size, seq_len] -> [batch_size, latent_height, latent_width]
# vq_codes should already be in range [0, codebook_size), no offset needed
latents = vq_codes.view(vq_codes.shape[0], latent_height, latent_width).long()
# latents = (vq_codes.view(1, latent_height, latent_width) - 126356).long()
# Decode
recon = vqvae.decode(
latents,
force_not_quantize=True,
shape=(vq_codes.shape[0], latent_height, latent_width, vqvae.config.latent_channels),
).sample.clip(0, 1)
# Post-process
img = img_proc.postprocess(recon.detach(), output_type="pil")[0]
# Save image (only if save_path is provided)
if save_path is not None:
img.save(save_path)
return img
def preprocess_image(image_path: str, target_size: tuple = (512, 512)):
"""
Preprocess image: load, crop, resize
Args:
image_path: Image path
target_size: Target size (width, height)
Returns:
Processed PIL image
"""
img = Image.open(image_path).convert("RGB")
crop_size_list = generate_crop_size_list((target_size[0] // 32) ** 2, 32)
processed_img = var_center_crop(img, crop_size_list=crop_size_list)
return processed_img
def calculate_vq_params(image_height: int, image_width: int, vae_scale: int = 16):
"""
Calculate VQ related parameters
Args:
image_height: Image height
image_width: Image width
vae_scale: VAE scale factor
Returns:
seq_len, newline_every, token_grid_height, token_grid_width
"""
token_grid_height = image_height // vae_scale
token_grid_width = image_width // vae_scale
seq_len = token_grid_height * token_grid_width
newline_every = token_grid_width
return seq_len, newline_every, token_grid_height, token_grid_width
def center_crop(pil_image, crop_size):
while pil_image.size[0] >= 2 * crop_size[0] and pil_image.size[1] >= 2 * crop_size[1]:
pil_image = pil_image.resize(tuple(x // 2 for x in pil_image.size), resample=Image.BOX)
scale = max(crop_size[0] / pil_image.size[0], crop_size[1] / pil_image.size[1])
pil_image = pil_image.resize(tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC)
crop_left = random.randint(0, pil_image.size[0] - crop_size[0])
crop_upper = random.randint(0, pil_image.size[1] - crop_size[1])
crop_right = crop_left + crop_size[0]
crop_lower = crop_upper + crop_size[1]
return pil_image.crop(box=(crop_left, crop_upper, crop_right, crop_lower))
def var_center_crop(pil_image, crop_size_list, random_top_k=1):
w, h = pil_image.size
rem_percent = [min(cw / w, ch / h) / max(cw / w, ch / h) for cw, ch in crop_size_list]
crop_size = random.choice(
sorted(((x, y) for x, y in zip(rem_percent, crop_size_list)), reverse=True)[:random_top_k]
)[1]
return center_crop(pil_image, crop_size)
def generate_crop_size_list(num_patches, patch_size, max_ratio=4.0):
assert max_ratio >= 1.0
crop_size_list = []
wp, hp = num_patches, 1
while wp > 0:
if max(wp, hp) / min(wp, hp) <= max_ratio:
crop_size_list.append((wp * patch_size, hp * patch_size))
if (hp + 1) * wp <= num_patches:
hp += 1
else:
wp -= 1
return crop_size_list
def add_break_line(sequence: list, H: int, W: int, new_number: int = 0) -> list:
"""Add newline characters to sequence"""
result = []
for i in range(H):
start = i * W
end = start + W
row = sequence[start:end]
result.extend(row + [new_number])
return result
def encode_img_with_breaks(img, vqvae, vae_scale_factor: int = 16):
"""Encode image and add newline characters"""
from diffusers.image_processor import VaeImageProcessor
orig = img.convert("RGB")
orig_resized = orig
image_processor = VaeImageProcessor(vae_scale_factor=vae_scale_factor, do_normalize=False)
x = image_processor.preprocess(orig_resized).to(vqvae.device)
latents = vqvae.encode(x).latents
latents_bsz, channels, lat_h, lat_w = latents.shape
quantized = vqvae.quantize(latents)[2][2] + 126356
quantized = quantized.reshape(latents_bsz, lat_h, lat_w).flatten().tolist()
img_token = add_break_line(quantized, lat_h, lat_w, new_number=126084)
img_token = [126349] + img_token + [126350]
return img_token
@torch.no_grad()
def encode_img_with_paint(
img: Image.Image,
vqvae: VQModel,
*,
mask_h_ratio: float = 1, # Height ratio
mask_w_ratio: float = 0.2, # Width ratio
gray_value: int = 127, # Visualization gray value
downsample_mode: str = "area",# Pixel mask alignment to latent grid
dilate_latent_k: int = 0, # Optional dilation on latent grid (grid count)
mask_mode: str = "inpainting", # "inpainting" | "outpainting"
):
"""
Encode image with mask for inpainting/outpainting tasks
Args:
img: Input PIL image
vqvae: VQ-VAE model for encoding
mask_h_ratio: Height ratio for mask region (default: 1.0)
mask_w_ratio: Width ratio for mask region (default: 0.2)
gray_value: Gray value for mask visualization (default: 127)
downsample_mode: Downsampling mode for mask alignment ("area", "nearest", "bilinear")
dilate_latent_k: Dilation kernel size for latent grid (default: 0)
mask_mode: Mask mode - "inpainting" (mask inside) or "outpainting" (mask outside)
Returns:
img_token: List[int] - Token sequence with newlines (126084) inserted at row ends;
masked positions = 126336, others = index + 126356
vis_img: PIL.Image - Gray mask visualization image (consistent with mask_mode)
Note:
* Encoding uses original image strictly; mask only maps to latent grid to determine
which tokens are set to MASK_TOKEN_ID.
* mask_mode="inpainting": mask inside rectangle; "outpainting": mask outside rectangle (inverse).
"""
MASK_TOKEN_ID = 126336 # mask token
NEWLINE_TOKEN_ID = 126084 # newline token
VQ_OFFSET = 126356 # quantization index offset
assert mask_mode in ("inpainting", "outpainting"), "mask_mode must be 'inpainting' or 'outpainting'"
# --- 1) Calculate center rectangle and generate visualization ---
img = img.convert("RGB")
W, H = img.size
mh = int(round(H * mask_h_ratio))
mw = int(round(W * mask_w_ratio))
top = (H - mh) // 2
left = (W - mw) // 2
bottom = top + mh
right = left + mw
if mask_mode == "inpainting":
vis_img = img.copy()
draw = ImageDraw.Draw(vis_img)
draw.rectangle([left, top, right, bottom], fill=(gray_value, gray_value, gray_value))
elif mask_mode == "outpainting": # outpainting
bg = Image.new("RGB", (W, H), (gray_value, gray_value, gray_value))
crop = img.crop((left, top, right, bottom))
bg.paste(crop, (left, top))
vis_img = bg
# --- 2) VQ encoding using original image ---
vae_scale_factor = 2 ** (len(vqvae.config.block_out_channels) - 1)
image_processor = VaeImageProcessor(vae_scale_factor=vae_scale_factor, do_normalize=False)
x = image_processor.preprocess(img).to(vqvae.device) # 1 x 3 x H' x W'
latents = vqvae.encode(x).latents # 1 x C x h x w
_, _, lat_h, lat_w = latents.shape
# Quantization indices
quant_pack = vqvae.quantize(latents)
indices = quant_pack[2][2].view(1, lat_h, lat_w) # 1 x h x w, long
# --- 3) Pixel mask -> latent grid mask (aligned with encoding input size) ---
Hp, Wp = x.shape[-2:]
mask_px = torch.zeros((1, 1, Hp, Wp), dtype=torch.float32, device=vqvae.device)
# First generate mask where "rectangle inside=1, outside=0"
top_p = int(round(top * Hp / H))
left_p = int(round(left * Wp / W))
bh_p = int(round(mh * Hp / H))
bw_p = int(round(mw * Wp / W))
mask_px[:, :, top_p:top_p+bh_p, left_p:left_p+bw_p] = 1.0
# If outpainting, need to invert (outside=1, inside=0 is the masked region)
if mask_mode == "outpainting":
mask_px = 1.0 - mask_px
if downsample_mode not in ("nearest", "area", "bilinear"):
downsample_mode = "area"
mask_lat = F.interpolate(mask_px, size=(lat_h, lat_w), mode=downsample_mode)
mask_lat = (mask_lat > 0.5) if downsample_mode == "area" else (mask_lat >= 0.5)
mask_lat = mask_lat[0, 0] # h x w (bool)
# Optional: latent grid dilation (after inversion is applied)
if dilate_latent_k > 0:
m = mask_lat.float().unsqueeze(0).unsqueeze(0)
ker = 2 * dilate_latent_k + 1
m = F.max_pool2d(m, kernel_size=ker, stride=1, padding=dilate_latent_k)
mask_lat = (m[0, 0] > 0.5)
# --- 4) Generate tokens: masked positions=MASK_TOKEN_ID, others=indices+VQ_OFFSET ---
idx_flat = indices.view(-1)
mask_flat = mask_lat.view(-1)
tokens = torch.empty_like(idx_flat)
tokens[mask_flat] = MASK_TOKEN_ID
tokens[~mask_flat] = idx_flat[~mask_flat] + VQ_OFFSET
tokens_list = tokens.tolist()
# --- 5) Insert newlines (no longer wrapped in <boi>/<eoi>, consistent with current return) ---
img_token = add_break_line(tokens_list, lat_h, lat_w, NEWLINE_TOKEN_ID)
return img_token, vis_img