phanerozoic's picture
Stage 4: specialist student (3.27M params, F1 0.710 vs 0.894 baseline)
864ba61 verified
"""Stage 4 specialist student architecture.
Compact ViT designed to reproduce the 100 target dims of EUPE-ViT-B that
feed the Stage 0 classifier. Depth 6, embed 192, patch 16, 3 heads. Emits a
100-D vector per image via a final projection from the max-pooled patch
tokens (plus global pool of CLS). Designed to pair with a frozen ternary
classifier head.
"""
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
class PatchEmbed(nn.Module):
def __init__(self, in_ch=3, embed_dim=192, patch_size=16):
super().__init__()
self.proj = nn.Conv2d(in_ch, embed_dim, patch_size, stride=patch_size)
def forward(self, x):
x = self.proj(x)
B, C, H, W = x.shape
return x.flatten(2).transpose(1, 2), H, W # (B, HW, C)
class Block(nn.Module):
def __init__(self, dim, heads, mlp_ratio=4.0):
super().__init__()
self.norm1 = nn.LayerNorm(dim)
self.attn = nn.MultiheadAttention(dim, heads, batch_first=True)
self.norm2 = nn.LayerNorm(dim)
hidden = int(dim * mlp_ratio)
self.mlp = nn.Sequential(
nn.Linear(dim, hidden), nn.GELU(), nn.Linear(hidden, dim))
def forward(self, x):
h = self.norm1(x)
h, _ = self.attn(h, h, h, need_weights=False)
x = x + h
x = x + self.mlp(self.norm2(x))
return x
class SpecialistStudent(nn.Module):
"""Compact ViT that outputs a 100-D vector per image."""
def __init__(self, out_dim=40, embed_dim=192, depth=6, heads=3,
patch_size=16, img_size=768, mlp_ratio=4.0):
super().__init__()
self.patch = PatchEmbed(3, embed_dim, patch_size)
self.num_patches = (img_size // patch_size) ** 2
self.pos = nn.Parameter(torch.zeros(1, self.num_patches, embed_dim))
nn.init.trunc_normal_(self.pos, std=0.02)
self.blocks = nn.ModuleList([Block(embed_dim, heads, mlp_ratio) for _ in range(depth)])
self.norm = nn.LayerNorm(embed_dim)
self.head = nn.Linear(embed_dim, out_dim)
def forward(self, x):
"""x: (B, 3, H, W). Returns (B, out_dim)."""
tokens, H, W = self.patch(x)
tokens = tokens + self.pos[:, :tokens.shape[1]]
for blk in self.blocks:
tokens = blk(tokens)
tokens = self.norm(tokens) # (B, HW, embed_dim)
pooled = tokens.max(dim=1).values # (B, embed_dim) max-pool per image
return self.head(pooled) # (B, out_dim)
if __name__ == '__main__':
m = SpecialistStudent()
total = sum(p.numel() for p in m.parameters())
print(f'Specialist student total params: {total:,} = {total/1e6:.2f}M')
x = torch.randn(2, 3, 768, 768)
y = m(x)
print(f'forward OK output shape: {tuple(y.shape)}')