bernabeSanchez commited on
Commit
98fb9fb
·
1 Parent(s): e5776d7

Update files/inference.py

Browse files
Files changed (1) hide show
  1. files/inference.py +1 -1
files/inference.py CHANGED
@@ -36,7 +36,7 @@ class Net(nn.Module):
36
  def forward(self, x):
37
  x = self.pool(F.relu(self.conv1(x)))
38
  x = self.pool(F.relu(self.conv2(x)))
39
- x = torch.flatten(x, 1) # flatten all dimensions except batch
40
  x = F.relu(self.fc1(x))
41
  x = F.relu(self.fc2(x))
42
  x = self.fc3(x)
 
36
  def forward(self, x):
37
  x = self.pool(F.relu(self.conv1(x)))
38
  x = self.pool(F.relu(self.conv2(x)))
39
+ x = torch.flatten(x, start_dim=1)
40
  x = F.relu(self.fc1(x))
41
  x = F.relu(self.fc2(x))
42
  x = self.fc3(x)