#!/usr/bin/env python3 """Parse DDPM Slurm stdout or bundled JSON for Train/Val loss series.""" from __future__ import annotations import json import re from pathlib import Path from typing import Tuple _ROW = re.compile( r"Epoch\s+(?P\d+)/\d+\s+\|\s+Train:\s+(?P[\d.eE+-]+)\s+\|\s+Val:\s+(?P[\d.eE+-]+)", ) def parse_slurm_training_log(path: str | Path) -> Tuple[list[int], list[float], list[float]]: """Return (epochs, train_losses, val_losses) parsed from Slurm *.out stdout.""" p = Path(path) text = p.read_text(encoding="utf-8", errors="replace") epochs, trains, vals = [], [], [] for m in _ROW.finditer(text): epochs.append(int(m.group("ep"))) trains.append(float(m.group("tr"))) vals.append(float(m.group("va"))) return epochs, trains, vals def load_training_loss_json(path: str | Path) -> Tuple[list[int], list[float], list[float]]: """Return (epochs, train_losses, val_losses) from a JSON export (keys: epochs, train, val).""" p = Path(path) raw = json.loads(p.read_text(encoding="utf-8")) epochs = [int(e) for e in raw["epochs"]] trains = [float(x) for x in raw["train"]] vals = [float(x) for x in raw["val"]] if not (len(epochs) == len(trains) == len(vals)): raise ValueError(f"{p}: mismatched lengths in epochs/train/val") return epochs, trains, vals def load_train_val_series(path: str | Path) -> Tuple[list[int], list[float], list[float]]: """Slurm *.out or *.json with the same semantic output as ``parse_slurm_training_log``.""" p = Path(path) if p.suffix.lower() == ".json": return load_training_loss_json(p) return parse_slurm_training_log(p)