|
|
import json |
|
|
from pathlib import Path |
|
|
import numpy as np |
|
|
import mlx.core as mx |
|
|
import mlx.nn as nn |
|
|
from mlx.utils import tree_unflatten |
|
|
|
|
|
|
|
|
from model import Model, ModelArgs |
|
|
|
|
|
|
|
|
def load_model(model_path: str): |
|
|
model_path = Path(model_path) |
|
|
with open(model_path / "config.json", "r") as f: |
|
|
config = json.load(f) |
|
|
|
|
|
|
|
|
npz_path = model_path / "weights.npz" |
|
|
npz = np.load(npz_path, allow_pickle=False) |
|
|
keys = list(npz.files) |
|
|
has_dual = any("g_up" in k for k in keys) |
|
|
|
|
|
args = ModelArgs.from_dict(config) |
|
|
args.use_dual_mlp = bool(has_dual) |
|
|
model = Model(args) |
|
|
|
|
|
|
|
|
qcfg = config.get("quantization") or {} |
|
|
method = qcfg.get("method") |
|
|
group_size = qcfg.get("group_size") |
|
|
|
|
|
if method == "uniform": |
|
|
bits = int(qcfg.get("bits", 4)) |
|
|
nn.quantize( |
|
|
model, |
|
|
group_size=int(group_size) if group_size is not None else 64, |
|
|
bits=bits, |
|
|
class_predicate=lambda p, m: isinstance(m, nn.Linear), |
|
|
) |
|
|
elif method == "mixed_precision_dynamic": |
|
|
per_layer_bits = qcfg.get("per_layer_bits", {}) |
|
|
|
|
|
def predicate(p, m): |
|
|
if not isinstance(m, nn.Linear): |
|
|
return False |
|
|
b = per_layer_bits.get(p) |
|
|
if b is None: |
|
|
return False |
|
|
return {"bits": int(b), "group_size": int(group_size)} |
|
|
|
|
|
nn.quantize( |
|
|
model, |
|
|
group_size=int(group_size) if group_size is not None else 64, |
|
|
bits=4, |
|
|
class_predicate=predicate, |
|
|
) |
|
|
|
|
|
|
|
|
weights = mx.load(str(npz_path)) |
|
|
model.update(tree_unflatten(list(weights.items()))) |
|
|
return model |
|
|
|