thomas-schweich commited on
Commit
87b2fa6
·
1 Parent(s): f828a0b

Add post-training evals, /dev/shm checkpoints, async HF push, and _orig_mod fix

Browse files

train_all.py:
- Fix torch.compile _orig_mod prefix bug: unwrap model before saving
- Add --run-evals for automatic probes, diagnostics, and Lichess eval
- Add --publish-results to push eval results to HuggingFace
- Add --shm-checkpoints for RAM-backed volatile storage (requires --hf-repo)
- Push checkpoints to HF in background threads (non-blocking)
- Track best val checkpoint for eval and /dev/shm cleanup
- Clean up old /dev/shm checkpoints after successful HF push

eval_probes.py: handle directory-based safetensors checkpoints
monitor_training.sh: new script for remote training monitoring
CLAUDE.md: document --shm-checkpoints, --run-evals, async HF push

Files changed (5) hide show
  1. CLAUDE.md +4 -2
  2. scripts/eval_probes.py +5 -1
  3. scripts/monitor_training.sh +40 -0
  4. scripts/train_all.py +180 -7
  5. uv.lock +121 -28
CLAUDE.md CHANGED
@@ -86,7 +86,7 @@ All adapters freeze the backbone and initialize to identity (zero-init or gamma=
86
  ## Scripts (`scripts/`)
87
 
88
  - `train.py` -- Pretrain from scratch (`--variant small|base|large|toy`)
89
- - `train_all.py` -- Train small/base/large simultaneously on shared data batches
90
  - `train_bottleneck.py`, `train_film.py`, `train_lora.py`, `train_sparse.py`, `train_hybrid.py` -- Adapter behavioral cloning on Lichess PGN
91
  - `train_tiny.py` -- Standalone tiny transformer baseline (no frozen backbone)
92
  - `eval_accuracy.py` -- MAIA-compatible evaluation (per-phase, per-ply accuracy)
@@ -151,7 +151,9 @@ All training scripts require one of:
151
  - `--hf-repo REPO_ID` — push checkpoints to a HuggingFace branch as they're written (durable)
152
  - `--local-checkpoints` — save locally only (for development without an HF account)
153
 
154
- HF mode creates a `run/{run_id}` branch. Squash-merge into main when satisfied.
 
 
155
 
156
  ### Data Integrity
157
 
 
86
  ## Scripts (`scripts/`)
87
 
88
  - `train.py` -- Pretrain from scratch (`--variant small|base|large|toy`)
89
+ - `train_all.py` -- Train small/base/large simultaneously on shared data batches. Supports `--run-evals` for automatic post-training probes, diagnostics, and Lichess eval, and `--publish-results` to push eval results to HF.
90
  - `train_bottleneck.py`, `train_film.py`, `train_lora.py`, `train_sparse.py`, `train_hybrid.py` -- Adapter behavioral cloning on Lichess PGN
91
  - `train_tiny.py` -- Standalone tiny transformer baseline (no frozen backbone)
92
  - `eval_accuracy.py` -- MAIA-compatible evaluation (per-phase, per-ply accuracy)
 
151
  - `--hf-repo REPO_ID` — push checkpoints to a HuggingFace branch as they're written (durable)
152
  - `--local-checkpoints` — save locally only (for development without an HF account)
153
 
154
+ HF mode creates a `run/{run_id}` branch. HF pushes happen in background threads (one per model slot) so training is not blocked by uploads. Squash-merge into main when satisfied.
155
+
156
+ Optional: `--shm-checkpoints` writes checkpoints to `/dev/shm` (RAM-backed filesystem, instant writes). Requires `--hf-repo` since `/dev/shm` is volatile. Old checkpoints are cleaned up after successful HF push, keeping only the latest and the best (by val loss) for post-training evals.
157
 
158
  ### Data Integrity
159
 
scripts/eval_probes.py CHANGED
@@ -61,7 +61,11 @@ def main():
61
  run_dir = config_path.parent
62
  if args.run and run_dir.name != args.run:
63
  continue
64
- checkpoints = sorted(run_dir.glob("checkpoints/step_*.pt"))
 
 
 
 
65
  if not checkpoints:
66
  continue
67
  latest = checkpoints[-1]
 
61
  run_dir = config_path.parent
62
  if args.run and run_dir.name != args.run:
63
  continue
64
+ # Find checkpoints: directory-based (safetensors) or legacy .pt
65
+ checkpoints = sorted(
66
+ [d for d in run_dir.glob("checkpoints/step_*") if d.is_dir()]
67
+ or list(run_dir.glob("checkpoints/step_*.pt"))
68
+ )
69
  if not checkpoints:
70
  continue
71
  latest = checkpoints[-1]
scripts/monitor_training.sh ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ # Monitor multi-model training: check pod log + HuggingFace checkpoints.
3
+ # Usage: bash scripts/monitor_training.sh <host> <port>
4
+ set -euo pipefail
5
+
6
+ HOST="${1:-50.145.48.110}"
7
+ PORT="${2:-13321}"
8
+ SSH="ssh -o StrictHostKeyChecking=accept-new -o ConnectTimeout=10 -p $PORT root@$HOST"
9
+
10
+ echo "=== Training Log ==="
11
+ $SSH "tail -15 /workspace/logs/train_all.log" 2>/dev/null || echo " (SSH failed)"
12
+
13
+ echo ""
14
+ echo "=== Process Status ==="
15
+ $SSH "pgrep -f train_all > /dev/null && echo RUNNING || echo STOPPED" 2>/dev/null || echo " (SSH failed)"
16
+
17
+ echo ""
18
+ echo "=== HuggingFace Checkpoints ==="
19
+ uv run python3 -c "
20
+ from huggingface_hub import HfApi
21
+ api = HfApi()
22
+ for variant in ['small', 'base', 'large']:
23
+ repo = f'thomas-schweich/pawn-{variant}'
24
+ try:
25
+ branches = [b.name for b in api.list_repo_refs(repo, repo_type='model').branches if b.name.startswith('run/')]
26
+ for branch in branches:
27
+ files = [f.rfilename for f in api.list_repo_tree(repo, revision=branch, repo_type='model', recursive=True) if hasattr(f, 'rfilename') and 'checkpoints/' in f.rfilename]
28
+ ckpts = sorted(set(f.split('/')[1] for f in files if f.startswith('checkpoints/step_')))
29
+ print(f' {repo}@{branch}: {len(ckpts)} checkpoints ({ckpts[-1] if ckpts else \"none\"})')
30
+ if not branches:
31
+ print(f' {repo}: no run branches')
32
+ except Exception as e:
33
+ print(f' {repo}: {e}')
34
+ " 2>/dev/null || echo " (HF check failed)"
35
+
36
+ echo ""
37
+ echo "=== Metrics Sync ==="
38
+ rsync -az --include='*/' --include='metrics.jsonl' --include='config.json' --exclude='*' \
39
+ -e "ssh -o StrictHostKeyChecking=accept-new -p $PORT" \
40
+ "root@$HOST:/workspace/logs/" logs/ 2>/dev/null && echo " Synced" || echo " (Sync failed)"
scripts/train_all.py CHANGED
@@ -15,9 +15,11 @@ import argparse
15
  import json
16
  import math
17
  import os
 
18
  import signal
19
  import sys
20
  import time
 
21
 
22
  import torch
23
  import torch.multiprocessing as mp
@@ -26,7 +28,7 @@ from torch.utils.data import DataLoader
26
  from pawn.config import CLMConfig, TrainingConfig
27
  from pawn.model import PAWNCLM, clm_loss
28
  from pawn.data import CLMDataset, create_validation_set
29
- from pawn.gpu import configure_gpu, apply_gpu_config
30
  from pawn.checkpoint import save_pretrain_checkpoint, push_checkpoint_to_hf
31
 
32
 
@@ -44,12 +46,14 @@ class ModelSlot:
44
  train_cfg: TrainingConfig,
45
  device: str,
46
  hf_repo: str | None,
 
47
  ):
48
  self.name = name
49
  self.model_cfg = model_cfg
50
  self.train_cfg = train_cfg
51
  self.device = device
52
  self.hf_repo = hf_repo
 
53
 
54
  self.model = PAWNCLM(model_cfg).to(device)
55
  param_count = sum(p.numel() for p in self.model.parameters())
@@ -70,9 +74,14 @@ class ModelSlot:
70
 
71
  self.scaler = torch.amp.GradScaler(device, enabled=train_cfg.use_amp)
72
 
73
- # Run directory
74
  self.run_dir = _make_run_dir(train_cfg.log_dir, name)
75
- self.checkpoint_dir = os.path.join(self.run_dir, "checkpoints")
 
 
 
 
 
76
  os.makedirs(self.checkpoint_dir, exist_ok=True)
77
 
78
  self.jsonl_path = os.path.join(self.run_dir, "metrics.jsonl")
@@ -80,6 +89,13 @@ class ModelSlot:
80
 
81
  self.hf_branch = f"run/{os.path.basename(self.run_dir)}" if hf_repo else None
82
  self.global_step = 0
 
 
 
 
 
 
 
83
 
84
  # Write config
85
  import subprocess
@@ -133,24 +149,56 @@ class ModelSlot:
133
  self.scheduler.step()
134
  return grad_norm
135
 
 
 
 
 
 
 
 
136
  def save_checkpoint(self):
137
  path = os.path.join(self.checkpoint_dir, f"step_{self.global_step:08d}")
138
  save_pretrain_checkpoint(
139
- path, self.model, self.optimizer, self.scheduler, self.scaler,
140
  self.global_step, self.model_cfg.__dict__, self.train_cfg.__dict__,
141
  )
142
  print(f" [{self.name}] Checkpoint saved: {path}")
143
 
144
  if self.hf_repo and self.hf_branch:
 
 
 
 
 
 
 
 
 
145
  try:
146
  push_checkpoint_to_hf(
147
- path, self.hf_repo, self.hf_branch,
148
- metrics_path=self.jsonl_path, step=self.global_step,
149
  )
150
  print(f" [{self.name}] Pushed to HF: {self.hf_repo}@{self.hf_branch}")
 
 
 
 
 
 
 
 
151
  except Exception as e:
152
  print(f" [{self.name}] WARNING: HF push failed: {e}")
153
 
 
 
 
 
 
 
 
 
154
  @torch.no_grad()
155
  def evaluate(self, val_data: dict[str, torch.Tensor]) -> dict[str, float]:
156
  self.model.eval()
@@ -182,6 +230,8 @@ class ModelSlot:
182
  self._jsonl_file.flush()
183
 
184
  def close(self):
 
 
185
  if self._jsonl_file:
186
  self._jsonl_file.close()
187
  self._jsonl_file = None
@@ -215,12 +265,120 @@ def parse_args():
215
  ckpt_group.add_argument("--hf-repo", type=str, default=None,
216
  help="HF repo prefix (appends -{variant}). E.g. thomas-schweich/pawn")
217
  ckpt_group.add_argument("--local-checkpoints", action="store_true")
 
 
 
 
 
 
 
 
 
 
 
218
  return p.parse_args()
219
 
220
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
221
  def main():
222
  args = parse_args()
223
 
 
 
 
 
224
  device = args.device or ("cuda" if torch.cuda.is_available() else "cpu")
225
  if device == "cuda":
226
  gpu_cfg = configure_gpu()
@@ -239,6 +397,8 @@ def main():
239
  print(f"Device: {device}")
240
  print(f"Batch size: {args.batch_size}")
241
  print(f"Total steps: {args.total_steps}")
 
 
242
  print()
243
 
244
  slots: list[ModelSlot] = []
@@ -256,7 +416,8 @@ def main():
256
  train_cfg.use_wandb = args.wandb
257
 
258
  hf_repo = f"{args.hf_repo}-{name}" if args.hf_repo else None
259
- slots.append(ModelSlot(name, model_cfg, train_cfg, device, hf_repo))
 
260
 
261
  # Shared dataset and validation set
262
  max_ply = 256
@@ -362,6 +523,11 @@ def main():
362
  "timestamp": time.time(),
363
  **val_metrics,
364
  })
 
 
 
 
 
365
 
366
  # Checkpoint
367
  if global_step % args.checkpoint_interval == 0:
@@ -389,6 +555,13 @@ def main():
389
  for slot in slots:
390
  slot.close()
391
 
 
 
 
 
 
 
 
392
  print("\nAll done.")
393
 
394
 
 
15
  import json
16
  import math
17
  import os
18
+ import shutil
19
  import signal
20
  import sys
21
  import time
22
+ from pathlib import Path
23
 
24
  import torch
25
  import torch.multiprocessing as mp
 
28
  from pawn.config import CLMConfig, TrainingConfig
29
  from pawn.model import PAWNCLM, clm_loss
30
  from pawn.data import CLMDataset, create_validation_set
31
+ from pawn.gpu import configure_gpu
32
  from pawn.checkpoint import save_pretrain_checkpoint, push_checkpoint_to_hf
33
 
34
 
 
46
  train_cfg: TrainingConfig,
47
  device: str,
48
  hf_repo: str | None,
49
+ shm_checkpoints: bool = False,
50
  ):
51
  self.name = name
52
  self.model_cfg = model_cfg
53
  self.train_cfg = train_cfg
54
  self.device = device
55
  self.hf_repo = hf_repo
56
+ self.shm_checkpoints = shm_checkpoints
57
 
58
  self.model = PAWNCLM(model_cfg).to(device)
59
  param_count = sum(p.numel() for p in self.model.parameters())
 
74
 
75
  self.scaler = torch.amp.GradScaler(device, enabled=train_cfg.use_amp)
76
 
77
+ # Run directory (logs always on persistent disk)
78
  self.run_dir = _make_run_dir(train_cfg.log_dir, name)
79
+
80
+ # Checkpoint directory: /dev/shm if requested, else under run_dir
81
+ if shm_checkpoints:
82
+ self.checkpoint_dir = f"/dev/shm/pawn_checkpoints/{name}"
83
+ else:
84
+ self.checkpoint_dir = os.path.join(self.run_dir, "checkpoints")
85
  os.makedirs(self.checkpoint_dir, exist_ok=True)
86
 
87
  self.jsonl_path = os.path.join(self.run_dir, "metrics.jsonl")
 
89
 
90
  self.hf_branch = f"run/{os.path.basename(self.run_dir)}" if hf_repo else None
91
  self.global_step = 0
92
+ self.best_val_step = 0
93
+ self.best_val_loss = float("inf")
94
+
95
+ # Background HF push (one thread per slot, so pushes don't block training)
96
+ from concurrent.futures import ThreadPoolExecutor
97
+ self._hf_push_pool = ThreadPoolExecutor(max_workers=1, thread_name_prefix=f"hf-{name}")
98
+ self._hf_push_future = None
99
 
100
  # Write config
101
  import subprocess
 
149
  self.scheduler.step()
150
  return grad_norm
151
 
152
+ def _unwrapped_model(self):
153
+ """Return the unwrapped model (strips torch.compile wrapper)."""
154
+ m = self.model
155
+ while hasattr(m, '_orig_mod'):
156
+ m = m._orig_mod
157
+ return m
158
+
159
  def save_checkpoint(self):
160
  path = os.path.join(self.checkpoint_dir, f"step_{self.global_step:08d}")
161
  save_pretrain_checkpoint(
162
+ path, self._unwrapped_model(), self.optimizer, self.scheduler, self.scaler,
163
  self.global_step, self.model_cfg.__dict__, self.train_cfg.__dict__,
164
  )
165
  print(f" [{self.name}] Checkpoint saved: {path}")
166
 
167
  if self.hf_repo and self.hf_branch:
168
+ self._push_to_hf_async(path, self.global_step)
169
+
170
+ def _push_to_hf_async(self, ckpt_path: str, step: int):
171
+ """Push checkpoint to HuggingFace in a background thread."""
172
+ # Wait for any previous push to finish before starting a new one
173
+ if self._hf_push_future is not None:
174
+ self._hf_push_future.result() # raises if previous push failed
175
+
176
+ def _push():
177
  try:
178
  push_checkpoint_to_hf(
179
+ ckpt_path, self.hf_repo, self.hf_branch,
180
+ metrics_path=self.jsonl_path, step=step,
181
  )
182
  print(f" [{self.name}] Pushed to HF: {self.hf_repo}@{self.hf_branch}")
183
+
184
+ # On /dev/shm, clean up old checkpoints after successful push.
185
+ # Keep the latest (just saved) and the best (for post-training evals).
186
+ if self.shm_checkpoints:
187
+ keep = {Path(ckpt_path).name, f"step_{self.best_val_step:08d}"}
188
+ for old in sorted(Path(self.checkpoint_dir).glob("step_*")):
189
+ if old.name not in keep:
190
+ shutil.rmtree(old, ignore_errors=True)
191
  except Exception as e:
192
  print(f" [{self.name}] WARNING: HF push failed: {e}")
193
 
194
+ self._hf_push_future = self._hf_push_pool.submit(_push)
195
+
196
+ def wait_for_push(self):
197
+ """Block until any in-flight HF push completes."""
198
+ if self._hf_push_future is not None:
199
+ self._hf_push_future.result()
200
+ self._hf_push_future = None
201
+
202
  @torch.no_grad()
203
  def evaluate(self, val_data: dict[str, torch.Tensor]) -> dict[str, float]:
204
  self.model.eval()
 
230
  self._jsonl_file.flush()
231
 
232
  def close(self):
233
+ self.wait_for_push()
234
+ self._hf_push_pool.shutdown(wait=True)
235
  if self._jsonl_file:
236
  self._jsonl_file.close()
237
  self._jsonl_file = None
 
265
  ckpt_group.add_argument("--hf-repo", type=str, default=None,
266
  help="HF repo prefix (appends -{variant}). E.g. thomas-schweich/pawn")
267
  ckpt_group.add_argument("--local-checkpoints", action="store_true")
268
+
269
+ p.add_argument("--shm-checkpoints", action="store_true",
270
+ help="Write checkpoints to /dev/shm (RAM-backed, instant writes). "
271
+ "Requires --hf-repo since /dev/shm is volatile.")
272
+
273
+ p.add_argument("--run-evals", action="store_true",
274
+ help="Run probes, diagnostics, and Lichess eval after training completes")
275
+ p.add_argument("--lichess-pgn", type=str, default=None,
276
+ help="Path to Lichess PGN file for eval (required with --run-evals)")
277
+ p.add_argument("--publish-results", action="store_true",
278
+ help="Push eval results to HuggingFace (requires --hf-repo and --run-evals)")
279
  return p.parse_args()
280
 
281
 
282
+ def _run_post_training_evals(slots: list[ModelSlot], args):
283
+ """Run probes, diagnostics, and Lichess eval on best checkpoint per variant."""
284
+ import tempfile
285
+ from pawn.eval_suite.probes import extract_probe_data, train_all_probes
286
+ from pawn.eval_suite.corpus import generate_corpus, load_corpus
287
+ from pawn.eval_suite.diagnostics import extract_diagnostic_positions, evaluate_diagnostic_positions
288
+
289
+ device = args.device or ("cuda" if torch.cuda.is_available() else "cpu")
290
+
291
+ for slot in slots:
292
+ print(f"\n--- Evaluating {slot.name} ---")
293
+
294
+ # Use tracked best val step (kept on /dev/shm if shm_checkpoints)
295
+ best_step = slot.best_val_step
296
+ best_loss = slot.best_val_loss
297
+
298
+ ckpt_path = os.path.join(slot.checkpoint_dir, f"step_{best_step:08d}")
299
+ if not os.path.isdir(ckpt_path):
300
+ # Fall back to latest
301
+ ckpts = sorted(Path(slot.checkpoint_dir).glob("step_*"))
302
+ ckpt_path = str(ckpts[-1]) if ckpts else None
303
+
304
+ if not ckpt_path:
305
+ print(f" No checkpoint found, skipping")
306
+ continue
307
+
308
+ print(f" Best checkpoint: {ckpt_path} (val_loss={best_loss:.4f})")
309
+
310
+ # Load model (unwrapped)
311
+ from pawn.checkpoint import load_backbone_weights
312
+ state_dict, _ = load_backbone_weights(ckpt_path)
313
+ model = PAWNCLM(slot.model_cfg).to(device)
314
+ model.load_state_dict(state_dict)
315
+ model.eval()
316
+
317
+ results = {}
318
+
319
+ # 1. Probes
320
+ print(" Running probes...")
321
+ train_data = extract_probe_data(2048, 256, seed=12345)
322
+ val_data = extract_probe_data(512, 256, seed=54321)
323
+ probe_results = train_all_probes(
324
+ model, train_data, val_data, device=device,
325
+ per_layer=True, n_epochs=20, verbose=True,
326
+ )
327
+ results["probes"] = probe_results
328
+ del train_data, val_data
329
+
330
+ # 2. Diagnostics
331
+ print(" Running diagnostics...")
332
+ with tempfile.TemporaryDirectory() as tmpdir:
333
+ corpus_path = generate_corpus(tmpdir, n_games=2048, max_ply=255, seed=99999, batch_size=2048)
334
+ corpus = load_corpus(corpus_path)
335
+ positions = extract_diagnostic_positions(corpus, min_per_category=200, max_per_category=1000)
336
+ diag_results = evaluate_diagnostic_positions(model, positions, corpus, device=device)
337
+ results["diagnostics"] = diag_results
338
+
339
+ # 3. Lichess eval (if PGN provided)
340
+ if args.lichess_pgn:
341
+ print(" Running Lichess eval...")
342
+ from pawn.eval_suite.lichess import prepare_lichess_corpus, evaluate_on_lichess
343
+ lichess_data = prepare_lichess_corpus(args.lichess_pgn, max_games_per_band=1000)
344
+ lichess_results = evaluate_on_lichess(model, lichess_data, device=device)
345
+ results["lichess"] = lichess_results
346
+
347
+ # Save results
348
+ results_path = os.path.join(slot.run_dir, "eval_results.json")
349
+ with open(results_path, "w") as f:
350
+ json.dump(results, f, indent=2, default=str)
351
+ print(f" Results saved: {results_path}")
352
+
353
+ # Publish to HF
354
+ if args.publish_results and slot.hf_repo and slot.hf_branch:
355
+ from huggingface_hub import HfApi
356
+ api = HfApi()
357
+ try:
358
+ api.upload_file(
359
+ path_or_fileobj=results_path,
360
+ path_in_repo="eval_results.json",
361
+ repo_id=slot.hf_repo,
362
+ repo_type="model",
363
+ revision=slot.hf_branch,
364
+ commit_message=f"Eval results (best step {best_step})",
365
+ )
366
+ print(f" Published to {slot.hf_repo}@{slot.hf_branch}")
367
+ except Exception as e:
368
+ print(f" WARNING: HF publish failed: {e}")
369
+
370
+ del model, state_dict
371
+ if torch.cuda.is_available():
372
+ torch.cuda.empty_cache()
373
+
374
+
375
  def main():
376
  args = parse_args()
377
 
378
+ if args.shm_checkpoints and not args.hf_repo:
379
+ print("ERROR: --shm-checkpoints requires --hf-repo (HF is the only durable store)")
380
+ sys.exit(1)
381
+
382
  device = args.device or ("cuda" if torch.cuda.is_available() else "cpu")
383
  if device == "cuda":
384
  gpu_cfg = configure_gpu()
 
397
  print(f"Device: {device}")
398
  print(f"Batch size: {args.batch_size}")
399
  print(f"Total steps: {args.total_steps}")
400
+ if args.shm_checkpoints:
401
+ print("Checkpoints: /dev/shm (volatile, HF push is durable store)")
402
  print()
403
 
404
  slots: list[ModelSlot] = []
 
416
  train_cfg.use_wandb = args.wandb
417
 
418
  hf_repo = f"{args.hf_repo}-{name}" if args.hf_repo else None
419
+ slots.append(ModelSlot(name, model_cfg, train_cfg, device, hf_repo,
420
+ shm_checkpoints=args.shm_checkpoints))
421
 
422
  # Shared dataset and validation set
423
  max_ply = 256
 
523
  "timestamp": time.time(),
524
  **val_metrics,
525
  })
526
+ # Track best for eval and /dev/shm cleanup
527
+ vl = val_metrics["val/loss"]
528
+ if vl < slot.best_val_loss:
529
+ slot.best_val_loss = vl
530
+ slot.best_val_step = global_step
531
 
532
  # Checkpoint
533
  if global_step % args.checkpoint_interval == 0:
 
555
  for slot in slots:
556
  slot.close()
557
 
558
+ # Post-training evals
559
+ if args.run_evals:
560
+ print("\n" + "=" * 60)
561
+ print("POST-TRAINING EVALUATION")
562
+ print("=" * 60)
563
+ _run_post_training_evals(slots, args)
564
+
565
  print("\nAll done.")
566
 
567
 
uv.lock CHANGED
@@ -20,6 +20,15 @@ members = [
20
  "pawn",
21
  ]
22
 
 
 
 
 
 
 
 
 
 
23
  [[package]]
24
  name = "annotated-types"
25
  version = "0.7.0"
@@ -499,6 +508,70 @@ wheels = [
499
  { url = "https://files.pythonhosted.org/packages/04/4b/29cac41a4d98d144bf5f6d33995617b185d14b22401f75ca86f384e87ff1/h11-0.16.0-py3-none-any.whl", hash = "sha256:63cf8bbe7522de3bf65932fda1d9c2772064ffb3dae62d55932da54b31cb6c86", size = 37515, upload-time = "2025-04-24T03:35:24.344Z" },
500
  ]
501
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
502
  [[package]]
503
  name = "humanize"
504
  version = "4.15.0"
@@ -1322,10 +1395,22 @@ name = "pawn"
1322
  version = "0.1.0"
1323
  source = { editable = "." }
1324
  dependencies = [
 
1325
  { name = "chess-engine", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
 
 
 
1326
  { name = "numpy", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
 
 
 
 
1327
  { name = "psutil", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
 
 
1328
  { name = "safetensors", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
 
 
1329
  { name = "tqdm", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
1330
  { name = "wandb", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
1331
  ]
@@ -1334,23 +1419,6 @@ dependencies = [
1334
  cu128 = [
1335
  { name = "torch", version = "2.10.0+cu128", source = { registry = "https://download.pytorch.org/whl/cu128" }, marker = "sys_platform == 'linux'" },
1336
  ]
1337
- dashboard = [
1338
- { name = "anywidget", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
1339
- { name = "pandas", version = "2.3.3", source = { registry = "https://pypi.org/simple" }, marker = "(python_full_version < '3.11' and sys_platform == 'linux') or (python_full_version >= '3.11' and extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm') or (sys_platform != 'linux' and extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
1340
- { name = "pandas", version = "3.0.1", source = { registry = "https://pypi.org/simple" }, marker = "(python_full_version >= '3.11' and sys_platform == 'linux') or (python_full_version < '3.11' and extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm') or (sys_platform != 'linux' and extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
1341
- { name = "plotly", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
1342
- { name = "solara", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
1343
- ]
1344
- dev = [
1345
- { name = "ipykernel", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
1346
- { name = "pytest", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
1347
- ]
1348
- eval = [
1349
- { name = "matplotlib", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
1350
- { name = "polars", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
1351
- { name = "pyarrow", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
1352
- { name = "seaborn", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
1353
- ]
1354
  rocm = [
1355
  { name = "torch", version = "2.10.0+rocm7.1", source = { registry = "https://download.pytorch.org/whl/rocm7.1" }, marker = "sys_platform == 'linux'" },
1356
  { name = "triton-rocm", marker = "sys_platform == 'linux'" },
@@ -1358,27 +1426,28 @@ rocm = [
1358
 
1359
  [package.metadata]
1360
  requires-dist = [
1361
- { name = "anywidget", marker = "extra == 'dashboard'", specifier = ">=0.9.21" },
1362
  { name = "chess-engine", editable = "engine" },
1363
- { name = "ipykernel", marker = "extra == 'dev'", specifier = ">=7.2.0" },
1364
- { name = "matplotlib", marker = "extra == 'eval'", specifier = ">=3.10.8" },
 
1365
  { name = "numpy", specifier = "~=2.2.0" },
1366
- { name = "pandas", marker = "extra == 'dashboard'", specifier = ">=2.0.0" },
1367
- { name = "plotly", marker = "extra == 'dashboard'", specifier = ">=5.18.0" },
1368
- { name = "polars", marker = "extra == 'eval'", specifier = ">=1.39.0" },
1369
  { name = "psutil", specifier = ">=5.9.0" },
1370
- { name = "pyarrow", marker = "extra == 'eval'", specifier = ">=23.0.1" },
1371
- { name = "pytest", marker = "extra == 'dev'", specifier = "~=9.0.0" },
1372
  { name = "safetensors", specifier = ">=0.4.0" },
1373
- { name = "seaborn", marker = "extra == 'eval'", specifier = ">=0.13.2" },
1374
- { name = "solara", marker = "extra == 'dashboard'", specifier = ">=1.0.0" },
1375
  { name = "torch", marker = "extra == 'cu128'", specifier = "~=2.10.0", index = "https://download.pytorch.org/whl/cu128", conflict = { package = "pawn", extra = "cu128" } },
1376
  { name = "torch", marker = "extra == 'rocm'", specifier = "~=2.10.0", index = "https://download.pytorch.org/whl/rocm7.1", conflict = { package = "pawn", extra = "rocm" } },
1377
  { name = "tqdm", specifier = "~=4.67.0" },
1378
  { name = "triton-rocm", marker = "extra == 'rocm'", specifier = ">=3.6.0", index = "https://download.pytorch.org/whl/rocm7.1", conflict = { package = "pawn", extra = "rocm" } },
1379
  { name = "wandb", specifier = "~=0.25.0" },
1380
  ]
1381
- provides-extras = ["rocm", "cu128", "eval", "dashboard", "dev"]
1382
 
1383
  [[package]]
1384
  name = "pexpect"
@@ -2065,6 +2134,15 @@ wheels = [
2065
  { url = "https://files.pythonhosted.org/packages/9d/76/f789f7a86709c6b087c5a2f52f911838cad707cc613162401badc665acfe/setuptools-82.0.1-py3-none-any.whl", hash = "sha256:a59e362652f08dcd477c78bb6e7bd9d80a7995bc73ce773050228a348ce2e5bb", size = 1006223, upload-time = "2026-03-09T12:47:15.026Z" },
2066
  ]
2067
 
 
 
 
 
 
 
 
 
 
2068
  [[package]]
2069
  name = "six"
2070
  version = "1.17.0"
@@ -2351,6 +2429,21 @@ wheels = [
2351
  { url = "https://download.pytorch.org/whl/triton_rocm-3.6.0-cp312-cp312-linux_x86_64.whl" },
2352
  ]
2353
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2354
  [[package]]
2355
  name = "typing-extensions"
2356
  version = "4.15.0"
 
20
  "pawn",
21
  ]
22
 
23
+ [[package]]
24
+ name = "annotated-doc"
25
+ version = "0.0.4"
26
+ source = { registry = "https://pypi.org/simple" }
27
+ sdist = { url = "https://files.pythonhosted.org/packages/57/ba/046ceea27344560984e26a590f90bc7f4a75b06701f653222458922b558c/annotated_doc-0.0.4.tar.gz", hash = "sha256:fbcda96e87e9c92ad167c2e53839e57503ecfda18804ea28102353485033faa4", size = 7288, upload-time = "2025-11-10T22:07:42.062Z" }
28
+ wheels = [
29
+ { url = "https://files.pythonhosted.org/packages/1e/d3/26bf1008eb3d2daa8ef4cacc7f3bfdc11818d111f7e2d0201bc6e3b49d45/annotated_doc-0.0.4-py3-none-any.whl", hash = "sha256:571ac1dc6991c450b25a9c2d84a3705e2ae7a53467b5d111c24fa8baabbed320", size = 5303, upload-time = "2025-11-10T22:07:40.673Z" },
30
+ ]
31
+
32
  [[package]]
33
  name = "annotated-types"
34
  version = "0.7.0"
 
508
  { url = "https://files.pythonhosted.org/packages/04/4b/29cac41a4d98d144bf5f6d33995617b185d14b22401f75ca86f384e87ff1/h11-0.16.0-py3-none-any.whl", hash = "sha256:63cf8bbe7522de3bf65932fda1d9c2772064ffb3dae62d55932da54b31cb6c86", size = 37515, upload-time = "2025-04-24T03:35:24.344Z" },
509
  ]
510
 
511
+ [[package]]
512
+ name = "hf-xet"
513
+ version = "1.4.2"
514
+ source = { registry = "https://pypi.org/simple" }
515
+ sdist = { url = "https://files.pythonhosted.org/packages/09/08/23c84a26716382c89151b5b447b4beb19e3345f3a93d3b73009a71a57ad3/hf_xet-1.4.2.tar.gz", hash = "sha256:b7457b6b482d9e0743bd116363239b1fa904a5e65deede350fbc0c4ea67c71ea", size = 672357, upload-time = "2026-03-13T06:58:51.077Z" }
516
+ wheels = [
517
+ { url = "https://files.pythonhosted.org/packages/b4/86/b40b83a2ff03ef05c4478d2672b1fc2b9683ff870e2b25f4f3af240f2e7b/hf_xet-1.4.2-cp37-abi3-macosx_10_12_x86_64.whl", hash = "sha256:71f02d6e4cdd07f344f6844845d78518cc7186bd2bc52d37c3b73dc26a3b0bc5", size = 3800339, upload-time = "2026-03-13T06:58:36.245Z" },
518
+ { url = "https://files.pythonhosted.org/packages/64/2e/af4475c32b4378b0e92a587adb1aa3ec53e3450fd3e5fe0372a874531c00/hf_xet-1.4.2-cp37-abi3-macosx_11_0_arm64.whl", hash = "sha256:e9b38d876e94d4bdcf650778d6ebbaa791dd28de08db9736c43faff06ede1b5a", size = 3559664, upload-time = "2026-03-13T06:58:34.787Z" },
519
+ { url = "https://files.pythonhosted.org/packages/3c/4c/781267da3188db679e601de18112021a5cb16506fe86b246e22c5401a9c4/hf_xet-1.4.2-cp37-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:77e8c180b7ef12d8a96739a4e1e558847002afe9ea63b6f6358b2271a8bdda1c", size = 4217422, upload-time = "2026-03-13T06:58:27.472Z" },
520
+ { url = "https://files.pythonhosted.org/packages/68/47/d6cf4a39ecf6c7705f887a46f6ef5c8455b44ad9eb0d391aa7e8a2ff7fea/hf_xet-1.4.2-cp37-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:c3b3c6a882016b94b6c210957502ff7877802d0dbda8ad142c8595db8b944271", size = 3992847, upload-time = "2026-03-13T06:58:25.989Z" },
521
+ { url = "https://files.pythonhosted.org/packages/2d/ef/e80815061abff54697239803948abc665c6b1d237102c174f4f7a9a5ffc5/hf_xet-1.4.2-cp37-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:9d9a634cc929cfbaf2e1a50c0e532ae8c78fa98618426769480c58501e8c8ac2", size = 4193843, upload-time = "2026-03-13T06:58:44.59Z" },
522
+ { url = "https://files.pythonhosted.org/packages/54/75/07f6aa680575d9646c4167db6407c41340cbe2357f5654c4e72a1b01ca14/hf_xet-1.4.2-cp37-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:6b0932eb8b10317ea78b7da6bab172b17be03bbcd7809383d8d5abd6a2233e04", size = 4432751, upload-time = "2026-03-13T06:58:46.533Z" },
523
+ { url = "https://files.pythonhosted.org/packages/cd/71/193eabd7e7d4b903c4aa983a215509c6114915a5a237525ec562baddb868/hf_xet-1.4.2-cp37-abi3-win_amd64.whl", hash = "sha256:ad185719fb2e8ac26f88c8100562dbf9dbdcc3d9d2add00faa94b5f106aea53f", size = 3671149, upload-time = "2026-03-13T06:58:57.07Z" },
524
+ { url = "https://files.pythonhosted.org/packages/b4/7e/ccf239da366b37ba7f0b36095450efae4a64980bdc7ec2f51354205fdf39/hf_xet-1.4.2-cp37-abi3-win_arm64.whl", hash = "sha256:32c012286b581f783653e718c1862aea5b9eb140631685bb0c5e7012c8719a87", size = 3533426, upload-time = "2026-03-13T06:58:55.46Z" },
525
+ ]
526
+
527
+ [[package]]
528
+ name = "httpcore"
529
+ version = "1.0.9"
530
+ source = { registry = "https://pypi.org/simple" }
531
+ dependencies = [
532
+ { name = "certifi", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
533
+ { name = "h11", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
534
+ ]
535
+ sdist = { url = "https://files.pythonhosted.org/packages/06/94/82699a10bca87a5556c9c59b5963f2d039dbd239f25bc2a63907a05a14cb/httpcore-1.0.9.tar.gz", hash = "sha256:6e34463af53fd2ab5d807f399a9b45ea31c3dfa2276f15a2c3f00afff6e176e8", size = 85484, upload-time = "2025-04-24T22:06:22.219Z" }
536
+ wheels = [
537
+ { url = "https://files.pythonhosted.org/packages/7e/f5/f66802a942d491edb555dd61e3a9961140fd64c90bce1eafd741609d334d/httpcore-1.0.9-py3-none-any.whl", hash = "sha256:2d400746a40668fc9dec9810239072b40b4484b640a8c38fd654a024c7a1bf55", size = 78784, upload-time = "2025-04-24T22:06:20.566Z" },
538
+ ]
539
+
540
+ [[package]]
541
+ name = "httpx"
542
+ version = "0.28.1"
543
+ source = { registry = "https://pypi.org/simple" }
544
+ dependencies = [
545
+ { name = "anyio", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
546
+ { name = "certifi", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
547
+ { name = "httpcore", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
548
+ { name = "idna", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
549
+ ]
550
+ sdist = { url = "https://files.pythonhosted.org/packages/b1/df/48c586a5fe32a0f01324ee087459e112ebb7224f646c0b5023f5e79e9956/httpx-0.28.1.tar.gz", hash = "sha256:75e98c5f16b0f35b567856f597f06ff2270a374470a5c2392242528e3e3e42fc", size = 141406, upload-time = "2024-12-06T15:37:23.222Z" }
551
+ wheels = [
552
+ { url = "https://files.pythonhosted.org/packages/2a/39/e50c7c3a983047577ee07d2a9e53faf5a69493943ec3f6a384bdc792deb2/httpx-0.28.1-py3-none-any.whl", hash = "sha256:d909fcccc110f8c7faf814ca82a9a4d816bc5a6dbfea25d6591d6985b8ba59ad", size = 73517, upload-time = "2024-12-06T15:37:21.509Z" },
553
+ ]
554
+
555
+ [[package]]
556
+ name = "huggingface-hub"
557
+ version = "1.7.2"
558
+ source = { registry = "https://pypi.org/simple" }
559
+ dependencies = [
560
+ { name = "filelock", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
561
+ { name = "fsspec", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
562
+ { name = "hf-xet", marker = "(platform_machine != 'AMD64' and platform_machine != 'aarch64' and platform_machine != 'amd64' and platform_machine != 'arm64' and platform_machine != 'x86_64' and extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm') or (platform_machine == 'AMD64' and sys_platform == 'linux') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'amd64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
563
+ { name = "httpx", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
564
+ { name = "packaging", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
565
+ { name = "pyyaml", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
566
+ { name = "tqdm", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
567
+ { name = "typer", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
568
+ { name = "typing-extensions", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
569
+ ]
570
+ sdist = { url = "https://files.pythonhosted.org/packages/19/15/eafc1c57bf0f8afffb243dcd4c0cceb785e956acc17bba4d9bf2ae21fc9c/huggingface_hub-1.7.2.tar.gz", hash = "sha256:7f7e294e9bbb822e025bdb2ada025fa4344d978175a7f78e824d86e35f7ab43b", size = 724684, upload-time = "2026-03-20T10:36:08.767Z" }
571
+ wheels = [
572
+ { url = "https://files.pythonhosted.org/packages/08/de/3ad061a05f74728927ded48c90b73521b9a9328c85d841bdefb30e01fb85/huggingface_hub-1.7.2-py3-none-any.whl", hash = "sha256:288f33a0a17b2a73a1359e2a5fd28d1becb2c121748c6173ab8643fb342c850e", size = 618036, upload-time = "2026-03-20T10:36:06.824Z" },
573
+ ]
574
+
575
  [[package]]
576
  name = "humanize"
577
  version = "4.15.0"
 
1395
  version = "0.1.0"
1396
  source = { editable = "." }
1397
  dependencies = [
1398
+ { name = "anywidget", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
1399
  { name = "chess-engine", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
1400
+ { name = "huggingface-hub", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
1401
+ { name = "ipykernel", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
1402
+ { name = "matplotlib", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
1403
  { name = "numpy", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
1404
+ { name = "pandas", version = "2.3.3", source = { registry = "https://pypi.org/simple" }, marker = "(python_full_version < '3.11' and sys_platform == 'linux') or (python_full_version >= '3.11' and extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm') or (sys_platform != 'linux' and extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
1405
+ { name = "pandas", version = "3.0.1", source = { registry = "https://pypi.org/simple" }, marker = "(python_full_version >= '3.11' and sys_platform == 'linux') or (python_full_version < '3.11' and extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm') or (sys_platform != 'linux' and extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
1406
+ { name = "plotly", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
1407
+ { name = "polars", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
1408
  { name = "psutil", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
1409
+ { name = "pyarrow", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
1410
+ { name = "pytest", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
1411
  { name = "safetensors", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
1412
+ { name = "seaborn", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
1413
+ { name = "solara", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
1414
  { name = "tqdm", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
1415
  { name = "wandb", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
1416
  ]
 
1419
  cu128 = [
1420
  { name = "torch", version = "2.10.0+cu128", source = { registry = "https://download.pytorch.org/whl/cu128" }, marker = "sys_platform == 'linux'" },
1421
  ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1422
  rocm = [
1423
  { name = "torch", version = "2.10.0+rocm7.1", source = { registry = "https://download.pytorch.org/whl/rocm7.1" }, marker = "sys_platform == 'linux'" },
1424
  { name = "triton-rocm", marker = "sys_platform == 'linux'" },
 
1426
 
1427
  [package.metadata]
1428
  requires-dist = [
1429
+ { name = "anywidget", specifier = ">=0.9.21" },
1430
  { name = "chess-engine", editable = "engine" },
1431
+ { name = "huggingface-hub", specifier = ">=0.20.0" },
1432
+ { name = "ipykernel", specifier = ">=7.2.0" },
1433
+ { name = "matplotlib", specifier = ">=3.10.8" },
1434
  { name = "numpy", specifier = "~=2.2.0" },
1435
+ { name = "pandas", specifier = ">=2.0.0" },
1436
+ { name = "plotly", specifier = ">=5.18.0" },
1437
+ { name = "polars", specifier = ">=1.39.0" },
1438
  { name = "psutil", specifier = ">=5.9.0" },
1439
+ { name = "pyarrow", specifier = ">=23.0.1" },
1440
+ { name = "pytest", specifier = "~=9.0.0" },
1441
  { name = "safetensors", specifier = ">=0.4.0" },
1442
+ { name = "seaborn", specifier = ">=0.13.2" },
1443
+ { name = "solara", specifier = ">=1.0.0" },
1444
  { name = "torch", marker = "extra == 'cu128'", specifier = "~=2.10.0", index = "https://download.pytorch.org/whl/cu128", conflict = { package = "pawn", extra = "cu128" } },
1445
  { name = "torch", marker = "extra == 'rocm'", specifier = "~=2.10.0", index = "https://download.pytorch.org/whl/rocm7.1", conflict = { package = "pawn", extra = "rocm" } },
1446
  { name = "tqdm", specifier = "~=4.67.0" },
1447
  { name = "triton-rocm", marker = "extra == 'rocm'", specifier = ">=3.6.0", index = "https://download.pytorch.org/whl/rocm7.1", conflict = { package = "pawn", extra = "rocm" } },
1448
  { name = "wandb", specifier = "~=0.25.0" },
1449
  ]
1450
+ provides-extras = ["rocm", "cu128"]
1451
 
1452
  [[package]]
1453
  name = "pexpect"
 
2134
  { url = "https://files.pythonhosted.org/packages/9d/76/f789f7a86709c6b087c5a2f52f911838cad707cc613162401badc665acfe/setuptools-82.0.1-py3-none-any.whl", hash = "sha256:a59e362652f08dcd477c78bb6e7bd9d80a7995bc73ce773050228a348ce2e5bb", size = 1006223, upload-time = "2026-03-09T12:47:15.026Z" },
2135
  ]
2136
 
2137
+ [[package]]
2138
+ name = "shellingham"
2139
+ version = "1.5.4"
2140
+ source = { registry = "https://pypi.org/simple" }
2141
+ sdist = { url = "https://files.pythonhosted.org/packages/58/15/8b3609fd3830ef7b27b655beb4b4e9c62313a4e8da8c676e142cc210d58e/shellingham-1.5.4.tar.gz", hash = "sha256:8dbca0739d487e5bd35ab3ca4b36e11c4078f3a234bfce294b0a0291363404de", size = 10310, upload-time = "2023-10-24T04:13:40.426Z" }
2142
+ wheels = [
2143
+ { url = "https://files.pythonhosted.org/packages/e0/f9/0595336914c5619e5f28a1fb793285925a8cd4b432c9da0a987836c7f822/shellingham-1.5.4-py2.py3-none-any.whl", hash = "sha256:7ecfff8f2fd72616f7481040475a65b2bf8af90a56c89140852d1120324e8686", size = 9755, upload-time = "2023-10-24T04:13:38.866Z" },
2144
+ ]
2145
+
2146
  [[package]]
2147
  name = "six"
2148
  version = "1.17.0"
 
2429
  { url = "https://download.pytorch.org/whl/triton_rocm-3.6.0-cp312-cp312-linux_x86_64.whl" },
2430
  ]
2431
 
2432
+ [[package]]
2433
+ name = "typer"
2434
+ version = "0.24.1"
2435
+ source = { registry = "https://pypi.org/simple" }
2436
+ dependencies = [
2437
+ { name = "annotated-doc", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
2438
+ { name = "click", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
2439
+ { name = "rich", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
2440
+ { name = "shellingham", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
2441
+ ]
2442
+ sdist = { url = "https://files.pythonhosted.org/packages/f5/24/cb09efec5cc954f7f9b930bf8279447d24618bb6758d4f6adf2574c41780/typer-0.24.1.tar.gz", hash = "sha256:e39b4732d65fbdcde189ae76cf7cd48aeae72919dea1fdfc16593be016256b45", size = 118613, upload-time = "2026-02-21T16:54:40.609Z" }
2443
+ wheels = [
2444
+ { url = "https://files.pythonhosted.org/packages/4a/91/48db081e7a63bb37284f9fbcefda7c44c277b18b0e13fbc36ea2335b71e6/typer-0.24.1-py3-none-any.whl", hash = "sha256:112c1f0ce578bfb4cab9ffdabc68f031416ebcc216536611ba21f04e9aa84c9e", size = 56085, upload-time = "2026-02-21T16:54:41.616Z" },
2445
+ ]
2446
+
2447
  [[package]]
2448
  name = "typing-extensions"
2449
  version = "4.15.0"