mlx-llada2-uni / test_weights.py
treadon's picture
Upload test_weights.py with huggingface_hub
ccec450 verified
"""Test weight loader key coverage against model.safetensors.index.json.
Walks the index file and verifies every HF key has a mapping — does NOT actually load
safetensor blobs (so this runs before the full download finishes).
"""
import json
from collections import Counter
from pathlib import Path
from llada2.model import LLaDA2Config, LLaDA2Model
from llada2.weights import _route_weight
def main():
snap = Path(
"/Users/ritesh/.cache/huggingface/hub/models--inclusionAI--LLaDA2.0-Uni/"
"snapshots/f94fc0230089a283cc6bd1dd6542c24d7e7489b2"
)
idx = json.loads((snap / "model.safetensors.index.json").read_text())
cfg = LLaDA2Config.from_hf(json.loads((snap / "config.json").read_text()))
params: dict = {}
expert_slots: dict = {}
import mlx.core as mx
dummy = mx.zeros((1,)) # placeholder; we only care about keys
routed = 0
errors = []
for hf_key in idx["weight_map"].keys():
try:
_route_weight(hf_key, dummy, params, expert_slots, cfg)
routed += 1
except Exception as e:
errors.append((hf_key, str(e)))
print(f"routed ok: {routed} / {len(idx['weight_map'])}")
if errors:
print("first 10 errors:")
for k, e in errors[:10]:
print(f" {k}: {e}")
else:
# Summarize expert coverage per layer
per_layer = Counter()
for (li, ei) in expert_slots.keys():
per_layer[li] += 1
for li in sorted(per_layer):
print(f" layer {li}: {per_layer[li]} experts routed")
# Build model and cross-check that the param names we produced exist
model = LLaDA2Model(cfg)
from mlx.utils import tree_flatten
expected = {k for k, _ in tree_flatten(model.parameters())}
produced = set(params.keys()) | {
f"model.layers.{li}.mlp.experts_gate_w" for li in range(cfg.num_hidden_layers)
if li >= cfg.first_k_dense_replace
} | {
f"model.layers.{li}.mlp.experts_up_w" for li in range(cfg.num_hidden_layers)
if li >= cfg.first_k_dense_replace
} | {
f"model.layers.{li}.mlp.experts_down_w" for li in range(cfg.num_hidden_layers)
if li >= cfg.first_k_dense_replace
}
missing_from_expected = produced - expected
missing_from_produced = expected - produced
print(f"\nparams in model but not produced: {len(missing_from_produced)}")
for k in sorted(missing_from_produced)[:20]:
print(f" MISSING FROM LOADER: {k}")
print(f"\nparams produced but not in model: {len(missing_from_expected)}")
for k in sorted(missing_from_expected)[:20]:
print(f" EXTRA: {k}")
if __name__ == "__main__":
main()