patchify / vit_mosaic.py
daidedou
Iniital commit
856dd67
"""
vit_mosaic.py
ViT-style patch mosaic generator
Supports:
- Auto grid selection (12 / 16 patches)
- Transparent or colored padding
- Rounded borders
- True rounded clipping
- Supersampling
- Downscale or keep resolution
"""
import math
import numpy as np
from PIL import Image, ImageDraw
from typing import Iterable, Tuple, Union
ColorType = Union[Tuple[int, int, int], str]
def parse_color(color: ColorType):
if isinstance(color, tuple):
return (*color, 255)
if isinstance(color, str):
color = color.strip()
if color.startswith("#"):
r = int(color[1:3], 16)
g = int(color[3:5], 16)
b = int(color[5:7], 16)
return (r, g, b, 255)
raise ValueError("Color must be RGB tuple or hex string '#RRGGBB'")
def make_vit_mosaic(
image: Image.Image,
target_total_patches: Iterable[int] = (12, 16),
max_long_side: int = 256,
spacing: int = 12,
border_thickness: int = 14,
border_color: ColorType = "#00FFFF",
padding_color: Union[None, ColorType] = None,
corner_radius: int = 22,
rounded: bool = True,
true_clipping: bool = True,
supersample: int = 1,
output_scale_mode: str = "keep", # "keep" or "downscale"
):
border_rgba = parse_color(border_color)
image = image.convert("RGBA")
w, h = image.size
scale = max_long_side / max(w, h)
new_w = int(w * scale)
new_h = int(h * scale)
image = image.resize((new_w, new_h), Image.LANCZOS)
aspect = new_w / new_h
best_choice = None
best_diff = float("inf")
for total in target_total_patches:
for rows in range(1, total + 1):
if total % rows == 0:
cols = total // rows
diff = abs((cols / rows) - aspect)
if diff < best_diff:
best_diff = diff
best_choice = (rows, cols)
rows, cols = best_choice
patch_w = math.ceil(new_w / cols)
patch_h = math.ceil(new_h / rows)
patch_size = max(patch_w, patch_h)
pad_w = patch_size * cols
pad_h = patch_size * rows
if padding_color is None:
canvas = Image.new("RGBA", (pad_w, pad_h), (0, 0, 0, 0))
else:
canvas = Image.new("RGBA", (pad_w, pad_h), parse_color(padding_color))
offset_x = (pad_w - new_w) // 2
offset_y = (pad_h - new_h) // 2
canvas.paste(image, (offset_x, offset_y), image)
arr = np.array(canvas, dtype=np.uint8)
patches = (
arr.reshape(rows, patch_size, cols, patch_size, 4)
.transpose(0, 2, 1, 3, 4)
.reshape(rows * cols, patch_size, patch_size, 4)
)
ss = max(1, supersample)
scaled_patch = patch_size * ss
scaled_border = border_thickness * ss
scaled_radius = corner_radius * ss
scaled_spacing = spacing * ss
tile_w = scaled_patch + 2 * scaled_border
tile_h = scaled_patch + 2 * scaled_border
mosaic_w = cols * tile_w + (cols + 1) * scaled_spacing
mosaic_h = rows * tile_h + (rows + 1) * scaled_spacing
mosaic = Image.new("RGBA", (mosaic_w, mosaic_h), (0, 0, 0, 0))
def create_tile(patch_img):
patch_img = patch_img.resize(
(scaled_patch, scaled_patch),
Image.NEAREST
)
tile = Image.new("RGBA", (tile_w, tile_h), (0, 0, 0, 0))
draw = ImageDraw.Draw(tile)
if rounded:
draw.rounded_rectangle(
[0, 0, tile_w - 1, tile_h - 1],
radius=scaled_radius,
fill=border_rgba,
)
else:
draw.rectangle(
[0, 0, tile_w - 1, tile_h - 1],
fill=border_rgba,
)
if rounded and true_clipping:
mask = Image.new("L", (scaled_patch, scaled_patch), 0)
mask_draw = ImageDraw.Draw(mask)
mask_draw.rounded_rectangle(
[0, 0, scaled_patch - 1, scaled_patch - 1],
radius=max(0, scaled_radius - scaled_border),
fill=255,
)
tile.paste(patch_img, (scaled_border, scaled_border), mask)
else:
tile.paste(patch_img, (scaled_border, scaled_border), patch_img)
return tile
for idx in range(patches.shape[0]):
r = idx // cols
c = idx % cols
patch_img = Image.fromarray(patches[idx])
tile = create_tile(patch_img)
x = scaled_spacing + c * (tile_w + scaled_spacing)
y = scaled_spacing + r * (tile_h + scaled_spacing)
mosaic.paste(tile, (x, y), tile)
if ss > 1 and output_scale_mode == "downscale":
mosaic = mosaic.resize(
(mosaic_w // ss, mosaic_h // ss),
Image.LANCZOS
)
return mosaic, patches