thomas-schweich Claude Opus 4.6 (1M context) commited on
Commit
36caadb
·
1 Parent(s): d7ecc62

Add type annotations across codebase and configure pyright

Browse files

Configure pyright in basic mode for static type checking. Add type
annotations to all function signatures (~76 additions across 18 files),
bringing coverage from ~65% to 100%.

Key structural changes to achieve clean type checking without
suppressions:

- Add PAWNCLM.get_block() typed accessor to avoid ModuleList type
erasure (single type-narrowing point for all adapter consumers)
- Replace None with nn.Identity() in adapter ModuleLists (bottleneck,
hybrid) so entries are always valid Modules
- Separate trainer's unwrapped _model from potentially-compiled model
to preserve concrete PAWNCLM type
- Define GenerativeModel Protocol for eval suite's duck-typed model
parameter instead of bare nn.Module
- Add class-level type declarations for registered buffers and
submodules (rope_cos, causal_mask, TransformerBlock attrs)

Result: 0 pyright errors, 1 type: ignore (in get_block), all tests pass.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

pawn/adapters/bottleneck.py CHANGED
@@ -80,18 +80,18 @@ class BottleneckCLM(nn.Module):
80
  for p in backbone.parameters():
81
  p.requires_grad = False
82
 
83
- # Create adapter modules (None for non-adapted layers)
84
  self.attn_adapters = nn.ModuleList()
85
  self.ffn_adapters = nn.ModuleList()
86
  for i in range(n_layers):
87
  if i in self._attn_set:
88
  self.attn_adapters.append(BottleneckAdapter(cfg.d_model, bottleneck_dim))
89
  else:
90
- self.attn_adapters.append(None)
91
  if i in self._ffn_set:
92
  self.ffn_adapters.append(BottleneckAdapter(cfg.d_model, bottleneck_dim))
93
  else:
94
- self.ffn_adapters.append(None)
95
 
96
  @property
97
  def cfg(self) -> CLMConfig:
@@ -106,16 +106,15 @@ class BottleneckCLM(nn.Module):
106
  rope_cos = bb.rope_cos[:, :, :T, :]
107
  rope_sin = bb.rope_sin[:, :, :T, :]
108
 
109
- for i, block in enumerate(bb.layers):
 
110
  # Attention sublayer + adapter
111
  x = x + block.attn(block.attn_norm(x), rope_cos, rope_sin, None)
112
- if self.attn_adapters[i] is not None:
113
- x = self.attn_adapters[i](x)
114
 
115
  # FFN sublayer + adapter
116
  x = x + block.ffn(block.ffn_norm(x))
117
- if self.ffn_adapters[i] is not None:
118
- x = self.ffn_adapters[i](x)
119
 
120
  return bb.final_norm(x)
121
 
@@ -143,14 +142,13 @@ class BottleneckCLM(nn.Module):
143
  rope_cos = bb.rope_cos[:, :, :T, :]
144
  rope_sin = bb.rope_sin[:, :, :T, :]
145
 
146
- for i, block in enumerate(bb.layers):
 
147
  x = x + block.attn(block.attn_norm(x), rope_cos, rope_sin, mask)
148
- if self.attn_adapters[i] is not None:
149
- x = self.attn_adapters[i](x)
150
 
151
  x = x + block.ffn(block.ffn_norm(x))
152
- if self.ffn_adapters[i] is not None:
153
- x = self.ffn_adapters[i](x)
154
 
155
  x = bb.final_norm(x)
156
  return self.project_head(x)
@@ -174,20 +172,19 @@ class BottleneckCLM(nn.Module):
174
  rope_sin = bb.rope_sin[:, :, :T_new, :]
175
 
176
  new_kv_cache = []
177
- for i, block in enumerate(bb.layers):
 
178
  # KV-cache forward for attention
179
  layer_cache = kv_cache[i] if kv_cache is not None else None
180
  attn_out, new_cache = block.attn.forward_kv(
181
  block.attn_norm(x), rope_cos, rope_sin, layer_cache,
182
  )
183
  x = x + attn_out
184
- if self.attn_adapters[i] is not None:
185
- x = self.attn_adapters[i](x)
186
  new_kv_cache.append(new_cache)
187
 
188
  x = x + block.ffn(block.ffn_norm(x))
189
- if self.ffn_adapters[i] is not None:
190
- x = self.ffn_adapters[i](x)
191
 
192
  x = bb.final_norm(x[:, -1:, :])
193
  logits = bb.lm_head(x)
@@ -218,12 +215,12 @@ class BottleneckCLM(nn.Module):
218
  """Per-layer adapter weight norms for monitoring."""
219
  report = {}
220
  for i in range(len(self.backbone.layers)):
221
- if self.attn_adapters[i] is not None:
222
- a = self.attn_adapters[i]
223
  report[f"adapter/layer{i}.attn.down"] = a.down.weight.data.norm().item()
224
  report[f"adapter/layer{i}.attn.up"] = a.up.weight.data.norm().item()
225
- if self.ffn_adapters[i] is not None:
226
- a = self.ffn_adapters[i]
227
  report[f"adapter/layer{i}.ffn.down"] = a.down.weight.data.norm().item()
228
  report[f"adapter/layer{i}.ffn.up"] = a.up.weight.data.norm().item()
229
  return report
 
80
  for p in backbone.parameters():
81
  p.requires_grad = False
82
 
83
+ # Create adapter modules (Identity for non-adapted layers)
84
  self.attn_adapters = nn.ModuleList()
85
  self.ffn_adapters = nn.ModuleList()
86
  for i in range(n_layers):
87
  if i in self._attn_set:
88
  self.attn_adapters.append(BottleneckAdapter(cfg.d_model, bottleneck_dim))
89
  else:
90
+ self.attn_adapters.append(nn.Identity())
91
  if i in self._ffn_set:
92
  self.ffn_adapters.append(BottleneckAdapter(cfg.d_model, bottleneck_dim))
93
  else:
94
+ self.ffn_adapters.append(nn.Identity())
95
 
96
  @property
97
  def cfg(self) -> CLMConfig:
 
106
  rope_cos = bb.rope_cos[:, :, :T, :]
107
  rope_sin = bb.rope_sin[:, :, :T, :]
108
 
109
+ for i in range(len(bb.layers)):
110
+ block = bb.get_block(i)
111
  # Attention sublayer + adapter
112
  x = x + block.attn(block.attn_norm(x), rope_cos, rope_sin, None)
113
+ x = self.attn_adapters[i](x)
 
114
 
115
  # FFN sublayer + adapter
116
  x = x + block.ffn(block.ffn_norm(x))
117
+ x = self.ffn_adapters[i](x)
 
118
 
119
  return bb.final_norm(x)
120
 
 
142
  rope_cos = bb.rope_cos[:, :, :T, :]
143
  rope_sin = bb.rope_sin[:, :, :T, :]
144
 
145
+ for i in range(len(bb.layers)):
146
+ block = bb.get_block(i)
147
  x = x + block.attn(block.attn_norm(x), rope_cos, rope_sin, mask)
148
+ x = self.attn_adapters[i](x)
 
149
 
150
  x = x + block.ffn(block.ffn_norm(x))
151
+ x = self.ffn_adapters[i](x)
 
152
 
153
  x = bb.final_norm(x)
154
  return self.project_head(x)
 
172
  rope_sin = bb.rope_sin[:, :, :T_new, :]
173
 
174
  new_kv_cache = []
175
+ for i in range(len(bb.layers)):
176
+ block = bb.get_block(i)
177
  # KV-cache forward for attention
178
  layer_cache = kv_cache[i] if kv_cache is not None else None
179
  attn_out, new_cache = block.attn.forward_kv(
180
  block.attn_norm(x), rope_cos, rope_sin, layer_cache,
181
  )
182
  x = x + attn_out
183
+ x = self.attn_adapters[i](x)
 
184
  new_kv_cache.append(new_cache)
185
 
186
  x = x + block.ffn(block.ffn_norm(x))
187
+ x = self.ffn_adapters[i](x)
 
188
 
189
  x = bb.final_norm(x[:, -1:, :])
190
  logits = bb.lm_head(x)
 
215
  """Per-layer adapter weight norms for monitoring."""
216
  report = {}
217
  for i in range(len(self.backbone.layers)):
218
+ a = self.attn_adapters[i]
219
+ if isinstance(a, BottleneckAdapter):
220
  report[f"adapter/layer{i}.attn.down"] = a.down.weight.data.norm().item()
221
  report[f"adapter/layer{i}.attn.up"] = a.up.weight.data.norm().item()
222
+ a = self.ffn_adapters[i]
223
+ if isinstance(a, BottleneckAdapter):
224
  report[f"adapter/layer{i}.ffn.down"] = a.down.weight.data.norm().item()
225
  report[f"adapter/layer{i}.ffn.up"] = a.up.weight.data.norm().item()
226
  return report
pawn/adapters/film.py CHANGED
@@ -140,10 +140,11 @@ class FiLMCLM(nn.Module):
140
  rope_sin = bb.rope_sin[:, :, :T_new, :]
141
 
142
  new_kv_cache = []
143
- for i, (layer, film) in enumerate(zip(bb.layers, self.hidden_films)):
 
144
  layer_cache = kv_cache[i] if kv_cache is not None else None
145
- x, new_cache = layer.forward_kv(x, rope_cos, rope_sin, layer_cache)
146
- x = film(x)
147
  new_kv_cache.append(new_cache)
148
 
149
  x = bb.final_norm(x[:, -1:, :])
@@ -182,8 +183,9 @@ class FiLMCLM(nn.Module):
182
  """Per-layer FiLM deviation from identity, for monitoring."""
183
  report = {}
184
  for i, film in enumerate(self.hidden_films):
185
- report[f"hidden_{i}/gamma_dev"] = (film.gamma - 1.0).norm().item()
186
- report[f"hidden_{i}/beta_norm"] = film.beta.norm().item()
 
187
  if self.output_film is not None:
188
  report["output/gamma_dev"] = (self.output_film.gamma - 1.0).norm().item()
189
  report["output/beta_norm"] = self.output_film.beta.norm().item()
 
140
  rope_sin = bb.rope_sin[:, :, :T_new, :]
141
 
142
  new_kv_cache = []
143
+ for i in range(len(bb.layers)):
144
+ block = bb.get_block(i)
145
  layer_cache = kv_cache[i] if kv_cache is not None else None
146
+ x, new_cache = block.forward_kv(x, rope_cos, rope_sin, layer_cache)
147
+ x = self.hidden_films[i](x)
148
  new_kv_cache.append(new_cache)
149
 
150
  x = bb.final_norm(x[:, -1:, :])
 
183
  """Per-layer FiLM deviation from identity, for monitoring."""
184
  report = {}
185
  for i, film in enumerate(self.hidden_films):
186
+ if isinstance(film, FiLMLayer):
187
+ report[f"hidden_{i}/gamma_dev"] = (film.gamma - 1.0).norm().item()
188
+ report[f"hidden_{i}/beta_norm"] = film.beta.norm().item()
189
  if self.output_film is not None:
190
  report["output/gamma_dev"] = (self.output_film.gamma - 1.0).norm().item()
191
  report["output/beta_norm"] = self.output_film.beta.norm().item()
pawn/adapters/hybrid.py CHANGED
@@ -65,9 +65,11 @@ class HybridCLM(nn.Module):
65
  p.requires_grad = False
66
 
67
  # Inject LoRA
68
- for layer_idx, block in enumerate(backbone.layers):
69
  if layer_idx not in self.lora_layer_set:
70
  continue
 
 
71
  attn: Attention = block.attn
72
  for proj_name in self.attn_targets:
73
  original = getattr(attn, proj_name)
@@ -78,14 +80,14 @@ class HybridCLM(nn.Module):
78
  original = getattr(ffn, proj_name)
79
  setattr(ffn, proj_name, LoRALinear(original, lora_rank, self.lora_alpha))
80
 
81
- # Create FiLM layers (identity for non-adapted layers)
82
  if use_film:
83
  self.hidden_films = nn.ModuleList()
84
  for i in range(n_layers):
85
  if i in self.film_layer_set:
86
  self.hidden_films.append(FiLMLayer(cfg.d_model))
87
  else:
88
- self.hidden_films.append(None)
89
  else:
90
  self.hidden_films = None
91
 
@@ -109,7 +111,7 @@ class HybridCLM(nn.Module):
109
 
110
  for i, layer in enumerate(bb.layers):
111
  x = layer(x, rope_cos, rope_sin, None) # LoRA happens inside
112
- if self.hidden_films is not None and self.hidden_films[i] is not None:
113
  x = self.hidden_films[i](x)
114
 
115
  return bb.final_norm(x)
@@ -143,7 +145,7 @@ class HybridCLM(nn.Module):
143
 
144
  for i, layer in enumerate(bb.layers):
145
  x = layer(x, rope_cos, rope_sin, mask)
146
- if self.hidden_films is not None and self.hidden_films[i] is not None:
147
  x = self.hidden_films[i](x)
148
 
149
  x = bb.final_norm(x)
@@ -168,10 +170,11 @@ class HybridCLM(nn.Module):
168
  rope_sin = bb.rope_sin[:, :, :T_new, :]
169
 
170
  new_kv_cache = []
171
- for i, layer in enumerate(bb.layers):
 
172
  layer_cache = kv_cache[i] if kv_cache is not None else None
173
- x, new_cache = layer.forward_kv(x, rope_cos, rope_sin, layer_cache)
174
- if self.hidden_films is not None and self.hidden_films[i] is not None:
175
  x = self.hidden_films[i](x)
176
  new_kv_cache.append(new_cache)
177
 
@@ -187,7 +190,8 @@ class HybridCLM(nn.Module):
187
  def lora_parameters(self) -> list[nn.Parameter]:
188
  """Return only LoRA A/B parameters."""
189
  params = []
190
- for block in self.backbone.layers:
 
191
  for proj_name in self.attn_targets:
192
  module = getattr(block.attn, proj_name)
193
  if isinstance(module, LoRALinear):
@@ -206,7 +210,7 @@ class HybridCLM(nn.Module):
206
  params = []
207
  if self.hidden_films is not None:
208
  for film in self.hidden_films:
209
- if film is not None:
210
  params.extend(film.parameters())
211
  if self.output_film is not None:
212
  params.extend(self.output_film.parameters())
@@ -233,7 +237,8 @@ class HybridCLM(nn.Module):
233
  report = {}
234
 
235
  # LoRA norms
236
- for layer_idx, block in enumerate(self.backbone.layers):
 
237
  for proj_name in self.attn_targets:
238
  module = getattr(block.attn, proj_name)
239
  if isinstance(module, LoRALinear):
@@ -247,7 +252,7 @@ class HybridCLM(nn.Module):
247
  # FiLM norms
248
  if self.hidden_films is not None:
249
  for i, film in enumerate(self.hidden_films):
250
- if film is not None:
251
  report[f"film/hidden_{i}/gamma_dev"] = (film.gamma - 1.0).norm().item()
252
  report[f"film/hidden_{i}/beta_norm"] = film.beta.norm().item()
253
  if self.output_film is not None:
 
65
  p.requires_grad = False
66
 
67
  # Inject LoRA
68
+ for layer_idx in range(n_layers):
69
  if layer_idx not in self.lora_layer_set:
70
  continue
71
+ block = backbone.get_block(layer_idx)
72
+
73
  attn: Attention = block.attn
74
  for proj_name in self.attn_targets:
75
  original = getattr(attn, proj_name)
 
80
  original = getattr(ffn, proj_name)
81
  setattr(ffn, proj_name, LoRALinear(original, lora_rank, self.lora_alpha))
82
 
83
+ # Create FiLM layers (Identity for non-adapted layers)
84
  if use_film:
85
  self.hidden_films = nn.ModuleList()
86
  for i in range(n_layers):
87
  if i in self.film_layer_set:
88
  self.hidden_films.append(FiLMLayer(cfg.d_model))
89
  else:
90
+ self.hidden_films.append(nn.Identity())
91
  else:
92
  self.hidden_films = None
93
 
 
111
 
112
  for i, layer in enumerate(bb.layers):
113
  x = layer(x, rope_cos, rope_sin, None) # LoRA happens inside
114
+ if self.hidden_films is not None:
115
  x = self.hidden_films[i](x)
116
 
117
  return bb.final_norm(x)
 
145
 
146
  for i, layer in enumerate(bb.layers):
147
  x = layer(x, rope_cos, rope_sin, mask)
148
+ if self.hidden_films is not None:
149
  x = self.hidden_films[i](x)
150
 
151
  x = bb.final_norm(x)
 
170
  rope_sin = bb.rope_sin[:, :, :T_new, :]
171
 
172
  new_kv_cache = []
173
+ for i in range(len(bb.layers)):
174
+ block = bb.get_block(i)
175
  layer_cache = kv_cache[i] if kv_cache is not None else None
176
+ x, new_cache = block.forward_kv(x, rope_cos, rope_sin, layer_cache)
177
+ if self.hidden_films is not None:
178
  x = self.hidden_films[i](x)
179
  new_kv_cache.append(new_cache)
180
 
 
190
  def lora_parameters(self) -> list[nn.Parameter]:
191
  """Return only LoRA A/B parameters."""
192
  params = []
193
+ for layer_idx in range(len(self.backbone.layers)):
194
+ block = self.backbone.get_block(layer_idx)
195
  for proj_name in self.attn_targets:
196
  module = getattr(block.attn, proj_name)
197
  if isinstance(module, LoRALinear):
 
210
  params = []
211
  if self.hidden_films is not None:
212
  for film in self.hidden_films:
213
+ if isinstance(film, FiLMLayer):
214
  params.extend(film.parameters())
215
  if self.output_film is not None:
216
  params.extend(self.output_film.parameters())
 
237
  report = {}
238
 
239
  # LoRA norms
240
+ for layer_idx in range(len(self.backbone.layers)):
241
+ block = self.backbone.get_block(layer_idx)
242
  for proj_name in self.attn_targets:
243
  module = getattr(block.attn, proj_name)
244
  if isinstance(module, LoRALinear):
 
252
  # FiLM norms
253
  if self.hidden_films is not None:
254
  for i, film in enumerate(self.hidden_films):
255
+ if isinstance(film, FiLMLayer):
256
  report[f"film/hidden_{i}/gamma_dev"] = (film.gamma - 1.0).norm().item()
257
  report[f"film/hidden_{i}/beta_norm"] = film.beta.norm().item()
258
  if self.output_film is not None:
pawn/adapters/lora.py CHANGED
@@ -90,9 +90,10 @@ class LoRACLM(nn.Module):
90
  p.requires_grad = False
91
 
92
  # Inject LoRA into selected layers
93
- for layer_idx, block in enumerate(backbone.layers):
94
  if layer_idx not in self.adapted_layers:
95
  continue
 
96
 
97
  attn: Attention = block.attn
98
  for proj_name in self.attn_targets:
@@ -182,9 +183,9 @@ class LoRACLM(nn.Module):
182
  rope_sin = bb.rope_sin[:, :, :T_new, :]
183
 
184
  new_kv_cache = []
185
- for i, layer in enumerate(bb.layers):
186
  layer_cache = kv_cache[i] if kv_cache is not None else None
187
- x, new_cache = layer.forward_kv(x, rope_cos, rope_sin, layer_cache)
188
  new_kv_cache.append(new_cache)
189
 
190
  x = bb.final_norm(x[:, -1:, :])
@@ -215,8 +216,8 @@ class LoRACLM(nn.Module):
215
  def lora_weight_report(self) -> dict[str, float]:
216
  """Per-layer LoRA weight norms for monitoring."""
217
  report = {}
218
- for layer_idx, block in enumerate(self.backbone.layers):
219
- attn = block.attn
220
  for proj_name in self.attn_targets:
221
  module = getattr(attn, proj_name)
222
  if isinstance(module, LoRALinear):
@@ -224,7 +225,7 @@ class LoRACLM(nn.Module):
224
  report[f"layer{layer_idx}.{proj_name}.B"] = module.lora_B.data.norm().item()
225
 
226
  if self.adapt_ffn:
227
- ffn = block.ffn
228
  for proj_name in _FFN_TARGETS:
229
  module = getattr(ffn, proj_name)
230
  if isinstance(module, LoRALinear):
 
90
  p.requires_grad = False
91
 
92
  # Inject LoRA into selected layers
93
+ for layer_idx in range(len(backbone.layers)):
94
  if layer_idx not in self.adapted_layers:
95
  continue
96
+ block = backbone.get_block(layer_idx)
97
 
98
  attn: Attention = block.attn
99
  for proj_name in self.attn_targets:
 
183
  rope_sin = bb.rope_sin[:, :, :T_new, :]
184
 
185
  new_kv_cache = []
186
+ for i in range(len(bb.layers)):
187
  layer_cache = kv_cache[i] if kv_cache is not None else None
188
+ x, new_cache = bb.get_block(i).forward_kv(x, rope_cos, rope_sin, layer_cache)
189
  new_kv_cache.append(new_cache)
190
 
191
  x = bb.final_norm(x[:, -1:, :])
 
216
  def lora_weight_report(self) -> dict[str, float]:
217
  """Per-layer LoRA weight norms for monitoring."""
218
  report = {}
219
+ for layer_idx in range(len(self.backbone.layers)):
220
+ attn = self.backbone.get_block(layer_idx).attn
221
  for proj_name in self.attn_targets:
222
  module = getattr(attn, proj_name)
223
  if isinstance(module, LoRALinear):
 
225
  report[f"layer{layer_idx}.{proj_name}.B"] = module.lora_B.data.norm().item()
226
 
227
  if self.adapt_ffn:
228
+ ffn = self.backbone.get_block(layer_idx).ffn
229
  for proj_name in _FFN_TARGETS:
230
  module = getattr(ffn, proj_name)
231
  if isinstance(module, LoRALinear):
pawn/adapters/sparse.py CHANGED
@@ -24,6 +24,8 @@ class SparseLinear(nn.Module):
24
  output = F.linear(x, W_frozen + delta * mask, bias)
25
  """
26
 
 
 
27
  def __init__(self, frozen_linear: nn.Linear, mask: torch.Tensor):
28
  super().__init__()
29
  self.frozen = frozen_linear
@@ -82,9 +84,10 @@ class SparseCLM(nn.Module):
82
  gen = torch.Generator().manual_seed(seed)
83
 
84
  # Inject sparse adapters
85
- for layer_idx, block in enumerate(backbone.layers):
86
  if layer_idx not in self.adapted_layers:
87
  continue
 
88
 
89
  attn: Attention = block.attn
90
  for proj_name in self.attn_targets:
@@ -166,9 +169,9 @@ class SparseCLM(nn.Module):
166
  rope_sin = bb.rope_sin[:, :, :T_new, :]
167
 
168
  new_kv_cache = []
169
- for i, layer in enumerate(bb.layers):
170
  layer_cache = kv_cache[i] if kv_cache is not None else None
171
- x, new_cache = layer.forward_kv(x, rope_cos, rope_sin, layer_cache)
172
  new_kv_cache.append(new_cache)
173
 
174
  x = bb.final_norm(x[:, -1:, :])
@@ -184,7 +187,8 @@ class SparseCLM(nn.Module):
184
  def n_active_params(self) -> int:
185
  """Count of actually active (masked-in) parameters."""
186
  total = 0
187
- for block in self.backbone.layers:
 
188
  for proj_name in self.attn_targets:
189
  module = getattr(block.attn, proj_name)
190
  if isinstance(module, SparseLinear):
@@ -215,7 +219,8 @@ class SparseCLM(nn.Module):
215
  def sparse_weight_report(self) -> dict[str, float]:
216
  """Per-layer sparse delta norms for monitoring."""
217
  report = {}
218
- for layer_idx, block in enumerate(self.backbone.layers):
 
219
  for proj_name in self.attn_targets:
220
  module = getattr(block.attn, proj_name)
221
  if isinstance(module, SparseLinear):
 
24
  output = F.linear(x, W_frozen + delta * mask, bias)
25
  """
26
 
27
+ mask: torch.Tensor
28
+
29
  def __init__(self, frozen_linear: nn.Linear, mask: torch.Tensor):
30
  super().__init__()
31
  self.frozen = frozen_linear
 
84
  gen = torch.Generator().manual_seed(seed)
85
 
86
  # Inject sparse adapters
87
+ for layer_idx in range(len(backbone.layers)):
88
  if layer_idx not in self.adapted_layers:
89
  continue
90
+ block = backbone.get_block(layer_idx)
91
 
92
  attn: Attention = block.attn
93
  for proj_name in self.attn_targets:
 
169
  rope_sin = bb.rope_sin[:, :, :T_new, :]
170
 
171
  new_kv_cache = []
172
+ for i in range(len(bb.layers)):
173
  layer_cache = kv_cache[i] if kv_cache is not None else None
174
+ x, new_cache = bb.get_block(i).forward_kv(x, rope_cos, rope_sin, layer_cache)
175
  new_kv_cache.append(new_cache)
176
 
177
  x = bb.final_norm(x[:, -1:, :])
 
187
  def n_active_params(self) -> int:
188
  """Count of actually active (masked-in) parameters."""
189
  total = 0
190
+ for layer_idx in range(len(self.backbone.layers)):
191
+ block = self.backbone.get_block(layer_idx)
192
  for proj_name in self.attn_targets:
193
  module = getattr(block.attn, proj_name)
194
  if isinstance(module, SparseLinear):
 
219
  def sparse_weight_report(self) -> dict[str, float]:
220
  """Per-layer sparse delta norms for monitoring."""
221
  report = {}
222
+ for layer_idx in range(len(self.backbone.layers)):
223
+ block = self.backbone.get_block(layer_idx)
224
  for proj_name in self.attn_targets:
225
  module = getattr(block.attn, proj_name)
226
  if isinstance(module, SparseLinear):
pawn/data.py CHANGED
@@ -3,6 +3,7 @@
3
  import os
4
  import threading
5
  import time
 
6
 
7
  import numpy as np
8
  import torch
@@ -19,7 +20,7 @@ from pawn.config import (
19
  )
20
 
21
 
22
- _positions_cache: dict = {}
23
 
24
 
25
 
@@ -130,10 +131,10 @@ class CLMDataset(torch.utils.data.IterableDataset):
130
  self._start_step = 0
131
  self._main_pid = os.getpid()
132
 
133
- def set_start_step(self, step: int):
134
  self._start_step = step
135
 
136
- def __iter__(self):
137
  worker_info = torch.utils.data.get_worker_info()
138
  worker_id = worker_info.id if worker_info else 0
139
  num_workers = worker_info.num_workers if worker_info else 1
 
3
  import os
4
  import threading
5
  import time
6
+ from collections.abc import Iterator
7
 
8
  import numpy as np
9
  import torch
 
20
  )
21
 
22
 
23
+ _positions_cache: dict[tuple[str, int], torch.Tensor] = {}
24
 
25
 
26
 
 
131
  self._start_step = 0
132
  self._main_pid = os.getpid()
133
 
134
+ def set_start_step(self, step: int) -> None:
135
  self._start_step = step
136
 
137
+ def __iter__(self) -> Iterator[dict[str, torch.Tensor]]:
138
  worker_info = torch.utils.data.get_worker_info()
139
  worker_id = worker_info.id if worker_info else 0
140
  num_workers = worker_info.num_workers if worker_info else 1
pawn/eval_suite/corpus.py CHANGED
@@ -10,6 +10,7 @@ Storage layout:
10
 
11
  import json
12
  import time
 
13
  from pathlib import Path
14
 
15
  import numpy as np
@@ -32,7 +33,7 @@ def _popcount_u64(arr: np.ndarray) -> np.ndarray:
32
  return result
33
 
34
 
35
- def _count_legal_moves(move_ids, game_lengths):
36
  """Legal move count per ply via bit-packed grids + promo mask."""
37
  grid, promo_mask = engine.compute_legal_move_masks(move_ids, game_lengths)
38
  grid_counts = np.zeros(grid.shape[:2], dtype=np.uint32)
@@ -384,7 +385,7 @@ _PHASES = [("ply_1_20", 0, 20), ("ply_21_80", 20, 80),
384
  ("ply_81_150", 80, 150), ("ply_150_plus", 150, 9999)]
385
 
386
 
387
- def _iter_position_parts(corpus: dict):
388
  """Yield each position parquet part as an eager DataFrame."""
389
  from pathlib import Path
390
  # Find corpus dir from the LazyFrame's file path
@@ -395,7 +396,7 @@ def _iter_position_parts(corpus: dict):
395
  yield pl.read_parquet(f)
396
 
397
 
398
- def _new_accumulator() -> dict:
399
  return {
400
  "n": 0, "sum_k": 0.0, "sum_k_sq": 0.0, "k_min": 999, "k_max": 0,
401
  "sum_inv_k": 0.0, "sum_inv_k_sq": 0.0,
@@ -408,7 +409,7 @@ def _new_accumulator() -> dict:
408
  }
409
 
410
 
411
- def _accumulate(acc: dict, df: pl.DataFrame):
412
  """Accumulate stats from one chunk (already filtered to k > 0)."""
413
  k = df["k"].to_numpy().astype(np.float64)
414
  ply = df["ply"].to_numpy()
@@ -455,7 +456,7 @@ def _accumulate(acc: dict, df: pl.DataFrame):
455
  del k, ply, chk, inv_k, ln_k, top5
456
 
457
 
458
- def _finalize_k_stats(acc):
459
  N = acc["n"]
460
  mean = acc["sum_k"] / N
461
  var = acc["sum_k_sq"] / N - mean ** 2
@@ -466,13 +467,13 @@ def _finalize_k_stats(acc):
466
  "min": acc["k_min"], "max": acc["k_max"]}
467
 
468
 
469
- def _finalize_k_hist(acc):
470
  h = acc["k_hist"]
471
  nz = h > 0
472
  return {"values": np.arange(300)[nz].tolist(), "counts": h[nz].tolist(), "total": acc["n"]}
473
 
474
 
475
- def _finalize_phases(acc):
476
  result = {}
477
  for name, _, _ in _PHASES:
478
  c = acc[f"{name}_n"]
@@ -486,7 +487,7 @@ def _finalize_phases(acc):
486
  return result
487
 
488
 
489
- def _finalize_checks(acc):
490
  N = acc["n"]
491
  result = {}
492
  for label in ("chk", "nochk"):
 
10
 
11
  import json
12
  import time
13
+ from collections.abc import Iterator
14
  from pathlib import Path
15
 
16
  import numpy as np
 
33
  return result
34
 
35
 
36
+ def _count_legal_moves(move_ids: np.ndarray, game_lengths: np.ndarray) -> np.ndarray:
37
  """Legal move count per ply via bit-packed grids + promo mask."""
38
  grid, promo_mask = engine.compute_legal_move_masks(move_ids, game_lengths)
39
  grid_counts = np.zeros(grid.shape[:2], dtype=np.uint32)
 
385
  ("ply_81_150", 80, 150), ("ply_150_plus", 150, 9999)]
386
 
387
 
388
+ def _iter_position_parts(corpus: dict) -> Iterator[pl.DataFrame]:
389
  """Yield each position parquet part as an eager DataFrame."""
390
  from pathlib import Path
391
  # Find corpus dir from the LazyFrame's file path
 
396
  yield pl.read_parquet(f)
397
 
398
 
399
+ def _new_accumulator() -> dict[str, int | float | np.ndarray]:
400
  return {
401
  "n": 0, "sum_k": 0.0, "sum_k_sq": 0.0, "k_min": 999, "k_max": 0,
402
  "sum_inv_k": 0.0, "sum_inv_k_sq": 0.0,
 
409
  }
410
 
411
 
412
+ def _accumulate(acc: dict, df: pl.DataFrame) -> None:
413
  """Accumulate stats from one chunk (already filtered to k > 0)."""
414
  k = df["k"].to_numpy().astype(np.float64)
415
  ply = df["ply"].to_numpy()
 
456
  del k, ply, chk, inv_k, ln_k, top5
457
 
458
 
459
+ def _finalize_k_stats(acc: dict) -> dict[str, float | int]:
460
  N = acc["n"]
461
  mean = acc["sum_k"] / N
462
  var = acc["sum_k_sq"] / N - mean ** 2
 
467
  "min": acc["k_min"], "max": acc["k_max"]}
468
 
469
 
470
+ def _finalize_k_hist(acc: dict) -> dict:
471
  h = acc["k_hist"]
472
  nz = h > 0
473
  return {"values": np.arange(300)[nz].tolist(), "counts": h[nz].tolist(), "total": acc["n"]}
474
 
475
 
476
+ def _finalize_phases(acc: dict) -> dict:
477
  result = {}
478
  for name, _, _ in _PHASES:
479
  c = acc[f"{name}_n"]
 
487
  return result
488
 
489
 
490
+ def _finalize_checks(acc: dict) -> dict:
491
  N = acc["n"]
492
  result = {}
493
  for label in ("chk", "nochk"):
pawn/eval_suite/diagnostics.py CHANGED
@@ -2,6 +2,7 @@
2
 
3
  import numpy as np
4
  import torch
 
5
  import torch.nn.functional as F
6
 
7
  import chess_engine as engine
@@ -39,7 +40,7 @@ def extract_diagnostic_positions(
39
  corpus: dict,
40
  min_per_category: int = 2000,
41
  max_per_category: int = 5000,
42
- ) -> dict:
43
  """Extract diagnostic positions from corpus.
44
 
45
  Returns dict[category_name] -> list of dicts with:
@@ -140,7 +141,7 @@ def _term_code_to_outcome_name(tc: int, gl: int) -> str:
140
 
141
  @torch.no_grad()
142
  def evaluate_diagnostic_positions(
143
- model,
144
  positions: dict,
145
  corpus: dict,
146
  device: str,
 
2
 
3
  import numpy as np
4
  import torch
5
+ import torch.nn as nn
6
  import torch.nn.functional as F
7
 
8
  import chess_engine as engine
 
40
  corpus: dict,
41
  min_per_category: int = 2000,
42
  max_per_category: int = 5000,
43
+ ) -> dict[str, list[dict]]:
44
  """Extract diagnostic positions from corpus.
45
 
46
  Returns dict[category_name] -> list of dicts with:
 
141
 
142
  @torch.no_grad()
143
  def evaluate_diagnostic_positions(
144
+ model: nn.Module,
145
  positions: dict,
146
  corpus: dict,
147
  device: str,
pawn/eval_suite/generation.py CHANGED
@@ -1,9 +1,13 @@
1
  """Autoregressive generation for outcome token signal tests (§6)."""
2
 
 
 
3
  import gc
 
4
 
5
  import numpy as np
6
  import torch
 
7
  import torch.nn.functional as F
8
 
9
  import chess_engine as engine
@@ -12,13 +16,32 @@ from pawn.config import PAD_TOKEN, WHITE_CHECKMATES, PLY_LIMIT, CLMConfig
12
  from pawn.data import _map_termination_to_outcome
13
 
14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  # ---------------------------------------------------------------------------
16
  # Core autoregressive generation
17
  # ---------------------------------------------------------------------------
18
 
19
 
20
  def autoregressive_generate(
21
- model,
22
  outcome_token: int,
23
  n_games: int,
24
  device: str,
@@ -28,7 +51,7 @@ def autoregressive_generate(
28
  max_seq_len: int = 256,
29
  temperature: float = 1.0,
30
  batch_size: int = 64,
31
- ) -> dict:
32
  """Generate games autoregressively from a trained PAWN.
33
 
34
  Args:
@@ -74,9 +97,16 @@ def autoregressive_generate(
74
 
75
  @torch.no_grad()
76
  def _generate_batch(
77
- model, outcome_token, n_games, device, mask_illegal,
78
- prefix_moves, prefix_lengths, max_seq_len, temperature,
79
- ) -> dict:
 
 
 
 
 
 
 
80
  """Generate a batch of games using batch Rust engine for state management."""
81
  cfg_vocab_size = CLMConfig.vocab_size # 4278
82
  max_move_positions = max_seq_len - 1 # position 0 is outcome token
@@ -264,7 +294,7 @@ OUTCOME_TOKENS = {
264
 
265
 
266
  def outcome_signal_test(
267
- model,
268
  device: str,
269
  n_per_outcome: int = 1000,
270
  mask_conditions: tuple[bool, ...] = (False, True),
@@ -354,7 +384,7 @@ def _analyze_generated_games(gen: dict, conditioned_outcome: str) -> dict:
354
 
355
 
356
  def prefix_continuation_test(
357
- model,
358
  corpus: dict,
359
  device: str,
360
  n_per_bucket: int = 200,
@@ -475,7 +505,7 @@ def prefix_continuation_test(
475
  return results
476
 
477
 
478
- def _outcome_mask(term_codes, game_lengths, outcome_name):
479
  """Create a boolean mask for games matching the given outcome."""
480
  if outcome_name == "WHITE_CHECKMATES":
481
  return (term_codes == 0) & (game_lengths % 2 == 1)
@@ -503,7 +533,7 @@ POISONING_PAIRS = [
503
 
504
 
505
  def poisoned_prefix_test(
506
- model,
507
  corpus: dict,
508
  device: str,
509
  n_per_pair: int = 500,
@@ -566,7 +596,7 @@ def poisoned_prefix_test(
566
 
567
 
568
  def impossible_task_test(
569
- model,
570
  corpus: dict,
571
  device: str,
572
  n_per_scenario: int = 200,
@@ -652,7 +682,7 @@ def impossible_task_test(
652
 
653
 
654
  def improbable_task_test(
655
- model,
656
  corpus: dict,
657
  device: str,
658
  n_per_scenario: int = 200,
 
1
  """Autoregressive generation for outcome token signal tests (§6)."""
2
 
3
+ from __future__ import annotations
4
+
5
  import gc
6
+ from typing import Protocol
7
 
8
  import numpy as np
9
  import torch
10
+ import torch.nn as nn
11
  import torch.nn.functional as F
12
 
13
  import chess_engine as engine
 
16
  from pawn.data import _map_termination_to_outcome
17
 
18
 
19
+ class GenerativeModel(Protocol):
20
+ """Structural type for models usable in autoregressive generation."""
21
+
22
+ def eval(self) -> nn.Module: ...
23
+
24
+ def __call__(
25
+ self,
26
+ input_ids: torch.Tensor,
27
+ attention_mask: torch.Tensor,
28
+ hidden_only: bool = ...,
29
+ ) -> tuple[torch.Tensor, list[torch.Tensor]]: ...
30
+
31
+ def forward_generate(
32
+ self,
33
+ input_ids: torch.Tensor,
34
+ kv_cache: list[tuple[torch.Tensor, torch.Tensor]] | None = ...,
35
+ ) -> tuple[torch.Tensor, list[tuple[torch.Tensor, torch.Tensor]]]: ...
36
+
37
+
38
  # ---------------------------------------------------------------------------
39
  # Core autoregressive generation
40
  # ---------------------------------------------------------------------------
41
 
42
 
43
  def autoregressive_generate(
44
+ model: GenerativeModel,
45
  outcome_token: int,
46
  n_games: int,
47
  device: str,
 
51
  max_seq_len: int = 256,
52
  temperature: float = 1.0,
53
  batch_size: int = 64,
54
+ ) -> dict[str, np.ndarray]:
55
  """Generate games autoregressively from a trained PAWN.
56
 
57
  Args:
 
97
 
98
  @torch.no_grad()
99
  def _generate_batch(
100
+ model: GenerativeModel,
101
+ outcome_token: int,
102
+ n_games: int,
103
+ device: str,
104
+ mask_illegal: bool,
105
+ prefix_moves: np.ndarray | None,
106
+ prefix_lengths: np.ndarray | None,
107
+ max_seq_len: int,
108
+ temperature: float,
109
+ ) -> dict[str, np.ndarray]:
110
  """Generate a batch of games using batch Rust engine for state management."""
111
  cfg_vocab_size = CLMConfig.vocab_size # 4278
112
  max_move_positions = max_seq_len - 1 # position 0 is outcome token
 
294
 
295
 
296
  def outcome_signal_test(
297
+ model: GenerativeModel,
298
  device: str,
299
  n_per_outcome: int = 1000,
300
  mask_conditions: tuple[bool, ...] = (False, True),
 
384
 
385
 
386
  def prefix_continuation_test(
387
+ model: GenerativeModel,
388
  corpus: dict,
389
  device: str,
390
  n_per_bucket: int = 200,
 
505
  return results
506
 
507
 
508
+ def _outcome_mask(term_codes: np.ndarray, game_lengths: np.ndarray, outcome_name: str) -> np.ndarray:
509
  """Create a boolean mask for games matching the given outcome."""
510
  if outcome_name == "WHITE_CHECKMATES":
511
  return (term_codes == 0) & (game_lengths % 2 == 1)
 
533
 
534
 
535
  def poisoned_prefix_test(
536
+ model: GenerativeModel,
537
  corpus: dict,
538
  device: str,
539
  n_per_pair: int = 500,
 
596
 
597
 
598
  def impossible_task_test(
599
+ model: GenerativeModel,
600
  corpus: dict,
601
  device: str,
602
  n_per_scenario: int = 200,
 
682
 
683
 
684
  def improbable_task_test(
685
+ model: GenerativeModel,
686
  corpus: dict,
687
  device: str,
688
  n_per_scenario: int = 200,
pawn/eval_suite/lichess.py CHANGED
@@ -5,6 +5,7 @@ from pathlib import Path
5
 
6
  import numpy as np
7
  import torch
 
8
 
9
  import chess_engine as engine
10
 
@@ -119,7 +120,7 @@ def _extract_elos_from_pgn(pgn_path: Path, max_games: int) -> list[tuple[int, in
119
 
120
  @torch.no_grad()
121
  def evaluate_on_lichess(
122
- model,
123
  lichess_data: dict,
124
  device: str,
125
  max_seq_len: int = 256,
 
5
 
6
  import numpy as np
7
  import torch
8
+ import torch.nn as nn
9
 
10
  import chess_engine as engine
11
 
 
120
 
121
  @torch.no_grad()
122
  def evaluate_on_lichess(
123
+ model: nn.Module,
124
  lichess_data: dict,
125
  device: str,
126
  max_seq_len: int = 256,
pawn/eval_suite/probes.py CHANGED
@@ -295,7 +295,7 @@ def _extract_targets(
295
  # ---------------------------------------------------------------------------
296
 
297
 
298
- def _compute_loss(logits, targets, loss_type, n_outputs):
299
  if loss_type == "ce":
300
  return F.cross_entropy(logits, targets)
301
  elif loss_type == "ce_per_square":
@@ -308,7 +308,7 @@ def _compute_loss(logits, targets, loss_type, n_outputs):
308
  raise ValueError(f"Unknown loss type: {loss_type}")
309
 
310
 
311
- def _compute_accuracy(logits, targets, loss_type, n_outputs):
312
  if loss_type == "ce":
313
  preds = logits.argmax(dim=-1)
314
  return (preds == targets).float().mean().item()
@@ -328,7 +328,7 @@ def _compute_accuracy(logits, targets, loss_type, n_outputs):
328
  raise ValueError(f"Unknown loss type: {loss_type}")
329
 
330
 
331
- def _compute_mae(logits, targets):
332
  """Mean absolute error for regression probes."""
333
  return (logits - targets).abs().mean().item()
334
 
@@ -427,7 +427,7 @@ def train_single_probe(
427
  )
428
 
429
 
430
- def _eval_in_batches(probe, h, t, loss_type, n_outputs, device, batch_size):
431
  """Accuracy in mini-batches."""
432
  total_correct = 0.0
433
  total = 0
@@ -441,7 +441,7 @@ def _eval_in_batches(probe, h, t, loss_type, n_outputs, device, batch_size):
441
  return total_correct / total if total > 0 else 0.0
442
 
443
 
444
- def _eval_loss_in_batches(probe, h, t, loss_type, n_outputs, device, batch_size):
445
  """Loss in mini-batches (returns scalar)."""
446
  total_loss = 0.0
447
  total = 0
@@ -455,7 +455,7 @@ def _eval_loss_in_batches(probe, h, t, loss_type, n_outputs, device, batch_size)
455
  return total_loss / total if total > 0 else 0.0
456
 
457
 
458
- def _eval_mae_in_batches(probe, h, t, device, batch_size):
459
  """MAE in mini-batches."""
460
  total_ae = 0.0
461
  total = 0
 
295
  # ---------------------------------------------------------------------------
296
 
297
 
298
+ def _compute_loss(logits: torch.Tensor, targets: torch.Tensor, loss_type: str, n_outputs: int) -> torch.Tensor:
299
  if loss_type == "ce":
300
  return F.cross_entropy(logits, targets)
301
  elif loss_type == "ce_per_square":
 
308
  raise ValueError(f"Unknown loss type: {loss_type}")
309
 
310
 
311
+ def _compute_accuracy(logits: torch.Tensor, targets: torch.Tensor, loss_type: str, n_outputs: int) -> float:
312
  if loss_type == "ce":
313
  preds = logits.argmax(dim=-1)
314
  return (preds == targets).float().mean().item()
 
328
  raise ValueError(f"Unknown loss type: {loss_type}")
329
 
330
 
331
+ def _compute_mae(logits: torch.Tensor, targets: torch.Tensor) -> float:
332
  """Mean absolute error for regression probes."""
333
  return (logits - targets).abs().mean().item()
334
 
 
427
  )
428
 
429
 
430
+ def _eval_in_batches(probe: LinearProbe, h: torch.Tensor, t: torch.Tensor, loss_type: str, n_outputs: int, device: str, batch_size: int) -> float:
431
  """Accuracy in mini-batches."""
432
  total_correct = 0.0
433
  total = 0
 
441
  return total_correct / total if total > 0 else 0.0
442
 
443
 
444
+ def _eval_loss_in_batches(probe: LinearProbe, h: torch.Tensor, t: torch.Tensor, loss_type: str, n_outputs: int, device: str, batch_size: int) -> float:
445
  """Loss in mini-batches (returns scalar)."""
446
  total_loss = 0.0
447
  total = 0
 
455
  return total_loss / total if total > 0 else 0.0
456
 
457
 
458
+ def _eval_mae_in_batches(probe: LinearProbe, h: torch.Tensor, t: torch.Tensor, device: str, batch_size: int) -> float:
459
  """MAE in mini-batches."""
460
  total_ae = 0.0
461
  total = 0
pawn/eval_suite/viz.py CHANGED
@@ -3,6 +3,7 @@
3
  import numpy as np
4
  import matplotlib.pyplot as plt
5
  import matplotlib.ticker as mticker
 
6
  import seaborn as sns
7
 
8
  # Consistent style
@@ -26,11 +27,12 @@ GRID_PAWN_BASELINES = {
26
  # ---------------------------------------------------------------------------
27
 
28
 
29
- def plot_game_length_distribution(stats: dict, ax=None) -> plt.Figure:
30
  """Histogram of game lengths."""
31
  fig = None
32
  if ax is None:
33
  fig, ax = plt.subplots(figsize=FIGSIZE)
 
34
  counts = stats["game_length"]["histogram_counts"]
35
  edges = stats["game_length"]["histogram_edges"]
36
  centers = [(edges[i] + edges[i + 1]) / 2 for i in range(len(counts))]
@@ -44,11 +46,12 @@ def plot_game_length_distribution(stats: dict, ax=None) -> plt.Figure:
44
  return fig or ax.figure
45
 
46
 
47
- def plot_legal_move_distribution(bounds: dict, ax=None) -> plt.Figure:
48
  """Histogram of legal move counts (K) from pre-computed histogram data."""
49
  fig = None
50
  if ax is None:
51
  fig, ax = plt.subplots(figsize=FIGSIZE)
 
52
  k_hist = bounds["k_histogram"]
53
  k_vals = np.array(k_hist["values"])
54
  k_counts = np.array(k_hist["counts"], dtype=np.float64)
@@ -64,11 +67,12 @@ def plot_legal_move_distribution(bounds: dict, ax=None) -> plt.Figure:
64
  return fig or ax.figure
65
 
66
 
67
- def plot_outcome_rates(stats: dict, ax=None) -> plt.Figure:
68
  """Bar chart of outcome base rates."""
69
  fig = None
70
  if ax is None:
71
  fig, ax = plt.subplots(figsize=FIGSIZE)
 
72
  rates = stats["outcome_rates"]
73
  names = list(rates.keys())
74
  values = [rates[n] * 100 for n in names]
@@ -82,11 +86,12 @@ def plot_outcome_rates(stats: dict, ax=None) -> plt.Figure:
82
  return fig or ax.figure
83
 
84
 
85
- def plot_k_by_phase(bounds: dict, ax=None) -> plt.Figure:
86
  """E[1/K] by game phase."""
87
  fig = None
88
  if ax is None:
89
  fig, ax = plt.subplots(figsize=FIGSIZE)
 
90
  phase_data = bounds["phase_bounds"]
91
  names = list(phase_data.keys())
92
  e_inv_k = [phase_data[n]["e_1_over_k"] * 100 for n in names]
@@ -104,11 +109,12 @@ def plot_k_by_phase(bounds: dict, ax=None) -> plt.Figure:
104
  return fig or ax.figure
105
 
106
 
107
- def plot_prefix_histogram(sanity: dict, ax=None) -> plt.Figure:
108
  """Histogram of common prefix lengths."""
109
  fig = None
110
  if ax is None:
111
  fig, ax = plt.subplots(figsize=FIGSIZE)
 
112
  hist = sanity["prefix_length_histogram"]
113
  ks = sorted(hist.keys())
114
  vs = [hist[k] for k in ks]
 
3
  import numpy as np
4
  import matplotlib.pyplot as plt
5
  import matplotlib.ticker as mticker
6
+ from matplotlib.axes import Axes
7
  import seaborn as sns
8
 
9
  # Consistent style
 
27
  # ---------------------------------------------------------------------------
28
 
29
 
30
+ def plot_game_length_distribution(stats: dict, ax: Axes | None = None) -> plt.Figure:
31
  """Histogram of game lengths."""
32
  fig = None
33
  if ax is None:
34
  fig, ax = plt.subplots(figsize=FIGSIZE)
35
+ assert ax is not None
36
  counts = stats["game_length"]["histogram_counts"]
37
  edges = stats["game_length"]["histogram_edges"]
38
  centers = [(edges[i] + edges[i + 1]) / 2 for i in range(len(counts))]
 
46
  return fig or ax.figure
47
 
48
 
49
+ def plot_legal_move_distribution(bounds: dict, ax: Axes | None = None) -> plt.Figure:
50
  """Histogram of legal move counts (K) from pre-computed histogram data."""
51
  fig = None
52
  if ax is None:
53
  fig, ax = plt.subplots(figsize=FIGSIZE)
54
+ assert ax is not None
55
  k_hist = bounds["k_histogram"]
56
  k_vals = np.array(k_hist["values"])
57
  k_counts = np.array(k_hist["counts"], dtype=np.float64)
 
67
  return fig or ax.figure
68
 
69
 
70
+ def plot_outcome_rates(stats: dict, ax: Axes | None = None) -> plt.Figure:
71
  """Bar chart of outcome base rates."""
72
  fig = None
73
  if ax is None:
74
  fig, ax = plt.subplots(figsize=FIGSIZE)
75
+ assert ax is not None
76
  rates = stats["outcome_rates"]
77
  names = list(rates.keys())
78
  values = [rates[n] * 100 for n in names]
 
86
  return fig or ax.figure
87
 
88
 
89
+ def plot_k_by_phase(bounds: dict, ax: Axes | None = None) -> plt.Figure:
90
  """E[1/K] by game phase."""
91
  fig = None
92
  if ax is None:
93
  fig, ax = plt.subplots(figsize=FIGSIZE)
94
+ assert ax is not None
95
  phase_data = bounds["phase_bounds"]
96
  names = list(phase_data.keys())
97
  e_inv_k = [phase_data[n]["e_1_over_k"] * 100 for n in names]
 
109
  return fig or ax.figure
110
 
111
 
112
+ def plot_prefix_histogram(sanity: dict, ax: Axes | None = None) -> plt.Figure:
113
  """Histogram of common prefix lengths."""
114
  fig = None
115
  if ax is None:
116
  fig, ax = plt.subplots(figsize=FIGSIZE)
117
+ assert ax is not None
118
  hist = sanity["prefix_length_histogram"]
119
  ks = sorted(hist.keys())
120
  vs = [hist[k] for k in ks]
pawn/eval_suite/worker.py CHANGED
@@ -20,7 +20,12 @@ from __future__ import annotations
20
 
21
  import gc
22
  import multiprocessing as mp
 
23
  from pathlib import Path
 
 
 
 
24
 
25
  # Use "spawn" so the child gets a clean process with no inherited GPU state.
26
  _ctx = mp.get_context("spawn")
@@ -31,11 +36,11 @@ _ctx = mp.get_context("spawn")
31
  # ---------------------------------------------------------------------------
32
 
33
 
34
- def _worker_entry(fn, args, kwargs):
35
  return fn(*args, **kwargs)
36
 
37
 
38
- def run_in_worker(fn, *args, timeout: float | None = None, **kwargs):
39
  """Run fn(*args, **kwargs) in an isolated worker process.
40
 
41
  On KeyboardInterrupt, the worker is terminated and the interrupt is
@@ -55,7 +60,7 @@ def run_in_worker(fn, *args, timeout: float | None = None, **kwargs):
55
  # ---------------------------------------------------------------------------
56
 
57
 
58
- def _load_model(checkpoint_path: str, device: str):
59
  """Load and freeze a PAWNCLM checkpoint. Runs inside worker processes."""
60
  import torch
61
  from pawn.config import CLMConfig
@@ -83,8 +88,8 @@ def _load_corpus(corpus_dir: str) -> dict:
83
  # ---------------------------------------------------------------------------
84
 
85
 
86
- def _probes_worker(checkpoint_path, device, n_train, n_val, n_epochs,
87
- seed_train, seed_val):
88
  from pawn.eval_suite.probes import extract_probe_data, train_all_probes
89
  model = _load_model(checkpoint_path, device)
90
  train_data = extract_probe_data(n_train, max_ply=256, seed=seed_train)
@@ -94,7 +99,7 @@ def _probes_worker(checkpoint_path, device, n_train, n_val, n_epochs,
94
 
95
 
96
  def run_probes(
97
- checkpoint_path,
98
  device: str,
99
  n_train: int = 5_000,
100
  n_val: int = 1_000,
@@ -110,7 +115,8 @@ def run_probes(
110
  )
111
 
112
 
113
- def _signal_test_worker(checkpoint_path, device, n_per_outcome, mask_conditions):
 
114
  from pawn.eval_suite.generation import outcome_signal_test
115
  model = _load_model(checkpoint_path, device)
116
  return outcome_signal_test(model, device, n_per_outcome=n_per_outcome,
@@ -118,7 +124,7 @@ def _signal_test_worker(checkpoint_path, device, n_per_outcome, mask_conditions)
118
 
119
 
120
  def run_outcome_signal_test(
121
- checkpoint_path,
122
  device: str,
123
  n_per_outcome: int = 1000,
124
  mask_conditions: tuple[bool, ...] = (False, True),
@@ -130,8 +136,9 @@ def run_outcome_signal_test(
130
  )
131
 
132
 
133
- def _prefix_continuation_worker(checkpoint_path, corpus_dir, device,
134
- n_per_bucket, prefix_pcts, absolute_plies):
 
135
  from pawn.eval_suite.generation import prefix_continuation_test
136
  model = _load_model(checkpoint_path, device)
137
  corpus = _load_corpus(corpus_dir)
@@ -142,8 +149,8 @@ def _prefix_continuation_worker(checkpoint_path, corpus_dir, device,
142
 
143
 
144
  def run_prefix_continuation_test(
145
- checkpoint_path,
146
- corpus_dir,
147
  device: str,
148
  n_per_bucket: int = 200,
149
  prefix_pcts: tuple[float, ...] = (0.1, 0.5, 0.9),
@@ -157,8 +164,8 @@ def run_prefix_continuation_test(
157
  )
158
 
159
 
160
- def _poisoned_prefix_worker(checkpoint_path, corpus_dir, device,
161
- n_per_pair, prefix_pct):
162
  from pawn.eval_suite.generation import poisoned_prefix_test
163
  model = _load_model(checkpoint_path, device)
164
  corpus = _load_corpus(corpus_dir)
@@ -167,8 +174,8 @@ def _poisoned_prefix_worker(checkpoint_path, corpus_dir, device,
167
 
168
 
169
  def run_poisoned_prefix_test(
170
- checkpoint_path,
171
- corpus_dir,
172
  device: str,
173
  n_per_pair: int = 500,
174
  prefix_pct: float = 0.5,
@@ -180,7 +187,8 @@ def run_poisoned_prefix_test(
180
  )
181
 
182
 
183
- def _impossible_task_worker(checkpoint_path, corpus_dir, device, n_per_scenario):
 
184
  from pawn.eval_suite.generation import impossible_task_test
185
  model = _load_model(checkpoint_path, device)
186
  corpus = _load_corpus(corpus_dir)
@@ -188,8 +196,8 @@ def _impossible_task_worker(checkpoint_path, corpus_dir, device, n_per_scenario)
188
 
189
 
190
  def run_impossible_task_test(
191
- checkpoint_path,
192
- corpus_dir,
193
  device: str,
194
  n_per_scenario: int = 200,
195
  ) -> dict:
@@ -200,7 +208,8 @@ def run_impossible_task_test(
200
  )
201
 
202
 
203
- def _improbable_task_worker(checkpoint_path, corpus_dir, device, n_per_scenario):
 
204
  from pawn.eval_suite.generation import improbable_task_test
205
  model = _load_model(checkpoint_path, device)
206
  corpus = _load_corpus(corpus_dir)
@@ -208,8 +217,8 @@ def _improbable_task_worker(checkpoint_path, corpus_dir, device, n_per_scenario)
208
 
209
 
210
  def run_improbable_task_test(
211
- checkpoint_path,
212
- corpus_dir,
213
  device: str,
214
  n_per_scenario: int = 200,
215
  ) -> dict:
@@ -220,8 +229,9 @@ def run_improbable_task_test(
220
  )
221
 
222
 
223
- def _diagnostic_worker(checkpoint_path, corpus_dir, device, min_per_category,
224
- max_per_category, n_samples, batch_size):
 
225
  from pawn.eval_suite.diagnostics import (
226
  extract_diagnostic_positions, evaluate_diagnostic_positions,
227
  )
@@ -240,8 +250,8 @@ def _diagnostic_worker(checkpoint_path, corpus_dir, device, min_per_category,
240
 
241
 
242
  def run_diagnostic_eval(
243
- checkpoint_path,
244
- corpus_dir,
245
  device: str,
246
  min_per_category: int = 2000,
247
  max_per_category: int = 5000,
 
20
 
21
  import gc
22
  import multiprocessing as mp
23
+ from collections.abc import Callable
24
  from pathlib import Path
25
+ from typing import Any, TYPE_CHECKING
26
+
27
+ if TYPE_CHECKING:
28
+ from pawn.model import PAWNCLM
29
 
30
  # Use "spawn" so the child gets a clean process with no inherited GPU state.
31
  _ctx = mp.get_context("spawn")
 
36
  # ---------------------------------------------------------------------------
37
 
38
 
39
+ def _worker_entry(fn: Callable[..., Any], args: tuple, kwargs: dict) -> Any:
40
  return fn(*args, **kwargs)
41
 
42
 
43
+ def run_in_worker(fn: Callable[..., Any], *args: Any, timeout: float | None = None, **kwargs: Any) -> Any:
44
  """Run fn(*args, **kwargs) in an isolated worker process.
45
 
46
  On KeyboardInterrupt, the worker is terminated and the interrupt is
 
60
  # ---------------------------------------------------------------------------
61
 
62
 
63
+ def _load_model(checkpoint_path: str, device: str) -> PAWNCLM:
64
  """Load and freeze a PAWNCLM checkpoint. Runs inside worker processes."""
65
  import torch
66
  from pawn.config import CLMConfig
 
88
  # ---------------------------------------------------------------------------
89
 
90
 
91
+ def _probes_worker(checkpoint_path: str, device: str, n_train: int, n_val: int,
92
+ n_epochs: int, seed_train: int, seed_val: int) -> dict:
93
  from pawn.eval_suite.probes import extract_probe_data, train_all_probes
94
  model = _load_model(checkpoint_path, device)
95
  train_data = extract_probe_data(n_train, max_ply=256, seed=seed_train)
 
99
 
100
 
101
  def run_probes(
102
+ checkpoint_path: str | Path,
103
  device: str,
104
  n_train: int = 5_000,
105
  n_val: int = 1_000,
 
115
  )
116
 
117
 
118
+ def _signal_test_worker(checkpoint_path: str, device: str, n_per_outcome: int,
119
+ mask_conditions: list[bool]) -> dict:
120
  from pawn.eval_suite.generation import outcome_signal_test
121
  model = _load_model(checkpoint_path, device)
122
  return outcome_signal_test(model, device, n_per_outcome=n_per_outcome,
 
124
 
125
 
126
  def run_outcome_signal_test(
127
+ checkpoint_path: str | Path,
128
  device: str,
129
  n_per_outcome: int = 1000,
130
  mask_conditions: tuple[bool, ...] = (False, True),
 
136
  )
137
 
138
 
139
+ def _prefix_continuation_worker(checkpoint_path: str, corpus_dir: str, device: str,
140
+ n_per_bucket: int, prefix_pcts: list[float],
141
+ absolute_plies: list[int]) -> dict:
142
  from pawn.eval_suite.generation import prefix_continuation_test
143
  model = _load_model(checkpoint_path, device)
144
  corpus = _load_corpus(corpus_dir)
 
149
 
150
 
151
  def run_prefix_continuation_test(
152
+ checkpoint_path: str | Path,
153
+ corpus_dir: str | Path,
154
  device: str,
155
  n_per_bucket: int = 200,
156
  prefix_pcts: tuple[float, ...] = (0.1, 0.5, 0.9),
 
164
  )
165
 
166
 
167
+ def _poisoned_prefix_worker(checkpoint_path: str, corpus_dir: str, device: str,
168
+ n_per_pair: int, prefix_pct: float) -> dict:
169
  from pawn.eval_suite.generation import poisoned_prefix_test
170
  model = _load_model(checkpoint_path, device)
171
  corpus = _load_corpus(corpus_dir)
 
174
 
175
 
176
  def run_poisoned_prefix_test(
177
+ checkpoint_path: str | Path,
178
+ corpus_dir: str | Path,
179
  device: str,
180
  n_per_pair: int = 500,
181
  prefix_pct: float = 0.5,
 
187
  )
188
 
189
 
190
+ def _impossible_task_worker(checkpoint_path: str, corpus_dir: str, device: str,
191
+ n_per_scenario: int) -> dict:
192
  from pawn.eval_suite.generation import impossible_task_test
193
  model = _load_model(checkpoint_path, device)
194
  corpus = _load_corpus(corpus_dir)
 
196
 
197
 
198
  def run_impossible_task_test(
199
+ checkpoint_path: str | Path,
200
+ corpus_dir: str | Path,
201
  device: str,
202
  n_per_scenario: int = 200,
203
  ) -> dict:
 
208
  )
209
 
210
 
211
+ def _improbable_task_worker(checkpoint_path: str, corpus_dir: str, device: str,
212
+ n_per_scenario: int) -> dict:
213
  from pawn.eval_suite.generation import improbable_task_test
214
  model = _load_model(checkpoint_path, device)
215
  corpus = _load_corpus(corpus_dir)
 
217
 
218
 
219
  def run_improbable_task_test(
220
+ checkpoint_path: str | Path,
221
+ corpus_dir: str | Path,
222
  device: str,
223
  n_per_scenario: int = 200,
224
  ) -> dict:
 
229
  )
230
 
231
 
232
+ def _diagnostic_worker(checkpoint_path: str, corpus_dir: str, device: str,
233
+ min_per_category: int, max_per_category: int,
234
+ n_samples: int, batch_size: int) -> dict:
235
  from pawn.eval_suite.diagnostics import (
236
  extract_diagnostic_positions, evaluate_diagnostic_positions,
237
  )
 
250
 
251
 
252
  def run_diagnostic_eval(
253
+ checkpoint_path: str | Path,
254
+ corpus_dir: str | Path,
255
  device: str,
256
  min_per_category: int = 2000,
257
  max_per_category: int = 5000,
pawn/lichess_data.py CHANGED
@@ -308,10 +308,10 @@ class LichessDataset(torch.utils.data.Dataset):
308
  self.game_lengths = torch.from_numpy(np.array(self.game_lengths)).share_memory_()
309
  return self
310
 
311
- def __len__(self):
312
  return len(self.input_ids)
313
 
314
- def __getitem__(self, idx):
315
  return {
316
  "input_ids": self.input_ids[idx],
317
  "targets": self.targets[idx],
 
308
  self.game_lengths = torch.from_numpy(np.array(self.game_lengths)).share_memory_()
309
  return self
310
 
311
+ def __len__(self) -> int:
312
  return len(self.input_ids)
313
 
314
+ def __getitem__(self, idx: int) -> dict[str, torch.Tensor | int]:
315
  return {
316
  "input_ids": self.input_ids[idx],
317
  "targets": self.targets[idx],
pawn/logging.py CHANGED
@@ -72,7 +72,7 @@ class MetricsLogger:
72
  record_type: Record type (train, eval, batch, etc.)
73
  include_resources: Whether to include memory/CPU stats
74
  """
75
- record = {"type": record_type}
76
 
77
  if step is not None:
78
  record["step"] = step
@@ -119,14 +119,14 @@ class MetricsLogger:
119
  def close(self) -> None:
120
  self._file.close()
121
 
122
- def __enter__(self):
123
  return self
124
 
125
- def __exit__(self, *args):
126
  self.close()
127
 
128
 
129
- def _sanitize(obj):
130
  """Replace NaN/Inf with None for valid JSON."""
131
  if isinstance(obj, float) and (math.isnan(obj) or math.isinf(obj)):
132
  return None
 
72
  record_type: Record type (train, eval, batch, etc.)
73
  include_resources: Whether to include memory/CPU stats
74
  """
75
+ record: dict[str, object] = {"type": record_type}
76
 
77
  if step is not None:
78
  record["step"] = step
 
119
  def close(self) -> None:
120
  self._file.close()
121
 
122
+ def __enter__(self) -> "MetricsLogger":
123
  return self
124
 
125
+ def __exit__(self, *args: object) -> None:
126
  self.close()
127
 
128
 
129
+ def _sanitize(obj: object) -> object:
130
  """Replace NaN/Inf with None for valid JSON."""
131
  if isinstance(obj, float) and (math.isnan(obj) or math.isinf(obj)):
132
  return None
pawn/model.py CHANGED
@@ -179,6 +179,11 @@ class SwiGLUFFN(nn.Module):
179
 
180
 
181
  class TransformerBlock(nn.Module):
 
 
 
 
 
182
  def __init__(self, cfg: CLMConfig):
183
  super().__init__()
184
  self.attn_norm = RMSNorm(cfg.d_model)
@@ -220,6 +225,8 @@ class CLMEmbedding(nn.Module):
220
  PAD and outcome tokens use standalone embeddings.
221
  """
222
 
 
 
223
  def __init__(self, cfg: CLMConfig):
224
  super().__init__()
225
  self.d_model = cfg.d_model
@@ -273,6 +280,14 @@ class PAWNCLM(nn.Module):
273
  full vocabulary. No factored output head, no grid, no BCE.
274
  """
275
 
 
 
 
 
 
 
 
 
276
  def __init__(self, cfg: CLMConfig):
277
  super().__init__()
278
  self.cfg = cfg
@@ -300,6 +315,10 @@ class PAWNCLM(nn.Module):
300
 
301
  self._init_weights()
302
 
 
 
 
 
303
  def _init_weights(self):
304
  for p in self.parameters():
305
  if p.dim() > 1:
@@ -425,9 +444,9 @@ class PAWNCLM(nn.Module):
425
  rope_sin = self.rope_sin[:, :, :T_new, :]
426
 
427
  new_kv_cache = []
428
- for i, layer in enumerate(self.layers):
429
  layer_cache = kv_cache[i] if kv_cache is not None else None
430
- x, new_cache = layer.forward_kv(x, rope_cos, rope_sin, layer_cache)
431
  new_kv_cache.append(new_cache)
432
 
433
  x = self.final_norm(x[:, -1:, :])
 
179
 
180
 
181
  class TransformerBlock(nn.Module):
182
+ attn_norm: RMSNorm
183
+ attn: Attention
184
+ ffn_norm: RMSNorm
185
+ ffn: SwiGLUFFN
186
+
187
  def __init__(self, cfg: CLMConfig):
188
  super().__init__()
189
  self.attn_norm = RMSNorm(cfg.d_model)
 
225
  PAD and outcome tokens use standalone embeddings.
226
  """
227
 
228
+ decomp_table: torch.Tensor
229
+
230
  def __init__(self, cfg: CLMConfig):
231
  super().__init__()
232
  self.d_model = cfg.d_model
 
280
  full vocabulary. No factored output head, no grid, no BCE.
281
  """
282
 
283
+ rope_cos: torch.Tensor
284
+ rope_sin: torch.Tensor
285
+ causal_mask: torch.Tensor
286
+ embed: CLMEmbedding
287
+ layers: nn.ModuleList
288
+ final_norm: RMSNorm
289
+ lm_head: nn.Linear
290
+
291
  def __init__(self, cfg: CLMConfig):
292
  super().__init__()
293
  self.cfg = cfg
 
315
 
316
  self._init_weights()
317
 
318
+ def get_block(self, i: int) -> TransformerBlock:
319
+ """Typed accessor for transformer layers (avoids ModuleList type erasure)."""
320
+ return self.layers[i] # type: ignore[return-value]
321
+
322
  def _init_weights(self):
323
  for p in self.parameters():
324
  if p.dim() > 1:
 
444
  rope_sin = self.rope_sin[:, :, :T_new, :]
445
 
446
  new_kv_cache = []
447
+ for i in range(len(self.layers)):
448
  layer_cache = kv_cache[i] if kv_cache is not None else None
449
+ x, new_cache = self.get_block(i).forward_kv(x, rope_cos, rope_sin, layer_cache)
450
  new_kv_cache.append(new_cache)
451
 
452
  x = self.final_norm(x[:, -1:, :])
pawn/trainer.py CHANGED
@@ -38,12 +38,12 @@ class CosineWithWarmup:
38
  self._step = 0
39
  self._apply_lr(0)
40
 
41
- def _apply_lr(self, step: int):
42
  lr_scale = self._compute_lr_scale(step)
43
  for pg, base_lr in zip(self.optimizer.param_groups, self.base_lrs, strict=True):
44
  pg["lr"] = base_lr * lr_scale
45
 
46
- def step(self):
47
  self._step += 1
48
  self._apply_lr(self._step)
49
 
@@ -59,10 +59,10 @@ class CosineWithWarmup:
59
  def get_lr(self) -> float:
60
  return self.optimizer.param_groups[0]["lr"]
61
 
62
- def state_dict(self):
63
  return {"step": self._step}
64
 
65
- def load_state_dict(self, state):
66
  self._step = state["step"]
67
  self._apply_lr(self._step)
68
 
@@ -208,8 +208,9 @@ class CLMTrainer:
208
  self._jsonl_path = os.path.join(self.run_dir, "metrics.jsonl")
209
  self._jsonl_file = None
210
 
211
- self.model = PAWNCLM(model_cfg).to(self.device)
212
- param_count = sum(p.numel() for p in self.model.parameters())
 
213
  print(f"Model parameters: {param_count:,}")
214
  print(f"Run directory: {self.run_dir}")
215
 
@@ -345,7 +346,7 @@ class CLMTrainer:
345
 
346
  def optimizer_step(self) -> float:
347
  self.scaler.unscale_(self.optimizer)
348
- grad_norm = _get_grad_norm(self.model)
349
  torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.cfg.max_grad_norm)
350
  self.scaler.step(self.optimizer)
351
  self.scaler.update()
@@ -354,7 +355,7 @@ class CLMTrainer:
354
  return grad_norm
355
 
356
  def _eager_model(self) -> PAWNCLM:
357
- return self.model._orig_mod if hasattr(self.model, "_orig_mod") else self.model
358
 
359
  @torch.no_grad()
360
  def evaluate(self) -> dict[str, float]:
@@ -535,9 +536,7 @@ class CLMTrainer:
535
  if dirname:
536
  os.makedirs(dirname, exist_ok=True)
537
 
538
- model = self.model
539
- if hasattr(model, "_orig_mod"):
540
- model = model._orig_mod
541
 
542
  torch.save(
543
  {
@@ -561,9 +560,7 @@ class CLMTrainer:
561
  ckpt = torch.load(path, map_location=self.device, weights_only=False)
562
  self.global_step = ckpt["global_step"]
563
 
564
- model = self.model
565
- if hasattr(model, "_orig_mod"):
566
- model = model._orig_mod
567
 
568
  model.load_state_dict(ckpt["model_state_dict"])
569
  self.optimizer.load_state_dict(ckpt["optimizer_state_dict"])
 
38
  self._step = 0
39
  self._apply_lr(0)
40
 
41
+ def _apply_lr(self, step: int) -> None:
42
  lr_scale = self._compute_lr_scale(step)
43
  for pg, base_lr in zip(self.optimizer.param_groups, self.base_lrs, strict=True):
44
  pg["lr"] = base_lr * lr_scale
45
 
46
+ def step(self) -> None:
47
  self._step += 1
48
  self._apply_lr(self._step)
49
 
 
59
  def get_lr(self) -> float:
60
  return self.optimizer.param_groups[0]["lr"]
61
 
62
+ def state_dict(self) -> dict[str, int]:
63
  return {"step": self._step}
64
 
65
+ def load_state_dict(self, state: dict[str, int]) -> None:
66
  self._step = state["step"]
67
  self._apply_lr(self._step)
68
 
 
208
  self._jsonl_path = os.path.join(self.run_dir, "metrics.jsonl")
209
  self._jsonl_file = None
210
 
211
+ self._model = PAWNCLM(model_cfg).to(self.device)
212
+ self.model = self._model
213
+ param_count = sum(p.numel() for p in self._model.parameters())
214
  print(f"Model parameters: {param_count:,}")
215
  print(f"Run directory: {self.run_dir}")
216
 
 
346
 
347
  def optimizer_step(self) -> float:
348
  self.scaler.unscale_(self.optimizer)
349
+ grad_norm = _get_grad_norm(self._model)
350
  torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.cfg.max_grad_norm)
351
  self.scaler.step(self.optimizer)
352
  self.scaler.update()
 
355
  return grad_norm
356
 
357
  def _eager_model(self) -> PAWNCLM:
358
+ return self._model
359
 
360
  @torch.no_grad()
361
  def evaluate(self) -> dict[str, float]:
 
536
  if dirname:
537
  os.makedirs(dirname, exist_ok=True)
538
 
539
+ model: PAWNCLM = self._eager_model()
 
 
540
 
541
  torch.save(
542
  {
 
560
  ckpt = torch.load(path, map_location=self.device, weights_only=False)
561
  self.global_step = ckpt["global_step"]
562
 
563
+ model: PAWNCLM = self._eager_model()
 
 
564
 
565
  model.load_state_dict(ckpt["model_state_dict"])
566
  self.optimizer.load_state_dict(ckpt["optimizer_state_dict"])
pyproject.toml CHANGED
@@ -57,5 +57,13 @@ name = "pytorch-cu128"
57
  url = "https://download.pytorch.org/whl/cu128"
58
  explicit = true
59
 
 
 
 
 
 
 
 
 
60
  [tool.pytest.ini_options]
61
  testpaths = ["tests"]
 
57
  url = "https://download.pytorch.org/whl/cu128"
58
  explicit = true
59
 
60
+ [tool.pyright]
61
+ pythonVersion = "3.10"
62
+ typeCheckingMode = "basic"
63
+ reportMissingTypeStubs = false
64
+ reportPrivateImportUsage = false
65
+ reportMissingImports = "warning"
66
+ include = ["pawn"]
67
+
68
  [tool.pytest.ini_options]
69
  testpaths = ["tests"]