mlx-llada2-uni / llada2 /weights.py
treadon's picture
Upload llada2/weights.py with huggingface_hub
56fe8b3 verified
"""Load HF safetensors into the MLX LLaDA2Model.
HF uses ModuleList for experts (experts.0, experts.1, ...). We pack them into
3 big arrays per layer: experts_gate_w, experts_up_w, experts_down_w.
All other weight names map directly.
"""
from __future__ import annotations
import json
from pathlib import Path
from typing import Dict
import mlx.core as mx
from mlx.utils import tree_unflatten
def _bf16(a: mx.array) -> mx.array:
return a.astype(mx.bfloat16)
def load_weights_into_model(model, snapshot_dir: Path, dtype=mx.bfloat16, verbose: bool = True):
"""Load shards from `snapshot_dir` into `model` (an LLaDA2Model).
Expects model.safetensors.index.json and model-*.safetensors in the dir.
Converts HF key layout → our MLX key layout.
"""
snapshot_dir = Path(snapshot_dir)
index_path = snapshot_dir / "model.safetensors.index.json"
index = json.loads(index_path.read_text())
weight_map: Dict[str, str] = index["weight_map"]
config = model.config
# Group keys by shard file to open each once
shards: dict[str, list[str]] = {}
for k, f in weight_map.items():
shards.setdefault(f, []).append(k)
# Collect (layer_idx, expert_idx) → (gate_w, up_w, down_w) in temp storage
expert_slots: dict[tuple[int, int], dict[str, mx.array]] = {}
params: dict[str, mx.array] = {}
for shard_idx, (shard_file, keys) in enumerate(shards.items()):
if verbose:
print(f"[weights] opening {shard_file} ({len(keys)} keys) [{shard_idx+1}/{len(shards)}]")
shard_path = snapshot_dir / shard_file
tensors = mx.load(str(shard_path)) # dict[str, mx.array]
for hf_key in keys:
t = tensors[hf_key].astype(dtype)
_route_weight(hf_key, t, params, expert_slots, config)
# release the shard dict so its per-expert refs only live in expert_slots
del tensors
# Stack experts into packed arrays (memory-conscious: pop/delete as we go)
for layer_idx in range(config.num_hidden_layers):
if layer_idx < config.first_k_dense_replace:
continue
gates, ups, downs = [], [], []
for e in range(config.num_experts):
slot = expert_slots.pop((layer_idx, e), None)
if slot is None:
raise ValueError(f"Missing expert ({layer_idx}, {e})")
gates.append(slot["gate"])
ups.append(slot["up"])
downs.append(slot["down"])
prefix = f"model.layers.{layer_idx}.mlp"
params[f"{prefix}.experts_gate_w"] = mx.stack(gates, axis=0)
del gates
params[f"{prefix}.experts_up_w"] = mx.stack(ups, axis=0)
del ups
params[f"{prefix}.experts_down_w"] = mx.stack(downs, axis=0)
del downs
mx.eval(params[f"{prefix}.experts_gate_w"],
params[f"{prefix}.experts_up_w"],
params[f"{prefix}.experts_down_w"])
# Load into model
leaves = [(k, v) for k, v in params.items()]
model.update(tree_unflatten(leaves))
mx.eval(model.parameters())
if verbose:
print(f"[weights] loaded {len(leaves)} tensors")
return model
def _route_weight(hf_key: str, tensor: mx.array, params: dict, expert_slots: dict, config):
"""Map a single HF key -> our MLX key (and packing for experts)."""
# Examples of HF keys:
# model.word_embeddings.weight
# model.layers.3.attention.query_key_value.weight
# model.layers.3.attention.dense.weight
# model.layers.3.attention.query_layernorm.weight
# model.layers.3.attention.key_layernorm.weight
# model.layers.3.input_layernorm.weight
# model.layers.3.post_attention_layernorm.weight
# model.layers.0.mlp.{gate,up,down}_proj.weight (dense layer)
# model.layers.3.mlp.experts.{e}.{gate,up,down}_proj.weight (routed expert)
# model.layers.3.mlp.gate.weight
# model.layers.3.mlp.gate.expert_bias
# model.layers.3.mlp.shared_experts.{gate,up,down}_proj.weight
# model.norm.weight
# lm_head.weight
# Our LLaDA2Model stores the backbone as `self.model` so flattened param keys
# look like `model.word_embeddings.weight`, `model.layers.{i}.*`. The HF checkpoint
# also uses `model.*` at the top level, so the prefix passes through unchanged.
if hf_key == "model.word_embeddings.weight":
params["model.word_embeddings.weight"] = tensor
return
if hf_key == "model.norm.weight":
params["model.norm.weight"] = tensor
return
if hf_key == "lm_head.weight":
params["lm_head.weight"] = tensor
return
if not hf_key.startswith("model.layers."):
raise ValueError(f"Unknown HF key prefix: {hf_key}")
parts = hf_key.split(".")
layer_idx = int(parts[2])
rest = parts[3:]
prefix = f"model.layers.{layer_idx}"
# Per-layer norms
if rest == ["input_layernorm", "weight"]:
params[f"{prefix}.input_layernorm.weight"] = tensor
return
if rest == ["post_attention_layernorm", "weight"]:
params[f"{prefix}.post_attention_layernorm.weight"] = tensor
return
# Attention
if rest[:1] == ["attention"]:
sub = rest[1:]
if sub == ["query_key_value", "weight"]:
params[f"{prefix}.attention.query_key_value.weight"] = tensor
return
if sub == ["dense", "weight"]:
params[f"{prefix}.attention.dense.weight"] = tensor
return
if sub == ["query_layernorm", "weight"]:
params[f"{prefix}.attention.query_layernorm.weight"] = tensor
return
if sub == ["key_layernorm", "weight"]:
params[f"{prefix}.attention.key_layernorm.weight"] = tensor
return
raise ValueError(f"Unknown attention subkey: {hf_key}")
# MLP
if rest[:1] == ["mlp"]:
sub = rest[1:]
if layer_idx < config.first_k_dense_replace:
# Dense MLP
if sub == ["gate_proj", "weight"]:
params[f"{prefix}.mlp.gate_proj.weight"] = tensor
return
if sub == ["up_proj", "weight"]:
params[f"{prefix}.mlp.up_proj.weight"] = tensor
return
if sub == ["down_proj", "weight"]:
params[f"{prefix}.mlp.down_proj.weight"] = tensor
return
raise ValueError(f"Unknown dense-mlp subkey: {hf_key}")
# MoE layer
if sub[:1] == ["gate"]:
if sub == ["gate", "weight"]:
params[f"{prefix}.mlp.gate.weight"] = tensor
return
if sub == ["gate", "expert_bias"]:
params[f"{prefix}.mlp.gate.expert_bias"] = tensor
return
raise ValueError(f"Unknown gate subkey: {hf_key}")
if sub[:1] == ["shared_experts"]:
which = sub[1]
if which == "gate_proj":
params[f"{prefix}.mlp.shared_gate_proj.weight"] = tensor; return
if which == "up_proj":
params[f"{prefix}.mlp.shared_up_proj.weight"] = tensor; return
if which == "down_proj":
params[f"{prefix}.mlp.shared_down_proj.weight"] = tensor; return
raise ValueError(f"Unknown shared_experts subkey: {hf_key}")
if sub[:1] == ["experts"]:
# experts.{e}.{gate|up|down}_proj.weight
expert_idx = int(sub[1])
which = sub[2]
slot = expert_slots.setdefault((layer_idx, expert_idx), {})
if which == "gate_proj":
slot["gate"] = tensor; return
if which == "up_proj":
slot["up"] = tensor; return
if which == "down_proj":
slot["down"] = tensor; return
raise ValueError(f"Unknown expert subkey: {hf_key}")
raise ValueError(f"Unknown mlp subkey: {hf_key}")
raise ValueError(f"Unhandled key: {hf_key}")