File size: 7,997 Bytes
1b703d5 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 | """DINO token/class alignment head used by the DINAC-AE export."""
from __future__ import annotations
from dataclasses import dataclass
from typing import TYPE_CHECKING
import torch
from torch import Tensor, nn
from common.norms import RMSNorm
from dit.axial_rope2d import (
AxialRoPE2D,
AxialRoPE2DConfig,
AxialRoPE2DCoordMode,
AxialRoPE2DDimLayout,
AxialRoPE2DNormalizeCoords,
)
from dit.blocks import DitBlock
from dit.body_config import DiTConditioning
from dit.position_encoding import DiTPositionEncoding
from dit.xattn_blocks import CrossAttentionBlock, CrossAttentionConfig
if TYPE_CHECKING:
from dit.mlp_types import MLPType
@dataclass(frozen=True)
class DinoTokenAlignmentOutput:
"""Predicted DINO class token and spatial patch tokens."""
class_token: Tensor
spatial_tokens: Tensor
def _prepend_identity_rope_prefix(
*,
rope_sincos: tuple[Tensor, Tensor],
prefix_token_count: int,
device: torch.device,
) -> tuple[Tensor, Tensor]:
"""Prepend no-op RoPE entries for class/register prefix tokens."""
sin, cos = rope_sincos
prefix_shape = (int(prefix_token_count), int(sin.shape[-1]))
prefix_sin = torch.zeros(prefix_shape, device=device, dtype=sin.dtype)
prefix_cos = torch.ones(prefix_shape, device=device, dtype=cos.dtype)
match sin.dim():
case 2:
return (
torch.cat([prefix_sin, sin.to(device=device)], dim=0),
torch.cat([prefix_cos, cos.to(device=device)], dim=0),
)
case 3:
batch = int(sin.shape[0])
return (
torch.cat(
[
prefix_sin.unsqueeze(0).expand(batch, -1, -1),
sin.to(device=device),
],
dim=1,
),
torch.cat(
[
prefix_cos.unsqueeze(0).expand(batch, -1, -1),
cos.to(device=device),
],
dim=1,
),
)
case _ as unreachable:
raise ValueError(f"Unsupported RoPE tensor rank: {int(unreachable)}")
class DinoTokenAlignmentHead(nn.Module):
"""Predict DINOv3 spatial tokens and a class token from latent grids."""
in_channels: int
feature_dim: int
model_dim: int
register_token_count: int
in_proj: nn.Conv2d
initial_class_token: nn.Parameter
register_tokens: nn.Parameter
block: DitBlock
spatial_output_norm: RMSNorm
class_readout: CrossAttentionBlock
class_output_norm: RMSNorm
_axial_rope2d: AxialRoPE2D
def __init__(
self,
*,
in_channels: int,
feature_dim: int,
model_dim: int,
head_dim: int,
mlp_ratio: float,
mlp_activation: MLPType,
block_index: int,
register_token_count: int,
) -> None:
super().__init__()
if int(feature_dim) != int(model_dim):
raise ValueError("DINAC-AE class head requires feature_dim == model_dim")
if int(register_token_count) != 4:
raise ValueError("DINAC-AE class head requires four register tokens")
self.register_token_count = int(register_token_count)
self.in_channels = int(in_channels)
self.feature_dim = int(feature_dim)
self.model_dim = int(model_dim)
self.in_proj = nn.Conv2d(
self.in_channels,
self.model_dim,
kernel_size=1,
padding=0,
stride=1,
bias=True,
)
self.initial_class_token = nn.Parameter(torch.empty((1, self.model_dim)))
self.register_tokens = nn.Parameter(
torch.empty((self.register_token_count, self.model_dim))
)
nn.init.normal_(self.initial_class_token, mean=0.0, std=0.02)
nn.init.normal_(self.register_tokens, mean=0.0, std=0.02)
conditioning = DiTConditioning.UNCOND
self.block = DitBlock(
d_model=self.model_dim,
n_heads=int(self.model_dim // int(head_dim)),
mlp_ratio=float(mlp_ratio),
mlp_type=mlp_activation,
block_index=int(block_index),
use_norms=True,
position_encoding=DiTPositionEncoding.ROPE_2D_AXIAL_UNNORMALIZED,
conditioning=conditioning,
)
self.spatial_output_norm = RMSNorm(self.model_dim, affine=False)
self.class_readout = CrossAttentionBlock(
query_dim=self.model_dim,
context_dim=self.model_dim,
cfg=CrossAttentionConfig(
n_heads=int(self.model_dim // int(head_dim)),
head_dim=int(head_dim),
query_extra_dim=0,
context_extra_dim=0,
mlp_ratio=float(mlp_ratio),
attn_dropout=0.0,
mlp_type=mlp_activation,
activation_config=None,
use_norms=True,
block_index=int(block_index) + 1,
use_attn_residual=True,
),
)
self.class_output_norm = RMSNorm(self.model_dim, affine=False)
self._axial_rope2d = AxialRoPE2D(
head_dim=int(head_dim),
cfg=AxialRoPE2DConfig(
base=10_000.0,
min_period=None,
max_period=None,
coord_mode=AxialRoPE2DCoordMode.PATCH_INDICES,
normalize_coords=AxialRoPE2DNormalizeCoords.MAX,
dim_layout=AxialRoPE2DDimLayout.PAIR_INTERLEAVED,
angle_multiplier=1.0,
coord_offset=0.0,
frequency_aware=None,
beta_warp=None,
alpha_warp=None,
),
)
def compile_for_training(self, *, fullgraph: bool, dynamic: bool) -> None:
"""No-op hook kept for source API compatibility."""
_ = fullgraph, dynamic
def compile_for_eval(self, *, fullgraph: bool, dynamic: bool) -> None:
"""No-op hook kept for source API compatibility."""
_ = fullgraph, dynamic
def forward(self, latents: Tensor, *, t: Tensor) -> DinoTokenAlignmentOutput:
"""Return predicted class and spatial DINO tokens."""
y = self.in_proj(latents)
batch, _channels, height, width = y.shape
spatial_tokens = y.flatten(2).transpose(1, 2)
class_token = self.initial_class_token.to(device=y.device, dtype=y.dtype)
class_token = class_token.unsqueeze(0).expand(int(batch), -1, -1)
register_tokens = self.register_tokens.to(device=y.device, dtype=y.dtype)
register_tokens = register_tokens.unsqueeze(0).expand(int(batch), -1, -1)
tokens = torch.cat([class_token, register_tokens, spatial_tokens], dim=1)
rope_sincos = _prepend_identity_rope_prefix(
rope_sincos=self._axial_rope2d(H=int(height), W=int(width), scales=None),
prefix_token_count=int(1 + self.register_token_count),
device=y.device,
)
_ = t
cond = torch.zeros(
(int(batch), self.model_dim),
device=y.device,
dtype=y.dtype,
)
tokens = self.block(
tokens,
hw=(int(height), int(width)),
cond_vec=cond,
adaln_m=None,
rope_sincos=rope_sincos,
generator=None,
)
class_query = tokens[:, :1, :]
context = tokens[:, 1:, :]
class_output = self.class_readout(class_query, context)[:, 0, :]
class_output = self.class_output_norm(class_output)
prefix_token_count = int(1 + self.register_token_count)
predicted_spatial = self.spatial_output_norm(tokens[:, prefix_token_count:, :])
return DinoTokenAlignmentOutput(
class_token=class_output,
spatial_tokens=predicted_spatial,
)
__all__ = ["DinoTokenAlignmentHead", "DinoTokenAlignmentOutput"]
|