""" vit_mosaic.py ViT-style patch mosaic generator Supports: - Auto grid selection (12 / 16 patches) - Transparent or colored padding - Rounded borders - True rounded clipping - Supersampling - Downscale or keep resolution """ import math import numpy as np from PIL import Image, ImageDraw from typing import Iterable, Tuple, Union ColorType = Union[Tuple[int, int, int], str] def parse_color(color: ColorType): if isinstance(color, tuple): return (*color, 255) if isinstance(color, str): color = color.strip() if color.startswith("#"): r = int(color[1:3], 16) g = int(color[3:5], 16) b = int(color[5:7], 16) return (r, g, b, 255) raise ValueError("Color must be RGB tuple or hex string '#RRGGBB'") def make_vit_mosaic( image: Image.Image, target_total_patches: Iterable[int] = (12, 16), max_long_side: int = 256, spacing: int = 12, border_thickness: int = 14, border_color: ColorType = "#00FFFF", padding_color: Union[None, ColorType] = None, corner_radius: int = 22, rounded: bool = True, true_clipping: bool = True, supersample: int = 1, output_scale_mode: str = "keep", # "keep" or "downscale" ): border_rgba = parse_color(border_color) image = image.convert("RGBA") w, h = image.size scale = max_long_side / max(w, h) new_w = int(w * scale) new_h = int(h * scale) image = image.resize((new_w, new_h), Image.LANCZOS) aspect = new_w / new_h best_choice = None best_diff = float("inf") for total in target_total_patches: for rows in range(1, total + 1): if total % rows == 0: cols = total // rows diff = abs((cols / rows) - aspect) if diff < best_diff: best_diff = diff best_choice = (rows, cols) rows, cols = best_choice patch_w = math.ceil(new_w / cols) patch_h = math.ceil(new_h / rows) patch_size = max(patch_w, patch_h) pad_w = patch_size * cols pad_h = patch_size * rows if padding_color is None: canvas = Image.new("RGBA", (pad_w, pad_h), (0, 0, 0, 0)) else: canvas = Image.new("RGBA", (pad_w, pad_h), parse_color(padding_color)) offset_x = (pad_w - new_w) // 2 offset_y = (pad_h - new_h) // 2 canvas.paste(image, (offset_x, offset_y), image) arr = np.array(canvas, dtype=np.uint8) patches = ( arr.reshape(rows, patch_size, cols, patch_size, 4) .transpose(0, 2, 1, 3, 4) .reshape(rows * cols, patch_size, patch_size, 4) ) ss = max(1, supersample) scaled_patch = patch_size * ss scaled_border = border_thickness * ss scaled_radius = corner_radius * ss scaled_spacing = spacing * ss tile_w = scaled_patch + 2 * scaled_border tile_h = scaled_patch + 2 * scaled_border mosaic_w = cols * tile_w + (cols + 1) * scaled_spacing mosaic_h = rows * tile_h + (rows + 1) * scaled_spacing mosaic = Image.new("RGBA", (mosaic_w, mosaic_h), (0, 0, 0, 0)) def create_tile(patch_img): patch_img = patch_img.resize( (scaled_patch, scaled_patch), Image.NEAREST ) tile = Image.new("RGBA", (tile_w, tile_h), (0, 0, 0, 0)) draw = ImageDraw.Draw(tile) if rounded: draw.rounded_rectangle( [0, 0, tile_w - 1, tile_h - 1], radius=scaled_radius, fill=border_rgba, ) else: draw.rectangle( [0, 0, tile_w - 1, tile_h - 1], fill=border_rgba, ) if rounded and true_clipping: mask = Image.new("L", (scaled_patch, scaled_patch), 0) mask_draw = ImageDraw.Draw(mask) mask_draw.rounded_rectangle( [0, 0, scaled_patch - 1, scaled_patch - 1], radius=max(0, scaled_radius - scaled_border), fill=255, ) tile.paste(patch_img, (scaled_border, scaled_border), mask) else: tile.paste(patch_img, (scaled_border, scaled_border), patch_img) return tile for idx in range(patches.shape[0]): r = idx // cols c = idx % cols patch_img = Image.fromarray(patches[idx]) tile = create_tile(patch_img) x = scaled_spacing + c * (tile_w + scaled_spacing) y = scaled_spacing + r * (tile_h + scaled_spacing) mosaic.paste(tile, (x, y), tile) if ss > 1 and output_scale_mode == "downscale": mosaic = mosaic.resize( (mosaic_w // ss, mosaic_h // ss), Image.LANCZOS ) return mosaic, patches