NickMystic commited on
Commit
bdff6f4
·
verified ·
1 Parent(s): e520ebb

Upload folder using huggingface_hub

Browse files
README.md CHANGED
@@ -162,7 +162,7 @@ python dream.py --input love.jpg \
162
 
163
  ## 💾 Weight Conversion & Efficiency
164
 
165
- We didn't just wrap existing libs. We wrote custom exporters (`export_*.py`) to rip weights from standard PyTorch/Torchvision archives and serialize them into optimized MLX `.npz` arrays.
166
 
167
  ### 50% Smaller Weights (FP16)
168
  We now support **Float16** (Half-Precision) weights by default. This cuts model size in half with zero visual loss for DeepDreaming.
@@ -186,9 +186,8 @@ You need to fine-tune the base model on a new dataset.
186
  **Current Workflow:**
187
  1. Train your model in PyTorch (standard ImageNet training or custom dataset).
188
  2. Save the `.pth` checkpoint.
189
- 3. Modify our `export_*.py` scripts to load your custom checkpoint.
190
- 4. Export to `.npz`.
191
- 5. Dream.
192
 
193
  *A dedicated `train_dream.py` script is on the roadmap.*
194
 
 
162
 
163
  ## 💾 Weight Conversion & Efficiency
164
 
165
+ We didn't just wrap existing libs. We wrote a custom exporter (`export_models.py`) to rip weights from standard PyTorch/Torchvision archives and serialize them into optimized MLX `.npz` arrays.
166
 
167
  ### 50% Smaller Weights (FP16)
168
  We now support **Float16** (Half-Precision) weights by default. This cuts model size in half with zero visual loss for DeepDreaming.
 
186
  **Current Workflow:**
187
  1. Train your model in PyTorch (standard ImageNet training or custom dataset).
188
  2. Save the `.pth` checkpoint.
189
+ 3. Use `export_models.py` to load your custom checkpoint and export to MLX.
190
+ 4. Dream.
 
191
 
192
  *A dedicated `train_dream.py` script is on the roadmap.*
193
 
dream.py CHANGED
@@ -7,10 +7,10 @@ import mlx.core as mx
7
  import mlx.nn as nn
8
  import numpy as np
9
  import scipy.ndimage as nd
10
- from mlx_resnet50 import ResNet50
11
  from PIL import Image
12
 
13
  from mlx_googlenet import GoogLeNet
 
14
  from mlx_vgg16 import VGG16
15
  from mlx_vgg19 import VGG19
16
 
@@ -62,7 +62,7 @@ def gaussian_kernel(sigma, truncate=4.0, fixed_radius=None):
62
  radius = fixed_radius
63
  else:
64
  radius = int(truncate * sigma + 0.5)
65
-
66
  x = mx.arange(-radius, radius + 1)
67
  kernel = mx.exp(-0.5 * (x / sigma) ** 2)
68
  kernel = kernel / kernel.sum()
@@ -75,14 +75,14 @@ def gaussian_blur_2d(x, sigma, fixed_radius=None):
75
  kernel = kernel.astype(x.dtype)
76
  k_size = kernel.shape[0]
77
  C = x.shape[-1]
78
-
79
  k_x = kernel.reshape(1, 1, k_size, 1)
80
  k_x = mx.repeat(k_x, C, axis=0)
81
  k_y = kernel.reshape(1, k_size, 1, 1)
82
  k_y = mx.repeat(k_y, C, axis=0)
83
-
84
  pad = k_size // 2
85
-
86
  x = mx.conv2d(x, k_x, stride=1, padding=(0, pad), groups=C)
87
  x = mx.conv2d(x, k_y, stride=1, padding=(pad, 0), groups=C)
88
  return x
@@ -94,7 +94,7 @@ def smooth_gradients(grad, sigma, fixed_radius=None):
94
  smoothed = []
95
  for s in sigmas:
96
  smoothed.append(gaussian_blur_2d(grad, s, fixed_radius=fixed_radius))
97
-
98
  g_total = smoothed[0]
99
  for i in range(1, len(smoothed)):
100
  g_total = g_total + smoothed[i]
@@ -135,7 +135,7 @@ def deepdream(
135
  if guide_img_np is not None:
136
  guide_resized = resize_bilinear(preprocess(guide_img_np), nh, nw)
137
  _, guide_features = model.forward_with_endpoints(guide_resized)
138
-
139
  def loss_fn(x):
140
  endpoints = model.forward_with_endpoints(x)[1]
141
  loss = mx.zeros(())
@@ -165,35 +165,79 @@ def deepdream(
165
  for it in range(steps):
166
  ox, oy = np.random.randint(-jitter, jitter + 1, 2)
167
  rolled = mx.roll(mx.roll(img, ox, axis=1), oy, axis=2)
168
-
169
  sigma_val = ((it + 1) / steps) * 2.0 + smoothing
170
-
171
  rolled, loss = update_step(rolled, mx.array(sigma_val))
172
-
173
  img = mx.roll(mx.roll(rolled, -ox, axis=1), -oy, axis=2)
174
-
175
  return deprocess(img)
176
 
177
 
178
  def get_weights_path(model_name, explicit_path=None):
 
 
179
  if explicit_path:
 
 
180
  return explicit_path
 
 
181
 
182
- # 1. Try bf16 (Efficient)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
183
  bf16_path = f"{model_name}_mlx_bf16.npz"
 
 
184
  if os.path.exists(bf16_path):
 
 
185
  return bf16_path
 
 
186
 
187
- # 2. Try standard float32
 
 
 
 
188
  fp32_path = f"{model_name}_mlx.npz"
 
 
189
  if os.path.exists(fp32_path):
 
 
190
  return fp32_path
 
 
191
 
192
- return fp32_path # Default fallback for error message
 
 
 
193
 
194
  def run_dream_for_model(model_name, args, img_np):
195
  print(f"--- Running DeepDream with {model_name} ---")
196
-
197
  # ... (PRESETS dict remains here) ...
198
  # Notebook presets
199
  PRESETS = {
@@ -249,7 +293,7 @@ def run_dream_for_model(model_name, args, img_np):
249
  current_scale = p["scale"]
250
  current_jitter = p["jitter"]
251
  current_smoothing = p["smoothing"]
252
-
253
  elif model_name == "vgg19":
254
  model = VGG19()
255
  weights = get_weights_path("vgg19", args.weights)
@@ -263,13 +307,13 @@ def run_dream_for_model(model_name, args, img_np):
263
  current_scale = p["scale"]
264
  current_jitter = p["jitter"]
265
  current_smoothing = p["smoothing"]
266
-
267
  elif model_name == "resnet50":
268
  model = ResNet50()
269
  weights = get_weights_path("resnet50", args.weights)
270
  default_layers = ["layer4_2"]
271
-
272
- else: # googlenet
273
  model = GoogLeNet()
274
  weights = get_weights_path("googlenet", args.weights)
275
  default_layers = ["inception3b", "inception4c", "inception4d"]
@@ -277,7 +321,7 @@ def run_dream_for_model(model_name, args, img_np):
277
  if not os.path.exists(weights):
278
  print(f"Error: Weights NPZ not found: {weights}. Skipping {model_name}.")
279
  return
280
-
281
  print(f"Loading weights from: {weights}")
282
  model.load_npz(weights)
283
 
@@ -301,10 +345,10 @@ def run_dream_for_model(model_name, args, img_np):
301
  smoothing=current_smoothing,
302
  guide_img_np=guide_img_np,
303
  )
304
-
305
  end_time = time.time()
306
  elapsed = end_time - start_time
307
-
308
  if args.output:
309
  out = args.output
310
  else:
@@ -323,10 +367,17 @@ def parse_args():
323
  p.add_argument("--input", required=True, help="Input image path")
324
  p.add_argument("--output", help="Output image path (optional)")
325
  p.add_argument("--guide", help="Guide image for guided dreaming")
326
-
327
- p.add_argument("--width", type=int, default=None, help="Resize input to width (maintains aspect ratio)")
328
- p.add_argument("--img_width", type=int, help="Alias for --width", dest="width") # Alias
329
-
 
 
 
 
 
 
 
330
  p.add_argument(
331
  "--model",
332
  choices=["vgg16", "vgg19", "googlenet", "resnet50", "all"],
@@ -334,25 +385,40 @@ def parse_args():
334
  help="Model to use. 'all' runs all models.",
335
  )
336
  p.add_argument("--preset", choices=["nb14", "nb20", "nb28"], help="VGG16 presets")
337
-
338
  p.add_argument("--layers", nargs="+", help="Layers to maximize")
339
- p.add_argument("--steps", type=int, default=10, help="Gradient ascent steps per octave")
 
 
340
  p.add_argument("--lr", type=float, default=0.09, help="Learning rate (step size)")
341
-
342
  p.add_argument("--octaves", type=int, default=4, help="Number of image octaves")
343
- p.add_argument("--pyramid_size", type=int, dest="octaves", help="Alias for --octaves") # Alias
344
-
 
 
345
  p.add_argument("--scale", type=float, default=1.8, help="Octave scale factor")
346
- p.add_argument("--pyramid_ratio", type=float, dest="scale", help="Alias for --scale") # Alias
347
- p.add_argument("--octave_scale", type=float, dest="scale", help="Alias for --scale") # Alias
348
-
 
 
 
 
349
  p.add_argument("--jitter", type=int, default=32, help="Jitter amount (pixels)")
350
-
351
- p.add_argument("--smoothing", type=float, default=0.5, help="Gradient smoothing strength")
352
- p.add_argument("--smoothing_coefficient", type=float, dest="smoothing", help="Alias for --smoothing") # Alias
353
-
 
 
 
 
 
 
 
354
  p.add_argument("--weights", help="Custom weights path")
355
-
356
  return p.parse_args()
357
 
358
 
@@ -360,11 +426,13 @@ def main():
360
  args = parse_args()
361
  img_np = load_image(args.input, args.width)
362
 
363
- if args.model == 'all':
364
  models = ["vgg16", "vgg19", "googlenet", "resnet50"]
365
  if args.output:
366
- print("Warning: --output argument ignored because --model='all' was selected.")
367
- args.output = None
 
 
368
  for m in models:
369
  run_dream_for_model(m, args, img_np)
370
  else:
@@ -372,4 +440,4 @@ def main():
372
 
373
 
374
  if __name__ == "__main__":
375
- main()
 
7
  import mlx.nn as nn
8
  import numpy as np
9
  import scipy.ndimage as nd
 
10
  from PIL import Image
11
 
12
  from mlx_googlenet import GoogLeNet
13
+ from mlx_resnet50 import ResNet50
14
  from mlx_vgg16 import VGG16
15
  from mlx_vgg19 import VGG19
16
 
 
62
  radius = fixed_radius
63
  else:
64
  radius = int(truncate * sigma + 0.5)
65
+
66
  x = mx.arange(-radius, radius + 1)
67
  kernel = mx.exp(-0.5 * (x / sigma) ** 2)
68
  kernel = kernel / kernel.sum()
 
75
  kernel = kernel.astype(x.dtype)
76
  k_size = kernel.shape[0]
77
  C = x.shape[-1]
78
+
79
  k_x = kernel.reshape(1, 1, k_size, 1)
80
  k_x = mx.repeat(k_x, C, axis=0)
81
  k_y = kernel.reshape(1, k_size, 1, 1)
82
  k_y = mx.repeat(k_y, C, axis=0)
83
+
84
  pad = k_size // 2
85
+
86
  x = mx.conv2d(x, k_x, stride=1, padding=(0, pad), groups=C)
87
  x = mx.conv2d(x, k_y, stride=1, padding=(pad, 0), groups=C)
88
  return x
 
94
  smoothed = []
95
  for s in sigmas:
96
  smoothed.append(gaussian_blur_2d(grad, s, fixed_radius=fixed_radius))
97
+
98
  g_total = smoothed[0]
99
  for i in range(1, len(smoothed)):
100
  g_total = g_total + smoothed[i]
 
135
  if guide_img_np is not None:
136
  guide_resized = resize_bilinear(preprocess(guide_img_np), nh, nw)
137
  _, guide_features = model.forward_with_endpoints(guide_resized)
138
+
139
  def loss_fn(x):
140
  endpoints = model.forward_with_endpoints(x)[1]
141
  loss = mx.zeros(())
 
165
  for it in range(steps):
166
  ox, oy = np.random.randint(-jitter, jitter + 1, 2)
167
  rolled = mx.roll(mx.roll(img, ox, axis=1), oy, axis=2)
168
+
169
  sigma_val = ((it + 1) / steps) * 2.0 + smoothing
170
+
171
  rolled, loss = update_step(rolled, mx.array(sigma_val))
172
+
173
  img = mx.roll(mx.roll(rolled, -ox, axis=1), -oy, axis=2)
174
+
175
  return deprocess(img)
176
 
177
 
178
  def get_weights_path(model_name, explicit_path=None):
179
+
180
+
181
  if explicit_path:
182
+
183
+
184
  return explicit_path
185
+
186
+
187
 
188
+
189
+
190
+ # 1. Try int8 (Maximum Efficiency / Smallest)
191
+
192
+
193
+ int8_path = f"{model_name}_mlx_int8.npz"
194
+
195
+
196
+ if os.path.exists(int8_path):
197
+
198
+
199
+ return int8_path
200
+
201
+
202
+
203
+
204
+
205
+ # 2. Try bf16 (Standard Efficient)
206
+
207
+
208
  bf16_path = f"{model_name}_mlx_bf16.npz"
209
+
210
+
211
  if os.path.exists(bf16_path):
212
+
213
+
214
  return bf16_path
215
+
216
+
217
 
218
+
219
+
220
+ # 3. Try standard float32
221
+
222
+
223
  fp32_path = f"{model_name}_mlx.npz"
224
+
225
+
226
  if os.path.exists(fp32_path):
227
+
228
+
229
  return fp32_path
230
+
231
+
232
 
233
+
234
+
235
+ return int8_path # Return preferred default for error message context
236
+
237
 
238
  def run_dream_for_model(model_name, args, img_np):
239
  print(f"--- Running DeepDream with {model_name} ---")
240
+
241
  # ... (PRESETS dict remains here) ...
242
  # Notebook presets
243
  PRESETS = {
 
293
  current_scale = p["scale"]
294
  current_jitter = p["jitter"]
295
  current_smoothing = p["smoothing"]
296
+
297
  elif model_name == "vgg19":
298
  model = VGG19()
299
  weights = get_weights_path("vgg19", args.weights)
 
307
  current_scale = p["scale"]
308
  current_jitter = p["jitter"]
309
  current_smoothing = p["smoothing"]
310
+
311
  elif model_name == "resnet50":
312
  model = ResNet50()
313
  weights = get_weights_path("resnet50", args.weights)
314
  default_layers = ["layer4_2"]
315
+
316
+ else: # googlenet
317
  model = GoogLeNet()
318
  weights = get_weights_path("googlenet", args.weights)
319
  default_layers = ["inception3b", "inception4c", "inception4d"]
 
321
  if not os.path.exists(weights):
322
  print(f"Error: Weights NPZ not found: {weights}. Skipping {model_name}.")
323
  return
324
+
325
  print(f"Loading weights from: {weights}")
326
  model.load_npz(weights)
327
 
 
345
  smoothing=current_smoothing,
346
  guide_img_np=guide_img_np,
347
  )
348
+
349
  end_time = time.time()
350
  elapsed = end_time - start_time
351
+
352
  if args.output:
353
  out = args.output
354
  else:
 
367
  p.add_argument("--input", required=True, help="Input image path")
368
  p.add_argument("--output", help="Output image path (optional)")
369
  p.add_argument("--guide", help="Guide image for guided dreaming")
370
+
371
+ p.add_argument(
372
+ "--width",
373
+ type=int,
374
+ default=None,
375
+ help="Resize input to width (maintains aspect ratio)",
376
+ )
377
+ p.add_argument(
378
+ "--img_width", type=int, help="Alias for --width", dest="width"
379
+ ) # Alias
380
+
381
  p.add_argument(
382
  "--model",
383
  choices=["vgg16", "vgg19", "googlenet", "resnet50", "all"],
 
385
  help="Model to use. 'all' runs all models.",
386
  )
387
  p.add_argument("--preset", choices=["nb14", "nb20", "nb28"], help="VGG16 presets")
388
+
389
  p.add_argument("--layers", nargs="+", help="Layers to maximize")
390
+ p.add_argument(
391
+ "--steps", type=int, default=10, help="Gradient ascent steps per octave"
392
+ )
393
  p.add_argument("--lr", type=float, default=0.09, help="Learning rate (step size)")
394
+
395
  p.add_argument("--octaves", type=int, default=4, help="Number of image octaves")
396
+ p.add_argument(
397
+ "--pyramid_size", type=int, dest="octaves", help="Alias for --octaves"
398
+ ) # Alias
399
+
400
  p.add_argument("--scale", type=float, default=1.8, help="Octave scale factor")
401
+ p.add_argument(
402
+ "--pyramid_ratio", type=float, dest="scale", help="Alias for --scale"
403
+ ) # Alias
404
+ p.add_argument(
405
+ "--octave_scale", type=float, dest="scale", help="Alias for --scale"
406
+ ) # Alias
407
+
408
  p.add_argument("--jitter", type=int, default=32, help="Jitter amount (pixels)")
409
+
410
+ p.add_argument(
411
+ "--smoothing", type=float, default=0.5, help="Gradient smoothing strength"
412
+ )
413
+ p.add_argument(
414
+ "--smoothing_coefficient",
415
+ type=float,
416
+ dest="smoothing",
417
+ help="Alias for --smoothing",
418
+ ) # Alias
419
+
420
  p.add_argument("--weights", help="Custom weights path")
421
+
422
  return p.parse_args()
423
 
424
 
 
426
  args = parse_args()
427
  img_np = load_image(args.input, args.width)
428
 
429
+ if args.model == "all":
430
  models = ["vgg16", "vgg19", "googlenet", "resnet50"]
431
  if args.output:
432
+ print(
433
+ "Warning: --output argument ignored because --model='all' was selected."
434
+ )
435
+ args.output = None
436
  for m in models:
437
  run_dream_for_model(m, args, img_np)
438
  else:
 
440
 
441
 
442
  if __name__ == "__main__":
443
+ main()
export_models.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Unified export script for converting PyTorch models to MLX .npz format.
3
+ Supports VGG16, VGG19, GoogLeNet, and ResNet50.
4
+ Handles both float32 (default) and float16/bfloat16 (efficient) exports.
5
+
6
+ Usage:
7
+ python export_models.py --model all --dtype float16
8
+ python export_models.py --model vgg16
9
+ """
10
+
11
+ import argparse
12
+ import os
13
+ import numpy as np
14
+ import torch
15
+ import torchvision.models as models
16
+
17
+ def get_model_info(model_name):
18
+ if model_name == "vgg16":
19
+ return models.vgg16, models.VGG16_Weights.IMAGENET1K_V1
20
+ elif model_name == "vgg19":
21
+ return models.vgg19, models.VGG19_Weights.IMAGENET1K_V1
22
+ elif model_name == "googlenet":
23
+ return models.googlenet, models.GoogLeNet_Weights.IMAGENET1K_V1
24
+ elif model_name == "resnet50":
25
+ return models.resnet50, models.ResNet50_Weights.IMAGENET1K_V1
26
+ else:
27
+ raise ValueError(f"Unknown model: {model_name}")
28
+
29
+ def export_model(model_name, dtype="float32"):
30
+ print(f"Exporting {model_name} ({dtype})...")
31
+ model_fn, weights = get_model_info(model_name)
32
+ model = model_fn(weights=weights)
33
+ model.eval()
34
+
35
+ state = model.state_dict()
36
+ converted_state = {}
37
+
38
+ target_type = np.float32
39
+ suffix = ""
40
+ quantize_int8 = False
41
+
42
+ if dtype in ["float16", "bf16", "half"]:
43
+ target_type = np.float16
44
+ suffix = "_bf16" # Keep legacy suffix for compatibility with dream.py logic
45
+ elif dtype == "int8":
46
+ target_type = np.float16 # Base type for scales/biases
47
+ suffix = "_int8"
48
+ quantize_int8 = True
49
+
50
+ for k, v in state.items():
51
+ v_np = v.cpu().detach().numpy()
52
+
53
+ if quantize_int8 and "weight" in k and v_np.ndim >= 2:
54
+ # Quantize to INT8
55
+ v_abs = np.abs(v_np)
56
+ v_max = np.max(v_abs)
57
+
58
+ # Scale to range [-127, 127]
59
+ # Avoid div by zero
60
+ if v_max == 0:
61
+ scale = 1.0
62
+ else:
63
+ scale = v_max / 127.0
64
+
65
+ v_int8 = (v_np / scale).astype(np.int8)
66
+
67
+ converted_state[f"{k}_int8"] = v_int8
68
+ converted_state[f"{k}_scale"] = np.array(scale).astype(target_type)
69
+ else:
70
+ converted_state[k] = v_np.astype(target_type)
71
+
72
+ out_name = f"{model_name}_mlx{suffix}.npz"
73
+ np.savez(out_name, **converted_state)
74
+
75
+ original_size = sum(v.numel() * 4 for v in state.values()) / (1024*1024)
76
+ new_size = os.path.getsize(out_name) / (1024*1024)
77
+
78
+ print(f"✅ Saved {out_name}")
79
+ print(f" Size: {new_size:.1f} MB (Original: ~{original_size:.1f} MB)")
80
+
81
+ def main():
82
+ parser = argparse.ArgumentParser(description="Export PyTorch models to MLX")
83
+ parser.add_argument("--model", choices=["vgg16", "vgg19", "googlenet", "resnet50", "all"], default="all")
84
+ parser.add_argument("--dtype", choices=["float32", "float16", "bf16", "int8"], default="float16", help="Output data type")
85
+ args = parser.parse_args()
86
+
87
+ models_to_export = ["vgg16", "vgg19", "googlenet", "resnet50"] if args.model == "all" else [args.model]
88
+
89
+ for m in models_to_export:
90
+ export_model(m, args.dtype)
91
+
92
+ if __name__ == "__main__":
93
+ main()
googlenet_mlx_int8.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f0fb7a656a2a69cfbbd42804d38e475d09eade67681fb813b9e8f78f1930da22
3
+ size 6791204
mlx_googlenet.py CHANGED
@@ -110,19 +110,36 @@ class GoogLeNet(nn.Module):
110
  def load_npz(self, path: str):
111
  data = np.load(path)
112
 
113
- def to_mlx_weight(w):
114
- # PyTorch Conv2d weights are (out_channels, in_channels, kH, kW)
115
- # MLX expects channel-last filters: (out_channels, kH, kW, in_channels)
116
- return np.transpose(w, (0, 2, 3, 1)) if w.ndim == 4 else w
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
 
118
  def load_conv_bn(prefix, seq_mod: nn.Sequential):
119
  conv = seq_mod.layers[0]
120
  bn = seq_mod.layers[1]
121
- conv.weight = mx.array(to_mlx_weight(data[f"{prefix}.conv.weight"]))
122
- bn.weight = mx.array(data[f"{prefix}.bn.weight"])
123
- bn.bias = mx.array(data[f"{prefix}.bn.bias"])
124
- bn.running_mean = mx.array(data[f"{prefix}.bn.running_mean"])
125
- bn.running_var = mx.array(data[f"{prefix}.bn.running_var"])
 
 
126
 
127
  load_conv_bn("conv1", self.conv1)
128
  load_conv_bn("conv2", self.conv2)
 
110
  def load_npz(self, path: str):
111
  data = np.load(path)
112
 
113
+ def load_weight(key, target_module, param_name="weight", transpose=False):
114
+ # Check for standard float16/32 key
115
+ if key in data:
116
+ w = data[key]
117
+ # Check for int8 quantized key
118
+ elif f"{key}_int8" in data:
119
+ w_int8 = data[f"{key}_int8"]
120
+ scale = data[f"{key}_scale"]
121
+ # Dequantize
122
+ w = w_int8.astype(scale.dtype) * scale
123
+ else:
124
+ raise ValueError(f"Missing key {key} (or {key}_int8) in npz")
125
+
126
+ # Transpose for Conv2d weights if needed (PyTorch [O,I,H,W] -> MLX [O,H,W,I])
127
+ if transpose and w.ndim == 4:
128
+ w = np.transpose(w, (0, 2, 3, 1))
129
+
130
+ # Assign to module
131
+ target_module[param_name] = mx.array(w)
132
 
133
  def load_conv_bn(prefix, seq_mod: nn.Sequential):
134
  conv = seq_mod.layers[0]
135
  bn = seq_mod.layers[1]
136
+
137
+ load_weight(f"{prefix}.conv.weight", conv, transpose=True)
138
+
139
+ load_weight(f"{prefix}.bn.weight", bn)
140
+ load_weight(f"{prefix}.bn.bias", bn, param_name="bias")
141
+ load_weight(f"{prefix}.bn.running_mean", bn, param_name="running_mean")
142
+ load_weight(f"{prefix}.bn.running_var", bn, param_name="running_var")
143
 
144
  load_conv_bn("conv1", self.conv1)
145
  load_conv_bn("conv2", self.conv2)
mlx_resnet50.py CHANGED
@@ -114,17 +114,28 @@ class ResNet(nn.Module):
114
  def load_npz(self, path: str):
115
  data = np.load(path)
116
 
117
- def to_mlx_weight(w):
118
- return np.transpose(w, (0, 2, 3, 1)) if w.ndim == 4 else w
 
 
 
 
 
 
 
 
 
 
 
119
 
120
  def load_bn(prefix, bn):
121
- bn.weight = mx.array(data[f"{prefix}.weight"])
122
- bn.bias = mx.array(data[f"{prefix}.bias"])
123
- bn.running_mean = mx.array(data[f"{prefix}.running_mean"])
124
- bn.running_var = mx.array(data[f"{prefix}.running_var"])
125
 
126
  def load_conv(prefix, conv):
127
- conv.weight = mx.array(to_mlx_weight(data[f"{prefix}.weight"]))
128
 
129
  # Initial layers
130
  load_conv("conv1", self.conv1)
 
114
  def load_npz(self, path: str):
115
  data = np.load(path)
116
 
117
+ def load_weight(key, transpose=False):
118
+ if key in data:
119
+ w = data[key]
120
+ elif f"{key}_int8" in data:
121
+ w_int8 = data[f"{key}_int8"]
122
+ scale = data[f"{key}_scale"]
123
+ w = w_int8.astype(scale.dtype) * scale
124
+ else:
125
+ raise ValueError(f"Missing key {key} in npz")
126
+
127
+ if transpose and w.ndim == 4:
128
+ w = np.transpose(w, (0, 2, 3, 1))
129
+ return mx.array(w)
130
 
131
  def load_bn(prefix, bn):
132
+ bn.weight = load_weight(f"{prefix}.weight")
133
+ bn.bias = load_weight(f"{prefix}.bias")
134
+ bn.running_mean = load_weight(f"{prefix}.running_mean")
135
+ bn.running_var = load_weight(f"{prefix}.running_var")
136
 
137
  def load_conv(prefix, conv):
138
+ conv.weight = load_weight(f"{prefix}.weight", transpose=True)
139
 
140
  # Initial layers
141
  load_conv("conv1", self.conv1)
mlx_vgg16.py CHANGED
@@ -79,13 +79,25 @@ class VGG16(nn.Module):
79
  def load_npz(self, path: str):
80
  data = np.load(path)
81
 
82
- def to_mlx_weight(w):
83
- return np.transpose(w, (0, 2, 3, 1)) if w.ndim == 4 else w
 
 
 
 
 
 
 
 
 
 
 
84
 
85
  conv_indices = [0, 2, 5, 7, 10, 12, 14, 17, 19, 21, 24, 26, 28]
86
  for idx in conv_indices:
87
  conv = self.layers[idx]
88
  weight_key = f"features.{idx}.weight"
89
  bias_key = f"features.{idx}.bias"
90
- conv.weight = mx.array(to_mlx_weight(data[weight_key]))
91
- conv.bias = mx.array(data[bias_key])
 
 
79
  def load_npz(self, path: str):
80
  data = np.load(path)
81
 
82
+ def load_weight(key, transpose=False):
83
+ if key in data:
84
+ w = data[key]
85
+ elif f"{key}_int8" in data:
86
+ w_int8 = data[f"{key}_int8"]
87
+ scale = data[f"{key}_scale"]
88
+ w = w_int8.astype(scale.dtype) * scale
89
+ else:
90
+ raise ValueError(f"Missing key {key} in npz")
91
+
92
+ if transpose and w.ndim == 4:
93
+ w = np.transpose(w, (0, 2, 3, 1))
94
+ return mx.array(w)
95
 
96
  conv_indices = [0, 2, 5, 7, 10, 12, 14, 17, 19, 21, 24, 26, 28]
97
  for idx in conv_indices:
98
  conv = self.layers[idx]
99
  weight_key = f"features.{idx}.weight"
100
  bias_key = f"features.{idx}.bias"
101
+
102
+ conv.weight = load_weight(weight_key, transpose=True)
103
+ conv.bias = load_weight(bias_key)
mlx_vgg19.py CHANGED
@@ -92,13 +92,25 @@ class VGG19(nn.Module):
92
  def load_npz(self, path: str):
93
  data = np.load(path)
94
 
95
- def to_mlx_weight(w):
96
- return np.transpose(w, (0, 2, 3, 1)) if w.ndim == 4 else w
 
 
 
 
 
 
 
 
 
 
 
97
 
98
  conv_indices = [0, 2, 5, 7, 10, 12, 14, 16, 19, 21, 23, 25, 28, 30, 32, 34]
99
  for idx in conv_indices:
100
  conv = self.layers[idx]
101
  weight_key = f"features.{idx}.weight"
102
  bias_key = f"features.{idx}.bias"
103
- conv.weight = mx.array(to_mlx_weight(data[weight_key]))
104
- conv.bias = mx.array(data[bias_key])
 
 
92
  def load_npz(self, path: str):
93
  data = np.load(path)
94
 
95
+ def load_weight(key, transpose=False):
96
+ if key in data:
97
+ w = data[key]
98
+ elif f"{key}_int8" in data:
99
+ w_int8 = data[f"{key}_int8"]
100
+ scale = data[f"{key}_scale"]
101
+ w = w_int8.astype(scale.dtype) * scale
102
+ else:
103
+ raise ValueError(f"Missing key {key} in npz")
104
+
105
+ if transpose and w.ndim == 4:
106
+ w = np.transpose(w, (0, 2, 3, 1))
107
+ return mx.array(w)
108
 
109
  conv_indices = [0, 2, 5, 7, 10, 12, 14, 16, 19, 21, 23, 25, 28, 30, 32, 34]
110
  for idx in conv_indices:
111
  conv = self.layers[idx]
112
  weight_key = f"features.{idx}.weight"
113
  bias_key = f"features.{idx}.bias"
114
+
115
+ conv.weight = load_weight(weight_key, transpose=True)
116
+ conv.bias = load_weight(bias_key)
quantize_experiment.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+
4
+ import mlx.core as mx
5
+
6
+ import mlx.nn as nn
7
+
8
+ import numpy as np
9
+
10
+ from mlx_googlenet import GoogLeNet
11
+
12
+ import os
13
+
14
+
15
+
16
+ def main():
17
+
18
+ print("--- Attempting Extreme Quantization (4-bit / 8-bit) ---")
19
+
20
+
21
+
22
+ # Load standard model
23
+
24
+ model = GoogLeNet()
25
+
26
+ model.load_npz("googlenet_mlx_bf16.npz")
27
+
28
+
29
+
30
+ print("Original Weights Loaded.")
31
+
32
+
33
+
34
+ print("\nStrategy: Quantize weights to INT8 (Storage Optimization)")
35
+
36
+ # We will effectively store weights as (int8_weight, float16_scale)
37
+
38
+ # On load, we will do: weight = int8_weight.astype(fp16) * scale
39
+
40
+
41
+
42
+ state = model.parameters()
43
+
44
+ compressed_state = {}
45
+
46
+
47
+
48
+ total_original = 0
49
+
50
+ total_compressed = 0
51
+
52
+
53
+
54
+ for k, v in state.items():
55
+
56
+ # Flatten keys for parameters() which returns nested dicts if using trees,
57
+
58
+ # but model.parameters() returns nested dict of arrays?
59
+
60
+ # No, mlx model.parameters() returns a dict of {name: array} if flattened?
61
+
62
+ # Actually model.parameters() returns a generator or dict?
63
+
64
+ # model.parameters() returns a dict of arrays recursively?
65
+
66
+ # Let's use flatten logic manually or just iterate what we have.
67
+
68
+ pass
69
+
70
+
71
+
72
+ # Actually model.state_dict() is better for flat keys
73
+
74
+ # Wait, MLX doesn't have state_dict() like PyTorch exactly?
75
+
76
+ # mlx.nn.utils.tree_flatten(model.parameters()) gives list.
77
+
78
+
79
+
80
+ # Let's assume we work on the flattened dict structure we used for saving npz
81
+
82
+ # Our export script did: np.savez(out, **{k: v})
83
+
84
+ # Our load_npz in models does: data[key]
85
+
86
+
87
+
88
+ # So we should load the .npz FILE directly and process it,
89
+
90
+ # rather than traversing the model object which might be complex.
91
+
92
+
93
+
94
+ data = np.load("googlenet_mlx_bf16.npz")
95
+
96
+
97
+
98
+ for k in data.files:
99
+
100
+ v = mx.array(data[k])
101
+
102
+
103
+
104
+ # Check if it's a weight (conv or linear)
105
+
106
+ # Heuristic: name ends in ".weight" and ndim >= 2
107
+
108
+ if "weight" in k and v.ndim >= 2:
109
+
110
+ # Quantize to INT8
111
+
112
+ v_abs = mx.abs(v)
113
+
114
+ v_max = mx.max(v_abs)
115
+
116
+
117
+
118
+ # Scale to range [-127, 127]
119
+
120
+ # Avoid div by zero
121
+
122
+ scale = v_max / 127.0
123
+
124
+ scale = mx.where(scale == 0, 1.0, scale)
125
+
126
+
127
+
128
+ v_int8 = (v / scale).astype(mx.int8)
129
+
130
+
131
+
132
+ # Save components
133
+
134
+ compressed_state[f"{k}_int8"] = np.array(v_int8)
135
+
136
+ compressed_state[f"{k}_scale"] = np.array(scale.astype(mx.float16))
137
+
138
+
139
+
140
+ original_bytes = v.nbytes
141
+
142
+ new_bytes = v_int8.nbytes + 2 # scale size
143
+
144
+
145
+
146
+ total_original += original_bytes
147
+
148
+ total_compressed += new_bytes
149
+
150
+
151
+
152
+ else:
153
+
154
+ # Save as is (float16)
155
+
156
+ compressed_state[k] = np.array(v.astype(mx.float16))
157
+
158
+ total_original += v.nbytes
159
+
160
+ total_compressed += v.nbytes
161
+
162
+
163
+
164
+ out_name = "googlenet_mlx_int8.npz"
165
+
166
+ np.savez(out_name, **compressed_state)
167
+
168
+
169
+
170
+ print(f"\n✅ Saved {out_name}")
171
+
172
+ print(f" Original Size: {total_original / (1024*1024):.2f} MB")
173
+
174
+ print(f" Quantized Size: {total_compressed / (1024*1024):.2f} MB")
175
+
176
+ print(f" Reduction: {100 * (1 - total_compressed/total_original):.1f}%")
177
+
178
+
179
+
180
+ if __name__ == "__main__":
181
+
182
+ main()
resnet50_mlx_int8.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e1ab804f8257e78f03244ea033cdd55ed6b285317cf444c04234b3ce1d0e3961
3
+ size 25822834
vgg16_mlx_int8.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:17f8012268ac3cb74fd3c8ce5d243970b13141492b2e0e84fab1924a786ec25f
3
+ size 138384160
vgg19_mlx_int8.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:13309dd1b75cf316b0025db0c5791d6a89c654145a5f3a486f488b4bcd822b93
3
+ size 143697608