diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..9926f7e8b88c5be1aec0dbb4af4bae3f0ed5d623 100644 --- a/.gitattributes +++ b/.gitattributes @@ -33,3 +33,10 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +edm2-img512-l-dino/demo.png filter=lfs diff=lfs merge=lfs -text +edm2-img512-l-fid/generator_test.png filter=lfs diff=lfs merge=lfs -text +edm2-img512-m-fid/demo.png filter=lfs diff=lfs merge=lfs -text +edm2-img512-s-fid/demo.png filter=lfs diff=lfs merge=lfs -text +edm2-img512-xl-fid/demo.png filter=lfs diff=lfs merge=lfs -text +edm2-img512-xs-fid/demo.png filter=lfs diff=lfs merge=lfs -text +edm2-img512-xxl-fid/demo.png filter=lfs diff=lfs merge=lfs -text diff --git a/README.md b/README.md new file mode 100644 index 0000000000000000000000000000000000000000..aed779e2b11ced8045fe03219eddd7a494c0e89f --- /dev/null +++ b/README.md @@ -0,0 +1,199 @@ +--- +license: cc-by-nc-sa-4.0 +library_name: diffusers +pipeline_tag: unconditional-image-generation +tags: + - diffusers + - edm2 + - image-generation + - class-conditional + - imagenet +inference: true +widget: + - output: + url: edm2-img512-xxl-fid/demo.png +language: + - en +--- + +# EDM2-diffusers + +Diffusers-ready checkpoints for **EDM2** ([Analyzing and Improving the Training Dynamics of Diffusion Models](https://arxiv.org/abs/2312.02696)), +converted from [NVlabs/edm2](https://github.com/NVlabs/edm2) post-hoc reconstructions. + +Official source weights: `https://nvlabs-fi-cdn.nvidia.com/edm2/posthoc-reconstructions/` + +This root folder is a model collection that contains: + +- `edm2-img512-xs-fid` +- `edm2-img512-s-fid` +- `edm2-img512-m-fid` +- `edm2-img512-l-fid` +- `edm2-img512-l-dino` +- `edm2-img512-xl-fid` +- `edm2-img512-xxl-fid` + +Each subfolder is a self-contained Diffusers model repo with: + +- `pipeline.py` +- `unet/unet_edm2.py` +- `scheduler/scheduler_config.json` (`EDMEulerScheduler`) +- `unet/diffusion_pytorch_model.safetensors` +- `vae/diffusion_pytorch_model.safetensors` + +## Demo + +![edm2-img512-xxl-fid demo](edm2-img512-xxl-fid/demo.png) + +Class-conditional sample (ImageNet class **207**, golden retriever), EDM2-XXL at 512×512, 32 steps, guidance 1.0, seed 42. + +## Model Paths + +Use paths relative to this root README: + +| Model | NVlabs preset | FID | Local path | +| --- | --- | ---: | --- | +| EDM2-XS | `edm2-img512-xs-fid` | 3.53 | `./edm2-img512-xs-fid` | +| EDM2-S | `edm2-img512-s-fid` | 2.56 | `./edm2-img512-s-fid` | +| EDM2-M | `edm2-img512-m-fid` | 2.25 | `./edm2-img512-m-fid` | +| EDM2-L | `edm2-img512-l-fid` | 2.06 | `./edm2-img512-l-fid` | +| EDM2-L (DINO) | `edm2-img512-l-dino` | — | `./edm2-img512-l-dino` | +| EDM2-XL | `edm2-img512-xl-fid` | 1.96 | `./edm2-img512-xl-fid` | +| EDM2-XXL | `edm2-img512-xxl-fid` | 1.91 | `./edm2-img512-xxl-fid` | + +## Inference Demo (Diffusers) + +### 1) Load a local subfolder checkpoint + +```python +from pathlib import Path +import torch +from diffusers import DiffusionPipeline + +model_dir = Path("./edm2-img512-xxl-fid") # change to any path in the table above +pipe = DiffusionPipeline.from_pretrained( + str(model_dir), + local_files_only=True, + trust_remote_code=True, + torch_dtype=torch.bfloat16, +).to("cuda") + +generator = torch.Generator(device="cuda").manual_seed(42) +image = pipe( + class_labels=207, # golden retriever (ImageNet id); omit for random class + num_inference_steps=32, + guidance_scale=1.0, # >1.0 requires a gnet/ checkpoint + generator=generator, +).images[0] +image.save("demo.png") +``` + +Official inference defaults (`generate_images.py`): `num_steps=32`, `sigma_min=0.002`, +`sigma_max=80`, `rho=7`, `guidance=1.0` (no gnet), `S_churn=0`. Heun sampling runs in +float32 internally even when UNet/VAE weights are loaded in bf16/fp16. + +Guided presets require a converted `gnet/` folder and `guidance_scale` matching the +NVlabs preset. + +### 2) Convert a legacy `.pkl` + +```bash +python scripts/convert_edm2_to_diffusers.py \ + --checkpoint models/BiliSakura/EDM2-diffusers/edm2-img512-xs-2147483-0.135.pkl \ + --output models/BiliSakura/EDM2-diffusers +``` + +Creates `edm2-img512-xs-fid/` automatically from the NVlabs preset mapping. + +## Checkpoint preset mapping + +Maps NVlabs `--preset=...` names from [`generate_images.py`](https://github.com/NVlabs/edm2/blob/main/generate_images.py) +to source pickle filenames and local Diffusers directories. + +### EDM2 paper — ImageNet-512 (conditional) + +| NVlabs preset | Source `.pkl` (net) | Diffusers dir | Metric | +| --- | --- | --- | --- | +| `edm2-img512-xs-fid` | `edm2-img512-xs-2147483-0.135.pkl` | `edm2-img512-xs-fid/` | FID 3.53 | +| `edm2-img512-xs-dino` | `edm2-img512-xs-2147483-0.200.pkl` | — | FDDINOv2 103.39 | +| `edm2-img512-s-fid` | `edm2-img512-s-2147483-0.130.pkl` | `edm2-img512-s-fid/` | FID 2.56 | +| `edm2-img512-s-dino` | `edm2-img512-s-2147483-0.190.pkl` | — | FDDINOv2 68.64 | +| `edm2-img512-m-fid` | `edm2-img512-m-2147483-0.100.pkl` | `edm2-img512-m-fid/` | FID 2.25 | +| `edm2-img512-m-dino` | `edm2-img512-m-2147483-0.155.pkl` | — | FDDINOv2 58.44 | +| `edm2-img512-l-fid` | `edm2-img512-l-1879048-0.085.pkl` | `edm2-img512-l-fid/` | FID 2.06 | +| `edm2-img512-l-dino` | `edm2-img512-l-1879048-0.155.pkl` | `edm2-img512-l-dino/` | FDDINOv2 52.25 | +| `edm2-img512-xl-fid` | `edm2-img512-xl-1342177-0.085.pkl` | `edm2-img512-xl-fid/` | FID 1.96 | +| `edm2-img512-xl-dino` | `edm2-img512-xl-1342177-0.155.pkl` | — | FDDINOv2 45.96 | +| `edm2-img512-xxl-fid` | `edm2-img512-xxl-0939524-0.070.pkl` | `edm2-img512-xxl-fid/` | FID 1.91 | +| `edm2-img512-xxl-dino` | `edm2-img512-xxl-0939524-0.150.pkl` | — | FDDINOv2 42.84 | + +### EDM2 paper — ImageNet-64 (conditional) + +| NVlabs preset | Source `.pkl` (net) | Metric | +| --- | --- | --- | +| `edm2-img64-s-fid` | `edm2-img64-s-1073741-0.075.pkl` | FID 1.58 | +| `edm2-img64-m-fid` | `edm2-img64-m-2147483-0.060.pkl` | FID 1.43 | +| `edm2-img64-l-fid` | `edm2-img64-l-1073741-0.040.pkl` | FID 1.33 | +| `edm2-img64-xl-fid` | `edm2-img64-xl-0671088-0.040.pkl` | FID 1.33 | + +### EDM2 paper — classifier-free guidance (ImageNet-512) + +Use `guidance_scale` below and include the converted `gnet/` checkpoint. + +| NVlabs preset | Source `.pkl` (net) | Source `.pkl` (gnet) | Guidance | Metric | +| --- | --- | --- | ---: | --- | +| `edm2-img512-xs-guid-fid` | `edm2-img512-xs-2147483-0.045.pkl` | `edm2-img512-xs-uncond-2147483-0.045.pkl` | 1.40 | FID 2.91 | +| `edm2-img512-xs-guid-dino` | `edm2-img512-xs-2147483-0.150.pkl` | `edm2-img512-xs-uncond-2147483-0.150.pkl` | 1.70 | FDDINOv2 79.94 | +| `edm2-img512-s-guid-fid` | `edm2-img512-s-2147483-0.025.pkl` | `edm2-img512-xs-uncond-2147483-0.025.pkl` | 1.40 | FID 2.23 | +| `edm2-img512-s-guid-dino` | `edm2-img512-s-2147483-0.085.pkl` | `edm2-img512-xs-uncond-2147483-0.085.pkl` | 1.90 | FDDINOv2 52.32 | +| `edm2-img512-m-guid-fid` | `edm2-img512-m-2147483-0.030.pkl` | `edm2-img512-xs-uncond-2147483-0.030.pkl` | 1.20 | FID 2.01 | +| `edm2-img512-m-guid-dino` | `edm2-img512-m-2147483-0.015.pkl` | `edm2-img512-xs-uncond-2147483-0.015.pkl` | 2.00 | FDDINOv2 41.98 | +| `edm2-img512-l-guid-fid` | `edm2-img512-l-1879048-0.015.pkl` | `edm2-img512-xs-uncond-2147483-0.015.pkl` | 1.20 | FID 1.88 | +| `edm2-img512-l-guid-dino` | `edm2-img512-l-1879048-0.035.pkl` | `edm2-img512-xs-uncond-2147483-0.035.pkl` | 1.70 | FDDINOv2 38.20 | +| `edm2-img512-xl-guid-fid` | `edm2-img512-xl-1342177-0.020.pkl` | `edm2-img512-xs-uncond-2147483-0.020.pkl` | 1.20 | FID 1.85 | +| `edm2-img512-xl-guid-dino` | `edm2-img512-xl-1342177-0.030.pkl` | `edm2-img512-xs-uncond-2147483-0.030.pkl` | 1.70 | FDDINOv2 35.67 | +| `edm2-img512-xxl-guid-fid` | `edm2-img512-xxl-0939524-0.015.pkl` | `edm2-img512-xs-uncond-2147483-0.015.pkl` | 1.20 | FID 1.81 | +| `edm2-img512-xxl-guid-dino` | `edm2-img512-xxl-0939524-0.015.pkl` | `edm2-img512-xs-uncond-2147483-0.015.pkl` | 1.70 | FDDINOv2 33.09 | + +### Autoguidance paper + +| NVlabs preset | Source `.pkl` (net) | Source `.pkl` (gnet) | Guidance | Metric | +| --- | --- | --- | ---: | --- | +| `edm2-img512-s-autog-fid` | `edm2-img512-s-2147483-0.070.pkl` | `edm2-img512-xs-0134217-0.125.pkl` | 2.10 | FID 1.34 | +| `edm2-img512-s-autog-dino` | `edm2-img512-s-2147483-0.120.pkl` | `edm2-img512-xs-0134217-0.165.pkl` | 2.45 | FDDINOv2 36.67 | +| `edm2-img512-xxl-autog-fid` | `edm2-img512-xxl-0939524-0.075.pkl` | `edm2-img512-m-0268435-0.155.pkl` | 2.05 | FID 1.25 | +| `edm2-img512-xxl-autog-dino` | `edm2-img512-xxl-0939524-0.130.pkl` | `edm2-img512-m-0268435-0.205.pkl` | 2.30 | FDDINOv2 24.18 | +| `edm2-img512-s-uncond-autog-fid` | `edm2-img512-s-uncond-2147483-0.070.pkl` | `edm2-img512-xs-uncond-0134217-0.110.pkl` | 2.85 | FID 3.86 | +| `edm2-img512-s-uncond-autog-dino` | `edm2-img512-s-uncond-2147483-0.090.pkl` | `edm2-img512-xs-uncond-0134217-0.125.pkl` | 2.90 | FDDINOv2 90.39 | +| `edm2-img64-s-autog-fid` | `edm2-img64-s-1073741-0.045.pkl` | `edm2-img64-xs-0134217-0.110.pkl` | 1.70 | FID 1.01 | +| `edm2-img64-s-autog-dino` | `edm2-img64-s-1073741-0.105.pkl` | `edm2-img64-xs-0134217-0.175.pkl` | 2.20 | FDDINOv2 31.85 | + +### NVlabs preset shorthand + +```text +# EDM2 paper +edm2-img512-{xs|s|m|l|xl|xxl}-{fid|dino} +edm2-img64-{s|m|l|xl}-fid +edm2-img512-{xs|s|m|l|xl|xxl}-guid-{fid|dino} + +# Autoguidance paper +edm2-img512-{s|xxl}-autog-{fid|dino} +edm2-img512-s-uncond-autog-{fid|dino} +edm2-img64-s-autog-{fid|dino} +``` + +Example NVlabs command: + +```bash +python generate_images.py --preset=edm2-img512-s-guid-dino --outdir=out +``` + +Equivalent expanded form: + +```bash +python generate_images.py \ + --net=https://nvlabs-fi-cdn.nvidia.com/edm2/posthoc-reconstructions/edm2-img512-s-2147483-0.085.pkl \ + --gnet=https://nvlabs-fi-cdn.nvidia.com/edm2/posthoc-reconstructions/edm2-img512-xs-uncond-2147483-0.085.pkl \ + --guidance=1.9 \ + --outdir=out +``` diff --git a/edm2-img512-l-dino/demo.png b/edm2-img512-l-dino/demo.png new file mode 100644 index 0000000000000000000000000000000000000000..60c5c3da1e339e32d3ea1c0fab558f809783d63d --- /dev/null +++ b/edm2-img512-l-dino/demo.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:12a2dab2ca0e5ec5a6eebe9f7c10b440232622055866192ecc5c8b3dc289db4d +size 389147 diff --git a/edm2-img512-l-dino/model_index.json b/edm2-img512-l-dino/model_index.json new file mode 100644 index 0000000000000000000000000000000000000000..24c0305eea3ec64abbc2f870476afa4f5e462f42 --- /dev/null +++ b/edm2-img512-l-dino/model_index.json @@ -0,0 +1,19 @@ +{ + "_class_name": [ + "pipeline", + "EDM2Pipeline" + ], + "_diffusers_version": "0.31.0", + "scheduler": [ + "diffusers", + "EDMEulerScheduler" + ], + "unet": [ + "unet_edm2", + "EDM2UNet2DModel" + ], + "vae": [ + "diffusers", + "AutoencoderKL" + ] +} diff --git a/edm2-img512-l-dino/pipeline.py b/edm2-img512-l-dino/pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..6b4131215d07770c74efa0674d00f3d6423d9c06 --- /dev/null +++ b/edm2-img512-l-dino/pipeline.py @@ -0,0 +1,406 @@ +"""Hub custom pipeline: EDM2Pipeline. +Load with native Hugging Face diffusers and trust_remote_code=True. +""" + +from __future__ import annotations + +# Copyright 2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os +from pathlib import Path +from typing import Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Union + +import numpy as np +import torch + +from diffusers.image_processor import VaeImageProcessor +from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput +from diffusers.utils import replace_example_docstring +from diffusers.utils.torch_utils import randn_tensor + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> from pathlib import Path + >>> import torch + >>> from diffusers import DiffusionPipeline + + >>> model_dir = Path("BiliSakura/EDM2-diffusers/edm2-img512-xs-fid").resolve() + >>> pipe = DiffusionPipeline.from_pretrained( + ... str(model_dir), + ... local_files_only=True, + ... custom_pipeline=str(model_dir / "pipeline.py"), + ... trust_remote_code=True, + ... torch_dtype=torch.float32, + ... ) + >>> pipe.to("cuda") + + >>> generator = torch.Generator(device="cuda").manual_seed(42) + >>> image = pipe( + ... class_labels=207, + ... num_inference_steps=32, + ... guidance_scale=1.0, + ... generator=generator, + ... ).images[0] + >>> image.save("demo.png") + ``` +""" + +# Default Stability VAE latent whitening used by NVlabs/edm2 (training/encoders.py). +_STABILITY_VAE_SCALE = np.float32(0.5) / np.float32([4.17, 4.62, 3.71, 3.28]) +_STABILITY_VAE_BIAS = np.float32(0.0) - np.float32([5.81, 3.25, 0.12, -2.15]) * _STABILITY_VAE_SCALE + +class EDM2Pipeline(DiffusionPipeline): + r""" + Pipeline for class-conditional image generation with EDM2 + ([Analyzing and Improving the Training Dynamics of Diffusion Models](https://arxiv.org/abs/2312.02696)). + + Parameters: + unet ([`EDM2UNet2DModel`]): + Main magnitude-preserving U-Net with EDM preconditioning. + scheduler ([`EDMEulerScheduler`]): + Built-in diffusers scheduler used for the Karras sigma schedule. EDM2 Heun sampling runs in + the pipeline because the UNet returns denoised latents rather than noise predictions. + vae ([`AutoencoderKL`], *optional*): + Decoder for 512px latent-diffusion checkpoints. Required when `unet.in_channels == 4`. + gnet ([`EDM2UNet2DModel`], *optional*): + Guiding network for autoguidance (`ref.lerp(main, guidance_scale)`). + id2label (`dict[int, str]`, *optional*): + ImageNet class id to English label mapping. + """ + + model_cpu_offload_seq = "unet->gnet->vae" + _optional_components = ["vae", "gnet"] + + def __init__( + self, + unet, + scheduler, + vae=None, + gnet=None, + id2label: Optional[Dict[Union[int, str], str]] = None, + ) -> None: + super().__init__() + self.register_modules(unet=unet, scheduler=scheduler, vae=vae, gnet=gnet) + self._id2label = self._normalize_id2label(id2label) + self.labels = self._build_label2id(self._id2label) + self._labels_loaded_from_model_index = bool(self._id2label) + self.vae_scale_factor = 8 if self.vae is not None else 1 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_normalize=False) + + @staticmethod + def _normalize_id2label(id2label: Optional[Dict[Union[int, str], str]]) -> Dict[int, str]: + if not id2label: + return {} + return {int(key): value for key, value in id2label.items()} + + @staticmethod + def _build_label2id(id2label: Dict[int, str]) -> Dict[str, int]: + label2id: Dict[str, int] = {} + for class_id, value in id2label.items(): + for synonym in value.split(","): + synonym = synonym.strip() + if synonym: + label2id[synonym] = int(class_id) + return dict(sorted(label2id.items())) + + def _ensure_labels_loaded(self) -> None: + if self._labels_loaded_from_model_index: + return + loaded = self._read_id2label_from_model_index(getattr(self.config, "_name_or_path", None)) + if loaded: + self._id2label = loaded + self.labels = self._build_label2id(self._id2label) + self._labels_loaded_from_model_index = True + + @staticmethod + def _read_id2label_from_model_index(variant_path: Optional[str]) -> Dict[int, str]: + if not variant_path: + return {} + model_index_path = Path(variant_path).resolve() / "model_index.json" + if not model_index_path.is_file(): + return {} + raw = json.loads(model_index_path.read_text(encoding="utf-8")) + id2label = raw.get("id2label") + if not isinstance(id2label, dict): + return {} + return {int(key): value for key, value in id2label.items()} + + @property + def id2label(self) -> Dict[int, str]: + r"""ImageNet class id to English label string (comma-separated synonyms).""" + self._ensure_labels_loaded() + return self._id2label + + def get_label_ids(self, label: Union[str, List[str]]) -> List[int]: + r""" + Map ImageNet label strings to class ids. + + Args: + label (`str` or `list[str]`): + One or more English label strings that match entries in `id2label`. + """ + self._ensure_labels_loaded() + if not self.labels: + raise ValueError("No English labels loaded. Add `id2label` to model_index.json.") + labels = [label] if isinstance(label, str) else list(label) + missing = [item for item in labels if item not in self.labels] + if missing: + preview = ", ".join(list(self.labels.keys())[:8]) + raise ValueError(f"Unknown English label(s): {missing}. Example valid labels: {preview}, ...") + return [self.labels[item] for item in labels] + + def _default_image_size(self) -> int: + latent_size = int(getattr(self.unet, "sample_size", getattr(self.unet.config, "sample_size", 64))) + return latent_size * self.vae_scale_factor + + def check_inputs( + self, + height: int, + width: int, + num_inference_steps: int, + guidance_scale: float, + output_type: str, + ) -> None: + if num_inference_steps < 1: + raise ValueError("num_inference_steps must be >= 1.") + if guidance_scale < 1.0: + raise ValueError("guidance_scale must be >= 1.0.") + if guidance_scale > 1.0 and self.gnet is None: + raise ValueError("guidance_scale > 1.0 requires a guiding network (`gnet`).") + if output_type not in {"pil", "np", "pt", "latent"}: + raise ValueError("output_type must be one of: 'pil', 'np', 'pt', 'latent'.") + + native_size = self._default_image_size() + if height != native_size or width != native_size: + raise ValueError( + f"EDM2 expects native resolution height=width={native_size}. " + f"Got height={height}, width={width}." + ) + + def _normalize_class_labels( + self, + class_labels: Optional[Union[int, str, Sequence[Union[int, str]], torch.Tensor]], + batch_size: int, + device: torch.device, + ) -> Optional[torch.Tensor]: + label_dim = int(getattr(self.unet, "num_class_embeds", getattr(self.unet.config, "num_class_embeds", 0))) + if label_dim == 0: + return None + if class_labels is None: + indices = torch.randint(label_dim, size=(batch_size,), device=device) + return torch.eye(label_dim, device=device, dtype=torch.float32)[indices] + + if isinstance(class_labels, str): + class_labels = self.get_label_ids(class_labels)[0] + elif isinstance(class_labels, Sequence) and class_labels and isinstance(class_labels[0], str): + class_labels = self.get_label_ids(list(class_labels)) + + if isinstance(class_labels, int): + indices = torch.full((batch_size,), class_labels, device=device, dtype=torch.long) + elif isinstance(class_labels, torch.Tensor): + if class_labels.ndim == 2: + labels = class_labels.to(device=device, dtype=torch.float32) + if labels.shape[0] != batch_size: + raise ValueError(f"class_labels batch must match batch_size={batch_size}.") + return labels + indices = class_labels.to(device=device, dtype=torch.long).flatten() + else: + indices = torch.tensor(list(class_labels), device=device, dtype=torch.long) + + if indices.numel() == 1 and batch_size > 1: + indices = indices.repeat(batch_size) + if indices.numel() != batch_size: + raise ValueError(f"class_labels must resolve to batch size {batch_size}, got {indices.numel()}.") + return torch.eye(label_dim, device=device, dtype=torch.float32)[indices] + + def prepare_latents( + self, + batch_size: int, + height: int, + width: int, + dtype: torch.dtype, + device: torch.device, + generator: Optional[Union[torch.Generator, List[torch.Generator]]], + ) -> torch.Tensor: + in_channels = int(getattr(self.unet, "in_channels", getattr(self.unet.config, "in_channels", 4))) + latent_size = height // self.vae_scale_factor + return randn_tensor( + (batch_size, in_channels, latent_size, latent_size), + generator=generator, + device=device, + dtype=torch.float32, + ) + + def decode_latents(self, latents: torch.Tensor, output_type: str = "pil"): + if output_type == "latent": + return latents + + in_channels = int(getattr(self.unet, "in_channels", getattr(self.unet.config, "in_channels", 3))) + if self.vae is None: + image = (latents.to(torch.float32) * 127.5 + 128).clip(0, 255) / 255.0 + return self.image_processor.postprocess(image, output_type=output_type) + + if in_channels == 4: + x = latents.to(torch.float32) + scale = torch.as_tensor(_STABILITY_VAE_SCALE, dtype=x.dtype, device=x.device).reshape(1, -1, 1, 1) + bias = torch.as_tensor(_STABILITY_VAE_BIAS, dtype=x.dtype, device=x.device).reshape(1, -1, 1, 1) + x = (x - bias) / scale + else: + x = latents.to(torch.float32) + + vae_dtype = getattr(self.vae, "dtype", None) or next(self.vae.parameters()).dtype + image = self.vae.decode(x.to(dtype=vae_dtype)).sample.to(torch.float32).clamp(0, 1) + + return self.image_processor.postprocess(image, output_type=output_type) + + @staticmethod + def _apply_autoguidance( + main: torch.Tensor, + ref: torch.Tensor, + guidance_scale: float, + ) -> torch.Tensor: + return ref.lerp(main, guidance_scale) + + @staticmethod + def _sample_edm2_heun( + denoise_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor], + noise: torch.Tensor, + sigmas: torch.Tensor, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + progress_bar: Optional[Callable[[Iterable], Iterable]] = None, + dtype: torch.dtype = torch.float32, + ) -> torch.Tensor: + """NVlabs EDM2 Heun sampler (generate_images.edm_sampler, guidance=1, S_churn=0).""" + x_next = noise.to(dtype) * sigmas[0] + + sigma_pairs = list(zip(sigmas[:-1], sigmas[1:])) + if progress_bar is not None: + sigma_pairs = progress_bar(sigma_pairs) + + num_steps = len(sigma_pairs) + for i, (sigma_cur, sigma_next) in enumerate(sigma_pairs): + x_hat, sigma_hat = x_next, sigma_cur + d_cur = (x_hat - denoise_fn(x_hat, sigma_hat)) / sigma_hat + x_next = x_hat + (sigma_next - sigma_hat) * d_cur + if i < num_steps - 1: + d_prime = (x_next - denoise_fn(x_next, sigma_next)) / sigma_next + x_next = x_hat + (sigma_next - sigma_hat) * (0.5 * d_cur + 0.5 * d_prime) + return x_next + + @torch.inference_mode() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + class_labels: Optional[Union[int, str, Sequence[Union[int, str]], torch.Tensor]] = None, + batch_size: int = 1, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 32, + guidance_scale: float = 1.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + output_type: str = "pil", + return_dict: bool = True, + ) -> Union[ImagePipelineOutput, Tuple]: + r""" + Generate class-conditional images with EDM2. + + Args: + class_labels (`int`, `str`, `list[int]`, `list[str]`, or `torch.Tensor`, *optional*): + ImageNet class indices, English label strings, or one-hot float tensors. + Random classes are sampled when omitted on conditional models. + batch_size (`int`, defaults to `1`): + Number of images to generate. + height (`int`, *optional*): + Output height in pixels. Defaults to the pretrained native resolution. + width (`int`, *optional*): + Output width in pixels. Defaults to the pretrained native resolution. + num_inference_steps (`int`, defaults to `32`): + Number of EDM2 Heun steps (NVlabs default). + guidance_scale (`float`, defaults to `1.0`): + Autoguidance strength. Values above `1.0` blend the main net with `gnet` + via `gnet_output.lerp(unet_output, guidance_scale)`. + generator (`torch.Generator`, *optional*): + RNG for reproducibility. + output_type (`str`, defaults to `"pil"`): + `"pil"`, `"np"`, `"pt"`, or `"latent"`. + return_dict (`bool`, defaults to `True`): + Return [`~pipelines.pipeline_utils.ImagePipelineOutput`] if True. + + Examples: + + """ + default_size = self._default_image_size() + height = int(height or default_size) + width = int(width or default_size) + self.check_inputs(height, width, num_inference_steps, guidance_scale, output_type) + + device = self._execution_device + dtype = self.unet.dtype + labels = self._normalize_class_labels(class_labels, batch_size=batch_size, device=device) + noise = self.prepare_latents(batch_size, height, width, dtype, device, generator) + + def denoise_fn(x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor: + sigma_batch = sigma.reshape(1).expand(batch_size) + main = self.unet( + sample=x, + sigma=sigma_batch, + class_labels=labels, + force_fp32=True, + ).sample + if guidance_scale == 1.0 or self.gnet is None: + return main.to(torch.float32) + ref = self.gnet( + sample=x, + sigma=sigma_batch, + class_labels=labels, + force_fp32=True, + ).sample + return self._apply_autoguidance(main, ref, guidance_scale).to(torch.float32) + + self.scheduler.set_timesteps(num_inference_steps, device=device) + latents = self._sample_edm2_heun( + denoise_fn=denoise_fn, + noise=noise, + sigmas=self.scheduler.sigmas.to(device), + generator=generator, + progress_bar=self.progress_bar, + dtype=torch.float32, + ) + + image = self.decode_latents(latents, output_type=output_type) + if not return_dict: + return (image, latents) + return ImagePipelineOutput(images=image) + + @classmethod + def _load_vae(cls, pretrained_model_name_or_path: str, torch_dtype: Optional[torch.dtype] = None): + vae_dir = os.path.join(pretrained_model_name_or_path, "vae") + if os.path.isdir(vae_dir): + try: + + return AutoencoderKL.from_pretrained(vae_dir, torch_dtype=torch_dtype) + except Exception: + return None + + vae_hint = os.path.join(pretrained_model_name_or_path, "vae_pretrained_model_name_or_path.txt") + if os.path.isfile(vae_hint): + with open(vae_hint, "r", encoding="utf-8") as f: + hub_id = f.read().strip() + if hub_id: + + return AutoencoderKL.from_pretrained(hub_id, torch_dtype=torch_dtype) + return None diff --git a/edm2-img512-l-dino/scheduler/scheduler_config.json b/edm2-img512-l-dino/scheduler/scheduler_config.json new file mode 100644 index 0000000000000000000000000000000000000000..08cbde97955d06d66b0b72be217db84bb797d711 --- /dev/null +++ b/edm2-img512-l-dino/scheduler/scheduler_config.json @@ -0,0 +1,11 @@ +{ + "_class_name": "EDMEulerScheduler", + "final_sigmas_type": "zero", + "num_train_timesteps": 1000, + "prediction_type": "epsilon", + "rho": 7.0, + "sigma_data": 0.5, + "sigma_max": 80.0, + "sigma_min": 0.002, + "sigma_schedule": "karras" +} diff --git a/edm2-img512-l-dino/unet/config.json b/edm2-img512-l-dino/unet/config.json new file mode 100644 index 0000000000000000000000000000000000000000..939ef1d82a8da253ce5dc9b30d3ac644f49dc4ed --- /dev/null +++ b/edm2-img512-l-dino/unet/config.json @@ -0,0 +1,31 @@ +{ + "_class_name": "EDM2UNet2DModel", + "attn_balance": 0.3, + "attn_resolutions": [ + 16, + 8 + ], + "channel_mult": [ + 1, + 2, + 3, + 4 + ], + "channel_mult_emb": 4, + "channel_mult_noise": 1, + "channels_per_head": 64, + "clip_act": 256, + "concat_balance": 0.5, + "dropout": 0.0, + "in_channels": 4, + "label_balance": 0.5, + "logvar_channels": 128, + "model_channels": 320, + "num_blocks": 3, + "num_class_embeds": 1000, + "out_channels": 4, + "res_balance": 0.3, + "sample_size": 64, + "sigma_data": 0.5, + "use_fp16": true +} diff --git a/edm2-img512-l-dino/unet/diffusion_pytorch_model.safetensors b/edm2-img512-l-dino/unet/diffusion_pytorch_model.safetensors new file mode 100644 index 0000000000000000000000000000000000000000..09f449c4dab91a9ddfd36f4204f4a3bc472f5208 --- /dev/null +++ b/edm2-img512-l-dino/unet/diffusion_pytorch_model.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f13f83377a74d74e1205843e241ce6d6e4bc9e49c2661944e49fdbe4d515ba33 +size 3110018564 diff --git a/edm2-img512-l-dino/unet/unet_edm2.py b/edm2-img512-l-dino/unet/unet_edm2.py new file mode 100644 index 0000000000000000000000000000000000000000..49d9451b443c4d6ad3b6ba2db19a7a804b07b6de --- /dev/null +++ b/edm2-img512-l-dino/unet/unet_edm2.py @@ -0,0 +1,434 @@ +import math +import json +import os +from dataclasses import dataclass +from typing import List, Optional, Tuple + +import numpy as np +import torch + +try: + from diffusers.configuration_utils import ConfigMixin, register_to_config + from diffusers.models.modeling_utils import ModelMixin + from diffusers.utils import BaseOutput +except ImportError: # pragma: no cover + class ModelMixin(torch.nn.Module): + pass + + class ConfigMixin: + config = {} + + def register_to_config(self, **kwargs): + self.config = kwargs + + def register_to_config(func): + return func + + @dataclass + class BaseOutput: + pass + + +def normalize(x: torch.Tensor, dim: Optional[List[int]] = None, eps: float = 1e-4) -> torch.Tensor: + if dim is None: + dim = list(range(1, x.ndim)) + norm = torch.linalg.vector_norm(x, dim=dim, keepdim=True, dtype=torch.float32) + norm = torch.add(eps, norm, alpha=math.sqrt(norm.numel() / x.numel())) + return x / norm.to(x.dtype) + + +def resample(x: torch.Tensor, f: List[float], mode: str = "keep") -> torch.Tensor: + if mode == "keep": + return x + filt = np.float32(f) + pad = (len(filt) - 1) // 2 + filt = filt / filt.sum() + filt = np.outer(filt, filt)[np.newaxis, np.newaxis, :, :] + filt = torch.as_tensor(filt, dtype=x.dtype, device=x.device) + c = x.shape[1] + if mode == "down": + return torch.nn.functional.conv2d(x, filt.tile([c, 1, 1, 1]), groups=c, stride=2, padding=(pad,)) + return torch.nn.functional.conv_transpose2d(x, (filt * 4).tile([c, 1, 1, 1]), groups=c, stride=2, padding=(pad,)) + + +def mp_silu(x: torch.Tensor) -> torch.Tensor: + return torch.nn.functional.silu(x) / 0.596 + + +def mp_sum(a: torch.Tensor, b: torch.Tensor, t: float = 0.5) -> torch.Tensor: + return a.lerp(b, t) / math.sqrt((1 - t) ** 2 + t ** 2) + + +def mp_cat(a: torch.Tensor, b: torch.Tensor, dim: int = 1, t: float = 0.5) -> torch.Tensor: + na = a.shape[dim] + nb = b.shape[dim] + c = math.sqrt((na + nb) / ((1 - t) ** 2 + t ** 2)) + wa = c / math.sqrt(na) * (1 - t) + wb = c / math.sqrt(nb) * t + return torch.cat([wa * a, wb * b], dim=dim) + + +class MPFourier(torch.nn.Module): + def __init__(self, num_channels: int, bandwidth: float = 1): + super().__init__() + self.register_buffer("freqs", 2 * math.pi * torch.randn(num_channels) * bandwidth) + self.register_buffer("phases", 2 * math.pi * torch.rand(num_channels)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + y = x.to(torch.float32).ger(self.freqs.to(torch.float32)) + y = y + self.phases.to(torch.float32) + y = y.cos() * math.sqrt(2) + return y.to(x.dtype) + + +class MPConv(torch.nn.Module): + def __init__(self, in_channels: int, out_channels: int, kernel: Tuple[int, ...]): + super().__init__() + self.out_channels = out_channels + self.weight = torch.nn.Parameter(torch.randn(out_channels, in_channels, *kernel)) + + def forward(self, x: torch.Tensor, gain: float = 1) -> torch.Tensor: + w = self.weight.to(torch.float32) + if self.training: + with torch.no_grad(): + self.weight.copy_(normalize(w)) + w = normalize(w) + w = w * (gain / math.sqrt(w[0].numel())) + w = w.to(x.dtype) + if w.ndim == 2: + return x @ w.t() + return torch.nn.functional.conv2d(x, w, padding=(w.shape[-1] // 2,)) + + +class Block(torch.nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + emb_channels: int, + flavor: str = "enc", + resample_mode: str = "keep", + resample_filter: List[float] = [1, 1], + attention: bool = False, + channels_per_head: int = 64, + dropout: float = 0.0, + res_balance: float = 0.3, + attn_balance: float = 0.3, + clip_act: Optional[float] = 256, + ): + super().__init__() + self.out_channels = out_channels + self.flavor = flavor + self.resample_filter = resample_filter + self.resample_mode = resample_mode + self.num_heads = out_channels // channels_per_head if attention else 0 + self.dropout = dropout + self.res_balance = res_balance + self.attn_balance = attn_balance + self.clip_act = clip_act + self.emb_gain = torch.nn.Parameter(torch.zeros([])) + self.conv_res0 = MPConv(out_channels if flavor == "enc" else in_channels, out_channels, kernel=(3, 3)) + self.emb_linear = MPConv(emb_channels, out_channels, kernel=()) + self.conv_res1 = MPConv(out_channels, out_channels, kernel=(3, 3)) + self.conv_skip = MPConv(in_channels, out_channels, kernel=(1, 1)) if in_channels != out_channels else None + self.attn_qkv = MPConv(out_channels, out_channels * 3, kernel=(1, 1)) if self.num_heads else None + self.attn_proj = MPConv(out_channels, out_channels, kernel=(1, 1)) if self.num_heads else None + + def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor: + x = resample(x, f=self.resample_filter, mode=self.resample_mode) + if self.flavor == "enc": + if self.conv_skip is not None: + x = self.conv_skip(x) + x = normalize(x, dim=[1]) + + y = self.conv_res0(mp_silu(x)) + c = self.emb_linear(emb, gain=self.emb_gain) + 1 + y = mp_silu(y * c.unsqueeze(2).unsqueeze(3).to(y.dtype)) + if self.training and self.dropout: + y = torch.nn.functional.dropout(y, p=self.dropout) + y = self.conv_res1(y) + + if self.flavor == "dec" and self.conv_skip is not None: + x = self.conv_skip(x) + x = mp_sum(x, y, t=self.res_balance) + + if self.num_heads: + y = self.attn_qkv(x) + y = y.reshape(y.shape[0], self.num_heads, -1, 3, y.shape[2] * y.shape[3]) + q, k, v = normalize(y, dim=[2]).unbind(3) + w = torch.einsum("nhcq,nhck->nhqk", q, k / math.sqrt(q.shape[2])).softmax(dim=3) + y = torch.einsum("nhqk,nhck->nhcq", w, v) + y = self.attn_proj(y.reshape(*x.shape)) + x = mp_sum(x, y, t=self.attn_balance) + + if self.clip_act is not None: + x = x.clip_(-self.clip_act, self.clip_act) + return x + + +class EDM2UNet(torch.nn.Module): + def __init__( + self, + img_resolution: int, + img_channels: int, + label_dim: int, + model_channels: int = 192, + channel_mult: Tuple[int, ...] = (1, 2, 3, 4), + channel_mult_noise: Optional[int] = None, + channel_mult_emb: Optional[int] = None, + num_blocks: int = 3, + attn_resolutions: Tuple[int, ...] = (16, 8), + label_balance: float = 0.5, + concat_balance: float = 0.5, + **block_kwargs, + ): + super().__init__() + cblock = [model_channels * x for x in channel_mult] + cnoise = model_channels * channel_mult_noise if channel_mult_noise is not None else cblock[0] + cemb = model_channels * channel_mult_emb if channel_mult_emb is not None else max(cblock) + self.label_balance = label_balance + self.concat_balance = concat_balance + self.out_gain = torch.nn.Parameter(torch.zeros([])) + + self.emb_fourier = MPFourier(cnoise) + self.emb_noise = MPConv(cnoise, cemb, kernel=()) + self.emb_label = MPConv(label_dim, cemb, kernel=()) if label_dim else None + + self.enc = torch.nn.ModuleDict() + cout = img_channels + 1 + for level, channels in enumerate(cblock): + res = img_resolution >> level + if level == 0: + cin = cout + cout = channels + self.enc[f"{res}x{res}_conv"] = MPConv(cin, cout, kernel=(3, 3)) + else: + self.enc[f"{res}x{res}_down"] = Block(cout, cout, cemb, flavor="enc", resample_mode="down", **block_kwargs) + for idx in range(num_blocks): + cin = cout + cout = channels + self.enc[f"{res}x{res}_block{idx}"] = Block( + cin, + cout, + cemb, + flavor="enc", + attention=(res in attn_resolutions), + **block_kwargs, + ) + + self.dec = torch.nn.ModuleDict() + skips = [block.out_channels for block in self.enc.values()] + for level, channels in reversed(list(enumerate(cblock))): + res = img_resolution >> level + if level == len(cblock) - 1: + self.dec[f"{res}x{res}_in0"] = Block(cout, cout, cemb, flavor="dec", attention=True, **block_kwargs) + self.dec[f"{res}x{res}_in1"] = Block(cout, cout, cemb, flavor="dec", **block_kwargs) + else: + self.dec[f"{res}x{res}_up"] = Block(cout, cout, cemb, flavor="dec", resample_mode="up", **block_kwargs) + for idx in range(num_blocks + 1): + cin = cout + skips.pop() + cout = channels + self.dec[f"{res}x{res}_block{idx}"] = Block( + cin, + cout, + cemb, + flavor="dec", + attention=(res in attn_resolutions), + **block_kwargs, + ) + + self.out_conv = MPConv(cout, img_channels, kernel=(3, 3)) + + def forward(self, x: torch.Tensor, noise_labels: torch.Tensor, class_labels: Optional[torch.Tensor]) -> torch.Tensor: + emb = self.emb_noise(self.emb_fourier(noise_labels)) + if self.emb_label is not None: + if class_labels is None: + raise ValueError("class_labels are required for conditional EDM2UNet.") + emb = mp_sum(emb, self.emb_label(class_labels * math.sqrt(class_labels.shape[1])), t=self.label_balance) + emb = mp_silu(emb) + + x = torch.cat([x, torch.ones_like(x[:, :1])], dim=1) + skips = [] + for name, block in self.enc.items(): + x = block(x) if "conv" in name else block(x, emb) + skips.append(x) + + for name, block in self.dec.items(): + if "block" in name: + x = mp_cat(x, skips.pop(), t=self.concat_balance) + x = block(x, emb) + return self.out_conv(x, gain=self.out_gain) + + +@dataclass +class EDM2UNet2DOutput(BaseOutput): + sample: torch.Tensor + logvar: Optional[torch.Tensor] = None + + + +_CONFIG_KEYS = ( + "sample_size", + "in_channels", + "out_channels", + "num_class_embeds", + "use_fp16", + "sigma_data", + "logvar_channels", + "model_channels", + "channel_mult", + "channel_mult_noise", + "channel_mult_emb", + "num_blocks", + "attn_resolutions", + "label_balance", + "concat_balance", + "dropout", + "channels_per_head", + "res_balance", + "attn_balance", + "clip_act", +) + + +class EDM2UNet2DModel(ModelMixin, ConfigMixin): + @register_to_config + def __init__( + self, + sample_size: int = 64, + in_channels: int = 4, + out_channels: int = 4, + num_class_embeds: int = 0, + use_fp16: bool = True, + sigma_data: float = 0.5, + logvar_channels: int = 128, + model_channels: int = 192, + channel_mult: Tuple[int, ...] = (1, 2, 3, 4), + channel_mult_noise: Optional[int] = None, + channel_mult_emb: Optional[int] = None, + num_blocks: int = 3, + attn_resolutions: Tuple[int, ...] = (16, 8), + label_balance: float = 0.5, + concat_balance: float = 0.5, + dropout: float = 0.0, + channels_per_head: int = 64, + res_balance: float = 0.3, + attn_balance: float = 0.3, + clip_act: Optional[float] = 256, + ): + super().__init__() + self.sample_size = sample_size + self.in_channels = in_channels + self.out_channels = out_channels + self.num_class_embeds = num_class_embeds + self.use_fp16 = use_fp16 + self.sigma_data = sigma_data + self.model_channels = model_channels + self.channel_mult = channel_mult + self.channel_mult_noise = channel_mult_noise + self.channel_mult_emb = channel_mult_emb + self.num_blocks = num_blocks + self.attn_resolutions = attn_resolutions + self.label_balance = label_balance + self.concat_balance = concat_balance + self.dropout = dropout + self.channels_per_head = channels_per_head + self.res_balance = res_balance + self.attn_balance = attn_balance + self.clip_act = clip_act + self.unet = EDM2UNet( + img_resolution=sample_size, + img_channels=in_channels, + label_dim=num_class_embeds, + model_channels=model_channels, + channel_mult=channel_mult, + channel_mult_noise=channel_mult_noise, + channel_mult_emb=channel_mult_emb, + num_blocks=num_blocks, + attn_resolutions=attn_resolutions, + label_balance=label_balance, + concat_balance=concat_balance, + dropout=dropout, + channels_per_head=channels_per_head, + res_balance=res_balance, + attn_balance=attn_balance, + clip_act=clip_act, + ) + self.logvar_fourier = MPFourier(logvar_channels) + self.logvar_linear = MPConv(logvar_channels, 1, kernel=()) + + def forward( + self, + sample: torch.Tensor, + sigma: torch.Tensor, + class_labels: Optional[torch.Tensor] = None, + force_fp32: bool = False, + return_logvar: bool = False, + return_dict: bool = True, + ) -> EDM2UNet2DOutput: + x = sample.to(torch.float32) + sigma = sigma.to(torch.float32).reshape(-1, 1, 1, 1) + if self.num_class_embeds == 0: + class_labels = None + else: + if class_labels is None: + class_labels = torch.zeros([x.shape[0], self.num_class_embeds], device=x.device) + class_labels = class_labels.to(torch.float32).reshape(-1, self.num_class_embeds) + dtype = torch.float16 if (self.use_fp16 and not force_fp32 and x.device.type == "cuda") else torch.float32 + + c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2) + c_out = sigma * self.sigma_data / (sigma**2 + self.sigma_data**2).sqrt() + c_in = 1 / (self.sigma_data**2 + sigma**2).sqrt() + c_noise = sigma.flatten().log() / 4 + + x_in = (c_in * x).to(dtype) + f_x = self.unet(x_in, c_noise, class_labels) + d_x = c_skip * x + c_out * f_x.to(torch.float32) + + logvar = None + if return_logvar: + logvar = self.logvar_linear(self.logvar_fourier(c_noise)).reshape(-1, 1, 1, 1) + + if not return_dict: + return (d_x, logvar) + return EDM2UNet2DOutput(sample=d_x, logvar=logvar) + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path: str, torch_dtype: Optional[torch.dtype] = None, **kwargs): + subfolder = kwargs.pop("subfolder", None) + model_dir = os.path.join(pretrained_model_name_or_path, subfolder) if subfolder else pretrained_model_name_or_path + with open(os.path.join(model_dir, "config.json"), "r", encoding="utf-8") as f: + config = json.load(f) + init_kwargs = {k: v for k, v in config.items() if k in _CONFIG_KEYS} + model = cls(**init_kwargs) + weight_file = os.path.join(model_dir, "diffusion_pytorch_model.safetensors") + if os.path.isfile(weight_file): + from safetensors.torch import load_file + + state_dict = load_file(weight_file) + else: + state_dict = torch.load(os.path.join(model_dir, "diffusion_pytorch_model.bin"), map_location="cpu") + model.load_state_dict(state_dict, strict=True) + if torch_dtype is not None: + model = model.to(dtype=torch_dtype) + return model + + def save_pretrained(self, save_directory: str, safe_serialization: bool = True): + os.makedirs(save_directory, exist_ok=True) + stored = dict(getattr(self, "config", {})) + config = {"_class_name": self.__class__.__name__} + for key in _CONFIG_KEYS: + if key in stored: + config[key] = stored[key] + elif hasattr(self, key): + config[key] = getattr(self, key) + with open(os.path.join(save_directory, "config.json"), "w", encoding="utf-8") as f: + json.dump(config, f, indent=2, sort_keys=True) + f.write("\n") + state_dict = self.state_dict() + if safe_serialization: + from safetensors.torch import save_file + + save_file(state_dict, os.path.join(save_directory, "diffusion_pytorch_model.safetensors")) + else: + torch.save(state_dict, os.path.join(save_directory, "diffusion_pytorch_model.bin")) diff --git a/edm2-img512-l-dino/vae/config.json b/edm2-img512-l-dino/vae/config.json new file mode 100644 index 0000000000000000000000000000000000000000..ae2ab37de2acb8bc348d3de809ae8385324be962 --- /dev/null +++ b/edm2-img512-l-dino/vae/config.json @@ -0,0 +1,38 @@ +{ + "_class_name": "AutoencoderKL", + "_diffusers_version": "0.36.0", + "_name_or_path": "stabilityai/sd-vae-ft-mse", + "act_fn": "silu", + "block_out_channels": [ + 128, + 256, + 512, + 512 + ], + "down_block_types": [ + "DownEncoderBlock2D", + "DownEncoderBlock2D", + "DownEncoderBlock2D", + "DownEncoderBlock2D" + ], + "force_upcast": true, + "in_channels": 3, + "latent_channels": 4, + "latents_mean": null, + "latents_std": null, + "layers_per_block": 2, + "mid_block_add_attention": true, + "norm_num_groups": 32, + "out_channels": 3, + "sample_size": 256, + "scaling_factor": 0.18215, + "shift_factor": null, + "up_block_types": [ + "UpDecoderBlock2D", + "UpDecoderBlock2D", + "UpDecoderBlock2D", + "UpDecoderBlock2D" + ], + "use_post_quant_conv": true, + "use_quant_conv": true +} diff --git a/edm2-img512-l-dino/vae/diffusion_pytorch_model.safetensors b/edm2-img512-l-dino/vae/diffusion_pytorch_model.safetensors new file mode 100644 index 0000000000000000000000000000000000000000..90464d67ac7303d0ee4696334df13da130a948ea --- /dev/null +++ b/edm2-img512-l-dino/vae/diffusion_pytorch_model.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a1d993488569e928462932c8c38a0760b874d166399b14414135bd9c42df5815 +size 334643276 diff --git a/edm2-img512-l-fid/generator_test.png b/edm2-img512-l-fid/generator_test.png new file mode 100644 index 0000000000000000000000000000000000000000..c7201a46a9b325e4b7129290fa7a3f13549cc7d4 --- /dev/null +++ b/edm2-img512-l-fid/generator_test.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:cabf3ca8019e86c4a85855d5c3fd2c6de6d25ac51682da208d20db23533e6578 +size 378707 diff --git a/edm2-img512-l-fid/model_index.json b/edm2-img512-l-fid/model_index.json new file mode 100644 index 0000000000000000000000000000000000000000..24c0305eea3ec64abbc2f870476afa4f5e462f42 --- /dev/null +++ b/edm2-img512-l-fid/model_index.json @@ -0,0 +1,19 @@ +{ + "_class_name": [ + "pipeline", + "EDM2Pipeline" + ], + "_diffusers_version": "0.31.0", + "scheduler": [ + "diffusers", + "EDMEulerScheduler" + ], + "unet": [ + "unet_edm2", + "EDM2UNet2DModel" + ], + "vae": [ + "diffusers", + "AutoencoderKL" + ] +} diff --git a/edm2-img512-l-fid/pipeline.py b/edm2-img512-l-fid/pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..6b4131215d07770c74efa0674d00f3d6423d9c06 --- /dev/null +++ b/edm2-img512-l-fid/pipeline.py @@ -0,0 +1,406 @@ +"""Hub custom pipeline: EDM2Pipeline. +Load with native Hugging Face diffusers and trust_remote_code=True. +""" + +from __future__ import annotations + +# Copyright 2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os +from pathlib import Path +from typing import Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Union + +import numpy as np +import torch + +from diffusers.image_processor import VaeImageProcessor +from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput +from diffusers.utils import replace_example_docstring +from diffusers.utils.torch_utils import randn_tensor + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> from pathlib import Path + >>> import torch + >>> from diffusers import DiffusionPipeline + + >>> model_dir = Path("BiliSakura/EDM2-diffusers/edm2-img512-xs-fid").resolve() + >>> pipe = DiffusionPipeline.from_pretrained( + ... str(model_dir), + ... local_files_only=True, + ... custom_pipeline=str(model_dir / "pipeline.py"), + ... trust_remote_code=True, + ... torch_dtype=torch.float32, + ... ) + >>> pipe.to("cuda") + + >>> generator = torch.Generator(device="cuda").manual_seed(42) + >>> image = pipe( + ... class_labels=207, + ... num_inference_steps=32, + ... guidance_scale=1.0, + ... generator=generator, + ... ).images[0] + >>> image.save("demo.png") + ``` +""" + +# Default Stability VAE latent whitening used by NVlabs/edm2 (training/encoders.py). +_STABILITY_VAE_SCALE = np.float32(0.5) / np.float32([4.17, 4.62, 3.71, 3.28]) +_STABILITY_VAE_BIAS = np.float32(0.0) - np.float32([5.81, 3.25, 0.12, -2.15]) * _STABILITY_VAE_SCALE + +class EDM2Pipeline(DiffusionPipeline): + r""" + Pipeline for class-conditional image generation with EDM2 + ([Analyzing and Improving the Training Dynamics of Diffusion Models](https://arxiv.org/abs/2312.02696)). + + Parameters: + unet ([`EDM2UNet2DModel`]): + Main magnitude-preserving U-Net with EDM preconditioning. + scheduler ([`EDMEulerScheduler`]): + Built-in diffusers scheduler used for the Karras sigma schedule. EDM2 Heun sampling runs in + the pipeline because the UNet returns denoised latents rather than noise predictions. + vae ([`AutoencoderKL`], *optional*): + Decoder for 512px latent-diffusion checkpoints. Required when `unet.in_channels == 4`. + gnet ([`EDM2UNet2DModel`], *optional*): + Guiding network for autoguidance (`ref.lerp(main, guidance_scale)`). + id2label (`dict[int, str]`, *optional*): + ImageNet class id to English label mapping. + """ + + model_cpu_offload_seq = "unet->gnet->vae" + _optional_components = ["vae", "gnet"] + + def __init__( + self, + unet, + scheduler, + vae=None, + gnet=None, + id2label: Optional[Dict[Union[int, str], str]] = None, + ) -> None: + super().__init__() + self.register_modules(unet=unet, scheduler=scheduler, vae=vae, gnet=gnet) + self._id2label = self._normalize_id2label(id2label) + self.labels = self._build_label2id(self._id2label) + self._labels_loaded_from_model_index = bool(self._id2label) + self.vae_scale_factor = 8 if self.vae is not None else 1 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_normalize=False) + + @staticmethod + def _normalize_id2label(id2label: Optional[Dict[Union[int, str], str]]) -> Dict[int, str]: + if not id2label: + return {} + return {int(key): value for key, value in id2label.items()} + + @staticmethod + def _build_label2id(id2label: Dict[int, str]) -> Dict[str, int]: + label2id: Dict[str, int] = {} + for class_id, value in id2label.items(): + for synonym in value.split(","): + synonym = synonym.strip() + if synonym: + label2id[synonym] = int(class_id) + return dict(sorted(label2id.items())) + + def _ensure_labels_loaded(self) -> None: + if self._labels_loaded_from_model_index: + return + loaded = self._read_id2label_from_model_index(getattr(self.config, "_name_or_path", None)) + if loaded: + self._id2label = loaded + self.labels = self._build_label2id(self._id2label) + self._labels_loaded_from_model_index = True + + @staticmethod + def _read_id2label_from_model_index(variant_path: Optional[str]) -> Dict[int, str]: + if not variant_path: + return {} + model_index_path = Path(variant_path).resolve() / "model_index.json" + if not model_index_path.is_file(): + return {} + raw = json.loads(model_index_path.read_text(encoding="utf-8")) + id2label = raw.get("id2label") + if not isinstance(id2label, dict): + return {} + return {int(key): value for key, value in id2label.items()} + + @property + def id2label(self) -> Dict[int, str]: + r"""ImageNet class id to English label string (comma-separated synonyms).""" + self._ensure_labels_loaded() + return self._id2label + + def get_label_ids(self, label: Union[str, List[str]]) -> List[int]: + r""" + Map ImageNet label strings to class ids. + + Args: + label (`str` or `list[str]`): + One or more English label strings that match entries in `id2label`. + """ + self._ensure_labels_loaded() + if not self.labels: + raise ValueError("No English labels loaded. Add `id2label` to model_index.json.") + labels = [label] if isinstance(label, str) else list(label) + missing = [item for item in labels if item not in self.labels] + if missing: + preview = ", ".join(list(self.labels.keys())[:8]) + raise ValueError(f"Unknown English label(s): {missing}. Example valid labels: {preview}, ...") + return [self.labels[item] for item in labels] + + def _default_image_size(self) -> int: + latent_size = int(getattr(self.unet, "sample_size", getattr(self.unet.config, "sample_size", 64))) + return latent_size * self.vae_scale_factor + + def check_inputs( + self, + height: int, + width: int, + num_inference_steps: int, + guidance_scale: float, + output_type: str, + ) -> None: + if num_inference_steps < 1: + raise ValueError("num_inference_steps must be >= 1.") + if guidance_scale < 1.0: + raise ValueError("guidance_scale must be >= 1.0.") + if guidance_scale > 1.0 and self.gnet is None: + raise ValueError("guidance_scale > 1.0 requires a guiding network (`gnet`).") + if output_type not in {"pil", "np", "pt", "latent"}: + raise ValueError("output_type must be one of: 'pil', 'np', 'pt', 'latent'.") + + native_size = self._default_image_size() + if height != native_size or width != native_size: + raise ValueError( + f"EDM2 expects native resolution height=width={native_size}. " + f"Got height={height}, width={width}." + ) + + def _normalize_class_labels( + self, + class_labels: Optional[Union[int, str, Sequence[Union[int, str]], torch.Tensor]], + batch_size: int, + device: torch.device, + ) -> Optional[torch.Tensor]: + label_dim = int(getattr(self.unet, "num_class_embeds", getattr(self.unet.config, "num_class_embeds", 0))) + if label_dim == 0: + return None + if class_labels is None: + indices = torch.randint(label_dim, size=(batch_size,), device=device) + return torch.eye(label_dim, device=device, dtype=torch.float32)[indices] + + if isinstance(class_labels, str): + class_labels = self.get_label_ids(class_labels)[0] + elif isinstance(class_labels, Sequence) and class_labels and isinstance(class_labels[0], str): + class_labels = self.get_label_ids(list(class_labels)) + + if isinstance(class_labels, int): + indices = torch.full((batch_size,), class_labels, device=device, dtype=torch.long) + elif isinstance(class_labels, torch.Tensor): + if class_labels.ndim == 2: + labels = class_labels.to(device=device, dtype=torch.float32) + if labels.shape[0] != batch_size: + raise ValueError(f"class_labels batch must match batch_size={batch_size}.") + return labels + indices = class_labels.to(device=device, dtype=torch.long).flatten() + else: + indices = torch.tensor(list(class_labels), device=device, dtype=torch.long) + + if indices.numel() == 1 and batch_size > 1: + indices = indices.repeat(batch_size) + if indices.numel() != batch_size: + raise ValueError(f"class_labels must resolve to batch size {batch_size}, got {indices.numel()}.") + return torch.eye(label_dim, device=device, dtype=torch.float32)[indices] + + def prepare_latents( + self, + batch_size: int, + height: int, + width: int, + dtype: torch.dtype, + device: torch.device, + generator: Optional[Union[torch.Generator, List[torch.Generator]]], + ) -> torch.Tensor: + in_channels = int(getattr(self.unet, "in_channels", getattr(self.unet.config, "in_channels", 4))) + latent_size = height // self.vae_scale_factor + return randn_tensor( + (batch_size, in_channels, latent_size, latent_size), + generator=generator, + device=device, + dtype=torch.float32, + ) + + def decode_latents(self, latents: torch.Tensor, output_type: str = "pil"): + if output_type == "latent": + return latents + + in_channels = int(getattr(self.unet, "in_channels", getattr(self.unet.config, "in_channels", 3))) + if self.vae is None: + image = (latents.to(torch.float32) * 127.5 + 128).clip(0, 255) / 255.0 + return self.image_processor.postprocess(image, output_type=output_type) + + if in_channels == 4: + x = latents.to(torch.float32) + scale = torch.as_tensor(_STABILITY_VAE_SCALE, dtype=x.dtype, device=x.device).reshape(1, -1, 1, 1) + bias = torch.as_tensor(_STABILITY_VAE_BIAS, dtype=x.dtype, device=x.device).reshape(1, -1, 1, 1) + x = (x - bias) / scale + else: + x = latents.to(torch.float32) + + vae_dtype = getattr(self.vae, "dtype", None) or next(self.vae.parameters()).dtype + image = self.vae.decode(x.to(dtype=vae_dtype)).sample.to(torch.float32).clamp(0, 1) + + return self.image_processor.postprocess(image, output_type=output_type) + + @staticmethod + def _apply_autoguidance( + main: torch.Tensor, + ref: torch.Tensor, + guidance_scale: float, + ) -> torch.Tensor: + return ref.lerp(main, guidance_scale) + + @staticmethod + def _sample_edm2_heun( + denoise_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor], + noise: torch.Tensor, + sigmas: torch.Tensor, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + progress_bar: Optional[Callable[[Iterable], Iterable]] = None, + dtype: torch.dtype = torch.float32, + ) -> torch.Tensor: + """NVlabs EDM2 Heun sampler (generate_images.edm_sampler, guidance=1, S_churn=0).""" + x_next = noise.to(dtype) * sigmas[0] + + sigma_pairs = list(zip(sigmas[:-1], sigmas[1:])) + if progress_bar is not None: + sigma_pairs = progress_bar(sigma_pairs) + + num_steps = len(sigma_pairs) + for i, (sigma_cur, sigma_next) in enumerate(sigma_pairs): + x_hat, sigma_hat = x_next, sigma_cur + d_cur = (x_hat - denoise_fn(x_hat, sigma_hat)) / sigma_hat + x_next = x_hat + (sigma_next - sigma_hat) * d_cur + if i < num_steps - 1: + d_prime = (x_next - denoise_fn(x_next, sigma_next)) / sigma_next + x_next = x_hat + (sigma_next - sigma_hat) * (0.5 * d_cur + 0.5 * d_prime) + return x_next + + @torch.inference_mode() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + class_labels: Optional[Union[int, str, Sequence[Union[int, str]], torch.Tensor]] = None, + batch_size: int = 1, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 32, + guidance_scale: float = 1.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + output_type: str = "pil", + return_dict: bool = True, + ) -> Union[ImagePipelineOutput, Tuple]: + r""" + Generate class-conditional images with EDM2. + + Args: + class_labels (`int`, `str`, `list[int]`, `list[str]`, or `torch.Tensor`, *optional*): + ImageNet class indices, English label strings, or one-hot float tensors. + Random classes are sampled when omitted on conditional models. + batch_size (`int`, defaults to `1`): + Number of images to generate. + height (`int`, *optional*): + Output height in pixels. Defaults to the pretrained native resolution. + width (`int`, *optional*): + Output width in pixels. Defaults to the pretrained native resolution. + num_inference_steps (`int`, defaults to `32`): + Number of EDM2 Heun steps (NVlabs default). + guidance_scale (`float`, defaults to `1.0`): + Autoguidance strength. Values above `1.0` blend the main net with `gnet` + via `gnet_output.lerp(unet_output, guidance_scale)`. + generator (`torch.Generator`, *optional*): + RNG for reproducibility. + output_type (`str`, defaults to `"pil"`): + `"pil"`, `"np"`, `"pt"`, or `"latent"`. + return_dict (`bool`, defaults to `True`): + Return [`~pipelines.pipeline_utils.ImagePipelineOutput`] if True. + + Examples: + + """ + default_size = self._default_image_size() + height = int(height or default_size) + width = int(width or default_size) + self.check_inputs(height, width, num_inference_steps, guidance_scale, output_type) + + device = self._execution_device + dtype = self.unet.dtype + labels = self._normalize_class_labels(class_labels, batch_size=batch_size, device=device) + noise = self.prepare_latents(batch_size, height, width, dtype, device, generator) + + def denoise_fn(x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor: + sigma_batch = sigma.reshape(1).expand(batch_size) + main = self.unet( + sample=x, + sigma=sigma_batch, + class_labels=labels, + force_fp32=True, + ).sample + if guidance_scale == 1.0 or self.gnet is None: + return main.to(torch.float32) + ref = self.gnet( + sample=x, + sigma=sigma_batch, + class_labels=labels, + force_fp32=True, + ).sample + return self._apply_autoguidance(main, ref, guidance_scale).to(torch.float32) + + self.scheduler.set_timesteps(num_inference_steps, device=device) + latents = self._sample_edm2_heun( + denoise_fn=denoise_fn, + noise=noise, + sigmas=self.scheduler.sigmas.to(device), + generator=generator, + progress_bar=self.progress_bar, + dtype=torch.float32, + ) + + image = self.decode_latents(latents, output_type=output_type) + if not return_dict: + return (image, latents) + return ImagePipelineOutput(images=image) + + @classmethod + def _load_vae(cls, pretrained_model_name_or_path: str, torch_dtype: Optional[torch.dtype] = None): + vae_dir = os.path.join(pretrained_model_name_or_path, "vae") + if os.path.isdir(vae_dir): + try: + + return AutoencoderKL.from_pretrained(vae_dir, torch_dtype=torch_dtype) + except Exception: + return None + + vae_hint = os.path.join(pretrained_model_name_or_path, "vae_pretrained_model_name_or_path.txt") + if os.path.isfile(vae_hint): + with open(vae_hint, "r", encoding="utf-8") as f: + hub_id = f.read().strip() + if hub_id: + + return AutoencoderKL.from_pretrained(hub_id, torch_dtype=torch_dtype) + return None diff --git a/edm2-img512-l-fid/scheduler/scheduler_config.json b/edm2-img512-l-fid/scheduler/scheduler_config.json new file mode 100644 index 0000000000000000000000000000000000000000..08cbde97955d06d66b0b72be217db84bb797d711 --- /dev/null +++ b/edm2-img512-l-fid/scheduler/scheduler_config.json @@ -0,0 +1,11 @@ +{ + "_class_name": "EDMEulerScheduler", + "final_sigmas_type": "zero", + "num_train_timesteps": 1000, + "prediction_type": "epsilon", + "rho": 7.0, + "sigma_data": 0.5, + "sigma_max": 80.0, + "sigma_min": 0.002, + "sigma_schedule": "karras" +} diff --git a/edm2-img512-l-fid/unet/config.json b/edm2-img512-l-fid/unet/config.json new file mode 100644 index 0000000000000000000000000000000000000000..939ef1d82a8da253ce5dc9b30d3ac644f49dc4ed --- /dev/null +++ b/edm2-img512-l-fid/unet/config.json @@ -0,0 +1,31 @@ +{ + "_class_name": "EDM2UNet2DModel", + "attn_balance": 0.3, + "attn_resolutions": [ + 16, + 8 + ], + "channel_mult": [ + 1, + 2, + 3, + 4 + ], + "channel_mult_emb": 4, + "channel_mult_noise": 1, + "channels_per_head": 64, + "clip_act": 256, + "concat_balance": 0.5, + "dropout": 0.0, + "in_channels": 4, + "label_balance": 0.5, + "logvar_channels": 128, + "model_channels": 320, + "num_blocks": 3, + "num_class_embeds": 1000, + "out_channels": 4, + "res_balance": 0.3, + "sample_size": 64, + "sigma_data": 0.5, + "use_fp16": true +} diff --git a/edm2-img512-l-fid/unet/diffusion_pytorch_model.safetensors b/edm2-img512-l-fid/unet/diffusion_pytorch_model.safetensors new file mode 100644 index 0000000000000000000000000000000000000000..3b76a348ae7a2f1b1d650d963890cf0e6e98ad5e --- /dev/null +++ b/edm2-img512-l-fid/unet/diffusion_pytorch_model.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7a3e3f5127c12027e4796bef297e247a38ddd13bb7b8445c5d41169106b94389 +size 3110018564 diff --git a/edm2-img512-l-fid/unet/unet_edm2.py b/edm2-img512-l-fid/unet/unet_edm2.py new file mode 100644 index 0000000000000000000000000000000000000000..49d9451b443c4d6ad3b6ba2db19a7a804b07b6de --- /dev/null +++ b/edm2-img512-l-fid/unet/unet_edm2.py @@ -0,0 +1,434 @@ +import math +import json +import os +from dataclasses import dataclass +from typing import List, Optional, Tuple + +import numpy as np +import torch + +try: + from diffusers.configuration_utils import ConfigMixin, register_to_config + from diffusers.models.modeling_utils import ModelMixin + from diffusers.utils import BaseOutput +except ImportError: # pragma: no cover + class ModelMixin(torch.nn.Module): + pass + + class ConfigMixin: + config = {} + + def register_to_config(self, **kwargs): + self.config = kwargs + + def register_to_config(func): + return func + + @dataclass + class BaseOutput: + pass + + +def normalize(x: torch.Tensor, dim: Optional[List[int]] = None, eps: float = 1e-4) -> torch.Tensor: + if dim is None: + dim = list(range(1, x.ndim)) + norm = torch.linalg.vector_norm(x, dim=dim, keepdim=True, dtype=torch.float32) + norm = torch.add(eps, norm, alpha=math.sqrt(norm.numel() / x.numel())) + return x / norm.to(x.dtype) + + +def resample(x: torch.Tensor, f: List[float], mode: str = "keep") -> torch.Tensor: + if mode == "keep": + return x + filt = np.float32(f) + pad = (len(filt) - 1) // 2 + filt = filt / filt.sum() + filt = np.outer(filt, filt)[np.newaxis, np.newaxis, :, :] + filt = torch.as_tensor(filt, dtype=x.dtype, device=x.device) + c = x.shape[1] + if mode == "down": + return torch.nn.functional.conv2d(x, filt.tile([c, 1, 1, 1]), groups=c, stride=2, padding=(pad,)) + return torch.nn.functional.conv_transpose2d(x, (filt * 4).tile([c, 1, 1, 1]), groups=c, stride=2, padding=(pad,)) + + +def mp_silu(x: torch.Tensor) -> torch.Tensor: + return torch.nn.functional.silu(x) / 0.596 + + +def mp_sum(a: torch.Tensor, b: torch.Tensor, t: float = 0.5) -> torch.Tensor: + return a.lerp(b, t) / math.sqrt((1 - t) ** 2 + t ** 2) + + +def mp_cat(a: torch.Tensor, b: torch.Tensor, dim: int = 1, t: float = 0.5) -> torch.Tensor: + na = a.shape[dim] + nb = b.shape[dim] + c = math.sqrt((na + nb) / ((1 - t) ** 2 + t ** 2)) + wa = c / math.sqrt(na) * (1 - t) + wb = c / math.sqrt(nb) * t + return torch.cat([wa * a, wb * b], dim=dim) + + +class MPFourier(torch.nn.Module): + def __init__(self, num_channels: int, bandwidth: float = 1): + super().__init__() + self.register_buffer("freqs", 2 * math.pi * torch.randn(num_channels) * bandwidth) + self.register_buffer("phases", 2 * math.pi * torch.rand(num_channels)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + y = x.to(torch.float32).ger(self.freqs.to(torch.float32)) + y = y + self.phases.to(torch.float32) + y = y.cos() * math.sqrt(2) + return y.to(x.dtype) + + +class MPConv(torch.nn.Module): + def __init__(self, in_channels: int, out_channels: int, kernel: Tuple[int, ...]): + super().__init__() + self.out_channels = out_channels + self.weight = torch.nn.Parameter(torch.randn(out_channels, in_channels, *kernel)) + + def forward(self, x: torch.Tensor, gain: float = 1) -> torch.Tensor: + w = self.weight.to(torch.float32) + if self.training: + with torch.no_grad(): + self.weight.copy_(normalize(w)) + w = normalize(w) + w = w * (gain / math.sqrt(w[0].numel())) + w = w.to(x.dtype) + if w.ndim == 2: + return x @ w.t() + return torch.nn.functional.conv2d(x, w, padding=(w.shape[-1] // 2,)) + + +class Block(torch.nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + emb_channels: int, + flavor: str = "enc", + resample_mode: str = "keep", + resample_filter: List[float] = [1, 1], + attention: bool = False, + channels_per_head: int = 64, + dropout: float = 0.0, + res_balance: float = 0.3, + attn_balance: float = 0.3, + clip_act: Optional[float] = 256, + ): + super().__init__() + self.out_channels = out_channels + self.flavor = flavor + self.resample_filter = resample_filter + self.resample_mode = resample_mode + self.num_heads = out_channels // channels_per_head if attention else 0 + self.dropout = dropout + self.res_balance = res_balance + self.attn_balance = attn_balance + self.clip_act = clip_act + self.emb_gain = torch.nn.Parameter(torch.zeros([])) + self.conv_res0 = MPConv(out_channels if flavor == "enc" else in_channels, out_channels, kernel=(3, 3)) + self.emb_linear = MPConv(emb_channels, out_channels, kernel=()) + self.conv_res1 = MPConv(out_channels, out_channels, kernel=(3, 3)) + self.conv_skip = MPConv(in_channels, out_channels, kernel=(1, 1)) if in_channels != out_channels else None + self.attn_qkv = MPConv(out_channels, out_channels * 3, kernel=(1, 1)) if self.num_heads else None + self.attn_proj = MPConv(out_channels, out_channels, kernel=(1, 1)) if self.num_heads else None + + def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor: + x = resample(x, f=self.resample_filter, mode=self.resample_mode) + if self.flavor == "enc": + if self.conv_skip is not None: + x = self.conv_skip(x) + x = normalize(x, dim=[1]) + + y = self.conv_res0(mp_silu(x)) + c = self.emb_linear(emb, gain=self.emb_gain) + 1 + y = mp_silu(y * c.unsqueeze(2).unsqueeze(3).to(y.dtype)) + if self.training and self.dropout: + y = torch.nn.functional.dropout(y, p=self.dropout) + y = self.conv_res1(y) + + if self.flavor == "dec" and self.conv_skip is not None: + x = self.conv_skip(x) + x = mp_sum(x, y, t=self.res_balance) + + if self.num_heads: + y = self.attn_qkv(x) + y = y.reshape(y.shape[0], self.num_heads, -1, 3, y.shape[2] * y.shape[3]) + q, k, v = normalize(y, dim=[2]).unbind(3) + w = torch.einsum("nhcq,nhck->nhqk", q, k / math.sqrt(q.shape[2])).softmax(dim=3) + y = torch.einsum("nhqk,nhck->nhcq", w, v) + y = self.attn_proj(y.reshape(*x.shape)) + x = mp_sum(x, y, t=self.attn_balance) + + if self.clip_act is not None: + x = x.clip_(-self.clip_act, self.clip_act) + return x + + +class EDM2UNet(torch.nn.Module): + def __init__( + self, + img_resolution: int, + img_channels: int, + label_dim: int, + model_channels: int = 192, + channel_mult: Tuple[int, ...] = (1, 2, 3, 4), + channel_mult_noise: Optional[int] = None, + channel_mult_emb: Optional[int] = None, + num_blocks: int = 3, + attn_resolutions: Tuple[int, ...] = (16, 8), + label_balance: float = 0.5, + concat_balance: float = 0.5, + **block_kwargs, + ): + super().__init__() + cblock = [model_channels * x for x in channel_mult] + cnoise = model_channels * channel_mult_noise if channel_mult_noise is not None else cblock[0] + cemb = model_channels * channel_mult_emb if channel_mult_emb is not None else max(cblock) + self.label_balance = label_balance + self.concat_balance = concat_balance + self.out_gain = torch.nn.Parameter(torch.zeros([])) + + self.emb_fourier = MPFourier(cnoise) + self.emb_noise = MPConv(cnoise, cemb, kernel=()) + self.emb_label = MPConv(label_dim, cemb, kernel=()) if label_dim else None + + self.enc = torch.nn.ModuleDict() + cout = img_channels + 1 + for level, channels in enumerate(cblock): + res = img_resolution >> level + if level == 0: + cin = cout + cout = channels + self.enc[f"{res}x{res}_conv"] = MPConv(cin, cout, kernel=(3, 3)) + else: + self.enc[f"{res}x{res}_down"] = Block(cout, cout, cemb, flavor="enc", resample_mode="down", **block_kwargs) + for idx in range(num_blocks): + cin = cout + cout = channels + self.enc[f"{res}x{res}_block{idx}"] = Block( + cin, + cout, + cemb, + flavor="enc", + attention=(res in attn_resolutions), + **block_kwargs, + ) + + self.dec = torch.nn.ModuleDict() + skips = [block.out_channels for block in self.enc.values()] + for level, channels in reversed(list(enumerate(cblock))): + res = img_resolution >> level + if level == len(cblock) - 1: + self.dec[f"{res}x{res}_in0"] = Block(cout, cout, cemb, flavor="dec", attention=True, **block_kwargs) + self.dec[f"{res}x{res}_in1"] = Block(cout, cout, cemb, flavor="dec", **block_kwargs) + else: + self.dec[f"{res}x{res}_up"] = Block(cout, cout, cemb, flavor="dec", resample_mode="up", **block_kwargs) + for idx in range(num_blocks + 1): + cin = cout + skips.pop() + cout = channels + self.dec[f"{res}x{res}_block{idx}"] = Block( + cin, + cout, + cemb, + flavor="dec", + attention=(res in attn_resolutions), + **block_kwargs, + ) + + self.out_conv = MPConv(cout, img_channels, kernel=(3, 3)) + + def forward(self, x: torch.Tensor, noise_labels: torch.Tensor, class_labels: Optional[torch.Tensor]) -> torch.Tensor: + emb = self.emb_noise(self.emb_fourier(noise_labels)) + if self.emb_label is not None: + if class_labels is None: + raise ValueError("class_labels are required for conditional EDM2UNet.") + emb = mp_sum(emb, self.emb_label(class_labels * math.sqrt(class_labels.shape[1])), t=self.label_balance) + emb = mp_silu(emb) + + x = torch.cat([x, torch.ones_like(x[:, :1])], dim=1) + skips = [] + for name, block in self.enc.items(): + x = block(x) if "conv" in name else block(x, emb) + skips.append(x) + + for name, block in self.dec.items(): + if "block" in name: + x = mp_cat(x, skips.pop(), t=self.concat_balance) + x = block(x, emb) + return self.out_conv(x, gain=self.out_gain) + + +@dataclass +class EDM2UNet2DOutput(BaseOutput): + sample: torch.Tensor + logvar: Optional[torch.Tensor] = None + + + +_CONFIG_KEYS = ( + "sample_size", + "in_channels", + "out_channels", + "num_class_embeds", + "use_fp16", + "sigma_data", + "logvar_channels", + "model_channels", + "channel_mult", + "channel_mult_noise", + "channel_mult_emb", + "num_blocks", + "attn_resolutions", + "label_balance", + "concat_balance", + "dropout", + "channels_per_head", + "res_balance", + "attn_balance", + "clip_act", +) + + +class EDM2UNet2DModel(ModelMixin, ConfigMixin): + @register_to_config + def __init__( + self, + sample_size: int = 64, + in_channels: int = 4, + out_channels: int = 4, + num_class_embeds: int = 0, + use_fp16: bool = True, + sigma_data: float = 0.5, + logvar_channels: int = 128, + model_channels: int = 192, + channel_mult: Tuple[int, ...] = (1, 2, 3, 4), + channel_mult_noise: Optional[int] = None, + channel_mult_emb: Optional[int] = None, + num_blocks: int = 3, + attn_resolutions: Tuple[int, ...] = (16, 8), + label_balance: float = 0.5, + concat_balance: float = 0.5, + dropout: float = 0.0, + channels_per_head: int = 64, + res_balance: float = 0.3, + attn_balance: float = 0.3, + clip_act: Optional[float] = 256, + ): + super().__init__() + self.sample_size = sample_size + self.in_channels = in_channels + self.out_channels = out_channels + self.num_class_embeds = num_class_embeds + self.use_fp16 = use_fp16 + self.sigma_data = sigma_data + self.model_channels = model_channels + self.channel_mult = channel_mult + self.channel_mult_noise = channel_mult_noise + self.channel_mult_emb = channel_mult_emb + self.num_blocks = num_blocks + self.attn_resolutions = attn_resolutions + self.label_balance = label_balance + self.concat_balance = concat_balance + self.dropout = dropout + self.channels_per_head = channels_per_head + self.res_balance = res_balance + self.attn_balance = attn_balance + self.clip_act = clip_act + self.unet = EDM2UNet( + img_resolution=sample_size, + img_channels=in_channels, + label_dim=num_class_embeds, + model_channels=model_channels, + channel_mult=channel_mult, + channel_mult_noise=channel_mult_noise, + channel_mult_emb=channel_mult_emb, + num_blocks=num_blocks, + attn_resolutions=attn_resolutions, + label_balance=label_balance, + concat_balance=concat_balance, + dropout=dropout, + channels_per_head=channels_per_head, + res_balance=res_balance, + attn_balance=attn_balance, + clip_act=clip_act, + ) + self.logvar_fourier = MPFourier(logvar_channels) + self.logvar_linear = MPConv(logvar_channels, 1, kernel=()) + + def forward( + self, + sample: torch.Tensor, + sigma: torch.Tensor, + class_labels: Optional[torch.Tensor] = None, + force_fp32: bool = False, + return_logvar: bool = False, + return_dict: bool = True, + ) -> EDM2UNet2DOutput: + x = sample.to(torch.float32) + sigma = sigma.to(torch.float32).reshape(-1, 1, 1, 1) + if self.num_class_embeds == 0: + class_labels = None + else: + if class_labels is None: + class_labels = torch.zeros([x.shape[0], self.num_class_embeds], device=x.device) + class_labels = class_labels.to(torch.float32).reshape(-1, self.num_class_embeds) + dtype = torch.float16 if (self.use_fp16 and not force_fp32 and x.device.type == "cuda") else torch.float32 + + c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2) + c_out = sigma * self.sigma_data / (sigma**2 + self.sigma_data**2).sqrt() + c_in = 1 / (self.sigma_data**2 + sigma**2).sqrt() + c_noise = sigma.flatten().log() / 4 + + x_in = (c_in * x).to(dtype) + f_x = self.unet(x_in, c_noise, class_labels) + d_x = c_skip * x + c_out * f_x.to(torch.float32) + + logvar = None + if return_logvar: + logvar = self.logvar_linear(self.logvar_fourier(c_noise)).reshape(-1, 1, 1, 1) + + if not return_dict: + return (d_x, logvar) + return EDM2UNet2DOutput(sample=d_x, logvar=logvar) + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path: str, torch_dtype: Optional[torch.dtype] = None, **kwargs): + subfolder = kwargs.pop("subfolder", None) + model_dir = os.path.join(pretrained_model_name_or_path, subfolder) if subfolder else pretrained_model_name_or_path + with open(os.path.join(model_dir, "config.json"), "r", encoding="utf-8") as f: + config = json.load(f) + init_kwargs = {k: v for k, v in config.items() if k in _CONFIG_KEYS} + model = cls(**init_kwargs) + weight_file = os.path.join(model_dir, "diffusion_pytorch_model.safetensors") + if os.path.isfile(weight_file): + from safetensors.torch import load_file + + state_dict = load_file(weight_file) + else: + state_dict = torch.load(os.path.join(model_dir, "diffusion_pytorch_model.bin"), map_location="cpu") + model.load_state_dict(state_dict, strict=True) + if torch_dtype is not None: + model = model.to(dtype=torch_dtype) + return model + + def save_pretrained(self, save_directory: str, safe_serialization: bool = True): + os.makedirs(save_directory, exist_ok=True) + stored = dict(getattr(self, "config", {})) + config = {"_class_name": self.__class__.__name__} + for key in _CONFIG_KEYS: + if key in stored: + config[key] = stored[key] + elif hasattr(self, key): + config[key] = getattr(self, key) + with open(os.path.join(save_directory, "config.json"), "w", encoding="utf-8") as f: + json.dump(config, f, indent=2, sort_keys=True) + f.write("\n") + state_dict = self.state_dict() + if safe_serialization: + from safetensors.torch import save_file + + save_file(state_dict, os.path.join(save_directory, "diffusion_pytorch_model.safetensors")) + else: + torch.save(state_dict, os.path.join(save_directory, "diffusion_pytorch_model.bin")) diff --git a/edm2-img512-l-fid/vae/config.json b/edm2-img512-l-fid/vae/config.json new file mode 100644 index 0000000000000000000000000000000000000000..ae2ab37de2acb8bc348d3de809ae8385324be962 --- /dev/null +++ b/edm2-img512-l-fid/vae/config.json @@ -0,0 +1,38 @@ +{ + "_class_name": "AutoencoderKL", + "_diffusers_version": "0.36.0", + "_name_or_path": "stabilityai/sd-vae-ft-mse", + "act_fn": "silu", + "block_out_channels": [ + 128, + 256, + 512, + 512 + ], + "down_block_types": [ + "DownEncoderBlock2D", + "DownEncoderBlock2D", + "DownEncoderBlock2D", + "DownEncoderBlock2D" + ], + "force_upcast": true, + "in_channels": 3, + "latent_channels": 4, + "latents_mean": null, + "latents_std": null, + "layers_per_block": 2, + "mid_block_add_attention": true, + "norm_num_groups": 32, + "out_channels": 3, + "sample_size": 256, + "scaling_factor": 0.18215, + "shift_factor": null, + "up_block_types": [ + "UpDecoderBlock2D", + "UpDecoderBlock2D", + "UpDecoderBlock2D", + "UpDecoderBlock2D" + ], + "use_post_quant_conv": true, + "use_quant_conv": true +} diff --git a/edm2-img512-l-fid/vae/diffusion_pytorch_model.safetensors b/edm2-img512-l-fid/vae/diffusion_pytorch_model.safetensors new file mode 100644 index 0000000000000000000000000000000000000000..90464d67ac7303d0ee4696334df13da130a948ea --- /dev/null +++ b/edm2-img512-l-fid/vae/diffusion_pytorch_model.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a1d993488569e928462932c8c38a0760b874d166399b14414135bd9c42df5815 +size 334643276 diff --git a/edm2-img512-m-fid/demo.png b/edm2-img512-m-fid/demo.png new file mode 100644 index 0000000000000000000000000000000000000000..6f729fad3723c801849aef1ab04df055413e8955 --- /dev/null +++ b/edm2-img512-m-fid/demo.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:bda2cb48c7ab17b37fbfa0599c7fec6d1f8d7de6848990f870e8ff4b613c929d +size 369586 diff --git a/edm2-img512-m-fid/model_index.json b/edm2-img512-m-fid/model_index.json new file mode 100644 index 0000000000000000000000000000000000000000..24c0305eea3ec64abbc2f870476afa4f5e462f42 --- /dev/null +++ b/edm2-img512-m-fid/model_index.json @@ -0,0 +1,19 @@ +{ + "_class_name": [ + "pipeline", + "EDM2Pipeline" + ], + "_diffusers_version": "0.31.0", + "scheduler": [ + "diffusers", + "EDMEulerScheduler" + ], + "unet": [ + "unet_edm2", + "EDM2UNet2DModel" + ], + "vae": [ + "diffusers", + "AutoencoderKL" + ] +} diff --git a/edm2-img512-m-fid/pipeline.py b/edm2-img512-m-fid/pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..6b4131215d07770c74efa0674d00f3d6423d9c06 --- /dev/null +++ b/edm2-img512-m-fid/pipeline.py @@ -0,0 +1,406 @@ +"""Hub custom pipeline: EDM2Pipeline. +Load with native Hugging Face diffusers and trust_remote_code=True. +""" + +from __future__ import annotations + +# Copyright 2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os +from pathlib import Path +from typing import Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Union + +import numpy as np +import torch + +from diffusers.image_processor import VaeImageProcessor +from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput +from diffusers.utils import replace_example_docstring +from diffusers.utils.torch_utils import randn_tensor + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> from pathlib import Path + >>> import torch + >>> from diffusers import DiffusionPipeline + + >>> model_dir = Path("BiliSakura/EDM2-diffusers/edm2-img512-xs-fid").resolve() + >>> pipe = DiffusionPipeline.from_pretrained( + ... str(model_dir), + ... local_files_only=True, + ... custom_pipeline=str(model_dir / "pipeline.py"), + ... trust_remote_code=True, + ... torch_dtype=torch.float32, + ... ) + >>> pipe.to("cuda") + + >>> generator = torch.Generator(device="cuda").manual_seed(42) + >>> image = pipe( + ... class_labels=207, + ... num_inference_steps=32, + ... guidance_scale=1.0, + ... generator=generator, + ... ).images[0] + >>> image.save("demo.png") + ``` +""" + +# Default Stability VAE latent whitening used by NVlabs/edm2 (training/encoders.py). +_STABILITY_VAE_SCALE = np.float32(0.5) / np.float32([4.17, 4.62, 3.71, 3.28]) +_STABILITY_VAE_BIAS = np.float32(0.0) - np.float32([5.81, 3.25, 0.12, -2.15]) * _STABILITY_VAE_SCALE + +class EDM2Pipeline(DiffusionPipeline): + r""" + Pipeline for class-conditional image generation with EDM2 + ([Analyzing and Improving the Training Dynamics of Diffusion Models](https://arxiv.org/abs/2312.02696)). + + Parameters: + unet ([`EDM2UNet2DModel`]): + Main magnitude-preserving U-Net with EDM preconditioning. + scheduler ([`EDMEulerScheduler`]): + Built-in diffusers scheduler used for the Karras sigma schedule. EDM2 Heun sampling runs in + the pipeline because the UNet returns denoised latents rather than noise predictions. + vae ([`AutoencoderKL`], *optional*): + Decoder for 512px latent-diffusion checkpoints. Required when `unet.in_channels == 4`. + gnet ([`EDM2UNet2DModel`], *optional*): + Guiding network for autoguidance (`ref.lerp(main, guidance_scale)`). + id2label (`dict[int, str]`, *optional*): + ImageNet class id to English label mapping. + """ + + model_cpu_offload_seq = "unet->gnet->vae" + _optional_components = ["vae", "gnet"] + + def __init__( + self, + unet, + scheduler, + vae=None, + gnet=None, + id2label: Optional[Dict[Union[int, str], str]] = None, + ) -> None: + super().__init__() + self.register_modules(unet=unet, scheduler=scheduler, vae=vae, gnet=gnet) + self._id2label = self._normalize_id2label(id2label) + self.labels = self._build_label2id(self._id2label) + self._labels_loaded_from_model_index = bool(self._id2label) + self.vae_scale_factor = 8 if self.vae is not None else 1 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_normalize=False) + + @staticmethod + def _normalize_id2label(id2label: Optional[Dict[Union[int, str], str]]) -> Dict[int, str]: + if not id2label: + return {} + return {int(key): value for key, value in id2label.items()} + + @staticmethod + def _build_label2id(id2label: Dict[int, str]) -> Dict[str, int]: + label2id: Dict[str, int] = {} + for class_id, value in id2label.items(): + for synonym in value.split(","): + synonym = synonym.strip() + if synonym: + label2id[synonym] = int(class_id) + return dict(sorted(label2id.items())) + + def _ensure_labels_loaded(self) -> None: + if self._labels_loaded_from_model_index: + return + loaded = self._read_id2label_from_model_index(getattr(self.config, "_name_or_path", None)) + if loaded: + self._id2label = loaded + self.labels = self._build_label2id(self._id2label) + self._labels_loaded_from_model_index = True + + @staticmethod + def _read_id2label_from_model_index(variant_path: Optional[str]) -> Dict[int, str]: + if not variant_path: + return {} + model_index_path = Path(variant_path).resolve() / "model_index.json" + if not model_index_path.is_file(): + return {} + raw = json.loads(model_index_path.read_text(encoding="utf-8")) + id2label = raw.get("id2label") + if not isinstance(id2label, dict): + return {} + return {int(key): value for key, value in id2label.items()} + + @property + def id2label(self) -> Dict[int, str]: + r"""ImageNet class id to English label string (comma-separated synonyms).""" + self._ensure_labels_loaded() + return self._id2label + + def get_label_ids(self, label: Union[str, List[str]]) -> List[int]: + r""" + Map ImageNet label strings to class ids. + + Args: + label (`str` or `list[str]`): + One or more English label strings that match entries in `id2label`. + """ + self._ensure_labels_loaded() + if not self.labels: + raise ValueError("No English labels loaded. Add `id2label` to model_index.json.") + labels = [label] if isinstance(label, str) else list(label) + missing = [item for item in labels if item not in self.labels] + if missing: + preview = ", ".join(list(self.labels.keys())[:8]) + raise ValueError(f"Unknown English label(s): {missing}. Example valid labels: {preview}, ...") + return [self.labels[item] for item in labels] + + def _default_image_size(self) -> int: + latent_size = int(getattr(self.unet, "sample_size", getattr(self.unet.config, "sample_size", 64))) + return latent_size * self.vae_scale_factor + + def check_inputs( + self, + height: int, + width: int, + num_inference_steps: int, + guidance_scale: float, + output_type: str, + ) -> None: + if num_inference_steps < 1: + raise ValueError("num_inference_steps must be >= 1.") + if guidance_scale < 1.0: + raise ValueError("guidance_scale must be >= 1.0.") + if guidance_scale > 1.0 and self.gnet is None: + raise ValueError("guidance_scale > 1.0 requires a guiding network (`gnet`).") + if output_type not in {"pil", "np", "pt", "latent"}: + raise ValueError("output_type must be one of: 'pil', 'np', 'pt', 'latent'.") + + native_size = self._default_image_size() + if height != native_size or width != native_size: + raise ValueError( + f"EDM2 expects native resolution height=width={native_size}. " + f"Got height={height}, width={width}." + ) + + def _normalize_class_labels( + self, + class_labels: Optional[Union[int, str, Sequence[Union[int, str]], torch.Tensor]], + batch_size: int, + device: torch.device, + ) -> Optional[torch.Tensor]: + label_dim = int(getattr(self.unet, "num_class_embeds", getattr(self.unet.config, "num_class_embeds", 0))) + if label_dim == 0: + return None + if class_labels is None: + indices = torch.randint(label_dim, size=(batch_size,), device=device) + return torch.eye(label_dim, device=device, dtype=torch.float32)[indices] + + if isinstance(class_labels, str): + class_labels = self.get_label_ids(class_labels)[0] + elif isinstance(class_labels, Sequence) and class_labels and isinstance(class_labels[0], str): + class_labels = self.get_label_ids(list(class_labels)) + + if isinstance(class_labels, int): + indices = torch.full((batch_size,), class_labels, device=device, dtype=torch.long) + elif isinstance(class_labels, torch.Tensor): + if class_labels.ndim == 2: + labels = class_labels.to(device=device, dtype=torch.float32) + if labels.shape[0] != batch_size: + raise ValueError(f"class_labels batch must match batch_size={batch_size}.") + return labels + indices = class_labels.to(device=device, dtype=torch.long).flatten() + else: + indices = torch.tensor(list(class_labels), device=device, dtype=torch.long) + + if indices.numel() == 1 and batch_size > 1: + indices = indices.repeat(batch_size) + if indices.numel() != batch_size: + raise ValueError(f"class_labels must resolve to batch size {batch_size}, got {indices.numel()}.") + return torch.eye(label_dim, device=device, dtype=torch.float32)[indices] + + def prepare_latents( + self, + batch_size: int, + height: int, + width: int, + dtype: torch.dtype, + device: torch.device, + generator: Optional[Union[torch.Generator, List[torch.Generator]]], + ) -> torch.Tensor: + in_channels = int(getattr(self.unet, "in_channels", getattr(self.unet.config, "in_channels", 4))) + latent_size = height // self.vae_scale_factor + return randn_tensor( + (batch_size, in_channels, latent_size, latent_size), + generator=generator, + device=device, + dtype=torch.float32, + ) + + def decode_latents(self, latents: torch.Tensor, output_type: str = "pil"): + if output_type == "latent": + return latents + + in_channels = int(getattr(self.unet, "in_channels", getattr(self.unet.config, "in_channels", 3))) + if self.vae is None: + image = (latents.to(torch.float32) * 127.5 + 128).clip(0, 255) / 255.0 + return self.image_processor.postprocess(image, output_type=output_type) + + if in_channels == 4: + x = latents.to(torch.float32) + scale = torch.as_tensor(_STABILITY_VAE_SCALE, dtype=x.dtype, device=x.device).reshape(1, -1, 1, 1) + bias = torch.as_tensor(_STABILITY_VAE_BIAS, dtype=x.dtype, device=x.device).reshape(1, -1, 1, 1) + x = (x - bias) / scale + else: + x = latents.to(torch.float32) + + vae_dtype = getattr(self.vae, "dtype", None) or next(self.vae.parameters()).dtype + image = self.vae.decode(x.to(dtype=vae_dtype)).sample.to(torch.float32).clamp(0, 1) + + return self.image_processor.postprocess(image, output_type=output_type) + + @staticmethod + def _apply_autoguidance( + main: torch.Tensor, + ref: torch.Tensor, + guidance_scale: float, + ) -> torch.Tensor: + return ref.lerp(main, guidance_scale) + + @staticmethod + def _sample_edm2_heun( + denoise_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor], + noise: torch.Tensor, + sigmas: torch.Tensor, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + progress_bar: Optional[Callable[[Iterable], Iterable]] = None, + dtype: torch.dtype = torch.float32, + ) -> torch.Tensor: + """NVlabs EDM2 Heun sampler (generate_images.edm_sampler, guidance=1, S_churn=0).""" + x_next = noise.to(dtype) * sigmas[0] + + sigma_pairs = list(zip(sigmas[:-1], sigmas[1:])) + if progress_bar is not None: + sigma_pairs = progress_bar(sigma_pairs) + + num_steps = len(sigma_pairs) + for i, (sigma_cur, sigma_next) in enumerate(sigma_pairs): + x_hat, sigma_hat = x_next, sigma_cur + d_cur = (x_hat - denoise_fn(x_hat, sigma_hat)) / sigma_hat + x_next = x_hat + (sigma_next - sigma_hat) * d_cur + if i < num_steps - 1: + d_prime = (x_next - denoise_fn(x_next, sigma_next)) / sigma_next + x_next = x_hat + (sigma_next - sigma_hat) * (0.5 * d_cur + 0.5 * d_prime) + return x_next + + @torch.inference_mode() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + class_labels: Optional[Union[int, str, Sequence[Union[int, str]], torch.Tensor]] = None, + batch_size: int = 1, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 32, + guidance_scale: float = 1.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + output_type: str = "pil", + return_dict: bool = True, + ) -> Union[ImagePipelineOutput, Tuple]: + r""" + Generate class-conditional images with EDM2. + + Args: + class_labels (`int`, `str`, `list[int]`, `list[str]`, or `torch.Tensor`, *optional*): + ImageNet class indices, English label strings, or one-hot float tensors. + Random classes are sampled when omitted on conditional models. + batch_size (`int`, defaults to `1`): + Number of images to generate. + height (`int`, *optional*): + Output height in pixels. Defaults to the pretrained native resolution. + width (`int`, *optional*): + Output width in pixels. Defaults to the pretrained native resolution. + num_inference_steps (`int`, defaults to `32`): + Number of EDM2 Heun steps (NVlabs default). + guidance_scale (`float`, defaults to `1.0`): + Autoguidance strength. Values above `1.0` blend the main net with `gnet` + via `gnet_output.lerp(unet_output, guidance_scale)`. + generator (`torch.Generator`, *optional*): + RNG for reproducibility. + output_type (`str`, defaults to `"pil"`): + `"pil"`, `"np"`, `"pt"`, or `"latent"`. + return_dict (`bool`, defaults to `True`): + Return [`~pipelines.pipeline_utils.ImagePipelineOutput`] if True. + + Examples: + + """ + default_size = self._default_image_size() + height = int(height or default_size) + width = int(width or default_size) + self.check_inputs(height, width, num_inference_steps, guidance_scale, output_type) + + device = self._execution_device + dtype = self.unet.dtype + labels = self._normalize_class_labels(class_labels, batch_size=batch_size, device=device) + noise = self.prepare_latents(batch_size, height, width, dtype, device, generator) + + def denoise_fn(x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor: + sigma_batch = sigma.reshape(1).expand(batch_size) + main = self.unet( + sample=x, + sigma=sigma_batch, + class_labels=labels, + force_fp32=True, + ).sample + if guidance_scale == 1.0 or self.gnet is None: + return main.to(torch.float32) + ref = self.gnet( + sample=x, + sigma=sigma_batch, + class_labels=labels, + force_fp32=True, + ).sample + return self._apply_autoguidance(main, ref, guidance_scale).to(torch.float32) + + self.scheduler.set_timesteps(num_inference_steps, device=device) + latents = self._sample_edm2_heun( + denoise_fn=denoise_fn, + noise=noise, + sigmas=self.scheduler.sigmas.to(device), + generator=generator, + progress_bar=self.progress_bar, + dtype=torch.float32, + ) + + image = self.decode_latents(latents, output_type=output_type) + if not return_dict: + return (image, latents) + return ImagePipelineOutput(images=image) + + @classmethod + def _load_vae(cls, pretrained_model_name_or_path: str, torch_dtype: Optional[torch.dtype] = None): + vae_dir = os.path.join(pretrained_model_name_or_path, "vae") + if os.path.isdir(vae_dir): + try: + + return AutoencoderKL.from_pretrained(vae_dir, torch_dtype=torch_dtype) + except Exception: + return None + + vae_hint = os.path.join(pretrained_model_name_or_path, "vae_pretrained_model_name_or_path.txt") + if os.path.isfile(vae_hint): + with open(vae_hint, "r", encoding="utf-8") as f: + hub_id = f.read().strip() + if hub_id: + + return AutoencoderKL.from_pretrained(hub_id, torch_dtype=torch_dtype) + return None diff --git a/edm2-img512-m-fid/scheduler/scheduler_config.json b/edm2-img512-m-fid/scheduler/scheduler_config.json new file mode 100644 index 0000000000000000000000000000000000000000..08cbde97955d06d66b0b72be217db84bb797d711 --- /dev/null +++ b/edm2-img512-m-fid/scheduler/scheduler_config.json @@ -0,0 +1,11 @@ +{ + "_class_name": "EDMEulerScheduler", + "final_sigmas_type": "zero", + "num_train_timesteps": 1000, + "prediction_type": "epsilon", + "rho": 7.0, + "sigma_data": 0.5, + "sigma_max": 80.0, + "sigma_min": 0.002, + "sigma_schedule": "karras" +} diff --git a/edm2-img512-m-fid/unet/config.json b/edm2-img512-m-fid/unet/config.json new file mode 100644 index 0000000000000000000000000000000000000000..0173e5acb7a02db0cd9bfde8ed828ce386e96899 --- /dev/null +++ b/edm2-img512-m-fid/unet/config.json @@ -0,0 +1,31 @@ +{ + "_class_name": "EDM2UNet2DModel", + "attn_balance": 0.3, + "attn_resolutions": [ + 16, + 8 + ], + "channel_mult": [ + 1, + 2, + 3, + 4 + ], + "channel_mult_emb": 4, + "channel_mult_noise": 1, + "channels_per_head": 64, + "clip_act": 256, + "concat_balance": 0.5, + "dropout": 0.0, + "in_channels": 4, + "label_balance": 0.5, + "logvar_channels": 128, + "model_channels": 256, + "num_blocks": 3, + "num_class_embeds": 1000, + "out_channels": 4, + "res_balance": 0.3, + "sample_size": 64, + "sigma_data": 0.5, + "use_fp16": true +} diff --git a/edm2-img512-m-fid/unet/diffusion_pytorch_model.safetensors b/edm2-img512-m-fid/unet/diffusion_pytorch_model.safetensors new file mode 100644 index 0000000000000000000000000000000000000000..ce30a1a7af868a3d6f1cae538169c82793ff12f5 --- /dev/null +++ b/edm2-img512-m-fid/unet/diffusion_pytorch_model.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4733c8b2d2823cd6ce7a67e2b89b0e9b94d50fdf595b0e0b17299e198da3bcfc +size 1991256788 diff --git a/edm2-img512-m-fid/unet/unet_edm2.py b/edm2-img512-m-fid/unet/unet_edm2.py new file mode 100644 index 0000000000000000000000000000000000000000..49d9451b443c4d6ad3b6ba2db19a7a804b07b6de --- /dev/null +++ b/edm2-img512-m-fid/unet/unet_edm2.py @@ -0,0 +1,434 @@ +import math +import json +import os +from dataclasses import dataclass +from typing import List, Optional, Tuple + +import numpy as np +import torch + +try: + from diffusers.configuration_utils import ConfigMixin, register_to_config + from diffusers.models.modeling_utils import ModelMixin + from diffusers.utils import BaseOutput +except ImportError: # pragma: no cover + class ModelMixin(torch.nn.Module): + pass + + class ConfigMixin: + config = {} + + def register_to_config(self, **kwargs): + self.config = kwargs + + def register_to_config(func): + return func + + @dataclass + class BaseOutput: + pass + + +def normalize(x: torch.Tensor, dim: Optional[List[int]] = None, eps: float = 1e-4) -> torch.Tensor: + if dim is None: + dim = list(range(1, x.ndim)) + norm = torch.linalg.vector_norm(x, dim=dim, keepdim=True, dtype=torch.float32) + norm = torch.add(eps, norm, alpha=math.sqrt(norm.numel() / x.numel())) + return x / norm.to(x.dtype) + + +def resample(x: torch.Tensor, f: List[float], mode: str = "keep") -> torch.Tensor: + if mode == "keep": + return x + filt = np.float32(f) + pad = (len(filt) - 1) // 2 + filt = filt / filt.sum() + filt = np.outer(filt, filt)[np.newaxis, np.newaxis, :, :] + filt = torch.as_tensor(filt, dtype=x.dtype, device=x.device) + c = x.shape[1] + if mode == "down": + return torch.nn.functional.conv2d(x, filt.tile([c, 1, 1, 1]), groups=c, stride=2, padding=(pad,)) + return torch.nn.functional.conv_transpose2d(x, (filt * 4).tile([c, 1, 1, 1]), groups=c, stride=2, padding=(pad,)) + + +def mp_silu(x: torch.Tensor) -> torch.Tensor: + return torch.nn.functional.silu(x) / 0.596 + + +def mp_sum(a: torch.Tensor, b: torch.Tensor, t: float = 0.5) -> torch.Tensor: + return a.lerp(b, t) / math.sqrt((1 - t) ** 2 + t ** 2) + + +def mp_cat(a: torch.Tensor, b: torch.Tensor, dim: int = 1, t: float = 0.5) -> torch.Tensor: + na = a.shape[dim] + nb = b.shape[dim] + c = math.sqrt((na + nb) / ((1 - t) ** 2 + t ** 2)) + wa = c / math.sqrt(na) * (1 - t) + wb = c / math.sqrt(nb) * t + return torch.cat([wa * a, wb * b], dim=dim) + + +class MPFourier(torch.nn.Module): + def __init__(self, num_channels: int, bandwidth: float = 1): + super().__init__() + self.register_buffer("freqs", 2 * math.pi * torch.randn(num_channels) * bandwidth) + self.register_buffer("phases", 2 * math.pi * torch.rand(num_channels)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + y = x.to(torch.float32).ger(self.freqs.to(torch.float32)) + y = y + self.phases.to(torch.float32) + y = y.cos() * math.sqrt(2) + return y.to(x.dtype) + + +class MPConv(torch.nn.Module): + def __init__(self, in_channels: int, out_channels: int, kernel: Tuple[int, ...]): + super().__init__() + self.out_channels = out_channels + self.weight = torch.nn.Parameter(torch.randn(out_channels, in_channels, *kernel)) + + def forward(self, x: torch.Tensor, gain: float = 1) -> torch.Tensor: + w = self.weight.to(torch.float32) + if self.training: + with torch.no_grad(): + self.weight.copy_(normalize(w)) + w = normalize(w) + w = w * (gain / math.sqrt(w[0].numel())) + w = w.to(x.dtype) + if w.ndim == 2: + return x @ w.t() + return torch.nn.functional.conv2d(x, w, padding=(w.shape[-1] // 2,)) + + +class Block(torch.nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + emb_channels: int, + flavor: str = "enc", + resample_mode: str = "keep", + resample_filter: List[float] = [1, 1], + attention: bool = False, + channels_per_head: int = 64, + dropout: float = 0.0, + res_balance: float = 0.3, + attn_balance: float = 0.3, + clip_act: Optional[float] = 256, + ): + super().__init__() + self.out_channels = out_channels + self.flavor = flavor + self.resample_filter = resample_filter + self.resample_mode = resample_mode + self.num_heads = out_channels // channels_per_head if attention else 0 + self.dropout = dropout + self.res_balance = res_balance + self.attn_balance = attn_balance + self.clip_act = clip_act + self.emb_gain = torch.nn.Parameter(torch.zeros([])) + self.conv_res0 = MPConv(out_channels if flavor == "enc" else in_channels, out_channels, kernel=(3, 3)) + self.emb_linear = MPConv(emb_channels, out_channels, kernel=()) + self.conv_res1 = MPConv(out_channels, out_channels, kernel=(3, 3)) + self.conv_skip = MPConv(in_channels, out_channels, kernel=(1, 1)) if in_channels != out_channels else None + self.attn_qkv = MPConv(out_channels, out_channels * 3, kernel=(1, 1)) if self.num_heads else None + self.attn_proj = MPConv(out_channels, out_channels, kernel=(1, 1)) if self.num_heads else None + + def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor: + x = resample(x, f=self.resample_filter, mode=self.resample_mode) + if self.flavor == "enc": + if self.conv_skip is not None: + x = self.conv_skip(x) + x = normalize(x, dim=[1]) + + y = self.conv_res0(mp_silu(x)) + c = self.emb_linear(emb, gain=self.emb_gain) + 1 + y = mp_silu(y * c.unsqueeze(2).unsqueeze(3).to(y.dtype)) + if self.training and self.dropout: + y = torch.nn.functional.dropout(y, p=self.dropout) + y = self.conv_res1(y) + + if self.flavor == "dec" and self.conv_skip is not None: + x = self.conv_skip(x) + x = mp_sum(x, y, t=self.res_balance) + + if self.num_heads: + y = self.attn_qkv(x) + y = y.reshape(y.shape[0], self.num_heads, -1, 3, y.shape[2] * y.shape[3]) + q, k, v = normalize(y, dim=[2]).unbind(3) + w = torch.einsum("nhcq,nhck->nhqk", q, k / math.sqrt(q.shape[2])).softmax(dim=3) + y = torch.einsum("nhqk,nhck->nhcq", w, v) + y = self.attn_proj(y.reshape(*x.shape)) + x = mp_sum(x, y, t=self.attn_balance) + + if self.clip_act is not None: + x = x.clip_(-self.clip_act, self.clip_act) + return x + + +class EDM2UNet(torch.nn.Module): + def __init__( + self, + img_resolution: int, + img_channels: int, + label_dim: int, + model_channels: int = 192, + channel_mult: Tuple[int, ...] = (1, 2, 3, 4), + channel_mult_noise: Optional[int] = None, + channel_mult_emb: Optional[int] = None, + num_blocks: int = 3, + attn_resolutions: Tuple[int, ...] = (16, 8), + label_balance: float = 0.5, + concat_balance: float = 0.5, + **block_kwargs, + ): + super().__init__() + cblock = [model_channels * x for x in channel_mult] + cnoise = model_channels * channel_mult_noise if channel_mult_noise is not None else cblock[0] + cemb = model_channels * channel_mult_emb if channel_mult_emb is not None else max(cblock) + self.label_balance = label_balance + self.concat_balance = concat_balance + self.out_gain = torch.nn.Parameter(torch.zeros([])) + + self.emb_fourier = MPFourier(cnoise) + self.emb_noise = MPConv(cnoise, cemb, kernel=()) + self.emb_label = MPConv(label_dim, cemb, kernel=()) if label_dim else None + + self.enc = torch.nn.ModuleDict() + cout = img_channels + 1 + for level, channels in enumerate(cblock): + res = img_resolution >> level + if level == 0: + cin = cout + cout = channels + self.enc[f"{res}x{res}_conv"] = MPConv(cin, cout, kernel=(3, 3)) + else: + self.enc[f"{res}x{res}_down"] = Block(cout, cout, cemb, flavor="enc", resample_mode="down", **block_kwargs) + for idx in range(num_blocks): + cin = cout + cout = channels + self.enc[f"{res}x{res}_block{idx}"] = Block( + cin, + cout, + cemb, + flavor="enc", + attention=(res in attn_resolutions), + **block_kwargs, + ) + + self.dec = torch.nn.ModuleDict() + skips = [block.out_channels for block in self.enc.values()] + for level, channels in reversed(list(enumerate(cblock))): + res = img_resolution >> level + if level == len(cblock) - 1: + self.dec[f"{res}x{res}_in0"] = Block(cout, cout, cemb, flavor="dec", attention=True, **block_kwargs) + self.dec[f"{res}x{res}_in1"] = Block(cout, cout, cemb, flavor="dec", **block_kwargs) + else: + self.dec[f"{res}x{res}_up"] = Block(cout, cout, cemb, flavor="dec", resample_mode="up", **block_kwargs) + for idx in range(num_blocks + 1): + cin = cout + skips.pop() + cout = channels + self.dec[f"{res}x{res}_block{idx}"] = Block( + cin, + cout, + cemb, + flavor="dec", + attention=(res in attn_resolutions), + **block_kwargs, + ) + + self.out_conv = MPConv(cout, img_channels, kernel=(3, 3)) + + def forward(self, x: torch.Tensor, noise_labels: torch.Tensor, class_labels: Optional[torch.Tensor]) -> torch.Tensor: + emb = self.emb_noise(self.emb_fourier(noise_labels)) + if self.emb_label is not None: + if class_labels is None: + raise ValueError("class_labels are required for conditional EDM2UNet.") + emb = mp_sum(emb, self.emb_label(class_labels * math.sqrt(class_labels.shape[1])), t=self.label_balance) + emb = mp_silu(emb) + + x = torch.cat([x, torch.ones_like(x[:, :1])], dim=1) + skips = [] + for name, block in self.enc.items(): + x = block(x) if "conv" in name else block(x, emb) + skips.append(x) + + for name, block in self.dec.items(): + if "block" in name: + x = mp_cat(x, skips.pop(), t=self.concat_balance) + x = block(x, emb) + return self.out_conv(x, gain=self.out_gain) + + +@dataclass +class EDM2UNet2DOutput(BaseOutput): + sample: torch.Tensor + logvar: Optional[torch.Tensor] = None + + + +_CONFIG_KEYS = ( + "sample_size", + "in_channels", + "out_channels", + "num_class_embeds", + "use_fp16", + "sigma_data", + "logvar_channels", + "model_channels", + "channel_mult", + "channel_mult_noise", + "channel_mult_emb", + "num_blocks", + "attn_resolutions", + "label_balance", + "concat_balance", + "dropout", + "channels_per_head", + "res_balance", + "attn_balance", + "clip_act", +) + + +class EDM2UNet2DModel(ModelMixin, ConfigMixin): + @register_to_config + def __init__( + self, + sample_size: int = 64, + in_channels: int = 4, + out_channels: int = 4, + num_class_embeds: int = 0, + use_fp16: bool = True, + sigma_data: float = 0.5, + logvar_channels: int = 128, + model_channels: int = 192, + channel_mult: Tuple[int, ...] = (1, 2, 3, 4), + channel_mult_noise: Optional[int] = None, + channel_mult_emb: Optional[int] = None, + num_blocks: int = 3, + attn_resolutions: Tuple[int, ...] = (16, 8), + label_balance: float = 0.5, + concat_balance: float = 0.5, + dropout: float = 0.0, + channels_per_head: int = 64, + res_balance: float = 0.3, + attn_balance: float = 0.3, + clip_act: Optional[float] = 256, + ): + super().__init__() + self.sample_size = sample_size + self.in_channels = in_channels + self.out_channels = out_channels + self.num_class_embeds = num_class_embeds + self.use_fp16 = use_fp16 + self.sigma_data = sigma_data + self.model_channels = model_channels + self.channel_mult = channel_mult + self.channel_mult_noise = channel_mult_noise + self.channel_mult_emb = channel_mult_emb + self.num_blocks = num_blocks + self.attn_resolutions = attn_resolutions + self.label_balance = label_balance + self.concat_balance = concat_balance + self.dropout = dropout + self.channels_per_head = channels_per_head + self.res_balance = res_balance + self.attn_balance = attn_balance + self.clip_act = clip_act + self.unet = EDM2UNet( + img_resolution=sample_size, + img_channels=in_channels, + label_dim=num_class_embeds, + model_channels=model_channels, + channel_mult=channel_mult, + channel_mult_noise=channel_mult_noise, + channel_mult_emb=channel_mult_emb, + num_blocks=num_blocks, + attn_resolutions=attn_resolutions, + label_balance=label_balance, + concat_balance=concat_balance, + dropout=dropout, + channels_per_head=channels_per_head, + res_balance=res_balance, + attn_balance=attn_balance, + clip_act=clip_act, + ) + self.logvar_fourier = MPFourier(logvar_channels) + self.logvar_linear = MPConv(logvar_channels, 1, kernel=()) + + def forward( + self, + sample: torch.Tensor, + sigma: torch.Tensor, + class_labels: Optional[torch.Tensor] = None, + force_fp32: bool = False, + return_logvar: bool = False, + return_dict: bool = True, + ) -> EDM2UNet2DOutput: + x = sample.to(torch.float32) + sigma = sigma.to(torch.float32).reshape(-1, 1, 1, 1) + if self.num_class_embeds == 0: + class_labels = None + else: + if class_labels is None: + class_labels = torch.zeros([x.shape[0], self.num_class_embeds], device=x.device) + class_labels = class_labels.to(torch.float32).reshape(-1, self.num_class_embeds) + dtype = torch.float16 if (self.use_fp16 and not force_fp32 and x.device.type == "cuda") else torch.float32 + + c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2) + c_out = sigma * self.sigma_data / (sigma**2 + self.sigma_data**2).sqrt() + c_in = 1 / (self.sigma_data**2 + sigma**2).sqrt() + c_noise = sigma.flatten().log() / 4 + + x_in = (c_in * x).to(dtype) + f_x = self.unet(x_in, c_noise, class_labels) + d_x = c_skip * x + c_out * f_x.to(torch.float32) + + logvar = None + if return_logvar: + logvar = self.logvar_linear(self.logvar_fourier(c_noise)).reshape(-1, 1, 1, 1) + + if not return_dict: + return (d_x, logvar) + return EDM2UNet2DOutput(sample=d_x, logvar=logvar) + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path: str, torch_dtype: Optional[torch.dtype] = None, **kwargs): + subfolder = kwargs.pop("subfolder", None) + model_dir = os.path.join(pretrained_model_name_or_path, subfolder) if subfolder else pretrained_model_name_or_path + with open(os.path.join(model_dir, "config.json"), "r", encoding="utf-8") as f: + config = json.load(f) + init_kwargs = {k: v for k, v in config.items() if k in _CONFIG_KEYS} + model = cls(**init_kwargs) + weight_file = os.path.join(model_dir, "diffusion_pytorch_model.safetensors") + if os.path.isfile(weight_file): + from safetensors.torch import load_file + + state_dict = load_file(weight_file) + else: + state_dict = torch.load(os.path.join(model_dir, "diffusion_pytorch_model.bin"), map_location="cpu") + model.load_state_dict(state_dict, strict=True) + if torch_dtype is not None: + model = model.to(dtype=torch_dtype) + return model + + def save_pretrained(self, save_directory: str, safe_serialization: bool = True): + os.makedirs(save_directory, exist_ok=True) + stored = dict(getattr(self, "config", {})) + config = {"_class_name": self.__class__.__name__} + for key in _CONFIG_KEYS: + if key in stored: + config[key] = stored[key] + elif hasattr(self, key): + config[key] = getattr(self, key) + with open(os.path.join(save_directory, "config.json"), "w", encoding="utf-8") as f: + json.dump(config, f, indent=2, sort_keys=True) + f.write("\n") + state_dict = self.state_dict() + if safe_serialization: + from safetensors.torch import save_file + + save_file(state_dict, os.path.join(save_directory, "diffusion_pytorch_model.safetensors")) + else: + torch.save(state_dict, os.path.join(save_directory, "diffusion_pytorch_model.bin")) diff --git a/edm2-img512-m-fid/vae/config.json b/edm2-img512-m-fid/vae/config.json new file mode 100644 index 0000000000000000000000000000000000000000..ae2ab37de2acb8bc348d3de809ae8385324be962 --- /dev/null +++ b/edm2-img512-m-fid/vae/config.json @@ -0,0 +1,38 @@ +{ + "_class_name": "AutoencoderKL", + "_diffusers_version": "0.36.0", + "_name_or_path": "stabilityai/sd-vae-ft-mse", + "act_fn": "silu", + "block_out_channels": [ + 128, + 256, + 512, + 512 + ], + "down_block_types": [ + "DownEncoderBlock2D", + "DownEncoderBlock2D", + "DownEncoderBlock2D", + "DownEncoderBlock2D" + ], + "force_upcast": true, + "in_channels": 3, + "latent_channels": 4, + "latents_mean": null, + "latents_std": null, + "layers_per_block": 2, + "mid_block_add_attention": true, + "norm_num_groups": 32, + "out_channels": 3, + "sample_size": 256, + "scaling_factor": 0.18215, + "shift_factor": null, + "up_block_types": [ + "UpDecoderBlock2D", + "UpDecoderBlock2D", + "UpDecoderBlock2D", + "UpDecoderBlock2D" + ], + "use_post_quant_conv": true, + "use_quant_conv": true +} diff --git a/edm2-img512-m-fid/vae/diffusion_pytorch_model.safetensors b/edm2-img512-m-fid/vae/diffusion_pytorch_model.safetensors new file mode 100644 index 0000000000000000000000000000000000000000..90464d67ac7303d0ee4696334df13da130a948ea --- /dev/null +++ b/edm2-img512-m-fid/vae/diffusion_pytorch_model.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a1d993488569e928462932c8c38a0760b874d166399b14414135bd9c42df5815 +size 334643276 diff --git a/edm2-img512-s-fid/demo.png b/edm2-img512-s-fid/demo.png new file mode 100644 index 0000000000000000000000000000000000000000..aed38cdc126d7bffa89fbadc5e8f61c3afd45c16 --- /dev/null +++ b/edm2-img512-s-fid/demo.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:58bdb49e30c85b02b9e3619a11b39b1ec760452e8ad96cea1c5856e99df39d42 +size 381489 diff --git a/edm2-img512-s-fid/model_index.json b/edm2-img512-s-fid/model_index.json new file mode 100644 index 0000000000000000000000000000000000000000..24c0305eea3ec64abbc2f870476afa4f5e462f42 --- /dev/null +++ b/edm2-img512-s-fid/model_index.json @@ -0,0 +1,19 @@ +{ + "_class_name": [ + "pipeline", + "EDM2Pipeline" + ], + "_diffusers_version": "0.31.0", + "scheduler": [ + "diffusers", + "EDMEulerScheduler" + ], + "unet": [ + "unet_edm2", + "EDM2UNet2DModel" + ], + "vae": [ + "diffusers", + "AutoencoderKL" + ] +} diff --git a/edm2-img512-s-fid/pipeline.py b/edm2-img512-s-fid/pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..6b4131215d07770c74efa0674d00f3d6423d9c06 --- /dev/null +++ b/edm2-img512-s-fid/pipeline.py @@ -0,0 +1,406 @@ +"""Hub custom pipeline: EDM2Pipeline. +Load with native Hugging Face diffusers and trust_remote_code=True. +""" + +from __future__ import annotations + +# Copyright 2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os +from pathlib import Path +from typing import Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Union + +import numpy as np +import torch + +from diffusers.image_processor import VaeImageProcessor +from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput +from diffusers.utils import replace_example_docstring +from diffusers.utils.torch_utils import randn_tensor + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> from pathlib import Path + >>> import torch + >>> from diffusers import DiffusionPipeline + + >>> model_dir = Path("BiliSakura/EDM2-diffusers/edm2-img512-xs-fid").resolve() + >>> pipe = DiffusionPipeline.from_pretrained( + ... str(model_dir), + ... local_files_only=True, + ... custom_pipeline=str(model_dir / "pipeline.py"), + ... trust_remote_code=True, + ... torch_dtype=torch.float32, + ... ) + >>> pipe.to("cuda") + + >>> generator = torch.Generator(device="cuda").manual_seed(42) + >>> image = pipe( + ... class_labels=207, + ... num_inference_steps=32, + ... guidance_scale=1.0, + ... generator=generator, + ... ).images[0] + >>> image.save("demo.png") + ``` +""" + +# Default Stability VAE latent whitening used by NVlabs/edm2 (training/encoders.py). +_STABILITY_VAE_SCALE = np.float32(0.5) / np.float32([4.17, 4.62, 3.71, 3.28]) +_STABILITY_VAE_BIAS = np.float32(0.0) - np.float32([5.81, 3.25, 0.12, -2.15]) * _STABILITY_VAE_SCALE + +class EDM2Pipeline(DiffusionPipeline): + r""" + Pipeline for class-conditional image generation with EDM2 + ([Analyzing and Improving the Training Dynamics of Diffusion Models](https://arxiv.org/abs/2312.02696)). + + Parameters: + unet ([`EDM2UNet2DModel`]): + Main magnitude-preserving U-Net with EDM preconditioning. + scheduler ([`EDMEulerScheduler`]): + Built-in diffusers scheduler used for the Karras sigma schedule. EDM2 Heun sampling runs in + the pipeline because the UNet returns denoised latents rather than noise predictions. + vae ([`AutoencoderKL`], *optional*): + Decoder for 512px latent-diffusion checkpoints. Required when `unet.in_channels == 4`. + gnet ([`EDM2UNet2DModel`], *optional*): + Guiding network for autoguidance (`ref.lerp(main, guidance_scale)`). + id2label (`dict[int, str]`, *optional*): + ImageNet class id to English label mapping. + """ + + model_cpu_offload_seq = "unet->gnet->vae" + _optional_components = ["vae", "gnet"] + + def __init__( + self, + unet, + scheduler, + vae=None, + gnet=None, + id2label: Optional[Dict[Union[int, str], str]] = None, + ) -> None: + super().__init__() + self.register_modules(unet=unet, scheduler=scheduler, vae=vae, gnet=gnet) + self._id2label = self._normalize_id2label(id2label) + self.labels = self._build_label2id(self._id2label) + self._labels_loaded_from_model_index = bool(self._id2label) + self.vae_scale_factor = 8 if self.vae is not None else 1 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_normalize=False) + + @staticmethod + def _normalize_id2label(id2label: Optional[Dict[Union[int, str], str]]) -> Dict[int, str]: + if not id2label: + return {} + return {int(key): value for key, value in id2label.items()} + + @staticmethod + def _build_label2id(id2label: Dict[int, str]) -> Dict[str, int]: + label2id: Dict[str, int] = {} + for class_id, value in id2label.items(): + for synonym in value.split(","): + synonym = synonym.strip() + if synonym: + label2id[synonym] = int(class_id) + return dict(sorted(label2id.items())) + + def _ensure_labels_loaded(self) -> None: + if self._labels_loaded_from_model_index: + return + loaded = self._read_id2label_from_model_index(getattr(self.config, "_name_or_path", None)) + if loaded: + self._id2label = loaded + self.labels = self._build_label2id(self._id2label) + self._labels_loaded_from_model_index = True + + @staticmethod + def _read_id2label_from_model_index(variant_path: Optional[str]) -> Dict[int, str]: + if not variant_path: + return {} + model_index_path = Path(variant_path).resolve() / "model_index.json" + if not model_index_path.is_file(): + return {} + raw = json.loads(model_index_path.read_text(encoding="utf-8")) + id2label = raw.get("id2label") + if not isinstance(id2label, dict): + return {} + return {int(key): value for key, value in id2label.items()} + + @property + def id2label(self) -> Dict[int, str]: + r"""ImageNet class id to English label string (comma-separated synonyms).""" + self._ensure_labels_loaded() + return self._id2label + + def get_label_ids(self, label: Union[str, List[str]]) -> List[int]: + r""" + Map ImageNet label strings to class ids. + + Args: + label (`str` or `list[str]`): + One or more English label strings that match entries in `id2label`. + """ + self._ensure_labels_loaded() + if not self.labels: + raise ValueError("No English labels loaded. Add `id2label` to model_index.json.") + labels = [label] if isinstance(label, str) else list(label) + missing = [item for item in labels if item not in self.labels] + if missing: + preview = ", ".join(list(self.labels.keys())[:8]) + raise ValueError(f"Unknown English label(s): {missing}. Example valid labels: {preview}, ...") + return [self.labels[item] for item in labels] + + def _default_image_size(self) -> int: + latent_size = int(getattr(self.unet, "sample_size", getattr(self.unet.config, "sample_size", 64))) + return latent_size * self.vae_scale_factor + + def check_inputs( + self, + height: int, + width: int, + num_inference_steps: int, + guidance_scale: float, + output_type: str, + ) -> None: + if num_inference_steps < 1: + raise ValueError("num_inference_steps must be >= 1.") + if guidance_scale < 1.0: + raise ValueError("guidance_scale must be >= 1.0.") + if guidance_scale > 1.0 and self.gnet is None: + raise ValueError("guidance_scale > 1.0 requires a guiding network (`gnet`).") + if output_type not in {"pil", "np", "pt", "latent"}: + raise ValueError("output_type must be one of: 'pil', 'np', 'pt', 'latent'.") + + native_size = self._default_image_size() + if height != native_size or width != native_size: + raise ValueError( + f"EDM2 expects native resolution height=width={native_size}. " + f"Got height={height}, width={width}." + ) + + def _normalize_class_labels( + self, + class_labels: Optional[Union[int, str, Sequence[Union[int, str]], torch.Tensor]], + batch_size: int, + device: torch.device, + ) -> Optional[torch.Tensor]: + label_dim = int(getattr(self.unet, "num_class_embeds", getattr(self.unet.config, "num_class_embeds", 0))) + if label_dim == 0: + return None + if class_labels is None: + indices = torch.randint(label_dim, size=(batch_size,), device=device) + return torch.eye(label_dim, device=device, dtype=torch.float32)[indices] + + if isinstance(class_labels, str): + class_labels = self.get_label_ids(class_labels)[0] + elif isinstance(class_labels, Sequence) and class_labels and isinstance(class_labels[0], str): + class_labels = self.get_label_ids(list(class_labels)) + + if isinstance(class_labels, int): + indices = torch.full((batch_size,), class_labels, device=device, dtype=torch.long) + elif isinstance(class_labels, torch.Tensor): + if class_labels.ndim == 2: + labels = class_labels.to(device=device, dtype=torch.float32) + if labels.shape[0] != batch_size: + raise ValueError(f"class_labels batch must match batch_size={batch_size}.") + return labels + indices = class_labels.to(device=device, dtype=torch.long).flatten() + else: + indices = torch.tensor(list(class_labels), device=device, dtype=torch.long) + + if indices.numel() == 1 and batch_size > 1: + indices = indices.repeat(batch_size) + if indices.numel() != batch_size: + raise ValueError(f"class_labels must resolve to batch size {batch_size}, got {indices.numel()}.") + return torch.eye(label_dim, device=device, dtype=torch.float32)[indices] + + def prepare_latents( + self, + batch_size: int, + height: int, + width: int, + dtype: torch.dtype, + device: torch.device, + generator: Optional[Union[torch.Generator, List[torch.Generator]]], + ) -> torch.Tensor: + in_channels = int(getattr(self.unet, "in_channels", getattr(self.unet.config, "in_channels", 4))) + latent_size = height // self.vae_scale_factor + return randn_tensor( + (batch_size, in_channels, latent_size, latent_size), + generator=generator, + device=device, + dtype=torch.float32, + ) + + def decode_latents(self, latents: torch.Tensor, output_type: str = "pil"): + if output_type == "latent": + return latents + + in_channels = int(getattr(self.unet, "in_channels", getattr(self.unet.config, "in_channels", 3))) + if self.vae is None: + image = (latents.to(torch.float32) * 127.5 + 128).clip(0, 255) / 255.0 + return self.image_processor.postprocess(image, output_type=output_type) + + if in_channels == 4: + x = latents.to(torch.float32) + scale = torch.as_tensor(_STABILITY_VAE_SCALE, dtype=x.dtype, device=x.device).reshape(1, -1, 1, 1) + bias = torch.as_tensor(_STABILITY_VAE_BIAS, dtype=x.dtype, device=x.device).reshape(1, -1, 1, 1) + x = (x - bias) / scale + else: + x = latents.to(torch.float32) + + vae_dtype = getattr(self.vae, "dtype", None) or next(self.vae.parameters()).dtype + image = self.vae.decode(x.to(dtype=vae_dtype)).sample.to(torch.float32).clamp(0, 1) + + return self.image_processor.postprocess(image, output_type=output_type) + + @staticmethod + def _apply_autoguidance( + main: torch.Tensor, + ref: torch.Tensor, + guidance_scale: float, + ) -> torch.Tensor: + return ref.lerp(main, guidance_scale) + + @staticmethod + def _sample_edm2_heun( + denoise_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor], + noise: torch.Tensor, + sigmas: torch.Tensor, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + progress_bar: Optional[Callable[[Iterable], Iterable]] = None, + dtype: torch.dtype = torch.float32, + ) -> torch.Tensor: + """NVlabs EDM2 Heun sampler (generate_images.edm_sampler, guidance=1, S_churn=0).""" + x_next = noise.to(dtype) * sigmas[0] + + sigma_pairs = list(zip(sigmas[:-1], sigmas[1:])) + if progress_bar is not None: + sigma_pairs = progress_bar(sigma_pairs) + + num_steps = len(sigma_pairs) + for i, (sigma_cur, sigma_next) in enumerate(sigma_pairs): + x_hat, sigma_hat = x_next, sigma_cur + d_cur = (x_hat - denoise_fn(x_hat, sigma_hat)) / sigma_hat + x_next = x_hat + (sigma_next - sigma_hat) * d_cur + if i < num_steps - 1: + d_prime = (x_next - denoise_fn(x_next, sigma_next)) / sigma_next + x_next = x_hat + (sigma_next - sigma_hat) * (0.5 * d_cur + 0.5 * d_prime) + return x_next + + @torch.inference_mode() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + class_labels: Optional[Union[int, str, Sequence[Union[int, str]], torch.Tensor]] = None, + batch_size: int = 1, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 32, + guidance_scale: float = 1.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + output_type: str = "pil", + return_dict: bool = True, + ) -> Union[ImagePipelineOutput, Tuple]: + r""" + Generate class-conditional images with EDM2. + + Args: + class_labels (`int`, `str`, `list[int]`, `list[str]`, or `torch.Tensor`, *optional*): + ImageNet class indices, English label strings, or one-hot float tensors. + Random classes are sampled when omitted on conditional models. + batch_size (`int`, defaults to `1`): + Number of images to generate. + height (`int`, *optional*): + Output height in pixels. Defaults to the pretrained native resolution. + width (`int`, *optional*): + Output width in pixels. Defaults to the pretrained native resolution. + num_inference_steps (`int`, defaults to `32`): + Number of EDM2 Heun steps (NVlabs default). + guidance_scale (`float`, defaults to `1.0`): + Autoguidance strength. Values above `1.0` blend the main net with `gnet` + via `gnet_output.lerp(unet_output, guidance_scale)`. + generator (`torch.Generator`, *optional*): + RNG for reproducibility. + output_type (`str`, defaults to `"pil"`): + `"pil"`, `"np"`, `"pt"`, or `"latent"`. + return_dict (`bool`, defaults to `True`): + Return [`~pipelines.pipeline_utils.ImagePipelineOutput`] if True. + + Examples: + + """ + default_size = self._default_image_size() + height = int(height or default_size) + width = int(width or default_size) + self.check_inputs(height, width, num_inference_steps, guidance_scale, output_type) + + device = self._execution_device + dtype = self.unet.dtype + labels = self._normalize_class_labels(class_labels, batch_size=batch_size, device=device) + noise = self.prepare_latents(batch_size, height, width, dtype, device, generator) + + def denoise_fn(x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor: + sigma_batch = sigma.reshape(1).expand(batch_size) + main = self.unet( + sample=x, + sigma=sigma_batch, + class_labels=labels, + force_fp32=True, + ).sample + if guidance_scale == 1.0 or self.gnet is None: + return main.to(torch.float32) + ref = self.gnet( + sample=x, + sigma=sigma_batch, + class_labels=labels, + force_fp32=True, + ).sample + return self._apply_autoguidance(main, ref, guidance_scale).to(torch.float32) + + self.scheduler.set_timesteps(num_inference_steps, device=device) + latents = self._sample_edm2_heun( + denoise_fn=denoise_fn, + noise=noise, + sigmas=self.scheduler.sigmas.to(device), + generator=generator, + progress_bar=self.progress_bar, + dtype=torch.float32, + ) + + image = self.decode_latents(latents, output_type=output_type) + if not return_dict: + return (image, latents) + return ImagePipelineOutput(images=image) + + @classmethod + def _load_vae(cls, pretrained_model_name_or_path: str, torch_dtype: Optional[torch.dtype] = None): + vae_dir = os.path.join(pretrained_model_name_or_path, "vae") + if os.path.isdir(vae_dir): + try: + + return AutoencoderKL.from_pretrained(vae_dir, torch_dtype=torch_dtype) + except Exception: + return None + + vae_hint = os.path.join(pretrained_model_name_or_path, "vae_pretrained_model_name_or_path.txt") + if os.path.isfile(vae_hint): + with open(vae_hint, "r", encoding="utf-8") as f: + hub_id = f.read().strip() + if hub_id: + + return AutoencoderKL.from_pretrained(hub_id, torch_dtype=torch_dtype) + return None diff --git a/edm2-img512-s-fid/scheduler/scheduler_config.json b/edm2-img512-s-fid/scheduler/scheduler_config.json new file mode 100644 index 0000000000000000000000000000000000000000..08cbde97955d06d66b0b72be217db84bb797d711 --- /dev/null +++ b/edm2-img512-s-fid/scheduler/scheduler_config.json @@ -0,0 +1,11 @@ +{ + "_class_name": "EDMEulerScheduler", + "final_sigmas_type": "zero", + "num_train_timesteps": 1000, + "prediction_type": "epsilon", + "rho": 7.0, + "sigma_data": 0.5, + "sigma_max": 80.0, + "sigma_min": 0.002, + "sigma_schedule": "karras" +} diff --git a/edm2-img512-s-fid/unet/config.json b/edm2-img512-s-fid/unet/config.json new file mode 100644 index 0000000000000000000000000000000000000000..4f7eb106dcadc3dd4c2744ae18e255f6cbe420da --- /dev/null +++ b/edm2-img512-s-fid/unet/config.json @@ -0,0 +1,31 @@ +{ + "_class_name": "EDM2UNet2DModel", + "attn_balance": 0.3, + "attn_resolutions": [ + 16, + 8 + ], + "channel_mult": [ + 1, + 2, + 3, + 4 + ], + "channel_mult_emb": 4, + "channel_mult_noise": 1, + "channels_per_head": 64, + "clip_act": 256, + "concat_balance": 0.5, + "dropout": 0.0, + "in_channels": 4, + "label_balance": 0.5, + "logvar_channels": 128, + "model_channels": 192, + "num_blocks": 3, + "num_class_embeds": 1000, + "out_channels": 4, + "res_balance": 0.3, + "sample_size": 64, + "sigma_data": 0.5, + "use_fp16": true +} diff --git a/edm2-img512-s-fid/unet/diffusion_pytorch_model.safetensors b/edm2-img512-s-fid/unet/diffusion_pytorch_model.safetensors new file mode 100644 index 0000000000000000000000000000000000000000..f7bc7bc5c9314fce180ac5f46ae971dfe08fb183 --- /dev/null +++ b/edm2-img512-s-fid/unet/diffusion_pytorch_model.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5dee937e117e2367ede680aae4edf96635ff4debb9ae73f2617111991aa83d61 +size 1120876188 diff --git a/edm2-img512-s-fid/unet/unet_edm2.py b/edm2-img512-s-fid/unet/unet_edm2.py new file mode 100644 index 0000000000000000000000000000000000000000..49d9451b443c4d6ad3b6ba2db19a7a804b07b6de --- /dev/null +++ b/edm2-img512-s-fid/unet/unet_edm2.py @@ -0,0 +1,434 @@ +import math +import json +import os +from dataclasses import dataclass +from typing import List, Optional, Tuple + +import numpy as np +import torch + +try: + from diffusers.configuration_utils import ConfigMixin, register_to_config + from diffusers.models.modeling_utils import ModelMixin + from diffusers.utils import BaseOutput +except ImportError: # pragma: no cover + class ModelMixin(torch.nn.Module): + pass + + class ConfigMixin: + config = {} + + def register_to_config(self, **kwargs): + self.config = kwargs + + def register_to_config(func): + return func + + @dataclass + class BaseOutput: + pass + + +def normalize(x: torch.Tensor, dim: Optional[List[int]] = None, eps: float = 1e-4) -> torch.Tensor: + if dim is None: + dim = list(range(1, x.ndim)) + norm = torch.linalg.vector_norm(x, dim=dim, keepdim=True, dtype=torch.float32) + norm = torch.add(eps, norm, alpha=math.sqrt(norm.numel() / x.numel())) + return x / norm.to(x.dtype) + + +def resample(x: torch.Tensor, f: List[float], mode: str = "keep") -> torch.Tensor: + if mode == "keep": + return x + filt = np.float32(f) + pad = (len(filt) - 1) // 2 + filt = filt / filt.sum() + filt = np.outer(filt, filt)[np.newaxis, np.newaxis, :, :] + filt = torch.as_tensor(filt, dtype=x.dtype, device=x.device) + c = x.shape[1] + if mode == "down": + return torch.nn.functional.conv2d(x, filt.tile([c, 1, 1, 1]), groups=c, stride=2, padding=(pad,)) + return torch.nn.functional.conv_transpose2d(x, (filt * 4).tile([c, 1, 1, 1]), groups=c, stride=2, padding=(pad,)) + + +def mp_silu(x: torch.Tensor) -> torch.Tensor: + return torch.nn.functional.silu(x) / 0.596 + + +def mp_sum(a: torch.Tensor, b: torch.Tensor, t: float = 0.5) -> torch.Tensor: + return a.lerp(b, t) / math.sqrt((1 - t) ** 2 + t ** 2) + + +def mp_cat(a: torch.Tensor, b: torch.Tensor, dim: int = 1, t: float = 0.5) -> torch.Tensor: + na = a.shape[dim] + nb = b.shape[dim] + c = math.sqrt((na + nb) / ((1 - t) ** 2 + t ** 2)) + wa = c / math.sqrt(na) * (1 - t) + wb = c / math.sqrt(nb) * t + return torch.cat([wa * a, wb * b], dim=dim) + + +class MPFourier(torch.nn.Module): + def __init__(self, num_channels: int, bandwidth: float = 1): + super().__init__() + self.register_buffer("freqs", 2 * math.pi * torch.randn(num_channels) * bandwidth) + self.register_buffer("phases", 2 * math.pi * torch.rand(num_channels)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + y = x.to(torch.float32).ger(self.freqs.to(torch.float32)) + y = y + self.phases.to(torch.float32) + y = y.cos() * math.sqrt(2) + return y.to(x.dtype) + + +class MPConv(torch.nn.Module): + def __init__(self, in_channels: int, out_channels: int, kernel: Tuple[int, ...]): + super().__init__() + self.out_channels = out_channels + self.weight = torch.nn.Parameter(torch.randn(out_channels, in_channels, *kernel)) + + def forward(self, x: torch.Tensor, gain: float = 1) -> torch.Tensor: + w = self.weight.to(torch.float32) + if self.training: + with torch.no_grad(): + self.weight.copy_(normalize(w)) + w = normalize(w) + w = w * (gain / math.sqrt(w[0].numel())) + w = w.to(x.dtype) + if w.ndim == 2: + return x @ w.t() + return torch.nn.functional.conv2d(x, w, padding=(w.shape[-1] // 2,)) + + +class Block(torch.nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + emb_channels: int, + flavor: str = "enc", + resample_mode: str = "keep", + resample_filter: List[float] = [1, 1], + attention: bool = False, + channels_per_head: int = 64, + dropout: float = 0.0, + res_balance: float = 0.3, + attn_balance: float = 0.3, + clip_act: Optional[float] = 256, + ): + super().__init__() + self.out_channels = out_channels + self.flavor = flavor + self.resample_filter = resample_filter + self.resample_mode = resample_mode + self.num_heads = out_channels // channels_per_head if attention else 0 + self.dropout = dropout + self.res_balance = res_balance + self.attn_balance = attn_balance + self.clip_act = clip_act + self.emb_gain = torch.nn.Parameter(torch.zeros([])) + self.conv_res0 = MPConv(out_channels if flavor == "enc" else in_channels, out_channels, kernel=(3, 3)) + self.emb_linear = MPConv(emb_channels, out_channels, kernel=()) + self.conv_res1 = MPConv(out_channels, out_channels, kernel=(3, 3)) + self.conv_skip = MPConv(in_channels, out_channels, kernel=(1, 1)) if in_channels != out_channels else None + self.attn_qkv = MPConv(out_channels, out_channels * 3, kernel=(1, 1)) if self.num_heads else None + self.attn_proj = MPConv(out_channels, out_channels, kernel=(1, 1)) if self.num_heads else None + + def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor: + x = resample(x, f=self.resample_filter, mode=self.resample_mode) + if self.flavor == "enc": + if self.conv_skip is not None: + x = self.conv_skip(x) + x = normalize(x, dim=[1]) + + y = self.conv_res0(mp_silu(x)) + c = self.emb_linear(emb, gain=self.emb_gain) + 1 + y = mp_silu(y * c.unsqueeze(2).unsqueeze(3).to(y.dtype)) + if self.training and self.dropout: + y = torch.nn.functional.dropout(y, p=self.dropout) + y = self.conv_res1(y) + + if self.flavor == "dec" and self.conv_skip is not None: + x = self.conv_skip(x) + x = mp_sum(x, y, t=self.res_balance) + + if self.num_heads: + y = self.attn_qkv(x) + y = y.reshape(y.shape[0], self.num_heads, -1, 3, y.shape[2] * y.shape[3]) + q, k, v = normalize(y, dim=[2]).unbind(3) + w = torch.einsum("nhcq,nhck->nhqk", q, k / math.sqrt(q.shape[2])).softmax(dim=3) + y = torch.einsum("nhqk,nhck->nhcq", w, v) + y = self.attn_proj(y.reshape(*x.shape)) + x = mp_sum(x, y, t=self.attn_balance) + + if self.clip_act is not None: + x = x.clip_(-self.clip_act, self.clip_act) + return x + + +class EDM2UNet(torch.nn.Module): + def __init__( + self, + img_resolution: int, + img_channels: int, + label_dim: int, + model_channels: int = 192, + channel_mult: Tuple[int, ...] = (1, 2, 3, 4), + channel_mult_noise: Optional[int] = None, + channel_mult_emb: Optional[int] = None, + num_blocks: int = 3, + attn_resolutions: Tuple[int, ...] = (16, 8), + label_balance: float = 0.5, + concat_balance: float = 0.5, + **block_kwargs, + ): + super().__init__() + cblock = [model_channels * x for x in channel_mult] + cnoise = model_channels * channel_mult_noise if channel_mult_noise is not None else cblock[0] + cemb = model_channels * channel_mult_emb if channel_mult_emb is not None else max(cblock) + self.label_balance = label_balance + self.concat_balance = concat_balance + self.out_gain = torch.nn.Parameter(torch.zeros([])) + + self.emb_fourier = MPFourier(cnoise) + self.emb_noise = MPConv(cnoise, cemb, kernel=()) + self.emb_label = MPConv(label_dim, cemb, kernel=()) if label_dim else None + + self.enc = torch.nn.ModuleDict() + cout = img_channels + 1 + for level, channels in enumerate(cblock): + res = img_resolution >> level + if level == 0: + cin = cout + cout = channels + self.enc[f"{res}x{res}_conv"] = MPConv(cin, cout, kernel=(3, 3)) + else: + self.enc[f"{res}x{res}_down"] = Block(cout, cout, cemb, flavor="enc", resample_mode="down", **block_kwargs) + for idx in range(num_blocks): + cin = cout + cout = channels + self.enc[f"{res}x{res}_block{idx}"] = Block( + cin, + cout, + cemb, + flavor="enc", + attention=(res in attn_resolutions), + **block_kwargs, + ) + + self.dec = torch.nn.ModuleDict() + skips = [block.out_channels for block in self.enc.values()] + for level, channels in reversed(list(enumerate(cblock))): + res = img_resolution >> level + if level == len(cblock) - 1: + self.dec[f"{res}x{res}_in0"] = Block(cout, cout, cemb, flavor="dec", attention=True, **block_kwargs) + self.dec[f"{res}x{res}_in1"] = Block(cout, cout, cemb, flavor="dec", **block_kwargs) + else: + self.dec[f"{res}x{res}_up"] = Block(cout, cout, cemb, flavor="dec", resample_mode="up", **block_kwargs) + for idx in range(num_blocks + 1): + cin = cout + skips.pop() + cout = channels + self.dec[f"{res}x{res}_block{idx}"] = Block( + cin, + cout, + cemb, + flavor="dec", + attention=(res in attn_resolutions), + **block_kwargs, + ) + + self.out_conv = MPConv(cout, img_channels, kernel=(3, 3)) + + def forward(self, x: torch.Tensor, noise_labels: torch.Tensor, class_labels: Optional[torch.Tensor]) -> torch.Tensor: + emb = self.emb_noise(self.emb_fourier(noise_labels)) + if self.emb_label is not None: + if class_labels is None: + raise ValueError("class_labels are required for conditional EDM2UNet.") + emb = mp_sum(emb, self.emb_label(class_labels * math.sqrt(class_labels.shape[1])), t=self.label_balance) + emb = mp_silu(emb) + + x = torch.cat([x, torch.ones_like(x[:, :1])], dim=1) + skips = [] + for name, block in self.enc.items(): + x = block(x) if "conv" in name else block(x, emb) + skips.append(x) + + for name, block in self.dec.items(): + if "block" in name: + x = mp_cat(x, skips.pop(), t=self.concat_balance) + x = block(x, emb) + return self.out_conv(x, gain=self.out_gain) + + +@dataclass +class EDM2UNet2DOutput(BaseOutput): + sample: torch.Tensor + logvar: Optional[torch.Tensor] = None + + + +_CONFIG_KEYS = ( + "sample_size", + "in_channels", + "out_channels", + "num_class_embeds", + "use_fp16", + "sigma_data", + "logvar_channels", + "model_channels", + "channel_mult", + "channel_mult_noise", + "channel_mult_emb", + "num_blocks", + "attn_resolutions", + "label_balance", + "concat_balance", + "dropout", + "channels_per_head", + "res_balance", + "attn_balance", + "clip_act", +) + + +class EDM2UNet2DModel(ModelMixin, ConfigMixin): + @register_to_config + def __init__( + self, + sample_size: int = 64, + in_channels: int = 4, + out_channels: int = 4, + num_class_embeds: int = 0, + use_fp16: bool = True, + sigma_data: float = 0.5, + logvar_channels: int = 128, + model_channels: int = 192, + channel_mult: Tuple[int, ...] = (1, 2, 3, 4), + channel_mult_noise: Optional[int] = None, + channel_mult_emb: Optional[int] = None, + num_blocks: int = 3, + attn_resolutions: Tuple[int, ...] = (16, 8), + label_balance: float = 0.5, + concat_balance: float = 0.5, + dropout: float = 0.0, + channels_per_head: int = 64, + res_balance: float = 0.3, + attn_balance: float = 0.3, + clip_act: Optional[float] = 256, + ): + super().__init__() + self.sample_size = sample_size + self.in_channels = in_channels + self.out_channels = out_channels + self.num_class_embeds = num_class_embeds + self.use_fp16 = use_fp16 + self.sigma_data = sigma_data + self.model_channels = model_channels + self.channel_mult = channel_mult + self.channel_mult_noise = channel_mult_noise + self.channel_mult_emb = channel_mult_emb + self.num_blocks = num_blocks + self.attn_resolutions = attn_resolutions + self.label_balance = label_balance + self.concat_balance = concat_balance + self.dropout = dropout + self.channels_per_head = channels_per_head + self.res_balance = res_balance + self.attn_balance = attn_balance + self.clip_act = clip_act + self.unet = EDM2UNet( + img_resolution=sample_size, + img_channels=in_channels, + label_dim=num_class_embeds, + model_channels=model_channels, + channel_mult=channel_mult, + channel_mult_noise=channel_mult_noise, + channel_mult_emb=channel_mult_emb, + num_blocks=num_blocks, + attn_resolutions=attn_resolutions, + label_balance=label_balance, + concat_balance=concat_balance, + dropout=dropout, + channels_per_head=channels_per_head, + res_balance=res_balance, + attn_balance=attn_balance, + clip_act=clip_act, + ) + self.logvar_fourier = MPFourier(logvar_channels) + self.logvar_linear = MPConv(logvar_channels, 1, kernel=()) + + def forward( + self, + sample: torch.Tensor, + sigma: torch.Tensor, + class_labels: Optional[torch.Tensor] = None, + force_fp32: bool = False, + return_logvar: bool = False, + return_dict: bool = True, + ) -> EDM2UNet2DOutput: + x = sample.to(torch.float32) + sigma = sigma.to(torch.float32).reshape(-1, 1, 1, 1) + if self.num_class_embeds == 0: + class_labels = None + else: + if class_labels is None: + class_labels = torch.zeros([x.shape[0], self.num_class_embeds], device=x.device) + class_labels = class_labels.to(torch.float32).reshape(-1, self.num_class_embeds) + dtype = torch.float16 if (self.use_fp16 and not force_fp32 and x.device.type == "cuda") else torch.float32 + + c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2) + c_out = sigma * self.sigma_data / (sigma**2 + self.sigma_data**2).sqrt() + c_in = 1 / (self.sigma_data**2 + sigma**2).sqrt() + c_noise = sigma.flatten().log() / 4 + + x_in = (c_in * x).to(dtype) + f_x = self.unet(x_in, c_noise, class_labels) + d_x = c_skip * x + c_out * f_x.to(torch.float32) + + logvar = None + if return_logvar: + logvar = self.logvar_linear(self.logvar_fourier(c_noise)).reshape(-1, 1, 1, 1) + + if not return_dict: + return (d_x, logvar) + return EDM2UNet2DOutput(sample=d_x, logvar=logvar) + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path: str, torch_dtype: Optional[torch.dtype] = None, **kwargs): + subfolder = kwargs.pop("subfolder", None) + model_dir = os.path.join(pretrained_model_name_or_path, subfolder) if subfolder else pretrained_model_name_or_path + with open(os.path.join(model_dir, "config.json"), "r", encoding="utf-8") as f: + config = json.load(f) + init_kwargs = {k: v for k, v in config.items() if k in _CONFIG_KEYS} + model = cls(**init_kwargs) + weight_file = os.path.join(model_dir, "diffusion_pytorch_model.safetensors") + if os.path.isfile(weight_file): + from safetensors.torch import load_file + + state_dict = load_file(weight_file) + else: + state_dict = torch.load(os.path.join(model_dir, "diffusion_pytorch_model.bin"), map_location="cpu") + model.load_state_dict(state_dict, strict=True) + if torch_dtype is not None: + model = model.to(dtype=torch_dtype) + return model + + def save_pretrained(self, save_directory: str, safe_serialization: bool = True): + os.makedirs(save_directory, exist_ok=True) + stored = dict(getattr(self, "config", {})) + config = {"_class_name": self.__class__.__name__} + for key in _CONFIG_KEYS: + if key in stored: + config[key] = stored[key] + elif hasattr(self, key): + config[key] = getattr(self, key) + with open(os.path.join(save_directory, "config.json"), "w", encoding="utf-8") as f: + json.dump(config, f, indent=2, sort_keys=True) + f.write("\n") + state_dict = self.state_dict() + if safe_serialization: + from safetensors.torch import save_file + + save_file(state_dict, os.path.join(save_directory, "diffusion_pytorch_model.safetensors")) + else: + torch.save(state_dict, os.path.join(save_directory, "diffusion_pytorch_model.bin")) diff --git a/edm2-img512-s-fid/vae/config.json b/edm2-img512-s-fid/vae/config.json new file mode 100644 index 0000000000000000000000000000000000000000..ae2ab37de2acb8bc348d3de809ae8385324be962 --- /dev/null +++ b/edm2-img512-s-fid/vae/config.json @@ -0,0 +1,38 @@ +{ + "_class_name": "AutoencoderKL", + "_diffusers_version": "0.36.0", + "_name_or_path": "stabilityai/sd-vae-ft-mse", + "act_fn": "silu", + "block_out_channels": [ + 128, + 256, + 512, + 512 + ], + "down_block_types": [ + "DownEncoderBlock2D", + "DownEncoderBlock2D", + "DownEncoderBlock2D", + "DownEncoderBlock2D" + ], + "force_upcast": true, + "in_channels": 3, + "latent_channels": 4, + "latents_mean": null, + "latents_std": null, + "layers_per_block": 2, + "mid_block_add_attention": true, + "norm_num_groups": 32, + "out_channels": 3, + "sample_size": 256, + "scaling_factor": 0.18215, + "shift_factor": null, + "up_block_types": [ + "UpDecoderBlock2D", + "UpDecoderBlock2D", + "UpDecoderBlock2D", + "UpDecoderBlock2D" + ], + "use_post_quant_conv": true, + "use_quant_conv": true +} diff --git a/edm2-img512-s-fid/vae/diffusion_pytorch_model.safetensors b/edm2-img512-s-fid/vae/diffusion_pytorch_model.safetensors new file mode 100644 index 0000000000000000000000000000000000000000..90464d67ac7303d0ee4696334df13da130a948ea --- /dev/null +++ b/edm2-img512-s-fid/vae/diffusion_pytorch_model.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a1d993488569e928462932c8c38a0760b874d166399b14414135bd9c42df5815 +size 334643276 diff --git a/edm2-img512-xl-fid/demo.png b/edm2-img512-xl-fid/demo.png new file mode 100644 index 0000000000000000000000000000000000000000..e9210202336f46b96bc0b80245451fde9342e1e9 --- /dev/null +++ b/edm2-img512-xl-fid/demo.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:551c91feb88ea0279f61d52c20463da670f01f99e37467a6f358b699f33cd526 +size 369559 diff --git a/edm2-img512-xl-fid/model_index.json b/edm2-img512-xl-fid/model_index.json new file mode 100644 index 0000000000000000000000000000000000000000..24c0305eea3ec64abbc2f870476afa4f5e462f42 --- /dev/null +++ b/edm2-img512-xl-fid/model_index.json @@ -0,0 +1,19 @@ +{ + "_class_name": [ + "pipeline", + "EDM2Pipeline" + ], + "_diffusers_version": "0.31.0", + "scheduler": [ + "diffusers", + "EDMEulerScheduler" + ], + "unet": [ + "unet_edm2", + "EDM2UNet2DModel" + ], + "vae": [ + "diffusers", + "AutoencoderKL" + ] +} diff --git a/edm2-img512-xl-fid/pipeline.py b/edm2-img512-xl-fid/pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..6b4131215d07770c74efa0674d00f3d6423d9c06 --- /dev/null +++ b/edm2-img512-xl-fid/pipeline.py @@ -0,0 +1,406 @@ +"""Hub custom pipeline: EDM2Pipeline. +Load with native Hugging Face diffusers and trust_remote_code=True. +""" + +from __future__ import annotations + +# Copyright 2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os +from pathlib import Path +from typing import Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Union + +import numpy as np +import torch + +from diffusers.image_processor import VaeImageProcessor +from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput +from diffusers.utils import replace_example_docstring +from diffusers.utils.torch_utils import randn_tensor + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> from pathlib import Path + >>> import torch + >>> from diffusers import DiffusionPipeline + + >>> model_dir = Path("BiliSakura/EDM2-diffusers/edm2-img512-xs-fid").resolve() + >>> pipe = DiffusionPipeline.from_pretrained( + ... str(model_dir), + ... local_files_only=True, + ... custom_pipeline=str(model_dir / "pipeline.py"), + ... trust_remote_code=True, + ... torch_dtype=torch.float32, + ... ) + >>> pipe.to("cuda") + + >>> generator = torch.Generator(device="cuda").manual_seed(42) + >>> image = pipe( + ... class_labels=207, + ... num_inference_steps=32, + ... guidance_scale=1.0, + ... generator=generator, + ... ).images[0] + >>> image.save("demo.png") + ``` +""" + +# Default Stability VAE latent whitening used by NVlabs/edm2 (training/encoders.py). +_STABILITY_VAE_SCALE = np.float32(0.5) / np.float32([4.17, 4.62, 3.71, 3.28]) +_STABILITY_VAE_BIAS = np.float32(0.0) - np.float32([5.81, 3.25, 0.12, -2.15]) * _STABILITY_VAE_SCALE + +class EDM2Pipeline(DiffusionPipeline): + r""" + Pipeline for class-conditional image generation with EDM2 + ([Analyzing and Improving the Training Dynamics of Diffusion Models](https://arxiv.org/abs/2312.02696)). + + Parameters: + unet ([`EDM2UNet2DModel`]): + Main magnitude-preserving U-Net with EDM preconditioning. + scheduler ([`EDMEulerScheduler`]): + Built-in diffusers scheduler used for the Karras sigma schedule. EDM2 Heun sampling runs in + the pipeline because the UNet returns denoised latents rather than noise predictions. + vae ([`AutoencoderKL`], *optional*): + Decoder for 512px latent-diffusion checkpoints. Required when `unet.in_channels == 4`. + gnet ([`EDM2UNet2DModel`], *optional*): + Guiding network for autoguidance (`ref.lerp(main, guidance_scale)`). + id2label (`dict[int, str]`, *optional*): + ImageNet class id to English label mapping. + """ + + model_cpu_offload_seq = "unet->gnet->vae" + _optional_components = ["vae", "gnet"] + + def __init__( + self, + unet, + scheduler, + vae=None, + gnet=None, + id2label: Optional[Dict[Union[int, str], str]] = None, + ) -> None: + super().__init__() + self.register_modules(unet=unet, scheduler=scheduler, vae=vae, gnet=gnet) + self._id2label = self._normalize_id2label(id2label) + self.labels = self._build_label2id(self._id2label) + self._labels_loaded_from_model_index = bool(self._id2label) + self.vae_scale_factor = 8 if self.vae is not None else 1 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_normalize=False) + + @staticmethod + def _normalize_id2label(id2label: Optional[Dict[Union[int, str], str]]) -> Dict[int, str]: + if not id2label: + return {} + return {int(key): value for key, value in id2label.items()} + + @staticmethod + def _build_label2id(id2label: Dict[int, str]) -> Dict[str, int]: + label2id: Dict[str, int] = {} + for class_id, value in id2label.items(): + for synonym in value.split(","): + synonym = synonym.strip() + if synonym: + label2id[synonym] = int(class_id) + return dict(sorted(label2id.items())) + + def _ensure_labels_loaded(self) -> None: + if self._labels_loaded_from_model_index: + return + loaded = self._read_id2label_from_model_index(getattr(self.config, "_name_or_path", None)) + if loaded: + self._id2label = loaded + self.labels = self._build_label2id(self._id2label) + self._labels_loaded_from_model_index = True + + @staticmethod + def _read_id2label_from_model_index(variant_path: Optional[str]) -> Dict[int, str]: + if not variant_path: + return {} + model_index_path = Path(variant_path).resolve() / "model_index.json" + if not model_index_path.is_file(): + return {} + raw = json.loads(model_index_path.read_text(encoding="utf-8")) + id2label = raw.get("id2label") + if not isinstance(id2label, dict): + return {} + return {int(key): value for key, value in id2label.items()} + + @property + def id2label(self) -> Dict[int, str]: + r"""ImageNet class id to English label string (comma-separated synonyms).""" + self._ensure_labels_loaded() + return self._id2label + + def get_label_ids(self, label: Union[str, List[str]]) -> List[int]: + r""" + Map ImageNet label strings to class ids. + + Args: + label (`str` or `list[str]`): + One or more English label strings that match entries in `id2label`. + """ + self._ensure_labels_loaded() + if not self.labels: + raise ValueError("No English labels loaded. Add `id2label` to model_index.json.") + labels = [label] if isinstance(label, str) else list(label) + missing = [item for item in labels if item not in self.labels] + if missing: + preview = ", ".join(list(self.labels.keys())[:8]) + raise ValueError(f"Unknown English label(s): {missing}. Example valid labels: {preview}, ...") + return [self.labels[item] for item in labels] + + def _default_image_size(self) -> int: + latent_size = int(getattr(self.unet, "sample_size", getattr(self.unet.config, "sample_size", 64))) + return latent_size * self.vae_scale_factor + + def check_inputs( + self, + height: int, + width: int, + num_inference_steps: int, + guidance_scale: float, + output_type: str, + ) -> None: + if num_inference_steps < 1: + raise ValueError("num_inference_steps must be >= 1.") + if guidance_scale < 1.0: + raise ValueError("guidance_scale must be >= 1.0.") + if guidance_scale > 1.0 and self.gnet is None: + raise ValueError("guidance_scale > 1.0 requires a guiding network (`gnet`).") + if output_type not in {"pil", "np", "pt", "latent"}: + raise ValueError("output_type must be one of: 'pil', 'np', 'pt', 'latent'.") + + native_size = self._default_image_size() + if height != native_size or width != native_size: + raise ValueError( + f"EDM2 expects native resolution height=width={native_size}. " + f"Got height={height}, width={width}." + ) + + def _normalize_class_labels( + self, + class_labels: Optional[Union[int, str, Sequence[Union[int, str]], torch.Tensor]], + batch_size: int, + device: torch.device, + ) -> Optional[torch.Tensor]: + label_dim = int(getattr(self.unet, "num_class_embeds", getattr(self.unet.config, "num_class_embeds", 0))) + if label_dim == 0: + return None + if class_labels is None: + indices = torch.randint(label_dim, size=(batch_size,), device=device) + return torch.eye(label_dim, device=device, dtype=torch.float32)[indices] + + if isinstance(class_labels, str): + class_labels = self.get_label_ids(class_labels)[0] + elif isinstance(class_labels, Sequence) and class_labels and isinstance(class_labels[0], str): + class_labels = self.get_label_ids(list(class_labels)) + + if isinstance(class_labels, int): + indices = torch.full((batch_size,), class_labels, device=device, dtype=torch.long) + elif isinstance(class_labels, torch.Tensor): + if class_labels.ndim == 2: + labels = class_labels.to(device=device, dtype=torch.float32) + if labels.shape[0] != batch_size: + raise ValueError(f"class_labels batch must match batch_size={batch_size}.") + return labels + indices = class_labels.to(device=device, dtype=torch.long).flatten() + else: + indices = torch.tensor(list(class_labels), device=device, dtype=torch.long) + + if indices.numel() == 1 and batch_size > 1: + indices = indices.repeat(batch_size) + if indices.numel() != batch_size: + raise ValueError(f"class_labels must resolve to batch size {batch_size}, got {indices.numel()}.") + return torch.eye(label_dim, device=device, dtype=torch.float32)[indices] + + def prepare_latents( + self, + batch_size: int, + height: int, + width: int, + dtype: torch.dtype, + device: torch.device, + generator: Optional[Union[torch.Generator, List[torch.Generator]]], + ) -> torch.Tensor: + in_channels = int(getattr(self.unet, "in_channels", getattr(self.unet.config, "in_channels", 4))) + latent_size = height // self.vae_scale_factor + return randn_tensor( + (batch_size, in_channels, latent_size, latent_size), + generator=generator, + device=device, + dtype=torch.float32, + ) + + def decode_latents(self, latents: torch.Tensor, output_type: str = "pil"): + if output_type == "latent": + return latents + + in_channels = int(getattr(self.unet, "in_channels", getattr(self.unet.config, "in_channels", 3))) + if self.vae is None: + image = (latents.to(torch.float32) * 127.5 + 128).clip(0, 255) / 255.0 + return self.image_processor.postprocess(image, output_type=output_type) + + if in_channels == 4: + x = latents.to(torch.float32) + scale = torch.as_tensor(_STABILITY_VAE_SCALE, dtype=x.dtype, device=x.device).reshape(1, -1, 1, 1) + bias = torch.as_tensor(_STABILITY_VAE_BIAS, dtype=x.dtype, device=x.device).reshape(1, -1, 1, 1) + x = (x - bias) / scale + else: + x = latents.to(torch.float32) + + vae_dtype = getattr(self.vae, "dtype", None) or next(self.vae.parameters()).dtype + image = self.vae.decode(x.to(dtype=vae_dtype)).sample.to(torch.float32).clamp(0, 1) + + return self.image_processor.postprocess(image, output_type=output_type) + + @staticmethod + def _apply_autoguidance( + main: torch.Tensor, + ref: torch.Tensor, + guidance_scale: float, + ) -> torch.Tensor: + return ref.lerp(main, guidance_scale) + + @staticmethod + def _sample_edm2_heun( + denoise_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor], + noise: torch.Tensor, + sigmas: torch.Tensor, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + progress_bar: Optional[Callable[[Iterable], Iterable]] = None, + dtype: torch.dtype = torch.float32, + ) -> torch.Tensor: + """NVlabs EDM2 Heun sampler (generate_images.edm_sampler, guidance=1, S_churn=0).""" + x_next = noise.to(dtype) * sigmas[0] + + sigma_pairs = list(zip(sigmas[:-1], sigmas[1:])) + if progress_bar is not None: + sigma_pairs = progress_bar(sigma_pairs) + + num_steps = len(sigma_pairs) + for i, (sigma_cur, sigma_next) in enumerate(sigma_pairs): + x_hat, sigma_hat = x_next, sigma_cur + d_cur = (x_hat - denoise_fn(x_hat, sigma_hat)) / sigma_hat + x_next = x_hat + (sigma_next - sigma_hat) * d_cur + if i < num_steps - 1: + d_prime = (x_next - denoise_fn(x_next, sigma_next)) / sigma_next + x_next = x_hat + (sigma_next - sigma_hat) * (0.5 * d_cur + 0.5 * d_prime) + return x_next + + @torch.inference_mode() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + class_labels: Optional[Union[int, str, Sequence[Union[int, str]], torch.Tensor]] = None, + batch_size: int = 1, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 32, + guidance_scale: float = 1.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + output_type: str = "pil", + return_dict: bool = True, + ) -> Union[ImagePipelineOutput, Tuple]: + r""" + Generate class-conditional images with EDM2. + + Args: + class_labels (`int`, `str`, `list[int]`, `list[str]`, or `torch.Tensor`, *optional*): + ImageNet class indices, English label strings, or one-hot float tensors. + Random classes are sampled when omitted on conditional models. + batch_size (`int`, defaults to `1`): + Number of images to generate. + height (`int`, *optional*): + Output height in pixels. Defaults to the pretrained native resolution. + width (`int`, *optional*): + Output width in pixels. Defaults to the pretrained native resolution. + num_inference_steps (`int`, defaults to `32`): + Number of EDM2 Heun steps (NVlabs default). + guidance_scale (`float`, defaults to `1.0`): + Autoguidance strength. Values above `1.0` blend the main net with `gnet` + via `gnet_output.lerp(unet_output, guidance_scale)`. + generator (`torch.Generator`, *optional*): + RNG for reproducibility. + output_type (`str`, defaults to `"pil"`): + `"pil"`, `"np"`, `"pt"`, or `"latent"`. + return_dict (`bool`, defaults to `True`): + Return [`~pipelines.pipeline_utils.ImagePipelineOutput`] if True. + + Examples: + + """ + default_size = self._default_image_size() + height = int(height or default_size) + width = int(width or default_size) + self.check_inputs(height, width, num_inference_steps, guidance_scale, output_type) + + device = self._execution_device + dtype = self.unet.dtype + labels = self._normalize_class_labels(class_labels, batch_size=batch_size, device=device) + noise = self.prepare_latents(batch_size, height, width, dtype, device, generator) + + def denoise_fn(x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor: + sigma_batch = sigma.reshape(1).expand(batch_size) + main = self.unet( + sample=x, + sigma=sigma_batch, + class_labels=labels, + force_fp32=True, + ).sample + if guidance_scale == 1.0 or self.gnet is None: + return main.to(torch.float32) + ref = self.gnet( + sample=x, + sigma=sigma_batch, + class_labels=labels, + force_fp32=True, + ).sample + return self._apply_autoguidance(main, ref, guidance_scale).to(torch.float32) + + self.scheduler.set_timesteps(num_inference_steps, device=device) + latents = self._sample_edm2_heun( + denoise_fn=denoise_fn, + noise=noise, + sigmas=self.scheduler.sigmas.to(device), + generator=generator, + progress_bar=self.progress_bar, + dtype=torch.float32, + ) + + image = self.decode_latents(latents, output_type=output_type) + if not return_dict: + return (image, latents) + return ImagePipelineOutput(images=image) + + @classmethod + def _load_vae(cls, pretrained_model_name_or_path: str, torch_dtype: Optional[torch.dtype] = None): + vae_dir = os.path.join(pretrained_model_name_or_path, "vae") + if os.path.isdir(vae_dir): + try: + + return AutoencoderKL.from_pretrained(vae_dir, torch_dtype=torch_dtype) + except Exception: + return None + + vae_hint = os.path.join(pretrained_model_name_or_path, "vae_pretrained_model_name_or_path.txt") + if os.path.isfile(vae_hint): + with open(vae_hint, "r", encoding="utf-8") as f: + hub_id = f.read().strip() + if hub_id: + + return AutoencoderKL.from_pretrained(hub_id, torch_dtype=torch_dtype) + return None diff --git a/edm2-img512-xl-fid/scheduler/scheduler_config.json b/edm2-img512-xl-fid/scheduler/scheduler_config.json new file mode 100644 index 0000000000000000000000000000000000000000..08cbde97955d06d66b0b72be217db84bb797d711 --- /dev/null +++ b/edm2-img512-xl-fid/scheduler/scheduler_config.json @@ -0,0 +1,11 @@ +{ + "_class_name": "EDMEulerScheduler", + "final_sigmas_type": "zero", + "num_train_timesteps": 1000, + "prediction_type": "epsilon", + "rho": 7.0, + "sigma_data": 0.5, + "sigma_max": 80.0, + "sigma_min": 0.002, + "sigma_schedule": "karras" +} diff --git a/edm2-img512-xl-fid/unet/config.json b/edm2-img512-xl-fid/unet/config.json new file mode 100644 index 0000000000000000000000000000000000000000..c341f808b68df99197f84d725dea6b36328fdb5c --- /dev/null +++ b/edm2-img512-xl-fid/unet/config.json @@ -0,0 +1,31 @@ +{ + "_class_name": "EDM2UNet2DModel", + "attn_balance": 0.3, + "attn_resolutions": [ + 16, + 8 + ], + "channel_mult": [ + 1, + 2, + 3, + 4 + ], + "channel_mult_emb": 4, + "channel_mult_noise": 1, + "channels_per_head": 64, + "clip_act": 256, + "concat_balance": 0.5, + "dropout": 0.0, + "in_channels": 4, + "label_balance": 0.5, + "logvar_channels": 128, + "model_channels": 384, + "num_blocks": 3, + "num_class_embeds": 1000, + "out_channels": 4, + "res_balance": 0.3, + "sample_size": 64, + "sigma_data": 0.5, + "use_fp16": true +} diff --git a/edm2-img512-xl-fid/unet/diffusion_pytorch_model.safetensors b/edm2-img512-xl-fid/unet/diffusion_pytorch_model.safetensors new file mode 100644 index 0000000000000000000000000000000000000000..c3d4d87fa62118756846646e93146ae8824f5c93 --- /dev/null +++ b/edm2-img512-xl-fid/unet/diffusion_pytorch_model.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3c7402d8a4e91781b5c94fa2a5beee5820970ad99d2249141e191364885f222a +size 4477161892 diff --git a/edm2-img512-xl-fid/unet/unet_edm2.py b/edm2-img512-xl-fid/unet/unet_edm2.py new file mode 100644 index 0000000000000000000000000000000000000000..49d9451b443c4d6ad3b6ba2db19a7a804b07b6de --- /dev/null +++ b/edm2-img512-xl-fid/unet/unet_edm2.py @@ -0,0 +1,434 @@ +import math +import json +import os +from dataclasses import dataclass +from typing import List, Optional, Tuple + +import numpy as np +import torch + +try: + from diffusers.configuration_utils import ConfigMixin, register_to_config + from diffusers.models.modeling_utils import ModelMixin + from diffusers.utils import BaseOutput +except ImportError: # pragma: no cover + class ModelMixin(torch.nn.Module): + pass + + class ConfigMixin: + config = {} + + def register_to_config(self, **kwargs): + self.config = kwargs + + def register_to_config(func): + return func + + @dataclass + class BaseOutput: + pass + + +def normalize(x: torch.Tensor, dim: Optional[List[int]] = None, eps: float = 1e-4) -> torch.Tensor: + if dim is None: + dim = list(range(1, x.ndim)) + norm = torch.linalg.vector_norm(x, dim=dim, keepdim=True, dtype=torch.float32) + norm = torch.add(eps, norm, alpha=math.sqrt(norm.numel() / x.numel())) + return x / norm.to(x.dtype) + + +def resample(x: torch.Tensor, f: List[float], mode: str = "keep") -> torch.Tensor: + if mode == "keep": + return x + filt = np.float32(f) + pad = (len(filt) - 1) // 2 + filt = filt / filt.sum() + filt = np.outer(filt, filt)[np.newaxis, np.newaxis, :, :] + filt = torch.as_tensor(filt, dtype=x.dtype, device=x.device) + c = x.shape[1] + if mode == "down": + return torch.nn.functional.conv2d(x, filt.tile([c, 1, 1, 1]), groups=c, stride=2, padding=(pad,)) + return torch.nn.functional.conv_transpose2d(x, (filt * 4).tile([c, 1, 1, 1]), groups=c, stride=2, padding=(pad,)) + + +def mp_silu(x: torch.Tensor) -> torch.Tensor: + return torch.nn.functional.silu(x) / 0.596 + + +def mp_sum(a: torch.Tensor, b: torch.Tensor, t: float = 0.5) -> torch.Tensor: + return a.lerp(b, t) / math.sqrt((1 - t) ** 2 + t ** 2) + + +def mp_cat(a: torch.Tensor, b: torch.Tensor, dim: int = 1, t: float = 0.5) -> torch.Tensor: + na = a.shape[dim] + nb = b.shape[dim] + c = math.sqrt((na + nb) / ((1 - t) ** 2 + t ** 2)) + wa = c / math.sqrt(na) * (1 - t) + wb = c / math.sqrt(nb) * t + return torch.cat([wa * a, wb * b], dim=dim) + + +class MPFourier(torch.nn.Module): + def __init__(self, num_channels: int, bandwidth: float = 1): + super().__init__() + self.register_buffer("freqs", 2 * math.pi * torch.randn(num_channels) * bandwidth) + self.register_buffer("phases", 2 * math.pi * torch.rand(num_channels)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + y = x.to(torch.float32).ger(self.freqs.to(torch.float32)) + y = y + self.phases.to(torch.float32) + y = y.cos() * math.sqrt(2) + return y.to(x.dtype) + + +class MPConv(torch.nn.Module): + def __init__(self, in_channels: int, out_channels: int, kernel: Tuple[int, ...]): + super().__init__() + self.out_channels = out_channels + self.weight = torch.nn.Parameter(torch.randn(out_channels, in_channels, *kernel)) + + def forward(self, x: torch.Tensor, gain: float = 1) -> torch.Tensor: + w = self.weight.to(torch.float32) + if self.training: + with torch.no_grad(): + self.weight.copy_(normalize(w)) + w = normalize(w) + w = w * (gain / math.sqrt(w[0].numel())) + w = w.to(x.dtype) + if w.ndim == 2: + return x @ w.t() + return torch.nn.functional.conv2d(x, w, padding=(w.shape[-1] // 2,)) + + +class Block(torch.nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + emb_channels: int, + flavor: str = "enc", + resample_mode: str = "keep", + resample_filter: List[float] = [1, 1], + attention: bool = False, + channels_per_head: int = 64, + dropout: float = 0.0, + res_balance: float = 0.3, + attn_balance: float = 0.3, + clip_act: Optional[float] = 256, + ): + super().__init__() + self.out_channels = out_channels + self.flavor = flavor + self.resample_filter = resample_filter + self.resample_mode = resample_mode + self.num_heads = out_channels // channels_per_head if attention else 0 + self.dropout = dropout + self.res_balance = res_balance + self.attn_balance = attn_balance + self.clip_act = clip_act + self.emb_gain = torch.nn.Parameter(torch.zeros([])) + self.conv_res0 = MPConv(out_channels if flavor == "enc" else in_channels, out_channels, kernel=(3, 3)) + self.emb_linear = MPConv(emb_channels, out_channels, kernel=()) + self.conv_res1 = MPConv(out_channels, out_channels, kernel=(3, 3)) + self.conv_skip = MPConv(in_channels, out_channels, kernel=(1, 1)) if in_channels != out_channels else None + self.attn_qkv = MPConv(out_channels, out_channels * 3, kernel=(1, 1)) if self.num_heads else None + self.attn_proj = MPConv(out_channels, out_channels, kernel=(1, 1)) if self.num_heads else None + + def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor: + x = resample(x, f=self.resample_filter, mode=self.resample_mode) + if self.flavor == "enc": + if self.conv_skip is not None: + x = self.conv_skip(x) + x = normalize(x, dim=[1]) + + y = self.conv_res0(mp_silu(x)) + c = self.emb_linear(emb, gain=self.emb_gain) + 1 + y = mp_silu(y * c.unsqueeze(2).unsqueeze(3).to(y.dtype)) + if self.training and self.dropout: + y = torch.nn.functional.dropout(y, p=self.dropout) + y = self.conv_res1(y) + + if self.flavor == "dec" and self.conv_skip is not None: + x = self.conv_skip(x) + x = mp_sum(x, y, t=self.res_balance) + + if self.num_heads: + y = self.attn_qkv(x) + y = y.reshape(y.shape[0], self.num_heads, -1, 3, y.shape[2] * y.shape[3]) + q, k, v = normalize(y, dim=[2]).unbind(3) + w = torch.einsum("nhcq,nhck->nhqk", q, k / math.sqrt(q.shape[2])).softmax(dim=3) + y = torch.einsum("nhqk,nhck->nhcq", w, v) + y = self.attn_proj(y.reshape(*x.shape)) + x = mp_sum(x, y, t=self.attn_balance) + + if self.clip_act is not None: + x = x.clip_(-self.clip_act, self.clip_act) + return x + + +class EDM2UNet(torch.nn.Module): + def __init__( + self, + img_resolution: int, + img_channels: int, + label_dim: int, + model_channels: int = 192, + channel_mult: Tuple[int, ...] = (1, 2, 3, 4), + channel_mult_noise: Optional[int] = None, + channel_mult_emb: Optional[int] = None, + num_blocks: int = 3, + attn_resolutions: Tuple[int, ...] = (16, 8), + label_balance: float = 0.5, + concat_balance: float = 0.5, + **block_kwargs, + ): + super().__init__() + cblock = [model_channels * x for x in channel_mult] + cnoise = model_channels * channel_mult_noise if channel_mult_noise is not None else cblock[0] + cemb = model_channels * channel_mult_emb if channel_mult_emb is not None else max(cblock) + self.label_balance = label_balance + self.concat_balance = concat_balance + self.out_gain = torch.nn.Parameter(torch.zeros([])) + + self.emb_fourier = MPFourier(cnoise) + self.emb_noise = MPConv(cnoise, cemb, kernel=()) + self.emb_label = MPConv(label_dim, cemb, kernel=()) if label_dim else None + + self.enc = torch.nn.ModuleDict() + cout = img_channels + 1 + for level, channels in enumerate(cblock): + res = img_resolution >> level + if level == 0: + cin = cout + cout = channels + self.enc[f"{res}x{res}_conv"] = MPConv(cin, cout, kernel=(3, 3)) + else: + self.enc[f"{res}x{res}_down"] = Block(cout, cout, cemb, flavor="enc", resample_mode="down", **block_kwargs) + for idx in range(num_blocks): + cin = cout + cout = channels + self.enc[f"{res}x{res}_block{idx}"] = Block( + cin, + cout, + cemb, + flavor="enc", + attention=(res in attn_resolutions), + **block_kwargs, + ) + + self.dec = torch.nn.ModuleDict() + skips = [block.out_channels for block in self.enc.values()] + for level, channels in reversed(list(enumerate(cblock))): + res = img_resolution >> level + if level == len(cblock) - 1: + self.dec[f"{res}x{res}_in0"] = Block(cout, cout, cemb, flavor="dec", attention=True, **block_kwargs) + self.dec[f"{res}x{res}_in1"] = Block(cout, cout, cemb, flavor="dec", **block_kwargs) + else: + self.dec[f"{res}x{res}_up"] = Block(cout, cout, cemb, flavor="dec", resample_mode="up", **block_kwargs) + for idx in range(num_blocks + 1): + cin = cout + skips.pop() + cout = channels + self.dec[f"{res}x{res}_block{idx}"] = Block( + cin, + cout, + cemb, + flavor="dec", + attention=(res in attn_resolutions), + **block_kwargs, + ) + + self.out_conv = MPConv(cout, img_channels, kernel=(3, 3)) + + def forward(self, x: torch.Tensor, noise_labels: torch.Tensor, class_labels: Optional[torch.Tensor]) -> torch.Tensor: + emb = self.emb_noise(self.emb_fourier(noise_labels)) + if self.emb_label is not None: + if class_labels is None: + raise ValueError("class_labels are required for conditional EDM2UNet.") + emb = mp_sum(emb, self.emb_label(class_labels * math.sqrt(class_labels.shape[1])), t=self.label_balance) + emb = mp_silu(emb) + + x = torch.cat([x, torch.ones_like(x[:, :1])], dim=1) + skips = [] + for name, block in self.enc.items(): + x = block(x) if "conv" in name else block(x, emb) + skips.append(x) + + for name, block in self.dec.items(): + if "block" in name: + x = mp_cat(x, skips.pop(), t=self.concat_balance) + x = block(x, emb) + return self.out_conv(x, gain=self.out_gain) + + +@dataclass +class EDM2UNet2DOutput(BaseOutput): + sample: torch.Tensor + logvar: Optional[torch.Tensor] = None + + + +_CONFIG_KEYS = ( + "sample_size", + "in_channels", + "out_channels", + "num_class_embeds", + "use_fp16", + "sigma_data", + "logvar_channels", + "model_channels", + "channel_mult", + "channel_mult_noise", + "channel_mult_emb", + "num_blocks", + "attn_resolutions", + "label_balance", + "concat_balance", + "dropout", + "channels_per_head", + "res_balance", + "attn_balance", + "clip_act", +) + + +class EDM2UNet2DModel(ModelMixin, ConfigMixin): + @register_to_config + def __init__( + self, + sample_size: int = 64, + in_channels: int = 4, + out_channels: int = 4, + num_class_embeds: int = 0, + use_fp16: bool = True, + sigma_data: float = 0.5, + logvar_channels: int = 128, + model_channels: int = 192, + channel_mult: Tuple[int, ...] = (1, 2, 3, 4), + channel_mult_noise: Optional[int] = None, + channel_mult_emb: Optional[int] = None, + num_blocks: int = 3, + attn_resolutions: Tuple[int, ...] = (16, 8), + label_balance: float = 0.5, + concat_balance: float = 0.5, + dropout: float = 0.0, + channels_per_head: int = 64, + res_balance: float = 0.3, + attn_balance: float = 0.3, + clip_act: Optional[float] = 256, + ): + super().__init__() + self.sample_size = sample_size + self.in_channels = in_channels + self.out_channels = out_channels + self.num_class_embeds = num_class_embeds + self.use_fp16 = use_fp16 + self.sigma_data = sigma_data + self.model_channels = model_channels + self.channel_mult = channel_mult + self.channel_mult_noise = channel_mult_noise + self.channel_mult_emb = channel_mult_emb + self.num_blocks = num_blocks + self.attn_resolutions = attn_resolutions + self.label_balance = label_balance + self.concat_balance = concat_balance + self.dropout = dropout + self.channels_per_head = channels_per_head + self.res_balance = res_balance + self.attn_balance = attn_balance + self.clip_act = clip_act + self.unet = EDM2UNet( + img_resolution=sample_size, + img_channels=in_channels, + label_dim=num_class_embeds, + model_channels=model_channels, + channel_mult=channel_mult, + channel_mult_noise=channel_mult_noise, + channel_mult_emb=channel_mult_emb, + num_blocks=num_blocks, + attn_resolutions=attn_resolutions, + label_balance=label_balance, + concat_balance=concat_balance, + dropout=dropout, + channels_per_head=channels_per_head, + res_balance=res_balance, + attn_balance=attn_balance, + clip_act=clip_act, + ) + self.logvar_fourier = MPFourier(logvar_channels) + self.logvar_linear = MPConv(logvar_channels, 1, kernel=()) + + def forward( + self, + sample: torch.Tensor, + sigma: torch.Tensor, + class_labels: Optional[torch.Tensor] = None, + force_fp32: bool = False, + return_logvar: bool = False, + return_dict: bool = True, + ) -> EDM2UNet2DOutput: + x = sample.to(torch.float32) + sigma = sigma.to(torch.float32).reshape(-1, 1, 1, 1) + if self.num_class_embeds == 0: + class_labels = None + else: + if class_labels is None: + class_labels = torch.zeros([x.shape[0], self.num_class_embeds], device=x.device) + class_labels = class_labels.to(torch.float32).reshape(-1, self.num_class_embeds) + dtype = torch.float16 if (self.use_fp16 and not force_fp32 and x.device.type == "cuda") else torch.float32 + + c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2) + c_out = sigma * self.sigma_data / (sigma**2 + self.sigma_data**2).sqrt() + c_in = 1 / (self.sigma_data**2 + sigma**2).sqrt() + c_noise = sigma.flatten().log() / 4 + + x_in = (c_in * x).to(dtype) + f_x = self.unet(x_in, c_noise, class_labels) + d_x = c_skip * x + c_out * f_x.to(torch.float32) + + logvar = None + if return_logvar: + logvar = self.logvar_linear(self.logvar_fourier(c_noise)).reshape(-1, 1, 1, 1) + + if not return_dict: + return (d_x, logvar) + return EDM2UNet2DOutput(sample=d_x, logvar=logvar) + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path: str, torch_dtype: Optional[torch.dtype] = None, **kwargs): + subfolder = kwargs.pop("subfolder", None) + model_dir = os.path.join(pretrained_model_name_or_path, subfolder) if subfolder else pretrained_model_name_or_path + with open(os.path.join(model_dir, "config.json"), "r", encoding="utf-8") as f: + config = json.load(f) + init_kwargs = {k: v for k, v in config.items() if k in _CONFIG_KEYS} + model = cls(**init_kwargs) + weight_file = os.path.join(model_dir, "diffusion_pytorch_model.safetensors") + if os.path.isfile(weight_file): + from safetensors.torch import load_file + + state_dict = load_file(weight_file) + else: + state_dict = torch.load(os.path.join(model_dir, "diffusion_pytorch_model.bin"), map_location="cpu") + model.load_state_dict(state_dict, strict=True) + if torch_dtype is not None: + model = model.to(dtype=torch_dtype) + return model + + def save_pretrained(self, save_directory: str, safe_serialization: bool = True): + os.makedirs(save_directory, exist_ok=True) + stored = dict(getattr(self, "config", {})) + config = {"_class_name": self.__class__.__name__} + for key in _CONFIG_KEYS: + if key in stored: + config[key] = stored[key] + elif hasattr(self, key): + config[key] = getattr(self, key) + with open(os.path.join(save_directory, "config.json"), "w", encoding="utf-8") as f: + json.dump(config, f, indent=2, sort_keys=True) + f.write("\n") + state_dict = self.state_dict() + if safe_serialization: + from safetensors.torch import save_file + + save_file(state_dict, os.path.join(save_directory, "diffusion_pytorch_model.safetensors")) + else: + torch.save(state_dict, os.path.join(save_directory, "diffusion_pytorch_model.bin")) diff --git a/edm2-img512-xl-fid/vae/config.json b/edm2-img512-xl-fid/vae/config.json new file mode 100644 index 0000000000000000000000000000000000000000..ae2ab37de2acb8bc348d3de809ae8385324be962 --- /dev/null +++ b/edm2-img512-xl-fid/vae/config.json @@ -0,0 +1,38 @@ +{ + "_class_name": "AutoencoderKL", + "_diffusers_version": "0.36.0", + "_name_or_path": "stabilityai/sd-vae-ft-mse", + "act_fn": "silu", + "block_out_channels": [ + 128, + 256, + 512, + 512 + ], + "down_block_types": [ + "DownEncoderBlock2D", + "DownEncoderBlock2D", + "DownEncoderBlock2D", + "DownEncoderBlock2D" + ], + "force_upcast": true, + "in_channels": 3, + "latent_channels": 4, + "latents_mean": null, + "latents_std": null, + "layers_per_block": 2, + "mid_block_add_attention": true, + "norm_num_groups": 32, + "out_channels": 3, + "sample_size": 256, + "scaling_factor": 0.18215, + "shift_factor": null, + "up_block_types": [ + "UpDecoderBlock2D", + "UpDecoderBlock2D", + "UpDecoderBlock2D", + "UpDecoderBlock2D" + ], + "use_post_quant_conv": true, + "use_quant_conv": true +} diff --git a/edm2-img512-xl-fid/vae/diffusion_pytorch_model.safetensors b/edm2-img512-xl-fid/vae/diffusion_pytorch_model.safetensors new file mode 100644 index 0000000000000000000000000000000000000000..90464d67ac7303d0ee4696334df13da130a948ea --- /dev/null +++ b/edm2-img512-xl-fid/vae/diffusion_pytorch_model.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a1d993488569e928462932c8c38a0760b874d166399b14414135bd9c42df5815 +size 334643276 diff --git a/edm2-img512-xs-fid/demo.png b/edm2-img512-xs-fid/demo.png new file mode 100644 index 0000000000000000000000000000000000000000..5f4b05769047e63d94737d04b4667949a76253ef --- /dev/null +++ b/edm2-img512-xs-fid/demo.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a5ceee02aab56e93c77b73e082ca5f952897a2bd98c1b78c1899f78845561785 +size 375611 diff --git a/edm2-img512-xs-fid/model_index.json b/edm2-img512-xs-fid/model_index.json new file mode 100644 index 0000000000000000000000000000000000000000..24c0305eea3ec64abbc2f870476afa4f5e462f42 --- /dev/null +++ b/edm2-img512-xs-fid/model_index.json @@ -0,0 +1,19 @@ +{ + "_class_name": [ + "pipeline", + "EDM2Pipeline" + ], + "_diffusers_version": "0.31.0", + "scheduler": [ + "diffusers", + "EDMEulerScheduler" + ], + "unet": [ + "unet_edm2", + "EDM2UNet2DModel" + ], + "vae": [ + "diffusers", + "AutoencoderKL" + ] +} diff --git a/edm2-img512-xs-fid/pipeline.py b/edm2-img512-xs-fid/pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..6b4131215d07770c74efa0674d00f3d6423d9c06 --- /dev/null +++ b/edm2-img512-xs-fid/pipeline.py @@ -0,0 +1,406 @@ +"""Hub custom pipeline: EDM2Pipeline. +Load with native Hugging Face diffusers and trust_remote_code=True. +""" + +from __future__ import annotations + +# Copyright 2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os +from pathlib import Path +from typing import Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Union + +import numpy as np +import torch + +from diffusers.image_processor import VaeImageProcessor +from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput +from diffusers.utils import replace_example_docstring +from diffusers.utils.torch_utils import randn_tensor + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> from pathlib import Path + >>> import torch + >>> from diffusers import DiffusionPipeline + + >>> model_dir = Path("BiliSakura/EDM2-diffusers/edm2-img512-xs-fid").resolve() + >>> pipe = DiffusionPipeline.from_pretrained( + ... str(model_dir), + ... local_files_only=True, + ... custom_pipeline=str(model_dir / "pipeline.py"), + ... trust_remote_code=True, + ... torch_dtype=torch.float32, + ... ) + >>> pipe.to("cuda") + + >>> generator = torch.Generator(device="cuda").manual_seed(42) + >>> image = pipe( + ... class_labels=207, + ... num_inference_steps=32, + ... guidance_scale=1.0, + ... generator=generator, + ... ).images[0] + >>> image.save("demo.png") + ``` +""" + +# Default Stability VAE latent whitening used by NVlabs/edm2 (training/encoders.py). +_STABILITY_VAE_SCALE = np.float32(0.5) / np.float32([4.17, 4.62, 3.71, 3.28]) +_STABILITY_VAE_BIAS = np.float32(0.0) - np.float32([5.81, 3.25, 0.12, -2.15]) * _STABILITY_VAE_SCALE + +class EDM2Pipeline(DiffusionPipeline): + r""" + Pipeline for class-conditional image generation with EDM2 + ([Analyzing and Improving the Training Dynamics of Diffusion Models](https://arxiv.org/abs/2312.02696)). + + Parameters: + unet ([`EDM2UNet2DModel`]): + Main magnitude-preserving U-Net with EDM preconditioning. + scheduler ([`EDMEulerScheduler`]): + Built-in diffusers scheduler used for the Karras sigma schedule. EDM2 Heun sampling runs in + the pipeline because the UNet returns denoised latents rather than noise predictions. + vae ([`AutoencoderKL`], *optional*): + Decoder for 512px latent-diffusion checkpoints. Required when `unet.in_channels == 4`. + gnet ([`EDM2UNet2DModel`], *optional*): + Guiding network for autoguidance (`ref.lerp(main, guidance_scale)`). + id2label (`dict[int, str]`, *optional*): + ImageNet class id to English label mapping. + """ + + model_cpu_offload_seq = "unet->gnet->vae" + _optional_components = ["vae", "gnet"] + + def __init__( + self, + unet, + scheduler, + vae=None, + gnet=None, + id2label: Optional[Dict[Union[int, str], str]] = None, + ) -> None: + super().__init__() + self.register_modules(unet=unet, scheduler=scheduler, vae=vae, gnet=gnet) + self._id2label = self._normalize_id2label(id2label) + self.labels = self._build_label2id(self._id2label) + self._labels_loaded_from_model_index = bool(self._id2label) + self.vae_scale_factor = 8 if self.vae is not None else 1 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_normalize=False) + + @staticmethod + def _normalize_id2label(id2label: Optional[Dict[Union[int, str], str]]) -> Dict[int, str]: + if not id2label: + return {} + return {int(key): value for key, value in id2label.items()} + + @staticmethod + def _build_label2id(id2label: Dict[int, str]) -> Dict[str, int]: + label2id: Dict[str, int] = {} + for class_id, value in id2label.items(): + for synonym in value.split(","): + synonym = synonym.strip() + if synonym: + label2id[synonym] = int(class_id) + return dict(sorted(label2id.items())) + + def _ensure_labels_loaded(self) -> None: + if self._labels_loaded_from_model_index: + return + loaded = self._read_id2label_from_model_index(getattr(self.config, "_name_or_path", None)) + if loaded: + self._id2label = loaded + self.labels = self._build_label2id(self._id2label) + self._labels_loaded_from_model_index = True + + @staticmethod + def _read_id2label_from_model_index(variant_path: Optional[str]) -> Dict[int, str]: + if not variant_path: + return {} + model_index_path = Path(variant_path).resolve() / "model_index.json" + if not model_index_path.is_file(): + return {} + raw = json.loads(model_index_path.read_text(encoding="utf-8")) + id2label = raw.get("id2label") + if not isinstance(id2label, dict): + return {} + return {int(key): value for key, value in id2label.items()} + + @property + def id2label(self) -> Dict[int, str]: + r"""ImageNet class id to English label string (comma-separated synonyms).""" + self._ensure_labels_loaded() + return self._id2label + + def get_label_ids(self, label: Union[str, List[str]]) -> List[int]: + r""" + Map ImageNet label strings to class ids. + + Args: + label (`str` or `list[str]`): + One or more English label strings that match entries in `id2label`. + """ + self._ensure_labels_loaded() + if not self.labels: + raise ValueError("No English labels loaded. Add `id2label` to model_index.json.") + labels = [label] if isinstance(label, str) else list(label) + missing = [item for item in labels if item not in self.labels] + if missing: + preview = ", ".join(list(self.labels.keys())[:8]) + raise ValueError(f"Unknown English label(s): {missing}. Example valid labels: {preview}, ...") + return [self.labels[item] for item in labels] + + def _default_image_size(self) -> int: + latent_size = int(getattr(self.unet, "sample_size", getattr(self.unet.config, "sample_size", 64))) + return latent_size * self.vae_scale_factor + + def check_inputs( + self, + height: int, + width: int, + num_inference_steps: int, + guidance_scale: float, + output_type: str, + ) -> None: + if num_inference_steps < 1: + raise ValueError("num_inference_steps must be >= 1.") + if guidance_scale < 1.0: + raise ValueError("guidance_scale must be >= 1.0.") + if guidance_scale > 1.0 and self.gnet is None: + raise ValueError("guidance_scale > 1.0 requires a guiding network (`gnet`).") + if output_type not in {"pil", "np", "pt", "latent"}: + raise ValueError("output_type must be one of: 'pil', 'np', 'pt', 'latent'.") + + native_size = self._default_image_size() + if height != native_size or width != native_size: + raise ValueError( + f"EDM2 expects native resolution height=width={native_size}. " + f"Got height={height}, width={width}." + ) + + def _normalize_class_labels( + self, + class_labels: Optional[Union[int, str, Sequence[Union[int, str]], torch.Tensor]], + batch_size: int, + device: torch.device, + ) -> Optional[torch.Tensor]: + label_dim = int(getattr(self.unet, "num_class_embeds", getattr(self.unet.config, "num_class_embeds", 0))) + if label_dim == 0: + return None + if class_labels is None: + indices = torch.randint(label_dim, size=(batch_size,), device=device) + return torch.eye(label_dim, device=device, dtype=torch.float32)[indices] + + if isinstance(class_labels, str): + class_labels = self.get_label_ids(class_labels)[0] + elif isinstance(class_labels, Sequence) and class_labels and isinstance(class_labels[0], str): + class_labels = self.get_label_ids(list(class_labels)) + + if isinstance(class_labels, int): + indices = torch.full((batch_size,), class_labels, device=device, dtype=torch.long) + elif isinstance(class_labels, torch.Tensor): + if class_labels.ndim == 2: + labels = class_labels.to(device=device, dtype=torch.float32) + if labels.shape[0] != batch_size: + raise ValueError(f"class_labels batch must match batch_size={batch_size}.") + return labels + indices = class_labels.to(device=device, dtype=torch.long).flatten() + else: + indices = torch.tensor(list(class_labels), device=device, dtype=torch.long) + + if indices.numel() == 1 and batch_size > 1: + indices = indices.repeat(batch_size) + if indices.numel() != batch_size: + raise ValueError(f"class_labels must resolve to batch size {batch_size}, got {indices.numel()}.") + return torch.eye(label_dim, device=device, dtype=torch.float32)[indices] + + def prepare_latents( + self, + batch_size: int, + height: int, + width: int, + dtype: torch.dtype, + device: torch.device, + generator: Optional[Union[torch.Generator, List[torch.Generator]]], + ) -> torch.Tensor: + in_channels = int(getattr(self.unet, "in_channels", getattr(self.unet.config, "in_channels", 4))) + latent_size = height // self.vae_scale_factor + return randn_tensor( + (batch_size, in_channels, latent_size, latent_size), + generator=generator, + device=device, + dtype=torch.float32, + ) + + def decode_latents(self, latents: torch.Tensor, output_type: str = "pil"): + if output_type == "latent": + return latents + + in_channels = int(getattr(self.unet, "in_channels", getattr(self.unet.config, "in_channels", 3))) + if self.vae is None: + image = (latents.to(torch.float32) * 127.5 + 128).clip(0, 255) / 255.0 + return self.image_processor.postprocess(image, output_type=output_type) + + if in_channels == 4: + x = latents.to(torch.float32) + scale = torch.as_tensor(_STABILITY_VAE_SCALE, dtype=x.dtype, device=x.device).reshape(1, -1, 1, 1) + bias = torch.as_tensor(_STABILITY_VAE_BIAS, dtype=x.dtype, device=x.device).reshape(1, -1, 1, 1) + x = (x - bias) / scale + else: + x = latents.to(torch.float32) + + vae_dtype = getattr(self.vae, "dtype", None) or next(self.vae.parameters()).dtype + image = self.vae.decode(x.to(dtype=vae_dtype)).sample.to(torch.float32).clamp(0, 1) + + return self.image_processor.postprocess(image, output_type=output_type) + + @staticmethod + def _apply_autoguidance( + main: torch.Tensor, + ref: torch.Tensor, + guidance_scale: float, + ) -> torch.Tensor: + return ref.lerp(main, guidance_scale) + + @staticmethod + def _sample_edm2_heun( + denoise_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor], + noise: torch.Tensor, + sigmas: torch.Tensor, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + progress_bar: Optional[Callable[[Iterable], Iterable]] = None, + dtype: torch.dtype = torch.float32, + ) -> torch.Tensor: + """NVlabs EDM2 Heun sampler (generate_images.edm_sampler, guidance=1, S_churn=0).""" + x_next = noise.to(dtype) * sigmas[0] + + sigma_pairs = list(zip(sigmas[:-1], sigmas[1:])) + if progress_bar is not None: + sigma_pairs = progress_bar(sigma_pairs) + + num_steps = len(sigma_pairs) + for i, (sigma_cur, sigma_next) in enumerate(sigma_pairs): + x_hat, sigma_hat = x_next, sigma_cur + d_cur = (x_hat - denoise_fn(x_hat, sigma_hat)) / sigma_hat + x_next = x_hat + (sigma_next - sigma_hat) * d_cur + if i < num_steps - 1: + d_prime = (x_next - denoise_fn(x_next, sigma_next)) / sigma_next + x_next = x_hat + (sigma_next - sigma_hat) * (0.5 * d_cur + 0.5 * d_prime) + return x_next + + @torch.inference_mode() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + class_labels: Optional[Union[int, str, Sequence[Union[int, str]], torch.Tensor]] = None, + batch_size: int = 1, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 32, + guidance_scale: float = 1.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + output_type: str = "pil", + return_dict: bool = True, + ) -> Union[ImagePipelineOutput, Tuple]: + r""" + Generate class-conditional images with EDM2. + + Args: + class_labels (`int`, `str`, `list[int]`, `list[str]`, or `torch.Tensor`, *optional*): + ImageNet class indices, English label strings, or one-hot float tensors. + Random classes are sampled when omitted on conditional models. + batch_size (`int`, defaults to `1`): + Number of images to generate. + height (`int`, *optional*): + Output height in pixels. Defaults to the pretrained native resolution. + width (`int`, *optional*): + Output width in pixels. Defaults to the pretrained native resolution. + num_inference_steps (`int`, defaults to `32`): + Number of EDM2 Heun steps (NVlabs default). + guidance_scale (`float`, defaults to `1.0`): + Autoguidance strength. Values above `1.0` blend the main net with `gnet` + via `gnet_output.lerp(unet_output, guidance_scale)`. + generator (`torch.Generator`, *optional*): + RNG for reproducibility. + output_type (`str`, defaults to `"pil"`): + `"pil"`, `"np"`, `"pt"`, or `"latent"`. + return_dict (`bool`, defaults to `True`): + Return [`~pipelines.pipeline_utils.ImagePipelineOutput`] if True. + + Examples: + + """ + default_size = self._default_image_size() + height = int(height or default_size) + width = int(width or default_size) + self.check_inputs(height, width, num_inference_steps, guidance_scale, output_type) + + device = self._execution_device + dtype = self.unet.dtype + labels = self._normalize_class_labels(class_labels, batch_size=batch_size, device=device) + noise = self.prepare_latents(batch_size, height, width, dtype, device, generator) + + def denoise_fn(x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor: + sigma_batch = sigma.reshape(1).expand(batch_size) + main = self.unet( + sample=x, + sigma=sigma_batch, + class_labels=labels, + force_fp32=True, + ).sample + if guidance_scale == 1.0 or self.gnet is None: + return main.to(torch.float32) + ref = self.gnet( + sample=x, + sigma=sigma_batch, + class_labels=labels, + force_fp32=True, + ).sample + return self._apply_autoguidance(main, ref, guidance_scale).to(torch.float32) + + self.scheduler.set_timesteps(num_inference_steps, device=device) + latents = self._sample_edm2_heun( + denoise_fn=denoise_fn, + noise=noise, + sigmas=self.scheduler.sigmas.to(device), + generator=generator, + progress_bar=self.progress_bar, + dtype=torch.float32, + ) + + image = self.decode_latents(latents, output_type=output_type) + if not return_dict: + return (image, latents) + return ImagePipelineOutput(images=image) + + @classmethod + def _load_vae(cls, pretrained_model_name_or_path: str, torch_dtype: Optional[torch.dtype] = None): + vae_dir = os.path.join(pretrained_model_name_or_path, "vae") + if os.path.isdir(vae_dir): + try: + + return AutoencoderKL.from_pretrained(vae_dir, torch_dtype=torch_dtype) + except Exception: + return None + + vae_hint = os.path.join(pretrained_model_name_or_path, "vae_pretrained_model_name_or_path.txt") + if os.path.isfile(vae_hint): + with open(vae_hint, "r", encoding="utf-8") as f: + hub_id = f.read().strip() + if hub_id: + + return AutoencoderKL.from_pretrained(hub_id, torch_dtype=torch_dtype) + return None diff --git a/edm2-img512-xs-fid/scheduler/scheduler_config.json b/edm2-img512-xs-fid/scheduler/scheduler_config.json new file mode 100644 index 0000000000000000000000000000000000000000..08cbde97955d06d66b0b72be217db84bb797d711 --- /dev/null +++ b/edm2-img512-xs-fid/scheduler/scheduler_config.json @@ -0,0 +1,11 @@ +{ + "_class_name": "EDMEulerScheduler", + "final_sigmas_type": "zero", + "num_train_timesteps": 1000, + "prediction_type": "epsilon", + "rho": 7.0, + "sigma_data": 0.5, + "sigma_max": 80.0, + "sigma_min": 0.002, + "sigma_schedule": "karras" +} diff --git a/edm2-img512-xs-fid/unet/config.json b/edm2-img512-xs-fid/unet/config.json new file mode 100644 index 0000000000000000000000000000000000000000..6fceca931c569a93854b38007d87f2f5df9fea86 --- /dev/null +++ b/edm2-img512-xs-fid/unet/config.json @@ -0,0 +1,30 @@ +{ + "_class_name": "EDM2UNet2DModel", + "attn_balance": 0.3, + "attn_resolutions": [ + 16, + 8 + ], + "channel_mult": [ + 1, + 2, + 3, + 4 + ], + "channel_mult_emb": 4, + "channel_mult_noise": 1, + "channels_per_head": 64, + "clip_act": 256, + "concat_balance": 0.5, + "dropout": 0.0, + "in_channels": 4, + "label_balance": 0.5, + "model_channels": 128, + "num_blocks": 3, + "num_class_embeds": 1000, + "out_channels": 4, + "res_balance": 0.3, + "sample_size": 64, + "sigma_data": 0.5, + "use_fp16": true +} diff --git a/edm2-img512-xs-fid/unet/diffusion_pytorch_model.safetensors b/edm2-img512-xs-fid/unet/diffusion_pytorch_model.safetensors new file mode 100644 index 0000000000000000000000000000000000000000..0362009d25a7cbb31886d430ea8402ddd8ff951f --- /dev/null +++ b/edm2-img512-xs-fid/unet/diffusion_pytorch_model.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d5fe6a82ecdaa64b966a245a6f9179d159c702902ce9897bd60e21e80615a59b +size 498877268 diff --git a/edm2-img512-xs-fid/unet/unet_edm2.py b/edm2-img512-xs-fid/unet/unet_edm2.py new file mode 100644 index 0000000000000000000000000000000000000000..49d9451b443c4d6ad3b6ba2db19a7a804b07b6de --- /dev/null +++ b/edm2-img512-xs-fid/unet/unet_edm2.py @@ -0,0 +1,434 @@ +import math +import json +import os +from dataclasses import dataclass +from typing import List, Optional, Tuple + +import numpy as np +import torch + +try: + from diffusers.configuration_utils import ConfigMixin, register_to_config + from diffusers.models.modeling_utils import ModelMixin + from diffusers.utils import BaseOutput +except ImportError: # pragma: no cover + class ModelMixin(torch.nn.Module): + pass + + class ConfigMixin: + config = {} + + def register_to_config(self, **kwargs): + self.config = kwargs + + def register_to_config(func): + return func + + @dataclass + class BaseOutput: + pass + + +def normalize(x: torch.Tensor, dim: Optional[List[int]] = None, eps: float = 1e-4) -> torch.Tensor: + if dim is None: + dim = list(range(1, x.ndim)) + norm = torch.linalg.vector_norm(x, dim=dim, keepdim=True, dtype=torch.float32) + norm = torch.add(eps, norm, alpha=math.sqrt(norm.numel() / x.numel())) + return x / norm.to(x.dtype) + + +def resample(x: torch.Tensor, f: List[float], mode: str = "keep") -> torch.Tensor: + if mode == "keep": + return x + filt = np.float32(f) + pad = (len(filt) - 1) // 2 + filt = filt / filt.sum() + filt = np.outer(filt, filt)[np.newaxis, np.newaxis, :, :] + filt = torch.as_tensor(filt, dtype=x.dtype, device=x.device) + c = x.shape[1] + if mode == "down": + return torch.nn.functional.conv2d(x, filt.tile([c, 1, 1, 1]), groups=c, stride=2, padding=(pad,)) + return torch.nn.functional.conv_transpose2d(x, (filt * 4).tile([c, 1, 1, 1]), groups=c, stride=2, padding=(pad,)) + + +def mp_silu(x: torch.Tensor) -> torch.Tensor: + return torch.nn.functional.silu(x) / 0.596 + + +def mp_sum(a: torch.Tensor, b: torch.Tensor, t: float = 0.5) -> torch.Tensor: + return a.lerp(b, t) / math.sqrt((1 - t) ** 2 + t ** 2) + + +def mp_cat(a: torch.Tensor, b: torch.Tensor, dim: int = 1, t: float = 0.5) -> torch.Tensor: + na = a.shape[dim] + nb = b.shape[dim] + c = math.sqrt((na + nb) / ((1 - t) ** 2 + t ** 2)) + wa = c / math.sqrt(na) * (1 - t) + wb = c / math.sqrt(nb) * t + return torch.cat([wa * a, wb * b], dim=dim) + + +class MPFourier(torch.nn.Module): + def __init__(self, num_channels: int, bandwidth: float = 1): + super().__init__() + self.register_buffer("freqs", 2 * math.pi * torch.randn(num_channels) * bandwidth) + self.register_buffer("phases", 2 * math.pi * torch.rand(num_channels)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + y = x.to(torch.float32).ger(self.freqs.to(torch.float32)) + y = y + self.phases.to(torch.float32) + y = y.cos() * math.sqrt(2) + return y.to(x.dtype) + + +class MPConv(torch.nn.Module): + def __init__(self, in_channels: int, out_channels: int, kernel: Tuple[int, ...]): + super().__init__() + self.out_channels = out_channels + self.weight = torch.nn.Parameter(torch.randn(out_channels, in_channels, *kernel)) + + def forward(self, x: torch.Tensor, gain: float = 1) -> torch.Tensor: + w = self.weight.to(torch.float32) + if self.training: + with torch.no_grad(): + self.weight.copy_(normalize(w)) + w = normalize(w) + w = w * (gain / math.sqrt(w[0].numel())) + w = w.to(x.dtype) + if w.ndim == 2: + return x @ w.t() + return torch.nn.functional.conv2d(x, w, padding=(w.shape[-1] // 2,)) + + +class Block(torch.nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + emb_channels: int, + flavor: str = "enc", + resample_mode: str = "keep", + resample_filter: List[float] = [1, 1], + attention: bool = False, + channels_per_head: int = 64, + dropout: float = 0.0, + res_balance: float = 0.3, + attn_balance: float = 0.3, + clip_act: Optional[float] = 256, + ): + super().__init__() + self.out_channels = out_channels + self.flavor = flavor + self.resample_filter = resample_filter + self.resample_mode = resample_mode + self.num_heads = out_channels // channels_per_head if attention else 0 + self.dropout = dropout + self.res_balance = res_balance + self.attn_balance = attn_balance + self.clip_act = clip_act + self.emb_gain = torch.nn.Parameter(torch.zeros([])) + self.conv_res0 = MPConv(out_channels if flavor == "enc" else in_channels, out_channels, kernel=(3, 3)) + self.emb_linear = MPConv(emb_channels, out_channels, kernel=()) + self.conv_res1 = MPConv(out_channels, out_channels, kernel=(3, 3)) + self.conv_skip = MPConv(in_channels, out_channels, kernel=(1, 1)) if in_channels != out_channels else None + self.attn_qkv = MPConv(out_channels, out_channels * 3, kernel=(1, 1)) if self.num_heads else None + self.attn_proj = MPConv(out_channels, out_channels, kernel=(1, 1)) if self.num_heads else None + + def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor: + x = resample(x, f=self.resample_filter, mode=self.resample_mode) + if self.flavor == "enc": + if self.conv_skip is not None: + x = self.conv_skip(x) + x = normalize(x, dim=[1]) + + y = self.conv_res0(mp_silu(x)) + c = self.emb_linear(emb, gain=self.emb_gain) + 1 + y = mp_silu(y * c.unsqueeze(2).unsqueeze(3).to(y.dtype)) + if self.training and self.dropout: + y = torch.nn.functional.dropout(y, p=self.dropout) + y = self.conv_res1(y) + + if self.flavor == "dec" and self.conv_skip is not None: + x = self.conv_skip(x) + x = mp_sum(x, y, t=self.res_balance) + + if self.num_heads: + y = self.attn_qkv(x) + y = y.reshape(y.shape[0], self.num_heads, -1, 3, y.shape[2] * y.shape[3]) + q, k, v = normalize(y, dim=[2]).unbind(3) + w = torch.einsum("nhcq,nhck->nhqk", q, k / math.sqrt(q.shape[2])).softmax(dim=3) + y = torch.einsum("nhqk,nhck->nhcq", w, v) + y = self.attn_proj(y.reshape(*x.shape)) + x = mp_sum(x, y, t=self.attn_balance) + + if self.clip_act is not None: + x = x.clip_(-self.clip_act, self.clip_act) + return x + + +class EDM2UNet(torch.nn.Module): + def __init__( + self, + img_resolution: int, + img_channels: int, + label_dim: int, + model_channels: int = 192, + channel_mult: Tuple[int, ...] = (1, 2, 3, 4), + channel_mult_noise: Optional[int] = None, + channel_mult_emb: Optional[int] = None, + num_blocks: int = 3, + attn_resolutions: Tuple[int, ...] = (16, 8), + label_balance: float = 0.5, + concat_balance: float = 0.5, + **block_kwargs, + ): + super().__init__() + cblock = [model_channels * x for x in channel_mult] + cnoise = model_channels * channel_mult_noise if channel_mult_noise is not None else cblock[0] + cemb = model_channels * channel_mult_emb if channel_mult_emb is not None else max(cblock) + self.label_balance = label_balance + self.concat_balance = concat_balance + self.out_gain = torch.nn.Parameter(torch.zeros([])) + + self.emb_fourier = MPFourier(cnoise) + self.emb_noise = MPConv(cnoise, cemb, kernel=()) + self.emb_label = MPConv(label_dim, cemb, kernel=()) if label_dim else None + + self.enc = torch.nn.ModuleDict() + cout = img_channels + 1 + for level, channels in enumerate(cblock): + res = img_resolution >> level + if level == 0: + cin = cout + cout = channels + self.enc[f"{res}x{res}_conv"] = MPConv(cin, cout, kernel=(3, 3)) + else: + self.enc[f"{res}x{res}_down"] = Block(cout, cout, cemb, flavor="enc", resample_mode="down", **block_kwargs) + for idx in range(num_blocks): + cin = cout + cout = channels + self.enc[f"{res}x{res}_block{idx}"] = Block( + cin, + cout, + cemb, + flavor="enc", + attention=(res in attn_resolutions), + **block_kwargs, + ) + + self.dec = torch.nn.ModuleDict() + skips = [block.out_channels for block in self.enc.values()] + for level, channels in reversed(list(enumerate(cblock))): + res = img_resolution >> level + if level == len(cblock) - 1: + self.dec[f"{res}x{res}_in0"] = Block(cout, cout, cemb, flavor="dec", attention=True, **block_kwargs) + self.dec[f"{res}x{res}_in1"] = Block(cout, cout, cemb, flavor="dec", **block_kwargs) + else: + self.dec[f"{res}x{res}_up"] = Block(cout, cout, cemb, flavor="dec", resample_mode="up", **block_kwargs) + for idx in range(num_blocks + 1): + cin = cout + skips.pop() + cout = channels + self.dec[f"{res}x{res}_block{idx}"] = Block( + cin, + cout, + cemb, + flavor="dec", + attention=(res in attn_resolutions), + **block_kwargs, + ) + + self.out_conv = MPConv(cout, img_channels, kernel=(3, 3)) + + def forward(self, x: torch.Tensor, noise_labels: torch.Tensor, class_labels: Optional[torch.Tensor]) -> torch.Tensor: + emb = self.emb_noise(self.emb_fourier(noise_labels)) + if self.emb_label is not None: + if class_labels is None: + raise ValueError("class_labels are required for conditional EDM2UNet.") + emb = mp_sum(emb, self.emb_label(class_labels * math.sqrt(class_labels.shape[1])), t=self.label_balance) + emb = mp_silu(emb) + + x = torch.cat([x, torch.ones_like(x[:, :1])], dim=1) + skips = [] + for name, block in self.enc.items(): + x = block(x) if "conv" in name else block(x, emb) + skips.append(x) + + for name, block in self.dec.items(): + if "block" in name: + x = mp_cat(x, skips.pop(), t=self.concat_balance) + x = block(x, emb) + return self.out_conv(x, gain=self.out_gain) + + +@dataclass +class EDM2UNet2DOutput(BaseOutput): + sample: torch.Tensor + logvar: Optional[torch.Tensor] = None + + + +_CONFIG_KEYS = ( + "sample_size", + "in_channels", + "out_channels", + "num_class_embeds", + "use_fp16", + "sigma_data", + "logvar_channels", + "model_channels", + "channel_mult", + "channel_mult_noise", + "channel_mult_emb", + "num_blocks", + "attn_resolutions", + "label_balance", + "concat_balance", + "dropout", + "channels_per_head", + "res_balance", + "attn_balance", + "clip_act", +) + + +class EDM2UNet2DModel(ModelMixin, ConfigMixin): + @register_to_config + def __init__( + self, + sample_size: int = 64, + in_channels: int = 4, + out_channels: int = 4, + num_class_embeds: int = 0, + use_fp16: bool = True, + sigma_data: float = 0.5, + logvar_channels: int = 128, + model_channels: int = 192, + channel_mult: Tuple[int, ...] = (1, 2, 3, 4), + channel_mult_noise: Optional[int] = None, + channel_mult_emb: Optional[int] = None, + num_blocks: int = 3, + attn_resolutions: Tuple[int, ...] = (16, 8), + label_balance: float = 0.5, + concat_balance: float = 0.5, + dropout: float = 0.0, + channels_per_head: int = 64, + res_balance: float = 0.3, + attn_balance: float = 0.3, + clip_act: Optional[float] = 256, + ): + super().__init__() + self.sample_size = sample_size + self.in_channels = in_channels + self.out_channels = out_channels + self.num_class_embeds = num_class_embeds + self.use_fp16 = use_fp16 + self.sigma_data = sigma_data + self.model_channels = model_channels + self.channel_mult = channel_mult + self.channel_mult_noise = channel_mult_noise + self.channel_mult_emb = channel_mult_emb + self.num_blocks = num_blocks + self.attn_resolutions = attn_resolutions + self.label_balance = label_balance + self.concat_balance = concat_balance + self.dropout = dropout + self.channels_per_head = channels_per_head + self.res_balance = res_balance + self.attn_balance = attn_balance + self.clip_act = clip_act + self.unet = EDM2UNet( + img_resolution=sample_size, + img_channels=in_channels, + label_dim=num_class_embeds, + model_channels=model_channels, + channel_mult=channel_mult, + channel_mult_noise=channel_mult_noise, + channel_mult_emb=channel_mult_emb, + num_blocks=num_blocks, + attn_resolutions=attn_resolutions, + label_balance=label_balance, + concat_balance=concat_balance, + dropout=dropout, + channels_per_head=channels_per_head, + res_balance=res_balance, + attn_balance=attn_balance, + clip_act=clip_act, + ) + self.logvar_fourier = MPFourier(logvar_channels) + self.logvar_linear = MPConv(logvar_channels, 1, kernel=()) + + def forward( + self, + sample: torch.Tensor, + sigma: torch.Tensor, + class_labels: Optional[torch.Tensor] = None, + force_fp32: bool = False, + return_logvar: bool = False, + return_dict: bool = True, + ) -> EDM2UNet2DOutput: + x = sample.to(torch.float32) + sigma = sigma.to(torch.float32).reshape(-1, 1, 1, 1) + if self.num_class_embeds == 0: + class_labels = None + else: + if class_labels is None: + class_labels = torch.zeros([x.shape[0], self.num_class_embeds], device=x.device) + class_labels = class_labels.to(torch.float32).reshape(-1, self.num_class_embeds) + dtype = torch.float16 if (self.use_fp16 and not force_fp32 and x.device.type == "cuda") else torch.float32 + + c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2) + c_out = sigma * self.sigma_data / (sigma**2 + self.sigma_data**2).sqrt() + c_in = 1 / (self.sigma_data**2 + sigma**2).sqrt() + c_noise = sigma.flatten().log() / 4 + + x_in = (c_in * x).to(dtype) + f_x = self.unet(x_in, c_noise, class_labels) + d_x = c_skip * x + c_out * f_x.to(torch.float32) + + logvar = None + if return_logvar: + logvar = self.logvar_linear(self.logvar_fourier(c_noise)).reshape(-1, 1, 1, 1) + + if not return_dict: + return (d_x, logvar) + return EDM2UNet2DOutput(sample=d_x, logvar=logvar) + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path: str, torch_dtype: Optional[torch.dtype] = None, **kwargs): + subfolder = kwargs.pop("subfolder", None) + model_dir = os.path.join(pretrained_model_name_or_path, subfolder) if subfolder else pretrained_model_name_or_path + with open(os.path.join(model_dir, "config.json"), "r", encoding="utf-8") as f: + config = json.load(f) + init_kwargs = {k: v for k, v in config.items() if k in _CONFIG_KEYS} + model = cls(**init_kwargs) + weight_file = os.path.join(model_dir, "diffusion_pytorch_model.safetensors") + if os.path.isfile(weight_file): + from safetensors.torch import load_file + + state_dict = load_file(weight_file) + else: + state_dict = torch.load(os.path.join(model_dir, "diffusion_pytorch_model.bin"), map_location="cpu") + model.load_state_dict(state_dict, strict=True) + if torch_dtype is not None: + model = model.to(dtype=torch_dtype) + return model + + def save_pretrained(self, save_directory: str, safe_serialization: bool = True): + os.makedirs(save_directory, exist_ok=True) + stored = dict(getattr(self, "config", {})) + config = {"_class_name": self.__class__.__name__} + for key in _CONFIG_KEYS: + if key in stored: + config[key] = stored[key] + elif hasattr(self, key): + config[key] = getattr(self, key) + with open(os.path.join(save_directory, "config.json"), "w", encoding="utf-8") as f: + json.dump(config, f, indent=2, sort_keys=True) + f.write("\n") + state_dict = self.state_dict() + if safe_serialization: + from safetensors.torch import save_file + + save_file(state_dict, os.path.join(save_directory, "diffusion_pytorch_model.safetensors")) + else: + torch.save(state_dict, os.path.join(save_directory, "diffusion_pytorch_model.bin")) diff --git a/edm2-img512-xs-fid/vae/config.json b/edm2-img512-xs-fid/vae/config.json new file mode 100644 index 0000000000000000000000000000000000000000..ae2ab37de2acb8bc348d3de809ae8385324be962 --- /dev/null +++ b/edm2-img512-xs-fid/vae/config.json @@ -0,0 +1,38 @@ +{ + "_class_name": "AutoencoderKL", + "_diffusers_version": "0.36.0", + "_name_or_path": "stabilityai/sd-vae-ft-mse", + "act_fn": "silu", + "block_out_channels": [ + 128, + 256, + 512, + 512 + ], + "down_block_types": [ + "DownEncoderBlock2D", + "DownEncoderBlock2D", + "DownEncoderBlock2D", + "DownEncoderBlock2D" + ], + "force_upcast": true, + "in_channels": 3, + "latent_channels": 4, + "latents_mean": null, + "latents_std": null, + "layers_per_block": 2, + "mid_block_add_attention": true, + "norm_num_groups": 32, + "out_channels": 3, + "sample_size": 256, + "scaling_factor": 0.18215, + "shift_factor": null, + "up_block_types": [ + "UpDecoderBlock2D", + "UpDecoderBlock2D", + "UpDecoderBlock2D", + "UpDecoderBlock2D" + ], + "use_post_quant_conv": true, + "use_quant_conv": true +} diff --git a/edm2-img512-xs-fid/vae/diffusion_pytorch_model.safetensors b/edm2-img512-xs-fid/vae/diffusion_pytorch_model.safetensors new file mode 100644 index 0000000000000000000000000000000000000000..90464d67ac7303d0ee4696334df13da130a948ea --- /dev/null +++ b/edm2-img512-xs-fid/vae/diffusion_pytorch_model.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a1d993488569e928462932c8c38a0760b874d166399b14414135bd9c42df5815 +size 334643276 diff --git a/edm2-img512-xxl-fid/README.md b/edm2-img512-xxl-fid/README.md new file mode 100644 index 0000000000000000000000000000000000000000..7e6079717eab26ea8b83bd835b2c9386dc5852da --- /dev/null +++ b/edm2-img512-xxl-fid/README.md @@ -0,0 +1,67 @@ +--- +license: cc-by-nc-sa-4.0 +library_name: diffusers +pipeline_tag: unconditional-image-generation +tags: + - diffusers + - edm2 + - image-generation + - class-conditional + - imagenet +inference: true +widget: + - output: + url: demo.png +language: + - en +--- + +# edm2-img512-xxl-fid + +Self-contained Diffusers checkpoint for **EDM2-XXL** at 512×512, optimized for FID (NVlabs preset `edm2-img512-xxl-fid`). + +Converted from [NVlabs/edm2](https://github.com/NVlabs/edm2) post-hoc reconstruction +`edm2-img512-xxl-0939524-0.070.pkl` (FID 1.91). + +## Demo + +![edm2-img512-xxl-fid demo](demo.png) + +Class-conditional sample (ImageNet class **207**, golden retriever), 512×512, 32 steps, guidance 1.0, seed 42. + +## Load + +```python +from pathlib import Path +import torch +from diffusers import DiffusionPipeline + +model_dir = Path(".") +pipe = DiffusionPipeline.from_pretrained( + str(model_dir), + local_files_only=True, + trust_remote_code=True, + torch_dtype=torch.bfloat16, +).to("cuda") + +generator = torch.Generator(device="cuda").manual_seed(42) +image = pipe( + class_labels=207, + num_inference_steps=32, + guidance_scale=1.0, + generator=generator, +).images[0] +image.save("demo.png") +``` + +Official NVlabs defaults (`generate_images.py`): `num_steps=32`, `sigma_min=0.002`, `sigma_max=80`, +`rho=7`, `guidance=1.0` (no gnet), `S_churn=0`. Heun sampling runs in float32 internally even when +UNet/VAE weights are loaded in bf16/fp16. + +## Components + +- `pipeline.py` +- `unet/unet_edm2.py` +- `unet/diffusion_pytorch_model.safetensors` +- `scheduler/scheduler_config.json` (`EDMEulerScheduler`) +- `vae/diffusion_pytorch_model.safetensors` (`stabilityai/sd-vae-ft-mse`) diff --git a/edm2-img512-xxl-fid/demo.png b/edm2-img512-xxl-fid/demo.png new file mode 100644 index 0000000000000000000000000000000000000000..83e3dbac321e2b6cbcee88d03903a7e9c1c0439a --- /dev/null +++ b/edm2-img512-xxl-fid/demo.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9da85f6cb881c112c3240fa5f62b3331af9700a656b882e01ebf9df4ea05660f +size 374923 diff --git a/edm2-img512-xxl-fid/model_index.json b/edm2-img512-xxl-fid/model_index.json new file mode 100644 index 0000000000000000000000000000000000000000..24c0305eea3ec64abbc2f870476afa4f5e462f42 --- /dev/null +++ b/edm2-img512-xxl-fid/model_index.json @@ -0,0 +1,19 @@ +{ + "_class_name": [ + "pipeline", + "EDM2Pipeline" + ], + "_diffusers_version": "0.31.0", + "scheduler": [ + "diffusers", + "EDMEulerScheduler" + ], + "unet": [ + "unet_edm2", + "EDM2UNet2DModel" + ], + "vae": [ + "diffusers", + "AutoencoderKL" + ] +} diff --git a/edm2-img512-xxl-fid/pipeline.py b/edm2-img512-xxl-fid/pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..6b4131215d07770c74efa0674d00f3d6423d9c06 --- /dev/null +++ b/edm2-img512-xxl-fid/pipeline.py @@ -0,0 +1,406 @@ +"""Hub custom pipeline: EDM2Pipeline. +Load with native Hugging Face diffusers and trust_remote_code=True. +""" + +from __future__ import annotations + +# Copyright 2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os +from pathlib import Path +from typing import Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Union + +import numpy as np +import torch + +from diffusers.image_processor import VaeImageProcessor +from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput +from diffusers.utils import replace_example_docstring +from diffusers.utils.torch_utils import randn_tensor + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> from pathlib import Path + >>> import torch + >>> from diffusers import DiffusionPipeline + + >>> model_dir = Path("BiliSakura/EDM2-diffusers/edm2-img512-xs-fid").resolve() + >>> pipe = DiffusionPipeline.from_pretrained( + ... str(model_dir), + ... local_files_only=True, + ... custom_pipeline=str(model_dir / "pipeline.py"), + ... trust_remote_code=True, + ... torch_dtype=torch.float32, + ... ) + >>> pipe.to("cuda") + + >>> generator = torch.Generator(device="cuda").manual_seed(42) + >>> image = pipe( + ... class_labels=207, + ... num_inference_steps=32, + ... guidance_scale=1.0, + ... generator=generator, + ... ).images[0] + >>> image.save("demo.png") + ``` +""" + +# Default Stability VAE latent whitening used by NVlabs/edm2 (training/encoders.py). +_STABILITY_VAE_SCALE = np.float32(0.5) / np.float32([4.17, 4.62, 3.71, 3.28]) +_STABILITY_VAE_BIAS = np.float32(0.0) - np.float32([5.81, 3.25, 0.12, -2.15]) * _STABILITY_VAE_SCALE + +class EDM2Pipeline(DiffusionPipeline): + r""" + Pipeline for class-conditional image generation with EDM2 + ([Analyzing and Improving the Training Dynamics of Diffusion Models](https://arxiv.org/abs/2312.02696)). + + Parameters: + unet ([`EDM2UNet2DModel`]): + Main magnitude-preserving U-Net with EDM preconditioning. + scheduler ([`EDMEulerScheduler`]): + Built-in diffusers scheduler used for the Karras sigma schedule. EDM2 Heun sampling runs in + the pipeline because the UNet returns denoised latents rather than noise predictions. + vae ([`AutoencoderKL`], *optional*): + Decoder for 512px latent-diffusion checkpoints. Required when `unet.in_channels == 4`. + gnet ([`EDM2UNet2DModel`], *optional*): + Guiding network for autoguidance (`ref.lerp(main, guidance_scale)`). + id2label (`dict[int, str]`, *optional*): + ImageNet class id to English label mapping. + """ + + model_cpu_offload_seq = "unet->gnet->vae" + _optional_components = ["vae", "gnet"] + + def __init__( + self, + unet, + scheduler, + vae=None, + gnet=None, + id2label: Optional[Dict[Union[int, str], str]] = None, + ) -> None: + super().__init__() + self.register_modules(unet=unet, scheduler=scheduler, vae=vae, gnet=gnet) + self._id2label = self._normalize_id2label(id2label) + self.labels = self._build_label2id(self._id2label) + self._labels_loaded_from_model_index = bool(self._id2label) + self.vae_scale_factor = 8 if self.vae is not None else 1 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_normalize=False) + + @staticmethod + def _normalize_id2label(id2label: Optional[Dict[Union[int, str], str]]) -> Dict[int, str]: + if not id2label: + return {} + return {int(key): value for key, value in id2label.items()} + + @staticmethod + def _build_label2id(id2label: Dict[int, str]) -> Dict[str, int]: + label2id: Dict[str, int] = {} + for class_id, value in id2label.items(): + for synonym in value.split(","): + synonym = synonym.strip() + if synonym: + label2id[synonym] = int(class_id) + return dict(sorted(label2id.items())) + + def _ensure_labels_loaded(self) -> None: + if self._labels_loaded_from_model_index: + return + loaded = self._read_id2label_from_model_index(getattr(self.config, "_name_or_path", None)) + if loaded: + self._id2label = loaded + self.labels = self._build_label2id(self._id2label) + self._labels_loaded_from_model_index = True + + @staticmethod + def _read_id2label_from_model_index(variant_path: Optional[str]) -> Dict[int, str]: + if not variant_path: + return {} + model_index_path = Path(variant_path).resolve() / "model_index.json" + if not model_index_path.is_file(): + return {} + raw = json.loads(model_index_path.read_text(encoding="utf-8")) + id2label = raw.get("id2label") + if not isinstance(id2label, dict): + return {} + return {int(key): value for key, value in id2label.items()} + + @property + def id2label(self) -> Dict[int, str]: + r"""ImageNet class id to English label string (comma-separated synonyms).""" + self._ensure_labels_loaded() + return self._id2label + + def get_label_ids(self, label: Union[str, List[str]]) -> List[int]: + r""" + Map ImageNet label strings to class ids. + + Args: + label (`str` or `list[str]`): + One or more English label strings that match entries in `id2label`. + """ + self._ensure_labels_loaded() + if not self.labels: + raise ValueError("No English labels loaded. Add `id2label` to model_index.json.") + labels = [label] if isinstance(label, str) else list(label) + missing = [item for item in labels if item not in self.labels] + if missing: + preview = ", ".join(list(self.labels.keys())[:8]) + raise ValueError(f"Unknown English label(s): {missing}. Example valid labels: {preview}, ...") + return [self.labels[item] for item in labels] + + def _default_image_size(self) -> int: + latent_size = int(getattr(self.unet, "sample_size", getattr(self.unet.config, "sample_size", 64))) + return latent_size * self.vae_scale_factor + + def check_inputs( + self, + height: int, + width: int, + num_inference_steps: int, + guidance_scale: float, + output_type: str, + ) -> None: + if num_inference_steps < 1: + raise ValueError("num_inference_steps must be >= 1.") + if guidance_scale < 1.0: + raise ValueError("guidance_scale must be >= 1.0.") + if guidance_scale > 1.0 and self.gnet is None: + raise ValueError("guidance_scale > 1.0 requires a guiding network (`gnet`).") + if output_type not in {"pil", "np", "pt", "latent"}: + raise ValueError("output_type must be one of: 'pil', 'np', 'pt', 'latent'.") + + native_size = self._default_image_size() + if height != native_size or width != native_size: + raise ValueError( + f"EDM2 expects native resolution height=width={native_size}. " + f"Got height={height}, width={width}." + ) + + def _normalize_class_labels( + self, + class_labels: Optional[Union[int, str, Sequence[Union[int, str]], torch.Tensor]], + batch_size: int, + device: torch.device, + ) -> Optional[torch.Tensor]: + label_dim = int(getattr(self.unet, "num_class_embeds", getattr(self.unet.config, "num_class_embeds", 0))) + if label_dim == 0: + return None + if class_labels is None: + indices = torch.randint(label_dim, size=(batch_size,), device=device) + return torch.eye(label_dim, device=device, dtype=torch.float32)[indices] + + if isinstance(class_labels, str): + class_labels = self.get_label_ids(class_labels)[0] + elif isinstance(class_labels, Sequence) and class_labels and isinstance(class_labels[0], str): + class_labels = self.get_label_ids(list(class_labels)) + + if isinstance(class_labels, int): + indices = torch.full((batch_size,), class_labels, device=device, dtype=torch.long) + elif isinstance(class_labels, torch.Tensor): + if class_labels.ndim == 2: + labels = class_labels.to(device=device, dtype=torch.float32) + if labels.shape[0] != batch_size: + raise ValueError(f"class_labels batch must match batch_size={batch_size}.") + return labels + indices = class_labels.to(device=device, dtype=torch.long).flatten() + else: + indices = torch.tensor(list(class_labels), device=device, dtype=torch.long) + + if indices.numel() == 1 and batch_size > 1: + indices = indices.repeat(batch_size) + if indices.numel() != batch_size: + raise ValueError(f"class_labels must resolve to batch size {batch_size}, got {indices.numel()}.") + return torch.eye(label_dim, device=device, dtype=torch.float32)[indices] + + def prepare_latents( + self, + batch_size: int, + height: int, + width: int, + dtype: torch.dtype, + device: torch.device, + generator: Optional[Union[torch.Generator, List[torch.Generator]]], + ) -> torch.Tensor: + in_channels = int(getattr(self.unet, "in_channels", getattr(self.unet.config, "in_channels", 4))) + latent_size = height // self.vae_scale_factor + return randn_tensor( + (batch_size, in_channels, latent_size, latent_size), + generator=generator, + device=device, + dtype=torch.float32, + ) + + def decode_latents(self, latents: torch.Tensor, output_type: str = "pil"): + if output_type == "latent": + return latents + + in_channels = int(getattr(self.unet, "in_channels", getattr(self.unet.config, "in_channels", 3))) + if self.vae is None: + image = (latents.to(torch.float32) * 127.5 + 128).clip(0, 255) / 255.0 + return self.image_processor.postprocess(image, output_type=output_type) + + if in_channels == 4: + x = latents.to(torch.float32) + scale = torch.as_tensor(_STABILITY_VAE_SCALE, dtype=x.dtype, device=x.device).reshape(1, -1, 1, 1) + bias = torch.as_tensor(_STABILITY_VAE_BIAS, dtype=x.dtype, device=x.device).reshape(1, -1, 1, 1) + x = (x - bias) / scale + else: + x = latents.to(torch.float32) + + vae_dtype = getattr(self.vae, "dtype", None) or next(self.vae.parameters()).dtype + image = self.vae.decode(x.to(dtype=vae_dtype)).sample.to(torch.float32).clamp(0, 1) + + return self.image_processor.postprocess(image, output_type=output_type) + + @staticmethod + def _apply_autoguidance( + main: torch.Tensor, + ref: torch.Tensor, + guidance_scale: float, + ) -> torch.Tensor: + return ref.lerp(main, guidance_scale) + + @staticmethod + def _sample_edm2_heun( + denoise_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor], + noise: torch.Tensor, + sigmas: torch.Tensor, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + progress_bar: Optional[Callable[[Iterable], Iterable]] = None, + dtype: torch.dtype = torch.float32, + ) -> torch.Tensor: + """NVlabs EDM2 Heun sampler (generate_images.edm_sampler, guidance=1, S_churn=0).""" + x_next = noise.to(dtype) * sigmas[0] + + sigma_pairs = list(zip(sigmas[:-1], sigmas[1:])) + if progress_bar is not None: + sigma_pairs = progress_bar(sigma_pairs) + + num_steps = len(sigma_pairs) + for i, (sigma_cur, sigma_next) in enumerate(sigma_pairs): + x_hat, sigma_hat = x_next, sigma_cur + d_cur = (x_hat - denoise_fn(x_hat, sigma_hat)) / sigma_hat + x_next = x_hat + (sigma_next - sigma_hat) * d_cur + if i < num_steps - 1: + d_prime = (x_next - denoise_fn(x_next, sigma_next)) / sigma_next + x_next = x_hat + (sigma_next - sigma_hat) * (0.5 * d_cur + 0.5 * d_prime) + return x_next + + @torch.inference_mode() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + class_labels: Optional[Union[int, str, Sequence[Union[int, str]], torch.Tensor]] = None, + batch_size: int = 1, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 32, + guidance_scale: float = 1.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + output_type: str = "pil", + return_dict: bool = True, + ) -> Union[ImagePipelineOutput, Tuple]: + r""" + Generate class-conditional images with EDM2. + + Args: + class_labels (`int`, `str`, `list[int]`, `list[str]`, or `torch.Tensor`, *optional*): + ImageNet class indices, English label strings, or one-hot float tensors. + Random classes are sampled when omitted on conditional models. + batch_size (`int`, defaults to `1`): + Number of images to generate. + height (`int`, *optional*): + Output height in pixels. Defaults to the pretrained native resolution. + width (`int`, *optional*): + Output width in pixels. Defaults to the pretrained native resolution. + num_inference_steps (`int`, defaults to `32`): + Number of EDM2 Heun steps (NVlabs default). + guidance_scale (`float`, defaults to `1.0`): + Autoguidance strength. Values above `1.0` blend the main net with `gnet` + via `gnet_output.lerp(unet_output, guidance_scale)`. + generator (`torch.Generator`, *optional*): + RNG for reproducibility. + output_type (`str`, defaults to `"pil"`): + `"pil"`, `"np"`, `"pt"`, or `"latent"`. + return_dict (`bool`, defaults to `True`): + Return [`~pipelines.pipeline_utils.ImagePipelineOutput`] if True. + + Examples: + + """ + default_size = self._default_image_size() + height = int(height or default_size) + width = int(width or default_size) + self.check_inputs(height, width, num_inference_steps, guidance_scale, output_type) + + device = self._execution_device + dtype = self.unet.dtype + labels = self._normalize_class_labels(class_labels, batch_size=batch_size, device=device) + noise = self.prepare_latents(batch_size, height, width, dtype, device, generator) + + def denoise_fn(x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor: + sigma_batch = sigma.reshape(1).expand(batch_size) + main = self.unet( + sample=x, + sigma=sigma_batch, + class_labels=labels, + force_fp32=True, + ).sample + if guidance_scale == 1.0 or self.gnet is None: + return main.to(torch.float32) + ref = self.gnet( + sample=x, + sigma=sigma_batch, + class_labels=labels, + force_fp32=True, + ).sample + return self._apply_autoguidance(main, ref, guidance_scale).to(torch.float32) + + self.scheduler.set_timesteps(num_inference_steps, device=device) + latents = self._sample_edm2_heun( + denoise_fn=denoise_fn, + noise=noise, + sigmas=self.scheduler.sigmas.to(device), + generator=generator, + progress_bar=self.progress_bar, + dtype=torch.float32, + ) + + image = self.decode_latents(latents, output_type=output_type) + if not return_dict: + return (image, latents) + return ImagePipelineOutput(images=image) + + @classmethod + def _load_vae(cls, pretrained_model_name_or_path: str, torch_dtype: Optional[torch.dtype] = None): + vae_dir = os.path.join(pretrained_model_name_or_path, "vae") + if os.path.isdir(vae_dir): + try: + + return AutoencoderKL.from_pretrained(vae_dir, torch_dtype=torch_dtype) + except Exception: + return None + + vae_hint = os.path.join(pretrained_model_name_or_path, "vae_pretrained_model_name_or_path.txt") + if os.path.isfile(vae_hint): + with open(vae_hint, "r", encoding="utf-8") as f: + hub_id = f.read().strip() + if hub_id: + + return AutoencoderKL.from_pretrained(hub_id, torch_dtype=torch_dtype) + return None diff --git a/edm2-img512-xxl-fid/scheduler/scheduler_config.json b/edm2-img512-xxl-fid/scheduler/scheduler_config.json new file mode 100644 index 0000000000000000000000000000000000000000..08cbde97955d06d66b0b72be217db84bb797d711 --- /dev/null +++ b/edm2-img512-xxl-fid/scheduler/scheduler_config.json @@ -0,0 +1,11 @@ +{ + "_class_name": "EDMEulerScheduler", + "final_sigmas_type": "zero", + "num_train_timesteps": 1000, + "prediction_type": "epsilon", + "rho": 7.0, + "sigma_data": 0.5, + "sigma_max": 80.0, + "sigma_min": 0.002, + "sigma_schedule": "karras" +} diff --git a/edm2-img512-xxl-fid/unet/config.json b/edm2-img512-xxl-fid/unet/config.json new file mode 100644 index 0000000000000000000000000000000000000000..e62e0c056bed5b2ae37ccc58d5bb2a064116d52e --- /dev/null +++ b/edm2-img512-xxl-fid/unet/config.json @@ -0,0 +1,31 @@ +{ + "_class_name": "EDM2UNet2DModel", + "attn_balance": 0.3, + "attn_resolutions": [ + 16, + 8 + ], + "channel_mult": [ + 1, + 2, + 3, + 4 + ], + "channel_mult_emb": 4, + "channel_mult_noise": 1, + "channels_per_head": 64, + "clip_act": 256, + "concat_balance": 0.5, + "dropout": 0.0, + "in_channels": 4, + "label_balance": 0.5, + "logvar_channels": 128, + "model_channels": 448, + "num_blocks": 3, + "num_class_embeds": 1000, + "out_channels": 4, + "res_balance": 0.3, + "sample_size": 64, + "sigma_data": 0.5, + "use_fp16": true +} diff --git a/edm2-img512-xxl-fid/unet/diffusion_pytorch_model.safetensors b/edm2-img512-xxl-fid/unet/diffusion_pytorch_model.safetensors new file mode 100644 index 0000000000000000000000000000000000000000..67ed03655a4a647f47d640e3e0a97ca403eb4bfc --- /dev/null +++ b/edm2-img512-xxl-fid/unet/diffusion_pytorch_model.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:018c4f1a23a207e787667d195a2df9522be8543590d72cd47f6420590700da2e +size 6092686516 diff --git a/edm2-img512-xxl-fid/unet/unet_edm2.py b/edm2-img512-xxl-fid/unet/unet_edm2.py new file mode 100644 index 0000000000000000000000000000000000000000..49d9451b443c4d6ad3b6ba2db19a7a804b07b6de --- /dev/null +++ b/edm2-img512-xxl-fid/unet/unet_edm2.py @@ -0,0 +1,434 @@ +import math +import json +import os +from dataclasses import dataclass +from typing import List, Optional, Tuple + +import numpy as np +import torch + +try: + from diffusers.configuration_utils import ConfigMixin, register_to_config + from diffusers.models.modeling_utils import ModelMixin + from diffusers.utils import BaseOutput +except ImportError: # pragma: no cover + class ModelMixin(torch.nn.Module): + pass + + class ConfigMixin: + config = {} + + def register_to_config(self, **kwargs): + self.config = kwargs + + def register_to_config(func): + return func + + @dataclass + class BaseOutput: + pass + + +def normalize(x: torch.Tensor, dim: Optional[List[int]] = None, eps: float = 1e-4) -> torch.Tensor: + if dim is None: + dim = list(range(1, x.ndim)) + norm = torch.linalg.vector_norm(x, dim=dim, keepdim=True, dtype=torch.float32) + norm = torch.add(eps, norm, alpha=math.sqrt(norm.numel() / x.numel())) + return x / norm.to(x.dtype) + + +def resample(x: torch.Tensor, f: List[float], mode: str = "keep") -> torch.Tensor: + if mode == "keep": + return x + filt = np.float32(f) + pad = (len(filt) - 1) // 2 + filt = filt / filt.sum() + filt = np.outer(filt, filt)[np.newaxis, np.newaxis, :, :] + filt = torch.as_tensor(filt, dtype=x.dtype, device=x.device) + c = x.shape[1] + if mode == "down": + return torch.nn.functional.conv2d(x, filt.tile([c, 1, 1, 1]), groups=c, stride=2, padding=(pad,)) + return torch.nn.functional.conv_transpose2d(x, (filt * 4).tile([c, 1, 1, 1]), groups=c, stride=2, padding=(pad,)) + + +def mp_silu(x: torch.Tensor) -> torch.Tensor: + return torch.nn.functional.silu(x) / 0.596 + + +def mp_sum(a: torch.Tensor, b: torch.Tensor, t: float = 0.5) -> torch.Tensor: + return a.lerp(b, t) / math.sqrt((1 - t) ** 2 + t ** 2) + + +def mp_cat(a: torch.Tensor, b: torch.Tensor, dim: int = 1, t: float = 0.5) -> torch.Tensor: + na = a.shape[dim] + nb = b.shape[dim] + c = math.sqrt((na + nb) / ((1 - t) ** 2 + t ** 2)) + wa = c / math.sqrt(na) * (1 - t) + wb = c / math.sqrt(nb) * t + return torch.cat([wa * a, wb * b], dim=dim) + + +class MPFourier(torch.nn.Module): + def __init__(self, num_channels: int, bandwidth: float = 1): + super().__init__() + self.register_buffer("freqs", 2 * math.pi * torch.randn(num_channels) * bandwidth) + self.register_buffer("phases", 2 * math.pi * torch.rand(num_channels)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + y = x.to(torch.float32).ger(self.freqs.to(torch.float32)) + y = y + self.phases.to(torch.float32) + y = y.cos() * math.sqrt(2) + return y.to(x.dtype) + + +class MPConv(torch.nn.Module): + def __init__(self, in_channels: int, out_channels: int, kernel: Tuple[int, ...]): + super().__init__() + self.out_channels = out_channels + self.weight = torch.nn.Parameter(torch.randn(out_channels, in_channels, *kernel)) + + def forward(self, x: torch.Tensor, gain: float = 1) -> torch.Tensor: + w = self.weight.to(torch.float32) + if self.training: + with torch.no_grad(): + self.weight.copy_(normalize(w)) + w = normalize(w) + w = w * (gain / math.sqrt(w[0].numel())) + w = w.to(x.dtype) + if w.ndim == 2: + return x @ w.t() + return torch.nn.functional.conv2d(x, w, padding=(w.shape[-1] // 2,)) + + +class Block(torch.nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + emb_channels: int, + flavor: str = "enc", + resample_mode: str = "keep", + resample_filter: List[float] = [1, 1], + attention: bool = False, + channels_per_head: int = 64, + dropout: float = 0.0, + res_balance: float = 0.3, + attn_balance: float = 0.3, + clip_act: Optional[float] = 256, + ): + super().__init__() + self.out_channels = out_channels + self.flavor = flavor + self.resample_filter = resample_filter + self.resample_mode = resample_mode + self.num_heads = out_channels // channels_per_head if attention else 0 + self.dropout = dropout + self.res_balance = res_balance + self.attn_balance = attn_balance + self.clip_act = clip_act + self.emb_gain = torch.nn.Parameter(torch.zeros([])) + self.conv_res0 = MPConv(out_channels if flavor == "enc" else in_channels, out_channels, kernel=(3, 3)) + self.emb_linear = MPConv(emb_channels, out_channels, kernel=()) + self.conv_res1 = MPConv(out_channels, out_channels, kernel=(3, 3)) + self.conv_skip = MPConv(in_channels, out_channels, kernel=(1, 1)) if in_channels != out_channels else None + self.attn_qkv = MPConv(out_channels, out_channels * 3, kernel=(1, 1)) if self.num_heads else None + self.attn_proj = MPConv(out_channels, out_channels, kernel=(1, 1)) if self.num_heads else None + + def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor: + x = resample(x, f=self.resample_filter, mode=self.resample_mode) + if self.flavor == "enc": + if self.conv_skip is not None: + x = self.conv_skip(x) + x = normalize(x, dim=[1]) + + y = self.conv_res0(mp_silu(x)) + c = self.emb_linear(emb, gain=self.emb_gain) + 1 + y = mp_silu(y * c.unsqueeze(2).unsqueeze(3).to(y.dtype)) + if self.training and self.dropout: + y = torch.nn.functional.dropout(y, p=self.dropout) + y = self.conv_res1(y) + + if self.flavor == "dec" and self.conv_skip is not None: + x = self.conv_skip(x) + x = mp_sum(x, y, t=self.res_balance) + + if self.num_heads: + y = self.attn_qkv(x) + y = y.reshape(y.shape[0], self.num_heads, -1, 3, y.shape[2] * y.shape[3]) + q, k, v = normalize(y, dim=[2]).unbind(3) + w = torch.einsum("nhcq,nhck->nhqk", q, k / math.sqrt(q.shape[2])).softmax(dim=3) + y = torch.einsum("nhqk,nhck->nhcq", w, v) + y = self.attn_proj(y.reshape(*x.shape)) + x = mp_sum(x, y, t=self.attn_balance) + + if self.clip_act is not None: + x = x.clip_(-self.clip_act, self.clip_act) + return x + + +class EDM2UNet(torch.nn.Module): + def __init__( + self, + img_resolution: int, + img_channels: int, + label_dim: int, + model_channels: int = 192, + channel_mult: Tuple[int, ...] = (1, 2, 3, 4), + channel_mult_noise: Optional[int] = None, + channel_mult_emb: Optional[int] = None, + num_blocks: int = 3, + attn_resolutions: Tuple[int, ...] = (16, 8), + label_balance: float = 0.5, + concat_balance: float = 0.5, + **block_kwargs, + ): + super().__init__() + cblock = [model_channels * x for x in channel_mult] + cnoise = model_channels * channel_mult_noise if channel_mult_noise is not None else cblock[0] + cemb = model_channels * channel_mult_emb if channel_mult_emb is not None else max(cblock) + self.label_balance = label_balance + self.concat_balance = concat_balance + self.out_gain = torch.nn.Parameter(torch.zeros([])) + + self.emb_fourier = MPFourier(cnoise) + self.emb_noise = MPConv(cnoise, cemb, kernel=()) + self.emb_label = MPConv(label_dim, cemb, kernel=()) if label_dim else None + + self.enc = torch.nn.ModuleDict() + cout = img_channels + 1 + for level, channels in enumerate(cblock): + res = img_resolution >> level + if level == 0: + cin = cout + cout = channels + self.enc[f"{res}x{res}_conv"] = MPConv(cin, cout, kernel=(3, 3)) + else: + self.enc[f"{res}x{res}_down"] = Block(cout, cout, cemb, flavor="enc", resample_mode="down", **block_kwargs) + for idx in range(num_blocks): + cin = cout + cout = channels + self.enc[f"{res}x{res}_block{idx}"] = Block( + cin, + cout, + cemb, + flavor="enc", + attention=(res in attn_resolutions), + **block_kwargs, + ) + + self.dec = torch.nn.ModuleDict() + skips = [block.out_channels for block in self.enc.values()] + for level, channels in reversed(list(enumerate(cblock))): + res = img_resolution >> level + if level == len(cblock) - 1: + self.dec[f"{res}x{res}_in0"] = Block(cout, cout, cemb, flavor="dec", attention=True, **block_kwargs) + self.dec[f"{res}x{res}_in1"] = Block(cout, cout, cemb, flavor="dec", **block_kwargs) + else: + self.dec[f"{res}x{res}_up"] = Block(cout, cout, cemb, flavor="dec", resample_mode="up", **block_kwargs) + for idx in range(num_blocks + 1): + cin = cout + skips.pop() + cout = channels + self.dec[f"{res}x{res}_block{idx}"] = Block( + cin, + cout, + cemb, + flavor="dec", + attention=(res in attn_resolutions), + **block_kwargs, + ) + + self.out_conv = MPConv(cout, img_channels, kernel=(3, 3)) + + def forward(self, x: torch.Tensor, noise_labels: torch.Tensor, class_labels: Optional[torch.Tensor]) -> torch.Tensor: + emb = self.emb_noise(self.emb_fourier(noise_labels)) + if self.emb_label is not None: + if class_labels is None: + raise ValueError("class_labels are required for conditional EDM2UNet.") + emb = mp_sum(emb, self.emb_label(class_labels * math.sqrt(class_labels.shape[1])), t=self.label_balance) + emb = mp_silu(emb) + + x = torch.cat([x, torch.ones_like(x[:, :1])], dim=1) + skips = [] + for name, block in self.enc.items(): + x = block(x) if "conv" in name else block(x, emb) + skips.append(x) + + for name, block in self.dec.items(): + if "block" in name: + x = mp_cat(x, skips.pop(), t=self.concat_balance) + x = block(x, emb) + return self.out_conv(x, gain=self.out_gain) + + +@dataclass +class EDM2UNet2DOutput(BaseOutput): + sample: torch.Tensor + logvar: Optional[torch.Tensor] = None + + + +_CONFIG_KEYS = ( + "sample_size", + "in_channels", + "out_channels", + "num_class_embeds", + "use_fp16", + "sigma_data", + "logvar_channels", + "model_channels", + "channel_mult", + "channel_mult_noise", + "channel_mult_emb", + "num_blocks", + "attn_resolutions", + "label_balance", + "concat_balance", + "dropout", + "channels_per_head", + "res_balance", + "attn_balance", + "clip_act", +) + + +class EDM2UNet2DModel(ModelMixin, ConfigMixin): + @register_to_config + def __init__( + self, + sample_size: int = 64, + in_channels: int = 4, + out_channels: int = 4, + num_class_embeds: int = 0, + use_fp16: bool = True, + sigma_data: float = 0.5, + logvar_channels: int = 128, + model_channels: int = 192, + channel_mult: Tuple[int, ...] = (1, 2, 3, 4), + channel_mult_noise: Optional[int] = None, + channel_mult_emb: Optional[int] = None, + num_blocks: int = 3, + attn_resolutions: Tuple[int, ...] = (16, 8), + label_balance: float = 0.5, + concat_balance: float = 0.5, + dropout: float = 0.0, + channels_per_head: int = 64, + res_balance: float = 0.3, + attn_balance: float = 0.3, + clip_act: Optional[float] = 256, + ): + super().__init__() + self.sample_size = sample_size + self.in_channels = in_channels + self.out_channels = out_channels + self.num_class_embeds = num_class_embeds + self.use_fp16 = use_fp16 + self.sigma_data = sigma_data + self.model_channels = model_channels + self.channel_mult = channel_mult + self.channel_mult_noise = channel_mult_noise + self.channel_mult_emb = channel_mult_emb + self.num_blocks = num_blocks + self.attn_resolutions = attn_resolutions + self.label_balance = label_balance + self.concat_balance = concat_balance + self.dropout = dropout + self.channels_per_head = channels_per_head + self.res_balance = res_balance + self.attn_balance = attn_balance + self.clip_act = clip_act + self.unet = EDM2UNet( + img_resolution=sample_size, + img_channels=in_channels, + label_dim=num_class_embeds, + model_channels=model_channels, + channel_mult=channel_mult, + channel_mult_noise=channel_mult_noise, + channel_mult_emb=channel_mult_emb, + num_blocks=num_blocks, + attn_resolutions=attn_resolutions, + label_balance=label_balance, + concat_balance=concat_balance, + dropout=dropout, + channels_per_head=channels_per_head, + res_balance=res_balance, + attn_balance=attn_balance, + clip_act=clip_act, + ) + self.logvar_fourier = MPFourier(logvar_channels) + self.logvar_linear = MPConv(logvar_channels, 1, kernel=()) + + def forward( + self, + sample: torch.Tensor, + sigma: torch.Tensor, + class_labels: Optional[torch.Tensor] = None, + force_fp32: bool = False, + return_logvar: bool = False, + return_dict: bool = True, + ) -> EDM2UNet2DOutput: + x = sample.to(torch.float32) + sigma = sigma.to(torch.float32).reshape(-1, 1, 1, 1) + if self.num_class_embeds == 0: + class_labels = None + else: + if class_labels is None: + class_labels = torch.zeros([x.shape[0], self.num_class_embeds], device=x.device) + class_labels = class_labels.to(torch.float32).reshape(-1, self.num_class_embeds) + dtype = torch.float16 if (self.use_fp16 and not force_fp32 and x.device.type == "cuda") else torch.float32 + + c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2) + c_out = sigma * self.sigma_data / (sigma**2 + self.sigma_data**2).sqrt() + c_in = 1 / (self.sigma_data**2 + sigma**2).sqrt() + c_noise = sigma.flatten().log() / 4 + + x_in = (c_in * x).to(dtype) + f_x = self.unet(x_in, c_noise, class_labels) + d_x = c_skip * x + c_out * f_x.to(torch.float32) + + logvar = None + if return_logvar: + logvar = self.logvar_linear(self.logvar_fourier(c_noise)).reshape(-1, 1, 1, 1) + + if not return_dict: + return (d_x, logvar) + return EDM2UNet2DOutput(sample=d_x, logvar=logvar) + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path: str, torch_dtype: Optional[torch.dtype] = None, **kwargs): + subfolder = kwargs.pop("subfolder", None) + model_dir = os.path.join(pretrained_model_name_or_path, subfolder) if subfolder else pretrained_model_name_or_path + with open(os.path.join(model_dir, "config.json"), "r", encoding="utf-8") as f: + config = json.load(f) + init_kwargs = {k: v for k, v in config.items() if k in _CONFIG_KEYS} + model = cls(**init_kwargs) + weight_file = os.path.join(model_dir, "diffusion_pytorch_model.safetensors") + if os.path.isfile(weight_file): + from safetensors.torch import load_file + + state_dict = load_file(weight_file) + else: + state_dict = torch.load(os.path.join(model_dir, "diffusion_pytorch_model.bin"), map_location="cpu") + model.load_state_dict(state_dict, strict=True) + if torch_dtype is not None: + model = model.to(dtype=torch_dtype) + return model + + def save_pretrained(self, save_directory: str, safe_serialization: bool = True): + os.makedirs(save_directory, exist_ok=True) + stored = dict(getattr(self, "config", {})) + config = {"_class_name": self.__class__.__name__} + for key in _CONFIG_KEYS: + if key in stored: + config[key] = stored[key] + elif hasattr(self, key): + config[key] = getattr(self, key) + with open(os.path.join(save_directory, "config.json"), "w", encoding="utf-8") as f: + json.dump(config, f, indent=2, sort_keys=True) + f.write("\n") + state_dict = self.state_dict() + if safe_serialization: + from safetensors.torch import save_file + + save_file(state_dict, os.path.join(save_directory, "diffusion_pytorch_model.safetensors")) + else: + torch.save(state_dict, os.path.join(save_directory, "diffusion_pytorch_model.bin")) diff --git a/edm2-img512-xxl-fid/vae/config.json b/edm2-img512-xxl-fid/vae/config.json new file mode 100644 index 0000000000000000000000000000000000000000..ae2ab37de2acb8bc348d3de809ae8385324be962 --- /dev/null +++ b/edm2-img512-xxl-fid/vae/config.json @@ -0,0 +1,38 @@ +{ + "_class_name": "AutoencoderKL", + "_diffusers_version": "0.36.0", + "_name_or_path": "stabilityai/sd-vae-ft-mse", + "act_fn": "silu", + "block_out_channels": [ + 128, + 256, + 512, + 512 + ], + "down_block_types": [ + "DownEncoderBlock2D", + "DownEncoderBlock2D", + "DownEncoderBlock2D", + "DownEncoderBlock2D" + ], + "force_upcast": true, + "in_channels": 3, + "latent_channels": 4, + "latents_mean": null, + "latents_std": null, + "layers_per_block": 2, + "mid_block_add_attention": true, + "norm_num_groups": 32, + "out_channels": 3, + "sample_size": 256, + "scaling_factor": 0.18215, + "shift_factor": null, + "up_block_types": [ + "UpDecoderBlock2D", + "UpDecoderBlock2D", + "UpDecoderBlock2D", + "UpDecoderBlock2D" + ], + "use_post_quant_conv": true, + "use_quant_conv": true +} diff --git a/edm2-img512-xxl-fid/vae/diffusion_pytorch_model.safetensors b/edm2-img512-xxl-fid/vae/diffusion_pytorch_model.safetensors new file mode 100644 index 0000000000000000000000000000000000000000..90464d67ac7303d0ee4696334df13da130a948ea --- /dev/null +++ b/edm2-img512-xxl-fid/vae/diffusion_pytorch_model.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a1d993488569e928462932c8c38a0760b874d166399b14414135bd9c42df5815 +size 334643276