robbiemu's picture
add mlx and mlx-lm support
e39ff3a
import json
from pathlib import Path
import numpy as np
import mlx.core as mx
import mlx.nn as nn
from mlx.utils import tree_unflatten
# We must import from the project root's model.py
from model import Model, ModelArgs
def load_model(model_path: str):
model_path = Path(model_path)
with open(model_path / "config.json", "r") as f:
config = json.load(f)
# Peek with numpy to inspect keys without materializing MLX arrays yet
npz_path = model_path / "weights.npz"
npz = np.load(npz_path, allow_pickle=False)
keys = list(npz.files)
has_dual = any("g_up" in k for k in keys)
args = ModelArgs.from_dict(config)
args.use_dual_mlp = bool(has_dual)
model = Model(args)
# If quantization metadata is present, re-materialize QuantizedLinear modules
qcfg = config.get("quantization") or {}
method = qcfg.get("method")
group_size = qcfg.get("group_size")
if method == "uniform":
bits = int(qcfg.get("bits", 4))
nn.quantize(
model,
group_size=int(group_size) if group_size is not None else 64,
bits=bits,
class_predicate=lambda p, m: isinstance(m, nn.Linear),
)
elif method == "mixed_precision_dynamic":
per_layer_bits = qcfg.get("per_layer_bits", {})
def predicate(p, m):
if not isinstance(m, nn.Linear):
return False
b = per_layer_bits.get(p)
if b is None:
return False
return {"bits": int(b), "group_size": int(group_size)}
nn.quantize(
model,
group_size=int(group_size) if group_size is not None else 64,
bits=4,
class_predicate=predicate,
)
# Now load the actual weights into MLX and update
weights = mx.load(str(npz_path))
model.update(tree_unflatten(list(weights.items())))
return model