Safetensors migration, checkpoint integrity, and multi-model training. (#1)
Browse files- .dockerignore +1 -0
- .gitignore +9 -2
- CLAUDE.md +78 -7
- Dockerfile +28 -8
- README.md +10 -4
- checkpoints/pawn-base +1 -1
- checkpoints/pawn-large +1 -1
- checkpoints/pawn-small +1 -1
- deploy/entrypoint-run.sh +19 -0
- deploy/sync.sh +21 -50
- engine/python/chess_engine/__init__.py +2 -0
- engine/src/batch.rs +190 -22
- engine/src/board.rs +1 -1
- engine/src/lib.rs +41 -0
- engine/src/pgn.rs +0 -6
- engine/src/vocab.rs +59 -8
- pawn/checkpoint.py +647 -0
- pawn/config.py +1 -1
- pawn/data.py +47 -28
- pawn/eval_suite/diagnostics.py +1 -1
- pawn/eval_suite/lichess.py +7 -8
- pawn/eval_suite/probes.py +4 -10
- pawn/eval_suite/worker.py +5 -5
- pawn/lichess_data.py +5 -22
- pawn/model.py +1 -1
- pawn/trainer.py +52 -37
- pyproject.toml +1 -0
- scripts/check_progress.sh +14 -34
- scripts/eval_accuracy.py +34 -24
- scripts/eval_probes.py +9 -12
- scripts/export_hf_repo.py +283 -0
- scripts/profile_step.py +7 -6
- scripts/train.py +7 -1
- scripts/train_all.py +400 -0
- scripts/train_bottleneck.py +73 -35
- scripts/train_film.py +59 -20
- scripts/train_hybrid.py +59 -20
- scripts/train_lora.py +59 -20
- scripts/train_sparse.py +59 -20
- scripts/train_tiny.py +45 -12
- tests/test_checkpoint.py +359 -0
- tests/test_clm_format.py +224 -0
- tests/test_graceful_shutdown.py +130 -0
- uv.lock +25 -334
.dockerignore
CHANGED
|
@@ -10,6 +10,7 @@ data/
|
|
| 10 |
checkpoints/
|
| 11 |
logs/
|
| 12 |
deploy/
|
|
|
|
| 13 |
*.so
|
| 14 |
CLAUDE.md
|
| 15 |
docs/
|
|
|
|
| 10 |
checkpoints/
|
| 11 |
logs/
|
| 12 |
deploy/
|
| 13 |
+
!deploy/entrypoint-run.sh
|
| 14 |
*.so
|
| 15 |
CLAUDE.md
|
| 16 |
docs/
|
.gitignore
CHANGED
|
@@ -1,5 +1,4 @@
|
|
| 1 |
-
#
|
| 2 |
-
checkpoints/
|
| 3 |
logs/
|
| 4 |
|
| 5 |
# Training data
|
|
@@ -7,6 +6,7 @@ data/
|
|
| 7 |
|
| 8 |
# Deployment artifacts
|
| 9 |
deploy/pawn-deploy/
|
|
|
|
| 10 |
|
| 11 |
# Python
|
| 12 |
__pycache__/
|
|
@@ -36,8 +36,15 @@ Thumbs.db
|
|
| 36 |
# uv (track root lockfile only)
|
| 37 |
engine/uv.lock
|
| 38 |
|
|
|
|
|
|
|
|
|
|
| 39 |
# Misc
|
| 40 |
runpodctl.tar.gz
|
| 41 |
.claude/
|
|
|
|
|
|
|
|
|
|
| 42 |
.playwright-mcp/
|
| 43 |
local/
|
|
|
|
|
|
| 1 |
+
# Local checkpoints and logs
|
|
|
|
| 2 |
logs/
|
| 3 |
|
| 4 |
# Training data
|
|
|
|
| 6 |
|
| 7 |
# Deployment artifacts
|
| 8 |
deploy/pawn-deploy/
|
| 9 |
+
export/
|
| 10 |
|
| 11 |
# Python
|
| 12 |
__pycache__/
|
|
|
|
| 36 |
# uv (track root lockfile only)
|
| 37 |
engine/uv.lock
|
| 38 |
|
| 39 |
+
# Local working files
|
| 40 |
+
local/
|
| 41 |
+
|
| 42 |
# Misc
|
| 43 |
runpodctl.tar.gz
|
| 44 |
.claude/
|
| 45 |
+
<<<<<<< HEAD
|
| 46 |
+
.playwright-mcp
|
| 47 |
+
=======
|
| 48 |
.playwright-mcp/
|
| 49 |
local/
|
| 50 |
+
>>>>>>> c965110657b87bef82768cf9960b3f5486c54867
|
CLAUDE.md
CHANGED
|
@@ -41,7 +41,7 @@ uv sync --extra dev # + pytest, ipykernel
|
|
| 41 |
uv run pytest tests/
|
| 42 |
|
| 43 |
# Pretrain from scratch
|
| 44 |
-
uv run python scripts/train.py --variant base
|
| 45 |
```
|
| 46 |
|
| 47 |
## Engine (`engine/`)
|
|
@@ -86,18 +86,24 @@ All adapters freeze the backbone and initialize to identity (zero-init or gamma=
|
|
| 86 |
## Scripts (`scripts/`)
|
| 87 |
|
| 88 |
- `train.py` -- Pretrain from scratch (`--variant small|base|large|toy`)
|
|
|
|
| 89 |
- `train_bottleneck.py`, `train_film.py`, `train_lora.py`, `train_sparse.py`, `train_hybrid.py` -- Adapter behavioral cloning on Lichess PGN
|
| 90 |
- `train_tiny.py` -- Standalone tiny transformer baseline (no frozen backbone)
|
| 91 |
- `eval_accuracy.py` -- MAIA-compatible evaluation (per-phase, per-ply accuracy)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 92 |
|
| 93 |
## Deploy (`deploy/`)
|
| 94 |
|
| 95 |
-
|
| 96 |
-
- `
|
| 97 |
-
- `
|
| 98 |
-
- `
|
|
|
|
| 99 |
|
| 100 |
-
|
| 101 |
|
| 102 |
## Dashboard (`pawn/dashboard/`)
|
| 103 |
|
|
@@ -117,8 +123,73 @@ Auto-detects run type from config fields (`run_type`, `formulation`, `pgn_file`)
|
|
| 117 |
|
| 118 |
## Checkpoints
|
| 119 |
|
| 120 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 121 |
|
| 122 |
## Logs
|
| 123 |
|
| 124 |
Training metrics in `logs/` (gitignored). Each run gets a timestamped directory with `metrics.jsonl`.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
uv run pytest tests/
|
| 42 |
|
| 43 |
# Pretrain from scratch
|
| 44 |
+
uv run python scripts/train.py --variant base --local-checkpoints
|
| 45 |
```
|
| 46 |
|
| 47 |
## Engine (`engine/`)
|
|
|
|
| 86 |
## Scripts (`scripts/`)
|
| 87 |
|
| 88 |
- `train.py` -- Pretrain from scratch (`--variant small|base|large|toy`)
|
| 89 |
+
- `train_all.py` -- Train small/base/large simultaneously on shared data batches
|
| 90 |
- `train_bottleneck.py`, `train_film.py`, `train_lora.py`, `train_sparse.py`, `train_hybrid.py` -- Adapter behavioral cloning on Lichess PGN
|
| 91 |
- `train_tiny.py` -- Standalone tiny transformer baseline (no frozen backbone)
|
| 92 |
- `eval_accuracy.py` -- MAIA-compatible evaluation (per-phase, per-ply accuracy)
|
| 93 |
+
- `eval_probes.py` -- Run linear probes on all checkpoints
|
| 94 |
+
- `export_hf_repo.py` -- Convert training run to HuggingFace repo format (safetensors + metrics)
|
| 95 |
+
|
| 96 |
+
All training scripts require `--hf-repo REPO` or `--local-checkpoints`.
|
| 97 |
|
| 98 |
## Deploy (`deploy/`)
|
| 99 |
|
| 100 |
+
Docker-based deployment to Runpod GPU VMs:
|
| 101 |
+
- `Dockerfile` -- Multi-target build: `interactive` (SSH+Jupyter, default) and `runner` (auto-stop)
|
| 102 |
+
- `entrypoint-run.sh` -- Runner entrypoint, pulls from HF via `PAWN_MODEL` env var
|
| 103 |
+
- `sync.sh` -- Pull latest checkpoints/metrics from HuggingFace submodules
|
| 104 |
+
- `pod.sh` -- Pod lifecycle (create/start/stop/delete/ssh)
|
| 105 |
|
| 106 |
+
Code lives at `/opt/pawn` on pods (outside the `/workspace` volume mount).
|
| 107 |
|
| 108 |
## Dashboard (`pawn/dashboard/`)
|
| 109 |
|
|
|
|
| 123 |
|
| 124 |
## Checkpoints
|
| 125 |
|
| 126 |
+
Pre-trained weights are HuggingFace git submodules under `checkpoints/`:
|
| 127 |
+
- `checkpoints/pawn-small` — 9.5M params, `CLMConfig.small()`
|
| 128 |
+
- `checkpoints/pawn-base` — 35.8M params, `CLMConfig.base()`
|
| 129 |
+
- `checkpoints/pawn-large` — 68.4M params, `CLMConfig.large()`
|
| 130 |
+
|
| 131 |
+
Pull with: `git submodule update --init --remote checkpoints/pawn-base`
|
| 132 |
+
|
| 133 |
+
### Checkpoint Format (safetensors)
|
| 134 |
+
|
| 135 |
+
Checkpoints are directories, not single files:
|
| 136 |
+
```
|
| 137 |
+
step_00065000/
|
| 138 |
+
├── model.safetensors # model weights
|
| 139 |
+
├── optimizer.safetensors # flattened optimizer state
|
| 140 |
+
├── training_state.json # step, scheduler, scaler, RNG (base64)
|
| 141 |
+
├── config.json # model + training config
|
| 142 |
+
└── .complete # SHA-256 hashes of all files (integrity sentinel)
|
| 143 |
+
```
|
| 144 |
+
|
| 145 |
+
Central module: `pawn/checkpoint.py`. All save/load goes through this module.
|
| 146 |
+
Legacy `.pt` files are still loadable (backward compatible).
|
| 147 |
+
|
| 148 |
+
### Checkpoint Storage Modes
|
| 149 |
+
|
| 150 |
+
All training scripts require one of:
|
| 151 |
+
- `--hf-repo REPO_ID` — push checkpoints to a HuggingFace branch as they're written (durable)
|
| 152 |
+
- `--local-checkpoints` — save locally only (for development without an HF account)
|
| 153 |
+
|
| 154 |
+
HF mode creates a `run/{run_id}` branch. Squash-merge into main when satisfied.
|
| 155 |
+
|
| 156 |
+
### Data Integrity
|
| 157 |
+
|
| 158 |
+
**Every checkpoint write is atomic**: files are written to a `.tmp` directory, then renamed.
|
| 159 |
+
The `.complete` sentinel contains SHA-256 hashes of every file in the checkpoint.
|
| 160 |
+
**Hashes are always verified on load — no exceptions.**
|
| 161 |
+
|
| 162 |
+
- `IncompleteCheckpointError` — raised when `.complete` sentinel is missing
|
| 163 |
+
- `CheckpointIntegrityError` — raised when any hash mismatches
|
| 164 |
+
|
| 165 |
+
**Never use `kill -9` on training processes.** SIGTERM is handled gracefully: a flag is set,
|
| 166 |
+
the training loop checks it between steps, saves a checkpoint, pushes to HF, and exits cleanly.
|
| 167 |
+
|
| 168 |
+
**Never rsync checkpoint files from running pods.** Checkpoints are pushed to HuggingFace
|
| 169 |
+
from the trainer. Pull via `deploy/sync.sh` (submodule update).
|
| 170 |
|
| 171 |
## Logs
|
| 172 |
|
| 173 |
Training metrics in `logs/` (gitignored). Each run gets a timestamped directory with `metrics.jsonl`.
|
| 174 |
+
|
| 175 |
+
## Runpod Pod Management
|
| 176 |
+
|
| 177 |
+
### Setup
|
| 178 |
+
|
| 179 |
+
- Docker image: multi-target build in `Dockerfile`
|
| 180 |
+
- `interactive` (default) — SSH + Jupyter, stays alive
|
| 181 |
+
- `runner` — executes command then exits (pod auto-stops)
|
| 182 |
+
- Build: `docker build --target runner --build-arg GIT_HASH=$(git rev-parse HEAD) ...`
|
| 183 |
+
|
| 184 |
+
### Required Configuration
|
| 185 |
+
|
| 186 |
+
- **Always attach a network volume.** Checkpoints write to disk during atomic rename and HF push. Ephemeral container disk is lost on pod termination.
|
| 187 |
+
- **Set `HF_TOKEN` as a pod environment variable** for automatic HuggingFace authentication.
|
| 188 |
+
- Set `PAWN_MODEL=thomas-schweich/pawn-base` env var in the runner to auto-pull a checkpoint on startup.
|
| 189 |
+
|
| 190 |
+
### Lifecycle
|
| 191 |
+
|
| 192 |
+
- Create: `runpodctl pod create --name pawn-exp --gpu-id "NVIDIA RTX A5000" --image thomasschweich/pawn:<tag> --volume-in-gb 75 --ports "8888/http,22/tcp"`
|
| 193 |
+
- Stop: `runpodctl pod stop <ID>` — sends SIGTERM → trainer saves and pushes before exiting
|
| 194 |
+
- **Never `runpodctl pod delete` while training is running** — data loss risk
|
| 195 |
+
- Monitor: pull HF submodule (`deploy/sync.sh`) and read `metrics.jsonl`
|
Dockerfile
CHANGED
|
@@ -3,17 +3,27 @@
|
|
| 3 |
# Extends the official Runpod PyTorch template — SSH and JupyterLab
|
| 4 |
# start automatically via the base image's /start.sh entrypoint.
|
| 5 |
#
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
# Build:
|
| 7 |
# docker build --platform linux/amd64 \
|
| 8 |
# --build-arg GIT_HASH=$(git rev-parse HEAD) \
|
| 9 |
# --build-arg GIT_TAG=$(git tag --points-at HEAD) \
|
|
|
|
| 10 |
# -t pawn:<tag> .
|
| 11 |
#
|
| 12 |
-
#
|
| 13 |
-
#
|
| 14 |
-
#
|
| 15 |
-
#
|
| 16 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
|
| 18 |
# ── Builder ──────────────────────────────────────────────────────────
|
| 19 |
FROM runpod/pytorch:1.0.2-cu1281-torch280-ubuntu2404 AS builder
|
|
@@ -43,14 +53,14 @@ RUN cd engine && \
|
|
| 43 |
uv run --no-project --with maturin maturin build --release && \
|
| 44 |
cd ..
|
| 45 |
|
| 46 |
-
# ── Runtime ─────────────────────────────
|
| 47 |
-
FROM runpod/pytorch:1.0.2-cu1281-torch280-ubuntu2404
|
| 48 |
|
| 49 |
ENV PYTHONUNBUFFERED=1 \
|
| 50 |
PYTHONPATH=/opt/pawn
|
| 51 |
|
| 52 |
# Direct deps only (torch + numpy already in base image)
|
| 53 |
-
RUN pip install --no-cache-dir psutil tqdm wandb
|
| 54 |
|
| 55 |
COPY --from=builder /workspace/pawn/engine/target/wheels/*.whl /tmp/
|
| 56 |
RUN pip install --no-cache-dir /tmp/*.whl && rm -rf /tmp/*.whl
|
|
@@ -72,3 +82,13 @@ RUN echo "export PYTHONPATH=/opt/pawn" >> /etc/environment && \
|
|
| 72 |
echo "export PAWN_GIT_HASH=${GIT_HASH}" >> /etc/environment && \
|
| 73 |
echo "export PAWN_GIT_TAG=${GIT_TAG}" >> /etc/environment && \
|
| 74 |
cat /etc/environment >> /root/.bashrc
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
# Extends the official Runpod PyTorch template — SSH and JupyterLab
|
| 4 |
# start automatically via the base image's /start.sh entrypoint.
|
| 5 |
#
|
| 6 |
+
# Build targets:
|
| 7 |
+
# interactive (default) — SSH + Jupyter, stays alive
|
| 8 |
+
# runner — runs a command then exits (pod auto-stops)
|
| 9 |
+
#
|
| 10 |
# Build:
|
| 11 |
# docker build --platform linux/amd64 \
|
| 12 |
# --build-arg GIT_HASH=$(git rev-parse HEAD) \
|
| 13 |
# --build-arg GIT_TAG=$(git tag --points-at HEAD) \
|
| 14 |
+
# [--target runner] \
|
| 15 |
# -t pawn:<tag> .
|
| 16 |
#
|
| 17 |
+
# Run (interactive):
|
| 18 |
+
# docker run --gpus all pawn:<tag>
|
| 19 |
+
#
|
| 20 |
+
# IMPORTANT: Always attach a Runpod network volume. Checkpoints use
|
| 21 |
+
# atomic directory writes (tmp + rename) that require persistent disk.
|
| 22 |
+
# Set HF_TOKEN as a pod env var for HuggingFace checkpoint push.
|
| 23 |
+
#
|
| 24 |
+
# Run (auto-stop):
|
| 25 |
+
# docker run --gpus all -e PAWN_MODEL=thomas-schweich/pawn-base \
|
| 26 |
+
# pawn:<tag>-runner python scripts/train.py --variant base
|
| 27 |
|
| 28 |
# ── Builder ──────────────────────────────────────────────────────────
|
| 29 |
FROM runpod/pytorch:1.0.2-cu1281-torch280-ubuntu2404 AS builder
|
|
|
|
| 53 |
uv run --no-project --with maturin maturin build --release && \
|
| 54 |
cd ..
|
| 55 |
|
| 56 |
+
# ── Runtime base (shared by all targets) ─────────────────────────────
|
| 57 |
+
FROM runpod/pytorch:1.0.2-cu1281-torch280-ubuntu2404 AS runtime-base
|
| 58 |
|
| 59 |
ENV PYTHONUNBUFFERED=1 \
|
| 60 |
PYTHONPATH=/opt/pawn
|
| 61 |
|
| 62 |
# Direct deps only (torch + numpy already in base image)
|
| 63 |
+
RUN pip install --no-cache-dir psutil safetensors tqdm wandb huggingface-hub
|
| 64 |
|
| 65 |
COPY --from=builder /workspace/pawn/engine/target/wheels/*.whl /tmp/
|
| 66 |
RUN pip install --no-cache-dir /tmp/*.whl && rm -rf /tmp/*.whl
|
|
|
|
| 82 |
echo "export PAWN_GIT_HASH=${GIT_HASH}" >> /etc/environment && \
|
| 83 |
echo "export PAWN_GIT_TAG=${GIT_TAG}" >> /etc/environment && \
|
| 84 |
cat /etc/environment >> /root/.bashrc
|
| 85 |
+
|
| 86 |
+
# ── Runner — executes command then exits (pod auto-stops) ────────────
|
| 87 |
+
FROM runtime-base AS runner
|
| 88 |
+
COPY deploy/entrypoint-run.sh /entrypoint-run.sh
|
| 89 |
+
RUN chmod +x /entrypoint-run.sh
|
| 90 |
+
ENTRYPOINT ["/entrypoint-run.sh"]
|
| 91 |
+
|
| 92 |
+
# ── Interactive (default) — SSH + Jupyter, stays alive ───────────────
|
| 93 |
+
FROM runtime-base AS interactive
|
| 94 |
+
# Inherits /start.sh entrypoint from Runpod base image
|
README.md
CHANGED
|
@@ -32,7 +32,9 @@ Notably, the vocabulary includes impossible moves like `a1a1` and `b1a5`. PAWN n
|
|
| 32 |
|
| 33 |
Conceptually, each token is best thought of as a move in UCI notation--they are effectively coordinates. They do not include any information on the type of peice, side to play, or any direct geometric or board state information other than the factored nature of the embeddings (see the architecture section below for details).
|
| 34 |
|
| 35 |
-
For example, `e2e4` is the token that represents the king's pawn opening, but only when it's the first ply in the sequence (moving a rook between from e2 to e4 in the late game would use the same token). The model learns to track which type of peice is on each square any given moment entirely of its own accord.
|
|
|
|
|
|
|
| 36 |
|
| 37 |
## Quickstart
|
| 38 |
|
|
@@ -50,13 +52,17 @@ uv sync --extra cu128 # NVIDIA GPU (or --extra rocm for AMD)
|
|
| 50 |
uv sync --extra dev --extra eval --extra dashboard
|
| 51 |
|
| 52 |
# Train an adapter on a pre-trained checkpoint
|
|
|
|
| 53 |
uv run python scripts/train_bottleneck.py \
|
| 54 |
-
--checkpoint checkpoints/pawn-base
|
| 55 |
--pgn data/lichess_1800_1900.pgn \
|
| 56 |
-
--bottleneck-dim 32 --lr 1e-4
|
| 57 |
|
| 58 |
# Or pretrain a PAWN variant from scratch (generates random games on-the-fly; no dataset required)
|
| 59 |
-
uv run python scripts/train.py --variant base
|
|
|
|
|
|
|
|
|
|
| 60 |
|
| 61 |
# Launch the real-time monitoring dashboard (optional dashboard dependency must be installed)
|
| 62 |
uv run python -m pawn.dashboard --log-dir logs --port 8765
|
|
|
|
| 32 |
|
| 33 |
Conceptually, each token is best thought of as a move in UCI notation--they are effectively coordinates. They do not include any information on the type of peice, side to play, or any direct geometric or board state information other than the factored nature of the embeddings (see the architecture section below for details).
|
| 34 |
|
| 35 |
+
For example, `e2e4` is the token that represents the king's pawn opening, but only when it's the first ply in the sequence (moving a rook between from e2 to e4 in the late game would use the same token). The model learns to track which type of peice is on each square any given moment entirely of its own accord.
|
| 36 |
+
|
| 37 |
+
For that matter, it isn't told what piece types exist, what movement patterns they follow, or indeed the concept of a peice. All of that 'understanding' comes purely from observation and can be isolated via [linear probes](https://arxiv.org/abs/1610.01644) (Alain & Bengio, 2016).
|
| 38 |
|
| 39 |
## Quickstart
|
| 40 |
|
|
|
|
| 52 |
uv sync --extra dev --extra eval --extra dashboard
|
| 53 |
|
| 54 |
# Train an adapter on a pre-trained checkpoint
|
| 55 |
+
git submodule update --init checkpoints/pawn-base
|
| 56 |
uv run python scripts/train_bottleneck.py \
|
| 57 |
+
--checkpoint checkpoints/pawn-base \
|
| 58 |
--pgn data/lichess_1800_1900.pgn \
|
| 59 |
+
--bottleneck-dim 32 --lr 1e-4 --local-checkpoints
|
| 60 |
|
| 61 |
# Or pretrain a PAWN variant from scratch (generates random games on-the-fly; no dataset required)
|
| 62 |
+
uv run python scripts/train.py --variant base --local-checkpoints
|
| 63 |
+
|
| 64 |
+
# Or train all three variants simultaneously on shared data
|
| 65 |
+
uv run python scripts/train_all.py --local-checkpoints
|
| 66 |
|
| 67 |
# Launch the real-time monitoring dashboard (optional dashboard dependency must be installed)
|
| 68 |
uv run python -m pawn.dashboard --log-dir logs --port 8765
|
checkpoints/pawn-base
CHANGED
|
@@ -1 +1 @@
|
|
| 1 |
-
Subproject commit
|
|
|
|
| 1 |
+
Subproject commit 45aa6e347ca9516662874238adaeef4d30fc6df8
|
checkpoints/pawn-large
CHANGED
|
@@ -1 +1 @@
|
|
| 1 |
-
Subproject commit
|
|
|
|
| 1 |
+
Subproject commit 1d3f1deee3411e86d69a814520553fcf78f96c5f
|
checkpoints/pawn-small
CHANGED
|
@@ -1 +1 @@
|
|
| 1 |
-
Subproject commit
|
|
|
|
| 1 |
+
Subproject commit 929fee43ce16e02cf12a255cf39cc3f26e0913e1
|
deploy/entrypoint-run.sh
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
# Entrypoint for auto-stop runner target.
|
| 3 |
+
# Optionally pulls a checkpoint from HuggingFace before running the command.
|
| 4 |
+
set -e
|
| 5 |
+
|
| 6 |
+
cd /opt/pawn
|
| 7 |
+
export PYTHONPATH=/opt/pawn
|
| 8 |
+
|
| 9 |
+
# Pull checkpoint from HuggingFace if PAWN_MODEL is set
|
| 10 |
+
if [ -n "$PAWN_MODEL" ]; then
|
| 11 |
+
echo "Pulling model: $PAWN_MODEL"
|
| 12 |
+
python3 -c "
|
| 13 |
+
from huggingface_hub import snapshot_download
|
| 14 |
+
snapshot_download('$PAWN_MODEL', local_dir='checkpoints/model')
|
| 15 |
+
print('Model downloaded to checkpoints/model/')
|
| 16 |
+
"
|
| 17 |
+
fi
|
| 18 |
+
|
| 19 |
+
exec "$@"
|
deploy/sync.sh
CHANGED
|
@@ -1,65 +1,36 @@
|
|
| 1 |
#!/usr/bin/env bash
|
| 2 |
-
#
|
| 3 |
-
# Usage: bash deploy/sync.sh [
|
| 4 |
#
|
| 5 |
-
# With no args,
|
| 6 |
-
# With a
|
| 7 |
set -euo pipefail
|
| 8 |
|
| 9 |
REPO="$(cd "$(dirname "$0")/.." && pwd)"
|
| 10 |
-
POD_DIR="$HOME/.config/pawn/pods"
|
| 11 |
|
| 12 |
-
|
| 13 |
-
local
|
| 14 |
-
local
|
| 15 |
-
|
| 16 |
-
echo "
|
| 17 |
-
|
| 18 |
-
echo "
|
| 19 |
-
rsync -avz --progress --no-owner --no-group \
|
| 20 |
-
-e "ssh $ssh_opts" \
|
| 21 |
-
"root@$host:$remote_root/logs/" "$REPO/logs/" || {
|
| 22 |
-
echo " Failed to sync logs from $name"
|
| 23 |
-
return 1
|
| 24 |
-
}
|
| 25 |
-
|
| 26 |
-
echo "--- Checkpoints (best.pt only) ---"
|
| 27 |
-
rsync -avz --progress --no-owner --no-group \
|
| 28 |
-
--include='*/' --include='best.pt' --include='*.pt' --exclude='*' \
|
| 29 |
-
-e "ssh $ssh_opts" \
|
| 30 |
-
"root@$host:$remote_root/logs/" "$REPO/logs/" || {
|
| 31 |
-
echo " Failed to sync checkpoints from $name"
|
| 32 |
-
return 1
|
| 33 |
-
}
|
| 34 |
-
|
| 35 |
-
echo "=== $name sync complete ==="
|
| 36 |
echo ""
|
| 37 |
}
|
| 38 |
|
| 39 |
-
if [ ! -d "$POD_DIR" ] || [ -z "$(ls "$POD_DIR"/*.env 2>/dev/null)" ]; then
|
| 40 |
-
echo "No pods configured. Add .env files to $POD_DIR/"
|
| 41 |
-
exit 1
|
| 42 |
-
fi
|
| 43 |
-
|
| 44 |
if [ $# -ge 1 ]; then
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
|
|
|
| 50 |
exit 1
|
| 51 |
fi
|
| 52 |
-
|
| 53 |
-
sync_pod "$1" "$POD_HOST" "$POD_PORT" "${POD_REMOTE_ROOT:-/opt/pawn}"
|
| 54 |
else
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
unset POD_HOST POD_PORT POD_REMOTE_ROOT
|
| 59 |
-
source "$pod_file"
|
| 60 |
-
sync_pod "$pod_name" "$POD_HOST" "$POD_PORT" "${POD_REMOTE_ROOT:-/opt/pawn}" || true
|
| 61 |
done
|
| 62 |
fi
|
| 63 |
-
|
| 64 |
-
echo "Local logs:"
|
| 65 |
-
du -sh "$REPO/logs/"
|
|
|
|
| 1 |
#!/usr/bin/env bash
|
| 2 |
+
# Pull latest checkpoints and metrics from HuggingFace submodules.
|
| 3 |
+
# Usage: bash deploy/sync.sh [submodule-name]
|
| 4 |
#
|
| 5 |
+
# With no args, pulls all checkpoint submodules.
|
| 6 |
+
# With a name, pulls only that submodule.
|
| 7 |
set -euo pipefail
|
| 8 |
|
| 9 |
REPO="$(cd "$(dirname "$0")/.." && pwd)"
|
|
|
|
| 10 |
|
| 11 |
+
pull_submodule() {
|
| 12 |
+
local sub="$1"
|
| 13 |
+
local name="$(basename "$sub")"
|
| 14 |
+
echo "=== Pulling $name ==="
|
| 15 |
+
git -C "$sub" fetch origin 2>/dev/null || { echo " Failed to fetch $name"; return 1; }
|
| 16 |
+
git -C "$sub" pull origin main 2>/dev/null || { echo " Failed to pull $name (main)"; return 1; }
|
| 17 |
+
echo "=== $name up to date ==="
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
echo ""
|
| 19 |
}
|
| 20 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
if [ $# -ge 1 ]; then
|
| 22 |
+
sub="$REPO/checkpoints/$1"
|
| 23 |
+
if [ ! -d "$sub/.git" ]; then
|
| 24 |
+
echo "Submodule '$1' not found. Available:"
|
| 25 |
+
for s in "$REPO"/checkpoints/pawn-*/; do
|
| 26 |
+
[ -d "$s/.git" ] && echo " $(basename "$s")"
|
| 27 |
+
done
|
| 28 |
exit 1
|
| 29 |
fi
|
| 30 |
+
pull_submodule "$sub"
|
|
|
|
| 31 |
else
|
| 32 |
+
for sub in "$REPO"/checkpoints/pawn-*/; do
|
| 33 |
+
[ -d "$sub/.git" ] || continue
|
| 34 |
+
pull_submodule "$sub" || true
|
|
|
|
|
|
|
|
|
|
| 35 |
done
|
| 36 |
fi
|
|
|
|
|
|
|
|
|
engine/python/chess_engine/__init__.py
CHANGED
|
@@ -8,6 +8,7 @@ from chess_engine._engine import (
|
|
| 8 |
# Core game generation
|
| 9 |
generate_training_batch,
|
| 10 |
generate_random_games,
|
|
|
|
| 11 |
generate_checkmate_games,
|
| 12 |
generate_checkmate_training_batch,
|
| 13 |
# Diagnostic sets
|
|
@@ -37,6 +38,7 @@ from chess_engine._engine import (
|
|
| 37 |
__all__ = [
|
| 38 |
"generate_training_batch",
|
| 39 |
"generate_random_games",
|
|
|
|
| 40 |
"generate_checkmate_games",
|
| 41 |
"generate_checkmate_training_batch",
|
| 42 |
"generate_diagnostic_sets",
|
|
|
|
| 8 |
# Core game generation
|
| 9 |
generate_training_batch,
|
| 10 |
generate_random_games,
|
| 11 |
+
generate_clm_batch,
|
| 12 |
generate_checkmate_games,
|
| 13 |
generate_checkmate_training_batch,
|
| 14 |
# Diagnostic sets
|
|
|
|
| 38 |
__all__ = [
|
| 39 |
"generate_training_batch",
|
| 40 |
"generate_random_games",
|
| 41 |
+
"generate_clm_batch",
|
| 42 |
"generate_checkmate_games",
|
| 43 |
"generate_checkmate_training_batch",
|
| 44 |
"generate_diagnostic_sets",
|
engine/src/batch.rs
CHANGED
|
@@ -49,23 +49,18 @@ pub fn generate_training_batch(batch_size: usize, max_ply: usize, seed: u64) ->
|
|
| 49 |
game_lengths.push(record.game_length as i16);
|
| 50 |
termination_codes.push(record.termination.as_u8());
|
| 51 |
|
| 52 |
-
// Copy move_ids
|
| 53 |
for t in 0..length {
|
| 54 |
move_ids[b * max_ply + t] = record.move_ids[t] as i16;
|
| 55 |
}
|
| 56 |
-
// Place EOG token at position game_length
|
| 57 |
-
if length < max_ply {
|
| 58 |
-
move_ids[b * max_ply + length] = vocab::EOG_TOKEN as i16;
|
| 59 |
-
}
|
| 60 |
|
| 61 |
-
// Copy legal move grids
|
| 62 |
for t in 0..length {
|
| 63 |
let grid_offset = (b * max_ply + t) * 64;
|
| 64 |
debug_assert_eq!(record.legal_grids[t].len(), 64);
|
| 65 |
legal_move_grid[grid_offset..grid_offset + 64]
|
| 66 |
.copy_from_slice(&record.legal_grids[t]);
|
| 67 |
}
|
| 68 |
-
// EOG position: all zeros (already initialized)
|
| 69 |
|
| 70 |
// Copy promotion masks (contiguous layout: [[bool; 4]; 44] = [bool; 176])
|
| 71 |
for t in 0..length {
|
|
@@ -108,9 +103,6 @@ pub fn generate_random_games(n_games: usize, max_ply: usize, seed: u64) -> GameB
|
|
| 108 |
for t in 0..(*length as usize) {
|
| 109 |
move_ids[b * max_ply + t] = moves[t] as i16;
|
| 110 |
}
|
| 111 |
-
if (*length as usize) < max_ply {
|
| 112 |
-
move_ids[b * max_ply + *length as usize] = vocab::EOG_TOKEN as i16;
|
| 113 |
-
}
|
| 114 |
}
|
| 115 |
|
| 116 |
GameBatch {
|
|
@@ -152,9 +144,6 @@ pub fn generate_checkmate_training_batch(
|
|
| 152 |
for t in 0..(ex.game_length as usize).min(max_ply) {
|
| 153 |
move_ids[b * max_ply + t] = ex.move_ids[t] as i16;
|
| 154 |
}
|
| 155 |
-
if (ex.game_length as usize) < max_ply {
|
| 156 |
-
move_ids[b * max_ply + ex.game_length as usize] = vocab::EOG_TOKEN as i16;
|
| 157 |
-
}
|
| 158 |
checkmate_targets[b * 64..(b + 1) * 64].copy_from_slice(&ex.checkmate_grid);
|
| 159 |
legal_grids[b * 64..(b + 1) * 64].copy_from_slice(&ex.legal_grid);
|
| 160 |
}
|
|
@@ -206,9 +195,6 @@ pub fn generate_completed_games(n_games: usize, max_ply: usize, seed: u64) -> Ga
|
|
| 206 |
for t in 0..(*length as usize) {
|
| 207 |
move_ids[b * max_ply + t] = moves[t] as i16;
|
| 208 |
}
|
| 209 |
-
if (*length as usize) < max_ply {
|
| 210 |
-
move_ids[b * max_ply + *length as usize] = vocab::EOG_TOKEN as i16;
|
| 211 |
-
}
|
| 212 |
}
|
| 213 |
|
| 214 |
GameBatch {
|
|
@@ -281,9 +267,6 @@ pub fn generate_checkmate_games(
|
|
| 281 |
for t in 0..(*length as usize) {
|
| 282 |
move_ids[b * max_ply + t] = moves[t] as i16;
|
| 283 |
}
|
| 284 |
-
if (*length as usize) < max_ply {
|
| 285 |
-
move_ids[b * max_ply + *length as usize] = vocab::EOG_TOKEN as i16;
|
| 286 |
-
}
|
| 287 |
}
|
| 288 |
|
| 289 |
(GameBatch {
|
|
@@ -295,6 +278,97 @@ pub fn generate_checkmate_games(
|
|
| 295 |
}, total_generated)
|
| 296 |
}
|
| 297 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 298 |
#[cfg(test)]
|
| 299 |
mod tests {
|
| 300 |
use super::*;
|
|
@@ -321,15 +395,23 @@ mod tests {
|
|
| 321 |
}
|
| 322 |
|
| 323 |
#[test]
|
| 324 |
-
fn
|
| 325 |
let batch = generate_training_batch(2, 256, 42);
|
| 326 |
for b in 0..2 {
|
| 327 |
let len = batch.game_lengths[b] as usize;
|
| 328 |
if len < 256 {
|
| 329 |
assert_eq!(
|
| 330 |
batch.move_ids[b * 256 + len],
|
| 331 |
-
vocab::
|
| 332 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 333 |
);
|
| 334 |
}
|
| 335 |
}
|
|
@@ -343,4 +425,90 @@ mod tests {
|
|
| 343 |
assert_eq!(b1.game_lengths, b2.game_lengths);
|
| 344 |
assert_eq!(b1.legal_move_grid, b2.legal_move_grid);
|
| 345 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 346 |
}
|
|
|
|
| 49 |
game_lengths.push(record.game_length as i16);
|
| 50 |
termination_codes.push(record.termination.as_u8());
|
| 51 |
|
| 52 |
+
// Copy move_ids (remaining positions are already 0 = PAD)
|
| 53 |
for t in 0..length {
|
| 54 |
move_ids[b * max_ply + t] = record.move_ids[t] as i16;
|
| 55 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
|
| 57 |
+
// Copy legal move grids (positions beyond game_length are already 0)
|
| 58 |
for t in 0..length {
|
| 59 |
let grid_offset = (b * max_ply + t) * 64;
|
| 60 |
debug_assert_eq!(record.legal_grids[t].len(), 64);
|
| 61 |
legal_move_grid[grid_offset..grid_offset + 64]
|
| 62 |
.copy_from_slice(&record.legal_grids[t]);
|
| 63 |
}
|
|
|
|
| 64 |
|
| 65 |
// Copy promotion masks (contiguous layout: [[bool; 4]; 44] = [bool; 176])
|
| 66 |
for t in 0..length {
|
|
|
|
| 103 |
for t in 0..(*length as usize) {
|
| 104 |
move_ids[b * max_ply + t] = moves[t] as i16;
|
| 105 |
}
|
|
|
|
|
|
|
|
|
|
| 106 |
}
|
| 107 |
|
| 108 |
GameBatch {
|
|
|
|
| 144 |
for t in 0..(ex.game_length as usize).min(max_ply) {
|
| 145 |
move_ids[b * max_ply + t] = ex.move_ids[t] as i16;
|
| 146 |
}
|
|
|
|
|
|
|
|
|
|
| 147 |
checkmate_targets[b * 64..(b + 1) * 64].copy_from_slice(&ex.checkmate_grid);
|
| 148 |
legal_grids[b * 64..(b + 1) * 64].copy_from_slice(&ex.legal_grid);
|
| 149 |
}
|
|
|
|
| 195 |
for t in 0..(*length as usize) {
|
| 196 |
move_ids[b * max_ply + t] = moves[t] as i16;
|
| 197 |
}
|
|
|
|
|
|
|
|
|
|
| 198 |
}
|
| 199 |
|
| 200 |
GameBatch {
|
|
|
|
| 267 |
for t in 0..(*length as usize) {
|
| 268 |
move_ids[b * max_ply + t] = moves[t] as i16;
|
| 269 |
}
|
|
|
|
|
|
|
|
|
|
| 270 |
}
|
| 271 |
|
| 272 |
(GameBatch {
|
|
|
|
| 278 |
}, total_generated)
|
| 279 |
}
|
| 280 |
|
| 281 |
+
/// Output of CLM (Causal Language Model) batch generation.
|
| 282 |
+
///
|
| 283 |
+
/// Contains ready-to-train tensors in the format:
|
| 284 |
+
/// input_ids = [outcome, ply_1, ply_2, ..., ply_N, PAD, ..., PAD]
|
| 285 |
+
/// targets = [ply_1, ply_2, ply_3, ..., PAD, PAD, ..., PAD]
|
| 286 |
+
/// loss_mask = [true, true, true, ..., true, false, ..., false]
|
| 287 |
+
///
|
| 288 |
+
/// Also includes raw move_ids and game_lengths for replay operations
|
| 289 |
+
/// (legal mask computation, board state extraction, validation).
|
| 290 |
+
pub struct CLMBatch {
|
| 291 |
+
pub input_ids: Vec<i16>, // [batch_size * seq_len]
|
| 292 |
+
pub targets: Vec<i16>, // [batch_size * seq_len]
|
| 293 |
+
pub loss_mask: Vec<bool>, // [batch_size * seq_len]
|
| 294 |
+
pub move_ids: Vec<i16>, // [batch_size * max_ply] raw for replay
|
| 295 |
+
pub game_lengths: Vec<i16>, // [batch_size]
|
| 296 |
+
pub termination_codes: Vec<u8>, // [batch_size]
|
| 297 |
+
pub batch_size: usize,
|
| 298 |
+
pub seq_len: usize,
|
| 299 |
+
pub max_ply: usize,
|
| 300 |
+
}
|
| 301 |
+
|
| 302 |
+
/// Generate a CLM training batch: random games packed into model-ready format.
|
| 303 |
+
///
|
| 304 |
+
/// `seq_len` is the total sequence length (256). Games are generated with up to
|
| 305 |
+
/// `seq_len - 1` plies, leaving position 0 for the outcome token.
|
| 306 |
+
pub fn generate_clm_batch(
|
| 307 |
+
batch_size: usize,
|
| 308 |
+
seq_len: usize,
|
| 309 |
+
seed: u64,
|
| 310 |
+
discard_ply_limit: bool,
|
| 311 |
+
) -> CLMBatch {
|
| 312 |
+
let max_ply = seq_len - 1;
|
| 313 |
+
|
| 314 |
+
let game_batch = if discard_ply_limit {
|
| 315 |
+
generate_completed_games(batch_size, max_ply, seed)
|
| 316 |
+
} else {
|
| 317 |
+
generate_random_games(batch_size, max_ply, seed)
|
| 318 |
+
};
|
| 319 |
+
|
| 320 |
+
let mut input_ids = vec![0i16; batch_size * seq_len];
|
| 321 |
+
let mut targets = vec![0i16; batch_size * seq_len];
|
| 322 |
+
let mut loss_mask = vec![false; batch_size * seq_len];
|
| 323 |
+
|
| 324 |
+
for b in 0..batch_size {
|
| 325 |
+
let gl = game_batch.game_lengths[b] as usize;
|
| 326 |
+
let term = match game_batch.termination_codes[b] {
|
| 327 |
+
0 => Termination::Checkmate,
|
| 328 |
+
1 => Termination::Stalemate,
|
| 329 |
+
2 => Termination::SeventyFiveMoveRule,
|
| 330 |
+
3 => Termination::FivefoldRepetition,
|
| 331 |
+
4 => Termination::InsufficientMaterial,
|
| 332 |
+
_ => Termination::PlyLimit,
|
| 333 |
+
};
|
| 334 |
+
let outcome = vocab::termination_to_outcome(term, game_batch.game_lengths[b] as u16);
|
| 335 |
+
|
| 336 |
+
let row = b * seq_len;
|
| 337 |
+
|
| 338 |
+
// Position 0: outcome token
|
| 339 |
+
input_ids[row] = outcome as i16;
|
| 340 |
+
|
| 341 |
+
// Positions 1..=gl: move tokens
|
| 342 |
+
for t in 0..gl {
|
| 343 |
+
input_ids[row + 1 + t] = game_batch.move_ids[b * max_ply + t];
|
| 344 |
+
}
|
| 345 |
+
// Remaining positions are already 0 (PAD)
|
| 346 |
+
|
| 347 |
+
// Targets: input_ids shifted left by 1
|
| 348 |
+
for t in 0..(seq_len - 1) {
|
| 349 |
+
targets[row + t] = input_ids[row + t + 1];
|
| 350 |
+
}
|
| 351 |
+
// targets[row + seq_len - 1] is already 0
|
| 352 |
+
|
| 353 |
+
// Loss mask: positions 0..=gl are true
|
| 354 |
+
for t in 0..=gl {
|
| 355 |
+
loss_mask[row + t] = true;
|
| 356 |
+
}
|
| 357 |
+
}
|
| 358 |
+
|
| 359 |
+
CLMBatch {
|
| 360 |
+
input_ids,
|
| 361 |
+
targets,
|
| 362 |
+
loss_mask,
|
| 363 |
+
move_ids: game_batch.move_ids,
|
| 364 |
+
game_lengths: game_batch.game_lengths,
|
| 365 |
+
termination_codes: game_batch.termination_codes,
|
| 366 |
+
batch_size,
|
| 367 |
+
seq_len,
|
| 368 |
+
max_ply,
|
| 369 |
+
}
|
| 370 |
+
}
|
| 371 |
+
|
| 372 |
#[cfg(test)]
|
| 373 |
mod tests {
|
| 374 |
use super::*;
|
|
|
|
| 395 |
}
|
| 396 |
|
| 397 |
#[test]
|
| 398 |
+
fn test_pad_after_game_end() {
|
| 399 |
let batch = generate_training_batch(2, 256, 42);
|
| 400 |
for b in 0..2 {
|
| 401 |
let len = batch.game_lengths[b] as usize;
|
| 402 |
if len < 256 {
|
| 403 |
assert_eq!(
|
| 404 |
batch.move_ids[b * 256 + len],
|
| 405 |
+
vocab::PAD_TOKEN as i16,
|
| 406 |
+
"Position game_length should be PAD (0)"
|
| 407 |
+
);
|
| 408 |
+
}
|
| 409 |
+
// All positions after game_length should also be PAD
|
| 410 |
+
for t in len..256 {
|
| 411 |
+
assert_eq!(
|
| 412 |
+
batch.move_ids[b * 256 + t],
|
| 413 |
+
0,
|
| 414 |
+
"Position {} (after game_length={}) should be PAD", t, len
|
| 415 |
);
|
| 416 |
}
|
| 417 |
}
|
|
|
|
| 425 |
assert_eq!(b1.game_lengths, b2.game_lengths);
|
| 426 |
assert_eq!(b1.legal_move_grid, b2.legal_move_grid);
|
| 427 |
}
|
| 428 |
+
|
| 429 |
+
#[test]
|
| 430 |
+
fn test_clm_batch_format() {
|
| 431 |
+
let seq_len = 256;
|
| 432 |
+
let batch = generate_clm_batch(8, seq_len, 42, false);
|
| 433 |
+
assert_eq!(batch.input_ids.len(), 8 * seq_len);
|
| 434 |
+
assert_eq!(batch.targets.len(), 8 * seq_len);
|
| 435 |
+
assert_eq!(batch.loss_mask.len(), 8 * seq_len);
|
| 436 |
+
assert_eq!(batch.move_ids.len(), 8 * (seq_len - 1));
|
| 437 |
+
assert_eq!(batch.game_lengths.len(), 8);
|
| 438 |
+
|
| 439 |
+
for b in 0..8 {
|
| 440 |
+
let gl = batch.game_lengths[b] as usize;
|
| 441 |
+
let row = b * seq_len;
|
| 442 |
+
|
| 443 |
+
// Position 0: outcome token (4273-4277)
|
| 444 |
+
let outcome = batch.input_ids[row];
|
| 445 |
+
assert!(outcome >= vocab::OUTCOME_BASE as i16 && outcome <= vocab::PLY_LIMIT as i16,
|
| 446 |
+
"Position 0 should be outcome token, got {}", outcome);
|
| 447 |
+
|
| 448 |
+
// Positions 1..=gl: move tokens (1-4272)
|
| 449 |
+
for t in 1..=gl {
|
| 450 |
+
let tok = batch.input_ids[row + t];
|
| 451 |
+
assert!(tok >= 1 && tok <= 4272,
|
| 452 |
+
"Position {} should be move token, got {}", t, tok);
|
| 453 |
+
}
|
| 454 |
+
|
| 455 |
+
// Positions gl+1..seq_len: PAD (0)
|
| 456 |
+
for t in (gl + 1)..seq_len {
|
| 457 |
+
assert_eq!(batch.input_ids[row + t], 0,
|
| 458 |
+
"Position {} should be PAD, got {}", t, batch.input_ids[row + t]);
|
| 459 |
+
}
|
| 460 |
+
|
| 461 |
+
// Targets: shifted left by 1
|
| 462 |
+
for t in 0..(seq_len - 1) {
|
| 463 |
+
assert_eq!(batch.targets[row + t], batch.input_ids[row + t + 1],
|
| 464 |
+
"targets[{}] should equal input_ids[{}]", t, t + 1);
|
| 465 |
+
}
|
| 466 |
+
assert_eq!(batch.targets[row + seq_len - 1], 0, "Last target should be PAD");
|
| 467 |
+
|
| 468 |
+
// Target at position gl is PAD (end of game)
|
| 469 |
+
assert_eq!(batch.targets[row + gl], 0, "Target at game_length should be PAD");
|
| 470 |
+
|
| 471 |
+
// Loss mask: true for 0..=gl, false after
|
| 472 |
+
for t in 0..=gl {
|
| 473 |
+
assert!(batch.loss_mask[row + t],
|
| 474 |
+
"loss_mask[{}] should be true (gl={})", t, gl);
|
| 475 |
+
}
|
| 476 |
+
for t in (gl + 1)..seq_len {
|
| 477 |
+
assert!(!batch.loss_mask[row + t],
|
| 478 |
+
"loss_mask[{}] should be false (gl={})", t, gl);
|
| 479 |
+
}
|
| 480 |
+
}
|
| 481 |
+
}
|
| 482 |
+
|
| 483 |
+
#[test]
|
| 484 |
+
fn test_clm_batch_deterministic() {
|
| 485 |
+
let b1 = generate_clm_batch(4, 256, 99, false);
|
| 486 |
+
let b2 = generate_clm_batch(4, 256, 99, false);
|
| 487 |
+
assert_eq!(b1.input_ids, b2.input_ids);
|
| 488 |
+
assert_eq!(b1.targets, b2.targets);
|
| 489 |
+
assert_eq!(b1.loss_mask, b2.loss_mask);
|
| 490 |
+
assert_eq!(b1.game_lengths, b2.game_lengths);
|
| 491 |
+
}
|
| 492 |
+
|
| 493 |
+
#[test]
|
| 494 |
+
fn test_clm_batch_outcome_correctness() {
|
| 495 |
+
let batch = generate_clm_batch(32, 256, 42, false);
|
| 496 |
+
for b in 0..32 {
|
| 497 |
+
let gl = batch.game_lengths[b] as usize;
|
| 498 |
+
let tc = batch.termination_codes[b];
|
| 499 |
+
let expected = vocab::termination_to_outcome(
|
| 500 |
+
match tc {
|
| 501 |
+
0 => Termination::Checkmate,
|
| 502 |
+
1 => Termination::Stalemate,
|
| 503 |
+
2 => Termination::SeventyFiveMoveRule,
|
| 504 |
+
3 => Termination::FivefoldRepetition,
|
| 505 |
+
4 => Termination::InsufficientMaterial,
|
| 506 |
+
_ => Termination::PlyLimit,
|
| 507 |
+
},
|
| 508 |
+
gl as u16,
|
| 509 |
+
);
|
| 510 |
+
assert_eq!(batch.input_ids[b * 256] as u16, expected,
|
| 511 |
+
"Game {} outcome mismatch: tc={}, gl={}", b, tc, gl);
|
| 512 |
+
}
|
| 513 |
+
}
|
| 514 |
}
|
engine/src/board.rs
CHANGED
|
@@ -74,7 +74,7 @@ pub fn move_to_token(m: &Move) -> u16 {
|
|
| 74 |
/// Convert our token index to a shakmaty Move, given the current position.
|
| 75 |
/// Finds the legal move matching the token's (src, dst, promo) decomposition.
|
| 76 |
pub fn token_to_move(pos: &Chess, token: u16) -> Option<Move> {
|
| 77 |
-
// Validate the token is decomposable (not PAD/
|
| 78 |
vocab::decompose_token(token)?;
|
| 79 |
let legal = pos.legal_moves();
|
| 80 |
|
|
|
|
| 74 |
/// Convert our token index to a shakmaty Move, given the current position.
|
| 75 |
/// Finds the legal move matching the token's (src, dst, promo) decomposition.
|
| 76 |
pub fn token_to_move(pos: &Chess, token: u16) -> Option<Move> {
|
| 77 |
+
// Validate the token is decomposable (not PAD/outcome)
|
| 78 |
vocab::decompose_token(token)?;
|
| 79 |
let legal = pos.legal_moves();
|
| 80 |
|
engine/src/lib.rs
CHANGED
|
@@ -169,6 +169,46 @@ fn generate_random_games<'py>(
|
|
| 169 |
Ok((move_ids, game_lengths, termination_codes))
|
| 170 |
}
|
| 171 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 172 |
/// Compute legal move masks by replaying games. Spec §7.4.
|
| 173 |
#[pyfunction]
|
| 174 |
fn compute_legal_move_masks<'py>(
|
|
@@ -868,6 +908,7 @@ fn _engine(m: &Bound<'_, PyModule>) -> PyResult<()> {
|
|
| 868 |
m.add_function(wrap_pyfunction!(export_move_vocabulary, m)?)?;
|
| 869 |
m.add_function(wrap_pyfunction!(generate_training_batch, m)?)?;
|
| 870 |
m.add_function(wrap_pyfunction!(generate_random_games, m)?)?;
|
|
|
|
| 871 |
m.add_function(wrap_pyfunction!(generate_checkmate_games, m)?)?;
|
| 872 |
m.add_function(wrap_pyfunction!(generate_checkmate_training_batch, m)?)?;
|
| 873 |
m.add_function(wrap_pyfunction!(compute_legal_move_masks, m)?)?;
|
|
|
|
| 169 |
Ok((move_ids, game_lengths, termination_codes))
|
| 170 |
}
|
| 171 |
|
| 172 |
+
/// Generate a CLM training batch with model-ready tensors.
|
| 173 |
+
///
|
| 174 |
+
/// Returns (input_ids, targets, loss_mask, move_ids, game_lengths, term_codes).
|
| 175 |
+
/// input_ids = [outcome, ply_1, ..., ply_N, PAD, ...] (seq_len per row).
|
| 176 |
+
/// move_ids are the raw moves (seq_len-1 per row) for replay operations.
|
| 177 |
+
#[pyfunction]
|
| 178 |
+
#[pyo3(signature = (batch_size, seq_len=256, seed=42, discard_ply_limit=false))]
|
| 179 |
+
fn generate_clm_batch<'py>(
|
| 180 |
+
py: Python<'py>,
|
| 181 |
+
batch_size: usize,
|
| 182 |
+
seq_len: usize,
|
| 183 |
+
seed: u64,
|
| 184 |
+
discard_ply_limit: bool,
|
| 185 |
+
) -> PyResult<(
|
| 186 |
+
Bound<'py, PyArray2<i16>>, // input_ids (B, seq_len)
|
| 187 |
+
Bound<'py, PyArray2<i16>>, // targets (B, seq_len)
|
| 188 |
+
Bound<'py, PyArray2<bool>>, // loss_mask (B, seq_len)
|
| 189 |
+
Bound<'py, PyArray2<i16>>, // move_ids (B, seq_len-1)
|
| 190 |
+
Bound<'py, PyArray1<i16>>, // game_lengths (B,)
|
| 191 |
+
Bound<'py, PyArray1<u8>>, // termination_codes (B,)
|
| 192 |
+
)> {
|
| 193 |
+
let result = py.allow_threads(|| {
|
| 194 |
+
batch::generate_clm_batch(batch_size, seq_len, seed, discard_ply_limit)
|
| 195 |
+
});
|
| 196 |
+
|
| 197 |
+
let max_ply = seq_len - 1;
|
| 198 |
+
let input_ids = numpy::PyArray::from_vec(py, result.input_ids)
|
| 199 |
+
.reshape([batch_size, seq_len])?;
|
| 200 |
+
let targets = numpy::PyArray::from_vec(py, result.targets)
|
| 201 |
+
.reshape([batch_size, seq_len])?;
|
| 202 |
+
let loss_mask = numpy::PyArray::from_vec(py, result.loss_mask)
|
| 203 |
+
.reshape([batch_size, seq_len])?;
|
| 204 |
+
let move_ids = numpy::PyArray::from_vec(py, result.move_ids)
|
| 205 |
+
.reshape([batch_size, max_ply])?;
|
| 206 |
+
let game_lengths = numpy::PyArray::from_vec(py, result.game_lengths);
|
| 207 |
+
let termination_codes = numpy::PyArray::from_vec(py, result.termination_codes);
|
| 208 |
+
|
| 209 |
+
Ok((input_ids, targets, loss_mask, move_ids, game_lengths, termination_codes))
|
| 210 |
+
}
|
| 211 |
+
|
| 212 |
/// Compute legal move masks by replaying games. Spec §7.4.
|
| 213 |
#[pyfunction]
|
| 214 |
fn compute_legal_move_masks<'py>(
|
|
|
|
| 908 |
m.add_function(wrap_pyfunction!(export_move_vocabulary, m)?)?;
|
| 909 |
m.add_function(wrap_pyfunction!(generate_training_batch, m)?)?;
|
| 910 |
m.add_function(wrap_pyfunction!(generate_random_games, m)?)?;
|
| 911 |
+
m.add_function(wrap_pyfunction!(generate_clm_batch, m)?)?;
|
| 912 |
m.add_function(wrap_pyfunction!(generate_checkmate_games, m)?)?;
|
| 913 |
m.add_function(wrap_pyfunction!(generate_checkmate_training_batch, m)?)?;
|
| 914 |
m.add_function(wrap_pyfunction!(compute_legal_move_masks, m)?)?;
|
engine/src/pgn.rs
CHANGED
|
@@ -62,9 +62,6 @@ pub fn batch_san_to_tokens(
|
|
| 62 |
for (t, &tok) in tokens.iter().enumerate() {
|
| 63 |
flat[gi * max_ply + t] = tok as i16;
|
| 64 |
}
|
| 65 |
-
if n_valid < max_ply {
|
| 66 |
-
flat[gi * max_ply + n_valid] = crate::vocab::EOG_TOKEN as i16;
|
| 67 |
-
}
|
| 68 |
lengths.push(n_valid as i16);
|
| 69 |
}
|
| 70 |
|
|
@@ -195,9 +192,6 @@ pub fn pgn_file_to_tokens(
|
|
| 195 |
for (t, &tok) in tokens.iter().enumerate() {
|
| 196 |
flat[gi * max_ply + t] = tok as i16;
|
| 197 |
}
|
| 198 |
-
if *n_valid < max_ply {
|
| 199 |
-
flat[gi * max_ply + n_valid] = crate::vocab::EOG_TOKEN as i16;
|
| 200 |
-
}
|
| 201 |
lengths.push(*n_valid as i16);
|
| 202 |
}
|
| 203 |
|
|
|
|
| 62 |
for (t, &tok) in tokens.iter().enumerate() {
|
| 63 |
flat[gi * max_ply + t] = tok as i16;
|
| 64 |
}
|
|
|
|
|
|
|
|
|
|
| 65 |
lengths.push(n_valid as i16);
|
| 66 |
}
|
| 67 |
|
|
|
|
| 192 |
for (t, &tok) in tokens.iter().enumerate() {
|
| 193 |
flat[gi * max_ply + t] = tok as i16;
|
| 194 |
}
|
|
|
|
|
|
|
|
|
|
| 195 |
lengths.push(*n_valid as i16);
|
| 196 |
}
|
| 197 |
|
engine/src/vocab.rs
CHANGED
|
@@ -1,10 +1,10 @@
|
|
| 1 |
//! Move vocabulary: the single source of truth for token ↔ UCI string mapping.
|
| 2 |
//!
|
| 3 |
-
//! Token layout (4,
|
| 4 |
//! 0 = padding
|
| 5 |
//! 1..=4096 = base grid (64×64 src×dst pairs)
|
| 6 |
//! 4097..=4272 = promotions (44 eligible pairs × 4 piece types)
|
| 7 |
-
//! 4273
|
| 8 |
//!
|
| 9 |
//! Square indexing: file-major within rank.
|
| 10 |
//! a1=0, b1=1, ..., h1=7, a2=8, ..., h8=63
|
|
@@ -12,9 +12,10 @@
|
|
| 12 |
|
| 13 |
use std::collections::HashMap;
|
| 14 |
|
| 15 |
-
|
|
|
|
|
|
|
| 16 |
pub const PAD_TOKEN: u16 = 0;
|
| 17 |
-
pub const EOG_TOKEN: u16 = 4273;
|
| 18 |
pub const BASE_GRID_START: u16 = 1;
|
| 19 |
pub const BASE_GRID_END: u16 = 4096; // inclusive
|
| 20 |
pub const PROMO_START: u16 = 4097;
|
|
@@ -22,6 +23,14 @@ pub const PROMO_END: u16 = 4272; // inclusive
|
|
| 22 |
pub const NUM_PROMO_PAIRS: usize = 44;
|
| 23 |
pub const NUM_PROMO_TYPES: usize = 4;
|
| 24 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
/// Square names in our index order.
|
| 26 |
pub const SQUARE_NAMES: [&str; 64] = [
|
| 27 |
"a1", "b1", "c1", "d1", "e1", "f1", "g1", "h1",
|
|
@@ -131,11 +140,29 @@ pub fn promo_token(src: u8, dst: u8, promo_type: u8) -> Option<u16> {
|
|
| 131 |
Some(PROMO_START + (pair_idx as u16) * 4 + (promo_type as u16))
|
| 132 |
}
|
| 133 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 134 |
/// Decompose a token into (src_square, dst_square, promo_type).
|
| 135 |
/// promo_type: 0=none, 1=q, 2=r, 3=b, 4=n (matches embedding index).
|
| 136 |
-
/// Returns None for PAD and
|
| 137 |
pub fn decompose_token(token: u16) -> Option<(u8, u8, u8)> {
|
| 138 |
-
if token == PAD_TOKEN || token =
|
| 139 |
return None;
|
| 140 |
}
|
| 141 |
if token >= BASE_GRID_START && token <= BASE_GRID_END {
|
|
@@ -285,8 +312,32 @@ mod tests {
|
|
| 285 |
}
|
| 286 |
|
| 287 |
#[test]
|
| 288 |
-
fn
|
| 289 |
assert!(decompose_token(PAD_TOKEN).is_none());
|
| 290 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 291 |
}
|
| 292 |
}
|
|
|
|
| 1 |
//! Move vocabulary: the single source of truth for token ↔ UCI string mapping.
|
| 2 |
//!
|
| 3 |
+
//! Token layout (4,278 total):
|
| 4 |
//! 0 = padding
|
| 5 |
//! 1..=4096 = base grid (64×64 src×dst pairs)
|
| 6 |
//! 4097..=4272 = promotions (44 eligible pairs × 4 piece types)
|
| 7 |
+
//! 4273..=4277 = outcome tokens (game result)
|
| 8 |
//!
|
| 9 |
//! Square indexing: file-major within rank.
|
| 10 |
//! a1=0, b1=1, ..., h1=7, a2=8, ..., h8=63
|
|
|
|
| 12 |
|
| 13 |
use std::collections::HashMap;
|
| 14 |
|
| 15 |
+
use crate::types::Termination;
|
| 16 |
+
|
| 17 |
+
pub const VOCAB_SIZE: usize = 4278;
|
| 18 |
pub const PAD_TOKEN: u16 = 0;
|
|
|
|
| 19 |
pub const BASE_GRID_START: u16 = 1;
|
| 20 |
pub const BASE_GRID_END: u16 = 4096; // inclusive
|
| 21 |
pub const PROMO_START: u16 = 4097;
|
|
|
|
| 23 |
pub const NUM_PROMO_PAIRS: usize = 44;
|
| 24 |
pub const NUM_PROMO_TYPES: usize = 4;
|
| 25 |
|
| 26 |
+
// Outcome tokens — must match pawn/config.py
|
| 27 |
+
pub const OUTCOME_BASE: u16 = 4273;
|
| 28 |
+
pub const WHITE_CHECKMATES: u16 = 4273;
|
| 29 |
+
pub const BLACK_CHECKMATES: u16 = 4274;
|
| 30 |
+
pub const STALEMATE: u16 = 4275;
|
| 31 |
+
pub const DRAW_BY_RULE: u16 = 4276;
|
| 32 |
+
pub const PLY_LIMIT: u16 = 4277;
|
| 33 |
+
|
| 34 |
/// Square names in our index order.
|
| 35 |
pub const SQUARE_NAMES: [&str; 64] = [
|
| 36 |
"a1", "b1", "c1", "d1", "e1", "f1", "g1", "h1",
|
|
|
|
| 140 |
Some(PROMO_START + (pair_idx as u16) * 4 + (promo_type as u16))
|
| 141 |
}
|
| 142 |
|
| 143 |
+
/// Map a game termination reason to the corresponding outcome token.
|
| 144 |
+
///
|
| 145 |
+
/// For checkmate, the winner is determined by game length parity:
|
| 146 |
+
/// - Odd game_length (white made the last move) → WHITE_CHECKMATES
|
| 147 |
+
/// - Even game_length (black made the last move) → BLACK_CHECKMATES
|
| 148 |
+
pub fn termination_to_outcome(term: Termination, game_length: u16) -> u16 {
|
| 149 |
+
match term {
|
| 150 |
+
Termination::Checkmate => {
|
| 151 |
+
if game_length % 2 == 1 { WHITE_CHECKMATES } else { BLACK_CHECKMATES }
|
| 152 |
+
}
|
| 153 |
+
Termination::Stalemate => STALEMATE,
|
| 154 |
+
Termination::SeventyFiveMoveRule
|
| 155 |
+
| Termination::FivefoldRepetition
|
| 156 |
+
| Termination::InsufficientMaterial => DRAW_BY_RULE,
|
| 157 |
+
Termination::PlyLimit => PLY_LIMIT,
|
| 158 |
+
}
|
| 159 |
+
}
|
| 160 |
+
|
| 161 |
/// Decompose a token into (src_square, dst_square, promo_type).
|
| 162 |
/// promo_type: 0=none, 1=q, 2=r, 3=b, 4=n (matches embedding index).
|
| 163 |
+
/// Returns None for PAD and outcome tokens.
|
| 164 |
pub fn decompose_token(token: u16) -> Option<(u8, u8, u8)> {
|
| 165 |
+
if token == PAD_TOKEN || token >= OUTCOME_BASE {
|
| 166 |
return None;
|
| 167 |
}
|
| 168 |
if token >= BASE_GRID_START && token <= BASE_GRID_END {
|
|
|
|
| 312 |
}
|
| 313 |
|
| 314 |
#[test]
|
| 315 |
+
fn test_pad_outcome_decompose() {
|
| 316 |
assert!(decompose_token(PAD_TOKEN).is_none());
|
| 317 |
+
// All 5 outcome tokens should return None
|
| 318 |
+
for token in OUTCOME_BASE..=PLY_LIMIT {
|
| 319 |
+
assert!(decompose_token(token).is_none(),
|
| 320 |
+
"outcome token {} should not decompose", token);
|
| 321 |
+
}
|
| 322 |
+
}
|
| 323 |
+
|
| 324 |
+
#[test]
|
| 325 |
+
fn test_termination_to_outcome() {
|
| 326 |
+
use crate::types::Termination;
|
| 327 |
+
|
| 328 |
+
// Checkmate with odd game_length = white wins
|
| 329 |
+
assert_eq!(termination_to_outcome(Termination::Checkmate, 11), WHITE_CHECKMATES);
|
| 330 |
+
assert_eq!(termination_to_outcome(Termination::Checkmate, 1), WHITE_CHECKMATES);
|
| 331 |
+
|
| 332 |
+
// Checkmate with even game_length = black wins
|
| 333 |
+
assert_eq!(termination_to_outcome(Termination::Checkmate, 12), BLACK_CHECKMATES);
|
| 334 |
+
assert_eq!(termination_to_outcome(Termination::Checkmate, 2), BLACK_CHECKMATES);
|
| 335 |
+
|
| 336 |
+
// Other terminations
|
| 337 |
+
assert_eq!(termination_to_outcome(Termination::Stalemate, 50), STALEMATE);
|
| 338 |
+
assert_eq!(termination_to_outcome(Termination::SeventyFiveMoveRule, 100), DRAW_BY_RULE);
|
| 339 |
+
assert_eq!(termination_to_outcome(Termination::FivefoldRepetition, 80), DRAW_BY_RULE);
|
| 340 |
+
assert_eq!(termination_to_outcome(Termination::InsufficientMaterial, 60), DRAW_BY_RULE);
|
| 341 |
+
assert_eq!(termination_to_outcome(Termination::PlyLimit, 255), PLY_LIMIT);
|
| 342 |
}
|
| 343 |
}
|
pawn/checkpoint.py
ADDED
|
@@ -0,0 +1,647 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Checkpoint save/load using safetensors + JSON.
|
| 2 |
+
|
| 3 |
+
Replaces monolithic torch.save() .pt files with directory-based checkpoints:
|
| 4 |
+
- model.safetensors / adapter.safetensors — tensor data
|
| 5 |
+
- optimizer.safetensors — flattened optimizer state tensors
|
| 6 |
+
- training_state.json — scalars, scheduler, scaler, RNG, optimizer metadata
|
| 7 |
+
- config.json — model and training configuration
|
| 8 |
+
- .complete — SHA-256 hashes of all files (integrity sentinel)
|
| 9 |
+
|
| 10 |
+
Writes are atomic: files are written to a .tmp directory, then renamed.
|
| 11 |
+
Loads always verify the .complete sentinel and SHA-256 hashes.
|
| 12 |
+
|
| 13 |
+
Backward compatible: all load functions transparently handle legacy .pt files.
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
from __future__ import annotations
|
| 17 |
+
|
| 18 |
+
import base64
|
| 19 |
+
import hashlib
|
| 20 |
+
import json
|
| 21 |
+
import os
|
| 22 |
+
import shutil
|
| 23 |
+
from pathlib import Path
|
| 24 |
+
|
| 25 |
+
import torch
|
| 26 |
+
import torch.nn as nn
|
| 27 |
+
from safetensors.torch import save_file, load_file
|
| 28 |
+
|
| 29 |
+
CHECKPOINT_FORMAT_VERSION = 1
|
| 30 |
+
LEGACY_EXTENSIONS = {".pt", ".pth"}
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
# ---------------------------------------------------------------------------
|
| 34 |
+
# Exceptions
|
| 35 |
+
# ---------------------------------------------------------------------------
|
| 36 |
+
|
| 37 |
+
class IncompleteCheckpointError(Exception):
|
| 38 |
+
"""Raised when a checkpoint directory is missing its .complete sentinel."""
|
| 39 |
+
pass
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class CheckpointIntegrityError(Exception):
|
| 43 |
+
"""Raised when a checkpoint file's SHA-256 hash doesn't match .complete."""
|
| 44 |
+
pass
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
# ---------------------------------------------------------------------------
|
| 48 |
+
# SHA-256 helpers
|
| 49 |
+
# ---------------------------------------------------------------------------
|
| 50 |
+
|
| 51 |
+
def _sha256_file(path: Path) -> str:
|
| 52 |
+
"""Compute SHA-256 hex digest of a file."""
|
| 53 |
+
h = hashlib.sha256()
|
| 54 |
+
with open(path, "rb") as f:
|
| 55 |
+
for chunk in iter(lambda: f.read(1 << 20), b""): # 1MB chunks
|
| 56 |
+
h.update(chunk)
|
| 57 |
+
return h.hexdigest()
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def _write_complete_sentinel(directory: Path) -> None:
|
| 61 |
+
"""Write .complete sentinel with SHA-256 hashes of all checkpoint files."""
|
| 62 |
+
hashes = {}
|
| 63 |
+
for f in sorted(directory.iterdir()):
|
| 64 |
+
if f.name == ".complete" or f.is_dir():
|
| 65 |
+
continue
|
| 66 |
+
hashes[f.name] = _sha256_file(f)
|
| 67 |
+
|
| 68 |
+
sentinel = {"format_version": CHECKPOINT_FORMAT_VERSION, "files": hashes}
|
| 69 |
+
with open(directory / ".complete", "w") as f:
|
| 70 |
+
json.dump(sentinel, f, indent=2)
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def _verify_complete_sentinel(directory: Path) -> None:
|
| 74 |
+
"""Verify .complete sentinel exists and all hashes match.
|
| 75 |
+
|
| 76 |
+
Raises IncompleteCheckpointError if sentinel is missing.
|
| 77 |
+
Raises CheckpointIntegrityError if any hash mismatches.
|
| 78 |
+
"""
|
| 79 |
+
sentinel_path = directory / ".complete"
|
| 80 |
+
if not sentinel_path.exists():
|
| 81 |
+
raise IncompleteCheckpointError(
|
| 82 |
+
f"Checkpoint {directory} is missing .complete sentinel — "
|
| 83 |
+
f"likely a partial write from a crashed or interrupted save."
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
with open(sentinel_path) as f:
|
| 87 |
+
sentinel = json.load(f)
|
| 88 |
+
|
| 89 |
+
for filename, expected_hash in sentinel["files"].items():
|
| 90 |
+
filepath = directory / filename
|
| 91 |
+
if not filepath.exists():
|
| 92 |
+
raise CheckpointIntegrityError(
|
| 93 |
+
f"File {filename} listed in .complete but missing from {directory}"
|
| 94 |
+
)
|
| 95 |
+
actual_hash = _sha256_file(filepath)
|
| 96 |
+
if actual_hash != expected_hash:
|
| 97 |
+
raise CheckpointIntegrityError(
|
| 98 |
+
f"SHA-256 mismatch for {filename} in {directory}: "
|
| 99 |
+
f"expected {expected_hash[:16]}..., got {actual_hash[:16]}..."
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
# ---------------------------------------------------------------------------
|
| 104 |
+
# Atomic directory write
|
| 105 |
+
# ---------------------------------------------------------------------------
|
| 106 |
+
|
| 107 |
+
def _atomic_directory_write(target: Path):
|
| 108 |
+
"""Context manager for atomic directory writes.
|
| 109 |
+
|
| 110 |
+
Usage:
|
| 111 |
+
with _atomic_directory_write(Path("step_00001")) as tmp:
|
| 112 |
+
save_file(tensors, tmp / "model.safetensors")
|
| 113 |
+
...
|
| 114 |
+
# Directory is now at step_00001/ with .complete sentinel
|
| 115 |
+
"""
|
| 116 |
+
class _AtomicDir:
|
| 117 |
+
def __init__(self, target: Path):
|
| 118 |
+
self.target = target
|
| 119 |
+
self.tmp = target.parent / f"{target.name}.tmp"
|
| 120 |
+
|
| 121 |
+
def __enter__(self) -> Path:
|
| 122 |
+
# Clean up any leftover .tmp from a previous crash
|
| 123 |
+
if self.tmp.exists():
|
| 124 |
+
shutil.rmtree(self.tmp)
|
| 125 |
+
self.tmp.mkdir(parents=True, exist_ok=True)
|
| 126 |
+
return self.tmp
|
| 127 |
+
|
| 128 |
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
| 129 |
+
if exc_type is not None:
|
| 130 |
+
# Error during write — clean up temp dir
|
| 131 |
+
if self.tmp.exists():
|
| 132 |
+
shutil.rmtree(self.tmp)
|
| 133 |
+
return False
|
| 134 |
+
|
| 135 |
+
# Write .complete sentinel with hashes
|
| 136 |
+
_write_complete_sentinel(self.tmp)
|
| 137 |
+
|
| 138 |
+
# Atomic rename (same filesystem)
|
| 139 |
+
if self.target.exists():
|
| 140 |
+
shutil.rmtree(self.target)
|
| 141 |
+
os.rename(self.tmp, self.target)
|
| 142 |
+
return False
|
| 143 |
+
|
| 144 |
+
return _AtomicDir(target)
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
# ---------------------------------------------------------------------------
|
| 148 |
+
# JSON helpers
|
| 149 |
+
# ---------------------------------------------------------------------------
|
| 150 |
+
|
| 151 |
+
def _json_default(obj):
|
| 152 |
+
"""JSON serializer for types not natively supported."""
|
| 153 |
+
if isinstance(obj, torch.Tensor):
|
| 154 |
+
return obj.item() if obj.numel() == 1 else obj.tolist()
|
| 155 |
+
if hasattr(obj, "item"): # numpy scalar
|
| 156 |
+
return obj.item()
|
| 157 |
+
if isinstance(obj, Path):
|
| 158 |
+
return str(obj)
|
| 159 |
+
return str(obj)
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
# ---------------------------------------------------------------------------
|
| 163 |
+
# RNG state serialization
|
| 164 |
+
# ---------------------------------------------------------------------------
|
| 165 |
+
|
| 166 |
+
def _rng_to_json(
|
| 167 |
+
torch_rng: torch.Tensor | None, cuda_rng: torch.Tensor | None
|
| 168 |
+
) -> dict:
|
| 169 |
+
data = {}
|
| 170 |
+
if torch_rng is not None:
|
| 171 |
+
data["torch_rng_state"] = base64.b64encode(
|
| 172 |
+
torch_rng.numpy().tobytes()
|
| 173 |
+
).decode("ascii")
|
| 174 |
+
if cuda_rng is not None:
|
| 175 |
+
data["cuda_rng_state"] = base64.b64encode(
|
| 176 |
+
cuda_rng.numpy().tobytes()
|
| 177 |
+
).decode("ascii")
|
| 178 |
+
return data
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
def _json_to_rng(data: dict) -> tuple[torch.Tensor | None, torch.Tensor | None]:
|
| 182 |
+
torch_rng = None
|
| 183 |
+
if "torch_rng_state" in data:
|
| 184 |
+
raw = base64.b64decode(data["torch_rng_state"])
|
| 185 |
+
torch_rng = torch.frombuffer(bytearray(raw), dtype=torch.uint8)
|
| 186 |
+
cuda_rng = None
|
| 187 |
+
if "cuda_rng_state" in data:
|
| 188 |
+
raw = base64.b64decode(data["cuda_rng_state"])
|
| 189 |
+
cuda_rng = torch.frombuffer(bytearray(raw), dtype=torch.uint8)
|
| 190 |
+
return torch_rng, cuda_rng
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
# ---------------------------------------------------------------------------
|
| 194 |
+
# Optimizer state flattening for safetensors
|
| 195 |
+
# ---------------------------------------------------------------------------
|
| 196 |
+
|
| 197 |
+
def _flatten_optimizer_state(
|
| 198 |
+
opt_state_dict: dict,
|
| 199 |
+
) -> tuple[dict[str, torch.Tensor], dict]:
|
| 200 |
+
"""Flatten optimizer state into safetensors-compatible tensors + JSON metadata."""
|
| 201 |
+
tensors: dict[str, torch.Tensor] = {}
|
| 202 |
+
scalars: dict[str, float | int] = {}
|
| 203 |
+
|
| 204 |
+
for param_id, param_state in opt_state_dict["state"].items():
|
| 205 |
+
for key, val in param_state.items():
|
| 206 |
+
flat_key = f"state.{param_id}.{key}"
|
| 207 |
+
if isinstance(val, torch.Tensor):
|
| 208 |
+
t = val.cpu().contiguous()
|
| 209 |
+
if t.ndim == 0:
|
| 210 |
+
t = t.unsqueeze(0)
|
| 211 |
+
tensors[flat_key] = t
|
| 212 |
+
else:
|
| 213 |
+
scalars[flat_key] = val
|
| 214 |
+
|
| 215 |
+
meta = {
|
| 216 |
+
"param_groups": opt_state_dict["param_groups"],
|
| 217 |
+
"scalars": scalars if scalars else None,
|
| 218 |
+
}
|
| 219 |
+
return tensors, meta
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
def _unflatten_optimizer_state(
|
| 223 |
+
tensors: dict[str, torch.Tensor],
|
| 224 |
+
meta: dict,
|
| 225 |
+
device: str = "cpu",
|
| 226 |
+
) -> dict:
|
| 227 |
+
"""Reconstruct optimizer state_dict from flattened tensors + metadata."""
|
| 228 |
+
state: dict[int, dict[str, torch.Tensor | float | int]] = {}
|
| 229 |
+
scalars = meta.get("scalars") or {}
|
| 230 |
+
|
| 231 |
+
for flat_key, val in tensors.items():
|
| 232 |
+
parts = flat_key.split(".", 2)
|
| 233 |
+
if len(parts) != 3 or parts[0] != "state":
|
| 234 |
+
continue
|
| 235 |
+
param_id = int(parts[1])
|
| 236 |
+
key = parts[2]
|
| 237 |
+
if param_id not in state:
|
| 238 |
+
state[param_id] = {}
|
| 239 |
+
t = val.to(device)
|
| 240 |
+
if t.shape == (1,) and key == "step":
|
| 241 |
+
t = t.squeeze(0)
|
| 242 |
+
state[param_id][key] = t
|
| 243 |
+
|
| 244 |
+
for flat_key, val in scalars.items():
|
| 245 |
+
parts = flat_key.split(".", 2)
|
| 246 |
+
if len(parts) != 3:
|
| 247 |
+
continue
|
| 248 |
+
param_id = int(parts[1])
|
| 249 |
+
key = parts[2]
|
| 250 |
+
if param_id not in state:
|
| 251 |
+
state[param_id] = {}
|
| 252 |
+
state[param_id][key] = val
|
| 253 |
+
|
| 254 |
+
return {
|
| 255 |
+
"state": state,
|
| 256 |
+
"param_groups": meta["param_groups"],
|
| 257 |
+
}
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
# ---------------------------------------------------------------------------
|
| 261 |
+
# Legacy checkpoint detection
|
| 262 |
+
# ---------------------------------------------------------------------------
|
| 263 |
+
|
| 264 |
+
def is_legacy_checkpoint(path: str | Path) -> bool:
|
| 265 |
+
"""True if path is a single .pt/.pth file (legacy format)."""
|
| 266 |
+
p = Path(path)
|
| 267 |
+
return p.is_file() and p.suffix in LEGACY_EXTENSIONS
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
def _load_legacy_pt(path: str | Path, device: str = "cpu") -> dict:
|
| 271 |
+
return torch.load(str(path), map_location=device, weights_only=False)
|
| 272 |
+
|
| 273 |
+
|
| 274 |
+
# ---------------------------------------------------------------------------
|
| 275 |
+
# Pretrain checkpoint save/load
|
| 276 |
+
# ---------------------------------------------------------------------------
|
| 277 |
+
|
| 278 |
+
def save_pretrain_checkpoint(
|
| 279 |
+
path: str | Path,
|
| 280 |
+
model: nn.Module,
|
| 281 |
+
optimizer: torch.optim.Optimizer,
|
| 282 |
+
scheduler,
|
| 283 |
+
scaler,
|
| 284 |
+
global_step: int,
|
| 285 |
+
model_config: dict,
|
| 286 |
+
training_config: dict,
|
| 287 |
+
) -> None:
|
| 288 |
+
"""Save a pretraining checkpoint atomically.
|
| 289 |
+
|
| 290 |
+
Writes to {path}.tmp/, then renames to {path}/ with .complete sentinel.
|
| 291 |
+
"""
|
| 292 |
+
path = Path(path)
|
| 293 |
+
|
| 294 |
+
with _atomic_directory_write(path) as tmp:
|
| 295 |
+
# 1. Model weights
|
| 296 |
+
state_dict = {k: v.cpu().contiguous() for k, v in model.state_dict().items()}
|
| 297 |
+
save_file(state_dict, tmp / "model.safetensors")
|
| 298 |
+
|
| 299 |
+
# 2. Optimizer tensors
|
| 300 |
+
opt_tensors, opt_meta = _flatten_optimizer_state(optimizer.state_dict())
|
| 301 |
+
if opt_tensors:
|
| 302 |
+
save_file(opt_tensors, tmp / "optimizer.safetensors")
|
| 303 |
+
|
| 304 |
+
# 3. Training state (JSON)
|
| 305 |
+
training_state = {
|
| 306 |
+
"format_version": CHECKPOINT_FORMAT_VERSION,
|
| 307 |
+
"global_step": global_step,
|
| 308 |
+
"scheduler_state_dict": scheduler.state_dict() if scheduler is not None else None,
|
| 309 |
+
"scaler_state_dict": scaler.state_dict() if scaler is not None else None,
|
| 310 |
+
"optimizer_meta": opt_meta,
|
| 311 |
+
**_rng_to_json(
|
| 312 |
+
torch.get_rng_state(),
|
| 313 |
+
torch.cuda.get_rng_state() if torch.cuda.is_available() else None,
|
| 314 |
+
),
|
| 315 |
+
}
|
| 316 |
+
with open(tmp / "training_state.json", "w") as f:
|
| 317 |
+
json.dump(training_state, f, indent=2, default=_json_default)
|
| 318 |
+
|
| 319 |
+
# 4. Config
|
| 320 |
+
config = {
|
| 321 |
+
"format_version": CHECKPOINT_FORMAT_VERSION,
|
| 322 |
+
"checkpoint_type": "pretrain",
|
| 323 |
+
"model_config": model_config,
|
| 324 |
+
"training_config": training_config,
|
| 325 |
+
}
|
| 326 |
+
with open(tmp / "config.json", "w") as f:
|
| 327 |
+
json.dump(config, f, indent=2, default=_json_default)
|
| 328 |
+
|
| 329 |
+
|
| 330 |
+
def load_pretrain_checkpoint(
|
| 331 |
+
path: str | Path,
|
| 332 |
+
model: nn.Module,
|
| 333 |
+
optimizer: torch.optim.Optimizer | None = None,
|
| 334 |
+
scheduler=None,
|
| 335 |
+
scaler=None,
|
| 336 |
+
device: str = "cpu",
|
| 337 |
+
) -> dict:
|
| 338 |
+
"""Load a pretraining checkpoint with integrity verification.
|
| 339 |
+
|
| 340 |
+
Handles both legacy .pt files and new directory format.
|
| 341 |
+
New format: verifies .complete sentinel and SHA-256 hashes.
|
| 342 |
+
"""
|
| 343 |
+
path = Path(path)
|
| 344 |
+
|
| 345 |
+
if is_legacy_checkpoint(path):
|
| 346 |
+
ckpt = _load_legacy_pt(path, device)
|
| 347 |
+
model.load_state_dict(ckpt["model_state_dict"])
|
| 348 |
+
if optimizer and "optimizer_state_dict" in ckpt:
|
| 349 |
+
optimizer.load_state_dict(ckpt["optimizer_state_dict"])
|
| 350 |
+
if scheduler and "scheduler_state_dict" in ckpt:
|
| 351 |
+
scheduler.load_state_dict(ckpt["scheduler_state_dict"])
|
| 352 |
+
if scaler and "scaler_state_dict" in ckpt:
|
| 353 |
+
scaler.load_state_dict(ckpt["scaler_state_dict"])
|
| 354 |
+
if ckpt.get("torch_rng_state") is not None:
|
| 355 |
+
torch.set_rng_state(ckpt["torch_rng_state"].cpu().byte())
|
| 356 |
+
if ckpt.get("cuda_rng_state") is not None and torch.cuda.is_available():
|
| 357 |
+
torch.cuda.set_rng_state(ckpt["cuda_rng_state"].cpu().byte())
|
| 358 |
+
return {
|
| 359 |
+
"global_step": ckpt.get("global_step", 0),
|
| 360 |
+
"model_config": ckpt.get("model_config"),
|
| 361 |
+
"training_config": ckpt.get("training_config"),
|
| 362 |
+
}
|
| 363 |
+
|
| 364 |
+
# New directory format — verify integrity first
|
| 365 |
+
_verify_complete_sentinel(path)
|
| 366 |
+
|
| 367 |
+
weights = load_file(path / "model.safetensors", device=device)
|
| 368 |
+
model.load_state_dict(weights)
|
| 369 |
+
|
| 370 |
+
with open(path / "training_state.json") as f:
|
| 371 |
+
ts = json.load(f)
|
| 372 |
+
|
| 373 |
+
if optimizer and (path / "optimizer.safetensors").exists():
|
| 374 |
+
opt_tensors = load_file(path / "optimizer.safetensors", device=device)
|
| 375 |
+
opt_state = _unflatten_optimizer_state(opt_tensors, ts["optimizer_meta"], device)
|
| 376 |
+
optimizer.load_state_dict(opt_state)
|
| 377 |
+
|
| 378 |
+
if scheduler and "scheduler_state_dict" in ts:
|
| 379 |
+
scheduler.load_state_dict(ts["scheduler_state_dict"])
|
| 380 |
+
|
| 381 |
+
if scaler and "scaler_state_dict" in ts:
|
| 382 |
+
scaler.load_state_dict(ts["scaler_state_dict"])
|
| 383 |
+
|
| 384 |
+
torch_rng, cuda_rng = _json_to_rng(ts)
|
| 385 |
+
if torch_rng is not None:
|
| 386 |
+
torch.set_rng_state(torch_rng.cpu().byte())
|
| 387 |
+
if cuda_rng is not None and torch.cuda.is_available():
|
| 388 |
+
torch.cuda.set_rng_state(cuda_rng.cpu().byte())
|
| 389 |
+
|
| 390 |
+
with open(path / "config.json") as f:
|
| 391 |
+
config = json.load(f)
|
| 392 |
+
|
| 393 |
+
return {
|
| 394 |
+
"global_step": ts.get("global_step", 0),
|
| 395 |
+
"model_config": config.get("model_config"),
|
| 396 |
+
"training_config": config.get("training_config"),
|
| 397 |
+
}
|
| 398 |
+
|
| 399 |
+
|
| 400 |
+
# ---------------------------------------------------------------------------
|
| 401 |
+
# Adapter checkpoint save/load
|
| 402 |
+
# ---------------------------------------------------------------------------
|
| 403 |
+
|
| 404 |
+
def save_adapter_checkpoint(
|
| 405 |
+
path: str | Path,
|
| 406 |
+
adapter_state_dict: dict[str, torch.Tensor],
|
| 407 |
+
config: dict,
|
| 408 |
+
epoch: int,
|
| 409 |
+
step: int,
|
| 410 |
+
val_metrics: dict,
|
| 411 |
+
optimizer: torch.optim.Optimizer | None = None,
|
| 412 |
+
scheduler=None,
|
| 413 |
+
scaler=None,
|
| 414 |
+
extra: dict | None = None,
|
| 415 |
+
) -> None:
|
| 416 |
+
"""Save an adapter checkpoint atomically."""
|
| 417 |
+
path = Path(path)
|
| 418 |
+
|
| 419 |
+
with _atomic_directory_write(path) as tmp:
|
| 420 |
+
# 1. Adapter weights
|
| 421 |
+
tensors = {k: v.cpu().contiguous() for k, v in adapter_state_dict.items()}
|
| 422 |
+
save_file(tensors, tmp / "adapter.safetensors")
|
| 423 |
+
|
| 424 |
+
# 2. Optimizer tensors (if provided)
|
| 425 |
+
opt_meta = None
|
| 426 |
+
if optimizer is not None:
|
| 427 |
+
opt_tensors, opt_meta = _flatten_optimizer_state(optimizer.state_dict())
|
| 428 |
+
if opt_tensors:
|
| 429 |
+
save_file(opt_tensors, tmp / "optimizer.safetensors")
|
| 430 |
+
|
| 431 |
+
# 3. Training state
|
| 432 |
+
training_state: dict = {
|
| 433 |
+
"format_version": CHECKPOINT_FORMAT_VERSION,
|
| 434 |
+
"epoch": epoch,
|
| 435 |
+
"step": step,
|
| 436 |
+
"val_metrics": val_metrics,
|
| 437 |
+
}
|
| 438 |
+
if opt_meta is not None:
|
| 439 |
+
training_state["optimizer_meta"] = opt_meta
|
| 440 |
+
if scheduler is not None:
|
| 441 |
+
training_state["scheduler_state_dict"] = scheduler.state_dict()
|
| 442 |
+
if scaler is not None:
|
| 443 |
+
training_state["scaler_state_dict"] = scaler.state_dict()
|
| 444 |
+
if extra:
|
| 445 |
+
training_state.update(extra)
|
| 446 |
+
|
| 447 |
+
with open(tmp / "training_state.json", "w") as f:
|
| 448 |
+
json.dump(training_state, f, indent=2, default=_json_default)
|
| 449 |
+
|
| 450 |
+
# 4. Config
|
| 451 |
+
with open(tmp / "config.json", "w") as f:
|
| 452 |
+
json.dump(config, f, indent=2, default=_json_default)
|
| 453 |
+
|
| 454 |
+
|
| 455 |
+
def load_adapter_checkpoint(
|
| 456 |
+
path: str | Path,
|
| 457 |
+
device: str = "cpu",
|
| 458 |
+
) -> dict:
|
| 459 |
+
"""Load an adapter checkpoint with integrity verification."""
|
| 460 |
+
path = Path(path)
|
| 461 |
+
|
| 462 |
+
if is_legacy_checkpoint(path):
|
| 463 |
+
ckpt = _load_legacy_pt(path, device)
|
| 464 |
+
adapter_key = None
|
| 465 |
+
for key in ("lora_state_dict", "bottleneck_state_dict", "film_state_dict",
|
| 466 |
+
"sparse_state_dict", "adapter_state_dict", "model_state_dict"):
|
| 467 |
+
if key in ckpt:
|
| 468 |
+
adapter_key = key
|
| 469 |
+
break
|
| 470 |
+
return {
|
| 471 |
+
"adapter_state_dict": ckpt.get(adapter_key, {}),
|
| 472 |
+
"config": ckpt.get("config", {}),
|
| 473 |
+
"epoch": ckpt.get("epoch", 0),
|
| 474 |
+
"step": ckpt.get("step", 0),
|
| 475 |
+
"val_metrics": {
|
| 476 |
+
"loss": ckpt.get("val_loss"),
|
| 477 |
+
"top1_accuracy": ckpt.get("val_top1"),
|
| 478 |
+
},
|
| 479 |
+
"optimizer_state_dict": ckpt.get("optimizer_state_dict"),
|
| 480 |
+
"scheduler_state_dict": ckpt.get("scheduler_state_dict"),
|
| 481 |
+
"scaler_state_dict": ckpt.get("scaler_state_dict"),
|
| 482 |
+
"best_val_loss": ckpt.get("best_val_loss"),
|
| 483 |
+
"patience_counter": ckpt.get("patience_counter"),
|
| 484 |
+
}
|
| 485 |
+
|
| 486 |
+
# New directory format — verify integrity first
|
| 487 |
+
_verify_complete_sentinel(path)
|
| 488 |
+
|
| 489 |
+
adapter_weights = load_file(path / "adapter.safetensors", device=device)
|
| 490 |
+
|
| 491 |
+
with open(path / "config.json") as f:
|
| 492 |
+
config = json.load(f)
|
| 493 |
+
|
| 494 |
+
ts = {}
|
| 495 |
+
ts_path = path / "training_state.json"
|
| 496 |
+
if ts_path.exists():
|
| 497 |
+
with open(ts_path) as f:
|
| 498 |
+
ts = json.load(f)
|
| 499 |
+
|
| 500 |
+
result: dict = {
|
| 501 |
+
"adapter_state_dict": adapter_weights,
|
| 502 |
+
"config": config,
|
| 503 |
+
"epoch": ts.get("epoch", 0),
|
| 504 |
+
"step": ts.get("step", 0),
|
| 505 |
+
"val_metrics": ts.get("val_metrics", {}),
|
| 506 |
+
"best_val_loss": ts.get("best_val_loss"),
|
| 507 |
+
"patience_counter": ts.get("patience_counter"),
|
| 508 |
+
}
|
| 509 |
+
|
| 510 |
+
if (path / "optimizer.safetensors").exists() and "optimizer_meta" in ts:
|
| 511 |
+
opt_tensors = load_file(path / "optimizer.safetensors", device=device)
|
| 512 |
+
result["optimizer_state_dict"] = _unflatten_optimizer_state(
|
| 513 |
+
opt_tensors, ts["optimizer_meta"], device
|
| 514 |
+
)
|
| 515 |
+
|
| 516 |
+
if "scheduler_state_dict" in ts:
|
| 517 |
+
result["scheduler_state_dict"] = ts["scheduler_state_dict"]
|
| 518 |
+
|
| 519 |
+
if "scaler_state_dict" in ts:
|
| 520 |
+
result["scaler_state_dict"] = ts["scaler_state_dict"]
|
| 521 |
+
|
| 522 |
+
return result
|
| 523 |
+
|
| 524 |
+
|
| 525 |
+
# ---------------------------------------------------------------------------
|
| 526 |
+
# Backbone-only loading (inference)
|
| 527 |
+
# ---------------------------------------------------------------------------
|
| 528 |
+
|
| 529 |
+
def load_backbone_weights(
|
| 530 |
+
path: str | Path,
|
| 531 |
+
device: str = "cpu",
|
| 532 |
+
) -> tuple[dict[str, torch.Tensor], dict | None]:
|
| 533 |
+
"""Load model weights and config for inference with integrity verification.
|
| 534 |
+
|
| 535 |
+
Works with:
|
| 536 |
+
- Legacy .pt files (extracts model_state_dict + model_config)
|
| 537 |
+
- New checkpoint directories (reads model.safetensors + config.json, verifies .complete)
|
| 538 |
+
- Bare model.safetensors files (no .complete check — used for HF downloads)
|
| 539 |
+
|
| 540 |
+
Returns (state_dict, model_config_dict_or_None).
|
| 541 |
+
"""
|
| 542 |
+
path = Path(path)
|
| 543 |
+
|
| 544 |
+
if is_legacy_checkpoint(path):
|
| 545 |
+
ckpt = _load_legacy_pt(path, device)
|
| 546 |
+
return ckpt["model_state_dict"], ckpt.get("model_config")
|
| 547 |
+
|
| 548 |
+
# Directory with model.safetensors
|
| 549 |
+
if path.is_dir():
|
| 550 |
+
sf_path = path / "model.safetensors"
|
| 551 |
+
if not sf_path.exists():
|
| 552 |
+
raise FileNotFoundError(f"No model.safetensors in {path}")
|
| 553 |
+
# Verify integrity if .complete exists (new format checkpoints)
|
| 554 |
+
if (path / ".complete").exists():
|
| 555 |
+
_verify_complete_sentinel(path)
|
| 556 |
+
weights = load_file(sf_path, device=device)
|
| 557 |
+
config = None
|
| 558 |
+
config_path = path / "config.json"
|
| 559 |
+
if config_path.exists():
|
| 560 |
+
with open(config_path) as f:
|
| 561 |
+
config = json.load(f).get("model_config")
|
| 562 |
+
return weights, config
|
| 563 |
+
|
| 564 |
+
# Bare safetensors file
|
| 565 |
+
if path.suffix == ".safetensors":
|
| 566 |
+
weights = load_file(path, device=device)
|
| 567 |
+
config = None
|
| 568 |
+
config_path = path.parent / "config.json"
|
| 569 |
+
if config_path.exists():
|
| 570 |
+
with open(config_path) as f:
|
| 571 |
+
config = json.load(f).get("model_config")
|
| 572 |
+
return weights, config
|
| 573 |
+
|
| 574 |
+
raise ValueError(f"Unrecognized checkpoint format: {path}")
|
| 575 |
+
|
| 576 |
+
|
| 577 |
+
# ---------------------------------------------------------------------------
|
| 578 |
+
# HuggingFace push
|
| 579 |
+
# ---------------------------------------------------------------------------
|
| 580 |
+
|
| 581 |
+
def push_checkpoint_to_hf(
|
| 582 |
+
checkpoint_path: str | Path,
|
| 583 |
+
repo_id: str,
|
| 584 |
+
branch: str,
|
| 585 |
+
metrics_path: str | Path | None = None,
|
| 586 |
+
step: int = 0,
|
| 587 |
+
) -> None:
|
| 588 |
+
"""Push a complete checkpoint directory to a HuggingFace repo branch.
|
| 589 |
+
|
| 590 |
+
Uploads checkpoint files to checkpoints/step_NNNN/ on the branch.
|
| 591 |
+
Optionally uploads metrics.jsonl (truncated to current step) to the root.
|
| 592 |
+
|
| 593 |
+
Requires HF_TOKEN environment variable or prior `huggingface_hub.login()`.
|
| 594 |
+
"""
|
| 595 |
+
from huggingface_hub import HfApi
|
| 596 |
+
|
| 597 |
+
checkpoint_path = Path(checkpoint_path)
|
| 598 |
+
api = HfApi()
|
| 599 |
+
|
| 600 |
+
# Ensure branch exists
|
| 601 |
+
try:
|
| 602 |
+
api.create_branch(repo_id, repo_type="model", branch=branch, exist_ok=True)
|
| 603 |
+
except Exception:
|
| 604 |
+
pass # Branch may already exist
|
| 605 |
+
|
| 606 |
+
# Upload checkpoint directory
|
| 607 |
+
api.upload_folder(
|
| 608 |
+
folder_path=str(checkpoint_path),
|
| 609 |
+
path_in_repo=f"checkpoints/{checkpoint_path.name}",
|
| 610 |
+
repo_id=repo_id,
|
| 611 |
+
repo_type="model",
|
| 612 |
+
revision=branch,
|
| 613 |
+
commit_message=f"Checkpoint step {step}",
|
| 614 |
+
)
|
| 615 |
+
|
| 616 |
+
# Upload truncated metrics.jsonl to repo root
|
| 617 |
+
if metrics_path is not None:
|
| 618 |
+
metrics_path = Path(metrics_path)
|
| 619 |
+
if metrics_path.exists():
|
| 620 |
+
import tempfile
|
| 621 |
+
# Truncate metrics to current step
|
| 622 |
+
truncated_lines = []
|
| 623 |
+
with open(metrics_path) as f:
|
| 624 |
+
for line in f:
|
| 625 |
+
truncated_lines.append(line)
|
| 626 |
+
try:
|
| 627 |
+
record = json.loads(line)
|
| 628 |
+
if record.get("type") in ("train", "val") and record.get("step", 0) >= step:
|
| 629 |
+
break
|
| 630 |
+
except json.JSONDecodeError:
|
| 631 |
+
continue
|
| 632 |
+
|
| 633 |
+
with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as tmp:
|
| 634 |
+
tmp.writelines(truncated_lines)
|
| 635 |
+
tmp_path = tmp.name
|
| 636 |
+
|
| 637 |
+
try:
|
| 638 |
+
api.upload_file(
|
| 639 |
+
path_or_fileobj=tmp_path,
|
| 640 |
+
path_in_repo="metrics.jsonl",
|
| 641 |
+
repo_id=repo_id,
|
| 642 |
+
repo_type="model",
|
| 643 |
+
revision=branch,
|
| 644 |
+
commit_message=f"Metrics through step {step}",
|
| 645 |
+
)
|
| 646 |
+
finally:
|
| 647 |
+
os.unlink(tmp_path)
|
pawn/config.py
CHANGED
|
@@ -3,7 +3,7 @@
|
|
| 3 |
from dataclasses import dataclass
|
| 4 |
|
| 5 |
|
| 6 |
-
# Outcome token IDs
|
| 7 |
PAD_TOKEN = 0
|
| 8 |
OUTCOME_TOKEN_BASE = 4273
|
| 9 |
WHITE_CHECKMATES = 4273
|
|
|
|
| 3 |
from dataclasses import dataclass
|
| 4 |
|
| 5 |
|
| 6 |
+
# Outcome token IDs — must match engine/src/vocab.rs
|
| 7 |
PAD_TOKEN = 0
|
| 8 |
OUTCOME_TOKEN_BASE = 4273
|
| 9 |
WHITE_CHECKMATES = 4273
|
pawn/data.py
CHANGED
|
@@ -23,7 +23,6 @@ from pawn.config import (
|
|
| 23 |
_positions_cache: dict[tuple[str, int], torch.Tensor] = {}
|
| 24 |
|
| 25 |
|
| 26 |
-
|
| 27 |
def _map_termination_to_outcome(
|
| 28 |
term_codes: np.ndarray, game_lengths: np.ndarray
|
| 29 |
) -> torch.Tensor:
|
|
@@ -51,46 +50,45 @@ def _map_termination_to_outcome(
|
|
| 51 |
return outcomes
|
| 52 |
|
| 53 |
|
| 54 |
-
def
|
| 55 |
move_ids: np.ndarray,
|
| 56 |
game_lengths: np.ndarray,
|
| 57 |
-
|
| 58 |
seq_len: int,
|
| 59 |
) -> dict[str, torch.Tensor]:
|
| 60 |
-
"""
|
| 61 |
|
| 62 |
Constructs input_ids = [outcome, move_1, ..., move_N, PAD, ...]
|
| 63 |
and targets shifted left by 1.
|
| 64 |
|
| 65 |
Args:
|
| 66 |
-
move_ids: (B,
|
| 67 |
-
game_lengths: (B,) actual game lengths
|
| 68 |
-
|
| 69 |
seq_len: total CLM sequence length (256)
|
| 70 |
"""
|
| 71 |
B = len(game_lengths)
|
| 72 |
n_move_slots = seq_len - 1 # 255 slots for moves (position 0 = outcome)
|
| 73 |
-
|
| 74 |
|
| 75 |
game_lengths_t = torch.from_numpy(game_lengths).long()
|
| 76 |
-
move_ids_t = torch.from_numpy(move_ids).long() # (B,
|
| 77 |
-
outcome_tokens = _map_termination_to_outcome(term_codes, game_lengths)
|
| 78 |
|
| 79 |
# Build input_ids: [outcome, move_0, ..., move_{N-1}, PAD, ...]
|
| 80 |
input_ids = torch.zeros(B, seq_len, dtype=torch.long)
|
| 81 |
input_ids[:, 0] = outcome_tokens
|
| 82 |
|
| 83 |
# Mask out any non-move tokens from engine output
|
| 84 |
-
cache_key = ("engine",
|
| 85 |
engine_positions = _positions_cache.get(cache_key)
|
| 86 |
if engine_positions is None:
|
| 87 |
-
engine_positions = torch.arange(
|
| 88 |
_positions_cache[cache_key] = engine_positions
|
| 89 |
move_mask = engine_positions < game_lengths_t.unsqueeze(1)
|
| 90 |
clean_moves = move_ids_t * move_mask
|
| 91 |
|
| 92 |
# Place moves at positions 1..n_move_slots
|
| 93 |
-
n_to_copy = min(
|
| 94 |
input_ids[:, 1 : n_to_copy + 1] = clean_moves[:, :n_to_copy]
|
| 95 |
|
| 96 |
# Cap game_lengths to n_move_slots (handles edge case where engine
|
|
@@ -116,6 +114,21 @@ def _to_clm_batch(
|
|
| 116 |
}
|
| 117 |
|
| 118 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 119 |
class CLMDataset(torch.utils.data.IterableDataset):
|
| 120 |
"""Generates CLM training data on-the-fly via the Rust engine.
|
| 121 |
|
|
@@ -156,16 +169,18 @@ class CLMDataset(torch.utils.data.IterableDataset):
|
|
| 156 |
t.start()
|
| 157 |
|
| 158 |
step = self._start_step
|
| 159 |
-
# Engine generates games of up to max_ply-1 moves, leaving 1 slot
|
| 160 |
-
# for the outcome token in the seq_len=max_ply CLM sequence.
|
| 161 |
-
engine_max_ply = self.max_ply - 1
|
| 162 |
while True:
|
| 163 |
seed = self.base_seed + step * num_workers + worker_id
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 169 |
step += 1
|
| 170 |
|
| 171 |
|
|
@@ -178,17 +193,21 @@ def create_validation_set(
|
|
| 178 |
Also computes legal move masks for legal move rate evaluation.
|
| 179 |
|
| 180 |
Args:
|
| 181 |
-
max_ply: total CLM sequence length (256).
|
| 182 |
"""
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
batch =
|
|
|
|
|
|
|
|
|
|
|
|
|
| 188 |
|
| 189 |
# Compute legal move masks for evaluating legal move rate
|
| 190 |
legal_grid, _legal_promo = engine.compute_legal_move_masks(move_ids, game_lengths)
|
| 191 |
-
batch["legal_grid"] = torch.from_numpy(legal_grid).long()
|
| 192 |
batch["game_lengths"] = torch.from_numpy(game_lengths).long()
|
| 193 |
|
| 194 |
return batch
|
|
|
|
| 23 |
_positions_cache: dict[tuple[str, int], torch.Tensor] = {}
|
| 24 |
|
| 25 |
|
|
|
|
| 26 |
def _map_termination_to_outcome(
|
| 27 |
term_codes: np.ndarray, game_lengths: np.ndarray
|
| 28 |
) -> torch.Tensor:
|
|
|
|
| 50 |
return outcomes
|
| 51 |
|
| 52 |
|
| 53 |
+
def pack_clm_sequences(
|
| 54 |
move_ids: np.ndarray,
|
| 55 |
game_lengths: np.ndarray,
|
| 56 |
+
outcome_tokens: torch.Tensor,
|
| 57 |
seq_len: int,
|
| 58 |
) -> dict[str, torch.Tensor]:
|
| 59 |
+
"""Pack move arrays into CLM training tensors.
|
| 60 |
|
| 61 |
Constructs input_ids = [outcome, move_1, ..., move_N, PAD, ...]
|
| 62 |
and targets shifted left by 1.
|
| 63 |
|
| 64 |
Args:
|
| 65 |
+
move_ids: (B, max_ply) raw move token IDs
|
| 66 |
+
game_lengths: (B,) actual game lengths
|
| 67 |
+
outcome_tokens: (B,) pre-computed outcome token IDs (4273-4277)
|
| 68 |
seq_len: total CLM sequence length (256)
|
| 69 |
"""
|
| 70 |
B = len(game_lengths)
|
| 71 |
n_move_slots = seq_len - 1 # 255 slots for moves (position 0 = outcome)
|
| 72 |
+
max_ply = move_ids.shape[1]
|
| 73 |
|
| 74 |
game_lengths_t = torch.from_numpy(game_lengths).long()
|
| 75 |
+
move_ids_t = torch.from_numpy(move_ids).long() # (B, max_ply)
|
|
|
|
| 76 |
|
| 77 |
# Build input_ids: [outcome, move_0, ..., move_{N-1}, PAD, ...]
|
| 78 |
input_ids = torch.zeros(B, seq_len, dtype=torch.long)
|
| 79 |
input_ids[:, 0] = outcome_tokens
|
| 80 |
|
| 81 |
# Mask out any non-move tokens from engine output
|
| 82 |
+
cache_key = ("engine", max_ply)
|
| 83 |
engine_positions = _positions_cache.get(cache_key)
|
| 84 |
if engine_positions is None:
|
| 85 |
+
engine_positions = torch.arange(max_ply).unsqueeze(0)
|
| 86 |
_positions_cache[cache_key] = engine_positions
|
| 87 |
move_mask = engine_positions < game_lengths_t.unsqueeze(1)
|
| 88 |
clean_moves = move_ids_t * move_mask
|
| 89 |
|
| 90 |
# Place moves at positions 1..n_move_slots
|
| 91 |
+
n_to_copy = min(max_ply, n_move_slots)
|
| 92 |
input_ids[:, 1 : n_to_copy + 1] = clean_moves[:, :n_to_copy]
|
| 93 |
|
| 94 |
# Cap game_lengths to n_move_slots (handles edge case where engine
|
|
|
|
| 114 |
}
|
| 115 |
|
| 116 |
|
| 117 |
+
def _to_clm_batch(
|
| 118 |
+
move_ids: np.ndarray,
|
| 119 |
+
game_lengths: np.ndarray,
|
| 120 |
+
term_codes: np.ndarray,
|
| 121 |
+
seq_len: int,
|
| 122 |
+
) -> dict[str, torch.Tensor]:
|
| 123 |
+
"""Convert Rust engine output to CLM training tensors.
|
| 124 |
+
|
| 125 |
+
Convenience wrapper: computes outcome tokens from termination codes,
|
| 126 |
+
then delegates to pack_clm_sequences.
|
| 127 |
+
"""
|
| 128 |
+
outcome_tokens = _map_termination_to_outcome(term_codes, game_lengths)
|
| 129 |
+
return pack_clm_sequences(move_ids, game_lengths, outcome_tokens, seq_len)
|
| 130 |
+
|
| 131 |
+
|
| 132 |
class CLMDataset(torch.utils.data.IterableDataset):
|
| 133 |
"""Generates CLM training data on-the-fly via the Rust engine.
|
| 134 |
|
|
|
|
| 169 |
t.start()
|
| 170 |
|
| 171 |
step = self._start_step
|
|
|
|
|
|
|
|
|
|
| 172 |
while True:
|
| 173 |
seed = self.base_seed + step * num_workers + worker_id
|
| 174 |
+
input_ids, targets, loss_mask, _move_ids, _gl, _tc = \
|
| 175 |
+
engine.generate_clm_batch(
|
| 176 |
+
self.batch_size, self.max_ply, seed,
|
| 177 |
+
discard_ply_limit=self.discard_ply_limit,
|
| 178 |
+
)
|
| 179 |
+
yield {
|
| 180 |
+
"input_ids": torch.from_numpy(input_ids).long(),
|
| 181 |
+
"targets": torch.from_numpy(targets).long(),
|
| 182 |
+
"loss_mask": torch.from_numpy(loss_mask),
|
| 183 |
+
}
|
| 184 |
step += 1
|
| 185 |
|
| 186 |
|
|
|
|
| 193 |
Also computes legal move masks for legal move rate evaluation.
|
| 194 |
|
| 195 |
Args:
|
| 196 |
+
max_ply: total CLM sequence length (256).
|
| 197 |
"""
|
| 198 |
+
input_ids, targets, loss_mask, move_ids, game_lengths, _tc = \
|
| 199 |
+
engine.generate_clm_batch(
|
| 200 |
+
n_games, max_ply, seed, discard_ply_limit=discard_ply_limit,
|
| 201 |
+
)
|
| 202 |
+
batch = {
|
| 203 |
+
"input_ids": torch.from_numpy(input_ids).long(),
|
| 204 |
+
"targets": torch.from_numpy(targets).long(),
|
| 205 |
+
"loss_mask": torch.from_numpy(loss_mask),
|
| 206 |
+
}
|
| 207 |
|
| 208 |
# Compute legal move masks for evaluating legal move rate
|
| 209 |
legal_grid, _legal_promo = engine.compute_legal_move_masks(move_ids, game_lengths)
|
| 210 |
+
batch["legal_grid"] = torch.from_numpy(legal_grid).long()
|
| 211 |
batch["game_lengths"] = torch.from_numpy(game_lengths).long()
|
| 212 |
|
| 213 |
return batch
|
pawn/eval_suite/diagnostics.py
CHANGED
|
@@ -8,7 +8,7 @@ import torch.nn.functional as F
|
|
| 8 |
import chess_engine as engine
|
| 9 |
|
| 10 |
from pawn.config import PAD_TOKEN
|
| 11 |
-
from pawn.data import
|
| 12 |
|
| 13 |
|
| 14 |
# ---------------------------------------------------------------------------
|
|
|
|
| 8 |
import chess_engine as engine
|
| 9 |
|
| 10 |
from pawn.config import PAD_TOKEN
|
| 11 |
+
from pawn.data import pack_clm_sequences, _map_termination_to_outcome
|
| 12 |
|
| 13 |
|
| 14 |
# ---------------------------------------------------------------------------
|
pawn/eval_suite/lichess.py
CHANGED
|
@@ -9,8 +9,8 @@ import torch.nn as nn
|
|
| 9 |
|
| 10 |
import chess_engine as engine
|
| 11 |
|
| 12 |
-
from pawn.config import PAD_TOKEN, WHITE_CHECKMATES
|
| 13 |
-
from pawn.data import
|
| 14 |
|
| 15 |
|
| 16 |
# ---------------------------------------------------------------------------
|
|
@@ -139,12 +139,11 @@ def evaluate_on_lichess(
|
|
| 139 |
game_lengths = band_data["game_lengths"]
|
| 140 |
n = len(move_ids)
|
| 141 |
|
| 142 |
-
# We don't have true termination codes for Lichess games parsed via PGN
|
| 143 |
-
#
|
| 144 |
-
#
|
| 145 |
-
# For loss/accuracy evaluation, the outcome token choice doesn't matter
|
| 146 |
# much since we evaluate on move prediction, not outcome prediction.
|
| 147 |
-
|
| 148 |
|
| 149 |
engine_max_ply = max_seq_len - 1
|
| 150 |
# Pad/truncate move_ids to engine_max_ply
|
|
@@ -154,7 +153,7 @@ def evaluate_on_lichess(
|
|
| 154 |
padded[i, :gl] = move_ids[i, :gl]
|
| 155 |
game_lengths_capped = np.minimum(game_lengths, engine_max_ply).astype(np.int16)
|
| 156 |
|
| 157 |
-
batch =
|
| 158 |
input_ids = batch["input_ids"]
|
| 159 |
targets = batch["targets"]
|
| 160 |
loss_mask = batch["loss_mask"]
|
|
|
|
| 9 |
|
| 10 |
import chess_engine as engine
|
| 11 |
|
| 12 |
+
from pawn.config import PAD_TOKEN, WHITE_CHECKMATES, PLY_LIMIT
|
| 13 |
+
from pawn.data import pack_clm_sequences, _map_termination_to_outcome
|
| 14 |
|
| 15 |
|
| 16 |
# ---------------------------------------------------------------------------
|
|
|
|
| 139 |
game_lengths = band_data["game_lengths"]
|
| 140 |
n = len(move_ids)
|
| 141 |
|
| 142 |
+
# We don't have true termination codes for Lichess games parsed via PGN.
|
| 143 |
+
# Use PLY_LIMIT as a dummy outcome token for all games — for
|
| 144 |
+
# loss/accuracy evaluation the outcome token choice doesn't matter
|
|
|
|
| 145 |
# much since we evaluate on move prediction, not outcome prediction.
|
| 146 |
+
dummy_outcomes = torch.full((n,), PLY_LIMIT, dtype=torch.long)
|
| 147 |
|
| 148 |
engine_max_ply = max_seq_len - 1
|
| 149 |
# Pad/truncate move_ids to engine_max_ply
|
|
|
|
| 153 |
padded[i, :gl] = move_ids[i, :gl]
|
| 154 |
game_lengths_capped = np.minimum(game_lengths, engine_max_ply).astype(np.int16)
|
| 155 |
|
| 156 |
+
batch = pack_clm_sequences(padded, game_lengths_capped, dummy_outcomes, max_seq_len)
|
| 157 |
input_ids = batch["input_ids"]
|
| 158 |
targets = batch["targets"]
|
| 159 |
loss_mask = batch["loss_mask"]
|
pawn/eval_suite/probes.py
CHANGED
|
@@ -15,7 +15,6 @@ import chess_engine as engine
|
|
| 15 |
|
| 16 |
from pawn.config import CLMConfig
|
| 17 |
from pawn.model import PAWNCLM
|
| 18 |
-
from pawn.data import _to_clm_batch
|
| 19 |
|
| 20 |
|
| 21 |
# ---------------------------------------------------------------------------
|
|
@@ -77,21 +76,16 @@ def extract_probe_data(
|
|
| 77 |
|
| 78 |
Returns dict with all arrays needed for all probes.
|
| 79 |
"""
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
move_ids_np, game_lengths_np, term_codes_np = engine.generate_random_games(
|
| 83 |
-
n_games, engine_max_ply, seed
|
| 84 |
-
)
|
| 85 |
|
| 86 |
boards_np, side_np, castling_np, ep_np, check_np, halfmove_np = (
|
| 87 |
engine.extract_board_states(move_ids_np, game_lengths_np)
|
| 88 |
)
|
| 89 |
|
| 90 |
-
batch = _to_clm_batch(move_ids_np, game_lengths_np, term_codes_np, max_ply)
|
| 91 |
-
|
| 92 |
result = {
|
| 93 |
-
"input_ids":
|
| 94 |
-
"loss_mask":
|
| 95 |
"boards": torch.from_numpy(boards_np.copy()).long(),
|
| 96 |
"side_to_move": torch.from_numpy(side_np.copy()).float(),
|
| 97 |
"castling_rights": torch.from_numpy(castling_np.copy()),
|
|
|
|
| 15 |
|
| 16 |
from pawn.config import CLMConfig
|
| 17 |
from pawn.model import PAWNCLM
|
|
|
|
| 18 |
|
| 19 |
|
| 20 |
# ---------------------------------------------------------------------------
|
|
|
|
| 76 |
|
| 77 |
Returns dict with all arrays needed for all probes.
|
| 78 |
"""
|
| 79 |
+
input_ids, targets, loss_mask, move_ids_np, game_lengths_np, _tc = \
|
| 80 |
+
engine.generate_clm_batch(n_games, max_ply, seed)
|
|
|
|
|
|
|
|
|
|
| 81 |
|
| 82 |
boards_np, side_np, castling_np, ep_np, check_np, halfmove_np = (
|
| 83 |
engine.extract_board_states(move_ids_np, game_lengths_np)
|
| 84 |
)
|
| 85 |
|
|
|
|
|
|
|
| 86 |
result = {
|
| 87 |
+
"input_ids": torch.from_numpy(input_ids).long(),
|
| 88 |
+
"loss_mask": torch.from_numpy(loss_mask),
|
| 89 |
"boards": torch.from_numpy(boards_np.copy()).long(),
|
| 90 |
"side_to_move": torch.from_numpy(side_np.copy()).float(),
|
| 91 |
"castling_rights": torch.from_numpy(castling_np.copy()),
|
pawn/eval_suite/worker.py
CHANGED
|
@@ -62,15 +62,15 @@ def run_in_worker(fn: Callable[..., Any], *args: Any, timeout: float | None = No
|
|
| 62 |
|
| 63 |
def _load_model(checkpoint_path: str, device: str) -> PAWNCLM:
|
| 64 |
"""Load and freeze a PAWNCLM checkpoint. Runs inside worker processes."""
|
| 65 |
-
import
|
| 66 |
from pawn.config import CLMConfig
|
| 67 |
from pawn.model import PAWNCLM
|
| 68 |
|
| 69 |
-
|
| 70 |
-
cfg = CLMConfig(**
|
| 71 |
model = PAWNCLM(cfg).to(device)
|
| 72 |
-
model.load_state_dict(
|
| 73 |
-
del
|
| 74 |
gc.collect()
|
| 75 |
model.eval()
|
| 76 |
for p in model.parameters():
|
|
|
|
| 62 |
|
| 63 |
def _load_model(checkpoint_path: str, device: str) -> PAWNCLM:
|
| 64 |
"""Load and freeze a PAWNCLM checkpoint. Runs inside worker processes."""
|
| 65 |
+
from pawn.checkpoint import load_backbone_weights
|
| 66 |
from pawn.config import CLMConfig
|
| 67 |
from pawn.model import PAWNCLM
|
| 68 |
|
| 69 |
+
state_dict, model_config = load_backbone_weights(checkpoint_path, device)
|
| 70 |
+
cfg = CLMConfig(**model_config) if model_config else CLMConfig()
|
| 71 |
model = PAWNCLM(cfg).to(device)
|
| 72 |
+
model.load_state_dict(state_dict)
|
| 73 |
+
del state_dict
|
| 74 |
gc.collect()
|
| 75 |
model.eval()
|
| 76 |
for p in model.parameters():
|
pawn/lichess_data.py
CHANGED
|
@@ -219,32 +219,15 @@ def prepare_lichess_dataset(
|
|
| 219 |
|
| 220 |
seq_len = max_ply + 1 # outcome token + max_ply move slots
|
| 221 |
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
input_ids[:, 0] = outcome_tokens
|
| 225 |
-
|
| 226 |
-
gl_t = torch.from_numpy(game_lengths).long()
|
| 227 |
-
mid_t = torch.from_numpy(move_ids).long()
|
| 228 |
-
|
| 229 |
-
ply_range = torch.arange(max_ply).unsqueeze(0)
|
| 230 |
-
move_mask = ply_range < gl_t.unsqueeze(1)
|
| 231 |
-
input_ids[:, 1:] = mid_t * move_mask
|
| 232 |
-
|
| 233 |
-
# Targets: shifted left by 1
|
| 234 |
-
targets = torch.zeros(N, seq_len, dtype=torch.long)
|
| 235 |
-
targets[:, :-1] = input_ids[:, 1:]
|
| 236 |
-
|
| 237 |
-
# Loss mask: positions 0..game_length-1 (each has a valid move target)
|
| 238 |
-
# Position gl would target PAD, which we don't want to train on.
|
| 239 |
-
seq_positions = torch.arange(seq_len).unsqueeze(0)
|
| 240 |
-
loss_mask = seq_positions < gl_t.unsqueeze(1)
|
| 241 |
|
| 242 |
return {
|
| 243 |
"move_ids": move_ids,
|
| 244 |
"game_lengths": game_lengths,
|
| 245 |
-
"input_ids": input_ids,
|
| 246 |
-
"targets": targets,
|
| 247 |
-
"loss_mask": loss_mask,
|
| 248 |
"outcome_tokens": outcome_tokens,
|
| 249 |
"n_games": N,
|
| 250 |
}
|
|
|
|
| 219 |
|
| 220 |
seq_len = max_ply + 1 # outcome token + max_ply move slots
|
| 221 |
|
| 222 |
+
from pawn.data import pack_clm_sequences
|
| 223 |
+
batch = pack_clm_sequences(move_ids, game_lengths, outcome_tokens, seq_len)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 224 |
|
| 225 |
return {
|
| 226 |
"move_ids": move_ids,
|
| 227 |
"game_lengths": game_lengths,
|
| 228 |
+
"input_ids": batch["input_ids"],
|
| 229 |
+
"targets": batch["targets"],
|
| 230 |
+
"loss_mask": batch["loss_mask"],
|
| 231 |
"outcome_tokens": outcome_tokens,
|
| 232 |
"n_games": N,
|
| 233 |
}
|
pawn/model.py
CHANGED
|
@@ -37,7 +37,7 @@ def _build_decomposition_table() -> torch.Tensor:
|
|
| 37 |
|
| 38 |
for token_idx, uci_str in vocab["token_to_move"].items():
|
| 39 |
if token_idx >= OUTCOME_TOKEN_BASE:
|
| 40 |
-
continue #
|
| 41 |
src_name = uci_str[:2]
|
| 42 |
dst_name = uci_str[2:4]
|
| 43 |
promo_suffix = uci_str[4:] if len(uci_str) > 4 else ""
|
|
|
|
| 37 |
|
| 38 |
for token_idx, uci_str in vocab["token_to_move"].items():
|
| 39 |
if token_idx >= OUTCOME_TOKEN_BASE:
|
| 40 |
+
continue # Outcome tokens use standalone embeddings
|
| 41 |
src_name = uci_str[:2]
|
| 42 |
dst_name = uci_str[2:4]
|
| 43 |
promo_suffix = uci_str[4:] if len(uci_str) > 4 else ""
|
pawn/trainer.py
CHANGED
|
@@ -207,17 +207,27 @@ def _make_run_dir(base_log_dir: str) -> str:
|
|
| 207 |
|
| 208 |
|
| 209 |
class CLMTrainer:
|
| 210 |
-
def __init__(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 211 |
self.cfg = train_cfg
|
| 212 |
self.model_cfg = model_cfg
|
| 213 |
self.device = train_cfg.device
|
| 214 |
self.global_step = 0
|
|
|
|
|
|
|
| 215 |
|
| 216 |
self.run_dir = _make_run_dir(train_cfg.log_dir)
|
| 217 |
self.cfg.checkpoint_dir = os.path.join(self.run_dir, "checkpoints")
|
| 218 |
self._jsonl_path = os.path.join(self.run_dir, "metrics.jsonl")
|
| 219 |
self._jsonl_file = None
|
| 220 |
|
|
|
|
|
|
|
|
|
|
| 221 |
self._model = PAWNCLM(model_cfg).to(self.device)
|
| 222 |
self.model = self._model
|
| 223 |
param_count = sum(p.numel() for p in self._model.parameters())
|
|
@@ -449,12 +459,13 @@ class CLMTrainer:
|
|
| 449 |
prefetch_factor=1 if num_workers > 0 else None,
|
| 450 |
)
|
| 451 |
|
|
|
|
|
|
|
|
|
|
| 452 |
def _graceful_exit(signum, frame):
|
| 453 |
-
|
| 454 |
-
|
| 455 |
-
|
| 456 |
-
self._jsonl_file.close()
|
| 457 |
-
sys.exit(128 + signum)
|
| 458 |
|
| 459 |
old_term = signal.signal(signal.SIGTERM, _graceful_exit)
|
| 460 |
old_int = signal.signal(signal.SIGINT, _graceful_exit)
|
|
@@ -547,6 +558,12 @@ class CLMTrainer:
|
|
| 547 |
self.save_checkpoint()
|
| 548 |
break
|
| 549 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 550 |
step_start = time.time()
|
| 551 |
|
| 552 |
signal.signal(signal.SIGTERM, old_term)
|
|
@@ -557,49 +574,47 @@ class CLMTrainer:
|
|
| 557 |
self._jsonl_file = None
|
| 558 |
|
| 559 |
def save_checkpoint(self, path: str | None = None):
|
|
|
|
|
|
|
| 560 |
if path is None:
|
| 561 |
path = os.path.join(
|
| 562 |
-
self.cfg.checkpoint_dir, f"step_{self.global_step:08d}
|
| 563 |
)
|
| 564 |
|
| 565 |
-
dirname = os.path.dirname(path)
|
| 566 |
-
if dirname:
|
| 567 |
-
os.makedirs(dirname, exist_ok=True)
|
| 568 |
-
|
| 569 |
model: PAWNCLM = self._eager_model()
|
| 570 |
|
| 571 |
-
|
| 572 |
-
{
|
| 573 |
-
"global_step": self.global_step,
|
| 574 |
-
"model_state_dict": model.state_dict(),
|
| 575 |
-
"optimizer_state_dict": self.optimizer.state_dict(),
|
| 576 |
-
"scheduler_state_dict": self.scheduler.state_dict(),
|
| 577 |
-
"scaler_state_dict": self.scaler.state_dict(),
|
| 578 |
-
"model_config": self.model_cfg.__dict__,
|
| 579 |
-
"training_config": self.cfg.__dict__,
|
| 580 |
-
"torch_rng_state": torch.get_rng_state(),
|
| 581 |
-
"cuda_rng_state": (
|
| 582 |
-
torch.cuda.get_rng_state() if torch.cuda.is_available() else None
|
| 583 |
-
),
|
| 584 |
-
},
|
| 585 |
path,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 586 |
)
|
| 587 |
print(f"Checkpoint saved: {path}")
|
| 588 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 589 |
def load_checkpoint(self, path: str):
|
| 590 |
-
|
| 591 |
-
self.global_step = ckpt["global_step"]
|
| 592 |
|
| 593 |
model: PAWNCLM = self._eager_model()
|
| 594 |
|
| 595 |
-
|
| 596 |
-
|
| 597 |
-
|
| 598 |
-
|
| 599 |
-
|
| 600 |
-
if ckpt.get("torch_rng_state") is not None:
|
| 601 |
-
torch.set_rng_state(ckpt["torch_rng_state"].cpu().byte())
|
| 602 |
-
if ckpt.get("cuda_rng_state") is not None and torch.cuda.is_available():
|
| 603 |
-
torch.cuda.set_rng_state(ckpt["cuda_rng_state"].cpu().byte())
|
| 604 |
-
|
| 605 |
print(f"Resumed from step {self.global_step}")
|
|
|
|
| 207 |
|
| 208 |
|
| 209 |
class CLMTrainer:
|
| 210 |
+
def __init__(
|
| 211 |
+
self,
|
| 212 |
+
train_cfg: TrainingConfig,
|
| 213 |
+
model_cfg: CLMConfig,
|
| 214 |
+
hf_repo: str | None = None,
|
| 215 |
+
):
|
| 216 |
self.cfg = train_cfg
|
| 217 |
self.model_cfg = model_cfg
|
| 218 |
self.device = train_cfg.device
|
| 219 |
self.global_step = 0
|
| 220 |
+
self.hf_repo = hf_repo
|
| 221 |
+
self.hf_branch: str | None = None
|
| 222 |
|
| 223 |
self.run_dir = _make_run_dir(train_cfg.log_dir)
|
| 224 |
self.cfg.checkpoint_dir = os.path.join(self.run_dir, "checkpoints")
|
| 225 |
self._jsonl_path = os.path.join(self.run_dir, "metrics.jsonl")
|
| 226 |
self._jsonl_file = None
|
| 227 |
|
| 228 |
+
if self.hf_repo:
|
| 229 |
+
self.hf_branch = f"run/{os.path.basename(self.run_dir)}"
|
| 230 |
+
|
| 231 |
self._model = PAWNCLM(model_cfg).to(self.device)
|
| 232 |
self.model = self._model
|
| 233 |
param_count = sum(p.numel() for p in self._model.parameters())
|
|
|
|
| 459 |
prefetch_factor=1 if num_workers > 0 else None,
|
| 460 |
)
|
| 461 |
|
| 462 |
+
_shutdown_requested = False
|
| 463 |
+
_shutdown_signal = None
|
| 464 |
+
|
| 465 |
def _graceful_exit(signum, frame):
|
| 466 |
+
nonlocal _shutdown_requested, _shutdown_signal
|
| 467 |
+
_shutdown_requested = True
|
| 468 |
+
_shutdown_signal = signum
|
|
|
|
|
|
|
| 469 |
|
| 470 |
old_term = signal.signal(signal.SIGTERM, _graceful_exit)
|
| 471 |
old_int = signal.signal(signal.SIGINT, _graceful_exit)
|
|
|
|
| 558 |
self.save_checkpoint()
|
| 559 |
break
|
| 560 |
|
| 561 |
+
if _shutdown_requested:
|
| 562 |
+
print(f"\nShutdown requested (signal {_shutdown_signal}), "
|
| 563 |
+
f"saving checkpoint at step {self.global_step}...")
|
| 564 |
+
self.save_checkpoint()
|
| 565 |
+
break
|
| 566 |
+
|
| 567 |
step_start = time.time()
|
| 568 |
|
| 569 |
signal.signal(signal.SIGTERM, old_term)
|
|
|
|
| 574 |
self._jsonl_file = None
|
| 575 |
|
| 576 |
def save_checkpoint(self, path: str | None = None):
|
| 577 |
+
from pawn.checkpoint import save_pretrain_checkpoint
|
| 578 |
+
|
| 579 |
if path is None:
|
| 580 |
path = os.path.join(
|
| 581 |
+
self.cfg.checkpoint_dir, f"step_{self.global_step:08d}"
|
| 582 |
)
|
| 583 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 584 |
model: PAWNCLM = self._eager_model()
|
| 585 |
|
| 586 |
+
save_pretrain_checkpoint(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 587 |
path,
|
| 588 |
+
model,
|
| 589 |
+
self.optimizer,
|
| 590 |
+
self.scheduler,
|
| 591 |
+
self.scaler,
|
| 592 |
+
self.global_step,
|
| 593 |
+
self.model_cfg.__dict__,
|
| 594 |
+
self.cfg.__dict__,
|
| 595 |
)
|
| 596 |
print(f"Checkpoint saved: {path}")
|
| 597 |
|
| 598 |
+
if self.hf_repo and self.hf_branch:
|
| 599 |
+
from pawn.checkpoint import push_checkpoint_to_hf
|
| 600 |
+
try:
|
| 601 |
+
push_checkpoint_to_hf(
|
| 602 |
+
path, self.hf_repo, self.hf_branch,
|
| 603 |
+
metrics_path=self._jsonl_path,
|
| 604 |
+
step=self.global_step,
|
| 605 |
+
)
|
| 606 |
+
print(f"Pushed to HF: {self.hf_repo}@{self.hf_branch}")
|
| 607 |
+
except Exception as e:
|
| 608 |
+
print(f"WARNING: HF push failed: {e}")
|
| 609 |
+
|
| 610 |
def load_checkpoint(self, path: str):
|
| 611 |
+
from pawn.checkpoint import load_pretrain_checkpoint
|
|
|
|
| 612 |
|
| 613 |
model: PAWNCLM = self._eager_model()
|
| 614 |
|
| 615 |
+
meta = load_pretrain_checkpoint(
|
| 616 |
+
path, model, self.optimizer, self.scheduler, self.scaler,
|
| 617 |
+
device=self.device,
|
| 618 |
+
)
|
| 619 |
+
self.global_step = meta["global_step"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 620 |
print(f"Resumed from step {self.global_step}")
|
pyproject.toml
CHANGED
|
@@ -8,6 +8,7 @@ dependencies = [
|
|
| 8 |
"chess-engine",
|
| 9 |
"numpy~=2.2.0",
|
| 10 |
"psutil>=5.9.0",
|
|
|
|
| 11 |
"tqdm~=4.67.0",
|
| 12 |
"wandb~=0.25.0",
|
| 13 |
]
|
|
|
|
| 8 |
"chess-engine",
|
| 9 |
"numpy~=2.2.0",
|
| 10 |
"psutil>=5.9.0",
|
| 11 |
+
"safetensors>=0.4.0",
|
| 12 |
"tqdm~=4.67.0",
|
| 13 |
"wandb~=0.25.0",
|
| 14 |
]
|
scripts/check_progress.sh
CHANGED
|
@@ -1,33 +1,34 @@
|
|
| 1 |
#!/usr/bin/env bash
|
| 2 |
-
# Check training progress
|
| 3 |
-
# Usage: check_progress.sh [--sync] [
|
| 4 |
set -euo pipefail
|
| 5 |
|
| 6 |
SYNC=false
|
| 7 |
-
AUTO_STOP=false
|
| 8 |
LOG_DIR=""
|
| 9 |
|
| 10 |
for arg in "$@"; do
|
| 11 |
case "$arg" in
|
| 12 |
--sync) SYNC=true ;;
|
| 13 |
-
--auto-stop) AUTO_STOP=true ;;
|
| 14 |
*) LOG_DIR="$arg" ;;
|
| 15 |
esac
|
| 16 |
done
|
| 17 |
LOG_DIR="${LOG_DIR:-logs}"
|
| 18 |
|
| 19 |
REPO="$(cd "$(dirname "$0")/.." && pwd)"
|
| 20 |
-
POD_DIR="$HOME/.config/pawn/pods"
|
| 21 |
|
| 22 |
-
# Sync from
|
| 23 |
-
if $SYNC
|
| 24 |
bash "$REPO/deploy/sync.sh" 2>/dev/null || true
|
| 25 |
fi
|
| 26 |
|
| 27 |
-
# Show progress
|
| 28 |
-
|
| 29 |
-
|
|
|
|
|
|
|
|
|
|
| 30 |
run_name="$(basename "$(dirname "$path")")"
|
|
|
|
| 31 |
python3 -c "
|
| 32 |
import json, sys
|
| 33 |
records = [json.loads(l) for l in open('$path')]
|
|
@@ -51,29 +52,8 @@ print(f'{\"$run_name\":<28} {variant} discard_ply={str(discard):<5} step {ste
|
|
| 51 |
done
|
| 52 |
|
| 53 |
# Check local training process
|
| 54 |
-
if pgrep -f 'train.py
|
| 55 |
-
echo "Local
|
| 56 |
else
|
| 57 |
-
echo "
|
| 58 |
-
fi
|
| 59 |
-
|
| 60 |
-
# Auto-stop finished pods
|
| 61 |
-
if $AUTO_STOP && [ -d "$POD_DIR" ]; then
|
| 62 |
-
for env_file in "$POD_DIR"/*.env; do
|
| 63 |
-
[ -f "$env_file" ] || continue
|
| 64 |
-
pod_name="$(basename "${env_file%.env}")"
|
| 65 |
-
unset POD_ID POD_HOST POD_PORT POD_GPU 2>/dev/null || true
|
| 66 |
-
source "$env_file"
|
| 67 |
-
|
| 68 |
-
# Check if process is alive on pod
|
| 69 |
-
alive=$(ssh -o ConnectTimeout=5 -p "$POD_PORT" "root@$POD_HOST" \
|
| 70 |
-
"pgrep -f 'train.py' > /dev/null 2>&1 && echo yes || echo no" 2>/dev/null || echo "unreachable")
|
| 71 |
-
|
| 72 |
-
if [ "$alive" = "no" ]; then
|
| 73 |
-
echo ">>> $pod_name: training finished. Final sync + stopping..."
|
| 74 |
-
bash "$REPO/deploy/sync.sh" "$pod_name" 2>/dev/null || true
|
| 75 |
-
runpodctl pod stop "$POD_ID" 2>/dev/null
|
| 76 |
-
echo ">>> $pod_name ($POD_ID) STOPPED"
|
| 77 |
-
fi
|
| 78 |
-
done
|
| 79 |
fi
|
|
|
|
| 1 |
#!/usr/bin/env bash
|
| 2 |
+
# Check training progress from HuggingFace submodules and local logs.
|
| 3 |
+
# Usage: check_progress.sh [--sync] [LOG_DIR]
|
| 4 |
set -euo pipefail
|
| 5 |
|
| 6 |
SYNC=false
|
|
|
|
| 7 |
LOG_DIR=""
|
| 8 |
|
| 9 |
for arg in "$@"; do
|
| 10 |
case "$arg" in
|
| 11 |
--sync) SYNC=true ;;
|
|
|
|
| 12 |
*) LOG_DIR="$arg" ;;
|
| 13 |
esac
|
| 14 |
done
|
| 15 |
LOG_DIR="${LOG_DIR:-logs}"
|
| 16 |
|
| 17 |
REPO="$(cd "$(dirname "$0")/.." && pwd)"
|
|
|
|
| 18 |
|
| 19 |
+
# Sync submodules from HuggingFace
|
| 20 |
+
if $SYNC; then
|
| 21 |
bash "$REPO/deploy/sync.sh" 2>/dev/null || true
|
| 22 |
fi
|
| 23 |
|
| 24 |
+
# Show progress from all metrics.jsonl files (local logs + submodules)
|
| 25 |
+
N=5
|
| 26 |
+
{
|
| 27 |
+
find "$LOG_DIR" -name metrics.jsonl -printf '%T@ %p\n' 2>/dev/null
|
| 28 |
+
find "$REPO/checkpoints" -name metrics.jsonl -printf '%T@ %p\n' 2>/dev/null
|
| 29 |
+
} | sort -rn | head -n "$N" | while read -r _ path; do
|
| 30 |
run_name="$(basename "$(dirname "$path")")"
|
| 31 |
+
|
| 32 |
python3 -c "
|
| 33 |
import json, sys
|
| 34 |
records = [json.loads(l) for l in open('$path')]
|
|
|
|
| 52 |
done
|
| 53 |
|
| 54 |
# Check local training process
|
| 55 |
+
if pgrep -f 'train.py' > /dev/null 2>&1; then
|
| 56 |
+
echo "Local training: RUNNING"
|
| 57 |
else
|
| 58 |
+
echo "Local training: NOT RUNNING"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
fi
|
scripts/eval_accuracy.py
CHANGED
|
@@ -65,37 +65,47 @@ def parse_args():
|
|
| 65 |
return p.parse_args()
|
| 66 |
|
| 67 |
|
| 68 |
-
def _detect_adapter_type(
|
| 69 |
-
"""Auto-detect adapter type from
|
| 70 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 71 |
return "hybrid"
|
| 72 |
-
if "
|
| 73 |
return "lora"
|
| 74 |
-
if "
|
| 75 |
-
return "film"
|
| 76 |
-
if "sparse_state_dict" in ckpt:
|
| 77 |
return "sparse"
|
| 78 |
-
if "
|
| 79 |
-
return "
|
| 80 |
-
|
| 81 |
-
|
|
|
|
| 82 |
|
| 83 |
|
| 84 |
def load_model(checkpoint_path: str, adapter_path: str, device: str):
|
| 85 |
"""Load backbone + adapter, auto-detecting adapter type."""
|
|
|
|
|
|
|
| 86 |
# Backbone
|
| 87 |
-
|
| 88 |
-
cfg = CLMConfig(**
|
| 89 |
backbone = PAWNCLM(cfg).to(device)
|
| 90 |
-
backbone.load_state_dict(
|
| 91 |
-
del
|
| 92 |
gc.collect()
|
| 93 |
backbone.eval()
|
| 94 |
|
| 95 |
# Adapter
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
adapter_config =
|
|
|
|
| 99 |
|
| 100 |
if adapter_type == "lora":
|
| 101 |
from pawn.adapters.lora import LoRACLM
|
|
@@ -108,15 +118,15 @@ def load_model(checkpoint_path: str, adapter_path: str, device: str):
|
|
| 108 |
layers=tuple(int(x) for x in adapter_config["lora_layers"].split(","))
|
| 109 |
if adapter_config.get("lora_layers") else None,
|
| 110 |
).to(device)
|
| 111 |
-
model.load_lora_state_dict(
|
| 112 |
|
| 113 |
elif adapter_type == "film":
|
| 114 |
from pawn.adapters.film import FiLMCLM
|
| 115 |
has_output = not adapter_config.get("no_output_film", False)
|
| 116 |
-
if any(k.startswith("output_film.") for k in
|
| 117 |
has_output = True
|
| 118 |
model = FiLMCLM(backbone, use_output_film=has_output).to(device)
|
| 119 |
-
model.load_film_state_dict(
|
| 120 |
|
| 121 |
elif adapter_type == "hybrid":
|
| 122 |
from pawn.adapters.hybrid import HybridCLM
|
|
@@ -133,7 +143,7 @@ def load_model(checkpoint_path: str, adapter_path: str, device: str):
|
|
| 133 |
use_output_film=adapter_config.get("output_film", False),
|
| 134 |
film_layers=tuple(int(x) for x in film_layers.split(",")) if film_layers else None,
|
| 135 |
).to(device)
|
| 136 |
-
model.load_adapter_state_dict(
|
| 137 |
|
| 138 |
elif adapter_type == "sparse":
|
| 139 |
from pawn.adapters.sparse import SparseCLM
|
|
@@ -148,7 +158,7 @@ def load_model(checkpoint_path: str, adapter_path: str, device: str):
|
|
| 148 |
layers=tuple(int(x) for x in sparse_layers.split(",")) if sparse_layers else None,
|
| 149 |
seed=adapter_config.get("sparse_seed", 42),
|
| 150 |
).to(device)
|
| 151 |
-
model.load_sparse_state_dict(
|
| 152 |
|
| 153 |
elif adapter_type == "bottleneck":
|
| 154 |
from pawn.adapters.bottleneck import BottleneckCLM
|
|
@@ -160,7 +170,7 @@ def load_model(checkpoint_path: str, adapter_path: str, device: str):
|
|
| 160 |
adapt_ffn=adapter_config.get("adapt_ffn", True),
|
| 161 |
layers=tuple(int(x) for x in adapter_layers_str.split(",")) if adapter_layers_str else None,
|
| 162 |
).to(device)
|
| 163 |
-
model.load_adapter_state_dict(
|
| 164 |
|
| 165 |
model.eval()
|
| 166 |
return model, adapter_type
|
|
|
|
| 65 |
return p.parse_args()
|
| 66 |
|
| 67 |
|
| 68 |
+
def _detect_adapter_type(config: dict) -> str:
|
| 69 |
+
"""Auto-detect adapter type from config dict.
|
| 70 |
+
|
| 71 |
+
The config dict comes from load_adapter_checkpoint()["config"], which
|
| 72 |
+
contains training args for both legacy .pt files and new-format directories.
|
| 73 |
+
"""
|
| 74 |
+
if "checkpoint_type" in config:
|
| 75 |
+
return config["checkpoint_type"]
|
| 76 |
+
if "bottleneck_dim" in config:
|
| 77 |
+
return "bottleneck"
|
| 78 |
+
if "lora_rank" in config and config.get("use_film") is not None:
|
| 79 |
return "hybrid"
|
| 80 |
+
if "lora_rank" in config:
|
| 81 |
return "lora"
|
| 82 |
+
if "density" in config:
|
|
|
|
|
|
|
| 83 |
return "sparse"
|
| 84 |
+
if "no_output_film" in config:
|
| 85 |
+
return "film"
|
| 86 |
+
|
| 87 |
+
raise ValueError("Cannot detect adapter type from config keys: "
|
| 88 |
+
+ ", ".join(config.keys()))
|
| 89 |
|
| 90 |
|
| 91 |
def load_model(checkpoint_path: str, adapter_path: str, device: str):
|
| 92 |
"""Load backbone + adapter, auto-detecting adapter type."""
|
| 93 |
+
from pawn.checkpoint import load_backbone_weights, load_adapter_checkpoint
|
| 94 |
+
|
| 95 |
# Backbone
|
| 96 |
+
state_dict, model_config = load_backbone_weights(checkpoint_path, device)
|
| 97 |
+
cfg = CLMConfig(**model_config) if model_config else CLMConfig()
|
| 98 |
backbone = PAWNCLM(cfg).to(device)
|
| 99 |
+
backbone.load_state_dict(state_dict)
|
| 100 |
+
del state_dict
|
| 101 |
gc.collect()
|
| 102 |
backbone.eval()
|
| 103 |
|
| 104 |
# Adapter
|
| 105 |
+
adapter_data = load_adapter_checkpoint(adapter_path, device)
|
| 106 |
+
adapter_weights = adapter_data["adapter_state_dict"]
|
| 107 |
+
adapter_config = adapter_data.get("config", {})
|
| 108 |
+
adapter_type = _detect_adapter_type(adapter_config)
|
| 109 |
|
| 110 |
if adapter_type == "lora":
|
| 111 |
from pawn.adapters.lora import LoRACLM
|
|
|
|
| 118 |
layers=tuple(int(x) for x in adapter_config["lora_layers"].split(","))
|
| 119 |
if adapter_config.get("lora_layers") else None,
|
| 120 |
).to(device)
|
| 121 |
+
model.load_lora_state_dict(adapter_weights)
|
| 122 |
|
| 123 |
elif adapter_type == "film":
|
| 124 |
from pawn.adapters.film import FiLMCLM
|
| 125 |
has_output = not adapter_config.get("no_output_film", False)
|
| 126 |
+
if any(k.startswith("output_film.") for k in adapter_weights):
|
| 127 |
has_output = True
|
| 128 |
model = FiLMCLM(backbone, use_output_film=has_output).to(device)
|
| 129 |
+
model.load_film_state_dict(adapter_weights)
|
| 130 |
|
| 131 |
elif adapter_type == "hybrid":
|
| 132 |
from pawn.adapters.hybrid import HybridCLM
|
|
|
|
| 143 |
use_output_film=adapter_config.get("output_film", False),
|
| 144 |
film_layers=tuple(int(x) for x in film_layers.split(",")) if film_layers else None,
|
| 145 |
).to(device)
|
| 146 |
+
model.load_adapter_state_dict(adapter_weights)
|
| 147 |
|
| 148 |
elif adapter_type == "sparse":
|
| 149 |
from pawn.adapters.sparse import SparseCLM
|
|
|
|
| 158 |
layers=tuple(int(x) for x in sparse_layers.split(",")) if sparse_layers else None,
|
| 159 |
seed=adapter_config.get("sparse_seed", 42),
|
| 160 |
).to(device)
|
| 161 |
+
model.load_sparse_state_dict(adapter_weights)
|
| 162 |
|
| 163 |
elif adapter_type == "bottleneck":
|
| 164 |
from pawn.adapters.bottleneck import BottleneckCLM
|
|
|
|
| 170 |
adapt_ffn=adapter_config.get("adapt_ffn", True),
|
| 171 |
layers=tuple(int(x) for x in adapter_layers_str.split(",")) if adapter_layers_str else None,
|
| 172 |
).to(device)
|
| 173 |
+
model.load_adapter_state_dict(adapter_weights)
|
| 174 |
|
| 175 |
model.eval()
|
| 176 |
return model, adapter_type
|
scripts/eval_probes.py
CHANGED
|
@@ -11,20 +11,17 @@ import torch
|
|
| 11 |
from pawn.config import CLMConfig
|
| 12 |
from pawn.model import PAWNCLM
|
| 13 |
from pawn.eval_suite.probes import extract_probe_data, train_all_probes
|
| 14 |
-
from pawn.gpu import configure_gpu
|
| 15 |
|
| 16 |
|
| 17 |
def load_model_from_checkpoint(checkpoint_path: str, device: str) -> PAWNCLM:
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
if
|
| 21 |
-
cfg = CLMConfig(**
|
| 22 |
else:
|
| 23 |
-
state
|
| 24 |
-
d_model =
|
| 25 |
-
n_layers = max(int(k.split(".")[1]) for k in
|
| 26 |
-
n_heads = state["layers.0.attn.wq.weight"].shape[0] // (d_model // (d_model // (state["layers.0.attn.wq.weight"].shape[0] // d_model * d_model // state["layers.0.attn.wq.weight"].shape[0]) if True else 1))
|
| 27 |
-
# Infer config from known variants
|
| 28 |
if d_model == 256 and n_layers == 8:
|
| 29 |
cfg = CLMConfig.small()
|
| 30 |
elif d_model == 512 and n_layers == 8:
|
|
@@ -33,9 +30,8 @@ def load_model_from_checkpoint(checkpoint_path: str, device: str) -> PAWNCLM:
|
|
| 33 |
cfg = CLMConfig.large()
|
| 34 |
else:
|
| 35 |
cfg = CLMConfig(d_model=d_model, n_layers=n_layers)
|
| 36 |
-
|
| 37 |
model = PAWNCLM(cfg).to(device)
|
| 38 |
-
model.load_state_dict(
|
| 39 |
model.eval()
|
| 40 |
return model
|
| 41 |
|
|
@@ -52,6 +48,7 @@ def main():
|
|
| 52 |
|
| 53 |
device = args.device or ("cuda" if torch.cuda.is_available() else "cpu")
|
| 54 |
if device == "cuda":
|
|
|
|
| 55 |
gpu_cfg = configure_gpu()
|
| 56 |
import pawn.model as model_module
|
| 57 |
model_module.SDPA_BACKEND = gpu_cfg.get("sdpa_backend")
|
|
|
|
| 11 |
from pawn.config import CLMConfig
|
| 12 |
from pawn.model import PAWNCLM
|
| 13 |
from pawn.eval_suite.probes import extract_probe_data, train_all_probes
|
|
|
|
| 14 |
|
| 15 |
|
| 16 |
def load_model_from_checkpoint(checkpoint_path: str, device: str) -> PAWNCLM:
|
| 17 |
+
from pawn.checkpoint import load_backbone_weights
|
| 18 |
+
state_dict, model_config = load_backbone_weights(checkpoint_path, device)
|
| 19 |
+
if model_config:
|
| 20 |
+
cfg = CLMConfig(**model_config)
|
| 21 |
else:
|
| 22 |
+
# Fallback: infer from state dict shapes
|
| 23 |
+
d_model = state_dict["embed.src_embed.weight"].shape[1]
|
| 24 |
+
n_layers = max(int(k.split(".")[1]) for k in state_dict if k.startswith("layers.")) + 1
|
|
|
|
|
|
|
| 25 |
if d_model == 256 and n_layers == 8:
|
| 26 |
cfg = CLMConfig.small()
|
| 27 |
elif d_model == 512 and n_layers == 8:
|
|
|
|
| 30 |
cfg = CLMConfig.large()
|
| 31 |
else:
|
| 32 |
cfg = CLMConfig(d_model=d_model, n_layers=n_layers)
|
|
|
|
| 33 |
model = PAWNCLM(cfg).to(device)
|
| 34 |
+
model.load_state_dict(state_dict)
|
| 35 |
model.eval()
|
| 36 |
return model
|
| 37 |
|
|
|
|
| 48 |
|
| 49 |
device = args.device or ("cuda" if torch.cuda.is_available() else "cpu")
|
| 50 |
if device == "cuda":
|
| 51 |
+
from pawn.gpu import configure_gpu
|
| 52 |
gpu_cfg = configure_gpu()
|
| 53 |
import pawn.model as model_module
|
| 54 |
model_module.SDPA_BACKEND = gpu_cfg.get("sdpa_backend")
|
scripts/export_hf_repo.py
ADDED
|
@@ -0,0 +1,283 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""Export a training run to HuggingFace repo format.
|
| 3 |
+
|
| 4 |
+
Converts .pt checkpoints to safetensors and structures files for HF upload:
|
| 5 |
+
- Root: best checkpoint (model.safetensors, config.json, metrics.jsonl, README.md)
|
| 6 |
+
- checkpoints/step_NNNN/: other checkpoints with truncated metrics
|
| 7 |
+
|
| 8 |
+
Usage:
|
| 9 |
+
python scripts/export_hf_repo.py \
|
| 10 |
+
--run-dir logs/run_20260322_182707 \
|
| 11 |
+
--output-dir export/pawn-base \
|
| 12 |
+
--repo-name pawn-base \
|
| 13 |
+
--github-url https://github.com/thomas-schweich/PAWN
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
from __future__ import annotations
|
| 17 |
+
|
| 18 |
+
import argparse
|
| 19 |
+
import json
|
| 20 |
+
import shutil
|
| 21 |
+
from pathlib import Path
|
| 22 |
+
|
| 23 |
+
import torch
|
| 24 |
+
from safetensors.torch import save_file
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def find_best_step(metrics_path: Path) -> int | None:
|
| 28 |
+
"""Find the step with lowest val loss from metrics.jsonl."""
|
| 29 |
+
best_loss = float("inf")
|
| 30 |
+
best_step = None
|
| 31 |
+
with open(metrics_path) as f:
|
| 32 |
+
for line in f:
|
| 33 |
+
record = json.loads(line)
|
| 34 |
+
if record.get("type") != "val":
|
| 35 |
+
continue
|
| 36 |
+
loss = record.get("val/loss", float("inf"))
|
| 37 |
+
step = record.get("step")
|
| 38 |
+
if loss < best_loss and step is not None:
|
| 39 |
+
best_loss = loss
|
| 40 |
+
best_step = step
|
| 41 |
+
return best_step
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def truncate_metrics(metrics_path: Path, up_to_step: int) -> list[str]:
|
| 45 |
+
"""Return metrics lines up to and including the given step."""
|
| 46 |
+
lines = []
|
| 47 |
+
with open(metrics_path) as f:
|
| 48 |
+
for line in f:
|
| 49 |
+
lines.append(line)
|
| 50 |
+
record = json.loads(line)
|
| 51 |
+
if record.get("type") in ("train", "val") and record.get("step", 0) > up_to_step:
|
| 52 |
+
break
|
| 53 |
+
return lines
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def convert_pt_to_safetensors(pt_path: Path, output_dir: Path):
|
| 57 |
+
"""Convert a .pt checkpoint to safetensors + JSON directory format."""
|
| 58 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 59 |
+
|
| 60 |
+
ckpt = torch.load(str(pt_path), map_location="cpu", weights_only=False)
|
| 61 |
+
|
| 62 |
+
# Model weights -> safetensors
|
| 63 |
+
state_dict = ckpt["model_state_dict"]
|
| 64 |
+
tensors = {k: v.cpu().contiguous() for k, v in state_dict.items()}
|
| 65 |
+
save_file(tensors, output_dir / "model.safetensors")
|
| 66 |
+
|
| 67 |
+
# Config
|
| 68 |
+
config = {
|
| 69 |
+
"format_version": 1,
|
| 70 |
+
"checkpoint_type": "pretrain",
|
| 71 |
+
"model_config": ckpt.get("model_config", {}),
|
| 72 |
+
"training_config": ckpt.get("training_config", {}),
|
| 73 |
+
}
|
| 74 |
+
with open(output_dir / "config.json", "w") as f:
|
| 75 |
+
json.dump(config, f, indent=2, default=str)
|
| 76 |
+
|
| 77 |
+
# Optimizer -> safetensors (if present)
|
| 78 |
+
if "optimizer_state_dict" in ckpt:
|
| 79 |
+
from pawn.checkpoint import _flatten_optimizer_state, _rng_to_json, _json_default
|
| 80 |
+
opt_tensors, opt_meta = _flatten_optimizer_state(ckpt["optimizer_state_dict"])
|
| 81 |
+
if opt_tensors:
|
| 82 |
+
save_file(opt_tensors, output_dir / "optimizer.safetensors")
|
| 83 |
+
|
| 84 |
+
# Training state
|
| 85 |
+
training_state = {
|
| 86 |
+
"format_version": 1,
|
| 87 |
+
"global_step": ckpt.get("global_step", 0),
|
| 88 |
+
"scheduler_state_dict": ckpt.get("scheduler_state_dict"),
|
| 89 |
+
"scaler_state_dict": ckpt.get("scaler_state_dict"),
|
| 90 |
+
"optimizer_meta": opt_meta,
|
| 91 |
+
}
|
| 92 |
+
rng_state = {}
|
| 93 |
+
if ckpt.get("torch_rng_state") is not None:
|
| 94 |
+
rng_state.update(_rng_to_json(ckpt["torch_rng_state"], ckpt.get("cuda_rng_state")))
|
| 95 |
+
training_state.update(rng_state)
|
| 96 |
+
|
| 97 |
+
with open(output_dir / "training_state.json", "w") as f:
|
| 98 |
+
json.dump(training_state, f, indent=2, default=_json_default)
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def generate_readme(
|
| 102 |
+
repo_name: str, model_config: dict, training_config: dict,
|
| 103 |
+
best_step: int, val_loss: float, val_acc: float,
|
| 104 |
+
github_url: str, extra_desc: str = "",
|
| 105 |
+
) -> str:
|
| 106 |
+
"""Generate a HuggingFace model card README."""
|
| 107 |
+
d_model = model_config.get("d_model", "?")
|
| 108 |
+
n_layers = model_config.get("n_layers", "?")
|
| 109 |
+
n_heads = model_config.get("n_heads", "?")
|
| 110 |
+
discard = training_config.get("discard_ply_limit", False)
|
| 111 |
+
|
| 112 |
+
# Infer variant name
|
| 113 |
+
variant = "base"
|
| 114 |
+
if d_model == 256:
|
| 115 |
+
variant = "small"
|
| 116 |
+
elif d_model == 640:
|
| 117 |
+
variant = "large"
|
| 118 |
+
|
| 119 |
+
params = {"small": "9.5M", "base": "35.8M", "large": "68.4M"}.get(variant, "?")
|
| 120 |
+
|
| 121 |
+
return f"""---
|
| 122 |
+
license: apache-2.0
|
| 123 |
+
library_name: pytorch
|
| 124 |
+
tags:
|
| 125 |
+
- chess
|
| 126 |
+
- transformer
|
| 127 |
+
- causal-lm
|
| 128 |
+
- world-model
|
| 129 |
+
datasets:
|
| 130 |
+
- random-self-play
|
| 131 |
+
model-index:
|
| 132 |
+
- name: {repo_name}
|
| 133 |
+
results:
|
| 134 |
+
- task:
|
| 135 |
+
type: next-move-prediction
|
| 136 |
+
metrics:
|
| 137 |
+
- name: Val Loss
|
| 138 |
+
type: loss
|
| 139 |
+
value: {val_loss}
|
| 140 |
+
- name: Val Accuracy
|
| 141 |
+
type: accuracy
|
| 142 |
+
value: {val_acc}
|
| 143 |
+
---
|
| 144 |
+
|
| 145 |
+
# {repo_name.upper()}
|
| 146 |
+
|
| 147 |
+
A causal transformer trained on random chess games, designed as a testbed for finetuning and augmentation methods at small scales.
|
| 148 |
+
{extra_desc}
|
| 149 |
+
|
| 150 |
+
## Model Details
|
| 151 |
+
|
| 152 |
+
| | |
|
| 153 |
+
|---|---|
|
| 154 |
+
| **Parameters** | {params} |
|
| 155 |
+
| **Architecture** | Decoder-only transformer (RMSNorm, SwiGLU, RoPE) |
|
| 156 |
+
| **d_model** | {d_model} |
|
| 157 |
+
| **Layers** | {n_layers} |
|
| 158 |
+
| **Heads** | {n_heads} |
|
| 159 |
+
| **Vocabulary** | 4,278 tokens (4,096 grid + 176 promotions + 5 outcomes + 1 PAD) |
|
| 160 |
+
| **Sequence length** | 256 |
|
| 161 |
+
| **Best val loss** | {val_loss:.4f} (step {best_step:,}) |
|
| 162 |
+
| **Best val accuracy** | {val_acc:.1%} |
|
| 163 |
+
|
| 164 |
+
## Usage
|
| 165 |
+
|
| 166 |
+
```python
|
| 167 |
+
import torch
|
| 168 |
+
from safetensors.torch import load_file
|
| 169 |
+
from pawn.config import CLMConfig
|
| 170 |
+
from pawn.model import PAWNCLM
|
| 171 |
+
|
| 172 |
+
cfg = CLMConfig.{variant}()
|
| 173 |
+
model = PAWNCLM(cfg)
|
| 174 |
+
model.load_state_dict(load_file("model.safetensors"))
|
| 175 |
+
model.eval()
|
| 176 |
+
```
|
| 177 |
+
|
| 178 |
+
## Training
|
| 179 |
+
|
| 180 |
+
Trained from scratch on random self-play games generated by a Rust chess engine (shakmaty).
|
| 181 |
+
See the [PAWN repository]({github_url}) for training code, data pipeline, and evaluation suite.
|
| 182 |
+
|
| 183 |
+
## License
|
| 184 |
+
|
| 185 |
+
Apache 2.0
|
| 186 |
+
"""
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
def main():
|
| 190 |
+
parser = argparse.ArgumentParser(description="Export training run to HF repo format")
|
| 191 |
+
parser.add_argument("--run-dir", required=True, help="Training run directory")
|
| 192 |
+
parser.add_argument("--output-dir", required=True, help="Output directory for HF repo")
|
| 193 |
+
parser.add_argument("--repo-name", required=True, help="Repository name for README")
|
| 194 |
+
parser.add_argument("--github-url", default="https://github.com/thomas-schweich/PAWN")
|
| 195 |
+
parser.add_argument("--best-only", action="store_true", help="Only export best checkpoint")
|
| 196 |
+
parser.add_argument("--extra-desc", default="", help="Extra description for README")
|
| 197 |
+
args = parser.parse_args()
|
| 198 |
+
|
| 199 |
+
run_dir = Path(args.run_dir)
|
| 200 |
+
output_dir = Path(args.output_dir)
|
| 201 |
+
metrics_path = run_dir / "metrics.jsonl"
|
| 202 |
+
|
| 203 |
+
if not metrics_path.exists():
|
| 204 |
+
print(f"ERROR: {metrics_path} not found")
|
| 205 |
+
return
|
| 206 |
+
|
| 207 |
+
# Find best step
|
| 208 |
+
best_step = find_best_step(metrics_path)
|
| 209 |
+
if best_step is None:
|
| 210 |
+
print("ERROR: No val records found in metrics.jsonl")
|
| 211 |
+
return
|
| 212 |
+
print(f"Best val step: {best_step}")
|
| 213 |
+
|
| 214 |
+
# Find best val metrics
|
| 215 |
+
best_val_loss, best_val_acc = float("inf"), 0.0
|
| 216 |
+
with open(metrics_path) as f:
|
| 217 |
+
for line in f:
|
| 218 |
+
r = json.loads(line)
|
| 219 |
+
if r.get("type") == "val" and r.get("step") == best_step:
|
| 220 |
+
best_val_loss = r.get("val/loss", float("inf"))
|
| 221 |
+
best_val_acc = r.get("val/accuracy", 0.0)
|
| 222 |
+
break
|
| 223 |
+
|
| 224 |
+
# Find all .pt checkpoints
|
| 225 |
+
ckpt_dir = run_dir / "checkpoints"
|
| 226 |
+
checkpoints = sorted(ckpt_dir.glob("step_*.pt")) if ckpt_dir.exists() else []
|
| 227 |
+
if not checkpoints:
|
| 228 |
+
print("ERROR: No checkpoints found")
|
| 229 |
+
return
|
| 230 |
+
|
| 231 |
+
# Find nearest checkpoint to best step
|
| 232 |
+
best_ckpt = min(checkpoints, key=lambda p: abs(
|
| 233 |
+
int(p.stem.replace("step_", "")) - best_step
|
| 234 |
+
))
|
| 235 |
+
print(f"Best checkpoint: {best_ckpt}")
|
| 236 |
+
|
| 237 |
+
# Read config from checkpoint
|
| 238 |
+
ckpt_data = torch.load(str(best_ckpt), map_location="cpu", weights_only=False)
|
| 239 |
+
model_config = ckpt_data.get("model_config", {})
|
| 240 |
+
training_config = ckpt_data.get("training_config", {})
|
| 241 |
+
del ckpt_data
|
| 242 |
+
|
| 243 |
+
# Export best checkpoint to root
|
| 244 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 245 |
+
print(f"\nExporting best checkpoint to {output_dir}/")
|
| 246 |
+
convert_pt_to_safetensors(best_ckpt, output_dir)
|
| 247 |
+
|
| 248 |
+
# Copy full metrics.jsonl
|
| 249 |
+
shutil.copy2(metrics_path, output_dir / "metrics.jsonl")
|
| 250 |
+
|
| 251 |
+
# Generate README
|
| 252 |
+
readme = generate_readme(
|
| 253 |
+
args.repo_name, model_config, training_config,
|
| 254 |
+
best_step, best_val_loss, best_val_acc,
|
| 255 |
+
args.github_url, args.extra_desc,
|
| 256 |
+
)
|
| 257 |
+
with open(output_dir / "README.md", "w") as f:
|
| 258 |
+
f.write(readme)
|
| 259 |
+
|
| 260 |
+
# Export other checkpoints
|
| 261 |
+
if not args.best_only:
|
| 262 |
+
for ckpt in checkpoints:
|
| 263 |
+
if ckpt == best_ckpt:
|
| 264 |
+
continue
|
| 265 |
+
step_name = ckpt.stem # e.g. "step_00005000"
|
| 266 |
+
step_num = int(step_name.replace("step_", ""))
|
| 267 |
+
step_dir = output_dir / "checkpoints" / step_name
|
| 268 |
+
print(f" Exporting {step_name}...")
|
| 269 |
+
convert_pt_to_safetensors(ckpt, step_dir)
|
| 270 |
+
|
| 271 |
+
# Truncated metrics
|
| 272 |
+
truncated = truncate_metrics(metrics_path, step_num)
|
| 273 |
+
with open(step_dir / "metrics.jsonl", "w") as f:
|
| 274 |
+
f.writelines(truncated)
|
| 275 |
+
|
| 276 |
+
print(f"\nExport complete: {output_dir}")
|
| 277 |
+
print(f" Best: model.safetensors, config.json, metrics.jsonl, README.md")
|
| 278 |
+
if not args.best_only:
|
| 279 |
+
print(f" Checkpoints: {len(checkpoints) - 1} in checkpoints/")
|
| 280 |
+
|
| 281 |
+
|
| 282 |
+
if __name__ == "__main__":
|
| 283 |
+
main()
|
scripts/profile_step.py
CHANGED
|
@@ -18,17 +18,18 @@ import torch
|
|
| 18 |
# ── PAWN imports ─────────────────────────────────────────────────────────────
|
| 19 |
from pawn.config import CLMConfig
|
| 20 |
from pawn.model import PAWNCLM
|
| 21 |
-
from pawn.data import _to_clm_batch
|
| 22 |
import chess_engine as engine
|
| 23 |
|
| 24 |
|
| 25 |
def generate_clm_batch(batch_size: int, device: str):
|
| 26 |
"""Generate a CLM batch and move to device."""
|
| 27 |
-
|
| 28 |
-
batch_size,
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
|
|
|
|
|
|
| 32 |
|
| 33 |
|
| 34 |
|
|
|
|
| 18 |
# ── PAWN imports ─────────────────────────────────────────────────────────────
|
| 19 |
from pawn.config import CLMConfig
|
| 20 |
from pawn.model import PAWNCLM
|
|
|
|
| 21 |
import chess_engine as engine
|
| 22 |
|
| 23 |
|
| 24 |
def generate_clm_batch(batch_size: int, device: str):
|
| 25 |
"""Generate a CLM batch and move to device."""
|
| 26 |
+
input_ids, targets, loss_mask, _mid, _gl, _tc = \
|
| 27 |
+
engine.generate_clm_batch(batch_size, 256, seed=42)
|
| 28 |
+
return {
|
| 29 |
+
"input_ids": torch.from_numpy(input_ids).long().to(device),
|
| 30 |
+
"targets": torch.from_numpy(targets).long().to(device),
|
| 31 |
+
"loss_mask": torch.from_numpy(loss_mask).to(device),
|
| 32 |
+
}
|
| 33 |
|
| 34 |
|
| 35 |
|
scripts/train.py
CHANGED
|
@@ -30,6 +30,12 @@ def parse_args():
|
|
| 30 |
parser.add_argument("--log-dir", type=str, default=None, help="Override log directory")
|
| 31 |
parser.add_argument("--discard-ply-limit", action="store_true",
|
| 32 |
help="Only train on games that ended naturally (no ply limit truncation)")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
return parser.parse_args()
|
| 34 |
|
| 35 |
|
|
@@ -73,7 +79,7 @@ def main():
|
|
| 73 |
print(f"Model config: {model_cfg}")
|
| 74 |
print(f"Training config: {train_cfg}")
|
| 75 |
|
| 76 |
-
trainer = CLMTrainer(train_cfg, model_cfg)
|
| 77 |
|
| 78 |
if args.resume:
|
| 79 |
trainer.load_checkpoint(args.resume)
|
|
|
|
| 30 |
parser.add_argument("--log-dir", type=str, default=None, help="Override log directory")
|
| 31 |
parser.add_argument("--discard-ply-limit", action="store_true",
|
| 32 |
help="Only train on games that ended naturally (no ply limit truncation)")
|
| 33 |
+
|
| 34 |
+
ckpt_group = parser.add_mutually_exclusive_group(required=True)
|
| 35 |
+
ckpt_group.add_argument("--hf-repo", type=str, default=None,
|
| 36 |
+
help="Push checkpoints to this HuggingFace repo (requires HF_TOKEN)")
|
| 37 |
+
ckpt_group.add_argument("--local-checkpoints", action="store_true",
|
| 38 |
+
help="Save checkpoints locally only (no HuggingFace push)")
|
| 39 |
return parser.parse_args()
|
| 40 |
|
| 41 |
|
|
|
|
| 79 |
print(f"Model config: {model_cfg}")
|
| 80 |
print(f"Training config: {train_cfg}")
|
| 81 |
|
| 82 |
+
trainer = CLMTrainer(train_cfg, model_cfg, hf_repo=args.hf_repo)
|
| 83 |
|
| 84 |
if args.resume:
|
| 85 |
trainer.load_checkpoint(args.resume)
|
scripts/train_all.py
ADDED
|
@@ -0,0 +1,400 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""Train small, base, and large PAWN models simultaneously on shared data.
|
| 3 |
+
|
| 4 |
+
All three models see the exact same batches in the same order, eliminating
|
| 5 |
+
data generation overhead and ensuring comparable training conditions.
|
| 6 |
+
|
| 7 |
+
Usage:
|
| 8 |
+
uv run python scripts/train_all.py --local-checkpoints
|
| 9 |
+
uv run python scripts/train_all.py --hf-repo thomas-schweich/pawn-{variant}
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
from __future__ import annotations
|
| 13 |
+
|
| 14 |
+
import argparse
|
| 15 |
+
import json
|
| 16 |
+
import math
|
| 17 |
+
import os
|
| 18 |
+
import signal
|
| 19 |
+
import sys
|
| 20 |
+
import time
|
| 21 |
+
|
| 22 |
+
import torch
|
| 23 |
+
import torch.multiprocessing as mp
|
| 24 |
+
from torch.utils.data import DataLoader
|
| 25 |
+
|
| 26 |
+
from pawn.config import CLMConfig, TrainingConfig
|
| 27 |
+
from pawn.model import PAWNCLM, clm_loss
|
| 28 |
+
from pawn.data import CLMDataset, create_validation_set
|
| 29 |
+
from pawn.gpu import configure_gpu, apply_gpu_config
|
| 30 |
+
from pawn.checkpoint import save_pretrain_checkpoint, push_checkpoint_to_hf
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
# ---------------------------------------------------------------------------
|
| 34 |
+
# Per-model state
|
| 35 |
+
# ---------------------------------------------------------------------------
|
| 36 |
+
|
| 37 |
+
class ModelSlot:
|
| 38 |
+
"""Holds everything needed to train and checkpoint one model variant."""
|
| 39 |
+
|
| 40 |
+
def __init__(
|
| 41 |
+
self,
|
| 42 |
+
name: str,
|
| 43 |
+
model_cfg: CLMConfig,
|
| 44 |
+
train_cfg: TrainingConfig,
|
| 45 |
+
device: str,
|
| 46 |
+
hf_repo: str | None,
|
| 47 |
+
):
|
| 48 |
+
self.name = name
|
| 49 |
+
self.model_cfg = model_cfg
|
| 50 |
+
self.train_cfg = train_cfg
|
| 51 |
+
self.device = device
|
| 52 |
+
self.hf_repo = hf_repo
|
| 53 |
+
|
| 54 |
+
self.model = PAWNCLM(model_cfg).to(device)
|
| 55 |
+
param_count = sum(p.numel() for p in self.model.parameters())
|
| 56 |
+
print(f" {name}: {param_count:,} params ({model_cfg.d_model}d/{model_cfg.n_layers}L)")
|
| 57 |
+
|
| 58 |
+
self.optimizer = torch.optim.AdamW(
|
| 59 |
+
self.model.parameters(),
|
| 60 |
+
lr=train_cfg.lr,
|
| 61 |
+
weight_decay=train_cfg.weight_decay,
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
from pawn.trainer import CosineWithWarmup
|
| 65 |
+
self.scheduler = CosineWithWarmup(
|
| 66 |
+
self.optimizer,
|
| 67 |
+
warmup_steps=train_cfg.warmup_steps,
|
| 68 |
+
total_steps=train_cfg.total_steps,
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
self.scaler = torch.amp.GradScaler(device, enabled=train_cfg.use_amp)
|
| 72 |
+
|
| 73 |
+
# Run directory
|
| 74 |
+
self.run_dir = _make_run_dir(train_cfg.log_dir, name)
|
| 75 |
+
self.checkpoint_dir = os.path.join(self.run_dir, "checkpoints")
|
| 76 |
+
os.makedirs(self.checkpoint_dir, exist_ok=True)
|
| 77 |
+
|
| 78 |
+
self.jsonl_path = os.path.join(self.run_dir, "metrics.jsonl")
|
| 79 |
+
self._jsonl_file: open | None = None
|
| 80 |
+
|
| 81 |
+
self.hf_branch = f"run/{os.path.basename(self.run_dir)}" if hf_repo else None
|
| 82 |
+
self.global_step = 0
|
| 83 |
+
|
| 84 |
+
# Write config
|
| 85 |
+
import subprocess
|
| 86 |
+
try:
|
| 87 |
+
git_hash = subprocess.check_output(
|
| 88 |
+
["git", "rev-parse", "HEAD"], stderr=subprocess.DEVNULL, text=True
|
| 89 |
+
).strip()
|
| 90 |
+
except Exception:
|
| 91 |
+
git_hash = os.environ.get("PAWN_GIT_HASH")
|
| 92 |
+
try:
|
| 93 |
+
git_tag = subprocess.check_output(
|
| 94 |
+
["git", "tag", "--points-at", "HEAD"], stderr=subprocess.DEVNULL, text=True
|
| 95 |
+
).strip() or None
|
| 96 |
+
except Exception:
|
| 97 |
+
git_tag = os.environ.get("PAWN_GIT_TAG")
|
| 98 |
+
|
| 99 |
+
config_data = {
|
| 100 |
+
"model": model_cfg.__dict__,
|
| 101 |
+
"training": train_cfg.__dict__,
|
| 102 |
+
"param_count": param_count,
|
| 103 |
+
"formulation": "clm",
|
| 104 |
+
"multi_model": True,
|
| 105 |
+
"variant": name,
|
| 106 |
+
"git_hash": git_hash,
|
| 107 |
+
"git_tag": git_tag,
|
| 108 |
+
}
|
| 109 |
+
with open(os.path.join(self.run_dir, "config.json"), "w") as f:
|
| 110 |
+
json.dump(config_data, f, indent=2, default=str)
|
| 111 |
+
self._log_jsonl({"type": "config", **config_data})
|
| 112 |
+
|
| 113 |
+
def train_step(self, batch: dict[str, torch.Tensor]) -> dict[str, float]:
|
| 114 |
+
self.model.train()
|
| 115 |
+
input_ids = batch["input_ids"].to(self.device)
|
| 116 |
+
targets = batch["targets"].to(self.device)
|
| 117 |
+
loss_mask = batch["loss_mask"].to(self.device)
|
| 118 |
+
|
| 119 |
+
with torch.amp.autocast(self.device, enabled=self.train_cfg.use_amp):
|
| 120 |
+
loss, metrics = self.model.forward_train(input_ids, loss_mask, targets)
|
| 121 |
+
|
| 122 |
+
self.scaler.scale(loss).backward()
|
| 123 |
+
return metrics
|
| 124 |
+
|
| 125 |
+
def optimizer_step(self) -> float:
|
| 126 |
+
self.scaler.unscale_(self.optimizer)
|
| 127 |
+
grad_norm = torch.nn.utils.clip_grad_norm_(
|
| 128 |
+
self.model.parameters(), self.train_cfg.max_grad_norm
|
| 129 |
+
).item()
|
| 130 |
+
self.scaler.step(self.optimizer)
|
| 131 |
+
self.scaler.update()
|
| 132 |
+
self.optimizer.zero_grad(set_to_none=True)
|
| 133 |
+
self.scheduler.step()
|
| 134 |
+
return grad_norm
|
| 135 |
+
|
| 136 |
+
def save_checkpoint(self):
|
| 137 |
+
path = os.path.join(self.checkpoint_dir, f"step_{self.global_step:08d}")
|
| 138 |
+
save_pretrain_checkpoint(
|
| 139 |
+
path, self.model, self.optimizer, self.scheduler, self.scaler,
|
| 140 |
+
self.global_step, self.model_cfg.__dict__, self.train_cfg.__dict__,
|
| 141 |
+
)
|
| 142 |
+
print(f" [{self.name}] Checkpoint saved: {path}")
|
| 143 |
+
|
| 144 |
+
if self.hf_repo and self.hf_branch:
|
| 145 |
+
try:
|
| 146 |
+
push_checkpoint_to_hf(
|
| 147 |
+
path, self.hf_repo, self.hf_branch,
|
| 148 |
+
metrics_path=self.jsonl_path, step=self.global_step,
|
| 149 |
+
)
|
| 150 |
+
print(f" [{self.name}] Pushed to HF: {self.hf_repo}@{self.hf_branch}")
|
| 151 |
+
except Exception as e:
|
| 152 |
+
print(f" [{self.name}] WARNING: HF push failed: {e}")
|
| 153 |
+
|
| 154 |
+
@torch.no_grad()
|
| 155 |
+
def evaluate(self, val_data: dict[str, torch.Tensor]) -> dict[str, float]:
|
| 156 |
+
self.model.eval()
|
| 157 |
+
n = val_data["input_ids"].shape[0]
|
| 158 |
+
batch_size = self.train_cfg.batch_size
|
| 159 |
+
total_metrics: dict[str, float] = {}
|
| 160 |
+
n_batches = 0
|
| 161 |
+
|
| 162 |
+
for start in range(0, n, batch_size):
|
| 163 |
+
end = min(start + batch_size, n)
|
| 164 |
+
input_ids = val_data["input_ids"][start:end].to(self.device)
|
| 165 |
+
targets = val_data["targets"][start:end].to(self.device)
|
| 166 |
+
loss_mask = val_data["loss_mask"][start:end].to(self.device)
|
| 167 |
+
|
| 168 |
+
with torch.amp.autocast(self.device, enabled=self.train_cfg.use_amp):
|
| 169 |
+
logits, _ = self.model(input_ids, loss_mask)
|
| 170 |
+
_, metrics = clm_loss(logits, targets, loss_mask)
|
| 171 |
+
|
| 172 |
+
for k, v in metrics.items():
|
| 173 |
+
total_metrics[k] = total_metrics.get(k, 0.0) + v
|
| 174 |
+
n_batches += 1
|
| 175 |
+
|
| 176 |
+
return {f"val/{k}": v / n_batches for k, v in total_metrics.items()}
|
| 177 |
+
|
| 178 |
+
def _log_jsonl(self, record: dict):
|
| 179 |
+
if self._jsonl_file is None:
|
| 180 |
+
self._jsonl_file = open(self.jsonl_path, "a")
|
| 181 |
+
self._jsonl_file.write(json.dumps(record, default=str) + "\n")
|
| 182 |
+
self._jsonl_file.flush()
|
| 183 |
+
|
| 184 |
+
def close(self):
|
| 185 |
+
if self._jsonl_file:
|
| 186 |
+
self._jsonl_file.close()
|
| 187 |
+
self._jsonl_file = None
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
def _make_run_dir(log_dir: str, variant: str) -> str:
|
| 191 |
+
timestamp = time.strftime("%Y%m%d_%H%M%S")
|
| 192 |
+
run_dir = os.path.join(log_dir, f"run_{timestamp}_{variant}")
|
| 193 |
+
os.makedirs(run_dir, exist_ok=True)
|
| 194 |
+
return run_dir
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
# ---------------------------------------------------------------------------
|
| 198 |
+
# CLI
|
| 199 |
+
# ---------------------------------------------------------------------------
|
| 200 |
+
|
| 201 |
+
def parse_args():
|
| 202 |
+
p = argparse.ArgumentParser(description="Train small/base/large PAWN models simultaneously")
|
| 203 |
+
p.add_argument("--device", type=str, default=None, help="Device (cuda/cpu)")
|
| 204 |
+
p.add_argument("--total-steps", type=int, default=100_000, help="Total training steps")
|
| 205 |
+
p.add_argument("--batch-size", type=int, default=256, help="Batch size (shared across models)")
|
| 206 |
+
p.add_argument("--num-workers", type=int, default=4, help="DataLoader workers")
|
| 207 |
+
p.add_argument("--log-dir", type=str, default="logs", help="Log directory")
|
| 208 |
+
p.add_argument("--log-interval", type=int, default=10)
|
| 209 |
+
p.add_argument("--eval-interval", type=int, default=500)
|
| 210 |
+
p.add_argument("--checkpoint-interval", type=int, default=5000)
|
| 211 |
+
p.add_argument("--discard-ply-limit", action="store_true")
|
| 212 |
+
p.add_argument("--wandb", action="store_true")
|
| 213 |
+
|
| 214 |
+
ckpt_group = p.add_mutually_exclusive_group(required=True)
|
| 215 |
+
ckpt_group.add_argument("--hf-repo", type=str, default=None,
|
| 216 |
+
help="HF repo prefix (appends -{variant}). E.g. thomas-schweich/pawn")
|
| 217 |
+
ckpt_group.add_argument("--local-checkpoints", action="store_true")
|
| 218 |
+
return p.parse_args()
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
def main():
|
| 222 |
+
args = parse_args()
|
| 223 |
+
|
| 224 |
+
device = args.device or ("cuda" if torch.cuda.is_available() else "cpu")
|
| 225 |
+
if device == "cuda":
|
| 226 |
+
gpu_cfg = configure_gpu()
|
| 227 |
+
import pawn.model as model_module
|
| 228 |
+
if gpu_cfg.get("sdpa_backend"):
|
| 229 |
+
model_module.SDPA_BACKEND = gpu_cfg["sdpa_backend"]
|
| 230 |
+
|
| 231 |
+
# Build per-variant configs (shared training hyperparams, different model sizes)
|
| 232 |
+
variants = {
|
| 233 |
+
"small": CLMConfig.small(),
|
| 234 |
+
"base": CLMConfig.base(),
|
| 235 |
+
"large": CLMConfig.large(),
|
| 236 |
+
}
|
| 237 |
+
|
| 238 |
+
print("=== Multi-Model Training ===")
|
| 239 |
+
print(f"Device: {device}")
|
| 240 |
+
print(f"Batch size: {args.batch_size}")
|
| 241 |
+
print(f"Total steps: {args.total_steps}")
|
| 242 |
+
print()
|
| 243 |
+
|
| 244 |
+
slots: list[ModelSlot] = []
|
| 245 |
+
for name, model_cfg in variants.items():
|
| 246 |
+
train_cfg = TrainingConfig()
|
| 247 |
+
train_cfg.total_steps = args.total_steps
|
| 248 |
+
train_cfg.batch_size = args.batch_size
|
| 249 |
+
train_cfg.num_workers = args.num_workers
|
| 250 |
+
train_cfg.device = device
|
| 251 |
+
train_cfg.log_dir = args.log_dir
|
| 252 |
+
train_cfg.log_interval = args.log_interval
|
| 253 |
+
train_cfg.eval_interval = args.eval_interval
|
| 254 |
+
train_cfg.checkpoint_interval = args.checkpoint_interval
|
| 255 |
+
train_cfg.discard_ply_limit = args.discard_ply_limit
|
| 256 |
+
train_cfg.use_wandb = args.wandb
|
| 257 |
+
|
| 258 |
+
hf_repo = f"{args.hf_repo}-{name}" if args.hf_repo else None
|
| 259 |
+
slots.append(ModelSlot(name, model_cfg, train_cfg, device, hf_repo))
|
| 260 |
+
|
| 261 |
+
# Shared dataset and validation set
|
| 262 |
+
max_ply = 256
|
| 263 |
+
dataset = CLMDataset(
|
| 264 |
+
args.batch_size, max_ply, base_seed=42,
|
| 265 |
+
discard_ply_limit=args.discard_ply_limit,
|
| 266 |
+
)
|
| 267 |
+
|
| 268 |
+
print("\nGenerating shared validation set...")
|
| 269 |
+
val_data = create_validation_set(512, max_ply, seed=(2**63) - 1,
|
| 270 |
+
discard_ply_limit=args.discard_ply_limit)
|
| 271 |
+
|
| 272 |
+
# Compile models
|
| 273 |
+
if device != "cpu":
|
| 274 |
+
for slot in slots:
|
| 275 |
+
try:
|
| 276 |
+
slot.model = torch.compile(slot.model, mode="default")
|
| 277 |
+
print(f" [{slot.name}] torch.compile enabled")
|
| 278 |
+
except Exception:
|
| 279 |
+
print(f" [{slot.name}] torch.compile not available")
|
| 280 |
+
|
| 281 |
+
loader = DataLoader(
|
| 282 |
+
dataset,
|
| 283 |
+
batch_size=None,
|
| 284 |
+
num_workers=args.num_workers,
|
| 285 |
+
pin_memory=(device != "cpu"),
|
| 286 |
+
persistent_workers=(args.num_workers > 0),
|
| 287 |
+
prefetch_factor=1 if args.num_workers > 0 else None,
|
| 288 |
+
)
|
| 289 |
+
|
| 290 |
+
# Signal handling
|
| 291 |
+
_shutdown_requested = False
|
| 292 |
+
_shutdown_signal = None
|
| 293 |
+
|
| 294 |
+
def _graceful_exit(signum, frame):
|
| 295 |
+
nonlocal _shutdown_requested, _shutdown_signal
|
| 296 |
+
_shutdown_requested = True
|
| 297 |
+
_shutdown_signal = signum
|
| 298 |
+
|
| 299 |
+
signal.signal(signal.SIGTERM, _graceful_exit)
|
| 300 |
+
signal.signal(signal.SIGINT, _graceful_exit)
|
| 301 |
+
|
| 302 |
+
# Training loop
|
| 303 |
+
global_step = 0
|
| 304 |
+
step_start = time.time()
|
| 305 |
+
|
| 306 |
+
print(f"\nStarting training from step 0", flush=True)
|
| 307 |
+
for slot in slots:
|
| 308 |
+
print(f" [{slot.name}] JSONL: {slot.jsonl_path}", flush=True)
|
| 309 |
+
print()
|
| 310 |
+
|
| 311 |
+
for batch in loader:
|
| 312 |
+
# Forward + backward for each model on the same batch
|
| 313 |
+
all_metrics: dict[str, dict[str, float]] = {}
|
| 314 |
+
for slot in slots:
|
| 315 |
+
metrics = slot.train_step(batch)
|
| 316 |
+
all_metrics[slot.name] = metrics
|
| 317 |
+
|
| 318 |
+
# Optimizer step for each model
|
| 319 |
+
all_grad_norms: dict[str, float] = {}
|
| 320 |
+
for slot in slots:
|
| 321 |
+
gn = slot.optimizer_step()
|
| 322 |
+
all_grad_norms[slot.name] = gn
|
| 323 |
+
|
| 324 |
+
global_step += 1
|
| 325 |
+
for slot in slots:
|
| 326 |
+
slot.global_step = global_step
|
| 327 |
+
|
| 328 |
+
step_time = time.time() - step_start
|
| 329 |
+
games_per_sec = args.batch_size / step_time
|
| 330 |
+
|
| 331 |
+
# Logging
|
| 332 |
+
if global_step % args.log_interval == 0:
|
| 333 |
+
print(f"step {global_step:>7d} | {games_per_sec:.0f} g/s | {step_time:.2f}s", flush=True)
|
| 334 |
+
for slot in slots:
|
| 335 |
+
m = all_metrics[slot.name]
|
| 336 |
+
gn = all_grad_norms[slot.name]
|
| 337 |
+
lr = slot.scheduler.get_lr()
|
| 338 |
+
print(f" {slot.name:>5s}: loss {m['loss']:.4f} | acc {m['accuracy']:.3f} | "
|
| 339 |
+
f"lr {lr:.2e} | gn {gn:.2f}", flush=True)
|
| 340 |
+
|
| 341 |
+
record = {
|
| 342 |
+
"type": "train",
|
| 343 |
+
"step": global_step,
|
| 344 |
+
"timestamp": time.time(),
|
| 345 |
+
"lr": lr,
|
| 346 |
+
"grad_norm": gn,
|
| 347 |
+
"step_time": step_time,
|
| 348 |
+
"games_per_sec": games_per_sec,
|
| 349 |
+
**{f"train/{k}": v for k, v in m.items()},
|
| 350 |
+
}
|
| 351 |
+
slot._log_jsonl(record)
|
| 352 |
+
|
| 353 |
+
# Eval
|
| 354 |
+
if global_step % args.eval_interval == 0:
|
| 355 |
+
for slot in slots:
|
| 356 |
+
val_metrics = slot.evaluate(val_data)
|
| 357 |
+
print(f" {slot.name:>5s} val: loss {val_metrics['val/loss']:.4f} | "
|
| 358 |
+
f"acc {val_metrics['val/accuracy']:.3f}", flush=True)
|
| 359 |
+
slot._log_jsonl({
|
| 360 |
+
"type": "val",
|
| 361 |
+
"step": global_step,
|
| 362 |
+
"timestamp": time.time(),
|
| 363 |
+
**val_metrics,
|
| 364 |
+
})
|
| 365 |
+
|
| 366 |
+
# Checkpoint
|
| 367 |
+
if global_step % args.checkpoint_interval == 0:
|
| 368 |
+
for slot in slots:
|
| 369 |
+
slot.save_checkpoint()
|
| 370 |
+
|
| 371 |
+
# Done?
|
| 372 |
+
if global_step >= args.total_steps:
|
| 373 |
+
print(f"\nTraining complete at step {global_step}")
|
| 374 |
+
for slot in slots:
|
| 375 |
+
slot.save_checkpoint()
|
| 376 |
+
break
|
| 377 |
+
|
| 378 |
+
# Graceful shutdown
|
| 379 |
+
if _shutdown_requested:
|
| 380 |
+
print(f"\nShutdown requested (signal {_shutdown_signal}), "
|
| 381 |
+
f"saving checkpoints at step {global_step}...")
|
| 382 |
+
for slot in slots:
|
| 383 |
+
slot.save_checkpoint()
|
| 384 |
+
break
|
| 385 |
+
|
| 386 |
+
step_start = time.time()
|
| 387 |
+
|
| 388 |
+
# Cleanup
|
| 389 |
+
for slot in slots:
|
| 390 |
+
slot.close()
|
| 391 |
+
|
| 392 |
+
print("\nAll done.")
|
| 393 |
+
|
| 394 |
+
|
| 395 |
+
if __name__ == "__main__":
|
| 396 |
+
try:
|
| 397 |
+
mp.set_start_method("forkserver", force=True)
|
| 398 |
+
except ValueError:
|
| 399 |
+
mp.set_start_method("spawn", force=True)
|
| 400 |
+
main()
|
scripts/train_bottleneck.py
CHANGED
|
@@ -16,6 +16,7 @@ from __future__ import annotations
|
|
| 16 |
import argparse
|
| 17 |
import gc
|
| 18 |
import math
|
|
|
|
| 19 |
import time
|
| 20 |
from pathlib import Path
|
| 21 |
|
|
@@ -90,15 +91,22 @@ def parse_args():
|
|
| 90 |
p.add_argument("--resume", type=str, default=None,
|
| 91 |
help="Path to checkpoint to resume from (best.pt or final.pt)")
|
| 92 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 93 |
return p.parse_args()
|
| 94 |
|
| 95 |
|
| 96 |
def load_backbone(checkpoint_path: str, device: str) -> PAWNCLM:
|
| 97 |
-
|
| 98 |
-
|
|
|
|
| 99 |
model = PAWNCLM(cfg).to(device)
|
| 100 |
-
model.load_state_dict(
|
| 101 |
-
del
|
| 102 |
gc.collect()
|
| 103 |
model.eval()
|
| 104 |
return model
|
|
@@ -197,6 +205,10 @@ def main():
|
|
| 197 |
ckpt_dir = out_dir / "checkpoints"
|
| 198 |
ckpt_dir.mkdir(exist_ok=True)
|
| 199 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 200 |
print(f"Device: {device}")
|
| 201 |
print(f"Output: {out_dir}")
|
| 202 |
|
|
@@ -312,17 +324,18 @@ def main():
|
|
| 312 |
|
| 313 |
if args.resume:
|
| 314 |
print(f"\nResuming from: {args.resume}")
|
| 315 |
-
|
| 316 |
-
|
| 317 |
-
|
|
|
|
| 318 |
optimizer.load_state_dict(ckpt["optimizer_state_dict"])
|
| 319 |
-
if "scheduler_state_dict"
|
| 320 |
scheduler.load_state_dict(ckpt["scheduler_state_dict"])
|
| 321 |
if scaler and ckpt.get("scaler_state_dict"):
|
| 322 |
scaler.load_state_dict(ckpt["scaler_state_dict"])
|
| 323 |
start_epoch = ckpt["epoch"] + 1
|
| 324 |
global_step = ckpt["step"]
|
| 325 |
-
best_val_loss = ckpt.get("best_val_loss", ckpt.get("
|
| 326 |
patience_counter = ckpt.get("patience_counter", 0)
|
| 327 |
print(f" Resumed at epoch {start_epoch}, step {global_step}, "
|
| 328 |
f"best_val_loss={best_val_loss:.4f}")
|
|
@@ -358,6 +371,13 @@ def main():
|
|
| 358 |
val_metrics = evaluate(model, val_loader, mask_builder, device, use_amp=use_amp,
|
| 359 |
precomputed_indices=val_legal_indices) if args.resume else baseline
|
| 360 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 361 |
print(f"\nTraining for up to {args.epochs} epochs ({total_steps} steps)")
|
| 362 |
print(f" Warmup: {warmup_steps} steps, LR: {args.lr}")
|
| 363 |
|
|
@@ -434,38 +454,56 @@ def main():
|
|
| 434 |
if val_metrics["loss"] < best_val_loss:
|
| 435 |
best_val_loss = val_metrics["loss"]
|
| 436 |
patience_counter = 0
|
| 437 |
-
|
| 438 |
-
|
| 439 |
-
"
|
| 440 |
-
|
| 441 |
-
|
| 442 |
-
|
| 443 |
-
|
| 444 |
-
|
| 445 |
-
|
| 446 |
-
|
| 447 |
-
|
| 448 |
-
"
|
| 449 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 450 |
else:
|
| 451 |
patience_counter += 1
|
| 452 |
if patience_counter >= args.patience:
|
| 453 |
print(f"\n Early stopping at epoch {epoch} (patience={args.patience})")
|
| 454 |
break
|
| 455 |
|
| 456 |
-
|
| 457 |
-
|
| 458 |
-
|
| 459 |
-
|
| 460 |
-
|
| 461 |
-
|
| 462 |
-
"
|
| 463 |
-
|
| 464 |
-
|
| 465 |
-
|
| 466 |
-
|
| 467 |
-
|
| 468 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 469 |
|
| 470 |
logger.close()
|
| 471 |
print(f"\nDone. Best val_loss={best_val_loss:.4f}")
|
|
|
|
| 16 |
import argparse
|
| 17 |
import gc
|
| 18 |
import math
|
| 19 |
+
import signal
|
| 20 |
import time
|
| 21 |
from pathlib import Path
|
| 22 |
|
|
|
|
| 91 |
p.add_argument("--resume", type=str, default=None,
|
| 92 |
help="Path to checkpoint to resume from (best.pt or final.pt)")
|
| 93 |
|
| 94 |
+
ckpt_group = p.add_mutually_exclusive_group(required=True)
|
| 95 |
+
ckpt_group.add_argument("--hf-repo", type=str, default=None,
|
| 96 |
+
help="Push checkpoints to this HuggingFace repo (requires HF_TOKEN)")
|
| 97 |
+
ckpt_group.add_argument("--local-checkpoints", action="store_true",
|
| 98 |
+
help="Save checkpoints locally only")
|
| 99 |
+
|
| 100 |
return p.parse_args()
|
| 101 |
|
| 102 |
|
| 103 |
def load_backbone(checkpoint_path: str, device: str) -> PAWNCLM:
|
| 104 |
+
from pawn.checkpoint import load_backbone_weights
|
| 105 |
+
state_dict, model_config = load_backbone_weights(checkpoint_path, device)
|
| 106 |
+
cfg = CLMConfig(**model_config) if model_config else CLMConfig()
|
| 107 |
model = PAWNCLM(cfg).to(device)
|
| 108 |
+
model.load_state_dict(state_dict)
|
| 109 |
+
del state_dict
|
| 110 |
gc.collect()
|
| 111 |
model.eval()
|
| 112 |
return model
|
|
|
|
| 205 |
ckpt_dir = out_dir / "checkpoints"
|
| 206 |
ckpt_dir.mkdir(exist_ok=True)
|
| 207 |
|
| 208 |
+
hf_branch = None
|
| 209 |
+
if args.hf_repo:
|
| 210 |
+
hf_branch = f"run/{out_dir.name}"
|
| 211 |
+
|
| 212 |
print(f"Device: {device}")
|
| 213 |
print(f"Output: {out_dir}")
|
| 214 |
|
|
|
|
| 324 |
|
| 325 |
if args.resume:
|
| 326 |
print(f"\nResuming from: {args.resume}")
|
| 327 |
+
from pawn.checkpoint import load_adapter_checkpoint
|
| 328 |
+
ckpt = load_adapter_checkpoint(args.resume, device=device)
|
| 329 |
+
model.load_adapter_state_dict(ckpt["adapter_state_dict"])
|
| 330 |
+
if ckpt.get("optimizer_state_dict"):
|
| 331 |
optimizer.load_state_dict(ckpt["optimizer_state_dict"])
|
| 332 |
+
if ckpt.get("scheduler_state_dict"):
|
| 333 |
scheduler.load_state_dict(ckpt["scheduler_state_dict"])
|
| 334 |
if scaler and ckpt.get("scaler_state_dict"):
|
| 335 |
scaler.load_state_dict(ckpt["scaler_state_dict"])
|
| 336 |
start_epoch = ckpt["epoch"] + 1
|
| 337 |
global_step = ckpt["step"]
|
| 338 |
+
best_val_loss = ckpt.get("best_val_loss", ckpt.get("val_metrics", {}).get("loss", float("inf")))
|
| 339 |
patience_counter = ckpt.get("patience_counter", 0)
|
| 340 |
print(f" Resumed at epoch {start_epoch}, step {global_step}, "
|
| 341 |
f"best_val_loss={best_val_loss:.4f}")
|
|
|
|
| 371 |
val_metrics = evaluate(model, val_loader, mask_builder, device, use_amp=use_amp,
|
| 372 |
precomputed_indices=val_legal_indices) if args.resume else baseline
|
| 373 |
|
| 374 |
+
_shutdown_requested = False
|
| 375 |
+
def _graceful_exit(signum, frame):
|
| 376 |
+
nonlocal _shutdown_requested
|
| 377 |
+
_shutdown_requested = True
|
| 378 |
+
signal.signal(signal.SIGTERM, _graceful_exit)
|
| 379 |
+
signal.signal(signal.SIGINT, _graceful_exit)
|
| 380 |
+
|
| 381 |
print(f"\nTraining for up to {args.epochs} epochs ({total_steps} steps)")
|
| 382 |
print(f" Warmup: {warmup_steps} steps, LR: {args.lr}")
|
| 383 |
|
|
|
|
| 454 |
if val_metrics["loss"] < best_val_loss:
|
| 455 |
best_val_loss = val_metrics["loss"]
|
| 456 |
patience_counter = 0
|
| 457 |
+
from pawn.checkpoint import save_adapter_checkpoint
|
| 458 |
+
save_adapter_checkpoint(
|
| 459 |
+
ckpt_dir / "best",
|
| 460 |
+
model.adapter_state_dict(),
|
| 461 |
+
config=vars(args),
|
| 462 |
+
epoch=epoch,
|
| 463 |
+
step=global_step,
|
| 464 |
+
val_metrics=val_metrics,
|
| 465 |
+
optimizer=optimizer,
|
| 466 |
+
scheduler=scheduler,
|
| 467 |
+
scaler=scaler,
|
| 468 |
+
extra={"best_val_loss": best_val_loss, "patience_counter": patience_counter},
|
| 469 |
+
)
|
| 470 |
+
if args.hf_repo and hf_branch:
|
| 471 |
+
from pawn.checkpoint import push_checkpoint_to_hf
|
| 472 |
+
try:
|
| 473 |
+
push_checkpoint_to_hf(ckpt_dir / "best", args.hf_repo, hf_branch, step=global_step)
|
| 474 |
+
print(f"Pushed to HF: {args.hf_repo}@{hf_branch}")
|
| 475 |
+
except Exception as e:
|
| 476 |
+
print(f"WARNING: HF push failed: {e}")
|
| 477 |
else:
|
| 478 |
patience_counter += 1
|
| 479 |
if patience_counter >= args.patience:
|
| 480 |
print(f"\n Early stopping at epoch {epoch} (patience={args.patience})")
|
| 481 |
break
|
| 482 |
|
| 483 |
+
if _shutdown_requested:
|
| 484 |
+
print("Shutdown requested, saving checkpoint...")
|
| 485 |
+
break
|
| 486 |
+
|
| 487 |
+
from pawn.checkpoint import save_adapter_checkpoint
|
| 488 |
+
save_adapter_checkpoint(
|
| 489 |
+
ckpt_dir / "final",
|
| 490 |
+
model.adapter_state_dict(),
|
| 491 |
+
config=vars(args),
|
| 492 |
+
epoch=epoch,
|
| 493 |
+
step=global_step,
|
| 494 |
+
val_metrics=val_metrics,
|
| 495 |
+
optimizer=optimizer,
|
| 496 |
+
scheduler=scheduler,
|
| 497 |
+
scaler=scaler,
|
| 498 |
+
extra={"best_val_loss": best_val_loss, "patience_counter": patience_counter},
|
| 499 |
+
)
|
| 500 |
+
if args.hf_repo and hf_branch:
|
| 501 |
+
from pawn.checkpoint import push_checkpoint_to_hf
|
| 502 |
+
try:
|
| 503 |
+
push_checkpoint_to_hf(ckpt_dir / "final", args.hf_repo, hf_branch, step=global_step)
|
| 504 |
+
print(f"Pushed to HF: {args.hf_repo}@{hf_branch}")
|
| 505 |
+
except Exception as e:
|
| 506 |
+
print(f"WARNING: HF push failed: {e}")
|
| 507 |
|
| 508 |
logger.close()
|
| 509 |
print(f"\nDone. Best val_loss={best_val_loss:.4f}")
|
scripts/train_film.py
CHANGED
|
@@ -17,6 +17,7 @@ import argparse
|
|
| 17 |
import gc
|
| 18 |
import json
|
| 19 |
import math
|
|
|
|
| 20 |
import time
|
| 21 |
from pathlib import Path
|
| 22 |
|
|
@@ -80,15 +81,22 @@ def parse_args():
|
|
| 80 |
p.add_argument("--sdpa-math", action="store_true",
|
| 81 |
help="Use MATH SDPA backend (workaround for ROCm flash attn + compile)")
|
| 82 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 83 |
return p.parse_args()
|
| 84 |
|
| 85 |
|
| 86 |
def load_backbone(checkpoint_path: str, device: str) -> PAWNCLM:
|
| 87 |
-
|
| 88 |
-
|
|
|
|
| 89 |
model = PAWNCLM(cfg).to(device)
|
| 90 |
-
model.load_state_dict(
|
| 91 |
-
del
|
| 92 |
gc.collect()
|
| 93 |
model.eval()
|
| 94 |
return model
|
|
@@ -180,6 +188,10 @@ def main():
|
|
| 180 |
ckpt_dir = out_dir / "checkpoints"
|
| 181 |
ckpt_dir.mkdir(exist_ok=True)
|
| 182 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 183 |
device = args.device
|
| 184 |
print(f"Device: {device}")
|
| 185 |
print(f"Output: {out_dir}")
|
|
@@ -284,6 +296,13 @@ def main():
|
|
| 284 |
global_step = 0
|
| 285 |
val_metrics = baseline # for first log if val_every > 1
|
| 286 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 287 |
print(f"\nTraining for up to {args.epochs} epochs ({total_steps} steps)")
|
| 288 |
print(f" Warmup: {warmup_steps} steps, LR: {args.lr}")
|
| 289 |
if args.val_every > 1:
|
|
@@ -371,29 +390,49 @@ def main():
|
|
| 371 |
if val_metrics["loss"] < best_val_loss:
|
| 372 |
best_val_loss = val_metrics["loss"]
|
| 373 |
patience_counter = 0
|
| 374 |
-
|
| 375 |
-
|
| 376 |
-
"
|
| 377 |
-
|
| 378 |
-
|
| 379 |
-
|
| 380 |
-
|
| 381 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 382 |
else:
|
| 383 |
patience_counter += 1
|
| 384 |
if patience_counter >= args.patience:
|
| 385 |
print(f"\n Early stopping at epoch {epoch} (patience={args.patience})")
|
| 386 |
break
|
| 387 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 388 |
# Save final checkpoint
|
| 389 |
-
|
| 390 |
-
|
| 391 |
-
"
|
| 392 |
-
|
| 393 |
-
|
| 394 |
-
|
| 395 |
-
|
| 396 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 397 |
|
| 398 |
print(f"\nDone. Best val_loss={best_val_loss:.4f}")
|
| 399 |
print(f"Checkpoints saved to {out_dir}")
|
|
|
|
| 17 |
import gc
|
| 18 |
import json
|
| 19 |
import math
|
| 20 |
+
import signal
|
| 21 |
import time
|
| 22 |
from pathlib import Path
|
| 23 |
|
|
|
|
| 81 |
p.add_argument("--sdpa-math", action="store_true",
|
| 82 |
help="Use MATH SDPA backend (workaround for ROCm flash attn + compile)")
|
| 83 |
|
| 84 |
+
ckpt_group = p.add_mutually_exclusive_group(required=True)
|
| 85 |
+
ckpt_group.add_argument("--hf-repo", type=str, default=None,
|
| 86 |
+
help="Push checkpoints to this HuggingFace repo (requires HF_TOKEN)")
|
| 87 |
+
ckpt_group.add_argument("--local-checkpoints", action="store_true",
|
| 88 |
+
help="Save checkpoints locally only")
|
| 89 |
+
|
| 90 |
return p.parse_args()
|
| 91 |
|
| 92 |
|
| 93 |
def load_backbone(checkpoint_path: str, device: str) -> PAWNCLM:
|
| 94 |
+
from pawn.checkpoint import load_backbone_weights
|
| 95 |
+
state_dict, model_config = load_backbone_weights(checkpoint_path, device)
|
| 96 |
+
cfg = CLMConfig(**model_config) if model_config else CLMConfig()
|
| 97 |
model = PAWNCLM(cfg).to(device)
|
| 98 |
+
model.load_state_dict(state_dict)
|
| 99 |
+
del state_dict
|
| 100 |
gc.collect()
|
| 101 |
model.eval()
|
| 102 |
return model
|
|
|
|
| 188 |
ckpt_dir = out_dir / "checkpoints"
|
| 189 |
ckpt_dir.mkdir(exist_ok=True)
|
| 190 |
|
| 191 |
+
hf_branch = None
|
| 192 |
+
if args.hf_repo:
|
| 193 |
+
hf_branch = f"run/{out_dir.name}"
|
| 194 |
+
|
| 195 |
device = args.device
|
| 196 |
print(f"Device: {device}")
|
| 197 |
print(f"Output: {out_dir}")
|
|
|
|
| 296 |
global_step = 0
|
| 297 |
val_metrics = baseline # for first log if val_every > 1
|
| 298 |
|
| 299 |
+
_shutdown_requested = False
|
| 300 |
+
def _graceful_exit(signum, frame):
|
| 301 |
+
nonlocal _shutdown_requested
|
| 302 |
+
_shutdown_requested = True
|
| 303 |
+
signal.signal(signal.SIGTERM, _graceful_exit)
|
| 304 |
+
signal.signal(signal.SIGINT, _graceful_exit)
|
| 305 |
+
|
| 306 |
print(f"\nTraining for up to {args.epochs} epochs ({total_steps} steps)")
|
| 307 |
print(f" Warmup: {warmup_steps} steps, LR: {args.lr}")
|
| 308 |
if args.val_every > 1:
|
|
|
|
| 390 |
if val_metrics["loss"] < best_val_loss:
|
| 391 |
best_val_loss = val_metrics["loss"]
|
| 392 |
patience_counter = 0
|
| 393 |
+
from pawn.checkpoint import save_adapter_checkpoint
|
| 394 |
+
save_adapter_checkpoint(
|
| 395 |
+
ckpt_dir / "best",
|
| 396 |
+
model.film_state_dict(),
|
| 397 |
+
config=vars(args),
|
| 398 |
+
epoch=epoch,
|
| 399 |
+
step=global_step,
|
| 400 |
+
val_metrics=val_metrics,
|
| 401 |
+
)
|
| 402 |
+
if args.hf_repo and hf_branch:
|
| 403 |
+
from pawn.checkpoint import push_checkpoint_to_hf
|
| 404 |
+
try:
|
| 405 |
+
push_checkpoint_to_hf(ckpt_dir / "best", args.hf_repo, hf_branch, step=global_step)
|
| 406 |
+
print(f"Pushed to HF: {args.hf_repo}@{hf_branch}")
|
| 407 |
+
except Exception as e:
|
| 408 |
+
print(f"WARNING: HF push failed: {e}")
|
| 409 |
else:
|
| 410 |
patience_counter += 1
|
| 411 |
if patience_counter >= args.patience:
|
| 412 |
print(f"\n Early stopping at epoch {epoch} (patience={args.patience})")
|
| 413 |
break
|
| 414 |
|
| 415 |
+
if _shutdown_requested:
|
| 416 |
+
print("Shutdown requested, saving checkpoint...")
|
| 417 |
+
break
|
| 418 |
+
|
| 419 |
# Save final checkpoint
|
| 420 |
+
from pawn.checkpoint import save_adapter_checkpoint
|
| 421 |
+
save_adapter_checkpoint(
|
| 422 |
+
ckpt_dir / "final",
|
| 423 |
+
model.film_state_dict(),
|
| 424 |
+
config=vars(args),
|
| 425 |
+
epoch=epoch,
|
| 426 |
+
step=global_step,
|
| 427 |
+
val_metrics=val_metrics,
|
| 428 |
+
)
|
| 429 |
+
if args.hf_repo and hf_branch:
|
| 430 |
+
from pawn.checkpoint import push_checkpoint_to_hf
|
| 431 |
+
try:
|
| 432 |
+
push_checkpoint_to_hf(ckpt_dir / "final", args.hf_repo, hf_branch, step=global_step)
|
| 433 |
+
print(f"Pushed to HF: {args.hf_repo}@{hf_branch}")
|
| 434 |
+
except Exception as e:
|
| 435 |
+
print(f"WARNING: HF push failed: {e}")
|
| 436 |
|
| 437 |
print(f"\nDone. Best val_loss={best_val_loss:.4f}")
|
| 438 |
print(f"Checkpoints saved to {out_dir}")
|
scripts/train_hybrid.py
CHANGED
|
@@ -18,6 +18,7 @@ import argparse
|
|
| 18 |
import gc
|
| 19 |
import json
|
| 20 |
import math
|
|
|
|
| 21 |
import time
|
| 22 |
from pathlib import Path
|
| 23 |
|
|
@@ -91,15 +92,22 @@ def parse_args():
|
|
| 91 |
p.add_argument("--sdpa-math", action="store_true",
|
| 92 |
help="Use MATH SDPA backend (workaround for ROCm flash attn + compile)")
|
| 93 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 94 |
return p.parse_args()
|
| 95 |
|
| 96 |
|
| 97 |
def load_backbone(checkpoint_path: str, device: str) -> PAWNCLM:
|
| 98 |
-
|
| 99 |
-
|
|
|
|
| 100 |
model = PAWNCLM(cfg).to(device)
|
| 101 |
-
model.load_state_dict(
|
| 102 |
-
del
|
| 103 |
gc.collect()
|
| 104 |
model.eval()
|
| 105 |
return model
|
|
@@ -182,6 +190,10 @@ def main():
|
|
| 182 |
ckpt_dir = out_dir / "checkpoints"
|
| 183 |
ckpt_dir.mkdir(exist_ok=True)
|
| 184 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 185 |
device = args.device
|
| 186 |
print(f"Device: {device}")
|
| 187 |
print(f"Output: {out_dir}")
|
|
@@ -314,6 +326,13 @@ def main():
|
|
| 314 |
global_step = 0
|
| 315 |
val_metrics = baseline
|
| 316 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 317 |
print(f"\nTraining for up to {args.epochs} epochs ({total_steps} steps)")
|
| 318 |
|
| 319 |
for epoch in range(args.epochs):
|
|
@@ -394,28 +413,48 @@ def main():
|
|
| 394 |
if val_metrics["loss"] < best_val_loss:
|
| 395 |
best_val_loss = val_metrics["loss"]
|
| 396 |
patience_counter = 0
|
| 397 |
-
|
| 398 |
-
|
| 399 |
-
"
|
| 400 |
-
|
| 401 |
-
|
| 402 |
-
|
| 403 |
-
|
| 404 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 405 |
else:
|
| 406 |
patience_counter += 1
|
| 407 |
if patience_counter >= args.patience:
|
| 408 |
print(f"\n Early stopping at epoch {epoch} (patience={args.patience})")
|
| 409 |
break
|
| 410 |
|
| 411 |
-
|
| 412 |
-
|
| 413 |
-
|
| 414 |
-
|
| 415 |
-
|
| 416 |
-
|
| 417 |
-
"
|
| 418 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 419 |
|
| 420 |
print(f"\nDone. Best val_loss={best_val_loss:.4f}")
|
| 421 |
print(f"Checkpoints saved to {out_dir}")
|
|
|
|
| 18 |
import gc
|
| 19 |
import json
|
| 20 |
import math
|
| 21 |
+
import signal
|
| 22 |
import time
|
| 23 |
from pathlib import Path
|
| 24 |
|
|
|
|
| 92 |
p.add_argument("--sdpa-math", action="store_true",
|
| 93 |
help="Use MATH SDPA backend (workaround for ROCm flash attn + compile)")
|
| 94 |
|
| 95 |
+
ckpt_group = p.add_mutually_exclusive_group(required=True)
|
| 96 |
+
ckpt_group.add_argument("--hf-repo", type=str, default=None,
|
| 97 |
+
help="Push checkpoints to this HuggingFace repo (requires HF_TOKEN)")
|
| 98 |
+
ckpt_group.add_argument("--local-checkpoints", action="store_true",
|
| 99 |
+
help="Save checkpoints locally only")
|
| 100 |
+
|
| 101 |
return p.parse_args()
|
| 102 |
|
| 103 |
|
| 104 |
def load_backbone(checkpoint_path: str, device: str) -> PAWNCLM:
|
| 105 |
+
from pawn.checkpoint import load_backbone_weights
|
| 106 |
+
state_dict, model_config = load_backbone_weights(checkpoint_path, device)
|
| 107 |
+
cfg = CLMConfig(**model_config) if model_config else CLMConfig()
|
| 108 |
model = PAWNCLM(cfg).to(device)
|
| 109 |
+
model.load_state_dict(state_dict)
|
| 110 |
+
del state_dict
|
| 111 |
gc.collect()
|
| 112 |
model.eval()
|
| 113 |
return model
|
|
|
|
| 190 |
ckpt_dir = out_dir / "checkpoints"
|
| 191 |
ckpt_dir.mkdir(exist_ok=True)
|
| 192 |
|
| 193 |
+
hf_branch = None
|
| 194 |
+
if args.hf_repo:
|
| 195 |
+
hf_branch = f"run/{out_dir.name}"
|
| 196 |
+
|
| 197 |
device = args.device
|
| 198 |
print(f"Device: {device}")
|
| 199 |
print(f"Output: {out_dir}")
|
|
|
|
| 326 |
global_step = 0
|
| 327 |
val_metrics = baseline
|
| 328 |
|
| 329 |
+
_shutdown_requested = False
|
| 330 |
+
def _graceful_exit(signum, frame):
|
| 331 |
+
nonlocal _shutdown_requested
|
| 332 |
+
_shutdown_requested = True
|
| 333 |
+
signal.signal(signal.SIGTERM, _graceful_exit)
|
| 334 |
+
signal.signal(signal.SIGINT, _graceful_exit)
|
| 335 |
+
|
| 336 |
print(f"\nTraining for up to {args.epochs} epochs ({total_steps} steps)")
|
| 337 |
|
| 338 |
for epoch in range(args.epochs):
|
|
|
|
| 413 |
if val_metrics["loss"] < best_val_loss:
|
| 414 |
best_val_loss = val_metrics["loss"]
|
| 415 |
patience_counter = 0
|
| 416 |
+
from pawn.checkpoint import save_adapter_checkpoint
|
| 417 |
+
save_adapter_checkpoint(
|
| 418 |
+
ckpt_dir / "best",
|
| 419 |
+
model.adapter_state_dict(),
|
| 420 |
+
config=vars(args),
|
| 421 |
+
epoch=epoch,
|
| 422 |
+
step=global_step,
|
| 423 |
+
val_metrics=val_metrics,
|
| 424 |
+
)
|
| 425 |
+
if args.hf_repo and hf_branch:
|
| 426 |
+
from pawn.checkpoint import push_checkpoint_to_hf
|
| 427 |
+
try:
|
| 428 |
+
push_checkpoint_to_hf(ckpt_dir / "best", args.hf_repo, hf_branch, step=global_step)
|
| 429 |
+
print(f"Pushed to HF: {args.hf_repo}@{hf_branch}")
|
| 430 |
+
except Exception as e:
|
| 431 |
+
print(f"WARNING: HF push failed: {e}")
|
| 432 |
else:
|
| 433 |
patience_counter += 1
|
| 434 |
if patience_counter >= args.patience:
|
| 435 |
print(f"\n Early stopping at epoch {epoch} (patience={args.patience})")
|
| 436 |
break
|
| 437 |
|
| 438 |
+
if _shutdown_requested:
|
| 439 |
+
print("Shutdown requested, saving checkpoint...")
|
| 440 |
+
break
|
| 441 |
+
|
| 442 |
+
from pawn.checkpoint import save_adapter_checkpoint
|
| 443 |
+
save_adapter_checkpoint(
|
| 444 |
+
ckpt_dir / "final",
|
| 445 |
+
model.adapter_state_dict(),
|
| 446 |
+
config=vars(args),
|
| 447 |
+
epoch=epoch,
|
| 448 |
+
step=global_step,
|
| 449 |
+
val_metrics=val_metrics,
|
| 450 |
+
)
|
| 451 |
+
if args.hf_repo and hf_branch:
|
| 452 |
+
from pawn.checkpoint import push_checkpoint_to_hf
|
| 453 |
+
try:
|
| 454 |
+
push_checkpoint_to_hf(ckpt_dir / "final", args.hf_repo, hf_branch, step=global_step)
|
| 455 |
+
print(f"Pushed to HF: {args.hf_repo}@{hf_branch}")
|
| 456 |
+
except Exception as e:
|
| 457 |
+
print(f"WARNING: HF push failed: {e}")
|
| 458 |
|
| 459 |
print(f"\nDone. Best val_loss={best_val_loss:.4f}")
|
| 460 |
print(f"Checkpoints saved to {out_dir}")
|
scripts/train_lora.py
CHANGED
|
@@ -17,6 +17,7 @@ import argparse
|
|
| 17 |
import gc
|
| 18 |
import json
|
| 19 |
import math
|
|
|
|
| 20 |
import time
|
| 21 |
from pathlib import Path
|
| 22 |
|
|
@@ -92,15 +93,22 @@ def parse_args():
|
|
| 92 |
p.add_argument("--sdpa-math", action="store_true",
|
| 93 |
help="Use MATH SDPA backend (workaround for ROCm flash attn + compile)")
|
| 94 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 95 |
return p.parse_args()
|
| 96 |
|
| 97 |
|
| 98 |
def load_backbone(checkpoint_path: str, device: str) -> PAWNCLM:
|
| 99 |
-
|
| 100 |
-
|
|
|
|
| 101 |
model = PAWNCLM(cfg).to(device)
|
| 102 |
-
model.load_state_dict(
|
| 103 |
-
del
|
| 104 |
gc.collect()
|
| 105 |
model.eval()
|
| 106 |
return model
|
|
@@ -190,6 +198,10 @@ def main():
|
|
| 190 |
ckpt_dir = out_dir / "checkpoints"
|
| 191 |
ckpt_dir.mkdir(exist_ok=True)
|
| 192 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 193 |
device = args.device
|
| 194 |
print(f"Device: {device}")
|
| 195 |
print(f"Output: {out_dir}")
|
|
@@ -309,6 +321,13 @@ def main():
|
|
| 309 |
global_step = 0
|
| 310 |
val_metrics = baseline
|
| 311 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 312 |
print(f"\nTraining for up to {args.epochs} epochs ({total_steps} steps)")
|
| 313 |
print(f" Warmup: {warmup_steps} steps, LR: {args.lr}")
|
| 314 |
if args.val_every > 1:
|
|
@@ -395,29 +414,49 @@ def main():
|
|
| 395 |
if val_metrics["loss"] < best_val_loss:
|
| 396 |
best_val_loss = val_metrics["loss"]
|
| 397 |
patience_counter = 0
|
| 398 |
-
|
| 399 |
-
|
| 400 |
-
"
|
| 401 |
-
|
| 402 |
-
|
| 403 |
-
|
| 404 |
-
|
| 405 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 406 |
else:
|
| 407 |
patience_counter += 1
|
| 408 |
if patience_counter >= args.patience:
|
| 409 |
print(f"\n Early stopping at epoch {epoch} (patience={args.patience})")
|
| 410 |
break
|
| 411 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 412 |
# Save final checkpoint
|
| 413 |
-
|
| 414 |
-
|
| 415 |
-
"
|
| 416 |
-
|
| 417 |
-
|
| 418 |
-
|
| 419 |
-
|
| 420 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 421 |
|
| 422 |
print(f"\nDone. Best val_loss={best_val_loss:.4f}")
|
| 423 |
print(f"Checkpoints saved to {out_dir}")
|
|
|
|
| 17 |
import gc
|
| 18 |
import json
|
| 19 |
import math
|
| 20 |
+
import signal
|
| 21 |
import time
|
| 22 |
from pathlib import Path
|
| 23 |
|
|
|
|
| 93 |
p.add_argument("--sdpa-math", action="store_true",
|
| 94 |
help="Use MATH SDPA backend (workaround for ROCm flash attn + compile)")
|
| 95 |
|
| 96 |
+
ckpt_group = p.add_mutually_exclusive_group(required=True)
|
| 97 |
+
ckpt_group.add_argument("--hf-repo", type=str, default=None,
|
| 98 |
+
help="Push checkpoints to this HuggingFace repo (requires HF_TOKEN)")
|
| 99 |
+
ckpt_group.add_argument("--local-checkpoints", action="store_true",
|
| 100 |
+
help="Save checkpoints locally only")
|
| 101 |
+
|
| 102 |
return p.parse_args()
|
| 103 |
|
| 104 |
|
| 105 |
def load_backbone(checkpoint_path: str, device: str) -> PAWNCLM:
|
| 106 |
+
from pawn.checkpoint import load_backbone_weights
|
| 107 |
+
state_dict, model_config = load_backbone_weights(checkpoint_path, device)
|
| 108 |
+
cfg = CLMConfig(**model_config) if model_config else CLMConfig()
|
| 109 |
model = PAWNCLM(cfg).to(device)
|
| 110 |
+
model.load_state_dict(state_dict)
|
| 111 |
+
del state_dict
|
| 112 |
gc.collect()
|
| 113 |
model.eval()
|
| 114 |
return model
|
|
|
|
| 198 |
ckpt_dir = out_dir / "checkpoints"
|
| 199 |
ckpt_dir.mkdir(exist_ok=True)
|
| 200 |
|
| 201 |
+
hf_branch = None
|
| 202 |
+
if args.hf_repo:
|
| 203 |
+
hf_branch = f"run/{out_dir.name}"
|
| 204 |
+
|
| 205 |
device = args.device
|
| 206 |
print(f"Device: {device}")
|
| 207 |
print(f"Output: {out_dir}")
|
|
|
|
| 321 |
global_step = 0
|
| 322 |
val_metrics = baseline
|
| 323 |
|
| 324 |
+
_shutdown_requested = False
|
| 325 |
+
def _graceful_exit(signum, frame):
|
| 326 |
+
nonlocal _shutdown_requested
|
| 327 |
+
_shutdown_requested = True
|
| 328 |
+
signal.signal(signal.SIGTERM, _graceful_exit)
|
| 329 |
+
signal.signal(signal.SIGINT, _graceful_exit)
|
| 330 |
+
|
| 331 |
print(f"\nTraining for up to {args.epochs} epochs ({total_steps} steps)")
|
| 332 |
print(f" Warmup: {warmup_steps} steps, LR: {args.lr}")
|
| 333 |
if args.val_every > 1:
|
|
|
|
| 414 |
if val_metrics["loss"] < best_val_loss:
|
| 415 |
best_val_loss = val_metrics["loss"]
|
| 416 |
patience_counter = 0
|
| 417 |
+
from pawn.checkpoint import save_adapter_checkpoint
|
| 418 |
+
save_adapter_checkpoint(
|
| 419 |
+
ckpt_dir / "best",
|
| 420 |
+
model.lora_state_dict(),
|
| 421 |
+
config=vars(args),
|
| 422 |
+
epoch=epoch,
|
| 423 |
+
step=global_step,
|
| 424 |
+
val_metrics=val_metrics,
|
| 425 |
+
)
|
| 426 |
+
if args.hf_repo and hf_branch:
|
| 427 |
+
from pawn.checkpoint import push_checkpoint_to_hf
|
| 428 |
+
try:
|
| 429 |
+
push_checkpoint_to_hf(ckpt_dir / "best", args.hf_repo, hf_branch, step=global_step)
|
| 430 |
+
print(f"Pushed to HF: {args.hf_repo}@{hf_branch}")
|
| 431 |
+
except Exception as e:
|
| 432 |
+
print(f"WARNING: HF push failed: {e}")
|
| 433 |
else:
|
| 434 |
patience_counter += 1
|
| 435 |
if patience_counter >= args.patience:
|
| 436 |
print(f"\n Early stopping at epoch {epoch} (patience={args.patience})")
|
| 437 |
break
|
| 438 |
|
| 439 |
+
if _shutdown_requested:
|
| 440 |
+
print("Shutdown requested, saving checkpoint...")
|
| 441 |
+
break
|
| 442 |
+
|
| 443 |
# Save final checkpoint
|
| 444 |
+
from pawn.checkpoint import save_adapter_checkpoint
|
| 445 |
+
save_adapter_checkpoint(
|
| 446 |
+
ckpt_dir / "final",
|
| 447 |
+
model.lora_state_dict(),
|
| 448 |
+
config=vars(args),
|
| 449 |
+
epoch=epoch,
|
| 450 |
+
step=global_step,
|
| 451 |
+
val_metrics=val_metrics,
|
| 452 |
+
)
|
| 453 |
+
if args.hf_repo and hf_branch:
|
| 454 |
+
from pawn.checkpoint import push_checkpoint_to_hf
|
| 455 |
+
try:
|
| 456 |
+
push_checkpoint_to_hf(ckpt_dir / "final", args.hf_repo, hf_branch, step=global_step)
|
| 457 |
+
print(f"Pushed to HF: {args.hf_repo}@{hf_branch}")
|
| 458 |
+
except Exception as e:
|
| 459 |
+
print(f"WARNING: HF push failed: {e}")
|
| 460 |
|
| 461 |
print(f"\nDone. Best val_loss={best_val_loss:.4f}")
|
| 462 |
print(f"Checkpoints saved to {out_dir}")
|
scripts/train_sparse.py
CHANGED
|
@@ -17,6 +17,7 @@ import argparse
|
|
| 17 |
import gc
|
| 18 |
import json
|
| 19 |
import math
|
|
|
|
| 20 |
import time
|
| 21 |
from pathlib import Path
|
| 22 |
|
|
@@ -83,15 +84,22 @@ def parse_args():
|
|
| 83 |
p.add_argument("--sdpa-math", action="store_true",
|
| 84 |
help="Use MATH SDPA backend (workaround for ROCm flash attn + compile)")
|
| 85 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 86 |
return p.parse_args()
|
| 87 |
|
| 88 |
|
| 89 |
def load_backbone(checkpoint_path: str, device: str) -> PAWNCLM:
|
| 90 |
-
|
| 91 |
-
|
|
|
|
| 92 |
model = PAWNCLM(cfg).to(device)
|
| 93 |
-
model.load_state_dict(
|
| 94 |
-
del
|
| 95 |
gc.collect()
|
| 96 |
model.eval()
|
| 97 |
return model
|
|
@@ -182,6 +190,10 @@ def main():
|
|
| 182 |
ckpt_dir = out_dir / "checkpoints"
|
| 183 |
ckpt_dir.mkdir(exist_ok=True)
|
| 184 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 185 |
device = args.device
|
| 186 |
print(f"Device: {device}")
|
| 187 |
print(f"Output: {out_dir}")
|
|
@@ -299,6 +311,13 @@ def main():
|
|
| 299 |
global_step = 0
|
| 300 |
val_metrics = baseline
|
| 301 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 302 |
print(f"\nTraining for up to {args.epochs} epochs ({total_steps} steps)")
|
| 303 |
print(f" Warmup: {warmup_steps} steps, LR: {args.lr}")
|
| 304 |
|
|
@@ -379,28 +398,48 @@ def main():
|
|
| 379 |
if val_metrics["loss"] < best_val_loss:
|
| 380 |
best_val_loss = val_metrics["loss"]
|
| 381 |
patience_counter = 0
|
| 382 |
-
|
| 383 |
-
|
| 384 |
-
"
|
| 385 |
-
|
| 386 |
-
|
| 387 |
-
|
| 388 |
-
|
| 389 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 390 |
else:
|
| 391 |
patience_counter += 1
|
| 392 |
if patience_counter >= args.patience:
|
| 393 |
print(f"\n Early stopping at epoch {epoch} (patience={args.patience})")
|
| 394 |
break
|
| 395 |
|
| 396 |
-
|
| 397 |
-
|
| 398 |
-
|
| 399 |
-
|
| 400 |
-
|
| 401 |
-
|
| 402 |
-
"
|
| 403 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 404 |
|
| 405 |
print(f"\nDone. Best val_loss={best_val_loss:.4f}")
|
| 406 |
print(f"Checkpoints saved to {out_dir}")
|
|
|
|
| 17 |
import gc
|
| 18 |
import json
|
| 19 |
import math
|
| 20 |
+
import signal
|
| 21 |
import time
|
| 22 |
from pathlib import Path
|
| 23 |
|
|
|
|
| 84 |
p.add_argument("--sdpa-math", action="store_true",
|
| 85 |
help="Use MATH SDPA backend (workaround for ROCm flash attn + compile)")
|
| 86 |
|
| 87 |
+
ckpt_group = p.add_mutually_exclusive_group(required=True)
|
| 88 |
+
ckpt_group.add_argument("--hf-repo", type=str, default=None,
|
| 89 |
+
help="Push checkpoints to this HuggingFace repo (requires HF_TOKEN)")
|
| 90 |
+
ckpt_group.add_argument("--local-checkpoints", action="store_true",
|
| 91 |
+
help="Save checkpoints locally only")
|
| 92 |
+
|
| 93 |
return p.parse_args()
|
| 94 |
|
| 95 |
|
| 96 |
def load_backbone(checkpoint_path: str, device: str) -> PAWNCLM:
|
| 97 |
+
from pawn.checkpoint import load_backbone_weights
|
| 98 |
+
state_dict, model_config = load_backbone_weights(checkpoint_path, device)
|
| 99 |
+
cfg = CLMConfig(**model_config) if model_config else CLMConfig()
|
| 100 |
model = PAWNCLM(cfg).to(device)
|
| 101 |
+
model.load_state_dict(state_dict)
|
| 102 |
+
del state_dict
|
| 103 |
gc.collect()
|
| 104 |
model.eval()
|
| 105 |
return model
|
|
|
|
| 190 |
ckpt_dir = out_dir / "checkpoints"
|
| 191 |
ckpt_dir.mkdir(exist_ok=True)
|
| 192 |
|
| 193 |
+
hf_branch = None
|
| 194 |
+
if args.hf_repo:
|
| 195 |
+
hf_branch = f"run/{out_dir.name}"
|
| 196 |
+
|
| 197 |
device = args.device
|
| 198 |
print(f"Device: {device}")
|
| 199 |
print(f"Output: {out_dir}")
|
|
|
|
| 311 |
global_step = 0
|
| 312 |
val_metrics = baseline
|
| 313 |
|
| 314 |
+
_shutdown_requested = False
|
| 315 |
+
def _graceful_exit(signum, frame):
|
| 316 |
+
nonlocal _shutdown_requested
|
| 317 |
+
_shutdown_requested = True
|
| 318 |
+
signal.signal(signal.SIGTERM, _graceful_exit)
|
| 319 |
+
signal.signal(signal.SIGINT, _graceful_exit)
|
| 320 |
+
|
| 321 |
print(f"\nTraining for up to {args.epochs} epochs ({total_steps} steps)")
|
| 322 |
print(f" Warmup: {warmup_steps} steps, LR: {args.lr}")
|
| 323 |
|
|
|
|
| 398 |
if val_metrics["loss"] < best_val_loss:
|
| 399 |
best_val_loss = val_metrics["loss"]
|
| 400 |
patience_counter = 0
|
| 401 |
+
from pawn.checkpoint import save_adapter_checkpoint
|
| 402 |
+
save_adapter_checkpoint(
|
| 403 |
+
ckpt_dir / "best",
|
| 404 |
+
model.sparse_state_dict(),
|
| 405 |
+
config=vars(args),
|
| 406 |
+
epoch=epoch,
|
| 407 |
+
step=global_step,
|
| 408 |
+
val_metrics=val_metrics,
|
| 409 |
+
)
|
| 410 |
+
if args.hf_repo and hf_branch:
|
| 411 |
+
from pawn.checkpoint import push_checkpoint_to_hf
|
| 412 |
+
try:
|
| 413 |
+
push_checkpoint_to_hf(ckpt_dir / "best", args.hf_repo, hf_branch, step=global_step)
|
| 414 |
+
print(f"Pushed to HF: {args.hf_repo}@{hf_branch}")
|
| 415 |
+
except Exception as e:
|
| 416 |
+
print(f"WARNING: HF push failed: {e}")
|
| 417 |
else:
|
| 418 |
patience_counter += 1
|
| 419 |
if patience_counter >= args.patience:
|
| 420 |
print(f"\n Early stopping at epoch {epoch} (patience={args.patience})")
|
| 421 |
break
|
| 422 |
|
| 423 |
+
if _shutdown_requested:
|
| 424 |
+
print("Shutdown requested, saving checkpoint...")
|
| 425 |
+
break
|
| 426 |
+
|
| 427 |
+
from pawn.checkpoint import save_adapter_checkpoint
|
| 428 |
+
save_adapter_checkpoint(
|
| 429 |
+
ckpt_dir / "final",
|
| 430 |
+
model.sparse_state_dict(),
|
| 431 |
+
config=vars(args),
|
| 432 |
+
epoch=epoch,
|
| 433 |
+
step=global_step,
|
| 434 |
+
val_metrics=val_metrics,
|
| 435 |
+
)
|
| 436 |
+
if args.hf_repo and hf_branch:
|
| 437 |
+
from pawn.checkpoint import push_checkpoint_to_hf
|
| 438 |
+
try:
|
| 439 |
+
push_checkpoint_to_hf(ckpt_dir / "final", args.hf_repo, hf_branch, step=global_step)
|
| 440 |
+
print(f"Pushed to HF: {args.hf_repo}@{hf_branch}")
|
| 441 |
+
except Exception as e:
|
| 442 |
+
print(f"WARNING: HF push failed: {e}")
|
| 443 |
|
| 444 |
print(f"\nDone. Best val_loss={best_val_loss:.4f}")
|
| 445 |
print(f"Checkpoints saved to {out_dir}")
|
scripts/train_tiny.py
CHANGED
|
@@ -15,6 +15,7 @@ from __future__ import annotations
|
|
| 15 |
|
| 16 |
import argparse
|
| 17 |
import math
|
|
|
|
| 18 |
import time
|
| 19 |
from pathlib import Path
|
| 20 |
|
|
@@ -228,6 +229,12 @@ def parse_args():
|
|
| 228 |
p.add_argument("--no-compile", action="store_true")
|
| 229 |
p.add_argument("--num-workers", type=int, default=3)
|
| 230 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 231 |
return p.parse_args()
|
| 232 |
|
| 233 |
|
|
@@ -255,6 +262,10 @@ def main():
|
|
| 255 |
ckpt_dir = out_dir / "checkpoints"
|
| 256 |
ckpt_dir.mkdir(exist_ok=True)
|
| 257 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 258 |
# Build model
|
| 259 |
model = TinyChessLM(
|
| 260 |
vocab_size=vocab_size,
|
|
@@ -358,6 +369,13 @@ def main():
|
|
| 358 |
global_step = 0
|
| 359 |
val_metrics = baseline
|
| 360 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 361 |
print(f"\nTraining for up to {args.epochs} epochs ({total_steps} steps)")
|
| 362 |
print(f" Warmup: {warmup_steps} steps, LR: {args.lr}")
|
| 363 |
|
|
@@ -428,24 +446,39 @@ def main():
|
|
| 428 |
if val_metrics["loss"] < best_val_loss:
|
| 429 |
best_val_loss = val_metrics["loss"]
|
| 430 |
patience_counter = 0
|
| 431 |
-
|
| 432 |
-
|
| 433 |
-
"
|
| 434 |
-
|
| 435 |
-
|
| 436 |
-
|
| 437 |
-
|
| 438 |
-
|
| 439 |
-
|
| 440 |
-
|
| 441 |
-
|
| 442 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 443 |
else:
|
| 444 |
patience_counter += 1
|
| 445 |
if patience_counter >= args.patience:
|
| 446 |
print(f"\n Early stopping at epoch {epoch} (patience={args.patience})")
|
| 447 |
break
|
| 448 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 449 |
logger.close()
|
| 450 |
print(f"\nDone. Best val_loss={best_val_loss:.4f}")
|
| 451 |
print(f"Checkpoints saved to {out_dir}")
|
|
|
|
| 15 |
|
| 16 |
import argparse
|
| 17 |
import math
|
| 18 |
+
import signal
|
| 19 |
import time
|
| 20 |
from pathlib import Path
|
| 21 |
|
|
|
|
| 229 |
p.add_argument("--no-compile", action="store_true")
|
| 230 |
p.add_argument("--num-workers", type=int, default=3)
|
| 231 |
|
| 232 |
+
ckpt_group = p.add_mutually_exclusive_group(required=True)
|
| 233 |
+
ckpt_group.add_argument("--hf-repo", type=str, default=None,
|
| 234 |
+
help="Push checkpoints to this HuggingFace repo (requires HF_TOKEN)")
|
| 235 |
+
ckpt_group.add_argument("--local-checkpoints", action="store_true",
|
| 236 |
+
help="Save checkpoints locally only")
|
| 237 |
+
|
| 238 |
return p.parse_args()
|
| 239 |
|
| 240 |
|
|
|
|
| 262 |
ckpt_dir = out_dir / "checkpoints"
|
| 263 |
ckpt_dir.mkdir(exist_ok=True)
|
| 264 |
|
| 265 |
+
hf_branch = None
|
| 266 |
+
if args.hf_repo:
|
| 267 |
+
hf_branch = f"run/{out_dir.name}"
|
| 268 |
+
|
| 269 |
# Build model
|
| 270 |
model = TinyChessLM(
|
| 271 |
vocab_size=vocab_size,
|
|
|
|
| 369 |
global_step = 0
|
| 370 |
val_metrics = baseline
|
| 371 |
|
| 372 |
+
_shutdown_requested = False
|
| 373 |
+
def _graceful_exit(signum, frame):
|
| 374 |
+
nonlocal _shutdown_requested
|
| 375 |
+
_shutdown_requested = True
|
| 376 |
+
signal.signal(signal.SIGTERM, _graceful_exit)
|
| 377 |
+
signal.signal(signal.SIGINT, _graceful_exit)
|
| 378 |
+
|
| 379 |
print(f"\nTraining for up to {args.epochs} epochs ({total_steps} steps)")
|
| 380 |
print(f" Warmup: {warmup_steps} steps, LR: {args.lr}")
|
| 381 |
|
|
|
|
| 446 |
if val_metrics["loss"] < best_val_loss:
|
| 447 |
best_val_loss = val_metrics["loss"]
|
| 448 |
patience_counter = 0
|
| 449 |
+
from pawn.checkpoint import save_pretrain_checkpoint
|
| 450 |
+
save_pretrain_checkpoint(
|
| 451 |
+
ckpt_dir / "best",
|
| 452 |
+
model=model,
|
| 453 |
+
optimizer=optimizer,
|
| 454 |
+
scheduler=scheduler,
|
| 455 |
+
scaler=scaler,
|
| 456 |
+
global_step=global_step,
|
| 457 |
+
model_config={
|
| 458 |
+
"d_model": args.d_model,
|
| 459 |
+
"n_layers": args.n_layers,
|
| 460 |
+
"n_heads": args.n_heads,
|
| 461 |
+
"d_ff": args.d_ff,
|
| 462 |
+
},
|
| 463 |
+
training_config=vars(args),
|
| 464 |
+
)
|
| 465 |
+
if args.hf_repo and hf_branch:
|
| 466 |
+
from pawn.checkpoint import push_checkpoint_to_hf
|
| 467 |
+
try:
|
| 468 |
+
push_checkpoint_to_hf(ckpt_dir / "best", args.hf_repo, hf_branch, step=global_step)
|
| 469 |
+
print(f"Pushed to HF: {args.hf_repo}@{hf_branch}")
|
| 470 |
+
except Exception as e:
|
| 471 |
+
print(f"WARNING: HF push failed: {e}")
|
| 472 |
else:
|
| 473 |
patience_counter += 1
|
| 474 |
if patience_counter >= args.patience:
|
| 475 |
print(f"\n Early stopping at epoch {epoch} (patience={args.patience})")
|
| 476 |
break
|
| 477 |
|
| 478 |
+
if _shutdown_requested:
|
| 479 |
+
print("Shutdown requested, saving checkpoint...")
|
| 480 |
+
break
|
| 481 |
+
|
| 482 |
logger.close()
|
| 483 |
print(f"\nDone. Best val_loss={best_val_loss:.4f}")
|
| 484 |
print(f"Checkpoints saved to {out_dir}")
|
tests/test_checkpoint.py
ADDED
|
@@ -0,0 +1,359 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Tests for pawn.checkpoint safetensors save/load."""
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
+
import tempfile
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
|
| 10 |
+
from pawn.config import CLMConfig
|
| 11 |
+
from pawn.model import PAWNCLM
|
| 12 |
+
import pytest
|
| 13 |
+
|
| 14 |
+
from pawn.checkpoint import (
|
| 15 |
+
save_pretrain_checkpoint,
|
| 16 |
+
load_pretrain_checkpoint,
|
| 17 |
+
save_adapter_checkpoint,
|
| 18 |
+
load_adapter_checkpoint,
|
| 19 |
+
load_backbone_weights,
|
| 20 |
+
is_legacy_checkpoint,
|
| 21 |
+
IncompleteCheckpointError,
|
| 22 |
+
CheckpointIntegrityError,
|
| 23 |
+
_flatten_optimizer_state,
|
| 24 |
+
_unflatten_optimizer_state,
|
| 25 |
+
_rng_to_json,
|
| 26 |
+
_json_to_rng,
|
| 27 |
+
_verify_complete_sentinel,
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def _make_toy_model(device="cpu"):
|
| 32 |
+
cfg = CLMConfig.toy()
|
| 33 |
+
model = PAWNCLM(cfg).to(device)
|
| 34 |
+
return model, cfg
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def _make_optimizer(model):
|
| 38 |
+
return torch.optim.AdamW(model.parameters(), lr=1e-3)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class FakeScheduler:
|
| 42 |
+
"""Minimal scheduler stand-in for testing."""
|
| 43 |
+
def __init__(self):
|
| 44 |
+
self._step = 0
|
| 45 |
+
|
| 46 |
+
def state_dict(self):
|
| 47 |
+
return {"step": self._step}
|
| 48 |
+
|
| 49 |
+
def load_state_dict(self, d):
|
| 50 |
+
self._step = d["step"]
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
class FakeScaler:
|
| 54 |
+
"""Minimal scaler stand-in for testing."""
|
| 55 |
+
def __init__(self):
|
| 56 |
+
self._scale = 65536.0
|
| 57 |
+
|
| 58 |
+
def state_dict(self):
|
| 59 |
+
return {"scale": self._scale, "_growth_tracker": 0}
|
| 60 |
+
|
| 61 |
+
def load_state_dict(self, d):
|
| 62 |
+
self._scale = d["scale"]
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def test_pretrain_checkpoint_roundtrip():
|
| 66 |
+
"""Save and load a pretrain checkpoint, verify model weights match."""
|
| 67 |
+
model1, cfg = _make_toy_model()
|
| 68 |
+
opt1 = _make_optimizer(model1)
|
| 69 |
+
sched1 = FakeScheduler()
|
| 70 |
+
sched1._step = 42
|
| 71 |
+
scaler1 = FakeScaler()
|
| 72 |
+
|
| 73 |
+
# Run a fake step to populate optimizer state
|
| 74 |
+
x = torch.randint(0, cfg.vocab_size, (2, cfg.max_seq_len))
|
| 75 |
+
mask = torch.ones(2, cfg.max_seq_len, dtype=torch.bool)
|
| 76 |
+
targets = torch.randint(0, cfg.vocab_size, (2, cfg.max_seq_len))
|
| 77 |
+
loss, _ = model1.forward_train(x, mask, targets)
|
| 78 |
+
loss.backward()
|
| 79 |
+
opt1.step()
|
| 80 |
+
|
| 81 |
+
with tempfile.TemporaryDirectory() as tmpdir:
|
| 82 |
+
ckpt_path = Path(tmpdir) / "step_00000001"
|
| 83 |
+
save_pretrain_checkpoint(
|
| 84 |
+
ckpt_path, model1, opt1, sched1, scaler1,
|
| 85 |
+
global_step=1,
|
| 86 |
+
model_config=cfg.__dict__,
|
| 87 |
+
training_config={"lr": 1e-3},
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
# Verify files exist
|
| 91 |
+
assert (ckpt_path / "model.safetensors").exists()
|
| 92 |
+
assert (ckpt_path / "optimizer.safetensors").exists()
|
| 93 |
+
assert (ckpt_path / "training_state.json").exists()
|
| 94 |
+
assert (ckpt_path / "config.json").exists()
|
| 95 |
+
|
| 96 |
+
# Load into fresh model
|
| 97 |
+
model2, _ = _make_toy_model()
|
| 98 |
+
opt2 = _make_optimizer(model2)
|
| 99 |
+
sched2 = FakeScheduler()
|
| 100 |
+
scaler2 = FakeScaler()
|
| 101 |
+
|
| 102 |
+
meta = load_pretrain_checkpoint(
|
| 103 |
+
ckpt_path, model2, opt2, sched2, scaler2
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
assert meta["global_step"] == 1
|
| 107 |
+
assert meta["model_config"]["d_model"] == cfg.d_model
|
| 108 |
+
assert sched2._step == 42
|
| 109 |
+
|
| 110 |
+
# Verify model weights match
|
| 111 |
+
for k in model1.state_dict():
|
| 112 |
+
torch.testing.assert_close(
|
| 113 |
+
model1.state_dict()[k],
|
| 114 |
+
model2.state_dict()[k],
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def test_optimizer_flatten_unflatten():
|
| 119 |
+
"""Verify optimizer state survives flatten/unflatten."""
|
| 120 |
+
model, cfg = _make_toy_model()
|
| 121 |
+
opt = _make_optimizer(model)
|
| 122 |
+
|
| 123 |
+
# Populate optimizer state
|
| 124 |
+
x = torch.randint(0, cfg.vocab_size, (2, cfg.max_seq_len))
|
| 125 |
+
mask = torch.ones(2, cfg.max_seq_len, dtype=torch.bool)
|
| 126 |
+
targets = torch.randint(0, cfg.vocab_size, (2, cfg.max_seq_len))
|
| 127 |
+
loss, _ = model.forward_train(x, mask, targets)
|
| 128 |
+
loss.backward()
|
| 129 |
+
opt.step()
|
| 130 |
+
|
| 131 |
+
orig_state = opt.state_dict()
|
| 132 |
+
tensors, meta = _flatten_optimizer_state(orig_state)
|
| 133 |
+
|
| 134 |
+
# All tensors should be named "state.{id}.{key}"
|
| 135 |
+
for k in tensors:
|
| 136 |
+
assert k.startswith("state."), f"Unexpected key: {k}"
|
| 137 |
+
|
| 138 |
+
restored = _unflatten_optimizer_state(tensors, meta)
|
| 139 |
+
|
| 140 |
+
# Verify param_groups match
|
| 141 |
+
assert len(restored["param_groups"]) == len(orig_state["param_groups"])
|
| 142 |
+
|
| 143 |
+
# Verify state tensors match
|
| 144 |
+
for param_id in orig_state["state"]:
|
| 145 |
+
for key in ("exp_avg", "exp_avg_sq"):
|
| 146 |
+
torch.testing.assert_close(
|
| 147 |
+
orig_state["state"][param_id][key],
|
| 148 |
+
restored["state"][param_id][key],
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
def test_rng_roundtrip():
|
| 153 |
+
"""Verify RNG state survives JSON serialization."""
|
| 154 |
+
torch_rng = torch.get_rng_state()
|
| 155 |
+
cuda_rng = None
|
| 156 |
+
|
| 157 |
+
data = _rng_to_json(torch_rng, cuda_rng)
|
| 158 |
+
assert "torch_rng_state" in data
|
| 159 |
+
assert "cuda_rng_state" not in data
|
| 160 |
+
|
| 161 |
+
restored_torch, restored_cuda = _json_to_rng(data)
|
| 162 |
+
assert restored_cuda is None
|
| 163 |
+
assert torch.equal(torch_rng, restored_torch)
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
def test_adapter_checkpoint_roundtrip():
|
| 167 |
+
"""Save and load an adapter checkpoint."""
|
| 168 |
+
adapter_weights = {
|
| 169 |
+
"down.weight": torch.randn(8, 64),
|
| 170 |
+
"up.weight": torch.zeros(64, 8),
|
| 171 |
+
}
|
| 172 |
+
|
| 173 |
+
with tempfile.TemporaryDirectory() as tmpdir:
|
| 174 |
+
ckpt_path = Path(tmpdir) / "best"
|
| 175 |
+
save_adapter_checkpoint(
|
| 176 |
+
ckpt_path,
|
| 177 |
+
adapter_weights,
|
| 178 |
+
config={"bottleneck_dim": 8, "checkpoint_type": "bottleneck"},
|
| 179 |
+
epoch=5,
|
| 180 |
+
step=1000,
|
| 181 |
+
val_metrics={"loss": 3.14, "top1_accuracy": 0.07},
|
| 182 |
+
extra={"best_val_loss": 3.10, "patience_counter": 2},
|
| 183 |
+
)
|
| 184 |
+
|
| 185 |
+
assert (ckpt_path / "adapter.safetensors").exists()
|
| 186 |
+
assert (ckpt_path / "config.json").exists()
|
| 187 |
+
assert (ckpt_path / "training_state.json").exists()
|
| 188 |
+
assert not (ckpt_path / "optimizer.safetensors").exists() # no optimizer passed
|
| 189 |
+
|
| 190 |
+
loaded = load_adapter_checkpoint(ckpt_path)
|
| 191 |
+
assert loaded["epoch"] == 5
|
| 192 |
+
assert loaded["step"] == 1000
|
| 193 |
+
assert loaded["val_metrics"]["loss"] == 3.14
|
| 194 |
+
assert loaded["best_val_loss"] == 3.10
|
| 195 |
+
|
| 196 |
+
for k in adapter_weights:
|
| 197 |
+
torch.testing.assert_close(adapter_weights[k], loaded["adapter_state_dict"][k])
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
def test_load_backbone_weights_new_format():
|
| 201 |
+
"""load_backbone_weights works with new directory format."""
|
| 202 |
+
model1, cfg = _make_toy_model()
|
| 203 |
+
|
| 204 |
+
with tempfile.TemporaryDirectory() as tmpdir:
|
| 205 |
+
ckpt_path = Path(tmpdir) / "step_00000001"
|
| 206 |
+
opt = _make_optimizer(model1)
|
| 207 |
+
save_pretrain_checkpoint(
|
| 208 |
+
ckpt_path, model1, opt, FakeScheduler(), FakeScaler(),
|
| 209 |
+
global_step=1, model_config=cfg.__dict__, training_config={},
|
| 210 |
+
)
|
| 211 |
+
|
| 212 |
+
state_dict, model_config = load_backbone_weights(ckpt_path)
|
| 213 |
+
assert model_config["d_model"] == cfg.d_model
|
| 214 |
+
for k in model1.state_dict():
|
| 215 |
+
torch.testing.assert_close(model1.state_dict()[k], state_dict[k])
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
def test_load_backbone_weights_legacy():
|
| 219 |
+
"""load_backbone_weights works with legacy .pt files."""
|
| 220 |
+
model1, cfg = _make_toy_model()
|
| 221 |
+
|
| 222 |
+
with tempfile.TemporaryDirectory() as tmpdir:
|
| 223 |
+
pt_path = Path(tmpdir) / "model.pt"
|
| 224 |
+
torch.save({
|
| 225 |
+
"model_state_dict": model1.state_dict(),
|
| 226 |
+
"model_config": cfg.__dict__,
|
| 227 |
+
}, pt_path)
|
| 228 |
+
|
| 229 |
+
state_dict, model_config = load_backbone_weights(pt_path)
|
| 230 |
+
assert model_config["d_model"] == cfg.d_model
|
| 231 |
+
for k in model1.state_dict():
|
| 232 |
+
torch.testing.assert_close(model1.state_dict()[k], state_dict[k])
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
def test_is_legacy_checkpoint():
|
| 236 |
+
"""Detect legacy vs new format."""
|
| 237 |
+
with tempfile.TemporaryDirectory() as tmpdir:
|
| 238 |
+
pt = Path(tmpdir) / "step.pt"
|
| 239 |
+
pt.touch()
|
| 240 |
+
assert is_legacy_checkpoint(pt)
|
| 241 |
+
|
| 242 |
+
d = Path(tmpdir) / "step_00001"
|
| 243 |
+
d.mkdir()
|
| 244 |
+
assert not is_legacy_checkpoint(d)
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
def test_config_json_contents():
|
| 248 |
+
"""Verify config.json has expected structure."""
|
| 249 |
+
model, cfg = _make_toy_model()
|
| 250 |
+
opt = _make_optimizer(model)
|
| 251 |
+
|
| 252 |
+
with tempfile.TemporaryDirectory() as tmpdir:
|
| 253 |
+
ckpt_path = Path(tmpdir) / "step_00000001"
|
| 254 |
+
save_pretrain_checkpoint(
|
| 255 |
+
ckpt_path, model, opt, FakeScheduler(), FakeScaler(),
|
| 256 |
+
global_step=1, model_config=cfg.__dict__,
|
| 257 |
+
training_config={"lr": 1e-3, "batch_size": 32},
|
| 258 |
+
)
|
| 259 |
+
|
| 260 |
+
with open(ckpt_path / "config.json") as f:
|
| 261 |
+
config = json.load(f)
|
| 262 |
+
|
| 263 |
+
assert config["format_version"] == 1
|
| 264 |
+
assert config["checkpoint_type"] == "pretrain"
|
| 265 |
+
assert config["model_config"]["d_model"] == cfg.d_model
|
| 266 |
+
assert config["training_config"]["lr"] == 1e-3
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
def test_complete_sentinel_written():
|
| 270 |
+
"""Verify .complete sentinel is created with SHA-256 hashes."""
|
| 271 |
+
model, cfg = _make_toy_model()
|
| 272 |
+
opt = _make_optimizer(model)
|
| 273 |
+
|
| 274 |
+
with tempfile.TemporaryDirectory() as tmpdir:
|
| 275 |
+
ckpt_path = Path(tmpdir) / "step_00000001"
|
| 276 |
+
save_pretrain_checkpoint(
|
| 277 |
+
ckpt_path, model, opt, FakeScheduler(), FakeScaler(),
|
| 278 |
+
global_step=1, model_config=cfg.__dict__, training_config={},
|
| 279 |
+
)
|
| 280 |
+
|
| 281 |
+
sentinel_path = ckpt_path / ".complete"
|
| 282 |
+
assert sentinel_path.exists()
|
| 283 |
+
|
| 284 |
+
with open(sentinel_path) as f:
|
| 285 |
+
sentinel = json.load(f)
|
| 286 |
+
|
| 287 |
+
assert "files" in sentinel
|
| 288 |
+
assert "model.safetensors" in sentinel["files"]
|
| 289 |
+
assert "config.json" in sentinel["files"]
|
| 290 |
+
assert "training_state.json" in sentinel["files"]
|
| 291 |
+
# Each hash should be a 64-char hex string
|
| 292 |
+
for name, h in sentinel["files"].items():
|
| 293 |
+
assert len(h) == 64, f"Bad hash length for {name}: {len(h)}"
|
| 294 |
+
|
| 295 |
+
|
| 296 |
+
def test_no_tmp_directory_after_save():
|
| 297 |
+
"""Verify .tmp directory is cleaned up after successful save."""
|
| 298 |
+
model, cfg = _make_toy_model()
|
| 299 |
+
opt = _make_optimizer(model)
|
| 300 |
+
|
| 301 |
+
with tempfile.TemporaryDirectory() as tmpdir:
|
| 302 |
+
ckpt_path = Path(tmpdir) / "step_00000001"
|
| 303 |
+
save_pretrain_checkpoint(
|
| 304 |
+
ckpt_path, model, opt, FakeScheduler(), FakeScaler(),
|
| 305 |
+
global_step=1, model_config=cfg.__dict__, training_config={},
|
| 306 |
+
)
|
| 307 |
+
|
| 308 |
+
tmp_path = Path(tmpdir) / "step_00000001.tmp"
|
| 309 |
+
assert not tmp_path.exists(), ".tmp directory should be removed after rename"
|
| 310 |
+
|
| 311 |
+
|
| 312 |
+
def test_incomplete_checkpoint_raises():
|
| 313 |
+
"""Loading a checkpoint without .complete raises IncompleteCheckpointError."""
|
| 314 |
+
with tempfile.TemporaryDirectory() as tmpdir:
|
| 315 |
+
ckpt_path = Path(tmpdir) / "step_bad"
|
| 316 |
+
ckpt_path.mkdir()
|
| 317 |
+
# Write some files but no .complete
|
| 318 |
+
(ckpt_path / "model.safetensors").touch()
|
| 319 |
+
(ckpt_path / "config.json").write_text("{}")
|
| 320 |
+
|
| 321 |
+
with pytest.raises(IncompleteCheckpointError):
|
| 322 |
+
_verify_complete_sentinel(ckpt_path)
|
| 323 |
+
|
| 324 |
+
|
| 325 |
+
def test_corrupted_file_raises():
|
| 326 |
+
"""Loading a checkpoint with a tampered file raises CheckpointIntegrityError."""
|
| 327 |
+
model, cfg = _make_toy_model()
|
| 328 |
+
opt = _make_optimizer(model)
|
| 329 |
+
|
| 330 |
+
with tempfile.TemporaryDirectory() as tmpdir:
|
| 331 |
+
ckpt_path = Path(tmpdir) / "step_00000001"
|
| 332 |
+
save_pretrain_checkpoint(
|
| 333 |
+
ckpt_path, model, opt, FakeScheduler(), FakeScaler(),
|
| 334 |
+
global_step=1, model_config=cfg.__dict__, training_config={},
|
| 335 |
+
)
|
| 336 |
+
|
| 337 |
+
# Corrupt a file after save
|
| 338 |
+
config_path = ckpt_path / "config.json"
|
| 339 |
+
config_path.write_text("CORRUPTED DATA")
|
| 340 |
+
|
| 341 |
+
with pytest.raises(CheckpointIntegrityError):
|
| 342 |
+
_verify_complete_sentinel(ckpt_path)
|
| 343 |
+
|
| 344 |
+
|
| 345 |
+
def test_adapter_checkpoint_has_sentinel():
|
| 346 |
+
"""Adapter checkpoints also get .complete sentinel."""
|
| 347 |
+
adapter_weights = {"down.weight": torch.randn(8, 64)}
|
| 348 |
+
|
| 349 |
+
with tempfile.TemporaryDirectory() as tmpdir:
|
| 350 |
+
ckpt_path = Path(tmpdir) / "best"
|
| 351 |
+
save_adapter_checkpoint(
|
| 352 |
+
ckpt_path, adapter_weights,
|
| 353 |
+
config={"checkpoint_type": "bottleneck"},
|
| 354 |
+
epoch=1, step=100, val_metrics={"loss": 3.0},
|
| 355 |
+
)
|
| 356 |
+
|
| 357 |
+
assert (ckpt_path / ".complete").exists()
|
| 358 |
+
# Should not raise
|
| 359 |
+
_verify_complete_sentinel(ckpt_path)
|
tests/test_clm_format.py
ADDED
|
@@ -0,0 +1,224 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Tests for CLM sequence format, off-by-one boundaries, and engine/Python consistency."""
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
import pytest
|
| 6 |
+
|
| 7 |
+
import chess_engine as engine
|
| 8 |
+
|
| 9 |
+
from pawn.config import (
|
| 10 |
+
PAD_TOKEN,
|
| 11 |
+
OUTCOME_TOKEN_BASE,
|
| 12 |
+
WHITE_CHECKMATES,
|
| 13 |
+
BLACK_CHECKMATES,
|
| 14 |
+
STALEMATE,
|
| 15 |
+
DRAW_BY_RULE,
|
| 16 |
+
PLY_LIMIT,
|
| 17 |
+
)
|
| 18 |
+
from pawn.data import (
|
| 19 |
+
_map_termination_to_outcome,
|
| 20 |
+
pack_clm_sequences,
|
| 21 |
+
_to_clm_batch,
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
# ---------------------------------------------------------------------------
|
| 26 |
+
# Vocabulary consistency
|
| 27 |
+
# ---------------------------------------------------------------------------
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class TestVocabConsistency:
|
| 31 |
+
def test_vocab_size(self):
|
| 32 |
+
"""Engine move vocab + PAD + outcomes = model vocab_size."""
|
| 33 |
+
vocab = engine.export_move_vocabulary()
|
| 34 |
+
n_moves = len(vocab["token_to_move"])
|
| 35 |
+
# 4096 grid + 176 promotions = 4272 move tokens
|
| 36 |
+
assert n_moves == 4272
|
| 37 |
+
# 1 PAD + 4272 moves + 5 outcomes = 4278
|
| 38 |
+
from pawn.config import CLMConfig
|
| 39 |
+
assert CLMConfig().vocab_size == 1 + n_moves + 5
|
| 40 |
+
|
| 41 |
+
def test_outcome_tokens_not_in_move_vocab(self):
|
| 42 |
+
"""Outcome tokens (4273-4277) must not appear in the move vocabulary."""
|
| 43 |
+
vocab = engine.export_move_vocabulary()
|
| 44 |
+
for token_id in range(OUTCOME_TOKEN_BASE, PLY_LIMIT + 1):
|
| 45 |
+
assert token_id not in vocab["token_to_move"], \
|
| 46 |
+
f"Outcome token {token_id} should not be in move vocab"
|
| 47 |
+
|
| 48 |
+
def test_no_eog_in_raw_move_ids(self):
|
| 49 |
+
"""Raw move_ids from generate_random_games should not contain
|
| 50 |
+
tokens >= OUTCOME_BASE (no EOG or outcome tokens in move data)."""
|
| 51 |
+
move_ids, game_lengths, _tc = engine.generate_random_games(64, 255, seed=42)
|
| 52 |
+
for b in range(64):
|
| 53 |
+
gl = int(game_lengths[b])
|
| 54 |
+
# Moves should be in range 1-4272
|
| 55 |
+
for t in range(gl):
|
| 56 |
+
tok = int(move_ids[b, t])
|
| 57 |
+
assert 1 <= tok <= 4272, \
|
| 58 |
+
f"Game {b}, ply {t}: expected move token, got {tok}"
|
| 59 |
+
# Position game_length should be PAD (0), not EOG
|
| 60 |
+
if gl < 255:
|
| 61 |
+
assert move_ids[b, gl] == PAD_TOKEN, \
|
| 62 |
+
f"Game {b}: position {gl} should be PAD, got {move_ids[b, gl]}"
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
# ---------------------------------------------------------------------------
|
| 66 |
+
# CLM batch format (Rust engine)
|
| 67 |
+
# ---------------------------------------------------------------------------
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
class TestRustCLMBatch:
|
| 71 |
+
@pytest.fixture
|
| 72 |
+
def clm_batch(self):
|
| 73 |
+
return engine.generate_clm_batch(32, 256, seed=42)
|
| 74 |
+
|
| 75 |
+
def test_shapes(self, clm_batch):
|
| 76 |
+
input_ids, targets, loss_mask, move_ids, game_lengths, term_codes = clm_batch
|
| 77 |
+
assert input_ids.shape == (32, 256)
|
| 78 |
+
assert targets.shape == (32, 256)
|
| 79 |
+
assert loss_mask.shape == (32, 256)
|
| 80 |
+
assert move_ids.shape == (32, 255)
|
| 81 |
+
assert game_lengths.shape == (32,)
|
| 82 |
+
assert term_codes.shape == (32,)
|
| 83 |
+
|
| 84 |
+
def test_position_zero_is_outcome(self, clm_batch):
|
| 85 |
+
input_ids, *_ = clm_batch
|
| 86 |
+
for b in range(input_ids.shape[0]):
|
| 87 |
+
tok = int(input_ids[b, 0])
|
| 88 |
+
assert OUTCOME_TOKEN_BASE <= tok <= PLY_LIMIT, \
|
| 89 |
+
f"Game {b}: position 0 should be outcome token, got {tok}"
|
| 90 |
+
|
| 91 |
+
def test_moves_in_valid_range(self, clm_batch):
|
| 92 |
+
input_ids, _, _, _, game_lengths, _ = clm_batch
|
| 93 |
+
for b in range(input_ids.shape[0]):
|
| 94 |
+
gl = int(game_lengths[b])
|
| 95 |
+
for t in range(1, gl + 1):
|
| 96 |
+
tok = int(input_ids[b, t])
|
| 97 |
+
assert 1 <= tok <= 4272, \
|
| 98 |
+
f"Game {b}, position {t}: expected move token, got {tok}"
|
| 99 |
+
|
| 100 |
+
def test_padding_is_zero(self, clm_batch):
|
| 101 |
+
input_ids, _, _, _, game_lengths, _ = clm_batch
|
| 102 |
+
for b in range(input_ids.shape[0]):
|
| 103 |
+
gl = int(game_lengths[b])
|
| 104 |
+
for t in range(gl + 1, 256):
|
| 105 |
+
assert input_ids[b, t] == 0, \
|
| 106 |
+
f"Game {b}, position {t}: expected PAD, got {input_ids[b, t]}"
|
| 107 |
+
|
| 108 |
+
def test_target_shift_correct(self, clm_batch):
|
| 109 |
+
input_ids, targets, *_ = clm_batch
|
| 110 |
+
B, T = input_ids.shape
|
| 111 |
+
for b in range(B):
|
| 112 |
+
for t in range(T - 1):
|
| 113 |
+
assert targets[b, t] == input_ids[b, t + 1], \
|
| 114 |
+
f"Game {b}: targets[{t}]={targets[b, t]} != input_ids[{t+1}]={input_ids[b, t+1]}"
|
| 115 |
+
assert targets[b, T - 1] == 0, "Last target should be PAD"
|
| 116 |
+
|
| 117 |
+
def test_target_at_game_end_is_pad(self, clm_batch):
|
| 118 |
+
_, targets, _, _, game_lengths, _ = clm_batch
|
| 119 |
+
for b in range(targets.shape[0]):
|
| 120 |
+
gl = int(game_lengths[b])
|
| 121 |
+
assert targets[b, gl] == 0, \
|
| 122 |
+
f"Game {b}: target at game_length={gl} should be PAD, got {targets[b, gl]}"
|
| 123 |
+
|
| 124 |
+
def test_loss_mask_boundary(self, clm_batch):
|
| 125 |
+
_, _, loss_mask, _, game_lengths, _ = clm_batch
|
| 126 |
+
for b in range(loss_mask.shape[0]):
|
| 127 |
+
gl = int(game_lengths[b])
|
| 128 |
+
# True for positions 0..=gl
|
| 129 |
+
for t in range(gl + 1):
|
| 130 |
+
assert loss_mask[b, t], \
|
| 131 |
+
f"Game {b}: loss_mask[{t}] should be True (gl={gl})"
|
| 132 |
+
# False after gl
|
| 133 |
+
for t in range(gl + 1, 256):
|
| 134 |
+
assert not loss_mask[b, t], \
|
| 135 |
+
f"Game {b}: loss_mask[{t}] should be False (gl={gl})"
|
| 136 |
+
|
| 137 |
+
def test_raw_move_ids_replayable(self, clm_batch):
|
| 138 |
+
"""Raw move_ids from generate_clm_batch should work with replay functions."""
|
| 139 |
+
_, _, _, move_ids, game_lengths, _ = clm_batch
|
| 140 |
+
# validate_games should confirm all games are legal
|
| 141 |
+
is_valid, first_illegal = engine.validate_games(move_ids, game_lengths)
|
| 142 |
+
assert all(is_valid), "All generated games should be valid"
|
| 143 |
+
# compute_legal_move_masks should not error
|
| 144 |
+
grid, promo = engine.compute_legal_move_masks(move_ids, game_lengths)
|
| 145 |
+
assert grid.shape[0] == 32
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
# ---------------------------------------------------------------------------
|
| 149 |
+
# Rust CLM matches Python pack_clm_sequences
|
| 150 |
+
# ---------------------------------------------------------------------------
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
class TestRustPythonConsistency:
|
| 154 |
+
def test_rust_clm_matches_python_pack(self):
|
| 155 |
+
"""Rust generate_clm_batch should produce identical output to
|
| 156 |
+
Python _to_clm_batch with the same seed."""
|
| 157 |
+
seq_len = 256
|
| 158 |
+
seed = 42
|
| 159 |
+
B = 16
|
| 160 |
+
|
| 161 |
+
# Rust path
|
| 162 |
+
r_input_ids, r_targets, r_loss_mask, r_move_ids, r_gl, r_tc = \
|
| 163 |
+
engine.generate_clm_batch(B, seq_len, seed)
|
| 164 |
+
|
| 165 |
+
# Python path: generate raw + pack
|
| 166 |
+
py_batch = _to_clm_batch(r_move_ids, r_gl, r_tc, seq_len)
|
| 167 |
+
|
| 168 |
+
# Compare
|
| 169 |
+
r_input_ids_t = torch.from_numpy(r_input_ids).long()
|
| 170 |
+
r_targets_t = torch.from_numpy(r_targets).long()
|
| 171 |
+
r_loss_mask_t = torch.from_numpy(r_loss_mask)
|
| 172 |
+
|
| 173 |
+
assert torch.equal(r_input_ids_t, py_batch["input_ids"]), \
|
| 174 |
+
"input_ids mismatch between Rust and Python"
|
| 175 |
+
assert torch.equal(r_targets_t, py_batch["targets"]), \
|
| 176 |
+
"targets mismatch between Rust and Python"
|
| 177 |
+
assert torch.equal(r_loss_mask_t, py_batch["loss_mask"]), \
|
| 178 |
+
"loss_mask mismatch between Rust and Python"
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
# ---------------------------------------------------------------------------
|
| 182 |
+
# Boundary / edge cases
|
| 183 |
+
# ---------------------------------------------------------------------------
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
class TestBoundaryCases:
|
| 187 |
+
def test_boundary_max_length_game(self):
|
| 188 |
+
"""Test with games that may fill all 255 move slots.
|
| 189 |
+
|
| 190 |
+
We generate many games and check that any near-max-length game
|
| 191 |
+
is correctly formatted (no off-by-one at the boundary).
|
| 192 |
+
"""
|
| 193 |
+
input_ids, targets, loss_mask, _, game_lengths, _ = \
|
| 194 |
+
engine.generate_clm_batch(256, 256, seed=123)
|
| 195 |
+
|
| 196 |
+
for b in range(256):
|
| 197 |
+
gl = int(game_lengths[b])
|
| 198 |
+
# Regardless of length, format invariants hold
|
| 199 |
+
assert OUTCOME_TOKEN_BASE <= input_ids[b, 0] <= PLY_LIMIT
|
| 200 |
+
if gl + 1 < 256:
|
| 201 |
+
assert input_ids[b, gl + 1] == 0
|
| 202 |
+
assert loss_mask[b, gl] == True
|
| 203 |
+
if gl + 1 < 256:
|
| 204 |
+
assert loss_mask[b, gl + 1] == False
|
| 205 |
+
|
| 206 |
+
def test_discard_ply_limit(self):
|
| 207 |
+
"""generate_clm_batch with discard_ply_limit should only have
|
| 208 |
+
naturally-terminated games (no PLY_LIMIT outcome)."""
|
| 209 |
+
input_ids, _, _, _, _, term_codes = \
|
| 210 |
+
engine.generate_clm_batch(32, 256, seed=42, discard_ply_limit=True)
|
| 211 |
+
|
| 212 |
+
for b in range(32):
|
| 213 |
+
assert term_codes[b] != 5, \
|
| 214 |
+
f"Game {b} has PLY_LIMIT termination but discard_ply_limit=True"
|
| 215 |
+
assert input_ids[b, 0] != PLY_LIMIT, \
|
| 216 |
+
f"Game {b} has PLY_LIMIT outcome token but discard_ply_limit=True"
|
| 217 |
+
|
| 218 |
+
def test_determinism(self):
|
| 219 |
+
"""Same seed should produce identical results."""
|
| 220 |
+
r1 = engine.generate_clm_batch(8, 256, seed=99)
|
| 221 |
+
r2 = engine.generate_clm_batch(8, 256, seed=99)
|
| 222 |
+
assert np.array_equal(r1[0], r2[0]), "input_ids should be deterministic"
|
| 223 |
+
assert np.array_equal(r1[1], r2[1]), "targets should be deterministic"
|
| 224 |
+
assert np.array_equal(r1[2], r2[2]), "loss_mask should be deterministic"
|
tests/test_graceful_shutdown.py
ADDED
|
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Regression test: training processes exit gracefully on SIGTERM.
|
| 2 |
+
|
| 3 |
+
Spawns a toy training run as a subprocess, sends SIGTERM, and verifies
|
| 4 |
+
that the process exits cleanly with valid checkpoints.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import json
|
| 8 |
+
import os
|
| 9 |
+
import signal
|
| 10 |
+
import subprocess
|
| 11 |
+
import sys
|
| 12 |
+
import tempfile
|
| 13 |
+
import time
|
| 14 |
+
from pathlib import Path
|
| 15 |
+
|
| 16 |
+
import pytest
|
| 17 |
+
|
| 18 |
+
from pawn.checkpoint import load_backbone_weights, _verify_complete_sentinel
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
@pytest.fixture
|
| 22 |
+
def train_tmpdir():
|
| 23 |
+
with tempfile.TemporaryDirectory() as d:
|
| 24 |
+
yield Path(d)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def _wait_for_training_start(proc: subprocess.Popen, timeout: float = 30) -> None:
|
| 28 |
+
"""Wait until the training process prints 'Starting training'."""
|
| 29 |
+
deadline = time.time() + timeout
|
| 30 |
+
assert proc.stdout is not None
|
| 31 |
+
while time.time() < deadline:
|
| 32 |
+
line = proc.stdout.readline()
|
| 33 |
+
if not line:
|
| 34 |
+
time.sleep(0.1)
|
| 35 |
+
continue
|
| 36 |
+
text = line.decode("utf-8", errors="replace")
|
| 37 |
+
if "Starting training" in text:
|
| 38 |
+
return
|
| 39 |
+
raise TimeoutError("Training process did not start within timeout")
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def test_sigterm_produces_valid_checkpoint(train_tmpdir):
|
| 43 |
+
"""Spawn a toy training run, send SIGTERM, verify checkpoint is valid."""
|
| 44 |
+
ckpt_dir = train_tmpdir / "checkpoints"
|
| 45 |
+
log_dir = train_tmpdir / "logs"
|
| 46 |
+
|
| 47 |
+
proc = subprocess.Popen(
|
| 48 |
+
[
|
| 49 |
+
sys.executable, "scripts/train.py",
|
| 50 |
+
"--toy",
|
| 51 |
+
"--local-checkpoints",
|
| 52 |
+
"--total-steps", "5000",
|
| 53 |
+
"--device", "cpu",
|
| 54 |
+
"--num-workers", "0",
|
| 55 |
+
"--checkpoint-dir", str(ckpt_dir),
|
| 56 |
+
"--log-dir", str(log_dir),
|
| 57 |
+
],
|
| 58 |
+
stdout=subprocess.PIPE,
|
| 59 |
+
stderr=subprocess.STDOUT,
|
| 60 |
+
env={**os.environ, "PYTHONUNBUFFERED": "1"},
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
try:
|
| 64 |
+
_wait_for_training_start(proc)
|
| 65 |
+
|
| 66 |
+
# Let it train for a few checkpoints
|
| 67 |
+
time.sleep(5)
|
| 68 |
+
|
| 69 |
+
# Send SIGTERM
|
| 70 |
+
proc.send_signal(signal.SIGTERM)
|
| 71 |
+
|
| 72 |
+
# Wait for graceful exit
|
| 73 |
+
returncode = proc.wait(timeout=30)
|
| 74 |
+
except Exception:
|
| 75 |
+
proc.kill()
|
| 76 |
+
proc.wait()
|
| 77 |
+
raise
|
| 78 |
+
|
| 79 |
+
# Process should exit cleanly (0 from break, not 128+signal from sys.exit)
|
| 80 |
+
assert returncode == 0, f"Training exited with code {returncode}"
|
| 81 |
+
|
| 82 |
+
# Find checkpoints (trainer saves under {log_dir}/run_*/checkpoints/)
|
| 83 |
+
checkpoints = sorted(log_dir.glob("run_*/checkpoints/step_*/"))
|
| 84 |
+
assert len(checkpoints) > 0, "No checkpoints were saved"
|
| 85 |
+
|
| 86 |
+
# Every checkpoint must have .complete sentinel and pass integrity check
|
| 87 |
+
for ckpt in checkpoints:
|
| 88 |
+
assert (ckpt / ".complete").exists(), f"Missing .complete in {ckpt.name}"
|
| 89 |
+
_verify_complete_sentinel(ckpt) # raises on failure
|
| 90 |
+
|
| 91 |
+
# Verify we can actually load the weights
|
| 92 |
+
weights, config = load_backbone_weights(ckpt)
|
| 93 |
+
assert config is not None, f"No config in {ckpt.name}"
|
| 94 |
+
assert len(weights) > 0, f"Empty weights in {ckpt.name}"
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def test_sigterm_does_not_leave_tmp_dirs(train_tmpdir):
|
| 98 |
+
"""SIGTERM should not leave .tmp directories from interrupted saves."""
|
| 99 |
+
ckpt_dir = train_tmpdir / "checkpoints"
|
| 100 |
+
log_dir = train_tmpdir / "logs"
|
| 101 |
+
|
| 102 |
+
proc = subprocess.Popen(
|
| 103 |
+
[
|
| 104 |
+
sys.executable, "scripts/train.py",
|
| 105 |
+
"--toy",
|
| 106 |
+
"--local-checkpoints",
|
| 107 |
+
"--total-steps", "5000",
|
| 108 |
+
"--device", "cpu",
|
| 109 |
+
"--num-workers", "0",
|
| 110 |
+
"--checkpoint-dir", str(ckpt_dir),
|
| 111 |
+
"--log-dir", str(log_dir),
|
| 112 |
+
],
|
| 113 |
+
stdout=subprocess.PIPE,
|
| 114 |
+
stderr=subprocess.STDOUT,
|
| 115 |
+
env={**os.environ, "PYTHONUNBUFFERED": "1"},
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
try:
|
| 119 |
+
_wait_for_training_start(proc)
|
| 120 |
+
time.sleep(5)
|
| 121 |
+
proc.send_signal(signal.SIGTERM)
|
| 122 |
+
proc.wait(timeout=30)
|
| 123 |
+
except Exception:
|
| 124 |
+
proc.kill()
|
| 125 |
+
proc.wait()
|
| 126 |
+
raise
|
| 127 |
+
|
| 128 |
+
# No .tmp directories should exist anywhere under log_dir
|
| 129 |
+
tmp_dirs = list(log_dir.glob("**/*.tmp"))
|
| 130 |
+
assert len(tmp_dirs) == 0, f"Leftover .tmp directories: {tmp_dirs}"
|
uv.lock
CHANGED
|
@@ -20,24 +20,6 @@ members = [
|
|
| 20 |
"pawn",
|
| 21 |
]
|
| 22 |
|
| 23 |
-
[[package]]
|
| 24 |
-
name = "aiofiles"
|
| 25 |
-
version = "24.1.0"
|
| 26 |
-
source = { registry = "https://pypi.org/simple" }
|
| 27 |
-
sdist = { url = "https://files.pythonhosted.org/packages/0b/03/a88171e277e8caa88a4c77808c20ebb04ba74cc4681bf1e9416c862de237/aiofiles-24.1.0.tar.gz", hash = "sha256:22a075c9e5a3810f0c2e48f3008c94d68c65d763b9b03857924c99e57355166c", size = 30247, upload-time = "2024-06-24T11:02:03.584Z" }
|
| 28 |
-
wheels = [
|
| 29 |
-
{ url = "https://files.pythonhosted.org/packages/a5/45/30bb92d442636f570cb5651bc661f52b610e2eec3f891a5dc3a4c3667db0/aiofiles-24.1.0-py3-none-any.whl", hash = "sha256:b4ec55f4195e3eb5d7abd1bf7e061763e864dd4954231fb8539a0ef8bb8260e5", size = 15896, upload-time = "2024-06-24T11:02:01.529Z" },
|
| 30 |
-
]
|
| 31 |
-
|
| 32 |
-
[[package]]
|
| 33 |
-
name = "annotated-doc"
|
| 34 |
-
version = "0.0.4"
|
| 35 |
-
source = { registry = "https://pypi.org/simple" }
|
| 36 |
-
sdist = { url = "https://files.pythonhosted.org/packages/57/ba/046ceea27344560984e26a590f90bc7f4a75b06701f653222458922b558c/annotated_doc-0.0.4.tar.gz", hash = "sha256:fbcda96e87e9c92ad167c2e53839e57503ecfda18804ea28102353485033faa4", size = 7288, upload-time = "2025-11-10T22:07:42.062Z" }
|
| 37 |
-
wheels = [
|
| 38 |
-
{ url = "https://files.pythonhosted.org/packages/1e/d3/26bf1008eb3d2daa8ef4cacc7f3bfdc11818d111f7e2d0201bc6e3b49d45/annotated_doc-0.0.4-py3-none-any.whl", hash = "sha256:571ac1dc6991c450b25a9c2d84a3705e2ae7a53467b5d111c24fa8baabbed320", size = 5303, upload-time = "2025-11-10T22:07:40.673Z" },
|
| 39 |
-
]
|
| 40 |
-
|
| 41 |
[[package]]
|
| 42 |
name = "annotated-types"
|
| 43 |
version = "0.7.0"
|
|
@@ -93,44 +75,6 @@ wheels = [
|
|
| 93 |
{ url = "https://files.pythonhosted.org/packages/64/b4/17d4b0b2a2dc85a6df63d1157e028ed19f90d4cd97c36717afef2bc2f395/attrs-26.1.0-py3-none-any.whl", hash = "sha256:c647aa4a12dfbad9333ca4e71fe62ddc36f4e63b2d260a37a8b83d2f043ac309", size = 67548, upload-time = "2026-03-19T14:22:23.645Z" },
|
| 94 |
]
|
| 95 |
|
| 96 |
-
[[package]]
|
| 97 |
-
name = "brotli"
|
| 98 |
-
version = "1.2.0"
|
| 99 |
-
source = { registry = "https://pypi.org/simple" }
|
| 100 |
-
sdist = { url = "https://files.pythonhosted.org/packages/f7/16/c92ca344d646e71a43b8bb353f0a6490d7f6e06210f8554c8f874e454285/brotli-1.2.0.tar.gz", hash = "sha256:e310f77e41941c13340a95976fe66a8a95b01e783d430eeaf7a2f87e0a57dd0a", size = 7388632, upload-time = "2025-11-05T18:39:42.86Z" }
|
| 101 |
-
wheels = [
|
| 102 |
-
{ url = "https://files.pythonhosted.org/packages/64/10/a090475284fc4a71aed40a96f32e44a7fe5bda39687353dd977720b211b6/brotli-1.2.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:3b90b767916ac44e93a8e28ce6adf8d551e43affb512f2377c732d486ac6514e", size = 863089, upload-time = "2025-11-05T18:38:01.181Z" },
|
| 103 |
-
{ url = "https://files.pythonhosted.org/packages/03/41/17416630e46c07ac21e378c3464815dd2e120b441e641bc516ac32cc51d2/brotli-1.2.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:6be67c19e0b0c56365c6a76e393b932fb0e78b3b56b711d180dd7013cb1fd984", size = 445442, upload-time = "2025-11-05T18:38:02.434Z" },
|
| 104 |
-
{ url = "https://files.pythonhosted.org/packages/24/31/90cc06584deb5d4fcafc0985e37741fc6b9717926a78674bbb3ce018957e/brotli-1.2.0-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:0bbd5b5ccd157ae7913750476d48099aaf507a79841c0d04a9db4415b14842de", size = 1532658, upload-time = "2025-11-05T18:38:03.588Z" },
|
| 105 |
-
{ url = "https://files.pythonhosted.org/packages/62/17/33bf0c83bcbc96756dfd712201d87342732fad70bb3472c27e833a44a4f9/brotli-1.2.0-cp310-cp310-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:3f3c908bcc404c90c77d5a073e55271a0a498f4e0756e48127c35d91cf155947", size = 1631241, upload-time = "2025-11-05T18:38:04.582Z" },
|
| 106 |
-
{ url = "https://files.pythonhosted.org/packages/48/10/f47854a1917b62efe29bc98ac18e5d4f71df03f629184575b862ef2e743b/brotli-1.2.0-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:1b557b29782a643420e08d75aea889462a4a8796e9a6cf5621ab05a3f7da8ef2", size = 1424307, upload-time = "2025-11-05T18:38:05.587Z" },
|
| 107 |
-
{ url = "https://files.pythonhosted.org/packages/e4/b7/f88eb461719259c17483484ea8456925ee057897f8e64487d76e24e5e38d/brotli-1.2.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:81da1b229b1889f25adadc929aeb9dbc4e922bd18561b65b08dd9343cfccca84", size = 1488208, upload-time = "2025-11-05T18:38:06.613Z" },
|
| 108 |
-
{ url = "https://files.pythonhosted.org/packages/26/59/41bbcb983a0c48b0b8004203e74706c6b6e99a04f3c7ca6f4f41f364db50/brotli-1.2.0-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:ff09cd8c5eec3b9d02d2408db41be150d8891c5566addce57513bf546e3d6c6d", size = 1597574, upload-time = "2025-11-05T18:38:07.838Z" },
|
| 109 |
-
{ url = "https://files.pythonhosted.org/packages/8e/e6/8c89c3bdabbe802febb4c5c6ca224a395e97913b5df0dff11b54f23c1788/brotli-1.2.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:a1778532b978d2536e79c05dac2d8cd857f6c55cd0c95ace5b03740824e0e2f1", size = 1492109, upload-time = "2025-11-05T18:38:08.816Z" },
|
| 110 |
-
{ url = "https://files.pythonhosted.org/packages/ed/9a/4b19d4310b2dbd545c0c33f176b0528fa68c3cd0754e34b2f2bcf56548ae/brotli-1.2.0-cp310-cp310-win32.whl", hash = "sha256:b232029d100d393ae3c603c8ffd7e3fe6f798c5e28ddca5feabb8e8fdb732997", size = 334461, upload-time = "2025-11-05T18:38:10.729Z" },
|
| 111 |
-
{ url = "https://files.pythonhosted.org/packages/ac/39/70981d9f47705e3c2b95c0847dfa3e7a37aa3b7c6030aedc4873081ed005/brotli-1.2.0-cp310-cp310-win_amd64.whl", hash = "sha256:ef87b8ab2704da227e83a246356a2b179ef826f550f794b2c52cddb4efbd0196", size = 369035, upload-time = "2025-11-05T18:38:11.827Z" },
|
| 112 |
-
{ url = "https://files.pythonhosted.org/packages/7a/ef/f285668811a9e1ddb47a18cb0b437d5fc2760d537a2fe8a57875ad6f8448/brotli-1.2.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:15b33fe93cedc4caaff8a0bd1eb7e3dab1c61bb22a0bf5bdfdfd97cd7da79744", size = 863110, upload-time = "2025-11-05T18:38:12.978Z" },
|
| 113 |
-
{ url = "https://files.pythonhosted.org/packages/50/62/a3b77593587010c789a9d6eaa527c79e0848b7b860402cc64bc0bc28a86c/brotli-1.2.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:898be2be399c221d2671d29eed26b6b2713a02c2119168ed914e7d00ceadb56f", size = 445438, upload-time = "2025-11-05T18:38:14.208Z" },
|
| 114 |
-
{ url = "https://files.pythonhosted.org/packages/cd/e1/7fadd47f40ce5549dc44493877db40292277db373da5053aff181656e16e/brotli-1.2.0-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:350c8348f0e76fff0a0fd6c26755d2653863279d086d3aa2c290a6a7251135dd", size = 1534420, upload-time = "2025-11-05T18:38:15.111Z" },
|
| 115 |
-
{ url = "https://files.pythonhosted.org/packages/12/8b/1ed2f64054a5a008a4ccd2f271dbba7a5fb1a3067a99f5ceadedd4c1d5a7/brotli-1.2.0-cp311-cp311-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:2e1ad3fda65ae0d93fec742a128d72e145c9c7a99ee2fcd667785d99eb25a7fe", size = 1632619, upload-time = "2025-11-05T18:38:16.094Z" },
|
| 116 |
-
{ url = "https://files.pythonhosted.org/packages/89/5a/7071a621eb2d052d64efd5da2ef55ecdac7c3b0c6e4f9d519e9c66d987ef/brotli-1.2.0-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:40d918bce2b427a0c4ba189df7a006ac0c7277c180aee4617d99e9ccaaf59e6a", size = 1426014, upload-time = "2025-11-05T18:38:17.177Z" },
|
| 117 |
-
{ url = "https://files.pythonhosted.org/packages/26/6d/0971a8ea435af5156acaaccec1a505f981c9c80227633851f2810abd252a/brotli-1.2.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:2a7f1d03727130fc875448b65b127a9ec5d06d19d0148e7554384229706f9d1b", size = 1489661, upload-time = "2025-11-05T18:38:18.41Z" },
|
| 118 |
-
{ url = "https://files.pythonhosted.org/packages/f3/75/c1baca8b4ec6c96a03ef8230fab2a785e35297632f402ebb1e78a1e39116/brotli-1.2.0-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:9c79f57faa25d97900bfb119480806d783fba83cd09ee0b33c17623935b05fa3", size = 1599150, upload-time = "2025-11-05T18:38:19.792Z" },
|
| 119 |
-
{ url = "https://files.pythonhosted.org/packages/0d/1a/23fcfee1c324fd48a63d7ebf4bac3a4115bdb1b00e600f80f727d850b1ae/brotli-1.2.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:844a8ceb8483fefafc412f85c14f2aae2fb69567bf2a0de53cdb88b73e7c43ae", size = 1493505, upload-time = "2025-11-05T18:38:20.913Z" },
|
| 120 |
-
{ url = "https://files.pythonhosted.org/packages/36/e5/12904bbd36afeef53d45a84881a4810ae8810ad7e328a971ebbfd760a0b3/brotli-1.2.0-cp311-cp311-win32.whl", hash = "sha256:aa47441fa3026543513139cb8926a92a8e305ee9c71a6209ef7a97d91640ea03", size = 334451, upload-time = "2025-11-05T18:38:21.94Z" },
|
| 121 |
-
{ url = "https://files.pythonhosted.org/packages/02/8b/ecb5761b989629a4758c394b9301607a5880de61ee2ee5fe104b87149ebc/brotli-1.2.0-cp311-cp311-win_amd64.whl", hash = "sha256:022426c9e99fd65d9475dce5c195526f04bb8be8907607e27e747893f6ee3e24", size = 369035, upload-time = "2025-11-05T18:38:22.941Z" },
|
| 122 |
-
{ url = "https://files.pythonhosted.org/packages/11/ee/b0a11ab2315c69bb9b45a2aaed022499c9c24a205c3a49c3513b541a7967/brotli-1.2.0-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:35d382625778834a7f3061b15423919aa03e4f5da34ac8e02c074e4b75ab4f84", size = 861543, upload-time = "2025-11-05T18:38:24.183Z" },
|
| 123 |
-
{ url = "https://files.pythonhosted.org/packages/e1/2f/29c1459513cd35828e25531ebfcbf3e92a5e49f560b1777a9af7203eb46e/brotli-1.2.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:7a61c06b334bd99bc5ae84f1eeb36bfe01400264b3c352f968c6e30a10f9d08b", size = 444288, upload-time = "2025-11-05T18:38:25.139Z" },
|
| 124 |
-
{ url = "https://files.pythonhosted.org/packages/3d/6f/feba03130d5fceadfa3a1bb102cb14650798c848b1df2a808356f939bb16/brotli-1.2.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:acec55bb7c90f1dfc476126f9711a8e81c9af7fb617409a9ee2953115343f08d", size = 1528071, upload-time = "2025-11-05T18:38:26.081Z" },
|
| 125 |
-
{ url = "https://files.pythonhosted.org/packages/2b/38/f3abb554eee089bd15471057ba85f47e53a44a462cfce265d9bf7088eb09/brotli-1.2.0-cp312-cp312-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:260d3692396e1895c5034f204f0db022c056f9e2ac841593a4cf9426e2a3faca", size = 1626913, upload-time = "2025-11-05T18:38:27.284Z" },
|
| 126 |
-
{ url = "https://files.pythonhosted.org/packages/03/a7/03aa61fbc3c5cbf99b44d158665f9b0dd3d8059be16c460208d9e385c837/brotli-1.2.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:072e7624b1fc4d601036ab3f4f27942ef772887e876beff0301d261210bca97f", size = 1419762, upload-time = "2025-11-05T18:38:28.295Z" },
|
| 127 |
-
{ url = "https://files.pythonhosted.org/packages/21/1b/0374a89ee27d152a5069c356c96b93afd1b94eae83f1e004b57eb6ce2f10/brotli-1.2.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:adedc4a67e15327dfdd04884873c6d5a01d3e3b6f61406f99b1ed4865a2f6d28", size = 1484494, upload-time = "2025-11-05T18:38:29.29Z" },
|
| 128 |
-
{ url = "https://files.pythonhosted.org/packages/cf/57/69d4fe84a67aef4f524dcd075c6eee868d7850e85bf01d778a857d8dbe0a/brotli-1.2.0-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:7a47ce5c2288702e09dc22a44d0ee6152f2c7eda97b3c8482d826a1f3cfc7da7", size = 1593302, upload-time = "2025-11-05T18:38:30.639Z" },
|
| 129 |
-
{ url = "https://files.pythonhosted.org/packages/d5/3b/39e13ce78a8e9a621c5df3aeb5fd181fcc8caba8c48a194cd629771f6828/brotli-1.2.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:af43b8711a8264bb4e7d6d9a6d004c3a2019c04c01127a868709ec29962b6036", size = 1487913, upload-time = "2025-11-05T18:38:31.618Z" },
|
| 130 |
-
{ url = "https://files.pythonhosted.org/packages/62/28/4d00cb9bd76a6357a66fcd54b4b6d70288385584063f4b07884c1e7286ac/brotli-1.2.0-cp312-cp312-win32.whl", hash = "sha256:e99befa0b48f3cd293dafeacdd0d191804d105d279e0b387a32054c1180f3161", size = 334362, upload-time = "2025-11-05T18:38:32.939Z" },
|
| 131 |
-
{ url = "https://files.pythonhosted.org/packages/1c/4e/bc1dcac9498859d5e353c9b153627a3752868a9d5f05ce8dedd81a2354ab/brotli-1.2.0-cp312-cp312-win_amd64.whl", hash = "sha256:b35c13ce241abdd44cb8ca70683f20c0c079728a36a996297adb5334adfc1c44", size = 369115, upload-time = "2025-11-05T18:38:33.765Z" },
|
| 132 |
-
]
|
| 133 |
-
|
| 134 |
[[package]]
|
| 135 |
name = "cachetools"
|
| 136 |
version = "7.0.5"
|
|
@@ -462,22 +406,6 @@ wheels = [
|
|
| 462 |
{ url = "https://files.pythonhosted.org/packages/c1/ea/53f2148663b321f21b5a606bd5f191517cf40b7072c0497d3c92c4a13b1e/executing-2.2.1-py2.py3-none-any.whl", hash = "sha256:760643d3452b4d777d295bb167ccc74c64a81df23fb5e08eff250c425a4b2017", size = 28317, upload-time = "2025-09-01T09:48:08.5Z" },
|
| 463 |
]
|
| 464 |
|
| 465 |
-
[[package]]
|
| 466 |
-
name = "fastapi"
|
| 467 |
-
version = "0.135.1"
|
| 468 |
-
source = { registry = "https://pypi.org/simple" }
|
| 469 |
-
dependencies = [
|
| 470 |
-
{ name = "annotated-doc", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
|
| 471 |
-
{ name = "pydantic", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
|
| 472 |
-
{ name = "starlette", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
|
| 473 |
-
{ name = "typing-extensions", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
|
| 474 |
-
{ name = "typing-inspection", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
|
| 475 |
-
]
|
| 476 |
-
sdist = { url = "https://files.pythonhosted.org/packages/e7/7b/f8e0211e9380f7195ba3f3d40c292594fd81ba8ec4629e3854c353aaca45/fastapi-0.135.1.tar.gz", hash = "sha256:d04115b508d936d254cea545b7312ecaa58a7b3a0f84952535b4c9afae7668cd", size = 394962, upload-time = "2026-03-01T18:18:29.369Z" }
|
| 477 |
-
wheels = [
|
| 478 |
-
{ url = "https://files.pythonhosted.org/packages/e4/72/42e900510195b23a56bde950d26a51f8b723846bfcaa0286e90287f0422b/fastapi-0.135.1-py3-none-any.whl", hash = "sha256:46e2fc5745924b7c840f71ddd277382af29ce1cdb7d5eab5bf697e3fb9999c9e", size = 116999, upload-time = "2026-03-01T18:18:30.831Z" },
|
| 479 |
-
]
|
| 480 |
-
|
| 481 |
[[package]]
|
| 482 |
name = "fastjsonschema"
|
| 483 |
version = "2.21.2"
|
|
@@ -487,15 +415,6 @@ wheels = [
|
|
| 487 |
{ url = "https://files.pythonhosted.org/packages/cb/a8/20d0723294217e47de6d9e2e40fd4a9d2f7c4b6ef974babd482a59743694/fastjsonschema-2.21.2-py3-none-any.whl", hash = "sha256:1c797122d0a86c5cace2e54bf4e819c36223b552017172f32c5c024a6b77e463", size = 24024, upload-time = "2025-08-14T18:49:34.776Z" },
|
| 488 |
]
|
| 489 |
|
| 490 |
-
[[package]]
|
| 491 |
-
name = "ffmpy"
|
| 492 |
-
version = "1.0.0"
|
| 493 |
-
source = { registry = "https://pypi.org/simple" }
|
| 494 |
-
sdist = { url = "https://files.pythonhosted.org/packages/7d/d2/1c4c582d71bcc65c76fa69fab85de6257d50fdf6fd4a2317c53917e9a581/ffmpy-1.0.0.tar.gz", hash = "sha256:b12932e95435c8820f1cd041024402765f821971e4bae753b327fc02a6e12f8b", size = 5101, upload-time = "2025-11-11T06:24:23.856Z" }
|
| 495 |
-
wheels = [
|
| 496 |
-
{ url = "https://files.pythonhosted.org/packages/55/56/dd3669eccebb6d8ac81e624542ebd53fe6f08e1b8f2f8d50aeb7e3b83f99/ffmpy-1.0.0-py3-none-any.whl", hash = "sha256:5640e5f0fd03fb6236d0e119b16ccf6522db1c826fdf35dcb87087b60fd7504f", size = 5614, upload-time = "2025-11-11T06:24:22.818Z" },
|
| 497 |
-
]
|
| 498 |
-
|
| 499 |
[[package]]
|
| 500 |
name = "filelock"
|
| 501 |
version = "3.25.2"
|
|
@@ -571,71 +490,6 @@ wheels = [
|
|
| 571 |
{ url = "https://files.pythonhosted.org/packages/6a/09/e21df6aef1e1ffc0c816f0522ddc3f6dcded766c3261813131c78a704470/gitpython-3.1.46-py3-none-any.whl", hash = "sha256:79812ed143d9d25b6d176a10bb511de0f9c67b1fa641d82097b0ab90398a2058", size = 208620, upload-time = "2026-01-01T15:37:30.574Z" },
|
| 572 |
]
|
| 573 |
|
| 574 |
-
[[package]]
|
| 575 |
-
name = "gradio"
|
| 576 |
-
version = "6.9.0"
|
| 577 |
-
source = { registry = "https://pypi.org/simple" }
|
| 578 |
-
dependencies = [
|
| 579 |
-
{ name = "aiofiles", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
|
| 580 |
-
{ name = "anyio", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
|
| 581 |
-
{ name = "brotli", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
|
| 582 |
-
{ name = "fastapi", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
|
| 583 |
-
{ name = "ffmpy", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
|
| 584 |
-
{ name = "gradio-client", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
|
| 585 |
-
{ name = "groovy", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
|
| 586 |
-
{ name = "httpx", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
|
| 587 |
-
{ name = "huggingface-hub", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
|
| 588 |
-
{ name = "jinja2", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
|
| 589 |
-
{ name = "markupsafe", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
|
| 590 |
-
{ name = "numpy", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
|
| 591 |
-
{ name = "orjson", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
|
| 592 |
-
{ name = "packaging", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
|
| 593 |
-
{ name = "pandas", version = "2.3.3", source = { registry = "https://pypi.org/simple" }, marker = "(python_full_version < '3.11' and sys_platform == 'linux') or (python_full_version >= '3.11' and extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm') or (sys_platform != 'linux' and extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
|
| 594 |
-
{ name = "pandas", version = "3.0.1", source = { registry = "https://pypi.org/simple" }, marker = "(python_full_version >= '3.11' and sys_platform == 'linux') or (python_full_version < '3.11' and extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm') or (sys_platform != 'linux' and extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
|
| 595 |
-
{ name = "pillow", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
|
| 596 |
-
{ name = "pydantic", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
|
| 597 |
-
{ name = "pydub", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
|
| 598 |
-
{ name = "python-multipart", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
|
| 599 |
-
{ name = "pytz", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
|
| 600 |
-
{ name = "pyyaml", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
|
| 601 |
-
{ name = "safehttpx", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
|
| 602 |
-
{ name = "semantic-version", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
|
| 603 |
-
{ name = "starlette", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
|
| 604 |
-
{ name = "tomlkit", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
|
| 605 |
-
{ name = "typer", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
|
| 606 |
-
{ name = "typing-extensions", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
|
| 607 |
-
{ name = "uvicorn", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
|
| 608 |
-
]
|
| 609 |
-
sdist = { url = "https://files.pythonhosted.org/packages/bd/83/29bdbf94b212512e3c775482d390f5b699a72d71a2c431dea367a6e45a37/gradio-6.9.0.tar.gz", hash = "sha256:593e60e33233f3586452ebfa9f741817c5ae849a98cc70945f3ccb8dc895eb22", size = 57904480, upload-time = "2026-03-06T17:44:26.025Z" }
|
| 610 |
-
wheels = [
|
| 611 |
-
{ url = "https://files.pythonhosted.org/packages/b3/8b/dc357ab966544e4dc898a2fee326d755c5f54da82af71a1a802e3476e78e/gradio-6.9.0-py3-none-any.whl", hash = "sha256:c173dd330c9247002a42222c85d76c0ecee65437eff808084e360862e7bbd24f", size = 42940853, upload-time = "2026-03-06T17:44:22.009Z" },
|
| 612 |
-
]
|
| 613 |
-
|
| 614 |
-
[[package]]
|
| 615 |
-
name = "gradio-client"
|
| 616 |
-
version = "2.3.0"
|
| 617 |
-
source = { registry = "https://pypi.org/simple" }
|
| 618 |
-
dependencies = [
|
| 619 |
-
{ name = "fsspec", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
|
| 620 |
-
{ name = "httpx", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
|
| 621 |
-
{ name = "huggingface-hub", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
|
| 622 |
-
{ name = "packaging", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
|
| 623 |
-
{ name = "typing-extensions", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
|
| 624 |
-
]
|
| 625 |
-
sdist = { url = "https://files.pythonhosted.org/packages/97/d2/de2037f5eff13a5145cdf6982fd34c9735f0806e8a2ee5d4bfe9a7d25a54/gradio_client-2.3.0.tar.gz", hash = "sha256:1c700dc60e65bae4386ba7cf3732b9f9d5bcf5fb8eb451df3944fe092d7d9a29", size = 57552, upload-time = "2026-03-06T17:44:38.247Z" }
|
| 626 |
-
wheels = [
|
| 627 |
-
{ url = "https://files.pythonhosted.org/packages/99/6a/41752781399811afbf8ac858f63c20eff354ed35169daa39604aefced4e8/gradio_client-2.3.0-py3-none-any.whl", hash = "sha256:9ec51a927888fc188e123a0ac5ad341d9265b325539a399554d1fc2604942e74", size = 58531, upload-time = "2026-03-06T17:44:36.961Z" },
|
| 628 |
-
]
|
| 629 |
-
|
| 630 |
-
[[package]]
|
| 631 |
-
name = "groovy"
|
| 632 |
-
version = "0.1.2"
|
| 633 |
-
source = { registry = "https://pypi.org/simple" }
|
| 634 |
-
sdist = { url = "https://files.pythonhosted.org/packages/52/36/bbdede67400277bef33d3ec0e6a31750da972c469f75966b4930c753218f/groovy-0.1.2.tar.gz", hash = "sha256:25c1dc09b3f9d7e292458aa762c6beb96ea037071bf5e917fc81fb78d2231083", size = 17325, upload-time = "2025-02-28T20:24:56.068Z" }
|
| 635 |
-
wheels = [
|
| 636 |
-
{ url = "https://files.pythonhosted.org/packages/28/27/3d6dcadc8a3214d8522c1e7f6a19554e33659be44546d44a2f7572ac7d2a/groovy-0.1.2-py3-none-any.whl", hash = "sha256:7f7975bab18c729a257a8b1ae9dcd70b7cafb1720481beae47719af57c35fa64", size = 14090, upload-time = "2025-02-28T20:24:55.152Z" },
|
| 637 |
-
]
|
| 638 |
-
|
| 639 |
[[package]]
|
| 640 |
name = "h11"
|
| 641 |
version = "0.16.0"
|
|
@@ -645,70 +499,6 @@ wheels = [
|
|
| 645 |
{ url = "https://files.pythonhosted.org/packages/04/4b/29cac41a4d98d144bf5f6d33995617b185d14b22401f75ca86f384e87ff1/h11-0.16.0-py3-none-any.whl", hash = "sha256:63cf8bbe7522de3bf65932fda1d9c2772064ffb3dae62d55932da54b31cb6c86", size = 37515, upload-time = "2025-04-24T03:35:24.344Z" },
|
| 646 |
]
|
| 647 |
|
| 648 |
-
[[package]]
|
| 649 |
-
name = "hf-xet"
|
| 650 |
-
version = "1.4.2"
|
| 651 |
-
source = { registry = "https://pypi.org/simple" }
|
| 652 |
-
sdist = { url = "https://files.pythonhosted.org/packages/09/08/23c84a26716382c89151b5b447b4beb19e3345f3a93d3b73009a71a57ad3/hf_xet-1.4.2.tar.gz", hash = "sha256:b7457b6b482d9e0743bd116363239b1fa904a5e65deede350fbc0c4ea67c71ea", size = 672357, upload-time = "2026-03-13T06:58:51.077Z" }
|
| 653 |
-
wheels = [
|
| 654 |
-
{ url = "https://files.pythonhosted.org/packages/b4/86/b40b83a2ff03ef05c4478d2672b1fc2b9683ff870e2b25f4f3af240f2e7b/hf_xet-1.4.2-cp37-abi3-macosx_10_12_x86_64.whl", hash = "sha256:71f02d6e4cdd07f344f6844845d78518cc7186bd2bc52d37c3b73dc26a3b0bc5", size = 3800339, upload-time = "2026-03-13T06:58:36.245Z" },
|
| 655 |
-
{ url = "https://files.pythonhosted.org/packages/64/2e/af4475c32b4378b0e92a587adb1aa3ec53e3450fd3e5fe0372a874531c00/hf_xet-1.4.2-cp37-abi3-macosx_11_0_arm64.whl", hash = "sha256:e9b38d876e94d4bdcf650778d6ebbaa791dd28de08db9736c43faff06ede1b5a", size = 3559664, upload-time = "2026-03-13T06:58:34.787Z" },
|
| 656 |
-
{ url = "https://files.pythonhosted.org/packages/3c/4c/781267da3188db679e601de18112021a5cb16506fe86b246e22c5401a9c4/hf_xet-1.4.2-cp37-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:77e8c180b7ef12d8a96739a4e1e558847002afe9ea63b6f6358b2271a8bdda1c", size = 4217422, upload-time = "2026-03-13T06:58:27.472Z" },
|
| 657 |
-
{ url = "https://files.pythonhosted.org/packages/68/47/d6cf4a39ecf6c7705f887a46f6ef5c8455b44ad9eb0d391aa7e8a2ff7fea/hf_xet-1.4.2-cp37-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:c3b3c6a882016b94b6c210957502ff7877802d0dbda8ad142c8595db8b944271", size = 3992847, upload-time = "2026-03-13T06:58:25.989Z" },
|
| 658 |
-
{ url = "https://files.pythonhosted.org/packages/2d/ef/e80815061abff54697239803948abc665c6b1d237102c174f4f7a9a5ffc5/hf_xet-1.4.2-cp37-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:9d9a634cc929cfbaf2e1a50c0e532ae8c78fa98618426769480c58501e8c8ac2", size = 4193843, upload-time = "2026-03-13T06:58:44.59Z" },
|
| 659 |
-
{ url = "https://files.pythonhosted.org/packages/54/75/07f6aa680575d9646c4167db6407c41340cbe2357f5654c4e72a1b01ca14/hf_xet-1.4.2-cp37-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:6b0932eb8b10317ea78b7da6bab172b17be03bbcd7809383d8d5abd6a2233e04", size = 4432751, upload-time = "2026-03-13T06:58:46.533Z" },
|
| 660 |
-
{ url = "https://files.pythonhosted.org/packages/cd/71/193eabd7e7d4b903c4aa983a215509c6114915a5a237525ec562baddb868/hf_xet-1.4.2-cp37-abi3-win_amd64.whl", hash = "sha256:ad185719fb2e8ac26f88c8100562dbf9dbdcc3d9d2add00faa94b5f106aea53f", size = 3671149, upload-time = "2026-03-13T06:58:57.07Z" },
|
| 661 |
-
{ url = "https://files.pythonhosted.org/packages/b4/7e/ccf239da366b37ba7f0b36095450efae4a64980bdc7ec2f51354205fdf39/hf_xet-1.4.2-cp37-abi3-win_arm64.whl", hash = "sha256:32c012286b581f783653e718c1862aea5b9eb140631685bb0c5e7012c8719a87", size = 3533426, upload-time = "2026-03-13T06:58:55.46Z" },
|
| 662 |
-
]
|
| 663 |
-
|
| 664 |
-
[[package]]
|
| 665 |
-
name = "httpcore"
|
| 666 |
-
version = "1.0.9"
|
| 667 |
-
source = { registry = "https://pypi.org/simple" }
|
| 668 |
-
dependencies = [
|
| 669 |
-
{ name = "certifi", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
|
| 670 |
-
{ name = "h11", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
|
| 671 |
-
]
|
| 672 |
-
sdist = { url = "https://files.pythonhosted.org/packages/06/94/82699a10bca87a5556c9c59b5963f2d039dbd239f25bc2a63907a05a14cb/httpcore-1.0.9.tar.gz", hash = "sha256:6e34463af53fd2ab5d807f399a9b45ea31c3dfa2276f15a2c3f00afff6e176e8", size = 85484, upload-time = "2025-04-24T22:06:22.219Z" }
|
| 673 |
-
wheels = [
|
| 674 |
-
{ url = "https://files.pythonhosted.org/packages/7e/f5/f66802a942d491edb555dd61e3a9961140fd64c90bce1eafd741609d334d/httpcore-1.0.9-py3-none-any.whl", hash = "sha256:2d400746a40668fc9dec9810239072b40b4484b640a8c38fd654a024c7a1bf55", size = 78784, upload-time = "2025-04-24T22:06:20.566Z" },
|
| 675 |
-
]
|
| 676 |
-
|
| 677 |
-
[[package]]
|
| 678 |
-
name = "httpx"
|
| 679 |
-
version = "0.28.1"
|
| 680 |
-
source = { registry = "https://pypi.org/simple" }
|
| 681 |
-
dependencies = [
|
| 682 |
-
{ name = "anyio", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
|
| 683 |
-
{ name = "certifi", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
|
| 684 |
-
{ name = "httpcore", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
|
| 685 |
-
{ name = "idna", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
|
| 686 |
-
]
|
| 687 |
-
sdist = { url = "https://files.pythonhosted.org/packages/b1/df/48c586a5fe32a0f01324ee087459e112ebb7224f646c0b5023f5e79e9956/httpx-0.28.1.tar.gz", hash = "sha256:75e98c5f16b0f35b567856f597f06ff2270a374470a5c2392242528e3e3e42fc", size = 141406, upload-time = "2024-12-06T15:37:23.222Z" }
|
| 688 |
-
wheels = [
|
| 689 |
-
{ url = "https://files.pythonhosted.org/packages/2a/39/e50c7c3a983047577ee07d2a9e53faf5a69493943ec3f6a384bdc792deb2/httpx-0.28.1-py3-none-any.whl", hash = "sha256:d909fcccc110f8c7faf814ca82a9a4d816bc5a6dbfea25d6591d6985b8ba59ad", size = 73517, upload-time = "2024-12-06T15:37:21.509Z" },
|
| 690 |
-
]
|
| 691 |
-
|
| 692 |
-
[[package]]
|
| 693 |
-
name = "huggingface-hub"
|
| 694 |
-
version = "1.7.2"
|
| 695 |
-
source = { registry = "https://pypi.org/simple" }
|
| 696 |
-
dependencies = [
|
| 697 |
-
{ name = "filelock", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
|
| 698 |
-
{ name = "fsspec", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
|
| 699 |
-
{ name = "hf-xet", marker = "(platform_machine != 'AMD64' and platform_machine != 'aarch64' and platform_machine != 'amd64' and platform_machine != 'arm64' and platform_machine != 'x86_64' and extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm') or (platform_machine == 'AMD64' and sys_platform == 'linux') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'amd64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
|
| 700 |
-
{ name = "httpx", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
|
| 701 |
-
{ name = "packaging", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
|
| 702 |
-
{ name = "pyyaml", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
|
| 703 |
-
{ name = "tqdm", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
|
| 704 |
-
{ name = "typer", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
|
| 705 |
-
{ name = "typing-extensions", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
|
| 706 |
-
]
|
| 707 |
-
sdist = { url = "https://files.pythonhosted.org/packages/19/15/eafc1c57bf0f8afffb243dcd4c0cceb785e956acc17bba4d9bf2ae21fc9c/huggingface_hub-1.7.2.tar.gz", hash = "sha256:7f7e294e9bbb822e025bdb2ada025fa4344d978175a7f78e824d86e35f7ab43b", size = 724684, upload-time = "2026-03-20T10:36:08.767Z" }
|
| 708 |
-
wheels = [
|
| 709 |
-
{ url = "https://files.pythonhosted.org/packages/08/de/3ad061a05f74728927ded48c90b73521b9a9328c85d841bdefb30e01fb85/huggingface_hub-1.7.2-py3-none-any.whl", hash = "sha256:288f33a0a17b2a73a1359e2a5fd28d1becb2c121748c6173ab8643fb342c850e", size = 618036, upload-time = "2026-03-20T10:36:06.824Z" },
|
| 710 |
-
]
|
| 711 |
-
|
| 712 |
[[package]]
|
| 713 |
name = "humanize"
|
| 714 |
version = "4.15.0"
|
|
@@ -1439,57 +1229,6 @@ wheels = [
|
|
| 1439 |
{ url = "https://files.pythonhosted.org/packages/9f/99/4c9c0c329bf9fc125008c3b54c7c94c0023518d06fc025ae36431375e1fe/nvidia_nvtx_cu12-12.8.90-py3-none-win_amd64.whl", hash = "sha256:619c8304aedc69f02ea82dd244541a83c3d9d40993381b3b590f1adaed3db41e", size = 56492, upload-time = "2025-03-07T01:52:24.69Z" },
|
| 1440 |
]
|
| 1441 |
|
| 1442 |
-
[[package]]
|
| 1443 |
-
name = "orjson"
|
| 1444 |
-
version = "3.11.7"
|
| 1445 |
-
source = { registry = "https://pypi.org/simple" }
|
| 1446 |
-
sdist = { url = "https://files.pythonhosted.org/packages/53/45/b268004f745ede84e5798b48ee12b05129d19235d0e15267aa57dcdb400b/orjson-3.11.7.tar.gz", hash = "sha256:9b1a67243945819ce55d24a30b59d6a168e86220452d2c96f4d1f093e71c0c49", size = 6144992, upload-time = "2026-02-02T15:38:49.29Z" }
|
| 1447 |
-
wheels = [
|
| 1448 |
-
{ url = "https://files.pythonhosted.org/packages/de/1a/a373746fa6d0e116dd9e54371a7b54622c44d12296d5d0f3ad5e3ff33490/orjson-3.11.7-cp310-cp310-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:a02c833f38f36546ba65a452127633afce4cf0dd7296b753d3bb54e55e5c0174", size = 229140, upload-time = "2026-02-02T15:37:06.082Z" },
|
| 1449 |
-
{ url = "https://files.pythonhosted.org/packages/52/a2/fa129e749d500f9b183e8a3446a193818a25f60261e9ce143ad61e975208/orjson-3.11.7-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b63c6e6738d7c3470ad01601e23376aa511e50e1f3931395b9f9c722406d1a67", size = 128670, upload-time = "2026-02-02T15:37:08.002Z" },
|
| 1450 |
-
{ url = "https://files.pythonhosted.org/packages/08/93/1e82011cd1e0bd051ef9d35bed1aa7fb4ea1f0a055dc2c841b46b43a9ebd/orjson-3.11.7-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:043d3006b7d32c7e233b8cfb1f01c651013ea079e08dcef7189a29abd8befe11", size = 123832, upload-time = "2026-02-02T15:37:09.191Z" },
|
| 1451 |
-
{ url = "https://files.pythonhosted.org/packages/fe/d8/a26b431ef962c7d55736674dddade876822f3e33223c1f47a36879350d04/orjson-3.11.7-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:57036b27ac8a25d81112eb0cc9835cd4833c5b16e1467816adc0015f59e870dc", size = 129171, upload-time = "2026-02-02T15:37:11.112Z" },
|
| 1452 |
-
{ url = "https://files.pythonhosted.org/packages/a7/19/f47819b84a580f490da260c3ee9ade214cf4cf78ac9ce8c1c758f80fdfc9/orjson-3.11.7-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:733ae23ada68b804b222c44affed76b39e30806d38660bf1eb200520d259cc16", size = 141967, upload-time = "2026-02-02T15:37:12.282Z" },
|
| 1453 |
-
{ url = "https://files.pythonhosted.org/packages/5b/cd/37ece39a0777ba077fdcdbe4cccae3be8ed00290c14bf8afdc548befc260/orjson-3.11.7-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5fdfad2093bdd08245f2e204d977facd5f871c88c4a71230d5bcbd0e43bf6222", size = 130991, upload-time = "2026-02-02T15:37:13.465Z" },
|
| 1454 |
-
{ url = "https://files.pythonhosted.org/packages/8f/ed/f2b5d66aa9b6b5c02ff5f120efc7b38c7c4962b21e6be0f00fd99a5c348e/orjson-3.11.7-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cededd6738e1c153530793998e31c05086582b08315db48ab66649768f326baa", size = 133674, upload-time = "2026-02-02T15:37:14.694Z" },
|
| 1455 |
-
{ url = "https://files.pythonhosted.org/packages/c4/6e/baa83e68d1aa09fa8c3e5b2c087d01d0a0bd45256de719ed7bc22c07052d/orjson-3.11.7-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:14f440c7268c8f8633d1b3d443a434bd70cb15686117ea6beff8fdc8f5917a1e", size = 138722, upload-time = "2026-02-02T15:37:16.501Z" },
|
| 1456 |
-
{ url = "https://files.pythonhosted.org/packages/0c/47/7f8ef4963b772cd56999b535e553f7eb5cd27e9dd6c049baee6f18bfa05d/orjson-3.11.7-cp310-cp310-musllinux_1_2_armv7l.whl", hash = "sha256:3a2479753bbb95b0ebcf7969f562cdb9668e6d12416a35b0dda79febf89cdea2", size = 409056, upload-time = "2026-02-02T15:37:17.895Z" },
|
| 1457 |
-
{ url = "https://files.pythonhosted.org/packages/38/eb/2df104dd2244b3618f25325a656f85cc3277f74bbd91224752410a78f3c7/orjson-3.11.7-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:71924496986275a737f38e3f22b4e0878882b3f7a310d2ff4dc96e812789120c", size = 144196, upload-time = "2026-02-02T15:37:19.349Z" },
|
| 1458 |
-
{ url = "https://files.pythonhosted.org/packages/b6/2a/ee41de0aa3a6686598661eae2b4ebdff1340c65bfb17fcff8b87138aab21/orjson-3.11.7-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:b4a9eefdc70bf8bf9857f0290f973dec534ac84c35cd6a7f4083be43e7170a8f", size = 134979, upload-time = "2026-02-02T15:37:20.906Z" },
|
| 1459 |
-
{ url = "https://files.pythonhosted.org/packages/4c/fa/92fc5d3d402b87a8b28277a9ed35386218a6a5287c7fe5ee9b9f02c53fb2/orjson-3.11.7-cp310-cp310-win32.whl", hash = "sha256:ae9e0b37a834cef7ce8f99de6498f8fad4a2c0bf6bfc3d02abd8ed56aa15b2de", size = 127968, upload-time = "2026-02-02T15:37:23.178Z" },
|
| 1460 |
-
{ url = "https://files.pythonhosted.org/packages/07/29/a576bf36d73d60df06904d3844a9df08e25d59eba64363aaf8ec2f9bff41/orjson-3.11.7-cp310-cp310-win_amd64.whl", hash = "sha256:d772afdb22555f0c58cfc741bdae44180122b3616faa1ecadb595cd526e4c993", size = 125128, upload-time = "2026-02-02T15:37:24.329Z" },
|
| 1461 |
-
{ url = "https://files.pythonhosted.org/packages/37/02/da6cb01fc6087048d7f61522c327edf4250f1683a58a839fdcc435746dd5/orjson-3.11.7-cp311-cp311-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:9487abc2c2086e7c8eb9a211d2ce8855bae0e92586279d0d27b341d5ad76c85c", size = 228664, upload-time = "2026-02-02T15:37:25.542Z" },
|
| 1462 |
-
{ url = "https://files.pythonhosted.org/packages/c1/c2/5885e7a5881dba9a9af51bc564e8967225a642b3e03d089289a35054e749/orjson-3.11.7-cp311-cp311-macosx_15_0_arm64.whl", hash = "sha256:79cacb0b52f6004caf92405a7e1f11e6e2de8bdf9019e4f76b44ba045125cd6b", size = 125344, upload-time = "2026-02-02T15:37:26.92Z" },
|
| 1463 |
-
{ url = "https://files.pythonhosted.org/packages/a4/1d/4e7688de0a92d1caf600dfd5fb70b4c5bfff51dfa61ac555072ef2d0d32a/orjson-3.11.7-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c2e85fe4698b6a56d5e2ebf7ae87544d668eb6bde1ad1226c13f44663f20ec9e", size = 128404, upload-time = "2026-02-02T15:37:28.108Z" },
|
| 1464 |
-
{ url = "https://files.pythonhosted.org/packages/2f/b2/ec04b74ae03a125db7bd69cffd014b227b7f341e3261bf75b5eb88a1aa92/orjson-3.11.7-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:b8d14b71c0b12963fe8a62aac87119f1afdf4cb88a400f61ca5ae581449efcb5", size = 123677, upload-time = "2026-02-02T15:37:30.287Z" },
|
| 1465 |
-
{ url = "https://files.pythonhosted.org/packages/4c/69/f95bdf960605f08f827f6e3291fe243d8aa9c5c9ff017a8d7232209184c3/orjson-3.11.7-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:91c81ef070c8f3220054115e1ef468b1c9ce8497b4e526cb9f68ab4dc0a7ac62", size = 128950, upload-time = "2026-02-02T15:37:31.595Z" },
|
| 1466 |
-
{ url = "https://files.pythonhosted.org/packages/a4/1b/de59c57bae1d148ef298852abd31909ac3089cff370dfd4cd84cc99cbc42/orjson-3.11.7-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:411ebaf34d735e25e358a6d9e7978954a9c9d58cfb47bc6683cdc3964cd2f910", size = 141756, upload-time = "2026-02-02T15:37:32.985Z" },
|
| 1467 |
-
{ url = "https://files.pythonhosted.org/packages/ee/9e/9decc59f4499f695f65c650f6cfa6cd4c37a3fbe8fa235a0a3614cb54386/orjson-3.11.7-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a16bcd08ab0bcdfc7e8801d9c4a9cc17e58418e4d48ddc6ded4e9e4b1a94062b", size = 130812, upload-time = "2026-02-02T15:37:34.204Z" },
|
| 1468 |
-
{ url = "https://files.pythonhosted.org/packages/28/e6/59f932bcabd1eac44e334fe8e3281a92eacfcb450586e1f4bde0423728d8/orjson-3.11.7-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9c0b51672e466fd7e56230ffbae7f1639e18d0ce023351fb75da21b71bc2c960", size = 133444, upload-time = "2026-02-02T15:37:35.446Z" },
|
| 1469 |
-
{ url = "https://files.pythonhosted.org/packages/f1/36/b0f05c0eaa7ca30bc965e37e6a2956b0d67adb87a9872942d3568da846ae/orjson-3.11.7-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:136dcd6a2e796dfd9ffca9fc027d778567b0b7c9968d092842d3c323cef88aa8", size = 138609, upload-time = "2026-02-02T15:37:36.657Z" },
|
| 1470 |
-
{ url = "https://files.pythonhosted.org/packages/b8/03/58ec7d302b8d86944c60c7b4b82975d5161fcce4c9bc8c6cb1d6741b6115/orjson-3.11.7-cp311-cp311-musllinux_1_2_armv7l.whl", hash = "sha256:7ba61079379b0ae29e117db13bda5f28d939766e410d321ec1624afc6a0b0504", size = 408918, upload-time = "2026-02-02T15:37:38.076Z" },
|
| 1471 |
-
{ url = "https://files.pythonhosted.org/packages/06/3a/868d65ef9a8b99be723bd510de491349618abd9f62c826cf206d962db295/orjson-3.11.7-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:0527a4510c300e3b406591b0ba69b5dc50031895b0a93743526a3fc45f59d26e", size = 143998, upload-time = "2026-02-02T15:37:39.706Z" },
|
| 1472 |
-
{ url = "https://files.pythonhosted.org/packages/5b/c7/1e18e1c83afe3349f4f6dc9e14910f0ae5f82eac756d1412ea4018938535/orjson-3.11.7-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:a709e881723c9b18acddcfb8ba357322491ad553e277cf467e1e7e20e2d90561", size = 134802, upload-time = "2026-02-02T15:37:41.002Z" },
|
| 1473 |
-
{ url = "https://files.pythonhosted.org/packages/d4/0b/ccb7ee1a65b37e8eeb8b267dc953561d72370e85185e459616d4345bab34/orjson-3.11.7-cp311-cp311-win32.whl", hash = "sha256:c43b8b5bab288b6b90dac410cca7e986a4fa747a2e8f94615aea407da706980d", size = 127828, upload-time = "2026-02-02T15:37:42.241Z" },
|
| 1474 |
-
{ url = "https://files.pythonhosted.org/packages/af/9e/55c776dffda3f381e0f07d010a4f5f3902bf48eaba1bb7684d301acd4924/orjson-3.11.7-cp311-cp311-win_amd64.whl", hash = "sha256:6543001328aa857187f905308a028935864aefe9968af3848401b6fe80dbb471", size = 124941, upload-time = "2026-02-02T15:37:43.444Z" },
|
| 1475 |
-
{ url = "https://files.pythonhosted.org/packages/aa/8e/424a620fa7d263b880162505fb107ef5e0afaa765b5b06a88312ac291560/orjson-3.11.7-cp311-cp311-win_arm64.whl", hash = "sha256:1ee5cc7160a821dfe14f130bc8e63e7611051f964b463d9e2a3a573204446a4d", size = 126245, upload-time = "2026-02-02T15:37:45.18Z" },
|
| 1476 |
-
{ url = "https://files.pythonhosted.org/packages/80/bf/76f4f1665f6983385938f0e2a5d7efa12a58171b8456c252f3bae8a4cf75/orjson-3.11.7-cp312-cp312-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:bd03ea7606833655048dab1a00734a2875e3e86c276e1d772b2a02556f0d895f", size = 228545, upload-time = "2026-02-02T15:37:46.376Z" },
|
| 1477 |
-
{ url = "https://files.pythonhosted.org/packages/79/53/6c72c002cb13b5a978a068add59b25a8bdf2800ac1c9c8ecdb26d6d97064/orjson-3.11.7-cp312-cp312-macosx_15_0_arm64.whl", hash = "sha256:89e440ebc74ce8ab5c7bc4ce6757b4a6b1041becb127df818f6997b5c71aa60b", size = 125224, upload-time = "2026-02-02T15:37:47.697Z" },
|
| 1478 |
-
{ url = "https://files.pythonhosted.org/packages/2c/83/10e48852865e5dd151bdfe652c06f7da484578ed02c5fca938e3632cb0b8/orjson-3.11.7-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5ede977b5fe5ac91b1dffc0a517ca4542d2ec8a6a4ff7b2652d94f640796342a", size = 128154, upload-time = "2026-02-02T15:37:48.954Z" },
|
| 1479 |
-
{ url = "https://files.pythonhosted.org/packages/6e/52/a66e22a2b9abaa374b4a081d410edab6d1e30024707b87eab7c734afe28d/orjson-3.11.7-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:b7b1dae39230a393df353827c855a5f176271c23434cfd2db74e0e424e693e10", size = 123548, upload-time = "2026-02-02T15:37:50.187Z" },
|
| 1480 |
-
{ url = "https://files.pythonhosted.org/packages/de/38/605d371417021359f4910c496f764c48ceb8997605f8c25bf1dfe58c0ebe/orjson-3.11.7-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ed46f17096e28fb28d2975834836a639af7278aa87c84f68ab08fbe5b8bd75fa", size = 129000, upload-time = "2026-02-02T15:37:51.426Z" },
|
| 1481 |
-
{ url = "https://files.pythonhosted.org/packages/44/98/af32e842b0ffd2335c89714d48ca4e3917b42f5d6ee5537832e069a4b3ac/orjson-3.11.7-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3726be79e36e526e3d9c1aceaadbfb4a04ee80a72ab47b3f3c17fefb9812e7b8", size = 141686, upload-time = "2026-02-02T15:37:52.607Z" },
|
| 1482 |
-
{ url = "https://files.pythonhosted.org/packages/96/0b/fc793858dfa54be6feee940c1463370ece34b3c39c1ca0aa3845f5ba9892/orjson-3.11.7-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:0724e265bc548af1dedebd9cb3d24b4e1c1e685a343be43e87ba922a5c5fff2f", size = 130812, upload-time = "2026-02-02T15:37:53.944Z" },
|
| 1483 |
-
{ url = "https://files.pythonhosted.org/packages/dc/91/98a52415059db3f374757d0b7f0f16e3b5cd5976c90d1c2b56acaea039e6/orjson-3.11.7-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e7745312efa9e11c17fbd3cb3097262d079da26930ae9ae7ba28fb738367cbad", size = 133440, upload-time = "2026-02-02T15:37:55.615Z" },
|
| 1484 |
-
{ url = "https://files.pythonhosted.org/packages/dc/b6/cb540117bda61791f46381f8c26c8f93e802892830a6055748d3bb1925ab/orjson-3.11.7-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:f904c24bdeabd4298f7a977ef14ca2a022ca921ed670b92ecd16ab6f3d01f867", size = 138386, upload-time = "2026-02-02T15:37:56.814Z" },
|
| 1485 |
-
{ url = "https://files.pythonhosted.org/packages/63/1a/50a3201c334a7f17c231eee5f841342190723794e3b06293f26e7cf87d31/orjson-3.11.7-cp312-cp312-musllinux_1_2_armv7l.whl", hash = "sha256:b9fc4d0f81f394689e0814617aadc4f2ea0e8025f38c226cbf22d3b5ddbf025d", size = 408853, upload-time = "2026-02-02T15:37:58.291Z" },
|
| 1486 |
-
{ url = "https://files.pythonhosted.org/packages/87/cd/8de1c67d0be44fdc22701e5989c0d015a2adf391498ad42c4dc589cd3013/orjson-3.11.7-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:849e38203e5be40b776ed2718e587faf204d184fc9a008ae441f9442320c0cab", size = 144130, upload-time = "2026-02-02T15:38:00.163Z" },
|
| 1487 |
-
{ url = "https://files.pythonhosted.org/packages/0f/fe/d605d700c35dd55f51710d159fc54516a280923cd1b7e47508982fbb387d/orjson-3.11.7-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:4682d1db3bcebd2b64757e0ddf9e87ae5f00d29d16c5cdf3a62f561d08cc3dd2", size = 134818, upload-time = "2026-02-02T15:38:01.507Z" },
|
| 1488 |
-
{ url = "https://files.pythonhosted.org/packages/e4/e4/15ecc67edb3ddb3e2f46ae04475f2d294e8b60c1825fbe28a428b93b3fbd/orjson-3.11.7-cp312-cp312-win32.whl", hash = "sha256:f4f7c956b5215d949a1f65334cf9d7612dde38f20a95f2315deef167def91a6f", size = 127923, upload-time = "2026-02-02T15:38:02.75Z" },
|
| 1489 |
-
{ url = "https://files.pythonhosted.org/packages/34/70/2e0855361f76198a3965273048c8e50a9695d88cd75811a5b46444895845/orjson-3.11.7-cp312-cp312-win_amd64.whl", hash = "sha256:bf742e149121dc5648ba0a08ea0871e87b660467ef168a3a5e53bc1fbd64bb74", size = 125007, upload-time = "2026-02-02T15:38:04.032Z" },
|
| 1490 |
-
{ url = "https://files.pythonhosted.org/packages/68/40/c2051bd19fc467610fed469dc29e43ac65891571138f476834ca192bc290/orjson-3.11.7-cp312-cp312-win_arm64.whl", hash = "sha256:26c3b9132f783b7d7903bf1efb095fed8d4a3a85ec0d334ee8beff3d7a4749d5", size = 126089, upload-time = "2026-02-02T15:38:05.297Z" },
|
| 1491 |
-
]
|
| 1492 |
-
|
| 1493 |
[[package]]
|
| 1494 |
name = "packaging"
|
| 1495 |
version = "26.0"
|
|
@@ -1586,6 +1325,7 @@ dependencies = [
|
|
| 1586 |
{ name = "chess-engine", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
|
| 1587 |
{ name = "numpy", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
|
| 1588 |
{ name = "psutil", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
|
|
|
|
| 1589 |
{ name = "tqdm", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
|
| 1590 |
{ name = "wandb", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
|
| 1591 |
]
|
|
@@ -1601,9 +1341,6 @@ dashboard = [
|
|
| 1601 |
{ name = "plotly", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
|
| 1602 |
{ name = "solara", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
|
| 1603 |
]
|
| 1604 |
-
demo = [
|
| 1605 |
-
{ name = "gradio", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
|
| 1606 |
-
]
|
| 1607 |
dev = [
|
| 1608 |
{ name = "ipykernel", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
|
| 1609 |
{ name = "pytest", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
|
|
@@ -1623,7 +1360,6 @@ rocm = [
|
|
| 1623 |
requires-dist = [
|
| 1624 |
{ name = "anywidget", marker = "extra == 'dashboard'", specifier = ">=0.9.21" },
|
| 1625 |
{ name = "chess-engine", editable = "engine" },
|
| 1626 |
-
{ name = "gradio", marker = "extra == 'demo'", specifier = ">=5.0.0" },
|
| 1627 |
{ name = "ipykernel", marker = "extra == 'dev'", specifier = ">=7.2.0" },
|
| 1628 |
{ name = "matplotlib", marker = "extra == 'eval'", specifier = ">=3.10.8" },
|
| 1629 |
{ name = "numpy", specifier = "~=2.2.0" },
|
|
@@ -1633,6 +1369,7 @@ requires-dist = [
|
|
| 1633 |
{ name = "psutil", specifier = ">=5.9.0" },
|
| 1634 |
{ name = "pyarrow", marker = "extra == 'eval'", specifier = ">=23.0.1" },
|
| 1635 |
{ name = "pytest", marker = "extra == 'dev'", specifier = "~=9.0.0" },
|
|
|
|
| 1636 |
{ name = "seaborn", marker = "extra == 'eval'", specifier = ">=0.13.2" },
|
| 1637 |
{ name = "solara", marker = "extra == 'dashboard'", specifier = ">=1.0.0" },
|
| 1638 |
{ name = "torch", marker = "extra == 'cu128'", specifier = "~=2.10.0", index = "https://download.pytorch.org/whl/cu128", conflict = { package = "pawn", extra = "cu128" } },
|
|
@@ -1641,7 +1378,7 @@ requires-dist = [
|
|
| 1641 |
{ name = "triton-rocm", marker = "extra == 'rocm'", specifier = ">=3.6.0", index = "https://download.pytorch.org/whl/rocm7.1", conflict = { package = "pawn", extra = "rocm" } },
|
| 1642 |
{ name = "wandb", specifier = "~=0.25.0" },
|
| 1643 |
]
|
| 1644 |
-
provides-extras = ["rocm", "cu128", "eval", "dashboard", "
|
| 1645 |
|
| 1646 |
[[package]]
|
| 1647 |
name = "pexpect"
|
|
@@ -1976,15 +1713,6 @@ wheels = [
|
|
| 1976 |
{ url = "https://files.pythonhosted.org/packages/36/c7/cfc8e811f061c841d7990b0201912c3556bfeb99cdcb7ed24adc8d6f8704/pydantic_core-2.41.5-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:56121965f7a4dc965bff783d70b907ddf3d57f6eba29b6d2e5dabfaf07799c51", size = 2145302, upload-time = "2025-11-04T13:43:46.64Z" },
|
| 1977 |
]
|
| 1978 |
|
| 1979 |
-
[[package]]
|
| 1980 |
-
name = "pydub"
|
| 1981 |
-
version = "0.25.1"
|
| 1982 |
-
source = { registry = "https://pypi.org/simple" }
|
| 1983 |
-
sdist = { url = "https://files.pythonhosted.org/packages/fe/9a/e6bca0eed82db26562c73b5076539a4a08d3cffd19c3cc5913a3e61145fd/pydub-0.25.1.tar.gz", hash = "sha256:980a33ce9949cab2a569606b65674d748ecbca4f0796887fd6f46173a7b0d30f", size = 38326, upload-time = "2021-03-10T02:09:54.659Z" }
|
| 1984 |
-
wheels = [
|
| 1985 |
-
{ url = "https://files.pythonhosted.org/packages/a6/53/d78dc063216e62fc55f6b2eebb447f6a4b0a59f55c8406376f76bf959b08/pydub-0.25.1-py2.py3-none-any.whl", hash = "sha256:65617e33033874b59d87db603aa1ed450633288aefead953b30bded59cb599a6", size = 32327, upload-time = "2021-03-10T02:09:53.503Z" },
|
| 1986 |
-
]
|
| 1987 |
-
|
| 1988 |
[[package]]
|
| 1989 |
name = "pygments"
|
| 1990 |
version = "2.19.2"
|
|
@@ -2045,15 +1773,6 @@ wheels = [
|
|
| 2045 |
{ url = "https://files.pythonhosted.org/packages/ec/57/56b9bcc3c9c6a792fcbaf139543cee77261f3651ca9da0c93f5c1221264b/python_dateutil-2.9.0.post0-py2.py3-none-any.whl", hash = "sha256:a8b2bc7bffae282281c8140a97d3aa9c14da0b136dfe83f850eea9a5f7470427", size = 229892, upload-time = "2024-03-01T18:36:18.57Z" },
|
| 2046 |
]
|
| 2047 |
|
| 2048 |
-
[[package]]
|
| 2049 |
-
name = "python-multipart"
|
| 2050 |
-
version = "0.0.22"
|
| 2051 |
-
source = { registry = "https://pypi.org/simple" }
|
| 2052 |
-
sdist = { url = "https://files.pythonhosted.org/packages/94/01/979e98d542a70714b0cb2b6728ed0b7c46792b695e3eaec3e20711271ca3/python_multipart-0.0.22.tar.gz", hash = "sha256:7340bef99a7e0032613f56dc36027b959fd3b30a787ed62d310e951f7c3a3a58", size = 37612, upload-time = "2026-01-25T10:15:56.219Z" }
|
| 2053 |
-
wheels = [
|
| 2054 |
-
{ url = "https://files.pythonhosted.org/packages/1b/d0/397f9626e711ff749a95d96b7af99b9c566a9bb5129b8e4c10fc4d100304/python_multipart-0.0.22-py3-none-any.whl", hash = "sha256:2b2cd894c83d21bf49d702499531c7bafd057d730c201782048f7945d82de155", size = 24579, upload-time = "2026-01-25T10:15:54.811Z" },
|
| 2055 |
-
]
|
| 2056 |
-
|
| 2057 |
[[package]]
|
| 2058 |
name = "pytz"
|
| 2059 |
version = "2026.1.post1"
|
|
@@ -2284,15 +2003,29 @@ wheels = [
|
|
| 2284 |
]
|
| 2285 |
|
| 2286 |
[[package]]
|
| 2287 |
-
name = "
|
| 2288 |
-
version = "0.
|
| 2289 |
source = { registry = "https://pypi.org/simple" }
|
| 2290 |
-
|
| 2291 |
-
|
| 2292 |
-
|
| 2293 |
-
|
| 2294 |
-
|
| 2295 |
-
{ url = "https://files.pythonhosted.org/packages/
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2296 |
]
|
| 2297 |
|
| 2298 |
[[package]]
|
|
@@ -2310,15 +2043,6 @@ wheels = [
|
|
| 2310 |
{ url = "https://files.pythonhosted.org/packages/83/11/00d3c3dfc25ad54e731d91449895a79e4bf2384dc3ac01809010ba88f6d5/seaborn-0.13.2-py3-none-any.whl", hash = "sha256:636f8336facf092165e27924f223d3c62ca560b1f2bb5dff7ab7fad265361987", size = 294914, upload-time = "2024-01-25T13:21:49.598Z" },
|
| 2311 |
]
|
| 2312 |
|
| 2313 |
-
[[package]]
|
| 2314 |
-
name = "semantic-version"
|
| 2315 |
-
version = "2.10.0"
|
| 2316 |
-
source = { registry = "https://pypi.org/simple" }
|
| 2317 |
-
sdist = { url = "https://files.pythonhosted.org/packages/7d/31/f2289ce78b9b473d582568c234e104d2a342fd658cc288a7553d83bb8595/semantic_version-2.10.0.tar.gz", hash = "sha256:bdabb6d336998cbb378d4b9db3a4b56a1e3235701dc05ea2690d9a997ed5041c", size = 52289, upload-time = "2022-05-26T13:35:23.454Z" }
|
| 2318 |
-
wheels = [
|
| 2319 |
-
{ url = "https://files.pythonhosted.org/packages/6a/23/8146aad7d88f4fcb3a6218f41a60f6c2d4e3a72de72da1825dc7c8f7877c/semantic_version-2.10.0-py2.py3-none-any.whl", hash = "sha256:de78a3b8e0feda74cabc54aab2da702113e33ac9d9eb9d2389bcf1f58b7d9177", size = 15552, upload-time = "2022-05-26T13:35:21.206Z" },
|
| 2320 |
-
]
|
| 2321 |
-
|
| 2322 |
[[package]]
|
| 2323 |
name = "sentry-sdk"
|
| 2324 |
version = "2.55.0"
|
|
@@ -2341,15 +2065,6 @@ wheels = [
|
|
| 2341 |
{ url = "https://files.pythonhosted.org/packages/9d/76/f789f7a86709c6b087c5a2f52f911838cad707cc613162401badc665acfe/setuptools-82.0.1-py3-none-any.whl", hash = "sha256:a59e362652f08dcd477c78bb6e7bd9d80a7995bc73ce773050228a348ce2e5bb", size = 1006223, upload-time = "2026-03-09T12:47:15.026Z" },
|
| 2342 |
]
|
| 2343 |
|
| 2344 |
-
[[package]]
|
| 2345 |
-
name = "shellingham"
|
| 2346 |
-
version = "1.5.4"
|
| 2347 |
-
source = { registry = "https://pypi.org/simple" }
|
| 2348 |
-
sdist = { url = "https://files.pythonhosted.org/packages/58/15/8b3609fd3830ef7b27b655beb4b4e9c62313a4e8da8c676e142cc210d58e/shellingham-1.5.4.tar.gz", hash = "sha256:8dbca0739d487e5bd35ab3ca4b36e11c4078f3a234bfce294b0a0291363404de", size = 10310, upload-time = "2023-10-24T04:13:40.426Z" }
|
| 2349 |
-
wheels = [
|
| 2350 |
-
{ url = "https://files.pythonhosted.org/packages/e0/f9/0595336914c5619e5f28a1fb793285925a8cd4b432c9da0a987836c7f822/shellingham-1.5.4-py2.py3-none-any.whl", hash = "sha256:7ecfff8f2fd72616f7481040475a65b2bf8af90a56c89140852d1120324e8686", size = 9755, upload-time = "2023-10-24T04:13:38.866Z" },
|
| 2351 |
-
]
|
| 2352 |
-
|
| 2353 |
[[package]]
|
| 2354 |
name = "six"
|
| 2355 |
version = "1.17.0"
|
|
@@ -2504,15 +2219,6 @@ wheels = [
|
|
| 2504 |
{ url = "https://files.pythonhosted.org/packages/23/d1/136eb2cb77520a31e1f64cbae9d33ec6df0d78bdf4160398e86eec8a8754/tomli-2.4.0-py3-none-any.whl", hash = "sha256:1f776e7d669ebceb01dee46484485f43a4048746235e683bcdffacdf1fb4785a", size = 14477, upload-time = "2026-01-11T11:22:37.446Z" },
|
| 2505 |
]
|
| 2506 |
|
| 2507 |
-
[[package]]
|
| 2508 |
-
name = "tomlkit"
|
| 2509 |
-
version = "0.13.3"
|
| 2510 |
-
source = { registry = "https://pypi.org/simple" }
|
| 2511 |
-
sdist = { url = "https://files.pythonhosted.org/packages/cc/18/0bbf3884e9eaa38819ebe46a7bd25dcd56b67434402b66a58c4b8e552575/tomlkit-0.13.3.tar.gz", hash = "sha256:430cf247ee57df2b94ee3fbe588e71d362a941ebb545dec29b53961d61add2a1", size = 185207, upload-time = "2025-06-05T07:13:44.947Z" }
|
| 2512 |
-
wheels = [
|
| 2513 |
-
{ url = "https://files.pythonhosted.org/packages/bd/75/8539d011f6be8e29f339c42e633aae3cb73bffa95dd0f9adec09b9c58e85/tomlkit-0.13.3-py3-none-any.whl", hash = "sha256:c89c649d79ee40629a9fda55f8ace8c6a1b42deb912b2a8fd8d942ddadb606b0", size = 38901, upload-time = "2025-06-05T07:13:43.546Z" },
|
| 2514 |
-
]
|
| 2515 |
-
|
| 2516 |
[[package]]
|
| 2517 |
name = "torch"
|
| 2518 |
version = "2.10.0+cu128"
|
|
@@ -2645,21 +2351,6 @@ wheels = [
|
|
| 2645 |
{ url = "https://download.pytorch.org/whl/triton_rocm-3.6.0-cp312-cp312-linux_x86_64.whl" },
|
| 2646 |
]
|
| 2647 |
|
| 2648 |
-
[[package]]
|
| 2649 |
-
name = "typer"
|
| 2650 |
-
version = "0.24.1"
|
| 2651 |
-
source = { registry = "https://pypi.org/simple" }
|
| 2652 |
-
dependencies = [
|
| 2653 |
-
{ name = "annotated-doc", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
|
| 2654 |
-
{ name = "click", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
|
| 2655 |
-
{ name = "rich", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
|
| 2656 |
-
{ name = "shellingham", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
|
| 2657 |
-
]
|
| 2658 |
-
sdist = { url = "https://files.pythonhosted.org/packages/f5/24/cb09efec5cc954f7f9b930bf8279447d24618bb6758d4f6adf2574c41780/typer-0.24.1.tar.gz", hash = "sha256:e39b4732d65fbdcde189ae76cf7cd48aeae72919dea1fdfc16593be016256b45", size = 118613, upload-time = "2026-02-21T16:54:40.609Z" }
|
| 2659 |
-
wheels = [
|
| 2660 |
-
{ url = "https://files.pythonhosted.org/packages/4a/91/48db081e7a63bb37284f9fbcefda7c44c277b18b0e13fbc36ea2335b71e6/typer-0.24.1-py3-none-any.whl", hash = "sha256:112c1f0ce578bfb4cab9ffdabc68f031416ebcc216536611ba21f04e9aa84c9e", size = 56085, upload-time = "2026-02-21T16:54:41.616Z" },
|
| 2661 |
-
]
|
| 2662 |
-
|
| 2663 |
[[package]]
|
| 2664 |
name = "typing-extensions"
|
| 2665 |
version = "4.15.0"
|
|
|
|
| 20 |
"pawn",
|
| 21 |
]
|
| 22 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
[[package]]
|
| 24 |
name = "annotated-types"
|
| 25 |
version = "0.7.0"
|
|
|
|
| 75 |
{ url = "https://files.pythonhosted.org/packages/64/b4/17d4b0b2a2dc85a6df63d1157e028ed19f90d4cd97c36717afef2bc2f395/attrs-26.1.0-py3-none-any.whl", hash = "sha256:c647aa4a12dfbad9333ca4e71fe62ddc36f4e63b2d260a37a8b83d2f043ac309", size = 67548, upload-time = "2026-03-19T14:22:23.645Z" },
|
| 76 |
]
|
| 77 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 78 |
[[package]]
|
| 79 |
name = "cachetools"
|
| 80 |
version = "7.0.5"
|
|
|
|
| 406 |
{ url = "https://files.pythonhosted.org/packages/c1/ea/53f2148663b321f21b5a606bd5f191517cf40b7072c0497d3c92c4a13b1e/executing-2.2.1-py2.py3-none-any.whl", hash = "sha256:760643d3452b4d777d295bb167ccc74c64a81df23fb5e08eff250c425a4b2017", size = 28317, upload-time = "2025-09-01T09:48:08.5Z" },
|
| 407 |
]
|
| 408 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 409 |
[[package]]
|
| 410 |
name = "fastjsonschema"
|
| 411 |
version = "2.21.2"
|
|
|
|
| 415 |
{ url = "https://files.pythonhosted.org/packages/cb/a8/20d0723294217e47de6d9e2e40fd4a9d2f7c4b6ef974babd482a59743694/fastjsonschema-2.21.2-py3-none-any.whl", hash = "sha256:1c797122d0a86c5cace2e54bf4e819c36223b552017172f32c5c024a6b77e463", size = 24024, upload-time = "2025-08-14T18:49:34.776Z" },
|
| 416 |
]
|
| 417 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 418 |
[[package]]
|
| 419 |
name = "filelock"
|
| 420 |
version = "3.25.2"
|
|
|
|
| 490 |
{ url = "https://files.pythonhosted.org/packages/6a/09/e21df6aef1e1ffc0c816f0522ddc3f6dcded766c3261813131c78a704470/gitpython-3.1.46-py3-none-any.whl", hash = "sha256:79812ed143d9d25b6d176a10bb511de0f9c67b1fa641d82097b0ab90398a2058", size = 208620, upload-time = "2026-01-01T15:37:30.574Z" },
|
| 491 |
]
|
| 492 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 493 |
[[package]]
|
| 494 |
name = "h11"
|
| 495 |
version = "0.16.0"
|
|
|
|
| 499 |
{ url = "https://files.pythonhosted.org/packages/04/4b/29cac41a4d98d144bf5f6d33995617b185d14b22401f75ca86f384e87ff1/h11-0.16.0-py3-none-any.whl", hash = "sha256:63cf8bbe7522de3bf65932fda1d9c2772064ffb3dae62d55932da54b31cb6c86", size = 37515, upload-time = "2025-04-24T03:35:24.344Z" },
|
| 500 |
]
|
| 501 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 502 |
[[package]]
|
| 503 |
name = "humanize"
|
| 504 |
version = "4.15.0"
|
|
|
|
| 1229 |
{ url = "https://files.pythonhosted.org/packages/9f/99/4c9c0c329bf9fc125008c3b54c7c94c0023518d06fc025ae36431375e1fe/nvidia_nvtx_cu12-12.8.90-py3-none-win_amd64.whl", hash = "sha256:619c8304aedc69f02ea82dd244541a83c3d9d40993381b3b590f1adaed3db41e", size = 56492, upload-time = "2025-03-07T01:52:24.69Z" },
|
| 1230 |
]
|
| 1231 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1232 |
[[package]]
|
| 1233 |
name = "packaging"
|
| 1234 |
version = "26.0"
|
|
|
|
| 1325 |
{ name = "chess-engine", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
|
| 1326 |
{ name = "numpy", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
|
| 1327 |
{ name = "psutil", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
|
| 1328 |
+
{ name = "safetensors", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
|
| 1329 |
{ name = "tqdm", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
|
| 1330 |
{ name = "wandb", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
|
| 1331 |
]
|
|
|
|
| 1341 |
{ name = "plotly", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
|
| 1342 |
{ name = "solara", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
|
| 1343 |
]
|
|
|
|
|
|
|
|
|
|
| 1344 |
dev = [
|
| 1345 |
{ name = "ipykernel", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
|
| 1346 |
{ name = "pytest", marker = "sys_platform == 'linux' or (extra == 'extra-4-pawn-cu128' and extra == 'extra-4-pawn-rocm')" },
|
|
|
|
| 1360 |
requires-dist = [
|
| 1361 |
{ name = "anywidget", marker = "extra == 'dashboard'", specifier = ">=0.9.21" },
|
| 1362 |
{ name = "chess-engine", editable = "engine" },
|
|
|
|
| 1363 |
{ name = "ipykernel", marker = "extra == 'dev'", specifier = ">=7.2.0" },
|
| 1364 |
{ name = "matplotlib", marker = "extra == 'eval'", specifier = ">=3.10.8" },
|
| 1365 |
{ name = "numpy", specifier = "~=2.2.0" },
|
|
|
|
| 1369 |
{ name = "psutil", specifier = ">=5.9.0" },
|
| 1370 |
{ name = "pyarrow", marker = "extra == 'eval'", specifier = ">=23.0.1" },
|
| 1371 |
{ name = "pytest", marker = "extra == 'dev'", specifier = "~=9.0.0" },
|
| 1372 |
+
{ name = "safetensors", specifier = ">=0.4.0" },
|
| 1373 |
{ name = "seaborn", marker = "extra == 'eval'", specifier = ">=0.13.2" },
|
| 1374 |
{ name = "solara", marker = "extra == 'dashboard'", specifier = ">=1.0.0" },
|
| 1375 |
{ name = "torch", marker = "extra == 'cu128'", specifier = "~=2.10.0", index = "https://download.pytorch.org/whl/cu128", conflict = { package = "pawn", extra = "cu128" } },
|
|
|
|
| 1378 |
{ name = "triton-rocm", marker = "extra == 'rocm'", specifier = ">=3.6.0", index = "https://download.pytorch.org/whl/rocm7.1", conflict = { package = "pawn", extra = "rocm" } },
|
| 1379 |
{ name = "wandb", specifier = "~=0.25.0" },
|
| 1380 |
]
|
| 1381 |
+
provides-extras = ["rocm", "cu128", "eval", "dashboard", "dev"]
|
| 1382 |
|
| 1383 |
[[package]]
|
| 1384 |
name = "pexpect"
|
|
|
|
| 1713 |
{ url = "https://files.pythonhosted.org/packages/36/c7/cfc8e811f061c841d7990b0201912c3556bfeb99cdcb7ed24adc8d6f8704/pydantic_core-2.41.5-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:56121965f7a4dc965bff783d70b907ddf3d57f6eba29b6d2e5dabfaf07799c51", size = 2145302, upload-time = "2025-11-04T13:43:46.64Z" },
|
| 1714 |
]
|
| 1715 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1716 |
[[package]]
|
| 1717 |
name = "pygments"
|
| 1718 |
version = "2.19.2"
|
|
|
|
| 1773 |
{ url = "https://files.pythonhosted.org/packages/ec/57/56b9bcc3c9c6a792fcbaf139543cee77261f3651ca9da0c93f5c1221264b/python_dateutil-2.9.0.post0-py2.py3-none-any.whl", hash = "sha256:a8b2bc7bffae282281c8140a97d3aa9c14da0b136dfe83f850eea9a5f7470427", size = 229892, upload-time = "2024-03-01T18:36:18.57Z" },
|
| 1774 |
]
|
| 1775 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1776 |
[[package]]
|
| 1777 |
name = "pytz"
|
| 1778 |
version = "2026.1.post1"
|
|
|
|
| 2003 |
]
|
| 2004 |
|
| 2005 |
[[package]]
|
| 2006 |
+
name = "safetensors"
|
| 2007 |
+
version = "0.7.0"
|
| 2008 |
source = { registry = "https://pypi.org/simple" }
|
| 2009 |
+
sdist = { url = "https://files.pythonhosted.org/packages/29/9c/6e74567782559a63bd040a236edca26fd71bc7ba88de2ef35d75df3bca5e/safetensors-0.7.0.tar.gz", hash = "sha256:07663963b67e8bd9f0b8ad15bb9163606cd27cc5a1b96235a50d8369803b96b0", size = 200878, upload-time = "2025-11-19T15:18:43.199Z" }
|
| 2010 |
+
wheels = [
|
| 2011 |
+
{ url = "https://files.pythonhosted.org/packages/fa/47/aef6c06649039accf914afef490268e1067ed82be62bcfa5b7e886ad15e8/safetensors-0.7.0-cp38-abi3-macosx_10_12_x86_64.whl", hash = "sha256:c82f4d474cf725255d9e6acf17252991c3c8aac038d6ef363a4bf8be2f6db517", size = 467781, upload-time = "2025-11-19T15:18:35.84Z" },
|
| 2012 |
+
{ url = "https://files.pythonhosted.org/packages/e8/00/374c0c068e30cd31f1e1b46b4b5738168ec79e7689ca82ee93ddfea05109/safetensors-0.7.0-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:94fd4858284736bb67a897a41608b5b0c2496c9bdb3bf2af1fa3409127f20d57", size = 447058, upload-time = "2025-11-19T15:18:34.416Z" },
|
| 2013 |
+
{ url = "https://files.pythonhosted.org/packages/f1/06/578ffed52c2296f93d7fd2d844cabfa92be51a587c38c8afbb8ae449ca89/safetensors-0.7.0-cp38-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e07d91d0c92a31200f25351f4acb2bc6aff7f48094e13ebb1d0fb995b54b6542", size = 491748, upload-time = "2025-11-19T15:18:09.79Z" },
|
| 2014 |
+
{ url = "https://files.pythonhosted.org/packages/ae/33/1debbbb70e4791dde185edb9413d1fe01619255abb64b300157d7f15dddd/safetensors-0.7.0-cp38-abi3-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:8469155f4cb518bafb4acf4865e8bb9d6804110d2d9bdcaa78564b9fd841e104", size = 503881, upload-time = "2025-11-19T15:18:16.145Z" },
|
| 2015 |
+
{ url = "https://files.pythonhosted.org/packages/8e/1c/40c2ca924d60792c3be509833df711b553c60effbd91da6f5284a83f7122/safetensors-0.7.0-cp38-abi3-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:54bef08bf00a2bff599982f6b08e8770e09cc012d7bba00783fc7ea38f1fb37d", size = 623463, upload-time = "2025-11-19T15:18:21.11Z" },
|
| 2016 |
+
{ url = "https://files.pythonhosted.org/packages/9b/3a/13784a9364bd43b0d61eef4bea2845039bc2030458b16594a1bd787ae26e/safetensors-0.7.0-cp38-abi3-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:42cb091236206bb2016d245c377ed383aa7f78691748f3bb6ee1bfa51ae2ce6a", size = 532855, upload-time = "2025-11-19T15:18:25.719Z" },
|
| 2017 |
+
{ url = "https://files.pythonhosted.org/packages/a0/60/429e9b1cb3fc651937727befe258ea24122d9663e4d5709a48c9cbfceecb/safetensors-0.7.0-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dac7252938f0696ddea46f5e855dd3138444e82236e3be475f54929f0c510d48", size = 507152, upload-time = "2025-11-19T15:18:33.023Z" },
|
| 2018 |
+
{ url = "https://files.pythonhosted.org/packages/3c/a8/4b45e4e059270d17af60359713ffd83f97900d45a6afa73aaa0d737d48b6/safetensors-0.7.0-cp38-abi3-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:1d060c70284127fa805085d8f10fbd0962792aed71879d00864acda69dbab981", size = 541856, upload-time = "2025-11-19T15:18:31.075Z" },
|
| 2019 |
+
{ url = "https://files.pythonhosted.org/packages/06/87/d26d8407c44175d8ae164a95b5a62707fcc445f3c0c56108e37d98070a3d/safetensors-0.7.0-cp38-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:cdab83a366799fa730f90a4ebb563e494f28e9e92c4819e556152ad55e43591b", size = 674060, upload-time = "2025-11-19T15:18:37.211Z" },
|
| 2020 |
+
{ url = "https://files.pythonhosted.org/packages/11/f5/57644a2ff08dc6325816ba7217e5095f17269dada2554b658442c66aed51/safetensors-0.7.0-cp38-abi3-musllinux_1_2_armv7l.whl", hash = "sha256:672132907fcad9f2aedcb705b2d7b3b93354a2aec1b2f706c4db852abe338f85", size = 771715, upload-time = "2025-11-19T15:18:38.689Z" },
|
| 2021 |
+
{ url = "https://files.pythonhosted.org/packages/86/31/17883e13a814bd278ae6e266b13282a01049b0c81341da7fd0e3e71a80a3/safetensors-0.7.0-cp38-abi3-musllinux_1_2_i686.whl", hash = "sha256:5d72abdb8a4d56d4020713724ba81dac065fedb7f3667151c4a637f1d3fb26c0", size = 714377, upload-time = "2025-11-19T15:18:40.162Z" },
|
| 2022 |
+
{ url = "https://files.pythonhosted.org/packages/4a/d8/0c8a7dc9b41dcac53c4cbf9df2b9c83e0e0097203de8b37a712b345c0be5/safetensors-0.7.0-cp38-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:b0f6d66c1c538d5a94a73aa9ddca8ccc4227e6c9ff555322ea40bdd142391dd4", size = 677368, upload-time = "2025-11-19T15:18:41.627Z" },
|
| 2023 |
+
{ url = "https://files.pythonhosted.org/packages/05/e5/cb4b713c8a93469e3c5be7c3f8d77d307e65fe89673e731f5c2bfd0a9237/safetensors-0.7.0-cp38-abi3-win32.whl", hash = "sha256:c74af94bf3ac15ac4d0f2a7c7b4663a15f8c2ab15ed0fc7531ca61d0835eccba", size = 326423, upload-time = "2025-11-19T15:18:45.74Z" },
|
| 2024 |
+
{ url = "https://files.pythonhosted.org/packages/5d/e6/ec8471c8072382cb91233ba7267fd931219753bb43814cbc71757bfd4dab/safetensors-0.7.0-cp38-abi3-win_amd64.whl", hash = "sha256:d1239932053f56f3456f32eb9625590cc7582e905021f94636202a864d470755", size = 341380, upload-time = "2025-11-19T15:18:44.427Z" },
|
| 2025 |
+
{ url = "https://files.pythonhosted.org/packages/a7/6a/4d08d89a6fcbe905c5ae68b8b34f0791850882fc19782d0d02c65abbdf3b/safetensors-0.7.0-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f4729811a6640d019a4b7ba8638ee2fd21fa5ca8c7e7bdf0fed62068fcaac737", size = 492430, upload-time = "2025-11-19T15:18:11.884Z" },
|
| 2026 |
+
{ url = "https://files.pythonhosted.org/packages/dd/29/59ed8152b30f72c42d00d241e58eaca558ae9dbfa5695206e2e0f54c7063/safetensors-0.7.0-pp310-pypy310_pp73-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:12f49080303fa6bb424b362149a12949dfbbf1e06811a88f2307276b0c131afd", size = 503977, upload-time = "2025-11-19T15:18:17.523Z" },
|
| 2027 |
+
{ url = "https://files.pythonhosted.org/packages/d3/0b/4811bfec67fa260e791369b16dab105e4bae82686120554cc484064e22b4/safetensors-0.7.0-pp310-pypy310_pp73-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:0071bffba4150c2f46cae1432d31995d77acfd9f8db598b5d1a2ce67e8440ad2", size = 623890, upload-time = "2025-11-19T15:18:22.666Z" },
|
| 2028 |
+
{ url = "https://files.pythonhosted.org/packages/58/5b/632a58724221ef03d78ab65062e82a1010e1bef8e8e0b9d7c6d7b8044841/safetensors-0.7.0-pp310-pypy310_pp73-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:473b32699f4200e69801bf5abf93f1a4ecd432a70984df164fc22ccf39c4a6f3", size = 531885, upload-time = "2025-11-19T15:18:27.146Z" },
|
| 2029 |
]
|
| 2030 |
|
| 2031 |
[[package]]
|
|
|
|
| 2043 |
{ url = "https://files.pythonhosted.org/packages/83/11/00d3c3dfc25ad54e731d91449895a79e4bf2384dc3ac01809010ba88f6d5/seaborn-0.13.2-py3-none-any.whl", hash = "sha256:636f8336facf092165e27924f223d3c62ca560b1f2bb5dff7ab7fad265361987", size = 294914, upload-time = "2024-01-25T13:21:49.598Z" },
|
| 2044 |
]
|
| 2045 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2046 |
[[package]]
|
| 2047 |
name = "sentry-sdk"
|
| 2048 |
version = "2.55.0"
|
|
|
|
| 2065 |
{ url = "https://files.pythonhosted.org/packages/9d/76/f789f7a86709c6b087c5a2f52f911838cad707cc613162401badc665acfe/setuptools-82.0.1-py3-none-any.whl", hash = "sha256:a59e362652f08dcd477c78bb6e7bd9d80a7995bc73ce773050228a348ce2e5bb", size = 1006223, upload-time = "2026-03-09T12:47:15.026Z" },
|
| 2066 |
]
|
| 2067 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2068 |
[[package]]
|
| 2069 |
name = "six"
|
| 2070 |
version = "1.17.0"
|
|
|
|
| 2219 |
{ url = "https://files.pythonhosted.org/packages/23/d1/136eb2cb77520a31e1f64cbae9d33ec6df0d78bdf4160398e86eec8a8754/tomli-2.4.0-py3-none-any.whl", hash = "sha256:1f776e7d669ebceb01dee46484485f43a4048746235e683bcdffacdf1fb4785a", size = 14477, upload-time = "2026-01-11T11:22:37.446Z" },
|
| 2220 |
]
|
| 2221 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2222 |
[[package]]
|
| 2223 |
name = "torch"
|
| 2224 |
version = "2.10.0+cu128"
|
|
|
|
| 2351 |
{ url = "https://download.pytorch.org/whl/triton_rocm-3.6.0-cp312-cp312-linux_x86_64.whl" },
|
| 2352 |
]
|
| 2353 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2354 |
[[package]]
|
| 2355 |
name = "typing-extensions"
|
| 2356 |
version = "4.15.0"
|