File size: 1,248 Bytes
36c95ba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import pytest
import torch.nn as nn

from kornia.metrics import AverageMeter
from kornia.x import EarlyStopping, ModelCheckpoint
from kornia.x.utils import TrainerState


@pytest.fixture
def model():
    return nn.Conv2d(3, 10, kernel_size=1)


def test_callback_modelcheckpoint(tmp_path, model):
    cb = ModelCheckpoint(tmp_path, 'test_monitor')
    assert cb is not None

    metric = {'test_monitor': AverageMeter()}
    metric['test_monitor'].avg = 1.0

    cb(model, epoch=0, valid_metric=metric)
    assert cb.best_metric == 1.0
    assert (tmp_path / "model_0.pt").is_file()


def test_callback_earlystopping(model):
    cb = EarlyStopping('test_monitor', patience=2)
    assert cb is not None
    assert cb.counter == 0

    metric = {'test_monitor': AverageMeter()}
    metric['test_monitor'].avg = 1

    state = cb(model, epoch=0, valid_metric=metric)
    assert state == TrainerState.TRAINING
    assert cb.best_score == -1
    assert cb.counter == 0

    metric['test_monitor'].avg = 2
    state = cb(model, epoch=0, valid_metric=metric)
    assert state == TrainerState.TRAINING
    assert cb.best_score == -1
    assert cb.counter == 1

    state = cb(model, epoch=0, valid_metric=metric)
    assert state == TrainerState.TERMINATE