| from __future__ import annotations |
|
|
| import dataclasses |
| import os |
| import random |
| from datetime import datetime |
| from typing import TYPE_CHECKING |
|
|
| import numpy as np |
| import pandas as pd |
| import torch |
|
|
| if TYPE_CHECKING: |
| from src.config import Config |
|
|
|
|
| def set_seed(seed: int) -> None: |
| """Set all relevant random seeds for reproducibility.""" |
| random.seed(seed) |
| np.random.seed(seed) |
| torch.manual_seed(seed) |
| if torch.cuda.is_available(): |
| torch.cuda.manual_seed_all(seed) |
| elif torch.backends.mps.is_available(): |
| torch.mps.manual_seed(seed) |
|
|
|
|
| def free_device_cache(device: str) -> None: |
| """Release unused memory on GPU / MPS (useful between seeds / Optuna trials).""" |
| if device == "mps": |
| torch.mps.empty_cache() |
| elif device == "cuda": |
| torch.cuda.empty_cache() |
|
|
|
|
| def log_run( |
| model_name: str, |
| val_metrics: dict, |
| test_metrics: dict, |
| config: "Config", |
| n_seeds: int, |
| log_path: str = "results_log.csv", |
| ) -> pd.DataFrame: |
| """Append one training run to the global results log CSV. |
| |
| Creates the file with a header if it does not exist yet, otherwise appends. |
| |
| Columns |
| βββββββ |
| run_id, model_name, created_at, |
| <all Config fields except device/csv_path/image_dir/submission_test_dir/output_dir>, |
| n_seeds, |
| val_* (auc, sensitivity, specificity, accuracy, youden, composite, |
| precision, tp, tn, fp, fn, threshold) |
| test_* (same set) |
| |
| Parameters |
| ββββββββββ |
| model_name : human-readable name for this run (e.g. "xrv_densenet_ensemble") |
| val_metrics : dict returned by metrics_at_threshold on the validation split |
| test_metrics : dict returned by metrics_at_threshold on the test split |
| config : the Config instance used for this run |
| n_seeds : number of seeds in the ensemble |
| log_path : path to the CSV results log (created if missing) |
| """ |
| cfg_dict = dataclasses.asdict(config) |
|
|
| |
| skip = {"csv_path", "image_dir", "submission_test_dir", "output_dir", "device"} |
| hyperparams = {k: v for k, v in cfg_dict.items() if k not in skip} |
|
|
| |
| if "seeds" in hyperparams and isinstance(hyperparams["seeds"], list): |
| hyperparams["seeds"] = ",".join(str(s) for s in hyperparams["seeds"]) |
|
|
| row: dict = { |
| "run_id": datetime.now().strftime("%Y%m%d_%H%M%S"), |
| "model_name": model_name, |
| "created_at": datetime.now().strftime("%Y-%m-%d %H:%M:%S"), |
| "n_seeds": n_seeds, |
| **hyperparams, |
| } |
|
|
| for prefix, metrics in [("val", val_metrics), ("test", test_metrics)]: |
| for key, value in metrics.items(): |
| if key != "threshold": |
| row[f"{prefix}_{key}"] = value |
| row[f"{prefix}_threshold"] = metrics.get("threshold", float("nan")) |
|
|
| new_row_df = pd.DataFrame([row]) |
|
|
| if os.path.exists(log_path): |
| log_df = pd.read_csv(log_path) |
| log_df = pd.concat([log_df, new_row_df], ignore_index=True) |
| else: |
| log_df = new_row_df |
|
|
| log_df.to_csv(log_path, index=False) |
| print(f"Run logged β {log_path} ({len(log_df)} total runs)") |
| return log_df |
|
|