ADM-diffusers / ADM-G-256 /pipeline.py
BiliSakura's picture
Fix generator determinism: forward generator through scheduler steps and seeded noise
cc5e6bd verified
"""Hub custom pipeline: ADMPipeline.
Load with native Hugging Face diffusers and trust_remote_code=True.
"""
from __future__ import annotations
# Copyright 2026 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from typing import Any, Dict, List, Optional, Set, Tuple, Union
import torch
from diffusers.image_processor import VaeImageProcessor
from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
from diffusers.utils import replace_example_docstring
from diffusers.utils.torch_utils import randn_tensor
EXAMPLE_DOC_STRING = """
Examples:
```py
>>> from pathlib import Path
>>> import torch
>>> from diffusers import DiffusionPipeline
>>> model_dir = Path("path/to/BiliSakura/ADM-diffusers/ADM-G-256")
>>> pipe = DiffusionPipeline.from_pretrained(
... str(model_dir),
... local_files_only=True,
... custom_pipeline=str(model_dir / "pipeline.py"),
... trust_remote_code=True,
... torch_dtype=torch.bfloat16,
... )
>>> pipe = pipe.to("cuda")
>>> class_id = pipe.get_label_ids("golden retriever")[0]
>>> image = pipe(class_labels=class_id, guidance_scale=1.0).images[0]
```
"""
class ADMPipeline(DiffusionPipeline):
r"""ADM/ADM-G pipeline compatible with Diffusers custom pipeline loading."""
model_cpu_offload_seq = "classifier->unet"
_optional_components = ["classifier"]
def __init__(
self,
unet,
scheduler,
classifier: Optional[Any] = None,
id2label: Optional[Dict[str, str]] = None,
null_class_id: int = 1000,
) -> None:
super().__init__()
self.register_modules(unet=unet, scheduler=scheduler, classifier=classifier)
self.register_to_config(null_class_id=int(null_class_id))
self.image_processor = VaeImageProcessor(vae_scale_factor=1, do_normalize=False)
self._id2label = {int(k): v for k, v in (id2label or {}).items()}
self.labels = self._build_label2id(self._id2label)
@staticmethod
def _build_label2id(id2label: Dict[int, str]) -> Dict[str, int]:
label2id: Dict[str, int] = {}
for class_id, value in id2label.items():
for synonym in value.split(","):
synonym = synonym.strip()
if synonym:
label2id[synonym] = int(class_id)
return dict(sorted(label2id.items()))
@property
def id2label(self) -> Dict[int, str]:
return self._id2label
def get_label_ids(self, label: Union[str, List[str]]) -> List[int]:
if not self.labels:
raise ValueError("No id2label mapping is available in this checkpoint.")
labels = [label] if isinstance(label, str) else label
missing = [item for item in labels if item not in self.labels]
if missing:
preview = ", ".join(list(self.labels.keys())[:8])
raise ValueError(f"Unknown labels: {missing}. Example valid labels: {preview}, ...")
return [self.labels[item] for item in labels]
@staticmethod
def prepare_extra_step_kwargs(
scheduler,
generator: Optional[Union[torch.Generator, List[torch.Generator]]],
eta: float,
) -> Dict[str, Any]:
kwargs: Dict[str, Any] = {}
step_params = set(inspect.signature(scheduler.step).parameters.keys())
if "eta" in step_params:
kwargs["eta"] = eta
if "generator" in step_params:
kwargs["generator"] = generator
return kwargs
@staticmethod
def _is_ddim_like(step_params: Set[str]) -> bool:
return "eta" in step_params
@staticmethod
def _prepare_model_output_for_scheduler(
model_output: torch.Tensor,
channels: int,
scheduler,
) -> torch.Tensor:
if model_output.shape[1] != 2 * channels:
return model_output
variance_type = getattr(scheduler.config, "variance_type", None)
if scheduler.__class__.__name__ == "DDPMScheduler" and variance_type in ("learned", "learned_range"):
return model_output
model_output, _ = torch.split(model_output, channels, dim=1)
return model_output
@staticmethod
def _expand_timestep(timestep, batch: int, device: torch.device) -> torch.Tensor:
if not torch.is_tensor(timestep):
timestep = torch.tensor([timestep], dtype=torch.long, device=device)
elif timestep.ndim == 0:
timestep = timestep[None].to(device=device)
return timestep.expand(batch)
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
class_labels: Optional[Union[int, str, List[Union[int, str]], torch.Tensor]] = None,
batch_size: int = 1,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 250,
guidance_scale: float = 1.0,
classifier_guidance_scale: float = 0.0,
eta: float = 0.0,
clip_denoised: bool = True,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.Tensor] = None,
output_type: str = "pil",
return_dict: bool = True,
) -> Union[ImagePipelineOutput, Tuple]:
r"""
Generate samples from the ADM/ADM-G checkpoint.
Examples:
<!-- this section is replaced by replace_example_docstring -->
"""
# Stage 1: check inputs
if isinstance(class_labels, str):
class_labels = self.get_label_ids(class_labels)[0]
if isinstance(class_labels, list) and class_labels and isinstance(class_labels[0], str):
class_labels = self.get_label_ids(class_labels)
native_size = int(getattr(self.unet.config, "image_size", 256))
height = native_size if height is None else int(height)
width = native_size if width is None else int(width)
if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"height and width must be divisible by 8, got ({height}, {width}).")
if output_type not in {"pil", "np", "pt", "latent"}:
raise ValueError(f"Unsupported output_type: {output_type}")
# This checkpoint does not use classifier-free guidance (CFG).
# Keep classifier_guidance_scale for compatibility, but treat guidance_scale
# as the primary classifier-guidance strength.
effective_classifier_guidance_scale = (
float(classifier_guidance_scale) if classifier_guidance_scale > 0 else float(guidance_scale)
)
if class_labels is None and (self.unet.config.class_cond or effective_classifier_guidance_scale > 0):
raise ValueError("class_labels are required for class-conditional sampling and ADM-G guidance.")
if isinstance(class_labels, int):
batch_size = 1
class_labels = [class_labels]
elif isinstance(class_labels, list):
batch_size = len(class_labels)
elif torch.is_tensor(class_labels):
batch_size = int(class_labels.shape[0])
# Stage 2: define call parameters
device = self._execution_device
channels = int(getattr(self.unet.config, "in_channels", 3))
dtype = self.unet.dtype
# Stage 3: prepare class conditioning
class_tensor = None
class_input = None
if class_labels is not None:
class_tensor = class_labels if torch.is_tensor(class_labels) else torch.tensor(class_labels, dtype=torch.long)
class_tensor = class_tensor.to(device=device, dtype=torch.long).reshape(-1)
if class_tensor.shape[0] != batch_size:
raise ValueError("class_labels batch must match requested batch_size")
if self.unet.config.class_cond:
class_input = class_tensor
# Stage 4: prepare timesteps
scheduler = self.scheduler
step_params = set(inspect.signature(scheduler.step).parameters.keys())
scheduler.set_timesteps(num_inference_steps, device=device)
# Stage 5: prepare latent variables
shape = (batch_size, channels, height, width)
if latents is None:
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
else:
if tuple(latents.shape) != shape:
raise ValueError(f"Unexpected latents shape {tuple(latents.shape)}; expected {shape}.")
latents = latents.to(device=device, dtype=dtype)
latents = latents * scheduler.init_noise_sigma
# Stage 6: prepare extra step kwargs
extra_step_kwargs = self.prepare_extra_step_kwargs(scheduler, generator, eta)
# Stage 7: denoising loop
for timestep in self.progress_bar(scheduler.timesteps):
model_input = latents
model_input = scheduler.scale_model_input(model_input, timestep)
timestep_input = self._expand_timestep(timestep, model_input.shape[0], model_input.device)
model_output = self.unet(model_input, timestep_input, class_labels=class_input, return_dict=True).sample
cond_grad = None
if effective_classifier_guidance_scale > 0:
if self.classifier is None or class_tensor is None:
raise ValueError("guidance_scale requires both classifier and class_labels.")
grad_t = self._expand_timestep(timestep, batch_size, latents.device)
cond_grad = self.classifier.guidance_gradient(
latents, grad_t, class_tensor, classifier_scale=effective_classifier_guidance_scale
)
step_model_output = model_output
if cond_grad is not None:
if self._is_ddim_like(step_params):
eps = model_output[:, :channels] if model_output.shape[1] == 2 * channels else model_output
alpha_bar_t = scheduler.alphas_cumprod[timestep].to(device=latents.device, dtype=latents.dtype)
step_model_output = eps - (1 - alpha_bar_t).sqrt() * cond_grad
elif hasattr(scheduler, "_get_variance"):
pred_var = None
if model_output.shape[1] == 2 * channels:
_, pred_var = torch.split(model_output, channels, dim=1)
variance = scheduler._get_variance(int(timestep), predicted_variance=pred_var)
if scheduler.config.variance_type == "learned_range":
variance = torch.exp(variance)
latents = latents + variance * cond_grad
else:
raise ValueError(
"guidance_scale is not supported for the current scheduler. "
"Use a DDPM/DDIM-compatible scheduler or disable classifier guidance."
)
step_model_output = self._prepare_model_output_for_scheduler(step_model_output, channels, scheduler)
latents = scheduler.step(step_model_output, timestep, latents, return_dict=True, **extra_step_kwargs).prev_sample
image = latents if output_type == "latent" else (latents / 2 + 0.5).clamp(0, 1)
if output_type in {"pil", "np"}:
image = self.image_processor.postprocess(image, output_type=output_type)
self.maybe_free_model_hooks()
if not return_dict:
return (image,)
return ImagePipelineOutput(images=image)