dinac_ae / dit /repa_projection.py
data-archetype's picture
Upload DINAC-AE export package
1b703d5
"""DINO token/class alignment head used by the DINAC-AE export."""
from __future__ import annotations
from dataclasses import dataclass
from typing import TYPE_CHECKING
import torch
from torch import Tensor, nn
from common.norms import RMSNorm
from dit.axial_rope2d import (
AxialRoPE2D,
AxialRoPE2DConfig,
AxialRoPE2DCoordMode,
AxialRoPE2DDimLayout,
AxialRoPE2DNormalizeCoords,
)
from dit.blocks import DitBlock
from dit.body_config import DiTConditioning
from dit.position_encoding import DiTPositionEncoding
from dit.xattn_blocks import CrossAttentionBlock, CrossAttentionConfig
if TYPE_CHECKING:
from dit.mlp_types import MLPType
@dataclass(frozen=True)
class DinoTokenAlignmentOutput:
"""Predicted DINO class token and spatial patch tokens."""
class_token: Tensor
spatial_tokens: Tensor
def _prepend_identity_rope_prefix(
*,
rope_sincos: tuple[Tensor, Tensor],
prefix_token_count: int,
device: torch.device,
) -> tuple[Tensor, Tensor]:
"""Prepend no-op RoPE entries for class/register prefix tokens."""
sin, cos = rope_sincos
prefix_shape = (int(prefix_token_count), int(sin.shape[-1]))
prefix_sin = torch.zeros(prefix_shape, device=device, dtype=sin.dtype)
prefix_cos = torch.ones(prefix_shape, device=device, dtype=cos.dtype)
match sin.dim():
case 2:
return (
torch.cat([prefix_sin, sin.to(device=device)], dim=0),
torch.cat([prefix_cos, cos.to(device=device)], dim=0),
)
case 3:
batch = int(sin.shape[0])
return (
torch.cat(
[
prefix_sin.unsqueeze(0).expand(batch, -1, -1),
sin.to(device=device),
],
dim=1,
),
torch.cat(
[
prefix_cos.unsqueeze(0).expand(batch, -1, -1),
cos.to(device=device),
],
dim=1,
),
)
case _ as unreachable:
raise ValueError(f"Unsupported RoPE tensor rank: {int(unreachable)}")
class DinoTokenAlignmentHead(nn.Module):
"""Predict DINOv3 spatial tokens and a class token from latent grids."""
in_channels: int
feature_dim: int
model_dim: int
register_token_count: int
in_proj: nn.Conv2d
initial_class_token: nn.Parameter
register_tokens: nn.Parameter
block: DitBlock
spatial_output_norm: RMSNorm
class_readout: CrossAttentionBlock
class_output_norm: RMSNorm
_axial_rope2d: AxialRoPE2D
def __init__(
self,
*,
in_channels: int,
feature_dim: int,
model_dim: int,
head_dim: int,
mlp_ratio: float,
mlp_activation: MLPType,
block_index: int,
register_token_count: int,
) -> None:
super().__init__()
if int(feature_dim) != int(model_dim):
raise ValueError("DINAC-AE class head requires feature_dim == model_dim")
if int(register_token_count) != 4:
raise ValueError("DINAC-AE class head requires four register tokens")
self.register_token_count = int(register_token_count)
self.in_channels = int(in_channels)
self.feature_dim = int(feature_dim)
self.model_dim = int(model_dim)
self.in_proj = nn.Conv2d(
self.in_channels,
self.model_dim,
kernel_size=1,
padding=0,
stride=1,
bias=True,
)
self.initial_class_token = nn.Parameter(torch.empty((1, self.model_dim)))
self.register_tokens = nn.Parameter(
torch.empty((self.register_token_count, self.model_dim))
)
nn.init.normal_(self.initial_class_token, mean=0.0, std=0.02)
nn.init.normal_(self.register_tokens, mean=0.0, std=0.02)
conditioning = DiTConditioning.UNCOND
self.block = DitBlock(
d_model=self.model_dim,
n_heads=int(self.model_dim // int(head_dim)),
mlp_ratio=float(mlp_ratio),
mlp_type=mlp_activation,
block_index=int(block_index),
use_norms=True,
position_encoding=DiTPositionEncoding.ROPE_2D_AXIAL_UNNORMALIZED,
conditioning=conditioning,
)
self.spatial_output_norm = RMSNorm(self.model_dim, affine=False)
self.class_readout = CrossAttentionBlock(
query_dim=self.model_dim,
context_dim=self.model_dim,
cfg=CrossAttentionConfig(
n_heads=int(self.model_dim // int(head_dim)),
head_dim=int(head_dim),
query_extra_dim=0,
context_extra_dim=0,
mlp_ratio=float(mlp_ratio),
attn_dropout=0.0,
mlp_type=mlp_activation,
activation_config=None,
use_norms=True,
block_index=int(block_index) + 1,
use_attn_residual=True,
),
)
self.class_output_norm = RMSNorm(self.model_dim, affine=False)
self._axial_rope2d = AxialRoPE2D(
head_dim=int(head_dim),
cfg=AxialRoPE2DConfig(
base=10_000.0,
min_period=None,
max_period=None,
coord_mode=AxialRoPE2DCoordMode.PATCH_INDICES,
normalize_coords=AxialRoPE2DNormalizeCoords.MAX,
dim_layout=AxialRoPE2DDimLayout.PAIR_INTERLEAVED,
angle_multiplier=1.0,
coord_offset=0.0,
frequency_aware=None,
beta_warp=None,
alpha_warp=None,
),
)
def compile_for_training(self, *, fullgraph: bool, dynamic: bool) -> None:
"""No-op hook kept for source API compatibility."""
_ = fullgraph, dynamic
def compile_for_eval(self, *, fullgraph: bool, dynamic: bool) -> None:
"""No-op hook kept for source API compatibility."""
_ = fullgraph, dynamic
def forward(self, latents: Tensor, *, t: Tensor) -> DinoTokenAlignmentOutput:
"""Return predicted class and spatial DINO tokens."""
y = self.in_proj(latents)
batch, _channels, height, width = y.shape
spatial_tokens = y.flatten(2).transpose(1, 2)
class_token = self.initial_class_token.to(device=y.device, dtype=y.dtype)
class_token = class_token.unsqueeze(0).expand(int(batch), -1, -1)
register_tokens = self.register_tokens.to(device=y.device, dtype=y.dtype)
register_tokens = register_tokens.unsqueeze(0).expand(int(batch), -1, -1)
tokens = torch.cat([class_token, register_tokens, spatial_tokens], dim=1)
rope_sincos = _prepend_identity_rope_prefix(
rope_sincos=self._axial_rope2d(H=int(height), W=int(width), scales=None),
prefix_token_count=int(1 + self.register_token_count),
device=y.device,
)
_ = t
cond = torch.zeros(
(int(batch), self.model_dim),
device=y.device,
dtype=y.dtype,
)
tokens = self.block(
tokens,
hw=(int(height), int(width)),
cond_vec=cond,
adaln_m=None,
rope_sincos=rope_sincos,
generator=None,
)
class_query = tokens[:, :1, :]
context = tokens[:, 1:, :]
class_output = self.class_readout(class_query, context)[:, 0, :]
class_output = self.class_output_norm(class_output)
prefix_token_count = int(1 + self.register_token_count)
predicted_spatial = self.spatial_output_norm(tokens[:, prefix_token_count:, :])
return DinoTokenAlignmentOutput(
class_token=class_output,
spatial_tokens=predicted_spatial,
)
__all__ = ["DinoTokenAlignmentHead", "DinoTokenAlignmentOutput"]