Any-to-Any
MLX
diffusion-lm
mixture-of-experts
multimodal
text-to-image
image-understanding
apple-silicon
llada
Instructions to use treadon/mlx-llada2-uni with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- MLX
How to use treadon/mlx-llada2-uni with MLX:
# Download the model from the Hub pip install huggingface_hub[hf_xet] huggingface-cli download --local-dir mlx-llada2-uni treadon/mlx-llada2-uni
- Notebooks
- Google Colab
- Kaggle
- Local Apps Settings
- LM Studio
| """Load HF safetensors into the MLX LLaDA2Model. | |
| HF uses ModuleList for experts (experts.0, experts.1, ...). We pack them into | |
| 3 big arrays per layer: experts_gate_w, experts_up_w, experts_down_w. | |
| All other weight names map directly. | |
| """ | |
| from __future__ import annotations | |
| import json | |
| from pathlib import Path | |
| from typing import Dict | |
| import mlx.core as mx | |
| from mlx.utils import tree_unflatten | |
| def _bf16(a: mx.array) -> mx.array: | |
| return a.astype(mx.bfloat16) | |
| def load_weights_into_model(model, snapshot_dir: Path, dtype=mx.bfloat16, verbose: bool = True): | |
| """Load shards from `snapshot_dir` into `model` (an LLaDA2Model). | |
| Expects model.safetensors.index.json and model-*.safetensors in the dir. | |
| Converts HF key layout → our MLX key layout. | |
| """ | |
| snapshot_dir = Path(snapshot_dir) | |
| index_path = snapshot_dir / "model.safetensors.index.json" | |
| index = json.loads(index_path.read_text()) | |
| weight_map: Dict[str, str] = index["weight_map"] | |
| config = model.config | |
| # Group keys by shard file to open each once | |
| shards: dict[str, list[str]] = {} | |
| for k, f in weight_map.items(): | |
| shards.setdefault(f, []).append(k) | |
| # Collect (layer_idx, expert_idx) → (gate_w, up_w, down_w) in temp storage | |
| expert_slots: dict[tuple[int, int], dict[str, mx.array]] = {} | |
| params: dict[str, mx.array] = {} | |
| for shard_idx, (shard_file, keys) in enumerate(shards.items()): | |
| if verbose: | |
| print(f"[weights] opening {shard_file} ({len(keys)} keys) [{shard_idx+1}/{len(shards)}]") | |
| shard_path = snapshot_dir / shard_file | |
| tensors = mx.load(str(shard_path)) # dict[str, mx.array] | |
| for hf_key in keys: | |
| t = tensors[hf_key].astype(dtype) | |
| _route_weight(hf_key, t, params, expert_slots, config) | |
| # release the shard dict so its per-expert refs only live in expert_slots | |
| del tensors | |
| # Stack experts into packed arrays (memory-conscious: pop/delete as we go) | |
| for layer_idx in range(config.num_hidden_layers): | |
| if layer_idx < config.first_k_dense_replace: | |
| continue | |
| gates, ups, downs = [], [], [] | |
| for e in range(config.num_experts): | |
| slot = expert_slots.pop((layer_idx, e), None) | |
| if slot is None: | |
| raise ValueError(f"Missing expert ({layer_idx}, {e})") | |
| gates.append(slot["gate"]) | |
| ups.append(slot["up"]) | |
| downs.append(slot["down"]) | |
| prefix = f"model.layers.{layer_idx}.mlp" | |
| params[f"{prefix}.experts_gate_w"] = mx.stack(gates, axis=0) | |
| del gates | |
| params[f"{prefix}.experts_up_w"] = mx.stack(ups, axis=0) | |
| del ups | |
| params[f"{prefix}.experts_down_w"] = mx.stack(downs, axis=0) | |
| del downs | |
| mx.eval(params[f"{prefix}.experts_gate_w"], | |
| params[f"{prefix}.experts_up_w"], | |
| params[f"{prefix}.experts_down_w"]) | |
| # Load into model | |
| leaves = [(k, v) for k, v in params.items()] | |
| model.update(tree_unflatten(leaves)) | |
| mx.eval(model.parameters()) | |
| if verbose: | |
| print(f"[weights] loaded {len(leaves)} tensors") | |
| return model | |
| def _route_weight(hf_key: str, tensor: mx.array, params: dict, expert_slots: dict, config): | |
| """Map a single HF key -> our MLX key (and packing for experts).""" | |
| # Examples of HF keys: | |
| # model.word_embeddings.weight | |
| # model.layers.3.attention.query_key_value.weight | |
| # model.layers.3.attention.dense.weight | |
| # model.layers.3.attention.query_layernorm.weight | |
| # model.layers.3.attention.key_layernorm.weight | |
| # model.layers.3.input_layernorm.weight | |
| # model.layers.3.post_attention_layernorm.weight | |
| # model.layers.0.mlp.{gate,up,down}_proj.weight (dense layer) | |
| # model.layers.3.mlp.experts.{e}.{gate,up,down}_proj.weight (routed expert) | |
| # model.layers.3.mlp.gate.weight | |
| # model.layers.3.mlp.gate.expert_bias | |
| # model.layers.3.mlp.shared_experts.{gate,up,down}_proj.weight | |
| # model.norm.weight | |
| # lm_head.weight | |
| # Our LLaDA2Model stores the backbone as `self.model` so flattened param keys | |
| # look like `model.word_embeddings.weight`, `model.layers.{i}.*`. The HF checkpoint | |
| # also uses `model.*` at the top level, so the prefix passes through unchanged. | |
| if hf_key == "model.word_embeddings.weight": | |
| params["model.word_embeddings.weight"] = tensor | |
| return | |
| if hf_key == "model.norm.weight": | |
| params["model.norm.weight"] = tensor | |
| return | |
| if hf_key == "lm_head.weight": | |
| params["lm_head.weight"] = tensor | |
| return | |
| if not hf_key.startswith("model.layers."): | |
| raise ValueError(f"Unknown HF key prefix: {hf_key}") | |
| parts = hf_key.split(".") | |
| layer_idx = int(parts[2]) | |
| rest = parts[3:] | |
| prefix = f"model.layers.{layer_idx}" | |
| # Per-layer norms | |
| if rest == ["input_layernorm", "weight"]: | |
| params[f"{prefix}.input_layernorm.weight"] = tensor | |
| return | |
| if rest == ["post_attention_layernorm", "weight"]: | |
| params[f"{prefix}.post_attention_layernorm.weight"] = tensor | |
| return | |
| # Attention | |
| if rest[:1] == ["attention"]: | |
| sub = rest[1:] | |
| if sub == ["query_key_value", "weight"]: | |
| params[f"{prefix}.attention.query_key_value.weight"] = tensor | |
| return | |
| if sub == ["dense", "weight"]: | |
| params[f"{prefix}.attention.dense.weight"] = tensor | |
| return | |
| if sub == ["query_layernorm", "weight"]: | |
| params[f"{prefix}.attention.query_layernorm.weight"] = tensor | |
| return | |
| if sub == ["key_layernorm", "weight"]: | |
| params[f"{prefix}.attention.key_layernorm.weight"] = tensor | |
| return | |
| raise ValueError(f"Unknown attention subkey: {hf_key}") | |
| # MLP | |
| if rest[:1] == ["mlp"]: | |
| sub = rest[1:] | |
| if layer_idx < config.first_k_dense_replace: | |
| # Dense MLP | |
| if sub == ["gate_proj", "weight"]: | |
| params[f"{prefix}.mlp.gate_proj.weight"] = tensor | |
| return | |
| if sub == ["up_proj", "weight"]: | |
| params[f"{prefix}.mlp.up_proj.weight"] = tensor | |
| return | |
| if sub == ["down_proj", "weight"]: | |
| params[f"{prefix}.mlp.down_proj.weight"] = tensor | |
| return | |
| raise ValueError(f"Unknown dense-mlp subkey: {hf_key}") | |
| # MoE layer | |
| if sub[:1] == ["gate"]: | |
| if sub == ["gate", "weight"]: | |
| params[f"{prefix}.mlp.gate.weight"] = tensor | |
| return | |
| if sub == ["gate", "expert_bias"]: | |
| params[f"{prefix}.mlp.gate.expert_bias"] = tensor | |
| return | |
| raise ValueError(f"Unknown gate subkey: {hf_key}") | |
| if sub[:1] == ["shared_experts"]: | |
| which = sub[1] | |
| if which == "gate_proj": | |
| params[f"{prefix}.mlp.shared_gate_proj.weight"] = tensor; return | |
| if which == "up_proj": | |
| params[f"{prefix}.mlp.shared_up_proj.weight"] = tensor; return | |
| if which == "down_proj": | |
| params[f"{prefix}.mlp.shared_down_proj.weight"] = tensor; return | |
| raise ValueError(f"Unknown shared_experts subkey: {hf_key}") | |
| if sub[:1] == ["experts"]: | |
| # experts.{e}.{gate|up|down}_proj.weight | |
| expert_idx = int(sub[1]) | |
| which = sub[2] | |
| slot = expert_slots.setdefault((layer_idx, expert_idx), {}) | |
| if which == "gate_proj": | |
| slot["gate"] = tensor; return | |
| if which == "up_proj": | |
| slot["up"] = tensor; return | |
| if which == "down_proj": | |
| slot["down"] = tensor; return | |
| raise ValueError(f"Unknown expert subkey: {hf_key}") | |
| raise ValueError(f"Unknown mlp subkey: {hf_key}") | |
| raise ValueError(f"Unhandled key: {hf_key}") | |