""" patch_epoch_saving.py — Wrap `opf train` to save a checkpoint after every epoch. Usage (drop-in replacement for `opf train`): python scripts/patch_epoch_saving.py train.jsonl \\ --output-dir ./ckpt --epochs 5 [... all normal opf train flags ...] Output layout: ./ckpt/ <- best epoch checkpoint (identical to plain `opf train`) ./ckpt/epoch_1/ <- checkpoint after epoch 1 ./ckpt/epoch_2/ <- checkpoint after epoch 2 ... Per-epoch checkpoints save the model state at the END of each epoch (not the best-so-far). Use them for manual selection or checkpoint averaging afterwards. How it works (implementation notes): opf/_train/runner.py has a single `main()` function with an inline epoch loop. We can't inject code between loop lines without modifying the source, so we patch three private functions that are called at predictable points: 1. _train_one_epoch(model=..., epoch_index=N, ...) → Called once per epoch. We wrap it to (a) capture the model object on the first call and (b) save a per-epoch checkpoint after each call. 2. _resolve_output_dtype(...) → Called once, just before the epoch loop writes the final checkpoint. We intercept it to capture the serialisation dtype. 3. json.dumps (inside the runner module's namespace) → The output_config dict (containing span_class_names, num_labels, …) is serialised with json.dumps and then written to config.json. We intercept the specific call by replacing json in the runner module's globals with a thin shim that captures the first dict containing "span_class_names". This fires BEFORE the epoch loop, so all per-epoch checkpoints get a proper config.json. All patches are installed before opf_main() is called and are module-local — they do not affect any other imports in the process. """ from __future__ import annotations import json import sys from pathlib import Path from typing import Any import torch from safetensors.torch import save_file # --------------------------------------------------------------------------- # Shared context — populated by the patches during training # --------------------------------------------------------------------------- _CTX: dict[str, Any] = { "model": None, # torch.nn.Module — captured from first _train_one_epoch call "output_dir": None, # Path — captured from _ensure_output_dir patch "output_config": None, # dict — captured from json.dumps shim "output_dtype": None, # torch.dtype — captured from _resolve_output_dtype patch "epoch_saves": [], # list[(epoch, Path)] — filled as epochs complete } # --------------------------------------------------------------------------- # Per-epoch checkpoint writer # --------------------------------------------------------------------------- def _save_epoch_checkpoint(epoch: int) -> None: model = _CTX["model"] output_dir: Path | None = _CTX["output_dir"] output_config: dict | None = _CTX["output_config"] output_dtype: torch.dtype | None = _CTX["output_dtype"] if model is None or output_dir is None: print( f"[patch_epoch_saving] warning: skipping epoch {epoch} save " "(context not yet populated)", flush=True, ) return epoch_dir = output_dir / f"epoch_{epoch}" epoch_dir.mkdir(parents=True, exist_ok=True) # config.json — use captured config if available, else copy from output_dir if output_config is not None: (epoch_dir / "config.json").write_text( json.dumps(output_config, indent=2, sort_keys=True) + "\n", encoding="utf-8", ) else: src = output_dir / "config.json" if src.exists(): import shutil # noqa: PLC0415 shutil.copy2(src, epoch_dir / "config.json") # model.safetensors — snapshot named_parameters at current dtype tensors: dict[str, torch.Tensor] = {} for name, param in model.named_parameters(): t = param.detach().cpu() if output_dtype is not None: t = t.to(output_dtype) tensors[name] = t.contiguous() save_file(tensors, str(epoch_dir / "model.safetensors")) _CTX["epoch_saves"].append((epoch, epoch_dir)) print( f"[patch_epoch_saving] epoch {epoch} checkpoint -> {epoch_dir}", flush=True, ) # --------------------------------------------------------------------------- # Patch installation # --------------------------------------------------------------------------- def _install_patches() -> None: import opf._train.runner as runner # noqa: PLC0415 from opf._model import weights as weights_mod # noqa: PLC0415 # ---- 1. Wrap _train_one_epoch to capture model + fire save hook -------- _orig_train = runner._train_one_epoch def _patched_train_one_epoch(*args, **kwargs): if _CTX["model"] is None: _CTX["model"] = kwargs.get("model") result = _orig_train(*args, **kwargs) epoch = kwargs.get("epoch_index") if epoch is not None: _save_epoch_checkpoint(epoch) return result runner._train_one_epoch = _patched_train_one_epoch # ---- 2. Wrap _ensure_output_dir to capture output_dir ------------------ _orig_ensure = runner._ensure_output_dir def _patched_ensure_output_dir(path: Path, *, overwrite: bool) -> None: _CTX["output_dir"] = path return _orig_ensure(path, overwrite=overwrite) runner._ensure_output_dir = _patched_ensure_output_dir # ---- 3. Wrap _resolve_output_dtype to capture output torch.dtype ------- _orig_resolve_dtype = runner._resolve_output_dtype def _patched_resolve_output_dtype(**kwargs): serialized, dtype = _orig_resolve_dtype(**kwargs) _CTX["output_dtype"] = dtype return serialized, dtype runner._resolve_output_dtype = _patched_resolve_output_dtype # ---- 4. Shim json.dumps in the runner module to capture output_config -- # The output_config dict is assembled inline in main() and passed to # json.dumps right before it is written to config.json. We identify it by # the presence of "span_class_names" + "num_labels" (unique to that dict). _real_json = runner.json # the json module as imported by runner class _JsonShim: """Proxy for the json module that intercepts the output_config dump.""" def __getattr__(self, name: str): return getattr(_real_json, name) def dumps(self, obj, *args, **kwargs): # noqa: ANN001 result = _real_json.dumps(obj, *args, **kwargs) if ( _CTX["output_config"] is None and isinstance(obj, dict) and "span_class_names" in obj and "num_labels" in obj and "epoch_metrics" not in obj # exclude the summary JSON ): _CTX["output_config"] = dict(obj) return result runner.json = _JsonShim() # --------------------------------------------------------------------------- # Entry point # --------------------------------------------------------------------------- def main(argv: list[str] | None = None) -> int: _install_patches() from opf._train.runner import main as opf_main # noqa: PLC0415 ret = opf_main(argv) saves = _CTX["epoch_saves"] if saves: print(f"\n[patch_epoch_saving] {len(saves)} per-epoch checkpoint(s) written:") for epoch, path in saves: print(f" epoch {epoch:>3}: {path}") else: print("\n[patch_epoch_saving] warning: no per-epoch checkpoints were written") return ret if isinstance(ret, int) else 0 if __name__ == "__main__": sys.exit(main(sys.argv[1:]))