File size: 4,740 Bytes
856dd67 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 | """
vit_mosaic.py
ViT-style patch mosaic generator
Supports:
- Auto grid selection (12 / 16 patches)
- Transparent or colored padding
- Rounded borders
- True rounded clipping
- Supersampling
- Downscale or keep resolution
"""
import math
import numpy as np
from PIL import Image, ImageDraw
from typing import Iterable, Tuple, Union
ColorType = Union[Tuple[int, int, int], str]
def parse_color(color: ColorType):
if isinstance(color, tuple):
return (*color, 255)
if isinstance(color, str):
color = color.strip()
if color.startswith("#"):
r = int(color[1:3], 16)
g = int(color[3:5], 16)
b = int(color[5:7], 16)
return (r, g, b, 255)
raise ValueError("Color must be RGB tuple or hex string '#RRGGBB'")
def make_vit_mosaic(
image: Image.Image,
target_total_patches: Iterable[int] = (12, 16),
max_long_side: int = 256,
spacing: int = 12,
border_thickness: int = 14,
border_color: ColorType = "#00FFFF",
padding_color: Union[None, ColorType] = None,
corner_radius: int = 22,
rounded: bool = True,
true_clipping: bool = True,
supersample: int = 1,
output_scale_mode: str = "keep", # "keep" or "downscale"
):
border_rgba = parse_color(border_color)
image = image.convert("RGBA")
w, h = image.size
scale = max_long_side / max(w, h)
new_w = int(w * scale)
new_h = int(h * scale)
image = image.resize((new_w, new_h), Image.LANCZOS)
aspect = new_w / new_h
best_choice = None
best_diff = float("inf")
for total in target_total_patches:
for rows in range(1, total + 1):
if total % rows == 0:
cols = total // rows
diff = abs((cols / rows) - aspect)
if diff < best_diff:
best_diff = diff
best_choice = (rows, cols)
rows, cols = best_choice
patch_w = math.ceil(new_w / cols)
patch_h = math.ceil(new_h / rows)
patch_size = max(patch_w, patch_h)
pad_w = patch_size * cols
pad_h = patch_size * rows
if padding_color is None:
canvas = Image.new("RGBA", (pad_w, pad_h), (0, 0, 0, 0))
else:
canvas = Image.new("RGBA", (pad_w, pad_h), parse_color(padding_color))
offset_x = (pad_w - new_w) // 2
offset_y = (pad_h - new_h) // 2
canvas.paste(image, (offset_x, offset_y), image)
arr = np.array(canvas, dtype=np.uint8)
patches = (
arr.reshape(rows, patch_size, cols, patch_size, 4)
.transpose(0, 2, 1, 3, 4)
.reshape(rows * cols, patch_size, patch_size, 4)
)
ss = max(1, supersample)
scaled_patch = patch_size * ss
scaled_border = border_thickness * ss
scaled_radius = corner_radius * ss
scaled_spacing = spacing * ss
tile_w = scaled_patch + 2 * scaled_border
tile_h = scaled_patch + 2 * scaled_border
mosaic_w = cols * tile_w + (cols + 1) * scaled_spacing
mosaic_h = rows * tile_h + (rows + 1) * scaled_spacing
mosaic = Image.new("RGBA", (mosaic_w, mosaic_h), (0, 0, 0, 0))
def create_tile(patch_img):
patch_img = patch_img.resize(
(scaled_patch, scaled_patch),
Image.NEAREST
)
tile = Image.new("RGBA", (tile_w, tile_h), (0, 0, 0, 0))
draw = ImageDraw.Draw(tile)
if rounded:
draw.rounded_rectangle(
[0, 0, tile_w - 1, tile_h - 1],
radius=scaled_radius,
fill=border_rgba,
)
else:
draw.rectangle(
[0, 0, tile_w - 1, tile_h - 1],
fill=border_rgba,
)
if rounded and true_clipping:
mask = Image.new("L", (scaled_patch, scaled_patch), 0)
mask_draw = ImageDraw.Draw(mask)
mask_draw.rounded_rectangle(
[0, 0, scaled_patch - 1, scaled_patch - 1],
radius=max(0, scaled_radius - scaled_border),
fill=255,
)
tile.paste(patch_img, (scaled_border, scaled_border), mask)
else:
tile.paste(patch_img, (scaled_border, scaled_border), patch_img)
return tile
for idx in range(patches.shape[0]):
r = idx // cols
c = idx % cols
patch_img = Image.fromarray(patches[idx])
tile = create_tile(patch_img)
x = scaled_spacing + c * (tile_w + scaled_spacing)
y = scaled_spacing + r * (tile_h + scaled_spacing)
mosaic.paste(tile, (x, y), tile)
if ss > 1 and output_scale_mode == "downscale":
mosaic = mosaic.resize(
(mosaic_w // ss, mosaic_h // ss),
Image.LANCZOS
)
return mosaic, patches |