DDPM-2param / cross_model /scripts /compare_ddpm_training_curves.py
collins909's picture
Upload 2-parameter conditional DDPM (HI emulation, CAMELS LH params_2, epoch 200) with full training/eval/posterior toolchain
c496462 verified
#!/usr/bin/env python3
"""Parse DDPM Slurm stdout or bundled JSON for Train/Val loss series."""
from __future__ import annotations
import json
import re
from pathlib import Path
from typing import Tuple
_ROW = re.compile(
r"Epoch\s+(?P<ep>\d+)/\d+\s+\|\s+Train:\s+(?P<tr>[\d.eE+-]+)\s+\|\s+Val:\s+(?P<va>[\d.eE+-]+)",
)
def parse_slurm_training_log(path: str | Path) -> Tuple[list[int], list[float], list[float]]:
"""Return (epochs, train_losses, val_losses) parsed from Slurm *.out stdout."""
p = Path(path)
text = p.read_text(encoding="utf-8", errors="replace")
epochs, trains, vals = [], [], []
for m in _ROW.finditer(text):
epochs.append(int(m.group("ep")))
trains.append(float(m.group("tr")))
vals.append(float(m.group("va")))
return epochs, trains, vals
def load_training_loss_json(path: str | Path) -> Tuple[list[int], list[float], list[float]]:
"""Return (epochs, train_losses, val_losses) from a JSON export (keys: epochs, train, val)."""
p = Path(path)
raw = json.loads(p.read_text(encoding="utf-8"))
epochs = [int(e) for e in raw["epochs"]]
trains = [float(x) for x in raw["train"]]
vals = [float(x) for x in raw["val"]]
if not (len(epochs) == len(trains) == len(vals)):
raise ValueError(f"{p}: mismatched lengths in epochs/train/val")
return epochs, trains, vals
def load_train_val_series(path: str | Path) -> Tuple[list[int], list[float], list[float]]:
"""Slurm *.out or *.json with the same semantic output as ``parse_slurm_training_log``."""
p = Path(path)
if p.suffix.lower() == ".json":
return load_training_loss_json(p)
return parse_slurm_training_log(p)