Unconditional Image Generation
Diffusers
Safetensors
English
deco
image-generation
class-conditional
imagenet
Instructions to use BiliSakura/DeCo-diffusers with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Diffusers
How to use BiliSakura/DeCo-diffusers with Diffusers:
pip install -U diffusers transformers accelerate
import torch from diffusers import DiffusionPipeline # switch to "mps" for apple devices pipe = DiffusionPipeline.from_pretrained("BiliSakura/DeCo-diffusers", dtype=torch.bfloat16, device_map="cuda") prompt = "golden retriever" image = pipe(prompt).images[0] - Notebooks
- Google Colab
- Kaggle
| # Copyright 2026 The HuggingFace Team. All rights reserved. | |
| from __future__ import annotations | |
| from dataclasses import dataclass | |
| from functools import lru_cache | |
| from typing import Union | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torch.utils.checkpoint import checkpoint | |
| from diffusers.configuration_utils import ConfigMixin, register_to_config | |
| from diffusers.models.modeling_utils import ModelMixin | |
| from diffusers.utils import BaseOutput | |
| def _modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: | |
| return x * (1 + scale) + shift | |
| class NerfEmbedder(nn.Module): | |
| def __init__(self, in_channels: int, hidden_size_input: int, max_freqs: int): | |
| super().__init__() | |
| self.max_freqs = max_freqs | |
| self.embedder = nn.Sequential(nn.Linear(in_channels + max_freqs**2, hidden_size_input, bias=True)) | |
| def fetch_pos(self, patch_size: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor: | |
| pos_x = torch.linspace(0, 1, patch_size, device=device, dtype=dtype) | |
| pos_y = torch.linspace(0, 1, patch_size, device=device, dtype=dtype) | |
| pos_y, pos_x = torch.meshgrid(pos_y, pos_x, indexing="ij") | |
| freqs = torch.linspace(0, self.max_freqs, self.max_freqs, dtype=dtype, device=device) | |
| freqs_x = freqs[None, :, None] | |
| freqs_y = freqs[None, None, :] | |
| coeffs = (1 + freqs_x * freqs_y) ** -1 | |
| dct = ( | |
| torch.cos(pos_x.reshape(-1, 1, 1) * freqs_x * torch.pi) | |
| * torch.cos(pos_y.reshape(-1, 1, 1) * freqs_y * torch.pi) | |
| * coeffs | |
| ).view(1, -1, self.max_freqs**2) | |
| return dct | |
| def forward(self, inputs: torch.Tensor) -> torch.Tensor: | |
| batch_size, patch_tokens, _ = inputs.shape | |
| patch_size = int(patch_tokens**0.5) | |
| dct = self.fetch_pos(patch_size, inputs.device, inputs.dtype).repeat(batch_size, 1, 1) | |
| return self.embedder(torch.cat([inputs, dct], dim=-1)) | |
| class ResBlock(nn.Module): | |
| def __init__(self, channels: int): | |
| super().__init__() | |
| self.in_ln = nn.LayerNorm(channels, eps=1e-6) | |
| self.mlp = nn.Sequential( | |
| nn.Linear(channels, channels, bias=True), | |
| nn.SiLU(), | |
| nn.Linear(channels, channels, bias=True), | |
| ) | |
| self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(channels, 3 * channels, bias=True)) | |
| def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: | |
| shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(y).chunk(3, dim=-1) | |
| return x + gate_mlp * self.mlp(_modulate(self.in_ln(x), shift_mlp, scale_mlp)) | |
| class DecoderFinalLayer(nn.Module): | |
| def __init__(self, model_channels: int, out_channels: int): | |
| super().__init__() | |
| self.norm_final = nn.LayerNorm(model_channels, elementwise_affine=False, eps=1e-6) | |
| self.linear = nn.Linear(model_channels, out_channels, bias=True) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| return self.linear(self.norm_final(x)) | |
| class SimpleMLPAdaLN(nn.Module): | |
| def __init__( | |
| self, | |
| in_channels: int, | |
| model_channels: int, | |
| out_channels: int, | |
| z_channels: int, | |
| num_res_blocks: int, | |
| patch_size: int, | |
| grad_checkpointing: bool = False, | |
| ): | |
| super().__init__() | |
| self.patch_size = patch_size | |
| self.grad_checkpointing = grad_checkpointing | |
| self.cond_embed = nn.Linear(z_channels, patch_size**2 * model_channels) | |
| self.input_proj = nn.Linear(in_channels, model_channels) | |
| self.res_blocks = nn.ModuleList([ResBlock(model_channels) for _ in range(num_res_blocks)]) | |
| self.final_layer = DecoderFinalLayer(model_channels, out_channels) | |
| self._init_weights() | |
| def _init_weights(self) -> None: | |
| for block in self.res_blocks: | |
| nn.init.constant_(block.adaLN_modulation[-1].weight, 0) | |
| nn.init.constant_(block.adaLN_modulation[-1].bias, 0) | |
| nn.init.constant_(self.final_layer.linear.weight, 0) | |
| nn.init.constant_(self.final_layer.linear.bias, 0) | |
| def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor: | |
| x = self.input_proj(x) | |
| y = self.cond_embed(c).reshape(c.shape[0], self.patch_size**2, -1) | |
| for block in self.res_blocks: | |
| if self.grad_checkpointing and not torch.jit.is_scripting(): | |
| x = checkpoint(block, x, y) | |
| else: | |
| x = block(x, y) | |
| return self.final_layer(x) | |
| class DeCoPatchDecoderOutput(BaseOutput): | |
| sample: torch.Tensor | |
| class DeCoPatchDecoderModel(ModelMixin, ConfigMixin): | |
| """Per-patch RGB decoder for DeCo (NerfEmbedder + AdaLN MLP).""" | |
| config_name = "config.json" | |
| def __init__( | |
| self, | |
| in_channels: int = 3, | |
| hidden_size_x: int = 32, | |
| z_channels: int = 1152, | |
| num_res_blocks: int = 3, | |
| patch_size: int = 16, | |
| max_freqs: int = 8, | |
| ): | |
| super().__init__() | |
| self.x_embedder = NerfEmbedder(in_channels, hidden_size_x, max_freqs=max_freqs) | |
| self.dec_net = SimpleMLPAdaLN( | |
| in_channels=hidden_size_x, | |
| model_channels=hidden_size_x, | |
| out_channels=in_channels, | |
| z_channels=z_channels, | |
| num_res_blocks=num_res_blocks, | |
| patch_size=patch_size, | |
| ) | |
| def forward( | |
| self, | |
| patch_pixels: torch.Tensor, | |
| conditioning: torch.Tensor, | |
| return_dict: bool = True, | |
| ) -> Union[DeCoPatchDecoderOutput, tuple[torch.Tensor]]: | |
| """ | |
| Args: | |
| patch_pixels (`torch.Tensor`): | |
| Flattened patch pixels of shape `(batch * num_patches, patch_size ** 2, in_channels)`. | |
| conditioning (`torch.Tensor`): | |
| Per-patch conditioning of shape `(batch * num_patches, z_channels)`. | |
| """ | |
| output = self.dec_net(self.x_embedder(patch_pixels), conditioning) | |
| if not return_dict: | |
| return (output,) | |
| return DeCoPatchDecoderOutput(sample=output) | |