File size: 2,407 Bytes
c75b31a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
"""Stage 4B bigger specialist student.

Depth 8, embed 384, 6 heads, MLP ratio 4. Emits a 768-D vector per image
(matches the full EUPE-ViT-B pooled layernormed output) for cosine-similarity
distillation.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F


class PatchEmbed(nn.Module):
    def __init__(self, in_ch=3, embed_dim=384, patch_size=16):
        super().__init__()
        self.proj = nn.Conv2d(in_ch, embed_dim, patch_size, stride=patch_size)

    def forward(self, x):
        x = self.proj(x)
        return x.flatten(2).transpose(1, 2)


class Block(nn.Module):
    def __init__(self, dim, heads, mlp_ratio=4.0):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attn = nn.MultiheadAttention(dim, heads, batch_first=True)
        self.norm2 = nn.LayerNorm(dim)
        hidden = int(dim * mlp_ratio)
        self.mlp = nn.Sequential(nn.Linear(dim, hidden), nn.GELU(), nn.Linear(hidden, dim))

    def forward(self, x):
        h = self.norm1(x)
        h, _ = self.attn(h, h, h, need_weights=False)
        x = x + h
        x = x + self.mlp(self.norm2(x))
        return x


class Stage4BStudent(nn.Module):
    """Outputs a 768-D vector per image to match the EUPE-ViT-B teacher."""
    def __init__(self, out_dim=768, embed_dim=384, depth=8, heads=6,
                 patch_size=16, img_size=768, mlp_ratio=4.0):
        super().__init__()
        self.patch = PatchEmbed(3, embed_dim, patch_size)
        self.num_patches = (img_size // patch_size) ** 2
        self.pos = nn.Parameter(torch.zeros(1, self.num_patches, embed_dim))
        nn.init.trunc_normal_(self.pos, std=0.02)
        self.blocks = nn.ModuleList([Block(embed_dim, heads, mlp_ratio) for _ in range(depth)])
        self.norm = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, out_dim)

    def forward(self, x):
        tokens = self.patch(x)
        tokens = tokens + self.pos[:, :tokens.shape[1]]
        for blk in self.blocks:
            tokens = blk(tokens)
        tokens = self.norm(tokens)
        pooled = tokens.max(dim=1).values
        return self.head(pooled)


if __name__ == '__main__':
    m = Stage4BStudent()
    total = sum(p.numel() for p in m.parameters())
    print(f'Stage 4B student: {total:,} params = {total/1e6:.2f}M')
    x = torch.randn(2, 3, 768, 768)
    y = m(x)
    print(f'forward OK  output shape: {tuple(y.shape)}')