"""Stage 4 specialist student architecture. Compact ViT designed to reproduce the 100 target dims of EUPE-ViT-B that feed the Stage 0 classifier. Depth 6, embed 192, patch 16, 3 heads. Emits a 100-D vector per image via a final projection from the max-pooled patch tokens (plus global pool of CLS). Designed to pair with a frozen ternary classifier head. """ import math import torch import torch.nn as nn import torch.nn.functional as F class PatchEmbed(nn.Module): def __init__(self, in_ch=3, embed_dim=192, patch_size=16): super().__init__() self.proj = nn.Conv2d(in_ch, embed_dim, patch_size, stride=patch_size) def forward(self, x): x = self.proj(x) B, C, H, W = x.shape return x.flatten(2).transpose(1, 2), H, W # (B, HW, C) class Block(nn.Module): def __init__(self, dim, heads, mlp_ratio=4.0): super().__init__() self.norm1 = nn.LayerNorm(dim) self.attn = nn.MultiheadAttention(dim, heads, batch_first=True) self.norm2 = nn.LayerNorm(dim) hidden = int(dim * mlp_ratio) self.mlp = nn.Sequential( nn.Linear(dim, hidden), nn.GELU(), nn.Linear(hidden, dim)) def forward(self, x): h = self.norm1(x) h, _ = self.attn(h, h, h, need_weights=False) x = x + h x = x + self.mlp(self.norm2(x)) return x class SpecialistStudent(nn.Module): """Compact ViT that outputs a 100-D vector per image.""" def __init__(self, out_dim=40, embed_dim=192, depth=6, heads=3, patch_size=16, img_size=768, mlp_ratio=4.0): super().__init__() self.patch = PatchEmbed(3, embed_dim, patch_size) self.num_patches = (img_size // patch_size) ** 2 self.pos = nn.Parameter(torch.zeros(1, self.num_patches, embed_dim)) nn.init.trunc_normal_(self.pos, std=0.02) self.blocks = nn.ModuleList([Block(embed_dim, heads, mlp_ratio) for _ in range(depth)]) self.norm = nn.LayerNorm(embed_dim) self.head = nn.Linear(embed_dim, out_dim) def forward(self, x): """x: (B, 3, H, W). Returns (B, out_dim).""" tokens, H, W = self.patch(x) tokens = tokens + self.pos[:, :tokens.shape[1]] for blk in self.blocks: tokens = blk(tokens) tokens = self.norm(tokens) # (B, HW, embed_dim) pooled = tokens.max(dim=1).values # (B, embed_dim) max-pool per image return self.head(pooled) # (B, out_dim) if __name__ == '__main__': m = SpecialistStudent() total = sum(p.numel() for p in m.parameters()) print(f'Specialist student total params: {total:,} = {total/1e6:.2f}M') x = torch.randn(2, 3, 768, 768) y = m(x) print(f'forward OK output shape: {tuple(y.shape)}')