Spaces:
Sleeping
Sleeping
| import torch | |
| import numpy as np | |
| from PIL import Image | |
| # NOTE: | |
| # We intentionally avoid importing `datasets` at module import time so the API can | |
| # start even if the optional DINO dependencies are not installed locally. | |
| from sklearn.metrics.pairwise import cosine_similarity | |
| def load_dino(device): | |
| model = torch.hub.load( | |
| "facebookresearch/dinov2", | |
| "dinov2_vits14" | |
| ) | |
| model.eval().to(device) | |
| return model | |
| def build_embeddings(dino, transform, device): | |
| # Lazy import to keep DINO optional. | |
| from datasets import load_dataset | |
| dataset = load_dataset( | |
| "AdarshDS/mold-reference-images", | |
| split="train" | |
| ) | |
| embs = [] | |
| for sample in dataset: | |
| img: Image.Image = sample["image"].convert("RGB") | |
| t = transform(img).unsqueeze(0).to(device) | |
| with torch.no_grad(): | |
| e = dino(t) | |
| embs.append(e.squeeze().cpu().numpy()) | |
| if not embs: | |
| raise RuntimeError( | |
| "No reference images found in HF dataset" | |
| ) | |
| return np.vstack(embs) | |
| def similarity(dino, mold_embs, image, transform, device): | |
| t = transform(image).unsqueeze(0).to(device) | |
| with torch.no_grad(): | |
| e = dino(t).cpu().numpy() | |
| return float(cosine_similarity(e, mold_embs).max()) | |