dino / app.py
WolfDavid's picture
Upload app.py with huggingface_hub
84dd851 verified
"""
DINOv2 Vision Demo — Image Classification, Feature Visualization, and Similarity.
Lightweight Gradio app running DINOv2-small entirely on CPU.
"""
import io
import warnings
import gradio as gr
import numpy as np
import torch
import torch.nn.functional as F
from PIL import Image
from sklearn.decomposition import PCA
from transformers import (
AutoImageProcessor,
AutoModelForImageClassification,
Dinov2Model,
)
warnings.filterwarnings("ignore")
# ---------------------------------------------------------------------------
# Global model cache — loaded once on startup
# ---------------------------------------------------------------------------
_classification_model = None
_classification_processor = None
_feature_model = None
_feature_processor = None
CLASSIFICATION_REPO = "facebook/dinov2-small-imagenet1k-1-layer"
FEATURE_REPO = "facebook/dinov2-small"
def _load_classification_model():
global _classification_model, _classification_processor
if _classification_model is None:
_classification_processor = AutoImageProcessor.from_pretrained(CLASSIFICATION_REPO)
_classification_model = AutoModelForImageClassification.from_pretrained(CLASSIFICATION_REPO)
_classification_model.eval()
return _classification_processor, _classification_model
def _load_feature_model():
global _feature_model, _feature_processor
if _feature_model is None:
_feature_processor = AutoImageProcessor.from_pretrained(FEATURE_REPO)
_feature_model = Dinov2Model.from_pretrained(FEATURE_REPO)
_feature_model.eval()
return _feature_processor, _feature_model
# ---------------------------------------------------------------------------
# Tab 1 — Image Classification
# ---------------------------------------------------------------------------
def classify_image(image: Image.Image | None):
"""Return top-5 predicted ImageNet classes with confidences."""
if image is None:
raise gr.Error("Please upload an image first.")
processor, model = _load_classification_model()
image = image.convert("RGB")
inputs = processor(images=image, return_tensors="pt")
with torch.no_grad():
logits = model(**inputs).logits
probs = F.softmax(logits, dim=-1)[0]
top5_probs, top5_indices = torch.topk(probs, k=5)
labels = model.config.id2label
results = {labels[idx.item()]: float(prob) for prob, idx in zip(top5_probs, top5_indices)}
return results
# ---------------------------------------------------------------------------
# Tab 2 — Feature Visualization (PCA on patch tokens)
# ---------------------------------------------------------------------------
def visualize_features(image: Image.Image | None):
"""Compute PCA over DINOv2 patch tokens and render first 3 components as RGB."""
if image is None:
raise gr.Error("Please upload an image first.")
processor, model = _load_feature_model()
image = image.convert("RGB")
inputs = processor(images=image, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
# Patch tokens (drop CLS token at position 0)
patch_tokens = outputs.last_hidden_state[0, 1:, :].cpu().numpy() # (N_patches, D)
n_patches = patch_tokens.shape[0]
h = w = int(np.sqrt(n_patches))
# PCA → 3 components for RGB
pca = PCA(n_components=3)
pca_result = pca.fit_transform(patch_tokens) # (N_patches, 3)
# Normalize each component to [0, 255]
for c in range(3):
col = pca_result[:, c]
lo, hi = col.min(), col.max()
if hi - lo > 0:
pca_result[:, c] = (col - lo) / (hi - lo) * 255.0
else:
pca_result[:, c] = 0.0
pca_img = pca_result.reshape(h, w, 3).astype(np.uint8)
# Resize to match original image for a cleaner side-by-side view
pca_pil = Image.fromarray(pca_img).resize(image.size, Image.BILINEAR)
return image, pca_pil
# ---------------------------------------------------------------------------
# Tab 3 — Image Similarity (CLS token cosine similarity)
# ---------------------------------------------------------------------------
def compute_similarity(
image1: Image.Image | None,
image2: Image.Image | None,
):
"""Compute cosine similarity between two images using DINOv2 CLS embeddings."""
if image1 is None or image2 is None:
raise gr.Error("Please upload both images.")
processor, model = _load_feature_model()
imgs = [img.convert("RGB") for img in (image1, image2)]
inputs = processor(images=imgs, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
cls_tokens = outputs.last_hidden_state[:, 0, :] # (2, D)
similarity = F.cosine_similarity(cls_tokens[0:1], cls_tokens[1:2]).item()
label = "Identical" if similarity > 0.99 else (
"Very similar" if similarity > 0.85 else (
"Similar" if similarity > 0.6 else (
"Somewhat related" if similarity > 0.3 else "Different"
)))
md = (
f"## Cosine Similarity: **{similarity:.4f}**\n\n"
f"Interpretation: **{label}**\n\n"
f"*Scale: -1 (opposite) ← 0 (unrelated) → 1 (identical)*"
)
return md
# ---------------------------------------------------------------------------
# No example images — users upload their own
# ---------------------------------------------------------------------------
# ---------------------------------------------------------------------------
# Gradio UI
# ---------------------------------------------------------------------------
with gr.Blocks(
title="DINOv2 Vision Demo",
) as demo:
gr.Markdown(
"# DINOv2 Vision Demo\n"
"Explore Meta's **DINOv2** self-supervised vision transformer — "
"image classification, patch-level feature visualization, and embedding similarity. "
"Everything runs on CPU using `dinov2-small` (~86 MB)."
)
# ---- Tab 1: Classification ----
with gr.Tab("Image Classification"):
gr.Markdown(
"Upload an image to classify it against **ImageNet-1k** labels using "
"`facebook/dinov2-small-imagenet1k-1-layer`."
)
with gr.Row():
with gr.Column():
cls_input = gr.Image(type="pil", label="Input Image")
cls_btn = gr.Button("Classify", variant="primary")
with gr.Column():
cls_output = gr.Label(num_top_classes=5, label="Top-5 Predictions")
cls_btn.click(fn=classify_image, inputs=cls_input, outputs=cls_output)
# ---- Tab 2: Feature Visualization ----
with gr.Tab("Feature Visualization"):
gr.Markdown(
"Visualize **DINOv2 patch token features** via PCA. "
"The first 3 principal components are mapped to RGB channels, "
"revealing how the model perceives structure and semantics."
)
with gr.Row():
feat_input = gr.Image(type="pil", label="Input Image")
feat_btn = gr.Button("Visualize Features", variant="primary")
with gr.Row():
feat_orig = gr.Image(type="pil", label="Original")
feat_pca = gr.Image(type="pil", label="PCA Feature Map")
feat_btn.click(fn=visualize_features, inputs=feat_input, outputs=[feat_orig, feat_pca])
# ---- Tab 3: Image Similarity ----
with gr.Tab("Image Similarity"):
gr.Markdown(
"Upload two images to compare their **DINOv2 CLS embeddings** "
"via cosine similarity."
)
with gr.Row():
sim_img1 = gr.Image(type="pil", label="Image A")
sim_img2 = gr.Image(type="pil", label="Image B")
sim_btn = gr.Button("Compare", variant="primary")
sim_output = gr.Markdown(label="Similarity")
sim_btn.click(
fn=compute_similarity,
inputs=[sim_img1, sim_img2],
outputs=sim_output,
)
if __name__ == "__main__":
demo.launch()