|
|
|
|
|
|
|
|
|
|
|
import mlx.core as mx |
|
|
|
|
|
import mlx.nn as nn |
|
|
|
|
|
import numpy as np |
|
|
|
|
|
from mlx_googlenet import GoogLeNet |
|
|
|
|
|
import os |
|
|
|
|
|
|
|
|
|
|
|
def main(): |
|
|
|
|
|
print("--- Attempting Extreme Quantization (4-bit / 8-bit) ---") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model = GoogLeNet() |
|
|
|
|
|
model.load_npz("googlenet_mlx_bf16.npz") |
|
|
|
|
|
|
|
|
|
|
|
print("Original Weights Loaded.") |
|
|
|
|
|
|
|
|
|
|
|
print("\nStrategy: Quantize weights to INT8 (Storage Optimization)") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
state = model.parameters() |
|
|
|
|
|
compressed_state = {} |
|
|
|
|
|
|
|
|
|
|
|
total_original = 0 |
|
|
|
|
|
total_compressed = 0 |
|
|
|
|
|
|
|
|
|
|
|
for k, v in state.items(): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pass |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
data = np.load("googlenet_mlx_bf16.npz") |
|
|
|
|
|
|
|
|
|
|
|
for k in data.files: |
|
|
|
|
|
v = mx.array(data[k]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if "weight" in k and v.ndim >= 2: |
|
|
|
|
|
|
|
|
|
|
|
v_abs = mx.abs(v) |
|
|
|
|
|
v_max = mx.max(v_abs) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
scale = v_max / 127.0 |
|
|
|
|
|
scale = mx.where(scale == 0, 1.0, scale) |
|
|
|
|
|
|
|
|
|
|
|
v_int8 = (v / scale).astype(mx.int8) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
compressed_state[f"{k}_int8"] = np.array(v_int8) |
|
|
|
|
|
compressed_state[f"{k}_scale"] = np.array(scale.astype(mx.float16)) |
|
|
|
|
|
|
|
|
|
|
|
original_bytes = v.nbytes |
|
|
|
|
|
new_bytes = v_int8.nbytes + 2 |
|
|
|
|
|
|
|
|
|
|
|
total_original += original_bytes |
|
|
|
|
|
total_compressed += new_bytes |
|
|
|
|
|
|
|
|
|
|
|
else: |
|
|
|
|
|
|
|
|
|
|
|
compressed_state[k] = np.array(v.astype(mx.float16)) |
|
|
|
|
|
total_original += v.nbytes |
|
|
|
|
|
total_compressed += v.nbytes |
|
|
|
|
|
|
|
|
|
|
|
out_name = "googlenet_mlx_int8.npz" |
|
|
|
|
|
np.savez(out_name, **compressed_state) |
|
|
|
|
|
|
|
|
|
|
|
print(f"\n✅ Saved {out_name}") |
|
|
|
|
|
print(f" Original Size: {total_original / (1024*1024):.2f} MB") |
|
|
|
|
|
print(f" Quantized Size: {total_compressed / (1024*1024):.2f} MB") |
|
|
|
|
|
print(f" Reduction: {100 * (1 - total_compressed/total_original):.1f}%") |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
main() |
|
|
|