Spaces:
Sleeping
Sleeping
| # app.py — DINOv3 two‑image patch similarity (click on Image 1 → show similarities on both images) | |
| # Runs on CPU or CUDA. No external image URLs. | |
| import os | |
| from typing import Tuple | |
| import gradio as gr | |
| import numpy as np | |
| from PIL import Image, ImageDraw | |
| import torch | |
| import torch.nn.functional as F | |
| import torchvision.transforms.functional as TF | |
| #from transformers import AutoModel # trust_remote_code=True | |
| from transformers import AutoModel | |
| # ============================ | |
| # Config | |
| # ============================ | |
| DEFAULT_MODEL_ID = "facebook/dinov3-vits16plus-pretrain-lvd1689m" | |
| ALT_MODEL_ID = "facebook/dinov3-vith16plus-pretrain-lvd1689m" | |
| #DEFAULT_MODEL_ID = "onnx-community/dinov3-vits16-pretrain-lvd1689m-ONNX" | |
| #ALT_MODEL_ID = "onnx-community/dinov3-vith16-pretrain-lvd1689m-ONNX" | |
| AVAILABLE_MODELS = [DEFAULT_MODEL_ID, ALT_MODEL_ID] | |
| PATCH_SIZE = 16 | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| IMAGENET_MEAN = (0.485, 0.456, 0.406) | |
| IMAGENET_STD = (0.229, 0.224, 0.225) | |
| # Many DINOv3 HF ports expose 1 [CLS] + 4 registers at the front | |
| N_SPECIAL_TOKENS = 5 | |
| # Robust colormap import (Matplotlib new/old) | |
| try: | |
| from matplotlib import colormaps as _mpl_colormaps | |
| def _get_cmap(name: str): | |
| return _mpl_colormaps[name] | |
| except Exception: | |
| import matplotlib.cm as _cm | |
| def _get_cmap(name: str): | |
| return _cm.get_cmap(name) | |
| # ============================ | |
| # Model loading / cache | |
| # ============================ | |
| _model_cache = {} | |
| _current_model_id = None | |
| model = None | |
| def load_model_from_hubold(model_id: str): | |
| print(f"Loading model '{model_id}' from HF Hub…") | |
| token = os.environ.get("HF_TOKEN") | |
| mdl = AutoModel.from_pretrained(model_id, token=token, trust_remote_code=True) | |
| mdl.to(DEVICE).eval() | |
| print(f"✅ Loaded '{model_id}' on {DEVICE}") | |
| return mdl | |
| def load_model_from_hubold2(model_id: str): | |
| print(f"Loading model '{model_id}' from HF Hub…") | |
| token = os.environ.get("HF_TOKEN") | |
| # Use pipeline instead of AutoModel | |
| extractor = pipeline( | |
| "image-feature-extraction", | |
| model=model_id, | |
| token=token, | |
| trust_remote_code=True, | |
| device=0 if DEVICE == "cuda" else -1, | |
| ) | |
| print(f"✅ Loaded '{model_id}' on {DEVICE}") | |
| return extractor | |
| def load_model_from_hub(model_id: str): | |
| print(f"Loading model '{model_id}' from HF Hub…") | |
| token = os.environ.get("HF_TOKEN") | |
| mdl = AutoModel.from_pretrained( | |
| model_id, | |
| token=token, | |
| trust_remote_code=True, | |
| ) | |
| mdl.to(DEVICE).eval() | |
| print(f"✅ Loaded '{model_id}' on {DEVICE}") | |
| return mdl | |
| def get_model(model_id: str): | |
| if model_id in _model_cache: | |
| return _model_cache[model_id] | |
| mdl = load_model_from_hub(model_id) | |
| _model_cache[model_id] = mdl | |
| return mdl | |
| # Load default at startup | |
| model = get_model(DEFAULT_MODEL_ID) | |
| _current_model_id = DEFAULT_MODEL_ID | |
| # ============================ | |
| # Helpers | |
| # ============================ | |
| def resize_to_grid(img: Image.Image, long_side: int, patch: int = PATCH_SIZE) -> torch.Tensor: | |
| """Resize so max(h,w)=long_side with aspect kept; then pad to multiples of patch. | |
| Return CHW float tensor in [0,1].""" | |
| w, h = img.size | |
| scale = long_side / max(h, w) | |
| new_h = max(patch, int(round(h * scale))) | |
| new_w = max(patch, int(round(w * scale))) | |
| new_h = ((new_h + patch - 1) // patch) * patch | |
| new_w = ((new_w + patch - 1) // patch) * patch | |
| return TF.to_tensor(TF.resize(img.convert("RGB"), (new_h, new_w))) | |
| def colorize(sim_map_up: np.ndarray, cmap_name: str = "viridis") -> Image.Image: | |
| x = sim_map_up.astype(np.float32) | |
| x = (x - x.min()) / (x.max() - x.min() + 1e-6) | |
| rgb = (_get_cmap(cmap_name)(x)[..., :3] * 255).astype(np.uint8) | |
| return Image.fromarray(rgb) | |
| def blend(base: Image.Image, heat: Image.Image, alpha: float = 0.55) -> Image.Image: | |
| base = base.convert("RGBA") | |
| heat = heat.convert("RGBA") | |
| a = Image.new("L", heat.size, int(255 * alpha)) | |
| heat.putalpha(a) | |
| out = Image.alpha_composite(base, heat) | |
| return out.convert("RGB") | |
| def draw_crosshair(img: Image.Image, x: int, y: int, radius: int | None = None) -> Image.Image: | |
| r = radius if radius is not None else max(2, PATCH_SIZE // 2) | |
| out = img.copy() | |
| draw = ImageDraw.Draw(out) | |
| draw.line([(x - r, y), (x + r, y)], fill="red", width=3) | |
| draw.line([(x, y - r), (x, y + r)], fill="red", width=3) | |
| return out | |
| # ============================ | |
| # Feature extraction | |
| # ============================ | |
| def extract_image_features(image_pil: Image.Image, target_long_side: int, mdl=None): | |
| mdl = mdl or model | |
| t = resize_to_grid(image_pil, target_long_side, PATCH_SIZE) | |
| t_norm = TF.normalize(t, IMAGENET_MEAN, IMAGENET_STD).unsqueeze(0).to(DEVICE) | |
| _, _, H, W = t_norm.shape | |
| Hp, Wp = H // PATCH_SIZE, W // PATCH_SIZE | |
| outputs = mdl(t_norm) | |
| patch_emb = outputs.last_hidden_state.squeeze(0)[N_SPECIAL_TOKENS:, :] # skip special tokens | |
| X = F.normalize(patch_emb, p=2, dim=-1) # (Hp*Wp, D), L2 norm for cosine | |
| img_resized = TF.to_pil_image(t) | |
| return {"X": X, "Hp": Hp, "Wp": Wp, "img": img_resized} | |
| # ============================ | |
| # Similarity utilities | |
| # ============================ | |
| def row_col_from_xy(x_pix: int, y_pix: int, Hp: int, Wp: int): | |
| col = int(np.clip(x_pix // PATCH_SIZE, 0, Wp - 1)) | |
| row = int(np.clip(y_pix // PATCH_SIZE, 0, Hp - 1)) | |
| return row, col | |
| def similarity_map(X: torch.Tensor, Hp: int, Wp: int, q_vec: torch.Tensor, | |
| img_h: int, img_w: int): | |
| sims = torch.matmul(X, q_vec) # (Hp*Wp) | |
| sim_map = sims.view(Hp, Wp) | |
| sim_up = F.interpolate( | |
| sim_map.unsqueeze(0).unsqueeze(0), | |
| size=(img_h, img_w), | |
| mode="bicubic", | |
| align_corners=False, | |
| ).squeeze().detach().cpu().numpy() | |
| return sim_map, sim_up | |
| # ============================ | |
| # Core: click on image 1 → heatmaps on image 1 and image 2 | |
| # ============================ | |
| def click_two_image_similarity(state1: dict, state2: dict, click_xy: Tuple[int, int], | |
| exclude_radius_patches: int, alpha: float, cmap_name: str): | |
| if not state1 or not state2: | |
| return (None,)*6 | |
| X1, Hp1, Wp1, img1 = state1["X"], state1["Hp"], state1["Wp"], state1["img"] | |
| X2, Hp2, Wp2, img2 = state2["X"], state2["Hp"], state2["Wp"], state2["img"] | |
| img1_w, img1_h = img1.size | |
| img2_w, img2_h = img2.size | |
| # Query vector from clicked patch on image 1 | |
| col = int(np.clip(click_xy[0] // PATCH_SIZE, 0, Wp1 - 1)) | |
| row = int(np.clip(click_xy[1] // PATCH_SIZE, 0, Hp1 - 1)) | |
| idx = row * Wp1 + col | |
| q = X1[idx] # (D,) | |
| # Similarity on image 1 (+ small exclusion mask around click if requested) | |
| sims1 = torch.matmul(X1, q) | |
| sim_map1 = sims1.view(Hp1, Wp1) | |
| if exclude_radius_patches > 0: | |
| rr, cc = torch.meshgrid( | |
| torch.arange(Hp1, device=sims1.device), | |
| torch.arange(Wp1, device=sims1.device), | |
| indexing="ij", | |
| ) | |
| mask1 = (torch.abs(rr - row) <= exclude_radius_patches) & (torch.abs(cc - col) <= exclude_radius_patches) | |
| sim_map1 = sim_map1.masked_fill(mask1, float("-inf")) | |
| sim1_up = F.interpolate( | |
| sim_map1.unsqueeze(0).unsqueeze(0), | |
| size=(img1_h, img1_w), | |
| mode="bicubic", | |
| align_corners=False, | |
| ).squeeze().detach().cpu().numpy() | |
| heat1 = colorize(sim1_up, cmap_name) | |
| overlay1 = blend(img1, heat1, alpha) | |
| marked1 = draw_crosshair(img1, int(click_xy[0]), int(click_xy[1]), radius=PATCH_SIZE // 2) | |
| # Similarity on image 2 | |
| sims2 = torch.matmul(X2, q) | |
| sim_map2 = sims2.view(Hp2, Wp2) | |
| sim2_up = F.interpolate( | |
| sim_map2.unsqueeze(0).unsqueeze(0), | |
| size=(img2_h, img2_w), | |
| mode="bicubic", | |
| align_corners=False, | |
| ).squeeze().detach().cpu().numpy() | |
| heat2 = colorize(sim2_up, cmap_name) | |
| overlay2 = blend(img2, heat2, alpha) | |
| return marked1, heat1, overlay1, heat2, overlay2, float(sim2_up.max()) | |
| # ============================ | |
| # Gradio UI | |
| # ============================ | |
| with gr.Blocks(theme=gr.themes.Soft(), title="DINOv3 Two‑Image Patch Similarity") as demo: | |
| gr.Markdown("# DINOv3 Two‑Image Patch Similarity") | |
| gr.Markdown("Upload two images and press **Process both**. Then click on **Image 1** to see similar regions on **both** images.") | |
| state1 = gr.State() | |
| state2 = gr.State() | |
| with gr.Row(): | |
| with gr.Column(): | |
| model_choice = gr.Dropdown(choices=AVAILABLE_MODELS, value=DEFAULT_MODEL_ID, label="Backbone") | |
| target_long_side = gr.Slider(224, 1024, value=768, step=16, label="Resolution (long side)") | |
| alpha = gr.Slider(0.0, 1.0, value=0.55, step=0.05, label="Overlay opacity") | |
| cmap = gr.Dropdown(["viridis", "magma", "plasma", "inferno", "turbo", "cividis"], value="viridis", label="Colormap") | |
| exclude_r = gr.Slider(0, 10, value=0, step=1, label="Exclude radius (patches) for Image 1") | |
| start_btn = gr.Button("▶️ Process both", variant="primary") | |
| with gr.Column(): | |
| img1 = gr.Image(label="Image 1 (clickable)", type="pil", sources=["upload", "clipboard"], value=None) | |
| img2 = gr.Image(label="Image 2", type="pil", sources=["upload", "clipboard"], value=None) | |
| with gr.Row(): | |
| with gr.Column(): | |
| marked1 = gr.Image(label="Image 1 — click marker / preview", interactive=False) | |
| heat1 = gr.Image(label="Image 1 — similarity heatmap", interactive=False) | |
| overlay1= gr.Image(label="Image 1 — overlay", interactive=False) | |
| with gr.Column(): | |
| heat2 = gr.Image(label="Image 2 — similarity heatmap", interactive=False) | |
| overlay2= gr.Image(label="Image 2 — overlay", interactive=False) | |
| score2 = gr.Number(label="Image 2 — max similarity score", precision=6) | |
| # Utilities | |
| def _ensure_model(model_id: str): | |
| global model, _current_model_id | |
| if model_id != _current_model_id: | |
| model = get_model(model_id) | |
| _current_model_id = model_id | |
| # Process button → extract features for both images and store in state | |
| def _run_both(im1: Image.Image, im2: Image.Image, long_side: int, model_id: str, progress=gr.Progress(track_tqdm=False)): | |
| if im1 is None or im2 is None: | |
| raise gr.Error("Please provide both images before processing.") | |
| _ensure_model(model_id) | |
| progress(0, desc="Extracting features for Image 1…") | |
| st1 = extract_image_features(im1, int(long_side), mdl=model) | |
| progress(0.5, desc="Extracting features for Image 2…") | |
| st2 = extract_image_features(im2, int(long_side), mdl=model) | |
| progress(1, desc="Done") | |
| # Show quick previews to confirm processing | |
| return st1["img"], st2["img"], st1, st2 | |
| start_btn.click( | |
| _run_both, | |
| inputs=[img1, img2, target_long_side, model_choice], | |
| outputs=[marked1, overlay2, state1, state2], | |
| ) | |
| # Clicking on Image 1 → compute similarities on both images | |
| def _on_click(st1, st2, a: float, m: str, excl: int, evt: gr.SelectData): | |
| if not st1 or not st2 or evt is None: | |
| return (None,)*6 | |
| return click_two_image_similarity( | |
| st1, st2, | |
| click_xy=evt.index, | |
| exclude_radius_patches=int(excl), | |
| alpha=float(a), cmap_name=m, | |
| ) | |
| img1.select( | |
| _on_click, | |
| inputs=[state1, state2, alpha, cmap, exclude_r], | |
| outputs=[marked1, heat1, overlay1, heat2, overlay2, score2], | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |