File size: 1,665 Bytes
39057fb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
#!/usr/bin/env python
"""Quantise a converted component's safetensors (DiT, patch_encoder, ...) to mlx
int4/int8, matching mlx's default Linear eligibility so the Swift `quantize(model:)`
call picks the same layers. Quantises 2D `.weight` tensors whose in-features are
divisible by group_size (Linear weights [out, in]); leaves norms (1D), conv (3D)
and biases in fp. Writes config.json with the mlx_lm `quantization` block.

Usage: quantize_component.py <src_dir> <dst_dir> <bits> <group_size>
"""
import json
import sys
from pathlib import Path

import mlx.core as mx

src = Path(sys.argv[1])
dst = Path(sys.argv[2])
bits = int(sys.argv[3])
gs = int(sys.argv[4])
dst.mkdir(parents=True, exist_ok=True)

w = mx.load(str(src / "model.safetensors"))
out = {}
nq = 0
skipped = []
for k, v in w.items():
    if k.endswith(".weight") and v.ndim == 2 and v.shape[-1] % gs == 0:
        wq, scales, biases = mx.quantize(v, group_size=gs, bits=bits)
        base = k[: -len(".weight")]
        out[k] = wq
        out[base + ".scales"] = scales
        out[base + ".biases"] = biases
        nq += 1
    else:
        out[k] = v
        if k.endswith(".weight") and v.ndim == 2:
            skipped.append((k, tuple(v.shape)))

mx.save_safetensors(str(dst / "model.safetensors"), out)
cfg = {}
if (src / "config.json").exists():
    cfg = json.loads((src / "config.json").read_text())
cfg["quantization"] = {"group_size": gs, "bits": bits, "mode": "affine"}
(dst / "config.json").write_text(json.dumps(cfg, indent=2))
print(f"{src.name}: quantised {nq} linears -> {dst} ({bits}-bit g{gs})")
if skipped:
    print("  skipped 2D weights (in not divisible by gs):", skipped)