arcspan / scripts /patch_epoch_saving.py
chairulridjal's picture
Add files using upload-large-folder tool
3dac39e verified
"""
patch_epoch_saving.py — Wrap `opf train` to save a checkpoint after every epoch.
Usage (drop-in replacement for `opf train`):
python scripts/patch_epoch_saving.py train.jsonl \\
--output-dir ./ckpt --epochs 5 [... all normal opf train flags ...]
Output layout:
./ckpt/ <- best epoch checkpoint (identical to plain `opf train`)
./ckpt/epoch_1/ <- checkpoint after epoch 1
./ckpt/epoch_2/ <- checkpoint after epoch 2
...
Per-epoch checkpoints save the model state at the END of each epoch (not the
best-so-far). Use them for manual selection or checkpoint averaging afterwards.
How it works (implementation notes):
opf/_train/runner.py has a single `main()` function with an inline epoch loop.
We can't inject code between loop lines without modifying the source, so we
patch three private functions that are called at predictable points:
1. _train_one_epoch(model=..., epoch_index=N, ...)
→ Called once per epoch. We wrap it to (a) capture the model object on
the first call and (b) save a per-epoch checkpoint after each call.
2. _resolve_output_dtype(...)
→ Called once, just before the epoch loop writes the final checkpoint.
We intercept it to capture the serialisation dtype.
3. json.dumps (inside the runner module's namespace)
→ The output_config dict (containing span_class_names, num_labels, …) is
serialised with json.dumps and then written to config.json. We intercept
the specific call by replacing json in the runner module's globals with a
thin shim that captures the first dict containing "span_class_names".
This fires BEFORE the epoch loop, so all per-epoch checkpoints get a
proper config.json.
All patches are installed before opf_main() is called and are module-local —
they do not affect any other imports in the process.
"""
from __future__ import annotations
import json
import sys
from pathlib import Path
from typing import Any
import torch
from safetensors.torch import save_file
# ---------------------------------------------------------------------------
# Shared context — populated by the patches during training
# ---------------------------------------------------------------------------
_CTX: dict[str, Any] = {
"model": None, # torch.nn.Module — captured from first _train_one_epoch call
"output_dir": None, # Path — captured from _ensure_output_dir patch
"output_config": None, # dict — captured from json.dumps shim
"output_dtype": None, # torch.dtype — captured from _resolve_output_dtype patch
"epoch_saves": [], # list[(epoch, Path)] — filled as epochs complete
}
# ---------------------------------------------------------------------------
# Per-epoch checkpoint writer
# ---------------------------------------------------------------------------
def _save_epoch_checkpoint(epoch: int) -> None:
model = _CTX["model"]
output_dir: Path | None = _CTX["output_dir"]
output_config: dict | None = _CTX["output_config"]
output_dtype: torch.dtype | None = _CTX["output_dtype"]
if model is None or output_dir is None:
print(
f"[patch_epoch_saving] warning: skipping epoch {epoch} save "
"(context not yet populated)",
flush=True,
)
return
epoch_dir = output_dir / f"epoch_{epoch}"
epoch_dir.mkdir(parents=True, exist_ok=True)
# config.json — use captured config if available, else copy from output_dir
if output_config is not None:
(epoch_dir / "config.json").write_text(
json.dumps(output_config, indent=2, sort_keys=True) + "\n",
encoding="utf-8",
)
else:
src = output_dir / "config.json"
if src.exists():
import shutil # noqa: PLC0415
shutil.copy2(src, epoch_dir / "config.json")
# model.safetensors — snapshot named_parameters at current dtype
tensors: dict[str, torch.Tensor] = {}
for name, param in model.named_parameters():
t = param.detach().cpu()
if output_dtype is not None:
t = t.to(output_dtype)
tensors[name] = t.contiguous()
save_file(tensors, str(epoch_dir / "model.safetensors"))
_CTX["epoch_saves"].append((epoch, epoch_dir))
print(
f"[patch_epoch_saving] epoch {epoch} checkpoint -> {epoch_dir}",
flush=True,
)
# ---------------------------------------------------------------------------
# Patch installation
# ---------------------------------------------------------------------------
def _install_patches() -> None:
import opf._train.runner as runner # noqa: PLC0415
from opf._model import weights as weights_mod # noqa: PLC0415
# ---- 1. Wrap _train_one_epoch to capture model + fire save hook --------
_orig_train = runner._train_one_epoch
def _patched_train_one_epoch(*args, **kwargs):
if _CTX["model"] is None:
_CTX["model"] = kwargs.get("model")
result = _orig_train(*args, **kwargs)
epoch = kwargs.get("epoch_index")
if epoch is not None:
_save_epoch_checkpoint(epoch)
return result
runner._train_one_epoch = _patched_train_one_epoch
# ---- 2. Wrap _ensure_output_dir to capture output_dir ------------------
_orig_ensure = runner._ensure_output_dir
def _patched_ensure_output_dir(path: Path, *, overwrite: bool) -> None:
_CTX["output_dir"] = path
return _orig_ensure(path, overwrite=overwrite)
runner._ensure_output_dir = _patched_ensure_output_dir
# ---- 3. Wrap _resolve_output_dtype to capture output torch.dtype -------
_orig_resolve_dtype = runner._resolve_output_dtype
def _patched_resolve_output_dtype(**kwargs):
serialized, dtype = _orig_resolve_dtype(**kwargs)
_CTX["output_dtype"] = dtype
return serialized, dtype
runner._resolve_output_dtype = _patched_resolve_output_dtype
# ---- 4. Shim json.dumps in the runner module to capture output_config --
# The output_config dict is assembled inline in main() and passed to
# json.dumps right before it is written to config.json. We identify it by
# the presence of "span_class_names" + "num_labels" (unique to that dict).
_real_json = runner.json # the json module as imported by runner
class _JsonShim:
"""Proxy for the json module that intercepts the output_config dump."""
def __getattr__(self, name: str):
return getattr(_real_json, name)
def dumps(self, obj, *args, **kwargs): # noqa: ANN001
result = _real_json.dumps(obj, *args, **kwargs)
if (
_CTX["output_config"] is None
and isinstance(obj, dict)
and "span_class_names" in obj
and "num_labels" in obj
and "epoch_metrics" not in obj # exclude the summary JSON
):
_CTX["output_config"] = dict(obj)
return result
runner.json = _JsonShim()
# ---------------------------------------------------------------------------
# Entry point
# ---------------------------------------------------------------------------
def main(argv: list[str] | None = None) -> int:
_install_patches()
from opf._train.runner import main as opf_main # noqa: PLC0415
ret = opf_main(argv)
saves = _CTX["epoch_saves"]
if saves:
print(f"\n[patch_epoch_saving] {len(saves)} per-epoch checkpoint(s) written:")
for epoch, path in saves:
print(f" epoch {epoch:>3}: {path}")
else:
print("\n[patch_epoch_saving] warning: no per-epoch checkpoints were written")
return ret if isinstance(ret, int) else 0
if __name__ == "__main__":
sys.exit(main(sys.argv[1:]))