File size: 2,697 Bytes
ccec450
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
"""Test weight loader key coverage against model.safetensors.index.json.

Walks the index file and verifies every HF key has a mapping — does NOT actually load
safetensor blobs (so this runs before the full download finishes).
"""
import json
from collections import Counter
from pathlib import Path

from llada2.model import LLaDA2Config, LLaDA2Model
from llada2.weights import _route_weight


def main():
    snap = Path(
        "/Users/ritesh/.cache/huggingface/hub/models--inclusionAI--LLaDA2.0-Uni/"
        "snapshots/f94fc0230089a283cc6bd1dd6542c24d7e7489b2"
    )
    idx = json.loads((snap / "model.safetensors.index.json").read_text())
    cfg = LLaDA2Config.from_hf(json.loads((snap / "config.json").read_text()))

    params: dict = {}
    expert_slots: dict = {}
    import mlx.core as mx
    dummy = mx.zeros((1,))  # placeholder; we only care about keys

    routed = 0
    errors = []
    for hf_key in idx["weight_map"].keys():
        try:
            _route_weight(hf_key, dummy, params, expert_slots, cfg)
            routed += 1
        except Exception as e:
            errors.append((hf_key, str(e)))

    print(f"routed ok: {routed} / {len(idx['weight_map'])}")
    if errors:
        print("first 10 errors:")
        for k, e in errors[:10]:
            print(f"  {k}: {e}")
    else:
        # Summarize expert coverage per layer
        per_layer = Counter()
        for (li, ei) in expert_slots.keys():
            per_layer[li] += 1
        for li in sorted(per_layer):
            print(f"  layer {li}: {per_layer[li]} experts routed")

    # Build model and cross-check that the param names we produced exist
    model = LLaDA2Model(cfg)
    from mlx.utils import tree_flatten
    expected = {k for k, _ in tree_flatten(model.parameters())}
    produced = set(params.keys()) | {
        f"model.layers.{li}.mlp.experts_gate_w" for li in range(cfg.num_hidden_layers)
        if li >= cfg.first_k_dense_replace
    } | {
        f"model.layers.{li}.mlp.experts_up_w" for li in range(cfg.num_hidden_layers)
        if li >= cfg.first_k_dense_replace
    } | {
        f"model.layers.{li}.mlp.experts_down_w" for li in range(cfg.num_hidden_layers)
        if li >= cfg.first_k_dense_replace
    }
    missing_from_expected = produced - expected
    missing_from_produced = expected - produced
    print(f"\nparams in model but not produced: {len(missing_from_produced)}")
    for k in sorted(missing_from_produced)[:20]:
        print(f"  MISSING FROM LOADER: {k}")
    print(f"\nparams produced but not in model: {len(missing_from_expected)}")
    for k in sorted(missing_from_expected)[:20]:
        print(f"  EXTRA: {k}")


if __name__ == "__main__":
    main()