import json from pathlib import Path import numpy as np import mlx.core as mx import mlx.nn as nn from mlx.utils import tree_unflatten # We must import from the project root's model.py from model import Model, ModelArgs def load_model(model_path: str): model_path = Path(model_path) with open(model_path / "config.json", "r") as f: config = json.load(f) # Peek with numpy to inspect keys without materializing MLX arrays yet npz_path = model_path / "weights.npz" npz = np.load(npz_path, allow_pickle=False) keys = list(npz.files) has_dual = any("g_up" in k for k in keys) args = ModelArgs.from_dict(config) args.use_dual_mlp = bool(has_dual) model = Model(args) # If quantization metadata is present, re-materialize QuantizedLinear modules qcfg = config.get("quantization") or {} method = qcfg.get("method") group_size = qcfg.get("group_size") if method == "uniform": bits = int(qcfg.get("bits", 4)) nn.quantize( model, group_size=int(group_size) if group_size is not None else 64, bits=bits, class_predicate=lambda p, m: isinstance(m, nn.Linear), ) elif method == "mixed_precision_dynamic": per_layer_bits = qcfg.get("per_layer_bits", {}) def predicate(p, m): if not isinstance(m, nn.Linear): return False b = per_layer_bits.get(p) if b is None: return False return {"bits": int(b), "group_size": int(group_size)} nn.quantize( model, group_size=int(group_size) if group_size is not None else 64, bits=4, class_predicate=predicate, ) # Now load the actual weights into MLX and update weights = mx.load(str(npz_path)) model.update(tree_unflatten(list(weights.items()))) return model