File size: 7,866 Bytes
3dac39e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 | """
patch_epoch_saving.py β Wrap `opf train` to save a checkpoint after every epoch.
Usage (drop-in replacement for `opf train`):
python scripts/patch_epoch_saving.py train.jsonl \\
--output-dir ./ckpt --epochs 5 [... all normal opf train flags ...]
Output layout:
./ckpt/ <- best epoch checkpoint (identical to plain `opf train`)
./ckpt/epoch_1/ <- checkpoint after epoch 1
./ckpt/epoch_2/ <- checkpoint after epoch 2
...
Per-epoch checkpoints save the model state at the END of each epoch (not the
best-so-far). Use them for manual selection or checkpoint averaging afterwards.
How it works (implementation notes):
opf/_train/runner.py has a single `main()` function with an inline epoch loop.
We can't inject code between loop lines without modifying the source, so we
patch three private functions that are called at predictable points:
1. _train_one_epoch(model=..., epoch_index=N, ...)
β Called once per epoch. We wrap it to (a) capture the model object on
the first call and (b) save a per-epoch checkpoint after each call.
2. _resolve_output_dtype(...)
β Called once, just before the epoch loop writes the final checkpoint.
We intercept it to capture the serialisation dtype.
3. json.dumps (inside the runner module's namespace)
β The output_config dict (containing span_class_names, num_labels, β¦) is
serialised with json.dumps and then written to config.json. We intercept
the specific call by replacing json in the runner module's globals with a
thin shim that captures the first dict containing "span_class_names".
This fires BEFORE the epoch loop, so all per-epoch checkpoints get a
proper config.json.
All patches are installed before opf_main() is called and are module-local β
they do not affect any other imports in the process.
"""
from __future__ import annotations
import json
import sys
from pathlib import Path
from typing import Any
import torch
from safetensors.torch import save_file
# ---------------------------------------------------------------------------
# Shared context β populated by the patches during training
# ---------------------------------------------------------------------------
_CTX: dict[str, Any] = {
"model": None, # torch.nn.Module β captured from first _train_one_epoch call
"output_dir": None, # Path β captured from _ensure_output_dir patch
"output_config": None, # dict β captured from json.dumps shim
"output_dtype": None, # torch.dtype β captured from _resolve_output_dtype patch
"epoch_saves": [], # list[(epoch, Path)] β filled as epochs complete
}
# ---------------------------------------------------------------------------
# Per-epoch checkpoint writer
# ---------------------------------------------------------------------------
def _save_epoch_checkpoint(epoch: int) -> None:
model = _CTX["model"]
output_dir: Path | None = _CTX["output_dir"]
output_config: dict | None = _CTX["output_config"]
output_dtype: torch.dtype | None = _CTX["output_dtype"]
if model is None or output_dir is None:
print(
f"[patch_epoch_saving] warning: skipping epoch {epoch} save "
"(context not yet populated)",
flush=True,
)
return
epoch_dir = output_dir / f"epoch_{epoch}"
epoch_dir.mkdir(parents=True, exist_ok=True)
# config.json β use captured config if available, else copy from output_dir
if output_config is not None:
(epoch_dir / "config.json").write_text(
json.dumps(output_config, indent=2, sort_keys=True) + "\n",
encoding="utf-8",
)
else:
src = output_dir / "config.json"
if src.exists():
import shutil # noqa: PLC0415
shutil.copy2(src, epoch_dir / "config.json")
# model.safetensors β snapshot named_parameters at current dtype
tensors: dict[str, torch.Tensor] = {}
for name, param in model.named_parameters():
t = param.detach().cpu()
if output_dtype is not None:
t = t.to(output_dtype)
tensors[name] = t.contiguous()
save_file(tensors, str(epoch_dir / "model.safetensors"))
_CTX["epoch_saves"].append((epoch, epoch_dir))
print(
f"[patch_epoch_saving] epoch {epoch} checkpoint -> {epoch_dir}",
flush=True,
)
# ---------------------------------------------------------------------------
# Patch installation
# ---------------------------------------------------------------------------
def _install_patches() -> None:
import opf._train.runner as runner # noqa: PLC0415
from opf._model import weights as weights_mod # noqa: PLC0415
# ---- 1. Wrap _train_one_epoch to capture model + fire save hook --------
_orig_train = runner._train_one_epoch
def _patched_train_one_epoch(*args, **kwargs):
if _CTX["model"] is None:
_CTX["model"] = kwargs.get("model")
result = _orig_train(*args, **kwargs)
epoch = kwargs.get("epoch_index")
if epoch is not None:
_save_epoch_checkpoint(epoch)
return result
runner._train_one_epoch = _patched_train_one_epoch
# ---- 2. Wrap _ensure_output_dir to capture output_dir ------------------
_orig_ensure = runner._ensure_output_dir
def _patched_ensure_output_dir(path: Path, *, overwrite: bool) -> None:
_CTX["output_dir"] = path
return _orig_ensure(path, overwrite=overwrite)
runner._ensure_output_dir = _patched_ensure_output_dir
# ---- 3. Wrap _resolve_output_dtype to capture output torch.dtype -------
_orig_resolve_dtype = runner._resolve_output_dtype
def _patched_resolve_output_dtype(**kwargs):
serialized, dtype = _orig_resolve_dtype(**kwargs)
_CTX["output_dtype"] = dtype
return serialized, dtype
runner._resolve_output_dtype = _patched_resolve_output_dtype
# ---- 4. Shim json.dumps in the runner module to capture output_config --
# The output_config dict is assembled inline in main() and passed to
# json.dumps right before it is written to config.json. We identify it by
# the presence of "span_class_names" + "num_labels" (unique to that dict).
_real_json = runner.json # the json module as imported by runner
class _JsonShim:
"""Proxy for the json module that intercepts the output_config dump."""
def __getattr__(self, name: str):
return getattr(_real_json, name)
def dumps(self, obj, *args, **kwargs): # noqa: ANN001
result = _real_json.dumps(obj, *args, **kwargs)
if (
_CTX["output_config"] is None
and isinstance(obj, dict)
and "span_class_names" in obj
and "num_labels" in obj
and "epoch_metrics" not in obj # exclude the summary JSON
):
_CTX["output_config"] = dict(obj)
return result
runner.json = _JsonShim()
# ---------------------------------------------------------------------------
# Entry point
# ---------------------------------------------------------------------------
def main(argv: list[str] | None = None) -> int:
_install_patches()
from opf._train.runner import main as opf_main # noqa: PLC0415
ret = opf_main(argv)
saves = _CTX["epoch_saves"]
if saves:
print(f"\n[patch_epoch_saving] {len(saves)} per-epoch checkpoint(s) written:")
for epoch, path in saves:
print(f" epoch {epoch:>3}: {path}")
else:
print("\n[patch_epoch_saving] warning: no per-epoch checkpoints were written")
return ret if isinstance(ret, int) else 0
if __name__ == "__main__":
sys.exit(main(sys.argv[1:]))
|