import torchvision class ResNet18(torchvision.models.ResNet): def __init__(self, num_classes=1000, weight=None): super(ResNet18, self).__init__(block=torchvision.models.resnet.BasicBlock, layers=[2, 2, 2, 2], num_classes=num_classes) self.zero_init_residual = True def forward(self, x): return self._forward_impl(x)