Fix: _unstack_scan_params breaks after flax deserialization (from_bytes returns numpy arrays)
#1
by dignity045 - opened
Bug
generate.py crashes with flax.errors.ScopeParamShapeError when loading an exported model (params.msgpack). The inference path uses _unstack_scan_params in gpt.py to extract per-layer params from the scanned param tree. However, _unstack_scan_params used isinstance(tree, jnp.ndarray) to detect leaf arrays — after flax.serialization.from_bytes(), params become plain numpy.ndarray instances, NOT jnp.ndarray. The check silently fails, the stacked (num_layers, d_model) params are passed to _ref_block.apply() instead of per-layer (d_model,) slices, triggering the shape error.
Fix
Replaced isinstance(tree, jnp.ndarray) with duck-typing via hasattr(tree, 'ndim') and hasattr(tree, 'shape'), which correctly handles both JAX and NumPy arrays.
Tested
- Reproduced the exact error with the user's config (d_model=768, num_layers=16, parallel_block=True)
- Verified fix resolves both parallel_block=True and parallel_block=False cases
- Full end-to-end export→load→generate cycle passes
dignity045 changed pull request status to open
Dhiraj45 changed pull request status to merged