import torch import numpy as np from PIL import Image # NOTE: # We intentionally avoid importing `datasets` at module import time so the API can # start even if the optional DINO dependencies are not installed locally. from sklearn.metrics.pairwise import cosine_similarity def load_dino(device): model = torch.hub.load( "facebookresearch/dinov2", "dinov2_vits14" ) model.eval().to(device) return model def build_embeddings(dino, transform, device): # Lazy import to keep DINO optional. from datasets import load_dataset dataset = load_dataset( "AdarshDS/mold-reference-images", split="train" ) embs = [] for sample in dataset: img: Image.Image = sample["image"].convert("RGB") t = transform(img).unsqueeze(0).to(device) with torch.no_grad(): e = dino(t) embs.append(e.squeeze().cpu().numpy()) if not embs: raise RuntimeError( "No reference images found in HF dataset" ) return np.vstack(embs) def similarity(dino, mold_embs, image, transform, device): t = transform(image).unsqueeze(0).to(device) with torch.no_grad(): e = dino(t).cpu().numpy() return float(cosine_similarity(e, mold_embs).max())