resberry commited on
Commit
ba30de8
·
verified ·
1 Parent(s): adfc9ff

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +2 -3
model.py CHANGED
@@ -1,12 +1,12 @@
 
1
  import torch
2
- import torchvision
3
  from torch import nn
4
  from torchvision import models
5
 
6
  class FineTunedResNet(nn.Module):
7
  def __init__(self, num_classes=3):
8
  super(FineTunedResNet, self).__init__()
9
- self.resnet = models.resnet50(pretrained=True)
10
  self.resnet.fc = nn.Sequential(
11
  nn.Linear(self.resnet.fc.in_features, 512),
12
  nn.ReLU(),
@@ -16,4 +16,3 @@ class FineTunedResNet(nn.Module):
16
 
17
  def forward(self, x):
18
  return self.resnet(x)
19
-
 
1
+ %%writefile /content/lung_disease_detection/model.py
2
  import torch
 
3
  from torch import nn
4
  from torchvision import models
5
 
6
  class FineTunedResNet(nn.Module):
7
  def __init__(self, num_classes=3):
8
  super(FineTunedResNet, self).__init__()
9
+ self.resnet = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
10
  self.resnet.fc = nn.Sequential(
11
  nn.Linear(self.resnet.fc.in_features, 512),
12
  nn.ReLU(),
 
16
 
17
  def forward(self, x):
18
  return self.resnet(x)