CSATv2 / ResNet18.py
sosigikiller's picture
change_folder
f8cea41
raw
history blame
349 Bytes
import torchvision
class ResNet18(torchvision.models.ResNet):
def __init__(self, num_classes=1000, weight=None):
super(ResNet18, self).__init__(block=torchvision.models.resnet.BasicBlock, layers=[2, 2, 2, 2], num_classes=num_classes)
self.zero_init_residual = True
def forward(self, x):
return self._forward_impl(x)