Initial upload: Paris MoE inference code and weights
Browse files- .gitattributes +2 -0
- .ipynb_checkpoints/instructions-checkpoint.txt +1 -0
- .ipynb_checkpoints/test_inference-checkpoint.png +3 -0
- .ipynb_checkpoints/test_int8-checkpoint.png +3 -0
- README.md +154 -0
- __pycache__/generate.cpython-312.pyc +0 -0
- benchmark.py +440 -0
- benchmark_results.md +41 -0
- generate.py +747 -0
- instructions.txt +1 -0
- quantize.py +435 -0
- requirements.txt +7 -0
- src/__init__.py +1 -0
- src/__pycache__/config.cpython-312.pyc +0 -0
- src/__pycache__/models.cpython-312.pyc +0 -0
- src/__pycache__/schedules.cpython-312.pyc +0 -0
- src/__pycache__/vae_utils.cpython-312.pyc +0 -0
- src/config.py +199 -0
- src/models.py +1913 -0
- src/schedules.py +166 -0
- src/vae_utils.py +186 -0
- weights/bf16/config.pt +3 -0
- weights/bf16/expert_0.safetensors +3 -0
- weights/bf16/expert_1.safetensors +3 -0
- weights/bf16/expert_2.safetensors +3 -0
- weights/bf16/expert_3.safetensors +3 -0
- weights/bf16/expert_4.safetensors +3 -0
- weights/bf16/expert_5.safetensors +3 -0
- weights/bf16/expert_6.safetensors +3 -0
- weights/bf16/expert_7.safetensors +3 -0
- weights/bf16/router.safetensors +3 -0
- weights/bf16/router_config.pt +3 -0
- weights/int8/expert_0.safetensors +3 -0
- weights/int8/expert_1.safetensors +3 -0
- weights/int8/expert_2.safetensors +3 -0
- weights/int8/expert_3.safetensors +3 -0
- weights/int8/expert_4.safetensors +3 -0
- weights/int8/expert_5.safetensors +3 -0
- weights/int8/expert_6.safetensors +3 -0
- weights/int8/expert_7.safetensors +3 -0
- weights/int8/router.safetensors +3 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
.ipynb_checkpoints/test_inference-checkpoint.png filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
.ipynb_checkpoints/test_int8-checkpoint.png filter=lfs diff=lfs merge=lfs -text
|
.ipynb_checkpoints/instructions-checkpoint.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
What we're now is we're going to prepare an inference folder and we're going to make an inference repository for our Paris model. This will include, we will just stick with int8 and bfloat16 and mixed int8, bfloat16 for now. And this repository will include efficient methods how to run the code. It will include quantization code that can accept either PT or save tensors or Float32 save tensors or Float32 PT. It will include a lot of different methods that can accept either PT or Float32 PT. We need to make a visualizer next which outputs a little pretty ASCII chart. We should output the ASCII chart right on the terminal every time we run the inference via this tool. Let's just say we're running the int8 inference of the mixed int8 model. By the way, we're also going to put the weights that we quantized inside this inference folder because we're going to publish this on HuggingFace. have just again, the beef flow 16 and intake weights. we might already be done this by the way. But again, I wanted to do that when we have to keep some kind of track and output a chart in the terminal, like as a little terminal visualization in ASCII. MAKE SURE WE'RE DOING ROUTING PROPERLY. Top 2 etc. Again, just to recap, we're going to make a folder that's just called inference. In this folder, we're going to put the quantized weights that we already made, because we already made them before in the last session. So the bfloat16 and the int8 weights. And we're going to put one Python file for the inference code, and it's going to have all the flags, and it's also going to have a visualized flag. And the visualized flag is actually a lot more than that, because it keeps track of which expert is being used during each inference step, and that shows like a little pretty chart. So if we're generating with 30 steps, which is going to show which experts got to use the most and the least out of eight of them. And so we want to have this in the inference code. Make sure to read files in full before like a pass inference code that we already wrote. Try to list like the most recent files that we made for that. And we also want to have the quantization code to just be an all in one utility with a very nice terminal interface as well, because we want the quantization code to be able to handle float 16 bfloat 16 float 32 weights in both safe tensors and in dot pt format. So that needs to be very smart and tested that it actually works. And also, yeah, make a read me in this folder for the Paris model, because we're going to publish this on hogging face as the inference repository. And then also read all the MD files that we have written here in full because after we do all of this and after we test that it works and it differences fine. We're going to we're going to start to play around with network inference. So that's going to be the fun next step after. So again, make a 20 point to do this for this and please make sure to include at least four or five sentences per point. So the to do list is going to be very long, naturally and very detailed. But I believe we're going to do an excellent, excellent job here.
|
.ipynb_checkpoints/test_inference-checkpoint.png
ADDED
|
Git LFS Details
|
.ipynb_checkpoints/test_int8-checkpoint.png
ADDED
|
Git LFS Details
|
README.md
ADDED
|
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 🥖 Baguette - Paris MoE Text-to-Image
|
| 2 |
+
|
| 3 |
+
A ~5 billion parameter Mixture-of-Experts diffusion model with 8 specialized experts.
|
| 4 |
+
|
| 5 |
+
## ⚡ Quick Start
|
| 6 |
+
|
| 7 |
+
```bash
|
| 8 |
+
# Install dependencies
|
| 9 |
+
pip install uv && uv pip install torch torchvision safetensors transformers diffusers accelerate tqdm
|
| 10 |
+
|
| 11 |
+
# Generate 4 cat images
|
| 12 |
+
python generate.py --prompt "a cute cat" --num_samples 4
|
| 13 |
+
```
|
| 14 |
+
|
| 15 |
+
That's it! Images saved to `output_bf16.png`.
|
| 16 |
+
|
| 17 |
+
---
|
| 18 |
+
|
| 19 |
+
## 🎨 Examples
|
| 20 |
+
|
| 21 |
+
```bash
|
| 22 |
+
# Simple generation
|
| 23 |
+
python generate.py --prompt "sunset over mountains"
|
| 24 |
+
|
| 25 |
+
# More samples, see expert routing
|
| 26 |
+
python generate.py --prompt "abstract art" --num_samples 16 --visualize
|
| 27 |
+
|
| 28 |
+
# Faster with fewer steps
|
| 29 |
+
python generate.py --prompt "a dog" --num_steps 15
|
| 30 |
+
|
| 31 |
+
# Lower memory (offload 4 experts to CPU)
|
| 32 |
+
python generate.py --prompt "portrait" --offload 4
|
| 33 |
+
|
| 34 |
+
# INT8 weights (smaller, slightly lower quality)
|
| 35 |
+
python generate.py --prompt "forest" --precision int8
|
| 36 |
+
```
|
| 37 |
+
|
| 38 |
+
---
|
| 39 |
+
|
| 40 |
+
## 📋 All Options
|
| 41 |
+
|
| 42 |
+
| Flag | Default | Description |
|
| 43 |
+
|------|---------|-------------|
|
| 44 |
+
| `--prompt` | "a cute cat" | What to generate |
|
| 45 |
+
| `--num_samples` | 16 | Number of images |
|
| 46 |
+
| `--num_steps` | 30 | Sampling steps (20-50 recommended) |
|
| 47 |
+
| `--cfg_scale` | 7.5 | Guidance strength (5-10 recommended) |
|
| 48 |
+
| `--precision` | bf16 | `bf16` (best) or `int8` (smaller) |
|
| 49 |
+
| `--topk` | 2 | Experts per sample (1 or 2) |
|
| 50 |
+
| `--offload` | 0 | Experts to keep on CPU (0-7) |
|
| 51 |
+
| `--visualize` | off | Show expert routing stats |
|
| 52 |
+
| `--output` | auto | Output filename |
|
| 53 |
+
| `--seed` | 999 | Random seed |
|
| 54 |
+
|
| 55 |
+
---
|
| 56 |
+
|
| 57 |
+
## 🔍 Expert Visualization
|
| 58 |
+
|
| 59 |
+
Use `--visualize` to see which experts the router selects:
|
| 60 |
+
|
| 61 |
+
```
|
| 62 |
+
╭──────────────────────────────────────────────────╮
|
| 63 |
+
│ ⚡ EXPERT USAGE DISTRIBUTION │
|
| 64 |
+
├──────────────────────────────────────────────────┤
|
| 65 |
+
│ → E4 │████████████████████████████│ 40.6% │
|
| 66 |
+
│ E2 │██████████████████████████ │ 36.7% │
|
| 67 |
+
│ E6 │██████████ │ 14.8% │
|
| 68 |
+
│ E1 │███ │ 5.5% │
|
| 69 |
+
│ E5 │█ │ 2.3% │
|
| 70 |
+
│ E0 │ │ 0.0% │
|
| 71 |
+
│ E3 │ │ 0.0% │
|
| 72 |
+
│ E7 │ │ 0.0% │
|
| 73 |
+
├──────────────────────────────────────────────────┤
|
| 74 |
+
│ Active: 5/8 experts Calls: 128 │
|
| 75 |
+
╰──────────────────────────────────────────────────╯
|
| 76 |
+
|
| 77 |
+
╭──────────────────────────────────────────────────╮
|
| 78 |
+
│ 📈 ROUTING TIMELINE │
|
| 79 |
+
├──────────────────────────────────────────────────┤
|
| 80 |
+
│ Step 0 1 2 3 4 5 6 7 8 9 10 11 ... │
|
| 81 |
+
│ ──────────────────────────────────────────── │
|
| 82 |
+
│ E0 · · · · · · · · · · · · │
|
| 83 |
+
│ E2 · · · · · · ● ● ● ● ● ● │
|
| 84 |
+
│ E4 · · ● ● ● ● · · · · · · │
|
| 85 |
+
│ E6 ● ● · · · · · · · · · · │
|
| 86 |
+
├──────────────────────────────────────────────────┤
|
| 87 |
+
│ Routing changes: 2/11 steps (18%) │
|
| 88 |
+
╰──────────────────────────────────────────────────╯
|
| 89 |
+
```
|
| 90 |
+
|
| 91 |
+
---
|
| 92 |
+
|
| 93 |
+
## 💾 Memory & Speed
|
| 94 |
+
|
| 95 |
+
| Config | GPU Memory | Speed |
|
| 96 |
+
|--------|-----------|-------|
|
| 97 |
+
| BF16 (all on GPU) | ~25 GB | ~3 img/s |
|
| 98 |
+
| BF16 + offload 4 | ~14 GB | ~1 img/s |
|
| 99 |
+
| INT8 (all on GPU) | ~12 GB | ~2 img/s |
|
| 100 |
+
| INT8 + offload 4 | ~8 GB | ~0.5 img/s |
|
| 101 |
+
|
| 102 |
+
---
|
| 103 |
+
|
| 104 |
+
## 🏗️ Architecture
|
| 105 |
+
|
| 106 |
+
```
|
| 107 |
+
┌─────────────────────────────────────────┐
|
| 108 |
+
│ Paris MoE Model │
|
| 109 |
+
├─────────────────────────────────────────┤
|
| 110 |
+
│ Router: DiT-B/2 (129M params) │
|
| 111 |
+
│ ↓ selects top-K experts │
|
| 112 |
+
│ Experts: 8× DiT-XL/2 (606M each) │
|
| 113 |
+
│ ↓ predicts velocity │
|
| 114 |
+
│ VAE: Stable Diffusion VAE │
|
| 115 |
+
�� ↓ decodes to pixels │
|
| 116 |
+
│ Output: 256×256 RGB │
|
| 117 |
+
└─────────────────────────────────────────┘
|
| 118 |
+
```
|
| 119 |
+
|
| 120 |
+
- **Total Parameters**: ~5 Billion
|
| 121 |
+
- **Latent Space**: 32×32×4
|
| 122 |
+
- **Text Encoder**: CLIP ViT-L/14
|
| 123 |
+
|
| 124 |
+
---
|
| 125 |
+
|
| 126 |
+
## 📁 Files
|
| 127 |
+
|
| 128 |
+
```
|
| 129 |
+
├── generate.py # Main generation script
|
| 130 |
+
├── benchmark.py # Performance testing
|
| 131 |
+
├── quantize.py # Weight conversion tool
|
| 132 |
+
├── src/ # Model code
|
| 133 |
+
└── weights/
|
| 134 |
+
├── bf16/ # BFloat16 weights (9.3 GB)
|
| 135 |
+
└── int8/ # INT8 weights (4.8 GB)
|
| 136 |
+
```
|
| 137 |
+
|
| 138 |
+
---
|
| 139 |
+
|
| 140 |
+
## 🔧 Convert Your Own Weights
|
| 141 |
+
|
| 142 |
+
```bash
|
| 143 |
+
# From PyTorch .pt to BF16 safetensors
|
| 144 |
+
python quantize.py --input /path/to/weights --output ./weights/bf16 --format bf16
|
| 145 |
+
|
| 146 |
+
# From BF16 to INT8
|
| 147 |
+
python quantize.py --input ./weights/bf16 --output ./weights/int8 --format int8
|
| 148 |
+
```
|
| 149 |
+
|
| 150 |
+
---
|
| 151 |
+
|
| 152 |
+
## 📜 License
|
| 153 |
+
|
| 154 |
+
Apache 2.0
|
__pycache__/generate.cpython-312.pyc
ADDED
|
Binary file (37.7 kB). View file
|
|
|
benchmark.py
ADDED
|
@@ -0,0 +1,440 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
╔══════════════════════════════════════════════════════════════════════════════╗
|
| 4 |
+
║ ║
|
| 5 |
+
║ 📊 Paris MoE - Comprehensive Benchmarking Utility 📊 ║
|
| 6 |
+
║ ║
|
| 7 |
+
║ Measures performance across precision modes, batch sizes, and configs. ║
|
| 8 |
+
║ Outputs results as both terminal display and Markdown file. ║
|
| 9 |
+
║ ║
|
| 10 |
+
╚══════════════════════════════════════════════════════════════════════════════╝
|
| 11 |
+
|
| 12 |
+
Usage:
|
| 13 |
+
python benchmark.py # Run all benchmarks
|
| 14 |
+
python benchmark.py --quick # Quick benchmark (fewer configs)
|
| 15 |
+
python benchmark.py --precision bf16 # Benchmark specific precision
|
| 16 |
+
python benchmark.py --output results.md # Save results to file
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
import argparse
|
| 20 |
+
import sys
|
| 21 |
+
import os
|
| 22 |
+
import time
|
| 23 |
+
import gc
|
| 24 |
+
from pathlib import Path
|
| 25 |
+
from datetime import datetime
|
| 26 |
+
from dataclasses import dataclass
|
| 27 |
+
from typing import List, Dict, Optional
|
| 28 |
+
|
| 29 |
+
SCRIPT_DIR = Path(__file__).parent.absolute()
|
| 30 |
+
SRC_DIR = SCRIPT_DIR / "src"
|
| 31 |
+
sys.path.insert(0, str(SRC_DIR))
|
| 32 |
+
|
| 33 |
+
import torch
|
| 34 |
+
|
| 35 |
+
# ═══════════════════════════════════════════════════════════════════════════════
|
| 36 |
+
# DATA STRUCTURES
|
| 37 |
+
# ═══════════════════════════════════════════════════════════════════════════════
|
| 38 |
+
|
| 39 |
+
@dataclass
|
| 40 |
+
class BenchmarkResult:
|
| 41 |
+
"""Single benchmark result."""
|
| 42 |
+
precision: str
|
| 43 |
+
num_samples: int
|
| 44 |
+
num_steps: int
|
| 45 |
+
topk: int
|
| 46 |
+
offload: int
|
| 47 |
+
|
| 48 |
+
load_time: float # Model loading time (seconds)
|
| 49 |
+
gen_time: float # Generation time (seconds)
|
| 50 |
+
decode_time: float # VAE decoding time (seconds)
|
| 51 |
+
|
| 52 |
+
peak_memory_gb: float # Peak GPU memory usage
|
| 53 |
+
|
| 54 |
+
@property
|
| 55 |
+
def total_time(self) -> float:
|
| 56 |
+
return self.gen_time + self.decode_time
|
| 57 |
+
|
| 58 |
+
@property
|
| 59 |
+
def throughput(self) -> float:
|
| 60 |
+
"""Images per second (generation only)."""
|
| 61 |
+
return self.num_samples / self.gen_time if self.gen_time > 0 else 0
|
| 62 |
+
|
| 63 |
+
@property
|
| 64 |
+
def time_per_step(self) -> float:
|
| 65 |
+
"""Seconds per sampling step."""
|
| 66 |
+
return self.gen_time / self.num_steps if self.num_steps > 0 else 0
|
| 67 |
+
|
| 68 |
+
@property
|
| 69 |
+
def time_per_image(self) -> float:
|
| 70 |
+
"""Seconds per image (generation only)."""
|
| 71 |
+
return self.gen_time / self.num_samples if self.num_samples > 0 else 0
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
# ═══════════════════════════════════════════════════════════════════════════════
|
| 75 |
+
# BENCHMARK RUNNER
|
| 76 |
+
# ═══════════════════════════════════════════════════════════════════════════════
|
| 77 |
+
|
| 78 |
+
def get_gpu_memory_gb() -> float:
|
| 79 |
+
"""Get current GPU memory usage in GB."""
|
| 80 |
+
if torch.cuda.is_available():
|
| 81 |
+
return torch.cuda.max_memory_allocated() / (1024 ** 3)
|
| 82 |
+
return 0.0
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def reset_gpu_memory():
|
| 86 |
+
"""Reset GPU memory tracking."""
|
| 87 |
+
if torch.cuda.is_available():
|
| 88 |
+
torch.cuda.reset_peak_memory_stats()
|
| 89 |
+
torch.cuda.empty_cache()
|
| 90 |
+
gc.collect()
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def run_single_benchmark(precision: str, num_samples: int, num_steps: int,
|
| 94 |
+
topk: int, offload: int, device: str = 'cuda') -> BenchmarkResult:
|
| 95 |
+
"""Run a single benchmark configuration."""
|
| 96 |
+
from generate import load_sampler
|
| 97 |
+
|
| 98 |
+
reset_gpu_memory()
|
| 99 |
+
|
| 100 |
+
# Load model
|
| 101 |
+
start_load = time.time()
|
| 102 |
+
sampler = load_sampler(precision=precision, device=device, offload=offload)
|
| 103 |
+
load_time = time.time() - start_load
|
| 104 |
+
|
| 105 |
+
# Set seed for reproducibility
|
| 106 |
+
torch.manual_seed(42)
|
| 107 |
+
if torch.cuda.is_available():
|
| 108 |
+
torch.cuda.manual_seed(42)
|
| 109 |
+
|
| 110 |
+
# Warmup run
|
| 111 |
+
_ = sampler.sample(
|
| 112 |
+
num_samples=1,
|
| 113 |
+
text_prompts=["warmup"],
|
| 114 |
+
cfg_scale=7.5,
|
| 115 |
+
num_steps=2,
|
| 116 |
+
use_bf16=(precision == 'bf16'),
|
| 117 |
+
topk=topk
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
reset_gpu_memory()
|
| 121 |
+
torch.cuda.synchronize()
|
| 122 |
+
|
| 123 |
+
# Timed generation
|
| 124 |
+
start_gen = time.time()
|
| 125 |
+
latents = sampler.sample(
|
| 126 |
+
num_samples=num_samples,
|
| 127 |
+
text_prompts=["a cute cat"],
|
| 128 |
+
cfg_scale=7.5,
|
| 129 |
+
num_steps=num_steps,
|
| 130 |
+
use_bf16=(precision == 'bf16'),
|
| 131 |
+
topk=topk
|
| 132 |
+
)
|
| 133 |
+
torch.cuda.synchronize()
|
| 134 |
+
gen_time = time.time() - start_gen
|
| 135 |
+
|
| 136 |
+
# Timed decoding
|
| 137 |
+
start_decode = time.time()
|
| 138 |
+
images = sampler.vae_manager.decode(latents)
|
| 139 |
+
torch.cuda.synchronize()
|
| 140 |
+
decode_time = time.time() - start_decode
|
| 141 |
+
|
| 142 |
+
peak_memory = get_gpu_memory_gb()
|
| 143 |
+
|
| 144 |
+
# Cleanup
|
| 145 |
+
del sampler, latents, images
|
| 146 |
+
gc.collect()
|
| 147 |
+
torch.cuda.empty_cache()
|
| 148 |
+
|
| 149 |
+
return BenchmarkResult(
|
| 150 |
+
precision=precision,
|
| 151 |
+
num_samples=num_samples,
|
| 152 |
+
num_steps=num_steps,
|
| 153 |
+
topk=topk,
|
| 154 |
+
offload=offload,
|
| 155 |
+
load_time=load_time,
|
| 156 |
+
gen_time=gen_time,
|
| 157 |
+
decode_time=decode_time,
|
| 158 |
+
peak_memory_gb=peak_memory
|
| 159 |
+
)
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
# ═══════════════════════════════════════════════════════════════════════════════
|
| 163 |
+
# OUTPUT FORMATTERS
|
| 164 |
+
# ═══════════════════════════════════════════════════════════════════════════════
|
| 165 |
+
|
| 166 |
+
def format_terminal_results(results: List[BenchmarkResult], gpu_name: str) -> str:
|
| 167 |
+
"""Format results for terminal display."""
|
| 168 |
+
lines = []
|
| 169 |
+
|
| 170 |
+
lines.append("""
|
| 171 |
+
╔══════════════════════════════════════════════════════════════════════════════╗
|
| 172 |
+
║ 📊 PARIS MoE BENCHMARK RESULTS 📊 ║
|
| 173 |
+
╚══════════════════════════════════════════════════════════════════════════════╝
|
| 174 |
+
""")
|
| 175 |
+
|
| 176 |
+
lines.append(f" GPU: {gpu_name}")
|
| 177 |
+
lines.append(f" Date: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
|
| 178 |
+
lines.append("")
|
| 179 |
+
|
| 180 |
+
# Group by precision
|
| 181 |
+
precisions = sorted(set(r.precision for r in results))
|
| 182 |
+
|
| 183 |
+
for precision in precisions:
|
| 184 |
+
prec_results = [r for r in results if r.precision == precision]
|
| 185 |
+
|
| 186 |
+
lines.append(f"┌{'─'*78}┐")
|
| 187 |
+
lines.append(f"│ {precision.upper()} Precision{' '*65}│")
|
| 188 |
+
lines.append(f"├{'─'*78}┤")
|
| 189 |
+
lines.append(f"│ {'Samples':>8} │ {'Steps':>6} │ {'TopK':>5} │ {'Offload':>7} │ "
|
| 190 |
+
f"{'Gen(s)':>8} │ {'Img/s':>6} │ {'s/step':>6} │ {'Mem(GB)':>8} │")
|
| 191 |
+
lines.append(f"├{'─'*78}┤")
|
| 192 |
+
|
| 193 |
+
for r in prec_results:
|
| 194 |
+
lines.append(
|
| 195 |
+
f"│ {r.num_samples:>8} │ {r.num_steps:>6} │ {r.topk:>5} │ {r.offload:>7} │ "
|
| 196 |
+
f"{r.gen_time:>8.2f} │ {r.throughput:>6.2f} │ {r.time_per_step:>6.3f} │ "
|
| 197 |
+
f"{r.peak_memory_gb:>8.2f} │"
|
| 198 |
+
)
|
| 199 |
+
|
| 200 |
+
lines.append(f"└{'─'*78}┘")
|
| 201 |
+
lines.append("")
|
| 202 |
+
|
| 203 |
+
# Summary
|
| 204 |
+
if results:
|
| 205 |
+
fastest = min(results, key=lambda r: r.time_per_image)
|
| 206 |
+
most_efficient = min(results, key=lambda r: r.peak_memory_gb)
|
| 207 |
+
|
| 208 |
+
lines.append("┌─────────────────────────────────────────────────────────────────┐")
|
| 209 |
+
lines.append("│ 📈 SUMMARY │")
|
| 210 |
+
lines.append("├─────────────────────────────────────────────────────────────────┤")
|
| 211 |
+
lines.append(f"│ 🏆 Fastest: {fastest.precision.upper():>6} @ {fastest.throughput:.2f} img/s │")
|
| 212 |
+
lines.append(f"│ 💾 Most Efficient: {most_efficient.precision.upper():>6} @ {most_efficient.peak_memory_gb:.1f} GB peak │")
|
| 213 |
+
lines.append("└─────────────────────────────────────────────────────────────────┘")
|
| 214 |
+
|
| 215 |
+
return "\n".join(lines)
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
def format_markdown_results(results: List[BenchmarkResult], gpu_name: str) -> str:
|
| 219 |
+
"""Format results as Markdown."""
|
| 220 |
+
lines = []
|
| 221 |
+
|
| 222 |
+
lines.append("# 📊 Paris MoE Benchmark Results")
|
| 223 |
+
lines.append("")
|
| 224 |
+
lines.append(f"**GPU:** {gpu_name}")
|
| 225 |
+
lines.append(f"**Date:** {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
|
| 226 |
+
lines.append("")
|
| 227 |
+
|
| 228 |
+
lines.append("## 🏗️ Model Architecture")
|
| 229 |
+
lines.append("")
|
| 230 |
+
lines.append("| Component | Details |")
|
| 231 |
+
lines.append("|-----------|---------|")
|
| 232 |
+
lines.append("| Experts | 8× DiT-XL/2 (606M params each) |")
|
| 233 |
+
lines.append("| Router | DiT-B/2 (129M params) |")
|
| 234 |
+
lines.append("| Total | ~5 Billion parameters |")
|
| 235 |
+
lines.append("| VAE | SD-VAE (stabilityai/sd-vae-ft-mse) |")
|
| 236 |
+
lines.append("| Text Encoder | CLIP ViT-L/14 |")
|
| 237 |
+
lines.append("")
|
| 238 |
+
|
| 239 |
+
# Group by precision
|
| 240 |
+
precisions = sorted(set(r.precision for r in results))
|
| 241 |
+
|
| 242 |
+
for precision in precisions:
|
| 243 |
+
prec_results = [r for r in results if r.precision == precision]
|
| 244 |
+
|
| 245 |
+
lines.append(f"## {precision.upper()} Precision")
|
| 246 |
+
lines.append("")
|
| 247 |
+
lines.append("| Samples | Steps | TopK | Offload | Gen Time (s) | Throughput (img/s) | Time/Step (s) | Peak Memory (GB) |")
|
| 248 |
+
lines.append("|---------|-------|------|---------|--------------|-------------------|---------------|------------------|")
|
| 249 |
+
|
| 250 |
+
for r in prec_results:
|
| 251 |
+
lines.append(
|
| 252 |
+
f"| {r.num_samples} | {r.num_steps} | {r.topk} | {r.offload} | "
|
| 253 |
+
f"{r.gen_time:.2f} | {r.throughput:.2f} | {r.time_per_step:.3f} | {r.peak_memory_gb:.2f} |"
|
| 254 |
+
)
|
| 255 |
+
|
| 256 |
+
lines.append("")
|
| 257 |
+
|
| 258 |
+
# Summary
|
| 259 |
+
if results:
|
| 260 |
+
lines.append("## 📈 Summary")
|
| 261 |
+
lines.append("")
|
| 262 |
+
|
| 263 |
+
fastest = min(results, key=lambda r: r.time_per_image)
|
| 264 |
+
most_efficient = min(results, key=lambda r: r.peak_memory_gb)
|
| 265 |
+
|
| 266 |
+
lines.append(f"- **🏆 Fastest Configuration:** {fastest.precision.upper()}, "
|
| 267 |
+
f"{fastest.num_samples} samples @ {fastest.throughput:.2f} img/s")
|
| 268 |
+
lines.append(f"- **💾 Most Memory Efficient:** {most_efficient.precision.upper()} "
|
| 269 |
+
f"with offload={most_efficient.offload} @ {most_efficient.peak_memory_gb:.1f} GB peak")
|
| 270 |
+
lines.append("")
|
| 271 |
+
|
| 272 |
+
# Recommendations
|
| 273 |
+
lines.append("## 🎯 Recommendations")
|
| 274 |
+
lines.append("")
|
| 275 |
+
lines.append("| Use Case | Precision | Offload | Expected Performance |")
|
| 276 |
+
lines.append("|----------|-----------|---------|---------------------|")
|
| 277 |
+
|
| 278 |
+
bf16_results = [r for r in results if r.precision == 'bf16' and r.offload == 0]
|
| 279 |
+
if bf16_results:
|
| 280 |
+
r = bf16_results[0]
|
| 281 |
+
lines.append(f"| **Production (Quality)** | BF16 | 0 | {r.throughput:.2f} img/s, {r.peak_memory_gb:.1f} GB |")
|
| 282 |
+
|
| 283 |
+
int8_results = [r for r in results if r.precision == 'int8' and r.offload == 0]
|
| 284 |
+
if int8_results:
|
| 285 |
+
r = int8_results[0]
|
| 286 |
+
lines.append(f"| **Balanced** | INT8 | 0 | {r.throughput:.2f} img/s, {r.peak_memory_gb:.1f} GB |")
|
| 287 |
+
|
| 288 |
+
offload_results = [r for r in results if r.offload > 0]
|
| 289 |
+
if offload_results:
|
| 290 |
+
r = min(offload_results, key=lambda x: x.peak_memory_gb)
|
| 291 |
+
lines.append(f"| **Low VRAM** | {r.precision.upper()} | {r.offload} | {r.throughput:.2f} img/s, {r.peak_memory_gb:.1f} GB |")
|
| 292 |
+
|
| 293 |
+
lines.append("")
|
| 294 |
+
lines.append("---")
|
| 295 |
+
lines.append("*Generated by Paris MoE Benchmark Utility*")
|
| 296 |
+
|
| 297 |
+
return "\n".join(lines)
|
| 298 |
+
|
| 299 |
+
|
| 300 |
+
# ═══════════════════════════════════════════════════════════════════════════════
|
| 301 |
+
# MAIN
|
| 302 |
+
# ═══════════════════════════════════════════════════════════════════════════════
|
| 303 |
+
|
| 304 |
+
def parse_args():
|
| 305 |
+
parser = argparse.ArgumentParser(
|
| 306 |
+
description="📊 Paris MoE - Benchmark Utility",
|
| 307 |
+
formatter_class=argparse.RawDescriptionHelpFormatter,
|
| 308 |
+
epilog="""
|
| 309 |
+
Examples:
|
| 310 |
+
python benchmark.py # Full benchmark suite
|
| 311 |
+
python benchmark.py --quick # Quick benchmark
|
| 312 |
+
python benchmark.py --precision bf16 # BF16 only
|
| 313 |
+
python benchmark.py --output results.md # Save to file
|
| 314 |
+
"""
|
| 315 |
+
)
|
| 316 |
+
|
| 317 |
+
parser.add_argument("--quick", action="store_true",
|
| 318 |
+
help="Run quick benchmark with fewer configurations")
|
| 319 |
+
parser.add_argument("--precision", type=str, default=None,
|
| 320 |
+
choices=["bf16", "int8", "mixed"],
|
| 321 |
+
help="Benchmark specific precision only")
|
| 322 |
+
parser.add_argument("--output", "-o", type=str, default=None,
|
| 323 |
+
help="Output Markdown file path")
|
| 324 |
+
parser.add_argument("--samples", type=int, default=None,
|
| 325 |
+
help="Override number of samples")
|
| 326 |
+
parser.add_argument("--steps", type=int, default=None,
|
| 327 |
+
help="Override number of steps")
|
| 328 |
+
|
| 329 |
+
return parser.parse_args()
|
| 330 |
+
|
| 331 |
+
|
| 332 |
+
def get_benchmark_configs(args) -> List[Dict]:
|
| 333 |
+
"""Get list of benchmark configurations to run."""
|
| 334 |
+
configs = []
|
| 335 |
+
|
| 336 |
+
if args.quick:
|
| 337 |
+
# Quick benchmark: minimal configs
|
| 338 |
+
precisions = [args.precision] if args.precision else ['bf16', 'int8']
|
| 339 |
+
samples = args.samples or 4
|
| 340 |
+
steps = args.steps or 10
|
| 341 |
+
|
| 342 |
+
for precision in precisions:
|
| 343 |
+
configs.append({
|
| 344 |
+
'precision': precision,
|
| 345 |
+
'num_samples': samples,
|
| 346 |
+
'num_steps': steps,
|
| 347 |
+
'topk': 1,
|
| 348 |
+
'offload': 0
|
| 349 |
+
})
|
| 350 |
+
else:
|
| 351 |
+
# Full benchmark suite
|
| 352 |
+
precisions = [args.precision] if args.precision else ['bf16', 'int8']
|
| 353 |
+
samples_list = [args.samples] if args.samples else [4, 16]
|
| 354 |
+
steps_list = [args.steps] if args.steps else [20, 30]
|
| 355 |
+
topk_list = [1, 2]
|
| 356 |
+
offload_list = [0, 4]
|
| 357 |
+
|
| 358 |
+
for precision in precisions:
|
| 359 |
+
for samples in samples_list:
|
| 360 |
+
for steps in steps_list:
|
| 361 |
+
for topk in topk_list:
|
| 362 |
+
for offload in offload_list:
|
| 363 |
+
configs.append({
|
| 364 |
+
'precision': precision,
|
| 365 |
+
'num_samples': samples,
|
| 366 |
+
'num_steps': steps,
|
| 367 |
+
'topk': topk,
|
| 368 |
+
'offload': offload
|
| 369 |
+
})
|
| 370 |
+
|
| 371 |
+
return configs
|
| 372 |
+
|
| 373 |
+
|
| 374 |
+
def main():
|
| 375 |
+
args = parse_args()
|
| 376 |
+
|
| 377 |
+
print("""
|
| 378 |
+
╔══════════════════════════════════════════════════════════════════════════════╗
|
| 379 |
+
║ ║
|
| 380 |
+
║ 📊 Paris MoE - Comprehensive Benchmarking Utility 📊 ║
|
| 381 |
+
║ ║
|
| 382 |
+
║ Measuring performance across precision modes, batch sizes, and configs. ║
|
| 383 |
+
║ ║
|
| 384 |
+
╚══════════════════════════════════════════════════════════════════════════════╝
|
| 385 |
+
""")
|
| 386 |
+
|
| 387 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 388 |
+
if device != "cuda":
|
| 389 |
+
print("⚠️ Warning: Running on CPU. Benchmarks will be slow.")
|
| 390 |
+
|
| 391 |
+
gpu_name = torch.cuda.get_device_name(0) if torch.cuda.is_available() else "CPU"
|
| 392 |
+
print(f"🖥️ Device: {gpu_name}")
|
| 393 |
+
|
| 394 |
+
configs = get_benchmark_configs(args)
|
| 395 |
+
print(f"📋 Running {len(configs)} benchmark configurations...\n")
|
| 396 |
+
|
| 397 |
+
results = []
|
| 398 |
+
|
| 399 |
+
for i, config in enumerate(configs):
|
| 400 |
+
print(f"[{i+1}/{len(configs)}] {config['precision'].upper()} | "
|
| 401 |
+
f"{config['num_samples']} samples | {config['num_steps']} steps | "
|
| 402 |
+
f"Top-{config['topk']} | Offload {config['offload']}")
|
| 403 |
+
|
| 404 |
+
try:
|
| 405 |
+
result = run_single_benchmark(
|
| 406 |
+
precision=config['precision'],
|
| 407 |
+
num_samples=config['num_samples'],
|
| 408 |
+
num_steps=config['num_steps'],
|
| 409 |
+
topk=config['topk'],
|
| 410 |
+
offload=config['offload'],
|
| 411 |
+
device=device
|
| 412 |
+
)
|
| 413 |
+
results.append(result)
|
| 414 |
+
print(f" ✅ {result.gen_time:.2f}s, {result.throughput:.2f} img/s, "
|
| 415 |
+
f"{result.peak_memory_gb:.1f} GB peak")
|
| 416 |
+
except Exception as e:
|
| 417 |
+
print(f" ❌ Failed: {e}")
|
| 418 |
+
|
| 419 |
+
print()
|
| 420 |
+
|
| 421 |
+
if not results:
|
| 422 |
+
print("❌ No successful benchmarks!")
|
| 423 |
+
return 1
|
| 424 |
+
|
| 425 |
+
# Print terminal results
|
| 426 |
+
terminal_output = format_terminal_results(results, gpu_name)
|
| 427 |
+
print(terminal_output)
|
| 428 |
+
|
| 429 |
+
# Save Markdown if requested
|
| 430 |
+
if args.output:
|
| 431 |
+
md_output = format_markdown_results(results, gpu_name)
|
| 432 |
+
with open(args.output, 'w') as f:
|
| 433 |
+
f.write(md_output)
|
| 434 |
+
print(f"\n✅ Results saved to: {args.output}")
|
| 435 |
+
|
| 436 |
+
return 0
|
| 437 |
+
|
| 438 |
+
|
| 439 |
+
if __name__ == "__main__":
|
| 440 |
+
exit(main())
|
benchmark_results.md
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 📊 Paris MoE Benchmark Results
|
| 2 |
+
|
| 3 |
+
**GPU:** NVIDIA RTX 6000 Ada Generation
|
| 4 |
+
**Date:** 2025-12-05 16:35:39
|
| 5 |
+
|
| 6 |
+
## 🏗️ Model Architecture
|
| 7 |
+
|
| 8 |
+
| Component | Details |
|
| 9 |
+
|-----------|---------|
|
| 10 |
+
| Experts | 8× DiT-XL/2 (606M params each) |
|
| 11 |
+
| Router | DiT-B/2 (129M params) |
|
| 12 |
+
| Total | ~5 Billion parameters |
|
| 13 |
+
| VAE | SD-VAE (stabilityai/sd-vae-ft-mse) |
|
| 14 |
+
| Text Encoder | CLIP ViT-L/14 |
|
| 15 |
+
|
| 16 |
+
## BF16 Precision
|
| 17 |
+
|
| 18 |
+
| Samples | Steps | TopK | Offload | Gen Time (s) | Throughput (img/s) | Time/Step (s) | Peak Memory (GB) |
|
| 19 |
+
|---------|-------|------|---------|--------------|-------------------|---------------|------------------|
|
| 20 |
+
| 4 | 10 | 1 | 0 | 1.49 | 2.68 | 0.149 | 10.79 |
|
| 21 |
+
|
| 22 |
+
## INT8 Precision
|
| 23 |
+
|
| 24 |
+
| Samples | Steps | TopK | Offload | Gen Time (s) | Throughput (img/s) | Time/Step (s) | Peak Memory (GB) |
|
| 25 |
+
|---------|-------|------|---------|--------------|-------------------|---------------|------------------|
|
| 26 |
+
| 4 | 10 | 1 | 0 | 2.12 | 1.89 | 0.212 | 20.17 |
|
| 27 |
+
|
| 28 |
+
## 📈 Summary
|
| 29 |
+
|
| 30 |
+
- **🏆 Fastest Configuration:** BF16, 4 samples @ 2.68 img/s
|
| 31 |
+
- **💾 Most Memory Efficient:** BF16 with offload=0 @ 10.8 GB peak
|
| 32 |
+
|
| 33 |
+
## 🎯 Recommendations
|
| 34 |
+
|
| 35 |
+
| Use Case | Precision | Offload | Expected Performance |
|
| 36 |
+
|----------|-----------|---------|---------------------|
|
| 37 |
+
| **Production (Quality)** | BF16 | 0 | 2.68 img/s, 10.8 GB |
|
| 38 |
+
| **Balanced** | INT8 | 0 | 1.89 img/s, 20.2 GB |
|
| 39 |
+
|
| 40 |
+
---
|
| 41 |
+
*Generated by Paris MoE Benchmark Utility*
|
generate.py
ADDED
|
@@ -0,0 +1,747 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
╔══════════════════════════════════════════════════════════════════════════════╗
|
| 4 |
+
║ ║
|
| 5 |
+
║ 🎨 Paris MoE - Unified Image Generation Script 🎨 ║
|
| 6 |
+
║ ║
|
| 7 |
+
║ Mixture-of-Experts Diffusion Model (8× DiT-XL/2 + DiT-B/2 Router) ║
|
| 8 |
+
║ ~5 Billion Parameters Total ║
|
| 9 |
+
║ ║
|
| 10 |
+
╚══════════════════════════════════════════════════════════════════════════════╝
|
| 11 |
+
|
| 12 |
+
Supports multiple precision modes:
|
| 13 |
+
- bf16: Best quality, 9.3GB total (~1.2GB per expert)
|
| 14 |
+
- int8: Good quality, 4.8GB total (~580MB per expert), 15x compression
|
| 15 |
+
- mixed: Router in bf16, experts in int8 (balanced)
|
| 16 |
+
|
| 17 |
+
Memory Offloading:
|
| 18 |
+
- --offload N: Keep N experts in CPU memory, move to GPU only during computation
|
| 19 |
+
- Experts are moved to GPU → compute → moved back to CPU (memory offloading)
|
| 20 |
+
- All computation happens on GPU, only storage is on CPU
|
| 21 |
+
|
| 22 |
+
Usage:
|
| 23 |
+
python generate.py --prompt "a cute cat" --precision bf16
|
| 24 |
+
python generate.py --prompt "a sunset over mountains" --precision int8 --visualize
|
| 25 |
+
python generate.py --prompt "abstract art" --precision mixed --num_samples 4 --topk 2
|
| 26 |
+
"""
|
| 27 |
+
|
| 28 |
+
import argparse
|
| 29 |
+
import sys
|
| 30 |
+
import os
|
| 31 |
+
import time
|
| 32 |
+
from pathlib import Path
|
| 33 |
+
|
| 34 |
+
# Add src to path for model imports
|
| 35 |
+
SCRIPT_DIR = Path(__file__).parent.absolute()
|
| 36 |
+
SRC_DIR = SCRIPT_DIR / "src"
|
| 37 |
+
sys.path.insert(0, str(SRC_DIR))
|
| 38 |
+
|
| 39 |
+
import torch
|
| 40 |
+
import torch.nn.functional as F
|
| 41 |
+
from tqdm import tqdm
|
| 42 |
+
from torchvision.utils import make_grid, save_image
|
| 43 |
+
from safetensors.torch import load_file
|
| 44 |
+
from safetensors import safe_open
|
| 45 |
+
from transformers import CLIPTextModel, CLIPTokenizer
|
| 46 |
+
from collections import defaultdict
|
| 47 |
+
|
| 48 |
+
# ═══════════════════════════════════════════════════════════════════════════════
|
| 49 |
+
# WEIGHT PATHS
|
| 50 |
+
# ═══════════════════════════════════════════════════════════════════════════════
|
| 51 |
+
|
| 52 |
+
WEIGHTS_DIR = SCRIPT_DIR / "weights"
|
| 53 |
+
BF16_DIR = WEIGHTS_DIR / "bf16"
|
| 54 |
+
INT8_DIR = WEIGHTS_DIR / "int8"
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
# ═══════════════════════════════════════════════════════════════════════════════
|
| 58 |
+
# ASCII VISUALIZATION
|
| 59 |
+
# ═══════════════════════════════════════════════════════════════════════════════
|
| 60 |
+
|
| 61 |
+
class ExpertTracker:
|
| 62 |
+
"""Tracks which experts are used during generation for visualization."""
|
| 63 |
+
|
| 64 |
+
def __init__(self, num_experts: int = 8):
|
| 65 |
+
self.num_experts = num_experts
|
| 66 |
+
self.usage_counts = defaultdict(int)
|
| 67 |
+
self.per_step_primary = []
|
| 68 |
+
self.total_calls = 0
|
| 69 |
+
|
| 70 |
+
def record(self, expert_ids: torch.Tensor, step: int, weights: torch.Tensor = None):
|
| 71 |
+
"""Record expert usage for a batch at a given step."""
|
| 72 |
+
step_counts = defaultdict(float)
|
| 73 |
+
|
| 74 |
+
if weights is not None:
|
| 75 |
+
for eid, w in zip(expert_ids.flatten().tolist(), weights.flatten().tolist()):
|
| 76 |
+
self.usage_counts[eid] += 1
|
| 77 |
+
step_counts[eid] += w
|
| 78 |
+
self.total_calls += 1
|
| 79 |
+
else:
|
| 80 |
+
for eid in expert_ids.tolist():
|
| 81 |
+
self.usage_counts[eid] += 1
|
| 82 |
+
step_counts[eid] += 1.0
|
| 83 |
+
self.total_calls += 1
|
| 84 |
+
|
| 85 |
+
if step_counts:
|
| 86 |
+
self.per_step_primary.append(max(step_counts, key=step_counts.get))
|
| 87 |
+
|
| 88 |
+
def get_usage_chart(self) -> str:
|
| 89 |
+
"""Chart 1: Expert usage ranked by frequency."""
|
| 90 |
+
if self.total_calls == 0:
|
| 91 |
+
return ""
|
| 92 |
+
|
| 93 |
+
# Sort experts by usage
|
| 94 |
+
sorted_experts = sorted(range(8), key=lambda e: self.usage_counts.get(e, 0), reverse=True)
|
| 95 |
+
max_count = max(self.usage_counts.values()) if self.usage_counts else 1
|
| 96 |
+
unique = sum(1 for e in range(8) if self.usage_counts.get(e, 0) > 0)
|
| 97 |
+
|
| 98 |
+
lines = [
|
| 99 |
+
"",
|
| 100 |
+
"╭────────────────��─────────────────────────────────╮",
|
| 101 |
+
"│ ⚡ EXPERT USAGE DISTRIBUTION │",
|
| 102 |
+
"├──────────────────────────────────────────────────┤",
|
| 103 |
+
]
|
| 104 |
+
|
| 105 |
+
bars = ['▏', '▎', '▍', '▌', '▋', '▊', '▉', '█']
|
| 106 |
+
|
| 107 |
+
for eid in sorted_experts:
|
| 108 |
+
count = self.usage_counts.get(eid, 0)
|
| 109 |
+
pct = 100 * count / self.total_calls if self.total_calls > 0 else 0
|
| 110 |
+
|
| 111 |
+
# Create gradient bar
|
| 112 |
+
bar_width = 28
|
| 113 |
+
fill = (count / max_count) * bar_width if max_count > 0 else 0
|
| 114 |
+
full_blocks = int(fill)
|
| 115 |
+
partial = int((fill - full_blocks) * 8)
|
| 116 |
+
|
| 117 |
+
bar = '█' * full_blocks
|
| 118 |
+
if partial > 0 and full_blocks < bar_width:
|
| 119 |
+
bar += bars[partial - 1]
|
| 120 |
+
bar = bar.ljust(bar_width, ' ')
|
| 121 |
+
|
| 122 |
+
marker = "→" if count == max_count and count > 0 else " "
|
| 123 |
+
lines.append(f"│ {marker} E{eid} │{bar}│ {pct:5.1f}% │")
|
| 124 |
+
|
| 125 |
+
lines.extend([
|
| 126 |
+
"├──────────────────────────────────────────────────┤",
|
| 127 |
+
f"│ Active: {unique}/8 experts Calls: {self.total_calls:<13} │",
|
| 128 |
+
"╰──────────────────────────────────────────────────╯",
|
| 129 |
+
])
|
| 130 |
+
|
| 131 |
+
return "\n".join(lines)
|
| 132 |
+
|
| 133 |
+
def get_timeline(self) -> str:
|
| 134 |
+
"""Chart 2: Visual timeline of expert selection per step."""
|
| 135 |
+
if not self.per_step_primary:
|
| 136 |
+
return ""
|
| 137 |
+
|
| 138 |
+
num_steps = len(self.per_step_primary)
|
| 139 |
+
show_steps = min(20, num_steps)
|
| 140 |
+
|
| 141 |
+
# Count transitions
|
| 142 |
+
transitions = sum(1 for i in range(1, num_steps)
|
| 143 |
+
if self.per_step_primary[i] != self.per_step_primary[i-1])
|
| 144 |
+
|
| 145 |
+
lines = [
|
| 146 |
+
"",
|
| 147 |
+
"╭──────────────────────────────────────────────────╮",
|
| 148 |
+
"│ 📈 ROUTING TIMELINE │",
|
| 149 |
+
"├──────────────────────────────────────────────────┤",
|
| 150 |
+
]
|
| 151 |
+
|
| 152 |
+
# Compact step numbers
|
| 153 |
+
step_row = "│ Step "
|
| 154 |
+
for i in range(show_steps):
|
| 155 |
+
step_row += f"{i:2d} "
|
| 156 |
+
if num_steps > 20:
|
| 157 |
+
step_row = step_row[:48] + "..│"
|
| 158 |
+
else:
|
| 159 |
+
step_row = step_row[:48].ljust(48) + " │"
|
| 160 |
+
lines.append(step_row)
|
| 161 |
+
|
| 162 |
+
lines.append("│ " + "───" * 16 + " │")
|
| 163 |
+
|
| 164 |
+
# Show each expert's timeline
|
| 165 |
+
symbols = ['○', '●']
|
| 166 |
+
for eid in range(self.num_experts):
|
| 167 |
+
row = f"│ E{eid} "
|
| 168 |
+
for step in range(show_steps):
|
| 169 |
+
if self.per_step_primary[step] == eid:
|
| 170 |
+
row += " ● "
|
| 171 |
+
else:
|
| 172 |
+
row += " · "
|
| 173 |
+
if num_steps > 20:
|
| 174 |
+
row = row[:48] + "..│"
|
| 175 |
+
else:
|
| 176 |
+
row = row[:48].ljust(48) + " │"
|
| 177 |
+
lines.append(row)
|
| 178 |
+
|
| 179 |
+
lines.extend([
|
| 180 |
+
"├──────────────────────────────────────────────────┤",
|
| 181 |
+
f"│ Routing changes: {transitions:>3}/{num_steps-1:<3} steps ({100*transitions/(num_steps-1):.0f}%) │",
|
| 182 |
+
"╰──────────────────────────────────────────────────╯",
|
| 183 |
+
])
|
| 184 |
+
|
| 185 |
+
return "\n".join(lines)
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
# ═══════════════════════════════════════════════════════════════════════════════
|
| 189 |
+
# INT8 DEQUANTIZATION
|
| 190 |
+
# ═══════════════════════════════════════════════════════════════════════════════
|
| 191 |
+
|
| 192 |
+
def dequantize_tensor(int8_tensor: torch.Tensor, t_min: float, t_max: float) -> torch.Tensor:
|
| 193 |
+
"""Dequantize INT8 tensor back to float32."""
|
| 194 |
+
if t_min == t_max:
|
| 195 |
+
return torch.full_like(int8_tensor, t_min, dtype=torch.float32)
|
| 196 |
+
normalized = (int8_tensor.float() + 128) / 255.0
|
| 197 |
+
return normalized * (t_max - t_min) + t_min
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
def load_int8_state_dict(safetensors_path: Path) -> dict:
|
| 201 |
+
"""Load and dequantize INT8 safetensors to float32 state_dict."""
|
| 202 |
+
state_dict = {}
|
| 203 |
+
|
| 204 |
+
with safe_open(str(safetensors_path), framework="pt", device="cpu") as f:
|
| 205 |
+
keys = list(f.keys())
|
| 206 |
+
|
| 207 |
+
# Find quantized tensors (those with _min/_max companions)
|
| 208 |
+
quantized_keys = set()
|
| 209 |
+
for key in keys:
|
| 210 |
+
if key.endswith('._min'):
|
| 211 |
+
base_key = key[:-5]
|
| 212 |
+
quantized_keys.add(base_key)
|
| 213 |
+
|
| 214 |
+
# Load and dequantize
|
| 215 |
+
for key in keys:
|
| 216 |
+
# Skip metadata and quantization params
|
| 217 |
+
if key.endswith('._min') or key.endswith('._max'):
|
| 218 |
+
continue
|
| 219 |
+
if key == '_config_json':
|
| 220 |
+
continue
|
| 221 |
+
|
| 222 |
+
tensor = f.get_tensor(key)
|
| 223 |
+
|
| 224 |
+
if key in quantized_keys:
|
| 225 |
+
# Dequantize INT8 tensor
|
| 226 |
+
t_min = f.get_tensor(f"{key}._min").item()
|
| 227 |
+
t_max = f.get_tensor(f"{key}._max").item()
|
| 228 |
+
tensor = dequantize_tensor(tensor, t_min, t_max)
|
| 229 |
+
|
| 230 |
+
state_dict[key] = tensor
|
| 231 |
+
|
| 232 |
+
return state_dict
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
# ═══════════════════════════════════════════════════════════════════════════════
|
| 236 |
+
# MODEL CREATION
|
| 237 |
+
# ═══════════════════════════════════════════════════════════════════════════════
|
| 238 |
+
|
| 239 |
+
def create_expert(config, expert_id: int = 0):
|
| 240 |
+
"""Create a DiT expert model."""
|
| 241 |
+
from models import DiTExpert
|
| 242 |
+
return DiTExpert(config)
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
def create_router(config):
|
| 246 |
+
"""Create a DiT router model."""
|
| 247 |
+
from models import DiTRouter
|
| 248 |
+
return DiTRouter(config)
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
# ═══════════════════════════════════════════════════════════════════════════════
|
| 252 |
+
# SAMPLER CLASS
|
| 253 |
+
# ═══════════════════════════════════════════════════════════════════════════════
|
| 254 |
+
|
| 255 |
+
class ParisSampler:
|
| 256 |
+
"""Unified sampler for Paris MoE model with expert tracking."""
|
| 257 |
+
|
| 258 |
+
def __init__(self, experts: dict, router, vae_manager, config, device='cuda',
|
| 259 |
+
offloaded_experts: set = None):
|
| 260 |
+
self.experts = experts
|
| 261 |
+
self.router = router
|
| 262 |
+
self.vae_manager = vae_manager
|
| 263 |
+
self.config = config
|
| 264 |
+
self.device = device
|
| 265 |
+
self.tracker = None
|
| 266 |
+
self.offloaded_experts = offloaded_experts or set() # Which experts are on CPU
|
| 267 |
+
|
| 268 |
+
# Set models to eval mode
|
| 269 |
+
for expert in self.experts.values():
|
| 270 |
+
expert.eval()
|
| 271 |
+
if self.router is not None:
|
| 272 |
+
self.router.eval()
|
| 273 |
+
|
| 274 |
+
# Precompute null embeddings for CFG
|
| 275 |
+
self._precompute_null_embeddings()
|
| 276 |
+
|
| 277 |
+
def _precompute_null_embeddings(self):
|
| 278 |
+
"""Precompute null embeddings for classifier-free guidance."""
|
| 279 |
+
try:
|
| 280 |
+
text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
|
| 281 |
+
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
|
| 282 |
+
text_encoder = text_encoder.to(self.device)
|
| 283 |
+
text_encoder.eval()
|
| 284 |
+
|
| 285 |
+
with torch.no_grad():
|
| 286 |
+
null_tokens = tokenizer(
|
| 287 |
+
[""],
|
| 288 |
+
max_length=77,
|
| 289 |
+
padding='max_length',
|
| 290 |
+
truncation=True,
|
| 291 |
+
return_tensors='pt'
|
| 292 |
+
)
|
| 293 |
+
self.null_text_embeds = text_encoder(null_tokens.input_ids.to(self.device)).last_hidden_state
|
| 294 |
+
self.null_attention_mask = null_tokens.attention_mask.to(self.device)
|
| 295 |
+
|
| 296 |
+
del text_encoder, tokenizer
|
| 297 |
+
torch.cuda.empty_cache()
|
| 298 |
+
except Exception as e:
|
| 299 |
+
print(f"Warning: Could not precompute null text embeddings: {e}")
|
| 300 |
+
self.null_text_embeds = None
|
| 301 |
+
self.null_attention_mask = None
|
| 302 |
+
|
| 303 |
+
def _encode_text_prompts(self, text_prompts: list):
|
| 304 |
+
"""Encode text prompts using CLIP."""
|
| 305 |
+
text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
|
| 306 |
+
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
|
| 307 |
+
|
| 308 |
+
text_encoder = text_encoder.to(self.device)
|
| 309 |
+
text_encoder.eval()
|
| 310 |
+
|
| 311 |
+
tokenizer_output = tokenizer(
|
| 312 |
+
text_prompts,
|
| 313 |
+
max_length=77,
|
| 314 |
+
padding='max_length',
|
| 315 |
+
truncation=True,
|
| 316 |
+
return_tensors='pt'
|
| 317 |
+
)
|
| 318 |
+
tokens = tokenizer_output.input_ids.to(self.device)
|
| 319 |
+
attention_mask = tokenizer_output.attention_mask.to(self.device)
|
| 320 |
+
|
| 321 |
+
with torch.no_grad():
|
| 322 |
+
text_embeds = text_encoder(tokens).last_hidden_state
|
| 323 |
+
|
| 324 |
+
del text_encoder, tokenizer
|
| 325 |
+
torch.cuda.empty_cache()
|
| 326 |
+
|
| 327 |
+
return text_embeds, attention_mask
|
| 328 |
+
|
| 329 |
+
def _move_expert_to_gpu(self, expert_id: int):
|
| 330 |
+
"""Move an offloaded expert to GPU for computation."""
|
| 331 |
+
if expert_id in self.offloaded_experts:
|
| 332 |
+
self.experts[expert_id] = self.experts[expert_id].to(self.device)
|
| 333 |
+
torch.cuda.synchronize() # Ensure transfer is complete
|
| 334 |
+
|
| 335 |
+
def _move_expert_to_cpu(self, expert_id: int):
|
| 336 |
+
"""Move expert back to CPU after computation (memory offloading)."""
|
| 337 |
+
if expert_id in self.offloaded_experts:
|
| 338 |
+
self.experts[expert_id] = self.experts[expert_id].cpu()
|
| 339 |
+
torch.cuda.empty_cache() # Free GPU memory immediately
|
| 340 |
+
|
| 341 |
+
def _run_expert_with_cfg(self, expert_id: int, samples: torch.Tensor, t: torch.Tensor,
|
| 342 |
+
text_embeds: torch.Tensor, attention_mask: torch.Tensor,
|
| 343 |
+
null_embeds: torch.Tensor, null_mask: torch.Tensor,
|
| 344 |
+
cfg_scale: float) -> torch.Tensor:
|
| 345 |
+
"""Run expert inference with optional CFG, handling memory offloading."""
|
| 346 |
+
# Move to GPU if offloaded
|
| 347 |
+
self._move_expert_to_gpu(expert_id)
|
| 348 |
+
|
| 349 |
+
expert = self.experts[expert_id]
|
| 350 |
+
|
| 351 |
+
try:
|
| 352 |
+
if cfg_scale != 1.0:
|
| 353 |
+
v_cond = expert(samples, t, text_embeds, attention_mask)
|
| 354 |
+
v_uncond = expert(samples, t, null_embeds, null_mask)
|
| 355 |
+
v_pred = v_uncond + cfg_scale * (v_cond - v_uncond)
|
| 356 |
+
else:
|
| 357 |
+
v_pred = expert(samples, t, text_embeds, attention_mask)
|
| 358 |
+
|
| 359 |
+
return v_pred
|
| 360 |
+
finally:
|
| 361 |
+
# Move back to CPU if it was offloaded (memory offloading)
|
| 362 |
+
self._move_expert_to_cpu(expert_id)
|
| 363 |
+
|
| 364 |
+
def sample(self, num_samples: int, text_prompts: list, cfg_scale: float = 7.5,
|
| 365 |
+
num_steps: int = 30, use_bf16: bool = True, track_experts: bool = False,
|
| 366 |
+
topk: int = 1):
|
| 367 |
+
"""
|
| 368 |
+
Generate samples using expert routing.
|
| 369 |
+
|
| 370 |
+
Args:
|
| 371 |
+
num_samples: Number of images to generate
|
| 372 |
+
text_prompts: List of text prompts
|
| 373 |
+
cfg_scale: Classifier-free guidance scale
|
| 374 |
+
num_steps: Number of sampling steps
|
| 375 |
+
use_bf16: Use bfloat16 precision
|
| 376 |
+
track_experts: Track and visualize expert usage
|
| 377 |
+
topk: Number of experts to use per sample (1=top-1, 2=top-2, etc.)
|
| 378 |
+
"""
|
| 379 |
+
# Initialize tracker if requested
|
| 380 |
+
if track_experts:
|
| 381 |
+
self.tracker = ExpertTracker(num_experts=8)
|
| 382 |
+
else:
|
| 383 |
+
self.tracker = None
|
| 384 |
+
|
| 385 |
+
text_embeds, attention_mask = self._encode_text_prompts(text_prompts)
|
| 386 |
+
|
| 387 |
+
latent_size = self.config.image_size
|
| 388 |
+
channels = 4
|
| 389 |
+
dtype = torch.bfloat16 if use_bf16 else torch.float32
|
| 390 |
+
|
| 391 |
+
# Start with random noise
|
| 392 |
+
samples = torch.randn(
|
| 393 |
+
num_samples, channels, latent_size, latent_size,
|
| 394 |
+
device=self.device, dtype=dtype
|
| 395 |
+
)
|
| 396 |
+
|
| 397 |
+
# Convert text embeds to appropriate dtype
|
| 398 |
+
text_embeds = text_embeds.to(dtype)
|
| 399 |
+
if self.null_text_embeds is not None:
|
| 400 |
+
null_text_embeds = self.null_text_embeds.to(dtype)
|
| 401 |
+
null_attention_mask = self.null_attention_mask
|
| 402 |
+
|
| 403 |
+
dt = 1.0 / num_steps
|
| 404 |
+
|
| 405 |
+
autocast_ctx = torch.amp.autocast(device_type='cuda', dtype=dtype) if use_bf16 else torch.no_grad()
|
| 406 |
+
|
| 407 |
+
with torch.no_grad(), autocast_ctx:
|
| 408 |
+
for i in tqdm(range(num_steps), desc="🎨 Generating"):
|
| 409 |
+
t = torch.ones(num_samples, device=self.device) * (1.0 - i * dt)
|
| 410 |
+
|
| 411 |
+
# Expand text embeddings if needed
|
| 412 |
+
batch_text_embeds = text_embeds.expand(num_samples, -1, -1) if text_embeds.shape[0] == 1 else text_embeds[:num_samples]
|
| 413 |
+
batch_attention_mask = attention_mask.expand(num_samples, -1) if attention_mask.shape[0] == 1 else attention_mask[:num_samples]
|
| 414 |
+
|
| 415 |
+
# Get router predictions (router expects float32)
|
| 416 |
+
with torch.amp.autocast(device_type='cuda', enabled=False):
|
| 417 |
+
router_logits = self.router(samples.float(), t.float())
|
| 418 |
+
expert_probs = F.softmax(router_logits, dim=1)
|
| 419 |
+
|
| 420 |
+
if topk == 1:
|
| 421 |
+
# Top-1 routing
|
| 422 |
+
expert_choices = torch.argmax(expert_probs, dim=1)
|
| 423 |
+
|
| 424 |
+
# Track expert usage
|
| 425 |
+
if self.tracker is not None:
|
| 426 |
+
self.tracker.record(expert_choices, i)
|
| 427 |
+
|
| 428 |
+
# Predict velocity for each sample using selected expert
|
| 429 |
+
v_pred = torch.zeros_like(samples)
|
| 430 |
+
|
| 431 |
+
for expert_id in range(8):
|
| 432 |
+
mask = (expert_choices == expert_id)
|
| 433 |
+
if mask.any():
|
| 434 |
+
mask_size = mask.sum().item()
|
| 435 |
+
null_embeds = null_text_embeds.expand(mask_size, -1, -1)
|
| 436 |
+
null_mask = null_attention_mask.expand(mask_size, -1)
|
| 437 |
+
|
| 438 |
+
v_batch = self._run_expert_with_cfg(
|
| 439 |
+
expert_id,
|
| 440 |
+
samples[mask], t[mask],
|
| 441 |
+
batch_text_embeds[mask], batch_attention_mask[mask],
|
| 442 |
+
null_embeds, null_mask,
|
| 443 |
+
cfg_scale
|
| 444 |
+
)
|
| 445 |
+
v_pred[mask] = v_batch
|
| 446 |
+
else:
|
| 447 |
+
# Top-K routing with weighted ensemble
|
| 448 |
+
topk_probs, topk_indices = torch.topk(expert_probs, k=min(topk, 8), dim=1)
|
| 449 |
+
topk_probs = topk_probs / topk_probs.sum(dim=1, keepdim=True) # Renormalize
|
| 450 |
+
|
| 451 |
+
# Track expert usage
|
| 452 |
+
if self.tracker is not None:
|
| 453 |
+
self.tracker.record(topk_indices, i, topk_probs)
|
| 454 |
+
|
| 455 |
+
v_pred = torch.zeros_like(samples)
|
| 456 |
+
|
| 457 |
+
# Process each sample
|
| 458 |
+
for sample_idx in range(num_samples):
|
| 459 |
+
v_sample = torch.zeros(channels, latent_size, latent_size,
|
| 460 |
+
device=self.device, dtype=dtype)
|
| 461 |
+
|
| 462 |
+
for k_idx in range(topk_indices.shape[1]):
|
| 463 |
+
expert_id = topk_indices[sample_idx, k_idx].item()
|
| 464 |
+
weight = topk_probs[sample_idx, k_idx].item()
|
| 465 |
+
|
| 466 |
+
null_embeds = null_text_embeds
|
| 467 |
+
null_mask = null_attention_mask
|
| 468 |
+
|
| 469 |
+
v_expert = self._run_expert_with_cfg(
|
| 470 |
+
expert_id,
|
| 471 |
+
samples[sample_idx:sample_idx+1],
|
| 472 |
+
t[sample_idx:sample_idx+1],
|
| 473 |
+
batch_text_embeds[sample_idx:sample_idx+1],
|
| 474 |
+
batch_attention_mask[sample_idx:sample_idx+1],
|
| 475 |
+
null_embeds, null_mask,
|
| 476 |
+
cfg_scale
|
| 477 |
+
)
|
| 478 |
+
|
| 479 |
+
v_sample += weight * v_expert.squeeze(0)
|
| 480 |
+
|
| 481 |
+
v_pred[sample_idx] = v_sample
|
| 482 |
+
|
| 483 |
+
# Euler integration step
|
| 484 |
+
samples = samples - v_pred * dt
|
| 485 |
+
|
| 486 |
+
return samples.float()
|
| 487 |
+
|
| 488 |
+
|
| 489 |
+
# ═══════════════════════════════════════════════════════════════════════════════
|
| 490 |
+
# MODEL LOADING
|
| 491 |
+
# ═══════════════════════════════════════════════════════════════════════════════
|
| 492 |
+
|
| 493 |
+
def load_sampler(precision: str = 'bf16', device: str = 'cuda', offload: int = 0):
|
| 494 |
+
"""
|
| 495 |
+
Load Paris MoE sampler with specified precision.
|
| 496 |
+
|
| 497 |
+
Args:
|
| 498 |
+
precision: Weight precision ('bf16', 'int8', 'mixed')
|
| 499 |
+
device: Compute device ('cuda' or 'cpu')
|
| 500 |
+
offload: Number of experts to keep in CPU memory (0-7)
|
| 501 |
+
These experts will be moved to GPU only during computation.
|
| 502 |
+
"""
|
| 503 |
+
from vae_utils import VAEManager
|
| 504 |
+
|
| 505 |
+
# Determine weight directories based on precision
|
| 506 |
+
if precision == 'bf16':
|
| 507 |
+
expert_dir = BF16_DIR
|
| 508 |
+
router_dir = BF16_DIR
|
| 509 |
+
use_int8_experts = False
|
| 510 |
+
elif precision == 'int8':
|
| 511 |
+
expert_dir = INT8_DIR
|
| 512 |
+
router_dir = BF16_DIR # Router always from bf16
|
| 513 |
+
use_int8_experts = True
|
| 514 |
+
elif precision == 'mixed':
|
| 515 |
+
expert_dir = INT8_DIR
|
| 516 |
+
router_dir = BF16_DIR
|
| 517 |
+
use_int8_experts = True
|
| 518 |
+
else:
|
| 519 |
+
raise ValueError(f"Unknown precision: {precision}. Use 'bf16', 'int8', or 'mixed'.")
|
| 520 |
+
|
| 521 |
+
# Load config
|
| 522 |
+
config_path = BF16_DIR / 'config.pt'
|
| 523 |
+
config_data = torch.load(config_path, map_location='cpu', weights_only=False)
|
| 524 |
+
config = config_data['config']
|
| 525 |
+
|
| 526 |
+
# Load router config
|
| 527 |
+
router_config_path = BF16_DIR / 'router_config.pt'
|
| 528 |
+
router_config_data = torch.load(router_config_path, map_location='cpu', weights_only=False)
|
| 529 |
+
router_config = router_config_data['config']
|
| 530 |
+
|
| 531 |
+
# Update config with router params
|
| 532 |
+
config.router_architecture = router_config.router_architecture
|
| 533 |
+
config.router_params = router_config.router_params
|
| 534 |
+
|
| 535 |
+
# Load router (always on GPU, bf16/float32)
|
| 536 |
+
print("📡 Loading router...")
|
| 537 |
+
router = create_router(config).to(device)
|
| 538 |
+
router_weights = load_file(str(router_dir / 'router.safetensors'))
|
| 539 |
+
router_weights = {k: v.float() for k, v in router_weights.items()}
|
| 540 |
+
router.load_state_dict(router_weights)
|
| 541 |
+
router.eval()
|
| 542 |
+
|
| 543 |
+
# Determine which experts to offload
|
| 544 |
+
# Offload the LAST N experts (highest IDs)
|
| 545 |
+
offloaded_experts = set(range(8 - offload, 8)) if offload > 0 else set()
|
| 546 |
+
|
| 547 |
+
# Load experts
|
| 548 |
+
experts = {}
|
| 549 |
+
for i in range(8):
|
| 550 |
+
print(f"🧠 Loading expert {i}...", end="")
|
| 551 |
+
expert = create_expert(config, expert_id=i)
|
| 552 |
+
|
| 553 |
+
if use_int8_experts:
|
| 554 |
+
expert_weights = load_int8_state_dict(expert_dir / f'expert_{i}.safetensors')
|
| 555 |
+
else:
|
| 556 |
+
expert_weights = load_file(str(expert_dir / f'expert_{i}.safetensors'))
|
| 557 |
+
|
| 558 |
+
expert.load_state_dict(expert_weights)
|
| 559 |
+
expert.eval()
|
| 560 |
+
|
| 561 |
+
# Convert to bf16 if using bf16 precision
|
| 562 |
+
if precision == 'bf16':
|
| 563 |
+
expert = expert.to(torch.bfloat16)
|
| 564 |
+
|
| 565 |
+
# Decide where to place the expert
|
| 566 |
+
if i in offloaded_experts:
|
| 567 |
+
expert = expert.cpu() # Keep in CPU memory
|
| 568 |
+
print(f" 💾 (CPU memory, GPU compute)")
|
| 569 |
+
else:
|
| 570 |
+
expert = expert.to(device) # Keep on GPU
|
| 571 |
+
print(f" 🎮 (GPU)")
|
| 572 |
+
|
| 573 |
+
experts[i] = expert
|
| 574 |
+
|
| 575 |
+
# Load VAE
|
| 576 |
+
print("🖼️ Loading VAE...")
|
| 577 |
+
vae_manager = VAEManager(device=device)
|
| 578 |
+
|
| 579 |
+
return ParisSampler(experts, router, vae_manager, config, device, offloaded_experts)
|
| 580 |
+
|
| 581 |
+
|
| 582 |
+
# ═══════════════════════════════════════════════════════════════════════════════
|
| 583 |
+
# MAIN ENTRYPOINT
|
| 584 |
+
# ═══════════════════════════════════════════════════════════════════════════════
|
| 585 |
+
|
| 586 |
+
def parse_args():
|
| 587 |
+
parser = argparse.ArgumentParser(
|
| 588 |
+
description="🎨 Paris MoE - Image Generation",
|
| 589 |
+
formatter_class=argparse.RawDescriptionHelpFormatter,
|
| 590 |
+
epilog="""
|
| 591 |
+
Examples:
|
| 592 |
+
python generate.py --prompt "a cute cat playing piano"
|
| 593 |
+
python generate.py --prompt "sunset over mountains" --precision int8 --visualize
|
| 594 |
+
python generate.py --prompt "abstract art" --num_samples 4 --cfg_scale 10 --topk 2
|
| 595 |
+
python generate.py --prompt "portrait" --offload 4 # Offload 4 experts to CPU memory
|
| 596 |
+
"""
|
| 597 |
+
)
|
| 598 |
+
|
| 599 |
+
parser.add_argument("--prompt", type=str, default="a cute cat",
|
| 600 |
+
help="Text prompt for generation")
|
| 601 |
+
parser.add_argument("--num_samples", type=int, default=16,
|
| 602 |
+
help="Number of images to generate (default: 16)")
|
| 603 |
+
parser.add_argument("--cfg_scale", type=float, default=7.5,
|
| 604 |
+
help="Classifier-free guidance scale (default: 7.5)")
|
| 605 |
+
parser.add_argument("--num_steps", type=int, default=30,
|
| 606 |
+
help="Number of sampling steps (default: 30)")
|
| 607 |
+
parser.add_argument("--seed", type=int, default=999,
|
| 608 |
+
help="Random seed for reproducibility")
|
| 609 |
+
parser.add_argument("--output", type=str, default=None,
|
| 610 |
+
help="Output image path (default: output_<precision>.png)")
|
| 611 |
+
parser.add_argument("--precision", type=str, default="bf16",
|
| 612 |
+
choices=["bf16", "int8", "mixed"],
|
| 613 |
+
help="Weight precision: bf16, int8, or mixed (default: bf16)")
|
| 614 |
+
parser.add_argument("--offload", type=int, default=0,
|
| 615 |
+
help="Number of experts to keep in CPU memory (0-7). Computation still on GPU.")
|
| 616 |
+
parser.add_argument("--topk", type=int, default=2,
|
| 617 |
+
help="Top-K expert routing (1=top-1, 2=top-2 ensemble, etc.) [default: 2]")
|
| 618 |
+
parser.add_argument("--visualize", action="store_true",
|
| 619 |
+
help="Show expert usage visualization")
|
| 620 |
+
parser.add_argument("--no-save", action="store_true",
|
| 621 |
+
help="Don't save output image (for testing)")
|
| 622 |
+
|
| 623 |
+
return parser.parse_args()
|
| 624 |
+
|
| 625 |
+
|
| 626 |
+
def print_header():
|
| 627 |
+
"""Print beautiful ASCII header."""
|
| 628 |
+
print("""
|
| 629 |
+
╔══════════════════════════════════════════════════════════════════════════════╗
|
| 630 |
+
║ ║
|
| 631 |
+
║ ██████╗ █████╗ ██████╗ ██╗███████╗ ███╗ ███╗ ██████╗ ███████╗ ║
|
| 632 |
+
║ ██╔══██╗██╔══██╗██╔══██╗██║██╔════╝ ████╗ ████║██╔═══██╗██╔════╝ ║
|
| 633 |
+
║ ██████╔╝███████║██████╔╝██║███████╗ ██╔████╔██║██║ ██║█████╗ ║
|
| 634 |
+
║ ██╔═══╝ ██╔══██║██╔══██╗██║╚════██║ ██║╚██╔╝██║██║ ██║██╔══╝ ║
|
| 635 |
+
║ ██║ ██║ ██║██║ ██║██║███████║ ██║ ╚═╝ ██║╚██████╔╝███████╗ ║
|
| 636 |
+
║ ╚═╝ ╚═╝ ╚═╝╚═╝ ╚═╝╚═╝╚══════╝ ╚═╝ ╚═╝ ╚═════╝ ╚══════╝ ║
|
| 637 |
+
║ ║
|
| 638 |
+
║ 🎨 Mixture-of-Experts Text-to-Image Diffusion Model ║
|
| 639 |
+
║ 📊 8× DiT-XL/2 Experts + DiT-B/2 Router (~5B Parameters) ║
|
| 640 |
+
║ ║
|
| 641 |
+
╚══════════════════════════════════════════════════════════════════════════════╝
|
| 642 |
+
""")
|
| 643 |
+
|
| 644 |
+
|
| 645 |
+
def print_config(args):
|
| 646 |
+
"""Print configuration summary."""
|
| 647 |
+
offload_str = f"{args.offload} experts (CPU mem, GPU compute)" if args.offload > 0 else "None"
|
| 648 |
+
topk_str = f"Top-{args.topk}" if args.topk > 1 else "Top-1"
|
| 649 |
+
|
| 650 |
+
print(f"""
|
| 651 |
+
┌──────────────────────────────────────────────────────────────────────────────┐
|
| 652 |
+
│ 📋 Configuration │
|
| 653 |
+
├──────────────────────────────────────────────────────────────────────────────┤
|
| 654 |
+
│ Prompt: {args.prompt[:50]:<50}│
|
| 655 |
+
│ Samples: {args.num_samples:<50}│
|
| 656 |
+
│ Steps: {args.num_steps:<50}│
|
| 657 |
+
│ CFG Scale: {args.cfg_scale:<50}│
|
| 658 |
+
│ Precision: {args.precision.upper():<50}│
|
| 659 |
+
│ Routing: {topk_str:<50}│
|
| 660 |
+
│ Seed: {args.seed:<50}│
|
| 661 |
+
│ Offload: {offload_str:<50}│
|
| 662 |
+
└──────────────────────────────────────────────────────────────────────────────┘
|
| 663 |
+
""")
|
| 664 |
+
|
| 665 |
+
|
| 666 |
+
def main():
|
| 667 |
+
args = parse_args()
|
| 668 |
+
|
| 669 |
+
# Print header
|
| 670 |
+
print_header()
|
| 671 |
+
print_config(args)
|
| 672 |
+
|
| 673 |
+
# Set seed
|
| 674 |
+
torch.manual_seed(args.seed)
|
| 675 |
+
if torch.cuda.is_available():
|
| 676 |
+
torch.cuda.manual_seed(args.seed)
|
| 677 |
+
|
| 678 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 679 |
+
print(f"🖥️ Using device: {device}")
|
| 680 |
+
|
| 681 |
+
# Load sampler
|
| 682 |
+
print(f"\n📦 Loading {args.precision.upper()} weights...")
|
| 683 |
+
start_load = time.time()
|
| 684 |
+
sampler = load_sampler(
|
| 685 |
+
precision=args.precision,
|
| 686 |
+
device=device,
|
| 687 |
+
offload=args.offload
|
| 688 |
+
)
|
| 689 |
+
load_time = time.time() - start_load
|
| 690 |
+
print(f"⏱️ Model loaded in {load_time:.1f}s")
|
| 691 |
+
|
| 692 |
+
# Generate samples
|
| 693 |
+
print(f"\n🎨 Generating {args.num_samples} images...")
|
| 694 |
+
start_gen = time.time()
|
| 695 |
+
latents = sampler.sample(
|
| 696 |
+
num_samples=args.num_samples,
|
| 697 |
+
text_prompts=[args.prompt],
|
| 698 |
+
cfg_scale=args.cfg_scale,
|
| 699 |
+
num_steps=args.num_steps,
|
| 700 |
+
use_bf16=(args.precision == 'bf16'),
|
| 701 |
+
track_experts=args.visualize,
|
| 702 |
+
topk=args.topk
|
| 703 |
+
)
|
| 704 |
+
gen_time = time.time() - start_gen
|
| 705 |
+
|
| 706 |
+
# Show visualization if requested
|
| 707 |
+
if args.visualize and sampler.tracker is not None:
|
| 708 |
+
print(sampler.tracker.get_usage_chart())
|
| 709 |
+
print(sampler.tracker.get_timeline())
|
| 710 |
+
|
| 711 |
+
# Decode latents
|
| 712 |
+
print("\n🖼️ Decoding latents...")
|
| 713 |
+
start_decode = time.time()
|
| 714 |
+
images = sampler.vae_manager.decode(latents)
|
| 715 |
+
images = (images + 1.0) / 2.0
|
| 716 |
+
images = torch.clamp(images, 0, 1)
|
| 717 |
+
decode_time = time.time() - start_decode
|
| 718 |
+
|
| 719 |
+
# Save output
|
| 720 |
+
if not args.no_save:
|
| 721 |
+
output_path = args.output or f"output_{args.precision}.png"
|
| 722 |
+
nrow = 4 if args.num_samples >= 4 else args.num_samples
|
| 723 |
+
grid = make_grid(images.cpu(), nrow=nrow, normalize=False, padding=2)
|
| 724 |
+
save_image(grid, output_path)
|
| 725 |
+
print(f"\n✅ Saved to: {output_path}")
|
| 726 |
+
|
| 727 |
+
# Print timing summary
|
| 728 |
+
total_time = load_time + gen_time + decode_time
|
| 729 |
+
throughput = args.num_samples / gen_time
|
| 730 |
+
|
| 731 |
+
print(f"""
|
| 732 |
+
╔══════════════════════════════════════════════════════════════════════════════╗
|
| 733 |
+
║ ⏱️ Timing Summary ⏱️ ║
|
| 734 |
+
╠══════════════════════════════════════════════════════════════════════════════╣
|
| 735 |
+
║ Model loading: {load_time:>6.1f}s ║
|
| 736 |
+
║ Generation: {gen_time:>6.1f}s ({throughput:.2f} img/s, {gen_time/args.num_steps:.2f}s/step) ║
|
| 737 |
+
║ VAE decoding: {decode_time:>6.1f}s ║
|
| 738 |
+
║ ────────────────────────────── ║
|
| 739 |
+
║ Total: {total_time:>6.1f}s ║
|
| 740 |
+
╚══════════════════════════════════════════════════════════════════════════════╝
|
| 741 |
+
""")
|
| 742 |
+
|
| 743 |
+
print("🎉 Done!")
|
| 744 |
+
|
| 745 |
+
|
| 746 |
+
if __name__ == "__main__":
|
| 747 |
+
main()
|
instructions.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
What we're now is we're going to prepare an inference folder and we're going to make an inference repository for our Paris model. This will include, we will just stick with int8 and bfloat16 and mixed int8, bfloat16 for now. And this repository will include efficient methods how to run the code. It will include quantization code that can accept either PT or save tensors or Float32 save tensors or Float32 PT. It will include a lot of different methods that can accept either PT or Float32 PT. We need to make a visualizer next which outputs a little pretty ASCII chart. We should output the ASCII chart right on the terminal every time we run the inference via this tool. Let's just say we're running the int8 inference of the mixed int8 model. By the way, we're also going to put the weights that we quantized inside this inference folder because we're going to publish this on HuggingFace. have just again, the beef flow 16 and intake weights. we might already be done this by the way. But again, I wanted to do that when we have to keep some kind of track and output a chart in the terminal, like as a little terminal visualization in ASCII. MAKE SURE WE'RE DOING ROUTING PROPERLY. Top 2 etc. Again, just to recap, we're going to make a folder that's just called inference. In this folder, we're going to put the quantized weights that we already made, because we already made them before in the last session. So the bfloat16 and the int8 weights. And we're going to put one Python file for the inference code, and it's going to have all the flags, and it's also going to have a visualized flag. And the visualized flag is actually a lot more than that, because it keeps track of which expert is being used during each inference step, and that shows like a little pretty chart. So if we're generating with 30 steps, which is going to show which experts got to use the most and the least out of eight of them. And so we want to have this in the inference code. Make sure to read files in full before like a pass inference code that we already wrote. Try to list like the most recent files that we made for that. And we also want to have the quantization code to just be an all in one utility with a very nice terminal interface as well, because we want the quantization code to be able to handle float 16 bfloat 16 float 32 weights in both safe tensors and in dot pt format. So that needs to be very smart and tested that it actually works. And also, yeah, make a read me in this folder for the Paris model, because we're going to publish this on hogging face as the inference repository. And then also read all the MD files that we have written here in full because after we do all of this and after we test that it works and it differences fine. We're going to we're going to start to play around with network inference. So that's going to be the fun next step after. So again, make a 20 point to do this for this and please make sure to include at least four or five sentences per point. So the to do list is going to be very long, naturally and very detailed. But I believe we're going to do an excellent, excellent job here.
|
quantize.py
ADDED
|
@@ -0,0 +1,435 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
╔══════════════════════════════════════════════════════════════════════════════╗
|
| 4 |
+
║ ║
|
| 5 |
+
║ 🔧 Paris MoE - Weight Quantization Utility 🔧 ║
|
| 6 |
+
║ ║
|
| 7 |
+
║ Converts weights between formats: ║
|
| 8 |
+
║ • Input: .pt (PyTorch) or .safetensors (F32 or BF16) ║
|
| 9 |
+
║ • Output: BF16 or INT8 safetensors ║
|
| 10 |
+
║ ║
|
| 11 |
+
╚══════════════════════════════════════════════════════════════════════════════╝
|
| 12 |
+
|
| 13 |
+
Usage:
|
| 14 |
+
# Convert original .pt files to BF16 safetensors
|
| 15 |
+
python quantize.py --input /path/to/weights/ --output ./weights/bf16 --format bf16
|
| 16 |
+
|
| 17 |
+
# Convert to INT8 safetensors
|
| 18 |
+
python quantize.py --input /path/to/weights/ --output ./weights/int8 --format int8
|
| 19 |
+
|
| 20 |
+
# Convert from existing safetensors (bf16 -> int8)
|
| 21 |
+
python quantize.py --input ./weights/bf16 --output ./weights/int8 --format int8
|
| 22 |
+
|
| 23 |
+
Input Formats Supported:
|
| 24 |
+
- PyTorch .pt files (original training checkpoints)
|
| 25 |
+
- SafeTensors .safetensors files (F32 or BF16)
|
| 26 |
+
|
| 27 |
+
Output Formats:
|
| 28 |
+
- bf16: BFloat16 safetensors (best quality, ~1.2GB per expert)
|
| 29 |
+
- int8: INT8 quantized safetensors (~580MB per expert)
|
| 30 |
+
"""
|
| 31 |
+
|
| 32 |
+
import argparse
|
| 33 |
+
import os
|
| 34 |
+
import gc
|
| 35 |
+
from pathlib import Path
|
| 36 |
+
from typing import Dict, Optional, Tuple
|
| 37 |
+
import json
|
| 38 |
+
|
| 39 |
+
import torch
|
| 40 |
+
from safetensors.torch import save_file, load_file
|
| 41 |
+
from safetensors import safe_open
|
| 42 |
+
from tqdm import tqdm
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
# ═══════════════════════════════════════════════════════════════════════════════
|
| 46 |
+
# FILE DETECTION
|
| 47 |
+
# ═══════════════════════════════════════════════════════════════════════════════
|
| 48 |
+
|
| 49 |
+
def detect_input_format(input_dir: Path) -> Tuple[str, Dict[str, Path]]:
|
| 50 |
+
"""
|
| 51 |
+
Detect input format and locate weight files.
|
| 52 |
+
|
| 53 |
+
Returns:
|
| 54 |
+
format: 'pt' or 'safetensors'
|
| 55 |
+
files: Dict mapping 'expert_0'..'expert_7', 'router' to file paths
|
| 56 |
+
"""
|
| 57 |
+
files = {}
|
| 58 |
+
|
| 59 |
+
# Check for PyTorch .pt files (original format)
|
| 60 |
+
pt_patterns = [
|
| 61 |
+
# Pattern 1: Full training checkpoint names
|
| 62 |
+
("dit_xl2_multi_expert_pretrained_text_new_dataset_expert_{}_best.pt", "expert_{}"),
|
| 63 |
+
("laion_router_preclustered_dit_berthead_b2_improved_router_best.pt", "router"),
|
| 64 |
+
# Pattern 2: Simple names
|
| 65 |
+
("expert_{}_best.pt", "expert_{}"),
|
| 66 |
+
("expert_{}.pt", "expert_{}"),
|
| 67 |
+
("router_best.pt", "router"),
|
| 68 |
+
("router.pt", "router"),
|
| 69 |
+
]
|
| 70 |
+
|
| 71 |
+
# Check for SafeTensors files
|
| 72 |
+
st_patterns = [
|
| 73 |
+
("expert_{}.safetensors", "expert_{}"),
|
| 74 |
+
("router.safetensors", "router"),
|
| 75 |
+
]
|
| 76 |
+
|
| 77 |
+
# Try PyTorch patterns first
|
| 78 |
+
for pattern, key_pattern in pt_patterns:
|
| 79 |
+
if "{}" in pattern:
|
| 80 |
+
# Expert pattern
|
| 81 |
+
for i in range(8):
|
| 82 |
+
filename = pattern.format(i)
|
| 83 |
+
filepath = input_dir / filename
|
| 84 |
+
if filepath.exists():
|
| 85 |
+
key = key_pattern.format(i)
|
| 86 |
+
files[key] = filepath
|
| 87 |
+
else:
|
| 88 |
+
# Router pattern
|
| 89 |
+
filepath = input_dir / pattern
|
| 90 |
+
if filepath.exists():
|
| 91 |
+
files[key_pattern] = filepath
|
| 92 |
+
|
| 93 |
+
if len(files) >= 8: # At least 8 experts found
|
| 94 |
+
return 'pt', files
|
| 95 |
+
|
| 96 |
+
# Try SafeTensors patterns
|
| 97 |
+
files = {}
|
| 98 |
+
for pattern, key_pattern in st_patterns:
|
| 99 |
+
if "{}" in pattern:
|
| 100 |
+
for i in range(8):
|
| 101 |
+
filename = pattern.format(i)
|
| 102 |
+
filepath = input_dir / filename
|
| 103 |
+
if filepath.exists():
|
| 104 |
+
key = key_pattern.format(i)
|
| 105 |
+
files[key] = filepath
|
| 106 |
+
else:
|
| 107 |
+
filepath = input_dir / pattern
|
| 108 |
+
if filepath.exists():
|
| 109 |
+
files[key_pattern] = filepath
|
| 110 |
+
|
| 111 |
+
if len(files) >= 8:
|
| 112 |
+
return 'safetensors', files
|
| 113 |
+
|
| 114 |
+
# List what we found
|
| 115 |
+
print(f"Found files in {input_dir}:")
|
| 116 |
+
for f in sorted(input_dir.glob("*")):
|
| 117 |
+
print(f" {f.name}")
|
| 118 |
+
|
| 119 |
+
raise ValueError(f"Could not find weight files in {input_dir}")
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
# ═══════════════════════════════════════════════════════════════════════════════
|
| 123 |
+
# LOADING UTILITIES
|
| 124 |
+
# ═══════════════════════════════════════════════════════════════════════════════
|
| 125 |
+
|
| 126 |
+
def load_pt_expert(filepath: Path, expert_id: int) -> Tuple[dict, Optional[object]]:
|
| 127 |
+
"""
|
| 128 |
+
Load expert weights from PyTorch checkpoint.
|
| 129 |
+
|
| 130 |
+
Returns:
|
| 131 |
+
state_dict: Model weights
|
| 132 |
+
config: Config object if available
|
| 133 |
+
"""
|
| 134 |
+
print(f" Loading {filepath.name}...")
|
| 135 |
+
ckpt = torch.load(filepath, map_location='cpu', weights_only=False)
|
| 136 |
+
|
| 137 |
+
# Try EMA weights first (preferred for inference)
|
| 138 |
+
ema_key = f'expert_{expert_id}_ema_state_dict'
|
| 139 |
+
regular_key = f'expert_{expert_id}_state_dict'
|
| 140 |
+
|
| 141 |
+
if ema_key in ckpt:
|
| 142 |
+
state_dict = ckpt[ema_key]
|
| 143 |
+
print(f" Using EMA weights")
|
| 144 |
+
elif regular_key in ckpt:
|
| 145 |
+
state_dict = ckpt[regular_key]
|
| 146 |
+
print(f" Using regular weights (no EMA)")
|
| 147 |
+
else:
|
| 148 |
+
# Try to find any state dict key
|
| 149 |
+
for k in ckpt.keys():
|
| 150 |
+
if 'state_dict' in k and 'optimizer' not in k:
|
| 151 |
+
state_dict = ckpt[k]
|
| 152 |
+
print(f" Using key: {k}")
|
| 153 |
+
break
|
| 154 |
+
else:
|
| 155 |
+
raise KeyError(f"No state dict found in {filepath}")
|
| 156 |
+
|
| 157 |
+
config = ckpt.get('config', None)
|
| 158 |
+
return state_dict, config
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
def load_pt_router(filepath: Path) -> Tuple[dict, Optional[object]]:
|
| 162 |
+
"""Load router weights from PyTorch checkpoint."""
|
| 163 |
+
print(f" Loading {filepath.name}...")
|
| 164 |
+
ckpt = torch.load(filepath, map_location='cpu', weights_only=False)
|
| 165 |
+
|
| 166 |
+
if 'router_state_dict' in ckpt:
|
| 167 |
+
state_dict = ckpt['router_state_dict']
|
| 168 |
+
else:
|
| 169 |
+
raise KeyError(f"router_state_dict not found in {filepath}")
|
| 170 |
+
|
| 171 |
+
config = ckpt.get('config', None)
|
| 172 |
+
return state_dict, config
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
def load_safetensors_weights(filepath: Path) -> dict:
|
| 176 |
+
"""Load weights from SafeTensors file."""
|
| 177 |
+
print(f" Loading {filepath.name}...")
|
| 178 |
+
return load_file(str(filepath))
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
# ═══════════════════════════════════════════════════════════════════════════════
|
| 182 |
+
# QUANTIZATION
|
| 183 |
+
# ═══════════════════════════════════════════════════════════════════════════════
|
| 184 |
+
|
| 185 |
+
def convert_to_bf16(state_dict: dict) -> dict:
|
| 186 |
+
"""Convert all floating point tensors to bfloat16."""
|
| 187 |
+
bf16_state = {}
|
| 188 |
+
for k, v in state_dict.items():
|
| 189 |
+
if isinstance(v, torch.Tensor) and v.is_floating_point():
|
| 190 |
+
bf16_state[k] = v.to(torch.bfloat16)
|
| 191 |
+
else:
|
| 192 |
+
bf16_state[k] = v
|
| 193 |
+
return bf16_state
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
def is_layernorm_key(key: str) -> bool:
|
| 197 |
+
"""Check if a key belongs to a LayerNorm layer."""
|
| 198 |
+
ln_patterns = ['norm', 'layernorm', 'layer_norm', 'ln_', 'scale_shift_table']
|
| 199 |
+
key_lower = key.lower()
|
| 200 |
+
return any(p in key_lower for p in ln_patterns)
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
def quantize_tensor_int8(tensor: torch.Tensor) -> Tuple[torch.Tensor, float, float]:
|
| 204 |
+
"""
|
| 205 |
+
Quantize a tensor to INT8 with min/max scaling.
|
| 206 |
+
|
| 207 |
+
Formula: int8 = round((x - min) / (max - min) * 255) - 128
|
| 208 |
+
"""
|
| 209 |
+
if tensor.numel() == 0:
|
| 210 |
+
return tensor.to(torch.int8), 0.0, 0.0
|
| 211 |
+
|
| 212 |
+
t_float = tensor.float()
|
| 213 |
+
t_min = t_float.min().item()
|
| 214 |
+
t_max = t_float.max().item()
|
| 215 |
+
|
| 216 |
+
if t_min == t_max:
|
| 217 |
+
return torch.zeros_like(tensor, dtype=torch.int8), t_min, t_max
|
| 218 |
+
|
| 219 |
+
# Quantize: map [min, max] to [-128, 127]
|
| 220 |
+
normalized = (t_float - t_min) / (t_max - t_min)
|
| 221 |
+
int8_tensor = (normalized * 255 - 128).round().clamp(-128, 127).to(torch.int8)
|
| 222 |
+
|
| 223 |
+
return int8_tensor, t_min, t_max
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
def convert_to_int8(state_dict: dict) -> dict:
|
| 227 |
+
"""
|
| 228 |
+
Convert state dict to INT8 quantized format.
|
| 229 |
+
|
| 230 |
+
LayerNorm and small tensors are kept in float32.
|
| 231 |
+
Quantization parameters (_min, _max) are stored alongside.
|
| 232 |
+
"""
|
| 233 |
+
quantized = {}
|
| 234 |
+
stats = {'float32': 0, 'int8': 0}
|
| 235 |
+
|
| 236 |
+
for key, tensor in state_dict.items():
|
| 237 |
+
if not isinstance(tensor, torch.Tensor):
|
| 238 |
+
continue
|
| 239 |
+
|
| 240 |
+
# Skip LayerNorm layers - keep as float32
|
| 241 |
+
if is_layernorm_key(key):
|
| 242 |
+
quantized[key] = tensor.float()
|
| 243 |
+
stats['float32'] += tensor.numel()
|
| 244 |
+
# Only quantize weight tensors with enough elements
|
| 245 |
+
elif tensor.numel() >= 16 and tensor.dtype in [torch.float32, torch.float16, torch.bfloat16]:
|
| 246 |
+
int8_tensor, t_min, t_max = quantize_tensor_int8(tensor)
|
| 247 |
+
quantized[key] = int8_tensor
|
| 248 |
+
quantized[f"{key}._min"] = torch.tensor([t_min], dtype=torch.float32)
|
| 249 |
+
quantized[f"{key}._max"] = torch.tensor([t_max], dtype=torch.float32)
|
| 250 |
+
stats['int8'] += tensor.numel()
|
| 251 |
+
else:
|
| 252 |
+
# Keep small tensors as float32
|
| 253 |
+
quantized[key] = tensor.float()
|
| 254 |
+
stats['float32'] += tensor.numel()
|
| 255 |
+
|
| 256 |
+
return quantized, stats
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
# ═══════════════════════════════════════════════════════════════════════════════
|
| 260 |
+
# MAIN CONVERSION
|
| 261 |
+
# ═══════════════════════════════════════════════════════════════════════════════
|
| 262 |
+
|
| 263 |
+
def convert_weights(input_dir: Path, output_dir: Path, output_format: str):
|
| 264 |
+
"""
|
| 265 |
+
Convert weights to specified format.
|
| 266 |
+
|
| 267 |
+
Args:
|
| 268 |
+
input_dir: Directory containing input weights
|
| 269 |
+
output_dir: Directory to write output weights
|
| 270 |
+
output_format: 'bf16' or 'int8'
|
| 271 |
+
"""
|
| 272 |
+
print(f"""
|
| 273 |
+
╔══════════════════════════════════════════════════════════════════════════════╗
|
| 274 |
+
║ 🔧 Paris MoE Weight Conversion 🔧 ║
|
| 275 |
+
╠══════════════════════════════════════════════════════════════════════════════╣
|
| 276 |
+
║ Input: {str(input_dir):<60} ║
|
| 277 |
+
║ Output: {str(output_dir):<60} ║
|
| 278 |
+
║ Format: {output_format.upper():<60} ║
|
| 279 |
+
╚══════════════════════════════════════════════════════════════════════════════╝
|
| 280 |
+
""")
|
| 281 |
+
|
| 282 |
+
# Detect input format
|
| 283 |
+
input_format, files = detect_input_format(input_dir)
|
| 284 |
+
print(f"📂 Detected input format: {input_format}")
|
| 285 |
+
print(f"📁 Found {len(files)} weight files")
|
| 286 |
+
|
| 287 |
+
# Create output directory
|
| 288 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 289 |
+
|
| 290 |
+
# Track sizes
|
| 291 |
+
sizes = {'input': 0, 'output': 0}
|
| 292 |
+
expert_config = None
|
| 293 |
+
router_config = None
|
| 294 |
+
|
| 295 |
+
# Process experts
|
| 296 |
+
print("\n🧠 Converting experts...")
|
| 297 |
+
for i in range(8):
|
| 298 |
+
key = f"expert_{i}"
|
| 299 |
+
if key not in files:
|
| 300 |
+
print(f" ⚠️ {key} not found, skipping")
|
| 301 |
+
continue
|
| 302 |
+
|
| 303 |
+
filepath = files[key]
|
| 304 |
+
sizes['input'] += filepath.stat().st_size
|
| 305 |
+
|
| 306 |
+
# Load weights
|
| 307 |
+
if input_format == 'pt':
|
| 308 |
+
state_dict, config = load_pt_expert(filepath, i)
|
| 309 |
+
if config is not None and expert_config is None:
|
| 310 |
+
expert_config = config
|
| 311 |
+
else:
|
| 312 |
+
state_dict = load_safetensors_weights(filepath)
|
| 313 |
+
|
| 314 |
+
# Convert
|
| 315 |
+
if output_format == 'bf16':
|
| 316 |
+
converted = convert_to_bf16(state_dict)
|
| 317 |
+
else:
|
| 318 |
+
converted, stats = convert_to_int8(state_dict)
|
| 319 |
+
print(f" INT8: {stats['int8']:,} params, Float32: {stats['float32']:,} params")
|
| 320 |
+
|
| 321 |
+
# Save
|
| 322 |
+
output_path = output_dir / f"expert_{i}.safetensors"
|
| 323 |
+
save_file(converted, str(output_path))
|
| 324 |
+
sizes['output'] += output_path.stat().st_size
|
| 325 |
+
|
| 326 |
+
print(f" ✅ Saved: {output_path.name} ({output_path.stat().st_size / 1e6:.1f} MB)")
|
| 327 |
+
|
| 328 |
+
# Clean up
|
| 329 |
+
del state_dict, converted
|
| 330 |
+
gc.collect()
|
| 331 |
+
|
| 332 |
+
# Process router
|
| 333 |
+
if 'router' in files:
|
| 334 |
+
print("\n📡 Converting router...")
|
| 335 |
+
filepath = files['router']
|
| 336 |
+
sizes['input'] += filepath.stat().st_size
|
| 337 |
+
|
| 338 |
+
if input_format == 'pt':
|
| 339 |
+
state_dict, config = load_pt_router(filepath)
|
| 340 |
+
if config is not None:
|
| 341 |
+
router_config = config
|
| 342 |
+
else:
|
| 343 |
+
state_dict = load_safetensors_weights(filepath)
|
| 344 |
+
|
| 345 |
+
# Router always kept in bf16/float32 for stability
|
| 346 |
+
converted = convert_to_bf16(state_dict)
|
| 347 |
+
|
| 348 |
+
output_path = output_dir / "router.safetensors"
|
| 349 |
+
save_file(converted, str(output_path))
|
| 350 |
+
sizes['output'] += output_path.stat().st_size
|
| 351 |
+
|
| 352 |
+
print(f" ✅ Saved: {output_path.name} ({output_path.stat().st_size / 1e6:.1f} MB)")
|
| 353 |
+
|
| 354 |
+
del state_dict, converted
|
| 355 |
+
gc.collect()
|
| 356 |
+
|
| 357 |
+
# Save configs if from .pt files
|
| 358 |
+
if expert_config is not None:
|
| 359 |
+
config_path = output_dir / "config.pt"
|
| 360 |
+
torch.save({'config': expert_config}, config_path)
|
| 361 |
+
print(f" ✅ Saved: config.pt")
|
| 362 |
+
|
| 363 |
+
if router_config is not None:
|
| 364 |
+
config_path = output_dir / "router_config.pt"
|
| 365 |
+
torch.save({'config': router_config}, config_path)
|
| 366 |
+
print(f" ✅ Saved: router_config.pt")
|
| 367 |
+
|
| 368 |
+
# Summary
|
| 369 |
+
compression = sizes['input'] / sizes['output'] if sizes['output'] > 0 else 1
|
| 370 |
+
print(f"""
|
| 371 |
+
╔══════════════════════════════════════════════════════════════════════════════╗
|
| 372 |
+
║ 📊 Conversion Summary 📊 ║
|
| 373 |
+
╠══════════════════════════════════════════════════════════════════════════════╣
|
| 374 |
+
║ Input size: {sizes['input']/1e9:>8.2f} GB ║
|
| 375 |
+
║ Output size: {sizes['output']/1e9:>8.2f} GB ║
|
| 376 |
+
║ Compression: {compression:>8.1f}x ║
|
| 377 |
+
╠══════════════════════════════════════════════════════════════════════════════╣
|
| 378 |
+
║ ✅ Conversion complete! ║
|
| 379 |
+
╚══════════════════════════════════════════════════════════════════════════════╝
|
| 380 |
+
""")
|
| 381 |
+
|
| 382 |
+
# List output files
|
| 383 |
+
print("📁 Output files:")
|
| 384 |
+
for f in sorted(output_dir.glob("*")):
|
| 385 |
+
print(f" {f.name}: {f.stat().st_size/1e6:.1f} MB")
|
| 386 |
+
|
| 387 |
+
|
| 388 |
+
# ═══════════════════════════════════════════════════════════════════════════════
|
| 389 |
+
# CLI
|
| 390 |
+
# ═══════════════════════════════════════════════════════════════════════════════
|
| 391 |
+
|
| 392 |
+
def parse_args():
|
| 393 |
+
parser = argparse.ArgumentParser(
|
| 394 |
+
description="🔧 Paris MoE - Weight Quantization Utility",
|
| 395 |
+
formatter_class=argparse.RawDescriptionHelpFormatter,
|
| 396 |
+
epilog="""
|
| 397 |
+
Examples:
|
| 398 |
+
# Convert original .pt files to BF16
|
| 399 |
+
python quantize.py --input /path/to/weights --output ./weights/bf16 --format bf16
|
| 400 |
+
|
| 401 |
+
# Convert to INT8 from .pt files
|
| 402 |
+
python quantize.py --input /path/to/weights --output ./weights/int8 --format int8
|
| 403 |
+
|
| 404 |
+
# Convert from BF16 safetensors to INT8
|
| 405 |
+
python quantize.py --input ./weights/bf16 --output ./weights/int8 --format int8
|
| 406 |
+
"""
|
| 407 |
+
)
|
| 408 |
+
|
| 409 |
+
parser.add_argument("--input", "-i", type=str, required=True,
|
| 410 |
+
help="Input directory containing weight files")
|
| 411 |
+
parser.add_argument("--output", "-o", type=str, required=True,
|
| 412 |
+
help="Output directory for converted weights")
|
| 413 |
+
parser.add_argument("--format", "-f", type=str, required=True,
|
| 414 |
+
choices=["bf16", "int8"],
|
| 415 |
+
help="Output format: bf16 or int8")
|
| 416 |
+
|
| 417 |
+
return parser.parse_args()
|
| 418 |
+
|
| 419 |
+
|
| 420 |
+
def main():
|
| 421 |
+
args = parse_args()
|
| 422 |
+
|
| 423 |
+
input_dir = Path(args.input)
|
| 424 |
+
output_dir = Path(args.output)
|
| 425 |
+
|
| 426 |
+
if not input_dir.exists():
|
| 427 |
+
print(f"❌ Error: Input directory does not exist: {input_dir}")
|
| 428 |
+
return 1
|
| 429 |
+
|
| 430 |
+
convert_weights(input_dir, output_dir, args.format)
|
| 431 |
+
return 0
|
| 432 |
+
|
| 433 |
+
|
| 434 |
+
if __name__ == "__main__":
|
| 435 |
+
exit(main())
|
requirements.txt
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch>=2.0
|
| 2 |
+
torchvision
|
| 3 |
+
safetensors
|
| 4 |
+
transformers
|
| 5 |
+
diffusers
|
| 6 |
+
accelerate
|
| 7 |
+
tqdm
|
src/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# Paris MoE Inference - Source modules
|
src/__pycache__/config.cpython-312.pyc
ADDED
|
Binary file (7.8 kB). View file
|
|
|
src/__pycache__/models.cpython-312.pyc
ADDED
|
Binary file (90 kB). View file
|
|
|
src/__pycache__/schedules.cpython-312.pyc
ADDED
|
Binary file (7.41 kB). View file
|
|
|
src/__pycache__/vae_utils.cpython-312.pyc
ADDED
|
Binary file (8.65 kB). View file
|
|
|
src/config.py
ADDED
|
@@ -0,0 +1,199 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# src/config.py
|
| 2 |
+
import yaml
|
| 3 |
+
from typing import Dict, Any, Optional
|
| 4 |
+
from dataclasses import dataclass
|
| 5 |
+
|
| 6 |
+
@dataclass
|
| 7 |
+
class Config:
|
| 8 |
+
"""Single config class - no inheritance needed"""
|
| 9 |
+
|
| 10 |
+
# Experiment
|
| 11 |
+
experiment_name: str
|
| 12 |
+
seed: int = 42
|
| 13 |
+
|
| 14 |
+
# Dataset
|
| 15 |
+
dataset_name: str = "cifar10"
|
| 16 |
+
image_size: int = 64
|
| 17 |
+
num_channels: Optional[int] = None # If None, auto-determined based on dataset/latents
|
| 18 |
+
data_path: str = "./data"
|
| 19 |
+
download: bool = True
|
| 20 |
+
use_latents: bool = False # Whether to use VAE latents instead of raw images
|
| 21 |
+
latent_data_path: Optional[str] = None # Path to latent dataset JSON
|
| 22 |
+
split_strategy: str = "global" # "global" or "per_cluster"
|
| 23 |
+
preclustered_data_path: Optional[str] = None # Path to pre-clustered data
|
| 24 |
+
train_ratio: float = 0.95 # Train/val split ratio
|
| 25 |
+
|
| 26 |
+
# Clustering (None for monolithic)
|
| 27 |
+
clustering_method: Optional[str] = None # "manual", "kmeans", <- note that we dont support dino as an on-the-fly clustering method
|
| 28 |
+
num_clusters: int = 1
|
| 29 |
+
manual_mapping: Optional[Dict[int, int]] = None
|
| 30 |
+
|
| 31 |
+
# Model
|
| 32 |
+
num_experts: int = 1 # 1 = monolithic, >1 = DDM
|
| 33 |
+
expert_architecture: str = "unet" # "unet", "dit", "simple_cnn"
|
| 34 |
+
router_architecture: str = "none" # "vit", "cnn", "dit", "none"
|
| 35 |
+
router_pretrained: bool = True
|
| 36 |
+
clip_tokenizer_name: str = "openai/clip-vit-large-patch14"
|
| 37 |
+
|
| 38 |
+
# Training
|
| 39 |
+
batch_size: int = 32
|
| 40 |
+
num_epochs: int = 20
|
| 41 |
+
learning_rate: float = 1e-4
|
| 42 |
+
optimizer: str = "adamw"
|
| 43 |
+
mixed_precision: bool = True
|
| 44 |
+
num_gpus: int = 1
|
| 45 |
+
distributed: bool = False
|
| 46 |
+
train_router_jointly: bool = False
|
| 47 |
+
weight_decay: float = 0
|
| 48 |
+
use_lr_scheduler: bool = True
|
| 49 |
+
warmup_steps: int = 0 # Learning rate warmup steps
|
| 50 |
+
warmup_factor: float = 0.1 # Learning rate warmup factor
|
| 51 |
+
grad_accum_steps: int = 1
|
| 52 |
+
use_amp: bool = True
|
| 53 |
+
imagenet_pretrain_checkpoint: Optional[str] = None
|
| 54 |
+
|
| 55 |
+
# Cluster imbalance handling
|
| 56 |
+
use_class_weights: bool = False # Enable class weighting for imbalanced clusters
|
| 57 |
+
weight_smoothing: float = 0.0 # Weight smoothing factor (0.0-1.0)
|
| 58 |
+
|
| 59 |
+
# New dataset training options
|
| 60 |
+
new_dataset_learning_rate: Optional[float] = None
|
| 61 |
+
reset_optimizer: bool = True
|
| 62 |
+
reset_scheduler: bool = True
|
| 63 |
+
reset_epoch: bool = True
|
| 64 |
+
reset_ema: bool = False
|
| 65 |
+
|
| 66 |
+
# Decentralized training
|
| 67 |
+
expert_parallel: bool = False
|
| 68 |
+
target_expert_id: int = 0
|
| 69 |
+
target_gpu_id: int = 0
|
| 70 |
+
|
| 71 |
+
# FID evaluation
|
| 72 |
+
compute_fid: bool = False
|
| 73 |
+
fid_every: int = 5000
|
| 74 |
+
fid_num_samples: int = 5000
|
| 75 |
+
fid_batch_size: int = 50
|
| 76 |
+
|
| 77 |
+
# EMA parameters
|
| 78 |
+
use_ema: bool = True
|
| 79 |
+
ema_decay: float = 0.9999
|
| 80 |
+
ema_update_every: int = 1
|
| 81 |
+
|
| 82 |
+
# Heterogeneous objectives
|
| 83 |
+
expert_objectives: Optional[Dict[int, str]] = None # {expert_id: "ddpm"|"fm"|"rf"}
|
| 84 |
+
default_objective: str = "fm" # Default if expert_objectives not specified
|
| 85 |
+
|
| 86 |
+
# Schedule configuration (NEW)
|
| 87 |
+
schedule_type: str = "linear_interp" # Default for backward compatibility
|
| 88 |
+
expert_schedule_types: Optional[Dict[int, str]] = None # Per-expert schedules for Strategy B
|
| 89 |
+
|
| 90 |
+
# Consistency loss (NEW)
|
| 91 |
+
use_consistency_loss: bool = False
|
| 92 |
+
consistency_loss_weight: float = 0.1
|
| 93 |
+
|
| 94 |
+
# Model parameters (flexible dicts)
|
| 95 |
+
expert_params: Dict[str, Any] = None
|
| 96 |
+
router_params: Dict[str, Any] = None
|
| 97 |
+
video_config: Dict[str, Any] = None # Video-specific parameters (temporal_frames, latent_height, etc.)
|
| 98 |
+
|
| 99 |
+
# Inference
|
| 100 |
+
sampling_strategy: str = "top1" # "top1", "topk", "full", "monolithic"
|
| 101 |
+
num_inference_steps: int = 50
|
| 102 |
+
|
| 103 |
+
# Diffusion settings
|
| 104 |
+
beta_start: float = 0.0001
|
| 105 |
+
beta_end: float = 0.02
|
| 106 |
+
beta_schedule: str = "linear"
|
| 107 |
+
max_text_length: int = 77
|
| 108 |
+
|
| 109 |
+
# Paths
|
| 110 |
+
checkpoint_dir: str = "./outputs/checkpoints"
|
| 111 |
+
log_dir: str = "./outputs/logs"
|
| 112 |
+
|
| 113 |
+
def __post_init__(self) -> None:
|
| 114 |
+
# Set defaults for missing fields
|
| 115 |
+
if self.expert_params is None:
|
| 116 |
+
self.expert_params = {}
|
| 117 |
+
if self.router_params is None:
|
| 118 |
+
self.router_params = {}
|
| 119 |
+
if self.video_config is None:
|
| 120 |
+
self.video_config = {}
|
| 121 |
+
|
| 122 |
+
# Auto-determine num_channels if not explicitly set
|
| 123 |
+
if self.num_channels is None:
|
| 124 |
+
if self.use_latents:
|
| 125 |
+
self.num_channels = 4 # VAE latent channels
|
| 126 |
+
elif self.dataset_name in ["mnist", "fashionmnist"]:
|
| 127 |
+
self.num_channels = 1
|
| 128 |
+
else:
|
| 129 |
+
self.num_channels = 3
|
| 130 |
+
|
| 131 |
+
# Initialize and validate expert_objectives
|
| 132 |
+
valid_objectives = {"ddpm", "fm", "rf"}
|
| 133 |
+
|
| 134 |
+
# Validate default_objective
|
| 135 |
+
if self.default_objective not in valid_objectives:
|
| 136 |
+
raise ValueError(f"default_objective must be one of {valid_objectives}, got {self.default_objective}")
|
| 137 |
+
|
| 138 |
+
# Initialize expert_objectives if None
|
| 139 |
+
if self.expert_objectives is None:
|
| 140 |
+
self.expert_objectives = {i: self.default_objective for i in range(self.num_experts)}
|
| 141 |
+
else:
|
| 142 |
+
# Validate all objective types
|
| 143 |
+
for expert_id, obj_type in self.expert_objectives.items():
|
| 144 |
+
if obj_type not in valid_objectives:
|
| 145 |
+
raise ValueError(f"Expert {expert_id} has invalid objective '{obj_type}'. Must be one of {valid_objectives}")
|
| 146 |
+
|
| 147 |
+
# Ensure all expert IDs have objectives assigned
|
| 148 |
+
for expert_id in range(self.num_experts):
|
| 149 |
+
if expert_id not in self.expert_objectives:
|
| 150 |
+
self.expert_objectives[expert_id] = self.default_objective
|
| 151 |
+
|
| 152 |
+
# Validate schedule types (NEW)
|
| 153 |
+
valid_schedules = {"cosine", "linear_beta", "linear_interp"}
|
| 154 |
+
|
| 155 |
+
# Validate default schedule_type
|
| 156 |
+
if self.schedule_type not in valid_schedules:
|
| 157 |
+
raise ValueError(f"schedule_type must be one of {valid_schedules}, got {self.schedule_type}")
|
| 158 |
+
|
| 159 |
+
# Validate expert_schedule_types if provided
|
| 160 |
+
if self.expert_schedule_types is not None:
|
| 161 |
+
for expert_id, sched_type in self.expert_schedule_types.items():
|
| 162 |
+
if sched_type not in valid_schedules:
|
| 163 |
+
raise ValueError(f"Expert {expert_id} has invalid schedule '{sched_type}'. Must be one of {valid_schedules}")
|
| 164 |
+
|
| 165 |
+
@classmethod
|
| 166 |
+
def from_yaml(cls, config_path: str) -> 'Config':
|
| 167 |
+
with open(config_path, 'r') as f:
|
| 168 |
+
config_dict = yaml.safe_load(f)
|
| 169 |
+
|
| 170 |
+
# Set defaults for missing fields
|
| 171 |
+
config_dict.setdefault('expert_params', {})
|
| 172 |
+
config_dict.setdefault('router_params', {})
|
| 173 |
+
|
| 174 |
+
# If num_experts is not specified, default to num_clusters (or 1 if num_clusters is not set)
|
| 175 |
+
if 'num_experts' not in config_dict:
|
| 176 |
+
num_clusters = config_dict.get('num_clusters', 1)
|
| 177 |
+
config_dict['num_experts'] = max(1, num_clusters)
|
| 178 |
+
|
| 179 |
+
return cls(**config_dict)
|
| 180 |
+
|
| 181 |
+
@property
|
| 182 |
+
def is_monolithic(self) -> bool:
|
| 183 |
+
return self.num_experts == 1
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
@property
|
| 187 |
+
def num_classes(self) -> int:
|
| 188 |
+
dataset_classes = {
|
| 189 |
+
"mnist": 10, "fashionmnist": 10,
|
| 190 |
+
"cifar10": 10, "cifar100": 100,
|
| 191 |
+
"celeba": 0, # No class conditioning
|
| 192 |
+
"butterfly": 1, # Single class for butterflies
|
| 193 |
+
"laion": 0 # No class conditioning for LAION
|
| 194 |
+
}
|
| 195 |
+
return dataset_classes.get(self.dataset_name, 10)
|
| 196 |
+
|
| 197 |
+
def load_config(config_path: str) -> Config:
|
| 198 |
+
"""Simple config loader"""
|
| 199 |
+
return Config.from_yaml(config_path)
|
src/models.py
ADDED
|
@@ -0,0 +1,1913 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# src/models.py
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
from diffusers import UNet2DModel
|
| 6 |
+
from transformers import ViTForImageClassification, ViTConfig
|
| 7 |
+
import math
|
| 8 |
+
from typing import Optional, List
|
| 9 |
+
import numpy as np
|
| 10 |
+
|
| 11 |
+
# =============================================================================
|
| 12 |
+
# TIME EMBEDDING (shared utility)
|
| 13 |
+
# =============================================================================
|
| 14 |
+
|
| 15 |
+
class TimeEmbedding(nn.Module):
|
| 16 |
+
def __init__(self, dim: int) -> None:
|
| 17 |
+
super().__init__()
|
| 18 |
+
self.dim = dim
|
| 19 |
+
|
| 20 |
+
def forward(self, t: torch.Tensor) -> torch.Tensor:
|
| 21 |
+
device = t.device
|
| 22 |
+
half_dim = self.dim // 2
|
| 23 |
+
embeddings = math.log(10000) / (half_dim - 1)
|
| 24 |
+
embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
|
| 25 |
+
embeddings = t[:, None] * embeddings[None, :]
|
| 26 |
+
embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
|
| 27 |
+
return embeddings
|
| 28 |
+
|
| 29 |
+
class DiTTimestepEmbedder(nn.Module):
|
| 30 |
+
def __init__(self, hidden_size, freq_dim=128, max_period=10000):
|
| 31 |
+
super().__init__()
|
| 32 |
+
self.freq_dim = freq_dim
|
| 33 |
+
self.max_period = max_period
|
| 34 |
+
self.mlp = nn.Sequential(
|
| 35 |
+
nn.Linear(2*freq_dim, hidden_size, bias=True),
|
| 36 |
+
nn.SiLU(),
|
| 37 |
+
nn.Linear(hidden_size, hidden_size, bias=True),
|
| 38 |
+
)
|
| 39 |
+
def forward(self, t): # t: [B] integers (float tensor ok)
|
| 40 |
+
# standard "timestep_embedding" (like ADM/DiT)
|
| 41 |
+
half = self.freq_dim
|
| 42 |
+
device = t.device
|
| 43 |
+
# positions in radians
|
| 44 |
+
freqs = torch.exp(
|
| 45 |
+
-torch.arange(half, device=device).float() * np.log(self.max_period) / half
|
| 46 |
+
)
|
| 47 |
+
args = t.float()[:, None] * freqs[None] # [B, half]
|
| 48 |
+
emb = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) # [B, 2*half]
|
| 49 |
+
return self.mlp(emb)
|
| 50 |
+
|
| 51 |
+
# =============================================================================
|
| 52 |
+
# OUTPUT CONVERTER (for heterogeneous objectives)
|
| 53 |
+
# =============================================================================
|
| 54 |
+
|
| 55 |
+
class OutputConverter(nn.Module):
|
| 56 |
+
def __init__(self, schedule_type: str = 'linear_interp', use_latents: bool = False, derivative_eps: float = 1e-4):
|
| 57 |
+
super().__init__()
|
| 58 |
+
from schedules import NoiseSchedule
|
| 59 |
+
self.schedule = NoiseSchedule(schedule_type)
|
| 60 |
+
self.schedule_type = schedule_type
|
| 61 |
+
self.use_latents = use_latents
|
| 62 |
+
self.derivative_eps = derivative_eps # For finite difference derivatives
|
| 63 |
+
|
| 64 |
+
# Set clamping range based on data type
|
| 65 |
+
# VAE latents have larger range than pixel-space images
|
| 66 |
+
self.clamp_range = 20.0 if use_latents else 5.0
|
| 67 |
+
|
| 68 |
+
def _get_schedule_with_derivatives(self, t: torch.Tensor):
|
| 69 |
+
"""
|
| 70 |
+
Compute schedule coefficients and their derivatives.
|
| 71 |
+
Essential for correct velocity computation with any schedule.
|
| 72 |
+
"""
|
| 73 |
+
# Get coefficients at current time
|
| 74 |
+
alpha_t, sigma_t = self.schedule.get_schedule(t)
|
| 75 |
+
|
| 76 |
+
# Compute derivatives using finite differences
|
| 77 |
+
h = torch.full_like(t, self.derivative_eps)
|
| 78 |
+
t_plus = (t + h).clamp(0.0, 1.0)
|
| 79 |
+
t_minus = (t - h).clamp(0.0, 1.0)
|
| 80 |
+
|
| 81 |
+
alpha_plus, sigma_plus = self.schedule.get_schedule(t_plus)
|
| 82 |
+
alpha_minus, sigma_minus = self.schedule.get_schedule(t_minus)
|
| 83 |
+
|
| 84 |
+
# Derivatives
|
| 85 |
+
dt = (t_plus - t_minus).clamp(min=1e-6)
|
| 86 |
+
d_alpha_dt = (alpha_plus - alpha_minus) / dt
|
| 87 |
+
d_sigma_dt = (sigma_plus - sigma_minus) / dt
|
| 88 |
+
|
| 89 |
+
return alpha_t, sigma_t, d_alpha_dt, d_sigma_dt
|
| 90 |
+
|
| 91 |
+
def epsilon_to_velocity(self, epsilon_pred: torch.Tensor, x_t: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
|
| 92 |
+
"""
|
| 93 |
+
Correct ε→v conversion for ANY schedule using proper derivatives.
|
| 94 |
+
|
| 95 |
+
From ODE: dx_t/dt = d(alpha_t)/dt * x_0 + d(sigma_t)/dt * ε
|
| 96 |
+
This is the TRUE velocity for the schedule!
|
| 97 |
+
"""
|
| 98 |
+
# Get schedule coefficients AND their derivatives
|
| 99 |
+
alpha_t, sigma_t, d_alpha_dt, d_sigma_dt = self._get_schedule_with_derivatives(t)
|
| 100 |
+
|
| 101 |
+
# Reshape for broadcasting
|
| 102 |
+
alpha_t = alpha_t.view(-1, 1, 1, 1)
|
| 103 |
+
sigma_t = sigma_t.view(-1, 1, 1, 1)
|
| 104 |
+
d_alpha_dt = d_alpha_dt.view(-1, 1, 1, 1)
|
| 105 |
+
d_sigma_dt = d_sigma_dt.view(-1, 1, 1, 1)
|
| 106 |
+
|
| 107 |
+
# Numerical stability: handle small alpha_t
|
| 108 |
+
alpha_safe = torch.clamp(alpha_t, min=0.01)
|
| 109 |
+
|
| 110 |
+
# Step 1: Recover x_0 using Tweedie's formula
|
| 111 |
+
x_0_pred = (x_t - sigma_t * epsilon_pred) / alpha_safe
|
| 112 |
+
|
| 113 |
+
# Step 2: Clamp x_0 to reasonable range (prevents blow-up)
|
| 114 |
+
# Use adaptive clamping: larger range for VAE latents, tighter for pixel space
|
| 115 |
+
x_0_pred = torch.clamp(x_0_pred, -self.clamp_range, self.clamp_range)
|
| 116 |
+
|
| 117 |
+
# Step 3: Compute velocity based on schedule type
|
| 118 |
+
if self.schedule_type == 'linear_interp':
|
| 119 |
+
# For linear interpolation: x_t = (1-t)*x_0 + t*ε
|
| 120 |
+
# Velocity is simply: v = ε - x_0
|
| 121 |
+
v = epsilon_pred - x_0_pred
|
| 122 |
+
else:
|
| 123 |
+
# For cosine and other schedules: use proper derivatives
|
| 124 |
+
# v = d(alpha_t)/dt * x_0 + d(sigma_t)/dt * ε
|
| 125 |
+
v = d_alpha_dt * x_0_pred + d_sigma_dt * epsilon_pred
|
| 126 |
+
|
| 127 |
+
# Adaptive velocity scaling for cosine schedule
|
| 128 |
+
# Derivatives vary dramatically with timestep - need adaptive dampening
|
| 129 |
+
if self.schedule_type == 'cosine':
|
| 130 |
+
t_val = t[0].item() if t.numel() > 0 else 0.5
|
| 131 |
+
if t_val > 0.85:
|
| 132 |
+
# Very high noise: derivatives are large, need dampening
|
| 133 |
+
scale = 0.88
|
| 134 |
+
elif t_val > 0.6:
|
| 135 |
+
# Medium-high noise: moderate dampening
|
| 136 |
+
scale = 0.93
|
| 137 |
+
else:
|
| 138 |
+
# Low to medium noise: slight dampening
|
| 139 |
+
scale = 0.96
|
| 140 |
+
v = v * scale
|
| 141 |
+
|
| 142 |
+
# Per-channel bias correction to prevent color drift
|
| 143 |
+
# The model has inherent channel bias that gets amplified by integration
|
| 144 |
+
# Remove per-channel mean to prevent accumulation
|
| 145 |
+
# Only apply to color channels (1,2,3), preserve luminance channel (0)
|
| 146 |
+
for c in range(1, 4):
|
| 147 |
+
v[:, c] = v[:, c] - v[:, c].mean()
|
| 148 |
+
|
| 149 |
+
return v
|
| 150 |
+
|
| 151 |
+
def convert(self, prediction: torch.Tensor, objective_type: str, x_t: torch.Tensor, t: torch.Tensor):
|
| 152 |
+
"""
|
| 153 |
+
Convert any prediction to velocity space.
|
| 154 |
+
|
| 155 |
+
Args:
|
| 156 |
+
prediction: expert output
|
| 157 |
+
objective_type: 'ddpm' | 'fm' | 'rf'
|
| 158 |
+
x_t: current noisy state
|
| 159 |
+
t: current timesteps
|
| 160 |
+
|
| 161 |
+
Returns:
|
| 162 |
+
v: velocity representation
|
| 163 |
+
"""
|
| 164 |
+
if objective_type == "ddpm":
|
| 165 |
+
# Proper ε→v conversion for unified integration
|
| 166 |
+
return self.epsilon_to_velocity(prediction, x_t, t)
|
| 167 |
+
elif objective_type in ["fm", "rf"]:
|
| 168 |
+
return prediction # Already velocity
|
| 169 |
+
else:
|
| 170 |
+
raise ValueError(f"Unknown objective type: {objective_type}")
|
| 171 |
+
|
| 172 |
+
# =============================================================================
|
| 173 |
+
# EXPERT MODELS
|
| 174 |
+
# =============================================================================
|
| 175 |
+
|
| 176 |
+
class UNetExpert(nn.Module):
|
| 177 |
+
"""UNet expert using diffusers"""
|
| 178 |
+
|
| 179 |
+
def __init__(self, config) -> None:
|
| 180 |
+
super().__init__()
|
| 181 |
+
|
| 182 |
+
# Default UNet params
|
| 183 |
+
default_params = {
|
| 184 |
+
"sample_size": config.image_size,
|
| 185 |
+
"in_channels": config.num_channels,
|
| 186 |
+
"out_channels": config.num_channels,
|
| 187 |
+
"layers_per_block": 2,
|
| 188 |
+
"block_out_channels": [64, 128, 256, 256],
|
| 189 |
+
"attention_head_dim": 8,
|
| 190 |
+
}
|
| 191 |
+
|
| 192 |
+
# Override with config params
|
| 193 |
+
params = {**default_params, **config.expert_params}
|
| 194 |
+
|
| 195 |
+
# Store objective type for heterogeneous training (and remove from params)
|
| 196 |
+
self.objective_type = params.pop("objective_type", "fm")
|
| 197 |
+
|
| 198 |
+
# Store and initialize schedule (NEW)
|
| 199 |
+
schedule_type = params.pop("schedule_type", "linear_interp")
|
| 200 |
+
from schedules import NoiseSchedule
|
| 201 |
+
self.schedule = NoiseSchedule(schedule_type)
|
| 202 |
+
|
| 203 |
+
self.unet = UNet2DModel(**params)
|
| 204 |
+
|
| 205 |
+
def forward(self, xt: torch.Tensor, t: torch.Tensor, **kwargs) -> torch.Tensor:
|
| 206 |
+
# Scale timesteps for diffusers (expects 0-1000)
|
| 207 |
+
# t_scaled = (t * 1000).long()
|
| 208 |
+
t_scaled = (t * 999).round().long().clamp(0, 999)
|
| 209 |
+
return self.unet(xt, t_scaled).sample
|
| 210 |
+
|
| 211 |
+
def compute_loss(self, x0: torch.Tensor) -> torch.Tensor:
|
| 212 |
+
"""Unified loss computation based on objective type"""
|
| 213 |
+
if self.objective_type == "ddpm":
|
| 214 |
+
return self.ddpm_loss(x0)
|
| 215 |
+
elif self.objective_type == "fm":
|
| 216 |
+
return self.flow_matching_loss(x0)
|
| 217 |
+
elif self.objective_type == "rf":
|
| 218 |
+
return self.rectified_flow_loss(x0)
|
| 219 |
+
else:
|
| 220 |
+
raise ValueError(f"Unknown objective type: {self.objective_type}")
|
| 221 |
+
|
| 222 |
+
def ddpm_loss(self, x0: torch.Tensor) -> torch.Tensor:
|
| 223 |
+
"""DDPM: predict noise ε"""
|
| 224 |
+
batch_size = x0.shape[0]
|
| 225 |
+
device = x0.device
|
| 226 |
+
|
| 227 |
+
t = torch.rand(batch_size, device=device)
|
| 228 |
+
|
| 229 |
+
# Use proper schedule (NEW)
|
| 230 |
+
alpha_t, sigma_t = self.schedule.get_schedule(t)
|
| 231 |
+
|
| 232 |
+
noise = torch.randn_like(x0)
|
| 233 |
+
xt = alpha_t.view(-1, 1, 1, 1) * x0 + sigma_t.view(-1, 1, 1, 1) * noise
|
| 234 |
+
|
| 235 |
+
pred_eps = self.forward(xt, t)
|
| 236 |
+
return F.mse_loss(pred_eps, noise)
|
| 237 |
+
|
| 238 |
+
def rectified_flow_loss(self, x0: torch.Tensor) -> torch.Tensor:
|
| 239 |
+
"""Rectified Flow: predict velocity v = x_1 - x_0"""
|
| 240 |
+
batch_size = x0.shape[0]
|
| 241 |
+
device = x0.device
|
| 242 |
+
|
| 243 |
+
t = torch.rand(batch_size, device=device)
|
| 244 |
+
x1 = torch.randn_like(x0)
|
| 245 |
+
xt = (1 - t).view(-1, 1, 1, 1) * x0 + t.view(-1, 1, 1, 1) * x1
|
| 246 |
+
|
| 247 |
+
pred_v = self.forward(xt, t)
|
| 248 |
+
true_v = x1 - x0
|
| 249 |
+
return F.mse_loss(pred_v, true_v)
|
| 250 |
+
|
| 251 |
+
def flow_matching_loss(self, x0: torch.Tensor) -> torch.Tensor:
|
| 252 |
+
"""Flow matching loss for training"""
|
| 253 |
+
batch_size = x0.shape[0]
|
| 254 |
+
device = x0.device
|
| 255 |
+
|
| 256 |
+
# Sample random timesteps
|
| 257 |
+
t = torch.rand(batch_size, device=device)
|
| 258 |
+
|
| 259 |
+
# Use proper schedule (NEW)
|
| 260 |
+
alpha_t, sigma_t = self.schedule.get_schedule(t)
|
| 261 |
+
|
| 262 |
+
# Add noise
|
| 263 |
+
noise = torch.randn_like(x0)
|
| 264 |
+
xt = alpha_t.view(-1, 1, 1, 1) * x0 + sigma_t.view(-1, 1, 1, 1) * noise
|
| 265 |
+
|
| 266 |
+
# Predict velocity
|
| 267 |
+
pred_v = self.forward(xt, t)
|
| 268 |
+
|
| 269 |
+
# True velocity for flow matching
|
| 270 |
+
# true_v = x0 - xt
|
| 271 |
+
true_v = noise - x0
|
| 272 |
+
|
| 273 |
+
return F.mse_loss(pred_v, true_v)
|
| 274 |
+
|
| 275 |
+
class SimpleCNNExpert(nn.Module):
|
| 276 |
+
"""Simple CNN expert for fast training"""
|
| 277 |
+
|
| 278 |
+
def __init__(self, config) -> None:
|
| 279 |
+
super().__init__()
|
| 280 |
+
|
| 281 |
+
# Default params
|
| 282 |
+
default_params = {
|
| 283 |
+
"hidden_dims": [64, 128, 256],
|
| 284 |
+
"time_dim": 64,
|
| 285 |
+
}
|
| 286 |
+
params = {**default_params, **config.expert_params}
|
| 287 |
+
|
| 288 |
+
# Store objective type for heterogeneous training
|
| 289 |
+
self.objective_type = params.get("objective_type", "fm")
|
| 290 |
+
|
| 291 |
+
# Store and initialize schedule (NEW)
|
| 292 |
+
schedule_type = params.get("schedule_type", "linear_interp")
|
| 293 |
+
from schedules import NoiseSchedule
|
| 294 |
+
self.schedule = NoiseSchedule(schedule_type)
|
| 295 |
+
|
| 296 |
+
self.time_embedding = TimeEmbedding(params["time_dim"])
|
| 297 |
+
self.target_size = config.image_size
|
| 298 |
+
|
| 299 |
+
# Simple encoder-decoder
|
| 300 |
+
self.encoder = self._build_encoder(config.num_channels, params["hidden_dims"])
|
| 301 |
+
self.decoder = self._build_decoder(params["hidden_dims"], config.num_channels)
|
| 302 |
+
|
| 303 |
+
# Time conditioning
|
| 304 |
+
self.time_mlp = nn.Sequential(
|
| 305 |
+
nn.Linear(params["time_dim"], params["hidden_dims"][-1]),
|
| 306 |
+
nn.SiLU(),
|
| 307 |
+
nn.Linear(params["hidden_dims"][-1], params["hidden_dims"][-1])
|
| 308 |
+
)
|
| 309 |
+
|
| 310 |
+
def _build_encoder(self, in_channels: int, hidden_dims: List[int]) -> nn.Sequential:
|
| 311 |
+
layers = []
|
| 312 |
+
prev_dim = in_channels
|
| 313 |
+
|
| 314 |
+
for dim in hidden_dims:
|
| 315 |
+
layers.extend([
|
| 316 |
+
nn.Conv2d(prev_dim, dim, 3, padding=1),
|
| 317 |
+
nn.GroupNorm(8, dim),
|
| 318 |
+
nn.SiLU(),
|
| 319 |
+
nn.Conv2d(dim, dim, 3, padding=1),
|
| 320 |
+
nn.GroupNorm(8, dim),
|
| 321 |
+
nn.SiLU(),
|
| 322 |
+
nn.MaxPool2d(2)
|
| 323 |
+
])
|
| 324 |
+
prev_dim = dim
|
| 325 |
+
|
| 326 |
+
return nn.Sequential(*layers)
|
| 327 |
+
|
| 328 |
+
def _build_decoder(self, hidden_dims: List[int], out_channels: int) -> nn.Sequential:
|
| 329 |
+
layers = []
|
| 330 |
+
reversed_dims = list(reversed(hidden_dims))
|
| 331 |
+
|
| 332 |
+
for i, dim in enumerate(reversed_dims[:-1]):
|
| 333 |
+
next_dim = reversed_dims[i + 1]
|
| 334 |
+
layers.extend([
|
| 335 |
+
nn.ConvTranspose2d(dim, next_dim, 4, stride=2, padding=1),
|
| 336 |
+
nn.GroupNorm(8, next_dim),
|
| 337 |
+
nn.SiLU(),
|
| 338 |
+
nn.Conv2d(next_dim, next_dim, 3, padding=1),
|
| 339 |
+
nn.GroupNorm(8, next_dim),
|
| 340 |
+
nn.SiLU(),
|
| 341 |
+
])
|
| 342 |
+
|
| 343 |
+
# Final layer
|
| 344 |
+
layers.append(nn.Conv2d(reversed_dims[-1], out_channels, 3, padding=1))
|
| 345 |
+
|
| 346 |
+
return nn.Sequential(*layers)
|
| 347 |
+
|
| 348 |
+
def forward(self, xt: torch.Tensor, t: torch.Tensor, **kwargs) -> torch.Tensor:
|
| 349 |
+
# Time embedding
|
| 350 |
+
time_emb = self.time_embedding(t)
|
| 351 |
+
time_features = self.time_mlp(time_emb)
|
| 352 |
+
|
| 353 |
+
# Encode
|
| 354 |
+
encoded = self.encoder(xt)
|
| 355 |
+
|
| 356 |
+
# Add time conditioning
|
| 357 |
+
time_features = time_features.view(time_features.shape[0], -1, 1, 1)
|
| 358 |
+
time_features = time_features.expand(-1, -1, encoded.shape[2], encoded.shape[3])
|
| 359 |
+
conditioned = encoded + time_features
|
| 360 |
+
|
| 361 |
+
# Decode
|
| 362 |
+
output = self.decoder(conditioned)
|
| 363 |
+
|
| 364 |
+
# Ensure output matches target size
|
| 365 |
+
output = F.interpolate(output, size=xt.shape[-2:], mode='bilinear', align_corners=False)
|
| 366 |
+
|
| 367 |
+
return output
|
| 368 |
+
|
| 369 |
+
def compute_loss(self, x0: torch.Tensor) -> torch.Tensor:
|
| 370 |
+
"""Unified loss computation based on objective type"""
|
| 371 |
+
if self.objective_type == "ddpm":
|
| 372 |
+
return self.ddpm_loss(x0)
|
| 373 |
+
elif self.objective_type == "fm":
|
| 374 |
+
return self.flow_matching_loss(x0)
|
| 375 |
+
elif self.objective_type == "rf":
|
| 376 |
+
return self.rectified_flow_loss(x0)
|
| 377 |
+
else:
|
| 378 |
+
raise ValueError(f"Unknown objective type: {self.objective_type}")
|
| 379 |
+
|
| 380 |
+
def ddpm_loss(self, x0: torch.Tensor) -> torch.Tensor:
|
| 381 |
+
"""DDPM: predict noise ε"""
|
| 382 |
+
batch_size = x0.shape[0]
|
| 383 |
+
device = x0.device
|
| 384 |
+
|
| 385 |
+
t = torch.rand(batch_size, device=device)
|
| 386 |
+
|
| 387 |
+
# Use proper schedule (NEW)
|
| 388 |
+
alpha_t, sigma_t = self.schedule.get_schedule(t)
|
| 389 |
+
|
| 390 |
+
noise = torch.randn_like(x0)
|
| 391 |
+
xt = alpha_t.view(-1, 1, 1, 1) * x0 + sigma_t.view(-1, 1, 1, 1) * noise
|
| 392 |
+
|
| 393 |
+
pred_eps = self.forward(xt, t)
|
| 394 |
+
|
| 395 |
+
# Ensure pred_eps matches noise shape
|
| 396 |
+
if pred_eps.shape != noise.shape:
|
| 397 |
+
pred_eps = F.interpolate(pred_eps, size=noise.shape[-2:], mode='bilinear', align_corners=False)
|
| 398 |
+
|
| 399 |
+
return F.mse_loss(pred_eps, noise)
|
| 400 |
+
|
| 401 |
+
def rectified_flow_loss(self, x0: torch.Tensor) -> torch.Tensor:
|
| 402 |
+
"""Rectified Flow: predict velocity v = x_1 - x_0"""
|
| 403 |
+
batch_size = x0.shape[0]
|
| 404 |
+
device = x0.device
|
| 405 |
+
|
| 406 |
+
t = torch.rand(batch_size, device=device)
|
| 407 |
+
x1 = torch.randn_like(x0)
|
| 408 |
+
xt = (1 - t).view(-1, 1, 1, 1) * x0 + t.view(-1, 1, 1, 1) * x1
|
| 409 |
+
|
| 410 |
+
pred_v = self.forward(xt, t)
|
| 411 |
+
true_v = x1 - x0
|
| 412 |
+
|
| 413 |
+
# Ensure pred_v matches true_v shape
|
| 414 |
+
if pred_v.shape != true_v.shape:
|
| 415 |
+
pred_v = F.interpolate(pred_v, size=true_v.shape[-2:], mode='bilinear', align_corners=False)
|
| 416 |
+
|
| 417 |
+
return F.mse_loss(pred_v, true_v)
|
| 418 |
+
|
| 419 |
+
def flow_matching_loss(self, x0: torch.Tensor) -> torch.Tensor:
|
| 420 |
+
"""Flow matching loss"""
|
| 421 |
+
batch_size = x0.shape[0]
|
| 422 |
+
device = x0.device
|
| 423 |
+
|
| 424 |
+
t = torch.rand(batch_size, device=device)
|
| 425 |
+
|
| 426 |
+
# Use proper schedule (NEW)
|
| 427 |
+
alpha_t, sigma_t = self.schedule.get_schedule(t)
|
| 428 |
+
|
| 429 |
+
noise = torch.randn_like(x0)
|
| 430 |
+
xt = alpha_t.view(-1, 1, 1, 1) * x0 + sigma_t.view(-1, 1, 1, 1) * noise
|
| 431 |
+
|
| 432 |
+
pred_v = self.forward(xt, t)
|
| 433 |
+
# true_v = x0 - xt
|
| 434 |
+
true_v = noise - x0
|
| 435 |
+
|
| 436 |
+
# Ensure pred_v matches true_v shape
|
| 437 |
+
if pred_v.shape != true_v.shape:
|
| 438 |
+
pred_v = F.interpolate(pred_v, size=true_v.shape[-2:], mode='bilinear', align_corners=False)
|
| 439 |
+
|
| 440 |
+
return F.mse_loss(pred_v, true_v)
|
| 441 |
+
|
| 442 |
+
# Helper function from original DiT
|
| 443 |
+
def modulate(x, shift, scale):
|
| 444 |
+
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
| 445 |
+
|
| 446 |
+
# Fixed sin-cos position embedding from original
|
| 447 |
+
def get_2d_sincos_pos_embed(embed_dim, grid_size):
|
| 448 |
+
grid_h = np.arange(grid_size, dtype=np.float32)
|
| 449 |
+
grid_w = np.arange(grid_size, dtype=np.float32)
|
| 450 |
+
grid = np.meshgrid(grid_w, grid_h)
|
| 451 |
+
grid = np.stack(grid, axis=0)
|
| 452 |
+
grid = grid.reshape([2, 1, grid_size, grid_size])
|
| 453 |
+
|
| 454 |
+
assert embed_dim % 2 == 0
|
| 455 |
+
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0])
|
| 456 |
+
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1])
|
| 457 |
+
emb = np.concatenate([emb_h, emb_w], axis=1)
|
| 458 |
+
return emb
|
| 459 |
+
|
| 460 |
+
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
| 461 |
+
assert embed_dim % 2 == 0
|
| 462 |
+
omega = np.arange(embed_dim // 2, dtype=np.float64)
|
| 463 |
+
omega /= embed_dim / 2.
|
| 464 |
+
omega = 1. / 10000**omega
|
| 465 |
+
pos = pos.reshape(-1)
|
| 466 |
+
out = np.einsum('m,d->md', pos, omega)
|
| 467 |
+
emb_sin = np.sin(out)
|
| 468 |
+
emb_cos = np.cos(out)
|
| 469 |
+
emb = np.concatenate([emb_sin, emb_cos], axis=1)
|
| 470 |
+
return emb
|
| 471 |
+
|
| 472 |
+
# Timestep Embedder
|
| 473 |
+
class TimestepEmbedder(nn.Module):
|
| 474 |
+
def __init__(self, hidden_size: int, frequency_embedding_size: int = 256):
|
| 475 |
+
super().__init__()
|
| 476 |
+
self.frequency_embedding_size = frequency_embedding_size
|
| 477 |
+
self.mlp = nn.Sequential(
|
| 478 |
+
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
|
| 479 |
+
nn.SiLU(),
|
| 480 |
+
nn.Linear(hidden_size, hidden_size, bias=True),
|
| 481 |
+
)
|
| 482 |
+
|
| 483 |
+
@staticmethod
|
| 484 |
+
def timestep_embedding(t, dim, max_period=10000):
|
| 485 |
+
half = dim // 2
|
| 486 |
+
freqs = torch.exp(-math.log(max_period) * torch.arange(0, half, dtype=torch.float32, device=t.device) / half)
|
| 487 |
+
args = t[:, None].float() * freqs[None]
|
| 488 |
+
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
| 489 |
+
if dim % 2:
|
| 490 |
+
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
| 491 |
+
return embedding
|
| 492 |
+
|
| 493 |
+
def forward(self, t: torch.Tensor) -> torch.Tensor:
|
| 494 |
+
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
|
| 495 |
+
return self.mlp(t_freq)
|
| 496 |
+
|
| 497 |
+
# DiTBlock with proper AdaLN-Zero
|
| 498 |
+
class DiTBlock(nn.Module):
|
| 499 |
+
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float = 4.0, use_text: bool = False, use_adaln_single: bool = False):
|
| 500 |
+
super().__init__()
|
| 501 |
+
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 502 |
+
self.attn = nn.MultiheadAttention(hidden_size, num_heads, dropout=0.1, batch_first=True)
|
| 503 |
+
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 504 |
+
|
| 505 |
+
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
| 506 |
+
self.mlp = nn.Sequential(
|
| 507 |
+
nn.Linear(hidden_size, mlp_hidden_dim),
|
| 508 |
+
nn.GELU(approximate="tanh"), # Match original
|
| 509 |
+
nn.Linear(mlp_hidden_dim, hidden_size),
|
| 510 |
+
)
|
| 511 |
+
|
| 512 |
+
# AdaLN modulation - either per-block MLP or AdaLN-Single embeddings
|
| 513 |
+
self.use_adaln_single = use_adaln_single
|
| 514 |
+
if use_adaln_single:
|
| 515 |
+
# AdaLN-Single: use learnable per-block embeddings instead of MLP
|
| 516 |
+
self.scale_shift_table = nn.Parameter(torch.randn(6, hidden_size) / hidden_size ** 0.5)
|
| 517 |
+
self.adaLN_modulation = None # No MLP needed
|
| 518 |
+
else:
|
| 519 |
+
# Original AdaLN with per-block MLP
|
| 520 |
+
self.adaLN_modulation = nn.Sequential(
|
| 521 |
+
nn.SiLU(),
|
| 522 |
+
nn.Linear(hidden_size, 6 * hidden_size, bias=True)
|
| 523 |
+
)
|
| 524 |
+
self.scale_shift_table = None
|
| 525 |
+
|
| 526 |
+
# Optional text cross-attention
|
| 527 |
+
self.use_text = use_text
|
| 528 |
+
if use_text:
|
| 529 |
+
# Note: PixArt uses xformers which may handle unnormalized queries differently
|
| 530 |
+
# We add a simple norm for stability with PyTorch's MultiheadAttention
|
| 531 |
+
self.norm_cross = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 532 |
+
self.cross_attn = nn.MultiheadAttention(hidden_size, num_heads, dropout=0.1, batch_first=True)
|
| 533 |
+
|
| 534 |
+
def forward(self, x: torch.Tensor, c: torch.Tensor, text_emb: Optional[torch.Tensor] = None,
|
| 535 |
+
attention_mask: Optional[torch.Tensor] = None):
|
| 536 |
+
# Get modulation parameters
|
| 537 |
+
if self.use_adaln_single:
|
| 538 |
+
# AdaLN-Single: combine global time embedding with per-block parameters
|
| 539 |
+
# c should be pre-computed from global t_block with shape [B, 6*hidden_size]
|
| 540 |
+
B = x.shape[0]
|
| 541 |
+
# Chunk and squeeze to get [B, hidden_size] tensors for compatibility with PyTorch's MultiheadAttention
|
| 542 |
+
temp = (self.scale_shift_table[None] + c.reshape(B, 6, -1)).chunk(6, dim=1)
|
| 543 |
+
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = [t.squeeze(1) for t in temp]
|
| 544 |
+
else:
|
| 545 |
+
# Original AdaLN: compute modulation from per-block MLP
|
| 546 |
+
# Also squeeze after chunk to get [B, hidden_size] for consistency
|
| 547 |
+
temp = self.adaLN_modulation(c).chunk(6, dim=1)
|
| 548 |
+
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = [t.squeeze(1) for t in temp]
|
| 549 |
+
|
| 550 |
+
# Self-attention with modulation
|
| 551 |
+
# Both paths now use modulate function for consistency
|
| 552 |
+
x_norm = modulate(self.norm1(x), shift_msa, scale_msa)
|
| 553 |
+
attn_out, _ = self.attn(x_norm, x_norm, x_norm)
|
| 554 |
+
x = x + gate_msa.unsqueeze(1) * attn_out
|
| 555 |
+
|
| 556 |
+
# Optional cross-attention
|
| 557 |
+
if self.use_text and text_emb is not None:
|
| 558 |
+
if text_emb.dim() == 2:
|
| 559 |
+
text_emb = text_emb.unsqueeze(1)
|
| 560 |
+
# Convert attention mask to key_padding_mask format (True = ignore)
|
| 561 |
+
# attention_mask: shape [B, T]; either bool (True=keep) or 0/1 numeric (1=keep)
|
| 562 |
+
key_padding_mask = None
|
| 563 |
+
if attention_mask is not None:
|
| 564 |
+
if attention_mask.dtype is not torch.bool:
|
| 565 |
+
# Convert 0/1 (or >=1) to bool keep-mask first
|
| 566 |
+
keep_mask = attention_mask > 0
|
| 567 |
+
else:
|
| 568 |
+
keep_mask = attention_mask
|
| 569 |
+
# key_padding_mask semantics: True = ignore, False = keep
|
| 570 |
+
key_padding_mask = ~keep_mask # logical NOT, not arithmetic subtraction
|
| 571 |
+
|
| 572 |
+
# Normalize queries for stability (PixArt uses xformers which may differ)
|
| 573 |
+
x_norm = self.norm_cross(x)
|
| 574 |
+
cross_out, _ = self.cross_attn(x_norm, text_emb, text_emb, key_padding_mask=key_padding_mask)
|
| 575 |
+
x = x + cross_out
|
| 576 |
+
|
| 577 |
+
# MLP with modulation
|
| 578 |
+
# Both paths now use modulate function for consistency
|
| 579 |
+
x_norm = modulate(self.norm2(x), shift_mlp, scale_mlp)
|
| 580 |
+
mlp_out = self.mlp(x_norm)
|
| 581 |
+
x = x + gate_mlp.unsqueeze(1) * mlp_out
|
| 582 |
+
|
| 583 |
+
return x
|
| 584 |
+
|
| 585 |
+
# FinalLayer with AdaLN modulation
|
| 586 |
+
class FinalLayer(nn.Module):
|
| 587 |
+
def __init__(self, hidden_size: int, patch_size: int, out_channels: int):
|
| 588 |
+
super().__init__()
|
| 589 |
+
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 590 |
+
self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
|
| 591 |
+
self.adaLN_modulation = nn.Sequential(
|
| 592 |
+
nn.SiLU(),
|
| 593 |
+
nn.Linear(hidden_size, 2 * hidden_size, bias=True)
|
| 594 |
+
)
|
| 595 |
+
|
| 596 |
+
def forward(self, x: torch.Tensor, c: torch.Tensor):
|
| 597 |
+
shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
|
| 598 |
+
x = modulate(self.norm_final(x), shift, scale)
|
| 599 |
+
x = self.linear(x)
|
| 600 |
+
return x
|
| 601 |
+
|
| 602 |
+
# T2IFinalLayer with AdaLN-Single for parameter efficiency
|
| 603 |
+
class T2IFinalLayer(nn.Module):
|
| 604 |
+
def __init__(self, hidden_size: int, patch_size: int, out_channels: int):
|
| 605 |
+
super().__init__()
|
| 606 |
+
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 607 |
+
self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
|
| 608 |
+
# AdaLN-Single: use learnable embeddings instead of MLP
|
| 609 |
+
self.scale_shift_table = nn.Parameter(torch.randn(2, hidden_size) / hidden_size ** 0.5)
|
| 610 |
+
self.hidden_size = hidden_size
|
| 611 |
+
|
| 612 |
+
def forward(self, x: torch.Tensor, t: torch.Tensor):
|
| 613 |
+
# t should be the original time embedding with shape [B, hidden_size]
|
| 614 |
+
# Following PixArt implementation exactly
|
| 615 |
+
shift, scale = (self.scale_shift_table[None] + t[:, None]).chunk(2, dim=1)
|
| 616 |
+
# shift and scale are [B, 1, hidden_size], use t2i_modulate style
|
| 617 |
+
x = self.norm_final(x) * (1 + scale) + shift
|
| 618 |
+
x = self.linear(x)
|
| 619 |
+
return x
|
| 620 |
+
|
| 621 |
+
# DiTExpert
|
| 622 |
+
class DiTExpert(nn.Module):
|
| 623 |
+
def __init__(self, config):
|
| 624 |
+
super().__init__()
|
| 625 |
+
default_params = {
|
| 626 |
+
"hidden_size": 768,
|
| 627 |
+
"num_layers": 12,
|
| 628 |
+
"num_heads": 12,
|
| 629 |
+
"patch_size": 2,
|
| 630 |
+
"in_channels": 4,
|
| 631 |
+
"out_channels": 4,
|
| 632 |
+
"use_text_conditioning": False,
|
| 633 |
+
"use_class_conditioning": False,
|
| 634 |
+
"num_classes": 1000, # ImageNet classes
|
| 635 |
+
"mlp_ratio": 4.0,
|
| 636 |
+
"text_embed_dim": 768,
|
| 637 |
+
"use_dit_time_embed": False,
|
| 638 |
+
}
|
| 639 |
+
params = {**default_params, **config.expert_params}
|
| 640 |
+
|
| 641 |
+
self.patch_size = params["patch_size"]
|
| 642 |
+
self.in_channels = params["in_channels"]
|
| 643 |
+
self.out_channels = params["out_channels"]
|
| 644 |
+
self.hidden_size = params["hidden_size"]
|
| 645 |
+
self.num_heads = params["num_heads"]
|
| 646 |
+
self.use_text = params.get("use_text_conditioning", False)
|
| 647 |
+
self.use_class = params.get("use_class_conditioning", False)
|
| 648 |
+
self.cfg_dropout_prob = params.get("cfg_dropout_prob", 0.1) # 10% dropout for CFG
|
| 649 |
+
self.text_embed_dim = params.get("text_embed_dim", 768)
|
| 650 |
+
self.use_adaln_single = params.get("use_adaln_single", False) # AdaLN-Single for parameter efficiency
|
| 651 |
+
self.depth = params["num_layers"]
|
| 652 |
+
|
| 653 |
+
# Store objective type for heterogeneous training
|
| 654 |
+
self.objective_type = params.get("objective_type", "fm")
|
| 655 |
+
|
| 656 |
+
# Store and initialize schedule (NEW)
|
| 657 |
+
schedule_type = params.get("schedule_type", "linear_interp")
|
| 658 |
+
from schedules import NoiseSchedule
|
| 659 |
+
self.schedule = NoiseSchedule(schedule_type)
|
| 660 |
+
|
| 661 |
+
# Validation: cannot use both text and class conditioning simultaneously
|
| 662 |
+
assert not (self.use_text and self.use_class), "Cannot use both text and class conditioning simultaneously"
|
| 663 |
+
|
| 664 |
+
# Patch embedding
|
| 665 |
+
self.patch_embed = nn.Conv2d(self.in_channels, self.hidden_size,
|
| 666 |
+
kernel_size=self.patch_size, stride=self.patch_size)
|
| 667 |
+
|
| 668 |
+
# Fixed sin-cos positional embedding
|
| 669 |
+
latent_size = getattr(config, 'image_size', 32)
|
| 670 |
+
self.num_patches = (latent_size // self.patch_size) ** 2
|
| 671 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches, self.hidden_size), requires_grad=False)
|
| 672 |
+
|
| 673 |
+
# Time embedding
|
| 674 |
+
self.use_dit_time_embed = params.get("use_dit_time_embed", False)
|
| 675 |
+
if self.use_dit_time_embed:
|
| 676 |
+
self.time_embed = DiTTimestepEmbedder(self.hidden_size)
|
| 677 |
+
else:
|
| 678 |
+
self.time_embed = TimestepEmbedder(self.hidden_size)
|
| 679 |
+
|
| 680 |
+
# Global time block for AdaLN-Single
|
| 681 |
+
if self.use_adaln_single:
|
| 682 |
+
self.t_block = nn.Sequential(
|
| 683 |
+
nn.SiLU(),
|
| 684 |
+
nn.Linear(self.hidden_size, 6 * self.hidden_size, bias=True)
|
| 685 |
+
)
|
| 686 |
+
|
| 687 |
+
# Optional text conditioning
|
| 688 |
+
if self.use_text:
|
| 689 |
+
self.text_proj = nn.Linear(self.text_embed_dim, self.hidden_size)
|
| 690 |
+
self.text_norm = nn.LayerNorm(self.hidden_size, elementwise_affine=False, eps=1e-6)
|
| 691 |
+
# Note: null text embedding will be provided by empty string encoding from CLIP
|
| 692 |
+
# This is handled in the training loop, not as a learnable parameter
|
| 693 |
+
|
| 694 |
+
# Optional class conditioning (ImageNet style)
|
| 695 |
+
if self.use_class:
|
| 696 |
+
# Add 1 extra embedding for null/unconditional class
|
| 697 |
+
self.class_embed = nn.Embedding(params["num_classes"] + 1, self.hidden_size)
|
| 698 |
+
self.null_class_id = params["num_classes"] # Use last index as null class
|
| 699 |
+
|
| 700 |
+
# Transformer blocks
|
| 701 |
+
self.layers = nn.ModuleList([
|
| 702 |
+
DiTBlock(self.hidden_size, self.num_heads, params.get("mlp_ratio", 4.0),
|
| 703 |
+
self.use_text, use_adaln_single=self.use_adaln_single)
|
| 704 |
+
for _ in range(self.depth)
|
| 705 |
+
])
|
| 706 |
+
|
| 707 |
+
# Final layer with modulation
|
| 708 |
+
if self.use_adaln_single:
|
| 709 |
+
self.final_layer = T2IFinalLayer(self.hidden_size, self.patch_size, self.out_channels)
|
| 710 |
+
else:
|
| 711 |
+
self.final_layer = FinalLayer(self.hidden_size, self.patch_size, self.out_channels)
|
| 712 |
+
|
| 713 |
+
# Initialize weights
|
| 714 |
+
self.initialize_weights()
|
| 715 |
+
|
| 716 |
+
def initialize_weights(self):
|
| 717 |
+
# Initialize transformer layers
|
| 718 |
+
def _basic_init(module):
|
| 719 |
+
if isinstance(module, nn.Linear):
|
| 720 |
+
torch.nn.init.xavier_uniform_(module.weight)
|
| 721 |
+
if module.bias is not None:
|
| 722 |
+
nn.init.constant_(module.bias, 0)
|
| 723 |
+
self.apply(_basic_init)
|
| 724 |
+
|
| 725 |
+
# Initialize positional embedding with sin-cos
|
| 726 |
+
grid_size = int(self.num_patches ** 0.5)
|
| 727 |
+
pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], grid_size)
|
| 728 |
+
self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
|
| 729 |
+
|
| 730 |
+
# Initialize patch_embed like nn.Linear
|
| 731 |
+
w = self.patch_embed.weight.data
|
| 732 |
+
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
|
| 733 |
+
if self.patch_embed.bias is not None:
|
| 734 |
+
nn.init.constant_(self.patch_embed.bias, 0)
|
| 735 |
+
|
| 736 |
+
# Initialize timestep embedding MLP
|
| 737 |
+
nn.init.normal_(self.time_embed.mlp[0].weight, std=0.02)
|
| 738 |
+
nn.init.normal_(self.time_embed.mlp[2].weight, std=0.02)
|
| 739 |
+
|
| 740 |
+
# Zero-out adaLN modulation layers in DiT blocks (from DiT paper)
|
| 741 |
+
for block in self.layers:
|
| 742 |
+
if block.adaLN_modulation is not None:
|
| 743 |
+
# Original AdaLN mode
|
| 744 |
+
nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
|
| 745 |
+
nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
|
| 746 |
+
# AdaLN-Single mode: scale_shift_table is already initialized with randn/sqrt(hidden_size)
|
| 747 |
+
|
| 748 |
+
# Zero-out cross-attention output projection (from PixArt-Alpha)
|
| 749 |
+
if self.use_text and hasattr(block, 'cross_attn'):
|
| 750 |
+
nn.init.constant_(block.cross_attn.out_proj.weight, 0)
|
| 751 |
+
nn.init.constant_(block.cross_attn.out_proj.bias, 0)
|
| 752 |
+
|
| 753 |
+
# Initialize text projection layer (analogous to PixArt's caption embedding)
|
| 754 |
+
if self.use_text and hasattr(self, 'text_proj'):
|
| 755 |
+
nn.init.normal_(self.text_proj.weight, std=0.02)
|
| 756 |
+
if self.text_proj.bias is not None:
|
| 757 |
+
nn.init.constant_(self.text_proj.bias, 0)
|
| 758 |
+
|
| 759 |
+
# Initialize class embedding layer (similar to DiT paper)
|
| 760 |
+
if self.use_class and hasattr(self, 'class_embed'):
|
| 761 |
+
nn.init.normal_(self.class_embed.weight, std=0.02)
|
| 762 |
+
|
| 763 |
+
# Initialize global t_block for AdaLN-Single
|
| 764 |
+
if self.use_adaln_single and hasattr(self, 't_block'):
|
| 765 |
+
nn.init.normal_(self.t_block[1].weight, std=0.02)
|
| 766 |
+
# Zero-out t_block initially for stability
|
| 767 |
+
nn.init.constant_(self.t_block[1].bias, 0)
|
| 768 |
+
|
| 769 |
+
# Zero-out output layers
|
| 770 |
+
if hasattr(self.final_layer, 'adaLN_modulation') and self.final_layer.adaLN_modulation is not None:
|
| 771 |
+
# Original FinalLayer
|
| 772 |
+
nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
|
| 773 |
+
nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
|
| 774 |
+
# T2IFinalLayer scale_shift_table is already initialized with randn/sqrt(hidden_size)
|
| 775 |
+
nn.init.constant_(self.final_layer.linear.weight, 0)
|
| 776 |
+
nn.init.constant_(self.final_layer.linear.bias, 0)
|
| 777 |
+
|
| 778 |
+
def forward(self, xt: torch.Tensor, t: torch.Tensor, text_embeds: Optional[torch.Tensor] = None,
|
| 779 |
+
attention_mask: Optional[torch.Tensor] = None, class_labels: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor:
|
| 780 |
+
B, C, H, W = xt.shape
|
| 781 |
+
|
| 782 |
+
# Handle timestep scaling - DiT expects timesteps in [0, 999] range
|
| 783 |
+
# If t is normalized (in [0, 1]), scale it to [0, 999]
|
| 784 |
+
if t.max() <= 1.0 and t.min() >= 0.0:
|
| 785 |
+
# Normalized timesteps, scale to DiT range
|
| 786 |
+
t = t * 999.0
|
| 787 |
+
# Ensure t is in correct range for DiT
|
| 788 |
+
t = t.clamp(0, 999)
|
| 789 |
+
|
| 790 |
+
# Patchify
|
| 791 |
+
x = self.patch_embed(xt) # [B, hidden_size, H//p, W//p]
|
| 792 |
+
x = x.flatten(2).transpose(1, 2) # [B, num_patches, hidden_size]
|
| 793 |
+
x = x + self.pos_embed # Add positional embedding
|
| 794 |
+
|
| 795 |
+
# Prepare conditioning
|
| 796 |
+
time_emb = self.time_embed(t) # [B, hidden_size]
|
| 797 |
+
|
| 798 |
+
# Add class conditioning to time embedding if using class conditioning
|
| 799 |
+
if self.use_class and class_labels is not None:
|
| 800 |
+
class_emb = self.class_embed(class_labels) # [B, hidden_size]
|
| 801 |
+
time_emb = time_emb + class_emb # Additive combination following DiT paper
|
| 802 |
+
|
| 803 |
+
# Process conditioning based on AdaLN mode
|
| 804 |
+
if self.use_adaln_single:
|
| 805 |
+
# AdaLN-Single: compute global modulation once
|
| 806 |
+
c = self.t_block(time_emb) # [B, 6*hidden_size]
|
| 807 |
+
else:
|
| 808 |
+
# Original AdaLN: pass time embedding to each block
|
| 809 |
+
c = time_emb
|
| 810 |
+
|
| 811 |
+
# Prepare text tokens for cross-attention (not fused with time)
|
| 812 |
+
text_tokens = None
|
| 813 |
+
if self.use_text and text_embeds is not None:
|
| 814 |
+
if text_embeds.dim() == 3:
|
| 815 |
+
text_tokens = self.text_proj(text_embeds) # [B, T, hidden_size]
|
| 816 |
+
text_tokens = self.text_norm(text_tokens)
|
| 817 |
+
else:
|
| 818 |
+
text_tokens = self.text_proj(text_embeds).unsqueeze(1) # [B, 1, hidden_size]
|
| 819 |
+
text_tokens = self.text_norm(text_tokens)
|
| 820 |
+
|
| 821 |
+
if attention_mask is not None:
|
| 822 |
+
# cast to bool, clamp shapes to text_tokens length
|
| 823 |
+
attention_mask = attention_mask[:, :text_tokens.shape[1]].to(torch.bool)
|
| 824 |
+
# safety: avoid all-false rows (would yield NaNs in softmax)
|
| 825 |
+
all_false = attention_mask.sum(dim=1) == 0
|
| 826 |
+
if all_false.any():
|
| 827 |
+
attention_mask[all_false, 0] = True
|
| 828 |
+
|
| 829 |
+
# Apply transformer blocks
|
| 830 |
+
for layer in self.layers:
|
| 831 |
+
x = layer(x, c, text_tokens, attention_mask)
|
| 832 |
+
|
| 833 |
+
# Final projection
|
| 834 |
+
if self.use_adaln_single:
|
| 835 |
+
# T2IFinalLayer expects original time embedding, not global modulation
|
| 836 |
+
x = self.final_layer(x, time_emb) # [B, num_patches, patch_size^2 * out_channels]
|
| 837 |
+
else:
|
| 838 |
+
# Original FinalLayer expects conditioning
|
| 839 |
+
x = self.final_layer(x, c) # [B, num_patches, patch_size^2 * out_channels]
|
| 840 |
+
|
| 841 |
+
# Unpatchify
|
| 842 |
+
patch_h = patch_w = int(self.num_patches ** 0.5)
|
| 843 |
+
x = x.view(B, patch_h, patch_w, self.patch_size, self.patch_size, self.out_channels)
|
| 844 |
+
x = x.permute(0, 5, 1, 3, 2, 4).contiguous()
|
| 845 |
+
x = x.view(B, self.out_channels, H, W)
|
| 846 |
+
|
| 847 |
+
return x
|
| 848 |
+
|
| 849 |
+
def compute_loss(self, x0: torch.Tensor, text_embeds: Optional[torch.Tensor] = None,
|
| 850 |
+
attention_mask: Optional[torch.Tensor] = None, class_labels: Optional[torch.Tensor] = None,
|
| 851 |
+
null_text_embeds: Optional[torch.Tensor] = None, null_attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
| 852 |
+
"""Unified loss computation based on objective type"""
|
| 853 |
+
if self.objective_type == "ddpm":
|
| 854 |
+
return self.ddpm_loss(x0, text_embeds, attention_mask, class_labels, null_text_embeds, null_attention_mask)
|
| 855 |
+
elif self.objective_type == "fm":
|
| 856 |
+
return self.flow_matching_loss(x0, text_embeds, attention_mask, class_labels, null_text_embeds, null_attention_mask)
|
| 857 |
+
elif self.objective_type == "rf":
|
| 858 |
+
return self.rectified_flow_loss(x0, text_embeds, attention_mask, class_labels, null_text_embeds, null_attention_mask)
|
| 859 |
+
else:
|
| 860 |
+
raise ValueError(f"Unknown objective type: {self.objective_type}")
|
| 861 |
+
|
| 862 |
+
def ddpm_loss(self, x0: torch.Tensor, text_embeds: Optional[torch.Tensor] = None,
|
| 863 |
+
attention_mask: Optional[torch.Tensor] = None, class_labels: Optional[torch.Tensor] = None,
|
| 864 |
+
null_text_embeds: Optional[torch.Tensor] = None, null_attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
| 865 |
+
"""DDPM: predict noise ε"""
|
| 866 |
+
B = x0.shape[0]
|
| 867 |
+
device = x0.device
|
| 868 |
+
|
| 869 |
+
# Sample time uniformly
|
| 870 |
+
t = torch.rand(B, device=device)
|
| 871 |
+
|
| 872 |
+
# Use proper schedule (NEW)
|
| 873 |
+
alpha_t, sigma_t = self.schedule.get_schedule(t)
|
| 874 |
+
|
| 875 |
+
noise = torch.randn_like(x0)
|
| 876 |
+
xt = alpha_t.view(-1, 1, 1, 1) * x0 + sigma_t.view(-1, 1, 1, 1) * noise
|
| 877 |
+
|
| 878 |
+
# Apply CFG dropout during training
|
| 879 |
+
if self.training and self.cfg_dropout_prob > 0:
|
| 880 |
+
if self.use_text and text_embeds is not None:
|
| 881 |
+
keep = (torch.rand(B, device=device) > self.cfg_dropout_prob) # True = keep text
|
| 882 |
+
|
| 883 |
+
if null_text_embeds is not None:
|
| 884 |
+
# Use provided null text embeddings (from empty string CLIP encoding)
|
| 885 |
+
if null_text_embeds.shape[0] == 1:
|
| 886 |
+
null_text_embeds = null_text_embeds.expand(B, -1, -1)
|
| 887 |
+
|
| 888 |
+
# Replace dropped samples with null text embeddings
|
| 889 |
+
dropped = ~keep
|
| 890 |
+
if dropped.any():
|
| 891 |
+
text_embeds = text_embeds.clone()
|
| 892 |
+
text_embeds[dropped] = null_text_embeds[dropped]
|
| 893 |
+
|
| 894 |
+
# Use provided null attention mask or create default for empty string
|
| 895 |
+
if attention_mask is not None:
|
| 896 |
+
attention_mask = attention_mask.clone()
|
| 897 |
+
if null_attention_mask is not None:
|
| 898 |
+
if null_attention_mask.shape[0] == 1:
|
| 899 |
+
null_attention_mask = null_attention_mask.expand(B, -1)
|
| 900 |
+
attention_mask[dropped] = null_attention_mask[dropped]
|
| 901 |
+
else:
|
| 902 |
+
attention_mask[dropped] = 0
|
| 903 |
+
attention_mask[dropped, 0] = 1
|
| 904 |
+
else:
|
| 905 |
+
# Fallback to old zeroing approach if null_text_embeds not provided
|
| 906 |
+
if text_embeds.dim() == 3: # [B, T, D]
|
| 907 |
+
text_embeds = text_embeds * keep[:, None, None].to(text_embeds.dtype)
|
| 908 |
+
else: # [B, D]
|
| 909 |
+
text_embeds = text_embeds * keep[:, None].to(text_embeds.dtype)
|
| 910 |
+
|
| 911 |
+
if attention_mask is not None:
|
| 912 |
+
attention_mask = attention_mask.clone()
|
| 913 |
+
dropped = ~keep
|
| 914 |
+
if dropped.any():
|
| 915 |
+
attention_mask[dropped, 0] = 1
|
| 916 |
+
|
| 917 |
+
elif self.use_class and class_labels is not None:
|
| 918 |
+
# Apply CFG dropout to class labels using null class embedding
|
| 919 |
+
keep = (torch.rand(B, device=device) > self.cfg_dropout_prob)
|
| 920 |
+
null_class = torch.full_like(class_labels, self.null_class_id)
|
| 921 |
+
class_labels = torch.where(keep, class_labels, null_class)
|
| 922 |
+
|
| 923 |
+
# Predict noise
|
| 924 |
+
pred_eps = self.forward(xt, t, text_embeds, attention_mask, class_labels)
|
| 925 |
+
|
| 926 |
+
return F.mse_loss(pred_eps, noise)
|
| 927 |
+
|
| 928 |
+
def rectified_flow_loss(self, x0: torch.Tensor, text_embeds: Optional[torch.Tensor] = None,
|
| 929 |
+
attention_mask: Optional[torch.Tensor] = None, class_labels: Optional[torch.Tensor] = None,
|
| 930 |
+
null_text_embeds: Optional[torch.Tensor] = None, null_attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
| 931 |
+
"""Rectified Flow: predict velocity v = x_1 - x_0 (straight paths)"""
|
| 932 |
+
B = x0.shape[0]
|
| 933 |
+
device = x0.device
|
| 934 |
+
|
| 935 |
+
# Sample time uniformly
|
| 936 |
+
t = torch.rand(B, device=device)
|
| 937 |
+
|
| 938 |
+
# Straight-line interpolation
|
| 939 |
+
x1 = torch.randn_like(x0) # Gaussian noise as x_1
|
| 940 |
+
xt = (1 - t).view(-1, 1, 1, 1) * x0 + t.view(-1, 1, 1, 1) * x1
|
| 941 |
+
|
| 942 |
+
# Apply CFG dropout during training
|
| 943 |
+
if self.training and self.cfg_dropout_prob > 0:
|
| 944 |
+
if self.use_text and text_embeds is not None:
|
| 945 |
+
keep = (torch.rand(B, device=device) > self.cfg_dropout_prob) # True = keep text
|
| 946 |
+
|
| 947 |
+
if null_text_embeds is not None:
|
| 948 |
+
# Use provided null text embeddings (from empty string CLIP encoding)
|
| 949 |
+
if null_text_embeds.shape[0] == 1:
|
| 950 |
+
null_text_embeds = null_text_embeds.expand(B, -1, -1)
|
| 951 |
+
|
| 952 |
+
# Replace dropped samples with null text embeddings
|
| 953 |
+
dropped = ~keep
|
| 954 |
+
if dropped.any():
|
| 955 |
+
text_embeds = text_embeds.clone()
|
| 956 |
+
text_embeds[dropped] = null_text_embeds[dropped]
|
| 957 |
+
|
| 958 |
+
# Use provided null attention mask or create default for empty string
|
| 959 |
+
if attention_mask is not None:
|
| 960 |
+
attention_mask = attention_mask.clone()
|
| 961 |
+
if null_attention_mask is not None:
|
| 962 |
+
if null_attention_mask.shape[0] == 1:
|
| 963 |
+
null_attention_mask = null_attention_mask.expand(B, -1)
|
| 964 |
+
attention_mask[dropped] = null_attention_mask[dropped]
|
| 965 |
+
else:
|
| 966 |
+
attention_mask[dropped] = 0
|
| 967 |
+
attention_mask[dropped, 0] = 1
|
| 968 |
+
else:
|
| 969 |
+
# Fallback to old zeroing approach if null_text_embeds not provided
|
| 970 |
+
if text_embeds.dim() == 3: # [B, T, D]
|
| 971 |
+
text_embeds = text_embeds * keep[:, None, None].to(text_embeds.dtype)
|
| 972 |
+
else: # [B, D]
|
| 973 |
+
text_embeds = text_embeds * keep[:, None].to(text_embeds.dtype)
|
| 974 |
+
|
| 975 |
+
if attention_mask is not None:
|
| 976 |
+
attention_mask = attention_mask.clone()
|
| 977 |
+
dropped = ~keep
|
| 978 |
+
if dropped.any():
|
| 979 |
+
attention_mask[dropped, 0] = 1
|
| 980 |
+
|
| 981 |
+
elif self.use_class and class_labels is not None:
|
| 982 |
+
# Apply CFG dropout to class labels using null class embedding
|
| 983 |
+
keep = (torch.rand(B, device=device) > self.cfg_dropout_prob)
|
| 984 |
+
null_class = torch.full_like(class_labels, self.null_class_id)
|
| 985 |
+
class_labels = torch.where(keep, class_labels, null_class)
|
| 986 |
+
|
| 987 |
+
# Predict velocity (x_1 - x_0)
|
| 988 |
+
pred_v = self.forward(xt, t, text_embeds, attention_mask, class_labels)
|
| 989 |
+
true_v = x1 - x0
|
| 990 |
+
|
| 991 |
+
return F.mse_loss(pred_v, true_v)
|
| 992 |
+
|
| 993 |
+
def flow_matching_loss(self, x0: torch.Tensor, text_embeds: Optional[torch.Tensor] = None,
|
| 994 |
+
attention_mask: Optional[torch.Tensor] = None, class_labels: Optional[torch.Tensor] = None,
|
| 995 |
+
null_text_embeds: Optional[torch.Tensor] = None, null_attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
| 996 |
+
"""Flow matching loss for latent space training with CFG dropout."""
|
| 997 |
+
B = x0.shape[0]
|
| 998 |
+
device = x0.device
|
| 999 |
+
|
| 1000 |
+
# Sample time uniformly
|
| 1001 |
+
t = torch.rand(B, device=device)
|
| 1002 |
+
|
| 1003 |
+
# Use proper schedule (NEW)
|
| 1004 |
+
alpha_t, sigma_t = self.schedule.get_schedule(t)
|
| 1005 |
+
|
| 1006 |
+
noise = torch.randn_like(x0)
|
| 1007 |
+
xt = alpha_t.view(-1, 1, 1, 1) * x0 + sigma_t.view(-1, 1, 1, 1) * noise
|
| 1008 |
+
|
| 1009 |
+
# Apply CFG dropout during training
|
| 1010 |
+
if self.training and self.cfg_dropout_prob > 0:
|
| 1011 |
+
if self.use_text and text_embeds is not None:
|
| 1012 |
+
keep = (torch.rand(B, device=device) > self.cfg_dropout_prob) # True = keep text
|
| 1013 |
+
|
| 1014 |
+
if null_text_embeds is not None:
|
| 1015 |
+
# Use provided null text embeddings (from empty string CLIP encoding)
|
| 1016 |
+
# Ensure null_text_embeds matches the batch size
|
| 1017 |
+
if null_text_embeds.shape[0] == 1:
|
| 1018 |
+
null_text_embeds = null_text_embeds.expand(B, -1, -1)
|
| 1019 |
+
|
| 1020 |
+
# Replace dropped samples with null text embeddings
|
| 1021 |
+
dropped = ~keep
|
| 1022 |
+
if dropped.any():
|
| 1023 |
+
text_embeds = text_embeds.clone()
|
| 1024 |
+
text_embeds[dropped] = null_text_embeds[dropped]
|
| 1025 |
+
|
| 1026 |
+
# Use provided null attention mask or create default for empty string
|
| 1027 |
+
if attention_mask is not None:
|
| 1028 |
+
attention_mask = attention_mask.clone()
|
| 1029 |
+
if null_attention_mask is not None:
|
| 1030 |
+
# Ensure null_attention_mask matches batch size
|
| 1031 |
+
if null_attention_mask.shape[0] == 1:
|
| 1032 |
+
null_attention_mask = null_attention_mask.expand(B, -1)
|
| 1033 |
+
attention_mask[dropped] = null_attention_mask[dropped]
|
| 1034 |
+
else:
|
| 1035 |
+
# Default: For null text (empty string), typically only the first token is valid
|
| 1036 |
+
attention_mask[dropped] = 0
|
| 1037 |
+
attention_mask[dropped, 0] = 1 # Keep only first token for empty string
|
| 1038 |
+
else:
|
| 1039 |
+
# Fallback to old zeroing approach if null_text_embeds not provided
|
| 1040 |
+
if text_embeds.dim() == 3: # [B, T, D]
|
| 1041 |
+
text_embeds = text_embeds * keep[:, None, None].to(text_embeds.dtype)
|
| 1042 |
+
else: # [B, D]
|
| 1043 |
+
text_embeds = text_embeds * keep[:, None].to(text_embeds.dtype)
|
| 1044 |
+
|
| 1045 |
+
# Handle attention mask for fallback approach
|
| 1046 |
+
if attention_mask is not None:
|
| 1047 |
+
attention_mask = attention_mask.clone()
|
| 1048 |
+
dropped = ~keep
|
| 1049 |
+
if dropped.any():
|
| 1050 |
+
attention_mask[dropped, 0] = 1
|
| 1051 |
+
|
| 1052 |
+
elif self.use_class and class_labels is not None:
|
| 1053 |
+
# Apply CFG dropout to class labels using null class embedding
|
| 1054 |
+
keep = (torch.rand(B, device=device) > self.cfg_dropout_prob) # True = keep class
|
| 1055 |
+
# Use the dedicated null class embedding for unconditional generation
|
| 1056 |
+
null_class = torch.full_like(class_labels, self.null_class_id)
|
| 1057 |
+
class_labels = torch.where(keep, class_labels, null_class)
|
| 1058 |
+
|
| 1059 |
+
# Predict velocity
|
| 1060 |
+
pred_v = self.forward(xt, t, text_embeds, attention_mask, class_labels)
|
| 1061 |
+
true_v = noise - x0
|
| 1062 |
+
|
| 1063 |
+
return F.mse_loss(pred_v, true_v)
|
| 1064 |
+
|
| 1065 |
+
# =============================================================================
|
| 1066 |
+
# ROUTER MODELS
|
| 1067 |
+
# =============================================================================
|
| 1068 |
+
|
| 1069 |
+
class ViTRouter(nn.Module):
|
| 1070 |
+
"""ViT-based router for cluster classification"""
|
| 1071 |
+
|
| 1072 |
+
def __init__(self, config) -> None:
|
| 1073 |
+
super().__init__()
|
| 1074 |
+
|
| 1075 |
+
# Default params
|
| 1076 |
+
default_params = {
|
| 1077 |
+
"hidden_size": 384,
|
| 1078 |
+
"num_layers": 6,
|
| 1079 |
+
"num_heads": 6,
|
| 1080 |
+
"patch_size": 8,
|
| 1081 |
+
"use_dit_time_embed": False, # Whether to use DiT-style time embedding
|
| 1082 |
+
}
|
| 1083 |
+
params = {**default_params, **config.router_params}
|
| 1084 |
+
|
| 1085 |
+
if config.router_pretrained:
|
| 1086 |
+
# Use pretrained ViT and adapt
|
| 1087 |
+
self.vit = ViTForImageClassification.from_pretrained(
|
| 1088 |
+
"google/vit-base-patch16-224"
|
| 1089 |
+
)
|
| 1090 |
+
self._adapt_pretrained(config, params)
|
| 1091 |
+
else:
|
| 1092 |
+
# Build from scratch
|
| 1093 |
+
vit_config = ViTConfig(
|
| 1094 |
+
image_size=config.image_size,
|
| 1095 |
+
patch_size=params["patch_size"],
|
| 1096 |
+
num_channels=config.num_channels,
|
| 1097 |
+
hidden_size=params["hidden_size"],
|
| 1098 |
+
num_hidden_layers=params["num_layers"],
|
| 1099 |
+
num_attention_heads=params["num_heads"],
|
| 1100 |
+
num_labels=config.num_clusters
|
| 1101 |
+
)
|
| 1102 |
+
self.vit = ViTForImageClassification(vit_config)
|
| 1103 |
+
|
| 1104 |
+
# Time conditioning - support both embedding styles
|
| 1105 |
+
self.use_dit_time_embed = params.get("use_dit_time_embed", False)
|
| 1106 |
+
if self.use_dit_time_embed:
|
| 1107 |
+
# Use DiT-style timestep embedding for consistency
|
| 1108 |
+
self.time_embedding = DiTTimestepEmbedder(params["hidden_size"])
|
| 1109 |
+
else:
|
| 1110 |
+
# Original simple time embedding
|
| 1111 |
+
self.time_embedding = nn.Sequential(
|
| 1112 |
+
nn.Linear(1, params["hidden_size"]),
|
| 1113 |
+
nn.SiLU(),
|
| 1114 |
+
nn.Linear(params["hidden_size"], params["hidden_size"])
|
| 1115 |
+
)
|
| 1116 |
+
|
| 1117 |
+
# Combined classifier
|
| 1118 |
+
self.classifier = nn.Sequential(
|
| 1119 |
+
nn.Linear(params["hidden_size"] * 2, params["hidden_size"]),
|
| 1120 |
+
nn.ReLU(),
|
| 1121 |
+
nn.Dropout(0.1),
|
| 1122 |
+
nn.Linear(params["hidden_size"], config.num_clusters)
|
| 1123 |
+
)
|
| 1124 |
+
|
| 1125 |
+
def _adapt_pretrained(self, config, params) -> ViTForImageClassification:
|
| 1126 |
+
"""Adapt pretrained ViT for our task"""
|
| 1127 |
+
# Modify patch embeddings if needed
|
| 1128 |
+
if config.image_size != 224 or config.num_channels != 3:
|
| 1129 |
+
self.vit.vit.embeddings.patch_embeddings.projection = nn.Conv2d(
|
| 1130 |
+
config.num_channels,
|
| 1131 |
+
self.vit.config.hidden_size,
|
| 1132 |
+
kernel_size=params["patch_size"],
|
| 1133 |
+
stride=params["patch_size"]
|
| 1134 |
+
)
|
| 1135 |
+
|
| 1136 |
+
def forward(self, xt: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
|
| 1137 |
+
# Process image through ViT
|
| 1138 |
+
vit_outputs = self.vit.vit(xt)
|
| 1139 |
+
image_features = vit_outputs.last_hidden_state[:, 0] # CLS token
|
| 1140 |
+
|
| 1141 |
+
# Time conditioning
|
| 1142 |
+
if self.use_dit_time_embed:
|
| 1143 |
+
# DiT embedder expects raw timesteps
|
| 1144 |
+
time_features = self.time_embedding(t)
|
| 1145 |
+
else:
|
| 1146 |
+
# Original embedding needs unsqueeze
|
| 1147 |
+
time_features = self.time_embedding(t.unsqueeze(-1))
|
| 1148 |
+
|
| 1149 |
+
# Combine and classify
|
| 1150 |
+
combined = torch.cat([image_features, time_features], dim=1)
|
| 1151 |
+
return self.classifier(combined)
|
| 1152 |
+
|
| 1153 |
+
class CNNRouter(nn.Module):
|
| 1154 |
+
"""Simple CNN router for cluster classification"""
|
| 1155 |
+
|
| 1156 |
+
def __init__(self, config) -> None:
|
| 1157 |
+
super().__init__()
|
| 1158 |
+
|
| 1159 |
+
# Default params
|
| 1160 |
+
default_params = {
|
| 1161 |
+
"hidden_dims": [64, 128, 256],
|
| 1162 |
+
"use_dit_time_embed": False, # Whether to use DiT-style time embedding
|
| 1163 |
+
}
|
| 1164 |
+
params = {**default_params, **config.router_params}
|
| 1165 |
+
|
| 1166 |
+
# CNN backbone
|
| 1167 |
+
self.backbone = self._build_cnn(config.num_channels, params["hidden_dims"])
|
| 1168 |
+
|
| 1169 |
+
# Time embedding - support both styles
|
| 1170 |
+
self.use_dit_time_embed = params.get("use_dit_time_embed", False)
|
| 1171 |
+
if self.use_dit_time_embed:
|
| 1172 |
+
# Use DiT-style timestep embedding, output to 128 dims for CNN
|
| 1173 |
+
self.time_embedding = DiTTimestepEmbedder(128)
|
| 1174 |
+
else:
|
| 1175 |
+
# Original simple time embedding
|
| 1176 |
+
self.time_embedding = nn.Sequential(
|
| 1177 |
+
nn.Linear(1, 128),
|
| 1178 |
+
nn.SiLU(),
|
| 1179 |
+
nn.Linear(128, 128)
|
| 1180 |
+
)
|
| 1181 |
+
|
| 1182 |
+
# Classifier
|
| 1183 |
+
self.classifier = nn.Sequential(
|
| 1184 |
+
nn.Linear(params["hidden_dims"][-1] + 128, 256),
|
| 1185 |
+
nn.ReLU(),
|
| 1186 |
+
nn.Dropout(0.1),
|
| 1187 |
+
nn.Linear(256, config.num_clusters)
|
| 1188 |
+
)
|
| 1189 |
+
|
| 1190 |
+
def _build_cnn(self, in_channels: int, hidden_dims: List[int]) -> nn.Sequential:
|
| 1191 |
+
layers = []
|
| 1192 |
+
prev_dim = in_channels
|
| 1193 |
+
|
| 1194 |
+
for dim in hidden_dims:
|
| 1195 |
+
layers.extend([
|
| 1196 |
+
nn.Conv2d(prev_dim, dim, 3, padding=1),
|
| 1197 |
+
nn.ReLU(),
|
| 1198 |
+
nn.Conv2d(dim, dim, 3, padding=1),
|
| 1199 |
+
nn.ReLU(),
|
| 1200 |
+
nn.MaxPool2d(2)
|
| 1201 |
+
])
|
| 1202 |
+
prev_dim = dim
|
| 1203 |
+
|
| 1204 |
+
layers.append(nn.AdaptiveAvgPool2d(1))
|
| 1205 |
+
layers.append(nn.Flatten())
|
| 1206 |
+
|
| 1207 |
+
return nn.Sequential(*layers)
|
| 1208 |
+
|
| 1209 |
+
def forward(self, xt: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
|
| 1210 |
+
# CNN features
|
| 1211 |
+
img_features = self.backbone(xt)
|
| 1212 |
+
|
| 1213 |
+
# Time features
|
| 1214 |
+
if self.use_dit_time_embed:
|
| 1215 |
+
# DiT embedder expects raw timesteps
|
| 1216 |
+
time_features = self.time_embedding(t)
|
| 1217 |
+
else:
|
| 1218 |
+
# Original embedding needs unsqueeze
|
| 1219 |
+
time_features = self.time_embedding(t.unsqueeze(-1))
|
| 1220 |
+
|
| 1221 |
+
# Combine and classify
|
| 1222 |
+
combined = torch.cat([img_features, time_features], dim=1)
|
| 1223 |
+
return self.classifier(combined)
|
| 1224 |
+
|
| 1225 |
+
class DiTRouter(nn.Module):
|
| 1226 |
+
"""DiT B/2 router for cluster classification"""
|
| 1227 |
+
|
| 1228 |
+
def __init__(self, config):
|
| 1229 |
+
super().__init__()
|
| 1230 |
+
|
| 1231 |
+
# DiT B/2 specifications
|
| 1232 |
+
default_params = {
|
| 1233 |
+
"hidden_size": 768, # DiT-B uses 768
|
| 1234 |
+
"num_layers": 12, # DiT-B uses 12 layers
|
| 1235 |
+
"num_heads": 12, # DiT-B uses 12 heads
|
| 1236 |
+
"patch_size": 2, # For latent space (32x32 -> 16x16 patches)
|
| 1237 |
+
"in_channels": 4, # VAE latent channels
|
| 1238 |
+
"mlp_ratio": 4.0,
|
| 1239 |
+
"use_dit_time_embed": False, # Whether to use DiT-style time embedding
|
| 1240 |
+
}
|
| 1241 |
+
params = {**default_params, **config.router_params}
|
| 1242 |
+
|
| 1243 |
+
self.patch_size = params["patch_size"]
|
| 1244 |
+
self.in_channels = params["in_channels"]
|
| 1245 |
+
self.hidden_size = params["hidden_size"]
|
| 1246 |
+
self.num_heads = params["num_heads"]
|
| 1247 |
+
self.num_clusters = config.num_clusters
|
| 1248 |
+
|
| 1249 |
+
# Patch embedding (same as expert)
|
| 1250 |
+
self.patch_embed = nn.Conv2d(
|
| 1251 |
+
self.in_channels, self.hidden_size,
|
| 1252 |
+
kernel_size=self.patch_size, stride=self.patch_size
|
| 1253 |
+
)
|
| 1254 |
+
|
| 1255 |
+
# Calculate number of patches
|
| 1256 |
+
latent_size = getattr(config, 'image_size', 32) # Assuming 256/8=32 for VAE
|
| 1257 |
+
self.num_patches = (latent_size // self.patch_size) ** 2
|
| 1258 |
+
|
| 1259 |
+
# Fixed sin-cos positional embedding (same as expert)
|
| 1260 |
+
self.pos_embed = nn.Parameter(
|
| 1261 |
+
torch.zeros(1, self.num_patches, self.hidden_size),
|
| 1262 |
+
requires_grad=False
|
| 1263 |
+
)
|
| 1264 |
+
|
| 1265 |
+
# CLS token (KEY ADDITION from paper)
|
| 1266 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, self.hidden_size))
|
| 1267 |
+
|
| 1268 |
+
# Time embedding - match expert's choice
|
| 1269 |
+
self.use_dit_time_embed = params.get("use_dit_time_embed", False)
|
| 1270 |
+
if self.use_dit_time_embed:
|
| 1271 |
+
self.time_embed = DiTTimestepEmbedder(self.hidden_size)
|
| 1272 |
+
else:
|
| 1273 |
+
self.time_embed = TimestepEmbedder(self.hidden_size)
|
| 1274 |
+
|
| 1275 |
+
# DiT blocks with AdaLN (reuse DiTBlock from expert)
|
| 1276 |
+
# Note: Router doesn't need text conditioning
|
| 1277 |
+
self.layers = nn.ModuleList([
|
| 1278 |
+
DiTBlock(self.hidden_size, self.num_heads, params["mlp_ratio"], use_text=False)
|
| 1279 |
+
for _ in range(params["num_layers"])
|
| 1280 |
+
])
|
| 1281 |
+
|
| 1282 |
+
# Final layer norm
|
| 1283 |
+
self.norm_final = nn.LayerNorm(self.hidden_size, elementwise_affine=False, eps=1e-6)
|
| 1284 |
+
|
| 1285 |
+
# Linear classifier on CLS token (as specified in paper)
|
| 1286 |
+
# self.head = nn.Linear(self.hidden_size, self.num_clusters)
|
| 1287 |
+
self.head = nn.Sequential(
|
| 1288 |
+
nn.Linear(self.hidden_size, self.hidden_size),
|
| 1289 |
+
nn.GELU(),
|
| 1290 |
+
nn.LayerNorm(self.hidden_size),
|
| 1291 |
+
nn.Dropout(0.1),
|
| 1292 |
+
nn.Linear(self.hidden_size, self.num_clusters)
|
| 1293 |
+
)
|
| 1294 |
+
|
| 1295 |
+
# Initialize weights
|
| 1296 |
+
self.initialize_weights()
|
| 1297 |
+
|
| 1298 |
+
def initialize_weights(self):
|
| 1299 |
+
# Initialize transformer layers
|
| 1300 |
+
def _basic_init(module):
|
| 1301 |
+
if isinstance(module, nn.Linear):
|
| 1302 |
+
torch.nn.init.xavier_uniform_(module.weight)
|
| 1303 |
+
if module.bias is not None:
|
| 1304 |
+
nn.init.constant_(module.bias, 0)
|
| 1305 |
+
self.apply(_basic_init)
|
| 1306 |
+
|
| 1307 |
+
# Initialize CLS token
|
| 1308 |
+
nn.init.normal_(self.cls_token, std=0.02)
|
| 1309 |
+
|
| 1310 |
+
# Initialize positional embedding with sin-cos (same as expert)
|
| 1311 |
+
grid_size = int(self.num_patches ** 0.5)
|
| 1312 |
+
pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], grid_size)
|
| 1313 |
+
self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
|
| 1314 |
+
|
| 1315 |
+
# Initialize patch_embed like nn.Linear
|
| 1316 |
+
w = self.patch_embed.weight.data
|
| 1317 |
+
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
|
| 1318 |
+
if self.patch_embed.bias is not None:
|
| 1319 |
+
nn.init.constant_(self.patch_embed.bias, 0)
|
| 1320 |
+
|
| 1321 |
+
# Initialize timestep embedding MLP
|
| 1322 |
+
if hasattr(self.time_embed, 'mlp'):
|
| 1323 |
+
nn.init.normal_(self.time_embed.mlp[0].weight, std=0.02)
|
| 1324 |
+
nn.init.normal_(self.time_embed.mlp[2].weight, std=0.02)
|
| 1325 |
+
|
| 1326 |
+
# Zero-out adaLN modulation in blocks (following expert initialization)
|
| 1327 |
+
for block in self.layers:
|
| 1328 |
+
nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
|
| 1329 |
+
nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
|
| 1330 |
+
|
| 1331 |
+
# # Initialize classification head (simpler version for classification head)
|
| 1332 |
+
# nn.init.constant_(self.head.weight, 0)
|
| 1333 |
+
# nn.init.constant_(self.head.bias, 0)
|
| 1334 |
+
|
| 1335 |
+
# Initialize classification head (Sequential)
|
| 1336 |
+
# Initialize intermediate layers normally, zero-out final layer
|
| 1337 |
+
nn.init.normal_(self.head[0].weight, std=0.02) # First linear layer
|
| 1338 |
+
if self.head[0].bias is not None:
|
| 1339 |
+
nn.init.constant_(self.head[0].bias, 0)
|
| 1340 |
+
|
| 1341 |
+
# Zero-out final classification layer (following DiT paper)
|
| 1342 |
+
nn.init.constant_(self.head[-1].weight, 0) # Last linear layer
|
| 1343 |
+
if self.head[-1].bias is not None:
|
| 1344 |
+
nn.init.constant_(self.head[-1].bias, 0)
|
| 1345 |
+
|
| 1346 |
+
def forward(self, xt: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
|
| 1347 |
+
B, C, H, W = xt.shape
|
| 1348 |
+
|
| 1349 |
+
# Match expert's timestep interpretation
|
| 1350 |
+
if t.max() <= 1.0 and t.min() >= 0.0:
|
| 1351 |
+
t = t * 999.0
|
| 1352 |
+
t = t.clamp(0, 999)
|
| 1353 |
+
|
| 1354 |
+
# Patchify
|
| 1355 |
+
x = self.patch_embed(xt) # [B, hidden_size, H//p, W//p]
|
| 1356 |
+
x = x.flatten(2).transpose(1, 2) # [B, num_patches, hidden_size]
|
| 1357 |
+
|
| 1358 |
+
# Add positional embedding
|
| 1359 |
+
x = x + self.pos_embed
|
| 1360 |
+
|
| 1361 |
+
# Prepend CLS token
|
| 1362 |
+
cls_tokens = self.cls_token.expand(B, -1, -1) # [B, 1, hidden_size]
|
| 1363 |
+
x = torch.cat([cls_tokens, x], dim=1) # [B, 1 + num_patches, hidden_size]
|
| 1364 |
+
|
| 1365 |
+
# Time conditioning
|
| 1366 |
+
c = self.time_embed(t) # [B, hidden_size]
|
| 1367 |
+
|
| 1368 |
+
# Apply DiT blocks with AdaLN modulation
|
| 1369 |
+
for layer in self.layers:
|
| 1370 |
+
x = layer(x, c, text_emb=None)
|
| 1371 |
+
|
| 1372 |
+
# Extract CLS token and apply final norm
|
| 1373 |
+
cls_output = x[:, 0] # [B, hidden_size]
|
| 1374 |
+
cls_output = self.norm_final(cls_output)
|
| 1375 |
+
|
| 1376 |
+
# Linear classification head
|
| 1377 |
+
logits = self.head(cls_output) # [B, num_clusters]
|
| 1378 |
+
|
| 1379 |
+
return logits
|
| 1380 |
+
|
| 1381 |
+
# =============================================================================
|
| 1382 |
+
# DETERMINISTIC ROUTER (for controlled experiments)
|
| 1383 |
+
# =============================================================================
|
| 1384 |
+
|
| 1385 |
+
class DeterministicTimestepRouter(nn.Module):
|
| 1386 |
+
"""
|
| 1387 |
+
Deterministic router that assigns experts based on timestep.
|
| 1388 |
+
|
| 1389 |
+
Useful for controlled experiments where you want to test specific routing strategies,
|
| 1390 |
+
such as: "high noise → DDPM expert, low noise → FM expert"
|
| 1391 |
+
|
| 1392 |
+
Args:
|
| 1393 |
+
config: Config object with router_params containing:
|
| 1394 |
+
- timestep_threshold: t value to switch experts (default: 0.5)
|
| 1395 |
+
- high_noise_expert: Expert ID for t > threshold (default: 0, typically DDPM)
|
| 1396 |
+
- low_noise_expert: Expert ID for t <= threshold (default: 1, typically FM)
|
| 1397 |
+
|
| 1398 |
+
Example config:
|
| 1399 |
+
router_architecture: "deterministic_timestep"
|
| 1400 |
+
router_params:
|
| 1401 |
+
timestep_threshold: 0.5
|
| 1402 |
+
high_noise_expert: 0 # DDPM for high noise
|
| 1403 |
+
low_noise_expert: 1 # FM for low noise
|
| 1404 |
+
"""
|
| 1405 |
+
|
| 1406 |
+
def __init__(self, config):
|
| 1407 |
+
super().__init__()
|
| 1408 |
+
self.num_experts = config.num_experts
|
| 1409 |
+
self.threshold = config.router_params.get('timestep_threshold', 0.5)
|
| 1410 |
+
self.high_noise_expert = config.router_params.get('high_noise_expert', 0)
|
| 1411 |
+
self.low_noise_expert = config.router_params.get('low_noise_expert', 1)
|
| 1412 |
+
|
| 1413 |
+
# Validate expert IDs
|
| 1414 |
+
assert 0 <= self.high_noise_expert < self.num_experts, \
|
| 1415 |
+
f"high_noise_expert {self.high_noise_expert} out of range [0, {self.num_experts})"
|
| 1416 |
+
assert 0 <= self.low_noise_expert < self.num_experts, \
|
| 1417 |
+
f"low_noise_expert {self.low_noise_expert} out of range [0, {self.num_experts})"
|
| 1418 |
+
|
| 1419 |
+
# Validate threshold
|
| 1420 |
+
assert 0.0 <= self.threshold <= 1.0, \
|
| 1421 |
+
f"timestep_threshold {self.threshold} must be in [0, 1]"
|
| 1422 |
+
|
| 1423 |
+
# This router has no trainable parameters
|
| 1424 |
+
# Register threshold as buffer (not trained, but saved with model)
|
| 1425 |
+
self.register_buffer('_threshold', torch.tensor(self.threshold))
|
| 1426 |
+
|
| 1427 |
+
print(f"DeterministicTimestepRouter initialized:")
|
| 1428 |
+
print(f" Threshold: {self.threshold}")
|
| 1429 |
+
print(f" High noise (t > {self.threshold}) → Expert {self.high_noise_expert}")
|
| 1430 |
+
print(f" Low noise (t <= {self.threshold}) → Expert {self.low_noise_expert}")
|
| 1431 |
+
|
| 1432 |
+
def forward(self, x: torch.Tensor, t: torch.Tensor, **kwargs) -> torch.Tensor:
|
| 1433 |
+
"""
|
| 1434 |
+
Returns one-hot routing probabilities based on timestep.
|
| 1435 |
+
|
| 1436 |
+
Args:
|
| 1437 |
+
x: Input tensor (unused, but kept for API compatibility with other routers)
|
| 1438 |
+
t: Timesteps, shape (B,)
|
| 1439 |
+
|
| 1440 |
+
Returns:
|
| 1441 |
+
routing_probs: Shape (B, num_experts), one-hot encoded
|
| 1442 |
+
"""
|
| 1443 |
+
B = t.shape[0]
|
| 1444 |
+
device = t.device
|
| 1445 |
+
|
| 1446 |
+
# Initialize routing probabilities (all zeros)
|
| 1447 |
+
routing_probs = torch.zeros(B, self.num_experts, device=device)
|
| 1448 |
+
|
| 1449 |
+
# High noise (t > threshold) → high_noise_expert
|
| 1450 |
+
# Low noise (t <= threshold) → low_noise_expert
|
| 1451 |
+
high_noise_mask = t > self.threshold
|
| 1452 |
+
routing_probs[high_noise_mask, self.high_noise_expert] = 1.0
|
| 1453 |
+
routing_probs[~high_noise_mask, self.low_noise_expert] = 1.0
|
| 1454 |
+
|
| 1455 |
+
return routing_probs
|
| 1456 |
+
|
| 1457 |
+
def train(self, mode: bool = True):
|
| 1458 |
+
"""Override train() - this router is never trained, always in eval mode"""
|
| 1459 |
+
return super(DeterministicTimestepRouter, self).train(False)
|
| 1460 |
+
|
| 1461 |
+
# =============================================================================
|
| 1462 |
+
# ADAPTIVE VIDEO ROUTER (for Video DDM)
|
| 1463 |
+
# =============================================================================
|
| 1464 |
+
|
| 1465 |
+
class AdaptiveVideoRouter(nn.Module):
|
| 1466 |
+
"""
|
| 1467 |
+
Time-adaptive router for video DDM.
|
| 1468 |
+
|
| 1469 |
+
Key innovation: Learns optimal weighting of information sources
|
| 1470 |
+
at each noise level, solving the "motion invisible at t=1" problem.
|
| 1471 |
+
|
| 1472 |
+
Information availability is time-dependent:
|
| 1473 |
+
t ~ 1.0: Only text/first_frame informative → Route on conditioning
|
| 1474 |
+
t ~ 0.5: Structure emerging → Latent becomes useful
|
| 1475 |
+
t ~ 0.1: Near clean → Full information available
|
| 1476 |
+
|
| 1477 |
+
Expected learned behavior:
|
| 1478 |
+
| Noise Level | Text | Frame | Latent | Behavior |
|
| 1479 |
+
|-------------|------|-------|--------|-----------------------------|
|
| 1480 |
+
| t ~ 1.0 | ~0.7 | ~0.2 | ~0.1 | Routes on text semantics |
|
| 1481 |
+
| t ~ 0.5 | ~0.4 | ~0.3 | ~0.3 | Balanced; emerging structure|
|
| 1482 |
+
| t ~ 0.1 | ~0.2 | ~0.2 | ~0.6 | Trusts latent; fine-grained |
|
| 1483 |
+
|
| 1484 |
+
Enhancements:
|
| 1485 |
+
- Masked mean pooling for text (handles variable-length prompts)
|
| 1486 |
+
- Temporal-aware latent encoder (captures motion patterns)
|
| 1487 |
+
- Temperature scaling for inference control
|
| 1488 |
+
"""
|
| 1489 |
+
|
| 1490 |
+
def __init__(self, config):
|
| 1491 |
+
super().__init__()
|
| 1492 |
+
|
| 1493 |
+
# Default params
|
| 1494 |
+
default_params = {
|
| 1495 |
+
"hidden_dim": 512,
|
| 1496 |
+
"text_embed_dim": 768, # CLIP-L text embedding dimension
|
| 1497 |
+
"frame_embed_dim": 768, # DINOv2-B (base) feature dimension
|
| 1498 |
+
"latent_channels": 16, # VAE latent channels (CogVideoX uses 16)
|
| 1499 |
+
"latent_conv_dim": 64, # Intermediate conv channels for latent encoder
|
| 1500 |
+
"dropout": 0.1,
|
| 1501 |
+
"temporal_pool_mode": "attention", # "attention", "avg", or "max"
|
| 1502 |
+
"normalize_inputs": True, # L2-normalize text/frame inputs (match clustering)
|
| 1503 |
+
}
|
| 1504 |
+
params = {**default_params, **getattr(config, 'router_params', {})}
|
| 1505 |
+
|
| 1506 |
+
self.hidden_dim = params["hidden_dim"]
|
| 1507 |
+
self.num_experts = getattr(config, 'num_experts', config.num_clusters)
|
| 1508 |
+
self.latent_channels = params["latent_channels"]
|
| 1509 |
+
self.latent_conv_dim = params["latent_conv_dim"]
|
| 1510 |
+
self.temporal_pool_mode = params["temporal_pool_mode"]
|
| 1511 |
+
self.normalize_inputs = params.get("normalize_inputs", True)
|
| 1512 |
+
|
| 1513 |
+
# === Information Source Encoders ===
|
| 1514 |
+
|
| 1515 |
+
# Text pathway (always available, primary signal at high t)
|
| 1516 |
+
self.text_encoder = nn.Sequential(
|
| 1517 |
+
nn.Linear(params["text_embed_dim"], self.hidden_dim),
|
| 1518 |
+
nn.LayerNorm(self.hidden_dim),
|
| 1519 |
+
nn.GELU(),
|
| 1520 |
+
nn.Linear(self.hidden_dim, self.hidden_dim)
|
| 1521 |
+
)
|
| 1522 |
+
|
| 1523 |
+
# First frame pathway (available for I2V tasks)
|
| 1524 |
+
# Uses DINOv2 features extracted from the first frame
|
| 1525 |
+
self.frame_encoder = nn.Sequential(
|
| 1526 |
+
nn.Linear(params["frame_embed_dim"], self.hidden_dim),
|
| 1527 |
+
nn.LayerNorm(self.hidden_dim),
|
| 1528 |
+
nn.GELU(),
|
| 1529 |
+
nn.Linear(self.hidden_dim, self.hidden_dim)
|
| 1530 |
+
)
|
| 1531 |
+
|
| 1532 |
+
# === Temporal-Aware Latent Encoder ===
|
| 1533 |
+
# Captures both spatial content and temporal motion patterns
|
| 1534 |
+
|
| 1535 |
+
# Spatial feature extraction (per-frame)
|
| 1536 |
+
self.spatial_conv = nn.Sequential(
|
| 1537 |
+
nn.Conv3d(params["latent_channels"], params["latent_conv_dim"],
|
| 1538 |
+
kernel_size=(1, 3, 3), padding=(0, 1, 1)), # Spatial only
|
| 1539 |
+
nn.GroupNorm(8, params["latent_conv_dim"]),
|
| 1540 |
+
nn.GELU(),
|
| 1541 |
+
)
|
| 1542 |
+
|
| 1543 |
+
# Temporal feature extraction (motion patterns)
|
| 1544 |
+
self.temporal_conv = nn.Sequential(
|
| 1545 |
+
nn.Conv3d(params["latent_conv_dim"], params["latent_conv_dim"],
|
| 1546 |
+
kernel_size=(3, 1, 1), padding=(1, 0, 0)), # Temporal only
|
| 1547 |
+
nn.GroupNorm(8, params["latent_conv_dim"]),
|
| 1548 |
+
nn.GELU(),
|
| 1549 |
+
)
|
| 1550 |
+
|
| 1551 |
+
# Combined spatio-temporal processing
|
| 1552 |
+
self.st_conv = nn.Sequential(
|
| 1553 |
+
nn.Conv3d(params["latent_conv_dim"], params["latent_conv_dim"],
|
| 1554 |
+
kernel_size=3, padding=1), # Full 3D
|
| 1555 |
+
nn.GroupNorm(8, params["latent_conv_dim"]),
|
| 1556 |
+
nn.GELU(),
|
| 1557 |
+
)
|
| 1558 |
+
|
| 1559 |
+
# Spatial pooling (keep temporal dimension)
|
| 1560 |
+
self.spatial_pool = nn.AdaptiveAvgPool3d((None, 1, 1)) # [B, C, T, 1, 1]
|
| 1561 |
+
|
| 1562 |
+
# Temporal attention pooling (learns which frames matter for routing)
|
| 1563 |
+
if self.temporal_pool_mode == "attention":
|
| 1564 |
+
self.temporal_attn = nn.Sequential(
|
| 1565 |
+
nn.Linear(params["latent_conv_dim"], params["latent_conv_dim"] // 4),
|
| 1566 |
+
nn.Tanh(),
|
| 1567 |
+
nn.Linear(params["latent_conv_dim"] // 4, 1),
|
| 1568 |
+
)
|
| 1569 |
+
|
| 1570 |
+
# Motion feature extractor (frame differences)
|
| 1571 |
+
self.motion_encoder = nn.Sequential(
|
| 1572 |
+
nn.Linear(params["latent_conv_dim"], params["latent_conv_dim"]),
|
| 1573 |
+
nn.GELU(),
|
| 1574 |
+
nn.Linear(params["latent_conv_dim"], self.hidden_dim // 2),
|
| 1575 |
+
)
|
| 1576 |
+
|
| 1577 |
+
# Content feature projector
|
| 1578 |
+
self.content_proj = nn.Linear(params["latent_conv_dim"], self.hidden_dim // 2)
|
| 1579 |
+
|
| 1580 |
+
# Final latent projection (combines content + motion)
|
| 1581 |
+
self.latent_proj = nn.Sequential(
|
| 1582 |
+
nn.Linear(self.hidden_dim, self.hidden_dim),
|
| 1583 |
+
nn.LayerNorm(self.hidden_dim),
|
| 1584 |
+
)
|
| 1585 |
+
|
| 1586 |
+
# === Time-Dependent Weighting ===
|
| 1587 |
+
|
| 1588 |
+
# Time embedding using existing infrastructure
|
| 1589 |
+
self.time_embed = TimestepEmbedder(self.hidden_dim)
|
| 1590 |
+
|
| 1591 |
+
self.time_mlp = nn.Sequential(
|
| 1592 |
+
nn.Linear(self.hidden_dim, self.hidden_dim),
|
| 1593 |
+
nn.GELU(),
|
| 1594 |
+
nn.Linear(self.hidden_dim, self.hidden_dim)
|
| 1595 |
+
)
|
| 1596 |
+
|
| 1597 |
+
# Learns adaptive weighting: at high t → trust text; at low t → trust latent
|
| 1598 |
+
self.source_weighting = nn.Sequential(
|
| 1599 |
+
nn.Linear(self.hidden_dim, 128),
|
| 1600 |
+
nn.GELU(),
|
| 1601 |
+
nn.Linear(128, 3), # [text, frame, latent] weights
|
| 1602 |
+
nn.Softmax(dim=-1)
|
| 1603 |
+
)
|
| 1604 |
+
|
| 1605 |
+
# === Routing Head ===
|
| 1606 |
+
|
| 1607 |
+
self.router_head = nn.Sequential(
|
| 1608 |
+
nn.Linear(self.hidden_dim, self.hidden_dim),
|
| 1609 |
+
nn.GELU(),
|
| 1610 |
+
nn.LayerNorm(self.hidden_dim),
|
| 1611 |
+
nn.Dropout(params["dropout"]),
|
| 1612 |
+
nn.Linear(self.hidden_dim, self.num_experts)
|
| 1613 |
+
)
|
| 1614 |
+
|
| 1615 |
+
# Initialize weights
|
| 1616 |
+
self.initialize_weights()
|
| 1617 |
+
|
| 1618 |
+
def initialize_weights(self):
|
| 1619 |
+
"""Initialize weights following DiT conventions."""
|
| 1620 |
+
def _basic_init(module):
|
| 1621 |
+
if isinstance(module, nn.Linear):
|
| 1622 |
+
torch.nn.init.xavier_uniform_(module.weight)
|
| 1623 |
+
if module.bias is not None:
|
| 1624 |
+
nn.init.constant_(module.bias, 0)
|
| 1625 |
+
elif isinstance(module, nn.Conv3d):
|
| 1626 |
+
# Flatten spatial dims for xavier init
|
| 1627 |
+
w = module.weight.data
|
| 1628 |
+
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
|
| 1629 |
+
if module.bias is not None:
|
| 1630 |
+
nn.init.constant_(module.bias, 0)
|
| 1631 |
+
self.apply(_basic_init)
|
| 1632 |
+
|
| 1633 |
+
# Initialize timestep embedding MLP (following DiT)
|
| 1634 |
+
if hasattr(self.time_embed, 'mlp'):
|
| 1635 |
+
nn.init.normal_(self.time_embed.mlp[0].weight, std=0.02)
|
| 1636 |
+
nn.init.normal_(self.time_embed.mlp[2].weight, std=0.02)
|
| 1637 |
+
|
| 1638 |
+
# Small non-zero initialization for final routing layer
|
| 1639 |
+
# (pure zeros cause uniform outputs that break temperature scaling)
|
| 1640 |
+
nn.init.normal_(self.router_head[-1].weight, std=0.01)
|
| 1641 |
+
nn.init.constant_(self.router_head[-1].bias, 0)
|
| 1642 |
+
|
| 1643 |
+
# Initialize source weighting to start roughly uniform
|
| 1644 |
+
# The softmax will make [0, 0, 0] → [0.33, 0.33, 0.33]
|
| 1645 |
+
nn.init.constant_(self.source_weighting[-2].weight, 0)
|
| 1646 |
+
nn.init.constant_(self.source_weighting[-2].bias, 0)
|
| 1647 |
+
|
| 1648 |
+
# Initialize temporal attention to uniform attention
|
| 1649 |
+
if self.temporal_pool_mode == "attention":
|
| 1650 |
+
nn.init.constant_(self.temporal_attn[-1].weight, 0)
|
| 1651 |
+
nn.init.constant_(self.temporal_attn[-1].bias, 0)
|
| 1652 |
+
|
| 1653 |
+
def _masked_mean_pool(self, embeddings: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
| 1654 |
+
"""
|
| 1655 |
+
Compute masked mean pooling over sequence dimension.
|
| 1656 |
+
|
| 1657 |
+
Args:
|
| 1658 |
+
embeddings: [B, seq_len, embed_dim] - Token embeddings
|
| 1659 |
+
attention_mask: [B, seq_len] - 1 for real tokens, 0 for padding
|
| 1660 |
+
|
| 1661 |
+
Returns:
|
| 1662 |
+
pooled: [B, embed_dim] - Pooled representation
|
| 1663 |
+
"""
|
| 1664 |
+
if attention_mask is None:
|
| 1665 |
+
# No mask provided, use simple mean
|
| 1666 |
+
return embeddings.mean(dim=1)
|
| 1667 |
+
|
| 1668 |
+
# Expand mask for broadcasting: [B, seq_len] -> [B, seq_len, 1]
|
| 1669 |
+
mask = attention_mask.unsqueeze(-1).to(embeddings.dtype)
|
| 1670 |
+
|
| 1671 |
+
# Masked sum
|
| 1672 |
+
masked_sum = (embeddings * mask).sum(dim=1) # [B, embed_dim]
|
| 1673 |
+
|
| 1674 |
+
# Count of valid tokens (avoid division by zero)
|
| 1675 |
+
token_counts = mask.sum(dim=1).clamp(min=1.0) # [B, 1]
|
| 1676 |
+
|
| 1677 |
+
return masked_sum / token_counts
|
| 1678 |
+
|
| 1679 |
+
def _encode_latent_temporal(self, x_t: torch.Tensor) -> torch.Tensor:
|
| 1680 |
+
"""
|
| 1681 |
+
Encode video latent with temporal awareness.
|
| 1682 |
+
|
| 1683 |
+
Extracts both:
|
| 1684 |
+
- Content features: What is in the video (spatial)
|
| 1685 |
+
- Motion features: How things move (temporal differences)
|
| 1686 |
+
|
| 1687 |
+
Args:
|
| 1688 |
+
x_t: [B, C, T, H, W] - Noisy video latent
|
| 1689 |
+
|
| 1690 |
+
Returns:
|
| 1691 |
+
latent_feat: [B, hidden_dim] - Combined latent features
|
| 1692 |
+
"""
|
| 1693 |
+
B, C, T, H, W = x_t.shape
|
| 1694 |
+
|
| 1695 |
+
# 1. Spatial feature extraction
|
| 1696 |
+
spatial_feat = self.spatial_conv(x_t) # [B, conv_dim, T, H, W]
|
| 1697 |
+
|
| 1698 |
+
# 2. Temporal feature extraction (captures local motion)
|
| 1699 |
+
temporal_feat = self.temporal_conv(spatial_feat) # [B, conv_dim, T, H, W]
|
| 1700 |
+
|
| 1701 |
+
# 3. Combined spatio-temporal processing
|
| 1702 |
+
st_feat = self.st_conv(temporal_feat) # [B, conv_dim, T, H, W]
|
| 1703 |
+
|
| 1704 |
+
# 4. Pool spatially, keep temporal: [B, conv_dim, T, 1, 1] -> [B, T, conv_dim]
|
| 1705 |
+
pooled = self.spatial_pool(st_feat).squeeze(-1).squeeze(-1) # [B, conv_dim, T]
|
| 1706 |
+
pooled = pooled.permute(0, 2, 1) # [B, T, conv_dim]
|
| 1707 |
+
|
| 1708 |
+
# 5. Temporal pooling with optional attention
|
| 1709 |
+
if self.temporal_pool_mode == "attention" and T > 1:
|
| 1710 |
+
# Learn which frames matter for routing
|
| 1711 |
+
attn_scores = self.temporal_attn(pooled).squeeze(-1) # [B, T]
|
| 1712 |
+
attn_weights = F.softmax(attn_scores, dim=-1) # [B, T]
|
| 1713 |
+
content_feat = (pooled * attn_weights.unsqueeze(-1)).sum(dim=1) # [B, conv_dim]
|
| 1714 |
+
elif self.temporal_pool_mode == "max":
|
| 1715 |
+
content_feat = pooled.max(dim=1)[0] # [B, conv_dim]
|
| 1716 |
+
else: # "avg"
|
| 1717 |
+
content_feat = pooled.mean(dim=1) # [B, conv_dim]
|
| 1718 |
+
|
| 1719 |
+
# 6. Extract motion features (frame differences)
|
| 1720 |
+
if T > 1:
|
| 1721 |
+
# Compute frame-to-frame differences
|
| 1722 |
+
frame_diffs = pooled[:, 1:] - pooled[:, :-1] # [B, T-1, conv_dim]
|
| 1723 |
+
|
| 1724 |
+
# Motion magnitude and direction encoding
|
| 1725 |
+
motion_feat = self.motion_encoder(frame_diffs.mean(dim=1)) # [B, hidden_dim//2]
|
| 1726 |
+
else:
|
| 1727 |
+
# Single frame, no motion
|
| 1728 |
+
motion_feat = torch.zeros(B, self.hidden_dim // 2, device=x_t.device)
|
| 1729 |
+
|
| 1730 |
+
# 7. Project content features
|
| 1731 |
+
content_proj = self.content_proj(content_feat) # [B, hidden_dim//2]
|
| 1732 |
+
|
| 1733 |
+
# 8. Combine content + motion
|
| 1734 |
+
combined = torch.cat([content_proj, motion_feat], dim=-1) # [B, hidden_dim]
|
| 1735 |
+
latent_feat = self.latent_proj(combined) # [B, hidden_dim]
|
| 1736 |
+
|
| 1737 |
+
return latent_feat
|
| 1738 |
+
|
| 1739 |
+
def forward(self, x_t: torch.Tensor, t: torch.Tensor,
|
| 1740 |
+
text_embed: torch.Tensor,
|
| 1741 |
+
first_frame_feat: Optional[torch.Tensor] = None,
|
| 1742 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 1743 |
+
temperature: float = 1.0) -> torch.Tensor:
|
| 1744 |
+
"""
|
| 1745 |
+
Compute routing logits with time-adaptive information weighting.
|
| 1746 |
+
|
| 1747 |
+
Args:
|
| 1748 |
+
x_t: Noisy video latent [B, C, T, H, W]
|
| 1749 |
+
t: Noise level [B] in [0, 1] or [0, 999]
|
| 1750 |
+
text_embed: CLIP text embedding [B, text_embed_dim] or [B, seq_len, text_embed_dim]
|
| 1751 |
+
first_frame_feat: Optional DINOv2 features [B, frame_embed_dim]
|
| 1752 |
+
attention_mask: Optional [B, seq_len] mask for text (1=valid, 0=padding)
|
| 1753 |
+
temperature: Softmax temperature for sharper/softer routing (default: 1.0)
|
| 1754 |
+
|
| 1755 |
+
Returns:
|
| 1756 |
+
logits: Expert selection logits [B, num_experts] (scaled by temperature)
|
| 1757 |
+
"""
|
| 1758 |
+
B = x_t.shape[0]
|
| 1759 |
+
device = x_t.device
|
| 1760 |
+
|
| 1761 |
+
# === Encode each information source ===
|
| 1762 |
+
|
| 1763 |
+
# Handle both pooled [B, D] and sequence [B, seq_len, D] text embeddings
|
| 1764 |
+
if text_embed.dim() == 3:
|
| 1765 |
+
# Use masked mean pooling for sequence embeddings
|
| 1766 |
+
text_embed_pooled = self._masked_mean_pool(text_embed, attention_mask)
|
| 1767 |
+
else:
|
| 1768 |
+
# Already pooled
|
| 1769 |
+
text_embed_pooled = text_embed
|
| 1770 |
+
|
| 1771 |
+
# L2-normalize inputs to match clustering preprocessing
|
| 1772 |
+
if self.normalize_inputs:
|
| 1773 |
+
text_embed_pooled = F.normalize(text_embed_pooled, p=2, dim=-1)
|
| 1774 |
+
|
| 1775 |
+
text_feat = self.text_encoder(text_embed_pooled) # [B, hidden_dim]
|
| 1776 |
+
|
| 1777 |
+
# Frame features (optional for T2V, required for I2V)
|
| 1778 |
+
if first_frame_feat is not None:
|
| 1779 |
+
# L2-normalize to match clustering preprocessing
|
| 1780 |
+
if self.normalize_inputs:
|
| 1781 |
+
first_frame_feat = F.normalize(first_frame_feat, p=2, dim=-1)
|
| 1782 |
+
frame_feat = self.frame_encoder(first_frame_feat) # [B, hidden_dim]
|
| 1783 |
+
else:
|
| 1784 |
+
frame_feat = torch.zeros(B, self.hidden_dim, device=device)
|
| 1785 |
+
|
| 1786 |
+
# Latent features from noisy video (temporal-aware encoding)
|
| 1787 |
+
latent_feat = self._encode_latent_temporal(x_t) # [B, hidden_dim]
|
| 1788 |
+
|
| 1789 |
+
# === Time-dependent weighting ===
|
| 1790 |
+
|
| 1791 |
+
# Normalize timesteps to [0, 999] for TimestepEmbedder
|
| 1792 |
+
if t.max() <= 1.0:
|
| 1793 |
+
t_scaled = t * 999.0
|
| 1794 |
+
else:
|
| 1795 |
+
t_scaled = t
|
| 1796 |
+
t_scaled = t_scaled.clamp(0, 999)
|
| 1797 |
+
|
| 1798 |
+
# Get time features
|
| 1799 |
+
time_emb = self.time_embed(t_scaled) # [B, hidden_dim]
|
| 1800 |
+
time_feat = self.time_mlp(time_emb) # [B, hidden_dim]
|
| 1801 |
+
|
| 1802 |
+
# Compute adaptive weights based on noise level
|
| 1803 |
+
# Network learns: high t → high text weight; low t → high latent weight
|
| 1804 |
+
weights = self.source_weighting(time_feat) # [B, 3]
|
| 1805 |
+
|
| 1806 |
+
# === Adaptive combination ===
|
| 1807 |
+
|
| 1808 |
+
combined = (
|
| 1809 |
+
weights[:, 0:1] * text_feat + # Text contribution
|
| 1810 |
+
weights[:, 1:2] * frame_feat + # Frame contribution
|
| 1811 |
+
weights[:, 2:3] * latent_feat # Latent contribution
|
| 1812 |
+
)
|
| 1813 |
+
|
| 1814 |
+
# Final routing decision (incorporate time context)
|
| 1815 |
+
logits = self.router_head(combined + time_feat)
|
| 1816 |
+
|
| 1817 |
+
# Apply temperature scaling (lower temp = sharper routing)
|
| 1818 |
+
if temperature != 1.0:
|
| 1819 |
+
logits = logits / temperature
|
| 1820 |
+
|
| 1821 |
+
return logits
|
| 1822 |
+
|
| 1823 |
+
def get_source_weights(self, t: torch.Tensor) -> torch.Tensor:
|
| 1824 |
+
"""
|
| 1825 |
+
Get the learned source weights for given timesteps.
|
| 1826 |
+
Useful for debugging and visualization.
|
| 1827 |
+
|
| 1828 |
+
Args:
|
| 1829 |
+
t: Noise levels [B] in [0, 1] or [0, 999]
|
| 1830 |
+
|
| 1831 |
+
Returns:
|
| 1832 |
+
weights: Source weights [B, 3] for [text, frame, latent]
|
| 1833 |
+
"""
|
| 1834 |
+
# Normalize timesteps
|
| 1835 |
+
if t.max() <= 1.0:
|
| 1836 |
+
t_scaled = t * 999.0
|
| 1837 |
+
else:
|
| 1838 |
+
t_scaled = t
|
| 1839 |
+
t_scaled = t_scaled.clamp(0, 999)
|
| 1840 |
+
|
| 1841 |
+
time_emb = self.time_embed(t_scaled)
|
| 1842 |
+
time_feat = self.time_mlp(time_emb)
|
| 1843 |
+
weights = self.source_weighting(time_feat)
|
| 1844 |
+
|
| 1845 |
+
return weights
|
| 1846 |
+
|
| 1847 |
+
# =============================================================================
|
| 1848 |
+
# MODEL FACTORY FUNCTIONS
|
| 1849 |
+
# =============================================================================
|
| 1850 |
+
|
| 1851 |
+
def create_expert(config, expert_id: Optional[int] = None) -> nn.Module:
|
| 1852 |
+
"""
|
| 1853 |
+
Factory function to create expert model
|
| 1854 |
+
|
| 1855 |
+
Args:
|
| 1856 |
+
config: Config object
|
| 1857 |
+
expert_id: Optional expert ID for per-expert schedule/objective configuration
|
| 1858 |
+
"""
|
| 1859 |
+
# Make a copy of config to avoid modifying the original
|
| 1860 |
+
import copy
|
| 1861 |
+
config = copy.copy(config)
|
| 1862 |
+
config.expert_params = config.expert_params.copy()
|
| 1863 |
+
|
| 1864 |
+
# Inject schedule_type into expert_params if not already present
|
| 1865 |
+
if "schedule_type" not in config.expert_params:
|
| 1866 |
+
# Check for per-expert schedule first (with backward compatibility)
|
| 1867 |
+
if (hasattr(config, 'expert_schedule_types') and
|
| 1868 |
+
config.expert_schedule_types and
|
| 1869 |
+
expert_id is not None and
|
| 1870 |
+
expert_id in config.expert_schedule_types):
|
| 1871 |
+
config.expert_params["schedule_type"] = config.expert_schedule_types[expert_id]
|
| 1872 |
+
else:
|
| 1873 |
+
# Use default schedule_type (with fallback for old configs)
|
| 1874 |
+
config.expert_params["schedule_type"] = getattr(config, 'schedule_type', 'linear_interp')
|
| 1875 |
+
|
| 1876 |
+
# Inject objective_type into expert_params if not already present
|
| 1877 |
+
if "objective_type" not in config.expert_params:
|
| 1878 |
+
# Check for per-expert objectives (with backward compatibility)
|
| 1879 |
+
if (hasattr(config, 'expert_objectives') and
|
| 1880 |
+
config.expert_objectives and
|
| 1881 |
+
expert_id is not None and
|
| 1882 |
+
expert_id in config.expert_objectives):
|
| 1883 |
+
config.expert_params["objective_type"] = config.expert_objectives[expert_id]
|
| 1884 |
+
else:
|
| 1885 |
+
# Use default objective (with fallback for old configs)
|
| 1886 |
+
config.expert_params["objective_type"] = getattr(config, 'default_objective', 'fm')
|
| 1887 |
+
|
| 1888 |
+
if config.expert_architecture == "unet":
|
| 1889 |
+
return UNetExpert(config)
|
| 1890 |
+
elif config.expert_architecture == "simple_cnn":
|
| 1891 |
+
return SimpleCNNExpert(config)
|
| 1892 |
+
elif config.expert_architecture == "dit":
|
| 1893 |
+
return DiTExpert(config)
|
| 1894 |
+
else:
|
| 1895 |
+
raise ValueError(f"Unknown expert architecture: {config.expert_architecture}")
|
| 1896 |
+
|
| 1897 |
+
def create_router(config) -> Optional[nn.Module]:
|
| 1898 |
+
"""Factory function to create router model"""
|
| 1899 |
+
|
| 1900 |
+
if config.router_architecture == "none" or config.is_monolithic:
|
| 1901 |
+
return None
|
| 1902 |
+
elif config.router_architecture == "deterministic_timestep":
|
| 1903 |
+
return DeterministicTimestepRouter(config)
|
| 1904 |
+
elif config.router_architecture == "vit":
|
| 1905 |
+
return ViTRouter(config)
|
| 1906 |
+
elif config.router_architecture == "cnn":
|
| 1907 |
+
return CNNRouter(config)
|
| 1908 |
+
elif config.router_architecture == "dit":
|
| 1909 |
+
return DiTRouter(config)
|
| 1910 |
+
elif config.router_architecture == "adaptive_video":
|
| 1911 |
+
return AdaptiveVideoRouter(config)
|
| 1912 |
+
else:
|
| 1913 |
+
raise ValueError(f"Unknown router architecture: {config.router_architecture}")
|
src/schedules.py
ADDED
|
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# src/schedules.py
|
| 2 |
+
"""
|
| 3 |
+
Centralized noise schedule manager for diffusion models.
|
| 4 |
+
|
| 5 |
+
Supports three schedules:
|
| 6 |
+
1. 'cosine': Cosine schedule (Nichol & Dhariwal 2021)
|
| 7 |
+
2. 'linear_beta': Linear beta schedule (Ho et al. 2020)
|
| 8 |
+
3. 'linear_interp': Linear interpolation - Flow Matching default
|
| 9 |
+
|
| 10 |
+
All schedules return (alpha_t, sigma_t) such that:
|
| 11 |
+
x_t = alpha_t * x_0 + sigma_t * epsilon
|
| 12 |
+
alpha_t^2 + sigma_t^2 = 1 (variance preserving)
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
import torch
|
| 16 |
+
import math
|
| 17 |
+
from typing import Tuple
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class NoiseSchedule:
|
| 21 |
+
"""
|
| 22 |
+
Centralized noise schedule manager.
|
| 23 |
+
|
| 24 |
+
Args:
|
| 25 |
+
schedule_type: One of ['cosine', 'linear_beta', 'linear_interp']
|
| 26 |
+
"""
|
| 27 |
+
|
| 28 |
+
def __init__(self, schedule_type: str = 'linear_interp'):
|
| 29 |
+
assert schedule_type in ['cosine', 'linear_beta', 'linear_interp'], \
|
| 30 |
+
f"Unknown schedule: {schedule_type}. Must be one of ['cosine', 'linear_beta', 'linear_interp']"
|
| 31 |
+
self.schedule_type = schedule_type
|
| 32 |
+
|
| 33 |
+
# Linear beta schedule parameters (if used)
|
| 34 |
+
self.beta_min = 0.0001
|
| 35 |
+
self.beta_max = 0.02
|
| 36 |
+
self.num_timesteps = 1000 # T in discrete formulation
|
| 37 |
+
|
| 38 |
+
# Cosine schedule parameter
|
| 39 |
+
self.s = 0.008 # Small offset to prevent beta from being too small near t=0
|
| 40 |
+
|
| 41 |
+
def get_schedule(self, t: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 42 |
+
"""
|
| 43 |
+
Get (alpha_t, sigma_t) for given timesteps.
|
| 44 |
+
|
| 45 |
+
Args:
|
| 46 |
+
t: Tensor of timesteps in [0, 1], shape (B,)
|
| 47 |
+
|
| 48 |
+
Returns:
|
| 49 |
+
alpha_t: Shape (B,), coefficient for x_0
|
| 50 |
+
sigma_t: Shape (B,), coefficient for epsilon
|
| 51 |
+
"""
|
| 52 |
+
if self.schedule_type == 'cosine':
|
| 53 |
+
return self._cosine_schedule(t)
|
| 54 |
+
elif self.schedule_type == 'linear_beta':
|
| 55 |
+
return self._linear_beta_schedule(t)
|
| 56 |
+
elif self.schedule_type == 'linear_interp':
|
| 57 |
+
return self._linear_interpolation(t)
|
| 58 |
+
|
| 59 |
+
def _cosine_schedule(self, t: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 60 |
+
"""
|
| 61 |
+
Cosine schedule: alpha_bar_t = f(t) / f(0)
|
| 62 |
+
where f(t) = cos²((t + s)/(1 + s) * π/2)
|
| 63 |
+
|
| 64 |
+
Reference: "Improved Denoising Diffusion Probabilistic Models"
|
| 65 |
+
(Nichol & Dhariwal, 2021)
|
| 66 |
+
|
| 67 |
+
This schedule provides better conditioning than linear beta schedule,
|
| 68 |
+
especially at very small and very large t values.
|
| 69 |
+
"""
|
| 70 |
+
# Compute f(t) = cos²((t + s)/(1 + s) * π/2)
|
| 71 |
+
f_t = torch.cos(((t + self.s) / (1 + self.s)) * math.pi * 0.5) ** 2
|
| 72 |
+
|
| 73 |
+
# Compute f(0) for normalization to ensure alpha_bar_0 = 1
|
| 74 |
+
f_0 = math.cos((self.s / (1 + self.s)) * math.pi * 0.5) ** 2
|
| 75 |
+
|
| 76 |
+
# Normalize: alpha_bar_t = f(t) / f(0)
|
| 77 |
+
alpha_bar_t = f_t / f_0
|
| 78 |
+
|
| 79 |
+
# Clamp to ensure numerical stability
|
| 80 |
+
alpha_bar_t = torch.clamp(alpha_bar_t, min=1e-8, max=1.0)
|
| 81 |
+
|
| 82 |
+
# Compute coefficients
|
| 83 |
+
alpha_t = torch.sqrt(alpha_bar_t)
|
| 84 |
+
sigma_t = torch.sqrt(1 - alpha_bar_t)
|
| 85 |
+
|
| 86 |
+
return alpha_t, sigma_t
|
| 87 |
+
|
| 88 |
+
def _linear_beta_schedule(self, t: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 89 |
+
"""
|
| 90 |
+
Linear beta schedule: beta_t increases linearly from beta_min to beta_max
|
| 91 |
+
|
| 92 |
+
Reference: "Denoising Diffusion Probabilistic Models" (Ho et al., 2020)
|
| 93 |
+
|
| 94 |
+
For continuous time t ∈ [0,1]:
|
| 95 |
+
beta(t) = beta_min + t * (beta_max - beta_min)
|
| 96 |
+
alpha_bar(t) = exp(-0.5 * integral_0^t beta(s) ds)
|
| 97 |
+
= exp(-0.5 * t * (beta_min + 0.5 * t * (beta_max - beta_min)))
|
| 98 |
+
"""
|
| 99 |
+
# Compute alpha_bar(t) = exp(-0.5 * integral beta(s) ds)
|
| 100 |
+
# integral_0^t (beta_min + s * (beta_max - beta_min)) ds
|
| 101 |
+
# = beta_min * t + 0.5 * t^2 * (beta_max - beta_min)
|
| 102 |
+
integral_beta = self.beta_min * t + 0.5 * t * t * (self.beta_max - self.beta_min)
|
| 103 |
+
alpha_bar_t = torch.exp(-0.5 * integral_beta * self.num_timesteps)
|
| 104 |
+
|
| 105 |
+
# Compute coefficients
|
| 106 |
+
alpha_t = torch.sqrt(alpha_bar_t)
|
| 107 |
+
sigma_t = torch.sqrt(1 - alpha_bar_t)
|
| 108 |
+
|
| 109 |
+
return alpha_t, sigma_t
|
| 110 |
+
|
| 111 |
+
def _linear_interpolation(self, t: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 112 |
+
"""
|
| 113 |
+
Linear interpolation: x_t = (1-t) * x_0 + t * epsilon
|
| 114 |
+
|
| 115 |
+
This is the default for Flow Matching but NOT a proper DDPM schedule.
|
| 116 |
+
This is what the current implementation uses.
|
| 117 |
+
"""
|
| 118 |
+
alpha_t = 1 - t
|
| 119 |
+
sigma_t = t
|
| 120 |
+
return alpha_t, sigma_t
|
| 121 |
+
|
| 122 |
+
def get_snr(self, t: torch.Tensor) -> torch.Tensor:
|
| 123 |
+
"""
|
| 124 |
+
Compute signal-to-noise ratio (SNR) = alpha_t^2 / sigma_t^2
|
| 125 |
+
|
| 126 |
+
Useful for:
|
| 127 |
+
1. Time warping between different schedules
|
| 128 |
+
2. Analysis and visualization
|
| 129 |
+
|
| 130 |
+
Args:
|
| 131 |
+
t: Tensor of timesteps in [0, 1]
|
| 132 |
+
|
| 133 |
+
Returns:
|
| 134 |
+
snr: Signal-to-noise ratio at each timestep
|
| 135 |
+
"""
|
| 136 |
+
alpha_t, sigma_t = self.get_schedule(t)
|
| 137 |
+
snr = (alpha_t ** 2) / (sigma_t ** 2 + 1e-8)
|
| 138 |
+
return snr
|
| 139 |
+
|
| 140 |
+
def alpha_to_time(self, alpha: torch.Tensor, num_steps: int = 100) -> torch.Tensor:
|
| 141 |
+
"""
|
| 142 |
+
Inverse mapping: given alpha, find t
|
| 143 |
+
|
| 144 |
+
Used for inference when you want to specify noise levels directly.
|
| 145 |
+
Uses binary search since schedules are monotonic.
|
| 146 |
+
|
| 147 |
+
Args:
|
| 148 |
+
alpha: Desired alpha values
|
| 149 |
+
num_steps: Number of steps for binary search
|
| 150 |
+
|
| 151 |
+
Returns:
|
| 152 |
+
t: Corresponding timesteps
|
| 153 |
+
"""
|
| 154 |
+
device = alpha.device
|
| 155 |
+
|
| 156 |
+
# Binary search for t
|
| 157 |
+
t_candidates = torch.linspace(0, 1, num_steps, device=device)
|
| 158 |
+
alpha_candidates, _ = self.get_schedule(t_candidates)
|
| 159 |
+
|
| 160 |
+
# Find closest match
|
| 161 |
+
distances = torch.abs(alpha_candidates.unsqueeze(0) - alpha.unsqueeze(1))
|
| 162 |
+
indices = torch.argmin(distances, dim=1)
|
| 163 |
+
t = t_candidates[indices]
|
| 164 |
+
|
| 165 |
+
return t
|
| 166 |
+
|
src/vae_utils.py
ADDED
|
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# src/vae_utils.py
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from diffusers import AutoencoderKL
|
| 5 |
+
from typing import Optional
|
| 6 |
+
import numpy as np
|
| 7 |
+
|
| 8 |
+
class VAEManager:
|
| 9 |
+
"""Utility class for VAE encoding/decoding operations"""
|
| 10 |
+
|
| 11 |
+
def __init__(self, model_name: str = "stabilityai/sd-vae-ft-mse", device: str = "cuda"):
|
| 12 |
+
self.device = device
|
| 13 |
+
self.model_name = model_name
|
| 14 |
+
self.vae = None
|
| 15 |
+
self._load_vae()
|
| 16 |
+
|
| 17 |
+
def _load_vae(self):
|
| 18 |
+
"""Load VAE model"""
|
| 19 |
+
print(f"Loading VAE: {self.model_name}")
|
| 20 |
+
self.vae = AutoencoderKL.from_pretrained(self.model_name)
|
| 21 |
+
self.vae = self.vae.to(self.device)
|
| 22 |
+
self.vae.eval()
|
| 23 |
+
|
| 24 |
+
# Freeze VAE parameters
|
| 25 |
+
for param in self.vae.parameters():
|
| 26 |
+
param.requires_grad = False
|
| 27 |
+
|
| 28 |
+
def encode(self, images: torch.Tensor) -> torch.Tensor:
|
| 29 |
+
"""
|
| 30 |
+
Encode images to latent space
|
| 31 |
+
|
| 32 |
+
Args:
|
| 33 |
+
images: Tensor of shape [B, 3, H, W] in range [-1, 1]
|
| 34 |
+
|
| 35 |
+
Returns:
|
| 36 |
+
latents: Tensor of shape [B, 4, H//8, W//8]
|
| 37 |
+
"""
|
| 38 |
+
with torch.no_grad():
|
| 39 |
+
images = images.to(self.device)
|
| 40 |
+
latent_dist = self.vae.encode(images).latent_dist
|
| 41 |
+
latents = latent_dist.sample()
|
| 42 |
+
latents = latents * self.vae.config.scaling_factor
|
| 43 |
+
|
| 44 |
+
return latents
|
| 45 |
+
|
| 46 |
+
def decode(self, latents: torch.Tensor, upscale_factor: Optional[float] = None,
|
| 47 |
+
upscale_mode: str = 'bicubic') -> torch.Tensor:
|
| 48 |
+
"""
|
| 49 |
+
Decode latents to images
|
| 50 |
+
|
| 51 |
+
Args:
|
| 52 |
+
latents: Tensor of shape [B, 4, H, W]
|
| 53 |
+
upscale_factor: Optional upscaling factor (e.g., 2.0 for 2x, 1.5 for 1.5x)
|
| 54 |
+
If None, returns images at native resolution (H*8, W*8)
|
| 55 |
+
upscale_mode: Interpolation mode ('bicubic', 'bilinear', 'nearest')
|
| 56 |
+
|
| 57 |
+
Returns:
|
| 58 |
+
images: Tensor of shape [B, 3, H*8*upscale_factor, W*8*upscale_factor] in range [-1, 1]
|
| 59 |
+
"""
|
| 60 |
+
with torch.no_grad():
|
| 61 |
+
latents = latents.to(self.device)
|
| 62 |
+
# Rescale latents
|
| 63 |
+
latents = latents / self.vae.config.scaling_factor
|
| 64 |
+
images = self.vae.decode(latents).sample
|
| 65 |
+
|
| 66 |
+
# Apply upscaling if requested
|
| 67 |
+
if upscale_factor is not None and upscale_factor != 1.0:
|
| 68 |
+
_, _, h, w = images.shape
|
| 69 |
+
new_h = int(h * upscale_factor)
|
| 70 |
+
new_w = int(w * upscale_factor)
|
| 71 |
+
images = F.interpolate(
|
| 72 |
+
images,
|
| 73 |
+
size=(new_h, new_w),
|
| 74 |
+
mode=upscale_mode,
|
| 75 |
+
align_corners=False if upscale_mode in ['bilinear', 'bicubic'] else None,
|
| 76 |
+
antialias=True if upscale_mode in ['bilinear', 'bicubic'] else False
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
return images
|
| 80 |
+
|
| 81 |
+
def decode_to_pil(self, latents: torch.Tensor, upscale_factor: Optional[float] = None,
|
| 82 |
+
upscale_mode: str = 'bicubic', target_size: Optional[tuple] = None):
|
| 83 |
+
"""
|
| 84 |
+
Decode latents to PIL images
|
| 85 |
+
|
| 86 |
+
Args:
|
| 87 |
+
latents: Tensor of shape [B, 4, H, W]
|
| 88 |
+
upscale_factor: Optional upscaling factor (e.g., 2.0 for 2x)
|
| 89 |
+
upscale_mode: Interpolation mode ('bicubic', 'bilinear', 'nearest')
|
| 90 |
+
target_size: Optional target size as (height, width). Overrides upscale_factor if provided.
|
| 91 |
+
|
| 92 |
+
Returns:
|
| 93 |
+
pil_images: List of PIL images
|
| 94 |
+
"""
|
| 95 |
+
from PIL import Image
|
| 96 |
+
|
| 97 |
+
# Decode to tensor
|
| 98 |
+
images = self.decode(latents, upscale_factor=upscale_factor, upscale_mode=upscale_mode)
|
| 99 |
+
|
| 100 |
+
# Apply target size if specified
|
| 101 |
+
if target_size is not None:
|
| 102 |
+
images = F.interpolate(
|
| 103 |
+
images,
|
| 104 |
+
size=target_size,
|
| 105 |
+
mode=upscale_mode,
|
| 106 |
+
align_corners=False if upscale_mode in ['bilinear', 'bicubic'] else None,
|
| 107 |
+
antialias=True if upscale_mode in ['bilinear', 'bicubic'] else False
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
# Convert to [0, 1] range
|
| 111 |
+
images = (images + 1.0) / 2.0
|
| 112 |
+
images = torch.clamp(images, 0, 1)
|
| 113 |
+
|
| 114 |
+
# Convert to PIL
|
| 115 |
+
pil_images = []
|
| 116 |
+
for i in range(images.shape[0]):
|
| 117 |
+
img_array = images[i].cpu().numpy().transpose(1, 2, 0)
|
| 118 |
+
img_array = (img_array * 255).astype(np.uint8)
|
| 119 |
+
pil_image = Image.fromarray(img_array)
|
| 120 |
+
pil_images.append(pil_image)
|
| 121 |
+
|
| 122 |
+
return pil_images
|
| 123 |
+
|
| 124 |
+
@property
|
| 125 |
+
def scaling_factor(self) -> float:
|
| 126 |
+
"""Get VAE scaling factor"""
|
| 127 |
+
return self.vae.config.scaling_factor
|
| 128 |
+
|
| 129 |
+
@property
|
| 130 |
+
def latent_channels(self) -> int:
|
| 131 |
+
"""Get number of latent channels"""
|
| 132 |
+
return 4 # Standard for Stable Diffusion VAE
|
| 133 |
+
|
| 134 |
+
def create_vae_manager(model_name: str = "stabilityai/sd-vae-ft-mse", device: str = "cuda") -> VAEManager:
|
| 135 |
+
"""Factory function to create VAE manager"""
|
| 136 |
+
return VAEManager(model_name, device)
|
| 137 |
+
|
| 138 |
+
def save_images_from_latents(latents: torch.Tensor, save_dir: str, vae_manager: VAEManager, prefix: str = "sample"):
|
| 139 |
+
"""
|
| 140 |
+
Save images from latents using VAE decoder
|
| 141 |
+
|
| 142 |
+
Args:
|
| 143 |
+
latents: Tensor of shape [B, 4, H, W]
|
| 144 |
+
save_dir: Directory to save images
|
| 145 |
+
vae_manager: VAE manager instance
|
| 146 |
+
prefix: Filename prefix
|
| 147 |
+
"""
|
| 148 |
+
import os
|
| 149 |
+
|
| 150 |
+
os.makedirs(save_dir, exist_ok=True)
|
| 151 |
+
|
| 152 |
+
# Decode to PIL images
|
| 153 |
+
pil_images = vae_manager.decode_to_pil(latents)
|
| 154 |
+
|
| 155 |
+
# Save each image
|
| 156 |
+
for i, pil_image in enumerate(pil_images):
|
| 157 |
+
save_path = os.path.join(save_dir, f"{prefix}_{i:03d}.png")
|
| 158 |
+
pil_image.save(save_path)
|
| 159 |
+
|
| 160 |
+
print(f"Saved {len(pil_images)} images to {save_dir}")
|
| 161 |
+
|
| 162 |
+
def create_image_grid(latents: torch.Tensor, vae_manager: VAEManager, nrow: int = 4) -> torch.Tensor:
|
| 163 |
+
"""
|
| 164 |
+
Create image grid from latents
|
| 165 |
+
|
| 166 |
+
Args:
|
| 167 |
+
latents: Tensor of shape [B, 4, H, W]
|
| 168 |
+
vae_manager: VAE manager instance
|
| 169 |
+
nrow: Number of images per row
|
| 170 |
+
|
| 171 |
+
Returns:
|
| 172 |
+
grid: Image grid tensor
|
| 173 |
+
"""
|
| 174 |
+
import torchvision.utils as vutils
|
| 175 |
+
|
| 176 |
+
# Decode latents
|
| 177 |
+
images = vae_manager.decode(latents)
|
| 178 |
+
|
| 179 |
+
# Convert to [0, 1] range
|
| 180 |
+
images = (images + 1.0) / 2.0
|
| 181 |
+
images = torch.clamp(images, 0, 1)
|
| 182 |
+
|
| 183 |
+
# Create grid
|
| 184 |
+
grid = vutils.make_grid(images, nrow=nrow, padding=2)
|
| 185 |
+
|
| 186 |
+
return grid
|
weights/bf16/config.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:bf54162afaf045deefb715e9834ed60948d7494354e866e70e76ddaebe575a78
|
| 3 |
+
size 2908
|
weights/bf16/expert_0.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:4a069731935a6285a64e2379c554371997ff32ad1f6c956422cfb83a8086549d
|
| 3 |
+
size 1211979376
|
weights/bf16/expert_1.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:5a5d45e5b96ce31cc3c2c9d8f903fb75c7d0b757be96212ec345ee0e78037d48
|
| 3 |
+
size 1211979376
|
weights/bf16/expert_2.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:9fa3505dfa75f4b82894064cc3c3b70aa6f409796dc7cda8bc14ce3572268a44
|
| 3 |
+
size 1211979376
|
weights/bf16/expert_3.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c9f53a42c3690ff8e27187a6c42770c888a1ce2fca8c132e181433870a6b4797
|
| 3 |
+
size 1211979376
|
weights/bf16/expert_4.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:58a367b34eb486789f9e8709384ad45d69768ac302a896fac85bd512134cdb3b
|
| 3 |
+
size 1211979376
|
weights/bf16/expert_5.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:37c9ce6fa79faa97a029de00fcdedc7e96dbc5de36deabc953ad2ee95c2ab0ad
|
| 3 |
+
size 1211979376
|
weights/bf16/expert_6.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:1fb83aad644a4fe22cd661cc4bedd49c73815dfc91bf81caf6a89dc21f1f90b3
|
| 3 |
+
size 1211979376
|
weights/bf16/expert_7.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c1d37b9b495d74121080237dbed32a5042ecbd7ed8ed619519cc2946f26e199b
|
| 3 |
+
size 1211979376
|
weights/bf16/router.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ff8aaa22f59e382227b3b9fe6527010a6929e8b0b7c4322213b392a0ca03a1bf
|
| 3 |
+
size 258286840
|
weights/bf16/router_config.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:0e951c1c39ad5401b33bb3147f62803d76303a7b7ca0e457e4cc0aaf1e585bb5
|
| 3 |
+
size 2744
|
weights/int8/expert_0.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:4a0942e1503de55b07393582bb231fc0c8358cb8f03b329c3e282f8c4a8b861c
|
| 3 |
+
size 606080694
|
weights/int8/expert_1.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:e6a9ad48d7b84574a122f52ef6d619cb8c5d9f3766c1a55af5f0b5d463fbd109
|
| 3 |
+
size 606080672
|
weights/int8/expert_2.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:baa7887b97f60db4532682be3701f6e9fc9a9dec1446af00ff3f1515055f888e
|
| 3 |
+
size 606080694
|
weights/int8/expert_3.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c36196779aba27cef9ea66f5775a9ba43886a7811904fb66a9b23b0095800da9
|
| 3 |
+
size 606080694
|
weights/int8/expert_4.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d255e3a6d89f0bfbff7461dff0fb27fa206a4b9e98a83be87121327d9cac56f7
|
| 3 |
+
size 606080694
|
weights/int8/expert_5.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f5cabe2f9deba779a5ea14a4d2e038a9aefd0ed5a4d2cd1b7776cd10939ffb21
|
| 3 |
+
size 606080694
|
weights/int8/expert_6.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f0470248ca040b321a5e72ce73f05c32c9d0cbe8515115021fb0f6065cc3599d
|
| 3 |
+
size 606080694
|
weights/int8/expert_7.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:6f0047faeb6e7d0e5acc0abcbdede83789a5fec6e6d93ef3c4d5903785dd4660
|
| 3 |
+
size 606080694
|
weights/int8/router.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:2a52cce497dd02a88804bc81669eba0ab4957dd2b3c54b8de781dabb5a8c15b2
|
| 3 |
+
size 256740839
|