Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import dataclasses | |
| import os | |
| import random | |
| from datetime import datetime | |
| from typing import TYPE_CHECKING | |
| import numpy as np | |
| import pandas as pd | |
| import torch | |
| if TYPE_CHECKING: | |
| from src.config import Config | |
| def set_seed(seed: int) -> None: | |
| """Set all relevant random seeds for reproducibility.""" | |
| random.seed(seed) | |
| np.random.seed(seed) | |
| torch.manual_seed(seed) | |
| if torch.cuda.is_available(): | |
| torch.cuda.manual_seed_all(seed) | |
| elif torch.backends.mps.is_available(): | |
| torch.mps.manual_seed(seed) | |
| def free_device_cache(device: str) -> None: | |
| """Release unused memory on GPU / MPS (useful between seeds / Optuna trials).""" | |
| if device == "mps": | |
| torch.mps.empty_cache() | |
| elif device == "cuda": | |
| torch.cuda.empty_cache() | |
| def log_run( | |
| model_name: str, | |
| val_metrics: dict, | |
| test_metrics: dict, | |
| config: "Config", | |
| n_seeds: int, | |
| log_path: str = "results_log.csv", | |
| ) -> pd.DataFrame: | |
| """Append one training run to the global results log CSV. | |
| Creates the file with a header if it does not exist yet, otherwise appends. | |
| Columns | |
| βββββββ | |
| run_id, model_name, created_at, | |
| <all Config fields except device/csv_path/image_dir/submission_test_dir/output_dir>, | |
| n_seeds, | |
| val_* (auc, sensitivity, specificity, accuracy, youden, composite, | |
| precision, tp, tn, fp, fn, threshold) | |
| test_* (same set) | |
| Parameters | |
| ββββββββββ | |
| model_name : human-readable name for this run (e.g. "xrv_densenet_ensemble") | |
| val_metrics : dict returned by metrics_at_threshold on the validation split | |
| test_metrics : dict returned by metrics_at_threshold on the test split | |
| config : the Config instance used for this run | |
| n_seeds : number of seeds in the ensemble | |
| log_path : path to the CSV results log (created if missing) | |
| """ | |
| cfg_dict = dataclasses.asdict(config) | |
| # exclude path/device fields β not meaningful for comparison | |
| skip = {"csv_path", "image_dir", "submission_test_dir", "output_dir", "device"} | |
| hyperparams = {k: v for k, v in cfg_dict.items() if k not in skip} | |
| # `seeds` is a list β flatten to a string so the CSV stays readable | |
| if "seeds" in hyperparams and isinstance(hyperparams["seeds"], list): | |
| hyperparams["seeds"] = ",".join(str(s) for s in hyperparams["seeds"]) | |
| row: dict = { | |
| "run_id": datetime.now().strftime("%Y%m%d_%H%M%S"), | |
| "model_name": model_name, | |
| "created_at": datetime.now().strftime("%Y-%m-%d %H:%M:%S"), | |
| "n_seeds": n_seeds, | |
| **hyperparams, | |
| } | |
| for prefix, metrics in [("val", val_metrics), ("test", test_metrics)]: | |
| for key, value in metrics.items(): | |
| if key != "threshold": | |
| row[f"{prefix}_{key}"] = value | |
| row[f"{prefix}_threshold"] = metrics.get("threshold", float("nan")) | |
| new_row_df = pd.DataFrame([row]) | |
| if os.path.exists(log_path): | |
| log_df = pd.read_csv(log_path) | |
| log_df = pd.concat([log_df, new_row_df], ignore_index=True) | |
| else: | |
| log_df = new_row_df | |
| log_df.to_csv(log_path, index=False) | |
| print(f"Run logged β {log_path} ({len(log_df)} total runs)") | |
| return log_df | |