import os import torch import numpy as np import gradio as gr import torchvision.transforms.functional as TF from PIL import Image, ImageDraw, ImageFont from transformers import AutoModel from sklearn.decomposition import PCA # ── constants ───────────────────────────────────────────────────────────────── IMAGENET_MEAN = [0.485, 0.456, 0.406] IMAGENET_STD = [0.229, 0.224, 0.225] PATCH_SIZE = 16 PCA_COMPONENTS = 3 DEVICE = "cuda" if torch.cuda.is_available() else "cpu" MODEL_IDS = { "ViT-S/16": { "DiNO": "OK-AI/dino-vits16-pretrain-in1k", "iBOT": "OK-AI/ibot-vits16-pretrain-in1k", "LeJEPA": "OK-AI/lejepa-vits16-pretrain-in1k", }, "ViT-B/16": { "DiNO": "OK-AI/dino-vitb16-pretrain-in1k", "iBOT": "OK-AI/ibot-vitb16-pretrain-in1k", "LeJEPA": "OK-AI/lejepa-vitb16-pretrain-in1k", }, } MODEL_KEYS = ["DiNO", "iBOT", "LeJEPA"] # ── model loading (cached) ──────────────────────────────────────────────────── _model_cache: dict[str, torch.nn.Module] = {} def get_model(repo_id: str, revision: str) -> torch.nn.Module: cache_key = f"{repo_id}@{revision}" if cache_key not in _model_cache: model = AutoModel.from_pretrained( repo_id, revision=revision, trust_remote_code=True, ) model.eval().to(DEVICE) _model_cache[cache_key] = model return _model_cache[cache_key] # ── image helpers ───────────────────────────────────────────────────────────── def create_coming_soon_image( image_size, text="COMING SOON", background_color=(40, 20, 20), text_color="white", ): """ Create a placeholder image with centered text. Args: image_size (int): Width and height of the square image. text (str): Text to display. background_color (tuple): RGB background color. text_color (str|tuple): Text color. Returns: PIL.Image.Image """ image = Image.new("RGB", (image_size, image_size), color=background_color) draw = ImageDraw.Draw(image) try: font = ImageFont.truetype("arial.ttf", size=max(24, image_size // 12)) except Exception: font = ImageFont.load_default() bbox = draw.textbbox((0, 0), text, font=font) text_width = bbox[2] - bbox[0] text_height = bbox[3] - bbox[1] x = (image_size - text_width) // 2 y = (image_size - text_height) // 2 draw.text( (x, y), text, fill=text_color, font=font, stroke_width=2, stroke_fill="black", ) return image def resize_image_for_patches( image: Image.Image, image_size: int, patch_size: int = PATCH_SIZE, ) -> torch.Tensor: """Resize so height = image_size and width is patch-aligned, preserving aspect ratio. Returns (1, 3, H, W) float tensor.""" w, h = image.size h_patches = image_size // patch_size w_patches = max(1, round((w * image_size) / (h * patch_size))) target_h = h_patches * patch_size target_w = w_patches * patch_size resized = TF.resize(image, (target_h, target_w)) return TF.to_tensor(resized).unsqueeze(0) # (1, 3, H, W) def preprocess(image_tensor: torch.Tensor) -> torch.Tensor: """ImageNet-normalise a (1, 3, H, W) tensor.""" return TF.normalize( image_tensor.squeeze(0), mean=IMAGENET_MEAN, std=IMAGENET_STD, ).unsqueeze(0) def pad_to_square(img: Image.Image, canvas_size: int) -> Image.Image: """Letterbox/pillarbox img onto a square canvas with a dark background. Ensures all output images share the same dimensions so the Gradio row never reflows or stretches when aspect ratios differ.""" w, h = img.size size = max(w, h, canvas_size) canvas = Image.new("RGB", (size, size), color=(18, 18, 18)) canvas.paste(img, ((size - w) // 2, (size - h) // 2)) return canvas # ── PCA visualisation ───────────────────────────────────────────────────────── def pca_vis( model: torch.nn.Module, image_tensor: torch.Tensor, canvas_size: int ) -> Image.Image: """Run image through model, PCA patch features → square-padded RGB PIL image.""" model_input = preprocess(image_tensor).to(DEVICE) with torch.inference_mode(): outputs = model(model_input) patch_latent = outputs["patch_latent"][0].cpu().float() # (num_patches, dim) _, _, H, W = image_tensor.shape h_patches = H // PATCH_SIZE w_patches = W // PATCH_SIZE pca = PCA(n_components=PCA_COMPONENTS, whiten=True) projected = pca.fit_transform(patch_latent.numpy()) # (num_patches, 3) projected_t = torch.from_numpy(projected).view(h_patches, w_patches, PCA_COMPONENTS) vis = torch.sigmoid(projected_t * 2.0) pca_array = (vis.numpy() * 255).astype(np.uint8) # (H_p, W_p, 3) # nearest-neighbour upscale → pad to square so all outputs are the same size upscaled = Image.fromarray(pca_array, mode="RGB").resize((W, H), Image.NEAREST) return pad_to_square(upscaled, canvas_size) # ── streaming inference ─────────────────────────────────────────────────────── def run(pil_image: Image.Image, epoch: str, weight_type: str, image_size: int): """ Generator: yields updates sequentially across models and sizes. """ if pil_image is None: raise gr.Error("Please upload an image.") image_size = int(image_size) pending_img = Image.new("RGB", (image_size, image_size), color=(18, 18, 18)) results = [pending_img] * 6 yield tuple(results) pil_image = pil_image.convert("RGB") image_tensor = resize_image_for_patches(pil_image, image_size) idx = 0 for arch in ["ViT-S/16", "ViT-B/16"]: for model_key in MODEL_KEYS: repo_id = MODEL_IDS[arch][model_key] current_weight = "student" if model_key == "LeJEPA" else weight_type revision = f"{epoch}/{current_weight}" try: model = get_model(repo_id, revision) results[idx] = pca_vis(model, image_tensor, image_size) except Exception as e: print(f"Error processing {repo_id} ({revision}): {e}") results[idx] = create_coming_soon_image(image_size) yield tuple(results) idx += 1 # ── UI ──────────────────────────────────────────────────────────────────────── CSS = """ .title-row { text-align: center; padding: 1.5rem 0 0.25rem; } /* Higher contrast subtitle */ .subtitle-row { text-align: center; color: #d1d5db; font-size: 0.9rem; padding-bottom: 1rem; } /* Higher contrast section headers */ .arch-header { font-size: 1.2rem; font-weight: 700; margin-top: 1rem; padding-left: 0.5rem; border-left: 4px solid #60a5fa; color: #f3f4f6; } /* Brighter model labels */ .model-label { text-align: center; font-weight: 700; font-size: 0.9rem; color: #f3f4f6; padding: 0.25rem 0; } /* Make links readable before AND after clicking */ .subtitle-row a, .model-label a, .custom-footer a, .subtitle-row a:visited, .model-label a:visited, .custom-footer a:visited { color: #93c5fd; text-decoration: underline; text-decoration-color: #93c5fd; font-weight: 600; } /* Strong hover state */ .subtitle-row a:hover, .model-label a:hover, .custom-footer a:hover { color: #dbeafe; text-decoration-color: #dbeafe; } /* Prevent browsers from turning visited links purple/dark */ .subtitle-row a:active, .model-label a:active, .custom-footer a:active { color: #bfdbfe; } .output-col { display: flex !important; flex-direction: column !important; align-items: center !important; gap: 0.25rem !important; flex: 1 1 0% !important; min-width: 150px !important; } .output-col img { aspect-ratio: 1 / 1 !important; object-fit: contain !important; max-height: 350px !important; width: 100% !important; } /* Improve contrast of markdown/help text */ .gradio-container p { color: #d1d5db; } /* Improve dropdown labels and general form text */ .gradio-container label, .gradio-container .form, .gradio-container .prose { color: #f3f4f6; } /* More legible footer */ .custom-footer { text-align: center; margin-top: 2.5rem; padding-top: 1rem; border-top: 1px solid #374151; font-size: 0.85rem; color: #d1d5db; } footer { display: none !important; } """ with gr.Blocks(css=CSS, title="SSL ViT PCA Visualiser") as demo: gr.HTML("""
PCA is fit on all patch tokens and projected to 3 components, then scaled with sigmoid for colour display. Results stream seamlessly into view as individual variants complete.
""") with gr.Column(scale=3): # ── ViT-S/16 Row ── gr.HTML('