Fix: _unstack_scan_params breaks after flax deserialization (from_bytes returns numpy arrays)
#1
by dignity045 - opened
Same fix as https://huggingface.co/Dhiraj45/LaughLM/discussions/1
Bug
generate.py crashes with flax.errors.ScopeParamShapeError when loading an exported model (params.msgpack). _unstack_scan_params used isinstance(tree, jnp.ndarray) to detect leaf arrays β after flax.serialization.from_bytes(), params become plain numpy.ndarray, NOT jnp.ndarray. Stacked (num_layers, d_model) params slip through unsplit.
Fix
Replace isinstance(tree, jnp.ndarray) with duck-typing: hasattr(tree, 'ndim') and hasattr(tree, 'shape'), which handles both JAX and NumPy arrays.
dignity045 changed pull request status to open
dignity045 changed pull request status to merged