boringKey's picture
Upload 236 files
5fee096 verified
"""
@inproceedings{10.24963/ijcai.2024/456,
author = {Hong, Chenxing and Jin, Yan and Kang, Zhiqi and Chen, Yizhou and Li, Mengke and Lu, Yang and Wang, Hanzi},
title = {Dynamically anchored prompting for task-imbalanced continual learning},
booktitle = {Proceedings of the Thirty-Third International Joint Conference on Artificial Intelligence},
year = {2025},
}
https://dl.acm.org/doi/10.24963/ijcai.2024/456
Adapted from https://github.com/chenxing6666/dap
"""
import math
import copy
import torch
import torch.nn.functional as F
from .finetune import Finetune
import numpy as np
from torch.utils.data import DataLoader
global_max_dist = torch.tensor(0)
global_max_dist2 = torch.tensor(0)
global_lam = 0.25
class DAP(Finetune):
def __init__(self, backbone, feat_dim, num_class, **kwargs):
super().__init__(backbone, feat_dim, num_class, **kwargs)
self.kwargs = kwargs
self.network = backbone
self.train_mask = kwargs['train_mask']
self.task_inc = kwargs['task_inc']
self.pull_constraint = kwargs['pull_constraint']
self.pull_constraint_coeff = kwargs['pull_constraint_coeff']
self.task_idx = 0
self.task_data_count = []
self.prompt_center = None
# initialize class_mask
if self.num_class % kwargs['task_num'] != 0:
raise ValueError('Number of classes must be divisible by number of tasks')
classes_per_task = self.num_class // kwargs['task_num']
self.class_mask = [list(range(i * classes_per_task, (i + 1) * classes_per_task)) for i in range(kwargs['task_num'])]
self.original_model = copy.deepcopy(self.backbone)
self.original_model.to(self.device)
self.original_model.eval()
if kwargs['freeze']:
# all parameters are frozen for original vit model
for p in self.original_model.parameters():
p.requires_grad = False
# freeze args.freeze[blocks, patch_embed, cls_token] parameters
for n, p in self.network.named_parameters():
if n.startswith(tuple(kwargs['freeze'])):
p.requires_grad = False
self.loss_fn.to(self.device)
def observe(self, data, train_gprompt=False, gen=False):
x, y = data['image'], data['label']
x = x.to(self.device)
y = y.to(self.device)
with torch.no_grad():
if self.original_model is not None:
output = self.original_model(x)
cls_features = output['pre_logits']
else:
cls_features = None
if gen:
output = self.network(x, task_id=self.task_idx, cls_features=cls_features, train=True, gen=gen)
else:
output = self.network(x, task_id=self.task_idx, cls_features=cls_features, train=True)
logits = output['logits']
# here is the trick to mask out classes of non-current tasks
if self.train_mask and self.class_mask is not None:
mask = self.class_mask[self.task_idx]
not_mask = np.setdiff1d(np.arange(self.num_class), mask)
not_mask = torch.tensor(not_mask, dtype=torch.int64).to(self.device)
logits = logits.index_fill(
dim=1, index=not_mask, value=float('-inf'))
if (train_gprompt):
pla_similarity_loss_res = self.cal_latestsimilarity_loss(
model=self.network, task_id=self.task_idx)
sta_similarity_loss_res = self.cal_similarity_loss(model=self.network, task_id=self.task_idx, prompt_center=self.prompt_center)
pla_similarity_loss = pla_similarity_loss_res['similarity']
sta_similarity_loss = sta_similarity_loss_res['avg_similarity']
min_data_count = min(self.task_data_count)
max_data_count = max(self.task_data_count)
last_data_count = self.task_data_count[-1]
epsilon = 1e-10
alpha = (last_data_count - min_data_count) / (max_data_count - min_data_count + epsilon)
loss2 = alpha*sta_similarity_loss
loss3 = (1-alpha)*pla_similarity_loss
loss = self.loss_fn(logits, y) + loss2 + loss3
else:
# base criterion (CrossEntropyLoss)
loss = self.loss_fn(logits, y)
if self.pull_constraint and 'reduce_sim' in output:
loss = loss - self.pull_constraint_coeff * output['reduce_sim']
if not math.isfinite(loss.item()):
raise RuntimeError(f'Loss is {loss.item()}, stopping training')
pred = torch.argmax(logits, dim=1)
acc = torch.sum(pred == y).item()
return pred, acc / x.size(0), loss
def inference(self, data):
x, y = data['image'], data['label']
x = x.to(self.device)
y = y.to(self.device)
with torch.no_grad():
if self.original_model is not None:
output = self.original_model(x)
cls_features = output['pre_logits']
else:
cls_features = None
output = self.network(x, task_id=self.task_idx, cls_features=cls_features, gen=True)
logits = output['logits']
# adding mask to output logits
if self.task_inc and self.class_mask is not None:
mask = self.class_mask[self.task_idx]
mask = torch.tensor(mask, dtype=torch.int64).to(self.device)
logits_mask = torch.ones_like(logits, device=self.device) * float('-inf')
logits_mask = logits_mask.index_fill(1, mask, 0.0)
logits = logits + logits_mask
pred = torch.argmax(logits, dim=1)
acc = torch.sum(pred == y).item()
return pred, acc / x.size(0)
def before_task(self, task_idx, buffer, train_loader, test_loaders):
self.task_idx = task_idx
self.network.task_id = task_idx
self.task_data_count.append(len(train_loader.dataset))
@staticmethod
def cal_latestsimilarity_loss(model: torch.nn.Module, task_id=-1):
res = dict()
global global_max_dist2
gprompt = model.prompt.generalprompt
tprompt = model.prompt.taskprompt[task_id].detach()
gprompt_flat = gprompt.view(-1)
tprompt_tensors = tprompt.view(-1)
similarity = 1-F.cosine_similarity(gprompt_flat, tprompt_tensors, dim=0)
res['similarity'] = similarity
return res
@staticmethod
def cal_center(model: torch.nn.Module, task_id=-1, task_data_count=None, prompt_center=None):
tprompt = model.prompt.taskprompt
if task_id > 0:
if prompt_center is None:
prompt_center = tprompt[0].detach().view(-1)
current_tprompt = tprompt[task_id - 1].detach().view(-1)
if task_data_count:
weights = [1 / count for count in task_data_count[:task_id]]
normalized_weight = weights[-1] / sum(weights)
weights2 = sum(weights[:-1]) / sum(weights)
else:
normalized_weight = 1.0 / task_id
prompt_center = (prompt_center * weights2) + \
(current_tprompt * normalized_weight)
else:
prompt_center = torch.zeros_like(tprompt[0].detach().view(-1))
return prompt_center
@staticmethod
def cal_similarity_loss(model: torch.nn.Module, task_id=-1, prompt_center=None):
res = dict()
global global_max_dist
gprompt = model.prompt.generalprompt
if task_id > 0:
gprompt_flat = gprompt.view(-1)
similarity = 1-F.cosine_similarity(gprompt_flat, prompt_center, dim=0)
res['similarity'] = similarity
res['avg_similarity'] = similarity
else:
res['similarity'] = torch.tensor(0)
res['avg_similarity'] = 0
return res