| """
|
| DDPMUnconditionalPipeline for unconditional image generation.
|
| pipeline.py is in the repo — use custom_pipeline="pipeline" (relative path).
|
|
|
| Usage::
|
|
|
| from diffusers import DiffusionPipeline
|
|
|
| pipe = DiffusionPipeline.from_pretrained(
|
| "BiliSakura/ddpm-cd-pretrained-256",
|
| custom_pipeline="pipeline",
|
| trust_remote_code=True,
|
| )
|
| images = pipe.generate(batch_size=4, image_size=256)
|
| """
|
|
|
| import math
|
| from inspect import isfunction
|
|
|
| import numpy as np
|
| import torch
|
| import torch.nn as nn
|
| from diffusers import DDPMScheduler
|
| from diffusers.configuration_utils import ConfigMixin, register_to_config
|
| from diffusers.models.modeling_utils import ModelMixin
|
| from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
| from tqdm.auto import tqdm
|
|
|
|
|
|
|
|
|
|
|
|
|
| def _exists(x):
|
| return x is not None
|
|
|
|
|
| def _default(val, d):
|
| if _exists(val):
|
| return val
|
| return d() if isfunction(d) else d
|
|
|
|
|
| class PositionalEncoding(nn.Module):
|
| def __init__(self, dim):
|
| super().__init__()
|
| self.dim = dim
|
|
|
| def forward(self, noise_level):
|
| count = self.dim // 2
|
| step = torch.arange(count, dtype=noise_level.dtype, device=noise_level.device) / count
|
| encoding = noise_level.unsqueeze(1) * torch.exp(-math.log(1e4) * step.unsqueeze(0))
|
| return torch.cat([torch.sin(encoding), torch.cos(encoding)], dim=-1)
|
|
|
|
|
| class FeatureWiseAffine(nn.Module):
|
| def __init__(self, in_channels, out_channels, use_affine_level=False):
|
| super().__init__()
|
| self.use_affine_level = use_affine_level
|
| self.noise_func = nn.Sequential(nn.Linear(in_channels, out_channels * (1 + self.use_affine_level)))
|
|
|
| def forward(self, x, noise_embed):
|
| batch = x.shape[0]
|
| if self.use_affine_level:
|
| gamma, beta = self.noise_func(noise_embed).view(batch, -1, 1, 1).chunk(2, dim=1)
|
| x = (1 + gamma) * x + beta
|
| else:
|
| x = x + self.noise_func(noise_embed).view(batch, -1, 1, 1)
|
| return x
|
|
|
|
|
| class Swish(nn.Module):
|
| def forward(self, x):
|
| return x * torch.sigmoid(x)
|
|
|
|
|
| class Upsample(nn.Module):
|
| def __init__(self, dim):
|
| super().__init__()
|
| self.up = nn.Upsample(scale_factor=2, mode="nearest")
|
| self.conv = nn.Conv2d(dim, dim, 3, padding=1)
|
|
|
| def forward(self, x):
|
| return self.conv(self.up(x))
|
|
|
|
|
| class Downsample(nn.Module):
|
| def __init__(self, dim):
|
| super().__init__()
|
| self.conv = nn.Conv2d(dim, dim, 3, 2, 1)
|
|
|
| def forward(self, x):
|
| return self.conv(x)
|
|
|
|
|
| class Block(nn.Module):
|
| def __init__(self, dim, dim_out, groups=32, dropout=0):
|
| super().__init__()
|
| self.block = nn.Sequential(
|
| nn.GroupNorm(groups, dim),
|
| Swish(),
|
| nn.Dropout(dropout) if dropout != 0 else nn.Identity(),
|
| nn.Conv2d(dim, dim_out, 3, padding=1),
|
| )
|
|
|
| def forward(self, x):
|
| return self.block(x)
|
|
|
|
|
| class ResnetBlock(nn.Module):
|
| def __init__(self, dim, dim_out, noise_level_emb_dim=None, dropout=0, use_affine_level=False, norm_groups=32):
|
| super().__init__()
|
| self.noise_func = FeatureWiseAffine(noise_level_emb_dim, dim_out, use_affine_level)
|
| self.block1 = Block(dim, dim_out, groups=norm_groups)
|
| self.block2 = Block(dim_out, dim_out, groups=norm_groups, dropout=dropout)
|
| self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
|
|
|
| def forward(self, x, time_emb):
|
| h = self.block1(x)
|
| h = self.noise_func(h, time_emb)
|
| h = self.block2(h)
|
| return h + self.res_conv(x)
|
|
|
|
|
| class SelfAttention(nn.Module):
|
| def __init__(self, in_channel, n_head=1, norm_groups=32):
|
| super().__init__()
|
| self.n_head = n_head
|
| self.norm = nn.GroupNorm(norm_groups, in_channel)
|
| self.qkv = nn.Conv2d(in_channel, in_channel * 3, 1, bias=False)
|
| self.out = nn.Conv2d(in_channel, in_channel, 1)
|
|
|
| def forward(self, input):
|
| batch, channel, height, width = input.shape
|
| n_head, head_dim = self.n_head, channel // self.n_head
|
| norm = self.norm(input)
|
| qkv = self.qkv(norm).view(batch, n_head, head_dim * 3, height, width)
|
| query, key, value = qkv.chunk(3, dim=2)
|
| attn = torch.einsum("bnchw, bncyx -> bnhwyx", query, key).contiguous() / math.sqrt(channel)
|
| attn = torch.softmax(attn.view(batch, n_head, height, width, -1), -1)
|
| attn = attn.view(batch, n_head, height, width, height, width)
|
| out = torch.einsum("bnhwyx, bncyx -> bnchw", attn, value).contiguous()
|
| return self.out(out.view(batch, channel, height, width)) + input
|
|
|
|
|
| class ResnetBlocWithAttn(nn.Module):
|
| def __init__(self, dim, dim_out, *, noise_level_emb_dim=None, norm_groups=32, dropout=0, with_attn=False):
|
| super().__init__()
|
| self.with_attn = with_attn
|
| self.res_block = ResnetBlock(dim, dim_out, noise_level_emb_dim, norm_groups=norm_groups, dropout=dropout)
|
| self.attn = SelfAttention(dim_out, norm_groups=norm_groups) if with_attn else None
|
|
|
| def forward(self, x, time_emb):
|
| x = self.res_block(x, time_emb)
|
| if self.with_attn:
|
| x = self.attn(x)
|
| return x
|
|
|
|
|
| class UNet(ModelMixin, ConfigMixin):
|
| """SR3-style UNet with noise-level conditioning."""
|
|
|
| @register_to_config
|
| def __init__(
|
| self,
|
| in_channel=6,
|
| out_channel=3,
|
| inner_channel=32,
|
| norm_groups=32,
|
| channel_mults=(1, 2, 4, 8, 8),
|
| attn_res=(8,),
|
| res_blocks=3,
|
| dropout=0,
|
| with_noise_level_emb=True,
|
| image_size=128,
|
| ):
|
| super().__init__()
|
| noise_level_channel = inner_channel if with_noise_level_emb else None
|
| self.noise_level_mlp = (
|
| nn.Sequential(
|
| PositionalEncoding(inner_channel),
|
| nn.Linear(inner_channel, inner_channel * 4),
|
| Swish(),
|
| nn.Linear(inner_channel * 4, inner_channel),
|
| )
|
| if with_noise_level_emb
|
| else None
|
| )
|
|
|
| num_mults = len(channel_mults)
|
| pre_channel, feat_channels, now_res = inner_channel, [inner_channel], image_size
|
| self.init_conv = nn.Conv2d(in_channel, inner_channel, 3, padding=1)
|
|
|
| downs = []
|
| for ind in range(num_mults):
|
| use_attn = now_res in attn_res
|
| channel_mult = inner_channel * channel_mults[ind]
|
| for _ in range(res_blocks):
|
| downs.append(
|
| ResnetBlocWithAttn(
|
| pre_channel, channel_mult,
|
| noise_level_emb_dim=noise_level_channel, norm_groups=norm_groups,
|
| dropout=dropout, with_attn=use_attn,
|
| )
|
| )
|
| feat_channels.append(channel_mult)
|
| pre_channel = channel_mult
|
| if ind < num_mults - 1:
|
| downs.append(Downsample(pre_channel))
|
| feat_channels.append(pre_channel)
|
| now_res = now_res // 2
|
| self.downs = nn.ModuleList(downs)
|
|
|
| self.mid = nn.ModuleList([
|
| ResnetBlocWithAttn(pre_channel, pre_channel, noise_level_emb_dim=noise_level_channel,
|
| norm_groups=norm_groups, dropout=dropout, with_attn=True),
|
| ResnetBlocWithAttn(pre_channel, pre_channel, noise_level_emb_dim=noise_level_channel,
|
| norm_groups=norm_groups, dropout=dropout, with_attn=False),
|
| ])
|
|
|
| ups = []
|
| for ind in reversed(range(num_mults)):
|
| use_attn = now_res in attn_res
|
| channel_mult = inner_channel * channel_mults[ind]
|
| for _ in range(res_blocks + 1):
|
| ups.append(
|
| ResnetBlocWithAttn(
|
| pre_channel + feat_channels.pop(), channel_mult,
|
| noise_level_emb_dim=noise_level_channel, norm_groups=norm_groups,
|
| dropout=dropout, with_attn=use_attn,
|
| )
|
| )
|
| pre_channel = channel_mult
|
| if ind > 0:
|
| ups.append(Upsample(pre_channel))
|
| now_res = now_res * 2
|
| self.ups = nn.ModuleList(ups)
|
| self.final_conv = Block(pre_channel, _default(out_channel, lambda: in_channel), groups=norm_groups)
|
|
|
| def forward(self, x, time):
|
| t = self.noise_level_mlp(time) if _exists(self.noise_level_mlp) else None
|
| x = self.init_conv(x)
|
| feats = [x]
|
| for layer in self.downs:
|
| x = layer(x, t) if isinstance(layer, ResnetBlocWithAttn) else layer(x)
|
| feats.append(x)
|
| for layer in self.mid:
|
| x = layer(x, t) if isinstance(layer, ResnetBlocWithAttn) else layer(x)
|
| for layer in self.ups:
|
| if isinstance(layer, ResnetBlocWithAttn):
|
| x = layer(torch.cat((x, feats.pop()), dim=1), t)
|
| else:
|
| x = layer(x)
|
| return self.final_conv(x)
|
|
|
|
|
|
|
|
|
|
|
|
|
| def _precompute_alpha_tables(scheduler):
|
| ac = scheduler.alphas_cumprod.numpy()
|
| return np.sqrt(np.append(1.0, ac))
|
|
|
|
|
| class DDPMUnconditionalPipeline(DiffusionPipeline):
|
| """Unconditional DDPM image generation. Load with custom_pipeline and trust_remote_code=True."""
|
|
|
| def __init__(self, unet, scheduler):
|
| super().__init__()
|
| self.register_modules(unet=unet, scheduler=scheduler)
|
|
|
| @torch.no_grad()
|
| def generate(self, batch_size=1, in_channels=3, image_size=256, num_inference_steps=None, generator=None):
|
| """Generate images via denoising."""
|
| device = next(self.unet.parameters()).device
|
| steps = num_inference_steps or self.scheduler.config.num_train_timesteps
|
| sqrt_a = _precompute_alpha_tables(self.scheduler)
|
| image = torch.randn((batch_size, in_channels, image_size, image_size), device=device, generator=generator)
|
| self.scheduler.set_timesteps(steps)
|
| for t in tqdm(self.scheduler.timesteps, desc="Sampling"):
|
| idx = min(int(t) + 1, len(sqrt_a) - 1)
|
| lvl = torch.FloatTensor([sqrt_a[idx]]).repeat(batch_size, 1).to(device)
|
| noise_pred = self.unet(image, lvl)
|
| image = self.scheduler.step(noise_pred, t, image).prev_sample
|
| return image
|
|
|