Vvaann commited on
Commit
5678dec
·
verified ·
1 Parent(s): 007294c

Update resnet_lightning.py

Browse files
Files changed (1) hide show
  1. resnet_lightning.py +1 -2
resnet_lightning.py CHANGED
@@ -77,8 +77,7 @@ class ResNet18Model(L.LightningModule):
77
  self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
78
  self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
79
  self.linear = nn.Linear(512*block.expansion, num_classes)
80
-
81
- self.accuracy = Accuracy(task="multiclass", num_classes=10)
82
 
83
  def _make_layer(self, block, planes, num_blocks, stride):
84
  strides = [stride] + [1]*(num_blocks-1)
 
77
  self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
78
  self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
79
  self.linear = nn.Linear(512*block.expansion, num_classes)
80
+ self.accuracy = torchmetrics.classification.Accuracy(task="multiclass", num_classes=10)
 
81
 
82
  def _make_layer(self, block, planes, num_blocks, stride):
83
  strides = [stride] + [1]*(num_blocks-1)