robbiemu's picture
add mlx and mlx-lm support
e39ff3a
import argparse
import json
from pathlib import Path
import numpy as np
def main():
p = argparse.ArgumentParser(description="Summarize MLX-LM quantization layout")
p.add_argument("--model-path", required=True, help="Path to converted MLX model")
p.add_argument("--show", type=int, default=10, help="Show up to N entries per group")
args = p.parse_args()
mpath = Path(args.model_path)
cfg = json.loads((mpath / "config.json").read_text())
q = cfg.get("quantization") or {}
method = q.get("method", "none")
gsize = q.get("group_size")
plb = q.get("per_layer_bits", {})
print(f"Method: {method}")
print(f"Group size: {gsize}")
if method == "uniform":
print(f"Uniform bits: {q.get('bits')}")
return
if not plb:
print("No per-layer bits found in config.")
return
# Basic counts
buckets = {4: [], 8: [], "other": []}
for k, b in plb.items():
if b == 4:
buckets[4].append(k)
elif b == 8:
buckets[8].append(k)
else:
buckets["other"].append(k)
total = sum(len(v) for v in buckets.values())
print(f"Total linear layers: {total}")
print(f"4-bit layers: {len(buckets[4])}")
print(f"8-bit layers: {len(buckets[8])}")
if buckets["other"]:
print(f"Other-bit layers: {len(buckets['other'])}")
# Optional: show a few examples
for b in (8, 4):
items = sorted(buckets[b])
if not items:
continue
print(f"\nExamples ({b}-bit):")
for k in items[: args.show]:
print(f"- {k}")
# Optional: sanity-check against npz contents
try:
npz = np.load(mpath / "weights.npz", allow_pickle=False)
has_q = any(k.endswith(".scales") or k.endswith(".biases") for k in npz.files)
print(f"\nweights.npz contains quantized tensors: {has_q}")
except Exception as e:
print(f"Note: could not open weights.npz: {e}")
if __name__ == "__main__":
main()