"""DINO token/class alignment head used by the DINAC-AE export.""" from __future__ import annotations from dataclasses import dataclass from typing import TYPE_CHECKING import torch from torch import Tensor, nn from common.norms import RMSNorm from dit.axial_rope2d import ( AxialRoPE2D, AxialRoPE2DConfig, AxialRoPE2DCoordMode, AxialRoPE2DDimLayout, AxialRoPE2DNormalizeCoords, ) from dit.blocks import DitBlock from dit.body_config import DiTConditioning from dit.position_encoding import DiTPositionEncoding from dit.xattn_blocks import CrossAttentionBlock, CrossAttentionConfig if TYPE_CHECKING: from dit.mlp_types import MLPType @dataclass(frozen=True) class DinoTokenAlignmentOutput: """Predicted DINO class token and spatial patch tokens.""" class_token: Tensor spatial_tokens: Tensor def _prepend_identity_rope_prefix( *, rope_sincos: tuple[Tensor, Tensor], prefix_token_count: int, device: torch.device, ) -> tuple[Tensor, Tensor]: """Prepend no-op RoPE entries for class/register prefix tokens.""" sin, cos = rope_sincos prefix_shape = (int(prefix_token_count), int(sin.shape[-1])) prefix_sin = torch.zeros(prefix_shape, device=device, dtype=sin.dtype) prefix_cos = torch.ones(prefix_shape, device=device, dtype=cos.dtype) match sin.dim(): case 2: return ( torch.cat([prefix_sin, sin.to(device=device)], dim=0), torch.cat([prefix_cos, cos.to(device=device)], dim=0), ) case 3: batch = int(sin.shape[0]) return ( torch.cat( [ prefix_sin.unsqueeze(0).expand(batch, -1, -1), sin.to(device=device), ], dim=1, ), torch.cat( [ prefix_cos.unsqueeze(0).expand(batch, -1, -1), cos.to(device=device), ], dim=1, ), ) case _ as unreachable: raise ValueError(f"Unsupported RoPE tensor rank: {int(unreachable)}") class DinoTokenAlignmentHead(nn.Module): """Predict DINOv3 spatial tokens and a class token from latent grids.""" in_channels: int feature_dim: int model_dim: int register_token_count: int in_proj: nn.Conv2d initial_class_token: nn.Parameter register_tokens: nn.Parameter block: DitBlock spatial_output_norm: RMSNorm class_readout: CrossAttentionBlock class_output_norm: RMSNorm _axial_rope2d: AxialRoPE2D def __init__( self, *, in_channels: int, feature_dim: int, model_dim: int, head_dim: int, mlp_ratio: float, mlp_activation: MLPType, block_index: int, register_token_count: int, ) -> None: super().__init__() if int(feature_dim) != int(model_dim): raise ValueError("DINAC-AE class head requires feature_dim == model_dim") if int(register_token_count) != 4: raise ValueError("DINAC-AE class head requires four register tokens") self.register_token_count = int(register_token_count) self.in_channels = int(in_channels) self.feature_dim = int(feature_dim) self.model_dim = int(model_dim) self.in_proj = nn.Conv2d( self.in_channels, self.model_dim, kernel_size=1, padding=0, stride=1, bias=True, ) self.initial_class_token = nn.Parameter(torch.empty((1, self.model_dim))) self.register_tokens = nn.Parameter( torch.empty((self.register_token_count, self.model_dim)) ) nn.init.normal_(self.initial_class_token, mean=0.0, std=0.02) nn.init.normal_(self.register_tokens, mean=0.0, std=0.02) conditioning = DiTConditioning.UNCOND self.block = DitBlock( d_model=self.model_dim, n_heads=int(self.model_dim // int(head_dim)), mlp_ratio=float(mlp_ratio), mlp_type=mlp_activation, block_index=int(block_index), use_norms=True, position_encoding=DiTPositionEncoding.ROPE_2D_AXIAL_UNNORMALIZED, conditioning=conditioning, ) self.spatial_output_norm = RMSNorm(self.model_dim, affine=False) self.class_readout = CrossAttentionBlock( query_dim=self.model_dim, context_dim=self.model_dim, cfg=CrossAttentionConfig( n_heads=int(self.model_dim // int(head_dim)), head_dim=int(head_dim), query_extra_dim=0, context_extra_dim=0, mlp_ratio=float(mlp_ratio), attn_dropout=0.0, mlp_type=mlp_activation, activation_config=None, use_norms=True, block_index=int(block_index) + 1, use_attn_residual=True, ), ) self.class_output_norm = RMSNorm(self.model_dim, affine=False) self._axial_rope2d = AxialRoPE2D( head_dim=int(head_dim), cfg=AxialRoPE2DConfig( base=10_000.0, min_period=None, max_period=None, coord_mode=AxialRoPE2DCoordMode.PATCH_INDICES, normalize_coords=AxialRoPE2DNormalizeCoords.MAX, dim_layout=AxialRoPE2DDimLayout.PAIR_INTERLEAVED, angle_multiplier=1.0, coord_offset=0.0, frequency_aware=None, beta_warp=None, alpha_warp=None, ), ) def compile_for_training(self, *, fullgraph: bool, dynamic: bool) -> None: """No-op hook kept for source API compatibility.""" _ = fullgraph, dynamic def compile_for_eval(self, *, fullgraph: bool, dynamic: bool) -> None: """No-op hook kept for source API compatibility.""" _ = fullgraph, dynamic def forward(self, latents: Tensor, *, t: Tensor) -> DinoTokenAlignmentOutput: """Return predicted class and spatial DINO tokens.""" y = self.in_proj(latents) batch, _channels, height, width = y.shape spatial_tokens = y.flatten(2).transpose(1, 2) class_token = self.initial_class_token.to(device=y.device, dtype=y.dtype) class_token = class_token.unsqueeze(0).expand(int(batch), -1, -1) register_tokens = self.register_tokens.to(device=y.device, dtype=y.dtype) register_tokens = register_tokens.unsqueeze(0).expand(int(batch), -1, -1) tokens = torch.cat([class_token, register_tokens, spatial_tokens], dim=1) rope_sincos = _prepend_identity_rope_prefix( rope_sincos=self._axial_rope2d(H=int(height), W=int(width), scales=None), prefix_token_count=int(1 + self.register_token_count), device=y.device, ) _ = t cond = torch.zeros( (int(batch), self.model_dim), device=y.device, dtype=y.dtype, ) tokens = self.block( tokens, hw=(int(height), int(width)), cond_vec=cond, adaln_m=None, rope_sincos=rope_sincos, generator=None, ) class_query = tokens[:, :1, :] context = tokens[:, 1:, :] class_output = self.class_readout(class_query, context)[:, 0, :] class_output = self.class_output_norm(class_output) prefix_token_count = int(1 + self.register_token_count) predicted_spatial = self.spatial_output_norm(tokens[:, prefix_token_count:, :]) return DinoTokenAlignmentOutput( class_token=class_output, spatial_tokens=predicted_spatial, ) __all__ = ["DinoTokenAlignmentHead", "DinoTokenAlignmentOutput"]