# Copyright 2026 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. import math from abc import abstractmethod from dataclasses import dataclass from typing import Optional, Tuple, Union import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.checkpoint import checkpoint as torch_checkpoint from diffusers.configuration_utils import ConfigMixin, register_to_config from diffusers.models.embeddings import get_timestep_embedding from diffusers.models.modeling_utils import ModelMixin from diffusers.utils import BaseOutput NUM_CLASSES = 1000 def conv_nd(dims: int, *args, **kwargs): if dims == 1: return nn.Conv1d(*args, **kwargs) if dims == 2: return nn.Conv2d(*args, **kwargs) if dims == 3: return nn.Conv3d(*args, **kwargs) raise ValueError(f"unsupported dimensions: {dims}") def linear(*args, **kwargs): return nn.Linear(*args, **kwargs) def avg_pool_nd(dims: int, *args, **kwargs): if dims == 1: return nn.AvgPool1d(*args, **kwargs) if dims == 2: return nn.AvgPool2d(*args, **kwargs) if dims == 3: return nn.AvgPool3d(*args, **kwargs) raise ValueError(f"unsupported dimensions: {dims}") class GroupNorm32(nn.GroupNorm): def forward(self, x): weight = self.weight.float() if self.weight is not None else None bias = self.bias.float() if self.bias is not None else None y = F.group_norm(x.float(), self.num_groups, weight, bias, self.eps) return y.to(dtype=x.dtype) def normalization(channels: int): return GroupNorm32(32, channels) def zero_module(module: nn.Module): for p in module.parameters(): p.detach().zero_() return module def convert_module_to_f16(module: nn.Module): if isinstance(module, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): module.weight.data = module.weight.data.half() if module.bias is not None: module.bias.data = module.bias.data.half() def convert_module_to_f32(module: nn.Module): if isinstance(module, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): module.weight.data = module.weight.data.float() if module.bias is not None: module.bias.data = module.bias.data.float() class TimestepBlock(nn.Module): @abstractmethod def forward(self, x, emb): raise NotImplementedError class TimestepEmbedSequential(nn.Sequential, TimestepBlock): def forward(self, x, emb): for layer in self: if isinstance(layer, TimestepBlock): x = layer(x, emb) else: x = layer(x) return x class Upsample(nn.Module): def __init__(self, channels, use_conv, dims=2, out_channels=None): super().__init__() self.channels = channels self.out_channels = out_channels or channels self.use_conv = use_conv self.dims = dims if use_conv: self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=1) def forward(self, x): assert x.shape[1] == self.channels if self.dims == 3: x = F.interpolate(x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest") else: x = F.interpolate(x, scale_factor=2, mode="nearest") if self.use_conv: x = self.conv(x) return x class Downsample(nn.Module): def __init__(self, channels, use_conv, dims=2, out_channels=None): super().__init__() self.channels = channels self.out_channels = out_channels or channels self.use_conv = use_conv stride = 2 if dims != 3 else (1, 2, 2) if use_conv: self.op = conv_nd(dims, self.channels, self.out_channels, 3, stride=stride, padding=1) else: assert self.channels == self.out_channels self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride) def forward(self, x): assert x.shape[1] == self.channels return self.op(x) class ResBlock(TimestepBlock): def __init__( self, channels, emb_channels, dropout, out_channels=None, use_conv=False, use_scale_shift_norm=False, dims=2, use_checkpoint=False, up=False, down=False, ): super().__init__() self.channels = channels self.out_channels = out_channels or channels self.use_checkpoint = use_checkpoint self.use_scale_shift_norm = use_scale_shift_norm self.in_layers = nn.Sequential( normalization(channels), nn.SiLU(), conv_nd(dims, channels, self.out_channels, 3, padding=1), ) self.updown = up or down if up: self.h_upd = Upsample(channels, False, dims) self.x_upd = Upsample(channels, False, dims) elif down: self.h_upd = Downsample(channels, False, dims) self.x_upd = Downsample(channels, False, dims) else: self.h_upd = self.x_upd = nn.Identity() self.emb_layers = nn.Sequential( nn.SiLU(), linear(emb_channels, 2 * self.out_channels if use_scale_shift_norm else self.out_channels), ) self.out_layers = nn.Sequential( normalization(self.out_channels), nn.SiLU(), nn.Dropout(p=dropout), zero_module(conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)), ) if self.out_channels == channels: self.skip_connection = nn.Identity() elif use_conv: self.skip_connection = conv_nd(dims, channels, self.out_channels, 3, padding=1) else: self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) def forward(self, x, emb): if self.use_checkpoint and x.requires_grad: return torch_checkpoint(self._forward, x, emb, use_reentrant=False) return self._forward(x, emb) def _forward(self, x, emb): if self.updown: in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] h = in_rest(x) h = self.h_upd(h) x = self.x_upd(x) h = in_conv(h) else: h = self.in_layers(x) emb_out = self.emb_layers(emb).type(h.dtype) while len(emb_out.shape) < len(h.shape): emb_out = emb_out[..., None] if self.use_scale_shift_norm: out_norm, out_rest = self.out_layers[0], self.out_layers[1:] scale, shift = torch.chunk(emb_out, 2, dim=1) h = out_norm(h) * (1 + scale) + shift h = out_rest(h) else: h = h + emb_out h = self.out_layers(h) return self.skip_connection(x) + h class QKVAttentionLegacy(nn.Module): def __init__(self, n_heads): super().__init__() self.n_heads = n_heads def forward(self, qkv): bs, width, length = qkv.shape assert width % (3 * self.n_heads) == 0 ch = width // (3 * self.n_heads) q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1) scale = 1 / math.sqrt(math.sqrt(ch)) weight = torch.einsum("bct,bcs->bts", q * scale, k * scale) weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) a = torch.einsum("bts,bcs->bct", weight, v) return a.reshape(bs, -1, length) class QKVAttention(nn.Module): def __init__(self, n_heads): super().__init__() self.n_heads = n_heads def forward(self, qkv): bs, width, length = qkv.shape assert width % (3 * self.n_heads) == 0 ch = width // (3 * self.n_heads) q, k, v = qkv.chunk(3, dim=1) scale = 1 / math.sqrt(math.sqrt(ch)) weight = torch.einsum( "bct,bcs->bts", (q * scale).view(bs * self.n_heads, ch, length), (k * scale).view(bs * self.n_heads, ch, length), ) weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) a = torch.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length)) return a.reshape(bs, -1, length) class AttentionBlock(nn.Module): def __init__( self, channels, num_heads=1, num_head_channels=-1, use_checkpoint=False, use_new_attention_order=False, ): super().__init__() if num_head_channels == -1: self.num_heads = num_heads else: assert channels % num_head_channels == 0 self.num_heads = channels // num_head_channels self.use_checkpoint = use_checkpoint self.norm = normalization(channels) self.qkv = conv_nd(1, channels, channels * 3, 1) self.attention = QKVAttention(self.num_heads) if use_new_attention_order else QKVAttentionLegacy(self.num_heads) self.proj_out = zero_module(conv_nd(1, channels, channels, 1)) def forward(self, x): if self.use_checkpoint and x.requires_grad: return torch_checkpoint(self._forward, x, use_reentrant=False) return self._forward(x) def _forward(self, x): b, c, *spatial = x.shape x = x.reshape(b, c, -1) qkv = self.qkv(self.norm(x)) h = self.attention(qkv) h = self.proj_out(h) return (x + h).reshape(b, c, *spatial) class AttentionPool2d(nn.Module): """CLIP-style attention pooling used by ADM noisy classifiers.""" def __init__(self, spacial_dim: int, embed_dim: int, num_heads_channels: int, output_dim: int = None): super().__init__() self.positional_embedding = nn.Parameter(torch.randn(embed_dim, spacial_dim**2 + 1) / embed_dim**0.5) self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1) self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1) self.num_heads = embed_dim // num_heads_channels self.attention = QKVAttention(self.num_heads) def forward(self, x): b, c, *_spatial = x.shape x = x.reshape(b, c, -1) x = torch.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) x = x + self.positional_embedding[None, :, :].to(x.dtype) x = self.qkv_proj(x) x = self.attention(x) x = self.c_proj(x) return x[:, :, 0] class EncoderUNetModel(nn.Module): """Noisy image classifier backbone for ADM-G (classifier guidance).""" def __init__( self, image_size, in_channels, model_channels, out_channels, num_res_blocks, attention_resolutions, dropout=0, channel_mult=(1, 2, 4, 8), conv_resample=True, dims=2, use_checkpoint=False, use_fp16=False, num_heads=1, num_head_channels=-1, use_scale_shift_norm=False, resblock_updown=False, use_new_attention_order=False, pool="adaptive", ): super().__init__() self.model_channels = model_channels self.use_checkpoint = use_checkpoint self.dtype = torch.float16 if use_fp16 else torch.float32 time_embed_dim = model_channels * 4 self.time_embed = nn.Sequential( linear(model_channels, time_embed_dim), nn.SiLU(), linear(time_embed_dim, time_embed_dim), ) ch = int(channel_mult[0] * model_channels) self.input_blocks = nn.ModuleList([TimestepEmbedSequential(conv_nd(dims, in_channels, ch, 3, padding=1))]) self._feature_size = ch ds = 1 for level, mult in enumerate(channel_mult): for _ in range(num_res_blocks): layers = [ ResBlock( ch, time_embed_dim, dropout, out_channels=int(mult * model_channels), dims=dims, use_checkpoint=use_checkpoint, use_scale_shift_norm=use_scale_shift_norm, ) ] ch = int(mult * model_channels) if ds in attention_resolutions: layers.append( AttentionBlock( ch, use_checkpoint=use_checkpoint, num_heads=num_heads, num_head_channels=num_head_channels, use_new_attention_order=use_new_attention_order, ) ) self.input_blocks.append(TimestepEmbedSequential(*layers)) self._feature_size += ch if level != len(channel_mult) - 1: out_ch = ch self.input_blocks.append( TimestepEmbedSequential( ResBlock( ch, time_embed_dim, dropout, out_channels=out_ch, dims=dims, use_checkpoint=use_checkpoint, use_scale_shift_norm=use_scale_shift_norm, down=True, ) if resblock_updown else Downsample(ch, conv_resample, dims=dims, out_channels=out_ch) ) ) ch = out_ch ds *= 2 self._feature_size += ch self.middle_block = TimestepEmbedSequential( ResBlock( ch, time_embed_dim, dropout, dims=dims, use_checkpoint=use_checkpoint, use_scale_shift_norm=use_scale_shift_norm, ), AttentionBlock( ch, use_checkpoint=use_checkpoint, num_heads=num_heads, num_head_channels=num_head_channels, use_new_attention_order=use_new_attention_order, ), ResBlock( ch, time_embed_dim, dropout, dims=dims, use_checkpoint=use_checkpoint, use_scale_shift_norm=use_scale_shift_norm, ), ) self._feature_size += ch self.pool = pool if pool == "adaptive": self.out = nn.Sequential( normalization(ch), nn.SiLU(), nn.AdaptiveAvgPool2d((1, 1)), zero_module(conv_nd(dims, ch, out_channels, 1)), nn.Flatten(), ) elif pool == "attention": assert num_head_channels != -1 self.out = nn.Sequential( normalization(ch), nn.SiLU(), AttentionPool2d((image_size // ds), ch, num_head_channels, out_channels), ) elif pool == "spatial": self.out = nn.Sequential( nn.Linear(self._feature_size, 2048), nn.ReLU(), nn.Linear(2048, out_channels), ) elif pool == "spatial_v2": self.out = nn.Sequential( nn.Linear(self._feature_size, 2048), normalization(2048), nn.SiLU(), nn.Linear(2048, out_channels), ) else: raise NotImplementedError(f"Unexpected {pool} pooling") def convert_to_fp16(self): self.input_blocks.apply(convert_module_to_f16) self.middle_block.apply(convert_module_to_f16) def convert_to_fp32(self): self.input_blocks.apply(convert_module_to_f32) self.middle_block.apply(convert_module_to_f32) def forward(self, x, timesteps): emb = get_timestep_embedding(timesteps, self.model_channels).to(dtype=self.time_embed[0].weight.dtype) emb = self.time_embed(emb) results = [] h = x.to(dtype=self.time_embed[0].weight.dtype) for module in self.input_blocks: h = module(h, emb) if self.pool.startswith("spatial"): results.append(h.to(dtype=self.time_embed[0].weight.dtype).mean(dim=(2, 3))) h = self.middle_block(h, emb) if self.pool.startswith("spatial"): results.append(h.to(dtype=self.time_embed[0].weight.dtype).mean(dim=(2, 3))) h = torch.cat(results, dim=-1) return self.out(h) h = h.to(dtype=self.time_embed[0].weight.dtype) return self.out(h) def _default_channel_mult(image_size: int): if image_size == 512: return (0.5, 1, 1, 2, 2, 4, 4) if image_size == 256: return (1, 1, 2, 2, 4, 4) if image_size == 128: return (1, 1, 2, 3, 4) if image_size == 64: return (1, 2, 3, 4) raise ValueError(f"unsupported image size: {image_size}") def create_adm_classifier_model( image_size: int, classifier_width: int = 128, classifier_depth: int = 2, classifier_attention_resolutions: str = "32,16,8", classifier_use_scale_shift_norm: bool = True, classifier_resblock_updown: bool = True, classifier_pool: str = "attention", use_fp16: bool = False, num_classes: int = NUM_CLASSES, ): channel_mult = _default_channel_mult(image_size) attention_ds = tuple(image_size // int(res) for res in classifier_attention_resolutions.split(",")) return EncoderUNetModel( image_size=image_size, in_channels=3, model_channels=classifier_width, out_channels=num_classes, num_res_blocks=classifier_depth, attention_resolutions=attention_ds, channel_mult=channel_mult, use_fp16=use_fp16, num_head_channels=64, use_scale_shift_norm=classifier_use_scale_shift_norm, resblock_updown=classifier_resblock_updown, pool=classifier_pool, ) @dataclass class ADMClassifierOutput(BaseOutput): """ Output of the ADM noisy image classifier. Args: logits (`torch.Tensor` of shape `(batch_size, num_classes)`): Class logits for the noisy input. """ logits: torch.FloatTensor class ADMClassifierModel(ModelMixin, ConfigMixin): """ Noisy ImageNet classifier for ADM-G classifier guidance. This model predicts class labels from noisy images `x_t` and is used to compute gradients that steer an unconditional ADM diffusion model toward a target class. """ @register_to_config def __init__( self, image_size: int = 128, classifier_width: int = 128, classifier_depth: int = 2, classifier_attention_resolutions: str = "32,16,8", classifier_use_scale_shift_norm: bool = True, classifier_resblock_updown: bool = True, classifier_pool: str = "attention", use_fp16: bool = False, num_classes: int = 1000, ): super().__init__() self.model = create_adm_classifier_model( image_size=image_size, classifier_width=classifier_width, classifier_depth=classifier_depth, classifier_attention_resolutions=classifier_attention_resolutions, classifier_use_scale_shift_norm=classifier_use_scale_shift_norm, classifier_resblock_updown=classifier_resblock_updown, classifier_pool=classifier_pool, use_fp16=use_fp16, num_classes=num_classes, ) @property def dtype(self) -> torch.dtype: return next(self.parameters()).dtype def forward( self, sample: torch.Tensor, timestep: Union[torch.Tensor, float, int], return_dict: bool = True, ) -> Union[ADMClassifierOutput, Tuple[torch.Tensor, ...]]: """ Args: sample (`torch.Tensor`): Noisy image `(batch_size, 3, height, width)` in `[-1, 1]`. timestep (`torch.Tensor` or `float` or `int`): Diffusion timestep indices (respaced indices during ADM-G sampling). return_dict (`bool`, *optional*, defaults to `True`): Whether to return an [`ADMClassifierOutput`]. Returns: [`ADMClassifierOutput`] or `tuple`: Classifier logits. """ if not torch.is_tensor(timestep): timestep = torch.tensor([timestep], device=sample.device, dtype=torch.long) elif timestep.ndim == 0: timestep = timestep.reshape(1).to(device=sample.device) if timestep.shape[0] == 1 and sample.shape[0] > 1: timestep = timestep.expand(sample.shape[0]) logits = self.model(sample, timestep) if not return_dict: return (logits,) return ADMClassifierOutput(logits=logits) def guidance_gradient( self, sample: torch.Tensor, timestep: torch.Tensor, class_labels: torch.Tensor, classifier_scale: float = 1.0, ) -> torch.Tensor: """ Compute `classifier_scale * grad_x log p(y | x_t)` for classifier guidance (ADM-G). Args: sample (`torch.Tensor`): Current noisy sample `x_t`. timestep (`torch.Tensor`): Respaced diffusion timestep indices. class_labels (`torch.Tensor`): Target ImageNet class indices of shape `(batch_size,)`. classifier_scale (`float`, *optional*, defaults to 1.0): Guidance strength (OpenAI `classifier_scale`). Returns: `torch.Tensor`: Gradient with respect to `sample`, same shape as `sample`. """ with torch.enable_grad(): x_in = sample.detach().requires_grad_(True) logits = self.model(x_in, timestep) log_probs = F.log_softmax(logits, dim=-1) selected = log_probs[torch.arange(logits.shape[0], device=logits.device), class_labels.view(-1)] grad = torch.autograd.grad(selected.sum(), x_in)[0] return grad * classifier_scale