#!/usr/bin/env python """Quantise a converted component's safetensors (DiT, patch_encoder, ...) to mlx int4/int8, matching mlx's default Linear eligibility so the Swift `quantize(model:)` call picks the same layers. Quantises 2D `.weight` tensors whose in-features are divisible by group_size (Linear weights [out, in]); leaves norms (1D), conv (3D) and biases in fp. Writes config.json with the mlx_lm `quantization` block. Usage: quantize_component.py """ import json import sys from pathlib import Path import mlx.core as mx src = Path(sys.argv[1]) dst = Path(sys.argv[2]) bits = int(sys.argv[3]) gs = int(sys.argv[4]) dst.mkdir(parents=True, exist_ok=True) w = mx.load(str(src / "model.safetensors")) out = {} nq = 0 skipped = [] for k, v in w.items(): if k.endswith(".weight") and v.ndim == 2 and v.shape[-1] % gs == 0: wq, scales, biases = mx.quantize(v, group_size=gs, bits=bits) base = k[: -len(".weight")] out[k] = wq out[base + ".scales"] = scales out[base + ".biases"] = biases nq += 1 else: out[k] = v if k.endswith(".weight") and v.ndim == 2: skipped.append((k, tuple(v.shape))) mx.save_safetensors(str(dst / "model.safetensors"), out) cfg = {} if (src / "config.json").exists(): cfg = json.loads((src / "config.json").read_text()) cfg["quantization"] = {"group_size": gs, "bits": bits, "mode": "affine"} (dst / "config.json").write_text(json.dumps(cfg, indent=2)) print(f"{src.name}: quantised {nq} linears -> {dst} ({bits}-bit g{gs})") if skipped: print(" skipped 2D weights (in not divisible by gs):", skipped)