| import math |
| from functools import lru_cache |
|
|
| import torch |
| from torch import nn |
| import torch.nn.functional as F |
|
|
| D_MODEL = 1280 |
| N_HEADS = 20 |
| HEAD_DIM = D_MODEL // N_HEADS |
| N_LAYERS = 32 |
| D_FFN = 5120 |
| N_REGISTERS = 4 |
| PATCH_SIZE = 16 |
| ROPE_THETA = 100.0 |
| ROPE_RESCALE = 2.0 |
| LN_EPS = 1e-5 |
| LAYERSCALE = 1.0 |
|
|
|
|
| @lru_cache(maxsize=32) |
| def _patch_coords_cached(h: int, w: int, device_str: str) -> torch.Tensor: |
| device = torch.device(device_str) |
| cy = torch.arange(0.5, h, dtype=torch.float32, device=device) / h |
| cx = torch.arange(0.5, w, dtype=torch.float32, device=device) / w |
| coords = torch.stack(torch.meshgrid(cy, cx, indexing="ij"), dim=-1).flatten(0, 1) |
| coords = 2.0 * coords - 1.0 |
| coords = coords * ROPE_RESCALE |
| return coords |
|
|
|
|
| def _build_rope(h_patches: int, w_patches: int, dtype: torch.dtype, device: torch.device): |
| coords = _patch_coords_cached(h_patches, w_patches, str(device)) |
| inv_freq = 1.0 / (ROPE_THETA ** torch.arange(0, 1, 4 / HEAD_DIM, dtype=torch.float32, device=device)) |
| angles = 2 * math.pi * coords[:, :, None] * inv_freq[None, None, :] |
| angles = angles.flatten(1, 2).tile(2) |
| cos = torch.cos(angles).to(dtype).unsqueeze(0).unsqueeze(0) |
| sin = torch.sin(angles).to(dtype).unsqueeze(0).unsqueeze(0) |
| return cos, sin |
|
|
|
|
| def _rotate_half(x: torch.Tensor) -> torch.Tensor: |
| h = x.shape[-1] // 2 |
| return torch.cat((-x[..., h:], x[..., :h]), dim=-1) |
|
|
|
|
| def _apply_rope(q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor): |
| n_pre = 1 + N_REGISTERS |
| q_pre, q_pat = q[..., :n_pre, :], q[..., n_pre:, :] |
| k_pre, k_pat = k[..., :n_pre, :], k[..., n_pre:, :] |
| q_pat = q_pat * cos + _rotate_half(q_pat) * sin |
| k_pat = k_pat * cos + _rotate_half(k_pat) * sin |
| return torch.cat([q_pre, q_pat], dim=-2), torch.cat([k_pre, k_pat], dim=-2) |
|
|
|
|
| class _Attention(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.q_proj = nn.Linear(D_MODEL, D_MODEL, bias=True) |
| self.k_proj = nn.Linear(D_MODEL, D_MODEL, bias=False) |
| self.v_proj = nn.Linear(D_MODEL, D_MODEL, bias=True) |
| self.o_proj = nn.Linear(D_MODEL, D_MODEL, bias=True) |
|
|
| def forward(self, x, cos, sin): |
| bsz, seq_len, _ = x.shape |
| q = self.q_proj(x).view(bsz, seq_len, N_HEADS, HEAD_DIM).transpose(1, 2) |
| k = self.k_proj(x).view(bsz, seq_len, N_HEADS, HEAD_DIM).transpose(1, 2) |
| v = self.v_proj(x).view(bsz, seq_len, N_HEADS, HEAD_DIM).transpose(1, 2) |
| q, k = _apply_rope(q, k, cos, sin) |
| out = F.scaled_dot_product_attention(q, k, v, scale=HEAD_DIM ** -0.5) |
| return self.o_proj(out.transpose(1, 2).reshape(bsz, seq_len, D_MODEL)) |
|
|
|
|
| class _GatedMLP(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.gate_proj = nn.Linear(D_MODEL, D_FFN, bias=True) |
| self.up_proj = nn.Linear(D_MODEL, D_FFN, bias=True) |
| self.down_proj = nn.Linear(D_FFN, D_MODEL, bias=True) |
|
|
| def forward(self, x): |
| return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x)) |
|
|
|
|
| class _Block(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.norm1 = nn.LayerNorm(D_MODEL, eps=LN_EPS) |
| self.attention = _Attention() |
| self.layer_scale1 = nn.Parameter(torch.full((D_MODEL,), LAYERSCALE)) |
| self.norm2 = nn.LayerNorm(D_MODEL, eps=LN_EPS) |
| self.mlp = _GatedMLP() |
| self.layer_scale2 = nn.Parameter(torch.full((D_MODEL,), LAYERSCALE)) |
|
|
| def forward(self, x, cos, sin): |
| x = x + self.attention(self.norm1(x), cos, sin) * self.layer_scale1 |
| x = x + self.mlp(self.norm2(x)) * self.layer_scale2 |
| return x |
|
|
|
|
| class _Embeddings(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.cls_token = nn.Parameter(torch.zeros(1, 1, D_MODEL)) |
| self.mask_token = nn.Parameter(torch.zeros(1, 1, D_MODEL)) |
| self.register_tokens = nn.Parameter(torch.zeros(1, N_REGISTERS, D_MODEL)) |
| self.patch_embeddings = nn.Conv2d(3, D_MODEL, kernel_size=PATCH_SIZE, stride=PATCH_SIZE) |
|
|
| def forward(self, pixel_values): |
| batch = pixel_values.shape[0] |
| dtype = self.patch_embeddings.weight.dtype |
| patches = self.patch_embeddings(pixel_values.to(dtype)).flatten(2).transpose(1, 2) |
| cls = self.cls_token.expand(batch, -1, -1) |
| regs = self.register_tokens.expand(batch, -1, -1) |
| return torch.cat([cls, regs, patches], dim=1) |
|
|
|
|
| class DINOv3ViTH(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.embeddings = _Embeddings() |
| self.layer = nn.ModuleList([_Block() for _ in range(N_LAYERS)]) |
| self.norm = nn.LayerNorm(D_MODEL, eps=LN_EPS) |
|
|
| def forward(self, pixel_values): |
| _, _, height, width = pixel_values.shape |
| x = self.embeddings(pixel_values) |
| h_p, w_p = height // PATCH_SIZE, width // PATCH_SIZE |
| cos, sin = _build_rope(h_p, w_p, x.dtype, pixel_values.device) |
| for block in self.layer: |
| x = block(x, cos, sin) |
| return self.norm(x) |
|
|
|
|
| def _split_and_clean_state_dict(sd: dict) -> tuple[dict, dict]: |
| backbone_sd: dict = {} |
| head_sd: dict = {} |
| for key, value in sd.items(): |
| if key.startswith("backbone."): |
| new_key = key[len("backbone."):] |
| if new_key.startswith("model.layer."): |
| new_key = new_key[len("model."):] |
| backbone_sd[new_key] = value |
| else: |
| head_sd[key] = value |
|
|
| for key in list(backbone_sd.keys()): |
| if ".layer_scale" in key and key.endswith(".lambda1"): |
| backbone_sd[key[:-len(".lambda1")]] = backbone_sd.pop(key) |
|
|
| for key in list(backbone_sd.keys()): |
| if "rope_embeddings" in key: |
| backbone_sd.pop(key) |
|
|
| return backbone_sd, head_sd |
|
|
|
|
| class ResLinear(nn.Module): |
| def __init__(self, *args, **kwargs): |
| super().__init__() |
| self.linear = nn.Linear(*args, **kwargs) |
| self.act_fn = nn.SiLU() |
|
|
| def forward(self, inputs): |
| return self.act_fn(self.linear(inputs)) + inputs |
|
|
|
|
| class ScoringHead(nn.Module): |
| def __init__(self, in_size): |
| super().__init__() |
| self.layer1 = nn.Linear(in_size, 256) |
| self.act_fn = nn.SiLU() |
| self.res_layers = nn.Sequential( |
| ResLinear(256, 256), |
| ResLinear(256, 256), |
| ResLinear(256, 256), |
| ResLinear(256, 256), |
| ResLinear(256, 256), |
| ResLinear(256, 256), |
| ResLinear(256, 256), |
| ResLinear(256, 256), |
| ResLinear(256, 256), |
| ) |
| self.layer4 = nn.Linear(256, 1) |
|
|
| def forward(self, x): |
| x = self.layer1(x) |
| x = self.act_fn(x) |
| x = self.res_layers(x) |
| return self.layer4(x) |
|
|
|
|
| class TaggerAestheticModel(nn.Module): |
| def __init__(self, feature_dim: int): |
| super().__init__() |
| self.scoring_head = ScoringHead(feature_dim) |
|
|
| def forward(self, features): |
| return self.scoring_head(features) |
|
|
|
|
| def extract_scoring_head_state_dict(state_dict: dict[str, torch.Tensor], feature_dim: int) -> dict[str, torch.Tensor]: |
| expected_head = ScoringHead(in_size=feature_dim).state_dict() |
| expected_keys = set(expected_head.keys()) |
| full_keys = set(state_dict.keys()) |
|
|
| if full_keys == expected_keys: |
| return state_dict |
|
|
| model_prefix = "scoring_head." |
| extracted = {k[len(model_prefix):]: v for k, v in state_dict.items() if k.startswith(model_prefix)} |
| if set(extracted.keys()) == expected_keys: |
| return extracted |
|
|
| raise AssertionError("Could not locate scoring_head weights in provided state dict") |