File size: 4,024 Bytes
090a270
f2a237f
 
 
 
 
 
090a270
f2a237f
 
090a270
 
 
 
 
 
f2a237f
00e5cc8
f2a237f
090a270
 
 
 
 
 
 
 
 
 
f2a237f
 
090a270
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f2a237f
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
# tests/test_cli/test_experiments.py
import torch

from app.cli.definitions import (
    DATASET_PATHS,
    DEFAULT_CONFIG,
    EXPERIMENTS,
)
from app.cli.experiments import parse_args
from app.cli.runner import get_device, show_status


class TestCLI:
    """Tests for CLI module."""

    def test_experiments_defined(self):
        """Test that only valid experiments are defined."""
        assert len(EXPERIMENTS) == 5
        assert all(i in EXPERIMENTS for i in [6, 7, 8, 9])

    def test_experiment_structure(self):
        """Test that experiments have required fields."""
        required_fields = ["name", "train_dataset", "test_dataset", "num_classes", "test_classes", "description"]

        for exp_id, config in EXPERIMENTS.items():
            for field in required_fields:
                assert field in config, f"Exp{exp_id} missing field: {field}"

    def test_dataset_paths_defined(self):
        """Test that required dataset paths are defined."""
        required_datasets = ["clean"]

        for dataset in required_datasets:
            assert dataset in DATASET_PATHS, f"Missing dataset: {dataset}"

    def test_default_config(self):
        """Test default configuration values."""
        assert DEFAULT_CONFIG["model_variant"] == "resnet18"
        assert DEFAULT_CONFIG["epochs"] == 26
        assert DEFAULT_CONFIG["batch_size"] == 8
        assert DEFAULT_CONFIG["learning_rate"] == 0.001
        assert DEFAULT_CONFIG["num_folds"] == 5

    def test_parse_args_default(self, monkeypatch):
        """Test default argument values."""
        monkeypatch.setattr("sys.argv", ["prog"])
        args = parse_args()

        assert args.model == "resnet18"
        assert args.epochs == 26
        assert args.batch_size == 8
        assert args.lr == 0.001
        assert args.folds == 5

    def test_parse_args_model_resnet50(self, monkeypatch):
        """Test model argument with resnet50."""
        monkeypatch.setattr("sys.argv", ["prog", "--model", "resnet50"])
        args = parse_args()

        assert args.model == "resnet50"

    def test_parse_args_epochs_override(self, monkeypatch):
        """Test epochs argument override."""
        monkeypatch.setattr("sys.argv", ["prog", "--epochs", "50"])
        args = parse_args()

        assert args.epochs == 50

    def test_parse_args_folds(self, monkeypatch):
        """Test folds argument for cross-validation."""
        monkeypatch.setattr("sys.argv", ["prog", "--cross-validate", "--folds", "10"])
        args = parse_args()

        assert args.folds == 10
        assert args.cross_validate is True

    def test_parse_args_flags(self, monkeypatch):
        """Test action flags."""
        monkeypatch.setattr("sys.argv", ["prog", "--all", "--train"])
        args = parse_args()

        assert args.all is True
        assert args.train is True
        assert args.test is False

    def test_parse_args_cross_validate(self, monkeypatch):
        """Test cross-validate flag."""
        monkeypatch.setattr("sys.argv", ["prog", "--cross-validate", "--exp", "4"])
        args = parse_args()

        assert args.cross_validate is True
        assert args.exp == 4

    def test_parse_args_exp_selection(self, monkeypatch):
        """Test experiment selection."""
        monkeypatch.setattr("sys.argv", ["prog", "--train", "--exp", "4", "--model", "resnet50"])
        args = parse_args()

        assert args.exp == 4
        assert args.train is True
        assert args.model == "resnet50"

    def test_get_device(self):
        """Test device selection."""
        device = get_device()
        assert device in [torch.device("cuda"), torch.device("mps"), torch.device("cpu")]

    def test_show_status_runs(self, capsys):
        """Test that show_status runs without error."""
        show_status()
        captured = capsys.readouterr()

        assert "EXPERIMENT STATUS" in captured.out
        assert "Exp6" in captured.out
        assert "Clean split" in captured.out or "clean_split" in captured.out