|
|
import argparse |
|
|
import copy |
|
|
import json |
|
|
import os |
|
|
import sys |
|
|
from pathlib import Path |
|
|
|
|
|
import mlx.core as mx |
|
|
import mlx.nn as nn |
|
|
from mlx.utils import tree_flatten, tree_map, tree_unflatten |
|
|
from mlx_lm.quant.dynamic_quant import eval_ppl |
|
|
from mlx_lm.quant.utils import load_data |
|
|
from safetensors import safe_open |
|
|
from tqdm import tqdm |
|
|
from transformers import AutoTokenizer |
|
|
|
|
|
|
|
|
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) |
|
|
if project_root not in sys.path: |
|
|
sys.path.insert(0, project_root) |
|
|
|
|
|
from model import Model, ModelArgs |
|
|
|
|
|
|
|
|
def estimate_sensitivities( |
|
|
model, data, low_bits, low_group_size, high_bits, high_group_size, batch_size=4 |
|
|
): |
|
|
def qdq(w, bits, group_size): |
|
|
w, s, b = mx.quantize(w, bits=bits, group_size=group_size) |
|
|
return mx.dequantize(w, scales=s, biases=b, bits=bits, group_size=group_size) |
|
|
|
|
|
q_model = copy.deepcopy(model) |
|
|
linear_layers = { |
|
|
k: layer |
|
|
for k, layer in tree_flatten( |
|
|
q_model.leaf_modules(), is_leaf=nn.Module.is_module |
|
|
) |
|
|
if isinstance(layer, nn.Linear) |
|
|
} |
|
|
|
|
|
|
|
|
for layer in linear_layers.values(): |
|
|
layer.weight = qdq(layer.weight, low_bits, low_group_size) |
|
|
|
|
|
def loss_fn(batch, targets): |
|
|
logits = q_model(batch) |
|
|
return nn.losses.cross_entropy(logits, targets, reduction="mean") |
|
|
|
|
|
grad_accum = tree_map(lambda x: mx.zeros(x.shape), q_model.trainable_parameters()) |
|
|
|
|
|
for s in tqdm(range(0, len(data), batch_size), desc="Estimating sensitivities"): |
|
|
batch = data[s : s + batch_size] |
|
|
targets = model(batch[:, :-1]) |
|
|
mx.eval(targets) |
|
|
_, grads = nn.value_and_grad(q_model, loss_fn)(batch[:, :-1], batch[:, 1:]) |
|
|
grad_accum = tree_map(lambda x, y: x + y, grad_accum, grads) |
|
|
mx.eval(grad_accum) |
|
|
|
|
|
def compute_sensitivity(grad, lq_w, orig_w): |
|
|
hq_w = qdq(orig_w, high_bits, high_group_size) |
|
|
return (grad * (lq_w - hq_w)).sum() |
|
|
|
|
|
|
|
|
grad_dict = dict(tree_flatten(grad_accum)) |
|
|
q_params_dict = dict(tree_flatten(q_model.parameters())) |
|
|
orig_params_dict = dict(tree_flatten(model.parameters())) |
|
|
|
|
|
sensitivities = {} |
|
|
for path, module in linear_layers.items(): |
|
|
weight_key = f"{path}.weight" |
|
|
if weight_key in grad_dict: |
|
|
grad = grad_dict[weight_key] |
|
|
q_weight = q_params_dict[weight_key] |
|
|
orig_weight = orig_params_dict[weight_key] |
|
|
|
|
|
sensitivity = compute_sensitivity(grad, q_weight, orig_weight) |
|
|
sensitivities[path] = sensitivity.item() |
|
|
|
|
|
return sensitivities |
|
|
|
|
|
|
|
|
def estimate_threshold( |
|
|
model, |
|
|
sensitivities, |
|
|
target_bpw, |
|
|
low_bits, |
|
|
low_group_size, |
|
|
high_bits, |
|
|
high_group_size, |
|
|
): |
|
|
def predicate(p, m, threshold): |
|
|
if not isinstance(m, nn.Linear): |
|
|
return False |
|
|
return sensitivities.get(p, 0) > threshold |
|
|
|
|
|
sens_vals = list(sensitivities.values()) |
|
|
if len(sens_vals) == 0: |
|
|
raise RuntimeError( |
|
|
"No sensitivities were computed. This usually means gradients " |
|
|
"for Linear weights were not collected. Ensure layers are detected " |
|
|
"and weights are trainable during sensitivity estimation." |
|
|
) |
|
|
min_thr, max_thr = min(sens_vals), max(sens_vals) |
|
|
|
|
|
while (max_thr - min_thr) > 1e-3 * (max(sens_vals) - min(sens_vals)): |
|
|
mid = (max_thr + min_thr) / 2 |
|
|
q_model = copy.deepcopy(model) |
|
|
|
|
|
def high_predicate(p, m): |
|
|
return predicate(p, m, mid) |
|
|
|
|
|
def low_predicate(p, m): |
|
|
|
|
|
|
|
|
return isinstance(m, nn.Linear) and (not predicate(p, m, mid)) |
|
|
|
|
|
nn.quantize( |
|
|
q_model, |
|
|
group_size=high_group_size, |
|
|
bits=high_bits, |
|
|
class_predicate=high_predicate, |
|
|
) |
|
|
nn.quantize( |
|
|
q_model, |
|
|
group_size=low_group_size, |
|
|
bits=low_bits, |
|
|
class_predicate=low_predicate, |
|
|
) |
|
|
|
|
|
bpw = ( |
|
|
sum(p.nbytes for _, p in tree_flatten(q_model.parameters())) |
|
|
* 8 |
|
|
/ sum(p.size for _, p in tree_flatten(q_model.parameters())) |
|
|
) |
|
|
|
|
|
if bpw > target_bpw: |
|
|
min_thr = mid |
|
|
else: |
|
|
max_thr = mid |
|
|
return (max_thr + min_thr) / 2 |
|
|
|
|
|
|
|
|
|
|
|
def main(): |
|
|
parser = argparse.ArgumentParser( |
|
|
description="Convert and optionally quantize a model." |
|
|
) |
|
|
parser.add_argument( |
|
|
"--hf-path", type=str, default=".", help="Path to the Hugging Face model." |
|
|
) |
|
|
parser.add_argument( |
|
|
"--mlx-path", type=str, required=True, help="Path to save the MLX model." |
|
|
) |
|
|
parser.add_argument( |
|
|
"--quantize", |
|
|
"-q", |
|
|
action="store_true", |
|
|
help="Generate a simple uniformly quantized model.", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--dynamic-quant", |
|
|
action="store_true", |
|
|
help="Use advanced mixed-precision quantization.", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--report-ppl", |
|
|
action="store_true", |
|
|
help="Report perplexity before and after quantization.", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--target-bpw", |
|
|
type=float, |
|
|
default=4.5, |
|
|
help="Target bits per weight for advanced quant.", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--bits", "-b", type=int, default=4, help="Bits for uniform quantization." |
|
|
) |
|
|
parser.add_argument( |
|
|
"--group-size", |
|
|
"-g", |
|
|
type=int, |
|
|
default=None, |
|
|
help="Group size for quantization. If omitted, defaults to 64 when quantizing.", |
|
|
) |
|
|
args = parser.parse_args() |
|
|
|
|
|
print(f"Loading model from {args.hf_path}...") |
|
|
hf_path = Path(args.hf_path) |
|
|
tokenizer = AutoTokenizer.from_pretrained(args.hf_path) |
|
|
|
|
|
with open(hf_path / "config.json", "r") as f: |
|
|
config = json.load(f) |
|
|
|
|
|
with safe_open(hf_path / "model.safetensors", framework="mlx") as f: |
|
|
keys = list(f.keys()) |
|
|
has_dual = any( |
|
|
(".feed_forward.g_up.weight" in k) or (".mlp.g_up.weight" in k) for k in keys |
|
|
) |
|
|
model_args = ModelArgs.from_dict(config) |
|
|
model_args.use_dual_mlp = bool(has_dual) |
|
|
model = Model(model_args) |
|
|
|
|
|
weights = {} |
|
|
with safe_open(hf_path / "model.safetensors", framework="mlx") as f: |
|
|
for k in f.keys(): |
|
|
if has_dual and ("gate_proj" in k or "up_proj" in k or "down_proj" in k): |
|
|
continue |
|
|
v = f.get_tensor(k) |
|
|
k = k.replace("model.embed_tokens", "tok_embeddings") |
|
|
k = k.replace("model.layers", "layers") |
|
|
k = k.replace("self_attn", "attention") |
|
|
k = k.replace("input_layernorm", "attention_norm") |
|
|
k = k.replace("post_attention_layernorm", "ffn_norm") |
|
|
k = k.replace("mlp.", "feed_forward.") |
|
|
k = k.replace("model.norm", "norm") |
|
|
weights[k] = v |
|
|
if config.get("tie_word_embeddings", True): |
|
|
weights.pop("output.weight", None) |
|
|
model.update(tree_unflatten(list(weights.items()))) |
|
|
|
|
|
calibration_data = None |
|
|
if args.report_ppl or args.dynamic_quant: |
|
|
print("Loading calibration data...") |
|
|
calibration_data = load_data(tokenizer, num_samples=-1, sequence_length=512) |
|
|
|
|
|
if args.report_ppl: |
|
|
print("Calculating perplexity of original model...") |
|
|
ppl = eval_ppl(model, data=calibration_data) |
|
|
print(f"Original PPL: {ppl:.3f}") |
|
|
|
|
|
if args.dynamic_quant: |
|
|
|
|
|
if args.group_size is None: |
|
|
args.group_size = 64 |
|
|
print("[info] Using default group_size=64 for dynamic quantization") |
|
|
print("Starting advanced mixed-precision quantization...") |
|
|
sensitivities = estimate_sensitivities( |
|
|
model, calibration_data, 4, args.group_size, 8, args.group_size |
|
|
) |
|
|
|
|
|
threshold = estimate_threshold( |
|
|
model, |
|
|
sensitivities, |
|
|
args.target_bpw, |
|
|
4, |
|
|
args.group_size, |
|
|
8, |
|
|
args.group_size, |
|
|
) |
|
|
|
|
|
|
|
|
per_layer_bits = {p: (8 if s > threshold else 4) for p, s in sensitivities.items()} |
|
|
|
|
|
def high_predicate(p, m): |
|
|
return isinstance(m, nn.Linear) and per_layer_bits.get(p, 4) == 8 |
|
|
|
|
|
def low_predicate(p, m): |
|
|
return isinstance(m, nn.Linear) and per_layer_bits.get(p, 4) == 4 |
|
|
|
|
|
nn.quantize( |
|
|
model, group_size=args.group_size, bits=8, class_predicate=high_predicate |
|
|
) |
|
|
nn.quantize( |
|
|
model, group_size=args.group_size, bits=4, class_predicate=low_predicate |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
config["quantization"] = { |
|
|
"group_size": args.group_size, |
|
|
"method": "mixed_precision_dynamic", |
|
|
"per_layer_bits": per_layer_bits, |
|
|
} |
|
|
|
|
|
elif args.quantize: |
|
|
|
|
|
if args.group_size is None: |
|
|
args.group_size = 64 |
|
|
print("[info] Using default group_size=64 for uniform quantization") |
|
|
print("Starting simple uniform quantization...") |
|
|
nn.quantize(model, group_size=args.group_size, bits=args.bits) |
|
|
config["quantization"] = { |
|
|
"group_size": args.group_size, |
|
|
"bits": args.bits, |
|
|
"method": "uniform", |
|
|
} |
|
|
|
|
|
if args.report_ppl and (args.quantize or args.dynamic_quant): |
|
|
print("Calculating perplexity of quantized model...") |
|
|
ppl = eval_ppl(model, data=calibration_data) |
|
|
print(f"Quantized PPL: {ppl:.3f}") |
|
|
|
|
|
output_path = Path(args.mlx_path) |
|
|
output_path.mkdir(parents=True, exist_ok=True) |
|
|
mx.savez(str(output_path / "weights.npz"), **dict(tree_flatten(model.parameters()))) |
|
|
with open(output_path / "config.json", "w") as f: |
|
|
json.dump(config, f, indent=4) |
|
|
tokenizer.save_pretrained(output_path) |
|
|
print(f"\n✅ Model saved to {args.mlx_path}") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|