BiliSakura commited on
Commit
3b91ebd
·
verified ·
1 Parent(s): 01e062f

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +7 -0
  2. README.md +199 -0
  3. edm2-img512-l-dino/demo.png +3 -0
  4. edm2-img512-l-dino/model_index.json +19 -0
  5. edm2-img512-l-dino/pipeline.py +406 -0
  6. edm2-img512-l-dino/scheduler/scheduler_config.json +11 -0
  7. edm2-img512-l-dino/unet/config.json +31 -0
  8. edm2-img512-l-dino/unet/diffusion_pytorch_model.safetensors +3 -0
  9. edm2-img512-l-dino/unet/unet_edm2.py +434 -0
  10. edm2-img512-l-dino/vae/config.json +38 -0
  11. edm2-img512-l-dino/vae/diffusion_pytorch_model.safetensors +3 -0
  12. edm2-img512-l-fid/generator_test.png +3 -0
  13. edm2-img512-l-fid/model_index.json +19 -0
  14. edm2-img512-l-fid/pipeline.py +406 -0
  15. edm2-img512-l-fid/scheduler/scheduler_config.json +11 -0
  16. edm2-img512-l-fid/unet/config.json +31 -0
  17. edm2-img512-l-fid/unet/diffusion_pytorch_model.safetensors +3 -0
  18. edm2-img512-l-fid/unet/unet_edm2.py +434 -0
  19. edm2-img512-l-fid/vae/config.json +38 -0
  20. edm2-img512-l-fid/vae/diffusion_pytorch_model.safetensors +3 -0
  21. edm2-img512-m-fid/demo.png +3 -0
  22. edm2-img512-m-fid/model_index.json +19 -0
  23. edm2-img512-m-fid/pipeline.py +406 -0
  24. edm2-img512-m-fid/scheduler/scheduler_config.json +11 -0
  25. edm2-img512-m-fid/unet/config.json +31 -0
  26. edm2-img512-m-fid/unet/diffusion_pytorch_model.safetensors +3 -0
  27. edm2-img512-m-fid/unet/unet_edm2.py +434 -0
  28. edm2-img512-m-fid/vae/config.json +38 -0
  29. edm2-img512-m-fid/vae/diffusion_pytorch_model.safetensors +3 -0
  30. edm2-img512-s-fid/demo.png +3 -0
  31. edm2-img512-s-fid/model_index.json +19 -0
  32. edm2-img512-s-fid/pipeline.py +406 -0
  33. edm2-img512-s-fid/scheduler/scheduler_config.json +11 -0
  34. edm2-img512-s-fid/unet/config.json +31 -0
  35. edm2-img512-s-fid/unet/diffusion_pytorch_model.safetensors +3 -0
  36. edm2-img512-s-fid/unet/unet_edm2.py +434 -0
  37. edm2-img512-s-fid/vae/config.json +38 -0
  38. edm2-img512-s-fid/vae/diffusion_pytorch_model.safetensors +3 -0
  39. edm2-img512-xl-fid/demo.png +3 -0
  40. edm2-img512-xl-fid/model_index.json +19 -0
  41. edm2-img512-xl-fid/pipeline.py +406 -0
  42. edm2-img512-xl-fid/scheduler/scheduler_config.json +11 -0
  43. edm2-img512-xl-fid/unet/config.json +31 -0
  44. edm2-img512-xl-fid/unet/diffusion_pytorch_model.safetensors +3 -0
  45. edm2-img512-xl-fid/unet/unet_edm2.py +434 -0
  46. edm2-img512-xl-fid/vae/config.json +38 -0
  47. edm2-img512-xl-fid/vae/diffusion_pytorch_model.safetensors +3 -0
  48. edm2-img512-xs-fid/demo.png +3 -0
  49. edm2-img512-xs-fid/model_index.json +19 -0
  50. edm2-img512-xs-fid/pipeline.py +406 -0
.gitattributes CHANGED
@@ -33,3 +33,10 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ edm2-img512-l-dino/demo.png filter=lfs diff=lfs merge=lfs -text
37
+ edm2-img512-l-fid/generator_test.png filter=lfs diff=lfs merge=lfs -text
38
+ edm2-img512-m-fid/demo.png filter=lfs diff=lfs merge=lfs -text
39
+ edm2-img512-s-fid/demo.png filter=lfs diff=lfs merge=lfs -text
40
+ edm2-img512-xl-fid/demo.png filter=lfs diff=lfs merge=lfs -text
41
+ edm2-img512-xs-fid/demo.png filter=lfs diff=lfs merge=lfs -text
42
+ edm2-img512-xxl-fid/demo.png filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: cc-by-nc-sa-4.0
3
+ library_name: diffusers
4
+ pipeline_tag: unconditional-image-generation
5
+ tags:
6
+ - diffusers
7
+ - edm2
8
+ - image-generation
9
+ - class-conditional
10
+ - imagenet
11
+ inference: true
12
+ widget:
13
+ - output:
14
+ url: edm2-img512-xxl-fid/demo.png
15
+ language:
16
+ - en
17
+ ---
18
+
19
+ # EDM2-diffusers
20
+
21
+ Diffusers-ready checkpoints for **EDM2** ([Analyzing and Improving the Training Dynamics of Diffusion Models](https://arxiv.org/abs/2312.02696)),
22
+ converted from [NVlabs/edm2](https://github.com/NVlabs/edm2) post-hoc reconstructions.
23
+
24
+ Official source weights: `https://nvlabs-fi-cdn.nvidia.com/edm2/posthoc-reconstructions/`
25
+
26
+ This root folder is a model collection that contains:
27
+
28
+ - `edm2-img512-xs-fid`
29
+ - `edm2-img512-s-fid`
30
+ - `edm2-img512-m-fid`
31
+ - `edm2-img512-l-fid`
32
+ - `edm2-img512-l-dino`
33
+ - `edm2-img512-xl-fid`
34
+ - `edm2-img512-xxl-fid`
35
+
36
+ Each subfolder is a self-contained Diffusers model repo with:
37
+
38
+ - `pipeline.py`
39
+ - `unet/unet_edm2.py`
40
+ - `scheduler/scheduler_config.json` (`EDMEulerScheduler`)
41
+ - `unet/diffusion_pytorch_model.safetensors`
42
+ - `vae/diffusion_pytorch_model.safetensors`
43
+
44
+ ## Demo
45
+
46
+ ![edm2-img512-xxl-fid demo](edm2-img512-xxl-fid/demo.png)
47
+
48
+ Class-conditional sample (ImageNet class **207**, golden retriever), EDM2-XXL at 512×512, 32 steps, guidance 1.0, seed 42.
49
+
50
+ ## Model Paths
51
+
52
+ Use paths relative to this root README:
53
+
54
+ | Model | NVlabs preset | FID | Local path |
55
+ | --- | --- | ---: | --- |
56
+ | EDM2-XS | `edm2-img512-xs-fid` | 3.53 | `./edm2-img512-xs-fid` |
57
+ | EDM2-S | `edm2-img512-s-fid` | 2.56 | `./edm2-img512-s-fid` |
58
+ | EDM2-M | `edm2-img512-m-fid` | 2.25 | `./edm2-img512-m-fid` |
59
+ | EDM2-L | `edm2-img512-l-fid` | 2.06 | `./edm2-img512-l-fid` |
60
+ | EDM2-L (DINO) | `edm2-img512-l-dino` | — | `./edm2-img512-l-dino` |
61
+ | EDM2-XL | `edm2-img512-xl-fid` | 1.96 | `./edm2-img512-xl-fid` |
62
+ | EDM2-XXL | `edm2-img512-xxl-fid` | 1.91 | `./edm2-img512-xxl-fid` |
63
+
64
+ ## Inference Demo (Diffusers)
65
+
66
+ ### 1) Load a local subfolder checkpoint
67
+
68
+ ```python
69
+ from pathlib import Path
70
+ import torch
71
+ from diffusers import DiffusionPipeline
72
+
73
+ model_dir = Path("./edm2-img512-xxl-fid") # change to any path in the table above
74
+ pipe = DiffusionPipeline.from_pretrained(
75
+ str(model_dir),
76
+ local_files_only=True,
77
+ trust_remote_code=True,
78
+ torch_dtype=torch.bfloat16,
79
+ ).to("cuda")
80
+
81
+ generator = torch.Generator(device="cuda").manual_seed(42)
82
+ image = pipe(
83
+ class_labels=207, # golden retriever (ImageNet id); omit for random class
84
+ num_inference_steps=32,
85
+ guidance_scale=1.0, # >1.0 requires a gnet/ checkpoint
86
+ generator=generator,
87
+ ).images[0]
88
+ image.save("demo.png")
89
+ ```
90
+
91
+ Official inference defaults (`generate_images.py`): `num_steps=32`, `sigma_min=0.002`,
92
+ `sigma_max=80`, `rho=7`, `guidance=1.0` (no gnet), `S_churn=0`. Heun sampling runs in
93
+ float32 internally even when UNet/VAE weights are loaded in bf16/fp16.
94
+
95
+ Guided presets require a converted `gnet/` folder and `guidance_scale` matching the
96
+ NVlabs preset.
97
+
98
+ ### 2) Convert a legacy `.pkl`
99
+
100
+ ```bash
101
+ python scripts/convert_edm2_to_diffusers.py \
102
+ --checkpoint models/BiliSakura/EDM2-diffusers/edm2-img512-xs-2147483-0.135.pkl \
103
+ --output models/BiliSakura/EDM2-diffusers
104
+ ```
105
+
106
+ Creates `edm2-img512-xs-fid/` automatically from the NVlabs preset mapping.
107
+
108
+ ## Checkpoint preset mapping
109
+
110
+ Maps NVlabs `--preset=...` names from [`generate_images.py`](https://github.com/NVlabs/edm2/blob/main/generate_images.py)
111
+ to source pickle filenames and local Diffusers directories.
112
+
113
+ ### EDM2 paper — ImageNet-512 (conditional)
114
+
115
+ | NVlabs preset | Source `.pkl` (net) | Diffusers dir | Metric |
116
+ | --- | --- | --- | --- |
117
+ | `edm2-img512-xs-fid` | `edm2-img512-xs-2147483-0.135.pkl` | `edm2-img512-xs-fid/` | FID 3.53 |
118
+ | `edm2-img512-xs-dino` | `edm2-img512-xs-2147483-0.200.pkl` | — | FD<sub>DINOv2</sub> 103.39 |
119
+ | `edm2-img512-s-fid` | `edm2-img512-s-2147483-0.130.pkl` | `edm2-img512-s-fid/` | FID 2.56 |
120
+ | `edm2-img512-s-dino` | `edm2-img512-s-2147483-0.190.pkl` | — | FD<sub>DINOv2</sub> 68.64 |
121
+ | `edm2-img512-m-fid` | `edm2-img512-m-2147483-0.100.pkl` | `edm2-img512-m-fid/` | FID 2.25 |
122
+ | `edm2-img512-m-dino` | `edm2-img512-m-2147483-0.155.pkl` | — | FD<sub>DINOv2</sub> 58.44 |
123
+ | `edm2-img512-l-fid` | `edm2-img512-l-1879048-0.085.pkl` | `edm2-img512-l-fid/` | FID 2.06 |
124
+ | `edm2-img512-l-dino` | `edm2-img512-l-1879048-0.155.pkl` | `edm2-img512-l-dino/` | FD<sub>DINOv2</sub> 52.25 |
125
+ | `edm2-img512-xl-fid` | `edm2-img512-xl-1342177-0.085.pkl` | `edm2-img512-xl-fid/` | FID 1.96 |
126
+ | `edm2-img512-xl-dino` | `edm2-img512-xl-1342177-0.155.pkl` | — | FD<sub>DINOv2</sub> 45.96 |
127
+ | `edm2-img512-xxl-fid` | `edm2-img512-xxl-0939524-0.070.pkl` | `edm2-img512-xxl-fid/` | FID 1.91 |
128
+ | `edm2-img512-xxl-dino` | `edm2-img512-xxl-0939524-0.150.pkl` | — | FD<sub>DINOv2</sub> 42.84 |
129
+
130
+ ### EDM2 paper — ImageNet-64 (conditional)
131
+
132
+ | NVlabs preset | Source `.pkl` (net) | Metric |
133
+ | --- | --- | --- |
134
+ | `edm2-img64-s-fid` | `edm2-img64-s-1073741-0.075.pkl` | FID 1.58 |
135
+ | `edm2-img64-m-fid` | `edm2-img64-m-2147483-0.060.pkl` | FID 1.43 |
136
+ | `edm2-img64-l-fid` | `edm2-img64-l-1073741-0.040.pkl` | FID 1.33 |
137
+ | `edm2-img64-xl-fid` | `edm2-img64-xl-0671088-0.040.pkl` | FID 1.33 |
138
+
139
+ ### EDM2 paper — classifier-free guidance (ImageNet-512)
140
+
141
+ Use `guidance_scale` below and include the converted `gnet/` checkpoint.
142
+
143
+ | NVlabs preset | Source `.pkl` (net) | Source `.pkl` (gnet) | Guidance | Metric |
144
+ | --- | --- | --- | ---: | --- |
145
+ | `edm2-img512-xs-guid-fid` | `edm2-img512-xs-2147483-0.045.pkl` | `edm2-img512-xs-uncond-2147483-0.045.pkl` | 1.40 | FID 2.91 |
146
+ | `edm2-img512-xs-guid-dino` | `edm2-img512-xs-2147483-0.150.pkl` | `edm2-img512-xs-uncond-2147483-0.150.pkl` | 1.70 | FD<sub>DINOv2</sub> 79.94 |
147
+ | `edm2-img512-s-guid-fid` | `edm2-img512-s-2147483-0.025.pkl` | `edm2-img512-xs-uncond-2147483-0.025.pkl` | 1.40 | FID 2.23 |
148
+ | `edm2-img512-s-guid-dino` | `edm2-img512-s-2147483-0.085.pkl` | `edm2-img512-xs-uncond-2147483-0.085.pkl` | 1.90 | FD<sub>DINOv2</sub> 52.32 |
149
+ | `edm2-img512-m-guid-fid` | `edm2-img512-m-2147483-0.030.pkl` | `edm2-img512-xs-uncond-2147483-0.030.pkl` | 1.20 | FID 2.01 |
150
+ | `edm2-img512-m-guid-dino` | `edm2-img512-m-2147483-0.015.pkl` | `edm2-img512-xs-uncond-2147483-0.015.pkl` | 2.00 | FD<sub>DINOv2</sub> 41.98 |
151
+ | `edm2-img512-l-guid-fid` | `edm2-img512-l-1879048-0.015.pkl` | `edm2-img512-xs-uncond-2147483-0.015.pkl` | 1.20 | FID 1.88 |
152
+ | `edm2-img512-l-guid-dino` | `edm2-img512-l-1879048-0.035.pkl` | `edm2-img512-xs-uncond-2147483-0.035.pkl` | 1.70 | FD<sub>DINOv2</sub> 38.20 |
153
+ | `edm2-img512-xl-guid-fid` | `edm2-img512-xl-1342177-0.020.pkl` | `edm2-img512-xs-uncond-2147483-0.020.pkl` | 1.20 | FID 1.85 |
154
+ | `edm2-img512-xl-guid-dino` | `edm2-img512-xl-1342177-0.030.pkl` | `edm2-img512-xs-uncond-2147483-0.030.pkl` | 1.70 | FD<sub>DINOv2</sub> 35.67 |
155
+ | `edm2-img512-xxl-guid-fid` | `edm2-img512-xxl-0939524-0.015.pkl` | `edm2-img512-xs-uncond-2147483-0.015.pkl` | 1.20 | FID 1.81 |
156
+ | `edm2-img512-xxl-guid-dino` | `edm2-img512-xxl-0939524-0.015.pkl` | `edm2-img512-xs-uncond-2147483-0.015.pkl` | 1.70 | FD<sub>DINOv2</sub> 33.09 |
157
+
158
+ ### Autoguidance paper
159
+
160
+ | NVlabs preset | Source `.pkl` (net) | Source `.pkl` (gnet) | Guidance | Metric |
161
+ | --- | --- | --- | ---: | --- |
162
+ | `edm2-img512-s-autog-fid` | `edm2-img512-s-2147483-0.070.pkl` | `edm2-img512-xs-0134217-0.125.pkl` | 2.10 | FID 1.34 |
163
+ | `edm2-img512-s-autog-dino` | `edm2-img512-s-2147483-0.120.pkl` | `edm2-img512-xs-0134217-0.165.pkl` | 2.45 | FD<sub>DINOv2</sub> 36.67 |
164
+ | `edm2-img512-xxl-autog-fid` | `edm2-img512-xxl-0939524-0.075.pkl` | `edm2-img512-m-0268435-0.155.pkl` | 2.05 | FID 1.25 |
165
+ | `edm2-img512-xxl-autog-dino` | `edm2-img512-xxl-0939524-0.130.pkl` | `edm2-img512-m-0268435-0.205.pkl` | 2.30 | FD<sub>DINOv2</sub> 24.18 |
166
+ | `edm2-img512-s-uncond-autog-fid` | `edm2-img512-s-uncond-2147483-0.070.pkl` | `edm2-img512-xs-uncond-0134217-0.110.pkl` | 2.85 | FID 3.86 |
167
+ | `edm2-img512-s-uncond-autog-dino` | `edm2-img512-s-uncond-2147483-0.090.pkl` | `edm2-img512-xs-uncond-0134217-0.125.pkl` | 2.90 | FD<sub>DINOv2</sub> 90.39 |
168
+ | `edm2-img64-s-autog-fid` | `edm2-img64-s-1073741-0.045.pkl` | `edm2-img64-xs-0134217-0.110.pkl` | 1.70 | FID 1.01 |
169
+ | `edm2-img64-s-autog-dino` | `edm2-img64-s-1073741-0.105.pkl` | `edm2-img64-xs-0134217-0.175.pkl` | 2.20 | FD<sub>DINOv2</sub> 31.85 |
170
+
171
+ ### NVlabs preset shorthand
172
+
173
+ ```text
174
+ # EDM2 paper
175
+ edm2-img512-{xs|s|m|l|xl|xxl}-{fid|dino}
176
+ edm2-img64-{s|m|l|xl}-fid
177
+ edm2-img512-{xs|s|m|l|xl|xxl}-guid-{fid|dino}
178
+
179
+ # Autoguidance paper
180
+ edm2-img512-{s|xxl}-autog-{fid|dino}
181
+ edm2-img512-s-uncond-autog-{fid|dino}
182
+ edm2-img64-s-autog-{fid|dino}
183
+ ```
184
+
185
+ Example NVlabs command:
186
+
187
+ ```bash
188
+ python generate_images.py --preset=edm2-img512-s-guid-dino --outdir=out
189
+ ```
190
+
191
+ Equivalent expanded form:
192
+
193
+ ```bash
194
+ python generate_images.py \
195
+ --net=https://nvlabs-fi-cdn.nvidia.com/edm2/posthoc-reconstructions/edm2-img512-s-2147483-0.085.pkl \
196
+ --gnet=https://nvlabs-fi-cdn.nvidia.com/edm2/posthoc-reconstructions/edm2-img512-xs-uncond-2147483-0.085.pkl \
197
+ --guidance=1.9 \
198
+ --outdir=out
199
+ ```
edm2-img512-l-dino/demo.png ADDED

Git LFS Details

  • SHA256: 12a2dab2ca0e5ec5a6eebe9f7c10b440232622055866192ecc5c8b3dc289db4d
  • Pointer size: 131 Bytes
  • Size of remote file: 389 kB
edm2-img512-l-dino/model_index.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": [
3
+ "pipeline",
4
+ "EDM2Pipeline"
5
+ ],
6
+ "_diffusers_version": "0.31.0",
7
+ "scheduler": [
8
+ "diffusers",
9
+ "EDMEulerScheduler"
10
+ ],
11
+ "unet": [
12
+ "unet_edm2",
13
+ "EDM2UNet2DModel"
14
+ ],
15
+ "vae": [
16
+ "diffusers",
17
+ "AutoencoderKL"
18
+ ]
19
+ }
edm2-img512-l-dino/pipeline.py ADDED
@@ -0,0 +1,406 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Hub custom pipeline: EDM2Pipeline.
2
+ Load with native Hugging Face diffusers and trust_remote_code=True.
3
+ """
4
+
5
+ from __future__ import annotations
6
+
7
+ # Copyright 2026 The HuggingFace Team. All rights reserved.
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+
21
+ import json
22
+ import os
23
+ from pathlib import Path
24
+ from typing import Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Union
25
+
26
+ import numpy as np
27
+ import torch
28
+
29
+ from diffusers.image_processor import VaeImageProcessor
30
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
31
+ from diffusers.utils import replace_example_docstring
32
+ from diffusers.utils.torch_utils import randn_tensor
33
+
34
+ EXAMPLE_DOC_STRING = """
35
+ Examples:
36
+ ```py
37
+ >>> from pathlib import Path
38
+ >>> import torch
39
+ >>> from diffusers import DiffusionPipeline
40
+
41
+ >>> model_dir = Path("BiliSakura/EDM2-diffusers/edm2-img512-xs-fid").resolve()
42
+ >>> pipe = DiffusionPipeline.from_pretrained(
43
+ ... str(model_dir),
44
+ ... local_files_only=True,
45
+ ... custom_pipeline=str(model_dir / "pipeline.py"),
46
+ ... trust_remote_code=True,
47
+ ... torch_dtype=torch.float32,
48
+ ... )
49
+ >>> pipe.to("cuda")
50
+
51
+ >>> generator = torch.Generator(device="cuda").manual_seed(42)
52
+ >>> image = pipe(
53
+ ... class_labels=207,
54
+ ... num_inference_steps=32,
55
+ ... guidance_scale=1.0,
56
+ ... generator=generator,
57
+ ... ).images[0]
58
+ >>> image.save("demo.png")
59
+ ```
60
+ """
61
+
62
+ # Default Stability VAE latent whitening used by NVlabs/edm2 (training/encoders.py).
63
+ _STABILITY_VAE_SCALE = np.float32(0.5) / np.float32([4.17, 4.62, 3.71, 3.28])
64
+ _STABILITY_VAE_BIAS = np.float32(0.0) - np.float32([5.81, 3.25, 0.12, -2.15]) * _STABILITY_VAE_SCALE
65
+
66
+ class EDM2Pipeline(DiffusionPipeline):
67
+ r"""
68
+ Pipeline for class-conditional image generation with EDM2
69
+ ([Analyzing and Improving the Training Dynamics of Diffusion Models](https://arxiv.org/abs/2312.02696)).
70
+
71
+ Parameters:
72
+ unet ([`EDM2UNet2DModel`]):
73
+ Main magnitude-preserving U-Net with EDM preconditioning.
74
+ scheduler ([`EDMEulerScheduler`]):
75
+ Built-in diffusers scheduler used for the Karras sigma schedule. EDM2 Heun sampling runs in
76
+ the pipeline because the UNet returns denoised latents rather than noise predictions.
77
+ vae ([`AutoencoderKL`], *optional*):
78
+ Decoder for 512px latent-diffusion checkpoints. Required when `unet.in_channels == 4`.
79
+ gnet ([`EDM2UNet2DModel`], *optional*):
80
+ Guiding network for autoguidance (`ref.lerp(main, guidance_scale)`).
81
+ id2label (`dict[int, str]`, *optional*):
82
+ ImageNet class id to English label mapping.
83
+ """
84
+
85
+ model_cpu_offload_seq = "unet->gnet->vae"
86
+ _optional_components = ["vae", "gnet"]
87
+
88
+ def __init__(
89
+ self,
90
+ unet,
91
+ scheduler,
92
+ vae=None,
93
+ gnet=None,
94
+ id2label: Optional[Dict[Union[int, str], str]] = None,
95
+ ) -> None:
96
+ super().__init__()
97
+ self.register_modules(unet=unet, scheduler=scheduler, vae=vae, gnet=gnet)
98
+ self._id2label = self._normalize_id2label(id2label)
99
+ self.labels = self._build_label2id(self._id2label)
100
+ self._labels_loaded_from_model_index = bool(self._id2label)
101
+ self.vae_scale_factor = 8 if self.vae is not None else 1
102
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_normalize=False)
103
+
104
+ @staticmethod
105
+ def _normalize_id2label(id2label: Optional[Dict[Union[int, str], str]]) -> Dict[int, str]:
106
+ if not id2label:
107
+ return {}
108
+ return {int(key): value for key, value in id2label.items()}
109
+
110
+ @staticmethod
111
+ def _build_label2id(id2label: Dict[int, str]) -> Dict[str, int]:
112
+ label2id: Dict[str, int] = {}
113
+ for class_id, value in id2label.items():
114
+ for synonym in value.split(","):
115
+ synonym = synonym.strip()
116
+ if synonym:
117
+ label2id[synonym] = int(class_id)
118
+ return dict(sorted(label2id.items()))
119
+
120
+ def _ensure_labels_loaded(self) -> None:
121
+ if self._labels_loaded_from_model_index:
122
+ return
123
+ loaded = self._read_id2label_from_model_index(getattr(self.config, "_name_or_path", None))
124
+ if loaded:
125
+ self._id2label = loaded
126
+ self.labels = self._build_label2id(self._id2label)
127
+ self._labels_loaded_from_model_index = True
128
+
129
+ @staticmethod
130
+ def _read_id2label_from_model_index(variant_path: Optional[str]) -> Dict[int, str]:
131
+ if not variant_path:
132
+ return {}
133
+ model_index_path = Path(variant_path).resolve() / "model_index.json"
134
+ if not model_index_path.is_file():
135
+ return {}
136
+ raw = json.loads(model_index_path.read_text(encoding="utf-8"))
137
+ id2label = raw.get("id2label")
138
+ if not isinstance(id2label, dict):
139
+ return {}
140
+ return {int(key): value for key, value in id2label.items()}
141
+
142
+ @property
143
+ def id2label(self) -> Dict[int, str]:
144
+ r"""ImageNet class id to English label string (comma-separated synonyms)."""
145
+ self._ensure_labels_loaded()
146
+ return self._id2label
147
+
148
+ def get_label_ids(self, label: Union[str, List[str]]) -> List[int]:
149
+ r"""
150
+ Map ImageNet label strings to class ids.
151
+
152
+ Args:
153
+ label (`str` or `list[str]`):
154
+ One or more English label strings that match entries in `id2label`.
155
+ """
156
+ self._ensure_labels_loaded()
157
+ if not self.labels:
158
+ raise ValueError("No English labels loaded. Add `id2label` to model_index.json.")
159
+ labels = [label] if isinstance(label, str) else list(label)
160
+ missing = [item for item in labels if item not in self.labels]
161
+ if missing:
162
+ preview = ", ".join(list(self.labels.keys())[:8])
163
+ raise ValueError(f"Unknown English label(s): {missing}. Example valid labels: {preview}, ...")
164
+ return [self.labels[item] for item in labels]
165
+
166
+ def _default_image_size(self) -> int:
167
+ latent_size = int(getattr(self.unet, "sample_size", getattr(self.unet.config, "sample_size", 64)))
168
+ return latent_size * self.vae_scale_factor
169
+
170
+ def check_inputs(
171
+ self,
172
+ height: int,
173
+ width: int,
174
+ num_inference_steps: int,
175
+ guidance_scale: float,
176
+ output_type: str,
177
+ ) -> None:
178
+ if num_inference_steps < 1:
179
+ raise ValueError("num_inference_steps must be >= 1.")
180
+ if guidance_scale < 1.0:
181
+ raise ValueError("guidance_scale must be >= 1.0.")
182
+ if guidance_scale > 1.0 and self.gnet is None:
183
+ raise ValueError("guidance_scale > 1.0 requires a guiding network (`gnet`).")
184
+ if output_type not in {"pil", "np", "pt", "latent"}:
185
+ raise ValueError("output_type must be one of: 'pil', 'np', 'pt', 'latent'.")
186
+
187
+ native_size = self._default_image_size()
188
+ if height != native_size or width != native_size:
189
+ raise ValueError(
190
+ f"EDM2 expects native resolution height=width={native_size}. "
191
+ f"Got height={height}, width={width}."
192
+ )
193
+
194
+ def _normalize_class_labels(
195
+ self,
196
+ class_labels: Optional[Union[int, str, Sequence[Union[int, str]], torch.Tensor]],
197
+ batch_size: int,
198
+ device: torch.device,
199
+ ) -> Optional[torch.Tensor]:
200
+ label_dim = int(getattr(self.unet, "num_class_embeds", getattr(self.unet.config, "num_class_embeds", 0)))
201
+ if label_dim == 0:
202
+ return None
203
+ if class_labels is None:
204
+ indices = torch.randint(label_dim, size=(batch_size,), device=device)
205
+ return torch.eye(label_dim, device=device, dtype=torch.float32)[indices]
206
+
207
+ if isinstance(class_labels, str):
208
+ class_labels = self.get_label_ids(class_labels)[0]
209
+ elif isinstance(class_labels, Sequence) and class_labels and isinstance(class_labels[0], str):
210
+ class_labels = self.get_label_ids(list(class_labels))
211
+
212
+ if isinstance(class_labels, int):
213
+ indices = torch.full((batch_size,), class_labels, device=device, dtype=torch.long)
214
+ elif isinstance(class_labels, torch.Tensor):
215
+ if class_labels.ndim == 2:
216
+ labels = class_labels.to(device=device, dtype=torch.float32)
217
+ if labels.shape[0] != batch_size:
218
+ raise ValueError(f"class_labels batch must match batch_size={batch_size}.")
219
+ return labels
220
+ indices = class_labels.to(device=device, dtype=torch.long).flatten()
221
+ else:
222
+ indices = torch.tensor(list(class_labels), device=device, dtype=torch.long)
223
+
224
+ if indices.numel() == 1 and batch_size > 1:
225
+ indices = indices.repeat(batch_size)
226
+ if indices.numel() != batch_size:
227
+ raise ValueError(f"class_labels must resolve to batch size {batch_size}, got {indices.numel()}.")
228
+ return torch.eye(label_dim, device=device, dtype=torch.float32)[indices]
229
+
230
+ def prepare_latents(
231
+ self,
232
+ batch_size: int,
233
+ height: int,
234
+ width: int,
235
+ dtype: torch.dtype,
236
+ device: torch.device,
237
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]],
238
+ ) -> torch.Tensor:
239
+ in_channels = int(getattr(self.unet, "in_channels", getattr(self.unet.config, "in_channels", 4)))
240
+ latent_size = height // self.vae_scale_factor
241
+ return randn_tensor(
242
+ (batch_size, in_channels, latent_size, latent_size),
243
+ generator=generator,
244
+ device=device,
245
+ dtype=torch.float32,
246
+ )
247
+
248
+ def decode_latents(self, latents: torch.Tensor, output_type: str = "pil"):
249
+ if output_type == "latent":
250
+ return latents
251
+
252
+ in_channels = int(getattr(self.unet, "in_channels", getattr(self.unet.config, "in_channels", 3)))
253
+ if self.vae is None:
254
+ image = (latents.to(torch.float32) * 127.5 + 128).clip(0, 255) / 255.0
255
+ return self.image_processor.postprocess(image, output_type=output_type)
256
+
257
+ if in_channels == 4:
258
+ x = latents.to(torch.float32)
259
+ scale = torch.as_tensor(_STABILITY_VAE_SCALE, dtype=x.dtype, device=x.device).reshape(1, -1, 1, 1)
260
+ bias = torch.as_tensor(_STABILITY_VAE_BIAS, dtype=x.dtype, device=x.device).reshape(1, -1, 1, 1)
261
+ x = (x - bias) / scale
262
+ else:
263
+ x = latents.to(torch.float32)
264
+
265
+ vae_dtype = getattr(self.vae, "dtype", None) or next(self.vae.parameters()).dtype
266
+ image = self.vae.decode(x.to(dtype=vae_dtype)).sample.to(torch.float32).clamp(0, 1)
267
+
268
+ return self.image_processor.postprocess(image, output_type=output_type)
269
+
270
+ @staticmethod
271
+ def _apply_autoguidance(
272
+ main: torch.Tensor,
273
+ ref: torch.Tensor,
274
+ guidance_scale: float,
275
+ ) -> torch.Tensor:
276
+ return ref.lerp(main, guidance_scale)
277
+
278
+ @staticmethod
279
+ def _sample_edm2_heun(
280
+ denoise_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
281
+ noise: torch.Tensor,
282
+ sigmas: torch.Tensor,
283
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
284
+ progress_bar: Optional[Callable[[Iterable], Iterable]] = None,
285
+ dtype: torch.dtype = torch.float32,
286
+ ) -> torch.Tensor:
287
+ """NVlabs EDM2 Heun sampler (generate_images.edm_sampler, guidance=1, S_churn=0)."""
288
+ x_next = noise.to(dtype) * sigmas[0]
289
+
290
+ sigma_pairs = list(zip(sigmas[:-1], sigmas[1:]))
291
+ if progress_bar is not None:
292
+ sigma_pairs = progress_bar(sigma_pairs)
293
+
294
+ num_steps = len(sigma_pairs)
295
+ for i, (sigma_cur, sigma_next) in enumerate(sigma_pairs):
296
+ x_hat, sigma_hat = x_next, sigma_cur
297
+ d_cur = (x_hat - denoise_fn(x_hat, sigma_hat)) / sigma_hat
298
+ x_next = x_hat + (sigma_next - sigma_hat) * d_cur
299
+ if i < num_steps - 1:
300
+ d_prime = (x_next - denoise_fn(x_next, sigma_next)) / sigma_next
301
+ x_next = x_hat + (sigma_next - sigma_hat) * (0.5 * d_cur + 0.5 * d_prime)
302
+ return x_next
303
+
304
+ @torch.inference_mode()
305
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
306
+ def __call__(
307
+ self,
308
+ class_labels: Optional[Union[int, str, Sequence[Union[int, str]], torch.Tensor]] = None,
309
+ batch_size: int = 1,
310
+ height: Optional[int] = None,
311
+ width: Optional[int] = None,
312
+ num_inference_steps: int = 32,
313
+ guidance_scale: float = 1.0,
314
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
315
+ output_type: str = "pil",
316
+ return_dict: bool = True,
317
+ ) -> Union[ImagePipelineOutput, Tuple]:
318
+ r"""
319
+ Generate class-conditional images with EDM2.
320
+
321
+ Args:
322
+ class_labels (`int`, `str`, `list[int]`, `list[str]`, or `torch.Tensor`, *optional*):
323
+ ImageNet class indices, English label strings, or one-hot float tensors.
324
+ Random classes are sampled when omitted on conditional models.
325
+ batch_size (`int`, defaults to `1`):
326
+ Number of images to generate.
327
+ height (`int`, *optional*):
328
+ Output height in pixels. Defaults to the pretrained native resolution.
329
+ width (`int`, *optional*):
330
+ Output width in pixels. Defaults to the pretrained native resolution.
331
+ num_inference_steps (`int`, defaults to `32`):
332
+ Number of EDM2 Heun steps (NVlabs default).
333
+ guidance_scale (`float`, defaults to `1.0`):
334
+ Autoguidance strength. Values above `1.0` blend the main net with `gnet`
335
+ via `gnet_output.lerp(unet_output, guidance_scale)`.
336
+ generator (`torch.Generator`, *optional*):
337
+ RNG for reproducibility.
338
+ output_type (`str`, defaults to `"pil"`):
339
+ `"pil"`, `"np"`, `"pt"`, or `"latent"`.
340
+ return_dict (`bool`, defaults to `True`):
341
+ Return [`~pipelines.pipeline_utils.ImagePipelineOutput`] if True.
342
+
343
+ Examples:
344
+ <!-- this section is replaced by replace_example_docstring -->
345
+ """
346
+ default_size = self._default_image_size()
347
+ height = int(height or default_size)
348
+ width = int(width or default_size)
349
+ self.check_inputs(height, width, num_inference_steps, guidance_scale, output_type)
350
+
351
+ device = self._execution_device
352
+ dtype = self.unet.dtype
353
+ labels = self._normalize_class_labels(class_labels, batch_size=batch_size, device=device)
354
+ noise = self.prepare_latents(batch_size, height, width, dtype, device, generator)
355
+
356
+ def denoise_fn(x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor:
357
+ sigma_batch = sigma.reshape(1).expand(batch_size)
358
+ main = self.unet(
359
+ sample=x,
360
+ sigma=sigma_batch,
361
+ class_labels=labels,
362
+ force_fp32=True,
363
+ ).sample
364
+ if guidance_scale == 1.0 or self.gnet is None:
365
+ return main.to(torch.float32)
366
+ ref = self.gnet(
367
+ sample=x,
368
+ sigma=sigma_batch,
369
+ class_labels=labels,
370
+ force_fp32=True,
371
+ ).sample
372
+ return self._apply_autoguidance(main, ref, guidance_scale).to(torch.float32)
373
+
374
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
375
+ latents = self._sample_edm2_heun(
376
+ denoise_fn=denoise_fn,
377
+ noise=noise,
378
+ sigmas=self.scheduler.sigmas.to(device),
379
+ generator=generator,
380
+ progress_bar=self.progress_bar,
381
+ dtype=torch.float32,
382
+ )
383
+
384
+ image = self.decode_latents(latents, output_type=output_type)
385
+ if not return_dict:
386
+ return (image, latents)
387
+ return ImagePipelineOutput(images=image)
388
+
389
+ @classmethod
390
+ def _load_vae(cls, pretrained_model_name_or_path: str, torch_dtype: Optional[torch.dtype] = None):
391
+ vae_dir = os.path.join(pretrained_model_name_or_path, "vae")
392
+ if os.path.isdir(vae_dir):
393
+ try:
394
+
395
+ return AutoencoderKL.from_pretrained(vae_dir, torch_dtype=torch_dtype)
396
+ except Exception:
397
+ return None
398
+
399
+ vae_hint = os.path.join(pretrained_model_name_or_path, "vae_pretrained_model_name_or_path.txt")
400
+ if os.path.isfile(vae_hint):
401
+ with open(vae_hint, "r", encoding="utf-8") as f:
402
+ hub_id = f.read().strip()
403
+ if hub_id:
404
+
405
+ return AutoencoderKL.from_pretrained(hub_id, torch_dtype=torch_dtype)
406
+ return None
edm2-img512-l-dino/scheduler/scheduler_config.json ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "EDMEulerScheduler",
3
+ "final_sigmas_type": "zero",
4
+ "num_train_timesteps": 1000,
5
+ "prediction_type": "epsilon",
6
+ "rho": 7.0,
7
+ "sigma_data": 0.5,
8
+ "sigma_max": 80.0,
9
+ "sigma_min": 0.002,
10
+ "sigma_schedule": "karras"
11
+ }
edm2-img512-l-dino/unet/config.json ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "EDM2UNet2DModel",
3
+ "attn_balance": 0.3,
4
+ "attn_resolutions": [
5
+ 16,
6
+ 8
7
+ ],
8
+ "channel_mult": [
9
+ 1,
10
+ 2,
11
+ 3,
12
+ 4
13
+ ],
14
+ "channel_mult_emb": 4,
15
+ "channel_mult_noise": 1,
16
+ "channels_per_head": 64,
17
+ "clip_act": 256,
18
+ "concat_balance": 0.5,
19
+ "dropout": 0.0,
20
+ "in_channels": 4,
21
+ "label_balance": 0.5,
22
+ "logvar_channels": 128,
23
+ "model_channels": 320,
24
+ "num_blocks": 3,
25
+ "num_class_embeds": 1000,
26
+ "out_channels": 4,
27
+ "res_balance": 0.3,
28
+ "sample_size": 64,
29
+ "sigma_data": 0.5,
30
+ "use_fp16": true
31
+ }
edm2-img512-l-dino/unet/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f13f83377a74d74e1205843e241ce6d6e4bc9e49c2661944e49fdbe4d515ba33
3
+ size 3110018564
edm2-img512-l-dino/unet/unet_edm2.py ADDED
@@ -0,0 +1,434 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import json
3
+ import os
4
+ from dataclasses import dataclass
5
+ from typing import List, Optional, Tuple
6
+
7
+ import numpy as np
8
+ import torch
9
+
10
+ try:
11
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
12
+ from diffusers.models.modeling_utils import ModelMixin
13
+ from diffusers.utils import BaseOutput
14
+ except ImportError: # pragma: no cover
15
+ class ModelMixin(torch.nn.Module):
16
+ pass
17
+
18
+ class ConfigMixin:
19
+ config = {}
20
+
21
+ def register_to_config(self, **kwargs):
22
+ self.config = kwargs
23
+
24
+ def register_to_config(func):
25
+ return func
26
+
27
+ @dataclass
28
+ class BaseOutput:
29
+ pass
30
+
31
+
32
+ def normalize(x: torch.Tensor, dim: Optional[List[int]] = None, eps: float = 1e-4) -> torch.Tensor:
33
+ if dim is None:
34
+ dim = list(range(1, x.ndim))
35
+ norm = torch.linalg.vector_norm(x, dim=dim, keepdim=True, dtype=torch.float32)
36
+ norm = torch.add(eps, norm, alpha=math.sqrt(norm.numel() / x.numel()))
37
+ return x / norm.to(x.dtype)
38
+
39
+
40
+ def resample(x: torch.Tensor, f: List[float], mode: str = "keep") -> torch.Tensor:
41
+ if mode == "keep":
42
+ return x
43
+ filt = np.float32(f)
44
+ pad = (len(filt) - 1) // 2
45
+ filt = filt / filt.sum()
46
+ filt = np.outer(filt, filt)[np.newaxis, np.newaxis, :, :]
47
+ filt = torch.as_tensor(filt, dtype=x.dtype, device=x.device)
48
+ c = x.shape[1]
49
+ if mode == "down":
50
+ return torch.nn.functional.conv2d(x, filt.tile([c, 1, 1, 1]), groups=c, stride=2, padding=(pad,))
51
+ return torch.nn.functional.conv_transpose2d(x, (filt * 4).tile([c, 1, 1, 1]), groups=c, stride=2, padding=(pad,))
52
+
53
+
54
+ def mp_silu(x: torch.Tensor) -> torch.Tensor:
55
+ return torch.nn.functional.silu(x) / 0.596
56
+
57
+
58
+ def mp_sum(a: torch.Tensor, b: torch.Tensor, t: float = 0.5) -> torch.Tensor:
59
+ return a.lerp(b, t) / math.sqrt((1 - t) ** 2 + t ** 2)
60
+
61
+
62
+ def mp_cat(a: torch.Tensor, b: torch.Tensor, dim: int = 1, t: float = 0.5) -> torch.Tensor:
63
+ na = a.shape[dim]
64
+ nb = b.shape[dim]
65
+ c = math.sqrt((na + nb) / ((1 - t) ** 2 + t ** 2))
66
+ wa = c / math.sqrt(na) * (1 - t)
67
+ wb = c / math.sqrt(nb) * t
68
+ return torch.cat([wa * a, wb * b], dim=dim)
69
+
70
+
71
+ class MPFourier(torch.nn.Module):
72
+ def __init__(self, num_channels: int, bandwidth: float = 1):
73
+ super().__init__()
74
+ self.register_buffer("freqs", 2 * math.pi * torch.randn(num_channels) * bandwidth)
75
+ self.register_buffer("phases", 2 * math.pi * torch.rand(num_channels))
76
+
77
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
78
+ y = x.to(torch.float32).ger(self.freqs.to(torch.float32))
79
+ y = y + self.phases.to(torch.float32)
80
+ y = y.cos() * math.sqrt(2)
81
+ return y.to(x.dtype)
82
+
83
+
84
+ class MPConv(torch.nn.Module):
85
+ def __init__(self, in_channels: int, out_channels: int, kernel: Tuple[int, ...]):
86
+ super().__init__()
87
+ self.out_channels = out_channels
88
+ self.weight = torch.nn.Parameter(torch.randn(out_channels, in_channels, *kernel))
89
+
90
+ def forward(self, x: torch.Tensor, gain: float = 1) -> torch.Tensor:
91
+ w = self.weight.to(torch.float32)
92
+ if self.training:
93
+ with torch.no_grad():
94
+ self.weight.copy_(normalize(w))
95
+ w = normalize(w)
96
+ w = w * (gain / math.sqrt(w[0].numel()))
97
+ w = w.to(x.dtype)
98
+ if w.ndim == 2:
99
+ return x @ w.t()
100
+ return torch.nn.functional.conv2d(x, w, padding=(w.shape[-1] // 2,))
101
+
102
+
103
+ class Block(torch.nn.Module):
104
+ def __init__(
105
+ self,
106
+ in_channels: int,
107
+ out_channels: int,
108
+ emb_channels: int,
109
+ flavor: str = "enc",
110
+ resample_mode: str = "keep",
111
+ resample_filter: List[float] = [1, 1],
112
+ attention: bool = False,
113
+ channels_per_head: int = 64,
114
+ dropout: float = 0.0,
115
+ res_balance: float = 0.3,
116
+ attn_balance: float = 0.3,
117
+ clip_act: Optional[float] = 256,
118
+ ):
119
+ super().__init__()
120
+ self.out_channels = out_channels
121
+ self.flavor = flavor
122
+ self.resample_filter = resample_filter
123
+ self.resample_mode = resample_mode
124
+ self.num_heads = out_channels // channels_per_head if attention else 0
125
+ self.dropout = dropout
126
+ self.res_balance = res_balance
127
+ self.attn_balance = attn_balance
128
+ self.clip_act = clip_act
129
+ self.emb_gain = torch.nn.Parameter(torch.zeros([]))
130
+ self.conv_res0 = MPConv(out_channels if flavor == "enc" else in_channels, out_channels, kernel=(3, 3))
131
+ self.emb_linear = MPConv(emb_channels, out_channels, kernel=())
132
+ self.conv_res1 = MPConv(out_channels, out_channels, kernel=(3, 3))
133
+ self.conv_skip = MPConv(in_channels, out_channels, kernel=(1, 1)) if in_channels != out_channels else None
134
+ self.attn_qkv = MPConv(out_channels, out_channels * 3, kernel=(1, 1)) if self.num_heads else None
135
+ self.attn_proj = MPConv(out_channels, out_channels, kernel=(1, 1)) if self.num_heads else None
136
+
137
+ def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor:
138
+ x = resample(x, f=self.resample_filter, mode=self.resample_mode)
139
+ if self.flavor == "enc":
140
+ if self.conv_skip is not None:
141
+ x = self.conv_skip(x)
142
+ x = normalize(x, dim=[1])
143
+
144
+ y = self.conv_res0(mp_silu(x))
145
+ c = self.emb_linear(emb, gain=self.emb_gain) + 1
146
+ y = mp_silu(y * c.unsqueeze(2).unsqueeze(3).to(y.dtype))
147
+ if self.training and self.dropout:
148
+ y = torch.nn.functional.dropout(y, p=self.dropout)
149
+ y = self.conv_res1(y)
150
+
151
+ if self.flavor == "dec" and self.conv_skip is not None:
152
+ x = self.conv_skip(x)
153
+ x = mp_sum(x, y, t=self.res_balance)
154
+
155
+ if self.num_heads:
156
+ y = self.attn_qkv(x)
157
+ y = y.reshape(y.shape[0], self.num_heads, -1, 3, y.shape[2] * y.shape[3])
158
+ q, k, v = normalize(y, dim=[2]).unbind(3)
159
+ w = torch.einsum("nhcq,nhck->nhqk", q, k / math.sqrt(q.shape[2])).softmax(dim=3)
160
+ y = torch.einsum("nhqk,nhck->nhcq", w, v)
161
+ y = self.attn_proj(y.reshape(*x.shape))
162
+ x = mp_sum(x, y, t=self.attn_balance)
163
+
164
+ if self.clip_act is not None:
165
+ x = x.clip_(-self.clip_act, self.clip_act)
166
+ return x
167
+
168
+
169
+ class EDM2UNet(torch.nn.Module):
170
+ def __init__(
171
+ self,
172
+ img_resolution: int,
173
+ img_channels: int,
174
+ label_dim: int,
175
+ model_channels: int = 192,
176
+ channel_mult: Tuple[int, ...] = (1, 2, 3, 4),
177
+ channel_mult_noise: Optional[int] = None,
178
+ channel_mult_emb: Optional[int] = None,
179
+ num_blocks: int = 3,
180
+ attn_resolutions: Tuple[int, ...] = (16, 8),
181
+ label_balance: float = 0.5,
182
+ concat_balance: float = 0.5,
183
+ **block_kwargs,
184
+ ):
185
+ super().__init__()
186
+ cblock = [model_channels * x for x in channel_mult]
187
+ cnoise = model_channels * channel_mult_noise if channel_mult_noise is not None else cblock[0]
188
+ cemb = model_channels * channel_mult_emb if channel_mult_emb is not None else max(cblock)
189
+ self.label_balance = label_balance
190
+ self.concat_balance = concat_balance
191
+ self.out_gain = torch.nn.Parameter(torch.zeros([]))
192
+
193
+ self.emb_fourier = MPFourier(cnoise)
194
+ self.emb_noise = MPConv(cnoise, cemb, kernel=())
195
+ self.emb_label = MPConv(label_dim, cemb, kernel=()) if label_dim else None
196
+
197
+ self.enc = torch.nn.ModuleDict()
198
+ cout = img_channels + 1
199
+ for level, channels in enumerate(cblock):
200
+ res = img_resolution >> level
201
+ if level == 0:
202
+ cin = cout
203
+ cout = channels
204
+ self.enc[f"{res}x{res}_conv"] = MPConv(cin, cout, kernel=(3, 3))
205
+ else:
206
+ self.enc[f"{res}x{res}_down"] = Block(cout, cout, cemb, flavor="enc", resample_mode="down", **block_kwargs)
207
+ for idx in range(num_blocks):
208
+ cin = cout
209
+ cout = channels
210
+ self.enc[f"{res}x{res}_block{idx}"] = Block(
211
+ cin,
212
+ cout,
213
+ cemb,
214
+ flavor="enc",
215
+ attention=(res in attn_resolutions),
216
+ **block_kwargs,
217
+ )
218
+
219
+ self.dec = torch.nn.ModuleDict()
220
+ skips = [block.out_channels for block in self.enc.values()]
221
+ for level, channels in reversed(list(enumerate(cblock))):
222
+ res = img_resolution >> level
223
+ if level == len(cblock) - 1:
224
+ self.dec[f"{res}x{res}_in0"] = Block(cout, cout, cemb, flavor="dec", attention=True, **block_kwargs)
225
+ self.dec[f"{res}x{res}_in1"] = Block(cout, cout, cemb, flavor="dec", **block_kwargs)
226
+ else:
227
+ self.dec[f"{res}x{res}_up"] = Block(cout, cout, cemb, flavor="dec", resample_mode="up", **block_kwargs)
228
+ for idx in range(num_blocks + 1):
229
+ cin = cout + skips.pop()
230
+ cout = channels
231
+ self.dec[f"{res}x{res}_block{idx}"] = Block(
232
+ cin,
233
+ cout,
234
+ cemb,
235
+ flavor="dec",
236
+ attention=(res in attn_resolutions),
237
+ **block_kwargs,
238
+ )
239
+
240
+ self.out_conv = MPConv(cout, img_channels, kernel=(3, 3))
241
+
242
+ def forward(self, x: torch.Tensor, noise_labels: torch.Tensor, class_labels: Optional[torch.Tensor]) -> torch.Tensor:
243
+ emb = self.emb_noise(self.emb_fourier(noise_labels))
244
+ if self.emb_label is not None:
245
+ if class_labels is None:
246
+ raise ValueError("class_labels are required for conditional EDM2UNet.")
247
+ emb = mp_sum(emb, self.emb_label(class_labels * math.sqrt(class_labels.shape[1])), t=self.label_balance)
248
+ emb = mp_silu(emb)
249
+
250
+ x = torch.cat([x, torch.ones_like(x[:, :1])], dim=1)
251
+ skips = []
252
+ for name, block in self.enc.items():
253
+ x = block(x) if "conv" in name else block(x, emb)
254
+ skips.append(x)
255
+
256
+ for name, block in self.dec.items():
257
+ if "block" in name:
258
+ x = mp_cat(x, skips.pop(), t=self.concat_balance)
259
+ x = block(x, emb)
260
+ return self.out_conv(x, gain=self.out_gain)
261
+
262
+
263
+ @dataclass
264
+ class EDM2UNet2DOutput(BaseOutput):
265
+ sample: torch.Tensor
266
+ logvar: Optional[torch.Tensor] = None
267
+
268
+
269
+
270
+ _CONFIG_KEYS = (
271
+ "sample_size",
272
+ "in_channels",
273
+ "out_channels",
274
+ "num_class_embeds",
275
+ "use_fp16",
276
+ "sigma_data",
277
+ "logvar_channels",
278
+ "model_channels",
279
+ "channel_mult",
280
+ "channel_mult_noise",
281
+ "channel_mult_emb",
282
+ "num_blocks",
283
+ "attn_resolutions",
284
+ "label_balance",
285
+ "concat_balance",
286
+ "dropout",
287
+ "channels_per_head",
288
+ "res_balance",
289
+ "attn_balance",
290
+ "clip_act",
291
+ )
292
+
293
+
294
+ class EDM2UNet2DModel(ModelMixin, ConfigMixin):
295
+ @register_to_config
296
+ def __init__(
297
+ self,
298
+ sample_size: int = 64,
299
+ in_channels: int = 4,
300
+ out_channels: int = 4,
301
+ num_class_embeds: int = 0,
302
+ use_fp16: bool = True,
303
+ sigma_data: float = 0.5,
304
+ logvar_channels: int = 128,
305
+ model_channels: int = 192,
306
+ channel_mult: Tuple[int, ...] = (1, 2, 3, 4),
307
+ channel_mult_noise: Optional[int] = None,
308
+ channel_mult_emb: Optional[int] = None,
309
+ num_blocks: int = 3,
310
+ attn_resolutions: Tuple[int, ...] = (16, 8),
311
+ label_balance: float = 0.5,
312
+ concat_balance: float = 0.5,
313
+ dropout: float = 0.0,
314
+ channels_per_head: int = 64,
315
+ res_balance: float = 0.3,
316
+ attn_balance: float = 0.3,
317
+ clip_act: Optional[float] = 256,
318
+ ):
319
+ super().__init__()
320
+ self.sample_size = sample_size
321
+ self.in_channels = in_channels
322
+ self.out_channels = out_channels
323
+ self.num_class_embeds = num_class_embeds
324
+ self.use_fp16 = use_fp16
325
+ self.sigma_data = sigma_data
326
+ self.model_channels = model_channels
327
+ self.channel_mult = channel_mult
328
+ self.channel_mult_noise = channel_mult_noise
329
+ self.channel_mult_emb = channel_mult_emb
330
+ self.num_blocks = num_blocks
331
+ self.attn_resolutions = attn_resolutions
332
+ self.label_balance = label_balance
333
+ self.concat_balance = concat_balance
334
+ self.dropout = dropout
335
+ self.channels_per_head = channels_per_head
336
+ self.res_balance = res_balance
337
+ self.attn_balance = attn_balance
338
+ self.clip_act = clip_act
339
+ self.unet = EDM2UNet(
340
+ img_resolution=sample_size,
341
+ img_channels=in_channels,
342
+ label_dim=num_class_embeds,
343
+ model_channels=model_channels,
344
+ channel_mult=channel_mult,
345
+ channel_mult_noise=channel_mult_noise,
346
+ channel_mult_emb=channel_mult_emb,
347
+ num_blocks=num_blocks,
348
+ attn_resolutions=attn_resolutions,
349
+ label_balance=label_balance,
350
+ concat_balance=concat_balance,
351
+ dropout=dropout,
352
+ channels_per_head=channels_per_head,
353
+ res_balance=res_balance,
354
+ attn_balance=attn_balance,
355
+ clip_act=clip_act,
356
+ )
357
+ self.logvar_fourier = MPFourier(logvar_channels)
358
+ self.logvar_linear = MPConv(logvar_channels, 1, kernel=())
359
+
360
+ def forward(
361
+ self,
362
+ sample: torch.Tensor,
363
+ sigma: torch.Tensor,
364
+ class_labels: Optional[torch.Tensor] = None,
365
+ force_fp32: bool = False,
366
+ return_logvar: bool = False,
367
+ return_dict: bool = True,
368
+ ) -> EDM2UNet2DOutput:
369
+ x = sample.to(torch.float32)
370
+ sigma = sigma.to(torch.float32).reshape(-1, 1, 1, 1)
371
+ if self.num_class_embeds == 0:
372
+ class_labels = None
373
+ else:
374
+ if class_labels is None:
375
+ class_labels = torch.zeros([x.shape[0], self.num_class_embeds], device=x.device)
376
+ class_labels = class_labels.to(torch.float32).reshape(-1, self.num_class_embeds)
377
+ dtype = torch.float16 if (self.use_fp16 and not force_fp32 and x.device.type == "cuda") else torch.float32
378
+
379
+ c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2)
380
+ c_out = sigma * self.sigma_data / (sigma**2 + self.sigma_data**2).sqrt()
381
+ c_in = 1 / (self.sigma_data**2 + sigma**2).sqrt()
382
+ c_noise = sigma.flatten().log() / 4
383
+
384
+ x_in = (c_in * x).to(dtype)
385
+ f_x = self.unet(x_in, c_noise, class_labels)
386
+ d_x = c_skip * x + c_out * f_x.to(torch.float32)
387
+
388
+ logvar = None
389
+ if return_logvar:
390
+ logvar = self.logvar_linear(self.logvar_fourier(c_noise)).reshape(-1, 1, 1, 1)
391
+
392
+ if not return_dict:
393
+ return (d_x, logvar)
394
+ return EDM2UNet2DOutput(sample=d_x, logvar=logvar)
395
+
396
+ @classmethod
397
+ def from_pretrained(cls, pretrained_model_name_or_path: str, torch_dtype: Optional[torch.dtype] = None, **kwargs):
398
+ subfolder = kwargs.pop("subfolder", None)
399
+ model_dir = os.path.join(pretrained_model_name_or_path, subfolder) if subfolder else pretrained_model_name_or_path
400
+ with open(os.path.join(model_dir, "config.json"), "r", encoding="utf-8") as f:
401
+ config = json.load(f)
402
+ init_kwargs = {k: v for k, v in config.items() if k in _CONFIG_KEYS}
403
+ model = cls(**init_kwargs)
404
+ weight_file = os.path.join(model_dir, "diffusion_pytorch_model.safetensors")
405
+ if os.path.isfile(weight_file):
406
+ from safetensors.torch import load_file
407
+
408
+ state_dict = load_file(weight_file)
409
+ else:
410
+ state_dict = torch.load(os.path.join(model_dir, "diffusion_pytorch_model.bin"), map_location="cpu")
411
+ model.load_state_dict(state_dict, strict=True)
412
+ if torch_dtype is not None:
413
+ model = model.to(dtype=torch_dtype)
414
+ return model
415
+
416
+ def save_pretrained(self, save_directory: str, safe_serialization: bool = True):
417
+ os.makedirs(save_directory, exist_ok=True)
418
+ stored = dict(getattr(self, "config", {}))
419
+ config = {"_class_name": self.__class__.__name__}
420
+ for key in _CONFIG_KEYS:
421
+ if key in stored:
422
+ config[key] = stored[key]
423
+ elif hasattr(self, key):
424
+ config[key] = getattr(self, key)
425
+ with open(os.path.join(save_directory, "config.json"), "w", encoding="utf-8") as f:
426
+ json.dump(config, f, indent=2, sort_keys=True)
427
+ f.write("\n")
428
+ state_dict = self.state_dict()
429
+ if safe_serialization:
430
+ from safetensors.torch import save_file
431
+
432
+ save_file(state_dict, os.path.join(save_directory, "diffusion_pytorch_model.safetensors"))
433
+ else:
434
+ torch.save(state_dict, os.path.join(save_directory, "diffusion_pytorch_model.bin"))
edm2-img512-l-dino/vae/config.json ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "AutoencoderKL",
3
+ "_diffusers_version": "0.36.0",
4
+ "_name_or_path": "stabilityai/sd-vae-ft-mse",
5
+ "act_fn": "silu",
6
+ "block_out_channels": [
7
+ 128,
8
+ 256,
9
+ 512,
10
+ 512
11
+ ],
12
+ "down_block_types": [
13
+ "DownEncoderBlock2D",
14
+ "DownEncoderBlock2D",
15
+ "DownEncoderBlock2D",
16
+ "DownEncoderBlock2D"
17
+ ],
18
+ "force_upcast": true,
19
+ "in_channels": 3,
20
+ "latent_channels": 4,
21
+ "latents_mean": null,
22
+ "latents_std": null,
23
+ "layers_per_block": 2,
24
+ "mid_block_add_attention": true,
25
+ "norm_num_groups": 32,
26
+ "out_channels": 3,
27
+ "sample_size": 256,
28
+ "scaling_factor": 0.18215,
29
+ "shift_factor": null,
30
+ "up_block_types": [
31
+ "UpDecoderBlock2D",
32
+ "UpDecoderBlock2D",
33
+ "UpDecoderBlock2D",
34
+ "UpDecoderBlock2D"
35
+ ],
36
+ "use_post_quant_conv": true,
37
+ "use_quant_conv": true
38
+ }
edm2-img512-l-dino/vae/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a1d993488569e928462932c8c38a0760b874d166399b14414135bd9c42df5815
3
+ size 334643276
edm2-img512-l-fid/generator_test.png ADDED

Git LFS Details

  • SHA256: cabf3ca8019e86c4a85855d5c3fd2c6de6d25ac51682da208d20db23533e6578
  • Pointer size: 131 Bytes
  • Size of remote file: 379 kB
edm2-img512-l-fid/model_index.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": [
3
+ "pipeline",
4
+ "EDM2Pipeline"
5
+ ],
6
+ "_diffusers_version": "0.31.0",
7
+ "scheduler": [
8
+ "diffusers",
9
+ "EDMEulerScheduler"
10
+ ],
11
+ "unet": [
12
+ "unet_edm2",
13
+ "EDM2UNet2DModel"
14
+ ],
15
+ "vae": [
16
+ "diffusers",
17
+ "AutoencoderKL"
18
+ ]
19
+ }
edm2-img512-l-fid/pipeline.py ADDED
@@ -0,0 +1,406 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Hub custom pipeline: EDM2Pipeline.
2
+ Load with native Hugging Face diffusers and trust_remote_code=True.
3
+ """
4
+
5
+ from __future__ import annotations
6
+
7
+ # Copyright 2026 The HuggingFace Team. All rights reserved.
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+
21
+ import json
22
+ import os
23
+ from pathlib import Path
24
+ from typing import Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Union
25
+
26
+ import numpy as np
27
+ import torch
28
+
29
+ from diffusers.image_processor import VaeImageProcessor
30
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
31
+ from diffusers.utils import replace_example_docstring
32
+ from diffusers.utils.torch_utils import randn_tensor
33
+
34
+ EXAMPLE_DOC_STRING = """
35
+ Examples:
36
+ ```py
37
+ >>> from pathlib import Path
38
+ >>> import torch
39
+ >>> from diffusers import DiffusionPipeline
40
+
41
+ >>> model_dir = Path("BiliSakura/EDM2-diffusers/edm2-img512-xs-fid").resolve()
42
+ >>> pipe = DiffusionPipeline.from_pretrained(
43
+ ... str(model_dir),
44
+ ... local_files_only=True,
45
+ ... custom_pipeline=str(model_dir / "pipeline.py"),
46
+ ... trust_remote_code=True,
47
+ ... torch_dtype=torch.float32,
48
+ ... )
49
+ >>> pipe.to("cuda")
50
+
51
+ >>> generator = torch.Generator(device="cuda").manual_seed(42)
52
+ >>> image = pipe(
53
+ ... class_labels=207,
54
+ ... num_inference_steps=32,
55
+ ... guidance_scale=1.0,
56
+ ... generator=generator,
57
+ ... ).images[0]
58
+ >>> image.save("demo.png")
59
+ ```
60
+ """
61
+
62
+ # Default Stability VAE latent whitening used by NVlabs/edm2 (training/encoders.py).
63
+ _STABILITY_VAE_SCALE = np.float32(0.5) / np.float32([4.17, 4.62, 3.71, 3.28])
64
+ _STABILITY_VAE_BIAS = np.float32(0.0) - np.float32([5.81, 3.25, 0.12, -2.15]) * _STABILITY_VAE_SCALE
65
+
66
+ class EDM2Pipeline(DiffusionPipeline):
67
+ r"""
68
+ Pipeline for class-conditional image generation with EDM2
69
+ ([Analyzing and Improving the Training Dynamics of Diffusion Models](https://arxiv.org/abs/2312.02696)).
70
+
71
+ Parameters:
72
+ unet ([`EDM2UNet2DModel`]):
73
+ Main magnitude-preserving U-Net with EDM preconditioning.
74
+ scheduler ([`EDMEulerScheduler`]):
75
+ Built-in diffusers scheduler used for the Karras sigma schedule. EDM2 Heun sampling runs in
76
+ the pipeline because the UNet returns denoised latents rather than noise predictions.
77
+ vae ([`AutoencoderKL`], *optional*):
78
+ Decoder for 512px latent-diffusion checkpoints. Required when `unet.in_channels == 4`.
79
+ gnet ([`EDM2UNet2DModel`], *optional*):
80
+ Guiding network for autoguidance (`ref.lerp(main, guidance_scale)`).
81
+ id2label (`dict[int, str]`, *optional*):
82
+ ImageNet class id to English label mapping.
83
+ """
84
+
85
+ model_cpu_offload_seq = "unet->gnet->vae"
86
+ _optional_components = ["vae", "gnet"]
87
+
88
+ def __init__(
89
+ self,
90
+ unet,
91
+ scheduler,
92
+ vae=None,
93
+ gnet=None,
94
+ id2label: Optional[Dict[Union[int, str], str]] = None,
95
+ ) -> None:
96
+ super().__init__()
97
+ self.register_modules(unet=unet, scheduler=scheduler, vae=vae, gnet=gnet)
98
+ self._id2label = self._normalize_id2label(id2label)
99
+ self.labels = self._build_label2id(self._id2label)
100
+ self._labels_loaded_from_model_index = bool(self._id2label)
101
+ self.vae_scale_factor = 8 if self.vae is not None else 1
102
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_normalize=False)
103
+
104
+ @staticmethod
105
+ def _normalize_id2label(id2label: Optional[Dict[Union[int, str], str]]) -> Dict[int, str]:
106
+ if not id2label:
107
+ return {}
108
+ return {int(key): value for key, value in id2label.items()}
109
+
110
+ @staticmethod
111
+ def _build_label2id(id2label: Dict[int, str]) -> Dict[str, int]:
112
+ label2id: Dict[str, int] = {}
113
+ for class_id, value in id2label.items():
114
+ for synonym in value.split(","):
115
+ synonym = synonym.strip()
116
+ if synonym:
117
+ label2id[synonym] = int(class_id)
118
+ return dict(sorted(label2id.items()))
119
+
120
+ def _ensure_labels_loaded(self) -> None:
121
+ if self._labels_loaded_from_model_index:
122
+ return
123
+ loaded = self._read_id2label_from_model_index(getattr(self.config, "_name_or_path", None))
124
+ if loaded:
125
+ self._id2label = loaded
126
+ self.labels = self._build_label2id(self._id2label)
127
+ self._labels_loaded_from_model_index = True
128
+
129
+ @staticmethod
130
+ def _read_id2label_from_model_index(variant_path: Optional[str]) -> Dict[int, str]:
131
+ if not variant_path:
132
+ return {}
133
+ model_index_path = Path(variant_path).resolve() / "model_index.json"
134
+ if not model_index_path.is_file():
135
+ return {}
136
+ raw = json.loads(model_index_path.read_text(encoding="utf-8"))
137
+ id2label = raw.get("id2label")
138
+ if not isinstance(id2label, dict):
139
+ return {}
140
+ return {int(key): value for key, value in id2label.items()}
141
+
142
+ @property
143
+ def id2label(self) -> Dict[int, str]:
144
+ r"""ImageNet class id to English label string (comma-separated synonyms)."""
145
+ self._ensure_labels_loaded()
146
+ return self._id2label
147
+
148
+ def get_label_ids(self, label: Union[str, List[str]]) -> List[int]:
149
+ r"""
150
+ Map ImageNet label strings to class ids.
151
+
152
+ Args:
153
+ label (`str` or `list[str]`):
154
+ One or more English label strings that match entries in `id2label`.
155
+ """
156
+ self._ensure_labels_loaded()
157
+ if not self.labels:
158
+ raise ValueError("No English labels loaded. Add `id2label` to model_index.json.")
159
+ labels = [label] if isinstance(label, str) else list(label)
160
+ missing = [item for item in labels if item not in self.labels]
161
+ if missing:
162
+ preview = ", ".join(list(self.labels.keys())[:8])
163
+ raise ValueError(f"Unknown English label(s): {missing}. Example valid labels: {preview}, ...")
164
+ return [self.labels[item] for item in labels]
165
+
166
+ def _default_image_size(self) -> int:
167
+ latent_size = int(getattr(self.unet, "sample_size", getattr(self.unet.config, "sample_size", 64)))
168
+ return latent_size * self.vae_scale_factor
169
+
170
+ def check_inputs(
171
+ self,
172
+ height: int,
173
+ width: int,
174
+ num_inference_steps: int,
175
+ guidance_scale: float,
176
+ output_type: str,
177
+ ) -> None:
178
+ if num_inference_steps < 1:
179
+ raise ValueError("num_inference_steps must be >= 1.")
180
+ if guidance_scale < 1.0:
181
+ raise ValueError("guidance_scale must be >= 1.0.")
182
+ if guidance_scale > 1.0 and self.gnet is None:
183
+ raise ValueError("guidance_scale > 1.0 requires a guiding network (`gnet`).")
184
+ if output_type not in {"pil", "np", "pt", "latent"}:
185
+ raise ValueError("output_type must be one of: 'pil', 'np', 'pt', 'latent'.")
186
+
187
+ native_size = self._default_image_size()
188
+ if height != native_size or width != native_size:
189
+ raise ValueError(
190
+ f"EDM2 expects native resolution height=width={native_size}. "
191
+ f"Got height={height}, width={width}."
192
+ )
193
+
194
+ def _normalize_class_labels(
195
+ self,
196
+ class_labels: Optional[Union[int, str, Sequence[Union[int, str]], torch.Tensor]],
197
+ batch_size: int,
198
+ device: torch.device,
199
+ ) -> Optional[torch.Tensor]:
200
+ label_dim = int(getattr(self.unet, "num_class_embeds", getattr(self.unet.config, "num_class_embeds", 0)))
201
+ if label_dim == 0:
202
+ return None
203
+ if class_labels is None:
204
+ indices = torch.randint(label_dim, size=(batch_size,), device=device)
205
+ return torch.eye(label_dim, device=device, dtype=torch.float32)[indices]
206
+
207
+ if isinstance(class_labels, str):
208
+ class_labels = self.get_label_ids(class_labels)[0]
209
+ elif isinstance(class_labels, Sequence) and class_labels and isinstance(class_labels[0], str):
210
+ class_labels = self.get_label_ids(list(class_labels))
211
+
212
+ if isinstance(class_labels, int):
213
+ indices = torch.full((batch_size,), class_labels, device=device, dtype=torch.long)
214
+ elif isinstance(class_labels, torch.Tensor):
215
+ if class_labels.ndim == 2:
216
+ labels = class_labels.to(device=device, dtype=torch.float32)
217
+ if labels.shape[0] != batch_size:
218
+ raise ValueError(f"class_labels batch must match batch_size={batch_size}.")
219
+ return labels
220
+ indices = class_labels.to(device=device, dtype=torch.long).flatten()
221
+ else:
222
+ indices = torch.tensor(list(class_labels), device=device, dtype=torch.long)
223
+
224
+ if indices.numel() == 1 and batch_size > 1:
225
+ indices = indices.repeat(batch_size)
226
+ if indices.numel() != batch_size:
227
+ raise ValueError(f"class_labels must resolve to batch size {batch_size}, got {indices.numel()}.")
228
+ return torch.eye(label_dim, device=device, dtype=torch.float32)[indices]
229
+
230
+ def prepare_latents(
231
+ self,
232
+ batch_size: int,
233
+ height: int,
234
+ width: int,
235
+ dtype: torch.dtype,
236
+ device: torch.device,
237
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]],
238
+ ) -> torch.Tensor:
239
+ in_channels = int(getattr(self.unet, "in_channels", getattr(self.unet.config, "in_channels", 4)))
240
+ latent_size = height // self.vae_scale_factor
241
+ return randn_tensor(
242
+ (batch_size, in_channels, latent_size, latent_size),
243
+ generator=generator,
244
+ device=device,
245
+ dtype=torch.float32,
246
+ )
247
+
248
+ def decode_latents(self, latents: torch.Tensor, output_type: str = "pil"):
249
+ if output_type == "latent":
250
+ return latents
251
+
252
+ in_channels = int(getattr(self.unet, "in_channels", getattr(self.unet.config, "in_channels", 3)))
253
+ if self.vae is None:
254
+ image = (latents.to(torch.float32) * 127.5 + 128).clip(0, 255) / 255.0
255
+ return self.image_processor.postprocess(image, output_type=output_type)
256
+
257
+ if in_channels == 4:
258
+ x = latents.to(torch.float32)
259
+ scale = torch.as_tensor(_STABILITY_VAE_SCALE, dtype=x.dtype, device=x.device).reshape(1, -1, 1, 1)
260
+ bias = torch.as_tensor(_STABILITY_VAE_BIAS, dtype=x.dtype, device=x.device).reshape(1, -1, 1, 1)
261
+ x = (x - bias) / scale
262
+ else:
263
+ x = latents.to(torch.float32)
264
+
265
+ vae_dtype = getattr(self.vae, "dtype", None) or next(self.vae.parameters()).dtype
266
+ image = self.vae.decode(x.to(dtype=vae_dtype)).sample.to(torch.float32).clamp(0, 1)
267
+
268
+ return self.image_processor.postprocess(image, output_type=output_type)
269
+
270
+ @staticmethod
271
+ def _apply_autoguidance(
272
+ main: torch.Tensor,
273
+ ref: torch.Tensor,
274
+ guidance_scale: float,
275
+ ) -> torch.Tensor:
276
+ return ref.lerp(main, guidance_scale)
277
+
278
+ @staticmethod
279
+ def _sample_edm2_heun(
280
+ denoise_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
281
+ noise: torch.Tensor,
282
+ sigmas: torch.Tensor,
283
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
284
+ progress_bar: Optional[Callable[[Iterable], Iterable]] = None,
285
+ dtype: torch.dtype = torch.float32,
286
+ ) -> torch.Tensor:
287
+ """NVlabs EDM2 Heun sampler (generate_images.edm_sampler, guidance=1, S_churn=0)."""
288
+ x_next = noise.to(dtype) * sigmas[0]
289
+
290
+ sigma_pairs = list(zip(sigmas[:-1], sigmas[1:]))
291
+ if progress_bar is not None:
292
+ sigma_pairs = progress_bar(sigma_pairs)
293
+
294
+ num_steps = len(sigma_pairs)
295
+ for i, (sigma_cur, sigma_next) in enumerate(sigma_pairs):
296
+ x_hat, sigma_hat = x_next, sigma_cur
297
+ d_cur = (x_hat - denoise_fn(x_hat, sigma_hat)) / sigma_hat
298
+ x_next = x_hat + (sigma_next - sigma_hat) * d_cur
299
+ if i < num_steps - 1:
300
+ d_prime = (x_next - denoise_fn(x_next, sigma_next)) / sigma_next
301
+ x_next = x_hat + (sigma_next - sigma_hat) * (0.5 * d_cur + 0.5 * d_prime)
302
+ return x_next
303
+
304
+ @torch.inference_mode()
305
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
306
+ def __call__(
307
+ self,
308
+ class_labels: Optional[Union[int, str, Sequence[Union[int, str]], torch.Tensor]] = None,
309
+ batch_size: int = 1,
310
+ height: Optional[int] = None,
311
+ width: Optional[int] = None,
312
+ num_inference_steps: int = 32,
313
+ guidance_scale: float = 1.0,
314
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
315
+ output_type: str = "pil",
316
+ return_dict: bool = True,
317
+ ) -> Union[ImagePipelineOutput, Tuple]:
318
+ r"""
319
+ Generate class-conditional images with EDM2.
320
+
321
+ Args:
322
+ class_labels (`int`, `str`, `list[int]`, `list[str]`, or `torch.Tensor`, *optional*):
323
+ ImageNet class indices, English label strings, or one-hot float tensors.
324
+ Random classes are sampled when omitted on conditional models.
325
+ batch_size (`int`, defaults to `1`):
326
+ Number of images to generate.
327
+ height (`int`, *optional*):
328
+ Output height in pixels. Defaults to the pretrained native resolution.
329
+ width (`int`, *optional*):
330
+ Output width in pixels. Defaults to the pretrained native resolution.
331
+ num_inference_steps (`int`, defaults to `32`):
332
+ Number of EDM2 Heun steps (NVlabs default).
333
+ guidance_scale (`float`, defaults to `1.0`):
334
+ Autoguidance strength. Values above `1.0` blend the main net with `gnet`
335
+ via `gnet_output.lerp(unet_output, guidance_scale)`.
336
+ generator (`torch.Generator`, *optional*):
337
+ RNG for reproducibility.
338
+ output_type (`str`, defaults to `"pil"`):
339
+ `"pil"`, `"np"`, `"pt"`, or `"latent"`.
340
+ return_dict (`bool`, defaults to `True`):
341
+ Return [`~pipelines.pipeline_utils.ImagePipelineOutput`] if True.
342
+
343
+ Examples:
344
+ <!-- this section is replaced by replace_example_docstring -->
345
+ """
346
+ default_size = self._default_image_size()
347
+ height = int(height or default_size)
348
+ width = int(width or default_size)
349
+ self.check_inputs(height, width, num_inference_steps, guidance_scale, output_type)
350
+
351
+ device = self._execution_device
352
+ dtype = self.unet.dtype
353
+ labels = self._normalize_class_labels(class_labels, batch_size=batch_size, device=device)
354
+ noise = self.prepare_latents(batch_size, height, width, dtype, device, generator)
355
+
356
+ def denoise_fn(x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor:
357
+ sigma_batch = sigma.reshape(1).expand(batch_size)
358
+ main = self.unet(
359
+ sample=x,
360
+ sigma=sigma_batch,
361
+ class_labels=labels,
362
+ force_fp32=True,
363
+ ).sample
364
+ if guidance_scale == 1.0 or self.gnet is None:
365
+ return main.to(torch.float32)
366
+ ref = self.gnet(
367
+ sample=x,
368
+ sigma=sigma_batch,
369
+ class_labels=labels,
370
+ force_fp32=True,
371
+ ).sample
372
+ return self._apply_autoguidance(main, ref, guidance_scale).to(torch.float32)
373
+
374
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
375
+ latents = self._sample_edm2_heun(
376
+ denoise_fn=denoise_fn,
377
+ noise=noise,
378
+ sigmas=self.scheduler.sigmas.to(device),
379
+ generator=generator,
380
+ progress_bar=self.progress_bar,
381
+ dtype=torch.float32,
382
+ )
383
+
384
+ image = self.decode_latents(latents, output_type=output_type)
385
+ if not return_dict:
386
+ return (image, latents)
387
+ return ImagePipelineOutput(images=image)
388
+
389
+ @classmethod
390
+ def _load_vae(cls, pretrained_model_name_or_path: str, torch_dtype: Optional[torch.dtype] = None):
391
+ vae_dir = os.path.join(pretrained_model_name_or_path, "vae")
392
+ if os.path.isdir(vae_dir):
393
+ try:
394
+
395
+ return AutoencoderKL.from_pretrained(vae_dir, torch_dtype=torch_dtype)
396
+ except Exception:
397
+ return None
398
+
399
+ vae_hint = os.path.join(pretrained_model_name_or_path, "vae_pretrained_model_name_or_path.txt")
400
+ if os.path.isfile(vae_hint):
401
+ with open(vae_hint, "r", encoding="utf-8") as f:
402
+ hub_id = f.read().strip()
403
+ if hub_id:
404
+
405
+ return AutoencoderKL.from_pretrained(hub_id, torch_dtype=torch_dtype)
406
+ return None
edm2-img512-l-fid/scheduler/scheduler_config.json ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "EDMEulerScheduler",
3
+ "final_sigmas_type": "zero",
4
+ "num_train_timesteps": 1000,
5
+ "prediction_type": "epsilon",
6
+ "rho": 7.0,
7
+ "sigma_data": 0.5,
8
+ "sigma_max": 80.0,
9
+ "sigma_min": 0.002,
10
+ "sigma_schedule": "karras"
11
+ }
edm2-img512-l-fid/unet/config.json ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "EDM2UNet2DModel",
3
+ "attn_balance": 0.3,
4
+ "attn_resolutions": [
5
+ 16,
6
+ 8
7
+ ],
8
+ "channel_mult": [
9
+ 1,
10
+ 2,
11
+ 3,
12
+ 4
13
+ ],
14
+ "channel_mult_emb": 4,
15
+ "channel_mult_noise": 1,
16
+ "channels_per_head": 64,
17
+ "clip_act": 256,
18
+ "concat_balance": 0.5,
19
+ "dropout": 0.0,
20
+ "in_channels": 4,
21
+ "label_balance": 0.5,
22
+ "logvar_channels": 128,
23
+ "model_channels": 320,
24
+ "num_blocks": 3,
25
+ "num_class_embeds": 1000,
26
+ "out_channels": 4,
27
+ "res_balance": 0.3,
28
+ "sample_size": 64,
29
+ "sigma_data": 0.5,
30
+ "use_fp16": true
31
+ }
edm2-img512-l-fid/unet/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7a3e3f5127c12027e4796bef297e247a38ddd13bb7b8445c5d41169106b94389
3
+ size 3110018564
edm2-img512-l-fid/unet/unet_edm2.py ADDED
@@ -0,0 +1,434 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import json
3
+ import os
4
+ from dataclasses import dataclass
5
+ from typing import List, Optional, Tuple
6
+
7
+ import numpy as np
8
+ import torch
9
+
10
+ try:
11
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
12
+ from diffusers.models.modeling_utils import ModelMixin
13
+ from diffusers.utils import BaseOutput
14
+ except ImportError: # pragma: no cover
15
+ class ModelMixin(torch.nn.Module):
16
+ pass
17
+
18
+ class ConfigMixin:
19
+ config = {}
20
+
21
+ def register_to_config(self, **kwargs):
22
+ self.config = kwargs
23
+
24
+ def register_to_config(func):
25
+ return func
26
+
27
+ @dataclass
28
+ class BaseOutput:
29
+ pass
30
+
31
+
32
+ def normalize(x: torch.Tensor, dim: Optional[List[int]] = None, eps: float = 1e-4) -> torch.Tensor:
33
+ if dim is None:
34
+ dim = list(range(1, x.ndim))
35
+ norm = torch.linalg.vector_norm(x, dim=dim, keepdim=True, dtype=torch.float32)
36
+ norm = torch.add(eps, norm, alpha=math.sqrt(norm.numel() / x.numel()))
37
+ return x / norm.to(x.dtype)
38
+
39
+
40
+ def resample(x: torch.Tensor, f: List[float], mode: str = "keep") -> torch.Tensor:
41
+ if mode == "keep":
42
+ return x
43
+ filt = np.float32(f)
44
+ pad = (len(filt) - 1) // 2
45
+ filt = filt / filt.sum()
46
+ filt = np.outer(filt, filt)[np.newaxis, np.newaxis, :, :]
47
+ filt = torch.as_tensor(filt, dtype=x.dtype, device=x.device)
48
+ c = x.shape[1]
49
+ if mode == "down":
50
+ return torch.nn.functional.conv2d(x, filt.tile([c, 1, 1, 1]), groups=c, stride=2, padding=(pad,))
51
+ return torch.nn.functional.conv_transpose2d(x, (filt * 4).tile([c, 1, 1, 1]), groups=c, stride=2, padding=(pad,))
52
+
53
+
54
+ def mp_silu(x: torch.Tensor) -> torch.Tensor:
55
+ return torch.nn.functional.silu(x) / 0.596
56
+
57
+
58
+ def mp_sum(a: torch.Tensor, b: torch.Tensor, t: float = 0.5) -> torch.Tensor:
59
+ return a.lerp(b, t) / math.sqrt((1 - t) ** 2 + t ** 2)
60
+
61
+
62
+ def mp_cat(a: torch.Tensor, b: torch.Tensor, dim: int = 1, t: float = 0.5) -> torch.Tensor:
63
+ na = a.shape[dim]
64
+ nb = b.shape[dim]
65
+ c = math.sqrt((na + nb) / ((1 - t) ** 2 + t ** 2))
66
+ wa = c / math.sqrt(na) * (1 - t)
67
+ wb = c / math.sqrt(nb) * t
68
+ return torch.cat([wa * a, wb * b], dim=dim)
69
+
70
+
71
+ class MPFourier(torch.nn.Module):
72
+ def __init__(self, num_channels: int, bandwidth: float = 1):
73
+ super().__init__()
74
+ self.register_buffer("freqs", 2 * math.pi * torch.randn(num_channels) * bandwidth)
75
+ self.register_buffer("phases", 2 * math.pi * torch.rand(num_channels))
76
+
77
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
78
+ y = x.to(torch.float32).ger(self.freqs.to(torch.float32))
79
+ y = y + self.phases.to(torch.float32)
80
+ y = y.cos() * math.sqrt(2)
81
+ return y.to(x.dtype)
82
+
83
+
84
+ class MPConv(torch.nn.Module):
85
+ def __init__(self, in_channels: int, out_channels: int, kernel: Tuple[int, ...]):
86
+ super().__init__()
87
+ self.out_channels = out_channels
88
+ self.weight = torch.nn.Parameter(torch.randn(out_channels, in_channels, *kernel))
89
+
90
+ def forward(self, x: torch.Tensor, gain: float = 1) -> torch.Tensor:
91
+ w = self.weight.to(torch.float32)
92
+ if self.training:
93
+ with torch.no_grad():
94
+ self.weight.copy_(normalize(w))
95
+ w = normalize(w)
96
+ w = w * (gain / math.sqrt(w[0].numel()))
97
+ w = w.to(x.dtype)
98
+ if w.ndim == 2:
99
+ return x @ w.t()
100
+ return torch.nn.functional.conv2d(x, w, padding=(w.shape[-1] // 2,))
101
+
102
+
103
+ class Block(torch.nn.Module):
104
+ def __init__(
105
+ self,
106
+ in_channels: int,
107
+ out_channels: int,
108
+ emb_channels: int,
109
+ flavor: str = "enc",
110
+ resample_mode: str = "keep",
111
+ resample_filter: List[float] = [1, 1],
112
+ attention: bool = False,
113
+ channels_per_head: int = 64,
114
+ dropout: float = 0.0,
115
+ res_balance: float = 0.3,
116
+ attn_balance: float = 0.3,
117
+ clip_act: Optional[float] = 256,
118
+ ):
119
+ super().__init__()
120
+ self.out_channels = out_channels
121
+ self.flavor = flavor
122
+ self.resample_filter = resample_filter
123
+ self.resample_mode = resample_mode
124
+ self.num_heads = out_channels // channels_per_head if attention else 0
125
+ self.dropout = dropout
126
+ self.res_balance = res_balance
127
+ self.attn_balance = attn_balance
128
+ self.clip_act = clip_act
129
+ self.emb_gain = torch.nn.Parameter(torch.zeros([]))
130
+ self.conv_res0 = MPConv(out_channels if flavor == "enc" else in_channels, out_channels, kernel=(3, 3))
131
+ self.emb_linear = MPConv(emb_channels, out_channels, kernel=())
132
+ self.conv_res1 = MPConv(out_channels, out_channels, kernel=(3, 3))
133
+ self.conv_skip = MPConv(in_channels, out_channels, kernel=(1, 1)) if in_channels != out_channels else None
134
+ self.attn_qkv = MPConv(out_channels, out_channels * 3, kernel=(1, 1)) if self.num_heads else None
135
+ self.attn_proj = MPConv(out_channels, out_channels, kernel=(1, 1)) if self.num_heads else None
136
+
137
+ def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor:
138
+ x = resample(x, f=self.resample_filter, mode=self.resample_mode)
139
+ if self.flavor == "enc":
140
+ if self.conv_skip is not None:
141
+ x = self.conv_skip(x)
142
+ x = normalize(x, dim=[1])
143
+
144
+ y = self.conv_res0(mp_silu(x))
145
+ c = self.emb_linear(emb, gain=self.emb_gain) + 1
146
+ y = mp_silu(y * c.unsqueeze(2).unsqueeze(3).to(y.dtype))
147
+ if self.training and self.dropout:
148
+ y = torch.nn.functional.dropout(y, p=self.dropout)
149
+ y = self.conv_res1(y)
150
+
151
+ if self.flavor == "dec" and self.conv_skip is not None:
152
+ x = self.conv_skip(x)
153
+ x = mp_sum(x, y, t=self.res_balance)
154
+
155
+ if self.num_heads:
156
+ y = self.attn_qkv(x)
157
+ y = y.reshape(y.shape[0], self.num_heads, -1, 3, y.shape[2] * y.shape[3])
158
+ q, k, v = normalize(y, dim=[2]).unbind(3)
159
+ w = torch.einsum("nhcq,nhck->nhqk", q, k / math.sqrt(q.shape[2])).softmax(dim=3)
160
+ y = torch.einsum("nhqk,nhck->nhcq", w, v)
161
+ y = self.attn_proj(y.reshape(*x.shape))
162
+ x = mp_sum(x, y, t=self.attn_balance)
163
+
164
+ if self.clip_act is not None:
165
+ x = x.clip_(-self.clip_act, self.clip_act)
166
+ return x
167
+
168
+
169
+ class EDM2UNet(torch.nn.Module):
170
+ def __init__(
171
+ self,
172
+ img_resolution: int,
173
+ img_channels: int,
174
+ label_dim: int,
175
+ model_channels: int = 192,
176
+ channel_mult: Tuple[int, ...] = (1, 2, 3, 4),
177
+ channel_mult_noise: Optional[int] = None,
178
+ channel_mult_emb: Optional[int] = None,
179
+ num_blocks: int = 3,
180
+ attn_resolutions: Tuple[int, ...] = (16, 8),
181
+ label_balance: float = 0.5,
182
+ concat_balance: float = 0.5,
183
+ **block_kwargs,
184
+ ):
185
+ super().__init__()
186
+ cblock = [model_channels * x for x in channel_mult]
187
+ cnoise = model_channels * channel_mult_noise if channel_mult_noise is not None else cblock[0]
188
+ cemb = model_channels * channel_mult_emb if channel_mult_emb is not None else max(cblock)
189
+ self.label_balance = label_balance
190
+ self.concat_balance = concat_balance
191
+ self.out_gain = torch.nn.Parameter(torch.zeros([]))
192
+
193
+ self.emb_fourier = MPFourier(cnoise)
194
+ self.emb_noise = MPConv(cnoise, cemb, kernel=())
195
+ self.emb_label = MPConv(label_dim, cemb, kernel=()) if label_dim else None
196
+
197
+ self.enc = torch.nn.ModuleDict()
198
+ cout = img_channels + 1
199
+ for level, channels in enumerate(cblock):
200
+ res = img_resolution >> level
201
+ if level == 0:
202
+ cin = cout
203
+ cout = channels
204
+ self.enc[f"{res}x{res}_conv"] = MPConv(cin, cout, kernel=(3, 3))
205
+ else:
206
+ self.enc[f"{res}x{res}_down"] = Block(cout, cout, cemb, flavor="enc", resample_mode="down", **block_kwargs)
207
+ for idx in range(num_blocks):
208
+ cin = cout
209
+ cout = channels
210
+ self.enc[f"{res}x{res}_block{idx}"] = Block(
211
+ cin,
212
+ cout,
213
+ cemb,
214
+ flavor="enc",
215
+ attention=(res in attn_resolutions),
216
+ **block_kwargs,
217
+ )
218
+
219
+ self.dec = torch.nn.ModuleDict()
220
+ skips = [block.out_channels for block in self.enc.values()]
221
+ for level, channels in reversed(list(enumerate(cblock))):
222
+ res = img_resolution >> level
223
+ if level == len(cblock) - 1:
224
+ self.dec[f"{res}x{res}_in0"] = Block(cout, cout, cemb, flavor="dec", attention=True, **block_kwargs)
225
+ self.dec[f"{res}x{res}_in1"] = Block(cout, cout, cemb, flavor="dec", **block_kwargs)
226
+ else:
227
+ self.dec[f"{res}x{res}_up"] = Block(cout, cout, cemb, flavor="dec", resample_mode="up", **block_kwargs)
228
+ for idx in range(num_blocks + 1):
229
+ cin = cout + skips.pop()
230
+ cout = channels
231
+ self.dec[f"{res}x{res}_block{idx}"] = Block(
232
+ cin,
233
+ cout,
234
+ cemb,
235
+ flavor="dec",
236
+ attention=(res in attn_resolutions),
237
+ **block_kwargs,
238
+ )
239
+
240
+ self.out_conv = MPConv(cout, img_channels, kernel=(3, 3))
241
+
242
+ def forward(self, x: torch.Tensor, noise_labels: torch.Tensor, class_labels: Optional[torch.Tensor]) -> torch.Tensor:
243
+ emb = self.emb_noise(self.emb_fourier(noise_labels))
244
+ if self.emb_label is not None:
245
+ if class_labels is None:
246
+ raise ValueError("class_labels are required for conditional EDM2UNet.")
247
+ emb = mp_sum(emb, self.emb_label(class_labels * math.sqrt(class_labels.shape[1])), t=self.label_balance)
248
+ emb = mp_silu(emb)
249
+
250
+ x = torch.cat([x, torch.ones_like(x[:, :1])], dim=1)
251
+ skips = []
252
+ for name, block in self.enc.items():
253
+ x = block(x) if "conv" in name else block(x, emb)
254
+ skips.append(x)
255
+
256
+ for name, block in self.dec.items():
257
+ if "block" in name:
258
+ x = mp_cat(x, skips.pop(), t=self.concat_balance)
259
+ x = block(x, emb)
260
+ return self.out_conv(x, gain=self.out_gain)
261
+
262
+
263
+ @dataclass
264
+ class EDM2UNet2DOutput(BaseOutput):
265
+ sample: torch.Tensor
266
+ logvar: Optional[torch.Tensor] = None
267
+
268
+
269
+
270
+ _CONFIG_KEYS = (
271
+ "sample_size",
272
+ "in_channels",
273
+ "out_channels",
274
+ "num_class_embeds",
275
+ "use_fp16",
276
+ "sigma_data",
277
+ "logvar_channels",
278
+ "model_channels",
279
+ "channel_mult",
280
+ "channel_mult_noise",
281
+ "channel_mult_emb",
282
+ "num_blocks",
283
+ "attn_resolutions",
284
+ "label_balance",
285
+ "concat_balance",
286
+ "dropout",
287
+ "channels_per_head",
288
+ "res_balance",
289
+ "attn_balance",
290
+ "clip_act",
291
+ )
292
+
293
+
294
+ class EDM2UNet2DModel(ModelMixin, ConfigMixin):
295
+ @register_to_config
296
+ def __init__(
297
+ self,
298
+ sample_size: int = 64,
299
+ in_channels: int = 4,
300
+ out_channels: int = 4,
301
+ num_class_embeds: int = 0,
302
+ use_fp16: bool = True,
303
+ sigma_data: float = 0.5,
304
+ logvar_channels: int = 128,
305
+ model_channels: int = 192,
306
+ channel_mult: Tuple[int, ...] = (1, 2, 3, 4),
307
+ channel_mult_noise: Optional[int] = None,
308
+ channel_mult_emb: Optional[int] = None,
309
+ num_blocks: int = 3,
310
+ attn_resolutions: Tuple[int, ...] = (16, 8),
311
+ label_balance: float = 0.5,
312
+ concat_balance: float = 0.5,
313
+ dropout: float = 0.0,
314
+ channels_per_head: int = 64,
315
+ res_balance: float = 0.3,
316
+ attn_balance: float = 0.3,
317
+ clip_act: Optional[float] = 256,
318
+ ):
319
+ super().__init__()
320
+ self.sample_size = sample_size
321
+ self.in_channels = in_channels
322
+ self.out_channels = out_channels
323
+ self.num_class_embeds = num_class_embeds
324
+ self.use_fp16 = use_fp16
325
+ self.sigma_data = sigma_data
326
+ self.model_channels = model_channels
327
+ self.channel_mult = channel_mult
328
+ self.channel_mult_noise = channel_mult_noise
329
+ self.channel_mult_emb = channel_mult_emb
330
+ self.num_blocks = num_blocks
331
+ self.attn_resolutions = attn_resolutions
332
+ self.label_balance = label_balance
333
+ self.concat_balance = concat_balance
334
+ self.dropout = dropout
335
+ self.channels_per_head = channels_per_head
336
+ self.res_balance = res_balance
337
+ self.attn_balance = attn_balance
338
+ self.clip_act = clip_act
339
+ self.unet = EDM2UNet(
340
+ img_resolution=sample_size,
341
+ img_channels=in_channels,
342
+ label_dim=num_class_embeds,
343
+ model_channels=model_channels,
344
+ channel_mult=channel_mult,
345
+ channel_mult_noise=channel_mult_noise,
346
+ channel_mult_emb=channel_mult_emb,
347
+ num_blocks=num_blocks,
348
+ attn_resolutions=attn_resolutions,
349
+ label_balance=label_balance,
350
+ concat_balance=concat_balance,
351
+ dropout=dropout,
352
+ channels_per_head=channels_per_head,
353
+ res_balance=res_balance,
354
+ attn_balance=attn_balance,
355
+ clip_act=clip_act,
356
+ )
357
+ self.logvar_fourier = MPFourier(logvar_channels)
358
+ self.logvar_linear = MPConv(logvar_channels, 1, kernel=())
359
+
360
+ def forward(
361
+ self,
362
+ sample: torch.Tensor,
363
+ sigma: torch.Tensor,
364
+ class_labels: Optional[torch.Tensor] = None,
365
+ force_fp32: bool = False,
366
+ return_logvar: bool = False,
367
+ return_dict: bool = True,
368
+ ) -> EDM2UNet2DOutput:
369
+ x = sample.to(torch.float32)
370
+ sigma = sigma.to(torch.float32).reshape(-1, 1, 1, 1)
371
+ if self.num_class_embeds == 0:
372
+ class_labels = None
373
+ else:
374
+ if class_labels is None:
375
+ class_labels = torch.zeros([x.shape[0], self.num_class_embeds], device=x.device)
376
+ class_labels = class_labels.to(torch.float32).reshape(-1, self.num_class_embeds)
377
+ dtype = torch.float16 if (self.use_fp16 and not force_fp32 and x.device.type == "cuda") else torch.float32
378
+
379
+ c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2)
380
+ c_out = sigma * self.sigma_data / (sigma**2 + self.sigma_data**2).sqrt()
381
+ c_in = 1 / (self.sigma_data**2 + sigma**2).sqrt()
382
+ c_noise = sigma.flatten().log() / 4
383
+
384
+ x_in = (c_in * x).to(dtype)
385
+ f_x = self.unet(x_in, c_noise, class_labels)
386
+ d_x = c_skip * x + c_out * f_x.to(torch.float32)
387
+
388
+ logvar = None
389
+ if return_logvar:
390
+ logvar = self.logvar_linear(self.logvar_fourier(c_noise)).reshape(-1, 1, 1, 1)
391
+
392
+ if not return_dict:
393
+ return (d_x, logvar)
394
+ return EDM2UNet2DOutput(sample=d_x, logvar=logvar)
395
+
396
+ @classmethod
397
+ def from_pretrained(cls, pretrained_model_name_or_path: str, torch_dtype: Optional[torch.dtype] = None, **kwargs):
398
+ subfolder = kwargs.pop("subfolder", None)
399
+ model_dir = os.path.join(pretrained_model_name_or_path, subfolder) if subfolder else pretrained_model_name_or_path
400
+ with open(os.path.join(model_dir, "config.json"), "r", encoding="utf-8") as f:
401
+ config = json.load(f)
402
+ init_kwargs = {k: v for k, v in config.items() if k in _CONFIG_KEYS}
403
+ model = cls(**init_kwargs)
404
+ weight_file = os.path.join(model_dir, "diffusion_pytorch_model.safetensors")
405
+ if os.path.isfile(weight_file):
406
+ from safetensors.torch import load_file
407
+
408
+ state_dict = load_file(weight_file)
409
+ else:
410
+ state_dict = torch.load(os.path.join(model_dir, "diffusion_pytorch_model.bin"), map_location="cpu")
411
+ model.load_state_dict(state_dict, strict=True)
412
+ if torch_dtype is not None:
413
+ model = model.to(dtype=torch_dtype)
414
+ return model
415
+
416
+ def save_pretrained(self, save_directory: str, safe_serialization: bool = True):
417
+ os.makedirs(save_directory, exist_ok=True)
418
+ stored = dict(getattr(self, "config", {}))
419
+ config = {"_class_name": self.__class__.__name__}
420
+ for key in _CONFIG_KEYS:
421
+ if key in stored:
422
+ config[key] = stored[key]
423
+ elif hasattr(self, key):
424
+ config[key] = getattr(self, key)
425
+ with open(os.path.join(save_directory, "config.json"), "w", encoding="utf-8") as f:
426
+ json.dump(config, f, indent=2, sort_keys=True)
427
+ f.write("\n")
428
+ state_dict = self.state_dict()
429
+ if safe_serialization:
430
+ from safetensors.torch import save_file
431
+
432
+ save_file(state_dict, os.path.join(save_directory, "diffusion_pytorch_model.safetensors"))
433
+ else:
434
+ torch.save(state_dict, os.path.join(save_directory, "diffusion_pytorch_model.bin"))
edm2-img512-l-fid/vae/config.json ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "AutoencoderKL",
3
+ "_diffusers_version": "0.36.0",
4
+ "_name_or_path": "stabilityai/sd-vae-ft-mse",
5
+ "act_fn": "silu",
6
+ "block_out_channels": [
7
+ 128,
8
+ 256,
9
+ 512,
10
+ 512
11
+ ],
12
+ "down_block_types": [
13
+ "DownEncoderBlock2D",
14
+ "DownEncoderBlock2D",
15
+ "DownEncoderBlock2D",
16
+ "DownEncoderBlock2D"
17
+ ],
18
+ "force_upcast": true,
19
+ "in_channels": 3,
20
+ "latent_channels": 4,
21
+ "latents_mean": null,
22
+ "latents_std": null,
23
+ "layers_per_block": 2,
24
+ "mid_block_add_attention": true,
25
+ "norm_num_groups": 32,
26
+ "out_channels": 3,
27
+ "sample_size": 256,
28
+ "scaling_factor": 0.18215,
29
+ "shift_factor": null,
30
+ "up_block_types": [
31
+ "UpDecoderBlock2D",
32
+ "UpDecoderBlock2D",
33
+ "UpDecoderBlock2D",
34
+ "UpDecoderBlock2D"
35
+ ],
36
+ "use_post_quant_conv": true,
37
+ "use_quant_conv": true
38
+ }
edm2-img512-l-fid/vae/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a1d993488569e928462932c8c38a0760b874d166399b14414135bd9c42df5815
3
+ size 334643276
edm2-img512-m-fid/demo.png ADDED

Git LFS Details

  • SHA256: bda2cb48c7ab17b37fbfa0599c7fec6d1f8d7de6848990f870e8ff4b613c929d
  • Pointer size: 131 Bytes
  • Size of remote file: 370 kB
edm2-img512-m-fid/model_index.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": [
3
+ "pipeline",
4
+ "EDM2Pipeline"
5
+ ],
6
+ "_diffusers_version": "0.31.0",
7
+ "scheduler": [
8
+ "diffusers",
9
+ "EDMEulerScheduler"
10
+ ],
11
+ "unet": [
12
+ "unet_edm2",
13
+ "EDM2UNet2DModel"
14
+ ],
15
+ "vae": [
16
+ "diffusers",
17
+ "AutoencoderKL"
18
+ ]
19
+ }
edm2-img512-m-fid/pipeline.py ADDED
@@ -0,0 +1,406 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Hub custom pipeline: EDM2Pipeline.
2
+ Load with native Hugging Face diffusers and trust_remote_code=True.
3
+ """
4
+
5
+ from __future__ import annotations
6
+
7
+ # Copyright 2026 The HuggingFace Team. All rights reserved.
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+
21
+ import json
22
+ import os
23
+ from pathlib import Path
24
+ from typing import Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Union
25
+
26
+ import numpy as np
27
+ import torch
28
+
29
+ from diffusers.image_processor import VaeImageProcessor
30
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
31
+ from diffusers.utils import replace_example_docstring
32
+ from diffusers.utils.torch_utils import randn_tensor
33
+
34
+ EXAMPLE_DOC_STRING = """
35
+ Examples:
36
+ ```py
37
+ >>> from pathlib import Path
38
+ >>> import torch
39
+ >>> from diffusers import DiffusionPipeline
40
+
41
+ >>> model_dir = Path("BiliSakura/EDM2-diffusers/edm2-img512-xs-fid").resolve()
42
+ >>> pipe = DiffusionPipeline.from_pretrained(
43
+ ... str(model_dir),
44
+ ... local_files_only=True,
45
+ ... custom_pipeline=str(model_dir / "pipeline.py"),
46
+ ... trust_remote_code=True,
47
+ ... torch_dtype=torch.float32,
48
+ ... )
49
+ >>> pipe.to("cuda")
50
+
51
+ >>> generator = torch.Generator(device="cuda").manual_seed(42)
52
+ >>> image = pipe(
53
+ ... class_labels=207,
54
+ ... num_inference_steps=32,
55
+ ... guidance_scale=1.0,
56
+ ... generator=generator,
57
+ ... ).images[0]
58
+ >>> image.save("demo.png")
59
+ ```
60
+ """
61
+
62
+ # Default Stability VAE latent whitening used by NVlabs/edm2 (training/encoders.py).
63
+ _STABILITY_VAE_SCALE = np.float32(0.5) / np.float32([4.17, 4.62, 3.71, 3.28])
64
+ _STABILITY_VAE_BIAS = np.float32(0.0) - np.float32([5.81, 3.25, 0.12, -2.15]) * _STABILITY_VAE_SCALE
65
+
66
+ class EDM2Pipeline(DiffusionPipeline):
67
+ r"""
68
+ Pipeline for class-conditional image generation with EDM2
69
+ ([Analyzing and Improving the Training Dynamics of Diffusion Models](https://arxiv.org/abs/2312.02696)).
70
+
71
+ Parameters:
72
+ unet ([`EDM2UNet2DModel`]):
73
+ Main magnitude-preserving U-Net with EDM preconditioning.
74
+ scheduler ([`EDMEulerScheduler`]):
75
+ Built-in diffusers scheduler used for the Karras sigma schedule. EDM2 Heun sampling runs in
76
+ the pipeline because the UNet returns denoised latents rather than noise predictions.
77
+ vae ([`AutoencoderKL`], *optional*):
78
+ Decoder for 512px latent-diffusion checkpoints. Required when `unet.in_channels == 4`.
79
+ gnet ([`EDM2UNet2DModel`], *optional*):
80
+ Guiding network for autoguidance (`ref.lerp(main, guidance_scale)`).
81
+ id2label (`dict[int, str]`, *optional*):
82
+ ImageNet class id to English label mapping.
83
+ """
84
+
85
+ model_cpu_offload_seq = "unet->gnet->vae"
86
+ _optional_components = ["vae", "gnet"]
87
+
88
+ def __init__(
89
+ self,
90
+ unet,
91
+ scheduler,
92
+ vae=None,
93
+ gnet=None,
94
+ id2label: Optional[Dict[Union[int, str], str]] = None,
95
+ ) -> None:
96
+ super().__init__()
97
+ self.register_modules(unet=unet, scheduler=scheduler, vae=vae, gnet=gnet)
98
+ self._id2label = self._normalize_id2label(id2label)
99
+ self.labels = self._build_label2id(self._id2label)
100
+ self._labels_loaded_from_model_index = bool(self._id2label)
101
+ self.vae_scale_factor = 8 if self.vae is not None else 1
102
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_normalize=False)
103
+
104
+ @staticmethod
105
+ def _normalize_id2label(id2label: Optional[Dict[Union[int, str], str]]) -> Dict[int, str]:
106
+ if not id2label:
107
+ return {}
108
+ return {int(key): value for key, value in id2label.items()}
109
+
110
+ @staticmethod
111
+ def _build_label2id(id2label: Dict[int, str]) -> Dict[str, int]:
112
+ label2id: Dict[str, int] = {}
113
+ for class_id, value in id2label.items():
114
+ for synonym in value.split(","):
115
+ synonym = synonym.strip()
116
+ if synonym:
117
+ label2id[synonym] = int(class_id)
118
+ return dict(sorted(label2id.items()))
119
+
120
+ def _ensure_labels_loaded(self) -> None:
121
+ if self._labels_loaded_from_model_index:
122
+ return
123
+ loaded = self._read_id2label_from_model_index(getattr(self.config, "_name_or_path", None))
124
+ if loaded:
125
+ self._id2label = loaded
126
+ self.labels = self._build_label2id(self._id2label)
127
+ self._labels_loaded_from_model_index = True
128
+
129
+ @staticmethod
130
+ def _read_id2label_from_model_index(variant_path: Optional[str]) -> Dict[int, str]:
131
+ if not variant_path:
132
+ return {}
133
+ model_index_path = Path(variant_path).resolve() / "model_index.json"
134
+ if not model_index_path.is_file():
135
+ return {}
136
+ raw = json.loads(model_index_path.read_text(encoding="utf-8"))
137
+ id2label = raw.get("id2label")
138
+ if not isinstance(id2label, dict):
139
+ return {}
140
+ return {int(key): value for key, value in id2label.items()}
141
+
142
+ @property
143
+ def id2label(self) -> Dict[int, str]:
144
+ r"""ImageNet class id to English label string (comma-separated synonyms)."""
145
+ self._ensure_labels_loaded()
146
+ return self._id2label
147
+
148
+ def get_label_ids(self, label: Union[str, List[str]]) -> List[int]:
149
+ r"""
150
+ Map ImageNet label strings to class ids.
151
+
152
+ Args:
153
+ label (`str` or `list[str]`):
154
+ One or more English label strings that match entries in `id2label`.
155
+ """
156
+ self._ensure_labels_loaded()
157
+ if not self.labels:
158
+ raise ValueError("No English labels loaded. Add `id2label` to model_index.json.")
159
+ labels = [label] if isinstance(label, str) else list(label)
160
+ missing = [item for item in labels if item not in self.labels]
161
+ if missing:
162
+ preview = ", ".join(list(self.labels.keys())[:8])
163
+ raise ValueError(f"Unknown English label(s): {missing}. Example valid labels: {preview}, ...")
164
+ return [self.labels[item] for item in labels]
165
+
166
+ def _default_image_size(self) -> int:
167
+ latent_size = int(getattr(self.unet, "sample_size", getattr(self.unet.config, "sample_size", 64)))
168
+ return latent_size * self.vae_scale_factor
169
+
170
+ def check_inputs(
171
+ self,
172
+ height: int,
173
+ width: int,
174
+ num_inference_steps: int,
175
+ guidance_scale: float,
176
+ output_type: str,
177
+ ) -> None:
178
+ if num_inference_steps < 1:
179
+ raise ValueError("num_inference_steps must be >= 1.")
180
+ if guidance_scale < 1.0:
181
+ raise ValueError("guidance_scale must be >= 1.0.")
182
+ if guidance_scale > 1.0 and self.gnet is None:
183
+ raise ValueError("guidance_scale > 1.0 requires a guiding network (`gnet`).")
184
+ if output_type not in {"pil", "np", "pt", "latent"}:
185
+ raise ValueError("output_type must be one of: 'pil', 'np', 'pt', 'latent'.")
186
+
187
+ native_size = self._default_image_size()
188
+ if height != native_size or width != native_size:
189
+ raise ValueError(
190
+ f"EDM2 expects native resolution height=width={native_size}. "
191
+ f"Got height={height}, width={width}."
192
+ )
193
+
194
+ def _normalize_class_labels(
195
+ self,
196
+ class_labels: Optional[Union[int, str, Sequence[Union[int, str]], torch.Tensor]],
197
+ batch_size: int,
198
+ device: torch.device,
199
+ ) -> Optional[torch.Tensor]:
200
+ label_dim = int(getattr(self.unet, "num_class_embeds", getattr(self.unet.config, "num_class_embeds", 0)))
201
+ if label_dim == 0:
202
+ return None
203
+ if class_labels is None:
204
+ indices = torch.randint(label_dim, size=(batch_size,), device=device)
205
+ return torch.eye(label_dim, device=device, dtype=torch.float32)[indices]
206
+
207
+ if isinstance(class_labels, str):
208
+ class_labels = self.get_label_ids(class_labels)[0]
209
+ elif isinstance(class_labels, Sequence) and class_labels and isinstance(class_labels[0], str):
210
+ class_labels = self.get_label_ids(list(class_labels))
211
+
212
+ if isinstance(class_labels, int):
213
+ indices = torch.full((batch_size,), class_labels, device=device, dtype=torch.long)
214
+ elif isinstance(class_labels, torch.Tensor):
215
+ if class_labels.ndim == 2:
216
+ labels = class_labels.to(device=device, dtype=torch.float32)
217
+ if labels.shape[0] != batch_size:
218
+ raise ValueError(f"class_labels batch must match batch_size={batch_size}.")
219
+ return labels
220
+ indices = class_labels.to(device=device, dtype=torch.long).flatten()
221
+ else:
222
+ indices = torch.tensor(list(class_labels), device=device, dtype=torch.long)
223
+
224
+ if indices.numel() == 1 and batch_size > 1:
225
+ indices = indices.repeat(batch_size)
226
+ if indices.numel() != batch_size:
227
+ raise ValueError(f"class_labels must resolve to batch size {batch_size}, got {indices.numel()}.")
228
+ return torch.eye(label_dim, device=device, dtype=torch.float32)[indices]
229
+
230
+ def prepare_latents(
231
+ self,
232
+ batch_size: int,
233
+ height: int,
234
+ width: int,
235
+ dtype: torch.dtype,
236
+ device: torch.device,
237
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]],
238
+ ) -> torch.Tensor:
239
+ in_channels = int(getattr(self.unet, "in_channels", getattr(self.unet.config, "in_channels", 4)))
240
+ latent_size = height // self.vae_scale_factor
241
+ return randn_tensor(
242
+ (batch_size, in_channels, latent_size, latent_size),
243
+ generator=generator,
244
+ device=device,
245
+ dtype=torch.float32,
246
+ )
247
+
248
+ def decode_latents(self, latents: torch.Tensor, output_type: str = "pil"):
249
+ if output_type == "latent":
250
+ return latents
251
+
252
+ in_channels = int(getattr(self.unet, "in_channels", getattr(self.unet.config, "in_channels", 3)))
253
+ if self.vae is None:
254
+ image = (latents.to(torch.float32) * 127.5 + 128).clip(0, 255) / 255.0
255
+ return self.image_processor.postprocess(image, output_type=output_type)
256
+
257
+ if in_channels == 4:
258
+ x = latents.to(torch.float32)
259
+ scale = torch.as_tensor(_STABILITY_VAE_SCALE, dtype=x.dtype, device=x.device).reshape(1, -1, 1, 1)
260
+ bias = torch.as_tensor(_STABILITY_VAE_BIAS, dtype=x.dtype, device=x.device).reshape(1, -1, 1, 1)
261
+ x = (x - bias) / scale
262
+ else:
263
+ x = latents.to(torch.float32)
264
+
265
+ vae_dtype = getattr(self.vae, "dtype", None) or next(self.vae.parameters()).dtype
266
+ image = self.vae.decode(x.to(dtype=vae_dtype)).sample.to(torch.float32).clamp(0, 1)
267
+
268
+ return self.image_processor.postprocess(image, output_type=output_type)
269
+
270
+ @staticmethod
271
+ def _apply_autoguidance(
272
+ main: torch.Tensor,
273
+ ref: torch.Tensor,
274
+ guidance_scale: float,
275
+ ) -> torch.Tensor:
276
+ return ref.lerp(main, guidance_scale)
277
+
278
+ @staticmethod
279
+ def _sample_edm2_heun(
280
+ denoise_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
281
+ noise: torch.Tensor,
282
+ sigmas: torch.Tensor,
283
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
284
+ progress_bar: Optional[Callable[[Iterable], Iterable]] = None,
285
+ dtype: torch.dtype = torch.float32,
286
+ ) -> torch.Tensor:
287
+ """NVlabs EDM2 Heun sampler (generate_images.edm_sampler, guidance=1, S_churn=0)."""
288
+ x_next = noise.to(dtype) * sigmas[0]
289
+
290
+ sigma_pairs = list(zip(sigmas[:-1], sigmas[1:]))
291
+ if progress_bar is not None:
292
+ sigma_pairs = progress_bar(sigma_pairs)
293
+
294
+ num_steps = len(sigma_pairs)
295
+ for i, (sigma_cur, sigma_next) in enumerate(sigma_pairs):
296
+ x_hat, sigma_hat = x_next, sigma_cur
297
+ d_cur = (x_hat - denoise_fn(x_hat, sigma_hat)) / sigma_hat
298
+ x_next = x_hat + (sigma_next - sigma_hat) * d_cur
299
+ if i < num_steps - 1:
300
+ d_prime = (x_next - denoise_fn(x_next, sigma_next)) / sigma_next
301
+ x_next = x_hat + (sigma_next - sigma_hat) * (0.5 * d_cur + 0.5 * d_prime)
302
+ return x_next
303
+
304
+ @torch.inference_mode()
305
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
306
+ def __call__(
307
+ self,
308
+ class_labels: Optional[Union[int, str, Sequence[Union[int, str]], torch.Tensor]] = None,
309
+ batch_size: int = 1,
310
+ height: Optional[int] = None,
311
+ width: Optional[int] = None,
312
+ num_inference_steps: int = 32,
313
+ guidance_scale: float = 1.0,
314
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
315
+ output_type: str = "pil",
316
+ return_dict: bool = True,
317
+ ) -> Union[ImagePipelineOutput, Tuple]:
318
+ r"""
319
+ Generate class-conditional images with EDM2.
320
+
321
+ Args:
322
+ class_labels (`int`, `str`, `list[int]`, `list[str]`, or `torch.Tensor`, *optional*):
323
+ ImageNet class indices, English label strings, or one-hot float tensors.
324
+ Random classes are sampled when omitted on conditional models.
325
+ batch_size (`int`, defaults to `1`):
326
+ Number of images to generate.
327
+ height (`int`, *optional*):
328
+ Output height in pixels. Defaults to the pretrained native resolution.
329
+ width (`int`, *optional*):
330
+ Output width in pixels. Defaults to the pretrained native resolution.
331
+ num_inference_steps (`int`, defaults to `32`):
332
+ Number of EDM2 Heun steps (NVlabs default).
333
+ guidance_scale (`float`, defaults to `1.0`):
334
+ Autoguidance strength. Values above `1.0` blend the main net with `gnet`
335
+ via `gnet_output.lerp(unet_output, guidance_scale)`.
336
+ generator (`torch.Generator`, *optional*):
337
+ RNG for reproducibility.
338
+ output_type (`str`, defaults to `"pil"`):
339
+ `"pil"`, `"np"`, `"pt"`, or `"latent"`.
340
+ return_dict (`bool`, defaults to `True`):
341
+ Return [`~pipelines.pipeline_utils.ImagePipelineOutput`] if True.
342
+
343
+ Examples:
344
+ <!-- this section is replaced by replace_example_docstring -->
345
+ """
346
+ default_size = self._default_image_size()
347
+ height = int(height or default_size)
348
+ width = int(width or default_size)
349
+ self.check_inputs(height, width, num_inference_steps, guidance_scale, output_type)
350
+
351
+ device = self._execution_device
352
+ dtype = self.unet.dtype
353
+ labels = self._normalize_class_labels(class_labels, batch_size=batch_size, device=device)
354
+ noise = self.prepare_latents(batch_size, height, width, dtype, device, generator)
355
+
356
+ def denoise_fn(x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor:
357
+ sigma_batch = sigma.reshape(1).expand(batch_size)
358
+ main = self.unet(
359
+ sample=x,
360
+ sigma=sigma_batch,
361
+ class_labels=labels,
362
+ force_fp32=True,
363
+ ).sample
364
+ if guidance_scale == 1.0 or self.gnet is None:
365
+ return main.to(torch.float32)
366
+ ref = self.gnet(
367
+ sample=x,
368
+ sigma=sigma_batch,
369
+ class_labels=labels,
370
+ force_fp32=True,
371
+ ).sample
372
+ return self._apply_autoguidance(main, ref, guidance_scale).to(torch.float32)
373
+
374
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
375
+ latents = self._sample_edm2_heun(
376
+ denoise_fn=denoise_fn,
377
+ noise=noise,
378
+ sigmas=self.scheduler.sigmas.to(device),
379
+ generator=generator,
380
+ progress_bar=self.progress_bar,
381
+ dtype=torch.float32,
382
+ )
383
+
384
+ image = self.decode_latents(latents, output_type=output_type)
385
+ if not return_dict:
386
+ return (image, latents)
387
+ return ImagePipelineOutput(images=image)
388
+
389
+ @classmethod
390
+ def _load_vae(cls, pretrained_model_name_or_path: str, torch_dtype: Optional[torch.dtype] = None):
391
+ vae_dir = os.path.join(pretrained_model_name_or_path, "vae")
392
+ if os.path.isdir(vae_dir):
393
+ try:
394
+
395
+ return AutoencoderKL.from_pretrained(vae_dir, torch_dtype=torch_dtype)
396
+ except Exception:
397
+ return None
398
+
399
+ vae_hint = os.path.join(pretrained_model_name_or_path, "vae_pretrained_model_name_or_path.txt")
400
+ if os.path.isfile(vae_hint):
401
+ with open(vae_hint, "r", encoding="utf-8") as f:
402
+ hub_id = f.read().strip()
403
+ if hub_id:
404
+
405
+ return AutoencoderKL.from_pretrained(hub_id, torch_dtype=torch_dtype)
406
+ return None
edm2-img512-m-fid/scheduler/scheduler_config.json ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "EDMEulerScheduler",
3
+ "final_sigmas_type": "zero",
4
+ "num_train_timesteps": 1000,
5
+ "prediction_type": "epsilon",
6
+ "rho": 7.0,
7
+ "sigma_data": 0.5,
8
+ "sigma_max": 80.0,
9
+ "sigma_min": 0.002,
10
+ "sigma_schedule": "karras"
11
+ }
edm2-img512-m-fid/unet/config.json ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "EDM2UNet2DModel",
3
+ "attn_balance": 0.3,
4
+ "attn_resolutions": [
5
+ 16,
6
+ 8
7
+ ],
8
+ "channel_mult": [
9
+ 1,
10
+ 2,
11
+ 3,
12
+ 4
13
+ ],
14
+ "channel_mult_emb": 4,
15
+ "channel_mult_noise": 1,
16
+ "channels_per_head": 64,
17
+ "clip_act": 256,
18
+ "concat_balance": 0.5,
19
+ "dropout": 0.0,
20
+ "in_channels": 4,
21
+ "label_balance": 0.5,
22
+ "logvar_channels": 128,
23
+ "model_channels": 256,
24
+ "num_blocks": 3,
25
+ "num_class_embeds": 1000,
26
+ "out_channels": 4,
27
+ "res_balance": 0.3,
28
+ "sample_size": 64,
29
+ "sigma_data": 0.5,
30
+ "use_fp16": true
31
+ }
edm2-img512-m-fid/unet/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4733c8b2d2823cd6ce7a67e2b89b0e9b94d50fdf595b0e0b17299e198da3bcfc
3
+ size 1991256788
edm2-img512-m-fid/unet/unet_edm2.py ADDED
@@ -0,0 +1,434 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import json
3
+ import os
4
+ from dataclasses import dataclass
5
+ from typing import List, Optional, Tuple
6
+
7
+ import numpy as np
8
+ import torch
9
+
10
+ try:
11
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
12
+ from diffusers.models.modeling_utils import ModelMixin
13
+ from diffusers.utils import BaseOutput
14
+ except ImportError: # pragma: no cover
15
+ class ModelMixin(torch.nn.Module):
16
+ pass
17
+
18
+ class ConfigMixin:
19
+ config = {}
20
+
21
+ def register_to_config(self, **kwargs):
22
+ self.config = kwargs
23
+
24
+ def register_to_config(func):
25
+ return func
26
+
27
+ @dataclass
28
+ class BaseOutput:
29
+ pass
30
+
31
+
32
+ def normalize(x: torch.Tensor, dim: Optional[List[int]] = None, eps: float = 1e-4) -> torch.Tensor:
33
+ if dim is None:
34
+ dim = list(range(1, x.ndim))
35
+ norm = torch.linalg.vector_norm(x, dim=dim, keepdim=True, dtype=torch.float32)
36
+ norm = torch.add(eps, norm, alpha=math.sqrt(norm.numel() / x.numel()))
37
+ return x / norm.to(x.dtype)
38
+
39
+
40
+ def resample(x: torch.Tensor, f: List[float], mode: str = "keep") -> torch.Tensor:
41
+ if mode == "keep":
42
+ return x
43
+ filt = np.float32(f)
44
+ pad = (len(filt) - 1) // 2
45
+ filt = filt / filt.sum()
46
+ filt = np.outer(filt, filt)[np.newaxis, np.newaxis, :, :]
47
+ filt = torch.as_tensor(filt, dtype=x.dtype, device=x.device)
48
+ c = x.shape[1]
49
+ if mode == "down":
50
+ return torch.nn.functional.conv2d(x, filt.tile([c, 1, 1, 1]), groups=c, stride=2, padding=(pad,))
51
+ return torch.nn.functional.conv_transpose2d(x, (filt * 4).tile([c, 1, 1, 1]), groups=c, stride=2, padding=(pad,))
52
+
53
+
54
+ def mp_silu(x: torch.Tensor) -> torch.Tensor:
55
+ return torch.nn.functional.silu(x) / 0.596
56
+
57
+
58
+ def mp_sum(a: torch.Tensor, b: torch.Tensor, t: float = 0.5) -> torch.Tensor:
59
+ return a.lerp(b, t) / math.sqrt((1 - t) ** 2 + t ** 2)
60
+
61
+
62
+ def mp_cat(a: torch.Tensor, b: torch.Tensor, dim: int = 1, t: float = 0.5) -> torch.Tensor:
63
+ na = a.shape[dim]
64
+ nb = b.shape[dim]
65
+ c = math.sqrt((na + nb) / ((1 - t) ** 2 + t ** 2))
66
+ wa = c / math.sqrt(na) * (1 - t)
67
+ wb = c / math.sqrt(nb) * t
68
+ return torch.cat([wa * a, wb * b], dim=dim)
69
+
70
+
71
+ class MPFourier(torch.nn.Module):
72
+ def __init__(self, num_channels: int, bandwidth: float = 1):
73
+ super().__init__()
74
+ self.register_buffer("freqs", 2 * math.pi * torch.randn(num_channels) * bandwidth)
75
+ self.register_buffer("phases", 2 * math.pi * torch.rand(num_channels))
76
+
77
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
78
+ y = x.to(torch.float32).ger(self.freqs.to(torch.float32))
79
+ y = y + self.phases.to(torch.float32)
80
+ y = y.cos() * math.sqrt(2)
81
+ return y.to(x.dtype)
82
+
83
+
84
+ class MPConv(torch.nn.Module):
85
+ def __init__(self, in_channels: int, out_channels: int, kernel: Tuple[int, ...]):
86
+ super().__init__()
87
+ self.out_channels = out_channels
88
+ self.weight = torch.nn.Parameter(torch.randn(out_channels, in_channels, *kernel))
89
+
90
+ def forward(self, x: torch.Tensor, gain: float = 1) -> torch.Tensor:
91
+ w = self.weight.to(torch.float32)
92
+ if self.training:
93
+ with torch.no_grad():
94
+ self.weight.copy_(normalize(w))
95
+ w = normalize(w)
96
+ w = w * (gain / math.sqrt(w[0].numel()))
97
+ w = w.to(x.dtype)
98
+ if w.ndim == 2:
99
+ return x @ w.t()
100
+ return torch.nn.functional.conv2d(x, w, padding=(w.shape[-1] // 2,))
101
+
102
+
103
+ class Block(torch.nn.Module):
104
+ def __init__(
105
+ self,
106
+ in_channels: int,
107
+ out_channels: int,
108
+ emb_channels: int,
109
+ flavor: str = "enc",
110
+ resample_mode: str = "keep",
111
+ resample_filter: List[float] = [1, 1],
112
+ attention: bool = False,
113
+ channels_per_head: int = 64,
114
+ dropout: float = 0.0,
115
+ res_balance: float = 0.3,
116
+ attn_balance: float = 0.3,
117
+ clip_act: Optional[float] = 256,
118
+ ):
119
+ super().__init__()
120
+ self.out_channels = out_channels
121
+ self.flavor = flavor
122
+ self.resample_filter = resample_filter
123
+ self.resample_mode = resample_mode
124
+ self.num_heads = out_channels // channels_per_head if attention else 0
125
+ self.dropout = dropout
126
+ self.res_balance = res_balance
127
+ self.attn_balance = attn_balance
128
+ self.clip_act = clip_act
129
+ self.emb_gain = torch.nn.Parameter(torch.zeros([]))
130
+ self.conv_res0 = MPConv(out_channels if flavor == "enc" else in_channels, out_channels, kernel=(3, 3))
131
+ self.emb_linear = MPConv(emb_channels, out_channels, kernel=())
132
+ self.conv_res1 = MPConv(out_channels, out_channels, kernel=(3, 3))
133
+ self.conv_skip = MPConv(in_channels, out_channels, kernel=(1, 1)) if in_channels != out_channels else None
134
+ self.attn_qkv = MPConv(out_channels, out_channels * 3, kernel=(1, 1)) if self.num_heads else None
135
+ self.attn_proj = MPConv(out_channels, out_channels, kernel=(1, 1)) if self.num_heads else None
136
+
137
+ def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor:
138
+ x = resample(x, f=self.resample_filter, mode=self.resample_mode)
139
+ if self.flavor == "enc":
140
+ if self.conv_skip is not None:
141
+ x = self.conv_skip(x)
142
+ x = normalize(x, dim=[1])
143
+
144
+ y = self.conv_res0(mp_silu(x))
145
+ c = self.emb_linear(emb, gain=self.emb_gain) + 1
146
+ y = mp_silu(y * c.unsqueeze(2).unsqueeze(3).to(y.dtype))
147
+ if self.training and self.dropout:
148
+ y = torch.nn.functional.dropout(y, p=self.dropout)
149
+ y = self.conv_res1(y)
150
+
151
+ if self.flavor == "dec" and self.conv_skip is not None:
152
+ x = self.conv_skip(x)
153
+ x = mp_sum(x, y, t=self.res_balance)
154
+
155
+ if self.num_heads:
156
+ y = self.attn_qkv(x)
157
+ y = y.reshape(y.shape[0], self.num_heads, -1, 3, y.shape[2] * y.shape[3])
158
+ q, k, v = normalize(y, dim=[2]).unbind(3)
159
+ w = torch.einsum("nhcq,nhck->nhqk", q, k / math.sqrt(q.shape[2])).softmax(dim=3)
160
+ y = torch.einsum("nhqk,nhck->nhcq", w, v)
161
+ y = self.attn_proj(y.reshape(*x.shape))
162
+ x = mp_sum(x, y, t=self.attn_balance)
163
+
164
+ if self.clip_act is not None:
165
+ x = x.clip_(-self.clip_act, self.clip_act)
166
+ return x
167
+
168
+
169
+ class EDM2UNet(torch.nn.Module):
170
+ def __init__(
171
+ self,
172
+ img_resolution: int,
173
+ img_channels: int,
174
+ label_dim: int,
175
+ model_channels: int = 192,
176
+ channel_mult: Tuple[int, ...] = (1, 2, 3, 4),
177
+ channel_mult_noise: Optional[int] = None,
178
+ channel_mult_emb: Optional[int] = None,
179
+ num_blocks: int = 3,
180
+ attn_resolutions: Tuple[int, ...] = (16, 8),
181
+ label_balance: float = 0.5,
182
+ concat_balance: float = 0.5,
183
+ **block_kwargs,
184
+ ):
185
+ super().__init__()
186
+ cblock = [model_channels * x for x in channel_mult]
187
+ cnoise = model_channels * channel_mult_noise if channel_mult_noise is not None else cblock[0]
188
+ cemb = model_channels * channel_mult_emb if channel_mult_emb is not None else max(cblock)
189
+ self.label_balance = label_balance
190
+ self.concat_balance = concat_balance
191
+ self.out_gain = torch.nn.Parameter(torch.zeros([]))
192
+
193
+ self.emb_fourier = MPFourier(cnoise)
194
+ self.emb_noise = MPConv(cnoise, cemb, kernel=())
195
+ self.emb_label = MPConv(label_dim, cemb, kernel=()) if label_dim else None
196
+
197
+ self.enc = torch.nn.ModuleDict()
198
+ cout = img_channels + 1
199
+ for level, channels in enumerate(cblock):
200
+ res = img_resolution >> level
201
+ if level == 0:
202
+ cin = cout
203
+ cout = channels
204
+ self.enc[f"{res}x{res}_conv"] = MPConv(cin, cout, kernel=(3, 3))
205
+ else:
206
+ self.enc[f"{res}x{res}_down"] = Block(cout, cout, cemb, flavor="enc", resample_mode="down", **block_kwargs)
207
+ for idx in range(num_blocks):
208
+ cin = cout
209
+ cout = channels
210
+ self.enc[f"{res}x{res}_block{idx}"] = Block(
211
+ cin,
212
+ cout,
213
+ cemb,
214
+ flavor="enc",
215
+ attention=(res in attn_resolutions),
216
+ **block_kwargs,
217
+ )
218
+
219
+ self.dec = torch.nn.ModuleDict()
220
+ skips = [block.out_channels for block in self.enc.values()]
221
+ for level, channels in reversed(list(enumerate(cblock))):
222
+ res = img_resolution >> level
223
+ if level == len(cblock) - 1:
224
+ self.dec[f"{res}x{res}_in0"] = Block(cout, cout, cemb, flavor="dec", attention=True, **block_kwargs)
225
+ self.dec[f"{res}x{res}_in1"] = Block(cout, cout, cemb, flavor="dec", **block_kwargs)
226
+ else:
227
+ self.dec[f"{res}x{res}_up"] = Block(cout, cout, cemb, flavor="dec", resample_mode="up", **block_kwargs)
228
+ for idx in range(num_blocks + 1):
229
+ cin = cout + skips.pop()
230
+ cout = channels
231
+ self.dec[f"{res}x{res}_block{idx}"] = Block(
232
+ cin,
233
+ cout,
234
+ cemb,
235
+ flavor="dec",
236
+ attention=(res in attn_resolutions),
237
+ **block_kwargs,
238
+ )
239
+
240
+ self.out_conv = MPConv(cout, img_channels, kernel=(3, 3))
241
+
242
+ def forward(self, x: torch.Tensor, noise_labels: torch.Tensor, class_labels: Optional[torch.Tensor]) -> torch.Tensor:
243
+ emb = self.emb_noise(self.emb_fourier(noise_labels))
244
+ if self.emb_label is not None:
245
+ if class_labels is None:
246
+ raise ValueError("class_labels are required for conditional EDM2UNet.")
247
+ emb = mp_sum(emb, self.emb_label(class_labels * math.sqrt(class_labels.shape[1])), t=self.label_balance)
248
+ emb = mp_silu(emb)
249
+
250
+ x = torch.cat([x, torch.ones_like(x[:, :1])], dim=1)
251
+ skips = []
252
+ for name, block in self.enc.items():
253
+ x = block(x) if "conv" in name else block(x, emb)
254
+ skips.append(x)
255
+
256
+ for name, block in self.dec.items():
257
+ if "block" in name:
258
+ x = mp_cat(x, skips.pop(), t=self.concat_balance)
259
+ x = block(x, emb)
260
+ return self.out_conv(x, gain=self.out_gain)
261
+
262
+
263
+ @dataclass
264
+ class EDM2UNet2DOutput(BaseOutput):
265
+ sample: torch.Tensor
266
+ logvar: Optional[torch.Tensor] = None
267
+
268
+
269
+
270
+ _CONFIG_KEYS = (
271
+ "sample_size",
272
+ "in_channels",
273
+ "out_channels",
274
+ "num_class_embeds",
275
+ "use_fp16",
276
+ "sigma_data",
277
+ "logvar_channels",
278
+ "model_channels",
279
+ "channel_mult",
280
+ "channel_mult_noise",
281
+ "channel_mult_emb",
282
+ "num_blocks",
283
+ "attn_resolutions",
284
+ "label_balance",
285
+ "concat_balance",
286
+ "dropout",
287
+ "channels_per_head",
288
+ "res_balance",
289
+ "attn_balance",
290
+ "clip_act",
291
+ )
292
+
293
+
294
+ class EDM2UNet2DModel(ModelMixin, ConfigMixin):
295
+ @register_to_config
296
+ def __init__(
297
+ self,
298
+ sample_size: int = 64,
299
+ in_channels: int = 4,
300
+ out_channels: int = 4,
301
+ num_class_embeds: int = 0,
302
+ use_fp16: bool = True,
303
+ sigma_data: float = 0.5,
304
+ logvar_channels: int = 128,
305
+ model_channels: int = 192,
306
+ channel_mult: Tuple[int, ...] = (1, 2, 3, 4),
307
+ channel_mult_noise: Optional[int] = None,
308
+ channel_mult_emb: Optional[int] = None,
309
+ num_blocks: int = 3,
310
+ attn_resolutions: Tuple[int, ...] = (16, 8),
311
+ label_balance: float = 0.5,
312
+ concat_balance: float = 0.5,
313
+ dropout: float = 0.0,
314
+ channels_per_head: int = 64,
315
+ res_balance: float = 0.3,
316
+ attn_balance: float = 0.3,
317
+ clip_act: Optional[float] = 256,
318
+ ):
319
+ super().__init__()
320
+ self.sample_size = sample_size
321
+ self.in_channels = in_channels
322
+ self.out_channels = out_channels
323
+ self.num_class_embeds = num_class_embeds
324
+ self.use_fp16 = use_fp16
325
+ self.sigma_data = sigma_data
326
+ self.model_channels = model_channels
327
+ self.channel_mult = channel_mult
328
+ self.channel_mult_noise = channel_mult_noise
329
+ self.channel_mult_emb = channel_mult_emb
330
+ self.num_blocks = num_blocks
331
+ self.attn_resolutions = attn_resolutions
332
+ self.label_balance = label_balance
333
+ self.concat_balance = concat_balance
334
+ self.dropout = dropout
335
+ self.channels_per_head = channels_per_head
336
+ self.res_balance = res_balance
337
+ self.attn_balance = attn_balance
338
+ self.clip_act = clip_act
339
+ self.unet = EDM2UNet(
340
+ img_resolution=sample_size,
341
+ img_channels=in_channels,
342
+ label_dim=num_class_embeds,
343
+ model_channels=model_channels,
344
+ channel_mult=channel_mult,
345
+ channel_mult_noise=channel_mult_noise,
346
+ channel_mult_emb=channel_mult_emb,
347
+ num_blocks=num_blocks,
348
+ attn_resolutions=attn_resolutions,
349
+ label_balance=label_balance,
350
+ concat_balance=concat_balance,
351
+ dropout=dropout,
352
+ channels_per_head=channels_per_head,
353
+ res_balance=res_balance,
354
+ attn_balance=attn_balance,
355
+ clip_act=clip_act,
356
+ )
357
+ self.logvar_fourier = MPFourier(logvar_channels)
358
+ self.logvar_linear = MPConv(logvar_channels, 1, kernel=())
359
+
360
+ def forward(
361
+ self,
362
+ sample: torch.Tensor,
363
+ sigma: torch.Tensor,
364
+ class_labels: Optional[torch.Tensor] = None,
365
+ force_fp32: bool = False,
366
+ return_logvar: bool = False,
367
+ return_dict: bool = True,
368
+ ) -> EDM2UNet2DOutput:
369
+ x = sample.to(torch.float32)
370
+ sigma = sigma.to(torch.float32).reshape(-1, 1, 1, 1)
371
+ if self.num_class_embeds == 0:
372
+ class_labels = None
373
+ else:
374
+ if class_labels is None:
375
+ class_labels = torch.zeros([x.shape[0], self.num_class_embeds], device=x.device)
376
+ class_labels = class_labels.to(torch.float32).reshape(-1, self.num_class_embeds)
377
+ dtype = torch.float16 if (self.use_fp16 and not force_fp32 and x.device.type == "cuda") else torch.float32
378
+
379
+ c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2)
380
+ c_out = sigma * self.sigma_data / (sigma**2 + self.sigma_data**2).sqrt()
381
+ c_in = 1 / (self.sigma_data**2 + sigma**2).sqrt()
382
+ c_noise = sigma.flatten().log() / 4
383
+
384
+ x_in = (c_in * x).to(dtype)
385
+ f_x = self.unet(x_in, c_noise, class_labels)
386
+ d_x = c_skip * x + c_out * f_x.to(torch.float32)
387
+
388
+ logvar = None
389
+ if return_logvar:
390
+ logvar = self.logvar_linear(self.logvar_fourier(c_noise)).reshape(-1, 1, 1, 1)
391
+
392
+ if not return_dict:
393
+ return (d_x, logvar)
394
+ return EDM2UNet2DOutput(sample=d_x, logvar=logvar)
395
+
396
+ @classmethod
397
+ def from_pretrained(cls, pretrained_model_name_or_path: str, torch_dtype: Optional[torch.dtype] = None, **kwargs):
398
+ subfolder = kwargs.pop("subfolder", None)
399
+ model_dir = os.path.join(pretrained_model_name_or_path, subfolder) if subfolder else pretrained_model_name_or_path
400
+ with open(os.path.join(model_dir, "config.json"), "r", encoding="utf-8") as f:
401
+ config = json.load(f)
402
+ init_kwargs = {k: v for k, v in config.items() if k in _CONFIG_KEYS}
403
+ model = cls(**init_kwargs)
404
+ weight_file = os.path.join(model_dir, "diffusion_pytorch_model.safetensors")
405
+ if os.path.isfile(weight_file):
406
+ from safetensors.torch import load_file
407
+
408
+ state_dict = load_file(weight_file)
409
+ else:
410
+ state_dict = torch.load(os.path.join(model_dir, "diffusion_pytorch_model.bin"), map_location="cpu")
411
+ model.load_state_dict(state_dict, strict=True)
412
+ if torch_dtype is not None:
413
+ model = model.to(dtype=torch_dtype)
414
+ return model
415
+
416
+ def save_pretrained(self, save_directory: str, safe_serialization: bool = True):
417
+ os.makedirs(save_directory, exist_ok=True)
418
+ stored = dict(getattr(self, "config", {}))
419
+ config = {"_class_name": self.__class__.__name__}
420
+ for key in _CONFIG_KEYS:
421
+ if key in stored:
422
+ config[key] = stored[key]
423
+ elif hasattr(self, key):
424
+ config[key] = getattr(self, key)
425
+ with open(os.path.join(save_directory, "config.json"), "w", encoding="utf-8") as f:
426
+ json.dump(config, f, indent=2, sort_keys=True)
427
+ f.write("\n")
428
+ state_dict = self.state_dict()
429
+ if safe_serialization:
430
+ from safetensors.torch import save_file
431
+
432
+ save_file(state_dict, os.path.join(save_directory, "diffusion_pytorch_model.safetensors"))
433
+ else:
434
+ torch.save(state_dict, os.path.join(save_directory, "diffusion_pytorch_model.bin"))
edm2-img512-m-fid/vae/config.json ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "AutoencoderKL",
3
+ "_diffusers_version": "0.36.0",
4
+ "_name_or_path": "stabilityai/sd-vae-ft-mse",
5
+ "act_fn": "silu",
6
+ "block_out_channels": [
7
+ 128,
8
+ 256,
9
+ 512,
10
+ 512
11
+ ],
12
+ "down_block_types": [
13
+ "DownEncoderBlock2D",
14
+ "DownEncoderBlock2D",
15
+ "DownEncoderBlock2D",
16
+ "DownEncoderBlock2D"
17
+ ],
18
+ "force_upcast": true,
19
+ "in_channels": 3,
20
+ "latent_channels": 4,
21
+ "latents_mean": null,
22
+ "latents_std": null,
23
+ "layers_per_block": 2,
24
+ "mid_block_add_attention": true,
25
+ "norm_num_groups": 32,
26
+ "out_channels": 3,
27
+ "sample_size": 256,
28
+ "scaling_factor": 0.18215,
29
+ "shift_factor": null,
30
+ "up_block_types": [
31
+ "UpDecoderBlock2D",
32
+ "UpDecoderBlock2D",
33
+ "UpDecoderBlock2D",
34
+ "UpDecoderBlock2D"
35
+ ],
36
+ "use_post_quant_conv": true,
37
+ "use_quant_conv": true
38
+ }
edm2-img512-m-fid/vae/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a1d993488569e928462932c8c38a0760b874d166399b14414135bd9c42df5815
3
+ size 334643276
edm2-img512-s-fid/demo.png ADDED

Git LFS Details

  • SHA256: 58bdb49e30c85b02b9e3619a11b39b1ec760452e8ad96cea1c5856e99df39d42
  • Pointer size: 131 Bytes
  • Size of remote file: 381 kB
edm2-img512-s-fid/model_index.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": [
3
+ "pipeline",
4
+ "EDM2Pipeline"
5
+ ],
6
+ "_diffusers_version": "0.31.0",
7
+ "scheduler": [
8
+ "diffusers",
9
+ "EDMEulerScheduler"
10
+ ],
11
+ "unet": [
12
+ "unet_edm2",
13
+ "EDM2UNet2DModel"
14
+ ],
15
+ "vae": [
16
+ "diffusers",
17
+ "AutoencoderKL"
18
+ ]
19
+ }
edm2-img512-s-fid/pipeline.py ADDED
@@ -0,0 +1,406 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Hub custom pipeline: EDM2Pipeline.
2
+ Load with native Hugging Face diffusers and trust_remote_code=True.
3
+ """
4
+
5
+ from __future__ import annotations
6
+
7
+ # Copyright 2026 The HuggingFace Team. All rights reserved.
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+
21
+ import json
22
+ import os
23
+ from pathlib import Path
24
+ from typing import Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Union
25
+
26
+ import numpy as np
27
+ import torch
28
+
29
+ from diffusers.image_processor import VaeImageProcessor
30
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
31
+ from diffusers.utils import replace_example_docstring
32
+ from diffusers.utils.torch_utils import randn_tensor
33
+
34
+ EXAMPLE_DOC_STRING = """
35
+ Examples:
36
+ ```py
37
+ >>> from pathlib import Path
38
+ >>> import torch
39
+ >>> from diffusers import DiffusionPipeline
40
+
41
+ >>> model_dir = Path("BiliSakura/EDM2-diffusers/edm2-img512-xs-fid").resolve()
42
+ >>> pipe = DiffusionPipeline.from_pretrained(
43
+ ... str(model_dir),
44
+ ... local_files_only=True,
45
+ ... custom_pipeline=str(model_dir / "pipeline.py"),
46
+ ... trust_remote_code=True,
47
+ ... torch_dtype=torch.float32,
48
+ ... )
49
+ >>> pipe.to("cuda")
50
+
51
+ >>> generator = torch.Generator(device="cuda").manual_seed(42)
52
+ >>> image = pipe(
53
+ ... class_labels=207,
54
+ ... num_inference_steps=32,
55
+ ... guidance_scale=1.0,
56
+ ... generator=generator,
57
+ ... ).images[0]
58
+ >>> image.save("demo.png")
59
+ ```
60
+ """
61
+
62
+ # Default Stability VAE latent whitening used by NVlabs/edm2 (training/encoders.py).
63
+ _STABILITY_VAE_SCALE = np.float32(0.5) / np.float32([4.17, 4.62, 3.71, 3.28])
64
+ _STABILITY_VAE_BIAS = np.float32(0.0) - np.float32([5.81, 3.25, 0.12, -2.15]) * _STABILITY_VAE_SCALE
65
+
66
+ class EDM2Pipeline(DiffusionPipeline):
67
+ r"""
68
+ Pipeline for class-conditional image generation with EDM2
69
+ ([Analyzing and Improving the Training Dynamics of Diffusion Models](https://arxiv.org/abs/2312.02696)).
70
+
71
+ Parameters:
72
+ unet ([`EDM2UNet2DModel`]):
73
+ Main magnitude-preserving U-Net with EDM preconditioning.
74
+ scheduler ([`EDMEulerScheduler`]):
75
+ Built-in diffusers scheduler used for the Karras sigma schedule. EDM2 Heun sampling runs in
76
+ the pipeline because the UNet returns denoised latents rather than noise predictions.
77
+ vae ([`AutoencoderKL`], *optional*):
78
+ Decoder for 512px latent-diffusion checkpoints. Required when `unet.in_channels == 4`.
79
+ gnet ([`EDM2UNet2DModel`], *optional*):
80
+ Guiding network for autoguidance (`ref.lerp(main, guidance_scale)`).
81
+ id2label (`dict[int, str]`, *optional*):
82
+ ImageNet class id to English label mapping.
83
+ """
84
+
85
+ model_cpu_offload_seq = "unet->gnet->vae"
86
+ _optional_components = ["vae", "gnet"]
87
+
88
+ def __init__(
89
+ self,
90
+ unet,
91
+ scheduler,
92
+ vae=None,
93
+ gnet=None,
94
+ id2label: Optional[Dict[Union[int, str], str]] = None,
95
+ ) -> None:
96
+ super().__init__()
97
+ self.register_modules(unet=unet, scheduler=scheduler, vae=vae, gnet=gnet)
98
+ self._id2label = self._normalize_id2label(id2label)
99
+ self.labels = self._build_label2id(self._id2label)
100
+ self._labels_loaded_from_model_index = bool(self._id2label)
101
+ self.vae_scale_factor = 8 if self.vae is not None else 1
102
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_normalize=False)
103
+
104
+ @staticmethod
105
+ def _normalize_id2label(id2label: Optional[Dict[Union[int, str], str]]) -> Dict[int, str]:
106
+ if not id2label:
107
+ return {}
108
+ return {int(key): value for key, value in id2label.items()}
109
+
110
+ @staticmethod
111
+ def _build_label2id(id2label: Dict[int, str]) -> Dict[str, int]:
112
+ label2id: Dict[str, int] = {}
113
+ for class_id, value in id2label.items():
114
+ for synonym in value.split(","):
115
+ synonym = synonym.strip()
116
+ if synonym:
117
+ label2id[synonym] = int(class_id)
118
+ return dict(sorted(label2id.items()))
119
+
120
+ def _ensure_labels_loaded(self) -> None:
121
+ if self._labels_loaded_from_model_index:
122
+ return
123
+ loaded = self._read_id2label_from_model_index(getattr(self.config, "_name_or_path", None))
124
+ if loaded:
125
+ self._id2label = loaded
126
+ self.labels = self._build_label2id(self._id2label)
127
+ self._labels_loaded_from_model_index = True
128
+
129
+ @staticmethod
130
+ def _read_id2label_from_model_index(variant_path: Optional[str]) -> Dict[int, str]:
131
+ if not variant_path:
132
+ return {}
133
+ model_index_path = Path(variant_path).resolve() / "model_index.json"
134
+ if not model_index_path.is_file():
135
+ return {}
136
+ raw = json.loads(model_index_path.read_text(encoding="utf-8"))
137
+ id2label = raw.get("id2label")
138
+ if not isinstance(id2label, dict):
139
+ return {}
140
+ return {int(key): value for key, value in id2label.items()}
141
+
142
+ @property
143
+ def id2label(self) -> Dict[int, str]:
144
+ r"""ImageNet class id to English label string (comma-separated synonyms)."""
145
+ self._ensure_labels_loaded()
146
+ return self._id2label
147
+
148
+ def get_label_ids(self, label: Union[str, List[str]]) -> List[int]:
149
+ r"""
150
+ Map ImageNet label strings to class ids.
151
+
152
+ Args:
153
+ label (`str` or `list[str]`):
154
+ One or more English label strings that match entries in `id2label`.
155
+ """
156
+ self._ensure_labels_loaded()
157
+ if not self.labels:
158
+ raise ValueError("No English labels loaded. Add `id2label` to model_index.json.")
159
+ labels = [label] if isinstance(label, str) else list(label)
160
+ missing = [item for item in labels if item not in self.labels]
161
+ if missing:
162
+ preview = ", ".join(list(self.labels.keys())[:8])
163
+ raise ValueError(f"Unknown English label(s): {missing}. Example valid labels: {preview}, ...")
164
+ return [self.labels[item] for item in labels]
165
+
166
+ def _default_image_size(self) -> int:
167
+ latent_size = int(getattr(self.unet, "sample_size", getattr(self.unet.config, "sample_size", 64)))
168
+ return latent_size * self.vae_scale_factor
169
+
170
+ def check_inputs(
171
+ self,
172
+ height: int,
173
+ width: int,
174
+ num_inference_steps: int,
175
+ guidance_scale: float,
176
+ output_type: str,
177
+ ) -> None:
178
+ if num_inference_steps < 1:
179
+ raise ValueError("num_inference_steps must be >= 1.")
180
+ if guidance_scale < 1.0:
181
+ raise ValueError("guidance_scale must be >= 1.0.")
182
+ if guidance_scale > 1.0 and self.gnet is None:
183
+ raise ValueError("guidance_scale > 1.0 requires a guiding network (`gnet`).")
184
+ if output_type not in {"pil", "np", "pt", "latent"}:
185
+ raise ValueError("output_type must be one of: 'pil', 'np', 'pt', 'latent'.")
186
+
187
+ native_size = self._default_image_size()
188
+ if height != native_size or width != native_size:
189
+ raise ValueError(
190
+ f"EDM2 expects native resolution height=width={native_size}. "
191
+ f"Got height={height}, width={width}."
192
+ )
193
+
194
+ def _normalize_class_labels(
195
+ self,
196
+ class_labels: Optional[Union[int, str, Sequence[Union[int, str]], torch.Tensor]],
197
+ batch_size: int,
198
+ device: torch.device,
199
+ ) -> Optional[torch.Tensor]:
200
+ label_dim = int(getattr(self.unet, "num_class_embeds", getattr(self.unet.config, "num_class_embeds", 0)))
201
+ if label_dim == 0:
202
+ return None
203
+ if class_labels is None:
204
+ indices = torch.randint(label_dim, size=(batch_size,), device=device)
205
+ return torch.eye(label_dim, device=device, dtype=torch.float32)[indices]
206
+
207
+ if isinstance(class_labels, str):
208
+ class_labels = self.get_label_ids(class_labels)[0]
209
+ elif isinstance(class_labels, Sequence) and class_labels and isinstance(class_labels[0], str):
210
+ class_labels = self.get_label_ids(list(class_labels))
211
+
212
+ if isinstance(class_labels, int):
213
+ indices = torch.full((batch_size,), class_labels, device=device, dtype=torch.long)
214
+ elif isinstance(class_labels, torch.Tensor):
215
+ if class_labels.ndim == 2:
216
+ labels = class_labels.to(device=device, dtype=torch.float32)
217
+ if labels.shape[0] != batch_size:
218
+ raise ValueError(f"class_labels batch must match batch_size={batch_size}.")
219
+ return labels
220
+ indices = class_labels.to(device=device, dtype=torch.long).flatten()
221
+ else:
222
+ indices = torch.tensor(list(class_labels), device=device, dtype=torch.long)
223
+
224
+ if indices.numel() == 1 and batch_size > 1:
225
+ indices = indices.repeat(batch_size)
226
+ if indices.numel() != batch_size:
227
+ raise ValueError(f"class_labels must resolve to batch size {batch_size}, got {indices.numel()}.")
228
+ return torch.eye(label_dim, device=device, dtype=torch.float32)[indices]
229
+
230
+ def prepare_latents(
231
+ self,
232
+ batch_size: int,
233
+ height: int,
234
+ width: int,
235
+ dtype: torch.dtype,
236
+ device: torch.device,
237
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]],
238
+ ) -> torch.Tensor:
239
+ in_channels = int(getattr(self.unet, "in_channels", getattr(self.unet.config, "in_channels", 4)))
240
+ latent_size = height // self.vae_scale_factor
241
+ return randn_tensor(
242
+ (batch_size, in_channels, latent_size, latent_size),
243
+ generator=generator,
244
+ device=device,
245
+ dtype=torch.float32,
246
+ )
247
+
248
+ def decode_latents(self, latents: torch.Tensor, output_type: str = "pil"):
249
+ if output_type == "latent":
250
+ return latents
251
+
252
+ in_channels = int(getattr(self.unet, "in_channels", getattr(self.unet.config, "in_channels", 3)))
253
+ if self.vae is None:
254
+ image = (latents.to(torch.float32) * 127.5 + 128).clip(0, 255) / 255.0
255
+ return self.image_processor.postprocess(image, output_type=output_type)
256
+
257
+ if in_channels == 4:
258
+ x = latents.to(torch.float32)
259
+ scale = torch.as_tensor(_STABILITY_VAE_SCALE, dtype=x.dtype, device=x.device).reshape(1, -1, 1, 1)
260
+ bias = torch.as_tensor(_STABILITY_VAE_BIAS, dtype=x.dtype, device=x.device).reshape(1, -1, 1, 1)
261
+ x = (x - bias) / scale
262
+ else:
263
+ x = latents.to(torch.float32)
264
+
265
+ vae_dtype = getattr(self.vae, "dtype", None) or next(self.vae.parameters()).dtype
266
+ image = self.vae.decode(x.to(dtype=vae_dtype)).sample.to(torch.float32).clamp(0, 1)
267
+
268
+ return self.image_processor.postprocess(image, output_type=output_type)
269
+
270
+ @staticmethod
271
+ def _apply_autoguidance(
272
+ main: torch.Tensor,
273
+ ref: torch.Tensor,
274
+ guidance_scale: float,
275
+ ) -> torch.Tensor:
276
+ return ref.lerp(main, guidance_scale)
277
+
278
+ @staticmethod
279
+ def _sample_edm2_heun(
280
+ denoise_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
281
+ noise: torch.Tensor,
282
+ sigmas: torch.Tensor,
283
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
284
+ progress_bar: Optional[Callable[[Iterable], Iterable]] = None,
285
+ dtype: torch.dtype = torch.float32,
286
+ ) -> torch.Tensor:
287
+ """NVlabs EDM2 Heun sampler (generate_images.edm_sampler, guidance=1, S_churn=0)."""
288
+ x_next = noise.to(dtype) * sigmas[0]
289
+
290
+ sigma_pairs = list(zip(sigmas[:-1], sigmas[1:]))
291
+ if progress_bar is not None:
292
+ sigma_pairs = progress_bar(sigma_pairs)
293
+
294
+ num_steps = len(sigma_pairs)
295
+ for i, (sigma_cur, sigma_next) in enumerate(sigma_pairs):
296
+ x_hat, sigma_hat = x_next, sigma_cur
297
+ d_cur = (x_hat - denoise_fn(x_hat, sigma_hat)) / sigma_hat
298
+ x_next = x_hat + (sigma_next - sigma_hat) * d_cur
299
+ if i < num_steps - 1:
300
+ d_prime = (x_next - denoise_fn(x_next, sigma_next)) / sigma_next
301
+ x_next = x_hat + (sigma_next - sigma_hat) * (0.5 * d_cur + 0.5 * d_prime)
302
+ return x_next
303
+
304
+ @torch.inference_mode()
305
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
306
+ def __call__(
307
+ self,
308
+ class_labels: Optional[Union[int, str, Sequence[Union[int, str]], torch.Tensor]] = None,
309
+ batch_size: int = 1,
310
+ height: Optional[int] = None,
311
+ width: Optional[int] = None,
312
+ num_inference_steps: int = 32,
313
+ guidance_scale: float = 1.0,
314
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
315
+ output_type: str = "pil",
316
+ return_dict: bool = True,
317
+ ) -> Union[ImagePipelineOutput, Tuple]:
318
+ r"""
319
+ Generate class-conditional images with EDM2.
320
+
321
+ Args:
322
+ class_labels (`int`, `str`, `list[int]`, `list[str]`, or `torch.Tensor`, *optional*):
323
+ ImageNet class indices, English label strings, or one-hot float tensors.
324
+ Random classes are sampled when omitted on conditional models.
325
+ batch_size (`int`, defaults to `1`):
326
+ Number of images to generate.
327
+ height (`int`, *optional*):
328
+ Output height in pixels. Defaults to the pretrained native resolution.
329
+ width (`int`, *optional*):
330
+ Output width in pixels. Defaults to the pretrained native resolution.
331
+ num_inference_steps (`int`, defaults to `32`):
332
+ Number of EDM2 Heun steps (NVlabs default).
333
+ guidance_scale (`float`, defaults to `1.0`):
334
+ Autoguidance strength. Values above `1.0` blend the main net with `gnet`
335
+ via `gnet_output.lerp(unet_output, guidance_scale)`.
336
+ generator (`torch.Generator`, *optional*):
337
+ RNG for reproducibility.
338
+ output_type (`str`, defaults to `"pil"`):
339
+ `"pil"`, `"np"`, `"pt"`, or `"latent"`.
340
+ return_dict (`bool`, defaults to `True`):
341
+ Return [`~pipelines.pipeline_utils.ImagePipelineOutput`] if True.
342
+
343
+ Examples:
344
+ <!-- this section is replaced by replace_example_docstring -->
345
+ """
346
+ default_size = self._default_image_size()
347
+ height = int(height or default_size)
348
+ width = int(width or default_size)
349
+ self.check_inputs(height, width, num_inference_steps, guidance_scale, output_type)
350
+
351
+ device = self._execution_device
352
+ dtype = self.unet.dtype
353
+ labels = self._normalize_class_labels(class_labels, batch_size=batch_size, device=device)
354
+ noise = self.prepare_latents(batch_size, height, width, dtype, device, generator)
355
+
356
+ def denoise_fn(x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor:
357
+ sigma_batch = sigma.reshape(1).expand(batch_size)
358
+ main = self.unet(
359
+ sample=x,
360
+ sigma=sigma_batch,
361
+ class_labels=labels,
362
+ force_fp32=True,
363
+ ).sample
364
+ if guidance_scale == 1.0 or self.gnet is None:
365
+ return main.to(torch.float32)
366
+ ref = self.gnet(
367
+ sample=x,
368
+ sigma=sigma_batch,
369
+ class_labels=labels,
370
+ force_fp32=True,
371
+ ).sample
372
+ return self._apply_autoguidance(main, ref, guidance_scale).to(torch.float32)
373
+
374
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
375
+ latents = self._sample_edm2_heun(
376
+ denoise_fn=denoise_fn,
377
+ noise=noise,
378
+ sigmas=self.scheduler.sigmas.to(device),
379
+ generator=generator,
380
+ progress_bar=self.progress_bar,
381
+ dtype=torch.float32,
382
+ )
383
+
384
+ image = self.decode_latents(latents, output_type=output_type)
385
+ if not return_dict:
386
+ return (image, latents)
387
+ return ImagePipelineOutput(images=image)
388
+
389
+ @classmethod
390
+ def _load_vae(cls, pretrained_model_name_or_path: str, torch_dtype: Optional[torch.dtype] = None):
391
+ vae_dir = os.path.join(pretrained_model_name_or_path, "vae")
392
+ if os.path.isdir(vae_dir):
393
+ try:
394
+
395
+ return AutoencoderKL.from_pretrained(vae_dir, torch_dtype=torch_dtype)
396
+ except Exception:
397
+ return None
398
+
399
+ vae_hint = os.path.join(pretrained_model_name_or_path, "vae_pretrained_model_name_or_path.txt")
400
+ if os.path.isfile(vae_hint):
401
+ with open(vae_hint, "r", encoding="utf-8") as f:
402
+ hub_id = f.read().strip()
403
+ if hub_id:
404
+
405
+ return AutoencoderKL.from_pretrained(hub_id, torch_dtype=torch_dtype)
406
+ return None
edm2-img512-s-fid/scheduler/scheduler_config.json ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "EDMEulerScheduler",
3
+ "final_sigmas_type": "zero",
4
+ "num_train_timesteps": 1000,
5
+ "prediction_type": "epsilon",
6
+ "rho": 7.0,
7
+ "sigma_data": 0.5,
8
+ "sigma_max": 80.0,
9
+ "sigma_min": 0.002,
10
+ "sigma_schedule": "karras"
11
+ }
edm2-img512-s-fid/unet/config.json ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "EDM2UNet2DModel",
3
+ "attn_balance": 0.3,
4
+ "attn_resolutions": [
5
+ 16,
6
+ 8
7
+ ],
8
+ "channel_mult": [
9
+ 1,
10
+ 2,
11
+ 3,
12
+ 4
13
+ ],
14
+ "channel_mult_emb": 4,
15
+ "channel_mult_noise": 1,
16
+ "channels_per_head": 64,
17
+ "clip_act": 256,
18
+ "concat_balance": 0.5,
19
+ "dropout": 0.0,
20
+ "in_channels": 4,
21
+ "label_balance": 0.5,
22
+ "logvar_channels": 128,
23
+ "model_channels": 192,
24
+ "num_blocks": 3,
25
+ "num_class_embeds": 1000,
26
+ "out_channels": 4,
27
+ "res_balance": 0.3,
28
+ "sample_size": 64,
29
+ "sigma_data": 0.5,
30
+ "use_fp16": true
31
+ }
edm2-img512-s-fid/unet/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5dee937e117e2367ede680aae4edf96635ff4debb9ae73f2617111991aa83d61
3
+ size 1120876188
edm2-img512-s-fid/unet/unet_edm2.py ADDED
@@ -0,0 +1,434 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import json
3
+ import os
4
+ from dataclasses import dataclass
5
+ from typing import List, Optional, Tuple
6
+
7
+ import numpy as np
8
+ import torch
9
+
10
+ try:
11
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
12
+ from diffusers.models.modeling_utils import ModelMixin
13
+ from diffusers.utils import BaseOutput
14
+ except ImportError: # pragma: no cover
15
+ class ModelMixin(torch.nn.Module):
16
+ pass
17
+
18
+ class ConfigMixin:
19
+ config = {}
20
+
21
+ def register_to_config(self, **kwargs):
22
+ self.config = kwargs
23
+
24
+ def register_to_config(func):
25
+ return func
26
+
27
+ @dataclass
28
+ class BaseOutput:
29
+ pass
30
+
31
+
32
+ def normalize(x: torch.Tensor, dim: Optional[List[int]] = None, eps: float = 1e-4) -> torch.Tensor:
33
+ if dim is None:
34
+ dim = list(range(1, x.ndim))
35
+ norm = torch.linalg.vector_norm(x, dim=dim, keepdim=True, dtype=torch.float32)
36
+ norm = torch.add(eps, norm, alpha=math.sqrt(norm.numel() / x.numel()))
37
+ return x / norm.to(x.dtype)
38
+
39
+
40
+ def resample(x: torch.Tensor, f: List[float], mode: str = "keep") -> torch.Tensor:
41
+ if mode == "keep":
42
+ return x
43
+ filt = np.float32(f)
44
+ pad = (len(filt) - 1) // 2
45
+ filt = filt / filt.sum()
46
+ filt = np.outer(filt, filt)[np.newaxis, np.newaxis, :, :]
47
+ filt = torch.as_tensor(filt, dtype=x.dtype, device=x.device)
48
+ c = x.shape[1]
49
+ if mode == "down":
50
+ return torch.nn.functional.conv2d(x, filt.tile([c, 1, 1, 1]), groups=c, stride=2, padding=(pad,))
51
+ return torch.nn.functional.conv_transpose2d(x, (filt * 4).tile([c, 1, 1, 1]), groups=c, stride=2, padding=(pad,))
52
+
53
+
54
+ def mp_silu(x: torch.Tensor) -> torch.Tensor:
55
+ return torch.nn.functional.silu(x) / 0.596
56
+
57
+
58
+ def mp_sum(a: torch.Tensor, b: torch.Tensor, t: float = 0.5) -> torch.Tensor:
59
+ return a.lerp(b, t) / math.sqrt((1 - t) ** 2 + t ** 2)
60
+
61
+
62
+ def mp_cat(a: torch.Tensor, b: torch.Tensor, dim: int = 1, t: float = 0.5) -> torch.Tensor:
63
+ na = a.shape[dim]
64
+ nb = b.shape[dim]
65
+ c = math.sqrt((na + nb) / ((1 - t) ** 2 + t ** 2))
66
+ wa = c / math.sqrt(na) * (1 - t)
67
+ wb = c / math.sqrt(nb) * t
68
+ return torch.cat([wa * a, wb * b], dim=dim)
69
+
70
+
71
+ class MPFourier(torch.nn.Module):
72
+ def __init__(self, num_channels: int, bandwidth: float = 1):
73
+ super().__init__()
74
+ self.register_buffer("freqs", 2 * math.pi * torch.randn(num_channels) * bandwidth)
75
+ self.register_buffer("phases", 2 * math.pi * torch.rand(num_channels))
76
+
77
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
78
+ y = x.to(torch.float32).ger(self.freqs.to(torch.float32))
79
+ y = y + self.phases.to(torch.float32)
80
+ y = y.cos() * math.sqrt(2)
81
+ return y.to(x.dtype)
82
+
83
+
84
+ class MPConv(torch.nn.Module):
85
+ def __init__(self, in_channels: int, out_channels: int, kernel: Tuple[int, ...]):
86
+ super().__init__()
87
+ self.out_channels = out_channels
88
+ self.weight = torch.nn.Parameter(torch.randn(out_channels, in_channels, *kernel))
89
+
90
+ def forward(self, x: torch.Tensor, gain: float = 1) -> torch.Tensor:
91
+ w = self.weight.to(torch.float32)
92
+ if self.training:
93
+ with torch.no_grad():
94
+ self.weight.copy_(normalize(w))
95
+ w = normalize(w)
96
+ w = w * (gain / math.sqrt(w[0].numel()))
97
+ w = w.to(x.dtype)
98
+ if w.ndim == 2:
99
+ return x @ w.t()
100
+ return torch.nn.functional.conv2d(x, w, padding=(w.shape[-1] // 2,))
101
+
102
+
103
+ class Block(torch.nn.Module):
104
+ def __init__(
105
+ self,
106
+ in_channels: int,
107
+ out_channels: int,
108
+ emb_channels: int,
109
+ flavor: str = "enc",
110
+ resample_mode: str = "keep",
111
+ resample_filter: List[float] = [1, 1],
112
+ attention: bool = False,
113
+ channels_per_head: int = 64,
114
+ dropout: float = 0.0,
115
+ res_balance: float = 0.3,
116
+ attn_balance: float = 0.3,
117
+ clip_act: Optional[float] = 256,
118
+ ):
119
+ super().__init__()
120
+ self.out_channels = out_channels
121
+ self.flavor = flavor
122
+ self.resample_filter = resample_filter
123
+ self.resample_mode = resample_mode
124
+ self.num_heads = out_channels // channels_per_head if attention else 0
125
+ self.dropout = dropout
126
+ self.res_balance = res_balance
127
+ self.attn_balance = attn_balance
128
+ self.clip_act = clip_act
129
+ self.emb_gain = torch.nn.Parameter(torch.zeros([]))
130
+ self.conv_res0 = MPConv(out_channels if flavor == "enc" else in_channels, out_channels, kernel=(3, 3))
131
+ self.emb_linear = MPConv(emb_channels, out_channels, kernel=())
132
+ self.conv_res1 = MPConv(out_channels, out_channels, kernel=(3, 3))
133
+ self.conv_skip = MPConv(in_channels, out_channels, kernel=(1, 1)) if in_channels != out_channels else None
134
+ self.attn_qkv = MPConv(out_channels, out_channels * 3, kernel=(1, 1)) if self.num_heads else None
135
+ self.attn_proj = MPConv(out_channels, out_channels, kernel=(1, 1)) if self.num_heads else None
136
+
137
+ def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor:
138
+ x = resample(x, f=self.resample_filter, mode=self.resample_mode)
139
+ if self.flavor == "enc":
140
+ if self.conv_skip is not None:
141
+ x = self.conv_skip(x)
142
+ x = normalize(x, dim=[1])
143
+
144
+ y = self.conv_res0(mp_silu(x))
145
+ c = self.emb_linear(emb, gain=self.emb_gain) + 1
146
+ y = mp_silu(y * c.unsqueeze(2).unsqueeze(3).to(y.dtype))
147
+ if self.training and self.dropout:
148
+ y = torch.nn.functional.dropout(y, p=self.dropout)
149
+ y = self.conv_res1(y)
150
+
151
+ if self.flavor == "dec" and self.conv_skip is not None:
152
+ x = self.conv_skip(x)
153
+ x = mp_sum(x, y, t=self.res_balance)
154
+
155
+ if self.num_heads:
156
+ y = self.attn_qkv(x)
157
+ y = y.reshape(y.shape[0], self.num_heads, -1, 3, y.shape[2] * y.shape[3])
158
+ q, k, v = normalize(y, dim=[2]).unbind(3)
159
+ w = torch.einsum("nhcq,nhck->nhqk", q, k / math.sqrt(q.shape[2])).softmax(dim=3)
160
+ y = torch.einsum("nhqk,nhck->nhcq", w, v)
161
+ y = self.attn_proj(y.reshape(*x.shape))
162
+ x = mp_sum(x, y, t=self.attn_balance)
163
+
164
+ if self.clip_act is not None:
165
+ x = x.clip_(-self.clip_act, self.clip_act)
166
+ return x
167
+
168
+
169
+ class EDM2UNet(torch.nn.Module):
170
+ def __init__(
171
+ self,
172
+ img_resolution: int,
173
+ img_channels: int,
174
+ label_dim: int,
175
+ model_channels: int = 192,
176
+ channel_mult: Tuple[int, ...] = (1, 2, 3, 4),
177
+ channel_mult_noise: Optional[int] = None,
178
+ channel_mult_emb: Optional[int] = None,
179
+ num_blocks: int = 3,
180
+ attn_resolutions: Tuple[int, ...] = (16, 8),
181
+ label_balance: float = 0.5,
182
+ concat_balance: float = 0.5,
183
+ **block_kwargs,
184
+ ):
185
+ super().__init__()
186
+ cblock = [model_channels * x for x in channel_mult]
187
+ cnoise = model_channels * channel_mult_noise if channel_mult_noise is not None else cblock[0]
188
+ cemb = model_channels * channel_mult_emb if channel_mult_emb is not None else max(cblock)
189
+ self.label_balance = label_balance
190
+ self.concat_balance = concat_balance
191
+ self.out_gain = torch.nn.Parameter(torch.zeros([]))
192
+
193
+ self.emb_fourier = MPFourier(cnoise)
194
+ self.emb_noise = MPConv(cnoise, cemb, kernel=())
195
+ self.emb_label = MPConv(label_dim, cemb, kernel=()) if label_dim else None
196
+
197
+ self.enc = torch.nn.ModuleDict()
198
+ cout = img_channels + 1
199
+ for level, channels in enumerate(cblock):
200
+ res = img_resolution >> level
201
+ if level == 0:
202
+ cin = cout
203
+ cout = channels
204
+ self.enc[f"{res}x{res}_conv"] = MPConv(cin, cout, kernel=(3, 3))
205
+ else:
206
+ self.enc[f"{res}x{res}_down"] = Block(cout, cout, cemb, flavor="enc", resample_mode="down", **block_kwargs)
207
+ for idx in range(num_blocks):
208
+ cin = cout
209
+ cout = channels
210
+ self.enc[f"{res}x{res}_block{idx}"] = Block(
211
+ cin,
212
+ cout,
213
+ cemb,
214
+ flavor="enc",
215
+ attention=(res in attn_resolutions),
216
+ **block_kwargs,
217
+ )
218
+
219
+ self.dec = torch.nn.ModuleDict()
220
+ skips = [block.out_channels for block in self.enc.values()]
221
+ for level, channels in reversed(list(enumerate(cblock))):
222
+ res = img_resolution >> level
223
+ if level == len(cblock) - 1:
224
+ self.dec[f"{res}x{res}_in0"] = Block(cout, cout, cemb, flavor="dec", attention=True, **block_kwargs)
225
+ self.dec[f"{res}x{res}_in1"] = Block(cout, cout, cemb, flavor="dec", **block_kwargs)
226
+ else:
227
+ self.dec[f"{res}x{res}_up"] = Block(cout, cout, cemb, flavor="dec", resample_mode="up", **block_kwargs)
228
+ for idx in range(num_blocks + 1):
229
+ cin = cout + skips.pop()
230
+ cout = channels
231
+ self.dec[f"{res}x{res}_block{idx}"] = Block(
232
+ cin,
233
+ cout,
234
+ cemb,
235
+ flavor="dec",
236
+ attention=(res in attn_resolutions),
237
+ **block_kwargs,
238
+ )
239
+
240
+ self.out_conv = MPConv(cout, img_channels, kernel=(3, 3))
241
+
242
+ def forward(self, x: torch.Tensor, noise_labels: torch.Tensor, class_labels: Optional[torch.Tensor]) -> torch.Tensor:
243
+ emb = self.emb_noise(self.emb_fourier(noise_labels))
244
+ if self.emb_label is not None:
245
+ if class_labels is None:
246
+ raise ValueError("class_labels are required for conditional EDM2UNet.")
247
+ emb = mp_sum(emb, self.emb_label(class_labels * math.sqrt(class_labels.shape[1])), t=self.label_balance)
248
+ emb = mp_silu(emb)
249
+
250
+ x = torch.cat([x, torch.ones_like(x[:, :1])], dim=1)
251
+ skips = []
252
+ for name, block in self.enc.items():
253
+ x = block(x) if "conv" in name else block(x, emb)
254
+ skips.append(x)
255
+
256
+ for name, block in self.dec.items():
257
+ if "block" in name:
258
+ x = mp_cat(x, skips.pop(), t=self.concat_balance)
259
+ x = block(x, emb)
260
+ return self.out_conv(x, gain=self.out_gain)
261
+
262
+
263
+ @dataclass
264
+ class EDM2UNet2DOutput(BaseOutput):
265
+ sample: torch.Tensor
266
+ logvar: Optional[torch.Tensor] = None
267
+
268
+
269
+
270
+ _CONFIG_KEYS = (
271
+ "sample_size",
272
+ "in_channels",
273
+ "out_channels",
274
+ "num_class_embeds",
275
+ "use_fp16",
276
+ "sigma_data",
277
+ "logvar_channels",
278
+ "model_channels",
279
+ "channel_mult",
280
+ "channel_mult_noise",
281
+ "channel_mult_emb",
282
+ "num_blocks",
283
+ "attn_resolutions",
284
+ "label_balance",
285
+ "concat_balance",
286
+ "dropout",
287
+ "channels_per_head",
288
+ "res_balance",
289
+ "attn_balance",
290
+ "clip_act",
291
+ )
292
+
293
+
294
+ class EDM2UNet2DModel(ModelMixin, ConfigMixin):
295
+ @register_to_config
296
+ def __init__(
297
+ self,
298
+ sample_size: int = 64,
299
+ in_channels: int = 4,
300
+ out_channels: int = 4,
301
+ num_class_embeds: int = 0,
302
+ use_fp16: bool = True,
303
+ sigma_data: float = 0.5,
304
+ logvar_channels: int = 128,
305
+ model_channels: int = 192,
306
+ channel_mult: Tuple[int, ...] = (1, 2, 3, 4),
307
+ channel_mult_noise: Optional[int] = None,
308
+ channel_mult_emb: Optional[int] = None,
309
+ num_blocks: int = 3,
310
+ attn_resolutions: Tuple[int, ...] = (16, 8),
311
+ label_balance: float = 0.5,
312
+ concat_balance: float = 0.5,
313
+ dropout: float = 0.0,
314
+ channels_per_head: int = 64,
315
+ res_balance: float = 0.3,
316
+ attn_balance: float = 0.3,
317
+ clip_act: Optional[float] = 256,
318
+ ):
319
+ super().__init__()
320
+ self.sample_size = sample_size
321
+ self.in_channels = in_channels
322
+ self.out_channels = out_channels
323
+ self.num_class_embeds = num_class_embeds
324
+ self.use_fp16 = use_fp16
325
+ self.sigma_data = sigma_data
326
+ self.model_channels = model_channels
327
+ self.channel_mult = channel_mult
328
+ self.channel_mult_noise = channel_mult_noise
329
+ self.channel_mult_emb = channel_mult_emb
330
+ self.num_blocks = num_blocks
331
+ self.attn_resolutions = attn_resolutions
332
+ self.label_balance = label_balance
333
+ self.concat_balance = concat_balance
334
+ self.dropout = dropout
335
+ self.channels_per_head = channels_per_head
336
+ self.res_balance = res_balance
337
+ self.attn_balance = attn_balance
338
+ self.clip_act = clip_act
339
+ self.unet = EDM2UNet(
340
+ img_resolution=sample_size,
341
+ img_channels=in_channels,
342
+ label_dim=num_class_embeds,
343
+ model_channels=model_channels,
344
+ channel_mult=channel_mult,
345
+ channel_mult_noise=channel_mult_noise,
346
+ channel_mult_emb=channel_mult_emb,
347
+ num_blocks=num_blocks,
348
+ attn_resolutions=attn_resolutions,
349
+ label_balance=label_balance,
350
+ concat_balance=concat_balance,
351
+ dropout=dropout,
352
+ channels_per_head=channels_per_head,
353
+ res_balance=res_balance,
354
+ attn_balance=attn_balance,
355
+ clip_act=clip_act,
356
+ )
357
+ self.logvar_fourier = MPFourier(logvar_channels)
358
+ self.logvar_linear = MPConv(logvar_channels, 1, kernel=())
359
+
360
+ def forward(
361
+ self,
362
+ sample: torch.Tensor,
363
+ sigma: torch.Tensor,
364
+ class_labels: Optional[torch.Tensor] = None,
365
+ force_fp32: bool = False,
366
+ return_logvar: bool = False,
367
+ return_dict: bool = True,
368
+ ) -> EDM2UNet2DOutput:
369
+ x = sample.to(torch.float32)
370
+ sigma = sigma.to(torch.float32).reshape(-1, 1, 1, 1)
371
+ if self.num_class_embeds == 0:
372
+ class_labels = None
373
+ else:
374
+ if class_labels is None:
375
+ class_labels = torch.zeros([x.shape[0], self.num_class_embeds], device=x.device)
376
+ class_labels = class_labels.to(torch.float32).reshape(-1, self.num_class_embeds)
377
+ dtype = torch.float16 if (self.use_fp16 and not force_fp32 and x.device.type == "cuda") else torch.float32
378
+
379
+ c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2)
380
+ c_out = sigma * self.sigma_data / (sigma**2 + self.sigma_data**2).sqrt()
381
+ c_in = 1 / (self.sigma_data**2 + sigma**2).sqrt()
382
+ c_noise = sigma.flatten().log() / 4
383
+
384
+ x_in = (c_in * x).to(dtype)
385
+ f_x = self.unet(x_in, c_noise, class_labels)
386
+ d_x = c_skip * x + c_out * f_x.to(torch.float32)
387
+
388
+ logvar = None
389
+ if return_logvar:
390
+ logvar = self.logvar_linear(self.logvar_fourier(c_noise)).reshape(-1, 1, 1, 1)
391
+
392
+ if not return_dict:
393
+ return (d_x, logvar)
394
+ return EDM2UNet2DOutput(sample=d_x, logvar=logvar)
395
+
396
+ @classmethod
397
+ def from_pretrained(cls, pretrained_model_name_or_path: str, torch_dtype: Optional[torch.dtype] = None, **kwargs):
398
+ subfolder = kwargs.pop("subfolder", None)
399
+ model_dir = os.path.join(pretrained_model_name_or_path, subfolder) if subfolder else pretrained_model_name_or_path
400
+ with open(os.path.join(model_dir, "config.json"), "r", encoding="utf-8") as f:
401
+ config = json.load(f)
402
+ init_kwargs = {k: v for k, v in config.items() if k in _CONFIG_KEYS}
403
+ model = cls(**init_kwargs)
404
+ weight_file = os.path.join(model_dir, "diffusion_pytorch_model.safetensors")
405
+ if os.path.isfile(weight_file):
406
+ from safetensors.torch import load_file
407
+
408
+ state_dict = load_file(weight_file)
409
+ else:
410
+ state_dict = torch.load(os.path.join(model_dir, "diffusion_pytorch_model.bin"), map_location="cpu")
411
+ model.load_state_dict(state_dict, strict=True)
412
+ if torch_dtype is not None:
413
+ model = model.to(dtype=torch_dtype)
414
+ return model
415
+
416
+ def save_pretrained(self, save_directory: str, safe_serialization: bool = True):
417
+ os.makedirs(save_directory, exist_ok=True)
418
+ stored = dict(getattr(self, "config", {}))
419
+ config = {"_class_name": self.__class__.__name__}
420
+ for key in _CONFIG_KEYS:
421
+ if key in stored:
422
+ config[key] = stored[key]
423
+ elif hasattr(self, key):
424
+ config[key] = getattr(self, key)
425
+ with open(os.path.join(save_directory, "config.json"), "w", encoding="utf-8") as f:
426
+ json.dump(config, f, indent=2, sort_keys=True)
427
+ f.write("\n")
428
+ state_dict = self.state_dict()
429
+ if safe_serialization:
430
+ from safetensors.torch import save_file
431
+
432
+ save_file(state_dict, os.path.join(save_directory, "diffusion_pytorch_model.safetensors"))
433
+ else:
434
+ torch.save(state_dict, os.path.join(save_directory, "diffusion_pytorch_model.bin"))
edm2-img512-s-fid/vae/config.json ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "AutoencoderKL",
3
+ "_diffusers_version": "0.36.0",
4
+ "_name_or_path": "stabilityai/sd-vae-ft-mse",
5
+ "act_fn": "silu",
6
+ "block_out_channels": [
7
+ 128,
8
+ 256,
9
+ 512,
10
+ 512
11
+ ],
12
+ "down_block_types": [
13
+ "DownEncoderBlock2D",
14
+ "DownEncoderBlock2D",
15
+ "DownEncoderBlock2D",
16
+ "DownEncoderBlock2D"
17
+ ],
18
+ "force_upcast": true,
19
+ "in_channels": 3,
20
+ "latent_channels": 4,
21
+ "latents_mean": null,
22
+ "latents_std": null,
23
+ "layers_per_block": 2,
24
+ "mid_block_add_attention": true,
25
+ "norm_num_groups": 32,
26
+ "out_channels": 3,
27
+ "sample_size": 256,
28
+ "scaling_factor": 0.18215,
29
+ "shift_factor": null,
30
+ "up_block_types": [
31
+ "UpDecoderBlock2D",
32
+ "UpDecoderBlock2D",
33
+ "UpDecoderBlock2D",
34
+ "UpDecoderBlock2D"
35
+ ],
36
+ "use_post_quant_conv": true,
37
+ "use_quant_conv": true
38
+ }
edm2-img512-s-fid/vae/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a1d993488569e928462932c8c38a0760b874d166399b14414135bd9c42df5815
3
+ size 334643276
edm2-img512-xl-fid/demo.png ADDED

Git LFS Details

  • SHA256: 551c91feb88ea0279f61d52c20463da670f01f99e37467a6f358b699f33cd526
  • Pointer size: 131 Bytes
  • Size of remote file: 370 kB
edm2-img512-xl-fid/model_index.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": [
3
+ "pipeline",
4
+ "EDM2Pipeline"
5
+ ],
6
+ "_diffusers_version": "0.31.0",
7
+ "scheduler": [
8
+ "diffusers",
9
+ "EDMEulerScheduler"
10
+ ],
11
+ "unet": [
12
+ "unet_edm2",
13
+ "EDM2UNet2DModel"
14
+ ],
15
+ "vae": [
16
+ "diffusers",
17
+ "AutoencoderKL"
18
+ ]
19
+ }
edm2-img512-xl-fid/pipeline.py ADDED
@@ -0,0 +1,406 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Hub custom pipeline: EDM2Pipeline.
2
+ Load with native Hugging Face diffusers and trust_remote_code=True.
3
+ """
4
+
5
+ from __future__ import annotations
6
+
7
+ # Copyright 2026 The HuggingFace Team. All rights reserved.
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+
21
+ import json
22
+ import os
23
+ from pathlib import Path
24
+ from typing import Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Union
25
+
26
+ import numpy as np
27
+ import torch
28
+
29
+ from diffusers.image_processor import VaeImageProcessor
30
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
31
+ from diffusers.utils import replace_example_docstring
32
+ from diffusers.utils.torch_utils import randn_tensor
33
+
34
+ EXAMPLE_DOC_STRING = """
35
+ Examples:
36
+ ```py
37
+ >>> from pathlib import Path
38
+ >>> import torch
39
+ >>> from diffusers import DiffusionPipeline
40
+
41
+ >>> model_dir = Path("BiliSakura/EDM2-diffusers/edm2-img512-xs-fid").resolve()
42
+ >>> pipe = DiffusionPipeline.from_pretrained(
43
+ ... str(model_dir),
44
+ ... local_files_only=True,
45
+ ... custom_pipeline=str(model_dir / "pipeline.py"),
46
+ ... trust_remote_code=True,
47
+ ... torch_dtype=torch.float32,
48
+ ... )
49
+ >>> pipe.to("cuda")
50
+
51
+ >>> generator = torch.Generator(device="cuda").manual_seed(42)
52
+ >>> image = pipe(
53
+ ... class_labels=207,
54
+ ... num_inference_steps=32,
55
+ ... guidance_scale=1.0,
56
+ ... generator=generator,
57
+ ... ).images[0]
58
+ >>> image.save("demo.png")
59
+ ```
60
+ """
61
+
62
+ # Default Stability VAE latent whitening used by NVlabs/edm2 (training/encoders.py).
63
+ _STABILITY_VAE_SCALE = np.float32(0.5) / np.float32([4.17, 4.62, 3.71, 3.28])
64
+ _STABILITY_VAE_BIAS = np.float32(0.0) - np.float32([5.81, 3.25, 0.12, -2.15]) * _STABILITY_VAE_SCALE
65
+
66
+ class EDM2Pipeline(DiffusionPipeline):
67
+ r"""
68
+ Pipeline for class-conditional image generation with EDM2
69
+ ([Analyzing and Improving the Training Dynamics of Diffusion Models](https://arxiv.org/abs/2312.02696)).
70
+
71
+ Parameters:
72
+ unet ([`EDM2UNet2DModel`]):
73
+ Main magnitude-preserving U-Net with EDM preconditioning.
74
+ scheduler ([`EDMEulerScheduler`]):
75
+ Built-in diffusers scheduler used for the Karras sigma schedule. EDM2 Heun sampling runs in
76
+ the pipeline because the UNet returns denoised latents rather than noise predictions.
77
+ vae ([`AutoencoderKL`], *optional*):
78
+ Decoder for 512px latent-diffusion checkpoints. Required when `unet.in_channels == 4`.
79
+ gnet ([`EDM2UNet2DModel`], *optional*):
80
+ Guiding network for autoguidance (`ref.lerp(main, guidance_scale)`).
81
+ id2label (`dict[int, str]`, *optional*):
82
+ ImageNet class id to English label mapping.
83
+ """
84
+
85
+ model_cpu_offload_seq = "unet->gnet->vae"
86
+ _optional_components = ["vae", "gnet"]
87
+
88
+ def __init__(
89
+ self,
90
+ unet,
91
+ scheduler,
92
+ vae=None,
93
+ gnet=None,
94
+ id2label: Optional[Dict[Union[int, str], str]] = None,
95
+ ) -> None:
96
+ super().__init__()
97
+ self.register_modules(unet=unet, scheduler=scheduler, vae=vae, gnet=gnet)
98
+ self._id2label = self._normalize_id2label(id2label)
99
+ self.labels = self._build_label2id(self._id2label)
100
+ self._labels_loaded_from_model_index = bool(self._id2label)
101
+ self.vae_scale_factor = 8 if self.vae is not None else 1
102
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_normalize=False)
103
+
104
+ @staticmethod
105
+ def _normalize_id2label(id2label: Optional[Dict[Union[int, str], str]]) -> Dict[int, str]:
106
+ if not id2label:
107
+ return {}
108
+ return {int(key): value for key, value in id2label.items()}
109
+
110
+ @staticmethod
111
+ def _build_label2id(id2label: Dict[int, str]) -> Dict[str, int]:
112
+ label2id: Dict[str, int] = {}
113
+ for class_id, value in id2label.items():
114
+ for synonym in value.split(","):
115
+ synonym = synonym.strip()
116
+ if synonym:
117
+ label2id[synonym] = int(class_id)
118
+ return dict(sorted(label2id.items()))
119
+
120
+ def _ensure_labels_loaded(self) -> None:
121
+ if self._labels_loaded_from_model_index:
122
+ return
123
+ loaded = self._read_id2label_from_model_index(getattr(self.config, "_name_or_path", None))
124
+ if loaded:
125
+ self._id2label = loaded
126
+ self.labels = self._build_label2id(self._id2label)
127
+ self._labels_loaded_from_model_index = True
128
+
129
+ @staticmethod
130
+ def _read_id2label_from_model_index(variant_path: Optional[str]) -> Dict[int, str]:
131
+ if not variant_path:
132
+ return {}
133
+ model_index_path = Path(variant_path).resolve() / "model_index.json"
134
+ if not model_index_path.is_file():
135
+ return {}
136
+ raw = json.loads(model_index_path.read_text(encoding="utf-8"))
137
+ id2label = raw.get("id2label")
138
+ if not isinstance(id2label, dict):
139
+ return {}
140
+ return {int(key): value for key, value in id2label.items()}
141
+
142
+ @property
143
+ def id2label(self) -> Dict[int, str]:
144
+ r"""ImageNet class id to English label string (comma-separated synonyms)."""
145
+ self._ensure_labels_loaded()
146
+ return self._id2label
147
+
148
+ def get_label_ids(self, label: Union[str, List[str]]) -> List[int]:
149
+ r"""
150
+ Map ImageNet label strings to class ids.
151
+
152
+ Args:
153
+ label (`str` or `list[str]`):
154
+ One or more English label strings that match entries in `id2label`.
155
+ """
156
+ self._ensure_labels_loaded()
157
+ if not self.labels:
158
+ raise ValueError("No English labels loaded. Add `id2label` to model_index.json.")
159
+ labels = [label] if isinstance(label, str) else list(label)
160
+ missing = [item for item in labels if item not in self.labels]
161
+ if missing:
162
+ preview = ", ".join(list(self.labels.keys())[:8])
163
+ raise ValueError(f"Unknown English label(s): {missing}. Example valid labels: {preview}, ...")
164
+ return [self.labels[item] for item in labels]
165
+
166
+ def _default_image_size(self) -> int:
167
+ latent_size = int(getattr(self.unet, "sample_size", getattr(self.unet.config, "sample_size", 64)))
168
+ return latent_size * self.vae_scale_factor
169
+
170
+ def check_inputs(
171
+ self,
172
+ height: int,
173
+ width: int,
174
+ num_inference_steps: int,
175
+ guidance_scale: float,
176
+ output_type: str,
177
+ ) -> None:
178
+ if num_inference_steps < 1:
179
+ raise ValueError("num_inference_steps must be >= 1.")
180
+ if guidance_scale < 1.0:
181
+ raise ValueError("guidance_scale must be >= 1.0.")
182
+ if guidance_scale > 1.0 and self.gnet is None:
183
+ raise ValueError("guidance_scale > 1.0 requires a guiding network (`gnet`).")
184
+ if output_type not in {"pil", "np", "pt", "latent"}:
185
+ raise ValueError("output_type must be one of: 'pil', 'np', 'pt', 'latent'.")
186
+
187
+ native_size = self._default_image_size()
188
+ if height != native_size or width != native_size:
189
+ raise ValueError(
190
+ f"EDM2 expects native resolution height=width={native_size}. "
191
+ f"Got height={height}, width={width}."
192
+ )
193
+
194
+ def _normalize_class_labels(
195
+ self,
196
+ class_labels: Optional[Union[int, str, Sequence[Union[int, str]], torch.Tensor]],
197
+ batch_size: int,
198
+ device: torch.device,
199
+ ) -> Optional[torch.Tensor]:
200
+ label_dim = int(getattr(self.unet, "num_class_embeds", getattr(self.unet.config, "num_class_embeds", 0)))
201
+ if label_dim == 0:
202
+ return None
203
+ if class_labels is None:
204
+ indices = torch.randint(label_dim, size=(batch_size,), device=device)
205
+ return torch.eye(label_dim, device=device, dtype=torch.float32)[indices]
206
+
207
+ if isinstance(class_labels, str):
208
+ class_labels = self.get_label_ids(class_labels)[0]
209
+ elif isinstance(class_labels, Sequence) and class_labels and isinstance(class_labels[0], str):
210
+ class_labels = self.get_label_ids(list(class_labels))
211
+
212
+ if isinstance(class_labels, int):
213
+ indices = torch.full((batch_size,), class_labels, device=device, dtype=torch.long)
214
+ elif isinstance(class_labels, torch.Tensor):
215
+ if class_labels.ndim == 2:
216
+ labels = class_labels.to(device=device, dtype=torch.float32)
217
+ if labels.shape[0] != batch_size:
218
+ raise ValueError(f"class_labels batch must match batch_size={batch_size}.")
219
+ return labels
220
+ indices = class_labels.to(device=device, dtype=torch.long).flatten()
221
+ else:
222
+ indices = torch.tensor(list(class_labels), device=device, dtype=torch.long)
223
+
224
+ if indices.numel() == 1 and batch_size > 1:
225
+ indices = indices.repeat(batch_size)
226
+ if indices.numel() != batch_size:
227
+ raise ValueError(f"class_labels must resolve to batch size {batch_size}, got {indices.numel()}.")
228
+ return torch.eye(label_dim, device=device, dtype=torch.float32)[indices]
229
+
230
+ def prepare_latents(
231
+ self,
232
+ batch_size: int,
233
+ height: int,
234
+ width: int,
235
+ dtype: torch.dtype,
236
+ device: torch.device,
237
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]],
238
+ ) -> torch.Tensor:
239
+ in_channels = int(getattr(self.unet, "in_channels", getattr(self.unet.config, "in_channels", 4)))
240
+ latent_size = height // self.vae_scale_factor
241
+ return randn_tensor(
242
+ (batch_size, in_channels, latent_size, latent_size),
243
+ generator=generator,
244
+ device=device,
245
+ dtype=torch.float32,
246
+ )
247
+
248
+ def decode_latents(self, latents: torch.Tensor, output_type: str = "pil"):
249
+ if output_type == "latent":
250
+ return latents
251
+
252
+ in_channels = int(getattr(self.unet, "in_channels", getattr(self.unet.config, "in_channels", 3)))
253
+ if self.vae is None:
254
+ image = (latents.to(torch.float32) * 127.5 + 128).clip(0, 255) / 255.0
255
+ return self.image_processor.postprocess(image, output_type=output_type)
256
+
257
+ if in_channels == 4:
258
+ x = latents.to(torch.float32)
259
+ scale = torch.as_tensor(_STABILITY_VAE_SCALE, dtype=x.dtype, device=x.device).reshape(1, -1, 1, 1)
260
+ bias = torch.as_tensor(_STABILITY_VAE_BIAS, dtype=x.dtype, device=x.device).reshape(1, -1, 1, 1)
261
+ x = (x - bias) / scale
262
+ else:
263
+ x = latents.to(torch.float32)
264
+
265
+ vae_dtype = getattr(self.vae, "dtype", None) or next(self.vae.parameters()).dtype
266
+ image = self.vae.decode(x.to(dtype=vae_dtype)).sample.to(torch.float32).clamp(0, 1)
267
+
268
+ return self.image_processor.postprocess(image, output_type=output_type)
269
+
270
+ @staticmethod
271
+ def _apply_autoguidance(
272
+ main: torch.Tensor,
273
+ ref: torch.Tensor,
274
+ guidance_scale: float,
275
+ ) -> torch.Tensor:
276
+ return ref.lerp(main, guidance_scale)
277
+
278
+ @staticmethod
279
+ def _sample_edm2_heun(
280
+ denoise_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
281
+ noise: torch.Tensor,
282
+ sigmas: torch.Tensor,
283
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
284
+ progress_bar: Optional[Callable[[Iterable], Iterable]] = None,
285
+ dtype: torch.dtype = torch.float32,
286
+ ) -> torch.Tensor:
287
+ """NVlabs EDM2 Heun sampler (generate_images.edm_sampler, guidance=1, S_churn=0)."""
288
+ x_next = noise.to(dtype) * sigmas[0]
289
+
290
+ sigma_pairs = list(zip(sigmas[:-1], sigmas[1:]))
291
+ if progress_bar is not None:
292
+ sigma_pairs = progress_bar(sigma_pairs)
293
+
294
+ num_steps = len(sigma_pairs)
295
+ for i, (sigma_cur, sigma_next) in enumerate(sigma_pairs):
296
+ x_hat, sigma_hat = x_next, sigma_cur
297
+ d_cur = (x_hat - denoise_fn(x_hat, sigma_hat)) / sigma_hat
298
+ x_next = x_hat + (sigma_next - sigma_hat) * d_cur
299
+ if i < num_steps - 1:
300
+ d_prime = (x_next - denoise_fn(x_next, sigma_next)) / sigma_next
301
+ x_next = x_hat + (sigma_next - sigma_hat) * (0.5 * d_cur + 0.5 * d_prime)
302
+ return x_next
303
+
304
+ @torch.inference_mode()
305
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
306
+ def __call__(
307
+ self,
308
+ class_labels: Optional[Union[int, str, Sequence[Union[int, str]], torch.Tensor]] = None,
309
+ batch_size: int = 1,
310
+ height: Optional[int] = None,
311
+ width: Optional[int] = None,
312
+ num_inference_steps: int = 32,
313
+ guidance_scale: float = 1.0,
314
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
315
+ output_type: str = "pil",
316
+ return_dict: bool = True,
317
+ ) -> Union[ImagePipelineOutput, Tuple]:
318
+ r"""
319
+ Generate class-conditional images with EDM2.
320
+
321
+ Args:
322
+ class_labels (`int`, `str`, `list[int]`, `list[str]`, or `torch.Tensor`, *optional*):
323
+ ImageNet class indices, English label strings, or one-hot float tensors.
324
+ Random classes are sampled when omitted on conditional models.
325
+ batch_size (`int`, defaults to `1`):
326
+ Number of images to generate.
327
+ height (`int`, *optional*):
328
+ Output height in pixels. Defaults to the pretrained native resolution.
329
+ width (`int`, *optional*):
330
+ Output width in pixels. Defaults to the pretrained native resolution.
331
+ num_inference_steps (`int`, defaults to `32`):
332
+ Number of EDM2 Heun steps (NVlabs default).
333
+ guidance_scale (`float`, defaults to `1.0`):
334
+ Autoguidance strength. Values above `1.0` blend the main net with `gnet`
335
+ via `gnet_output.lerp(unet_output, guidance_scale)`.
336
+ generator (`torch.Generator`, *optional*):
337
+ RNG for reproducibility.
338
+ output_type (`str`, defaults to `"pil"`):
339
+ `"pil"`, `"np"`, `"pt"`, or `"latent"`.
340
+ return_dict (`bool`, defaults to `True`):
341
+ Return [`~pipelines.pipeline_utils.ImagePipelineOutput`] if True.
342
+
343
+ Examples:
344
+ <!-- this section is replaced by replace_example_docstring -->
345
+ """
346
+ default_size = self._default_image_size()
347
+ height = int(height or default_size)
348
+ width = int(width or default_size)
349
+ self.check_inputs(height, width, num_inference_steps, guidance_scale, output_type)
350
+
351
+ device = self._execution_device
352
+ dtype = self.unet.dtype
353
+ labels = self._normalize_class_labels(class_labels, batch_size=batch_size, device=device)
354
+ noise = self.prepare_latents(batch_size, height, width, dtype, device, generator)
355
+
356
+ def denoise_fn(x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor:
357
+ sigma_batch = sigma.reshape(1).expand(batch_size)
358
+ main = self.unet(
359
+ sample=x,
360
+ sigma=sigma_batch,
361
+ class_labels=labels,
362
+ force_fp32=True,
363
+ ).sample
364
+ if guidance_scale == 1.0 or self.gnet is None:
365
+ return main.to(torch.float32)
366
+ ref = self.gnet(
367
+ sample=x,
368
+ sigma=sigma_batch,
369
+ class_labels=labels,
370
+ force_fp32=True,
371
+ ).sample
372
+ return self._apply_autoguidance(main, ref, guidance_scale).to(torch.float32)
373
+
374
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
375
+ latents = self._sample_edm2_heun(
376
+ denoise_fn=denoise_fn,
377
+ noise=noise,
378
+ sigmas=self.scheduler.sigmas.to(device),
379
+ generator=generator,
380
+ progress_bar=self.progress_bar,
381
+ dtype=torch.float32,
382
+ )
383
+
384
+ image = self.decode_latents(latents, output_type=output_type)
385
+ if not return_dict:
386
+ return (image, latents)
387
+ return ImagePipelineOutput(images=image)
388
+
389
+ @classmethod
390
+ def _load_vae(cls, pretrained_model_name_or_path: str, torch_dtype: Optional[torch.dtype] = None):
391
+ vae_dir = os.path.join(pretrained_model_name_or_path, "vae")
392
+ if os.path.isdir(vae_dir):
393
+ try:
394
+
395
+ return AutoencoderKL.from_pretrained(vae_dir, torch_dtype=torch_dtype)
396
+ except Exception:
397
+ return None
398
+
399
+ vae_hint = os.path.join(pretrained_model_name_or_path, "vae_pretrained_model_name_or_path.txt")
400
+ if os.path.isfile(vae_hint):
401
+ with open(vae_hint, "r", encoding="utf-8") as f:
402
+ hub_id = f.read().strip()
403
+ if hub_id:
404
+
405
+ return AutoencoderKL.from_pretrained(hub_id, torch_dtype=torch_dtype)
406
+ return None
edm2-img512-xl-fid/scheduler/scheduler_config.json ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "EDMEulerScheduler",
3
+ "final_sigmas_type": "zero",
4
+ "num_train_timesteps": 1000,
5
+ "prediction_type": "epsilon",
6
+ "rho": 7.0,
7
+ "sigma_data": 0.5,
8
+ "sigma_max": 80.0,
9
+ "sigma_min": 0.002,
10
+ "sigma_schedule": "karras"
11
+ }
edm2-img512-xl-fid/unet/config.json ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "EDM2UNet2DModel",
3
+ "attn_balance": 0.3,
4
+ "attn_resolutions": [
5
+ 16,
6
+ 8
7
+ ],
8
+ "channel_mult": [
9
+ 1,
10
+ 2,
11
+ 3,
12
+ 4
13
+ ],
14
+ "channel_mult_emb": 4,
15
+ "channel_mult_noise": 1,
16
+ "channels_per_head": 64,
17
+ "clip_act": 256,
18
+ "concat_balance": 0.5,
19
+ "dropout": 0.0,
20
+ "in_channels": 4,
21
+ "label_balance": 0.5,
22
+ "logvar_channels": 128,
23
+ "model_channels": 384,
24
+ "num_blocks": 3,
25
+ "num_class_embeds": 1000,
26
+ "out_channels": 4,
27
+ "res_balance": 0.3,
28
+ "sample_size": 64,
29
+ "sigma_data": 0.5,
30
+ "use_fp16": true
31
+ }
edm2-img512-xl-fid/unet/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3c7402d8a4e91781b5c94fa2a5beee5820970ad99d2249141e191364885f222a
3
+ size 4477161892
edm2-img512-xl-fid/unet/unet_edm2.py ADDED
@@ -0,0 +1,434 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import json
3
+ import os
4
+ from dataclasses import dataclass
5
+ from typing import List, Optional, Tuple
6
+
7
+ import numpy as np
8
+ import torch
9
+
10
+ try:
11
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
12
+ from diffusers.models.modeling_utils import ModelMixin
13
+ from diffusers.utils import BaseOutput
14
+ except ImportError: # pragma: no cover
15
+ class ModelMixin(torch.nn.Module):
16
+ pass
17
+
18
+ class ConfigMixin:
19
+ config = {}
20
+
21
+ def register_to_config(self, **kwargs):
22
+ self.config = kwargs
23
+
24
+ def register_to_config(func):
25
+ return func
26
+
27
+ @dataclass
28
+ class BaseOutput:
29
+ pass
30
+
31
+
32
+ def normalize(x: torch.Tensor, dim: Optional[List[int]] = None, eps: float = 1e-4) -> torch.Tensor:
33
+ if dim is None:
34
+ dim = list(range(1, x.ndim))
35
+ norm = torch.linalg.vector_norm(x, dim=dim, keepdim=True, dtype=torch.float32)
36
+ norm = torch.add(eps, norm, alpha=math.sqrt(norm.numel() / x.numel()))
37
+ return x / norm.to(x.dtype)
38
+
39
+
40
+ def resample(x: torch.Tensor, f: List[float], mode: str = "keep") -> torch.Tensor:
41
+ if mode == "keep":
42
+ return x
43
+ filt = np.float32(f)
44
+ pad = (len(filt) - 1) // 2
45
+ filt = filt / filt.sum()
46
+ filt = np.outer(filt, filt)[np.newaxis, np.newaxis, :, :]
47
+ filt = torch.as_tensor(filt, dtype=x.dtype, device=x.device)
48
+ c = x.shape[1]
49
+ if mode == "down":
50
+ return torch.nn.functional.conv2d(x, filt.tile([c, 1, 1, 1]), groups=c, stride=2, padding=(pad,))
51
+ return torch.nn.functional.conv_transpose2d(x, (filt * 4).tile([c, 1, 1, 1]), groups=c, stride=2, padding=(pad,))
52
+
53
+
54
+ def mp_silu(x: torch.Tensor) -> torch.Tensor:
55
+ return torch.nn.functional.silu(x) / 0.596
56
+
57
+
58
+ def mp_sum(a: torch.Tensor, b: torch.Tensor, t: float = 0.5) -> torch.Tensor:
59
+ return a.lerp(b, t) / math.sqrt((1 - t) ** 2 + t ** 2)
60
+
61
+
62
+ def mp_cat(a: torch.Tensor, b: torch.Tensor, dim: int = 1, t: float = 0.5) -> torch.Tensor:
63
+ na = a.shape[dim]
64
+ nb = b.shape[dim]
65
+ c = math.sqrt((na + nb) / ((1 - t) ** 2 + t ** 2))
66
+ wa = c / math.sqrt(na) * (1 - t)
67
+ wb = c / math.sqrt(nb) * t
68
+ return torch.cat([wa * a, wb * b], dim=dim)
69
+
70
+
71
+ class MPFourier(torch.nn.Module):
72
+ def __init__(self, num_channels: int, bandwidth: float = 1):
73
+ super().__init__()
74
+ self.register_buffer("freqs", 2 * math.pi * torch.randn(num_channels) * bandwidth)
75
+ self.register_buffer("phases", 2 * math.pi * torch.rand(num_channels))
76
+
77
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
78
+ y = x.to(torch.float32).ger(self.freqs.to(torch.float32))
79
+ y = y + self.phases.to(torch.float32)
80
+ y = y.cos() * math.sqrt(2)
81
+ return y.to(x.dtype)
82
+
83
+
84
+ class MPConv(torch.nn.Module):
85
+ def __init__(self, in_channels: int, out_channels: int, kernel: Tuple[int, ...]):
86
+ super().__init__()
87
+ self.out_channels = out_channels
88
+ self.weight = torch.nn.Parameter(torch.randn(out_channels, in_channels, *kernel))
89
+
90
+ def forward(self, x: torch.Tensor, gain: float = 1) -> torch.Tensor:
91
+ w = self.weight.to(torch.float32)
92
+ if self.training:
93
+ with torch.no_grad():
94
+ self.weight.copy_(normalize(w))
95
+ w = normalize(w)
96
+ w = w * (gain / math.sqrt(w[0].numel()))
97
+ w = w.to(x.dtype)
98
+ if w.ndim == 2:
99
+ return x @ w.t()
100
+ return torch.nn.functional.conv2d(x, w, padding=(w.shape[-1] // 2,))
101
+
102
+
103
+ class Block(torch.nn.Module):
104
+ def __init__(
105
+ self,
106
+ in_channels: int,
107
+ out_channels: int,
108
+ emb_channels: int,
109
+ flavor: str = "enc",
110
+ resample_mode: str = "keep",
111
+ resample_filter: List[float] = [1, 1],
112
+ attention: bool = False,
113
+ channels_per_head: int = 64,
114
+ dropout: float = 0.0,
115
+ res_balance: float = 0.3,
116
+ attn_balance: float = 0.3,
117
+ clip_act: Optional[float] = 256,
118
+ ):
119
+ super().__init__()
120
+ self.out_channels = out_channels
121
+ self.flavor = flavor
122
+ self.resample_filter = resample_filter
123
+ self.resample_mode = resample_mode
124
+ self.num_heads = out_channels // channels_per_head if attention else 0
125
+ self.dropout = dropout
126
+ self.res_balance = res_balance
127
+ self.attn_balance = attn_balance
128
+ self.clip_act = clip_act
129
+ self.emb_gain = torch.nn.Parameter(torch.zeros([]))
130
+ self.conv_res0 = MPConv(out_channels if flavor == "enc" else in_channels, out_channels, kernel=(3, 3))
131
+ self.emb_linear = MPConv(emb_channels, out_channels, kernel=())
132
+ self.conv_res1 = MPConv(out_channels, out_channels, kernel=(3, 3))
133
+ self.conv_skip = MPConv(in_channels, out_channels, kernel=(1, 1)) if in_channels != out_channels else None
134
+ self.attn_qkv = MPConv(out_channels, out_channels * 3, kernel=(1, 1)) if self.num_heads else None
135
+ self.attn_proj = MPConv(out_channels, out_channels, kernel=(1, 1)) if self.num_heads else None
136
+
137
+ def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor:
138
+ x = resample(x, f=self.resample_filter, mode=self.resample_mode)
139
+ if self.flavor == "enc":
140
+ if self.conv_skip is not None:
141
+ x = self.conv_skip(x)
142
+ x = normalize(x, dim=[1])
143
+
144
+ y = self.conv_res0(mp_silu(x))
145
+ c = self.emb_linear(emb, gain=self.emb_gain) + 1
146
+ y = mp_silu(y * c.unsqueeze(2).unsqueeze(3).to(y.dtype))
147
+ if self.training and self.dropout:
148
+ y = torch.nn.functional.dropout(y, p=self.dropout)
149
+ y = self.conv_res1(y)
150
+
151
+ if self.flavor == "dec" and self.conv_skip is not None:
152
+ x = self.conv_skip(x)
153
+ x = mp_sum(x, y, t=self.res_balance)
154
+
155
+ if self.num_heads:
156
+ y = self.attn_qkv(x)
157
+ y = y.reshape(y.shape[0], self.num_heads, -1, 3, y.shape[2] * y.shape[3])
158
+ q, k, v = normalize(y, dim=[2]).unbind(3)
159
+ w = torch.einsum("nhcq,nhck->nhqk", q, k / math.sqrt(q.shape[2])).softmax(dim=3)
160
+ y = torch.einsum("nhqk,nhck->nhcq", w, v)
161
+ y = self.attn_proj(y.reshape(*x.shape))
162
+ x = mp_sum(x, y, t=self.attn_balance)
163
+
164
+ if self.clip_act is not None:
165
+ x = x.clip_(-self.clip_act, self.clip_act)
166
+ return x
167
+
168
+
169
+ class EDM2UNet(torch.nn.Module):
170
+ def __init__(
171
+ self,
172
+ img_resolution: int,
173
+ img_channels: int,
174
+ label_dim: int,
175
+ model_channels: int = 192,
176
+ channel_mult: Tuple[int, ...] = (1, 2, 3, 4),
177
+ channel_mult_noise: Optional[int] = None,
178
+ channel_mult_emb: Optional[int] = None,
179
+ num_blocks: int = 3,
180
+ attn_resolutions: Tuple[int, ...] = (16, 8),
181
+ label_balance: float = 0.5,
182
+ concat_balance: float = 0.5,
183
+ **block_kwargs,
184
+ ):
185
+ super().__init__()
186
+ cblock = [model_channels * x for x in channel_mult]
187
+ cnoise = model_channels * channel_mult_noise if channel_mult_noise is not None else cblock[0]
188
+ cemb = model_channels * channel_mult_emb if channel_mult_emb is not None else max(cblock)
189
+ self.label_balance = label_balance
190
+ self.concat_balance = concat_balance
191
+ self.out_gain = torch.nn.Parameter(torch.zeros([]))
192
+
193
+ self.emb_fourier = MPFourier(cnoise)
194
+ self.emb_noise = MPConv(cnoise, cemb, kernel=())
195
+ self.emb_label = MPConv(label_dim, cemb, kernel=()) if label_dim else None
196
+
197
+ self.enc = torch.nn.ModuleDict()
198
+ cout = img_channels + 1
199
+ for level, channels in enumerate(cblock):
200
+ res = img_resolution >> level
201
+ if level == 0:
202
+ cin = cout
203
+ cout = channels
204
+ self.enc[f"{res}x{res}_conv"] = MPConv(cin, cout, kernel=(3, 3))
205
+ else:
206
+ self.enc[f"{res}x{res}_down"] = Block(cout, cout, cemb, flavor="enc", resample_mode="down", **block_kwargs)
207
+ for idx in range(num_blocks):
208
+ cin = cout
209
+ cout = channels
210
+ self.enc[f"{res}x{res}_block{idx}"] = Block(
211
+ cin,
212
+ cout,
213
+ cemb,
214
+ flavor="enc",
215
+ attention=(res in attn_resolutions),
216
+ **block_kwargs,
217
+ )
218
+
219
+ self.dec = torch.nn.ModuleDict()
220
+ skips = [block.out_channels for block in self.enc.values()]
221
+ for level, channels in reversed(list(enumerate(cblock))):
222
+ res = img_resolution >> level
223
+ if level == len(cblock) - 1:
224
+ self.dec[f"{res}x{res}_in0"] = Block(cout, cout, cemb, flavor="dec", attention=True, **block_kwargs)
225
+ self.dec[f"{res}x{res}_in1"] = Block(cout, cout, cemb, flavor="dec", **block_kwargs)
226
+ else:
227
+ self.dec[f"{res}x{res}_up"] = Block(cout, cout, cemb, flavor="dec", resample_mode="up", **block_kwargs)
228
+ for idx in range(num_blocks + 1):
229
+ cin = cout + skips.pop()
230
+ cout = channels
231
+ self.dec[f"{res}x{res}_block{idx}"] = Block(
232
+ cin,
233
+ cout,
234
+ cemb,
235
+ flavor="dec",
236
+ attention=(res in attn_resolutions),
237
+ **block_kwargs,
238
+ )
239
+
240
+ self.out_conv = MPConv(cout, img_channels, kernel=(3, 3))
241
+
242
+ def forward(self, x: torch.Tensor, noise_labels: torch.Tensor, class_labels: Optional[torch.Tensor]) -> torch.Tensor:
243
+ emb = self.emb_noise(self.emb_fourier(noise_labels))
244
+ if self.emb_label is not None:
245
+ if class_labels is None:
246
+ raise ValueError("class_labels are required for conditional EDM2UNet.")
247
+ emb = mp_sum(emb, self.emb_label(class_labels * math.sqrt(class_labels.shape[1])), t=self.label_balance)
248
+ emb = mp_silu(emb)
249
+
250
+ x = torch.cat([x, torch.ones_like(x[:, :1])], dim=1)
251
+ skips = []
252
+ for name, block in self.enc.items():
253
+ x = block(x) if "conv" in name else block(x, emb)
254
+ skips.append(x)
255
+
256
+ for name, block in self.dec.items():
257
+ if "block" in name:
258
+ x = mp_cat(x, skips.pop(), t=self.concat_balance)
259
+ x = block(x, emb)
260
+ return self.out_conv(x, gain=self.out_gain)
261
+
262
+
263
+ @dataclass
264
+ class EDM2UNet2DOutput(BaseOutput):
265
+ sample: torch.Tensor
266
+ logvar: Optional[torch.Tensor] = None
267
+
268
+
269
+
270
+ _CONFIG_KEYS = (
271
+ "sample_size",
272
+ "in_channels",
273
+ "out_channels",
274
+ "num_class_embeds",
275
+ "use_fp16",
276
+ "sigma_data",
277
+ "logvar_channels",
278
+ "model_channels",
279
+ "channel_mult",
280
+ "channel_mult_noise",
281
+ "channel_mult_emb",
282
+ "num_blocks",
283
+ "attn_resolutions",
284
+ "label_balance",
285
+ "concat_balance",
286
+ "dropout",
287
+ "channels_per_head",
288
+ "res_balance",
289
+ "attn_balance",
290
+ "clip_act",
291
+ )
292
+
293
+
294
+ class EDM2UNet2DModel(ModelMixin, ConfigMixin):
295
+ @register_to_config
296
+ def __init__(
297
+ self,
298
+ sample_size: int = 64,
299
+ in_channels: int = 4,
300
+ out_channels: int = 4,
301
+ num_class_embeds: int = 0,
302
+ use_fp16: bool = True,
303
+ sigma_data: float = 0.5,
304
+ logvar_channels: int = 128,
305
+ model_channels: int = 192,
306
+ channel_mult: Tuple[int, ...] = (1, 2, 3, 4),
307
+ channel_mult_noise: Optional[int] = None,
308
+ channel_mult_emb: Optional[int] = None,
309
+ num_blocks: int = 3,
310
+ attn_resolutions: Tuple[int, ...] = (16, 8),
311
+ label_balance: float = 0.5,
312
+ concat_balance: float = 0.5,
313
+ dropout: float = 0.0,
314
+ channels_per_head: int = 64,
315
+ res_balance: float = 0.3,
316
+ attn_balance: float = 0.3,
317
+ clip_act: Optional[float] = 256,
318
+ ):
319
+ super().__init__()
320
+ self.sample_size = sample_size
321
+ self.in_channels = in_channels
322
+ self.out_channels = out_channels
323
+ self.num_class_embeds = num_class_embeds
324
+ self.use_fp16 = use_fp16
325
+ self.sigma_data = sigma_data
326
+ self.model_channels = model_channels
327
+ self.channel_mult = channel_mult
328
+ self.channel_mult_noise = channel_mult_noise
329
+ self.channel_mult_emb = channel_mult_emb
330
+ self.num_blocks = num_blocks
331
+ self.attn_resolutions = attn_resolutions
332
+ self.label_balance = label_balance
333
+ self.concat_balance = concat_balance
334
+ self.dropout = dropout
335
+ self.channels_per_head = channels_per_head
336
+ self.res_balance = res_balance
337
+ self.attn_balance = attn_balance
338
+ self.clip_act = clip_act
339
+ self.unet = EDM2UNet(
340
+ img_resolution=sample_size,
341
+ img_channels=in_channels,
342
+ label_dim=num_class_embeds,
343
+ model_channels=model_channels,
344
+ channel_mult=channel_mult,
345
+ channel_mult_noise=channel_mult_noise,
346
+ channel_mult_emb=channel_mult_emb,
347
+ num_blocks=num_blocks,
348
+ attn_resolutions=attn_resolutions,
349
+ label_balance=label_balance,
350
+ concat_balance=concat_balance,
351
+ dropout=dropout,
352
+ channels_per_head=channels_per_head,
353
+ res_balance=res_balance,
354
+ attn_balance=attn_balance,
355
+ clip_act=clip_act,
356
+ )
357
+ self.logvar_fourier = MPFourier(logvar_channels)
358
+ self.logvar_linear = MPConv(logvar_channels, 1, kernel=())
359
+
360
+ def forward(
361
+ self,
362
+ sample: torch.Tensor,
363
+ sigma: torch.Tensor,
364
+ class_labels: Optional[torch.Tensor] = None,
365
+ force_fp32: bool = False,
366
+ return_logvar: bool = False,
367
+ return_dict: bool = True,
368
+ ) -> EDM2UNet2DOutput:
369
+ x = sample.to(torch.float32)
370
+ sigma = sigma.to(torch.float32).reshape(-1, 1, 1, 1)
371
+ if self.num_class_embeds == 0:
372
+ class_labels = None
373
+ else:
374
+ if class_labels is None:
375
+ class_labels = torch.zeros([x.shape[0], self.num_class_embeds], device=x.device)
376
+ class_labels = class_labels.to(torch.float32).reshape(-1, self.num_class_embeds)
377
+ dtype = torch.float16 if (self.use_fp16 and not force_fp32 and x.device.type == "cuda") else torch.float32
378
+
379
+ c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2)
380
+ c_out = sigma * self.sigma_data / (sigma**2 + self.sigma_data**2).sqrt()
381
+ c_in = 1 / (self.sigma_data**2 + sigma**2).sqrt()
382
+ c_noise = sigma.flatten().log() / 4
383
+
384
+ x_in = (c_in * x).to(dtype)
385
+ f_x = self.unet(x_in, c_noise, class_labels)
386
+ d_x = c_skip * x + c_out * f_x.to(torch.float32)
387
+
388
+ logvar = None
389
+ if return_logvar:
390
+ logvar = self.logvar_linear(self.logvar_fourier(c_noise)).reshape(-1, 1, 1, 1)
391
+
392
+ if not return_dict:
393
+ return (d_x, logvar)
394
+ return EDM2UNet2DOutput(sample=d_x, logvar=logvar)
395
+
396
+ @classmethod
397
+ def from_pretrained(cls, pretrained_model_name_or_path: str, torch_dtype: Optional[torch.dtype] = None, **kwargs):
398
+ subfolder = kwargs.pop("subfolder", None)
399
+ model_dir = os.path.join(pretrained_model_name_or_path, subfolder) if subfolder else pretrained_model_name_or_path
400
+ with open(os.path.join(model_dir, "config.json"), "r", encoding="utf-8") as f:
401
+ config = json.load(f)
402
+ init_kwargs = {k: v for k, v in config.items() if k in _CONFIG_KEYS}
403
+ model = cls(**init_kwargs)
404
+ weight_file = os.path.join(model_dir, "diffusion_pytorch_model.safetensors")
405
+ if os.path.isfile(weight_file):
406
+ from safetensors.torch import load_file
407
+
408
+ state_dict = load_file(weight_file)
409
+ else:
410
+ state_dict = torch.load(os.path.join(model_dir, "diffusion_pytorch_model.bin"), map_location="cpu")
411
+ model.load_state_dict(state_dict, strict=True)
412
+ if torch_dtype is not None:
413
+ model = model.to(dtype=torch_dtype)
414
+ return model
415
+
416
+ def save_pretrained(self, save_directory: str, safe_serialization: bool = True):
417
+ os.makedirs(save_directory, exist_ok=True)
418
+ stored = dict(getattr(self, "config", {}))
419
+ config = {"_class_name": self.__class__.__name__}
420
+ for key in _CONFIG_KEYS:
421
+ if key in stored:
422
+ config[key] = stored[key]
423
+ elif hasattr(self, key):
424
+ config[key] = getattr(self, key)
425
+ with open(os.path.join(save_directory, "config.json"), "w", encoding="utf-8") as f:
426
+ json.dump(config, f, indent=2, sort_keys=True)
427
+ f.write("\n")
428
+ state_dict = self.state_dict()
429
+ if safe_serialization:
430
+ from safetensors.torch import save_file
431
+
432
+ save_file(state_dict, os.path.join(save_directory, "diffusion_pytorch_model.safetensors"))
433
+ else:
434
+ torch.save(state_dict, os.path.join(save_directory, "diffusion_pytorch_model.bin"))
edm2-img512-xl-fid/vae/config.json ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "AutoencoderKL",
3
+ "_diffusers_version": "0.36.0",
4
+ "_name_or_path": "stabilityai/sd-vae-ft-mse",
5
+ "act_fn": "silu",
6
+ "block_out_channels": [
7
+ 128,
8
+ 256,
9
+ 512,
10
+ 512
11
+ ],
12
+ "down_block_types": [
13
+ "DownEncoderBlock2D",
14
+ "DownEncoderBlock2D",
15
+ "DownEncoderBlock2D",
16
+ "DownEncoderBlock2D"
17
+ ],
18
+ "force_upcast": true,
19
+ "in_channels": 3,
20
+ "latent_channels": 4,
21
+ "latents_mean": null,
22
+ "latents_std": null,
23
+ "layers_per_block": 2,
24
+ "mid_block_add_attention": true,
25
+ "norm_num_groups": 32,
26
+ "out_channels": 3,
27
+ "sample_size": 256,
28
+ "scaling_factor": 0.18215,
29
+ "shift_factor": null,
30
+ "up_block_types": [
31
+ "UpDecoderBlock2D",
32
+ "UpDecoderBlock2D",
33
+ "UpDecoderBlock2D",
34
+ "UpDecoderBlock2D"
35
+ ],
36
+ "use_post_quant_conv": true,
37
+ "use_quant_conv": true
38
+ }
edm2-img512-xl-fid/vae/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a1d993488569e928462932c8c38a0760b874d166399b14414135bd9c42df5815
3
+ size 334643276
edm2-img512-xs-fid/demo.png ADDED

Git LFS Details

  • SHA256: a5ceee02aab56e93c77b73e082ca5f952897a2bd98c1b78c1899f78845561785
  • Pointer size: 131 Bytes
  • Size of remote file: 376 kB
edm2-img512-xs-fid/model_index.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": [
3
+ "pipeline",
4
+ "EDM2Pipeline"
5
+ ],
6
+ "_diffusers_version": "0.31.0",
7
+ "scheduler": [
8
+ "diffusers",
9
+ "EDMEulerScheduler"
10
+ ],
11
+ "unet": [
12
+ "unet_edm2",
13
+ "EDM2UNet2DModel"
14
+ ],
15
+ "vae": [
16
+ "diffusers",
17
+ "AutoencoderKL"
18
+ ]
19
+ }
edm2-img512-xs-fid/pipeline.py ADDED
@@ -0,0 +1,406 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Hub custom pipeline: EDM2Pipeline.
2
+ Load with native Hugging Face diffusers and trust_remote_code=True.
3
+ """
4
+
5
+ from __future__ import annotations
6
+
7
+ # Copyright 2026 The HuggingFace Team. All rights reserved.
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+
21
+ import json
22
+ import os
23
+ from pathlib import Path
24
+ from typing import Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Union
25
+
26
+ import numpy as np
27
+ import torch
28
+
29
+ from diffusers.image_processor import VaeImageProcessor
30
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
31
+ from diffusers.utils import replace_example_docstring
32
+ from diffusers.utils.torch_utils import randn_tensor
33
+
34
+ EXAMPLE_DOC_STRING = """
35
+ Examples:
36
+ ```py
37
+ >>> from pathlib import Path
38
+ >>> import torch
39
+ >>> from diffusers import DiffusionPipeline
40
+
41
+ >>> model_dir = Path("BiliSakura/EDM2-diffusers/edm2-img512-xs-fid").resolve()
42
+ >>> pipe = DiffusionPipeline.from_pretrained(
43
+ ... str(model_dir),
44
+ ... local_files_only=True,
45
+ ... custom_pipeline=str(model_dir / "pipeline.py"),
46
+ ... trust_remote_code=True,
47
+ ... torch_dtype=torch.float32,
48
+ ... )
49
+ >>> pipe.to("cuda")
50
+
51
+ >>> generator = torch.Generator(device="cuda").manual_seed(42)
52
+ >>> image = pipe(
53
+ ... class_labels=207,
54
+ ... num_inference_steps=32,
55
+ ... guidance_scale=1.0,
56
+ ... generator=generator,
57
+ ... ).images[0]
58
+ >>> image.save("demo.png")
59
+ ```
60
+ """
61
+
62
+ # Default Stability VAE latent whitening used by NVlabs/edm2 (training/encoders.py).
63
+ _STABILITY_VAE_SCALE = np.float32(0.5) / np.float32([4.17, 4.62, 3.71, 3.28])
64
+ _STABILITY_VAE_BIAS = np.float32(0.0) - np.float32([5.81, 3.25, 0.12, -2.15]) * _STABILITY_VAE_SCALE
65
+
66
+ class EDM2Pipeline(DiffusionPipeline):
67
+ r"""
68
+ Pipeline for class-conditional image generation with EDM2
69
+ ([Analyzing and Improving the Training Dynamics of Diffusion Models](https://arxiv.org/abs/2312.02696)).
70
+
71
+ Parameters:
72
+ unet ([`EDM2UNet2DModel`]):
73
+ Main magnitude-preserving U-Net with EDM preconditioning.
74
+ scheduler ([`EDMEulerScheduler`]):
75
+ Built-in diffusers scheduler used for the Karras sigma schedule. EDM2 Heun sampling runs in
76
+ the pipeline because the UNet returns denoised latents rather than noise predictions.
77
+ vae ([`AutoencoderKL`], *optional*):
78
+ Decoder for 512px latent-diffusion checkpoints. Required when `unet.in_channels == 4`.
79
+ gnet ([`EDM2UNet2DModel`], *optional*):
80
+ Guiding network for autoguidance (`ref.lerp(main, guidance_scale)`).
81
+ id2label (`dict[int, str]`, *optional*):
82
+ ImageNet class id to English label mapping.
83
+ """
84
+
85
+ model_cpu_offload_seq = "unet->gnet->vae"
86
+ _optional_components = ["vae", "gnet"]
87
+
88
+ def __init__(
89
+ self,
90
+ unet,
91
+ scheduler,
92
+ vae=None,
93
+ gnet=None,
94
+ id2label: Optional[Dict[Union[int, str], str]] = None,
95
+ ) -> None:
96
+ super().__init__()
97
+ self.register_modules(unet=unet, scheduler=scheduler, vae=vae, gnet=gnet)
98
+ self._id2label = self._normalize_id2label(id2label)
99
+ self.labels = self._build_label2id(self._id2label)
100
+ self._labels_loaded_from_model_index = bool(self._id2label)
101
+ self.vae_scale_factor = 8 if self.vae is not None else 1
102
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_normalize=False)
103
+
104
+ @staticmethod
105
+ def _normalize_id2label(id2label: Optional[Dict[Union[int, str], str]]) -> Dict[int, str]:
106
+ if not id2label:
107
+ return {}
108
+ return {int(key): value for key, value in id2label.items()}
109
+
110
+ @staticmethod
111
+ def _build_label2id(id2label: Dict[int, str]) -> Dict[str, int]:
112
+ label2id: Dict[str, int] = {}
113
+ for class_id, value in id2label.items():
114
+ for synonym in value.split(","):
115
+ synonym = synonym.strip()
116
+ if synonym:
117
+ label2id[synonym] = int(class_id)
118
+ return dict(sorted(label2id.items()))
119
+
120
+ def _ensure_labels_loaded(self) -> None:
121
+ if self._labels_loaded_from_model_index:
122
+ return
123
+ loaded = self._read_id2label_from_model_index(getattr(self.config, "_name_or_path", None))
124
+ if loaded:
125
+ self._id2label = loaded
126
+ self.labels = self._build_label2id(self._id2label)
127
+ self._labels_loaded_from_model_index = True
128
+
129
+ @staticmethod
130
+ def _read_id2label_from_model_index(variant_path: Optional[str]) -> Dict[int, str]:
131
+ if not variant_path:
132
+ return {}
133
+ model_index_path = Path(variant_path).resolve() / "model_index.json"
134
+ if not model_index_path.is_file():
135
+ return {}
136
+ raw = json.loads(model_index_path.read_text(encoding="utf-8"))
137
+ id2label = raw.get("id2label")
138
+ if not isinstance(id2label, dict):
139
+ return {}
140
+ return {int(key): value for key, value in id2label.items()}
141
+
142
+ @property
143
+ def id2label(self) -> Dict[int, str]:
144
+ r"""ImageNet class id to English label string (comma-separated synonyms)."""
145
+ self._ensure_labels_loaded()
146
+ return self._id2label
147
+
148
+ def get_label_ids(self, label: Union[str, List[str]]) -> List[int]:
149
+ r"""
150
+ Map ImageNet label strings to class ids.
151
+
152
+ Args:
153
+ label (`str` or `list[str]`):
154
+ One or more English label strings that match entries in `id2label`.
155
+ """
156
+ self._ensure_labels_loaded()
157
+ if not self.labels:
158
+ raise ValueError("No English labels loaded. Add `id2label` to model_index.json.")
159
+ labels = [label] if isinstance(label, str) else list(label)
160
+ missing = [item for item in labels if item not in self.labels]
161
+ if missing:
162
+ preview = ", ".join(list(self.labels.keys())[:8])
163
+ raise ValueError(f"Unknown English label(s): {missing}. Example valid labels: {preview}, ...")
164
+ return [self.labels[item] for item in labels]
165
+
166
+ def _default_image_size(self) -> int:
167
+ latent_size = int(getattr(self.unet, "sample_size", getattr(self.unet.config, "sample_size", 64)))
168
+ return latent_size * self.vae_scale_factor
169
+
170
+ def check_inputs(
171
+ self,
172
+ height: int,
173
+ width: int,
174
+ num_inference_steps: int,
175
+ guidance_scale: float,
176
+ output_type: str,
177
+ ) -> None:
178
+ if num_inference_steps < 1:
179
+ raise ValueError("num_inference_steps must be >= 1.")
180
+ if guidance_scale < 1.0:
181
+ raise ValueError("guidance_scale must be >= 1.0.")
182
+ if guidance_scale > 1.0 and self.gnet is None:
183
+ raise ValueError("guidance_scale > 1.0 requires a guiding network (`gnet`).")
184
+ if output_type not in {"pil", "np", "pt", "latent"}:
185
+ raise ValueError("output_type must be one of: 'pil', 'np', 'pt', 'latent'.")
186
+
187
+ native_size = self._default_image_size()
188
+ if height != native_size or width != native_size:
189
+ raise ValueError(
190
+ f"EDM2 expects native resolution height=width={native_size}. "
191
+ f"Got height={height}, width={width}."
192
+ )
193
+
194
+ def _normalize_class_labels(
195
+ self,
196
+ class_labels: Optional[Union[int, str, Sequence[Union[int, str]], torch.Tensor]],
197
+ batch_size: int,
198
+ device: torch.device,
199
+ ) -> Optional[torch.Tensor]:
200
+ label_dim = int(getattr(self.unet, "num_class_embeds", getattr(self.unet.config, "num_class_embeds", 0)))
201
+ if label_dim == 0:
202
+ return None
203
+ if class_labels is None:
204
+ indices = torch.randint(label_dim, size=(batch_size,), device=device)
205
+ return torch.eye(label_dim, device=device, dtype=torch.float32)[indices]
206
+
207
+ if isinstance(class_labels, str):
208
+ class_labels = self.get_label_ids(class_labels)[0]
209
+ elif isinstance(class_labels, Sequence) and class_labels and isinstance(class_labels[0], str):
210
+ class_labels = self.get_label_ids(list(class_labels))
211
+
212
+ if isinstance(class_labels, int):
213
+ indices = torch.full((batch_size,), class_labels, device=device, dtype=torch.long)
214
+ elif isinstance(class_labels, torch.Tensor):
215
+ if class_labels.ndim == 2:
216
+ labels = class_labels.to(device=device, dtype=torch.float32)
217
+ if labels.shape[0] != batch_size:
218
+ raise ValueError(f"class_labels batch must match batch_size={batch_size}.")
219
+ return labels
220
+ indices = class_labels.to(device=device, dtype=torch.long).flatten()
221
+ else:
222
+ indices = torch.tensor(list(class_labels), device=device, dtype=torch.long)
223
+
224
+ if indices.numel() == 1 and batch_size > 1:
225
+ indices = indices.repeat(batch_size)
226
+ if indices.numel() != batch_size:
227
+ raise ValueError(f"class_labels must resolve to batch size {batch_size}, got {indices.numel()}.")
228
+ return torch.eye(label_dim, device=device, dtype=torch.float32)[indices]
229
+
230
+ def prepare_latents(
231
+ self,
232
+ batch_size: int,
233
+ height: int,
234
+ width: int,
235
+ dtype: torch.dtype,
236
+ device: torch.device,
237
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]],
238
+ ) -> torch.Tensor:
239
+ in_channels = int(getattr(self.unet, "in_channels", getattr(self.unet.config, "in_channels", 4)))
240
+ latent_size = height // self.vae_scale_factor
241
+ return randn_tensor(
242
+ (batch_size, in_channels, latent_size, latent_size),
243
+ generator=generator,
244
+ device=device,
245
+ dtype=torch.float32,
246
+ )
247
+
248
+ def decode_latents(self, latents: torch.Tensor, output_type: str = "pil"):
249
+ if output_type == "latent":
250
+ return latents
251
+
252
+ in_channels = int(getattr(self.unet, "in_channels", getattr(self.unet.config, "in_channels", 3)))
253
+ if self.vae is None:
254
+ image = (latents.to(torch.float32) * 127.5 + 128).clip(0, 255) / 255.0
255
+ return self.image_processor.postprocess(image, output_type=output_type)
256
+
257
+ if in_channels == 4:
258
+ x = latents.to(torch.float32)
259
+ scale = torch.as_tensor(_STABILITY_VAE_SCALE, dtype=x.dtype, device=x.device).reshape(1, -1, 1, 1)
260
+ bias = torch.as_tensor(_STABILITY_VAE_BIAS, dtype=x.dtype, device=x.device).reshape(1, -1, 1, 1)
261
+ x = (x - bias) / scale
262
+ else:
263
+ x = latents.to(torch.float32)
264
+
265
+ vae_dtype = getattr(self.vae, "dtype", None) or next(self.vae.parameters()).dtype
266
+ image = self.vae.decode(x.to(dtype=vae_dtype)).sample.to(torch.float32).clamp(0, 1)
267
+
268
+ return self.image_processor.postprocess(image, output_type=output_type)
269
+
270
+ @staticmethod
271
+ def _apply_autoguidance(
272
+ main: torch.Tensor,
273
+ ref: torch.Tensor,
274
+ guidance_scale: float,
275
+ ) -> torch.Tensor:
276
+ return ref.lerp(main, guidance_scale)
277
+
278
+ @staticmethod
279
+ def _sample_edm2_heun(
280
+ denoise_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
281
+ noise: torch.Tensor,
282
+ sigmas: torch.Tensor,
283
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
284
+ progress_bar: Optional[Callable[[Iterable], Iterable]] = None,
285
+ dtype: torch.dtype = torch.float32,
286
+ ) -> torch.Tensor:
287
+ """NVlabs EDM2 Heun sampler (generate_images.edm_sampler, guidance=1, S_churn=0)."""
288
+ x_next = noise.to(dtype) * sigmas[0]
289
+
290
+ sigma_pairs = list(zip(sigmas[:-1], sigmas[1:]))
291
+ if progress_bar is not None:
292
+ sigma_pairs = progress_bar(sigma_pairs)
293
+
294
+ num_steps = len(sigma_pairs)
295
+ for i, (sigma_cur, sigma_next) in enumerate(sigma_pairs):
296
+ x_hat, sigma_hat = x_next, sigma_cur
297
+ d_cur = (x_hat - denoise_fn(x_hat, sigma_hat)) / sigma_hat
298
+ x_next = x_hat + (sigma_next - sigma_hat) * d_cur
299
+ if i < num_steps - 1:
300
+ d_prime = (x_next - denoise_fn(x_next, sigma_next)) / sigma_next
301
+ x_next = x_hat + (sigma_next - sigma_hat) * (0.5 * d_cur + 0.5 * d_prime)
302
+ return x_next
303
+
304
+ @torch.inference_mode()
305
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
306
+ def __call__(
307
+ self,
308
+ class_labels: Optional[Union[int, str, Sequence[Union[int, str]], torch.Tensor]] = None,
309
+ batch_size: int = 1,
310
+ height: Optional[int] = None,
311
+ width: Optional[int] = None,
312
+ num_inference_steps: int = 32,
313
+ guidance_scale: float = 1.0,
314
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
315
+ output_type: str = "pil",
316
+ return_dict: bool = True,
317
+ ) -> Union[ImagePipelineOutput, Tuple]:
318
+ r"""
319
+ Generate class-conditional images with EDM2.
320
+
321
+ Args:
322
+ class_labels (`int`, `str`, `list[int]`, `list[str]`, or `torch.Tensor`, *optional*):
323
+ ImageNet class indices, English label strings, or one-hot float tensors.
324
+ Random classes are sampled when omitted on conditional models.
325
+ batch_size (`int`, defaults to `1`):
326
+ Number of images to generate.
327
+ height (`int`, *optional*):
328
+ Output height in pixels. Defaults to the pretrained native resolution.
329
+ width (`int`, *optional*):
330
+ Output width in pixels. Defaults to the pretrained native resolution.
331
+ num_inference_steps (`int`, defaults to `32`):
332
+ Number of EDM2 Heun steps (NVlabs default).
333
+ guidance_scale (`float`, defaults to `1.0`):
334
+ Autoguidance strength. Values above `1.0` blend the main net with `gnet`
335
+ via `gnet_output.lerp(unet_output, guidance_scale)`.
336
+ generator (`torch.Generator`, *optional*):
337
+ RNG for reproducibility.
338
+ output_type (`str`, defaults to `"pil"`):
339
+ `"pil"`, `"np"`, `"pt"`, or `"latent"`.
340
+ return_dict (`bool`, defaults to `True`):
341
+ Return [`~pipelines.pipeline_utils.ImagePipelineOutput`] if True.
342
+
343
+ Examples:
344
+ <!-- this section is replaced by replace_example_docstring -->
345
+ """
346
+ default_size = self._default_image_size()
347
+ height = int(height or default_size)
348
+ width = int(width or default_size)
349
+ self.check_inputs(height, width, num_inference_steps, guidance_scale, output_type)
350
+
351
+ device = self._execution_device
352
+ dtype = self.unet.dtype
353
+ labels = self._normalize_class_labels(class_labels, batch_size=batch_size, device=device)
354
+ noise = self.prepare_latents(batch_size, height, width, dtype, device, generator)
355
+
356
+ def denoise_fn(x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor:
357
+ sigma_batch = sigma.reshape(1).expand(batch_size)
358
+ main = self.unet(
359
+ sample=x,
360
+ sigma=sigma_batch,
361
+ class_labels=labels,
362
+ force_fp32=True,
363
+ ).sample
364
+ if guidance_scale == 1.0 or self.gnet is None:
365
+ return main.to(torch.float32)
366
+ ref = self.gnet(
367
+ sample=x,
368
+ sigma=sigma_batch,
369
+ class_labels=labels,
370
+ force_fp32=True,
371
+ ).sample
372
+ return self._apply_autoguidance(main, ref, guidance_scale).to(torch.float32)
373
+
374
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
375
+ latents = self._sample_edm2_heun(
376
+ denoise_fn=denoise_fn,
377
+ noise=noise,
378
+ sigmas=self.scheduler.sigmas.to(device),
379
+ generator=generator,
380
+ progress_bar=self.progress_bar,
381
+ dtype=torch.float32,
382
+ )
383
+
384
+ image = self.decode_latents(latents, output_type=output_type)
385
+ if not return_dict:
386
+ return (image, latents)
387
+ return ImagePipelineOutput(images=image)
388
+
389
+ @classmethod
390
+ def _load_vae(cls, pretrained_model_name_or_path: str, torch_dtype: Optional[torch.dtype] = None):
391
+ vae_dir = os.path.join(pretrained_model_name_or_path, "vae")
392
+ if os.path.isdir(vae_dir):
393
+ try:
394
+
395
+ return AutoencoderKL.from_pretrained(vae_dir, torch_dtype=torch_dtype)
396
+ except Exception:
397
+ return None
398
+
399
+ vae_hint = os.path.join(pretrained_model_name_or_path, "vae_pretrained_model_name_or_path.txt")
400
+ if os.path.isfile(vae_hint):
401
+ with open(vae_hint, "r", encoding="utf-8") as f:
402
+ hub_id = f.read().strip()
403
+ if hub_id:
404
+
405
+ return AutoencoderKL.from_pretrained(hub_id, torch_dtype=torch_dtype)
406
+ return None