boringKey's picture
Upload 236 files
5fee096 verified
'''
Code Reference:
Adapted from https://github.com/GT-RIPL/CODA-Prompt
'''
import os
import timm
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from functools import partial
from timm.models.vision_transformer import _cfg, PatchEmbed
from timm.models.registry import register_model
from timm.models.layers import trunc_normal_, DropPath
from timm.models.helpers import named_apply, adapt_input_conv
from .prompt import L2P, CodaPrompt, DualPrompt
from .transformer import MultiHeadAttention_LoRA, VisionTransformer, VisionTransformer_CL_LoRA
def interpolate_pos_embed(pos_embed_checkpoint, visual_encoder):
# interpolate position embedding
embedding_size = pos_embed_checkpoint.shape[-1]
num_patches = visual_encoder.patch_embed.num_patches
num_extra_tokens = visual_encoder.pos_embed.shape[-2] - num_patches
# height (== width) for the checkpoint position embedding
orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
# height (== width) for the new position embedding
new_size = int(num_patches ** 0.5)
if orig_size!=new_size:
# class_token and dist_token are kept unchanged
extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
# only the position tokens are interpolated
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
pos_tokens = torch.nn.functional.interpolate(
pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
print('reshape position embedding from %d to %d'%(orig_size ** 2,new_size ** 2))
return new_pos_embed
else:
return pos_embed_checkpoint
class ViTZoo(nn.Module):
def __init__(self, pretrained = False, model_name='vit_base_patch16_224', attn_layer='MultiHeadAttention', **kwargs):
super(ViTZoo, self).__init__()
self.task_id = None
self.feat_dim = 768
self.feat = VisionTransformer(img_size=224, patch_size=16, embed_dim=768, depth=12,
num_heads=12, ckpt_layer=0,
drop_path_rate=0, attn_layer=attn_layer,
**kwargs
)
if pretrained:
print(f'Using pretrained model : {model_name}')
if model_name == 'vit_base_patch16_224.augreg2_in21k_ft_in1k' and os.path.exists('/home/lvqiexuan/.cache/torch/hub/checkpoints/vit_base_patch16_224.augreg2_in21k_ft_in1k.pt'):
# Manually Loading weight
load_dict = torch.load('/home/lvqiexuan/.cache/torch/hub/checkpoints/vit_base_patch16_224.augreg2_in21k_ft_in1k.pt')
else:
load_dict = timm.create_model(model_name, pretrained = pretrained).state_dict()
key_mapping = {
".norm1.": ".ln_1.",
".norm2.": ".ln_2.",
"blocks.": "transformer.blocks."
}
modified_load_dict = {}
for key in load_dict.keys():
new_key = key
for old_key, mapped_key in key_mapping.items():
if old_key in new_key:
new_key = new_key.replace(old_key, mapped_key)
modified_load_dict[new_key] = load_dict[key]
self.feat.load_state_dict(modified_load_dict, strict = False)
self.prompt = None
self.prompt_flag = ''
def create_prompt(self, prompt_flag, **kwargs):
self.prompt_flag = prompt_flag
if self.prompt_flag == 'l2p':
self.prompt = L2P(**kwargs)
elif self.prompt_flag == 'dual':
self.prompt = DualPrompt(768, **kwargs)
elif self.prompt_flag == 'coda':
self.prompt = CodaPrompt(768, **kwargs)
# pen: get penultimate features
def forward(self, image, text=None, pen=False, train=False, **kwargs):
if self.prompt_flag == 'l2p':
with torch.no_grad():
self.eval()
cls_features = self.feat(image, prompt_flag = self.prompt_flag)
if train:
self.train()
out, reduce_sim = self.feat(
x = image,
prompt = self.prompt,
cls_features = cls_features,
prompt_flag = self.prompt_flag
)
return out, reduce_sim
elif self.prompt is not None:
with torch.no_grad():
q, _ = self.feat(image)
q = q[:,0,:]
# q?, train?, task_id?
out, prompt_loss = self.feat(image, prompt=self.prompt, q=q, train=train, task_id=self.task_id)
out = out[:,0,:]
else:
out, _ = self.feat(image, **kwargs)
if len(out.shape) == 3:
out = out[:,0,:]
out = out.view(out.size(0), -1)
if self.prompt is not None and train:
return out, prompt_loss
else:
return out
class ViT_in21k_adapter(nn.Module):
def __init__(self, pretrained=False, **kwargs):
super(ViT_in21k_adapter, self).__init__()
self.task_id = None
self.feat_dim = 768
# get feature encoder
if pretrained:
print("Using pretrained model")
from core.model.backbone.petl import vision_transformer_adapter
from easydict import EasyDict
tuning_config = EasyDict(
# AdaptFormer
ffn_adapt=True,
ffn_option="parallel",
ffn_adapter_layernorm_option="none",
ffn_adapter_init_option="lora",
ffn_adapter_scalar="0.1",
ffn_num=64,
d_model=768,
# VPT related
vpt_on=False,
vpt_num=0,
)
zoo_model = vision_transformer_adapter.vit_base_patch16_224_in21k_adapter(num_classes=0,
global_pool=False, drop_path_rate=0.0, tuning_config=tuning_config)
zoo_model.out_dim=768
zoo_model.eval()
self.prompt = None
# feature encoder changes if transformer vs resnet
self.feat = zoo_model
def create_prompt(self, prompt_flag, **kwargs):
self.prompt_flag = prompt_flag
# self.prompt_param = prompt_param
# create prompting module
if self.prompt_flag == 'l2p':
self.prompt = L2P(768, **kwargs)
elif self.prompt_flag == 'dual':
self.prompt = DualPrompt(768, **kwargs)
elif self.prompt_flag == 'coda':
self.prompt = CodaPrompt(768, **kwargs)
# pen: get penultimate features
def forward(self, x, pen=False, train=False):
if self.prompt is not None:
with torch.no_grad():
q, _ = self.feat(x)
q = q[:,0,:]
out, prompt_loss = self.feat(x, prompt=self.prompt, q=q, train=train, task_id=self.task_id)
out = out[:,0,:]
else:
out = self.feat(x) # This implementation of adapter vit doesn't return prompt loss
out = out.view(out.size(0), -1)
# if not pen:
# out = self.last(out)
if self.prompt is not None and train:
return out, prompt_loss
else:
return out
class ViT_CL_LoRA(nn.Module):
def __init__(self, pretrained = False, model_name='vit_base_patch16_224', attn_layer='MultiHeadAttention', **kwargs):
super().__init__()
self.task_id = None
self.feat_dim = 768
self.feat = VisionTransformer_CL_LoRA(img_size=224, patch_size=16, embed_dim=768, depth=12,
num_heads=12, ckpt_layer=0,
drop_path_rate=0, attn_layer=attn_layer,
**kwargs
)
if pretrained:
print(f'Using pretrained model : {model_name}')
if model_name == 'vit_base_patch16_224.augreg2_in21k_ft_in1k' and os.path.exists('/home/lvqiexuan/.cache/torch/hub/checkpoints/vit_base_patch16_224.augreg2_in21k_ft_in1k.pt'):
# Manually Loading weight
load_dict = torch.load('/home/lvqiexuan/.cache/torch/hub/checkpoints/vit_base_patch16_224.augreg2_in21k_ft_in1k.pt')
else:
load_dict = timm.create_model(model_name, pretrained = pretrained).state_dict()
key_mapping = {
".norm1.": ".ln_1.",
".norm2.": ".ln_2.",
"blocks.": "transformer.blocks."
}
modified_load_dict = {}
for key in load_dict.keys():
new_key = key
for old_key, mapped_key in key_mapping.items():
if old_key in new_key:
new_key = new_key.replace(old_key, mapped_key)
modified_load_dict[new_key] = load_dict[key]
self.feat.load_state_dict(modified_load_dict, strict = False)
self.prompt = None
self.prompt_flag = ''
# pen: get penultimate features
def forward(self, image, test, text=None, pen=False, train=False, **kwargs):
if self.prompt_flag == 'l2p':
with torch.no_grad():
self.eval()
cls_features = self.feat(image, prompt_flag = self.prompt_flag)
if train:
self.train()
out, reduce_sim = self.feat(
x = image,
prompt = self.prompt,
cls_features = cls_features,
prompt_flag = self.prompt_flag
)
return out, reduce_sim
elif self.prompt is not None:
with torch.no_grad():
q, _ = self.feat(image)
q = q[:,0,:]
# q?, train?, task_id?
out, prompt_loss = self.feat(image, prompt=self.prompt, q=q, train=train, task_id=self.task_id)
out = out[:,0,:]
else:
out, _ = self.feat(image, test, **kwargs)
if len(out.shape) == 3:
out = out[:,0,:]
out = out.view(out.size(0), -1)
if self.prompt is not None and train:
return out, prompt_loss
else:
return out
def forward_proto(self, x, adapt_index):
return self.feat.forward_proto(x, adapt_index)
def forward_general_cls(self, x, t_idx):
return self.feat.forward_general_cls(x, t_idx)
def add_adapter_to_list(self):
self.feat.add_adapter_to_list()
def vit_pt_imnet(pretrained=False, **kwargs):
return ViTZoo(pretrained, **kwargs)
def vit_pt_imnet_in21k_adapter(pretrained=False, **kwargs):
return ViT_in21k_adapter(pretrained, **kwargs)
def vit_cl_lora(pretrained=False, **kwargs):
return ViT_CL_LoRA(pretrained, **kwargs)