SAE / convs /linears.py
Ttius's picture
Upload 192 files
998bb30 verified
'''
Reference:
https://github.com/hshustc/CVPR19_Incremental_Learning/blob/master/cifar100-class-incremental/modified_linear.py
'''
import math
import torch
from torch import nn
from torch.nn import functional as F
from typing import Union, Optional, Dict
from abc import ABCMeta, abstractmethod
class SimpleLinear(nn.Module):
'''
Reference:
https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/linear.py
'''
def __init__(self, in_features, out_features, bias=True):
super(SimpleLinear, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.weight = nn.Parameter(torch.Tensor(out_features, in_features))
if bias:
self.bias = nn.Parameter(torch.Tensor(out_features))
else:
self.register_parameter('bias', None)
self.reset_parameters()
def reset_parameters(self):
nn.init.kaiming_uniform_(self.weight, nonlinearity='linear')
nn.init.constant_(self.bias, 0)
def forward(self, input):
return {'logits': F.linear(input, self.weight, self.bias)}
class CosineLinear(nn.Module):
def __init__(self, in_features, out_features, nb_proxy=1, to_reduce=False, sigma=True):
super(CosineLinear, self).__init__()
self.in_features = in_features
self.out_features = out_features * nb_proxy
self.nb_proxy = nb_proxy
self.to_reduce = to_reduce
self.weight = nn.Parameter(torch.Tensor(self.out_features, in_features))
if sigma:
self.sigma = nn.Parameter(torch.Tensor(1))
else:
self.register_parameter('sigma', None)
self.reset_parameters()
def reset_parameters(self):
stdv = 1. / math.sqrt(self.weight.size(1))
self.weight.data.uniform_(-stdv, stdv)
if self.sigma is not None:
self.sigma.data.fill_(1)
def forward(self, input):
out = F.linear(F.normalize(input, p=2, dim=1), F.normalize(self.weight, p=2, dim=1))
if self.to_reduce:
# Reduce_proxy
out = reduce_proxies(out, self.nb_proxy)
if self.sigma is not None:
out = self.sigma * out
return {'logits': out}
class SplitCosineLinear(nn.Module):
def __init__(self, in_features, out_features1, out_features2, nb_proxy=1, sigma=True):
super(SplitCosineLinear, self).__init__()
self.in_features = in_features
self.out_features = (out_features1 + out_features2) * nb_proxy
self.nb_proxy = nb_proxy
self.fc1 = CosineLinear(in_features, out_features1, nb_proxy, False, False)
self.fc2 = CosineLinear(in_features, out_features2, nb_proxy, False, False)
if sigma:
self.sigma = nn.Parameter(torch.Tensor(1))
self.sigma.data.fill_(1)
else:
self.register_parameter('sigma', None)
def forward(self, x):
out1 = self.fc1(x)
out2 = self.fc2(x)
out = torch.cat((out1['logits'], out2['logits']), dim=1) # concatenate along the channel
# Reduce_proxy
out = reduce_proxies(out, self.nb_proxy)
if self.sigma is not None:
out = self.sigma * out
return {
'old_scores': reduce_proxies(out1['logits'], self.nb_proxy),
'new_scores': reduce_proxies(out2['logits'], self.nb_proxy),
'logits': out
}
class AnalyticLinear(torch.nn.Linear, metaclass=ABCMeta):
"""
Abstract linear module for the analytic continual learning [1-3].
This implementation refers to the official implementation https://github.com/ZHUANGHP/Analytic-continual-learning.
References:
[1] Zhuang, Huiping, et al.
"ACIL: Analytic class-incremental learning with absolute memorization and privacy protection."
Advances in Neural Information Processing Systems 35 (2022): 11602-11614.
[2] Zhuang, Huiping, et al.
"GKEAL: Gaussian Kernel Embedded Analytic Learning for Few-Shot Class Incremental Task."
Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition. 2023.
[3] Zhuang, Huiping, et al.
"DS-AL: A Dual-Stream Analytic Learning for Exemplar-Free Class-Incremental Learning."
Proceedings of the AAAI Conference on Artificial Intelligence. Vol. 38. No. 15. 2024.
"""
def __init__(
self,
in_features: int,
gamma: float = 1e-1,
bias: bool = False,
device: Optional[Union[torch.device, str, int]] = None,
dtype=torch.double,
) -> None:
super(torch.nn.Linear, self).__init__() # Skip the Linear class
factory_kwargs = {"device": device, "dtype": dtype}
self.gamma: float = gamma
self.bias: bool = bias
self.dtype = dtype
# Linear Layer
if bias:
in_features += 1
weight = torch.zeros((in_features, 0), **factory_kwargs)
self.register_buffer("weight", weight)
@torch.no_grad()
def forward(self, X: torch.Tensor) -> Dict[str, torch.Tensor]:
X = X.to(self.weight)
if self.bias:
X = torch.cat((X, torch.ones(X.shape[0], 1).to(X)), dim=-1)
return {"logits": X @ self.weight}
@property
def in_features(self) -> int:
if self.bias:
return self.weight.shape[0] - 1
return self.weight.shape[0]
@property
def out_features(self) -> int:
return self.weight.shape[1]
def reset_parameters(self) -> None:
# Following the equation (4) of ACIL, self.weight is set to \hat{W}_{FCN}^{-1}
self.weight = torch.zeros((self.weight.shape[0], 0)).to(self.weight)
@abstractmethod
def fit(self, X: torch.Tensor, Y: torch.Tensor) -> None:
raise NotImplementedError()
def after_task(self) -> None:
assert torch.isfinite(self.weight).all(), (
"Pay attention to the numerical stability! "
"A possible solution is to increase the value of gamma. "
"Setting self.dtype=torch.double also helps."
)
class RecursiveLinear(AnalyticLinear):
"""
Recursive analytic linear (ridge regression) modules for the analytic continual learning [1-3].
This implementation refers to the official implementation https://github.com/ZHUANGHP/Analytic-continual-learning.
References:
[1] Zhuang, Huiping, et al.
"ACIL: Analytic class-incremental learning with absolute memorization and privacy protection."
Advances in Neural Information Processing Systems 35 (2022): 11602-11614.
[2] Zhuang, Huiping, et al.
"GKEAL: Gaussian Kernel Embedded Analytic Learning for Few-Shot Class Incremental Task."
Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition. 2023.
[3] Zhuang, Huiping, et al.
"DS-AL: A Dual-Stream Analytic Learning for Exemplar-Free Class-Incremental Learning."
Proceedings of the AAAI Conference on Artificial Intelligence. Vol. 38. No. 15. 2024.
"""
def __init__(
self,
in_features: int,
gamma: float = 1e-1,
bias: bool = False,
device: Optional[Union[torch.device, str, int]] = None,
dtype=torch.double,
) -> None:
super().__init__(in_features, gamma, bias, device, dtype)
factory_kwargs = {"device": device, "dtype": dtype}
# Regularized Feature Autocorrelation Matrix (RFAuM)
self.R: torch.Tensor
R = torch.eye(self.weight.shape[0], **factory_kwargs) / self.gamma
self.register_buffer("R", R)
def update_fc(self, nb_classes: int) -> None:
increment_size = nb_classes - self.out_features
assert increment_size >= 0, "The number of classes should be increasing."
tail = torch.zeros((self.weight.shape[0], increment_size)).to(self.weight)
self.weight = torch.cat((self.weight, tail), dim=1)
@torch.no_grad()
def fit(self, X: torch.Tensor, Y: torch.Tensor) -> None:
"""The core code of the ACIL [1].
This implementation, which is different but equivalent to the equations shown in the paper,
which supports mini-batch learning.
"""
X, Y = X.to(self.weight), Y.to(self.weight)
if self.bias:
X = torch.cat((X, torch.ones(X.shape[0], 1).to(X)), dim=-1)
# ACIL
# Please update your PyTorch & CUDA if the `cusolver error` occurs.
# If you insist on using this version, doing the `torch.inverse` on CPUs might help.
# >>> K_inv = torch.eye(X.shape[0]).to(X) + X @ self.R @ X.T
# >>> K = torch.inverse(K_inv.cpu()).to(self.weight.device)
K = torch.inverse(torch.eye(X.shape[0]).to(X) + X @ self.R @ X.T)
# Equation (10) of ACIL
self.R -= self.R @ X.T @ K @ X @ self.R
# Equation (9) of ACIL
self.weight += self.R @ X.T @ (Y - X @ self.weight)
def reduce_proxies(out, nb_proxy):
if nb_proxy == 1:
return out
bs = out.shape[0]
nb_classes = out.shape[1] / nb_proxy
assert nb_classes.is_integer(), 'Shape error'
nb_classes = int(nb_classes)
simi_per_class = out.view(bs, nb_classes, nb_proxy)
attentions = F.softmax(simi_per_class, dim=-1)
return (attentions * simi_per_class).sum(-1)
'''
class CosineLinear(nn.Module):
def __init__(self, in_features, out_features, sigma=True):
super(CosineLinear, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.weight = nn.Parameter(torch.Tensor(out_features, in_features))
if sigma:
self.sigma = nn.Parameter(torch.Tensor(1))
else:
self.register_parameter('sigma', None)
self.reset_parameters()
def reset_parameters(self):
stdv = 1. / math.sqrt(self.weight.size(1))
self.weight.data.uniform_(-stdv, stdv)
if self.sigma is not None:
self.sigma.data.fill_(1)
def forward(self, input):
out = F.linear(F.normalize(input, p=2, dim=1), F.normalize(self.weight, p=2, dim=1))
if self.sigma is not None:
out = self.sigma * out
return {'logits': out}
class SplitCosineLinear(nn.Module):
def __init__(self, in_features, out_features1, out_features2, sigma=True):
super(SplitCosineLinear, self).__init__()
self.in_features = in_features
self.out_features = out_features1 + out_features2
self.fc1 = CosineLinear(in_features, out_features1, False)
self.fc2 = CosineLinear(in_features, out_features2, False)
if sigma:
self.sigma = nn.Parameter(torch.Tensor(1))
self.sigma.data.fill_(1)
else:
self.register_parameter('sigma', None)
def forward(self, x):
out1 = self.fc1(x)
out2 = self.fc2(x)
out = torch.cat((out1['logits'], out2['logits']), dim=1) # concatenate along the channel
if self.sigma is not None:
out = self.sigma * out
return {
'old_scores': out1['logits'],
'new_scores': out2['logits'],
'logits': out
}
'''