"""Test weight loader key coverage against model.safetensors.index.json. Walks the index file and verifies every HF key has a mapping — does NOT actually load safetensor blobs (so this runs before the full download finishes). """ import json from collections import Counter from pathlib import Path from llada2.model import LLaDA2Config, LLaDA2Model from llada2.weights import _route_weight def main(): snap = Path( "/Users/ritesh/.cache/huggingface/hub/models--inclusionAI--LLaDA2.0-Uni/" "snapshots/f94fc0230089a283cc6bd1dd6542c24d7e7489b2" ) idx = json.loads((snap / "model.safetensors.index.json").read_text()) cfg = LLaDA2Config.from_hf(json.loads((snap / "config.json").read_text())) params: dict = {} expert_slots: dict = {} import mlx.core as mx dummy = mx.zeros((1,)) # placeholder; we only care about keys routed = 0 errors = [] for hf_key in idx["weight_map"].keys(): try: _route_weight(hf_key, dummy, params, expert_slots, cfg) routed += 1 except Exception as e: errors.append((hf_key, str(e))) print(f"routed ok: {routed} / {len(idx['weight_map'])}") if errors: print("first 10 errors:") for k, e in errors[:10]: print(f" {k}: {e}") else: # Summarize expert coverage per layer per_layer = Counter() for (li, ei) in expert_slots.keys(): per_layer[li] += 1 for li in sorted(per_layer): print(f" layer {li}: {per_layer[li]} experts routed") # Build model and cross-check that the param names we produced exist model = LLaDA2Model(cfg) from mlx.utils import tree_flatten expected = {k for k, _ in tree_flatten(model.parameters())} produced = set(params.keys()) | { f"model.layers.{li}.mlp.experts_gate_w" for li in range(cfg.num_hidden_layers) if li >= cfg.first_k_dense_replace } | { f"model.layers.{li}.mlp.experts_up_w" for li in range(cfg.num_hidden_layers) if li >= cfg.first_k_dense_replace } | { f"model.layers.{li}.mlp.experts_down_w" for li in range(cfg.num_hidden_layers) if li >= cfg.first_k_dense_replace } missing_from_expected = produced - expected missing_from_produced = expected - produced print(f"\nparams in model but not produced: {len(missing_from_produced)}") for k in sorted(missing_from_produced)[:20]: print(f" MISSING FROM LOADER: {k}") print(f"\nparams produced but not in model: {len(missing_from_expected)}") for k in sorted(missing_from_expected)[:20]: print(f" EXTRA: {k}") if __name__ == "__main__": main()