Bettensor commited on
Commit
69aac37
·
verified ·
1 Parent(s): 0a662f6

Added v1 base model source code

Browse files
Files changed (1) hide show
  1. Podos_v1_model.py +25 -0
Podos_v1_model.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from huggingface_hub import PyTorchModelHubMixin
4
+
5
+ class PodosTransformer(nn.Module,PyTorchModelHubMixin):
6
+ def __init__(self, input_dim, model_dim, num_classes, num_heads=4, num_layers=2, dropout=0.1,temperature=1):
7
+ super(PodosTransformer, self).__init__()
8
+ self.temperature = temperature
9
+
10
+ self.projection = nn.Linear(input_dim, model_dim)
11
+ encoder_layer = nn.TransformerEncoderLayer(d_model=model_dim, nhead=num_heads, dropout=dropout)
12
+ self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
13
+ self.fc = nn.Linear(model_dim, num_classes)
14
+
15
+ def forward(self, x):
16
+ x = self.projection(x)
17
+ x = x.unsqueeze(1)
18
+ x = self.transformer_encoder(x)
19
+ x = x.mean(dim=1)
20
+ x = self.fc(x)
21
+
22
+ if self.temperature != 1.0:
23
+ x = x / self.temperature
24
+
25
+ return x