import math from functools import lru_cache import torch from torch import nn import torch.nn.functional as F D_MODEL = 1280 N_HEADS = 20 HEAD_DIM = D_MODEL // N_HEADS N_LAYERS = 32 D_FFN = 5120 N_REGISTERS = 4 PATCH_SIZE = 16 ROPE_THETA = 100.0 ROPE_RESCALE = 2.0 LN_EPS = 1e-5 LAYERSCALE = 1.0 @lru_cache(maxsize=32) def _patch_coords_cached(h: int, w: int, device_str: str) -> torch.Tensor: device = torch.device(device_str) cy = torch.arange(0.5, h, dtype=torch.float32, device=device) / h cx = torch.arange(0.5, w, dtype=torch.float32, device=device) / w coords = torch.stack(torch.meshgrid(cy, cx, indexing="ij"), dim=-1).flatten(0, 1) coords = 2.0 * coords - 1.0 coords = coords * ROPE_RESCALE return coords def _build_rope(h_patches: int, w_patches: int, dtype: torch.dtype, device: torch.device): coords = _patch_coords_cached(h_patches, w_patches, str(device)) inv_freq = 1.0 / (ROPE_THETA ** torch.arange(0, 1, 4 / HEAD_DIM, dtype=torch.float32, device=device)) angles = 2 * math.pi * coords[:, :, None] * inv_freq[None, None, :] angles = angles.flatten(1, 2).tile(2) cos = torch.cos(angles).to(dtype).unsqueeze(0).unsqueeze(0) sin = torch.sin(angles).to(dtype).unsqueeze(0).unsqueeze(0) return cos, sin def _rotate_half(x: torch.Tensor) -> torch.Tensor: h = x.shape[-1] // 2 return torch.cat((-x[..., h:], x[..., :h]), dim=-1) def _apply_rope(q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor): n_pre = 1 + N_REGISTERS q_pre, q_pat = q[..., :n_pre, :], q[..., n_pre:, :] k_pre, k_pat = k[..., :n_pre, :], k[..., n_pre:, :] q_pat = q_pat * cos + _rotate_half(q_pat) * sin k_pat = k_pat * cos + _rotate_half(k_pat) * sin return torch.cat([q_pre, q_pat], dim=-2), torch.cat([k_pre, k_pat], dim=-2) class _Attention(nn.Module): def __init__(self): super().__init__() self.q_proj = nn.Linear(D_MODEL, D_MODEL, bias=True) self.k_proj = nn.Linear(D_MODEL, D_MODEL, bias=False) self.v_proj = nn.Linear(D_MODEL, D_MODEL, bias=True) self.o_proj = nn.Linear(D_MODEL, D_MODEL, bias=True) def forward(self, x, cos, sin): bsz, seq_len, _ = x.shape q = self.q_proj(x).view(bsz, seq_len, N_HEADS, HEAD_DIM).transpose(1, 2) k = self.k_proj(x).view(bsz, seq_len, N_HEADS, HEAD_DIM).transpose(1, 2) v = self.v_proj(x).view(bsz, seq_len, N_HEADS, HEAD_DIM).transpose(1, 2) q, k = _apply_rope(q, k, cos, sin) out = F.scaled_dot_product_attention(q, k, v, scale=HEAD_DIM ** -0.5) return self.o_proj(out.transpose(1, 2).reshape(bsz, seq_len, D_MODEL)) class _GatedMLP(nn.Module): def __init__(self): super().__init__() self.gate_proj = nn.Linear(D_MODEL, D_FFN, bias=True) self.up_proj = nn.Linear(D_MODEL, D_FFN, bias=True) self.down_proj = nn.Linear(D_FFN, D_MODEL, bias=True) def forward(self, x): return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x)) class _Block(nn.Module): def __init__(self): super().__init__() self.norm1 = nn.LayerNorm(D_MODEL, eps=LN_EPS) self.attention = _Attention() self.layer_scale1 = nn.Parameter(torch.full((D_MODEL,), LAYERSCALE)) self.norm2 = nn.LayerNorm(D_MODEL, eps=LN_EPS) self.mlp = _GatedMLP() self.layer_scale2 = nn.Parameter(torch.full((D_MODEL,), LAYERSCALE)) def forward(self, x, cos, sin): x = x + self.attention(self.norm1(x), cos, sin) * self.layer_scale1 x = x + self.mlp(self.norm2(x)) * self.layer_scale2 return x class _Embeddings(nn.Module): def __init__(self): super().__init__() self.cls_token = nn.Parameter(torch.zeros(1, 1, D_MODEL)) self.mask_token = nn.Parameter(torch.zeros(1, 1, D_MODEL)) self.register_tokens = nn.Parameter(torch.zeros(1, N_REGISTERS, D_MODEL)) self.patch_embeddings = nn.Conv2d(3, D_MODEL, kernel_size=PATCH_SIZE, stride=PATCH_SIZE) def forward(self, pixel_values): batch = pixel_values.shape[0] dtype = self.patch_embeddings.weight.dtype patches = self.patch_embeddings(pixel_values.to(dtype)).flatten(2).transpose(1, 2) cls = self.cls_token.expand(batch, -1, -1) regs = self.register_tokens.expand(batch, -1, -1) return torch.cat([cls, regs, patches], dim=1) class DINOv3ViTH(nn.Module): def __init__(self): super().__init__() self.embeddings = _Embeddings() self.layer = nn.ModuleList([_Block() for _ in range(N_LAYERS)]) self.norm = nn.LayerNorm(D_MODEL, eps=LN_EPS) def forward(self, pixel_values): _, _, height, width = pixel_values.shape x = self.embeddings(pixel_values) h_p, w_p = height // PATCH_SIZE, width // PATCH_SIZE cos, sin = _build_rope(h_p, w_p, x.dtype, pixel_values.device) for block in self.layer: x = block(x, cos, sin) return self.norm(x) def _split_and_clean_state_dict(sd: dict) -> tuple[dict, dict]: backbone_sd: dict = {} head_sd: dict = {} for key, value in sd.items(): if key.startswith("backbone."): new_key = key[len("backbone."):] if new_key.startswith("model.layer."): new_key = new_key[len("model."):] backbone_sd[new_key] = value else: head_sd[key] = value for key in list(backbone_sd.keys()): if ".layer_scale" in key and key.endswith(".lambda1"): backbone_sd[key[:-len(".lambda1")]] = backbone_sd.pop(key) for key in list(backbone_sd.keys()): if "rope_embeddings" in key: backbone_sd.pop(key) return backbone_sd, head_sd class ResLinear(nn.Module): def __init__(self, *args, **kwargs): super().__init__() self.linear = nn.Linear(*args, **kwargs) self.act_fn = nn.SiLU() def forward(self, inputs): return self.act_fn(self.linear(inputs)) + inputs class ScoringHead(nn.Module): def __init__(self, in_size): super().__init__() self.layer1 = nn.Linear(in_size, 256) self.act_fn = nn.SiLU() self.res_layers = nn.Sequential( ResLinear(256, 256), ResLinear(256, 256), ResLinear(256, 256), ResLinear(256, 256), ResLinear(256, 256), ResLinear(256, 256), ResLinear(256, 256), ResLinear(256, 256), ResLinear(256, 256), ) self.layer4 = nn.Linear(256, 1) def forward(self, x): x = self.layer1(x) x = self.act_fn(x) x = self.res_layers(x) return self.layer4(x) class TaggerAestheticModel(nn.Module): def __init__(self, feature_dim: int): super().__init__() self.scoring_head = ScoringHead(feature_dim) def forward(self, features): return self.scoring_head(features) def extract_scoring_head_state_dict(state_dict: dict[str, torch.Tensor], feature_dim: int) -> dict[str, torch.Tensor]: expected_head = ScoringHead(in_size=feature_dim).state_dict() expected_keys = set(expected_head.keys()) full_keys = set(state_dict.keys()) if full_keys == expected_keys: return state_dict model_prefix = "scoring_head." extracted = {k[len(model_prefix):]: v for k, v in state_dict.items() if k.startswith(model_prefix)} if set(extracted.keys()) == expected_keys: return extracted raise AssertionError("Could not locate scoring_head weights in provided state dict")