phanerozoic's picture
Stage 4B: 15.67M student + cosine loss on 768-D, F1 0.723 (+0.013 over Stage 4)
c75b31a verified
"""Stage 4B bigger specialist student.
Depth 8, embed 384, 6 heads, MLP ratio 4. Emits a 768-D vector per image
(matches the full EUPE-ViT-B pooled layernormed output) for cosine-similarity
distillation.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
class PatchEmbed(nn.Module):
def __init__(self, in_ch=3, embed_dim=384, patch_size=16):
super().__init__()
self.proj = nn.Conv2d(in_ch, embed_dim, patch_size, stride=patch_size)
def forward(self, x):
x = self.proj(x)
return x.flatten(2).transpose(1, 2)
class Block(nn.Module):
def __init__(self, dim, heads, mlp_ratio=4.0):
super().__init__()
self.norm1 = nn.LayerNorm(dim)
self.attn = nn.MultiheadAttention(dim, heads, batch_first=True)
self.norm2 = nn.LayerNorm(dim)
hidden = int(dim * mlp_ratio)
self.mlp = nn.Sequential(nn.Linear(dim, hidden), nn.GELU(), nn.Linear(hidden, dim))
def forward(self, x):
h = self.norm1(x)
h, _ = self.attn(h, h, h, need_weights=False)
x = x + h
x = x + self.mlp(self.norm2(x))
return x
class Stage4BStudent(nn.Module):
"""Outputs a 768-D vector per image to match the EUPE-ViT-B teacher."""
def __init__(self, out_dim=768, embed_dim=384, depth=8, heads=6,
patch_size=16, img_size=768, mlp_ratio=4.0):
super().__init__()
self.patch = PatchEmbed(3, embed_dim, patch_size)
self.num_patches = (img_size // patch_size) ** 2
self.pos = nn.Parameter(torch.zeros(1, self.num_patches, embed_dim))
nn.init.trunc_normal_(self.pos, std=0.02)
self.blocks = nn.ModuleList([Block(embed_dim, heads, mlp_ratio) for _ in range(depth)])
self.norm = nn.LayerNorm(embed_dim)
self.head = nn.Linear(embed_dim, out_dim)
def forward(self, x):
tokens = self.patch(x)
tokens = tokens + self.pos[:, :tokens.shape[1]]
for blk in self.blocks:
tokens = blk(tokens)
tokens = self.norm(tokens)
pooled = tokens.max(dim=1).values
return self.head(pooled)
if __name__ == '__main__':
m = Stage4BStudent()
total = sum(p.numel() for p in m.parameters())
print(f'Stage 4B student: {total:,} params = {total/1e6:.2f}M')
x = torch.randn(2, 3, 768, 768)
y = m(x)
print(f'forward OK output shape: {tuple(y.shape)}')