File size: 3,341 Bytes
bdff6f4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 |
import mlx.core as mx
import mlx.nn as nn
import numpy as np
from mlx_googlenet import GoogLeNet
import os
def main():
print("--- Attempting Extreme Quantization (4-bit / 8-bit) ---")
# Load standard model
model = GoogLeNet()
model.load_npz("googlenet_mlx_bf16.npz")
print("Original Weights Loaded.")
print("\nStrategy: Quantize weights to INT8 (Storage Optimization)")
# We will effectively store weights as (int8_weight, float16_scale)
# On load, we will do: weight = int8_weight.astype(fp16) * scale
state = model.parameters()
compressed_state = {}
total_original = 0
total_compressed = 0
for k, v in state.items():
# Flatten keys for parameters() which returns nested dicts if using trees,
# but model.parameters() returns nested dict of arrays?
# No, mlx model.parameters() returns a dict of {name: array} if flattened?
# Actually model.parameters() returns a generator or dict?
# model.parameters() returns a dict of arrays recursively?
# Let's use flatten logic manually or just iterate what we have.
pass
# Actually model.state_dict() is better for flat keys
# Wait, MLX doesn't have state_dict() like PyTorch exactly?
# mlx.nn.utils.tree_flatten(model.parameters()) gives list.
# Let's assume we work on the flattened dict structure we used for saving npz
# Our export script did: np.savez(out, **{k: v})
# Our load_npz in models does: data[key]
# So we should load the .npz FILE directly and process it,
# rather than traversing the model object which might be complex.
data = np.load("googlenet_mlx_bf16.npz")
for k in data.files:
v = mx.array(data[k])
# Check if it's a weight (conv or linear)
# Heuristic: name ends in ".weight" and ndim >= 2
if "weight" in k and v.ndim >= 2:
# Quantize to INT8
v_abs = mx.abs(v)
v_max = mx.max(v_abs)
# Scale to range [-127, 127]
# Avoid div by zero
scale = v_max / 127.0
scale = mx.where(scale == 0, 1.0, scale)
v_int8 = (v / scale).astype(mx.int8)
# Save components
compressed_state[f"{k}_int8"] = np.array(v_int8)
compressed_state[f"{k}_scale"] = np.array(scale.astype(mx.float16))
original_bytes = v.nbytes
new_bytes = v_int8.nbytes + 2 # scale size
total_original += original_bytes
total_compressed += new_bytes
else:
# Save as is (float16)
compressed_state[k] = np.array(v.astype(mx.float16))
total_original += v.nbytes
total_compressed += v.nbytes
out_name = "googlenet_mlx_int8.npz"
np.savez(out_name, **compressed_state)
print(f"\n✅ Saved {out_name}")
print(f" Original Size: {total_original / (1024*1024):.2f} MB")
print(f" Quantized Size: {total_compressed / (1024*1024):.2f} MB")
print(f" Reduction: {100 * (1 - total_compressed/total_original):.1f}%")
if __name__ == "__main__":
main()
|