File size: 6,193 Bytes
5fee096 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 | import torch
import torch.nn as nn
import copy
from .vit_inflora import VisionTransformer, PatchEmbed, Block, resolve_pretrained_cfg, build_model_with_cfg, checkpoint_filter_fn
class ViT_lora_co(VisionTransformer):
def __init__(
self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, global_pool='token',
embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, representation_size=None,
drop_rate=0., attn_drop_rate=0., drop_path_rate=0., weight_init='', init_values=None,
embed_layer=PatchEmbed, norm_layer=None, act_layer=None, block_fn=Block, n_tasks=10, rank=64):
super().__init__(img_size=img_size, patch_size=patch_size, in_chans=in_chans, num_classes=num_classes, global_pool=global_pool,
embed_dim=embed_dim, depth=depth, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, representation_size=representation_size,
drop_rate=drop_rate, attn_drop_rate=attn_drop_rate, drop_path_rate=drop_path_rate, weight_init=weight_init, init_values=init_values,
embed_layer=embed_layer, norm_layer=norm_layer, act_layer=act_layer, block_fn=block_fn, n_tasks=n_tasks, rank=rank)
def forward(self, x, task_id, register_blk=-1, get_feat=False, get_cur_feat=False):
x = self.patch_embed(x)
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
x = x + self.pos_embed[:, :x.size(1), :]
x = self.pos_drop(x)
prompt_loss = torch.zeros((1,), requires_grad=True).to(x.device)
for i, blk in enumerate(self.blocks):
x = blk(x, task_id, register_blk == i,
get_feat=get_feat, get_cur_feat=get_cur_feat)
x = self.norm(x)
return x, prompt_loss
def _create_vision_transformer(variant, pretrained=False, **kwargs):
if kwargs.get('features_only', None):
raise RuntimeError(
'features_only not implemented for Vision Transformer models.')
# NOTE this extra code to support handling of repr size for in21k pretrained models
# pretrained_cfg = resolve_pretrained_cfg(variant, kwargs=kwargs)
pretrained_cfg = resolve_pretrained_cfg(variant)
default_num_classes = pretrained_cfg['num_classes']
num_classes = kwargs.get('num_classes', default_num_classes)
repr_size = kwargs.pop('representation_size', None)
if repr_size is not None and num_classes != default_num_classes:
repr_size = None
model = build_model_with_cfg(
ViT_lora_co, variant, pretrained,
pretrained_cfg=pretrained_cfg,
representation_size=repr_size,
pretrained_filter_fn=checkpoint_filter_fn,
pretrained_custom_load='npz' in pretrained_cfg['url'],
**kwargs)
return model
class SiNet_vit(nn.Module):
def __init__(self, **args):
'''
args is a dictionary with the required arguments.
image_encoder is defined in vit_inflora.
class_num is the number of initial class.
'''
super(SiNet_vit, self).__init__()
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12,
num_heads=12, n_tasks=args["total_sessions"], rank=args["rank"])
self.image_encoder = _create_vision_transformer(
'vit_base_patch16_224_in21k', pretrained=True, **model_kwargs)
self.class_num = 1
self.class_num = args["init_cls"]
self.classifier_pool = nn.ModuleList([
nn.Linear(args["embd_dim"], self.class_num, bias=True)
for i in range(args["total_sessions"])
])
self.classifier_pool_backup = nn.ModuleList([
nn.Linear(args["embd_dim"], self.class_num, bias=True)
for i in range(args["total_sessions"])
])
self.numtask = 0
@property
def feature_dim(self):
return self.image_encoder.out_dim
def extract_vector(self, image, task=None):
if task == None:
image_features, _ = self.image_encoder(image, self.numtask-1)
else:
image_features, _ = self.image_encoder(image, task)
image_features = image_features[:, 0, :]
return image_features
def forward(self, image, get_feat=False, get_cur_feat=False, fc_only=False):
"""
return the output of fully connected layer.
"""
if fc_only:
fc_outs = []
for ti in range(self.numtask):
fc_out = self.classifier_pool[ti](image)
fc_outs.append(fc_out)
return torch.cat(fc_outs, dim=1)
logits = []
image_features, prompt_loss = self.image_encoder(
image, task_id=self.numtask-1, get_feat=get_feat, get_cur_feat=get_cur_feat)
image_features = image_features[:, 0, :]
image_features = image_features.view(image_features.size(0), -1)
for prompts in [self.classifier_pool[self.numtask-1]]:
logits.append(prompts(image_features))
return {
'logits': torch.cat(logits, dim=1),
'features': image_features,
'prompt_loss': prompt_loss
}
def interface(self, image):
image_features, _ = self.image_encoder(image, task_id=self.numtask-1)
image_features = image_features[:, 0, :]
image_features = image_features.view(image_features.size(0), -1)
logits = []
for prompt in self.classifier_pool[:self.numtask]:
logits.append(prompt(image_features))
logits = torch.cat(logits, 1)
return logits
def update_fc(self, nb_classes):
"""
update the number of tasks.
"""
self.numtask += 1
def classifier_backup(self, task_id):
self.classifier_pool_backup[task_id].load_state_dict(
self.classifier_pool[task_id].state_dict())
def classifier_recall(self):
self.classifier_pool.load_state_dict(self.old_state_dict)
def copy(self):
return copy.deepcopy(self)
def freeze(self):
for param in self.parameters():
param.requires_grad = False
self.eval()
return self
|