ADM-diffusers / ADM-G-512 /classifier /classifier_adm.py
BiliSakura's picture
Upload folder using huggingface_hub
7fc7e34 verified
# Copyright 2026 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
import math
from abc import abstractmethod
from dataclasses import dataclass
from typing import Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint as torch_checkpoint
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.models.embeddings import get_timestep_embedding
from diffusers.models.modeling_utils import ModelMixin
from diffusers.utils import BaseOutput
NUM_CLASSES = 1000
def conv_nd(dims: int, *args, **kwargs):
if dims == 1:
return nn.Conv1d(*args, **kwargs)
if dims == 2:
return nn.Conv2d(*args, **kwargs)
if dims == 3:
return nn.Conv3d(*args, **kwargs)
raise ValueError(f"unsupported dimensions: {dims}")
def linear(*args, **kwargs):
return nn.Linear(*args, **kwargs)
def avg_pool_nd(dims: int, *args, **kwargs):
if dims == 1:
return nn.AvgPool1d(*args, **kwargs)
if dims == 2:
return nn.AvgPool2d(*args, **kwargs)
if dims == 3:
return nn.AvgPool3d(*args, **kwargs)
raise ValueError(f"unsupported dimensions: {dims}")
class GroupNorm32(nn.GroupNorm):
def forward(self, x):
weight = self.weight.float() if self.weight is not None else None
bias = self.bias.float() if self.bias is not None else None
y = F.group_norm(x.float(), self.num_groups, weight, bias, self.eps)
return y.to(dtype=x.dtype)
def normalization(channels: int):
return GroupNorm32(32, channels)
def zero_module(module: nn.Module):
for p in module.parameters():
p.detach().zero_()
return module
def convert_module_to_f16(module: nn.Module):
if isinstance(module, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
module.weight.data = module.weight.data.half()
if module.bias is not None:
module.bias.data = module.bias.data.half()
def convert_module_to_f32(module: nn.Module):
if isinstance(module, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
module.weight.data = module.weight.data.float()
if module.bias is not None:
module.bias.data = module.bias.data.float()
class TimestepBlock(nn.Module):
@abstractmethod
def forward(self, x, emb):
raise NotImplementedError
class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
def forward(self, x, emb):
for layer in self:
if isinstance(layer, TimestepBlock):
x = layer(x, emb)
else:
x = layer(x)
return x
class Upsample(nn.Module):
def __init__(self, channels, use_conv, dims=2, out_channels=None):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.dims = dims
if use_conv:
self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=1)
def forward(self, x):
assert x.shape[1] == self.channels
if self.dims == 3:
x = F.interpolate(x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest")
else:
x = F.interpolate(x, scale_factor=2, mode="nearest")
if self.use_conv:
x = self.conv(x)
return x
class Downsample(nn.Module):
def __init__(self, channels, use_conv, dims=2, out_channels=None):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
stride = 2 if dims != 3 else (1, 2, 2)
if use_conv:
self.op = conv_nd(dims, self.channels, self.out_channels, 3, stride=stride, padding=1)
else:
assert self.channels == self.out_channels
self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
def forward(self, x):
assert x.shape[1] == self.channels
return self.op(x)
class ResBlock(TimestepBlock):
def __init__(
self,
channels,
emb_channels,
dropout,
out_channels=None,
use_conv=False,
use_scale_shift_norm=False,
dims=2,
use_checkpoint=False,
up=False,
down=False,
):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_checkpoint = use_checkpoint
self.use_scale_shift_norm = use_scale_shift_norm
self.in_layers = nn.Sequential(
normalization(channels),
nn.SiLU(),
conv_nd(dims, channels, self.out_channels, 3, padding=1),
)
self.updown = up or down
if up:
self.h_upd = Upsample(channels, False, dims)
self.x_upd = Upsample(channels, False, dims)
elif down:
self.h_upd = Downsample(channels, False, dims)
self.x_upd = Downsample(channels, False, dims)
else:
self.h_upd = self.x_upd = nn.Identity()
self.emb_layers = nn.Sequential(
nn.SiLU(),
linear(emb_channels, 2 * self.out_channels if use_scale_shift_norm else self.out_channels),
)
self.out_layers = nn.Sequential(
normalization(self.out_channels),
nn.SiLU(),
nn.Dropout(p=dropout),
zero_module(conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)),
)
if self.out_channels == channels:
self.skip_connection = nn.Identity()
elif use_conv:
self.skip_connection = conv_nd(dims, channels, self.out_channels, 3, padding=1)
else:
self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
def forward(self, x, emb):
if self.use_checkpoint and x.requires_grad:
return torch_checkpoint(self._forward, x, emb, use_reentrant=False)
return self._forward(x, emb)
def _forward(self, x, emb):
if self.updown:
in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
h = in_rest(x)
h = self.h_upd(h)
x = self.x_upd(x)
h = in_conv(h)
else:
h = self.in_layers(x)
emb_out = self.emb_layers(emb).type(h.dtype)
while len(emb_out.shape) < len(h.shape):
emb_out = emb_out[..., None]
if self.use_scale_shift_norm:
out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
scale, shift = torch.chunk(emb_out, 2, dim=1)
h = out_norm(h) * (1 + scale) + shift
h = out_rest(h)
else:
h = h + emb_out
h = self.out_layers(h)
return self.skip_connection(x) + h
class QKVAttentionLegacy(nn.Module):
def __init__(self, n_heads):
super().__init__()
self.n_heads = n_heads
def forward(self, qkv):
bs, width, length = qkv.shape
assert width % (3 * self.n_heads) == 0
ch = width // (3 * self.n_heads)
q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
scale = 1 / math.sqrt(math.sqrt(ch))
weight = torch.einsum("bct,bcs->bts", q * scale, k * scale)
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
a = torch.einsum("bts,bcs->bct", weight, v)
return a.reshape(bs, -1, length)
class QKVAttention(nn.Module):
def __init__(self, n_heads):
super().__init__()
self.n_heads = n_heads
def forward(self, qkv):
bs, width, length = qkv.shape
assert width % (3 * self.n_heads) == 0
ch = width // (3 * self.n_heads)
q, k, v = qkv.chunk(3, dim=1)
scale = 1 / math.sqrt(math.sqrt(ch))
weight = torch.einsum(
"bct,bcs->bts",
(q * scale).view(bs * self.n_heads, ch, length),
(k * scale).view(bs * self.n_heads, ch, length),
)
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
a = torch.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length))
return a.reshape(bs, -1, length)
class AttentionBlock(nn.Module):
def __init__(
self,
channels,
num_heads=1,
num_head_channels=-1,
use_checkpoint=False,
use_new_attention_order=False,
):
super().__init__()
if num_head_channels == -1:
self.num_heads = num_heads
else:
assert channels % num_head_channels == 0
self.num_heads = channels // num_head_channels
self.use_checkpoint = use_checkpoint
self.norm = normalization(channels)
self.qkv = conv_nd(1, channels, channels * 3, 1)
self.attention = QKVAttention(self.num_heads) if use_new_attention_order else QKVAttentionLegacy(self.num_heads)
self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
def forward(self, x):
if self.use_checkpoint and x.requires_grad:
return torch_checkpoint(self._forward, x, use_reentrant=False)
return self._forward(x)
def _forward(self, x):
b, c, *spatial = x.shape
x = x.reshape(b, c, -1)
qkv = self.qkv(self.norm(x))
h = self.attention(qkv)
h = self.proj_out(h)
return (x + h).reshape(b, c, *spatial)
class AttentionPool2d(nn.Module):
"""CLIP-style attention pooling used by ADM noisy classifiers."""
def __init__(self, spacial_dim: int, embed_dim: int, num_heads_channels: int, output_dim: int = None):
super().__init__()
self.positional_embedding = nn.Parameter(torch.randn(embed_dim, spacial_dim**2 + 1) / embed_dim**0.5)
self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
self.num_heads = embed_dim // num_heads_channels
self.attention = QKVAttention(self.num_heads)
def forward(self, x):
b, c, *_spatial = x.shape
x = x.reshape(b, c, -1)
x = torch.cat([x.mean(dim=-1, keepdim=True), x], dim=-1)
x = x + self.positional_embedding[None, :, :].to(x.dtype)
x = self.qkv_proj(x)
x = self.attention(x)
x = self.c_proj(x)
return x[:, :, 0]
class EncoderUNetModel(nn.Module):
"""Noisy image classifier backbone for ADM-G (classifier guidance)."""
def __init__(
self,
image_size,
in_channels,
model_channels,
out_channels,
num_res_blocks,
attention_resolutions,
dropout=0,
channel_mult=(1, 2, 4, 8),
conv_resample=True,
dims=2,
use_checkpoint=False,
use_fp16=False,
num_heads=1,
num_head_channels=-1,
use_scale_shift_norm=False,
resblock_updown=False,
use_new_attention_order=False,
pool="adaptive",
):
super().__init__()
self.model_channels = model_channels
self.use_checkpoint = use_checkpoint
self.dtype = torch.float16 if use_fp16 else torch.float32
time_embed_dim = model_channels * 4
self.time_embed = nn.Sequential(
linear(model_channels, time_embed_dim),
nn.SiLU(),
linear(time_embed_dim, time_embed_dim),
)
ch = int(channel_mult[0] * model_channels)
self.input_blocks = nn.ModuleList([TimestepEmbedSequential(conv_nd(dims, in_channels, ch, 3, padding=1))])
self._feature_size = ch
ds = 1
for level, mult in enumerate(channel_mult):
for _ in range(num_res_blocks):
layers = [
ResBlock(
ch,
time_embed_dim,
dropout,
out_channels=int(mult * model_channels),
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
)
]
ch = int(mult * model_channels)
if ds in attention_resolutions:
layers.append(
AttentionBlock(
ch,
use_checkpoint=use_checkpoint,
num_heads=num_heads,
num_head_channels=num_head_channels,
use_new_attention_order=use_new_attention_order,
)
)
self.input_blocks.append(TimestepEmbedSequential(*layers))
self._feature_size += ch
if level != len(channel_mult) - 1:
out_ch = ch
self.input_blocks.append(
TimestepEmbedSequential(
ResBlock(
ch,
time_embed_dim,
dropout,
out_channels=out_ch,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
down=True,
)
if resblock_updown
else Downsample(ch, conv_resample, dims=dims, out_channels=out_ch)
)
)
ch = out_ch
ds *= 2
self._feature_size += ch
self.middle_block = TimestepEmbedSequential(
ResBlock(
ch,
time_embed_dim,
dropout,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
),
AttentionBlock(
ch,
use_checkpoint=use_checkpoint,
num_heads=num_heads,
num_head_channels=num_head_channels,
use_new_attention_order=use_new_attention_order,
),
ResBlock(
ch,
time_embed_dim,
dropout,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
),
)
self._feature_size += ch
self.pool = pool
if pool == "adaptive":
self.out = nn.Sequential(
normalization(ch),
nn.SiLU(),
nn.AdaptiveAvgPool2d((1, 1)),
zero_module(conv_nd(dims, ch, out_channels, 1)),
nn.Flatten(),
)
elif pool == "attention":
assert num_head_channels != -1
self.out = nn.Sequential(
normalization(ch),
nn.SiLU(),
AttentionPool2d((image_size // ds), ch, num_head_channels, out_channels),
)
elif pool == "spatial":
self.out = nn.Sequential(
nn.Linear(self._feature_size, 2048),
nn.ReLU(),
nn.Linear(2048, out_channels),
)
elif pool == "spatial_v2":
self.out = nn.Sequential(
nn.Linear(self._feature_size, 2048),
normalization(2048),
nn.SiLU(),
nn.Linear(2048, out_channels),
)
else:
raise NotImplementedError(f"Unexpected {pool} pooling")
def convert_to_fp16(self):
self.input_blocks.apply(convert_module_to_f16)
self.middle_block.apply(convert_module_to_f16)
def convert_to_fp32(self):
self.input_blocks.apply(convert_module_to_f32)
self.middle_block.apply(convert_module_to_f32)
def forward(self, x, timesteps):
emb = get_timestep_embedding(timesteps, self.model_channels).to(dtype=self.time_embed[0].weight.dtype)
emb = self.time_embed(emb)
results = []
h = x.to(dtype=self.time_embed[0].weight.dtype)
for module in self.input_blocks:
h = module(h, emb)
if self.pool.startswith("spatial"):
results.append(h.to(dtype=self.time_embed[0].weight.dtype).mean(dim=(2, 3)))
h = self.middle_block(h, emb)
if self.pool.startswith("spatial"):
results.append(h.to(dtype=self.time_embed[0].weight.dtype).mean(dim=(2, 3)))
h = torch.cat(results, dim=-1)
return self.out(h)
h = h.to(dtype=self.time_embed[0].weight.dtype)
return self.out(h)
def _default_channel_mult(image_size: int):
if image_size == 512:
return (0.5, 1, 1, 2, 2, 4, 4)
if image_size == 256:
return (1, 1, 2, 2, 4, 4)
if image_size == 128:
return (1, 1, 2, 3, 4)
if image_size == 64:
return (1, 2, 3, 4)
raise ValueError(f"unsupported image size: {image_size}")
def create_adm_classifier_model(
image_size: int,
classifier_width: int = 128,
classifier_depth: int = 2,
classifier_attention_resolutions: str = "32,16,8",
classifier_use_scale_shift_norm: bool = True,
classifier_resblock_updown: bool = True,
classifier_pool: str = "attention",
use_fp16: bool = False,
num_classes: int = NUM_CLASSES,
):
channel_mult = _default_channel_mult(image_size)
attention_ds = tuple(image_size // int(res) for res in classifier_attention_resolutions.split(","))
return EncoderUNetModel(
image_size=image_size,
in_channels=3,
model_channels=classifier_width,
out_channels=num_classes,
num_res_blocks=classifier_depth,
attention_resolutions=attention_ds,
channel_mult=channel_mult,
use_fp16=use_fp16,
num_head_channels=64,
use_scale_shift_norm=classifier_use_scale_shift_norm,
resblock_updown=classifier_resblock_updown,
pool=classifier_pool,
)
@dataclass
class ADMClassifierOutput(BaseOutput):
"""
Output of the ADM noisy image classifier.
Args:
logits (`torch.Tensor` of shape `(batch_size, num_classes)`):
Class logits for the noisy input.
"""
logits: torch.FloatTensor
class ADMClassifierModel(ModelMixin, ConfigMixin):
"""
Noisy ImageNet classifier for ADM-G classifier guidance.
This model predicts class labels from noisy images `x_t` and is used to compute gradients that steer
an unconditional ADM diffusion model toward a target class.
"""
@register_to_config
def __init__(
self,
image_size: int = 128,
classifier_width: int = 128,
classifier_depth: int = 2,
classifier_attention_resolutions: str = "32,16,8",
classifier_use_scale_shift_norm: bool = True,
classifier_resblock_updown: bool = True,
classifier_pool: str = "attention",
use_fp16: bool = False,
num_classes: int = 1000,
):
super().__init__()
self.model = create_adm_classifier_model(
image_size=image_size,
classifier_width=classifier_width,
classifier_depth=classifier_depth,
classifier_attention_resolutions=classifier_attention_resolutions,
classifier_use_scale_shift_norm=classifier_use_scale_shift_norm,
classifier_resblock_updown=classifier_resblock_updown,
classifier_pool=classifier_pool,
use_fp16=use_fp16,
num_classes=num_classes,
)
@property
def dtype(self) -> torch.dtype:
return next(self.parameters()).dtype
def forward(
self,
sample: torch.Tensor,
timestep: Union[torch.Tensor, float, int],
return_dict: bool = True,
) -> Union[ADMClassifierOutput, Tuple[torch.Tensor, ...]]:
"""
Args:
sample (`torch.Tensor`):
Noisy image `(batch_size, 3, height, width)` in `[-1, 1]`.
timestep (`torch.Tensor` or `float` or `int`):
Diffusion timestep indices (respaced indices during ADM-G sampling).
return_dict (`bool`, *optional*, defaults to `True`):
Whether to return an [`ADMClassifierOutput`].
Returns:
[`ADMClassifierOutput`] or `tuple`:
Classifier logits.
"""
if not torch.is_tensor(timestep):
timestep = torch.tensor([timestep], device=sample.device, dtype=torch.long)
elif timestep.ndim == 0:
timestep = timestep.reshape(1).to(device=sample.device)
if timestep.shape[0] == 1 and sample.shape[0] > 1:
timestep = timestep.expand(sample.shape[0])
logits = self.model(sample, timestep)
if not return_dict:
return (logits,)
return ADMClassifierOutput(logits=logits)
def guidance_gradient(
self,
sample: torch.Tensor,
timestep: torch.Tensor,
class_labels: torch.Tensor,
classifier_scale: float = 1.0,
) -> torch.Tensor:
"""
Compute `classifier_scale * grad_x log p(y | x_t)` for classifier guidance (ADM-G).
Args:
sample (`torch.Tensor`):
Current noisy sample `x_t`.
timestep (`torch.Tensor`):
Respaced diffusion timestep indices.
class_labels (`torch.Tensor`):
Target ImageNet class indices of shape `(batch_size,)`.
classifier_scale (`float`, *optional*, defaults to 1.0):
Guidance strength (OpenAI `classifier_scale`).
Returns:
`torch.Tensor`:
Gradient with respect to `sample`, same shape as `sample`.
"""
with torch.enable_grad():
x_in = sample.detach().requires_grad_(True)
logits = self.model(x_in, timestep)
log_probs = F.log_softmax(logits, dim=-1)
selected = log_probs[torch.arange(logits.shape[0], device=logits.device), class_labels.view(-1)]
grad = torch.autograd.grad(selected.sum(), x_in)[0]
return grad * classifier_scale