""" DINOv2 Vision Demo — Image Classification, Feature Visualization, and Similarity. Lightweight Gradio app running DINOv2-small entirely on CPU. """ import io import warnings import gradio as gr import numpy as np import torch import torch.nn.functional as F from PIL import Image from sklearn.decomposition import PCA from transformers import ( AutoImageProcessor, AutoModelForImageClassification, Dinov2Model, ) warnings.filterwarnings("ignore") # --------------------------------------------------------------------------- # Global model cache — loaded once on startup # --------------------------------------------------------------------------- _classification_model = None _classification_processor = None _feature_model = None _feature_processor = None CLASSIFICATION_REPO = "facebook/dinov2-small-imagenet1k-1-layer" FEATURE_REPO = "facebook/dinov2-small" def _load_classification_model(): global _classification_model, _classification_processor if _classification_model is None: _classification_processor = AutoImageProcessor.from_pretrained(CLASSIFICATION_REPO) _classification_model = AutoModelForImageClassification.from_pretrained(CLASSIFICATION_REPO) _classification_model.eval() return _classification_processor, _classification_model def _load_feature_model(): global _feature_model, _feature_processor if _feature_model is None: _feature_processor = AutoImageProcessor.from_pretrained(FEATURE_REPO) _feature_model = Dinov2Model.from_pretrained(FEATURE_REPO) _feature_model.eval() return _feature_processor, _feature_model # --------------------------------------------------------------------------- # Tab 1 — Image Classification # --------------------------------------------------------------------------- def classify_image(image: Image.Image | None): """Return top-5 predicted ImageNet classes with confidences.""" if image is None: raise gr.Error("Please upload an image first.") processor, model = _load_classification_model() image = image.convert("RGB") inputs = processor(images=image, return_tensors="pt") with torch.no_grad(): logits = model(**inputs).logits probs = F.softmax(logits, dim=-1)[0] top5_probs, top5_indices = torch.topk(probs, k=5) labels = model.config.id2label results = {labels[idx.item()]: float(prob) for prob, idx in zip(top5_probs, top5_indices)} return results # --------------------------------------------------------------------------- # Tab 2 — Feature Visualization (PCA on patch tokens) # --------------------------------------------------------------------------- def visualize_features(image: Image.Image | None): """Compute PCA over DINOv2 patch tokens and render first 3 components as RGB.""" if image is None: raise gr.Error("Please upload an image first.") processor, model = _load_feature_model() image = image.convert("RGB") inputs = processor(images=image, return_tensors="pt") with torch.no_grad(): outputs = model(**inputs) # Patch tokens (drop CLS token at position 0) patch_tokens = outputs.last_hidden_state[0, 1:, :].cpu().numpy() # (N_patches, D) n_patches = patch_tokens.shape[0] h = w = int(np.sqrt(n_patches)) # PCA → 3 components for RGB pca = PCA(n_components=3) pca_result = pca.fit_transform(patch_tokens) # (N_patches, 3) # Normalize each component to [0, 255] for c in range(3): col = pca_result[:, c] lo, hi = col.min(), col.max() if hi - lo > 0: pca_result[:, c] = (col - lo) / (hi - lo) * 255.0 else: pca_result[:, c] = 0.0 pca_img = pca_result.reshape(h, w, 3).astype(np.uint8) # Resize to match original image for a cleaner side-by-side view pca_pil = Image.fromarray(pca_img).resize(image.size, Image.BILINEAR) return image, pca_pil # --------------------------------------------------------------------------- # Tab 3 — Image Similarity (CLS token cosine similarity) # --------------------------------------------------------------------------- def compute_similarity( image1: Image.Image | None, image2: Image.Image | None, ): """Compute cosine similarity between two images using DINOv2 CLS embeddings.""" if image1 is None or image2 is None: raise gr.Error("Please upload both images.") processor, model = _load_feature_model() imgs = [img.convert("RGB") for img in (image1, image2)] inputs = processor(images=imgs, return_tensors="pt") with torch.no_grad(): outputs = model(**inputs) cls_tokens = outputs.last_hidden_state[:, 0, :] # (2, D) similarity = F.cosine_similarity(cls_tokens[0:1], cls_tokens[1:2]).item() label = "Identical" if similarity > 0.99 else ( "Very similar" if similarity > 0.85 else ( "Similar" if similarity > 0.6 else ( "Somewhat related" if similarity > 0.3 else "Different" ))) md = ( f"## Cosine Similarity: **{similarity:.4f}**\n\n" f"Interpretation: **{label}**\n\n" f"*Scale: -1 (opposite) ← 0 (unrelated) → 1 (identical)*" ) return md # --------------------------------------------------------------------------- # No example images — users upload their own # --------------------------------------------------------------------------- # --------------------------------------------------------------------------- # Gradio UI # --------------------------------------------------------------------------- with gr.Blocks( title="DINOv2 Vision Demo", ) as demo: gr.Markdown( "# DINOv2 Vision Demo\n" "Explore Meta's **DINOv2** self-supervised vision transformer — " "image classification, patch-level feature visualization, and embedding similarity. " "Everything runs on CPU using `dinov2-small` (~86 MB)." ) # ---- Tab 1: Classification ---- with gr.Tab("Image Classification"): gr.Markdown( "Upload an image to classify it against **ImageNet-1k** labels using " "`facebook/dinov2-small-imagenet1k-1-layer`." ) with gr.Row(): with gr.Column(): cls_input = gr.Image(type="pil", label="Input Image") cls_btn = gr.Button("Classify", variant="primary") with gr.Column(): cls_output = gr.Label(num_top_classes=5, label="Top-5 Predictions") cls_btn.click(fn=classify_image, inputs=cls_input, outputs=cls_output) # ---- Tab 2: Feature Visualization ---- with gr.Tab("Feature Visualization"): gr.Markdown( "Visualize **DINOv2 patch token features** via PCA. " "The first 3 principal components are mapped to RGB channels, " "revealing how the model perceives structure and semantics." ) with gr.Row(): feat_input = gr.Image(type="pil", label="Input Image") feat_btn = gr.Button("Visualize Features", variant="primary") with gr.Row(): feat_orig = gr.Image(type="pil", label="Original") feat_pca = gr.Image(type="pil", label="PCA Feature Map") feat_btn.click(fn=visualize_features, inputs=feat_input, outputs=[feat_orig, feat_pca]) # ---- Tab 3: Image Similarity ---- with gr.Tab("Image Similarity"): gr.Markdown( "Upload two images to compare their **DINOv2 CLS embeddings** " "via cosine similarity." ) with gr.Row(): sim_img1 = gr.Image(type="pil", label="Image A") sim_img2 = gr.Image(type="pil", label="Image B") sim_btn = gr.Button("Compare", variant="primary") sim_output = gr.Markdown(label="Similarity") sim_btn.click( fn=compute_similarity, inputs=[sim_img1, sim_img2], outputs=sim_output, ) if __name__ == "__main__": demo.launch()