Tenbatsu24
change: layout improvement
9f37452
Raw
History Blame Contribute Delete
14.5 kB
import os
import torch
import numpy as np
import gradio as gr
import torchvision.transforms.functional as TF
from PIL import Image, ImageDraw, ImageFont
from transformers import AutoModel
from sklearn.decomposition import PCA
# ── constants ─────────────────────────────────────────────────────────────────
IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD = [0.229, 0.224, 0.225]
PATCH_SIZE = 16
PCA_COMPONENTS = 3
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
MODEL_IDS = {
"ViT-S/16": {
"DiNO": "OK-AI/dino-vits16-pretrain-in1k",
"iBOT": "OK-AI/ibot-vits16-pretrain-in1k",
"LeJEPA": "OK-AI/lejepa-vits16-pretrain-in1k",
},
"ViT-B/16": {
"DiNO": "OK-AI/dino-vitb16-pretrain-in1k",
"iBOT": "OK-AI/ibot-vitb16-pretrain-in1k",
"LeJEPA": "OK-AI/lejepa-vitb16-pretrain-in1k",
},
}
MODEL_KEYS = ["DiNO", "iBOT", "LeJEPA"]
# ── model loading (cached) ────────────────────────────────────────────────────
_model_cache: dict[str, torch.nn.Module] = {}
def get_model(repo_id: str, revision: str) -> torch.nn.Module:
cache_key = f"{repo_id}@{revision}"
if cache_key not in _model_cache:
model = AutoModel.from_pretrained(
repo_id,
revision=revision,
trust_remote_code=True,
)
model.eval().to(DEVICE)
_model_cache[cache_key] = model
return _model_cache[cache_key]
# ── image helpers ─────────────────────────────────────────────────────────────
def create_coming_soon_image(
image_size,
text="COMING SOON",
background_color=(40, 20, 20),
text_color="white",
):
"""
Create a placeholder image with centered text.
Args:
image_size (int): Width and height of the square image.
text (str): Text to display.
background_color (tuple): RGB background color.
text_color (str|tuple): Text color.
Returns:
PIL.Image.Image
"""
image = Image.new("RGB", (image_size, image_size), color=background_color)
draw = ImageDraw.Draw(image)
try:
font = ImageFont.truetype("arial.ttf", size=max(24, image_size // 12))
except Exception:
font = ImageFont.load_default()
bbox = draw.textbbox((0, 0), text, font=font)
text_width = bbox[2] - bbox[0]
text_height = bbox[3] - bbox[1]
x = (image_size - text_width) // 2
y = (image_size - text_height) // 2
draw.text(
(x, y),
text,
fill=text_color,
font=font,
stroke_width=2,
stroke_fill="black",
)
return image
def resize_image_for_patches(
image: Image.Image,
image_size: int,
patch_size: int = PATCH_SIZE,
) -> torch.Tensor:
"""Resize so height = image_size and width is patch-aligned,
preserving aspect ratio. Returns (1, 3, H, W) float tensor."""
w, h = image.size
h_patches = image_size // patch_size
w_patches = max(1, round((w * image_size) / (h * patch_size)))
target_h = h_patches * patch_size
target_w = w_patches * patch_size
resized = TF.resize(image, (target_h, target_w))
return TF.to_tensor(resized).unsqueeze(0) # (1, 3, H, W)
def preprocess(image_tensor: torch.Tensor) -> torch.Tensor:
"""ImageNet-normalise a (1, 3, H, W) tensor."""
return TF.normalize(
image_tensor.squeeze(0),
mean=IMAGENET_MEAN,
std=IMAGENET_STD,
).unsqueeze(0)
def pad_to_square(img: Image.Image, canvas_size: int) -> Image.Image:
"""Letterbox/pillarbox img onto a square canvas with a dark background.
Ensures all output images share the same dimensions so the Gradio row
never reflows or stretches when aspect ratios differ."""
w, h = img.size
size = max(w, h, canvas_size)
canvas = Image.new("RGB", (size, size), color=(18, 18, 18))
canvas.paste(img, ((size - w) // 2, (size - h) // 2))
return canvas
# ── PCA visualisation ─────────────────────────────────────────────────────────
def pca_vis(
model: torch.nn.Module, image_tensor: torch.Tensor, canvas_size: int
) -> Image.Image:
"""Run image through model, PCA patch features β†’ square-padded RGB PIL image."""
model_input = preprocess(image_tensor).to(DEVICE)
with torch.inference_mode():
outputs = model(model_input)
patch_latent = outputs["patch_latent"][0].cpu().float() # (num_patches, dim)
_, _, H, W = image_tensor.shape
h_patches = H // PATCH_SIZE
w_patches = W // PATCH_SIZE
pca = PCA(n_components=PCA_COMPONENTS, whiten=True)
projected = pca.fit_transform(patch_latent.numpy()) # (num_patches, 3)
projected_t = torch.from_numpy(projected).view(h_patches, w_patches, PCA_COMPONENTS)
vis = torch.sigmoid(projected_t * 2.0)
pca_array = (vis.numpy() * 255).astype(np.uint8) # (H_p, W_p, 3)
# nearest-neighbour upscale β†’ pad to square so all outputs are the same size
upscaled = Image.fromarray(pca_array, mode="RGB").resize((W, H), Image.NEAREST)
return pad_to_square(upscaled, canvas_size)
# ── streaming inference ───────────────────────────────────────────────────────
def run(pil_image: Image.Image, epoch: str, weight_type: str, image_size: int):
"""
Generator: yields updates sequentially across models and sizes.
"""
if pil_image is None:
raise gr.Error("Please upload an image.")
image_size = int(image_size)
pending_img = Image.new("RGB", (image_size, image_size), color=(18, 18, 18))
results = [pending_img] * 6
yield tuple(results)
pil_image = pil_image.convert("RGB")
image_tensor = resize_image_for_patches(pil_image, image_size)
idx = 0
for arch in ["ViT-S/16", "ViT-B/16"]:
for model_key in MODEL_KEYS:
repo_id = MODEL_IDS[arch][model_key]
current_weight = "student" if model_key == "LeJEPA" else weight_type
revision = f"{epoch}/{current_weight}"
try:
model = get_model(repo_id, revision)
results[idx] = pca_vis(model, image_tensor, image_size)
except Exception as e:
print(f"Error processing {repo_id} ({revision}): {e}")
results[idx] = create_coming_soon_image(image_size)
yield tuple(results)
idx += 1
# ── UI ────────────────────────────────────────────────────────────────────────
CSS = """
.title-row {
text-align: center;
padding: 1.5rem 0 0.25rem;
}
/* Higher contrast subtitle */
.subtitle-row {
text-align: center;
color: #d1d5db;
font-size: 0.9rem;
padding-bottom: 1rem;
}
/* Higher contrast section headers */
.arch-header {
font-size: 1.2rem;
font-weight: 700;
margin-top: 1rem;
padding-left: 0.5rem;
border-left: 4px solid #60a5fa;
color: #f3f4f6;
}
/* Brighter model labels */
.model-label {
text-align: center;
font-weight: 700;
font-size: 0.9rem;
color: #f3f4f6;
padding: 0.25rem 0;
}
/* Make links readable before AND after clicking */
.subtitle-row a,
.model-label a,
.custom-footer a,
.subtitle-row a:visited,
.model-label a:visited,
.custom-footer a:visited {
color: #93c5fd;
text-decoration: underline;
text-decoration-color: #93c5fd;
font-weight: 600;
}
/* Strong hover state */
.subtitle-row a:hover,
.model-label a:hover,
.custom-footer a:hover {
color: #dbeafe;
text-decoration-color: #dbeafe;
}
/* Prevent browsers from turning visited links purple/dark */
.subtitle-row a:active,
.model-label a:active,
.custom-footer a:active {
color: #bfdbfe;
}
.output-col {
display: flex !important;
flex-direction: column !important;
align-items: center !important;
gap: 0.25rem !important;
flex: 1 1 0% !important;
min-width: 150px !important;
}
.output-col img {
aspect-ratio: 1 / 1 !important;
object-fit: contain !important;
max-height: 350px !important;
width: 100% !important;
}
/* Improve contrast of markdown/help text */
.gradio-container p {
color: #d1d5db;
}
/* Improve dropdown labels and general form text */
.gradio-container label,
.gradio-container .form,
.gradio-container .prose {
color: #f3f4f6;
}
/* More legible footer */
.custom-footer {
text-align: center;
margin-top: 2.5rem;
padding-top: 1rem;
border-top: 1px solid #374151;
font-size: 0.85rem;
color: #d1d5db;
}
footer { display: none !important; }
"""
with gr.Blocks(css=CSS, title="SSL ViT PCA Visualiser") as demo:
gr.HTML("""
<div class="title-row">
<h1 style="font-size:1.6rem; font-weight:700; margin:0;">
SSL ViT β€” Patch Feature PCA
</h1>
</div>
<div class="subtitle-row">
ImageNet-1K pre-training &nbsp;Β·&nbsp;
<a href="https://huggingface.co/OK-AI" target="_blank">OK-AI Models</a>
</div>
""")
with gr.Row():
with gr.Column(scale=1):
input_image = gr.Image(
type="pil",
label="Input image",
show_label=True,
)
with gr.Row():
opt_epoch = gr.Dropdown(
choices=["ep100", "ep300"],
value="ep300",
label="Epochs",
interactive=True,
)
opt_weight = gr.Dropdown(
choices=["student", "teacher"],
value="teacher",
label="Weight Type",
info="LeJEPA always uses student",
interactive=True,
)
opt_size = gr.Dropdown(
choices=["224", "448", "672", "1280"],
value="672",
label="Image Target Resolution",
interactive=True,
)
run_btn = gr.Button("Visualise", variant="primary")
gr.HTML("""
<p style="font-size:0.8rem; color:#9ca3af; margin-top:0.5rem; line-height:1.5;">
PCA is fit on all patch tokens and projected to
3 components, then scaled with sigmoid for colour display.
Results stream seamlessly into view as individual variants complete.
</p>
<div class="custom-footer">
Models: <a href="https://huggingface.co/OK-AI" target="_blank">OK-AI on HuggingFace</a>
&nbsp;Β·&nbsp;
Code: <a href="https://github.com/Open-Knowledge-AI/lite_ssl" target="_blank">lite_ssl</a>
</div>
""")
with gr.Column(scale=3):
# ── ViT-S/16 Row ──
gr.HTML('<div class="arch-header">ViT-S/16 Grid</div>')
with gr.Row(equal_height=True):
with gr.Column(elem_classes="output-col"):
gr.HTML(
f'<div class="model-label"><a href="https://huggingface.co/{MODEL_IDS["ViT-S/16"]["DiNO"]}" target="_blank">DiNO (S/16)</a></div>'
)
out_dino_s = gr.Image(show_label=False, interactive=False)
with gr.Column(elem_classes="output-col"):
gr.HTML(
f'<div class="model-label"><a href="https://huggingface.co/{MODEL_IDS["ViT-S/16"]["iBOT"]}" target="_blank">iBOT (S/16)</a></div>'
)
out_ibot_s = gr.Image(show_label=False, interactive=False)
with gr.Column(elem_classes="output-col"):
gr.HTML(
f'<div class="model-label"><a href="https://huggingface.co/{MODEL_IDS["ViT-S/16"]["LeJEPA"]}" target="_blank">LeJEPA (S/16)</a></div>'
)
out_lejepa_s = gr.Image(show_label=False, interactive=False)
# ── ViT-B/16 Row ──
gr.HTML('<div class="arch-header">ViT-B/16 Grid</div>')
with gr.Row(equal_height=True):
with gr.Column(elem_classes="output-col"):
gr.HTML(
f'<div class="model-label"><a href="https://huggingface.co/{MODEL_IDS["ViT-B/16"]["DiNO"]}" target="_blank">DiNO (B/16)</a></div>'
)
out_dino_b = gr.Image(show_label=False, interactive=False)
with gr.Column(elem_classes="output-col"):
gr.HTML(
f'<div class="model-label"><a href="https://huggingface.co/{MODEL_IDS["ViT-B/16"]["iBOT"]}" target="_blank">iBOT (B/16)</a></div>'
)
out_ibot_b = gr.Image(show_label=False, interactive=False)
with gr.Column(elem_classes="output-col"):
gr.HTML(
f'<div class="model-label"><a href="https://huggingface.co/{MODEL_IDS["ViT-B/16"]["LeJEPA"]}" target="_blank">LeJEPA (B/16)</a></div>'
)
out_lejepa_b = gr.Image(show_label=False, interactive=False)
# Wire outputs orderly following the exact resolution pattern tracking inside the `run` loop
output_targets = [
out_dino_s,
out_ibot_s,
out_lejepa_s,
out_dino_b,
out_ibot_b,
out_lejepa_b,
]
run_btn.click(
fn=run,
inputs=[input_image, opt_epoch, opt_weight, opt_size],
outputs=output_targets,
)
if os.path.exists("examples"):
gr.Examples(
examples=[
[f"examples/{f}"]
for f in sorted(os.listdir("examples"))
if f.lower().endswith((".jpg", ".jpeg", ".png", ".webp"))
],
inputs=[input_image],
)
if __name__ == "__main__":
demo.launch()