| import os |
|
|
| import torch |
| import numpy as np |
| import gradio as gr |
| import torchvision.transforms.functional as TF |
|
|
| from PIL import Image, ImageDraw, ImageFont |
|
|
| from transformers import AutoModel |
| from sklearn.decomposition import PCA |
|
|
| |
|
|
| IMAGENET_MEAN = [0.485, 0.456, 0.406] |
| IMAGENET_STD = [0.229, 0.224, 0.225] |
|
|
| PATCH_SIZE = 16 |
| PCA_COMPONENTS = 3 |
|
|
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
| MODEL_IDS = { |
| "ViT-S/16": { |
| "DiNO": "OK-AI/dino-vits16-pretrain-in1k", |
| "iBOT": "OK-AI/ibot-vits16-pretrain-in1k", |
| "LeJEPA": "OK-AI/lejepa-vits16-pretrain-in1k", |
| }, |
| "ViT-B/16": { |
| "DiNO": "OK-AI/dino-vitb16-pretrain-in1k", |
| "iBOT": "OK-AI/ibot-vitb16-pretrain-in1k", |
| "LeJEPA": "OK-AI/lejepa-vitb16-pretrain-in1k", |
| }, |
| } |
|
|
| MODEL_KEYS = ["DiNO", "iBOT", "LeJEPA"] |
|
|
| |
|
|
| _model_cache: dict[str, torch.nn.Module] = {} |
|
|
|
|
| def get_model(repo_id: str, revision: str) -> torch.nn.Module: |
| cache_key = f"{repo_id}@{revision}" |
| if cache_key not in _model_cache: |
| model = AutoModel.from_pretrained( |
| repo_id, |
| revision=revision, |
| trust_remote_code=True, |
| ) |
| model.eval().to(DEVICE) |
| _model_cache[cache_key] = model |
| return _model_cache[cache_key] |
|
|
|
|
| |
| def create_coming_soon_image( |
| image_size, |
| text="COMING SOON", |
| background_color=(40, 20, 20), |
| text_color="white", |
| ): |
| """ |
| Create a placeholder image with centered text. |
| |
| Args: |
| image_size (int): Width and height of the square image. |
| text (str): Text to display. |
| background_color (tuple): RGB background color. |
| text_color (str|tuple): Text color. |
| |
| Returns: |
| PIL.Image.Image |
| """ |
| image = Image.new("RGB", (image_size, image_size), color=background_color) |
| draw = ImageDraw.Draw(image) |
|
|
| try: |
| font = ImageFont.truetype("arial.ttf", size=max(24, image_size // 12)) |
| except Exception: |
| font = ImageFont.load_default() |
|
|
| bbox = draw.textbbox((0, 0), text, font=font) |
| text_width = bbox[2] - bbox[0] |
| text_height = bbox[3] - bbox[1] |
|
|
| x = (image_size - text_width) // 2 |
| y = (image_size - text_height) // 2 |
|
|
| draw.text( |
| (x, y), |
| text, |
| fill=text_color, |
| font=font, |
| stroke_width=2, |
| stroke_fill="black", |
| ) |
|
|
| return image |
|
|
|
|
| def resize_image_for_patches( |
| image: Image.Image, |
| image_size: int, |
| patch_size: int = PATCH_SIZE, |
| ) -> torch.Tensor: |
| """Resize so height = image_size and width is patch-aligned, |
| preserving aspect ratio. Returns (1, 3, H, W) float tensor.""" |
| w, h = image.size |
| h_patches = image_size // patch_size |
| w_patches = max(1, round((w * image_size) / (h * patch_size))) |
| target_h = h_patches * patch_size |
| target_w = w_patches * patch_size |
| resized = TF.resize(image, (target_h, target_w)) |
| return TF.to_tensor(resized).unsqueeze(0) |
|
|
|
|
| def preprocess(image_tensor: torch.Tensor) -> torch.Tensor: |
| """ImageNet-normalise a (1, 3, H, W) tensor.""" |
| return TF.normalize( |
| image_tensor.squeeze(0), |
| mean=IMAGENET_MEAN, |
| std=IMAGENET_STD, |
| ).unsqueeze(0) |
|
|
|
|
| def pad_to_square(img: Image.Image, canvas_size: int) -> Image.Image: |
| """Letterbox/pillarbox img onto a square canvas with a dark background. |
| Ensures all output images share the same dimensions so the Gradio row |
| never reflows or stretches when aspect ratios differ.""" |
| w, h = img.size |
| size = max(w, h, canvas_size) |
| canvas = Image.new("RGB", (size, size), color=(18, 18, 18)) |
| canvas.paste(img, ((size - w) // 2, (size - h) // 2)) |
| return canvas |
|
|
|
|
| |
|
|
|
|
| def pca_vis( |
| model: torch.nn.Module, image_tensor: torch.Tensor, canvas_size: int |
| ) -> Image.Image: |
| """Run image through model, PCA patch features β square-padded RGB PIL image.""" |
| model_input = preprocess(image_tensor).to(DEVICE) |
|
|
| with torch.inference_mode(): |
| outputs = model(model_input) |
|
|
| patch_latent = outputs["patch_latent"][0].cpu().float() |
|
|
| _, _, H, W = image_tensor.shape |
| h_patches = H // PATCH_SIZE |
| w_patches = W // PATCH_SIZE |
|
|
| pca = PCA(n_components=PCA_COMPONENTS, whiten=True) |
| projected = pca.fit_transform(patch_latent.numpy()) |
|
|
| projected_t = torch.from_numpy(projected).view(h_patches, w_patches, PCA_COMPONENTS) |
| vis = torch.sigmoid(projected_t * 2.0) |
| pca_array = (vis.numpy() * 255).astype(np.uint8) |
|
|
| |
| upscaled = Image.fromarray(pca_array, mode="RGB").resize((W, H), Image.NEAREST) |
| return pad_to_square(upscaled, canvas_size) |
|
|
|
|
| |
|
|
|
|
| def run(pil_image: Image.Image, epoch: str, weight_type: str, image_size: int): |
| """ |
| Generator: yields updates sequentially across models and sizes. |
| """ |
| if pil_image is None: |
| raise gr.Error("Please upload an image.") |
|
|
| image_size = int(image_size) |
| pending_img = Image.new("RGB", (image_size, image_size), color=(18, 18, 18)) |
|
|
| results = [pending_img] * 6 |
| yield tuple(results) |
|
|
| pil_image = pil_image.convert("RGB") |
| image_tensor = resize_image_for_patches(pil_image, image_size) |
|
|
| idx = 0 |
| for arch in ["ViT-S/16", "ViT-B/16"]: |
| for model_key in MODEL_KEYS: |
| repo_id = MODEL_IDS[arch][model_key] |
|
|
| current_weight = "student" if model_key == "LeJEPA" else weight_type |
| revision = f"{epoch}/{current_weight}" |
|
|
| try: |
| model = get_model(repo_id, revision) |
| results[idx] = pca_vis(model, image_tensor, image_size) |
| except Exception as e: |
| print(f"Error processing {repo_id} ({revision}): {e}") |
| results[idx] = create_coming_soon_image(image_size) |
|
|
| yield tuple(results) |
| idx += 1 |
|
|
|
|
| |
|
|
| CSS = """ |
| .title-row { |
| text-align: center; |
| padding: 1.5rem 0 0.25rem; |
| } |
| /* Higher contrast subtitle */ |
| .subtitle-row { |
| text-align: center; |
| color: #d1d5db; |
| font-size: 0.9rem; |
| padding-bottom: 1rem; |
| } |
| /* Higher contrast section headers */ |
| .arch-header { |
| font-size: 1.2rem; |
| font-weight: 700; |
| margin-top: 1rem; |
| padding-left: 0.5rem; |
| border-left: 4px solid #60a5fa; |
| color: #f3f4f6; |
| } |
| /* Brighter model labels */ |
| .model-label { |
| text-align: center; |
| font-weight: 700; |
| font-size: 0.9rem; |
| color: #f3f4f6; |
| padding: 0.25rem 0; |
| } |
| /* Make links readable before AND after clicking */ |
| .subtitle-row a, |
| .model-label a, |
| .custom-footer a, |
| .subtitle-row a:visited, |
| .model-label a:visited, |
| .custom-footer a:visited { |
| color: #93c5fd; |
| text-decoration: underline; |
| text-decoration-color: #93c5fd; |
| font-weight: 600; |
| } |
| /* Strong hover state */ |
| .subtitle-row a:hover, |
| .model-label a:hover, |
| .custom-footer a:hover { |
| color: #dbeafe; |
| text-decoration-color: #dbeafe; |
| } |
| /* Prevent browsers from turning visited links purple/dark */ |
| .subtitle-row a:active, |
| .model-label a:active, |
| .custom-footer a:active { |
| color: #bfdbfe; |
| } |
| .output-col { |
| display: flex !important; |
| flex-direction: column !important; |
| align-items: center !important; |
| gap: 0.25rem !important; |
| flex: 1 1 0% !important; |
| min-width: 150px !important; |
| } |
| .output-col img { |
| aspect-ratio: 1 / 1 !important; |
| object-fit: contain !important; |
| max-height: 350px !important; |
| width: 100% !important; |
| } |
| /* Improve contrast of markdown/help text */ |
| .gradio-container p { |
| color: #d1d5db; |
| } |
| /* Improve dropdown labels and general form text */ |
| .gradio-container label, |
| .gradio-container .form, |
| .gradio-container .prose { |
| color: #f3f4f6; |
| } |
| /* More legible footer */ |
| .custom-footer { |
| text-align: center; |
| margin-top: 2.5rem; |
| padding-top: 1rem; |
| border-top: 1px solid #374151; |
| font-size: 0.85rem; |
| color: #d1d5db; |
| } |
| footer { display: none !important; } |
| """ |
|
|
| with gr.Blocks(css=CSS, title="SSL ViT PCA Visualiser") as demo: |
| gr.HTML(""" |
| <div class="title-row"> |
| <h1 style="font-size:1.6rem; font-weight:700; margin:0;"> |
| SSL ViT β Patch Feature PCA |
| </h1> |
| </div> |
| <div class="subtitle-row"> |
| ImageNet-1K pre-training Β· |
| <a href="https://huggingface.co/OK-AI" target="_blank">OK-AI Models</a> |
| </div> |
| """) |
|
|
| with gr.Row(): |
| with gr.Column(scale=1): |
| input_image = gr.Image( |
| type="pil", |
| label="Input image", |
| show_label=True, |
| ) |
|
|
| with gr.Row(): |
| opt_epoch = gr.Dropdown( |
| choices=["ep100", "ep300"], |
| value="ep300", |
| label="Epochs", |
| interactive=True, |
| ) |
| opt_weight = gr.Dropdown( |
| choices=["student", "teacher"], |
| value="teacher", |
| label="Weight Type", |
| info="LeJEPA always uses student", |
| interactive=True, |
| ) |
|
|
| opt_size = gr.Dropdown( |
| choices=["224", "448", "672", "1280"], |
| value="672", |
| label="Image Target Resolution", |
| interactive=True, |
| ) |
|
|
| run_btn = gr.Button("Visualise", variant="primary") |
|
|
| gr.HTML(""" |
| <p style="font-size:0.8rem; color:#9ca3af; margin-top:0.5rem; line-height:1.5;"> |
| PCA is fit on all patch tokens and projected to |
| 3 components, then scaled with sigmoid for colour display. |
| Results stream seamlessly into view as individual variants complete. |
| </p> |
| |
| <div class="custom-footer"> |
| Models: <a href="https://huggingface.co/OK-AI" target="_blank">OK-AI on HuggingFace</a> |
| Β· |
| Code: <a href="https://github.com/Open-Knowledge-AI/lite_ssl" target="_blank">lite_ssl</a> |
| </div> |
| """) |
|
|
| with gr.Column(scale=3): |
| |
| gr.HTML('<div class="arch-header">ViT-S/16 Grid</div>') |
| with gr.Row(equal_height=True): |
| with gr.Column(elem_classes="output-col"): |
| gr.HTML( |
| f'<div class="model-label"><a href="https://huggingface.co/{MODEL_IDS["ViT-S/16"]["DiNO"]}" target="_blank">DiNO (S/16)</a></div>' |
| ) |
| out_dino_s = gr.Image(show_label=False, interactive=False) |
| with gr.Column(elem_classes="output-col"): |
| gr.HTML( |
| f'<div class="model-label"><a href="https://huggingface.co/{MODEL_IDS["ViT-S/16"]["iBOT"]}" target="_blank">iBOT (S/16)</a></div>' |
| ) |
| out_ibot_s = gr.Image(show_label=False, interactive=False) |
| with gr.Column(elem_classes="output-col"): |
| gr.HTML( |
| f'<div class="model-label"><a href="https://huggingface.co/{MODEL_IDS["ViT-S/16"]["LeJEPA"]}" target="_blank">LeJEPA (S/16)</a></div>' |
| ) |
| out_lejepa_s = gr.Image(show_label=False, interactive=False) |
|
|
| |
| gr.HTML('<div class="arch-header">ViT-B/16 Grid</div>') |
| with gr.Row(equal_height=True): |
| with gr.Column(elem_classes="output-col"): |
| gr.HTML( |
| f'<div class="model-label"><a href="https://huggingface.co/{MODEL_IDS["ViT-B/16"]["DiNO"]}" target="_blank">DiNO (B/16)</a></div>' |
| ) |
| out_dino_b = gr.Image(show_label=False, interactive=False) |
| with gr.Column(elem_classes="output-col"): |
| gr.HTML( |
| f'<div class="model-label"><a href="https://huggingface.co/{MODEL_IDS["ViT-B/16"]["iBOT"]}" target="_blank">iBOT (B/16)</a></div>' |
| ) |
| out_ibot_b = gr.Image(show_label=False, interactive=False) |
| with gr.Column(elem_classes="output-col"): |
| gr.HTML( |
| f'<div class="model-label"><a href="https://huggingface.co/{MODEL_IDS["ViT-B/16"]["LeJEPA"]}" target="_blank">LeJEPA (B/16)</a></div>' |
| ) |
| out_lejepa_b = gr.Image(show_label=False, interactive=False) |
|
|
| |
| output_targets = [ |
| out_dino_s, |
| out_ibot_s, |
| out_lejepa_s, |
| out_dino_b, |
| out_ibot_b, |
| out_lejepa_b, |
| ] |
|
|
| run_btn.click( |
| fn=run, |
| inputs=[input_image, opt_epoch, opt_weight, opt_size], |
| outputs=output_targets, |
| ) |
|
|
| if os.path.exists("examples"): |
| gr.Examples( |
| examples=[ |
| [f"examples/{f}"] |
| for f in sorted(os.listdir("examples")) |
| if f.lower().endswith((".jpg", ".jpeg", ".png", ".webp")) |
| ], |
| inputs=[input_image], |
| ) |
|
|
| if __name__ == "__main__": |
| demo.launch() |
|
|