Fix: _unstack_scan_params breaks after flax deserialization (from_bytes returns numpy arrays)

#1
by dignity045 - opened

Same fix as https://huggingface.co/Dhiraj45/LaughLM/discussions/1

Bug

generate.py crashes with flax.errors.ScopeParamShapeError when loading an exported model (params.msgpack). _unstack_scan_params used isinstance(tree, jnp.ndarray) to detect leaf arrays β€” after flax.serialization.from_bytes(), params become plain numpy.ndarray, NOT jnp.ndarray. Stacked (num_layers, d_model) params slip through unsplit.

Fix

Replace isinstance(tree, jnp.ndarray) with duck-typing: hasattr(tree, 'ndim') and hasattr(tree, 'shape'), which handles both JAX and NumPy arrays.

dignity045 changed pull request status to open
dignity045 changed pull request status to merged

Sign up or log in to comment