Upload folder using huggingface_hub
Browse files- .gitignore +14 -2
- Agents.md +62 -0
- README.md +60 -142
- alexnet_places365.pth_mlx.npz +3 -0
- alexnet_places365_mlx.npz +3 -0
- benchmark.py +180 -0
- comparisons/torch_dream.py +144 -0
- convert.py +217 -0
- dream.py +15 -51
- dream_video.py +130 -0
- mlx_alexnet.py +88 -0
- resnet50_places365.pth_mlx.npz +3 -0
- resnet50_places365_mlx.npz +3 -0
- resnet50_places365_t7_mlx.npz +3 -0
- toConvert/.gitkeep +0 -0
.gitignore
CHANGED
|
@@ -2,13 +2,25 @@ venv/
|
|
| 2 |
__pycache__/
|
| 3 |
*.DS_Store
|
| 4 |
pics/
|
| 5 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
|
| 7 |
# Ignore images generally
|
| 8 |
*.jpg
|
| 9 |
*.png
|
| 10 |
*.gif
|
| 11 |
|
| 12 |
-
# Un-ignore specific
|
| 13 |
!assets/
|
| 14 |
!input/
|
|
|
|
| 2 |
__pycache__/
|
| 3 |
*.DS_Store
|
| 4 |
pics/
|
| 5 |
+
borrowFrom/
|
| 6 |
+
benchmark_results/
|
| 7 |
+
|
| 8 |
+
# Large Model Files (Source)
|
| 9 |
+
*.pth
|
| 10 |
+
*.tar
|
| 11 |
+
*.t7
|
| 12 |
+
*.caffemodel
|
| 13 |
+
*.ckpt
|
| 14 |
+
|
| 15 |
+
# Ignore contents of toConvert but keep the folder
|
| 16 |
+
toConvert/*
|
| 17 |
+
!toConvert/.gitkeep
|
| 18 |
|
| 19 |
# Ignore images generally
|
| 20 |
*.jpg
|
| 21 |
*.png
|
| 22 |
*.gif
|
| 23 |
|
| 24 |
+
# Un-ignore specific assets
|
| 25 |
!assets/
|
| 26 |
!input/
|
Agents.md
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# DeepDream MLX: Agents
|
| 2 |
+
|
| 3 |
+
## 1. The Mission
|
| 4 |
+
To resurrect the 2015 DeepDream aesthetic using modern 2025 Apple Silicon hardware, bypassing the need for archaic frameworks like Caffe or Torch7 by porting everything to native MLX.
|
| 5 |
+
|
| 6 |
+
## 2. Training & Fine-Tuning Plan (The "Punch-Card" Revival)
|
| 7 |
+
In the "classic" days (Intel Caffe era), training a custom DeepDream model meant fine-tuning a GoogLeNet on a dataset of specific objects (e.g., slugs, eyes, cars) so the network would hallucinate *those specific things* when dreaming.
|
| 8 |
+
|
| 9 |
+
**The Roadmap for MLX Training:**
|
| 10 |
+
|
| 11 |
+
### Phase 1: Dataset Prep
|
| 12 |
+
The `dream-creator` logic (from ProGamerGov) is still sound. We need:
|
| 13 |
+
1. **Structure:** `dataset/class_name/*.jpg` (Standard PyTorch ImageFolder format).
|
| 14 |
+
2. **Cleaning:** Remove corrupt images, deduplicate.
|
| 15 |
+
3. **Resizing:** Resize to ~224x224 or 256x256.
|
| 16 |
+
4. **Stats:** Calculate Mean/StdDev.
|
| 17 |
+
|
| 18 |
+
### Phase 2: The Trainer (`train_dream.py`)
|
| 19 |
+
We need to write a native MLX training loop.
|
| 20 |
+
* **Base Model:** Load `googlenet_mlx.npz`.
|
| 21 |
+
* **Architecture:** InceptionV1 (GoogLeNet).
|
| 22 |
+
* **Layer Freezing:**
|
| 23 |
+
- **Critical:** Freeze early layers (`conv1`, `conv2`, `inception3a/b`) to preserve the "visual vocabulary" (edges, textures).
|
| 24 |
+
- **Train:** Retrain only the higher layers (`inception4c`, `inception5b`, `fc`) and the Auxiliary Classifiers.
|
| 25 |
+
* **Auxiliary Classifiers:** Inception has two side-branches (`aux1`, `aux2`) used for training stability. We must support training these or stripping them.
|
| 26 |
+
* **Loss:** Cross-Entropy.
|
| 27 |
+
* **Optimizer:** SGD with Momentum (classic) or Adam.
|
| 28 |
+
|
| 29 |
+
### Phase 3: "Decorrelation" (The Secret Sauce)
|
| 30 |
+
`dream-creator` confirms that "Color Decorrelation" is key.
|
| 31 |
+
* **Matrix:** A 3x3 matrix calculated from the training set covariance.
|
| 32 |
+
* **Effect:** "Whitens" the input image gradients during dreaming, preventing the image from converging to a mono-color blob.
|
| 33 |
+
* **Implementation:** Port `data_tools/calc_cm.py` to MLX.
|
| 34 |
+
|
| 35 |
+
## 3. Animation & Video Strategy
|
| 36 |
+
The "Zoom" video effect is the second pillar of DeepDream.
|
| 37 |
+
* **Logic:** Feedback Loop.
|
| 38 |
+
1. Dream on Frame N.
|
| 39 |
+
2. Zoom (Scale + Crop center) Frame N to create Frame N+1.
|
| 40 |
+
3. Repeat.
|
| 41 |
+
* **Implementation:** A dedicated `dream_video.py` script.
|
| 42 |
+
* **Tech:** Use `scipy.ndimage.zoom` (same as original 2015 code) for the scaling, as MLX's `resize` might differ slightly in sub-pixel interpolation.
|
| 43 |
+
|
| 44 |
+
## 4. Available Models & Wishlist
|
| 45 |
+
**Current:**
|
| 46 |
+
* `alexnet`: The raw, chaotic ancestor.
|
| 47 |
+
* `googlenet` (InceptionV1): The classic "slugs and dogs".
|
| 48 |
+
* `vgg16/19`: The "painterly" style transfer beast.
|
| 49 |
+
* `resnet50`: Modern, sharp, geometric.
|
| 50 |
+
|
| 51 |
+
**Wishlist (To Convert):**
|
| 52 |
+
* `inception_v3`: More refined hallucinations.
|
| 53 |
+
* `googlenet_places365`: Hallucinates landscapes/interiors. (Verified working via `convert.py --download googlenet` when URL is fixed/found).
|
| 54 |
+
|
| 55 |
+
## 5. Hugging Face Hygiene
|
| 56 |
+
* **Repo:** `NickMystic/DeepDream-MLX`
|
| 57 |
+
* **LFS:** Track `*.npz`.
|
| 58 |
+
* **Cleanup:** Ensure `toConvert/` is empty of large raw files.
|
| 59 |
+
* **Banner:** `assets/deepdream_header.jpg`.
|
| 60 |
+
|
| 61 |
+
---
|
| 62 |
+
*Docs derived from deep analysis of `dream-creator` and classic Caffe workflows.*
|
README.md
CHANGED
|
@@ -12,184 +12,102 @@ tags:
|
|
| 12 |
- deepdream
|
| 13 |
pipeline_tag: image-to-image
|
| 14 |
---
|
|
|
|
| 15 |
# DeepDream-MLX
|
| 16 |
|
| 17 |
<img src="assets/deepdream_header.jpg" alt="DeepDream Header" width="100%"/>
|
| 18 |
|
| 19 |
-
**Status:** Fast. Native.
|
| 20 |
**Vibe:** 2015 Hallucinations // 2025 Silicon.
|
| 21 |
|
| 22 |
-
|
|
|
|
|
|
|
| 23 |
|
| 24 |
```bash
|
| 25 |
-
# 1. Install
|
| 26 |
-
pip install
|
| 27 |
|
| 28 |
-
# 2. Dream (VGG16
|
| 29 |
-
python dream.py --input
|
| 30 |
|
| 31 |
-
# 3.
|
| 32 |
-
python dream.py --input
|
| 33 |
```
|
| 34 |
|
| 35 |
-
## 🔮 The
|
| 36 |
|
| 37 |
-
|
| 38 |
|
| 39 |
```text
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
║ ║ (Philosophy: "Deeper") ║ ║ (Philosophy: "Wider") ║ ║ (Philosophy: "Identity") ║ ║
|
| 60 |
-
║ ╚═════════════════╦════════════════╝ ╚═════════════════��════════════════╝ ╚════════════════════╦════════════╝ ║
|
| 61 |
-
║ │ │ │ ║
|
| 62 |
-
║ ┌─────────┴─────────┐ │ │ ║
|
| 63 |
-
║ │ │ │ │ ║
|
| 64 |
-
║ ┏━━━━▼━━━━┓ ┏━━━━▼━━━━┓ ┏━━━━▼━━━━┓ ┏━━━━▼━━━━┓ ║
|
| 65 |
-
║ ┃ VGG16 ┃ ┃ VGG19 ┃ ┃Inception┃ ┃ ResNet ┃ ║
|
| 66 |
-
║ ┃ ┃ ┃ ┃ ┃ V1 ┃ ┃ 50 ┃ ║
|
| 67 |
-
║ ┗━━━━┳━━━━┛ ┗━━━━┳━━━━┛ ┗━━━━┳━━━━┛ ┗━━━━┳━━━━┛ ║
|
| 68 |
-
║ │ │ │ │ ║
|
| 69 |
-
║ (The Painter) (The Stylist) (The Hallucinator) (The Modernist) ║
|
| 70 |
-
║ │ │ │ │ ║
|
| 71 |
-
║ ▼ ▼ ▼ ▼ ║
|
| 72 |
-
║ vgg16_mlx.npz vgg19_mlx.npz googlenet_mlx.npz resnet50_mlx.npz ║
|
| 73 |
-
║ ║
|
| 74 |
-
╚═════════════════════════════════════════════════════════════════════════════════════════════════════════════════════╝
|
| 75 |
```
|
| 76 |
|
| 77 |
-
## 🧠 The Models
|
| 78 |
-
|
| 79 |
-
* **VGG16:** General purpose image features.
|
| 80 |
-
* **GoogLeNet (InceptionV1):** The classic DeepDream model.
|
| 81 |
-
* **VGG19:** Deeper VGG features.
|
| 82 |
-
* **ResNet50:** Modern deep features.
|
| 83 |
-
|
| 84 |
## 🧪 Recipes
|
| 85 |
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
### 1. Classic Inception Patterns (GoogLeNet)
|
| 89 |
-
*This setup targets various Inception layers for recognizable DeepDream shapes.*
|
| 90 |
-
|
| 91 |
```bash
|
| 92 |
-
python dream.py --input
|
| 93 |
-
--model googlenet \
|
| 94 |
-
--steps 22 \
|
| 95 |
-
--lr 0.061 \
|
| 96 |
-
--octaves 4 \
|
| 97 |
-
--scale 1.8 \
|
| 98 |
-
--jitter 26 \
|
| 99 |
-
--smoothing 0.08 \
|
| 100 |
-
--layers inception3a inception4e inception5b
|
| 101 |
```
|
| 102 |
|
| 103 |
-
### 2.
|
| 104 |
-
|
| 105 |
-
|
| 106 |
```bash
|
| 107 |
-
python dream.py --input
|
| 108 |
-
--model vgg16 \
|
| 109 |
-
--steps 24 \
|
| 110 |
-
--lr 0.07 \
|
| 111 |
-
--octaves 4 \
|
| 112 |
-
--scale 1.8 \
|
| 113 |
-
--jitter 36 \
|
| 114 |
-
--smoothing 0.19 \
|
| 115 |
-
--layers relu4_2
|
| 116 |
```
|
| 117 |
|
| 118 |
-
### 3.
|
| 119 |
-
|
| 120 |
-
|
| 121 |
```bash
|
| 122 |
-
python dream.py --input
|
| 123 |
-
--model vgg19 \
|
| 124 |
-
--steps 14 \
|
| 125 |
-
--lr 0.045 \
|
| 126 |
-
--octaves 2 \
|
| 127 |
-
--scale 1.5 \
|
| 128 |
-
--jitter 27 \
|
| 129 |
-
--smoothing 0.41 \
|
| 130 |
-
--layers relu5_2
|
| 131 |
```
|
| 132 |
|
| 133 |
-
|
| 134 |
-
|
|
|
|
|
|
|
| 135 |
|
| 136 |
```bash
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
--steps 24 \
|
| 140 |
-
--lr 0.069 \
|
| 141 |
-
--octaves 4 \
|
| 142 |
-
--scale 1.8 \
|
| 143 |
-
--jitter 10 \
|
| 144 |
-
--smoothing 0.41 \
|
| 145 |
-
--layers relu5_1
|
| 146 |
-
```
|
| 147 |
|
| 148 |
-
|
| 149 |
-
|
|
|
|
| 150 |
|
|
|
|
|
|
|
| 151 |
```bash
|
| 152 |
-
python
|
| 153 |
-
--model resnet50 \
|
| 154 |
-
--steps 22 \
|
| 155 |
-
--lr 0.13 \
|
| 156 |
-
--octaves 4 \
|
| 157 |
-
--scale 2 \
|
| 158 |
-
--jitter 83 \
|
| 159 |
-
--smoothing 0.47 \
|
| 160 |
-
--layers layer3_2 layer3_5
|
| 161 |
```
|
| 162 |
|
| 163 |
-
##
|
| 164 |
-
|
| 165 |
-
We didn't just wrap existing libs. We wrote a custom exporter (`export_models.py`) to rip weights from standard PyTorch/Torchvision archives and serialize them into optimized MLX `.npz` arrays.
|
| 166 |
-
|
| 167 |
-
### 50% Smaller Weights (FP16)
|
| 168 |
-
We now support **Float16** (Half-Precision) weights by default. This cuts model size in half with zero visual loss for DeepDreaming.
|
| 169 |
-
* **VGG16:** 528MB → **264MB**
|
| 170 |
-
* **ResNet50:** 98MB → **49MB**
|
| 171 |
-
|
| 172 |
-
`dream.py` automatically detects and loads `_bf16.npz` files if present.
|
| 173 |
-
|
| 174 |
-
## 🔎 Where to find models?
|
| 175 |
-
|
| 176 |
-
You can convert *any* standard PyTorch model to run here.
|
| 177 |
-
1. **Torchvision:** The source of our VGG/GoogLeNet/ResNet weights.
|
| 178 |
-
2. **Hugging Face Hub:** Massive repo of pretrained models.
|
| 179 |
-
3. **Caffe Model Zoo (Historical):** If you have `.caffemodel` files, load them into PyTorch (using tools like `load_caffe`) and then export.
|
| 180 |
-
|
| 181 |
-
## 🎓 Training & Fine-Tuning (TODO)
|
| 182 |
-
|
| 183 |
-
Want your DeepDream to see things *differently*? (e.g., dogs instead of slugs?)
|
| 184 |
-
You need to fine-tune the base model on a new dataset.
|
| 185 |
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
4. Dream.
|
| 191 |
|
| 192 |
-
*
|
| 193 |
|
| 194 |
---
|
| 195 |
-
*
|
|
|
|
| 12 |
- deepdream
|
| 13 |
pipeline_tag: image-to-image
|
| 14 |
---
|
| 15 |
+
|
| 16 |
# DeepDream-MLX
|
| 17 |
|
| 18 |
<img src="assets/deepdream_header.jpg" alt="DeepDream Header" width="100%"/>
|
| 19 |
|
| 20 |
+
**Status:** Fast. Native.
|
| 21 |
**Vibe:** 2015 Hallucinations // 2025 Silicon.
|
| 22 |
|
| 23 |
+
DeepDream-MLX brings the classic psychedelic computer vision algorithm to modern Apple Silicon, running natively on the GPU via the [MLX](https://github.com/ml-explore/mlx) framework. No Caffe, no slow conversion layers—just pure tensor operations.
|
| 24 |
+
|
| 25 |
+
## ⚡️ Quick Start
|
| 26 |
|
| 27 |
```bash
|
| 28 |
+
# 1. Install
|
| 29 |
+
pip install -r requirements.txt
|
| 30 |
|
| 31 |
+
# 2. Dream (Default VGG16)
|
| 32 |
+
python dream.py --input assets/demo_googlenet.jpg
|
| 33 |
|
| 34 |
+
# 3. Explore Models
|
| 35 |
+
python dream.py --input assets/demo_googlenet.jpg --model googlenet --layers inception4c
|
| 36 |
```
|
| 37 |
|
| 38 |
+
## 🔮 The Evolution of Vision
|
| 39 |
|
| 40 |
+
We support the classic ancestors of modern Computer Vision.
|
| 41 |
|
| 42 |
```text
|
| 43 |
+
TIMELINE MODEL PARAMS PHILOSOPHY
|
| 44 |
+
──────────────────────────────────────────────────────────
|
| 45 |
+
1998 LeNet-5 60K "Digits."
|
| 46 |
+
│
|
| 47 |
+
▼
|
| 48 |
+
2012 AlexNet 60M "Deep."
|
| 49 |
+
│ (Available)
|
| 50 |
+
│
|
| 51 |
+
├────────────┐
|
| 52 |
+
▼ ▼
|
| 53 |
+
2014 2014
|
| 54 |
+
VGG16 GoogLeNet 7M "Wide & Efficient."
|
| 55 |
+
138M (Inception)
|
| 56 |
+
"Deeper."
|
| 57 |
+
│
|
| 58 |
+
▼
|
| 59 |
+
2015
|
| 60 |
+
ResNet50 25M "Identity & Residuals."
|
| 61 |
+
(Modern Standard)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
```
|
| 63 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 64 |
## 🧪 Recipes
|
| 65 |
|
| 66 |
+
### 1. The Classic (GoogLeNet)
|
| 67 |
+
The original DeepDream look. Eyes, slugs, and pagodas.
|
|
|
|
|
|
|
|
|
|
| 68 |
```bash
|
| 69 |
+
python dream.py --input img.jpg --model googlenet --layers inception4c --octaves 4 --scale 1.4
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 70 |
```
|
| 71 |
|
| 72 |
+
### 2. The Painter (VGG16)
|
| 73 |
+
Dense, rich textures. Great for artistic style transfer-like effects.
|
|
|
|
| 74 |
```bash
|
| 75 |
+
python dream.py --input img.jpg --model vgg16 --layers relu4_3 --steps 20
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 76 |
```
|
| 77 |
|
| 78 |
+
### 3. The Modernist (ResNet50)
|
| 79 |
+
Sharp, geometric, and sometimes abstract architectural hallucinations.
|
|
|
|
| 80 |
```bash
|
| 81 |
+
python dream.py --input img.jpg --model resnet50 --layers layer4_2
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 82 |
```
|
| 83 |
|
| 84 |
+
## 🛠 Advanced Usage
|
| 85 |
+
|
| 86 |
+
### Converting Models
|
| 87 |
+
We include a universal converter that ingests standard PyTorch (`.pth`) and legacy Torch7 (`.t7`) models, optimizing them into MLX format (`float16` by default).
|
| 88 |
|
| 89 |
```bash
|
| 90 |
+
# Convert a local file
|
| 91 |
+
python convert.py --scan path/to/models
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 92 |
|
| 93 |
+
# Download & Convert Places365 (AlexNet, ResNet, etc.)
|
| 94 |
+
python convert.py --download all
|
| 95 |
+
```
|
| 96 |
|
| 97 |
+
### Benchmarking
|
| 98 |
+
Verify performance on your machine.
|
| 99 |
```bash
|
| 100 |
+
python benchmark.py
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 101 |
```
|
| 102 |
|
| 103 |
+
## ⚖️ Performance (M2 Max)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 104 |
|
| 105 |
+
| Framework | Model | Precision | Speed |
|
| 106 |
+
| :--- | :--- | :--- | :--- |
|
| 107 |
+
| **MLX** | GoogLeNet | **float16** | **~3.6s** |
|
| 108 |
+
| PyTorch (MPS) | GoogLeNet | float32 | ~4.5s |
|
|
|
|
| 109 |
|
| 110 |
+
*Benchmarks run at 400px width, 10 iterations.*
|
| 111 |
|
| 112 |
---
|
| 113 |
+
*Built for the dreamers.*
|
alexnet_places365.pth_mlx.npz
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:587f2f379063fb722563b86d9e7fea2321119b571c6bff7e09e309abf6dbf0b4
|
| 3 |
+
size 117002764
|
alexnet_places365_mlx.npz
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:587f2f379063fb722563b86d9e7fea2321119b571c6bff7e09e309abf6dbf0b4
|
| 3 |
+
size 117002764
|
benchmark.py
ADDED
|
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
import argparse
|
| 3 |
+
import os
|
| 4 |
+
import subprocess
|
| 5 |
+
import time
|
| 6 |
+
from datetime import datetime
|
| 7 |
+
import json
|
| 8 |
+
|
| 9 |
+
# Benchmark Configuration
|
| 10 |
+
MODELS = ["googlenet", "vgg16", "resnet50"] # vgg19 often similar to vgg16, skipping for speed unless requested
|
| 11 |
+
PRECISIONS = ["int8", "bf16", "float32"]
|
| 12 |
+
INPUT_IMAGE = "assets/demo_googlenet.jpg" # Use a standard asset if available, or fallback
|
| 13 |
+
OUTPUT_DIR = "benchmark_results"
|
| 14 |
+
|
| 15 |
+
def ensure_asset():
|
| 16 |
+
"""Ensures a test image exists."""
|
| 17 |
+
if not os.path.exists(INPUT_IMAGE):
|
| 18 |
+
# Fallback if specific asset missing
|
| 19 |
+
candidates = [f for f in os.listdir("assets") if f.endswith(".jpg")]
|
| 20 |
+
if candidates:
|
| 21 |
+
return os.path.join("assets", candidates[0])
|
| 22 |
+
else:
|
| 23 |
+
raise FileNotFoundError("No test image found in assets/")
|
| 24 |
+
return INPUT_IMAGE
|
| 25 |
+
|
| 26 |
+
def get_weight_file(model, precision):
|
| 27 |
+
"""Maps model+precision to expected filename."""
|
| 28 |
+
suffix = ""
|
| 29 |
+
if precision == "int8":
|
| 30 |
+
suffix = "_mlx_int8.npz"
|
| 31 |
+
elif precision == "bf16":
|
| 32 |
+
suffix = "_mlx_bf16.npz"
|
| 33 |
+
elif precision == "float32":
|
| 34 |
+
suffix = "_mlx.npz"
|
| 35 |
+
|
| 36 |
+
return f"{model}{suffix}"
|
| 37 |
+
|
| 38 |
+
def run_benchmark():
|
| 39 |
+
if not os.path.exists(OUTPUT_DIR):
|
| 40 |
+
os.makedirs(OUTPUT_DIR)
|
| 41 |
+
|
| 42 |
+
test_img = ensure_asset()
|
| 43 |
+
results = []
|
| 44 |
+
|
| 45 |
+
print(f"Starting Benchmark on {test_img}...")
|
| 46 |
+
print(f"{ 'Model':<15} {'Precision':<10} {'Time (s)':<10} {'Status':<10}")
|
| 47 |
+
print("-" * 50)
|
| 48 |
+
|
| 49 |
+
for model in MODELS:
|
| 50 |
+
for prec in PRECISIONS:
|
| 51 |
+
weight_file = get_weight_file(model, prec)
|
| 52 |
+
|
| 53 |
+
if not os.path.exists(weight_file):
|
| 54 |
+
print(f"{model:<15} {prec:<10} {'---':<10} {'Missing Weights'}")
|
| 55 |
+
continue
|
| 56 |
+
|
| 57 |
+
# Run dream.py
|
| 58 |
+
# We use a fixed seed or settings for consistency if possible,
|
| 59 |
+
# but dream.py is deterministic given same args usually.
|
| 60 |
+
# We limit steps to 5 for speed, or use default 10? Default 10 is better for realistic timing.
|
| 61 |
+
|
| 62 |
+
out_path = os.path.join(OUTPUT_DIR, f"bench_{model}_{prec}.jpg")
|
| 63 |
+
|
| 64 |
+
cmd = [
|
| 65 |
+
"python", "dream.py",
|
| 66 |
+
"--input", test_img,
|
| 67 |
+
"--output", out_path,
|
| 68 |
+
"--model", model,
|
| 69 |
+
"--weights", weight_file,
|
| 70 |
+
"--steps", "10",
|
| 71 |
+
"--width", "400"
|
| 72 |
+
]
|
| 73 |
+
|
| 74 |
+
start_t = time.time()
|
| 75 |
+
try:
|
| 76 |
+
# Capture output to avoid clutter
|
| 77 |
+
subprocess.run(cmd, check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
|
| 78 |
+
duration = time.time() - start_t
|
| 79 |
+
print(f"{model:<15} {prec:<10} {duration:.2f} {'OK'}")
|
| 80 |
+
results.append({
|
| 81 |
+
"model": model,
|
| 82 |
+
"precision": prec,
|
| 83 |
+
"time": duration,
|
| 84 |
+
"image": out_path
|
| 85 |
+
})
|
| 86 |
+
except subprocess.CalledProcessError:
|
| 87 |
+
print(f"{model:<15} {prec:<10} {'Error':<10} {'Failed'}")
|
| 88 |
+
|
| 89 |
+
# Generate Report
|
| 90 |
+
generate_report(results)
|
| 91 |
+
create_composite_image(results)
|
| 92 |
+
|
| 93 |
+
def generate_report(results):
|
| 94 |
+
report_path = os.path.join(OUTPUT_DIR, "BENCHMARK_REPORT.md")
|
| 95 |
+
with open(report_path, "w") as f:
|
| 96 |
+
f.write("# DeepDream MLX Benchmark Report\n\n")
|
| 97 |
+
f.write(f"Date: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n")
|
| 98 |
+
f.write("| Model | Precision | Time (s) | Result |\n")
|
| 99 |
+
f.write("|-------|-----------|----------|--------|\n")
|
| 100 |
+
|
| 101 |
+
for r in results:
|
| 102 |
+
rel_img = os.path.basename(r['image'])
|
| 103 |
+
f.write(f"| {r['model']} | {r['precision']} | {r['time']:.2f} | <img src='{rel_img}' width='100'/> |\n")
|
| 104 |
+
|
| 105 |
+
print(f"\nReport generated at {report_path}")
|
| 106 |
+
|
| 107 |
+
def create_composite_image(results):
|
| 108 |
+
try:
|
| 109 |
+
from PIL import Image, ImageDraw, ImageFont
|
| 110 |
+
except ImportError:
|
| 111 |
+
print("PIL not installed, skipping composite image.")
|
| 112 |
+
return
|
| 113 |
+
|
| 114 |
+
# Organize data
|
| 115 |
+
# matrix[model][precision] = image_path
|
| 116 |
+
matrix = {}
|
| 117 |
+
all_models = sorted(list(set(r['model'] for r in results)))
|
| 118 |
+
all_precs = sorted(list(set(r['precision'] for r in results)))
|
| 119 |
+
|
| 120 |
+
for r in results:
|
| 121 |
+
if r['model'] not in matrix:
|
| 122 |
+
matrix[r['model']] = {}
|
| 123 |
+
matrix[r['model']][r['precision']] = r['image']
|
| 124 |
+
|
| 125 |
+
if not matrix:
|
| 126 |
+
return
|
| 127 |
+
|
| 128 |
+
# Determine sizes
|
| 129 |
+
# Assume all images roughly same size, read first found
|
| 130 |
+
sample_img = Image.open(results[0]['image'])
|
| 131 |
+
w, h = sample_img.size
|
| 132 |
+
|
| 133 |
+
# Layout: Header row (precisions), Left col (models)
|
| 134 |
+
padding = 50
|
| 135 |
+
header_height = 60
|
| 136 |
+
label_width = 120
|
| 137 |
+
|
| 138 |
+
grid_w = label_width + len(all_precs) * (w + padding)
|
| 139 |
+
grid_h = header_height + len(all_models) * (h + padding)
|
| 140 |
+
|
| 141 |
+
composite = Image.new("RGB", (grid_w, grid_h), (255, 255, 255))
|
| 142 |
+
draw = ImageDraw.Draw(composite)
|
| 143 |
+
|
| 144 |
+
# Try to load a font, else default
|
| 145 |
+
try:
|
| 146 |
+
font = ImageFont.truetype("Arial", 24)
|
| 147 |
+
except IOError:
|
| 148 |
+
font = ImageFont.load_default()
|
| 149 |
+
|
| 150 |
+
# Draw Header
|
| 151 |
+
for i, prec in enumerate(all_precs):
|
| 152 |
+
x = label_width + i * (w + padding)
|
| 153 |
+
draw.text((x + w//2 - 20, 20), prec, fill=(0,0,0), font=font)
|
| 154 |
+
|
| 155 |
+
# Draw Rows
|
| 156 |
+
for j, model in enumerate(all_models):
|
| 157 |
+
y = header_height + j * (h + padding)
|
| 158 |
+
# Model Label
|
| 159 |
+
draw.text((10, y + h//2), model, fill=(0,0,0), font=font)
|
| 160 |
+
|
| 161 |
+
for i, prec in enumerate(all_precs):
|
| 162 |
+
x = label_width + i * (w + padding)
|
| 163 |
+
if prec in matrix[model]:
|
| 164 |
+
img_path = matrix[model][prec]
|
| 165 |
+
if os.path.exists(img_path):
|
| 166 |
+
img = Image.open(img_path)
|
| 167 |
+
if img.size != (w, h):
|
| 168 |
+
img = img.resize((w, h))
|
| 169 |
+
composite.paste(img, (x, y))
|
| 170 |
+
|
| 171 |
+
# Draw time
|
| 172 |
+
time_val = next(r['time'] for r in results if r['model'] == model and r['precision'] == prec)
|
| 173 |
+
draw.text((x + 5, y + h + 5), f"{time_val:.2f}s", fill=(0,0,0), font=font)
|
| 174 |
+
|
| 175 |
+
comp_path = os.path.join(OUTPUT_DIR, "benchmark_composite.jpg")
|
| 176 |
+
composite.save(comp_path)
|
| 177 |
+
print(f"Composite benchmark image saved to {comp_path}")
|
| 178 |
+
|
| 179 |
+
if __name__ == "__main__":
|
| 180 |
+
run_benchmark()
|
comparisons/torch_dream.py
ADDED
|
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
import argparse
|
| 3 |
+
import time
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
from torchvision import models, transforms
|
| 7 |
+
from PIL import Image
|
| 8 |
+
import numpy as np
|
| 9 |
+
import scipy.ndimage as nd
|
| 10 |
+
|
| 11 |
+
# device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
|
| 12 |
+
# MPS support for some ops (like rolling) might be tricky or just fall back to CPU.
|
| 13 |
+
# For fairness, we try to use MPS where possible.
|
| 14 |
+
DEVICE = torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")
|
| 15 |
+
|
| 16 |
+
IMAGENET_MEAN = torch.tensor([0.485, 0.456, 0.406]).to(DEVICE).view(1, 3, 1, 1)
|
| 17 |
+
IMAGENET_STD = torch.tensor([0.229, 0.224, 0.225]).to(DEVICE).view(1, 3, 1, 1)
|
| 18 |
+
|
| 19 |
+
def preprocess(img_np):
|
| 20 |
+
# HWC -> CHW, Add batch dim
|
| 21 |
+
x = torch.from_numpy(img_np).float().permute(2, 0, 1).unsqueeze(0) / 255.0
|
| 22 |
+
x = x.to(DEVICE)
|
| 23 |
+
x = (x - IMAGENET_MEAN) / IMAGENET_STD
|
| 24 |
+
return x
|
| 25 |
+
|
| 26 |
+
def deprocess(x):
|
| 27 |
+
x = x * IMAGENET_STD + IMAGENET_MEAN
|
| 28 |
+
x = torch.clamp(x, 0, 1)
|
| 29 |
+
x = x.squeeze(0).permute(1, 2, 0).cpu().detach().numpy()
|
| 30 |
+
return (x * 255).astype(np.uint8)
|
| 31 |
+
|
| 32 |
+
def get_model(name):
|
| 33 |
+
if name == "googlenet":
|
| 34 |
+
model = models.googlenet(weights='DEFAULT')
|
| 35 |
+
layers = ["inception4c"] # Default roughly
|
| 36 |
+
elif name == "vgg16":
|
| 37 |
+
model = models.vgg16(weights='DEFAULT')
|
| 38 |
+
layers = ["features.20"] # relu4_2 roughly
|
| 39 |
+
elif name == "resnet50":
|
| 40 |
+
model = models.resnet50(weights='DEFAULT')
|
| 41 |
+
layers = ["layer4"]
|
| 42 |
+
else:
|
| 43 |
+
raise ValueError(name)
|
| 44 |
+
|
| 45 |
+
model.to(DEVICE)
|
| 46 |
+
model.eval()
|
| 47 |
+
for param in model.parameters():
|
| 48 |
+
param.requires_grad = False
|
| 49 |
+
return model, layers
|
| 50 |
+
|
| 51 |
+
class Hook:
|
| 52 |
+
def __init__(self, module):
|
| 53 |
+
self.hook = module.register_forward_hook(self.hook_fn)
|
| 54 |
+
self.activation = None
|
| 55 |
+
def hook_fn(self, module, input, output):
|
| 56 |
+
self.activation = output
|
| 57 |
+
def close(self):
|
| 58 |
+
self.hook.remove()
|
| 59 |
+
|
| 60 |
+
def deepdream(args):
|
| 61 |
+
img = Image.open(args.input).convert('RGB')
|
| 62 |
+
if args.width:
|
| 63 |
+
w, h = img.size
|
| 64 |
+
scale = args.width / w
|
| 65 |
+
img = img.resize((args.width, int(h*scale)), Image.LANCZOS)
|
| 66 |
+
|
| 67 |
+
img_np = np.array(img)
|
| 68 |
+
model, default_layer_names = get_model(args.model)
|
| 69 |
+
|
| 70 |
+
# Hooks
|
| 71 |
+
hooks = []
|
| 72 |
+
# Simplified layer selection for benchmark: just use leaf modules if possible
|
| 73 |
+
# or get by name. For torchvision models, names are tricky.
|
| 74 |
+
# We'll stick to a simple hardcoded layer for the benchmark comparison.
|
| 75 |
+
# GoogLeNet inception4c is usually 'inception4c' submodule.
|
| 76 |
+
|
| 77 |
+
target_modules = []
|
| 78 |
+
if args.model == "googlenet":
|
| 79 |
+
target_modules = [model.inception4c]
|
| 80 |
+
elif args.model == "vgg16":
|
| 81 |
+
target_modules = [model.features[20]] # relu4_2
|
| 82 |
+
elif args.model == "resnet50":
|
| 83 |
+
target_modules = [model.layer4]
|
| 84 |
+
|
| 85 |
+
for m in target_modules:
|
| 86 |
+
hooks.append(Hook(m))
|
| 87 |
+
|
| 88 |
+
input_tensor = preprocess(img_np).requires_grad_(True)
|
| 89 |
+
|
| 90 |
+
print(f"Running Torch ({DEVICE}) Dream on {args.model}...")
|
| 91 |
+
start_t = time.time()
|
| 92 |
+
|
| 93 |
+
# Octave handling is complex to replicate exactly pixel-perfect with MLX version
|
| 94 |
+
# due to resize implementation differences.
|
| 95 |
+
# We will implement a Single Scale run for benchmarking pure iteration speed.
|
| 96 |
+
# Multi-scale introduces resize overhead which is CPU bound mostly.
|
| 97 |
+
|
| 98 |
+
optimizer = torch.optim.SGD([input_tensor], lr=args.lr)
|
| 99 |
+
|
| 100 |
+
for i in range(args.steps):
|
| 101 |
+
optimizer.zero_grad()
|
| 102 |
+
model(input_tensor)
|
| 103 |
+
|
| 104 |
+
loss = 0
|
| 105 |
+
for h in hooks:
|
| 106 |
+
act = h.activation
|
| 107 |
+
loss += act.pow(2).mean()
|
| 108 |
+
|
| 109 |
+
loss.backward()
|
| 110 |
+
|
| 111 |
+
# Gradient Smoothing (Gaussian Blur) would go here.
|
| 112 |
+
# For benchmark simplicity, we skip explicit smoothing to test raw backprop speed,
|
| 113 |
+
# or we could add a simple avg pool.
|
| 114 |
+
|
| 115 |
+
# Normalize grad
|
| 116 |
+
g = input_tensor.grad
|
| 117 |
+
g /= (torch.std(g) + 1e-8)
|
| 118 |
+
input_tensor.grad = g
|
| 119 |
+
|
| 120 |
+
optimizer.step()
|
| 121 |
+
|
| 122 |
+
# Clip
|
| 123 |
+
# (Manual clip to bounds omitted for speed, standard clamp at end)
|
| 124 |
+
|
| 125 |
+
torch.cuda.synchronize() if str(DEVICE) == 'cuda' else None
|
| 126 |
+
# MPS sync?
|
| 127 |
+
|
| 128 |
+
duration = time.time() - start_t
|
| 129 |
+
print(f"Time: {duration:.4f}s")
|
| 130 |
+
|
| 131 |
+
out = deprocess(input_tensor)
|
| 132 |
+
Image.fromarray(out).save(args.output)
|
| 133 |
+
|
| 134 |
+
if __name__ == "__main__":
|
| 135 |
+
parser = argparse.ArgumentParser()
|
| 136 |
+
parser.add_argument("--input", required=True)
|
| 137 |
+
parser.add_argument("--output", default="torch_out.jpg")
|
| 138 |
+
parser.add_argument("--model", default="googlenet")
|
| 139 |
+
parser.add_argument("--steps", type=int, default=10)
|
| 140 |
+
parser.add_argument("--lr", type=float, default=0.05)
|
| 141 |
+
parser.add_argument("--width", type=int, default=400)
|
| 142 |
+
args = parser.parse_args()
|
| 143 |
+
|
| 144 |
+
deepdream(args)
|
convert.py
ADDED
|
@@ -0,0 +1,217 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Universal Model Converter for DeepDream-MLX.
|
| 4 |
+
Converts PyTorch (.pth) and Torch7 (.t7) models to MLX (.npz).
|
| 5 |
+
Also supports auto-downloading standard Places365 models.
|
| 6 |
+
Defaults to float16 for optimal performance on Apple Silicon.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import os
|
| 10 |
+
import argparse
|
| 11 |
+
import glob
|
| 12 |
+
import numpy as np
|
| 13 |
+
import torch
|
| 14 |
+
import torchvision.models as models
|
| 15 |
+
from torch.hub import download_url_to_file
|
| 16 |
+
|
| 17 |
+
# Optional Torchfile for .t7 support
|
| 18 |
+
try:
|
| 19 |
+
import torchfile
|
| 20 |
+
except ImportError:
|
| 21 |
+
torchfile = None
|
| 22 |
+
|
| 23 |
+
# --- Configuration ---
|
| 24 |
+
PLACES365_URLS = {
|
| 25 |
+
"alexnet": "http://places2.csail.mit.edu/models_places365/alexnet_places365.pth.tar",
|
| 26 |
+
"resnet50": "http://places2.csail.mit.edu/models_places365/resnet50_places365.pth.tar",
|
| 27 |
+
"vgg16": "http://places2.csail.mit.edu/models_places365/vgg16_places365.pth.tar",
|
| 28 |
+
"googlenet": "http://places2.csail.mit.edu/models_places365/googlenet_places365.pth.tar"
|
| 29 |
+
}
|
| 30 |
+
|
| 31 |
+
# --- Helper Functions ---
|
| 32 |
+
|
| 33 |
+
def convert_tensor(tensor, target_dtype=np.float16):
|
| 34 |
+
"""Converts a tensor/array to the target numpy dtype."""
|
| 35 |
+
if isinstance(tensor, torch.Tensor):
|
| 36 |
+
return tensor.cpu().detach().numpy().astype(target_dtype)
|
| 37 |
+
elif isinstance(tensor, np.ndarray):
|
| 38 |
+
return tensor.astype(target_dtype)
|
| 39 |
+
else:
|
| 40 |
+
return np.array(tensor).astype(target_dtype)
|
| 41 |
+
|
| 42 |
+
def clean_state_dict(state_dict):
|
| 43 |
+
"""
|
| 44 |
+
Flattens the state dictionary and removes common prefix artifacts
|
| 45 |
+
like 'module.' from DataParallel wrapping.
|
| 46 |
+
"""
|
| 47 |
+
new_dict = {}
|
| 48 |
+
for k, v in state_dict.items():
|
| 49 |
+
# Remove 'module.' anywhere in the key
|
| 50 |
+
name = k.replace("module.", "")
|
| 51 |
+
new_dict[name] = convert_tensor(v)
|
| 52 |
+
return new_dict
|
| 53 |
+
|
| 54 |
+
def get_places365_model_skeleton(arch):
|
| 55 |
+
"""Returns a standard PyTorch model structure for Places365."""
|
| 56 |
+
if arch == "alexnet":
|
| 57 |
+
return models.alexnet(num_classes=365)
|
| 58 |
+
elif arch == "resnet50":
|
| 59 |
+
return models.resnet50(num_classes=365)
|
| 60 |
+
elif arch == "vgg16":
|
| 61 |
+
return models.vgg16(num_classes=365)
|
| 62 |
+
elif arch == "googlenet":
|
| 63 |
+
return models.googlenet(num_classes=365, aux_logits=False)
|
| 64 |
+
else:
|
| 65 |
+
raise ValueError(f"Unknown architecture: {arch}")
|
| 66 |
+
|
| 67 |
+
# --- Conversion Logic ---
|
| 68 |
+
|
| 69 |
+
def convert_torch7(filepath, target_dir):
|
| 70 |
+
if torchfile is None:
|
| 71 |
+
print(f"⚠️ Skipping {filepath}: 'torchfile' not installed. Run `pip install torchfile`.")
|
| 72 |
+
return
|
| 73 |
+
|
| 74 |
+
print(f"Processing Torch7 file: {filepath}")
|
| 75 |
+
try:
|
| 76 |
+
model_obj = torchfile.load(filepath)
|
| 77 |
+
converted_state = {}
|
| 78 |
+
|
| 79 |
+
def extract_layers(layer, prefix=""):
|
| 80 |
+
if hasattr(layer, 'weight') and layer.weight is not None:
|
| 81 |
+
converted_state[f"{prefix}.weight"] = convert_tensor(layer.weight)
|
| 82 |
+
if hasattr(layer, 'bias') and layer.bias is not None:
|
| 83 |
+
converted_state[f"{prefix}.bias"] = convert_tensor(layer.bias)
|
| 84 |
+
|
| 85 |
+
if hasattr(layer, 'modules') and layer.modules:
|
| 86 |
+
for i, sublayer in enumerate(layer.modules):
|
| 87 |
+
# 0-based indexing for compatibility
|
| 88 |
+
next_prefix = f"{prefix}.{i}" if prefix else f"{i}"
|
| 89 |
+
extract_layers(sublayer, next_prefix)
|
| 90 |
+
|
| 91 |
+
extract_layers(model_obj)
|
| 92 |
+
|
| 93 |
+
if not converted_state:
|
| 94 |
+
print(f"❌ No weights found in {filepath}.")
|
| 95 |
+
return
|
| 96 |
+
|
| 97 |
+
name_base = os.path.splitext(os.path.basename(filepath))[0]
|
| 98 |
+
out_path = os.path.join(target_dir, f"{name_base}_t7_mlx.npz")
|
| 99 |
+
np.savez(out_path, **converted_state)
|
| 100 |
+
print(f"✅ Saved {out_path} ({len(converted_state)} tensors)")
|
| 101 |
+
|
| 102 |
+
except Exception as e:
|
| 103 |
+
print(f"❌ Failed to convert {filepath}: {e}")
|
| 104 |
+
|
| 105 |
+
def convert_pytorch(filepath, target_dir):
|
| 106 |
+
print(f"Processing PyTorch file: {filepath}")
|
| 107 |
+
try:
|
| 108 |
+
checkpoint = torch.load(filepath, map_location="cpu")
|
| 109 |
+
|
| 110 |
+
if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
|
| 111 |
+
state_dict = checkpoint['state_dict']
|
| 112 |
+
elif isinstance(checkpoint, dict):
|
| 113 |
+
state_dict = checkpoint
|
| 114 |
+
else:
|
| 115 |
+
print(f"❌ Unknown checkpoint format in {filepath}")
|
| 116 |
+
return
|
| 117 |
+
|
| 118 |
+
clean_dict = clean_state_dict(state_dict)
|
| 119 |
+
|
| 120 |
+
name_base = os.path.splitext(os.path.basename(filepath))[0]
|
| 121 |
+
# Avoid double extension if file was .pth.tar
|
| 122 |
+
if name_base.endswith(".pth"):
|
| 123 |
+
name_base = os.path.splitext(name_base)[0]
|
| 124 |
+
|
| 125 |
+
out_path = os.path.join(target_dir, f"{name_base}_mlx.npz")
|
| 126 |
+
np.savez(out_path, **clean_dict)
|
| 127 |
+
|
| 128 |
+
size_mb = os.path.getsize(out_path) / (1024*1024)
|
| 129 |
+
print(f"✅ Saved {out_path} ({size_mb:.1f} MB)")
|
| 130 |
+
|
| 131 |
+
except Exception as e:
|
| 132 |
+
print(f"❌ Failed to convert {filepath}: {e}")
|
| 133 |
+
|
| 134 |
+
def download_and_convert_places365(arch, download_dir, target_dir):
|
| 135 |
+
url = PLACES365_URLS.get(arch)
|
| 136 |
+
if not url:
|
| 137 |
+
print(f"No URL for {arch}")
|
| 138 |
+
return
|
| 139 |
+
|
| 140 |
+
filename = os.path.join(download_dir, os.path.basename(url))
|
| 141 |
+
|
| 142 |
+
# 1. Download
|
| 143 |
+
if not os.path.exists(filename):
|
| 144 |
+
print(f"Downloading {arch} from {url}...")
|
| 145 |
+
try:
|
| 146 |
+
download_url_to_file(url, filename)
|
| 147 |
+
except Exception as e:
|
| 148 |
+
print(f"Download failed: {e}")
|
| 149 |
+
return
|
| 150 |
+
else:
|
| 151 |
+
print(f"Found cached {filename}")
|
| 152 |
+
|
| 153 |
+
# 2. Load into standard Skeleton (ensures structural correctness)
|
| 154 |
+
print(f"Loading {arch} into PyTorch structure...")
|
| 155 |
+
try:
|
| 156 |
+
model = get_places365_model_skeleton(arch)
|
| 157 |
+
checkpoint = torch.load(filename, map_location="cpu")
|
| 158 |
+
state_dict = checkpoint['state_dict'] if 'state_dict' in checkpoint else checkpoint
|
| 159 |
+
|
| 160 |
+
# Robust Load
|
| 161 |
+
new_state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
|
| 162 |
+
try:
|
| 163 |
+
model.load_state_dict(new_state_dict, strict=True)
|
| 164 |
+
except:
|
| 165 |
+
model.load_state_dict(new_state_dict, strict=False)
|
| 166 |
+
|
| 167 |
+
# 3. Export
|
| 168 |
+
model.eval()
|
| 169 |
+
final_dict = clean_state_dict(model.state_dict())
|
| 170 |
+
out_path = os.path.join(target_dir, f"{arch}_places365_mlx.npz")
|
| 171 |
+
np.savez(out_path, **final_dict)
|
| 172 |
+
print(f"✅ Saved {out_path}")
|
| 173 |
+
|
| 174 |
+
except Exception as e:
|
| 175 |
+
print(f"Failed to process {arch}: {e}")
|
| 176 |
+
|
| 177 |
+
# --- Main CLI ---
|
| 178 |
+
|
| 179 |
+
def main():
|
| 180 |
+
parser = argparse.ArgumentParser(description="DeepDream-MLX Model Converter")
|
| 181 |
+
parser.add_argument("--scan", default="toConvert", help="Directory to scan for local files")
|
| 182 |
+
parser.add_argument("--download", choices=["alexnet", "resnet50", "vgg16", "googlenet", "all"],
|
| 183 |
+
help="Download and convert specific Places365 models")
|
| 184 |
+
parser.add_argument("--dest", default=".", help="Output directory for .npz files")
|
| 185 |
+
args = parser.parse_args()
|
| 186 |
+
|
| 187 |
+
if not os.path.exists(args.dest):
|
| 188 |
+
os.makedirs(args.dest)
|
| 189 |
+
|
| 190 |
+
# 1. Handle Downloads
|
| 191 |
+
if args.download:
|
| 192 |
+
if not os.path.exists(args.scan):
|
| 193 |
+
os.makedirs(args.scan)
|
| 194 |
+
|
| 195 |
+
targets = ["alexnet", "resnet50", "vgg16", "googlenet"] if args.download == "all" else [args.download]
|
| 196 |
+
for t in targets:
|
| 197 |
+
download_and_convert_places365(t, args.scan, args.dest)
|
| 198 |
+
|
| 199 |
+
# 2. Handle Local Scan
|
| 200 |
+
if os.path.exists(args.scan):
|
| 201 |
+
print(f"\nScanning '{args.scan}' for local models...")
|
| 202 |
+
files = glob.glob(os.path.join(args.scan, "*"))
|
| 203 |
+
for f in files:
|
| 204 |
+
if os.path.isdir(f): continue
|
| 205 |
+
ext = os.path.splitext(f)[1].lower()
|
| 206 |
+
|
| 207 |
+
if ext == ".t7":
|
| 208 |
+
convert_torch7(f, args.dest)
|
| 209 |
+
elif ext in [".pth", ".pt", ".tar", ".pkl"]:
|
| 210 |
+
# If it looks like a downloaded places file we already processed, skip to avoid duplication
|
| 211 |
+
# heuristic: if we just downloaded it.
|
| 212 |
+
convert_pytorch(f, args.dest)
|
| 213 |
+
elif ext in [".caffemodel"]:
|
| 214 |
+
print(f"⚠️ Skipping Caffe model {os.path.basename(f)} (Not supported)")
|
| 215 |
+
|
| 216 |
+
if __name__ == "__main__":
|
| 217 |
+
main()
|
dream.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
|
|
| 1 |
import argparse
|
| 2 |
import os
|
| 3 |
import time
|
|
@@ -13,6 +14,7 @@ from mlx_googlenet import GoogLeNet
|
|
| 13 |
from mlx_resnet50 import ResNet50
|
| 14 |
from mlx_vgg16 import VGG16
|
| 15 |
from mlx_vgg19 import VGG19
|
|
|
|
| 16 |
|
| 17 |
IMAGENET_MEAN = mx.array([0.485, 0.456, 0.406])
|
| 18 |
IMAGENET_STD = mx.array([0.229, 0.224, 0.225])
|
|
@@ -176,63 +178,20 @@ def deepdream(
|
|
| 176 |
|
| 177 |
|
| 178 |
def get_weights_path(model_name, explicit_path=None):
|
| 179 |
-
|
| 180 |
-
|
| 181 |
if explicit_path:
|
| 182 |
-
|
| 183 |
-
|
| 184 |
return explicit_path
|
| 185 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 186 |
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
# 1. Try int8 (Maximum Efficiency / Smallest)
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
int8_path = f"{model_name}_mlx_int8.npz"
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
if os.path.exists(int8_path):
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
return int8_path
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
# 2. Try bf16 (Standard Efficient)
|
| 206 |
-
|
| 207 |
-
|
| 208 |
bf16_path = f"{model_name}_mlx_bf16.npz"
|
| 209 |
-
|
| 210 |
-
|
| 211 |
if os.path.exists(bf16_path):
|
| 212 |
-
|
| 213 |
-
|
| 214 |
return bf16_path
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
# 3. Try standard float32
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
fp32_path = f"{model_name}_mlx.npz"
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
if os.path.exists(fp32_path):
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
return fp32_path
|
| 230 |
-
|
| 231 |
-
|
| 232 |
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
return int8_path # Return preferred default for error message context
|
| 236 |
|
| 237 |
|
| 238 |
def run_dream_for_model(model_name, args, img_np):
|
|
@@ -313,6 +272,11 @@ def run_dream_for_model(model_name, args, img_np):
|
|
| 313 |
weights = get_weights_path("resnet50", args.weights)
|
| 314 |
default_layers = ["layer4_2"]
|
| 315 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 316 |
else: # googlenet
|
| 317 |
model = GoogLeNet()
|
| 318 |
weights = get_weights_path("googlenet", args.weights)
|
|
@@ -380,7 +344,7 @@ def parse_args():
|
|
| 380 |
|
| 381 |
p.add_argument(
|
| 382 |
"--model",
|
| 383 |
-
choices=["vgg16", "vgg19", "googlenet", "resnet50", "all"],
|
| 384 |
default="vgg16",
|
| 385 |
help="Model to use. 'all' runs all models.",
|
| 386 |
)
|
|
@@ -427,7 +391,7 @@ def main():
|
|
| 427 |
img_np = load_image(args.input, args.width)
|
| 428 |
|
| 429 |
if args.model == "all":
|
| 430 |
-
models = ["vgg16", "vgg19", "googlenet", "resnet50"]
|
| 431 |
if args.output:
|
| 432 |
print(
|
| 433 |
"Warning: --output argument ignored because --model='all' was selected."
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
import argparse
|
| 3 |
import os
|
| 4 |
import time
|
|
|
|
| 14 |
from mlx_resnet50 import ResNet50
|
| 15 |
from mlx_vgg16 import VGG16
|
| 16 |
from mlx_vgg19 import VGG19
|
| 17 |
+
from mlx_alexnet import AlexNet
|
| 18 |
|
| 19 |
IMAGENET_MEAN = mx.array([0.485, 0.456, 0.406])
|
| 20 |
IMAGENET_STD = mx.array([0.229, 0.224, 0.225])
|
|
|
|
| 178 |
|
| 179 |
|
| 180 |
def get_weights_path(model_name, explicit_path=None):
|
|
|
|
|
|
|
| 181 |
if explicit_path:
|
|
|
|
|
|
|
| 182 |
return explicit_path
|
| 183 |
|
| 184 |
+
# 1. Try standard MLX export (float16/bf16 default)
|
| 185 |
+
path = f"{model_name}_mlx.npz"
|
| 186 |
+
if os.path.exists(path):
|
| 187 |
+
return path
|
| 188 |
|
| 189 |
+
# 2. Try explicit bf16 suffix (legacy)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 190 |
bf16_path = f"{model_name}_mlx_bf16.npz"
|
|
|
|
|
|
|
| 191 |
if os.path.exists(bf16_path):
|
|
|
|
|
|
|
| 192 |
return bf16_path
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 193 |
|
| 194 |
+
return path # Return default for error message context
|
|
|
|
|
|
|
| 195 |
|
| 196 |
|
| 197 |
def run_dream_for_model(model_name, args, img_np):
|
|
|
|
| 272 |
weights = get_weights_path("resnet50", args.weights)
|
| 273 |
default_layers = ["layer4_2"]
|
| 274 |
|
| 275 |
+
elif model_name == "alexnet":
|
| 276 |
+
model = AlexNet()
|
| 277 |
+
weights = get_weights_path("alexnet", args.weights)
|
| 278 |
+
default_layers = ["relu5"]
|
| 279 |
+
|
| 280 |
else: # googlenet
|
| 281 |
model = GoogLeNet()
|
| 282 |
weights = get_weights_path("googlenet", args.weights)
|
|
|
|
| 344 |
|
| 345 |
p.add_argument(
|
| 346 |
"--model",
|
| 347 |
+
choices=["vgg16", "vgg19", "googlenet", "resnet50", "alexnet", "all"],
|
| 348 |
default="vgg16",
|
| 349 |
help="Model to use. 'all' runs all models.",
|
| 350 |
)
|
|
|
|
| 391 |
img_np = load_image(args.input, args.width)
|
| 392 |
|
| 393 |
if args.model == "all":
|
| 394 |
+
models = ["vgg16", "vgg19", "googlenet", "resnet50", "alexnet"]
|
| 395 |
if args.output:
|
| 396 |
print(
|
| 397 |
"Warning: --output argument ignored because --model='all' was selected."
|
dream_video.py
ADDED
|
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
import argparse
|
| 3 |
+
import os
|
| 4 |
+
import time
|
| 5 |
+
import numpy as np
|
| 6 |
+
import mlx.core as mx
|
| 7 |
+
import scipy.ndimage as nd
|
| 8 |
+
from PIL import Image
|
| 9 |
+
from dream import deepdream, load_image, deprocess, get_weights_path
|
| 10 |
+
from mlx_googlenet import GoogLeNet
|
| 11 |
+
from mlx_resnet50 import ResNet50
|
| 12 |
+
from mlx_vgg16 import VGG16
|
| 13 |
+
from mlx_vgg19 import VGG19
|
| 14 |
+
from mlx_alexnet import AlexNet
|
| 15 |
+
|
| 16 |
+
def run_video_dream(args):
|
| 17 |
+
print(f"--- DeepDream Video Generator ---")
|
| 18 |
+
print(f"Model: {args.model}")
|
| 19 |
+
print(f"Zoom: {args.zoom_factor}")
|
| 20 |
+
print(f"Frames: {args.frames}")
|
| 21 |
+
|
| 22 |
+
# 1. Load Model
|
| 23 |
+
if args.model == "vgg16":
|
| 24 |
+
model = VGG16()
|
| 25 |
+
default_layers = ["relu4_3"]
|
| 26 |
+
elif args.model == "vgg19":
|
| 27 |
+
model = VGG19()
|
| 28 |
+
default_layers = ["relu4_4"]
|
| 29 |
+
elif args.model == "resnet50":
|
| 30 |
+
model = ResNet50()
|
| 31 |
+
default_layers = ["layer4_2"]
|
| 32 |
+
elif args.model == "alexnet":
|
| 33 |
+
model = AlexNet()
|
| 34 |
+
default_layers = ["relu5"]
|
| 35 |
+
else:
|
| 36 |
+
model = GoogLeNet()
|
| 37 |
+
default_layers = ["inception4c"]
|
| 38 |
+
|
| 39 |
+
weights = get_weights_path(args.model, args.weights)
|
| 40 |
+
if not os.path.exists(weights):
|
| 41 |
+
print(f"Error: Weights {weights} not found.")
|
| 42 |
+
return
|
| 43 |
+
|
| 44 |
+
print(f"Loading weights: {weights}")
|
| 45 |
+
model.load_npz(weights)
|
| 46 |
+
|
| 47 |
+
# 2. Prepare Input
|
| 48 |
+
img_np = load_image(args.input, args.width)
|
| 49 |
+
|
| 50 |
+
# 3. Prepare Output Dir
|
| 51 |
+
if not os.path.exists(args.output_dir):
|
| 52 |
+
os.makedirs(args.output_dir)
|
| 53 |
+
|
| 54 |
+
current_img = img_np.astype(np.float32)
|
| 55 |
+
|
| 56 |
+
# 4. Loop
|
| 57 |
+
for i in range(args.frames):
|
| 58 |
+
start_t = time.time()
|
| 59 |
+
|
| 60 |
+
# Dream
|
| 61 |
+
dreamed = deepdream(
|
| 62 |
+
model,
|
| 63 |
+
current_img,
|
| 64 |
+
layers=args.layers or default_layers,
|
| 65 |
+
steps=args.steps,
|
| 66 |
+
lr=args.lr,
|
| 67 |
+
num_octaves=args.octaves,
|
| 68 |
+
scale=args.scale,
|
| 69 |
+
jitter=args.jitter,
|
| 70 |
+
smoothing=args.smoothing
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
# Save Frame
|
| 74 |
+
frame_name = f"frame_{i:04d}.jpg"
|
| 75 |
+
out_path = os.path.join(args.output_dir, frame_name)
|
| 76 |
+
Image.fromarray(dreamed).save(out_path)
|
| 77 |
+
|
| 78 |
+
elapsed = time.time() - start_t
|
| 79 |
+
print(f"Frame {i+1}/{args.frames}: {frame_name} ({elapsed:.2f}s)")
|
| 80 |
+
|
| 81 |
+
# Transform for next frame (Zoom)
|
| 82 |
+
# Zooming involves:
|
| 83 |
+
# 1. Scaling up by zoom_factor
|
| 84 |
+
# 2. Cropping back to original size (center crop)
|
| 85 |
+
|
| 86 |
+
if i < args.frames - 1:
|
| 87 |
+
# dreamed is (H, W, 3) uint8
|
| 88 |
+
# Convert back to float for zoom to avoid precision loss
|
| 89 |
+
next_input = dreamed.astype(np.float32)
|
| 90 |
+
|
| 91 |
+
# Scipy Zoom (order=1 is bilinear, usually sufficient and fast)
|
| 92 |
+
# Zoom H and W dimensions, keep Channel dimension (zoom=1)
|
| 93 |
+
zf = args.zoom_factor
|
| 94 |
+
next_input = nd.zoom(next_input, (zf, zf, 1), order=1)
|
| 95 |
+
|
| 96 |
+
# Crop Center
|
| 97 |
+
h_new, w_new, _ = next_input.shape
|
| 98 |
+
h_orig, w_orig, _ = img_np.shape
|
| 99 |
+
|
| 100 |
+
start_h = (h_new - h_orig) // 2
|
| 101 |
+
start_w = (w_new - w_orig) // 2
|
| 102 |
+
|
| 103 |
+
current_img = next_input[start_h:start_h+h_orig, start_w:start_w+w_orig, :]
|
| 104 |
+
|
| 105 |
+
print(f"\nDone! Frames saved to {args.output_dir}/\n")
|
| 106 |
+
print(f"To create video (requires ffmpeg):")
|
| 107 |
+
print(f"ffmpeg -framerate 15 -i {args.output_dir}/frame_%04d.jpg -c:v libx264 -pix_fmt yuv420p video.mp4")
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
if __name__ == "__main__":
|
| 111 |
+
parser = argparse.ArgumentParser()
|
| 112 |
+
parser.add_argument("--input", required=True)
|
| 113 |
+
parser.add_argument("--output_dir", default="frames")
|
| 114 |
+
parser.add_argument("--frames", type=int, default=30)
|
| 115 |
+
parser.add_argument("--zoom_factor", type=float, default=1.05)
|
| 116 |
+
|
| 117 |
+
# Shared dream args
|
| 118 |
+
parser.add_argument("--width", type=int, default=None)
|
| 119 |
+
parser.add_argument("--model", default="googlenet")
|
| 120 |
+
parser.add_argument("--weights", default=None)
|
| 121 |
+
parser.add_argument("--layers", nargs="+ ")
|
| 122 |
+
parser.add_argument("--steps", type=int, default=5) # Fewer steps for video usually smoother
|
| 123 |
+
parser.add_argument("--lr", type=float, default=0.05)
|
| 124 |
+
parser.add_argument("--octaves", type=int, default=2) # Fewer octaves for speed
|
| 125 |
+
parser.add_argument("--scale", type=float, default=1.4)
|
| 126 |
+
parser.add_argument("--jitter", type=int, default=32)
|
| 127 |
+
parser.add_argument("--smoothing", type=float, default=0.5)
|
| 128 |
+
|
| 129 |
+
args = parser.parse_args()
|
| 130 |
+
run_video_dream(args)
|
mlx_alexnet.py
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
AlexNet in MLX with endpoints for relu1, relu2, relu3, relu4, relu5.
|
| 3 |
+
Loads weights from a torchvision-exported npz.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import mlx.core as mx
|
| 7 |
+
import mlx.nn as nn
|
| 8 |
+
import numpy as np
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def _conv(in_ch, out_ch, kernel_size, stride=1, padding=0):
|
| 12 |
+
return nn.Conv2d(
|
| 13 |
+
in_ch,
|
| 14 |
+
out_ch,
|
| 15 |
+
kernel_size=kernel_size,
|
| 16 |
+
stride=stride,
|
| 17 |
+
padding=padding,
|
| 18 |
+
bias=True,
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class AlexNet(nn.Module):
|
| 23 |
+
def __init__(self):
|
| 24 |
+
super().__init__()
|
| 25 |
+
self.layers = [
|
| 26 |
+
_conv(3, 64, kernel_size=11, stride=4, padding=2), # 0
|
| 27 |
+
nn.ReLU(), # 1 (relu1)
|
| 28 |
+
nn.MaxPool2d(kernel_size=3, stride=2), # 2
|
| 29 |
+
_conv(64, 192, kernel_size=5, padding=2), # 3
|
| 30 |
+
nn.ReLU(), # 4 (relu2)
|
| 31 |
+
nn.MaxPool2d(kernel_size=3, stride=2), # 5
|
| 32 |
+
_conv(192, 384, kernel_size=3, padding=1), # 6
|
| 33 |
+
nn.ReLU(), # 7 (relu3)
|
| 34 |
+
_conv(384, 256, kernel_size=3, padding=1), # 8
|
| 35 |
+
nn.ReLU(), # 9 (relu4)
|
| 36 |
+
_conv(256, 256, kernel_size=3, padding=1), # 10
|
| 37 |
+
nn.ReLU(), # 11 (relu5)
|
| 38 |
+
nn.MaxPool2d(kernel_size=3, stride=2), # 12
|
| 39 |
+
]
|
| 40 |
+
|
| 41 |
+
self.endpoint_indices = {
|
| 42 |
+
"relu1": 1,
|
| 43 |
+
"relu2": 4,
|
| 44 |
+
"relu3": 7,
|
| 45 |
+
"relu4": 9,
|
| 46 |
+
"relu5": 11,
|
| 47 |
+
}
|
| 48 |
+
|
| 49 |
+
def forward_with_endpoints(self, x):
|
| 50 |
+
endpoints = {}
|
| 51 |
+
for idx, layer in enumerate(self.layers):
|
| 52 |
+
x = layer(x)
|
| 53 |
+
for name, i in self.endpoint_indices.items():
|
| 54 |
+
if idx == i:
|
| 55 |
+
endpoints[name] = x
|
| 56 |
+
return x, endpoints
|
| 57 |
+
|
| 58 |
+
def __call__(self, x):
|
| 59 |
+
_, endpoints = self.forward_with_endpoints(x)
|
| 60 |
+
return endpoints
|
| 61 |
+
|
| 62 |
+
def load_npz(self, path: str):
|
| 63 |
+
data = np.load(path)
|
| 64 |
+
|
| 65 |
+
def load_weight(key, transpose=False):
|
| 66 |
+
if key in data:
|
| 67 |
+
w = data[key]
|
| 68 |
+
elif f"{key}_int8" in data:
|
| 69 |
+
w_int8 = data[f"{key}_int8"]
|
| 70 |
+
scale = data[f"{key}_scale"]
|
| 71 |
+
w = w_int8.astype(scale.dtype) * scale
|
| 72 |
+
else:
|
| 73 |
+
raise ValueError(f"Missing key {key} in npz")
|
| 74 |
+
|
| 75 |
+
if transpose and w.ndim == 4:
|
| 76 |
+
w = np.transpose(w, (0, 2, 3, 1))
|
| 77 |
+
return mx.array(w)
|
| 78 |
+
|
| 79 |
+
# Map layer indices to 'features.X' in standard torchvision keys
|
| 80 |
+
conv_indices = [0, 3, 6, 8, 10]
|
| 81 |
+
|
| 82 |
+
for idx in conv_indices:
|
| 83 |
+
conv = self.layers[idx]
|
| 84 |
+
weight_key = f"features.{idx}.weight"
|
| 85 |
+
bias_key = f"features.{idx}.bias"
|
| 86 |
+
|
| 87 |
+
conv.weight = load_weight(weight_key, transpose=True)
|
| 88 |
+
conv.bias = load_weight(bias_key)
|
resnet50_places365.pth_mlx.npz
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c7e4496e460a4cbec41e02f169c7be9c0e3cebe28036ac917105ba386471c47b
|
| 3 |
+
size 48691562
|
resnet50_places365_mlx.npz
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c7e4496e460a4cbec41e02f169c7be9c0e3cebe28036ac917105ba386471c47b
|
| 3 |
+
size 48691562
|
resnet50_places365_t7_mlx.npz
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:cbfc6e4d63fb8824df3a8c60d82581106679b2061a654fd9d9ab62d798b94f99
|
| 3 |
+
size 48536532
|
toConvert/.gitkeep
ADDED
|
File without changes
|