SAE / utils /inc_net.py
Ttius's picture
Upload 192 files
998bb30 verified
import copy
import logging
import torch
from torch import nn
from convs.cifar_resnet import resnet32
from convs.resnet import resnet18, resnet34, resnet50, resnet101, resnet152
from convs.ucir_cifar_resnet import resnet32 as cosine_resnet32
from convs.ucir_resnet import resnet18 as cosine_resnet18
from convs.ucir_resnet import resnet34 as cosine_resnet34
from convs.ucir_resnet import resnet50 as cosine_resnet50
from convs.linears import SimpleLinear, SplitCosineLinear, CosineLinear
from convs.modified_represnet import resnet18_rep, resnet34_rep
from convs.resnet_cbam import resnet18_cbam, resnet34_cbam, resnet50_cbam
from convs.memo_resnet import get_resnet18_imagenet as get_memo_resnet18 # for MEMO imagenet
from convs.memo_cifar_resnet import get_resnet32_a2fc as get_memo_resnet32 # for MEMO cifar
from convs.ACL_buffer import RandomBuffer, activation_t
from convs.linears import RecursiveLinear
from typing import Dict, Any
def get_convnet(args, pretrained=False):
name = args["convnet_type"].lower()
if name == "resnet32":
return resnet32()
elif name == "resnet18":
return resnet18(pretrained=pretrained, args=args)
elif name == "resnet34":
return resnet34(pretrained=pretrained, args=args)
elif name == "resnet50":
return resnet50(pretrained=pretrained, args=args)
elif name == "cosine_resnet18":
return cosine_resnet18(pretrained=pretrained, args=args)
elif name == "cosine_resnet32":
return cosine_resnet32()
elif name == "cosine_resnet34":
return cosine_resnet34(pretrained=pretrained, args=args)
elif name == "cosine_resnet50":
return cosine_resnet50(pretrained=pretrained, args=args)
elif name == "resnet18_rep":
return resnet18_rep(pretrained=pretrained, args=args)
elif name == "resnet18_cbam":
return resnet18_cbam(pretrained=pretrained, args=args)
elif name == "resnet34_cbam":
return resnet34_cbam(pretrained=pretrained, args=args)
elif name == "resnet50_cbam":
return resnet50_cbam(pretrained=pretrained, args=args)
# MEMO benchmark backbone
elif name == 'memo_resnet18':
_basenet, _adaptive_net = get_memo_resnet18()
return _basenet, _adaptive_net
elif name == 'memo_resnet32':
_basenet, _adaptive_net = get_memo_resnet32()
return _basenet, _adaptive_net
else:
raise NotImplementedError("Unknown type {}".format(name))
class BaseNet(nn.Module):
def __init__(self, args, pretrained):
super(BaseNet, self).__init__()
self.args = args
self.convnet = get_convnet(args, pretrained)
self.fc = None
@property
def feature_dim(self):
return self.convnet.out_dim
def extract_vector(self, x):
return self.convnet(x)["features"]
def forward(self, x):
x = self.convnet(x)
out = self.fc(x["features"])
"""
{
'fmaps': [x_1, x_2, ..., x_n],
'features': features
'logits': logits
}
"""
out.update(x)
if self.args["attack"]:
return out["logits"]
else:
return out
def update_fc(self, nb_classes):
pass
def generate_fc(self, in_dim, out_dim):
pass
def copy(self):
return copy.deepcopy(self)
def freeze(self):
for param in self.parameters():
param.requires_grad = False
self.eval()
return self
def load_checkpoint(self, args):
if args["init_cls"] == 50:
pkl_name = "{}_{}_{}_B{}_Inc{}".format(
args["dataset"],
args["seed"],
args["convnet_type"],
0,
args["init_cls"],
)
checkpoint_name = f"checkpoints/finetune_{pkl_name}_0.pkl"
else:
checkpoint_name = f"checkpoints/finetune_{args['csv_name']}_0.pkl"
model_infos = torch.load(checkpoint_name)
self.convnet.load_state_dict(model_infos['convnet'])
self.fc.load_state_dict(model_infos['fc'])
test_acc = model_infos['test_acc']
return test_acc
class IncrementalNet(BaseNet):
def __init__(self, args, pretrained, gradcam=False):
super().__init__(args, pretrained)
self.gradcam = gradcam
if hasattr(self, "gradcam") and self.gradcam:
self._gradcam_hooks = [None, None]
self.set_gradcam_hook()
def update_fc(self, nb_classes):
fc = self.generate_fc(self.feature_dim, nb_classes)
if self.fc is not None:
nb_output = self.fc.out_features
weight = copy.deepcopy(self.fc.weight.data)
bias = copy.deepcopy(self.fc.bias.data)
fc.weight.data[:nb_output] = weight
fc.bias.data[:nb_output] = bias
del self.fc
self.fc = fc
def weight_align(self, increment):
weights = self.fc.weight.data
newnorm = torch.norm(weights[-increment:, :], p=2, dim=1)
oldnorm = torch.norm(weights[:-increment, :], p=2, dim=1)
meannew = torch.mean(newnorm)
meanold = torch.mean(oldnorm)
gamma = meanold / meannew
print("alignweights,gamma=", gamma)
self.fc.weight.data[-increment:, :] *= gamma
def generate_fc(self, in_dim, out_dim):
fc = SimpleLinear(in_dim, out_dim)
return fc
def forward(self, x):
x = self.convnet(x)
out = self.fc(x["features"])
out.update(x)
if hasattr(self, "gradcam") and self.gradcam:
out["gradcam_gradients"] = self._gradcam_gradients
out["gradcam_activations"] = self._gradcam_activations
# Change for Foolbox attack
if self.args["attack"]:
return out["logits"]
else:
return out
def unset_gradcam_hook(self):
self._gradcam_hooks[0].remove()
self._gradcam_hooks[1].remove()
self._gradcam_hooks[0] = None
self._gradcam_hooks[1] = None
self._gradcam_gradients, self._gradcam_activations = [None], [None]
def set_gradcam_hook(self):
self._gradcam_gradients, self._gradcam_activations = [None], [None]
def backward_hook(module, grad_input, grad_output):
self._gradcam_gradients[0] = grad_output[0]
return None
def forward_hook(module, input, output):
self._gradcam_activations[0] = output
return None
self._gradcam_hooks[0] = self.convnet.last_conv.register_backward_hook(
backward_hook
)
self._gradcam_hooks[1] = self.convnet.last_conv.register_forward_hook(
forward_hook
)
class IL2ANet(IncrementalNet):
def update_fc(self, num_old, num_total, num_aux):
fc = self.generate_fc(self.feature_dim, num_total + num_aux)
if self.fc is not None:
weight = copy.deepcopy(self.fc.weight.data)
bias = copy.deepcopy(self.fc.bias.data)
fc.weight.data[:num_old] = weight[:num_old]
fc.bias.data[:num_old] = bias[:num_old]
del self.fc
self.fc = fc
class CosineIncrementalNet(BaseNet):
def __init__(self, args, pretrained, nb_proxy=1):
super().__init__(args, pretrained)
self.nb_proxy = nb_proxy
def update_fc(self, nb_classes, task_num):
fc = self.generate_fc(self.feature_dim, nb_classes)
if self.fc is not None:
if task_num == 1:
fc.fc1.weight.data = self.fc.weight.data
fc.sigma.data = self.fc.sigma.data
else:
prev_out_features1 = self.fc.fc1.out_features
fc.fc1.weight.data[:prev_out_features1] = self.fc.fc1.weight.data
fc.fc1.weight.data[prev_out_features1:] = self.fc.fc2.weight.data
fc.sigma.data = self.fc.sigma.data
del self.fc
self.fc = fc
def generate_fc(self, in_dim, out_dim):
if self.fc is None:
fc = CosineLinear(in_dim, out_dim, self.nb_proxy, to_reduce=True)
else:
prev_out_features = self.fc.out_features // self.nb_proxy
# prev_out_features = self.fc.out_features
fc = SplitCosineLinear(
in_dim, prev_out_features, out_dim - prev_out_features, self.nb_proxy
)
return fc
class BiasLayer_BIC(nn.Module):
def __init__(self):
super(BiasLayer_BIC, self).__init__()
self.alpha = nn.Parameter(torch.ones(1, requires_grad=True))
self.beta = nn.Parameter(torch.zeros(1, requires_grad=True))
def forward(self, x, low_range, high_range):
ret_x = x.clone()
ret_x[:, low_range:high_range] = (
self.alpha * x[:, low_range:high_range] + self.beta
)
return ret_x
def get_params(self):
return (self.alpha.item(), self.beta.item())
class IncrementalNetWithBias(BaseNet):
def __init__(self, args, pretrained, bias_correction=False):
super().__init__(args, pretrained)
# Bias layer
self.bias_correction = bias_correction
self.bias_layers = nn.ModuleList([])
self.task_sizes = []
def forward(self, x):
x = self.convnet(x)
out = self.fc(x["features"])
if self.bias_correction:
logits = out["logits"]
for i, layer in enumerate(self.bias_layers):
logits = layer(
logits, sum(self.task_sizes[:i]), sum(self.task_sizes[: i + 1])
)
out["logits"] = logits
out.update(x)
if self.args["attack"]:
return out["logits"]
else:
return out
def update_fc(self, nb_classes):
fc = self.generate_fc(self.feature_dim, nb_classes)
if self.fc is not None:
nb_output = self.fc.out_features
weight = copy.deepcopy(self.fc.weight.data)
bias = copy.deepcopy(self.fc.bias.data)
fc.weight.data[:nb_output] = weight
fc.bias.data[:nb_output] = bias
del self.fc
self.fc = fc
new_task_size = nb_classes - sum(self.task_sizes)
self.task_sizes.append(new_task_size)
self.bias_layers.append(BiasLayer_BIC())
def generate_fc(self, in_dim, out_dim):
fc = SimpleLinear(in_dim, out_dim)
return fc
def get_bias_params(self):
params = []
for layer in self.bias_layers:
params.append(layer.get_params())
return params
def unfreeze(self):
for param in self.parameters():
param.requires_grad = True
class DERNet(nn.Module):
def __init__(self, args, pretrained):
super(DERNet, self).__init__()
self.convnet_type = args["convnet_type"]
self.convnets = nn.ModuleList()
self.pretrained = pretrained
self.out_dim = None
self.fc = None
self.aux_fc = None
self.task_sizes = []
self.args = args
@property
def feature_dim(self):
if self.out_dim is None:
return 0
return self.out_dim * len(self.convnets)
def extract_vector(self, x):
features = [convnet(x)["features"] for convnet in self.convnets]
features = torch.cat(features, 1)
return features
def forward(self, x):
features = [convnet(x)["features"] for convnet in self.convnets]
features = torch.cat(features, 1)
out = self.fc(features) # {logics: self.fc(features)}
aux_logits = self.aux_fc(features[:, -self.out_dim:])["logits"]
out.update({"aux_logits": aux_logits, "features": features})
if self.args["attack"]:
return out["logits"]
else:
return out
"""
{
'features': features
'logits': logits
'aux_logits':aux_logits
}
"""
def update_fc(self, nb_classes):
if len(self.convnets) == 0:
self.convnets.append(get_convnet(self.args))
else:
self.convnets.append(get_convnet(self.args))
self.convnets[-1].load_state_dict(self.convnets[-2].state_dict())
if self.out_dim is None:
self.out_dim = self.convnets[-1].out_dim
fc = self.generate_fc(self.feature_dim, nb_classes)
if self.fc is not None:
nb_output = self.fc.out_features
weight = copy.deepcopy(self.fc.weight.data)
bias = copy.deepcopy(self.fc.bias.data)
fc.weight.data[:nb_output, : self.feature_dim - self.out_dim] = weight
fc.bias.data[:nb_output] = bias
del self.fc
self.fc = fc
new_task_size = nb_classes - sum(self.task_sizes)
self.task_sizes.append(new_task_size)
self.aux_fc = self.generate_fc(self.out_dim, new_task_size + 1)
def generate_fc(self, in_dim, out_dim):
fc = SimpleLinear(in_dim, out_dim)
return fc
def copy(self):
return copy.deepcopy(self)
def freeze(self):
for param in self.parameters():
param.requires_grad = False
self.eval()
return self
def freeze_conv(self):
for param in self.convnets.parameters():
param.requires_grad = False
self.convnets.eval()
def weight_align(self, increment):
weights = self.fc.weight.data
newnorm = torch.norm(weights[-increment:, :], p=2, dim=1)
oldnorm = torch.norm(weights[:-increment, :], p=2, dim=1)
meannew = torch.mean(newnorm)
meanold = torch.mean(oldnorm)
gamma = meanold / meannew
print("alignweights,gamma=", gamma)
self.fc.weight.data[-increment:, :] *= gamma
def load_checkpoint(self, args):
checkpoint_name = f"checkpoints/finetune_{args['csv_name']}_0.pkl"
model_infos = torch.load(checkpoint_name)
assert len(self.convnets) == 1
self.convnets[0].load_state_dict(model_infos['convnet'])
self.fc.load_state_dict(model_infos['fc'])
test_acc = model_infos['test_acc']
return test_acc
class SimpleCosineIncrementalNet(BaseNet):
def __init__(self, args, pretrained):
super().__init__(args, pretrained)
def update_fc(self, nb_classes, nextperiod_initialization=None):
fc = self.generate_fc(self.feature_dim, nb_classes).cuda()
if self.fc is not None:
nb_output = self.fc.out_features
weight = copy.deepcopy(self.fc.weight.data)
fc.sigma.data = self.fc.sigma.data
if nextperiod_initialization is not None:
weight = torch.cat([weight, nextperiod_initialization])
else:
weight = torch.cat([weight, torch.zeros(nb_classes - nb_output, self.feature_dim).cuda()])
fc.weight = nn.Parameter(weight)
del self.fc
self.fc = fc
def generate_fc(self, in_dim, out_dim):
fc = CosineLinear(in_dim, out_dim)
return fc
def regenerate_fc(self, nb_classes):
fc = self.generate_fc(self.feature_dim, nb_classes).cuda()
del self.fc
self.fc = fc
return fc
class MultiBranchCosineIncrementalNet(BaseNet):
def __init__(self, args, pretrained):
super().__init__(args, pretrained)
# no need the convnet.
print(
'Clear the convnet in MultiBranchCosineIncrementalNet, since we are using self.convnets with dual branches')
self.convnet = torch.nn.Identity()
for param in self.convnet.parameters():
param.requires_grad = False
self.convnets = nn.ModuleList()
self.args = args
def update_fc(self, nb_classes, nextperiod_initialization=None):
fc = self.generate_fc(self._feature_dim, nb_classes).cuda()
if self.fc is not None:
nb_output = self.fc.out_features
weight = copy.deepcopy(self.fc.weight.data)
fc.sigma.data = self.fc.sigma.data
if nextperiod_initialization is not None:
weight = torch.cat([weight, nextperiod_initialization])
else:
weight = torch.cat([weight, torch.zeros(nb_classes - nb_output, self._feature_dim).cuda()])
fc.weight = nn.Parameter(weight)
del self.fc
self.fc = fc
def generate_fc(self, in_dim, out_dim):
fc = CosineLinear(in_dim, out_dim)
return fc
def forward(self, x):
features = [convnet(x)["features"] for convnet in self.convnets]
features = torch.cat(features, 1)
# import pdb; pdb.set_trace()
out = self.fc(features)
out.update({"features": features})
if self.args["attack"]:
return out["logits"]
else:
return out
def construct_dual_branch_network(self, trained_model, tuned_model, cls_num):
self.convnets.append(trained_model.convnet)
self.convnets.append(tuned_model.convnet)
self._feature_dim = self.convnets[0].out_dim * len(self.convnets)
self.fc = self.generate_fc(self._feature_dim, cls_num)
class FOSTERNet(nn.Module):
def __init__(self, args, pretrained):
super(FOSTERNet, self).__init__()
self.convnet_type = args["convnet_type"]
self.convnets = nn.ModuleList()
self.pretrained = pretrained
self.out_dim = None
self.fc = None
self.fe_fc = None
self.task_sizes = []
self.oldfc = None
self.args = args
@property
def feature_dim(self):
if self.out_dim is None:
return 0
return self.out_dim * len(self.convnets)
def extract_vector(self, x):
features = [convnet(x)["features"] for convnet in self.convnets]
features = torch.cat(features, 1)
return features
def forward(self, x):
features = [convnet(x)["features"] for convnet in self.convnets]
features = torch.cat(features, 1)
out = self.fc(features)
fe_logits = self.fe_fc(features[:, -self.out_dim:])["logits"]
out.update({"fe_logits": fe_logits, "features": features})
if self.oldfc is not None:
old_logits = self.oldfc(features[:, : -self.out_dim])["logits"]
out.update({"old_logits": old_logits})
out.update({"eval_logits": out["logits"]})
# Change for Foolbox attack
if self.args["attack"]:
return out["logits"]
else:
return out
def update_fc(self, nb_classes):
self.convnets.append(get_convnet(self.args))
if self.out_dim is None:
self.out_dim = self.convnets[-1].out_dim
fc = self.generate_fc(self.feature_dim, nb_classes)
if self.fc is not None:
nb_output = self.fc.out_features
weight = copy.deepcopy(self.fc.weight.data)
bias = copy.deepcopy(self.fc.bias.data)
fc.weight.data[:nb_output, : self.feature_dim - self.out_dim] = weight
fc.bias.data[:nb_output] = bias
self.convnets[-1].load_state_dict(self.convnets[-2].state_dict())
self.oldfc = self.fc
self.fc = fc
new_task_size = nb_classes - sum(self.task_sizes)
self.task_sizes.append(new_task_size)
self.fe_fc = self.generate_fc(self.out_dim, nb_classes)
def generate_fc(self, in_dim, out_dim):
fc = SimpleLinear(in_dim, out_dim)
return fc
def copy(self):
return copy.deepcopy(self)
def copy_fc(self, fc):
weight = copy.deepcopy(fc.weight.data)
bias = copy.deepcopy(fc.bias.data)
n, m = weight.shape[0], weight.shape[1]
self.fc.weight.data[:n, :m] = weight
self.fc.bias.data[:n] = bias
def freeze(self):
for param in self.parameters():
param.requires_grad = False
self.eval()
return self
def freeze_conv(self):
for param in self.convnets.parameters():
param.requires_grad = False
self.convnets.eval()
def weight_align(self, old, increment, value):
weights = self.fc.weight.data
newnorm = torch.norm(weights[-increment:, :], p=2, dim=1)
oldnorm = torch.norm(weights[:-increment, :], p=2, dim=1)
meannew = torch.mean(newnorm)
meanold = torch.mean(oldnorm)
gamma = meanold / meannew * (value ** (old / increment))
logging.info("align weights, gamma = {} ".format(gamma))
self.fc.weight.data[-increment:, :] *= gamma
def load_checkpoint(self, args):
if args["init_cls"] == 50:
pkl_name = "{}_{}_{}_B{}_Inc{}".format(
args["dataset"],
args["seed"],
args["convnet_type"],
0,
args["init_cls"],
)
checkpoint_name = f"checkpoints/finetune_{pkl_name}_0.pkl"
else:
checkpoint_name = f"checkpoints/finetune_{args['csv_name']}_0.pkl"
model_infos = torch.load(checkpoint_name)
assert len(self.convnets) == 1
self.convnets[0].load_state_dict(model_infos['convnet'])
self.fc.load_state_dict(model_infos['fc'])
test_acc = model_infos['test_acc']
return test_acc
class BiasLayer(nn.Module):
def __init__(self):
super(BiasLayer, self).__init__()
self.alpha = nn.Parameter(torch.zeros(1, requires_grad=True))
self.beta = nn.Parameter(torch.zeros(1, requires_grad=True))
def forward(self, x, bias=True):
ret_x = x.clone()
ret_x = (self.alpha + 1) * x # + self.beta
if bias:
ret_x = ret_x + self.beta
return ret_x
def get_params(self):
return (self.alpha.item(), self.beta.item())
class BEEFISONet(nn.Module):
def __init__(self, args, pretrained):
super(BEEFISONet, self).__init__()
self.convnet_type = args["convnet_type"]
self.convnets = nn.ModuleList()
self.pretrained = pretrained
self.out_dim = None
self.old_fc = None
self.new_fc = None
self.task_sizes = []
self.forward_prototypes = None
self.backward_prototypes = None
self.args = args
self.biases = nn.ModuleList()
@property
def feature_dim(self):
if self.out_dim is None:
return 0
return self.out_dim * len(self.convnets)
def extract_vector(self, x):
features = [convnet(x)["features"] for convnet in self.convnets]
features = torch.cat(features, 1)
return features
def forward(self, x):
features = [convnet(x)["features"] for convnet in self.convnets]
features = torch.cat(features, 1)
if self.old_fc is None:
fc = self.new_fc
out = fc(features)
else:
'''
merge the weights
'''
new_task_size = self.task_sizes[-1]
fc_weight = torch.cat(
[self.old_fc.weight, torch.zeros((new_task_size, self.feature_dim - self.out_dim)).cuda()], dim=0)
new_fc_weight = self.new_fc.weight
new_fc_bias = self.new_fc.bias
for i in range(len(self.task_sizes) - 2, -1, -1):
new_fc_weight = torch.cat([*[self.biases[i](self.backward_prototypes.weight[i].unsqueeze(0), bias=False)
for _ in range(self.task_sizes[i])], new_fc_weight], dim=0)
new_fc_bias = torch.cat([*[self.biases[i](self.backward_prototypes.bias[i].unsqueeze(0), bias=True) for
_ in range(self.task_sizes[i])], new_fc_bias])
fc_weight = torch.cat([fc_weight, new_fc_weight], dim=1)
fc_bias = torch.cat([self.old_fc.bias, torch.zeros(new_task_size).cuda()])
fc_bias = +new_fc_bias
logits = features @ fc_weight.permute(1, 0) + fc_bias
out = {"logits": logits}
new_fc_weight = self.new_fc.weight
new_fc_bias = self.new_fc.bias
for i in range(len(self.task_sizes) - 2, -1, -1):
new_fc_weight = torch.cat([self.backward_prototypes.weight[i].unsqueeze(0), new_fc_weight], dim=0)
new_fc_bias = torch.cat([self.backward_prototypes.bias[i].unsqueeze(0), new_fc_bias])
out["train_logits"] = features[:, -self.out_dim:] @ new_fc_weight.permute(1, 0) + new_fc_bias
out.update({"eval_logits": out["logits"],
"energy_logits": self.forward_prototypes(features[:, -self.out_dim:])["logits"]})
if self.args["attack"]:
return out["logits"]
else:
return out
def update_fc_before(self, nb_classes):
new_task_size = nb_classes - sum(self.task_sizes)
self.biases = nn.ModuleList([BiasLayer() for i in range(len(self.task_sizes))])
self.convnets.append(get_convnet(self.args))
if self.out_dim is None:
self.out_dim = self.convnets[-1].out_dim
if self.new_fc is not None:
self.fe_fc = self.generate_fc(self.out_dim, nb_classes)
self.backward_prototypes = self.generate_fc(self.out_dim, len(self.task_sizes))
self.convnets[-1].load_state_dict(self.convnets[0].state_dict())
self.forward_prototypes = self.generate_fc(self.out_dim, nb_classes)
self.new_fc = self.generate_fc(self.out_dim, new_task_size)
self.task_sizes.append(new_task_size)
def generate_fc(self, in_dim, out_dim):
fc = SimpleLinear(in_dim, out_dim)
return fc
def update_fc_after(self):
if self.old_fc is not None:
old_fc = self.generate_fc(self.feature_dim, sum(self.task_sizes))
new_task_size = self.task_sizes[-1]
old_fc.weight.data = torch.cat(
[self.old_fc.weight.data, torch.zeros((new_task_size, self.feature_dim - self.out_dim)).cuda()], dim=0)
new_fc_weight = self.new_fc.weight.data
new_fc_bias = self.new_fc.bias.data
for i in range(len(self.task_sizes) - 2, -1, -1):
new_fc_weight = torch.cat([*[
self.biases[i](self.backward_prototypes.weight.data[i].unsqueeze(0), bias=False) for _ in
range(self.task_sizes[i])], new_fc_weight], dim=0)
new_fc_bias = torch.cat([*[self.biases[i](self.backward_prototypes.bias.data[i].unsqueeze(0), bias=True)
for _ in range(self.task_sizes[i])], new_fc_bias])
old_fc.weight.data = torch.cat([old_fc.weight.data, new_fc_weight], dim=1)
old_fc.bias.data = torch.cat([self.old_fc.bias.data, torch.zeros(new_task_size).cuda()])
old_fc.bias.data += new_fc_bias
self.old_fc = old_fc
else:
self.old_fc = self.new_fc
def copy(self):
return copy.deepcopy(self)
def copy_fc(self, fc):
weight = copy.deepcopy(fc.weight.data)
bias = copy.deepcopy(fc.bias.data)
n, m = weight.shape[0], weight.shape[1]
self.fc.weight.data[:n, :m] = weight
self.fc.bias.data[:n] = bias
def freeze(self):
for param in self.parameters():
param.requires_grad = False
self.eval()
return self
def freeze_conv(self):
for param in self.convnets.parameters():
param.requires_grad = False
self.convnets.eval()
def weight_align(self, old, increment, value):
weights = self.fc.weight.data
newnorm = torch.norm(weights[-increment:, :], p=2, dim=1)
oldnorm = torch.norm(weights[:-increment, :], p=2, dim=1)
meannew = torch.mean(newnorm)
meanold = torch.mean(oldnorm)
gamma = meanold / meannew * (value ** (old / increment))
logging.info("align weights, gamma = {} ".format(gamma))
self.fc.weight.data[-increment:, :] *= gamma
class AdaptiveNet(nn.Module):
def __init__(self, args, pretrained):
super(AdaptiveNet, self).__init__()
self.convnet_type = args["convnet_type"]
self.TaskAgnosticExtractor, _ = get_convnet(args, pretrained) # Generalized blocks
self.TaskAgnosticExtractor.train()
self.AdaptiveExtractors = nn.ModuleList() # Specialized Blocks
self.pretrained = pretrained
self.out_dim = None
self.fc = None
self.aux_fc = None
self.task_sizes = []
self.args = args
@property
def feature_dim(self):
if self.out_dim is None:
return 0
return self.out_dim * len(self.AdaptiveExtractors)
def extract_vector(self, x):
base_feature_map = self.TaskAgnosticExtractor(x)
features = [extractor(base_feature_map) for extractor in self.AdaptiveExtractors]
features = torch.cat(features, 1)
return features
def forward(self, x):
base_feature_map = self.TaskAgnosticExtractor(x)
features = [extractor(base_feature_map) for extractor in self.AdaptiveExtractors]
features = torch.cat(features, 1)
out = self.fc(features) # {logits: self.fc(features)}
aux_logits = self.aux_fc(features[:, -self.out_dim:])["logits"]
out.update({"aux_logits": aux_logits, "features": features})
out.update({"base_features": base_feature_map})
if self.args["attack"]:
return out["logits"]
else:
return out
'''
{
'features': features
'logits': logits
'aux_logits':aux_logits
}
'''
def update_fc(self, nb_classes):
_, _new_extractor = get_convnet(self.args)
if len(self.AdaptiveExtractors) == 0:
self.AdaptiveExtractors.append(_new_extractor)
else:
self.AdaptiveExtractors.append(_new_extractor)
self.AdaptiveExtractors[-1].load_state_dict(self.AdaptiveExtractors[-2].state_dict())
if self.out_dim is None:
logging.info(self.AdaptiveExtractors[-1])
self.out_dim = self.AdaptiveExtractors[-1].feature_dim
fc = self.generate_fc(self.feature_dim, nb_classes)
if self.fc is not None:
nb_output = self.fc.out_features
weight = copy.deepcopy(self.fc.weight.data)
bias = copy.deepcopy(self.fc.bias.data)
fc.weight.data[:nb_output, :self.feature_dim - self.out_dim] = weight
fc.bias.data[:nb_output] = bias
del self.fc
self.fc = fc
new_task_size = nb_classes - sum(self.task_sizes)
self.task_sizes.append(new_task_size)
self.aux_fc = self.generate_fc(self.out_dim, new_task_size + 1)
def generate_fc(self, in_dim, out_dim):
fc = SimpleLinear(in_dim, out_dim)
return fc
def copy(self):
return copy.deepcopy(self)
def weight_align(self, increment):
weights = self.fc.weight.data
newnorm = (torch.norm(weights[-increment:, :], p=2, dim=1))
oldnorm = (torch.norm(weights[:-increment, :], p=2, dim=1))
meannew = torch.mean(newnorm)
meanold = torch.mean(oldnorm)
gamma = meanold / meannew
print('alignweights,gamma=', gamma)
self.fc.weight.data[-increment:, :] *= gamma
def load_checkpoint(self, args):
if args["init_cls"] == 50:
pkl_name = "{}_{}_{}_B{}_Inc{}".format(
args["dataset"],
args["seed"],
args["convnet_type"],
0,
args["init_cls"],
)
checkpoint_name = f"checkpoints/finetune_{pkl_name}_0.pkl"
else:
checkpoint_name = f"checkpoints/finetune_{args['csv_name']}_0.pkl"
checkpoint_name = checkpoint_name.replace("memo_", "")
model_infos = torch.load(checkpoint_name)
model_dict = model_infos['convnet']
assert len(self.AdaptiveExtractors) == 1
base_state_dict = self.TaskAgnosticExtractor.state_dict()
adap_state_dict = self.AdaptiveExtractors[0].state_dict()
pretrained_base_dict = {
k: v
for k, v in model_dict.items()
if k in base_state_dict
}
pretrained_adap_dict = {
k: v
for k, v in model_dict.items()
if k in adap_state_dict
}
base_state_dict.update(pretrained_base_dict)
adap_state_dict.update(pretrained_adap_dict)
self.TaskAgnosticExtractor.load_state_dict(base_state_dict)
self.AdaptiveExtractors[0].load_state_dict(adap_state_dict)
self.fc.load_state_dict(model_infos['fc'])
test_acc = model_infos['test_acc']
return test_acc
class ACILNet(BaseNet):
"""
Network structure of the ACIL [1].
This implementation refers to the official implementation https://github.com/ZHUANGHP/Analytic-continual-learning.
References:
[1] Zhuang, Huiping, et al.
"ACIL: Analytic class-incremental learning with absolute memorization and privacy protection."
Advances in Neural Information Processing Systems 35 (2022): 11602-11614.
"""
def __init__(
self,
args: Dict[str, Any],
buffer_size: int = 8192,
gamma: float = 0.1,
pretrained: bool = False,
device=None,
dtype=torch.double,
) -> None:
super().__init__(args, pretrained)
assert isinstance(
self.convnet, torch.nn.Module
), "The backbone network `convnet` must be a `torch.nn.Module`."
self.convnet: torch.nn.Module = self.convnet.to(device, non_blocking=True)
self.args = args
self.buffer_size: int = buffer_size
self.gamma: float = gamma
self.device = device
self.dtype = dtype
@torch.no_grad()
def forward(self, X: torch.Tensor) -> Dict[str, torch.Tensor]:
X = self.convnet(X)["features"]
X = self.buffer(X)
X = self.fc(X)["logits"]
return {"logits": X}
def update_fc(self, nb_classes: int) -> None:
self.fc.update_fc(nb_classes)
def generate_fc(self, *_) -> None:
self.fc = RecursiveLinear(
self.buffer_size,
self.gamma,
bias=False,
device=self.device,
dtype=self.dtype,
)
def generate_buffer(self) -> None:
self.buffer = RandomBuffer(
self.feature_dim, self.buffer_size, device=self.device, dtype=self.dtype
)
def after_task(self) -> None:
self.fc.after_task()
@torch.no_grad()
def fit(self, X: torch.Tensor, y: torch.Tensor) -> None:
X = self.convnet(X)["features"]
X = self.buffer(X)
Y: torch.Tensor = torch.nn.functional.one_hot(y, self.fc.out_features)
self.fc.fit(X, Y)
class DSALNet(ACILNet):
"""
Network structure of the DS-AL [1].
This implementation refers to the official implementation https://github.com/ZHUANGHP/Analytic-continual-learning.
References:
[1] Zhuang, Huiping, et al.
"DS-AL: A Dual-Stream Analytic Learning for Exemplar-Free Class-Incremental Learning."
Proceedings of the AAAI Conference on Artificial Intelligence. Vol. 38. No. 15. 2024.
"""
def __init__(
self,
args: Dict[str, Any],
buffer_size: int = 8192,
gamma_main: float = 1e-3,
gamma_comp: float = 1e-3,
C: float = 1,
activation_main: activation_t = torch.relu,
activation_comp: activation_t = torch.tanh,
pretrained: bool = False,
device=None,
dtype=torch.double,
) -> None:
self.C = C
self.gamma_comp = gamma_comp
self.activation_main = activation_main
self.activation_comp = activation_comp
super().__init__(args, buffer_size, gamma_main, pretrained, device, dtype)
@torch.no_grad()
def forward(self, X: torch.Tensor) -> Dict[str, torch.Tensor]:
X = self.buffer(self.convnet(X)["features"])
X_main = self.fc(self.activation_main(X))["logits"]
X_comp = self.fc_comp(self.activation_comp(X))["logits"]
return {"logits": X_main + self.C * X_comp}
@torch.no_grad()
def fit(self, X: torch.Tensor, y: torch.Tensor) -> None:
num_classes = max(self.fc.out_features, int(y.max().item()) + 1)
Y_main = torch.nn.functional.one_hot(y, num_classes=num_classes)
X = self.buffer(self.convnet(X)["features"])
# Train the main stream
X_main = self.activation_main(X)
self.fc.fit(X_main, Y_main)
self.fc.after_task()
# Previous label cleansing (PLC)
Y_comp = Y_main - self.fc(X_main)["logits"]
Y_comp[:, : -self.increment_size] = 0
# Train the compensation stream
X_comp = self.activation_comp(X)
self.fc_comp.fit(X_comp, Y_comp)
@torch.no_grad()
def after_task(self) -> None:
self.fc.after_task()
self.fc_comp.after_task()
def generate_buffer(self) -> None:
self.buffer = RandomBuffer(
self.feature_dim,
self.buffer_size,
activation=None,
device=self.device,
dtype=self.dtype,
)
def generate_fc(self, *_) -> None:
# Main stream
self.fc = RecursiveLinear(
self.buffer_size,
self.gamma,
bias=False,
device=self.device,
dtype=self.dtype,
)
# Compensation stream
self.fc_comp = RecursiveLinear(
self.buffer_size,
self.gamma_comp,
bias=False,
device=self.device,
dtype=self.dtype,
)
def update_fc(self, nb_classes) -> None:
self.increment_size = nb_classes - self.fc.out_features
self.fc.update_fc(nb_classes)
self.fc_comp.update_fc(nb_classes)