import json import os import re from dataclasses import asdict, fields, is_dataclass from pathlib import Path from typing import Any, Dict, List try: # pragma: no cover - prefer PyYAML when available import yaml # type: ignore except ModuleNotFoundError: # pragma: no cover - fallback for minimal environments from sim_priors_pk.config_classes import yaml_fallback as yaml import lightning.pytorch as pl import torch from lightning import Callback from sim_priors_pk.config_classes.node_pk_config import ( EncoderDecoderNetworkConfig, MetaDosingConfig, MetaStudyConfig, MixDataConfig, NodePKExperimentConfig, ObservationsConfig, ) from sim_priors_pk.config_classes.training_config import TrainingConfig EXPERIMENT_CONFIG_FILENAME = "experiment_config.yaml" def get_lightning_checkpoint_path(experiment_dir, checkpoint_type="best"): """ Checks for lightning checkpoints in experiment folders and returns the checkpoint with the highest epoch number for the specified type, preferring non-versioned filenames when epochs are equal. checkpoint_type: one of ['best', 'best_log', 'last'] - 'best' matches files like 'best-epoch=...' - 'best_log' matches 'best_log_rmse.ckpt' - 'last' matches 'last.ckpt' """ checkpoints_path = Path(experiment_dir) if not checkpoints_path.exists(): print(f"CHECKPOINTS DIRECTORY DOES NOT EXIST: {checkpoints_path}") return None if not checkpoints_path.is_dir(): print(f"CHECKPOINTS PATH IS NOT A DIRECTORY: {checkpoints_path}") return None all_checkpoints = os.listdir(checkpoints_path) checkpoint_epochs = {} if len(all_checkpoints) == 0: print("NO CHECKPOINTS FOUND IN DIRECTORY.") return None # Populate available checkpoints for ckpt in all_checkpoints: ckpt_path = checkpoints_path / ckpt if ckpt.startswith("best-epoch="): match = re.search(r"epoch=(\d+)", ckpt) if match: epoch = int(match.group(1)) if "best" not in checkpoint_epochs or epoch > checkpoint_epochs["best"]["epoch"]: checkpoint_epochs["best"] = {"epoch": epoch, "path": ckpt_path} elif ckpt == "best_log_rmse.ckpt": checkpoint_epochs["best_log"] = {"epoch": -1, "path": ckpt_path} elif ckpt == "last.ckpt": checkpoint_epochs["last"] = {"epoch": -1, "path": ckpt_path} elif ckpt.startswith("recon-epoch="): match = re.search(r"epoch=(\d+)", ckpt) if match: epoch = int(match.group(1)) if "recon" not in checkpoint_epochs or epoch > checkpoint_epochs["recon"]["epoch"]: checkpoint_epochs["recon"] = {"epoch": epoch, "path": ckpt_path} if not checkpoint_epochs: print("NO VALID CHECKPOINTS FOUND, RETURNING FIRST AVAILABLE") return checkpoints_path / all_checkpoints[0] if checkpoint_type in checkpoint_epochs: return checkpoint_epochs[checkpoint_type]["path"] for preferred_type in ["last", "best", "recon"]: if preferred_type in checkpoint_epochs: print( f"CHECKPOINT TYPE '{checkpoint_type}' NOT FOUND, RETURNING '{preferred_type}' INSTEAD." ) return checkpoint_epochs[preferred_type]["path"] print("NO RECOGNIZED CHECKPOINT TYPES FOUND, RETURNING FIRST AVAILABLE") return checkpoints_path / all_checkpoints[0] # Test the function class UnusedParamReporter(pl.Callback): """ checks for unussed parameters for ddp strategy compatibility trainer = Trainer( default_root_dir=self.experiment_dir, accelerator="gpu" if torch.cuda.is_available() else "cpu", devices=self.devices, strategy=self.strategy, logger=self.logger, #max_epochs=self.model_config.train.epochs, max_epochs=3, #callbacks=self.callbacks or [], # the callbacks are defined one in ranck zero empty list for the workers callbacks=[UnusedParamReporter()], log_every_n_steps=self.model_config.train.log_interval, gradient_clip_val=self.model_config.train.gradient_clip_val, reload_dataloaders_every_n_epochs=1, ) """ def on_train_epoch_start(self, trainer, pl_module): self.touched = set() # reset every epoch def on_after_backward(self, trainer, pl_module): # runs *immediately* after each backward() call for n, p in pl_module.named_parameters(): if p.requires_grad and p.grad is not None: self.touched.add(n) def on_train_epoch_end(self, trainer, pl_module): unused = [ n for n, p in pl_module.named_parameters() if p.requires_grad and n not in self.touched ] if unused: print("\n[Unused parameters]") for n in unused: print(" •", n) class NonFiniteLossCallback(pl.Callback): """Abort training when any logged metric is non-finite and record the last valid checkpoint path.""" def __init__(self): from typing import Optional self.last_valid_checkpoint: Optional[str] = None def on_train_start(self, trainer, pl_module) -> None: setattr(pl_module, "nan_detected", False) def on_validation_epoch_end(self, trainer, pl_module) -> None: metrics = trainer.callback_metrics has_invalid = False for v in metrics.values(): if isinstance(v, torch.Tensor) and not torch.isfinite(v).all(): has_invalid = True break if has_invalid: setattr(pl_module, "nan_detected", True) trainer.should_stop = True else: for cb in trainer.callbacks: if isinstance(cb, pl.callbacks.ModelCheckpoint) and getattr( cb, "last_model_path", None ): self.last_valid_checkpoint = cb.last_model_path break def metrics_are_finite(metrics: dict[str, torch.Tensor]) -> bool: """Return ``True`` if all tensor values are finite.""" for v in metrics.values(): if isinstance(v, torch.Tensor) and not torch.isfinite(v).all(): return False return True def dataclass_from_dict(klass, dikt): if not is_dataclass(klass): return dikt kwargs = {} for f in fields(klass): name = f.name if name in dikt: value = dikt[name] # nested dataclass if is_dataclass(f.type) and isinstance(value, dict): kwargs[name] = dataclass_from_dict(f.type, value) else: kwargs[name] = value return klass(**kwargs) def _filter_dataclass_kwargs(dataclass_type, raw_mapping: Dict[str, Any]) -> Dict[str, Any]: """Keep only fields declared on ``dataclass_type``. This makes Comet config reconstruction tolerant to legacy parameters that may still exist on older runs but were removed from the current config dataclasses. """ if not isinstance(raw_mapping, dict): return {} valid = {field_.name for field_ in fields(dataclass_type)} return {key: value for key, value in raw_mapping.items() if key in valid} def get_experiment_config_path( experiment_dir: str, filename: str = EXPERIMENT_CONFIG_FILENAME, ) -> str: """Return the expected YAML path for an experiment configuration.""" return os.path.join(experiment_dir, filename) def save_experiment_config_yaml( exp_config: Any, experiment_dir: str, filename: str = EXPERIMENT_CONFIG_FILENAME, ) -> str: """Persist ``exp_config`` as YAML under ``experiment_dir`` and return the file path.""" if exp_config is None: raise ValueError("exp_config must be provided to save the experiment YAML.") os.makedirs(experiment_dir, exist_ok=True) yaml_path = get_experiment_config_path(experiment_dir, filename) to_yaml = getattr(exp_config, "to_yaml", None) if callable(to_yaml): to_yaml(yaml_path) return yaml_path payload = asdict(exp_config) if is_dataclass(exp_config) else exp_config with open(yaml_path, "w", encoding="utf-8") as handle: yaml.dump(payload, handle, default_flow_style=False) return yaml_path def load_experiment_config_yaml( experiment_dir: str, filename: str = EXPERIMENT_CONFIG_FILENAME, ): """Load an experiment config YAML from ``experiment_dir`` and parse it.""" yaml_path = get_experiment_config_path(experiment_dir, filename) if not os.path.isfile(yaml_path): raise FileNotFoundError(f"Experiment config YAML not found at: {yaml_path}") from sim_priors_pk.models import get_model_config return get_model_config(yaml_path) def _convert_value(val_str: str) -> Any: """ Convert a Comet string value to an appropriate Python type. - "null" -> None - JSON lists/dicts -> Python list/dict - "true"/"false" -> bool - ints/floats -> int/float - fallback: original string """ if val_str is None or val_str == "null": return None # try JSON for list or dict if isinstance(val_str, str): s = val_str.strip() if (s.startswith("[") and s.endswith("]")) or (s.startswith("{") and s.endswith("}")): try: parsed = json.loads(s) return parsed except json.JSONDecodeError: pass low = str(val_str).lower() if low in ("true", "false"): return low == "true" # numeric? try: # int first (so "1" -> 1, not 1.0) i = int(val_str) return i except (ValueError, TypeError): try: return float(val_str) except (ValueError, TypeError): return val_str def parse_comet_parameters_summary(parameters: List[Dict[str, Any]]) -> NodePKExperimentConfig: """ Parse a list of Comet API parameter summaries into a NodePKConfig. - Strips "config/" or "model_config/" prefixes if present. - Splits names on '/', '.', or '|' to detect section and field. - Converts JSON lists/dicts, booleans, numbers, and "null". - Builds nested sub-configs for known sections. - Filters out any keys not defined on NodePKConfig. """ nested: Dict[str, Dict[str, Any]] = {} top_level: Dict[str, Any] = {} for entry in parameters: raw = entry.get("name", "") # split on '/', '.', or '|' and drop empty parts parts = [p for p in re.split(r"[\/\.|]", raw) if p] # drop "config" or "model_config" prefix if parts and parts[0] in ("config", "model_config"): parts = parts[1:] if not parts: continue val = _convert_value(entry.get("valueCurrent")) if len(parts) == 1: top_level[parts[0]] = val elif len(parts) == 2: sec, fld = parts nested.setdefault(sec, {})[fld] = val # deeper nesting is ignored # assemble constructor kwargs init_kwargs: Dict[str, Any] = {} for sec, mapping in nested.items(): if sec == "network": init_kwargs["network"] = EncoderDecoderNetworkConfig( **_filter_dataclass_kwargs(EncoderDecoderNetworkConfig, mapping) ) elif sec == "mix_data": init_kwargs["mix_data"] = MixDataConfig(**_filter_dataclass_kwargs(MixDataConfig, mapping)) elif sec in ("context_observations", "target_observations"): init_kwargs[sec] = ObservationsConfig( **_filter_dataclass_kwargs(ObservationsConfig, mapping) ) elif sec == "meta_study": init_kwargs["meta_study"] = MetaStudyConfig( **_filter_dataclass_kwargs(MetaStudyConfig, mapping) ) elif sec == "dosing": init_kwargs["dosing"] = MetaDosingConfig( **_filter_dataclass_kwargs(MetaDosingConfig, mapping) ) elif sec == "train": init_kwargs["train"] = TrainingConfig(**TrainingConfig._filter_kwargs(mapping)) # unknown sections get dropped init_kwargs.update(top_level) # filter to only the real NodePKConfig fields valid = {f.name for f in fields(NodePKExperimentConfig)} filtered = {k: v for k, v in init_kwargs.items() if k in valid} return NodePKExperimentConfig(**filtered) class SkipNonFiniteGrad(Callback): """Cancel the optimizer step when the total grad‑norm is not finite.""" def on_before_optimizer_step(self, trainer, pl_module, optimizer, opt_idx): grad_norm = torch.nn.utils.clip_grad_norm_( pl_module.parameters(), max_norm=1.0 ) # clip + return norm if not torch.isfinite(grad_norm): pl_module.print(f"⚠️ Non‑finite grad‑norm ({grad_norm}); skipping step.") optimizer.zero_grad(set_to_none=True) trainer.fit_loop._skip_backward = True # Lightning ≥2.2 if __name__ == "__main__": experiment_dir = "test_dir" all_checkpoints = [ "best-epoch=01.ckpt", "best-epoch=19.ckpt", "best-epoch=20.ckpt", "last-epoch=11.ckpt", "last-epoch=20-v1.ckpt", "last-epoch=20.ckpt", "recon-epoch=01.ckpt", "recon-epoch=05.ckpt", ] # Simulate os.listdir os.listdir = lambda x: all_checkpoints # Test different checkpoint types print(get_lightning_checkpoint_path(experiment_dir, "last")) # Should return last-epoch=20.ckpt print(get_lightning_checkpoint_path(experiment_dir, "best")) # Should return best-epoch=20.ckpt print( get_lightning_checkpoint_path(experiment_dir, "recon") ) # Should return recon-epoch=05.ckpt