| """ |
| patch_epoch_saving.py — Wrap `opf train` to save a checkpoint after every epoch. |
| |
| Usage (drop-in replacement for `opf train`): |
| python scripts/patch_epoch_saving.py train.jsonl \\ |
| --output-dir ./ckpt --epochs 5 [... all normal opf train flags ...] |
| |
| Output layout: |
| ./ckpt/ <- best epoch checkpoint (identical to plain `opf train`) |
| ./ckpt/epoch_1/ <- checkpoint after epoch 1 |
| ./ckpt/epoch_2/ <- checkpoint after epoch 2 |
| ... |
| |
| Per-epoch checkpoints save the model state at the END of each epoch (not the |
| best-so-far). Use them for manual selection or checkpoint averaging afterwards. |
| |
| How it works (implementation notes): |
| opf/_train/runner.py has a single `main()` function with an inline epoch loop. |
| We can't inject code between loop lines without modifying the source, so we |
| patch three private functions that are called at predictable points: |
| |
| 1. _train_one_epoch(model=..., epoch_index=N, ...) |
| → Called once per epoch. We wrap it to (a) capture the model object on |
| the first call and (b) save a per-epoch checkpoint after each call. |
| |
| 2. _resolve_output_dtype(...) |
| → Called once, just before the epoch loop writes the final checkpoint. |
| We intercept it to capture the serialisation dtype. |
| |
| 3. json.dumps (inside the runner module's namespace) |
| → The output_config dict (containing span_class_names, num_labels, …) is |
| serialised with json.dumps and then written to config.json. We intercept |
| the specific call by replacing json in the runner module's globals with a |
| thin shim that captures the first dict containing "span_class_names". |
| This fires BEFORE the epoch loop, so all per-epoch checkpoints get a |
| proper config.json. |
| |
| All patches are installed before opf_main() is called and are module-local — |
| they do not affect any other imports in the process. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import json |
| import sys |
| from pathlib import Path |
| from typing import Any |
|
|
| import torch |
| from safetensors.torch import save_file |
|
|
|
|
| |
| |
| |
|
|
| _CTX: dict[str, Any] = { |
| "model": None, |
| "output_dir": None, |
| "output_config": None, |
| "output_dtype": None, |
| "epoch_saves": [], |
| } |
|
|
|
|
| |
| |
| |
|
|
| def _save_epoch_checkpoint(epoch: int) -> None: |
| model = _CTX["model"] |
| output_dir: Path | None = _CTX["output_dir"] |
| output_config: dict | None = _CTX["output_config"] |
| output_dtype: torch.dtype | None = _CTX["output_dtype"] |
|
|
| if model is None or output_dir is None: |
| print( |
| f"[patch_epoch_saving] warning: skipping epoch {epoch} save " |
| "(context not yet populated)", |
| flush=True, |
| ) |
| return |
|
|
| epoch_dir = output_dir / f"epoch_{epoch}" |
| epoch_dir.mkdir(parents=True, exist_ok=True) |
|
|
| |
| if output_config is not None: |
| (epoch_dir / "config.json").write_text( |
| json.dumps(output_config, indent=2, sort_keys=True) + "\n", |
| encoding="utf-8", |
| ) |
| else: |
| src = output_dir / "config.json" |
| if src.exists(): |
| import shutil |
| shutil.copy2(src, epoch_dir / "config.json") |
|
|
| |
| tensors: dict[str, torch.Tensor] = {} |
| for name, param in model.named_parameters(): |
| t = param.detach().cpu() |
| if output_dtype is not None: |
| t = t.to(output_dtype) |
| tensors[name] = t.contiguous() |
|
|
| save_file(tensors, str(epoch_dir / "model.safetensors")) |
|
|
| _CTX["epoch_saves"].append((epoch, epoch_dir)) |
| print( |
| f"[patch_epoch_saving] epoch {epoch} checkpoint -> {epoch_dir}", |
| flush=True, |
| ) |
|
|
|
|
| |
| |
| |
|
|
| def _install_patches() -> None: |
| import opf._train.runner as runner |
| from opf._model import weights as weights_mod |
|
|
| |
| _orig_train = runner._train_one_epoch |
|
|
| def _patched_train_one_epoch(*args, **kwargs): |
| if _CTX["model"] is None: |
| _CTX["model"] = kwargs.get("model") |
| result = _orig_train(*args, **kwargs) |
| epoch = kwargs.get("epoch_index") |
| if epoch is not None: |
| _save_epoch_checkpoint(epoch) |
| return result |
|
|
| runner._train_one_epoch = _patched_train_one_epoch |
|
|
| |
| _orig_ensure = runner._ensure_output_dir |
|
|
| def _patched_ensure_output_dir(path: Path, *, overwrite: bool) -> None: |
| _CTX["output_dir"] = path |
| return _orig_ensure(path, overwrite=overwrite) |
|
|
| runner._ensure_output_dir = _patched_ensure_output_dir |
|
|
| |
| _orig_resolve_dtype = runner._resolve_output_dtype |
|
|
| def _patched_resolve_output_dtype(**kwargs): |
| serialized, dtype = _orig_resolve_dtype(**kwargs) |
| _CTX["output_dtype"] = dtype |
| return serialized, dtype |
|
|
| runner._resolve_output_dtype = _patched_resolve_output_dtype |
|
|
| |
| |
| |
| |
| _real_json = runner.json |
|
|
| class _JsonShim: |
| """Proxy for the json module that intercepts the output_config dump.""" |
|
|
| def __getattr__(self, name: str): |
| return getattr(_real_json, name) |
|
|
| def dumps(self, obj, *args, **kwargs): |
| result = _real_json.dumps(obj, *args, **kwargs) |
| if ( |
| _CTX["output_config"] is None |
| and isinstance(obj, dict) |
| and "span_class_names" in obj |
| and "num_labels" in obj |
| and "epoch_metrics" not in obj |
| ): |
| _CTX["output_config"] = dict(obj) |
| return result |
|
|
| runner.json = _JsonShim() |
|
|
|
|
| |
| |
| |
|
|
| def main(argv: list[str] | None = None) -> int: |
| _install_patches() |
|
|
| from opf._train.runner import main as opf_main |
|
|
| ret = opf_main(argv) |
|
|
| saves = _CTX["epoch_saves"] |
| if saves: |
| print(f"\n[patch_epoch_saving] {len(saves)} per-epoch checkpoint(s) written:") |
| for epoch, path in saves: |
| print(f" epoch {epoch:>3}: {path}") |
| else: |
| print("\n[patch_epoch_saving] warning: no per-epoch checkpoints were written") |
|
|
| return ret if isinstance(ret, int) else 0 |
|
|
|
|
| if __name__ == "__main__": |
| sys.exit(main(sys.argv[1:])) |
|
|