File size: 6,670 Bytes
bc7101b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
"""Auto-push new pilot artifacts to HF as they appear.

Designed to run as a long-lived (~8 h) nohup'd daemon on the training box.
Token is read from stdin and held only in process memory.

Polls every 2 min for:
  - new local ckpts (ckpt-step10000, ckpt-step12000): pushes each as
    `ckpts/<name>/` under the repo.
  - `final/` directory + `final/ablation_n200.json` (the auto-eval result):
    pushes both, regenerates the README with the full ablation table.

Exits after final/ is pushed, or after `--max_hours` (default 8) regardless.
"""
from __future__ import annotations

import argparse
import json
import os
import shutil
import sys
import time
from pathlib import Path


def push_folder(api, folder, path_in_repo, repo, msg):
    api.upload_folder(
        folder_path=str(folder),
        path_in_repo=path_in_repo,
        repo_id=repo,
        repo_type="model",
        commit_message=msg,
    )


def push_file(api, fpath, path_in_repo, repo, msg):
    api.upload_file(
        path_or_fileobj=str(fpath),
        path_in_repo=path_in_repo,
        repo_id=repo,
        repo_type="model",
        commit_message=msg,
    )


def ckpt_complete(ckpt_dir: Path) -> bool:
    """A ckpt is complete iff all 3 expected artifacts exist with nonzero size."""
    for rel in ("model/adapter_model.safetensors", "projector.pt", "head.pt"):
        p = ckpt_dir / rel
        if not p.exists() or p.stat().st_size == 0:
            return False
    return True


def regen_readme(pilot_dir: Path, code_dir: Path, repo: str) -> str:
    """Re-import the upload helper from the local code dir."""
    sys.path.insert(0, str(code_dir.parent))
    from experiments.blt_reasoner.scripts.hf_upload_pilot import build_readme
    return build_readme(pilot_dir, code_dir, repo)


def main():
    p = argparse.ArgumentParser()
    p.add_argument("--repo", required=True)
    p.add_argument("--pilot_dir", required=True)
    p.add_argument("--code_dir", required=True)
    p.add_argument("--train_pid", type=int, required=True)
    p.add_argument("--max_hours", type=float, default=8.0)
    p.add_argument("--log", default=None)
    p.add_argument("--ckpts_to_watch", default="ckpt-step10000,ckpt-step12000")
    args = p.parse_args()

    log_path = args.log or os.path.join(args.pilot_dir, "auto_push.log")
    log_f = open(log_path, "a", buffering=1)

    def log(m):
        line = f"[{time.strftime('%H:%M:%S')}] {m}"
        print(line, flush=True)
        log_f.write(line + "\n")

    # Token source: env var BLT_HF_TOKEN (used when launched via `nohup &`,
    # which detaches stdin). Falls back to stdin for interactive launches.
    token = os.environ.pop("BLT_HF_TOKEN", "").strip()
    if not token:
        try:
            token = sys.stdin.read().strip()
        except Exception:
            token = ""
    if not token.startswith("hf_"):
        log("ERROR: no hf_ token in BLT_HF_TOKEN env or stdin; aborting")
        sys.exit(2)

    from huggingface_hub import HfApi
    api = HfApi(token=token)

    pilot = Path(args.pilot_dir)
    code = Path(args.code_dir)
    watchlist = [c.strip() for c in args.ckpts_to_watch.split(",") if c.strip()]
    pushed = set()

    log(f"daemon start: repo={args.repo} pilot={pilot} train_pid={args.train_pid} watching={watchlist}")

    deadline = time.time() + args.max_hours * 3600
    final_pushed = False

    while time.time() < deadline and not final_pushed:
        # 1. Push any newly-complete ckpts
        for ckpt_name in watchlist:
            if ckpt_name in pushed:
                continue
            ckpt = pilot / ckpt_name
            if ckpt.exists() and ckpt_complete(ckpt):
                # Wait a bit in case files are still flushing
                time.sleep(15)
                try:
                    log(f"pushing {ckpt_name}")
                    push_folder(api, ckpt, f"ckpts/{ckpt_name}", args.repo,
                                f"Add {ckpt_name}")
                    pushed.add(ckpt_name)
                    log(f"  ok: {ckpt_name}")
                except Exception as e:
                    log(f"  ERROR pushing {ckpt_name}: {e!r}; will retry")

        # 2. Check if training has exited and final is ready
        try:
            os.kill(args.train_pid, 0)
            train_alive = True
        except ProcessLookupError:
            train_alive = False
        except PermissionError:
            train_alive = True

        if not train_alive:
            # Auto-eval poller may still be running. Wait until ablation_n200.json appears.
            final_dir = pilot / "final"
            final_abl = final_dir / "ablation_n200.json"
            if final_dir.exists() and final_abl.exists():
                # Wait a bit for any final flushes
                time.sleep(30)
                try:
                    log("pushing final/ (ckpt + ablation_n200.json)")
                    push_folder(api, final_dir, "final", args.repo,
                                "Add final ckpt + n=200 pre-registered z-ablation")
                    # Also push the auto-eval logs
                    for name in ("auto_eval.log", "run.log", "metrics.jsonl"):
                        f = pilot / name
                        if f.exists():
                            try:
                                push_file(api, f, f"logs/{name}", args.repo,
                                          f"Refresh logs/{name} at end of pilot")
                            except Exception as e:
                                log(f"  warn: log push {name}: {e!r}")
                    # Regenerate README with final ablation table
                    try:
                        readme = regen_readme(pilot, code, args.repo)
                        tmp = Path("/tmp/blt_final_readme.md")
                        tmp.write_text(readme)
                        push_file(api, tmp, "README.md", args.repo,
                                  "Regenerate README with final ablation results")
                        tmp.unlink(missing_ok=True)
                    except Exception as e:
                        log(f"  warn: README regen: {e!r}")
                    final_pushed = True
                    log("DONE: final pushed; daemon exiting")
                except Exception as e:
                    log(f"  ERROR pushing final: {e!r}; will retry")
            else:
                log(f"train PID gone but final/ablation_n200.json not yet present; waiting")

        if not final_pushed:
            time.sleep(120)

    if not final_pushed:
        log(f"deadline reached or daemon exiting without final push (pushed={pushed})")
    log_f.close()


if __name__ == "__main__":
    main()