dikro commited on
Commit
2fe9c55
·
verified ·
1 Parent(s): 4fe00dd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -6
app.py CHANGED
@@ -25,16 +25,17 @@ DEVICE = torch.device("cpu")
25
  class PretrainedEfficientNet(nn.Module):
26
  def __init__(self, num_classes=10):
27
  super().__init__()
28
- self.net = models.efficientnet_b0(weights=None)
29
- old = self.net.features[0][0]
30
- self.net.features[0][0] = nn.Conv2d(
31
  1, old.out_channels, kernel_size=old.kernel_size,
32
  stride=old.stride, padding=old.padding, bias=False)
33
- self.net.classifier[1] = nn.Linear(
34
- self.net.classifier[1].in_features, num_classes)
35
 
36
  def forward(self, x):
37
- return self.net(x)
 
38
 
39
  model = PretrainedEfficientNet(num_classes=10)
40
  weights_path = os.path.join(os.path.dirname(__file__), "best_effnet.pth")
 
25
  class PretrainedEfficientNet(nn.Module):
26
  def __init__(self, num_classes=10):
27
  super().__init__()
28
+ self.efficientnet = models.efficientnet_b0(weights=None)
29
+ old = self.efficientnet.features[0][0]
30
+ self.efficientnet.features[0][0] = nn.Conv2d(
31
  1, old.out_channels, kernel_size=old.kernel_size,
32
  stride=old.stride, padding=old.padding, bias=False)
33
+ self.efficientnet.classifier[1] = nn.Linear(
34
+ self.efficientnet.classifier[1].in_features, num_classes)
35
 
36
  def forward(self, x):
37
+ return self.efficientnet(x)
38
+
39
 
40
  model = PretrainedEfficientNet(num_classes=10)
41
  weights_path = os.path.join(os.path.dirname(__file__), "best_effnet.pth")