| """Stage 4 specialist student architecture. |
| |
| Compact ViT designed to reproduce the 100 target dims of EUPE-ViT-B that |
| feed the Stage 0 classifier. Depth 6, embed 192, patch 16, 3 heads. Emits a |
| 100-D vector per image via a final projection from the max-pooled patch |
| tokens (plus global pool of CLS). Designed to pair with a frozen ternary |
| classifier head. |
| """ |
| import math |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
|
|
| class PatchEmbed(nn.Module): |
| def __init__(self, in_ch=3, embed_dim=192, patch_size=16): |
| super().__init__() |
| self.proj = nn.Conv2d(in_ch, embed_dim, patch_size, stride=patch_size) |
|
|
| def forward(self, x): |
| x = self.proj(x) |
| B, C, H, W = x.shape |
| return x.flatten(2).transpose(1, 2), H, W |
|
|
|
|
| class Block(nn.Module): |
| def __init__(self, dim, heads, mlp_ratio=4.0): |
| super().__init__() |
| self.norm1 = nn.LayerNorm(dim) |
| self.attn = nn.MultiheadAttention(dim, heads, batch_first=True) |
| self.norm2 = nn.LayerNorm(dim) |
| hidden = int(dim * mlp_ratio) |
| self.mlp = nn.Sequential( |
| nn.Linear(dim, hidden), nn.GELU(), nn.Linear(hidden, dim)) |
|
|
| def forward(self, x): |
| h = self.norm1(x) |
| h, _ = self.attn(h, h, h, need_weights=False) |
| x = x + h |
| x = x + self.mlp(self.norm2(x)) |
| return x |
|
|
|
|
| class SpecialistStudent(nn.Module): |
| """Compact ViT that outputs a 100-D vector per image.""" |
|
|
| def __init__(self, out_dim=40, embed_dim=192, depth=6, heads=3, |
| patch_size=16, img_size=768, mlp_ratio=4.0): |
| super().__init__() |
| self.patch = PatchEmbed(3, embed_dim, patch_size) |
| self.num_patches = (img_size // patch_size) ** 2 |
| self.pos = nn.Parameter(torch.zeros(1, self.num_patches, embed_dim)) |
| nn.init.trunc_normal_(self.pos, std=0.02) |
| self.blocks = nn.ModuleList([Block(embed_dim, heads, mlp_ratio) for _ in range(depth)]) |
| self.norm = nn.LayerNorm(embed_dim) |
| self.head = nn.Linear(embed_dim, out_dim) |
|
|
| def forward(self, x): |
| """x: (B, 3, H, W). Returns (B, out_dim).""" |
| tokens, H, W = self.patch(x) |
| tokens = tokens + self.pos[:, :tokens.shape[1]] |
| for blk in self.blocks: |
| tokens = blk(tokens) |
| tokens = self.norm(tokens) |
| pooled = tokens.max(dim=1).values |
| return self.head(pooled) |
|
|
|
|
| if __name__ == '__main__': |
| m = SpecialistStudent() |
| total = sum(p.numel() for p in m.parameters()) |
| print(f'Specialist student total params: {total:,} = {total/1e6:.2f}M') |
| x = torch.randn(2, 3, 768, 768) |
| y = m(x) |
| print(f'forward OK output shape: {tuple(y.shape)}') |
|
|