compvis / test /x /test_image_classification.py
Dexter's picture
Upload folder using huggingface_hub
36c95ba verified
raw
history blame
1.64 kB
import pytest
import torch
import torch.nn as nn
from torch.utils.data import Dataset
from kornia.contrib import ClassificationHead, VisionTransformer
from kornia.x import Configuration, ImageClassifierTrainer
class DummyDatasetClassification(Dataset):
def __len__(self):
return 10
def __getitem__(self, index):
return torch.ones(3, 32, 32), torch.tensor(1)
@pytest.fixture
def model():
return nn.Sequential(VisionTransformer(image_size=32), ClassificationHead(num_classes=10))
@pytest.fixture
def dataloader():
dataset = DummyDatasetClassification()
return torch.utils.data.DataLoader(dataset, batch_size=1)
@pytest.fixture
def criterion():
return nn.CrossEntropyLoss()
@pytest.fixture
def optimizer(model):
return torch.optim.AdamW(model.parameters())
@pytest.fixture
def scheduler(optimizer, dataloader):
return torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, len(dataloader))
@pytest.fixture
def configuration():
config = Configuration()
config.num_epochs = 1
return config
class TestImageClassifierTrainer:
def test_fit(self, model, dataloader, criterion, optimizer, scheduler, configuration):
trainer = ImageClassifierTrainer(model, dataloader, dataloader, criterion, optimizer, scheduler, configuration)
trainer.fit()
def test_exception(self, model, dataloader, criterion, optimizer, scheduler, configuration):
with pytest.raises(ValueError):
ImageClassifierTrainer(
model, dataloader, dataloader, criterion, optimizer, scheduler, configuration, callbacks={'frodo': None}
)