Mithridatium / tests /test_evaluator.py
Gustavo Lucca
added model argument
447fe2d
import os
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torchvision.models import resnet18
import mithridatium.evaluator as evaluator
import mithridatium.loader as loader
import unittest
class TestEvaluator(unittest.TestCase):
def test_extract_embeddings_and_evaluate(self):
# Get model path from environment variable or use default
"""
export MODEL_PATH=models/resnet18_bd.pth
export BATCH_SIZE=128
.venv/bin/python -m unittest tests/test_evaluator.py
"""
model_path = os.environ.get("MODEL_PATH", "models/resnet18_bd.pth")
batch_size = int(os.environ.get("BATCH_SIZE", 128))
# Use a tiny subset of CIFAR-10
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])
testset = datasets.CIFAR10('./data', train=False, download=True, transform=transform)
indices = list(range(512))
subset = torch.utils.data.Subset(testset, indices)
loader_ = DataLoader(subset, batch_size=batch_size, shuffle=False)
model, feature_module = loader.load_resnet18(model_path)
embs, labels = evaluator.extract_embeddings(model, loader_, feature_module)
print(f"Embeddings shape: {embs.shape}")
print(f"Labels shape: {labels.shape}")
print(f"First 5 labels: {labels[:5].tolist()}")
loss, accy = evaluator.evaluate(model, loader_)
print(f"Loss: {loss:.4f}")
print(f"Accuracy: {accy*100:.2f}%")
self.assertTrue(embs.shape[0] > 0)
self.assertTrue(labels.shape[0] > 0)
self.assertTrue(loss >= 0)
self.assertTrue(accy >= 0)
if __name__ == "__main__":
unittest.main()