|
|
''' |
|
|
Reference: |
|
|
https://github.com/hshustc/CVPR19_Incremental_Learning/blob/master/cifar100-class-incremental/modified_linear.py |
|
|
''' |
|
|
import math |
|
|
import torch |
|
|
from torch import nn |
|
|
from torch.nn import functional as F |
|
|
from typing import Union, Optional, Dict |
|
|
from abc import ABCMeta, abstractmethod |
|
|
|
|
|
|
|
|
class SimpleLinear(nn.Module): |
|
|
''' |
|
|
Reference: |
|
|
https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/linear.py |
|
|
''' |
|
|
def __init__(self, in_features, out_features, bias=True): |
|
|
super(SimpleLinear, self).__init__() |
|
|
self.in_features = in_features |
|
|
self.out_features = out_features |
|
|
self.weight = nn.Parameter(torch.Tensor(out_features, in_features)) |
|
|
if bias: |
|
|
self.bias = nn.Parameter(torch.Tensor(out_features)) |
|
|
else: |
|
|
self.register_parameter('bias', None) |
|
|
self.reset_parameters() |
|
|
|
|
|
def reset_parameters(self): |
|
|
nn.init.kaiming_uniform_(self.weight, nonlinearity='linear') |
|
|
nn.init.constant_(self.bias, 0) |
|
|
|
|
|
def forward(self, input): |
|
|
return {'logits': F.linear(input, self.weight, self.bias)} |
|
|
|
|
|
|
|
|
class CosineLinear(nn.Module): |
|
|
def __init__(self, in_features, out_features, nb_proxy=1, to_reduce=False, sigma=True): |
|
|
super(CosineLinear, self).__init__() |
|
|
self.in_features = in_features |
|
|
self.out_features = out_features * nb_proxy |
|
|
self.nb_proxy = nb_proxy |
|
|
self.to_reduce = to_reduce |
|
|
self.weight = nn.Parameter(torch.Tensor(self.out_features, in_features)) |
|
|
if sigma: |
|
|
self.sigma = nn.Parameter(torch.Tensor(1)) |
|
|
else: |
|
|
self.register_parameter('sigma', None) |
|
|
self.reset_parameters() |
|
|
|
|
|
def reset_parameters(self): |
|
|
stdv = 1. / math.sqrt(self.weight.size(1)) |
|
|
self.weight.data.uniform_(-stdv, stdv) |
|
|
if self.sigma is not None: |
|
|
self.sigma.data.fill_(1) |
|
|
|
|
|
def forward(self, input): |
|
|
out = F.linear(F.normalize(input, p=2, dim=1), F.normalize(self.weight, p=2, dim=1)) |
|
|
|
|
|
if self.to_reduce: |
|
|
|
|
|
out = reduce_proxies(out, self.nb_proxy) |
|
|
|
|
|
if self.sigma is not None: |
|
|
out = self.sigma * out |
|
|
|
|
|
return {'logits': out} |
|
|
|
|
|
|
|
|
class SplitCosineLinear(nn.Module): |
|
|
def __init__(self, in_features, out_features1, out_features2, nb_proxy=1, sigma=True): |
|
|
super(SplitCosineLinear, self).__init__() |
|
|
self.in_features = in_features |
|
|
self.out_features = (out_features1 + out_features2) * nb_proxy |
|
|
self.nb_proxy = nb_proxy |
|
|
self.fc1 = CosineLinear(in_features, out_features1, nb_proxy, False, False) |
|
|
self.fc2 = CosineLinear(in_features, out_features2, nb_proxy, False, False) |
|
|
if sigma: |
|
|
self.sigma = nn.Parameter(torch.Tensor(1)) |
|
|
self.sigma.data.fill_(1) |
|
|
else: |
|
|
self.register_parameter('sigma', None) |
|
|
|
|
|
def forward(self, x): |
|
|
out1 = self.fc1(x) |
|
|
out2 = self.fc2(x) |
|
|
|
|
|
out = torch.cat((out1['logits'], out2['logits']), dim=1) |
|
|
|
|
|
|
|
|
out = reduce_proxies(out, self.nb_proxy) |
|
|
|
|
|
if self.sigma is not None: |
|
|
out = self.sigma * out |
|
|
|
|
|
return { |
|
|
'old_scores': reduce_proxies(out1['logits'], self.nb_proxy), |
|
|
'new_scores': reduce_proxies(out2['logits'], self.nb_proxy), |
|
|
'logits': out |
|
|
} |
|
|
|
|
|
|
|
|
class AnalyticLinear(torch.nn.Linear, metaclass=ABCMeta): |
|
|
""" |
|
|
Abstract linear module for the analytic continual learning [1-3]. |
|
|
|
|
|
This implementation refers to the official implementation https://github.com/ZHUANGHP/Analytic-continual-learning. |
|
|
|
|
|
References: |
|
|
[1] Zhuang, Huiping, et al. |
|
|
"ACIL: Analytic class-incremental learning with absolute memorization and privacy protection." |
|
|
Advances in Neural Information Processing Systems 35 (2022): 11602-11614. |
|
|
[2] Zhuang, Huiping, et al. |
|
|
"GKEAL: Gaussian Kernel Embedded Analytic Learning for Few-Shot Class Incremental Task." |
|
|
Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition. 2023. |
|
|
[3] Zhuang, Huiping, et al. |
|
|
"DS-AL: A Dual-Stream Analytic Learning for Exemplar-Free Class-Incremental Learning." |
|
|
Proceedings of the AAAI Conference on Artificial Intelligence. Vol. 38. No. 15. 2024. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
in_features: int, |
|
|
gamma: float = 1e-1, |
|
|
bias: bool = False, |
|
|
device: Optional[Union[torch.device, str, int]] = None, |
|
|
dtype=torch.double, |
|
|
) -> None: |
|
|
super(torch.nn.Linear, self).__init__() |
|
|
factory_kwargs = {"device": device, "dtype": dtype} |
|
|
self.gamma: float = gamma |
|
|
self.bias: bool = bias |
|
|
self.dtype = dtype |
|
|
|
|
|
|
|
|
if bias: |
|
|
in_features += 1 |
|
|
weight = torch.zeros((in_features, 0), **factory_kwargs) |
|
|
self.register_buffer("weight", weight) |
|
|
|
|
|
@torch.no_grad() |
|
|
def forward(self, X: torch.Tensor) -> Dict[str, torch.Tensor]: |
|
|
X = X.to(self.weight) |
|
|
if self.bias: |
|
|
X = torch.cat((X, torch.ones(X.shape[0], 1).to(X)), dim=-1) |
|
|
return {"logits": X @ self.weight} |
|
|
|
|
|
@property |
|
|
def in_features(self) -> int: |
|
|
if self.bias: |
|
|
return self.weight.shape[0] - 1 |
|
|
return self.weight.shape[0] |
|
|
|
|
|
@property |
|
|
def out_features(self) -> int: |
|
|
return self.weight.shape[1] |
|
|
|
|
|
def reset_parameters(self) -> None: |
|
|
|
|
|
self.weight = torch.zeros((self.weight.shape[0], 0)).to(self.weight) |
|
|
|
|
|
@abstractmethod |
|
|
def fit(self, X: torch.Tensor, Y: torch.Tensor) -> None: |
|
|
raise NotImplementedError() |
|
|
|
|
|
def after_task(self) -> None: |
|
|
assert torch.isfinite(self.weight).all(), ( |
|
|
"Pay attention to the numerical stability! " |
|
|
"A possible solution is to increase the value of gamma. " |
|
|
"Setting self.dtype=torch.double also helps." |
|
|
) |
|
|
|
|
|
|
|
|
class RecursiveLinear(AnalyticLinear): |
|
|
""" |
|
|
Recursive analytic linear (ridge regression) modules for the analytic continual learning [1-3]. |
|
|
|
|
|
This implementation refers to the official implementation https://github.com/ZHUANGHP/Analytic-continual-learning. |
|
|
|
|
|
References: |
|
|
[1] Zhuang, Huiping, et al. |
|
|
"ACIL: Analytic class-incremental learning with absolute memorization and privacy protection." |
|
|
Advances in Neural Information Processing Systems 35 (2022): 11602-11614. |
|
|
[2] Zhuang, Huiping, et al. |
|
|
"GKEAL: Gaussian Kernel Embedded Analytic Learning for Few-Shot Class Incremental Task." |
|
|
Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition. 2023. |
|
|
[3] Zhuang, Huiping, et al. |
|
|
"DS-AL: A Dual-Stream Analytic Learning for Exemplar-Free Class-Incremental Learning." |
|
|
Proceedings of the AAAI Conference on Artificial Intelligence. Vol. 38. No. 15. 2024. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
in_features: int, |
|
|
gamma: float = 1e-1, |
|
|
bias: bool = False, |
|
|
device: Optional[Union[torch.device, str, int]] = None, |
|
|
dtype=torch.double, |
|
|
) -> None: |
|
|
super().__init__(in_features, gamma, bias, device, dtype) |
|
|
factory_kwargs = {"device": device, "dtype": dtype} |
|
|
|
|
|
|
|
|
self.R: torch.Tensor |
|
|
R = torch.eye(self.weight.shape[0], **factory_kwargs) / self.gamma |
|
|
self.register_buffer("R", R) |
|
|
|
|
|
def update_fc(self, nb_classes: int) -> None: |
|
|
increment_size = nb_classes - self.out_features |
|
|
assert increment_size >= 0, "The number of classes should be increasing." |
|
|
tail = torch.zeros((self.weight.shape[0], increment_size)).to(self.weight) |
|
|
self.weight = torch.cat((self.weight, tail), dim=1) |
|
|
|
|
|
@torch.no_grad() |
|
|
def fit(self, X: torch.Tensor, Y: torch.Tensor) -> None: |
|
|
"""The core code of the ACIL [1]. |
|
|
This implementation, which is different but equivalent to the equations shown in the paper, |
|
|
which supports mini-batch learning. |
|
|
""" |
|
|
X, Y = X.to(self.weight), Y.to(self.weight) |
|
|
if self.bias: |
|
|
X = torch.cat((X, torch.ones(X.shape[0], 1).to(X)), dim=-1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
K = torch.inverse(torch.eye(X.shape[0]).to(X) + X @ self.R @ X.T) |
|
|
|
|
|
self.R -= self.R @ X.T @ K @ X @ self.R |
|
|
|
|
|
self.weight += self.R @ X.T @ (Y - X @ self.weight) |
|
|
|
|
|
|
|
|
def reduce_proxies(out, nb_proxy): |
|
|
if nb_proxy == 1: |
|
|
return out |
|
|
bs = out.shape[0] |
|
|
nb_classes = out.shape[1] / nb_proxy |
|
|
assert nb_classes.is_integer(), 'Shape error' |
|
|
nb_classes = int(nb_classes) |
|
|
|
|
|
simi_per_class = out.view(bs, nb_classes, nb_proxy) |
|
|
attentions = F.softmax(simi_per_class, dim=-1) |
|
|
|
|
|
return (attentions * simi_per_class).sum(-1) |
|
|
|
|
|
|
|
|
''' |
|
|
class CosineLinear(nn.Module): |
|
|
def __init__(self, in_features, out_features, sigma=True): |
|
|
super(CosineLinear, self).__init__() |
|
|
self.in_features = in_features |
|
|
self.out_features = out_features |
|
|
self.weight = nn.Parameter(torch.Tensor(out_features, in_features)) |
|
|
if sigma: |
|
|
self.sigma = nn.Parameter(torch.Tensor(1)) |
|
|
else: |
|
|
self.register_parameter('sigma', None) |
|
|
self.reset_parameters() |
|
|
|
|
|
def reset_parameters(self): |
|
|
stdv = 1. / math.sqrt(self.weight.size(1)) |
|
|
self.weight.data.uniform_(-stdv, stdv) |
|
|
if self.sigma is not None: |
|
|
self.sigma.data.fill_(1) |
|
|
|
|
|
def forward(self, input): |
|
|
out = F.linear(F.normalize(input, p=2, dim=1), F.normalize(self.weight, p=2, dim=1)) |
|
|
if self.sigma is not None: |
|
|
out = self.sigma * out |
|
|
return {'logits': out} |
|
|
|
|
|
|
|
|
class SplitCosineLinear(nn.Module): |
|
|
def __init__(self, in_features, out_features1, out_features2, sigma=True): |
|
|
super(SplitCosineLinear, self).__init__() |
|
|
self.in_features = in_features |
|
|
self.out_features = out_features1 + out_features2 |
|
|
self.fc1 = CosineLinear(in_features, out_features1, False) |
|
|
self.fc2 = CosineLinear(in_features, out_features2, False) |
|
|
if sigma: |
|
|
self.sigma = nn.Parameter(torch.Tensor(1)) |
|
|
self.sigma.data.fill_(1) |
|
|
else: |
|
|
self.register_parameter('sigma', None) |
|
|
|
|
|
def forward(self, x): |
|
|
out1 = self.fc1(x) |
|
|
out2 = self.fc2(x) |
|
|
|
|
|
out = torch.cat((out1['logits'], out2['logits']), dim=1) # concatenate along the channel |
|
|
if self.sigma is not None: |
|
|
out = self.sigma * out |
|
|
|
|
|
return { |
|
|
'old_scores': out1['logits'], |
|
|
'new_scores': out2['logits'], |
|
|
'logits': out |
|
|
} |
|
|
''' |
|
|
|