OpenTransformer commited on
Commit
269c08f
·
verified ·
1 Parent(s): b572c64

AGILLM4_training_script_and_bounded_uploads

Browse files
AGILLM-4.md CHANGED
@@ -37,7 +37,7 @@ Production first-run recipe on 4090:
37
  ```bash
38
  AGILLM4_4090_WARMSTART_FROM=/workspace/agillm-4/agillm4_floor_seed_from_v3_v47.pt \
39
  AGILLM4_4090_PRESET=agillm4_floor \
40
- AGILLM4_4090_BLOCK=512 \
41
  AGILLM4_4090_TOKEN_PARAM_RATIO=100 \
42
  tmux new-session -d -s agillm4_floor \
43
  /workspace/agillm-4/launch_agillm4_4090_floor_from_v47.sh
@@ -47,14 +47,18 @@ Important: `--sat_every 1 --nat_every 4` keeps SAT trained every step and NAT ac
47
 
48
  Escalation ladder on 4090:
49
 
50
- 1. `block=512`
51
- 2. `block=640`
52
- 3. `block=768`
53
- 4. `block=1024`
54
- 5. `block=1280+` only after measured VRAM headroom
55
 
56
  If 8-bit optimizer is unavailable, install `bitsandbytes` rather than dropping the long-block target. SAT remains active every step; NAT should stay enabled with a slower cadence or `--nat_max_tokens` cap on 24GB. The code lowers peak memory by backpropagating AR, SAT, and NAT sequentially, not by deleting heads.
57
 
 
 
 
 
 
 
58
  ## Warm Start from AGILLM-3 (function-preserving)
59
 
60
  The next AGILLM-4 run does **not** start from a random init. `build_v4_seed.py`
@@ -96,7 +100,7 @@ Use with `--warmstart_from`:
96
  python /workspace/agillm-4/nB300_agillm4.py train \
97
  --preset agillm4_floor \
98
  --warmstart_from /workspace/agillm-4/agillm4_floor_seed_from_v3_v47.pt \
99
- --batch_size 1 --block 512 --amp --grad_checkpoint --sat_every 1 --nat_every 4 \
100
  --token_param_ratio 100
101
  ```
102
 
 
37
  ```bash
38
  AGILLM4_4090_WARMSTART_FROM=/workspace/agillm-4/agillm4_floor_seed_from_v3_v47.pt \
39
  AGILLM4_4090_PRESET=agillm4_floor \
40
+ AGILLM4_4090_BLOCK=1280 \
41
  AGILLM4_4090_TOKEN_PARAM_RATIO=100 \
42
  tmux new-session -d -s agillm4_floor \
43
  /workspace/agillm-4/launch_agillm4_4090_floor_from_v47.sh
 
47
 
48
  Escalation ladder on 4090:
49
 
50
+ 1. Start `block=1280`.
51
+ 2. If OOM, back off by about 20% at a time (`1024`, `768`, ...), not straight to half.
52
+ 3. If stable below 23GB VRAM after real step timing, raise toward `1536`.
 
 
53
 
54
  If 8-bit optimizer is unavailable, install `bitsandbytes` rather than dropping the long-block target. SAT remains active every step; NAT should stay enabled with a slower cadence or `--nat_max_tokens` cap on 24GB. The code lowers peak memory by backpropagating AR, SAT, and NAT sequentially, not by deleting heads.
55
 
56
+ For the year-scale full run, keep checkpoint retention bounded. The 4090 launcher
57
+ keeps one local full checkpoint and one local delta. The companion uploader
58
+ publishes status/log tails every 30 minutes, uploads the newest delta at most
59
+ daily, uploads full checkpoints at most weekly, and prunes remote AGILLM-4
60
+ training uploads to the configured rolling window.
61
+
62
  ## Warm Start from AGILLM-3 (function-preserving)
63
 
64
  The next AGILLM-4 run does **not** start from a random init. `build_v4_seed.py`
 
100
  python /workspace/agillm-4/nB300_agillm4.py train \
101
  --preset agillm4_floor \
102
  --warmstart_from /workspace/agillm-4/agillm4_floor_seed_from_v3_v47.pt \
103
+ --batch_size 1 --block 1280 --amp --grad_checkpoint --sat_every 1 --nat_every 4 \
104
  --token_param_ratio 100
105
  ```
106
 
README.md CHANGED
@@ -29,10 +29,16 @@ recipes. The current sublinear backend is intentionally experimental: profile it
29
  against SDPA before using it for a real run.
30
 
31
  On RTX 4090-class 24GB cards, `run_agillm4_4090_longblock.sh` now defaults to
32
- `agillm4_floor` instead of the AGILLM-3-sized `large` preset. Override
33
- `AGILLM4_4090_BLOCK` upward only after the first floor run is stable.
34
  For the current v47 seed, launch tmux with
35
  `/workspace/agillm-4/launch_agillm4_4090_floor_from_v47.sh`; it writes
36
  `/workspace/agillm4_floor_train.log`.
37
 
 
 
 
 
 
 
38
  Current harvest status from n1.py is tracked in [N1_HARVEST.md](N1_HARVEST.md).
 
29
  against SDPA before using it for a real run.
30
 
31
  On RTX 4090-class 24GB cards, `run_agillm4_4090_longblock.sh` now defaults to
32
+ `agillm4_floor` instead of the AGILLM-3-sized `large` preset, starts at block
33
+ `1280`, and backs off in smaller 20% steps if VRAM is too tight.
34
  For the current v47 seed, launch tmux with
35
  `/workspace/agillm-4/launch_agillm4_4090_floor_from_v47.sh`; it writes
36
  `/workspace/agillm4_floor_train.log`.
37
 
38
+ Checkpoint upload policy is intentionally bounded for the public HF storage
39
+ quota: status and log tails upload every 30 minutes, the latest multi-GB delta
40
+ uploads at most daily, and full checkpoints upload at most weekly with only two
41
+ current remote files retained. Local full saves default to daily and local
42
+ retention is one full plus one delta, so the 64GB Vast disk does not slowly fill.
43
+
44
  Current harvest status from n1.py is tracked in [N1_HARVEST.md](N1_HARVEST.md).
launch_agillm4_4090_floor_from_v47.sh CHANGED
@@ -9,11 +9,14 @@ echo "LAUNCH_AGILLM4_4090_FLOOR_FROM_V47 $(date -u +%Y-%m-%dT%H:%M:%SZ) host=$(h
9
 
10
  export AGILLM4_4090_WARMSTART_FROM="${AGILLM4_4090_WARMSTART_FROM:-/workspace/agillm-4/agillm4_floor_seed_from_v3_v47.pt}"
11
  export AGILLM4_4090_PRESET="${AGILLM4_4090_PRESET:-agillm4_floor}"
12
- export AGILLM4_4090_BLOCK="${AGILLM4_4090_BLOCK:-512}"
13
  export AGILLM4_4090_TOKEN_PARAM_RATIO="${AGILLM4_4090_TOKEN_PARAM_RATIO:-100}"
14
  export AGILLM4_4090_NAT_EVERY="${AGILLM4_4090_NAT_EVERY:-4}"
15
- export AGILLM4_4090_NAT_MAX_TOKENS="${AGILLM4_4090_NAT_MAX_TOKENS:-512}"
16
- export AGILLM4_4090_SAVE_EVERY_SEC="${AGILLM4_4090_SAVE_EVERY_SEC:-21600}"
 
 
 
17
  export AGILLM4_4090_SAVE_DIR="${AGILLM4_4090_SAVE_DIR:-/workspace/agillm4_4090_ckpts}"
18
 
19
  exec bash /workspace/agillm-4/run_agillm4_4090_longblock.sh
 
9
 
10
  export AGILLM4_4090_WARMSTART_FROM="${AGILLM4_4090_WARMSTART_FROM:-/workspace/agillm-4/agillm4_floor_seed_from_v3_v47.pt}"
11
  export AGILLM4_4090_PRESET="${AGILLM4_4090_PRESET:-agillm4_floor}"
12
+ export AGILLM4_4090_BLOCK="${AGILLM4_4090_BLOCK:-1280}"
13
  export AGILLM4_4090_TOKEN_PARAM_RATIO="${AGILLM4_4090_TOKEN_PARAM_RATIO:-100}"
14
  export AGILLM4_4090_NAT_EVERY="${AGILLM4_4090_NAT_EVERY:-4}"
15
+ export AGILLM4_4090_NAT_MAX_TOKENS="${AGILLM4_4090_NAT_MAX_TOKENS:-768}"
16
+ export AGILLM4_4090_SAVE_EVERY_SEC="${AGILLM4_4090_SAVE_EVERY_SEC:-86400}"
17
+ export AGILLM4_4090_DELTA_EVERY_STEPS="${AGILLM4_4090_DELTA_EVERY_STEPS:-25000}"
18
+ export AGILLM4_4090_DELTA_MAX_KEEP="${AGILLM4_4090_DELTA_MAX_KEEP:-1}"
19
+ export AGILLM4_4090_MAX_CKPTS="${AGILLM4_4090_MAX_CKPTS:-1}"
20
  export AGILLM4_4090_SAVE_DIR="${AGILLM4_4090_SAVE_DIR:-/workspace/agillm4_4090_ckpts}"
21
 
22
  exec bash /workspace/agillm-4/run_agillm4_4090_longblock.sh
nB300_agillm4.py CHANGED
@@ -69,7 +69,11 @@ _STATUS_PROGRESS_RE = re.compile(
69
  r"^\[(?P<percent>\d+(?:\.\d+)?)%\]\s+"
70
  r"(?P<seen>[\d,]+)/(?P<target>[\d,]+)\s+tok\s+\|\s+"
71
  r"(?P<tok_s>[\d.]+)\s+tok/s\s+\|\s+"
72
- r"loss=(?P<loss>-?[\d.]+)\s+B=(?P<batch>\d+)\s+L=(?P<block>\d+)\s*$"
 
 
 
 
73
  )
74
  _STATUS_DELTA_RE = re.compile(r"\[delta\]\s+saved\s+(?P<name>\S+?\.pt)\s+\((?P<sha>[0-9a-f]+)\.\.\.\)")
75
  _STATUS_STEP_RE = re.compile(r"step(?P<step>\d+)")
@@ -99,6 +103,30 @@ def _status_human_duration(seconds: Optional[float]) -> Optional[str]:
99
  return " ".join(parts)
100
 
101
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
  def _status_format_int(value: Optional[int]) -> str:
103
  return "?" if value is None else f"{value:,}"
104
 
@@ -193,6 +221,9 @@ def _status_parse_progress_line(line: str) -> Optional[Dict[str, Any]]:
193
  "loss": loss,
194
  "batch": int(match.group("batch")),
195
  "block": int(match.group("block")),
 
 
 
196
  }
197
 
198
 
@@ -508,12 +539,18 @@ def _format_status_text(status: Dict[str, Any]) -> str:
508
 
509
  progress = status.get("progress")
510
  if progress:
 
 
 
 
511
  lines.append(
512
  "Progress: "
513
  f"{progress['percent']:.1f}% | "
514
  f"{_status_format_int(progress['seen_tokens'])}/{_status_format_int(progress['target_tokens'])} tok | "
515
  f"{progress['tok_per_sec']} tok/s | loss {progress['loss']:.3f} | "
516
  f"B={progress['batch']} L={progress['block']}"
 
 
517
  )
518
  else:
519
  lines.append("Progress: unavailable")
@@ -590,23 +627,45 @@ from anchor_memory import AnchorMemoryConfig, AnchorMemoryLayer
590
 
591
  # SafeProgress - Claude-safe progress (discrete lines, not single growing line)
592
  class SafeProgress:
593
- def __init__(self, total, initial=0, unit="tok", print_every=500):
594
  self.total, self.n, self.unit = total, initial, unit
595
  self.initial = initial
596
  self.last_print, self.postfix = initial, {}
 
 
 
 
597
  self.start_time = __import__('time').time()
 
598
  def update(self, n=1):
599
  self.n += n
600
- if self.n - self.last_print >= 1000000: # print every ~1M tokens
601
- self._print(); self.last_print = self.n
 
 
 
 
 
 
 
 
 
602
  def set_postfix(self, **kwargs): self.postfix = kwargs
603
- def _print(self):
604
- elapsed = __import__('time').time() - self.start_time
 
605
  rate = (self.n - self.initial) / elapsed if elapsed > 0 else 0
606
  pct = 100 * self.n / self.total if self.total > 0 else 0
607
  pf = ' '.join(f"{k}={v}" for k,v in self.postfix.items())
608
- print(f"[{pct:.1f}%] {self.n:,}/{self.total:,} {self.unit} | {rate:.0f} tok/s | {pf}")
609
- def close(self): self._print(); print("Done.")
 
 
 
 
 
 
 
610
 
611
  import torch.nn as nn
612
  import torch.nn.functional as F
@@ -2118,7 +2177,10 @@ def _train_phase(
2118
  BATCH -= 1
2119
  time.sleep(2)
2120
  else:
2121
- new_block = max(128, BLOCK // 2)
 
 
 
2122
  print(f"\n[{phase_name} OOM] Reducing Block: {BLOCK} -> {new_block}")
2123
  BLOCK = new_block
2124
  time.sleep(2)
@@ -2139,8 +2201,8 @@ def _train_phase(
2139
  oom_retries = 0
2140
  toks_processed = BLOCK * BATCH
2141
  seen_tok += toks_processed
2142
- pbar.update(toks_processed)
2143
  pbar.set_postfix(loss=f"{loss_value:.3f}", B=BATCH, L=BLOCK)
 
2144
  if args.save_every_sec > 0:
2145
  now_mono = time.monotonic()
2146
  if now_mono - last_save_mono >= args.save_every_sec:
 
69
  r"^\[(?P<percent>\d+(?:\.\d+)?)%\]\s+"
70
  r"(?P<seen>[\d,]+)/(?P<target>[\d,]+)\s+tok\s+\|\s+"
71
  r"(?P<tok_s>[\d.]+)\s+tok/s\s+\|\s+"
72
+ r"loss=(?P<loss>-?[\d.]+)\s+B=(?P<batch>\d+)\s+L=(?P<block>\d+)"
73
+ r"(?:\s+step=(?P<step>\d+))?"
74
+ r"(?:\s+eta=(?P<eta>\S+))?"
75
+ r"(?:\s+elapsed=(?P<elapsed>\S+))?"
76
+ r"\s*$"
77
  )
78
  _STATUS_DELTA_RE = re.compile(r"\[delta\]\s+saved\s+(?P<name>\S+?\.pt)\s+\((?P<sha>[0-9a-f]+)\.\.\.\)")
79
  _STATUS_STEP_RE = re.compile(r"step(?P<step>\d+)")
 
103
  return " ".join(parts)
104
 
105
 
106
+ def _status_compact_duration(seconds: Optional[float]) -> str:
107
+ if seconds is None:
108
+ return "unknown"
109
+ try:
110
+ if not math.isfinite(float(seconds)):
111
+ return "unknown"
112
+ except Exception:
113
+ return "unknown"
114
+ total = max(0, int(seconds))
115
+ years, rem = divmod(total, 365 * 86400)
116
+ days, rem = divmod(rem, 86400)
117
+ hours, rem = divmod(rem, 3600)
118
+ minutes, secs = divmod(rem, 60)
119
+ if years:
120
+ return f"{years}y{days}d{hours}h"
121
+ if days:
122
+ return f"{days}d{hours}h{minutes}m"
123
+ if hours:
124
+ return f"{hours}h{minutes}m{secs}s"
125
+ if minutes:
126
+ return f"{minutes}m{secs}s"
127
+ return f"{secs}s"
128
+
129
+
130
  def _status_format_int(value: Optional[int]) -> str:
131
  return "?" if value is None else f"{value:,}"
132
 
 
221
  "loss": loss,
222
  "batch": int(match.group("batch")),
223
  "block": int(match.group("block")),
224
+ "step": int(match.group("step")) if match.group("step") else None,
225
+ "eta": match.group("eta"),
226
+ "elapsed": match.group("elapsed"),
227
  }
228
 
229
 
 
539
 
540
  progress = status.get("progress")
541
  if progress:
542
+ eta = progress.get("eta")
543
+ if not eta and progress.get("tok_per_sec"):
544
+ remaining = max(0, progress["target_tokens"] - progress["seen_tokens"])
545
+ eta = _status_compact_duration(remaining / float(progress["tok_per_sec"]))
546
  lines.append(
547
  "Progress: "
548
  f"{progress['percent']:.1f}% | "
549
  f"{_status_format_int(progress['seen_tokens'])}/{_status_format_int(progress['target_tokens'])} tok | "
550
  f"{progress['tok_per_sec']} tok/s | loss {progress['loss']:.3f} | "
551
  f"B={progress['batch']} L={progress['block']}"
552
+ + (f" | step {progress['step']}" if progress.get("step") else "")
553
+ + (f" | ETA {eta}" if eta else "")
554
  )
555
  else:
556
  lines.append("Progress: unavailable")
 
627
 
628
  # SafeProgress - Claude-safe progress (discrete lines, not single growing line)
629
  class SafeProgress:
630
+ def __init__(self, total, initial=0, unit="tok", print_every=100, print_every_sec=60):
631
  self.total, self.n, self.unit = total, initial, unit
632
  self.initial = initial
633
  self.last_print, self.postfix = initial, {}
634
+ self.print_every = max(1, int(print_every))
635
+ self.print_every_sec = max(1, int(print_every_sec))
636
+ self.step = 0
637
+ self.last_print_step = 0
638
  self.start_time = __import__('time').time()
639
+ self.last_print_time = self.start_time
640
  def update(self, n=1):
641
  self.n += n
642
+ self.step += 1
643
+ now = __import__('time').time()
644
+ if (
645
+ self.step == 1
646
+ or (self.step - self.last_print_step) >= self.print_every
647
+ or (now - self.last_print_time) >= self.print_every_sec
648
+ ):
649
+ self._print(now)
650
+ self.last_print = self.n
651
+ self.last_print_step = self.step
652
+ self.last_print_time = now
653
  def set_postfix(self, **kwargs): self.postfix = kwargs
654
+ def _print(self, now=None):
655
+ now = now or __import__('time').time()
656
+ elapsed = now - self.start_time
657
  rate = (self.n - self.initial) / elapsed if elapsed > 0 else 0
658
  pct = 100 * self.n / self.total if self.total > 0 else 0
659
  pf = ' '.join(f"{k}={v}" for k,v in self.postfix.items())
660
+ remaining = max(0, self.total - self.n)
661
+ eta = _status_compact_duration(remaining / rate) if rate > 0 else "unknown"
662
+ elapsed_s = _status_compact_duration(elapsed)
663
+ print(
664
+ f"[{pct:.4f}%] {self.n:,}/{self.total:,} {self.unit} | "
665
+ f"{rate:.2f} tok/s | {pf} step={self.step} eta={eta} elapsed={elapsed_s}",
666
+ flush=True,
667
+ )
668
+ def close(self): self._print(); print("Done.", flush=True)
669
 
670
  import torch.nn as nn
671
  import torch.nn.functional as F
 
2177
  BATCH -= 1
2178
  time.sleep(2)
2179
  else:
2180
+ new_block = max(128, int(BLOCK * 0.8))
2181
+ new_block = max(128, (new_block // 128) * 128)
2182
+ if new_block >= BLOCK:
2183
+ new_block = max(128, BLOCK - 128)
2184
  print(f"\n[{phase_name} OOM] Reducing Block: {BLOCK} -> {new_block}")
2185
  BLOCK = new_block
2186
  time.sleep(2)
 
2201
  oom_retries = 0
2202
  toks_processed = BLOCK * BATCH
2203
  seen_tok += toks_processed
 
2204
  pbar.set_postfix(loss=f"{loss_value:.3f}", B=BATCH, L=BLOCK)
2205
+ pbar.update(toks_processed)
2206
  if args.save_every_sec > 0:
2207
  now_mono = time.monotonic()
2208
  if now_mono - last_save_mono >= args.save_every_sec:
run_agillm4_4090_longblock.sh CHANGED
@@ -13,7 +13,7 @@ if [ -f /root/.cache/huggingface/token ]; then
13
  fi
14
 
15
  PRESET="${AGILLM4_4090_PRESET:-agillm4_floor}"
16
- BLOCK="${AGILLM4_4090_BLOCK:-512}"
17
  TOKEN_PARAM_RATIO="${AGILLM4_4090_TOKEN_PARAM_RATIO:-100}"
18
  SAVE_DIR="${AGILLM4_4090_SAVE_DIR:-/workspace/agillm4_4090_ckpts}"
19
 
@@ -41,10 +41,10 @@ exec python -u /workspace/agillm-4/nB300_agillm4.py train \
41
  --nat_every "${AGILLM4_4090_NAT_EVERY:-4}" \
42
  --nat_loss_weight "${AGILLM4_4090_NAT_LOSS_WEIGHT:-1.0}" \
43
  --nat_expand "${AGILLM4_4090_NAT_EXPAND:-2}" \
44
- --nat_max_tokens "${AGILLM4_4090_NAT_MAX_TOKENS:-512}" \
45
  --token_param_ratio "$TOKEN_PARAM_RATIO" \
46
  --save_dir "$SAVE_DIR" \
47
- --save_every_sec "${AGILLM4_4090_SAVE_EVERY_SEC:-21600}" \
48
  --delta_every_steps "${AGILLM4_4090_DELTA_EVERY_STEPS:-25000}" \
49
- --delta_max_keep "${AGILLM4_4090_DELTA_MAX_KEEP:-8}" \
50
- --max_ckpts "${AGILLM4_4090_MAX_CKPTS:-3}"
 
13
  fi
14
 
15
  PRESET="${AGILLM4_4090_PRESET:-agillm4_floor}"
16
+ BLOCK="${AGILLM4_4090_BLOCK:-1280}"
17
  TOKEN_PARAM_RATIO="${AGILLM4_4090_TOKEN_PARAM_RATIO:-100}"
18
  SAVE_DIR="${AGILLM4_4090_SAVE_DIR:-/workspace/agillm4_4090_ckpts}"
19
 
 
41
  --nat_every "${AGILLM4_4090_NAT_EVERY:-4}" \
42
  --nat_loss_weight "${AGILLM4_4090_NAT_LOSS_WEIGHT:-1.0}" \
43
  --nat_expand "${AGILLM4_4090_NAT_EXPAND:-2}" \
44
+ --nat_max_tokens "${AGILLM4_4090_NAT_MAX_TOKENS:-768}" \
45
  --token_param_ratio "$TOKEN_PARAM_RATIO" \
46
  --save_dir "$SAVE_DIR" \
47
+ --save_every_sec "${AGILLM4_4090_SAVE_EVERY_SEC:-86400}" \
48
  --delta_every_steps "${AGILLM4_4090_DELTA_EVERY_STEPS:-25000}" \
49
+ --delta_max_keep "${AGILLM4_4090_DELTA_MAX_KEEP:-1}" \
50
+ --max_ckpts "${AGILLM4_4090_MAX_CKPTS:-1}"
upload_agillm4_checkpoints.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ from __future__ import annotations
3
+
4
+ import argparse
5
+ import hashlib
6
+ import json
7
+ import os
8
+ import shutil
9
+ import subprocess
10
+ import sys
11
+ import time
12
+ from datetime import datetime, timezone
13
+ from pathlib import Path
14
+ from typing import Any
15
+
16
+ from huggingface_hub import HfApi
17
+
18
+
19
+ def iso_now() -> str:
20
+ return datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ")
21
+
22
+
23
+ def stamp_now() -> str:
24
+ return datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%SZ")
25
+
26
+
27
+ def load_json(path: Path, default: Any) -> Any:
28
+ try:
29
+ return json.loads(path.read_text(encoding="utf-8"))
30
+ except Exception:
31
+ return default
32
+
33
+
34
+ def save_json(path: Path, data: Any) -> None:
35
+ path.parent.mkdir(parents=True, exist_ok=True)
36
+ path.write_text(json.dumps(data, indent=2, sort_keys=True), encoding="utf-8")
37
+
38
+
39
+ def sha256_file(path: Path) -> str:
40
+ h = hashlib.sha256()
41
+ with path.open("rb") as handle:
42
+ for chunk in iter(lambda: handle.read(1 << 22), b""):
43
+ h.update(chunk)
44
+ return h.hexdigest()
45
+
46
+
47
+ def upload_file(api: HfApi, repo_id: str, local_path: Path, remote_path: str, message: str) -> None:
48
+ api.upload_file(
49
+ repo_id=repo_id,
50
+ path_or_fileobj=str(local_path),
51
+ path_in_repo=remote_path,
52
+ commit_message=message,
53
+ )
54
+
55
+
56
+ def delete_remote_not_kept(api: HfApi, repo_id: str, remote_dir: str, keep_basenames: set[str]) -> list[str]:
57
+ deleted: list[str] = []
58
+ try:
59
+ files = api.list_repo_files(repo_id=repo_id)
60
+ except Exception as exc:
61
+ print(f"[upload] WARN list_repo_files failed: {exc}", flush=True)
62
+ return deleted
63
+ prefix = remote_dir.rstrip("/") + "/"
64
+ victims = []
65
+ for file_path in files:
66
+ if not file_path.startswith(prefix):
67
+ continue
68
+ name = Path(file_path).name
69
+ base = name[:-7] if name.endswith(".sha256") else name
70
+ if base not in keep_basenames:
71
+ victims.append(file_path)
72
+ if victims:
73
+ try:
74
+ api.delete_files(repo_id=repo_id, paths=victims, commit_message=f"Prune AGILLM4 uploads under {remote_dir}")
75
+ deleted.extend(victims)
76
+ except Exception as exc:
77
+ print(f"[upload] WARN delete_files failed for {len(victims)} files: {exc}", flush=True)
78
+ return deleted
79
+
80
+
81
+ def latest_file(glob_root: Path, pattern: str) -> Path | None:
82
+ files = [p for p in glob_root.glob(pattern) if p.is_file()]
83
+ return max(files, key=lambda p: p.stat().st_mtime) if files else None
84
+
85
+
86
+ def status_json(script: Path, log: Path, save_dir: Path) -> dict[str, Any]:
87
+ result = subprocess.run(
88
+ [sys.executable, "-u", str(script), "status", "--json", "--log", str(log), "--save_dir", str(save_dir)],
89
+ capture_output=True,
90
+ text=True,
91
+ timeout=60,
92
+ check=False,
93
+ )
94
+ if result.returncode != 0:
95
+ return {"checked_at": iso_now(), "error": result.stderr.strip() or result.stdout.strip()}
96
+ try:
97
+ return json.loads(result.stdout)
98
+ except Exception:
99
+ return {"checked_at": iso_now(), "error": "failed to parse status json", "raw": result.stdout[-4000:]}
100
+
101
+
102
+ def write_tail(src: Path, dst: Path, lines: int) -> None:
103
+ dst.parent.mkdir(parents=True, exist_ok=True)
104
+ if not src.exists():
105
+ dst.write_text("", encoding="utf-8")
106
+ return
107
+ result = subprocess.run(["tail", "-n", str(lines), str(src)], capture_output=True, text=True, check=False)
108
+ dst.write_text(result.stdout, encoding="utf-8", errors="replace")
109
+
110
+
111
+ def maybe_upload_large(
112
+ api: HfApi,
113
+ repo_id: str,
114
+ state: dict[str, Any],
115
+ kind: str,
116
+ path: Path | None,
117
+ remote_dir: str,
118
+ interval_sec: int,
119
+ keep: int,
120
+ ) -> bool:
121
+ if path is None or not path.exists():
122
+ print(f"[upload] no {kind} checkpoint yet", flush=True)
123
+ return False
124
+ now = time.time()
125
+ last_t = float(state.get(f"last_{kind}_upload_time", 0) or 0)
126
+ identity = f"{path.name}:{path.stat().st_size}:{int(path.stat().st_mtime)}"
127
+ if state.get(f"last_{kind}_identity") == identity:
128
+ print(f"[upload] {kind} unchanged: {path.name}", flush=True)
129
+ return False
130
+ if last_t and now - last_t < interval_sec:
131
+ remaining = int(interval_sec - (now - last_t))
132
+ print(f"[upload] {kind} interval not due for {remaining}s: {path.name}", flush=True)
133
+ return False
134
+
135
+ digest = sha256_file(path)
136
+ sha_path = path.with_suffix(path.suffix + ".upload.sha256")
137
+ sha_path.write_text(f"{digest} {path.name}\n", encoding="utf-8")
138
+ remote_name = f"{stamp_now()}_{path.name}"
139
+ remote_path = f"{remote_dir.rstrip('/')}/{remote_name}"
140
+ print(f"[upload] uploading {kind}: {path} -> {repo_id}/{remote_path}", flush=True)
141
+ upload_file(api, repo_id, path, remote_path, f"Upload AGILLM4 {kind} checkpoint {path.name}")
142
+ upload_file(api, repo_id, sha_path, remote_path + ".sha256", f"Upload AGILLM4 {kind} checksum {path.name}")
143
+
144
+ history = list(state.get(f"{kind}_uploads", []))
145
+ history.append({"name": remote_name, "remote_path": remote_path, "sha256": digest, "uploaded_at": iso_now(), "size": path.stat().st_size})
146
+ history = history[-max(1, keep):]
147
+ state[f"{kind}_uploads"] = history
148
+ state[f"last_{kind}_upload_time"] = now
149
+ state[f"last_{kind}_identity"] = identity
150
+ keep_names = {item["name"] for item in history}
151
+ deleted = delete_remote_not_kept(api, repo_id, remote_dir, keep_names)
152
+ if deleted:
153
+ print(f"[upload] pruned {len(deleted)} remote {kind} files", flush=True)
154
+ return True
155
+
156
+
157
+ def main() -> int:
158
+ parser = argparse.ArgumentParser(description="Bounded AGILLM4 checkpoint uploader")
159
+ parser.add_argument("--repo", default=os.environ.get("AGILLM4_UPLOAD_REPO", "OpenTransformer/AGILLM-4"))
160
+ parser.add_argument("--prefix", default=os.environ.get("AGILLM4_UPLOAD_PREFIX", "training/agillm4_floor_v47"))
161
+ parser.add_argument("--save-dir", type=Path, default=Path(os.environ.get("AGILLM4_UPLOAD_SAVE_DIR", "/workspace/agillm4_4090_ckpts")))
162
+ parser.add_argument("--log", type=Path, default=Path(os.environ.get("AGILLM4_UPLOAD_LOG", "/workspace/agillm4_floor_train.log")))
163
+ parser.add_argument("--script", type=Path, default=Path(os.environ.get("AGILLM4_UPLOAD_SCRIPT", "/workspace/agillm-4/nB300_agillm4.py")))
164
+ parser.add_argument("--state", type=Path, default=Path(os.environ.get("AGILLM4_UPLOAD_STATE", "/workspace/agillm4_upload_state.json")))
165
+ parser.add_argument("--stage", type=Path, default=Path(os.environ.get("AGILLM4_UPLOAD_STAGE", "/workspace/agillm4_upload_stage")))
166
+ parser.add_argument("--full-interval-sec", type=int, default=int(os.environ.get("AGILLM4_UPLOAD_FULL_INTERVAL_SEC", str(7 * 24 * 3600))))
167
+ parser.add_argument("--delta-interval-sec", type=int, default=int(os.environ.get("AGILLM4_UPLOAD_DELTA_INTERVAL_SEC", str(24 * 3600))))
168
+ parser.add_argument("--keep-full", type=int, default=int(os.environ.get("AGILLM4_UPLOAD_KEEP_FULL", "2")))
169
+ parser.add_argument("--keep-delta", type=int, default=int(os.environ.get("AGILLM4_UPLOAD_KEEP_DELTA", "2")))
170
+ parser.add_argument("--tail-lines", type=int, default=int(os.environ.get("AGILLM4_UPLOAD_TAIL_LINES", "5000")))
171
+ args = parser.parse_args()
172
+
173
+ api = HfApi()
174
+ prefix = args.prefix.strip("/")
175
+ args.stage.mkdir(parents=True, exist_ok=True)
176
+ state = load_json(args.state, {})
177
+
178
+ status = status_json(args.script, args.log, args.save_dir)
179
+ status["upload_policy"] = {
180
+ "full_interval_sec": args.full_interval_sec,
181
+ "delta_interval_sec": args.delta_interval_sec,
182
+ "keep_full_current_files": args.keep_full,
183
+ "keep_delta_current_files": args.keep_delta,
184
+ "note": "Small status/log tail uploads are frequent; multi-GB deltas/full checkpoints are rate-limited for HF public storage.",
185
+ }
186
+ status_path = args.stage / "status.json"
187
+ save_json(status_path, status)
188
+ upload_file(api, args.repo, status_path, f"{prefix}/status/status.json", "Update AGILLM4 training status")
189
+
190
+ tail_path = args.stage / "train_tail.log"
191
+ write_tail(args.log, tail_path, args.tail_lines)
192
+ upload_file(api, args.repo, tail_path, f"{prefix}/logs/train_tail.log", "Update AGILLM4 training log tail")
193
+
194
+ latest_json = args.save_dir / "latest.json"
195
+ if latest_json.exists():
196
+ shutil.copy2(latest_json, args.stage / "latest.json")
197
+ upload_file(api, args.repo, args.stage / "latest.json", f"{prefix}/status/latest.json", "Update AGILLM4 latest checkpoint metadata")
198
+
199
+ newest_delta = latest_file(args.save_dir, "*_delta_step*.pt")
200
+ newest_full = latest_file(args.save_dir, "*_step*.pt")
201
+ maybe_upload_large(api, args.repo, state, "delta", newest_delta, f"{prefix}/checkpoints/deltas", args.delta_interval_sec, args.keep_delta)
202
+ maybe_upload_large(api, args.repo, state, "full", newest_full, f"{prefix}/checkpoints/full", args.full_interval_sec, args.keep_full)
203
+
204
+ state["last_status_upload_at"] = iso_now()
205
+ save_json(args.state, state)
206
+ manifest_path = args.stage / "upload_state.json"
207
+ save_json(manifest_path, state)
208
+ upload_file(api, args.repo, manifest_path, f"{prefix}/status/upload_state.json", "Update AGILLM4 upload state")
209
+ print(f"[upload] done {iso_now()} repo={args.repo} prefix={prefix}", flush=True)
210
+ return 0
211
+
212
+
213
+ if __name__ == "__main__":
214
+ raise SystemExit(main())
upload_agillm4_checkpoints_loop.sh ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ set -Eeuo pipefail
3
+
4
+ cd /workspace/agillm-4
5
+
6
+ LOG="${AGILLM4_UPLOAD_LOOP_LOG:-/workspace/agillm4_upload_loop.log}"
7
+ INTERVAL="${AGILLM4_UPLOAD_LOOP_INTERVAL_SEC:-1800}"
8
+
9
+ if [ -f /root/.cache/huggingface/token ]; then
10
+ export HF_TOKEN="$(tr -d '\r\n' < /root/.cache/huggingface/token)"
11
+ export HUGGING_FACE_HUB_TOKEN="$HF_TOKEN"
12
+ fi
13
+
14
+ mkdir -p "$(dirname "$LOG")"
15
+ exec >> "$LOG" 2>&1
16
+
17
+ echo "START_AGILLM4_UPLOAD_LOOP $(date -u +%Y-%m-%dT%H:%M:%SZ) interval=${INTERVAL}s"
18
+
19
+ while true; do
20
+ echo "UPLOAD_TICK $(date -u +%Y-%m-%dT%H:%M:%SZ)"
21
+ python -u /workspace/agillm-4/upload_agillm4_checkpoints.py || true
22
+ sleep "$INTERVAL"
23
+ done