File size: 349 Bytes
5df9707
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
import torchvision

class ResNet18(torchvision.models.ResNet):
    def __init__(self, num_classes=1000, weight=None):
        super(ResNet18, self).__init__(block=torchvision.models.resnet.BasicBlock, layers=[2, 2, 2, 2], num_classes=num_classes)
        self.zero_init_residual = True

    def forward(self, x):
        return self._forward_impl(x)