File size: 1,912 Bytes
e39ff3a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import json
from pathlib import Path
import numpy as np
import mlx.core as mx
import mlx.nn as nn
from mlx.utils import tree_unflatten

# We must import from the project root's model.py
from model import Model, ModelArgs


def load_model(model_path: str):
    model_path = Path(model_path)
    with open(model_path / "config.json", "r") as f:
        config = json.load(f)

    # Peek with numpy to inspect keys without materializing MLX arrays yet
    npz_path = model_path / "weights.npz"
    npz = np.load(npz_path, allow_pickle=False)
    keys = list(npz.files)
    has_dual = any("g_up" in k for k in keys)

    args = ModelArgs.from_dict(config)
    args.use_dual_mlp = bool(has_dual)
    model = Model(args)

    # If quantization metadata is present, re-materialize QuantizedLinear modules
    qcfg = config.get("quantization") or {}
    method = qcfg.get("method")
    group_size = qcfg.get("group_size")

    if method == "uniform":
        bits = int(qcfg.get("bits", 4))
        nn.quantize(
            model,
            group_size=int(group_size) if group_size is not None else 64,
            bits=bits,
            class_predicate=lambda p, m: isinstance(m, nn.Linear),
        )
    elif method == "mixed_precision_dynamic":
        per_layer_bits = qcfg.get("per_layer_bits", {})

        def predicate(p, m):
            if not isinstance(m, nn.Linear):
                return False
            b = per_layer_bits.get(p)
            if b is None:
                return False
            return {"bits": int(b), "group_size": int(group_size)}

        nn.quantize(
            model,
            group_size=int(group_size) if group_size is not None else 64,
            bits=4,
            class_predicate=predicate,
        )

    # Now load the actual weights into MLX and update
    weights = mx.load(str(npz_path))
    model.update(tree_unflatten(list(weights.items())))
    return model