File size: 2,800 Bytes
864ba61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
"""Stage 4 specialist student architecture.

Compact ViT designed to reproduce the 100 target dims of EUPE-ViT-B that
feed the Stage 0 classifier. Depth 6, embed 192, patch 16, 3 heads. Emits a
100-D vector per image via a final projection from the max-pooled patch
tokens (plus global pool of CLS). Designed to pair with a frozen ternary
classifier head.
"""
import math
import torch
import torch.nn as nn
import torch.nn.functional as F


class PatchEmbed(nn.Module):
    def __init__(self, in_ch=3, embed_dim=192, patch_size=16):
        super().__init__()
        self.proj = nn.Conv2d(in_ch, embed_dim, patch_size, stride=patch_size)

    def forward(self, x):
        x = self.proj(x)
        B, C, H, W = x.shape
        return x.flatten(2).transpose(1, 2), H, W  # (B, HW, C)


class Block(nn.Module):
    def __init__(self, dim, heads, mlp_ratio=4.0):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attn = nn.MultiheadAttention(dim, heads, batch_first=True)
        self.norm2 = nn.LayerNorm(dim)
        hidden = int(dim * mlp_ratio)
        self.mlp = nn.Sequential(
            nn.Linear(dim, hidden), nn.GELU(), nn.Linear(hidden, dim))

    def forward(self, x):
        h = self.norm1(x)
        h, _ = self.attn(h, h, h, need_weights=False)
        x = x + h
        x = x + self.mlp(self.norm2(x))
        return x


class SpecialistStudent(nn.Module):
    """Compact ViT that outputs a 100-D vector per image."""

    def __init__(self, out_dim=40, embed_dim=192, depth=6, heads=3,
                 patch_size=16, img_size=768, mlp_ratio=4.0):
        super().__init__()
        self.patch = PatchEmbed(3, embed_dim, patch_size)
        self.num_patches = (img_size // patch_size) ** 2
        self.pos = nn.Parameter(torch.zeros(1, self.num_patches, embed_dim))
        nn.init.trunc_normal_(self.pos, std=0.02)
        self.blocks = nn.ModuleList([Block(embed_dim, heads, mlp_ratio) for _ in range(depth)])
        self.norm = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, out_dim)

    def forward(self, x):
        """x: (B, 3, H, W). Returns (B, out_dim)."""
        tokens, H, W = self.patch(x)
        tokens = tokens + self.pos[:, :tokens.shape[1]]
        for blk in self.blocks:
            tokens = blk(tokens)
        tokens = self.norm(tokens)               # (B, HW, embed_dim)
        pooled = tokens.max(dim=1).values        # (B, embed_dim) max-pool per image
        return self.head(pooled)                  # (B, out_dim)


if __name__ == '__main__':
    m = SpecialistStudent()
    total = sum(p.numel() for p in m.parameters())
    print(f'Specialist student total params: {total:,} = {total/1e6:.2f}M')
    x = torch.randn(2, 3, 768, 768)
    y = m(x)
    print(f'forward OK  output shape: {tuple(y.shape)}')