core-jepa / benchmarks /rsscn7.py
Gajesh Ladhar
initial src and benchmark added
c71037b
import torch
import numpy as np
import kagglehub
import glob
import os
import pandas as pd
from PIL import Image
from mapminer import models
from tqdm import tqdm
from sklearn.preprocessing import LabelEncoder
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import accuracy_score
def load_models():
dinov3 = models.DINOv3(architecture="vit-l-sat", pretrained=True).cuda()
_ = dinov3.eval()
lejepa = models.DINOv3(architecture="vit-l-sat", pretrained=True).cuda()
weight_path = "lejepa-l.pt"
state_dict = torch.load(weight_path)
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in state_dict.items():
if k.startswith('encoder.model.'):
new_key = k.replace('encoder.model.', 'model.')
new_state_dict[new_key] = v
else:
new_state_dict[k] = v
state_dict = new_state_dict
lejepa.load_state_dict(state_dict,strict=False)
_ = lejepa.eval()
return dinov3, lejepa
def load_dataset():
path = kagglehub.dataset_download("nifulislam/rsscn7-dataset")
print("Path to dataset files:", path)
return path
def get_embedding(model, img_path, device="cuda"):
img = Image.open(img_path).convert("RGB").resize((256, 256))
# to tensor
arr = np.asarray(img).transpose(2,0,1)
x = torch.tensor(arr, dtype=torch.float32)[None].to(device)
# normalize with model's preprocess
x = model.normalize(x)
# forward pass
with torch.no_grad():
out = model.model.forward_features(x)
emb = out['x_norm_clstoken'].squeeze().cpu().numpy()
return emb
def evaluate():
path = load_dataset()
root = "/root/.cache/kagglehub/datasets/nifulislam/rsscn7-dataset/versions/2/gamma-correct"
classes = sorted(os.listdir(root))
files = []
for cls in classes:
imgs = glob.glob(f"{root}/{cls}/*.jpg")
for f in imgs:
files.append((f, cls))
df = pd.DataFrame(files,columns=['path','class'])
files = df
files = files.sample(frac=1.0,random_state=42)
dinov3, lejepa = load_models()
X_dino_ = []
X_jepa_ = []
y_ = []
for index in tqdm(range(len(files))):
path = files.iloc[index]['path']
cls = files.iloc[index]['class']
emb_d = get_embedding(dinov3, path)
emb_j = get_embedding(lejepa, path)
X_dino_.append(emb_d)
X_jepa_.append(emb_j)
y_.append(cls)
X_dino = np.array(X_dino_)
X_jepa = np.array(X_jepa_)
# encode labels
le = LabelEncoder()
y_enc = le.fit_transform(y_)
kf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
accs_dino = []
accs_jepa = []
for train_idx, test_idx in kf.split(X_dino, y_enc):
Xtr_d, Xte_d = X_dino[train_idx], X_dino[test_idx]
Xtr_j, Xte_j = X_jepa[train_idx], X_jepa[test_idx]
ytr, yte = y_enc[train_idx], y_enc[test_idx]
clf_dino = LogisticRegression(max_iter=8000, n_jobs=-1)
clf_jepa = LogisticRegression(max_iter=8000, n_jobs=-1)
clf_dino.fit(Xtr_d, ytr)
clf_jepa.fit(Xtr_j, ytr)
pred_dino = clf_dino.predict(Xte_d)
pred_jepa = clf_jepa.predict(Xte_j)
accs_dino.append(accuracy_score(yte, pred_dino))
accs_jepa.append(accuracy_score(yte, pred_jepa))
print("DINOv3 K-fold Accuracy:", round(np.mean(accs_dino)*100,1))
print("LeJEPA K-fold Accuracy:", round(np.mean(accs_jepa)*100,1))