Add type annotations across codebase and configure pyright
Browse filesConfigure pyright in basic mode for static type checking. Add type
annotations to all function signatures (~76 additions across 18 files),
bringing coverage from ~65% to 100%.
Key structural changes to achieve clean type checking without
suppressions:
- Add PAWNCLM.get_block() typed accessor to avoid ModuleList type
erasure (single type-narrowing point for all adapter consumers)
- Replace None with nn.Identity() in adapter ModuleLists (bottleneck,
hybrid) so entries are always valid Modules
- Separate trainer's unwrapped _model from potentially-compiled model
to preserve concrete PAWNCLM type
- Define GenerativeModel Protocol for eval suite's duck-typed model
parameter instead of bare nn.Module
- Add class-level type declarations for registered buffers and
submodules (rope_cos, causal_mask, TransformerBlock attrs)
Result: 0 pyright errors, 1 type: ignore (in get_block), all tests pass.
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
- pawn/adapters/bottleneck.py +19 -22
- pawn/adapters/film.py +7 -5
- pawn/adapters/hybrid.py +17 -12
- pawn/adapters/lora.py +7 -6
- pawn/adapters/sparse.py +10 -5
- pawn/data.py +4 -3
- pawn/eval_suite/corpus.py +9 -8
- pawn/eval_suite/diagnostics.py +3 -2
- pawn/eval_suite/generation.py +41 -11
- pawn/eval_suite/lichess.py +2 -1
- pawn/eval_suite/probes.py +6 -6
- pawn/eval_suite/viz.py +11 -5
- pawn/eval_suite/worker.py +36 -26
- pawn/lichess_data.py +2 -2
- pawn/logging.py +4 -4
- pawn/model.py +21 -2
- pawn/trainer.py +11 -14
- pyproject.toml +8 -0
|
@@ -80,18 +80,18 @@ class BottleneckCLM(nn.Module):
|
|
| 80 |
for p in backbone.parameters():
|
| 81 |
p.requires_grad = False
|
| 82 |
|
| 83 |
-
# Create adapter modules (
|
| 84 |
self.attn_adapters = nn.ModuleList()
|
| 85 |
self.ffn_adapters = nn.ModuleList()
|
| 86 |
for i in range(n_layers):
|
| 87 |
if i in self._attn_set:
|
| 88 |
self.attn_adapters.append(BottleneckAdapter(cfg.d_model, bottleneck_dim))
|
| 89 |
else:
|
| 90 |
-
self.attn_adapters.append(
|
| 91 |
if i in self._ffn_set:
|
| 92 |
self.ffn_adapters.append(BottleneckAdapter(cfg.d_model, bottleneck_dim))
|
| 93 |
else:
|
| 94 |
-
self.ffn_adapters.append(
|
| 95 |
|
| 96 |
@property
|
| 97 |
def cfg(self) -> CLMConfig:
|
|
@@ -106,16 +106,15 @@ class BottleneckCLM(nn.Module):
|
|
| 106 |
rope_cos = bb.rope_cos[:, :, :T, :]
|
| 107 |
rope_sin = bb.rope_sin[:, :, :T, :]
|
| 108 |
|
| 109 |
-
for i
|
|
|
|
| 110 |
# Attention sublayer + adapter
|
| 111 |
x = x + block.attn(block.attn_norm(x), rope_cos, rope_sin, None)
|
| 112 |
-
|
| 113 |
-
x = self.attn_adapters[i](x)
|
| 114 |
|
| 115 |
# FFN sublayer + adapter
|
| 116 |
x = x + block.ffn(block.ffn_norm(x))
|
| 117 |
-
|
| 118 |
-
x = self.ffn_adapters[i](x)
|
| 119 |
|
| 120 |
return bb.final_norm(x)
|
| 121 |
|
|
@@ -143,14 +142,13 @@ class BottleneckCLM(nn.Module):
|
|
| 143 |
rope_cos = bb.rope_cos[:, :, :T, :]
|
| 144 |
rope_sin = bb.rope_sin[:, :, :T, :]
|
| 145 |
|
| 146 |
-
for i
|
|
|
|
| 147 |
x = x + block.attn(block.attn_norm(x), rope_cos, rope_sin, mask)
|
| 148 |
-
|
| 149 |
-
x = self.attn_adapters[i](x)
|
| 150 |
|
| 151 |
x = x + block.ffn(block.ffn_norm(x))
|
| 152 |
-
|
| 153 |
-
x = self.ffn_adapters[i](x)
|
| 154 |
|
| 155 |
x = bb.final_norm(x)
|
| 156 |
return self.project_head(x)
|
|
@@ -174,20 +172,19 @@ class BottleneckCLM(nn.Module):
|
|
| 174 |
rope_sin = bb.rope_sin[:, :, :T_new, :]
|
| 175 |
|
| 176 |
new_kv_cache = []
|
| 177 |
-
for i
|
|
|
|
| 178 |
# KV-cache forward for attention
|
| 179 |
layer_cache = kv_cache[i] if kv_cache is not None else None
|
| 180 |
attn_out, new_cache = block.attn.forward_kv(
|
| 181 |
block.attn_norm(x), rope_cos, rope_sin, layer_cache,
|
| 182 |
)
|
| 183 |
x = x + attn_out
|
| 184 |
-
|
| 185 |
-
x = self.attn_adapters[i](x)
|
| 186 |
new_kv_cache.append(new_cache)
|
| 187 |
|
| 188 |
x = x + block.ffn(block.ffn_norm(x))
|
| 189 |
-
|
| 190 |
-
x = self.ffn_adapters[i](x)
|
| 191 |
|
| 192 |
x = bb.final_norm(x[:, -1:, :])
|
| 193 |
logits = bb.lm_head(x)
|
|
@@ -218,12 +215,12 @@ class BottleneckCLM(nn.Module):
|
|
| 218 |
"""Per-layer adapter weight norms for monitoring."""
|
| 219 |
report = {}
|
| 220 |
for i in range(len(self.backbone.layers)):
|
| 221 |
-
|
| 222 |
-
|
| 223 |
report[f"adapter/layer{i}.attn.down"] = a.down.weight.data.norm().item()
|
| 224 |
report[f"adapter/layer{i}.attn.up"] = a.up.weight.data.norm().item()
|
| 225 |
-
|
| 226 |
-
|
| 227 |
report[f"adapter/layer{i}.ffn.down"] = a.down.weight.data.norm().item()
|
| 228 |
report[f"adapter/layer{i}.ffn.up"] = a.up.weight.data.norm().item()
|
| 229 |
return report
|
|
|
|
| 80 |
for p in backbone.parameters():
|
| 81 |
p.requires_grad = False
|
| 82 |
|
| 83 |
+
# Create adapter modules (Identity for non-adapted layers)
|
| 84 |
self.attn_adapters = nn.ModuleList()
|
| 85 |
self.ffn_adapters = nn.ModuleList()
|
| 86 |
for i in range(n_layers):
|
| 87 |
if i in self._attn_set:
|
| 88 |
self.attn_adapters.append(BottleneckAdapter(cfg.d_model, bottleneck_dim))
|
| 89 |
else:
|
| 90 |
+
self.attn_adapters.append(nn.Identity())
|
| 91 |
if i in self._ffn_set:
|
| 92 |
self.ffn_adapters.append(BottleneckAdapter(cfg.d_model, bottleneck_dim))
|
| 93 |
else:
|
| 94 |
+
self.ffn_adapters.append(nn.Identity())
|
| 95 |
|
| 96 |
@property
|
| 97 |
def cfg(self) -> CLMConfig:
|
|
|
|
| 106 |
rope_cos = bb.rope_cos[:, :, :T, :]
|
| 107 |
rope_sin = bb.rope_sin[:, :, :T, :]
|
| 108 |
|
| 109 |
+
for i in range(len(bb.layers)):
|
| 110 |
+
block = bb.get_block(i)
|
| 111 |
# Attention sublayer + adapter
|
| 112 |
x = x + block.attn(block.attn_norm(x), rope_cos, rope_sin, None)
|
| 113 |
+
x = self.attn_adapters[i](x)
|
|
|
|
| 114 |
|
| 115 |
# FFN sublayer + adapter
|
| 116 |
x = x + block.ffn(block.ffn_norm(x))
|
| 117 |
+
x = self.ffn_adapters[i](x)
|
|
|
|
| 118 |
|
| 119 |
return bb.final_norm(x)
|
| 120 |
|
|
|
|
| 142 |
rope_cos = bb.rope_cos[:, :, :T, :]
|
| 143 |
rope_sin = bb.rope_sin[:, :, :T, :]
|
| 144 |
|
| 145 |
+
for i in range(len(bb.layers)):
|
| 146 |
+
block = bb.get_block(i)
|
| 147 |
x = x + block.attn(block.attn_norm(x), rope_cos, rope_sin, mask)
|
| 148 |
+
x = self.attn_adapters[i](x)
|
|
|
|
| 149 |
|
| 150 |
x = x + block.ffn(block.ffn_norm(x))
|
| 151 |
+
x = self.ffn_adapters[i](x)
|
|
|
|
| 152 |
|
| 153 |
x = bb.final_norm(x)
|
| 154 |
return self.project_head(x)
|
|
|
|
| 172 |
rope_sin = bb.rope_sin[:, :, :T_new, :]
|
| 173 |
|
| 174 |
new_kv_cache = []
|
| 175 |
+
for i in range(len(bb.layers)):
|
| 176 |
+
block = bb.get_block(i)
|
| 177 |
# KV-cache forward for attention
|
| 178 |
layer_cache = kv_cache[i] if kv_cache is not None else None
|
| 179 |
attn_out, new_cache = block.attn.forward_kv(
|
| 180 |
block.attn_norm(x), rope_cos, rope_sin, layer_cache,
|
| 181 |
)
|
| 182 |
x = x + attn_out
|
| 183 |
+
x = self.attn_adapters[i](x)
|
|
|
|
| 184 |
new_kv_cache.append(new_cache)
|
| 185 |
|
| 186 |
x = x + block.ffn(block.ffn_norm(x))
|
| 187 |
+
x = self.ffn_adapters[i](x)
|
|
|
|
| 188 |
|
| 189 |
x = bb.final_norm(x[:, -1:, :])
|
| 190 |
logits = bb.lm_head(x)
|
|
|
|
| 215 |
"""Per-layer adapter weight norms for monitoring."""
|
| 216 |
report = {}
|
| 217 |
for i in range(len(self.backbone.layers)):
|
| 218 |
+
a = self.attn_adapters[i]
|
| 219 |
+
if isinstance(a, BottleneckAdapter):
|
| 220 |
report[f"adapter/layer{i}.attn.down"] = a.down.weight.data.norm().item()
|
| 221 |
report[f"adapter/layer{i}.attn.up"] = a.up.weight.data.norm().item()
|
| 222 |
+
a = self.ffn_adapters[i]
|
| 223 |
+
if isinstance(a, BottleneckAdapter):
|
| 224 |
report[f"adapter/layer{i}.ffn.down"] = a.down.weight.data.norm().item()
|
| 225 |
report[f"adapter/layer{i}.ffn.up"] = a.up.weight.data.norm().item()
|
| 226 |
return report
|
|
@@ -140,10 +140,11 @@ class FiLMCLM(nn.Module):
|
|
| 140 |
rope_sin = bb.rope_sin[:, :, :T_new, :]
|
| 141 |
|
| 142 |
new_kv_cache = []
|
| 143 |
-
for i
|
|
|
|
| 144 |
layer_cache = kv_cache[i] if kv_cache is not None else None
|
| 145 |
-
x, new_cache =
|
| 146 |
-
x =
|
| 147 |
new_kv_cache.append(new_cache)
|
| 148 |
|
| 149 |
x = bb.final_norm(x[:, -1:, :])
|
|
@@ -182,8 +183,9 @@ class FiLMCLM(nn.Module):
|
|
| 182 |
"""Per-layer FiLM deviation from identity, for monitoring."""
|
| 183 |
report = {}
|
| 184 |
for i, film in enumerate(self.hidden_films):
|
| 185 |
-
|
| 186 |
-
|
|
|
|
| 187 |
if self.output_film is not None:
|
| 188 |
report["output/gamma_dev"] = (self.output_film.gamma - 1.0).norm().item()
|
| 189 |
report["output/beta_norm"] = self.output_film.beta.norm().item()
|
|
|
|
| 140 |
rope_sin = bb.rope_sin[:, :, :T_new, :]
|
| 141 |
|
| 142 |
new_kv_cache = []
|
| 143 |
+
for i in range(len(bb.layers)):
|
| 144 |
+
block = bb.get_block(i)
|
| 145 |
layer_cache = kv_cache[i] if kv_cache is not None else None
|
| 146 |
+
x, new_cache = block.forward_kv(x, rope_cos, rope_sin, layer_cache)
|
| 147 |
+
x = self.hidden_films[i](x)
|
| 148 |
new_kv_cache.append(new_cache)
|
| 149 |
|
| 150 |
x = bb.final_norm(x[:, -1:, :])
|
|
|
|
| 183 |
"""Per-layer FiLM deviation from identity, for monitoring."""
|
| 184 |
report = {}
|
| 185 |
for i, film in enumerate(self.hidden_films):
|
| 186 |
+
if isinstance(film, FiLMLayer):
|
| 187 |
+
report[f"hidden_{i}/gamma_dev"] = (film.gamma - 1.0).norm().item()
|
| 188 |
+
report[f"hidden_{i}/beta_norm"] = film.beta.norm().item()
|
| 189 |
if self.output_film is not None:
|
| 190 |
report["output/gamma_dev"] = (self.output_film.gamma - 1.0).norm().item()
|
| 191 |
report["output/beta_norm"] = self.output_film.beta.norm().item()
|
|
@@ -65,9 +65,11 @@ class HybridCLM(nn.Module):
|
|
| 65 |
p.requires_grad = False
|
| 66 |
|
| 67 |
# Inject LoRA
|
| 68 |
-
for layer_idx
|
| 69 |
if layer_idx not in self.lora_layer_set:
|
| 70 |
continue
|
|
|
|
|
|
|
| 71 |
attn: Attention = block.attn
|
| 72 |
for proj_name in self.attn_targets:
|
| 73 |
original = getattr(attn, proj_name)
|
|
@@ -78,14 +80,14 @@ class HybridCLM(nn.Module):
|
|
| 78 |
original = getattr(ffn, proj_name)
|
| 79 |
setattr(ffn, proj_name, LoRALinear(original, lora_rank, self.lora_alpha))
|
| 80 |
|
| 81 |
-
# Create FiLM layers (
|
| 82 |
if use_film:
|
| 83 |
self.hidden_films = nn.ModuleList()
|
| 84 |
for i in range(n_layers):
|
| 85 |
if i in self.film_layer_set:
|
| 86 |
self.hidden_films.append(FiLMLayer(cfg.d_model))
|
| 87 |
else:
|
| 88 |
-
self.hidden_films.append(
|
| 89 |
else:
|
| 90 |
self.hidden_films = None
|
| 91 |
|
|
@@ -109,7 +111,7 @@ class HybridCLM(nn.Module):
|
|
| 109 |
|
| 110 |
for i, layer in enumerate(bb.layers):
|
| 111 |
x = layer(x, rope_cos, rope_sin, None) # LoRA happens inside
|
| 112 |
-
if self.hidden_films is not None
|
| 113 |
x = self.hidden_films[i](x)
|
| 114 |
|
| 115 |
return bb.final_norm(x)
|
|
@@ -143,7 +145,7 @@ class HybridCLM(nn.Module):
|
|
| 143 |
|
| 144 |
for i, layer in enumerate(bb.layers):
|
| 145 |
x = layer(x, rope_cos, rope_sin, mask)
|
| 146 |
-
if self.hidden_films is not None
|
| 147 |
x = self.hidden_films[i](x)
|
| 148 |
|
| 149 |
x = bb.final_norm(x)
|
|
@@ -168,10 +170,11 @@ class HybridCLM(nn.Module):
|
|
| 168 |
rope_sin = bb.rope_sin[:, :, :T_new, :]
|
| 169 |
|
| 170 |
new_kv_cache = []
|
| 171 |
-
for i
|
|
|
|
| 172 |
layer_cache = kv_cache[i] if kv_cache is not None else None
|
| 173 |
-
x, new_cache =
|
| 174 |
-
if self.hidden_films is not None
|
| 175 |
x = self.hidden_films[i](x)
|
| 176 |
new_kv_cache.append(new_cache)
|
| 177 |
|
|
@@ -187,7 +190,8 @@ class HybridCLM(nn.Module):
|
|
| 187 |
def lora_parameters(self) -> list[nn.Parameter]:
|
| 188 |
"""Return only LoRA A/B parameters."""
|
| 189 |
params = []
|
| 190 |
-
for
|
|
|
|
| 191 |
for proj_name in self.attn_targets:
|
| 192 |
module = getattr(block.attn, proj_name)
|
| 193 |
if isinstance(module, LoRALinear):
|
|
@@ -206,7 +210,7 @@ class HybridCLM(nn.Module):
|
|
| 206 |
params = []
|
| 207 |
if self.hidden_films is not None:
|
| 208 |
for film in self.hidden_films:
|
| 209 |
-
if film
|
| 210 |
params.extend(film.parameters())
|
| 211 |
if self.output_film is not None:
|
| 212 |
params.extend(self.output_film.parameters())
|
|
@@ -233,7 +237,8 @@ class HybridCLM(nn.Module):
|
|
| 233 |
report = {}
|
| 234 |
|
| 235 |
# LoRA norms
|
| 236 |
-
for layer_idx
|
|
|
|
| 237 |
for proj_name in self.attn_targets:
|
| 238 |
module = getattr(block.attn, proj_name)
|
| 239 |
if isinstance(module, LoRALinear):
|
|
@@ -247,7 +252,7 @@ class HybridCLM(nn.Module):
|
|
| 247 |
# FiLM norms
|
| 248 |
if self.hidden_films is not None:
|
| 249 |
for i, film in enumerate(self.hidden_films):
|
| 250 |
-
if film
|
| 251 |
report[f"film/hidden_{i}/gamma_dev"] = (film.gamma - 1.0).norm().item()
|
| 252 |
report[f"film/hidden_{i}/beta_norm"] = film.beta.norm().item()
|
| 253 |
if self.output_film is not None:
|
|
|
|
| 65 |
p.requires_grad = False
|
| 66 |
|
| 67 |
# Inject LoRA
|
| 68 |
+
for layer_idx in range(n_layers):
|
| 69 |
if layer_idx not in self.lora_layer_set:
|
| 70 |
continue
|
| 71 |
+
block = backbone.get_block(layer_idx)
|
| 72 |
+
|
| 73 |
attn: Attention = block.attn
|
| 74 |
for proj_name in self.attn_targets:
|
| 75 |
original = getattr(attn, proj_name)
|
|
|
|
| 80 |
original = getattr(ffn, proj_name)
|
| 81 |
setattr(ffn, proj_name, LoRALinear(original, lora_rank, self.lora_alpha))
|
| 82 |
|
| 83 |
+
# Create FiLM layers (Identity for non-adapted layers)
|
| 84 |
if use_film:
|
| 85 |
self.hidden_films = nn.ModuleList()
|
| 86 |
for i in range(n_layers):
|
| 87 |
if i in self.film_layer_set:
|
| 88 |
self.hidden_films.append(FiLMLayer(cfg.d_model))
|
| 89 |
else:
|
| 90 |
+
self.hidden_films.append(nn.Identity())
|
| 91 |
else:
|
| 92 |
self.hidden_films = None
|
| 93 |
|
|
|
|
| 111 |
|
| 112 |
for i, layer in enumerate(bb.layers):
|
| 113 |
x = layer(x, rope_cos, rope_sin, None) # LoRA happens inside
|
| 114 |
+
if self.hidden_films is not None:
|
| 115 |
x = self.hidden_films[i](x)
|
| 116 |
|
| 117 |
return bb.final_norm(x)
|
|
|
|
| 145 |
|
| 146 |
for i, layer in enumerate(bb.layers):
|
| 147 |
x = layer(x, rope_cos, rope_sin, mask)
|
| 148 |
+
if self.hidden_films is not None:
|
| 149 |
x = self.hidden_films[i](x)
|
| 150 |
|
| 151 |
x = bb.final_norm(x)
|
|
|
|
| 170 |
rope_sin = bb.rope_sin[:, :, :T_new, :]
|
| 171 |
|
| 172 |
new_kv_cache = []
|
| 173 |
+
for i in range(len(bb.layers)):
|
| 174 |
+
block = bb.get_block(i)
|
| 175 |
layer_cache = kv_cache[i] if kv_cache is not None else None
|
| 176 |
+
x, new_cache = block.forward_kv(x, rope_cos, rope_sin, layer_cache)
|
| 177 |
+
if self.hidden_films is not None:
|
| 178 |
x = self.hidden_films[i](x)
|
| 179 |
new_kv_cache.append(new_cache)
|
| 180 |
|
|
|
|
| 190 |
def lora_parameters(self) -> list[nn.Parameter]:
|
| 191 |
"""Return only LoRA A/B parameters."""
|
| 192 |
params = []
|
| 193 |
+
for layer_idx in range(len(self.backbone.layers)):
|
| 194 |
+
block = self.backbone.get_block(layer_idx)
|
| 195 |
for proj_name in self.attn_targets:
|
| 196 |
module = getattr(block.attn, proj_name)
|
| 197 |
if isinstance(module, LoRALinear):
|
|
|
|
| 210 |
params = []
|
| 211 |
if self.hidden_films is not None:
|
| 212 |
for film in self.hidden_films:
|
| 213 |
+
if isinstance(film, FiLMLayer):
|
| 214 |
params.extend(film.parameters())
|
| 215 |
if self.output_film is not None:
|
| 216 |
params.extend(self.output_film.parameters())
|
|
|
|
| 237 |
report = {}
|
| 238 |
|
| 239 |
# LoRA norms
|
| 240 |
+
for layer_idx in range(len(self.backbone.layers)):
|
| 241 |
+
block = self.backbone.get_block(layer_idx)
|
| 242 |
for proj_name in self.attn_targets:
|
| 243 |
module = getattr(block.attn, proj_name)
|
| 244 |
if isinstance(module, LoRALinear):
|
|
|
|
| 252 |
# FiLM norms
|
| 253 |
if self.hidden_films is not None:
|
| 254 |
for i, film in enumerate(self.hidden_films):
|
| 255 |
+
if isinstance(film, FiLMLayer):
|
| 256 |
report[f"film/hidden_{i}/gamma_dev"] = (film.gamma - 1.0).norm().item()
|
| 257 |
report[f"film/hidden_{i}/beta_norm"] = film.beta.norm().item()
|
| 258 |
if self.output_film is not None:
|
|
@@ -90,9 +90,10 @@ class LoRACLM(nn.Module):
|
|
| 90 |
p.requires_grad = False
|
| 91 |
|
| 92 |
# Inject LoRA into selected layers
|
| 93 |
-
for layer_idx
|
| 94 |
if layer_idx not in self.adapted_layers:
|
| 95 |
continue
|
|
|
|
| 96 |
|
| 97 |
attn: Attention = block.attn
|
| 98 |
for proj_name in self.attn_targets:
|
|
@@ -182,9 +183,9 @@ class LoRACLM(nn.Module):
|
|
| 182 |
rope_sin = bb.rope_sin[:, :, :T_new, :]
|
| 183 |
|
| 184 |
new_kv_cache = []
|
| 185 |
-
for i
|
| 186 |
layer_cache = kv_cache[i] if kv_cache is not None else None
|
| 187 |
-
x, new_cache =
|
| 188 |
new_kv_cache.append(new_cache)
|
| 189 |
|
| 190 |
x = bb.final_norm(x[:, -1:, :])
|
|
@@ -215,8 +216,8 @@ class LoRACLM(nn.Module):
|
|
| 215 |
def lora_weight_report(self) -> dict[str, float]:
|
| 216 |
"""Per-layer LoRA weight norms for monitoring."""
|
| 217 |
report = {}
|
| 218 |
-
for layer_idx
|
| 219 |
-
attn =
|
| 220 |
for proj_name in self.attn_targets:
|
| 221 |
module = getattr(attn, proj_name)
|
| 222 |
if isinstance(module, LoRALinear):
|
|
@@ -224,7 +225,7 @@ class LoRACLM(nn.Module):
|
|
| 224 |
report[f"layer{layer_idx}.{proj_name}.B"] = module.lora_B.data.norm().item()
|
| 225 |
|
| 226 |
if self.adapt_ffn:
|
| 227 |
-
ffn =
|
| 228 |
for proj_name in _FFN_TARGETS:
|
| 229 |
module = getattr(ffn, proj_name)
|
| 230 |
if isinstance(module, LoRALinear):
|
|
|
|
| 90 |
p.requires_grad = False
|
| 91 |
|
| 92 |
# Inject LoRA into selected layers
|
| 93 |
+
for layer_idx in range(len(backbone.layers)):
|
| 94 |
if layer_idx not in self.adapted_layers:
|
| 95 |
continue
|
| 96 |
+
block = backbone.get_block(layer_idx)
|
| 97 |
|
| 98 |
attn: Attention = block.attn
|
| 99 |
for proj_name in self.attn_targets:
|
|
|
|
| 183 |
rope_sin = bb.rope_sin[:, :, :T_new, :]
|
| 184 |
|
| 185 |
new_kv_cache = []
|
| 186 |
+
for i in range(len(bb.layers)):
|
| 187 |
layer_cache = kv_cache[i] if kv_cache is not None else None
|
| 188 |
+
x, new_cache = bb.get_block(i).forward_kv(x, rope_cos, rope_sin, layer_cache)
|
| 189 |
new_kv_cache.append(new_cache)
|
| 190 |
|
| 191 |
x = bb.final_norm(x[:, -1:, :])
|
|
|
|
| 216 |
def lora_weight_report(self) -> dict[str, float]:
|
| 217 |
"""Per-layer LoRA weight norms for monitoring."""
|
| 218 |
report = {}
|
| 219 |
+
for layer_idx in range(len(self.backbone.layers)):
|
| 220 |
+
attn = self.backbone.get_block(layer_idx).attn
|
| 221 |
for proj_name in self.attn_targets:
|
| 222 |
module = getattr(attn, proj_name)
|
| 223 |
if isinstance(module, LoRALinear):
|
|
|
|
| 225 |
report[f"layer{layer_idx}.{proj_name}.B"] = module.lora_B.data.norm().item()
|
| 226 |
|
| 227 |
if self.adapt_ffn:
|
| 228 |
+
ffn = self.backbone.get_block(layer_idx).ffn
|
| 229 |
for proj_name in _FFN_TARGETS:
|
| 230 |
module = getattr(ffn, proj_name)
|
| 231 |
if isinstance(module, LoRALinear):
|
|
@@ -24,6 +24,8 @@ class SparseLinear(nn.Module):
|
|
| 24 |
output = F.linear(x, W_frozen + delta * mask, bias)
|
| 25 |
"""
|
| 26 |
|
|
|
|
|
|
|
| 27 |
def __init__(self, frozen_linear: nn.Linear, mask: torch.Tensor):
|
| 28 |
super().__init__()
|
| 29 |
self.frozen = frozen_linear
|
|
@@ -82,9 +84,10 @@ class SparseCLM(nn.Module):
|
|
| 82 |
gen = torch.Generator().manual_seed(seed)
|
| 83 |
|
| 84 |
# Inject sparse adapters
|
| 85 |
-
for layer_idx
|
| 86 |
if layer_idx not in self.adapted_layers:
|
| 87 |
continue
|
|
|
|
| 88 |
|
| 89 |
attn: Attention = block.attn
|
| 90 |
for proj_name in self.attn_targets:
|
|
@@ -166,9 +169,9 @@ class SparseCLM(nn.Module):
|
|
| 166 |
rope_sin = bb.rope_sin[:, :, :T_new, :]
|
| 167 |
|
| 168 |
new_kv_cache = []
|
| 169 |
-
for i
|
| 170 |
layer_cache = kv_cache[i] if kv_cache is not None else None
|
| 171 |
-
x, new_cache =
|
| 172 |
new_kv_cache.append(new_cache)
|
| 173 |
|
| 174 |
x = bb.final_norm(x[:, -1:, :])
|
|
@@ -184,7 +187,8 @@ class SparseCLM(nn.Module):
|
|
| 184 |
def n_active_params(self) -> int:
|
| 185 |
"""Count of actually active (masked-in) parameters."""
|
| 186 |
total = 0
|
| 187 |
-
for
|
|
|
|
| 188 |
for proj_name in self.attn_targets:
|
| 189 |
module = getattr(block.attn, proj_name)
|
| 190 |
if isinstance(module, SparseLinear):
|
|
@@ -215,7 +219,8 @@ class SparseCLM(nn.Module):
|
|
| 215 |
def sparse_weight_report(self) -> dict[str, float]:
|
| 216 |
"""Per-layer sparse delta norms for monitoring."""
|
| 217 |
report = {}
|
| 218 |
-
for layer_idx
|
|
|
|
| 219 |
for proj_name in self.attn_targets:
|
| 220 |
module = getattr(block.attn, proj_name)
|
| 221 |
if isinstance(module, SparseLinear):
|
|
|
|
| 24 |
output = F.linear(x, W_frozen + delta * mask, bias)
|
| 25 |
"""
|
| 26 |
|
| 27 |
+
mask: torch.Tensor
|
| 28 |
+
|
| 29 |
def __init__(self, frozen_linear: nn.Linear, mask: torch.Tensor):
|
| 30 |
super().__init__()
|
| 31 |
self.frozen = frozen_linear
|
|
|
|
| 84 |
gen = torch.Generator().manual_seed(seed)
|
| 85 |
|
| 86 |
# Inject sparse adapters
|
| 87 |
+
for layer_idx in range(len(backbone.layers)):
|
| 88 |
if layer_idx not in self.adapted_layers:
|
| 89 |
continue
|
| 90 |
+
block = backbone.get_block(layer_idx)
|
| 91 |
|
| 92 |
attn: Attention = block.attn
|
| 93 |
for proj_name in self.attn_targets:
|
|
|
|
| 169 |
rope_sin = bb.rope_sin[:, :, :T_new, :]
|
| 170 |
|
| 171 |
new_kv_cache = []
|
| 172 |
+
for i in range(len(bb.layers)):
|
| 173 |
layer_cache = kv_cache[i] if kv_cache is not None else None
|
| 174 |
+
x, new_cache = bb.get_block(i).forward_kv(x, rope_cos, rope_sin, layer_cache)
|
| 175 |
new_kv_cache.append(new_cache)
|
| 176 |
|
| 177 |
x = bb.final_norm(x[:, -1:, :])
|
|
|
|
| 187 |
def n_active_params(self) -> int:
|
| 188 |
"""Count of actually active (masked-in) parameters."""
|
| 189 |
total = 0
|
| 190 |
+
for layer_idx in range(len(self.backbone.layers)):
|
| 191 |
+
block = self.backbone.get_block(layer_idx)
|
| 192 |
for proj_name in self.attn_targets:
|
| 193 |
module = getattr(block.attn, proj_name)
|
| 194 |
if isinstance(module, SparseLinear):
|
|
|
|
| 219 |
def sparse_weight_report(self) -> dict[str, float]:
|
| 220 |
"""Per-layer sparse delta norms for monitoring."""
|
| 221 |
report = {}
|
| 222 |
+
for layer_idx in range(len(self.backbone.layers)):
|
| 223 |
+
block = self.backbone.get_block(layer_idx)
|
| 224 |
for proj_name in self.attn_targets:
|
| 225 |
module = getattr(block.attn, proj_name)
|
| 226 |
if isinstance(module, SparseLinear):
|
|
@@ -3,6 +3,7 @@
|
|
| 3 |
import os
|
| 4 |
import threading
|
| 5 |
import time
|
|
|
|
| 6 |
|
| 7 |
import numpy as np
|
| 8 |
import torch
|
|
@@ -19,7 +20,7 @@ from pawn.config import (
|
|
| 19 |
)
|
| 20 |
|
| 21 |
|
| 22 |
-
_positions_cache: dict = {}
|
| 23 |
|
| 24 |
|
| 25 |
|
|
@@ -130,10 +131,10 @@ class CLMDataset(torch.utils.data.IterableDataset):
|
|
| 130 |
self._start_step = 0
|
| 131 |
self._main_pid = os.getpid()
|
| 132 |
|
| 133 |
-
def set_start_step(self, step: int):
|
| 134 |
self._start_step = step
|
| 135 |
|
| 136 |
-
def __iter__(self):
|
| 137 |
worker_info = torch.utils.data.get_worker_info()
|
| 138 |
worker_id = worker_info.id if worker_info else 0
|
| 139 |
num_workers = worker_info.num_workers if worker_info else 1
|
|
|
|
| 3 |
import os
|
| 4 |
import threading
|
| 5 |
import time
|
| 6 |
+
from collections.abc import Iterator
|
| 7 |
|
| 8 |
import numpy as np
|
| 9 |
import torch
|
|
|
|
| 20 |
)
|
| 21 |
|
| 22 |
|
| 23 |
+
_positions_cache: dict[tuple[str, int], torch.Tensor] = {}
|
| 24 |
|
| 25 |
|
| 26 |
|
|
|
|
| 131 |
self._start_step = 0
|
| 132 |
self._main_pid = os.getpid()
|
| 133 |
|
| 134 |
+
def set_start_step(self, step: int) -> None:
|
| 135 |
self._start_step = step
|
| 136 |
|
| 137 |
+
def __iter__(self) -> Iterator[dict[str, torch.Tensor]]:
|
| 138 |
worker_info = torch.utils.data.get_worker_info()
|
| 139 |
worker_id = worker_info.id if worker_info else 0
|
| 140 |
num_workers = worker_info.num_workers if worker_info else 1
|
|
@@ -10,6 +10,7 @@ Storage layout:
|
|
| 10 |
|
| 11 |
import json
|
| 12 |
import time
|
|
|
|
| 13 |
from pathlib import Path
|
| 14 |
|
| 15 |
import numpy as np
|
|
@@ -32,7 +33,7 @@ def _popcount_u64(arr: np.ndarray) -> np.ndarray:
|
|
| 32 |
return result
|
| 33 |
|
| 34 |
|
| 35 |
-
def _count_legal_moves(move_ids, game_lengths):
|
| 36 |
"""Legal move count per ply via bit-packed grids + promo mask."""
|
| 37 |
grid, promo_mask = engine.compute_legal_move_masks(move_ids, game_lengths)
|
| 38 |
grid_counts = np.zeros(grid.shape[:2], dtype=np.uint32)
|
|
@@ -384,7 +385,7 @@ _PHASES = [("ply_1_20", 0, 20), ("ply_21_80", 20, 80),
|
|
| 384 |
("ply_81_150", 80, 150), ("ply_150_plus", 150, 9999)]
|
| 385 |
|
| 386 |
|
| 387 |
-
def _iter_position_parts(corpus: dict):
|
| 388 |
"""Yield each position parquet part as an eager DataFrame."""
|
| 389 |
from pathlib import Path
|
| 390 |
# Find corpus dir from the LazyFrame's file path
|
|
@@ -395,7 +396,7 @@ def _iter_position_parts(corpus: dict):
|
|
| 395 |
yield pl.read_parquet(f)
|
| 396 |
|
| 397 |
|
| 398 |
-
def _new_accumulator() -> dict:
|
| 399 |
return {
|
| 400 |
"n": 0, "sum_k": 0.0, "sum_k_sq": 0.0, "k_min": 999, "k_max": 0,
|
| 401 |
"sum_inv_k": 0.0, "sum_inv_k_sq": 0.0,
|
|
@@ -408,7 +409,7 @@ def _new_accumulator() -> dict:
|
|
| 408 |
}
|
| 409 |
|
| 410 |
|
| 411 |
-
def _accumulate(acc: dict, df: pl.DataFrame):
|
| 412 |
"""Accumulate stats from one chunk (already filtered to k > 0)."""
|
| 413 |
k = df["k"].to_numpy().astype(np.float64)
|
| 414 |
ply = df["ply"].to_numpy()
|
|
@@ -455,7 +456,7 @@ def _accumulate(acc: dict, df: pl.DataFrame):
|
|
| 455 |
del k, ply, chk, inv_k, ln_k, top5
|
| 456 |
|
| 457 |
|
| 458 |
-
def _finalize_k_stats(acc):
|
| 459 |
N = acc["n"]
|
| 460 |
mean = acc["sum_k"] / N
|
| 461 |
var = acc["sum_k_sq"] / N - mean ** 2
|
|
@@ -466,13 +467,13 @@ def _finalize_k_stats(acc):
|
|
| 466 |
"min": acc["k_min"], "max": acc["k_max"]}
|
| 467 |
|
| 468 |
|
| 469 |
-
def _finalize_k_hist(acc):
|
| 470 |
h = acc["k_hist"]
|
| 471 |
nz = h > 0
|
| 472 |
return {"values": np.arange(300)[nz].tolist(), "counts": h[nz].tolist(), "total": acc["n"]}
|
| 473 |
|
| 474 |
|
| 475 |
-
def _finalize_phases(acc):
|
| 476 |
result = {}
|
| 477 |
for name, _, _ in _PHASES:
|
| 478 |
c = acc[f"{name}_n"]
|
|
@@ -486,7 +487,7 @@ def _finalize_phases(acc):
|
|
| 486 |
return result
|
| 487 |
|
| 488 |
|
| 489 |
-
def _finalize_checks(acc):
|
| 490 |
N = acc["n"]
|
| 491 |
result = {}
|
| 492 |
for label in ("chk", "nochk"):
|
|
|
|
| 10 |
|
| 11 |
import json
|
| 12 |
import time
|
| 13 |
+
from collections.abc import Iterator
|
| 14 |
from pathlib import Path
|
| 15 |
|
| 16 |
import numpy as np
|
|
|
|
| 33 |
return result
|
| 34 |
|
| 35 |
|
| 36 |
+
def _count_legal_moves(move_ids: np.ndarray, game_lengths: np.ndarray) -> np.ndarray:
|
| 37 |
"""Legal move count per ply via bit-packed grids + promo mask."""
|
| 38 |
grid, promo_mask = engine.compute_legal_move_masks(move_ids, game_lengths)
|
| 39 |
grid_counts = np.zeros(grid.shape[:2], dtype=np.uint32)
|
|
|
|
| 385 |
("ply_81_150", 80, 150), ("ply_150_plus", 150, 9999)]
|
| 386 |
|
| 387 |
|
| 388 |
+
def _iter_position_parts(corpus: dict) -> Iterator[pl.DataFrame]:
|
| 389 |
"""Yield each position parquet part as an eager DataFrame."""
|
| 390 |
from pathlib import Path
|
| 391 |
# Find corpus dir from the LazyFrame's file path
|
|
|
|
| 396 |
yield pl.read_parquet(f)
|
| 397 |
|
| 398 |
|
| 399 |
+
def _new_accumulator() -> dict[str, int | float | np.ndarray]:
|
| 400 |
return {
|
| 401 |
"n": 0, "sum_k": 0.0, "sum_k_sq": 0.0, "k_min": 999, "k_max": 0,
|
| 402 |
"sum_inv_k": 0.0, "sum_inv_k_sq": 0.0,
|
|
|
|
| 409 |
}
|
| 410 |
|
| 411 |
|
| 412 |
+
def _accumulate(acc: dict, df: pl.DataFrame) -> None:
|
| 413 |
"""Accumulate stats from one chunk (already filtered to k > 0)."""
|
| 414 |
k = df["k"].to_numpy().astype(np.float64)
|
| 415 |
ply = df["ply"].to_numpy()
|
|
|
|
| 456 |
del k, ply, chk, inv_k, ln_k, top5
|
| 457 |
|
| 458 |
|
| 459 |
+
def _finalize_k_stats(acc: dict) -> dict[str, float | int]:
|
| 460 |
N = acc["n"]
|
| 461 |
mean = acc["sum_k"] / N
|
| 462 |
var = acc["sum_k_sq"] / N - mean ** 2
|
|
|
|
| 467 |
"min": acc["k_min"], "max": acc["k_max"]}
|
| 468 |
|
| 469 |
|
| 470 |
+
def _finalize_k_hist(acc: dict) -> dict:
|
| 471 |
h = acc["k_hist"]
|
| 472 |
nz = h > 0
|
| 473 |
return {"values": np.arange(300)[nz].tolist(), "counts": h[nz].tolist(), "total": acc["n"]}
|
| 474 |
|
| 475 |
|
| 476 |
+
def _finalize_phases(acc: dict) -> dict:
|
| 477 |
result = {}
|
| 478 |
for name, _, _ in _PHASES:
|
| 479 |
c = acc[f"{name}_n"]
|
|
|
|
| 487 |
return result
|
| 488 |
|
| 489 |
|
| 490 |
+
def _finalize_checks(acc: dict) -> dict:
|
| 491 |
N = acc["n"]
|
| 492 |
result = {}
|
| 493 |
for label in ("chk", "nochk"):
|
|
@@ -2,6 +2,7 @@
|
|
| 2 |
|
| 3 |
import numpy as np
|
| 4 |
import torch
|
|
|
|
| 5 |
import torch.nn.functional as F
|
| 6 |
|
| 7 |
import chess_engine as engine
|
|
@@ -39,7 +40,7 @@ def extract_diagnostic_positions(
|
|
| 39 |
corpus: dict,
|
| 40 |
min_per_category: int = 2000,
|
| 41 |
max_per_category: int = 5000,
|
| 42 |
-
) -> dict:
|
| 43 |
"""Extract diagnostic positions from corpus.
|
| 44 |
|
| 45 |
Returns dict[category_name] -> list of dicts with:
|
|
@@ -140,7 +141,7 @@ def _term_code_to_outcome_name(tc: int, gl: int) -> str:
|
|
| 140 |
|
| 141 |
@torch.no_grad()
|
| 142 |
def evaluate_diagnostic_positions(
|
| 143 |
-
model,
|
| 144 |
positions: dict,
|
| 145 |
corpus: dict,
|
| 146 |
device: str,
|
|
|
|
| 2 |
|
| 3 |
import numpy as np
|
| 4 |
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
import torch.nn.functional as F
|
| 7 |
|
| 8 |
import chess_engine as engine
|
|
|
|
| 40 |
corpus: dict,
|
| 41 |
min_per_category: int = 2000,
|
| 42 |
max_per_category: int = 5000,
|
| 43 |
+
) -> dict[str, list[dict]]:
|
| 44 |
"""Extract diagnostic positions from corpus.
|
| 45 |
|
| 46 |
Returns dict[category_name] -> list of dicts with:
|
|
|
|
| 141 |
|
| 142 |
@torch.no_grad()
|
| 143 |
def evaluate_diagnostic_positions(
|
| 144 |
+
model: nn.Module,
|
| 145 |
positions: dict,
|
| 146 |
corpus: dict,
|
| 147 |
device: str,
|
|
@@ -1,9 +1,13 @@
|
|
| 1 |
"""Autoregressive generation for outcome token signal tests (§6)."""
|
| 2 |
|
|
|
|
|
|
|
| 3 |
import gc
|
|
|
|
| 4 |
|
| 5 |
import numpy as np
|
| 6 |
import torch
|
|
|
|
| 7 |
import torch.nn.functional as F
|
| 8 |
|
| 9 |
import chess_engine as engine
|
|
@@ -12,13 +16,32 @@ from pawn.config import PAD_TOKEN, WHITE_CHECKMATES, PLY_LIMIT, CLMConfig
|
|
| 12 |
from pawn.data import _map_termination_to_outcome
|
| 13 |
|
| 14 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
# ---------------------------------------------------------------------------
|
| 16 |
# Core autoregressive generation
|
| 17 |
# ---------------------------------------------------------------------------
|
| 18 |
|
| 19 |
|
| 20 |
def autoregressive_generate(
|
| 21 |
-
model,
|
| 22 |
outcome_token: int,
|
| 23 |
n_games: int,
|
| 24 |
device: str,
|
|
@@ -28,7 +51,7 @@ def autoregressive_generate(
|
|
| 28 |
max_seq_len: int = 256,
|
| 29 |
temperature: float = 1.0,
|
| 30 |
batch_size: int = 64,
|
| 31 |
-
) -> dict:
|
| 32 |
"""Generate games autoregressively from a trained PAWN.
|
| 33 |
|
| 34 |
Args:
|
|
@@ -74,9 +97,16 @@ def autoregressive_generate(
|
|
| 74 |
|
| 75 |
@torch.no_grad()
|
| 76 |
def _generate_batch(
|
| 77 |
-
model
|
| 78 |
-
|
| 79 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 80 |
"""Generate a batch of games using batch Rust engine for state management."""
|
| 81 |
cfg_vocab_size = CLMConfig.vocab_size # 4278
|
| 82 |
max_move_positions = max_seq_len - 1 # position 0 is outcome token
|
|
@@ -264,7 +294,7 @@ OUTCOME_TOKENS = {
|
|
| 264 |
|
| 265 |
|
| 266 |
def outcome_signal_test(
|
| 267 |
-
model,
|
| 268 |
device: str,
|
| 269 |
n_per_outcome: int = 1000,
|
| 270 |
mask_conditions: tuple[bool, ...] = (False, True),
|
|
@@ -354,7 +384,7 @@ def _analyze_generated_games(gen: dict, conditioned_outcome: str) -> dict:
|
|
| 354 |
|
| 355 |
|
| 356 |
def prefix_continuation_test(
|
| 357 |
-
model,
|
| 358 |
corpus: dict,
|
| 359 |
device: str,
|
| 360 |
n_per_bucket: int = 200,
|
|
@@ -475,7 +505,7 @@ def prefix_continuation_test(
|
|
| 475 |
return results
|
| 476 |
|
| 477 |
|
| 478 |
-
def _outcome_mask(term_codes, game_lengths, outcome_name):
|
| 479 |
"""Create a boolean mask for games matching the given outcome."""
|
| 480 |
if outcome_name == "WHITE_CHECKMATES":
|
| 481 |
return (term_codes == 0) & (game_lengths % 2 == 1)
|
|
@@ -503,7 +533,7 @@ POISONING_PAIRS = [
|
|
| 503 |
|
| 504 |
|
| 505 |
def poisoned_prefix_test(
|
| 506 |
-
model,
|
| 507 |
corpus: dict,
|
| 508 |
device: str,
|
| 509 |
n_per_pair: int = 500,
|
|
@@ -566,7 +596,7 @@ def poisoned_prefix_test(
|
|
| 566 |
|
| 567 |
|
| 568 |
def impossible_task_test(
|
| 569 |
-
model,
|
| 570 |
corpus: dict,
|
| 571 |
device: str,
|
| 572 |
n_per_scenario: int = 200,
|
|
@@ -652,7 +682,7 @@ def impossible_task_test(
|
|
| 652 |
|
| 653 |
|
| 654 |
def improbable_task_test(
|
| 655 |
-
model,
|
| 656 |
corpus: dict,
|
| 657 |
device: str,
|
| 658 |
n_per_scenario: int = 200,
|
|
|
|
| 1 |
"""Autoregressive generation for outcome token signal tests (§6)."""
|
| 2 |
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
import gc
|
| 6 |
+
from typing import Protocol
|
| 7 |
|
| 8 |
import numpy as np
|
| 9 |
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
import torch.nn.functional as F
|
| 12 |
|
| 13 |
import chess_engine as engine
|
|
|
|
| 16 |
from pawn.data import _map_termination_to_outcome
|
| 17 |
|
| 18 |
|
| 19 |
+
class GenerativeModel(Protocol):
|
| 20 |
+
"""Structural type for models usable in autoregressive generation."""
|
| 21 |
+
|
| 22 |
+
def eval(self) -> nn.Module: ...
|
| 23 |
+
|
| 24 |
+
def __call__(
|
| 25 |
+
self,
|
| 26 |
+
input_ids: torch.Tensor,
|
| 27 |
+
attention_mask: torch.Tensor,
|
| 28 |
+
hidden_only: bool = ...,
|
| 29 |
+
) -> tuple[torch.Tensor, list[torch.Tensor]]: ...
|
| 30 |
+
|
| 31 |
+
def forward_generate(
|
| 32 |
+
self,
|
| 33 |
+
input_ids: torch.Tensor,
|
| 34 |
+
kv_cache: list[tuple[torch.Tensor, torch.Tensor]] | None = ...,
|
| 35 |
+
) -> tuple[torch.Tensor, list[tuple[torch.Tensor, torch.Tensor]]]: ...
|
| 36 |
+
|
| 37 |
+
|
| 38 |
# ---------------------------------------------------------------------------
|
| 39 |
# Core autoregressive generation
|
| 40 |
# ---------------------------------------------------------------------------
|
| 41 |
|
| 42 |
|
| 43 |
def autoregressive_generate(
|
| 44 |
+
model: GenerativeModel,
|
| 45 |
outcome_token: int,
|
| 46 |
n_games: int,
|
| 47 |
device: str,
|
|
|
|
| 51 |
max_seq_len: int = 256,
|
| 52 |
temperature: float = 1.0,
|
| 53 |
batch_size: int = 64,
|
| 54 |
+
) -> dict[str, np.ndarray]:
|
| 55 |
"""Generate games autoregressively from a trained PAWN.
|
| 56 |
|
| 57 |
Args:
|
|
|
|
| 97 |
|
| 98 |
@torch.no_grad()
|
| 99 |
def _generate_batch(
|
| 100 |
+
model: GenerativeModel,
|
| 101 |
+
outcome_token: int,
|
| 102 |
+
n_games: int,
|
| 103 |
+
device: str,
|
| 104 |
+
mask_illegal: bool,
|
| 105 |
+
prefix_moves: np.ndarray | None,
|
| 106 |
+
prefix_lengths: np.ndarray | None,
|
| 107 |
+
max_seq_len: int,
|
| 108 |
+
temperature: float,
|
| 109 |
+
) -> dict[str, np.ndarray]:
|
| 110 |
"""Generate a batch of games using batch Rust engine for state management."""
|
| 111 |
cfg_vocab_size = CLMConfig.vocab_size # 4278
|
| 112 |
max_move_positions = max_seq_len - 1 # position 0 is outcome token
|
|
|
|
| 294 |
|
| 295 |
|
| 296 |
def outcome_signal_test(
|
| 297 |
+
model: GenerativeModel,
|
| 298 |
device: str,
|
| 299 |
n_per_outcome: int = 1000,
|
| 300 |
mask_conditions: tuple[bool, ...] = (False, True),
|
|
|
|
| 384 |
|
| 385 |
|
| 386 |
def prefix_continuation_test(
|
| 387 |
+
model: GenerativeModel,
|
| 388 |
corpus: dict,
|
| 389 |
device: str,
|
| 390 |
n_per_bucket: int = 200,
|
|
|
|
| 505 |
return results
|
| 506 |
|
| 507 |
|
| 508 |
+
def _outcome_mask(term_codes: np.ndarray, game_lengths: np.ndarray, outcome_name: str) -> np.ndarray:
|
| 509 |
"""Create a boolean mask for games matching the given outcome."""
|
| 510 |
if outcome_name == "WHITE_CHECKMATES":
|
| 511 |
return (term_codes == 0) & (game_lengths % 2 == 1)
|
|
|
|
| 533 |
|
| 534 |
|
| 535 |
def poisoned_prefix_test(
|
| 536 |
+
model: GenerativeModel,
|
| 537 |
corpus: dict,
|
| 538 |
device: str,
|
| 539 |
n_per_pair: int = 500,
|
|
|
|
| 596 |
|
| 597 |
|
| 598 |
def impossible_task_test(
|
| 599 |
+
model: GenerativeModel,
|
| 600 |
corpus: dict,
|
| 601 |
device: str,
|
| 602 |
n_per_scenario: int = 200,
|
|
|
|
| 682 |
|
| 683 |
|
| 684 |
def improbable_task_test(
|
| 685 |
+
model: GenerativeModel,
|
| 686 |
corpus: dict,
|
| 687 |
device: str,
|
| 688 |
n_per_scenario: int = 200,
|
|
@@ -5,6 +5,7 @@ from pathlib import Path
|
|
| 5 |
|
| 6 |
import numpy as np
|
| 7 |
import torch
|
|
|
|
| 8 |
|
| 9 |
import chess_engine as engine
|
| 10 |
|
|
@@ -119,7 +120,7 @@ def _extract_elos_from_pgn(pgn_path: Path, max_games: int) -> list[tuple[int, in
|
|
| 119 |
|
| 120 |
@torch.no_grad()
|
| 121 |
def evaluate_on_lichess(
|
| 122 |
-
model,
|
| 123 |
lichess_data: dict,
|
| 124 |
device: str,
|
| 125 |
max_seq_len: int = 256,
|
|
|
|
| 5 |
|
| 6 |
import numpy as np
|
| 7 |
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
|
| 10 |
import chess_engine as engine
|
| 11 |
|
|
|
|
| 120 |
|
| 121 |
@torch.no_grad()
|
| 122 |
def evaluate_on_lichess(
|
| 123 |
+
model: nn.Module,
|
| 124 |
lichess_data: dict,
|
| 125 |
device: str,
|
| 126 |
max_seq_len: int = 256,
|
|
@@ -295,7 +295,7 @@ def _extract_targets(
|
|
| 295 |
# ---------------------------------------------------------------------------
|
| 296 |
|
| 297 |
|
| 298 |
-
def _compute_loss(logits, targets, loss_type, n_outputs):
|
| 299 |
if loss_type == "ce":
|
| 300 |
return F.cross_entropy(logits, targets)
|
| 301 |
elif loss_type == "ce_per_square":
|
|
@@ -308,7 +308,7 @@ def _compute_loss(logits, targets, loss_type, n_outputs):
|
|
| 308 |
raise ValueError(f"Unknown loss type: {loss_type}")
|
| 309 |
|
| 310 |
|
| 311 |
-
def _compute_accuracy(logits, targets, loss_type, n_outputs):
|
| 312 |
if loss_type == "ce":
|
| 313 |
preds = logits.argmax(dim=-1)
|
| 314 |
return (preds == targets).float().mean().item()
|
|
@@ -328,7 +328,7 @@ def _compute_accuracy(logits, targets, loss_type, n_outputs):
|
|
| 328 |
raise ValueError(f"Unknown loss type: {loss_type}")
|
| 329 |
|
| 330 |
|
| 331 |
-
def _compute_mae(logits, targets):
|
| 332 |
"""Mean absolute error for regression probes."""
|
| 333 |
return (logits - targets).abs().mean().item()
|
| 334 |
|
|
@@ -427,7 +427,7 @@ def train_single_probe(
|
|
| 427 |
)
|
| 428 |
|
| 429 |
|
| 430 |
-
def _eval_in_batches(probe, h, t, loss_type, n_outputs, device, batch_size):
|
| 431 |
"""Accuracy in mini-batches."""
|
| 432 |
total_correct = 0.0
|
| 433 |
total = 0
|
|
@@ -441,7 +441,7 @@ def _eval_in_batches(probe, h, t, loss_type, n_outputs, device, batch_size):
|
|
| 441 |
return total_correct / total if total > 0 else 0.0
|
| 442 |
|
| 443 |
|
| 444 |
-
def _eval_loss_in_batches(probe, h, t, loss_type, n_outputs, device, batch_size):
|
| 445 |
"""Loss in mini-batches (returns scalar)."""
|
| 446 |
total_loss = 0.0
|
| 447 |
total = 0
|
|
@@ -455,7 +455,7 @@ def _eval_loss_in_batches(probe, h, t, loss_type, n_outputs, device, batch_size)
|
|
| 455 |
return total_loss / total if total > 0 else 0.0
|
| 456 |
|
| 457 |
|
| 458 |
-
def _eval_mae_in_batches(probe, h, t, device, batch_size):
|
| 459 |
"""MAE in mini-batches."""
|
| 460 |
total_ae = 0.0
|
| 461 |
total = 0
|
|
|
|
| 295 |
# ---------------------------------------------------------------------------
|
| 296 |
|
| 297 |
|
| 298 |
+
def _compute_loss(logits: torch.Tensor, targets: torch.Tensor, loss_type: str, n_outputs: int) -> torch.Tensor:
|
| 299 |
if loss_type == "ce":
|
| 300 |
return F.cross_entropy(logits, targets)
|
| 301 |
elif loss_type == "ce_per_square":
|
|
|
|
| 308 |
raise ValueError(f"Unknown loss type: {loss_type}")
|
| 309 |
|
| 310 |
|
| 311 |
+
def _compute_accuracy(logits: torch.Tensor, targets: torch.Tensor, loss_type: str, n_outputs: int) -> float:
|
| 312 |
if loss_type == "ce":
|
| 313 |
preds = logits.argmax(dim=-1)
|
| 314 |
return (preds == targets).float().mean().item()
|
|
|
|
| 328 |
raise ValueError(f"Unknown loss type: {loss_type}")
|
| 329 |
|
| 330 |
|
| 331 |
+
def _compute_mae(logits: torch.Tensor, targets: torch.Tensor) -> float:
|
| 332 |
"""Mean absolute error for regression probes."""
|
| 333 |
return (logits - targets).abs().mean().item()
|
| 334 |
|
|
|
|
| 427 |
)
|
| 428 |
|
| 429 |
|
| 430 |
+
def _eval_in_batches(probe: LinearProbe, h: torch.Tensor, t: torch.Tensor, loss_type: str, n_outputs: int, device: str, batch_size: int) -> float:
|
| 431 |
"""Accuracy in mini-batches."""
|
| 432 |
total_correct = 0.0
|
| 433 |
total = 0
|
|
|
|
| 441 |
return total_correct / total if total > 0 else 0.0
|
| 442 |
|
| 443 |
|
| 444 |
+
def _eval_loss_in_batches(probe: LinearProbe, h: torch.Tensor, t: torch.Tensor, loss_type: str, n_outputs: int, device: str, batch_size: int) -> float:
|
| 445 |
"""Loss in mini-batches (returns scalar)."""
|
| 446 |
total_loss = 0.0
|
| 447 |
total = 0
|
|
|
|
| 455 |
return total_loss / total if total > 0 else 0.0
|
| 456 |
|
| 457 |
|
| 458 |
+
def _eval_mae_in_batches(probe: LinearProbe, h: torch.Tensor, t: torch.Tensor, device: str, batch_size: int) -> float:
|
| 459 |
"""MAE in mini-batches."""
|
| 460 |
total_ae = 0.0
|
| 461 |
total = 0
|
|
@@ -3,6 +3,7 @@
|
|
| 3 |
import numpy as np
|
| 4 |
import matplotlib.pyplot as plt
|
| 5 |
import matplotlib.ticker as mticker
|
|
|
|
| 6 |
import seaborn as sns
|
| 7 |
|
| 8 |
# Consistent style
|
|
@@ -26,11 +27,12 @@ GRID_PAWN_BASELINES = {
|
|
| 26 |
# ---------------------------------------------------------------------------
|
| 27 |
|
| 28 |
|
| 29 |
-
def plot_game_length_distribution(stats: dict, ax=None) -> plt.Figure:
|
| 30 |
"""Histogram of game lengths."""
|
| 31 |
fig = None
|
| 32 |
if ax is None:
|
| 33 |
fig, ax = plt.subplots(figsize=FIGSIZE)
|
|
|
|
| 34 |
counts = stats["game_length"]["histogram_counts"]
|
| 35 |
edges = stats["game_length"]["histogram_edges"]
|
| 36 |
centers = [(edges[i] + edges[i + 1]) / 2 for i in range(len(counts))]
|
|
@@ -44,11 +46,12 @@ def plot_game_length_distribution(stats: dict, ax=None) -> plt.Figure:
|
|
| 44 |
return fig or ax.figure
|
| 45 |
|
| 46 |
|
| 47 |
-
def plot_legal_move_distribution(bounds: dict, ax=None) -> plt.Figure:
|
| 48 |
"""Histogram of legal move counts (K) from pre-computed histogram data."""
|
| 49 |
fig = None
|
| 50 |
if ax is None:
|
| 51 |
fig, ax = plt.subplots(figsize=FIGSIZE)
|
|
|
|
| 52 |
k_hist = bounds["k_histogram"]
|
| 53 |
k_vals = np.array(k_hist["values"])
|
| 54 |
k_counts = np.array(k_hist["counts"], dtype=np.float64)
|
|
@@ -64,11 +67,12 @@ def plot_legal_move_distribution(bounds: dict, ax=None) -> plt.Figure:
|
|
| 64 |
return fig or ax.figure
|
| 65 |
|
| 66 |
|
| 67 |
-
def plot_outcome_rates(stats: dict, ax=None) -> plt.Figure:
|
| 68 |
"""Bar chart of outcome base rates."""
|
| 69 |
fig = None
|
| 70 |
if ax is None:
|
| 71 |
fig, ax = plt.subplots(figsize=FIGSIZE)
|
|
|
|
| 72 |
rates = stats["outcome_rates"]
|
| 73 |
names = list(rates.keys())
|
| 74 |
values = [rates[n] * 100 for n in names]
|
|
@@ -82,11 +86,12 @@ def plot_outcome_rates(stats: dict, ax=None) -> plt.Figure:
|
|
| 82 |
return fig or ax.figure
|
| 83 |
|
| 84 |
|
| 85 |
-
def plot_k_by_phase(bounds: dict, ax=None) -> plt.Figure:
|
| 86 |
"""E[1/K] by game phase."""
|
| 87 |
fig = None
|
| 88 |
if ax is None:
|
| 89 |
fig, ax = plt.subplots(figsize=FIGSIZE)
|
|
|
|
| 90 |
phase_data = bounds["phase_bounds"]
|
| 91 |
names = list(phase_data.keys())
|
| 92 |
e_inv_k = [phase_data[n]["e_1_over_k"] * 100 for n in names]
|
|
@@ -104,11 +109,12 @@ def plot_k_by_phase(bounds: dict, ax=None) -> plt.Figure:
|
|
| 104 |
return fig or ax.figure
|
| 105 |
|
| 106 |
|
| 107 |
-
def plot_prefix_histogram(sanity: dict, ax=None) -> plt.Figure:
|
| 108 |
"""Histogram of common prefix lengths."""
|
| 109 |
fig = None
|
| 110 |
if ax is None:
|
| 111 |
fig, ax = plt.subplots(figsize=FIGSIZE)
|
|
|
|
| 112 |
hist = sanity["prefix_length_histogram"]
|
| 113 |
ks = sorted(hist.keys())
|
| 114 |
vs = [hist[k] for k in ks]
|
|
|
|
| 3 |
import numpy as np
|
| 4 |
import matplotlib.pyplot as plt
|
| 5 |
import matplotlib.ticker as mticker
|
| 6 |
+
from matplotlib.axes import Axes
|
| 7 |
import seaborn as sns
|
| 8 |
|
| 9 |
# Consistent style
|
|
|
|
| 27 |
# ---------------------------------------------------------------------------
|
| 28 |
|
| 29 |
|
| 30 |
+
def plot_game_length_distribution(stats: dict, ax: Axes | None = None) -> plt.Figure:
|
| 31 |
"""Histogram of game lengths."""
|
| 32 |
fig = None
|
| 33 |
if ax is None:
|
| 34 |
fig, ax = plt.subplots(figsize=FIGSIZE)
|
| 35 |
+
assert ax is not None
|
| 36 |
counts = stats["game_length"]["histogram_counts"]
|
| 37 |
edges = stats["game_length"]["histogram_edges"]
|
| 38 |
centers = [(edges[i] + edges[i + 1]) / 2 for i in range(len(counts))]
|
|
|
|
| 46 |
return fig or ax.figure
|
| 47 |
|
| 48 |
|
| 49 |
+
def plot_legal_move_distribution(bounds: dict, ax: Axes | None = None) -> plt.Figure:
|
| 50 |
"""Histogram of legal move counts (K) from pre-computed histogram data."""
|
| 51 |
fig = None
|
| 52 |
if ax is None:
|
| 53 |
fig, ax = plt.subplots(figsize=FIGSIZE)
|
| 54 |
+
assert ax is not None
|
| 55 |
k_hist = bounds["k_histogram"]
|
| 56 |
k_vals = np.array(k_hist["values"])
|
| 57 |
k_counts = np.array(k_hist["counts"], dtype=np.float64)
|
|
|
|
| 67 |
return fig or ax.figure
|
| 68 |
|
| 69 |
|
| 70 |
+
def plot_outcome_rates(stats: dict, ax: Axes | None = None) -> plt.Figure:
|
| 71 |
"""Bar chart of outcome base rates."""
|
| 72 |
fig = None
|
| 73 |
if ax is None:
|
| 74 |
fig, ax = plt.subplots(figsize=FIGSIZE)
|
| 75 |
+
assert ax is not None
|
| 76 |
rates = stats["outcome_rates"]
|
| 77 |
names = list(rates.keys())
|
| 78 |
values = [rates[n] * 100 for n in names]
|
|
|
|
| 86 |
return fig or ax.figure
|
| 87 |
|
| 88 |
|
| 89 |
+
def plot_k_by_phase(bounds: dict, ax: Axes | None = None) -> plt.Figure:
|
| 90 |
"""E[1/K] by game phase."""
|
| 91 |
fig = None
|
| 92 |
if ax is None:
|
| 93 |
fig, ax = plt.subplots(figsize=FIGSIZE)
|
| 94 |
+
assert ax is not None
|
| 95 |
phase_data = bounds["phase_bounds"]
|
| 96 |
names = list(phase_data.keys())
|
| 97 |
e_inv_k = [phase_data[n]["e_1_over_k"] * 100 for n in names]
|
|
|
|
| 109 |
return fig or ax.figure
|
| 110 |
|
| 111 |
|
| 112 |
+
def plot_prefix_histogram(sanity: dict, ax: Axes | None = None) -> plt.Figure:
|
| 113 |
"""Histogram of common prefix lengths."""
|
| 114 |
fig = None
|
| 115 |
if ax is None:
|
| 116 |
fig, ax = plt.subplots(figsize=FIGSIZE)
|
| 117 |
+
assert ax is not None
|
| 118 |
hist = sanity["prefix_length_histogram"]
|
| 119 |
ks = sorted(hist.keys())
|
| 120 |
vs = [hist[k] for k in ks]
|
|
@@ -20,7 +20,12 @@ from __future__ import annotations
|
|
| 20 |
|
| 21 |
import gc
|
| 22 |
import multiprocessing as mp
|
|
|
|
| 23 |
from pathlib import Path
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
|
| 25 |
# Use "spawn" so the child gets a clean process with no inherited GPU state.
|
| 26 |
_ctx = mp.get_context("spawn")
|
|
@@ -31,11 +36,11 @@ _ctx = mp.get_context("spawn")
|
|
| 31 |
# ---------------------------------------------------------------------------
|
| 32 |
|
| 33 |
|
| 34 |
-
def _worker_entry(fn, args, kwargs):
|
| 35 |
return fn(*args, **kwargs)
|
| 36 |
|
| 37 |
|
| 38 |
-
def run_in_worker(fn, *args, timeout: float | None = None, **kwargs):
|
| 39 |
"""Run fn(*args, **kwargs) in an isolated worker process.
|
| 40 |
|
| 41 |
On KeyboardInterrupt, the worker is terminated and the interrupt is
|
|
@@ -55,7 +60,7 @@ def run_in_worker(fn, *args, timeout: float | None = None, **kwargs):
|
|
| 55 |
# ---------------------------------------------------------------------------
|
| 56 |
|
| 57 |
|
| 58 |
-
def _load_model(checkpoint_path: str, device: str):
|
| 59 |
"""Load and freeze a PAWNCLM checkpoint. Runs inside worker processes."""
|
| 60 |
import torch
|
| 61 |
from pawn.config import CLMConfig
|
|
@@ -83,8 +88,8 @@ def _load_corpus(corpus_dir: str) -> dict:
|
|
| 83 |
# ---------------------------------------------------------------------------
|
| 84 |
|
| 85 |
|
| 86 |
-
def _probes_worker(checkpoint_path, device, n_train, n_val
|
| 87 |
-
seed_train, seed_val):
|
| 88 |
from pawn.eval_suite.probes import extract_probe_data, train_all_probes
|
| 89 |
model = _load_model(checkpoint_path, device)
|
| 90 |
train_data = extract_probe_data(n_train, max_ply=256, seed=seed_train)
|
|
@@ -94,7 +99,7 @@ def _probes_worker(checkpoint_path, device, n_train, n_val, n_epochs,
|
|
| 94 |
|
| 95 |
|
| 96 |
def run_probes(
|
| 97 |
-
checkpoint_path,
|
| 98 |
device: str,
|
| 99 |
n_train: int = 5_000,
|
| 100 |
n_val: int = 1_000,
|
|
@@ -110,7 +115,8 @@ def run_probes(
|
|
| 110 |
)
|
| 111 |
|
| 112 |
|
| 113 |
-
def _signal_test_worker(checkpoint_path, device
|
|
|
|
| 114 |
from pawn.eval_suite.generation import outcome_signal_test
|
| 115 |
model = _load_model(checkpoint_path, device)
|
| 116 |
return outcome_signal_test(model, device, n_per_outcome=n_per_outcome,
|
|
@@ -118,7 +124,7 @@ def _signal_test_worker(checkpoint_path, device, n_per_outcome, mask_conditions)
|
|
| 118 |
|
| 119 |
|
| 120 |
def run_outcome_signal_test(
|
| 121 |
-
checkpoint_path,
|
| 122 |
device: str,
|
| 123 |
n_per_outcome: int = 1000,
|
| 124 |
mask_conditions: tuple[bool, ...] = (False, True),
|
|
@@ -130,8 +136,9 @@ def run_outcome_signal_test(
|
|
| 130 |
)
|
| 131 |
|
| 132 |
|
| 133 |
-
def _prefix_continuation_worker(checkpoint_path, corpus_dir, device,
|
| 134 |
-
n_per_bucket
|
|
|
|
| 135 |
from pawn.eval_suite.generation import prefix_continuation_test
|
| 136 |
model = _load_model(checkpoint_path, device)
|
| 137 |
corpus = _load_corpus(corpus_dir)
|
|
@@ -142,8 +149,8 @@ def _prefix_continuation_worker(checkpoint_path, corpus_dir, device,
|
|
| 142 |
|
| 143 |
|
| 144 |
def run_prefix_continuation_test(
|
| 145 |
-
checkpoint_path,
|
| 146 |
-
corpus_dir,
|
| 147 |
device: str,
|
| 148 |
n_per_bucket: int = 200,
|
| 149 |
prefix_pcts: tuple[float, ...] = (0.1, 0.5, 0.9),
|
|
@@ -157,8 +164,8 @@ def run_prefix_continuation_test(
|
|
| 157 |
)
|
| 158 |
|
| 159 |
|
| 160 |
-
def _poisoned_prefix_worker(checkpoint_path, corpus_dir, device,
|
| 161 |
-
n_per_pair, prefix_pct):
|
| 162 |
from pawn.eval_suite.generation import poisoned_prefix_test
|
| 163 |
model = _load_model(checkpoint_path, device)
|
| 164 |
corpus = _load_corpus(corpus_dir)
|
|
@@ -167,8 +174,8 @@ def _poisoned_prefix_worker(checkpoint_path, corpus_dir, device,
|
|
| 167 |
|
| 168 |
|
| 169 |
def run_poisoned_prefix_test(
|
| 170 |
-
checkpoint_path,
|
| 171 |
-
corpus_dir,
|
| 172 |
device: str,
|
| 173 |
n_per_pair: int = 500,
|
| 174 |
prefix_pct: float = 0.5,
|
|
@@ -180,7 +187,8 @@ def run_poisoned_prefix_test(
|
|
| 180 |
)
|
| 181 |
|
| 182 |
|
| 183 |
-
def _impossible_task_worker(checkpoint_path, corpus_dir
|
|
|
|
| 184 |
from pawn.eval_suite.generation import impossible_task_test
|
| 185 |
model = _load_model(checkpoint_path, device)
|
| 186 |
corpus = _load_corpus(corpus_dir)
|
|
@@ -188,8 +196,8 @@ def _impossible_task_worker(checkpoint_path, corpus_dir, device, n_per_scenario)
|
|
| 188 |
|
| 189 |
|
| 190 |
def run_impossible_task_test(
|
| 191 |
-
checkpoint_path,
|
| 192 |
-
corpus_dir,
|
| 193 |
device: str,
|
| 194 |
n_per_scenario: int = 200,
|
| 195 |
) -> dict:
|
|
@@ -200,7 +208,8 @@ def run_impossible_task_test(
|
|
| 200 |
)
|
| 201 |
|
| 202 |
|
| 203 |
-
def _improbable_task_worker(checkpoint_path, corpus_dir
|
|
|
|
| 204 |
from pawn.eval_suite.generation import improbable_task_test
|
| 205 |
model = _load_model(checkpoint_path, device)
|
| 206 |
corpus = _load_corpus(corpus_dir)
|
|
@@ -208,8 +217,8 @@ def _improbable_task_worker(checkpoint_path, corpus_dir, device, n_per_scenario)
|
|
| 208 |
|
| 209 |
|
| 210 |
def run_improbable_task_test(
|
| 211 |
-
checkpoint_path,
|
| 212 |
-
corpus_dir,
|
| 213 |
device: str,
|
| 214 |
n_per_scenario: int = 200,
|
| 215 |
) -> dict:
|
|
@@ -220,8 +229,9 @@ def run_improbable_task_test(
|
|
| 220 |
)
|
| 221 |
|
| 222 |
|
| 223 |
-
def _diagnostic_worker(checkpoint_path, corpus_dir, device
|
| 224 |
-
|
|
|
|
| 225 |
from pawn.eval_suite.diagnostics import (
|
| 226 |
extract_diagnostic_positions, evaluate_diagnostic_positions,
|
| 227 |
)
|
|
@@ -240,8 +250,8 @@ def _diagnostic_worker(checkpoint_path, corpus_dir, device, min_per_category,
|
|
| 240 |
|
| 241 |
|
| 242 |
def run_diagnostic_eval(
|
| 243 |
-
checkpoint_path,
|
| 244 |
-
corpus_dir,
|
| 245 |
device: str,
|
| 246 |
min_per_category: int = 2000,
|
| 247 |
max_per_category: int = 5000,
|
|
|
|
| 20 |
|
| 21 |
import gc
|
| 22 |
import multiprocessing as mp
|
| 23 |
+
from collections.abc import Callable
|
| 24 |
from pathlib import Path
|
| 25 |
+
from typing import Any, TYPE_CHECKING
|
| 26 |
+
|
| 27 |
+
if TYPE_CHECKING:
|
| 28 |
+
from pawn.model import PAWNCLM
|
| 29 |
|
| 30 |
# Use "spawn" so the child gets a clean process with no inherited GPU state.
|
| 31 |
_ctx = mp.get_context("spawn")
|
|
|
|
| 36 |
# ---------------------------------------------------------------------------
|
| 37 |
|
| 38 |
|
| 39 |
+
def _worker_entry(fn: Callable[..., Any], args: tuple, kwargs: dict) -> Any:
|
| 40 |
return fn(*args, **kwargs)
|
| 41 |
|
| 42 |
|
| 43 |
+
def run_in_worker(fn: Callable[..., Any], *args: Any, timeout: float | None = None, **kwargs: Any) -> Any:
|
| 44 |
"""Run fn(*args, **kwargs) in an isolated worker process.
|
| 45 |
|
| 46 |
On KeyboardInterrupt, the worker is terminated and the interrupt is
|
|
|
|
| 60 |
# ---------------------------------------------------------------------------
|
| 61 |
|
| 62 |
|
| 63 |
+
def _load_model(checkpoint_path: str, device: str) -> PAWNCLM:
|
| 64 |
"""Load and freeze a PAWNCLM checkpoint. Runs inside worker processes."""
|
| 65 |
import torch
|
| 66 |
from pawn.config import CLMConfig
|
|
|
|
| 88 |
# ---------------------------------------------------------------------------
|
| 89 |
|
| 90 |
|
| 91 |
+
def _probes_worker(checkpoint_path: str, device: str, n_train: int, n_val: int,
|
| 92 |
+
n_epochs: int, seed_train: int, seed_val: int) -> dict:
|
| 93 |
from pawn.eval_suite.probes import extract_probe_data, train_all_probes
|
| 94 |
model = _load_model(checkpoint_path, device)
|
| 95 |
train_data = extract_probe_data(n_train, max_ply=256, seed=seed_train)
|
|
|
|
| 99 |
|
| 100 |
|
| 101 |
def run_probes(
|
| 102 |
+
checkpoint_path: str | Path,
|
| 103 |
device: str,
|
| 104 |
n_train: int = 5_000,
|
| 105 |
n_val: int = 1_000,
|
|
|
|
| 115 |
)
|
| 116 |
|
| 117 |
|
| 118 |
+
def _signal_test_worker(checkpoint_path: str, device: str, n_per_outcome: int,
|
| 119 |
+
mask_conditions: list[bool]) -> dict:
|
| 120 |
from pawn.eval_suite.generation import outcome_signal_test
|
| 121 |
model = _load_model(checkpoint_path, device)
|
| 122 |
return outcome_signal_test(model, device, n_per_outcome=n_per_outcome,
|
|
|
|
| 124 |
|
| 125 |
|
| 126 |
def run_outcome_signal_test(
|
| 127 |
+
checkpoint_path: str | Path,
|
| 128 |
device: str,
|
| 129 |
n_per_outcome: int = 1000,
|
| 130 |
mask_conditions: tuple[bool, ...] = (False, True),
|
|
|
|
| 136 |
)
|
| 137 |
|
| 138 |
|
| 139 |
+
def _prefix_continuation_worker(checkpoint_path: str, corpus_dir: str, device: str,
|
| 140 |
+
n_per_bucket: int, prefix_pcts: list[float],
|
| 141 |
+
absolute_plies: list[int]) -> dict:
|
| 142 |
from pawn.eval_suite.generation import prefix_continuation_test
|
| 143 |
model = _load_model(checkpoint_path, device)
|
| 144 |
corpus = _load_corpus(corpus_dir)
|
|
|
|
| 149 |
|
| 150 |
|
| 151 |
def run_prefix_continuation_test(
|
| 152 |
+
checkpoint_path: str | Path,
|
| 153 |
+
corpus_dir: str | Path,
|
| 154 |
device: str,
|
| 155 |
n_per_bucket: int = 200,
|
| 156 |
prefix_pcts: tuple[float, ...] = (0.1, 0.5, 0.9),
|
|
|
|
| 164 |
)
|
| 165 |
|
| 166 |
|
| 167 |
+
def _poisoned_prefix_worker(checkpoint_path: str, corpus_dir: str, device: str,
|
| 168 |
+
n_per_pair: int, prefix_pct: float) -> dict:
|
| 169 |
from pawn.eval_suite.generation import poisoned_prefix_test
|
| 170 |
model = _load_model(checkpoint_path, device)
|
| 171 |
corpus = _load_corpus(corpus_dir)
|
|
|
|
| 174 |
|
| 175 |
|
| 176 |
def run_poisoned_prefix_test(
|
| 177 |
+
checkpoint_path: str | Path,
|
| 178 |
+
corpus_dir: str | Path,
|
| 179 |
device: str,
|
| 180 |
n_per_pair: int = 500,
|
| 181 |
prefix_pct: float = 0.5,
|
|
|
|
| 187 |
)
|
| 188 |
|
| 189 |
|
| 190 |
+
def _impossible_task_worker(checkpoint_path: str, corpus_dir: str, device: str,
|
| 191 |
+
n_per_scenario: int) -> dict:
|
| 192 |
from pawn.eval_suite.generation import impossible_task_test
|
| 193 |
model = _load_model(checkpoint_path, device)
|
| 194 |
corpus = _load_corpus(corpus_dir)
|
|
|
|
| 196 |
|
| 197 |
|
| 198 |
def run_impossible_task_test(
|
| 199 |
+
checkpoint_path: str | Path,
|
| 200 |
+
corpus_dir: str | Path,
|
| 201 |
device: str,
|
| 202 |
n_per_scenario: int = 200,
|
| 203 |
) -> dict:
|
|
|
|
| 208 |
)
|
| 209 |
|
| 210 |
|
| 211 |
+
def _improbable_task_worker(checkpoint_path: str, corpus_dir: str, device: str,
|
| 212 |
+
n_per_scenario: int) -> dict:
|
| 213 |
from pawn.eval_suite.generation import improbable_task_test
|
| 214 |
model = _load_model(checkpoint_path, device)
|
| 215 |
corpus = _load_corpus(corpus_dir)
|
|
|
|
| 217 |
|
| 218 |
|
| 219 |
def run_improbable_task_test(
|
| 220 |
+
checkpoint_path: str | Path,
|
| 221 |
+
corpus_dir: str | Path,
|
| 222 |
device: str,
|
| 223 |
n_per_scenario: int = 200,
|
| 224 |
) -> dict:
|
|
|
|
| 229 |
)
|
| 230 |
|
| 231 |
|
| 232 |
+
def _diagnostic_worker(checkpoint_path: str, corpus_dir: str, device: str,
|
| 233 |
+
min_per_category: int, max_per_category: int,
|
| 234 |
+
n_samples: int, batch_size: int) -> dict:
|
| 235 |
from pawn.eval_suite.diagnostics import (
|
| 236 |
extract_diagnostic_positions, evaluate_diagnostic_positions,
|
| 237 |
)
|
|
|
|
| 250 |
|
| 251 |
|
| 252 |
def run_diagnostic_eval(
|
| 253 |
+
checkpoint_path: str | Path,
|
| 254 |
+
corpus_dir: str | Path,
|
| 255 |
device: str,
|
| 256 |
min_per_category: int = 2000,
|
| 257 |
max_per_category: int = 5000,
|
|
@@ -308,10 +308,10 @@ class LichessDataset(torch.utils.data.Dataset):
|
|
| 308 |
self.game_lengths = torch.from_numpy(np.array(self.game_lengths)).share_memory_()
|
| 309 |
return self
|
| 310 |
|
| 311 |
-
def __len__(self):
|
| 312 |
return len(self.input_ids)
|
| 313 |
|
| 314 |
-
def __getitem__(self, idx):
|
| 315 |
return {
|
| 316 |
"input_ids": self.input_ids[idx],
|
| 317 |
"targets": self.targets[idx],
|
|
|
|
| 308 |
self.game_lengths = torch.from_numpy(np.array(self.game_lengths)).share_memory_()
|
| 309 |
return self
|
| 310 |
|
| 311 |
+
def __len__(self) -> int:
|
| 312 |
return len(self.input_ids)
|
| 313 |
|
| 314 |
+
def __getitem__(self, idx: int) -> dict[str, torch.Tensor | int]:
|
| 315 |
return {
|
| 316 |
"input_ids": self.input_ids[idx],
|
| 317 |
"targets": self.targets[idx],
|
|
@@ -72,7 +72,7 @@ class MetricsLogger:
|
|
| 72 |
record_type: Record type (train, eval, batch, etc.)
|
| 73 |
include_resources: Whether to include memory/CPU stats
|
| 74 |
"""
|
| 75 |
-
record = {"type": record_type}
|
| 76 |
|
| 77 |
if step is not None:
|
| 78 |
record["step"] = step
|
|
@@ -119,14 +119,14 @@ class MetricsLogger:
|
|
| 119 |
def close(self) -> None:
|
| 120 |
self._file.close()
|
| 121 |
|
| 122 |
-
def __enter__(self):
|
| 123 |
return self
|
| 124 |
|
| 125 |
-
def __exit__(self, *args):
|
| 126 |
self.close()
|
| 127 |
|
| 128 |
|
| 129 |
-
def _sanitize(obj):
|
| 130 |
"""Replace NaN/Inf with None for valid JSON."""
|
| 131 |
if isinstance(obj, float) and (math.isnan(obj) or math.isinf(obj)):
|
| 132 |
return None
|
|
|
|
| 72 |
record_type: Record type (train, eval, batch, etc.)
|
| 73 |
include_resources: Whether to include memory/CPU stats
|
| 74 |
"""
|
| 75 |
+
record: dict[str, object] = {"type": record_type}
|
| 76 |
|
| 77 |
if step is not None:
|
| 78 |
record["step"] = step
|
|
|
|
| 119 |
def close(self) -> None:
|
| 120 |
self._file.close()
|
| 121 |
|
| 122 |
+
def __enter__(self) -> "MetricsLogger":
|
| 123 |
return self
|
| 124 |
|
| 125 |
+
def __exit__(self, *args: object) -> None:
|
| 126 |
self.close()
|
| 127 |
|
| 128 |
|
| 129 |
+
def _sanitize(obj: object) -> object:
|
| 130 |
"""Replace NaN/Inf with None for valid JSON."""
|
| 131 |
if isinstance(obj, float) and (math.isnan(obj) or math.isinf(obj)):
|
| 132 |
return None
|
|
@@ -179,6 +179,11 @@ class SwiGLUFFN(nn.Module):
|
|
| 179 |
|
| 180 |
|
| 181 |
class TransformerBlock(nn.Module):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 182 |
def __init__(self, cfg: CLMConfig):
|
| 183 |
super().__init__()
|
| 184 |
self.attn_norm = RMSNorm(cfg.d_model)
|
|
@@ -220,6 +225,8 @@ class CLMEmbedding(nn.Module):
|
|
| 220 |
PAD and outcome tokens use standalone embeddings.
|
| 221 |
"""
|
| 222 |
|
|
|
|
|
|
|
| 223 |
def __init__(self, cfg: CLMConfig):
|
| 224 |
super().__init__()
|
| 225 |
self.d_model = cfg.d_model
|
|
@@ -273,6 +280,14 @@ class PAWNCLM(nn.Module):
|
|
| 273 |
full vocabulary. No factored output head, no grid, no BCE.
|
| 274 |
"""
|
| 275 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 276 |
def __init__(self, cfg: CLMConfig):
|
| 277 |
super().__init__()
|
| 278 |
self.cfg = cfg
|
|
@@ -300,6 +315,10 @@ class PAWNCLM(nn.Module):
|
|
| 300 |
|
| 301 |
self._init_weights()
|
| 302 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 303 |
def _init_weights(self):
|
| 304 |
for p in self.parameters():
|
| 305 |
if p.dim() > 1:
|
|
@@ -425,9 +444,9 @@ class PAWNCLM(nn.Module):
|
|
| 425 |
rope_sin = self.rope_sin[:, :, :T_new, :]
|
| 426 |
|
| 427 |
new_kv_cache = []
|
| 428 |
-
for i
|
| 429 |
layer_cache = kv_cache[i] if kv_cache is not None else None
|
| 430 |
-
x, new_cache =
|
| 431 |
new_kv_cache.append(new_cache)
|
| 432 |
|
| 433 |
x = self.final_norm(x[:, -1:, :])
|
|
|
|
| 179 |
|
| 180 |
|
| 181 |
class TransformerBlock(nn.Module):
|
| 182 |
+
attn_norm: RMSNorm
|
| 183 |
+
attn: Attention
|
| 184 |
+
ffn_norm: RMSNorm
|
| 185 |
+
ffn: SwiGLUFFN
|
| 186 |
+
|
| 187 |
def __init__(self, cfg: CLMConfig):
|
| 188 |
super().__init__()
|
| 189 |
self.attn_norm = RMSNorm(cfg.d_model)
|
|
|
|
| 225 |
PAD and outcome tokens use standalone embeddings.
|
| 226 |
"""
|
| 227 |
|
| 228 |
+
decomp_table: torch.Tensor
|
| 229 |
+
|
| 230 |
def __init__(self, cfg: CLMConfig):
|
| 231 |
super().__init__()
|
| 232 |
self.d_model = cfg.d_model
|
|
|
|
| 280 |
full vocabulary. No factored output head, no grid, no BCE.
|
| 281 |
"""
|
| 282 |
|
| 283 |
+
rope_cos: torch.Tensor
|
| 284 |
+
rope_sin: torch.Tensor
|
| 285 |
+
causal_mask: torch.Tensor
|
| 286 |
+
embed: CLMEmbedding
|
| 287 |
+
layers: nn.ModuleList
|
| 288 |
+
final_norm: RMSNorm
|
| 289 |
+
lm_head: nn.Linear
|
| 290 |
+
|
| 291 |
def __init__(self, cfg: CLMConfig):
|
| 292 |
super().__init__()
|
| 293 |
self.cfg = cfg
|
|
|
|
| 315 |
|
| 316 |
self._init_weights()
|
| 317 |
|
| 318 |
+
def get_block(self, i: int) -> TransformerBlock:
|
| 319 |
+
"""Typed accessor for transformer layers (avoids ModuleList type erasure)."""
|
| 320 |
+
return self.layers[i] # type: ignore[return-value]
|
| 321 |
+
|
| 322 |
def _init_weights(self):
|
| 323 |
for p in self.parameters():
|
| 324 |
if p.dim() > 1:
|
|
|
|
| 444 |
rope_sin = self.rope_sin[:, :, :T_new, :]
|
| 445 |
|
| 446 |
new_kv_cache = []
|
| 447 |
+
for i in range(len(self.layers)):
|
| 448 |
layer_cache = kv_cache[i] if kv_cache is not None else None
|
| 449 |
+
x, new_cache = self.get_block(i).forward_kv(x, rope_cos, rope_sin, layer_cache)
|
| 450 |
new_kv_cache.append(new_cache)
|
| 451 |
|
| 452 |
x = self.final_norm(x[:, -1:, :])
|
|
@@ -38,12 +38,12 @@ class CosineWithWarmup:
|
|
| 38 |
self._step = 0
|
| 39 |
self._apply_lr(0)
|
| 40 |
|
| 41 |
-
def _apply_lr(self, step: int):
|
| 42 |
lr_scale = self._compute_lr_scale(step)
|
| 43 |
for pg, base_lr in zip(self.optimizer.param_groups, self.base_lrs, strict=True):
|
| 44 |
pg["lr"] = base_lr * lr_scale
|
| 45 |
|
| 46 |
-
def step(self):
|
| 47 |
self._step += 1
|
| 48 |
self._apply_lr(self._step)
|
| 49 |
|
|
@@ -59,10 +59,10 @@ class CosineWithWarmup:
|
|
| 59 |
def get_lr(self) -> float:
|
| 60 |
return self.optimizer.param_groups[0]["lr"]
|
| 61 |
|
| 62 |
-
def state_dict(self):
|
| 63 |
return {"step": self._step}
|
| 64 |
|
| 65 |
-
def load_state_dict(self, state):
|
| 66 |
self._step = state["step"]
|
| 67 |
self._apply_lr(self._step)
|
| 68 |
|
|
@@ -208,8 +208,9 @@ class CLMTrainer:
|
|
| 208 |
self._jsonl_path = os.path.join(self.run_dir, "metrics.jsonl")
|
| 209 |
self._jsonl_file = None
|
| 210 |
|
| 211 |
-
self.
|
| 212 |
-
|
|
|
|
| 213 |
print(f"Model parameters: {param_count:,}")
|
| 214 |
print(f"Run directory: {self.run_dir}")
|
| 215 |
|
|
@@ -345,7 +346,7 @@ class CLMTrainer:
|
|
| 345 |
|
| 346 |
def optimizer_step(self) -> float:
|
| 347 |
self.scaler.unscale_(self.optimizer)
|
| 348 |
-
grad_norm = _get_grad_norm(self.
|
| 349 |
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.cfg.max_grad_norm)
|
| 350 |
self.scaler.step(self.optimizer)
|
| 351 |
self.scaler.update()
|
|
@@ -354,7 +355,7 @@ class CLMTrainer:
|
|
| 354 |
return grad_norm
|
| 355 |
|
| 356 |
def _eager_model(self) -> PAWNCLM:
|
| 357 |
-
return self.
|
| 358 |
|
| 359 |
@torch.no_grad()
|
| 360 |
def evaluate(self) -> dict[str, float]:
|
|
@@ -535,9 +536,7 @@ class CLMTrainer:
|
|
| 535 |
if dirname:
|
| 536 |
os.makedirs(dirname, exist_ok=True)
|
| 537 |
|
| 538 |
-
model = self.
|
| 539 |
-
if hasattr(model, "_orig_mod"):
|
| 540 |
-
model = model._orig_mod
|
| 541 |
|
| 542 |
torch.save(
|
| 543 |
{
|
|
@@ -561,9 +560,7 @@ class CLMTrainer:
|
|
| 561 |
ckpt = torch.load(path, map_location=self.device, weights_only=False)
|
| 562 |
self.global_step = ckpt["global_step"]
|
| 563 |
|
| 564 |
-
model = self.
|
| 565 |
-
if hasattr(model, "_orig_mod"):
|
| 566 |
-
model = model._orig_mod
|
| 567 |
|
| 568 |
model.load_state_dict(ckpt["model_state_dict"])
|
| 569 |
self.optimizer.load_state_dict(ckpt["optimizer_state_dict"])
|
|
|
|
| 38 |
self._step = 0
|
| 39 |
self._apply_lr(0)
|
| 40 |
|
| 41 |
+
def _apply_lr(self, step: int) -> None:
|
| 42 |
lr_scale = self._compute_lr_scale(step)
|
| 43 |
for pg, base_lr in zip(self.optimizer.param_groups, self.base_lrs, strict=True):
|
| 44 |
pg["lr"] = base_lr * lr_scale
|
| 45 |
|
| 46 |
+
def step(self) -> None:
|
| 47 |
self._step += 1
|
| 48 |
self._apply_lr(self._step)
|
| 49 |
|
|
|
|
| 59 |
def get_lr(self) -> float:
|
| 60 |
return self.optimizer.param_groups[0]["lr"]
|
| 61 |
|
| 62 |
+
def state_dict(self) -> dict[str, int]:
|
| 63 |
return {"step": self._step}
|
| 64 |
|
| 65 |
+
def load_state_dict(self, state: dict[str, int]) -> None:
|
| 66 |
self._step = state["step"]
|
| 67 |
self._apply_lr(self._step)
|
| 68 |
|
|
|
|
| 208 |
self._jsonl_path = os.path.join(self.run_dir, "metrics.jsonl")
|
| 209 |
self._jsonl_file = None
|
| 210 |
|
| 211 |
+
self._model = PAWNCLM(model_cfg).to(self.device)
|
| 212 |
+
self.model = self._model
|
| 213 |
+
param_count = sum(p.numel() for p in self._model.parameters())
|
| 214 |
print(f"Model parameters: {param_count:,}")
|
| 215 |
print(f"Run directory: {self.run_dir}")
|
| 216 |
|
|
|
|
| 346 |
|
| 347 |
def optimizer_step(self) -> float:
|
| 348 |
self.scaler.unscale_(self.optimizer)
|
| 349 |
+
grad_norm = _get_grad_norm(self._model)
|
| 350 |
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.cfg.max_grad_norm)
|
| 351 |
self.scaler.step(self.optimizer)
|
| 352 |
self.scaler.update()
|
|
|
|
| 355 |
return grad_norm
|
| 356 |
|
| 357 |
def _eager_model(self) -> PAWNCLM:
|
| 358 |
+
return self._model
|
| 359 |
|
| 360 |
@torch.no_grad()
|
| 361 |
def evaluate(self) -> dict[str, float]:
|
|
|
|
| 536 |
if dirname:
|
| 537 |
os.makedirs(dirname, exist_ok=True)
|
| 538 |
|
| 539 |
+
model: PAWNCLM = self._eager_model()
|
|
|
|
|
|
|
| 540 |
|
| 541 |
torch.save(
|
| 542 |
{
|
|
|
|
| 560 |
ckpt = torch.load(path, map_location=self.device, weights_only=False)
|
| 561 |
self.global_step = ckpt["global_step"]
|
| 562 |
|
| 563 |
+
model: PAWNCLM = self._eager_model()
|
|
|
|
|
|
|
| 564 |
|
| 565 |
model.load_state_dict(ckpt["model_state_dict"])
|
| 566 |
self.optimizer.load_state_dict(ckpt["optimizer_state_dict"])
|
|
@@ -57,5 +57,13 @@ name = "pytorch-cu128"
|
|
| 57 |
url = "https://download.pytorch.org/whl/cu128"
|
| 58 |
explicit = true
|
| 59 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
[tool.pytest.ini_options]
|
| 61 |
testpaths = ["tests"]
|
|
|
|
| 57 |
url = "https://download.pytorch.org/whl/cu128"
|
| 58 |
explicit = true
|
| 59 |
|
| 60 |
+
[tool.pyright]
|
| 61 |
+
pythonVersion = "3.10"
|
| 62 |
+
typeCheckingMode = "basic"
|
| 63 |
+
reportMissingTypeStubs = false
|
| 64 |
+
reportPrivateImportUsage = false
|
| 65 |
+
reportMissingImports = "warning"
|
| 66 |
+
include = ["pawn"]
|
| 67 |
+
|
| 68 |
[tool.pytest.ini_options]
|
| 69 |
testpaths = ["tests"]
|