"""Stage 4B bigger specialist student. Depth 8, embed 384, 6 heads, MLP ratio 4. Emits a 768-D vector per image (matches the full EUPE-ViT-B pooled layernormed output) for cosine-similarity distillation. """ import torch import torch.nn as nn import torch.nn.functional as F class PatchEmbed(nn.Module): def __init__(self, in_ch=3, embed_dim=384, patch_size=16): super().__init__() self.proj = nn.Conv2d(in_ch, embed_dim, patch_size, stride=patch_size) def forward(self, x): x = self.proj(x) return x.flatten(2).transpose(1, 2) class Block(nn.Module): def __init__(self, dim, heads, mlp_ratio=4.0): super().__init__() self.norm1 = nn.LayerNorm(dim) self.attn = nn.MultiheadAttention(dim, heads, batch_first=True) self.norm2 = nn.LayerNorm(dim) hidden = int(dim * mlp_ratio) self.mlp = nn.Sequential(nn.Linear(dim, hidden), nn.GELU(), nn.Linear(hidden, dim)) def forward(self, x): h = self.norm1(x) h, _ = self.attn(h, h, h, need_weights=False) x = x + h x = x + self.mlp(self.norm2(x)) return x class Stage4BStudent(nn.Module): """Outputs a 768-D vector per image to match the EUPE-ViT-B teacher.""" def __init__(self, out_dim=768, embed_dim=384, depth=8, heads=6, patch_size=16, img_size=768, mlp_ratio=4.0): super().__init__() self.patch = PatchEmbed(3, embed_dim, patch_size) self.num_patches = (img_size // patch_size) ** 2 self.pos = nn.Parameter(torch.zeros(1, self.num_patches, embed_dim)) nn.init.trunc_normal_(self.pos, std=0.02) self.blocks = nn.ModuleList([Block(embed_dim, heads, mlp_ratio) for _ in range(depth)]) self.norm = nn.LayerNorm(embed_dim) self.head = nn.Linear(embed_dim, out_dim) def forward(self, x): tokens = self.patch(x) tokens = tokens + self.pos[:, :tokens.shape[1]] for blk in self.blocks: tokens = blk(tokens) tokens = self.norm(tokens) pooled = tokens.max(dim=1).values return self.head(pooled) if __name__ == '__main__': m = Stage4BStudent() total = sum(p.numel() for p in m.parameters()) print(f'Stage 4B student: {total:,} params = {total/1e6:.2f}M') x = torch.randn(2, 3, 768, 768) y = m(x) print(f'forward OK output shape: {tuple(y.shape)}')