vu0018 commited on
Commit
077c0bf
·
verified ·
1 Parent(s): b636403

Create model.py

Browse files
Files changed (1) hide show
  1. model.py +32 -0
model.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ class ConvEncoder(nn.Module):
6
+ def __init__(self):
7
+ super().__init__()
8
+ self.features = nn.Sequential(
9
+ nn.Conv2d(3, 32, 3, stride=1, padding=1), nn.ReLU(),
10
+ nn.Conv2d(32, 64, 3, stride=2, padding=1), nn.ReLU(),
11
+ nn.Conv2d(64, 128, 3, stride=2, padding=1), nn.ReLU(),
12
+ )
13
+
14
+ def forward(self, x):
15
+ return self.features(x)
16
+
17
+
18
+ class GenConViT(nn.Module):
19
+ def __init__(self):
20
+ super().__init__()
21
+ self.encoder = ConvEncoder()
22
+ self.classifier = nn.Sequential(
23
+ nn.Linear(128 * 56 * 56, 256),
24
+ nn.ReLU(),
25
+ nn.Linear(256, 2)
26
+ )
27
+
28
+ def forward(self, x):
29
+ feat = self.encoder(x)
30
+ feat = feat.view(feat.size(0), -1)
31
+ out = self.classifier(feat)
32
+ return out