File size: 4,740 Bytes
856dd67
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
"""
vit_mosaic.py

ViT-style patch mosaic generator
Supports:
- Auto grid selection (12 / 16 patches)
- Transparent or colored padding
- Rounded borders
- True rounded clipping
- Supersampling
- Downscale or keep resolution
"""

import math
import numpy as np
from PIL import Image, ImageDraw
from typing import Iterable, Tuple, Union


ColorType = Union[Tuple[int, int, int], str]


def parse_color(color: ColorType):
    if isinstance(color, tuple):
        return (*color, 255)
    if isinstance(color, str):
        color = color.strip()
        if color.startswith("#"):
            r = int(color[1:3], 16)
            g = int(color[3:5], 16)
            b = int(color[5:7], 16)
            return (r, g, b, 255)
    raise ValueError("Color must be RGB tuple or hex string '#RRGGBB'")


def make_vit_mosaic(
    image: Image.Image,
    target_total_patches: Iterable[int] = (12, 16),
    max_long_side: int = 256,
    spacing: int = 12,
    border_thickness: int = 14,
    border_color: ColorType = "#00FFFF",
    padding_color: Union[None, ColorType] = None,
    corner_radius: int = 22,
    rounded: bool = True,
    true_clipping: bool = True,
    supersample: int = 1,
    output_scale_mode: str = "keep",  # "keep" or "downscale"
):
    border_rgba = parse_color(border_color)
    image = image.convert("RGBA")
    w, h = image.size

    scale = max_long_side / max(w, h)
    new_w = int(w * scale)
    new_h = int(h * scale)
    image = image.resize((new_w, new_h), Image.LANCZOS)

    aspect = new_w / new_h
    best_choice = None
    best_diff = float("inf")

    for total in target_total_patches:
        for rows in range(1, total + 1):
            if total % rows == 0:
                cols = total // rows
                diff = abs((cols / rows) - aspect)
                if diff < best_diff:
                    best_diff = diff
                    best_choice = (rows, cols)

    rows, cols = best_choice

    patch_w = math.ceil(new_w / cols)
    patch_h = math.ceil(new_h / rows)
    patch_size = max(patch_w, patch_h)

    pad_w = patch_size * cols
    pad_h = patch_size * rows

    if padding_color is None:
        canvas = Image.new("RGBA", (pad_w, pad_h), (0, 0, 0, 0))
    else:
        canvas = Image.new("RGBA", (pad_w, pad_h), parse_color(padding_color))

    offset_x = (pad_w - new_w) // 2
    offset_y = (pad_h - new_h) // 2
    canvas.paste(image, (offset_x, offset_y), image)

    arr = np.array(canvas, dtype=np.uint8)

    patches = (
        arr.reshape(rows, patch_size, cols, patch_size, 4)
        .transpose(0, 2, 1, 3, 4)
        .reshape(rows * cols, patch_size, patch_size, 4)
    )

    ss = max(1, supersample)

    scaled_patch = patch_size * ss
    scaled_border = border_thickness * ss
    scaled_radius = corner_radius * ss
    scaled_spacing = spacing * ss

    tile_w = scaled_patch + 2 * scaled_border
    tile_h = scaled_patch + 2 * scaled_border

    mosaic_w = cols * tile_w + (cols + 1) * scaled_spacing
    mosaic_h = rows * tile_h + (rows + 1) * scaled_spacing

    mosaic = Image.new("RGBA", (mosaic_w, mosaic_h), (0, 0, 0, 0))

    def create_tile(patch_img):
        patch_img = patch_img.resize(
            (scaled_patch, scaled_patch),
            Image.NEAREST
        )

        tile = Image.new("RGBA", (tile_w, tile_h), (0, 0, 0, 0))
        draw = ImageDraw.Draw(tile)

        if rounded:
            draw.rounded_rectangle(
                [0, 0, tile_w - 1, tile_h - 1],
                radius=scaled_radius,
                fill=border_rgba,
            )
        else:
            draw.rectangle(
                [0, 0, tile_w - 1, tile_h - 1],
                fill=border_rgba,
            )

        if rounded and true_clipping:
            mask = Image.new("L", (scaled_patch, scaled_patch), 0)
            mask_draw = ImageDraw.Draw(mask)
            mask_draw.rounded_rectangle(
                [0, 0, scaled_patch - 1, scaled_patch - 1],
                radius=max(0, scaled_radius - scaled_border),
                fill=255,
            )
            tile.paste(patch_img, (scaled_border, scaled_border), mask)
        else:
            tile.paste(patch_img, (scaled_border, scaled_border), patch_img)

        return tile

    for idx in range(patches.shape[0]):
        r = idx // cols
        c = idx % cols
        patch_img = Image.fromarray(patches[idx])
        tile = create_tile(patch_img)

        x = scaled_spacing + c * (tile_w + scaled_spacing)
        y = scaled_spacing + r * (tile_h + scaled_spacing)
        mosaic.paste(tile, (x, y), tile)

    if ss > 1 and output_scale_mode == "downscale":
        mosaic = mosaic.resize(
            (mosaic_w // ss, mosaic_h // ss),
            Image.LANCZOS
        )

    return mosaic, patches