tsaddev commited on
Commit
5e123e1
·
1 Parent(s): 6c373d6

Update app/Hackathon_setup/exp_recognition_model.py

Browse files
app/Hackathon_setup/exp_recognition_model.py CHANGED
@@ -25,10 +25,15 @@ class facExpRec(torch.nn.Module):
25
  self.conv1 = self.convlayer(in_channels=1, out_channels=64, kernel_size=3)
26
  self.conv2 = self.convlayer(in_channels=64, out_channels=128, kernel_size=5)
27
  self.conv3 = self.convlayer(in_channels=128, out_channels=512, kernel_size=3)
 
28
 
29
- self.fc1 = self.fclayer(512*3*3, 256)
30
  self.fc2 = self.fclayer(256, 512)
31
- self.fc3 = nn.Linear(512, out_features)
 
 
 
 
32
 
33
  def convlayer(self, in_channels, out_channels, kernel_size, max_pool=2):
34
  return nn.Sequential(
@@ -50,20 +55,13 @@ class facExpRec(torch.nn.Module):
50
 
51
  def forward(self, x):
52
  x = self.conv1(x)
53
- logger.info(x.shape)
54
  x = self.conv2(x)
55
- logger.info(x.shape)
56
  x = self.conv3(x)
57
- logger.info(x.shape)
58
- logger.info(x.shape)
59
- x = x.view(-1, 512*3*3)
60
- logger.info(x.shape)
61
  x = self.fc1(x)
62
- logger.info(x.shape)
63
  x = self.fc2(x)
64
- logger.info(x.shape)
65
- x = self.fc3(x)
66
- logger.info(x.shape)
67
  return x
68
 
69
  # Sample Helper function
 
25
  self.conv1 = self.convlayer(in_channels=1, out_channels=64, kernel_size=3)
26
  self.conv2 = self.convlayer(in_channels=64, out_channels=128, kernel_size=5)
27
  self.conv3 = self.convlayer(in_channels=128, out_channels=512, kernel_size=3)
28
+ self.conv4 = self.convlayer(in_channels=512, out_channels=512, kernel_size=3, max_pool=1)
29
 
30
+ self.fc1 = self.fclayer(512, 256)
31
  self.fc2 = self.fclayer(256, 512)
32
+ self.last = nn.Sequential(
33
+ nn.Linear(512, 256),
34
+ nn.ReLU(),
35
+ nn.Linear(256, 7)
36
+ )
37
 
38
  def convlayer(self, in_channels, out_channels, kernel_size, max_pool=2):
39
  return nn.Sequential(
 
55
 
56
  def forward(self, x):
57
  x = self.conv1(x)
 
58
  x = self.conv2(x)
 
59
  x = self.conv3(x)
60
+ x = self.conv4(x)
61
+ x = x.view(-1, 512)
 
 
62
  x = self.fc1(x)
 
63
  x = self.fc2(x)
64
+ x = self.last(x)
 
 
65
  return x
66
 
67
  # Sample Helper function