BiliSakura's picture
Upload folder using huggingface_hub
9dc3cb9 verified
# Copyright 2026 The HuggingFace Team. All rights reserved.
from __future__ import annotations
from dataclasses import dataclass
from functools import lru_cache
from typing import Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.models.modeling_utils import ModelMixin
from diffusers.utils import BaseOutput
def _modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
return x * (1 + scale) + shift
class NerfEmbedder(nn.Module):
def __init__(self, in_channels: int, hidden_size_input: int, max_freqs: int):
super().__init__()
self.max_freqs = max_freqs
self.embedder = nn.Sequential(nn.Linear(in_channels + max_freqs**2, hidden_size_input, bias=True))
@lru_cache
def fetch_pos(self, patch_size: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor:
pos_x = torch.linspace(0, 1, patch_size, device=device, dtype=dtype)
pos_y = torch.linspace(0, 1, patch_size, device=device, dtype=dtype)
pos_y, pos_x = torch.meshgrid(pos_y, pos_x, indexing="ij")
freqs = torch.linspace(0, self.max_freqs, self.max_freqs, dtype=dtype, device=device)
freqs_x = freqs[None, :, None]
freqs_y = freqs[None, None, :]
coeffs = (1 + freqs_x * freqs_y) ** -1
dct = (
torch.cos(pos_x.reshape(-1, 1, 1) * freqs_x * torch.pi)
* torch.cos(pos_y.reshape(-1, 1, 1) * freqs_y * torch.pi)
* coeffs
).view(1, -1, self.max_freqs**2)
return dct
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
batch_size, patch_tokens, _ = inputs.shape
patch_size = int(patch_tokens**0.5)
dct = self.fetch_pos(patch_size, inputs.device, inputs.dtype).repeat(batch_size, 1, 1)
return self.embedder(torch.cat([inputs, dct], dim=-1))
class ResBlock(nn.Module):
def __init__(self, channels: int):
super().__init__()
self.in_ln = nn.LayerNorm(channels, eps=1e-6)
self.mlp = nn.Sequential(
nn.Linear(channels, channels, bias=True),
nn.SiLU(),
nn.Linear(channels, channels, bias=True),
)
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(channels, 3 * channels, bias=True))
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(y).chunk(3, dim=-1)
return x + gate_mlp * self.mlp(_modulate(self.in_ln(x), shift_mlp, scale_mlp))
class DecoderFinalLayer(nn.Module):
def __init__(self, model_channels: int, out_channels: int):
super().__init__()
self.norm_final = nn.LayerNorm(model_channels, elementwise_affine=False, eps=1e-6)
self.linear = nn.Linear(model_channels, out_channels, bias=True)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.linear(self.norm_final(x))
class SimpleMLPAdaLN(nn.Module):
def __init__(
self,
in_channels: int,
model_channels: int,
out_channels: int,
z_channels: int,
num_res_blocks: int,
patch_size: int,
grad_checkpointing: bool = False,
):
super().__init__()
self.patch_size = patch_size
self.grad_checkpointing = grad_checkpointing
self.cond_embed = nn.Linear(z_channels, patch_size**2 * model_channels)
self.input_proj = nn.Linear(in_channels, model_channels)
self.res_blocks = nn.ModuleList([ResBlock(model_channels) for _ in range(num_res_blocks)])
self.final_layer = DecoderFinalLayer(model_channels, out_channels)
self._init_weights()
def _init_weights(self) -> None:
for block in self.res_blocks:
nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
nn.init.constant_(self.final_layer.linear.weight, 0)
nn.init.constant_(self.final_layer.linear.bias, 0)
def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
x = self.input_proj(x)
y = self.cond_embed(c).reshape(c.shape[0], self.patch_size**2, -1)
for block in self.res_blocks:
if self.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint(block, x, y)
else:
x = block(x, y)
return self.final_layer(x)
@dataclass
class DeCoPatchDecoderOutput(BaseOutput):
sample: torch.Tensor
class DeCoPatchDecoderModel(ModelMixin, ConfigMixin):
"""Per-patch RGB decoder for DeCo (NerfEmbedder + AdaLN MLP)."""
config_name = "config.json"
@register_to_config
def __init__(
self,
in_channels: int = 3,
hidden_size_x: int = 32,
z_channels: int = 1152,
num_res_blocks: int = 3,
patch_size: int = 16,
max_freqs: int = 8,
):
super().__init__()
self.x_embedder = NerfEmbedder(in_channels, hidden_size_x, max_freqs=max_freqs)
self.dec_net = SimpleMLPAdaLN(
in_channels=hidden_size_x,
model_channels=hidden_size_x,
out_channels=in_channels,
z_channels=z_channels,
num_res_blocks=num_res_blocks,
patch_size=patch_size,
)
def forward(
self,
patch_pixels: torch.Tensor,
conditioning: torch.Tensor,
return_dict: bool = True,
) -> Union[DeCoPatchDecoderOutput, tuple[torch.Tensor]]:
"""
Args:
patch_pixels (`torch.Tensor`):
Flattened patch pixels of shape `(batch * num_patches, patch_size ** 2, in_channels)`.
conditioning (`torch.Tensor`):
Per-patch conditioning of shape `(batch * num_patches, z_channels)`.
"""
output = self.dec_net(self.x_embedder(patch_pixels), conditioning)
if not return_dict:
return (output,)
return DeCoPatchDecoderOutput(sample=output)