dboris commited on
Commit
9f004e6
·
verified ·
1 Parent(s): 3f5fa19

Upload src/heads/mlp_head.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. src/heads/mlp_head.py +22 -0
src/heads/mlp_head.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ MLP classification head — shared across all backbones.
3
+
4
+ LayerNorm → Linear → GELU → Dropout → Linear → num_classes
5
+ """
6
+
7
+ import torch.nn as nn
8
+
9
+
10
+ class MLPHead(nn.Module):
11
+ def __init__(self, embed_dim: int, num_classes: int, hidden_dim: int = 512, dropout: float = 0.3):
12
+ super().__init__()
13
+ self.head = nn.Sequential(
14
+ nn.LayerNorm(embed_dim),
15
+ nn.Linear(embed_dim, hidden_dim),
16
+ nn.GELU(),
17
+ nn.Dropout(dropout),
18
+ nn.Linear(hidden_dim, num_classes),
19
+ )
20
+
21
+ def forward(self, x):
22
+ return self.head(x)