| | import pytest |
| | import torch |
| | import torch.nn as nn |
| | from torch.utils.data import Dataset |
| |
|
| | from kornia.x import Configuration, SemanticSegmentationTrainer |
| |
|
| |
|
| | class DummyDatasetSegmentation(Dataset): |
| | def __len__(self): |
| | return 10 |
| |
|
| | def __getitem__(self, index): |
| | return torch.ones(3, 32, 32), torch.ones(32, 32).long() |
| |
|
| |
|
| | @pytest.fixture |
| | def model(): |
| | return nn.Conv2d(3, 10, kernel_size=1) |
| |
|
| |
|
| | @pytest.fixture |
| | def dataloader(): |
| | dataset = DummyDatasetSegmentation() |
| | return torch.utils.data.DataLoader(dataset, batch_size=1) |
| |
|
| |
|
| | @pytest.fixture |
| | def criterion(): |
| | return nn.CrossEntropyLoss() |
| |
|
| |
|
| | @pytest.fixture |
| | def optimizer(model): |
| | return torch.optim.AdamW(model.parameters()) |
| |
|
| |
|
| | @pytest.fixture |
| | def scheduler(optimizer, dataloader): |
| | return torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, len(dataloader)) |
| |
|
| |
|
| | @pytest.fixture |
| | def configuration(): |
| | config = Configuration() |
| | config.num_epochs = 1 |
| | return config |
| |
|
| |
|
| | class TestsemanticSegmentationTrainer: |
| | def test_fit(self, model, dataloader, criterion, optimizer, scheduler, configuration): |
| | trainer = SemanticSegmentationTrainer( |
| | model, dataloader, dataloader, criterion, optimizer, scheduler, configuration |
| | ) |
| | trainer.fit() |
| |
|
| | def test_exception(self, model, dataloader, criterion, optimizer, scheduler, configuration): |
| | with pytest.raises(ValueError): |
| | SemanticSegmentationTrainer( |
| | model, dataloader, dataloader, criterion, optimizer, scheduler, configuration, callbacks={'frodo': None} |
| | ) |
| |
|