thomas-schweich commited on
Commit
230508d
·
unverified ·
1 Parent(s): c965110

Safetensors migration, checkpoint integrity, and multi-model training. (#1)

Browse files
.dockerignore CHANGED
@@ -10,6 +10,7 @@ data/
10
  checkpoints/
11
  logs/
12
  deploy/
 
13
  *.so
14
  CLAUDE.md
15
  docs/
 
10
  checkpoints/
11
  logs/
12
  deploy/
13
+ !deploy/entrypoint-run.sh
14
  *.so
15
  CLAUDE.md
16
  docs/
.gitignore CHANGED
@@ -1,5 +1,4 @@
1
- # Checkpoints and logs
2
- checkpoints/
3
  logs/
4
 
5
  # Training data
@@ -7,6 +6,7 @@ data/
7
 
8
  # Deployment artifacts
9
  deploy/pawn-deploy/
 
10
 
11
  # Python
12
  __pycache__/
@@ -36,8 +36,15 @@ Thumbs.db
36
  # uv (track root lockfile only)
37
  engine/uv.lock
38
 
 
 
 
39
  # Misc
40
  runpodctl.tar.gz
41
  .claude/
 
 
 
42
  .playwright-mcp/
43
  local/
 
 
1
+ # Local checkpoints and logs
 
2
  logs/
3
 
4
  # Training data
 
6
 
7
  # Deployment artifacts
8
  deploy/pawn-deploy/
9
+ export/
10
 
11
  # Python
12
  __pycache__/
 
36
  # uv (track root lockfile only)
37
  engine/uv.lock
38
 
39
+ # Local working files
40
+ local/
41
+
42
  # Misc
43
  runpodctl.tar.gz
44
  .claude/
45
+ <<<<<<< HEAD
46
+ .playwright-mcp
47
+ =======
48
  .playwright-mcp/
49
  local/
50
+ >>>>>>> c965110657b87bef82768cf9960b3f5486c54867
CLAUDE.md CHANGED
@@ -41,7 +41,7 @@ uv sync --extra dev # + pytest, ipykernel
41
  uv run pytest tests/
42
 
43
  # Pretrain from scratch
44
- uv run python scripts/train.py --variant base
45
  ```
46
 
47
  ## Engine (`engine/`)
@@ -86,18 +86,24 @@ All adapters freeze the backbone and initialize to identity (zero-init or gamma=
86
  ## Scripts (`scripts/`)
87
 
88
  - `train.py` -- Pretrain from scratch (`--variant small|base|large|toy`)
 
89
  - `train_bottleneck.py`, `train_film.py`, `train_lora.py`, `train_sparse.py`, `train_hybrid.py` -- Adapter behavioral cloning on Lichess PGN
90
  - `train_tiny.py` -- Standalone tiny transformer baseline (no frozen backbone)
91
  - `eval_accuracy.py` -- MAIA-compatible evaluation (per-phase, per-ply accuracy)
 
 
 
 
92
 
93
  ## Deploy (`deploy/`)
94
 
95
- Scripts for Runpod GPU VM deployment:
96
- - `setup.sh` -- One-time pod setup (Rust, uv, engine build, `uv sync --extra cu128`)
97
- - `build.sh` -- Package repo for transfer (`--checkpoint PATH`, `--data-dir PATH`)
98
- - `pod.sh` -- Pod lifecycle (create/start/stop/delete/ssh/deploy/launch)
 
99
 
100
- Pod configs cached in `~/.config/pawn/pods/`. Working directory on pods: `/workspace/pawn/`.
101
 
102
  ## Dashboard (`pawn/dashboard/`)
103
 
@@ -117,8 +123,73 @@ Auto-detects run type from config fields (`run_type`, `formulation`, `pgn_file`)
117
 
118
  ## Checkpoints
119
 
120
- Stored in `checkpoints/` (gitignored). Pre-trained weights downloadable from HuggingFace Hub / GitHub Releases.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
 
122
  ## Logs
123
 
124
  Training metrics in `logs/` (gitignored). Each run gets a timestamped directory with `metrics.jsonl`.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  uv run pytest tests/
42
 
43
  # Pretrain from scratch
44
+ uv run python scripts/train.py --variant base --local-checkpoints
45
  ```
46
 
47
  ## Engine (`engine/`)
 
86
  ## Scripts (`scripts/`)
87
 
88
  - `train.py` -- Pretrain from scratch (`--variant small|base|large|toy`)
89
+ - `train_all.py` -- Train small/base/large simultaneously on shared data batches
90
  - `train_bottleneck.py`, `train_film.py`, `train_lora.py`, `train_sparse.py`, `train_hybrid.py` -- Adapter behavioral cloning on Lichess PGN
91
  - `train_tiny.py` -- Standalone tiny transformer baseline (no frozen backbone)
92
  - `eval_accuracy.py` -- MAIA-compatible evaluation (per-phase, per-ply accuracy)
93
+ - `eval_probes.py` -- Run linear probes on all checkpoints
94
+ - `export_hf_repo.py` -- Convert training run to HuggingFace repo format (safetensors + metrics)
95
+
96
+ All training scripts require `--hf-repo REPO` or `--local-checkpoints`.
97
 
98
  ## Deploy (`deploy/`)
99
 
100
+ Docker-based deployment to Runpod GPU VMs:
101
+ - `Dockerfile` -- Multi-target build: `interactive` (SSH+Jupyter, default) and `runner` (auto-stop)
102
+ - `entrypoint-run.sh` -- Runner entrypoint, pulls from HF via `PAWN_MODEL` env var
103
+ - `sync.sh` -- Pull latest checkpoints/metrics from HuggingFace submodules
104
+ - `pod.sh` -- Pod lifecycle (create/start/stop/delete/ssh)
105
 
106
+ Code lives at `/opt/pawn` on pods (outside the `/workspace` volume mount).
107
 
108
  ## Dashboard (`pawn/dashboard/`)
109
 
 
123
 
124
  ## Checkpoints
125
 
126
+ Pre-trained weights are HuggingFace git submodules under `checkpoints/`:
127
+ - `checkpoints/pawn-small` — 9.5M params, `CLMConfig.small()`
128
+ - `checkpoints/pawn-base` — 35.8M params, `CLMConfig.base()`
129
+ - `checkpoints/pawn-large` — 68.4M params, `CLMConfig.large()`
130
+
131
+ Pull with: `git submodule update --init --remote checkpoints/pawn-base`
132
+
133
+ ### Checkpoint Format (safetensors)
134
+
135
+ Checkpoints are directories, not single files:
136
+ ```
137
+ step_00065000/
138
+ ├── model.safetensors # model weights
139
+ ├── optimizer.safetensors # flattened optimizer state
140
+ ├── training_state.json # step, scheduler, scaler, RNG (base64)
141
+ ├── config.json # model + training config
142
+ └── .complete # SHA-256 hashes of all files (integrity sentinel)
143
+ ```
144
+
145
+ Central module: `pawn/checkpoint.py`. All save/load goes through this module.
146
+ Legacy `.pt` files are still loadable (backward compatible).
147
+
148
+ ### Checkpoint Storage Modes
149
+
150
+ All training scripts require one of:
151
+ - `--hf-repo REPO_ID` — push checkpoints to a HuggingFace branch as they're written (durable)
152
+ - `--local-checkpoints` — save locally only (for development without an HF account)
153
+
154
+ HF mode creates a `run/{run_id}` branch. Squash-merge into main when satisfied.
155
+
156
+ ### Data Integrity
157
+
158
+ **Every checkpoint write is atomic**: files are written to a `.tmp` directory, then renamed.
159
+ The `.complete` sentinel contains SHA-256 hashes of every file in the checkpoint.
160
+ **Hashes are always verified on load — no exceptions.**
161
+
162
+ - `IncompleteCheckpointError` — raised when `.complete` sentinel is missing
163
+ - `CheckpointIntegrityError` — raised when any hash mismatches
164
+
165
+ **Never use `kill -9` on training processes.** SIGTERM is handled gracefully: a flag is set,
166
+ the training loop checks it between steps, saves a checkpoint, pushes to HF, and exits cleanly.
167
+
168
+ **Never rsync checkpoint files from running pods.** Checkpoints are pushed to HuggingFace
169
+ from the trainer. Pull via `deploy/sync.sh` (submodule update).
170
 
171
  ## Logs
172
 
173
  Training metrics in `logs/` (gitignored). Each run gets a timestamped directory with `metrics.jsonl`.
174
+
175
+ ## Runpod Pod Management
176
+
177
+ ### Setup
178
+
179
+ - Docker image: multi-target build in `Dockerfile`
180
+ - `interactive` (default) — SSH + Jupyter, stays alive
181
+ - `runner` — executes command then exits (pod auto-stops)
182
+ - Build: `docker build --target runner --build-arg GIT_HASH=$(git rev-parse HEAD) ...`
183
+
184
+ ### Required Configuration
185
+
186
+ - **Always attach a network volume.** Checkpoints write to disk during atomic rename and HF push. Ephemeral container disk is lost on pod termination.
187
+ - **Set `HF_TOKEN` as a pod environment variable** for automatic HuggingFace authentication.
188
+ - Set `PAWN_MODEL=thomas-schweich/pawn-base` env var in the runner to auto-pull a checkpoint on startup.
189
+
190
+ ### Lifecycle
191
+
192
+ - Create: `runpodctl pod create --name pawn-exp --gpu-id "NVIDIA RTX A5000" --image thomasschweich/pawn:<tag> --volume-in-gb 75 --ports "8888/http,22/tcp"`
193
+ - Stop: `runpodctl pod stop <ID>` — sends SIGTERM → trainer saves and pushes before exiting
194
+ - **Never `runpodctl pod delete` while training is running** — data loss risk
195
+ - Monitor: pull HF submodule (`deploy/sync.sh`) and read `metrics.jsonl`
Dockerfile CHANGED
@@ -3,17 +3,27 @@
3
  # Extends the official Runpod PyTorch template — SSH and JupyterLab
4
  # start automatically via the base image's /start.sh entrypoint.
5
  #
 
 
 
 
6
  # Build:
7
  # docker build --platform linux/amd64 \
8
  # --build-arg GIT_HASH=$(git rev-parse HEAD) \
9
  # --build-arg GIT_TAG=$(git tag --points-at HEAD) \
 
10
  # -t pawn:<tag> .
11
  #
12
- # Runpod BYOC:
13
- # Push to a registry, then set as the container image in a Pod template.
14
- # Configure HTTP port 8888 (Jupyter) and TCP port 22 (SSH).
15
- # Mount a network volume at /workspace for data, checkpoints, and logs.
16
- # Code lives at /opt/pawn (outside the volume mount).
 
 
 
 
 
17
 
18
  # ── Builder ──────────────────────────────────────────────────────────
19
  FROM runpod/pytorch:1.0.2-cu1281-torch280-ubuntu2404 AS builder
@@ -43,14 +53,14 @@ RUN cd engine && \
43
  uv run --no-project --with maturin maturin build --release && \
44
  cd ..
45
 
46
- # ── Runtime ──────────────────────────────────────────────────────────
47
- FROM runpod/pytorch:1.0.2-cu1281-torch280-ubuntu2404
48
 
49
  ENV PYTHONUNBUFFERED=1 \
50
  PYTHONPATH=/opt/pawn
51
 
52
  # Direct deps only (torch + numpy already in base image)
53
- RUN pip install --no-cache-dir psutil tqdm wandb
54
 
55
  COPY --from=builder /workspace/pawn/engine/target/wheels/*.whl /tmp/
56
  RUN pip install --no-cache-dir /tmp/*.whl && rm -rf /tmp/*.whl
@@ -72,3 +82,13 @@ RUN echo "export PYTHONPATH=/opt/pawn" >> /etc/environment && \
72
  echo "export PAWN_GIT_HASH=${GIT_HASH}" >> /etc/environment && \
73
  echo "export PAWN_GIT_TAG=${GIT_TAG}" >> /etc/environment && \
74
  cat /etc/environment >> /root/.bashrc
 
 
 
 
 
 
 
 
 
 
 
3
  # Extends the official Runpod PyTorch template — SSH and JupyterLab
4
  # start automatically via the base image's /start.sh entrypoint.
5
  #
6
+ # Build targets:
7
+ # interactive (default) — SSH + Jupyter, stays alive
8
+ # runner — runs a command then exits (pod auto-stops)
9
+ #
10
  # Build:
11
  # docker build --platform linux/amd64 \
12
  # --build-arg GIT_HASH=$(git rev-parse HEAD) \
13
  # --build-arg GIT_TAG=$(git tag --points-at HEAD) \
14
+ # [--target runner] \
15
  # -t pawn:<tag> .
16
  #
17
+ # Run (interactive):
18
+ # docker run --gpus all pawn:<tag>
19
+ #
20
+ # IMPORTANT: Always attach a Runpod network volume. Checkpoints use
21
+ # atomic directory writes (tmp + rename) that require persistent disk.
22
+ # Set HF_TOKEN as a pod env var for HuggingFace checkpoint push.
23
+ #
24
+ # Run (auto-stop):
25
+ # docker run --gpus all -e PAWN_MODEL=thomas-schweich/pawn-base \
26
+ # pawn:<tag>-runner python scripts/train.py --variant base
27
 
28
  # ── Builder ──────────────────────────────────────────────────────────
29
  FROM runpod/pytorch:1.0.2-cu1281-torch280-ubuntu2404 AS builder
 
53
  uv run --no-project --with maturin maturin build --release && \
54
  cd ..
55
 
56
+ # ── Runtime base (shared by all targets) ─────────────────────────────
57
+ FROM runpod/pytorch:1.0.2-cu1281-torch280-ubuntu2404 AS runtime-base
58
 
59
  ENV PYTHONUNBUFFERED=1 \
60
  PYTHONPATH=/opt/pawn
61
 
62
  # Direct deps only (torch + numpy already in base image)
63
+ RUN pip install --no-cache-dir psutil safetensors tqdm wandb huggingface-hub
64
 
65
  COPY --from=builder /workspace/pawn/engine/target/wheels/*.whl /tmp/
66
  RUN pip install --no-cache-dir /tmp/*.whl && rm -rf /tmp/*.whl
 
82
  echo "export PAWN_GIT_HASH=${GIT_HASH}" >> /etc/environment && \
83
  echo "export PAWN_GIT_TAG=${GIT_TAG}" >> /etc/environment && \
84
  cat /etc/environment >> /root/.bashrc
85
+
86
+ # ── Runner — executes command then exits (pod auto-stops) ────────────
87
+ FROM runtime-base AS runner
88
+ COPY deploy/entrypoint-run.sh /entrypoint-run.sh
89
+ RUN chmod +x /entrypoint-run.sh
90
+ ENTRYPOINT ["/entrypoint-run.sh"]
91
+
92
+ # ── Interactive (default) — SSH + Jupyter, stays alive ───────────────
93
+ FROM runtime-base AS interactive
94
+ # Inherits /start.sh entrypoint from Runpod base image
README.md CHANGED
@@ -32,7 +32,9 @@ Notably, the vocabulary includes impossible moves like `a1a1` and `b1a5`. PAWN n
32
 
33
  Conceptually, each token is best thought of as a move in UCI notation--they are effectively coordinates. They do not include any information on the type of peice, side to play, or any direct geometric or board state information other than the factored nature of the embeddings (see the architecture section below for details).
34
 
35
- For example, `e2e4` is the token that represents the king's pawn opening, but only when it's the first ply in the sequence (moving a rook between from e2 to e4 in the late game would use the same token). The model learns to track which type of peice is on each square any given moment entirely of its own accord. For that matter, it isn't even told what piece types exist and what movement patterns they follow, or indeed even the concept of a peice. All of that 'understanding' comes purely from observation.
 
 
36
 
37
  ## Quickstart
38
 
@@ -50,13 +52,17 @@ uv sync --extra cu128 # NVIDIA GPU (or --extra rocm for AMD)
50
  uv sync --extra dev --extra eval --extra dashboard
51
 
52
  # Train an adapter on a pre-trained checkpoint
 
53
  uv run python scripts/train_bottleneck.py \
54
- --checkpoint checkpoints/pawn-base.pt \
55
  --pgn data/lichess_1800_1900.pgn \
56
- --bottleneck-dim 32 --lr 1e-4
57
 
58
  # Or pretrain a PAWN variant from scratch (generates random games on-the-fly; no dataset required)
59
- uv run python scripts/train.py --variant base
 
 
 
60
 
61
  # Launch the real-time monitoring dashboard (optional dashboard dependency must be installed)
62
  uv run python -m pawn.dashboard --log-dir logs --port 8765
 
32
 
33
  Conceptually, each token is best thought of as a move in UCI notation--they are effectively coordinates. They do not include any information on the type of peice, side to play, or any direct geometric or board state information other than the factored nature of the embeddings (see the architecture section below for details).
34
 
35
+ For example, `e2e4` is the token that represents the king's pawn opening, but only when it's the first ply in the sequence (moving a rook between from e2 to e4 in the late game would use the same token). The model learns to track which type of peice is on each square any given moment entirely of its own accord.
36
+
37
+ For that matter, it isn't told what piece types exist, what movement patterns they follow, or indeed the concept of a peice. All of that 'understanding' comes purely from observation and can be isolated via [linear probes](https://arxiv.org/abs/1610.01644) (Alain & Bengio, 2016).
38
 
39
  ## Quickstart
40
 
 
52
  uv sync --extra dev --extra eval --extra dashboard
53
 
54
  # Train an adapter on a pre-trained checkpoint
55
+ git submodule update --init checkpoints/pawn-base
56
  uv run python scripts/train_bottleneck.py \
57
+ --checkpoint checkpoints/pawn-base \
58
  --pgn data/lichess_1800_1900.pgn \
59
+ --bottleneck-dim 32 --lr 1e-4 --local-checkpoints
60
 
61
  # Or pretrain a PAWN variant from scratch (generates random games on-the-fly; no dataset required)
62
+ uv run python scripts/train.py --variant base --local-checkpoints
63
+
64
+ # Or train all three variants simultaneously on shared data
65
+ uv run python scripts/train_all.py --local-checkpoints
66
 
67
  # Launch the real-time monitoring dashboard (optional dashboard dependency must be installed)
68
  uv run python -m pawn.dashboard --log-dir logs --port 8765
checkpoints/pawn-base CHANGED
@@ -1 +1 @@
1
- Subproject commit d09eee93b2368af24e8c6d1f1d64c399a4a22964
 
1
+ Subproject commit 45aa6e347ca9516662874238adaeef4d30fc6df8
checkpoints/pawn-large CHANGED
@@ -1 +1 @@
1
- Subproject commit 4076e45f4772228f98d68095c012202e1054bd3c
 
1
+ Subproject commit 1d3f1deee3411e86d69a814520553fcf78f96c5f
checkpoints/pawn-small CHANGED
@@ -1 +1 @@
1
- Subproject commit ac274df81dd414044dcb21e869a010d6bd5e4c24
 
1
+ Subproject commit 929fee43ce16e02cf12a255cf39cc3f26e0913e1
deploy/entrypoint-run.sh ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ # Entrypoint for auto-stop runner target.
3
+ # Optionally pulls a checkpoint from HuggingFace before running the command.
4
+ set -e
5
+
6
+ cd /opt/pawn
7
+ export PYTHONPATH=/opt/pawn
8
+
9
+ # Pull checkpoint from HuggingFace if PAWN_MODEL is set
10
+ if [ -n "$PAWN_MODEL" ]; then
11
+ echo "Pulling model: $PAWN_MODEL"
12
+ python3 -c "
13
+ from huggingface_hub import snapshot_download
14
+ snapshot_download('$PAWN_MODEL', local_dir='checkpoints/model')
15
+ print('Model downloaded to checkpoints/model/')
16
+ "
17
+ fi
18
+
19
+ exec "$@"
deploy/sync.sh CHANGED
@@ -1,65 +1,36 @@
1
  #!/usr/bin/env bash
2
- # Sync logs and checkpoints from Runpod pod(s) to local machine.
3
- # Usage: bash deploy/sync.sh [pod-name]
4
  #
5
- # With no args, syncs from all pods in ~/.config/pawn/pods/
6
- # With a pod name, syncs from that specific pod only.
7
  set -euo pipefail
8
 
9
  REPO="$(cd "$(dirname "$0")/.." && pwd)"
10
- POD_DIR="$HOME/.config/pawn/pods"
11
 
12
- sync_pod() {
13
- local name="$1" host="$2" port="$3" remote_root="${4:-/opt/pawn}"
14
- local ssh_opts="-o StrictHostKeyChecking=accept-new -o ConnectTimeout=10 -p $port"
15
-
16
- echo "=== Syncing from $name ($host:$port) ==="
17
-
18
- echo "--- Logs ---"
19
- rsync -avz --progress --no-owner --no-group \
20
- -e "ssh $ssh_opts" \
21
- "root@$host:$remote_root/logs/" "$REPO/logs/" || {
22
- echo " Failed to sync logs from $name"
23
- return 1
24
- }
25
-
26
- echo "--- Checkpoints (best.pt only) ---"
27
- rsync -avz --progress --no-owner --no-group \
28
- --include='*/' --include='best.pt' --include='*.pt' --exclude='*' \
29
- -e "ssh $ssh_opts" \
30
- "root@$host:$remote_root/logs/" "$REPO/logs/" || {
31
- echo " Failed to sync checkpoints from $name"
32
- return 1
33
- }
34
-
35
- echo "=== $name sync complete ==="
36
  echo ""
37
  }
38
 
39
- if [ ! -d "$POD_DIR" ] || [ -z "$(ls "$POD_DIR"/*.env 2>/dev/null)" ]; then
40
- echo "No pods configured. Add .env files to $POD_DIR/"
41
- exit 1
42
- fi
43
-
44
  if [ $# -ge 1 ]; then
45
- # Sync specific pod
46
- POD_FILE="$POD_DIR/$1.env"
47
- if [ ! -f "$POD_FILE" ]; then
48
- echo "Pod '$1' not found. Available pods:"
49
- ls "$POD_DIR"/*.env 2>/dev/null | xargs -I{} basename {} .env | sed 's/^/ /'
 
50
  exit 1
51
  fi
52
- source "$POD_FILE"
53
- sync_pod "$1" "$POD_HOST" "$POD_PORT" "${POD_REMOTE_ROOT:-/opt/pawn}"
54
  else
55
- # Sync all pods
56
- for pod_file in "$POD_DIR"/*.env; do
57
- pod_name="$(basename "${pod_file%.env}")"
58
- unset POD_HOST POD_PORT POD_REMOTE_ROOT
59
- source "$pod_file"
60
- sync_pod "$pod_name" "$POD_HOST" "$POD_PORT" "${POD_REMOTE_ROOT:-/opt/pawn}" || true
61
  done
62
  fi
63
-
64
- echo "Local logs:"
65
- du -sh "$REPO/logs/"
 
1
  #!/usr/bin/env bash
2
+ # Pull latest checkpoints and metrics from HuggingFace submodules.
3
+ # Usage: bash deploy/sync.sh [submodule-name]
4
  #
5
+ # With no args, pulls all checkpoint submodules.
6
+ # With a name, pulls only that submodule.
7
  set -euo pipefail
8
 
9
  REPO="$(cd "$(dirname "$0")/.." && pwd)"
 
10
 
11
+ pull_submodule() {
12
+ local sub="$1"
13
+ local name="$(basename "$sub")"
14
+ echo "=== Pulling $name ==="
15
+ git -C "$sub" fetch origin 2>/dev/null || { echo " Failed to fetch $name"; return 1; }
16
+ git -C "$sub" pull origin main 2>/dev/null || { echo " Failed to pull $name (main)"; return 1; }
17
+ echo "=== $name up to date ==="
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  echo ""
19
  }
20
 
 
 
 
 
 
21
  if [ $# -ge 1 ]; then
22
+ sub="$REPO/checkpoints/$1"
23
+ if [ ! -d "$sub/.git" ]; then
24
+ echo "Submodule '$1' not found. Available:"
25
+ for s in "$REPO"/checkpoints/pawn-*/; do
26
+ [ -d "$s/.git" ] && echo " $(basename "$s")"
27
+ done
28
  exit 1
29
  fi
30
+ pull_submodule "$sub"
 
31
  else
32
+ for sub in "$REPO"/checkpoints/pawn-*/; do
33
+ [ -d "$sub/.git" ] || continue
34
+ pull_submodule "$sub" || true
 
 
 
35
  done
36
  fi
 
 
 
engine/python/chess_engine/__init__.py CHANGED
@@ -8,6 +8,7 @@ from chess_engine._engine import (
8
  # Core game generation
9
  generate_training_batch,
10
  generate_random_games,
 
11
  generate_checkmate_games,
12
  generate_checkmate_training_batch,
13
  # Diagnostic sets
@@ -37,6 +38,7 @@ from chess_engine._engine import (
37
  __all__ = [
38
  "generate_training_batch",
39
  "generate_random_games",
 
40
  "generate_checkmate_games",
41
  "generate_checkmate_training_batch",
42
  "generate_diagnostic_sets",
 
8
  # Core game generation
9
  generate_training_batch,
10
  generate_random_games,
11
+ generate_clm_batch,
12
  generate_checkmate_games,
13
  generate_checkmate_training_batch,
14
  # Diagnostic sets
 
38
  __all__ = [
39
  "generate_training_batch",
40
  "generate_random_games",
41
+ "generate_clm_batch",
42
  "generate_checkmate_games",
43
  "generate_checkmate_training_batch",
44
  "generate_diagnostic_sets",
engine/src/batch.rs CHANGED
@@ -49,23 +49,18 @@ pub fn generate_training_batch(batch_size: usize, max_ply: usize, seed: u64) ->
49
  game_lengths.push(record.game_length as i16);
50
  termination_codes.push(record.termination.as_u8());
51
 
52
- // Copy move_ids with padding
53
  for t in 0..length {
54
  move_ids[b * max_ply + t] = record.move_ids[t] as i16;
55
  }
56
- // Place EOG token at position game_length
57
- if length < max_ply {
58
- move_ids[b * max_ply + length] = vocab::EOG_TOKEN as i16;
59
- }
60
 
61
- // Copy legal move grids
62
  for t in 0..length {
63
  let grid_offset = (b * max_ply + t) * 64;
64
  debug_assert_eq!(record.legal_grids[t].len(), 64);
65
  legal_move_grid[grid_offset..grid_offset + 64]
66
  .copy_from_slice(&record.legal_grids[t]);
67
  }
68
- // EOG position: all zeros (already initialized)
69
 
70
  // Copy promotion masks (contiguous layout: [[bool; 4]; 44] = [bool; 176])
71
  for t in 0..length {
@@ -108,9 +103,6 @@ pub fn generate_random_games(n_games: usize, max_ply: usize, seed: u64) -> GameB
108
  for t in 0..(*length as usize) {
109
  move_ids[b * max_ply + t] = moves[t] as i16;
110
  }
111
- if (*length as usize) < max_ply {
112
- move_ids[b * max_ply + *length as usize] = vocab::EOG_TOKEN as i16;
113
- }
114
  }
115
 
116
  GameBatch {
@@ -152,9 +144,6 @@ pub fn generate_checkmate_training_batch(
152
  for t in 0..(ex.game_length as usize).min(max_ply) {
153
  move_ids[b * max_ply + t] = ex.move_ids[t] as i16;
154
  }
155
- if (ex.game_length as usize) < max_ply {
156
- move_ids[b * max_ply + ex.game_length as usize] = vocab::EOG_TOKEN as i16;
157
- }
158
  checkmate_targets[b * 64..(b + 1) * 64].copy_from_slice(&ex.checkmate_grid);
159
  legal_grids[b * 64..(b + 1) * 64].copy_from_slice(&ex.legal_grid);
160
  }
@@ -206,9 +195,6 @@ pub fn generate_completed_games(n_games: usize, max_ply: usize, seed: u64) -> Ga
206
  for t in 0..(*length as usize) {
207
  move_ids[b * max_ply + t] = moves[t] as i16;
208
  }
209
- if (*length as usize) < max_ply {
210
- move_ids[b * max_ply + *length as usize] = vocab::EOG_TOKEN as i16;
211
- }
212
  }
213
 
214
  GameBatch {
@@ -281,9 +267,6 @@ pub fn generate_checkmate_games(
281
  for t in 0..(*length as usize) {
282
  move_ids[b * max_ply + t] = moves[t] as i16;
283
  }
284
- if (*length as usize) < max_ply {
285
- move_ids[b * max_ply + *length as usize] = vocab::EOG_TOKEN as i16;
286
- }
287
  }
288
 
289
  (GameBatch {
@@ -295,6 +278,97 @@ pub fn generate_checkmate_games(
295
  }, total_generated)
296
  }
297
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
298
  #[cfg(test)]
299
  mod tests {
300
  use super::*;
@@ -321,15 +395,23 @@ mod tests {
321
  }
322
 
323
  #[test]
324
- fn test_eog_token_placement() {
325
  let batch = generate_training_batch(2, 256, 42);
326
  for b in 0..2 {
327
  let len = batch.game_lengths[b] as usize;
328
  if len < 256 {
329
  assert_eq!(
330
  batch.move_ids[b * 256 + len],
331
- vocab::EOG_TOKEN as i16,
332
- "EOG token should be at position game_length"
 
 
 
 
 
 
 
 
333
  );
334
  }
335
  }
@@ -343,4 +425,90 @@ mod tests {
343
  assert_eq!(b1.game_lengths, b2.game_lengths);
344
  assert_eq!(b1.legal_move_grid, b2.legal_move_grid);
345
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
346
  }
 
49
  game_lengths.push(record.game_length as i16);
50
  termination_codes.push(record.termination.as_u8());
51
 
52
+ // Copy move_ids (remaining positions are already 0 = PAD)
53
  for t in 0..length {
54
  move_ids[b * max_ply + t] = record.move_ids[t] as i16;
55
  }
 
 
 
 
56
 
57
+ // Copy legal move grids (positions beyond game_length are already 0)
58
  for t in 0..length {
59
  let grid_offset = (b * max_ply + t) * 64;
60
  debug_assert_eq!(record.legal_grids[t].len(), 64);
61
  legal_move_grid[grid_offset..grid_offset + 64]
62
  .copy_from_slice(&record.legal_grids[t]);
63
  }
 
64
 
65
  // Copy promotion masks (contiguous layout: [[bool; 4]; 44] = [bool; 176])
66
  for t in 0..length {
 
103
  for t in 0..(*length as usize) {
104
  move_ids[b * max_ply + t] = moves[t] as i16;
105
  }
 
 
 
106
  }
107
 
108
  GameBatch {
 
144
  for t in 0..(ex.game_length as usize).min(max_ply) {
145
  move_ids[b * max_ply + t] = ex.move_ids[t] as i16;
146
  }
 
 
 
147
  checkmate_targets[b * 64..(b + 1) * 64].copy_from_slice(&ex.checkmate_grid);
148
  legal_grids[b * 64..(b + 1) * 64].copy_from_slice(&ex.legal_grid);
149
  }
 
195
  for t in 0..(*length as usize) {
196
  move_ids[b * max_ply + t] = moves[t] as i16;
197
  }
 
 
 
198
  }
199
 
200
  GameBatch {
 
267
  for t in 0..(*length as usize) {
268
  move_ids[b * max_ply + t] = moves[t] as i16;
269
  }
 
 
 
270
  }
271
 
272
  (GameBatch {
 
278
  }, total_generated)
279
  }
280
 
281
+ /// Output of CLM (Causal Language Model) batch generation.
282
+ ///
283
+ /// Contains ready-to-train tensors in the format:
284
+ /// input_ids = [outcome, ply_1, ply_2, ..., ply_N, PAD, ..., PAD]
285
+ /// targets = [ply_1, ply_2, ply_3, ..., PAD, PAD, ..., PAD]
286
+ /// loss_mask = [true, true, true, ..., true, false, ..., false]
287
+ ///
288
+ /// Also includes raw move_ids and game_lengths for replay operations
289
+ /// (legal mask computation, board state extraction, validation).
290
+ pub struct CLMBatch {
291
+ pub input_ids: Vec<i16>, // [batch_size * seq_len]
292
+ pub targets: Vec<i16>, // [batch_size * seq_len]
293
+ pub loss_mask: Vec<bool>, // [batch_size * seq_len]
294
+ pub move_ids: Vec<i16>, // [batch_size * max_ply] raw for replay
295
+ pub game_lengths: Vec<i16>, // [batch_size]
296
+ pub termination_codes: Vec<u8>, // [batch_size]
297
+ pub batch_size: usize,
298
+ pub seq_len: usize,
299
+ pub max_ply: usize,
300
+ }
301
+
302
+ /// Generate a CLM training batch: random games packed into model-ready format.
303
+ ///
304
+ /// `seq_len` is the total sequence length (256). Games are generated with up to
305
+ /// `seq_len - 1` plies, leaving position 0 for the outcome token.
306
+ pub fn generate_clm_batch(
307
+ batch_size: usize,
308
+ seq_len: usize,
309
+ seed: u64,
310
+ discard_ply_limit: bool,
311
+ ) -> CLMBatch {
312
+ let max_ply = seq_len - 1;
313
+
314
+ let game_batch = if discard_ply_limit {
315
+ generate_completed_games(batch_size, max_ply, seed)
316
+ } else {
317
+ generate_random_games(batch_size, max_ply, seed)
318
+ };
319
+
320
+ let mut input_ids = vec![0i16; batch_size * seq_len];
321
+ let mut targets = vec![0i16; batch_size * seq_len];
322
+ let mut loss_mask = vec![false; batch_size * seq_len];
323
+
324
+ for b in 0..batch_size {
325
+ let gl = game_batch.game_lengths[b] as usize;
326
+ let term = match game_batch.termination_codes[b] {
327
+ 0 => Termination::Checkmate,
328
+ 1 => Termination::Stalemate,
329
+ 2 => Termination::SeventyFiveMoveRule,
330
+ 3 => Termination::FivefoldRepetition,
331
+ 4 => Termination::InsufficientMaterial,
332
+ _ => Termination::PlyLimit,
333
+ };
334
+ let outcome = vocab::termination_to_outcome(term, game_batch.game_lengths[b] as u16);
335
+
336
+ let row = b * seq_len;
337
+
338
+ // Position 0: outcome token
339
+ input_ids[row] = outcome as i16;
340
+
341
+ // Positions 1..=gl: move tokens
342
+ for t in 0..gl {
343
+ input_ids[row + 1 + t] = game_batch.move_ids[b * max_ply + t];
344
+ }
345
+ // Remaining positions are already 0 (PAD)
346
+
347
+ // Targets: input_ids shifted left by 1
348
+ for t in 0..(seq_len - 1) {
349
+ targets[row + t] = input_ids[row + t + 1];
350
+ }
351
+ // targets[row + seq_len - 1] is already 0
352
+
353
+ // Loss mask: positions 0..=gl are true
354
+ for t in 0..=gl {
355
+ loss_mask[row + t] = true;
356
+ }
357
+ }
358
+
359
+ CLMBatch {
360
+ input_ids,
361
+ targets,
362
+ loss_mask,
363
+ move_ids: game_batch.move_ids,
364
+ game_lengths: game_batch.game_lengths,
365
+ termination_codes: game_batch.termination_codes,
366
+ batch_size,
367
+ seq_len,
368
+ max_ply,
369
+ }
370
+ }
371
+
372
  #[cfg(test)]
373
  mod tests {
374
  use super::*;
 
395
  }
396
 
397
  #[test]
398
+ fn test_pad_after_game_end() {
399
  let batch = generate_training_batch(2, 256, 42);
400
  for b in 0..2 {
401
  let len = batch.game_lengths[b] as usize;
402
  if len < 256 {
403
  assert_eq!(
404
  batch.move_ids[b * 256 + len],
405
+ vocab::PAD_TOKEN as i16,
406
+ "Position game_length should be PAD (0)"
407
+ );
408
+ }
409
+ // All positions after game_length should also be PAD
410
+ for t in len..256 {
411
+ assert_eq!(
412
+ batch.move_ids[b * 256 + t],
413
+ 0,
414
+ "Position {} (after game_length={}) should be PAD", t, len
415
  );
416
  }
417
  }
 
425
  assert_eq!(b1.game_lengths, b2.game_lengths);
426
  assert_eq!(b1.legal_move_grid, b2.legal_move_grid);
427
  }
428
+
429
+ #[test]
430
+ fn test_clm_batch_format() {
431
+ let seq_len = 256;
432
+ let batch = generate_clm_batch(8, seq_len, 42, false);
433
+ assert_eq!(batch.input_ids.len(), 8 * seq_len);
434
+ assert_eq!(batch.targets.len(), 8 * seq_len);
435
+ assert_eq!(batch.loss_mask.len(), 8 * seq_len);
436
+ assert_eq!(batch.move_ids.len(), 8 * (seq_len - 1));
437
+ assert_eq!(batch.game_lengths.len(), 8);
438
+
439
+ for b in 0..8 {
440
+ let gl = batch.game_lengths[b] as usize;
441
+ let row = b * seq_len;
442
+
443
+ // Position 0: outcome token (4273-4277)
444
+ let outcome = batch.input_ids[row];
445
+ assert!(outcome >= vocab::OUTCOME_BASE as i16 && outcome <= vocab::PLY_LIMIT as i16,
446
+ "Position 0 should be outcome token, got {}", outcome);
447
+
448
+ // Positions 1..=gl: move tokens (1-4272)
449
+ for t in 1..=gl {
450
+ let tok = batch.input_ids[row + t];
451
+ assert!(tok >= 1 && tok <= 4272,
452
+ "Position {} should be move token, got {}", t, tok);
453
+ }
454
+
455
+ // Positions gl+1..seq_len: PAD (0)
456
+ for t in (gl + 1)..seq_len {
457
+ assert_eq!(batch.input_ids[row + t], 0,
458
+ "Position {} should be PAD, got {}", t, batch.input_ids[row + t]);
459
+ }
460
+
461
+ // Targets: shifted left by 1
462
+ for t in 0..(seq_len - 1) {
463
+ assert_eq!(batch.targets[row + t], batch.input_ids[row + t + 1],
464
+ "targets[{}] should equal input_ids[{}]", t, t + 1);
465
+ }
466
+ assert_eq!(batch.targets[row + seq_len - 1], 0, "Last target should be PAD");
467
+
468
+ // Target at position gl is PAD (end of game)
469
+ assert_eq!(batch.targets[row + gl], 0, "Target at game_length should be PAD");
470
+
471
+ // Loss mask: true for 0..=gl, false after
472
+ for t in 0..=gl {
473
+ assert!(batch.loss_mask[row + t],
474
+ "loss_mask[{}] should be true (gl={})", t, gl);
475
+ }
476
+ for t in (gl + 1)..seq_len {
477
+ assert!(!batch.loss_mask[row + t],
478
+ "loss_mask[{}] should be false (gl={})", t, gl);
479
+ }
480
+ }
481
+ }
482
+
483
+ #[test]
484
+ fn test_clm_batch_deterministic() {
485
+ let b1 = generate_clm_batch(4, 256, 99, false);
486
+ let b2 = generate_clm_batch(4, 256, 99, false);
487
+ assert_eq!(b1.input_ids, b2.input_ids);
488
+ assert_eq!(b1.targets, b2.targets);
489
+ assert_eq!(b1.loss_mask, b2.loss_mask);
490
+ assert_eq!(b1.game_lengths, b2.game_lengths);
491
+ }
492
+
493
+ #[test]
494
+ fn test_clm_batch_outcome_correctness() {
495
+ let batch = generate_clm_batch(32, 256, 42, false);
496
+ for b in 0..32 {
497
+ let gl = batch.game_lengths[b] as usize;
498
+ let tc = batch.termination_codes[b];
499
+ let expected = vocab::termination_to_outcome(
500
+ match tc {
501
+ 0 => Termination::Checkmate,
502
+ 1 => Termination::Stalemate,
503
+ 2 => Termination::SeventyFiveMoveRule,
504
+ 3 => Termination::FivefoldRepetition,
505
+ 4 => Termination::InsufficientMaterial,
506
+ _ => Termination::PlyLimit,
507
+ },
508
+ gl as u16,
509
+ );
510
+ assert_eq!(batch.input_ids[b * 256] as u16, expected,
511
+ "Game {} outcome mismatch: tc={}, gl={}", b, tc, gl);
512
+ }
513
+ }
514
  }
engine/src/board.rs CHANGED
@@ -74,7 +74,7 @@ pub fn move_to_token(m: &Move) -> u16 {
74
  /// Convert our token index to a shakmaty Move, given the current position.
75
  /// Finds the legal move matching the token's (src, dst, promo) decomposition.
76
  pub fn token_to_move(pos: &Chess, token: u16) -> Option<Move> {
77
- // Validate the token is decomposable (not PAD/EOG)
78
  vocab::decompose_token(token)?;
79
  let legal = pos.legal_moves();
80
 
 
74
  /// Convert our token index to a shakmaty Move, given the current position.
75
  /// Finds the legal move matching the token's (src, dst, promo) decomposition.
76
  pub fn token_to_move(pos: &Chess, token: u16) -> Option<Move> {
77
+ // Validate the token is decomposable (not PAD/outcome)
78
  vocab::decompose_token(token)?;
79
  let legal = pos.legal_moves();
80
 
engine/src/lib.rs CHANGED
@@ -169,6 +169,46 @@ fn generate_random_games<'py>(
169
  Ok((move_ids, game_lengths, termination_codes))
170
  }
171
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
172
  /// Compute legal move masks by replaying games. Spec §7.4.
173
  #[pyfunction]
174
  fn compute_legal_move_masks<'py>(
@@ -868,6 +908,7 @@ fn _engine(m: &Bound<'_, PyModule>) -> PyResult<()> {
868
  m.add_function(wrap_pyfunction!(export_move_vocabulary, m)?)?;
869
  m.add_function(wrap_pyfunction!(generate_training_batch, m)?)?;
870
  m.add_function(wrap_pyfunction!(generate_random_games, m)?)?;
 
871
  m.add_function(wrap_pyfunction!(generate_checkmate_games, m)?)?;
872
  m.add_function(wrap_pyfunction!(generate_checkmate_training_batch, m)?)?;
873
  m.add_function(wrap_pyfunction!(compute_legal_move_masks, m)?)?;
 
169
  Ok((move_ids, game_lengths, termination_codes))
170
  }
171
 
172
+ /// Generate a CLM training batch with model-ready tensors.
173
+ ///
174
+ /// Returns (input_ids, targets, loss_mask, move_ids, game_lengths, term_codes).
175
+ /// input_ids = [outcome, ply_1, ..., ply_N, PAD, ...] (seq_len per row).
176
+ /// move_ids are the raw moves (seq_len-1 per row) for replay operations.
177
+ #[pyfunction]
178
+ #[pyo3(signature = (batch_size, seq_len=256, seed=42, discard_ply_limit=false))]
179
+ fn generate_clm_batch<'py>(
180
+ py: Python<'py>,
181
+ batch_size: usize,
182
+ seq_len: usize,
183
+ seed: u64,
184
+ discard_ply_limit: bool,
185
+ ) -> PyResult<(
186
+ Bound<'py, PyArray2<i16>>, // input_ids (B, seq_len)
187
+ Bound<'py, PyArray2<i16>>, // targets (B, seq_len)
188
+ Bound<'py, PyArray2<bool>>, // loss_mask (B, seq_len)
189
+ Bound<'py, PyArray2<i16>>, // move_ids (B, seq_len-1)
190
+ Bound<'py, PyArray1<i16>>, // game_lengths (B,)
191
+ Bound<'py, PyArray1<u8>>, // termination_codes (B,)
192
+ )> {
193
+ let result = py.allow_threads(|| {
194
+ batch::generate_clm_batch(batch_size, seq_len, seed, discard_ply_limit)
195
+ });
196
+
197
+ let max_ply = seq_len - 1;
198
+ let input_ids = numpy::PyArray::from_vec(py, result.input_ids)
199
+ .reshape([batch_size, seq_len])?;
200
+ let targets = numpy::PyArray::from_vec(py, result.targets)
201
+ .reshape([batch_size, seq_len])?;
202
+ let loss_mask = numpy::PyArray::from_vec(py, result.loss_mask)
203
+ .reshape([batch_size, seq_len])?;
204
+ let move_ids = numpy::PyArray::from_vec(py, result.move_ids)
205
+ .reshape([batch_size, max_ply])?;
206
+ let game_lengths = numpy::PyArray::from_vec(py, result.game_lengths);
207
+ let termination_codes = numpy::PyArray::from_vec(py, result.termination_codes);
208
+
209
+ Ok((input_ids, targets, loss_mask, move_ids, game_lengths, termination_codes))
210
+ }
211
+
212
  /// Compute legal move masks by replaying games. Spec §7.4.
213
  #[pyfunction]
214
  fn compute_legal_move_masks<'py>(
 
908
  m.add_function(wrap_pyfunction!(export_move_vocabulary, m)?)?;
909
  m.add_function(wrap_pyfunction!(generate_training_batch, m)?)?;
910
  m.add_function(wrap_pyfunction!(generate_random_games, m)?)?;
911
+ m.add_function(wrap_pyfunction!(generate_clm_batch, m)?)?;
912
  m.add_function(wrap_pyfunction!(generate_checkmate_games, m)?)?;
913
  m.add_function(wrap_pyfunction!(generate_checkmate_training_batch, m)?)?;
914
  m.add_function(wrap_pyfunction!(compute_legal_move_masks, m)?)?;
engine/src/pgn.rs CHANGED
@@ -62,9 +62,6 @@ pub fn batch_san_to_tokens(
62
  for (t, &tok) in tokens.iter().enumerate() {
63
  flat[gi * max_ply + t] = tok as i16;
64
  }
65
- if n_valid < max_ply {
66
- flat[gi * max_ply + n_valid] = crate::vocab::EOG_TOKEN as i16;
67
- }
68
  lengths.push(n_valid as i16);
69
  }
70
 
@@ -195,9 +192,6 @@ pub fn pgn_file_to_tokens(
195
  for (t, &tok) in tokens.iter().enumerate() {
196
  flat[gi * max_ply + t] = tok as i16;
197
  }
198
- if *n_valid < max_ply {
199
- flat[gi * max_ply + n_valid] = crate::vocab::EOG_TOKEN as i16;
200
- }
201
  lengths.push(*n_valid as i16);
202
  }
203
 
 
62
  for (t, &tok) in tokens.iter().enumerate() {
63
  flat[gi * max_ply + t] = tok as i16;
64
  }
 
 
 
65
  lengths.push(n_valid as i16);
66
  }
67
 
 
192
  for (t, &tok) in tokens.iter().enumerate() {
193
  flat[gi * max_ply + t] = tok as i16;
194
  }
 
 
 
195
  lengths.push(*n_valid as i16);
196
  }
197
 
engine/src/vocab.rs CHANGED
@@ -1,10 +1,10 @@
1
  //! Move vocabulary: the single source of truth for token ↔ UCI string mapping.
2
  //!
3
- //! Token layout (4,274 total):
4
  //! 0 = padding
5
  //! 1..=4096 = base grid (64×64 src×dst pairs)
6
  //! 4097..=4272 = promotions (44 eligible pairs × 4 piece types)
7
- //! 4273 = end-of-game
8
  //!
9
  //! Square indexing: file-major within rank.
10
  //! a1=0, b1=1, ..., h1=7, a2=8, ..., h8=63
@@ -12,9 +12,10 @@
12
 
13
  use std::collections::HashMap;
14
 
15
- pub const VOCAB_SIZE: usize = 4274;
 
 
16
  pub const PAD_TOKEN: u16 = 0;
17
- pub const EOG_TOKEN: u16 = 4273;
18
  pub const BASE_GRID_START: u16 = 1;
19
  pub const BASE_GRID_END: u16 = 4096; // inclusive
20
  pub const PROMO_START: u16 = 4097;
@@ -22,6 +23,14 @@ pub const PROMO_END: u16 = 4272; // inclusive
22
  pub const NUM_PROMO_PAIRS: usize = 44;
23
  pub const NUM_PROMO_TYPES: usize = 4;
24
 
 
 
 
 
 
 
 
 
25
  /// Square names in our index order.
26
  pub const SQUARE_NAMES: [&str; 64] = [
27
  "a1", "b1", "c1", "d1", "e1", "f1", "g1", "h1",
@@ -131,11 +140,29 @@ pub fn promo_token(src: u8, dst: u8, promo_type: u8) -> Option<u16> {
131
  Some(PROMO_START + (pair_idx as u16) * 4 + (promo_type as u16))
132
  }
133
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
  /// Decompose a token into (src_square, dst_square, promo_type).
135
  /// promo_type: 0=none, 1=q, 2=r, 3=b, 4=n (matches embedding index).
136
- /// Returns None for PAD and EOG tokens.
137
  pub fn decompose_token(token: u16) -> Option<(u8, u8, u8)> {
138
- if token == PAD_TOKEN || token == EOG_TOKEN {
139
  return None;
140
  }
141
  if token >= BASE_GRID_START && token <= BASE_GRID_END {
@@ -285,8 +312,32 @@ mod tests {
285
  }
286
 
287
  #[test]
288
- fn test_pad_eog_decompose() {
289
  assert!(decompose_token(PAD_TOKEN).is_none());
290
- assert!(decompose_token(EOG_TOKEN).is_none());
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
291
  }
292
  }
 
1
  //! Move vocabulary: the single source of truth for token ↔ UCI string mapping.
2
  //!
3
+ //! Token layout (4,278 total):
4
  //! 0 = padding
5
  //! 1..=4096 = base grid (64×64 src×dst pairs)
6
  //! 4097..=4272 = promotions (44 eligible pairs × 4 piece types)
7
+ //! 4273..=4277 = outcome tokens (game result)
8
  //!
9
  //! Square indexing: file-major within rank.
10
  //! a1=0, b1=1, ..., h1=7, a2=8, ..., h8=63
 
12
 
13
  use std::collections::HashMap;
14
 
15
+ use crate::types::Termination;
16
+
17
+ pub const VOCAB_SIZE: usize = 4278;
18
  pub const PAD_TOKEN: u16 = 0;
 
19
  pub const BASE_GRID_START: u16 = 1;
20
  pub const BASE_GRID_END: u16 = 4096; // inclusive
21
  pub const PROMO_START: u16 = 4097;
 
23
  pub const NUM_PROMO_PAIRS: usize = 44;
24
  pub const NUM_PROMO_TYPES: usize = 4;
25
 
26
+ // Outcome tokens — must match pawn/config.py
27
+ pub const OUTCOME_BASE: u16 = 4273;
28
+ pub const WHITE_CHECKMATES: u16 = 4273;
29
+ pub const BLACK_CHECKMATES: u16 = 4274;
30
+ pub const STALEMATE: u16 = 4275;
31
+ pub const DRAW_BY_RULE: u16 = 4276;
32
+ pub const PLY_LIMIT: u16 = 4277;
33
+
34
  /// Square names in our index order.
35
  pub const SQUARE_NAMES: [&str; 64] = [
36
  "a1", "b1", "c1", "d1", "e1", "f1", "g1", "h1",
 
140
  Some(PROMO_START + (pair_idx as u16) * 4 + (promo_type as u16))
141
  }
142
 
143
+ /// Map a game termination reason to the corresponding outcome token.
144
+ ///
145
+ /// For checkmate, the winner is determined by game length parity:
146
+ /// - Odd game_length (white made the last move) → WHITE_CHECKMATES
147
+ /// - Even game_length (black made the last move) → BLACK_CHECKMATES
148
+ pub fn termination_to_outcome(term: Termination, game_length: u16) -> u16 {
149
+ match term {
150
+ Termination::Checkmate => {
151
+ if game_length % 2 == 1 { WHITE_CHECKMATES } else { BLACK_CHECKMATES }
152
+ }
153
+ Termination::Stalemate => STALEMATE,
154
+ Termination::SeventyFiveMoveRule
155
+ | Termination::FivefoldRepetition
156
+ | Termination::InsufficientMaterial => DRAW_BY_RULE,
157
+ Termination::PlyLimit => PLY_LIMIT,
158
+ }
159
+ }
160
+
161
  /// Decompose a token into (src_square, dst_square, promo_type).
162
  /// promo_type: 0=none, 1=q, 2=r, 3=b, 4=n (matches embedding index).
163
+ /// Returns None for PAD and outcome tokens.
164
  pub fn decompose_token(token: u16) -> Option<(u8, u8, u8)> {
165
+ if token == PAD_TOKEN || token >= OUTCOME_BASE {
166
  return None;
167
  }
168
  if token >= BASE_GRID_START && token <= BASE_GRID_END {
 
312
  }
313
 
314
  #[test]
315
+ fn test_pad_outcome_decompose() {
316
  assert!(decompose_token(PAD_TOKEN).is_none());
317
+ // All 5 outcome tokens should return None
318
+ for token in OUTCOME_BASE..=PLY_LIMIT {
319
+ assert!(decompose_token(token).is_none(),
320
+ "outcome token {} should not decompose", token);
321
+ }
322
+ }
323
+
324
+ #[test]
325
+ fn test_termination_to_outcome() {
326
+ use crate::types::Termination;
327
+
328
+ // Checkmate with odd game_length = white wins
329
+ assert_eq!(termination_to_outcome(Termination::Checkmate, 11), WHITE_CHECKMATES);
330
+ assert_eq!(termination_to_outcome(Termination::Checkmate, 1), WHITE_CHECKMATES);
331
+
332
+ // Checkmate with even game_length = black wins
333
+ assert_eq!(termination_to_outcome(Termination::Checkmate, 12), BLACK_CHECKMATES);
334
+ assert_eq!(termination_to_outcome(Termination::Checkmate, 2), BLACK_CHECKMATES);
335
+
336
+ // Other terminations
337
+ assert_eq!(termination_to_outcome(Termination::Stalemate, 50), STALEMATE);
338
+ assert_eq!(termination_to_outcome(Termination::SeventyFiveMoveRule, 100), DRAW_BY_RULE);
339
+ assert_eq!(termination_to_outcome(Termination::FivefoldRepetition, 80), DRAW_BY_RULE);
340
+ assert_eq!(termination_to_outcome(Termination::InsufficientMaterial, 60), DRAW_BY_RULE);
341
+ assert_eq!(termination_to_outcome(Termination::PlyLimit, 255), PLY_LIMIT);
342
  }
343
  }
pawn/checkpoint.py ADDED
@@ -0,0 +1,647 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Checkpoint save/load using safetensors + JSON.
2
+
3
+ Replaces monolithic torch.save() .pt files with directory-based checkpoints:
4
+ - model.safetensors / adapter.safetensors — tensor data
5
+ - optimizer.safetensors — flattened optimizer state tensors
6
+ - training_state.json — scalars, scheduler, scaler, RNG, optimizer metadata
7
+ - config.json — model and training configuration
8
+ - .complete — SHA-256 hashes of all files (integrity sentinel)
9
+
10
+ Writes are atomic: files are written to a .tmp directory, then renamed.
11
+ Loads always verify the .complete sentinel and SHA-256 hashes.
12
+
13
+ Backward compatible: all load functions transparently handle legacy .pt files.
14
+ """
15
+
16
+ from __future__ import annotations
17
+
18
+ import base64
19
+ import hashlib
20
+ import json
21
+ import os
22
+ import shutil
23
+ from pathlib import Path
24
+
25
+ import torch
26
+ import torch.nn as nn
27
+ from safetensors.torch import save_file, load_file
28
+
29
+ CHECKPOINT_FORMAT_VERSION = 1
30
+ LEGACY_EXTENSIONS = {".pt", ".pth"}
31
+
32
+
33
+ # ---------------------------------------------------------------------------
34
+ # Exceptions
35
+ # ---------------------------------------------------------------------------
36
+
37
+ class IncompleteCheckpointError(Exception):
38
+ """Raised when a checkpoint directory is missing its .complete sentinel."""
39
+ pass
40
+
41
+
42
+ class CheckpointIntegrityError(Exception):
43
+ """Raised when a checkpoint file's SHA-256 hash doesn't match .complete."""
44
+ pass
45
+
46
+
47
+ # ---------------------------------------------------------------------------
48
+ # SHA-256 helpers
49
+ # ---------------------------------------------------------------------------
50
+
51
+ def _sha256_file(path: Path) -> str:
52
+ """Compute SHA-256 hex digest of a file."""
53
+ h = hashlib.sha256()
54
+ with open(path, "rb") as f:
55
+ for chunk in iter(lambda: f.read(1 << 20), b""): # 1MB chunks
56
+ h.update(chunk)
57
+ return h.hexdigest()
58
+
59
+
60
+ def _write_complete_sentinel(directory: Path) -> None:
61
+ """Write .complete sentinel with SHA-256 hashes of all checkpoint files."""
62
+ hashes = {}
63
+ for f in sorted(directory.iterdir()):
64
+ if f.name == ".complete" or f.is_dir():
65
+ continue
66
+ hashes[f.name] = _sha256_file(f)
67
+
68
+ sentinel = {"format_version": CHECKPOINT_FORMAT_VERSION, "files": hashes}
69
+ with open(directory / ".complete", "w") as f:
70
+ json.dump(sentinel, f, indent=2)
71
+
72
+
73
+ def _verify_complete_sentinel(directory: Path) -> None:
74
+ """Verify .complete sentinel exists and all hashes match.
75
+
76
+ Raises IncompleteCheckpointError if sentinel is missing.
77
+ Raises CheckpointIntegrityError if any hash mismatches.
78
+ """
79
+ sentinel_path = directory / ".complete"
80
+ if not sentinel_path.exists():
81
+ raise IncompleteCheckpointError(
82
+ f"Checkpoint {directory} is missing .complete sentinel — "
83
+ f"likely a partial write from a crashed or interrupted save."
84
+ )
85
+
86
+ with open(sentinel_path) as f:
87
+ sentinel = json.load(f)
88
+
89
+ for filename, expected_hash in sentinel["files"].items():
90
+ filepath = directory / filename
91
+ if not filepath.exists():
92
+ raise CheckpointIntegrityError(
93
+ f"File {filename} listed in .complete but missing from {directory}"
94
+ )
95
+ actual_hash = _sha256_file(filepath)
96
+ if actual_hash != expected_hash:
97
+ raise CheckpointIntegrityError(
98
+ f"SHA-256 mismatch for {filename} in {directory}: "
99
+ f"expected {expected_hash[:16]}..., got {actual_hash[:16]}..."
100
+ )
101
+
102
+
103
+ # ---------------------------------------------------------------------------
104
+ # Atomic directory write
105
+ # ---------------------------------------------------------------------------
106
+
107
+ def _atomic_directory_write(target: Path):
108
+ """Context manager for atomic directory writes.
109
+
110
+ Usage:
111
+ with _atomic_directory_write(Path("step_00001")) as tmp:
112
+ save_file(tensors, tmp / "model.safetensors")
113
+ ...
114
+ # Directory is now at step_00001/ with .complete sentinel
115
+ """
116
+ class _AtomicDir:
117
+ def __init__(self, target: Path):
118
+ self.target = target
119
+ self.tmp = target.parent / f"{target.name}.tmp"
120
+
121
+ def __enter__(self) -> Path:
122
+ # Clean up any leftover .tmp from a previous crash
123
+ if self.tmp.exists():
124
+ shutil.rmtree(self.tmp)
125
+ self.tmp.mkdir(parents=True, exist_ok=True)
126
+ return self.tmp
127
+
128
+ def __exit__(self, exc_type, exc_val, exc_tb):
129
+ if exc_type is not None:
130
+ # Error during write — clean up temp dir
131
+ if self.tmp.exists():
132
+ shutil.rmtree(self.tmp)
133
+ return False
134
+
135
+ # Write .complete sentinel with hashes
136
+ _write_complete_sentinel(self.tmp)
137
+
138
+ # Atomic rename (same filesystem)
139
+ if self.target.exists():
140
+ shutil.rmtree(self.target)
141
+ os.rename(self.tmp, self.target)
142
+ return False
143
+
144
+ return _AtomicDir(target)
145
+
146
+
147
+ # ---------------------------------------------------------------------------
148
+ # JSON helpers
149
+ # ---------------------------------------------------------------------------
150
+
151
+ def _json_default(obj):
152
+ """JSON serializer for types not natively supported."""
153
+ if isinstance(obj, torch.Tensor):
154
+ return obj.item() if obj.numel() == 1 else obj.tolist()
155
+ if hasattr(obj, "item"): # numpy scalar
156
+ return obj.item()
157
+ if isinstance(obj, Path):
158
+ return str(obj)
159
+ return str(obj)
160
+
161
+
162
+ # ---------------------------------------------------------------------------
163
+ # RNG state serialization
164
+ # ---------------------------------------------------------------------------
165
+
166
+ def _rng_to_json(
167
+ torch_rng: torch.Tensor | None, cuda_rng: torch.Tensor | None
168
+ ) -> dict:
169
+ data = {}
170
+ if torch_rng is not None:
171
+ data["torch_rng_state"] = base64.b64encode(
172
+ torch_rng.numpy().tobytes()
173
+ ).decode("ascii")
174
+ if cuda_rng is not None:
175
+ data["cuda_rng_state"] = base64.b64encode(
176
+ cuda_rng.numpy().tobytes()
177
+ ).decode("ascii")
178
+ return data
179
+
180
+
181
+ def _json_to_rng(data: dict) -> tuple[torch.Tensor | None, torch.Tensor | None]:
182
+ torch_rng = None
183
+ if "torch_rng_state" in data:
184
+ raw = base64.b64decode(data["torch_rng_state"])
185
+ torch_rng = torch.frombuffer(bytearray(raw), dtype=torch.uint8)
186
+ cuda_rng = None
187
+ if "cuda_rng_state" in data:
188
+ raw = base64.b64decode(data["cuda_rng_state"])
189
+ cuda_rng = torch.frombuffer(bytearray(raw), dtype=torch.uint8)
190
+ return torch_rng, cuda_rng
191
+
192
+
193
+ # ---------------------------------------------------------------------------
194
+ # Optimizer state flattening for safetensors
195
+ # ---------------------------------------------------------------------------
196
+
197
+ def _flatten_optimizer_state(
198
+ opt_state_dict: dict,
199
+ ) -> tuple[dict[str, torch.Tensor], dict]:
200
+ """Flatten optimizer state into safetensors-compatible tensors + JSON metadata."""
201
+ tensors: dict[str, torch.Tensor] = {}
202
+ scalars: dict[str, float | int] = {}
203
+
204
+ for param_id, param_state in opt_state_dict["state"].items():
205
+ for key, val in param_state.items():
206
+ flat_key = f"state.{param_id}.{key}"
207
+ if isinstance(val, torch.Tensor):
208
+ t = val.cpu().contiguous()
209
+ if t.ndim == 0:
210
+ t = t.unsqueeze(0)
211
+ tensors[flat_key] = t
212
+ else:
213
+ scalars[flat_key] = val
214
+
215
+ meta = {
216
+ "param_groups": opt_state_dict["param_groups"],
217
+ "scalars": scalars if scalars else None,
218
+ }
219
+ return tensors, meta
220
+
221
+
222
+ def _unflatten_optimizer_state(
223
+ tensors: dict[str, torch.Tensor],
224
+ meta: dict,
225
+ device: str = "cpu",
226
+ ) -> dict:
227
+ """Reconstruct optimizer state_dict from flattened tensors + metadata."""
228
+ state: dict[int, dict[str, torch.Tensor | float | int]] = {}
229
+ scalars = meta.get("scalars") or {}
230
+
231
+ for flat_key, val in tensors.items():
232
+ parts = flat_key.split(".", 2)
233
+ if len(parts) != 3 or parts[0] != "state":
234
+ continue
235
+ param_id = int(parts[1])
236
+ key = parts[2]
237
+ if param_id not in state:
238
+ state[param_id] = {}
239
+ t = val.to(device)
240
+ if t.shape == (1,) and key == "step":
241
+ t = t.squeeze(0)
242
+ state[param_id][key] = t
243
+
244
+ for flat_key, val in scalars.items():
245
+ parts = flat_key.split(".", 2)
246
+ if len(parts) != 3:
247
+ continue
248
+ param_id = int(parts[1])
249
+ key = parts[2]
250
+ if param_id not in state:
251
+ state[param_id] = {}
252
+ state[param_id][key] = val
253
+
254
+ return {
255
+ "state": state,
256
+ "param_groups": meta["param_groups"],
257
+ }
258
+
259
+
260
+ # ---------------------------------------------------------------------------
261
+ # Legacy checkpoint detection
262
+ # ---------------------------------------------------------------------------
263
+
264
+ def is_legacy_checkpoint(path: str | Path) -> bool:
265
+ """True if path is a single .pt/.pth file (legacy format)."""
266
+ p = Path(path)
267
+ return p.is_file() and p.suffix in LEGACY_EXTENSIONS
268
+
269
+
270
+ def _load_legacy_pt(path: str | Path, device: str = "cpu") -> dict:
271
+ return torch.load(str(path), map_location=device, weights_only=False)
272
+
273
+
274
+ # ---------------------------------------------------------------------------
275
+ # Pretrain checkpoint save/load
276
+ # ---------------------------------------------------------------------------
277
+
278
+ def save_pretrain_checkpoint(
279
+ path: str | Path,
280
+ model: nn.Module,
281
+ optimizer: torch.optim.Optimizer,
282
+ scheduler,
283
+ scaler,
284
+ global_step: int,
285
+ model_config: dict,
286
+ training_config: dict,
287
+ ) -> None:
288
+ """Save a pretraining checkpoint atomically.
289
+
290
+ Writes to {path}.tmp/, then renames to {path}/ with .complete sentinel.
291
+ """
292
+ path = Path(path)
293
+
294
+ with _atomic_directory_write(path) as tmp:
295
+ # 1. Model weights
296
+ state_dict = {k: v.cpu().contiguous() for k, v in model.state_dict().items()}
297
+ save_file(state_dict, tmp / "model.safetensors")
298
+
299
+ # 2. Optimizer tensors
300
+ opt_tensors, opt_meta = _flatten_optimizer_state(optimizer.state_dict())
301
+ if opt_tensors:
302
+ save_file(opt_tensors, tmp / "optimizer.safetensors")
303
+
304
+ # 3. Training state (JSON)
305
+ training_state = {
306
+ "format_version": CHECKPOINT_FORMAT_VERSION,
307
+ "global_step": global_step,
308
+ "scheduler_state_dict": scheduler.state_dict() if scheduler is not None else None,
309
+ "scaler_state_dict": scaler.state_dict() if scaler is not None else None,
310
+ "optimizer_meta": opt_meta,
311
+ **_rng_to_json(
312
+ torch.get_rng_state(),
313
+ torch.cuda.get_rng_state() if torch.cuda.is_available() else None,
314
+ ),
315
+ }
316
+ with open(tmp / "training_state.json", "w") as f:
317
+ json.dump(training_state, f, indent=2, default=_json_default)
318
+
319
+ # 4. Config
320
+ config = {
321
+ "format_version": CHECKPOINT_FORMAT_VERSION,
322
+ "checkpoint_type": "pretrain",
323
+ "model_config": model_config,
324
+ "training_config": training_config,
325
+ }
326
+ with open(tmp / "config.json", "w") as f:
327
+ json.dump(config, f, indent=2, default=_json_default)
328
+
329
+
330
+ def load_pretrain_checkpoint(
331
+ path: str | Path,
332
+ model: nn.Module,
333
+ optimizer: torch.optim.Optimizer | None = None,
334
+ scheduler=None,
335
+ scaler=None,
336
+ device: str = "cpu",
337
+ ) -> dict:
338
+ """Load a pretraining checkpoint with integrity verification.
339
+
340
+ Handles both legacy .pt files and new directory format.
341
+ New format: verifies .complete sentinel and SHA-256 hashes.
342
+ """
343
+ path = Path(path)
344
+
345
+ if is_legacy_checkpoint(path):
346
+ ckpt = _load_legacy_pt(path, device)
347
+ model.load_state_dict(ckpt["model_state_dict"])
348
+ if optimizer and "optimizer_state_dict" in ckpt:
349
+ optimizer.load_state_dict(ckpt["optimizer_state_dict"])
350
+ if scheduler and "scheduler_state_dict" in ckpt:
351
+ scheduler.load_state_dict(ckpt["scheduler_state_dict"])
352
+ if scaler and "scaler_state_dict" in ckpt:
353
+ scaler.load_state_dict(ckpt["scaler_state_dict"])
354
+ if ckpt.get("torch_rng_state") is not None:
355
+ torch.set_rng_state(ckpt["torch_rng_state"].cpu().byte())
356
+ if ckpt.get("cuda_rng_state") is not None and torch.cuda.is_available():
357
+ torch.cuda.set_rng_state(ckpt["cuda_rng_state"].cpu().byte())
358
+ return {
359
+ "global_step": ckpt.get("global_step", 0),
360
+ "model_config": ckpt.get("model_config"),
361
+ "training_config": ckpt.get("training_config"),
362
+ }
363
+
364
+ # New directory format — verify integrity first
365
+ _verify_complete_sentinel(path)
366
+
367
+ weights = load_file(path / "model.safetensors", device=device)
368
+ model.load_state_dict(weights)
369
+
370
+ with open(path / "training_state.json") as f:
371
+ ts = json.load(f)
372
+
373
+ if optimizer and (path / "optimizer.safetensors").exists():
374
+ opt_tensors = load_file(path / "optimizer.safetensors", device=device)
375
+ opt_state = _unflatten_optimizer_state(opt_tensors, ts["optimizer_meta"], device)
376
+ optimizer.load_state_dict(opt_state)
377
+
378
+ if scheduler and "scheduler_state_dict" in ts:
379
+ scheduler.load_state_dict(ts["scheduler_state_dict"])
380
+
381
+ if scaler and "scaler_state_dict" in ts:
382
+ scaler.load_state_dict(ts["scaler_state_dict"])
383
+
384
+ torch_rng, cuda_rng = _json_to_rng(ts)
385
+ if torch_rng is not None:
386
+ torch.set_rng_state(torch_rng.cpu().byte())
387
+ if cuda_rng is not None and torch.cuda.is_available():
388
+ torch.cuda.set_rng_state(cuda_rng.cpu().byte())
389
+
390
+ with open(path / "config.json") as f:
391
+ config = json.load(f)
392
+
393
+ return {
394
+ "global_step": ts.get("global_step", 0),
395
+ "model_config": config.get("model_config"),
396
+ "training_config": config.get("training_config"),
397
+ }
398
+
399
+
400
+ # ---------------------------------------------------------------------------
401
+ # Adapter checkpoint save/load
402
+ # ---------------------------------------------------------------------------
403
+
404
+ def save_adapter_checkpoint(
405
+ path: str | Path,
406
+ adapter_state_dict: dict[str, torch.Tensor],
407
+ config: dict,
408
+ epoch: int,
409
+ step: int,
410
+ val_metrics: dict,
411
+ optimizer: torch.optim.Optimizer | None = None,
412
+ scheduler=None,
413
+ scaler=None,
414
+ extra: dict | None = None,
415
+ ) -> None:
416
+ """Save an adapter checkpoint atomically."""
417
+ path = Path(path)
418
+
419
+ with _atomic_directory_write(path) as tmp:
420
+ # 1. Adapter weights
421
+ tensors = {k: v.cpu().contiguous() for k, v in adapter_state_dict.items()}
422
+ save_file(tensors, tmp / "adapter.safetensors")
423
+
424
+ # 2. Optimizer tensors (if provided)
425
+ opt_meta = None
426
+ if optimizer is not None:
427
+ opt_tensors, opt_meta = _flatten_optimizer_state(optimizer.state_dict())
428
+ if opt_tensors:
429
+ save_file(opt_tensors, tmp / "optimizer.safetensors")
430
+
431
+ # 3. Training state
432
+ training_state: dict = {
433
+ "format_version": CHECKPOINT_FORMAT_VERSION,
434
+ "epoch": epoch,
435
+ "step": step,
436
+ "val_metrics": val_metrics,
437
+ }
438
+ if opt_meta is not None:
439
+ training_state["optimizer_meta"] = opt_meta
440
+ if scheduler is not None:
441
+ training_state["scheduler_state_dict"] = scheduler.state_dict()
442
+ if scaler is not None:
443
+ training_state["scaler_state_dict"] = scaler.state_dict()
444
+ if extra:
445
+ training_state.update(extra)
446
+
447
+ with open(tmp / "training_state.json", "w") as f:
448
+ json.dump(training_state, f, indent=2, default=_json_default)
449
+
450
+ # 4. Config
451
+ with open(tmp / "config.json", "w") as f:
452
+ json.dump(config, f, indent=2, default=_json_default)
453
+
454
+
455
+ def load_adapter_checkpoint(
456
+ path: str | Path,
457
+ device: str = "cpu",
458
+ ) -> dict:
459
+ """Load an adapter checkpoint with integrity verification."""
460
+ path = Path(path)
461
+
462
+ if is_legacy_checkpoint(path):
463
+ ckpt = _load_legacy_pt(path, device)
464
+ adapter_key = None
465
+ for key in ("lora_state_dict", "bottleneck_state_dict", "film_state_dict",
466
+ "sparse_state_dict", "adapter_state_dict", "model_state_dict"):
467
+ if key in ckpt:
468
+ adapter_key = key
469
+ break
470
+ return {
471
+ "adapter_state_dict": ckpt.get(adapter_key, {}),
472
+ "config": ckpt.get("config", {}),
473
+ "epoch": ckpt.get("epoch", 0),
474
+ "step": ckpt.get("step", 0),
475
+ "val_metrics": {
476
+ "loss": ckpt.get("val_loss"),
477
+ "top1_accuracy": ckpt.get("val_top1"),
478
+ },
479
+ "optimizer_state_dict": ckpt.get("optimizer_state_dict"),
480
+ "scheduler_state_dict": ckpt.get("scheduler_state_dict"),
481
+ "scaler_state_dict": ckpt.get("scaler_state_dict"),
482
+ "best_val_loss": ckpt.get("best_val_loss"),
483
+ "patience_counter": ckpt.get("patience_counter"),
484
+ }
485
+
486
+ # New directory format — verify integrity first
487
+ _verify_complete_sentinel(path)
488
+
489
+ adapter_weights = load_file(path / "adapter.safetensors", device=device)
490
+
491
+ with open(path / "config.json") as f:
492
+ config = json.load(f)
493
+
494
+ ts = {}
495
+ ts_path = path / "training_state.json"
496
+ if ts_path.exists():
497
+ with open(ts_path) as f:
498
+ ts = json.load(f)
499
+
500
+ result: dict = {
501
+ "adapter_state_dict": adapter_weights,
502
+ "config": config,
503
+ "epoch": ts.get("epoch", 0),
504
+ "step": ts.get("step", 0),
505
+ "val_metrics": ts.get("val_metrics", {}),
506
+ "best_val_loss": ts.get("best_val_loss"),
507
+ "patience_counter": ts.get("patience_counter"),
508
+ }
509
+
510
+ if (path / "optimizer.safetensors").exists() and "optimizer_meta" in ts:
511
+ opt_tensors = load_file(path / "optimizer.safetensors", device=device)
512
+ result["optimizer_state_dict"] = _unflatten_optimizer_state(
513
+ opt_tensors, ts["optimizer_meta"], device
514
+ )
515
+
516
+ if "scheduler_state_dict" in ts:
517
+ result["scheduler_state_dict"] = ts["scheduler_state_dict"]
518
+
519
+ if "scaler_state_dict" in ts:
520
+ result["scaler_state_dict"] = ts["scaler_state_dict"]
521
+
522
+ return result
523
+
524
+
525
+ # ---------------------------------------------------------------------------
526
+ # Backbone-only loading (inference)
527
+ # ---------------------------------------------------------------------------
528
+
529
+ def load_backbone_weights(
530
+ path: str | Path,
531
+ device: str = "cpu",
532
+ ) -> tuple[dict[str, torch.Tensor], dict | None]:
533
+ """Load model weights and config for inference with integrity verification.
534
+
535
+ Works with:
536
+ - Legacy .pt files (extracts model_state_dict + model_config)
537
+ - New checkpoint directories (reads model.safetensors + config.json, verifies .complete)
538
+ - Bare model.safetensors files (no .complete check — used for HF downloads)
539
+
540
+ Returns (state_dict, model_config_dict_or_None).
541
+ """
542
+ path = Path(path)
543
+
544
+ if is_legacy_checkpoint(path):
545
+ ckpt = _load_legacy_pt(path, device)
546
+ return ckpt["model_state_dict"], ckpt.get("model_config")
547
+
548
+ # Directory with model.safetensors
549
+ if path.is_dir():
550
+ sf_path = path / "model.safetensors"
551
+ if not sf_path.exists():
552
+ raise FileNotFoundError(f"No model.safetensors in {path}")
553
+ # Verify integrity if .complete exists (new format checkpoints)
554
+ if (path / ".complete").exists():
555
+ _verify_complete_sentinel(path)
556
+ weights = load_file(sf_path, device=device)
557
+ config = None
558
+ config_path = path / "config.json"
559
+ if config_path.exists():
560
+ with open(config_path) as f:
561
+ config = json.load(f).get("model_config")
562
+ return weights, config
563
+
564
+ # Bare safetensors file
565
+ if path.suffix == ".safetensors":
566
+ weights = load_file(path, device=device)
567
+ config = None
568
+ config_path = path.parent / "config.json"
569
+ if config_path.exists():
570
+ with open(config_path) as f:
571
+ config = json.load(f).get("model_config")
572
+ return weights, config
573
+
574
+ raise ValueError(f"Unrecognized checkpoint format: {path}")
575
+
576
+
577
+ # ---------------------------------------------------------------------------
578
+ # HuggingFace push
579
+ # ---------------------------------------------------------------------------
580
+
581
+ def push_checkpoint_to_hf(
582
+ checkpoint_path: str | Path,
583
+ repo_id: str,
584
+ branch: str,
585
+ metrics_path: str | Path | None = None,
586
+ step: int = 0,
587
+ ) -> None:
588
+ """Push a complete checkpoint directory to a HuggingFace repo branch.
589
+
590
+ Uploads checkpoint files to checkpoints/step_NNNN/ on the branch.
591
+ Optionally uploads metrics.jsonl (truncated to current step) to the root.
592
+
593
+ Requires HF_TOKEN environment variable or prior `huggingface_hub.login()`.
594
+ """
595
+ from huggingface_hub import HfApi
596
+
597
+ checkpoint_path = Path(checkpoint_path)
598
+ api = HfApi()
599
+
600
+ # Ensure branch exists
601
+ try:
602
+ api.create_branch(repo_id, repo_type="model", branch=branch, exist_ok=True)
603
+ except Exception:
604
+ pass # Branch may already exist
605
+
606
+ # Upload checkpoint directory
607
+ api.upload_folder(
608
+ folder_path=str(checkpoint_path),
609
+ path_in_repo=f"checkpoints/{checkpoint_path.name}",
610
+ repo_id=repo_id,
611
+ repo_type="model",
612
+ revision=branch,
613
+ commit_message=f"Checkpoint step {step}",
614
+ )
615
+
616
+ # Upload truncated metrics.jsonl to repo root
617
+ if metrics_path is not None:
618
+ metrics_path = Path(metrics_path)
619
+ if metrics_path.exists():
620
+ import tempfile
621
+ # Truncate metrics to current step
622
+ truncated_lines = []
623
+ with open(metrics_path) as f:
624
+ for line in f:
625
+ truncated_lines.append(line)
626
+ try:
627
+ record = json.loads(line)
628
+ if record.get("type") in ("train", "val") and record.get("step", 0) >= step:
629
+ break
630
+ except json.JSONDecodeError:
631
+ continue
632
+
633
+ with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as tmp:
634
+ tmp.writelines(truncated_lines)
635
+ tmp_path = tmp.name
636
+
637
+ try:
638
+ api.upload_file(
639
+ path_or_fileobj=tmp_path,
640
+ path_in_repo="metrics.jsonl",
641
+ repo_id=repo_id,
642
+ repo_type="model",
643
+ revision=branch,
644
+ commit_message=f"Metrics through step {step}",
645
+ )
646
+ finally:
647
+ os.unlink(tmp_path)
pawn/config.py CHANGED
@@ -3,7 +3,7 @@
3
  from dataclasses import dataclass
4
 
5
 
6
- # Outcome token IDs (spec section 3.2)
7
  PAD_TOKEN = 0
8
  OUTCOME_TOKEN_BASE = 4273
9
  WHITE_CHECKMATES = 4273
 
3
  from dataclasses import dataclass
4
 
5
 
6
+ # Outcome token IDs must match engine/src/vocab.rs
7
  PAD_TOKEN = 0
8
  OUTCOME_TOKEN_BASE = 4273
9
  WHITE_CHECKMATES = 4273
pawn/data.py CHANGED
@@ -23,7 +23,6 @@ from pawn.config import (
23
  _positions_cache: dict[tuple[str, int], torch.Tensor] = {}
24
 
25
 
26
-
27
  def _map_termination_to_outcome(
28
  term_codes: np.ndarray, game_lengths: np.ndarray
29
  ) -> torch.Tensor:
@@ -51,46 +50,45 @@ def _map_termination_to_outcome(
51
  return outcomes
52
 
53
 
54
- def _to_clm_batch(
55
  move_ids: np.ndarray,
56
  game_lengths: np.ndarray,
57
- term_codes: np.ndarray,
58
  seq_len: int,
59
  ) -> dict[str, torch.Tensor]:
60
- """Convert Rust engine output to CLM training tensors.
61
 
62
  Constructs input_ids = [outcome, move_1, ..., move_N, PAD, ...]
63
  and targets shifted left by 1.
64
 
65
  Args:
66
- move_ids: (B, engine_max_ply) from generate_random_games
67
- game_lengths: (B,) actual game lengths (≤ engine_max_ply)
68
- term_codes: (B,) termination codes
69
  seq_len: total CLM sequence length (256)
70
  """
71
  B = len(game_lengths)
72
  n_move_slots = seq_len - 1 # 255 slots for moves (position 0 = outcome)
73
- engine_max_ply = move_ids.shape[1]
74
 
75
  game_lengths_t = torch.from_numpy(game_lengths).long()
76
- move_ids_t = torch.from_numpy(move_ids).long() # (B, engine_max_ply)
77
- outcome_tokens = _map_termination_to_outcome(term_codes, game_lengths)
78
 
79
  # Build input_ids: [outcome, move_0, ..., move_{N-1}, PAD, ...]
80
  input_ids = torch.zeros(B, seq_len, dtype=torch.long)
81
  input_ids[:, 0] = outcome_tokens
82
 
83
  # Mask out any non-move tokens from engine output
84
- cache_key = ("engine", engine_max_ply)
85
  engine_positions = _positions_cache.get(cache_key)
86
  if engine_positions is None:
87
- engine_positions = torch.arange(engine_max_ply).unsqueeze(0)
88
  _positions_cache[cache_key] = engine_positions
89
  move_mask = engine_positions < game_lengths_t.unsqueeze(1)
90
  clean_moves = move_ids_t * move_mask
91
 
92
  # Place moves at positions 1..n_move_slots
93
- n_to_copy = min(engine_max_ply, n_move_slots)
94
  input_ids[:, 1 : n_to_copy + 1] = clean_moves[:, :n_to_copy]
95
 
96
  # Cap game_lengths to n_move_slots (handles edge case where engine
@@ -116,6 +114,21 @@ def _to_clm_batch(
116
  }
117
 
118
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
  class CLMDataset(torch.utils.data.IterableDataset):
120
  """Generates CLM training data on-the-fly via the Rust engine.
121
 
@@ -156,16 +169,18 @@ class CLMDataset(torch.utils.data.IterableDataset):
156
  t.start()
157
 
158
  step = self._start_step
159
- # Engine generates games of up to max_ply-1 moves, leaving 1 slot
160
- # for the outcome token in the seq_len=max_ply CLM sequence.
161
- engine_max_ply = self.max_ply - 1
162
  while True:
163
  seed = self.base_seed + step * num_workers + worker_id
164
- move_ids, game_lengths, term_codes = engine.generate_random_games(
165
- self.batch_size, engine_max_ply, seed,
166
- discard_ply_limit=self.discard_ply_limit,
167
- )
168
- yield _to_clm_batch(move_ids, game_lengths, term_codes, self.max_ply)
 
 
 
 
 
169
  step += 1
170
 
171
 
@@ -178,17 +193,21 @@ def create_validation_set(
178
  Also computes legal move masks for legal move rate evaluation.
179
 
180
  Args:
181
- max_ply: total CLM sequence length (256). Engine gets max_ply-1.
182
  """
183
- engine_max_ply = max_ply - 1
184
- move_ids, game_lengths, term_codes = engine.generate_random_games(
185
- n_games, engine_max_ply, seed, discard_ply_limit=discard_ply_limit,
186
- )
187
- batch = _to_clm_batch(move_ids, game_lengths, term_codes, max_ply)
 
 
 
 
188
 
189
  # Compute legal move masks for evaluating legal move rate
190
  legal_grid, _legal_promo = engine.compute_legal_move_masks(move_ids, game_lengths)
191
- batch["legal_grid"] = torch.from_numpy(legal_grid).long() # (B, engine_max_ply, 64)
192
  batch["game_lengths"] = torch.from_numpy(game_lengths).long()
193
 
194
  return batch
 
23
  _positions_cache: dict[tuple[str, int], torch.Tensor] = {}
24
 
25
 
 
26
  def _map_termination_to_outcome(
27
  term_codes: np.ndarray, game_lengths: np.ndarray
28
  ) -> torch.Tensor:
 
50
  return outcomes
51
 
52
 
53
+ def pack_clm_sequences(
54
  move_ids: np.ndarray,
55
  game_lengths: np.ndarray,
56
+ outcome_tokens: torch.Tensor,
57
  seq_len: int,
58
  ) -> dict[str, torch.Tensor]:
59
+ """Pack move arrays into CLM training tensors.
60
 
61
  Constructs input_ids = [outcome, move_1, ..., move_N, PAD, ...]
62
  and targets shifted left by 1.
63
 
64
  Args:
65
+ move_ids: (B, max_ply) raw move token IDs
66
+ game_lengths: (B,) actual game lengths
67
+ outcome_tokens: (B,) pre-computed outcome token IDs (4273-4277)
68
  seq_len: total CLM sequence length (256)
69
  """
70
  B = len(game_lengths)
71
  n_move_slots = seq_len - 1 # 255 slots for moves (position 0 = outcome)
72
+ max_ply = move_ids.shape[1]
73
 
74
  game_lengths_t = torch.from_numpy(game_lengths).long()
75
+ move_ids_t = torch.from_numpy(move_ids).long() # (B, max_ply)
 
76
 
77
  # Build input_ids: [outcome, move_0, ..., move_{N-1}, PAD, ...]
78
  input_ids = torch.zeros(B, seq_len, dtype=torch.long)
79
  input_ids[:, 0] = outcome_tokens
80
 
81
  # Mask out any non-move tokens from engine output
82
+ cache_key = ("engine", max_ply)
83
  engine_positions = _positions_cache.get(cache_key)
84
  if engine_positions is None:
85
+ engine_positions = torch.arange(max_ply).unsqueeze(0)
86
  _positions_cache[cache_key] = engine_positions
87
  move_mask = engine_positions < game_lengths_t.unsqueeze(1)
88
  clean_moves = move_ids_t * move_mask
89
 
90
  # Place moves at positions 1..n_move_slots
91
+ n_to_copy = min(max_ply, n_move_slots)
92
  input_ids[:, 1 : n_to_copy + 1] = clean_moves[:, :n_to_copy]
93
 
94
  # Cap game_lengths to n_move_slots (handles edge case where engine
 
114
  }
115
 
116
 
117
+ def _to_clm_batch(
118
+ move_ids: np.ndarray,
119
+ game_lengths: np.ndarray,
120
+ term_codes: np.ndarray,
121
+ seq_len: int,
122
+ ) -> dict[str, torch.Tensor]:
123
+ """Convert Rust engine output to CLM training tensors.
124
+
125
+ Convenience wrapper: computes outcome tokens from termination codes,
126
+ then delegates to pack_clm_sequences.
127
+ """
128
+ outcome_tokens = _map_termination_to_outcome(term_codes, game_lengths)
129
+ return pack_clm_sequences(move_ids, game_lengths, outcome_tokens, seq_len)
130
+
131
+
132
  class CLMDataset(torch.utils.data.IterableDataset):
133
  """Generates CLM training data on-the-fly via the Rust engine.
134
 
 
169
  t.start()
170
 
171
  step = self._start_step
 
 
 
172
  while True:
173
  seed = self.base_seed + step * num_workers + worker_id
174
+ input_ids, targets, loss_mask, _move_ids, _gl, _tc = \
175
+ engine.generate_clm_batch(
176
+ self.batch_size, self.max_ply, seed,
177
+ discard_ply_limit=self.discard_ply_limit,
178
+ )
179
+ yield {
180
+ "input_ids": torch.from_numpy(input_ids).long(),
181
+ "targets": torch.from_numpy(targets).long(),
182
+ "loss_mask": torch.from_numpy(loss_mask),
183
+ }
184
  step += 1
185
 
186
 
 
193
  Also computes legal move masks for legal move rate evaluation.
194
 
195
  Args:
196
+ max_ply: total CLM sequence length (256).
197
  """
198
+ input_ids, targets, loss_mask, move_ids, game_lengths, _tc = \
199
+ engine.generate_clm_batch(
200
+ n_games, max_ply, seed, discard_ply_limit=discard_ply_limit,
201
+ )
202
+ batch = {
203
+ "input_ids": torch.from_numpy(input_ids).long(),
204
+ "targets": torch.from_numpy(targets).long(),
205
+ "loss_mask": torch.from_numpy(loss_mask),
206
+ }
207
 
208
  # Compute legal move masks for evaluating legal move rate
209
  legal_grid, _legal_promo = engine.compute_legal_move_masks(move_ids, game_lengths)
210
+ batch["legal_grid"] = torch.from_numpy(legal_grid).long()
211
  batch["game_lengths"] = torch.from_numpy(game_lengths).long()
212
 
213
  return batch
pawn/eval_suite/diagnostics.py CHANGED
@@ -8,7 +8,7 @@ import torch.nn.functional as F
8
  import chess_engine as engine
9
 
10
  from pawn.config import PAD_TOKEN
11
- from pawn.data import _to_clm_batch, _map_termination_to_outcome
12
 
13
 
14
  # ---------------------------------------------------------------------------
 
8
  import chess_engine as engine
9
 
10
  from pawn.config import PAD_TOKEN
11
+ from pawn.data import pack_clm_sequences, _map_termination_to_outcome
12
 
13
 
14
  # ---------------------------------------------------------------------------
pawn/eval_suite/lichess.py CHANGED
@@ -9,8 +9,8 @@ import torch.nn as nn
9
 
10
  import chess_engine as engine
11
 
12
- from pawn.config import PAD_TOKEN, WHITE_CHECKMATES
13
- from pawn.data import _to_clm_batch, _map_termination_to_outcome
14
 
15
 
16
  # ---------------------------------------------------------------------------
@@ -139,12 +139,11 @@ def evaluate_on_lichess(
139
  game_lengths = band_data["game_lengths"]
140
  n = len(move_ids)
141
 
142
- # We don't have true termination codes for Lichess games parsed via PGN
143
- # (the Rust parser returns move sequences, not outcomes).
144
- # Use a dummy termination code and outcome token.
145
- # For loss/accuracy evaluation, the outcome token choice doesn't matter
146
  # much since we evaluate on move prediction, not outcome prediction.
147
- term_codes = np.full(n, 5, dtype=np.uint8) # PLY_LIMIT as default
148
 
149
  engine_max_ply = max_seq_len - 1
150
  # Pad/truncate move_ids to engine_max_ply
@@ -154,7 +153,7 @@ def evaluate_on_lichess(
154
  padded[i, :gl] = move_ids[i, :gl]
155
  game_lengths_capped = np.minimum(game_lengths, engine_max_ply).astype(np.int16)
156
 
157
- batch = _to_clm_batch(padded, game_lengths_capped, term_codes, max_seq_len)
158
  input_ids = batch["input_ids"]
159
  targets = batch["targets"]
160
  loss_mask = batch["loss_mask"]
 
9
 
10
  import chess_engine as engine
11
 
12
+ from pawn.config import PAD_TOKEN, WHITE_CHECKMATES, PLY_LIMIT
13
+ from pawn.data import pack_clm_sequences, _map_termination_to_outcome
14
 
15
 
16
  # ---------------------------------------------------------------------------
 
139
  game_lengths = band_data["game_lengths"]
140
  n = len(move_ids)
141
 
142
+ # We don't have true termination codes for Lichess games parsed via PGN.
143
+ # Use PLY_LIMIT as a dummy outcome token for all games — for
144
+ # loss/accuracy evaluation the outcome token choice doesn't matter
 
145
  # much since we evaluate on move prediction, not outcome prediction.
146
+ dummy_outcomes = torch.full((n,), PLY_LIMIT, dtype=torch.long)
147
 
148
  engine_max_ply = max_seq_len - 1
149
  # Pad/truncate move_ids to engine_max_ply
 
153
  padded[i, :gl] = move_ids[i, :gl]
154
  game_lengths_capped = np.minimum(game_lengths, engine_max_ply).astype(np.int16)
155
 
156
+ batch = pack_clm_sequences(padded, game_lengths_capped, dummy_outcomes, max_seq_len)
157
  input_ids = batch["input_ids"]
158
  targets = batch["targets"]
159
  loss_mask = batch["loss_mask"]
pawn/eval_suite/probes.py CHANGED
@@ -15,7 +15,6 @@ import chess_engine as engine
15
 
16
  from pawn.config import CLMConfig
17
  from pawn.model import PAWNCLM
18
- from pawn.data import _to_clm_batch
19
 
20
 
21
  # ---------------------------------------------------------------------------
@@ -77,21 +76,16 @@ def extract_probe_data(
77
 
78
  Returns dict with all arrays needed for all probes.
79
  """
80
- engine_max_ply = max_ply - 1
81
-
82
- move_ids_np, game_lengths_np, term_codes_np = engine.generate_random_games(
83
- n_games, engine_max_ply, seed
84
- )
85
 
86
  boards_np, side_np, castling_np, ep_np, check_np, halfmove_np = (
87
  engine.extract_board_states(move_ids_np, game_lengths_np)
88
  )
89
 
90
- batch = _to_clm_batch(move_ids_np, game_lengths_np, term_codes_np, max_ply)
91
-
92
  result = {
93
- "input_ids": batch["input_ids"],
94
- "loss_mask": batch["loss_mask"],
95
  "boards": torch.from_numpy(boards_np.copy()).long(),
96
  "side_to_move": torch.from_numpy(side_np.copy()).float(),
97
  "castling_rights": torch.from_numpy(castling_np.copy()),
 
15
 
16
  from pawn.config import CLMConfig
17
  from pawn.model import PAWNCLM
 
18
 
19
 
20
  # ---------------------------------------------------------------------------
 
76
 
77
  Returns dict with all arrays needed for all probes.
78
  """
79
+ input_ids, targets, loss_mask, move_ids_np, game_lengths_np, _tc = \
80
+ engine.generate_clm_batch(n_games, max_ply, seed)
 
 
 
81
 
82
  boards_np, side_np, castling_np, ep_np, check_np, halfmove_np = (
83
  engine.extract_board_states(move_ids_np, game_lengths_np)
84
  )
85
 
 
 
86
  result = {
87
+ "input_ids": torch.from_numpy(input_ids).long(),
88
+ "loss_mask": torch.from_numpy(loss_mask),
89
  "boards": torch.from_numpy(boards_np.copy()).long(),
90
  "side_to_move": torch.from_numpy(side_np.copy()).float(),
91
  "castling_rights": torch.from_numpy(castling_np.copy()),
pawn/eval_suite/worker.py CHANGED
@@ -62,15 +62,15 @@ def run_in_worker(fn: Callable[..., Any], *args: Any, timeout: float | None = No
62
 
63
  def _load_model(checkpoint_path: str, device: str) -> PAWNCLM:
64
  """Load and freeze a PAWNCLM checkpoint. Runs inside worker processes."""
65
- import torch
66
  from pawn.config import CLMConfig
67
  from pawn.model import PAWNCLM
68
 
69
- ckpt = torch.load(checkpoint_path, map_location=device, weights_only=False)
70
- cfg = CLMConfig(**ckpt["model_config"]) if "model_config" in ckpt else CLMConfig()
71
  model = PAWNCLM(cfg).to(device)
72
- model.load_state_dict(ckpt["model_state_dict"])
73
- del ckpt
74
  gc.collect()
75
  model.eval()
76
  for p in model.parameters():
 
62
 
63
  def _load_model(checkpoint_path: str, device: str) -> PAWNCLM:
64
  """Load and freeze a PAWNCLM checkpoint. Runs inside worker processes."""
65
+ from pawn.checkpoint import load_backbone_weights
66
  from pawn.config import CLMConfig
67
  from pawn.model import PAWNCLM
68
 
69
+ state_dict, model_config = load_backbone_weights(checkpoint_path, device)
70
+ cfg = CLMConfig(**model_config) if model_config else CLMConfig()
71
  model = PAWNCLM(cfg).to(device)
72
+ model.load_state_dict(state_dict)
73
+ del state_dict
74
  gc.collect()
75
  model.eval()
76
  for p in model.parameters():
pawn/lichess_data.py CHANGED
@@ -219,32 +219,15 @@ def prepare_lichess_dataset(
219
 
220
  seq_len = max_ply + 1 # outcome token + max_ply move slots
221
 
222
- # Build CLM sequences: [outcome, move_0, ..., move_{N-1}, PAD, ...]
223
- input_ids = torch.zeros(N, seq_len, dtype=torch.long)
224
- input_ids[:, 0] = outcome_tokens
225
-
226
- gl_t = torch.from_numpy(game_lengths).long()
227
- mid_t = torch.from_numpy(move_ids).long()
228
-
229
- ply_range = torch.arange(max_ply).unsqueeze(0)
230
- move_mask = ply_range < gl_t.unsqueeze(1)
231
- input_ids[:, 1:] = mid_t * move_mask
232
-
233
- # Targets: shifted left by 1
234
- targets = torch.zeros(N, seq_len, dtype=torch.long)
235
- targets[:, :-1] = input_ids[:, 1:]
236
-
237
- # Loss mask: positions 0..game_length-1 (each has a valid move target)
238
- # Position gl would target PAD, which we don't want to train on.
239
- seq_positions = torch.arange(seq_len).unsqueeze(0)
240
- loss_mask = seq_positions < gl_t.unsqueeze(1)
241
 
242
  return {
243
  "move_ids": move_ids,
244
  "game_lengths": game_lengths,
245
- "input_ids": input_ids,
246
- "targets": targets,
247
- "loss_mask": loss_mask,
248
  "outcome_tokens": outcome_tokens,
249
  "n_games": N,
250
  }
 
219
 
220
  seq_len = max_ply + 1 # outcome token + max_ply move slots
221
 
222
+ from pawn.data import pack_clm_sequences
223
+ batch = pack_clm_sequences(move_ids, game_lengths, outcome_tokens, seq_len)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
224
 
225
  return {
226
  "move_ids": move_ids,
227
  "game_lengths": game_lengths,
228
+ "input_ids": batch["input_ids"],
229
+ "targets": batch["targets"],
230
+ "loss_mask": batch["loss_mask"],
231
  "outcome_tokens": outcome_tokens,
232
  "n_games": N,
233
  }
pawn/model.py CHANGED
@@ -37,7 +37,7 @@ def _build_decomposition_table() -> torch.Tensor:
37
 
38
  for token_idx, uci_str in vocab["token_to_move"].items():
39
  if token_idx >= OUTCOME_TOKEN_BASE:
40
- continue # Skip EOG (and beyond)
41
  src_name = uci_str[:2]
42
  dst_name = uci_str[2:4]
43
  promo_suffix = uci_str[4:] if len(uci_str) > 4 else ""
 
37
 
38
  for token_idx, uci_str in vocab["token_to_move"].items():
39
  if token_idx >= OUTCOME_TOKEN_BASE:
40
+ continue # Outcome tokens use standalone embeddings
41
  src_name = uci_str[:2]
42
  dst_name = uci_str[2:4]
43
  promo_suffix = uci_str[4:] if len(uci_str) > 4 else ""
pawn/trainer.py CHANGED
@@ -207,17 +207,27 @@ def _make_run_dir(base_log_dir: str) -> str:
207
 
208
 
209
  class CLMTrainer:
210
- def __init__(self, train_cfg: TrainingConfig, model_cfg: CLMConfig):
 
 
 
 
 
211
  self.cfg = train_cfg
212
  self.model_cfg = model_cfg
213
  self.device = train_cfg.device
214
  self.global_step = 0
 
 
215
 
216
  self.run_dir = _make_run_dir(train_cfg.log_dir)
217
  self.cfg.checkpoint_dir = os.path.join(self.run_dir, "checkpoints")
218
  self._jsonl_path = os.path.join(self.run_dir, "metrics.jsonl")
219
  self._jsonl_file = None
220
 
 
 
 
221
  self._model = PAWNCLM(model_cfg).to(self.device)
222
  self.model = self._model
223
  param_count = sum(p.numel() for p in self._model.parameters())
@@ -449,12 +459,13 @@ class CLMTrainer:
449
  prefetch_factor=1 if num_workers > 0 else None,
450
  )
451
 
 
 
 
452
  def _graceful_exit(signum, frame):
453
- print(f"\nReceived signal {signum}, saving checkpoint and exiting...")
454
- self.save_checkpoint()
455
- if self._jsonl_file:
456
- self._jsonl_file.close()
457
- sys.exit(128 + signum)
458
 
459
  old_term = signal.signal(signal.SIGTERM, _graceful_exit)
460
  old_int = signal.signal(signal.SIGINT, _graceful_exit)
@@ -547,6 +558,12 @@ class CLMTrainer:
547
  self.save_checkpoint()
548
  break
549
 
 
 
 
 
 
 
550
  step_start = time.time()
551
 
552
  signal.signal(signal.SIGTERM, old_term)
@@ -557,49 +574,47 @@ class CLMTrainer:
557
  self._jsonl_file = None
558
 
559
  def save_checkpoint(self, path: str | None = None):
 
 
560
  if path is None:
561
  path = os.path.join(
562
- self.cfg.checkpoint_dir, f"step_{self.global_step:08d}.pt"
563
  )
564
 
565
- dirname = os.path.dirname(path)
566
- if dirname:
567
- os.makedirs(dirname, exist_ok=True)
568
-
569
  model: PAWNCLM = self._eager_model()
570
 
571
- torch.save(
572
- {
573
- "global_step": self.global_step,
574
- "model_state_dict": model.state_dict(),
575
- "optimizer_state_dict": self.optimizer.state_dict(),
576
- "scheduler_state_dict": self.scheduler.state_dict(),
577
- "scaler_state_dict": self.scaler.state_dict(),
578
- "model_config": self.model_cfg.__dict__,
579
- "training_config": self.cfg.__dict__,
580
- "torch_rng_state": torch.get_rng_state(),
581
- "cuda_rng_state": (
582
- torch.cuda.get_rng_state() if torch.cuda.is_available() else None
583
- ),
584
- },
585
  path,
 
 
 
 
 
 
 
586
  )
587
  print(f"Checkpoint saved: {path}")
588
 
 
 
 
 
 
 
 
 
 
 
 
 
589
  def load_checkpoint(self, path: str):
590
- ckpt = torch.load(path, map_location=self.device, weights_only=False)
591
- self.global_step = ckpt["global_step"]
592
 
593
  model: PAWNCLM = self._eager_model()
594
 
595
- model.load_state_dict(ckpt["model_state_dict"])
596
- self.optimizer.load_state_dict(ckpt["optimizer_state_dict"])
597
- self.scheduler.load_state_dict(ckpt["scheduler_state_dict"])
598
- self.scaler.load_state_dict(ckpt["scaler_state_dict"])
599
-
600
- if ckpt.get("torch_rng_state") is not None:
601
- torch.set_rng_state(ckpt["torch_rng_state"].cpu().byte())
602
- if ckpt.get("cuda_rng_state") is not None and torch.cuda.is_available():
603
- torch.cuda.set_rng_state(ckpt["cuda_rng_state"].cpu().byte())
604
-
605
  print(f"Resumed from step {self.global_step}")
 
207
 
208
 
209
  class CLMTrainer:
210
+ def __init__(
211
+ self,
212
+ train_cfg: TrainingConfig,
213
+ model_cfg: CLMConfig,
214
+ hf_repo: str | None = None,
215
+ ):
216
  self.cfg = train_cfg
217
  self.model_cfg = model_cfg
218
  self.device = train_cfg.device
219
  self.global_step = 0
220
+ self.hf_repo = hf_repo
221
+ self.hf_branch: str | None = None
222
 
223
  self.run_dir = _make_run_dir(train_cfg.log_dir)
224
  self.cfg.checkpoint_dir = os.path.join(self.run_dir, "checkpoints")
225
  self._jsonl_path = os.path.join(self.run_dir, "metrics.jsonl")
226
  self._jsonl_file = None
227
 
228
+ if self.hf_repo:
229
+ self.hf_branch = f"run/{os.path.basename(self.run_dir)}"
230
+
231
  self._model = PAWNCLM(model_cfg).to(self.device)
232
  self.model = self._model
233
  param_count = sum(p.numel() for p in self._model.parameters())
 
459
  prefetch_factor=1 if num_workers > 0 else None,
460
  )
461
 
462
+ _shutdown_requested = False
463
+ _shutdown_signal = None
464
+
465
  def _graceful_exit(signum, frame):
466
+ nonlocal _shutdown_requested, _shutdown_signal
467
+ _shutdown_requested = True
468
+ _shutdown_signal = signum
 
 
469
 
470
  old_term = signal.signal(signal.SIGTERM, _graceful_exit)
471
  old_int = signal.signal(signal.SIGINT, _graceful_exit)
 
558
  self.save_checkpoint()
559
  break
560
 
561
+ if _shutdown_requested:
562
+ print(f"\nShutdown requested (signal {_shutdown_signal}), "
563
+ f"saving checkpoint at step {self.global_step}...")
564
+ self.save_checkpoint()
565
+ break
566
+
567
  step_start = time.time()
568
 
569
  signal.signal(signal.SIGTERM, old_term)
 
574
  self._jsonl_file = None
575
 
576
  def save_checkpoint(self, path: str | None = None):
577
+ from pawn.checkpoint import save_pretrain_checkpoint
578
+
579
  if path is None:
580
  path = os.path.join(
581
+ self.cfg.checkpoint_dir, f"step_{self.global_step:08d}"
582
  )
583
 
 
 
 
 
584
  model: PAWNCLM = self._eager_model()
585
 
586
+ save_pretrain_checkpoint(
 
 
 
 
 
 
 
 
 
 
 
 
 
587
  path,
588
+ model,
589
+ self.optimizer,
590
+ self.scheduler,
591
+ self.scaler,
592
+ self.global_step,
593
+ self.model_cfg.__dict__,
594
+ self.cfg.__dict__,
595
  )
596
  print(f"Checkpoint saved: {path}")
597
 
598
+ if self.hf_repo and self.hf_branch:
599
+ from pawn.checkpoint import push_checkpoint_to_hf
600
+ try:
601
+ push_checkpoint_to_hf(
602
+ path, self.hf_repo, self.hf_branch,
603
+ metrics_path=self._jsonl_path,
604
+ step=self.global_step,
605
+ )
606
+ print(f"Pushed to HF: {self.hf_repo}@{self.hf_branch}")
607
+ except Exception as e:
608
+ print(f"WARNING: HF push failed: {e}")
609
+
610
  def load_checkpoint(self, path: str):
611
+ from pawn.checkpoint import load_pretrain_checkpoint
 
612
 
613
  model: PAWNCLM = self._eager_model()
614
 
615
+ meta = load_pretrain_checkpoint(
616
+ path, model, self.optimizer, self.scheduler, self.scaler,
617
+ device=self.device,
618
+ )
619
+ self.global_step = meta["global_step"]
 
 
 
 
 
620
  print(f"Resumed from step {self.global_step}")
pyproject.toml CHANGED
@@ -8,6 +8,7 @@ dependencies = [
8
  "chess-engine",
9
  "numpy~=2.2.0",
10
  "psutil>=5.9.0",
 
11
  "tqdm~=4.67.0",
12
  "wandb~=0.25.0",
13
  ]
 
8
  "chess-engine",
9
  "numpy~=2.2.0",
10
  "psutil>=5.9.0",
11
+ "safetensors>=0.4.0",
12
  "tqdm~=4.67.0",
13
  "wandb~=0.25.0",
14
  ]
scripts/check_progress.sh CHANGED
@@ -1,33 +1,34 @@
1
  #!/usr/bin/env bash
2
- # Check training progress, sync from pods, and auto-stop finished pods.
3
- # Usage: check_progress.sh [--sync] [--auto-stop] [LOG_DIR]
4
  set -euo pipefail
5
 
6
  SYNC=false
7
- AUTO_STOP=false
8
  LOG_DIR=""
9
 
10
  for arg in "$@"; do
11
  case "$arg" in
12
  --sync) SYNC=true ;;
13
- --auto-stop) AUTO_STOP=true ;;
14
  *) LOG_DIR="$arg" ;;
15
  esac
16
  done
17
  LOG_DIR="${LOG_DIR:-logs}"
18
 
19
  REPO="$(cd "$(dirname "$0")/.." && pwd)"
20
- POD_DIR="$HOME/.config/pawn/pods"
21
 
22
- # Sync from all pods
23
- if $SYNC && [ -d "$POD_DIR" ]; then
24
  bash "$REPO/deploy/sync.sh" 2>/dev/null || true
25
  fi
26
 
27
- # Show progress for top 5 most recent runs
28
- find "$LOG_DIR" -name metrics.jsonl -printf '%T@ %p\n' 2>/dev/null \
29
- | sort -rn | head -n 5 | while read -r _ path; do
 
 
 
30
  run_name="$(basename "$(dirname "$path")")"
 
31
  python3 -c "
32
  import json, sys
33
  records = [json.loads(l) for l in open('$path')]
@@ -51,29 +52,8 @@ print(f'{\"$run_name\":<28} {variant} discard_ply={str(discard):<5} step {ste
51
  done
52
 
53
  # Check local training process
54
- if pgrep -f 'train.py.*discard-ply-limit' > /dev/null 2>&1; then
55
- echo "Local discard-ply-limit run: RUNNING"
56
  else
57
- echo "WARNING: Local discard-ply-limit run: NOT RUNNING"
58
- fi
59
-
60
- # Auto-stop finished pods
61
- if $AUTO_STOP && [ -d "$POD_DIR" ]; then
62
- for env_file in "$POD_DIR"/*.env; do
63
- [ -f "$env_file" ] || continue
64
- pod_name="$(basename "${env_file%.env}")"
65
- unset POD_ID POD_HOST POD_PORT POD_GPU 2>/dev/null || true
66
- source "$env_file"
67
-
68
- # Check if process is alive on pod
69
- alive=$(ssh -o ConnectTimeout=5 -p "$POD_PORT" "root@$POD_HOST" \
70
- "pgrep -f 'train.py' > /dev/null 2>&1 && echo yes || echo no" 2>/dev/null || echo "unreachable")
71
-
72
- if [ "$alive" = "no" ]; then
73
- echo ">>> $pod_name: training finished. Final sync + stopping..."
74
- bash "$REPO/deploy/sync.sh" "$pod_name" 2>/dev/null || true
75
- runpodctl pod stop "$POD_ID" 2>/dev/null
76
- echo ">>> $pod_name ($POD_ID) STOPPED"
77
- fi
78
- done
79
  fi
 
1
  #!/usr/bin/env bash
2
+ # Check training progress from HuggingFace submodules and local logs.
3
+ # Usage: check_progress.sh [--sync] [LOG_DIR]
4
  set -euo pipefail
5
 
6
  SYNC=false
 
7
  LOG_DIR=""
8
 
9
  for arg in "$@"; do
10
  case "$arg" in
11
  --sync) SYNC=true ;;
 
12
  *) LOG_DIR="$arg" ;;
13
  esac
14
  done
15
  LOG_DIR="${LOG_DIR:-logs}"
16
 
17
  REPO="$(cd "$(dirname "$0")/.." && pwd)"
 
18
 
19
+ # Sync submodules from HuggingFace
20
+ if $SYNC; then
21
  bash "$REPO/deploy/sync.sh" 2>/dev/null || true
22
  fi
23
 
24
+ # Show progress from all metrics.jsonl files (local logs + submodules)
25
+ N=5
26
+ {
27
+ find "$LOG_DIR" -name metrics.jsonl -printf '%T@ %p\n' 2>/dev/null
28
+ find "$REPO/checkpoints" -name metrics.jsonl -printf '%T@ %p\n' 2>/dev/null
29
+ } | sort -rn | head -n "$N" | while read -r _ path; do
30
  run_name="$(basename "$(dirname "$path")")"
31
+
32
  python3 -c "
33
  import json, sys
34
  records = [json.loads(l) for l in open('$path')]
 
52
  done
53
 
54
  # Check local training process
55
+ if pgrep -f 'train.py' > /dev/null 2>&1; then
56
+ echo "Local training: RUNNING"
57
  else
58
+ echo "Local training: NOT RUNNING"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  fi
scripts/eval_accuracy.py CHANGED
@@ -65,37 +65,47 @@ def parse_args():
65
  return p.parse_args()
66
 
67
 
68
- def _detect_adapter_type(ckpt: dict) -> str:
69
- """Auto-detect adapter type from checkpoint keys."""
70
- if "adapter_state_dict" in ckpt:
 
 
 
 
 
 
 
 
71
  return "hybrid"
72
- if "lora_state_dict" in ckpt:
73
  return "lora"
74
- if "film_state_dict" in ckpt:
75
- return "film"
76
- if "sparse_state_dict" in ckpt:
77
  return "sparse"
78
- if "bottleneck_state_dict" in ckpt:
79
- return "bottleneck"
80
- raise ValueError("Cannot detect adapter type from checkpoint keys: "
81
- + ", ".join(ckpt.keys()))
 
82
 
83
 
84
  def load_model(checkpoint_path: str, adapter_path: str, device: str):
85
  """Load backbone + adapter, auto-detecting adapter type."""
 
 
86
  # Backbone
87
- ckpt = torch.load(checkpoint_path, map_location=device, weights_only=False)
88
- cfg = CLMConfig(**ckpt["model_config"]) if "model_config" in ckpt else CLMConfig()
89
  backbone = PAWNCLM(cfg).to(device)
90
- backbone.load_state_dict(ckpt["model_state_dict"])
91
- del ckpt
92
  gc.collect()
93
  backbone.eval()
94
 
95
  # Adapter
96
- adapter_ckpt = torch.load(adapter_path, map_location=device, weights_only=False)
97
- adapter_type = _detect_adapter_type(adapter_ckpt)
98
- adapter_config = adapter_ckpt.get("config", {})
 
99
 
100
  if adapter_type == "lora":
101
  from pawn.adapters.lora import LoRACLM
@@ -108,15 +118,15 @@ def load_model(checkpoint_path: str, adapter_path: str, device: str):
108
  layers=tuple(int(x) for x in adapter_config["lora_layers"].split(","))
109
  if adapter_config.get("lora_layers") else None,
110
  ).to(device)
111
- model.load_lora_state_dict(adapter_ckpt["lora_state_dict"])
112
 
113
  elif adapter_type == "film":
114
  from pawn.adapters.film import FiLMCLM
115
  has_output = not adapter_config.get("no_output_film", False)
116
- if any(k.startswith("output_film.") for k in adapter_ckpt["film_state_dict"]):
117
  has_output = True
118
  model = FiLMCLM(backbone, use_output_film=has_output).to(device)
119
- model.load_film_state_dict(adapter_ckpt["film_state_dict"])
120
 
121
  elif adapter_type == "hybrid":
122
  from pawn.adapters.hybrid import HybridCLM
@@ -133,7 +143,7 @@ def load_model(checkpoint_path: str, adapter_path: str, device: str):
133
  use_output_film=adapter_config.get("output_film", False),
134
  film_layers=tuple(int(x) for x in film_layers.split(",")) if film_layers else None,
135
  ).to(device)
136
- model.load_adapter_state_dict(adapter_ckpt["adapter_state_dict"])
137
 
138
  elif adapter_type == "sparse":
139
  from pawn.adapters.sparse import SparseCLM
@@ -148,7 +158,7 @@ def load_model(checkpoint_path: str, adapter_path: str, device: str):
148
  layers=tuple(int(x) for x in sparse_layers.split(",")) if sparse_layers else None,
149
  seed=adapter_config.get("sparse_seed", 42),
150
  ).to(device)
151
- model.load_sparse_state_dict(adapter_ckpt["sparse_state_dict"])
152
 
153
  elif adapter_type == "bottleneck":
154
  from pawn.adapters.bottleneck import BottleneckCLM
@@ -160,7 +170,7 @@ def load_model(checkpoint_path: str, adapter_path: str, device: str):
160
  adapt_ffn=adapter_config.get("adapt_ffn", True),
161
  layers=tuple(int(x) for x in adapter_layers_str.split(",")) if adapter_layers_str else None,
162
  ).to(device)
163
- model.load_adapter_state_dict(adapter_ckpt["bottleneck_state_dict"])
164
 
165
  model.eval()
166
  return model, adapter_type
 
65
  return p.parse_args()
66
 
67
 
68
+ def _detect_adapter_type(config: dict) -> str:
69
+ """Auto-detect adapter type from config dict.
70
+
71
+ The config dict comes from load_adapter_checkpoint()["config"], which
72
+ contains training args for both legacy .pt files and new-format directories.
73
+ """
74
+ if "checkpoint_type" in config:
75
+ return config["checkpoint_type"]
76
+ if "bottleneck_dim" in config:
77
+ return "bottleneck"
78
+ if "lora_rank" in config and config.get("use_film") is not None:
79
  return "hybrid"
80
+ if "lora_rank" in config:
81
  return "lora"
82
+ if "density" in config:
 
 
83
  return "sparse"
84
+ if "no_output_film" in config:
85
+ return "film"
86
+
87
+ raise ValueError("Cannot detect adapter type from config keys: "
88
+ + ", ".join(config.keys()))
89
 
90
 
91
  def load_model(checkpoint_path: str, adapter_path: str, device: str):
92
  """Load backbone + adapter, auto-detecting adapter type."""
93
+ from pawn.checkpoint import load_backbone_weights, load_adapter_checkpoint
94
+
95
  # Backbone
96
+ state_dict, model_config = load_backbone_weights(checkpoint_path, device)
97
+ cfg = CLMConfig(**model_config) if model_config else CLMConfig()
98
  backbone = PAWNCLM(cfg).to(device)
99
+ backbone.load_state_dict(state_dict)
100
+ del state_dict
101
  gc.collect()
102
  backbone.eval()
103
 
104
  # Adapter
105
+ adapter_data = load_adapter_checkpoint(adapter_path, device)
106
+ adapter_weights = adapter_data["adapter_state_dict"]
107
+ adapter_config = adapter_data.get("config", {})
108
+ adapter_type = _detect_adapter_type(adapter_config)
109
 
110
  if adapter_type == "lora":
111
  from pawn.adapters.lora import LoRACLM
 
118
  layers=tuple(int(x) for x in adapter_config["lora_layers"].split(","))
119
  if adapter_config.get("lora_layers") else None,
120
  ).to(device)
121
+ model.load_lora_state_dict(adapter_weights)
122
 
123
  elif adapter_type == "film":
124
  from pawn.adapters.film import FiLMCLM
125
  has_output = not adapter_config.get("no_output_film", False)
126
+ if any(k.startswith("output_film.") for k in adapter_weights):
127
  has_output = True
128
  model = FiLMCLM(backbone, use_output_film=has_output).to(device)
129
+ model.load_film_state_dict(adapter_weights)
130
 
131
  elif adapter_type == "hybrid":
132
  from pawn.adapters.hybrid import HybridCLM
 
143
  use_output_film=adapter_config.get("output_film", False),
144
  film_layers=tuple(int(x) for x in film_layers.split(",")) if film_layers else None,
145
  ).to(device)
146
+ model.load_adapter_state_dict(adapter_weights)
147
 
148
  elif adapter_type == "sparse":
149
  from pawn.adapters.sparse import SparseCLM
 
158
  layers=tuple(int(x) for x in sparse_layers.split(",")) if sparse_layers else None,
159
  seed=adapter_config.get("sparse_seed", 42),
160
  ).to(device)
161
+ model.load_sparse_state_dict(adapter_weights)
162
 
163
  elif adapter_type == "bottleneck":
164
  from pawn.adapters.bottleneck import BottleneckCLM
 
170
  adapt_ffn=adapter_config.get("adapt_ffn", True),
171
  layers=tuple(int(x) for x in adapter_layers_str.split(",")) if adapter_layers_str else None,
172
  ).to(device)
173
+ model.load_adapter_state_dict(adapter_weights)
174
 
175
  model.eval()
176
  return model, adapter_type
scripts/eval_probes.py CHANGED
@@ -11,20 +11,17 @@ import torch
11
  from pawn.config import CLMConfig
12
  from pawn.model import PAWNCLM
13
  from pawn.eval_suite.probes import extract_probe_data, train_all_probes
14
- from pawn.gpu import configure_gpu
15
 
16
 
17
  def load_model_from_checkpoint(checkpoint_path: str, device: str) -> PAWNCLM:
18
- ckpt = torch.load(checkpoint_path, map_location=device, weights_only=False)
19
- config_dict = ckpt.get("model_config")
20
- if config_dict:
21
- cfg = CLMConfig(**config_dict)
22
  else:
23
- state = ckpt["model_state_dict"]
24
- d_model = state["embed.src_embed.weight"].shape[1]
25
- n_layers = max(int(k.split(".")[1]) for k in state if k.startswith("layers.")) + 1
26
- n_heads = state["layers.0.attn.wq.weight"].shape[0] // (d_model // (d_model // (state["layers.0.attn.wq.weight"].shape[0] // d_model * d_model // state["layers.0.attn.wq.weight"].shape[0]) if True else 1))
27
- # Infer config from known variants
28
  if d_model == 256 and n_layers == 8:
29
  cfg = CLMConfig.small()
30
  elif d_model == 512 and n_layers == 8:
@@ -33,9 +30,8 @@ def load_model_from_checkpoint(checkpoint_path: str, device: str) -> PAWNCLM:
33
  cfg = CLMConfig.large()
34
  else:
35
  cfg = CLMConfig(d_model=d_model, n_layers=n_layers)
36
-
37
  model = PAWNCLM(cfg).to(device)
38
- model.load_state_dict(ckpt["model_state_dict"])
39
  model.eval()
40
  return model
41
 
@@ -52,6 +48,7 @@ def main():
52
 
53
  device = args.device or ("cuda" if torch.cuda.is_available() else "cpu")
54
  if device == "cuda":
 
55
  gpu_cfg = configure_gpu()
56
  import pawn.model as model_module
57
  model_module.SDPA_BACKEND = gpu_cfg.get("sdpa_backend")
 
11
  from pawn.config import CLMConfig
12
  from pawn.model import PAWNCLM
13
  from pawn.eval_suite.probes import extract_probe_data, train_all_probes
 
14
 
15
 
16
  def load_model_from_checkpoint(checkpoint_path: str, device: str) -> PAWNCLM:
17
+ from pawn.checkpoint import load_backbone_weights
18
+ state_dict, model_config = load_backbone_weights(checkpoint_path, device)
19
+ if model_config:
20
+ cfg = CLMConfig(**model_config)
21
  else:
22
+ # Fallback: infer from state dict shapes
23
+ d_model = state_dict["embed.src_embed.weight"].shape[1]
24
+ n_layers = max(int(k.split(".")[1]) for k in state_dict if k.startswith("layers.")) + 1
 
 
25
  if d_model == 256 and n_layers == 8:
26
  cfg = CLMConfig.small()
27
  elif d_model == 512 and n_layers == 8:
 
30
  cfg = CLMConfig.large()
31
  else:
32
  cfg = CLMConfig(d_model=d_model, n_layers=n_layers)
 
33
  model = PAWNCLM(cfg).to(device)
34
+ model.load_state_dict(state_dict)
35
  model.eval()
36
  return model
37
 
 
48
 
49
  device = args.device or ("cuda" if torch.cuda.is_available() else "cpu")
50
  if device == "cuda":
51
+ from pawn.gpu import configure_gpu
52
  gpu_cfg = configure_gpu()
53
  import pawn.model as model_module
54
  model_module.SDPA_BACKEND = gpu_cfg.get("sdpa_backend")
scripts/export_hf_repo.py ADDED
@@ -0,0 +1,283 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Export a training run to HuggingFace repo format.
3
+
4
+ Converts .pt checkpoints to safetensors and structures files for HF upload:
5
+ - Root: best checkpoint (model.safetensors, config.json, metrics.jsonl, README.md)
6
+ - checkpoints/step_NNNN/: other checkpoints with truncated metrics
7
+
8
+ Usage:
9
+ python scripts/export_hf_repo.py \
10
+ --run-dir logs/run_20260322_182707 \
11
+ --output-dir export/pawn-base \
12
+ --repo-name pawn-base \
13
+ --github-url https://github.com/thomas-schweich/PAWN
14
+ """
15
+
16
+ from __future__ import annotations
17
+
18
+ import argparse
19
+ import json
20
+ import shutil
21
+ from pathlib import Path
22
+
23
+ import torch
24
+ from safetensors.torch import save_file
25
+
26
+
27
+ def find_best_step(metrics_path: Path) -> int | None:
28
+ """Find the step with lowest val loss from metrics.jsonl."""
29
+ best_loss = float("inf")
30
+ best_step = None
31
+ with open(metrics_path) as f:
32
+ for line in f:
33
+ record = json.loads(line)
34
+ if record.get("type") != "val":
35
+ continue
36
+ loss = record.get("val/loss", float("inf"))
37
+ step = record.get("step")
38
+ if loss < best_loss and step is not None:
39
+ best_loss = loss
40
+ best_step = step
41
+ return best_step
42
+
43
+
44
+ def truncate_metrics(metrics_path: Path, up_to_step: int) -> list[str]:
45
+ """Return metrics lines up to and including the given step."""
46
+ lines = []
47
+ with open(metrics_path) as f:
48
+ for line in f:
49
+ lines.append(line)
50
+ record = json.loads(line)
51
+ if record.get("type") in ("train", "val") and record.get("step", 0) > up_to_step:
52
+ break
53
+ return lines
54
+
55
+
56
+ def convert_pt_to_safetensors(pt_path: Path, output_dir: Path):
57
+ """Convert a .pt checkpoint to safetensors + JSON directory format."""
58
+ output_dir.mkdir(parents=True, exist_ok=True)
59
+
60
+ ckpt = torch.load(str(pt_path), map_location="cpu", weights_only=False)
61
+
62
+ # Model weights -> safetensors
63
+ state_dict = ckpt["model_state_dict"]
64
+ tensors = {k: v.cpu().contiguous() for k, v in state_dict.items()}
65
+ save_file(tensors, output_dir / "model.safetensors")
66
+
67
+ # Config
68
+ config = {
69
+ "format_version": 1,
70
+ "checkpoint_type": "pretrain",
71
+ "model_config": ckpt.get("model_config", {}),
72
+ "training_config": ckpt.get("training_config", {}),
73
+ }
74
+ with open(output_dir / "config.json", "w") as f:
75
+ json.dump(config, f, indent=2, default=str)
76
+
77
+ # Optimizer -> safetensors (if present)
78
+ if "optimizer_state_dict" in ckpt:
79
+ from pawn.checkpoint import _flatten_optimizer_state, _rng_to_json, _json_default
80
+ opt_tensors, opt_meta = _flatten_optimizer_state(ckpt["optimizer_state_dict"])
81
+ if opt_tensors:
82
+ save_file(opt_tensors, output_dir / "optimizer.safetensors")
83
+
84
+ # Training state
85
+ training_state = {
86
+ "format_version": 1,
87
+ "global_step": ckpt.get("global_step", 0),
88
+ "scheduler_state_dict": ckpt.get("scheduler_state_dict"),
89
+ "scaler_state_dict": ckpt.get("scaler_state_dict"),
90
+ "optimizer_meta": opt_meta,
91
+ }
92
+ rng_state = {}
93
+ if ckpt.get("torch_rng_state") is not None:
94
+ rng_state.update(_rng_to_json(ckpt["torch_rng_state"], ckpt.get("cuda_rng_state")))
95
+ training_state.update(rng_state)
96
+
97
+ with open(output_dir / "training_state.json", "w") as f:
98
+ json.dump(training_state, f, indent=2, default=_json_default)
99
+
100
+
101
+ def generate_readme(
102
+ repo_name: str, model_config: dict, training_config: dict,
103
+ best_step: int, val_loss: float, val_acc: float,
104
+ github_url: str, extra_desc: str = "",
105
+ ) -> str:
106
+ """Generate a HuggingFace model card README."""
107
+ d_model = model_config.get("d_model", "?")
108
+ n_layers = model_config.get("n_layers", "?")
109
+ n_heads = model_config.get("n_heads", "?")
110
+ discard = training_config.get("discard_ply_limit", False)
111
+
112
+ # Infer variant name
113
+ variant = "base"
114
+ if d_model == 256:
115
+ variant = "small"
116
+ elif d_model == 640:
117
+ variant = "large"
118
+
119
+ params = {"small": "9.5M", "base": "35.8M", "large": "68.4M"}.get(variant, "?")
120
+
121
+ return f"""---
122
+ license: apache-2.0
123
+ library_name: pytorch
124
+ tags:
125
+ - chess
126
+ - transformer
127
+ - causal-lm
128
+ - world-model
129
+ datasets:
130
+ - random-self-play
131
+ model-index:
132
+ - name: {repo_name}
133
+ results:
134
+ - task:
135
+ type: next-move-prediction
136
+ metrics:
137
+ - name: Val Loss
138
+ type: loss
139
+ value: {val_loss}
140
+ - name: Val Accuracy
141
+ type: accuracy
142
+ value: {val_acc}
143
+ ---
144
+
145
+ # {repo_name.upper()}
146
+
147
+ A causal transformer trained on random chess games, designed as a testbed for finetuning and augmentation methods at small scales.
148
+ {extra_desc}
149
+
150
+ ## Model Details
151
+
152
+ | | |
153
+ |---|---|
154
+ | **Parameters** | {params} |
155
+ | **Architecture** | Decoder-only transformer (RMSNorm, SwiGLU, RoPE) |
156
+ | **d_model** | {d_model} |
157
+ | **Layers** | {n_layers} |
158
+ | **Heads** | {n_heads} |
159
+ | **Vocabulary** | 4,278 tokens (4,096 grid + 176 promotions + 5 outcomes + 1 PAD) |
160
+ | **Sequence length** | 256 |
161
+ | **Best val loss** | {val_loss:.4f} (step {best_step:,}) |
162
+ | **Best val accuracy** | {val_acc:.1%} |
163
+
164
+ ## Usage
165
+
166
+ ```python
167
+ import torch
168
+ from safetensors.torch import load_file
169
+ from pawn.config import CLMConfig
170
+ from pawn.model import PAWNCLM
171
+
172
+ cfg = CLMConfig.{variant}()
173
+ model = PAWNCLM(cfg)
174
+ model.load_state_dict(load_file("model.safetensors"))
175
+ model.eval()
176
+ ```
177
+
178
+ ## Training
179
+
180
+ Trained from scratch on random self-play games generated by a Rust chess engine (shakmaty).
181
+ See the [PAWN repository]({github_url}) for training code, data pipeline, and evaluation suite.
182
+
183
+ ## License
184
+
185
+ Apache 2.0
186
+ """
187
+
188
+
189
+ def main():
190
+ parser = argparse.ArgumentParser(description="Export training run to HF repo format")
191
+ parser.add_argument("--run-dir", required=True, help="Training run directory")
192
+ parser.add_argument("--output-dir", required=True, help="Output directory for HF repo")
193
+ parser.add_argument("--repo-name", required=True, help="Repository name for README")
194
+ parser.add_argument("--github-url", default="https://github.com/thomas-schweich/PAWN")
195
+ parser.add_argument("--best-only", action="store_true", help="Only export best checkpoint")
196
+ parser.add_argument("--extra-desc", default="", help="Extra description for README")
197
+ args = parser.parse_args()
198
+
199
+ run_dir = Path(args.run_dir)
200
+ output_dir = Path(args.output_dir)
201
+ metrics_path = run_dir / "metrics.jsonl"
202
+
203
+ if not metrics_path.exists():
204
+ print(f"ERROR: {metrics_path} not found")
205
+ return
206
+
207
+ # Find best step
208
+ best_step = find_best_step(metrics_path)
209
+ if best_step is None:
210
+ print("ERROR: No val records found in metrics.jsonl")
211
+ return
212
+ print(f"Best val step: {best_step}")
213
+
214
+ # Find best val metrics
215
+ best_val_loss, best_val_acc = float("inf"), 0.0
216
+ with open(metrics_path) as f:
217
+ for line in f:
218
+ r = json.loads(line)
219
+ if r.get("type") == "val" and r.get("step") == best_step:
220
+ best_val_loss = r.get("val/loss", float("inf"))
221
+ best_val_acc = r.get("val/accuracy", 0.0)
222
+ break
223
+
224
+ # Find all .pt checkpoints
225
+ ckpt_dir = run_dir / "checkpoints"
226
+ checkpoints = sorted(ckpt_dir.glob("step_*.pt")) if ckpt_dir.exists() else []
227
+ if not checkpoints:
228
+ print("ERROR: No checkpoints found")
229
+ return
230
+
231
+ # Find nearest checkpoint to best step
232
+ best_ckpt = min(checkpoints, key=lambda p: abs(
233
+ int(p.stem.replace("step_", "")) - best_step
234
+ ))
235
+ print(f"Best checkpoint: {best_ckpt}")
236
+
237
+ # Read config from checkpoint
238
+ ckpt_data = torch.load(str(best_ckpt), map_location="cpu", weights_only=False)
239
+ model_config = ckpt_data.get("model_config", {})
240
+ training_config = ckpt_data.get("training_config", {})
241
+ del ckpt_data
242
+
243
+ # Export best checkpoint to root
244
+ output_dir.mkdir(parents=True, exist_ok=True)
245
+ print(f"\nExporting best checkpoint to {output_dir}/")
246
+ convert_pt_to_safetensors(best_ckpt, output_dir)
247
+
248
+ # Copy full metrics.jsonl
249
+ shutil.copy2(metrics_path, output_dir / "metrics.jsonl")
250
+
251
+ # Generate README
252
+ readme = generate_readme(
253
+ args.repo_name, model_config, training_config,
254
+ best_step, best_val_loss, best_val_acc,
255
+ args.github_url, args.extra_desc,
256
+ )
257
+ with open(output_dir / "README.md", "w") as f:
258
+ f.write(readme)
259
+
260
+ # Export other checkpoints
261
+ if not args.best_only:
262
+ for ckpt in checkpoints:
263
+ if ckpt == best_ckpt:
264
+ continue
265
+ step_name = ckpt.stem # e.g. "step_00005000"
266
+ step_num = int(step_name.replace("step_", ""))
267
+ step_dir = output_dir / "checkpoints" / step_name
268
+ print(f" Exporting {step_name}...")
269
+ convert_pt_to_safetensors(ckpt, step_dir)
270
+
271
+ # Truncated metrics
272
+ truncated = truncate_metrics(metrics_path, step_num)
273
+ with open(step_dir / "metrics.jsonl", "w") as f:
274
+ f.writelines(truncated)
275
+
276
+ print(f"\nExport complete: {output_dir}")
277
+ print(f" Best: model.safetensors, config.json, metrics.jsonl, README.md")
278
+ if not args.best_only:
279
+ print(f" Checkpoints: {len(checkpoints) - 1} in checkpoints/")
280
+
281
+
282
+ if __name__ == "__main__":
283
+ main()
scripts/profile_step.py CHANGED
@@ -18,17 +18,18 @@ import torch
18
  # ── PAWN imports ─────────────────────────────────────────────────────────────
19
  from pawn.config import CLMConfig
20
  from pawn.model import PAWNCLM
21
- from pawn.data import _to_clm_batch
22
  import chess_engine as engine
23
 
24
 
25
  def generate_clm_batch(batch_size: int, device: str):
26
  """Generate a CLM batch and move to device."""
27
- move_ids, game_lengths, term_codes = engine.generate_random_games(
28
- batch_size, 255, seed=42
29
- )
30
- batch = _to_clm_batch(move_ids, game_lengths, term_codes, 256)
31
- return {k: v.to(device) for k, v in batch.items()}
 
 
32
 
33
 
34
 
 
18
  # ── PAWN imports ─────────────────────────────────────────────────────────────
19
  from pawn.config import CLMConfig
20
  from pawn.model import PAWNCLM
 
21
  import chess_engine as engine
22
 
23
 
24
  def generate_clm_batch(batch_size: int, device: str):
25
  """Generate a CLM batch and move to device."""
26
+ input_ids, targets, loss_mask, _mid, _gl, _tc = \
27
+ engine.generate_clm_batch(batch_size, 256, seed=42)
28
+ return {
29
+ "input_ids": torch.from_numpy(input_ids).long().to(device),
30
+ "targets": torch.from_numpy(targets).long().to(device),
31
+ "loss_mask": torch.from_numpy(loss_mask).to(device),
32
+ }
33
 
34
 
35
 
scripts/train.py CHANGED
@@ -30,6 +30,12 @@ def parse_args():
30
  parser.add_argument("--log-dir", type=str, default=None, help="Override log directory")
31
  parser.add_argument("--discard-ply-limit", action="store_true",
32
  help="Only train on games that ended naturally (no ply limit truncation)")
 
 
 
 
 
 
33
  return parser.parse_args()
34
 
35
 
@@ -73,7 +79,7 @@ def main():
73
  print(f"Model config: {model_cfg}")
74
  print(f"Training config: {train_cfg}")
75
 
76
- trainer = CLMTrainer(train_cfg, model_cfg)
77
 
78
  if args.resume:
79
  trainer.load_checkpoint(args.resume)
 
30
  parser.add_argument("--log-dir", type=str, default=None, help="Override log directory")
31
  parser.add_argument("--discard-ply-limit", action="store_true",
32
  help="Only train on games that ended naturally (no ply limit truncation)")
33
+
34
+ ckpt_group = parser.add_mutually_exclusive_group(required=True)
35
+ ckpt_group.add_argument("--hf-repo", type=str, default=None,
36
+ help="Push checkpoints to this HuggingFace repo (requires HF_TOKEN)")
37
+ ckpt_group.add_argument("--local-checkpoints", action="store_true",
38
+ help="Save checkpoints locally only (no HuggingFace push)")
39
  return parser.parse_args()
40
 
41
 
 
79
  print(f"Model config: {model_cfg}")
80
  print(f"Training config: {train_cfg}")
81
 
82
+ trainer = CLMTrainer(train_cfg, model_cfg, hf_repo=args.hf_repo)
83
 
84
  if args.resume:
85
  trainer.load_checkpoint(args.resume)
scripts/train_all.py ADDED
@@ -0,0 +1,400 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Train small, base, and large PAWN models simultaneously on shared data.
3
+
4
+ All three models see the exact same batches in the same order, eliminating
5
+ data generation overhead and ensuring comparable training conditions.
6
+
7
+ Usage:
8
+ uv run python scripts/train_all.py --local-checkpoints
9
+ uv run python scripts/train_all.py --hf-repo thomas-schweich/pawn-{variant}
10
+ """
11
+
12
+ from __future__ import annotations
13
+
14
+ import argparse
15
+ import json
16
+ import math
17
+ import os
18
+ import signal
19
+ import sys
20
+ import time
21
+
22
+ import torch
23
+ import torch.multiprocessing as mp
24
+ from torch.utils.data import DataLoader
25
+
26
+ from pawn.config import CLMConfig, TrainingConfig
27
+ from pawn.model import PAWNCLM, clm_loss
28
+ from pawn.data import CLMDataset, create_validation_set
29
+ from pawn.gpu import configure_gpu, apply_gpu_config
30
+ from pawn.checkpoint import save_pretrain_checkpoint, push_checkpoint_to_hf
31
+
32
+
33
+ # ---------------------------------------------------------------------------
34
+ # Per-model state
35
+ # ---------------------------------------------------------------------------
36
+
37
+ class ModelSlot:
38
+ """Holds everything needed to train and checkpoint one model variant."""
39
+
40
+ def __init__(
41
+ self,
42
+ name: str,
43
+ model_cfg: CLMConfig,
44
+ train_cfg: TrainingConfig,
45
+ device: str,
46
+ hf_repo: str | None,
47
+ ):
48
+ self.name = name
49
+ self.model_cfg = model_cfg
50
+ self.train_cfg = train_cfg
51
+ self.device = device
52
+ self.hf_repo = hf_repo
53
+
54
+ self.model = PAWNCLM(model_cfg).to(device)
55
+ param_count = sum(p.numel() for p in self.model.parameters())
56
+ print(f" {name}: {param_count:,} params ({model_cfg.d_model}d/{model_cfg.n_layers}L)")
57
+
58
+ self.optimizer = torch.optim.AdamW(
59
+ self.model.parameters(),
60
+ lr=train_cfg.lr,
61
+ weight_decay=train_cfg.weight_decay,
62
+ )
63
+
64
+ from pawn.trainer import CosineWithWarmup
65
+ self.scheduler = CosineWithWarmup(
66
+ self.optimizer,
67
+ warmup_steps=train_cfg.warmup_steps,
68
+ total_steps=train_cfg.total_steps,
69
+ )
70
+
71
+ self.scaler = torch.amp.GradScaler(device, enabled=train_cfg.use_amp)
72
+
73
+ # Run directory
74
+ self.run_dir = _make_run_dir(train_cfg.log_dir, name)
75
+ self.checkpoint_dir = os.path.join(self.run_dir, "checkpoints")
76
+ os.makedirs(self.checkpoint_dir, exist_ok=True)
77
+
78
+ self.jsonl_path = os.path.join(self.run_dir, "metrics.jsonl")
79
+ self._jsonl_file: open | None = None
80
+
81
+ self.hf_branch = f"run/{os.path.basename(self.run_dir)}" if hf_repo else None
82
+ self.global_step = 0
83
+
84
+ # Write config
85
+ import subprocess
86
+ try:
87
+ git_hash = subprocess.check_output(
88
+ ["git", "rev-parse", "HEAD"], stderr=subprocess.DEVNULL, text=True
89
+ ).strip()
90
+ except Exception:
91
+ git_hash = os.environ.get("PAWN_GIT_HASH")
92
+ try:
93
+ git_tag = subprocess.check_output(
94
+ ["git", "tag", "--points-at", "HEAD"], stderr=subprocess.DEVNULL, text=True
95
+ ).strip() or None
96
+ except Exception:
97
+ git_tag = os.environ.get("PAWN_GIT_TAG")
98
+
99
+ config_data = {
100
+ "model": model_cfg.__dict__,
101
+ "training": train_cfg.__dict__,
102
+ "param_count": param_count,
103
+ "formulation": "clm",
104
+ "multi_model": True,
105
+ "variant": name,
106
+ "git_hash": git_hash,
107
+ "git_tag": git_tag,
108
+ }
109
+ with open(os.path.join(self.run_dir, "config.json"), "w") as f:
110
+ json.dump(config_data, f, indent=2, default=str)
111
+ self._log_jsonl({"type": "config", **config_data})
112
+
113
+ def train_step(self, batch: dict[str, torch.Tensor]) -> dict[str, float]:
114
+ self.model.train()
115
+ input_ids = batch["input_ids"].to(self.device)
116
+ targets = batch["targets"].to(self.device)
117
+ loss_mask = batch["loss_mask"].to(self.device)
118
+
119
+ with torch.amp.autocast(self.device, enabled=self.train_cfg.use_amp):
120
+ loss, metrics = self.model.forward_train(input_ids, loss_mask, targets)
121
+
122
+ self.scaler.scale(loss).backward()
123
+ return metrics
124
+
125
+ def optimizer_step(self) -> float:
126
+ self.scaler.unscale_(self.optimizer)
127
+ grad_norm = torch.nn.utils.clip_grad_norm_(
128
+ self.model.parameters(), self.train_cfg.max_grad_norm
129
+ ).item()
130
+ self.scaler.step(self.optimizer)
131
+ self.scaler.update()
132
+ self.optimizer.zero_grad(set_to_none=True)
133
+ self.scheduler.step()
134
+ return grad_norm
135
+
136
+ def save_checkpoint(self):
137
+ path = os.path.join(self.checkpoint_dir, f"step_{self.global_step:08d}")
138
+ save_pretrain_checkpoint(
139
+ path, self.model, self.optimizer, self.scheduler, self.scaler,
140
+ self.global_step, self.model_cfg.__dict__, self.train_cfg.__dict__,
141
+ )
142
+ print(f" [{self.name}] Checkpoint saved: {path}")
143
+
144
+ if self.hf_repo and self.hf_branch:
145
+ try:
146
+ push_checkpoint_to_hf(
147
+ path, self.hf_repo, self.hf_branch,
148
+ metrics_path=self.jsonl_path, step=self.global_step,
149
+ )
150
+ print(f" [{self.name}] Pushed to HF: {self.hf_repo}@{self.hf_branch}")
151
+ except Exception as e:
152
+ print(f" [{self.name}] WARNING: HF push failed: {e}")
153
+
154
+ @torch.no_grad()
155
+ def evaluate(self, val_data: dict[str, torch.Tensor]) -> dict[str, float]:
156
+ self.model.eval()
157
+ n = val_data["input_ids"].shape[0]
158
+ batch_size = self.train_cfg.batch_size
159
+ total_metrics: dict[str, float] = {}
160
+ n_batches = 0
161
+
162
+ for start in range(0, n, batch_size):
163
+ end = min(start + batch_size, n)
164
+ input_ids = val_data["input_ids"][start:end].to(self.device)
165
+ targets = val_data["targets"][start:end].to(self.device)
166
+ loss_mask = val_data["loss_mask"][start:end].to(self.device)
167
+
168
+ with torch.amp.autocast(self.device, enabled=self.train_cfg.use_amp):
169
+ logits, _ = self.model(input_ids, loss_mask)
170
+ _, metrics = clm_loss(logits, targets, loss_mask)
171
+
172
+ for k, v in metrics.items():
173
+ total_metrics[k] = total_metrics.get(k, 0.0) + v
174
+ n_batches += 1
175
+
176
+ return {f"val/{k}": v / n_batches for k, v in total_metrics.items()}
177
+
178
+ def _log_jsonl(self, record: dict):
179
+ if self._jsonl_file is None:
180
+ self._jsonl_file = open(self.jsonl_path, "a")
181
+ self._jsonl_file.write(json.dumps(record, default=str) + "\n")
182
+ self._jsonl_file.flush()
183
+
184
+ def close(self):
185
+ if self._jsonl_file:
186
+ self._jsonl_file.close()
187
+ self._jsonl_file = None
188
+
189
+
190
+ def _make_run_dir(log_dir: str, variant: str) -> str:
191
+ timestamp = time.strftime("%Y%m%d_%H%M%S")
192
+ run_dir = os.path.join(log_dir, f"run_{timestamp}_{variant}")
193
+ os.makedirs(run_dir, exist_ok=True)
194
+ return run_dir
195
+
196
+
197
+ # ---------------------------------------------------------------------------
198
+ # CLI
199
+ # ---------------------------------------------------------------------------
200
+
201
+ def parse_args():
202
+ p = argparse.ArgumentParser(description="Train small/base/large PAWN models simultaneously")
203
+ p.add_argument("--device", type=str, default=None, help="Device (cuda/cpu)")
204
+ p.add_argument("--total-steps", type=int, default=100_000, help="Total training steps")
205
+ p.add_argument("--batch-size", type=int, default=256, help="Batch size (shared across models)")
206
+ p.add_argument("--num-workers", type=int, default=4, help="DataLoader workers")
207
+ p.add_argument("--log-dir", type=str, default="logs", help="Log directory")
208
+ p.add_argument("--log-interval", type=int, default=10)
209
+ p.add_argument("--eval-interval", type=int, default=500)
210
+ p.add_argument("--checkpoint-interval", type=int, default=5000)
211
+ p.add_argument("--discard-ply-limit", action="store_true")
212
+ p.add_argument("--wandb", action="store_true")
213
+
214
+ ckpt_group = p.add_mutually_exclusive_group(required=True)
215
+ ckpt_group.add_argument("--hf-repo", type=str, default=None,
216
+ help="HF repo prefix (appends -{variant}). E.g. thomas-schweich/pawn")
217
+ ckpt_group.add_argument("--local-checkpoints", action="store_true")
218
+ return p.parse_args()
219
+
220
+
221
+ def main():
222
+ args = parse_args()
223
+
224
+ device = args.device or ("cuda" if torch.cuda.is_available() else "cpu")
225
+ if device == "cuda":
226
+ gpu_cfg = configure_gpu()
227
+ import pawn.model as model_module
228
+ if gpu_cfg.get("sdpa_backend"):
229
+ model_module.SDPA_BACKEND = gpu_cfg["sdpa_backend"]
230
+
231
+ # Build per-variant configs (shared training hyperparams, different model sizes)
232
+ variants = {
233
+ "small": CLMConfig.small(),
234
+ "base": CLMConfig.base(),
235
+ "large": CLMConfig.large(),
236
+ }
237
+
238
+ print("=== Multi-Model Training ===")
239
+ print(f"Device: {device}")
240
+ print(f"Batch size: {args.batch_size}")
241
+ print(f"Total steps: {args.total_steps}")
242
+ print()
243
+
244
+ slots: list[ModelSlot] = []
245
+ for name, model_cfg in variants.items():
246
+ train_cfg = TrainingConfig()
247
+ train_cfg.total_steps = args.total_steps
248
+ train_cfg.batch_size = args.batch_size
249
+ train_cfg.num_workers = args.num_workers
250
+ train_cfg.device = device
251
+ train_cfg.log_dir = args.log_dir
252
+ train_cfg.log_interval = args.log_interval
253
+ train_cfg.eval_interval = args.eval_interval
254
+ train_cfg.checkpoint_interval = args.checkpoint_interval
255
+ train_cfg.discard_ply_limit = args.discard_ply_limit
256
+ train_cfg.use_wandb = args.wandb
257
+
258
+ hf_repo = f"{args.hf_repo}-{name}" if args.hf_repo else None
259
+ slots.append(ModelSlot(name, model_cfg, train_cfg, device, hf_repo))
260
+
261
+ # Shared dataset and validation set
262
+ max_ply = 256
263
+ dataset = CLMDataset(
264
+ args.batch_size, max_ply, base_seed=42,
265
+ discard_ply_limit=args.discard_ply_limit,
266
+ )
267
+
268
+ print("\nGenerating shared validation set...")
269
+ val_data = create_validation_set(512, max_ply, seed=(2**63) - 1,
270
+ discard_ply_limit=args.discard_ply_limit)
271
+
272
+ # Compile models
273
+ if device != "cpu":
274
+ for slot in slots:
275
+ try:
276
+ slot.model = torch.compile(slot.model, mode="default")
277
+ print(f" [{slot.name}] torch.compile enabled")
278
+ except Exception:
279
+ print(f" [{slot.name}] torch.compile not available")
280
+
281
+ loader = DataLoader(
282
+ dataset,
283
+ batch_size=None,
284
+ num_workers=args.num_workers,
285
+ pin_memory=(device != "cpu"),
286
+ persistent_workers=(args.num_workers > 0),
287
+ prefetch_factor=1 if args.num_workers > 0 else None,
288
+ )
289
+
290
+ # Signal handling
291
+ _shutdown_requested = False
292
+ _shutdown_signal = None
293
+
294
+ def _graceful_exit(signum, frame):
295
+ nonlocal _shutdown_requested, _shutdown_signal
296
+ _shutdown_requested = True
297
+ _shutdown_signal = signum
298
+
299
+ signal.signal(signal.SIGTERM, _graceful_exit)
300
+ signal.signal(signal.SIGINT, _graceful_exit)
301
+
302
+ # Training loop
303
+ global_step = 0
304
+ step_start = time.time()
305
+
306
+ print(f"\nStarting training from step 0", flush=True)
307
+ for slot in slots:
308
+ print(f" [{slot.name}] JSONL: {slot.jsonl_path}", flush=True)
309
+ print()
310
+
311
+ for batch in loader:
312
+ # Forward + backward for each model on the same batch
313
+ all_metrics: dict[str, dict[str, float]] = {}
314
+ for slot in slots:
315
+ metrics = slot.train_step(batch)
316
+ all_metrics[slot.name] = metrics
317
+
318
+ # Optimizer step for each model
319
+ all_grad_norms: dict[str, float] = {}
320
+ for slot in slots:
321
+ gn = slot.optimizer_step()
322
+ all_grad_norms[slot.name] = gn
323
+
324
+ global_step += 1
325
+ for slot in slots:
326
+ slot.global_step = global_step
327
+
328
+ step_time = time.time() - step_start
329
+ games_per_sec = args.batch_size / step_time
330
+
331
+ # Logging
332
+ if global_step % args.log_interval == 0:
333
+ print(f"step {global_step:>7d} | {games_per_sec:.0f} g/s | {step_time:.2f}s", flush=True)
334
+ for slot in slots:
335
+ m = all_metrics[slot.name]
336
+ gn = all_grad_norms[slot.name]
337
+ lr = slot.scheduler.get_lr()
338
+ print(f" {slot.name:>5s}: loss {m['loss']:.4f} | acc {m['accuracy']:.3f} | "
339
+ f"lr {lr:.2e} | gn {gn:.2f}", flush=True)
340
+
341
+ record = {
342
+ "type": "train",
343
+ "step": global_step,
344
+ "timestamp": time.time(),
345
+ "lr": lr,
346
+ "grad_norm": gn,
347
+ "step_time": step_time,
348
+ "games_per_sec": games_per_sec,
349
+ **{f"train/{k}": v for k, v in m.items()},
350
+ }
351
+ slot._log_jsonl(record)
352
+
353
+ # Eval
354
+ if global_step % args.eval_interval == 0:
355
+ for slot in slots:
356
+ val_metrics = slot.evaluate(val_data)
357
+ print(f" {slot.name:>5s} val: loss {val_metrics['val/loss']:.4f} | "
358
+ f"acc {val_metrics['val/accuracy']:.3f}", flush=True)
359
+ slot._log_jsonl({
360
+ "type": "val",
361
+ "step": global_step,
362
+ "timestamp": time.time(),
363
+ **val_metrics,
364
+ })
365
+
366
+ # Checkpoint
367
+ if global_step % args.checkpoint_interval == 0:
368
+ for slot in slots:
369
+ slot.save_checkpoint()
370
+
371
+ # Done?
372
+ if global_step >= args.total_steps:
373
+ print(f"\nTraining complete at step {global_step}")
374
+ for slot in slots:
375
+ slot.save_checkpoint()
376
+ break
377
+
378
+ # Graceful shutdown
379
+ if _shutdown_requested:
380
+ print(f"\nShutdown requested (signal {_shutdown_signal}), "
381
+ f"saving checkpoints at step {global_step}...")
382
+ for slot in slots:
383
+ slot.save_checkpoint()
384
+ break
385
+
386
+ step_start = time.time()
387
+
388
+ # Cleanup
389
+ for slot in slots:
390
+ slot.close()
391
+
392
+ print("\nAll done.")
393
+
394
+
395
+ if __name__ == "__main__":
396
+ try:
397
+ mp.set_start_method("forkserver", force=True)
398
+ except ValueError:
399
+ mp.set_start_method("spawn", force=True)
400
+ main()
scripts/train_bottleneck.py CHANGED
@@ -16,6 +16,7 @@ from __future__ import annotations
16
  import argparse
17
  import gc
18
  import math
 
19
  import time
20
  from pathlib import Path
21
 
@@ -90,15 +91,22 @@ def parse_args():
90
  p.add_argument("--resume", type=str, default=None,
91
  help="Path to checkpoint to resume from (best.pt or final.pt)")
92
 
 
 
 
 
 
 
93
  return p.parse_args()
94
 
95
 
96
  def load_backbone(checkpoint_path: str, device: str) -> PAWNCLM:
97
- ckpt = torch.load(checkpoint_path, map_location=device, weights_only=False)
98
- cfg = CLMConfig(**ckpt["model_config"]) if "model_config" in ckpt else CLMConfig()
 
99
  model = PAWNCLM(cfg).to(device)
100
- model.load_state_dict(ckpt["model_state_dict"])
101
- del ckpt
102
  gc.collect()
103
  model.eval()
104
  return model
@@ -197,6 +205,10 @@ def main():
197
  ckpt_dir = out_dir / "checkpoints"
198
  ckpt_dir.mkdir(exist_ok=True)
199
 
 
 
 
 
200
  print(f"Device: {device}")
201
  print(f"Output: {out_dir}")
202
 
@@ -312,17 +324,18 @@ def main():
312
 
313
  if args.resume:
314
  print(f"\nResuming from: {args.resume}")
315
- ckpt = torch.load(args.resume, map_location=device, weights_only=False)
316
- model.load_adapter_state_dict(ckpt["bottleneck_state_dict"])
317
- if "optimizer_state_dict" in ckpt:
 
318
  optimizer.load_state_dict(ckpt["optimizer_state_dict"])
319
- if "scheduler_state_dict" in ckpt:
320
  scheduler.load_state_dict(ckpt["scheduler_state_dict"])
321
  if scaler and ckpt.get("scaler_state_dict"):
322
  scaler.load_state_dict(ckpt["scaler_state_dict"])
323
  start_epoch = ckpt["epoch"] + 1
324
  global_step = ckpt["step"]
325
- best_val_loss = ckpt.get("best_val_loss", ckpt.get("val_loss", float("inf")))
326
  patience_counter = ckpt.get("patience_counter", 0)
327
  print(f" Resumed at epoch {start_epoch}, step {global_step}, "
328
  f"best_val_loss={best_val_loss:.4f}")
@@ -358,6 +371,13 @@ def main():
358
  val_metrics = evaluate(model, val_loader, mask_builder, device, use_amp=use_amp,
359
  precomputed_indices=val_legal_indices) if args.resume else baseline
360
 
 
 
 
 
 
 
 
361
  print(f"\nTraining for up to {args.epochs} epochs ({total_steps} steps)")
362
  print(f" Warmup: {warmup_steps} steps, LR: {args.lr}")
363
 
@@ -434,38 +454,56 @@ def main():
434
  if val_metrics["loss"] < best_val_loss:
435
  best_val_loss = val_metrics["loss"]
436
  patience_counter = 0
437
- torch.save({
438
- "bottleneck_state_dict": model.adapter_state_dict(),
439
- "optimizer_state_dict": optimizer.state_dict(),
440
- "scheduler_state_dict": scheduler.state_dict(),
441
- "scaler_state_dict": scaler.state_dict() if scaler else None,
442
- "epoch": epoch,
443
- "step": global_step,
444
- "val_loss": val_metrics["loss"],
445
- "val_top1": val_metrics["top1_accuracy"],
446
- "best_val_loss": best_val_loss,
447
- "patience_counter": patience_counter,
448
- "config": vars(args),
449
- }, ckpt_dir / "best.pt")
 
 
 
 
 
 
 
450
  else:
451
  patience_counter += 1
452
  if patience_counter >= args.patience:
453
  print(f"\n Early stopping at epoch {epoch} (patience={args.patience})")
454
  break
455
 
456
- torch.save({
457
- "bottleneck_state_dict": model.adapter_state_dict(),
458
- "optimizer_state_dict": optimizer.state_dict(),
459
- "scheduler_state_dict": scheduler.state_dict(),
460
- "scaler_state_dict": scaler.state_dict() if scaler else None,
461
- "epoch": epoch,
462
- "step": global_step,
463
- "val_loss": val_metrics["loss"],
464
- "val_top1": val_metrics["top1_accuracy"],
465
- "best_val_loss": best_val_loss,
466
- "patience_counter": patience_counter,
467
- "config": vars(args),
468
- }, ckpt_dir / "final.pt")
 
 
 
 
 
 
 
 
 
 
 
469
 
470
  logger.close()
471
  print(f"\nDone. Best val_loss={best_val_loss:.4f}")
 
16
  import argparse
17
  import gc
18
  import math
19
+ import signal
20
  import time
21
  from pathlib import Path
22
 
 
91
  p.add_argument("--resume", type=str, default=None,
92
  help="Path to checkpoint to resume from (best.pt or final.pt)")
93
 
94
+ ckpt_group = p.add_mutually_exclusive_group(required=True)
95
+ ckpt_group.add_argument("--hf-repo", type=str, default=None,
96
+ help="Push checkpoints to this HuggingFace repo (requires HF_TOKEN)")
97
+ ckpt_group.add_argument("--local-checkpoints", action="store_true",
98
+ help="Save checkpoints locally only")
99
+
100
  return p.parse_args()
101
 
102
 
103
  def load_backbone(checkpoint_path: str, device: str) -> PAWNCLM:
104
+ from pawn.checkpoint import load_backbone_weights
105
+ state_dict, model_config = load_backbone_weights(checkpoint_path, device)
106
+ cfg = CLMConfig(**model_config) if model_config else CLMConfig()
107
  model = PAWNCLM(cfg).to(device)
108
+ model.load_state_dict(state_dict)
109
+ del state_dict
110
  gc.collect()
111
  model.eval()
112
  return model
 
205
  ckpt_dir = out_dir / "checkpoints"
206
  ckpt_dir.mkdir(exist_ok=True)
207
 
208
+ hf_branch = None
209
+ if args.hf_repo:
210
+ hf_branch = f"run/{out_dir.name}"
211
+
212
  print(f"Device: {device}")
213
  print(f"Output: {out_dir}")
214
 
 
324
 
325
  if args.resume:
326
  print(f"\nResuming from: {args.resume}")
327
+ from pawn.checkpoint import load_adapter_checkpoint
328
+ ckpt = load_adapter_checkpoint(args.resume, device=device)
329
+ model.load_adapter_state_dict(ckpt["adapter_state_dict"])
330
+ if ckpt.get("optimizer_state_dict"):
331
  optimizer.load_state_dict(ckpt["optimizer_state_dict"])
332
+ if ckpt.get("scheduler_state_dict"):
333
  scheduler.load_state_dict(ckpt["scheduler_state_dict"])
334
  if scaler and ckpt.get("scaler_state_dict"):
335
  scaler.load_state_dict(ckpt["scaler_state_dict"])
336
  start_epoch = ckpt["epoch"] + 1
337
  global_step = ckpt["step"]
338
+ best_val_loss = ckpt.get("best_val_loss", ckpt.get("val_metrics", {}).get("loss", float("inf")))
339
  patience_counter = ckpt.get("patience_counter", 0)
340
  print(f" Resumed at epoch {start_epoch}, step {global_step}, "
341
  f"best_val_loss={best_val_loss:.4f}")
 
371
  val_metrics = evaluate(model, val_loader, mask_builder, device, use_amp=use_amp,
372
  precomputed_indices=val_legal_indices) if args.resume else baseline
373
 
374
+ _shutdown_requested = False
375
+ def _graceful_exit(signum, frame):
376
+ nonlocal _shutdown_requested
377
+ _shutdown_requested = True
378
+ signal.signal(signal.SIGTERM, _graceful_exit)
379
+ signal.signal(signal.SIGINT, _graceful_exit)
380
+
381
  print(f"\nTraining for up to {args.epochs} epochs ({total_steps} steps)")
382
  print(f" Warmup: {warmup_steps} steps, LR: {args.lr}")
383
 
 
454
  if val_metrics["loss"] < best_val_loss:
455
  best_val_loss = val_metrics["loss"]
456
  patience_counter = 0
457
+ from pawn.checkpoint import save_adapter_checkpoint
458
+ save_adapter_checkpoint(
459
+ ckpt_dir / "best",
460
+ model.adapter_state_dict(),
461
+ config=vars(args),
462
+ epoch=epoch,
463
+ step=global_step,
464
+ val_metrics=val_metrics,
465
+ optimizer=optimizer,
466
+ scheduler=scheduler,
467
+ scaler=scaler,
468
+ extra={"best_val_loss": best_val_loss, "patience_counter": patience_counter},
469
+ )
470
+ if args.hf_repo and hf_branch:
471
+ from pawn.checkpoint import push_checkpoint_to_hf
472
+ try:
473
+ push_checkpoint_to_hf(ckpt_dir / "best", args.hf_repo, hf_branch, step=global_step)
474
+ print(f"Pushed to HF: {args.hf_repo}@{hf_branch}")
475
+ except Exception as e:
476
+ print(f"WARNING: HF push failed: {e}")
477
  else:
478
  patience_counter += 1
479
  if patience_counter >= args.patience:
480
  print(f"\n Early stopping at epoch {epoch} (patience={args.patience})")
481
  break
482
 
483
+ if _shutdown_requested:
484
+ print("Shutdown requested, saving checkpoint...")
485
+ break
486
+
487
+ from pawn.checkpoint import save_adapter_checkpoint
488
+ save_adapter_checkpoint(
489
+ ckpt_dir / "final",
490
+ model.adapter_state_dict(),
491
+ config=vars(args),
492
+ epoch=epoch,
493
+ step=global_step,
494
+ val_metrics=val_metrics,
495
+ optimizer=optimizer,
496
+ scheduler=scheduler,
497
+ scaler=scaler,
498
+ extra={"best_val_loss": best_val_loss, "patience_counter": patience_counter},
499
+ )
500
+ if args.hf_repo and hf_branch:
501
+ from pawn.checkpoint import push_checkpoint_to_hf
502
+ try:
503
+ push_checkpoint_to_hf(ckpt_dir / "final", args.hf_repo, hf_branch, step=global_step)
504
+ print(f"Pushed to HF: {args.hf_repo}@{hf_branch}")
505
+ except Exception as e:
506
+ print(f"WARNING: HF push failed: {e}")
507
 
508
  logger.close()
509
  print(f"\nDone. Best val_loss={best_val_loss:.4f}")
scripts/train_film.py CHANGED
@@ -17,6 +17,7 @@ import argparse
17
  import gc
18
  import json
19
  import math
 
20
  import time
21
  from pathlib import Path
22
 
@@ -80,15 +81,22 @@ def parse_args():
80
  p.add_argument("--sdpa-math", action="store_true",
81
  help="Use MATH SDPA backend (workaround for ROCm flash attn + compile)")
82
 
 
 
 
 
 
 
83
  return p.parse_args()
84
 
85
 
86
  def load_backbone(checkpoint_path: str, device: str) -> PAWNCLM:
87
- ckpt = torch.load(checkpoint_path, map_location=device, weights_only=False)
88
- cfg = CLMConfig(**ckpt["model_config"]) if "model_config" in ckpt else CLMConfig()
 
89
  model = PAWNCLM(cfg).to(device)
90
- model.load_state_dict(ckpt["model_state_dict"])
91
- del ckpt
92
  gc.collect()
93
  model.eval()
94
  return model
@@ -180,6 +188,10 @@ def main():
180
  ckpt_dir = out_dir / "checkpoints"
181
  ckpt_dir.mkdir(exist_ok=True)
182
 
 
 
 
 
183
  device = args.device
184
  print(f"Device: {device}")
185
  print(f"Output: {out_dir}")
@@ -284,6 +296,13 @@ def main():
284
  global_step = 0
285
  val_metrics = baseline # for first log if val_every > 1
286
 
 
 
 
 
 
 
 
287
  print(f"\nTraining for up to {args.epochs} epochs ({total_steps} steps)")
288
  print(f" Warmup: {warmup_steps} steps, LR: {args.lr}")
289
  if args.val_every > 1:
@@ -371,29 +390,49 @@ def main():
371
  if val_metrics["loss"] < best_val_loss:
372
  best_val_loss = val_metrics["loss"]
373
  patience_counter = 0
374
- torch.save({
375
- "film_state_dict": model.film_state_dict(),
376
- "epoch": epoch,
377
- "step": global_step,
378
- "val_loss": val_metrics["loss"],
379
- "val_top1": val_metrics["top1_accuracy"],
380
- "config": vars(args),
381
- }, ckpt_dir / "best.pt")
 
 
 
 
 
 
 
 
382
  else:
383
  patience_counter += 1
384
  if patience_counter >= args.patience:
385
  print(f"\n Early stopping at epoch {epoch} (patience={args.patience})")
386
  break
387
 
 
 
 
 
388
  # Save final checkpoint
389
- torch.save({
390
- "film_state_dict": model.film_state_dict(),
391
- "epoch": epoch,
392
- "step": global_step,
393
- "val_loss": val_metrics["loss"],
394
- "val_top1": val_metrics["top1_accuracy"],
395
- "config": vars(args),
396
- }, ckpt_dir / "final.pt")
 
 
 
 
 
 
 
 
397
 
398
  print(f"\nDone. Best val_loss={best_val_loss:.4f}")
399
  print(f"Checkpoints saved to {out_dir}")
 
17
  import gc
18
  import json
19
  import math
20
+ import signal
21
  import time
22
  from pathlib import Path
23
 
 
81
  p.add_argument("--sdpa-math", action="store_true",
82
  help="Use MATH SDPA backend (workaround for ROCm flash attn + compile)")
83
 
84
+ ckpt_group = p.add_mutually_exclusive_group(required=True)
85
+ ckpt_group.add_argument("--hf-repo", type=str, default=None,
86
+ help="Push checkpoints to this HuggingFace repo (requires HF_TOKEN)")
87
+ ckpt_group.add_argument("--local-checkpoints", action="store_true",
88
+ help="Save checkpoints locally only")
89
+
90
  return p.parse_args()
91
 
92
 
93
  def load_backbone(checkpoint_path: str, device: str) -> PAWNCLM:
94
+ from pawn.checkpoint import load_backbone_weights
95
+ state_dict, model_config = load_backbone_weights(checkpoint_path, device)
96
+ cfg = CLMConfig(**model_config) if model_config else CLMConfig()
97
  model = PAWNCLM(cfg).to(device)
98
+ model.load_state_dict(state_dict)
99
+ del state_dict
100
  gc.collect()
101
  model.eval()
102
  return model
 
188
  ckpt_dir = out_dir / "checkpoints"
189
  ckpt_dir.mkdir(exist_ok=True)
190
 
191
+ hf_branch = None
192
+ if args.hf_repo:
193
+ hf_branch = f"run/{out_dir.name}"
194
+
195
  device = args.device
196
  print(f"Device: {device}")
197
  print(f"Output: {out_dir}")
 
296
  global_step = 0
297
  val_metrics = baseline # for first log if val_every > 1
298
 
299
+ _shutdown_requested = False
300
+ def _graceful_exit(signum, frame):
301
+ nonlocal _shutdown_requested
302
+ _shutdown_requested = True
303
+ signal.signal(signal.SIGTERM, _graceful_exit)
304
+ signal.signal(signal.SIGINT, _graceful_exit)
305
+
306
  print(f"\nTraining for up to {args.epochs} epochs ({total_steps} steps)")
307
  print(f" Warmup: {warmup_steps} steps, LR: {args.lr}")
308
  if args.val_every > 1:
 
390
  if val_metrics["loss"] < best_val_loss:
391
  best_val_loss = val_metrics["loss"]
392
  patience_counter = 0
393
+ from pawn.checkpoint import save_adapter_checkpoint
394
+ save_adapter_checkpoint(
395
+ ckpt_dir / "best",
396
+ model.film_state_dict(),
397
+ config=vars(args),
398
+ epoch=epoch,
399
+ step=global_step,
400
+ val_metrics=val_metrics,
401
+ )
402
+ if args.hf_repo and hf_branch:
403
+ from pawn.checkpoint import push_checkpoint_to_hf
404
+ try:
405
+ push_checkpoint_to_hf(ckpt_dir / "best", args.hf_repo, hf_branch, step=global_step)
406
+ print(f"Pushed to HF: {args.hf_repo}@{hf_branch}")
407
+ except Exception as e:
408
+ print(f"WARNING: HF push failed: {e}")
409
  else:
410
  patience_counter += 1
411
  if patience_counter >= args.patience:
412
  print(f"\n Early stopping at epoch {epoch} (patience={args.patience})")
413
  break
414
 
415
+ if _shutdown_requested:
416
+ print("Shutdown requested, saving checkpoint...")
417
+ break
418
+
419
  # Save final checkpoint
420
+ from pawn.checkpoint import save_adapter_checkpoint
421
+ save_adapter_checkpoint(
422
+ ckpt_dir / "final",
423
+ model.film_state_dict(),
424
+ config=vars(args),
425
+ epoch=epoch,
426
+ step=global_step,
427
+ val_metrics=val_metrics,
428
+ )
429
+ if args.hf_repo and hf_branch:
430
+ from pawn.checkpoint import push_checkpoint_to_hf
431
+ try:
432
+ push_checkpoint_to_hf(ckpt_dir / "final", args.hf_repo, hf_branch, step=global_step)
433
+ print(f"Pushed to HF: {args.hf_repo}@{hf_branch}")
434
+ except Exception as e:
435
+ print(f"WARNING: HF push failed: {e}")
436
 
437
  print(f"\nDone. Best val_loss={best_val_loss:.4f}")
438
  print(f"Checkpoints saved to {out_dir}")
scripts/train_hybrid.py CHANGED
@@ -18,6 +18,7 @@ import argparse
18
  import gc
19
  import json
20
  import math
 
21
  import time
22
  from pathlib import Path
23
 
@@ -91,15 +92,22 @@ def parse_args():
91
  p.add_argument("--sdpa-math", action="store_true",
92
  help="Use MATH SDPA backend (workaround for ROCm flash attn + compile)")
93
 
 
 
 
 
 
 
94
  return p.parse_args()
95
 
96
 
97
  def load_backbone(checkpoint_path: str, device: str) -> PAWNCLM:
98
- ckpt = torch.load(checkpoint_path, map_location=device, weights_only=False)
99
- cfg = CLMConfig(**ckpt["model_config"]) if "model_config" in ckpt else CLMConfig()
 
100
  model = PAWNCLM(cfg).to(device)
101
- model.load_state_dict(ckpt["model_state_dict"])
102
- del ckpt
103
  gc.collect()
104
  model.eval()
105
  return model
@@ -182,6 +190,10 @@ def main():
182
  ckpt_dir = out_dir / "checkpoints"
183
  ckpt_dir.mkdir(exist_ok=True)
184
 
 
 
 
 
185
  device = args.device
186
  print(f"Device: {device}")
187
  print(f"Output: {out_dir}")
@@ -314,6 +326,13 @@ def main():
314
  global_step = 0
315
  val_metrics = baseline
316
 
 
 
 
 
 
 
 
317
  print(f"\nTraining for up to {args.epochs} epochs ({total_steps} steps)")
318
 
319
  for epoch in range(args.epochs):
@@ -394,28 +413,48 @@ def main():
394
  if val_metrics["loss"] < best_val_loss:
395
  best_val_loss = val_metrics["loss"]
396
  patience_counter = 0
397
- torch.save({
398
- "adapter_state_dict": model.adapter_state_dict(),
399
- "epoch": epoch,
400
- "step": global_step,
401
- "val_loss": val_metrics["loss"],
402
- "val_top1": val_metrics["top1_accuracy"],
403
- "config": vars(args),
404
- }, ckpt_dir / "best.pt")
 
 
 
 
 
 
 
 
405
  else:
406
  patience_counter += 1
407
  if patience_counter >= args.patience:
408
  print(f"\n Early stopping at epoch {epoch} (patience={args.patience})")
409
  break
410
 
411
- torch.save({
412
- "adapter_state_dict": model.adapter_state_dict(),
413
- "epoch": epoch,
414
- "step": global_step,
415
- "val_loss": val_metrics["loss"],
416
- "val_top1": val_metrics["top1_accuracy"],
417
- "config": vars(args),
418
- }, ckpt_dir / "final.pt")
 
 
 
 
 
 
 
 
 
 
 
 
419
 
420
  print(f"\nDone. Best val_loss={best_val_loss:.4f}")
421
  print(f"Checkpoints saved to {out_dir}")
 
18
  import gc
19
  import json
20
  import math
21
+ import signal
22
  import time
23
  from pathlib import Path
24
 
 
92
  p.add_argument("--sdpa-math", action="store_true",
93
  help="Use MATH SDPA backend (workaround for ROCm flash attn + compile)")
94
 
95
+ ckpt_group = p.add_mutually_exclusive_group(required=True)
96
+ ckpt_group.add_argument("--hf-repo", type=str, default=None,
97
+ help="Push checkpoints to this HuggingFace repo (requires HF_TOKEN)")
98
+ ckpt_group.add_argument("--local-checkpoints", action="store_true",
99
+ help="Save checkpoints locally only")
100
+
101
  return p.parse_args()
102
 
103
 
104
  def load_backbone(checkpoint_path: str, device: str) -> PAWNCLM:
105
+ from pawn.checkpoint import load_backbone_weights
106
+ state_dict, model_config = load_backbone_weights(checkpoint_path, device)
107
+ cfg = CLMConfig(**model_config) if model_config else CLMConfig()
108
  model = PAWNCLM(cfg).to(device)
109
+ model.load_state_dict(state_dict)
110
+ del state_dict
111
  gc.collect()
112
  model.eval()
113
  return model
 
190
  ckpt_dir = out_dir / "checkpoints"
191
  ckpt_dir.mkdir(exist_ok=True)
192
 
193
+ hf_branch = None
194
+ if args.hf_repo:
195
+ hf_branch = f"run/{out_dir.name}"
196
+
197
  device = args.device
198
  print(f"Device: {device}")
199
  print(f"Output: {out_dir}")
 
326
  global_step = 0
327
  val_metrics = baseline
328
 
329
+ _shutdown_requested = False
330
+ def _graceful_exit(signum, frame):
331
+ nonlocal _shutdown_requested
332
+ _shutdown_requested = True
333
+ signal.signal(signal.SIGTERM, _graceful_exit)
334
+ signal.signal(signal.SIGINT, _graceful_exit)
335
+
336
  print(f"\nTraining for up to {args.epochs} epochs ({total_steps} steps)")
337
 
338
  for epoch in range(args.epochs):
 
413
  if val_metrics["loss"] < best_val_loss:
414
  best_val_loss = val_metrics["loss"]
415
  patience_counter = 0
416
+ from pawn.checkpoint import save_adapter_checkpoint
417
+ save_adapter_checkpoint(
418
+ ckpt_dir / "best",
419
+ model.adapter_state_dict(),
420
+ config=vars(args),
421
+ epoch=epoch,
422
+ step=global_step,
423
+ val_metrics=val_metrics,
424
+ )
425
+ if args.hf_repo and hf_branch:
426
+ from pawn.checkpoint import push_checkpoint_to_hf
427
+ try:
428
+ push_checkpoint_to_hf(ckpt_dir / "best", args.hf_repo, hf_branch, step=global_step)
429
+ print(f"Pushed to HF: {args.hf_repo}@{hf_branch}")
430
+ except Exception as e:
431
+ print(f"WARNING: HF push failed: {e}")
432
  else:
433
  patience_counter += 1
434
  if patience_counter >= args.patience:
435
  print(f"\n Early stopping at epoch {epoch} (patience={args.patience})")
436
  break
437
 
438
+ if _shutdown_requested:
439
+ print("Shutdown requested, saving checkpoint...")
440
+ break
441
+
442
+ from pawn.checkpoint import save_adapter_checkpoint
443
+ save_adapter_checkpoint(
444
+ ckpt_dir / "final",
445
+ model.adapter_state_dict(),
446
+ config=vars(args),
447
+ epoch=epoch,
448
+ step=global_step,
449
+ val_metrics=val_metrics,
450
+ )
451
+ if args.hf_repo and hf_branch:
452
+ from pawn.checkpoint import push_checkpoint_to_hf
453
+ try:
454
+ push_checkpoint_to_hf(ckpt_dir / "final", args.hf_repo, hf_branch, step=global_step)
455
+ print(f"Pushed to HF: {args.hf_repo}@{hf_branch}")
456
+ except Exception as e:
457
+ print(f"WARNING: HF push failed: {e}")
458
 
459
  print(f"\nDone. Best val_loss={best_val_loss:.4f}")
460
  print(f"Checkpoints saved to {out_dir}")
scripts/train_lora.py CHANGED
@@ -17,6 +17,7 @@ import argparse
17
  import gc
18
  import json
19
  import math
 
20
  import time
21
  from pathlib import Path
22
 
@@ -92,15 +93,22 @@ def parse_args():
92
  p.add_argument("--sdpa-math", action="store_true",
93
  help="Use MATH SDPA backend (workaround for ROCm flash attn + compile)")
94
 
 
 
 
 
 
 
95
  return p.parse_args()
96
 
97
 
98
  def load_backbone(checkpoint_path: str, device: str) -> PAWNCLM:
99
- ckpt = torch.load(checkpoint_path, map_location=device, weights_only=False)
100
- cfg = CLMConfig(**ckpt["model_config"]) if "model_config" in ckpt else CLMConfig()
 
101
  model = PAWNCLM(cfg).to(device)
102
- model.load_state_dict(ckpt["model_state_dict"])
103
- del ckpt
104
  gc.collect()
105
  model.eval()
106
  return model
@@ -190,6 +198,10 @@ def main():
190
  ckpt_dir = out_dir / "checkpoints"
191
  ckpt_dir.mkdir(exist_ok=True)
192
 
 
 
 
 
193
  device = args.device
194
  print(f"Device: {device}")
195
  print(f"Output: {out_dir}")
@@ -309,6 +321,13 @@ def main():
309
  global_step = 0
310
  val_metrics = baseline
311
 
 
 
 
 
 
 
 
312
  print(f"\nTraining for up to {args.epochs} epochs ({total_steps} steps)")
313
  print(f" Warmup: {warmup_steps} steps, LR: {args.lr}")
314
  if args.val_every > 1:
@@ -395,29 +414,49 @@ def main():
395
  if val_metrics["loss"] < best_val_loss:
396
  best_val_loss = val_metrics["loss"]
397
  patience_counter = 0
398
- torch.save({
399
- "lora_state_dict": model.lora_state_dict(),
400
- "epoch": epoch,
401
- "step": global_step,
402
- "val_loss": val_metrics["loss"],
403
- "val_top1": val_metrics["top1_accuracy"],
404
- "config": vars(args),
405
- }, ckpt_dir / "best.pt")
 
 
 
 
 
 
 
 
406
  else:
407
  patience_counter += 1
408
  if patience_counter >= args.patience:
409
  print(f"\n Early stopping at epoch {epoch} (patience={args.patience})")
410
  break
411
 
 
 
 
 
412
  # Save final checkpoint
413
- torch.save({
414
- "lora_state_dict": model.lora_state_dict(),
415
- "epoch": epoch,
416
- "step": global_step,
417
- "val_loss": val_metrics["loss"],
418
- "val_top1": val_metrics["top1_accuracy"],
419
- "config": vars(args),
420
- }, ckpt_dir / "final.pt")
 
 
 
 
 
 
 
 
421
 
422
  print(f"\nDone. Best val_loss={best_val_loss:.4f}")
423
  print(f"Checkpoints saved to {out_dir}")
 
17
  import gc
18
  import json
19
  import math
20
+ import signal
21
  import time
22
  from pathlib import Path
23
 
 
93
  p.add_argument("--sdpa-math", action="store_true",
94
  help="Use MATH SDPA backend (workaround for ROCm flash attn + compile)")
95
 
96
+ ckpt_group = p.add_mutually_exclusive_group(required=True)
97
+ ckpt_group.add_argument("--hf-repo", type=str, default=None,
98
+ help="Push checkpoints to this HuggingFace repo (requires HF_TOKEN)")
99
+ ckpt_group.add_argument("--local-checkpoints", action="store_true",
100
+ help="Save checkpoints locally only")
101
+
102
  return p.parse_args()
103
 
104
 
105
  def load_backbone(checkpoint_path: str, device: str) -> PAWNCLM:
106
+ from pawn.checkpoint import load_backbone_weights
107
+ state_dict, model_config = load_backbone_weights(checkpoint_path, device)
108
+ cfg = CLMConfig(**model_config) if model_config else CLMConfig()
109
  model = PAWNCLM(cfg).to(device)
110
+ model.load_state_dict(state_dict)
111
+ del state_dict
112
  gc.collect()
113
  model.eval()
114
  return model
 
198
  ckpt_dir = out_dir / "checkpoints"
199
  ckpt_dir.mkdir(exist_ok=True)
200
 
201
+ hf_branch = None
202
+ if args.hf_repo:
203
+ hf_branch = f"run/{out_dir.name}"
204
+
205
  device = args.device
206
  print(f"Device: {device}")
207
  print(f"Output: {out_dir}")
 
321
  global_step = 0
322
  val_metrics = baseline
323
 
324
+ _shutdown_requested = False
325
+ def _graceful_exit(signum, frame):
326
+ nonlocal _shutdown_requested
327
+ _shutdown_requested = True
328
+ signal.signal(signal.SIGTERM, _graceful_exit)
329
+ signal.signal(signal.SIGINT, _graceful_exit)
330
+
331
  print(f"\nTraining for up to {args.epochs} epochs ({total_steps} steps)")
332
  print(f" Warmup: {warmup_steps} steps, LR: {args.lr}")
333
  if args.val_every > 1:
 
414
  if val_metrics["loss"] < best_val_loss:
415
  best_val_loss = val_metrics["loss"]
416
  patience_counter = 0
417
+ from pawn.checkpoint import save_adapter_checkpoint
418
+ save_adapter_checkpoint(
419
+ ckpt_dir / "best",
420
+ model.lora_state_dict(),
421
+ config=vars(args),
422
+ epoch=epoch,
423
+ step=global_step,
424
+ val_metrics=val_metrics,
425
+ )
426
+ if args.hf_repo and hf_branch:
427
+ from pawn.checkpoint import push_checkpoint_to_hf
428
+ try:
429
+ push_checkpoint_to_hf(ckpt_dir / "best", args.hf_repo, hf_branch, step=global_step)
430
+ print(f"Pushed to HF: {args.hf_repo}@{hf_branch}")
431
+ except Exception as e:
432
+ print(f"WARNING: HF push failed: {e}")
433
  else:
434
  patience_counter += 1
435
  if patience_counter >= args.patience:
436
  print(f"\n Early stopping at epoch {epoch} (patience={args.patience})")
437
  break
438
 
439
+ if _shutdown_requested:
440
+ print("Shutdown requested, saving checkpoint...")
441
+ break
442
+
443
  # Save final checkpoint
444
+ from pawn.checkpoint import save_adapter_checkpoint
445
+ save_adapter_checkpoint(
446
+ ckpt_dir / "final",
447
+ model.lora_state_dict(),
448
+ config=vars(args),
449
+ epoch=epoch,
450
+ step=global_step,
451
+ val_metrics=val_metrics,
452
+ )
453
+ if args.hf_repo and hf_branch:
454
+ from pawn.checkpoint import push_checkpoint_to_hf
455
+ try:
456
+ push_checkpoint_to_hf(ckpt_dir / "final", args.hf_repo, hf_branch, step=global_step)
457
+ print(f"Pushed to HF: {args.hf_repo}@{hf_branch}")
458
+ except Exception as e:
459
+ print(f"WARNING: HF push failed: {e}")
460
 
461
  print(f"\nDone. Best val_loss={best_val_loss:.4f}")
462
  print(f"Checkpoints saved to {out_dir}")
scripts/train_sparse.py CHANGED
@@ -17,6 +17,7 @@ import argparse
17
  import gc
18
  import json
19
  import math
 
20
  import time
21
  from pathlib import Path
22
 
@@ -83,15 +84,22 @@ def parse_args():
83
  p.add_argument("--sdpa-math", action="store_true",
84
  help="Use MATH SDPA backend (workaround for ROCm flash attn + compile)")
85
 
 
 
 
 
 
 
86
  return p.parse_args()
87
 
88
 
89
  def load_backbone(checkpoint_path: str, device: str) -> PAWNCLM:
90
- ckpt = torch.load(checkpoint_path, map_location=device, weights_only=False)
91
- cfg = CLMConfig(**ckpt["model_config"]) if "model_config" in ckpt else CLMConfig()
 
92
  model = PAWNCLM(cfg).to(device)
93
- model.load_state_dict(ckpt["model_state_dict"])
94
- del ckpt
95
  gc.collect()
96
  model.eval()
97
  return model
@@ -182,6 +190,10 @@ def main():
182
  ckpt_dir = out_dir / "checkpoints"
183
  ckpt_dir.mkdir(exist_ok=True)
184
 
 
 
 
 
185
  device = args.device
186
  print(f"Device: {device}")
187
  print(f"Output: {out_dir}")
@@ -299,6 +311,13 @@ def main():
299
  global_step = 0
300
  val_metrics = baseline
301
 
 
 
 
 
 
 
 
302
  print(f"\nTraining for up to {args.epochs} epochs ({total_steps} steps)")
303
  print(f" Warmup: {warmup_steps} steps, LR: {args.lr}")
304
 
@@ -379,28 +398,48 @@ def main():
379
  if val_metrics["loss"] < best_val_loss:
380
  best_val_loss = val_metrics["loss"]
381
  patience_counter = 0
382
- torch.save({
383
- "sparse_state_dict": model.sparse_state_dict(),
384
- "epoch": epoch,
385
- "step": global_step,
386
- "val_loss": val_metrics["loss"],
387
- "val_top1": val_metrics["top1_accuracy"],
388
- "config": vars(args),
389
- }, ckpt_dir / "best.pt")
 
 
 
 
 
 
 
 
390
  else:
391
  patience_counter += 1
392
  if patience_counter >= args.patience:
393
  print(f"\n Early stopping at epoch {epoch} (patience={args.patience})")
394
  break
395
 
396
- torch.save({
397
- "sparse_state_dict": model.sparse_state_dict(),
398
- "epoch": epoch,
399
- "step": global_step,
400
- "val_loss": val_metrics["loss"],
401
- "val_top1": val_metrics["top1_accuracy"],
402
- "config": vars(args),
403
- }, ckpt_dir / "final.pt")
 
 
 
 
 
 
 
 
 
 
 
 
404
 
405
  print(f"\nDone. Best val_loss={best_val_loss:.4f}")
406
  print(f"Checkpoints saved to {out_dir}")
 
17
  import gc
18
  import json
19
  import math
20
+ import signal
21
  import time
22
  from pathlib import Path
23
 
 
84
  p.add_argument("--sdpa-math", action="store_true",
85
  help="Use MATH SDPA backend (workaround for ROCm flash attn + compile)")
86
 
87
+ ckpt_group = p.add_mutually_exclusive_group(required=True)
88
+ ckpt_group.add_argument("--hf-repo", type=str, default=None,
89
+ help="Push checkpoints to this HuggingFace repo (requires HF_TOKEN)")
90
+ ckpt_group.add_argument("--local-checkpoints", action="store_true",
91
+ help="Save checkpoints locally only")
92
+
93
  return p.parse_args()
94
 
95
 
96
  def load_backbone(checkpoint_path: str, device: str) -> PAWNCLM:
97
+ from pawn.checkpoint import load_backbone_weights
98
+ state_dict, model_config = load_backbone_weights(checkpoint_path, device)
99
+ cfg = CLMConfig(**model_config) if model_config else CLMConfig()
100
  model = PAWNCLM(cfg).to(device)
101
+ model.load_state_dict(state_dict)
102
+ del state_dict
103
  gc.collect()
104
  model.eval()
105
  return model
 
190
  ckpt_dir = out_dir / "checkpoints"
191
  ckpt_dir.mkdir(exist_ok=True)
192
 
193
+ hf_branch = None
194
+ if args.hf_repo:
195
+ hf_branch = f"run/{out_dir.name}"
196
+
197
  device = args.device
198
  print(f"Device: {device}")
199
  print(f"Output: {out_dir}")
 
311
  global_step = 0
312
  val_metrics = baseline
313
 
314
+ _shutdown_requested = False
315
+ def _graceful_exit(signum, frame):
316
+ nonlocal _shutdown_requested
317
+ _shutdown_requested = True
318
+ signal.signal(signal.SIGTERM, _graceful_exit)
319
+ signal.signal(signal.SIGINT, _graceful_exit)
320
+
321
  print(f"\nTraining for up to {args.epochs} epochs ({total_steps} steps)")
322
  print(f" Warmup: {warmup_steps} steps, LR: {args.lr}")
323
 
 
398
  if val_metrics["loss"] < best_val_loss:
399
  best_val_loss = val_metrics["loss"]
400
  patience_counter = 0
401
+ from pawn.checkpoint import save_adapter_checkpoint
402
+ save_adapter_checkpoint(
403
+ ckpt_dir / "best",
404
+ model.sparse_state_dict(),
405
+ config=vars(args),
406
+ epoch=epoch,
407
+ step=global_step,
408
+ val_metrics=val_metrics,
409
+ )
410
+ if args.hf_repo and hf_branch:
411
+ from pawn.checkpoint import push_checkpoint_to_hf
412
+ try:
413
+ push_checkpoint_to_hf(ckpt_dir / "best", args.hf_repo, hf_branch, step=global_step)
414
+ print(f"Pushed to HF: {args.hf_repo}@{hf_branch}")
415
+ except Exception as e:
416
+ print(f"WARNING: HF push failed: {e}")
417
  else:
418
  patience_counter += 1
419
  if patience_counter >= args.patience:
420
  print(f"\n Early stopping at epoch {epoch} (patience={args.patience})")
421
  break
422
 
423
+ if _shutdown_requested:
424
+ print("Shutdown requested, saving checkpoint...")
425
+ break
426
+
427
+ from pawn.checkpoint import save_adapter_checkpoint
428
+ save_adapter_checkpoint(
429
+ ckpt_dir / "final",
430
+ model.sparse_state_dict(),
431
+ config=vars(args),
432
+ epoch=epoch,
433
+ step=global_step,
434
+ val_metrics=val_metrics,
435
+ )
436
+ if args.hf_repo and hf_branch:
437
+ from pawn.checkpoint import push_checkpoint_to_hf
438
+ try:
439
+ push_checkpoint_to_hf(ckpt_dir / "final", args.hf_repo, hf_branch, step=global_step)
440
+ print(f"Pushed to HF: {args.hf_repo}@{hf_branch}")
441
+ except Exception as e:
442
+ print(f"WARNING: HF push failed: {e}")
443
 
444
  print(f"\nDone. Best val_loss={best_val_loss:.4f}")
445
  print(f"Checkpoints saved to {out_dir}")
scripts/train_tiny.py CHANGED
@@ -15,6 +15,7 @@ from __future__ import annotations
15
 
16
  import argparse
17
  import math
 
18
  import time
19
  from pathlib import Path
20
 
@@ -228,6 +229,12 @@ def parse_args():
228
  p.add_argument("--no-compile", action="store_true")
229
  p.add_argument("--num-workers", type=int, default=3)
230
 
 
 
 
 
 
 
231
  return p.parse_args()
232
 
233
 
@@ -255,6 +262,10 @@ def main():
255
  ckpt_dir = out_dir / "checkpoints"
256
  ckpt_dir.mkdir(exist_ok=True)
257
 
 
 
 
 
258
  # Build model
259
  model = TinyChessLM(
260
  vocab_size=vocab_size,
@@ -358,6 +369,13 @@ def main():
358
  global_step = 0
359
  val_metrics = baseline
360
 
 
 
 
 
 
 
 
361
  print(f"\nTraining for up to {args.epochs} epochs ({total_steps} steps)")
362
  print(f" Warmup: {warmup_steps} steps, LR: {args.lr}")
363
 
@@ -428,24 +446,39 @@ def main():
428
  if val_metrics["loss"] < best_val_loss:
429
  best_val_loss = val_metrics["loss"]
430
  patience_counter = 0
431
- torch.save({
432
- "model_state_dict": model.state_dict(),
433
- "optimizer_state_dict": optimizer.state_dict(),
434
- "scheduler_state_dict": scheduler.state_dict(),
435
- "scaler_state_dict": scaler.state_dict() if scaler else None,
436
- "epoch": epoch, "step": global_step,
437
- "val_loss": val_metrics["loss"],
438
- "val_top1": val_metrics["top1_accuracy"],
439
- "best_val_loss": best_val_loss,
440
- "patience_counter": patience_counter,
441
- "config": vars(args),
442
- }, ckpt_dir / "best.pt")
 
 
 
 
 
 
 
 
 
 
 
443
  else:
444
  patience_counter += 1
445
  if patience_counter >= args.patience:
446
  print(f"\n Early stopping at epoch {epoch} (patience={args.patience})")
447
  break
448
 
 
 
 
 
449
  logger.close()
450
  print(f"\nDone. Best val_loss={best_val_loss:.4f}")
451
  print(f"Checkpoints saved to {out_dir}")
 
15
 
16
  import argparse
17
  import math
18
+ import signal
19
  import time
20
  from pathlib import Path
21
 
 
229
  p.add_argument("--no-compile", action="store_true")
230
  p.add_argument("--num-workers", type=int, default=3)
231
 
232
+ ckpt_group = p.add_mutually_exclusive_group(required=True)
233
+ ckpt_group.add_argument("--hf-repo", type=str, default=None,
234
+ help="Push checkpoints to this HuggingFace repo (requires HF_TOKEN)")
235
+ ckpt_group.add_argument("--local-checkpoints", action="store_true",
236
+ help="Save checkpoints locally only")
237
+
238
  return p.parse_args()
239
 
240
 
 
262
  ckpt_dir = out_dir / "checkpoints"
263
  ckpt_dir.mkdir(exist_ok=True)
264
 
265
+ hf_branch = None
266
+ if args.hf_repo:
267
+ hf_branch = f"run/{out_dir.name}"
268
+
269
  # Build model
270
  model = TinyChessLM(
271
  vocab_size=vocab_size,
 
369
  global_step = 0
370
  val_metrics = baseline
371
 
372
+ _shutdown_requested = False
373
+ def _graceful_exit(signum, frame):
374
+ nonlocal _shutdown_requested
375
+ _shutdown_requested = True
376
+ signal.signal(signal.SIGTERM, _graceful_exit)
377
+ signal.signal(signal.SIGINT, _graceful_exit)
378
+
379
  print(f"\nTraining for up to {args.epochs} epochs ({total_steps} steps)")
380
  print(f" Warmup: {warmup_steps} steps, LR: {args.lr}")
381
 
 
446
  if val_metrics["loss"] < best_val_loss:
447
  best_val_loss = val_metrics["loss"]
448
  patience_counter = 0
449
+ from pawn.checkpoint import save_pretrain_checkpoint
450
+ save_pretrain_checkpoint(
451
+ ckpt_dir / "best",
452
+ model=model,
453
+ optimizer=optimizer,
454
+ scheduler=scheduler,
455
+ scaler=scaler,
456
+ global_step=global_step,
457
+ model_config={
458
+ "d_model": args.d_model,
459
+ "n_layers": args.n_layers,
460
+ "n_heads": args.n_heads,
461
+ "d_ff": args.d_ff,
462
+ },
463
+ training_config=vars(args),
464
+ )
465
+ if args.hf_repo and hf_branch:
466
+ from pawn.checkpoint import push_checkpoint_to_hf
467
+ try:
468
+ push_checkpoint_to_hf(ckpt_dir / "best", args.hf_repo, hf_branch, step=global_step)
469
+ print(f"Pushed to HF: {args.hf_repo}@{hf_branch}")
470
+ except Exception as e:
471
+ print(f"WARNING: HF push failed: {e}")
472
  else:
473
  patience_counter += 1
474
  if patience_counter >= args.patience:
475
  print(f"\n Early stopping at epoch {epoch} (patience={args.patience})")
476
  break
477
 
478
+ if _shutdown_requested:
479
+ print("Shutdown requested, saving checkpoint...")
480
+ break
481
+
482
  logger.close()
483
  print(f"\nDone. Best val_loss={best_val_loss:.4f}")
484
  print(f"Checkpoints saved to {out_dir}")
tests/test_checkpoint.py ADDED
@@ -0,0 +1,359 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tests for pawn.checkpoint safetensors save/load."""
2
+
3
+ import json
4
+ import tempfile
5
+ from pathlib import Path
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+
10
+ from pawn.config import CLMConfig
11
+ from pawn.model import PAWNCLM
12
+ import pytest
13
+
14
+ from pawn.checkpoint import (
15
+ save_pretrain_checkpoint,
16
+ load_pretrain_checkpoint,
17
+ save_adapter_checkpoint,
18
+ load_adapter_checkpoint,
19
+ load_backbone_weights,
20
+ is_legacy_checkpoint,
21
+ IncompleteCheckpointError,
22
+ CheckpointIntegrityError,
23
+ _flatten_optimizer_state,
24
+ _unflatten_optimizer_state,
25
+ _rng_to_json,
26
+ _json_to_rng,
27
+ _verify_complete_sentinel,
28
+ )
29
+
30
+
31
+ def _make_toy_model(device="cpu"):
32
+ cfg = CLMConfig.toy()
33
+ model = PAWNCLM(cfg).to(device)
34
+ return model, cfg
35
+
36
+
37
+ def _make_optimizer(model):
38
+ return torch.optim.AdamW(model.parameters(), lr=1e-3)
39
+
40
+
41
+ class FakeScheduler:
42
+ """Minimal scheduler stand-in for testing."""
43
+ def __init__(self):
44
+ self._step = 0
45
+
46
+ def state_dict(self):
47
+ return {"step": self._step}
48
+
49
+ def load_state_dict(self, d):
50
+ self._step = d["step"]
51
+
52
+
53
+ class FakeScaler:
54
+ """Minimal scaler stand-in for testing."""
55
+ def __init__(self):
56
+ self._scale = 65536.0
57
+
58
+ def state_dict(self):
59
+ return {"scale": self._scale, "_growth_tracker": 0}
60
+
61
+ def load_state_dict(self, d):
62
+ self._scale = d["scale"]
63
+
64
+
65
+ def test_pretrain_checkpoint_roundtrip():
66
+ """Save and load a pretrain checkpoint, verify model weights match."""
67
+ model1, cfg = _make_toy_model()
68
+ opt1 = _make_optimizer(model1)
69
+ sched1 = FakeScheduler()
70
+ sched1._step = 42
71
+ scaler1 = FakeScaler()
72
+
73
+ # Run a fake step to populate optimizer state
74
+ x = torch.randint(0, cfg.vocab_size, (2, cfg.max_seq_len))
75
+ mask = torch.ones(2, cfg.max_seq_len, dtype=torch.bool)
76
+ targets = torch.randint(0, cfg.vocab_size, (2, cfg.max_seq_len))
77
+ loss, _ = model1.forward_train(x, mask, targets)
78
+ loss.backward()
79
+ opt1.step()
80
+
81
+ with tempfile.TemporaryDirectory() as tmpdir:
82
+ ckpt_path = Path(tmpdir) / "step_00000001"
83
+ save_pretrain_checkpoint(
84
+ ckpt_path, model1, opt1, sched1, scaler1,
85
+ global_step=1,
86
+ model_config=cfg.__dict__,
87
+ training_config={"lr": 1e-3},
88
+ )
89
+
90
+ # Verify files exist
91
+ assert (ckpt_path / "model.safetensors").exists()
92
+ assert (ckpt_path / "optimizer.safetensors").exists()
93
+ assert (ckpt_path / "training_state.json").exists()
94
+ assert (ckpt_path / "config.json").exists()
95
+
96
+ # Load into fresh model
97
+ model2, _ = _make_toy_model()
98
+ opt2 = _make_optimizer(model2)
99
+ sched2 = FakeScheduler()
100
+ scaler2 = FakeScaler()
101
+
102
+ meta = load_pretrain_checkpoint(
103
+ ckpt_path, model2, opt2, sched2, scaler2
104
+ )
105
+
106
+ assert meta["global_step"] == 1
107
+ assert meta["model_config"]["d_model"] == cfg.d_model
108
+ assert sched2._step == 42
109
+
110
+ # Verify model weights match
111
+ for k in model1.state_dict():
112
+ torch.testing.assert_close(
113
+ model1.state_dict()[k],
114
+ model2.state_dict()[k],
115
+ )
116
+
117
+
118
+ def test_optimizer_flatten_unflatten():
119
+ """Verify optimizer state survives flatten/unflatten."""
120
+ model, cfg = _make_toy_model()
121
+ opt = _make_optimizer(model)
122
+
123
+ # Populate optimizer state
124
+ x = torch.randint(0, cfg.vocab_size, (2, cfg.max_seq_len))
125
+ mask = torch.ones(2, cfg.max_seq_len, dtype=torch.bool)
126
+ targets = torch.randint(0, cfg.vocab_size, (2, cfg.max_seq_len))
127
+ loss, _ = model.forward_train(x, mask, targets)
128
+ loss.backward()
129
+ opt.step()
130
+
131
+ orig_state = opt.state_dict()
132
+ tensors, meta = _flatten_optimizer_state(orig_state)
133
+
134
+ # All tensors should be named "state.{id}.{key}"
135
+ for k in tensors:
136
+ assert k.startswith("state."), f"Unexpected key: {k}"
137
+
138
+ restored = _unflatten_optimizer_state(tensors, meta)
139
+
140
+ # Verify param_groups match
141
+ assert len(restored["param_groups"]) == len(orig_state["param_groups"])
142
+
143
+ # Verify state tensors match
144
+ for param_id in orig_state["state"]:
145
+ for key in ("exp_avg", "exp_avg_sq"):
146
+ torch.testing.assert_close(
147
+ orig_state["state"][param_id][key],
148
+ restored["state"][param_id][key],
149
+ )
150
+
151
+
152
+ def test_rng_roundtrip():
153
+ """Verify RNG state survives JSON serialization."""
154
+ torch_rng = torch.get_rng_state()
155
+ cuda_rng = None
156
+
157
+ data = _rng_to_json(torch_rng, cuda_rng)
158
+ assert "torch_rng_state" in data
159
+ assert "cuda_rng_state" not in data
160
+
161
+ restored_torch, restored_cuda = _json_to_rng(data)
162
+ assert restored_cuda is None
163
+ assert torch.equal(torch_rng, restored_torch)
164
+
165
+
166
+ def test_adapter_checkpoint_roundtrip():
167
+ """Save and load an adapter checkpoint."""
168
+ adapter_weights = {
169
+ "down.weight": torch.randn(8, 64),
170
+ "up.weight": torch.zeros(64, 8),
171
+ }
172
+
173
+ with tempfile.TemporaryDirectory() as tmpdir:
174
+ ckpt_path = Path(tmpdir) / "best"
175
+ save_adapter_checkpoint(
176
+ ckpt_path,
177
+ adapter_weights,
178
+ config={"bottleneck_dim": 8, "checkpoint_type": "bottleneck"},
179
+ epoch=5,
180
+ step=1000,
181
+ val_metrics={"loss": 3.14, "top1_accuracy": 0.07},
182
+ extra={"best_val_loss": 3.10, "patience_counter": 2},
183
+ )
184
+
185
+ assert (ckpt_path / "adapter.safetensors").exists()
186
+ assert (ckpt_path / "config.json").exists()
187
+ assert (ckpt_path / "training_state.json").exists()
188
+ assert not (ckpt_path / "optimizer.safetensors").exists() # no optimizer passed
189
+
190
+ loaded = load_adapter_checkpoint(ckpt_path)
191
+ assert loaded["epoch"] == 5
192
+ assert loaded["step"] == 1000
193
+ assert loaded["val_metrics"]["loss"] == 3.14
194
+ assert loaded["best_val_loss"] == 3.10
195
+
196
+ for k in adapter_weights:
197
+ torch.testing.assert_close(adapter_weights[k], loaded["adapter_state_dict"][k])
198
+
199
+
200
+ def test_load_backbone_weights_new_format():
201
+ """load_backbone_weights works with new directory format."""
202
+ model1, cfg = _make_toy_model()
203
+
204
+ with tempfile.TemporaryDirectory() as tmpdir:
205
+ ckpt_path = Path(tmpdir) / "step_00000001"
206
+ opt = _make_optimizer(model1)
207
+ save_pretrain_checkpoint(
208
+ ckpt_path, model1, opt, FakeScheduler(), FakeScaler(),
209
+ global_step=1, model_config=cfg.__dict__, training_config={},
210
+ )
211
+
212
+ state_dict, model_config = load_backbone_weights(ckpt_path)
213
+ assert model_config["d_model"] == cfg.d_model
214
+ for k in model1.state_dict():
215
+ torch.testing.assert_close(model1.state_dict()[k], state_dict[k])
216
+
217
+
218
+ def test_load_backbone_weights_legacy():
219
+ """load_backbone_weights works with legacy .pt files."""
220
+ model1, cfg = _make_toy_model()
221
+
222
+ with tempfile.TemporaryDirectory() as tmpdir:
223
+ pt_path = Path(tmpdir) / "model.pt"
224
+ torch.save({
225
+ "model_state_dict": model1.state_dict(),
226
+ "model_config": cfg.__dict__,
227
+ }, pt_path)
228
+
229
+ state_dict, model_config = load_backbone_weights(pt_path)
230
+ assert model_config["d_model"] == cfg.d_model
231
+ for k in model1.state_dict():
232
+ torch.testing.assert_close(model1.state_dict()[k], state_dict[k])
233
+
234
+
235
+ def test_is_legacy_checkpoint():
236
+ """Detect legacy vs new format."""
237
+ with tempfile.TemporaryDirectory() as tmpdir:
238
+ pt = Path(tmpdir) / "step.pt"
239
+ pt.touch()
240
+ assert is_legacy_checkpoint(pt)
241
+
242
+ d = Path(tmpdir) / "step_00001"
243
+ d.mkdir()
244
+ assert not is_legacy_checkpoint(d)
245
+
246
+
247
+ def test_config_json_contents():
248
+ """Verify config.json has expected structure."""
249
+ model, cfg = _make_toy_model()
250
+ opt = _make_optimizer(model)
251
+
252
+ with tempfile.TemporaryDirectory() as tmpdir:
253
+ ckpt_path = Path(tmpdir) / "step_00000001"
254
+ save_pretrain_checkpoint(
255
+ ckpt_path, model, opt, FakeScheduler(), FakeScaler(),
256
+ global_step=1, model_config=cfg.__dict__,
257
+ training_config={"lr": 1e-3, "batch_size": 32},
258
+ )
259
+
260
+ with open(ckpt_path / "config.json") as f:
261
+ config = json.load(f)
262
+
263
+ assert config["format_version"] == 1
264
+ assert config["checkpoint_type"] == "pretrain"
265
+ assert config["model_config"]["d_model"] == cfg.d_model
266
+ assert config["training_config"]["lr"] == 1e-3
267
+
268
+
269
+ def test_complete_sentinel_written():
270
+ """Verify .complete sentinel is created with SHA-256 hashes."""
271
+ model, cfg = _make_toy_model()
272
+ opt = _make_optimizer(model)
273
+
274
+ with tempfile.TemporaryDirectory() as tmpdir:
275
+ ckpt_path = Path(tmpdir) / "step_00000001"
276
+ save_pretrain_checkpoint(
277
+ ckpt_path, model, opt, FakeScheduler(), FakeScaler(),
278
+ global_step=1, model_config=cfg.__dict__, training_config={},
279
+ )
280
+
281
+ sentinel_path = ckpt_path / ".complete"
282
+ assert sentinel_path.exists()
283
+
284
+ with open(sentinel_path) as f:
285
+ sentinel = json.load(f)
286
+
287
+ assert "files" in sentinel
288
+ assert "model.safetensors" in sentinel["files"]
289
+ assert "config.json" in sentinel["files"]
290
+ assert "training_state.json" in sentinel["files"]
291
+ # Each hash should be a 64-char hex string
292
+ for name, h in sentinel["files"].items():
293
+ assert len(h) == 64, f"Bad hash length for {name}: {len(h)}"
294
+
295
+
296
+ def test_no_tmp_directory_after_save():
297
+ """Verify .tmp directory is cleaned up after successful save."""
298
+ model, cfg = _make_toy_model()
299
+ opt = _make_optimizer(model)
300
+
301
+ with tempfile.TemporaryDirectory() as tmpdir:
302
+ ckpt_path = Path(tmpdir) / "step_00000001"
303
+ save_pretrain_checkpoint(
304
+ ckpt_path, model, opt, FakeScheduler(), FakeScaler(),
305
+ global_step=1, model_config=cfg.__dict__, training_config={},
306
+ )
307
+
308
+ tmp_path = Path(tmpdir) / "step_00000001.tmp"
309
+ assert not tmp_path.exists(), ".tmp directory should be removed after rename"
310
+
311
+
312
+ def test_incomplete_checkpoint_raises():
313
+ """Loading a checkpoint without .complete raises IncompleteCheckpointError."""
314
+ with tempfile.TemporaryDirectory() as tmpdir:
315
+ ckpt_path = Path(tmpdir) / "step_bad"
316
+ ckpt_path.mkdir()
317
+ # Write some files but no .complete
318
+ (ckpt_path / "model.safetensors").touch()
319
+ (ckpt_path / "config.json").write_text("{}")
320
+
321
+ with pytest.raises(IncompleteCheckpointError):
322
+ _verify_complete_sentinel(ckpt_path)
323
+
324
+
325
+ def test_corrupted_file_raises():
326
+ """Loading a checkpoint with a tampered file raises CheckpointIntegrityError."""
327
+ model, cfg = _make_toy_model()
328
+ opt = _make_optimizer(model)
329
+
330
+ with tempfile.TemporaryDirectory() as tmpdir:
331
+ ckpt_path = Path(tmpdir) / "step_00000001"
332
+ save_pretrain_checkpoint(
333
+ ckpt_path, model, opt, FakeScheduler(), FakeScaler(),
334
+ global_step=1, model_config=cfg.__dict__, training_config={},
335
+ )
336
+
337
+ # Corrupt a file after save
338
+ config_path = ckpt_path / "config.json"
339
+ config_path.write_text("CORRUPTED DATA")
340
+
341
+ with pytest.raises(CheckpointIntegrityError):
342
+ _verify_complete_sentinel(ckpt_path)
343
+
344
+
345
+ def test_adapter_checkpoint_has_sentinel():
346
+ """Adapter checkpoints also get .complete sentinel."""
347
+ adapter_weights = {"down.weight": torch.randn(8, 64)}
348
+
349
+ with tempfile.TemporaryDirectory() as tmpdir:
350
+ ckpt_path = Path(tmpdir) / "best"
351
+ save_adapter_checkpoint(
352
+ ckpt_path, adapter_weights,
353
+ config={"checkpoint_type": "bottleneck"},
354
+ epoch=1, step=100, val_metrics={"loss": 3.0},
355
+ )
356
+
357
+ assert (ckpt_path / ".complete").exists()
358
+ # Should not raise
359
+ _verify_complete_sentinel(ckpt_path)
tests/test_clm_format.py ADDED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tests for CLM sequence format, off-by-one boundaries, and engine/Python consistency."""
2
+
3
+ import numpy as np
4
+ import torch
5
+ import pytest
6
+
7
+ import chess_engine as engine
8
+
9
+ from pawn.config import (
10
+ PAD_TOKEN,
11
+ OUTCOME_TOKEN_BASE,
12
+ WHITE_CHECKMATES,
13
+ BLACK_CHECKMATES,
14
+ STALEMATE,
15
+ DRAW_BY_RULE,
16
+ PLY_LIMIT,
17
+ )
18
+ from pawn.data import (
19
+ _map_termination_to_outcome,
20
+ pack_clm_sequences,
21
+ _to_clm_batch,
22
+ )
23
+
24
+
25
+ # ---------------------------------------------------------------------------
26
+ # Vocabulary consistency
27
+ # ---------------------------------------------------------------------------
28
+
29
+
30
+ class TestVocabConsistency:
31
+ def test_vocab_size(self):
32
+ """Engine move vocab + PAD + outcomes = model vocab_size."""
33
+ vocab = engine.export_move_vocabulary()
34
+ n_moves = len(vocab["token_to_move"])
35
+ # 4096 grid + 176 promotions = 4272 move tokens
36
+ assert n_moves == 4272
37
+ # 1 PAD + 4272 moves + 5 outcomes = 4278
38
+ from pawn.config import CLMConfig
39
+ assert CLMConfig().vocab_size == 1 + n_moves + 5
40
+
41
+ def test_outcome_tokens_not_in_move_vocab(self):
42
+ """Outcome tokens (4273-4277) must not appear in the move vocabulary."""
43
+ vocab = engine.export_move_vocabulary()
44
+ for token_id in range(OUTCOME_TOKEN_BASE, PLY_LIMIT + 1):
45
+ assert token_id not in vocab["token_to_move"], \
46
+ f"Outcome token {token_id} should not be in move vocab"
47
+
48
+ def test_no_eog_in_raw_move_ids(self):
49
+ """Raw move_ids from generate_random_games should not contain
50
+ tokens >= OUTCOME_BASE (no EOG or outcome tokens in move data)."""
51
+ move_ids, game_lengths, _tc = engine.generate_random_games(64, 255, seed=42)
52
+ for b in range(64):
53
+ gl = int(game_lengths[b])
54
+ # Moves should be in range 1-4272
55
+ for t in range(gl):
56
+ tok = int(move_ids[b, t])
57
+ assert 1 <= tok <= 4272, \
58
+ f"Game {b}, ply {t}: expected move token, got {tok}"
59
+ # Position game_length should be PAD (0), not EOG
60
+ if gl < 255:
61
+ assert move_ids[b, gl] == PAD_TOKEN, \
62
+ f"Game {b}: position {gl} should be PAD, got {move_ids[b, gl]}"
63
+
64
+
65
+ # ---------------------------------------------------------------------------
66
+ # CLM batch format (Rust engine)
67
+ # ---------------------------------------------------------------------------
68
+
69
+
70
+ class TestRustCLMBatch:
71
+ @pytest.fixture
72
+ def clm_batch(self):
73
+ return engine.generate_clm_batch(32, 256, seed=42)
74
+
75
+ def test_shapes(self, clm_batch):
76
+ input_ids, targets, loss_mask, move_ids, game_lengths, term_codes = clm_batch
77
+ assert input_ids.shape == (32, 256)
78
+ assert targets.shape == (32, 256)
79
+ assert loss_mask.shape == (32, 256)
80
+ assert move_ids.shape == (32, 255)
81
+ assert game_lengths.shape == (32,)
82
+ assert term_codes.shape == (32,)
83
+
84
+ def test_position_zero_is_outcome(self, clm_batch):
85
+ input_ids, *_ = clm_batch
86
+ for b in range(input_ids.shape[0]):
87
+ tok = int(input_ids[b, 0])
88
+ assert OUTCOME_TOKEN_BASE <= tok <= PLY_LIMIT, \
89
+ f"Game {b}: position 0 should be outcome token, got {tok}"
90
+
91
+ def test_moves_in_valid_range(self, clm_batch):
92
+ input_ids, _, _, _, game_lengths, _ = clm_batch
93
+ for b in range(input_ids.shape[0]):
94
+ gl = int(game_lengths[b])
95
+ for t in range(1, gl + 1):
96
+ tok = int(input_ids[b, t])
97
+ assert 1 <= tok <= 4272, \
98
+ f"Game {b}, position {t}: expected move token, got {tok}"
99
+
100
+ def test_padding_is_zero(self, clm_batch):
101
+ input_ids, _, _, _, game_lengths, _ = clm_batch
102
+ for b in range(input_ids.shape[0]):
103
+ gl = int(game_lengths[b])
104
+ for t in range(gl + 1, 256):
105
+ assert input_ids[b, t] == 0, \
106
+ f"Game {b}, position {t}: expected PAD, got {input_ids[b, t]}"
107
+
108
+ def test_target_shift_correct(self, clm_batch):
109
+ input_ids, targets, *_ = clm_batch
110
+ B, T = input_ids.shape
111
+ for b in range(B):
112
+ for t in range(T - 1):
113
+ assert targets[b, t] == input_ids[b, t + 1], \
114
+ f"Game {b}: targets[{t}]={targets[b, t]} != input_ids[{t+1}]={input_ids[b, t+1]}"
115
+ assert targets[b, T - 1] == 0, "Last target should be PAD"
116
+
117
+ def test_target_at_game_end_is_pad(self, clm_batch):
118
+ _, targets, _, _, game_lengths, _ = clm_batch
119
+ for b in range(targets.shape[0]):
120
+ gl = int(game_lengths[b])
121
+ assert targets[b, gl] == 0, \
122
+ f"Game {b}: target at game_length={gl} should be PAD, got {targets[b, gl]}"
123
+
124
+ def test_loss_mask_boundary(self, clm_batch):
125
+ _, _, loss_mask, _, game_lengths, _ = clm_batch
126
+ for b in range(loss_mask.shape[0]):
127
+ gl = int(game_lengths[b])
128
+ # True for positions 0..=gl
129
+ for t in range(gl + 1):
130
+ assert loss_mask[b, t], \
131
+ f"Game {b}: loss_mask[{t}] should be True (gl={gl})"
132
+ # False after gl
133
+ for t in range(gl + 1, 256):
134
+ assert not loss_mask[b, t], \
135
+ f"Game {b}: loss_mask[{t}] should be False (gl={gl})"
136
+
137
+ def test_raw_move_ids_replayable(self, clm_batch):
138
+ """Raw move_ids from generate_clm_batch should work with replay functions."""
139
+ _, _, _, move_ids, game_lengths, _ = clm_batch
140
+ # validate_games should confirm all games are legal
141
+ is_valid, first_illegal = engine.validate_games(move_ids, game_lengths)
142
+ assert all(is_valid), "All generated games should be valid"
143
+ # compute_legal_move_masks should not error
144
+ grid, promo = engine.compute_legal_move_masks(move_ids, game_lengths)
145
+ assert grid.shape[0] == 32
146
+
147
+
148
+ # ---------------------------------------------------------------------------
149
+ # Rust CLM matches Python pack_clm_sequences
150
+ # ---------------------------------------------------------------------------
151
+
152
+
153
+ class TestRustPythonConsistency:
154
+ def test_rust_clm_matches_python_pack(self):
155
+ """Rust generate_clm_batch should produce identical output to
156
+ Python _to_clm_batch with the same seed."""
157
+ seq_len = 256
158
+ seed = 42
159
+ B = 16
160
+
161
+ # Rust path
162
+ r_input_ids, r_targets, r_loss_mask, r_move_ids, r_gl, r_tc = \
163
+ engine.generate_clm_batch(B, seq_len, seed)
164
+
165
+ # Python path: generate raw + pack
166
+ py_batch = _to_clm_batch(r_move_ids, r_gl, r_tc, seq_len)
167
+
168
+ # Compare
169
+ r_input_ids_t = torch.from_numpy(r_input_ids).long()
170
+ r_targets_t = torch.from_numpy(r_targets).long()
171
+ r_loss_mask_t = torch.from_numpy(r_loss_mask)
172
+
173
+ assert torch.equal(r_input_ids_t, py_batch["input_ids"]), \
174
+ "input_ids mismatch between Rust and Python"
175
+ assert torch.equal(r_targets_t, py_batch["targets"]), \
176
+ "targets mismatch between Rust and Python"
177
+ assert torch.equal(r_loss_mask_t, py_batch["loss_mask"]), \
178
+ "loss_mask mismatch between Rust and Python"
179
+
180
+
181
+ # ---------------------------------------------------------------------------
182
+ # Boundary / edge cases
183
+ # ---------------------------------------------------------------------------
184
+
185
+
186
+ class TestBoundaryCases:
187
+ def test_boundary_max_length_game(self):
188
+ """Test with games that may fill all 255 move slots.
189
+
190
+ We generate many games and check that any near-max-length game
191
+ is correctly formatted (no off-by-one at the boundary).
192
+ """
193
+ input_ids, targets, loss_mask, _, game_lengths, _ = \
194
+ engine.generate_clm_batch(256, 256, seed=123)
195
+
196
+ for b in range(256):
197
+ gl = int(game_lengths[b])
198
+ # Regardless of length, format invariants hold
199
+ assert OUTCOME_TOKEN_BASE <= input_ids[b, 0] <= PLY_LIMIT
200
+ if gl + 1 < 256:
201
+ assert input_ids[b, gl + 1] == 0
202
+ assert loss_mask[b, gl] == True
203
+ if gl + 1 < 256:
204
+ assert loss_mask[b, gl + 1] == False
205
+
206
+ def test_discard_ply_limit(self):
207
+ """generate_clm_batch with discard_ply_limit should only have
208
+ naturally-terminated games (no PLY_LIMIT outcome)."""
209
+ input_ids, _, _, _, _, term_codes = \
210
+ engine.generate_clm_batch(32, 256, seed=42, discard_ply_limit=True)
211
+
212
+ for b in range(32):
213
+ assert term_codes[b] != 5, \
214
+ f"Game {b} has PLY_LIMIT termination but discard_ply_limit=True"
215
+ assert input_ids[b, 0] != PLY_LIMIT, \
216
+ f"Game {b} has PLY_LIMIT outcome token but discard_ply_limit=True"
217
+
218
+ def test_determinism(self):
219
+ """Same seed should produce identical results."""
220
+ r1 = engine.generate_clm_batch(8, 256, seed=99)
221
+ r2 = engine.generate_clm_batch(8, 256, seed=99)
222
+ assert np.array_equal(r1[0], r2[0]), "input_ids should be deterministic"
223
+ assert np.array_equal(r1[1], r2[1]), "targets should be deterministic"
224
+ assert np.array_equal(r1[2], r2[2]), "loss_mask should be deterministic"
tests/test_graceful_shutdown.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Regression test: training processes exit gracefully on SIGTERM.
2
+
3
+ Spawns a toy training run as a subprocess, sends SIGTERM, and verifies
4
+ that the process exits cleanly with valid checkpoints.
5
+ """
6
+
7
+ import json
8
+ import os
9
+ import signal
10
+ import subprocess
11
+ import sys
12
+ import tempfile
13
+ import time
14
+ from pathlib import Path
15
+
16
+ import pytest
17
+
18
+ from pawn.checkpoint import load_backbone_weights, _verify_complete_sentinel
19
+
20
+
21
+ @pytest.fixture
22
+ def train_tmpdir():
23
+ with tempfile.TemporaryDirectory() as d:
24
+ yield Path(d)
25
+
26
+
27
+ def _wait_for_training_start(proc: subprocess.Popen, timeout: float = 30) -> None:
28
+ """Wait until the training process prints 'Starting training'."""
29
+ deadline = time.time() + timeout
30
+ assert proc.stdout is not None
31
+ while time.time() < deadline:
32
+ line = proc.stdout.readline()
33
+ if not line:
34
+ time.sleep(0.1)
35
+ continue
36
+ text = line.decode("utf-8", errors="replace")
37
+ if "Starting training" in text:
38
+ return
39
+ raise TimeoutError("Training process did not start within timeout")
40
+
41
+
42
+ def test_sigterm_produces_valid_checkpoint(train_tmpdir):
43
+ """Spawn a toy training run, send SIGTERM, verify checkpoint is valid."""
44
+ ckpt_dir = train_tmpdir / "checkpoints"
45
+ log_dir = train_tmpdir / "logs"
46
+
47
+ proc = subprocess.Popen(
48
+ [
49
+ sys.executable, "scripts/train.py",
50
+ "--toy",
51
+ "--local-checkpoints",
52
+ "--total-steps", "5000",
53
+ "--device", "cpu",
54
+ "--num-workers", "0",
55
+ "--checkpoint-dir", str(ckpt_dir),
56
+ "--log-dir", str(log_dir),
57
+ ],
58
+ stdout=subprocess.PIPE,
59
+ stderr=subprocess.STDOUT,
60
+ env={**os.environ, "PYTHONUNBUFFERED": "1"},
61
+ )
62
+
63
+ try:
64
+ _wait_for_training_start(proc)
65
+
66
+ # Let it train for a few checkpoints
67
+ time.sleep(5)
68
+
69
+ # Send SIGTERM
70
+ proc.send_signal(signal.SIGTERM)
71
+
72
+ # Wait for graceful exit
73
+ returncode = proc.wait(timeout=30)
74
+ except Exception:
75
+ proc.kill()
76
+ proc.wait()
77
+ raise
78
+
79
+ # Process should exit cleanly (0 from break, not 128+signal from sys.exit)
80
+ assert returncode == 0, f"Training exited with code {returncode}"
81
+
82
+ # Find checkpoints (trainer saves under {log_dir}/run_*/checkpoints/)
83
+ checkpoints = sorted(log_dir.glob("run_*/checkpoints/step_*/"))
84
+ assert len(checkpoints) > 0, "No checkpoints were saved"
85
+
86
+ # Every checkpoint must have .complete sentinel and pass integrity check
87
+ for ckpt in checkpoints:
88
+ assert (ckpt / ".complete").exists(), f"Missing .complete in {ckpt.name}"
89
+ _verify_complete_sentinel(ckpt) # raises on failure
90
+
91
+ # Verify we can actually load the weights
92
+ weights, config = load_backbone_weights(ckpt)
93
+ assert config is not None, f"No config in {ckpt.name}"
94
+ assert len(weights) > 0, f"Empty weights in {ckpt.name}"
95
+
96
+
97
+ def test_sigterm_does_not_leave_tmp_dirs(train_tmpdir):
98
+ """SIGTERM should not leave .tmp directories from interrupted saves."""
99
+ ckpt_dir = train_tmpdir / "checkpoints"
100
+ log_dir = train_tmpdir / "logs"
101
+
102
+ proc = subprocess.Popen(
103
+ [
104
+ sys.executable, "scripts/train.py",
105
+ "--toy",
106
+ "--local-checkpoints",
107
+ "--total-steps", "5000",
108
+ "--device", "cpu",
109
+ "--num-workers", "0",
110
+ "--checkpoint-dir", str(ckpt_dir),
111
+ "--log-dir", str(log_dir),
112
+ ],
113
+ stdout=subprocess.PIPE,
114
+ stderr=subprocess.STDOUT,
115
+ env={**os.environ, "PYTHONUNBUFFERED": "1"},
116
+ )
117
+
118
+ try:
119
+ _wait_for_training_start(proc)
120
+ time.sleep(5)
121
+ proc.send_signal(signal.SIGTERM)
122
+ proc.wait(timeout=30)
123
+ except Exception:
124
+ proc.kill()
125
+ proc.wait()
126
+ raise
127
+
128
+ # No .tmp directories should exist anywhere under log_dir
129
+ tmp_dirs = list(log_dir.glob("**/*.tmp"))
130
+ assert len(tmp_dirs) == 0, f"Leftover .tmp directories: {tmp_dirs}"
uv.lock CHANGED
@@ -20,24 +20,6 @@ members = [
20
  "pawn",
21
  ]
22
 
23
- [[package]]
24
- name = "aiofiles"
25
- version = "24.1.0"
26
- source = { registry = "https://pypi.org/simple" }
27
- sdist = { url = "https://files.pythonhosted.org/packages/0b/03/a88171e277e8caa88a4c77808c20ebb04ba74cc4681bf1e9416c862de237/aiofiles-24.1.0.tar.gz", hash = "sha256:22a075c9e5a3810f0c2e48f3008c94d68c65d763b9b03857924c99e57355166c", size = 30247, upload-time = "2024-06-24T11:02:03.584Z" }
28
- wheels = [
29
- { url = "https://files.pythonhosted.org/packages/a5/45/30bb92d442636f570cb5651bc661f52b610e2eec3f891a5dc3a4c3667db0/aiofiles-24.1.0-py3-none-any.whl", hash = "sha256:b4ec55f4195e3eb5d7abd1bf7e061763e864dd4954231fb8539a0ef8bb8260e5", size = 15896, upload-time = "2024-06-24T11:02:01.529Z" },
30
- ]
31
-
32
- [[package]]
33
- name = "annotated-doc"
34
- version = "0.0.4"
35
- source = { registry = "https://pypi.org/simple" }
36
- sdist = { url = "https://files.pythonhosted.org/packages/57/ba/046ceea27344560984e26a590f90bc7f4a75b06701f653222458922b558c/annotated_doc-0.0.4.tar.gz", hash = "sha256:fbcda96e87e9c92ad167c2e53839e57503ecfda18804ea28102353485033faa4", size = 7288, upload-time = "2025-11-10T22:07:42.062Z" }
37
- wheels = [
38
- { url = "https://files.pythonhosted.org/packages/1e/d3/26bf1008eb3d2daa8ef4cacc7f3bfdc11818d111f7e2d0201bc6e3b49d45/annotated_doc-0.0.4-py3-none-any.whl", hash = "sha256:571ac1dc6991c450b25a9c2d84a3705e2ae7a53467b5d111c24fa8baabbed320", size = 5303, upload-time = "2025-11-10T22:07:40.673Z" },
39
- ]
40
-
41
  [[package]]
42
  name = "annotated-types"
43
  version = "0.7.0"
@@ -93,44 +75,6 @@ wheels = [
93
  { url = "https://files.pythonhosted.org/packages/64/b4/17d4b0b2a2dc85a6df63d1157e028ed19f90d4cd97c36717afef2bc2f395/attrs-26.1.0-py3-none-any.whl", hash = "sha256:c647aa4a12dfbad9333ca4e71fe62ddc36f4e63b2d260a37a8b83d2f043ac309", size = 67548, upload-time = "2026-03-19T14:22:23.645Z" },
94
  ]
95
 
96
- [[package]]
97
- name = "brotli"
98
- version = "1.2.0"
99
- source = { registry = "https://pypi.org/simple" }
100
- sdist = { url = "https://files.pythonhosted.org/packages/f7/16/c92ca344d646e71a43b8bb353f0a6490d7f6e06210f8554c8f874e454285/brotli-1.2.0.tar.gz", hash = "sha256:e310f77e41941c13340a95976fe66a8a95b01e783d430eeaf7a2f87e0a57dd0a", size = 7388632, upload-time = "2025-11-05T18:39:42.86Z" }
101
- wheels = [
102
- { url = "https://files.pythonhosted.org/packages/64/10/a090475284fc4a71aed40a96f32e44a7fe5bda39687353dd977720b211b6/brotli-1.2.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:3b90b767916ac44e93a8e28ce6adf8d551e43affb512f2377c732d486ac6514e", size = 863089, upload-time = "2025-11-05T18:38:01.181Z" },
103
- { url = "https://files.pythonhosted.org/packages/03/41/17416630e46c07ac21e378c3464815dd2e120b441e641bc516ac32cc51d2/brotli-1.2.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:6be67c19e0b0c56365c6a76e393b932fb0e78b3b56b711d180dd7013cb1fd984", size = 445442, upload-time = "2025-11-05T18:38:02.434Z" },
104
- { url = "https://files.pythonhosted.org/packages/24/31/90cc06584deb5d4fcafc0985e37741fc6b9717926a78674bbb3ce018957e/brotli-1.2.0-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:0bbd5b5ccd157ae7913750476d48099aaf507a79841c0d04a9db4415b14842de", size = 1532658, upload-time = "2025-11-05T18:38:03.588Z" },
105
- { url = "https://files.pythonhosted.org/packages/62/17/33bf0c83bcbc96756dfd712201d87342732fad70bb3472c27e833a44a4f9/brotli-1.2.0-cp310-cp310-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:3f3c908bcc404c90c77d5a073e55271a0a498f4e0756e48127c35d91cf155947", size = 1631241, upload-time = "2025-11-05T18:38:04.582Z" },
106
- { url = "https://files.pythonhosted.org/packages/48/10/f47854a1917b62efe29bc98ac18e5d4f71df03f629184575b862ef2e743b/brotli-1.2.0-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:1b557b29782a643420e08d75aea889462a4a8796e9a6cf5621ab05a3f7da8ef2", size = 1424307, upload-time = "2025-11-05T18:38:05.587Z" },
107
- { url = "https://files.pythonhosted.org/packages/e4/b7/f88eb461719259c17483484ea8456925ee057897f8e64487d76e24e5e38d/brotli-1.2.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:81da1b229b1889f25adadc929aeb9dbc4e922bd18561b65b08dd9343cfccca84", size = 1488208, upload-time = "2025-11-05T18:38:06.613Z" },
108
- { url = "https://files.pythonhosted.org/packages/26/59/41bbcb983a0c48b0b8004203e74706c6b6e99a04f3c7ca6f4f41f364db50/brotli-1.2.0-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:ff09cd8c5eec3b9d02d2408db41be150d8891c5566addce57513bf546e3d6c6d", size = 1597574, upload-time = "2025-11-05T18:38:07.838Z" },
109
- { url = "https://files.pythonhosted.org/packages/8e/e6/8c89c3bdabbe802febb4c5c6ca224a395e97913b5df0dff11b54f23c1788/brotli-1.2.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:a1778532b978d2536e79c05dac2d8cd857f6c55cd0c95ace5b03740824e0e2f1", size = 1492109, upload-time = "2025-11-05T18:38:08.816Z" },
110
- { url = "https://files.pythonhosted.org/packages/ed/9a/4b19d4310b2dbd545c0c33f176b0528fa68c3cd0754e34b2f2bcf56548ae/brotli-1.2.0-cp310-cp310-win32.whl", hash = "sha256:b232029d100d393ae3c603c8ffd7e3fe6f798c5e28ddca5feabb8e8fdb732997", size = 334461, upload-time = "2025-11-05T18:38:10.729Z" },
111
- { url = "https://files.pythonhosted.org/packages/ac/39/70981d9f47705e3c2b95c0847dfa3e7a37aa3b7c6030aedc4873081ed005/brotli-1.2.0-cp310-cp310-win_amd64.whl", hash = "sha256:ef87b8ab2704da227e83a246356a2b179ef826f550f794b2c52cddb4efbd0196", size = 369035, upload-time = "2025-11-05T18:38:11.827Z" },
112
- { url = "https://files.pythonhosted.org/packages/7a/ef/f285668811a9e1ddb47a18cb0b437d5fc2760d537a2fe8a57875ad6f8448/brotli-1.2.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:15b33fe93cedc4caaff8a0bd1eb7e3dab1c61bb22a0bf5bdfdfd97cd7da79744", size = 863110, upload-time = "2025-11-05T18:38:12.978Z" },
113
- { url = "https://files.pythonhosted.org/packages/50/62/a3b77593587010c789a9d6eaa527c79e0848b7b860402cc64bc0bc28a86c/brotli-1.2.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:898be2be399c221d2671d29eed26b6b2713a02c2119168ed914e7d00ceadb56f", size = 445438, upload-time = "2025-11-05T18:38:14.208Z" },
114
- { url = "https://files.pythonhosted.org/packages/cd/e1/7fadd47f40ce5549dc44493877db40292277db373da5053aff181656e16e/brotli-1.2.0-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:350c8348f0e76fff0a0fd6c26755d2653863279d086d3aa2c290a6a7251135dd", size = 1534420, upload-time = "2025-11-05T18:38:15.111Z" },
115
- { url = "https://files.pythonhosted.org/packages/12/8b/1ed2f64054a5a008a4ccd2f271dbba7a5fb1a3067a99f5ceadedd4c1d5a7/brotli-1.2.0-cp311-cp311-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:2e1ad3fda65ae0d93fec742a128d72e145c9c7a99ee2fcd667785d99eb25a7fe", size = 1632619, upload-time = "2025-11-05T18:38:16.094Z" },
116
- { url = "https://files.pythonhosted.org/packages/89/5a/7071a621eb2d052d64efd5da2ef55ecdac7c3b0c6e4f9d519e9c66d987ef/brotli-1.2.0-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:40d918bce2b427a0c4ba189df7a006ac0c7277c180aee4617d99e9ccaaf59e6a", size = 1426014, upload-time = "2025-11-05T18:38:17.177Z" },
117
- { url = "https://files.pythonhosted.org/packages/26/6d/0971a8ea435af5156acaaccec1a505f981c9c80227633851f2810abd252a/brotli-1.2.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:2a7f1d03727130fc875448b65b127a9ec5d06d19d0148e7554384229706f9d1b", size = 1489661, upload-time = "2025-11-05T18:38:18.41Z" },
118
- { url = "https://files.pythonhosted.org/packages/f3/75/c1baca8b4ec6c96a03ef8230fab2a785e35297632f402ebb1e78a1e39116/brotli-1.2.0-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:9c79f57faa25d97900bfb119480806d783fba83cd09ee0b33c17623935b05fa3", size = 1599150, upload-time = "2025-11-05T18:38:19.792Z" },
119
- { url = "https://files.pythonhosted.org/packages/0d/1a/23fcfee1c324fd48a63d7ebf4bac3a4115bdb1b00e600f80f727d850b1ae/brotli-1.2.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:844a8ceb8483fefafc412f85c14f2aae2fb69567bf2a0de53cdb88b73e7c43ae", size = 1493505, upload-time = "2025-11-05T18:38:20.913Z" },
120
- { url = "https://files.pythonhosted.org/packages/36/e5/12904bbd36afeef53d45a84881a4810ae8810ad7e328a971ebbfd760a0b3/brotli-1.2.0-cp311-cp311-win32.whl", hash = "sha256:aa47441fa3026543513139cb8926a92a8e305ee9c71a6209ef7a97d91640ea03", size = 334451, upload-time = "2025-11-05T18:38:21.94Z" },
121
- { url = "https://files.pythonhosted.org/packages/02/8b/ecb5761b989629a4758c394b9301607a5880de61ee2ee5fe104b87149ebc/brotli-1.2.0-cp311-cp311-win_amd64.whl", hash = "sha256:022426c9e99fd65d9475dce5c195526f04bb8be8907607e27e747893f6ee3e24", size = 369035, upload-time = "2025-11-05T18:38:22.941Z" },
122
- { url = "https://files.pythonhosted.org/packages/11/ee/b0a11ab2315c69bb9b45a2aaed022499c9c24a205c3a49c3513b541a7967/brotli-1.2.0-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:35d382625778834a7f3061b15423919aa03e4f5da34ac8e02c074e4b75ab4f84", size = 861543, upload-time = "2025-11-05T18:38:24.183Z" },
123
- { url = "https://files.pythonhosted.org/packages/e1/2f/29c1459513cd35828e25531ebfcbf3e92a5e49f560b1777a9af7203eb46e/brotli-1.2.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:7a61c06b334bd99bc5ae84f1eeb36bfe01400264b3c352f968c6e30a10f9d08b", size = 444288, upload-time = "2025-11-05T18:38:25.139Z" },
124
- { url = "https://files.pythonhosted.org/packages/3d/6f/feba03130d5fceadfa3a1bb102cb14650798c848b1df2a808356f939bb16/brotli-1.2.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:acec55bb7c90f1dfc476126f9711a8e81c9af7fb617409a9ee2953115343f08d", size = 1528071, upload-time = "2025-11-05T18:38:26.081Z" },
125
- { url = "https://files.pythonhosted.org/packages/2b/38/f3abb554eee089bd15471057ba85f47e53a44a462cfce265d9bf7088eb09/brotli-1.2.0-cp312-cp312-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:260d3692396e1895c5034f204f0db022c056f9e2ac841593a4cf9426e2a3faca", size = 1626913, upload-time = "2025-11-05T18:38:27.284Z" },
126
- { url = "https://files.pythonhosted.org/packages/03/a7/03aa61fbc3c5cbf99b44d158665f9b0dd3d8059be16c460208d9e385c837/brotli-1.2.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:072e7624b1fc4d601036ab3f4f27942ef772887e876beff0301d261210bca97f", size = 1419762, upload-time = "2025-11-05T18:38:28.295Z" },
127
- { url = "https://files.pythonhosted.org/packages/21/1b/0374a89ee27d152a5069c356c96b93afd1b94eae83f1e004b57eb6ce2f10/brotli-1.2.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:adedc4a67e15327dfdd04884873c6d5a01d3e3b6f61406f99b1ed4865a2f6d28", size = 1484494, upload-time = "2025-11-05T18:38:29.29Z" },
128
- { url = "https://files.pythonhosted.org/packages/cf/57/69d4fe84a67aef4f524dcd075c6eee868d7850e85bf01d778a857d8dbe0a/brotli-1.2.0-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:7a47ce5c2288702e09dc22a44d0ee6152f2c7eda97b3c8482d826a1f3cfc7da7", size = 1593302, upload-time = "2025-11-05T18:38:30.639Z" },
129
- { url = "https://files.pythonhosted.org/packages/d5/3b/39e13ce78a8e9a621c5df3aeb5fd181fcc8caba8c48a194cd629771f6828/brotli-1.2.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:af43b8711a8264bb4e7d6d9a6d004c3a2019c04c01127a868709ec29962b6036", size = 1487913, upload-time = "2025-11-05T18:38:31.618Z" },
130
- { url = "https://files.pythonhosted.org/packages/62/28/4d00cb9bd76a6357a66fcd54b4b6d70288385584063f4b07884c1e7286ac/brotli-1.2.0-cp312-cp312-win32.whl", hash = "sha256:e99befa0b48f3cd293dafeacdd0d191804d105d279e0b387a32054c1180f3161", size = 334362, upload-time = "2025-11-05T18:38:32.939Z" },
131
- { url = "https://files.pythonhosted.org/packages/1c/4e/bc1dcac9498859d5e353c9b153627a3752868a9d5f05ce8dedd81a2354ab/brotli-1.2.0-cp312-cp312-win_amd64.whl", hash = "sha256:b35c13ce241abdd44cb8ca70683f20c0c079728a36a996297adb5334adfc1c44", size = 369115, upload-time = "2025-11-05T18:38:33.765Z" },
132
- ]
133
-
134
  [[package]]
135
  name = "cachetools"
136
  version = "7.0.5"
@@ -462,22 +406,6 @@ wheels = [
462
  { url = "https://files.pythonhosted.org/packages/c1/ea/53f2148663b321f21b5a606bd5f191517cf40b7072c0497d3c92c4a13b1e/executing-2.2.1-py2.py3-none-any.whl", hash = "sha256:760643d3452b4d777d295bb167ccc74c64a81df23fb5e08eff250c425a4b2017", size = 28317, upload-time = "2025-09-01T09:48:08.5Z" },
463
  ]
464
 
465
- [[package]]
466
- name = "fastapi"
467
- version = "0.135.1"
468
- source = { registry = "https://pypi.org/simple" }
469
- dependencies = [
470
- { name = "annotated-doc", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
471
- { name = "pydantic", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
472
- { name = "starlette", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
473
- { name = "typing-extensions", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
474
- { name = "typing-inspection", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
475
- ]
476
- sdist = { url = "https://files.pythonhosted.org/packages/e7/7b/f8e0211e9380f7195ba3f3d40c292594fd81ba8ec4629e3854c353aaca45/fastapi-0.135.1.tar.gz", hash = "sha256:d04115b508d936d254cea545b7312ecaa58a7b3a0f84952535b4c9afae7668cd", size = 394962, upload-time = "2026-03-01T18:18:29.369Z" }
477
- wheels = [
478
- { url = "https://files.pythonhosted.org/packages/e4/72/42e900510195b23a56bde950d26a51f8b723846bfcaa0286e90287f0422b/fastapi-0.135.1-py3-none-any.whl", hash = "sha256:46e2fc5745924b7c840f71ddd277382af29ce1cdb7d5eab5bf697e3fb9999c9e", size = 116999, upload-time = "2026-03-01T18:18:30.831Z" },
479
- ]
480
-
481
  [[package]]
482
  name = "fastjsonschema"
483
  version = "2.21.2"
@@ -487,15 +415,6 @@ wheels = [
487
  { url = "https://files.pythonhosted.org/packages/cb/a8/20d0723294217e47de6d9e2e40fd4a9d2f7c4b6ef974babd482a59743694/fastjsonschema-2.21.2-py3-none-any.whl", hash = "sha256:1c797122d0a86c5cace2e54bf4e819c36223b552017172f32c5c024a6b77e463", size = 24024, upload-time = "2025-08-14T18:49:34.776Z" },
488
  ]
489
 
490
- [[package]]
491
- name = "ffmpy"
492
- version = "1.0.0"
493
- source = { registry = "https://pypi.org/simple" }
494
- sdist = { url = "https://files.pythonhosted.org/packages/7d/d2/1c4c582d71bcc65c76fa69fab85de6257d50fdf6fd4a2317c53917e9a581/ffmpy-1.0.0.tar.gz", hash = "sha256:b12932e95435c8820f1cd041024402765f821971e4bae753b327fc02a6e12f8b", size = 5101, upload-time = "2025-11-11T06:24:23.856Z" }
495
- wheels = [
496
- { url = "https://files.pythonhosted.org/packages/55/56/dd3669eccebb6d8ac81e624542ebd53fe6f08e1b8f2f8d50aeb7e3b83f99/ffmpy-1.0.0-py3-none-any.whl", hash = "sha256:5640e5f0fd03fb6236d0e119b16ccf6522db1c826fdf35dcb87087b60fd7504f", size = 5614, upload-time = "2025-11-11T06:24:22.818Z" },
497
- ]
498
-
499
  [[package]]
500
  name = "filelock"
501
  version = "3.25.2"
@@ -571,71 +490,6 @@ wheels = [
571
  { url = "https://files.pythonhosted.org/packages/6a/09/e21df6aef1e1ffc0c816f0522ddc3f6dcded766c3261813131c78a704470/gitpython-3.1.46-py3-none-any.whl", hash = "sha256:79812ed143d9d25b6d176a10bb511de0f9c67b1fa641d82097b0ab90398a2058", size = 208620, upload-time = "2026-01-01T15:37:30.574Z" },
572
  ]
573
 
574
- [[package]]
575
- name = "gradio"
576
- version = "6.9.0"
577
- source = { registry = "https://pypi.org/simple" }
578
- dependencies = [
579
- { name = "aiofiles", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
580
- { name = "anyio", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
581
- { name = "brotli", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
582
- { name = "fastapi", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
583
- { name = "ffmpy", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
584
- { name = "gradio-client", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
585
- { name = "groovy", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
586
- { name = "httpx", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
587
- { name = "huggingface-hub", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
588
- { name = "jinja2", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
589
- { name = "markupsafe", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
590
- { name = "numpy", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
591
- { name = "orjson", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
592
- { name = "packaging", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
593
- { name = "pandas", version = "2.3.3", source = { registry = "https://pypi.org/simple" }, marker = "(python_full_version < '3.11' and sys_platform == 'linux') or (python_full_version >= '3.11' and extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm') or (sys_platform != 'linux' and extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
594
- { name = "pandas", version = "3.0.1", source = { registry = "https://pypi.org/simple" }, marker = "(python_full_version >= '3.11' and sys_platform == 'linux') or (python_full_version < '3.11' and extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm') or (sys_platform != 'linux' and extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
595
- { name = "pillow", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
596
- { name = "pydantic", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
597
- { name = "pydub", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
598
- { name = "python-multipart", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
599
- { name = "pytz", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
600
- { name = "pyyaml", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
601
- { name = "safehttpx", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
602
- { name = "semantic-version", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
603
- { name = "starlette", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
604
- { name = "tomlkit", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
605
- { name = "typer", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
606
- { name = "typing-extensions", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
607
- { name = "uvicorn", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
608
- ]
609
- sdist = { url = "https://files.pythonhosted.org/packages/bd/83/29bdbf94b212512e3c775482d390f5b699a72d71a2c431dea367a6e45a37/gradio-6.9.0.tar.gz", hash = "sha256:593e60e33233f3586452ebfa9f741817c5ae849a98cc70945f3ccb8dc895eb22", size = 57904480, upload-time = "2026-03-06T17:44:26.025Z" }
610
- wheels = [
611
- { url = "https://files.pythonhosted.org/packages/b3/8b/dc357ab966544e4dc898a2fee326d755c5f54da82af71a1a802e3476e78e/gradio-6.9.0-py3-none-any.whl", hash = "sha256:c173dd330c9247002a42222c85d76c0ecee65437eff808084e360862e7bbd24f", size = 42940853, upload-time = "2026-03-06T17:44:22.009Z" },
612
- ]
613
-
614
- [[package]]
615
- name = "gradio-client"
616
- version = "2.3.0"
617
- source = { registry = "https://pypi.org/simple" }
618
- dependencies = [
619
- { name = "fsspec", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
620
- { name = "httpx", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
621
- { name = "huggingface-hub", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
622
- { name = "packaging", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
623
- { name = "typing-extensions", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
624
- ]
625
- sdist = { url = "https://files.pythonhosted.org/packages/97/d2/de2037f5eff13a5145cdf6982fd34c9735f0806e8a2ee5d4bfe9a7d25a54/gradio_client-2.3.0.tar.gz", hash = "sha256:1c700dc60e65bae4386ba7cf3732b9f9d5bcf5fb8eb451df3944fe092d7d9a29", size = 57552, upload-time = "2026-03-06T17:44:38.247Z" }
626
- wheels = [
627
- { url = "https://files.pythonhosted.org/packages/99/6a/41752781399811afbf8ac858f63c20eff354ed35169daa39604aefced4e8/gradio_client-2.3.0-py3-none-any.whl", hash = "sha256:9ec51a927888fc188e123a0ac5ad341d9265b325539a399554d1fc2604942e74", size = 58531, upload-time = "2026-03-06T17:44:36.961Z" },
628
- ]
629
-
630
- [[package]]
631
- name = "groovy"
632
- version = "0.1.2"
633
- source = { registry = "https://pypi.org/simple" }
634
- sdist = { url = "https://files.pythonhosted.org/packages/52/36/bbdede67400277bef33d3ec0e6a31750da972c469f75966b4930c753218f/groovy-0.1.2.tar.gz", hash = "sha256:25c1dc09b3f9d7e292458aa762c6beb96ea037071bf5e917fc81fb78d2231083", size = 17325, upload-time = "2025-02-28T20:24:56.068Z" }
635
- wheels = [
636
- { url = "https://files.pythonhosted.org/packages/28/27/3d6dcadc8a3214d8522c1e7f6a19554e33659be44546d44a2f7572ac7d2a/groovy-0.1.2-py3-none-any.whl", hash = "sha256:7f7975bab18c729a257a8b1ae9dcd70b7cafb1720481beae47719af57c35fa64", size = 14090, upload-time = "2025-02-28T20:24:55.152Z" },
637
- ]
638
-
639
  [[package]]
640
  name = "h11"
641
  version = "0.16.0"
@@ -645,70 +499,6 @@ wheels = [
645
  { url = "https://files.pythonhosted.org/packages/04/4b/29cac41a4d98d144bf5f6d33995617b185d14b22401f75ca86f384e87ff1/h11-0.16.0-py3-none-any.whl", hash = "sha256:63cf8bbe7522de3bf65932fda1d9c2772064ffb3dae62d55932da54b31cb6c86", size = 37515, upload-time = "2025-04-24T03:35:24.344Z" },
646
  ]
647
 
648
- [[package]]
649
- name = "hf-xet"
650
- version = "1.4.2"
651
- source = { registry = "https://pypi.org/simple" }
652
- sdist = { url = "https://files.pythonhosted.org/packages/09/08/23c84a26716382c89151b5b447b4beb19e3345f3a93d3b73009a71a57ad3/hf_xet-1.4.2.tar.gz", hash = "sha256:b7457b6b482d9e0743bd116363239b1fa904a5e65deede350fbc0c4ea67c71ea", size = 672357, upload-time = "2026-03-13T06:58:51.077Z" }
653
- wheels = [
654
- { url = "https://files.pythonhosted.org/packages/b4/86/b40b83a2ff03ef05c4478d2672b1fc2b9683ff870e2b25f4f3af240f2e7b/hf_xet-1.4.2-cp37-abi3-macosx_10_12_x86_64.whl", hash = "sha256:71f02d6e4cdd07f344f6844845d78518cc7186bd2bc52d37c3b73dc26a3b0bc5", size = 3800339, upload-time = "2026-03-13T06:58:36.245Z" },
655
- { url = "https://files.pythonhosted.org/packages/64/2e/af4475c32b4378b0e92a587adb1aa3ec53e3450fd3e5fe0372a874531c00/hf_xet-1.4.2-cp37-abi3-macosx_11_0_arm64.whl", hash = "sha256:e9b38d876e94d4bdcf650778d6ebbaa791dd28de08db9736c43faff06ede1b5a", size = 3559664, upload-time = "2026-03-13T06:58:34.787Z" },
656
- { url = "https://files.pythonhosted.org/packages/3c/4c/781267da3188db679e601de18112021a5cb16506fe86b246e22c5401a9c4/hf_xet-1.4.2-cp37-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:77e8c180b7ef12d8a96739a4e1e558847002afe9ea63b6f6358b2271a8bdda1c", size = 4217422, upload-time = "2026-03-13T06:58:27.472Z" },
657
- { url = "https://files.pythonhosted.org/packages/68/47/d6cf4a39ecf6c7705f887a46f6ef5c8455b44ad9eb0d391aa7e8a2ff7fea/hf_xet-1.4.2-cp37-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:c3b3c6a882016b94b6c210957502ff7877802d0dbda8ad142c8595db8b944271", size = 3992847, upload-time = "2026-03-13T06:58:25.989Z" },
658
- { url = "https://files.pythonhosted.org/packages/2d/ef/e80815061abff54697239803948abc665c6b1d237102c174f4f7a9a5ffc5/hf_xet-1.4.2-cp37-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:9d9a634cc929cfbaf2e1a50c0e532ae8c78fa98618426769480c58501e8c8ac2", size = 4193843, upload-time = "2026-03-13T06:58:44.59Z" },
659
- { url = "https://files.pythonhosted.org/packages/54/75/07f6aa680575d9646c4167db6407c41340cbe2357f5654c4e72a1b01ca14/hf_xet-1.4.2-cp37-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:6b0932eb8b10317ea78b7da6bab172b17be03bbcd7809383d8d5abd6a2233e04", size = 4432751, upload-time = "2026-03-13T06:58:46.533Z" },
660
- { url = "https://files.pythonhosted.org/packages/cd/71/193eabd7e7d4b903c4aa983a215509c6114915a5a237525ec562baddb868/hf_xet-1.4.2-cp37-abi3-win_amd64.whl", hash = "sha256:ad185719fb2e8ac26f88c8100562dbf9dbdcc3d9d2add00faa94b5f106aea53f", size = 3671149, upload-time = "2026-03-13T06:58:57.07Z" },
661
- { url = "https://files.pythonhosted.org/packages/b4/7e/ccf239da366b37ba7f0b36095450efae4a64980bdc7ec2f51354205fdf39/hf_xet-1.4.2-cp37-abi3-win_arm64.whl", hash = "sha256:32c012286b581f783653e718c1862aea5b9eb140631685bb0c5e7012c8719a87", size = 3533426, upload-time = "2026-03-13T06:58:55.46Z" },
662
- ]
663
-
664
- [[package]]
665
- name = "httpcore"
666
- version = "1.0.9"
667
- source = { registry = "https://pypi.org/simple" }
668
- dependencies = [
669
- { name = "certifi", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
670
- { name = "h11", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
671
- ]
672
- sdist = { url = "https://files.pythonhosted.org/packages/06/94/82699a10bca87a5556c9c59b5963f2d039dbd239f25bc2a63907a05a14cb/httpcore-1.0.9.tar.gz", hash = "sha256:6e34463af53fd2ab5d807f399a9b45ea31c3dfa2276f15a2c3f00afff6e176e8", size = 85484, upload-time = "2025-04-24T22:06:22.219Z" }
673
- wheels = [
674
- { url = "https://files.pythonhosted.org/packages/7e/f5/f66802a942d491edb555dd61e3a9961140fd64c90bce1eafd741609d334d/httpcore-1.0.9-py3-none-any.whl", hash = "sha256:2d400746a40668fc9dec9810239072b40b4484b640a8c38fd654a024c7a1bf55", size = 78784, upload-time = "2025-04-24T22:06:20.566Z" },
675
- ]
676
-
677
- [[package]]
678
- name = "httpx"
679
- version = "0.28.1"
680
- source = { registry = "https://pypi.org/simple" }
681
- dependencies = [
682
- { name = "anyio", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
683
- { name = "certifi", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
684
- { name = "httpcore", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
685
- { name = "idna", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
686
- ]
687
- sdist = { url = "https://files.pythonhosted.org/packages/b1/df/48c586a5fe32a0f01324ee087459e112ebb7224f646c0b5023f5e79e9956/httpx-0.28.1.tar.gz", hash = "sha256:75e98c5f16b0f35b567856f597f06ff2270a374470a5c2392242528e3e3e42fc", size = 141406, upload-time = "2024-12-06T15:37:23.222Z" }
688
- wheels = [
689
- { url = "https://files.pythonhosted.org/packages/2a/39/e50c7c3a983047577ee07d2a9e53faf5a69493943ec3f6a384bdc792deb2/httpx-0.28.1-py3-none-any.whl", hash = "sha256:d909fcccc110f8c7faf814ca82a9a4d816bc5a6dbfea25d6591d6985b8ba59ad", size = 73517, upload-time = "2024-12-06T15:37:21.509Z" },
690
- ]
691
-
692
- [[package]]
693
- name = "huggingface-hub"
694
- version = "1.7.2"
695
- source = { registry = "https://pypi.org/simple" }
696
- dependencies = [
697
- { name = "filelock", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
698
- { name = "fsspec", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
699
- { name = "hf-xet", marker = "(platform_machine != 'AMD64' and platform_machine != 'aarch64' and platform_machine != 'amd64' and platform_machine != 'arm64' and platform_machine != 'x86_64' and extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm') or (platform_machine == 'AMD64' and sys_platform == 'linux') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'amd64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
700
- { name = "httpx", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
701
- { name = "packaging", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
702
- { name = "pyyaml", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
703
- { name = "tqdm", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
704
- { name = "typer", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
705
- { name = "typing-extensions", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
706
- ]
707
- sdist = { url = "https://files.pythonhosted.org/packages/19/15/eafc1c57bf0f8afffb243dcd4c0cceb785e956acc17bba4d9bf2ae21fc9c/huggingface_hub-1.7.2.tar.gz", hash = "sha256:7f7e294e9bbb822e025bdb2ada025fa4344d978175a7f78e824d86e35f7ab43b", size = 724684, upload-time = "2026-03-20T10:36:08.767Z" }
708
- wheels = [
709
- { url = "https://files.pythonhosted.org/packages/08/de/3ad061a05f74728927ded48c90b73521b9a9328c85d841bdefb30e01fb85/huggingface_hub-1.7.2-py3-none-any.whl", hash = "sha256:288f33a0a17b2a73a1359e2a5fd28d1becb2c121748c6173ab8643fb342c850e", size = 618036, upload-time = "2026-03-20T10:36:06.824Z" },
710
- ]
711
-
712
  [[package]]
713
  name = "humanize"
714
  version = "4.15.0"
@@ -1439,57 +1229,6 @@ wheels = [
1439
  { url = "https://files.pythonhosted.org/packages/9f/99/4c9c0c329bf9fc125008c3b54c7c94c0023518d06fc025ae36431375e1fe/nvidia_nvtx_cu12-12.8.90-py3-none-win_amd64.whl", hash = "sha256:619c8304aedc69f02ea82dd244541a83c3d9d40993381b3b590f1adaed3db41e", size = 56492, upload-time = "2025-03-07T01:52:24.69Z" },
1440
  ]
1441
 
1442
- [[package]]
1443
- name = "orjson"
1444
- version = "3.11.7"
1445
- source = { registry = "https://pypi.org/simple" }
1446
- sdist = { url = "https://files.pythonhosted.org/packages/53/45/b268004f745ede84e5798b48ee12b05129d19235d0e15267aa57dcdb400b/orjson-3.11.7.tar.gz", hash = "sha256:9b1a67243945819ce55d24a30b59d6a168e86220452d2c96f4d1f093e71c0c49", size = 6144992, upload-time = "2026-02-02T15:38:49.29Z" }
1447
- wheels = [
1448
- { url = "https://files.pythonhosted.org/packages/de/1a/a373746fa6d0e116dd9e54371a7b54622c44d12296d5d0f3ad5e3ff33490/orjson-3.11.7-cp310-cp310-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:a02c833f38f36546ba65a452127633afce4cf0dd7296b753d3bb54e55e5c0174", size = 229140, upload-time = "2026-02-02T15:37:06.082Z" },
1449
- { url = "https://files.pythonhosted.org/packages/52/a2/fa129e749d500f9b183e8a3446a193818a25f60261e9ce143ad61e975208/orjson-3.11.7-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b63c6e6738d7c3470ad01601e23376aa511e50e1f3931395b9f9c722406d1a67", size = 128670, upload-time = "2026-02-02T15:37:08.002Z" },
1450
- { url = "https://files.pythonhosted.org/packages/08/93/1e82011cd1e0bd051ef9d35bed1aa7fb4ea1f0a055dc2c841b46b43a9ebd/orjson-3.11.7-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:043d3006b7d32c7e233b8cfb1f01c651013ea079e08dcef7189a29abd8befe11", size = 123832, upload-time = "2026-02-02T15:37:09.191Z" },
1451
- { url = "https://files.pythonhosted.org/packages/fe/d8/a26b431ef962c7d55736674dddade876822f3e33223c1f47a36879350d04/orjson-3.11.7-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:57036b27ac8a25d81112eb0cc9835cd4833c5b16e1467816adc0015f59e870dc", size = 129171, upload-time = "2026-02-02T15:37:11.112Z" },
1452
- { url = "https://files.pythonhosted.org/packages/a7/19/f47819b84a580f490da260c3ee9ade214cf4cf78ac9ce8c1c758f80fdfc9/orjson-3.11.7-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:733ae23ada68b804b222c44affed76b39e30806d38660bf1eb200520d259cc16", size = 141967, upload-time = "2026-02-02T15:37:12.282Z" },
1453
- { url = "https://files.pythonhosted.org/packages/5b/cd/37ece39a0777ba077fdcdbe4cccae3be8ed00290c14bf8afdc548befc260/orjson-3.11.7-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5fdfad2093bdd08245f2e204d977facd5f871c88c4a71230d5bcbd0e43bf6222", size = 130991, upload-time = "2026-02-02T15:37:13.465Z" },
1454
- { url = "https://files.pythonhosted.org/packages/8f/ed/f2b5d66aa9b6b5c02ff5f120efc7b38c7c4962b21e6be0f00fd99a5c348e/orjson-3.11.7-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cededd6738e1c153530793998e31c05086582b08315db48ab66649768f326baa", size = 133674, upload-time = "2026-02-02T15:37:14.694Z" },
1455
- { url = "https://files.pythonhosted.org/packages/c4/6e/baa83e68d1aa09fa8c3e5b2c087d01d0a0bd45256de719ed7bc22c07052d/orjson-3.11.7-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:14f440c7268c8f8633d1b3d443a434bd70cb15686117ea6beff8fdc8f5917a1e", size = 138722, upload-time = "2026-02-02T15:37:16.501Z" },
1456
- { url = "https://files.pythonhosted.org/packages/0c/47/7f8ef4963b772cd56999b535e553f7eb5cd27e9dd6c049baee6f18bfa05d/orjson-3.11.7-cp310-cp310-musllinux_1_2_armv7l.whl", hash = "sha256:3a2479753bbb95b0ebcf7969f562cdb9668e6d12416a35b0dda79febf89cdea2", size = 409056, upload-time = "2026-02-02T15:37:17.895Z" },
1457
- { url = "https://files.pythonhosted.org/packages/38/eb/2df104dd2244b3618f25325a656f85cc3277f74bbd91224752410a78f3c7/orjson-3.11.7-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:71924496986275a737f38e3f22b4e0878882b3f7a310d2ff4dc96e812789120c", size = 144196, upload-time = "2026-02-02T15:37:19.349Z" },
1458
- { url = "https://files.pythonhosted.org/packages/b6/2a/ee41de0aa3a6686598661eae2b4ebdff1340c65bfb17fcff8b87138aab21/orjson-3.11.7-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:b4a9eefdc70bf8bf9857f0290f973dec534ac84c35cd6a7f4083be43e7170a8f", size = 134979, upload-time = "2026-02-02T15:37:20.906Z" },
1459
- { url = "https://files.pythonhosted.org/packages/4c/fa/92fc5d3d402b87a8b28277a9ed35386218a6a5287c7fe5ee9b9f02c53fb2/orjson-3.11.7-cp310-cp310-win32.whl", hash = "sha256:ae9e0b37a834cef7ce8f99de6498f8fad4a2c0bf6bfc3d02abd8ed56aa15b2de", size = 127968, upload-time = "2026-02-02T15:37:23.178Z" },
1460
- { url = "https://files.pythonhosted.org/packages/07/29/a576bf36d73d60df06904d3844a9df08e25d59eba64363aaf8ec2f9bff41/orjson-3.11.7-cp310-cp310-win_amd64.whl", hash = "sha256:d772afdb22555f0c58cfc741bdae44180122b3616faa1ecadb595cd526e4c993", size = 125128, upload-time = "2026-02-02T15:37:24.329Z" },
1461
- { url = "https://files.pythonhosted.org/packages/37/02/da6cb01fc6087048d7f61522c327edf4250f1683a58a839fdcc435746dd5/orjson-3.11.7-cp311-cp311-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:9487abc2c2086e7c8eb9a211d2ce8855bae0e92586279d0d27b341d5ad76c85c", size = 228664, upload-time = "2026-02-02T15:37:25.542Z" },
1462
- { url = "https://files.pythonhosted.org/packages/c1/c2/5885e7a5881dba9a9af51bc564e8967225a642b3e03d089289a35054e749/orjson-3.11.7-cp311-cp311-macosx_15_0_arm64.whl", hash = "sha256:79cacb0b52f6004caf92405a7e1f11e6e2de8bdf9019e4f76b44ba045125cd6b", size = 125344, upload-time = "2026-02-02T15:37:26.92Z" },
1463
- { url = "https://files.pythonhosted.org/packages/a4/1d/4e7688de0a92d1caf600dfd5fb70b4c5bfff51dfa61ac555072ef2d0d32a/orjson-3.11.7-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c2e85fe4698b6a56d5e2ebf7ae87544d668eb6bde1ad1226c13f44663f20ec9e", size = 128404, upload-time = "2026-02-02T15:37:28.108Z" },
1464
- { url = "https://files.pythonhosted.org/packages/2f/b2/ec04b74ae03a125db7bd69cffd014b227b7f341e3261bf75b5eb88a1aa92/orjson-3.11.7-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:b8d14b71c0b12963fe8a62aac87119f1afdf4cb88a400f61ca5ae581449efcb5", size = 123677, upload-time = "2026-02-02T15:37:30.287Z" },
1465
- { url = "https://files.pythonhosted.org/packages/4c/69/f95bdf960605f08f827f6e3291fe243d8aa9c5c9ff017a8d7232209184c3/orjson-3.11.7-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:91c81ef070c8f3220054115e1ef468b1c9ce8497b4e526cb9f68ab4dc0a7ac62", size = 128950, upload-time = "2026-02-02T15:37:31.595Z" },
1466
- { url = "https://files.pythonhosted.org/packages/a4/1b/de59c57bae1d148ef298852abd31909ac3089cff370dfd4cd84cc99cbc42/orjson-3.11.7-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:411ebaf34d735e25e358a6d9e7978954a9c9d58cfb47bc6683cdc3964cd2f910", size = 141756, upload-time = "2026-02-02T15:37:32.985Z" },
1467
- { url = "https://files.pythonhosted.org/packages/ee/9e/9decc59f4499f695f65c650f6cfa6cd4c37a3fbe8fa235a0a3614cb54386/orjson-3.11.7-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a16bcd08ab0bcdfc7e8801d9c4a9cc17e58418e4d48ddc6ded4e9e4b1a94062b", size = 130812, upload-time = "2026-02-02T15:37:34.204Z" },
1468
- { url = "https://files.pythonhosted.org/packages/28/e6/59f932bcabd1eac44e334fe8e3281a92eacfcb450586e1f4bde0423728d8/orjson-3.11.7-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9c0b51672e466fd7e56230ffbae7f1639e18d0ce023351fb75da21b71bc2c960", size = 133444, upload-time = "2026-02-02T15:37:35.446Z" },
1469
- { url = "https://files.pythonhosted.org/packages/f1/36/b0f05c0eaa7ca30bc965e37e6a2956b0d67adb87a9872942d3568da846ae/orjson-3.11.7-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:136dcd6a2e796dfd9ffca9fc027d778567b0b7c9968d092842d3c323cef88aa8", size = 138609, upload-time = "2026-02-02T15:37:36.657Z" },
1470
- { url = "https://files.pythonhosted.org/packages/b8/03/58ec7d302b8d86944c60c7b4b82975d5161fcce4c9bc8c6cb1d6741b6115/orjson-3.11.7-cp311-cp311-musllinux_1_2_armv7l.whl", hash = "sha256:7ba61079379b0ae29e117db13bda5f28d939766e410d321ec1624afc6a0b0504", size = 408918, upload-time = "2026-02-02T15:37:38.076Z" },
1471
- { url = "https://files.pythonhosted.org/packages/06/3a/868d65ef9a8b99be723bd510de491349618abd9f62c826cf206d962db295/orjson-3.11.7-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:0527a4510c300e3b406591b0ba69b5dc50031895b0a93743526a3fc45f59d26e", size = 143998, upload-time = "2026-02-02T15:37:39.706Z" },
1472
- { url = "https://files.pythonhosted.org/packages/5b/c7/1e18e1c83afe3349f4f6dc9e14910f0ae5f82eac756d1412ea4018938535/orjson-3.11.7-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:a709e881723c9b18acddcfb8ba357322491ad553e277cf467e1e7e20e2d90561", size = 134802, upload-time = "2026-02-02T15:37:41.002Z" },
1473
- { url = "https://files.pythonhosted.org/packages/d4/0b/ccb7ee1a65b37e8eeb8b267dc953561d72370e85185e459616d4345bab34/orjson-3.11.7-cp311-cp311-win32.whl", hash = "sha256:c43b8b5bab288b6b90dac410cca7e986a4fa747a2e8f94615aea407da706980d", size = 127828, upload-time = "2026-02-02T15:37:42.241Z" },
1474
- { url = "https://files.pythonhosted.org/packages/af/9e/55c776dffda3f381e0f07d010a4f5f3902bf48eaba1bb7684d301acd4924/orjson-3.11.7-cp311-cp311-win_amd64.whl", hash = "sha256:6543001328aa857187f905308a028935864aefe9968af3848401b6fe80dbb471", size = 124941, upload-time = "2026-02-02T15:37:43.444Z" },
1475
- { url = "https://files.pythonhosted.org/packages/aa/8e/424a620fa7d263b880162505fb107ef5e0afaa765b5b06a88312ac291560/orjson-3.11.7-cp311-cp311-win_arm64.whl", hash = "sha256:1ee5cc7160a821dfe14f130bc8e63e7611051f964b463d9e2a3a573204446a4d", size = 126245, upload-time = "2026-02-02T15:37:45.18Z" },
1476
- { url = "https://files.pythonhosted.org/packages/80/bf/76f4f1665f6983385938f0e2a5d7efa12a58171b8456c252f3bae8a4cf75/orjson-3.11.7-cp312-cp312-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:bd03ea7606833655048dab1a00734a2875e3e86c276e1d772b2a02556f0d895f", size = 228545, upload-time = "2026-02-02T15:37:46.376Z" },
1477
- { url = "https://files.pythonhosted.org/packages/79/53/6c72c002cb13b5a978a068add59b25a8bdf2800ac1c9c8ecdb26d6d97064/orjson-3.11.7-cp312-cp312-macosx_15_0_arm64.whl", hash = "sha256:89e440ebc74ce8ab5c7bc4ce6757b4a6b1041becb127df818f6997b5c71aa60b", size = 125224, upload-time = "2026-02-02T15:37:47.697Z" },
1478
- { url = "https://files.pythonhosted.org/packages/2c/83/10e48852865e5dd151bdfe652c06f7da484578ed02c5fca938e3632cb0b8/orjson-3.11.7-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5ede977b5fe5ac91b1dffc0a517ca4542d2ec8a6a4ff7b2652d94f640796342a", size = 128154, upload-time = "2026-02-02T15:37:48.954Z" },
1479
- { url = "https://files.pythonhosted.org/packages/6e/52/a66e22a2b9abaa374b4a081d410edab6d1e30024707b87eab7c734afe28d/orjson-3.11.7-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:b7b1dae39230a393df353827c855a5f176271c23434cfd2db74e0e424e693e10", size = 123548, upload-time = "2026-02-02T15:37:50.187Z" },
1480
- { url = "https://files.pythonhosted.org/packages/de/38/605d371417021359f4910c496f764c48ceb8997605f8c25bf1dfe58c0ebe/orjson-3.11.7-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ed46f17096e28fb28d2975834836a639af7278aa87c84f68ab08fbe5b8bd75fa", size = 129000, upload-time = "2026-02-02T15:37:51.426Z" },
1481
- { url = "https://files.pythonhosted.org/packages/44/98/af32e842b0ffd2335c89714d48ca4e3917b42f5d6ee5537832e069a4b3ac/orjson-3.11.7-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3726be79e36e526e3d9c1aceaadbfb4a04ee80a72ab47b3f3c17fefb9812e7b8", size = 141686, upload-time = "2026-02-02T15:37:52.607Z" },
1482
- { url = "https://files.pythonhosted.org/packages/96/0b/fc793858dfa54be6feee940c1463370ece34b3c39c1ca0aa3845f5ba9892/orjson-3.11.7-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:0724e265bc548af1dedebd9cb3d24b4e1c1e685a343be43e87ba922a5c5fff2f", size = 130812, upload-time = "2026-02-02T15:37:53.944Z" },
1483
- { url = "https://files.pythonhosted.org/packages/dc/91/98a52415059db3f374757d0b7f0f16e3b5cd5976c90d1c2b56acaea039e6/orjson-3.11.7-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e7745312efa9e11c17fbd3cb3097262d079da26930ae9ae7ba28fb738367cbad", size = 133440, upload-time = "2026-02-02T15:37:55.615Z" },
1484
- { url = "https://files.pythonhosted.org/packages/dc/b6/cb540117bda61791f46381f8c26c8f93e802892830a6055748d3bb1925ab/orjson-3.11.7-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:f904c24bdeabd4298f7a977ef14ca2a022ca921ed670b92ecd16ab6f3d01f867", size = 138386, upload-time = "2026-02-02T15:37:56.814Z" },
1485
- { url = "https://files.pythonhosted.org/packages/63/1a/50a3201c334a7f17c231eee5f841342190723794e3b06293f26e7cf87d31/orjson-3.11.7-cp312-cp312-musllinux_1_2_armv7l.whl", hash = "sha256:b9fc4d0f81f394689e0814617aadc4f2ea0e8025f38c226cbf22d3b5ddbf025d", size = 408853, upload-time = "2026-02-02T15:37:58.291Z" },
1486
- { url = "https://files.pythonhosted.org/packages/87/cd/8de1c67d0be44fdc22701e5989c0d015a2adf391498ad42c4dc589cd3013/orjson-3.11.7-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:849e38203e5be40b776ed2718e587faf204d184fc9a008ae441f9442320c0cab", size = 144130, upload-time = "2026-02-02T15:38:00.163Z" },
1487
- { url = "https://files.pythonhosted.org/packages/0f/fe/d605d700c35dd55f51710d159fc54516a280923cd1b7e47508982fbb387d/orjson-3.11.7-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:4682d1db3bcebd2b64757e0ddf9e87ae5f00d29d16c5cdf3a62f561d08cc3dd2", size = 134818, upload-time = "2026-02-02T15:38:01.507Z" },
1488
- { url = "https://files.pythonhosted.org/packages/e4/e4/15ecc67edb3ddb3e2f46ae04475f2d294e8b60c1825fbe28a428b93b3fbd/orjson-3.11.7-cp312-cp312-win32.whl", hash = "sha256:f4f7c956b5215d949a1f65334cf9d7612dde38f20a95f2315deef167def91a6f", size = 127923, upload-time = "2026-02-02T15:38:02.75Z" },
1489
- { url = "https://files.pythonhosted.org/packages/34/70/2e0855361f76198a3965273048c8e50a9695d88cd75811a5b46444895845/orjson-3.11.7-cp312-cp312-win_amd64.whl", hash = "sha256:bf742e149121dc5648ba0a08ea0871e87b660467ef168a3a5e53bc1fbd64bb74", size = 125007, upload-time = "2026-02-02T15:38:04.032Z" },
1490
- { url = "https://files.pythonhosted.org/packages/68/40/c2051bd19fc467610fed469dc29e43ac65891571138f476834ca192bc290/orjson-3.11.7-cp312-cp312-win_arm64.whl", hash = "sha256:26c3b9132f783b7d7903bf1efb095fed8d4a3a85ec0d334ee8beff3d7a4749d5", size = 126089, upload-time = "2026-02-02T15:38:05.297Z" },
1491
- ]
1492
-
1493
  [[package]]
1494
  name = "packaging"
1495
  version = "26.0"
@@ -1586,6 +1325,7 @@ dependencies = [
1586
  { name = "chess-engine", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
1587
  { name = "numpy", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
1588
  { name = "psutil", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
 
1589
  { name = "tqdm", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
1590
  { name = "wandb", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
1591
  ]
@@ -1601,9 +1341,6 @@ dashboard = [
1601
  { name = "plotly", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
1602
  { name = "solara", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
1603
  ]
1604
- demo = [
1605
- { name = "gradio", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
1606
- ]
1607
  dev = [
1608
  { name = "ipykernel", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
1609
  { name = "pytest", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
@@ -1623,7 +1360,6 @@ rocm = [
1623
  requires-dist = [
1624
  { name = "anywidget", marker = "extra == 'dashboard'", specifier = ">=0.9.21" },
1625
  { name = "chess-engine", editable = "engine" },
1626
- { name = "gradio", marker = "extra == 'demo'", specifier = ">=5.0.0" },
1627
  { name = "ipykernel", marker = "extra == 'dev'", specifier = ">=7.2.0" },
1628
  { name = "matplotlib", marker = "extra == 'eval'", specifier = ">=3.10.8" },
1629
  { name = "numpy", specifier = "~=2.2.0" },
@@ -1633,6 +1369,7 @@ requires-dist = [
1633
  { name = "psutil", specifier = ">=5.9.0" },
1634
  { name = "pyarrow", marker = "extra == 'eval'", specifier = ">=23.0.1" },
1635
  { name = "pytest", marker = "extra == 'dev'", specifier = "~=9.0.0" },
 
1636
  { name = "seaborn", marker = "extra == 'eval'", specifier = ">=0.13.2" },
1637
  { name = "solara", marker = "extra == 'dashboard'", specifier = ">=1.0.0" },
1638
  { name = "torch", marker = "extra == 'cu128'", specifier = "~=2.10.0", index = "https://download.pytorch.org/whl/cu128", conflict = { package = "pawn", extra = "cu128" } },
@@ -1641,7 +1378,7 @@ requires-dist = [
1641
  { name = "triton-rocm", marker = "extra == 'rocm'", specifier = ">=3.6.0", index = "https://download.pytorch.org/whl/rocm7.1", conflict = { package = "pawn", extra = "rocm" } },
1642
  { name = "wandb", specifier = "~=0.25.0" },
1643
  ]
1644
- provides-extras = ["rocm", "cu128", "eval", "dashboard", "demo", "dev"]
1645
 
1646
  [[package]]
1647
  name = "pexpect"
@@ -1976,15 +1713,6 @@ wheels = [
1976
  { url = "https://files.pythonhosted.org/packages/36/c7/cfc8e811f061c841d7990b0201912c3556bfeb99cdcb7ed24adc8d6f8704/pydantic_core-2.41.5-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:56121965f7a4dc965bff783d70b907ddf3d57f6eba29b6d2e5dabfaf07799c51", size = 2145302, upload-time = "2025-11-04T13:43:46.64Z" },
1977
  ]
1978
 
1979
- [[package]]
1980
- name = "pydub"
1981
- version = "0.25.1"
1982
- source = { registry = "https://pypi.org/simple" }
1983
- sdist = { url = "https://files.pythonhosted.org/packages/fe/9a/e6bca0eed82db26562c73b5076539a4a08d3cffd19c3cc5913a3e61145fd/pydub-0.25.1.tar.gz", hash = "sha256:980a33ce9949cab2a569606b65674d748ecbca4f0796887fd6f46173a7b0d30f", size = 38326, upload-time = "2021-03-10T02:09:54.659Z" }
1984
- wheels = [
1985
- { url = "https://files.pythonhosted.org/packages/a6/53/d78dc063216e62fc55f6b2eebb447f6a4b0a59f55c8406376f76bf959b08/pydub-0.25.1-py2.py3-none-any.whl", hash = "sha256:65617e33033874b59d87db603aa1ed450633288aefead953b30bded59cb599a6", size = 32327, upload-time = "2021-03-10T02:09:53.503Z" },
1986
- ]
1987
-
1988
  [[package]]
1989
  name = "pygments"
1990
  version = "2.19.2"
@@ -2045,15 +1773,6 @@ wheels = [
2045
  { url = "https://files.pythonhosted.org/packages/ec/57/56b9bcc3c9c6a792fcbaf139543cee77261f3651ca9da0c93f5c1221264b/python_dateutil-2.9.0.post0-py2.py3-none-any.whl", hash = "sha256:a8b2bc7bffae282281c8140a97d3aa9c14da0b136dfe83f850eea9a5f7470427", size = 229892, upload-time = "2024-03-01T18:36:18.57Z" },
2046
  ]
2047
 
2048
- [[package]]
2049
- name = "python-multipart"
2050
- version = "0.0.22"
2051
- source = { registry = "https://pypi.org/simple" }
2052
- sdist = { url = "https://files.pythonhosted.org/packages/94/01/979e98d542a70714b0cb2b6728ed0b7c46792b695e3eaec3e20711271ca3/python_multipart-0.0.22.tar.gz", hash = "sha256:7340bef99a7e0032613f56dc36027b959fd3b30a787ed62d310e951f7c3a3a58", size = 37612, upload-time = "2026-01-25T10:15:56.219Z" }
2053
- wheels = [
2054
- { url = "https://files.pythonhosted.org/packages/1b/d0/397f9626e711ff749a95d96b7af99b9c566a9bb5129b8e4c10fc4d100304/python_multipart-0.0.22-py3-none-any.whl", hash = "sha256:2b2cd894c83d21bf49d702499531c7bafd057d730c201782048f7945d82de155", size = 24579, upload-time = "2026-01-25T10:15:54.811Z" },
2055
- ]
2056
-
2057
  [[package]]
2058
  name = "pytz"
2059
  version = "2026.1.post1"
@@ -2284,15 +2003,29 @@ wheels = [
2284
  ]
2285
 
2286
  [[package]]
2287
- name = "safehttpx"
2288
- version = "0.1.7"
2289
  source = { registry = "https://pypi.org/simple" }
2290
- dependencies = [
2291
- { name = "httpx", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
2292
- ]
2293
- sdist = { url = "https://files.pythonhosted.org/packages/89/d1/4282284d9cf1ee873607a46442da977fc3c985059315ab23610be31d5885/safehttpx-0.1.7.tar.gz", hash = "sha256:db201c0978c41eddb8bb480f3eee59dd67304fdd91646035e9d9a720049a9d23", size = 10385, upload-time = "2025-10-24T18:30:09.783Z" }
2294
- wheels = [
2295
- { url = "https://files.pythonhosted.org/packages/2e/a3/0f0b7d78e2f1eb9e8e1afbff1d2bff8d60144aee17aca51c065b516743dd/safehttpx-0.1.7-py3-none-any.whl", hash = "sha256:c4f4a162db6993464d7ca3d7cc4af0ffc6515a606dfd220b9f82c6945d869cde", size = 8959, upload-time = "2025-10-24T18:30:08.733Z" },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2296
  ]
2297
 
2298
  [[package]]
@@ -2310,15 +2043,6 @@ wheels = [
2310
  { url = "https://files.pythonhosted.org/packages/83/11/00d3c3dfc25ad54e731d91449895a79e4bf2384dc3ac01809010ba88f6d5/seaborn-0.13.2-py3-none-any.whl", hash = "sha256:636f8336facf092165e27924f223d3c62ca560b1f2bb5dff7ab7fad265361987", size = 294914, upload-time = "2024-01-25T13:21:49.598Z" },
2311
  ]
2312
 
2313
- [[package]]
2314
- name = "semantic-version"
2315
- version = "2.10.0"
2316
- source = { registry = "https://pypi.org/simple" }
2317
- sdist = { url = "https://files.pythonhosted.org/packages/7d/31/f2289ce78b9b473d582568c234e104d2a342fd658cc288a7553d83bb8595/semantic_version-2.10.0.tar.gz", hash = "sha256:bdabb6d336998cbb378d4b9db3a4b56a1e3235701dc05ea2690d9a997ed5041c", size = 52289, upload-time = "2022-05-26T13:35:23.454Z" }
2318
- wheels = [
2319
- { url = "https://files.pythonhosted.org/packages/6a/23/8146aad7d88f4fcb3a6218f41a60f6c2d4e3a72de72da1825dc7c8f7877c/semantic_version-2.10.0-py2.py3-none-any.whl", hash = "sha256:de78a3b8e0feda74cabc54aab2da702113e33ac9d9eb9d2389bcf1f58b7d9177", size = 15552, upload-time = "2022-05-26T13:35:21.206Z" },
2320
- ]
2321
-
2322
  [[package]]
2323
  name = "sentry-sdk"
2324
  version = "2.55.0"
@@ -2341,15 +2065,6 @@ wheels = [
2341
  { url = "https://files.pythonhosted.org/packages/9d/76/f789f7a86709c6b087c5a2f52f911838cad707cc613162401badc665acfe/setuptools-82.0.1-py3-none-any.whl", hash = "sha256:a59e362652f08dcd477c78bb6e7bd9d80a7995bc73ce773050228a348ce2e5bb", size = 1006223, upload-time = "2026-03-09T12:47:15.026Z" },
2342
  ]
2343
 
2344
- [[package]]
2345
- name = "shellingham"
2346
- version = "1.5.4"
2347
- source = { registry = "https://pypi.org/simple" }
2348
- sdist = { url = "https://files.pythonhosted.org/packages/58/15/8b3609fd3830ef7b27b655beb4b4e9c62313a4e8da8c676e142cc210d58e/shellingham-1.5.4.tar.gz", hash = "sha256:8dbca0739d487e5bd35ab3ca4b36e11c4078f3a234bfce294b0a0291363404de", size = 10310, upload-time = "2023-10-24T04:13:40.426Z" }
2349
- wheels = [
2350
- { url = "https://files.pythonhosted.org/packages/e0/f9/0595336914c5619e5f28a1fb793285925a8cd4b432c9da0a987836c7f822/shellingham-1.5.4-py2.py3-none-any.whl", hash = "sha256:7ecfff8f2fd72616f7481040475a65b2bf8af90a56c89140852d1120324e8686", size = 9755, upload-time = "2023-10-24T04:13:38.866Z" },
2351
- ]
2352
-
2353
  [[package]]
2354
  name = "six"
2355
  version = "1.17.0"
@@ -2504,15 +2219,6 @@ wheels = [
2504
  { url = "https://files.pythonhosted.org/packages/23/d1/136eb2cb77520a31e1f64cbae9d33ec6df0d78bdf4160398e86eec8a8754/tomli-2.4.0-py3-none-any.whl", hash = "sha256:1f776e7d669ebceb01dee46484485f43a4048746235e683bcdffacdf1fb4785a", size = 14477, upload-time = "2026-01-11T11:22:37.446Z" },
2505
  ]
2506
 
2507
- [[package]]
2508
- name = "tomlkit"
2509
- version = "0.13.3"
2510
- source = { registry = "https://pypi.org/simple" }
2511
- sdist = { url = "https://files.pythonhosted.org/packages/cc/18/0bbf3884e9eaa38819ebe46a7bd25dcd56b67434402b66a58c4b8e552575/tomlkit-0.13.3.tar.gz", hash = "sha256:430cf247ee57df2b94ee3fbe588e71d362a941ebb545dec29b53961d61add2a1", size = 185207, upload-time = "2025-06-05T07:13:44.947Z" }
2512
- wheels = [
2513
- { url = "https://files.pythonhosted.org/packages/bd/75/8539d011f6be8e29f339c42e633aae3cb73bffa95dd0f9adec09b9c58e85/tomlkit-0.13.3-py3-none-any.whl", hash = "sha256:c89c649d79ee40629a9fda55f8ace8c6a1b42deb912b2a8fd8d942ddadb606b0", size = 38901, upload-time = "2025-06-05T07:13:43.546Z" },
2514
- ]
2515
-
2516
  [[package]]
2517
  name = "torch"
2518
  version = "2.10.0+cu128"
@@ -2645,21 +2351,6 @@ wheels = [
2645
  { url = "https://download.pytorch.org/whl/triton_rocm-3.6.0-cp312-cp312-linux_x86_64.whl" },
2646
  ]
2647
 
2648
- [[package]]
2649
- name = "typer"
2650
- version = "0.24.1"
2651
- source = { registry = "https://pypi.org/simple" }
2652
- dependencies = [
2653
- { name = "annotated-doc", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
2654
- { name = "click", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
2655
- { name = "rich", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
2656
- { name = "shellingham", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
2657
- ]
2658
- sdist = { url = "https://files.pythonhosted.org/packages/f5/24/cb09efec5cc954f7f9b930bf8279447d24618bb6758d4f6adf2574c41780/typer-0.24.1.tar.gz", hash = "sha256:e39b4732d65fbdcde189ae76cf7cd48aeae72919dea1fdfc16593be016256b45", size = 118613, upload-time = "2026-02-21T16:54:40.609Z" }
2659
- wheels = [
2660
- { url = "https://files.pythonhosted.org/packages/4a/91/48db081e7a63bb37284f9fbcefda7c44c277b18b0e13fbc36ea2335b71e6/typer-0.24.1-py3-none-any.whl", hash = "sha256:112c1f0ce578bfb4cab9ffdabc68f031416ebcc216536611ba21f04e9aa84c9e", size = 56085, upload-time = "2026-02-21T16:54:41.616Z" },
2661
- ]
2662
-
2663
  [[package]]
2664
  name = "typing-extensions"
2665
  version = "4.15.0"
 
20
  "pawn",
21
  ]
22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  [[package]]
24
  name = "annotated-types"
25
  version = "0.7.0"
 
75
  { url = "https://files.pythonhosted.org/packages/64/b4/17d4b0b2a2dc85a6df63d1157e028ed19f90d4cd97c36717afef2bc2f395/attrs-26.1.0-py3-none-any.whl", hash = "sha256:c647aa4a12dfbad9333ca4e71fe62ddc36f4e63b2d260a37a8b83d2f043ac309", size = 67548, upload-time = "2026-03-19T14:22:23.645Z" },
76
  ]
77
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  [[package]]
79
  name = "cachetools"
80
  version = "7.0.5"
 
406
  { url = "https://files.pythonhosted.org/packages/c1/ea/53f2148663b321f21b5a606bd5f191517cf40b7072c0497d3c92c4a13b1e/executing-2.2.1-py2.py3-none-any.whl", hash = "sha256:760643d3452b4d777d295bb167ccc74c64a81df23fb5e08eff250c425a4b2017", size = 28317, upload-time = "2025-09-01T09:48:08.5Z" },
407
  ]
408
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
409
  [[package]]
410
  name = "fastjsonschema"
411
  version = "2.21.2"
 
415
  { url = "https://files.pythonhosted.org/packages/cb/a8/20d0723294217e47de6d9e2e40fd4a9d2f7c4b6ef974babd482a59743694/fastjsonschema-2.21.2-py3-none-any.whl", hash = "sha256:1c797122d0a86c5cace2e54bf4e819c36223b552017172f32c5c024a6b77e463", size = 24024, upload-time = "2025-08-14T18:49:34.776Z" },
416
  ]
417
 
 
 
 
 
 
 
 
 
 
418
  [[package]]
419
  name = "filelock"
420
  version = "3.25.2"
 
490
  { url = "https://files.pythonhosted.org/packages/6a/09/e21df6aef1e1ffc0c816f0522ddc3f6dcded766c3261813131c78a704470/gitpython-3.1.46-py3-none-any.whl", hash = "sha256:79812ed143d9d25b6d176a10bb511de0f9c67b1fa641d82097b0ab90398a2058", size = 208620, upload-time = "2026-01-01T15:37:30.574Z" },
491
  ]
492
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
493
  [[package]]
494
  name = "h11"
495
  version = "0.16.0"
 
499
  { url = "https://files.pythonhosted.org/packages/04/4b/29cac41a4d98d144bf5f6d33995617b185d14b22401f75ca86f384e87ff1/h11-0.16.0-py3-none-any.whl", hash = "sha256:63cf8bbe7522de3bf65932fda1d9c2772064ffb3dae62d55932da54b31cb6c86", size = 37515, upload-time = "2025-04-24T03:35:24.344Z" },
500
  ]
501
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
502
  [[package]]
503
  name = "humanize"
504
  version = "4.15.0"
 
1229
  { url = "https://files.pythonhosted.org/packages/9f/99/4c9c0c329bf9fc125008c3b54c7c94c0023518d06fc025ae36431375e1fe/nvidia_nvtx_cu12-12.8.90-py3-none-win_amd64.whl", hash = "sha256:619c8304aedc69f02ea82dd244541a83c3d9d40993381b3b590f1adaed3db41e", size = 56492, upload-time = "2025-03-07T01:52:24.69Z" },
1230
  ]
1231
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1232
  [[package]]
1233
  name = "packaging"
1234
  version = "26.0"
 
1325
  { name = "chess-engine", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
1326
  { name = "numpy", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
1327
  { name = "psutil", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
1328
+ { name = "safetensors", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
1329
  { name = "tqdm", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
1330
  { name = "wandb", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
1331
  ]
 
1341
  { name = "plotly", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
1342
  { name = "solara", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
1343
  ]
 
 
 
1344
  dev = [
1345
  { name = "ipykernel", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
1346
  { name = "pytest", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
 
1360
  requires-dist = [
1361
  { name = "anywidget", marker = "extra == 'dashboard'", specifier = ">=0.9.21" },
1362
  { name = "chess-engine", editable = "engine" },
 
1363
  { name = "ipykernel", marker = "extra == 'dev'", specifier = ">=7.2.0" },
1364
  { name = "matplotlib", marker = "extra == 'eval'", specifier = ">=3.10.8" },
1365
  { name = "numpy", specifier = "~=2.2.0" },
 
1369
  { name = "psutil", specifier = ">=5.9.0" },
1370
  { name = "pyarrow", marker = "extra == 'eval'", specifier = ">=23.0.1" },
1371
  { name = "pytest", marker = "extra == 'dev'", specifier = "~=9.0.0" },
1372
+ { name = "safetensors", specifier = ">=0.4.0" },
1373
  { name = "seaborn", marker = "extra == 'eval'", specifier = ">=0.13.2" },
1374
  { name = "solara", marker = "extra == 'dashboard'", specifier = ">=1.0.0" },
1375
  { name = "torch", marker = "extra == 'cu128'", specifier = "~=2.10.0", index = "https://download.pytorch.org/whl/cu128", conflict = { package = "pawn", extra = "cu128" } },
 
1378
  { name = "triton-rocm", marker = "extra == 'rocm'", specifier = ">=3.6.0", index = "https://download.pytorch.org/whl/rocm7.1", conflict = { package = "pawn", extra = "rocm" } },
1379
  { name = "wandb", specifier = "~=0.25.0" },
1380
  ]
1381
+ provides-extras = ["rocm", "cu128", "eval", "dashboard", "dev"]
1382
 
1383
  [[package]]
1384
  name = "pexpect"
 
1713
  { url = "https://files.pythonhosted.org/packages/36/c7/cfc8e811f061c841d7990b0201912c3556bfeb99cdcb7ed24adc8d6f8704/pydantic_core-2.41.5-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:56121965f7a4dc965bff783d70b907ddf3d57f6eba29b6d2e5dabfaf07799c51", size = 2145302, upload-time = "2025-11-04T13:43:46.64Z" },
1714
  ]
1715
 
 
 
 
 
 
 
 
 
 
1716
  [[package]]
1717
  name = "pygments"
1718
  version = "2.19.2"
 
1773
  { url = "https://files.pythonhosted.org/packages/ec/57/56b9bcc3c9c6a792fcbaf139543cee77261f3651ca9da0c93f5c1221264b/python_dateutil-2.9.0.post0-py2.py3-none-any.whl", hash = "sha256:a8b2bc7bffae282281c8140a97d3aa9c14da0b136dfe83f850eea9a5f7470427", size = 229892, upload-time = "2024-03-01T18:36:18.57Z" },
1774
  ]
1775
 
 
 
 
 
 
 
 
 
 
1776
  [[package]]
1777
  name = "pytz"
1778
  version = "2026.1.post1"
 
2003
  ]
2004
 
2005
  [[package]]
2006
+ name = "safetensors"
2007
+ version = "0.7.0"
2008
  source = { registry = "https://pypi.org/simple" }
2009
+ sdist = { url = "https://files.pythonhosted.org/packages/29/9c/6e74567782559a63bd040a236edca26fd71bc7ba88de2ef35d75df3bca5e/safetensors-0.7.0.tar.gz", hash = "sha256:07663963b67e8bd9f0b8ad15bb9163606cd27cc5a1b96235a50d8369803b96b0", size = 200878, upload-time = "2025-11-19T15:18:43.199Z" }
2010
+ wheels = [
2011
+ { url = "https://files.pythonhosted.org/packages/fa/47/aef6c06649039accf914afef490268e1067ed82be62bcfa5b7e886ad15e8/safetensors-0.7.0-cp38-abi3-macosx_10_12_x86_64.whl", hash = "sha256:c82f4d474cf725255d9e6acf17252991c3c8aac038d6ef363a4bf8be2f6db517", size = 467781, upload-time = "2025-11-19T15:18:35.84Z" },
2012
+ { url = "https://files.pythonhosted.org/packages/e8/00/374c0c068e30cd31f1e1b46b4b5738168ec79e7689ca82ee93ddfea05109/safetensors-0.7.0-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:94fd4858284736bb67a897a41608b5b0c2496c9bdb3bf2af1fa3409127f20d57", size = 447058, upload-time = "2025-11-19T15:18:34.416Z" },
2013
+ { url = "https://files.pythonhosted.org/packages/f1/06/578ffed52c2296f93d7fd2d844cabfa92be51a587c38c8afbb8ae449ca89/safetensors-0.7.0-cp38-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e07d91d0c92a31200f25351f4acb2bc6aff7f48094e13ebb1d0fb995b54b6542", size = 491748, upload-time = "2025-11-19T15:18:09.79Z" },
2014
+ { url = "https://files.pythonhosted.org/packages/ae/33/1debbbb70e4791dde185edb9413d1fe01619255abb64b300157d7f15dddd/safetensors-0.7.0-cp38-abi3-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:8469155f4cb518bafb4acf4865e8bb9d6804110d2d9bdcaa78564b9fd841e104", size = 503881, upload-time = "2025-11-19T15:18:16.145Z" },
2015
+ { url = "https://files.pythonhosted.org/packages/8e/1c/40c2ca924d60792c3be509833df711b553c60effbd91da6f5284a83f7122/safetensors-0.7.0-cp38-abi3-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:54bef08bf00a2bff599982f6b08e8770e09cc012d7bba00783fc7ea38f1fb37d", size = 623463, upload-time = "2025-11-19T15:18:21.11Z" },
2016
+ { url = "https://files.pythonhosted.org/packages/9b/3a/13784a9364bd43b0d61eef4bea2845039bc2030458b16594a1bd787ae26e/safetensors-0.7.0-cp38-abi3-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:42cb091236206bb2016d245c377ed383aa7f78691748f3bb6ee1bfa51ae2ce6a", size = 532855, upload-time = "2025-11-19T15:18:25.719Z" },
2017
+ { url = "https://files.pythonhosted.org/packages/a0/60/429e9b1cb3fc651937727befe258ea24122d9663e4d5709a48c9cbfceecb/safetensors-0.7.0-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dac7252938f0696ddea46f5e855dd3138444e82236e3be475f54929f0c510d48", size = 507152, upload-time = "2025-11-19T15:18:33.023Z" },
2018
+ { url = "https://files.pythonhosted.org/packages/3c/a8/4b45e4e059270d17af60359713ffd83f97900d45a6afa73aaa0d737d48b6/safetensors-0.7.0-cp38-abi3-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:1d060c70284127fa805085d8f10fbd0962792aed71879d00864acda69dbab981", size = 541856, upload-time = "2025-11-19T15:18:31.075Z" },
2019
+ { url = "https://files.pythonhosted.org/packages/06/87/d26d8407c44175d8ae164a95b5a62707fcc445f3c0c56108e37d98070a3d/safetensors-0.7.0-cp38-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:cdab83a366799fa730f90a4ebb563e494f28e9e92c4819e556152ad55e43591b", size = 674060, upload-time = "2025-11-19T15:18:37.211Z" },
2020
+ { url = "https://files.pythonhosted.org/packages/11/f5/57644a2ff08dc6325816ba7217e5095f17269dada2554b658442c66aed51/safetensors-0.7.0-cp38-abi3-musllinux_1_2_armv7l.whl", hash = "sha256:672132907fcad9f2aedcb705b2d7b3b93354a2aec1b2f706c4db852abe338f85", size = 771715, upload-time = "2025-11-19T15:18:38.689Z" },
2021
+ { url = "https://files.pythonhosted.org/packages/86/31/17883e13a814bd278ae6e266b13282a01049b0c81341da7fd0e3e71a80a3/safetensors-0.7.0-cp38-abi3-musllinux_1_2_i686.whl", hash = "sha256:5d72abdb8a4d56d4020713724ba81dac065fedb7f3667151c4a637f1d3fb26c0", size = 714377, upload-time = "2025-11-19T15:18:40.162Z" },
2022
+ { url = "https://files.pythonhosted.org/packages/4a/d8/0c8a7dc9b41dcac53c4cbf9df2b9c83e0e0097203de8b37a712b345c0be5/safetensors-0.7.0-cp38-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:b0f6d66c1c538d5a94a73aa9ddca8ccc4227e6c9ff555322ea40bdd142391dd4", size = 677368, upload-time = "2025-11-19T15:18:41.627Z" },
2023
+ { url = "https://files.pythonhosted.org/packages/05/e5/cb4b713c8a93469e3c5be7c3f8d77d307e65fe89673e731f5c2bfd0a9237/safetensors-0.7.0-cp38-abi3-win32.whl", hash = "sha256:c74af94bf3ac15ac4d0f2a7c7b4663a15f8c2ab15ed0fc7531ca61d0835eccba", size = 326423, upload-time = "2025-11-19T15:18:45.74Z" },
2024
+ { url = "https://files.pythonhosted.org/packages/5d/e6/ec8471c8072382cb91233ba7267fd931219753bb43814cbc71757bfd4dab/safetensors-0.7.0-cp38-abi3-win_amd64.whl", hash = "sha256:d1239932053f56f3456f32eb9625590cc7582e905021f94636202a864d470755", size = 341380, upload-time = "2025-11-19T15:18:44.427Z" },
2025
+ { url = "https://files.pythonhosted.org/packages/a7/6a/4d08d89a6fcbe905c5ae68b8b34f0791850882fc19782d0d02c65abbdf3b/safetensors-0.7.0-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f4729811a6640d019a4b7ba8638ee2fd21fa5ca8c7e7bdf0fed62068fcaac737", size = 492430, upload-time = "2025-11-19T15:18:11.884Z" },
2026
+ { url = "https://files.pythonhosted.org/packages/dd/29/59ed8152b30f72c42d00d241e58eaca558ae9dbfa5695206e2e0f54c7063/safetensors-0.7.0-pp310-pypy310_pp73-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:12f49080303fa6bb424b362149a12949dfbbf1e06811a88f2307276b0c131afd", size = 503977, upload-time = "2025-11-19T15:18:17.523Z" },
2027
+ { url = "https://files.pythonhosted.org/packages/d3/0b/4811bfec67fa260e791369b16dab105e4bae82686120554cc484064e22b4/safetensors-0.7.0-pp310-pypy310_pp73-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:0071bffba4150c2f46cae1432d31995d77acfd9f8db598b5d1a2ce67e8440ad2", size = 623890, upload-time = "2025-11-19T15:18:22.666Z" },
2028
+ { url = "https://files.pythonhosted.org/packages/58/5b/632a58724221ef03d78ab65062e82a1010e1bef8e8e0b9d7c6d7b8044841/safetensors-0.7.0-pp310-pypy310_pp73-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:473b32699f4200e69801bf5abf93f1a4ecd432a70984df164fc22ccf39c4a6f3", size = 531885, upload-time = "2025-11-19T15:18:27.146Z" },
2029
  ]
2030
 
2031
  [[package]]
 
2043
  { url = "https://files.pythonhosted.org/packages/83/11/00d3c3dfc25ad54e731d91449895a79e4bf2384dc3ac01809010ba88f6d5/seaborn-0.13.2-py3-none-any.whl", hash = "sha256:636f8336facf092165e27924f223d3c62ca560b1f2bb5dff7ab7fad265361987", size = 294914, upload-time = "2024-01-25T13:21:49.598Z" },
2044
  ]
2045
 
 
 
 
 
 
 
 
 
 
2046
  [[package]]
2047
  name = "sentry-sdk"
2048
  version = "2.55.0"
 
2065
  { url = "https://files.pythonhosted.org/packages/9d/76/f789f7a86709c6b087c5a2f52f911838cad707cc613162401badc665acfe/setuptools-82.0.1-py3-none-any.whl", hash = "sha256:a59e362652f08dcd477c78bb6e7bd9d80a7995bc73ce773050228a348ce2e5bb", size = 1006223, upload-time = "2026-03-09T12:47:15.026Z" },
2066
  ]
2067
 
 
 
 
 
 
 
 
 
 
2068
  [[package]]
2069
  name = "six"
2070
  version = "1.17.0"
 
2219
  { url = "https://files.pythonhosted.org/packages/23/d1/136eb2cb77520a31e1f64cbae9d33ec6df0d78bdf4160398e86eec8a8754/tomli-2.4.0-py3-none-any.whl", hash = "sha256:1f776e7d669ebceb01dee46484485f43a4048746235e683bcdffacdf1fb4785a", size = 14477, upload-time = "2026-01-11T11:22:37.446Z" },
2220
  ]
2221
 
 
 
 
 
 
 
 
 
 
2222
  [[package]]
2223
  name = "torch"
2224
  version = "2.10.0+cu128"
 
2351
  { url = "https://download.pytorch.org/whl/triton_rocm-3.6.0-cp312-cp312-linux_x86_64.whl" },
2352
  ]
2353
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2354
  [[package]]
2355
  name = "typing-extensions"
2356
  version = "4.15.0"