File size: 2,584 Bytes
090a270
 
f2a237f
090a270
f2a237f
090a270
 
 
 
 
 
20d06bb
 
 
 
 
 
 
 
 
090a270
20d06bb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
090a270
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import json
import os

import pytest

from app.visualization import plot_experiment


@pytest.fixture
def full_exp_dir(tmp_path):
    """Complete experiment directory for integration test."""
    (tmp_path / "config.json").write_text(
        json.dumps(
            {
                "model_variant": "resnet50",
                "num_epochs": 10,
                "patience": 5,
            }
        )
    )
    for i in range(2):
        (tmp_path / f"fold_{i}.json").write_text(
            json.dumps(
                {
                    "best_f1": 0.85,
                    "best_accuracy": 0.86,
                    "best_epoch": 3,
                    "final_train_loss": 0.1,
                    "final_val_loss": 0.5,
                    "history": {
                        "train_loss": [1.0, 0.5, 0.3, 0.15, 0.1],
                        "val_loss": [0.9, 0.6, 0.5, 0.48, 0.5],
                        "val_accuracy": [0.5, 0.7, 0.8, 0.85, 0.84],
                        "val_f1": [0.45, 0.65, 0.78, 0.85, 0.83],
                    },
                    "fold": i + 1,
                    "batch_size": 8,
                }
            )
        )
    (tmp_path / "results.json").write_text(
        json.dumps(
            {
                "mean_accuracy": 0.855,
                "std_accuracy": 0.01,
                "mean_f1": 0.85,
                "std_f1": 0.01,
                "best_fold": 0,
                "test_metrics": {
                    "accuracy": 0.97,
                    "f1_macro": 0.97,
                    "confusion_matrix": [[10, 1], [2, 15]],
                    "confusion_matrix_labels": ["A", "B"],
                },
                "fold_results": [],
            }
        )
    )
    return tmp_path


def test_plot_experiment_creates_all_figures(full_exp_dir):
    output_dir = str(full_exp_dir / "figures")
    plot_experiment(str(full_exp_dir), output_dir)
    expected = [
        "loss_curves.png",
        "metric_curves.png",
        "fold_summary.png",
        "confusion_matrix.png",
        "overfitting_proxy.png",
    ]
    for fname in expected:
        path = os.path.join(output_dir, fname)
        assert os.path.exists(path), f"Missing: {fname}"
        assert os.path.getsize(path) > 0, f"Empty: {fname}"


def test_plot_experiment_default_output_dir(full_exp_dir):
    # Default output_dir should be exp_dir/figures/
    plot_experiment(str(full_exp_dir))
    figures_dir = str(full_exp_dir / "figures")
    assert os.path.isdir(figures_dir)
    assert os.path.exists(os.path.join(figures_dir, "loss_curves.png"))