File size: 3,299 Bytes
3f984f1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
from __future__ import annotations

import dataclasses
import os
import random
from datetime import datetime
from typing import TYPE_CHECKING

import numpy as np
import pandas as pd
import torch

if TYPE_CHECKING:
    from src.config import Config


def set_seed(seed: int) -> None:
    """Set all relevant random seeds for reproducibility."""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    elif torch.backends.mps.is_available():
        torch.mps.manual_seed(seed)


def free_device_cache(device: str) -> None:
    """Release unused memory on GPU / MPS (useful between seeds / Optuna trials)."""
    if device == "mps":
        torch.mps.empty_cache()
    elif device == "cuda":
        torch.cuda.empty_cache()


def log_run(
    model_name: str,
    val_metrics: dict,
    test_metrics: dict,
    config: "Config",
    n_seeds: int,
    log_path: str = "results_log.csv",
) -> pd.DataFrame:
    """Append one training run to the global results log CSV.

    Creates the file with a header if it does not exist yet, otherwise appends.

    Columns
    ───────
    run_id, model_name, created_at,
    <all Config fields except device/csv_path/image_dir/submission_test_dir/output_dir>,
    n_seeds,
    val_*  (auc, sensitivity, specificity, accuracy, youden, composite,
            precision, tp, tn, fp, fn, threshold)
    test_* (same set)

    Parameters
    ──────────
    model_name   : human-readable name for this run (e.g. "xrv_densenet_ensemble")
    val_metrics  : dict returned by metrics_at_threshold on the validation split
    test_metrics : dict returned by metrics_at_threshold on the test split
    config       : the Config instance used for this run
    n_seeds      : number of seeds in the ensemble
    log_path     : path to the CSV results log (created if missing)
    """
    cfg_dict = dataclasses.asdict(config)

    # exclude path/device fields β€” not meaningful for comparison
    skip = {"csv_path", "image_dir", "submission_test_dir", "output_dir", "device"}
    hyperparams = {k: v for k, v in cfg_dict.items() if k not in skip}

    # `seeds` is a list β€” flatten to a string so the CSV stays readable
    if "seeds" in hyperparams and isinstance(hyperparams["seeds"], list):
        hyperparams["seeds"] = ",".join(str(s) for s in hyperparams["seeds"])

    row: dict = {
        "run_id":     datetime.now().strftime("%Y%m%d_%H%M%S"),
        "model_name": model_name,
        "created_at": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
        "n_seeds":    n_seeds,
        **hyperparams,
    }

    for prefix, metrics in [("val", val_metrics), ("test", test_metrics)]:
        for key, value in metrics.items():
            if key != "threshold":
                row[f"{prefix}_{key}"] = value
        row[f"{prefix}_threshold"] = metrics.get("threshold", float("nan"))

    new_row_df = pd.DataFrame([row])

    if os.path.exists(log_path):
        log_df = pd.read_csv(log_path)
        log_df = pd.concat([log_df, new_row_df], ignore_index=True)
    else:
        log_df = new_row_df

    log_df.to_csv(log_path, index=False)
    print(f"Run logged β†’ {log_path}  ({len(log_df)} total runs)")
    return log_df