| import json |
| import os |
| import re |
| from dataclasses import asdict, fields, is_dataclass |
| from pathlib import Path |
| from typing import Any, Dict, List |
|
|
| try: |
| import yaml |
| except ModuleNotFoundError: |
| from sim_priors_pk.config_classes import yaml_fallback as yaml |
|
|
| import lightning.pytorch as pl |
| import torch |
| from lightning import Callback |
|
|
| from sim_priors_pk.config_classes.node_pk_config import ( |
| EncoderDecoderNetworkConfig, |
| MetaDosingConfig, |
| MetaStudyConfig, |
| MixDataConfig, |
| NodePKExperimentConfig, |
| ObservationsConfig, |
| ) |
| from sim_priors_pk.config_classes.training_config import TrainingConfig |
|
|
| EXPERIMENT_CONFIG_FILENAME = "experiment_config.yaml" |
|
|
|
|
| def get_lightning_checkpoint_path(experiment_dir, checkpoint_type="best"): |
| """ |
| Checks for lightning checkpoints in experiment folders and returns |
| the checkpoint with the highest epoch number for the specified type, |
| preferring non-versioned filenames when epochs are equal. |
| |
| checkpoint_type: one of ['best', 'best_log', 'last'] |
| - 'best' matches files like 'best-epoch=...' |
| - 'best_log' matches 'best_log_rmse.ckpt' |
| - 'last' matches 'last.ckpt' |
| """ |
| checkpoints_path = Path(experiment_dir) |
| if not checkpoints_path.exists(): |
| print(f"CHECKPOINTS DIRECTORY DOES NOT EXIST: {checkpoints_path}") |
| return None |
| if not checkpoints_path.is_dir(): |
| print(f"CHECKPOINTS PATH IS NOT A DIRECTORY: {checkpoints_path}") |
| return None |
| all_checkpoints = os.listdir(checkpoints_path) |
| checkpoint_epochs = {} |
|
|
| if len(all_checkpoints) == 0: |
| print("NO CHECKPOINTS FOUND IN DIRECTORY.") |
| return None |
|
|
| |
| for ckpt in all_checkpoints: |
| ckpt_path = checkpoints_path / ckpt |
| if ckpt.startswith("best-epoch="): |
| match = re.search(r"epoch=(\d+)", ckpt) |
| if match: |
| epoch = int(match.group(1)) |
| if "best" not in checkpoint_epochs or epoch > checkpoint_epochs["best"]["epoch"]: |
| checkpoint_epochs["best"] = {"epoch": epoch, "path": ckpt_path} |
| elif ckpt == "best_log_rmse.ckpt": |
| checkpoint_epochs["best_log"] = {"epoch": -1, "path": ckpt_path} |
| elif ckpt == "last.ckpt": |
| checkpoint_epochs["last"] = {"epoch": -1, "path": ckpt_path} |
| elif ckpt.startswith("recon-epoch="): |
| match = re.search(r"epoch=(\d+)", ckpt) |
| if match: |
| epoch = int(match.group(1)) |
| if "recon" not in checkpoint_epochs or epoch > checkpoint_epochs["recon"]["epoch"]: |
| checkpoint_epochs["recon"] = {"epoch": epoch, "path": ckpt_path} |
|
|
| if not checkpoint_epochs: |
| print("NO VALID CHECKPOINTS FOUND, RETURNING FIRST AVAILABLE") |
| return checkpoints_path / all_checkpoints[0] |
|
|
| if checkpoint_type in checkpoint_epochs: |
| return checkpoint_epochs[checkpoint_type]["path"] |
|
|
| for preferred_type in ["last", "best", "recon"]: |
| if preferred_type in checkpoint_epochs: |
| print( |
| f"CHECKPOINT TYPE '{checkpoint_type}' NOT FOUND, RETURNING '{preferred_type}' INSTEAD." |
| ) |
| return checkpoint_epochs[preferred_type]["path"] |
|
|
| print("NO RECOGNIZED CHECKPOINT TYPES FOUND, RETURNING FIRST AVAILABLE") |
| return checkpoints_path / all_checkpoints[0] |
|
|
|
|
| |
| class UnusedParamReporter(pl.Callback): |
| """ |
| checks for unussed parameters for ddp strategy compatibility |
| |
| trainer = Trainer( |
| default_root_dir=self.experiment_dir, |
| accelerator="gpu" if torch.cuda.is_available() else "cpu", |
| devices=self.devices, |
| strategy=self.strategy, |
| logger=self.logger, |
| #max_epochs=self.model_config.train.epochs, |
| max_epochs=3, |
| #callbacks=self.callbacks or [], # the callbacks are defined one in ranck zero empty list for the workers |
| callbacks=[UnusedParamReporter()], |
| log_every_n_steps=self.model_config.train.log_interval, |
| gradient_clip_val=self.model_config.train.gradient_clip_val, |
| reload_dataloaders_every_n_epochs=1, |
| ) |
| """ |
|
|
| def on_train_epoch_start(self, trainer, pl_module): |
| self.touched = set() |
|
|
| def on_after_backward(self, trainer, pl_module): |
| |
| for n, p in pl_module.named_parameters(): |
| if p.requires_grad and p.grad is not None: |
| self.touched.add(n) |
|
|
| def on_train_epoch_end(self, trainer, pl_module): |
| unused = [ |
| n for n, p in pl_module.named_parameters() if p.requires_grad and n not in self.touched |
| ] |
| if unused: |
| print("\n[Unused parameters]") |
| for n in unused: |
| print(" •", n) |
|
|
|
|
| class NonFiniteLossCallback(pl.Callback): |
| """Abort training when any logged metric is non-finite and record the last |
| valid checkpoint path.""" |
|
|
| def __init__(self): |
| from typing import Optional |
|
|
| self.last_valid_checkpoint: Optional[str] = None |
|
|
| def on_train_start(self, trainer, pl_module) -> None: |
| setattr(pl_module, "nan_detected", False) |
|
|
| def on_validation_epoch_end(self, trainer, pl_module) -> None: |
| metrics = trainer.callback_metrics |
| has_invalid = False |
| for v in metrics.values(): |
| if isinstance(v, torch.Tensor) and not torch.isfinite(v).all(): |
| has_invalid = True |
| break |
| if has_invalid: |
| setattr(pl_module, "nan_detected", True) |
| trainer.should_stop = True |
| else: |
| for cb in trainer.callbacks: |
| if isinstance(cb, pl.callbacks.ModelCheckpoint) and getattr( |
| cb, "last_model_path", None |
| ): |
| self.last_valid_checkpoint = cb.last_model_path |
| break |
|
|
|
|
| def metrics_are_finite(metrics: dict[str, torch.Tensor]) -> bool: |
| """Return ``True`` if all tensor values are finite.""" |
| for v in metrics.values(): |
| if isinstance(v, torch.Tensor) and not torch.isfinite(v).all(): |
| return False |
| return True |
|
|
|
|
| def dataclass_from_dict(klass, dikt): |
| if not is_dataclass(klass): |
| return dikt |
| kwargs = {} |
| for f in fields(klass): |
| name = f.name |
| if name in dikt: |
| value = dikt[name] |
| |
| if is_dataclass(f.type) and isinstance(value, dict): |
| kwargs[name] = dataclass_from_dict(f.type, value) |
| else: |
| kwargs[name] = value |
| return klass(**kwargs) |
|
|
|
|
| def _filter_dataclass_kwargs(dataclass_type, raw_mapping: Dict[str, Any]) -> Dict[str, Any]: |
| """Keep only fields declared on ``dataclass_type``. |
| |
| This makes Comet config reconstruction tolerant to legacy parameters that |
| may still exist on older runs but were removed from the current config |
| dataclasses. |
| """ |
|
|
| if not isinstance(raw_mapping, dict): |
| return {} |
| valid = {field_.name for field_ in fields(dataclass_type)} |
| return {key: value for key, value in raw_mapping.items() if key in valid} |
|
|
|
|
| def get_experiment_config_path( |
| experiment_dir: str, |
| filename: str = EXPERIMENT_CONFIG_FILENAME, |
| ) -> str: |
| """Return the expected YAML path for an experiment configuration.""" |
| return os.path.join(experiment_dir, filename) |
|
|
|
|
| def save_experiment_config_yaml( |
| exp_config: Any, |
| experiment_dir: str, |
| filename: str = EXPERIMENT_CONFIG_FILENAME, |
| ) -> str: |
| """Persist ``exp_config`` as YAML under ``experiment_dir`` and return the file path.""" |
| if exp_config is None: |
| raise ValueError("exp_config must be provided to save the experiment YAML.") |
| os.makedirs(experiment_dir, exist_ok=True) |
| yaml_path = get_experiment_config_path(experiment_dir, filename) |
|
|
| to_yaml = getattr(exp_config, "to_yaml", None) |
| if callable(to_yaml): |
| to_yaml(yaml_path) |
| return yaml_path |
|
|
| payload = asdict(exp_config) if is_dataclass(exp_config) else exp_config |
| with open(yaml_path, "w", encoding="utf-8") as handle: |
| yaml.dump(payload, handle, default_flow_style=False) |
| return yaml_path |
|
|
|
|
| def load_experiment_config_yaml( |
| experiment_dir: str, |
| filename: str = EXPERIMENT_CONFIG_FILENAME, |
| ): |
| """Load an experiment config YAML from ``experiment_dir`` and parse it.""" |
| yaml_path = get_experiment_config_path(experiment_dir, filename) |
| if not os.path.isfile(yaml_path): |
| raise FileNotFoundError(f"Experiment config YAML not found at: {yaml_path}") |
| from sim_priors_pk.models import get_model_config |
|
|
| return get_model_config(yaml_path) |
|
|
|
|
| def _convert_value(val_str: str) -> Any: |
| """ |
| Convert a Comet string value to an appropriate Python type. |
| |
| - "null" -> None |
| - JSON lists/dicts -> Python list/dict |
| - "true"/"false" -> bool |
| - ints/floats -> int/float |
| - fallback: original string |
| """ |
| if val_str is None or val_str == "null": |
| return None |
|
|
| |
| if isinstance(val_str, str): |
| s = val_str.strip() |
| if (s.startswith("[") and s.endswith("]")) or (s.startswith("{") and s.endswith("}")): |
| try: |
| parsed = json.loads(s) |
| return parsed |
| except json.JSONDecodeError: |
| pass |
|
|
| low = str(val_str).lower() |
| if low in ("true", "false"): |
| return low == "true" |
|
|
| |
| try: |
| |
| i = int(val_str) |
| return i |
| except (ValueError, TypeError): |
| try: |
| return float(val_str) |
| except (ValueError, TypeError): |
| return val_str |
|
|
|
|
| def parse_comet_parameters_summary(parameters: List[Dict[str, Any]]) -> NodePKExperimentConfig: |
| """ |
| Parse a list of Comet API parameter summaries into a NodePKConfig. |
| |
| - Strips "config/" or "model_config/" prefixes if present. |
| - Splits names on '/', '.', or '|' to detect section and field. |
| - Converts JSON lists/dicts, booleans, numbers, and "null". |
| - Builds nested sub-configs for known sections. |
| - Filters out any keys not defined on NodePKConfig. |
| """ |
| nested: Dict[str, Dict[str, Any]] = {} |
| top_level: Dict[str, Any] = {} |
|
|
| for entry in parameters: |
| raw = entry.get("name", "") |
| |
| parts = [p for p in re.split(r"[\/\.|]", raw) if p] |
| |
| if parts and parts[0] in ("config", "model_config"): |
| parts = parts[1:] |
| if not parts: |
| continue |
|
|
| val = _convert_value(entry.get("valueCurrent")) |
| if len(parts) == 1: |
| top_level[parts[0]] = val |
| elif len(parts) == 2: |
| sec, fld = parts |
| nested.setdefault(sec, {})[fld] = val |
| |
|
|
| |
| init_kwargs: Dict[str, Any] = {} |
| for sec, mapping in nested.items(): |
| if sec == "network": |
| init_kwargs["network"] = EncoderDecoderNetworkConfig( |
| **_filter_dataclass_kwargs(EncoderDecoderNetworkConfig, mapping) |
| ) |
| elif sec == "mix_data": |
| init_kwargs["mix_data"] = MixDataConfig(**_filter_dataclass_kwargs(MixDataConfig, mapping)) |
| elif sec in ("context_observations", "target_observations"): |
| init_kwargs[sec] = ObservationsConfig( |
| **_filter_dataclass_kwargs(ObservationsConfig, mapping) |
| ) |
| elif sec == "meta_study": |
| init_kwargs["meta_study"] = MetaStudyConfig( |
| **_filter_dataclass_kwargs(MetaStudyConfig, mapping) |
| ) |
| elif sec == "dosing": |
| init_kwargs["dosing"] = MetaDosingConfig( |
| **_filter_dataclass_kwargs(MetaDosingConfig, mapping) |
| ) |
| elif sec == "train": |
| init_kwargs["train"] = TrainingConfig(**TrainingConfig._filter_kwargs(mapping)) |
| |
|
|
| init_kwargs.update(top_level) |
|
|
| |
| valid = {f.name for f in fields(NodePKExperimentConfig)} |
| filtered = {k: v for k, v in init_kwargs.items() if k in valid} |
|
|
| return NodePKExperimentConfig(**filtered) |
|
|
|
|
| class SkipNonFiniteGrad(Callback): |
| """Cancel the optimizer step when the total grad‑norm is not finite.""" |
|
|
| def on_before_optimizer_step(self, trainer, pl_module, optimizer, opt_idx): |
| grad_norm = torch.nn.utils.clip_grad_norm_( |
| pl_module.parameters(), max_norm=1.0 |
| ) |
| if not torch.isfinite(grad_norm): |
| pl_module.print(f"⚠️ Non‑finite grad‑norm ({grad_norm}); skipping step.") |
| optimizer.zero_grad(set_to_none=True) |
| trainer.fit_loop._skip_backward = True |
|
|
|
|
| if __name__ == "__main__": |
| experiment_dir = "test_dir" |
| all_checkpoints = [ |
| "best-epoch=01.ckpt", |
| "best-epoch=19.ckpt", |
| "best-epoch=20.ckpt", |
| "last-epoch=11.ckpt", |
| "last-epoch=20-v1.ckpt", |
| "last-epoch=20.ckpt", |
| "recon-epoch=01.ckpt", |
| "recon-epoch=05.ckpt", |
| ] |
|
|
| |
| os.listdir = lambda x: all_checkpoints |
|
|
| |
| print(get_lightning_checkpoint_path(experiment_dir, "last")) |
| print(get_lightning_checkpoint_path(experiment_dir, "best")) |
| print( |
| get_lightning_checkpoint_path(experiment_dir, "recon") |
| ) |
|
|