nbagel commited on
Commit
4dec1ca
·
verified ·
1 Parent(s): 2581971

Initial upload: Paris MoE inference code and weights

Browse files
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ .ipynb_checkpoints/test_inference-checkpoint.png filter=lfs diff=lfs merge=lfs -text
37
+ .ipynb_checkpoints/test_int8-checkpoint.png filter=lfs diff=lfs merge=lfs -text
.ipynb_checkpoints/instructions-checkpoint.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ What we're now is we're going to prepare an inference folder and we're going to make an inference repository for our Paris model. This will include, we will just stick with int8 and bfloat16 and mixed int8, bfloat16 for now. And this repository will include efficient methods how to run the code. It will include quantization code that can accept either PT or save tensors or Float32 save tensors or Float32 PT. It will include a lot of different methods that can accept either PT or Float32 PT. We need to make a visualizer next which outputs a little pretty ASCII chart. We should output the ASCII chart right on the terminal every time we run the inference via this tool. Let's just say we're running the int8 inference of the mixed int8 model. By the way, we're also going to put the weights that we quantized inside this inference folder because we're going to publish this on HuggingFace. have just again, the beef flow 16 and intake weights. we might already be done this by the way. But again, I wanted to do that when we have to keep some kind of track and output a chart in the terminal, like as a little terminal visualization in ASCII. MAKE SURE WE'RE DOING ROUTING PROPERLY. Top 2 etc. Again, just to recap, we're going to make a folder that's just called inference. In this folder, we're going to put the quantized weights that we already made, because we already made them before in the last session. So the bfloat16 and the int8 weights. And we're going to put one Python file for the inference code, and it's going to have all the flags, and it's also going to have a visualized flag. And the visualized flag is actually a lot more than that, because it keeps track of which expert is being used during each inference step, and that shows like a little pretty chart. So if we're generating with 30 steps, which is going to show which experts got to use the most and the least out of eight of them. And so we want to have this in the inference code. Make sure to read files in full before like a pass inference code that we already wrote. Try to list like the most recent files that we made for that. And we also want to have the quantization code to just be an all in one utility with a very nice terminal interface as well, because we want the quantization code to be able to handle float 16 bfloat 16 float 32 weights in both safe tensors and in dot pt format. So that needs to be very smart and tested that it actually works. And also, yeah, make a read me in this folder for the Paris model, because we're going to publish this on hogging face as the inference repository. And then also read all the MD files that we have written here in full because after we do all of this and after we test that it works and it differences fine. We're going to we're going to start to play around with network inference. So that's going to be the fun next step after. So again, make a 20 point to do this for this and please make sure to include at least four or five sentences per point. So the to do list is going to be very long, naturally and very detailed. But I believe we're going to do an excellent, excellent job here.
.ipynb_checkpoints/test_inference-checkpoint.png ADDED

Git LFS Details

  • SHA256: 4dba193213b76355dcae180f5631d852e440201f4c8a01cc8c2671fd96394aef
  • Pointer size: 131 Bytes
  • Size of remote file: 365 kB
.ipynb_checkpoints/test_int8-checkpoint.png ADDED

Git LFS Details

  • SHA256: 772d2190b8cbf97e519337bc959254abc4be981a46795e762fda5c8a9efecd9b
  • Pointer size: 131 Bytes
  • Size of remote file: 136 kB
README.md ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 🥖 Baguette - Paris MoE Text-to-Image
2
+
3
+ A ~5 billion parameter Mixture-of-Experts diffusion model with 8 specialized experts.
4
+
5
+ ## ⚡ Quick Start
6
+
7
+ ```bash
8
+ # Install dependencies
9
+ pip install uv && uv pip install torch torchvision safetensors transformers diffusers accelerate tqdm
10
+
11
+ # Generate 4 cat images
12
+ python generate.py --prompt "a cute cat" --num_samples 4
13
+ ```
14
+
15
+ That's it! Images saved to `output_bf16.png`.
16
+
17
+ ---
18
+
19
+ ## 🎨 Examples
20
+
21
+ ```bash
22
+ # Simple generation
23
+ python generate.py --prompt "sunset over mountains"
24
+
25
+ # More samples, see expert routing
26
+ python generate.py --prompt "abstract art" --num_samples 16 --visualize
27
+
28
+ # Faster with fewer steps
29
+ python generate.py --prompt "a dog" --num_steps 15
30
+
31
+ # Lower memory (offload 4 experts to CPU)
32
+ python generate.py --prompt "portrait" --offload 4
33
+
34
+ # INT8 weights (smaller, slightly lower quality)
35
+ python generate.py --prompt "forest" --precision int8
36
+ ```
37
+
38
+ ---
39
+
40
+ ## 📋 All Options
41
+
42
+ | Flag | Default | Description |
43
+ |------|---------|-------------|
44
+ | `--prompt` | "a cute cat" | What to generate |
45
+ | `--num_samples` | 16 | Number of images |
46
+ | `--num_steps` | 30 | Sampling steps (20-50 recommended) |
47
+ | `--cfg_scale` | 7.5 | Guidance strength (5-10 recommended) |
48
+ | `--precision` | bf16 | `bf16` (best) or `int8` (smaller) |
49
+ | `--topk` | 2 | Experts per sample (1 or 2) |
50
+ | `--offload` | 0 | Experts to keep on CPU (0-7) |
51
+ | `--visualize` | off | Show expert routing stats |
52
+ | `--output` | auto | Output filename |
53
+ | `--seed` | 999 | Random seed |
54
+
55
+ ---
56
+
57
+ ## 🔍 Expert Visualization
58
+
59
+ Use `--visualize` to see which experts the router selects:
60
+
61
+ ```
62
+ ╭──────────────────────────────────────────────────╮
63
+ │ ⚡ EXPERT USAGE DISTRIBUTION │
64
+ ├──────────────────────────────────────────────────┤
65
+ │ → E4 │████████████████████████████│ 40.6% │
66
+ │ E2 │██████████████████████████ │ 36.7% │
67
+ │ E6 │██████████ │ 14.8% │
68
+ │ E1 │███ │ 5.5% │
69
+ │ E5 │█ │ 2.3% │
70
+ │ E0 │ │ 0.0% │
71
+ │ E3 │ │ 0.0% │
72
+ │ E7 │ │ 0.0% │
73
+ ├──────────────────────────────────────────────────┤
74
+ │ Active: 5/8 experts Calls: 128 │
75
+ ╰──────────────────────────────────────────────────╯
76
+
77
+ ╭──────────────────────────────────────────────────╮
78
+ │ 📈 ROUTING TIMELINE │
79
+ ├──────────────────────────────────────────────────┤
80
+ │ Step 0 1 2 3 4 5 6 7 8 9 10 11 ... │
81
+ │ ──────────────────────────────────────────── │
82
+ │ E0 · · · · · · · · · · · · │
83
+ │ E2 · · · · · · ● ● ● ● ● ● │
84
+ │ E4 · · ● ● ● ● · · · · · · │
85
+ │ E6 ● ● · · · · · · · · · · │
86
+ ├──────────────────────────────────────────────────┤
87
+ │ Routing changes: 2/11 steps (18%) │
88
+ ╰──────────────────────────────────────────────────╯
89
+ ```
90
+
91
+ ---
92
+
93
+ ## 💾 Memory & Speed
94
+
95
+ | Config | GPU Memory | Speed |
96
+ |--------|-----------|-------|
97
+ | BF16 (all on GPU) | ~25 GB | ~3 img/s |
98
+ | BF16 + offload 4 | ~14 GB | ~1 img/s |
99
+ | INT8 (all on GPU) | ~12 GB | ~2 img/s |
100
+ | INT8 + offload 4 | ~8 GB | ~0.5 img/s |
101
+
102
+ ---
103
+
104
+ ## 🏗️ Architecture
105
+
106
+ ```
107
+ ┌─────────────────────────────────────────┐
108
+ │ Paris MoE Model │
109
+ ├─────────────────────────────────────────┤
110
+ │ Router: DiT-B/2 (129M params) │
111
+ │ ↓ selects top-K experts │
112
+ │ Experts: 8× DiT-XL/2 (606M each) │
113
+ │ ↓ predicts velocity │
114
+ │ VAE: Stable Diffusion VAE │
115
+ �� ↓ decodes to pixels │
116
+ │ Output: 256×256 RGB │
117
+ └─────────────────────────────────────────┘
118
+ ```
119
+
120
+ - **Total Parameters**: ~5 Billion
121
+ - **Latent Space**: 32×32×4
122
+ - **Text Encoder**: CLIP ViT-L/14
123
+
124
+ ---
125
+
126
+ ## 📁 Files
127
+
128
+ ```
129
+ ├── generate.py # Main generation script
130
+ ├── benchmark.py # Performance testing
131
+ ├── quantize.py # Weight conversion tool
132
+ ├── src/ # Model code
133
+ └── weights/
134
+ ├── bf16/ # BFloat16 weights (9.3 GB)
135
+ └── int8/ # INT8 weights (4.8 GB)
136
+ ```
137
+
138
+ ---
139
+
140
+ ## 🔧 Convert Your Own Weights
141
+
142
+ ```bash
143
+ # From PyTorch .pt to BF16 safetensors
144
+ python quantize.py --input /path/to/weights --output ./weights/bf16 --format bf16
145
+
146
+ # From BF16 to INT8
147
+ python quantize.py --input ./weights/bf16 --output ./weights/int8 --format int8
148
+ ```
149
+
150
+ ---
151
+
152
+ ## 📜 License
153
+
154
+ Apache 2.0
__pycache__/generate.cpython-312.pyc ADDED
Binary file (37.7 kB). View file
 
benchmark.py ADDED
@@ -0,0 +1,440 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ ╔══════════════════════════════════════════════════════════════════════════════╗
4
+ ║ ║
5
+ ║ 📊 Paris MoE - Comprehensive Benchmarking Utility 📊 ║
6
+ ║ ║
7
+ ║ Measures performance across precision modes, batch sizes, and configs. ║
8
+ ║ Outputs results as both terminal display and Markdown file. ║
9
+ ║ ║
10
+ ╚══════════════════════════════════════════════════════════════════════════════╝
11
+
12
+ Usage:
13
+ python benchmark.py # Run all benchmarks
14
+ python benchmark.py --quick # Quick benchmark (fewer configs)
15
+ python benchmark.py --precision bf16 # Benchmark specific precision
16
+ python benchmark.py --output results.md # Save results to file
17
+ """
18
+
19
+ import argparse
20
+ import sys
21
+ import os
22
+ import time
23
+ import gc
24
+ from pathlib import Path
25
+ from datetime import datetime
26
+ from dataclasses import dataclass
27
+ from typing import List, Dict, Optional
28
+
29
+ SCRIPT_DIR = Path(__file__).parent.absolute()
30
+ SRC_DIR = SCRIPT_DIR / "src"
31
+ sys.path.insert(0, str(SRC_DIR))
32
+
33
+ import torch
34
+
35
+ # ═══════════════════════════════════════════════════════════════════════════════
36
+ # DATA STRUCTURES
37
+ # ═══════════════════════════════════════════════════════════════════════════════
38
+
39
+ @dataclass
40
+ class BenchmarkResult:
41
+ """Single benchmark result."""
42
+ precision: str
43
+ num_samples: int
44
+ num_steps: int
45
+ topk: int
46
+ offload: int
47
+
48
+ load_time: float # Model loading time (seconds)
49
+ gen_time: float # Generation time (seconds)
50
+ decode_time: float # VAE decoding time (seconds)
51
+
52
+ peak_memory_gb: float # Peak GPU memory usage
53
+
54
+ @property
55
+ def total_time(self) -> float:
56
+ return self.gen_time + self.decode_time
57
+
58
+ @property
59
+ def throughput(self) -> float:
60
+ """Images per second (generation only)."""
61
+ return self.num_samples / self.gen_time if self.gen_time > 0 else 0
62
+
63
+ @property
64
+ def time_per_step(self) -> float:
65
+ """Seconds per sampling step."""
66
+ return self.gen_time / self.num_steps if self.num_steps > 0 else 0
67
+
68
+ @property
69
+ def time_per_image(self) -> float:
70
+ """Seconds per image (generation only)."""
71
+ return self.gen_time / self.num_samples if self.num_samples > 0 else 0
72
+
73
+
74
+ # ═══════════════════════════════════════════════════════════════════════════════
75
+ # BENCHMARK RUNNER
76
+ # ═══════════════════════════════════════════════════════════════════════════════
77
+
78
+ def get_gpu_memory_gb() -> float:
79
+ """Get current GPU memory usage in GB."""
80
+ if torch.cuda.is_available():
81
+ return torch.cuda.max_memory_allocated() / (1024 ** 3)
82
+ return 0.0
83
+
84
+
85
+ def reset_gpu_memory():
86
+ """Reset GPU memory tracking."""
87
+ if torch.cuda.is_available():
88
+ torch.cuda.reset_peak_memory_stats()
89
+ torch.cuda.empty_cache()
90
+ gc.collect()
91
+
92
+
93
+ def run_single_benchmark(precision: str, num_samples: int, num_steps: int,
94
+ topk: int, offload: int, device: str = 'cuda') -> BenchmarkResult:
95
+ """Run a single benchmark configuration."""
96
+ from generate import load_sampler
97
+
98
+ reset_gpu_memory()
99
+
100
+ # Load model
101
+ start_load = time.time()
102
+ sampler = load_sampler(precision=precision, device=device, offload=offload)
103
+ load_time = time.time() - start_load
104
+
105
+ # Set seed for reproducibility
106
+ torch.manual_seed(42)
107
+ if torch.cuda.is_available():
108
+ torch.cuda.manual_seed(42)
109
+
110
+ # Warmup run
111
+ _ = sampler.sample(
112
+ num_samples=1,
113
+ text_prompts=["warmup"],
114
+ cfg_scale=7.5,
115
+ num_steps=2,
116
+ use_bf16=(precision == 'bf16'),
117
+ topk=topk
118
+ )
119
+
120
+ reset_gpu_memory()
121
+ torch.cuda.synchronize()
122
+
123
+ # Timed generation
124
+ start_gen = time.time()
125
+ latents = sampler.sample(
126
+ num_samples=num_samples,
127
+ text_prompts=["a cute cat"],
128
+ cfg_scale=7.5,
129
+ num_steps=num_steps,
130
+ use_bf16=(precision == 'bf16'),
131
+ topk=topk
132
+ )
133
+ torch.cuda.synchronize()
134
+ gen_time = time.time() - start_gen
135
+
136
+ # Timed decoding
137
+ start_decode = time.time()
138
+ images = sampler.vae_manager.decode(latents)
139
+ torch.cuda.synchronize()
140
+ decode_time = time.time() - start_decode
141
+
142
+ peak_memory = get_gpu_memory_gb()
143
+
144
+ # Cleanup
145
+ del sampler, latents, images
146
+ gc.collect()
147
+ torch.cuda.empty_cache()
148
+
149
+ return BenchmarkResult(
150
+ precision=precision,
151
+ num_samples=num_samples,
152
+ num_steps=num_steps,
153
+ topk=topk,
154
+ offload=offload,
155
+ load_time=load_time,
156
+ gen_time=gen_time,
157
+ decode_time=decode_time,
158
+ peak_memory_gb=peak_memory
159
+ )
160
+
161
+
162
+ # ═══════════════════════════════════════════════════════════════════════════════
163
+ # OUTPUT FORMATTERS
164
+ # ═══════════════════════════════════════════════════════════════════════════════
165
+
166
+ def format_terminal_results(results: List[BenchmarkResult], gpu_name: str) -> str:
167
+ """Format results for terminal display."""
168
+ lines = []
169
+
170
+ lines.append("""
171
+ ╔══════════════════════════════════════════════════════════════════════════════╗
172
+ ║ 📊 PARIS MoE BENCHMARK RESULTS 📊 ║
173
+ ╚══════════════════════════════════════════════════════════════════════════════╝
174
+ """)
175
+
176
+ lines.append(f" GPU: {gpu_name}")
177
+ lines.append(f" Date: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
178
+ lines.append("")
179
+
180
+ # Group by precision
181
+ precisions = sorted(set(r.precision for r in results))
182
+
183
+ for precision in precisions:
184
+ prec_results = [r for r in results if r.precision == precision]
185
+
186
+ lines.append(f"┌{'─'*78}┐")
187
+ lines.append(f"│ {precision.upper()} Precision{' '*65}│")
188
+ lines.append(f"├{'─'*78}┤")
189
+ lines.append(f"│ {'Samples':>8} │ {'Steps':>6} │ {'TopK':>5} │ {'Offload':>7} │ "
190
+ f"{'Gen(s)':>8} │ {'Img/s':>6} │ {'s/step':>6} │ {'Mem(GB)':>8} │")
191
+ lines.append(f"├{'─'*78}┤")
192
+
193
+ for r in prec_results:
194
+ lines.append(
195
+ f"│ {r.num_samples:>8} │ {r.num_steps:>6} │ {r.topk:>5} │ {r.offload:>7} │ "
196
+ f"{r.gen_time:>8.2f} │ {r.throughput:>6.2f} │ {r.time_per_step:>6.3f} │ "
197
+ f"{r.peak_memory_gb:>8.2f} │"
198
+ )
199
+
200
+ lines.append(f"└{'─'*78}┘")
201
+ lines.append("")
202
+
203
+ # Summary
204
+ if results:
205
+ fastest = min(results, key=lambda r: r.time_per_image)
206
+ most_efficient = min(results, key=lambda r: r.peak_memory_gb)
207
+
208
+ lines.append("┌─────────────────────────────────────────────────────────────────┐")
209
+ lines.append("│ 📈 SUMMARY │")
210
+ lines.append("├─────────────────────────────────────────────────────────────────┤")
211
+ lines.append(f"│ 🏆 Fastest: {fastest.precision.upper():>6} @ {fastest.throughput:.2f} img/s │")
212
+ lines.append(f"│ 💾 Most Efficient: {most_efficient.precision.upper():>6} @ {most_efficient.peak_memory_gb:.1f} GB peak │")
213
+ lines.append("└─────────────────────────────────────────────────────────────────┘")
214
+
215
+ return "\n".join(lines)
216
+
217
+
218
+ def format_markdown_results(results: List[BenchmarkResult], gpu_name: str) -> str:
219
+ """Format results as Markdown."""
220
+ lines = []
221
+
222
+ lines.append("# 📊 Paris MoE Benchmark Results")
223
+ lines.append("")
224
+ lines.append(f"**GPU:** {gpu_name}")
225
+ lines.append(f"**Date:** {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
226
+ lines.append("")
227
+
228
+ lines.append("## 🏗️ Model Architecture")
229
+ lines.append("")
230
+ lines.append("| Component | Details |")
231
+ lines.append("|-----------|---------|")
232
+ lines.append("| Experts | 8× DiT-XL/2 (606M params each) |")
233
+ lines.append("| Router | DiT-B/2 (129M params) |")
234
+ lines.append("| Total | ~5 Billion parameters |")
235
+ lines.append("| VAE | SD-VAE (stabilityai/sd-vae-ft-mse) |")
236
+ lines.append("| Text Encoder | CLIP ViT-L/14 |")
237
+ lines.append("")
238
+
239
+ # Group by precision
240
+ precisions = sorted(set(r.precision for r in results))
241
+
242
+ for precision in precisions:
243
+ prec_results = [r for r in results if r.precision == precision]
244
+
245
+ lines.append(f"## {precision.upper()} Precision")
246
+ lines.append("")
247
+ lines.append("| Samples | Steps | TopK | Offload | Gen Time (s) | Throughput (img/s) | Time/Step (s) | Peak Memory (GB) |")
248
+ lines.append("|---------|-------|------|---------|--------------|-------------------|---------------|------------------|")
249
+
250
+ for r in prec_results:
251
+ lines.append(
252
+ f"| {r.num_samples} | {r.num_steps} | {r.topk} | {r.offload} | "
253
+ f"{r.gen_time:.2f} | {r.throughput:.2f} | {r.time_per_step:.3f} | {r.peak_memory_gb:.2f} |"
254
+ )
255
+
256
+ lines.append("")
257
+
258
+ # Summary
259
+ if results:
260
+ lines.append("## 📈 Summary")
261
+ lines.append("")
262
+
263
+ fastest = min(results, key=lambda r: r.time_per_image)
264
+ most_efficient = min(results, key=lambda r: r.peak_memory_gb)
265
+
266
+ lines.append(f"- **🏆 Fastest Configuration:** {fastest.precision.upper()}, "
267
+ f"{fastest.num_samples} samples @ {fastest.throughput:.2f} img/s")
268
+ lines.append(f"- **💾 Most Memory Efficient:** {most_efficient.precision.upper()} "
269
+ f"with offload={most_efficient.offload} @ {most_efficient.peak_memory_gb:.1f} GB peak")
270
+ lines.append("")
271
+
272
+ # Recommendations
273
+ lines.append("## 🎯 Recommendations")
274
+ lines.append("")
275
+ lines.append("| Use Case | Precision | Offload | Expected Performance |")
276
+ lines.append("|----------|-----------|---------|---------------------|")
277
+
278
+ bf16_results = [r for r in results if r.precision == 'bf16' and r.offload == 0]
279
+ if bf16_results:
280
+ r = bf16_results[0]
281
+ lines.append(f"| **Production (Quality)** | BF16 | 0 | {r.throughput:.2f} img/s, {r.peak_memory_gb:.1f} GB |")
282
+
283
+ int8_results = [r for r in results if r.precision == 'int8' and r.offload == 0]
284
+ if int8_results:
285
+ r = int8_results[0]
286
+ lines.append(f"| **Balanced** | INT8 | 0 | {r.throughput:.2f} img/s, {r.peak_memory_gb:.1f} GB |")
287
+
288
+ offload_results = [r for r in results if r.offload > 0]
289
+ if offload_results:
290
+ r = min(offload_results, key=lambda x: x.peak_memory_gb)
291
+ lines.append(f"| **Low VRAM** | {r.precision.upper()} | {r.offload} | {r.throughput:.2f} img/s, {r.peak_memory_gb:.1f} GB |")
292
+
293
+ lines.append("")
294
+ lines.append("---")
295
+ lines.append("*Generated by Paris MoE Benchmark Utility*")
296
+
297
+ return "\n".join(lines)
298
+
299
+
300
+ # ═══════════════════════════════════════════════════════════════════════════════
301
+ # MAIN
302
+ # ═══════════════════════════════════════════════════════════════════════════════
303
+
304
+ def parse_args():
305
+ parser = argparse.ArgumentParser(
306
+ description="📊 Paris MoE - Benchmark Utility",
307
+ formatter_class=argparse.RawDescriptionHelpFormatter,
308
+ epilog="""
309
+ Examples:
310
+ python benchmark.py # Full benchmark suite
311
+ python benchmark.py --quick # Quick benchmark
312
+ python benchmark.py --precision bf16 # BF16 only
313
+ python benchmark.py --output results.md # Save to file
314
+ """
315
+ )
316
+
317
+ parser.add_argument("--quick", action="store_true",
318
+ help="Run quick benchmark with fewer configurations")
319
+ parser.add_argument("--precision", type=str, default=None,
320
+ choices=["bf16", "int8", "mixed"],
321
+ help="Benchmark specific precision only")
322
+ parser.add_argument("--output", "-o", type=str, default=None,
323
+ help="Output Markdown file path")
324
+ parser.add_argument("--samples", type=int, default=None,
325
+ help="Override number of samples")
326
+ parser.add_argument("--steps", type=int, default=None,
327
+ help="Override number of steps")
328
+
329
+ return parser.parse_args()
330
+
331
+
332
+ def get_benchmark_configs(args) -> List[Dict]:
333
+ """Get list of benchmark configurations to run."""
334
+ configs = []
335
+
336
+ if args.quick:
337
+ # Quick benchmark: minimal configs
338
+ precisions = [args.precision] if args.precision else ['bf16', 'int8']
339
+ samples = args.samples or 4
340
+ steps = args.steps or 10
341
+
342
+ for precision in precisions:
343
+ configs.append({
344
+ 'precision': precision,
345
+ 'num_samples': samples,
346
+ 'num_steps': steps,
347
+ 'topk': 1,
348
+ 'offload': 0
349
+ })
350
+ else:
351
+ # Full benchmark suite
352
+ precisions = [args.precision] if args.precision else ['bf16', 'int8']
353
+ samples_list = [args.samples] if args.samples else [4, 16]
354
+ steps_list = [args.steps] if args.steps else [20, 30]
355
+ topk_list = [1, 2]
356
+ offload_list = [0, 4]
357
+
358
+ for precision in precisions:
359
+ for samples in samples_list:
360
+ for steps in steps_list:
361
+ for topk in topk_list:
362
+ for offload in offload_list:
363
+ configs.append({
364
+ 'precision': precision,
365
+ 'num_samples': samples,
366
+ 'num_steps': steps,
367
+ 'topk': topk,
368
+ 'offload': offload
369
+ })
370
+
371
+ return configs
372
+
373
+
374
+ def main():
375
+ args = parse_args()
376
+
377
+ print("""
378
+ ╔══════════════════════════════════════════════════════════════════════════════╗
379
+ ║ ║
380
+ ║ 📊 Paris MoE - Comprehensive Benchmarking Utility 📊 ║
381
+ ║ ║
382
+ ║ Measuring performance across precision modes, batch sizes, and configs. ║
383
+ ║ ║
384
+ ╚══════════════════════════════════════════════════════════════════════════════╝
385
+ """)
386
+
387
+ device = "cuda" if torch.cuda.is_available() else "cpu"
388
+ if device != "cuda":
389
+ print("⚠️ Warning: Running on CPU. Benchmarks will be slow.")
390
+
391
+ gpu_name = torch.cuda.get_device_name(0) if torch.cuda.is_available() else "CPU"
392
+ print(f"🖥️ Device: {gpu_name}")
393
+
394
+ configs = get_benchmark_configs(args)
395
+ print(f"📋 Running {len(configs)} benchmark configurations...\n")
396
+
397
+ results = []
398
+
399
+ for i, config in enumerate(configs):
400
+ print(f"[{i+1}/{len(configs)}] {config['precision'].upper()} | "
401
+ f"{config['num_samples']} samples | {config['num_steps']} steps | "
402
+ f"Top-{config['topk']} | Offload {config['offload']}")
403
+
404
+ try:
405
+ result = run_single_benchmark(
406
+ precision=config['precision'],
407
+ num_samples=config['num_samples'],
408
+ num_steps=config['num_steps'],
409
+ topk=config['topk'],
410
+ offload=config['offload'],
411
+ device=device
412
+ )
413
+ results.append(result)
414
+ print(f" ✅ {result.gen_time:.2f}s, {result.throughput:.2f} img/s, "
415
+ f"{result.peak_memory_gb:.1f} GB peak")
416
+ except Exception as e:
417
+ print(f" ❌ Failed: {e}")
418
+
419
+ print()
420
+
421
+ if not results:
422
+ print("❌ No successful benchmarks!")
423
+ return 1
424
+
425
+ # Print terminal results
426
+ terminal_output = format_terminal_results(results, gpu_name)
427
+ print(terminal_output)
428
+
429
+ # Save Markdown if requested
430
+ if args.output:
431
+ md_output = format_markdown_results(results, gpu_name)
432
+ with open(args.output, 'w') as f:
433
+ f.write(md_output)
434
+ print(f"\n✅ Results saved to: {args.output}")
435
+
436
+ return 0
437
+
438
+
439
+ if __name__ == "__main__":
440
+ exit(main())
benchmark_results.md ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 📊 Paris MoE Benchmark Results
2
+
3
+ **GPU:** NVIDIA RTX 6000 Ada Generation
4
+ **Date:** 2025-12-05 16:35:39
5
+
6
+ ## 🏗️ Model Architecture
7
+
8
+ | Component | Details |
9
+ |-----------|---------|
10
+ | Experts | 8× DiT-XL/2 (606M params each) |
11
+ | Router | DiT-B/2 (129M params) |
12
+ | Total | ~5 Billion parameters |
13
+ | VAE | SD-VAE (stabilityai/sd-vae-ft-mse) |
14
+ | Text Encoder | CLIP ViT-L/14 |
15
+
16
+ ## BF16 Precision
17
+
18
+ | Samples | Steps | TopK | Offload | Gen Time (s) | Throughput (img/s) | Time/Step (s) | Peak Memory (GB) |
19
+ |---------|-------|------|---------|--------------|-------------------|---------------|------------------|
20
+ | 4 | 10 | 1 | 0 | 1.49 | 2.68 | 0.149 | 10.79 |
21
+
22
+ ## INT8 Precision
23
+
24
+ | Samples | Steps | TopK | Offload | Gen Time (s) | Throughput (img/s) | Time/Step (s) | Peak Memory (GB) |
25
+ |---------|-------|------|---------|--------------|-------------------|---------------|------------------|
26
+ | 4 | 10 | 1 | 0 | 2.12 | 1.89 | 0.212 | 20.17 |
27
+
28
+ ## 📈 Summary
29
+
30
+ - **🏆 Fastest Configuration:** BF16, 4 samples @ 2.68 img/s
31
+ - **💾 Most Memory Efficient:** BF16 with offload=0 @ 10.8 GB peak
32
+
33
+ ## 🎯 Recommendations
34
+
35
+ | Use Case | Precision | Offload | Expected Performance |
36
+ |----------|-----------|---------|---------------------|
37
+ | **Production (Quality)** | BF16 | 0 | 2.68 img/s, 10.8 GB |
38
+ | **Balanced** | INT8 | 0 | 1.89 img/s, 20.2 GB |
39
+
40
+ ---
41
+ *Generated by Paris MoE Benchmark Utility*
generate.py ADDED
@@ -0,0 +1,747 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ ╔══════════════════════════════════════════════════════════════════════════════╗
4
+ ║ ║
5
+ ║ 🎨 Paris MoE - Unified Image Generation Script 🎨 ║
6
+ ║ ║
7
+ ║ Mixture-of-Experts Diffusion Model (8× DiT-XL/2 + DiT-B/2 Router) ║
8
+ ║ ~5 Billion Parameters Total ║
9
+ ║ ║
10
+ ╚══════════════════════════════════════════════════════════════════════════════╝
11
+
12
+ Supports multiple precision modes:
13
+ - bf16: Best quality, 9.3GB total (~1.2GB per expert)
14
+ - int8: Good quality, 4.8GB total (~580MB per expert), 15x compression
15
+ - mixed: Router in bf16, experts in int8 (balanced)
16
+
17
+ Memory Offloading:
18
+ - --offload N: Keep N experts in CPU memory, move to GPU only during computation
19
+ - Experts are moved to GPU → compute → moved back to CPU (memory offloading)
20
+ - All computation happens on GPU, only storage is on CPU
21
+
22
+ Usage:
23
+ python generate.py --prompt "a cute cat" --precision bf16
24
+ python generate.py --prompt "a sunset over mountains" --precision int8 --visualize
25
+ python generate.py --prompt "abstract art" --precision mixed --num_samples 4 --topk 2
26
+ """
27
+
28
+ import argparse
29
+ import sys
30
+ import os
31
+ import time
32
+ from pathlib import Path
33
+
34
+ # Add src to path for model imports
35
+ SCRIPT_DIR = Path(__file__).parent.absolute()
36
+ SRC_DIR = SCRIPT_DIR / "src"
37
+ sys.path.insert(0, str(SRC_DIR))
38
+
39
+ import torch
40
+ import torch.nn.functional as F
41
+ from tqdm import tqdm
42
+ from torchvision.utils import make_grid, save_image
43
+ from safetensors.torch import load_file
44
+ from safetensors import safe_open
45
+ from transformers import CLIPTextModel, CLIPTokenizer
46
+ from collections import defaultdict
47
+
48
+ # ═══════════════════════════════════════════════════════════════════════════════
49
+ # WEIGHT PATHS
50
+ # ═══════════════════════════════════════════════════════════════════════════════
51
+
52
+ WEIGHTS_DIR = SCRIPT_DIR / "weights"
53
+ BF16_DIR = WEIGHTS_DIR / "bf16"
54
+ INT8_DIR = WEIGHTS_DIR / "int8"
55
+
56
+
57
+ # ═══════════════════════════════════════════════════════════════════════════════
58
+ # ASCII VISUALIZATION
59
+ # ═══════════════════════════════════════════════════════════════════════════════
60
+
61
+ class ExpertTracker:
62
+ """Tracks which experts are used during generation for visualization."""
63
+
64
+ def __init__(self, num_experts: int = 8):
65
+ self.num_experts = num_experts
66
+ self.usage_counts = defaultdict(int)
67
+ self.per_step_primary = []
68
+ self.total_calls = 0
69
+
70
+ def record(self, expert_ids: torch.Tensor, step: int, weights: torch.Tensor = None):
71
+ """Record expert usage for a batch at a given step."""
72
+ step_counts = defaultdict(float)
73
+
74
+ if weights is not None:
75
+ for eid, w in zip(expert_ids.flatten().tolist(), weights.flatten().tolist()):
76
+ self.usage_counts[eid] += 1
77
+ step_counts[eid] += w
78
+ self.total_calls += 1
79
+ else:
80
+ for eid in expert_ids.tolist():
81
+ self.usage_counts[eid] += 1
82
+ step_counts[eid] += 1.0
83
+ self.total_calls += 1
84
+
85
+ if step_counts:
86
+ self.per_step_primary.append(max(step_counts, key=step_counts.get))
87
+
88
+ def get_usage_chart(self) -> str:
89
+ """Chart 1: Expert usage ranked by frequency."""
90
+ if self.total_calls == 0:
91
+ return ""
92
+
93
+ # Sort experts by usage
94
+ sorted_experts = sorted(range(8), key=lambda e: self.usage_counts.get(e, 0), reverse=True)
95
+ max_count = max(self.usage_counts.values()) if self.usage_counts else 1
96
+ unique = sum(1 for e in range(8) if self.usage_counts.get(e, 0) > 0)
97
+
98
+ lines = [
99
+ "",
100
+ "╭────────────────��─────────────────────────────────╮",
101
+ "│ ⚡ EXPERT USAGE DISTRIBUTION │",
102
+ "├──────────────────────────────────────────────────┤",
103
+ ]
104
+
105
+ bars = ['▏', '▎', '▍', '▌', '▋', '▊', '▉', '█']
106
+
107
+ for eid in sorted_experts:
108
+ count = self.usage_counts.get(eid, 0)
109
+ pct = 100 * count / self.total_calls if self.total_calls > 0 else 0
110
+
111
+ # Create gradient bar
112
+ bar_width = 28
113
+ fill = (count / max_count) * bar_width if max_count > 0 else 0
114
+ full_blocks = int(fill)
115
+ partial = int((fill - full_blocks) * 8)
116
+
117
+ bar = '█' * full_blocks
118
+ if partial > 0 and full_blocks < bar_width:
119
+ bar += bars[partial - 1]
120
+ bar = bar.ljust(bar_width, ' ')
121
+
122
+ marker = "→" if count == max_count and count > 0 else " "
123
+ lines.append(f"│ {marker} E{eid} │{bar}│ {pct:5.1f}% │")
124
+
125
+ lines.extend([
126
+ "├──────────────────────────────────────────────────┤",
127
+ f"│ Active: {unique}/8 experts Calls: {self.total_calls:<13} │",
128
+ "╰──────────────────────────────────────────────────╯",
129
+ ])
130
+
131
+ return "\n".join(lines)
132
+
133
+ def get_timeline(self) -> str:
134
+ """Chart 2: Visual timeline of expert selection per step."""
135
+ if not self.per_step_primary:
136
+ return ""
137
+
138
+ num_steps = len(self.per_step_primary)
139
+ show_steps = min(20, num_steps)
140
+
141
+ # Count transitions
142
+ transitions = sum(1 for i in range(1, num_steps)
143
+ if self.per_step_primary[i] != self.per_step_primary[i-1])
144
+
145
+ lines = [
146
+ "",
147
+ "╭──────────────────────────────────────────────────╮",
148
+ "│ 📈 ROUTING TIMELINE │",
149
+ "├──────────────────────────────────────────────────┤",
150
+ ]
151
+
152
+ # Compact step numbers
153
+ step_row = "│ Step "
154
+ for i in range(show_steps):
155
+ step_row += f"{i:2d} "
156
+ if num_steps > 20:
157
+ step_row = step_row[:48] + "..│"
158
+ else:
159
+ step_row = step_row[:48].ljust(48) + " │"
160
+ lines.append(step_row)
161
+
162
+ lines.append("│ " + "───" * 16 + " │")
163
+
164
+ # Show each expert's timeline
165
+ symbols = ['○', '●']
166
+ for eid in range(self.num_experts):
167
+ row = f"│ E{eid} "
168
+ for step in range(show_steps):
169
+ if self.per_step_primary[step] == eid:
170
+ row += " ● "
171
+ else:
172
+ row += " · "
173
+ if num_steps > 20:
174
+ row = row[:48] + "..│"
175
+ else:
176
+ row = row[:48].ljust(48) + " │"
177
+ lines.append(row)
178
+
179
+ lines.extend([
180
+ "├──────────────────────────────────────────────────┤",
181
+ f"│ Routing changes: {transitions:>3}/{num_steps-1:<3} steps ({100*transitions/(num_steps-1):.0f}%) │",
182
+ "╰──────────────────────────────────────────────────╯",
183
+ ])
184
+
185
+ return "\n".join(lines)
186
+
187
+
188
+ # ═══════════════════════════════════════════════════════════════════════════════
189
+ # INT8 DEQUANTIZATION
190
+ # ═══════════════════════════════════════════════════════════════════════════════
191
+
192
+ def dequantize_tensor(int8_tensor: torch.Tensor, t_min: float, t_max: float) -> torch.Tensor:
193
+ """Dequantize INT8 tensor back to float32."""
194
+ if t_min == t_max:
195
+ return torch.full_like(int8_tensor, t_min, dtype=torch.float32)
196
+ normalized = (int8_tensor.float() + 128) / 255.0
197
+ return normalized * (t_max - t_min) + t_min
198
+
199
+
200
+ def load_int8_state_dict(safetensors_path: Path) -> dict:
201
+ """Load and dequantize INT8 safetensors to float32 state_dict."""
202
+ state_dict = {}
203
+
204
+ with safe_open(str(safetensors_path), framework="pt", device="cpu") as f:
205
+ keys = list(f.keys())
206
+
207
+ # Find quantized tensors (those with _min/_max companions)
208
+ quantized_keys = set()
209
+ for key in keys:
210
+ if key.endswith('._min'):
211
+ base_key = key[:-5]
212
+ quantized_keys.add(base_key)
213
+
214
+ # Load and dequantize
215
+ for key in keys:
216
+ # Skip metadata and quantization params
217
+ if key.endswith('._min') or key.endswith('._max'):
218
+ continue
219
+ if key == '_config_json':
220
+ continue
221
+
222
+ tensor = f.get_tensor(key)
223
+
224
+ if key in quantized_keys:
225
+ # Dequantize INT8 tensor
226
+ t_min = f.get_tensor(f"{key}._min").item()
227
+ t_max = f.get_tensor(f"{key}._max").item()
228
+ tensor = dequantize_tensor(tensor, t_min, t_max)
229
+
230
+ state_dict[key] = tensor
231
+
232
+ return state_dict
233
+
234
+
235
+ # ═══════════════════════════════════════════════════════════════════════════════
236
+ # MODEL CREATION
237
+ # ═══════════════════════════════════════════════════════════════════════════════
238
+
239
+ def create_expert(config, expert_id: int = 0):
240
+ """Create a DiT expert model."""
241
+ from models import DiTExpert
242
+ return DiTExpert(config)
243
+
244
+
245
+ def create_router(config):
246
+ """Create a DiT router model."""
247
+ from models import DiTRouter
248
+ return DiTRouter(config)
249
+
250
+
251
+ # ═══════════════════════════════════════════════════════════════════════════════
252
+ # SAMPLER CLASS
253
+ # ═══════════════════════════════════════════════════════════════════════════════
254
+
255
+ class ParisSampler:
256
+ """Unified sampler for Paris MoE model with expert tracking."""
257
+
258
+ def __init__(self, experts: dict, router, vae_manager, config, device='cuda',
259
+ offloaded_experts: set = None):
260
+ self.experts = experts
261
+ self.router = router
262
+ self.vae_manager = vae_manager
263
+ self.config = config
264
+ self.device = device
265
+ self.tracker = None
266
+ self.offloaded_experts = offloaded_experts or set() # Which experts are on CPU
267
+
268
+ # Set models to eval mode
269
+ for expert in self.experts.values():
270
+ expert.eval()
271
+ if self.router is not None:
272
+ self.router.eval()
273
+
274
+ # Precompute null embeddings for CFG
275
+ self._precompute_null_embeddings()
276
+
277
+ def _precompute_null_embeddings(self):
278
+ """Precompute null embeddings for classifier-free guidance."""
279
+ try:
280
+ text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
281
+ tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
282
+ text_encoder = text_encoder.to(self.device)
283
+ text_encoder.eval()
284
+
285
+ with torch.no_grad():
286
+ null_tokens = tokenizer(
287
+ [""],
288
+ max_length=77,
289
+ padding='max_length',
290
+ truncation=True,
291
+ return_tensors='pt'
292
+ )
293
+ self.null_text_embeds = text_encoder(null_tokens.input_ids.to(self.device)).last_hidden_state
294
+ self.null_attention_mask = null_tokens.attention_mask.to(self.device)
295
+
296
+ del text_encoder, tokenizer
297
+ torch.cuda.empty_cache()
298
+ except Exception as e:
299
+ print(f"Warning: Could not precompute null text embeddings: {e}")
300
+ self.null_text_embeds = None
301
+ self.null_attention_mask = None
302
+
303
+ def _encode_text_prompts(self, text_prompts: list):
304
+ """Encode text prompts using CLIP."""
305
+ text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
306
+ tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
307
+
308
+ text_encoder = text_encoder.to(self.device)
309
+ text_encoder.eval()
310
+
311
+ tokenizer_output = tokenizer(
312
+ text_prompts,
313
+ max_length=77,
314
+ padding='max_length',
315
+ truncation=True,
316
+ return_tensors='pt'
317
+ )
318
+ tokens = tokenizer_output.input_ids.to(self.device)
319
+ attention_mask = tokenizer_output.attention_mask.to(self.device)
320
+
321
+ with torch.no_grad():
322
+ text_embeds = text_encoder(tokens).last_hidden_state
323
+
324
+ del text_encoder, tokenizer
325
+ torch.cuda.empty_cache()
326
+
327
+ return text_embeds, attention_mask
328
+
329
+ def _move_expert_to_gpu(self, expert_id: int):
330
+ """Move an offloaded expert to GPU for computation."""
331
+ if expert_id in self.offloaded_experts:
332
+ self.experts[expert_id] = self.experts[expert_id].to(self.device)
333
+ torch.cuda.synchronize() # Ensure transfer is complete
334
+
335
+ def _move_expert_to_cpu(self, expert_id: int):
336
+ """Move expert back to CPU after computation (memory offloading)."""
337
+ if expert_id in self.offloaded_experts:
338
+ self.experts[expert_id] = self.experts[expert_id].cpu()
339
+ torch.cuda.empty_cache() # Free GPU memory immediately
340
+
341
+ def _run_expert_with_cfg(self, expert_id: int, samples: torch.Tensor, t: torch.Tensor,
342
+ text_embeds: torch.Tensor, attention_mask: torch.Tensor,
343
+ null_embeds: torch.Tensor, null_mask: torch.Tensor,
344
+ cfg_scale: float) -> torch.Tensor:
345
+ """Run expert inference with optional CFG, handling memory offloading."""
346
+ # Move to GPU if offloaded
347
+ self._move_expert_to_gpu(expert_id)
348
+
349
+ expert = self.experts[expert_id]
350
+
351
+ try:
352
+ if cfg_scale != 1.0:
353
+ v_cond = expert(samples, t, text_embeds, attention_mask)
354
+ v_uncond = expert(samples, t, null_embeds, null_mask)
355
+ v_pred = v_uncond + cfg_scale * (v_cond - v_uncond)
356
+ else:
357
+ v_pred = expert(samples, t, text_embeds, attention_mask)
358
+
359
+ return v_pred
360
+ finally:
361
+ # Move back to CPU if it was offloaded (memory offloading)
362
+ self._move_expert_to_cpu(expert_id)
363
+
364
+ def sample(self, num_samples: int, text_prompts: list, cfg_scale: float = 7.5,
365
+ num_steps: int = 30, use_bf16: bool = True, track_experts: bool = False,
366
+ topk: int = 1):
367
+ """
368
+ Generate samples using expert routing.
369
+
370
+ Args:
371
+ num_samples: Number of images to generate
372
+ text_prompts: List of text prompts
373
+ cfg_scale: Classifier-free guidance scale
374
+ num_steps: Number of sampling steps
375
+ use_bf16: Use bfloat16 precision
376
+ track_experts: Track and visualize expert usage
377
+ topk: Number of experts to use per sample (1=top-1, 2=top-2, etc.)
378
+ """
379
+ # Initialize tracker if requested
380
+ if track_experts:
381
+ self.tracker = ExpertTracker(num_experts=8)
382
+ else:
383
+ self.tracker = None
384
+
385
+ text_embeds, attention_mask = self._encode_text_prompts(text_prompts)
386
+
387
+ latent_size = self.config.image_size
388
+ channels = 4
389
+ dtype = torch.bfloat16 if use_bf16 else torch.float32
390
+
391
+ # Start with random noise
392
+ samples = torch.randn(
393
+ num_samples, channels, latent_size, latent_size,
394
+ device=self.device, dtype=dtype
395
+ )
396
+
397
+ # Convert text embeds to appropriate dtype
398
+ text_embeds = text_embeds.to(dtype)
399
+ if self.null_text_embeds is not None:
400
+ null_text_embeds = self.null_text_embeds.to(dtype)
401
+ null_attention_mask = self.null_attention_mask
402
+
403
+ dt = 1.0 / num_steps
404
+
405
+ autocast_ctx = torch.amp.autocast(device_type='cuda', dtype=dtype) if use_bf16 else torch.no_grad()
406
+
407
+ with torch.no_grad(), autocast_ctx:
408
+ for i in tqdm(range(num_steps), desc="🎨 Generating"):
409
+ t = torch.ones(num_samples, device=self.device) * (1.0 - i * dt)
410
+
411
+ # Expand text embeddings if needed
412
+ batch_text_embeds = text_embeds.expand(num_samples, -1, -1) if text_embeds.shape[0] == 1 else text_embeds[:num_samples]
413
+ batch_attention_mask = attention_mask.expand(num_samples, -1) if attention_mask.shape[0] == 1 else attention_mask[:num_samples]
414
+
415
+ # Get router predictions (router expects float32)
416
+ with torch.amp.autocast(device_type='cuda', enabled=False):
417
+ router_logits = self.router(samples.float(), t.float())
418
+ expert_probs = F.softmax(router_logits, dim=1)
419
+
420
+ if topk == 1:
421
+ # Top-1 routing
422
+ expert_choices = torch.argmax(expert_probs, dim=1)
423
+
424
+ # Track expert usage
425
+ if self.tracker is not None:
426
+ self.tracker.record(expert_choices, i)
427
+
428
+ # Predict velocity for each sample using selected expert
429
+ v_pred = torch.zeros_like(samples)
430
+
431
+ for expert_id in range(8):
432
+ mask = (expert_choices == expert_id)
433
+ if mask.any():
434
+ mask_size = mask.sum().item()
435
+ null_embeds = null_text_embeds.expand(mask_size, -1, -1)
436
+ null_mask = null_attention_mask.expand(mask_size, -1)
437
+
438
+ v_batch = self._run_expert_with_cfg(
439
+ expert_id,
440
+ samples[mask], t[mask],
441
+ batch_text_embeds[mask], batch_attention_mask[mask],
442
+ null_embeds, null_mask,
443
+ cfg_scale
444
+ )
445
+ v_pred[mask] = v_batch
446
+ else:
447
+ # Top-K routing with weighted ensemble
448
+ topk_probs, topk_indices = torch.topk(expert_probs, k=min(topk, 8), dim=1)
449
+ topk_probs = topk_probs / topk_probs.sum(dim=1, keepdim=True) # Renormalize
450
+
451
+ # Track expert usage
452
+ if self.tracker is not None:
453
+ self.tracker.record(topk_indices, i, topk_probs)
454
+
455
+ v_pred = torch.zeros_like(samples)
456
+
457
+ # Process each sample
458
+ for sample_idx in range(num_samples):
459
+ v_sample = torch.zeros(channels, latent_size, latent_size,
460
+ device=self.device, dtype=dtype)
461
+
462
+ for k_idx in range(topk_indices.shape[1]):
463
+ expert_id = topk_indices[sample_idx, k_idx].item()
464
+ weight = topk_probs[sample_idx, k_idx].item()
465
+
466
+ null_embeds = null_text_embeds
467
+ null_mask = null_attention_mask
468
+
469
+ v_expert = self._run_expert_with_cfg(
470
+ expert_id,
471
+ samples[sample_idx:sample_idx+1],
472
+ t[sample_idx:sample_idx+1],
473
+ batch_text_embeds[sample_idx:sample_idx+1],
474
+ batch_attention_mask[sample_idx:sample_idx+1],
475
+ null_embeds, null_mask,
476
+ cfg_scale
477
+ )
478
+
479
+ v_sample += weight * v_expert.squeeze(0)
480
+
481
+ v_pred[sample_idx] = v_sample
482
+
483
+ # Euler integration step
484
+ samples = samples - v_pred * dt
485
+
486
+ return samples.float()
487
+
488
+
489
+ # ═══════════════════════════════════════════════════════════════════════════════
490
+ # MODEL LOADING
491
+ # ═══════════════════════════════════════════════════════════════════════════════
492
+
493
+ def load_sampler(precision: str = 'bf16', device: str = 'cuda', offload: int = 0):
494
+ """
495
+ Load Paris MoE sampler with specified precision.
496
+
497
+ Args:
498
+ precision: Weight precision ('bf16', 'int8', 'mixed')
499
+ device: Compute device ('cuda' or 'cpu')
500
+ offload: Number of experts to keep in CPU memory (0-7)
501
+ These experts will be moved to GPU only during computation.
502
+ """
503
+ from vae_utils import VAEManager
504
+
505
+ # Determine weight directories based on precision
506
+ if precision == 'bf16':
507
+ expert_dir = BF16_DIR
508
+ router_dir = BF16_DIR
509
+ use_int8_experts = False
510
+ elif precision == 'int8':
511
+ expert_dir = INT8_DIR
512
+ router_dir = BF16_DIR # Router always from bf16
513
+ use_int8_experts = True
514
+ elif precision == 'mixed':
515
+ expert_dir = INT8_DIR
516
+ router_dir = BF16_DIR
517
+ use_int8_experts = True
518
+ else:
519
+ raise ValueError(f"Unknown precision: {precision}. Use 'bf16', 'int8', or 'mixed'.")
520
+
521
+ # Load config
522
+ config_path = BF16_DIR / 'config.pt'
523
+ config_data = torch.load(config_path, map_location='cpu', weights_only=False)
524
+ config = config_data['config']
525
+
526
+ # Load router config
527
+ router_config_path = BF16_DIR / 'router_config.pt'
528
+ router_config_data = torch.load(router_config_path, map_location='cpu', weights_only=False)
529
+ router_config = router_config_data['config']
530
+
531
+ # Update config with router params
532
+ config.router_architecture = router_config.router_architecture
533
+ config.router_params = router_config.router_params
534
+
535
+ # Load router (always on GPU, bf16/float32)
536
+ print("📡 Loading router...")
537
+ router = create_router(config).to(device)
538
+ router_weights = load_file(str(router_dir / 'router.safetensors'))
539
+ router_weights = {k: v.float() for k, v in router_weights.items()}
540
+ router.load_state_dict(router_weights)
541
+ router.eval()
542
+
543
+ # Determine which experts to offload
544
+ # Offload the LAST N experts (highest IDs)
545
+ offloaded_experts = set(range(8 - offload, 8)) if offload > 0 else set()
546
+
547
+ # Load experts
548
+ experts = {}
549
+ for i in range(8):
550
+ print(f"🧠 Loading expert {i}...", end="")
551
+ expert = create_expert(config, expert_id=i)
552
+
553
+ if use_int8_experts:
554
+ expert_weights = load_int8_state_dict(expert_dir / f'expert_{i}.safetensors')
555
+ else:
556
+ expert_weights = load_file(str(expert_dir / f'expert_{i}.safetensors'))
557
+
558
+ expert.load_state_dict(expert_weights)
559
+ expert.eval()
560
+
561
+ # Convert to bf16 if using bf16 precision
562
+ if precision == 'bf16':
563
+ expert = expert.to(torch.bfloat16)
564
+
565
+ # Decide where to place the expert
566
+ if i in offloaded_experts:
567
+ expert = expert.cpu() # Keep in CPU memory
568
+ print(f" 💾 (CPU memory, GPU compute)")
569
+ else:
570
+ expert = expert.to(device) # Keep on GPU
571
+ print(f" 🎮 (GPU)")
572
+
573
+ experts[i] = expert
574
+
575
+ # Load VAE
576
+ print("🖼️ Loading VAE...")
577
+ vae_manager = VAEManager(device=device)
578
+
579
+ return ParisSampler(experts, router, vae_manager, config, device, offloaded_experts)
580
+
581
+
582
+ # ═══════════════════════════════════════════════════════════════════════════════
583
+ # MAIN ENTRYPOINT
584
+ # ═══════════════════════════════════════════════════════════════════════════════
585
+
586
+ def parse_args():
587
+ parser = argparse.ArgumentParser(
588
+ description="🎨 Paris MoE - Image Generation",
589
+ formatter_class=argparse.RawDescriptionHelpFormatter,
590
+ epilog="""
591
+ Examples:
592
+ python generate.py --prompt "a cute cat playing piano"
593
+ python generate.py --prompt "sunset over mountains" --precision int8 --visualize
594
+ python generate.py --prompt "abstract art" --num_samples 4 --cfg_scale 10 --topk 2
595
+ python generate.py --prompt "portrait" --offload 4 # Offload 4 experts to CPU memory
596
+ """
597
+ )
598
+
599
+ parser.add_argument("--prompt", type=str, default="a cute cat",
600
+ help="Text prompt for generation")
601
+ parser.add_argument("--num_samples", type=int, default=16,
602
+ help="Number of images to generate (default: 16)")
603
+ parser.add_argument("--cfg_scale", type=float, default=7.5,
604
+ help="Classifier-free guidance scale (default: 7.5)")
605
+ parser.add_argument("--num_steps", type=int, default=30,
606
+ help="Number of sampling steps (default: 30)")
607
+ parser.add_argument("--seed", type=int, default=999,
608
+ help="Random seed for reproducibility")
609
+ parser.add_argument("--output", type=str, default=None,
610
+ help="Output image path (default: output_<precision>.png)")
611
+ parser.add_argument("--precision", type=str, default="bf16",
612
+ choices=["bf16", "int8", "mixed"],
613
+ help="Weight precision: bf16, int8, or mixed (default: bf16)")
614
+ parser.add_argument("--offload", type=int, default=0,
615
+ help="Number of experts to keep in CPU memory (0-7). Computation still on GPU.")
616
+ parser.add_argument("--topk", type=int, default=2,
617
+ help="Top-K expert routing (1=top-1, 2=top-2 ensemble, etc.) [default: 2]")
618
+ parser.add_argument("--visualize", action="store_true",
619
+ help="Show expert usage visualization")
620
+ parser.add_argument("--no-save", action="store_true",
621
+ help="Don't save output image (for testing)")
622
+
623
+ return parser.parse_args()
624
+
625
+
626
+ def print_header():
627
+ """Print beautiful ASCII header."""
628
+ print("""
629
+ ╔══════════════════════════════════════════════════════════════════════════════╗
630
+ ║ ║
631
+ ║ ██████╗ █████╗ ██████╗ ██╗███████╗ ███╗ ███╗ ██████╗ ███████╗ ║
632
+ ║ ██╔══██╗██╔══██╗██╔══██╗██║██╔════╝ ████╗ ████║██╔═══██╗██╔════╝ ║
633
+ ║ ██████╔╝███████║██████╔╝██║███████╗ ██╔████╔██║██║ ██║█████╗ ║
634
+ ║ ██╔═══╝ ██╔══██║██╔══██╗██║╚════██║ ██║╚██╔╝██║██║ ██║██╔══╝ ║
635
+ ║ ██║ ██║ ██║██║ ██║██║███████║ ██║ ╚═╝ ██║╚██████╔╝███████╗ ║
636
+ ║ ╚═╝ ╚═╝ ╚═╝╚═╝ ╚═╝╚═╝╚══════╝ ╚═╝ ╚═╝ ╚═════╝ ╚══════╝ ║
637
+ ║ ║
638
+ ║ 🎨 Mixture-of-Experts Text-to-Image Diffusion Model ║
639
+ ║ 📊 8× DiT-XL/2 Experts + DiT-B/2 Router (~5B Parameters) ║
640
+ ║ ║
641
+ ╚══════════════════════════════════════════════════════════════════════════════╝
642
+ """)
643
+
644
+
645
+ def print_config(args):
646
+ """Print configuration summary."""
647
+ offload_str = f"{args.offload} experts (CPU mem, GPU compute)" if args.offload > 0 else "None"
648
+ topk_str = f"Top-{args.topk}" if args.topk > 1 else "Top-1"
649
+
650
+ print(f"""
651
+ ┌──────────────────────────────────────────────────────────────────────────────┐
652
+ │ 📋 Configuration │
653
+ ├──────────────────────────────────────────────────────────────────────────────┤
654
+ │ Prompt: {args.prompt[:50]:<50}│
655
+ │ Samples: {args.num_samples:<50}│
656
+ │ Steps: {args.num_steps:<50}│
657
+ │ CFG Scale: {args.cfg_scale:<50}│
658
+ │ Precision: {args.precision.upper():<50}│
659
+ │ Routing: {topk_str:<50}│
660
+ │ Seed: {args.seed:<50}│
661
+ │ Offload: {offload_str:<50}│
662
+ └──────────────────────────────────────────────────────────────────────────────┘
663
+ """)
664
+
665
+
666
+ def main():
667
+ args = parse_args()
668
+
669
+ # Print header
670
+ print_header()
671
+ print_config(args)
672
+
673
+ # Set seed
674
+ torch.manual_seed(args.seed)
675
+ if torch.cuda.is_available():
676
+ torch.cuda.manual_seed(args.seed)
677
+
678
+ device = "cuda" if torch.cuda.is_available() else "cpu"
679
+ print(f"🖥️ Using device: {device}")
680
+
681
+ # Load sampler
682
+ print(f"\n📦 Loading {args.precision.upper()} weights...")
683
+ start_load = time.time()
684
+ sampler = load_sampler(
685
+ precision=args.precision,
686
+ device=device,
687
+ offload=args.offload
688
+ )
689
+ load_time = time.time() - start_load
690
+ print(f"⏱️ Model loaded in {load_time:.1f}s")
691
+
692
+ # Generate samples
693
+ print(f"\n🎨 Generating {args.num_samples} images...")
694
+ start_gen = time.time()
695
+ latents = sampler.sample(
696
+ num_samples=args.num_samples,
697
+ text_prompts=[args.prompt],
698
+ cfg_scale=args.cfg_scale,
699
+ num_steps=args.num_steps,
700
+ use_bf16=(args.precision == 'bf16'),
701
+ track_experts=args.visualize,
702
+ topk=args.topk
703
+ )
704
+ gen_time = time.time() - start_gen
705
+
706
+ # Show visualization if requested
707
+ if args.visualize and sampler.tracker is not None:
708
+ print(sampler.tracker.get_usage_chart())
709
+ print(sampler.tracker.get_timeline())
710
+
711
+ # Decode latents
712
+ print("\n🖼️ Decoding latents...")
713
+ start_decode = time.time()
714
+ images = sampler.vae_manager.decode(latents)
715
+ images = (images + 1.0) / 2.0
716
+ images = torch.clamp(images, 0, 1)
717
+ decode_time = time.time() - start_decode
718
+
719
+ # Save output
720
+ if not args.no_save:
721
+ output_path = args.output or f"output_{args.precision}.png"
722
+ nrow = 4 if args.num_samples >= 4 else args.num_samples
723
+ grid = make_grid(images.cpu(), nrow=nrow, normalize=False, padding=2)
724
+ save_image(grid, output_path)
725
+ print(f"\n✅ Saved to: {output_path}")
726
+
727
+ # Print timing summary
728
+ total_time = load_time + gen_time + decode_time
729
+ throughput = args.num_samples / gen_time
730
+
731
+ print(f"""
732
+ ╔══════════════════════════════════════════════════════════════════════════════╗
733
+ ║ ⏱️ Timing Summary ⏱️ ║
734
+ ╠══════════════════════════════════════════════════════════════════════════════╣
735
+ ║ Model loading: {load_time:>6.1f}s ║
736
+ ║ Generation: {gen_time:>6.1f}s ({throughput:.2f} img/s, {gen_time/args.num_steps:.2f}s/step) ║
737
+ ║ VAE decoding: {decode_time:>6.1f}s ║
738
+ ║ ────────────────────────────── ║
739
+ ║ Total: {total_time:>6.1f}s ║
740
+ ╚══════════════════════════════════════════════════════════════════════════════╝
741
+ """)
742
+
743
+ print("🎉 Done!")
744
+
745
+
746
+ if __name__ == "__main__":
747
+ main()
instructions.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ What we're now is we're going to prepare an inference folder and we're going to make an inference repository for our Paris model. This will include, we will just stick with int8 and bfloat16 and mixed int8, bfloat16 for now. And this repository will include efficient methods how to run the code. It will include quantization code that can accept either PT or save tensors or Float32 save tensors or Float32 PT. It will include a lot of different methods that can accept either PT or Float32 PT. We need to make a visualizer next which outputs a little pretty ASCII chart. We should output the ASCII chart right on the terminal every time we run the inference via this tool. Let's just say we're running the int8 inference of the mixed int8 model. By the way, we're also going to put the weights that we quantized inside this inference folder because we're going to publish this on HuggingFace. have just again, the beef flow 16 and intake weights. we might already be done this by the way. But again, I wanted to do that when we have to keep some kind of track and output a chart in the terminal, like as a little terminal visualization in ASCII. MAKE SURE WE'RE DOING ROUTING PROPERLY. Top 2 etc. Again, just to recap, we're going to make a folder that's just called inference. In this folder, we're going to put the quantized weights that we already made, because we already made them before in the last session. So the bfloat16 and the int8 weights. And we're going to put one Python file for the inference code, and it's going to have all the flags, and it's also going to have a visualized flag. And the visualized flag is actually a lot more than that, because it keeps track of which expert is being used during each inference step, and that shows like a little pretty chart. So if we're generating with 30 steps, which is going to show which experts got to use the most and the least out of eight of them. And so we want to have this in the inference code. Make sure to read files in full before like a pass inference code that we already wrote. Try to list like the most recent files that we made for that. And we also want to have the quantization code to just be an all in one utility with a very nice terminal interface as well, because we want the quantization code to be able to handle float 16 bfloat 16 float 32 weights in both safe tensors and in dot pt format. So that needs to be very smart and tested that it actually works. And also, yeah, make a read me in this folder for the Paris model, because we're going to publish this on hogging face as the inference repository. And then also read all the MD files that we have written here in full because after we do all of this and after we test that it works and it differences fine. We're going to we're going to start to play around with network inference. So that's going to be the fun next step after. So again, make a 20 point to do this for this and please make sure to include at least four or five sentences per point. So the to do list is going to be very long, naturally and very detailed. But I believe we're going to do an excellent, excellent job here.
quantize.py ADDED
@@ -0,0 +1,435 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ ╔══════════════════════════════════════════════════════════════════════════════╗
4
+ ║ ║
5
+ ║ 🔧 Paris MoE - Weight Quantization Utility 🔧 ║
6
+ ║ ║
7
+ ║ Converts weights between formats: ║
8
+ ║ • Input: .pt (PyTorch) or .safetensors (F32 or BF16) ║
9
+ ║ • Output: BF16 or INT8 safetensors ║
10
+ ║ ║
11
+ ╚══════════════════════════════════════════════════════════════════════════════╝
12
+
13
+ Usage:
14
+ # Convert original .pt files to BF16 safetensors
15
+ python quantize.py --input /path/to/weights/ --output ./weights/bf16 --format bf16
16
+
17
+ # Convert to INT8 safetensors
18
+ python quantize.py --input /path/to/weights/ --output ./weights/int8 --format int8
19
+
20
+ # Convert from existing safetensors (bf16 -> int8)
21
+ python quantize.py --input ./weights/bf16 --output ./weights/int8 --format int8
22
+
23
+ Input Formats Supported:
24
+ - PyTorch .pt files (original training checkpoints)
25
+ - SafeTensors .safetensors files (F32 or BF16)
26
+
27
+ Output Formats:
28
+ - bf16: BFloat16 safetensors (best quality, ~1.2GB per expert)
29
+ - int8: INT8 quantized safetensors (~580MB per expert)
30
+ """
31
+
32
+ import argparse
33
+ import os
34
+ import gc
35
+ from pathlib import Path
36
+ from typing import Dict, Optional, Tuple
37
+ import json
38
+
39
+ import torch
40
+ from safetensors.torch import save_file, load_file
41
+ from safetensors import safe_open
42
+ from tqdm import tqdm
43
+
44
+
45
+ # ═══════════════════════════════════════════════════════════════════════════════
46
+ # FILE DETECTION
47
+ # ═══════════════════════════════════════════════════════════════════════════════
48
+
49
+ def detect_input_format(input_dir: Path) -> Tuple[str, Dict[str, Path]]:
50
+ """
51
+ Detect input format and locate weight files.
52
+
53
+ Returns:
54
+ format: 'pt' or 'safetensors'
55
+ files: Dict mapping 'expert_0'..'expert_7', 'router' to file paths
56
+ """
57
+ files = {}
58
+
59
+ # Check for PyTorch .pt files (original format)
60
+ pt_patterns = [
61
+ # Pattern 1: Full training checkpoint names
62
+ ("dit_xl2_multi_expert_pretrained_text_new_dataset_expert_{}_best.pt", "expert_{}"),
63
+ ("laion_router_preclustered_dit_berthead_b2_improved_router_best.pt", "router"),
64
+ # Pattern 2: Simple names
65
+ ("expert_{}_best.pt", "expert_{}"),
66
+ ("expert_{}.pt", "expert_{}"),
67
+ ("router_best.pt", "router"),
68
+ ("router.pt", "router"),
69
+ ]
70
+
71
+ # Check for SafeTensors files
72
+ st_patterns = [
73
+ ("expert_{}.safetensors", "expert_{}"),
74
+ ("router.safetensors", "router"),
75
+ ]
76
+
77
+ # Try PyTorch patterns first
78
+ for pattern, key_pattern in pt_patterns:
79
+ if "{}" in pattern:
80
+ # Expert pattern
81
+ for i in range(8):
82
+ filename = pattern.format(i)
83
+ filepath = input_dir / filename
84
+ if filepath.exists():
85
+ key = key_pattern.format(i)
86
+ files[key] = filepath
87
+ else:
88
+ # Router pattern
89
+ filepath = input_dir / pattern
90
+ if filepath.exists():
91
+ files[key_pattern] = filepath
92
+
93
+ if len(files) >= 8: # At least 8 experts found
94
+ return 'pt', files
95
+
96
+ # Try SafeTensors patterns
97
+ files = {}
98
+ for pattern, key_pattern in st_patterns:
99
+ if "{}" in pattern:
100
+ for i in range(8):
101
+ filename = pattern.format(i)
102
+ filepath = input_dir / filename
103
+ if filepath.exists():
104
+ key = key_pattern.format(i)
105
+ files[key] = filepath
106
+ else:
107
+ filepath = input_dir / pattern
108
+ if filepath.exists():
109
+ files[key_pattern] = filepath
110
+
111
+ if len(files) >= 8:
112
+ return 'safetensors', files
113
+
114
+ # List what we found
115
+ print(f"Found files in {input_dir}:")
116
+ for f in sorted(input_dir.glob("*")):
117
+ print(f" {f.name}")
118
+
119
+ raise ValueError(f"Could not find weight files in {input_dir}")
120
+
121
+
122
+ # ═══════════════════════════════════════════════════════════════════════════════
123
+ # LOADING UTILITIES
124
+ # ═══════════════════════════════════════════════════════════════════════════════
125
+
126
+ def load_pt_expert(filepath: Path, expert_id: int) -> Tuple[dict, Optional[object]]:
127
+ """
128
+ Load expert weights from PyTorch checkpoint.
129
+
130
+ Returns:
131
+ state_dict: Model weights
132
+ config: Config object if available
133
+ """
134
+ print(f" Loading {filepath.name}...")
135
+ ckpt = torch.load(filepath, map_location='cpu', weights_only=False)
136
+
137
+ # Try EMA weights first (preferred for inference)
138
+ ema_key = f'expert_{expert_id}_ema_state_dict'
139
+ regular_key = f'expert_{expert_id}_state_dict'
140
+
141
+ if ema_key in ckpt:
142
+ state_dict = ckpt[ema_key]
143
+ print(f" Using EMA weights")
144
+ elif regular_key in ckpt:
145
+ state_dict = ckpt[regular_key]
146
+ print(f" Using regular weights (no EMA)")
147
+ else:
148
+ # Try to find any state dict key
149
+ for k in ckpt.keys():
150
+ if 'state_dict' in k and 'optimizer' not in k:
151
+ state_dict = ckpt[k]
152
+ print(f" Using key: {k}")
153
+ break
154
+ else:
155
+ raise KeyError(f"No state dict found in {filepath}")
156
+
157
+ config = ckpt.get('config', None)
158
+ return state_dict, config
159
+
160
+
161
+ def load_pt_router(filepath: Path) -> Tuple[dict, Optional[object]]:
162
+ """Load router weights from PyTorch checkpoint."""
163
+ print(f" Loading {filepath.name}...")
164
+ ckpt = torch.load(filepath, map_location='cpu', weights_only=False)
165
+
166
+ if 'router_state_dict' in ckpt:
167
+ state_dict = ckpt['router_state_dict']
168
+ else:
169
+ raise KeyError(f"router_state_dict not found in {filepath}")
170
+
171
+ config = ckpt.get('config', None)
172
+ return state_dict, config
173
+
174
+
175
+ def load_safetensors_weights(filepath: Path) -> dict:
176
+ """Load weights from SafeTensors file."""
177
+ print(f" Loading {filepath.name}...")
178
+ return load_file(str(filepath))
179
+
180
+
181
+ # ═══════════════════════════════════════════════════════════════════════════════
182
+ # QUANTIZATION
183
+ # ═══════════════════════════════════════════════════════════════════════════════
184
+
185
+ def convert_to_bf16(state_dict: dict) -> dict:
186
+ """Convert all floating point tensors to bfloat16."""
187
+ bf16_state = {}
188
+ for k, v in state_dict.items():
189
+ if isinstance(v, torch.Tensor) and v.is_floating_point():
190
+ bf16_state[k] = v.to(torch.bfloat16)
191
+ else:
192
+ bf16_state[k] = v
193
+ return bf16_state
194
+
195
+
196
+ def is_layernorm_key(key: str) -> bool:
197
+ """Check if a key belongs to a LayerNorm layer."""
198
+ ln_patterns = ['norm', 'layernorm', 'layer_norm', 'ln_', 'scale_shift_table']
199
+ key_lower = key.lower()
200
+ return any(p in key_lower for p in ln_patterns)
201
+
202
+
203
+ def quantize_tensor_int8(tensor: torch.Tensor) -> Tuple[torch.Tensor, float, float]:
204
+ """
205
+ Quantize a tensor to INT8 with min/max scaling.
206
+
207
+ Formula: int8 = round((x - min) / (max - min) * 255) - 128
208
+ """
209
+ if tensor.numel() == 0:
210
+ return tensor.to(torch.int8), 0.0, 0.0
211
+
212
+ t_float = tensor.float()
213
+ t_min = t_float.min().item()
214
+ t_max = t_float.max().item()
215
+
216
+ if t_min == t_max:
217
+ return torch.zeros_like(tensor, dtype=torch.int8), t_min, t_max
218
+
219
+ # Quantize: map [min, max] to [-128, 127]
220
+ normalized = (t_float - t_min) / (t_max - t_min)
221
+ int8_tensor = (normalized * 255 - 128).round().clamp(-128, 127).to(torch.int8)
222
+
223
+ return int8_tensor, t_min, t_max
224
+
225
+
226
+ def convert_to_int8(state_dict: dict) -> dict:
227
+ """
228
+ Convert state dict to INT8 quantized format.
229
+
230
+ LayerNorm and small tensors are kept in float32.
231
+ Quantization parameters (_min, _max) are stored alongside.
232
+ """
233
+ quantized = {}
234
+ stats = {'float32': 0, 'int8': 0}
235
+
236
+ for key, tensor in state_dict.items():
237
+ if not isinstance(tensor, torch.Tensor):
238
+ continue
239
+
240
+ # Skip LayerNorm layers - keep as float32
241
+ if is_layernorm_key(key):
242
+ quantized[key] = tensor.float()
243
+ stats['float32'] += tensor.numel()
244
+ # Only quantize weight tensors with enough elements
245
+ elif tensor.numel() >= 16 and tensor.dtype in [torch.float32, torch.float16, torch.bfloat16]:
246
+ int8_tensor, t_min, t_max = quantize_tensor_int8(tensor)
247
+ quantized[key] = int8_tensor
248
+ quantized[f"{key}._min"] = torch.tensor([t_min], dtype=torch.float32)
249
+ quantized[f"{key}._max"] = torch.tensor([t_max], dtype=torch.float32)
250
+ stats['int8'] += tensor.numel()
251
+ else:
252
+ # Keep small tensors as float32
253
+ quantized[key] = tensor.float()
254
+ stats['float32'] += tensor.numel()
255
+
256
+ return quantized, stats
257
+
258
+
259
+ # ═══════════════════════════════════════════════════════════════════════════════
260
+ # MAIN CONVERSION
261
+ # ═══════════════════════════════════════════════════════════════════════════════
262
+
263
+ def convert_weights(input_dir: Path, output_dir: Path, output_format: str):
264
+ """
265
+ Convert weights to specified format.
266
+
267
+ Args:
268
+ input_dir: Directory containing input weights
269
+ output_dir: Directory to write output weights
270
+ output_format: 'bf16' or 'int8'
271
+ """
272
+ print(f"""
273
+ ╔══════════════════════════════════════════════════════════════════════════════╗
274
+ ║ 🔧 Paris MoE Weight Conversion 🔧 ║
275
+ ╠══════════════════════════════════════════════════════════════════════════════╣
276
+ ║ Input: {str(input_dir):<60} ║
277
+ ║ Output: {str(output_dir):<60} ║
278
+ ║ Format: {output_format.upper():<60} ║
279
+ ╚══════════════════════════════════════════════════════════════════════════════╝
280
+ """)
281
+
282
+ # Detect input format
283
+ input_format, files = detect_input_format(input_dir)
284
+ print(f"📂 Detected input format: {input_format}")
285
+ print(f"📁 Found {len(files)} weight files")
286
+
287
+ # Create output directory
288
+ output_dir.mkdir(parents=True, exist_ok=True)
289
+
290
+ # Track sizes
291
+ sizes = {'input': 0, 'output': 0}
292
+ expert_config = None
293
+ router_config = None
294
+
295
+ # Process experts
296
+ print("\n🧠 Converting experts...")
297
+ for i in range(8):
298
+ key = f"expert_{i}"
299
+ if key not in files:
300
+ print(f" ⚠️ {key} not found, skipping")
301
+ continue
302
+
303
+ filepath = files[key]
304
+ sizes['input'] += filepath.stat().st_size
305
+
306
+ # Load weights
307
+ if input_format == 'pt':
308
+ state_dict, config = load_pt_expert(filepath, i)
309
+ if config is not None and expert_config is None:
310
+ expert_config = config
311
+ else:
312
+ state_dict = load_safetensors_weights(filepath)
313
+
314
+ # Convert
315
+ if output_format == 'bf16':
316
+ converted = convert_to_bf16(state_dict)
317
+ else:
318
+ converted, stats = convert_to_int8(state_dict)
319
+ print(f" INT8: {stats['int8']:,} params, Float32: {stats['float32']:,} params")
320
+
321
+ # Save
322
+ output_path = output_dir / f"expert_{i}.safetensors"
323
+ save_file(converted, str(output_path))
324
+ sizes['output'] += output_path.stat().st_size
325
+
326
+ print(f" ✅ Saved: {output_path.name} ({output_path.stat().st_size / 1e6:.1f} MB)")
327
+
328
+ # Clean up
329
+ del state_dict, converted
330
+ gc.collect()
331
+
332
+ # Process router
333
+ if 'router' in files:
334
+ print("\n📡 Converting router...")
335
+ filepath = files['router']
336
+ sizes['input'] += filepath.stat().st_size
337
+
338
+ if input_format == 'pt':
339
+ state_dict, config = load_pt_router(filepath)
340
+ if config is not None:
341
+ router_config = config
342
+ else:
343
+ state_dict = load_safetensors_weights(filepath)
344
+
345
+ # Router always kept in bf16/float32 for stability
346
+ converted = convert_to_bf16(state_dict)
347
+
348
+ output_path = output_dir / "router.safetensors"
349
+ save_file(converted, str(output_path))
350
+ sizes['output'] += output_path.stat().st_size
351
+
352
+ print(f" ✅ Saved: {output_path.name} ({output_path.stat().st_size / 1e6:.1f} MB)")
353
+
354
+ del state_dict, converted
355
+ gc.collect()
356
+
357
+ # Save configs if from .pt files
358
+ if expert_config is not None:
359
+ config_path = output_dir / "config.pt"
360
+ torch.save({'config': expert_config}, config_path)
361
+ print(f" ✅ Saved: config.pt")
362
+
363
+ if router_config is not None:
364
+ config_path = output_dir / "router_config.pt"
365
+ torch.save({'config': router_config}, config_path)
366
+ print(f" ✅ Saved: router_config.pt")
367
+
368
+ # Summary
369
+ compression = sizes['input'] / sizes['output'] if sizes['output'] > 0 else 1
370
+ print(f"""
371
+ ╔══════════════════════════════════════════════════════════════════════════════╗
372
+ ║ 📊 Conversion Summary 📊 ║
373
+ ╠══════════════════════════════════════════════════════════════════════════════╣
374
+ ║ Input size: {sizes['input']/1e9:>8.2f} GB ║
375
+ ║ Output size: {sizes['output']/1e9:>8.2f} GB ║
376
+ ║ Compression: {compression:>8.1f}x ║
377
+ ╠══════════════════════════════════════════════════════════════════════════════╣
378
+ ║ ✅ Conversion complete! ║
379
+ ╚══════════════════════════════════════════════════════════════════════════════╝
380
+ """)
381
+
382
+ # List output files
383
+ print("📁 Output files:")
384
+ for f in sorted(output_dir.glob("*")):
385
+ print(f" {f.name}: {f.stat().st_size/1e6:.1f} MB")
386
+
387
+
388
+ # ═══════════════════════════════════════════════════════════════════════════════
389
+ # CLI
390
+ # ═══════════════════════════════════════════════════════════════════════════════
391
+
392
+ def parse_args():
393
+ parser = argparse.ArgumentParser(
394
+ description="🔧 Paris MoE - Weight Quantization Utility",
395
+ formatter_class=argparse.RawDescriptionHelpFormatter,
396
+ epilog="""
397
+ Examples:
398
+ # Convert original .pt files to BF16
399
+ python quantize.py --input /path/to/weights --output ./weights/bf16 --format bf16
400
+
401
+ # Convert to INT8 from .pt files
402
+ python quantize.py --input /path/to/weights --output ./weights/int8 --format int8
403
+
404
+ # Convert from BF16 safetensors to INT8
405
+ python quantize.py --input ./weights/bf16 --output ./weights/int8 --format int8
406
+ """
407
+ )
408
+
409
+ parser.add_argument("--input", "-i", type=str, required=True,
410
+ help="Input directory containing weight files")
411
+ parser.add_argument("--output", "-o", type=str, required=True,
412
+ help="Output directory for converted weights")
413
+ parser.add_argument("--format", "-f", type=str, required=True,
414
+ choices=["bf16", "int8"],
415
+ help="Output format: bf16 or int8")
416
+
417
+ return parser.parse_args()
418
+
419
+
420
+ def main():
421
+ args = parse_args()
422
+
423
+ input_dir = Path(args.input)
424
+ output_dir = Path(args.output)
425
+
426
+ if not input_dir.exists():
427
+ print(f"❌ Error: Input directory does not exist: {input_dir}")
428
+ return 1
429
+
430
+ convert_weights(input_dir, output_dir, args.format)
431
+ return 0
432
+
433
+
434
+ if __name__ == "__main__":
435
+ exit(main())
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ torch>=2.0
2
+ torchvision
3
+ safetensors
4
+ transformers
5
+ diffusers
6
+ accelerate
7
+ tqdm
src/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Paris MoE Inference - Source modules
src/__pycache__/config.cpython-312.pyc ADDED
Binary file (7.8 kB). View file
 
src/__pycache__/models.cpython-312.pyc ADDED
Binary file (90 kB). View file
 
src/__pycache__/schedules.cpython-312.pyc ADDED
Binary file (7.41 kB). View file
 
src/__pycache__/vae_utils.cpython-312.pyc ADDED
Binary file (8.65 kB). View file
 
src/config.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # src/config.py
2
+ import yaml
3
+ from typing import Dict, Any, Optional
4
+ from dataclasses import dataclass
5
+
6
+ @dataclass
7
+ class Config:
8
+ """Single config class - no inheritance needed"""
9
+
10
+ # Experiment
11
+ experiment_name: str
12
+ seed: int = 42
13
+
14
+ # Dataset
15
+ dataset_name: str = "cifar10"
16
+ image_size: int = 64
17
+ num_channels: Optional[int] = None # If None, auto-determined based on dataset/latents
18
+ data_path: str = "./data"
19
+ download: bool = True
20
+ use_latents: bool = False # Whether to use VAE latents instead of raw images
21
+ latent_data_path: Optional[str] = None # Path to latent dataset JSON
22
+ split_strategy: str = "global" # "global" or "per_cluster"
23
+ preclustered_data_path: Optional[str] = None # Path to pre-clustered data
24
+ train_ratio: float = 0.95 # Train/val split ratio
25
+
26
+ # Clustering (None for monolithic)
27
+ clustering_method: Optional[str] = None # "manual", "kmeans", <- note that we dont support dino as an on-the-fly clustering method
28
+ num_clusters: int = 1
29
+ manual_mapping: Optional[Dict[int, int]] = None
30
+
31
+ # Model
32
+ num_experts: int = 1 # 1 = monolithic, >1 = DDM
33
+ expert_architecture: str = "unet" # "unet", "dit", "simple_cnn"
34
+ router_architecture: str = "none" # "vit", "cnn", "dit", "none"
35
+ router_pretrained: bool = True
36
+ clip_tokenizer_name: str = "openai/clip-vit-large-patch14"
37
+
38
+ # Training
39
+ batch_size: int = 32
40
+ num_epochs: int = 20
41
+ learning_rate: float = 1e-4
42
+ optimizer: str = "adamw"
43
+ mixed_precision: bool = True
44
+ num_gpus: int = 1
45
+ distributed: bool = False
46
+ train_router_jointly: bool = False
47
+ weight_decay: float = 0
48
+ use_lr_scheduler: bool = True
49
+ warmup_steps: int = 0 # Learning rate warmup steps
50
+ warmup_factor: float = 0.1 # Learning rate warmup factor
51
+ grad_accum_steps: int = 1
52
+ use_amp: bool = True
53
+ imagenet_pretrain_checkpoint: Optional[str] = None
54
+
55
+ # Cluster imbalance handling
56
+ use_class_weights: bool = False # Enable class weighting for imbalanced clusters
57
+ weight_smoothing: float = 0.0 # Weight smoothing factor (0.0-1.0)
58
+
59
+ # New dataset training options
60
+ new_dataset_learning_rate: Optional[float] = None
61
+ reset_optimizer: bool = True
62
+ reset_scheduler: bool = True
63
+ reset_epoch: bool = True
64
+ reset_ema: bool = False
65
+
66
+ # Decentralized training
67
+ expert_parallel: bool = False
68
+ target_expert_id: int = 0
69
+ target_gpu_id: int = 0
70
+
71
+ # FID evaluation
72
+ compute_fid: bool = False
73
+ fid_every: int = 5000
74
+ fid_num_samples: int = 5000
75
+ fid_batch_size: int = 50
76
+
77
+ # EMA parameters
78
+ use_ema: bool = True
79
+ ema_decay: float = 0.9999
80
+ ema_update_every: int = 1
81
+
82
+ # Heterogeneous objectives
83
+ expert_objectives: Optional[Dict[int, str]] = None # {expert_id: "ddpm"|"fm"|"rf"}
84
+ default_objective: str = "fm" # Default if expert_objectives not specified
85
+
86
+ # Schedule configuration (NEW)
87
+ schedule_type: str = "linear_interp" # Default for backward compatibility
88
+ expert_schedule_types: Optional[Dict[int, str]] = None # Per-expert schedules for Strategy B
89
+
90
+ # Consistency loss (NEW)
91
+ use_consistency_loss: bool = False
92
+ consistency_loss_weight: float = 0.1
93
+
94
+ # Model parameters (flexible dicts)
95
+ expert_params: Dict[str, Any] = None
96
+ router_params: Dict[str, Any] = None
97
+ video_config: Dict[str, Any] = None # Video-specific parameters (temporal_frames, latent_height, etc.)
98
+
99
+ # Inference
100
+ sampling_strategy: str = "top1" # "top1", "topk", "full", "monolithic"
101
+ num_inference_steps: int = 50
102
+
103
+ # Diffusion settings
104
+ beta_start: float = 0.0001
105
+ beta_end: float = 0.02
106
+ beta_schedule: str = "linear"
107
+ max_text_length: int = 77
108
+
109
+ # Paths
110
+ checkpoint_dir: str = "./outputs/checkpoints"
111
+ log_dir: str = "./outputs/logs"
112
+
113
+ def __post_init__(self) -> None:
114
+ # Set defaults for missing fields
115
+ if self.expert_params is None:
116
+ self.expert_params = {}
117
+ if self.router_params is None:
118
+ self.router_params = {}
119
+ if self.video_config is None:
120
+ self.video_config = {}
121
+
122
+ # Auto-determine num_channels if not explicitly set
123
+ if self.num_channels is None:
124
+ if self.use_latents:
125
+ self.num_channels = 4 # VAE latent channels
126
+ elif self.dataset_name in ["mnist", "fashionmnist"]:
127
+ self.num_channels = 1
128
+ else:
129
+ self.num_channels = 3
130
+
131
+ # Initialize and validate expert_objectives
132
+ valid_objectives = {"ddpm", "fm", "rf"}
133
+
134
+ # Validate default_objective
135
+ if self.default_objective not in valid_objectives:
136
+ raise ValueError(f"default_objective must be one of {valid_objectives}, got {self.default_objective}")
137
+
138
+ # Initialize expert_objectives if None
139
+ if self.expert_objectives is None:
140
+ self.expert_objectives = {i: self.default_objective for i in range(self.num_experts)}
141
+ else:
142
+ # Validate all objective types
143
+ for expert_id, obj_type in self.expert_objectives.items():
144
+ if obj_type not in valid_objectives:
145
+ raise ValueError(f"Expert {expert_id} has invalid objective '{obj_type}'. Must be one of {valid_objectives}")
146
+
147
+ # Ensure all expert IDs have objectives assigned
148
+ for expert_id in range(self.num_experts):
149
+ if expert_id not in self.expert_objectives:
150
+ self.expert_objectives[expert_id] = self.default_objective
151
+
152
+ # Validate schedule types (NEW)
153
+ valid_schedules = {"cosine", "linear_beta", "linear_interp"}
154
+
155
+ # Validate default schedule_type
156
+ if self.schedule_type not in valid_schedules:
157
+ raise ValueError(f"schedule_type must be one of {valid_schedules}, got {self.schedule_type}")
158
+
159
+ # Validate expert_schedule_types if provided
160
+ if self.expert_schedule_types is not None:
161
+ for expert_id, sched_type in self.expert_schedule_types.items():
162
+ if sched_type not in valid_schedules:
163
+ raise ValueError(f"Expert {expert_id} has invalid schedule '{sched_type}'. Must be one of {valid_schedules}")
164
+
165
+ @classmethod
166
+ def from_yaml(cls, config_path: str) -> 'Config':
167
+ with open(config_path, 'r') as f:
168
+ config_dict = yaml.safe_load(f)
169
+
170
+ # Set defaults for missing fields
171
+ config_dict.setdefault('expert_params', {})
172
+ config_dict.setdefault('router_params', {})
173
+
174
+ # If num_experts is not specified, default to num_clusters (or 1 if num_clusters is not set)
175
+ if 'num_experts' not in config_dict:
176
+ num_clusters = config_dict.get('num_clusters', 1)
177
+ config_dict['num_experts'] = max(1, num_clusters)
178
+
179
+ return cls(**config_dict)
180
+
181
+ @property
182
+ def is_monolithic(self) -> bool:
183
+ return self.num_experts == 1
184
+
185
+
186
+ @property
187
+ def num_classes(self) -> int:
188
+ dataset_classes = {
189
+ "mnist": 10, "fashionmnist": 10,
190
+ "cifar10": 10, "cifar100": 100,
191
+ "celeba": 0, # No class conditioning
192
+ "butterfly": 1, # Single class for butterflies
193
+ "laion": 0 # No class conditioning for LAION
194
+ }
195
+ return dataset_classes.get(self.dataset_name, 10)
196
+
197
+ def load_config(config_path: str) -> Config:
198
+ """Simple config loader"""
199
+ return Config.from_yaml(config_path)
src/models.py ADDED
@@ -0,0 +1,1913 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # src/models.py
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from diffusers import UNet2DModel
6
+ from transformers import ViTForImageClassification, ViTConfig
7
+ import math
8
+ from typing import Optional, List
9
+ import numpy as np
10
+
11
+ # =============================================================================
12
+ # TIME EMBEDDING (shared utility)
13
+ # =============================================================================
14
+
15
+ class TimeEmbedding(nn.Module):
16
+ def __init__(self, dim: int) -> None:
17
+ super().__init__()
18
+ self.dim = dim
19
+
20
+ def forward(self, t: torch.Tensor) -> torch.Tensor:
21
+ device = t.device
22
+ half_dim = self.dim // 2
23
+ embeddings = math.log(10000) / (half_dim - 1)
24
+ embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
25
+ embeddings = t[:, None] * embeddings[None, :]
26
+ embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
27
+ return embeddings
28
+
29
+ class DiTTimestepEmbedder(nn.Module):
30
+ def __init__(self, hidden_size, freq_dim=128, max_period=10000):
31
+ super().__init__()
32
+ self.freq_dim = freq_dim
33
+ self.max_period = max_period
34
+ self.mlp = nn.Sequential(
35
+ nn.Linear(2*freq_dim, hidden_size, bias=True),
36
+ nn.SiLU(),
37
+ nn.Linear(hidden_size, hidden_size, bias=True),
38
+ )
39
+ def forward(self, t): # t: [B] integers (float tensor ok)
40
+ # standard "timestep_embedding" (like ADM/DiT)
41
+ half = self.freq_dim
42
+ device = t.device
43
+ # positions in radians
44
+ freqs = torch.exp(
45
+ -torch.arange(half, device=device).float() * np.log(self.max_period) / half
46
+ )
47
+ args = t.float()[:, None] * freqs[None] # [B, half]
48
+ emb = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) # [B, 2*half]
49
+ return self.mlp(emb)
50
+
51
+ # =============================================================================
52
+ # OUTPUT CONVERTER (for heterogeneous objectives)
53
+ # =============================================================================
54
+
55
+ class OutputConverter(nn.Module):
56
+ def __init__(self, schedule_type: str = 'linear_interp', use_latents: bool = False, derivative_eps: float = 1e-4):
57
+ super().__init__()
58
+ from schedules import NoiseSchedule
59
+ self.schedule = NoiseSchedule(schedule_type)
60
+ self.schedule_type = schedule_type
61
+ self.use_latents = use_latents
62
+ self.derivative_eps = derivative_eps # For finite difference derivatives
63
+
64
+ # Set clamping range based on data type
65
+ # VAE latents have larger range than pixel-space images
66
+ self.clamp_range = 20.0 if use_latents else 5.0
67
+
68
+ def _get_schedule_with_derivatives(self, t: torch.Tensor):
69
+ """
70
+ Compute schedule coefficients and their derivatives.
71
+ Essential for correct velocity computation with any schedule.
72
+ """
73
+ # Get coefficients at current time
74
+ alpha_t, sigma_t = self.schedule.get_schedule(t)
75
+
76
+ # Compute derivatives using finite differences
77
+ h = torch.full_like(t, self.derivative_eps)
78
+ t_plus = (t + h).clamp(0.0, 1.0)
79
+ t_minus = (t - h).clamp(0.0, 1.0)
80
+
81
+ alpha_plus, sigma_plus = self.schedule.get_schedule(t_plus)
82
+ alpha_minus, sigma_minus = self.schedule.get_schedule(t_minus)
83
+
84
+ # Derivatives
85
+ dt = (t_plus - t_minus).clamp(min=1e-6)
86
+ d_alpha_dt = (alpha_plus - alpha_minus) / dt
87
+ d_sigma_dt = (sigma_plus - sigma_minus) / dt
88
+
89
+ return alpha_t, sigma_t, d_alpha_dt, d_sigma_dt
90
+
91
+ def epsilon_to_velocity(self, epsilon_pred: torch.Tensor, x_t: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
92
+ """
93
+ Correct ε→v conversion for ANY schedule using proper derivatives.
94
+
95
+ From ODE: dx_t/dt = d(alpha_t)/dt * x_0 + d(sigma_t)/dt * ε
96
+ This is the TRUE velocity for the schedule!
97
+ """
98
+ # Get schedule coefficients AND their derivatives
99
+ alpha_t, sigma_t, d_alpha_dt, d_sigma_dt = self._get_schedule_with_derivatives(t)
100
+
101
+ # Reshape for broadcasting
102
+ alpha_t = alpha_t.view(-1, 1, 1, 1)
103
+ sigma_t = sigma_t.view(-1, 1, 1, 1)
104
+ d_alpha_dt = d_alpha_dt.view(-1, 1, 1, 1)
105
+ d_sigma_dt = d_sigma_dt.view(-1, 1, 1, 1)
106
+
107
+ # Numerical stability: handle small alpha_t
108
+ alpha_safe = torch.clamp(alpha_t, min=0.01)
109
+
110
+ # Step 1: Recover x_0 using Tweedie's formula
111
+ x_0_pred = (x_t - sigma_t * epsilon_pred) / alpha_safe
112
+
113
+ # Step 2: Clamp x_0 to reasonable range (prevents blow-up)
114
+ # Use adaptive clamping: larger range for VAE latents, tighter for pixel space
115
+ x_0_pred = torch.clamp(x_0_pred, -self.clamp_range, self.clamp_range)
116
+
117
+ # Step 3: Compute velocity based on schedule type
118
+ if self.schedule_type == 'linear_interp':
119
+ # For linear interpolation: x_t = (1-t)*x_0 + t*ε
120
+ # Velocity is simply: v = ε - x_0
121
+ v = epsilon_pred - x_0_pred
122
+ else:
123
+ # For cosine and other schedules: use proper derivatives
124
+ # v = d(alpha_t)/dt * x_0 + d(sigma_t)/dt * ε
125
+ v = d_alpha_dt * x_0_pred + d_sigma_dt * epsilon_pred
126
+
127
+ # Adaptive velocity scaling for cosine schedule
128
+ # Derivatives vary dramatically with timestep - need adaptive dampening
129
+ if self.schedule_type == 'cosine':
130
+ t_val = t[0].item() if t.numel() > 0 else 0.5
131
+ if t_val > 0.85:
132
+ # Very high noise: derivatives are large, need dampening
133
+ scale = 0.88
134
+ elif t_val > 0.6:
135
+ # Medium-high noise: moderate dampening
136
+ scale = 0.93
137
+ else:
138
+ # Low to medium noise: slight dampening
139
+ scale = 0.96
140
+ v = v * scale
141
+
142
+ # Per-channel bias correction to prevent color drift
143
+ # The model has inherent channel bias that gets amplified by integration
144
+ # Remove per-channel mean to prevent accumulation
145
+ # Only apply to color channels (1,2,3), preserve luminance channel (0)
146
+ for c in range(1, 4):
147
+ v[:, c] = v[:, c] - v[:, c].mean()
148
+
149
+ return v
150
+
151
+ def convert(self, prediction: torch.Tensor, objective_type: str, x_t: torch.Tensor, t: torch.Tensor):
152
+ """
153
+ Convert any prediction to velocity space.
154
+
155
+ Args:
156
+ prediction: expert output
157
+ objective_type: 'ddpm' | 'fm' | 'rf'
158
+ x_t: current noisy state
159
+ t: current timesteps
160
+
161
+ Returns:
162
+ v: velocity representation
163
+ """
164
+ if objective_type == "ddpm":
165
+ # Proper ε→v conversion for unified integration
166
+ return self.epsilon_to_velocity(prediction, x_t, t)
167
+ elif objective_type in ["fm", "rf"]:
168
+ return prediction # Already velocity
169
+ else:
170
+ raise ValueError(f"Unknown objective type: {objective_type}")
171
+
172
+ # =============================================================================
173
+ # EXPERT MODELS
174
+ # =============================================================================
175
+
176
+ class UNetExpert(nn.Module):
177
+ """UNet expert using diffusers"""
178
+
179
+ def __init__(self, config) -> None:
180
+ super().__init__()
181
+
182
+ # Default UNet params
183
+ default_params = {
184
+ "sample_size": config.image_size,
185
+ "in_channels": config.num_channels,
186
+ "out_channels": config.num_channels,
187
+ "layers_per_block": 2,
188
+ "block_out_channels": [64, 128, 256, 256],
189
+ "attention_head_dim": 8,
190
+ }
191
+
192
+ # Override with config params
193
+ params = {**default_params, **config.expert_params}
194
+
195
+ # Store objective type for heterogeneous training (and remove from params)
196
+ self.objective_type = params.pop("objective_type", "fm")
197
+
198
+ # Store and initialize schedule (NEW)
199
+ schedule_type = params.pop("schedule_type", "linear_interp")
200
+ from schedules import NoiseSchedule
201
+ self.schedule = NoiseSchedule(schedule_type)
202
+
203
+ self.unet = UNet2DModel(**params)
204
+
205
+ def forward(self, xt: torch.Tensor, t: torch.Tensor, **kwargs) -> torch.Tensor:
206
+ # Scale timesteps for diffusers (expects 0-1000)
207
+ # t_scaled = (t * 1000).long()
208
+ t_scaled = (t * 999).round().long().clamp(0, 999)
209
+ return self.unet(xt, t_scaled).sample
210
+
211
+ def compute_loss(self, x0: torch.Tensor) -> torch.Tensor:
212
+ """Unified loss computation based on objective type"""
213
+ if self.objective_type == "ddpm":
214
+ return self.ddpm_loss(x0)
215
+ elif self.objective_type == "fm":
216
+ return self.flow_matching_loss(x0)
217
+ elif self.objective_type == "rf":
218
+ return self.rectified_flow_loss(x0)
219
+ else:
220
+ raise ValueError(f"Unknown objective type: {self.objective_type}")
221
+
222
+ def ddpm_loss(self, x0: torch.Tensor) -> torch.Tensor:
223
+ """DDPM: predict noise ε"""
224
+ batch_size = x0.shape[0]
225
+ device = x0.device
226
+
227
+ t = torch.rand(batch_size, device=device)
228
+
229
+ # Use proper schedule (NEW)
230
+ alpha_t, sigma_t = self.schedule.get_schedule(t)
231
+
232
+ noise = torch.randn_like(x0)
233
+ xt = alpha_t.view(-1, 1, 1, 1) * x0 + sigma_t.view(-1, 1, 1, 1) * noise
234
+
235
+ pred_eps = self.forward(xt, t)
236
+ return F.mse_loss(pred_eps, noise)
237
+
238
+ def rectified_flow_loss(self, x0: torch.Tensor) -> torch.Tensor:
239
+ """Rectified Flow: predict velocity v = x_1 - x_0"""
240
+ batch_size = x0.shape[0]
241
+ device = x0.device
242
+
243
+ t = torch.rand(batch_size, device=device)
244
+ x1 = torch.randn_like(x0)
245
+ xt = (1 - t).view(-1, 1, 1, 1) * x0 + t.view(-1, 1, 1, 1) * x1
246
+
247
+ pred_v = self.forward(xt, t)
248
+ true_v = x1 - x0
249
+ return F.mse_loss(pred_v, true_v)
250
+
251
+ def flow_matching_loss(self, x0: torch.Tensor) -> torch.Tensor:
252
+ """Flow matching loss for training"""
253
+ batch_size = x0.shape[0]
254
+ device = x0.device
255
+
256
+ # Sample random timesteps
257
+ t = torch.rand(batch_size, device=device)
258
+
259
+ # Use proper schedule (NEW)
260
+ alpha_t, sigma_t = self.schedule.get_schedule(t)
261
+
262
+ # Add noise
263
+ noise = torch.randn_like(x0)
264
+ xt = alpha_t.view(-1, 1, 1, 1) * x0 + sigma_t.view(-1, 1, 1, 1) * noise
265
+
266
+ # Predict velocity
267
+ pred_v = self.forward(xt, t)
268
+
269
+ # True velocity for flow matching
270
+ # true_v = x0 - xt
271
+ true_v = noise - x0
272
+
273
+ return F.mse_loss(pred_v, true_v)
274
+
275
+ class SimpleCNNExpert(nn.Module):
276
+ """Simple CNN expert for fast training"""
277
+
278
+ def __init__(self, config) -> None:
279
+ super().__init__()
280
+
281
+ # Default params
282
+ default_params = {
283
+ "hidden_dims": [64, 128, 256],
284
+ "time_dim": 64,
285
+ }
286
+ params = {**default_params, **config.expert_params}
287
+
288
+ # Store objective type for heterogeneous training
289
+ self.objective_type = params.get("objective_type", "fm")
290
+
291
+ # Store and initialize schedule (NEW)
292
+ schedule_type = params.get("schedule_type", "linear_interp")
293
+ from schedules import NoiseSchedule
294
+ self.schedule = NoiseSchedule(schedule_type)
295
+
296
+ self.time_embedding = TimeEmbedding(params["time_dim"])
297
+ self.target_size = config.image_size
298
+
299
+ # Simple encoder-decoder
300
+ self.encoder = self._build_encoder(config.num_channels, params["hidden_dims"])
301
+ self.decoder = self._build_decoder(params["hidden_dims"], config.num_channels)
302
+
303
+ # Time conditioning
304
+ self.time_mlp = nn.Sequential(
305
+ nn.Linear(params["time_dim"], params["hidden_dims"][-1]),
306
+ nn.SiLU(),
307
+ nn.Linear(params["hidden_dims"][-1], params["hidden_dims"][-1])
308
+ )
309
+
310
+ def _build_encoder(self, in_channels: int, hidden_dims: List[int]) -> nn.Sequential:
311
+ layers = []
312
+ prev_dim = in_channels
313
+
314
+ for dim in hidden_dims:
315
+ layers.extend([
316
+ nn.Conv2d(prev_dim, dim, 3, padding=1),
317
+ nn.GroupNorm(8, dim),
318
+ nn.SiLU(),
319
+ nn.Conv2d(dim, dim, 3, padding=1),
320
+ nn.GroupNorm(8, dim),
321
+ nn.SiLU(),
322
+ nn.MaxPool2d(2)
323
+ ])
324
+ prev_dim = dim
325
+
326
+ return nn.Sequential(*layers)
327
+
328
+ def _build_decoder(self, hidden_dims: List[int], out_channels: int) -> nn.Sequential:
329
+ layers = []
330
+ reversed_dims = list(reversed(hidden_dims))
331
+
332
+ for i, dim in enumerate(reversed_dims[:-1]):
333
+ next_dim = reversed_dims[i + 1]
334
+ layers.extend([
335
+ nn.ConvTranspose2d(dim, next_dim, 4, stride=2, padding=1),
336
+ nn.GroupNorm(8, next_dim),
337
+ nn.SiLU(),
338
+ nn.Conv2d(next_dim, next_dim, 3, padding=1),
339
+ nn.GroupNorm(8, next_dim),
340
+ nn.SiLU(),
341
+ ])
342
+
343
+ # Final layer
344
+ layers.append(nn.Conv2d(reversed_dims[-1], out_channels, 3, padding=1))
345
+
346
+ return nn.Sequential(*layers)
347
+
348
+ def forward(self, xt: torch.Tensor, t: torch.Tensor, **kwargs) -> torch.Tensor:
349
+ # Time embedding
350
+ time_emb = self.time_embedding(t)
351
+ time_features = self.time_mlp(time_emb)
352
+
353
+ # Encode
354
+ encoded = self.encoder(xt)
355
+
356
+ # Add time conditioning
357
+ time_features = time_features.view(time_features.shape[0], -1, 1, 1)
358
+ time_features = time_features.expand(-1, -1, encoded.shape[2], encoded.shape[3])
359
+ conditioned = encoded + time_features
360
+
361
+ # Decode
362
+ output = self.decoder(conditioned)
363
+
364
+ # Ensure output matches target size
365
+ output = F.interpolate(output, size=xt.shape[-2:], mode='bilinear', align_corners=False)
366
+
367
+ return output
368
+
369
+ def compute_loss(self, x0: torch.Tensor) -> torch.Tensor:
370
+ """Unified loss computation based on objective type"""
371
+ if self.objective_type == "ddpm":
372
+ return self.ddpm_loss(x0)
373
+ elif self.objective_type == "fm":
374
+ return self.flow_matching_loss(x0)
375
+ elif self.objective_type == "rf":
376
+ return self.rectified_flow_loss(x0)
377
+ else:
378
+ raise ValueError(f"Unknown objective type: {self.objective_type}")
379
+
380
+ def ddpm_loss(self, x0: torch.Tensor) -> torch.Tensor:
381
+ """DDPM: predict noise ε"""
382
+ batch_size = x0.shape[0]
383
+ device = x0.device
384
+
385
+ t = torch.rand(batch_size, device=device)
386
+
387
+ # Use proper schedule (NEW)
388
+ alpha_t, sigma_t = self.schedule.get_schedule(t)
389
+
390
+ noise = torch.randn_like(x0)
391
+ xt = alpha_t.view(-1, 1, 1, 1) * x0 + sigma_t.view(-1, 1, 1, 1) * noise
392
+
393
+ pred_eps = self.forward(xt, t)
394
+
395
+ # Ensure pred_eps matches noise shape
396
+ if pred_eps.shape != noise.shape:
397
+ pred_eps = F.interpolate(pred_eps, size=noise.shape[-2:], mode='bilinear', align_corners=False)
398
+
399
+ return F.mse_loss(pred_eps, noise)
400
+
401
+ def rectified_flow_loss(self, x0: torch.Tensor) -> torch.Tensor:
402
+ """Rectified Flow: predict velocity v = x_1 - x_0"""
403
+ batch_size = x0.shape[0]
404
+ device = x0.device
405
+
406
+ t = torch.rand(batch_size, device=device)
407
+ x1 = torch.randn_like(x0)
408
+ xt = (1 - t).view(-1, 1, 1, 1) * x0 + t.view(-1, 1, 1, 1) * x1
409
+
410
+ pred_v = self.forward(xt, t)
411
+ true_v = x1 - x0
412
+
413
+ # Ensure pred_v matches true_v shape
414
+ if pred_v.shape != true_v.shape:
415
+ pred_v = F.interpolate(pred_v, size=true_v.shape[-2:], mode='bilinear', align_corners=False)
416
+
417
+ return F.mse_loss(pred_v, true_v)
418
+
419
+ def flow_matching_loss(self, x0: torch.Tensor) -> torch.Tensor:
420
+ """Flow matching loss"""
421
+ batch_size = x0.shape[0]
422
+ device = x0.device
423
+
424
+ t = torch.rand(batch_size, device=device)
425
+
426
+ # Use proper schedule (NEW)
427
+ alpha_t, sigma_t = self.schedule.get_schedule(t)
428
+
429
+ noise = torch.randn_like(x0)
430
+ xt = alpha_t.view(-1, 1, 1, 1) * x0 + sigma_t.view(-1, 1, 1, 1) * noise
431
+
432
+ pred_v = self.forward(xt, t)
433
+ # true_v = x0 - xt
434
+ true_v = noise - x0
435
+
436
+ # Ensure pred_v matches true_v shape
437
+ if pred_v.shape != true_v.shape:
438
+ pred_v = F.interpolate(pred_v, size=true_v.shape[-2:], mode='bilinear', align_corners=False)
439
+
440
+ return F.mse_loss(pred_v, true_v)
441
+
442
+ # Helper function from original DiT
443
+ def modulate(x, shift, scale):
444
+ return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
445
+
446
+ # Fixed sin-cos position embedding from original
447
+ def get_2d_sincos_pos_embed(embed_dim, grid_size):
448
+ grid_h = np.arange(grid_size, dtype=np.float32)
449
+ grid_w = np.arange(grid_size, dtype=np.float32)
450
+ grid = np.meshgrid(grid_w, grid_h)
451
+ grid = np.stack(grid, axis=0)
452
+ grid = grid.reshape([2, 1, grid_size, grid_size])
453
+
454
+ assert embed_dim % 2 == 0
455
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0])
456
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1])
457
+ emb = np.concatenate([emb_h, emb_w], axis=1)
458
+ return emb
459
+
460
+ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
461
+ assert embed_dim % 2 == 0
462
+ omega = np.arange(embed_dim // 2, dtype=np.float64)
463
+ omega /= embed_dim / 2.
464
+ omega = 1. / 10000**omega
465
+ pos = pos.reshape(-1)
466
+ out = np.einsum('m,d->md', pos, omega)
467
+ emb_sin = np.sin(out)
468
+ emb_cos = np.cos(out)
469
+ emb = np.concatenate([emb_sin, emb_cos], axis=1)
470
+ return emb
471
+
472
+ # Timestep Embedder
473
+ class TimestepEmbedder(nn.Module):
474
+ def __init__(self, hidden_size: int, frequency_embedding_size: int = 256):
475
+ super().__init__()
476
+ self.frequency_embedding_size = frequency_embedding_size
477
+ self.mlp = nn.Sequential(
478
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True),
479
+ nn.SiLU(),
480
+ nn.Linear(hidden_size, hidden_size, bias=True),
481
+ )
482
+
483
+ @staticmethod
484
+ def timestep_embedding(t, dim, max_period=10000):
485
+ half = dim // 2
486
+ freqs = torch.exp(-math.log(max_period) * torch.arange(0, half, dtype=torch.float32, device=t.device) / half)
487
+ args = t[:, None].float() * freqs[None]
488
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
489
+ if dim % 2:
490
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
491
+ return embedding
492
+
493
+ def forward(self, t: torch.Tensor) -> torch.Tensor:
494
+ t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
495
+ return self.mlp(t_freq)
496
+
497
+ # DiTBlock with proper AdaLN-Zero
498
+ class DiTBlock(nn.Module):
499
+ def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float = 4.0, use_text: bool = False, use_adaln_single: bool = False):
500
+ super().__init__()
501
+ self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
502
+ self.attn = nn.MultiheadAttention(hidden_size, num_heads, dropout=0.1, batch_first=True)
503
+ self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
504
+
505
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
506
+ self.mlp = nn.Sequential(
507
+ nn.Linear(hidden_size, mlp_hidden_dim),
508
+ nn.GELU(approximate="tanh"), # Match original
509
+ nn.Linear(mlp_hidden_dim, hidden_size),
510
+ )
511
+
512
+ # AdaLN modulation - either per-block MLP or AdaLN-Single embeddings
513
+ self.use_adaln_single = use_adaln_single
514
+ if use_adaln_single:
515
+ # AdaLN-Single: use learnable per-block embeddings instead of MLP
516
+ self.scale_shift_table = nn.Parameter(torch.randn(6, hidden_size) / hidden_size ** 0.5)
517
+ self.adaLN_modulation = None # No MLP needed
518
+ else:
519
+ # Original AdaLN with per-block MLP
520
+ self.adaLN_modulation = nn.Sequential(
521
+ nn.SiLU(),
522
+ nn.Linear(hidden_size, 6 * hidden_size, bias=True)
523
+ )
524
+ self.scale_shift_table = None
525
+
526
+ # Optional text cross-attention
527
+ self.use_text = use_text
528
+ if use_text:
529
+ # Note: PixArt uses xformers which may handle unnormalized queries differently
530
+ # We add a simple norm for stability with PyTorch's MultiheadAttention
531
+ self.norm_cross = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
532
+ self.cross_attn = nn.MultiheadAttention(hidden_size, num_heads, dropout=0.1, batch_first=True)
533
+
534
+ def forward(self, x: torch.Tensor, c: torch.Tensor, text_emb: Optional[torch.Tensor] = None,
535
+ attention_mask: Optional[torch.Tensor] = None):
536
+ # Get modulation parameters
537
+ if self.use_adaln_single:
538
+ # AdaLN-Single: combine global time embedding with per-block parameters
539
+ # c should be pre-computed from global t_block with shape [B, 6*hidden_size]
540
+ B = x.shape[0]
541
+ # Chunk and squeeze to get [B, hidden_size] tensors for compatibility with PyTorch's MultiheadAttention
542
+ temp = (self.scale_shift_table[None] + c.reshape(B, 6, -1)).chunk(6, dim=1)
543
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = [t.squeeze(1) for t in temp]
544
+ else:
545
+ # Original AdaLN: compute modulation from per-block MLP
546
+ # Also squeeze after chunk to get [B, hidden_size] for consistency
547
+ temp = self.adaLN_modulation(c).chunk(6, dim=1)
548
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = [t.squeeze(1) for t in temp]
549
+
550
+ # Self-attention with modulation
551
+ # Both paths now use modulate function for consistency
552
+ x_norm = modulate(self.norm1(x), shift_msa, scale_msa)
553
+ attn_out, _ = self.attn(x_norm, x_norm, x_norm)
554
+ x = x + gate_msa.unsqueeze(1) * attn_out
555
+
556
+ # Optional cross-attention
557
+ if self.use_text and text_emb is not None:
558
+ if text_emb.dim() == 2:
559
+ text_emb = text_emb.unsqueeze(1)
560
+ # Convert attention mask to key_padding_mask format (True = ignore)
561
+ # attention_mask: shape [B, T]; either bool (True=keep) or 0/1 numeric (1=keep)
562
+ key_padding_mask = None
563
+ if attention_mask is not None:
564
+ if attention_mask.dtype is not torch.bool:
565
+ # Convert 0/1 (or >=1) to bool keep-mask first
566
+ keep_mask = attention_mask > 0
567
+ else:
568
+ keep_mask = attention_mask
569
+ # key_padding_mask semantics: True = ignore, False = keep
570
+ key_padding_mask = ~keep_mask # logical NOT, not arithmetic subtraction
571
+
572
+ # Normalize queries for stability (PixArt uses xformers which may differ)
573
+ x_norm = self.norm_cross(x)
574
+ cross_out, _ = self.cross_attn(x_norm, text_emb, text_emb, key_padding_mask=key_padding_mask)
575
+ x = x + cross_out
576
+
577
+ # MLP with modulation
578
+ # Both paths now use modulate function for consistency
579
+ x_norm = modulate(self.norm2(x), shift_mlp, scale_mlp)
580
+ mlp_out = self.mlp(x_norm)
581
+ x = x + gate_mlp.unsqueeze(1) * mlp_out
582
+
583
+ return x
584
+
585
+ # FinalLayer with AdaLN modulation
586
+ class FinalLayer(nn.Module):
587
+ def __init__(self, hidden_size: int, patch_size: int, out_channels: int):
588
+ super().__init__()
589
+ self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
590
+ self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
591
+ self.adaLN_modulation = nn.Sequential(
592
+ nn.SiLU(),
593
+ nn.Linear(hidden_size, 2 * hidden_size, bias=True)
594
+ )
595
+
596
+ def forward(self, x: torch.Tensor, c: torch.Tensor):
597
+ shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
598
+ x = modulate(self.norm_final(x), shift, scale)
599
+ x = self.linear(x)
600
+ return x
601
+
602
+ # T2IFinalLayer with AdaLN-Single for parameter efficiency
603
+ class T2IFinalLayer(nn.Module):
604
+ def __init__(self, hidden_size: int, patch_size: int, out_channels: int):
605
+ super().__init__()
606
+ self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
607
+ self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
608
+ # AdaLN-Single: use learnable embeddings instead of MLP
609
+ self.scale_shift_table = nn.Parameter(torch.randn(2, hidden_size) / hidden_size ** 0.5)
610
+ self.hidden_size = hidden_size
611
+
612
+ def forward(self, x: torch.Tensor, t: torch.Tensor):
613
+ # t should be the original time embedding with shape [B, hidden_size]
614
+ # Following PixArt implementation exactly
615
+ shift, scale = (self.scale_shift_table[None] + t[:, None]).chunk(2, dim=1)
616
+ # shift and scale are [B, 1, hidden_size], use t2i_modulate style
617
+ x = self.norm_final(x) * (1 + scale) + shift
618
+ x = self.linear(x)
619
+ return x
620
+
621
+ # DiTExpert
622
+ class DiTExpert(nn.Module):
623
+ def __init__(self, config):
624
+ super().__init__()
625
+ default_params = {
626
+ "hidden_size": 768,
627
+ "num_layers": 12,
628
+ "num_heads": 12,
629
+ "patch_size": 2,
630
+ "in_channels": 4,
631
+ "out_channels": 4,
632
+ "use_text_conditioning": False,
633
+ "use_class_conditioning": False,
634
+ "num_classes": 1000, # ImageNet classes
635
+ "mlp_ratio": 4.0,
636
+ "text_embed_dim": 768,
637
+ "use_dit_time_embed": False,
638
+ }
639
+ params = {**default_params, **config.expert_params}
640
+
641
+ self.patch_size = params["patch_size"]
642
+ self.in_channels = params["in_channels"]
643
+ self.out_channels = params["out_channels"]
644
+ self.hidden_size = params["hidden_size"]
645
+ self.num_heads = params["num_heads"]
646
+ self.use_text = params.get("use_text_conditioning", False)
647
+ self.use_class = params.get("use_class_conditioning", False)
648
+ self.cfg_dropout_prob = params.get("cfg_dropout_prob", 0.1) # 10% dropout for CFG
649
+ self.text_embed_dim = params.get("text_embed_dim", 768)
650
+ self.use_adaln_single = params.get("use_adaln_single", False) # AdaLN-Single for parameter efficiency
651
+ self.depth = params["num_layers"]
652
+
653
+ # Store objective type for heterogeneous training
654
+ self.objective_type = params.get("objective_type", "fm")
655
+
656
+ # Store and initialize schedule (NEW)
657
+ schedule_type = params.get("schedule_type", "linear_interp")
658
+ from schedules import NoiseSchedule
659
+ self.schedule = NoiseSchedule(schedule_type)
660
+
661
+ # Validation: cannot use both text and class conditioning simultaneously
662
+ assert not (self.use_text and self.use_class), "Cannot use both text and class conditioning simultaneously"
663
+
664
+ # Patch embedding
665
+ self.patch_embed = nn.Conv2d(self.in_channels, self.hidden_size,
666
+ kernel_size=self.patch_size, stride=self.patch_size)
667
+
668
+ # Fixed sin-cos positional embedding
669
+ latent_size = getattr(config, 'image_size', 32)
670
+ self.num_patches = (latent_size // self.patch_size) ** 2
671
+ self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches, self.hidden_size), requires_grad=False)
672
+
673
+ # Time embedding
674
+ self.use_dit_time_embed = params.get("use_dit_time_embed", False)
675
+ if self.use_dit_time_embed:
676
+ self.time_embed = DiTTimestepEmbedder(self.hidden_size)
677
+ else:
678
+ self.time_embed = TimestepEmbedder(self.hidden_size)
679
+
680
+ # Global time block for AdaLN-Single
681
+ if self.use_adaln_single:
682
+ self.t_block = nn.Sequential(
683
+ nn.SiLU(),
684
+ nn.Linear(self.hidden_size, 6 * self.hidden_size, bias=True)
685
+ )
686
+
687
+ # Optional text conditioning
688
+ if self.use_text:
689
+ self.text_proj = nn.Linear(self.text_embed_dim, self.hidden_size)
690
+ self.text_norm = nn.LayerNorm(self.hidden_size, elementwise_affine=False, eps=1e-6)
691
+ # Note: null text embedding will be provided by empty string encoding from CLIP
692
+ # This is handled in the training loop, not as a learnable parameter
693
+
694
+ # Optional class conditioning (ImageNet style)
695
+ if self.use_class:
696
+ # Add 1 extra embedding for null/unconditional class
697
+ self.class_embed = nn.Embedding(params["num_classes"] + 1, self.hidden_size)
698
+ self.null_class_id = params["num_classes"] # Use last index as null class
699
+
700
+ # Transformer blocks
701
+ self.layers = nn.ModuleList([
702
+ DiTBlock(self.hidden_size, self.num_heads, params.get("mlp_ratio", 4.0),
703
+ self.use_text, use_adaln_single=self.use_adaln_single)
704
+ for _ in range(self.depth)
705
+ ])
706
+
707
+ # Final layer with modulation
708
+ if self.use_adaln_single:
709
+ self.final_layer = T2IFinalLayer(self.hidden_size, self.patch_size, self.out_channels)
710
+ else:
711
+ self.final_layer = FinalLayer(self.hidden_size, self.patch_size, self.out_channels)
712
+
713
+ # Initialize weights
714
+ self.initialize_weights()
715
+
716
+ def initialize_weights(self):
717
+ # Initialize transformer layers
718
+ def _basic_init(module):
719
+ if isinstance(module, nn.Linear):
720
+ torch.nn.init.xavier_uniform_(module.weight)
721
+ if module.bias is not None:
722
+ nn.init.constant_(module.bias, 0)
723
+ self.apply(_basic_init)
724
+
725
+ # Initialize positional embedding with sin-cos
726
+ grid_size = int(self.num_patches ** 0.5)
727
+ pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], grid_size)
728
+ self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
729
+
730
+ # Initialize patch_embed like nn.Linear
731
+ w = self.patch_embed.weight.data
732
+ nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
733
+ if self.patch_embed.bias is not None:
734
+ nn.init.constant_(self.patch_embed.bias, 0)
735
+
736
+ # Initialize timestep embedding MLP
737
+ nn.init.normal_(self.time_embed.mlp[0].weight, std=0.02)
738
+ nn.init.normal_(self.time_embed.mlp[2].weight, std=0.02)
739
+
740
+ # Zero-out adaLN modulation layers in DiT blocks (from DiT paper)
741
+ for block in self.layers:
742
+ if block.adaLN_modulation is not None:
743
+ # Original AdaLN mode
744
+ nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
745
+ nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
746
+ # AdaLN-Single mode: scale_shift_table is already initialized with randn/sqrt(hidden_size)
747
+
748
+ # Zero-out cross-attention output projection (from PixArt-Alpha)
749
+ if self.use_text and hasattr(block, 'cross_attn'):
750
+ nn.init.constant_(block.cross_attn.out_proj.weight, 0)
751
+ nn.init.constant_(block.cross_attn.out_proj.bias, 0)
752
+
753
+ # Initialize text projection layer (analogous to PixArt's caption embedding)
754
+ if self.use_text and hasattr(self, 'text_proj'):
755
+ nn.init.normal_(self.text_proj.weight, std=0.02)
756
+ if self.text_proj.bias is not None:
757
+ nn.init.constant_(self.text_proj.bias, 0)
758
+
759
+ # Initialize class embedding layer (similar to DiT paper)
760
+ if self.use_class and hasattr(self, 'class_embed'):
761
+ nn.init.normal_(self.class_embed.weight, std=0.02)
762
+
763
+ # Initialize global t_block for AdaLN-Single
764
+ if self.use_adaln_single and hasattr(self, 't_block'):
765
+ nn.init.normal_(self.t_block[1].weight, std=0.02)
766
+ # Zero-out t_block initially for stability
767
+ nn.init.constant_(self.t_block[1].bias, 0)
768
+
769
+ # Zero-out output layers
770
+ if hasattr(self.final_layer, 'adaLN_modulation') and self.final_layer.adaLN_modulation is not None:
771
+ # Original FinalLayer
772
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
773
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
774
+ # T2IFinalLayer scale_shift_table is already initialized with randn/sqrt(hidden_size)
775
+ nn.init.constant_(self.final_layer.linear.weight, 0)
776
+ nn.init.constant_(self.final_layer.linear.bias, 0)
777
+
778
+ def forward(self, xt: torch.Tensor, t: torch.Tensor, text_embeds: Optional[torch.Tensor] = None,
779
+ attention_mask: Optional[torch.Tensor] = None, class_labels: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor:
780
+ B, C, H, W = xt.shape
781
+
782
+ # Handle timestep scaling - DiT expects timesteps in [0, 999] range
783
+ # If t is normalized (in [0, 1]), scale it to [0, 999]
784
+ if t.max() <= 1.0 and t.min() >= 0.0:
785
+ # Normalized timesteps, scale to DiT range
786
+ t = t * 999.0
787
+ # Ensure t is in correct range for DiT
788
+ t = t.clamp(0, 999)
789
+
790
+ # Patchify
791
+ x = self.patch_embed(xt) # [B, hidden_size, H//p, W//p]
792
+ x = x.flatten(2).transpose(1, 2) # [B, num_patches, hidden_size]
793
+ x = x + self.pos_embed # Add positional embedding
794
+
795
+ # Prepare conditioning
796
+ time_emb = self.time_embed(t) # [B, hidden_size]
797
+
798
+ # Add class conditioning to time embedding if using class conditioning
799
+ if self.use_class and class_labels is not None:
800
+ class_emb = self.class_embed(class_labels) # [B, hidden_size]
801
+ time_emb = time_emb + class_emb # Additive combination following DiT paper
802
+
803
+ # Process conditioning based on AdaLN mode
804
+ if self.use_adaln_single:
805
+ # AdaLN-Single: compute global modulation once
806
+ c = self.t_block(time_emb) # [B, 6*hidden_size]
807
+ else:
808
+ # Original AdaLN: pass time embedding to each block
809
+ c = time_emb
810
+
811
+ # Prepare text tokens for cross-attention (not fused with time)
812
+ text_tokens = None
813
+ if self.use_text and text_embeds is not None:
814
+ if text_embeds.dim() == 3:
815
+ text_tokens = self.text_proj(text_embeds) # [B, T, hidden_size]
816
+ text_tokens = self.text_norm(text_tokens)
817
+ else:
818
+ text_tokens = self.text_proj(text_embeds).unsqueeze(1) # [B, 1, hidden_size]
819
+ text_tokens = self.text_norm(text_tokens)
820
+
821
+ if attention_mask is not None:
822
+ # cast to bool, clamp shapes to text_tokens length
823
+ attention_mask = attention_mask[:, :text_tokens.shape[1]].to(torch.bool)
824
+ # safety: avoid all-false rows (would yield NaNs in softmax)
825
+ all_false = attention_mask.sum(dim=1) == 0
826
+ if all_false.any():
827
+ attention_mask[all_false, 0] = True
828
+
829
+ # Apply transformer blocks
830
+ for layer in self.layers:
831
+ x = layer(x, c, text_tokens, attention_mask)
832
+
833
+ # Final projection
834
+ if self.use_adaln_single:
835
+ # T2IFinalLayer expects original time embedding, not global modulation
836
+ x = self.final_layer(x, time_emb) # [B, num_patches, patch_size^2 * out_channels]
837
+ else:
838
+ # Original FinalLayer expects conditioning
839
+ x = self.final_layer(x, c) # [B, num_patches, patch_size^2 * out_channels]
840
+
841
+ # Unpatchify
842
+ patch_h = patch_w = int(self.num_patches ** 0.5)
843
+ x = x.view(B, patch_h, patch_w, self.patch_size, self.patch_size, self.out_channels)
844
+ x = x.permute(0, 5, 1, 3, 2, 4).contiguous()
845
+ x = x.view(B, self.out_channels, H, W)
846
+
847
+ return x
848
+
849
+ def compute_loss(self, x0: torch.Tensor, text_embeds: Optional[torch.Tensor] = None,
850
+ attention_mask: Optional[torch.Tensor] = None, class_labels: Optional[torch.Tensor] = None,
851
+ null_text_embeds: Optional[torch.Tensor] = None, null_attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
852
+ """Unified loss computation based on objective type"""
853
+ if self.objective_type == "ddpm":
854
+ return self.ddpm_loss(x0, text_embeds, attention_mask, class_labels, null_text_embeds, null_attention_mask)
855
+ elif self.objective_type == "fm":
856
+ return self.flow_matching_loss(x0, text_embeds, attention_mask, class_labels, null_text_embeds, null_attention_mask)
857
+ elif self.objective_type == "rf":
858
+ return self.rectified_flow_loss(x0, text_embeds, attention_mask, class_labels, null_text_embeds, null_attention_mask)
859
+ else:
860
+ raise ValueError(f"Unknown objective type: {self.objective_type}")
861
+
862
+ def ddpm_loss(self, x0: torch.Tensor, text_embeds: Optional[torch.Tensor] = None,
863
+ attention_mask: Optional[torch.Tensor] = None, class_labels: Optional[torch.Tensor] = None,
864
+ null_text_embeds: Optional[torch.Tensor] = None, null_attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
865
+ """DDPM: predict noise ε"""
866
+ B = x0.shape[0]
867
+ device = x0.device
868
+
869
+ # Sample time uniformly
870
+ t = torch.rand(B, device=device)
871
+
872
+ # Use proper schedule (NEW)
873
+ alpha_t, sigma_t = self.schedule.get_schedule(t)
874
+
875
+ noise = torch.randn_like(x0)
876
+ xt = alpha_t.view(-1, 1, 1, 1) * x0 + sigma_t.view(-1, 1, 1, 1) * noise
877
+
878
+ # Apply CFG dropout during training
879
+ if self.training and self.cfg_dropout_prob > 0:
880
+ if self.use_text and text_embeds is not None:
881
+ keep = (torch.rand(B, device=device) > self.cfg_dropout_prob) # True = keep text
882
+
883
+ if null_text_embeds is not None:
884
+ # Use provided null text embeddings (from empty string CLIP encoding)
885
+ if null_text_embeds.shape[0] == 1:
886
+ null_text_embeds = null_text_embeds.expand(B, -1, -1)
887
+
888
+ # Replace dropped samples with null text embeddings
889
+ dropped = ~keep
890
+ if dropped.any():
891
+ text_embeds = text_embeds.clone()
892
+ text_embeds[dropped] = null_text_embeds[dropped]
893
+
894
+ # Use provided null attention mask or create default for empty string
895
+ if attention_mask is not None:
896
+ attention_mask = attention_mask.clone()
897
+ if null_attention_mask is not None:
898
+ if null_attention_mask.shape[0] == 1:
899
+ null_attention_mask = null_attention_mask.expand(B, -1)
900
+ attention_mask[dropped] = null_attention_mask[dropped]
901
+ else:
902
+ attention_mask[dropped] = 0
903
+ attention_mask[dropped, 0] = 1
904
+ else:
905
+ # Fallback to old zeroing approach if null_text_embeds not provided
906
+ if text_embeds.dim() == 3: # [B, T, D]
907
+ text_embeds = text_embeds * keep[:, None, None].to(text_embeds.dtype)
908
+ else: # [B, D]
909
+ text_embeds = text_embeds * keep[:, None].to(text_embeds.dtype)
910
+
911
+ if attention_mask is not None:
912
+ attention_mask = attention_mask.clone()
913
+ dropped = ~keep
914
+ if dropped.any():
915
+ attention_mask[dropped, 0] = 1
916
+
917
+ elif self.use_class and class_labels is not None:
918
+ # Apply CFG dropout to class labels using null class embedding
919
+ keep = (torch.rand(B, device=device) > self.cfg_dropout_prob)
920
+ null_class = torch.full_like(class_labels, self.null_class_id)
921
+ class_labels = torch.where(keep, class_labels, null_class)
922
+
923
+ # Predict noise
924
+ pred_eps = self.forward(xt, t, text_embeds, attention_mask, class_labels)
925
+
926
+ return F.mse_loss(pred_eps, noise)
927
+
928
+ def rectified_flow_loss(self, x0: torch.Tensor, text_embeds: Optional[torch.Tensor] = None,
929
+ attention_mask: Optional[torch.Tensor] = None, class_labels: Optional[torch.Tensor] = None,
930
+ null_text_embeds: Optional[torch.Tensor] = None, null_attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
931
+ """Rectified Flow: predict velocity v = x_1 - x_0 (straight paths)"""
932
+ B = x0.shape[0]
933
+ device = x0.device
934
+
935
+ # Sample time uniformly
936
+ t = torch.rand(B, device=device)
937
+
938
+ # Straight-line interpolation
939
+ x1 = torch.randn_like(x0) # Gaussian noise as x_1
940
+ xt = (1 - t).view(-1, 1, 1, 1) * x0 + t.view(-1, 1, 1, 1) * x1
941
+
942
+ # Apply CFG dropout during training
943
+ if self.training and self.cfg_dropout_prob > 0:
944
+ if self.use_text and text_embeds is not None:
945
+ keep = (torch.rand(B, device=device) > self.cfg_dropout_prob) # True = keep text
946
+
947
+ if null_text_embeds is not None:
948
+ # Use provided null text embeddings (from empty string CLIP encoding)
949
+ if null_text_embeds.shape[0] == 1:
950
+ null_text_embeds = null_text_embeds.expand(B, -1, -1)
951
+
952
+ # Replace dropped samples with null text embeddings
953
+ dropped = ~keep
954
+ if dropped.any():
955
+ text_embeds = text_embeds.clone()
956
+ text_embeds[dropped] = null_text_embeds[dropped]
957
+
958
+ # Use provided null attention mask or create default for empty string
959
+ if attention_mask is not None:
960
+ attention_mask = attention_mask.clone()
961
+ if null_attention_mask is not None:
962
+ if null_attention_mask.shape[0] == 1:
963
+ null_attention_mask = null_attention_mask.expand(B, -1)
964
+ attention_mask[dropped] = null_attention_mask[dropped]
965
+ else:
966
+ attention_mask[dropped] = 0
967
+ attention_mask[dropped, 0] = 1
968
+ else:
969
+ # Fallback to old zeroing approach if null_text_embeds not provided
970
+ if text_embeds.dim() == 3: # [B, T, D]
971
+ text_embeds = text_embeds * keep[:, None, None].to(text_embeds.dtype)
972
+ else: # [B, D]
973
+ text_embeds = text_embeds * keep[:, None].to(text_embeds.dtype)
974
+
975
+ if attention_mask is not None:
976
+ attention_mask = attention_mask.clone()
977
+ dropped = ~keep
978
+ if dropped.any():
979
+ attention_mask[dropped, 0] = 1
980
+
981
+ elif self.use_class and class_labels is not None:
982
+ # Apply CFG dropout to class labels using null class embedding
983
+ keep = (torch.rand(B, device=device) > self.cfg_dropout_prob)
984
+ null_class = torch.full_like(class_labels, self.null_class_id)
985
+ class_labels = torch.where(keep, class_labels, null_class)
986
+
987
+ # Predict velocity (x_1 - x_0)
988
+ pred_v = self.forward(xt, t, text_embeds, attention_mask, class_labels)
989
+ true_v = x1 - x0
990
+
991
+ return F.mse_loss(pred_v, true_v)
992
+
993
+ def flow_matching_loss(self, x0: torch.Tensor, text_embeds: Optional[torch.Tensor] = None,
994
+ attention_mask: Optional[torch.Tensor] = None, class_labels: Optional[torch.Tensor] = None,
995
+ null_text_embeds: Optional[torch.Tensor] = None, null_attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
996
+ """Flow matching loss for latent space training with CFG dropout."""
997
+ B = x0.shape[0]
998
+ device = x0.device
999
+
1000
+ # Sample time uniformly
1001
+ t = torch.rand(B, device=device)
1002
+
1003
+ # Use proper schedule (NEW)
1004
+ alpha_t, sigma_t = self.schedule.get_schedule(t)
1005
+
1006
+ noise = torch.randn_like(x0)
1007
+ xt = alpha_t.view(-1, 1, 1, 1) * x0 + sigma_t.view(-1, 1, 1, 1) * noise
1008
+
1009
+ # Apply CFG dropout during training
1010
+ if self.training and self.cfg_dropout_prob > 0:
1011
+ if self.use_text and text_embeds is not None:
1012
+ keep = (torch.rand(B, device=device) > self.cfg_dropout_prob) # True = keep text
1013
+
1014
+ if null_text_embeds is not None:
1015
+ # Use provided null text embeddings (from empty string CLIP encoding)
1016
+ # Ensure null_text_embeds matches the batch size
1017
+ if null_text_embeds.shape[0] == 1:
1018
+ null_text_embeds = null_text_embeds.expand(B, -1, -1)
1019
+
1020
+ # Replace dropped samples with null text embeddings
1021
+ dropped = ~keep
1022
+ if dropped.any():
1023
+ text_embeds = text_embeds.clone()
1024
+ text_embeds[dropped] = null_text_embeds[dropped]
1025
+
1026
+ # Use provided null attention mask or create default for empty string
1027
+ if attention_mask is not None:
1028
+ attention_mask = attention_mask.clone()
1029
+ if null_attention_mask is not None:
1030
+ # Ensure null_attention_mask matches batch size
1031
+ if null_attention_mask.shape[0] == 1:
1032
+ null_attention_mask = null_attention_mask.expand(B, -1)
1033
+ attention_mask[dropped] = null_attention_mask[dropped]
1034
+ else:
1035
+ # Default: For null text (empty string), typically only the first token is valid
1036
+ attention_mask[dropped] = 0
1037
+ attention_mask[dropped, 0] = 1 # Keep only first token for empty string
1038
+ else:
1039
+ # Fallback to old zeroing approach if null_text_embeds not provided
1040
+ if text_embeds.dim() == 3: # [B, T, D]
1041
+ text_embeds = text_embeds * keep[:, None, None].to(text_embeds.dtype)
1042
+ else: # [B, D]
1043
+ text_embeds = text_embeds * keep[:, None].to(text_embeds.dtype)
1044
+
1045
+ # Handle attention mask for fallback approach
1046
+ if attention_mask is not None:
1047
+ attention_mask = attention_mask.clone()
1048
+ dropped = ~keep
1049
+ if dropped.any():
1050
+ attention_mask[dropped, 0] = 1
1051
+
1052
+ elif self.use_class and class_labels is not None:
1053
+ # Apply CFG dropout to class labels using null class embedding
1054
+ keep = (torch.rand(B, device=device) > self.cfg_dropout_prob) # True = keep class
1055
+ # Use the dedicated null class embedding for unconditional generation
1056
+ null_class = torch.full_like(class_labels, self.null_class_id)
1057
+ class_labels = torch.where(keep, class_labels, null_class)
1058
+
1059
+ # Predict velocity
1060
+ pred_v = self.forward(xt, t, text_embeds, attention_mask, class_labels)
1061
+ true_v = noise - x0
1062
+
1063
+ return F.mse_loss(pred_v, true_v)
1064
+
1065
+ # =============================================================================
1066
+ # ROUTER MODELS
1067
+ # =============================================================================
1068
+
1069
+ class ViTRouter(nn.Module):
1070
+ """ViT-based router for cluster classification"""
1071
+
1072
+ def __init__(self, config) -> None:
1073
+ super().__init__()
1074
+
1075
+ # Default params
1076
+ default_params = {
1077
+ "hidden_size": 384,
1078
+ "num_layers": 6,
1079
+ "num_heads": 6,
1080
+ "patch_size": 8,
1081
+ "use_dit_time_embed": False, # Whether to use DiT-style time embedding
1082
+ }
1083
+ params = {**default_params, **config.router_params}
1084
+
1085
+ if config.router_pretrained:
1086
+ # Use pretrained ViT and adapt
1087
+ self.vit = ViTForImageClassification.from_pretrained(
1088
+ "google/vit-base-patch16-224"
1089
+ )
1090
+ self._adapt_pretrained(config, params)
1091
+ else:
1092
+ # Build from scratch
1093
+ vit_config = ViTConfig(
1094
+ image_size=config.image_size,
1095
+ patch_size=params["patch_size"],
1096
+ num_channels=config.num_channels,
1097
+ hidden_size=params["hidden_size"],
1098
+ num_hidden_layers=params["num_layers"],
1099
+ num_attention_heads=params["num_heads"],
1100
+ num_labels=config.num_clusters
1101
+ )
1102
+ self.vit = ViTForImageClassification(vit_config)
1103
+
1104
+ # Time conditioning - support both embedding styles
1105
+ self.use_dit_time_embed = params.get("use_dit_time_embed", False)
1106
+ if self.use_dit_time_embed:
1107
+ # Use DiT-style timestep embedding for consistency
1108
+ self.time_embedding = DiTTimestepEmbedder(params["hidden_size"])
1109
+ else:
1110
+ # Original simple time embedding
1111
+ self.time_embedding = nn.Sequential(
1112
+ nn.Linear(1, params["hidden_size"]),
1113
+ nn.SiLU(),
1114
+ nn.Linear(params["hidden_size"], params["hidden_size"])
1115
+ )
1116
+
1117
+ # Combined classifier
1118
+ self.classifier = nn.Sequential(
1119
+ nn.Linear(params["hidden_size"] * 2, params["hidden_size"]),
1120
+ nn.ReLU(),
1121
+ nn.Dropout(0.1),
1122
+ nn.Linear(params["hidden_size"], config.num_clusters)
1123
+ )
1124
+
1125
+ def _adapt_pretrained(self, config, params) -> ViTForImageClassification:
1126
+ """Adapt pretrained ViT for our task"""
1127
+ # Modify patch embeddings if needed
1128
+ if config.image_size != 224 or config.num_channels != 3:
1129
+ self.vit.vit.embeddings.patch_embeddings.projection = nn.Conv2d(
1130
+ config.num_channels,
1131
+ self.vit.config.hidden_size,
1132
+ kernel_size=params["patch_size"],
1133
+ stride=params["patch_size"]
1134
+ )
1135
+
1136
+ def forward(self, xt: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
1137
+ # Process image through ViT
1138
+ vit_outputs = self.vit.vit(xt)
1139
+ image_features = vit_outputs.last_hidden_state[:, 0] # CLS token
1140
+
1141
+ # Time conditioning
1142
+ if self.use_dit_time_embed:
1143
+ # DiT embedder expects raw timesteps
1144
+ time_features = self.time_embedding(t)
1145
+ else:
1146
+ # Original embedding needs unsqueeze
1147
+ time_features = self.time_embedding(t.unsqueeze(-1))
1148
+
1149
+ # Combine and classify
1150
+ combined = torch.cat([image_features, time_features], dim=1)
1151
+ return self.classifier(combined)
1152
+
1153
+ class CNNRouter(nn.Module):
1154
+ """Simple CNN router for cluster classification"""
1155
+
1156
+ def __init__(self, config) -> None:
1157
+ super().__init__()
1158
+
1159
+ # Default params
1160
+ default_params = {
1161
+ "hidden_dims": [64, 128, 256],
1162
+ "use_dit_time_embed": False, # Whether to use DiT-style time embedding
1163
+ }
1164
+ params = {**default_params, **config.router_params}
1165
+
1166
+ # CNN backbone
1167
+ self.backbone = self._build_cnn(config.num_channels, params["hidden_dims"])
1168
+
1169
+ # Time embedding - support both styles
1170
+ self.use_dit_time_embed = params.get("use_dit_time_embed", False)
1171
+ if self.use_dit_time_embed:
1172
+ # Use DiT-style timestep embedding, output to 128 dims for CNN
1173
+ self.time_embedding = DiTTimestepEmbedder(128)
1174
+ else:
1175
+ # Original simple time embedding
1176
+ self.time_embedding = nn.Sequential(
1177
+ nn.Linear(1, 128),
1178
+ nn.SiLU(),
1179
+ nn.Linear(128, 128)
1180
+ )
1181
+
1182
+ # Classifier
1183
+ self.classifier = nn.Sequential(
1184
+ nn.Linear(params["hidden_dims"][-1] + 128, 256),
1185
+ nn.ReLU(),
1186
+ nn.Dropout(0.1),
1187
+ nn.Linear(256, config.num_clusters)
1188
+ )
1189
+
1190
+ def _build_cnn(self, in_channels: int, hidden_dims: List[int]) -> nn.Sequential:
1191
+ layers = []
1192
+ prev_dim = in_channels
1193
+
1194
+ for dim in hidden_dims:
1195
+ layers.extend([
1196
+ nn.Conv2d(prev_dim, dim, 3, padding=1),
1197
+ nn.ReLU(),
1198
+ nn.Conv2d(dim, dim, 3, padding=1),
1199
+ nn.ReLU(),
1200
+ nn.MaxPool2d(2)
1201
+ ])
1202
+ prev_dim = dim
1203
+
1204
+ layers.append(nn.AdaptiveAvgPool2d(1))
1205
+ layers.append(nn.Flatten())
1206
+
1207
+ return nn.Sequential(*layers)
1208
+
1209
+ def forward(self, xt: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
1210
+ # CNN features
1211
+ img_features = self.backbone(xt)
1212
+
1213
+ # Time features
1214
+ if self.use_dit_time_embed:
1215
+ # DiT embedder expects raw timesteps
1216
+ time_features = self.time_embedding(t)
1217
+ else:
1218
+ # Original embedding needs unsqueeze
1219
+ time_features = self.time_embedding(t.unsqueeze(-1))
1220
+
1221
+ # Combine and classify
1222
+ combined = torch.cat([img_features, time_features], dim=1)
1223
+ return self.classifier(combined)
1224
+
1225
+ class DiTRouter(nn.Module):
1226
+ """DiT B/2 router for cluster classification"""
1227
+
1228
+ def __init__(self, config):
1229
+ super().__init__()
1230
+
1231
+ # DiT B/2 specifications
1232
+ default_params = {
1233
+ "hidden_size": 768, # DiT-B uses 768
1234
+ "num_layers": 12, # DiT-B uses 12 layers
1235
+ "num_heads": 12, # DiT-B uses 12 heads
1236
+ "patch_size": 2, # For latent space (32x32 -> 16x16 patches)
1237
+ "in_channels": 4, # VAE latent channels
1238
+ "mlp_ratio": 4.0,
1239
+ "use_dit_time_embed": False, # Whether to use DiT-style time embedding
1240
+ }
1241
+ params = {**default_params, **config.router_params}
1242
+
1243
+ self.patch_size = params["patch_size"]
1244
+ self.in_channels = params["in_channels"]
1245
+ self.hidden_size = params["hidden_size"]
1246
+ self.num_heads = params["num_heads"]
1247
+ self.num_clusters = config.num_clusters
1248
+
1249
+ # Patch embedding (same as expert)
1250
+ self.patch_embed = nn.Conv2d(
1251
+ self.in_channels, self.hidden_size,
1252
+ kernel_size=self.patch_size, stride=self.patch_size
1253
+ )
1254
+
1255
+ # Calculate number of patches
1256
+ latent_size = getattr(config, 'image_size', 32) # Assuming 256/8=32 for VAE
1257
+ self.num_patches = (latent_size // self.patch_size) ** 2
1258
+
1259
+ # Fixed sin-cos positional embedding (same as expert)
1260
+ self.pos_embed = nn.Parameter(
1261
+ torch.zeros(1, self.num_patches, self.hidden_size),
1262
+ requires_grad=False
1263
+ )
1264
+
1265
+ # CLS token (KEY ADDITION from paper)
1266
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, self.hidden_size))
1267
+
1268
+ # Time embedding - match expert's choice
1269
+ self.use_dit_time_embed = params.get("use_dit_time_embed", False)
1270
+ if self.use_dit_time_embed:
1271
+ self.time_embed = DiTTimestepEmbedder(self.hidden_size)
1272
+ else:
1273
+ self.time_embed = TimestepEmbedder(self.hidden_size)
1274
+
1275
+ # DiT blocks with AdaLN (reuse DiTBlock from expert)
1276
+ # Note: Router doesn't need text conditioning
1277
+ self.layers = nn.ModuleList([
1278
+ DiTBlock(self.hidden_size, self.num_heads, params["mlp_ratio"], use_text=False)
1279
+ for _ in range(params["num_layers"])
1280
+ ])
1281
+
1282
+ # Final layer norm
1283
+ self.norm_final = nn.LayerNorm(self.hidden_size, elementwise_affine=False, eps=1e-6)
1284
+
1285
+ # Linear classifier on CLS token (as specified in paper)
1286
+ # self.head = nn.Linear(self.hidden_size, self.num_clusters)
1287
+ self.head = nn.Sequential(
1288
+ nn.Linear(self.hidden_size, self.hidden_size),
1289
+ nn.GELU(),
1290
+ nn.LayerNorm(self.hidden_size),
1291
+ nn.Dropout(0.1),
1292
+ nn.Linear(self.hidden_size, self.num_clusters)
1293
+ )
1294
+
1295
+ # Initialize weights
1296
+ self.initialize_weights()
1297
+
1298
+ def initialize_weights(self):
1299
+ # Initialize transformer layers
1300
+ def _basic_init(module):
1301
+ if isinstance(module, nn.Linear):
1302
+ torch.nn.init.xavier_uniform_(module.weight)
1303
+ if module.bias is not None:
1304
+ nn.init.constant_(module.bias, 0)
1305
+ self.apply(_basic_init)
1306
+
1307
+ # Initialize CLS token
1308
+ nn.init.normal_(self.cls_token, std=0.02)
1309
+
1310
+ # Initialize positional embedding with sin-cos (same as expert)
1311
+ grid_size = int(self.num_patches ** 0.5)
1312
+ pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], grid_size)
1313
+ self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
1314
+
1315
+ # Initialize patch_embed like nn.Linear
1316
+ w = self.patch_embed.weight.data
1317
+ nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
1318
+ if self.patch_embed.bias is not None:
1319
+ nn.init.constant_(self.patch_embed.bias, 0)
1320
+
1321
+ # Initialize timestep embedding MLP
1322
+ if hasattr(self.time_embed, 'mlp'):
1323
+ nn.init.normal_(self.time_embed.mlp[0].weight, std=0.02)
1324
+ nn.init.normal_(self.time_embed.mlp[2].weight, std=0.02)
1325
+
1326
+ # Zero-out adaLN modulation in blocks (following expert initialization)
1327
+ for block in self.layers:
1328
+ nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
1329
+ nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
1330
+
1331
+ # # Initialize classification head (simpler version for classification head)
1332
+ # nn.init.constant_(self.head.weight, 0)
1333
+ # nn.init.constant_(self.head.bias, 0)
1334
+
1335
+ # Initialize classification head (Sequential)
1336
+ # Initialize intermediate layers normally, zero-out final layer
1337
+ nn.init.normal_(self.head[0].weight, std=0.02) # First linear layer
1338
+ if self.head[0].bias is not None:
1339
+ nn.init.constant_(self.head[0].bias, 0)
1340
+
1341
+ # Zero-out final classification layer (following DiT paper)
1342
+ nn.init.constant_(self.head[-1].weight, 0) # Last linear layer
1343
+ if self.head[-1].bias is not None:
1344
+ nn.init.constant_(self.head[-1].bias, 0)
1345
+
1346
+ def forward(self, xt: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
1347
+ B, C, H, W = xt.shape
1348
+
1349
+ # Match expert's timestep interpretation
1350
+ if t.max() <= 1.0 and t.min() >= 0.0:
1351
+ t = t * 999.0
1352
+ t = t.clamp(0, 999)
1353
+
1354
+ # Patchify
1355
+ x = self.patch_embed(xt) # [B, hidden_size, H//p, W//p]
1356
+ x = x.flatten(2).transpose(1, 2) # [B, num_patches, hidden_size]
1357
+
1358
+ # Add positional embedding
1359
+ x = x + self.pos_embed
1360
+
1361
+ # Prepend CLS token
1362
+ cls_tokens = self.cls_token.expand(B, -1, -1) # [B, 1, hidden_size]
1363
+ x = torch.cat([cls_tokens, x], dim=1) # [B, 1 + num_patches, hidden_size]
1364
+
1365
+ # Time conditioning
1366
+ c = self.time_embed(t) # [B, hidden_size]
1367
+
1368
+ # Apply DiT blocks with AdaLN modulation
1369
+ for layer in self.layers:
1370
+ x = layer(x, c, text_emb=None)
1371
+
1372
+ # Extract CLS token and apply final norm
1373
+ cls_output = x[:, 0] # [B, hidden_size]
1374
+ cls_output = self.norm_final(cls_output)
1375
+
1376
+ # Linear classification head
1377
+ logits = self.head(cls_output) # [B, num_clusters]
1378
+
1379
+ return logits
1380
+
1381
+ # =============================================================================
1382
+ # DETERMINISTIC ROUTER (for controlled experiments)
1383
+ # =============================================================================
1384
+
1385
+ class DeterministicTimestepRouter(nn.Module):
1386
+ """
1387
+ Deterministic router that assigns experts based on timestep.
1388
+
1389
+ Useful for controlled experiments where you want to test specific routing strategies,
1390
+ such as: "high noise → DDPM expert, low noise → FM expert"
1391
+
1392
+ Args:
1393
+ config: Config object with router_params containing:
1394
+ - timestep_threshold: t value to switch experts (default: 0.5)
1395
+ - high_noise_expert: Expert ID for t > threshold (default: 0, typically DDPM)
1396
+ - low_noise_expert: Expert ID for t <= threshold (default: 1, typically FM)
1397
+
1398
+ Example config:
1399
+ router_architecture: "deterministic_timestep"
1400
+ router_params:
1401
+ timestep_threshold: 0.5
1402
+ high_noise_expert: 0 # DDPM for high noise
1403
+ low_noise_expert: 1 # FM for low noise
1404
+ """
1405
+
1406
+ def __init__(self, config):
1407
+ super().__init__()
1408
+ self.num_experts = config.num_experts
1409
+ self.threshold = config.router_params.get('timestep_threshold', 0.5)
1410
+ self.high_noise_expert = config.router_params.get('high_noise_expert', 0)
1411
+ self.low_noise_expert = config.router_params.get('low_noise_expert', 1)
1412
+
1413
+ # Validate expert IDs
1414
+ assert 0 <= self.high_noise_expert < self.num_experts, \
1415
+ f"high_noise_expert {self.high_noise_expert} out of range [0, {self.num_experts})"
1416
+ assert 0 <= self.low_noise_expert < self.num_experts, \
1417
+ f"low_noise_expert {self.low_noise_expert} out of range [0, {self.num_experts})"
1418
+
1419
+ # Validate threshold
1420
+ assert 0.0 <= self.threshold <= 1.0, \
1421
+ f"timestep_threshold {self.threshold} must be in [0, 1]"
1422
+
1423
+ # This router has no trainable parameters
1424
+ # Register threshold as buffer (not trained, but saved with model)
1425
+ self.register_buffer('_threshold', torch.tensor(self.threshold))
1426
+
1427
+ print(f"DeterministicTimestepRouter initialized:")
1428
+ print(f" Threshold: {self.threshold}")
1429
+ print(f" High noise (t > {self.threshold}) → Expert {self.high_noise_expert}")
1430
+ print(f" Low noise (t <= {self.threshold}) → Expert {self.low_noise_expert}")
1431
+
1432
+ def forward(self, x: torch.Tensor, t: torch.Tensor, **kwargs) -> torch.Tensor:
1433
+ """
1434
+ Returns one-hot routing probabilities based on timestep.
1435
+
1436
+ Args:
1437
+ x: Input tensor (unused, but kept for API compatibility with other routers)
1438
+ t: Timesteps, shape (B,)
1439
+
1440
+ Returns:
1441
+ routing_probs: Shape (B, num_experts), one-hot encoded
1442
+ """
1443
+ B = t.shape[0]
1444
+ device = t.device
1445
+
1446
+ # Initialize routing probabilities (all zeros)
1447
+ routing_probs = torch.zeros(B, self.num_experts, device=device)
1448
+
1449
+ # High noise (t > threshold) → high_noise_expert
1450
+ # Low noise (t <= threshold) → low_noise_expert
1451
+ high_noise_mask = t > self.threshold
1452
+ routing_probs[high_noise_mask, self.high_noise_expert] = 1.0
1453
+ routing_probs[~high_noise_mask, self.low_noise_expert] = 1.0
1454
+
1455
+ return routing_probs
1456
+
1457
+ def train(self, mode: bool = True):
1458
+ """Override train() - this router is never trained, always in eval mode"""
1459
+ return super(DeterministicTimestepRouter, self).train(False)
1460
+
1461
+ # =============================================================================
1462
+ # ADAPTIVE VIDEO ROUTER (for Video DDM)
1463
+ # =============================================================================
1464
+
1465
+ class AdaptiveVideoRouter(nn.Module):
1466
+ """
1467
+ Time-adaptive router for video DDM.
1468
+
1469
+ Key innovation: Learns optimal weighting of information sources
1470
+ at each noise level, solving the "motion invisible at t=1" problem.
1471
+
1472
+ Information availability is time-dependent:
1473
+ t ~ 1.0: Only text/first_frame informative → Route on conditioning
1474
+ t ~ 0.5: Structure emerging → Latent becomes useful
1475
+ t ~ 0.1: Near clean → Full information available
1476
+
1477
+ Expected learned behavior:
1478
+ | Noise Level | Text | Frame | Latent | Behavior |
1479
+ |-------------|------|-------|--------|-----------------------------|
1480
+ | t ~ 1.0 | ~0.7 | ~0.2 | ~0.1 | Routes on text semantics |
1481
+ | t ~ 0.5 | ~0.4 | ~0.3 | ~0.3 | Balanced; emerging structure|
1482
+ | t ~ 0.1 | ~0.2 | ~0.2 | ~0.6 | Trusts latent; fine-grained |
1483
+
1484
+ Enhancements:
1485
+ - Masked mean pooling for text (handles variable-length prompts)
1486
+ - Temporal-aware latent encoder (captures motion patterns)
1487
+ - Temperature scaling for inference control
1488
+ """
1489
+
1490
+ def __init__(self, config):
1491
+ super().__init__()
1492
+
1493
+ # Default params
1494
+ default_params = {
1495
+ "hidden_dim": 512,
1496
+ "text_embed_dim": 768, # CLIP-L text embedding dimension
1497
+ "frame_embed_dim": 768, # DINOv2-B (base) feature dimension
1498
+ "latent_channels": 16, # VAE latent channels (CogVideoX uses 16)
1499
+ "latent_conv_dim": 64, # Intermediate conv channels for latent encoder
1500
+ "dropout": 0.1,
1501
+ "temporal_pool_mode": "attention", # "attention", "avg", or "max"
1502
+ "normalize_inputs": True, # L2-normalize text/frame inputs (match clustering)
1503
+ }
1504
+ params = {**default_params, **getattr(config, 'router_params', {})}
1505
+
1506
+ self.hidden_dim = params["hidden_dim"]
1507
+ self.num_experts = getattr(config, 'num_experts', config.num_clusters)
1508
+ self.latent_channels = params["latent_channels"]
1509
+ self.latent_conv_dim = params["latent_conv_dim"]
1510
+ self.temporal_pool_mode = params["temporal_pool_mode"]
1511
+ self.normalize_inputs = params.get("normalize_inputs", True)
1512
+
1513
+ # === Information Source Encoders ===
1514
+
1515
+ # Text pathway (always available, primary signal at high t)
1516
+ self.text_encoder = nn.Sequential(
1517
+ nn.Linear(params["text_embed_dim"], self.hidden_dim),
1518
+ nn.LayerNorm(self.hidden_dim),
1519
+ nn.GELU(),
1520
+ nn.Linear(self.hidden_dim, self.hidden_dim)
1521
+ )
1522
+
1523
+ # First frame pathway (available for I2V tasks)
1524
+ # Uses DINOv2 features extracted from the first frame
1525
+ self.frame_encoder = nn.Sequential(
1526
+ nn.Linear(params["frame_embed_dim"], self.hidden_dim),
1527
+ nn.LayerNorm(self.hidden_dim),
1528
+ nn.GELU(),
1529
+ nn.Linear(self.hidden_dim, self.hidden_dim)
1530
+ )
1531
+
1532
+ # === Temporal-Aware Latent Encoder ===
1533
+ # Captures both spatial content and temporal motion patterns
1534
+
1535
+ # Spatial feature extraction (per-frame)
1536
+ self.spatial_conv = nn.Sequential(
1537
+ nn.Conv3d(params["latent_channels"], params["latent_conv_dim"],
1538
+ kernel_size=(1, 3, 3), padding=(0, 1, 1)), # Spatial only
1539
+ nn.GroupNorm(8, params["latent_conv_dim"]),
1540
+ nn.GELU(),
1541
+ )
1542
+
1543
+ # Temporal feature extraction (motion patterns)
1544
+ self.temporal_conv = nn.Sequential(
1545
+ nn.Conv3d(params["latent_conv_dim"], params["latent_conv_dim"],
1546
+ kernel_size=(3, 1, 1), padding=(1, 0, 0)), # Temporal only
1547
+ nn.GroupNorm(8, params["latent_conv_dim"]),
1548
+ nn.GELU(),
1549
+ )
1550
+
1551
+ # Combined spatio-temporal processing
1552
+ self.st_conv = nn.Sequential(
1553
+ nn.Conv3d(params["latent_conv_dim"], params["latent_conv_dim"],
1554
+ kernel_size=3, padding=1), # Full 3D
1555
+ nn.GroupNorm(8, params["latent_conv_dim"]),
1556
+ nn.GELU(),
1557
+ )
1558
+
1559
+ # Spatial pooling (keep temporal dimension)
1560
+ self.spatial_pool = nn.AdaptiveAvgPool3d((None, 1, 1)) # [B, C, T, 1, 1]
1561
+
1562
+ # Temporal attention pooling (learns which frames matter for routing)
1563
+ if self.temporal_pool_mode == "attention":
1564
+ self.temporal_attn = nn.Sequential(
1565
+ nn.Linear(params["latent_conv_dim"], params["latent_conv_dim"] // 4),
1566
+ nn.Tanh(),
1567
+ nn.Linear(params["latent_conv_dim"] // 4, 1),
1568
+ )
1569
+
1570
+ # Motion feature extractor (frame differences)
1571
+ self.motion_encoder = nn.Sequential(
1572
+ nn.Linear(params["latent_conv_dim"], params["latent_conv_dim"]),
1573
+ nn.GELU(),
1574
+ nn.Linear(params["latent_conv_dim"], self.hidden_dim // 2),
1575
+ )
1576
+
1577
+ # Content feature projector
1578
+ self.content_proj = nn.Linear(params["latent_conv_dim"], self.hidden_dim // 2)
1579
+
1580
+ # Final latent projection (combines content + motion)
1581
+ self.latent_proj = nn.Sequential(
1582
+ nn.Linear(self.hidden_dim, self.hidden_dim),
1583
+ nn.LayerNorm(self.hidden_dim),
1584
+ )
1585
+
1586
+ # === Time-Dependent Weighting ===
1587
+
1588
+ # Time embedding using existing infrastructure
1589
+ self.time_embed = TimestepEmbedder(self.hidden_dim)
1590
+
1591
+ self.time_mlp = nn.Sequential(
1592
+ nn.Linear(self.hidden_dim, self.hidden_dim),
1593
+ nn.GELU(),
1594
+ nn.Linear(self.hidden_dim, self.hidden_dim)
1595
+ )
1596
+
1597
+ # Learns adaptive weighting: at high t → trust text; at low t → trust latent
1598
+ self.source_weighting = nn.Sequential(
1599
+ nn.Linear(self.hidden_dim, 128),
1600
+ nn.GELU(),
1601
+ nn.Linear(128, 3), # [text, frame, latent] weights
1602
+ nn.Softmax(dim=-1)
1603
+ )
1604
+
1605
+ # === Routing Head ===
1606
+
1607
+ self.router_head = nn.Sequential(
1608
+ nn.Linear(self.hidden_dim, self.hidden_dim),
1609
+ nn.GELU(),
1610
+ nn.LayerNorm(self.hidden_dim),
1611
+ nn.Dropout(params["dropout"]),
1612
+ nn.Linear(self.hidden_dim, self.num_experts)
1613
+ )
1614
+
1615
+ # Initialize weights
1616
+ self.initialize_weights()
1617
+
1618
+ def initialize_weights(self):
1619
+ """Initialize weights following DiT conventions."""
1620
+ def _basic_init(module):
1621
+ if isinstance(module, nn.Linear):
1622
+ torch.nn.init.xavier_uniform_(module.weight)
1623
+ if module.bias is not None:
1624
+ nn.init.constant_(module.bias, 0)
1625
+ elif isinstance(module, nn.Conv3d):
1626
+ # Flatten spatial dims for xavier init
1627
+ w = module.weight.data
1628
+ nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
1629
+ if module.bias is not None:
1630
+ nn.init.constant_(module.bias, 0)
1631
+ self.apply(_basic_init)
1632
+
1633
+ # Initialize timestep embedding MLP (following DiT)
1634
+ if hasattr(self.time_embed, 'mlp'):
1635
+ nn.init.normal_(self.time_embed.mlp[0].weight, std=0.02)
1636
+ nn.init.normal_(self.time_embed.mlp[2].weight, std=0.02)
1637
+
1638
+ # Small non-zero initialization for final routing layer
1639
+ # (pure zeros cause uniform outputs that break temperature scaling)
1640
+ nn.init.normal_(self.router_head[-1].weight, std=0.01)
1641
+ nn.init.constant_(self.router_head[-1].bias, 0)
1642
+
1643
+ # Initialize source weighting to start roughly uniform
1644
+ # The softmax will make [0, 0, 0] → [0.33, 0.33, 0.33]
1645
+ nn.init.constant_(self.source_weighting[-2].weight, 0)
1646
+ nn.init.constant_(self.source_weighting[-2].bias, 0)
1647
+
1648
+ # Initialize temporal attention to uniform attention
1649
+ if self.temporal_pool_mode == "attention":
1650
+ nn.init.constant_(self.temporal_attn[-1].weight, 0)
1651
+ nn.init.constant_(self.temporal_attn[-1].bias, 0)
1652
+
1653
+ def _masked_mean_pool(self, embeddings: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
1654
+ """
1655
+ Compute masked mean pooling over sequence dimension.
1656
+
1657
+ Args:
1658
+ embeddings: [B, seq_len, embed_dim] - Token embeddings
1659
+ attention_mask: [B, seq_len] - 1 for real tokens, 0 for padding
1660
+
1661
+ Returns:
1662
+ pooled: [B, embed_dim] - Pooled representation
1663
+ """
1664
+ if attention_mask is None:
1665
+ # No mask provided, use simple mean
1666
+ return embeddings.mean(dim=1)
1667
+
1668
+ # Expand mask for broadcasting: [B, seq_len] -> [B, seq_len, 1]
1669
+ mask = attention_mask.unsqueeze(-1).to(embeddings.dtype)
1670
+
1671
+ # Masked sum
1672
+ masked_sum = (embeddings * mask).sum(dim=1) # [B, embed_dim]
1673
+
1674
+ # Count of valid tokens (avoid division by zero)
1675
+ token_counts = mask.sum(dim=1).clamp(min=1.0) # [B, 1]
1676
+
1677
+ return masked_sum / token_counts
1678
+
1679
+ def _encode_latent_temporal(self, x_t: torch.Tensor) -> torch.Tensor:
1680
+ """
1681
+ Encode video latent with temporal awareness.
1682
+
1683
+ Extracts both:
1684
+ - Content features: What is in the video (spatial)
1685
+ - Motion features: How things move (temporal differences)
1686
+
1687
+ Args:
1688
+ x_t: [B, C, T, H, W] - Noisy video latent
1689
+
1690
+ Returns:
1691
+ latent_feat: [B, hidden_dim] - Combined latent features
1692
+ """
1693
+ B, C, T, H, W = x_t.shape
1694
+
1695
+ # 1. Spatial feature extraction
1696
+ spatial_feat = self.spatial_conv(x_t) # [B, conv_dim, T, H, W]
1697
+
1698
+ # 2. Temporal feature extraction (captures local motion)
1699
+ temporal_feat = self.temporal_conv(spatial_feat) # [B, conv_dim, T, H, W]
1700
+
1701
+ # 3. Combined spatio-temporal processing
1702
+ st_feat = self.st_conv(temporal_feat) # [B, conv_dim, T, H, W]
1703
+
1704
+ # 4. Pool spatially, keep temporal: [B, conv_dim, T, 1, 1] -> [B, T, conv_dim]
1705
+ pooled = self.spatial_pool(st_feat).squeeze(-1).squeeze(-1) # [B, conv_dim, T]
1706
+ pooled = pooled.permute(0, 2, 1) # [B, T, conv_dim]
1707
+
1708
+ # 5. Temporal pooling with optional attention
1709
+ if self.temporal_pool_mode == "attention" and T > 1:
1710
+ # Learn which frames matter for routing
1711
+ attn_scores = self.temporal_attn(pooled).squeeze(-1) # [B, T]
1712
+ attn_weights = F.softmax(attn_scores, dim=-1) # [B, T]
1713
+ content_feat = (pooled * attn_weights.unsqueeze(-1)).sum(dim=1) # [B, conv_dim]
1714
+ elif self.temporal_pool_mode == "max":
1715
+ content_feat = pooled.max(dim=1)[0] # [B, conv_dim]
1716
+ else: # "avg"
1717
+ content_feat = pooled.mean(dim=1) # [B, conv_dim]
1718
+
1719
+ # 6. Extract motion features (frame differences)
1720
+ if T > 1:
1721
+ # Compute frame-to-frame differences
1722
+ frame_diffs = pooled[:, 1:] - pooled[:, :-1] # [B, T-1, conv_dim]
1723
+
1724
+ # Motion magnitude and direction encoding
1725
+ motion_feat = self.motion_encoder(frame_diffs.mean(dim=1)) # [B, hidden_dim//2]
1726
+ else:
1727
+ # Single frame, no motion
1728
+ motion_feat = torch.zeros(B, self.hidden_dim // 2, device=x_t.device)
1729
+
1730
+ # 7. Project content features
1731
+ content_proj = self.content_proj(content_feat) # [B, hidden_dim//2]
1732
+
1733
+ # 8. Combine content + motion
1734
+ combined = torch.cat([content_proj, motion_feat], dim=-1) # [B, hidden_dim]
1735
+ latent_feat = self.latent_proj(combined) # [B, hidden_dim]
1736
+
1737
+ return latent_feat
1738
+
1739
+ def forward(self, x_t: torch.Tensor, t: torch.Tensor,
1740
+ text_embed: torch.Tensor,
1741
+ first_frame_feat: Optional[torch.Tensor] = None,
1742
+ attention_mask: Optional[torch.Tensor] = None,
1743
+ temperature: float = 1.0) -> torch.Tensor:
1744
+ """
1745
+ Compute routing logits with time-adaptive information weighting.
1746
+
1747
+ Args:
1748
+ x_t: Noisy video latent [B, C, T, H, W]
1749
+ t: Noise level [B] in [0, 1] or [0, 999]
1750
+ text_embed: CLIP text embedding [B, text_embed_dim] or [B, seq_len, text_embed_dim]
1751
+ first_frame_feat: Optional DINOv2 features [B, frame_embed_dim]
1752
+ attention_mask: Optional [B, seq_len] mask for text (1=valid, 0=padding)
1753
+ temperature: Softmax temperature for sharper/softer routing (default: 1.0)
1754
+
1755
+ Returns:
1756
+ logits: Expert selection logits [B, num_experts] (scaled by temperature)
1757
+ """
1758
+ B = x_t.shape[0]
1759
+ device = x_t.device
1760
+
1761
+ # === Encode each information source ===
1762
+
1763
+ # Handle both pooled [B, D] and sequence [B, seq_len, D] text embeddings
1764
+ if text_embed.dim() == 3:
1765
+ # Use masked mean pooling for sequence embeddings
1766
+ text_embed_pooled = self._masked_mean_pool(text_embed, attention_mask)
1767
+ else:
1768
+ # Already pooled
1769
+ text_embed_pooled = text_embed
1770
+
1771
+ # L2-normalize inputs to match clustering preprocessing
1772
+ if self.normalize_inputs:
1773
+ text_embed_pooled = F.normalize(text_embed_pooled, p=2, dim=-1)
1774
+
1775
+ text_feat = self.text_encoder(text_embed_pooled) # [B, hidden_dim]
1776
+
1777
+ # Frame features (optional for T2V, required for I2V)
1778
+ if first_frame_feat is not None:
1779
+ # L2-normalize to match clustering preprocessing
1780
+ if self.normalize_inputs:
1781
+ first_frame_feat = F.normalize(first_frame_feat, p=2, dim=-1)
1782
+ frame_feat = self.frame_encoder(first_frame_feat) # [B, hidden_dim]
1783
+ else:
1784
+ frame_feat = torch.zeros(B, self.hidden_dim, device=device)
1785
+
1786
+ # Latent features from noisy video (temporal-aware encoding)
1787
+ latent_feat = self._encode_latent_temporal(x_t) # [B, hidden_dim]
1788
+
1789
+ # === Time-dependent weighting ===
1790
+
1791
+ # Normalize timesteps to [0, 999] for TimestepEmbedder
1792
+ if t.max() <= 1.0:
1793
+ t_scaled = t * 999.0
1794
+ else:
1795
+ t_scaled = t
1796
+ t_scaled = t_scaled.clamp(0, 999)
1797
+
1798
+ # Get time features
1799
+ time_emb = self.time_embed(t_scaled) # [B, hidden_dim]
1800
+ time_feat = self.time_mlp(time_emb) # [B, hidden_dim]
1801
+
1802
+ # Compute adaptive weights based on noise level
1803
+ # Network learns: high t → high text weight; low t → high latent weight
1804
+ weights = self.source_weighting(time_feat) # [B, 3]
1805
+
1806
+ # === Adaptive combination ===
1807
+
1808
+ combined = (
1809
+ weights[:, 0:1] * text_feat + # Text contribution
1810
+ weights[:, 1:2] * frame_feat + # Frame contribution
1811
+ weights[:, 2:3] * latent_feat # Latent contribution
1812
+ )
1813
+
1814
+ # Final routing decision (incorporate time context)
1815
+ logits = self.router_head(combined + time_feat)
1816
+
1817
+ # Apply temperature scaling (lower temp = sharper routing)
1818
+ if temperature != 1.0:
1819
+ logits = logits / temperature
1820
+
1821
+ return logits
1822
+
1823
+ def get_source_weights(self, t: torch.Tensor) -> torch.Tensor:
1824
+ """
1825
+ Get the learned source weights for given timesteps.
1826
+ Useful for debugging and visualization.
1827
+
1828
+ Args:
1829
+ t: Noise levels [B] in [0, 1] or [0, 999]
1830
+
1831
+ Returns:
1832
+ weights: Source weights [B, 3] for [text, frame, latent]
1833
+ """
1834
+ # Normalize timesteps
1835
+ if t.max() <= 1.0:
1836
+ t_scaled = t * 999.0
1837
+ else:
1838
+ t_scaled = t
1839
+ t_scaled = t_scaled.clamp(0, 999)
1840
+
1841
+ time_emb = self.time_embed(t_scaled)
1842
+ time_feat = self.time_mlp(time_emb)
1843
+ weights = self.source_weighting(time_feat)
1844
+
1845
+ return weights
1846
+
1847
+ # =============================================================================
1848
+ # MODEL FACTORY FUNCTIONS
1849
+ # =============================================================================
1850
+
1851
+ def create_expert(config, expert_id: Optional[int] = None) -> nn.Module:
1852
+ """
1853
+ Factory function to create expert model
1854
+
1855
+ Args:
1856
+ config: Config object
1857
+ expert_id: Optional expert ID for per-expert schedule/objective configuration
1858
+ """
1859
+ # Make a copy of config to avoid modifying the original
1860
+ import copy
1861
+ config = copy.copy(config)
1862
+ config.expert_params = config.expert_params.copy()
1863
+
1864
+ # Inject schedule_type into expert_params if not already present
1865
+ if "schedule_type" not in config.expert_params:
1866
+ # Check for per-expert schedule first (with backward compatibility)
1867
+ if (hasattr(config, 'expert_schedule_types') and
1868
+ config.expert_schedule_types and
1869
+ expert_id is not None and
1870
+ expert_id in config.expert_schedule_types):
1871
+ config.expert_params["schedule_type"] = config.expert_schedule_types[expert_id]
1872
+ else:
1873
+ # Use default schedule_type (with fallback for old configs)
1874
+ config.expert_params["schedule_type"] = getattr(config, 'schedule_type', 'linear_interp')
1875
+
1876
+ # Inject objective_type into expert_params if not already present
1877
+ if "objective_type" not in config.expert_params:
1878
+ # Check for per-expert objectives (with backward compatibility)
1879
+ if (hasattr(config, 'expert_objectives') and
1880
+ config.expert_objectives and
1881
+ expert_id is not None and
1882
+ expert_id in config.expert_objectives):
1883
+ config.expert_params["objective_type"] = config.expert_objectives[expert_id]
1884
+ else:
1885
+ # Use default objective (with fallback for old configs)
1886
+ config.expert_params["objective_type"] = getattr(config, 'default_objective', 'fm')
1887
+
1888
+ if config.expert_architecture == "unet":
1889
+ return UNetExpert(config)
1890
+ elif config.expert_architecture == "simple_cnn":
1891
+ return SimpleCNNExpert(config)
1892
+ elif config.expert_architecture == "dit":
1893
+ return DiTExpert(config)
1894
+ else:
1895
+ raise ValueError(f"Unknown expert architecture: {config.expert_architecture}")
1896
+
1897
+ def create_router(config) -> Optional[nn.Module]:
1898
+ """Factory function to create router model"""
1899
+
1900
+ if config.router_architecture == "none" or config.is_monolithic:
1901
+ return None
1902
+ elif config.router_architecture == "deterministic_timestep":
1903
+ return DeterministicTimestepRouter(config)
1904
+ elif config.router_architecture == "vit":
1905
+ return ViTRouter(config)
1906
+ elif config.router_architecture == "cnn":
1907
+ return CNNRouter(config)
1908
+ elif config.router_architecture == "dit":
1909
+ return DiTRouter(config)
1910
+ elif config.router_architecture == "adaptive_video":
1911
+ return AdaptiveVideoRouter(config)
1912
+ else:
1913
+ raise ValueError(f"Unknown router architecture: {config.router_architecture}")
src/schedules.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # src/schedules.py
2
+ """
3
+ Centralized noise schedule manager for diffusion models.
4
+
5
+ Supports three schedules:
6
+ 1. 'cosine': Cosine schedule (Nichol & Dhariwal 2021)
7
+ 2. 'linear_beta': Linear beta schedule (Ho et al. 2020)
8
+ 3. 'linear_interp': Linear interpolation - Flow Matching default
9
+
10
+ All schedules return (alpha_t, sigma_t) such that:
11
+ x_t = alpha_t * x_0 + sigma_t * epsilon
12
+ alpha_t^2 + sigma_t^2 = 1 (variance preserving)
13
+ """
14
+
15
+ import torch
16
+ import math
17
+ from typing import Tuple
18
+
19
+
20
+ class NoiseSchedule:
21
+ """
22
+ Centralized noise schedule manager.
23
+
24
+ Args:
25
+ schedule_type: One of ['cosine', 'linear_beta', 'linear_interp']
26
+ """
27
+
28
+ def __init__(self, schedule_type: str = 'linear_interp'):
29
+ assert schedule_type in ['cosine', 'linear_beta', 'linear_interp'], \
30
+ f"Unknown schedule: {schedule_type}. Must be one of ['cosine', 'linear_beta', 'linear_interp']"
31
+ self.schedule_type = schedule_type
32
+
33
+ # Linear beta schedule parameters (if used)
34
+ self.beta_min = 0.0001
35
+ self.beta_max = 0.02
36
+ self.num_timesteps = 1000 # T in discrete formulation
37
+
38
+ # Cosine schedule parameter
39
+ self.s = 0.008 # Small offset to prevent beta from being too small near t=0
40
+
41
+ def get_schedule(self, t: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
42
+ """
43
+ Get (alpha_t, sigma_t) for given timesteps.
44
+
45
+ Args:
46
+ t: Tensor of timesteps in [0, 1], shape (B,)
47
+
48
+ Returns:
49
+ alpha_t: Shape (B,), coefficient for x_0
50
+ sigma_t: Shape (B,), coefficient for epsilon
51
+ """
52
+ if self.schedule_type == 'cosine':
53
+ return self._cosine_schedule(t)
54
+ elif self.schedule_type == 'linear_beta':
55
+ return self._linear_beta_schedule(t)
56
+ elif self.schedule_type == 'linear_interp':
57
+ return self._linear_interpolation(t)
58
+
59
+ def _cosine_schedule(self, t: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
60
+ """
61
+ Cosine schedule: alpha_bar_t = f(t) / f(0)
62
+ where f(t) = cos²((t + s)/(1 + s) * π/2)
63
+
64
+ Reference: "Improved Denoising Diffusion Probabilistic Models"
65
+ (Nichol & Dhariwal, 2021)
66
+
67
+ This schedule provides better conditioning than linear beta schedule,
68
+ especially at very small and very large t values.
69
+ """
70
+ # Compute f(t) = cos²((t + s)/(1 + s) * π/2)
71
+ f_t = torch.cos(((t + self.s) / (1 + self.s)) * math.pi * 0.5) ** 2
72
+
73
+ # Compute f(0) for normalization to ensure alpha_bar_0 = 1
74
+ f_0 = math.cos((self.s / (1 + self.s)) * math.pi * 0.5) ** 2
75
+
76
+ # Normalize: alpha_bar_t = f(t) / f(0)
77
+ alpha_bar_t = f_t / f_0
78
+
79
+ # Clamp to ensure numerical stability
80
+ alpha_bar_t = torch.clamp(alpha_bar_t, min=1e-8, max=1.0)
81
+
82
+ # Compute coefficients
83
+ alpha_t = torch.sqrt(alpha_bar_t)
84
+ sigma_t = torch.sqrt(1 - alpha_bar_t)
85
+
86
+ return alpha_t, sigma_t
87
+
88
+ def _linear_beta_schedule(self, t: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
89
+ """
90
+ Linear beta schedule: beta_t increases linearly from beta_min to beta_max
91
+
92
+ Reference: "Denoising Diffusion Probabilistic Models" (Ho et al., 2020)
93
+
94
+ For continuous time t ∈ [0,1]:
95
+ beta(t) = beta_min + t * (beta_max - beta_min)
96
+ alpha_bar(t) = exp(-0.5 * integral_0^t beta(s) ds)
97
+ = exp(-0.5 * t * (beta_min + 0.5 * t * (beta_max - beta_min)))
98
+ """
99
+ # Compute alpha_bar(t) = exp(-0.5 * integral beta(s) ds)
100
+ # integral_0^t (beta_min + s * (beta_max - beta_min)) ds
101
+ # = beta_min * t + 0.5 * t^2 * (beta_max - beta_min)
102
+ integral_beta = self.beta_min * t + 0.5 * t * t * (self.beta_max - self.beta_min)
103
+ alpha_bar_t = torch.exp(-0.5 * integral_beta * self.num_timesteps)
104
+
105
+ # Compute coefficients
106
+ alpha_t = torch.sqrt(alpha_bar_t)
107
+ sigma_t = torch.sqrt(1 - alpha_bar_t)
108
+
109
+ return alpha_t, sigma_t
110
+
111
+ def _linear_interpolation(self, t: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
112
+ """
113
+ Linear interpolation: x_t = (1-t) * x_0 + t * epsilon
114
+
115
+ This is the default for Flow Matching but NOT a proper DDPM schedule.
116
+ This is what the current implementation uses.
117
+ """
118
+ alpha_t = 1 - t
119
+ sigma_t = t
120
+ return alpha_t, sigma_t
121
+
122
+ def get_snr(self, t: torch.Tensor) -> torch.Tensor:
123
+ """
124
+ Compute signal-to-noise ratio (SNR) = alpha_t^2 / sigma_t^2
125
+
126
+ Useful for:
127
+ 1. Time warping between different schedules
128
+ 2. Analysis and visualization
129
+
130
+ Args:
131
+ t: Tensor of timesteps in [0, 1]
132
+
133
+ Returns:
134
+ snr: Signal-to-noise ratio at each timestep
135
+ """
136
+ alpha_t, sigma_t = self.get_schedule(t)
137
+ snr = (alpha_t ** 2) / (sigma_t ** 2 + 1e-8)
138
+ return snr
139
+
140
+ def alpha_to_time(self, alpha: torch.Tensor, num_steps: int = 100) -> torch.Tensor:
141
+ """
142
+ Inverse mapping: given alpha, find t
143
+
144
+ Used for inference when you want to specify noise levels directly.
145
+ Uses binary search since schedules are monotonic.
146
+
147
+ Args:
148
+ alpha: Desired alpha values
149
+ num_steps: Number of steps for binary search
150
+
151
+ Returns:
152
+ t: Corresponding timesteps
153
+ """
154
+ device = alpha.device
155
+
156
+ # Binary search for t
157
+ t_candidates = torch.linspace(0, 1, num_steps, device=device)
158
+ alpha_candidates, _ = self.get_schedule(t_candidates)
159
+
160
+ # Find closest match
161
+ distances = torch.abs(alpha_candidates.unsqueeze(0) - alpha.unsqueeze(1))
162
+ indices = torch.argmin(distances, dim=1)
163
+ t = t_candidates[indices]
164
+
165
+ return t
166
+
src/vae_utils.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # src/vae_utils.py
2
+ import torch
3
+ import torch.nn.functional as F
4
+ from diffusers import AutoencoderKL
5
+ from typing import Optional
6
+ import numpy as np
7
+
8
+ class VAEManager:
9
+ """Utility class for VAE encoding/decoding operations"""
10
+
11
+ def __init__(self, model_name: str = "stabilityai/sd-vae-ft-mse", device: str = "cuda"):
12
+ self.device = device
13
+ self.model_name = model_name
14
+ self.vae = None
15
+ self._load_vae()
16
+
17
+ def _load_vae(self):
18
+ """Load VAE model"""
19
+ print(f"Loading VAE: {self.model_name}")
20
+ self.vae = AutoencoderKL.from_pretrained(self.model_name)
21
+ self.vae = self.vae.to(self.device)
22
+ self.vae.eval()
23
+
24
+ # Freeze VAE parameters
25
+ for param in self.vae.parameters():
26
+ param.requires_grad = False
27
+
28
+ def encode(self, images: torch.Tensor) -> torch.Tensor:
29
+ """
30
+ Encode images to latent space
31
+
32
+ Args:
33
+ images: Tensor of shape [B, 3, H, W] in range [-1, 1]
34
+
35
+ Returns:
36
+ latents: Tensor of shape [B, 4, H//8, W//8]
37
+ """
38
+ with torch.no_grad():
39
+ images = images.to(self.device)
40
+ latent_dist = self.vae.encode(images).latent_dist
41
+ latents = latent_dist.sample()
42
+ latents = latents * self.vae.config.scaling_factor
43
+
44
+ return latents
45
+
46
+ def decode(self, latents: torch.Tensor, upscale_factor: Optional[float] = None,
47
+ upscale_mode: str = 'bicubic') -> torch.Tensor:
48
+ """
49
+ Decode latents to images
50
+
51
+ Args:
52
+ latents: Tensor of shape [B, 4, H, W]
53
+ upscale_factor: Optional upscaling factor (e.g., 2.0 for 2x, 1.5 for 1.5x)
54
+ If None, returns images at native resolution (H*8, W*8)
55
+ upscale_mode: Interpolation mode ('bicubic', 'bilinear', 'nearest')
56
+
57
+ Returns:
58
+ images: Tensor of shape [B, 3, H*8*upscale_factor, W*8*upscale_factor] in range [-1, 1]
59
+ """
60
+ with torch.no_grad():
61
+ latents = latents.to(self.device)
62
+ # Rescale latents
63
+ latents = latents / self.vae.config.scaling_factor
64
+ images = self.vae.decode(latents).sample
65
+
66
+ # Apply upscaling if requested
67
+ if upscale_factor is not None and upscale_factor != 1.0:
68
+ _, _, h, w = images.shape
69
+ new_h = int(h * upscale_factor)
70
+ new_w = int(w * upscale_factor)
71
+ images = F.interpolate(
72
+ images,
73
+ size=(new_h, new_w),
74
+ mode=upscale_mode,
75
+ align_corners=False if upscale_mode in ['bilinear', 'bicubic'] else None,
76
+ antialias=True if upscale_mode in ['bilinear', 'bicubic'] else False
77
+ )
78
+
79
+ return images
80
+
81
+ def decode_to_pil(self, latents: torch.Tensor, upscale_factor: Optional[float] = None,
82
+ upscale_mode: str = 'bicubic', target_size: Optional[tuple] = None):
83
+ """
84
+ Decode latents to PIL images
85
+
86
+ Args:
87
+ latents: Tensor of shape [B, 4, H, W]
88
+ upscale_factor: Optional upscaling factor (e.g., 2.0 for 2x)
89
+ upscale_mode: Interpolation mode ('bicubic', 'bilinear', 'nearest')
90
+ target_size: Optional target size as (height, width). Overrides upscale_factor if provided.
91
+
92
+ Returns:
93
+ pil_images: List of PIL images
94
+ """
95
+ from PIL import Image
96
+
97
+ # Decode to tensor
98
+ images = self.decode(latents, upscale_factor=upscale_factor, upscale_mode=upscale_mode)
99
+
100
+ # Apply target size if specified
101
+ if target_size is not None:
102
+ images = F.interpolate(
103
+ images,
104
+ size=target_size,
105
+ mode=upscale_mode,
106
+ align_corners=False if upscale_mode in ['bilinear', 'bicubic'] else None,
107
+ antialias=True if upscale_mode in ['bilinear', 'bicubic'] else False
108
+ )
109
+
110
+ # Convert to [0, 1] range
111
+ images = (images + 1.0) / 2.0
112
+ images = torch.clamp(images, 0, 1)
113
+
114
+ # Convert to PIL
115
+ pil_images = []
116
+ for i in range(images.shape[0]):
117
+ img_array = images[i].cpu().numpy().transpose(1, 2, 0)
118
+ img_array = (img_array * 255).astype(np.uint8)
119
+ pil_image = Image.fromarray(img_array)
120
+ pil_images.append(pil_image)
121
+
122
+ return pil_images
123
+
124
+ @property
125
+ def scaling_factor(self) -> float:
126
+ """Get VAE scaling factor"""
127
+ return self.vae.config.scaling_factor
128
+
129
+ @property
130
+ def latent_channels(self) -> int:
131
+ """Get number of latent channels"""
132
+ return 4 # Standard for Stable Diffusion VAE
133
+
134
+ def create_vae_manager(model_name: str = "stabilityai/sd-vae-ft-mse", device: str = "cuda") -> VAEManager:
135
+ """Factory function to create VAE manager"""
136
+ return VAEManager(model_name, device)
137
+
138
+ def save_images_from_latents(latents: torch.Tensor, save_dir: str, vae_manager: VAEManager, prefix: str = "sample"):
139
+ """
140
+ Save images from latents using VAE decoder
141
+
142
+ Args:
143
+ latents: Tensor of shape [B, 4, H, W]
144
+ save_dir: Directory to save images
145
+ vae_manager: VAE manager instance
146
+ prefix: Filename prefix
147
+ """
148
+ import os
149
+
150
+ os.makedirs(save_dir, exist_ok=True)
151
+
152
+ # Decode to PIL images
153
+ pil_images = vae_manager.decode_to_pil(latents)
154
+
155
+ # Save each image
156
+ for i, pil_image in enumerate(pil_images):
157
+ save_path = os.path.join(save_dir, f"{prefix}_{i:03d}.png")
158
+ pil_image.save(save_path)
159
+
160
+ print(f"Saved {len(pil_images)} images to {save_dir}")
161
+
162
+ def create_image_grid(latents: torch.Tensor, vae_manager: VAEManager, nrow: int = 4) -> torch.Tensor:
163
+ """
164
+ Create image grid from latents
165
+
166
+ Args:
167
+ latents: Tensor of shape [B, 4, H, W]
168
+ vae_manager: VAE manager instance
169
+ nrow: Number of images per row
170
+
171
+ Returns:
172
+ grid: Image grid tensor
173
+ """
174
+ import torchvision.utils as vutils
175
+
176
+ # Decode latents
177
+ images = vae_manager.decode(latents)
178
+
179
+ # Convert to [0, 1] range
180
+ images = (images + 1.0) / 2.0
181
+ images = torch.clamp(images, 0, 1)
182
+
183
+ # Create grid
184
+ grid = vutils.make_grid(images, nrow=nrow, padding=2)
185
+
186
+ return grid
weights/bf16/config.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bf54162afaf045deefb715e9834ed60948d7494354e866e70e76ddaebe575a78
3
+ size 2908
weights/bf16/expert_0.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4a069731935a6285a64e2379c554371997ff32ad1f6c956422cfb83a8086549d
3
+ size 1211979376
weights/bf16/expert_1.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5a5d45e5b96ce31cc3c2c9d8f903fb75c7d0b757be96212ec345ee0e78037d48
3
+ size 1211979376
weights/bf16/expert_2.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9fa3505dfa75f4b82894064cc3c3b70aa6f409796dc7cda8bc14ce3572268a44
3
+ size 1211979376
weights/bf16/expert_3.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c9f53a42c3690ff8e27187a6c42770c888a1ce2fca8c132e181433870a6b4797
3
+ size 1211979376
weights/bf16/expert_4.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:58a367b34eb486789f9e8709384ad45d69768ac302a896fac85bd512134cdb3b
3
+ size 1211979376
weights/bf16/expert_5.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:37c9ce6fa79faa97a029de00fcdedc7e96dbc5de36deabc953ad2ee95c2ab0ad
3
+ size 1211979376
weights/bf16/expert_6.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1fb83aad644a4fe22cd661cc4bedd49c73815dfc91bf81caf6a89dc21f1f90b3
3
+ size 1211979376
weights/bf16/expert_7.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c1d37b9b495d74121080237dbed32a5042ecbd7ed8ed619519cc2946f26e199b
3
+ size 1211979376
weights/bf16/router.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ff8aaa22f59e382227b3b9fe6527010a6929e8b0b7c4322213b392a0ca03a1bf
3
+ size 258286840
weights/bf16/router_config.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0e951c1c39ad5401b33bb3147f62803d76303a7b7ca0e457e4cc0aaf1e585bb5
3
+ size 2744
weights/int8/expert_0.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4a0942e1503de55b07393582bb231fc0c8358cb8f03b329c3e282f8c4a8b861c
3
+ size 606080694
weights/int8/expert_1.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e6a9ad48d7b84574a122f52ef6d619cb8c5d9f3766c1a55af5f0b5d463fbd109
3
+ size 606080672
weights/int8/expert_2.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:baa7887b97f60db4532682be3701f6e9fc9a9dec1446af00ff3f1515055f888e
3
+ size 606080694
weights/int8/expert_3.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c36196779aba27cef9ea66f5775a9ba43886a7811904fb66a9b23b0095800da9
3
+ size 606080694
weights/int8/expert_4.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d255e3a6d89f0bfbff7461dff0fb27fa206a4b9e98a83be87121327d9cac56f7
3
+ size 606080694
weights/int8/expert_5.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f5cabe2f9deba779a5ea14a4d2e038a9aefd0ed5a4d2cd1b7776cd10939ffb21
3
+ size 606080694
weights/int8/expert_6.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f0470248ca040b321a5e72ce73f05c32c9d0cbe8515115021fb0f6065cc3599d
3
+ size 606080694
weights/int8/expert_7.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6f0047faeb6e7d0e5acc0abcbdede83789a5fec6e6d93ef3c4d5903785dd4660
3
+ size 606080694
weights/int8/router.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2a52cce497dd02a88804bc81669eba0ab4957dd2b3c54b8de781dabb5a8c15b2
3
+ size 256740839