File size: 10,567 Bytes
e39ff3a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 |
import argparse
import copy
import json
import os
import sys
from pathlib import Path
import mlx.core as mx
import mlx.nn as nn
from mlx.utils import tree_flatten, tree_map, tree_unflatten
from mlx_lm.quant.dynamic_quant import eval_ppl
from mlx_lm.quant.utils import load_data
from safetensors import safe_open
from tqdm import tqdm
from transformers import AutoTokenizer
# FIX: Correctly calculate the project root to find model.py
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
if project_root not in sys.path:
sys.path.insert(0, project_root)
from model import Model, ModelArgs
def estimate_sensitivities(
model, data, low_bits, low_group_size, high_bits, high_group_size, batch_size=4
):
def qdq(w, bits, group_size):
w, s, b = mx.quantize(w, bits=bits, group_size=group_size)
return mx.dequantize(w, scales=s, biases=b, bits=bits, group_size=group_size)
q_model = copy.deepcopy(model)
linear_layers = {
k: layer
for k, layer in tree_flatten(
q_model.leaf_modules(), is_leaf=nn.Module.is_module
)
if isinstance(layer, nn.Linear)
}
# Quantize-dequantize weights for low-precision model copy and ensure
# the weights remain trainable so gradients are computed for sensitivities.
for layer in linear_layers.values():
layer.weight = qdq(layer.weight, low_bits, low_group_size)
def loss_fn(batch, targets):
logits = q_model(batch)
return nn.losses.cross_entropy(logits, targets, reduction="mean")
grad_accum = tree_map(lambda x: mx.zeros(x.shape), q_model.trainable_parameters())
for s in tqdm(range(0, len(data), batch_size), desc="Estimating sensitivities"):
batch = data[s : s + batch_size]
targets = model(batch[:, :-1])
mx.eval(targets)
_, grads = nn.value_and_grad(q_model, loss_fn)(batch[:, :-1], batch[:, 1:])
grad_accum = tree_map(lambda x, y: x + y, grad_accum, grads)
mx.eval(grad_accum)
def compute_sensitivity(grad, lq_w, orig_w):
hq_w = qdq(orig_w, high_bits, high_group_size)
return (grad * (lq_w - hq_w)).sum()
# Use a direct loop instead of tree_map to be more robust
grad_dict = dict(tree_flatten(grad_accum))
q_params_dict = dict(tree_flatten(q_model.parameters()))
orig_params_dict = dict(tree_flatten(model.parameters()))
sensitivities = {}
for path, module in linear_layers.items():
weight_key = f"{path}.weight"
if weight_key in grad_dict:
grad = grad_dict[weight_key]
q_weight = q_params_dict[weight_key]
orig_weight = orig_params_dict[weight_key]
sensitivity = compute_sensitivity(grad, q_weight, orig_weight)
sensitivities[path] = sensitivity.item()
return sensitivities
def estimate_threshold(
model,
sensitivities,
target_bpw,
low_bits,
low_group_size,
high_bits,
high_group_size,
):
def predicate(p, m, threshold):
if not isinstance(m, nn.Linear):
return False
return sensitivities.get(p, 0) > threshold
sens_vals = list(sensitivities.values())
if len(sens_vals) == 0:
raise RuntimeError(
"No sensitivities were computed. This usually means gradients "
"for Linear weights were not collected. Ensure layers are detected "
"and weights are trainable during sensitivity estimation."
)
min_thr, max_thr = min(sens_vals), max(sens_vals)
while (max_thr - min_thr) > 1e-3 * (max(sens_vals) - min(sens_vals)):
mid = (max_thr + min_thr) / 2
q_model = copy.deepcopy(model)
def high_predicate(p, m):
return predicate(p, m, mid)
def low_predicate(p, m):
# Only quantize remaining float nn.Linear layers; avoid re-quantizing
# modules already quantized in the first pass.
return isinstance(m, nn.Linear) and (not predicate(p, m, mid))
nn.quantize(
q_model,
group_size=high_group_size,
bits=high_bits,
class_predicate=high_predicate,
)
nn.quantize(
q_model,
group_size=low_group_size,
bits=low_bits,
class_predicate=low_predicate,
)
bpw = (
sum(p.nbytes for _, p in tree_flatten(q_model.parameters()))
* 8
/ sum(p.size for _, p in tree_flatten(q_model.parameters()))
)
if bpw > target_bpw:
min_thr = mid
else:
max_thr = mid
return (max_thr + min_thr) / 2
# --- Main Conversion and Saving Logic ---
def main():
parser = argparse.ArgumentParser(
description="Convert and optionally quantize a model."
)
parser.add_argument(
"--hf-path", type=str, default=".", help="Path to the Hugging Face model."
)
parser.add_argument(
"--mlx-path", type=str, required=True, help="Path to save the MLX model."
)
parser.add_argument(
"--quantize",
"-q",
action="store_true",
help="Generate a simple uniformly quantized model.",
)
parser.add_argument(
"--dynamic-quant",
action="store_true",
help="Use advanced mixed-precision quantization.",
)
parser.add_argument(
"--report-ppl",
action="store_true",
help="Report perplexity before and after quantization.",
)
parser.add_argument(
"--target-bpw",
type=float,
default=4.5,
help="Target bits per weight for advanced quant.",
)
parser.add_argument(
"--bits", "-b", type=int, default=4, help="Bits for uniform quantization."
)
parser.add_argument(
"--group-size",
"-g",
type=int,
default=None,
help="Group size for quantization. If omitted, defaults to 64 when quantizing.",
)
args = parser.parse_args()
print(f"Loading model from {args.hf_path}...")
hf_path = Path(args.hf_path)
tokenizer = AutoTokenizer.from_pretrained(args.hf_path)
with open(hf_path / "config.json", "r") as f:
config = json.load(f)
with safe_open(hf_path / "model.safetensors", framework="mlx") as f:
keys = list(f.keys())
has_dual = any(
(".feed_forward.g_up.weight" in k) or (".mlp.g_up.weight" in k) for k in keys
)
model_args = ModelArgs.from_dict(config)
model_args.use_dual_mlp = bool(has_dual)
model = Model(model_args)
weights = {}
with safe_open(hf_path / "model.safetensors", framework="mlx") as f:
for k in f.keys():
if has_dual and ("gate_proj" in k or "up_proj" in k or "down_proj" in k):
continue
v = f.get_tensor(k)
k = k.replace("model.embed_tokens", "tok_embeddings")
k = k.replace("model.layers", "layers")
k = k.replace("self_attn", "attention")
k = k.replace("input_layernorm", "attention_norm")
k = k.replace("post_attention_layernorm", "ffn_norm")
k = k.replace("mlp.", "feed_forward.")
k = k.replace("model.norm", "norm")
weights[k] = v
if config.get("tie_word_embeddings", True):
weights.pop("output.weight", None)
model.update(tree_unflatten(list(weights.items())))
calibration_data = None
if args.report_ppl or args.dynamic_quant:
print("Loading calibration data...")
calibration_data = load_data(tokenizer, num_samples=-1, sequence_length=512)
if args.report_ppl:
print("Calculating perplexity of original model...")
ppl = eval_ppl(model, data=calibration_data)
print(f"Original PPL: {ppl:.3f}")
if args.dynamic_quant:
# Choose a sensible default group size if not provided
if args.group_size is None:
args.group_size = 64
print("[info] Using default group_size=64 for dynamic quantization")
print("Starting advanced mixed-precision quantization...")
sensitivities = estimate_sensitivities(
model, calibration_data, 4, args.group_size, 8, args.group_size
)
threshold = estimate_threshold(
model,
sensitivities,
args.target_bpw,
4,
args.group_size,
8,
args.group_size,
)
# Compute per-layer bit widths BEFORE mutating the model
per_layer_bits = {p: (8 if s > threshold else 4) for p, s in sensitivities.items()}
def high_predicate(p, m):
return isinstance(m, nn.Linear) and per_layer_bits.get(p, 4) == 8
def low_predicate(p, m):
return isinstance(m, nn.Linear) and per_layer_bits.get(p, 4) == 4
nn.quantize(
model, group_size=args.group_size, bits=8, class_predicate=high_predicate
)
nn.quantize(
model, group_size=args.group_size, bits=4, class_predicate=low_predicate
)
# Persist per-layer bit-widths so the loader can re-materialize
# the correct QuantizedLinear modules on load without touching
# embeddings or other layers.
config["quantization"] = {
"group_size": args.group_size,
"method": "mixed_precision_dynamic",
"per_layer_bits": per_layer_bits,
}
elif args.quantize:
# Choose a sensible default group size if not provided
if args.group_size is None:
args.group_size = 64
print("[info] Using default group_size=64 for uniform quantization")
print("Starting simple uniform quantization...")
nn.quantize(model, group_size=args.group_size, bits=args.bits)
config["quantization"] = {
"group_size": args.group_size,
"bits": args.bits,
"method": "uniform",
}
if args.report_ppl and (args.quantize or args.dynamic_quant):
print("Calculating perplexity of quantized model...")
ppl = eval_ppl(model, data=calibration_data)
print(f"Quantized PPL: {ppl:.3f}")
output_path = Path(args.mlx_path)
output_path.mkdir(parents=True, exist_ok=True)
mx.savez(str(output_path / "weights.npz"), **dict(tree_flatten(model.parameters())))
with open(output_path / "config.json", "w") as f:
json.dump(config, f, indent=4)
tokenizer.save_pretrained(output_path)
print(f"\n✅ Model saved to {args.mlx_path}")
if __name__ == "__main__":
main()
|