|
|
|
|
|
import argparse |
|
|
import os |
|
|
import time |
|
|
from datetime import datetime |
|
|
|
|
|
import mlx.core as mx |
|
|
import mlx.nn as nn |
|
|
import numpy as np |
|
|
import scipy.ndimage as nd |
|
|
from PIL import Image |
|
|
|
|
|
from mlx_googlenet import GoogLeNet |
|
|
from mlx_resnet50 import ResNet50 |
|
|
from mlx_vgg16 import VGG16 |
|
|
from mlx_vgg19 import VGG19 |
|
|
from mlx_alexnet import AlexNet |
|
|
|
|
|
IMAGENET_MEAN = mx.array([0.485, 0.456, 0.406]) |
|
|
IMAGENET_STD = mx.array([0.229, 0.224, 0.225]) |
|
|
LOWER_IMAGE_BOUND = (-IMAGENET_MEAN / IMAGENET_STD).reshape(1, 1, 1, 3) |
|
|
UPPER_IMAGE_BOUND = ((1.0 - IMAGENET_MEAN) / IMAGENET_STD).reshape(1, 1, 1, 3) |
|
|
|
|
|
|
|
|
def load_image(path, target_width=None): |
|
|
img = Image.open(path).convert("RGB") |
|
|
if target_width: |
|
|
w, h = img.size |
|
|
scale = target_width / w |
|
|
new_h = int(h * scale) |
|
|
img = img.resize((target_width, new_h), Image.LANCZOS) |
|
|
return np.array(img) |
|
|
|
|
|
|
|
|
def preprocess(img_np): |
|
|
x = mx.array(img_np, dtype=mx.float32) / 255.0 |
|
|
x = (x - IMAGENET_MEAN) / IMAGENET_STD |
|
|
x = x[None, ...] |
|
|
return x |
|
|
|
|
|
|
|
|
def deprocess(x): |
|
|
x = x[0] |
|
|
x = x * IMAGENET_STD + IMAGENET_MEAN |
|
|
x = mx.clip(x, 0.0, 1.0) |
|
|
x = (x * 255.0).astype(mx.uint8) |
|
|
return np.array(x) |
|
|
|
|
|
|
|
|
def resize_bilinear(x, new_h, new_w): |
|
|
b, h, w, c = x.shape |
|
|
out = mx.zeros((b, new_h, new_w, c)) |
|
|
for bi in range(b): |
|
|
for ci in range(c): |
|
|
out[bi, :, :, ci] = mx.array( |
|
|
nd.zoom(np.array(x[bi, :, :, ci]), zoom=(new_h / h, new_w / w), order=1) |
|
|
) |
|
|
return out |
|
|
|
|
|
|
|
|
def gaussian_kernel(sigma, truncate=4.0, fixed_radius=None): |
|
|
"""Generates a 1D Gaussian kernel.""" |
|
|
if fixed_radius is not None: |
|
|
radius = fixed_radius |
|
|
else: |
|
|
radius = int(truncate * sigma + 0.5) |
|
|
|
|
|
x = mx.arange(-radius, radius + 1) |
|
|
kernel = mx.exp(-0.5 * (x / sigma) ** 2) |
|
|
kernel = kernel / kernel.sum() |
|
|
return kernel |
|
|
|
|
|
|
|
|
def gaussian_blur_2d(x, sigma, fixed_radius=None): |
|
|
"""Applies Gaussian blur using separable 1D convolutions in MLX.""" |
|
|
kernel = gaussian_kernel(sigma, fixed_radius=fixed_radius) |
|
|
kernel = kernel.astype(x.dtype) |
|
|
k_size = kernel.shape[0] |
|
|
C = x.shape[-1] |
|
|
|
|
|
k_x = kernel.reshape(1, 1, k_size, 1) |
|
|
k_x = mx.repeat(k_x, C, axis=0) |
|
|
k_y = kernel.reshape(1, k_size, 1, 1) |
|
|
k_y = mx.repeat(k_y, C, axis=0) |
|
|
|
|
|
pad = k_size // 2 |
|
|
|
|
|
x = mx.conv2d(x, k_x, stride=1, padding=(0, pad), groups=C) |
|
|
x = mx.conv2d(x, k_y, stride=1, padding=(pad, 0), groups=C) |
|
|
return x |
|
|
|
|
|
|
|
|
def smooth_gradients(grad, sigma, fixed_radius=None): |
|
|
"""Cascade 3 Gaussian blurs (sigma multipliers 0.5/1/2) using native MLX ops.""" |
|
|
sigmas = [sigma * 0.5, sigma * 1.0, sigma * 2.0] |
|
|
smoothed = [] |
|
|
for s in sigmas: |
|
|
smoothed.append(gaussian_blur_2d(grad, s, fixed_radius=fixed_radius)) |
|
|
|
|
|
g_total = smoothed[0] |
|
|
for i in range(1, len(smoothed)): |
|
|
g_total = g_total + smoothed[i] |
|
|
return g_total / len(smoothed) |
|
|
|
|
|
|
|
|
def get_pyramid_shapes(base_shape, num_octaves, scale): |
|
|
h, w = base_shape |
|
|
shapes = [] |
|
|
for level in range(num_octaves): |
|
|
exponent = level - num_octaves + 1 |
|
|
nh = max(1, int(round(h * (scale**exponent)))) |
|
|
nw = max(1, int(round(w * (scale**exponent)))) |
|
|
shapes.append((nh, nw)) |
|
|
return shapes |
|
|
|
|
|
|
|
|
def deepdream( |
|
|
model, |
|
|
img_np, |
|
|
layers, |
|
|
steps, |
|
|
lr, |
|
|
num_octaves, |
|
|
scale, |
|
|
jitter=32, |
|
|
smoothing=0.5, |
|
|
guide_img_np=None, |
|
|
): |
|
|
img = preprocess(img_np) |
|
|
base_h, base_w = img.shape[1:3] |
|
|
pyramid_shapes = get_pyramid_shapes((base_h, base_w), num_octaves, scale) |
|
|
|
|
|
for level, (nh, nw) in enumerate(pyramid_shapes): |
|
|
img = resize_bilinear(img, nh, nw) |
|
|
|
|
|
guide_features = {} |
|
|
if guide_img_np is not None: |
|
|
guide_resized = resize_bilinear(preprocess(guide_img_np), nh, nw) |
|
|
_, guide_features = model.forward_with_endpoints(guide_resized) |
|
|
|
|
|
def loss_fn(x): |
|
|
endpoints = model.forward_with_endpoints(x)[1] |
|
|
loss = mx.zeros(()) |
|
|
for name in layers: |
|
|
act = endpoints[name] |
|
|
if guide_img_np is not None: |
|
|
guide_act = guide_features[name] |
|
|
loss = loss + mx.mean(act * guide_act) |
|
|
else: |
|
|
loss = loss + mx.mean(act * act) |
|
|
return loss / len(layers) |
|
|
|
|
|
|
|
|
max_effective_sigma = 2.0 * (2.0 + smoothing) |
|
|
fixed_radius = int(4.0 * max_effective_sigma + 0.5) |
|
|
|
|
|
@mx.compile |
|
|
def update_step(x, sigma): |
|
|
loss, grads = mx.value_and_grad(loss_fn)(x) |
|
|
g = smooth_gradients(grads, sigma, fixed_radius=fixed_radius) |
|
|
g = g - mx.mean(g) |
|
|
g = g / (mx.std(g) + 1e-8) |
|
|
x = x + lr * g |
|
|
x = mx.minimum(mx.maximum(x, LOWER_IMAGE_BOUND), UPPER_IMAGE_BOUND) |
|
|
return x, loss |
|
|
|
|
|
for it in range(steps): |
|
|
ox, oy = np.random.randint(-jitter, jitter + 1, 2) |
|
|
rolled = mx.roll(mx.roll(img, ox, axis=1), oy, axis=2) |
|
|
|
|
|
sigma_val = ((it + 1) / steps) * 2.0 + smoothing |
|
|
|
|
|
rolled, loss = update_step(rolled, mx.array(sigma_val)) |
|
|
|
|
|
img = mx.roll(mx.roll(rolled, -ox, axis=1), -oy, axis=2) |
|
|
|
|
|
return deprocess(img) |
|
|
|
|
|
|
|
|
def get_weights_path(model_name, explicit_path=None): |
|
|
if explicit_path: |
|
|
return explicit_path |
|
|
|
|
|
|
|
|
path = f"{model_name}_mlx.npz" |
|
|
if os.path.exists(path): |
|
|
return path |
|
|
|
|
|
|
|
|
bf16_path = f"{model_name}_mlx_bf16.npz" |
|
|
if os.path.exists(bf16_path): |
|
|
return bf16_path |
|
|
|
|
|
return path |
|
|
|
|
|
|
|
|
def run_dream_for_model(model_name, args, img_np): |
|
|
print(f"--- Running DeepDream with {model_name} ---") |
|
|
|
|
|
|
|
|
|
|
|
PRESETS = { |
|
|
"nb14": { |
|
|
"layers": ["relu3_3"], |
|
|
"steps": 10, |
|
|
"lr": 0.06, |
|
|
"octaves": 6, |
|
|
"scale": 1.4, |
|
|
"jitter": 32, |
|
|
"smoothing": 0.5, |
|
|
}, |
|
|
"nb20": { |
|
|
"layers": ["relu4_2"], |
|
|
"steps": 10, |
|
|
"lr": 0.06, |
|
|
"octaves": 6, |
|
|
"scale": 1.4, |
|
|
"jitter": 32, |
|
|
"smoothing": 0.5, |
|
|
}, |
|
|
"nb28": { |
|
|
"layers": ["relu5_3"], |
|
|
"steps": 10, |
|
|
"lr": 0.06, |
|
|
"octaves": 6, |
|
|
"scale": 1.4, |
|
|
"jitter": 32, |
|
|
"smoothing": 0.5, |
|
|
}, |
|
|
} |
|
|
|
|
|
|
|
|
current_layers = args.layers |
|
|
current_steps = args.steps |
|
|
current_lr = args.lr |
|
|
current_octaves = args.octaves |
|
|
current_scale = args.scale |
|
|
current_jitter = args.jitter |
|
|
current_smoothing = args.smoothing |
|
|
|
|
|
if model_name == "vgg16": |
|
|
model = VGG16() |
|
|
weights = get_weights_path("vgg16", args.weights) |
|
|
default_layers = ["relu4_3"] |
|
|
if args.preset: |
|
|
p = PRESETS[args.preset] |
|
|
|
|
|
current_layers = p["layers"] |
|
|
current_steps = p["steps"] |
|
|
current_lr = p["lr"] |
|
|
current_octaves = p["octaves"] |
|
|
current_scale = p["scale"] |
|
|
current_jitter = p["jitter"] |
|
|
current_smoothing = p["smoothing"] |
|
|
|
|
|
elif model_name == "vgg19": |
|
|
model = VGG19() |
|
|
weights = get_weights_path("vgg19", args.weights) |
|
|
default_layers = ["relu4_4"] |
|
|
if args.preset and args.preset in PRESETS: |
|
|
p = PRESETS[args.preset] |
|
|
current_layers = p["layers"] |
|
|
current_steps = p["steps"] |
|
|
current_lr = p["lr"] |
|
|
current_octaves = p["octaves"] |
|
|
current_scale = p["scale"] |
|
|
current_jitter = p["jitter"] |
|
|
current_smoothing = p["smoothing"] |
|
|
|
|
|
elif model_name == "resnet50": |
|
|
model = ResNet50() |
|
|
weights = get_weights_path("resnet50", args.weights) |
|
|
default_layers = ["layer4_2"] |
|
|
|
|
|
elif model_name == "alexnet": |
|
|
model = AlexNet() |
|
|
weights = get_weights_path("alexnet", args.weights) |
|
|
default_layers = ["relu5"] |
|
|
|
|
|
else: |
|
|
model = GoogLeNet() |
|
|
weights = get_weights_path("googlenet", args.weights) |
|
|
default_layers = ["inception3b", "inception4c", "inception4d"] |
|
|
|
|
|
if not os.path.exists(weights): |
|
|
print(f"Error: Weights NPZ not found: {weights}. Skipping {model_name}.") |
|
|
return |
|
|
|
|
|
print(f"Loading weights from: {weights}") |
|
|
model.load_npz(weights) |
|
|
|
|
|
guide_img_np = None |
|
|
if args.guide: |
|
|
print(f"Using guide image: {args.guide}") |
|
|
guide_img_np = load_image(args.guide, args.width) |
|
|
|
|
|
start_time = time.time() |
|
|
start_timestamp = datetime.now() |
|
|
|
|
|
dreamed = deepdream( |
|
|
model, |
|
|
img_np, |
|
|
layers=current_layers or default_layers, |
|
|
steps=current_steps, |
|
|
lr=current_lr, |
|
|
num_octaves=current_octaves, |
|
|
scale=current_scale, |
|
|
jitter=current_jitter, |
|
|
smoothing=current_smoothing, |
|
|
guide_img_np=guide_img_np, |
|
|
) |
|
|
|
|
|
end_time = time.time() |
|
|
elapsed = end_time - start_time |
|
|
|
|
|
if args.output: |
|
|
out = args.output |
|
|
else: |
|
|
base_name = os.path.splitext(os.path.basename(args.input))[0] |
|
|
formatted_time = f"{elapsed:.2f}s" |
|
|
formatted_date = start_timestamp.strftime("%m%d") |
|
|
formatted_timestamp = start_timestamp.strftime("%H%M%S") |
|
|
out = f"{base_name}_dream_{model_name}_{formatted_time}_{formatted_date}_{formatted_timestamp}.jpg" |
|
|
|
|
|
Image.fromarray(dreamed).save(out) |
|
|
print(f"Saved {out}\n") |
|
|
|
|
|
|
|
|
def parse_args(): |
|
|
p = argparse.ArgumentParser(description="DeepDream with MLX (Compiled)") |
|
|
p.add_argument("--input", required=True, help="Input image path") |
|
|
p.add_argument("--output", help="Output image path (optional)") |
|
|
p.add_argument("--guide", help="Guide image for guided dreaming") |
|
|
|
|
|
p.add_argument( |
|
|
"--width", |
|
|
type=int, |
|
|
default=None, |
|
|
help="Resize input to width (maintains aspect ratio)", |
|
|
) |
|
|
p.add_argument( |
|
|
"--img_width", type=int, help="Alias for --width", dest="width" |
|
|
) |
|
|
|
|
|
p.add_argument( |
|
|
"--model", |
|
|
choices=["vgg16", "vgg19", "googlenet", "resnet50", "alexnet", "all"], |
|
|
default="vgg16", |
|
|
help="Model to use. 'all' runs all models.", |
|
|
) |
|
|
p.add_argument("--preset", choices=["nb14", "nb20", "nb28"], help="VGG16 presets") |
|
|
|
|
|
p.add_argument("--layers", nargs="+", help="Layers to maximize") |
|
|
p.add_argument( |
|
|
"--steps", type=int, default=10, help="Gradient ascent steps per octave" |
|
|
) |
|
|
p.add_argument("--lr", type=float, default=0.09, help="Learning rate (step size)") |
|
|
|
|
|
p.add_argument("--octaves", type=int, default=4, help="Number of image octaves") |
|
|
p.add_argument( |
|
|
"--pyramid_size", type=int, dest="octaves", help="Alias for --octaves" |
|
|
) |
|
|
|
|
|
p.add_argument("--scale", type=float, default=1.8, help="Octave scale factor") |
|
|
p.add_argument( |
|
|
"--pyramid_ratio", type=float, dest="scale", help="Alias for --scale" |
|
|
) |
|
|
p.add_argument( |
|
|
"--octave_scale", type=float, dest="scale", help="Alias for --scale" |
|
|
) |
|
|
|
|
|
p.add_argument("--jitter", type=int, default=32, help="Jitter amount (pixels)") |
|
|
|
|
|
p.add_argument( |
|
|
"--smoothing", type=float, default=0.5, help="Gradient smoothing strength" |
|
|
) |
|
|
p.add_argument( |
|
|
"--smoothing_coefficient", |
|
|
type=float, |
|
|
dest="smoothing", |
|
|
help="Alias for --smoothing", |
|
|
) |
|
|
|
|
|
p.add_argument("--weights", help="Custom weights path") |
|
|
|
|
|
return p.parse_args() |
|
|
|
|
|
|
|
|
def main(): |
|
|
args = parse_args() |
|
|
img_np = load_image(args.input, args.width) |
|
|
|
|
|
if args.model == "all": |
|
|
models = ["vgg16", "vgg19", "googlenet", "resnet50", "alexnet"] |
|
|
if args.output: |
|
|
print( |
|
|
"Warning: --output argument ignored because --model='all' was selected." |
|
|
) |
|
|
args.output = None |
|
|
for m in models: |
|
|
run_dream_for_model(m, args, img_np) |
|
|
else: |
|
|
run_dream_for_model(args.model, args, img_np) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|