boringKey's picture
Upload 236 files
5fee096 verified
'''
Code Reference:
https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py
https://github.com/G-U-N/PyCIL/blob/master/convs/resnet.py
'''
import copy
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parameter import Parameter
__all__ = ['resnet18', 'resnet34', 'resnet50', 'cifar_resnet20', 'cifar_resnet32', 'cifar_resnet32_V2', 'resnet32_V2', 'resnet18_AML', 'CosineLinear', 'SplitCosineLinear']
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
"""3x3 convolution with padding"""
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
padding=dilation, groups=groups, bias=False, dilation=dilation)
def conv1x1(in_planes, out_planes, stride=1):
"""1x1 convolution"""
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
class BasicBlock(nn.Module):
expansion = 1
__constants__ = ['downsample']
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
base_width=64, dilation=1, norm_layer=None):
super(BasicBlock, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
if groups != 1 or base_width != 64:
raise ValueError('BasicBlock only supports groups=1 and base_width=64')
if dilation > 1:
raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
self.conv1 = conv3x3(inplanes, planes, stride)
self.bn1 = norm_layer(planes)
self.relu = nn.ReLU(inplace=True)
self.conv2 = conv3x3(planes, planes)
self.bn2 = norm_layer(planes)
self.downsample = downsample
self.stride = stride
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
class Bottleneck(nn.Module):
expansion = 4
__constants__ = ['downsample']
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
base_width=64, dilation=1, norm_layer=None):
super(Bottleneck, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
width = int(planes * (base_width / 64.)) * groups
# Both self.conv2 and self.downsample layers downsample the input when stride != 1
self.conv1 = conv1x1(inplanes, width)
self.bn1 = norm_layer(width)
self.conv2 = conv3x3(width, width, stride, groups, dilation)
self.bn2 = norm_layer(width)
self.conv3 = conv1x1(width, planes * self.expansion)
self.bn3 = norm_layer(planes * self.expansion)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
class ResNet(nn.Module):
def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
groups=1, width_per_group=64, replace_stride_with_dilation=None,
norm_layer=None,args=None):
super(ResNet, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
self._norm_layer = norm_layer
self.inplanes = 64
self.dilation = 1
if replace_stride_with_dilation is None:
# each element in the tuple indicates if we should replace
# the 2x2 stride with a dilated convolution instead
replace_stride_with_dilation = [False, False, False]
if len(replace_stride_with_dilation) != 3:
raise ValueError("replace_stride_with_dilation should be None "
"or a 3-element tuple, got {}".format(replace_stride_with_dilation))
self.groups = groups
self.base_width = width_per_group
assert args is not None, "you should pass args to resnet"
if 'cifar' in args["dataset"] or '5-datasets' in args["dataset"]:
self.conv1 = nn.Sequential(nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False),
nn.BatchNorm2d(self.inplanes), nn.ReLU(inplace=True))
elif 'imagenet' in args["dataset"]:
if args["init_cls_num"] == args["inc_cls_num"]:
self.conv1 = nn.Sequential(
nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False),
nn.BatchNorm2d(self.inplanes),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
)
else:
self.conv1 = nn.Sequential(
nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False),
nn.BatchNorm2d(self.inplanes),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
)
self.layer1 = self._make_layer(block, 64, layers[0])
self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
dilate=replace_stride_with_dilation[0])
self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
dilate=replace_stride_with_dilation[1])
self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
dilate=replace_stride_with_dilation[2])
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.out_dim = 512 * block.expansion
# self.fc = nn.Linear(512 * block.expansion, num_classes) # Removed in _forward_impl
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
# Zero-initialize the last BN in each residual branch,
# so that the residual branch starts with zeros, and each residual block behaves like an identity.
# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
if zero_init_residual:
for m in self.modules():
if isinstance(m, Bottleneck):
nn.init.constant_(m.bn3.weight, 0)
elif isinstance(m, BasicBlock):
nn.init.constant_(m.bn2.weight, 0)
self.neck = nn.ModuleList()
self.fc = nn.Linear(512, 20)
def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
norm_layer = self._norm_layer
downsample = None
previous_dilation = self.dilation
if dilate:
self.dilation *= stride
stride = 1
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
conv1x1(self.inplanes, planes * block.expansion, stride),
norm_layer(planes * block.expansion),
)
layers = []
layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
self.base_width, previous_dilation, norm_layer))
self.inplanes = planes * block.expansion
for _ in range(1, blocks):
layers.append(block(self.inplanes, planes, groups=self.groups,
base_width=self.base_width, dilation=self.dilation,
norm_layer=norm_layer))
return nn.Sequential(*layers)
def _forward_impl(self, x):
x = self.conv1(x)
x_1 = self.layer1(x)
x_2 = self.layer2(x_1)
x_3 = self.layer3(x_2)
x_4 = self.layer4(x_3)
pooled = self.avgpool(x_4)
features = torch.flatten(pooled, 1)
return {
'fmaps': [x_1, x_2, x_3, x_4],
'features': features
}
def forward(self, x):
return self._forward_impl(x)
def feature(self, x):
x = self.conv1(x)
x_1 = self.layer1(x)
x_2 = self.layer2(x_1)
x_3 = self.layer3(x_2)
x_4 = self.layer4(x_3)
pooled = self.avgpool(x_4)
features = torch.flatten(pooled, 1)
return features
@property
def last_conv(self):
if hasattr(self.layer4[-1], 'conv3'):
return self.layer4[-1].conv3
else:
return self.layer4[-1].conv2
def _resnet(arch, block, layers, pretrained, progress, **kwargs):
model = ResNet(block, layers, **kwargs)
if pretrained:
raise NotImplementedError
if 'cosine_fc' in kwargs['args'].keys() and kwargs['args']['cosine_fc']:
in_features = model.fc.in_features
out_features = model.fc.out_features
model.fc = CosineLinear(in_features, out_features)
return model
def resnet18(pretrained=False, progress=True, **kwargs):
r"""ResNet-18 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress,
**kwargs)
def resnet34(pretrained=False, progress=True, **kwargs):
r"""ResNet-34 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress,
**kwargs)
def resnet50(pretrained=False, progress=True, **kwargs):
r"""ResNet-50 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress,
**kwargs)
class ResNetBasicblock(nn.Module):
expansion = 1
def __init__(self, inplanes, planes, stride=1, downsample=None):
super(ResNetBasicblock, self).__init__()
self.conv_a = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
self.bn_a = nn.BatchNorm2d(planes)
self.conv_b = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
self.bn_b = nn.BatchNorm2d(planes)
self.downsample = downsample
def forward(self, x):
residual = x
basicblock = self.conv_a(x)
basicblock = self.bn_a(basicblock)
basicblock = F.relu(basicblock, inplace=True)
basicblock = self.conv_b(basicblock)
basicblock = self.bn_b(basicblock)
if self.downsample is not None:
residual = self.downsample(x)
return F.relu(residual + basicblock, inplace=True)
'''
Code Reference:
https://github.com/G-U-N/PyCIL/blob/master/convs/resnet.py
We keep this version ResNet to ensure that we can achieve better accuracy.
'''
class CifarResNet(nn.Module):
"""
ResNet optimized for the Cifar Dataset, as specified in
https://arxiv.org/abs/1512.03385.pdf
"""
def __init__(self, block, depth, channels=3):
super(CifarResNet, self).__init__()
# Model type specifies number of layers for CIFAR-10 and CIFAR-100 model
assert (depth - 2) % 6 == 0, 'depth should be one of 20, 32, 44, 56, 110'
layer_blocks = (depth - 2) // 6
self.conv_1_3x3 = nn.Conv2d(channels, 16, kernel_size=3, stride=1, padding=1, bias=False)
self.bn_1 = nn.BatchNorm2d(16)
self.inplanes = 16
self.stage_1 = self._make_layer(block, 16, layer_blocks, 1)
self.stage_2 = self._make_layer(block, 32, layer_blocks, 2)
self.stage_3 = self._make_layer(block, 64, layer_blocks, 2)
self.avgpool = nn.AvgPool2d(8)
self.out_dim = 64 * block.expansion
# self.fc = nn.Linear(64*block.expansion, 20)
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
# m.bias.data.zero_()
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
nn.init.kaiming_normal_(m.weight)
m.bias.data.zero_()
self.neck = nn.ModuleList()
def _make_layer(self, block, planes, blocks, stride=1):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
nn.Conv2d(self.inplanes, planes * block.expansion,
kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(planes * block.expansion),
)
# downsample = DownsampleA(self.inplanes, planes * block.expansion, stride)
layers = []
layers.append(block(self.inplanes, planes, stride, downsample))
self.inplanes = planes * block.expansion
for i in range(1, blocks):
layers.append(block(self.inplanes, planes))
return nn.Sequential(*layers)
def forward(self, x):
x = self.conv_1_3x3(x)
x = F.relu(self.bn_1(x), inplace=True)
x_1 = self.stage_1(x)
x_2 = self.stage_2(x_1)
x_3 = self.stage_3(x_2)
pooled = self.avgpool(x_3)
features = pooled.view(pooled.size(0), -1)
return {
'fmaps': [x_1, x_2, x_3],
'features': features
}
def feature(self, x):
x = self.conv_1_3x3(x)
x = F.relu(self.bn_1(x), inplace=True)
x_1 = self.stage_1(x)
x_2 = self.stage_2(x_1)
x_3 = self.stage_3(x_2)
pooled = self.avgpool(x_3)
features = pooled.view(pooled.size(0), -1)
return features
@property
def last_conv(self):
return self.stage_3[-1].conv_b
'''
Code Reference:
https://github.com/hshustc/CVPR19_Incremental_Learning/blob/master/cifar100-class-incremental/modified_linear.py
'''
class CosineLinear(nn.Module):
def __init__(self, in_features, out_features, sigma=True):
super(CosineLinear, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.weight = Parameter(torch.Tensor(out_features, in_features))
if sigma:
self.sigma = Parameter(torch.Tensor(1))
else:
self.register_parameter('sigma', None)
self.reset_parameters()
def reset_parameters(self):
stdv = 1. / math.sqrt(self.weight.size(1))
self.weight.data.uniform_(-stdv, stdv)
if self.sigma is not None:
self.sigma.data.fill_(1) #for initializaiton of sigma
def forward(self, input):
out = F.linear(F.normalize(input, p=2,dim=1), \
F.normalize(self.weight, p=2, dim=1))
if self.sigma is not None:
out = self.sigma * out
return out
class SplitCosineLinear(nn.Module):
#consists of two fc layers and concatenate their outputs
def __init__(self, in_features, out_features1, out_features2, sigma=True):
super(SplitCosineLinear, self).__init__()
self.in_features = in_features
self.out_features = out_features1 + out_features2
self.fc1 = CosineLinear(in_features, out_features1, False)
self.fc2 = CosineLinear(in_features, out_features2, False)
if sigma:
self.sigma = Parameter(torch.Tensor(1))
self.sigma.data.fill_(1)
else:
self.register_parameter('sigma', None)
def forward(self, x):
out1 = self.fc1(x)
out2 = self.fc2(x)
out = torch.cat((out1, out2), dim=1) #concatenate along the channel
if self.sigma is not None:
out = self.sigma * out
return out
'''
Code Reference:
https://github.com/hshustc/CVPR19_Incremental_Learning
The version of ResNet used in the official LUCIR repository, if not used, will lead to a decrease in performance.
'''
class modified_BasicBlock(nn.Module):
expansion = 1
def __init__(self, inplanes, planes, stride=1, downsample=None, last=False):
super(modified_BasicBlock, self).__init__()
self.conv1 = conv3x3(inplanes, planes, stride)
self.bn1 = nn.BatchNorm2d(planes)
self.relu = nn.ReLU(inplace=True)
self.conv2 = conv3x3(planes, planes)
self.bn2 = nn.BatchNorm2d(planes)
self.downsample = downsample
self.stride = stride
self.last = last
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
if not self.last: #remove ReLU in the last layer
out = self.relu(out)
return out
class modified_ResNet(nn.Module):
def __init__(self, block, layers, num_classes=10):
self.inplanes = 16
super(modified_ResNet, self).__init__()
self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1,
bias=False)
self.bn1 = nn.BatchNorm2d(16)
self.relu = nn.ReLU(inplace=True)
self.layer1 = self._make_layer(block, 16, layers[0])
self.layer2 = self._make_layer(block, 32, layers[1], stride=2)
self.layer3 = self._make_layer(block, 64, layers[2], stride=2, last_phase=True)
self.avgpool = nn.AvgPool2d(8, stride=1)
# self.fc = modified_linear.CosineLinear(64 * block.expansion, num_classes)
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
def _make_layer(self, block, planes, blocks, stride=1, last_phase=False):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
nn.Conv2d(self.inplanes, planes * block.expansion,
kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(planes * block.expansion),
)
layers = []
layers.append(block(self.inplanes, planes, stride, downsample))
self.inplanes = planes * block.expansion
if last_phase:
for i in range(1, blocks-1):
layers.append(block(self.inplanes, planes))
layers.append(block(self.inplanes, planes, last=True))
else:
for i in range(1, blocks):
layers.append(block(self.inplanes, planes))
return nn.Sequential(*layers)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.avgpool(x)
x = x.view(x.size(0), -1)
return {"features": x}
def feature(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.avgpool(x)
x = x.view(x.size(0), -1)
return x
# Temporary Resnet for BIC, TODO: Merge
class BiasLayer(nn.Module):
def __init__(self):
super().__init__()
self.alpha = nn.Parameter(torch.ones(1, requires_grad=True))
self.beta = nn.Parameter(torch.zeros(1, requires_grad=True))
def forward(self, x):
return self.alpha * x + self.beta
class BasicBlock2(nn.Module):
expansion = 1
def __init__(self, inplanes, planes, stride=1, downsample=None):
super().__init__()
self.bn1 = nn.BatchNorm2d(inplanes)
self.relu = nn.ReLU(inplace=True)
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
self.downsample = downsample
self.stride = stride
def forward(self, x):
residual = x
out = self.bn1(x)
out = self.relu(out)
out = self.conv1(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv2(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
return out
class ResNet_BIC(nn.Module):
def __init__(self, depth, block_name='BasicBlock2'):
super().__init__()
# Model type specifies number of layers for CIFAR-10 model
if block_name.lower() == 'basicblock2':
assert (depth - 2) % 6 == 0, 'When use basicblock, depth should be 6n+2, e.g. 20, 32, 44, 56, 110, 1202'
n = (depth - 2) // 6
block = BasicBlock2
elif block_name.lower() == 'bottleneck':
assert 0, 'bottleneck is called, should not happen in method : BIC'
assert (depth - 2) % 9 == 0, 'When use bottleneck, depth should be 9n+2, e.g. 20, 29, 47, 56, 110, 1199'
n = (depth - 2) // 9
block = Bottleneck
else:
raise ValueError('block_name shoule be Basicblock or Bottleneck')
self.inplanes = 16
self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1, bias=False)
self.layer1 = self._make_layer(block, 16, n)
self.layer2 = self._make_layer(block, 32, n, stride=2)
self.layer3 = self._make_layer(block, 64, n, stride=2)
self.bn = nn.BatchNorm2d(64 * block.expansion)
self.relu = nn.ReLU(inplace=True)
self.avgpool = nn.AvgPool2d(8)
self.feat_dim = 256 # final feature's dim
#self.feat_dim = 3136 # ImageNet-R
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
def _make_layer(self, block, planes, blocks, stride=1):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
nn.Conv2d(self.inplanes, planes * block.expansion,
kernel_size=1, stride=stride, bias=False),
)
layers = []
layers.append(block(self.inplanes, planes, stride, downsample))
self.inplanes = planes * block.expansion
for i in range(1, blocks):
layers.append(block(self.inplanes, planes))
return nn.Sequential(*layers)
def forward(self, x):
x = self.conv1(x)
x = self.layer1(x) # 32x32
x = self.layer2(x) # 16x16
x = self.layer3(x) # 8x8
x = self.bn(x)
x = self.relu(x)
x = self.avgpool(x)
x = x.view(x.size(0), -1)
return x
# Temporary Resnet for AML (ERACE, ERAML), TODO: Merge
class BasicBlock_AML(nn.Module):
expansion = 1
def __init__(self, in_planes, planes, stride=1):
super().__init__()
self.conv1 = conv3x3(in_planes, planes, stride)
self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = conv3x3(planes, planes)
self.bn2 = nn.BatchNorm2d(planes)
self.shortcut = nn.Sequential()
if stride != 1 or in_planes != self.expansion * planes:
self.shortcut = nn.Sequential(
nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1,
stride=stride, bias=False),
nn.BatchNorm2d(self.expansion * planes)
)
self.activation = nn.ReLU()
def forward(self, x):
out = self.activation(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
out = out + self.shortcut(x)
out = self.activation(out)
return out
class ResNet_AML(nn.Module):
def __init__(self, block, num_blocks, num_classes, nf=20, input_size=(3, 32, 32)):
super().__init__()
self.in_planes = nf
self.input_size = input_size
self.conv1 = conv3x3(input_size[0], nf * 1)
self.bn1 = nn.BatchNorm2d(nf * 1)
self.layer1 = self._make_layer(block, nf * 1, num_blocks[0], stride=1)
self.layer2 = self._make_layer(block, nf * 2, num_blocks[1], stride=2)
self.layer3 = self._make_layer(block, nf * 4, num_blocks[2], stride=2)
self.layer4 = self._make_layer(block, nf * 8, num_blocks[3], stride=2)
self.activation = nn.ReLU()
with torch.no_grad():
dummy = torch.zeros(1, *self.input_size)
out = self.forward(dummy)
self.out_dim = out.view(1, -1).shape[1]
def _make_layer(self, block, planes, num_blocks, stride):
strides = [stride] + [1] * (num_blocks - 1)
layers = []
for stride in strides:
layers.append(block(self.in_planes, planes, stride))
self.in_planes = planes * block.expansion
return nn.Sequential(*layers)
def forward(self, x):
out = self.activation(self.bn1(self.conv1(x)))
out = self.layer1(out)
out = self.layer2(out)
out = self.layer3(out)
out = self.layer4(out)
out = F.avg_pool2d(out, 4)
return out.view(out.size(0), -1)
def cifar_resnet20(pretrained=False, **kwargs):
n = 3
model = CifarResNet(ResNetBasicblock, 20)
return model
def cifar_resnet32(pretrained=False, **kwargs):
# EWC
model = CifarResNet(ResNetBasicblock, 32)
return model
def cifar_resnet32_V2(pretrained=False, **kwargs):
# BIC
return ResNet_BIC(32)
def resnet32_V2(pretrained=False, **kwargs):
# ## LUCIR:
n = 5
model = modified_ResNet(modified_BasicBlock, [n, n, n], num_classes=50)
return model
def resnet18_AML(pretrained=False, **kwargs):
if 'input_size' not in kwargs.keys():
kwargs['input_size'] = [3, 32, 32]
return ResNet_AML(BasicBlock_AML, [2, 2, 2, 2], kwargs['num_classes'], input_size = kwargs['input_size'])