e6-visual-ratings / model.py
taigasan's picture
Upload model files and artifacts
fa21e63 verified
Raw
History Blame Contribute Delete
7.6 kB
import math
from functools import lru_cache
import torch
from torch import nn
import torch.nn.functional as F
D_MODEL = 1280
N_HEADS = 20
HEAD_DIM = D_MODEL // N_HEADS
N_LAYERS = 32
D_FFN = 5120
N_REGISTERS = 4
PATCH_SIZE = 16
ROPE_THETA = 100.0
ROPE_RESCALE = 2.0
LN_EPS = 1e-5
LAYERSCALE = 1.0
@lru_cache(maxsize=32)
def _patch_coords_cached(h: int, w: int, device_str: str) -> torch.Tensor:
device = torch.device(device_str)
cy = torch.arange(0.5, h, dtype=torch.float32, device=device) / h
cx = torch.arange(0.5, w, dtype=torch.float32, device=device) / w
coords = torch.stack(torch.meshgrid(cy, cx, indexing="ij"), dim=-1).flatten(0, 1)
coords = 2.0 * coords - 1.0
coords = coords * ROPE_RESCALE
return coords
def _build_rope(h_patches: int, w_patches: int, dtype: torch.dtype, device: torch.device):
coords = _patch_coords_cached(h_patches, w_patches, str(device))
inv_freq = 1.0 / (ROPE_THETA ** torch.arange(0, 1, 4 / HEAD_DIM, dtype=torch.float32, device=device))
angles = 2 * math.pi * coords[:, :, None] * inv_freq[None, None, :]
angles = angles.flatten(1, 2).tile(2)
cos = torch.cos(angles).to(dtype).unsqueeze(0).unsqueeze(0)
sin = torch.sin(angles).to(dtype).unsqueeze(0).unsqueeze(0)
return cos, sin
def _rotate_half(x: torch.Tensor) -> torch.Tensor:
h = x.shape[-1] // 2
return torch.cat((-x[..., h:], x[..., :h]), dim=-1)
def _apply_rope(q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor):
n_pre = 1 + N_REGISTERS
q_pre, q_pat = q[..., :n_pre, :], q[..., n_pre:, :]
k_pre, k_pat = k[..., :n_pre, :], k[..., n_pre:, :]
q_pat = q_pat * cos + _rotate_half(q_pat) * sin
k_pat = k_pat * cos + _rotate_half(k_pat) * sin
return torch.cat([q_pre, q_pat], dim=-2), torch.cat([k_pre, k_pat], dim=-2)
class _Attention(nn.Module):
def __init__(self):
super().__init__()
self.q_proj = nn.Linear(D_MODEL, D_MODEL, bias=True)
self.k_proj = nn.Linear(D_MODEL, D_MODEL, bias=False)
self.v_proj = nn.Linear(D_MODEL, D_MODEL, bias=True)
self.o_proj = nn.Linear(D_MODEL, D_MODEL, bias=True)
def forward(self, x, cos, sin):
bsz, seq_len, _ = x.shape
q = self.q_proj(x).view(bsz, seq_len, N_HEADS, HEAD_DIM).transpose(1, 2)
k = self.k_proj(x).view(bsz, seq_len, N_HEADS, HEAD_DIM).transpose(1, 2)
v = self.v_proj(x).view(bsz, seq_len, N_HEADS, HEAD_DIM).transpose(1, 2)
q, k = _apply_rope(q, k, cos, sin)
out = F.scaled_dot_product_attention(q, k, v, scale=HEAD_DIM ** -0.5)
return self.o_proj(out.transpose(1, 2).reshape(bsz, seq_len, D_MODEL))
class _GatedMLP(nn.Module):
def __init__(self):
super().__init__()
self.gate_proj = nn.Linear(D_MODEL, D_FFN, bias=True)
self.up_proj = nn.Linear(D_MODEL, D_FFN, bias=True)
self.down_proj = nn.Linear(D_FFN, D_MODEL, bias=True)
def forward(self, x):
return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))
class _Block(nn.Module):
def __init__(self):
super().__init__()
self.norm1 = nn.LayerNorm(D_MODEL, eps=LN_EPS)
self.attention = _Attention()
self.layer_scale1 = nn.Parameter(torch.full((D_MODEL,), LAYERSCALE))
self.norm2 = nn.LayerNorm(D_MODEL, eps=LN_EPS)
self.mlp = _GatedMLP()
self.layer_scale2 = nn.Parameter(torch.full((D_MODEL,), LAYERSCALE))
def forward(self, x, cos, sin):
x = x + self.attention(self.norm1(x), cos, sin) * self.layer_scale1
x = x + self.mlp(self.norm2(x)) * self.layer_scale2
return x
class _Embeddings(nn.Module):
def __init__(self):
super().__init__()
self.cls_token = nn.Parameter(torch.zeros(1, 1, D_MODEL))
self.mask_token = nn.Parameter(torch.zeros(1, 1, D_MODEL))
self.register_tokens = nn.Parameter(torch.zeros(1, N_REGISTERS, D_MODEL))
self.patch_embeddings = nn.Conv2d(3, D_MODEL, kernel_size=PATCH_SIZE, stride=PATCH_SIZE)
def forward(self, pixel_values):
batch = pixel_values.shape[0]
dtype = self.patch_embeddings.weight.dtype
patches = self.patch_embeddings(pixel_values.to(dtype)).flatten(2).transpose(1, 2)
cls = self.cls_token.expand(batch, -1, -1)
regs = self.register_tokens.expand(batch, -1, -1)
return torch.cat([cls, regs, patches], dim=1)
class DINOv3ViTH(nn.Module):
def __init__(self):
super().__init__()
self.embeddings = _Embeddings()
self.layer = nn.ModuleList([_Block() for _ in range(N_LAYERS)])
self.norm = nn.LayerNorm(D_MODEL, eps=LN_EPS)
def forward(self, pixel_values):
_, _, height, width = pixel_values.shape
x = self.embeddings(pixel_values)
h_p, w_p = height // PATCH_SIZE, width // PATCH_SIZE
cos, sin = _build_rope(h_p, w_p, x.dtype, pixel_values.device)
for block in self.layer:
x = block(x, cos, sin)
return self.norm(x)
def _split_and_clean_state_dict(sd: dict) -> tuple[dict, dict]:
backbone_sd: dict = {}
head_sd: dict = {}
for key, value in sd.items():
if key.startswith("backbone."):
new_key = key[len("backbone."):]
if new_key.startswith("model.layer."):
new_key = new_key[len("model."):]
backbone_sd[new_key] = value
else:
head_sd[key] = value
for key in list(backbone_sd.keys()):
if ".layer_scale" in key and key.endswith(".lambda1"):
backbone_sd[key[:-len(".lambda1")]] = backbone_sd.pop(key)
for key in list(backbone_sd.keys()):
if "rope_embeddings" in key:
backbone_sd.pop(key)
return backbone_sd, head_sd
class ResLinear(nn.Module):
def __init__(self, *args, **kwargs):
super().__init__()
self.linear = nn.Linear(*args, **kwargs)
self.act_fn = nn.SiLU()
def forward(self, inputs):
return self.act_fn(self.linear(inputs)) + inputs
class ScoringHead(nn.Module):
def __init__(self, in_size):
super().__init__()
self.layer1 = nn.Linear(in_size, 256)
self.act_fn = nn.SiLU()
self.res_layers = nn.Sequential(
ResLinear(256, 256),
ResLinear(256, 256),
ResLinear(256, 256),
ResLinear(256, 256),
ResLinear(256, 256),
ResLinear(256, 256),
ResLinear(256, 256),
ResLinear(256, 256),
ResLinear(256, 256),
)
self.layer4 = nn.Linear(256, 1)
def forward(self, x):
x = self.layer1(x)
x = self.act_fn(x)
x = self.res_layers(x)
return self.layer4(x)
class TaggerAestheticModel(nn.Module):
def __init__(self, feature_dim: int):
super().__init__()
self.scoring_head = ScoringHead(feature_dim)
def forward(self, features):
return self.scoring_head(features)
def extract_scoring_head_state_dict(state_dict: dict[str, torch.Tensor], feature_dim: int) -> dict[str, torch.Tensor]:
expected_head = ScoringHead(in_size=feature_dim).state_dict()
expected_keys = set(expected_head.keys())
full_keys = set(state_dict.keys())
if full_keys == expected_keys:
return state_dict
model_prefix = "scoring_head."
extracted = {k[len(model_prefix):]: v for k, v in state_dict.items() if k.startswith(model_prefix)}
if set(extracted.keys()) == expected_keys:
return extracted
raise AssertionError("Could not locate scoring_head weights in provided state dict")