diff --git a/.gitattributes b/.gitattributes
index a6344aac8c09253b3b630fb776ae94478aa0275b..9926f7e8b88c5be1aec0dbb4af4bae3f0ed5d623 100644
--- a/.gitattributes
+++ b/.gitattributes
@@ -33,3 +33,10 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
*.zip filter=lfs diff=lfs merge=lfs -text
*.zst filter=lfs diff=lfs merge=lfs -text
*tfevents* filter=lfs diff=lfs merge=lfs -text
+edm2-img512-l-dino/demo.png filter=lfs diff=lfs merge=lfs -text
+edm2-img512-l-fid/generator_test.png filter=lfs diff=lfs merge=lfs -text
+edm2-img512-m-fid/demo.png filter=lfs diff=lfs merge=lfs -text
+edm2-img512-s-fid/demo.png filter=lfs diff=lfs merge=lfs -text
+edm2-img512-xl-fid/demo.png filter=lfs diff=lfs merge=lfs -text
+edm2-img512-xs-fid/demo.png filter=lfs diff=lfs merge=lfs -text
+edm2-img512-xxl-fid/demo.png filter=lfs diff=lfs merge=lfs -text
diff --git a/README.md b/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..aed779e2b11ced8045fe03219eddd7a494c0e89f
--- /dev/null
+++ b/README.md
@@ -0,0 +1,199 @@
+---
+license: cc-by-nc-sa-4.0
+library_name: diffusers
+pipeline_tag: unconditional-image-generation
+tags:
+ - diffusers
+ - edm2
+ - image-generation
+ - class-conditional
+ - imagenet
+inference: true
+widget:
+ - output:
+ url: edm2-img512-xxl-fid/demo.png
+language:
+ - en
+---
+
+# EDM2-diffusers
+
+Diffusers-ready checkpoints for **EDM2** ([Analyzing and Improving the Training Dynamics of Diffusion Models](https://arxiv.org/abs/2312.02696)),
+converted from [NVlabs/edm2](https://github.com/NVlabs/edm2) post-hoc reconstructions.
+
+Official source weights: `https://nvlabs-fi-cdn.nvidia.com/edm2/posthoc-reconstructions/`
+
+This root folder is a model collection that contains:
+
+- `edm2-img512-xs-fid`
+- `edm2-img512-s-fid`
+- `edm2-img512-m-fid`
+- `edm2-img512-l-fid`
+- `edm2-img512-l-dino`
+- `edm2-img512-xl-fid`
+- `edm2-img512-xxl-fid`
+
+Each subfolder is a self-contained Diffusers model repo with:
+
+- `pipeline.py`
+- `unet/unet_edm2.py`
+- `scheduler/scheduler_config.json` (`EDMEulerScheduler`)
+- `unet/diffusion_pytorch_model.safetensors`
+- `vae/diffusion_pytorch_model.safetensors`
+
+## Demo
+
+
+
+Class-conditional sample (ImageNet class **207**, golden retriever), EDM2-XXL at 512×512, 32 steps, guidance 1.0, seed 42.
+
+## Model Paths
+
+Use paths relative to this root README:
+
+| Model | NVlabs preset | FID | Local path |
+| --- | --- | ---: | --- |
+| EDM2-XS | `edm2-img512-xs-fid` | 3.53 | `./edm2-img512-xs-fid` |
+| EDM2-S | `edm2-img512-s-fid` | 2.56 | `./edm2-img512-s-fid` |
+| EDM2-M | `edm2-img512-m-fid` | 2.25 | `./edm2-img512-m-fid` |
+| EDM2-L | `edm2-img512-l-fid` | 2.06 | `./edm2-img512-l-fid` |
+| EDM2-L (DINO) | `edm2-img512-l-dino` | — | `./edm2-img512-l-dino` |
+| EDM2-XL | `edm2-img512-xl-fid` | 1.96 | `./edm2-img512-xl-fid` |
+| EDM2-XXL | `edm2-img512-xxl-fid` | 1.91 | `./edm2-img512-xxl-fid` |
+
+## Inference Demo (Diffusers)
+
+### 1) Load a local subfolder checkpoint
+
+```python
+from pathlib import Path
+import torch
+from diffusers import DiffusionPipeline
+
+model_dir = Path("./edm2-img512-xxl-fid") # change to any path in the table above
+pipe = DiffusionPipeline.from_pretrained(
+ str(model_dir),
+ local_files_only=True,
+ trust_remote_code=True,
+ torch_dtype=torch.bfloat16,
+).to("cuda")
+
+generator = torch.Generator(device="cuda").manual_seed(42)
+image = pipe(
+ class_labels=207, # golden retriever (ImageNet id); omit for random class
+ num_inference_steps=32,
+ guidance_scale=1.0, # >1.0 requires a gnet/ checkpoint
+ generator=generator,
+).images[0]
+image.save("demo.png")
+```
+
+Official inference defaults (`generate_images.py`): `num_steps=32`, `sigma_min=0.002`,
+`sigma_max=80`, `rho=7`, `guidance=1.0` (no gnet), `S_churn=0`. Heun sampling runs in
+float32 internally even when UNet/VAE weights are loaded in bf16/fp16.
+
+Guided presets require a converted `gnet/` folder and `guidance_scale` matching the
+NVlabs preset.
+
+### 2) Convert a legacy `.pkl`
+
+```bash
+python scripts/convert_edm2_to_diffusers.py \
+ --checkpoint models/BiliSakura/EDM2-diffusers/edm2-img512-xs-2147483-0.135.pkl \
+ --output models/BiliSakura/EDM2-diffusers
+```
+
+Creates `edm2-img512-xs-fid/` automatically from the NVlabs preset mapping.
+
+## Checkpoint preset mapping
+
+Maps NVlabs `--preset=...` names from [`generate_images.py`](https://github.com/NVlabs/edm2/blob/main/generate_images.py)
+to source pickle filenames and local Diffusers directories.
+
+### EDM2 paper — ImageNet-512 (conditional)
+
+| NVlabs preset | Source `.pkl` (net) | Diffusers dir | Metric |
+| --- | --- | --- | --- |
+| `edm2-img512-xs-fid` | `edm2-img512-xs-2147483-0.135.pkl` | `edm2-img512-xs-fid/` | FID 3.53 |
+| `edm2-img512-xs-dino` | `edm2-img512-xs-2147483-0.200.pkl` | — | FDDINOv2 103.39 |
+| `edm2-img512-s-fid` | `edm2-img512-s-2147483-0.130.pkl` | `edm2-img512-s-fid/` | FID 2.56 |
+| `edm2-img512-s-dino` | `edm2-img512-s-2147483-0.190.pkl` | — | FDDINOv2 68.64 |
+| `edm2-img512-m-fid` | `edm2-img512-m-2147483-0.100.pkl` | `edm2-img512-m-fid/` | FID 2.25 |
+| `edm2-img512-m-dino` | `edm2-img512-m-2147483-0.155.pkl` | — | FDDINOv2 58.44 |
+| `edm2-img512-l-fid` | `edm2-img512-l-1879048-0.085.pkl` | `edm2-img512-l-fid/` | FID 2.06 |
+| `edm2-img512-l-dino` | `edm2-img512-l-1879048-0.155.pkl` | `edm2-img512-l-dino/` | FDDINOv2 52.25 |
+| `edm2-img512-xl-fid` | `edm2-img512-xl-1342177-0.085.pkl` | `edm2-img512-xl-fid/` | FID 1.96 |
+| `edm2-img512-xl-dino` | `edm2-img512-xl-1342177-0.155.pkl` | — | FDDINOv2 45.96 |
+| `edm2-img512-xxl-fid` | `edm2-img512-xxl-0939524-0.070.pkl` | `edm2-img512-xxl-fid/` | FID 1.91 |
+| `edm2-img512-xxl-dino` | `edm2-img512-xxl-0939524-0.150.pkl` | — | FDDINOv2 42.84 |
+
+### EDM2 paper — ImageNet-64 (conditional)
+
+| NVlabs preset | Source `.pkl` (net) | Metric |
+| --- | --- | --- |
+| `edm2-img64-s-fid` | `edm2-img64-s-1073741-0.075.pkl` | FID 1.58 |
+| `edm2-img64-m-fid` | `edm2-img64-m-2147483-0.060.pkl` | FID 1.43 |
+| `edm2-img64-l-fid` | `edm2-img64-l-1073741-0.040.pkl` | FID 1.33 |
+| `edm2-img64-xl-fid` | `edm2-img64-xl-0671088-0.040.pkl` | FID 1.33 |
+
+### EDM2 paper — classifier-free guidance (ImageNet-512)
+
+Use `guidance_scale` below and include the converted `gnet/` checkpoint.
+
+| NVlabs preset | Source `.pkl` (net) | Source `.pkl` (gnet) | Guidance | Metric |
+| --- | --- | --- | ---: | --- |
+| `edm2-img512-xs-guid-fid` | `edm2-img512-xs-2147483-0.045.pkl` | `edm2-img512-xs-uncond-2147483-0.045.pkl` | 1.40 | FID 2.91 |
+| `edm2-img512-xs-guid-dino` | `edm2-img512-xs-2147483-0.150.pkl` | `edm2-img512-xs-uncond-2147483-0.150.pkl` | 1.70 | FDDINOv2 79.94 |
+| `edm2-img512-s-guid-fid` | `edm2-img512-s-2147483-0.025.pkl` | `edm2-img512-xs-uncond-2147483-0.025.pkl` | 1.40 | FID 2.23 |
+| `edm2-img512-s-guid-dino` | `edm2-img512-s-2147483-0.085.pkl` | `edm2-img512-xs-uncond-2147483-0.085.pkl` | 1.90 | FDDINOv2 52.32 |
+| `edm2-img512-m-guid-fid` | `edm2-img512-m-2147483-0.030.pkl` | `edm2-img512-xs-uncond-2147483-0.030.pkl` | 1.20 | FID 2.01 |
+| `edm2-img512-m-guid-dino` | `edm2-img512-m-2147483-0.015.pkl` | `edm2-img512-xs-uncond-2147483-0.015.pkl` | 2.00 | FDDINOv2 41.98 |
+| `edm2-img512-l-guid-fid` | `edm2-img512-l-1879048-0.015.pkl` | `edm2-img512-xs-uncond-2147483-0.015.pkl` | 1.20 | FID 1.88 |
+| `edm2-img512-l-guid-dino` | `edm2-img512-l-1879048-0.035.pkl` | `edm2-img512-xs-uncond-2147483-0.035.pkl` | 1.70 | FDDINOv2 38.20 |
+| `edm2-img512-xl-guid-fid` | `edm2-img512-xl-1342177-0.020.pkl` | `edm2-img512-xs-uncond-2147483-0.020.pkl` | 1.20 | FID 1.85 |
+| `edm2-img512-xl-guid-dino` | `edm2-img512-xl-1342177-0.030.pkl` | `edm2-img512-xs-uncond-2147483-0.030.pkl` | 1.70 | FDDINOv2 35.67 |
+| `edm2-img512-xxl-guid-fid` | `edm2-img512-xxl-0939524-0.015.pkl` | `edm2-img512-xs-uncond-2147483-0.015.pkl` | 1.20 | FID 1.81 |
+| `edm2-img512-xxl-guid-dino` | `edm2-img512-xxl-0939524-0.015.pkl` | `edm2-img512-xs-uncond-2147483-0.015.pkl` | 1.70 | FDDINOv2 33.09 |
+
+### Autoguidance paper
+
+| NVlabs preset | Source `.pkl` (net) | Source `.pkl` (gnet) | Guidance | Metric |
+| --- | --- | --- | ---: | --- |
+| `edm2-img512-s-autog-fid` | `edm2-img512-s-2147483-0.070.pkl` | `edm2-img512-xs-0134217-0.125.pkl` | 2.10 | FID 1.34 |
+| `edm2-img512-s-autog-dino` | `edm2-img512-s-2147483-0.120.pkl` | `edm2-img512-xs-0134217-0.165.pkl` | 2.45 | FDDINOv2 36.67 |
+| `edm2-img512-xxl-autog-fid` | `edm2-img512-xxl-0939524-0.075.pkl` | `edm2-img512-m-0268435-0.155.pkl` | 2.05 | FID 1.25 |
+| `edm2-img512-xxl-autog-dino` | `edm2-img512-xxl-0939524-0.130.pkl` | `edm2-img512-m-0268435-0.205.pkl` | 2.30 | FDDINOv2 24.18 |
+| `edm2-img512-s-uncond-autog-fid` | `edm2-img512-s-uncond-2147483-0.070.pkl` | `edm2-img512-xs-uncond-0134217-0.110.pkl` | 2.85 | FID 3.86 |
+| `edm2-img512-s-uncond-autog-dino` | `edm2-img512-s-uncond-2147483-0.090.pkl` | `edm2-img512-xs-uncond-0134217-0.125.pkl` | 2.90 | FDDINOv2 90.39 |
+| `edm2-img64-s-autog-fid` | `edm2-img64-s-1073741-0.045.pkl` | `edm2-img64-xs-0134217-0.110.pkl` | 1.70 | FID 1.01 |
+| `edm2-img64-s-autog-dino` | `edm2-img64-s-1073741-0.105.pkl` | `edm2-img64-xs-0134217-0.175.pkl` | 2.20 | FDDINOv2 31.85 |
+
+### NVlabs preset shorthand
+
+```text
+# EDM2 paper
+edm2-img512-{xs|s|m|l|xl|xxl}-{fid|dino}
+edm2-img64-{s|m|l|xl}-fid
+edm2-img512-{xs|s|m|l|xl|xxl}-guid-{fid|dino}
+
+# Autoguidance paper
+edm2-img512-{s|xxl}-autog-{fid|dino}
+edm2-img512-s-uncond-autog-{fid|dino}
+edm2-img64-s-autog-{fid|dino}
+```
+
+Example NVlabs command:
+
+```bash
+python generate_images.py --preset=edm2-img512-s-guid-dino --outdir=out
+```
+
+Equivalent expanded form:
+
+```bash
+python generate_images.py \
+ --net=https://nvlabs-fi-cdn.nvidia.com/edm2/posthoc-reconstructions/edm2-img512-s-2147483-0.085.pkl \
+ --gnet=https://nvlabs-fi-cdn.nvidia.com/edm2/posthoc-reconstructions/edm2-img512-xs-uncond-2147483-0.085.pkl \
+ --guidance=1.9 \
+ --outdir=out
+```
diff --git a/edm2-img512-l-dino/demo.png b/edm2-img512-l-dino/demo.png
new file mode 100644
index 0000000000000000000000000000000000000000..60c5c3da1e339e32d3ea1c0fab558f809783d63d
--- /dev/null
+++ b/edm2-img512-l-dino/demo.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:12a2dab2ca0e5ec5a6eebe9f7c10b440232622055866192ecc5c8b3dc289db4d
+size 389147
diff --git a/edm2-img512-l-dino/model_index.json b/edm2-img512-l-dino/model_index.json
new file mode 100644
index 0000000000000000000000000000000000000000..24c0305eea3ec64abbc2f870476afa4f5e462f42
--- /dev/null
+++ b/edm2-img512-l-dino/model_index.json
@@ -0,0 +1,19 @@
+{
+ "_class_name": [
+ "pipeline",
+ "EDM2Pipeline"
+ ],
+ "_diffusers_version": "0.31.0",
+ "scheduler": [
+ "diffusers",
+ "EDMEulerScheduler"
+ ],
+ "unet": [
+ "unet_edm2",
+ "EDM2UNet2DModel"
+ ],
+ "vae": [
+ "diffusers",
+ "AutoencoderKL"
+ ]
+}
diff --git a/edm2-img512-l-dino/pipeline.py b/edm2-img512-l-dino/pipeline.py
new file mode 100644
index 0000000000000000000000000000000000000000..6b4131215d07770c74efa0674d00f3d6423d9c06
--- /dev/null
+++ b/edm2-img512-l-dino/pipeline.py
@@ -0,0 +1,406 @@
+"""Hub custom pipeline: EDM2Pipeline.
+Load with native Hugging Face diffusers and trust_remote_code=True.
+"""
+
+from __future__ import annotations
+
+# Copyright 2026 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+import os
+from pathlib import Path
+from typing import Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Union
+
+import numpy as np
+import torch
+
+from diffusers.image_processor import VaeImageProcessor
+from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
+from diffusers.utils import replace_example_docstring
+from diffusers.utils.torch_utils import randn_tensor
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> from pathlib import Path
+ >>> import torch
+ >>> from diffusers import DiffusionPipeline
+
+ >>> model_dir = Path("BiliSakura/EDM2-diffusers/edm2-img512-xs-fid").resolve()
+ >>> pipe = DiffusionPipeline.from_pretrained(
+ ... str(model_dir),
+ ... local_files_only=True,
+ ... custom_pipeline=str(model_dir / "pipeline.py"),
+ ... trust_remote_code=True,
+ ... torch_dtype=torch.float32,
+ ... )
+ >>> pipe.to("cuda")
+
+ >>> generator = torch.Generator(device="cuda").manual_seed(42)
+ >>> image = pipe(
+ ... class_labels=207,
+ ... num_inference_steps=32,
+ ... guidance_scale=1.0,
+ ... generator=generator,
+ ... ).images[0]
+ >>> image.save("demo.png")
+ ```
+"""
+
+# Default Stability VAE latent whitening used by NVlabs/edm2 (training/encoders.py).
+_STABILITY_VAE_SCALE = np.float32(0.5) / np.float32([4.17, 4.62, 3.71, 3.28])
+_STABILITY_VAE_BIAS = np.float32(0.0) - np.float32([5.81, 3.25, 0.12, -2.15]) * _STABILITY_VAE_SCALE
+
+class EDM2Pipeline(DiffusionPipeline):
+ r"""
+ Pipeline for class-conditional image generation with EDM2
+ ([Analyzing and Improving the Training Dynamics of Diffusion Models](https://arxiv.org/abs/2312.02696)).
+
+ Parameters:
+ unet ([`EDM2UNet2DModel`]):
+ Main magnitude-preserving U-Net with EDM preconditioning.
+ scheduler ([`EDMEulerScheduler`]):
+ Built-in diffusers scheduler used for the Karras sigma schedule. EDM2 Heun sampling runs in
+ the pipeline because the UNet returns denoised latents rather than noise predictions.
+ vae ([`AutoencoderKL`], *optional*):
+ Decoder for 512px latent-diffusion checkpoints. Required when `unet.in_channels == 4`.
+ gnet ([`EDM2UNet2DModel`], *optional*):
+ Guiding network for autoguidance (`ref.lerp(main, guidance_scale)`).
+ id2label (`dict[int, str]`, *optional*):
+ ImageNet class id to English label mapping.
+ """
+
+ model_cpu_offload_seq = "unet->gnet->vae"
+ _optional_components = ["vae", "gnet"]
+
+ def __init__(
+ self,
+ unet,
+ scheduler,
+ vae=None,
+ gnet=None,
+ id2label: Optional[Dict[Union[int, str], str]] = None,
+ ) -> None:
+ super().__init__()
+ self.register_modules(unet=unet, scheduler=scheduler, vae=vae, gnet=gnet)
+ self._id2label = self._normalize_id2label(id2label)
+ self.labels = self._build_label2id(self._id2label)
+ self._labels_loaded_from_model_index = bool(self._id2label)
+ self.vae_scale_factor = 8 if self.vae is not None else 1
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_normalize=False)
+
+ @staticmethod
+ def _normalize_id2label(id2label: Optional[Dict[Union[int, str], str]]) -> Dict[int, str]:
+ if not id2label:
+ return {}
+ return {int(key): value for key, value in id2label.items()}
+
+ @staticmethod
+ def _build_label2id(id2label: Dict[int, str]) -> Dict[str, int]:
+ label2id: Dict[str, int] = {}
+ for class_id, value in id2label.items():
+ for synonym in value.split(","):
+ synonym = synonym.strip()
+ if synonym:
+ label2id[synonym] = int(class_id)
+ return dict(sorted(label2id.items()))
+
+ def _ensure_labels_loaded(self) -> None:
+ if self._labels_loaded_from_model_index:
+ return
+ loaded = self._read_id2label_from_model_index(getattr(self.config, "_name_or_path", None))
+ if loaded:
+ self._id2label = loaded
+ self.labels = self._build_label2id(self._id2label)
+ self._labels_loaded_from_model_index = True
+
+ @staticmethod
+ def _read_id2label_from_model_index(variant_path: Optional[str]) -> Dict[int, str]:
+ if not variant_path:
+ return {}
+ model_index_path = Path(variant_path).resolve() / "model_index.json"
+ if not model_index_path.is_file():
+ return {}
+ raw = json.loads(model_index_path.read_text(encoding="utf-8"))
+ id2label = raw.get("id2label")
+ if not isinstance(id2label, dict):
+ return {}
+ return {int(key): value for key, value in id2label.items()}
+
+ @property
+ def id2label(self) -> Dict[int, str]:
+ r"""ImageNet class id to English label string (comma-separated synonyms)."""
+ self._ensure_labels_loaded()
+ return self._id2label
+
+ def get_label_ids(self, label: Union[str, List[str]]) -> List[int]:
+ r"""
+ Map ImageNet label strings to class ids.
+
+ Args:
+ label (`str` or `list[str]`):
+ One or more English label strings that match entries in `id2label`.
+ """
+ self._ensure_labels_loaded()
+ if not self.labels:
+ raise ValueError("No English labels loaded. Add `id2label` to model_index.json.")
+ labels = [label] if isinstance(label, str) else list(label)
+ missing = [item for item in labels if item not in self.labels]
+ if missing:
+ preview = ", ".join(list(self.labels.keys())[:8])
+ raise ValueError(f"Unknown English label(s): {missing}. Example valid labels: {preview}, ...")
+ return [self.labels[item] for item in labels]
+
+ def _default_image_size(self) -> int:
+ latent_size = int(getattr(self.unet, "sample_size", getattr(self.unet.config, "sample_size", 64)))
+ return latent_size * self.vae_scale_factor
+
+ def check_inputs(
+ self,
+ height: int,
+ width: int,
+ num_inference_steps: int,
+ guidance_scale: float,
+ output_type: str,
+ ) -> None:
+ if num_inference_steps < 1:
+ raise ValueError("num_inference_steps must be >= 1.")
+ if guidance_scale < 1.0:
+ raise ValueError("guidance_scale must be >= 1.0.")
+ if guidance_scale > 1.0 and self.gnet is None:
+ raise ValueError("guidance_scale > 1.0 requires a guiding network (`gnet`).")
+ if output_type not in {"pil", "np", "pt", "latent"}:
+ raise ValueError("output_type must be one of: 'pil', 'np', 'pt', 'latent'.")
+
+ native_size = self._default_image_size()
+ if height != native_size or width != native_size:
+ raise ValueError(
+ f"EDM2 expects native resolution height=width={native_size}. "
+ f"Got height={height}, width={width}."
+ )
+
+ def _normalize_class_labels(
+ self,
+ class_labels: Optional[Union[int, str, Sequence[Union[int, str]], torch.Tensor]],
+ batch_size: int,
+ device: torch.device,
+ ) -> Optional[torch.Tensor]:
+ label_dim = int(getattr(self.unet, "num_class_embeds", getattr(self.unet.config, "num_class_embeds", 0)))
+ if label_dim == 0:
+ return None
+ if class_labels is None:
+ indices = torch.randint(label_dim, size=(batch_size,), device=device)
+ return torch.eye(label_dim, device=device, dtype=torch.float32)[indices]
+
+ if isinstance(class_labels, str):
+ class_labels = self.get_label_ids(class_labels)[0]
+ elif isinstance(class_labels, Sequence) and class_labels and isinstance(class_labels[0], str):
+ class_labels = self.get_label_ids(list(class_labels))
+
+ if isinstance(class_labels, int):
+ indices = torch.full((batch_size,), class_labels, device=device, dtype=torch.long)
+ elif isinstance(class_labels, torch.Tensor):
+ if class_labels.ndim == 2:
+ labels = class_labels.to(device=device, dtype=torch.float32)
+ if labels.shape[0] != batch_size:
+ raise ValueError(f"class_labels batch must match batch_size={batch_size}.")
+ return labels
+ indices = class_labels.to(device=device, dtype=torch.long).flatten()
+ else:
+ indices = torch.tensor(list(class_labels), device=device, dtype=torch.long)
+
+ if indices.numel() == 1 and batch_size > 1:
+ indices = indices.repeat(batch_size)
+ if indices.numel() != batch_size:
+ raise ValueError(f"class_labels must resolve to batch size {batch_size}, got {indices.numel()}.")
+ return torch.eye(label_dim, device=device, dtype=torch.float32)[indices]
+
+ def prepare_latents(
+ self,
+ batch_size: int,
+ height: int,
+ width: int,
+ dtype: torch.dtype,
+ device: torch.device,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]],
+ ) -> torch.Tensor:
+ in_channels = int(getattr(self.unet, "in_channels", getattr(self.unet.config, "in_channels", 4)))
+ latent_size = height // self.vae_scale_factor
+ return randn_tensor(
+ (batch_size, in_channels, latent_size, latent_size),
+ generator=generator,
+ device=device,
+ dtype=torch.float32,
+ )
+
+ def decode_latents(self, latents: torch.Tensor, output_type: str = "pil"):
+ if output_type == "latent":
+ return latents
+
+ in_channels = int(getattr(self.unet, "in_channels", getattr(self.unet.config, "in_channels", 3)))
+ if self.vae is None:
+ image = (latents.to(torch.float32) * 127.5 + 128).clip(0, 255) / 255.0
+ return self.image_processor.postprocess(image, output_type=output_type)
+
+ if in_channels == 4:
+ x = latents.to(torch.float32)
+ scale = torch.as_tensor(_STABILITY_VAE_SCALE, dtype=x.dtype, device=x.device).reshape(1, -1, 1, 1)
+ bias = torch.as_tensor(_STABILITY_VAE_BIAS, dtype=x.dtype, device=x.device).reshape(1, -1, 1, 1)
+ x = (x - bias) / scale
+ else:
+ x = latents.to(torch.float32)
+
+ vae_dtype = getattr(self.vae, "dtype", None) or next(self.vae.parameters()).dtype
+ image = self.vae.decode(x.to(dtype=vae_dtype)).sample.to(torch.float32).clamp(0, 1)
+
+ return self.image_processor.postprocess(image, output_type=output_type)
+
+ @staticmethod
+ def _apply_autoguidance(
+ main: torch.Tensor,
+ ref: torch.Tensor,
+ guidance_scale: float,
+ ) -> torch.Tensor:
+ return ref.lerp(main, guidance_scale)
+
+ @staticmethod
+ def _sample_edm2_heun(
+ denoise_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
+ noise: torch.Tensor,
+ sigmas: torch.Tensor,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ progress_bar: Optional[Callable[[Iterable], Iterable]] = None,
+ dtype: torch.dtype = torch.float32,
+ ) -> torch.Tensor:
+ """NVlabs EDM2 Heun sampler (generate_images.edm_sampler, guidance=1, S_churn=0)."""
+ x_next = noise.to(dtype) * sigmas[0]
+
+ sigma_pairs = list(zip(sigmas[:-1], sigmas[1:]))
+ if progress_bar is not None:
+ sigma_pairs = progress_bar(sigma_pairs)
+
+ num_steps = len(sigma_pairs)
+ for i, (sigma_cur, sigma_next) in enumerate(sigma_pairs):
+ x_hat, sigma_hat = x_next, sigma_cur
+ d_cur = (x_hat - denoise_fn(x_hat, sigma_hat)) / sigma_hat
+ x_next = x_hat + (sigma_next - sigma_hat) * d_cur
+ if i < num_steps - 1:
+ d_prime = (x_next - denoise_fn(x_next, sigma_next)) / sigma_next
+ x_next = x_hat + (sigma_next - sigma_hat) * (0.5 * d_cur + 0.5 * d_prime)
+ return x_next
+
+ @torch.inference_mode()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ class_labels: Optional[Union[int, str, Sequence[Union[int, str]], torch.Tensor]] = None,
+ batch_size: int = 1,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 32,
+ guidance_scale: float = 1.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ output_type: str = "pil",
+ return_dict: bool = True,
+ ) -> Union[ImagePipelineOutput, Tuple]:
+ r"""
+ Generate class-conditional images with EDM2.
+
+ Args:
+ class_labels (`int`, `str`, `list[int]`, `list[str]`, or `torch.Tensor`, *optional*):
+ ImageNet class indices, English label strings, or one-hot float tensors.
+ Random classes are sampled when omitted on conditional models.
+ batch_size (`int`, defaults to `1`):
+ Number of images to generate.
+ height (`int`, *optional*):
+ Output height in pixels. Defaults to the pretrained native resolution.
+ width (`int`, *optional*):
+ Output width in pixels. Defaults to the pretrained native resolution.
+ num_inference_steps (`int`, defaults to `32`):
+ Number of EDM2 Heun steps (NVlabs default).
+ guidance_scale (`float`, defaults to `1.0`):
+ Autoguidance strength. Values above `1.0` blend the main net with `gnet`
+ via `gnet_output.lerp(unet_output, guidance_scale)`.
+ generator (`torch.Generator`, *optional*):
+ RNG for reproducibility.
+ output_type (`str`, defaults to `"pil"`):
+ `"pil"`, `"np"`, `"pt"`, or `"latent"`.
+ return_dict (`bool`, defaults to `True`):
+ Return [`~pipelines.pipeline_utils.ImagePipelineOutput`] if True.
+
+ Examples:
+
+ """
+ default_size = self._default_image_size()
+ height = int(height or default_size)
+ width = int(width or default_size)
+ self.check_inputs(height, width, num_inference_steps, guidance_scale, output_type)
+
+ device = self._execution_device
+ dtype = self.unet.dtype
+ labels = self._normalize_class_labels(class_labels, batch_size=batch_size, device=device)
+ noise = self.prepare_latents(batch_size, height, width, dtype, device, generator)
+
+ def denoise_fn(x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor:
+ sigma_batch = sigma.reshape(1).expand(batch_size)
+ main = self.unet(
+ sample=x,
+ sigma=sigma_batch,
+ class_labels=labels,
+ force_fp32=True,
+ ).sample
+ if guidance_scale == 1.0 or self.gnet is None:
+ return main.to(torch.float32)
+ ref = self.gnet(
+ sample=x,
+ sigma=sigma_batch,
+ class_labels=labels,
+ force_fp32=True,
+ ).sample
+ return self._apply_autoguidance(main, ref, guidance_scale).to(torch.float32)
+
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ latents = self._sample_edm2_heun(
+ denoise_fn=denoise_fn,
+ noise=noise,
+ sigmas=self.scheduler.sigmas.to(device),
+ generator=generator,
+ progress_bar=self.progress_bar,
+ dtype=torch.float32,
+ )
+
+ image = self.decode_latents(latents, output_type=output_type)
+ if not return_dict:
+ return (image, latents)
+ return ImagePipelineOutput(images=image)
+
+ @classmethod
+ def _load_vae(cls, pretrained_model_name_or_path: str, torch_dtype: Optional[torch.dtype] = None):
+ vae_dir = os.path.join(pretrained_model_name_or_path, "vae")
+ if os.path.isdir(vae_dir):
+ try:
+
+ return AutoencoderKL.from_pretrained(vae_dir, torch_dtype=torch_dtype)
+ except Exception:
+ return None
+
+ vae_hint = os.path.join(pretrained_model_name_or_path, "vae_pretrained_model_name_or_path.txt")
+ if os.path.isfile(vae_hint):
+ with open(vae_hint, "r", encoding="utf-8") as f:
+ hub_id = f.read().strip()
+ if hub_id:
+
+ return AutoencoderKL.from_pretrained(hub_id, torch_dtype=torch_dtype)
+ return None
diff --git a/edm2-img512-l-dino/scheduler/scheduler_config.json b/edm2-img512-l-dino/scheduler/scheduler_config.json
new file mode 100644
index 0000000000000000000000000000000000000000..08cbde97955d06d66b0b72be217db84bb797d711
--- /dev/null
+++ b/edm2-img512-l-dino/scheduler/scheduler_config.json
@@ -0,0 +1,11 @@
+{
+ "_class_name": "EDMEulerScheduler",
+ "final_sigmas_type": "zero",
+ "num_train_timesteps": 1000,
+ "prediction_type": "epsilon",
+ "rho": 7.0,
+ "sigma_data": 0.5,
+ "sigma_max": 80.0,
+ "sigma_min": 0.002,
+ "sigma_schedule": "karras"
+}
diff --git a/edm2-img512-l-dino/unet/config.json b/edm2-img512-l-dino/unet/config.json
new file mode 100644
index 0000000000000000000000000000000000000000..939ef1d82a8da253ce5dc9b30d3ac644f49dc4ed
--- /dev/null
+++ b/edm2-img512-l-dino/unet/config.json
@@ -0,0 +1,31 @@
+{
+ "_class_name": "EDM2UNet2DModel",
+ "attn_balance": 0.3,
+ "attn_resolutions": [
+ 16,
+ 8
+ ],
+ "channel_mult": [
+ 1,
+ 2,
+ 3,
+ 4
+ ],
+ "channel_mult_emb": 4,
+ "channel_mult_noise": 1,
+ "channels_per_head": 64,
+ "clip_act": 256,
+ "concat_balance": 0.5,
+ "dropout": 0.0,
+ "in_channels": 4,
+ "label_balance": 0.5,
+ "logvar_channels": 128,
+ "model_channels": 320,
+ "num_blocks": 3,
+ "num_class_embeds": 1000,
+ "out_channels": 4,
+ "res_balance": 0.3,
+ "sample_size": 64,
+ "sigma_data": 0.5,
+ "use_fp16": true
+}
diff --git a/edm2-img512-l-dino/unet/diffusion_pytorch_model.safetensors b/edm2-img512-l-dino/unet/diffusion_pytorch_model.safetensors
new file mode 100644
index 0000000000000000000000000000000000000000..09f449c4dab91a9ddfd36f4204f4a3bc472f5208
--- /dev/null
+++ b/edm2-img512-l-dino/unet/diffusion_pytorch_model.safetensors
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:f13f83377a74d74e1205843e241ce6d6e4bc9e49c2661944e49fdbe4d515ba33
+size 3110018564
diff --git a/edm2-img512-l-dino/unet/unet_edm2.py b/edm2-img512-l-dino/unet/unet_edm2.py
new file mode 100644
index 0000000000000000000000000000000000000000..49d9451b443c4d6ad3b6ba2db19a7a804b07b6de
--- /dev/null
+++ b/edm2-img512-l-dino/unet/unet_edm2.py
@@ -0,0 +1,434 @@
+import math
+import json
+import os
+from dataclasses import dataclass
+from typing import List, Optional, Tuple
+
+import numpy as np
+import torch
+
+try:
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
+ from diffusers.models.modeling_utils import ModelMixin
+ from diffusers.utils import BaseOutput
+except ImportError: # pragma: no cover
+ class ModelMixin(torch.nn.Module):
+ pass
+
+ class ConfigMixin:
+ config = {}
+
+ def register_to_config(self, **kwargs):
+ self.config = kwargs
+
+ def register_to_config(func):
+ return func
+
+ @dataclass
+ class BaseOutput:
+ pass
+
+
+def normalize(x: torch.Tensor, dim: Optional[List[int]] = None, eps: float = 1e-4) -> torch.Tensor:
+ if dim is None:
+ dim = list(range(1, x.ndim))
+ norm = torch.linalg.vector_norm(x, dim=dim, keepdim=True, dtype=torch.float32)
+ norm = torch.add(eps, norm, alpha=math.sqrt(norm.numel() / x.numel()))
+ return x / norm.to(x.dtype)
+
+
+def resample(x: torch.Tensor, f: List[float], mode: str = "keep") -> torch.Tensor:
+ if mode == "keep":
+ return x
+ filt = np.float32(f)
+ pad = (len(filt) - 1) // 2
+ filt = filt / filt.sum()
+ filt = np.outer(filt, filt)[np.newaxis, np.newaxis, :, :]
+ filt = torch.as_tensor(filt, dtype=x.dtype, device=x.device)
+ c = x.shape[1]
+ if mode == "down":
+ return torch.nn.functional.conv2d(x, filt.tile([c, 1, 1, 1]), groups=c, stride=2, padding=(pad,))
+ return torch.nn.functional.conv_transpose2d(x, (filt * 4).tile([c, 1, 1, 1]), groups=c, stride=2, padding=(pad,))
+
+
+def mp_silu(x: torch.Tensor) -> torch.Tensor:
+ return torch.nn.functional.silu(x) / 0.596
+
+
+def mp_sum(a: torch.Tensor, b: torch.Tensor, t: float = 0.5) -> torch.Tensor:
+ return a.lerp(b, t) / math.sqrt((1 - t) ** 2 + t ** 2)
+
+
+def mp_cat(a: torch.Tensor, b: torch.Tensor, dim: int = 1, t: float = 0.5) -> torch.Tensor:
+ na = a.shape[dim]
+ nb = b.shape[dim]
+ c = math.sqrt((na + nb) / ((1 - t) ** 2 + t ** 2))
+ wa = c / math.sqrt(na) * (1 - t)
+ wb = c / math.sqrt(nb) * t
+ return torch.cat([wa * a, wb * b], dim=dim)
+
+
+class MPFourier(torch.nn.Module):
+ def __init__(self, num_channels: int, bandwidth: float = 1):
+ super().__init__()
+ self.register_buffer("freqs", 2 * math.pi * torch.randn(num_channels) * bandwidth)
+ self.register_buffer("phases", 2 * math.pi * torch.rand(num_channels))
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ y = x.to(torch.float32).ger(self.freqs.to(torch.float32))
+ y = y + self.phases.to(torch.float32)
+ y = y.cos() * math.sqrt(2)
+ return y.to(x.dtype)
+
+
+class MPConv(torch.nn.Module):
+ def __init__(self, in_channels: int, out_channels: int, kernel: Tuple[int, ...]):
+ super().__init__()
+ self.out_channels = out_channels
+ self.weight = torch.nn.Parameter(torch.randn(out_channels, in_channels, *kernel))
+
+ def forward(self, x: torch.Tensor, gain: float = 1) -> torch.Tensor:
+ w = self.weight.to(torch.float32)
+ if self.training:
+ with torch.no_grad():
+ self.weight.copy_(normalize(w))
+ w = normalize(w)
+ w = w * (gain / math.sqrt(w[0].numel()))
+ w = w.to(x.dtype)
+ if w.ndim == 2:
+ return x @ w.t()
+ return torch.nn.functional.conv2d(x, w, padding=(w.shape[-1] // 2,))
+
+
+class Block(torch.nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ emb_channels: int,
+ flavor: str = "enc",
+ resample_mode: str = "keep",
+ resample_filter: List[float] = [1, 1],
+ attention: bool = False,
+ channels_per_head: int = 64,
+ dropout: float = 0.0,
+ res_balance: float = 0.3,
+ attn_balance: float = 0.3,
+ clip_act: Optional[float] = 256,
+ ):
+ super().__init__()
+ self.out_channels = out_channels
+ self.flavor = flavor
+ self.resample_filter = resample_filter
+ self.resample_mode = resample_mode
+ self.num_heads = out_channels // channels_per_head if attention else 0
+ self.dropout = dropout
+ self.res_balance = res_balance
+ self.attn_balance = attn_balance
+ self.clip_act = clip_act
+ self.emb_gain = torch.nn.Parameter(torch.zeros([]))
+ self.conv_res0 = MPConv(out_channels if flavor == "enc" else in_channels, out_channels, kernel=(3, 3))
+ self.emb_linear = MPConv(emb_channels, out_channels, kernel=())
+ self.conv_res1 = MPConv(out_channels, out_channels, kernel=(3, 3))
+ self.conv_skip = MPConv(in_channels, out_channels, kernel=(1, 1)) if in_channels != out_channels else None
+ self.attn_qkv = MPConv(out_channels, out_channels * 3, kernel=(1, 1)) if self.num_heads else None
+ self.attn_proj = MPConv(out_channels, out_channels, kernel=(1, 1)) if self.num_heads else None
+
+ def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor:
+ x = resample(x, f=self.resample_filter, mode=self.resample_mode)
+ if self.flavor == "enc":
+ if self.conv_skip is not None:
+ x = self.conv_skip(x)
+ x = normalize(x, dim=[1])
+
+ y = self.conv_res0(mp_silu(x))
+ c = self.emb_linear(emb, gain=self.emb_gain) + 1
+ y = mp_silu(y * c.unsqueeze(2).unsqueeze(3).to(y.dtype))
+ if self.training and self.dropout:
+ y = torch.nn.functional.dropout(y, p=self.dropout)
+ y = self.conv_res1(y)
+
+ if self.flavor == "dec" and self.conv_skip is not None:
+ x = self.conv_skip(x)
+ x = mp_sum(x, y, t=self.res_balance)
+
+ if self.num_heads:
+ y = self.attn_qkv(x)
+ y = y.reshape(y.shape[0], self.num_heads, -1, 3, y.shape[2] * y.shape[3])
+ q, k, v = normalize(y, dim=[2]).unbind(3)
+ w = torch.einsum("nhcq,nhck->nhqk", q, k / math.sqrt(q.shape[2])).softmax(dim=3)
+ y = torch.einsum("nhqk,nhck->nhcq", w, v)
+ y = self.attn_proj(y.reshape(*x.shape))
+ x = mp_sum(x, y, t=self.attn_balance)
+
+ if self.clip_act is not None:
+ x = x.clip_(-self.clip_act, self.clip_act)
+ return x
+
+
+class EDM2UNet(torch.nn.Module):
+ def __init__(
+ self,
+ img_resolution: int,
+ img_channels: int,
+ label_dim: int,
+ model_channels: int = 192,
+ channel_mult: Tuple[int, ...] = (1, 2, 3, 4),
+ channel_mult_noise: Optional[int] = None,
+ channel_mult_emb: Optional[int] = None,
+ num_blocks: int = 3,
+ attn_resolutions: Tuple[int, ...] = (16, 8),
+ label_balance: float = 0.5,
+ concat_balance: float = 0.5,
+ **block_kwargs,
+ ):
+ super().__init__()
+ cblock = [model_channels * x for x in channel_mult]
+ cnoise = model_channels * channel_mult_noise if channel_mult_noise is not None else cblock[0]
+ cemb = model_channels * channel_mult_emb if channel_mult_emb is not None else max(cblock)
+ self.label_balance = label_balance
+ self.concat_balance = concat_balance
+ self.out_gain = torch.nn.Parameter(torch.zeros([]))
+
+ self.emb_fourier = MPFourier(cnoise)
+ self.emb_noise = MPConv(cnoise, cemb, kernel=())
+ self.emb_label = MPConv(label_dim, cemb, kernel=()) if label_dim else None
+
+ self.enc = torch.nn.ModuleDict()
+ cout = img_channels + 1
+ for level, channels in enumerate(cblock):
+ res = img_resolution >> level
+ if level == 0:
+ cin = cout
+ cout = channels
+ self.enc[f"{res}x{res}_conv"] = MPConv(cin, cout, kernel=(3, 3))
+ else:
+ self.enc[f"{res}x{res}_down"] = Block(cout, cout, cemb, flavor="enc", resample_mode="down", **block_kwargs)
+ for idx in range(num_blocks):
+ cin = cout
+ cout = channels
+ self.enc[f"{res}x{res}_block{idx}"] = Block(
+ cin,
+ cout,
+ cemb,
+ flavor="enc",
+ attention=(res in attn_resolutions),
+ **block_kwargs,
+ )
+
+ self.dec = torch.nn.ModuleDict()
+ skips = [block.out_channels for block in self.enc.values()]
+ for level, channels in reversed(list(enumerate(cblock))):
+ res = img_resolution >> level
+ if level == len(cblock) - 1:
+ self.dec[f"{res}x{res}_in0"] = Block(cout, cout, cemb, flavor="dec", attention=True, **block_kwargs)
+ self.dec[f"{res}x{res}_in1"] = Block(cout, cout, cemb, flavor="dec", **block_kwargs)
+ else:
+ self.dec[f"{res}x{res}_up"] = Block(cout, cout, cemb, flavor="dec", resample_mode="up", **block_kwargs)
+ for idx in range(num_blocks + 1):
+ cin = cout + skips.pop()
+ cout = channels
+ self.dec[f"{res}x{res}_block{idx}"] = Block(
+ cin,
+ cout,
+ cemb,
+ flavor="dec",
+ attention=(res in attn_resolutions),
+ **block_kwargs,
+ )
+
+ self.out_conv = MPConv(cout, img_channels, kernel=(3, 3))
+
+ def forward(self, x: torch.Tensor, noise_labels: torch.Tensor, class_labels: Optional[torch.Tensor]) -> torch.Tensor:
+ emb = self.emb_noise(self.emb_fourier(noise_labels))
+ if self.emb_label is not None:
+ if class_labels is None:
+ raise ValueError("class_labels are required for conditional EDM2UNet.")
+ emb = mp_sum(emb, self.emb_label(class_labels * math.sqrt(class_labels.shape[1])), t=self.label_balance)
+ emb = mp_silu(emb)
+
+ x = torch.cat([x, torch.ones_like(x[:, :1])], dim=1)
+ skips = []
+ for name, block in self.enc.items():
+ x = block(x) if "conv" in name else block(x, emb)
+ skips.append(x)
+
+ for name, block in self.dec.items():
+ if "block" in name:
+ x = mp_cat(x, skips.pop(), t=self.concat_balance)
+ x = block(x, emb)
+ return self.out_conv(x, gain=self.out_gain)
+
+
+@dataclass
+class EDM2UNet2DOutput(BaseOutput):
+ sample: torch.Tensor
+ logvar: Optional[torch.Tensor] = None
+
+
+
+_CONFIG_KEYS = (
+ "sample_size",
+ "in_channels",
+ "out_channels",
+ "num_class_embeds",
+ "use_fp16",
+ "sigma_data",
+ "logvar_channels",
+ "model_channels",
+ "channel_mult",
+ "channel_mult_noise",
+ "channel_mult_emb",
+ "num_blocks",
+ "attn_resolutions",
+ "label_balance",
+ "concat_balance",
+ "dropout",
+ "channels_per_head",
+ "res_balance",
+ "attn_balance",
+ "clip_act",
+)
+
+
+class EDM2UNet2DModel(ModelMixin, ConfigMixin):
+ @register_to_config
+ def __init__(
+ self,
+ sample_size: int = 64,
+ in_channels: int = 4,
+ out_channels: int = 4,
+ num_class_embeds: int = 0,
+ use_fp16: bool = True,
+ sigma_data: float = 0.5,
+ logvar_channels: int = 128,
+ model_channels: int = 192,
+ channel_mult: Tuple[int, ...] = (1, 2, 3, 4),
+ channel_mult_noise: Optional[int] = None,
+ channel_mult_emb: Optional[int] = None,
+ num_blocks: int = 3,
+ attn_resolutions: Tuple[int, ...] = (16, 8),
+ label_balance: float = 0.5,
+ concat_balance: float = 0.5,
+ dropout: float = 0.0,
+ channels_per_head: int = 64,
+ res_balance: float = 0.3,
+ attn_balance: float = 0.3,
+ clip_act: Optional[float] = 256,
+ ):
+ super().__init__()
+ self.sample_size = sample_size
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.num_class_embeds = num_class_embeds
+ self.use_fp16 = use_fp16
+ self.sigma_data = sigma_data
+ self.model_channels = model_channels
+ self.channel_mult = channel_mult
+ self.channel_mult_noise = channel_mult_noise
+ self.channel_mult_emb = channel_mult_emb
+ self.num_blocks = num_blocks
+ self.attn_resolutions = attn_resolutions
+ self.label_balance = label_balance
+ self.concat_balance = concat_balance
+ self.dropout = dropout
+ self.channels_per_head = channels_per_head
+ self.res_balance = res_balance
+ self.attn_balance = attn_balance
+ self.clip_act = clip_act
+ self.unet = EDM2UNet(
+ img_resolution=sample_size,
+ img_channels=in_channels,
+ label_dim=num_class_embeds,
+ model_channels=model_channels,
+ channel_mult=channel_mult,
+ channel_mult_noise=channel_mult_noise,
+ channel_mult_emb=channel_mult_emb,
+ num_blocks=num_blocks,
+ attn_resolutions=attn_resolutions,
+ label_balance=label_balance,
+ concat_balance=concat_balance,
+ dropout=dropout,
+ channels_per_head=channels_per_head,
+ res_balance=res_balance,
+ attn_balance=attn_balance,
+ clip_act=clip_act,
+ )
+ self.logvar_fourier = MPFourier(logvar_channels)
+ self.logvar_linear = MPConv(logvar_channels, 1, kernel=())
+
+ def forward(
+ self,
+ sample: torch.Tensor,
+ sigma: torch.Tensor,
+ class_labels: Optional[torch.Tensor] = None,
+ force_fp32: bool = False,
+ return_logvar: bool = False,
+ return_dict: bool = True,
+ ) -> EDM2UNet2DOutput:
+ x = sample.to(torch.float32)
+ sigma = sigma.to(torch.float32).reshape(-1, 1, 1, 1)
+ if self.num_class_embeds == 0:
+ class_labels = None
+ else:
+ if class_labels is None:
+ class_labels = torch.zeros([x.shape[0], self.num_class_embeds], device=x.device)
+ class_labels = class_labels.to(torch.float32).reshape(-1, self.num_class_embeds)
+ dtype = torch.float16 if (self.use_fp16 and not force_fp32 and x.device.type == "cuda") else torch.float32
+
+ c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2)
+ c_out = sigma * self.sigma_data / (sigma**2 + self.sigma_data**2).sqrt()
+ c_in = 1 / (self.sigma_data**2 + sigma**2).sqrt()
+ c_noise = sigma.flatten().log() / 4
+
+ x_in = (c_in * x).to(dtype)
+ f_x = self.unet(x_in, c_noise, class_labels)
+ d_x = c_skip * x + c_out * f_x.to(torch.float32)
+
+ logvar = None
+ if return_logvar:
+ logvar = self.logvar_linear(self.logvar_fourier(c_noise)).reshape(-1, 1, 1, 1)
+
+ if not return_dict:
+ return (d_x, logvar)
+ return EDM2UNet2DOutput(sample=d_x, logvar=logvar)
+
+ @classmethod
+ def from_pretrained(cls, pretrained_model_name_or_path: str, torch_dtype: Optional[torch.dtype] = None, **kwargs):
+ subfolder = kwargs.pop("subfolder", None)
+ model_dir = os.path.join(pretrained_model_name_or_path, subfolder) if subfolder else pretrained_model_name_or_path
+ with open(os.path.join(model_dir, "config.json"), "r", encoding="utf-8") as f:
+ config = json.load(f)
+ init_kwargs = {k: v for k, v in config.items() if k in _CONFIG_KEYS}
+ model = cls(**init_kwargs)
+ weight_file = os.path.join(model_dir, "diffusion_pytorch_model.safetensors")
+ if os.path.isfile(weight_file):
+ from safetensors.torch import load_file
+
+ state_dict = load_file(weight_file)
+ else:
+ state_dict = torch.load(os.path.join(model_dir, "diffusion_pytorch_model.bin"), map_location="cpu")
+ model.load_state_dict(state_dict, strict=True)
+ if torch_dtype is not None:
+ model = model.to(dtype=torch_dtype)
+ return model
+
+ def save_pretrained(self, save_directory: str, safe_serialization: bool = True):
+ os.makedirs(save_directory, exist_ok=True)
+ stored = dict(getattr(self, "config", {}))
+ config = {"_class_name": self.__class__.__name__}
+ for key in _CONFIG_KEYS:
+ if key in stored:
+ config[key] = stored[key]
+ elif hasattr(self, key):
+ config[key] = getattr(self, key)
+ with open(os.path.join(save_directory, "config.json"), "w", encoding="utf-8") as f:
+ json.dump(config, f, indent=2, sort_keys=True)
+ f.write("\n")
+ state_dict = self.state_dict()
+ if safe_serialization:
+ from safetensors.torch import save_file
+
+ save_file(state_dict, os.path.join(save_directory, "diffusion_pytorch_model.safetensors"))
+ else:
+ torch.save(state_dict, os.path.join(save_directory, "diffusion_pytorch_model.bin"))
diff --git a/edm2-img512-l-dino/vae/config.json b/edm2-img512-l-dino/vae/config.json
new file mode 100644
index 0000000000000000000000000000000000000000..ae2ab37de2acb8bc348d3de809ae8385324be962
--- /dev/null
+++ b/edm2-img512-l-dino/vae/config.json
@@ -0,0 +1,38 @@
+{
+ "_class_name": "AutoencoderKL",
+ "_diffusers_version": "0.36.0",
+ "_name_or_path": "stabilityai/sd-vae-ft-mse",
+ "act_fn": "silu",
+ "block_out_channels": [
+ 128,
+ 256,
+ 512,
+ 512
+ ],
+ "down_block_types": [
+ "DownEncoderBlock2D",
+ "DownEncoderBlock2D",
+ "DownEncoderBlock2D",
+ "DownEncoderBlock2D"
+ ],
+ "force_upcast": true,
+ "in_channels": 3,
+ "latent_channels": 4,
+ "latents_mean": null,
+ "latents_std": null,
+ "layers_per_block": 2,
+ "mid_block_add_attention": true,
+ "norm_num_groups": 32,
+ "out_channels": 3,
+ "sample_size": 256,
+ "scaling_factor": 0.18215,
+ "shift_factor": null,
+ "up_block_types": [
+ "UpDecoderBlock2D",
+ "UpDecoderBlock2D",
+ "UpDecoderBlock2D",
+ "UpDecoderBlock2D"
+ ],
+ "use_post_quant_conv": true,
+ "use_quant_conv": true
+}
diff --git a/edm2-img512-l-dino/vae/diffusion_pytorch_model.safetensors b/edm2-img512-l-dino/vae/diffusion_pytorch_model.safetensors
new file mode 100644
index 0000000000000000000000000000000000000000..90464d67ac7303d0ee4696334df13da130a948ea
--- /dev/null
+++ b/edm2-img512-l-dino/vae/diffusion_pytorch_model.safetensors
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a1d993488569e928462932c8c38a0760b874d166399b14414135bd9c42df5815
+size 334643276
diff --git a/edm2-img512-l-fid/generator_test.png b/edm2-img512-l-fid/generator_test.png
new file mode 100644
index 0000000000000000000000000000000000000000..c7201a46a9b325e4b7129290fa7a3f13549cc7d4
--- /dev/null
+++ b/edm2-img512-l-fid/generator_test.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:cabf3ca8019e86c4a85855d5c3fd2c6de6d25ac51682da208d20db23533e6578
+size 378707
diff --git a/edm2-img512-l-fid/model_index.json b/edm2-img512-l-fid/model_index.json
new file mode 100644
index 0000000000000000000000000000000000000000..24c0305eea3ec64abbc2f870476afa4f5e462f42
--- /dev/null
+++ b/edm2-img512-l-fid/model_index.json
@@ -0,0 +1,19 @@
+{
+ "_class_name": [
+ "pipeline",
+ "EDM2Pipeline"
+ ],
+ "_diffusers_version": "0.31.0",
+ "scheduler": [
+ "diffusers",
+ "EDMEulerScheduler"
+ ],
+ "unet": [
+ "unet_edm2",
+ "EDM2UNet2DModel"
+ ],
+ "vae": [
+ "diffusers",
+ "AutoencoderKL"
+ ]
+}
diff --git a/edm2-img512-l-fid/pipeline.py b/edm2-img512-l-fid/pipeline.py
new file mode 100644
index 0000000000000000000000000000000000000000..6b4131215d07770c74efa0674d00f3d6423d9c06
--- /dev/null
+++ b/edm2-img512-l-fid/pipeline.py
@@ -0,0 +1,406 @@
+"""Hub custom pipeline: EDM2Pipeline.
+Load with native Hugging Face diffusers and trust_remote_code=True.
+"""
+
+from __future__ import annotations
+
+# Copyright 2026 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+import os
+from pathlib import Path
+from typing import Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Union
+
+import numpy as np
+import torch
+
+from diffusers.image_processor import VaeImageProcessor
+from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
+from diffusers.utils import replace_example_docstring
+from diffusers.utils.torch_utils import randn_tensor
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> from pathlib import Path
+ >>> import torch
+ >>> from diffusers import DiffusionPipeline
+
+ >>> model_dir = Path("BiliSakura/EDM2-diffusers/edm2-img512-xs-fid").resolve()
+ >>> pipe = DiffusionPipeline.from_pretrained(
+ ... str(model_dir),
+ ... local_files_only=True,
+ ... custom_pipeline=str(model_dir / "pipeline.py"),
+ ... trust_remote_code=True,
+ ... torch_dtype=torch.float32,
+ ... )
+ >>> pipe.to("cuda")
+
+ >>> generator = torch.Generator(device="cuda").manual_seed(42)
+ >>> image = pipe(
+ ... class_labels=207,
+ ... num_inference_steps=32,
+ ... guidance_scale=1.0,
+ ... generator=generator,
+ ... ).images[0]
+ >>> image.save("demo.png")
+ ```
+"""
+
+# Default Stability VAE latent whitening used by NVlabs/edm2 (training/encoders.py).
+_STABILITY_VAE_SCALE = np.float32(0.5) / np.float32([4.17, 4.62, 3.71, 3.28])
+_STABILITY_VAE_BIAS = np.float32(0.0) - np.float32([5.81, 3.25, 0.12, -2.15]) * _STABILITY_VAE_SCALE
+
+class EDM2Pipeline(DiffusionPipeline):
+ r"""
+ Pipeline for class-conditional image generation with EDM2
+ ([Analyzing and Improving the Training Dynamics of Diffusion Models](https://arxiv.org/abs/2312.02696)).
+
+ Parameters:
+ unet ([`EDM2UNet2DModel`]):
+ Main magnitude-preserving U-Net with EDM preconditioning.
+ scheduler ([`EDMEulerScheduler`]):
+ Built-in diffusers scheduler used for the Karras sigma schedule. EDM2 Heun sampling runs in
+ the pipeline because the UNet returns denoised latents rather than noise predictions.
+ vae ([`AutoencoderKL`], *optional*):
+ Decoder for 512px latent-diffusion checkpoints. Required when `unet.in_channels == 4`.
+ gnet ([`EDM2UNet2DModel`], *optional*):
+ Guiding network for autoguidance (`ref.lerp(main, guidance_scale)`).
+ id2label (`dict[int, str]`, *optional*):
+ ImageNet class id to English label mapping.
+ """
+
+ model_cpu_offload_seq = "unet->gnet->vae"
+ _optional_components = ["vae", "gnet"]
+
+ def __init__(
+ self,
+ unet,
+ scheduler,
+ vae=None,
+ gnet=None,
+ id2label: Optional[Dict[Union[int, str], str]] = None,
+ ) -> None:
+ super().__init__()
+ self.register_modules(unet=unet, scheduler=scheduler, vae=vae, gnet=gnet)
+ self._id2label = self._normalize_id2label(id2label)
+ self.labels = self._build_label2id(self._id2label)
+ self._labels_loaded_from_model_index = bool(self._id2label)
+ self.vae_scale_factor = 8 if self.vae is not None else 1
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_normalize=False)
+
+ @staticmethod
+ def _normalize_id2label(id2label: Optional[Dict[Union[int, str], str]]) -> Dict[int, str]:
+ if not id2label:
+ return {}
+ return {int(key): value for key, value in id2label.items()}
+
+ @staticmethod
+ def _build_label2id(id2label: Dict[int, str]) -> Dict[str, int]:
+ label2id: Dict[str, int] = {}
+ for class_id, value in id2label.items():
+ for synonym in value.split(","):
+ synonym = synonym.strip()
+ if synonym:
+ label2id[synonym] = int(class_id)
+ return dict(sorted(label2id.items()))
+
+ def _ensure_labels_loaded(self) -> None:
+ if self._labels_loaded_from_model_index:
+ return
+ loaded = self._read_id2label_from_model_index(getattr(self.config, "_name_or_path", None))
+ if loaded:
+ self._id2label = loaded
+ self.labels = self._build_label2id(self._id2label)
+ self._labels_loaded_from_model_index = True
+
+ @staticmethod
+ def _read_id2label_from_model_index(variant_path: Optional[str]) -> Dict[int, str]:
+ if not variant_path:
+ return {}
+ model_index_path = Path(variant_path).resolve() / "model_index.json"
+ if not model_index_path.is_file():
+ return {}
+ raw = json.loads(model_index_path.read_text(encoding="utf-8"))
+ id2label = raw.get("id2label")
+ if not isinstance(id2label, dict):
+ return {}
+ return {int(key): value for key, value in id2label.items()}
+
+ @property
+ def id2label(self) -> Dict[int, str]:
+ r"""ImageNet class id to English label string (comma-separated synonyms)."""
+ self._ensure_labels_loaded()
+ return self._id2label
+
+ def get_label_ids(self, label: Union[str, List[str]]) -> List[int]:
+ r"""
+ Map ImageNet label strings to class ids.
+
+ Args:
+ label (`str` or `list[str]`):
+ One or more English label strings that match entries in `id2label`.
+ """
+ self._ensure_labels_loaded()
+ if not self.labels:
+ raise ValueError("No English labels loaded. Add `id2label` to model_index.json.")
+ labels = [label] if isinstance(label, str) else list(label)
+ missing = [item for item in labels if item not in self.labels]
+ if missing:
+ preview = ", ".join(list(self.labels.keys())[:8])
+ raise ValueError(f"Unknown English label(s): {missing}. Example valid labels: {preview}, ...")
+ return [self.labels[item] for item in labels]
+
+ def _default_image_size(self) -> int:
+ latent_size = int(getattr(self.unet, "sample_size", getattr(self.unet.config, "sample_size", 64)))
+ return latent_size * self.vae_scale_factor
+
+ def check_inputs(
+ self,
+ height: int,
+ width: int,
+ num_inference_steps: int,
+ guidance_scale: float,
+ output_type: str,
+ ) -> None:
+ if num_inference_steps < 1:
+ raise ValueError("num_inference_steps must be >= 1.")
+ if guidance_scale < 1.0:
+ raise ValueError("guidance_scale must be >= 1.0.")
+ if guidance_scale > 1.0 and self.gnet is None:
+ raise ValueError("guidance_scale > 1.0 requires a guiding network (`gnet`).")
+ if output_type not in {"pil", "np", "pt", "latent"}:
+ raise ValueError("output_type must be one of: 'pil', 'np', 'pt', 'latent'.")
+
+ native_size = self._default_image_size()
+ if height != native_size or width != native_size:
+ raise ValueError(
+ f"EDM2 expects native resolution height=width={native_size}. "
+ f"Got height={height}, width={width}."
+ )
+
+ def _normalize_class_labels(
+ self,
+ class_labels: Optional[Union[int, str, Sequence[Union[int, str]], torch.Tensor]],
+ batch_size: int,
+ device: torch.device,
+ ) -> Optional[torch.Tensor]:
+ label_dim = int(getattr(self.unet, "num_class_embeds", getattr(self.unet.config, "num_class_embeds", 0)))
+ if label_dim == 0:
+ return None
+ if class_labels is None:
+ indices = torch.randint(label_dim, size=(batch_size,), device=device)
+ return torch.eye(label_dim, device=device, dtype=torch.float32)[indices]
+
+ if isinstance(class_labels, str):
+ class_labels = self.get_label_ids(class_labels)[0]
+ elif isinstance(class_labels, Sequence) and class_labels and isinstance(class_labels[0], str):
+ class_labels = self.get_label_ids(list(class_labels))
+
+ if isinstance(class_labels, int):
+ indices = torch.full((batch_size,), class_labels, device=device, dtype=torch.long)
+ elif isinstance(class_labels, torch.Tensor):
+ if class_labels.ndim == 2:
+ labels = class_labels.to(device=device, dtype=torch.float32)
+ if labels.shape[0] != batch_size:
+ raise ValueError(f"class_labels batch must match batch_size={batch_size}.")
+ return labels
+ indices = class_labels.to(device=device, dtype=torch.long).flatten()
+ else:
+ indices = torch.tensor(list(class_labels), device=device, dtype=torch.long)
+
+ if indices.numel() == 1 and batch_size > 1:
+ indices = indices.repeat(batch_size)
+ if indices.numel() != batch_size:
+ raise ValueError(f"class_labels must resolve to batch size {batch_size}, got {indices.numel()}.")
+ return torch.eye(label_dim, device=device, dtype=torch.float32)[indices]
+
+ def prepare_latents(
+ self,
+ batch_size: int,
+ height: int,
+ width: int,
+ dtype: torch.dtype,
+ device: torch.device,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]],
+ ) -> torch.Tensor:
+ in_channels = int(getattr(self.unet, "in_channels", getattr(self.unet.config, "in_channels", 4)))
+ latent_size = height // self.vae_scale_factor
+ return randn_tensor(
+ (batch_size, in_channels, latent_size, latent_size),
+ generator=generator,
+ device=device,
+ dtype=torch.float32,
+ )
+
+ def decode_latents(self, latents: torch.Tensor, output_type: str = "pil"):
+ if output_type == "latent":
+ return latents
+
+ in_channels = int(getattr(self.unet, "in_channels", getattr(self.unet.config, "in_channels", 3)))
+ if self.vae is None:
+ image = (latents.to(torch.float32) * 127.5 + 128).clip(0, 255) / 255.0
+ return self.image_processor.postprocess(image, output_type=output_type)
+
+ if in_channels == 4:
+ x = latents.to(torch.float32)
+ scale = torch.as_tensor(_STABILITY_VAE_SCALE, dtype=x.dtype, device=x.device).reshape(1, -1, 1, 1)
+ bias = torch.as_tensor(_STABILITY_VAE_BIAS, dtype=x.dtype, device=x.device).reshape(1, -1, 1, 1)
+ x = (x - bias) / scale
+ else:
+ x = latents.to(torch.float32)
+
+ vae_dtype = getattr(self.vae, "dtype", None) or next(self.vae.parameters()).dtype
+ image = self.vae.decode(x.to(dtype=vae_dtype)).sample.to(torch.float32).clamp(0, 1)
+
+ return self.image_processor.postprocess(image, output_type=output_type)
+
+ @staticmethod
+ def _apply_autoguidance(
+ main: torch.Tensor,
+ ref: torch.Tensor,
+ guidance_scale: float,
+ ) -> torch.Tensor:
+ return ref.lerp(main, guidance_scale)
+
+ @staticmethod
+ def _sample_edm2_heun(
+ denoise_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
+ noise: torch.Tensor,
+ sigmas: torch.Tensor,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ progress_bar: Optional[Callable[[Iterable], Iterable]] = None,
+ dtype: torch.dtype = torch.float32,
+ ) -> torch.Tensor:
+ """NVlabs EDM2 Heun sampler (generate_images.edm_sampler, guidance=1, S_churn=0)."""
+ x_next = noise.to(dtype) * sigmas[0]
+
+ sigma_pairs = list(zip(sigmas[:-1], sigmas[1:]))
+ if progress_bar is not None:
+ sigma_pairs = progress_bar(sigma_pairs)
+
+ num_steps = len(sigma_pairs)
+ for i, (sigma_cur, sigma_next) in enumerate(sigma_pairs):
+ x_hat, sigma_hat = x_next, sigma_cur
+ d_cur = (x_hat - denoise_fn(x_hat, sigma_hat)) / sigma_hat
+ x_next = x_hat + (sigma_next - sigma_hat) * d_cur
+ if i < num_steps - 1:
+ d_prime = (x_next - denoise_fn(x_next, sigma_next)) / sigma_next
+ x_next = x_hat + (sigma_next - sigma_hat) * (0.5 * d_cur + 0.5 * d_prime)
+ return x_next
+
+ @torch.inference_mode()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ class_labels: Optional[Union[int, str, Sequence[Union[int, str]], torch.Tensor]] = None,
+ batch_size: int = 1,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 32,
+ guidance_scale: float = 1.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ output_type: str = "pil",
+ return_dict: bool = True,
+ ) -> Union[ImagePipelineOutput, Tuple]:
+ r"""
+ Generate class-conditional images with EDM2.
+
+ Args:
+ class_labels (`int`, `str`, `list[int]`, `list[str]`, or `torch.Tensor`, *optional*):
+ ImageNet class indices, English label strings, or one-hot float tensors.
+ Random classes are sampled when omitted on conditional models.
+ batch_size (`int`, defaults to `1`):
+ Number of images to generate.
+ height (`int`, *optional*):
+ Output height in pixels. Defaults to the pretrained native resolution.
+ width (`int`, *optional*):
+ Output width in pixels. Defaults to the pretrained native resolution.
+ num_inference_steps (`int`, defaults to `32`):
+ Number of EDM2 Heun steps (NVlabs default).
+ guidance_scale (`float`, defaults to `1.0`):
+ Autoguidance strength. Values above `1.0` blend the main net with `gnet`
+ via `gnet_output.lerp(unet_output, guidance_scale)`.
+ generator (`torch.Generator`, *optional*):
+ RNG for reproducibility.
+ output_type (`str`, defaults to `"pil"`):
+ `"pil"`, `"np"`, `"pt"`, or `"latent"`.
+ return_dict (`bool`, defaults to `True`):
+ Return [`~pipelines.pipeline_utils.ImagePipelineOutput`] if True.
+
+ Examples:
+
+ """
+ default_size = self._default_image_size()
+ height = int(height or default_size)
+ width = int(width or default_size)
+ self.check_inputs(height, width, num_inference_steps, guidance_scale, output_type)
+
+ device = self._execution_device
+ dtype = self.unet.dtype
+ labels = self._normalize_class_labels(class_labels, batch_size=batch_size, device=device)
+ noise = self.prepare_latents(batch_size, height, width, dtype, device, generator)
+
+ def denoise_fn(x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor:
+ sigma_batch = sigma.reshape(1).expand(batch_size)
+ main = self.unet(
+ sample=x,
+ sigma=sigma_batch,
+ class_labels=labels,
+ force_fp32=True,
+ ).sample
+ if guidance_scale == 1.0 or self.gnet is None:
+ return main.to(torch.float32)
+ ref = self.gnet(
+ sample=x,
+ sigma=sigma_batch,
+ class_labels=labels,
+ force_fp32=True,
+ ).sample
+ return self._apply_autoguidance(main, ref, guidance_scale).to(torch.float32)
+
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ latents = self._sample_edm2_heun(
+ denoise_fn=denoise_fn,
+ noise=noise,
+ sigmas=self.scheduler.sigmas.to(device),
+ generator=generator,
+ progress_bar=self.progress_bar,
+ dtype=torch.float32,
+ )
+
+ image = self.decode_latents(latents, output_type=output_type)
+ if not return_dict:
+ return (image, latents)
+ return ImagePipelineOutput(images=image)
+
+ @classmethod
+ def _load_vae(cls, pretrained_model_name_or_path: str, torch_dtype: Optional[torch.dtype] = None):
+ vae_dir = os.path.join(pretrained_model_name_or_path, "vae")
+ if os.path.isdir(vae_dir):
+ try:
+
+ return AutoencoderKL.from_pretrained(vae_dir, torch_dtype=torch_dtype)
+ except Exception:
+ return None
+
+ vae_hint = os.path.join(pretrained_model_name_or_path, "vae_pretrained_model_name_or_path.txt")
+ if os.path.isfile(vae_hint):
+ with open(vae_hint, "r", encoding="utf-8") as f:
+ hub_id = f.read().strip()
+ if hub_id:
+
+ return AutoencoderKL.from_pretrained(hub_id, torch_dtype=torch_dtype)
+ return None
diff --git a/edm2-img512-l-fid/scheduler/scheduler_config.json b/edm2-img512-l-fid/scheduler/scheduler_config.json
new file mode 100644
index 0000000000000000000000000000000000000000..08cbde97955d06d66b0b72be217db84bb797d711
--- /dev/null
+++ b/edm2-img512-l-fid/scheduler/scheduler_config.json
@@ -0,0 +1,11 @@
+{
+ "_class_name": "EDMEulerScheduler",
+ "final_sigmas_type": "zero",
+ "num_train_timesteps": 1000,
+ "prediction_type": "epsilon",
+ "rho": 7.0,
+ "sigma_data": 0.5,
+ "sigma_max": 80.0,
+ "sigma_min": 0.002,
+ "sigma_schedule": "karras"
+}
diff --git a/edm2-img512-l-fid/unet/config.json b/edm2-img512-l-fid/unet/config.json
new file mode 100644
index 0000000000000000000000000000000000000000..939ef1d82a8da253ce5dc9b30d3ac644f49dc4ed
--- /dev/null
+++ b/edm2-img512-l-fid/unet/config.json
@@ -0,0 +1,31 @@
+{
+ "_class_name": "EDM2UNet2DModel",
+ "attn_balance": 0.3,
+ "attn_resolutions": [
+ 16,
+ 8
+ ],
+ "channel_mult": [
+ 1,
+ 2,
+ 3,
+ 4
+ ],
+ "channel_mult_emb": 4,
+ "channel_mult_noise": 1,
+ "channels_per_head": 64,
+ "clip_act": 256,
+ "concat_balance": 0.5,
+ "dropout": 0.0,
+ "in_channels": 4,
+ "label_balance": 0.5,
+ "logvar_channels": 128,
+ "model_channels": 320,
+ "num_blocks": 3,
+ "num_class_embeds": 1000,
+ "out_channels": 4,
+ "res_balance": 0.3,
+ "sample_size": 64,
+ "sigma_data": 0.5,
+ "use_fp16": true
+}
diff --git a/edm2-img512-l-fid/unet/diffusion_pytorch_model.safetensors b/edm2-img512-l-fid/unet/diffusion_pytorch_model.safetensors
new file mode 100644
index 0000000000000000000000000000000000000000..3b76a348ae7a2f1b1d650d963890cf0e6e98ad5e
--- /dev/null
+++ b/edm2-img512-l-fid/unet/diffusion_pytorch_model.safetensors
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:7a3e3f5127c12027e4796bef297e247a38ddd13bb7b8445c5d41169106b94389
+size 3110018564
diff --git a/edm2-img512-l-fid/unet/unet_edm2.py b/edm2-img512-l-fid/unet/unet_edm2.py
new file mode 100644
index 0000000000000000000000000000000000000000..49d9451b443c4d6ad3b6ba2db19a7a804b07b6de
--- /dev/null
+++ b/edm2-img512-l-fid/unet/unet_edm2.py
@@ -0,0 +1,434 @@
+import math
+import json
+import os
+from dataclasses import dataclass
+from typing import List, Optional, Tuple
+
+import numpy as np
+import torch
+
+try:
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
+ from diffusers.models.modeling_utils import ModelMixin
+ from diffusers.utils import BaseOutput
+except ImportError: # pragma: no cover
+ class ModelMixin(torch.nn.Module):
+ pass
+
+ class ConfigMixin:
+ config = {}
+
+ def register_to_config(self, **kwargs):
+ self.config = kwargs
+
+ def register_to_config(func):
+ return func
+
+ @dataclass
+ class BaseOutput:
+ pass
+
+
+def normalize(x: torch.Tensor, dim: Optional[List[int]] = None, eps: float = 1e-4) -> torch.Tensor:
+ if dim is None:
+ dim = list(range(1, x.ndim))
+ norm = torch.linalg.vector_norm(x, dim=dim, keepdim=True, dtype=torch.float32)
+ norm = torch.add(eps, norm, alpha=math.sqrt(norm.numel() / x.numel()))
+ return x / norm.to(x.dtype)
+
+
+def resample(x: torch.Tensor, f: List[float], mode: str = "keep") -> torch.Tensor:
+ if mode == "keep":
+ return x
+ filt = np.float32(f)
+ pad = (len(filt) - 1) // 2
+ filt = filt / filt.sum()
+ filt = np.outer(filt, filt)[np.newaxis, np.newaxis, :, :]
+ filt = torch.as_tensor(filt, dtype=x.dtype, device=x.device)
+ c = x.shape[1]
+ if mode == "down":
+ return torch.nn.functional.conv2d(x, filt.tile([c, 1, 1, 1]), groups=c, stride=2, padding=(pad,))
+ return torch.nn.functional.conv_transpose2d(x, (filt * 4).tile([c, 1, 1, 1]), groups=c, stride=2, padding=(pad,))
+
+
+def mp_silu(x: torch.Tensor) -> torch.Tensor:
+ return torch.nn.functional.silu(x) / 0.596
+
+
+def mp_sum(a: torch.Tensor, b: torch.Tensor, t: float = 0.5) -> torch.Tensor:
+ return a.lerp(b, t) / math.sqrt((1 - t) ** 2 + t ** 2)
+
+
+def mp_cat(a: torch.Tensor, b: torch.Tensor, dim: int = 1, t: float = 0.5) -> torch.Tensor:
+ na = a.shape[dim]
+ nb = b.shape[dim]
+ c = math.sqrt((na + nb) / ((1 - t) ** 2 + t ** 2))
+ wa = c / math.sqrt(na) * (1 - t)
+ wb = c / math.sqrt(nb) * t
+ return torch.cat([wa * a, wb * b], dim=dim)
+
+
+class MPFourier(torch.nn.Module):
+ def __init__(self, num_channels: int, bandwidth: float = 1):
+ super().__init__()
+ self.register_buffer("freqs", 2 * math.pi * torch.randn(num_channels) * bandwidth)
+ self.register_buffer("phases", 2 * math.pi * torch.rand(num_channels))
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ y = x.to(torch.float32).ger(self.freqs.to(torch.float32))
+ y = y + self.phases.to(torch.float32)
+ y = y.cos() * math.sqrt(2)
+ return y.to(x.dtype)
+
+
+class MPConv(torch.nn.Module):
+ def __init__(self, in_channels: int, out_channels: int, kernel: Tuple[int, ...]):
+ super().__init__()
+ self.out_channels = out_channels
+ self.weight = torch.nn.Parameter(torch.randn(out_channels, in_channels, *kernel))
+
+ def forward(self, x: torch.Tensor, gain: float = 1) -> torch.Tensor:
+ w = self.weight.to(torch.float32)
+ if self.training:
+ with torch.no_grad():
+ self.weight.copy_(normalize(w))
+ w = normalize(w)
+ w = w * (gain / math.sqrt(w[0].numel()))
+ w = w.to(x.dtype)
+ if w.ndim == 2:
+ return x @ w.t()
+ return torch.nn.functional.conv2d(x, w, padding=(w.shape[-1] // 2,))
+
+
+class Block(torch.nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ emb_channels: int,
+ flavor: str = "enc",
+ resample_mode: str = "keep",
+ resample_filter: List[float] = [1, 1],
+ attention: bool = False,
+ channels_per_head: int = 64,
+ dropout: float = 0.0,
+ res_balance: float = 0.3,
+ attn_balance: float = 0.3,
+ clip_act: Optional[float] = 256,
+ ):
+ super().__init__()
+ self.out_channels = out_channels
+ self.flavor = flavor
+ self.resample_filter = resample_filter
+ self.resample_mode = resample_mode
+ self.num_heads = out_channels // channels_per_head if attention else 0
+ self.dropout = dropout
+ self.res_balance = res_balance
+ self.attn_balance = attn_balance
+ self.clip_act = clip_act
+ self.emb_gain = torch.nn.Parameter(torch.zeros([]))
+ self.conv_res0 = MPConv(out_channels if flavor == "enc" else in_channels, out_channels, kernel=(3, 3))
+ self.emb_linear = MPConv(emb_channels, out_channels, kernel=())
+ self.conv_res1 = MPConv(out_channels, out_channels, kernel=(3, 3))
+ self.conv_skip = MPConv(in_channels, out_channels, kernel=(1, 1)) if in_channels != out_channels else None
+ self.attn_qkv = MPConv(out_channels, out_channels * 3, kernel=(1, 1)) if self.num_heads else None
+ self.attn_proj = MPConv(out_channels, out_channels, kernel=(1, 1)) if self.num_heads else None
+
+ def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor:
+ x = resample(x, f=self.resample_filter, mode=self.resample_mode)
+ if self.flavor == "enc":
+ if self.conv_skip is not None:
+ x = self.conv_skip(x)
+ x = normalize(x, dim=[1])
+
+ y = self.conv_res0(mp_silu(x))
+ c = self.emb_linear(emb, gain=self.emb_gain) + 1
+ y = mp_silu(y * c.unsqueeze(2).unsqueeze(3).to(y.dtype))
+ if self.training and self.dropout:
+ y = torch.nn.functional.dropout(y, p=self.dropout)
+ y = self.conv_res1(y)
+
+ if self.flavor == "dec" and self.conv_skip is not None:
+ x = self.conv_skip(x)
+ x = mp_sum(x, y, t=self.res_balance)
+
+ if self.num_heads:
+ y = self.attn_qkv(x)
+ y = y.reshape(y.shape[0], self.num_heads, -1, 3, y.shape[2] * y.shape[3])
+ q, k, v = normalize(y, dim=[2]).unbind(3)
+ w = torch.einsum("nhcq,nhck->nhqk", q, k / math.sqrt(q.shape[2])).softmax(dim=3)
+ y = torch.einsum("nhqk,nhck->nhcq", w, v)
+ y = self.attn_proj(y.reshape(*x.shape))
+ x = mp_sum(x, y, t=self.attn_balance)
+
+ if self.clip_act is not None:
+ x = x.clip_(-self.clip_act, self.clip_act)
+ return x
+
+
+class EDM2UNet(torch.nn.Module):
+ def __init__(
+ self,
+ img_resolution: int,
+ img_channels: int,
+ label_dim: int,
+ model_channels: int = 192,
+ channel_mult: Tuple[int, ...] = (1, 2, 3, 4),
+ channel_mult_noise: Optional[int] = None,
+ channel_mult_emb: Optional[int] = None,
+ num_blocks: int = 3,
+ attn_resolutions: Tuple[int, ...] = (16, 8),
+ label_balance: float = 0.5,
+ concat_balance: float = 0.5,
+ **block_kwargs,
+ ):
+ super().__init__()
+ cblock = [model_channels * x for x in channel_mult]
+ cnoise = model_channels * channel_mult_noise if channel_mult_noise is not None else cblock[0]
+ cemb = model_channels * channel_mult_emb if channel_mult_emb is not None else max(cblock)
+ self.label_balance = label_balance
+ self.concat_balance = concat_balance
+ self.out_gain = torch.nn.Parameter(torch.zeros([]))
+
+ self.emb_fourier = MPFourier(cnoise)
+ self.emb_noise = MPConv(cnoise, cemb, kernel=())
+ self.emb_label = MPConv(label_dim, cemb, kernel=()) if label_dim else None
+
+ self.enc = torch.nn.ModuleDict()
+ cout = img_channels + 1
+ for level, channels in enumerate(cblock):
+ res = img_resolution >> level
+ if level == 0:
+ cin = cout
+ cout = channels
+ self.enc[f"{res}x{res}_conv"] = MPConv(cin, cout, kernel=(3, 3))
+ else:
+ self.enc[f"{res}x{res}_down"] = Block(cout, cout, cemb, flavor="enc", resample_mode="down", **block_kwargs)
+ for idx in range(num_blocks):
+ cin = cout
+ cout = channels
+ self.enc[f"{res}x{res}_block{idx}"] = Block(
+ cin,
+ cout,
+ cemb,
+ flavor="enc",
+ attention=(res in attn_resolutions),
+ **block_kwargs,
+ )
+
+ self.dec = torch.nn.ModuleDict()
+ skips = [block.out_channels for block in self.enc.values()]
+ for level, channels in reversed(list(enumerate(cblock))):
+ res = img_resolution >> level
+ if level == len(cblock) - 1:
+ self.dec[f"{res}x{res}_in0"] = Block(cout, cout, cemb, flavor="dec", attention=True, **block_kwargs)
+ self.dec[f"{res}x{res}_in1"] = Block(cout, cout, cemb, flavor="dec", **block_kwargs)
+ else:
+ self.dec[f"{res}x{res}_up"] = Block(cout, cout, cemb, flavor="dec", resample_mode="up", **block_kwargs)
+ for idx in range(num_blocks + 1):
+ cin = cout + skips.pop()
+ cout = channels
+ self.dec[f"{res}x{res}_block{idx}"] = Block(
+ cin,
+ cout,
+ cemb,
+ flavor="dec",
+ attention=(res in attn_resolutions),
+ **block_kwargs,
+ )
+
+ self.out_conv = MPConv(cout, img_channels, kernel=(3, 3))
+
+ def forward(self, x: torch.Tensor, noise_labels: torch.Tensor, class_labels: Optional[torch.Tensor]) -> torch.Tensor:
+ emb = self.emb_noise(self.emb_fourier(noise_labels))
+ if self.emb_label is not None:
+ if class_labels is None:
+ raise ValueError("class_labels are required for conditional EDM2UNet.")
+ emb = mp_sum(emb, self.emb_label(class_labels * math.sqrt(class_labels.shape[1])), t=self.label_balance)
+ emb = mp_silu(emb)
+
+ x = torch.cat([x, torch.ones_like(x[:, :1])], dim=1)
+ skips = []
+ for name, block in self.enc.items():
+ x = block(x) if "conv" in name else block(x, emb)
+ skips.append(x)
+
+ for name, block in self.dec.items():
+ if "block" in name:
+ x = mp_cat(x, skips.pop(), t=self.concat_balance)
+ x = block(x, emb)
+ return self.out_conv(x, gain=self.out_gain)
+
+
+@dataclass
+class EDM2UNet2DOutput(BaseOutput):
+ sample: torch.Tensor
+ logvar: Optional[torch.Tensor] = None
+
+
+
+_CONFIG_KEYS = (
+ "sample_size",
+ "in_channels",
+ "out_channels",
+ "num_class_embeds",
+ "use_fp16",
+ "sigma_data",
+ "logvar_channels",
+ "model_channels",
+ "channel_mult",
+ "channel_mult_noise",
+ "channel_mult_emb",
+ "num_blocks",
+ "attn_resolutions",
+ "label_balance",
+ "concat_balance",
+ "dropout",
+ "channels_per_head",
+ "res_balance",
+ "attn_balance",
+ "clip_act",
+)
+
+
+class EDM2UNet2DModel(ModelMixin, ConfigMixin):
+ @register_to_config
+ def __init__(
+ self,
+ sample_size: int = 64,
+ in_channels: int = 4,
+ out_channels: int = 4,
+ num_class_embeds: int = 0,
+ use_fp16: bool = True,
+ sigma_data: float = 0.5,
+ logvar_channels: int = 128,
+ model_channels: int = 192,
+ channel_mult: Tuple[int, ...] = (1, 2, 3, 4),
+ channel_mult_noise: Optional[int] = None,
+ channel_mult_emb: Optional[int] = None,
+ num_blocks: int = 3,
+ attn_resolutions: Tuple[int, ...] = (16, 8),
+ label_balance: float = 0.5,
+ concat_balance: float = 0.5,
+ dropout: float = 0.0,
+ channels_per_head: int = 64,
+ res_balance: float = 0.3,
+ attn_balance: float = 0.3,
+ clip_act: Optional[float] = 256,
+ ):
+ super().__init__()
+ self.sample_size = sample_size
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.num_class_embeds = num_class_embeds
+ self.use_fp16 = use_fp16
+ self.sigma_data = sigma_data
+ self.model_channels = model_channels
+ self.channel_mult = channel_mult
+ self.channel_mult_noise = channel_mult_noise
+ self.channel_mult_emb = channel_mult_emb
+ self.num_blocks = num_blocks
+ self.attn_resolutions = attn_resolutions
+ self.label_balance = label_balance
+ self.concat_balance = concat_balance
+ self.dropout = dropout
+ self.channels_per_head = channels_per_head
+ self.res_balance = res_balance
+ self.attn_balance = attn_balance
+ self.clip_act = clip_act
+ self.unet = EDM2UNet(
+ img_resolution=sample_size,
+ img_channels=in_channels,
+ label_dim=num_class_embeds,
+ model_channels=model_channels,
+ channel_mult=channel_mult,
+ channel_mult_noise=channel_mult_noise,
+ channel_mult_emb=channel_mult_emb,
+ num_blocks=num_blocks,
+ attn_resolutions=attn_resolutions,
+ label_balance=label_balance,
+ concat_balance=concat_balance,
+ dropout=dropout,
+ channels_per_head=channels_per_head,
+ res_balance=res_balance,
+ attn_balance=attn_balance,
+ clip_act=clip_act,
+ )
+ self.logvar_fourier = MPFourier(logvar_channels)
+ self.logvar_linear = MPConv(logvar_channels, 1, kernel=())
+
+ def forward(
+ self,
+ sample: torch.Tensor,
+ sigma: torch.Tensor,
+ class_labels: Optional[torch.Tensor] = None,
+ force_fp32: bool = False,
+ return_logvar: bool = False,
+ return_dict: bool = True,
+ ) -> EDM2UNet2DOutput:
+ x = sample.to(torch.float32)
+ sigma = sigma.to(torch.float32).reshape(-1, 1, 1, 1)
+ if self.num_class_embeds == 0:
+ class_labels = None
+ else:
+ if class_labels is None:
+ class_labels = torch.zeros([x.shape[0], self.num_class_embeds], device=x.device)
+ class_labels = class_labels.to(torch.float32).reshape(-1, self.num_class_embeds)
+ dtype = torch.float16 if (self.use_fp16 and not force_fp32 and x.device.type == "cuda") else torch.float32
+
+ c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2)
+ c_out = sigma * self.sigma_data / (sigma**2 + self.sigma_data**2).sqrt()
+ c_in = 1 / (self.sigma_data**2 + sigma**2).sqrt()
+ c_noise = sigma.flatten().log() / 4
+
+ x_in = (c_in * x).to(dtype)
+ f_x = self.unet(x_in, c_noise, class_labels)
+ d_x = c_skip * x + c_out * f_x.to(torch.float32)
+
+ logvar = None
+ if return_logvar:
+ logvar = self.logvar_linear(self.logvar_fourier(c_noise)).reshape(-1, 1, 1, 1)
+
+ if not return_dict:
+ return (d_x, logvar)
+ return EDM2UNet2DOutput(sample=d_x, logvar=logvar)
+
+ @classmethod
+ def from_pretrained(cls, pretrained_model_name_or_path: str, torch_dtype: Optional[torch.dtype] = None, **kwargs):
+ subfolder = kwargs.pop("subfolder", None)
+ model_dir = os.path.join(pretrained_model_name_or_path, subfolder) if subfolder else pretrained_model_name_or_path
+ with open(os.path.join(model_dir, "config.json"), "r", encoding="utf-8") as f:
+ config = json.load(f)
+ init_kwargs = {k: v for k, v in config.items() if k in _CONFIG_KEYS}
+ model = cls(**init_kwargs)
+ weight_file = os.path.join(model_dir, "diffusion_pytorch_model.safetensors")
+ if os.path.isfile(weight_file):
+ from safetensors.torch import load_file
+
+ state_dict = load_file(weight_file)
+ else:
+ state_dict = torch.load(os.path.join(model_dir, "diffusion_pytorch_model.bin"), map_location="cpu")
+ model.load_state_dict(state_dict, strict=True)
+ if torch_dtype is not None:
+ model = model.to(dtype=torch_dtype)
+ return model
+
+ def save_pretrained(self, save_directory: str, safe_serialization: bool = True):
+ os.makedirs(save_directory, exist_ok=True)
+ stored = dict(getattr(self, "config", {}))
+ config = {"_class_name": self.__class__.__name__}
+ for key in _CONFIG_KEYS:
+ if key in stored:
+ config[key] = stored[key]
+ elif hasattr(self, key):
+ config[key] = getattr(self, key)
+ with open(os.path.join(save_directory, "config.json"), "w", encoding="utf-8") as f:
+ json.dump(config, f, indent=2, sort_keys=True)
+ f.write("\n")
+ state_dict = self.state_dict()
+ if safe_serialization:
+ from safetensors.torch import save_file
+
+ save_file(state_dict, os.path.join(save_directory, "diffusion_pytorch_model.safetensors"))
+ else:
+ torch.save(state_dict, os.path.join(save_directory, "diffusion_pytorch_model.bin"))
diff --git a/edm2-img512-l-fid/vae/config.json b/edm2-img512-l-fid/vae/config.json
new file mode 100644
index 0000000000000000000000000000000000000000..ae2ab37de2acb8bc348d3de809ae8385324be962
--- /dev/null
+++ b/edm2-img512-l-fid/vae/config.json
@@ -0,0 +1,38 @@
+{
+ "_class_name": "AutoencoderKL",
+ "_diffusers_version": "0.36.0",
+ "_name_or_path": "stabilityai/sd-vae-ft-mse",
+ "act_fn": "silu",
+ "block_out_channels": [
+ 128,
+ 256,
+ 512,
+ 512
+ ],
+ "down_block_types": [
+ "DownEncoderBlock2D",
+ "DownEncoderBlock2D",
+ "DownEncoderBlock2D",
+ "DownEncoderBlock2D"
+ ],
+ "force_upcast": true,
+ "in_channels": 3,
+ "latent_channels": 4,
+ "latents_mean": null,
+ "latents_std": null,
+ "layers_per_block": 2,
+ "mid_block_add_attention": true,
+ "norm_num_groups": 32,
+ "out_channels": 3,
+ "sample_size": 256,
+ "scaling_factor": 0.18215,
+ "shift_factor": null,
+ "up_block_types": [
+ "UpDecoderBlock2D",
+ "UpDecoderBlock2D",
+ "UpDecoderBlock2D",
+ "UpDecoderBlock2D"
+ ],
+ "use_post_quant_conv": true,
+ "use_quant_conv": true
+}
diff --git a/edm2-img512-l-fid/vae/diffusion_pytorch_model.safetensors b/edm2-img512-l-fid/vae/diffusion_pytorch_model.safetensors
new file mode 100644
index 0000000000000000000000000000000000000000..90464d67ac7303d0ee4696334df13da130a948ea
--- /dev/null
+++ b/edm2-img512-l-fid/vae/diffusion_pytorch_model.safetensors
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a1d993488569e928462932c8c38a0760b874d166399b14414135bd9c42df5815
+size 334643276
diff --git a/edm2-img512-m-fid/demo.png b/edm2-img512-m-fid/demo.png
new file mode 100644
index 0000000000000000000000000000000000000000..6f729fad3723c801849aef1ab04df055413e8955
--- /dev/null
+++ b/edm2-img512-m-fid/demo.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:bda2cb48c7ab17b37fbfa0599c7fec6d1f8d7de6848990f870e8ff4b613c929d
+size 369586
diff --git a/edm2-img512-m-fid/model_index.json b/edm2-img512-m-fid/model_index.json
new file mode 100644
index 0000000000000000000000000000000000000000..24c0305eea3ec64abbc2f870476afa4f5e462f42
--- /dev/null
+++ b/edm2-img512-m-fid/model_index.json
@@ -0,0 +1,19 @@
+{
+ "_class_name": [
+ "pipeline",
+ "EDM2Pipeline"
+ ],
+ "_diffusers_version": "0.31.0",
+ "scheduler": [
+ "diffusers",
+ "EDMEulerScheduler"
+ ],
+ "unet": [
+ "unet_edm2",
+ "EDM2UNet2DModel"
+ ],
+ "vae": [
+ "diffusers",
+ "AutoencoderKL"
+ ]
+}
diff --git a/edm2-img512-m-fid/pipeline.py b/edm2-img512-m-fid/pipeline.py
new file mode 100644
index 0000000000000000000000000000000000000000..6b4131215d07770c74efa0674d00f3d6423d9c06
--- /dev/null
+++ b/edm2-img512-m-fid/pipeline.py
@@ -0,0 +1,406 @@
+"""Hub custom pipeline: EDM2Pipeline.
+Load with native Hugging Face diffusers and trust_remote_code=True.
+"""
+
+from __future__ import annotations
+
+# Copyright 2026 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+import os
+from pathlib import Path
+from typing import Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Union
+
+import numpy as np
+import torch
+
+from diffusers.image_processor import VaeImageProcessor
+from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
+from diffusers.utils import replace_example_docstring
+from diffusers.utils.torch_utils import randn_tensor
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> from pathlib import Path
+ >>> import torch
+ >>> from diffusers import DiffusionPipeline
+
+ >>> model_dir = Path("BiliSakura/EDM2-diffusers/edm2-img512-xs-fid").resolve()
+ >>> pipe = DiffusionPipeline.from_pretrained(
+ ... str(model_dir),
+ ... local_files_only=True,
+ ... custom_pipeline=str(model_dir / "pipeline.py"),
+ ... trust_remote_code=True,
+ ... torch_dtype=torch.float32,
+ ... )
+ >>> pipe.to("cuda")
+
+ >>> generator = torch.Generator(device="cuda").manual_seed(42)
+ >>> image = pipe(
+ ... class_labels=207,
+ ... num_inference_steps=32,
+ ... guidance_scale=1.0,
+ ... generator=generator,
+ ... ).images[0]
+ >>> image.save("demo.png")
+ ```
+"""
+
+# Default Stability VAE latent whitening used by NVlabs/edm2 (training/encoders.py).
+_STABILITY_VAE_SCALE = np.float32(0.5) / np.float32([4.17, 4.62, 3.71, 3.28])
+_STABILITY_VAE_BIAS = np.float32(0.0) - np.float32([5.81, 3.25, 0.12, -2.15]) * _STABILITY_VAE_SCALE
+
+class EDM2Pipeline(DiffusionPipeline):
+ r"""
+ Pipeline for class-conditional image generation with EDM2
+ ([Analyzing and Improving the Training Dynamics of Diffusion Models](https://arxiv.org/abs/2312.02696)).
+
+ Parameters:
+ unet ([`EDM2UNet2DModel`]):
+ Main magnitude-preserving U-Net with EDM preconditioning.
+ scheduler ([`EDMEulerScheduler`]):
+ Built-in diffusers scheduler used for the Karras sigma schedule. EDM2 Heun sampling runs in
+ the pipeline because the UNet returns denoised latents rather than noise predictions.
+ vae ([`AutoencoderKL`], *optional*):
+ Decoder for 512px latent-diffusion checkpoints. Required when `unet.in_channels == 4`.
+ gnet ([`EDM2UNet2DModel`], *optional*):
+ Guiding network for autoguidance (`ref.lerp(main, guidance_scale)`).
+ id2label (`dict[int, str]`, *optional*):
+ ImageNet class id to English label mapping.
+ """
+
+ model_cpu_offload_seq = "unet->gnet->vae"
+ _optional_components = ["vae", "gnet"]
+
+ def __init__(
+ self,
+ unet,
+ scheduler,
+ vae=None,
+ gnet=None,
+ id2label: Optional[Dict[Union[int, str], str]] = None,
+ ) -> None:
+ super().__init__()
+ self.register_modules(unet=unet, scheduler=scheduler, vae=vae, gnet=gnet)
+ self._id2label = self._normalize_id2label(id2label)
+ self.labels = self._build_label2id(self._id2label)
+ self._labels_loaded_from_model_index = bool(self._id2label)
+ self.vae_scale_factor = 8 if self.vae is not None else 1
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_normalize=False)
+
+ @staticmethod
+ def _normalize_id2label(id2label: Optional[Dict[Union[int, str], str]]) -> Dict[int, str]:
+ if not id2label:
+ return {}
+ return {int(key): value for key, value in id2label.items()}
+
+ @staticmethod
+ def _build_label2id(id2label: Dict[int, str]) -> Dict[str, int]:
+ label2id: Dict[str, int] = {}
+ for class_id, value in id2label.items():
+ for synonym in value.split(","):
+ synonym = synonym.strip()
+ if synonym:
+ label2id[synonym] = int(class_id)
+ return dict(sorted(label2id.items()))
+
+ def _ensure_labels_loaded(self) -> None:
+ if self._labels_loaded_from_model_index:
+ return
+ loaded = self._read_id2label_from_model_index(getattr(self.config, "_name_or_path", None))
+ if loaded:
+ self._id2label = loaded
+ self.labels = self._build_label2id(self._id2label)
+ self._labels_loaded_from_model_index = True
+
+ @staticmethod
+ def _read_id2label_from_model_index(variant_path: Optional[str]) -> Dict[int, str]:
+ if not variant_path:
+ return {}
+ model_index_path = Path(variant_path).resolve() / "model_index.json"
+ if not model_index_path.is_file():
+ return {}
+ raw = json.loads(model_index_path.read_text(encoding="utf-8"))
+ id2label = raw.get("id2label")
+ if not isinstance(id2label, dict):
+ return {}
+ return {int(key): value for key, value in id2label.items()}
+
+ @property
+ def id2label(self) -> Dict[int, str]:
+ r"""ImageNet class id to English label string (comma-separated synonyms)."""
+ self._ensure_labels_loaded()
+ return self._id2label
+
+ def get_label_ids(self, label: Union[str, List[str]]) -> List[int]:
+ r"""
+ Map ImageNet label strings to class ids.
+
+ Args:
+ label (`str` or `list[str]`):
+ One or more English label strings that match entries in `id2label`.
+ """
+ self._ensure_labels_loaded()
+ if not self.labels:
+ raise ValueError("No English labels loaded. Add `id2label` to model_index.json.")
+ labels = [label] if isinstance(label, str) else list(label)
+ missing = [item for item in labels if item not in self.labels]
+ if missing:
+ preview = ", ".join(list(self.labels.keys())[:8])
+ raise ValueError(f"Unknown English label(s): {missing}. Example valid labels: {preview}, ...")
+ return [self.labels[item] for item in labels]
+
+ def _default_image_size(self) -> int:
+ latent_size = int(getattr(self.unet, "sample_size", getattr(self.unet.config, "sample_size", 64)))
+ return latent_size * self.vae_scale_factor
+
+ def check_inputs(
+ self,
+ height: int,
+ width: int,
+ num_inference_steps: int,
+ guidance_scale: float,
+ output_type: str,
+ ) -> None:
+ if num_inference_steps < 1:
+ raise ValueError("num_inference_steps must be >= 1.")
+ if guidance_scale < 1.0:
+ raise ValueError("guidance_scale must be >= 1.0.")
+ if guidance_scale > 1.0 and self.gnet is None:
+ raise ValueError("guidance_scale > 1.0 requires a guiding network (`gnet`).")
+ if output_type not in {"pil", "np", "pt", "latent"}:
+ raise ValueError("output_type must be one of: 'pil', 'np', 'pt', 'latent'.")
+
+ native_size = self._default_image_size()
+ if height != native_size or width != native_size:
+ raise ValueError(
+ f"EDM2 expects native resolution height=width={native_size}. "
+ f"Got height={height}, width={width}."
+ )
+
+ def _normalize_class_labels(
+ self,
+ class_labels: Optional[Union[int, str, Sequence[Union[int, str]], torch.Tensor]],
+ batch_size: int,
+ device: torch.device,
+ ) -> Optional[torch.Tensor]:
+ label_dim = int(getattr(self.unet, "num_class_embeds", getattr(self.unet.config, "num_class_embeds", 0)))
+ if label_dim == 0:
+ return None
+ if class_labels is None:
+ indices = torch.randint(label_dim, size=(batch_size,), device=device)
+ return torch.eye(label_dim, device=device, dtype=torch.float32)[indices]
+
+ if isinstance(class_labels, str):
+ class_labels = self.get_label_ids(class_labels)[0]
+ elif isinstance(class_labels, Sequence) and class_labels and isinstance(class_labels[0], str):
+ class_labels = self.get_label_ids(list(class_labels))
+
+ if isinstance(class_labels, int):
+ indices = torch.full((batch_size,), class_labels, device=device, dtype=torch.long)
+ elif isinstance(class_labels, torch.Tensor):
+ if class_labels.ndim == 2:
+ labels = class_labels.to(device=device, dtype=torch.float32)
+ if labels.shape[0] != batch_size:
+ raise ValueError(f"class_labels batch must match batch_size={batch_size}.")
+ return labels
+ indices = class_labels.to(device=device, dtype=torch.long).flatten()
+ else:
+ indices = torch.tensor(list(class_labels), device=device, dtype=torch.long)
+
+ if indices.numel() == 1 and batch_size > 1:
+ indices = indices.repeat(batch_size)
+ if indices.numel() != batch_size:
+ raise ValueError(f"class_labels must resolve to batch size {batch_size}, got {indices.numel()}.")
+ return torch.eye(label_dim, device=device, dtype=torch.float32)[indices]
+
+ def prepare_latents(
+ self,
+ batch_size: int,
+ height: int,
+ width: int,
+ dtype: torch.dtype,
+ device: torch.device,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]],
+ ) -> torch.Tensor:
+ in_channels = int(getattr(self.unet, "in_channels", getattr(self.unet.config, "in_channels", 4)))
+ latent_size = height // self.vae_scale_factor
+ return randn_tensor(
+ (batch_size, in_channels, latent_size, latent_size),
+ generator=generator,
+ device=device,
+ dtype=torch.float32,
+ )
+
+ def decode_latents(self, latents: torch.Tensor, output_type: str = "pil"):
+ if output_type == "latent":
+ return latents
+
+ in_channels = int(getattr(self.unet, "in_channels", getattr(self.unet.config, "in_channels", 3)))
+ if self.vae is None:
+ image = (latents.to(torch.float32) * 127.5 + 128).clip(0, 255) / 255.0
+ return self.image_processor.postprocess(image, output_type=output_type)
+
+ if in_channels == 4:
+ x = latents.to(torch.float32)
+ scale = torch.as_tensor(_STABILITY_VAE_SCALE, dtype=x.dtype, device=x.device).reshape(1, -1, 1, 1)
+ bias = torch.as_tensor(_STABILITY_VAE_BIAS, dtype=x.dtype, device=x.device).reshape(1, -1, 1, 1)
+ x = (x - bias) / scale
+ else:
+ x = latents.to(torch.float32)
+
+ vae_dtype = getattr(self.vae, "dtype", None) or next(self.vae.parameters()).dtype
+ image = self.vae.decode(x.to(dtype=vae_dtype)).sample.to(torch.float32).clamp(0, 1)
+
+ return self.image_processor.postprocess(image, output_type=output_type)
+
+ @staticmethod
+ def _apply_autoguidance(
+ main: torch.Tensor,
+ ref: torch.Tensor,
+ guidance_scale: float,
+ ) -> torch.Tensor:
+ return ref.lerp(main, guidance_scale)
+
+ @staticmethod
+ def _sample_edm2_heun(
+ denoise_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
+ noise: torch.Tensor,
+ sigmas: torch.Tensor,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ progress_bar: Optional[Callable[[Iterable], Iterable]] = None,
+ dtype: torch.dtype = torch.float32,
+ ) -> torch.Tensor:
+ """NVlabs EDM2 Heun sampler (generate_images.edm_sampler, guidance=1, S_churn=0)."""
+ x_next = noise.to(dtype) * sigmas[0]
+
+ sigma_pairs = list(zip(sigmas[:-1], sigmas[1:]))
+ if progress_bar is not None:
+ sigma_pairs = progress_bar(sigma_pairs)
+
+ num_steps = len(sigma_pairs)
+ for i, (sigma_cur, sigma_next) in enumerate(sigma_pairs):
+ x_hat, sigma_hat = x_next, sigma_cur
+ d_cur = (x_hat - denoise_fn(x_hat, sigma_hat)) / sigma_hat
+ x_next = x_hat + (sigma_next - sigma_hat) * d_cur
+ if i < num_steps - 1:
+ d_prime = (x_next - denoise_fn(x_next, sigma_next)) / sigma_next
+ x_next = x_hat + (sigma_next - sigma_hat) * (0.5 * d_cur + 0.5 * d_prime)
+ return x_next
+
+ @torch.inference_mode()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ class_labels: Optional[Union[int, str, Sequence[Union[int, str]], torch.Tensor]] = None,
+ batch_size: int = 1,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 32,
+ guidance_scale: float = 1.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ output_type: str = "pil",
+ return_dict: bool = True,
+ ) -> Union[ImagePipelineOutput, Tuple]:
+ r"""
+ Generate class-conditional images with EDM2.
+
+ Args:
+ class_labels (`int`, `str`, `list[int]`, `list[str]`, or `torch.Tensor`, *optional*):
+ ImageNet class indices, English label strings, or one-hot float tensors.
+ Random classes are sampled when omitted on conditional models.
+ batch_size (`int`, defaults to `1`):
+ Number of images to generate.
+ height (`int`, *optional*):
+ Output height in pixels. Defaults to the pretrained native resolution.
+ width (`int`, *optional*):
+ Output width in pixels. Defaults to the pretrained native resolution.
+ num_inference_steps (`int`, defaults to `32`):
+ Number of EDM2 Heun steps (NVlabs default).
+ guidance_scale (`float`, defaults to `1.0`):
+ Autoguidance strength. Values above `1.0` blend the main net with `gnet`
+ via `gnet_output.lerp(unet_output, guidance_scale)`.
+ generator (`torch.Generator`, *optional*):
+ RNG for reproducibility.
+ output_type (`str`, defaults to `"pil"`):
+ `"pil"`, `"np"`, `"pt"`, or `"latent"`.
+ return_dict (`bool`, defaults to `True`):
+ Return [`~pipelines.pipeline_utils.ImagePipelineOutput`] if True.
+
+ Examples:
+
+ """
+ default_size = self._default_image_size()
+ height = int(height or default_size)
+ width = int(width or default_size)
+ self.check_inputs(height, width, num_inference_steps, guidance_scale, output_type)
+
+ device = self._execution_device
+ dtype = self.unet.dtype
+ labels = self._normalize_class_labels(class_labels, batch_size=batch_size, device=device)
+ noise = self.prepare_latents(batch_size, height, width, dtype, device, generator)
+
+ def denoise_fn(x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor:
+ sigma_batch = sigma.reshape(1).expand(batch_size)
+ main = self.unet(
+ sample=x,
+ sigma=sigma_batch,
+ class_labels=labels,
+ force_fp32=True,
+ ).sample
+ if guidance_scale == 1.0 or self.gnet is None:
+ return main.to(torch.float32)
+ ref = self.gnet(
+ sample=x,
+ sigma=sigma_batch,
+ class_labels=labels,
+ force_fp32=True,
+ ).sample
+ return self._apply_autoguidance(main, ref, guidance_scale).to(torch.float32)
+
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ latents = self._sample_edm2_heun(
+ denoise_fn=denoise_fn,
+ noise=noise,
+ sigmas=self.scheduler.sigmas.to(device),
+ generator=generator,
+ progress_bar=self.progress_bar,
+ dtype=torch.float32,
+ )
+
+ image = self.decode_latents(latents, output_type=output_type)
+ if not return_dict:
+ return (image, latents)
+ return ImagePipelineOutput(images=image)
+
+ @classmethod
+ def _load_vae(cls, pretrained_model_name_or_path: str, torch_dtype: Optional[torch.dtype] = None):
+ vae_dir = os.path.join(pretrained_model_name_or_path, "vae")
+ if os.path.isdir(vae_dir):
+ try:
+
+ return AutoencoderKL.from_pretrained(vae_dir, torch_dtype=torch_dtype)
+ except Exception:
+ return None
+
+ vae_hint = os.path.join(pretrained_model_name_or_path, "vae_pretrained_model_name_or_path.txt")
+ if os.path.isfile(vae_hint):
+ with open(vae_hint, "r", encoding="utf-8") as f:
+ hub_id = f.read().strip()
+ if hub_id:
+
+ return AutoencoderKL.from_pretrained(hub_id, torch_dtype=torch_dtype)
+ return None
diff --git a/edm2-img512-m-fid/scheduler/scheduler_config.json b/edm2-img512-m-fid/scheduler/scheduler_config.json
new file mode 100644
index 0000000000000000000000000000000000000000..08cbde97955d06d66b0b72be217db84bb797d711
--- /dev/null
+++ b/edm2-img512-m-fid/scheduler/scheduler_config.json
@@ -0,0 +1,11 @@
+{
+ "_class_name": "EDMEulerScheduler",
+ "final_sigmas_type": "zero",
+ "num_train_timesteps": 1000,
+ "prediction_type": "epsilon",
+ "rho": 7.0,
+ "sigma_data": 0.5,
+ "sigma_max": 80.0,
+ "sigma_min": 0.002,
+ "sigma_schedule": "karras"
+}
diff --git a/edm2-img512-m-fid/unet/config.json b/edm2-img512-m-fid/unet/config.json
new file mode 100644
index 0000000000000000000000000000000000000000..0173e5acb7a02db0cd9bfde8ed828ce386e96899
--- /dev/null
+++ b/edm2-img512-m-fid/unet/config.json
@@ -0,0 +1,31 @@
+{
+ "_class_name": "EDM2UNet2DModel",
+ "attn_balance": 0.3,
+ "attn_resolutions": [
+ 16,
+ 8
+ ],
+ "channel_mult": [
+ 1,
+ 2,
+ 3,
+ 4
+ ],
+ "channel_mult_emb": 4,
+ "channel_mult_noise": 1,
+ "channels_per_head": 64,
+ "clip_act": 256,
+ "concat_balance": 0.5,
+ "dropout": 0.0,
+ "in_channels": 4,
+ "label_balance": 0.5,
+ "logvar_channels": 128,
+ "model_channels": 256,
+ "num_blocks": 3,
+ "num_class_embeds": 1000,
+ "out_channels": 4,
+ "res_balance": 0.3,
+ "sample_size": 64,
+ "sigma_data": 0.5,
+ "use_fp16": true
+}
diff --git a/edm2-img512-m-fid/unet/diffusion_pytorch_model.safetensors b/edm2-img512-m-fid/unet/diffusion_pytorch_model.safetensors
new file mode 100644
index 0000000000000000000000000000000000000000..ce30a1a7af868a3d6f1cae538169c82793ff12f5
--- /dev/null
+++ b/edm2-img512-m-fid/unet/diffusion_pytorch_model.safetensors
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:4733c8b2d2823cd6ce7a67e2b89b0e9b94d50fdf595b0e0b17299e198da3bcfc
+size 1991256788
diff --git a/edm2-img512-m-fid/unet/unet_edm2.py b/edm2-img512-m-fid/unet/unet_edm2.py
new file mode 100644
index 0000000000000000000000000000000000000000..49d9451b443c4d6ad3b6ba2db19a7a804b07b6de
--- /dev/null
+++ b/edm2-img512-m-fid/unet/unet_edm2.py
@@ -0,0 +1,434 @@
+import math
+import json
+import os
+from dataclasses import dataclass
+from typing import List, Optional, Tuple
+
+import numpy as np
+import torch
+
+try:
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
+ from diffusers.models.modeling_utils import ModelMixin
+ from diffusers.utils import BaseOutput
+except ImportError: # pragma: no cover
+ class ModelMixin(torch.nn.Module):
+ pass
+
+ class ConfigMixin:
+ config = {}
+
+ def register_to_config(self, **kwargs):
+ self.config = kwargs
+
+ def register_to_config(func):
+ return func
+
+ @dataclass
+ class BaseOutput:
+ pass
+
+
+def normalize(x: torch.Tensor, dim: Optional[List[int]] = None, eps: float = 1e-4) -> torch.Tensor:
+ if dim is None:
+ dim = list(range(1, x.ndim))
+ norm = torch.linalg.vector_norm(x, dim=dim, keepdim=True, dtype=torch.float32)
+ norm = torch.add(eps, norm, alpha=math.sqrt(norm.numel() / x.numel()))
+ return x / norm.to(x.dtype)
+
+
+def resample(x: torch.Tensor, f: List[float], mode: str = "keep") -> torch.Tensor:
+ if mode == "keep":
+ return x
+ filt = np.float32(f)
+ pad = (len(filt) - 1) // 2
+ filt = filt / filt.sum()
+ filt = np.outer(filt, filt)[np.newaxis, np.newaxis, :, :]
+ filt = torch.as_tensor(filt, dtype=x.dtype, device=x.device)
+ c = x.shape[1]
+ if mode == "down":
+ return torch.nn.functional.conv2d(x, filt.tile([c, 1, 1, 1]), groups=c, stride=2, padding=(pad,))
+ return torch.nn.functional.conv_transpose2d(x, (filt * 4).tile([c, 1, 1, 1]), groups=c, stride=2, padding=(pad,))
+
+
+def mp_silu(x: torch.Tensor) -> torch.Tensor:
+ return torch.nn.functional.silu(x) / 0.596
+
+
+def mp_sum(a: torch.Tensor, b: torch.Tensor, t: float = 0.5) -> torch.Tensor:
+ return a.lerp(b, t) / math.sqrt((1 - t) ** 2 + t ** 2)
+
+
+def mp_cat(a: torch.Tensor, b: torch.Tensor, dim: int = 1, t: float = 0.5) -> torch.Tensor:
+ na = a.shape[dim]
+ nb = b.shape[dim]
+ c = math.sqrt((na + nb) / ((1 - t) ** 2 + t ** 2))
+ wa = c / math.sqrt(na) * (1 - t)
+ wb = c / math.sqrt(nb) * t
+ return torch.cat([wa * a, wb * b], dim=dim)
+
+
+class MPFourier(torch.nn.Module):
+ def __init__(self, num_channels: int, bandwidth: float = 1):
+ super().__init__()
+ self.register_buffer("freqs", 2 * math.pi * torch.randn(num_channels) * bandwidth)
+ self.register_buffer("phases", 2 * math.pi * torch.rand(num_channels))
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ y = x.to(torch.float32).ger(self.freqs.to(torch.float32))
+ y = y + self.phases.to(torch.float32)
+ y = y.cos() * math.sqrt(2)
+ return y.to(x.dtype)
+
+
+class MPConv(torch.nn.Module):
+ def __init__(self, in_channels: int, out_channels: int, kernel: Tuple[int, ...]):
+ super().__init__()
+ self.out_channels = out_channels
+ self.weight = torch.nn.Parameter(torch.randn(out_channels, in_channels, *kernel))
+
+ def forward(self, x: torch.Tensor, gain: float = 1) -> torch.Tensor:
+ w = self.weight.to(torch.float32)
+ if self.training:
+ with torch.no_grad():
+ self.weight.copy_(normalize(w))
+ w = normalize(w)
+ w = w * (gain / math.sqrt(w[0].numel()))
+ w = w.to(x.dtype)
+ if w.ndim == 2:
+ return x @ w.t()
+ return torch.nn.functional.conv2d(x, w, padding=(w.shape[-1] // 2,))
+
+
+class Block(torch.nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ emb_channels: int,
+ flavor: str = "enc",
+ resample_mode: str = "keep",
+ resample_filter: List[float] = [1, 1],
+ attention: bool = False,
+ channels_per_head: int = 64,
+ dropout: float = 0.0,
+ res_balance: float = 0.3,
+ attn_balance: float = 0.3,
+ clip_act: Optional[float] = 256,
+ ):
+ super().__init__()
+ self.out_channels = out_channels
+ self.flavor = flavor
+ self.resample_filter = resample_filter
+ self.resample_mode = resample_mode
+ self.num_heads = out_channels // channels_per_head if attention else 0
+ self.dropout = dropout
+ self.res_balance = res_balance
+ self.attn_balance = attn_balance
+ self.clip_act = clip_act
+ self.emb_gain = torch.nn.Parameter(torch.zeros([]))
+ self.conv_res0 = MPConv(out_channels if flavor == "enc" else in_channels, out_channels, kernel=(3, 3))
+ self.emb_linear = MPConv(emb_channels, out_channels, kernel=())
+ self.conv_res1 = MPConv(out_channels, out_channels, kernel=(3, 3))
+ self.conv_skip = MPConv(in_channels, out_channels, kernel=(1, 1)) if in_channels != out_channels else None
+ self.attn_qkv = MPConv(out_channels, out_channels * 3, kernel=(1, 1)) if self.num_heads else None
+ self.attn_proj = MPConv(out_channels, out_channels, kernel=(1, 1)) if self.num_heads else None
+
+ def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor:
+ x = resample(x, f=self.resample_filter, mode=self.resample_mode)
+ if self.flavor == "enc":
+ if self.conv_skip is not None:
+ x = self.conv_skip(x)
+ x = normalize(x, dim=[1])
+
+ y = self.conv_res0(mp_silu(x))
+ c = self.emb_linear(emb, gain=self.emb_gain) + 1
+ y = mp_silu(y * c.unsqueeze(2).unsqueeze(3).to(y.dtype))
+ if self.training and self.dropout:
+ y = torch.nn.functional.dropout(y, p=self.dropout)
+ y = self.conv_res1(y)
+
+ if self.flavor == "dec" and self.conv_skip is not None:
+ x = self.conv_skip(x)
+ x = mp_sum(x, y, t=self.res_balance)
+
+ if self.num_heads:
+ y = self.attn_qkv(x)
+ y = y.reshape(y.shape[0], self.num_heads, -1, 3, y.shape[2] * y.shape[3])
+ q, k, v = normalize(y, dim=[2]).unbind(3)
+ w = torch.einsum("nhcq,nhck->nhqk", q, k / math.sqrt(q.shape[2])).softmax(dim=3)
+ y = torch.einsum("nhqk,nhck->nhcq", w, v)
+ y = self.attn_proj(y.reshape(*x.shape))
+ x = mp_sum(x, y, t=self.attn_balance)
+
+ if self.clip_act is not None:
+ x = x.clip_(-self.clip_act, self.clip_act)
+ return x
+
+
+class EDM2UNet(torch.nn.Module):
+ def __init__(
+ self,
+ img_resolution: int,
+ img_channels: int,
+ label_dim: int,
+ model_channels: int = 192,
+ channel_mult: Tuple[int, ...] = (1, 2, 3, 4),
+ channel_mult_noise: Optional[int] = None,
+ channel_mult_emb: Optional[int] = None,
+ num_blocks: int = 3,
+ attn_resolutions: Tuple[int, ...] = (16, 8),
+ label_balance: float = 0.5,
+ concat_balance: float = 0.5,
+ **block_kwargs,
+ ):
+ super().__init__()
+ cblock = [model_channels * x for x in channel_mult]
+ cnoise = model_channels * channel_mult_noise if channel_mult_noise is not None else cblock[0]
+ cemb = model_channels * channel_mult_emb if channel_mult_emb is not None else max(cblock)
+ self.label_balance = label_balance
+ self.concat_balance = concat_balance
+ self.out_gain = torch.nn.Parameter(torch.zeros([]))
+
+ self.emb_fourier = MPFourier(cnoise)
+ self.emb_noise = MPConv(cnoise, cemb, kernel=())
+ self.emb_label = MPConv(label_dim, cemb, kernel=()) if label_dim else None
+
+ self.enc = torch.nn.ModuleDict()
+ cout = img_channels + 1
+ for level, channels in enumerate(cblock):
+ res = img_resolution >> level
+ if level == 0:
+ cin = cout
+ cout = channels
+ self.enc[f"{res}x{res}_conv"] = MPConv(cin, cout, kernel=(3, 3))
+ else:
+ self.enc[f"{res}x{res}_down"] = Block(cout, cout, cemb, flavor="enc", resample_mode="down", **block_kwargs)
+ for idx in range(num_blocks):
+ cin = cout
+ cout = channels
+ self.enc[f"{res}x{res}_block{idx}"] = Block(
+ cin,
+ cout,
+ cemb,
+ flavor="enc",
+ attention=(res in attn_resolutions),
+ **block_kwargs,
+ )
+
+ self.dec = torch.nn.ModuleDict()
+ skips = [block.out_channels for block in self.enc.values()]
+ for level, channels in reversed(list(enumerate(cblock))):
+ res = img_resolution >> level
+ if level == len(cblock) - 1:
+ self.dec[f"{res}x{res}_in0"] = Block(cout, cout, cemb, flavor="dec", attention=True, **block_kwargs)
+ self.dec[f"{res}x{res}_in1"] = Block(cout, cout, cemb, flavor="dec", **block_kwargs)
+ else:
+ self.dec[f"{res}x{res}_up"] = Block(cout, cout, cemb, flavor="dec", resample_mode="up", **block_kwargs)
+ for idx in range(num_blocks + 1):
+ cin = cout + skips.pop()
+ cout = channels
+ self.dec[f"{res}x{res}_block{idx}"] = Block(
+ cin,
+ cout,
+ cemb,
+ flavor="dec",
+ attention=(res in attn_resolutions),
+ **block_kwargs,
+ )
+
+ self.out_conv = MPConv(cout, img_channels, kernel=(3, 3))
+
+ def forward(self, x: torch.Tensor, noise_labels: torch.Tensor, class_labels: Optional[torch.Tensor]) -> torch.Tensor:
+ emb = self.emb_noise(self.emb_fourier(noise_labels))
+ if self.emb_label is not None:
+ if class_labels is None:
+ raise ValueError("class_labels are required for conditional EDM2UNet.")
+ emb = mp_sum(emb, self.emb_label(class_labels * math.sqrt(class_labels.shape[1])), t=self.label_balance)
+ emb = mp_silu(emb)
+
+ x = torch.cat([x, torch.ones_like(x[:, :1])], dim=1)
+ skips = []
+ for name, block in self.enc.items():
+ x = block(x) if "conv" in name else block(x, emb)
+ skips.append(x)
+
+ for name, block in self.dec.items():
+ if "block" in name:
+ x = mp_cat(x, skips.pop(), t=self.concat_balance)
+ x = block(x, emb)
+ return self.out_conv(x, gain=self.out_gain)
+
+
+@dataclass
+class EDM2UNet2DOutput(BaseOutput):
+ sample: torch.Tensor
+ logvar: Optional[torch.Tensor] = None
+
+
+
+_CONFIG_KEYS = (
+ "sample_size",
+ "in_channels",
+ "out_channels",
+ "num_class_embeds",
+ "use_fp16",
+ "sigma_data",
+ "logvar_channels",
+ "model_channels",
+ "channel_mult",
+ "channel_mult_noise",
+ "channel_mult_emb",
+ "num_blocks",
+ "attn_resolutions",
+ "label_balance",
+ "concat_balance",
+ "dropout",
+ "channels_per_head",
+ "res_balance",
+ "attn_balance",
+ "clip_act",
+)
+
+
+class EDM2UNet2DModel(ModelMixin, ConfigMixin):
+ @register_to_config
+ def __init__(
+ self,
+ sample_size: int = 64,
+ in_channels: int = 4,
+ out_channels: int = 4,
+ num_class_embeds: int = 0,
+ use_fp16: bool = True,
+ sigma_data: float = 0.5,
+ logvar_channels: int = 128,
+ model_channels: int = 192,
+ channel_mult: Tuple[int, ...] = (1, 2, 3, 4),
+ channel_mult_noise: Optional[int] = None,
+ channel_mult_emb: Optional[int] = None,
+ num_blocks: int = 3,
+ attn_resolutions: Tuple[int, ...] = (16, 8),
+ label_balance: float = 0.5,
+ concat_balance: float = 0.5,
+ dropout: float = 0.0,
+ channels_per_head: int = 64,
+ res_balance: float = 0.3,
+ attn_balance: float = 0.3,
+ clip_act: Optional[float] = 256,
+ ):
+ super().__init__()
+ self.sample_size = sample_size
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.num_class_embeds = num_class_embeds
+ self.use_fp16 = use_fp16
+ self.sigma_data = sigma_data
+ self.model_channels = model_channels
+ self.channel_mult = channel_mult
+ self.channel_mult_noise = channel_mult_noise
+ self.channel_mult_emb = channel_mult_emb
+ self.num_blocks = num_blocks
+ self.attn_resolutions = attn_resolutions
+ self.label_balance = label_balance
+ self.concat_balance = concat_balance
+ self.dropout = dropout
+ self.channels_per_head = channels_per_head
+ self.res_balance = res_balance
+ self.attn_balance = attn_balance
+ self.clip_act = clip_act
+ self.unet = EDM2UNet(
+ img_resolution=sample_size,
+ img_channels=in_channels,
+ label_dim=num_class_embeds,
+ model_channels=model_channels,
+ channel_mult=channel_mult,
+ channel_mult_noise=channel_mult_noise,
+ channel_mult_emb=channel_mult_emb,
+ num_blocks=num_blocks,
+ attn_resolutions=attn_resolutions,
+ label_balance=label_balance,
+ concat_balance=concat_balance,
+ dropout=dropout,
+ channels_per_head=channels_per_head,
+ res_balance=res_balance,
+ attn_balance=attn_balance,
+ clip_act=clip_act,
+ )
+ self.logvar_fourier = MPFourier(logvar_channels)
+ self.logvar_linear = MPConv(logvar_channels, 1, kernel=())
+
+ def forward(
+ self,
+ sample: torch.Tensor,
+ sigma: torch.Tensor,
+ class_labels: Optional[torch.Tensor] = None,
+ force_fp32: bool = False,
+ return_logvar: bool = False,
+ return_dict: bool = True,
+ ) -> EDM2UNet2DOutput:
+ x = sample.to(torch.float32)
+ sigma = sigma.to(torch.float32).reshape(-1, 1, 1, 1)
+ if self.num_class_embeds == 0:
+ class_labels = None
+ else:
+ if class_labels is None:
+ class_labels = torch.zeros([x.shape[0], self.num_class_embeds], device=x.device)
+ class_labels = class_labels.to(torch.float32).reshape(-1, self.num_class_embeds)
+ dtype = torch.float16 if (self.use_fp16 and not force_fp32 and x.device.type == "cuda") else torch.float32
+
+ c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2)
+ c_out = sigma * self.sigma_data / (sigma**2 + self.sigma_data**2).sqrt()
+ c_in = 1 / (self.sigma_data**2 + sigma**2).sqrt()
+ c_noise = sigma.flatten().log() / 4
+
+ x_in = (c_in * x).to(dtype)
+ f_x = self.unet(x_in, c_noise, class_labels)
+ d_x = c_skip * x + c_out * f_x.to(torch.float32)
+
+ logvar = None
+ if return_logvar:
+ logvar = self.logvar_linear(self.logvar_fourier(c_noise)).reshape(-1, 1, 1, 1)
+
+ if not return_dict:
+ return (d_x, logvar)
+ return EDM2UNet2DOutput(sample=d_x, logvar=logvar)
+
+ @classmethod
+ def from_pretrained(cls, pretrained_model_name_or_path: str, torch_dtype: Optional[torch.dtype] = None, **kwargs):
+ subfolder = kwargs.pop("subfolder", None)
+ model_dir = os.path.join(pretrained_model_name_or_path, subfolder) if subfolder else pretrained_model_name_or_path
+ with open(os.path.join(model_dir, "config.json"), "r", encoding="utf-8") as f:
+ config = json.load(f)
+ init_kwargs = {k: v for k, v in config.items() if k in _CONFIG_KEYS}
+ model = cls(**init_kwargs)
+ weight_file = os.path.join(model_dir, "diffusion_pytorch_model.safetensors")
+ if os.path.isfile(weight_file):
+ from safetensors.torch import load_file
+
+ state_dict = load_file(weight_file)
+ else:
+ state_dict = torch.load(os.path.join(model_dir, "diffusion_pytorch_model.bin"), map_location="cpu")
+ model.load_state_dict(state_dict, strict=True)
+ if torch_dtype is not None:
+ model = model.to(dtype=torch_dtype)
+ return model
+
+ def save_pretrained(self, save_directory: str, safe_serialization: bool = True):
+ os.makedirs(save_directory, exist_ok=True)
+ stored = dict(getattr(self, "config", {}))
+ config = {"_class_name": self.__class__.__name__}
+ for key in _CONFIG_KEYS:
+ if key in stored:
+ config[key] = stored[key]
+ elif hasattr(self, key):
+ config[key] = getattr(self, key)
+ with open(os.path.join(save_directory, "config.json"), "w", encoding="utf-8") as f:
+ json.dump(config, f, indent=2, sort_keys=True)
+ f.write("\n")
+ state_dict = self.state_dict()
+ if safe_serialization:
+ from safetensors.torch import save_file
+
+ save_file(state_dict, os.path.join(save_directory, "diffusion_pytorch_model.safetensors"))
+ else:
+ torch.save(state_dict, os.path.join(save_directory, "diffusion_pytorch_model.bin"))
diff --git a/edm2-img512-m-fid/vae/config.json b/edm2-img512-m-fid/vae/config.json
new file mode 100644
index 0000000000000000000000000000000000000000..ae2ab37de2acb8bc348d3de809ae8385324be962
--- /dev/null
+++ b/edm2-img512-m-fid/vae/config.json
@@ -0,0 +1,38 @@
+{
+ "_class_name": "AutoencoderKL",
+ "_diffusers_version": "0.36.0",
+ "_name_or_path": "stabilityai/sd-vae-ft-mse",
+ "act_fn": "silu",
+ "block_out_channels": [
+ 128,
+ 256,
+ 512,
+ 512
+ ],
+ "down_block_types": [
+ "DownEncoderBlock2D",
+ "DownEncoderBlock2D",
+ "DownEncoderBlock2D",
+ "DownEncoderBlock2D"
+ ],
+ "force_upcast": true,
+ "in_channels": 3,
+ "latent_channels": 4,
+ "latents_mean": null,
+ "latents_std": null,
+ "layers_per_block": 2,
+ "mid_block_add_attention": true,
+ "norm_num_groups": 32,
+ "out_channels": 3,
+ "sample_size": 256,
+ "scaling_factor": 0.18215,
+ "shift_factor": null,
+ "up_block_types": [
+ "UpDecoderBlock2D",
+ "UpDecoderBlock2D",
+ "UpDecoderBlock2D",
+ "UpDecoderBlock2D"
+ ],
+ "use_post_quant_conv": true,
+ "use_quant_conv": true
+}
diff --git a/edm2-img512-m-fid/vae/diffusion_pytorch_model.safetensors b/edm2-img512-m-fid/vae/diffusion_pytorch_model.safetensors
new file mode 100644
index 0000000000000000000000000000000000000000..90464d67ac7303d0ee4696334df13da130a948ea
--- /dev/null
+++ b/edm2-img512-m-fid/vae/diffusion_pytorch_model.safetensors
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a1d993488569e928462932c8c38a0760b874d166399b14414135bd9c42df5815
+size 334643276
diff --git a/edm2-img512-s-fid/demo.png b/edm2-img512-s-fid/demo.png
new file mode 100644
index 0000000000000000000000000000000000000000..aed38cdc126d7bffa89fbadc5e8f61c3afd45c16
--- /dev/null
+++ b/edm2-img512-s-fid/demo.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:58bdb49e30c85b02b9e3619a11b39b1ec760452e8ad96cea1c5856e99df39d42
+size 381489
diff --git a/edm2-img512-s-fid/model_index.json b/edm2-img512-s-fid/model_index.json
new file mode 100644
index 0000000000000000000000000000000000000000..24c0305eea3ec64abbc2f870476afa4f5e462f42
--- /dev/null
+++ b/edm2-img512-s-fid/model_index.json
@@ -0,0 +1,19 @@
+{
+ "_class_name": [
+ "pipeline",
+ "EDM2Pipeline"
+ ],
+ "_diffusers_version": "0.31.0",
+ "scheduler": [
+ "diffusers",
+ "EDMEulerScheduler"
+ ],
+ "unet": [
+ "unet_edm2",
+ "EDM2UNet2DModel"
+ ],
+ "vae": [
+ "diffusers",
+ "AutoencoderKL"
+ ]
+}
diff --git a/edm2-img512-s-fid/pipeline.py b/edm2-img512-s-fid/pipeline.py
new file mode 100644
index 0000000000000000000000000000000000000000..6b4131215d07770c74efa0674d00f3d6423d9c06
--- /dev/null
+++ b/edm2-img512-s-fid/pipeline.py
@@ -0,0 +1,406 @@
+"""Hub custom pipeline: EDM2Pipeline.
+Load with native Hugging Face diffusers and trust_remote_code=True.
+"""
+
+from __future__ import annotations
+
+# Copyright 2026 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+import os
+from pathlib import Path
+from typing import Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Union
+
+import numpy as np
+import torch
+
+from diffusers.image_processor import VaeImageProcessor
+from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
+from diffusers.utils import replace_example_docstring
+from diffusers.utils.torch_utils import randn_tensor
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> from pathlib import Path
+ >>> import torch
+ >>> from diffusers import DiffusionPipeline
+
+ >>> model_dir = Path("BiliSakura/EDM2-diffusers/edm2-img512-xs-fid").resolve()
+ >>> pipe = DiffusionPipeline.from_pretrained(
+ ... str(model_dir),
+ ... local_files_only=True,
+ ... custom_pipeline=str(model_dir / "pipeline.py"),
+ ... trust_remote_code=True,
+ ... torch_dtype=torch.float32,
+ ... )
+ >>> pipe.to("cuda")
+
+ >>> generator = torch.Generator(device="cuda").manual_seed(42)
+ >>> image = pipe(
+ ... class_labels=207,
+ ... num_inference_steps=32,
+ ... guidance_scale=1.0,
+ ... generator=generator,
+ ... ).images[0]
+ >>> image.save("demo.png")
+ ```
+"""
+
+# Default Stability VAE latent whitening used by NVlabs/edm2 (training/encoders.py).
+_STABILITY_VAE_SCALE = np.float32(0.5) / np.float32([4.17, 4.62, 3.71, 3.28])
+_STABILITY_VAE_BIAS = np.float32(0.0) - np.float32([5.81, 3.25, 0.12, -2.15]) * _STABILITY_VAE_SCALE
+
+class EDM2Pipeline(DiffusionPipeline):
+ r"""
+ Pipeline for class-conditional image generation with EDM2
+ ([Analyzing and Improving the Training Dynamics of Diffusion Models](https://arxiv.org/abs/2312.02696)).
+
+ Parameters:
+ unet ([`EDM2UNet2DModel`]):
+ Main magnitude-preserving U-Net with EDM preconditioning.
+ scheduler ([`EDMEulerScheduler`]):
+ Built-in diffusers scheduler used for the Karras sigma schedule. EDM2 Heun sampling runs in
+ the pipeline because the UNet returns denoised latents rather than noise predictions.
+ vae ([`AutoencoderKL`], *optional*):
+ Decoder for 512px latent-diffusion checkpoints. Required when `unet.in_channels == 4`.
+ gnet ([`EDM2UNet2DModel`], *optional*):
+ Guiding network for autoguidance (`ref.lerp(main, guidance_scale)`).
+ id2label (`dict[int, str]`, *optional*):
+ ImageNet class id to English label mapping.
+ """
+
+ model_cpu_offload_seq = "unet->gnet->vae"
+ _optional_components = ["vae", "gnet"]
+
+ def __init__(
+ self,
+ unet,
+ scheduler,
+ vae=None,
+ gnet=None,
+ id2label: Optional[Dict[Union[int, str], str]] = None,
+ ) -> None:
+ super().__init__()
+ self.register_modules(unet=unet, scheduler=scheduler, vae=vae, gnet=gnet)
+ self._id2label = self._normalize_id2label(id2label)
+ self.labels = self._build_label2id(self._id2label)
+ self._labels_loaded_from_model_index = bool(self._id2label)
+ self.vae_scale_factor = 8 if self.vae is not None else 1
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_normalize=False)
+
+ @staticmethod
+ def _normalize_id2label(id2label: Optional[Dict[Union[int, str], str]]) -> Dict[int, str]:
+ if not id2label:
+ return {}
+ return {int(key): value for key, value in id2label.items()}
+
+ @staticmethod
+ def _build_label2id(id2label: Dict[int, str]) -> Dict[str, int]:
+ label2id: Dict[str, int] = {}
+ for class_id, value in id2label.items():
+ for synonym in value.split(","):
+ synonym = synonym.strip()
+ if synonym:
+ label2id[synonym] = int(class_id)
+ return dict(sorted(label2id.items()))
+
+ def _ensure_labels_loaded(self) -> None:
+ if self._labels_loaded_from_model_index:
+ return
+ loaded = self._read_id2label_from_model_index(getattr(self.config, "_name_or_path", None))
+ if loaded:
+ self._id2label = loaded
+ self.labels = self._build_label2id(self._id2label)
+ self._labels_loaded_from_model_index = True
+
+ @staticmethod
+ def _read_id2label_from_model_index(variant_path: Optional[str]) -> Dict[int, str]:
+ if not variant_path:
+ return {}
+ model_index_path = Path(variant_path).resolve() / "model_index.json"
+ if not model_index_path.is_file():
+ return {}
+ raw = json.loads(model_index_path.read_text(encoding="utf-8"))
+ id2label = raw.get("id2label")
+ if not isinstance(id2label, dict):
+ return {}
+ return {int(key): value for key, value in id2label.items()}
+
+ @property
+ def id2label(self) -> Dict[int, str]:
+ r"""ImageNet class id to English label string (comma-separated synonyms)."""
+ self._ensure_labels_loaded()
+ return self._id2label
+
+ def get_label_ids(self, label: Union[str, List[str]]) -> List[int]:
+ r"""
+ Map ImageNet label strings to class ids.
+
+ Args:
+ label (`str` or `list[str]`):
+ One or more English label strings that match entries in `id2label`.
+ """
+ self._ensure_labels_loaded()
+ if not self.labels:
+ raise ValueError("No English labels loaded. Add `id2label` to model_index.json.")
+ labels = [label] if isinstance(label, str) else list(label)
+ missing = [item for item in labels if item not in self.labels]
+ if missing:
+ preview = ", ".join(list(self.labels.keys())[:8])
+ raise ValueError(f"Unknown English label(s): {missing}. Example valid labels: {preview}, ...")
+ return [self.labels[item] for item in labels]
+
+ def _default_image_size(self) -> int:
+ latent_size = int(getattr(self.unet, "sample_size", getattr(self.unet.config, "sample_size", 64)))
+ return latent_size * self.vae_scale_factor
+
+ def check_inputs(
+ self,
+ height: int,
+ width: int,
+ num_inference_steps: int,
+ guidance_scale: float,
+ output_type: str,
+ ) -> None:
+ if num_inference_steps < 1:
+ raise ValueError("num_inference_steps must be >= 1.")
+ if guidance_scale < 1.0:
+ raise ValueError("guidance_scale must be >= 1.0.")
+ if guidance_scale > 1.0 and self.gnet is None:
+ raise ValueError("guidance_scale > 1.0 requires a guiding network (`gnet`).")
+ if output_type not in {"pil", "np", "pt", "latent"}:
+ raise ValueError("output_type must be one of: 'pil', 'np', 'pt', 'latent'.")
+
+ native_size = self._default_image_size()
+ if height != native_size or width != native_size:
+ raise ValueError(
+ f"EDM2 expects native resolution height=width={native_size}. "
+ f"Got height={height}, width={width}."
+ )
+
+ def _normalize_class_labels(
+ self,
+ class_labels: Optional[Union[int, str, Sequence[Union[int, str]], torch.Tensor]],
+ batch_size: int,
+ device: torch.device,
+ ) -> Optional[torch.Tensor]:
+ label_dim = int(getattr(self.unet, "num_class_embeds", getattr(self.unet.config, "num_class_embeds", 0)))
+ if label_dim == 0:
+ return None
+ if class_labels is None:
+ indices = torch.randint(label_dim, size=(batch_size,), device=device)
+ return torch.eye(label_dim, device=device, dtype=torch.float32)[indices]
+
+ if isinstance(class_labels, str):
+ class_labels = self.get_label_ids(class_labels)[0]
+ elif isinstance(class_labels, Sequence) and class_labels and isinstance(class_labels[0], str):
+ class_labels = self.get_label_ids(list(class_labels))
+
+ if isinstance(class_labels, int):
+ indices = torch.full((batch_size,), class_labels, device=device, dtype=torch.long)
+ elif isinstance(class_labels, torch.Tensor):
+ if class_labels.ndim == 2:
+ labels = class_labels.to(device=device, dtype=torch.float32)
+ if labels.shape[0] != batch_size:
+ raise ValueError(f"class_labels batch must match batch_size={batch_size}.")
+ return labels
+ indices = class_labels.to(device=device, dtype=torch.long).flatten()
+ else:
+ indices = torch.tensor(list(class_labels), device=device, dtype=torch.long)
+
+ if indices.numel() == 1 and batch_size > 1:
+ indices = indices.repeat(batch_size)
+ if indices.numel() != batch_size:
+ raise ValueError(f"class_labels must resolve to batch size {batch_size}, got {indices.numel()}.")
+ return torch.eye(label_dim, device=device, dtype=torch.float32)[indices]
+
+ def prepare_latents(
+ self,
+ batch_size: int,
+ height: int,
+ width: int,
+ dtype: torch.dtype,
+ device: torch.device,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]],
+ ) -> torch.Tensor:
+ in_channels = int(getattr(self.unet, "in_channels", getattr(self.unet.config, "in_channels", 4)))
+ latent_size = height // self.vae_scale_factor
+ return randn_tensor(
+ (batch_size, in_channels, latent_size, latent_size),
+ generator=generator,
+ device=device,
+ dtype=torch.float32,
+ )
+
+ def decode_latents(self, latents: torch.Tensor, output_type: str = "pil"):
+ if output_type == "latent":
+ return latents
+
+ in_channels = int(getattr(self.unet, "in_channels", getattr(self.unet.config, "in_channels", 3)))
+ if self.vae is None:
+ image = (latents.to(torch.float32) * 127.5 + 128).clip(0, 255) / 255.0
+ return self.image_processor.postprocess(image, output_type=output_type)
+
+ if in_channels == 4:
+ x = latents.to(torch.float32)
+ scale = torch.as_tensor(_STABILITY_VAE_SCALE, dtype=x.dtype, device=x.device).reshape(1, -1, 1, 1)
+ bias = torch.as_tensor(_STABILITY_VAE_BIAS, dtype=x.dtype, device=x.device).reshape(1, -1, 1, 1)
+ x = (x - bias) / scale
+ else:
+ x = latents.to(torch.float32)
+
+ vae_dtype = getattr(self.vae, "dtype", None) or next(self.vae.parameters()).dtype
+ image = self.vae.decode(x.to(dtype=vae_dtype)).sample.to(torch.float32).clamp(0, 1)
+
+ return self.image_processor.postprocess(image, output_type=output_type)
+
+ @staticmethod
+ def _apply_autoguidance(
+ main: torch.Tensor,
+ ref: torch.Tensor,
+ guidance_scale: float,
+ ) -> torch.Tensor:
+ return ref.lerp(main, guidance_scale)
+
+ @staticmethod
+ def _sample_edm2_heun(
+ denoise_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
+ noise: torch.Tensor,
+ sigmas: torch.Tensor,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ progress_bar: Optional[Callable[[Iterable], Iterable]] = None,
+ dtype: torch.dtype = torch.float32,
+ ) -> torch.Tensor:
+ """NVlabs EDM2 Heun sampler (generate_images.edm_sampler, guidance=1, S_churn=0)."""
+ x_next = noise.to(dtype) * sigmas[0]
+
+ sigma_pairs = list(zip(sigmas[:-1], sigmas[1:]))
+ if progress_bar is not None:
+ sigma_pairs = progress_bar(sigma_pairs)
+
+ num_steps = len(sigma_pairs)
+ for i, (sigma_cur, sigma_next) in enumerate(sigma_pairs):
+ x_hat, sigma_hat = x_next, sigma_cur
+ d_cur = (x_hat - denoise_fn(x_hat, sigma_hat)) / sigma_hat
+ x_next = x_hat + (sigma_next - sigma_hat) * d_cur
+ if i < num_steps - 1:
+ d_prime = (x_next - denoise_fn(x_next, sigma_next)) / sigma_next
+ x_next = x_hat + (sigma_next - sigma_hat) * (0.5 * d_cur + 0.5 * d_prime)
+ return x_next
+
+ @torch.inference_mode()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ class_labels: Optional[Union[int, str, Sequence[Union[int, str]], torch.Tensor]] = None,
+ batch_size: int = 1,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 32,
+ guidance_scale: float = 1.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ output_type: str = "pil",
+ return_dict: bool = True,
+ ) -> Union[ImagePipelineOutput, Tuple]:
+ r"""
+ Generate class-conditional images with EDM2.
+
+ Args:
+ class_labels (`int`, `str`, `list[int]`, `list[str]`, or `torch.Tensor`, *optional*):
+ ImageNet class indices, English label strings, or one-hot float tensors.
+ Random classes are sampled when omitted on conditional models.
+ batch_size (`int`, defaults to `1`):
+ Number of images to generate.
+ height (`int`, *optional*):
+ Output height in pixels. Defaults to the pretrained native resolution.
+ width (`int`, *optional*):
+ Output width in pixels. Defaults to the pretrained native resolution.
+ num_inference_steps (`int`, defaults to `32`):
+ Number of EDM2 Heun steps (NVlabs default).
+ guidance_scale (`float`, defaults to `1.0`):
+ Autoguidance strength. Values above `1.0` blend the main net with `gnet`
+ via `gnet_output.lerp(unet_output, guidance_scale)`.
+ generator (`torch.Generator`, *optional*):
+ RNG for reproducibility.
+ output_type (`str`, defaults to `"pil"`):
+ `"pil"`, `"np"`, `"pt"`, or `"latent"`.
+ return_dict (`bool`, defaults to `True`):
+ Return [`~pipelines.pipeline_utils.ImagePipelineOutput`] if True.
+
+ Examples:
+
+ """
+ default_size = self._default_image_size()
+ height = int(height or default_size)
+ width = int(width or default_size)
+ self.check_inputs(height, width, num_inference_steps, guidance_scale, output_type)
+
+ device = self._execution_device
+ dtype = self.unet.dtype
+ labels = self._normalize_class_labels(class_labels, batch_size=batch_size, device=device)
+ noise = self.prepare_latents(batch_size, height, width, dtype, device, generator)
+
+ def denoise_fn(x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor:
+ sigma_batch = sigma.reshape(1).expand(batch_size)
+ main = self.unet(
+ sample=x,
+ sigma=sigma_batch,
+ class_labels=labels,
+ force_fp32=True,
+ ).sample
+ if guidance_scale == 1.0 or self.gnet is None:
+ return main.to(torch.float32)
+ ref = self.gnet(
+ sample=x,
+ sigma=sigma_batch,
+ class_labels=labels,
+ force_fp32=True,
+ ).sample
+ return self._apply_autoguidance(main, ref, guidance_scale).to(torch.float32)
+
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ latents = self._sample_edm2_heun(
+ denoise_fn=denoise_fn,
+ noise=noise,
+ sigmas=self.scheduler.sigmas.to(device),
+ generator=generator,
+ progress_bar=self.progress_bar,
+ dtype=torch.float32,
+ )
+
+ image = self.decode_latents(latents, output_type=output_type)
+ if not return_dict:
+ return (image, latents)
+ return ImagePipelineOutput(images=image)
+
+ @classmethod
+ def _load_vae(cls, pretrained_model_name_or_path: str, torch_dtype: Optional[torch.dtype] = None):
+ vae_dir = os.path.join(pretrained_model_name_or_path, "vae")
+ if os.path.isdir(vae_dir):
+ try:
+
+ return AutoencoderKL.from_pretrained(vae_dir, torch_dtype=torch_dtype)
+ except Exception:
+ return None
+
+ vae_hint = os.path.join(pretrained_model_name_or_path, "vae_pretrained_model_name_or_path.txt")
+ if os.path.isfile(vae_hint):
+ with open(vae_hint, "r", encoding="utf-8") as f:
+ hub_id = f.read().strip()
+ if hub_id:
+
+ return AutoencoderKL.from_pretrained(hub_id, torch_dtype=torch_dtype)
+ return None
diff --git a/edm2-img512-s-fid/scheduler/scheduler_config.json b/edm2-img512-s-fid/scheduler/scheduler_config.json
new file mode 100644
index 0000000000000000000000000000000000000000..08cbde97955d06d66b0b72be217db84bb797d711
--- /dev/null
+++ b/edm2-img512-s-fid/scheduler/scheduler_config.json
@@ -0,0 +1,11 @@
+{
+ "_class_name": "EDMEulerScheduler",
+ "final_sigmas_type": "zero",
+ "num_train_timesteps": 1000,
+ "prediction_type": "epsilon",
+ "rho": 7.0,
+ "sigma_data": 0.5,
+ "sigma_max": 80.0,
+ "sigma_min": 0.002,
+ "sigma_schedule": "karras"
+}
diff --git a/edm2-img512-s-fid/unet/config.json b/edm2-img512-s-fid/unet/config.json
new file mode 100644
index 0000000000000000000000000000000000000000..4f7eb106dcadc3dd4c2744ae18e255f6cbe420da
--- /dev/null
+++ b/edm2-img512-s-fid/unet/config.json
@@ -0,0 +1,31 @@
+{
+ "_class_name": "EDM2UNet2DModel",
+ "attn_balance": 0.3,
+ "attn_resolutions": [
+ 16,
+ 8
+ ],
+ "channel_mult": [
+ 1,
+ 2,
+ 3,
+ 4
+ ],
+ "channel_mult_emb": 4,
+ "channel_mult_noise": 1,
+ "channels_per_head": 64,
+ "clip_act": 256,
+ "concat_balance": 0.5,
+ "dropout": 0.0,
+ "in_channels": 4,
+ "label_balance": 0.5,
+ "logvar_channels": 128,
+ "model_channels": 192,
+ "num_blocks": 3,
+ "num_class_embeds": 1000,
+ "out_channels": 4,
+ "res_balance": 0.3,
+ "sample_size": 64,
+ "sigma_data": 0.5,
+ "use_fp16": true
+}
diff --git a/edm2-img512-s-fid/unet/diffusion_pytorch_model.safetensors b/edm2-img512-s-fid/unet/diffusion_pytorch_model.safetensors
new file mode 100644
index 0000000000000000000000000000000000000000..f7bc7bc5c9314fce180ac5f46ae971dfe08fb183
--- /dev/null
+++ b/edm2-img512-s-fid/unet/diffusion_pytorch_model.safetensors
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:5dee937e117e2367ede680aae4edf96635ff4debb9ae73f2617111991aa83d61
+size 1120876188
diff --git a/edm2-img512-s-fid/unet/unet_edm2.py b/edm2-img512-s-fid/unet/unet_edm2.py
new file mode 100644
index 0000000000000000000000000000000000000000..49d9451b443c4d6ad3b6ba2db19a7a804b07b6de
--- /dev/null
+++ b/edm2-img512-s-fid/unet/unet_edm2.py
@@ -0,0 +1,434 @@
+import math
+import json
+import os
+from dataclasses import dataclass
+from typing import List, Optional, Tuple
+
+import numpy as np
+import torch
+
+try:
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
+ from diffusers.models.modeling_utils import ModelMixin
+ from diffusers.utils import BaseOutput
+except ImportError: # pragma: no cover
+ class ModelMixin(torch.nn.Module):
+ pass
+
+ class ConfigMixin:
+ config = {}
+
+ def register_to_config(self, **kwargs):
+ self.config = kwargs
+
+ def register_to_config(func):
+ return func
+
+ @dataclass
+ class BaseOutput:
+ pass
+
+
+def normalize(x: torch.Tensor, dim: Optional[List[int]] = None, eps: float = 1e-4) -> torch.Tensor:
+ if dim is None:
+ dim = list(range(1, x.ndim))
+ norm = torch.linalg.vector_norm(x, dim=dim, keepdim=True, dtype=torch.float32)
+ norm = torch.add(eps, norm, alpha=math.sqrt(norm.numel() / x.numel()))
+ return x / norm.to(x.dtype)
+
+
+def resample(x: torch.Tensor, f: List[float], mode: str = "keep") -> torch.Tensor:
+ if mode == "keep":
+ return x
+ filt = np.float32(f)
+ pad = (len(filt) - 1) // 2
+ filt = filt / filt.sum()
+ filt = np.outer(filt, filt)[np.newaxis, np.newaxis, :, :]
+ filt = torch.as_tensor(filt, dtype=x.dtype, device=x.device)
+ c = x.shape[1]
+ if mode == "down":
+ return torch.nn.functional.conv2d(x, filt.tile([c, 1, 1, 1]), groups=c, stride=2, padding=(pad,))
+ return torch.nn.functional.conv_transpose2d(x, (filt * 4).tile([c, 1, 1, 1]), groups=c, stride=2, padding=(pad,))
+
+
+def mp_silu(x: torch.Tensor) -> torch.Tensor:
+ return torch.nn.functional.silu(x) / 0.596
+
+
+def mp_sum(a: torch.Tensor, b: torch.Tensor, t: float = 0.5) -> torch.Tensor:
+ return a.lerp(b, t) / math.sqrt((1 - t) ** 2 + t ** 2)
+
+
+def mp_cat(a: torch.Tensor, b: torch.Tensor, dim: int = 1, t: float = 0.5) -> torch.Tensor:
+ na = a.shape[dim]
+ nb = b.shape[dim]
+ c = math.sqrt((na + nb) / ((1 - t) ** 2 + t ** 2))
+ wa = c / math.sqrt(na) * (1 - t)
+ wb = c / math.sqrt(nb) * t
+ return torch.cat([wa * a, wb * b], dim=dim)
+
+
+class MPFourier(torch.nn.Module):
+ def __init__(self, num_channels: int, bandwidth: float = 1):
+ super().__init__()
+ self.register_buffer("freqs", 2 * math.pi * torch.randn(num_channels) * bandwidth)
+ self.register_buffer("phases", 2 * math.pi * torch.rand(num_channels))
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ y = x.to(torch.float32).ger(self.freqs.to(torch.float32))
+ y = y + self.phases.to(torch.float32)
+ y = y.cos() * math.sqrt(2)
+ return y.to(x.dtype)
+
+
+class MPConv(torch.nn.Module):
+ def __init__(self, in_channels: int, out_channels: int, kernel: Tuple[int, ...]):
+ super().__init__()
+ self.out_channels = out_channels
+ self.weight = torch.nn.Parameter(torch.randn(out_channels, in_channels, *kernel))
+
+ def forward(self, x: torch.Tensor, gain: float = 1) -> torch.Tensor:
+ w = self.weight.to(torch.float32)
+ if self.training:
+ with torch.no_grad():
+ self.weight.copy_(normalize(w))
+ w = normalize(w)
+ w = w * (gain / math.sqrt(w[0].numel()))
+ w = w.to(x.dtype)
+ if w.ndim == 2:
+ return x @ w.t()
+ return torch.nn.functional.conv2d(x, w, padding=(w.shape[-1] // 2,))
+
+
+class Block(torch.nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ emb_channels: int,
+ flavor: str = "enc",
+ resample_mode: str = "keep",
+ resample_filter: List[float] = [1, 1],
+ attention: bool = False,
+ channels_per_head: int = 64,
+ dropout: float = 0.0,
+ res_balance: float = 0.3,
+ attn_balance: float = 0.3,
+ clip_act: Optional[float] = 256,
+ ):
+ super().__init__()
+ self.out_channels = out_channels
+ self.flavor = flavor
+ self.resample_filter = resample_filter
+ self.resample_mode = resample_mode
+ self.num_heads = out_channels // channels_per_head if attention else 0
+ self.dropout = dropout
+ self.res_balance = res_balance
+ self.attn_balance = attn_balance
+ self.clip_act = clip_act
+ self.emb_gain = torch.nn.Parameter(torch.zeros([]))
+ self.conv_res0 = MPConv(out_channels if flavor == "enc" else in_channels, out_channels, kernel=(3, 3))
+ self.emb_linear = MPConv(emb_channels, out_channels, kernel=())
+ self.conv_res1 = MPConv(out_channels, out_channels, kernel=(3, 3))
+ self.conv_skip = MPConv(in_channels, out_channels, kernel=(1, 1)) if in_channels != out_channels else None
+ self.attn_qkv = MPConv(out_channels, out_channels * 3, kernel=(1, 1)) if self.num_heads else None
+ self.attn_proj = MPConv(out_channels, out_channels, kernel=(1, 1)) if self.num_heads else None
+
+ def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor:
+ x = resample(x, f=self.resample_filter, mode=self.resample_mode)
+ if self.flavor == "enc":
+ if self.conv_skip is not None:
+ x = self.conv_skip(x)
+ x = normalize(x, dim=[1])
+
+ y = self.conv_res0(mp_silu(x))
+ c = self.emb_linear(emb, gain=self.emb_gain) + 1
+ y = mp_silu(y * c.unsqueeze(2).unsqueeze(3).to(y.dtype))
+ if self.training and self.dropout:
+ y = torch.nn.functional.dropout(y, p=self.dropout)
+ y = self.conv_res1(y)
+
+ if self.flavor == "dec" and self.conv_skip is not None:
+ x = self.conv_skip(x)
+ x = mp_sum(x, y, t=self.res_balance)
+
+ if self.num_heads:
+ y = self.attn_qkv(x)
+ y = y.reshape(y.shape[0], self.num_heads, -1, 3, y.shape[2] * y.shape[3])
+ q, k, v = normalize(y, dim=[2]).unbind(3)
+ w = torch.einsum("nhcq,nhck->nhqk", q, k / math.sqrt(q.shape[2])).softmax(dim=3)
+ y = torch.einsum("nhqk,nhck->nhcq", w, v)
+ y = self.attn_proj(y.reshape(*x.shape))
+ x = mp_sum(x, y, t=self.attn_balance)
+
+ if self.clip_act is not None:
+ x = x.clip_(-self.clip_act, self.clip_act)
+ return x
+
+
+class EDM2UNet(torch.nn.Module):
+ def __init__(
+ self,
+ img_resolution: int,
+ img_channels: int,
+ label_dim: int,
+ model_channels: int = 192,
+ channel_mult: Tuple[int, ...] = (1, 2, 3, 4),
+ channel_mult_noise: Optional[int] = None,
+ channel_mult_emb: Optional[int] = None,
+ num_blocks: int = 3,
+ attn_resolutions: Tuple[int, ...] = (16, 8),
+ label_balance: float = 0.5,
+ concat_balance: float = 0.5,
+ **block_kwargs,
+ ):
+ super().__init__()
+ cblock = [model_channels * x for x in channel_mult]
+ cnoise = model_channels * channel_mult_noise if channel_mult_noise is not None else cblock[0]
+ cemb = model_channels * channel_mult_emb if channel_mult_emb is not None else max(cblock)
+ self.label_balance = label_balance
+ self.concat_balance = concat_balance
+ self.out_gain = torch.nn.Parameter(torch.zeros([]))
+
+ self.emb_fourier = MPFourier(cnoise)
+ self.emb_noise = MPConv(cnoise, cemb, kernel=())
+ self.emb_label = MPConv(label_dim, cemb, kernel=()) if label_dim else None
+
+ self.enc = torch.nn.ModuleDict()
+ cout = img_channels + 1
+ for level, channels in enumerate(cblock):
+ res = img_resolution >> level
+ if level == 0:
+ cin = cout
+ cout = channels
+ self.enc[f"{res}x{res}_conv"] = MPConv(cin, cout, kernel=(3, 3))
+ else:
+ self.enc[f"{res}x{res}_down"] = Block(cout, cout, cemb, flavor="enc", resample_mode="down", **block_kwargs)
+ for idx in range(num_blocks):
+ cin = cout
+ cout = channels
+ self.enc[f"{res}x{res}_block{idx}"] = Block(
+ cin,
+ cout,
+ cemb,
+ flavor="enc",
+ attention=(res in attn_resolutions),
+ **block_kwargs,
+ )
+
+ self.dec = torch.nn.ModuleDict()
+ skips = [block.out_channels for block in self.enc.values()]
+ for level, channels in reversed(list(enumerate(cblock))):
+ res = img_resolution >> level
+ if level == len(cblock) - 1:
+ self.dec[f"{res}x{res}_in0"] = Block(cout, cout, cemb, flavor="dec", attention=True, **block_kwargs)
+ self.dec[f"{res}x{res}_in1"] = Block(cout, cout, cemb, flavor="dec", **block_kwargs)
+ else:
+ self.dec[f"{res}x{res}_up"] = Block(cout, cout, cemb, flavor="dec", resample_mode="up", **block_kwargs)
+ for idx in range(num_blocks + 1):
+ cin = cout + skips.pop()
+ cout = channels
+ self.dec[f"{res}x{res}_block{idx}"] = Block(
+ cin,
+ cout,
+ cemb,
+ flavor="dec",
+ attention=(res in attn_resolutions),
+ **block_kwargs,
+ )
+
+ self.out_conv = MPConv(cout, img_channels, kernel=(3, 3))
+
+ def forward(self, x: torch.Tensor, noise_labels: torch.Tensor, class_labels: Optional[torch.Tensor]) -> torch.Tensor:
+ emb = self.emb_noise(self.emb_fourier(noise_labels))
+ if self.emb_label is not None:
+ if class_labels is None:
+ raise ValueError("class_labels are required for conditional EDM2UNet.")
+ emb = mp_sum(emb, self.emb_label(class_labels * math.sqrt(class_labels.shape[1])), t=self.label_balance)
+ emb = mp_silu(emb)
+
+ x = torch.cat([x, torch.ones_like(x[:, :1])], dim=1)
+ skips = []
+ for name, block in self.enc.items():
+ x = block(x) if "conv" in name else block(x, emb)
+ skips.append(x)
+
+ for name, block in self.dec.items():
+ if "block" in name:
+ x = mp_cat(x, skips.pop(), t=self.concat_balance)
+ x = block(x, emb)
+ return self.out_conv(x, gain=self.out_gain)
+
+
+@dataclass
+class EDM2UNet2DOutput(BaseOutput):
+ sample: torch.Tensor
+ logvar: Optional[torch.Tensor] = None
+
+
+
+_CONFIG_KEYS = (
+ "sample_size",
+ "in_channels",
+ "out_channels",
+ "num_class_embeds",
+ "use_fp16",
+ "sigma_data",
+ "logvar_channels",
+ "model_channels",
+ "channel_mult",
+ "channel_mult_noise",
+ "channel_mult_emb",
+ "num_blocks",
+ "attn_resolutions",
+ "label_balance",
+ "concat_balance",
+ "dropout",
+ "channels_per_head",
+ "res_balance",
+ "attn_balance",
+ "clip_act",
+)
+
+
+class EDM2UNet2DModel(ModelMixin, ConfigMixin):
+ @register_to_config
+ def __init__(
+ self,
+ sample_size: int = 64,
+ in_channels: int = 4,
+ out_channels: int = 4,
+ num_class_embeds: int = 0,
+ use_fp16: bool = True,
+ sigma_data: float = 0.5,
+ logvar_channels: int = 128,
+ model_channels: int = 192,
+ channel_mult: Tuple[int, ...] = (1, 2, 3, 4),
+ channel_mult_noise: Optional[int] = None,
+ channel_mult_emb: Optional[int] = None,
+ num_blocks: int = 3,
+ attn_resolutions: Tuple[int, ...] = (16, 8),
+ label_balance: float = 0.5,
+ concat_balance: float = 0.5,
+ dropout: float = 0.0,
+ channels_per_head: int = 64,
+ res_balance: float = 0.3,
+ attn_balance: float = 0.3,
+ clip_act: Optional[float] = 256,
+ ):
+ super().__init__()
+ self.sample_size = sample_size
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.num_class_embeds = num_class_embeds
+ self.use_fp16 = use_fp16
+ self.sigma_data = sigma_data
+ self.model_channels = model_channels
+ self.channel_mult = channel_mult
+ self.channel_mult_noise = channel_mult_noise
+ self.channel_mult_emb = channel_mult_emb
+ self.num_blocks = num_blocks
+ self.attn_resolutions = attn_resolutions
+ self.label_balance = label_balance
+ self.concat_balance = concat_balance
+ self.dropout = dropout
+ self.channels_per_head = channels_per_head
+ self.res_balance = res_balance
+ self.attn_balance = attn_balance
+ self.clip_act = clip_act
+ self.unet = EDM2UNet(
+ img_resolution=sample_size,
+ img_channels=in_channels,
+ label_dim=num_class_embeds,
+ model_channels=model_channels,
+ channel_mult=channel_mult,
+ channel_mult_noise=channel_mult_noise,
+ channel_mult_emb=channel_mult_emb,
+ num_blocks=num_blocks,
+ attn_resolutions=attn_resolutions,
+ label_balance=label_balance,
+ concat_balance=concat_balance,
+ dropout=dropout,
+ channels_per_head=channels_per_head,
+ res_balance=res_balance,
+ attn_balance=attn_balance,
+ clip_act=clip_act,
+ )
+ self.logvar_fourier = MPFourier(logvar_channels)
+ self.logvar_linear = MPConv(logvar_channels, 1, kernel=())
+
+ def forward(
+ self,
+ sample: torch.Tensor,
+ sigma: torch.Tensor,
+ class_labels: Optional[torch.Tensor] = None,
+ force_fp32: bool = False,
+ return_logvar: bool = False,
+ return_dict: bool = True,
+ ) -> EDM2UNet2DOutput:
+ x = sample.to(torch.float32)
+ sigma = sigma.to(torch.float32).reshape(-1, 1, 1, 1)
+ if self.num_class_embeds == 0:
+ class_labels = None
+ else:
+ if class_labels is None:
+ class_labels = torch.zeros([x.shape[0], self.num_class_embeds], device=x.device)
+ class_labels = class_labels.to(torch.float32).reshape(-1, self.num_class_embeds)
+ dtype = torch.float16 if (self.use_fp16 and not force_fp32 and x.device.type == "cuda") else torch.float32
+
+ c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2)
+ c_out = sigma * self.sigma_data / (sigma**2 + self.sigma_data**2).sqrt()
+ c_in = 1 / (self.sigma_data**2 + sigma**2).sqrt()
+ c_noise = sigma.flatten().log() / 4
+
+ x_in = (c_in * x).to(dtype)
+ f_x = self.unet(x_in, c_noise, class_labels)
+ d_x = c_skip * x + c_out * f_x.to(torch.float32)
+
+ logvar = None
+ if return_logvar:
+ logvar = self.logvar_linear(self.logvar_fourier(c_noise)).reshape(-1, 1, 1, 1)
+
+ if not return_dict:
+ return (d_x, logvar)
+ return EDM2UNet2DOutput(sample=d_x, logvar=logvar)
+
+ @classmethod
+ def from_pretrained(cls, pretrained_model_name_or_path: str, torch_dtype: Optional[torch.dtype] = None, **kwargs):
+ subfolder = kwargs.pop("subfolder", None)
+ model_dir = os.path.join(pretrained_model_name_or_path, subfolder) if subfolder else pretrained_model_name_or_path
+ with open(os.path.join(model_dir, "config.json"), "r", encoding="utf-8") as f:
+ config = json.load(f)
+ init_kwargs = {k: v for k, v in config.items() if k in _CONFIG_KEYS}
+ model = cls(**init_kwargs)
+ weight_file = os.path.join(model_dir, "diffusion_pytorch_model.safetensors")
+ if os.path.isfile(weight_file):
+ from safetensors.torch import load_file
+
+ state_dict = load_file(weight_file)
+ else:
+ state_dict = torch.load(os.path.join(model_dir, "diffusion_pytorch_model.bin"), map_location="cpu")
+ model.load_state_dict(state_dict, strict=True)
+ if torch_dtype is not None:
+ model = model.to(dtype=torch_dtype)
+ return model
+
+ def save_pretrained(self, save_directory: str, safe_serialization: bool = True):
+ os.makedirs(save_directory, exist_ok=True)
+ stored = dict(getattr(self, "config", {}))
+ config = {"_class_name": self.__class__.__name__}
+ for key in _CONFIG_KEYS:
+ if key in stored:
+ config[key] = stored[key]
+ elif hasattr(self, key):
+ config[key] = getattr(self, key)
+ with open(os.path.join(save_directory, "config.json"), "w", encoding="utf-8") as f:
+ json.dump(config, f, indent=2, sort_keys=True)
+ f.write("\n")
+ state_dict = self.state_dict()
+ if safe_serialization:
+ from safetensors.torch import save_file
+
+ save_file(state_dict, os.path.join(save_directory, "diffusion_pytorch_model.safetensors"))
+ else:
+ torch.save(state_dict, os.path.join(save_directory, "diffusion_pytorch_model.bin"))
diff --git a/edm2-img512-s-fid/vae/config.json b/edm2-img512-s-fid/vae/config.json
new file mode 100644
index 0000000000000000000000000000000000000000..ae2ab37de2acb8bc348d3de809ae8385324be962
--- /dev/null
+++ b/edm2-img512-s-fid/vae/config.json
@@ -0,0 +1,38 @@
+{
+ "_class_name": "AutoencoderKL",
+ "_diffusers_version": "0.36.0",
+ "_name_or_path": "stabilityai/sd-vae-ft-mse",
+ "act_fn": "silu",
+ "block_out_channels": [
+ 128,
+ 256,
+ 512,
+ 512
+ ],
+ "down_block_types": [
+ "DownEncoderBlock2D",
+ "DownEncoderBlock2D",
+ "DownEncoderBlock2D",
+ "DownEncoderBlock2D"
+ ],
+ "force_upcast": true,
+ "in_channels": 3,
+ "latent_channels": 4,
+ "latents_mean": null,
+ "latents_std": null,
+ "layers_per_block": 2,
+ "mid_block_add_attention": true,
+ "norm_num_groups": 32,
+ "out_channels": 3,
+ "sample_size": 256,
+ "scaling_factor": 0.18215,
+ "shift_factor": null,
+ "up_block_types": [
+ "UpDecoderBlock2D",
+ "UpDecoderBlock2D",
+ "UpDecoderBlock2D",
+ "UpDecoderBlock2D"
+ ],
+ "use_post_quant_conv": true,
+ "use_quant_conv": true
+}
diff --git a/edm2-img512-s-fid/vae/diffusion_pytorch_model.safetensors b/edm2-img512-s-fid/vae/diffusion_pytorch_model.safetensors
new file mode 100644
index 0000000000000000000000000000000000000000..90464d67ac7303d0ee4696334df13da130a948ea
--- /dev/null
+++ b/edm2-img512-s-fid/vae/diffusion_pytorch_model.safetensors
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a1d993488569e928462932c8c38a0760b874d166399b14414135bd9c42df5815
+size 334643276
diff --git a/edm2-img512-xl-fid/demo.png b/edm2-img512-xl-fid/demo.png
new file mode 100644
index 0000000000000000000000000000000000000000..e9210202336f46b96bc0b80245451fde9342e1e9
--- /dev/null
+++ b/edm2-img512-xl-fid/demo.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:551c91feb88ea0279f61d52c20463da670f01f99e37467a6f358b699f33cd526
+size 369559
diff --git a/edm2-img512-xl-fid/model_index.json b/edm2-img512-xl-fid/model_index.json
new file mode 100644
index 0000000000000000000000000000000000000000..24c0305eea3ec64abbc2f870476afa4f5e462f42
--- /dev/null
+++ b/edm2-img512-xl-fid/model_index.json
@@ -0,0 +1,19 @@
+{
+ "_class_name": [
+ "pipeline",
+ "EDM2Pipeline"
+ ],
+ "_diffusers_version": "0.31.0",
+ "scheduler": [
+ "diffusers",
+ "EDMEulerScheduler"
+ ],
+ "unet": [
+ "unet_edm2",
+ "EDM2UNet2DModel"
+ ],
+ "vae": [
+ "diffusers",
+ "AutoencoderKL"
+ ]
+}
diff --git a/edm2-img512-xl-fid/pipeline.py b/edm2-img512-xl-fid/pipeline.py
new file mode 100644
index 0000000000000000000000000000000000000000..6b4131215d07770c74efa0674d00f3d6423d9c06
--- /dev/null
+++ b/edm2-img512-xl-fid/pipeline.py
@@ -0,0 +1,406 @@
+"""Hub custom pipeline: EDM2Pipeline.
+Load with native Hugging Face diffusers and trust_remote_code=True.
+"""
+
+from __future__ import annotations
+
+# Copyright 2026 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+import os
+from pathlib import Path
+from typing import Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Union
+
+import numpy as np
+import torch
+
+from diffusers.image_processor import VaeImageProcessor
+from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
+from diffusers.utils import replace_example_docstring
+from diffusers.utils.torch_utils import randn_tensor
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> from pathlib import Path
+ >>> import torch
+ >>> from diffusers import DiffusionPipeline
+
+ >>> model_dir = Path("BiliSakura/EDM2-diffusers/edm2-img512-xs-fid").resolve()
+ >>> pipe = DiffusionPipeline.from_pretrained(
+ ... str(model_dir),
+ ... local_files_only=True,
+ ... custom_pipeline=str(model_dir / "pipeline.py"),
+ ... trust_remote_code=True,
+ ... torch_dtype=torch.float32,
+ ... )
+ >>> pipe.to("cuda")
+
+ >>> generator = torch.Generator(device="cuda").manual_seed(42)
+ >>> image = pipe(
+ ... class_labels=207,
+ ... num_inference_steps=32,
+ ... guidance_scale=1.0,
+ ... generator=generator,
+ ... ).images[0]
+ >>> image.save("demo.png")
+ ```
+"""
+
+# Default Stability VAE latent whitening used by NVlabs/edm2 (training/encoders.py).
+_STABILITY_VAE_SCALE = np.float32(0.5) / np.float32([4.17, 4.62, 3.71, 3.28])
+_STABILITY_VAE_BIAS = np.float32(0.0) - np.float32([5.81, 3.25, 0.12, -2.15]) * _STABILITY_VAE_SCALE
+
+class EDM2Pipeline(DiffusionPipeline):
+ r"""
+ Pipeline for class-conditional image generation with EDM2
+ ([Analyzing and Improving the Training Dynamics of Diffusion Models](https://arxiv.org/abs/2312.02696)).
+
+ Parameters:
+ unet ([`EDM2UNet2DModel`]):
+ Main magnitude-preserving U-Net with EDM preconditioning.
+ scheduler ([`EDMEulerScheduler`]):
+ Built-in diffusers scheduler used for the Karras sigma schedule. EDM2 Heun sampling runs in
+ the pipeline because the UNet returns denoised latents rather than noise predictions.
+ vae ([`AutoencoderKL`], *optional*):
+ Decoder for 512px latent-diffusion checkpoints. Required when `unet.in_channels == 4`.
+ gnet ([`EDM2UNet2DModel`], *optional*):
+ Guiding network for autoguidance (`ref.lerp(main, guidance_scale)`).
+ id2label (`dict[int, str]`, *optional*):
+ ImageNet class id to English label mapping.
+ """
+
+ model_cpu_offload_seq = "unet->gnet->vae"
+ _optional_components = ["vae", "gnet"]
+
+ def __init__(
+ self,
+ unet,
+ scheduler,
+ vae=None,
+ gnet=None,
+ id2label: Optional[Dict[Union[int, str], str]] = None,
+ ) -> None:
+ super().__init__()
+ self.register_modules(unet=unet, scheduler=scheduler, vae=vae, gnet=gnet)
+ self._id2label = self._normalize_id2label(id2label)
+ self.labels = self._build_label2id(self._id2label)
+ self._labels_loaded_from_model_index = bool(self._id2label)
+ self.vae_scale_factor = 8 if self.vae is not None else 1
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_normalize=False)
+
+ @staticmethod
+ def _normalize_id2label(id2label: Optional[Dict[Union[int, str], str]]) -> Dict[int, str]:
+ if not id2label:
+ return {}
+ return {int(key): value for key, value in id2label.items()}
+
+ @staticmethod
+ def _build_label2id(id2label: Dict[int, str]) -> Dict[str, int]:
+ label2id: Dict[str, int] = {}
+ for class_id, value in id2label.items():
+ for synonym in value.split(","):
+ synonym = synonym.strip()
+ if synonym:
+ label2id[synonym] = int(class_id)
+ return dict(sorted(label2id.items()))
+
+ def _ensure_labels_loaded(self) -> None:
+ if self._labels_loaded_from_model_index:
+ return
+ loaded = self._read_id2label_from_model_index(getattr(self.config, "_name_or_path", None))
+ if loaded:
+ self._id2label = loaded
+ self.labels = self._build_label2id(self._id2label)
+ self._labels_loaded_from_model_index = True
+
+ @staticmethod
+ def _read_id2label_from_model_index(variant_path: Optional[str]) -> Dict[int, str]:
+ if not variant_path:
+ return {}
+ model_index_path = Path(variant_path).resolve() / "model_index.json"
+ if not model_index_path.is_file():
+ return {}
+ raw = json.loads(model_index_path.read_text(encoding="utf-8"))
+ id2label = raw.get("id2label")
+ if not isinstance(id2label, dict):
+ return {}
+ return {int(key): value for key, value in id2label.items()}
+
+ @property
+ def id2label(self) -> Dict[int, str]:
+ r"""ImageNet class id to English label string (comma-separated synonyms)."""
+ self._ensure_labels_loaded()
+ return self._id2label
+
+ def get_label_ids(self, label: Union[str, List[str]]) -> List[int]:
+ r"""
+ Map ImageNet label strings to class ids.
+
+ Args:
+ label (`str` or `list[str]`):
+ One or more English label strings that match entries in `id2label`.
+ """
+ self._ensure_labels_loaded()
+ if not self.labels:
+ raise ValueError("No English labels loaded. Add `id2label` to model_index.json.")
+ labels = [label] if isinstance(label, str) else list(label)
+ missing = [item for item in labels if item not in self.labels]
+ if missing:
+ preview = ", ".join(list(self.labels.keys())[:8])
+ raise ValueError(f"Unknown English label(s): {missing}. Example valid labels: {preview}, ...")
+ return [self.labels[item] for item in labels]
+
+ def _default_image_size(self) -> int:
+ latent_size = int(getattr(self.unet, "sample_size", getattr(self.unet.config, "sample_size", 64)))
+ return latent_size * self.vae_scale_factor
+
+ def check_inputs(
+ self,
+ height: int,
+ width: int,
+ num_inference_steps: int,
+ guidance_scale: float,
+ output_type: str,
+ ) -> None:
+ if num_inference_steps < 1:
+ raise ValueError("num_inference_steps must be >= 1.")
+ if guidance_scale < 1.0:
+ raise ValueError("guidance_scale must be >= 1.0.")
+ if guidance_scale > 1.0 and self.gnet is None:
+ raise ValueError("guidance_scale > 1.0 requires a guiding network (`gnet`).")
+ if output_type not in {"pil", "np", "pt", "latent"}:
+ raise ValueError("output_type must be one of: 'pil', 'np', 'pt', 'latent'.")
+
+ native_size = self._default_image_size()
+ if height != native_size or width != native_size:
+ raise ValueError(
+ f"EDM2 expects native resolution height=width={native_size}. "
+ f"Got height={height}, width={width}."
+ )
+
+ def _normalize_class_labels(
+ self,
+ class_labels: Optional[Union[int, str, Sequence[Union[int, str]], torch.Tensor]],
+ batch_size: int,
+ device: torch.device,
+ ) -> Optional[torch.Tensor]:
+ label_dim = int(getattr(self.unet, "num_class_embeds", getattr(self.unet.config, "num_class_embeds", 0)))
+ if label_dim == 0:
+ return None
+ if class_labels is None:
+ indices = torch.randint(label_dim, size=(batch_size,), device=device)
+ return torch.eye(label_dim, device=device, dtype=torch.float32)[indices]
+
+ if isinstance(class_labels, str):
+ class_labels = self.get_label_ids(class_labels)[0]
+ elif isinstance(class_labels, Sequence) and class_labels and isinstance(class_labels[0], str):
+ class_labels = self.get_label_ids(list(class_labels))
+
+ if isinstance(class_labels, int):
+ indices = torch.full((batch_size,), class_labels, device=device, dtype=torch.long)
+ elif isinstance(class_labels, torch.Tensor):
+ if class_labels.ndim == 2:
+ labels = class_labels.to(device=device, dtype=torch.float32)
+ if labels.shape[0] != batch_size:
+ raise ValueError(f"class_labels batch must match batch_size={batch_size}.")
+ return labels
+ indices = class_labels.to(device=device, dtype=torch.long).flatten()
+ else:
+ indices = torch.tensor(list(class_labels), device=device, dtype=torch.long)
+
+ if indices.numel() == 1 and batch_size > 1:
+ indices = indices.repeat(batch_size)
+ if indices.numel() != batch_size:
+ raise ValueError(f"class_labels must resolve to batch size {batch_size}, got {indices.numel()}.")
+ return torch.eye(label_dim, device=device, dtype=torch.float32)[indices]
+
+ def prepare_latents(
+ self,
+ batch_size: int,
+ height: int,
+ width: int,
+ dtype: torch.dtype,
+ device: torch.device,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]],
+ ) -> torch.Tensor:
+ in_channels = int(getattr(self.unet, "in_channels", getattr(self.unet.config, "in_channels", 4)))
+ latent_size = height // self.vae_scale_factor
+ return randn_tensor(
+ (batch_size, in_channels, latent_size, latent_size),
+ generator=generator,
+ device=device,
+ dtype=torch.float32,
+ )
+
+ def decode_latents(self, latents: torch.Tensor, output_type: str = "pil"):
+ if output_type == "latent":
+ return latents
+
+ in_channels = int(getattr(self.unet, "in_channels", getattr(self.unet.config, "in_channels", 3)))
+ if self.vae is None:
+ image = (latents.to(torch.float32) * 127.5 + 128).clip(0, 255) / 255.0
+ return self.image_processor.postprocess(image, output_type=output_type)
+
+ if in_channels == 4:
+ x = latents.to(torch.float32)
+ scale = torch.as_tensor(_STABILITY_VAE_SCALE, dtype=x.dtype, device=x.device).reshape(1, -1, 1, 1)
+ bias = torch.as_tensor(_STABILITY_VAE_BIAS, dtype=x.dtype, device=x.device).reshape(1, -1, 1, 1)
+ x = (x - bias) / scale
+ else:
+ x = latents.to(torch.float32)
+
+ vae_dtype = getattr(self.vae, "dtype", None) or next(self.vae.parameters()).dtype
+ image = self.vae.decode(x.to(dtype=vae_dtype)).sample.to(torch.float32).clamp(0, 1)
+
+ return self.image_processor.postprocess(image, output_type=output_type)
+
+ @staticmethod
+ def _apply_autoguidance(
+ main: torch.Tensor,
+ ref: torch.Tensor,
+ guidance_scale: float,
+ ) -> torch.Tensor:
+ return ref.lerp(main, guidance_scale)
+
+ @staticmethod
+ def _sample_edm2_heun(
+ denoise_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
+ noise: torch.Tensor,
+ sigmas: torch.Tensor,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ progress_bar: Optional[Callable[[Iterable], Iterable]] = None,
+ dtype: torch.dtype = torch.float32,
+ ) -> torch.Tensor:
+ """NVlabs EDM2 Heun sampler (generate_images.edm_sampler, guidance=1, S_churn=0)."""
+ x_next = noise.to(dtype) * sigmas[0]
+
+ sigma_pairs = list(zip(sigmas[:-1], sigmas[1:]))
+ if progress_bar is not None:
+ sigma_pairs = progress_bar(sigma_pairs)
+
+ num_steps = len(sigma_pairs)
+ for i, (sigma_cur, sigma_next) in enumerate(sigma_pairs):
+ x_hat, sigma_hat = x_next, sigma_cur
+ d_cur = (x_hat - denoise_fn(x_hat, sigma_hat)) / sigma_hat
+ x_next = x_hat + (sigma_next - sigma_hat) * d_cur
+ if i < num_steps - 1:
+ d_prime = (x_next - denoise_fn(x_next, sigma_next)) / sigma_next
+ x_next = x_hat + (sigma_next - sigma_hat) * (0.5 * d_cur + 0.5 * d_prime)
+ return x_next
+
+ @torch.inference_mode()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ class_labels: Optional[Union[int, str, Sequence[Union[int, str]], torch.Tensor]] = None,
+ batch_size: int = 1,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 32,
+ guidance_scale: float = 1.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ output_type: str = "pil",
+ return_dict: bool = True,
+ ) -> Union[ImagePipelineOutput, Tuple]:
+ r"""
+ Generate class-conditional images with EDM2.
+
+ Args:
+ class_labels (`int`, `str`, `list[int]`, `list[str]`, or `torch.Tensor`, *optional*):
+ ImageNet class indices, English label strings, or one-hot float tensors.
+ Random classes are sampled when omitted on conditional models.
+ batch_size (`int`, defaults to `1`):
+ Number of images to generate.
+ height (`int`, *optional*):
+ Output height in pixels. Defaults to the pretrained native resolution.
+ width (`int`, *optional*):
+ Output width in pixels. Defaults to the pretrained native resolution.
+ num_inference_steps (`int`, defaults to `32`):
+ Number of EDM2 Heun steps (NVlabs default).
+ guidance_scale (`float`, defaults to `1.0`):
+ Autoguidance strength. Values above `1.0` blend the main net with `gnet`
+ via `gnet_output.lerp(unet_output, guidance_scale)`.
+ generator (`torch.Generator`, *optional*):
+ RNG for reproducibility.
+ output_type (`str`, defaults to `"pil"`):
+ `"pil"`, `"np"`, `"pt"`, or `"latent"`.
+ return_dict (`bool`, defaults to `True`):
+ Return [`~pipelines.pipeline_utils.ImagePipelineOutput`] if True.
+
+ Examples:
+
+ """
+ default_size = self._default_image_size()
+ height = int(height or default_size)
+ width = int(width or default_size)
+ self.check_inputs(height, width, num_inference_steps, guidance_scale, output_type)
+
+ device = self._execution_device
+ dtype = self.unet.dtype
+ labels = self._normalize_class_labels(class_labels, batch_size=batch_size, device=device)
+ noise = self.prepare_latents(batch_size, height, width, dtype, device, generator)
+
+ def denoise_fn(x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor:
+ sigma_batch = sigma.reshape(1).expand(batch_size)
+ main = self.unet(
+ sample=x,
+ sigma=sigma_batch,
+ class_labels=labels,
+ force_fp32=True,
+ ).sample
+ if guidance_scale == 1.0 or self.gnet is None:
+ return main.to(torch.float32)
+ ref = self.gnet(
+ sample=x,
+ sigma=sigma_batch,
+ class_labels=labels,
+ force_fp32=True,
+ ).sample
+ return self._apply_autoguidance(main, ref, guidance_scale).to(torch.float32)
+
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ latents = self._sample_edm2_heun(
+ denoise_fn=denoise_fn,
+ noise=noise,
+ sigmas=self.scheduler.sigmas.to(device),
+ generator=generator,
+ progress_bar=self.progress_bar,
+ dtype=torch.float32,
+ )
+
+ image = self.decode_latents(latents, output_type=output_type)
+ if not return_dict:
+ return (image, latents)
+ return ImagePipelineOutput(images=image)
+
+ @classmethod
+ def _load_vae(cls, pretrained_model_name_or_path: str, torch_dtype: Optional[torch.dtype] = None):
+ vae_dir = os.path.join(pretrained_model_name_or_path, "vae")
+ if os.path.isdir(vae_dir):
+ try:
+
+ return AutoencoderKL.from_pretrained(vae_dir, torch_dtype=torch_dtype)
+ except Exception:
+ return None
+
+ vae_hint = os.path.join(pretrained_model_name_or_path, "vae_pretrained_model_name_or_path.txt")
+ if os.path.isfile(vae_hint):
+ with open(vae_hint, "r", encoding="utf-8") as f:
+ hub_id = f.read().strip()
+ if hub_id:
+
+ return AutoencoderKL.from_pretrained(hub_id, torch_dtype=torch_dtype)
+ return None
diff --git a/edm2-img512-xl-fid/scheduler/scheduler_config.json b/edm2-img512-xl-fid/scheduler/scheduler_config.json
new file mode 100644
index 0000000000000000000000000000000000000000..08cbde97955d06d66b0b72be217db84bb797d711
--- /dev/null
+++ b/edm2-img512-xl-fid/scheduler/scheduler_config.json
@@ -0,0 +1,11 @@
+{
+ "_class_name": "EDMEulerScheduler",
+ "final_sigmas_type": "zero",
+ "num_train_timesteps": 1000,
+ "prediction_type": "epsilon",
+ "rho": 7.0,
+ "sigma_data": 0.5,
+ "sigma_max": 80.0,
+ "sigma_min": 0.002,
+ "sigma_schedule": "karras"
+}
diff --git a/edm2-img512-xl-fid/unet/config.json b/edm2-img512-xl-fid/unet/config.json
new file mode 100644
index 0000000000000000000000000000000000000000..c341f808b68df99197f84d725dea6b36328fdb5c
--- /dev/null
+++ b/edm2-img512-xl-fid/unet/config.json
@@ -0,0 +1,31 @@
+{
+ "_class_name": "EDM2UNet2DModel",
+ "attn_balance": 0.3,
+ "attn_resolutions": [
+ 16,
+ 8
+ ],
+ "channel_mult": [
+ 1,
+ 2,
+ 3,
+ 4
+ ],
+ "channel_mult_emb": 4,
+ "channel_mult_noise": 1,
+ "channels_per_head": 64,
+ "clip_act": 256,
+ "concat_balance": 0.5,
+ "dropout": 0.0,
+ "in_channels": 4,
+ "label_balance": 0.5,
+ "logvar_channels": 128,
+ "model_channels": 384,
+ "num_blocks": 3,
+ "num_class_embeds": 1000,
+ "out_channels": 4,
+ "res_balance": 0.3,
+ "sample_size": 64,
+ "sigma_data": 0.5,
+ "use_fp16": true
+}
diff --git a/edm2-img512-xl-fid/unet/diffusion_pytorch_model.safetensors b/edm2-img512-xl-fid/unet/diffusion_pytorch_model.safetensors
new file mode 100644
index 0000000000000000000000000000000000000000..c3d4d87fa62118756846646e93146ae8824f5c93
--- /dev/null
+++ b/edm2-img512-xl-fid/unet/diffusion_pytorch_model.safetensors
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:3c7402d8a4e91781b5c94fa2a5beee5820970ad99d2249141e191364885f222a
+size 4477161892
diff --git a/edm2-img512-xl-fid/unet/unet_edm2.py b/edm2-img512-xl-fid/unet/unet_edm2.py
new file mode 100644
index 0000000000000000000000000000000000000000..49d9451b443c4d6ad3b6ba2db19a7a804b07b6de
--- /dev/null
+++ b/edm2-img512-xl-fid/unet/unet_edm2.py
@@ -0,0 +1,434 @@
+import math
+import json
+import os
+from dataclasses import dataclass
+from typing import List, Optional, Tuple
+
+import numpy as np
+import torch
+
+try:
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
+ from diffusers.models.modeling_utils import ModelMixin
+ from diffusers.utils import BaseOutput
+except ImportError: # pragma: no cover
+ class ModelMixin(torch.nn.Module):
+ pass
+
+ class ConfigMixin:
+ config = {}
+
+ def register_to_config(self, **kwargs):
+ self.config = kwargs
+
+ def register_to_config(func):
+ return func
+
+ @dataclass
+ class BaseOutput:
+ pass
+
+
+def normalize(x: torch.Tensor, dim: Optional[List[int]] = None, eps: float = 1e-4) -> torch.Tensor:
+ if dim is None:
+ dim = list(range(1, x.ndim))
+ norm = torch.linalg.vector_norm(x, dim=dim, keepdim=True, dtype=torch.float32)
+ norm = torch.add(eps, norm, alpha=math.sqrt(norm.numel() / x.numel()))
+ return x / norm.to(x.dtype)
+
+
+def resample(x: torch.Tensor, f: List[float], mode: str = "keep") -> torch.Tensor:
+ if mode == "keep":
+ return x
+ filt = np.float32(f)
+ pad = (len(filt) - 1) // 2
+ filt = filt / filt.sum()
+ filt = np.outer(filt, filt)[np.newaxis, np.newaxis, :, :]
+ filt = torch.as_tensor(filt, dtype=x.dtype, device=x.device)
+ c = x.shape[1]
+ if mode == "down":
+ return torch.nn.functional.conv2d(x, filt.tile([c, 1, 1, 1]), groups=c, stride=2, padding=(pad,))
+ return torch.nn.functional.conv_transpose2d(x, (filt * 4).tile([c, 1, 1, 1]), groups=c, stride=2, padding=(pad,))
+
+
+def mp_silu(x: torch.Tensor) -> torch.Tensor:
+ return torch.nn.functional.silu(x) / 0.596
+
+
+def mp_sum(a: torch.Tensor, b: torch.Tensor, t: float = 0.5) -> torch.Tensor:
+ return a.lerp(b, t) / math.sqrt((1 - t) ** 2 + t ** 2)
+
+
+def mp_cat(a: torch.Tensor, b: torch.Tensor, dim: int = 1, t: float = 0.5) -> torch.Tensor:
+ na = a.shape[dim]
+ nb = b.shape[dim]
+ c = math.sqrt((na + nb) / ((1 - t) ** 2 + t ** 2))
+ wa = c / math.sqrt(na) * (1 - t)
+ wb = c / math.sqrt(nb) * t
+ return torch.cat([wa * a, wb * b], dim=dim)
+
+
+class MPFourier(torch.nn.Module):
+ def __init__(self, num_channels: int, bandwidth: float = 1):
+ super().__init__()
+ self.register_buffer("freqs", 2 * math.pi * torch.randn(num_channels) * bandwidth)
+ self.register_buffer("phases", 2 * math.pi * torch.rand(num_channels))
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ y = x.to(torch.float32).ger(self.freqs.to(torch.float32))
+ y = y + self.phases.to(torch.float32)
+ y = y.cos() * math.sqrt(2)
+ return y.to(x.dtype)
+
+
+class MPConv(torch.nn.Module):
+ def __init__(self, in_channels: int, out_channels: int, kernel: Tuple[int, ...]):
+ super().__init__()
+ self.out_channels = out_channels
+ self.weight = torch.nn.Parameter(torch.randn(out_channels, in_channels, *kernel))
+
+ def forward(self, x: torch.Tensor, gain: float = 1) -> torch.Tensor:
+ w = self.weight.to(torch.float32)
+ if self.training:
+ with torch.no_grad():
+ self.weight.copy_(normalize(w))
+ w = normalize(w)
+ w = w * (gain / math.sqrt(w[0].numel()))
+ w = w.to(x.dtype)
+ if w.ndim == 2:
+ return x @ w.t()
+ return torch.nn.functional.conv2d(x, w, padding=(w.shape[-1] // 2,))
+
+
+class Block(torch.nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ emb_channels: int,
+ flavor: str = "enc",
+ resample_mode: str = "keep",
+ resample_filter: List[float] = [1, 1],
+ attention: bool = False,
+ channels_per_head: int = 64,
+ dropout: float = 0.0,
+ res_balance: float = 0.3,
+ attn_balance: float = 0.3,
+ clip_act: Optional[float] = 256,
+ ):
+ super().__init__()
+ self.out_channels = out_channels
+ self.flavor = flavor
+ self.resample_filter = resample_filter
+ self.resample_mode = resample_mode
+ self.num_heads = out_channels // channels_per_head if attention else 0
+ self.dropout = dropout
+ self.res_balance = res_balance
+ self.attn_balance = attn_balance
+ self.clip_act = clip_act
+ self.emb_gain = torch.nn.Parameter(torch.zeros([]))
+ self.conv_res0 = MPConv(out_channels if flavor == "enc" else in_channels, out_channels, kernel=(3, 3))
+ self.emb_linear = MPConv(emb_channels, out_channels, kernel=())
+ self.conv_res1 = MPConv(out_channels, out_channels, kernel=(3, 3))
+ self.conv_skip = MPConv(in_channels, out_channels, kernel=(1, 1)) if in_channels != out_channels else None
+ self.attn_qkv = MPConv(out_channels, out_channels * 3, kernel=(1, 1)) if self.num_heads else None
+ self.attn_proj = MPConv(out_channels, out_channels, kernel=(1, 1)) if self.num_heads else None
+
+ def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor:
+ x = resample(x, f=self.resample_filter, mode=self.resample_mode)
+ if self.flavor == "enc":
+ if self.conv_skip is not None:
+ x = self.conv_skip(x)
+ x = normalize(x, dim=[1])
+
+ y = self.conv_res0(mp_silu(x))
+ c = self.emb_linear(emb, gain=self.emb_gain) + 1
+ y = mp_silu(y * c.unsqueeze(2).unsqueeze(3).to(y.dtype))
+ if self.training and self.dropout:
+ y = torch.nn.functional.dropout(y, p=self.dropout)
+ y = self.conv_res1(y)
+
+ if self.flavor == "dec" and self.conv_skip is not None:
+ x = self.conv_skip(x)
+ x = mp_sum(x, y, t=self.res_balance)
+
+ if self.num_heads:
+ y = self.attn_qkv(x)
+ y = y.reshape(y.shape[0], self.num_heads, -1, 3, y.shape[2] * y.shape[3])
+ q, k, v = normalize(y, dim=[2]).unbind(3)
+ w = torch.einsum("nhcq,nhck->nhqk", q, k / math.sqrt(q.shape[2])).softmax(dim=3)
+ y = torch.einsum("nhqk,nhck->nhcq", w, v)
+ y = self.attn_proj(y.reshape(*x.shape))
+ x = mp_sum(x, y, t=self.attn_balance)
+
+ if self.clip_act is not None:
+ x = x.clip_(-self.clip_act, self.clip_act)
+ return x
+
+
+class EDM2UNet(torch.nn.Module):
+ def __init__(
+ self,
+ img_resolution: int,
+ img_channels: int,
+ label_dim: int,
+ model_channels: int = 192,
+ channel_mult: Tuple[int, ...] = (1, 2, 3, 4),
+ channel_mult_noise: Optional[int] = None,
+ channel_mult_emb: Optional[int] = None,
+ num_blocks: int = 3,
+ attn_resolutions: Tuple[int, ...] = (16, 8),
+ label_balance: float = 0.5,
+ concat_balance: float = 0.5,
+ **block_kwargs,
+ ):
+ super().__init__()
+ cblock = [model_channels * x for x in channel_mult]
+ cnoise = model_channels * channel_mult_noise if channel_mult_noise is not None else cblock[0]
+ cemb = model_channels * channel_mult_emb if channel_mult_emb is not None else max(cblock)
+ self.label_balance = label_balance
+ self.concat_balance = concat_balance
+ self.out_gain = torch.nn.Parameter(torch.zeros([]))
+
+ self.emb_fourier = MPFourier(cnoise)
+ self.emb_noise = MPConv(cnoise, cemb, kernel=())
+ self.emb_label = MPConv(label_dim, cemb, kernel=()) if label_dim else None
+
+ self.enc = torch.nn.ModuleDict()
+ cout = img_channels + 1
+ for level, channels in enumerate(cblock):
+ res = img_resolution >> level
+ if level == 0:
+ cin = cout
+ cout = channels
+ self.enc[f"{res}x{res}_conv"] = MPConv(cin, cout, kernel=(3, 3))
+ else:
+ self.enc[f"{res}x{res}_down"] = Block(cout, cout, cemb, flavor="enc", resample_mode="down", **block_kwargs)
+ for idx in range(num_blocks):
+ cin = cout
+ cout = channels
+ self.enc[f"{res}x{res}_block{idx}"] = Block(
+ cin,
+ cout,
+ cemb,
+ flavor="enc",
+ attention=(res in attn_resolutions),
+ **block_kwargs,
+ )
+
+ self.dec = torch.nn.ModuleDict()
+ skips = [block.out_channels for block in self.enc.values()]
+ for level, channels in reversed(list(enumerate(cblock))):
+ res = img_resolution >> level
+ if level == len(cblock) - 1:
+ self.dec[f"{res}x{res}_in0"] = Block(cout, cout, cemb, flavor="dec", attention=True, **block_kwargs)
+ self.dec[f"{res}x{res}_in1"] = Block(cout, cout, cemb, flavor="dec", **block_kwargs)
+ else:
+ self.dec[f"{res}x{res}_up"] = Block(cout, cout, cemb, flavor="dec", resample_mode="up", **block_kwargs)
+ for idx in range(num_blocks + 1):
+ cin = cout + skips.pop()
+ cout = channels
+ self.dec[f"{res}x{res}_block{idx}"] = Block(
+ cin,
+ cout,
+ cemb,
+ flavor="dec",
+ attention=(res in attn_resolutions),
+ **block_kwargs,
+ )
+
+ self.out_conv = MPConv(cout, img_channels, kernel=(3, 3))
+
+ def forward(self, x: torch.Tensor, noise_labels: torch.Tensor, class_labels: Optional[torch.Tensor]) -> torch.Tensor:
+ emb = self.emb_noise(self.emb_fourier(noise_labels))
+ if self.emb_label is not None:
+ if class_labels is None:
+ raise ValueError("class_labels are required for conditional EDM2UNet.")
+ emb = mp_sum(emb, self.emb_label(class_labels * math.sqrt(class_labels.shape[1])), t=self.label_balance)
+ emb = mp_silu(emb)
+
+ x = torch.cat([x, torch.ones_like(x[:, :1])], dim=1)
+ skips = []
+ for name, block in self.enc.items():
+ x = block(x) if "conv" in name else block(x, emb)
+ skips.append(x)
+
+ for name, block in self.dec.items():
+ if "block" in name:
+ x = mp_cat(x, skips.pop(), t=self.concat_balance)
+ x = block(x, emb)
+ return self.out_conv(x, gain=self.out_gain)
+
+
+@dataclass
+class EDM2UNet2DOutput(BaseOutput):
+ sample: torch.Tensor
+ logvar: Optional[torch.Tensor] = None
+
+
+
+_CONFIG_KEYS = (
+ "sample_size",
+ "in_channels",
+ "out_channels",
+ "num_class_embeds",
+ "use_fp16",
+ "sigma_data",
+ "logvar_channels",
+ "model_channels",
+ "channel_mult",
+ "channel_mult_noise",
+ "channel_mult_emb",
+ "num_blocks",
+ "attn_resolutions",
+ "label_balance",
+ "concat_balance",
+ "dropout",
+ "channels_per_head",
+ "res_balance",
+ "attn_balance",
+ "clip_act",
+)
+
+
+class EDM2UNet2DModel(ModelMixin, ConfigMixin):
+ @register_to_config
+ def __init__(
+ self,
+ sample_size: int = 64,
+ in_channels: int = 4,
+ out_channels: int = 4,
+ num_class_embeds: int = 0,
+ use_fp16: bool = True,
+ sigma_data: float = 0.5,
+ logvar_channels: int = 128,
+ model_channels: int = 192,
+ channel_mult: Tuple[int, ...] = (1, 2, 3, 4),
+ channel_mult_noise: Optional[int] = None,
+ channel_mult_emb: Optional[int] = None,
+ num_blocks: int = 3,
+ attn_resolutions: Tuple[int, ...] = (16, 8),
+ label_balance: float = 0.5,
+ concat_balance: float = 0.5,
+ dropout: float = 0.0,
+ channels_per_head: int = 64,
+ res_balance: float = 0.3,
+ attn_balance: float = 0.3,
+ clip_act: Optional[float] = 256,
+ ):
+ super().__init__()
+ self.sample_size = sample_size
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.num_class_embeds = num_class_embeds
+ self.use_fp16 = use_fp16
+ self.sigma_data = sigma_data
+ self.model_channels = model_channels
+ self.channel_mult = channel_mult
+ self.channel_mult_noise = channel_mult_noise
+ self.channel_mult_emb = channel_mult_emb
+ self.num_blocks = num_blocks
+ self.attn_resolutions = attn_resolutions
+ self.label_balance = label_balance
+ self.concat_balance = concat_balance
+ self.dropout = dropout
+ self.channels_per_head = channels_per_head
+ self.res_balance = res_balance
+ self.attn_balance = attn_balance
+ self.clip_act = clip_act
+ self.unet = EDM2UNet(
+ img_resolution=sample_size,
+ img_channels=in_channels,
+ label_dim=num_class_embeds,
+ model_channels=model_channels,
+ channel_mult=channel_mult,
+ channel_mult_noise=channel_mult_noise,
+ channel_mult_emb=channel_mult_emb,
+ num_blocks=num_blocks,
+ attn_resolutions=attn_resolutions,
+ label_balance=label_balance,
+ concat_balance=concat_balance,
+ dropout=dropout,
+ channels_per_head=channels_per_head,
+ res_balance=res_balance,
+ attn_balance=attn_balance,
+ clip_act=clip_act,
+ )
+ self.logvar_fourier = MPFourier(logvar_channels)
+ self.logvar_linear = MPConv(logvar_channels, 1, kernel=())
+
+ def forward(
+ self,
+ sample: torch.Tensor,
+ sigma: torch.Tensor,
+ class_labels: Optional[torch.Tensor] = None,
+ force_fp32: bool = False,
+ return_logvar: bool = False,
+ return_dict: bool = True,
+ ) -> EDM2UNet2DOutput:
+ x = sample.to(torch.float32)
+ sigma = sigma.to(torch.float32).reshape(-1, 1, 1, 1)
+ if self.num_class_embeds == 0:
+ class_labels = None
+ else:
+ if class_labels is None:
+ class_labels = torch.zeros([x.shape[0], self.num_class_embeds], device=x.device)
+ class_labels = class_labels.to(torch.float32).reshape(-1, self.num_class_embeds)
+ dtype = torch.float16 if (self.use_fp16 and not force_fp32 and x.device.type == "cuda") else torch.float32
+
+ c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2)
+ c_out = sigma * self.sigma_data / (sigma**2 + self.sigma_data**2).sqrt()
+ c_in = 1 / (self.sigma_data**2 + sigma**2).sqrt()
+ c_noise = sigma.flatten().log() / 4
+
+ x_in = (c_in * x).to(dtype)
+ f_x = self.unet(x_in, c_noise, class_labels)
+ d_x = c_skip * x + c_out * f_x.to(torch.float32)
+
+ logvar = None
+ if return_logvar:
+ logvar = self.logvar_linear(self.logvar_fourier(c_noise)).reshape(-1, 1, 1, 1)
+
+ if not return_dict:
+ return (d_x, logvar)
+ return EDM2UNet2DOutput(sample=d_x, logvar=logvar)
+
+ @classmethod
+ def from_pretrained(cls, pretrained_model_name_or_path: str, torch_dtype: Optional[torch.dtype] = None, **kwargs):
+ subfolder = kwargs.pop("subfolder", None)
+ model_dir = os.path.join(pretrained_model_name_or_path, subfolder) if subfolder else pretrained_model_name_or_path
+ with open(os.path.join(model_dir, "config.json"), "r", encoding="utf-8") as f:
+ config = json.load(f)
+ init_kwargs = {k: v for k, v in config.items() if k in _CONFIG_KEYS}
+ model = cls(**init_kwargs)
+ weight_file = os.path.join(model_dir, "diffusion_pytorch_model.safetensors")
+ if os.path.isfile(weight_file):
+ from safetensors.torch import load_file
+
+ state_dict = load_file(weight_file)
+ else:
+ state_dict = torch.load(os.path.join(model_dir, "diffusion_pytorch_model.bin"), map_location="cpu")
+ model.load_state_dict(state_dict, strict=True)
+ if torch_dtype is not None:
+ model = model.to(dtype=torch_dtype)
+ return model
+
+ def save_pretrained(self, save_directory: str, safe_serialization: bool = True):
+ os.makedirs(save_directory, exist_ok=True)
+ stored = dict(getattr(self, "config", {}))
+ config = {"_class_name": self.__class__.__name__}
+ for key in _CONFIG_KEYS:
+ if key in stored:
+ config[key] = stored[key]
+ elif hasattr(self, key):
+ config[key] = getattr(self, key)
+ with open(os.path.join(save_directory, "config.json"), "w", encoding="utf-8") as f:
+ json.dump(config, f, indent=2, sort_keys=True)
+ f.write("\n")
+ state_dict = self.state_dict()
+ if safe_serialization:
+ from safetensors.torch import save_file
+
+ save_file(state_dict, os.path.join(save_directory, "diffusion_pytorch_model.safetensors"))
+ else:
+ torch.save(state_dict, os.path.join(save_directory, "diffusion_pytorch_model.bin"))
diff --git a/edm2-img512-xl-fid/vae/config.json b/edm2-img512-xl-fid/vae/config.json
new file mode 100644
index 0000000000000000000000000000000000000000..ae2ab37de2acb8bc348d3de809ae8385324be962
--- /dev/null
+++ b/edm2-img512-xl-fid/vae/config.json
@@ -0,0 +1,38 @@
+{
+ "_class_name": "AutoencoderKL",
+ "_diffusers_version": "0.36.0",
+ "_name_or_path": "stabilityai/sd-vae-ft-mse",
+ "act_fn": "silu",
+ "block_out_channels": [
+ 128,
+ 256,
+ 512,
+ 512
+ ],
+ "down_block_types": [
+ "DownEncoderBlock2D",
+ "DownEncoderBlock2D",
+ "DownEncoderBlock2D",
+ "DownEncoderBlock2D"
+ ],
+ "force_upcast": true,
+ "in_channels": 3,
+ "latent_channels": 4,
+ "latents_mean": null,
+ "latents_std": null,
+ "layers_per_block": 2,
+ "mid_block_add_attention": true,
+ "norm_num_groups": 32,
+ "out_channels": 3,
+ "sample_size": 256,
+ "scaling_factor": 0.18215,
+ "shift_factor": null,
+ "up_block_types": [
+ "UpDecoderBlock2D",
+ "UpDecoderBlock2D",
+ "UpDecoderBlock2D",
+ "UpDecoderBlock2D"
+ ],
+ "use_post_quant_conv": true,
+ "use_quant_conv": true
+}
diff --git a/edm2-img512-xl-fid/vae/diffusion_pytorch_model.safetensors b/edm2-img512-xl-fid/vae/diffusion_pytorch_model.safetensors
new file mode 100644
index 0000000000000000000000000000000000000000..90464d67ac7303d0ee4696334df13da130a948ea
--- /dev/null
+++ b/edm2-img512-xl-fid/vae/diffusion_pytorch_model.safetensors
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a1d993488569e928462932c8c38a0760b874d166399b14414135bd9c42df5815
+size 334643276
diff --git a/edm2-img512-xs-fid/demo.png b/edm2-img512-xs-fid/demo.png
new file mode 100644
index 0000000000000000000000000000000000000000..5f4b05769047e63d94737d04b4667949a76253ef
--- /dev/null
+++ b/edm2-img512-xs-fid/demo.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a5ceee02aab56e93c77b73e082ca5f952897a2bd98c1b78c1899f78845561785
+size 375611
diff --git a/edm2-img512-xs-fid/model_index.json b/edm2-img512-xs-fid/model_index.json
new file mode 100644
index 0000000000000000000000000000000000000000..24c0305eea3ec64abbc2f870476afa4f5e462f42
--- /dev/null
+++ b/edm2-img512-xs-fid/model_index.json
@@ -0,0 +1,19 @@
+{
+ "_class_name": [
+ "pipeline",
+ "EDM2Pipeline"
+ ],
+ "_diffusers_version": "0.31.0",
+ "scheduler": [
+ "diffusers",
+ "EDMEulerScheduler"
+ ],
+ "unet": [
+ "unet_edm2",
+ "EDM2UNet2DModel"
+ ],
+ "vae": [
+ "diffusers",
+ "AutoencoderKL"
+ ]
+}
diff --git a/edm2-img512-xs-fid/pipeline.py b/edm2-img512-xs-fid/pipeline.py
new file mode 100644
index 0000000000000000000000000000000000000000..6b4131215d07770c74efa0674d00f3d6423d9c06
--- /dev/null
+++ b/edm2-img512-xs-fid/pipeline.py
@@ -0,0 +1,406 @@
+"""Hub custom pipeline: EDM2Pipeline.
+Load with native Hugging Face diffusers and trust_remote_code=True.
+"""
+
+from __future__ import annotations
+
+# Copyright 2026 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+import os
+from pathlib import Path
+from typing import Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Union
+
+import numpy as np
+import torch
+
+from diffusers.image_processor import VaeImageProcessor
+from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
+from diffusers.utils import replace_example_docstring
+from diffusers.utils.torch_utils import randn_tensor
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> from pathlib import Path
+ >>> import torch
+ >>> from diffusers import DiffusionPipeline
+
+ >>> model_dir = Path("BiliSakura/EDM2-diffusers/edm2-img512-xs-fid").resolve()
+ >>> pipe = DiffusionPipeline.from_pretrained(
+ ... str(model_dir),
+ ... local_files_only=True,
+ ... custom_pipeline=str(model_dir / "pipeline.py"),
+ ... trust_remote_code=True,
+ ... torch_dtype=torch.float32,
+ ... )
+ >>> pipe.to("cuda")
+
+ >>> generator = torch.Generator(device="cuda").manual_seed(42)
+ >>> image = pipe(
+ ... class_labels=207,
+ ... num_inference_steps=32,
+ ... guidance_scale=1.0,
+ ... generator=generator,
+ ... ).images[0]
+ >>> image.save("demo.png")
+ ```
+"""
+
+# Default Stability VAE latent whitening used by NVlabs/edm2 (training/encoders.py).
+_STABILITY_VAE_SCALE = np.float32(0.5) / np.float32([4.17, 4.62, 3.71, 3.28])
+_STABILITY_VAE_BIAS = np.float32(0.0) - np.float32([5.81, 3.25, 0.12, -2.15]) * _STABILITY_VAE_SCALE
+
+class EDM2Pipeline(DiffusionPipeline):
+ r"""
+ Pipeline for class-conditional image generation with EDM2
+ ([Analyzing and Improving the Training Dynamics of Diffusion Models](https://arxiv.org/abs/2312.02696)).
+
+ Parameters:
+ unet ([`EDM2UNet2DModel`]):
+ Main magnitude-preserving U-Net with EDM preconditioning.
+ scheduler ([`EDMEulerScheduler`]):
+ Built-in diffusers scheduler used for the Karras sigma schedule. EDM2 Heun sampling runs in
+ the pipeline because the UNet returns denoised latents rather than noise predictions.
+ vae ([`AutoencoderKL`], *optional*):
+ Decoder for 512px latent-diffusion checkpoints. Required when `unet.in_channels == 4`.
+ gnet ([`EDM2UNet2DModel`], *optional*):
+ Guiding network for autoguidance (`ref.lerp(main, guidance_scale)`).
+ id2label (`dict[int, str]`, *optional*):
+ ImageNet class id to English label mapping.
+ """
+
+ model_cpu_offload_seq = "unet->gnet->vae"
+ _optional_components = ["vae", "gnet"]
+
+ def __init__(
+ self,
+ unet,
+ scheduler,
+ vae=None,
+ gnet=None,
+ id2label: Optional[Dict[Union[int, str], str]] = None,
+ ) -> None:
+ super().__init__()
+ self.register_modules(unet=unet, scheduler=scheduler, vae=vae, gnet=gnet)
+ self._id2label = self._normalize_id2label(id2label)
+ self.labels = self._build_label2id(self._id2label)
+ self._labels_loaded_from_model_index = bool(self._id2label)
+ self.vae_scale_factor = 8 if self.vae is not None else 1
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_normalize=False)
+
+ @staticmethod
+ def _normalize_id2label(id2label: Optional[Dict[Union[int, str], str]]) -> Dict[int, str]:
+ if not id2label:
+ return {}
+ return {int(key): value for key, value in id2label.items()}
+
+ @staticmethod
+ def _build_label2id(id2label: Dict[int, str]) -> Dict[str, int]:
+ label2id: Dict[str, int] = {}
+ for class_id, value in id2label.items():
+ for synonym in value.split(","):
+ synonym = synonym.strip()
+ if synonym:
+ label2id[synonym] = int(class_id)
+ return dict(sorted(label2id.items()))
+
+ def _ensure_labels_loaded(self) -> None:
+ if self._labels_loaded_from_model_index:
+ return
+ loaded = self._read_id2label_from_model_index(getattr(self.config, "_name_or_path", None))
+ if loaded:
+ self._id2label = loaded
+ self.labels = self._build_label2id(self._id2label)
+ self._labels_loaded_from_model_index = True
+
+ @staticmethod
+ def _read_id2label_from_model_index(variant_path: Optional[str]) -> Dict[int, str]:
+ if not variant_path:
+ return {}
+ model_index_path = Path(variant_path).resolve() / "model_index.json"
+ if not model_index_path.is_file():
+ return {}
+ raw = json.loads(model_index_path.read_text(encoding="utf-8"))
+ id2label = raw.get("id2label")
+ if not isinstance(id2label, dict):
+ return {}
+ return {int(key): value for key, value in id2label.items()}
+
+ @property
+ def id2label(self) -> Dict[int, str]:
+ r"""ImageNet class id to English label string (comma-separated synonyms)."""
+ self._ensure_labels_loaded()
+ return self._id2label
+
+ def get_label_ids(self, label: Union[str, List[str]]) -> List[int]:
+ r"""
+ Map ImageNet label strings to class ids.
+
+ Args:
+ label (`str` or `list[str]`):
+ One or more English label strings that match entries in `id2label`.
+ """
+ self._ensure_labels_loaded()
+ if not self.labels:
+ raise ValueError("No English labels loaded. Add `id2label` to model_index.json.")
+ labels = [label] if isinstance(label, str) else list(label)
+ missing = [item for item in labels if item not in self.labels]
+ if missing:
+ preview = ", ".join(list(self.labels.keys())[:8])
+ raise ValueError(f"Unknown English label(s): {missing}. Example valid labels: {preview}, ...")
+ return [self.labels[item] for item in labels]
+
+ def _default_image_size(self) -> int:
+ latent_size = int(getattr(self.unet, "sample_size", getattr(self.unet.config, "sample_size", 64)))
+ return latent_size * self.vae_scale_factor
+
+ def check_inputs(
+ self,
+ height: int,
+ width: int,
+ num_inference_steps: int,
+ guidance_scale: float,
+ output_type: str,
+ ) -> None:
+ if num_inference_steps < 1:
+ raise ValueError("num_inference_steps must be >= 1.")
+ if guidance_scale < 1.0:
+ raise ValueError("guidance_scale must be >= 1.0.")
+ if guidance_scale > 1.0 and self.gnet is None:
+ raise ValueError("guidance_scale > 1.0 requires a guiding network (`gnet`).")
+ if output_type not in {"pil", "np", "pt", "latent"}:
+ raise ValueError("output_type must be one of: 'pil', 'np', 'pt', 'latent'.")
+
+ native_size = self._default_image_size()
+ if height != native_size or width != native_size:
+ raise ValueError(
+ f"EDM2 expects native resolution height=width={native_size}. "
+ f"Got height={height}, width={width}."
+ )
+
+ def _normalize_class_labels(
+ self,
+ class_labels: Optional[Union[int, str, Sequence[Union[int, str]], torch.Tensor]],
+ batch_size: int,
+ device: torch.device,
+ ) -> Optional[torch.Tensor]:
+ label_dim = int(getattr(self.unet, "num_class_embeds", getattr(self.unet.config, "num_class_embeds", 0)))
+ if label_dim == 0:
+ return None
+ if class_labels is None:
+ indices = torch.randint(label_dim, size=(batch_size,), device=device)
+ return torch.eye(label_dim, device=device, dtype=torch.float32)[indices]
+
+ if isinstance(class_labels, str):
+ class_labels = self.get_label_ids(class_labels)[0]
+ elif isinstance(class_labels, Sequence) and class_labels and isinstance(class_labels[0], str):
+ class_labels = self.get_label_ids(list(class_labels))
+
+ if isinstance(class_labels, int):
+ indices = torch.full((batch_size,), class_labels, device=device, dtype=torch.long)
+ elif isinstance(class_labels, torch.Tensor):
+ if class_labels.ndim == 2:
+ labels = class_labels.to(device=device, dtype=torch.float32)
+ if labels.shape[0] != batch_size:
+ raise ValueError(f"class_labels batch must match batch_size={batch_size}.")
+ return labels
+ indices = class_labels.to(device=device, dtype=torch.long).flatten()
+ else:
+ indices = torch.tensor(list(class_labels), device=device, dtype=torch.long)
+
+ if indices.numel() == 1 and batch_size > 1:
+ indices = indices.repeat(batch_size)
+ if indices.numel() != batch_size:
+ raise ValueError(f"class_labels must resolve to batch size {batch_size}, got {indices.numel()}.")
+ return torch.eye(label_dim, device=device, dtype=torch.float32)[indices]
+
+ def prepare_latents(
+ self,
+ batch_size: int,
+ height: int,
+ width: int,
+ dtype: torch.dtype,
+ device: torch.device,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]],
+ ) -> torch.Tensor:
+ in_channels = int(getattr(self.unet, "in_channels", getattr(self.unet.config, "in_channels", 4)))
+ latent_size = height // self.vae_scale_factor
+ return randn_tensor(
+ (batch_size, in_channels, latent_size, latent_size),
+ generator=generator,
+ device=device,
+ dtype=torch.float32,
+ )
+
+ def decode_latents(self, latents: torch.Tensor, output_type: str = "pil"):
+ if output_type == "latent":
+ return latents
+
+ in_channels = int(getattr(self.unet, "in_channels", getattr(self.unet.config, "in_channels", 3)))
+ if self.vae is None:
+ image = (latents.to(torch.float32) * 127.5 + 128).clip(0, 255) / 255.0
+ return self.image_processor.postprocess(image, output_type=output_type)
+
+ if in_channels == 4:
+ x = latents.to(torch.float32)
+ scale = torch.as_tensor(_STABILITY_VAE_SCALE, dtype=x.dtype, device=x.device).reshape(1, -1, 1, 1)
+ bias = torch.as_tensor(_STABILITY_VAE_BIAS, dtype=x.dtype, device=x.device).reshape(1, -1, 1, 1)
+ x = (x - bias) / scale
+ else:
+ x = latents.to(torch.float32)
+
+ vae_dtype = getattr(self.vae, "dtype", None) or next(self.vae.parameters()).dtype
+ image = self.vae.decode(x.to(dtype=vae_dtype)).sample.to(torch.float32).clamp(0, 1)
+
+ return self.image_processor.postprocess(image, output_type=output_type)
+
+ @staticmethod
+ def _apply_autoguidance(
+ main: torch.Tensor,
+ ref: torch.Tensor,
+ guidance_scale: float,
+ ) -> torch.Tensor:
+ return ref.lerp(main, guidance_scale)
+
+ @staticmethod
+ def _sample_edm2_heun(
+ denoise_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
+ noise: torch.Tensor,
+ sigmas: torch.Tensor,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ progress_bar: Optional[Callable[[Iterable], Iterable]] = None,
+ dtype: torch.dtype = torch.float32,
+ ) -> torch.Tensor:
+ """NVlabs EDM2 Heun sampler (generate_images.edm_sampler, guidance=1, S_churn=0)."""
+ x_next = noise.to(dtype) * sigmas[0]
+
+ sigma_pairs = list(zip(sigmas[:-1], sigmas[1:]))
+ if progress_bar is not None:
+ sigma_pairs = progress_bar(sigma_pairs)
+
+ num_steps = len(sigma_pairs)
+ for i, (sigma_cur, sigma_next) in enumerate(sigma_pairs):
+ x_hat, sigma_hat = x_next, sigma_cur
+ d_cur = (x_hat - denoise_fn(x_hat, sigma_hat)) / sigma_hat
+ x_next = x_hat + (sigma_next - sigma_hat) * d_cur
+ if i < num_steps - 1:
+ d_prime = (x_next - denoise_fn(x_next, sigma_next)) / sigma_next
+ x_next = x_hat + (sigma_next - sigma_hat) * (0.5 * d_cur + 0.5 * d_prime)
+ return x_next
+
+ @torch.inference_mode()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ class_labels: Optional[Union[int, str, Sequence[Union[int, str]], torch.Tensor]] = None,
+ batch_size: int = 1,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 32,
+ guidance_scale: float = 1.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ output_type: str = "pil",
+ return_dict: bool = True,
+ ) -> Union[ImagePipelineOutput, Tuple]:
+ r"""
+ Generate class-conditional images with EDM2.
+
+ Args:
+ class_labels (`int`, `str`, `list[int]`, `list[str]`, or `torch.Tensor`, *optional*):
+ ImageNet class indices, English label strings, or one-hot float tensors.
+ Random classes are sampled when omitted on conditional models.
+ batch_size (`int`, defaults to `1`):
+ Number of images to generate.
+ height (`int`, *optional*):
+ Output height in pixels. Defaults to the pretrained native resolution.
+ width (`int`, *optional*):
+ Output width in pixels. Defaults to the pretrained native resolution.
+ num_inference_steps (`int`, defaults to `32`):
+ Number of EDM2 Heun steps (NVlabs default).
+ guidance_scale (`float`, defaults to `1.0`):
+ Autoguidance strength. Values above `1.0` blend the main net with `gnet`
+ via `gnet_output.lerp(unet_output, guidance_scale)`.
+ generator (`torch.Generator`, *optional*):
+ RNG for reproducibility.
+ output_type (`str`, defaults to `"pil"`):
+ `"pil"`, `"np"`, `"pt"`, or `"latent"`.
+ return_dict (`bool`, defaults to `True`):
+ Return [`~pipelines.pipeline_utils.ImagePipelineOutput`] if True.
+
+ Examples:
+
+ """
+ default_size = self._default_image_size()
+ height = int(height or default_size)
+ width = int(width or default_size)
+ self.check_inputs(height, width, num_inference_steps, guidance_scale, output_type)
+
+ device = self._execution_device
+ dtype = self.unet.dtype
+ labels = self._normalize_class_labels(class_labels, batch_size=batch_size, device=device)
+ noise = self.prepare_latents(batch_size, height, width, dtype, device, generator)
+
+ def denoise_fn(x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor:
+ sigma_batch = sigma.reshape(1).expand(batch_size)
+ main = self.unet(
+ sample=x,
+ sigma=sigma_batch,
+ class_labels=labels,
+ force_fp32=True,
+ ).sample
+ if guidance_scale == 1.0 or self.gnet is None:
+ return main.to(torch.float32)
+ ref = self.gnet(
+ sample=x,
+ sigma=sigma_batch,
+ class_labels=labels,
+ force_fp32=True,
+ ).sample
+ return self._apply_autoguidance(main, ref, guidance_scale).to(torch.float32)
+
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ latents = self._sample_edm2_heun(
+ denoise_fn=denoise_fn,
+ noise=noise,
+ sigmas=self.scheduler.sigmas.to(device),
+ generator=generator,
+ progress_bar=self.progress_bar,
+ dtype=torch.float32,
+ )
+
+ image = self.decode_latents(latents, output_type=output_type)
+ if not return_dict:
+ return (image, latents)
+ return ImagePipelineOutput(images=image)
+
+ @classmethod
+ def _load_vae(cls, pretrained_model_name_or_path: str, torch_dtype: Optional[torch.dtype] = None):
+ vae_dir = os.path.join(pretrained_model_name_or_path, "vae")
+ if os.path.isdir(vae_dir):
+ try:
+
+ return AutoencoderKL.from_pretrained(vae_dir, torch_dtype=torch_dtype)
+ except Exception:
+ return None
+
+ vae_hint = os.path.join(pretrained_model_name_or_path, "vae_pretrained_model_name_or_path.txt")
+ if os.path.isfile(vae_hint):
+ with open(vae_hint, "r", encoding="utf-8") as f:
+ hub_id = f.read().strip()
+ if hub_id:
+
+ return AutoencoderKL.from_pretrained(hub_id, torch_dtype=torch_dtype)
+ return None
diff --git a/edm2-img512-xs-fid/scheduler/scheduler_config.json b/edm2-img512-xs-fid/scheduler/scheduler_config.json
new file mode 100644
index 0000000000000000000000000000000000000000..08cbde97955d06d66b0b72be217db84bb797d711
--- /dev/null
+++ b/edm2-img512-xs-fid/scheduler/scheduler_config.json
@@ -0,0 +1,11 @@
+{
+ "_class_name": "EDMEulerScheduler",
+ "final_sigmas_type": "zero",
+ "num_train_timesteps": 1000,
+ "prediction_type": "epsilon",
+ "rho": 7.0,
+ "sigma_data": 0.5,
+ "sigma_max": 80.0,
+ "sigma_min": 0.002,
+ "sigma_schedule": "karras"
+}
diff --git a/edm2-img512-xs-fid/unet/config.json b/edm2-img512-xs-fid/unet/config.json
new file mode 100644
index 0000000000000000000000000000000000000000..6fceca931c569a93854b38007d87f2f5df9fea86
--- /dev/null
+++ b/edm2-img512-xs-fid/unet/config.json
@@ -0,0 +1,30 @@
+{
+ "_class_name": "EDM2UNet2DModel",
+ "attn_balance": 0.3,
+ "attn_resolutions": [
+ 16,
+ 8
+ ],
+ "channel_mult": [
+ 1,
+ 2,
+ 3,
+ 4
+ ],
+ "channel_mult_emb": 4,
+ "channel_mult_noise": 1,
+ "channels_per_head": 64,
+ "clip_act": 256,
+ "concat_balance": 0.5,
+ "dropout": 0.0,
+ "in_channels": 4,
+ "label_balance": 0.5,
+ "model_channels": 128,
+ "num_blocks": 3,
+ "num_class_embeds": 1000,
+ "out_channels": 4,
+ "res_balance": 0.3,
+ "sample_size": 64,
+ "sigma_data": 0.5,
+ "use_fp16": true
+}
diff --git a/edm2-img512-xs-fid/unet/diffusion_pytorch_model.safetensors b/edm2-img512-xs-fid/unet/diffusion_pytorch_model.safetensors
new file mode 100644
index 0000000000000000000000000000000000000000..0362009d25a7cbb31886d430ea8402ddd8ff951f
--- /dev/null
+++ b/edm2-img512-xs-fid/unet/diffusion_pytorch_model.safetensors
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:d5fe6a82ecdaa64b966a245a6f9179d159c702902ce9897bd60e21e80615a59b
+size 498877268
diff --git a/edm2-img512-xs-fid/unet/unet_edm2.py b/edm2-img512-xs-fid/unet/unet_edm2.py
new file mode 100644
index 0000000000000000000000000000000000000000..49d9451b443c4d6ad3b6ba2db19a7a804b07b6de
--- /dev/null
+++ b/edm2-img512-xs-fid/unet/unet_edm2.py
@@ -0,0 +1,434 @@
+import math
+import json
+import os
+from dataclasses import dataclass
+from typing import List, Optional, Tuple
+
+import numpy as np
+import torch
+
+try:
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
+ from diffusers.models.modeling_utils import ModelMixin
+ from diffusers.utils import BaseOutput
+except ImportError: # pragma: no cover
+ class ModelMixin(torch.nn.Module):
+ pass
+
+ class ConfigMixin:
+ config = {}
+
+ def register_to_config(self, **kwargs):
+ self.config = kwargs
+
+ def register_to_config(func):
+ return func
+
+ @dataclass
+ class BaseOutput:
+ pass
+
+
+def normalize(x: torch.Tensor, dim: Optional[List[int]] = None, eps: float = 1e-4) -> torch.Tensor:
+ if dim is None:
+ dim = list(range(1, x.ndim))
+ norm = torch.linalg.vector_norm(x, dim=dim, keepdim=True, dtype=torch.float32)
+ norm = torch.add(eps, norm, alpha=math.sqrt(norm.numel() / x.numel()))
+ return x / norm.to(x.dtype)
+
+
+def resample(x: torch.Tensor, f: List[float], mode: str = "keep") -> torch.Tensor:
+ if mode == "keep":
+ return x
+ filt = np.float32(f)
+ pad = (len(filt) - 1) // 2
+ filt = filt / filt.sum()
+ filt = np.outer(filt, filt)[np.newaxis, np.newaxis, :, :]
+ filt = torch.as_tensor(filt, dtype=x.dtype, device=x.device)
+ c = x.shape[1]
+ if mode == "down":
+ return torch.nn.functional.conv2d(x, filt.tile([c, 1, 1, 1]), groups=c, stride=2, padding=(pad,))
+ return torch.nn.functional.conv_transpose2d(x, (filt * 4).tile([c, 1, 1, 1]), groups=c, stride=2, padding=(pad,))
+
+
+def mp_silu(x: torch.Tensor) -> torch.Tensor:
+ return torch.nn.functional.silu(x) / 0.596
+
+
+def mp_sum(a: torch.Tensor, b: torch.Tensor, t: float = 0.5) -> torch.Tensor:
+ return a.lerp(b, t) / math.sqrt((1 - t) ** 2 + t ** 2)
+
+
+def mp_cat(a: torch.Tensor, b: torch.Tensor, dim: int = 1, t: float = 0.5) -> torch.Tensor:
+ na = a.shape[dim]
+ nb = b.shape[dim]
+ c = math.sqrt((na + nb) / ((1 - t) ** 2 + t ** 2))
+ wa = c / math.sqrt(na) * (1 - t)
+ wb = c / math.sqrt(nb) * t
+ return torch.cat([wa * a, wb * b], dim=dim)
+
+
+class MPFourier(torch.nn.Module):
+ def __init__(self, num_channels: int, bandwidth: float = 1):
+ super().__init__()
+ self.register_buffer("freqs", 2 * math.pi * torch.randn(num_channels) * bandwidth)
+ self.register_buffer("phases", 2 * math.pi * torch.rand(num_channels))
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ y = x.to(torch.float32).ger(self.freqs.to(torch.float32))
+ y = y + self.phases.to(torch.float32)
+ y = y.cos() * math.sqrt(2)
+ return y.to(x.dtype)
+
+
+class MPConv(torch.nn.Module):
+ def __init__(self, in_channels: int, out_channels: int, kernel: Tuple[int, ...]):
+ super().__init__()
+ self.out_channels = out_channels
+ self.weight = torch.nn.Parameter(torch.randn(out_channels, in_channels, *kernel))
+
+ def forward(self, x: torch.Tensor, gain: float = 1) -> torch.Tensor:
+ w = self.weight.to(torch.float32)
+ if self.training:
+ with torch.no_grad():
+ self.weight.copy_(normalize(w))
+ w = normalize(w)
+ w = w * (gain / math.sqrt(w[0].numel()))
+ w = w.to(x.dtype)
+ if w.ndim == 2:
+ return x @ w.t()
+ return torch.nn.functional.conv2d(x, w, padding=(w.shape[-1] // 2,))
+
+
+class Block(torch.nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ emb_channels: int,
+ flavor: str = "enc",
+ resample_mode: str = "keep",
+ resample_filter: List[float] = [1, 1],
+ attention: bool = False,
+ channels_per_head: int = 64,
+ dropout: float = 0.0,
+ res_balance: float = 0.3,
+ attn_balance: float = 0.3,
+ clip_act: Optional[float] = 256,
+ ):
+ super().__init__()
+ self.out_channels = out_channels
+ self.flavor = flavor
+ self.resample_filter = resample_filter
+ self.resample_mode = resample_mode
+ self.num_heads = out_channels // channels_per_head if attention else 0
+ self.dropout = dropout
+ self.res_balance = res_balance
+ self.attn_balance = attn_balance
+ self.clip_act = clip_act
+ self.emb_gain = torch.nn.Parameter(torch.zeros([]))
+ self.conv_res0 = MPConv(out_channels if flavor == "enc" else in_channels, out_channels, kernel=(3, 3))
+ self.emb_linear = MPConv(emb_channels, out_channels, kernel=())
+ self.conv_res1 = MPConv(out_channels, out_channels, kernel=(3, 3))
+ self.conv_skip = MPConv(in_channels, out_channels, kernel=(1, 1)) if in_channels != out_channels else None
+ self.attn_qkv = MPConv(out_channels, out_channels * 3, kernel=(1, 1)) if self.num_heads else None
+ self.attn_proj = MPConv(out_channels, out_channels, kernel=(1, 1)) if self.num_heads else None
+
+ def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor:
+ x = resample(x, f=self.resample_filter, mode=self.resample_mode)
+ if self.flavor == "enc":
+ if self.conv_skip is not None:
+ x = self.conv_skip(x)
+ x = normalize(x, dim=[1])
+
+ y = self.conv_res0(mp_silu(x))
+ c = self.emb_linear(emb, gain=self.emb_gain) + 1
+ y = mp_silu(y * c.unsqueeze(2).unsqueeze(3).to(y.dtype))
+ if self.training and self.dropout:
+ y = torch.nn.functional.dropout(y, p=self.dropout)
+ y = self.conv_res1(y)
+
+ if self.flavor == "dec" and self.conv_skip is not None:
+ x = self.conv_skip(x)
+ x = mp_sum(x, y, t=self.res_balance)
+
+ if self.num_heads:
+ y = self.attn_qkv(x)
+ y = y.reshape(y.shape[0], self.num_heads, -1, 3, y.shape[2] * y.shape[3])
+ q, k, v = normalize(y, dim=[2]).unbind(3)
+ w = torch.einsum("nhcq,nhck->nhqk", q, k / math.sqrt(q.shape[2])).softmax(dim=3)
+ y = torch.einsum("nhqk,nhck->nhcq", w, v)
+ y = self.attn_proj(y.reshape(*x.shape))
+ x = mp_sum(x, y, t=self.attn_balance)
+
+ if self.clip_act is not None:
+ x = x.clip_(-self.clip_act, self.clip_act)
+ return x
+
+
+class EDM2UNet(torch.nn.Module):
+ def __init__(
+ self,
+ img_resolution: int,
+ img_channels: int,
+ label_dim: int,
+ model_channels: int = 192,
+ channel_mult: Tuple[int, ...] = (1, 2, 3, 4),
+ channel_mult_noise: Optional[int] = None,
+ channel_mult_emb: Optional[int] = None,
+ num_blocks: int = 3,
+ attn_resolutions: Tuple[int, ...] = (16, 8),
+ label_balance: float = 0.5,
+ concat_balance: float = 0.5,
+ **block_kwargs,
+ ):
+ super().__init__()
+ cblock = [model_channels * x for x in channel_mult]
+ cnoise = model_channels * channel_mult_noise if channel_mult_noise is not None else cblock[0]
+ cemb = model_channels * channel_mult_emb if channel_mult_emb is not None else max(cblock)
+ self.label_balance = label_balance
+ self.concat_balance = concat_balance
+ self.out_gain = torch.nn.Parameter(torch.zeros([]))
+
+ self.emb_fourier = MPFourier(cnoise)
+ self.emb_noise = MPConv(cnoise, cemb, kernel=())
+ self.emb_label = MPConv(label_dim, cemb, kernel=()) if label_dim else None
+
+ self.enc = torch.nn.ModuleDict()
+ cout = img_channels + 1
+ for level, channels in enumerate(cblock):
+ res = img_resolution >> level
+ if level == 0:
+ cin = cout
+ cout = channels
+ self.enc[f"{res}x{res}_conv"] = MPConv(cin, cout, kernel=(3, 3))
+ else:
+ self.enc[f"{res}x{res}_down"] = Block(cout, cout, cemb, flavor="enc", resample_mode="down", **block_kwargs)
+ for idx in range(num_blocks):
+ cin = cout
+ cout = channels
+ self.enc[f"{res}x{res}_block{idx}"] = Block(
+ cin,
+ cout,
+ cemb,
+ flavor="enc",
+ attention=(res in attn_resolutions),
+ **block_kwargs,
+ )
+
+ self.dec = torch.nn.ModuleDict()
+ skips = [block.out_channels for block in self.enc.values()]
+ for level, channels in reversed(list(enumerate(cblock))):
+ res = img_resolution >> level
+ if level == len(cblock) - 1:
+ self.dec[f"{res}x{res}_in0"] = Block(cout, cout, cemb, flavor="dec", attention=True, **block_kwargs)
+ self.dec[f"{res}x{res}_in1"] = Block(cout, cout, cemb, flavor="dec", **block_kwargs)
+ else:
+ self.dec[f"{res}x{res}_up"] = Block(cout, cout, cemb, flavor="dec", resample_mode="up", **block_kwargs)
+ for idx in range(num_blocks + 1):
+ cin = cout + skips.pop()
+ cout = channels
+ self.dec[f"{res}x{res}_block{idx}"] = Block(
+ cin,
+ cout,
+ cemb,
+ flavor="dec",
+ attention=(res in attn_resolutions),
+ **block_kwargs,
+ )
+
+ self.out_conv = MPConv(cout, img_channels, kernel=(3, 3))
+
+ def forward(self, x: torch.Tensor, noise_labels: torch.Tensor, class_labels: Optional[torch.Tensor]) -> torch.Tensor:
+ emb = self.emb_noise(self.emb_fourier(noise_labels))
+ if self.emb_label is not None:
+ if class_labels is None:
+ raise ValueError("class_labels are required for conditional EDM2UNet.")
+ emb = mp_sum(emb, self.emb_label(class_labels * math.sqrt(class_labels.shape[1])), t=self.label_balance)
+ emb = mp_silu(emb)
+
+ x = torch.cat([x, torch.ones_like(x[:, :1])], dim=1)
+ skips = []
+ for name, block in self.enc.items():
+ x = block(x) if "conv" in name else block(x, emb)
+ skips.append(x)
+
+ for name, block in self.dec.items():
+ if "block" in name:
+ x = mp_cat(x, skips.pop(), t=self.concat_balance)
+ x = block(x, emb)
+ return self.out_conv(x, gain=self.out_gain)
+
+
+@dataclass
+class EDM2UNet2DOutput(BaseOutput):
+ sample: torch.Tensor
+ logvar: Optional[torch.Tensor] = None
+
+
+
+_CONFIG_KEYS = (
+ "sample_size",
+ "in_channels",
+ "out_channels",
+ "num_class_embeds",
+ "use_fp16",
+ "sigma_data",
+ "logvar_channels",
+ "model_channels",
+ "channel_mult",
+ "channel_mult_noise",
+ "channel_mult_emb",
+ "num_blocks",
+ "attn_resolutions",
+ "label_balance",
+ "concat_balance",
+ "dropout",
+ "channels_per_head",
+ "res_balance",
+ "attn_balance",
+ "clip_act",
+)
+
+
+class EDM2UNet2DModel(ModelMixin, ConfigMixin):
+ @register_to_config
+ def __init__(
+ self,
+ sample_size: int = 64,
+ in_channels: int = 4,
+ out_channels: int = 4,
+ num_class_embeds: int = 0,
+ use_fp16: bool = True,
+ sigma_data: float = 0.5,
+ logvar_channels: int = 128,
+ model_channels: int = 192,
+ channel_mult: Tuple[int, ...] = (1, 2, 3, 4),
+ channel_mult_noise: Optional[int] = None,
+ channel_mult_emb: Optional[int] = None,
+ num_blocks: int = 3,
+ attn_resolutions: Tuple[int, ...] = (16, 8),
+ label_balance: float = 0.5,
+ concat_balance: float = 0.5,
+ dropout: float = 0.0,
+ channels_per_head: int = 64,
+ res_balance: float = 0.3,
+ attn_balance: float = 0.3,
+ clip_act: Optional[float] = 256,
+ ):
+ super().__init__()
+ self.sample_size = sample_size
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.num_class_embeds = num_class_embeds
+ self.use_fp16 = use_fp16
+ self.sigma_data = sigma_data
+ self.model_channels = model_channels
+ self.channel_mult = channel_mult
+ self.channel_mult_noise = channel_mult_noise
+ self.channel_mult_emb = channel_mult_emb
+ self.num_blocks = num_blocks
+ self.attn_resolutions = attn_resolutions
+ self.label_balance = label_balance
+ self.concat_balance = concat_balance
+ self.dropout = dropout
+ self.channels_per_head = channels_per_head
+ self.res_balance = res_balance
+ self.attn_balance = attn_balance
+ self.clip_act = clip_act
+ self.unet = EDM2UNet(
+ img_resolution=sample_size,
+ img_channels=in_channels,
+ label_dim=num_class_embeds,
+ model_channels=model_channels,
+ channel_mult=channel_mult,
+ channel_mult_noise=channel_mult_noise,
+ channel_mult_emb=channel_mult_emb,
+ num_blocks=num_blocks,
+ attn_resolutions=attn_resolutions,
+ label_balance=label_balance,
+ concat_balance=concat_balance,
+ dropout=dropout,
+ channels_per_head=channels_per_head,
+ res_balance=res_balance,
+ attn_balance=attn_balance,
+ clip_act=clip_act,
+ )
+ self.logvar_fourier = MPFourier(logvar_channels)
+ self.logvar_linear = MPConv(logvar_channels, 1, kernel=())
+
+ def forward(
+ self,
+ sample: torch.Tensor,
+ sigma: torch.Tensor,
+ class_labels: Optional[torch.Tensor] = None,
+ force_fp32: bool = False,
+ return_logvar: bool = False,
+ return_dict: bool = True,
+ ) -> EDM2UNet2DOutput:
+ x = sample.to(torch.float32)
+ sigma = sigma.to(torch.float32).reshape(-1, 1, 1, 1)
+ if self.num_class_embeds == 0:
+ class_labels = None
+ else:
+ if class_labels is None:
+ class_labels = torch.zeros([x.shape[0], self.num_class_embeds], device=x.device)
+ class_labels = class_labels.to(torch.float32).reshape(-1, self.num_class_embeds)
+ dtype = torch.float16 if (self.use_fp16 and not force_fp32 and x.device.type == "cuda") else torch.float32
+
+ c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2)
+ c_out = sigma * self.sigma_data / (sigma**2 + self.sigma_data**2).sqrt()
+ c_in = 1 / (self.sigma_data**2 + sigma**2).sqrt()
+ c_noise = sigma.flatten().log() / 4
+
+ x_in = (c_in * x).to(dtype)
+ f_x = self.unet(x_in, c_noise, class_labels)
+ d_x = c_skip * x + c_out * f_x.to(torch.float32)
+
+ logvar = None
+ if return_logvar:
+ logvar = self.logvar_linear(self.logvar_fourier(c_noise)).reshape(-1, 1, 1, 1)
+
+ if not return_dict:
+ return (d_x, logvar)
+ return EDM2UNet2DOutput(sample=d_x, logvar=logvar)
+
+ @classmethod
+ def from_pretrained(cls, pretrained_model_name_or_path: str, torch_dtype: Optional[torch.dtype] = None, **kwargs):
+ subfolder = kwargs.pop("subfolder", None)
+ model_dir = os.path.join(pretrained_model_name_or_path, subfolder) if subfolder else pretrained_model_name_or_path
+ with open(os.path.join(model_dir, "config.json"), "r", encoding="utf-8") as f:
+ config = json.load(f)
+ init_kwargs = {k: v for k, v in config.items() if k in _CONFIG_KEYS}
+ model = cls(**init_kwargs)
+ weight_file = os.path.join(model_dir, "diffusion_pytorch_model.safetensors")
+ if os.path.isfile(weight_file):
+ from safetensors.torch import load_file
+
+ state_dict = load_file(weight_file)
+ else:
+ state_dict = torch.load(os.path.join(model_dir, "diffusion_pytorch_model.bin"), map_location="cpu")
+ model.load_state_dict(state_dict, strict=True)
+ if torch_dtype is not None:
+ model = model.to(dtype=torch_dtype)
+ return model
+
+ def save_pretrained(self, save_directory: str, safe_serialization: bool = True):
+ os.makedirs(save_directory, exist_ok=True)
+ stored = dict(getattr(self, "config", {}))
+ config = {"_class_name": self.__class__.__name__}
+ for key in _CONFIG_KEYS:
+ if key in stored:
+ config[key] = stored[key]
+ elif hasattr(self, key):
+ config[key] = getattr(self, key)
+ with open(os.path.join(save_directory, "config.json"), "w", encoding="utf-8") as f:
+ json.dump(config, f, indent=2, sort_keys=True)
+ f.write("\n")
+ state_dict = self.state_dict()
+ if safe_serialization:
+ from safetensors.torch import save_file
+
+ save_file(state_dict, os.path.join(save_directory, "diffusion_pytorch_model.safetensors"))
+ else:
+ torch.save(state_dict, os.path.join(save_directory, "diffusion_pytorch_model.bin"))
diff --git a/edm2-img512-xs-fid/vae/config.json b/edm2-img512-xs-fid/vae/config.json
new file mode 100644
index 0000000000000000000000000000000000000000..ae2ab37de2acb8bc348d3de809ae8385324be962
--- /dev/null
+++ b/edm2-img512-xs-fid/vae/config.json
@@ -0,0 +1,38 @@
+{
+ "_class_name": "AutoencoderKL",
+ "_diffusers_version": "0.36.0",
+ "_name_or_path": "stabilityai/sd-vae-ft-mse",
+ "act_fn": "silu",
+ "block_out_channels": [
+ 128,
+ 256,
+ 512,
+ 512
+ ],
+ "down_block_types": [
+ "DownEncoderBlock2D",
+ "DownEncoderBlock2D",
+ "DownEncoderBlock2D",
+ "DownEncoderBlock2D"
+ ],
+ "force_upcast": true,
+ "in_channels": 3,
+ "latent_channels": 4,
+ "latents_mean": null,
+ "latents_std": null,
+ "layers_per_block": 2,
+ "mid_block_add_attention": true,
+ "norm_num_groups": 32,
+ "out_channels": 3,
+ "sample_size": 256,
+ "scaling_factor": 0.18215,
+ "shift_factor": null,
+ "up_block_types": [
+ "UpDecoderBlock2D",
+ "UpDecoderBlock2D",
+ "UpDecoderBlock2D",
+ "UpDecoderBlock2D"
+ ],
+ "use_post_quant_conv": true,
+ "use_quant_conv": true
+}
diff --git a/edm2-img512-xs-fid/vae/diffusion_pytorch_model.safetensors b/edm2-img512-xs-fid/vae/diffusion_pytorch_model.safetensors
new file mode 100644
index 0000000000000000000000000000000000000000..90464d67ac7303d0ee4696334df13da130a948ea
--- /dev/null
+++ b/edm2-img512-xs-fid/vae/diffusion_pytorch_model.safetensors
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a1d993488569e928462932c8c38a0760b874d166399b14414135bd9c42df5815
+size 334643276
diff --git a/edm2-img512-xxl-fid/README.md b/edm2-img512-xxl-fid/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..7e6079717eab26ea8b83bd835b2c9386dc5852da
--- /dev/null
+++ b/edm2-img512-xxl-fid/README.md
@@ -0,0 +1,67 @@
+---
+license: cc-by-nc-sa-4.0
+library_name: diffusers
+pipeline_tag: unconditional-image-generation
+tags:
+ - diffusers
+ - edm2
+ - image-generation
+ - class-conditional
+ - imagenet
+inference: true
+widget:
+ - output:
+ url: demo.png
+language:
+ - en
+---
+
+# edm2-img512-xxl-fid
+
+Self-contained Diffusers checkpoint for **EDM2-XXL** at 512×512, optimized for FID (NVlabs preset `edm2-img512-xxl-fid`).
+
+Converted from [NVlabs/edm2](https://github.com/NVlabs/edm2) post-hoc reconstruction
+`edm2-img512-xxl-0939524-0.070.pkl` (FID 1.91).
+
+## Demo
+
+
+
+Class-conditional sample (ImageNet class **207**, golden retriever), 512×512, 32 steps, guidance 1.0, seed 42.
+
+## Load
+
+```python
+from pathlib import Path
+import torch
+from diffusers import DiffusionPipeline
+
+model_dir = Path(".")
+pipe = DiffusionPipeline.from_pretrained(
+ str(model_dir),
+ local_files_only=True,
+ trust_remote_code=True,
+ torch_dtype=torch.bfloat16,
+).to("cuda")
+
+generator = torch.Generator(device="cuda").manual_seed(42)
+image = pipe(
+ class_labels=207,
+ num_inference_steps=32,
+ guidance_scale=1.0,
+ generator=generator,
+).images[0]
+image.save("demo.png")
+```
+
+Official NVlabs defaults (`generate_images.py`): `num_steps=32`, `sigma_min=0.002`, `sigma_max=80`,
+`rho=7`, `guidance=1.0` (no gnet), `S_churn=0`. Heun sampling runs in float32 internally even when
+UNet/VAE weights are loaded in bf16/fp16.
+
+## Components
+
+- `pipeline.py`
+- `unet/unet_edm2.py`
+- `unet/diffusion_pytorch_model.safetensors`
+- `scheduler/scheduler_config.json` (`EDMEulerScheduler`)
+- `vae/diffusion_pytorch_model.safetensors` (`stabilityai/sd-vae-ft-mse`)
diff --git a/edm2-img512-xxl-fid/demo.png b/edm2-img512-xxl-fid/demo.png
new file mode 100644
index 0000000000000000000000000000000000000000..83e3dbac321e2b6cbcee88d03903a7e9c1c0439a
--- /dev/null
+++ b/edm2-img512-xxl-fid/demo.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:9da85f6cb881c112c3240fa5f62b3331af9700a656b882e01ebf9df4ea05660f
+size 374923
diff --git a/edm2-img512-xxl-fid/model_index.json b/edm2-img512-xxl-fid/model_index.json
new file mode 100644
index 0000000000000000000000000000000000000000..24c0305eea3ec64abbc2f870476afa4f5e462f42
--- /dev/null
+++ b/edm2-img512-xxl-fid/model_index.json
@@ -0,0 +1,19 @@
+{
+ "_class_name": [
+ "pipeline",
+ "EDM2Pipeline"
+ ],
+ "_diffusers_version": "0.31.0",
+ "scheduler": [
+ "diffusers",
+ "EDMEulerScheduler"
+ ],
+ "unet": [
+ "unet_edm2",
+ "EDM2UNet2DModel"
+ ],
+ "vae": [
+ "diffusers",
+ "AutoencoderKL"
+ ]
+}
diff --git a/edm2-img512-xxl-fid/pipeline.py b/edm2-img512-xxl-fid/pipeline.py
new file mode 100644
index 0000000000000000000000000000000000000000..6b4131215d07770c74efa0674d00f3d6423d9c06
--- /dev/null
+++ b/edm2-img512-xxl-fid/pipeline.py
@@ -0,0 +1,406 @@
+"""Hub custom pipeline: EDM2Pipeline.
+Load with native Hugging Face diffusers and trust_remote_code=True.
+"""
+
+from __future__ import annotations
+
+# Copyright 2026 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+import os
+from pathlib import Path
+from typing import Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Union
+
+import numpy as np
+import torch
+
+from diffusers.image_processor import VaeImageProcessor
+from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
+from diffusers.utils import replace_example_docstring
+from diffusers.utils.torch_utils import randn_tensor
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> from pathlib import Path
+ >>> import torch
+ >>> from diffusers import DiffusionPipeline
+
+ >>> model_dir = Path("BiliSakura/EDM2-diffusers/edm2-img512-xs-fid").resolve()
+ >>> pipe = DiffusionPipeline.from_pretrained(
+ ... str(model_dir),
+ ... local_files_only=True,
+ ... custom_pipeline=str(model_dir / "pipeline.py"),
+ ... trust_remote_code=True,
+ ... torch_dtype=torch.float32,
+ ... )
+ >>> pipe.to("cuda")
+
+ >>> generator = torch.Generator(device="cuda").manual_seed(42)
+ >>> image = pipe(
+ ... class_labels=207,
+ ... num_inference_steps=32,
+ ... guidance_scale=1.0,
+ ... generator=generator,
+ ... ).images[0]
+ >>> image.save("demo.png")
+ ```
+"""
+
+# Default Stability VAE latent whitening used by NVlabs/edm2 (training/encoders.py).
+_STABILITY_VAE_SCALE = np.float32(0.5) / np.float32([4.17, 4.62, 3.71, 3.28])
+_STABILITY_VAE_BIAS = np.float32(0.0) - np.float32([5.81, 3.25, 0.12, -2.15]) * _STABILITY_VAE_SCALE
+
+class EDM2Pipeline(DiffusionPipeline):
+ r"""
+ Pipeline for class-conditional image generation with EDM2
+ ([Analyzing and Improving the Training Dynamics of Diffusion Models](https://arxiv.org/abs/2312.02696)).
+
+ Parameters:
+ unet ([`EDM2UNet2DModel`]):
+ Main magnitude-preserving U-Net with EDM preconditioning.
+ scheduler ([`EDMEulerScheduler`]):
+ Built-in diffusers scheduler used for the Karras sigma schedule. EDM2 Heun sampling runs in
+ the pipeline because the UNet returns denoised latents rather than noise predictions.
+ vae ([`AutoencoderKL`], *optional*):
+ Decoder for 512px latent-diffusion checkpoints. Required when `unet.in_channels == 4`.
+ gnet ([`EDM2UNet2DModel`], *optional*):
+ Guiding network for autoguidance (`ref.lerp(main, guidance_scale)`).
+ id2label (`dict[int, str]`, *optional*):
+ ImageNet class id to English label mapping.
+ """
+
+ model_cpu_offload_seq = "unet->gnet->vae"
+ _optional_components = ["vae", "gnet"]
+
+ def __init__(
+ self,
+ unet,
+ scheduler,
+ vae=None,
+ gnet=None,
+ id2label: Optional[Dict[Union[int, str], str]] = None,
+ ) -> None:
+ super().__init__()
+ self.register_modules(unet=unet, scheduler=scheduler, vae=vae, gnet=gnet)
+ self._id2label = self._normalize_id2label(id2label)
+ self.labels = self._build_label2id(self._id2label)
+ self._labels_loaded_from_model_index = bool(self._id2label)
+ self.vae_scale_factor = 8 if self.vae is not None else 1
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_normalize=False)
+
+ @staticmethod
+ def _normalize_id2label(id2label: Optional[Dict[Union[int, str], str]]) -> Dict[int, str]:
+ if not id2label:
+ return {}
+ return {int(key): value for key, value in id2label.items()}
+
+ @staticmethod
+ def _build_label2id(id2label: Dict[int, str]) -> Dict[str, int]:
+ label2id: Dict[str, int] = {}
+ for class_id, value in id2label.items():
+ for synonym in value.split(","):
+ synonym = synonym.strip()
+ if synonym:
+ label2id[synonym] = int(class_id)
+ return dict(sorted(label2id.items()))
+
+ def _ensure_labels_loaded(self) -> None:
+ if self._labels_loaded_from_model_index:
+ return
+ loaded = self._read_id2label_from_model_index(getattr(self.config, "_name_or_path", None))
+ if loaded:
+ self._id2label = loaded
+ self.labels = self._build_label2id(self._id2label)
+ self._labels_loaded_from_model_index = True
+
+ @staticmethod
+ def _read_id2label_from_model_index(variant_path: Optional[str]) -> Dict[int, str]:
+ if not variant_path:
+ return {}
+ model_index_path = Path(variant_path).resolve() / "model_index.json"
+ if not model_index_path.is_file():
+ return {}
+ raw = json.loads(model_index_path.read_text(encoding="utf-8"))
+ id2label = raw.get("id2label")
+ if not isinstance(id2label, dict):
+ return {}
+ return {int(key): value for key, value in id2label.items()}
+
+ @property
+ def id2label(self) -> Dict[int, str]:
+ r"""ImageNet class id to English label string (comma-separated synonyms)."""
+ self._ensure_labels_loaded()
+ return self._id2label
+
+ def get_label_ids(self, label: Union[str, List[str]]) -> List[int]:
+ r"""
+ Map ImageNet label strings to class ids.
+
+ Args:
+ label (`str` or `list[str]`):
+ One or more English label strings that match entries in `id2label`.
+ """
+ self._ensure_labels_loaded()
+ if not self.labels:
+ raise ValueError("No English labels loaded. Add `id2label` to model_index.json.")
+ labels = [label] if isinstance(label, str) else list(label)
+ missing = [item for item in labels if item not in self.labels]
+ if missing:
+ preview = ", ".join(list(self.labels.keys())[:8])
+ raise ValueError(f"Unknown English label(s): {missing}. Example valid labels: {preview}, ...")
+ return [self.labels[item] for item in labels]
+
+ def _default_image_size(self) -> int:
+ latent_size = int(getattr(self.unet, "sample_size", getattr(self.unet.config, "sample_size", 64)))
+ return latent_size * self.vae_scale_factor
+
+ def check_inputs(
+ self,
+ height: int,
+ width: int,
+ num_inference_steps: int,
+ guidance_scale: float,
+ output_type: str,
+ ) -> None:
+ if num_inference_steps < 1:
+ raise ValueError("num_inference_steps must be >= 1.")
+ if guidance_scale < 1.0:
+ raise ValueError("guidance_scale must be >= 1.0.")
+ if guidance_scale > 1.0 and self.gnet is None:
+ raise ValueError("guidance_scale > 1.0 requires a guiding network (`gnet`).")
+ if output_type not in {"pil", "np", "pt", "latent"}:
+ raise ValueError("output_type must be one of: 'pil', 'np', 'pt', 'latent'.")
+
+ native_size = self._default_image_size()
+ if height != native_size or width != native_size:
+ raise ValueError(
+ f"EDM2 expects native resolution height=width={native_size}. "
+ f"Got height={height}, width={width}."
+ )
+
+ def _normalize_class_labels(
+ self,
+ class_labels: Optional[Union[int, str, Sequence[Union[int, str]], torch.Tensor]],
+ batch_size: int,
+ device: torch.device,
+ ) -> Optional[torch.Tensor]:
+ label_dim = int(getattr(self.unet, "num_class_embeds", getattr(self.unet.config, "num_class_embeds", 0)))
+ if label_dim == 0:
+ return None
+ if class_labels is None:
+ indices = torch.randint(label_dim, size=(batch_size,), device=device)
+ return torch.eye(label_dim, device=device, dtype=torch.float32)[indices]
+
+ if isinstance(class_labels, str):
+ class_labels = self.get_label_ids(class_labels)[0]
+ elif isinstance(class_labels, Sequence) and class_labels and isinstance(class_labels[0], str):
+ class_labels = self.get_label_ids(list(class_labels))
+
+ if isinstance(class_labels, int):
+ indices = torch.full((batch_size,), class_labels, device=device, dtype=torch.long)
+ elif isinstance(class_labels, torch.Tensor):
+ if class_labels.ndim == 2:
+ labels = class_labels.to(device=device, dtype=torch.float32)
+ if labels.shape[0] != batch_size:
+ raise ValueError(f"class_labels batch must match batch_size={batch_size}.")
+ return labels
+ indices = class_labels.to(device=device, dtype=torch.long).flatten()
+ else:
+ indices = torch.tensor(list(class_labels), device=device, dtype=torch.long)
+
+ if indices.numel() == 1 and batch_size > 1:
+ indices = indices.repeat(batch_size)
+ if indices.numel() != batch_size:
+ raise ValueError(f"class_labels must resolve to batch size {batch_size}, got {indices.numel()}.")
+ return torch.eye(label_dim, device=device, dtype=torch.float32)[indices]
+
+ def prepare_latents(
+ self,
+ batch_size: int,
+ height: int,
+ width: int,
+ dtype: torch.dtype,
+ device: torch.device,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]],
+ ) -> torch.Tensor:
+ in_channels = int(getattr(self.unet, "in_channels", getattr(self.unet.config, "in_channels", 4)))
+ latent_size = height // self.vae_scale_factor
+ return randn_tensor(
+ (batch_size, in_channels, latent_size, latent_size),
+ generator=generator,
+ device=device,
+ dtype=torch.float32,
+ )
+
+ def decode_latents(self, latents: torch.Tensor, output_type: str = "pil"):
+ if output_type == "latent":
+ return latents
+
+ in_channels = int(getattr(self.unet, "in_channels", getattr(self.unet.config, "in_channels", 3)))
+ if self.vae is None:
+ image = (latents.to(torch.float32) * 127.5 + 128).clip(0, 255) / 255.0
+ return self.image_processor.postprocess(image, output_type=output_type)
+
+ if in_channels == 4:
+ x = latents.to(torch.float32)
+ scale = torch.as_tensor(_STABILITY_VAE_SCALE, dtype=x.dtype, device=x.device).reshape(1, -1, 1, 1)
+ bias = torch.as_tensor(_STABILITY_VAE_BIAS, dtype=x.dtype, device=x.device).reshape(1, -1, 1, 1)
+ x = (x - bias) / scale
+ else:
+ x = latents.to(torch.float32)
+
+ vae_dtype = getattr(self.vae, "dtype", None) or next(self.vae.parameters()).dtype
+ image = self.vae.decode(x.to(dtype=vae_dtype)).sample.to(torch.float32).clamp(0, 1)
+
+ return self.image_processor.postprocess(image, output_type=output_type)
+
+ @staticmethod
+ def _apply_autoguidance(
+ main: torch.Tensor,
+ ref: torch.Tensor,
+ guidance_scale: float,
+ ) -> torch.Tensor:
+ return ref.lerp(main, guidance_scale)
+
+ @staticmethod
+ def _sample_edm2_heun(
+ denoise_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
+ noise: torch.Tensor,
+ sigmas: torch.Tensor,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ progress_bar: Optional[Callable[[Iterable], Iterable]] = None,
+ dtype: torch.dtype = torch.float32,
+ ) -> torch.Tensor:
+ """NVlabs EDM2 Heun sampler (generate_images.edm_sampler, guidance=1, S_churn=0)."""
+ x_next = noise.to(dtype) * sigmas[0]
+
+ sigma_pairs = list(zip(sigmas[:-1], sigmas[1:]))
+ if progress_bar is not None:
+ sigma_pairs = progress_bar(sigma_pairs)
+
+ num_steps = len(sigma_pairs)
+ for i, (sigma_cur, sigma_next) in enumerate(sigma_pairs):
+ x_hat, sigma_hat = x_next, sigma_cur
+ d_cur = (x_hat - denoise_fn(x_hat, sigma_hat)) / sigma_hat
+ x_next = x_hat + (sigma_next - sigma_hat) * d_cur
+ if i < num_steps - 1:
+ d_prime = (x_next - denoise_fn(x_next, sigma_next)) / sigma_next
+ x_next = x_hat + (sigma_next - sigma_hat) * (0.5 * d_cur + 0.5 * d_prime)
+ return x_next
+
+ @torch.inference_mode()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ class_labels: Optional[Union[int, str, Sequence[Union[int, str]], torch.Tensor]] = None,
+ batch_size: int = 1,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 32,
+ guidance_scale: float = 1.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ output_type: str = "pil",
+ return_dict: bool = True,
+ ) -> Union[ImagePipelineOutput, Tuple]:
+ r"""
+ Generate class-conditional images with EDM2.
+
+ Args:
+ class_labels (`int`, `str`, `list[int]`, `list[str]`, or `torch.Tensor`, *optional*):
+ ImageNet class indices, English label strings, or one-hot float tensors.
+ Random classes are sampled when omitted on conditional models.
+ batch_size (`int`, defaults to `1`):
+ Number of images to generate.
+ height (`int`, *optional*):
+ Output height in pixels. Defaults to the pretrained native resolution.
+ width (`int`, *optional*):
+ Output width in pixels. Defaults to the pretrained native resolution.
+ num_inference_steps (`int`, defaults to `32`):
+ Number of EDM2 Heun steps (NVlabs default).
+ guidance_scale (`float`, defaults to `1.0`):
+ Autoguidance strength. Values above `1.0` blend the main net with `gnet`
+ via `gnet_output.lerp(unet_output, guidance_scale)`.
+ generator (`torch.Generator`, *optional*):
+ RNG for reproducibility.
+ output_type (`str`, defaults to `"pil"`):
+ `"pil"`, `"np"`, `"pt"`, or `"latent"`.
+ return_dict (`bool`, defaults to `True`):
+ Return [`~pipelines.pipeline_utils.ImagePipelineOutput`] if True.
+
+ Examples:
+
+ """
+ default_size = self._default_image_size()
+ height = int(height or default_size)
+ width = int(width or default_size)
+ self.check_inputs(height, width, num_inference_steps, guidance_scale, output_type)
+
+ device = self._execution_device
+ dtype = self.unet.dtype
+ labels = self._normalize_class_labels(class_labels, batch_size=batch_size, device=device)
+ noise = self.prepare_latents(batch_size, height, width, dtype, device, generator)
+
+ def denoise_fn(x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor:
+ sigma_batch = sigma.reshape(1).expand(batch_size)
+ main = self.unet(
+ sample=x,
+ sigma=sigma_batch,
+ class_labels=labels,
+ force_fp32=True,
+ ).sample
+ if guidance_scale == 1.0 or self.gnet is None:
+ return main.to(torch.float32)
+ ref = self.gnet(
+ sample=x,
+ sigma=sigma_batch,
+ class_labels=labels,
+ force_fp32=True,
+ ).sample
+ return self._apply_autoguidance(main, ref, guidance_scale).to(torch.float32)
+
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ latents = self._sample_edm2_heun(
+ denoise_fn=denoise_fn,
+ noise=noise,
+ sigmas=self.scheduler.sigmas.to(device),
+ generator=generator,
+ progress_bar=self.progress_bar,
+ dtype=torch.float32,
+ )
+
+ image = self.decode_latents(latents, output_type=output_type)
+ if not return_dict:
+ return (image, latents)
+ return ImagePipelineOutput(images=image)
+
+ @classmethod
+ def _load_vae(cls, pretrained_model_name_or_path: str, torch_dtype: Optional[torch.dtype] = None):
+ vae_dir = os.path.join(pretrained_model_name_or_path, "vae")
+ if os.path.isdir(vae_dir):
+ try:
+
+ return AutoencoderKL.from_pretrained(vae_dir, torch_dtype=torch_dtype)
+ except Exception:
+ return None
+
+ vae_hint = os.path.join(pretrained_model_name_or_path, "vae_pretrained_model_name_or_path.txt")
+ if os.path.isfile(vae_hint):
+ with open(vae_hint, "r", encoding="utf-8") as f:
+ hub_id = f.read().strip()
+ if hub_id:
+
+ return AutoencoderKL.from_pretrained(hub_id, torch_dtype=torch_dtype)
+ return None
diff --git a/edm2-img512-xxl-fid/scheduler/scheduler_config.json b/edm2-img512-xxl-fid/scheduler/scheduler_config.json
new file mode 100644
index 0000000000000000000000000000000000000000..08cbde97955d06d66b0b72be217db84bb797d711
--- /dev/null
+++ b/edm2-img512-xxl-fid/scheduler/scheduler_config.json
@@ -0,0 +1,11 @@
+{
+ "_class_name": "EDMEulerScheduler",
+ "final_sigmas_type": "zero",
+ "num_train_timesteps": 1000,
+ "prediction_type": "epsilon",
+ "rho": 7.0,
+ "sigma_data": 0.5,
+ "sigma_max": 80.0,
+ "sigma_min": 0.002,
+ "sigma_schedule": "karras"
+}
diff --git a/edm2-img512-xxl-fid/unet/config.json b/edm2-img512-xxl-fid/unet/config.json
new file mode 100644
index 0000000000000000000000000000000000000000..e62e0c056bed5b2ae37ccc58d5bb2a064116d52e
--- /dev/null
+++ b/edm2-img512-xxl-fid/unet/config.json
@@ -0,0 +1,31 @@
+{
+ "_class_name": "EDM2UNet2DModel",
+ "attn_balance": 0.3,
+ "attn_resolutions": [
+ 16,
+ 8
+ ],
+ "channel_mult": [
+ 1,
+ 2,
+ 3,
+ 4
+ ],
+ "channel_mult_emb": 4,
+ "channel_mult_noise": 1,
+ "channels_per_head": 64,
+ "clip_act": 256,
+ "concat_balance": 0.5,
+ "dropout": 0.0,
+ "in_channels": 4,
+ "label_balance": 0.5,
+ "logvar_channels": 128,
+ "model_channels": 448,
+ "num_blocks": 3,
+ "num_class_embeds": 1000,
+ "out_channels": 4,
+ "res_balance": 0.3,
+ "sample_size": 64,
+ "sigma_data": 0.5,
+ "use_fp16": true
+}
diff --git a/edm2-img512-xxl-fid/unet/diffusion_pytorch_model.safetensors b/edm2-img512-xxl-fid/unet/diffusion_pytorch_model.safetensors
new file mode 100644
index 0000000000000000000000000000000000000000..67ed03655a4a647f47d640e3e0a97ca403eb4bfc
--- /dev/null
+++ b/edm2-img512-xxl-fid/unet/diffusion_pytorch_model.safetensors
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:018c4f1a23a207e787667d195a2df9522be8543590d72cd47f6420590700da2e
+size 6092686516
diff --git a/edm2-img512-xxl-fid/unet/unet_edm2.py b/edm2-img512-xxl-fid/unet/unet_edm2.py
new file mode 100644
index 0000000000000000000000000000000000000000..49d9451b443c4d6ad3b6ba2db19a7a804b07b6de
--- /dev/null
+++ b/edm2-img512-xxl-fid/unet/unet_edm2.py
@@ -0,0 +1,434 @@
+import math
+import json
+import os
+from dataclasses import dataclass
+from typing import List, Optional, Tuple
+
+import numpy as np
+import torch
+
+try:
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
+ from diffusers.models.modeling_utils import ModelMixin
+ from diffusers.utils import BaseOutput
+except ImportError: # pragma: no cover
+ class ModelMixin(torch.nn.Module):
+ pass
+
+ class ConfigMixin:
+ config = {}
+
+ def register_to_config(self, **kwargs):
+ self.config = kwargs
+
+ def register_to_config(func):
+ return func
+
+ @dataclass
+ class BaseOutput:
+ pass
+
+
+def normalize(x: torch.Tensor, dim: Optional[List[int]] = None, eps: float = 1e-4) -> torch.Tensor:
+ if dim is None:
+ dim = list(range(1, x.ndim))
+ norm = torch.linalg.vector_norm(x, dim=dim, keepdim=True, dtype=torch.float32)
+ norm = torch.add(eps, norm, alpha=math.sqrt(norm.numel() / x.numel()))
+ return x / norm.to(x.dtype)
+
+
+def resample(x: torch.Tensor, f: List[float], mode: str = "keep") -> torch.Tensor:
+ if mode == "keep":
+ return x
+ filt = np.float32(f)
+ pad = (len(filt) - 1) // 2
+ filt = filt / filt.sum()
+ filt = np.outer(filt, filt)[np.newaxis, np.newaxis, :, :]
+ filt = torch.as_tensor(filt, dtype=x.dtype, device=x.device)
+ c = x.shape[1]
+ if mode == "down":
+ return torch.nn.functional.conv2d(x, filt.tile([c, 1, 1, 1]), groups=c, stride=2, padding=(pad,))
+ return torch.nn.functional.conv_transpose2d(x, (filt * 4).tile([c, 1, 1, 1]), groups=c, stride=2, padding=(pad,))
+
+
+def mp_silu(x: torch.Tensor) -> torch.Tensor:
+ return torch.nn.functional.silu(x) / 0.596
+
+
+def mp_sum(a: torch.Tensor, b: torch.Tensor, t: float = 0.5) -> torch.Tensor:
+ return a.lerp(b, t) / math.sqrt((1 - t) ** 2 + t ** 2)
+
+
+def mp_cat(a: torch.Tensor, b: torch.Tensor, dim: int = 1, t: float = 0.5) -> torch.Tensor:
+ na = a.shape[dim]
+ nb = b.shape[dim]
+ c = math.sqrt((na + nb) / ((1 - t) ** 2 + t ** 2))
+ wa = c / math.sqrt(na) * (1 - t)
+ wb = c / math.sqrt(nb) * t
+ return torch.cat([wa * a, wb * b], dim=dim)
+
+
+class MPFourier(torch.nn.Module):
+ def __init__(self, num_channels: int, bandwidth: float = 1):
+ super().__init__()
+ self.register_buffer("freqs", 2 * math.pi * torch.randn(num_channels) * bandwidth)
+ self.register_buffer("phases", 2 * math.pi * torch.rand(num_channels))
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ y = x.to(torch.float32).ger(self.freqs.to(torch.float32))
+ y = y + self.phases.to(torch.float32)
+ y = y.cos() * math.sqrt(2)
+ return y.to(x.dtype)
+
+
+class MPConv(torch.nn.Module):
+ def __init__(self, in_channels: int, out_channels: int, kernel: Tuple[int, ...]):
+ super().__init__()
+ self.out_channels = out_channels
+ self.weight = torch.nn.Parameter(torch.randn(out_channels, in_channels, *kernel))
+
+ def forward(self, x: torch.Tensor, gain: float = 1) -> torch.Tensor:
+ w = self.weight.to(torch.float32)
+ if self.training:
+ with torch.no_grad():
+ self.weight.copy_(normalize(w))
+ w = normalize(w)
+ w = w * (gain / math.sqrt(w[0].numel()))
+ w = w.to(x.dtype)
+ if w.ndim == 2:
+ return x @ w.t()
+ return torch.nn.functional.conv2d(x, w, padding=(w.shape[-1] // 2,))
+
+
+class Block(torch.nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ emb_channels: int,
+ flavor: str = "enc",
+ resample_mode: str = "keep",
+ resample_filter: List[float] = [1, 1],
+ attention: bool = False,
+ channels_per_head: int = 64,
+ dropout: float = 0.0,
+ res_balance: float = 0.3,
+ attn_balance: float = 0.3,
+ clip_act: Optional[float] = 256,
+ ):
+ super().__init__()
+ self.out_channels = out_channels
+ self.flavor = flavor
+ self.resample_filter = resample_filter
+ self.resample_mode = resample_mode
+ self.num_heads = out_channels // channels_per_head if attention else 0
+ self.dropout = dropout
+ self.res_balance = res_balance
+ self.attn_balance = attn_balance
+ self.clip_act = clip_act
+ self.emb_gain = torch.nn.Parameter(torch.zeros([]))
+ self.conv_res0 = MPConv(out_channels if flavor == "enc" else in_channels, out_channels, kernel=(3, 3))
+ self.emb_linear = MPConv(emb_channels, out_channels, kernel=())
+ self.conv_res1 = MPConv(out_channels, out_channels, kernel=(3, 3))
+ self.conv_skip = MPConv(in_channels, out_channels, kernel=(1, 1)) if in_channels != out_channels else None
+ self.attn_qkv = MPConv(out_channels, out_channels * 3, kernel=(1, 1)) if self.num_heads else None
+ self.attn_proj = MPConv(out_channels, out_channels, kernel=(1, 1)) if self.num_heads else None
+
+ def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor:
+ x = resample(x, f=self.resample_filter, mode=self.resample_mode)
+ if self.flavor == "enc":
+ if self.conv_skip is not None:
+ x = self.conv_skip(x)
+ x = normalize(x, dim=[1])
+
+ y = self.conv_res0(mp_silu(x))
+ c = self.emb_linear(emb, gain=self.emb_gain) + 1
+ y = mp_silu(y * c.unsqueeze(2).unsqueeze(3).to(y.dtype))
+ if self.training and self.dropout:
+ y = torch.nn.functional.dropout(y, p=self.dropout)
+ y = self.conv_res1(y)
+
+ if self.flavor == "dec" and self.conv_skip is not None:
+ x = self.conv_skip(x)
+ x = mp_sum(x, y, t=self.res_balance)
+
+ if self.num_heads:
+ y = self.attn_qkv(x)
+ y = y.reshape(y.shape[0], self.num_heads, -1, 3, y.shape[2] * y.shape[3])
+ q, k, v = normalize(y, dim=[2]).unbind(3)
+ w = torch.einsum("nhcq,nhck->nhqk", q, k / math.sqrt(q.shape[2])).softmax(dim=3)
+ y = torch.einsum("nhqk,nhck->nhcq", w, v)
+ y = self.attn_proj(y.reshape(*x.shape))
+ x = mp_sum(x, y, t=self.attn_balance)
+
+ if self.clip_act is not None:
+ x = x.clip_(-self.clip_act, self.clip_act)
+ return x
+
+
+class EDM2UNet(torch.nn.Module):
+ def __init__(
+ self,
+ img_resolution: int,
+ img_channels: int,
+ label_dim: int,
+ model_channels: int = 192,
+ channel_mult: Tuple[int, ...] = (1, 2, 3, 4),
+ channel_mult_noise: Optional[int] = None,
+ channel_mult_emb: Optional[int] = None,
+ num_blocks: int = 3,
+ attn_resolutions: Tuple[int, ...] = (16, 8),
+ label_balance: float = 0.5,
+ concat_balance: float = 0.5,
+ **block_kwargs,
+ ):
+ super().__init__()
+ cblock = [model_channels * x for x in channel_mult]
+ cnoise = model_channels * channel_mult_noise if channel_mult_noise is not None else cblock[0]
+ cemb = model_channels * channel_mult_emb if channel_mult_emb is not None else max(cblock)
+ self.label_balance = label_balance
+ self.concat_balance = concat_balance
+ self.out_gain = torch.nn.Parameter(torch.zeros([]))
+
+ self.emb_fourier = MPFourier(cnoise)
+ self.emb_noise = MPConv(cnoise, cemb, kernel=())
+ self.emb_label = MPConv(label_dim, cemb, kernel=()) if label_dim else None
+
+ self.enc = torch.nn.ModuleDict()
+ cout = img_channels + 1
+ for level, channels in enumerate(cblock):
+ res = img_resolution >> level
+ if level == 0:
+ cin = cout
+ cout = channels
+ self.enc[f"{res}x{res}_conv"] = MPConv(cin, cout, kernel=(3, 3))
+ else:
+ self.enc[f"{res}x{res}_down"] = Block(cout, cout, cemb, flavor="enc", resample_mode="down", **block_kwargs)
+ for idx in range(num_blocks):
+ cin = cout
+ cout = channels
+ self.enc[f"{res}x{res}_block{idx}"] = Block(
+ cin,
+ cout,
+ cemb,
+ flavor="enc",
+ attention=(res in attn_resolutions),
+ **block_kwargs,
+ )
+
+ self.dec = torch.nn.ModuleDict()
+ skips = [block.out_channels for block in self.enc.values()]
+ for level, channels in reversed(list(enumerate(cblock))):
+ res = img_resolution >> level
+ if level == len(cblock) - 1:
+ self.dec[f"{res}x{res}_in0"] = Block(cout, cout, cemb, flavor="dec", attention=True, **block_kwargs)
+ self.dec[f"{res}x{res}_in1"] = Block(cout, cout, cemb, flavor="dec", **block_kwargs)
+ else:
+ self.dec[f"{res}x{res}_up"] = Block(cout, cout, cemb, flavor="dec", resample_mode="up", **block_kwargs)
+ for idx in range(num_blocks + 1):
+ cin = cout + skips.pop()
+ cout = channels
+ self.dec[f"{res}x{res}_block{idx}"] = Block(
+ cin,
+ cout,
+ cemb,
+ flavor="dec",
+ attention=(res in attn_resolutions),
+ **block_kwargs,
+ )
+
+ self.out_conv = MPConv(cout, img_channels, kernel=(3, 3))
+
+ def forward(self, x: torch.Tensor, noise_labels: torch.Tensor, class_labels: Optional[torch.Tensor]) -> torch.Tensor:
+ emb = self.emb_noise(self.emb_fourier(noise_labels))
+ if self.emb_label is not None:
+ if class_labels is None:
+ raise ValueError("class_labels are required for conditional EDM2UNet.")
+ emb = mp_sum(emb, self.emb_label(class_labels * math.sqrt(class_labels.shape[1])), t=self.label_balance)
+ emb = mp_silu(emb)
+
+ x = torch.cat([x, torch.ones_like(x[:, :1])], dim=1)
+ skips = []
+ for name, block in self.enc.items():
+ x = block(x) if "conv" in name else block(x, emb)
+ skips.append(x)
+
+ for name, block in self.dec.items():
+ if "block" in name:
+ x = mp_cat(x, skips.pop(), t=self.concat_balance)
+ x = block(x, emb)
+ return self.out_conv(x, gain=self.out_gain)
+
+
+@dataclass
+class EDM2UNet2DOutput(BaseOutput):
+ sample: torch.Tensor
+ logvar: Optional[torch.Tensor] = None
+
+
+
+_CONFIG_KEYS = (
+ "sample_size",
+ "in_channels",
+ "out_channels",
+ "num_class_embeds",
+ "use_fp16",
+ "sigma_data",
+ "logvar_channels",
+ "model_channels",
+ "channel_mult",
+ "channel_mult_noise",
+ "channel_mult_emb",
+ "num_blocks",
+ "attn_resolutions",
+ "label_balance",
+ "concat_balance",
+ "dropout",
+ "channels_per_head",
+ "res_balance",
+ "attn_balance",
+ "clip_act",
+)
+
+
+class EDM2UNet2DModel(ModelMixin, ConfigMixin):
+ @register_to_config
+ def __init__(
+ self,
+ sample_size: int = 64,
+ in_channels: int = 4,
+ out_channels: int = 4,
+ num_class_embeds: int = 0,
+ use_fp16: bool = True,
+ sigma_data: float = 0.5,
+ logvar_channels: int = 128,
+ model_channels: int = 192,
+ channel_mult: Tuple[int, ...] = (1, 2, 3, 4),
+ channel_mult_noise: Optional[int] = None,
+ channel_mult_emb: Optional[int] = None,
+ num_blocks: int = 3,
+ attn_resolutions: Tuple[int, ...] = (16, 8),
+ label_balance: float = 0.5,
+ concat_balance: float = 0.5,
+ dropout: float = 0.0,
+ channels_per_head: int = 64,
+ res_balance: float = 0.3,
+ attn_balance: float = 0.3,
+ clip_act: Optional[float] = 256,
+ ):
+ super().__init__()
+ self.sample_size = sample_size
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.num_class_embeds = num_class_embeds
+ self.use_fp16 = use_fp16
+ self.sigma_data = sigma_data
+ self.model_channels = model_channels
+ self.channel_mult = channel_mult
+ self.channel_mult_noise = channel_mult_noise
+ self.channel_mult_emb = channel_mult_emb
+ self.num_blocks = num_blocks
+ self.attn_resolutions = attn_resolutions
+ self.label_balance = label_balance
+ self.concat_balance = concat_balance
+ self.dropout = dropout
+ self.channels_per_head = channels_per_head
+ self.res_balance = res_balance
+ self.attn_balance = attn_balance
+ self.clip_act = clip_act
+ self.unet = EDM2UNet(
+ img_resolution=sample_size,
+ img_channels=in_channels,
+ label_dim=num_class_embeds,
+ model_channels=model_channels,
+ channel_mult=channel_mult,
+ channel_mult_noise=channel_mult_noise,
+ channel_mult_emb=channel_mult_emb,
+ num_blocks=num_blocks,
+ attn_resolutions=attn_resolutions,
+ label_balance=label_balance,
+ concat_balance=concat_balance,
+ dropout=dropout,
+ channels_per_head=channels_per_head,
+ res_balance=res_balance,
+ attn_balance=attn_balance,
+ clip_act=clip_act,
+ )
+ self.logvar_fourier = MPFourier(logvar_channels)
+ self.logvar_linear = MPConv(logvar_channels, 1, kernel=())
+
+ def forward(
+ self,
+ sample: torch.Tensor,
+ sigma: torch.Tensor,
+ class_labels: Optional[torch.Tensor] = None,
+ force_fp32: bool = False,
+ return_logvar: bool = False,
+ return_dict: bool = True,
+ ) -> EDM2UNet2DOutput:
+ x = sample.to(torch.float32)
+ sigma = sigma.to(torch.float32).reshape(-1, 1, 1, 1)
+ if self.num_class_embeds == 0:
+ class_labels = None
+ else:
+ if class_labels is None:
+ class_labels = torch.zeros([x.shape[0], self.num_class_embeds], device=x.device)
+ class_labels = class_labels.to(torch.float32).reshape(-1, self.num_class_embeds)
+ dtype = torch.float16 if (self.use_fp16 and not force_fp32 and x.device.type == "cuda") else torch.float32
+
+ c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2)
+ c_out = sigma * self.sigma_data / (sigma**2 + self.sigma_data**2).sqrt()
+ c_in = 1 / (self.sigma_data**2 + sigma**2).sqrt()
+ c_noise = sigma.flatten().log() / 4
+
+ x_in = (c_in * x).to(dtype)
+ f_x = self.unet(x_in, c_noise, class_labels)
+ d_x = c_skip * x + c_out * f_x.to(torch.float32)
+
+ logvar = None
+ if return_logvar:
+ logvar = self.logvar_linear(self.logvar_fourier(c_noise)).reshape(-1, 1, 1, 1)
+
+ if not return_dict:
+ return (d_x, logvar)
+ return EDM2UNet2DOutput(sample=d_x, logvar=logvar)
+
+ @classmethod
+ def from_pretrained(cls, pretrained_model_name_or_path: str, torch_dtype: Optional[torch.dtype] = None, **kwargs):
+ subfolder = kwargs.pop("subfolder", None)
+ model_dir = os.path.join(pretrained_model_name_or_path, subfolder) if subfolder else pretrained_model_name_or_path
+ with open(os.path.join(model_dir, "config.json"), "r", encoding="utf-8") as f:
+ config = json.load(f)
+ init_kwargs = {k: v for k, v in config.items() if k in _CONFIG_KEYS}
+ model = cls(**init_kwargs)
+ weight_file = os.path.join(model_dir, "diffusion_pytorch_model.safetensors")
+ if os.path.isfile(weight_file):
+ from safetensors.torch import load_file
+
+ state_dict = load_file(weight_file)
+ else:
+ state_dict = torch.load(os.path.join(model_dir, "diffusion_pytorch_model.bin"), map_location="cpu")
+ model.load_state_dict(state_dict, strict=True)
+ if torch_dtype is not None:
+ model = model.to(dtype=torch_dtype)
+ return model
+
+ def save_pretrained(self, save_directory: str, safe_serialization: bool = True):
+ os.makedirs(save_directory, exist_ok=True)
+ stored = dict(getattr(self, "config", {}))
+ config = {"_class_name": self.__class__.__name__}
+ for key in _CONFIG_KEYS:
+ if key in stored:
+ config[key] = stored[key]
+ elif hasattr(self, key):
+ config[key] = getattr(self, key)
+ with open(os.path.join(save_directory, "config.json"), "w", encoding="utf-8") as f:
+ json.dump(config, f, indent=2, sort_keys=True)
+ f.write("\n")
+ state_dict = self.state_dict()
+ if safe_serialization:
+ from safetensors.torch import save_file
+
+ save_file(state_dict, os.path.join(save_directory, "diffusion_pytorch_model.safetensors"))
+ else:
+ torch.save(state_dict, os.path.join(save_directory, "diffusion_pytorch_model.bin"))
diff --git a/edm2-img512-xxl-fid/vae/config.json b/edm2-img512-xxl-fid/vae/config.json
new file mode 100644
index 0000000000000000000000000000000000000000..ae2ab37de2acb8bc348d3de809ae8385324be962
--- /dev/null
+++ b/edm2-img512-xxl-fid/vae/config.json
@@ -0,0 +1,38 @@
+{
+ "_class_name": "AutoencoderKL",
+ "_diffusers_version": "0.36.0",
+ "_name_or_path": "stabilityai/sd-vae-ft-mse",
+ "act_fn": "silu",
+ "block_out_channels": [
+ 128,
+ 256,
+ 512,
+ 512
+ ],
+ "down_block_types": [
+ "DownEncoderBlock2D",
+ "DownEncoderBlock2D",
+ "DownEncoderBlock2D",
+ "DownEncoderBlock2D"
+ ],
+ "force_upcast": true,
+ "in_channels": 3,
+ "latent_channels": 4,
+ "latents_mean": null,
+ "latents_std": null,
+ "layers_per_block": 2,
+ "mid_block_add_attention": true,
+ "norm_num_groups": 32,
+ "out_channels": 3,
+ "sample_size": 256,
+ "scaling_factor": 0.18215,
+ "shift_factor": null,
+ "up_block_types": [
+ "UpDecoderBlock2D",
+ "UpDecoderBlock2D",
+ "UpDecoderBlock2D",
+ "UpDecoderBlock2D"
+ ],
+ "use_post_quant_conv": true,
+ "use_quant_conv": true
+}
diff --git a/edm2-img512-xxl-fid/vae/diffusion_pytorch_model.safetensors b/edm2-img512-xxl-fid/vae/diffusion_pytorch_model.safetensors
new file mode 100644
index 0000000000000000000000000000000000000000..90464d67ac7303d0ee4696334df13da130a948ea
--- /dev/null
+++ b/edm2-img512-xxl-fid/vae/diffusion_pytorch_model.safetensors
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a1d993488569e928462932c8c38a0760b874d166399b14414135bd9c42df5815
+size 334643276