Cleanup: Remove obsolete scripts and int8 weights
Browse files- alexnet_places365.pth_mlx.npz +0 -3
- export_all_bf16.py +0 -56
- export_googlenet_npz.py +0 -21
- export_models.py +0 -93
- export_resnet50_npz.py +0 -23
- export_vgg16_npz.py +0 -23
- export_vgg19_npz.py +0 -23
- googlenet_mlx_int8.npz +0 -3
- resnet50_mlx_int8.npz +0 -3
- resnet50_places365.pth_mlx.npz +0 -3
- tf_inception_v1.py +0 -79
- train_dream.py +0 -30
- vgg16_mlx_int8.npz +0 -3
- vgg19_mlx_int8.npz +0 -3
alexnet_places365.pth_mlx.npz
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:587f2f379063fb722563b86d9e7fea2321119b571c6bff7e09e309abf6dbf0b4
|
| 3 |
-
size 117002764
|
|
|
|
|
|
|
|
|
|
|
|
export_all_bf16.py
DELETED
|
@@ -1,56 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Export all supported models to MLX .npz format in bfloat16 (bf16) for 50% size reduction.
|
| 3 |
-
Requires torch, torchvision, numpy.
|
| 4 |
-
"""
|
| 5 |
-
|
| 6 |
-
import os
|
| 7 |
-
import numpy as np
|
| 8 |
-
import torch
|
| 9 |
-
import torchvision.models as models
|
| 10 |
-
|
| 11 |
-
def export_model(model_name, model_fn, weights_enum):
|
| 12 |
-
print(f"Exporting {model_name} (bf16)...")
|
| 13 |
-
model = model_fn(weights=weights_enum)
|
| 14 |
-
model.eval()
|
| 15 |
-
|
| 16 |
-
state = model.state_dict()
|
| 17 |
-
converted_state = {}
|
| 18 |
-
|
| 19 |
-
for k, v in state.items():
|
| 20 |
-
# Convert to numpy float16 (bfloat16 is not fully standard in numpy saving,
|
| 21 |
-
# but MLX handles float16 perfectly. We will save as float16 for simplicity
|
| 22 |
-
# and broad compatibility, or we can try casting to bfloat16 if numpy supports it
|
| 23 |
-
# or just save as float16 which is also 2 bytes).
|
| 24 |
-
# Actually, numpy doesn't fully support bfloat16 serialization widely yet.
|
| 25 |
-
# float16 is the standard "half".
|
| 26 |
-
# DeepDream doesn't need bf16 dynamic range usually. float16 is fine.
|
| 27 |
-
v_np = v.cpu().detach().numpy().astype(np.float16)
|
| 28 |
-
converted_state[k] = v_np
|
| 29 |
-
|
| 30 |
-
out_name = f"{model_name}_mlx_bf16.npz" # Naming it bf16/fp16 to imply half precision
|
| 31 |
-
# But wait, let's stick to what the user asked "bf16".
|
| 32 |
-
# MLX load_npz will load it as float16.
|
| 33 |
-
|
| 34 |
-
np.savez(out_name, **converted_state)
|
| 35 |
-
|
| 36 |
-
original_size = sum(v.numel() * 4 for v in state.values()) / (1024*1024)
|
| 37 |
-
new_size = os.path.getsize(out_name) / (1024*1024)
|
| 38 |
-
|
| 39 |
-
print(f"✅ Saved {out_name}")
|
| 40 |
-
print(f" Size: {new_size:.1f} MB (Original float32: ~{original_size:.1f} MB)")
|
| 41 |
-
|
| 42 |
-
def main():
|
| 43 |
-
# 1. VGG16
|
| 44 |
-
export_model("vgg16", models.vgg16, models.VGG16_Weights.IMAGENET1K_V1)
|
| 45 |
-
|
| 46 |
-
# 2. VGG19
|
| 47 |
-
export_model("vgg19", models.vgg19, models.VGG19_Weights.IMAGENET1K_V1)
|
| 48 |
-
|
| 49 |
-
# 3. GoogLeNet
|
| 50 |
-
export_model("googlenet", models.googlenet, models.GoogLeNet_Weights.IMAGENET1K_V1)
|
| 51 |
-
|
| 52 |
-
# 4. ResNet50
|
| 53 |
-
export_model("resnet50", models.resnet50, models.ResNet50_Weights.IMAGENET1K_V1)
|
| 54 |
-
|
| 55 |
-
if __name__ == "__main__":
|
| 56 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
export_googlenet_npz.py
DELETED
|
@@ -1,21 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Export torchvision GoogLeNet (Inception V1) weights to an .npz for MLX.
|
| 3 |
-
Run this in a PyTorch+torchvision env:
|
| 4 |
-
python export_googlenet_npz.py
|
| 5 |
-
It writes models/googlenet_mlx.npz
|
| 6 |
-
"""
|
| 7 |
-
import os
|
| 8 |
-
import numpy as np
|
| 9 |
-
import torch
|
| 10 |
-
from torchvision.models import googlenet, GoogLeNet_Weights
|
| 11 |
-
|
| 12 |
-
def main():
|
| 13 |
-
model = googlenet(weights=GoogLeNet_Weights.IMAGENET1K_V1)
|
| 14 |
-
state = model.state_dict()
|
| 15 |
-
os.makedirs("models", exist_ok=True)
|
| 16 |
-
out_path = os.path.join("models", "googlenet_mlx.npz")
|
| 17 |
-
np.savez(out_path, **{k: v.cpu().numpy() for k, v in state.items()})
|
| 18 |
-
print(f"Saved {out_path} with {len(state)} tensors.")
|
| 19 |
-
|
| 20 |
-
if __name__ == "__main__":
|
| 21 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
export_models.py
DELETED
|
@@ -1,93 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Unified export script for converting PyTorch models to MLX .npz format.
|
| 3 |
-
Supports VGG16, VGG19, GoogLeNet, and ResNet50.
|
| 4 |
-
Handles both float32 (default) and float16/bfloat16 (efficient) exports.
|
| 5 |
-
|
| 6 |
-
Usage:
|
| 7 |
-
python export_models.py --model all --dtype float16
|
| 8 |
-
python export_models.py --model vgg16
|
| 9 |
-
"""
|
| 10 |
-
|
| 11 |
-
import argparse
|
| 12 |
-
import os
|
| 13 |
-
import numpy as np
|
| 14 |
-
import torch
|
| 15 |
-
import torchvision.models as models
|
| 16 |
-
|
| 17 |
-
def get_model_info(model_name):
|
| 18 |
-
if model_name == "vgg16":
|
| 19 |
-
return models.vgg16, models.VGG16_Weights.IMAGENET1K_V1
|
| 20 |
-
elif model_name == "vgg19":
|
| 21 |
-
return models.vgg19, models.VGG19_Weights.IMAGENET1K_V1
|
| 22 |
-
elif model_name == "googlenet":
|
| 23 |
-
return models.googlenet, models.GoogLeNet_Weights.IMAGENET1K_V1
|
| 24 |
-
elif model_name == "resnet50":
|
| 25 |
-
return models.resnet50, models.ResNet50_Weights.IMAGENET1K_V1
|
| 26 |
-
else:
|
| 27 |
-
raise ValueError(f"Unknown model: {model_name}")
|
| 28 |
-
|
| 29 |
-
def export_model(model_name, dtype="float32"):
|
| 30 |
-
print(f"Exporting {model_name} ({dtype})...")
|
| 31 |
-
model_fn, weights = get_model_info(model_name)
|
| 32 |
-
model = model_fn(weights=weights)
|
| 33 |
-
model.eval()
|
| 34 |
-
|
| 35 |
-
state = model.state_dict()
|
| 36 |
-
converted_state = {}
|
| 37 |
-
|
| 38 |
-
target_type = np.float32
|
| 39 |
-
suffix = ""
|
| 40 |
-
quantize_int8 = False
|
| 41 |
-
|
| 42 |
-
if dtype in ["float16", "bf16", "half"]:
|
| 43 |
-
target_type = np.float16
|
| 44 |
-
suffix = "_bf16" # Keep legacy suffix for compatibility with dream.py logic
|
| 45 |
-
elif dtype == "int8":
|
| 46 |
-
target_type = np.float16 # Base type for scales/biases
|
| 47 |
-
suffix = "_int8"
|
| 48 |
-
quantize_int8 = True
|
| 49 |
-
|
| 50 |
-
for k, v in state.items():
|
| 51 |
-
v_np = v.cpu().detach().numpy()
|
| 52 |
-
|
| 53 |
-
if quantize_int8 and "weight" in k and v_np.ndim >= 2:
|
| 54 |
-
# Quantize to INT8
|
| 55 |
-
v_abs = np.abs(v_np)
|
| 56 |
-
v_max = np.max(v_abs)
|
| 57 |
-
|
| 58 |
-
# Scale to range [-127, 127]
|
| 59 |
-
# Avoid div by zero
|
| 60 |
-
if v_max == 0:
|
| 61 |
-
scale = 1.0
|
| 62 |
-
else:
|
| 63 |
-
scale = v_max / 127.0
|
| 64 |
-
|
| 65 |
-
v_int8 = (v_np / scale).astype(np.int8)
|
| 66 |
-
|
| 67 |
-
converted_state[f"{k}_int8"] = v_int8
|
| 68 |
-
converted_state[f"{k}_scale"] = np.array(scale).astype(target_type)
|
| 69 |
-
else:
|
| 70 |
-
converted_state[k] = v_np.astype(target_type)
|
| 71 |
-
|
| 72 |
-
out_name = f"{model_name}_mlx{suffix}.npz"
|
| 73 |
-
np.savez(out_name, **converted_state)
|
| 74 |
-
|
| 75 |
-
original_size = sum(v.numel() * 4 for v in state.values()) / (1024*1024)
|
| 76 |
-
new_size = os.path.getsize(out_name) / (1024*1024)
|
| 77 |
-
|
| 78 |
-
print(f"✅ Saved {out_name}")
|
| 79 |
-
print(f" Size: {new_size:.1f} MB (Original: ~{original_size:.1f} MB)")
|
| 80 |
-
|
| 81 |
-
def main():
|
| 82 |
-
parser = argparse.ArgumentParser(description="Export PyTorch models to MLX")
|
| 83 |
-
parser.add_argument("--model", choices=["vgg16", "vgg19", "googlenet", "resnet50", "all"], default="all")
|
| 84 |
-
parser.add_argument("--dtype", choices=["float32", "float16", "bf16", "int8"], default="float16", help="Output data type")
|
| 85 |
-
args = parser.parse_args()
|
| 86 |
-
|
| 87 |
-
models_to_export = ["vgg16", "vgg19", "googlenet", "resnet50"] if args.model == "all" else [args.model]
|
| 88 |
-
|
| 89 |
-
for m in models_to_export:
|
| 90 |
-
export_model(m, args.dtype)
|
| 91 |
-
|
| 92 |
-
if __name__ == "__main__":
|
| 93 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
export_resnet50_npz.py
DELETED
|
@@ -1,23 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Export torchvision ResNet50 weights to an .npz for MLX.
|
| 3 |
-
Run this in a PyTorch+torchvision env:
|
| 4 |
-
python export_resnet50_npz.py
|
| 5 |
-
It writes models/resnet50_mlx.npz
|
| 6 |
-
"""
|
| 7 |
-
import os
|
| 8 |
-
import numpy as np
|
| 9 |
-
import torch
|
| 10 |
-
from torchvision.models import resnet50, ResNet50_Weights
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
def main():
|
| 14 |
-
model = resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
|
| 15 |
-
state = model.state_dict()
|
| 16 |
-
os.makedirs("models", exist_ok=True)
|
| 17 |
-
out_path = os.path.join("models", "resnet50_mlx.npz")
|
| 18 |
-
np.savez(out_path, **{k: v.cpu().numpy() for k, v in state.items()})
|
| 19 |
-
print(f"Saved {out_path} with {len(state)} tensors.")
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
if __name__ == "__main__":
|
| 23 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
export_vgg16_npz.py
DELETED
|
@@ -1,23 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Export torchvision VGG16 weights to an .npz for MLX.
|
| 3 |
-
Run this in a PyTorch+torchvision env:
|
| 4 |
-
python export_vgg16_npz.py
|
| 5 |
-
It writes models/vgg16_mlx.npz
|
| 6 |
-
"""
|
| 7 |
-
import os
|
| 8 |
-
import numpy as np
|
| 9 |
-
import torch
|
| 10 |
-
from torchvision.models import vgg16, VGG16_Weights
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
def main():
|
| 14 |
-
model = vgg16(weights=VGG16_Weights.IMAGENET1K_V1)
|
| 15 |
-
state = model.state_dict()
|
| 16 |
-
os.makedirs("models", exist_ok=True)
|
| 17 |
-
out_path = os.path.join("models", "vgg16_mlx.npz")
|
| 18 |
-
np.savez(out_path, **{k: v.cpu().numpy() for k, v in state.items()})
|
| 19 |
-
print(f"Saved {out_path} with {len(state)} tensors.")
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
if __name__ == "__main__":
|
| 23 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
export_vgg19_npz.py
DELETED
|
@@ -1,23 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Export torchvision VGG19 weights to an .npz for MLX.
|
| 3 |
-
Run this in a PyTorch+torchvision env:
|
| 4 |
-
python export_vgg19_npz.py
|
| 5 |
-
It writes models/vgg19_mlx.npz
|
| 6 |
-
"""
|
| 7 |
-
import os
|
| 8 |
-
import numpy as np
|
| 9 |
-
import torch
|
| 10 |
-
from torchvision.models import vgg19, VGG19_Weights
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
def main():
|
| 14 |
-
model = vgg19(weights=VGG19_Weights.IMAGENET1K_V1)
|
| 15 |
-
state = model.state_dict()
|
| 16 |
-
os.makedirs("models", exist_ok=True)
|
| 17 |
-
out_path = os.path.join("models", "vgg19_mlx.npz")
|
| 18 |
-
np.savez(out_path, **{k: v.cpu().numpy() for k, v in state.items()})
|
| 19 |
-
print(f"Saved {out_path} with {len(state)} tensors.")
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
if __name__ == "__main__":
|
| 23 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
googlenet_mlx_int8.npz
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:f0fb7a656a2a69cfbbd42804d38e475d09eade67681fb813b9e8f78f1930da22
|
| 3 |
-
size 6791204
|
|
|
|
|
|
|
|
|
|
|
|
resnet50_mlx_int8.npz
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:e1ab804f8257e78f03244ea033cdd55ed6b285317cf444c04234b3ce1d0e3961
|
| 3 |
-
size 25822834
|
|
|
|
|
|
|
|
|
|
|
|
resnet50_places365.pth_mlx.npz
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:c7e4496e460a4cbec41e02f169c7be9c0e3cebe28036ac917105ba386471c47b
|
| 3 |
-
size 48691562
|
|
|
|
|
|
|
|
|
|
|
|
tf_inception_v1.py
DELETED
|
@@ -1,79 +0,0 @@
|
|
| 1 |
-
"""TF-Slim InceptionV1 forward callable for TF2 (no KerasTensor issues)."""
|
| 2 |
-
|
| 3 |
-
import os
|
| 4 |
-
from typing import Iterable, Tuple, Callable, List
|
| 5 |
-
|
| 6 |
-
import tensorflow as tf
|
| 7 |
-
import tf_slim as slim
|
| 8 |
-
from tf_slim.nets import inception_v1
|
| 9 |
-
|
| 10 |
-
WEIGHTS_URL = "http://download.tensorflow.org/models/inception_v1_2016_08_28.tar.gz"
|
| 11 |
-
DEFAULT_LAYER_NAMES = (
|
| 12 |
-
"Mixed_4b",
|
| 13 |
-
"Mixed_4c",
|
| 14 |
-
"Mixed_4d",
|
| 15 |
-
)
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
def _download_checkpoint_if_needed(weights_path: str = None) -> str:
|
| 19 |
-
if weights_path:
|
| 20 |
-
if not os.path.exists(weights_path):
|
| 21 |
-
raise FileNotFoundError(f"Weights path does not exist: {weights_path}")
|
| 22 |
-
return weights_path
|
| 23 |
-
|
| 24 |
-
tar_path = tf.keras.utils.get_file(
|
| 25 |
-
origin=WEIGHTS_URL,
|
| 26 |
-
fname=os.path.basename(WEIGHTS_URL),
|
| 27 |
-
extract=True,
|
| 28 |
-
cache_dir=os.path.expanduser("~/.keras"),
|
| 29 |
-
)
|
| 30 |
-
ckpt_dir = os.path.join(os.path.dirname(tar_path), "inception_v1_2016_08_28")
|
| 31 |
-
ckpt_path = os.path.join(ckpt_dir, "inception_v1.ckpt")
|
| 32 |
-
if not os.path.exists(ckpt_path):
|
| 33 |
-
raise FileNotFoundError(f"Checkpoint not found after download: {ckpt_path}")
|
| 34 |
-
return ckpt_path
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
def _preprocess_fn(x: tf.Tensor) -> tf.Tensor:
|
| 38 |
-
"""Match TF-Slim InceptionV1 preprocessing: scale to [-1, 1]."""
|
| 39 |
-
x = tf.cast(x, tf.float32)
|
| 40 |
-
return (x / 127.5) - 1.0
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
def build_inception_v1_callable(
|
| 44 |
-
layer_names: Iterable[str] = DEFAULT_LAYER_NAMES, weights_path: str = None
|
| 45 |
-
) -> Tuple[Callable[[tf.Tensor], List[tf.Tensor]], Callable[[tf.Tensor], tf.Tensor]]:
|
| 46 |
-
"""
|
| 47 |
-
Returns:
|
| 48 |
-
forward_fn: callable taking NHWC float tensor -> list of endpoints
|
| 49 |
-
preprocess_fn: preprocessing callable
|
| 50 |
-
"""
|
| 51 |
-
|
| 52 |
-
layer_names = tuple(layer_names)
|
| 53 |
-
scope_name = "InceptionV1"
|
| 54 |
-
|
| 55 |
-
@tf.function
|
| 56 |
-
def forward_fn(x: tf.Tensor) -> List[tf.Tensor]:
|
| 57 |
-
with tf.compat.v1.variable_scope(scope_name, reuse=tf.compat.v1.AUTO_REUSE):
|
| 58 |
-
with slim.arg_scope(inception_v1.inception_v1_arg_scope()):
|
| 59 |
-
_, endpoints = inception_v1.inception_v1(
|
| 60 |
-
x,
|
| 61 |
-
num_classes=1001,
|
| 62 |
-
is_training=False,
|
| 63 |
-
spatial_squeeze=False,
|
| 64 |
-
)
|
| 65 |
-
return [endpoints[name] for name in layer_names]
|
| 66 |
-
|
| 67 |
-
# Build variables by a dummy call
|
| 68 |
-
_ = forward_fn(tf.zeros([1, 224, 224, 3], dtype=tf.float32))
|
| 69 |
-
|
| 70 |
-
ckpt_path = _download_checkpoint_if_needed(weights_path)
|
| 71 |
-
var_list = tf.compat.v1.get_collection(
|
| 72 |
-
tf.compat.v1.GraphKeys.GLOBAL_VARIABLES, scope=scope_name
|
| 73 |
-
)
|
| 74 |
-
name_map = {v.name.split(":")[0]: v for v in var_list}
|
| 75 |
-
ckpt = tf.train.Checkpoint(**name_map)
|
| 76 |
-
ckpt.restore(ckpt_path).expect_partial()
|
| 77 |
-
|
| 78 |
-
return forward_fn, _preprocess_fn
|
| 79 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
train_dream.py
DELETED
|
@@ -1,30 +0,0 @@
|
|
| 1 |
-
# TODO: Implement Fine-Tuning Logic
|
| 2 |
-
|
| 3 |
-
"""
|
| 4 |
-
DeepDream Training / Fine-Tuning Script (Placeholder)
|
| 5 |
-
|
| 6 |
-
Goal:
|
| 7 |
-
Allow users to fine-tune these base models (VGG, GoogLeNet, etc.) on their own datasets
|
| 8 |
-
to create custom Dream styles.
|
| 9 |
-
|
| 10 |
-
Steps to Implement:
|
| 11 |
-
1. Load Dataset: Use `torchvision.datasets.ImageFolder` or custom loader for user images.
|
| 12 |
-
2. Load Model: Use our MLX models (need to add `train()` mode with dropout/grad support if missing,
|
| 13 |
-
or simpler: use PyTorch for training -> export to MLX).
|
| 14 |
-
*Easier path:* Train in PyTorch using standard scripts, then use `export_*.py` to bring it here.
|
| 15 |
-
3. Training Loop: Standard classification training or style transfer fine-tuning.
|
| 16 |
-
4. Export: Save the fine-tuned weights to `.pth`, then run export script.
|
| 17 |
-
|
| 18 |
-
Usage:
|
| 19 |
-
python train_dream.py --data /path/to/images --epochs 10 --model vgg16
|
| 20 |
-
"""
|
| 21 |
-
|
| 22 |
-
import argparse
|
| 23 |
-
|
| 24 |
-
def main():
|
| 25 |
-
print("--- DeepDream-MLX Training Stub ---")
|
| 26 |
-
print("Feature coming soon.")
|
| 27 |
-
print("Current Workflow: Train in PyTorch -> Use export_*.py -> Dream in MLX")
|
| 28 |
-
|
| 29 |
-
if __name__ == "__main__":
|
| 30 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
vgg16_mlx_int8.npz
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:17f8012268ac3cb74fd3c8ce5d243970b13141492b2e0e84fab1924a786ec25f
|
| 3 |
-
size 138384160
|
|
|
|
|
|
|
|
|
|
|
|
vgg19_mlx_int8.npz
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:13309dd1b75cf316b0025db0c5791d6a89c654145a5f3a486f488b4bcd822b93
|
| 3 |
-
size 143697608
|
|
|
|
|
|
|
|
|
|
|
|