Fix: _unstack_scan_params breaks after flax deserialization (from_bytes returns numpy arrays)

#1
by dignity045 - opened
Files changed (1) hide show
  1. LaughLM/model/gpt.py +216 -75
LaughLM/model/gpt.py CHANGED
@@ -1,35 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
 
2
  import jax
3
  import jax.numpy as jnp
4
  from flax import linen as nn
5
- from typing import Optional, Tuple
6
 
7
  from LaughLM.config.schema import LaughLMConfig
8
- from LaughLM.model.transformer_block import TransformerBlock
9
  from LaughLM.model.layers.normalization import build_normalization
10
  from LaughLM.model.layers.positional import (
11
  build_positional_encoding,
12
  build_rope_tables,
13
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
 
16
  class GPTModel(nn.Module):
17
  config: LaughLMConfig
18
 
19
  def setup(self):
 
 
 
 
 
20
 
21
- cfg = self.config
22
- d_model = cfg.model.d_model
23
- vocab_size = cfg.model.vocab_size
24
- num_layers = cfg.model.num_layers
25
- pos_type = cfg.architecture.positional
26
- compute_bf16 = (cfg.parallelism.compute_dtype == "bfloat16")
27
-
28
- self._compute_dtype = jnp.bfloat16 if compute_bf16 else jnp.float32
29
 
30
- # ------------------------------------------------------------
31
- # Token embedding
32
- # ------------------------------------------------------------
33
  self.token_embedding = nn.Embed(
34
  num_embeddings=vocab_size,
35
  features=d_model,
@@ -38,111 +166,124 @@ class GPTModel(nn.Module):
38
  ),
39
  )
40
 
41
- # ------------------------------------------------------------
42
- # Positional encoding (additive only)
43
- # ------------------------------------------------------------
44
  self.positional = build_positional_encoding(cfg)
45
 
46
- # ------------------------------------------------------------
47
- # RoPE
48
- # ------------------------------------------------------------
49
  self._use_rope = pos_type in ("rope", "rope_scaled")
50
 
51
  if self._use_rope:
52
  head_dim = d_model // cfg.model.num_heads
 
53
  self._rope_sin, self._rope_cos = build_rope_tables(
54
  head_dim=head_dim,
55
  max_seq_len=cfg.model.max_seq_len,
 
56
  )
57
  else:
58
  self._rope_sin = None
59
  self._rope_cos = None
60
 
61
- # ------------------------------------------------------------
62
- # Transformer blocks
63
- # ------------------------------------------------------------
64
- self.blocks = [
65
- TransformerBlock(config=cfg)
66
- for _ in range(num_layers)
67
- ]
 
 
68
 
 
69
  self.final_norm = build_normalization(cfg)
70
 
 
71
  if not cfg.architecture.weight_tying:
72
  self.lm_head = nn.Dense(
73
  vocab_size,
74
  use_bias=cfg.architecture.bias,
75
- kernel_init=nn.initializers.normal(
76
- stddev=cfg.initialization.std
77
- ),
78
  )
79
 
80
  def __call__(
81
  self,
82
  input_ids: jnp.ndarray,
83
  doc_ids: Optional[jnp.ndarray] = None,
84
- ) -> jnp.ndarray:
85
-
86
- # ------------------------------------------------------------
87
- # πŸ”΄ CRITICAL: enforce input contract
88
- # ------------------------------------------------------------
89
- assert input_ids.ndim == 2, f"[GPT] Expected (B, T), got {input_ids.shape}"
90
 
 
91
  B, T = input_ids.shape
92
 
93
- # ------------------------------------------------------------
94
- # Token embedding
95
- # ------------------------------------------------------------
96
- x = self.token_embedding(input_ids) # (B, T, D)
97
  x = x.astype(self._compute_dtype)
98
 
99
- # ------------------------------------------------------------
100
- # Positional encoding (safe broadcasting)
101
- # ------------------------------------------------------------
102
  if self.positional is not None:
103
- positions = jnp.arange(T)[None, :] # (1, T)
104
- pos_emb = self.positional(positions) # (1, T, D)
105
-
106
- # πŸ”΄ CRITICAL FIX: enforce shape explicitly
107
- assert pos_emb.ndim == 3, f"[GPT] pos_emb wrong shape: {pos_emb.shape}"
108
- assert pos_emb.shape[1] == T, f"[GPT] pos_emb T mismatch: {pos_emb.shape}"
109
-
110
- # Safe broadcast
111
  x = x + pos_emb.astype(self._compute_dtype)
112
 
113
- # ------------------------------------------------------------
114
- # RoPE tables (slice once)
115
- # ------------------------------------------------------------
116
- rope_tables: Optional[Tuple] = None
117
  if self._use_rope:
118
- rope_tables = (
119
- self._rope_sin[:T],
120
- self._rope_cos[:T],
121
- )
122
 
123
- # ------------------------------------------------------------
124
- # Transformer stack
125
- # ------------------------------------------------------------
126
- for block in self.blocks:
127
- x = block(x, rope_tables=rope_tables, doc_ids=doc_ids)
128
 
129
- # ------------------------------------------------------------
130
- # Final norm
131
- # ------------------------------------------------------------
132
- x = self.final_norm(x)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
 
134
- # ------------------------------------------------------------
135
- # Back to FP32 for logits
136
- # ------------------------------------------------------------
 
 
 
 
 
 
 
 
 
 
137
  x = x.astype(jnp.float32)
138
 
139
- # ------------------------------------------------------------
140
- # Output projection
141
- # ------------------------------------------------------------
142
  if self.config.architecture.weight_tying:
143
- embedding_table = self.token_embedding.embedding # (V, D)
144
  logits = jnp.einsum("btd,vd->btv", x, embedding_table)
145
  else:
146
  logits = self.lm_head(x)
147
 
148
- return logits
 
1
+ """
2
+ LaughLM/model/gpt.py
3
+
4
+ Top-level GPT model β€” nn.scan for training, for-loop for inference.
5
+
6
+ Key design:
7
+ - Training (kv_caches=None): uses nn.scan when scan_layers=True for O(1) compile
8
+ - Inference (kv_caches != None): uses a for-loop with per-layer params extracted
9
+ from the scan variable tree
10
+
11
+ FIX (audit 2025): Previous code created UNINITIALIZED TransformerBlock instances
12
+ during inference when scan_layers=True, producing GARBAGE output.
13
+
14
+ The fix: when scan_layers=True, inference extracts per-layer params from the
15
+ scanned param tree via _unstack_scan_params() and uses .apply() to run each
16
+ block statelessly. When scan_layers=False, blocks run normally via self.blocks.
17
+
18
+ A single "reference block" is created for type/structure only during scan mode β€”
19
+ it's used via .apply() with per-layer params, never via __call__ with its own params.
20
+
21
+ FIX (2026-05-06): _unstack_scan_params used isinstance(tree, jnp.ndarray) to detect
22
+ leaf arrays to split. After flax.serialization.from_bytes(), params become plain
23
+ numpy.ndarray instances (NOT jnp.ndarray), causing the check to fail and the
24
+ scanned params to be returned un-split. The reference block then received
25
+ stacked params with shape (num_layers, d_model) instead of per-layer (d_model,),
26
+ triggering flax.errors.ScopeParamShapeError. Fixed by using duck-typing
27
+ (hasattr ndim + shape) instead of isinstance.
28
+ """
29
 
30
  import jax
31
  import jax.numpy as jnp
32
  from flax import linen as nn
33
+ from typing import Optional, Tuple, List
34
 
35
  from LaughLM.config.schema import LaughLMConfig
36
+ from LaughLM.model.transformer_block import TransformerBlock, build_block, get_remat_policy
37
  from LaughLM.model.layers.normalization import build_normalization
38
  from LaughLM.model.layers.positional import (
39
  build_positional_encoding,
40
  build_rope_tables,
41
  )
42
+ from LaughLM.model.layers.attention import KVCache
43
+ from LaughLM.utils.dtype import resolve_compute_dtype
44
+
45
+
46
+ def _build_scanned_block(config: LaughLMConfig):
47
+ """
48
+ Build a scanned transformer stack using nn.scan.
49
+ Params stacked [num_layers, ...]. Single XLA trace reused N times.
50
+ """
51
+ remat_cfg = config.spmd.remat
52
+
53
+ BlockClass = TransformerBlock
54
+
55
+ if remat_cfg.policy != "everything_saveable":
56
+ policy = get_remat_policy(remat_cfg.policy)
57
+ BlockClass = nn.remat(
58
+ BlockClass,
59
+ policy=policy,
60
+ prevent_cse=remat_cfg.prevent_cse,
61
+ )
62
+
63
+ ScanBlock = nn.scan(
64
+ BlockClass,
65
+ variable_axes={"params": 0},
66
+ split_rngs={"params": True},
67
+ in_axes=(nn.broadcast, nn.broadcast, nn.broadcast),
68
+ length=config.model.num_layers,
69
+ )
70
+
71
+ return ScanBlock(config=config)
72
+
73
+
74
+ def _unstack_scan_params(params, num_layers):
75
+ """
76
+ Convert scanned param tree β†’ list of per-layer param dicts.
77
+
78
+ nn.scan with variable_axes={"params": 0} stacks each scanned variable
79
+ along axis 0. This function recursively walks the param dict and for
80
+ each leaf ndarray with shape[0] == num_layers, splits it into
81
+ num_layers slices. Other leaves (non-scanned params like embedding tables)
82
+ are kept unchanged.
83
+
84
+ Returns: list of num_layers param dicts, each structured like a
85
+ single TransformerBlock's params.
86
+ """
87
+
88
+ def _is_array(tree):
89
+ """Check if tree is an ndarray-like (JAX or numpy).
90
+
91
+ After flax.serialization.from_bytes(), params become plain
92
+ numpy.ndarray instances, NOT jnp.ndarray. Duck-typing by
93
+ ndim + shape handles both cases.
94
+ """
95
+ return hasattr(tree, 'ndim') and hasattr(tree, 'shape')
96
+
97
+ def _split(tree):
98
+ """Recursively split a param tree. Returns either the original
99
+ (non-scanned) or a list of per-layer dicts (scanned)."""
100
+ if isinstance(tree, dict):
101
+ keys = sorted(tree.keys())
102
+ # Recursively split each child
103
+ split_children = {k: _split(tree[k]) for k in keys}
104
+
105
+ # Determine if any child was split (returned a list)
106
+ any_split = any(isinstance(v, list) for v in split_children.values())
107
+
108
+ if any_split:
109
+ # Merge per-layer dicts across all children
110
+ result = []
111
+ for i in range(num_layers):
112
+ layer_dict = {}
113
+ for k in keys:
114
+ child = split_children[k]
115
+ if isinstance(child, list):
116
+ layer_dict[k] = child[i]
117
+ else:
118
+ # Non-scanned: same across all layers
119
+ layer_dict[k] = child
120
+ result.append(layer_dict)
121
+ return result
122
+ else:
123
+ return tree
124
+ elif _is_array(tree):
125
+ if tree.ndim > 0 and tree.shape[0] == num_layers:
126
+ return [tree[i] for i in range(num_layers)]
127
+ else:
128
+ return tree
129
+ else:
130
+ return tree
131
+
132
+ # The scan_block subtree contains stacked per-layer params
133
+ # Structure: scan_block β†’ {Dense_0: {kernel: [L, ...], bias: [L, ...]}, ...}
134
+ result = _split(params)
135
+
136
+ if isinstance(result, list):
137
+ return result
138
+ elif isinstance(result, dict):
139
+ # No scanned params found β€” this means all params are non-scanned
140
+ # which shouldn't happen for scan_block. Return as replicated.
141
+ return [result] * num_layers
142
+ else:
143
+ raise ValueError(f"Unexpected result from _split: {type(result)}")
144
 
145
 
146
  class GPTModel(nn.Module):
147
  config: LaughLMConfig
148
 
149
  def setup(self):
150
+ cfg = self.config
151
+ d_model = cfg.model.d_model
152
+ vocab_size = cfg.model.vocab_size
153
+ num_layers = cfg.model.num_layers
154
+ pos_type = cfg.architecture.positional
155
 
156
+ self._compute_dtype = resolve_compute_dtype(cfg)
157
+ self._use_scan = cfg.spmd.remat.scan_layers
158
+ self._num_layers = num_layers
 
 
 
 
 
159
 
160
+ # ── Token embedding ───────────────────────────────────
 
 
161
  self.token_embedding = nn.Embed(
162
  num_embeddings=vocab_size,
163
  features=d_model,
 
166
  ),
167
  )
168
 
169
+ # ── Positional encoding (additive only) ──────────────
 
 
170
  self.positional = build_positional_encoding(cfg)
171
 
172
+ # ── RoPE tables ───────────────────────────────────────
 
 
173
  self._use_rope = pos_type in ("rope", "rope_scaled")
174
 
175
  if self._use_rope:
176
  head_dim = d_model // cfg.model.num_heads
177
+ scale_factor = 4.0 if pos_type == "rope_scaled" else None
178
  self._rope_sin, self._rope_cos = build_rope_tables(
179
  head_dim=head_dim,
180
  max_seq_len=cfg.model.max_seq_len,
181
+ scale_factor=scale_factor,
182
  )
183
  else:
184
  self._rope_sin = None
185
  self._rope_cos = None
186
 
187
+ # ── Transformer blocks ────────────────────────────────
188
+ if self._use_scan:
189
+ # Scan mode: use nn.scan for training (O(1) compile)
190
+ # Also create a reference block for inference .apply() calls
191
+ self.scan_block = _build_scanned_block(cfg)
192
+ self._ref_block = TransformerBlock(config=cfg)
193
+ else:
194
+ # Non-scan mode: explicit blocks for both training and inference
195
+ self.blocks = [build_block(cfg) for _ in range(num_layers)]
196
 
197
+ # ── Final norm ────────────────────────────────────────
198
  self.final_norm = build_normalization(cfg)
199
 
200
+ # ── LM head ──────────────────────────────────────────
201
  if not cfg.architecture.weight_tying:
202
  self.lm_head = nn.Dense(
203
  vocab_size,
204
  use_bias=cfg.architecture.bias,
205
+ kernel_init=nn.initializers.normal(stddev=cfg.initialization.std),
 
 
206
  )
207
 
208
  def __call__(
209
  self,
210
  input_ids: jnp.ndarray,
211
  doc_ids: Optional[jnp.ndarray] = None,
212
+ kv_caches: Optional[List[KVCache]] = None,
213
+ ) -> Tuple[jnp.ndarray, Optional[List[KVCache]]]:
 
 
 
 
214
 
215
+ assert input_ids.ndim == 2, f"Expected (B, T), got {input_ids.shape}"
216
  B, T = input_ids.shape
217
 
218
+ # ── Token embedding ───────────────────────────────────
219
+ x = self.token_embedding(input_ids)
 
 
220
  x = x.astype(self._compute_dtype)
221
 
222
+ # ── Positional encoding ───────────────────────────────
 
 
223
  if self.positional is not None:
224
+ positions = jnp.arange(T)[None, :]
225
+ pos_emb = self.positional(positions)
 
 
 
 
 
 
226
  x = x + pos_emb.astype(self._compute_dtype)
227
 
228
+ # ── RoPE tables ───────────────────────────────────────
229
+ rope_tables = None
 
 
230
  if self._use_rope:
231
+ rope_tables = (self._rope_sin[:T], self._rope_cos[:T])
 
 
 
232
 
233
+ # ── Transformer stack ─────────────────────────────────
234
+ if kv_caches is not None:
235
+ # ── Inference: for-loop with per-layer KV cache ──
236
+ new_caches = []
 
237
 
238
+ if self._use_scan:
239
+ # Extract per-layer params from the scanned parameter tree.
240
+ # self.variables['params'] contains:
241
+ # {token_embedding: {...}, scan_block: {Dense_0: {kernel: [L, ...]}, ...}, ...}
242
+ # We need just the scan_block subtree, unstacked per layer.
243
+ all_params = self.variables.get("params", {})
244
+ scan_params = all_params.get("scan_block", all_params)
245
+ layer_params_list = _unstack_scan_params(scan_params, self._num_layers)
246
+
247
+ for i in range(self._num_layers):
248
+ # Use .apply() with per-layer params β€” stateless, no init needed
249
+ block_vars = {"params": layer_params_list[i]}
250
+ x, new_cache = self._ref_block.apply(
251
+ block_vars,
252
+ x,
253
+ rope_tables=rope_tables,
254
+ doc_ids=doc_ids,
255
+ kv_cache=kv_caches[i],
256
+ )
257
+ new_caches.append(new_cache)
258
+ else:
259
+ for i, block in enumerate(self.blocks):
260
+ x, new_cache = block(
261
+ x,
262
+ rope_tables=rope_tables,
263
+ doc_ids=doc_ids,
264
+ kv_cache=kv_caches[i],
265
+ )
266
+ new_caches.append(new_cache)
267
 
268
+ elif self._use_scan:
269
+ # ── Training: nn.scan (O(1) compile, optimal) ──
270
+ x, _ = self.scan_block(x, rope_tables, doc_ids, None)
271
+ new_caches = None
272
+
273
+ else:
274
+ # ── Fallback: for-loop (no scan) ──
275
+ for block in self.blocks:
276
+ x, _ = block(x, rope_tables=rope_tables, doc_ids=doc_ids, kv_cache=None)
277
+ new_caches = None
278
+
279
+ # ── Final norm + logits ─────────────���─────────────────
280
+ x = self.final_norm(x)
281
  x = x.astype(jnp.float32)
282
 
 
 
 
283
  if self.config.architecture.weight_tying:
284
+ embedding_table = self.token_embedding.embedding
285
  logits = jnp.einsum("btd,vd->btv", x, embedding_table)
286
  else:
287
  logits = self.lm_head(x)
288
 
289
+ return logits, new_caches