| from __future__ import annotations |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from transformers import AutoImageProcessor, AutoModel |
| from transformers.image_utils import load_image |
|
|
| MODEL_DIR = "your/path/to/facebook/dinov3-vith16plus-pretrain-lvd1689m" |
| CKPT_PATH = "your/path/to/PokeCon_head.pt" |
|
|
| IMG1 = "your_image1.png" |
| IMG2 = "your_image2.png" |
|
|
| DTYPE = torch.bfloat16 |
|
|
| class ProjectionHead(nn.Module): |
| def __init__(self, in_dim: int, hidden_dims: list[int], dropout: float = 0.05, use_layernorm: bool = True): |
| super().__init__() |
| layers: list[nn.Module] = [] |
| prev = in_dim |
| for i, d in enumerate(hidden_dims): |
| layers.append(nn.Linear(prev, d)) |
| is_last = (i == len(hidden_dims) - 1) |
| if not is_last: |
| if use_layernorm: |
| layers.append(nn.LayerNorm(d)) |
| layers.append(nn.GELU()) |
| if dropout > 0: |
| layers.append(nn.Dropout(dropout)) |
| prev = d |
| self.net = nn.Sequential(*layers) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| return self.net(x) |
|
|
|
|
| def main() -> None: |
| image1 = load_image(IMG1) |
| image2 = load_image(IMG2) |
|
|
| processor = AutoImageProcessor.from_pretrained(MODEL_DIR, local_files_only=True) |
| backbone = AutoModel.from_pretrained( |
| MODEL_DIR, |
| local_files_only=True, |
| torch_dtype=DTYPE, |
| device_map="auto", |
| ).eval() |
|
|
| ckpt = torch.load(CKPT_PATH, map_location="cpu") |
| if not isinstance(ckpt, dict) or "head_state_dict" not in ckpt: |
| raise RuntimeError("Expected a checkpoint dict with key: 'head_state_dict'") |
|
|
| cfg = ckpt.get("config", {}) |
| proj_dims = list(cfg.get("proj_dims", [512, 256])) |
| dropout = float(cfg.get("dropout", 0.05)) |
| use_layernorm = bool(cfg.get("use_layernorm", True)) |
|
|
| head = ProjectionHead( |
| in_dim=backbone.config.hidden_size, |
| hidden_dims=proj_dims, |
| dropout=dropout, |
| use_layernorm=use_layernorm, |
| ).to(device=backbone.device, dtype=DTYPE) |
| head.load_state_dict(ckpt["head_state_dict"], strict=True) |
| head.eval() |
|
|
| inputs = processor(images=[image1, image2], return_tensors="pt").to(backbone.device) |
|
|
| with torch.inference_mode(): |
| out = backbone(pixel_values=inputs["pixel_values"].to(backbone.dtype)) |
| pooled = out.pooler_output |
| z = head(pooled.to(DTYPE)) |
| z = F.normalize(z, dim=-1) |
| cos = (z[0] * z[1]).sum().item() |
|
|
| print(f"Cosine similarity: {cos:.6f}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|