File size: 2,017 Bytes
e39ff3a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import argparse
import json
from pathlib import Path
import numpy as np


def main():
    p = argparse.ArgumentParser(description="Summarize MLX-LM quantization layout")
    p.add_argument("--model-path", required=True, help="Path to converted MLX model")
    p.add_argument("--show", type=int, default=10, help="Show up to N entries per group")
    args = p.parse_args()

    mpath = Path(args.model_path)
    cfg = json.loads((mpath / "config.json").read_text())
    q = cfg.get("quantization") or {}
    method = q.get("method", "none")
    gsize = q.get("group_size")
    plb = q.get("per_layer_bits", {})

    print(f"Method: {method}")
    print(f"Group size: {gsize}")
    if method == "uniform":
        print(f"Uniform bits: {q.get('bits')}")
        return

    if not plb:
        print("No per-layer bits found in config.")
        return

    # Basic counts
    buckets = {4: [], 8: [], "other": []}
    for k, b in plb.items():
        if b == 4:
            buckets[4].append(k)
        elif b == 8:
            buckets[8].append(k)
        else:
            buckets["other"].append(k)

    total = sum(len(v) for v in buckets.values())
    print(f"Total linear layers: {total}")
    print(f"4-bit layers: {len(buckets[4])}")
    print(f"8-bit layers: {len(buckets[8])}")
    if buckets["other"]:
        print(f"Other-bit layers: {len(buckets['other'])}")

    # Optional: show a few examples
    for b in (8, 4):
        items = sorted(buckets[b])
        if not items:
            continue
        print(f"\nExamples ({b}-bit):")
        for k in items[: args.show]:
            print(f"- {k}")

    # Optional: sanity-check against npz contents
    try:
        npz = np.load(mpath / "weights.npz", allow_pickle=False)
        has_q = any(k.endswith(".scales") or k.endswith(".biases") for k in npz.files)
        print(f"\nweights.npz contains quantized tensors: {has_q}")
    except Exception as e:
        print(f"Note: could not open weights.npz: {e}")


if __name__ == "__main__":
    main()