Commit ·
87b2fa6
1
Parent(s): f828a0b
Add post-training evals, /dev/shm checkpoints, async HF push, and _orig_mod fix
Browse filestrain_all.py:
- Fix torch.compile _orig_mod prefix bug: unwrap model before saving
- Add --run-evals for automatic probes, diagnostics, and Lichess eval
- Add --publish-results to push eval results to HuggingFace
- Add --shm-checkpoints for RAM-backed volatile storage (requires --hf-repo)
- Push checkpoints to HF in background threads (non-blocking)
- Track best val checkpoint for eval and /dev/shm cleanup
- Clean up old /dev/shm checkpoints after successful HF push
eval_probes.py: handle directory-based safetensors checkpoints
monitor_training.sh: new script for remote training monitoring
CLAUDE.md: document --shm-checkpoints, --run-evals, async HF push
- CLAUDE.md +4 -2
- scripts/eval_probes.py +5 -1
- scripts/monitor_training.sh +40 -0
- scripts/train_all.py +180 -7
- uv.lock +121 -28
CLAUDE.md
CHANGED
|
@@ -86,7 +86,7 @@ All adapters freeze the backbone and initialize to identity (zero-init or gamma=
|
|
| 86 |
## Scripts (`scripts/`)
|
| 87 |
|
| 88 |
- `train.py` -- Pretrain from scratch (`--variant small|base|large|toy`)
|
| 89 |
-
- `train_all.py` -- Train small/base/large simultaneously on shared data batches
|
| 90 |
- `train_bottleneck.py`, `train_film.py`, `train_lora.py`, `train_sparse.py`, `train_hybrid.py` -- Adapter behavioral cloning on Lichess PGN
|
| 91 |
- `train_tiny.py` -- Standalone tiny transformer baseline (no frozen backbone)
|
| 92 |
- `eval_accuracy.py` -- MAIA-compatible evaluation (per-phase, per-ply accuracy)
|
|
@@ -151,7 +151,9 @@ All training scripts require one of:
|
|
| 151 |
- `--hf-repo REPO_ID` — push checkpoints to a HuggingFace branch as they're written (durable)
|
| 152 |
- `--local-checkpoints` — save locally only (for development without an HF account)
|
| 153 |
|
| 154 |
-
HF mode creates a `run/{run_id}` branch. Squash-merge into main when satisfied.
|
|
|
|
|
|
|
| 155 |
|
| 156 |
### Data Integrity
|
| 157 |
|
|
|
|
| 86 |
## Scripts (`scripts/`)
|
| 87 |
|
| 88 |
- `train.py` -- Pretrain from scratch (`--variant small|base|large|toy`)
|
| 89 |
+
- `train_all.py` -- Train small/base/large simultaneously on shared data batches. Supports `--run-evals` for automatic post-training probes, diagnostics, and Lichess eval, and `--publish-results` to push eval results to HF.
|
| 90 |
- `train_bottleneck.py`, `train_film.py`, `train_lora.py`, `train_sparse.py`, `train_hybrid.py` -- Adapter behavioral cloning on Lichess PGN
|
| 91 |
- `train_tiny.py` -- Standalone tiny transformer baseline (no frozen backbone)
|
| 92 |
- `eval_accuracy.py` -- MAIA-compatible evaluation (per-phase, per-ply accuracy)
|
|
|
|
| 151 |
- `--hf-repo REPO_ID` — push checkpoints to a HuggingFace branch as they're written (durable)
|
| 152 |
- `--local-checkpoints` — save locally only (for development without an HF account)
|
| 153 |
|
| 154 |
+
HF mode creates a `run/{run_id}` branch. HF pushes happen in background threads (one per model slot) so training is not blocked by uploads. Squash-merge into main when satisfied.
|
| 155 |
+
|
| 156 |
+
Optional: `--shm-checkpoints` writes checkpoints to `/dev/shm` (RAM-backed filesystem, instant writes). Requires `--hf-repo` since `/dev/shm` is volatile. Old checkpoints are cleaned up after successful HF push, keeping only the latest and the best (by val loss) for post-training evals.
|
| 157 |
|
| 158 |
### Data Integrity
|
| 159 |
|
scripts/eval_probes.py
CHANGED
|
@@ -61,7 +61,11 @@ def main():
|
|
| 61 |
run_dir = config_path.parent
|
| 62 |
if args.run and run_dir.name != args.run:
|
| 63 |
continue
|
| 64 |
-
checkpoints
|
|
|
|
|
|
|
|
|
|
|
|
|
| 65 |
if not checkpoints:
|
| 66 |
continue
|
| 67 |
latest = checkpoints[-1]
|
|
|
|
| 61 |
run_dir = config_path.parent
|
| 62 |
if args.run and run_dir.name != args.run:
|
| 63 |
continue
|
| 64 |
+
# Find checkpoints: directory-based (safetensors) or legacy .pt
|
| 65 |
+
checkpoints = sorted(
|
| 66 |
+
[d for d in run_dir.glob("checkpoints/step_*") if d.is_dir()]
|
| 67 |
+
or list(run_dir.glob("checkpoints/step_*.pt"))
|
| 68 |
+
)
|
| 69 |
if not checkpoints:
|
| 70 |
continue
|
| 71 |
latest = checkpoints[-1]
|
scripts/monitor_training.sh
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
# Monitor multi-model training: check pod log + HuggingFace checkpoints.
|
| 3 |
+
# Usage: bash scripts/monitor_training.sh <host> <port>
|
| 4 |
+
set -euo pipefail
|
| 5 |
+
|
| 6 |
+
HOST="${1:-50.145.48.110}"
|
| 7 |
+
PORT="${2:-13321}"
|
| 8 |
+
SSH="ssh -o StrictHostKeyChecking=accept-new -o ConnectTimeout=10 -p $PORT root@$HOST"
|
| 9 |
+
|
| 10 |
+
echo "=== Training Log ==="
|
| 11 |
+
$SSH "tail -15 /workspace/logs/train_all.log" 2>/dev/null || echo " (SSH failed)"
|
| 12 |
+
|
| 13 |
+
echo ""
|
| 14 |
+
echo "=== Process Status ==="
|
| 15 |
+
$SSH "pgrep -f train_all > /dev/null && echo RUNNING || echo STOPPED" 2>/dev/null || echo " (SSH failed)"
|
| 16 |
+
|
| 17 |
+
echo ""
|
| 18 |
+
echo "=== HuggingFace Checkpoints ==="
|
| 19 |
+
uv run python3 -c "
|
| 20 |
+
from huggingface_hub import HfApi
|
| 21 |
+
api = HfApi()
|
| 22 |
+
for variant in ['small', 'base', 'large']:
|
| 23 |
+
repo = f'thomas-schweich/pawn-{variant}'
|
| 24 |
+
try:
|
| 25 |
+
branches = [b.name for b in api.list_repo_refs(repo, repo_type='model').branches if b.name.startswith('run/')]
|
| 26 |
+
for branch in branches:
|
| 27 |
+
files = [f.rfilename for f in api.list_repo_tree(repo, revision=branch, repo_type='model', recursive=True) if hasattr(f, 'rfilename') and 'checkpoints/' in f.rfilename]
|
| 28 |
+
ckpts = sorted(set(f.split('/')[1] for f in files if f.startswith('checkpoints/step_')))
|
| 29 |
+
print(f' {repo}@{branch}: {len(ckpts)} checkpoints ({ckpts[-1] if ckpts else \"none\"})')
|
| 30 |
+
if not branches:
|
| 31 |
+
print(f' {repo}: no run branches')
|
| 32 |
+
except Exception as e:
|
| 33 |
+
print(f' {repo}: {e}')
|
| 34 |
+
" 2>/dev/null || echo " (HF check failed)"
|
| 35 |
+
|
| 36 |
+
echo ""
|
| 37 |
+
echo "=== Metrics Sync ==="
|
| 38 |
+
rsync -az --include='*/' --include='metrics.jsonl' --include='config.json' --exclude='*' \
|
| 39 |
+
-e "ssh -o StrictHostKeyChecking=accept-new -p $PORT" \
|
| 40 |
+
"root@$HOST:/workspace/logs/" logs/ 2>/dev/null && echo " Synced" || echo " (Sync failed)"
|
scripts/train_all.py
CHANGED
|
@@ -15,9 +15,11 @@ import argparse
|
|
| 15 |
import json
|
| 16 |
import math
|
| 17 |
import os
|
|
|
|
| 18 |
import signal
|
| 19 |
import sys
|
| 20 |
import time
|
|
|
|
| 21 |
|
| 22 |
import torch
|
| 23 |
import torch.multiprocessing as mp
|
|
@@ -26,7 +28,7 @@ from torch.utils.data import DataLoader
|
|
| 26 |
from pawn.config import CLMConfig, TrainingConfig
|
| 27 |
from pawn.model import PAWNCLM, clm_loss
|
| 28 |
from pawn.data import CLMDataset, create_validation_set
|
| 29 |
-
from pawn.gpu import configure_gpu
|
| 30 |
from pawn.checkpoint import save_pretrain_checkpoint, push_checkpoint_to_hf
|
| 31 |
|
| 32 |
|
|
@@ -44,12 +46,14 @@ class ModelSlot:
|
|
| 44 |
train_cfg: TrainingConfig,
|
| 45 |
device: str,
|
| 46 |
hf_repo: str | None,
|
|
|
|
| 47 |
):
|
| 48 |
self.name = name
|
| 49 |
self.model_cfg = model_cfg
|
| 50 |
self.train_cfg = train_cfg
|
| 51 |
self.device = device
|
| 52 |
self.hf_repo = hf_repo
|
|
|
|
| 53 |
|
| 54 |
self.model = PAWNCLM(model_cfg).to(device)
|
| 55 |
param_count = sum(p.numel() for p in self.model.parameters())
|
|
@@ -70,9 +74,14 @@ class ModelSlot:
|
|
| 70 |
|
| 71 |
self.scaler = torch.amp.GradScaler(device, enabled=train_cfg.use_amp)
|
| 72 |
|
| 73 |
-
# Run directory
|
| 74 |
self.run_dir = _make_run_dir(train_cfg.log_dir, name)
|
| 75 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 76 |
os.makedirs(self.checkpoint_dir, exist_ok=True)
|
| 77 |
|
| 78 |
self.jsonl_path = os.path.join(self.run_dir, "metrics.jsonl")
|
|
@@ -80,6 +89,13 @@ class ModelSlot:
|
|
| 80 |
|
| 81 |
self.hf_branch = f"run/{os.path.basename(self.run_dir)}" if hf_repo else None
|
| 82 |
self.global_step = 0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 83 |
|
| 84 |
# Write config
|
| 85 |
import subprocess
|
|
@@ -133,24 +149,56 @@ class ModelSlot:
|
|
| 133 |
self.scheduler.step()
|
| 134 |
return grad_norm
|
| 135 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 136 |
def save_checkpoint(self):
|
| 137 |
path = os.path.join(self.checkpoint_dir, f"step_{self.global_step:08d}")
|
| 138 |
save_pretrain_checkpoint(
|
| 139 |
-
path, self.
|
| 140 |
self.global_step, self.model_cfg.__dict__, self.train_cfg.__dict__,
|
| 141 |
)
|
| 142 |
print(f" [{self.name}] Checkpoint saved: {path}")
|
| 143 |
|
| 144 |
if self.hf_repo and self.hf_branch:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 145 |
try:
|
| 146 |
push_checkpoint_to_hf(
|
| 147 |
-
|
| 148 |
-
metrics_path=self.jsonl_path, step=
|
| 149 |
)
|
| 150 |
print(f" [{self.name}] Pushed to HF: {self.hf_repo}@{self.hf_branch}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 151 |
except Exception as e:
|
| 152 |
print(f" [{self.name}] WARNING: HF push failed: {e}")
|
| 153 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 154 |
@torch.no_grad()
|
| 155 |
def evaluate(self, val_data: dict[str, torch.Tensor]) -> dict[str, float]:
|
| 156 |
self.model.eval()
|
|
@@ -182,6 +230,8 @@ class ModelSlot:
|
|
| 182 |
self._jsonl_file.flush()
|
| 183 |
|
| 184 |
def close(self):
|
|
|
|
|
|
|
| 185 |
if self._jsonl_file:
|
| 186 |
self._jsonl_file.close()
|
| 187 |
self._jsonl_file = None
|
|
@@ -215,12 +265,120 @@ def parse_args():
|
|
| 215 |
ckpt_group.add_argument("--hf-repo", type=str, default=None,
|
| 216 |
help="HF repo prefix (appends -{variant}). E.g. thomas-schweich/pawn")
|
| 217 |
ckpt_group.add_argument("--local-checkpoints", action="store_true")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 218 |
return p.parse_args()
|
| 219 |
|
| 220 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 221 |
def main():
|
| 222 |
args = parse_args()
|
| 223 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 224 |
device = args.device or ("cuda" if torch.cuda.is_available() else "cpu")
|
| 225 |
if device == "cuda":
|
| 226 |
gpu_cfg = configure_gpu()
|
|
@@ -239,6 +397,8 @@ def main():
|
|
| 239 |
print(f"Device: {device}")
|
| 240 |
print(f"Batch size: {args.batch_size}")
|
| 241 |
print(f"Total steps: {args.total_steps}")
|
|
|
|
|
|
|
| 242 |
print()
|
| 243 |
|
| 244 |
slots: list[ModelSlot] = []
|
|
@@ -256,7 +416,8 @@ def main():
|
|
| 256 |
train_cfg.use_wandb = args.wandb
|
| 257 |
|
| 258 |
hf_repo = f"{args.hf_repo}-{name}" if args.hf_repo else None
|
| 259 |
-
slots.append(ModelSlot(name, model_cfg, train_cfg, device, hf_repo
|
|
|
|
| 260 |
|
| 261 |
# Shared dataset and validation set
|
| 262 |
max_ply = 256
|
|
@@ -362,6 +523,11 @@ def main():
|
|
| 362 |
"timestamp": time.time(),
|
| 363 |
**val_metrics,
|
| 364 |
})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 365 |
|
| 366 |
# Checkpoint
|
| 367 |
if global_step % args.checkpoint_interval == 0:
|
|
@@ -389,6 +555,13 @@ def main():
|
|
| 389 |
for slot in slots:
|
| 390 |
slot.close()
|
| 391 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 392 |
print("\nAll done.")
|
| 393 |
|
| 394 |
|
|
|
|
| 15 |
import json
|
| 16 |
import math
|
| 17 |
import os
|
| 18 |
+
import shutil
|
| 19 |
import signal
|
| 20 |
import sys
|
| 21 |
import time
|
| 22 |
+
from pathlib import Path
|
| 23 |
|
| 24 |
import torch
|
| 25 |
import torch.multiprocessing as mp
|
|
|
|
| 28 |
from pawn.config import CLMConfig, TrainingConfig
|
| 29 |
from pawn.model import PAWNCLM, clm_loss
|
| 30 |
from pawn.data import CLMDataset, create_validation_set
|
| 31 |
+
from pawn.gpu import configure_gpu
|
| 32 |
from pawn.checkpoint import save_pretrain_checkpoint, push_checkpoint_to_hf
|
| 33 |
|
| 34 |
|
|
|
|
| 46 |
train_cfg: TrainingConfig,
|
| 47 |
device: str,
|
| 48 |
hf_repo: str | None,
|
| 49 |
+
shm_checkpoints: bool = False,
|
| 50 |
):
|
| 51 |
self.name = name
|
| 52 |
self.model_cfg = model_cfg
|
| 53 |
self.train_cfg = train_cfg
|
| 54 |
self.device = device
|
| 55 |
self.hf_repo = hf_repo
|
| 56 |
+
self.shm_checkpoints = shm_checkpoints
|
| 57 |
|
| 58 |
self.model = PAWNCLM(model_cfg).to(device)
|
| 59 |
param_count = sum(p.numel() for p in self.model.parameters())
|
|
|
|
| 74 |
|
| 75 |
self.scaler = torch.amp.GradScaler(device, enabled=train_cfg.use_amp)
|
| 76 |
|
| 77 |
+
# Run directory (logs always on persistent disk)
|
| 78 |
self.run_dir = _make_run_dir(train_cfg.log_dir, name)
|
| 79 |
+
|
| 80 |
+
# Checkpoint directory: /dev/shm if requested, else under run_dir
|
| 81 |
+
if shm_checkpoints:
|
| 82 |
+
self.checkpoint_dir = f"/dev/shm/pawn_checkpoints/{name}"
|
| 83 |
+
else:
|
| 84 |
+
self.checkpoint_dir = os.path.join(self.run_dir, "checkpoints")
|
| 85 |
os.makedirs(self.checkpoint_dir, exist_ok=True)
|
| 86 |
|
| 87 |
self.jsonl_path = os.path.join(self.run_dir, "metrics.jsonl")
|
|
|
|
| 89 |
|
| 90 |
self.hf_branch = f"run/{os.path.basename(self.run_dir)}" if hf_repo else None
|
| 91 |
self.global_step = 0
|
| 92 |
+
self.best_val_step = 0
|
| 93 |
+
self.best_val_loss = float("inf")
|
| 94 |
+
|
| 95 |
+
# Background HF push (one thread per slot, so pushes don't block training)
|
| 96 |
+
from concurrent.futures import ThreadPoolExecutor
|
| 97 |
+
self._hf_push_pool = ThreadPoolExecutor(max_workers=1, thread_name_prefix=f"hf-{name}")
|
| 98 |
+
self._hf_push_future = None
|
| 99 |
|
| 100 |
# Write config
|
| 101 |
import subprocess
|
|
|
|
| 149 |
self.scheduler.step()
|
| 150 |
return grad_norm
|
| 151 |
|
| 152 |
+
def _unwrapped_model(self):
|
| 153 |
+
"""Return the unwrapped model (strips torch.compile wrapper)."""
|
| 154 |
+
m = self.model
|
| 155 |
+
while hasattr(m, '_orig_mod'):
|
| 156 |
+
m = m._orig_mod
|
| 157 |
+
return m
|
| 158 |
+
|
| 159 |
def save_checkpoint(self):
|
| 160 |
path = os.path.join(self.checkpoint_dir, f"step_{self.global_step:08d}")
|
| 161 |
save_pretrain_checkpoint(
|
| 162 |
+
path, self._unwrapped_model(), self.optimizer, self.scheduler, self.scaler,
|
| 163 |
self.global_step, self.model_cfg.__dict__, self.train_cfg.__dict__,
|
| 164 |
)
|
| 165 |
print(f" [{self.name}] Checkpoint saved: {path}")
|
| 166 |
|
| 167 |
if self.hf_repo and self.hf_branch:
|
| 168 |
+
self._push_to_hf_async(path, self.global_step)
|
| 169 |
+
|
| 170 |
+
def _push_to_hf_async(self, ckpt_path: str, step: int):
|
| 171 |
+
"""Push checkpoint to HuggingFace in a background thread."""
|
| 172 |
+
# Wait for any previous push to finish before starting a new one
|
| 173 |
+
if self._hf_push_future is not None:
|
| 174 |
+
self._hf_push_future.result() # raises if previous push failed
|
| 175 |
+
|
| 176 |
+
def _push():
|
| 177 |
try:
|
| 178 |
push_checkpoint_to_hf(
|
| 179 |
+
ckpt_path, self.hf_repo, self.hf_branch,
|
| 180 |
+
metrics_path=self.jsonl_path, step=step,
|
| 181 |
)
|
| 182 |
print(f" [{self.name}] Pushed to HF: {self.hf_repo}@{self.hf_branch}")
|
| 183 |
+
|
| 184 |
+
# On /dev/shm, clean up old checkpoints after successful push.
|
| 185 |
+
# Keep the latest (just saved) and the best (for post-training evals).
|
| 186 |
+
if self.shm_checkpoints:
|
| 187 |
+
keep = {Path(ckpt_path).name, f"step_{self.best_val_step:08d}"}
|
| 188 |
+
for old in sorted(Path(self.checkpoint_dir).glob("step_*")):
|
| 189 |
+
if old.name not in keep:
|
| 190 |
+
shutil.rmtree(old, ignore_errors=True)
|
| 191 |
except Exception as e:
|
| 192 |
print(f" [{self.name}] WARNING: HF push failed: {e}")
|
| 193 |
|
| 194 |
+
self._hf_push_future = self._hf_push_pool.submit(_push)
|
| 195 |
+
|
| 196 |
+
def wait_for_push(self):
|
| 197 |
+
"""Block until any in-flight HF push completes."""
|
| 198 |
+
if self._hf_push_future is not None:
|
| 199 |
+
self._hf_push_future.result()
|
| 200 |
+
self._hf_push_future = None
|
| 201 |
+
|
| 202 |
@torch.no_grad()
|
| 203 |
def evaluate(self, val_data: dict[str, torch.Tensor]) -> dict[str, float]:
|
| 204 |
self.model.eval()
|
|
|
|
| 230 |
self._jsonl_file.flush()
|
| 231 |
|
| 232 |
def close(self):
|
| 233 |
+
self.wait_for_push()
|
| 234 |
+
self._hf_push_pool.shutdown(wait=True)
|
| 235 |
if self._jsonl_file:
|
| 236 |
self._jsonl_file.close()
|
| 237 |
self._jsonl_file = None
|
|
|
|
| 265 |
ckpt_group.add_argument("--hf-repo", type=str, default=None,
|
| 266 |
help="HF repo prefix (appends -{variant}). E.g. thomas-schweich/pawn")
|
| 267 |
ckpt_group.add_argument("--local-checkpoints", action="store_true")
|
| 268 |
+
|
| 269 |
+
p.add_argument("--shm-checkpoints", action="store_true",
|
| 270 |
+
help="Write checkpoints to /dev/shm (RAM-backed, instant writes). "
|
| 271 |
+
"Requires --hf-repo since /dev/shm is volatile.")
|
| 272 |
+
|
| 273 |
+
p.add_argument("--run-evals", action="store_true",
|
| 274 |
+
help="Run probes, diagnostics, and Lichess eval after training completes")
|
| 275 |
+
p.add_argument("--lichess-pgn", type=str, default=None,
|
| 276 |
+
help="Path to Lichess PGN file for eval (required with --run-evals)")
|
| 277 |
+
p.add_argument("--publish-results", action="store_true",
|
| 278 |
+
help="Push eval results to HuggingFace (requires --hf-repo and --run-evals)")
|
| 279 |
return p.parse_args()
|
| 280 |
|
| 281 |
|
| 282 |
+
def _run_post_training_evals(slots: list[ModelSlot], args):
|
| 283 |
+
"""Run probes, diagnostics, and Lichess eval on best checkpoint per variant."""
|
| 284 |
+
import tempfile
|
| 285 |
+
from pawn.eval_suite.probes import extract_probe_data, train_all_probes
|
| 286 |
+
from pawn.eval_suite.corpus import generate_corpus, load_corpus
|
| 287 |
+
from pawn.eval_suite.diagnostics import extract_diagnostic_positions, evaluate_diagnostic_positions
|
| 288 |
+
|
| 289 |
+
device = args.device or ("cuda" if torch.cuda.is_available() else "cpu")
|
| 290 |
+
|
| 291 |
+
for slot in slots:
|
| 292 |
+
print(f"\n--- Evaluating {slot.name} ---")
|
| 293 |
+
|
| 294 |
+
# Use tracked best val step (kept on /dev/shm if shm_checkpoints)
|
| 295 |
+
best_step = slot.best_val_step
|
| 296 |
+
best_loss = slot.best_val_loss
|
| 297 |
+
|
| 298 |
+
ckpt_path = os.path.join(slot.checkpoint_dir, f"step_{best_step:08d}")
|
| 299 |
+
if not os.path.isdir(ckpt_path):
|
| 300 |
+
# Fall back to latest
|
| 301 |
+
ckpts = sorted(Path(slot.checkpoint_dir).glob("step_*"))
|
| 302 |
+
ckpt_path = str(ckpts[-1]) if ckpts else None
|
| 303 |
+
|
| 304 |
+
if not ckpt_path:
|
| 305 |
+
print(f" No checkpoint found, skipping")
|
| 306 |
+
continue
|
| 307 |
+
|
| 308 |
+
print(f" Best checkpoint: {ckpt_path} (val_loss={best_loss:.4f})")
|
| 309 |
+
|
| 310 |
+
# Load model (unwrapped)
|
| 311 |
+
from pawn.checkpoint import load_backbone_weights
|
| 312 |
+
state_dict, _ = load_backbone_weights(ckpt_path)
|
| 313 |
+
model = PAWNCLM(slot.model_cfg).to(device)
|
| 314 |
+
model.load_state_dict(state_dict)
|
| 315 |
+
model.eval()
|
| 316 |
+
|
| 317 |
+
results = {}
|
| 318 |
+
|
| 319 |
+
# 1. Probes
|
| 320 |
+
print(" Running probes...")
|
| 321 |
+
train_data = extract_probe_data(2048, 256, seed=12345)
|
| 322 |
+
val_data = extract_probe_data(512, 256, seed=54321)
|
| 323 |
+
probe_results = train_all_probes(
|
| 324 |
+
model, train_data, val_data, device=device,
|
| 325 |
+
per_layer=True, n_epochs=20, verbose=True,
|
| 326 |
+
)
|
| 327 |
+
results["probes"] = probe_results
|
| 328 |
+
del train_data, val_data
|
| 329 |
+
|
| 330 |
+
# 2. Diagnostics
|
| 331 |
+
print(" Running diagnostics...")
|
| 332 |
+
with tempfile.TemporaryDirectory() as tmpdir:
|
| 333 |
+
corpus_path = generate_corpus(tmpdir, n_games=2048, max_ply=255, seed=99999, batch_size=2048)
|
| 334 |
+
corpus = load_corpus(corpus_path)
|
| 335 |
+
positions = extract_diagnostic_positions(corpus, min_per_category=200, max_per_category=1000)
|
| 336 |
+
diag_results = evaluate_diagnostic_positions(model, positions, corpus, device=device)
|
| 337 |
+
results["diagnostics"] = diag_results
|
| 338 |
+
|
| 339 |
+
# 3. Lichess eval (if PGN provided)
|
| 340 |
+
if args.lichess_pgn:
|
| 341 |
+
print(" Running Lichess eval...")
|
| 342 |
+
from pawn.eval_suite.lichess import prepare_lichess_corpus, evaluate_on_lichess
|
| 343 |
+
lichess_data = prepare_lichess_corpus(args.lichess_pgn, max_games_per_band=1000)
|
| 344 |
+
lichess_results = evaluate_on_lichess(model, lichess_data, device=device)
|
| 345 |
+
results["lichess"] = lichess_results
|
| 346 |
+
|
| 347 |
+
# Save results
|
| 348 |
+
results_path = os.path.join(slot.run_dir, "eval_results.json")
|
| 349 |
+
with open(results_path, "w") as f:
|
| 350 |
+
json.dump(results, f, indent=2, default=str)
|
| 351 |
+
print(f" Results saved: {results_path}")
|
| 352 |
+
|
| 353 |
+
# Publish to HF
|
| 354 |
+
if args.publish_results and slot.hf_repo and slot.hf_branch:
|
| 355 |
+
from huggingface_hub import HfApi
|
| 356 |
+
api = HfApi()
|
| 357 |
+
try:
|
| 358 |
+
api.upload_file(
|
| 359 |
+
path_or_fileobj=results_path,
|
| 360 |
+
path_in_repo="eval_results.json",
|
| 361 |
+
repo_id=slot.hf_repo,
|
| 362 |
+
repo_type="model",
|
| 363 |
+
revision=slot.hf_branch,
|
| 364 |
+
commit_message=f"Eval results (best step {best_step})",
|
| 365 |
+
)
|
| 366 |
+
print(f" Published to {slot.hf_repo}@{slot.hf_branch}")
|
| 367 |
+
except Exception as e:
|
| 368 |
+
print(f" WARNING: HF publish failed: {e}")
|
| 369 |
+
|
| 370 |
+
del model, state_dict
|
| 371 |
+
if torch.cuda.is_available():
|
| 372 |
+
torch.cuda.empty_cache()
|
| 373 |
+
|
| 374 |
+
|
| 375 |
def main():
|
| 376 |
args = parse_args()
|
| 377 |
|
| 378 |
+
if args.shm_checkpoints and not args.hf_repo:
|
| 379 |
+
print("ERROR: --shm-checkpoints requires --hf-repo (HF is the only durable store)")
|
| 380 |
+
sys.exit(1)
|
| 381 |
+
|
| 382 |
device = args.device or ("cuda" if torch.cuda.is_available() else "cpu")
|
| 383 |
if device == "cuda":
|
| 384 |
gpu_cfg = configure_gpu()
|
|
|
|
| 397 |
print(f"Device: {device}")
|
| 398 |
print(f"Batch size: {args.batch_size}")
|
| 399 |
print(f"Total steps: {args.total_steps}")
|
| 400 |
+
if args.shm_checkpoints:
|
| 401 |
+
print("Checkpoints: /dev/shm (volatile, HF push is durable store)")
|
| 402 |
print()
|
| 403 |
|
| 404 |
slots: list[ModelSlot] = []
|
|
|
|
| 416 |
train_cfg.use_wandb = args.wandb
|
| 417 |
|
| 418 |
hf_repo = f"{args.hf_repo}-{name}" if args.hf_repo else None
|
| 419 |
+
slots.append(ModelSlot(name, model_cfg, train_cfg, device, hf_repo,
|
| 420 |
+
shm_checkpoints=args.shm_checkpoints))
|
| 421 |
|
| 422 |
# Shared dataset and validation set
|
| 423 |
max_ply = 256
|
|
|
|
| 523 |
"timestamp": time.time(),
|
| 524 |
**val_metrics,
|
| 525 |
})
|
| 526 |
+
# Track best for eval and /dev/shm cleanup
|
| 527 |
+
vl = val_metrics["val/loss"]
|
| 528 |
+
if vl < slot.best_val_loss:
|
| 529 |
+
slot.best_val_loss = vl
|
| 530 |
+
slot.best_val_step = global_step
|
| 531 |
|
| 532 |
# Checkpoint
|
| 533 |
if global_step % args.checkpoint_interval == 0:
|
|
|
|
| 555 |
for slot in slots:
|
| 556 |
slot.close()
|
| 557 |
|
| 558 |
+
# Post-training evals
|
| 559 |
+
if args.run_evals:
|
| 560 |
+
print("\n" + "=" * 60)
|
| 561 |
+
print("POST-TRAINING EVALUATION")
|
| 562 |
+
print("=" * 60)
|
| 563 |
+
_run_post_training_evals(slots, args)
|
| 564 |
+
|
| 565 |
print("\nAll done.")
|
| 566 |
|
| 567 |
|
uv.lock
CHANGED
|
@@ -20,6 +20,15 @@ members = [
|
|
| 20 |
"pawn",
|
| 21 |
]
|
| 22 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
[[package]]
|
| 24 |
name = "annotated-types"
|
| 25 |
version = "0.7.0"
|
|
@@ -499,6 +508,70 @@ wheels = [
|
|
| 499 |
{ url = "https://files.pythonhosted.org/packages/04/4b/29cac41a4d98d144bf5f6d33995617b185d14b22401f75ca86f384e87ff1/h11-0.16.0-py3-none-any.whl", hash = "sha256:63cf8bbe7522de3bf65932fda1d9c2772064ffb3dae62d55932da54b31cb6c86", size = 37515, upload-time = "2025-04-24T03:35:24.344Z" },
|
| 500 |
]
|
| 501 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 502 |
[[package]]
|
| 503 |
name = "humanize"
|
| 504 |
version = "4.15.0"
|
|
@@ -1322,10 +1395,22 @@ name = "pawn"
|
|
| 1322 |
version = "0.1.0"
|
| 1323 |
source = { editable = "." }
|
| 1324 |
dependencies = [
|
|
|
|
| 1325 |
{ name = "chess-engine", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
|
|
|
|
|
|
|
|
|
|
| 1326 |
{ name = "numpy", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1327 |
{ name = "psutil", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
|
|
|
|
|
|
|
| 1328 |
{ name = "safetensors", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
|
|
|
|
|
|
|
| 1329 |
{ name = "tqdm", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
|
| 1330 |
{ name = "wandb", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
|
| 1331 |
]
|
|
@@ -1334,23 +1419,6 @@ dependencies = [
|
|
| 1334 |
cu128 = [
|
| 1335 |
{ name = "torch", version = "2.10.0+cu128", source = { registry = "https://download.pytorch.org/whl/cu128" }, marker = "sys_platform == 'linux'" },
|
| 1336 |
]
|
| 1337 |
-
dashboard = [
|
| 1338 |
-
{ name = "anywidget", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
|
| 1339 |
-
{ name = "pandas", version = "2.3.3", source = { registry = "https://pypi.org/simple" }, marker = "(python_full_version < '3.11' and sys_platform == 'linux') or (python_full_version >= '3.11' and extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm') or (sys_platform != 'linux' and extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
|
| 1340 |
-
{ name = "pandas", version = "3.0.1", source = { registry = "https://pypi.org/simple" }, marker = "(python_full_version >= '3.11' and sys_platform == 'linux') or (python_full_version < '3.11' and extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm') or (sys_platform != 'linux' and extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
|
| 1341 |
-
{ name = "plotly", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
|
| 1342 |
-
{ name = "solara", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
|
| 1343 |
-
]
|
| 1344 |
-
dev = [
|
| 1345 |
-
{ name = "ipykernel", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
|
| 1346 |
-
{ name = "pytest", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
|
| 1347 |
-
]
|
| 1348 |
-
eval = [
|
| 1349 |
-
{ name = "matplotlib", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
|
| 1350 |
-
{ name = "polars", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
|
| 1351 |
-
{ name = "pyarrow", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
|
| 1352 |
-
{ name = "seaborn", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
|
| 1353 |
-
]
|
| 1354 |
rocm = [
|
| 1355 |
{ name = "torch", version = "2.10.0+rocm7.1", source = { registry = "https://download.pytorch.org/whl/rocm7.1" }, marker = "sys_platform == 'linux'" },
|
| 1356 |
{ name = "triton-rocm", marker = "sys_platform == 'linux'" },
|
|
@@ -1358,27 +1426,28 @@ rocm = [
|
|
| 1358 |
|
| 1359 |
[package.metadata]
|
| 1360 |
requires-dist = [
|
| 1361 |
-
{ name = "anywidget",
|
| 1362 |
{ name = "chess-engine", editable = "engine" },
|
| 1363 |
-
{ name = "
|
| 1364 |
-
{ name = "
|
|
|
|
| 1365 |
{ name = "numpy", specifier = "~=2.2.0" },
|
| 1366 |
-
{ name = "pandas",
|
| 1367 |
-
{ name = "plotly",
|
| 1368 |
-
{ name = "polars",
|
| 1369 |
{ name = "psutil", specifier = ">=5.9.0" },
|
| 1370 |
-
{ name = "pyarrow",
|
| 1371 |
-
{ name = "pytest",
|
| 1372 |
{ name = "safetensors", specifier = ">=0.4.0" },
|
| 1373 |
-
{ name = "seaborn",
|
| 1374 |
-
{ name = "solara",
|
| 1375 |
{ name = "torch", marker = "extra == 'cu128'", specifier = "~=2.10.0", index = "https://download.pytorch.org/whl/cu128", conflict = { package = "pawn", extra = "cu128" } },
|
| 1376 |
{ name = "torch", marker = "extra == 'rocm'", specifier = "~=2.10.0", index = "https://download.pytorch.org/whl/rocm7.1", conflict = { package = "pawn", extra = "rocm" } },
|
| 1377 |
{ name = "tqdm", specifier = "~=4.67.0" },
|
| 1378 |
{ name = "triton-rocm", marker = "extra == 'rocm'", specifier = ">=3.6.0", index = "https://download.pytorch.org/whl/rocm7.1", conflict = { package = "pawn", extra = "rocm" } },
|
| 1379 |
{ name = "wandb", specifier = "~=0.25.0" },
|
| 1380 |
]
|
| 1381 |
-
provides-extras = ["rocm", "cu128"
|
| 1382 |
|
| 1383 |
[[package]]
|
| 1384 |
name = "pexpect"
|
|
@@ -2065,6 +2134,15 @@ wheels = [
|
|
| 2065 |
{ url = "https://files.pythonhosted.org/packages/9d/76/f789f7a86709c6b087c5a2f52f911838cad707cc613162401badc665acfe/setuptools-82.0.1-py3-none-any.whl", hash = "sha256:a59e362652f08dcd477c78bb6e7bd9d80a7995bc73ce773050228a348ce2e5bb", size = 1006223, upload-time = "2026-03-09T12:47:15.026Z" },
|
| 2066 |
]
|
| 2067 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2068 |
[[package]]
|
| 2069 |
name = "six"
|
| 2070 |
version = "1.17.0"
|
|
@@ -2351,6 +2429,21 @@ wheels = [
|
|
| 2351 |
{ url = "https://download.pytorch.org/whl/triton_rocm-3.6.0-cp312-cp312-linux_x86_64.whl" },
|
| 2352 |
]
|
| 2353 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2354 |
[[package]]
|
| 2355 |
name = "typing-extensions"
|
| 2356 |
version = "4.15.0"
|
|
|
|
| 20 |
"pawn",
|
| 21 |
]
|
| 22 |
|
| 23 |
+
[[package]]
|
| 24 |
+
name = "annotated-doc"
|
| 25 |
+
version = "0.0.4"
|
| 26 |
+
source = { registry = "https://pypi.org/simple" }
|
| 27 |
+
sdist = { url = "https://files.pythonhosted.org/packages/57/ba/046ceea27344560984e26a590f90bc7f4a75b06701f653222458922b558c/annotated_doc-0.0.4.tar.gz", hash = "sha256:fbcda96e87e9c92ad167c2e53839e57503ecfda18804ea28102353485033faa4", size = 7288, upload-time = "2025-11-10T22:07:42.062Z" }
|
| 28 |
+
wheels = [
|
| 29 |
+
{ url = "https://files.pythonhosted.org/packages/1e/d3/26bf1008eb3d2daa8ef4cacc7f3bfdc11818d111f7e2d0201bc6e3b49d45/annotated_doc-0.0.4-py3-none-any.whl", hash = "sha256:571ac1dc6991c450b25a9c2d84a3705e2ae7a53467b5d111c24fa8baabbed320", size = 5303, upload-time = "2025-11-10T22:07:40.673Z" },
|
| 30 |
+
]
|
| 31 |
+
|
| 32 |
[[package]]
|
| 33 |
name = "annotated-types"
|
| 34 |
version = "0.7.0"
|
|
|
|
| 508 |
{ url = "https://files.pythonhosted.org/packages/04/4b/29cac41a4d98d144bf5f6d33995617b185d14b22401f75ca86f384e87ff1/h11-0.16.0-py3-none-any.whl", hash = "sha256:63cf8bbe7522de3bf65932fda1d9c2772064ffb3dae62d55932da54b31cb6c86", size = 37515, upload-time = "2025-04-24T03:35:24.344Z" },
|
| 509 |
]
|
| 510 |
|
| 511 |
+
[[package]]
|
| 512 |
+
name = "hf-xet"
|
| 513 |
+
version = "1.4.2"
|
| 514 |
+
source = { registry = "https://pypi.org/simple" }
|
| 515 |
+
sdist = { url = "https://files.pythonhosted.org/packages/09/08/23c84a26716382c89151b5b447b4beb19e3345f3a93d3b73009a71a57ad3/hf_xet-1.4.2.tar.gz", hash = "sha256:b7457b6b482d9e0743bd116363239b1fa904a5e65deede350fbc0c4ea67c71ea", size = 672357, upload-time = "2026-03-13T06:58:51.077Z" }
|
| 516 |
+
wheels = [
|
| 517 |
+
{ url = "https://files.pythonhosted.org/packages/b4/86/b40b83a2ff03ef05c4478d2672b1fc2b9683ff870e2b25f4f3af240f2e7b/hf_xet-1.4.2-cp37-abi3-macosx_10_12_x86_64.whl", hash = "sha256:71f02d6e4cdd07f344f6844845d78518cc7186bd2bc52d37c3b73dc26a3b0bc5", size = 3800339, upload-time = "2026-03-13T06:58:36.245Z" },
|
| 518 |
+
{ url = "https://files.pythonhosted.org/packages/64/2e/af4475c32b4378b0e92a587adb1aa3ec53e3450fd3e5fe0372a874531c00/hf_xet-1.4.2-cp37-abi3-macosx_11_0_arm64.whl", hash = "sha256:e9b38d876e94d4bdcf650778d6ebbaa791dd28de08db9736c43faff06ede1b5a", size = 3559664, upload-time = "2026-03-13T06:58:34.787Z" },
|
| 519 |
+
{ url = "https://files.pythonhosted.org/packages/3c/4c/781267da3188db679e601de18112021a5cb16506fe86b246e22c5401a9c4/hf_xet-1.4.2-cp37-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:77e8c180b7ef12d8a96739a4e1e558847002afe9ea63b6f6358b2271a8bdda1c", size = 4217422, upload-time = "2026-03-13T06:58:27.472Z" },
|
| 520 |
+
{ url = "https://files.pythonhosted.org/packages/68/47/d6cf4a39ecf6c7705f887a46f6ef5c8455b44ad9eb0d391aa7e8a2ff7fea/hf_xet-1.4.2-cp37-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:c3b3c6a882016b94b6c210957502ff7877802d0dbda8ad142c8595db8b944271", size = 3992847, upload-time = "2026-03-13T06:58:25.989Z" },
|
| 521 |
+
{ url = "https://files.pythonhosted.org/packages/2d/ef/e80815061abff54697239803948abc665c6b1d237102c174f4f7a9a5ffc5/hf_xet-1.4.2-cp37-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:9d9a634cc929cfbaf2e1a50c0e532ae8c78fa98618426769480c58501e8c8ac2", size = 4193843, upload-time = "2026-03-13T06:58:44.59Z" },
|
| 522 |
+
{ url = "https://files.pythonhosted.org/packages/54/75/07f6aa680575d9646c4167db6407c41340cbe2357f5654c4e72a1b01ca14/hf_xet-1.4.2-cp37-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:6b0932eb8b10317ea78b7da6bab172b17be03bbcd7809383d8d5abd6a2233e04", size = 4432751, upload-time = "2026-03-13T06:58:46.533Z" },
|
| 523 |
+
{ url = "https://files.pythonhosted.org/packages/cd/71/193eabd7e7d4b903c4aa983a215509c6114915a5a237525ec562baddb868/hf_xet-1.4.2-cp37-abi3-win_amd64.whl", hash = "sha256:ad185719fb2e8ac26f88c8100562dbf9dbdcc3d9d2add00faa94b5f106aea53f", size = 3671149, upload-time = "2026-03-13T06:58:57.07Z" },
|
| 524 |
+
{ url = "https://files.pythonhosted.org/packages/b4/7e/ccf239da366b37ba7f0b36095450efae4a64980bdc7ec2f51354205fdf39/hf_xet-1.4.2-cp37-abi3-win_arm64.whl", hash = "sha256:32c012286b581f783653e718c1862aea5b9eb140631685bb0c5e7012c8719a87", size = 3533426, upload-time = "2026-03-13T06:58:55.46Z" },
|
| 525 |
+
]
|
| 526 |
+
|
| 527 |
+
[[package]]
|
| 528 |
+
name = "httpcore"
|
| 529 |
+
version = "1.0.9"
|
| 530 |
+
source = { registry = "https://pypi.org/simple" }
|
| 531 |
+
dependencies = [
|
| 532 |
+
{ name = "certifi", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
|
| 533 |
+
{ name = "h11", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
|
| 534 |
+
]
|
| 535 |
+
sdist = { url = "https://files.pythonhosted.org/packages/06/94/82699a10bca87a5556c9c59b5963f2d039dbd239f25bc2a63907a05a14cb/httpcore-1.0.9.tar.gz", hash = "sha256:6e34463af53fd2ab5d807f399a9b45ea31c3dfa2276f15a2c3f00afff6e176e8", size = 85484, upload-time = "2025-04-24T22:06:22.219Z" }
|
| 536 |
+
wheels = [
|
| 537 |
+
{ url = "https://files.pythonhosted.org/packages/7e/f5/f66802a942d491edb555dd61e3a9961140fd64c90bce1eafd741609d334d/httpcore-1.0.9-py3-none-any.whl", hash = "sha256:2d400746a40668fc9dec9810239072b40b4484b640a8c38fd654a024c7a1bf55", size = 78784, upload-time = "2025-04-24T22:06:20.566Z" },
|
| 538 |
+
]
|
| 539 |
+
|
| 540 |
+
[[package]]
|
| 541 |
+
name = "httpx"
|
| 542 |
+
version = "0.28.1"
|
| 543 |
+
source = { registry = "https://pypi.org/simple" }
|
| 544 |
+
dependencies = [
|
| 545 |
+
{ name = "anyio", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
|
| 546 |
+
{ name = "certifi", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
|
| 547 |
+
{ name = "httpcore", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
|
| 548 |
+
{ name = "idna", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
|
| 549 |
+
]
|
| 550 |
+
sdist = { url = "https://files.pythonhosted.org/packages/b1/df/48c586a5fe32a0f01324ee087459e112ebb7224f646c0b5023f5e79e9956/httpx-0.28.1.tar.gz", hash = "sha256:75e98c5f16b0f35b567856f597f06ff2270a374470a5c2392242528e3e3e42fc", size = 141406, upload-time = "2024-12-06T15:37:23.222Z" }
|
| 551 |
+
wheels = [
|
| 552 |
+
{ url = "https://files.pythonhosted.org/packages/2a/39/e50c7c3a983047577ee07d2a9e53faf5a69493943ec3f6a384bdc792deb2/httpx-0.28.1-py3-none-any.whl", hash = "sha256:d909fcccc110f8c7faf814ca82a9a4d816bc5a6dbfea25d6591d6985b8ba59ad", size = 73517, upload-time = "2024-12-06T15:37:21.509Z" },
|
| 553 |
+
]
|
| 554 |
+
|
| 555 |
+
[[package]]
|
| 556 |
+
name = "huggingface-hub"
|
| 557 |
+
version = "1.7.2"
|
| 558 |
+
source = { registry = "https://pypi.org/simple" }
|
| 559 |
+
dependencies = [
|
| 560 |
+
{ name = "filelock", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
|
| 561 |
+
{ name = "fsspec", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
|
| 562 |
+
{ name = "hf-xet", marker = "(platform_machine != 'AMD64' and platform_machine != 'aarch64' and platform_machine != 'amd64' and platform_machine != 'arm64' and platform_machine != 'x86_64' and extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm') or (platform_machine == 'AMD64' and sys_platform == 'linux') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'amd64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
|
| 563 |
+
{ name = "httpx", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
|
| 564 |
+
{ name = "packaging", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
|
| 565 |
+
{ name = "pyyaml", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
|
| 566 |
+
{ name = "tqdm", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
|
| 567 |
+
{ name = "typer", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
|
| 568 |
+
{ name = "typing-extensions", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
|
| 569 |
+
]
|
| 570 |
+
sdist = { url = "https://files.pythonhosted.org/packages/19/15/eafc1c57bf0f8afffb243dcd4c0cceb785e956acc17bba4d9bf2ae21fc9c/huggingface_hub-1.7.2.tar.gz", hash = "sha256:7f7e294e9bbb822e025bdb2ada025fa4344d978175a7f78e824d86e35f7ab43b", size = 724684, upload-time = "2026-03-20T10:36:08.767Z" }
|
| 571 |
+
wheels = [
|
| 572 |
+
{ url = "https://files.pythonhosted.org/packages/08/de/3ad061a05f74728927ded48c90b73521b9a9328c85d841bdefb30e01fb85/huggingface_hub-1.7.2-py3-none-any.whl", hash = "sha256:288f33a0a17b2a73a1359e2a5fd28d1becb2c121748c6173ab8643fb342c850e", size = 618036, upload-time = "2026-03-20T10:36:06.824Z" },
|
| 573 |
+
]
|
| 574 |
+
|
| 575 |
[[package]]
|
| 576 |
name = "humanize"
|
| 577 |
version = "4.15.0"
|
|
|
|
| 1395 |
version = "0.1.0"
|
| 1396 |
source = { editable = "." }
|
| 1397 |
dependencies = [
|
| 1398 |
+
{ name = "anywidget", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
|
| 1399 |
{ name = "chess-engine", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
|
| 1400 |
+
{ name = "huggingface-hub", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
|
| 1401 |
+
{ name = "ipykernel", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
|
| 1402 |
+
{ name = "matplotlib", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
|
| 1403 |
{ name = "numpy", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
|
| 1404 |
+
{ name = "pandas", version = "2.3.3", source = { registry = "https://pypi.org/simple" }, marker = "(python_full_version < '3.11' and sys_platform == 'linux') or (python_full_version >= '3.11' and extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm') or (sys_platform != 'linux' and extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
|
| 1405 |
+
{ name = "pandas", version = "3.0.1", source = { registry = "https://pypi.org/simple" }, marker = "(python_full_version >= '3.11' and sys_platform == 'linux') or (python_full_version < '3.11' and extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm') or (sys_platform != 'linux' and extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
|
| 1406 |
+
{ name = "plotly", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
|
| 1407 |
+
{ name = "polars", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
|
| 1408 |
{ name = "psutil", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
|
| 1409 |
+
{ name = "pyarrow", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
|
| 1410 |
+
{ name = "pytest", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
|
| 1411 |
{ name = "safetensors", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
|
| 1412 |
+
{ name = "seaborn", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
|
| 1413 |
+
{ name = "solara", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
|
| 1414 |
{ name = "tqdm", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
|
| 1415 |
{ name = "wandb", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
|
| 1416 |
]
|
|
|
|
| 1419 |
cu128 = [
|
| 1420 |
{ name = "torch", version = "2.10.0+cu128", source = { registry = "https://download.pytorch.org/whl/cu128" }, marker = "sys_platform == 'linux'" },
|
| 1421 |
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1422 |
rocm = [
|
| 1423 |
{ name = "torch", version = "2.10.0+rocm7.1", source = { registry = "https://download.pytorch.org/whl/rocm7.1" }, marker = "sys_platform == 'linux'" },
|
| 1424 |
{ name = "triton-rocm", marker = "sys_platform == 'linux'" },
|
|
|
|
| 1426 |
|
| 1427 |
[package.metadata]
|
| 1428 |
requires-dist = [
|
| 1429 |
+
{ name = "anywidget", specifier = ">=0.9.21" },
|
| 1430 |
{ name = "chess-engine", editable = "engine" },
|
| 1431 |
+
{ name = "huggingface-hub", specifier = ">=0.20.0" },
|
| 1432 |
+
{ name = "ipykernel", specifier = ">=7.2.0" },
|
| 1433 |
+
{ name = "matplotlib", specifier = ">=3.10.8" },
|
| 1434 |
{ name = "numpy", specifier = "~=2.2.0" },
|
| 1435 |
+
{ name = "pandas", specifier = ">=2.0.0" },
|
| 1436 |
+
{ name = "plotly", specifier = ">=5.18.0" },
|
| 1437 |
+
{ name = "polars", specifier = ">=1.39.0" },
|
| 1438 |
{ name = "psutil", specifier = ">=5.9.0" },
|
| 1439 |
+
{ name = "pyarrow", specifier = ">=23.0.1" },
|
| 1440 |
+
{ name = "pytest", specifier = "~=9.0.0" },
|
| 1441 |
{ name = "safetensors", specifier = ">=0.4.0" },
|
| 1442 |
+
{ name = "seaborn", specifier = ">=0.13.2" },
|
| 1443 |
+
{ name = "solara", specifier = ">=1.0.0" },
|
| 1444 |
{ name = "torch", marker = "extra == 'cu128'", specifier = "~=2.10.0", index = "https://download.pytorch.org/whl/cu128", conflict = { package = "pawn", extra = "cu128" } },
|
| 1445 |
{ name = "torch", marker = "extra == 'rocm'", specifier = "~=2.10.0", index = "https://download.pytorch.org/whl/rocm7.1", conflict = { package = "pawn", extra = "rocm" } },
|
| 1446 |
{ name = "tqdm", specifier = "~=4.67.0" },
|
| 1447 |
{ name = "triton-rocm", marker = "extra == 'rocm'", specifier = ">=3.6.0", index = "https://download.pytorch.org/whl/rocm7.1", conflict = { package = "pawn", extra = "rocm" } },
|
| 1448 |
{ name = "wandb", specifier = "~=0.25.0" },
|
| 1449 |
]
|
| 1450 |
+
provides-extras = ["rocm", "cu128"]
|
| 1451 |
|
| 1452 |
[[package]]
|
| 1453 |
name = "pexpect"
|
|
|
|
| 2134 |
{ url = "https://files.pythonhosted.org/packages/9d/76/f789f7a86709c6b087c5a2f52f911838cad707cc613162401badc665acfe/setuptools-82.0.1-py3-none-any.whl", hash = "sha256:a59e362652f08dcd477c78bb6e7bd9d80a7995bc73ce773050228a348ce2e5bb", size = 1006223, upload-time = "2026-03-09T12:47:15.026Z" },
|
| 2135 |
]
|
| 2136 |
|
| 2137 |
+
[[package]]
|
| 2138 |
+
name = "shellingham"
|
| 2139 |
+
version = "1.5.4"
|
| 2140 |
+
source = { registry = "https://pypi.org/simple" }
|
| 2141 |
+
sdist = { url = "https://files.pythonhosted.org/packages/58/15/8b3609fd3830ef7b27b655beb4b4e9c62313a4e8da8c676e142cc210d58e/shellingham-1.5.4.tar.gz", hash = "sha256:8dbca0739d487e5bd35ab3ca4b36e11c4078f3a234bfce294b0a0291363404de", size = 10310, upload-time = "2023-10-24T04:13:40.426Z" }
|
| 2142 |
+
wheels = [
|
| 2143 |
+
{ url = "https://files.pythonhosted.org/packages/e0/f9/0595336914c5619e5f28a1fb793285925a8cd4b432c9da0a987836c7f822/shellingham-1.5.4-py2.py3-none-any.whl", hash = "sha256:7ecfff8f2fd72616f7481040475a65b2bf8af90a56c89140852d1120324e8686", size = 9755, upload-time = "2023-10-24T04:13:38.866Z" },
|
| 2144 |
+
]
|
| 2145 |
+
|
| 2146 |
[[package]]
|
| 2147 |
name = "six"
|
| 2148 |
version = "1.17.0"
|
|
|
|
| 2429 |
{ url = "https://download.pytorch.org/whl/triton_rocm-3.6.0-cp312-cp312-linux_x86_64.whl" },
|
| 2430 |
]
|
| 2431 |
|
| 2432 |
+
[[package]]
|
| 2433 |
+
name = "typer"
|
| 2434 |
+
version = "0.24.1"
|
| 2435 |
+
source = { registry = "https://pypi.org/simple" }
|
| 2436 |
+
dependencies = [
|
| 2437 |
+
{ name = "annotated-doc", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
|
| 2438 |
+
{ name = "click", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
|
| 2439 |
+
{ name = "rich", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
|
| 2440 |
+
{ name = "shellingham", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
|
| 2441 |
+
]
|
| 2442 |
+
sdist = { url = "https://files.pythonhosted.org/packages/f5/24/cb09efec5cc954f7f9b930bf8279447d24618bb6758d4f6adf2574c41780/typer-0.24.1.tar.gz", hash = "sha256:e39b4732d65fbdcde189ae76cf7cd48aeae72919dea1fdfc16593be016256b45", size = 118613, upload-time = "2026-02-21T16:54:40.609Z" }
|
| 2443 |
+
wheels = [
|
| 2444 |
+
{ url = "https://files.pythonhosted.org/packages/4a/91/48db081e7a63bb37284f9fbcefda7c44c277b18b0e13fbc36ea2335b71e6/typer-0.24.1-py3-none-any.whl", hash = "sha256:112c1f0ce578bfb4cab9ffdabc68f031416ebcc216536611ba21f04e9aa84c9e", size = 56085, upload-time = "2026-02-21T16:54:41.616Z" },
|
| 2445 |
+
]
|
| 2446 |
+
|
| 2447 |
[[package]]
|
| 2448 |
name = "typing-extensions"
|
| 2449 |
version = "4.15.0"
|