"""Load HF safetensors into the MLX LLaDA2Model. HF uses ModuleList for experts (experts.0, experts.1, ...). We pack them into 3 big arrays per layer: experts_gate_w, experts_up_w, experts_down_w. All other weight names map directly. """ from __future__ import annotations import json from pathlib import Path from typing import Dict import mlx.core as mx from mlx.utils import tree_unflatten def _bf16(a: mx.array) -> mx.array: return a.astype(mx.bfloat16) def load_weights_into_model(model, snapshot_dir: Path, dtype=mx.bfloat16, verbose: bool = True): """Load shards from `snapshot_dir` into `model` (an LLaDA2Model). Expects model.safetensors.index.json and model-*.safetensors in the dir. Converts HF key layout → our MLX key layout. """ snapshot_dir = Path(snapshot_dir) index_path = snapshot_dir / "model.safetensors.index.json" index = json.loads(index_path.read_text()) weight_map: Dict[str, str] = index["weight_map"] config = model.config # Group keys by shard file to open each once shards: dict[str, list[str]] = {} for k, f in weight_map.items(): shards.setdefault(f, []).append(k) # Collect (layer_idx, expert_idx) → (gate_w, up_w, down_w) in temp storage expert_slots: dict[tuple[int, int], dict[str, mx.array]] = {} params: dict[str, mx.array] = {} for shard_idx, (shard_file, keys) in enumerate(shards.items()): if verbose: print(f"[weights] opening {shard_file} ({len(keys)} keys) [{shard_idx+1}/{len(shards)}]") shard_path = snapshot_dir / shard_file tensors = mx.load(str(shard_path)) # dict[str, mx.array] for hf_key in keys: t = tensors[hf_key].astype(dtype) _route_weight(hf_key, t, params, expert_slots, config) # release the shard dict so its per-expert refs only live in expert_slots del tensors # Stack experts into packed arrays (memory-conscious: pop/delete as we go) for layer_idx in range(config.num_hidden_layers): if layer_idx < config.first_k_dense_replace: continue gates, ups, downs = [], [], [] for e in range(config.num_experts): slot = expert_slots.pop((layer_idx, e), None) if slot is None: raise ValueError(f"Missing expert ({layer_idx}, {e})") gates.append(slot["gate"]) ups.append(slot["up"]) downs.append(slot["down"]) prefix = f"model.layers.{layer_idx}.mlp" params[f"{prefix}.experts_gate_w"] = mx.stack(gates, axis=0) del gates params[f"{prefix}.experts_up_w"] = mx.stack(ups, axis=0) del ups params[f"{prefix}.experts_down_w"] = mx.stack(downs, axis=0) del downs mx.eval(params[f"{prefix}.experts_gate_w"], params[f"{prefix}.experts_up_w"], params[f"{prefix}.experts_down_w"]) # Load into model leaves = [(k, v) for k, v in params.items()] model.update(tree_unflatten(leaves)) mx.eval(model.parameters()) if verbose: print(f"[weights] loaded {len(leaves)} tensors") return model def _route_weight(hf_key: str, tensor: mx.array, params: dict, expert_slots: dict, config): """Map a single HF key -> our MLX key (and packing for experts).""" # Examples of HF keys: # model.word_embeddings.weight # model.layers.3.attention.query_key_value.weight # model.layers.3.attention.dense.weight # model.layers.3.attention.query_layernorm.weight # model.layers.3.attention.key_layernorm.weight # model.layers.3.input_layernorm.weight # model.layers.3.post_attention_layernorm.weight # model.layers.0.mlp.{gate,up,down}_proj.weight (dense layer) # model.layers.3.mlp.experts.{e}.{gate,up,down}_proj.weight (routed expert) # model.layers.3.mlp.gate.weight # model.layers.3.mlp.gate.expert_bias # model.layers.3.mlp.shared_experts.{gate,up,down}_proj.weight # model.norm.weight # lm_head.weight # Our LLaDA2Model stores the backbone as `self.model` so flattened param keys # look like `model.word_embeddings.weight`, `model.layers.{i}.*`. The HF checkpoint # also uses `model.*` at the top level, so the prefix passes through unchanged. if hf_key == "model.word_embeddings.weight": params["model.word_embeddings.weight"] = tensor return if hf_key == "model.norm.weight": params["model.norm.weight"] = tensor return if hf_key == "lm_head.weight": params["lm_head.weight"] = tensor return if not hf_key.startswith("model.layers."): raise ValueError(f"Unknown HF key prefix: {hf_key}") parts = hf_key.split(".") layer_idx = int(parts[2]) rest = parts[3:] prefix = f"model.layers.{layer_idx}" # Per-layer norms if rest == ["input_layernorm", "weight"]: params[f"{prefix}.input_layernorm.weight"] = tensor return if rest == ["post_attention_layernorm", "weight"]: params[f"{prefix}.post_attention_layernorm.weight"] = tensor return # Attention if rest[:1] == ["attention"]: sub = rest[1:] if sub == ["query_key_value", "weight"]: params[f"{prefix}.attention.query_key_value.weight"] = tensor return if sub == ["dense", "weight"]: params[f"{prefix}.attention.dense.weight"] = tensor return if sub == ["query_layernorm", "weight"]: params[f"{prefix}.attention.query_layernorm.weight"] = tensor return if sub == ["key_layernorm", "weight"]: params[f"{prefix}.attention.key_layernorm.weight"] = tensor return raise ValueError(f"Unknown attention subkey: {hf_key}") # MLP if rest[:1] == ["mlp"]: sub = rest[1:] if layer_idx < config.first_k_dense_replace: # Dense MLP if sub == ["gate_proj", "weight"]: params[f"{prefix}.mlp.gate_proj.weight"] = tensor return if sub == ["up_proj", "weight"]: params[f"{prefix}.mlp.up_proj.weight"] = tensor return if sub == ["down_proj", "weight"]: params[f"{prefix}.mlp.down_proj.weight"] = tensor return raise ValueError(f"Unknown dense-mlp subkey: {hf_key}") # MoE layer if sub[:1] == ["gate"]: if sub == ["gate", "weight"]: params[f"{prefix}.mlp.gate.weight"] = tensor return if sub == ["gate", "expert_bias"]: params[f"{prefix}.mlp.gate.expert_bias"] = tensor return raise ValueError(f"Unknown gate subkey: {hf_key}") if sub[:1] == ["shared_experts"]: which = sub[1] if which == "gate_proj": params[f"{prefix}.mlp.shared_gate_proj.weight"] = tensor; return if which == "up_proj": params[f"{prefix}.mlp.shared_up_proj.weight"] = tensor; return if which == "down_proj": params[f"{prefix}.mlp.shared_down_proj.weight"] = tensor; return raise ValueError(f"Unknown shared_experts subkey: {hf_key}") if sub[:1] == ["experts"]: # experts.{e}.{gate|up|down}_proj.weight expert_idx = int(sub[1]) which = sub[2] slot = expert_slots.setdefault((layer_idx, expert_idx), {}) if which == "gate_proj": slot["gate"] = tensor; return if which == "up_proj": slot["up"] = tensor; return if which == "down_proj": slot["down"] = tensor; return raise ValueError(f"Unknown expert subkey: {hf_key}") raise ValueError(f"Unknown mlp subkey: {hf_key}") raise ValueError(f"Unhandled key: {hf_key}")