Unconditional Image Generation
Diffusers
Safetensors
English
deco
image-generation
class-conditional
imagenet
Instructions to use BiliSakura/DeCo-diffusers with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Diffusers
How to use BiliSakura/DeCo-diffusers with Diffusers:
pip install -U diffusers transformers accelerate
import torch from diffusers import DiffusionPipeline # switch to "mps" for apple devices pipe = DiffusionPipeline.from_pretrained("BiliSakura/DeCo-diffusers", dtype=torch.bfloat16, device_map="cuda") prompt = "golden retriever" image = pipe(prompt).images[0] - Notebooks
- Google Colab
- Kaggle
File size: 6,128 Bytes
9dc3cb9 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 | # Copyright 2026 The HuggingFace Team. All rights reserved.
from __future__ import annotations
from dataclasses import dataclass
from functools import lru_cache
from typing import Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.models.modeling_utils import ModelMixin
from diffusers.utils import BaseOutput
def _modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
return x * (1 + scale) + shift
class NerfEmbedder(nn.Module):
def __init__(self, in_channels: int, hidden_size_input: int, max_freqs: int):
super().__init__()
self.max_freqs = max_freqs
self.embedder = nn.Sequential(nn.Linear(in_channels + max_freqs**2, hidden_size_input, bias=True))
@lru_cache
def fetch_pos(self, patch_size: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor:
pos_x = torch.linspace(0, 1, patch_size, device=device, dtype=dtype)
pos_y = torch.linspace(0, 1, patch_size, device=device, dtype=dtype)
pos_y, pos_x = torch.meshgrid(pos_y, pos_x, indexing="ij")
freqs = torch.linspace(0, self.max_freqs, self.max_freqs, dtype=dtype, device=device)
freqs_x = freqs[None, :, None]
freqs_y = freqs[None, None, :]
coeffs = (1 + freqs_x * freqs_y) ** -1
dct = (
torch.cos(pos_x.reshape(-1, 1, 1) * freqs_x * torch.pi)
* torch.cos(pos_y.reshape(-1, 1, 1) * freqs_y * torch.pi)
* coeffs
).view(1, -1, self.max_freqs**2)
return dct
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
batch_size, patch_tokens, _ = inputs.shape
patch_size = int(patch_tokens**0.5)
dct = self.fetch_pos(patch_size, inputs.device, inputs.dtype).repeat(batch_size, 1, 1)
return self.embedder(torch.cat([inputs, dct], dim=-1))
class ResBlock(nn.Module):
def __init__(self, channels: int):
super().__init__()
self.in_ln = nn.LayerNorm(channels, eps=1e-6)
self.mlp = nn.Sequential(
nn.Linear(channels, channels, bias=True),
nn.SiLU(),
nn.Linear(channels, channels, bias=True),
)
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(channels, 3 * channels, bias=True))
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(y).chunk(3, dim=-1)
return x + gate_mlp * self.mlp(_modulate(self.in_ln(x), shift_mlp, scale_mlp))
class DecoderFinalLayer(nn.Module):
def __init__(self, model_channels: int, out_channels: int):
super().__init__()
self.norm_final = nn.LayerNorm(model_channels, elementwise_affine=False, eps=1e-6)
self.linear = nn.Linear(model_channels, out_channels, bias=True)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.linear(self.norm_final(x))
class SimpleMLPAdaLN(nn.Module):
def __init__(
self,
in_channels: int,
model_channels: int,
out_channels: int,
z_channels: int,
num_res_blocks: int,
patch_size: int,
grad_checkpointing: bool = False,
):
super().__init__()
self.patch_size = patch_size
self.grad_checkpointing = grad_checkpointing
self.cond_embed = nn.Linear(z_channels, patch_size**2 * model_channels)
self.input_proj = nn.Linear(in_channels, model_channels)
self.res_blocks = nn.ModuleList([ResBlock(model_channels) for _ in range(num_res_blocks)])
self.final_layer = DecoderFinalLayer(model_channels, out_channels)
self._init_weights()
def _init_weights(self) -> None:
for block in self.res_blocks:
nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
nn.init.constant_(self.final_layer.linear.weight, 0)
nn.init.constant_(self.final_layer.linear.bias, 0)
def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
x = self.input_proj(x)
y = self.cond_embed(c).reshape(c.shape[0], self.patch_size**2, -1)
for block in self.res_blocks:
if self.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint(block, x, y)
else:
x = block(x, y)
return self.final_layer(x)
@dataclass
class DeCoPatchDecoderOutput(BaseOutput):
sample: torch.Tensor
class DeCoPatchDecoderModel(ModelMixin, ConfigMixin):
"""Per-patch RGB decoder for DeCo (NerfEmbedder + AdaLN MLP)."""
config_name = "config.json"
@register_to_config
def __init__(
self,
in_channels: int = 3,
hidden_size_x: int = 32,
z_channels: int = 1152,
num_res_blocks: int = 3,
patch_size: int = 16,
max_freqs: int = 8,
):
super().__init__()
self.x_embedder = NerfEmbedder(in_channels, hidden_size_x, max_freqs=max_freqs)
self.dec_net = SimpleMLPAdaLN(
in_channels=hidden_size_x,
model_channels=hidden_size_x,
out_channels=in_channels,
z_channels=z_channels,
num_res_blocks=num_res_blocks,
patch_size=patch_size,
)
def forward(
self,
patch_pixels: torch.Tensor,
conditioning: torch.Tensor,
return_dict: bool = True,
) -> Union[DeCoPatchDecoderOutput, tuple[torch.Tensor]]:
"""
Args:
patch_pixels (`torch.Tensor`):
Flattened patch pixels of shape `(batch * num_patches, patch_size ** 2, in_channels)`.
conditioning (`torch.Tensor`):
Per-patch conditioning of shape `(batch * num_patches, z_channels)`.
"""
output = self.dec_net(self.x_embedder(patch_pixels), conditioning)
if not return_dict:
return (output,)
return DeCoPatchDecoderOutput(sample=output)
|