File size: 7,866 Bytes
3dac39e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
"""
patch_epoch_saving.py β€” Wrap `opf train` to save a checkpoint after every epoch.

Usage (drop-in replacement for `opf train`):
    python scripts/patch_epoch_saving.py train.jsonl \\
        --output-dir ./ckpt --epochs 5 [... all normal opf train flags ...]

Output layout:
    ./ckpt/          <- best epoch checkpoint (identical to plain `opf train`)
    ./ckpt/epoch_1/  <- checkpoint after epoch 1
    ./ckpt/epoch_2/  <- checkpoint after epoch 2
    ...

Per-epoch checkpoints save the model state at the END of each epoch (not the
best-so-far). Use them for manual selection or checkpoint averaging afterwards.

How it works (implementation notes):
    opf/_train/runner.py has a single `main()` function with an inline epoch loop.
    We can't inject code between loop lines without modifying the source, so we
    patch three private functions that are called at predictable points:

    1. _train_one_epoch(model=..., epoch_index=N, ...)
       β†’ Called once per epoch. We wrap it to (a) capture the model object on
         the first call and (b) save a per-epoch checkpoint after each call.

    2. _resolve_output_dtype(...)
       β†’ Called once, just before the epoch loop writes the final checkpoint.
         We intercept it to capture the serialisation dtype.

    3. json.dumps (inside the runner module's namespace)
       β†’ The output_config dict (containing span_class_names, num_labels, …) is
         serialised with json.dumps and then written to config.json. We intercept
         the specific call by replacing json in the runner module's globals with a
         thin shim that captures the first dict containing "span_class_names".
         This fires BEFORE the epoch loop, so all per-epoch checkpoints get a
         proper config.json.

All patches are installed before opf_main() is called and are module-local β€”
they do not affect any other imports in the process.
"""

from __future__ import annotations

import json
import sys
from pathlib import Path
from typing import Any

import torch
from safetensors.torch import save_file


# ---------------------------------------------------------------------------
# Shared context β€” populated by the patches during training
# ---------------------------------------------------------------------------

_CTX: dict[str, Any] = {
    "model": None,        # torch.nn.Module β€” captured from first _train_one_epoch call
    "output_dir": None,   # Path β€” captured from _ensure_output_dir patch
    "output_config": None,  # dict β€” captured from json.dumps shim
    "output_dtype": None,   # torch.dtype β€” captured from _resolve_output_dtype patch
    "epoch_saves": [],      # list[(epoch, Path)] β€” filled as epochs complete
}


# ---------------------------------------------------------------------------
# Per-epoch checkpoint writer
# ---------------------------------------------------------------------------

def _save_epoch_checkpoint(epoch: int) -> None:
    model = _CTX["model"]
    output_dir: Path | None = _CTX["output_dir"]
    output_config: dict | None = _CTX["output_config"]
    output_dtype: torch.dtype | None = _CTX["output_dtype"]

    if model is None or output_dir is None:
        print(
            f"[patch_epoch_saving] warning: skipping epoch {epoch} save "
            "(context not yet populated)",
            flush=True,
        )
        return

    epoch_dir = output_dir / f"epoch_{epoch}"
    epoch_dir.mkdir(parents=True, exist_ok=True)

    # config.json β€” use captured config if available, else copy from output_dir
    if output_config is not None:
        (epoch_dir / "config.json").write_text(
            json.dumps(output_config, indent=2, sort_keys=True) + "\n",
            encoding="utf-8",
        )
    else:
        src = output_dir / "config.json"
        if src.exists():
            import shutil  # noqa: PLC0415
            shutil.copy2(src, epoch_dir / "config.json")

    # model.safetensors β€” snapshot named_parameters at current dtype
    tensors: dict[str, torch.Tensor] = {}
    for name, param in model.named_parameters():
        t = param.detach().cpu()
        if output_dtype is not None:
            t = t.to(output_dtype)
        tensors[name] = t.contiguous()

    save_file(tensors, str(epoch_dir / "model.safetensors"))

    _CTX["epoch_saves"].append((epoch, epoch_dir))
    print(
        f"[patch_epoch_saving] epoch {epoch} checkpoint -> {epoch_dir}",
        flush=True,
    )


# ---------------------------------------------------------------------------
# Patch installation
# ---------------------------------------------------------------------------

def _install_patches() -> None:
    import opf._train.runner as runner  # noqa: PLC0415
    from opf._model import weights as weights_mod  # noqa: PLC0415

    # ---- 1. Wrap _train_one_epoch to capture model + fire save hook --------
    _orig_train = runner._train_one_epoch

    def _patched_train_one_epoch(*args, **kwargs):
        if _CTX["model"] is None:
            _CTX["model"] = kwargs.get("model")
        result = _orig_train(*args, **kwargs)
        epoch = kwargs.get("epoch_index")
        if epoch is not None:
            _save_epoch_checkpoint(epoch)
        return result

    runner._train_one_epoch = _patched_train_one_epoch

    # ---- 2. Wrap _ensure_output_dir to capture output_dir ------------------
    _orig_ensure = runner._ensure_output_dir

    def _patched_ensure_output_dir(path: Path, *, overwrite: bool) -> None:
        _CTX["output_dir"] = path
        return _orig_ensure(path, overwrite=overwrite)

    runner._ensure_output_dir = _patched_ensure_output_dir

    # ---- 3. Wrap _resolve_output_dtype to capture output torch.dtype -------
    _orig_resolve_dtype = runner._resolve_output_dtype

    def _patched_resolve_output_dtype(**kwargs):
        serialized, dtype = _orig_resolve_dtype(**kwargs)
        _CTX["output_dtype"] = dtype
        return serialized, dtype

    runner._resolve_output_dtype = _patched_resolve_output_dtype

    # ---- 4. Shim json.dumps in the runner module to capture output_config --
    # The output_config dict is assembled inline in main() and passed to
    # json.dumps right before it is written to config.json. We identify it by
    # the presence of "span_class_names" + "num_labels" (unique to that dict).
    _real_json = runner.json  # the json module as imported by runner

    class _JsonShim:
        """Proxy for the json module that intercepts the output_config dump."""

        def __getattr__(self, name: str):
            return getattr(_real_json, name)

        def dumps(self, obj, *args, **kwargs):  # noqa: ANN001
            result = _real_json.dumps(obj, *args, **kwargs)
            if (
                _CTX["output_config"] is None
                and isinstance(obj, dict)
                and "span_class_names" in obj
                and "num_labels" in obj
                and "epoch_metrics" not in obj  # exclude the summary JSON
            ):
                _CTX["output_config"] = dict(obj)
            return result

    runner.json = _JsonShim()


# ---------------------------------------------------------------------------
# Entry point
# ---------------------------------------------------------------------------

def main(argv: list[str] | None = None) -> int:
    _install_patches()

    from opf._train.runner import main as opf_main  # noqa: PLC0415

    ret = opf_main(argv)

    saves = _CTX["epoch_saves"]
    if saves:
        print(f"\n[patch_epoch_saving] {len(saves)} per-epoch checkpoint(s) written:")
        for epoch, path in saves:
            print(f"  epoch {epoch:>3}: {path}")
    else:
        print("\n[patch_epoch_saving] warning: no per-epoch checkpoints were written")

    return ret if isinstance(ret, int) else 0


if __name__ == "__main__":
    sys.exit(main(sys.argv[1:]))