AdarshRajDS
Add ResNet baseline and ConvNeXt v2 backend
2b8b06c
import torch
import numpy as np
from PIL import Image
# NOTE:
# We intentionally avoid importing `datasets` at module import time so the API can
# start even if the optional DINO dependencies are not installed locally.
from sklearn.metrics.pairwise import cosine_similarity
def load_dino(device):
model = torch.hub.load(
"facebookresearch/dinov2",
"dinov2_vits14"
)
model.eval().to(device)
return model
def build_embeddings(dino, transform, device):
# Lazy import to keep DINO optional.
from datasets import load_dataset
dataset = load_dataset(
"AdarshDS/mold-reference-images",
split="train"
)
embs = []
for sample in dataset:
img: Image.Image = sample["image"].convert("RGB")
t = transform(img).unsqueeze(0).to(device)
with torch.no_grad():
e = dino(t)
embs.append(e.squeeze().cpu().numpy())
if not embs:
raise RuntimeError(
"No reference images found in HF dataset"
)
return np.vstack(embs)
def similarity(dino, mold_embs, image, transform, device):
t = transform(image).unsqueeze(0).to(device)
with torch.no_grad():
e = dino(t).cpu().numpy()
return float(cosine_similarity(e, mold_embs).max())