Spaces:
Sleeping
Sleeping
| # tests/test_cli/test_experiments.py | |
| import torch | |
| from app.cli.definitions import ( | |
| DATASET_PATHS, | |
| DEFAULT_CONFIG, | |
| EXPERIMENTS, | |
| ) | |
| from app.cli.experiments import parse_args | |
| from app.cli.runner import get_device, show_status | |
| class TestCLI: | |
| """Tests for CLI module.""" | |
| def test_experiments_defined(self): | |
| """Test that only valid experiments are defined.""" | |
| assert len(EXPERIMENTS) == 5 | |
| assert all(i in EXPERIMENTS for i in [6, 7, 8, 9]) | |
| def test_experiment_structure(self): | |
| """Test that experiments have required fields.""" | |
| required_fields = ["name", "train_dataset", "test_dataset", "num_classes", "test_classes", "description"] | |
| for exp_id, config in EXPERIMENTS.items(): | |
| for field in required_fields: | |
| assert field in config, f"Exp{exp_id} missing field: {field}" | |
| def test_dataset_paths_defined(self): | |
| """Test that required dataset paths are defined.""" | |
| required_datasets = ["clean"] | |
| for dataset in required_datasets: | |
| assert dataset in DATASET_PATHS, f"Missing dataset: {dataset}" | |
| def test_default_config(self): | |
| """Test default configuration values.""" | |
| assert DEFAULT_CONFIG["model_variant"] == "resnet18" | |
| assert DEFAULT_CONFIG["epochs"] == 26 | |
| assert DEFAULT_CONFIG["batch_size"] == 8 | |
| assert DEFAULT_CONFIG["learning_rate"] == 0.001 | |
| assert DEFAULT_CONFIG["num_folds"] == 5 | |
| def test_parse_args_default(self, monkeypatch): | |
| """Test default argument values.""" | |
| monkeypatch.setattr("sys.argv", ["prog"]) | |
| args = parse_args() | |
| assert args.model == "resnet18" | |
| assert args.epochs == 26 | |
| assert args.batch_size == 8 | |
| assert args.lr == 0.001 | |
| assert args.folds == 5 | |
| def test_parse_args_model_resnet50(self, monkeypatch): | |
| """Test model argument with resnet50.""" | |
| monkeypatch.setattr("sys.argv", ["prog", "--model", "resnet50"]) | |
| args = parse_args() | |
| assert args.model == "resnet50" | |
| def test_parse_args_epochs_override(self, monkeypatch): | |
| """Test epochs argument override.""" | |
| monkeypatch.setattr("sys.argv", ["prog", "--epochs", "50"]) | |
| args = parse_args() | |
| assert args.epochs == 50 | |
| def test_parse_args_folds(self, monkeypatch): | |
| """Test folds argument for cross-validation.""" | |
| monkeypatch.setattr("sys.argv", ["prog", "--cross-validate", "--folds", "10"]) | |
| args = parse_args() | |
| assert args.folds == 10 | |
| assert args.cross_validate is True | |
| def test_parse_args_flags(self, monkeypatch): | |
| """Test action flags.""" | |
| monkeypatch.setattr("sys.argv", ["prog", "--all", "--train"]) | |
| args = parse_args() | |
| assert args.all is True | |
| assert args.train is True | |
| assert args.test is False | |
| def test_parse_args_cross_validate(self, monkeypatch): | |
| """Test cross-validate flag.""" | |
| monkeypatch.setattr("sys.argv", ["prog", "--cross-validate", "--exp", "4"]) | |
| args = parse_args() | |
| assert args.cross_validate is True | |
| assert args.exp == 4 | |
| def test_parse_args_exp_selection(self, monkeypatch): | |
| """Test experiment selection.""" | |
| monkeypatch.setattr("sys.argv", ["prog", "--train", "--exp", "4", "--model", "resnet50"]) | |
| args = parse_args() | |
| assert args.exp == 4 | |
| assert args.train is True | |
| assert args.model == "resnet50" | |
| def test_get_device(self): | |
| """Test device selection.""" | |
| device = get_device() | |
| assert device in [torch.device("cuda"), torch.device("mps"), torch.device("cpu")] | |
| def test_show_status_runs(self, capsys): | |
| """Test that show_status runs without error.""" | |
| show_status() | |
| captured = capsys.readouterr() | |
| assert "EXPERIMENT STATUS" in captured.out | |
| assert "Exp6" in captured.out | |
| assert "Clean split" in captured.out or "clean_split" in captured.out | |