File size: 4,107 Bytes
5fee096 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 | # -*- coding: utf-8 -*-
"""
@inproceedings{DBLP:conf/cvpr/0002ZL0SRSPDP22,
author = {Zifeng Wang and
Zizhao Zhang and
Chen{-}Yu Lee and
Han Zhang and
Ruoxi Sun and
Xiaoqi Ren and
Guolong Su and
Vincent Perot and
Jennifer G. Dy and
Tomas Pfister},
title = {Learning to Prompt for Continual Learning},
booktitle = {{IEEE/CVF} Conference on Computer Vision and Pattern Recognition,
{CVPR} 2022, New Orleans, LA, USA, June 18-24, 2022},
pages = {139--149},
publisher = {{IEEE}},
year = {2022}
}
https://arxiv.org/abs/2112.08654
Adapted from https://github.com/GT-RIPL/CODA-Prompt
"""
import math
import copy
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from core.model.backbone.resnet import *
class Model(nn.Module):
def __init__(self, backbone, embed_dim, total_cls_num):
super().__init__()
self.backbone = backbone
self.classifier = nn.Linear(embed_dim, total_cls_num, bias=True)
def forward(self, x, train=True):
feat, reduce_sim = self.backbone(x, train=train)
return self.classifier(feat), reduce_sim
class L2P(nn.Module):
def __init__(self, backbone, device, **kwargs):
super().__init__()
self.device = device
self.init_cls_num = kwargs['init_cls_num']
self.inc_cls_num = kwargs['inc_cls_num']
self.total_cls_num = kwargs['num_class']
self.task_num = kwargs['task_num']
self.embed_dim = kwargs['feat_dim']
self.pull_constraint_coeff = kwargs['pull_constraint_coeff']
self.cur_task_id = 0
self._known_classes = 0
self.network = Model(backbone, self.embed_dim, self.total_cls_num)
self.network.backbone.create_prompt(
prompt_flag = 'l2p',
length = kwargs['prompt_length'], # L_p
prompt_init = nn.init.uniform_,
pool_size = kwargs['pool_size'], # M
top_k = kwargs['top_k'], # N
num_layers = 1,
embed_dim = self.embed_dim
)
self.network.to(self.device)
self.unfrezeed_params = []
for name, param in self.network.named_parameters():
param.requires_grad_(False)
if 'prompt' in name or 'classifier' in name:
param.requires_grad_(True)
self.unfrezeed_params.append(param)
def before_task(self, task_idx, buffer, train_loader, test_loaders):
self.cur_task_id = task_idx
def after_task(self, task_idx, buffer, train_loader, test_loaders):
self._known_classes += self.init_cls_num if task_idx == 0 else self.inc_cls_num
def observe(self, data):
x, y = data['image'].to(self.device), data['label'].to(self.device)
logits, reduce_sim = self.network(x, train=True)
if self.cur_task_id == 0:
mask = np.arange(self.init_cls_num)
else:
mask = np.arange(self.inc_cls_num) + self._known_classes
not_mask = np.setdiff1d(np.arange(self.total_cls_num), mask)
not_mask = torch.tensor(not_mask, dtype=torch.int64).to(self.device)
logits = logits.index_fill(dim=1, index=not_mask, value=float('-inf'))
loss = F.cross_entropy(logits, y) - self.pull_constraint_coeff * reduce_sim
loss.backward()
torch.nn.utils.clip_grad_norm_(self.unfrezeed_params, 1.0)
pred = torch.argmax(logits, dim=1)
acc = torch.sum(pred == y).item() / x.size(0)
return pred, acc, loss
def inference(self, data):
x, y = data['image'].to(self.device), data['label'].to(self.device)
logits, _ = self.network(x, train=False)
pred = torch.argmax(logits, dim=1)
acc = torch.sum(pred == y).item() / x.size(0)
return pred, acc
def get_parameters(self, config):
return self.unfrezeed_params
|