Unconditional Image Generation
Diffusers
Safetensors
English
edm2
image-generation
class-conditional
imagenet
Instructions to use BiliSakura/EDM2-diffusers with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Diffusers
How to use BiliSakura/EDM2-diffusers with Diffusers:
pip install -U diffusers transformers accelerate
import torch from diffusers import DiffusionPipeline # switch to "mps" for apple devices pipe = DiffusionPipeline.from_pretrained("BiliSakura/EDM2-diffusers", dtype=torch.bfloat16, device_map="cuda") prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k" image = pipe(prompt).images[0] - Notebooks
- Google Colab
- Kaggle
Upload folder using huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- .gitattributes +7 -0
- README.md +199 -0
- edm2-img512-l-dino/demo.png +3 -0
- edm2-img512-l-dino/model_index.json +19 -0
- edm2-img512-l-dino/pipeline.py +406 -0
- edm2-img512-l-dino/scheduler/scheduler_config.json +11 -0
- edm2-img512-l-dino/unet/config.json +31 -0
- edm2-img512-l-dino/unet/diffusion_pytorch_model.safetensors +3 -0
- edm2-img512-l-dino/unet/unet_edm2.py +434 -0
- edm2-img512-l-dino/vae/config.json +38 -0
- edm2-img512-l-dino/vae/diffusion_pytorch_model.safetensors +3 -0
- edm2-img512-l-fid/generator_test.png +3 -0
- edm2-img512-l-fid/model_index.json +19 -0
- edm2-img512-l-fid/pipeline.py +406 -0
- edm2-img512-l-fid/scheduler/scheduler_config.json +11 -0
- edm2-img512-l-fid/unet/config.json +31 -0
- edm2-img512-l-fid/unet/diffusion_pytorch_model.safetensors +3 -0
- edm2-img512-l-fid/unet/unet_edm2.py +434 -0
- edm2-img512-l-fid/vae/config.json +38 -0
- edm2-img512-l-fid/vae/diffusion_pytorch_model.safetensors +3 -0
- edm2-img512-m-fid/demo.png +3 -0
- edm2-img512-m-fid/model_index.json +19 -0
- edm2-img512-m-fid/pipeline.py +406 -0
- edm2-img512-m-fid/scheduler/scheduler_config.json +11 -0
- edm2-img512-m-fid/unet/config.json +31 -0
- edm2-img512-m-fid/unet/diffusion_pytorch_model.safetensors +3 -0
- edm2-img512-m-fid/unet/unet_edm2.py +434 -0
- edm2-img512-m-fid/vae/config.json +38 -0
- edm2-img512-m-fid/vae/diffusion_pytorch_model.safetensors +3 -0
- edm2-img512-s-fid/demo.png +3 -0
- edm2-img512-s-fid/model_index.json +19 -0
- edm2-img512-s-fid/pipeline.py +406 -0
- edm2-img512-s-fid/scheduler/scheduler_config.json +11 -0
- edm2-img512-s-fid/unet/config.json +31 -0
- edm2-img512-s-fid/unet/diffusion_pytorch_model.safetensors +3 -0
- edm2-img512-s-fid/unet/unet_edm2.py +434 -0
- edm2-img512-s-fid/vae/config.json +38 -0
- edm2-img512-s-fid/vae/diffusion_pytorch_model.safetensors +3 -0
- edm2-img512-xl-fid/demo.png +3 -0
- edm2-img512-xl-fid/model_index.json +19 -0
- edm2-img512-xl-fid/pipeline.py +406 -0
- edm2-img512-xl-fid/scheduler/scheduler_config.json +11 -0
- edm2-img512-xl-fid/unet/config.json +31 -0
- edm2-img512-xl-fid/unet/diffusion_pytorch_model.safetensors +3 -0
- edm2-img512-xl-fid/unet/unet_edm2.py +434 -0
- edm2-img512-xl-fid/vae/config.json +38 -0
- edm2-img512-xl-fid/vae/diffusion_pytorch_model.safetensors +3 -0
- edm2-img512-xs-fid/demo.png +3 -0
- edm2-img512-xs-fid/model_index.json +19 -0
- edm2-img512-xs-fid/pipeline.py +406 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,10 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
edm2-img512-l-dino/demo.png filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
edm2-img512-l-fid/generator_test.png filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
edm2-img512-m-fid/demo.png filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
edm2-img512-s-fid/demo.png filter=lfs diff=lfs merge=lfs -text
|
| 40 |
+
edm2-img512-xl-fid/demo.png filter=lfs diff=lfs merge=lfs -text
|
| 41 |
+
edm2-img512-xs-fid/demo.png filter=lfs diff=lfs merge=lfs -text
|
| 42 |
+
edm2-img512-xxl-fid/demo.png filter=lfs diff=lfs merge=lfs -text
|
README.md
ADDED
|
@@ -0,0 +1,199 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: cc-by-nc-sa-4.0
|
| 3 |
+
library_name: diffusers
|
| 4 |
+
pipeline_tag: unconditional-image-generation
|
| 5 |
+
tags:
|
| 6 |
+
- diffusers
|
| 7 |
+
- edm2
|
| 8 |
+
- image-generation
|
| 9 |
+
- class-conditional
|
| 10 |
+
- imagenet
|
| 11 |
+
inference: true
|
| 12 |
+
widget:
|
| 13 |
+
- output:
|
| 14 |
+
url: edm2-img512-xxl-fid/demo.png
|
| 15 |
+
language:
|
| 16 |
+
- en
|
| 17 |
+
---
|
| 18 |
+
|
| 19 |
+
# EDM2-diffusers
|
| 20 |
+
|
| 21 |
+
Diffusers-ready checkpoints for **EDM2** ([Analyzing and Improving the Training Dynamics of Diffusion Models](https://arxiv.org/abs/2312.02696)),
|
| 22 |
+
converted from [NVlabs/edm2](https://github.com/NVlabs/edm2) post-hoc reconstructions.
|
| 23 |
+
|
| 24 |
+
Official source weights: `https://nvlabs-fi-cdn.nvidia.com/edm2/posthoc-reconstructions/`
|
| 25 |
+
|
| 26 |
+
This root folder is a model collection that contains:
|
| 27 |
+
|
| 28 |
+
- `edm2-img512-xs-fid`
|
| 29 |
+
- `edm2-img512-s-fid`
|
| 30 |
+
- `edm2-img512-m-fid`
|
| 31 |
+
- `edm2-img512-l-fid`
|
| 32 |
+
- `edm2-img512-l-dino`
|
| 33 |
+
- `edm2-img512-xl-fid`
|
| 34 |
+
- `edm2-img512-xxl-fid`
|
| 35 |
+
|
| 36 |
+
Each subfolder is a self-contained Diffusers model repo with:
|
| 37 |
+
|
| 38 |
+
- `pipeline.py`
|
| 39 |
+
- `unet/unet_edm2.py`
|
| 40 |
+
- `scheduler/scheduler_config.json` (`EDMEulerScheduler`)
|
| 41 |
+
- `unet/diffusion_pytorch_model.safetensors`
|
| 42 |
+
- `vae/diffusion_pytorch_model.safetensors`
|
| 43 |
+
|
| 44 |
+
## Demo
|
| 45 |
+
|
| 46 |
+

|
| 47 |
+
|
| 48 |
+
Class-conditional sample (ImageNet class **207**, golden retriever), EDM2-XXL at 512×512, 32 steps, guidance 1.0, seed 42.
|
| 49 |
+
|
| 50 |
+
## Model Paths
|
| 51 |
+
|
| 52 |
+
Use paths relative to this root README:
|
| 53 |
+
|
| 54 |
+
| Model | NVlabs preset | FID | Local path |
|
| 55 |
+
| --- | --- | ---: | --- |
|
| 56 |
+
| EDM2-XS | `edm2-img512-xs-fid` | 3.53 | `./edm2-img512-xs-fid` |
|
| 57 |
+
| EDM2-S | `edm2-img512-s-fid` | 2.56 | `./edm2-img512-s-fid` |
|
| 58 |
+
| EDM2-M | `edm2-img512-m-fid` | 2.25 | `./edm2-img512-m-fid` |
|
| 59 |
+
| EDM2-L | `edm2-img512-l-fid` | 2.06 | `./edm2-img512-l-fid` |
|
| 60 |
+
| EDM2-L (DINO) | `edm2-img512-l-dino` | — | `./edm2-img512-l-dino` |
|
| 61 |
+
| EDM2-XL | `edm2-img512-xl-fid` | 1.96 | `./edm2-img512-xl-fid` |
|
| 62 |
+
| EDM2-XXL | `edm2-img512-xxl-fid` | 1.91 | `./edm2-img512-xxl-fid` |
|
| 63 |
+
|
| 64 |
+
## Inference Demo (Diffusers)
|
| 65 |
+
|
| 66 |
+
### 1) Load a local subfolder checkpoint
|
| 67 |
+
|
| 68 |
+
```python
|
| 69 |
+
from pathlib import Path
|
| 70 |
+
import torch
|
| 71 |
+
from diffusers import DiffusionPipeline
|
| 72 |
+
|
| 73 |
+
model_dir = Path("./edm2-img512-xxl-fid") # change to any path in the table above
|
| 74 |
+
pipe = DiffusionPipeline.from_pretrained(
|
| 75 |
+
str(model_dir),
|
| 76 |
+
local_files_only=True,
|
| 77 |
+
trust_remote_code=True,
|
| 78 |
+
torch_dtype=torch.bfloat16,
|
| 79 |
+
).to("cuda")
|
| 80 |
+
|
| 81 |
+
generator = torch.Generator(device="cuda").manual_seed(42)
|
| 82 |
+
image = pipe(
|
| 83 |
+
class_labels=207, # golden retriever (ImageNet id); omit for random class
|
| 84 |
+
num_inference_steps=32,
|
| 85 |
+
guidance_scale=1.0, # >1.0 requires a gnet/ checkpoint
|
| 86 |
+
generator=generator,
|
| 87 |
+
).images[0]
|
| 88 |
+
image.save("demo.png")
|
| 89 |
+
```
|
| 90 |
+
|
| 91 |
+
Official inference defaults (`generate_images.py`): `num_steps=32`, `sigma_min=0.002`,
|
| 92 |
+
`sigma_max=80`, `rho=7`, `guidance=1.0` (no gnet), `S_churn=0`. Heun sampling runs in
|
| 93 |
+
float32 internally even when UNet/VAE weights are loaded in bf16/fp16.
|
| 94 |
+
|
| 95 |
+
Guided presets require a converted `gnet/` folder and `guidance_scale` matching the
|
| 96 |
+
NVlabs preset.
|
| 97 |
+
|
| 98 |
+
### 2) Convert a legacy `.pkl`
|
| 99 |
+
|
| 100 |
+
```bash
|
| 101 |
+
python scripts/convert_edm2_to_diffusers.py \
|
| 102 |
+
--checkpoint models/BiliSakura/EDM2-diffusers/edm2-img512-xs-2147483-0.135.pkl \
|
| 103 |
+
--output models/BiliSakura/EDM2-diffusers
|
| 104 |
+
```
|
| 105 |
+
|
| 106 |
+
Creates `edm2-img512-xs-fid/` automatically from the NVlabs preset mapping.
|
| 107 |
+
|
| 108 |
+
## Checkpoint preset mapping
|
| 109 |
+
|
| 110 |
+
Maps NVlabs `--preset=...` names from [`generate_images.py`](https://github.com/NVlabs/edm2/blob/main/generate_images.py)
|
| 111 |
+
to source pickle filenames and local Diffusers directories.
|
| 112 |
+
|
| 113 |
+
### EDM2 paper — ImageNet-512 (conditional)
|
| 114 |
+
|
| 115 |
+
| NVlabs preset | Source `.pkl` (net) | Diffusers dir | Metric |
|
| 116 |
+
| --- | --- | --- | --- |
|
| 117 |
+
| `edm2-img512-xs-fid` | `edm2-img512-xs-2147483-0.135.pkl` | `edm2-img512-xs-fid/` | FID 3.53 |
|
| 118 |
+
| `edm2-img512-xs-dino` | `edm2-img512-xs-2147483-0.200.pkl` | — | FD<sub>DINOv2</sub> 103.39 |
|
| 119 |
+
| `edm2-img512-s-fid` | `edm2-img512-s-2147483-0.130.pkl` | `edm2-img512-s-fid/` | FID 2.56 |
|
| 120 |
+
| `edm2-img512-s-dino` | `edm2-img512-s-2147483-0.190.pkl` | — | FD<sub>DINOv2</sub> 68.64 |
|
| 121 |
+
| `edm2-img512-m-fid` | `edm2-img512-m-2147483-0.100.pkl` | `edm2-img512-m-fid/` | FID 2.25 |
|
| 122 |
+
| `edm2-img512-m-dino` | `edm2-img512-m-2147483-0.155.pkl` | — | FD<sub>DINOv2</sub> 58.44 |
|
| 123 |
+
| `edm2-img512-l-fid` | `edm2-img512-l-1879048-0.085.pkl` | `edm2-img512-l-fid/` | FID 2.06 |
|
| 124 |
+
| `edm2-img512-l-dino` | `edm2-img512-l-1879048-0.155.pkl` | `edm2-img512-l-dino/` | FD<sub>DINOv2</sub> 52.25 |
|
| 125 |
+
| `edm2-img512-xl-fid` | `edm2-img512-xl-1342177-0.085.pkl` | `edm2-img512-xl-fid/` | FID 1.96 |
|
| 126 |
+
| `edm2-img512-xl-dino` | `edm2-img512-xl-1342177-0.155.pkl` | — | FD<sub>DINOv2</sub> 45.96 |
|
| 127 |
+
| `edm2-img512-xxl-fid` | `edm2-img512-xxl-0939524-0.070.pkl` | `edm2-img512-xxl-fid/` | FID 1.91 |
|
| 128 |
+
| `edm2-img512-xxl-dino` | `edm2-img512-xxl-0939524-0.150.pkl` | — | FD<sub>DINOv2</sub> 42.84 |
|
| 129 |
+
|
| 130 |
+
### EDM2 paper — ImageNet-64 (conditional)
|
| 131 |
+
|
| 132 |
+
| NVlabs preset | Source `.pkl` (net) | Metric |
|
| 133 |
+
| --- | --- | --- |
|
| 134 |
+
| `edm2-img64-s-fid` | `edm2-img64-s-1073741-0.075.pkl` | FID 1.58 |
|
| 135 |
+
| `edm2-img64-m-fid` | `edm2-img64-m-2147483-0.060.pkl` | FID 1.43 |
|
| 136 |
+
| `edm2-img64-l-fid` | `edm2-img64-l-1073741-0.040.pkl` | FID 1.33 |
|
| 137 |
+
| `edm2-img64-xl-fid` | `edm2-img64-xl-0671088-0.040.pkl` | FID 1.33 |
|
| 138 |
+
|
| 139 |
+
### EDM2 paper — classifier-free guidance (ImageNet-512)
|
| 140 |
+
|
| 141 |
+
Use `guidance_scale` below and include the converted `gnet/` checkpoint.
|
| 142 |
+
|
| 143 |
+
| NVlabs preset | Source `.pkl` (net) | Source `.pkl` (gnet) | Guidance | Metric |
|
| 144 |
+
| --- | --- | --- | ---: | --- |
|
| 145 |
+
| `edm2-img512-xs-guid-fid` | `edm2-img512-xs-2147483-0.045.pkl` | `edm2-img512-xs-uncond-2147483-0.045.pkl` | 1.40 | FID 2.91 |
|
| 146 |
+
| `edm2-img512-xs-guid-dino` | `edm2-img512-xs-2147483-0.150.pkl` | `edm2-img512-xs-uncond-2147483-0.150.pkl` | 1.70 | FD<sub>DINOv2</sub> 79.94 |
|
| 147 |
+
| `edm2-img512-s-guid-fid` | `edm2-img512-s-2147483-0.025.pkl` | `edm2-img512-xs-uncond-2147483-0.025.pkl` | 1.40 | FID 2.23 |
|
| 148 |
+
| `edm2-img512-s-guid-dino` | `edm2-img512-s-2147483-0.085.pkl` | `edm2-img512-xs-uncond-2147483-0.085.pkl` | 1.90 | FD<sub>DINOv2</sub> 52.32 |
|
| 149 |
+
| `edm2-img512-m-guid-fid` | `edm2-img512-m-2147483-0.030.pkl` | `edm2-img512-xs-uncond-2147483-0.030.pkl` | 1.20 | FID 2.01 |
|
| 150 |
+
| `edm2-img512-m-guid-dino` | `edm2-img512-m-2147483-0.015.pkl` | `edm2-img512-xs-uncond-2147483-0.015.pkl` | 2.00 | FD<sub>DINOv2</sub> 41.98 |
|
| 151 |
+
| `edm2-img512-l-guid-fid` | `edm2-img512-l-1879048-0.015.pkl` | `edm2-img512-xs-uncond-2147483-0.015.pkl` | 1.20 | FID 1.88 |
|
| 152 |
+
| `edm2-img512-l-guid-dino` | `edm2-img512-l-1879048-0.035.pkl` | `edm2-img512-xs-uncond-2147483-0.035.pkl` | 1.70 | FD<sub>DINOv2</sub> 38.20 |
|
| 153 |
+
| `edm2-img512-xl-guid-fid` | `edm2-img512-xl-1342177-0.020.pkl` | `edm2-img512-xs-uncond-2147483-0.020.pkl` | 1.20 | FID 1.85 |
|
| 154 |
+
| `edm2-img512-xl-guid-dino` | `edm2-img512-xl-1342177-0.030.pkl` | `edm2-img512-xs-uncond-2147483-0.030.pkl` | 1.70 | FD<sub>DINOv2</sub> 35.67 |
|
| 155 |
+
| `edm2-img512-xxl-guid-fid` | `edm2-img512-xxl-0939524-0.015.pkl` | `edm2-img512-xs-uncond-2147483-0.015.pkl` | 1.20 | FID 1.81 |
|
| 156 |
+
| `edm2-img512-xxl-guid-dino` | `edm2-img512-xxl-0939524-0.015.pkl` | `edm2-img512-xs-uncond-2147483-0.015.pkl` | 1.70 | FD<sub>DINOv2</sub> 33.09 |
|
| 157 |
+
|
| 158 |
+
### Autoguidance paper
|
| 159 |
+
|
| 160 |
+
| NVlabs preset | Source `.pkl` (net) | Source `.pkl` (gnet) | Guidance | Metric |
|
| 161 |
+
| --- | --- | --- | ---: | --- |
|
| 162 |
+
| `edm2-img512-s-autog-fid` | `edm2-img512-s-2147483-0.070.pkl` | `edm2-img512-xs-0134217-0.125.pkl` | 2.10 | FID 1.34 |
|
| 163 |
+
| `edm2-img512-s-autog-dino` | `edm2-img512-s-2147483-0.120.pkl` | `edm2-img512-xs-0134217-0.165.pkl` | 2.45 | FD<sub>DINOv2</sub> 36.67 |
|
| 164 |
+
| `edm2-img512-xxl-autog-fid` | `edm2-img512-xxl-0939524-0.075.pkl` | `edm2-img512-m-0268435-0.155.pkl` | 2.05 | FID 1.25 |
|
| 165 |
+
| `edm2-img512-xxl-autog-dino` | `edm2-img512-xxl-0939524-0.130.pkl` | `edm2-img512-m-0268435-0.205.pkl` | 2.30 | FD<sub>DINOv2</sub> 24.18 |
|
| 166 |
+
| `edm2-img512-s-uncond-autog-fid` | `edm2-img512-s-uncond-2147483-0.070.pkl` | `edm2-img512-xs-uncond-0134217-0.110.pkl` | 2.85 | FID 3.86 |
|
| 167 |
+
| `edm2-img512-s-uncond-autog-dino` | `edm2-img512-s-uncond-2147483-0.090.pkl` | `edm2-img512-xs-uncond-0134217-0.125.pkl` | 2.90 | FD<sub>DINOv2</sub> 90.39 |
|
| 168 |
+
| `edm2-img64-s-autog-fid` | `edm2-img64-s-1073741-0.045.pkl` | `edm2-img64-xs-0134217-0.110.pkl` | 1.70 | FID 1.01 |
|
| 169 |
+
| `edm2-img64-s-autog-dino` | `edm2-img64-s-1073741-0.105.pkl` | `edm2-img64-xs-0134217-0.175.pkl` | 2.20 | FD<sub>DINOv2</sub> 31.85 |
|
| 170 |
+
|
| 171 |
+
### NVlabs preset shorthand
|
| 172 |
+
|
| 173 |
+
```text
|
| 174 |
+
# EDM2 paper
|
| 175 |
+
edm2-img512-{xs|s|m|l|xl|xxl}-{fid|dino}
|
| 176 |
+
edm2-img64-{s|m|l|xl}-fid
|
| 177 |
+
edm2-img512-{xs|s|m|l|xl|xxl}-guid-{fid|dino}
|
| 178 |
+
|
| 179 |
+
# Autoguidance paper
|
| 180 |
+
edm2-img512-{s|xxl}-autog-{fid|dino}
|
| 181 |
+
edm2-img512-s-uncond-autog-{fid|dino}
|
| 182 |
+
edm2-img64-s-autog-{fid|dino}
|
| 183 |
+
```
|
| 184 |
+
|
| 185 |
+
Example NVlabs command:
|
| 186 |
+
|
| 187 |
+
```bash
|
| 188 |
+
python generate_images.py --preset=edm2-img512-s-guid-dino --outdir=out
|
| 189 |
+
```
|
| 190 |
+
|
| 191 |
+
Equivalent expanded form:
|
| 192 |
+
|
| 193 |
+
```bash
|
| 194 |
+
python generate_images.py \
|
| 195 |
+
--net=https://nvlabs-fi-cdn.nvidia.com/edm2/posthoc-reconstructions/edm2-img512-s-2147483-0.085.pkl \
|
| 196 |
+
--gnet=https://nvlabs-fi-cdn.nvidia.com/edm2/posthoc-reconstructions/edm2-img512-xs-uncond-2147483-0.085.pkl \
|
| 197 |
+
--guidance=1.9 \
|
| 198 |
+
--outdir=out
|
| 199 |
+
```
|
edm2-img512-l-dino/demo.png
ADDED
|
Git LFS Details
|
edm2-img512-l-dino/model_index.json
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_class_name": [
|
| 3 |
+
"pipeline",
|
| 4 |
+
"EDM2Pipeline"
|
| 5 |
+
],
|
| 6 |
+
"_diffusers_version": "0.31.0",
|
| 7 |
+
"scheduler": [
|
| 8 |
+
"diffusers",
|
| 9 |
+
"EDMEulerScheduler"
|
| 10 |
+
],
|
| 11 |
+
"unet": [
|
| 12 |
+
"unet_edm2",
|
| 13 |
+
"EDM2UNet2DModel"
|
| 14 |
+
],
|
| 15 |
+
"vae": [
|
| 16 |
+
"diffusers",
|
| 17 |
+
"AutoencoderKL"
|
| 18 |
+
]
|
| 19 |
+
}
|
edm2-img512-l-dino/pipeline.py
ADDED
|
@@ -0,0 +1,406 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Hub custom pipeline: EDM2Pipeline.
|
| 2 |
+
Load with native Hugging Face diffusers and trust_remote_code=True.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from __future__ import annotations
|
| 6 |
+
|
| 7 |
+
# Copyright 2026 The HuggingFace Team. All rights reserved.
|
| 8 |
+
#
|
| 9 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 10 |
+
# you may not use this file except in compliance with the License.
|
| 11 |
+
# You may obtain a copy of the License at
|
| 12 |
+
#
|
| 13 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 14 |
+
#
|
| 15 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 16 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 17 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 18 |
+
# See the License for the specific language governing permissions and
|
| 19 |
+
# limitations under the License.
|
| 20 |
+
|
| 21 |
+
import json
|
| 22 |
+
import os
|
| 23 |
+
from pathlib import Path
|
| 24 |
+
from typing import Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Union
|
| 25 |
+
|
| 26 |
+
import numpy as np
|
| 27 |
+
import torch
|
| 28 |
+
|
| 29 |
+
from diffusers.image_processor import VaeImageProcessor
|
| 30 |
+
from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
| 31 |
+
from diffusers.utils import replace_example_docstring
|
| 32 |
+
from diffusers.utils.torch_utils import randn_tensor
|
| 33 |
+
|
| 34 |
+
EXAMPLE_DOC_STRING = """
|
| 35 |
+
Examples:
|
| 36 |
+
```py
|
| 37 |
+
>>> from pathlib import Path
|
| 38 |
+
>>> import torch
|
| 39 |
+
>>> from diffusers import DiffusionPipeline
|
| 40 |
+
|
| 41 |
+
>>> model_dir = Path("BiliSakura/EDM2-diffusers/edm2-img512-xs-fid").resolve()
|
| 42 |
+
>>> pipe = DiffusionPipeline.from_pretrained(
|
| 43 |
+
... str(model_dir),
|
| 44 |
+
... local_files_only=True,
|
| 45 |
+
... custom_pipeline=str(model_dir / "pipeline.py"),
|
| 46 |
+
... trust_remote_code=True,
|
| 47 |
+
... torch_dtype=torch.float32,
|
| 48 |
+
... )
|
| 49 |
+
>>> pipe.to("cuda")
|
| 50 |
+
|
| 51 |
+
>>> generator = torch.Generator(device="cuda").manual_seed(42)
|
| 52 |
+
>>> image = pipe(
|
| 53 |
+
... class_labels=207,
|
| 54 |
+
... num_inference_steps=32,
|
| 55 |
+
... guidance_scale=1.0,
|
| 56 |
+
... generator=generator,
|
| 57 |
+
... ).images[0]
|
| 58 |
+
>>> image.save("demo.png")
|
| 59 |
+
```
|
| 60 |
+
"""
|
| 61 |
+
|
| 62 |
+
# Default Stability VAE latent whitening used by NVlabs/edm2 (training/encoders.py).
|
| 63 |
+
_STABILITY_VAE_SCALE = np.float32(0.5) / np.float32([4.17, 4.62, 3.71, 3.28])
|
| 64 |
+
_STABILITY_VAE_BIAS = np.float32(0.0) - np.float32([5.81, 3.25, 0.12, -2.15]) * _STABILITY_VAE_SCALE
|
| 65 |
+
|
| 66 |
+
class EDM2Pipeline(DiffusionPipeline):
|
| 67 |
+
r"""
|
| 68 |
+
Pipeline for class-conditional image generation with EDM2
|
| 69 |
+
([Analyzing and Improving the Training Dynamics of Diffusion Models](https://arxiv.org/abs/2312.02696)).
|
| 70 |
+
|
| 71 |
+
Parameters:
|
| 72 |
+
unet ([`EDM2UNet2DModel`]):
|
| 73 |
+
Main magnitude-preserving U-Net with EDM preconditioning.
|
| 74 |
+
scheduler ([`EDMEulerScheduler`]):
|
| 75 |
+
Built-in diffusers scheduler used for the Karras sigma schedule. EDM2 Heun sampling runs in
|
| 76 |
+
the pipeline because the UNet returns denoised latents rather than noise predictions.
|
| 77 |
+
vae ([`AutoencoderKL`], *optional*):
|
| 78 |
+
Decoder for 512px latent-diffusion checkpoints. Required when `unet.in_channels == 4`.
|
| 79 |
+
gnet ([`EDM2UNet2DModel`], *optional*):
|
| 80 |
+
Guiding network for autoguidance (`ref.lerp(main, guidance_scale)`).
|
| 81 |
+
id2label (`dict[int, str]`, *optional*):
|
| 82 |
+
ImageNet class id to English label mapping.
|
| 83 |
+
"""
|
| 84 |
+
|
| 85 |
+
model_cpu_offload_seq = "unet->gnet->vae"
|
| 86 |
+
_optional_components = ["vae", "gnet"]
|
| 87 |
+
|
| 88 |
+
def __init__(
|
| 89 |
+
self,
|
| 90 |
+
unet,
|
| 91 |
+
scheduler,
|
| 92 |
+
vae=None,
|
| 93 |
+
gnet=None,
|
| 94 |
+
id2label: Optional[Dict[Union[int, str], str]] = None,
|
| 95 |
+
) -> None:
|
| 96 |
+
super().__init__()
|
| 97 |
+
self.register_modules(unet=unet, scheduler=scheduler, vae=vae, gnet=gnet)
|
| 98 |
+
self._id2label = self._normalize_id2label(id2label)
|
| 99 |
+
self.labels = self._build_label2id(self._id2label)
|
| 100 |
+
self._labels_loaded_from_model_index = bool(self._id2label)
|
| 101 |
+
self.vae_scale_factor = 8 if self.vae is not None else 1
|
| 102 |
+
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_normalize=False)
|
| 103 |
+
|
| 104 |
+
@staticmethod
|
| 105 |
+
def _normalize_id2label(id2label: Optional[Dict[Union[int, str], str]]) -> Dict[int, str]:
|
| 106 |
+
if not id2label:
|
| 107 |
+
return {}
|
| 108 |
+
return {int(key): value for key, value in id2label.items()}
|
| 109 |
+
|
| 110 |
+
@staticmethod
|
| 111 |
+
def _build_label2id(id2label: Dict[int, str]) -> Dict[str, int]:
|
| 112 |
+
label2id: Dict[str, int] = {}
|
| 113 |
+
for class_id, value in id2label.items():
|
| 114 |
+
for synonym in value.split(","):
|
| 115 |
+
synonym = synonym.strip()
|
| 116 |
+
if synonym:
|
| 117 |
+
label2id[synonym] = int(class_id)
|
| 118 |
+
return dict(sorted(label2id.items()))
|
| 119 |
+
|
| 120 |
+
def _ensure_labels_loaded(self) -> None:
|
| 121 |
+
if self._labels_loaded_from_model_index:
|
| 122 |
+
return
|
| 123 |
+
loaded = self._read_id2label_from_model_index(getattr(self.config, "_name_or_path", None))
|
| 124 |
+
if loaded:
|
| 125 |
+
self._id2label = loaded
|
| 126 |
+
self.labels = self._build_label2id(self._id2label)
|
| 127 |
+
self._labels_loaded_from_model_index = True
|
| 128 |
+
|
| 129 |
+
@staticmethod
|
| 130 |
+
def _read_id2label_from_model_index(variant_path: Optional[str]) -> Dict[int, str]:
|
| 131 |
+
if not variant_path:
|
| 132 |
+
return {}
|
| 133 |
+
model_index_path = Path(variant_path).resolve() / "model_index.json"
|
| 134 |
+
if not model_index_path.is_file():
|
| 135 |
+
return {}
|
| 136 |
+
raw = json.loads(model_index_path.read_text(encoding="utf-8"))
|
| 137 |
+
id2label = raw.get("id2label")
|
| 138 |
+
if not isinstance(id2label, dict):
|
| 139 |
+
return {}
|
| 140 |
+
return {int(key): value for key, value in id2label.items()}
|
| 141 |
+
|
| 142 |
+
@property
|
| 143 |
+
def id2label(self) -> Dict[int, str]:
|
| 144 |
+
r"""ImageNet class id to English label string (comma-separated synonyms)."""
|
| 145 |
+
self._ensure_labels_loaded()
|
| 146 |
+
return self._id2label
|
| 147 |
+
|
| 148 |
+
def get_label_ids(self, label: Union[str, List[str]]) -> List[int]:
|
| 149 |
+
r"""
|
| 150 |
+
Map ImageNet label strings to class ids.
|
| 151 |
+
|
| 152 |
+
Args:
|
| 153 |
+
label (`str` or `list[str]`):
|
| 154 |
+
One or more English label strings that match entries in `id2label`.
|
| 155 |
+
"""
|
| 156 |
+
self._ensure_labels_loaded()
|
| 157 |
+
if not self.labels:
|
| 158 |
+
raise ValueError("No English labels loaded. Add `id2label` to model_index.json.")
|
| 159 |
+
labels = [label] if isinstance(label, str) else list(label)
|
| 160 |
+
missing = [item for item in labels if item not in self.labels]
|
| 161 |
+
if missing:
|
| 162 |
+
preview = ", ".join(list(self.labels.keys())[:8])
|
| 163 |
+
raise ValueError(f"Unknown English label(s): {missing}. Example valid labels: {preview}, ...")
|
| 164 |
+
return [self.labels[item] for item in labels]
|
| 165 |
+
|
| 166 |
+
def _default_image_size(self) -> int:
|
| 167 |
+
latent_size = int(getattr(self.unet, "sample_size", getattr(self.unet.config, "sample_size", 64)))
|
| 168 |
+
return latent_size * self.vae_scale_factor
|
| 169 |
+
|
| 170 |
+
def check_inputs(
|
| 171 |
+
self,
|
| 172 |
+
height: int,
|
| 173 |
+
width: int,
|
| 174 |
+
num_inference_steps: int,
|
| 175 |
+
guidance_scale: float,
|
| 176 |
+
output_type: str,
|
| 177 |
+
) -> None:
|
| 178 |
+
if num_inference_steps < 1:
|
| 179 |
+
raise ValueError("num_inference_steps must be >= 1.")
|
| 180 |
+
if guidance_scale < 1.0:
|
| 181 |
+
raise ValueError("guidance_scale must be >= 1.0.")
|
| 182 |
+
if guidance_scale > 1.0 and self.gnet is None:
|
| 183 |
+
raise ValueError("guidance_scale > 1.0 requires a guiding network (`gnet`).")
|
| 184 |
+
if output_type not in {"pil", "np", "pt", "latent"}:
|
| 185 |
+
raise ValueError("output_type must be one of: 'pil', 'np', 'pt', 'latent'.")
|
| 186 |
+
|
| 187 |
+
native_size = self._default_image_size()
|
| 188 |
+
if height != native_size or width != native_size:
|
| 189 |
+
raise ValueError(
|
| 190 |
+
f"EDM2 expects native resolution height=width={native_size}. "
|
| 191 |
+
f"Got height={height}, width={width}."
|
| 192 |
+
)
|
| 193 |
+
|
| 194 |
+
def _normalize_class_labels(
|
| 195 |
+
self,
|
| 196 |
+
class_labels: Optional[Union[int, str, Sequence[Union[int, str]], torch.Tensor]],
|
| 197 |
+
batch_size: int,
|
| 198 |
+
device: torch.device,
|
| 199 |
+
) -> Optional[torch.Tensor]:
|
| 200 |
+
label_dim = int(getattr(self.unet, "num_class_embeds", getattr(self.unet.config, "num_class_embeds", 0)))
|
| 201 |
+
if label_dim == 0:
|
| 202 |
+
return None
|
| 203 |
+
if class_labels is None:
|
| 204 |
+
indices = torch.randint(label_dim, size=(batch_size,), device=device)
|
| 205 |
+
return torch.eye(label_dim, device=device, dtype=torch.float32)[indices]
|
| 206 |
+
|
| 207 |
+
if isinstance(class_labels, str):
|
| 208 |
+
class_labels = self.get_label_ids(class_labels)[0]
|
| 209 |
+
elif isinstance(class_labels, Sequence) and class_labels and isinstance(class_labels[0], str):
|
| 210 |
+
class_labels = self.get_label_ids(list(class_labels))
|
| 211 |
+
|
| 212 |
+
if isinstance(class_labels, int):
|
| 213 |
+
indices = torch.full((batch_size,), class_labels, device=device, dtype=torch.long)
|
| 214 |
+
elif isinstance(class_labels, torch.Tensor):
|
| 215 |
+
if class_labels.ndim == 2:
|
| 216 |
+
labels = class_labels.to(device=device, dtype=torch.float32)
|
| 217 |
+
if labels.shape[0] != batch_size:
|
| 218 |
+
raise ValueError(f"class_labels batch must match batch_size={batch_size}.")
|
| 219 |
+
return labels
|
| 220 |
+
indices = class_labels.to(device=device, dtype=torch.long).flatten()
|
| 221 |
+
else:
|
| 222 |
+
indices = torch.tensor(list(class_labels), device=device, dtype=torch.long)
|
| 223 |
+
|
| 224 |
+
if indices.numel() == 1 and batch_size > 1:
|
| 225 |
+
indices = indices.repeat(batch_size)
|
| 226 |
+
if indices.numel() != batch_size:
|
| 227 |
+
raise ValueError(f"class_labels must resolve to batch size {batch_size}, got {indices.numel()}.")
|
| 228 |
+
return torch.eye(label_dim, device=device, dtype=torch.float32)[indices]
|
| 229 |
+
|
| 230 |
+
def prepare_latents(
|
| 231 |
+
self,
|
| 232 |
+
batch_size: int,
|
| 233 |
+
height: int,
|
| 234 |
+
width: int,
|
| 235 |
+
dtype: torch.dtype,
|
| 236 |
+
device: torch.device,
|
| 237 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]],
|
| 238 |
+
) -> torch.Tensor:
|
| 239 |
+
in_channels = int(getattr(self.unet, "in_channels", getattr(self.unet.config, "in_channels", 4)))
|
| 240 |
+
latent_size = height // self.vae_scale_factor
|
| 241 |
+
return randn_tensor(
|
| 242 |
+
(batch_size, in_channels, latent_size, latent_size),
|
| 243 |
+
generator=generator,
|
| 244 |
+
device=device,
|
| 245 |
+
dtype=torch.float32,
|
| 246 |
+
)
|
| 247 |
+
|
| 248 |
+
def decode_latents(self, latents: torch.Tensor, output_type: str = "pil"):
|
| 249 |
+
if output_type == "latent":
|
| 250 |
+
return latents
|
| 251 |
+
|
| 252 |
+
in_channels = int(getattr(self.unet, "in_channels", getattr(self.unet.config, "in_channels", 3)))
|
| 253 |
+
if self.vae is None:
|
| 254 |
+
image = (latents.to(torch.float32) * 127.5 + 128).clip(0, 255) / 255.0
|
| 255 |
+
return self.image_processor.postprocess(image, output_type=output_type)
|
| 256 |
+
|
| 257 |
+
if in_channels == 4:
|
| 258 |
+
x = latents.to(torch.float32)
|
| 259 |
+
scale = torch.as_tensor(_STABILITY_VAE_SCALE, dtype=x.dtype, device=x.device).reshape(1, -1, 1, 1)
|
| 260 |
+
bias = torch.as_tensor(_STABILITY_VAE_BIAS, dtype=x.dtype, device=x.device).reshape(1, -1, 1, 1)
|
| 261 |
+
x = (x - bias) / scale
|
| 262 |
+
else:
|
| 263 |
+
x = latents.to(torch.float32)
|
| 264 |
+
|
| 265 |
+
vae_dtype = getattr(self.vae, "dtype", None) or next(self.vae.parameters()).dtype
|
| 266 |
+
image = self.vae.decode(x.to(dtype=vae_dtype)).sample.to(torch.float32).clamp(0, 1)
|
| 267 |
+
|
| 268 |
+
return self.image_processor.postprocess(image, output_type=output_type)
|
| 269 |
+
|
| 270 |
+
@staticmethod
|
| 271 |
+
def _apply_autoguidance(
|
| 272 |
+
main: torch.Tensor,
|
| 273 |
+
ref: torch.Tensor,
|
| 274 |
+
guidance_scale: float,
|
| 275 |
+
) -> torch.Tensor:
|
| 276 |
+
return ref.lerp(main, guidance_scale)
|
| 277 |
+
|
| 278 |
+
@staticmethod
|
| 279 |
+
def _sample_edm2_heun(
|
| 280 |
+
denoise_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
|
| 281 |
+
noise: torch.Tensor,
|
| 282 |
+
sigmas: torch.Tensor,
|
| 283 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 284 |
+
progress_bar: Optional[Callable[[Iterable], Iterable]] = None,
|
| 285 |
+
dtype: torch.dtype = torch.float32,
|
| 286 |
+
) -> torch.Tensor:
|
| 287 |
+
"""NVlabs EDM2 Heun sampler (generate_images.edm_sampler, guidance=1, S_churn=0)."""
|
| 288 |
+
x_next = noise.to(dtype) * sigmas[0]
|
| 289 |
+
|
| 290 |
+
sigma_pairs = list(zip(sigmas[:-1], sigmas[1:]))
|
| 291 |
+
if progress_bar is not None:
|
| 292 |
+
sigma_pairs = progress_bar(sigma_pairs)
|
| 293 |
+
|
| 294 |
+
num_steps = len(sigma_pairs)
|
| 295 |
+
for i, (sigma_cur, sigma_next) in enumerate(sigma_pairs):
|
| 296 |
+
x_hat, sigma_hat = x_next, sigma_cur
|
| 297 |
+
d_cur = (x_hat - denoise_fn(x_hat, sigma_hat)) / sigma_hat
|
| 298 |
+
x_next = x_hat + (sigma_next - sigma_hat) * d_cur
|
| 299 |
+
if i < num_steps - 1:
|
| 300 |
+
d_prime = (x_next - denoise_fn(x_next, sigma_next)) / sigma_next
|
| 301 |
+
x_next = x_hat + (sigma_next - sigma_hat) * (0.5 * d_cur + 0.5 * d_prime)
|
| 302 |
+
return x_next
|
| 303 |
+
|
| 304 |
+
@torch.inference_mode()
|
| 305 |
+
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
| 306 |
+
def __call__(
|
| 307 |
+
self,
|
| 308 |
+
class_labels: Optional[Union[int, str, Sequence[Union[int, str]], torch.Tensor]] = None,
|
| 309 |
+
batch_size: int = 1,
|
| 310 |
+
height: Optional[int] = None,
|
| 311 |
+
width: Optional[int] = None,
|
| 312 |
+
num_inference_steps: int = 32,
|
| 313 |
+
guidance_scale: float = 1.0,
|
| 314 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 315 |
+
output_type: str = "pil",
|
| 316 |
+
return_dict: bool = True,
|
| 317 |
+
) -> Union[ImagePipelineOutput, Tuple]:
|
| 318 |
+
r"""
|
| 319 |
+
Generate class-conditional images with EDM2.
|
| 320 |
+
|
| 321 |
+
Args:
|
| 322 |
+
class_labels (`int`, `str`, `list[int]`, `list[str]`, or `torch.Tensor`, *optional*):
|
| 323 |
+
ImageNet class indices, English label strings, or one-hot float tensors.
|
| 324 |
+
Random classes are sampled when omitted on conditional models.
|
| 325 |
+
batch_size (`int`, defaults to `1`):
|
| 326 |
+
Number of images to generate.
|
| 327 |
+
height (`int`, *optional*):
|
| 328 |
+
Output height in pixels. Defaults to the pretrained native resolution.
|
| 329 |
+
width (`int`, *optional*):
|
| 330 |
+
Output width in pixels. Defaults to the pretrained native resolution.
|
| 331 |
+
num_inference_steps (`int`, defaults to `32`):
|
| 332 |
+
Number of EDM2 Heun steps (NVlabs default).
|
| 333 |
+
guidance_scale (`float`, defaults to `1.0`):
|
| 334 |
+
Autoguidance strength. Values above `1.0` blend the main net with `gnet`
|
| 335 |
+
via `gnet_output.lerp(unet_output, guidance_scale)`.
|
| 336 |
+
generator (`torch.Generator`, *optional*):
|
| 337 |
+
RNG for reproducibility.
|
| 338 |
+
output_type (`str`, defaults to `"pil"`):
|
| 339 |
+
`"pil"`, `"np"`, `"pt"`, or `"latent"`.
|
| 340 |
+
return_dict (`bool`, defaults to `True`):
|
| 341 |
+
Return [`~pipelines.pipeline_utils.ImagePipelineOutput`] if True.
|
| 342 |
+
|
| 343 |
+
Examples:
|
| 344 |
+
<!-- this section is replaced by replace_example_docstring -->
|
| 345 |
+
"""
|
| 346 |
+
default_size = self._default_image_size()
|
| 347 |
+
height = int(height or default_size)
|
| 348 |
+
width = int(width or default_size)
|
| 349 |
+
self.check_inputs(height, width, num_inference_steps, guidance_scale, output_type)
|
| 350 |
+
|
| 351 |
+
device = self._execution_device
|
| 352 |
+
dtype = self.unet.dtype
|
| 353 |
+
labels = self._normalize_class_labels(class_labels, batch_size=batch_size, device=device)
|
| 354 |
+
noise = self.prepare_latents(batch_size, height, width, dtype, device, generator)
|
| 355 |
+
|
| 356 |
+
def denoise_fn(x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor:
|
| 357 |
+
sigma_batch = sigma.reshape(1).expand(batch_size)
|
| 358 |
+
main = self.unet(
|
| 359 |
+
sample=x,
|
| 360 |
+
sigma=sigma_batch,
|
| 361 |
+
class_labels=labels,
|
| 362 |
+
force_fp32=True,
|
| 363 |
+
).sample
|
| 364 |
+
if guidance_scale == 1.0 or self.gnet is None:
|
| 365 |
+
return main.to(torch.float32)
|
| 366 |
+
ref = self.gnet(
|
| 367 |
+
sample=x,
|
| 368 |
+
sigma=sigma_batch,
|
| 369 |
+
class_labels=labels,
|
| 370 |
+
force_fp32=True,
|
| 371 |
+
).sample
|
| 372 |
+
return self._apply_autoguidance(main, ref, guidance_scale).to(torch.float32)
|
| 373 |
+
|
| 374 |
+
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
| 375 |
+
latents = self._sample_edm2_heun(
|
| 376 |
+
denoise_fn=denoise_fn,
|
| 377 |
+
noise=noise,
|
| 378 |
+
sigmas=self.scheduler.sigmas.to(device),
|
| 379 |
+
generator=generator,
|
| 380 |
+
progress_bar=self.progress_bar,
|
| 381 |
+
dtype=torch.float32,
|
| 382 |
+
)
|
| 383 |
+
|
| 384 |
+
image = self.decode_latents(latents, output_type=output_type)
|
| 385 |
+
if not return_dict:
|
| 386 |
+
return (image, latents)
|
| 387 |
+
return ImagePipelineOutput(images=image)
|
| 388 |
+
|
| 389 |
+
@classmethod
|
| 390 |
+
def _load_vae(cls, pretrained_model_name_or_path: str, torch_dtype: Optional[torch.dtype] = None):
|
| 391 |
+
vae_dir = os.path.join(pretrained_model_name_or_path, "vae")
|
| 392 |
+
if os.path.isdir(vae_dir):
|
| 393 |
+
try:
|
| 394 |
+
|
| 395 |
+
return AutoencoderKL.from_pretrained(vae_dir, torch_dtype=torch_dtype)
|
| 396 |
+
except Exception:
|
| 397 |
+
return None
|
| 398 |
+
|
| 399 |
+
vae_hint = os.path.join(pretrained_model_name_or_path, "vae_pretrained_model_name_or_path.txt")
|
| 400 |
+
if os.path.isfile(vae_hint):
|
| 401 |
+
with open(vae_hint, "r", encoding="utf-8") as f:
|
| 402 |
+
hub_id = f.read().strip()
|
| 403 |
+
if hub_id:
|
| 404 |
+
|
| 405 |
+
return AutoencoderKL.from_pretrained(hub_id, torch_dtype=torch_dtype)
|
| 406 |
+
return None
|
edm2-img512-l-dino/scheduler/scheduler_config.json
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_class_name": "EDMEulerScheduler",
|
| 3 |
+
"final_sigmas_type": "zero",
|
| 4 |
+
"num_train_timesteps": 1000,
|
| 5 |
+
"prediction_type": "epsilon",
|
| 6 |
+
"rho": 7.0,
|
| 7 |
+
"sigma_data": 0.5,
|
| 8 |
+
"sigma_max": 80.0,
|
| 9 |
+
"sigma_min": 0.002,
|
| 10 |
+
"sigma_schedule": "karras"
|
| 11 |
+
}
|
edm2-img512-l-dino/unet/config.json
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_class_name": "EDM2UNet2DModel",
|
| 3 |
+
"attn_balance": 0.3,
|
| 4 |
+
"attn_resolutions": [
|
| 5 |
+
16,
|
| 6 |
+
8
|
| 7 |
+
],
|
| 8 |
+
"channel_mult": [
|
| 9 |
+
1,
|
| 10 |
+
2,
|
| 11 |
+
3,
|
| 12 |
+
4
|
| 13 |
+
],
|
| 14 |
+
"channel_mult_emb": 4,
|
| 15 |
+
"channel_mult_noise": 1,
|
| 16 |
+
"channels_per_head": 64,
|
| 17 |
+
"clip_act": 256,
|
| 18 |
+
"concat_balance": 0.5,
|
| 19 |
+
"dropout": 0.0,
|
| 20 |
+
"in_channels": 4,
|
| 21 |
+
"label_balance": 0.5,
|
| 22 |
+
"logvar_channels": 128,
|
| 23 |
+
"model_channels": 320,
|
| 24 |
+
"num_blocks": 3,
|
| 25 |
+
"num_class_embeds": 1000,
|
| 26 |
+
"out_channels": 4,
|
| 27 |
+
"res_balance": 0.3,
|
| 28 |
+
"sample_size": 64,
|
| 29 |
+
"sigma_data": 0.5,
|
| 30 |
+
"use_fp16": true
|
| 31 |
+
}
|
edm2-img512-l-dino/unet/diffusion_pytorch_model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f13f83377a74d74e1205843e241ce6d6e4bc9e49c2661944e49fdbe4d515ba33
|
| 3 |
+
size 3110018564
|
edm2-img512-l-dino/unet/unet_edm2.py
ADDED
|
@@ -0,0 +1,434 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import json
|
| 3 |
+
import os
|
| 4 |
+
from dataclasses import dataclass
|
| 5 |
+
from typing import List, Optional, Tuple
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
import torch
|
| 9 |
+
|
| 10 |
+
try:
|
| 11 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
| 12 |
+
from diffusers.models.modeling_utils import ModelMixin
|
| 13 |
+
from diffusers.utils import BaseOutput
|
| 14 |
+
except ImportError: # pragma: no cover
|
| 15 |
+
class ModelMixin(torch.nn.Module):
|
| 16 |
+
pass
|
| 17 |
+
|
| 18 |
+
class ConfigMixin:
|
| 19 |
+
config = {}
|
| 20 |
+
|
| 21 |
+
def register_to_config(self, **kwargs):
|
| 22 |
+
self.config = kwargs
|
| 23 |
+
|
| 24 |
+
def register_to_config(func):
|
| 25 |
+
return func
|
| 26 |
+
|
| 27 |
+
@dataclass
|
| 28 |
+
class BaseOutput:
|
| 29 |
+
pass
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def normalize(x: torch.Tensor, dim: Optional[List[int]] = None, eps: float = 1e-4) -> torch.Tensor:
|
| 33 |
+
if dim is None:
|
| 34 |
+
dim = list(range(1, x.ndim))
|
| 35 |
+
norm = torch.linalg.vector_norm(x, dim=dim, keepdim=True, dtype=torch.float32)
|
| 36 |
+
norm = torch.add(eps, norm, alpha=math.sqrt(norm.numel() / x.numel()))
|
| 37 |
+
return x / norm.to(x.dtype)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def resample(x: torch.Tensor, f: List[float], mode: str = "keep") -> torch.Tensor:
|
| 41 |
+
if mode == "keep":
|
| 42 |
+
return x
|
| 43 |
+
filt = np.float32(f)
|
| 44 |
+
pad = (len(filt) - 1) // 2
|
| 45 |
+
filt = filt / filt.sum()
|
| 46 |
+
filt = np.outer(filt, filt)[np.newaxis, np.newaxis, :, :]
|
| 47 |
+
filt = torch.as_tensor(filt, dtype=x.dtype, device=x.device)
|
| 48 |
+
c = x.shape[1]
|
| 49 |
+
if mode == "down":
|
| 50 |
+
return torch.nn.functional.conv2d(x, filt.tile([c, 1, 1, 1]), groups=c, stride=2, padding=(pad,))
|
| 51 |
+
return torch.nn.functional.conv_transpose2d(x, (filt * 4).tile([c, 1, 1, 1]), groups=c, stride=2, padding=(pad,))
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def mp_silu(x: torch.Tensor) -> torch.Tensor:
|
| 55 |
+
return torch.nn.functional.silu(x) / 0.596
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def mp_sum(a: torch.Tensor, b: torch.Tensor, t: float = 0.5) -> torch.Tensor:
|
| 59 |
+
return a.lerp(b, t) / math.sqrt((1 - t) ** 2 + t ** 2)
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def mp_cat(a: torch.Tensor, b: torch.Tensor, dim: int = 1, t: float = 0.5) -> torch.Tensor:
|
| 63 |
+
na = a.shape[dim]
|
| 64 |
+
nb = b.shape[dim]
|
| 65 |
+
c = math.sqrt((na + nb) / ((1 - t) ** 2 + t ** 2))
|
| 66 |
+
wa = c / math.sqrt(na) * (1 - t)
|
| 67 |
+
wb = c / math.sqrt(nb) * t
|
| 68 |
+
return torch.cat([wa * a, wb * b], dim=dim)
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
class MPFourier(torch.nn.Module):
|
| 72 |
+
def __init__(self, num_channels: int, bandwidth: float = 1):
|
| 73 |
+
super().__init__()
|
| 74 |
+
self.register_buffer("freqs", 2 * math.pi * torch.randn(num_channels) * bandwidth)
|
| 75 |
+
self.register_buffer("phases", 2 * math.pi * torch.rand(num_channels))
|
| 76 |
+
|
| 77 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 78 |
+
y = x.to(torch.float32).ger(self.freqs.to(torch.float32))
|
| 79 |
+
y = y + self.phases.to(torch.float32)
|
| 80 |
+
y = y.cos() * math.sqrt(2)
|
| 81 |
+
return y.to(x.dtype)
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
class MPConv(torch.nn.Module):
|
| 85 |
+
def __init__(self, in_channels: int, out_channels: int, kernel: Tuple[int, ...]):
|
| 86 |
+
super().__init__()
|
| 87 |
+
self.out_channels = out_channels
|
| 88 |
+
self.weight = torch.nn.Parameter(torch.randn(out_channels, in_channels, *kernel))
|
| 89 |
+
|
| 90 |
+
def forward(self, x: torch.Tensor, gain: float = 1) -> torch.Tensor:
|
| 91 |
+
w = self.weight.to(torch.float32)
|
| 92 |
+
if self.training:
|
| 93 |
+
with torch.no_grad():
|
| 94 |
+
self.weight.copy_(normalize(w))
|
| 95 |
+
w = normalize(w)
|
| 96 |
+
w = w * (gain / math.sqrt(w[0].numel()))
|
| 97 |
+
w = w.to(x.dtype)
|
| 98 |
+
if w.ndim == 2:
|
| 99 |
+
return x @ w.t()
|
| 100 |
+
return torch.nn.functional.conv2d(x, w, padding=(w.shape[-1] // 2,))
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
class Block(torch.nn.Module):
|
| 104 |
+
def __init__(
|
| 105 |
+
self,
|
| 106 |
+
in_channels: int,
|
| 107 |
+
out_channels: int,
|
| 108 |
+
emb_channels: int,
|
| 109 |
+
flavor: str = "enc",
|
| 110 |
+
resample_mode: str = "keep",
|
| 111 |
+
resample_filter: List[float] = [1, 1],
|
| 112 |
+
attention: bool = False,
|
| 113 |
+
channels_per_head: int = 64,
|
| 114 |
+
dropout: float = 0.0,
|
| 115 |
+
res_balance: float = 0.3,
|
| 116 |
+
attn_balance: float = 0.3,
|
| 117 |
+
clip_act: Optional[float] = 256,
|
| 118 |
+
):
|
| 119 |
+
super().__init__()
|
| 120 |
+
self.out_channels = out_channels
|
| 121 |
+
self.flavor = flavor
|
| 122 |
+
self.resample_filter = resample_filter
|
| 123 |
+
self.resample_mode = resample_mode
|
| 124 |
+
self.num_heads = out_channels // channels_per_head if attention else 0
|
| 125 |
+
self.dropout = dropout
|
| 126 |
+
self.res_balance = res_balance
|
| 127 |
+
self.attn_balance = attn_balance
|
| 128 |
+
self.clip_act = clip_act
|
| 129 |
+
self.emb_gain = torch.nn.Parameter(torch.zeros([]))
|
| 130 |
+
self.conv_res0 = MPConv(out_channels if flavor == "enc" else in_channels, out_channels, kernel=(3, 3))
|
| 131 |
+
self.emb_linear = MPConv(emb_channels, out_channels, kernel=())
|
| 132 |
+
self.conv_res1 = MPConv(out_channels, out_channels, kernel=(3, 3))
|
| 133 |
+
self.conv_skip = MPConv(in_channels, out_channels, kernel=(1, 1)) if in_channels != out_channels else None
|
| 134 |
+
self.attn_qkv = MPConv(out_channels, out_channels * 3, kernel=(1, 1)) if self.num_heads else None
|
| 135 |
+
self.attn_proj = MPConv(out_channels, out_channels, kernel=(1, 1)) if self.num_heads else None
|
| 136 |
+
|
| 137 |
+
def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor:
|
| 138 |
+
x = resample(x, f=self.resample_filter, mode=self.resample_mode)
|
| 139 |
+
if self.flavor == "enc":
|
| 140 |
+
if self.conv_skip is not None:
|
| 141 |
+
x = self.conv_skip(x)
|
| 142 |
+
x = normalize(x, dim=[1])
|
| 143 |
+
|
| 144 |
+
y = self.conv_res0(mp_silu(x))
|
| 145 |
+
c = self.emb_linear(emb, gain=self.emb_gain) + 1
|
| 146 |
+
y = mp_silu(y * c.unsqueeze(2).unsqueeze(3).to(y.dtype))
|
| 147 |
+
if self.training and self.dropout:
|
| 148 |
+
y = torch.nn.functional.dropout(y, p=self.dropout)
|
| 149 |
+
y = self.conv_res1(y)
|
| 150 |
+
|
| 151 |
+
if self.flavor == "dec" and self.conv_skip is not None:
|
| 152 |
+
x = self.conv_skip(x)
|
| 153 |
+
x = mp_sum(x, y, t=self.res_balance)
|
| 154 |
+
|
| 155 |
+
if self.num_heads:
|
| 156 |
+
y = self.attn_qkv(x)
|
| 157 |
+
y = y.reshape(y.shape[0], self.num_heads, -1, 3, y.shape[2] * y.shape[3])
|
| 158 |
+
q, k, v = normalize(y, dim=[2]).unbind(3)
|
| 159 |
+
w = torch.einsum("nhcq,nhck->nhqk", q, k / math.sqrt(q.shape[2])).softmax(dim=3)
|
| 160 |
+
y = torch.einsum("nhqk,nhck->nhcq", w, v)
|
| 161 |
+
y = self.attn_proj(y.reshape(*x.shape))
|
| 162 |
+
x = mp_sum(x, y, t=self.attn_balance)
|
| 163 |
+
|
| 164 |
+
if self.clip_act is not None:
|
| 165 |
+
x = x.clip_(-self.clip_act, self.clip_act)
|
| 166 |
+
return x
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
class EDM2UNet(torch.nn.Module):
|
| 170 |
+
def __init__(
|
| 171 |
+
self,
|
| 172 |
+
img_resolution: int,
|
| 173 |
+
img_channels: int,
|
| 174 |
+
label_dim: int,
|
| 175 |
+
model_channels: int = 192,
|
| 176 |
+
channel_mult: Tuple[int, ...] = (1, 2, 3, 4),
|
| 177 |
+
channel_mult_noise: Optional[int] = None,
|
| 178 |
+
channel_mult_emb: Optional[int] = None,
|
| 179 |
+
num_blocks: int = 3,
|
| 180 |
+
attn_resolutions: Tuple[int, ...] = (16, 8),
|
| 181 |
+
label_balance: float = 0.5,
|
| 182 |
+
concat_balance: float = 0.5,
|
| 183 |
+
**block_kwargs,
|
| 184 |
+
):
|
| 185 |
+
super().__init__()
|
| 186 |
+
cblock = [model_channels * x for x in channel_mult]
|
| 187 |
+
cnoise = model_channels * channel_mult_noise if channel_mult_noise is not None else cblock[0]
|
| 188 |
+
cemb = model_channels * channel_mult_emb if channel_mult_emb is not None else max(cblock)
|
| 189 |
+
self.label_balance = label_balance
|
| 190 |
+
self.concat_balance = concat_balance
|
| 191 |
+
self.out_gain = torch.nn.Parameter(torch.zeros([]))
|
| 192 |
+
|
| 193 |
+
self.emb_fourier = MPFourier(cnoise)
|
| 194 |
+
self.emb_noise = MPConv(cnoise, cemb, kernel=())
|
| 195 |
+
self.emb_label = MPConv(label_dim, cemb, kernel=()) if label_dim else None
|
| 196 |
+
|
| 197 |
+
self.enc = torch.nn.ModuleDict()
|
| 198 |
+
cout = img_channels + 1
|
| 199 |
+
for level, channels in enumerate(cblock):
|
| 200 |
+
res = img_resolution >> level
|
| 201 |
+
if level == 0:
|
| 202 |
+
cin = cout
|
| 203 |
+
cout = channels
|
| 204 |
+
self.enc[f"{res}x{res}_conv"] = MPConv(cin, cout, kernel=(3, 3))
|
| 205 |
+
else:
|
| 206 |
+
self.enc[f"{res}x{res}_down"] = Block(cout, cout, cemb, flavor="enc", resample_mode="down", **block_kwargs)
|
| 207 |
+
for idx in range(num_blocks):
|
| 208 |
+
cin = cout
|
| 209 |
+
cout = channels
|
| 210 |
+
self.enc[f"{res}x{res}_block{idx}"] = Block(
|
| 211 |
+
cin,
|
| 212 |
+
cout,
|
| 213 |
+
cemb,
|
| 214 |
+
flavor="enc",
|
| 215 |
+
attention=(res in attn_resolutions),
|
| 216 |
+
**block_kwargs,
|
| 217 |
+
)
|
| 218 |
+
|
| 219 |
+
self.dec = torch.nn.ModuleDict()
|
| 220 |
+
skips = [block.out_channels for block in self.enc.values()]
|
| 221 |
+
for level, channels in reversed(list(enumerate(cblock))):
|
| 222 |
+
res = img_resolution >> level
|
| 223 |
+
if level == len(cblock) - 1:
|
| 224 |
+
self.dec[f"{res}x{res}_in0"] = Block(cout, cout, cemb, flavor="dec", attention=True, **block_kwargs)
|
| 225 |
+
self.dec[f"{res}x{res}_in1"] = Block(cout, cout, cemb, flavor="dec", **block_kwargs)
|
| 226 |
+
else:
|
| 227 |
+
self.dec[f"{res}x{res}_up"] = Block(cout, cout, cemb, flavor="dec", resample_mode="up", **block_kwargs)
|
| 228 |
+
for idx in range(num_blocks + 1):
|
| 229 |
+
cin = cout + skips.pop()
|
| 230 |
+
cout = channels
|
| 231 |
+
self.dec[f"{res}x{res}_block{idx}"] = Block(
|
| 232 |
+
cin,
|
| 233 |
+
cout,
|
| 234 |
+
cemb,
|
| 235 |
+
flavor="dec",
|
| 236 |
+
attention=(res in attn_resolutions),
|
| 237 |
+
**block_kwargs,
|
| 238 |
+
)
|
| 239 |
+
|
| 240 |
+
self.out_conv = MPConv(cout, img_channels, kernel=(3, 3))
|
| 241 |
+
|
| 242 |
+
def forward(self, x: torch.Tensor, noise_labels: torch.Tensor, class_labels: Optional[torch.Tensor]) -> torch.Tensor:
|
| 243 |
+
emb = self.emb_noise(self.emb_fourier(noise_labels))
|
| 244 |
+
if self.emb_label is not None:
|
| 245 |
+
if class_labels is None:
|
| 246 |
+
raise ValueError("class_labels are required for conditional EDM2UNet.")
|
| 247 |
+
emb = mp_sum(emb, self.emb_label(class_labels * math.sqrt(class_labels.shape[1])), t=self.label_balance)
|
| 248 |
+
emb = mp_silu(emb)
|
| 249 |
+
|
| 250 |
+
x = torch.cat([x, torch.ones_like(x[:, :1])], dim=1)
|
| 251 |
+
skips = []
|
| 252 |
+
for name, block in self.enc.items():
|
| 253 |
+
x = block(x) if "conv" in name else block(x, emb)
|
| 254 |
+
skips.append(x)
|
| 255 |
+
|
| 256 |
+
for name, block in self.dec.items():
|
| 257 |
+
if "block" in name:
|
| 258 |
+
x = mp_cat(x, skips.pop(), t=self.concat_balance)
|
| 259 |
+
x = block(x, emb)
|
| 260 |
+
return self.out_conv(x, gain=self.out_gain)
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
@dataclass
|
| 264 |
+
class EDM2UNet2DOutput(BaseOutput):
|
| 265 |
+
sample: torch.Tensor
|
| 266 |
+
logvar: Optional[torch.Tensor] = None
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
_CONFIG_KEYS = (
|
| 271 |
+
"sample_size",
|
| 272 |
+
"in_channels",
|
| 273 |
+
"out_channels",
|
| 274 |
+
"num_class_embeds",
|
| 275 |
+
"use_fp16",
|
| 276 |
+
"sigma_data",
|
| 277 |
+
"logvar_channels",
|
| 278 |
+
"model_channels",
|
| 279 |
+
"channel_mult",
|
| 280 |
+
"channel_mult_noise",
|
| 281 |
+
"channel_mult_emb",
|
| 282 |
+
"num_blocks",
|
| 283 |
+
"attn_resolutions",
|
| 284 |
+
"label_balance",
|
| 285 |
+
"concat_balance",
|
| 286 |
+
"dropout",
|
| 287 |
+
"channels_per_head",
|
| 288 |
+
"res_balance",
|
| 289 |
+
"attn_balance",
|
| 290 |
+
"clip_act",
|
| 291 |
+
)
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
class EDM2UNet2DModel(ModelMixin, ConfigMixin):
|
| 295 |
+
@register_to_config
|
| 296 |
+
def __init__(
|
| 297 |
+
self,
|
| 298 |
+
sample_size: int = 64,
|
| 299 |
+
in_channels: int = 4,
|
| 300 |
+
out_channels: int = 4,
|
| 301 |
+
num_class_embeds: int = 0,
|
| 302 |
+
use_fp16: bool = True,
|
| 303 |
+
sigma_data: float = 0.5,
|
| 304 |
+
logvar_channels: int = 128,
|
| 305 |
+
model_channels: int = 192,
|
| 306 |
+
channel_mult: Tuple[int, ...] = (1, 2, 3, 4),
|
| 307 |
+
channel_mult_noise: Optional[int] = None,
|
| 308 |
+
channel_mult_emb: Optional[int] = None,
|
| 309 |
+
num_blocks: int = 3,
|
| 310 |
+
attn_resolutions: Tuple[int, ...] = (16, 8),
|
| 311 |
+
label_balance: float = 0.5,
|
| 312 |
+
concat_balance: float = 0.5,
|
| 313 |
+
dropout: float = 0.0,
|
| 314 |
+
channels_per_head: int = 64,
|
| 315 |
+
res_balance: float = 0.3,
|
| 316 |
+
attn_balance: float = 0.3,
|
| 317 |
+
clip_act: Optional[float] = 256,
|
| 318 |
+
):
|
| 319 |
+
super().__init__()
|
| 320 |
+
self.sample_size = sample_size
|
| 321 |
+
self.in_channels = in_channels
|
| 322 |
+
self.out_channels = out_channels
|
| 323 |
+
self.num_class_embeds = num_class_embeds
|
| 324 |
+
self.use_fp16 = use_fp16
|
| 325 |
+
self.sigma_data = sigma_data
|
| 326 |
+
self.model_channels = model_channels
|
| 327 |
+
self.channel_mult = channel_mult
|
| 328 |
+
self.channel_mult_noise = channel_mult_noise
|
| 329 |
+
self.channel_mult_emb = channel_mult_emb
|
| 330 |
+
self.num_blocks = num_blocks
|
| 331 |
+
self.attn_resolutions = attn_resolutions
|
| 332 |
+
self.label_balance = label_balance
|
| 333 |
+
self.concat_balance = concat_balance
|
| 334 |
+
self.dropout = dropout
|
| 335 |
+
self.channels_per_head = channels_per_head
|
| 336 |
+
self.res_balance = res_balance
|
| 337 |
+
self.attn_balance = attn_balance
|
| 338 |
+
self.clip_act = clip_act
|
| 339 |
+
self.unet = EDM2UNet(
|
| 340 |
+
img_resolution=sample_size,
|
| 341 |
+
img_channels=in_channels,
|
| 342 |
+
label_dim=num_class_embeds,
|
| 343 |
+
model_channels=model_channels,
|
| 344 |
+
channel_mult=channel_mult,
|
| 345 |
+
channel_mult_noise=channel_mult_noise,
|
| 346 |
+
channel_mult_emb=channel_mult_emb,
|
| 347 |
+
num_blocks=num_blocks,
|
| 348 |
+
attn_resolutions=attn_resolutions,
|
| 349 |
+
label_balance=label_balance,
|
| 350 |
+
concat_balance=concat_balance,
|
| 351 |
+
dropout=dropout,
|
| 352 |
+
channels_per_head=channels_per_head,
|
| 353 |
+
res_balance=res_balance,
|
| 354 |
+
attn_balance=attn_balance,
|
| 355 |
+
clip_act=clip_act,
|
| 356 |
+
)
|
| 357 |
+
self.logvar_fourier = MPFourier(logvar_channels)
|
| 358 |
+
self.logvar_linear = MPConv(logvar_channels, 1, kernel=())
|
| 359 |
+
|
| 360 |
+
def forward(
|
| 361 |
+
self,
|
| 362 |
+
sample: torch.Tensor,
|
| 363 |
+
sigma: torch.Tensor,
|
| 364 |
+
class_labels: Optional[torch.Tensor] = None,
|
| 365 |
+
force_fp32: bool = False,
|
| 366 |
+
return_logvar: bool = False,
|
| 367 |
+
return_dict: bool = True,
|
| 368 |
+
) -> EDM2UNet2DOutput:
|
| 369 |
+
x = sample.to(torch.float32)
|
| 370 |
+
sigma = sigma.to(torch.float32).reshape(-1, 1, 1, 1)
|
| 371 |
+
if self.num_class_embeds == 0:
|
| 372 |
+
class_labels = None
|
| 373 |
+
else:
|
| 374 |
+
if class_labels is None:
|
| 375 |
+
class_labels = torch.zeros([x.shape[0], self.num_class_embeds], device=x.device)
|
| 376 |
+
class_labels = class_labels.to(torch.float32).reshape(-1, self.num_class_embeds)
|
| 377 |
+
dtype = torch.float16 if (self.use_fp16 and not force_fp32 and x.device.type == "cuda") else torch.float32
|
| 378 |
+
|
| 379 |
+
c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2)
|
| 380 |
+
c_out = sigma * self.sigma_data / (sigma**2 + self.sigma_data**2).sqrt()
|
| 381 |
+
c_in = 1 / (self.sigma_data**2 + sigma**2).sqrt()
|
| 382 |
+
c_noise = sigma.flatten().log() / 4
|
| 383 |
+
|
| 384 |
+
x_in = (c_in * x).to(dtype)
|
| 385 |
+
f_x = self.unet(x_in, c_noise, class_labels)
|
| 386 |
+
d_x = c_skip * x + c_out * f_x.to(torch.float32)
|
| 387 |
+
|
| 388 |
+
logvar = None
|
| 389 |
+
if return_logvar:
|
| 390 |
+
logvar = self.logvar_linear(self.logvar_fourier(c_noise)).reshape(-1, 1, 1, 1)
|
| 391 |
+
|
| 392 |
+
if not return_dict:
|
| 393 |
+
return (d_x, logvar)
|
| 394 |
+
return EDM2UNet2DOutput(sample=d_x, logvar=logvar)
|
| 395 |
+
|
| 396 |
+
@classmethod
|
| 397 |
+
def from_pretrained(cls, pretrained_model_name_or_path: str, torch_dtype: Optional[torch.dtype] = None, **kwargs):
|
| 398 |
+
subfolder = kwargs.pop("subfolder", None)
|
| 399 |
+
model_dir = os.path.join(pretrained_model_name_or_path, subfolder) if subfolder else pretrained_model_name_or_path
|
| 400 |
+
with open(os.path.join(model_dir, "config.json"), "r", encoding="utf-8") as f:
|
| 401 |
+
config = json.load(f)
|
| 402 |
+
init_kwargs = {k: v for k, v in config.items() if k in _CONFIG_KEYS}
|
| 403 |
+
model = cls(**init_kwargs)
|
| 404 |
+
weight_file = os.path.join(model_dir, "diffusion_pytorch_model.safetensors")
|
| 405 |
+
if os.path.isfile(weight_file):
|
| 406 |
+
from safetensors.torch import load_file
|
| 407 |
+
|
| 408 |
+
state_dict = load_file(weight_file)
|
| 409 |
+
else:
|
| 410 |
+
state_dict = torch.load(os.path.join(model_dir, "diffusion_pytorch_model.bin"), map_location="cpu")
|
| 411 |
+
model.load_state_dict(state_dict, strict=True)
|
| 412 |
+
if torch_dtype is not None:
|
| 413 |
+
model = model.to(dtype=torch_dtype)
|
| 414 |
+
return model
|
| 415 |
+
|
| 416 |
+
def save_pretrained(self, save_directory: str, safe_serialization: bool = True):
|
| 417 |
+
os.makedirs(save_directory, exist_ok=True)
|
| 418 |
+
stored = dict(getattr(self, "config", {}))
|
| 419 |
+
config = {"_class_name": self.__class__.__name__}
|
| 420 |
+
for key in _CONFIG_KEYS:
|
| 421 |
+
if key in stored:
|
| 422 |
+
config[key] = stored[key]
|
| 423 |
+
elif hasattr(self, key):
|
| 424 |
+
config[key] = getattr(self, key)
|
| 425 |
+
with open(os.path.join(save_directory, "config.json"), "w", encoding="utf-8") as f:
|
| 426 |
+
json.dump(config, f, indent=2, sort_keys=True)
|
| 427 |
+
f.write("\n")
|
| 428 |
+
state_dict = self.state_dict()
|
| 429 |
+
if safe_serialization:
|
| 430 |
+
from safetensors.torch import save_file
|
| 431 |
+
|
| 432 |
+
save_file(state_dict, os.path.join(save_directory, "diffusion_pytorch_model.safetensors"))
|
| 433 |
+
else:
|
| 434 |
+
torch.save(state_dict, os.path.join(save_directory, "diffusion_pytorch_model.bin"))
|
edm2-img512-l-dino/vae/config.json
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_class_name": "AutoencoderKL",
|
| 3 |
+
"_diffusers_version": "0.36.0",
|
| 4 |
+
"_name_or_path": "stabilityai/sd-vae-ft-mse",
|
| 5 |
+
"act_fn": "silu",
|
| 6 |
+
"block_out_channels": [
|
| 7 |
+
128,
|
| 8 |
+
256,
|
| 9 |
+
512,
|
| 10 |
+
512
|
| 11 |
+
],
|
| 12 |
+
"down_block_types": [
|
| 13 |
+
"DownEncoderBlock2D",
|
| 14 |
+
"DownEncoderBlock2D",
|
| 15 |
+
"DownEncoderBlock2D",
|
| 16 |
+
"DownEncoderBlock2D"
|
| 17 |
+
],
|
| 18 |
+
"force_upcast": true,
|
| 19 |
+
"in_channels": 3,
|
| 20 |
+
"latent_channels": 4,
|
| 21 |
+
"latents_mean": null,
|
| 22 |
+
"latents_std": null,
|
| 23 |
+
"layers_per_block": 2,
|
| 24 |
+
"mid_block_add_attention": true,
|
| 25 |
+
"norm_num_groups": 32,
|
| 26 |
+
"out_channels": 3,
|
| 27 |
+
"sample_size": 256,
|
| 28 |
+
"scaling_factor": 0.18215,
|
| 29 |
+
"shift_factor": null,
|
| 30 |
+
"up_block_types": [
|
| 31 |
+
"UpDecoderBlock2D",
|
| 32 |
+
"UpDecoderBlock2D",
|
| 33 |
+
"UpDecoderBlock2D",
|
| 34 |
+
"UpDecoderBlock2D"
|
| 35 |
+
],
|
| 36 |
+
"use_post_quant_conv": true,
|
| 37 |
+
"use_quant_conv": true
|
| 38 |
+
}
|
edm2-img512-l-dino/vae/diffusion_pytorch_model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a1d993488569e928462932c8c38a0760b874d166399b14414135bd9c42df5815
|
| 3 |
+
size 334643276
|
edm2-img512-l-fid/generator_test.png
ADDED
|
Git LFS Details
|
edm2-img512-l-fid/model_index.json
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_class_name": [
|
| 3 |
+
"pipeline",
|
| 4 |
+
"EDM2Pipeline"
|
| 5 |
+
],
|
| 6 |
+
"_diffusers_version": "0.31.0",
|
| 7 |
+
"scheduler": [
|
| 8 |
+
"diffusers",
|
| 9 |
+
"EDMEulerScheduler"
|
| 10 |
+
],
|
| 11 |
+
"unet": [
|
| 12 |
+
"unet_edm2",
|
| 13 |
+
"EDM2UNet2DModel"
|
| 14 |
+
],
|
| 15 |
+
"vae": [
|
| 16 |
+
"diffusers",
|
| 17 |
+
"AutoencoderKL"
|
| 18 |
+
]
|
| 19 |
+
}
|
edm2-img512-l-fid/pipeline.py
ADDED
|
@@ -0,0 +1,406 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Hub custom pipeline: EDM2Pipeline.
|
| 2 |
+
Load with native Hugging Face diffusers and trust_remote_code=True.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from __future__ import annotations
|
| 6 |
+
|
| 7 |
+
# Copyright 2026 The HuggingFace Team. All rights reserved.
|
| 8 |
+
#
|
| 9 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 10 |
+
# you may not use this file except in compliance with the License.
|
| 11 |
+
# You may obtain a copy of the License at
|
| 12 |
+
#
|
| 13 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 14 |
+
#
|
| 15 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 16 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 17 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 18 |
+
# See the License for the specific language governing permissions and
|
| 19 |
+
# limitations under the License.
|
| 20 |
+
|
| 21 |
+
import json
|
| 22 |
+
import os
|
| 23 |
+
from pathlib import Path
|
| 24 |
+
from typing import Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Union
|
| 25 |
+
|
| 26 |
+
import numpy as np
|
| 27 |
+
import torch
|
| 28 |
+
|
| 29 |
+
from diffusers.image_processor import VaeImageProcessor
|
| 30 |
+
from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
| 31 |
+
from diffusers.utils import replace_example_docstring
|
| 32 |
+
from diffusers.utils.torch_utils import randn_tensor
|
| 33 |
+
|
| 34 |
+
EXAMPLE_DOC_STRING = """
|
| 35 |
+
Examples:
|
| 36 |
+
```py
|
| 37 |
+
>>> from pathlib import Path
|
| 38 |
+
>>> import torch
|
| 39 |
+
>>> from diffusers import DiffusionPipeline
|
| 40 |
+
|
| 41 |
+
>>> model_dir = Path("BiliSakura/EDM2-diffusers/edm2-img512-xs-fid").resolve()
|
| 42 |
+
>>> pipe = DiffusionPipeline.from_pretrained(
|
| 43 |
+
... str(model_dir),
|
| 44 |
+
... local_files_only=True,
|
| 45 |
+
... custom_pipeline=str(model_dir / "pipeline.py"),
|
| 46 |
+
... trust_remote_code=True,
|
| 47 |
+
... torch_dtype=torch.float32,
|
| 48 |
+
... )
|
| 49 |
+
>>> pipe.to("cuda")
|
| 50 |
+
|
| 51 |
+
>>> generator = torch.Generator(device="cuda").manual_seed(42)
|
| 52 |
+
>>> image = pipe(
|
| 53 |
+
... class_labels=207,
|
| 54 |
+
... num_inference_steps=32,
|
| 55 |
+
... guidance_scale=1.0,
|
| 56 |
+
... generator=generator,
|
| 57 |
+
... ).images[0]
|
| 58 |
+
>>> image.save("demo.png")
|
| 59 |
+
```
|
| 60 |
+
"""
|
| 61 |
+
|
| 62 |
+
# Default Stability VAE latent whitening used by NVlabs/edm2 (training/encoders.py).
|
| 63 |
+
_STABILITY_VAE_SCALE = np.float32(0.5) / np.float32([4.17, 4.62, 3.71, 3.28])
|
| 64 |
+
_STABILITY_VAE_BIAS = np.float32(0.0) - np.float32([5.81, 3.25, 0.12, -2.15]) * _STABILITY_VAE_SCALE
|
| 65 |
+
|
| 66 |
+
class EDM2Pipeline(DiffusionPipeline):
|
| 67 |
+
r"""
|
| 68 |
+
Pipeline for class-conditional image generation with EDM2
|
| 69 |
+
([Analyzing and Improving the Training Dynamics of Diffusion Models](https://arxiv.org/abs/2312.02696)).
|
| 70 |
+
|
| 71 |
+
Parameters:
|
| 72 |
+
unet ([`EDM2UNet2DModel`]):
|
| 73 |
+
Main magnitude-preserving U-Net with EDM preconditioning.
|
| 74 |
+
scheduler ([`EDMEulerScheduler`]):
|
| 75 |
+
Built-in diffusers scheduler used for the Karras sigma schedule. EDM2 Heun sampling runs in
|
| 76 |
+
the pipeline because the UNet returns denoised latents rather than noise predictions.
|
| 77 |
+
vae ([`AutoencoderKL`], *optional*):
|
| 78 |
+
Decoder for 512px latent-diffusion checkpoints. Required when `unet.in_channels == 4`.
|
| 79 |
+
gnet ([`EDM2UNet2DModel`], *optional*):
|
| 80 |
+
Guiding network for autoguidance (`ref.lerp(main, guidance_scale)`).
|
| 81 |
+
id2label (`dict[int, str]`, *optional*):
|
| 82 |
+
ImageNet class id to English label mapping.
|
| 83 |
+
"""
|
| 84 |
+
|
| 85 |
+
model_cpu_offload_seq = "unet->gnet->vae"
|
| 86 |
+
_optional_components = ["vae", "gnet"]
|
| 87 |
+
|
| 88 |
+
def __init__(
|
| 89 |
+
self,
|
| 90 |
+
unet,
|
| 91 |
+
scheduler,
|
| 92 |
+
vae=None,
|
| 93 |
+
gnet=None,
|
| 94 |
+
id2label: Optional[Dict[Union[int, str], str]] = None,
|
| 95 |
+
) -> None:
|
| 96 |
+
super().__init__()
|
| 97 |
+
self.register_modules(unet=unet, scheduler=scheduler, vae=vae, gnet=gnet)
|
| 98 |
+
self._id2label = self._normalize_id2label(id2label)
|
| 99 |
+
self.labels = self._build_label2id(self._id2label)
|
| 100 |
+
self._labels_loaded_from_model_index = bool(self._id2label)
|
| 101 |
+
self.vae_scale_factor = 8 if self.vae is not None else 1
|
| 102 |
+
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_normalize=False)
|
| 103 |
+
|
| 104 |
+
@staticmethod
|
| 105 |
+
def _normalize_id2label(id2label: Optional[Dict[Union[int, str], str]]) -> Dict[int, str]:
|
| 106 |
+
if not id2label:
|
| 107 |
+
return {}
|
| 108 |
+
return {int(key): value for key, value in id2label.items()}
|
| 109 |
+
|
| 110 |
+
@staticmethod
|
| 111 |
+
def _build_label2id(id2label: Dict[int, str]) -> Dict[str, int]:
|
| 112 |
+
label2id: Dict[str, int] = {}
|
| 113 |
+
for class_id, value in id2label.items():
|
| 114 |
+
for synonym in value.split(","):
|
| 115 |
+
synonym = synonym.strip()
|
| 116 |
+
if synonym:
|
| 117 |
+
label2id[synonym] = int(class_id)
|
| 118 |
+
return dict(sorted(label2id.items()))
|
| 119 |
+
|
| 120 |
+
def _ensure_labels_loaded(self) -> None:
|
| 121 |
+
if self._labels_loaded_from_model_index:
|
| 122 |
+
return
|
| 123 |
+
loaded = self._read_id2label_from_model_index(getattr(self.config, "_name_or_path", None))
|
| 124 |
+
if loaded:
|
| 125 |
+
self._id2label = loaded
|
| 126 |
+
self.labels = self._build_label2id(self._id2label)
|
| 127 |
+
self._labels_loaded_from_model_index = True
|
| 128 |
+
|
| 129 |
+
@staticmethod
|
| 130 |
+
def _read_id2label_from_model_index(variant_path: Optional[str]) -> Dict[int, str]:
|
| 131 |
+
if not variant_path:
|
| 132 |
+
return {}
|
| 133 |
+
model_index_path = Path(variant_path).resolve() / "model_index.json"
|
| 134 |
+
if not model_index_path.is_file():
|
| 135 |
+
return {}
|
| 136 |
+
raw = json.loads(model_index_path.read_text(encoding="utf-8"))
|
| 137 |
+
id2label = raw.get("id2label")
|
| 138 |
+
if not isinstance(id2label, dict):
|
| 139 |
+
return {}
|
| 140 |
+
return {int(key): value for key, value in id2label.items()}
|
| 141 |
+
|
| 142 |
+
@property
|
| 143 |
+
def id2label(self) -> Dict[int, str]:
|
| 144 |
+
r"""ImageNet class id to English label string (comma-separated synonyms)."""
|
| 145 |
+
self._ensure_labels_loaded()
|
| 146 |
+
return self._id2label
|
| 147 |
+
|
| 148 |
+
def get_label_ids(self, label: Union[str, List[str]]) -> List[int]:
|
| 149 |
+
r"""
|
| 150 |
+
Map ImageNet label strings to class ids.
|
| 151 |
+
|
| 152 |
+
Args:
|
| 153 |
+
label (`str` or `list[str]`):
|
| 154 |
+
One or more English label strings that match entries in `id2label`.
|
| 155 |
+
"""
|
| 156 |
+
self._ensure_labels_loaded()
|
| 157 |
+
if not self.labels:
|
| 158 |
+
raise ValueError("No English labels loaded. Add `id2label` to model_index.json.")
|
| 159 |
+
labels = [label] if isinstance(label, str) else list(label)
|
| 160 |
+
missing = [item for item in labels if item not in self.labels]
|
| 161 |
+
if missing:
|
| 162 |
+
preview = ", ".join(list(self.labels.keys())[:8])
|
| 163 |
+
raise ValueError(f"Unknown English label(s): {missing}. Example valid labels: {preview}, ...")
|
| 164 |
+
return [self.labels[item] for item in labels]
|
| 165 |
+
|
| 166 |
+
def _default_image_size(self) -> int:
|
| 167 |
+
latent_size = int(getattr(self.unet, "sample_size", getattr(self.unet.config, "sample_size", 64)))
|
| 168 |
+
return latent_size * self.vae_scale_factor
|
| 169 |
+
|
| 170 |
+
def check_inputs(
|
| 171 |
+
self,
|
| 172 |
+
height: int,
|
| 173 |
+
width: int,
|
| 174 |
+
num_inference_steps: int,
|
| 175 |
+
guidance_scale: float,
|
| 176 |
+
output_type: str,
|
| 177 |
+
) -> None:
|
| 178 |
+
if num_inference_steps < 1:
|
| 179 |
+
raise ValueError("num_inference_steps must be >= 1.")
|
| 180 |
+
if guidance_scale < 1.0:
|
| 181 |
+
raise ValueError("guidance_scale must be >= 1.0.")
|
| 182 |
+
if guidance_scale > 1.0 and self.gnet is None:
|
| 183 |
+
raise ValueError("guidance_scale > 1.0 requires a guiding network (`gnet`).")
|
| 184 |
+
if output_type not in {"pil", "np", "pt", "latent"}:
|
| 185 |
+
raise ValueError("output_type must be one of: 'pil', 'np', 'pt', 'latent'.")
|
| 186 |
+
|
| 187 |
+
native_size = self._default_image_size()
|
| 188 |
+
if height != native_size or width != native_size:
|
| 189 |
+
raise ValueError(
|
| 190 |
+
f"EDM2 expects native resolution height=width={native_size}. "
|
| 191 |
+
f"Got height={height}, width={width}."
|
| 192 |
+
)
|
| 193 |
+
|
| 194 |
+
def _normalize_class_labels(
|
| 195 |
+
self,
|
| 196 |
+
class_labels: Optional[Union[int, str, Sequence[Union[int, str]], torch.Tensor]],
|
| 197 |
+
batch_size: int,
|
| 198 |
+
device: torch.device,
|
| 199 |
+
) -> Optional[torch.Tensor]:
|
| 200 |
+
label_dim = int(getattr(self.unet, "num_class_embeds", getattr(self.unet.config, "num_class_embeds", 0)))
|
| 201 |
+
if label_dim == 0:
|
| 202 |
+
return None
|
| 203 |
+
if class_labels is None:
|
| 204 |
+
indices = torch.randint(label_dim, size=(batch_size,), device=device)
|
| 205 |
+
return torch.eye(label_dim, device=device, dtype=torch.float32)[indices]
|
| 206 |
+
|
| 207 |
+
if isinstance(class_labels, str):
|
| 208 |
+
class_labels = self.get_label_ids(class_labels)[0]
|
| 209 |
+
elif isinstance(class_labels, Sequence) and class_labels and isinstance(class_labels[0], str):
|
| 210 |
+
class_labels = self.get_label_ids(list(class_labels))
|
| 211 |
+
|
| 212 |
+
if isinstance(class_labels, int):
|
| 213 |
+
indices = torch.full((batch_size,), class_labels, device=device, dtype=torch.long)
|
| 214 |
+
elif isinstance(class_labels, torch.Tensor):
|
| 215 |
+
if class_labels.ndim == 2:
|
| 216 |
+
labels = class_labels.to(device=device, dtype=torch.float32)
|
| 217 |
+
if labels.shape[0] != batch_size:
|
| 218 |
+
raise ValueError(f"class_labels batch must match batch_size={batch_size}.")
|
| 219 |
+
return labels
|
| 220 |
+
indices = class_labels.to(device=device, dtype=torch.long).flatten()
|
| 221 |
+
else:
|
| 222 |
+
indices = torch.tensor(list(class_labels), device=device, dtype=torch.long)
|
| 223 |
+
|
| 224 |
+
if indices.numel() == 1 and batch_size > 1:
|
| 225 |
+
indices = indices.repeat(batch_size)
|
| 226 |
+
if indices.numel() != batch_size:
|
| 227 |
+
raise ValueError(f"class_labels must resolve to batch size {batch_size}, got {indices.numel()}.")
|
| 228 |
+
return torch.eye(label_dim, device=device, dtype=torch.float32)[indices]
|
| 229 |
+
|
| 230 |
+
def prepare_latents(
|
| 231 |
+
self,
|
| 232 |
+
batch_size: int,
|
| 233 |
+
height: int,
|
| 234 |
+
width: int,
|
| 235 |
+
dtype: torch.dtype,
|
| 236 |
+
device: torch.device,
|
| 237 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]],
|
| 238 |
+
) -> torch.Tensor:
|
| 239 |
+
in_channels = int(getattr(self.unet, "in_channels", getattr(self.unet.config, "in_channels", 4)))
|
| 240 |
+
latent_size = height // self.vae_scale_factor
|
| 241 |
+
return randn_tensor(
|
| 242 |
+
(batch_size, in_channels, latent_size, latent_size),
|
| 243 |
+
generator=generator,
|
| 244 |
+
device=device,
|
| 245 |
+
dtype=torch.float32,
|
| 246 |
+
)
|
| 247 |
+
|
| 248 |
+
def decode_latents(self, latents: torch.Tensor, output_type: str = "pil"):
|
| 249 |
+
if output_type == "latent":
|
| 250 |
+
return latents
|
| 251 |
+
|
| 252 |
+
in_channels = int(getattr(self.unet, "in_channels", getattr(self.unet.config, "in_channels", 3)))
|
| 253 |
+
if self.vae is None:
|
| 254 |
+
image = (latents.to(torch.float32) * 127.5 + 128).clip(0, 255) / 255.0
|
| 255 |
+
return self.image_processor.postprocess(image, output_type=output_type)
|
| 256 |
+
|
| 257 |
+
if in_channels == 4:
|
| 258 |
+
x = latents.to(torch.float32)
|
| 259 |
+
scale = torch.as_tensor(_STABILITY_VAE_SCALE, dtype=x.dtype, device=x.device).reshape(1, -1, 1, 1)
|
| 260 |
+
bias = torch.as_tensor(_STABILITY_VAE_BIAS, dtype=x.dtype, device=x.device).reshape(1, -1, 1, 1)
|
| 261 |
+
x = (x - bias) / scale
|
| 262 |
+
else:
|
| 263 |
+
x = latents.to(torch.float32)
|
| 264 |
+
|
| 265 |
+
vae_dtype = getattr(self.vae, "dtype", None) or next(self.vae.parameters()).dtype
|
| 266 |
+
image = self.vae.decode(x.to(dtype=vae_dtype)).sample.to(torch.float32).clamp(0, 1)
|
| 267 |
+
|
| 268 |
+
return self.image_processor.postprocess(image, output_type=output_type)
|
| 269 |
+
|
| 270 |
+
@staticmethod
|
| 271 |
+
def _apply_autoguidance(
|
| 272 |
+
main: torch.Tensor,
|
| 273 |
+
ref: torch.Tensor,
|
| 274 |
+
guidance_scale: float,
|
| 275 |
+
) -> torch.Tensor:
|
| 276 |
+
return ref.lerp(main, guidance_scale)
|
| 277 |
+
|
| 278 |
+
@staticmethod
|
| 279 |
+
def _sample_edm2_heun(
|
| 280 |
+
denoise_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
|
| 281 |
+
noise: torch.Tensor,
|
| 282 |
+
sigmas: torch.Tensor,
|
| 283 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 284 |
+
progress_bar: Optional[Callable[[Iterable], Iterable]] = None,
|
| 285 |
+
dtype: torch.dtype = torch.float32,
|
| 286 |
+
) -> torch.Tensor:
|
| 287 |
+
"""NVlabs EDM2 Heun sampler (generate_images.edm_sampler, guidance=1, S_churn=0)."""
|
| 288 |
+
x_next = noise.to(dtype) * sigmas[0]
|
| 289 |
+
|
| 290 |
+
sigma_pairs = list(zip(sigmas[:-1], sigmas[1:]))
|
| 291 |
+
if progress_bar is not None:
|
| 292 |
+
sigma_pairs = progress_bar(sigma_pairs)
|
| 293 |
+
|
| 294 |
+
num_steps = len(sigma_pairs)
|
| 295 |
+
for i, (sigma_cur, sigma_next) in enumerate(sigma_pairs):
|
| 296 |
+
x_hat, sigma_hat = x_next, sigma_cur
|
| 297 |
+
d_cur = (x_hat - denoise_fn(x_hat, sigma_hat)) / sigma_hat
|
| 298 |
+
x_next = x_hat + (sigma_next - sigma_hat) * d_cur
|
| 299 |
+
if i < num_steps - 1:
|
| 300 |
+
d_prime = (x_next - denoise_fn(x_next, sigma_next)) / sigma_next
|
| 301 |
+
x_next = x_hat + (sigma_next - sigma_hat) * (0.5 * d_cur + 0.5 * d_prime)
|
| 302 |
+
return x_next
|
| 303 |
+
|
| 304 |
+
@torch.inference_mode()
|
| 305 |
+
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
| 306 |
+
def __call__(
|
| 307 |
+
self,
|
| 308 |
+
class_labels: Optional[Union[int, str, Sequence[Union[int, str]], torch.Tensor]] = None,
|
| 309 |
+
batch_size: int = 1,
|
| 310 |
+
height: Optional[int] = None,
|
| 311 |
+
width: Optional[int] = None,
|
| 312 |
+
num_inference_steps: int = 32,
|
| 313 |
+
guidance_scale: float = 1.0,
|
| 314 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 315 |
+
output_type: str = "pil",
|
| 316 |
+
return_dict: bool = True,
|
| 317 |
+
) -> Union[ImagePipelineOutput, Tuple]:
|
| 318 |
+
r"""
|
| 319 |
+
Generate class-conditional images with EDM2.
|
| 320 |
+
|
| 321 |
+
Args:
|
| 322 |
+
class_labels (`int`, `str`, `list[int]`, `list[str]`, or `torch.Tensor`, *optional*):
|
| 323 |
+
ImageNet class indices, English label strings, or one-hot float tensors.
|
| 324 |
+
Random classes are sampled when omitted on conditional models.
|
| 325 |
+
batch_size (`int`, defaults to `1`):
|
| 326 |
+
Number of images to generate.
|
| 327 |
+
height (`int`, *optional*):
|
| 328 |
+
Output height in pixels. Defaults to the pretrained native resolution.
|
| 329 |
+
width (`int`, *optional*):
|
| 330 |
+
Output width in pixels. Defaults to the pretrained native resolution.
|
| 331 |
+
num_inference_steps (`int`, defaults to `32`):
|
| 332 |
+
Number of EDM2 Heun steps (NVlabs default).
|
| 333 |
+
guidance_scale (`float`, defaults to `1.0`):
|
| 334 |
+
Autoguidance strength. Values above `1.0` blend the main net with `gnet`
|
| 335 |
+
via `gnet_output.lerp(unet_output, guidance_scale)`.
|
| 336 |
+
generator (`torch.Generator`, *optional*):
|
| 337 |
+
RNG for reproducibility.
|
| 338 |
+
output_type (`str`, defaults to `"pil"`):
|
| 339 |
+
`"pil"`, `"np"`, `"pt"`, or `"latent"`.
|
| 340 |
+
return_dict (`bool`, defaults to `True`):
|
| 341 |
+
Return [`~pipelines.pipeline_utils.ImagePipelineOutput`] if True.
|
| 342 |
+
|
| 343 |
+
Examples:
|
| 344 |
+
<!-- this section is replaced by replace_example_docstring -->
|
| 345 |
+
"""
|
| 346 |
+
default_size = self._default_image_size()
|
| 347 |
+
height = int(height or default_size)
|
| 348 |
+
width = int(width or default_size)
|
| 349 |
+
self.check_inputs(height, width, num_inference_steps, guidance_scale, output_type)
|
| 350 |
+
|
| 351 |
+
device = self._execution_device
|
| 352 |
+
dtype = self.unet.dtype
|
| 353 |
+
labels = self._normalize_class_labels(class_labels, batch_size=batch_size, device=device)
|
| 354 |
+
noise = self.prepare_latents(batch_size, height, width, dtype, device, generator)
|
| 355 |
+
|
| 356 |
+
def denoise_fn(x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor:
|
| 357 |
+
sigma_batch = sigma.reshape(1).expand(batch_size)
|
| 358 |
+
main = self.unet(
|
| 359 |
+
sample=x,
|
| 360 |
+
sigma=sigma_batch,
|
| 361 |
+
class_labels=labels,
|
| 362 |
+
force_fp32=True,
|
| 363 |
+
).sample
|
| 364 |
+
if guidance_scale == 1.0 or self.gnet is None:
|
| 365 |
+
return main.to(torch.float32)
|
| 366 |
+
ref = self.gnet(
|
| 367 |
+
sample=x,
|
| 368 |
+
sigma=sigma_batch,
|
| 369 |
+
class_labels=labels,
|
| 370 |
+
force_fp32=True,
|
| 371 |
+
).sample
|
| 372 |
+
return self._apply_autoguidance(main, ref, guidance_scale).to(torch.float32)
|
| 373 |
+
|
| 374 |
+
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
| 375 |
+
latents = self._sample_edm2_heun(
|
| 376 |
+
denoise_fn=denoise_fn,
|
| 377 |
+
noise=noise,
|
| 378 |
+
sigmas=self.scheduler.sigmas.to(device),
|
| 379 |
+
generator=generator,
|
| 380 |
+
progress_bar=self.progress_bar,
|
| 381 |
+
dtype=torch.float32,
|
| 382 |
+
)
|
| 383 |
+
|
| 384 |
+
image = self.decode_latents(latents, output_type=output_type)
|
| 385 |
+
if not return_dict:
|
| 386 |
+
return (image, latents)
|
| 387 |
+
return ImagePipelineOutput(images=image)
|
| 388 |
+
|
| 389 |
+
@classmethod
|
| 390 |
+
def _load_vae(cls, pretrained_model_name_or_path: str, torch_dtype: Optional[torch.dtype] = None):
|
| 391 |
+
vae_dir = os.path.join(pretrained_model_name_or_path, "vae")
|
| 392 |
+
if os.path.isdir(vae_dir):
|
| 393 |
+
try:
|
| 394 |
+
|
| 395 |
+
return AutoencoderKL.from_pretrained(vae_dir, torch_dtype=torch_dtype)
|
| 396 |
+
except Exception:
|
| 397 |
+
return None
|
| 398 |
+
|
| 399 |
+
vae_hint = os.path.join(pretrained_model_name_or_path, "vae_pretrained_model_name_or_path.txt")
|
| 400 |
+
if os.path.isfile(vae_hint):
|
| 401 |
+
with open(vae_hint, "r", encoding="utf-8") as f:
|
| 402 |
+
hub_id = f.read().strip()
|
| 403 |
+
if hub_id:
|
| 404 |
+
|
| 405 |
+
return AutoencoderKL.from_pretrained(hub_id, torch_dtype=torch_dtype)
|
| 406 |
+
return None
|
edm2-img512-l-fid/scheduler/scheduler_config.json
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_class_name": "EDMEulerScheduler",
|
| 3 |
+
"final_sigmas_type": "zero",
|
| 4 |
+
"num_train_timesteps": 1000,
|
| 5 |
+
"prediction_type": "epsilon",
|
| 6 |
+
"rho": 7.0,
|
| 7 |
+
"sigma_data": 0.5,
|
| 8 |
+
"sigma_max": 80.0,
|
| 9 |
+
"sigma_min": 0.002,
|
| 10 |
+
"sigma_schedule": "karras"
|
| 11 |
+
}
|
edm2-img512-l-fid/unet/config.json
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_class_name": "EDM2UNet2DModel",
|
| 3 |
+
"attn_balance": 0.3,
|
| 4 |
+
"attn_resolutions": [
|
| 5 |
+
16,
|
| 6 |
+
8
|
| 7 |
+
],
|
| 8 |
+
"channel_mult": [
|
| 9 |
+
1,
|
| 10 |
+
2,
|
| 11 |
+
3,
|
| 12 |
+
4
|
| 13 |
+
],
|
| 14 |
+
"channel_mult_emb": 4,
|
| 15 |
+
"channel_mult_noise": 1,
|
| 16 |
+
"channels_per_head": 64,
|
| 17 |
+
"clip_act": 256,
|
| 18 |
+
"concat_balance": 0.5,
|
| 19 |
+
"dropout": 0.0,
|
| 20 |
+
"in_channels": 4,
|
| 21 |
+
"label_balance": 0.5,
|
| 22 |
+
"logvar_channels": 128,
|
| 23 |
+
"model_channels": 320,
|
| 24 |
+
"num_blocks": 3,
|
| 25 |
+
"num_class_embeds": 1000,
|
| 26 |
+
"out_channels": 4,
|
| 27 |
+
"res_balance": 0.3,
|
| 28 |
+
"sample_size": 64,
|
| 29 |
+
"sigma_data": 0.5,
|
| 30 |
+
"use_fp16": true
|
| 31 |
+
}
|
edm2-img512-l-fid/unet/diffusion_pytorch_model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:7a3e3f5127c12027e4796bef297e247a38ddd13bb7b8445c5d41169106b94389
|
| 3 |
+
size 3110018564
|
edm2-img512-l-fid/unet/unet_edm2.py
ADDED
|
@@ -0,0 +1,434 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import json
|
| 3 |
+
import os
|
| 4 |
+
from dataclasses import dataclass
|
| 5 |
+
from typing import List, Optional, Tuple
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
import torch
|
| 9 |
+
|
| 10 |
+
try:
|
| 11 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
| 12 |
+
from diffusers.models.modeling_utils import ModelMixin
|
| 13 |
+
from diffusers.utils import BaseOutput
|
| 14 |
+
except ImportError: # pragma: no cover
|
| 15 |
+
class ModelMixin(torch.nn.Module):
|
| 16 |
+
pass
|
| 17 |
+
|
| 18 |
+
class ConfigMixin:
|
| 19 |
+
config = {}
|
| 20 |
+
|
| 21 |
+
def register_to_config(self, **kwargs):
|
| 22 |
+
self.config = kwargs
|
| 23 |
+
|
| 24 |
+
def register_to_config(func):
|
| 25 |
+
return func
|
| 26 |
+
|
| 27 |
+
@dataclass
|
| 28 |
+
class BaseOutput:
|
| 29 |
+
pass
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def normalize(x: torch.Tensor, dim: Optional[List[int]] = None, eps: float = 1e-4) -> torch.Tensor:
|
| 33 |
+
if dim is None:
|
| 34 |
+
dim = list(range(1, x.ndim))
|
| 35 |
+
norm = torch.linalg.vector_norm(x, dim=dim, keepdim=True, dtype=torch.float32)
|
| 36 |
+
norm = torch.add(eps, norm, alpha=math.sqrt(norm.numel() / x.numel()))
|
| 37 |
+
return x / norm.to(x.dtype)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def resample(x: torch.Tensor, f: List[float], mode: str = "keep") -> torch.Tensor:
|
| 41 |
+
if mode == "keep":
|
| 42 |
+
return x
|
| 43 |
+
filt = np.float32(f)
|
| 44 |
+
pad = (len(filt) - 1) // 2
|
| 45 |
+
filt = filt / filt.sum()
|
| 46 |
+
filt = np.outer(filt, filt)[np.newaxis, np.newaxis, :, :]
|
| 47 |
+
filt = torch.as_tensor(filt, dtype=x.dtype, device=x.device)
|
| 48 |
+
c = x.shape[1]
|
| 49 |
+
if mode == "down":
|
| 50 |
+
return torch.nn.functional.conv2d(x, filt.tile([c, 1, 1, 1]), groups=c, stride=2, padding=(pad,))
|
| 51 |
+
return torch.nn.functional.conv_transpose2d(x, (filt * 4).tile([c, 1, 1, 1]), groups=c, stride=2, padding=(pad,))
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def mp_silu(x: torch.Tensor) -> torch.Tensor:
|
| 55 |
+
return torch.nn.functional.silu(x) / 0.596
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def mp_sum(a: torch.Tensor, b: torch.Tensor, t: float = 0.5) -> torch.Tensor:
|
| 59 |
+
return a.lerp(b, t) / math.sqrt((1 - t) ** 2 + t ** 2)
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def mp_cat(a: torch.Tensor, b: torch.Tensor, dim: int = 1, t: float = 0.5) -> torch.Tensor:
|
| 63 |
+
na = a.shape[dim]
|
| 64 |
+
nb = b.shape[dim]
|
| 65 |
+
c = math.sqrt((na + nb) / ((1 - t) ** 2 + t ** 2))
|
| 66 |
+
wa = c / math.sqrt(na) * (1 - t)
|
| 67 |
+
wb = c / math.sqrt(nb) * t
|
| 68 |
+
return torch.cat([wa * a, wb * b], dim=dim)
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
class MPFourier(torch.nn.Module):
|
| 72 |
+
def __init__(self, num_channels: int, bandwidth: float = 1):
|
| 73 |
+
super().__init__()
|
| 74 |
+
self.register_buffer("freqs", 2 * math.pi * torch.randn(num_channels) * bandwidth)
|
| 75 |
+
self.register_buffer("phases", 2 * math.pi * torch.rand(num_channels))
|
| 76 |
+
|
| 77 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 78 |
+
y = x.to(torch.float32).ger(self.freqs.to(torch.float32))
|
| 79 |
+
y = y + self.phases.to(torch.float32)
|
| 80 |
+
y = y.cos() * math.sqrt(2)
|
| 81 |
+
return y.to(x.dtype)
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
class MPConv(torch.nn.Module):
|
| 85 |
+
def __init__(self, in_channels: int, out_channels: int, kernel: Tuple[int, ...]):
|
| 86 |
+
super().__init__()
|
| 87 |
+
self.out_channels = out_channels
|
| 88 |
+
self.weight = torch.nn.Parameter(torch.randn(out_channels, in_channels, *kernel))
|
| 89 |
+
|
| 90 |
+
def forward(self, x: torch.Tensor, gain: float = 1) -> torch.Tensor:
|
| 91 |
+
w = self.weight.to(torch.float32)
|
| 92 |
+
if self.training:
|
| 93 |
+
with torch.no_grad():
|
| 94 |
+
self.weight.copy_(normalize(w))
|
| 95 |
+
w = normalize(w)
|
| 96 |
+
w = w * (gain / math.sqrt(w[0].numel()))
|
| 97 |
+
w = w.to(x.dtype)
|
| 98 |
+
if w.ndim == 2:
|
| 99 |
+
return x @ w.t()
|
| 100 |
+
return torch.nn.functional.conv2d(x, w, padding=(w.shape[-1] // 2,))
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
class Block(torch.nn.Module):
|
| 104 |
+
def __init__(
|
| 105 |
+
self,
|
| 106 |
+
in_channels: int,
|
| 107 |
+
out_channels: int,
|
| 108 |
+
emb_channels: int,
|
| 109 |
+
flavor: str = "enc",
|
| 110 |
+
resample_mode: str = "keep",
|
| 111 |
+
resample_filter: List[float] = [1, 1],
|
| 112 |
+
attention: bool = False,
|
| 113 |
+
channels_per_head: int = 64,
|
| 114 |
+
dropout: float = 0.0,
|
| 115 |
+
res_balance: float = 0.3,
|
| 116 |
+
attn_balance: float = 0.3,
|
| 117 |
+
clip_act: Optional[float] = 256,
|
| 118 |
+
):
|
| 119 |
+
super().__init__()
|
| 120 |
+
self.out_channels = out_channels
|
| 121 |
+
self.flavor = flavor
|
| 122 |
+
self.resample_filter = resample_filter
|
| 123 |
+
self.resample_mode = resample_mode
|
| 124 |
+
self.num_heads = out_channels // channels_per_head if attention else 0
|
| 125 |
+
self.dropout = dropout
|
| 126 |
+
self.res_balance = res_balance
|
| 127 |
+
self.attn_balance = attn_balance
|
| 128 |
+
self.clip_act = clip_act
|
| 129 |
+
self.emb_gain = torch.nn.Parameter(torch.zeros([]))
|
| 130 |
+
self.conv_res0 = MPConv(out_channels if flavor == "enc" else in_channels, out_channels, kernel=(3, 3))
|
| 131 |
+
self.emb_linear = MPConv(emb_channels, out_channels, kernel=())
|
| 132 |
+
self.conv_res1 = MPConv(out_channels, out_channels, kernel=(3, 3))
|
| 133 |
+
self.conv_skip = MPConv(in_channels, out_channels, kernel=(1, 1)) if in_channels != out_channels else None
|
| 134 |
+
self.attn_qkv = MPConv(out_channels, out_channels * 3, kernel=(1, 1)) if self.num_heads else None
|
| 135 |
+
self.attn_proj = MPConv(out_channels, out_channels, kernel=(1, 1)) if self.num_heads else None
|
| 136 |
+
|
| 137 |
+
def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor:
|
| 138 |
+
x = resample(x, f=self.resample_filter, mode=self.resample_mode)
|
| 139 |
+
if self.flavor == "enc":
|
| 140 |
+
if self.conv_skip is not None:
|
| 141 |
+
x = self.conv_skip(x)
|
| 142 |
+
x = normalize(x, dim=[1])
|
| 143 |
+
|
| 144 |
+
y = self.conv_res0(mp_silu(x))
|
| 145 |
+
c = self.emb_linear(emb, gain=self.emb_gain) + 1
|
| 146 |
+
y = mp_silu(y * c.unsqueeze(2).unsqueeze(3).to(y.dtype))
|
| 147 |
+
if self.training and self.dropout:
|
| 148 |
+
y = torch.nn.functional.dropout(y, p=self.dropout)
|
| 149 |
+
y = self.conv_res1(y)
|
| 150 |
+
|
| 151 |
+
if self.flavor == "dec" and self.conv_skip is not None:
|
| 152 |
+
x = self.conv_skip(x)
|
| 153 |
+
x = mp_sum(x, y, t=self.res_balance)
|
| 154 |
+
|
| 155 |
+
if self.num_heads:
|
| 156 |
+
y = self.attn_qkv(x)
|
| 157 |
+
y = y.reshape(y.shape[0], self.num_heads, -1, 3, y.shape[2] * y.shape[3])
|
| 158 |
+
q, k, v = normalize(y, dim=[2]).unbind(3)
|
| 159 |
+
w = torch.einsum("nhcq,nhck->nhqk", q, k / math.sqrt(q.shape[2])).softmax(dim=3)
|
| 160 |
+
y = torch.einsum("nhqk,nhck->nhcq", w, v)
|
| 161 |
+
y = self.attn_proj(y.reshape(*x.shape))
|
| 162 |
+
x = mp_sum(x, y, t=self.attn_balance)
|
| 163 |
+
|
| 164 |
+
if self.clip_act is not None:
|
| 165 |
+
x = x.clip_(-self.clip_act, self.clip_act)
|
| 166 |
+
return x
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
class EDM2UNet(torch.nn.Module):
|
| 170 |
+
def __init__(
|
| 171 |
+
self,
|
| 172 |
+
img_resolution: int,
|
| 173 |
+
img_channels: int,
|
| 174 |
+
label_dim: int,
|
| 175 |
+
model_channels: int = 192,
|
| 176 |
+
channel_mult: Tuple[int, ...] = (1, 2, 3, 4),
|
| 177 |
+
channel_mult_noise: Optional[int] = None,
|
| 178 |
+
channel_mult_emb: Optional[int] = None,
|
| 179 |
+
num_blocks: int = 3,
|
| 180 |
+
attn_resolutions: Tuple[int, ...] = (16, 8),
|
| 181 |
+
label_balance: float = 0.5,
|
| 182 |
+
concat_balance: float = 0.5,
|
| 183 |
+
**block_kwargs,
|
| 184 |
+
):
|
| 185 |
+
super().__init__()
|
| 186 |
+
cblock = [model_channels * x for x in channel_mult]
|
| 187 |
+
cnoise = model_channels * channel_mult_noise if channel_mult_noise is not None else cblock[0]
|
| 188 |
+
cemb = model_channels * channel_mult_emb if channel_mult_emb is not None else max(cblock)
|
| 189 |
+
self.label_balance = label_balance
|
| 190 |
+
self.concat_balance = concat_balance
|
| 191 |
+
self.out_gain = torch.nn.Parameter(torch.zeros([]))
|
| 192 |
+
|
| 193 |
+
self.emb_fourier = MPFourier(cnoise)
|
| 194 |
+
self.emb_noise = MPConv(cnoise, cemb, kernel=())
|
| 195 |
+
self.emb_label = MPConv(label_dim, cemb, kernel=()) if label_dim else None
|
| 196 |
+
|
| 197 |
+
self.enc = torch.nn.ModuleDict()
|
| 198 |
+
cout = img_channels + 1
|
| 199 |
+
for level, channels in enumerate(cblock):
|
| 200 |
+
res = img_resolution >> level
|
| 201 |
+
if level == 0:
|
| 202 |
+
cin = cout
|
| 203 |
+
cout = channels
|
| 204 |
+
self.enc[f"{res}x{res}_conv"] = MPConv(cin, cout, kernel=(3, 3))
|
| 205 |
+
else:
|
| 206 |
+
self.enc[f"{res}x{res}_down"] = Block(cout, cout, cemb, flavor="enc", resample_mode="down", **block_kwargs)
|
| 207 |
+
for idx in range(num_blocks):
|
| 208 |
+
cin = cout
|
| 209 |
+
cout = channels
|
| 210 |
+
self.enc[f"{res}x{res}_block{idx}"] = Block(
|
| 211 |
+
cin,
|
| 212 |
+
cout,
|
| 213 |
+
cemb,
|
| 214 |
+
flavor="enc",
|
| 215 |
+
attention=(res in attn_resolutions),
|
| 216 |
+
**block_kwargs,
|
| 217 |
+
)
|
| 218 |
+
|
| 219 |
+
self.dec = torch.nn.ModuleDict()
|
| 220 |
+
skips = [block.out_channels for block in self.enc.values()]
|
| 221 |
+
for level, channels in reversed(list(enumerate(cblock))):
|
| 222 |
+
res = img_resolution >> level
|
| 223 |
+
if level == len(cblock) - 1:
|
| 224 |
+
self.dec[f"{res}x{res}_in0"] = Block(cout, cout, cemb, flavor="dec", attention=True, **block_kwargs)
|
| 225 |
+
self.dec[f"{res}x{res}_in1"] = Block(cout, cout, cemb, flavor="dec", **block_kwargs)
|
| 226 |
+
else:
|
| 227 |
+
self.dec[f"{res}x{res}_up"] = Block(cout, cout, cemb, flavor="dec", resample_mode="up", **block_kwargs)
|
| 228 |
+
for idx in range(num_blocks + 1):
|
| 229 |
+
cin = cout + skips.pop()
|
| 230 |
+
cout = channels
|
| 231 |
+
self.dec[f"{res}x{res}_block{idx}"] = Block(
|
| 232 |
+
cin,
|
| 233 |
+
cout,
|
| 234 |
+
cemb,
|
| 235 |
+
flavor="dec",
|
| 236 |
+
attention=(res in attn_resolutions),
|
| 237 |
+
**block_kwargs,
|
| 238 |
+
)
|
| 239 |
+
|
| 240 |
+
self.out_conv = MPConv(cout, img_channels, kernel=(3, 3))
|
| 241 |
+
|
| 242 |
+
def forward(self, x: torch.Tensor, noise_labels: torch.Tensor, class_labels: Optional[torch.Tensor]) -> torch.Tensor:
|
| 243 |
+
emb = self.emb_noise(self.emb_fourier(noise_labels))
|
| 244 |
+
if self.emb_label is not None:
|
| 245 |
+
if class_labels is None:
|
| 246 |
+
raise ValueError("class_labels are required for conditional EDM2UNet.")
|
| 247 |
+
emb = mp_sum(emb, self.emb_label(class_labels * math.sqrt(class_labels.shape[1])), t=self.label_balance)
|
| 248 |
+
emb = mp_silu(emb)
|
| 249 |
+
|
| 250 |
+
x = torch.cat([x, torch.ones_like(x[:, :1])], dim=1)
|
| 251 |
+
skips = []
|
| 252 |
+
for name, block in self.enc.items():
|
| 253 |
+
x = block(x) if "conv" in name else block(x, emb)
|
| 254 |
+
skips.append(x)
|
| 255 |
+
|
| 256 |
+
for name, block in self.dec.items():
|
| 257 |
+
if "block" in name:
|
| 258 |
+
x = mp_cat(x, skips.pop(), t=self.concat_balance)
|
| 259 |
+
x = block(x, emb)
|
| 260 |
+
return self.out_conv(x, gain=self.out_gain)
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
@dataclass
|
| 264 |
+
class EDM2UNet2DOutput(BaseOutput):
|
| 265 |
+
sample: torch.Tensor
|
| 266 |
+
logvar: Optional[torch.Tensor] = None
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
_CONFIG_KEYS = (
|
| 271 |
+
"sample_size",
|
| 272 |
+
"in_channels",
|
| 273 |
+
"out_channels",
|
| 274 |
+
"num_class_embeds",
|
| 275 |
+
"use_fp16",
|
| 276 |
+
"sigma_data",
|
| 277 |
+
"logvar_channels",
|
| 278 |
+
"model_channels",
|
| 279 |
+
"channel_mult",
|
| 280 |
+
"channel_mult_noise",
|
| 281 |
+
"channel_mult_emb",
|
| 282 |
+
"num_blocks",
|
| 283 |
+
"attn_resolutions",
|
| 284 |
+
"label_balance",
|
| 285 |
+
"concat_balance",
|
| 286 |
+
"dropout",
|
| 287 |
+
"channels_per_head",
|
| 288 |
+
"res_balance",
|
| 289 |
+
"attn_balance",
|
| 290 |
+
"clip_act",
|
| 291 |
+
)
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
class EDM2UNet2DModel(ModelMixin, ConfigMixin):
|
| 295 |
+
@register_to_config
|
| 296 |
+
def __init__(
|
| 297 |
+
self,
|
| 298 |
+
sample_size: int = 64,
|
| 299 |
+
in_channels: int = 4,
|
| 300 |
+
out_channels: int = 4,
|
| 301 |
+
num_class_embeds: int = 0,
|
| 302 |
+
use_fp16: bool = True,
|
| 303 |
+
sigma_data: float = 0.5,
|
| 304 |
+
logvar_channels: int = 128,
|
| 305 |
+
model_channels: int = 192,
|
| 306 |
+
channel_mult: Tuple[int, ...] = (1, 2, 3, 4),
|
| 307 |
+
channel_mult_noise: Optional[int] = None,
|
| 308 |
+
channel_mult_emb: Optional[int] = None,
|
| 309 |
+
num_blocks: int = 3,
|
| 310 |
+
attn_resolutions: Tuple[int, ...] = (16, 8),
|
| 311 |
+
label_balance: float = 0.5,
|
| 312 |
+
concat_balance: float = 0.5,
|
| 313 |
+
dropout: float = 0.0,
|
| 314 |
+
channels_per_head: int = 64,
|
| 315 |
+
res_balance: float = 0.3,
|
| 316 |
+
attn_balance: float = 0.3,
|
| 317 |
+
clip_act: Optional[float] = 256,
|
| 318 |
+
):
|
| 319 |
+
super().__init__()
|
| 320 |
+
self.sample_size = sample_size
|
| 321 |
+
self.in_channels = in_channels
|
| 322 |
+
self.out_channels = out_channels
|
| 323 |
+
self.num_class_embeds = num_class_embeds
|
| 324 |
+
self.use_fp16 = use_fp16
|
| 325 |
+
self.sigma_data = sigma_data
|
| 326 |
+
self.model_channels = model_channels
|
| 327 |
+
self.channel_mult = channel_mult
|
| 328 |
+
self.channel_mult_noise = channel_mult_noise
|
| 329 |
+
self.channel_mult_emb = channel_mult_emb
|
| 330 |
+
self.num_blocks = num_blocks
|
| 331 |
+
self.attn_resolutions = attn_resolutions
|
| 332 |
+
self.label_balance = label_balance
|
| 333 |
+
self.concat_balance = concat_balance
|
| 334 |
+
self.dropout = dropout
|
| 335 |
+
self.channels_per_head = channels_per_head
|
| 336 |
+
self.res_balance = res_balance
|
| 337 |
+
self.attn_balance = attn_balance
|
| 338 |
+
self.clip_act = clip_act
|
| 339 |
+
self.unet = EDM2UNet(
|
| 340 |
+
img_resolution=sample_size,
|
| 341 |
+
img_channels=in_channels,
|
| 342 |
+
label_dim=num_class_embeds,
|
| 343 |
+
model_channels=model_channels,
|
| 344 |
+
channel_mult=channel_mult,
|
| 345 |
+
channel_mult_noise=channel_mult_noise,
|
| 346 |
+
channel_mult_emb=channel_mult_emb,
|
| 347 |
+
num_blocks=num_blocks,
|
| 348 |
+
attn_resolutions=attn_resolutions,
|
| 349 |
+
label_balance=label_balance,
|
| 350 |
+
concat_balance=concat_balance,
|
| 351 |
+
dropout=dropout,
|
| 352 |
+
channels_per_head=channels_per_head,
|
| 353 |
+
res_balance=res_balance,
|
| 354 |
+
attn_balance=attn_balance,
|
| 355 |
+
clip_act=clip_act,
|
| 356 |
+
)
|
| 357 |
+
self.logvar_fourier = MPFourier(logvar_channels)
|
| 358 |
+
self.logvar_linear = MPConv(logvar_channels, 1, kernel=())
|
| 359 |
+
|
| 360 |
+
def forward(
|
| 361 |
+
self,
|
| 362 |
+
sample: torch.Tensor,
|
| 363 |
+
sigma: torch.Tensor,
|
| 364 |
+
class_labels: Optional[torch.Tensor] = None,
|
| 365 |
+
force_fp32: bool = False,
|
| 366 |
+
return_logvar: bool = False,
|
| 367 |
+
return_dict: bool = True,
|
| 368 |
+
) -> EDM2UNet2DOutput:
|
| 369 |
+
x = sample.to(torch.float32)
|
| 370 |
+
sigma = sigma.to(torch.float32).reshape(-1, 1, 1, 1)
|
| 371 |
+
if self.num_class_embeds == 0:
|
| 372 |
+
class_labels = None
|
| 373 |
+
else:
|
| 374 |
+
if class_labels is None:
|
| 375 |
+
class_labels = torch.zeros([x.shape[0], self.num_class_embeds], device=x.device)
|
| 376 |
+
class_labels = class_labels.to(torch.float32).reshape(-1, self.num_class_embeds)
|
| 377 |
+
dtype = torch.float16 if (self.use_fp16 and not force_fp32 and x.device.type == "cuda") else torch.float32
|
| 378 |
+
|
| 379 |
+
c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2)
|
| 380 |
+
c_out = sigma * self.sigma_data / (sigma**2 + self.sigma_data**2).sqrt()
|
| 381 |
+
c_in = 1 / (self.sigma_data**2 + sigma**2).sqrt()
|
| 382 |
+
c_noise = sigma.flatten().log() / 4
|
| 383 |
+
|
| 384 |
+
x_in = (c_in * x).to(dtype)
|
| 385 |
+
f_x = self.unet(x_in, c_noise, class_labels)
|
| 386 |
+
d_x = c_skip * x + c_out * f_x.to(torch.float32)
|
| 387 |
+
|
| 388 |
+
logvar = None
|
| 389 |
+
if return_logvar:
|
| 390 |
+
logvar = self.logvar_linear(self.logvar_fourier(c_noise)).reshape(-1, 1, 1, 1)
|
| 391 |
+
|
| 392 |
+
if not return_dict:
|
| 393 |
+
return (d_x, logvar)
|
| 394 |
+
return EDM2UNet2DOutput(sample=d_x, logvar=logvar)
|
| 395 |
+
|
| 396 |
+
@classmethod
|
| 397 |
+
def from_pretrained(cls, pretrained_model_name_or_path: str, torch_dtype: Optional[torch.dtype] = None, **kwargs):
|
| 398 |
+
subfolder = kwargs.pop("subfolder", None)
|
| 399 |
+
model_dir = os.path.join(pretrained_model_name_or_path, subfolder) if subfolder else pretrained_model_name_or_path
|
| 400 |
+
with open(os.path.join(model_dir, "config.json"), "r", encoding="utf-8") as f:
|
| 401 |
+
config = json.load(f)
|
| 402 |
+
init_kwargs = {k: v for k, v in config.items() if k in _CONFIG_KEYS}
|
| 403 |
+
model = cls(**init_kwargs)
|
| 404 |
+
weight_file = os.path.join(model_dir, "diffusion_pytorch_model.safetensors")
|
| 405 |
+
if os.path.isfile(weight_file):
|
| 406 |
+
from safetensors.torch import load_file
|
| 407 |
+
|
| 408 |
+
state_dict = load_file(weight_file)
|
| 409 |
+
else:
|
| 410 |
+
state_dict = torch.load(os.path.join(model_dir, "diffusion_pytorch_model.bin"), map_location="cpu")
|
| 411 |
+
model.load_state_dict(state_dict, strict=True)
|
| 412 |
+
if torch_dtype is not None:
|
| 413 |
+
model = model.to(dtype=torch_dtype)
|
| 414 |
+
return model
|
| 415 |
+
|
| 416 |
+
def save_pretrained(self, save_directory: str, safe_serialization: bool = True):
|
| 417 |
+
os.makedirs(save_directory, exist_ok=True)
|
| 418 |
+
stored = dict(getattr(self, "config", {}))
|
| 419 |
+
config = {"_class_name": self.__class__.__name__}
|
| 420 |
+
for key in _CONFIG_KEYS:
|
| 421 |
+
if key in stored:
|
| 422 |
+
config[key] = stored[key]
|
| 423 |
+
elif hasattr(self, key):
|
| 424 |
+
config[key] = getattr(self, key)
|
| 425 |
+
with open(os.path.join(save_directory, "config.json"), "w", encoding="utf-8") as f:
|
| 426 |
+
json.dump(config, f, indent=2, sort_keys=True)
|
| 427 |
+
f.write("\n")
|
| 428 |
+
state_dict = self.state_dict()
|
| 429 |
+
if safe_serialization:
|
| 430 |
+
from safetensors.torch import save_file
|
| 431 |
+
|
| 432 |
+
save_file(state_dict, os.path.join(save_directory, "diffusion_pytorch_model.safetensors"))
|
| 433 |
+
else:
|
| 434 |
+
torch.save(state_dict, os.path.join(save_directory, "diffusion_pytorch_model.bin"))
|
edm2-img512-l-fid/vae/config.json
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_class_name": "AutoencoderKL",
|
| 3 |
+
"_diffusers_version": "0.36.0",
|
| 4 |
+
"_name_or_path": "stabilityai/sd-vae-ft-mse",
|
| 5 |
+
"act_fn": "silu",
|
| 6 |
+
"block_out_channels": [
|
| 7 |
+
128,
|
| 8 |
+
256,
|
| 9 |
+
512,
|
| 10 |
+
512
|
| 11 |
+
],
|
| 12 |
+
"down_block_types": [
|
| 13 |
+
"DownEncoderBlock2D",
|
| 14 |
+
"DownEncoderBlock2D",
|
| 15 |
+
"DownEncoderBlock2D",
|
| 16 |
+
"DownEncoderBlock2D"
|
| 17 |
+
],
|
| 18 |
+
"force_upcast": true,
|
| 19 |
+
"in_channels": 3,
|
| 20 |
+
"latent_channels": 4,
|
| 21 |
+
"latents_mean": null,
|
| 22 |
+
"latents_std": null,
|
| 23 |
+
"layers_per_block": 2,
|
| 24 |
+
"mid_block_add_attention": true,
|
| 25 |
+
"norm_num_groups": 32,
|
| 26 |
+
"out_channels": 3,
|
| 27 |
+
"sample_size": 256,
|
| 28 |
+
"scaling_factor": 0.18215,
|
| 29 |
+
"shift_factor": null,
|
| 30 |
+
"up_block_types": [
|
| 31 |
+
"UpDecoderBlock2D",
|
| 32 |
+
"UpDecoderBlock2D",
|
| 33 |
+
"UpDecoderBlock2D",
|
| 34 |
+
"UpDecoderBlock2D"
|
| 35 |
+
],
|
| 36 |
+
"use_post_quant_conv": true,
|
| 37 |
+
"use_quant_conv": true
|
| 38 |
+
}
|
edm2-img512-l-fid/vae/diffusion_pytorch_model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a1d993488569e928462932c8c38a0760b874d166399b14414135bd9c42df5815
|
| 3 |
+
size 334643276
|
edm2-img512-m-fid/demo.png
ADDED
|
Git LFS Details
|
edm2-img512-m-fid/model_index.json
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_class_name": [
|
| 3 |
+
"pipeline",
|
| 4 |
+
"EDM2Pipeline"
|
| 5 |
+
],
|
| 6 |
+
"_diffusers_version": "0.31.0",
|
| 7 |
+
"scheduler": [
|
| 8 |
+
"diffusers",
|
| 9 |
+
"EDMEulerScheduler"
|
| 10 |
+
],
|
| 11 |
+
"unet": [
|
| 12 |
+
"unet_edm2",
|
| 13 |
+
"EDM2UNet2DModel"
|
| 14 |
+
],
|
| 15 |
+
"vae": [
|
| 16 |
+
"diffusers",
|
| 17 |
+
"AutoencoderKL"
|
| 18 |
+
]
|
| 19 |
+
}
|
edm2-img512-m-fid/pipeline.py
ADDED
|
@@ -0,0 +1,406 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Hub custom pipeline: EDM2Pipeline.
|
| 2 |
+
Load with native Hugging Face diffusers and trust_remote_code=True.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from __future__ import annotations
|
| 6 |
+
|
| 7 |
+
# Copyright 2026 The HuggingFace Team. All rights reserved.
|
| 8 |
+
#
|
| 9 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 10 |
+
# you may not use this file except in compliance with the License.
|
| 11 |
+
# You may obtain a copy of the License at
|
| 12 |
+
#
|
| 13 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 14 |
+
#
|
| 15 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 16 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 17 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 18 |
+
# See the License for the specific language governing permissions and
|
| 19 |
+
# limitations under the License.
|
| 20 |
+
|
| 21 |
+
import json
|
| 22 |
+
import os
|
| 23 |
+
from pathlib import Path
|
| 24 |
+
from typing import Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Union
|
| 25 |
+
|
| 26 |
+
import numpy as np
|
| 27 |
+
import torch
|
| 28 |
+
|
| 29 |
+
from diffusers.image_processor import VaeImageProcessor
|
| 30 |
+
from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
| 31 |
+
from diffusers.utils import replace_example_docstring
|
| 32 |
+
from diffusers.utils.torch_utils import randn_tensor
|
| 33 |
+
|
| 34 |
+
EXAMPLE_DOC_STRING = """
|
| 35 |
+
Examples:
|
| 36 |
+
```py
|
| 37 |
+
>>> from pathlib import Path
|
| 38 |
+
>>> import torch
|
| 39 |
+
>>> from diffusers import DiffusionPipeline
|
| 40 |
+
|
| 41 |
+
>>> model_dir = Path("BiliSakura/EDM2-diffusers/edm2-img512-xs-fid").resolve()
|
| 42 |
+
>>> pipe = DiffusionPipeline.from_pretrained(
|
| 43 |
+
... str(model_dir),
|
| 44 |
+
... local_files_only=True,
|
| 45 |
+
... custom_pipeline=str(model_dir / "pipeline.py"),
|
| 46 |
+
... trust_remote_code=True,
|
| 47 |
+
... torch_dtype=torch.float32,
|
| 48 |
+
... )
|
| 49 |
+
>>> pipe.to("cuda")
|
| 50 |
+
|
| 51 |
+
>>> generator = torch.Generator(device="cuda").manual_seed(42)
|
| 52 |
+
>>> image = pipe(
|
| 53 |
+
... class_labels=207,
|
| 54 |
+
... num_inference_steps=32,
|
| 55 |
+
... guidance_scale=1.0,
|
| 56 |
+
... generator=generator,
|
| 57 |
+
... ).images[0]
|
| 58 |
+
>>> image.save("demo.png")
|
| 59 |
+
```
|
| 60 |
+
"""
|
| 61 |
+
|
| 62 |
+
# Default Stability VAE latent whitening used by NVlabs/edm2 (training/encoders.py).
|
| 63 |
+
_STABILITY_VAE_SCALE = np.float32(0.5) / np.float32([4.17, 4.62, 3.71, 3.28])
|
| 64 |
+
_STABILITY_VAE_BIAS = np.float32(0.0) - np.float32([5.81, 3.25, 0.12, -2.15]) * _STABILITY_VAE_SCALE
|
| 65 |
+
|
| 66 |
+
class EDM2Pipeline(DiffusionPipeline):
|
| 67 |
+
r"""
|
| 68 |
+
Pipeline for class-conditional image generation with EDM2
|
| 69 |
+
([Analyzing and Improving the Training Dynamics of Diffusion Models](https://arxiv.org/abs/2312.02696)).
|
| 70 |
+
|
| 71 |
+
Parameters:
|
| 72 |
+
unet ([`EDM2UNet2DModel`]):
|
| 73 |
+
Main magnitude-preserving U-Net with EDM preconditioning.
|
| 74 |
+
scheduler ([`EDMEulerScheduler`]):
|
| 75 |
+
Built-in diffusers scheduler used for the Karras sigma schedule. EDM2 Heun sampling runs in
|
| 76 |
+
the pipeline because the UNet returns denoised latents rather than noise predictions.
|
| 77 |
+
vae ([`AutoencoderKL`], *optional*):
|
| 78 |
+
Decoder for 512px latent-diffusion checkpoints. Required when `unet.in_channels == 4`.
|
| 79 |
+
gnet ([`EDM2UNet2DModel`], *optional*):
|
| 80 |
+
Guiding network for autoguidance (`ref.lerp(main, guidance_scale)`).
|
| 81 |
+
id2label (`dict[int, str]`, *optional*):
|
| 82 |
+
ImageNet class id to English label mapping.
|
| 83 |
+
"""
|
| 84 |
+
|
| 85 |
+
model_cpu_offload_seq = "unet->gnet->vae"
|
| 86 |
+
_optional_components = ["vae", "gnet"]
|
| 87 |
+
|
| 88 |
+
def __init__(
|
| 89 |
+
self,
|
| 90 |
+
unet,
|
| 91 |
+
scheduler,
|
| 92 |
+
vae=None,
|
| 93 |
+
gnet=None,
|
| 94 |
+
id2label: Optional[Dict[Union[int, str], str]] = None,
|
| 95 |
+
) -> None:
|
| 96 |
+
super().__init__()
|
| 97 |
+
self.register_modules(unet=unet, scheduler=scheduler, vae=vae, gnet=gnet)
|
| 98 |
+
self._id2label = self._normalize_id2label(id2label)
|
| 99 |
+
self.labels = self._build_label2id(self._id2label)
|
| 100 |
+
self._labels_loaded_from_model_index = bool(self._id2label)
|
| 101 |
+
self.vae_scale_factor = 8 if self.vae is not None else 1
|
| 102 |
+
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_normalize=False)
|
| 103 |
+
|
| 104 |
+
@staticmethod
|
| 105 |
+
def _normalize_id2label(id2label: Optional[Dict[Union[int, str], str]]) -> Dict[int, str]:
|
| 106 |
+
if not id2label:
|
| 107 |
+
return {}
|
| 108 |
+
return {int(key): value for key, value in id2label.items()}
|
| 109 |
+
|
| 110 |
+
@staticmethod
|
| 111 |
+
def _build_label2id(id2label: Dict[int, str]) -> Dict[str, int]:
|
| 112 |
+
label2id: Dict[str, int] = {}
|
| 113 |
+
for class_id, value in id2label.items():
|
| 114 |
+
for synonym in value.split(","):
|
| 115 |
+
synonym = synonym.strip()
|
| 116 |
+
if synonym:
|
| 117 |
+
label2id[synonym] = int(class_id)
|
| 118 |
+
return dict(sorted(label2id.items()))
|
| 119 |
+
|
| 120 |
+
def _ensure_labels_loaded(self) -> None:
|
| 121 |
+
if self._labels_loaded_from_model_index:
|
| 122 |
+
return
|
| 123 |
+
loaded = self._read_id2label_from_model_index(getattr(self.config, "_name_or_path", None))
|
| 124 |
+
if loaded:
|
| 125 |
+
self._id2label = loaded
|
| 126 |
+
self.labels = self._build_label2id(self._id2label)
|
| 127 |
+
self._labels_loaded_from_model_index = True
|
| 128 |
+
|
| 129 |
+
@staticmethod
|
| 130 |
+
def _read_id2label_from_model_index(variant_path: Optional[str]) -> Dict[int, str]:
|
| 131 |
+
if not variant_path:
|
| 132 |
+
return {}
|
| 133 |
+
model_index_path = Path(variant_path).resolve() / "model_index.json"
|
| 134 |
+
if not model_index_path.is_file():
|
| 135 |
+
return {}
|
| 136 |
+
raw = json.loads(model_index_path.read_text(encoding="utf-8"))
|
| 137 |
+
id2label = raw.get("id2label")
|
| 138 |
+
if not isinstance(id2label, dict):
|
| 139 |
+
return {}
|
| 140 |
+
return {int(key): value for key, value in id2label.items()}
|
| 141 |
+
|
| 142 |
+
@property
|
| 143 |
+
def id2label(self) -> Dict[int, str]:
|
| 144 |
+
r"""ImageNet class id to English label string (comma-separated synonyms)."""
|
| 145 |
+
self._ensure_labels_loaded()
|
| 146 |
+
return self._id2label
|
| 147 |
+
|
| 148 |
+
def get_label_ids(self, label: Union[str, List[str]]) -> List[int]:
|
| 149 |
+
r"""
|
| 150 |
+
Map ImageNet label strings to class ids.
|
| 151 |
+
|
| 152 |
+
Args:
|
| 153 |
+
label (`str` or `list[str]`):
|
| 154 |
+
One or more English label strings that match entries in `id2label`.
|
| 155 |
+
"""
|
| 156 |
+
self._ensure_labels_loaded()
|
| 157 |
+
if not self.labels:
|
| 158 |
+
raise ValueError("No English labels loaded. Add `id2label` to model_index.json.")
|
| 159 |
+
labels = [label] if isinstance(label, str) else list(label)
|
| 160 |
+
missing = [item for item in labels if item not in self.labels]
|
| 161 |
+
if missing:
|
| 162 |
+
preview = ", ".join(list(self.labels.keys())[:8])
|
| 163 |
+
raise ValueError(f"Unknown English label(s): {missing}. Example valid labels: {preview}, ...")
|
| 164 |
+
return [self.labels[item] for item in labels]
|
| 165 |
+
|
| 166 |
+
def _default_image_size(self) -> int:
|
| 167 |
+
latent_size = int(getattr(self.unet, "sample_size", getattr(self.unet.config, "sample_size", 64)))
|
| 168 |
+
return latent_size * self.vae_scale_factor
|
| 169 |
+
|
| 170 |
+
def check_inputs(
|
| 171 |
+
self,
|
| 172 |
+
height: int,
|
| 173 |
+
width: int,
|
| 174 |
+
num_inference_steps: int,
|
| 175 |
+
guidance_scale: float,
|
| 176 |
+
output_type: str,
|
| 177 |
+
) -> None:
|
| 178 |
+
if num_inference_steps < 1:
|
| 179 |
+
raise ValueError("num_inference_steps must be >= 1.")
|
| 180 |
+
if guidance_scale < 1.0:
|
| 181 |
+
raise ValueError("guidance_scale must be >= 1.0.")
|
| 182 |
+
if guidance_scale > 1.0 and self.gnet is None:
|
| 183 |
+
raise ValueError("guidance_scale > 1.0 requires a guiding network (`gnet`).")
|
| 184 |
+
if output_type not in {"pil", "np", "pt", "latent"}:
|
| 185 |
+
raise ValueError("output_type must be one of: 'pil', 'np', 'pt', 'latent'.")
|
| 186 |
+
|
| 187 |
+
native_size = self._default_image_size()
|
| 188 |
+
if height != native_size or width != native_size:
|
| 189 |
+
raise ValueError(
|
| 190 |
+
f"EDM2 expects native resolution height=width={native_size}. "
|
| 191 |
+
f"Got height={height}, width={width}."
|
| 192 |
+
)
|
| 193 |
+
|
| 194 |
+
def _normalize_class_labels(
|
| 195 |
+
self,
|
| 196 |
+
class_labels: Optional[Union[int, str, Sequence[Union[int, str]], torch.Tensor]],
|
| 197 |
+
batch_size: int,
|
| 198 |
+
device: torch.device,
|
| 199 |
+
) -> Optional[torch.Tensor]:
|
| 200 |
+
label_dim = int(getattr(self.unet, "num_class_embeds", getattr(self.unet.config, "num_class_embeds", 0)))
|
| 201 |
+
if label_dim == 0:
|
| 202 |
+
return None
|
| 203 |
+
if class_labels is None:
|
| 204 |
+
indices = torch.randint(label_dim, size=(batch_size,), device=device)
|
| 205 |
+
return torch.eye(label_dim, device=device, dtype=torch.float32)[indices]
|
| 206 |
+
|
| 207 |
+
if isinstance(class_labels, str):
|
| 208 |
+
class_labels = self.get_label_ids(class_labels)[0]
|
| 209 |
+
elif isinstance(class_labels, Sequence) and class_labels and isinstance(class_labels[0], str):
|
| 210 |
+
class_labels = self.get_label_ids(list(class_labels))
|
| 211 |
+
|
| 212 |
+
if isinstance(class_labels, int):
|
| 213 |
+
indices = torch.full((batch_size,), class_labels, device=device, dtype=torch.long)
|
| 214 |
+
elif isinstance(class_labels, torch.Tensor):
|
| 215 |
+
if class_labels.ndim == 2:
|
| 216 |
+
labels = class_labels.to(device=device, dtype=torch.float32)
|
| 217 |
+
if labels.shape[0] != batch_size:
|
| 218 |
+
raise ValueError(f"class_labels batch must match batch_size={batch_size}.")
|
| 219 |
+
return labels
|
| 220 |
+
indices = class_labels.to(device=device, dtype=torch.long).flatten()
|
| 221 |
+
else:
|
| 222 |
+
indices = torch.tensor(list(class_labels), device=device, dtype=torch.long)
|
| 223 |
+
|
| 224 |
+
if indices.numel() == 1 and batch_size > 1:
|
| 225 |
+
indices = indices.repeat(batch_size)
|
| 226 |
+
if indices.numel() != batch_size:
|
| 227 |
+
raise ValueError(f"class_labels must resolve to batch size {batch_size}, got {indices.numel()}.")
|
| 228 |
+
return torch.eye(label_dim, device=device, dtype=torch.float32)[indices]
|
| 229 |
+
|
| 230 |
+
def prepare_latents(
|
| 231 |
+
self,
|
| 232 |
+
batch_size: int,
|
| 233 |
+
height: int,
|
| 234 |
+
width: int,
|
| 235 |
+
dtype: torch.dtype,
|
| 236 |
+
device: torch.device,
|
| 237 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]],
|
| 238 |
+
) -> torch.Tensor:
|
| 239 |
+
in_channels = int(getattr(self.unet, "in_channels", getattr(self.unet.config, "in_channels", 4)))
|
| 240 |
+
latent_size = height // self.vae_scale_factor
|
| 241 |
+
return randn_tensor(
|
| 242 |
+
(batch_size, in_channels, latent_size, latent_size),
|
| 243 |
+
generator=generator,
|
| 244 |
+
device=device,
|
| 245 |
+
dtype=torch.float32,
|
| 246 |
+
)
|
| 247 |
+
|
| 248 |
+
def decode_latents(self, latents: torch.Tensor, output_type: str = "pil"):
|
| 249 |
+
if output_type == "latent":
|
| 250 |
+
return latents
|
| 251 |
+
|
| 252 |
+
in_channels = int(getattr(self.unet, "in_channels", getattr(self.unet.config, "in_channels", 3)))
|
| 253 |
+
if self.vae is None:
|
| 254 |
+
image = (latents.to(torch.float32) * 127.5 + 128).clip(0, 255) / 255.0
|
| 255 |
+
return self.image_processor.postprocess(image, output_type=output_type)
|
| 256 |
+
|
| 257 |
+
if in_channels == 4:
|
| 258 |
+
x = latents.to(torch.float32)
|
| 259 |
+
scale = torch.as_tensor(_STABILITY_VAE_SCALE, dtype=x.dtype, device=x.device).reshape(1, -1, 1, 1)
|
| 260 |
+
bias = torch.as_tensor(_STABILITY_VAE_BIAS, dtype=x.dtype, device=x.device).reshape(1, -1, 1, 1)
|
| 261 |
+
x = (x - bias) / scale
|
| 262 |
+
else:
|
| 263 |
+
x = latents.to(torch.float32)
|
| 264 |
+
|
| 265 |
+
vae_dtype = getattr(self.vae, "dtype", None) or next(self.vae.parameters()).dtype
|
| 266 |
+
image = self.vae.decode(x.to(dtype=vae_dtype)).sample.to(torch.float32).clamp(0, 1)
|
| 267 |
+
|
| 268 |
+
return self.image_processor.postprocess(image, output_type=output_type)
|
| 269 |
+
|
| 270 |
+
@staticmethod
|
| 271 |
+
def _apply_autoguidance(
|
| 272 |
+
main: torch.Tensor,
|
| 273 |
+
ref: torch.Tensor,
|
| 274 |
+
guidance_scale: float,
|
| 275 |
+
) -> torch.Tensor:
|
| 276 |
+
return ref.lerp(main, guidance_scale)
|
| 277 |
+
|
| 278 |
+
@staticmethod
|
| 279 |
+
def _sample_edm2_heun(
|
| 280 |
+
denoise_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
|
| 281 |
+
noise: torch.Tensor,
|
| 282 |
+
sigmas: torch.Tensor,
|
| 283 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 284 |
+
progress_bar: Optional[Callable[[Iterable], Iterable]] = None,
|
| 285 |
+
dtype: torch.dtype = torch.float32,
|
| 286 |
+
) -> torch.Tensor:
|
| 287 |
+
"""NVlabs EDM2 Heun sampler (generate_images.edm_sampler, guidance=1, S_churn=0)."""
|
| 288 |
+
x_next = noise.to(dtype) * sigmas[0]
|
| 289 |
+
|
| 290 |
+
sigma_pairs = list(zip(sigmas[:-1], sigmas[1:]))
|
| 291 |
+
if progress_bar is not None:
|
| 292 |
+
sigma_pairs = progress_bar(sigma_pairs)
|
| 293 |
+
|
| 294 |
+
num_steps = len(sigma_pairs)
|
| 295 |
+
for i, (sigma_cur, sigma_next) in enumerate(sigma_pairs):
|
| 296 |
+
x_hat, sigma_hat = x_next, sigma_cur
|
| 297 |
+
d_cur = (x_hat - denoise_fn(x_hat, sigma_hat)) / sigma_hat
|
| 298 |
+
x_next = x_hat + (sigma_next - sigma_hat) * d_cur
|
| 299 |
+
if i < num_steps - 1:
|
| 300 |
+
d_prime = (x_next - denoise_fn(x_next, sigma_next)) / sigma_next
|
| 301 |
+
x_next = x_hat + (sigma_next - sigma_hat) * (0.5 * d_cur + 0.5 * d_prime)
|
| 302 |
+
return x_next
|
| 303 |
+
|
| 304 |
+
@torch.inference_mode()
|
| 305 |
+
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
| 306 |
+
def __call__(
|
| 307 |
+
self,
|
| 308 |
+
class_labels: Optional[Union[int, str, Sequence[Union[int, str]], torch.Tensor]] = None,
|
| 309 |
+
batch_size: int = 1,
|
| 310 |
+
height: Optional[int] = None,
|
| 311 |
+
width: Optional[int] = None,
|
| 312 |
+
num_inference_steps: int = 32,
|
| 313 |
+
guidance_scale: float = 1.0,
|
| 314 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 315 |
+
output_type: str = "pil",
|
| 316 |
+
return_dict: bool = True,
|
| 317 |
+
) -> Union[ImagePipelineOutput, Tuple]:
|
| 318 |
+
r"""
|
| 319 |
+
Generate class-conditional images with EDM2.
|
| 320 |
+
|
| 321 |
+
Args:
|
| 322 |
+
class_labels (`int`, `str`, `list[int]`, `list[str]`, or `torch.Tensor`, *optional*):
|
| 323 |
+
ImageNet class indices, English label strings, or one-hot float tensors.
|
| 324 |
+
Random classes are sampled when omitted on conditional models.
|
| 325 |
+
batch_size (`int`, defaults to `1`):
|
| 326 |
+
Number of images to generate.
|
| 327 |
+
height (`int`, *optional*):
|
| 328 |
+
Output height in pixels. Defaults to the pretrained native resolution.
|
| 329 |
+
width (`int`, *optional*):
|
| 330 |
+
Output width in pixels. Defaults to the pretrained native resolution.
|
| 331 |
+
num_inference_steps (`int`, defaults to `32`):
|
| 332 |
+
Number of EDM2 Heun steps (NVlabs default).
|
| 333 |
+
guidance_scale (`float`, defaults to `1.0`):
|
| 334 |
+
Autoguidance strength. Values above `1.0` blend the main net with `gnet`
|
| 335 |
+
via `gnet_output.lerp(unet_output, guidance_scale)`.
|
| 336 |
+
generator (`torch.Generator`, *optional*):
|
| 337 |
+
RNG for reproducibility.
|
| 338 |
+
output_type (`str`, defaults to `"pil"`):
|
| 339 |
+
`"pil"`, `"np"`, `"pt"`, or `"latent"`.
|
| 340 |
+
return_dict (`bool`, defaults to `True`):
|
| 341 |
+
Return [`~pipelines.pipeline_utils.ImagePipelineOutput`] if True.
|
| 342 |
+
|
| 343 |
+
Examples:
|
| 344 |
+
<!-- this section is replaced by replace_example_docstring -->
|
| 345 |
+
"""
|
| 346 |
+
default_size = self._default_image_size()
|
| 347 |
+
height = int(height or default_size)
|
| 348 |
+
width = int(width or default_size)
|
| 349 |
+
self.check_inputs(height, width, num_inference_steps, guidance_scale, output_type)
|
| 350 |
+
|
| 351 |
+
device = self._execution_device
|
| 352 |
+
dtype = self.unet.dtype
|
| 353 |
+
labels = self._normalize_class_labels(class_labels, batch_size=batch_size, device=device)
|
| 354 |
+
noise = self.prepare_latents(batch_size, height, width, dtype, device, generator)
|
| 355 |
+
|
| 356 |
+
def denoise_fn(x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor:
|
| 357 |
+
sigma_batch = sigma.reshape(1).expand(batch_size)
|
| 358 |
+
main = self.unet(
|
| 359 |
+
sample=x,
|
| 360 |
+
sigma=sigma_batch,
|
| 361 |
+
class_labels=labels,
|
| 362 |
+
force_fp32=True,
|
| 363 |
+
).sample
|
| 364 |
+
if guidance_scale == 1.0 or self.gnet is None:
|
| 365 |
+
return main.to(torch.float32)
|
| 366 |
+
ref = self.gnet(
|
| 367 |
+
sample=x,
|
| 368 |
+
sigma=sigma_batch,
|
| 369 |
+
class_labels=labels,
|
| 370 |
+
force_fp32=True,
|
| 371 |
+
).sample
|
| 372 |
+
return self._apply_autoguidance(main, ref, guidance_scale).to(torch.float32)
|
| 373 |
+
|
| 374 |
+
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
| 375 |
+
latents = self._sample_edm2_heun(
|
| 376 |
+
denoise_fn=denoise_fn,
|
| 377 |
+
noise=noise,
|
| 378 |
+
sigmas=self.scheduler.sigmas.to(device),
|
| 379 |
+
generator=generator,
|
| 380 |
+
progress_bar=self.progress_bar,
|
| 381 |
+
dtype=torch.float32,
|
| 382 |
+
)
|
| 383 |
+
|
| 384 |
+
image = self.decode_latents(latents, output_type=output_type)
|
| 385 |
+
if not return_dict:
|
| 386 |
+
return (image, latents)
|
| 387 |
+
return ImagePipelineOutput(images=image)
|
| 388 |
+
|
| 389 |
+
@classmethod
|
| 390 |
+
def _load_vae(cls, pretrained_model_name_or_path: str, torch_dtype: Optional[torch.dtype] = None):
|
| 391 |
+
vae_dir = os.path.join(pretrained_model_name_or_path, "vae")
|
| 392 |
+
if os.path.isdir(vae_dir):
|
| 393 |
+
try:
|
| 394 |
+
|
| 395 |
+
return AutoencoderKL.from_pretrained(vae_dir, torch_dtype=torch_dtype)
|
| 396 |
+
except Exception:
|
| 397 |
+
return None
|
| 398 |
+
|
| 399 |
+
vae_hint = os.path.join(pretrained_model_name_or_path, "vae_pretrained_model_name_or_path.txt")
|
| 400 |
+
if os.path.isfile(vae_hint):
|
| 401 |
+
with open(vae_hint, "r", encoding="utf-8") as f:
|
| 402 |
+
hub_id = f.read().strip()
|
| 403 |
+
if hub_id:
|
| 404 |
+
|
| 405 |
+
return AutoencoderKL.from_pretrained(hub_id, torch_dtype=torch_dtype)
|
| 406 |
+
return None
|
edm2-img512-m-fid/scheduler/scheduler_config.json
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_class_name": "EDMEulerScheduler",
|
| 3 |
+
"final_sigmas_type": "zero",
|
| 4 |
+
"num_train_timesteps": 1000,
|
| 5 |
+
"prediction_type": "epsilon",
|
| 6 |
+
"rho": 7.0,
|
| 7 |
+
"sigma_data": 0.5,
|
| 8 |
+
"sigma_max": 80.0,
|
| 9 |
+
"sigma_min": 0.002,
|
| 10 |
+
"sigma_schedule": "karras"
|
| 11 |
+
}
|
edm2-img512-m-fid/unet/config.json
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_class_name": "EDM2UNet2DModel",
|
| 3 |
+
"attn_balance": 0.3,
|
| 4 |
+
"attn_resolutions": [
|
| 5 |
+
16,
|
| 6 |
+
8
|
| 7 |
+
],
|
| 8 |
+
"channel_mult": [
|
| 9 |
+
1,
|
| 10 |
+
2,
|
| 11 |
+
3,
|
| 12 |
+
4
|
| 13 |
+
],
|
| 14 |
+
"channel_mult_emb": 4,
|
| 15 |
+
"channel_mult_noise": 1,
|
| 16 |
+
"channels_per_head": 64,
|
| 17 |
+
"clip_act": 256,
|
| 18 |
+
"concat_balance": 0.5,
|
| 19 |
+
"dropout": 0.0,
|
| 20 |
+
"in_channels": 4,
|
| 21 |
+
"label_balance": 0.5,
|
| 22 |
+
"logvar_channels": 128,
|
| 23 |
+
"model_channels": 256,
|
| 24 |
+
"num_blocks": 3,
|
| 25 |
+
"num_class_embeds": 1000,
|
| 26 |
+
"out_channels": 4,
|
| 27 |
+
"res_balance": 0.3,
|
| 28 |
+
"sample_size": 64,
|
| 29 |
+
"sigma_data": 0.5,
|
| 30 |
+
"use_fp16": true
|
| 31 |
+
}
|
edm2-img512-m-fid/unet/diffusion_pytorch_model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:4733c8b2d2823cd6ce7a67e2b89b0e9b94d50fdf595b0e0b17299e198da3bcfc
|
| 3 |
+
size 1991256788
|
edm2-img512-m-fid/unet/unet_edm2.py
ADDED
|
@@ -0,0 +1,434 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import json
|
| 3 |
+
import os
|
| 4 |
+
from dataclasses import dataclass
|
| 5 |
+
from typing import List, Optional, Tuple
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
import torch
|
| 9 |
+
|
| 10 |
+
try:
|
| 11 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
| 12 |
+
from diffusers.models.modeling_utils import ModelMixin
|
| 13 |
+
from diffusers.utils import BaseOutput
|
| 14 |
+
except ImportError: # pragma: no cover
|
| 15 |
+
class ModelMixin(torch.nn.Module):
|
| 16 |
+
pass
|
| 17 |
+
|
| 18 |
+
class ConfigMixin:
|
| 19 |
+
config = {}
|
| 20 |
+
|
| 21 |
+
def register_to_config(self, **kwargs):
|
| 22 |
+
self.config = kwargs
|
| 23 |
+
|
| 24 |
+
def register_to_config(func):
|
| 25 |
+
return func
|
| 26 |
+
|
| 27 |
+
@dataclass
|
| 28 |
+
class BaseOutput:
|
| 29 |
+
pass
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def normalize(x: torch.Tensor, dim: Optional[List[int]] = None, eps: float = 1e-4) -> torch.Tensor:
|
| 33 |
+
if dim is None:
|
| 34 |
+
dim = list(range(1, x.ndim))
|
| 35 |
+
norm = torch.linalg.vector_norm(x, dim=dim, keepdim=True, dtype=torch.float32)
|
| 36 |
+
norm = torch.add(eps, norm, alpha=math.sqrt(norm.numel() / x.numel()))
|
| 37 |
+
return x / norm.to(x.dtype)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def resample(x: torch.Tensor, f: List[float], mode: str = "keep") -> torch.Tensor:
|
| 41 |
+
if mode == "keep":
|
| 42 |
+
return x
|
| 43 |
+
filt = np.float32(f)
|
| 44 |
+
pad = (len(filt) - 1) // 2
|
| 45 |
+
filt = filt / filt.sum()
|
| 46 |
+
filt = np.outer(filt, filt)[np.newaxis, np.newaxis, :, :]
|
| 47 |
+
filt = torch.as_tensor(filt, dtype=x.dtype, device=x.device)
|
| 48 |
+
c = x.shape[1]
|
| 49 |
+
if mode == "down":
|
| 50 |
+
return torch.nn.functional.conv2d(x, filt.tile([c, 1, 1, 1]), groups=c, stride=2, padding=(pad,))
|
| 51 |
+
return torch.nn.functional.conv_transpose2d(x, (filt * 4).tile([c, 1, 1, 1]), groups=c, stride=2, padding=(pad,))
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def mp_silu(x: torch.Tensor) -> torch.Tensor:
|
| 55 |
+
return torch.nn.functional.silu(x) / 0.596
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def mp_sum(a: torch.Tensor, b: torch.Tensor, t: float = 0.5) -> torch.Tensor:
|
| 59 |
+
return a.lerp(b, t) / math.sqrt((1 - t) ** 2 + t ** 2)
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def mp_cat(a: torch.Tensor, b: torch.Tensor, dim: int = 1, t: float = 0.5) -> torch.Tensor:
|
| 63 |
+
na = a.shape[dim]
|
| 64 |
+
nb = b.shape[dim]
|
| 65 |
+
c = math.sqrt((na + nb) / ((1 - t) ** 2 + t ** 2))
|
| 66 |
+
wa = c / math.sqrt(na) * (1 - t)
|
| 67 |
+
wb = c / math.sqrt(nb) * t
|
| 68 |
+
return torch.cat([wa * a, wb * b], dim=dim)
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
class MPFourier(torch.nn.Module):
|
| 72 |
+
def __init__(self, num_channels: int, bandwidth: float = 1):
|
| 73 |
+
super().__init__()
|
| 74 |
+
self.register_buffer("freqs", 2 * math.pi * torch.randn(num_channels) * bandwidth)
|
| 75 |
+
self.register_buffer("phases", 2 * math.pi * torch.rand(num_channels))
|
| 76 |
+
|
| 77 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 78 |
+
y = x.to(torch.float32).ger(self.freqs.to(torch.float32))
|
| 79 |
+
y = y + self.phases.to(torch.float32)
|
| 80 |
+
y = y.cos() * math.sqrt(2)
|
| 81 |
+
return y.to(x.dtype)
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
class MPConv(torch.nn.Module):
|
| 85 |
+
def __init__(self, in_channels: int, out_channels: int, kernel: Tuple[int, ...]):
|
| 86 |
+
super().__init__()
|
| 87 |
+
self.out_channels = out_channels
|
| 88 |
+
self.weight = torch.nn.Parameter(torch.randn(out_channels, in_channels, *kernel))
|
| 89 |
+
|
| 90 |
+
def forward(self, x: torch.Tensor, gain: float = 1) -> torch.Tensor:
|
| 91 |
+
w = self.weight.to(torch.float32)
|
| 92 |
+
if self.training:
|
| 93 |
+
with torch.no_grad():
|
| 94 |
+
self.weight.copy_(normalize(w))
|
| 95 |
+
w = normalize(w)
|
| 96 |
+
w = w * (gain / math.sqrt(w[0].numel()))
|
| 97 |
+
w = w.to(x.dtype)
|
| 98 |
+
if w.ndim == 2:
|
| 99 |
+
return x @ w.t()
|
| 100 |
+
return torch.nn.functional.conv2d(x, w, padding=(w.shape[-1] // 2,))
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
class Block(torch.nn.Module):
|
| 104 |
+
def __init__(
|
| 105 |
+
self,
|
| 106 |
+
in_channels: int,
|
| 107 |
+
out_channels: int,
|
| 108 |
+
emb_channels: int,
|
| 109 |
+
flavor: str = "enc",
|
| 110 |
+
resample_mode: str = "keep",
|
| 111 |
+
resample_filter: List[float] = [1, 1],
|
| 112 |
+
attention: bool = False,
|
| 113 |
+
channels_per_head: int = 64,
|
| 114 |
+
dropout: float = 0.0,
|
| 115 |
+
res_balance: float = 0.3,
|
| 116 |
+
attn_balance: float = 0.3,
|
| 117 |
+
clip_act: Optional[float] = 256,
|
| 118 |
+
):
|
| 119 |
+
super().__init__()
|
| 120 |
+
self.out_channels = out_channels
|
| 121 |
+
self.flavor = flavor
|
| 122 |
+
self.resample_filter = resample_filter
|
| 123 |
+
self.resample_mode = resample_mode
|
| 124 |
+
self.num_heads = out_channels // channels_per_head if attention else 0
|
| 125 |
+
self.dropout = dropout
|
| 126 |
+
self.res_balance = res_balance
|
| 127 |
+
self.attn_balance = attn_balance
|
| 128 |
+
self.clip_act = clip_act
|
| 129 |
+
self.emb_gain = torch.nn.Parameter(torch.zeros([]))
|
| 130 |
+
self.conv_res0 = MPConv(out_channels if flavor == "enc" else in_channels, out_channels, kernel=(3, 3))
|
| 131 |
+
self.emb_linear = MPConv(emb_channels, out_channels, kernel=())
|
| 132 |
+
self.conv_res1 = MPConv(out_channels, out_channels, kernel=(3, 3))
|
| 133 |
+
self.conv_skip = MPConv(in_channels, out_channels, kernel=(1, 1)) if in_channels != out_channels else None
|
| 134 |
+
self.attn_qkv = MPConv(out_channels, out_channels * 3, kernel=(1, 1)) if self.num_heads else None
|
| 135 |
+
self.attn_proj = MPConv(out_channels, out_channels, kernel=(1, 1)) if self.num_heads else None
|
| 136 |
+
|
| 137 |
+
def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor:
|
| 138 |
+
x = resample(x, f=self.resample_filter, mode=self.resample_mode)
|
| 139 |
+
if self.flavor == "enc":
|
| 140 |
+
if self.conv_skip is not None:
|
| 141 |
+
x = self.conv_skip(x)
|
| 142 |
+
x = normalize(x, dim=[1])
|
| 143 |
+
|
| 144 |
+
y = self.conv_res0(mp_silu(x))
|
| 145 |
+
c = self.emb_linear(emb, gain=self.emb_gain) + 1
|
| 146 |
+
y = mp_silu(y * c.unsqueeze(2).unsqueeze(3).to(y.dtype))
|
| 147 |
+
if self.training and self.dropout:
|
| 148 |
+
y = torch.nn.functional.dropout(y, p=self.dropout)
|
| 149 |
+
y = self.conv_res1(y)
|
| 150 |
+
|
| 151 |
+
if self.flavor == "dec" and self.conv_skip is not None:
|
| 152 |
+
x = self.conv_skip(x)
|
| 153 |
+
x = mp_sum(x, y, t=self.res_balance)
|
| 154 |
+
|
| 155 |
+
if self.num_heads:
|
| 156 |
+
y = self.attn_qkv(x)
|
| 157 |
+
y = y.reshape(y.shape[0], self.num_heads, -1, 3, y.shape[2] * y.shape[3])
|
| 158 |
+
q, k, v = normalize(y, dim=[2]).unbind(3)
|
| 159 |
+
w = torch.einsum("nhcq,nhck->nhqk", q, k / math.sqrt(q.shape[2])).softmax(dim=3)
|
| 160 |
+
y = torch.einsum("nhqk,nhck->nhcq", w, v)
|
| 161 |
+
y = self.attn_proj(y.reshape(*x.shape))
|
| 162 |
+
x = mp_sum(x, y, t=self.attn_balance)
|
| 163 |
+
|
| 164 |
+
if self.clip_act is not None:
|
| 165 |
+
x = x.clip_(-self.clip_act, self.clip_act)
|
| 166 |
+
return x
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
class EDM2UNet(torch.nn.Module):
|
| 170 |
+
def __init__(
|
| 171 |
+
self,
|
| 172 |
+
img_resolution: int,
|
| 173 |
+
img_channels: int,
|
| 174 |
+
label_dim: int,
|
| 175 |
+
model_channels: int = 192,
|
| 176 |
+
channel_mult: Tuple[int, ...] = (1, 2, 3, 4),
|
| 177 |
+
channel_mult_noise: Optional[int] = None,
|
| 178 |
+
channel_mult_emb: Optional[int] = None,
|
| 179 |
+
num_blocks: int = 3,
|
| 180 |
+
attn_resolutions: Tuple[int, ...] = (16, 8),
|
| 181 |
+
label_balance: float = 0.5,
|
| 182 |
+
concat_balance: float = 0.5,
|
| 183 |
+
**block_kwargs,
|
| 184 |
+
):
|
| 185 |
+
super().__init__()
|
| 186 |
+
cblock = [model_channels * x for x in channel_mult]
|
| 187 |
+
cnoise = model_channels * channel_mult_noise if channel_mult_noise is not None else cblock[0]
|
| 188 |
+
cemb = model_channels * channel_mult_emb if channel_mult_emb is not None else max(cblock)
|
| 189 |
+
self.label_balance = label_balance
|
| 190 |
+
self.concat_balance = concat_balance
|
| 191 |
+
self.out_gain = torch.nn.Parameter(torch.zeros([]))
|
| 192 |
+
|
| 193 |
+
self.emb_fourier = MPFourier(cnoise)
|
| 194 |
+
self.emb_noise = MPConv(cnoise, cemb, kernel=())
|
| 195 |
+
self.emb_label = MPConv(label_dim, cemb, kernel=()) if label_dim else None
|
| 196 |
+
|
| 197 |
+
self.enc = torch.nn.ModuleDict()
|
| 198 |
+
cout = img_channels + 1
|
| 199 |
+
for level, channels in enumerate(cblock):
|
| 200 |
+
res = img_resolution >> level
|
| 201 |
+
if level == 0:
|
| 202 |
+
cin = cout
|
| 203 |
+
cout = channels
|
| 204 |
+
self.enc[f"{res}x{res}_conv"] = MPConv(cin, cout, kernel=(3, 3))
|
| 205 |
+
else:
|
| 206 |
+
self.enc[f"{res}x{res}_down"] = Block(cout, cout, cemb, flavor="enc", resample_mode="down", **block_kwargs)
|
| 207 |
+
for idx in range(num_blocks):
|
| 208 |
+
cin = cout
|
| 209 |
+
cout = channels
|
| 210 |
+
self.enc[f"{res}x{res}_block{idx}"] = Block(
|
| 211 |
+
cin,
|
| 212 |
+
cout,
|
| 213 |
+
cemb,
|
| 214 |
+
flavor="enc",
|
| 215 |
+
attention=(res in attn_resolutions),
|
| 216 |
+
**block_kwargs,
|
| 217 |
+
)
|
| 218 |
+
|
| 219 |
+
self.dec = torch.nn.ModuleDict()
|
| 220 |
+
skips = [block.out_channels for block in self.enc.values()]
|
| 221 |
+
for level, channels in reversed(list(enumerate(cblock))):
|
| 222 |
+
res = img_resolution >> level
|
| 223 |
+
if level == len(cblock) - 1:
|
| 224 |
+
self.dec[f"{res}x{res}_in0"] = Block(cout, cout, cemb, flavor="dec", attention=True, **block_kwargs)
|
| 225 |
+
self.dec[f"{res}x{res}_in1"] = Block(cout, cout, cemb, flavor="dec", **block_kwargs)
|
| 226 |
+
else:
|
| 227 |
+
self.dec[f"{res}x{res}_up"] = Block(cout, cout, cemb, flavor="dec", resample_mode="up", **block_kwargs)
|
| 228 |
+
for idx in range(num_blocks + 1):
|
| 229 |
+
cin = cout + skips.pop()
|
| 230 |
+
cout = channels
|
| 231 |
+
self.dec[f"{res}x{res}_block{idx}"] = Block(
|
| 232 |
+
cin,
|
| 233 |
+
cout,
|
| 234 |
+
cemb,
|
| 235 |
+
flavor="dec",
|
| 236 |
+
attention=(res in attn_resolutions),
|
| 237 |
+
**block_kwargs,
|
| 238 |
+
)
|
| 239 |
+
|
| 240 |
+
self.out_conv = MPConv(cout, img_channels, kernel=(3, 3))
|
| 241 |
+
|
| 242 |
+
def forward(self, x: torch.Tensor, noise_labels: torch.Tensor, class_labels: Optional[torch.Tensor]) -> torch.Tensor:
|
| 243 |
+
emb = self.emb_noise(self.emb_fourier(noise_labels))
|
| 244 |
+
if self.emb_label is not None:
|
| 245 |
+
if class_labels is None:
|
| 246 |
+
raise ValueError("class_labels are required for conditional EDM2UNet.")
|
| 247 |
+
emb = mp_sum(emb, self.emb_label(class_labels * math.sqrt(class_labels.shape[1])), t=self.label_balance)
|
| 248 |
+
emb = mp_silu(emb)
|
| 249 |
+
|
| 250 |
+
x = torch.cat([x, torch.ones_like(x[:, :1])], dim=1)
|
| 251 |
+
skips = []
|
| 252 |
+
for name, block in self.enc.items():
|
| 253 |
+
x = block(x) if "conv" in name else block(x, emb)
|
| 254 |
+
skips.append(x)
|
| 255 |
+
|
| 256 |
+
for name, block in self.dec.items():
|
| 257 |
+
if "block" in name:
|
| 258 |
+
x = mp_cat(x, skips.pop(), t=self.concat_balance)
|
| 259 |
+
x = block(x, emb)
|
| 260 |
+
return self.out_conv(x, gain=self.out_gain)
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
@dataclass
|
| 264 |
+
class EDM2UNet2DOutput(BaseOutput):
|
| 265 |
+
sample: torch.Tensor
|
| 266 |
+
logvar: Optional[torch.Tensor] = None
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
_CONFIG_KEYS = (
|
| 271 |
+
"sample_size",
|
| 272 |
+
"in_channels",
|
| 273 |
+
"out_channels",
|
| 274 |
+
"num_class_embeds",
|
| 275 |
+
"use_fp16",
|
| 276 |
+
"sigma_data",
|
| 277 |
+
"logvar_channels",
|
| 278 |
+
"model_channels",
|
| 279 |
+
"channel_mult",
|
| 280 |
+
"channel_mult_noise",
|
| 281 |
+
"channel_mult_emb",
|
| 282 |
+
"num_blocks",
|
| 283 |
+
"attn_resolutions",
|
| 284 |
+
"label_balance",
|
| 285 |
+
"concat_balance",
|
| 286 |
+
"dropout",
|
| 287 |
+
"channels_per_head",
|
| 288 |
+
"res_balance",
|
| 289 |
+
"attn_balance",
|
| 290 |
+
"clip_act",
|
| 291 |
+
)
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
class EDM2UNet2DModel(ModelMixin, ConfigMixin):
|
| 295 |
+
@register_to_config
|
| 296 |
+
def __init__(
|
| 297 |
+
self,
|
| 298 |
+
sample_size: int = 64,
|
| 299 |
+
in_channels: int = 4,
|
| 300 |
+
out_channels: int = 4,
|
| 301 |
+
num_class_embeds: int = 0,
|
| 302 |
+
use_fp16: bool = True,
|
| 303 |
+
sigma_data: float = 0.5,
|
| 304 |
+
logvar_channels: int = 128,
|
| 305 |
+
model_channels: int = 192,
|
| 306 |
+
channel_mult: Tuple[int, ...] = (1, 2, 3, 4),
|
| 307 |
+
channel_mult_noise: Optional[int] = None,
|
| 308 |
+
channel_mult_emb: Optional[int] = None,
|
| 309 |
+
num_blocks: int = 3,
|
| 310 |
+
attn_resolutions: Tuple[int, ...] = (16, 8),
|
| 311 |
+
label_balance: float = 0.5,
|
| 312 |
+
concat_balance: float = 0.5,
|
| 313 |
+
dropout: float = 0.0,
|
| 314 |
+
channels_per_head: int = 64,
|
| 315 |
+
res_balance: float = 0.3,
|
| 316 |
+
attn_balance: float = 0.3,
|
| 317 |
+
clip_act: Optional[float] = 256,
|
| 318 |
+
):
|
| 319 |
+
super().__init__()
|
| 320 |
+
self.sample_size = sample_size
|
| 321 |
+
self.in_channels = in_channels
|
| 322 |
+
self.out_channels = out_channels
|
| 323 |
+
self.num_class_embeds = num_class_embeds
|
| 324 |
+
self.use_fp16 = use_fp16
|
| 325 |
+
self.sigma_data = sigma_data
|
| 326 |
+
self.model_channels = model_channels
|
| 327 |
+
self.channel_mult = channel_mult
|
| 328 |
+
self.channel_mult_noise = channel_mult_noise
|
| 329 |
+
self.channel_mult_emb = channel_mult_emb
|
| 330 |
+
self.num_blocks = num_blocks
|
| 331 |
+
self.attn_resolutions = attn_resolutions
|
| 332 |
+
self.label_balance = label_balance
|
| 333 |
+
self.concat_balance = concat_balance
|
| 334 |
+
self.dropout = dropout
|
| 335 |
+
self.channels_per_head = channels_per_head
|
| 336 |
+
self.res_balance = res_balance
|
| 337 |
+
self.attn_balance = attn_balance
|
| 338 |
+
self.clip_act = clip_act
|
| 339 |
+
self.unet = EDM2UNet(
|
| 340 |
+
img_resolution=sample_size,
|
| 341 |
+
img_channels=in_channels,
|
| 342 |
+
label_dim=num_class_embeds,
|
| 343 |
+
model_channels=model_channels,
|
| 344 |
+
channel_mult=channel_mult,
|
| 345 |
+
channel_mult_noise=channel_mult_noise,
|
| 346 |
+
channel_mult_emb=channel_mult_emb,
|
| 347 |
+
num_blocks=num_blocks,
|
| 348 |
+
attn_resolutions=attn_resolutions,
|
| 349 |
+
label_balance=label_balance,
|
| 350 |
+
concat_balance=concat_balance,
|
| 351 |
+
dropout=dropout,
|
| 352 |
+
channels_per_head=channels_per_head,
|
| 353 |
+
res_balance=res_balance,
|
| 354 |
+
attn_balance=attn_balance,
|
| 355 |
+
clip_act=clip_act,
|
| 356 |
+
)
|
| 357 |
+
self.logvar_fourier = MPFourier(logvar_channels)
|
| 358 |
+
self.logvar_linear = MPConv(logvar_channels, 1, kernel=())
|
| 359 |
+
|
| 360 |
+
def forward(
|
| 361 |
+
self,
|
| 362 |
+
sample: torch.Tensor,
|
| 363 |
+
sigma: torch.Tensor,
|
| 364 |
+
class_labels: Optional[torch.Tensor] = None,
|
| 365 |
+
force_fp32: bool = False,
|
| 366 |
+
return_logvar: bool = False,
|
| 367 |
+
return_dict: bool = True,
|
| 368 |
+
) -> EDM2UNet2DOutput:
|
| 369 |
+
x = sample.to(torch.float32)
|
| 370 |
+
sigma = sigma.to(torch.float32).reshape(-1, 1, 1, 1)
|
| 371 |
+
if self.num_class_embeds == 0:
|
| 372 |
+
class_labels = None
|
| 373 |
+
else:
|
| 374 |
+
if class_labels is None:
|
| 375 |
+
class_labels = torch.zeros([x.shape[0], self.num_class_embeds], device=x.device)
|
| 376 |
+
class_labels = class_labels.to(torch.float32).reshape(-1, self.num_class_embeds)
|
| 377 |
+
dtype = torch.float16 if (self.use_fp16 and not force_fp32 and x.device.type == "cuda") else torch.float32
|
| 378 |
+
|
| 379 |
+
c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2)
|
| 380 |
+
c_out = sigma * self.sigma_data / (sigma**2 + self.sigma_data**2).sqrt()
|
| 381 |
+
c_in = 1 / (self.sigma_data**2 + sigma**2).sqrt()
|
| 382 |
+
c_noise = sigma.flatten().log() / 4
|
| 383 |
+
|
| 384 |
+
x_in = (c_in * x).to(dtype)
|
| 385 |
+
f_x = self.unet(x_in, c_noise, class_labels)
|
| 386 |
+
d_x = c_skip * x + c_out * f_x.to(torch.float32)
|
| 387 |
+
|
| 388 |
+
logvar = None
|
| 389 |
+
if return_logvar:
|
| 390 |
+
logvar = self.logvar_linear(self.logvar_fourier(c_noise)).reshape(-1, 1, 1, 1)
|
| 391 |
+
|
| 392 |
+
if not return_dict:
|
| 393 |
+
return (d_x, logvar)
|
| 394 |
+
return EDM2UNet2DOutput(sample=d_x, logvar=logvar)
|
| 395 |
+
|
| 396 |
+
@classmethod
|
| 397 |
+
def from_pretrained(cls, pretrained_model_name_or_path: str, torch_dtype: Optional[torch.dtype] = None, **kwargs):
|
| 398 |
+
subfolder = kwargs.pop("subfolder", None)
|
| 399 |
+
model_dir = os.path.join(pretrained_model_name_or_path, subfolder) if subfolder else pretrained_model_name_or_path
|
| 400 |
+
with open(os.path.join(model_dir, "config.json"), "r", encoding="utf-8") as f:
|
| 401 |
+
config = json.load(f)
|
| 402 |
+
init_kwargs = {k: v for k, v in config.items() if k in _CONFIG_KEYS}
|
| 403 |
+
model = cls(**init_kwargs)
|
| 404 |
+
weight_file = os.path.join(model_dir, "diffusion_pytorch_model.safetensors")
|
| 405 |
+
if os.path.isfile(weight_file):
|
| 406 |
+
from safetensors.torch import load_file
|
| 407 |
+
|
| 408 |
+
state_dict = load_file(weight_file)
|
| 409 |
+
else:
|
| 410 |
+
state_dict = torch.load(os.path.join(model_dir, "diffusion_pytorch_model.bin"), map_location="cpu")
|
| 411 |
+
model.load_state_dict(state_dict, strict=True)
|
| 412 |
+
if torch_dtype is not None:
|
| 413 |
+
model = model.to(dtype=torch_dtype)
|
| 414 |
+
return model
|
| 415 |
+
|
| 416 |
+
def save_pretrained(self, save_directory: str, safe_serialization: bool = True):
|
| 417 |
+
os.makedirs(save_directory, exist_ok=True)
|
| 418 |
+
stored = dict(getattr(self, "config", {}))
|
| 419 |
+
config = {"_class_name": self.__class__.__name__}
|
| 420 |
+
for key in _CONFIG_KEYS:
|
| 421 |
+
if key in stored:
|
| 422 |
+
config[key] = stored[key]
|
| 423 |
+
elif hasattr(self, key):
|
| 424 |
+
config[key] = getattr(self, key)
|
| 425 |
+
with open(os.path.join(save_directory, "config.json"), "w", encoding="utf-8") as f:
|
| 426 |
+
json.dump(config, f, indent=2, sort_keys=True)
|
| 427 |
+
f.write("\n")
|
| 428 |
+
state_dict = self.state_dict()
|
| 429 |
+
if safe_serialization:
|
| 430 |
+
from safetensors.torch import save_file
|
| 431 |
+
|
| 432 |
+
save_file(state_dict, os.path.join(save_directory, "diffusion_pytorch_model.safetensors"))
|
| 433 |
+
else:
|
| 434 |
+
torch.save(state_dict, os.path.join(save_directory, "diffusion_pytorch_model.bin"))
|
edm2-img512-m-fid/vae/config.json
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_class_name": "AutoencoderKL",
|
| 3 |
+
"_diffusers_version": "0.36.0",
|
| 4 |
+
"_name_or_path": "stabilityai/sd-vae-ft-mse",
|
| 5 |
+
"act_fn": "silu",
|
| 6 |
+
"block_out_channels": [
|
| 7 |
+
128,
|
| 8 |
+
256,
|
| 9 |
+
512,
|
| 10 |
+
512
|
| 11 |
+
],
|
| 12 |
+
"down_block_types": [
|
| 13 |
+
"DownEncoderBlock2D",
|
| 14 |
+
"DownEncoderBlock2D",
|
| 15 |
+
"DownEncoderBlock2D",
|
| 16 |
+
"DownEncoderBlock2D"
|
| 17 |
+
],
|
| 18 |
+
"force_upcast": true,
|
| 19 |
+
"in_channels": 3,
|
| 20 |
+
"latent_channels": 4,
|
| 21 |
+
"latents_mean": null,
|
| 22 |
+
"latents_std": null,
|
| 23 |
+
"layers_per_block": 2,
|
| 24 |
+
"mid_block_add_attention": true,
|
| 25 |
+
"norm_num_groups": 32,
|
| 26 |
+
"out_channels": 3,
|
| 27 |
+
"sample_size": 256,
|
| 28 |
+
"scaling_factor": 0.18215,
|
| 29 |
+
"shift_factor": null,
|
| 30 |
+
"up_block_types": [
|
| 31 |
+
"UpDecoderBlock2D",
|
| 32 |
+
"UpDecoderBlock2D",
|
| 33 |
+
"UpDecoderBlock2D",
|
| 34 |
+
"UpDecoderBlock2D"
|
| 35 |
+
],
|
| 36 |
+
"use_post_quant_conv": true,
|
| 37 |
+
"use_quant_conv": true
|
| 38 |
+
}
|
edm2-img512-m-fid/vae/diffusion_pytorch_model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a1d993488569e928462932c8c38a0760b874d166399b14414135bd9c42df5815
|
| 3 |
+
size 334643276
|
edm2-img512-s-fid/demo.png
ADDED
|
Git LFS Details
|
edm2-img512-s-fid/model_index.json
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_class_name": [
|
| 3 |
+
"pipeline",
|
| 4 |
+
"EDM2Pipeline"
|
| 5 |
+
],
|
| 6 |
+
"_diffusers_version": "0.31.0",
|
| 7 |
+
"scheduler": [
|
| 8 |
+
"diffusers",
|
| 9 |
+
"EDMEulerScheduler"
|
| 10 |
+
],
|
| 11 |
+
"unet": [
|
| 12 |
+
"unet_edm2",
|
| 13 |
+
"EDM2UNet2DModel"
|
| 14 |
+
],
|
| 15 |
+
"vae": [
|
| 16 |
+
"diffusers",
|
| 17 |
+
"AutoencoderKL"
|
| 18 |
+
]
|
| 19 |
+
}
|
edm2-img512-s-fid/pipeline.py
ADDED
|
@@ -0,0 +1,406 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Hub custom pipeline: EDM2Pipeline.
|
| 2 |
+
Load with native Hugging Face diffusers and trust_remote_code=True.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from __future__ import annotations
|
| 6 |
+
|
| 7 |
+
# Copyright 2026 The HuggingFace Team. All rights reserved.
|
| 8 |
+
#
|
| 9 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 10 |
+
# you may not use this file except in compliance with the License.
|
| 11 |
+
# You may obtain a copy of the License at
|
| 12 |
+
#
|
| 13 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 14 |
+
#
|
| 15 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 16 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 17 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 18 |
+
# See the License for the specific language governing permissions and
|
| 19 |
+
# limitations under the License.
|
| 20 |
+
|
| 21 |
+
import json
|
| 22 |
+
import os
|
| 23 |
+
from pathlib import Path
|
| 24 |
+
from typing import Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Union
|
| 25 |
+
|
| 26 |
+
import numpy as np
|
| 27 |
+
import torch
|
| 28 |
+
|
| 29 |
+
from diffusers.image_processor import VaeImageProcessor
|
| 30 |
+
from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
| 31 |
+
from diffusers.utils import replace_example_docstring
|
| 32 |
+
from diffusers.utils.torch_utils import randn_tensor
|
| 33 |
+
|
| 34 |
+
EXAMPLE_DOC_STRING = """
|
| 35 |
+
Examples:
|
| 36 |
+
```py
|
| 37 |
+
>>> from pathlib import Path
|
| 38 |
+
>>> import torch
|
| 39 |
+
>>> from diffusers import DiffusionPipeline
|
| 40 |
+
|
| 41 |
+
>>> model_dir = Path("BiliSakura/EDM2-diffusers/edm2-img512-xs-fid").resolve()
|
| 42 |
+
>>> pipe = DiffusionPipeline.from_pretrained(
|
| 43 |
+
... str(model_dir),
|
| 44 |
+
... local_files_only=True,
|
| 45 |
+
... custom_pipeline=str(model_dir / "pipeline.py"),
|
| 46 |
+
... trust_remote_code=True,
|
| 47 |
+
... torch_dtype=torch.float32,
|
| 48 |
+
... )
|
| 49 |
+
>>> pipe.to("cuda")
|
| 50 |
+
|
| 51 |
+
>>> generator = torch.Generator(device="cuda").manual_seed(42)
|
| 52 |
+
>>> image = pipe(
|
| 53 |
+
... class_labels=207,
|
| 54 |
+
... num_inference_steps=32,
|
| 55 |
+
... guidance_scale=1.0,
|
| 56 |
+
... generator=generator,
|
| 57 |
+
... ).images[0]
|
| 58 |
+
>>> image.save("demo.png")
|
| 59 |
+
```
|
| 60 |
+
"""
|
| 61 |
+
|
| 62 |
+
# Default Stability VAE latent whitening used by NVlabs/edm2 (training/encoders.py).
|
| 63 |
+
_STABILITY_VAE_SCALE = np.float32(0.5) / np.float32([4.17, 4.62, 3.71, 3.28])
|
| 64 |
+
_STABILITY_VAE_BIAS = np.float32(0.0) - np.float32([5.81, 3.25, 0.12, -2.15]) * _STABILITY_VAE_SCALE
|
| 65 |
+
|
| 66 |
+
class EDM2Pipeline(DiffusionPipeline):
|
| 67 |
+
r"""
|
| 68 |
+
Pipeline for class-conditional image generation with EDM2
|
| 69 |
+
([Analyzing and Improving the Training Dynamics of Diffusion Models](https://arxiv.org/abs/2312.02696)).
|
| 70 |
+
|
| 71 |
+
Parameters:
|
| 72 |
+
unet ([`EDM2UNet2DModel`]):
|
| 73 |
+
Main magnitude-preserving U-Net with EDM preconditioning.
|
| 74 |
+
scheduler ([`EDMEulerScheduler`]):
|
| 75 |
+
Built-in diffusers scheduler used for the Karras sigma schedule. EDM2 Heun sampling runs in
|
| 76 |
+
the pipeline because the UNet returns denoised latents rather than noise predictions.
|
| 77 |
+
vae ([`AutoencoderKL`], *optional*):
|
| 78 |
+
Decoder for 512px latent-diffusion checkpoints. Required when `unet.in_channels == 4`.
|
| 79 |
+
gnet ([`EDM2UNet2DModel`], *optional*):
|
| 80 |
+
Guiding network for autoguidance (`ref.lerp(main, guidance_scale)`).
|
| 81 |
+
id2label (`dict[int, str]`, *optional*):
|
| 82 |
+
ImageNet class id to English label mapping.
|
| 83 |
+
"""
|
| 84 |
+
|
| 85 |
+
model_cpu_offload_seq = "unet->gnet->vae"
|
| 86 |
+
_optional_components = ["vae", "gnet"]
|
| 87 |
+
|
| 88 |
+
def __init__(
|
| 89 |
+
self,
|
| 90 |
+
unet,
|
| 91 |
+
scheduler,
|
| 92 |
+
vae=None,
|
| 93 |
+
gnet=None,
|
| 94 |
+
id2label: Optional[Dict[Union[int, str], str]] = None,
|
| 95 |
+
) -> None:
|
| 96 |
+
super().__init__()
|
| 97 |
+
self.register_modules(unet=unet, scheduler=scheduler, vae=vae, gnet=gnet)
|
| 98 |
+
self._id2label = self._normalize_id2label(id2label)
|
| 99 |
+
self.labels = self._build_label2id(self._id2label)
|
| 100 |
+
self._labels_loaded_from_model_index = bool(self._id2label)
|
| 101 |
+
self.vae_scale_factor = 8 if self.vae is not None else 1
|
| 102 |
+
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_normalize=False)
|
| 103 |
+
|
| 104 |
+
@staticmethod
|
| 105 |
+
def _normalize_id2label(id2label: Optional[Dict[Union[int, str], str]]) -> Dict[int, str]:
|
| 106 |
+
if not id2label:
|
| 107 |
+
return {}
|
| 108 |
+
return {int(key): value for key, value in id2label.items()}
|
| 109 |
+
|
| 110 |
+
@staticmethod
|
| 111 |
+
def _build_label2id(id2label: Dict[int, str]) -> Dict[str, int]:
|
| 112 |
+
label2id: Dict[str, int] = {}
|
| 113 |
+
for class_id, value in id2label.items():
|
| 114 |
+
for synonym in value.split(","):
|
| 115 |
+
synonym = synonym.strip()
|
| 116 |
+
if synonym:
|
| 117 |
+
label2id[synonym] = int(class_id)
|
| 118 |
+
return dict(sorted(label2id.items()))
|
| 119 |
+
|
| 120 |
+
def _ensure_labels_loaded(self) -> None:
|
| 121 |
+
if self._labels_loaded_from_model_index:
|
| 122 |
+
return
|
| 123 |
+
loaded = self._read_id2label_from_model_index(getattr(self.config, "_name_or_path", None))
|
| 124 |
+
if loaded:
|
| 125 |
+
self._id2label = loaded
|
| 126 |
+
self.labels = self._build_label2id(self._id2label)
|
| 127 |
+
self._labels_loaded_from_model_index = True
|
| 128 |
+
|
| 129 |
+
@staticmethod
|
| 130 |
+
def _read_id2label_from_model_index(variant_path: Optional[str]) -> Dict[int, str]:
|
| 131 |
+
if not variant_path:
|
| 132 |
+
return {}
|
| 133 |
+
model_index_path = Path(variant_path).resolve() / "model_index.json"
|
| 134 |
+
if not model_index_path.is_file():
|
| 135 |
+
return {}
|
| 136 |
+
raw = json.loads(model_index_path.read_text(encoding="utf-8"))
|
| 137 |
+
id2label = raw.get("id2label")
|
| 138 |
+
if not isinstance(id2label, dict):
|
| 139 |
+
return {}
|
| 140 |
+
return {int(key): value for key, value in id2label.items()}
|
| 141 |
+
|
| 142 |
+
@property
|
| 143 |
+
def id2label(self) -> Dict[int, str]:
|
| 144 |
+
r"""ImageNet class id to English label string (comma-separated synonyms)."""
|
| 145 |
+
self._ensure_labels_loaded()
|
| 146 |
+
return self._id2label
|
| 147 |
+
|
| 148 |
+
def get_label_ids(self, label: Union[str, List[str]]) -> List[int]:
|
| 149 |
+
r"""
|
| 150 |
+
Map ImageNet label strings to class ids.
|
| 151 |
+
|
| 152 |
+
Args:
|
| 153 |
+
label (`str` or `list[str]`):
|
| 154 |
+
One or more English label strings that match entries in `id2label`.
|
| 155 |
+
"""
|
| 156 |
+
self._ensure_labels_loaded()
|
| 157 |
+
if not self.labels:
|
| 158 |
+
raise ValueError("No English labels loaded. Add `id2label` to model_index.json.")
|
| 159 |
+
labels = [label] if isinstance(label, str) else list(label)
|
| 160 |
+
missing = [item for item in labels if item not in self.labels]
|
| 161 |
+
if missing:
|
| 162 |
+
preview = ", ".join(list(self.labels.keys())[:8])
|
| 163 |
+
raise ValueError(f"Unknown English label(s): {missing}. Example valid labels: {preview}, ...")
|
| 164 |
+
return [self.labels[item] for item in labels]
|
| 165 |
+
|
| 166 |
+
def _default_image_size(self) -> int:
|
| 167 |
+
latent_size = int(getattr(self.unet, "sample_size", getattr(self.unet.config, "sample_size", 64)))
|
| 168 |
+
return latent_size * self.vae_scale_factor
|
| 169 |
+
|
| 170 |
+
def check_inputs(
|
| 171 |
+
self,
|
| 172 |
+
height: int,
|
| 173 |
+
width: int,
|
| 174 |
+
num_inference_steps: int,
|
| 175 |
+
guidance_scale: float,
|
| 176 |
+
output_type: str,
|
| 177 |
+
) -> None:
|
| 178 |
+
if num_inference_steps < 1:
|
| 179 |
+
raise ValueError("num_inference_steps must be >= 1.")
|
| 180 |
+
if guidance_scale < 1.0:
|
| 181 |
+
raise ValueError("guidance_scale must be >= 1.0.")
|
| 182 |
+
if guidance_scale > 1.0 and self.gnet is None:
|
| 183 |
+
raise ValueError("guidance_scale > 1.0 requires a guiding network (`gnet`).")
|
| 184 |
+
if output_type not in {"pil", "np", "pt", "latent"}:
|
| 185 |
+
raise ValueError("output_type must be one of: 'pil', 'np', 'pt', 'latent'.")
|
| 186 |
+
|
| 187 |
+
native_size = self._default_image_size()
|
| 188 |
+
if height != native_size or width != native_size:
|
| 189 |
+
raise ValueError(
|
| 190 |
+
f"EDM2 expects native resolution height=width={native_size}. "
|
| 191 |
+
f"Got height={height}, width={width}."
|
| 192 |
+
)
|
| 193 |
+
|
| 194 |
+
def _normalize_class_labels(
|
| 195 |
+
self,
|
| 196 |
+
class_labels: Optional[Union[int, str, Sequence[Union[int, str]], torch.Tensor]],
|
| 197 |
+
batch_size: int,
|
| 198 |
+
device: torch.device,
|
| 199 |
+
) -> Optional[torch.Tensor]:
|
| 200 |
+
label_dim = int(getattr(self.unet, "num_class_embeds", getattr(self.unet.config, "num_class_embeds", 0)))
|
| 201 |
+
if label_dim == 0:
|
| 202 |
+
return None
|
| 203 |
+
if class_labels is None:
|
| 204 |
+
indices = torch.randint(label_dim, size=(batch_size,), device=device)
|
| 205 |
+
return torch.eye(label_dim, device=device, dtype=torch.float32)[indices]
|
| 206 |
+
|
| 207 |
+
if isinstance(class_labels, str):
|
| 208 |
+
class_labels = self.get_label_ids(class_labels)[0]
|
| 209 |
+
elif isinstance(class_labels, Sequence) and class_labels and isinstance(class_labels[0], str):
|
| 210 |
+
class_labels = self.get_label_ids(list(class_labels))
|
| 211 |
+
|
| 212 |
+
if isinstance(class_labels, int):
|
| 213 |
+
indices = torch.full((batch_size,), class_labels, device=device, dtype=torch.long)
|
| 214 |
+
elif isinstance(class_labels, torch.Tensor):
|
| 215 |
+
if class_labels.ndim == 2:
|
| 216 |
+
labels = class_labels.to(device=device, dtype=torch.float32)
|
| 217 |
+
if labels.shape[0] != batch_size:
|
| 218 |
+
raise ValueError(f"class_labels batch must match batch_size={batch_size}.")
|
| 219 |
+
return labels
|
| 220 |
+
indices = class_labels.to(device=device, dtype=torch.long).flatten()
|
| 221 |
+
else:
|
| 222 |
+
indices = torch.tensor(list(class_labels), device=device, dtype=torch.long)
|
| 223 |
+
|
| 224 |
+
if indices.numel() == 1 and batch_size > 1:
|
| 225 |
+
indices = indices.repeat(batch_size)
|
| 226 |
+
if indices.numel() != batch_size:
|
| 227 |
+
raise ValueError(f"class_labels must resolve to batch size {batch_size}, got {indices.numel()}.")
|
| 228 |
+
return torch.eye(label_dim, device=device, dtype=torch.float32)[indices]
|
| 229 |
+
|
| 230 |
+
def prepare_latents(
|
| 231 |
+
self,
|
| 232 |
+
batch_size: int,
|
| 233 |
+
height: int,
|
| 234 |
+
width: int,
|
| 235 |
+
dtype: torch.dtype,
|
| 236 |
+
device: torch.device,
|
| 237 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]],
|
| 238 |
+
) -> torch.Tensor:
|
| 239 |
+
in_channels = int(getattr(self.unet, "in_channels", getattr(self.unet.config, "in_channels", 4)))
|
| 240 |
+
latent_size = height // self.vae_scale_factor
|
| 241 |
+
return randn_tensor(
|
| 242 |
+
(batch_size, in_channels, latent_size, latent_size),
|
| 243 |
+
generator=generator,
|
| 244 |
+
device=device,
|
| 245 |
+
dtype=torch.float32,
|
| 246 |
+
)
|
| 247 |
+
|
| 248 |
+
def decode_latents(self, latents: torch.Tensor, output_type: str = "pil"):
|
| 249 |
+
if output_type == "latent":
|
| 250 |
+
return latents
|
| 251 |
+
|
| 252 |
+
in_channels = int(getattr(self.unet, "in_channels", getattr(self.unet.config, "in_channels", 3)))
|
| 253 |
+
if self.vae is None:
|
| 254 |
+
image = (latents.to(torch.float32) * 127.5 + 128).clip(0, 255) / 255.0
|
| 255 |
+
return self.image_processor.postprocess(image, output_type=output_type)
|
| 256 |
+
|
| 257 |
+
if in_channels == 4:
|
| 258 |
+
x = latents.to(torch.float32)
|
| 259 |
+
scale = torch.as_tensor(_STABILITY_VAE_SCALE, dtype=x.dtype, device=x.device).reshape(1, -1, 1, 1)
|
| 260 |
+
bias = torch.as_tensor(_STABILITY_VAE_BIAS, dtype=x.dtype, device=x.device).reshape(1, -1, 1, 1)
|
| 261 |
+
x = (x - bias) / scale
|
| 262 |
+
else:
|
| 263 |
+
x = latents.to(torch.float32)
|
| 264 |
+
|
| 265 |
+
vae_dtype = getattr(self.vae, "dtype", None) or next(self.vae.parameters()).dtype
|
| 266 |
+
image = self.vae.decode(x.to(dtype=vae_dtype)).sample.to(torch.float32).clamp(0, 1)
|
| 267 |
+
|
| 268 |
+
return self.image_processor.postprocess(image, output_type=output_type)
|
| 269 |
+
|
| 270 |
+
@staticmethod
|
| 271 |
+
def _apply_autoguidance(
|
| 272 |
+
main: torch.Tensor,
|
| 273 |
+
ref: torch.Tensor,
|
| 274 |
+
guidance_scale: float,
|
| 275 |
+
) -> torch.Tensor:
|
| 276 |
+
return ref.lerp(main, guidance_scale)
|
| 277 |
+
|
| 278 |
+
@staticmethod
|
| 279 |
+
def _sample_edm2_heun(
|
| 280 |
+
denoise_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
|
| 281 |
+
noise: torch.Tensor,
|
| 282 |
+
sigmas: torch.Tensor,
|
| 283 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 284 |
+
progress_bar: Optional[Callable[[Iterable], Iterable]] = None,
|
| 285 |
+
dtype: torch.dtype = torch.float32,
|
| 286 |
+
) -> torch.Tensor:
|
| 287 |
+
"""NVlabs EDM2 Heun sampler (generate_images.edm_sampler, guidance=1, S_churn=0)."""
|
| 288 |
+
x_next = noise.to(dtype) * sigmas[0]
|
| 289 |
+
|
| 290 |
+
sigma_pairs = list(zip(sigmas[:-1], sigmas[1:]))
|
| 291 |
+
if progress_bar is not None:
|
| 292 |
+
sigma_pairs = progress_bar(sigma_pairs)
|
| 293 |
+
|
| 294 |
+
num_steps = len(sigma_pairs)
|
| 295 |
+
for i, (sigma_cur, sigma_next) in enumerate(sigma_pairs):
|
| 296 |
+
x_hat, sigma_hat = x_next, sigma_cur
|
| 297 |
+
d_cur = (x_hat - denoise_fn(x_hat, sigma_hat)) / sigma_hat
|
| 298 |
+
x_next = x_hat + (sigma_next - sigma_hat) * d_cur
|
| 299 |
+
if i < num_steps - 1:
|
| 300 |
+
d_prime = (x_next - denoise_fn(x_next, sigma_next)) / sigma_next
|
| 301 |
+
x_next = x_hat + (sigma_next - sigma_hat) * (0.5 * d_cur + 0.5 * d_prime)
|
| 302 |
+
return x_next
|
| 303 |
+
|
| 304 |
+
@torch.inference_mode()
|
| 305 |
+
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
| 306 |
+
def __call__(
|
| 307 |
+
self,
|
| 308 |
+
class_labels: Optional[Union[int, str, Sequence[Union[int, str]], torch.Tensor]] = None,
|
| 309 |
+
batch_size: int = 1,
|
| 310 |
+
height: Optional[int] = None,
|
| 311 |
+
width: Optional[int] = None,
|
| 312 |
+
num_inference_steps: int = 32,
|
| 313 |
+
guidance_scale: float = 1.0,
|
| 314 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 315 |
+
output_type: str = "pil",
|
| 316 |
+
return_dict: bool = True,
|
| 317 |
+
) -> Union[ImagePipelineOutput, Tuple]:
|
| 318 |
+
r"""
|
| 319 |
+
Generate class-conditional images with EDM2.
|
| 320 |
+
|
| 321 |
+
Args:
|
| 322 |
+
class_labels (`int`, `str`, `list[int]`, `list[str]`, or `torch.Tensor`, *optional*):
|
| 323 |
+
ImageNet class indices, English label strings, or one-hot float tensors.
|
| 324 |
+
Random classes are sampled when omitted on conditional models.
|
| 325 |
+
batch_size (`int`, defaults to `1`):
|
| 326 |
+
Number of images to generate.
|
| 327 |
+
height (`int`, *optional*):
|
| 328 |
+
Output height in pixels. Defaults to the pretrained native resolution.
|
| 329 |
+
width (`int`, *optional*):
|
| 330 |
+
Output width in pixels. Defaults to the pretrained native resolution.
|
| 331 |
+
num_inference_steps (`int`, defaults to `32`):
|
| 332 |
+
Number of EDM2 Heun steps (NVlabs default).
|
| 333 |
+
guidance_scale (`float`, defaults to `1.0`):
|
| 334 |
+
Autoguidance strength. Values above `1.0` blend the main net with `gnet`
|
| 335 |
+
via `gnet_output.lerp(unet_output, guidance_scale)`.
|
| 336 |
+
generator (`torch.Generator`, *optional*):
|
| 337 |
+
RNG for reproducibility.
|
| 338 |
+
output_type (`str`, defaults to `"pil"`):
|
| 339 |
+
`"pil"`, `"np"`, `"pt"`, or `"latent"`.
|
| 340 |
+
return_dict (`bool`, defaults to `True`):
|
| 341 |
+
Return [`~pipelines.pipeline_utils.ImagePipelineOutput`] if True.
|
| 342 |
+
|
| 343 |
+
Examples:
|
| 344 |
+
<!-- this section is replaced by replace_example_docstring -->
|
| 345 |
+
"""
|
| 346 |
+
default_size = self._default_image_size()
|
| 347 |
+
height = int(height or default_size)
|
| 348 |
+
width = int(width or default_size)
|
| 349 |
+
self.check_inputs(height, width, num_inference_steps, guidance_scale, output_type)
|
| 350 |
+
|
| 351 |
+
device = self._execution_device
|
| 352 |
+
dtype = self.unet.dtype
|
| 353 |
+
labels = self._normalize_class_labels(class_labels, batch_size=batch_size, device=device)
|
| 354 |
+
noise = self.prepare_latents(batch_size, height, width, dtype, device, generator)
|
| 355 |
+
|
| 356 |
+
def denoise_fn(x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor:
|
| 357 |
+
sigma_batch = sigma.reshape(1).expand(batch_size)
|
| 358 |
+
main = self.unet(
|
| 359 |
+
sample=x,
|
| 360 |
+
sigma=sigma_batch,
|
| 361 |
+
class_labels=labels,
|
| 362 |
+
force_fp32=True,
|
| 363 |
+
).sample
|
| 364 |
+
if guidance_scale == 1.0 or self.gnet is None:
|
| 365 |
+
return main.to(torch.float32)
|
| 366 |
+
ref = self.gnet(
|
| 367 |
+
sample=x,
|
| 368 |
+
sigma=sigma_batch,
|
| 369 |
+
class_labels=labels,
|
| 370 |
+
force_fp32=True,
|
| 371 |
+
).sample
|
| 372 |
+
return self._apply_autoguidance(main, ref, guidance_scale).to(torch.float32)
|
| 373 |
+
|
| 374 |
+
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
| 375 |
+
latents = self._sample_edm2_heun(
|
| 376 |
+
denoise_fn=denoise_fn,
|
| 377 |
+
noise=noise,
|
| 378 |
+
sigmas=self.scheduler.sigmas.to(device),
|
| 379 |
+
generator=generator,
|
| 380 |
+
progress_bar=self.progress_bar,
|
| 381 |
+
dtype=torch.float32,
|
| 382 |
+
)
|
| 383 |
+
|
| 384 |
+
image = self.decode_latents(latents, output_type=output_type)
|
| 385 |
+
if not return_dict:
|
| 386 |
+
return (image, latents)
|
| 387 |
+
return ImagePipelineOutput(images=image)
|
| 388 |
+
|
| 389 |
+
@classmethod
|
| 390 |
+
def _load_vae(cls, pretrained_model_name_or_path: str, torch_dtype: Optional[torch.dtype] = None):
|
| 391 |
+
vae_dir = os.path.join(pretrained_model_name_or_path, "vae")
|
| 392 |
+
if os.path.isdir(vae_dir):
|
| 393 |
+
try:
|
| 394 |
+
|
| 395 |
+
return AutoencoderKL.from_pretrained(vae_dir, torch_dtype=torch_dtype)
|
| 396 |
+
except Exception:
|
| 397 |
+
return None
|
| 398 |
+
|
| 399 |
+
vae_hint = os.path.join(pretrained_model_name_or_path, "vae_pretrained_model_name_or_path.txt")
|
| 400 |
+
if os.path.isfile(vae_hint):
|
| 401 |
+
with open(vae_hint, "r", encoding="utf-8") as f:
|
| 402 |
+
hub_id = f.read().strip()
|
| 403 |
+
if hub_id:
|
| 404 |
+
|
| 405 |
+
return AutoencoderKL.from_pretrained(hub_id, torch_dtype=torch_dtype)
|
| 406 |
+
return None
|
edm2-img512-s-fid/scheduler/scheduler_config.json
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_class_name": "EDMEulerScheduler",
|
| 3 |
+
"final_sigmas_type": "zero",
|
| 4 |
+
"num_train_timesteps": 1000,
|
| 5 |
+
"prediction_type": "epsilon",
|
| 6 |
+
"rho": 7.0,
|
| 7 |
+
"sigma_data": 0.5,
|
| 8 |
+
"sigma_max": 80.0,
|
| 9 |
+
"sigma_min": 0.002,
|
| 10 |
+
"sigma_schedule": "karras"
|
| 11 |
+
}
|
edm2-img512-s-fid/unet/config.json
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_class_name": "EDM2UNet2DModel",
|
| 3 |
+
"attn_balance": 0.3,
|
| 4 |
+
"attn_resolutions": [
|
| 5 |
+
16,
|
| 6 |
+
8
|
| 7 |
+
],
|
| 8 |
+
"channel_mult": [
|
| 9 |
+
1,
|
| 10 |
+
2,
|
| 11 |
+
3,
|
| 12 |
+
4
|
| 13 |
+
],
|
| 14 |
+
"channel_mult_emb": 4,
|
| 15 |
+
"channel_mult_noise": 1,
|
| 16 |
+
"channels_per_head": 64,
|
| 17 |
+
"clip_act": 256,
|
| 18 |
+
"concat_balance": 0.5,
|
| 19 |
+
"dropout": 0.0,
|
| 20 |
+
"in_channels": 4,
|
| 21 |
+
"label_balance": 0.5,
|
| 22 |
+
"logvar_channels": 128,
|
| 23 |
+
"model_channels": 192,
|
| 24 |
+
"num_blocks": 3,
|
| 25 |
+
"num_class_embeds": 1000,
|
| 26 |
+
"out_channels": 4,
|
| 27 |
+
"res_balance": 0.3,
|
| 28 |
+
"sample_size": 64,
|
| 29 |
+
"sigma_data": 0.5,
|
| 30 |
+
"use_fp16": true
|
| 31 |
+
}
|
edm2-img512-s-fid/unet/diffusion_pytorch_model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:5dee937e117e2367ede680aae4edf96635ff4debb9ae73f2617111991aa83d61
|
| 3 |
+
size 1120876188
|
edm2-img512-s-fid/unet/unet_edm2.py
ADDED
|
@@ -0,0 +1,434 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import json
|
| 3 |
+
import os
|
| 4 |
+
from dataclasses import dataclass
|
| 5 |
+
from typing import List, Optional, Tuple
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
import torch
|
| 9 |
+
|
| 10 |
+
try:
|
| 11 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
| 12 |
+
from diffusers.models.modeling_utils import ModelMixin
|
| 13 |
+
from diffusers.utils import BaseOutput
|
| 14 |
+
except ImportError: # pragma: no cover
|
| 15 |
+
class ModelMixin(torch.nn.Module):
|
| 16 |
+
pass
|
| 17 |
+
|
| 18 |
+
class ConfigMixin:
|
| 19 |
+
config = {}
|
| 20 |
+
|
| 21 |
+
def register_to_config(self, **kwargs):
|
| 22 |
+
self.config = kwargs
|
| 23 |
+
|
| 24 |
+
def register_to_config(func):
|
| 25 |
+
return func
|
| 26 |
+
|
| 27 |
+
@dataclass
|
| 28 |
+
class BaseOutput:
|
| 29 |
+
pass
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def normalize(x: torch.Tensor, dim: Optional[List[int]] = None, eps: float = 1e-4) -> torch.Tensor:
|
| 33 |
+
if dim is None:
|
| 34 |
+
dim = list(range(1, x.ndim))
|
| 35 |
+
norm = torch.linalg.vector_norm(x, dim=dim, keepdim=True, dtype=torch.float32)
|
| 36 |
+
norm = torch.add(eps, norm, alpha=math.sqrt(norm.numel() / x.numel()))
|
| 37 |
+
return x / norm.to(x.dtype)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def resample(x: torch.Tensor, f: List[float], mode: str = "keep") -> torch.Tensor:
|
| 41 |
+
if mode == "keep":
|
| 42 |
+
return x
|
| 43 |
+
filt = np.float32(f)
|
| 44 |
+
pad = (len(filt) - 1) // 2
|
| 45 |
+
filt = filt / filt.sum()
|
| 46 |
+
filt = np.outer(filt, filt)[np.newaxis, np.newaxis, :, :]
|
| 47 |
+
filt = torch.as_tensor(filt, dtype=x.dtype, device=x.device)
|
| 48 |
+
c = x.shape[1]
|
| 49 |
+
if mode == "down":
|
| 50 |
+
return torch.nn.functional.conv2d(x, filt.tile([c, 1, 1, 1]), groups=c, stride=2, padding=(pad,))
|
| 51 |
+
return torch.nn.functional.conv_transpose2d(x, (filt * 4).tile([c, 1, 1, 1]), groups=c, stride=2, padding=(pad,))
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def mp_silu(x: torch.Tensor) -> torch.Tensor:
|
| 55 |
+
return torch.nn.functional.silu(x) / 0.596
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def mp_sum(a: torch.Tensor, b: torch.Tensor, t: float = 0.5) -> torch.Tensor:
|
| 59 |
+
return a.lerp(b, t) / math.sqrt((1 - t) ** 2 + t ** 2)
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def mp_cat(a: torch.Tensor, b: torch.Tensor, dim: int = 1, t: float = 0.5) -> torch.Tensor:
|
| 63 |
+
na = a.shape[dim]
|
| 64 |
+
nb = b.shape[dim]
|
| 65 |
+
c = math.sqrt((na + nb) / ((1 - t) ** 2 + t ** 2))
|
| 66 |
+
wa = c / math.sqrt(na) * (1 - t)
|
| 67 |
+
wb = c / math.sqrt(nb) * t
|
| 68 |
+
return torch.cat([wa * a, wb * b], dim=dim)
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
class MPFourier(torch.nn.Module):
|
| 72 |
+
def __init__(self, num_channels: int, bandwidth: float = 1):
|
| 73 |
+
super().__init__()
|
| 74 |
+
self.register_buffer("freqs", 2 * math.pi * torch.randn(num_channels) * bandwidth)
|
| 75 |
+
self.register_buffer("phases", 2 * math.pi * torch.rand(num_channels))
|
| 76 |
+
|
| 77 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 78 |
+
y = x.to(torch.float32).ger(self.freqs.to(torch.float32))
|
| 79 |
+
y = y + self.phases.to(torch.float32)
|
| 80 |
+
y = y.cos() * math.sqrt(2)
|
| 81 |
+
return y.to(x.dtype)
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
class MPConv(torch.nn.Module):
|
| 85 |
+
def __init__(self, in_channels: int, out_channels: int, kernel: Tuple[int, ...]):
|
| 86 |
+
super().__init__()
|
| 87 |
+
self.out_channels = out_channels
|
| 88 |
+
self.weight = torch.nn.Parameter(torch.randn(out_channels, in_channels, *kernel))
|
| 89 |
+
|
| 90 |
+
def forward(self, x: torch.Tensor, gain: float = 1) -> torch.Tensor:
|
| 91 |
+
w = self.weight.to(torch.float32)
|
| 92 |
+
if self.training:
|
| 93 |
+
with torch.no_grad():
|
| 94 |
+
self.weight.copy_(normalize(w))
|
| 95 |
+
w = normalize(w)
|
| 96 |
+
w = w * (gain / math.sqrt(w[0].numel()))
|
| 97 |
+
w = w.to(x.dtype)
|
| 98 |
+
if w.ndim == 2:
|
| 99 |
+
return x @ w.t()
|
| 100 |
+
return torch.nn.functional.conv2d(x, w, padding=(w.shape[-1] // 2,))
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
class Block(torch.nn.Module):
|
| 104 |
+
def __init__(
|
| 105 |
+
self,
|
| 106 |
+
in_channels: int,
|
| 107 |
+
out_channels: int,
|
| 108 |
+
emb_channels: int,
|
| 109 |
+
flavor: str = "enc",
|
| 110 |
+
resample_mode: str = "keep",
|
| 111 |
+
resample_filter: List[float] = [1, 1],
|
| 112 |
+
attention: bool = False,
|
| 113 |
+
channels_per_head: int = 64,
|
| 114 |
+
dropout: float = 0.0,
|
| 115 |
+
res_balance: float = 0.3,
|
| 116 |
+
attn_balance: float = 0.3,
|
| 117 |
+
clip_act: Optional[float] = 256,
|
| 118 |
+
):
|
| 119 |
+
super().__init__()
|
| 120 |
+
self.out_channels = out_channels
|
| 121 |
+
self.flavor = flavor
|
| 122 |
+
self.resample_filter = resample_filter
|
| 123 |
+
self.resample_mode = resample_mode
|
| 124 |
+
self.num_heads = out_channels // channels_per_head if attention else 0
|
| 125 |
+
self.dropout = dropout
|
| 126 |
+
self.res_balance = res_balance
|
| 127 |
+
self.attn_balance = attn_balance
|
| 128 |
+
self.clip_act = clip_act
|
| 129 |
+
self.emb_gain = torch.nn.Parameter(torch.zeros([]))
|
| 130 |
+
self.conv_res0 = MPConv(out_channels if flavor == "enc" else in_channels, out_channels, kernel=(3, 3))
|
| 131 |
+
self.emb_linear = MPConv(emb_channels, out_channels, kernel=())
|
| 132 |
+
self.conv_res1 = MPConv(out_channels, out_channels, kernel=(3, 3))
|
| 133 |
+
self.conv_skip = MPConv(in_channels, out_channels, kernel=(1, 1)) if in_channels != out_channels else None
|
| 134 |
+
self.attn_qkv = MPConv(out_channels, out_channels * 3, kernel=(1, 1)) if self.num_heads else None
|
| 135 |
+
self.attn_proj = MPConv(out_channels, out_channels, kernel=(1, 1)) if self.num_heads else None
|
| 136 |
+
|
| 137 |
+
def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor:
|
| 138 |
+
x = resample(x, f=self.resample_filter, mode=self.resample_mode)
|
| 139 |
+
if self.flavor == "enc":
|
| 140 |
+
if self.conv_skip is not None:
|
| 141 |
+
x = self.conv_skip(x)
|
| 142 |
+
x = normalize(x, dim=[1])
|
| 143 |
+
|
| 144 |
+
y = self.conv_res0(mp_silu(x))
|
| 145 |
+
c = self.emb_linear(emb, gain=self.emb_gain) + 1
|
| 146 |
+
y = mp_silu(y * c.unsqueeze(2).unsqueeze(3).to(y.dtype))
|
| 147 |
+
if self.training and self.dropout:
|
| 148 |
+
y = torch.nn.functional.dropout(y, p=self.dropout)
|
| 149 |
+
y = self.conv_res1(y)
|
| 150 |
+
|
| 151 |
+
if self.flavor == "dec" and self.conv_skip is not None:
|
| 152 |
+
x = self.conv_skip(x)
|
| 153 |
+
x = mp_sum(x, y, t=self.res_balance)
|
| 154 |
+
|
| 155 |
+
if self.num_heads:
|
| 156 |
+
y = self.attn_qkv(x)
|
| 157 |
+
y = y.reshape(y.shape[0], self.num_heads, -1, 3, y.shape[2] * y.shape[3])
|
| 158 |
+
q, k, v = normalize(y, dim=[2]).unbind(3)
|
| 159 |
+
w = torch.einsum("nhcq,nhck->nhqk", q, k / math.sqrt(q.shape[2])).softmax(dim=3)
|
| 160 |
+
y = torch.einsum("nhqk,nhck->nhcq", w, v)
|
| 161 |
+
y = self.attn_proj(y.reshape(*x.shape))
|
| 162 |
+
x = mp_sum(x, y, t=self.attn_balance)
|
| 163 |
+
|
| 164 |
+
if self.clip_act is not None:
|
| 165 |
+
x = x.clip_(-self.clip_act, self.clip_act)
|
| 166 |
+
return x
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
class EDM2UNet(torch.nn.Module):
|
| 170 |
+
def __init__(
|
| 171 |
+
self,
|
| 172 |
+
img_resolution: int,
|
| 173 |
+
img_channels: int,
|
| 174 |
+
label_dim: int,
|
| 175 |
+
model_channels: int = 192,
|
| 176 |
+
channel_mult: Tuple[int, ...] = (1, 2, 3, 4),
|
| 177 |
+
channel_mult_noise: Optional[int] = None,
|
| 178 |
+
channel_mult_emb: Optional[int] = None,
|
| 179 |
+
num_blocks: int = 3,
|
| 180 |
+
attn_resolutions: Tuple[int, ...] = (16, 8),
|
| 181 |
+
label_balance: float = 0.5,
|
| 182 |
+
concat_balance: float = 0.5,
|
| 183 |
+
**block_kwargs,
|
| 184 |
+
):
|
| 185 |
+
super().__init__()
|
| 186 |
+
cblock = [model_channels * x for x in channel_mult]
|
| 187 |
+
cnoise = model_channels * channel_mult_noise if channel_mult_noise is not None else cblock[0]
|
| 188 |
+
cemb = model_channels * channel_mult_emb if channel_mult_emb is not None else max(cblock)
|
| 189 |
+
self.label_balance = label_balance
|
| 190 |
+
self.concat_balance = concat_balance
|
| 191 |
+
self.out_gain = torch.nn.Parameter(torch.zeros([]))
|
| 192 |
+
|
| 193 |
+
self.emb_fourier = MPFourier(cnoise)
|
| 194 |
+
self.emb_noise = MPConv(cnoise, cemb, kernel=())
|
| 195 |
+
self.emb_label = MPConv(label_dim, cemb, kernel=()) if label_dim else None
|
| 196 |
+
|
| 197 |
+
self.enc = torch.nn.ModuleDict()
|
| 198 |
+
cout = img_channels + 1
|
| 199 |
+
for level, channels in enumerate(cblock):
|
| 200 |
+
res = img_resolution >> level
|
| 201 |
+
if level == 0:
|
| 202 |
+
cin = cout
|
| 203 |
+
cout = channels
|
| 204 |
+
self.enc[f"{res}x{res}_conv"] = MPConv(cin, cout, kernel=(3, 3))
|
| 205 |
+
else:
|
| 206 |
+
self.enc[f"{res}x{res}_down"] = Block(cout, cout, cemb, flavor="enc", resample_mode="down", **block_kwargs)
|
| 207 |
+
for idx in range(num_blocks):
|
| 208 |
+
cin = cout
|
| 209 |
+
cout = channels
|
| 210 |
+
self.enc[f"{res}x{res}_block{idx}"] = Block(
|
| 211 |
+
cin,
|
| 212 |
+
cout,
|
| 213 |
+
cemb,
|
| 214 |
+
flavor="enc",
|
| 215 |
+
attention=(res in attn_resolutions),
|
| 216 |
+
**block_kwargs,
|
| 217 |
+
)
|
| 218 |
+
|
| 219 |
+
self.dec = torch.nn.ModuleDict()
|
| 220 |
+
skips = [block.out_channels for block in self.enc.values()]
|
| 221 |
+
for level, channels in reversed(list(enumerate(cblock))):
|
| 222 |
+
res = img_resolution >> level
|
| 223 |
+
if level == len(cblock) - 1:
|
| 224 |
+
self.dec[f"{res}x{res}_in0"] = Block(cout, cout, cemb, flavor="dec", attention=True, **block_kwargs)
|
| 225 |
+
self.dec[f"{res}x{res}_in1"] = Block(cout, cout, cemb, flavor="dec", **block_kwargs)
|
| 226 |
+
else:
|
| 227 |
+
self.dec[f"{res}x{res}_up"] = Block(cout, cout, cemb, flavor="dec", resample_mode="up", **block_kwargs)
|
| 228 |
+
for idx in range(num_blocks + 1):
|
| 229 |
+
cin = cout + skips.pop()
|
| 230 |
+
cout = channels
|
| 231 |
+
self.dec[f"{res}x{res}_block{idx}"] = Block(
|
| 232 |
+
cin,
|
| 233 |
+
cout,
|
| 234 |
+
cemb,
|
| 235 |
+
flavor="dec",
|
| 236 |
+
attention=(res in attn_resolutions),
|
| 237 |
+
**block_kwargs,
|
| 238 |
+
)
|
| 239 |
+
|
| 240 |
+
self.out_conv = MPConv(cout, img_channels, kernel=(3, 3))
|
| 241 |
+
|
| 242 |
+
def forward(self, x: torch.Tensor, noise_labels: torch.Tensor, class_labels: Optional[torch.Tensor]) -> torch.Tensor:
|
| 243 |
+
emb = self.emb_noise(self.emb_fourier(noise_labels))
|
| 244 |
+
if self.emb_label is not None:
|
| 245 |
+
if class_labels is None:
|
| 246 |
+
raise ValueError("class_labels are required for conditional EDM2UNet.")
|
| 247 |
+
emb = mp_sum(emb, self.emb_label(class_labels * math.sqrt(class_labels.shape[1])), t=self.label_balance)
|
| 248 |
+
emb = mp_silu(emb)
|
| 249 |
+
|
| 250 |
+
x = torch.cat([x, torch.ones_like(x[:, :1])], dim=1)
|
| 251 |
+
skips = []
|
| 252 |
+
for name, block in self.enc.items():
|
| 253 |
+
x = block(x) if "conv" in name else block(x, emb)
|
| 254 |
+
skips.append(x)
|
| 255 |
+
|
| 256 |
+
for name, block in self.dec.items():
|
| 257 |
+
if "block" in name:
|
| 258 |
+
x = mp_cat(x, skips.pop(), t=self.concat_balance)
|
| 259 |
+
x = block(x, emb)
|
| 260 |
+
return self.out_conv(x, gain=self.out_gain)
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
@dataclass
|
| 264 |
+
class EDM2UNet2DOutput(BaseOutput):
|
| 265 |
+
sample: torch.Tensor
|
| 266 |
+
logvar: Optional[torch.Tensor] = None
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
_CONFIG_KEYS = (
|
| 271 |
+
"sample_size",
|
| 272 |
+
"in_channels",
|
| 273 |
+
"out_channels",
|
| 274 |
+
"num_class_embeds",
|
| 275 |
+
"use_fp16",
|
| 276 |
+
"sigma_data",
|
| 277 |
+
"logvar_channels",
|
| 278 |
+
"model_channels",
|
| 279 |
+
"channel_mult",
|
| 280 |
+
"channel_mult_noise",
|
| 281 |
+
"channel_mult_emb",
|
| 282 |
+
"num_blocks",
|
| 283 |
+
"attn_resolutions",
|
| 284 |
+
"label_balance",
|
| 285 |
+
"concat_balance",
|
| 286 |
+
"dropout",
|
| 287 |
+
"channels_per_head",
|
| 288 |
+
"res_balance",
|
| 289 |
+
"attn_balance",
|
| 290 |
+
"clip_act",
|
| 291 |
+
)
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
class EDM2UNet2DModel(ModelMixin, ConfigMixin):
|
| 295 |
+
@register_to_config
|
| 296 |
+
def __init__(
|
| 297 |
+
self,
|
| 298 |
+
sample_size: int = 64,
|
| 299 |
+
in_channels: int = 4,
|
| 300 |
+
out_channels: int = 4,
|
| 301 |
+
num_class_embeds: int = 0,
|
| 302 |
+
use_fp16: bool = True,
|
| 303 |
+
sigma_data: float = 0.5,
|
| 304 |
+
logvar_channels: int = 128,
|
| 305 |
+
model_channels: int = 192,
|
| 306 |
+
channel_mult: Tuple[int, ...] = (1, 2, 3, 4),
|
| 307 |
+
channel_mult_noise: Optional[int] = None,
|
| 308 |
+
channel_mult_emb: Optional[int] = None,
|
| 309 |
+
num_blocks: int = 3,
|
| 310 |
+
attn_resolutions: Tuple[int, ...] = (16, 8),
|
| 311 |
+
label_balance: float = 0.5,
|
| 312 |
+
concat_balance: float = 0.5,
|
| 313 |
+
dropout: float = 0.0,
|
| 314 |
+
channels_per_head: int = 64,
|
| 315 |
+
res_balance: float = 0.3,
|
| 316 |
+
attn_balance: float = 0.3,
|
| 317 |
+
clip_act: Optional[float] = 256,
|
| 318 |
+
):
|
| 319 |
+
super().__init__()
|
| 320 |
+
self.sample_size = sample_size
|
| 321 |
+
self.in_channels = in_channels
|
| 322 |
+
self.out_channels = out_channels
|
| 323 |
+
self.num_class_embeds = num_class_embeds
|
| 324 |
+
self.use_fp16 = use_fp16
|
| 325 |
+
self.sigma_data = sigma_data
|
| 326 |
+
self.model_channels = model_channels
|
| 327 |
+
self.channel_mult = channel_mult
|
| 328 |
+
self.channel_mult_noise = channel_mult_noise
|
| 329 |
+
self.channel_mult_emb = channel_mult_emb
|
| 330 |
+
self.num_blocks = num_blocks
|
| 331 |
+
self.attn_resolutions = attn_resolutions
|
| 332 |
+
self.label_balance = label_balance
|
| 333 |
+
self.concat_balance = concat_balance
|
| 334 |
+
self.dropout = dropout
|
| 335 |
+
self.channels_per_head = channels_per_head
|
| 336 |
+
self.res_balance = res_balance
|
| 337 |
+
self.attn_balance = attn_balance
|
| 338 |
+
self.clip_act = clip_act
|
| 339 |
+
self.unet = EDM2UNet(
|
| 340 |
+
img_resolution=sample_size,
|
| 341 |
+
img_channels=in_channels,
|
| 342 |
+
label_dim=num_class_embeds,
|
| 343 |
+
model_channels=model_channels,
|
| 344 |
+
channel_mult=channel_mult,
|
| 345 |
+
channel_mult_noise=channel_mult_noise,
|
| 346 |
+
channel_mult_emb=channel_mult_emb,
|
| 347 |
+
num_blocks=num_blocks,
|
| 348 |
+
attn_resolutions=attn_resolutions,
|
| 349 |
+
label_balance=label_balance,
|
| 350 |
+
concat_balance=concat_balance,
|
| 351 |
+
dropout=dropout,
|
| 352 |
+
channels_per_head=channels_per_head,
|
| 353 |
+
res_balance=res_balance,
|
| 354 |
+
attn_balance=attn_balance,
|
| 355 |
+
clip_act=clip_act,
|
| 356 |
+
)
|
| 357 |
+
self.logvar_fourier = MPFourier(logvar_channels)
|
| 358 |
+
self.logvar_linear = MPConv(logvar_channels, 1, kernel=())
|
| 359 |
+
|
| 360 |
+
def forward(
|
| 361 |
+
self,
|
| 362 |
+
sample: torch.Tensor,
|
| 363 |
+
sigma: torch.Tensor,
|
| 364 |
+
class_labels: Optional[torch.Tensor] = None,
|
| 365 |
+
force_fp32: bool = False,
|
| 366 |
+
return_logvar: bool = False,
|
| 367 |
+
return_dict: bool = True,
|
| 368 |
+
) -> EDM2UNet2DOutput:
|
| 369 |
+
x = sample.to(torch.float32)
|
| 370 |
+
sigma = sigma.to(torch.float32).reshape(-1, 1, 1, 1)
|
| 371 |
+
if self.num_class_embeds == 0:
|
| 372 |
+
class_labels = None
|
| 373 |
+
else:
|
| 374 |
+
if class_labels is None:
|
| 375 |
+
class_labels = torch.zeros([x.shape[0], self.num_class_embeds], device=x.device)
|
| 376 |
+
class_labels = class_labels.to(torch.float32).reshape(-1, self.num_class_embeds)
|
| 377 |
+
dtype = torch.float16 if (self.use_fp16 and not force_fp32 and x.device.type == "cuda") else torch.float32
|
| 378 |
+
|
| 379 |
+
c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2)
|
| 380 |
+
c_out = sigma * self.sigma_data / (sigma**2 + self.sigma_data**2).sqrt()
|
| 381 |
+
c_in = 1 / (self.sigma_data**2 + sigma**2).sqrt()
|
| 382 |
+
c_noise = sigma.flatten().log() / 4
|
| 383 |
+
|
| 384 |
+
x_in = (c_in * x).to(dtype)
|
| 385 |
+
f_x = self.unet(x_in, c_noise, class_labels)
|
| 386 |
+
d_x = c_skip * x + c_out * f_x.to(torch.float32)
|
| 387 |
+
|
| 388 |
+
logvar = None
|
| 389 |
+
if return_logvar:
|
| 390 |
+
logvar = self.logvar_linear(self.logvar_fourier(c_noise)).reshape(-1, 1, 1, 1)
|
| 391 |
+
|
| 392 |
+
if not return_dict:
|
| 393 |
+
return (d_x, logvar)
|
| 394 |
+
return EDM2UNet2DOutput(sample=d_x, logvar=logvar)
|
| 395 |
+
|
| 396 |
+
@classmethod
|
| 397 |
+
def from_pretrained(cls, pretrained_model_name_or_path: str, torch_dtype: Optional[torch.dtype] = None, **kwargs):
|
| 398 |
+
subfolder = kwargs.pop("subfolder", None)
|
| 399 |
+
model_dir = os.path.join(pretrained_model_name_or_path, subfolder) if subfolder else pretrained_model_name_or_path
|
| 400 |
+
with open(os.path.join(model_dir, "config.json"), "r", encoding="utf-8") as f:
|
| 401 |
+
config = json.load(f)
|
| 402 |
+
init_kwargs = {k: v for k, v in config.items() if k in _CONFIG_KEYS}
|
| 403 |
+
model = cls(**init_kwargs)
|
| 404 |
+
weight_file = os.path.join(model_dir, "diffusion_pytorch_model.safetensors")
|
| 405 |
+
if os.path.isfile(weight_file):
|
| 406 |
+
from safetensors.torch import load_file
|
| 407 |
+
|
| 408 |
+
state_dict = load_file(weight_file)
|
| 409 |
+
else:
|
| 410 |
+
state_dict = torch.load(os.path.join(model_dir, "diffusion_pytorch_model.bin"), map_location="cpu")
|
| 411 |
+
model.load_state_dict(state_dict, strict=True)
|
| 412 |
+
if torch_dtype is not None:
|
| 413 |
+
model = model.to(dtype=torch_dtype)
|
| 414 |
+
return model
|
| 415 |
+
|
| 416 |
+
def save_pretrained(self, save_directory: str, safe_serialization: bool = True):
|
| 417 |
+
os.makedirs(save_directory, exist_ok=True)
|
| 418 |
+
stored = dict(getattr(self, "config", {}))
|
| 419 |
+
config = {"_class_name": self.__class__.__name__}
|
| 420 |
+
for key in _CONFIG_KEYS:
|
| 421 |
+
if key in stored:
|
| 422 |
+
config[key] = stored[key]
|
| 423 |
+
elif hasattr(self, key):
|
| 424 |
+
config[key] = getattr(self, key)
|
| 425 |
+
with open(os.path.join(save_directory, "config.json"), "w", encoding="utf-8") as f:
|
| 426 |
+
json.dump(config, f, indent=2, sort_keys=True)
|
| 427 |
+
f.write("\n")
|
| 428 |
+
state_dict = self.state_dict()
|
| 429 |
+
if safe_serialization:
|
| 430 |
+
from safetensors.torch import save_file
|
| 431 |
+
|
| 432 |
+
save_file(state_dict, os.path.join(save_directory, "diffusion_pytorch_model.safetensors"))
|
| 433 |
+
else:
|
| 434 |
+
torch.save(state_dict, os.path.join(save_directory, "diffusion_pytorch_model.bin"))
|
edm2-img512-s-fid/vae/config.json
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_class_name": "AutoencoderKL",
|
| 3 |
+
"_diffusers_version": "0.36.0",
|
| 4 |
+
"_name_or_path": "stabilityai/sd-vae-ft-mse",
|
| 5 |
+
"act_fn": "silu",
|
| 6 |
+
"block_out_channels": [
|
| 7 |
+
128,
|
| 8 |
+
256,
|
| 9 |
+
512,
|
| 10 |
+
512
|
| 11 |
+
],
|
| 12 |
+
"down_block_types": [
|
| 13 |
+
"DownEncoderBlock2D",
|
| 14 |
+
"DownEncoderBlock2D",
|
| 15 |
+
"DownEncoderBlock2D",
|
| 16 |
+
"DownEncoderBlock2D"
|
| 17 |
+
],
|
| 18 |
+
"force_upcast": true,
|
| 19 |
+
"in_channels": 3,
|
| 20 |
+
"latent_channels": 4,
|
| 21 |
+
"latents_mean": null,
|
| 22 |
+
"latents_std": null,
|
| 23 |
+
"layers_per_block": 2,
|
| 24 |
+
"mid_block_add_attention": true,
|
| 25 |
+
"norm_num_groups": 32,
|
| 26 |
+
"out_channels": 3,
|
| 27 |
+
"sample_size": 256,
|
| 28 |
+
"scaling_factor": 0.18215,
|
| 29 |
+
"shift_factor": null,
|
| 30 |
+
"up_block_types": [
|
| 31 |
+
"UpDecoderBlock2D",
|
| 32 |
+
"UpDecoderBlock2D",
|
| 33 |
+
"UpDecoderBlock2D",
|
| 34 |
+
"UpDecoderBlock2D"
|
| 35 |
+
],
|
| 36 |
+
"use_post_quant_conv": true,
|
| 37 |
+
"use_quant_conv": true
|
| 38 |
+
}
|
edm2-img512-s-fid/vae/diffusion_pytorch_model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a1d993488569e928462932c8c38a0760b874d166399b14414135bd9c42df5815
|
| 3 |
+
size 334643276
|
edm2-img512-xl-fid/demo.png
ADDED
|
Git LFS Details
|
edm2-img512-xl-fid/model_index.json
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_class_name": [
|
| 3 |
+
"pipeline",
|
| 4 |
+
"EDM2Pipeline"
|
| 5 |
+
],
|
| 6 |
+
"_diffusers_version": "0.31.0",
|
| 7 |
+
"scheduler": [
|
| 8 |
+
"diffusers",
|
| 9 |
+
"EDMEulerScheduler"
|
| 10 |
+
],
|
| 11 |
+
"unet": [
|
| 12 |
+
"unet_edm2",
|
| 13 |
+
"EDM2UNet2DModel"
|
| 14 |
+
],
|
| 15 |
+
"vae": [
|
| 16 |
+
"diffusers",
|
| 17 |
+
"AutoencoderKL"
|
| 18 |
+
]
|
| 19 |
+
}
|
edm2-img512-xl-fid/pipeline.py
ADDED
|
@@ -0,0 +1,406 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Hub custom pipeline: EDM2Pipeline.
|
| 2 |
+
Load with native Hugging Face diffusers and trust_remote_code=True.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from __future__ import annotations
|
| 6 |
+
|
| 7 |
+
# Copyright 2026 The HuggingFace Team. All rights reserved.
|
| 8 |
+
#
|
| 9 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 10 |
+
# you may not use this file except in compliance with the License.
|
| 11 |
+
# You may obtain a copy of the License at
|
| 12 |
+
#
|
| 13 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 14 |
+
#
|
| 15 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 16 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 17 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 18 |
+
# See the License for the specific language governing permissions and
|
| 19 |
+
# limitations under the License.
|
| 20 |
+
|
| 21 |
+
import json
|
| 22 |
+
import os
|
| 23 |
+
from pathlib import Path
|
| 24 |
+
from typing import Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Union
|
| 25 |
+
|
| 26 |
+
import numpy as np
|
| 27 |
+
import torch
|
| 28 |
+
|
| 29 |
+
from diffusers.image_processor import VaeImageProcessor
|
| 30 |
+
from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
| 31 |
+
from diffusers.utils import replace_example_docstring
|
| 32 |
+
from diffusers.utils.torch_utils import randn_tensor
|
| 33 |
+
|
| 34 |
+
EXAMPLE_DOC_STRING = """
|
| 35 |
+
Examples:
|
| 36 |
+
```py
|
| 37 |
+
>>> from pathlib import Path
|
| 38 |
+
>>> import torch
|
| 39 |
+
>>> from diffusers import DiffusionPipeline
|
| 40 |
+
|
| 41 |
+
>>> model_dir = Path("BiliSakura/EDM2-diffusers/edm2-img512-xs-fid").resolve()
|
| 42 |
+
>>> pipe = DiffusionPipeline.from_pretrained(
|
| 43 |
+
... str(model_dir),
|
| 44 |
+
... local_files_only=True,
|
| 45 |
+
... custom_pipeline=str(model_dir / "pipeline.py"),
|
| 46 |
+
... trust_remote_code=True,
|
| 47 |
+
... torch_dtype=torch.float32,
|
| 48 |
+
... )
|
| 49 |
+
>>> pipe.to("cuda")
|
| 50 |
+
|
| 51 |
+
>>> generator = torch.Generator(device="cuda").manual_seed(42)
|
| 52 |
+
>>> image = pipe(
|
| 53 |
+
... class_labels=207,
|
| 54 |
+
... num_inference_steps=32,
|
| 55 |
+
... guidance_scale=1.0,
|
| 56 |
+
... generator=generator,
|
| 57 |
+
... ).images[0]
|
| 58 |
+
>>> image.save("demo.png")
|
| 59 |
+
```
|
| 60 |
+
"""
|
| 61 |
+
|
| 62 |
+
# Default Stability VAE latent whitening used by NVlabs/edm2 (training/encoders.py).
|
| 63 |
+
_STABILITY_VAE_SCALE = np.float32(0.5) / np.float32([4.17, 4.62, 3.71, 3.28])
|
| 64 |
+
_STABILITY_VAE_BIAS = np.float32(0.0) - np.float32([5.81, 3.25, 0.12, -2.15]) * _STABILITY_VAE_SCALE
|
| 65 |
+
|
| 66 |
+
class EDM2Pipeline(DiffusionPipeline):
|
| 67 |
+
r"""
|
| 68 |
+
Pipeline for class-conditional image generation with EDM2
|
| 69 |
+
([Analyzing and Improving the Training Dynamics of Diffusion Models](https://arxiv.org/abs/2312.02696)).
|
| 70 |
+
|
| 71 |
+
Parameters:
|
| 72 |
+
unet ([`EDM2UNet2DModel`]):
|
| 73 |
+
Main magnitude-preserving U-Net with EDM preconditioning.
|
| 74 |
+
scheduler ([`EDMEulerScheduler`]):
|
| 75 |
+
Built-in diffusers scheduler used for the Karras sigma schedule. EDM2 Heun sampling runs in
|
| 76 |
+
the pipeline because the UNet returns denoised latents rather than noise predictions.
|
| 77 |
+
vae ([`AutoencoderKL`], *optional*):
|
| 78 |
+
Decoder for 512px latent-diffusion checkpoints. Required when `unet.in_channels == 4`.
|
| 79 |
+
gnet ([`EDM2UNet2DModel`], *optional*):
|
| 80 |
+
Guiding network for autoguidance (`ref.lerp(main, guidance_scale)`).
|
| 81 |
+
id2label (`dict[int, str]`, *optional*):
|
| 82 |
+
ImageNet class id to English label mapping.
|
| 83 |
+
"""
|
| 84 |
+
|
| 85 |
+
model_cpu_offload_seq = "unet->gnet->vae"
|
| 86 |
+
_optional_components = ["vae", "gnet"]
|
| 87 |
+
|
| 88 |
+
def __init__(
|
| 89 |
+
self,
|
| 90 |
+
unet,
|
| 91 |
+
scheduler,
|
| 92 |
+
vae=None,
|
| 93 |
+
gnet=None,
|
| 94 |
+
id2label: Optional[Dict[Union[int, str], str]] = None,
|
| 95 |
+
) -> None:
|
| 96 |
+
super().__init__()
|
| 97 |
+
self.register_modules(unet=unet, scheduler=scheduler, vae=vae, gnet=gnet)
|
| 98 |
+
self._id2label = self._normalize_id2label(id2label)
|
| 99 |
+
self.labels = self._build_label2id(self._id2label)
|
| 100 |
+
self._labels_loaded_from_model_index = bool(self._id2label)
|
| 101 |
+
self.vae_scale_factor = 8 if self.vae is not None else 1
|
| 102 |
+
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_normalize=False)
|
| 103 |
+
|
| 104 |
+
@staticmethod
|
| 105 |
+
def _normalize_id2label(id2label: Optional[Dict[Union[int, str], str]]) -> Dict[int, str]:
|
| 106 |
+
if not id2label:
|
| 107 |
+
return {}
|
| 108 |
+
return {int(key): value for key, value in id2label.items()}
|
| 109 |
+
|
| 110 |
+
@staticmethod
|
| 111 |
+
def _build_label2id(id2label: Dict[int, str]) -> Dict[str, int]:
|
| 112 |
+
label2id: Dict[str, int] = {}
|
| 113 |
+
for class_id, value in id2label.items():
|
| 114 |
+
for synonym in value.split(","):
|
| 115 |
+
synonym = synonym.strip()
|
| 116 |
+
if synonym:
|
| 117 |
+
label2id[synonym] = int(class_id)
|
| 118 |
+
return dict(sorted(label2id.items()))
|
| 119 |
+
|
| 120 |
+
def _ensure_labels_loaded(self) -> None:
|
| 121 |
+
if self._labels_loaded_from_model_index:
|
| 122 |
+
return
|
| 123 |
+
loaded = self._read_id2label_from_model_index(getattr(self.config, "_name_or_path", None))
|
| 124 |
+
if loaded:
|
| 125 |
+
self._id2label = loaded
|
| 126 |
+
self.labels = self._build_label2id(self._id2label)
|
| 127 |
+
self._labels_loaded_from_model_index = True
|
| 128 |
+
|
| 129 |
+
@staticmethod
|
| 130 |
+
def _read_id2label_from_model_index(variant_path: Optional[str]) -> Dict[int, str]:
|
| 131 |
+
if not variant_path:
|
| 132 |
+
return {}
|
| 133 |
+
model_index_path = Path(variant_path).resolve() / "model_index.json"
|
| 134 |
+
if not model_index_path.is_file():
|
| 135 |
+
return {}
|
| 136 |
+
raw = json.loads(model_index_path.read_text(encoding="utf-8"))
|
| 137 |
+
id2label = raw.get("id2label")
|
| 138 |
+
if not isinstance(id2label, dict):
|
| 139 |
+
return {}
|
| 140 |
+
return {int(key): value for key, value in id2label.items()}
|
| 141 |
+
|
| 142 |
+
@property
|
| 143 |
+
def id2label(self) -> Dict[int, str]:
|
| 144 |
+
r"""ImageNet class id to English label string (comma-separated synonyms)."""
|
| 145 |
+
self._ensure_labels_loaded()
|
| 146 |
+
return self._id2label
|
| 147 |
+
|
| 148 |
+
def get_label_ids(self, label: Union[str, List[str]]) -> List[int]:
|
| 149 |
+
r"""
|
| 150 |
+
Map ImageNet label strings to class ids.
|
| 151 |
+
|
| 152 |
+
Args:
|
| 153 |
+
label (`str` or `list[str]`):
|
| 154 |
+
One or more English label strings that match entries in `id2label`.
|
| 155 |
+
"""
|
| 156 |
+
self._ensure_labels_loaded()
|
| 157 |
+
if not self.labels:
|
| 158 |
+
raise ValueError("No English labels loaded. Add `id2label` to model_index.json.")
|
| 159 |
+
labels = [label] if isinstance(label, str) else list(label)
|
| 160 |
+
missing = [item for item in labels if item not in self.labels]
|
| 161 |
+
if missing:
|
| 162 |
+
preview = ", ".join(list(self.labels.keys())[:8])
|
| 163 |
+
raise ValueError(f"Unknown English label(s): {missing}. Example valid labels: {preview}, ...")
|
| 164 |
+
return [self.labels[item] for item in labels]
|
| 165 |
+
|
| 166 |
+
def _default_image_size(self) -> int:
|
| 167 |
+
latent_size = int(getattr(self.unet, "sample_size", getattr(self.unet.config, "sample_size", 64)))
|
| 168 |
+
return latent_size * self.vae_scale_factor
|
| 169 |
+
|
| 170 |
+
def check_inputs(
|
| 171 |
+
self,
|
| 172 |
+
height: int,
|
| 173 |
+
width: int,
|
| 174 |
+
num_inference_steps: int,
|
| 175 |
+
guidance_scale: float,
|
| 176 |
+
output_type: str,
|
| 177 |
+
) -> None:
|
| 178 |
+
if num_inference_steps < 1:
|
| 179 |
+
raise ValueError("num_inference_steps must be >= 1.")
|
| 180 |
+
if guidance_scale < 1.0:
|
| 181 |
+
raise ValueError("guidance_scale must be >= 1.0.")
|
| 182 |
+
if guidance_scale > 1.0 and self.gnet is None:
|
| 183 |
+
raise ValueError("guidance_scale > 1.0 requires a guiding network (`gnet`).")
|
| 184 |
+
if output_type not in {"pil", "np", "pt", "latent"}:
|
| 185 |
+
raise ValueError("output_type must be one of: 'pil', 'np', 'pt', 'latent'.")
|
| 186 |
+
|
| 187 |
+
native_size = self._default_image_size()
|
| 188 |
+
if height != native_size or width != native_size:
|
| 189 |
+
raise ValueError(
|
| 190 |
+
f"EDM2 expects native resolution height=width={native_size}. "
|
| 191 |
+
f"Got height={height}, width={width}."
|
| 192 |
+
)
|
| 193 |
+
|
| 194 |
+
def _normalize_class_labels(
|
| 195 |
+
self,
|
| 196 |
+
class_labels: Optional[Union[int, str, Sequence[Union[int, str]], torch.Tensor]],
|
| 197 |
+
batch_size: int,
|
| 198 |
+
device: torch.device,
|
| 199 |
+
) -> Optional[torch.Tensor]:
|
| 200 |
+
label_dim = int(getattr(self.unet, "num_class_embeds", getattr(self.unet.config, "num_class_embeds", 0)))
|
| 201 |
+
if label_dim == 0:
|
| 202 |
+
return None
|
| 203 |
+
if class_labels is None:
|
| 204 |
+
indices = torch.randint(label_dim, size=(batch_size,), device=device)
|
| 205 |
+
return torch.eye(label_dim, device=device, dtype=torch.float32)[indices]
|
| 206 |
+
|
| 207 |
+
if isinstance(class_labels, str):
|
| 208 |
+
class_labels = self.get_label_ids(class_labels)[0]
|
| 209 |
+
elif isinstance(class_labels, Sequence) and class_labels and isinstance(class_labels[0], str):
|
| 210 |
+
class_labels = self.get_label_ids(list(class_labels))
|
| 211 |
+
|
| 212 |
+
if isinstance(class_labels, int):
|
| 213 |
+
indices = torch.full((batch_size,), class_labels, device=device, dtype=torch.long)
|
| 214 |
+
elif isinstance(class_labels, torch.Tensor):
|
| 215 |
+
if class_labels.ndim == 2:
|
| 216 |
+
labels = class_labels.to(device=device, dtype=torch.float32)
|
| 217 |
+
if labels.shape[0] != batch_size:
|
| 218 |
+
raise ValueError(f"class_labels batch must match batch_size={batch_size}.")
|
| 219 |
+
return labels
|
| 220 |
+
indices = class_labels.to(device=device, dtype=torch.long).flatten()
|
| 221 |
+
else:
|
| 222 |
+
indices = torch.tensor(list(class_labels), device=device, dtype=torch.long)
|
| 223 |
+
|
| 224 |
+
if indices.numel() == 1 and batch_size > 1:
|
| 225 |
+
indices = indices.repeat(batch_size)
|
| 226 |
+
if indices.numel() != batch_size:
|
| 227 |
+
raise ValueError(f"class_labels must resolve to batch size {batch_size}, got {indices.numel()}.")
|
| 228 |
+
return torch.eye(label_dim, device=device, dtype=torch.float32)[indices]
|
| 229 |
+
|
| 230 |
+
def prepare_latents(
|
| 231 |
+
self,
|
| 232 |
+
batch_size: int,
|
| 233 |
+
height: int,
|
| 234 |
+
width: int,
|
| 235 |
+
dtype: torch.dtype,
|
| 236 |
+
device: torch.device,
|
| 237 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]],
|
| 238 |
+
) -> torch.Tensor:
|
| 239 |
+
in_channels = int(getattr(self.unet, "in_channels", getattr(self.unet.config, "in_channels", 4)))
|
| 240 |
+
latent_size = height // self.vae_scale_factor
|
| 241 |
+
return randn_tensor(
|
| 242 |
+
(batch_size, in_channels, latent_size, latent_size),
|
| 243 |
+
generator=generator,
|
| 244 |
+
device=device,
|
| 245 |
+
dtype=torch.float32,
|
| 246 |
+
)
|
| 247 |
+
|
| 248 |
+
def decode_latents(self, latents: torch.Tensor, output_type: str = "pil"):
|
| 249 |
+
if output_type == "latent":
|
| 250 |
+
return latents
|
| 251 |
+
|
| 252 |
+
in_channels = int(getattr(self.unet, "in_channels", getattr(self.unet.config, "in_channels", 3)))
|
| 253 |
+
if self.vae is None:
|
| 254 |
+
image = (latents.to(torch.float32) * 127.5 + 128).clip(0, 255) / 255.0
|
| 255 |
+
return self.image_processor.postprocess(image, output_type=output_type)
|
| 256 |
+
|
| 257 |
+
if in_channels == 4:
|
| 258 |
+
x = latents.to(torch.float32)
|
| 259 |
+
scale = torch.as_tensor(_STABILITY_VAE_SCALE, dtype=x.dtype, device=x.device).reshape(1, -1, 1, 1)
|
| 260 |
+
bias = torch.as_tensor(_STABILITY_VAE_BIAS, dtype=x.dtype, device=x.device).reshape(1, -1, 1, 1)
|
| 261 |
+
x = (x - bias) / scale
|
| 262 |
+
else:
|
| 263 |
+
x = latents.to(torch.float32)
|
| 264 |
+
|
| 265 |
+
vae_dtype = getattr(self.vae, "dtype", None) or next(self.vae.parameters()).dtype
|
| 266 |
+
image = self.vae.decode(x.to(dtype=vae_dtype)).sample.to(torch.float32).clamp(0, 1)
|
| 267 |
+
|
| 268 |
+
return self.image_processor.postprocess(image, output_type=output_type)
|
| 269 |
+
|
| 270 |
+
@staticmethod
|
| 271 |
+
def _apply_autoguidance(
|
| 272 |
+
main: torch.Tensor,
|
| 273 |
+
ref: torch.Tensor,
|
| 274 |
+
guidance_scale: float,
|
| 275 |
+
) -> torch.Tensor:
|
| 276 |
+
return ref.lerp(main, guidance_scale)
|
| 277 |
+
|
| 278 |
+
@staticmethod
|
| 279 |
+
def _sample_edm2_heun(
|
| 280 |
+
denoise_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
|
| 281 |
+
noise: torch.Tensor,
|
| 282 |
+
sigmas: torch.Tensor,
|
| 283 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 284 |
+
progress_bar: Optional[Callable[[Iterable], Iterable]] = None,
|
| 285 |
+
dtype: torch.dtype = torch.float32,
|
| 286 |
+
) -> torch.Tensor:
|
| 287 |
+
"""NVlabs EDM2 Heun sampler (generate_images.edm_sampler, guidance=1, S_churn=0)."""
|
| 288 |
+
x_next = noise.to(dtype) * sigmas[0]
|
| 289 |
+
|
| 290 |
+
sigma_pairs = list(zip(sigmas[:-1], sigmas[1:]))
|
| 291 |
+
if progress_bar is not None:
|
| 292 |
+
sigma_pairs = progress_bar(sigma_pairs)
|
| 293 |
+
|
| 294 |
+
num_steps = len(sigma_pairs)
|
| 295 |
+
for i, (sigma_cur, sigma_next) in enumerate(sigma_pairs):
|
| 296 |
+
x_hat, sigma_hat = x_next, sigma_cur
|
| 297 |
+
d_cur = (x_hat - denoise_fn(x_hat, sigma_hat)) / sigma_hat
|
| 298 |
+
x_next = x_hat + (sigma_next - sigma_hat) * d_cur
|
| 299 |
+
if i < num_steps - 1:
|
| 300 |
+
d_prime = (x_next - denoise_fn(x_next, sigma_next)) / sigma_next
|
| 301 |
+
x_next = x_hat + (sigma_next - sigma_hat) * (0.5 * d_cur + 0.5 * d_prime)
|
| 302 |
+
return x_next
|
| 303 |
+
|
| 304 |
+
@torch.inference_mode()
|
| 305 |
+
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
| 306 |
+
def __call__(
|
| 307 |
+
self,
|
| 308 |
+
class_labels: Optional[Union[int, str, Sequence[Union[int, str]], torch.Tensor]] = None,
|
| 309 |
+
batch_size: int = 1,
|
| 310 |
+
height: Optional[int] = None,
|
| 311 |
+
width: Optional[int] = None,
|
| 312 |
+
num_inference_steps: int = 32,
|
| 313 |
+
guidance_scale: float = 1.0,
|
| 314 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 315 |
+
output_type: str = "pil",
|
| 316 |
+
return_dict: bool = True,
|
| 317 |
+
) -> Union[ImagePipelineOutput, Tuple]:
|
| 318 |
+
r"""
|
| 319 |
+
Generate class-conditional images with EDM2.
|
| 320 |
+
|
| 321 |
+
Args:
|
| 322 |
+
class_labels (`int`, `str`, `list[int]`, `list[str]`, or `torch.Tensor`, *optional*):
|
| 323 |
+
ImageNet class indices, English label strings, or one-hot float tensors.
|
| 324 |
+
Random classes are sampled when omitted on conditional models.
|
| 325 |
+
batch_size (`int`, defaults to `1`):
|
| 326 |
+
Number of images to generate.
|
| 327 |
+
height (`int`, *optional*):
|
| 328 |
+
Output height in pixels. Defaults to the pretrained native resolution.
|
| 329 |
+
width (`int`, *optional*):
|
| 330 |
+
Output width in pixels. Defaults to the pretrained native resolution.
|
| 331 |
+
num_inference_steps (`int`, defaults to `32`):
|
| 332 |
+
Number of EDM2 Heun steps (NVlabs default).
|
| 333 |
+
guidance_scale (`float`, defaults to `1.0`):
|
| 334 |
+
Autoguidance strength. Values above `1.0` blend the main net with `gnet`
|
| 335 |
+
via `gnet_output.lerp(unet_output, guidance_scale)`.
|
| 336 |
+
generator (`torch.Generator`, *optional*):
|
| 337 |
+
RNG for reproducibility.
|
| 338 |
+
output_type (`str`, defaults to `"pil"`):
|
| 339 |
+
`"pil"`, `"np"`, `"pt"`, or `"latent"`.
|
| 340 |
+
return_dict (`bool`, defaults to `True`):
|
| 341 |
+
Return [`~pipelines.pipeline_utils.ImagePipelineOutput`] if True.
|
| 342 |
+
|
| 343 |
+
Examples:
|
| 344 |
+
<!-- this section is replaced by replace_example_docstring -->
|
| 345 |
+
"""
|
| 346 |
+
default_size = self._default_image_size()
|
| 347 |
+
height = int(height or default_size)
|
| 348 |
+
width = int(width or default_size)
|
| 349 |
+
self.check_inputs(height, width, num_inference_steps, guidance_scale, output_type)
|
| 350 |
+
|
| 351 |
+
device = self._execution_device
|
| 352 |
+
dtype = self.unet.dtype
|
| 353 |
+
labels = self._normalize_class_labels(class_labels, batch_size=batch_size, device=device)
|
| 354 |
+
noise = self.prepare_latents(batch_size, height, width, dtype, device, generator)
|
| 355 |
+
|
| 356 |
+
def denoise_fn(x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor:
|
| 357 |
+
sigma_batch = sigma.reshape(1).expand(batch_size)
|
| 358 |
+
main = self.unet(
|
| 359 |
+
sample=x,
|
| 360 |
+
sigma=sigma_batch,
|
| 361 |
+
class_labels=labels,
|
| 362 |
+
force_fp32=True,
|
| 363 |
+
).sample
|
| 364 |
+
if guidance_scale == 1.0 or self.gnet is None:
|
| 365 |
+
return main.to(torch.float32)
|
| 366 |
+
ref = self.gnet(
|
| 367 |
+
sample=x,
|
| 368 |
+
sigma=sigma_batch,
|
| 369 |
+
class_labels=labels,
|
| 370 |
+
force_fp32=True,
|
| 371 |
+
).sample
|
| 372 |
+
return self._apply_autoguidance(main, ref, guidance_scale).to(torch.float32)
|
| 373 |
+
|
| 374 |
+
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
| 375 |
+
latents = self._sample_edm2_heun(
|
| 376 |
+
denoise_fn=denoise_fn,
|
| 377 |
+
noise=noise,
|
| 378 |
+
sigmas=self.scheduler.sigmas.to(device),
|
| 379 |
+
generator=generator,
|
| 380 |
+
progress_bar=self.progress_bar,
|
| 381 |
+
dtype=torch.float32,
|
| 382 |
+
)
|
| 383 |
+
|
| 384 |
+
image = self.decode_latents(latents, output_type=output_type)
|
| 385 |
+
if not return_dict:
|
| 386 |
+
return (image, latents)
|
| 387 |
+
return ImagePipelineOutput(images=image)
|
| 388 |
+
|
| 389 |
+
@classmethod
|
| 390 |
+
def _load_vae(cls, pretrained_model_name_or_path: str, torch_dtype: Optional[torch.dtype] = None):
|
| 391 |
+
vae_dir = os.path.join(pretrained_model_name_or_path, "vae")
|
| 392 |
+
if os.path.isdir(vae_dir):
|
| 393 |
+
try:
|
| 394 |
+
|
| 395 |
+
return AutoencoderKL.from_pretrained(vae_dir, torch_dtype=torch_dtype)
|
| 396 |
+
except Exception:
|
| 397 |
+
return None
|
| 398 |
+
|
| 399 |
+
vae_hint = os.path.join(pretrained_model_name_or_path, "vae_pretrained_model_name_or_path.txt")
|
| 400 |
+
if os.path.isfile(vae_hint):
|
| 401 |
+
with open(vae_hint, "r", encoding="utf-8") as f:
|
| 402 |
+
hub_id = f.read().strip()
|
| 403 |
+
if hub_id:
|
| 404 |
+
|
| 405 |
+
return AutoencoderKL.from_pretrained(hub_id, torch_dtype=torch_dtype)
|
| 406 |
+
return None
|
edm2-img512-xl-fid/scheduler/scheduler_config.json
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_class_name": "EDMEulerScheduler",
|
| 3 |
+
"final_sigmas_type": "zero",
|
| 4 |
+
"num_train_timesteps": 1000,
|
| 5 |
+
"prediction_type": "epsilon",
|
| 6 |
+
"rho": 7.0,
|
| 7 |
+
"sigma_data": 0.5,
|
| 8 |
+
"sigma_max": 80.0,
|
| 9 |
+
"sigma_min": 0.002,
|
| 10 |
+
"sigma_schedule": "karras"
|
| 11 |
+
}
|
edm2-img512-xl-fid/unet/config.json
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_class_name": "EDM2UNet2DModel",
|
| 3 |
+
"attn_balance": 0.3,
|
| 4 |
+
"attn_resolutions": [
|
| 5 |
+
16,
|
| 6 |
+
8
|
| 7 |
+
],
|
| 8 |
+
"channel_mult": [
|
| 9 |
+
1,
|
| 10 |
+
2,
|
| 11 |
+
3,
|
| 12 |
+
4
|
| 13 |
+
],
|
| 14 |
+
"channel_mult_emb": 4,
|
| 15 |
+
"channel_mult_noise": 1,
|
| 16 |
+
"channels_per_head": 64,
|
| 17 |
+
"clip_act": 256,
|
| 18 |
+
"concat_balance": 0.5,
|
| 19 |
+
"dropout": 0.0,
|
| 20 |
+
"in_channels": 4,
|
| 21 |
+
"label_balance": 0.5,
|
| 22 |
+
"logvar_channels": 128,
|
| 23 |
+
"model_channels": 384,
|
| 24 |
+
"num_blocks": 3,
|
| 25 |
+
"num_class_embeds": 1000,
|
| 26 |
+
"out_channels": 4,
|
| 27 |
+
"res_balance": 0.3,
|
| 28 |
+
"sample_size": 64,
|
| 29 |
+
"sigma_data": 0.5,
|
| 30 |
+
"use_fp16": true
|
| 31 |
+
}
|
edm2-img512-xl-fid/unet/diffusion_pytorch_model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:3c7402d8a4e91781b5c94fa2a5beee5820970ad99d2249141e191364885f222a
|
| 3 |
+
size 4477161892
|
edm2-img512-xl-fid/unet/unet_edm2.py
ADDED
|
@@ -0,0 +1,434 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import json
|
| 3 |
+
import os
|
| 4 |
+
from dataclasses import dataclass
|
| 5 |
+
from typing import List, Optional, Tuple
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
import torch
|
| 9 |
+
|
| 10 |
+
try:
|
| 11 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
| 12 |
+
from diffusers.models.modeling_utils import ModelMixin
|
| 13 |
+
from diffusers.utils import BaseOutput
|
| 14 |
+
except ImportError: # pragma: no cover
|
| 15 |
+
class ModelMixin(torch.nn.Module):
|
| 16 |
+
pass
|
| 17 |
+
|
| 18 |
+
class ConfigMixin:
|
| 19 |
+
config = {}
|
| 20 |
+
|
| 21 |
+
def register_to_config(self, **kwargs):
|
| 22 |
+
self.config = kwargs
|
| 23 |
+
|
| 24 |
+
def register_to_config(func):
|
| 25 |
+
return func
|
| 26 |
+
|
| 27 |
+
@dataclass
|
| 28 |
+
class BaseOutput:
|
| 29 |
+
pass
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def normalize(x: torch.Tensor, dim: Optional[List[int]] = None, eps: float = 1e-4) -> torch.Tensor:
|
| 33 |
+
if dim is None:
|
| 34 |
+
dim = list(range(1, x.ndim))
|
| 35 |
+
norm = torch.linalg.vector_norm(x, dim=dim, keepdim=True, dtype=torch.float32)
|
| 36 |
+
norm = torch.add(eps, norm, alpha=math.sqrt(norm.numel() / x.numel()))
|
| 37 |
+
return x / norm.to(x.dtype)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def resample(x: torch.Tensor, f: List[float], mode: str = "keep") -> torch.Tensor:
|
| 41 |
+
if mode == "keep":
|
| 42 |
+
return x
|
| 43 |
+
filt = np.float32(f)
|
| 44 |
+
pad = (len(filt) - 1) // 2
|
| 45 |
+
filt = filt / filt.sum()
|
| 46 |
+
filt = np.outer(filt, filt)[np.newaxis, np.newaxis, :, :]
|
| 47 |
+
filt = torch.as_tensor(filt, dtype=x.dtype, device=x.device)
|
| 48 |
+
c = x.shape[1]
|
| 49 |
+
if mode == "down":
|
| 50 |
+
return torch.nn.functional.conv2d(x, filt.tile([c, 1, 1, 1]), groups=c, stride=2, padding=(pad,))
|
| 51 |
+
return torch.nn.functional.conv_transpose2d(x, (filt * 4).tile([c, 1, 1, 1]), groups=c, stride=2, padding=(pad,))
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def mp_silu(x: torch.Tensor) -> torch.Tensor:
|
| 55 |
+
return torch.nn.functional.silu(x) / 0.596
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def mp_sum(a: torch.Tensor, b: torch.Tensor, t: float = 0.5) -> torch.Tensor:
|
| 59 |
+
return a.lerp(b, t) / math.sqrt((1 - t) ** 2 + t ** 2)
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def mp_cat(a: torch.Tensor, b: torch.Tensor, dim: int = 1, t: float = 0.5) -> torch.Tensor:
|
| 63 |
+
na = a.shape[dim]
|
| 64 |
+
nb = b.shape[dim]
|
| 65 |
+
c = math.sqrt((na + nb) / ((1 - t) ** 2 + t ** 2))
|
| 66 |
+
wa = c / math.sqrt(na) * (1 - t)
|
| 67 |
+
wb = c / math.sqrt(nb) * t
|
| 68 |
+
return torch.cat([wa * a, wb * b], dim=dim)
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
class MPFourier(torch.nn.Module):
|
| 72 |
+
def __init__(self, num_channels: int, bandwidth: float = 1):
|
| 73 |
+
super().__init__()
|
| 74 |
+
self.register_buffer("freqs", 2 * math.pi * torch.randn(num_channels) * bandwidth)
|
| 75 |
+
self.register_buffer("phases", 2 * math.pi * torch.rand(num_channels))
|
| 76 |
+
|
| 77 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 78 |
+
y = x.to(torch.float32).ger(self.freqs.to(torch.float32))
|
| 79 |
+
y = y + self.phases.to(torch.float32)
|
| 80 |
+
y = y.cos() * math.sqrt(2)
|
| 81 |
+
return y.to(x.dtype)
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
class MPConv(torch.nn.Module):
|
| 85 |
+
def __init__(self, in_channels: int, out_channels: int, kernel: Tuple[int, ...]):
|
| 86 |
+
super().__init__()
|
| 87 |
+
self.out_channels = out_channels
|
| 88 |
+
self.weight = torch.nn.Parameter(torch.randn(out_channels, in_channels, *kernel))
|
| 89 |
+
|
| 90 |
+
def forward(self, x: torch.Tensor, gain: float = 1) -> torch.Tensor:
|
| 91 |
+
w = self.weight.to(torch.float32)
|
| 92 |
+
if self.training:
|
| 93 |
+
with torch.no_grad():
|
| 94 |
+
self.weight.copy_(normalize(w))
|
| 95 |
+
w = normalize(w)
|
| 96 |
+
w = w * (gain / math.sqrt(w[0].numel()))
|
| 97 |
+
w = w.to(x.dtype)
|
| 98 |
+
if w.ndim == 2:
|
| 99 |
+
return x @ w.t()
|
| 100 |
+
return torch.nn.functional.conv2d(x, w, padding=(w.shape[-1] // 2,))
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
class Block(torch.nn.Module):
|
| 104 |
+
def __init__(
|
| 105 |
+
self,
|
| 106 |
+
in_channels: int,
|
| 107 |
+
out_channels: int,
|
| 108 |
+
emb_channels: int,
|
| 109 |
+
flavor: str = "enc",
|
| 110 |
+
resample_mode: str = "keep",
|
| 111 |
+
resample_filter: List[float] = [1, 1],
|
| 112 |
+
attention: bool = False,
|
| 113 |
+
channels_per_head: int = 64,
|
| 114 |
+
dropout: float = 0.0,
|
| 115 |
+
res_balance: float = 0.3,
|
| 116 |
+
attn_balance: float = 0.3,
|
| 117 |
+
clip_act: Optional[float] = 256,
|
| 118 |
+
):
|
| 119 |
+
super().__init__()
|
| 120 |
+
self.out_channels = out_channels
|
| 121 |
+
self.flavor = flavor
|
| 122 |
+
self.resample_filter = resample_filter
|
| 123 |
+
self.resample_mode = resample_mode
|
| 124 |
+
self.num_heads = out_channels // channels_per_head if attention else 0
|
| 125 |
+
self.dropout = dropout
|
| 126 |
+
self.res_balance = res_balance
|
| 127 |
+
self.attn_balance = attn_balance
|
| 128 |
+
self.clip_act = clip_act
|
| 129 |
+
self.emb_gain = torch.nn.Parameter(torch.zeros([]))
|
| 130 |
+
self.conv_res0 = MPConv(out_channels if flavor == "enc" else in_channels, out_channels, kernel=(3, 3))
|
| 131 |
+
self.emb_linear = MPConv(emb_channels, out_channels, kernel=())
|
| 132 |
+
self.conv_res1 = MPConv(out_channels, out_channels, kernel=(3, 3))
|
| 133 |
+
self.conv_skip = MPConv(in_channels, out_channels, kernel=(1, 1)) if in_channels != out_channels else None
|
| 134 |
+
self.attn_qkv = MPConv(out_channels, out_channels * 3, kernel=(1, 1)) if self.num_heads else None
|
| 135 |
+
self.attn_proj = MPConv(out_channels, out_channels, kernel=(1, 1)) if self.num_heads else None
|
| 136 |
+
|
| 137 |
+
def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor:
|
| 138 |
+
x = resample(x, f=self.resample_filter, mode=self.resample_mode)
|
| 139 |
+
if self.flavor == "enc":
|
| 140 |
+
if self.conv_skip is not None:
|
| 141 |
+
x = self.conv_skip(x)
|
| 142 |
+
x = normalize(x, dim=[1])
|
| 143 |
+
|
| 144 |
+
y = self.conv_res0(mp_silu(x))
|
| 145 |
+
c = self.emb_linear(emb, gain=self.emb_gain) + 1
|
| 146 |
+
y = mp_silu(y * c.unsqueeze(2).unsqueeze(3).to(y.dtype))
|
| 147 |
+
if self.training and self.dropout:
|
| 148 |
+
y = torch.nn.functional.dropout(y, p=self.dropout)
|
| 149 |
+
y = self.conv_res1(y)
|
| 150 |
+
|
| 151 |
+
if self.flavor == "dec" and self.conv_skip is not None:
|
| 152 |
+
x = self.conv_skip(x)
|
| 153 |
+
x = mp_sum(x, y, t=self.res_balance)
|
| 154 |
+
|
| 155 |
+
if self.num_heads:
|
| 156 |
+
y = self.attn_qkv(x)
|
| 157 |
+
y = y.reshape(y.shape[0], self.num_heads, -1, 3, y.shape[2] * y.shape[3])
|
| 158 |
+
q, k, v = normalize(y, dim=[2]).unbind(3)
|
| 159 |
+
w = torch.einsum("nhcq,nhck->nhqk", q, k / math.sqrt(q.shape[2])).softmax(dim=3)
|
| 160 |
+
y = torch.einsum("nhqk,nhck->nhcq", w, v)
|
| 161 |
+
y = self.attn_proj(y.reshape(*x.shape))
|
| 162 |
+
x = mp_sum(x, y, t=self.attn_balance)
|
| 163 |
+
|
| 164 |
+
if self.clip_act is not None:
|
| 165 |
+
x = x.clip_(-self.clip_act, self.clip_act)
|
| 166 |
+
return x
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
class EDM2UNet(torch.nn.Module):
|
| 170 |
+
def __init__(
|
| 171 |
+
self,
|
| 172 |
+
img_resolution: int,
|
| 173 |
+
img_channels: int,
|
| 174 |
+
label_dim: int,
|
| 175 |
+
model_channels: int = 192,
|
| 176 |
+
channel_mult: Tuple[int, ...] = (1, 2, 3, 4),
|
| 177 |
+
channel_mult_noise: Optional[int] = None,
|
| 178 |
+
channel_mult_emb: Optional[int] = None,
|
| 179 |
+
num_blocks: int = 3,
|
| 180 |
+
attn_resolutions: Tuple[int, ...] = (16, 8),
|
| 181 |
+
label_balance: float = 0.5,
|
| 182 |
+
concat_balance: float = 0.5,
|
| 183 |
+
**block_kwargs,
|
| 184 |
+
):
|
| 185 |
+
super().__init__()
|
| 186 |
+
cblock = [model_channels * x for x in channel_mult]
|
| 187 |
+
cnoise = model_channels * channel_mult_noise if channel_mult_noise is not None else cblock[0]
|
| 188 |
+
cemb = model_channels * channel_mult_emb if channel_mult_emb is not None else max(cblock)
|
| 189 |
+
self.label_balance = label_balance
|
| 190 |
+
self.concat_balance = concat_balance
|
| 191 |
+
self.out_gain = torch.nn.Parameter(torch.zeros([]))
|
| 192 |
+
|
| 193 |
+
self.emb_fourier = MPFourier(cnoise)
|
| 194 |
+
self.emb_noise = MPConv(cnoise, cemb, kernel=())
|
| 195 |
+
self.emb_label = MPConv(label_dim, cemb, kernel=()) if label_dim else None
|
| 196 |
+
|
| 197 |
+
self.enc = torch.nn.ModuleDict()
|
| 198 |
+
cout = img_channels + 1
|
| 199 |
+
for level, channels in enumerate(cblock):
|
| 200 |
+
res = img_resolution >> level
|
| 201 |
+
if level == 0:
|
| 202 |
+
cin = cout
|
| 203 |
+
cout = channels
|
| 204 |
+
self.enc[f"{res}x{res}_conv"] = MPConv(cin, cout, kernel=(3, 3))
|
| 205 |
+
else:
|
| 206 |
+
self.enc[f"{res}x{res}_down"] = Block(cout, cout, cemb, flavor="enc", resample_mode="down", **block_kwargs)
|
| 207 |
+
for idx in range(num_blocks):
|
| 208 |
+
cin = cout
|
| 209 |
+
cout = channels
|
| 210 |
+
self.enc[f"{res}x{res}_block{idx}"] = Block(
|
| 211 |
+
cin,
|
| 212 |
+
cout,
|
| 213 |
+
cemb,
|
| 214 |
+
flavor="enc",
|
| 215 |
+
attention=(res in attn_resolutions),
|
| 216 |
+
**block_kwargs,
|
| 217 |
+
)
|
| 218 |
+
|
| 219 |
+
self.dec = torch.nn.ModuleDict()
|
| 220 |
+
skips = [block.out_channels for block in self.enc.values()]
|
| 221 |
+
for level, channels in reversed(list(enumerate(cblock))):
|
| 222 |
+
res = img_resolution >> level
|
| 223 |
+
if level == len(cblock) - 1:
|
| 224 |
+
self.dec[f"{res}x{res}_in0"] = Block(cout, cout, cemb, flavor="dec", attention=True, **block_kwargs)
|
| 225 |
+
self.dec[f"{res}x{res}_in1"] = Block(cout, cout, cemb, flavor="dec", **block_kwargs)
|
| 226 |
+
else:
|
| 227 |
+
self.dec[f"{res}x{res}_up"] = Block(cout, cout, cemb, flavor="dec", resample_mode="up", **block_kwargs)
|
| 228 |
+
for idx in range(num_blocks + 1):
|
| 229 |
+
cin = cout + skips.pop()
|
| 230 |
+
cout = channels
|
| 231 |
+
self.dec[f"{res}x{res}_block{idx}"] = Block(
|
| 232 |
+
cin,
|
| 233 |
+
cout,
|
| 234 |
+
cemb,
|
| 235 |
+
flavor="dec",
|
| 236 |
+
attention=(res in attn_resolutions),
|
| 237 |
+
**block_kwargs,
|
| 238 |
+
)
|
| 239 |
+
|
| 240 |
+
self.out_conv = MPConv(cout, img_channels, kernel=(3, 3))
|
| 241 |
+
|
| 242 |
+
def forward(self, x: torch.Tensor, noise_labels: torch.Tensor, class_labels: Optional[torch.Tensor]) -> torch.Tensor:
|
| 243 |
+
emb = self.emb_noise(self.emb_fourier(noise_labels))
|
| 244 |
+
if self.emb_label is not None:
|
| 245 |
+
if class_labels is None:
|
| 246 |
+
raise ValueError("class_labels are required for conditional EDM2UNet.")
|
| 247 |
+
emb = mp_sum(emb, self.emb_label(class_labels * math.sqrt(class_labels.shape[1])), t=self.label_balance)
|
| 248 |
+
emb = mp_silu(emb)
|
| 249 |
+
|
| 250 |
+
x = torch.cat([x, torch.ones_like(x[:, :1])], dim=1)
|
| 251 |
+
skips = []
|
| 252 |
+
for name, block in self.enc.items():
|
| 253 |
+
x = block(x) if "conv" in name else block(x, emb)
|
| 254 |
+
skips.append(x)
|
| 255 |
+
|
| 256 |
+
for name, block in self.dec.items():
|
| 257 |
+
if "block" in name:
|
| 258 |
+
x = mp_cat(x, skips.pop(), t=self.concat_balance)
|
| 259 |
+
x = block(x, emb)
|
| 260 |
+
return self.out_conv(x, gain=self.out_gain)
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
@dataclass
|
| 264 |
+
class EDM2UNet2DOutput(BaseOutput):
|
| 265 |
+
sample: torch.Tensor
|
| 266 |
+
logvar: Optional[torch.Tensor] = None
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
_CONFIG_KEYS = (
|
| 271 |
+
"sample_size",
|
| 272 |
+
"in_channels",
|
| 273 |
+
"out_channels",
|
| 274 |
+
"num_class_embeds",
|
| 275 |
+
"use_fp16",
|
| 276 |
+
"sigma_data",
|
| 277 |
+
"logvar_channels",
|
| 278 |
+
"model_channels",
|
| 279 |
+
"channel_mult",
|
| 280 |
+
"channel_mult_noise",
|
| 281 |
+
"channel_mult_emb",
|
| 282 |
+
"num_blocks",
|
| 283 |
+
"attn_resolutions",
|
| 284 |
+
"label_balance",
|
| 285 |
+
"concat_balance",
|
| 286 |
+
"dropout",
|
| 287 |
+
"channels_per_head",
|
| 288 |
+
"res_balance",
|
| 289 |
+
"attn_balance",
|
| 290 |
+
"clip_act",
|
| 291 |
+
)
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
class EDM2UNet2DModel(ModelMixin, ConfigMixin):
|
| 295 |
+
@register_to_config
|
| 296 |
+
def __init__(
|
| 297 |
+
self,
|
| 298 |
+
sample_size: int = 64,
|
| 299 |
+
in_channels: int = 4,
|
| 300 |
+
out_channels: int = 4,
|
| 301 |
+
num_class_embeds: int = 0,
|
| 302 |
+
use_fp16: bool = True,
|
| 303 |
+
sigma_data: float = 0.5,
|
| 304 |
+
logvar_channels: int = 128,
|
| 305 |
+
model_channels: int = 192,
|
| 306 |
+
channel_mult: Tuple[int, ...] = (1, 2, 3, 4),
|
| 307 |
+
channel_mult_noise: Optional[int] = None,
|
| 308 |
+
channel_mult_emb: Optional[int] = None,
|
| 309 |
+
num_blocks: int = 3,
|
| 310 |
+
attn_resolutions: Tuple[int, ...] = (16, 8),
|
| 311 |
+
label_balance: float = 0.5,
|
| 312 |
+
concat_balance: float = 0.5,
|
| 313 |
+
dropout: float = 0.0,
|
| 314 |
+
channels_per_head: int = 64,
|
| 315 |
+
res_balance: float = 0.3,
|
| 316 |
+
attn_balance: float = 0.3,
|
| 317 |
+
clip_act: Optional[float] = 256,
|
| 318 |
+
):
|
| 319 |
+
super().__init__()
|
| 320 |
+
self.sample_size = sample_size
|
| 321 |
+
self.in_channels = in_channels
|
| 322 |
+
self.out_channels = out_channels
|
| 323 |
+
self.num_class_embeds = num_class_embeds
|
| 324 |
+
self.use_fp16 = use_fp16
|
| 325 |
+
self.sigma_data = sigma_data
|
| 326 |
+
self.model_channels = model_channels
|
| 327 |
+
self.channel_mult = channel_mult
|
| 328 |
+
self.channel_mult_noise = channel_mult_noise
|
| 329 |
+
self.channel_mult_emb = channel_mult_emb
|
| 330 |
+
self.num_blocks = num_blocks
|
| 331 |
+
self.attn_resolutions = attn_resolutions
|
| 332 |
+
self.label_balance = label_balance
|
| 333 |
+
self.concat_balance = concat_balance
|
| 334 |
+
self.dropout = dropout
|
| 335 |
+
self.channels_per_head = channels_per_head
|
| 336 |
+
self.res_balance = res_balance
|
| 337 |
+
self.attn_balance = attn_balance
|
| 338 |
+
self.clip_act = clip_act
|
| 339 |
+
self.unet = EDM2UNet(
|
| 340 |
+
img_resolution=sample_size,
|
| 341 |
+
img_channels=in_channels,
|
| 342 |
+
label_dim=num_class_embeds,
|
| 343 |
+
model_channels=model_channels,
|
| 344 |
+
channel_mult=channel_mult,
|
| 345 |
+
channel_mult_noise=channel_mult_noise,
|
| 346 |
+
channel_mult_emb=channel_mult_emb,
|
| 347 |
+
num_blocks=num_blocks,
|
| 348 |
+
attn_resolutions=attn_resolutions,
|
| 349 |
+
label_balance=label_balance,
|
| 350 |
+
concat_balance=concat_balance,
|
| 351 |
+
dropout=dropout,
|
| 352 |
+
channels_per_head=channels_per_head,
|
| 353 |
+
res_balance=res_balance,
|
| 354 |
+
attn_balance=attn_balance,
|
| 355 |
+
clip_act=clip_act,
|
| 356 |
+
)
|
| 357 |
+
self.logvar_fourier = MPFourier(logvar_channels)
|
| 358 |
+
self.logvar_linear = MPConv(logvar_channels, 1, kernel=())
|
| 359 |
+
|
| 360 |
+
def forward(
|
| 361 |
+
self,
|
| 362 |
+
sample: torch.Tensor,
|
| 363 |
+
sigma: torch.Tensor,
|
| 364 |
+
class_labels: Optional[torch.Tensor] = None,
|
| 365 |
+
force_fp32: bool = False,
|
| 366 |
+
return_logvar: bool = False,
|
| 367 |
+
return_dict: bool = True,
|
| 368 |
+
) -> EDM2UNet2DOutput:
|
| 369 |
+
x = sample.to(torch.float32)
|
| 370 |
+
sigma = sigma.to(torch.float32).reshape(-1, 1, 1, 1)
|
| 371 |
+
if self.num_class_embeds == 0:
|
| 372 |
+
class_labels = None
|
| 373 |
+
else:
|
| 374 |
+
if class_labels is None:
|
| 375 |
+
class_labels = torch.zeros([x.shape[0], self.num_class_embeds], device=x.device)
|
| 376 |
+
class_labels = class_labels.to(torch.float32).reshape(-1, self.num_class_embeds)
|
| 377 |
+
dtype = torch.float16 if (self.use_fp16 and not force_fp32 and x.device.type == "cuda") else torch.float32
|
| 378 |
+
|
| 379 |
+
c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2)
|
| 380 |
+
c_out = sigma * self.sigma_data / (sigma**2 + self.sigma_data**2).sqrt()
|
| 381 |
+
c_in = 1 / (self.sigma_data**2 + sigma**2).sqrt()
|
| 382 |
+
c_noise = sigma.flatten().log() / 4
|
| 383 |
+
|
| 384 |
+
x_in = (c_in * x).to(dtype)
|
| 385 |
+
f_x = self.unet(x_in, c_noise, class_labels)
|
| 386 |
+
d_x = c_skip * x + c_out * f_x.to(torch.float32)
|
| 387 |
+
|
| 388 |
+
logvar = None
|
| 389 |
+
if return_logvar:
|
| 390 |
+
logvar = self.logvar_linear(self.logvar_fourier(c_noise)).reshape(-1, 1, 1, 1)
|
| 391 |
+
|
| 392 |
+
if not return_dict:
|
| 393 |
+
return (d_x, logvar)
|
| 394 |
+
return EDM2UNet2DOutput(sample=d_x, logvar=logvar)
|
| 395 |
+
|
| 396 |
+
@classmethod
|
| 397 |
+
def from_pretrained(cls, pretrained_model_name_or_path: str, torch_dtype: Optional[torch.dtype] = None, **kwargs):
|
| 398 |
+
subfolder = kwargs.pop("subfolder", None)
|
| 399 |
+
model_dir = os.path.join(pretrained_model_name_or_path, subfolder) if subfolder else pretrained_model_name_or_path
|
| 400 |
+
with open(os.path.join(model_dir, "config.json"), "r", encoding="utf-8") as f:
|
| 401 |
+
config = json.load(f)
|
| 402 |
+
init_kwargs = {k: v for k, v in config.items() if k in _CONFIG_KEYS}
|
| 403 |
+
model = cls(**init_kwargs)
|
| 404 |
+
weight_file = os.path.join(model_dir, "diffusion_pytorch_model.safetensors")
|
| 405 |
+
if os.path.isfile(weight_file):
|
| 406 |
+
from safetensors.torch import load_file
|
| 407 |
+
|
| 408 |
+
state_dict = load_file(weight_file)
|
| 409 |
+
else:
|
| 410 |
+
state_dict = torch.load(os.path.join(model_dir, "diffusion_pytorch_model.bin"), map_location="cpu")
|
| 411 |
+
model.load_state_dict(state_dict, strict=True)
|
| 412 |
+
if torch_dtype is not None:
|
| 413 |
+
model = model.to(dtype=torch_dtype)
|
| 414 |
+
return model
|
| 415 |
+
|
| 416 |
+
def save_pretrained(self, save_directory: str, safe_serialization: bool = True):
|
| 417 |
+
os.makedirs(save_directory, exist_ok=True)
|
| 418 |
+
stored = dict(getattr(self, "config", {}))
|
| 419 |
+
config = {"_class_name": self.__class__.__name__}
|
| 420 |
+
for key in _CONFIG_KEYS:
|
| 421 |
+
if key in stored:
|
| 422 |
+
config[key] = stored[key]
|
| 423 |
+
elif hasattr(self, key):
|
| 424 |
+
config[key] = getattr(self, key)
|
| 425 |
+
with open(os.path.join(save_directory, "config.json"), "w", encoding="utf-8") as f:
|
| 426 |
+
json.dump(config, f, indent=2, sort_keys=True)
|
| 427 |
+
f.write("\n")
|
| 428 |
+
state_dict = self.state_dict()
|
| 429 |
+
if safe_serialization:
|
| 430 |
+
from safetensors.torch import save_file
|
| 431 |
+
|
| 432 |
+
save_file(state_dict, os.path.join(save_directory, "diffusion_pytorch_model.safetensors"))
|
| 433 |
+
else:
|
| 434 |
+
torch.save(state_dict, os.path.join(save_directory, "diffusion_pytorch_model.bin"))
|
edm2-img512-xl-fid/vae/config.json
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_class_name": "AutoencoderKL",
|
| 3 |
+
"_diffusers_version": "0.36.0",
|
| 4 |
+
"_name_or_path": "stabilityai/sd-vae-ft-mse",
|
| 5 |
+
"act_fn": "silu",
|
| 6 |
+
"block_out_channels": [
|
| 7 |
+
128,
|
| 8 |
+
256,
|
| 9 |
+
512,
|
| 10 |
+
512
|
| 11 |
+
],
|
| 12 |
+
"down_block_types": [
|
| 13 |
+
"DownEncoderBlock2D",
|
| 14 |
+
"DownEncoderBlock2D",
|
| 15 |
+
"DownEncoderBlock2D",
|
| 16 |
+
"DownEncoderBlock2D"
|
| 17 |
+
],
|
| 18 |
+
"force_upcast": true,
|
| 19 |
+
"in_channels": 3,
|
| 20 |
+
"latent_channels": 4,
|
| 21 |
+
"latents_mean": null,
|
| 22 |
+
"latents_std": null,
|
| 23 |
+
"layers_per_block": 2,
|
| 24 |
+
"mid_block_add_attention": true,
|
| 25 |
+
"norm_num_groups": 32,
|
| 26 |
+
"out_channels": 3,
|
| 27 |
+
"sample_size": 256,
|
| 28 |
+
"scaling_factor": 0.18215,
|
| 29 |
+
"shift_factor": null,
|
| 30 |
+
"up_block_types": [
|
| 31 |
+
"UpDecoderBlock2D",
|
| 32 |
+
"UpDecoderBlock2D",
|
| 33 |
+
"UpDecoderBlock2D",
|
| 34 |
+
"UpDecoderBlock2D"
|
| 35 |
+
],
|
| 36 |
+
"use_post_quant_conv": true,
|
| 37 |
+
"use_quant_conv": true
|
| 38 |
+
}
|
edm2-img512-xl-fid/vae/diffusion_pytorch_model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a1d993488569e928462932c8c38a0760b874d166399b14414135bd9c42df5815
|
| 3 |
+
size 334643276
|
edm2-img512-xs-fid/demo.png
ADDED
|
Git LFS Details
|
edm2-img512-xs-fid/model_index.json
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_class_name": [
|
| 3 |
+
"pipeline",
|
| 4 |
+
"EDM2Pipeline"
|
| 5 |
+
],
|
| 6 |
+
"_diffusers_version": "0.31.0",
|
| 7 |
+
"scheduler": [
|
| 8 |
+
"diffusers",
|
| 9 |
+
"EDMEulerScheduler"
|
| 10 |
+
],
|
| 11 |
+
"unet": [
|
| 12 |
+
"unet_edm2",
|
| 13 |
+
"EDM2UNet2DModel"
|
| 14 |
+
],
|
| 15 |
+
"vae": [
|
| 16 |
+
"diffusers",
|
| 17 |
+
"AutoencoderKL"
|
| 18 |
+
]
|
| 19 |
+
}
|
edm2-img512-xs-fid/pipeline.py
ADDED
|
@@ -0,0 +1,406 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Hub custom pipeline: EDM2Pipeline.
|
| 2 |
+
Load with native Hugging Face diffusers and trust_remote_code=True.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from __future__ import annotations
|
| 6 |
+
|
| 7 |
+
# Copyright 2026 The HuggingFace Team. All rights reserved.
|
| 8 |
+
#
|
| 9 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 10 |
+
# you may not use this file except in compliance with the License.
|
| 11 |
+
# You may obtain a copy of the License at
|
| 12 |
+
#
|
| 13 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 14 |
+
#
|
| 15 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 16 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 17 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 18 |
+
# See the License for the specific language governing permissions and
|
| 19 |
+
# limitations under the License.
|
| 20 |
+
|
| 21 |
+
import json
|
| 22 |
+
import os
|
| 23 |
+
from pathlib import Path
|
| 24 |
+
from typing import Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Union
|
| 25 |
+
|
| 26 |
+
import numpy as np
|
| 27 |
+
import torch
|
| 28 |
+
|
| 29 |
+
from diffusers.image_processor import VaeImageProcessor
|
| 30 |
+
from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
| 31 |
+
from diffusers.utils import replace_example_docstring
|
| 32 |
+
from diffusers.utils.torch_utils import randn_tensor
|
| 33 |
+
|
| 34 |
+
EXAMPLE_DOC_STRING = """
|
| 35 |
+
Examples:
|
| 36 |
+
```py
|
| 37 |
+
>>> from pathlib import Path
|
| 38 |
+
>>> import torch
|
| 39 |
+
>>> from diffusers import DiffusionPipeline
|
| 40 |
+
|
| 41 |
+
>>> model_dir = Path("BiliSakura/EDM2-diffusers/edm2-img512-xs-fid").resolve()
|
| 42 |
+
>>> pipe = DiffusionPipeline.from_pretrained(
|
| 43 |
+
... str(model_dir),
|
| 44 |
+
... local_files_only=True,
|
| 45 |
+
... custom_pipeline=str(model_dir / "pipeline.py"),
|
| 46 |
+
... trust_remote_code=True,
|
| 47 |
+
... torch_dtype=torch.float32,
|
| 48 |
+
... )
|
| 49 |
+
>>> pipe.to("cuda")
|
| 50 |
+
|
| 51 |
+
>>> generator = torch.Generator(device="cuda").manual_seed(42)
|
| 52 |
+
>>> image = pipe(
|
| 53 |
+
... class_labels=207,
|
| 54 |
+
... num_inference_steps=32,
|
| 55 |
+
... guidance_scale=1.0,
|
| 56 |
+
... generator=generator,
|
| 57 |
+
... ).images[0]
|
| 58 |
+
>>> image.save("demo.png")
|
| 59 |
+
```
|
| 60 |
+
"""
|
| 61 |
+
|
| 62 |
+
# Default Stability VAE latent whitening used by NVlabs/edm2 (training/encoders.py).
|
| 63 |
+
_STABILITY_VAE_SCALE = np.float32(0.5) / np.float32([4.17, 4.62, 3.71, 3.28])
|
| 64 |
+
_STABILITY_VAE_BIAS = np.float32(0.0) - np.float32([5.81, 3.25, 0.12, -2.15]) * _STABILITY_VAE_SCALE
|
| 65 |
+
|
| 66 |
+
class EDM2Pipeline(DiffusionPipeline):
|
| 67 |
+
r"""
|
| 68 |
+
Pipeline for class-conditional image generation with EDM2
|
| 69 |
+
([Analyzing and Improving the Training Dynamics of Diffusion Models](https://arxiv.org/abs/2312.02696)).
|
| 70 |
+
|
| 71 |
+
Parameters:
|
| 72 |
+
unet ([`EDM2UNet2DModel`]):
|
| 73 |
+
Main magnitude-preserving U-Net with EDM preconditioning.
|
| 74 |
+
scheduler ([`EDMEulerScheduler`]):
|
| 75 |
+
Built-in diffusers scheduler used for the Karras sigma schedule. EDM2 Heun sampling runs in
|
| 76 |
+
the pipeline because the UNet returns denoised latents rather than noise predictions.
|
| 77 |
+
vae ([`AutoencoderKL`], *optional*):
|
| 78 |
+
Decoder for 512px latent-diffusion checkpoints. Required when `unet.in_channels == 4`.
|
| 79 |
+
gnet ([`EDM2UNet2DModel`], *optional*):
|
| 80 |
+
Guiding network for autoguidance (`ref.lerp(main, guidance_scale)`).
|
| 81 |
+
id2label (`dict[int, str]`, *optional*):
|
| 82 |
+
ImageNet class id to English label mapping.
|
| 83 |
+
"""
|
| 84 |
+
|
| 85 |
+
model_cpu_offload_seq = "unet->gnet->vae"
|
| 86 |
+
_optional_components = ["vae", "gnet"]
|
| 87 |
+
|
| 88 |
+
def __init__(
|
| 89 |
+
self,
|
| 90 |
+
unet,
|
| 91 |
+
scheduler,
|
| 92 |
+
vae=None,
|
| 93 |
+
gnet=None,
|
| 94 |
+
id2label: Optional[Dict[Union[int, str], str]] = None,
|
| 95 |
+
) -> None:
|
| 96 |
+
super().__init__()
|
| 97 |
+
self.register_modules(unet=unet, scheduler=scheduler, vae=vae, gnet=gnet)
|
| 98 |
+
self._id2label = self._normalize_id2label(id2label)
|
| 99 |
+
self.labels = self._build_label2id(self._id2label)
|
| 100 |
+
self._labels_loaded_from_model_index = bool(self._id2label)
|
| 101 |
+
self.vae_scale_factor = 8 if self.vae is not None else 1
|
| 102 |
+
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_normalize=False)
|
| 103 |
+
|
| 104 |
+
@staticmethod
|
| 105 |
+
def _normalize_id2label(id2label: Optional[Dict[Union[int, str], str]]) -> Dict[int, str]:
|
| 106 |
+
if not id2label:
|
| 107 |
+
return {}
|
| 108 |
+
return {int(key): value for key, value in id2label.items()}
|
| 109 |
+
|
| 110 |
+
@staticmethod
|
| 111 |
+
def _build_label2id(id2label: Dict[int, str]) -> Dict[str, int]:
|
| 112 |
+
label2id: Dict[str, int] = {}
|
| 113 |
+
for class_id, value in id2label.items():
|
| 114 |
+
for synonym in value.split(","):
|
| 115 |
+
synonym = synonym.strip()
|
| 116 |
+
if synonym:
|
| 117 |
+
label2id[synonym] = int(class_id)
|
| 118 |
+
return dict(sorted(label2id.items()))
|
| 119 |
+
|
| 120 |
+
def _ensure_labels_loaded(self) -> None:
|
| 121 |
+
if self._labels_loaded_from_model_index:
|
| 122 |
+
return
|
| 123 |
+
loaded = self._read_id2label_from_model_index(getattr(self.config, "_name_or_path", None))
|
| 124 |
+
if loaded:
|
| 125 |
+
self._id2label = loaded
|
| 126 |
+
self.labels = self._build_label2id(self._id2label)
|
| 127 |
+
self._labels_loaded_from_model_index = True
|
| 128 |
+
|
| 129 |
+
@staticmethod
|
| 130 |
+
def _read_id2label_from_model_index(variant_path: Optional[str]) -> Dict[int, str]:
|
| 131 |
+
if not variant_path:
|
| 132 |
+
return {}
|
| 133 |
+
model_index_path = Path(variant_path).resolve() / "model_index.json"
|
| 134 |
+
if not model_index_path.is_file():
|
| 135 |
+
return {}
|
| 136 |
+
raw = json.loads(model_index_path.read_text(encoding="utf-8"))
|
| 137 |
+
id2label = raw.get("id2label")
|
| 138 |
+
if not isinstance(id2label, dict):
|
| 139 |
+
return {}
|
| 140 |
+
return {int(key): value for key, value in id2label.items()}
|
| 141 |
+
|
| 142 |
+
@property
|
| 143 |
+
def id2label(self) -> Dict[int, str]:
|
| 144 |
+
r"""ImageNet class id to English label string (comma-separated synonyms)."""
|
| 145 |
+
self._ensure_labels_loaded()
|
| 146 |
+
return self._id2label
|
| 147 |
+
|
| 148 |
+
def get_label_ids(self, label: Union[str, List[str]]) -> List[int]:
|
| 149 |
+
r"""
|
| 150 |
+
Map ImageNet label strings to class ids.
|
| 151 |
+
|
| 152 |
+
Args:
|
| 153 |
+
label (`str` or `list[str]`):
|
| 154 |
+
One or more English label strings that match entries in `id2label`.
|
| 155 |
+
"""
|
| 156 |
+
self._ensure_labels_loaded()
|
| 157 |
+
if not self.labels:
|
| 158 |
+
raise ValueError("No English labels loaded. Add `id2label` to model_index.json.")
|
| 159 |
+
labels = [label] if isinstance(label, str) else list(label)
|
| 160 |
+
missing = [item for item in labels if item not in self.labels]
|
| 161 |
+
if missing:
|
| 162 |
+
preview = ", ".join(list(self.labels.keys())[:8])
|
| 163 |
+
raise ValueError(f"Unknown English label(s): {missing}. Example valid labels: {preview}, ...")
|
| 164 |
+
return [self.labels[item] for item in labels]
|
| 165 |
+
|
| 166 |
+
def _default_image_size(self) -> int:
|
| 167 |
+
latent_size = int(getattr(self.unet, "sample_size", getattr(self.unet.config, "sample_size", 64)))
|
| 168 |
+
return latent_size * self.vae_scale_factor
|
| 169 |
+
|
| 170 |
+
def check_inputs(
|
| 171 |
+
self,
|
| 172 |
+
height: int,
|
| 173 |
+
width: int,
|
| 174 |
+
num_inference_steps: int,
|
| 175 |
+
guidance_scale: float,
|
| 176 |
+
output_type: str,
|
| 177 |
+
) -> None:
|
| 178 |
+
if num_inference_steps < 1:
|
| 179 |
+
raise ValueError("num_inference_steps must be >= 1.")
|
| 180 |
+
if guidance_scale < 1.0:
|
| 181 |
+
raise ValueError("guidance_scale must be >= 1.0.")
|
| 182 |
+
if guidance_scale > 1.0 and self.gnet is None:
|
| 183 |
+
raise ValueError("guidance_scale > 1.0 requires a guiding network (`gnet`).")
|
| 184 |
+
if output_type not in {"pil", "np", "pt", "latent"}:
|
| 185 |
+
raise ValueError("output_type must be one of: 'pil', 'np', 'pt', 'latent'.")
|
| 186 |
+
|
| 187 |
+
native_size = self._default_image_size()
|
| 188 |
+
if height != native_size or width != native_size:
|
| 189 |
+
raise ValueError(
|
| 190 |
+
f"EDM2 expects native resolution height=width={native_size}. "
|
| 191 |
+
f"Got height={height}, width={width}."
|
| 192 |
+
)
|
| 193 |
+
|
| 194 |
+
def _normalize_class_labels(
|
| 195 |
+
self,
|
| 196 |
+
class_labels: Optional[Union[int, str, Sequence[Union[int, str]], torch.Tensor]],
|
| 197 |
+
batch_size: int,
|
| 198 |
+
device: torch.device,
|
| 199 |
+
) -> Optional[torch.Tensor]:
|
| 200 |
+
label_dim = int(getattr(self.unet, "num_class_embeds", getattr(self.unet.config, "num_class_embeds", 0)))
|
| 201 |
+
if label_dim == 0:
|
| 202 |
+
return None
|
| 203 |
+
if class_labels is None:
|
| 204 |
+
indices = torch.randint(label_dim, size=(batch_size,), device=device)
|
| 205 |
+
return torch.eye(label_dim, device=device, dtype=torch.float32)[indices]
|
| 206 |
+
|
| 207 |
+
if isinstance(class_labels, str):
|
| 208 |
+
class_labels = self.get_label_ids(class_labels)[0]
|
| 209 |
+
elif isinstance(class_labels, Sequence) and class_labels and isinstance(class_labels[0], str):
|
| 210 |
+
class_labels = self.get_label_ids(list(class_labels))
|
| 211 |
+
|
| 212 |
+
if isinstance(class_labels, int):
|
| 213 |
+
indices = torch.full((batch_size,), class_labels, device=device, dtype=torch.long)
|
| 214 |
+
elif isinstance(class_labels, torch.Tensor):
|
| 215 |
+
if class_labels.ndim == 2:
|
| 216 |
+
labels = class_labels.to(device=device, dtype=torch.float32)
|
| 217 |
+
if labels.shape[0] != batch_size:
|
| 218 |
+
raise ValueError(f"class_labels batch must match batch_size={batch_size}.")
|
| 219 |
+
return labels
|
| 220 |
+
indices = class_labels.to(device=device, dtype=torch.long).flatten()
|
| 221 |
+
else:
|
| 222 |
+
indices = torch.tensor(list(class_labels), device=device, dtype=torch.long)
|
| 223 |
+
|
| 224 |
+
if indices.numel() == 1 and batch_size > 1:
|
| 225 |
+
indices = indices.repeat(batch_size)
|
| 226 |
+
if indices.numel() != batch_size:
|
| 227 |
+
raise ValueError(f"class_labels must resolve to batch size {batch_size}, got {indices.numel()}.")
|
| 228 |
+
return torch.eye(label_dim, device=device, dtype=torch.float32)[indices]
|
| 229 |
+
|
| 230 |
+
def prepare_latents(
|
| 231 |
+
self,
|
| 232 |
+
batch_size: int,
|
| 233 |
+
height: int,
|
| 234 |
+
width: int,
|
| 235 |
+
dtype: torch.dtype,
|
| 236 |
+
device: torch.device,
|
| 237 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]],
|
| 238 |
+
) -> torch.Tensor:
|
| 239 |
+
in_channels = int(getattr(self.unet, "in_channels", getattr(self.unet.config, "in_channels", 4)))
|
| 240 |
+
latent_size = height // self.vae_scale_factor
|
| 241 |
+
return randn_tensor(
|
| 242 |
+
(batch_size, in_channels, latent_size, latent_size),
|
| 243 |
+
generator=generator,
|
| 244 |
+
device=device,
|
| 245 |
+
dtype=torch.float32,
|
| 246 |
+
)
|
| 247 |
+
|
| 248 |
+
def decode_latents(self, latents: torch.Tensor, output_type: str = "pil"):
|
| 249 |
+
if output_type == "latent":
|
| 250 |
+
return latents
|
| 251 |
+
|
| 252 |
+
in_channels = int(getattr(self.unet, "in_channels", getattr(self.unet.config, "in_channels", 3)))
|
| 253 |
+
if self.vae is None:
|
| 254 |
+
image = (latents.to(torch.float32) * 127.5 + 128).clip(0, 255) / 255.0
|
| 255 |
+
return self.image_processor.postprocess(image, output_type=output_type)
|
| 256 |
+
|
| 257 |
+
if in_channels == 4:
|
| 258 |
+
x = latents.to(torch.float32)
|
| 259 |
+
scale = torch.as_tensor(_STABILITY_VAE_SCALE, dtype=x.dtype, device=x.device).reshape(1, -1, 1, 1)
|
| 260 |
+
bias = torch.as_tensor(_STABILITY_VAE_BIAS, dtype=x.dtype, device=x.device).reshape(1, -1, 1, 1)
|
| 261 |
+
x = (x - bias) / scale
|
| 262 |
+
else:
|
| 263 |
+
x = latents.to(torch.float32)
|
| 264 |
+
|
| 265 |
+
vae_dtype = getattr(self.vae, "dtype", None) or next(self.vae.parameters()).dtype
|
| 266 |
+
image = self.vae.decode(x.to(dtype=vae_dtype)).sample.to(torch.float32).clamp(0, 1)
|
| 267 |
+
|
| 268 |
+
return self.image_processor.postprocess(image, output_type=output_type)
|
| 269 |
+
|
| 270 |
+
@staticmethod
|
| 271 |
+
def _apply_autoguidance(
|
| 272 |
+
main: torch.Tensor,
|
| 273 |
+
ref: torch.Tensor,
|
| 274 |
+
guidance_scale: float,
|
| 275 |
+
) -> torch.Tensor:
|
| 276 |
+
return ref.lerp(main, guidance_scale)
|
| 277 |
+
|
| 278 |
+
@staticmethod
|
| 279 |
+
def _sample_edm2_heun(
|
| 280 |
+
denoise_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
|
| 281 |
+
noise: torch.Tensor,
|
| 282 |
+
sigmas: torch.Tensor,
|
| 283 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 284 |
+
progress_bar: Optional[Callable[[Iterable], Iterable]] = None,
|
| 285 |
+
dtype: torch.dtype = torch.float32,
|
| 286 |
+
) -> torch.Tensor:
|
| 287 |
+
"""NVlabs EDM2 Heun sampler (generate_images.edm_sampler, guidance=1, S_churn=0)."""
|
| 288 |
+
x_next = noise.to(dtype) * sigmas[0]
|
| 289 |
+
|
| 290 |
+
sigma_pairs = list(zip(sigmas[:-1], sigmas[1:]))
|
| 291 |
+
if progress_bar is not None:
|
| 292 |
+
sigma_pairs = progress_bar(sigma_pairs)
|
| 293 |
+
|
| 294 |
+
num_steps = len(sigma_pairs)
|
| 295 |
+
for i, (sigma_cur, sigma_next) in enumerate(sigma_pairs):
|
| 296 |
+
x_hat, sigma_hat = x_next, sigma_cur
|
| 297 |
+
d_cur = (x_hat - denoise_fn(x_hat, sigma_hat)) / sigma_hat
|
| 298 |
+
x_next = x_hat + (sigma_next - sigma_hat) * d_cur
|
| 299 |
+
if i < num_steps - 1:
|
| 300 |
+
d_prime = (x_next - denoise_fn(x_next, sigma_next)) / sigma_next
|
| 301 |
+
x_next = x_hat + (sigma_next - sigma_hat) * (0.5 * d_cur + 0.5 * d_prime)
|
| 302 |
+
return x_next
|
| 303 |
+
|
| 304 |
+
@torch.inference_mode()
|
| 305 |
+
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
| 306 |
+
def __call__(
|
| 307 |
+
self,
|
| 308 |
+
class_labels: Optional[Union[int, str, Sequence[Union[int, str]], torch.Tensor]] = None,
|
| 309 |
+
batch_size: int = 1,
|
| 310 |
+
height: Optional[int] = None,
|
| 311 |
+
width: Optional[int] = None,
|
| 312 |
+
num_inference_steps: int = 32,
|
| 313 |
+
guidance_scale: float = 1.0,
|
| 314 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 315 |
+
output_type: str = "pil",
|
| 316 |
+
return_dict: bool = True,
|
| 317 |
+
) -> Union[ImagePipelineOutput, Tuple]:
|
| 318 |
+
r"""
|
| 319 |
+
Generate class-conditional images with EDM2.
|
| 320 |
+
|
| 321 |
+
Args:
|
| 322 |
+
class_labels (`int`, `str`, `list[int]`, `list[str]`, or `torch.Tensor`, *optional*):
|
| 323 |
+
ImageNet class indices, English label strings, or one-hot float tensors.
|
| 324 |
+
Random classes are sampled when omitted on conditional models.
|
| 325 |
+
batch_size (`int`, defaults to `1`):
|
| 326 |
+
Number of images to generate.
|
| 327 |
+
height (`int`, *optional*):
|
| 328 |
+
Output height in pixels. Defaults to the pretrained native resolution.
|
| 329 |
+
width (`int`, *optional*):
|
| 330 |
+
Output width in pixels. Defaults to the pretrained native resolution.
|
| 331 |
+
num_inference_steps (`int`, defaults to `32`):
|
| 332 |
+
Number of EDM2 Heun steps (NVlabs default).
|
| 333 |
+
guidance_scale (`float`, defaults to `1.0`):
|
| 334 |
+
Autoguidance strength. Values above `1.0` blend the main net with `gnet`
|
| 335 |
+
via `gnet_output.lerp(unet_output, guidance_scale)`.
|
| 336 |
+
generator (`torch.Generator`, *optional*):
|
| 337 |
+
RNG for reproducibility.
|
| 338 |
+
output_type (`str`, defaults to `"pil"`):
|
| 339 |
+
`"pil"`, `"np"`, `"pt"`, or `"latent"`.
|
| 340 |
+
return_dict (`bool`, defaults to `True`):
|
| 341 |
+
Return [`~pipelines.pipeline_utils.ImagePipelineOutput`] if True.
|
| 342 |
+
|
| 343 |
+
Examples:
|
| 344 |
+
<!-- this section is replaced by replace_example_docstring -->
|
| 345 |
+
"""
|
| 346 |
+
default_size = self._default_image_size()
|
| 347 |
+
height = int(height or default_size)
|
| 348 |
+
width = int(width or default_size)
|
| 349 |
+
self.check_inputs(height, width, num_inference_steps, guidance_scale, output_type)
|
| 350 |
+
|
| 351 |
+
device = self._execution_device
|
| 352 |
+
dtype = self.unet.dtype
|
| 353 |
+
labels = self._normalize_class_labels(class_labels, batch_size=batch_size, device=device)
|
| 354 |
+
noise = self.prepare_latents(batch_size, height, width, dtype, device, generator)
|
| 355 |
+
|
| 356 |
+
def denoise_fn(x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor:
|
| 357 |
+
sigma_batch = sigma.reshape(1).expand(batch_size)
|
| 358 |
+
main = self.unet(
|
| 359 |
+
sample=x,
|
| 360 |
+
sigma=sigma_batch,
|
| 361 |
+
class_labels=labels,
|
| 362 |
+
force_fp32=True,
|
| 363 |
+
).sample
|
| 364 |
+
if guidance_scale == 1.0 or self.gnet is None:
|
| 365 |
+
return main.to(torch.float32)
|
| 366 |
+
ref = self.gnet(
|
| 367 |
+
sample=x,
|
| 368 |
+
sigma=sigma_batch,
|
| 369 |
+
class_labels=labels,
|
| 370 |
+
force_fp32=True,
|
| 371 |
+
).sample
|
| 372 |
+
return self._apply_autoguidance(main, ref, guidance_scale).to(torch.float32)
|
| 373 |
+
|
| 374 |
+
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
| 375 |
+
latents = self._sample_edm2_heun(
|
| 376 |
+
denoise_fn=denoise_fn,
|
| 377 |
+
noise=noise,
|
| 378 |
+
sigmas=self.scheduler.sigmas.to(device),
|
| 379 |
+
generator=generator,
|
| 380 |
+
progress_bar=self.progress_bar,
|
| 381 |
+
dtype=torch.float32,
|
| 382 |
+
)
|
| 383 |
+
|
| 384 |
+
image = self.decode_latents(latents, output_type=output_type)
|
| 385 |
+
if not return_dict:
|
| 386 |
+
return (image, latents)
|
| 387 |
+
return ImagePipelineOutput(images=image)
|
| 388 |
+
|
| 389 |
+
@classmethod
|
| 390 |
+
def _load_vae(cls, pretrained_model_name_or_path: str, torch_dtype: Optional[torch.dtype] = None):
|
| 391 |
+
vae_dir = os.path.join(pretrained_model_name_or_path, "vae")
|
| 392 |
+
if os.path.isdir(vae_dir):
|
| 393 |
+
try:
|
| 394 |
+
|
| 395 |
+
return AutoencoderKL.from_pretrained(vae_dir, torch_dtype=torch_dtype)
|
| 396 |
+
except Exception:
|
| 397 |
+
return None
|
| 398 |
+
|
| 399 |
+
vae_hint = os.path.join(pretrained_model_name_or_path, "vae_pretrained_model_name_or_path.txt")
|
| 400 |
+
if os.path.isfile(vae_hint):
|
| 401 |
+
with open(vae_hint, "r", encoding="utf-8") as f:
|
| 402 |
+
hub_id = f.read().strip()
|
| 403 |
+
if hub_id:
|
| 404 |
+
|
| 405 |
+
return AutoencoderKL.from_pretrained(hub_id, torch_dtype=torch_dtype)
|
| 406 |
+
return None
|