davjoython's picture
Upload 68 files
30a7879 verified
import os
import torch
import torch.nn as nn
from typing import Dict
from .layers import LoRALayer, PlainMultiheadAttentionLoRA
INDEX_POSITIONS_TEXT = {
'top1': [11],
'top2': [10, 11],
'top3': [9, 10, 11],
'bottom': [0, 1, 2, 3],
'mid': [4, 5, 6, 7],
'up': [8, 9, 10, 11],
'half-up': [6, 7, 8, 9, 10, 11],
'half-bottom': [0, 1, 2, 3, 4, 5],
'all': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]}
INDEX_POSITIONS_VISION = {
'ViT-B/16': {
'top': [11],
'top3': [9, 10, 11],
'bottom': [0, 1, 2, 3],
'mid': [4, 5, 6, 7],
'up': [8, 9, 10, 11],
'half-up': [6, 7, 8, 9, 10, 11],
'half-bottom': [0, 1, 2, 3, 4, 5],
'all': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]},
'ViT-B/32': {
'bottom': [0, 1, 2, 3],
'mid': [4, 5, 6, 7],
'up': [8, 9, 10, 11],
'half-up': [6, 7, 8, 9, 10, 11],
'half-bottom': [0, 1, 2, 3, 4, 5],
'all': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]},
'ViT-L/14': {
'half-up': [12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23],
'half-bottom': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11],
'all': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23]}
}
def mark_only_lora_as_trainable(model: nn.Module, bias: str = 'none') -> None:
for n, p in model.named_parameters():
if 'lora_' not in n:
p.requires_grad = False
if bias == 'none':
return
elif bias == 'all':
for n, p in model.named_parameters():
if 'bias' in n:
p.requires_grad = True
elif bias == 'lora_only':
for m in model.modules():
if isinstance(m, LoRALayer) and \
hasattr(m, 'bias') and \
m.bias is not None:
m.bias.requires_grad = True
else:
raise NotImplementedError
def lora_state_dict(model: nn.Module, bias: str = 'none') -> Dict[str, torch.Tensor]:
my_state_dict = model.state_dict()
if bias == 'none':
return {k: my_state_dict[k] for k in my_state_dict if 'lora_' in k}
elif bias == 'all':
return {k: my_state_dict[k] for k in my_state_dict if 'lora_' in k or 'bias' in k}
elif bias == 'lora_only':
to_return = {}
for k in my_state_dict:
if 'lora_' in k:
to_return[k] = my_state_dict[k]
bias_name = k.split('lora_')[0]+'bias'
if bias_name in my_state_dict:
to_return[bias_name] = my_state_dict[bias_name]
return to_return
else:
raise NotImplementedError
def get_lora_parameters(model, bias='none'):
params = []
for name, param in model.named_parameters():
if bias == 'none':
if 'lora_' in name:
params.append(param)
elif bias == 'all':
if 'lora_' in name or 'bias' in name:
params.append(param)
elif bias == 'lora_only':
if 'lora_' in name:
params.append(param)
bias_name = name.split('lora_')[0] + 'bias'
if bias_name in model.state_dict():
bias_param = dict(model.named_parameters())[bias_name]
params.append(bias_param)
else:
raise NotImplementedError
return params
def apply_lora(args, clip_model):
list_lora_layers = []
if args.encoder == 'text' or args.encoder == 'both':
indices = INDEX_POSITIONS_TEXT[args.position]
text_encoder = clip_model.transformer
for i, block in enumerate(text_encoder.resblocks):
print(f"Residual Attention Block {i}: {block}")
if i in indices:
for name, submodule in block.named_children():
if isinstance(submodule, nn.MultiheadAttention):
new_multi_head_lora = PlainMultiheadAttentionLoRA(
submodule, enable_lora=args.params, r=args.r, lora_alpha=args.alpha, dropout_rate=args.dropout_rate)
setattr(block, name, new_multi_head_lora)
list_lora_layers.append(new_multi_head_lora)
if args.encoder == 'vision' or args.encoder == 'both':
indices = INDEX_POSITIONS_VISION[args.backbone][args.position]
vision_encoder = clip_model.visual.transformer
for i, block in enumerate(vision_encoder.resblocks):
print(f"Residual Attention Block {i}: {block}")
if i in indices:
for name, submodule in block.named_children():
if isinstance(submodule, nn.MultiheadAttention):
new_multi_head_lora = PlainMultiheadAttentionLoRA(
submodule, enable_lora=args.params, r=args.r, lora_alpha=args.alpha, dropout_rate=args.dropout_rate)
setattr(block, name, new_multi_head_lora)
list_lora_layers.append(new_multi_head_lora)
return list_lora_layers
def save_lora(args, list_lora_layers):
weights = {}
for i, layer in enumerate(list_lora_layers):
layer_weights = {}
if 'q' in args.params:
layer_weights['q_proj'] = {
'w_lora_A': layer.q_proj.w_lora_A.data,
'w_lora_B': layer.q_proj.w_lora_B.data
}
if 'k' in args.params:
layer_weights['k_proj'] = {
'w_lora_A': layer.k_proj.w_lora_A.data,
'w_lora_B': layer.k_proj.w_lora_B.data
}
if 'v' in args.params:
layer_weights['v_proj'] = {
'w_lora_A': layer.v_proj.w_lora_A.data,
'w_lora_B': layer.v_proj.w_lora_B.data
}
if 'o' in args.params:
layer_weights['proj'] = {
'w_lora_A': layer.proj.w_lora_A.data,
'w_lora_B': layer.proj.w_lora_B.data
}
weights[f'layer_{i}'] = layer_weights
metadata = {
'r': args.r,
'alpha': args.alpha,
'encoder': args.encoder,
'params': args.params,
'position': args.position
}
save_data = {
'weights': weights,
'metadata': metadata
}
# to manage names like ViT-B/16
backbone = args.backbone.replace('/', '').replace('-', '').lower()
save_dir = f'{args.save_path}/{backbone}/{args.dataset}/{args.shots}shots/seed{args.seed}'
os.makedirs(save_dir, exist_ok=True)
save_path = f'{save_dir}/{args.filename}.pt'
torch.save(save_data, save_path)
print(f'LoRA weights saved to {save_path}')
def load_lora(args, list_lora_layers):
# to manage names like ViT-B/16
backbone = args.backbone.replace('/', '').replace('-', '').lower()
load_path = f'{args.save_path}/{backbone}/{args.dataset}/{args.shots}shots/seed{args.seed}/{args.filename}.pt'
if not os.path.exists(load_path):
raise FileNotFoundError(f'File {load_path} does not exist.')
loaded_data = torch.load(load_path)
metadata = loaded_data['metadata']
if metadata['r'] != args.r:
raise ValueError(
f"r mismatch: expected {args.r}, found {metadata['r']}")
if metadata['alpha'] != args.alpha:
raise ValueError(
f"alpha mismatch: expected {args.alpha}, found {metadata['alpha']}")
if metadata['encoder'] != args.encoder:
raise ValueError(
f"Encoder mismatch: expected {args.encoder}, found {metadata['encoder']}")
if metadata['params'] != args.params:
raise ValueError(
f"Params mismatch: expected {args.params}, found {metadata['params']}")
if metadata['position'] != args.position:
raise ValueError(
f"Position mismatch: expected {args.position}, found {metadata['position']}")
weights = loaded_data['weights']
for i, layer in enumerate(list_lora_layers):
layer_weights = weights[f'layer_{i}']
if 'q' in args.params and 'q_proj' in layer_weights:
layer.q_proj.w_lora_A.data.copy_(
layer_weights['q_proj']['w_lora_A'])
layer.q_proj.w_lora_B.data.copy_(
layer_weights['q_proj']['w_lora_B'])
if 'k' in args.params and 'k_proj' in layer_weights:
layer.k_proj.w_lora_A.data.copy_(
layer_weights['k_proj']['w_lora_A'])
layer.k_proj.w_lora_B.data.copy_(
layer_weights['k_proj']['w_lora_B'])
if 'v' in args.params and 'v_proj' in layer_weights:
layer.v_proj.w_lora_A.data.copy_(
layer_weights['v_proj']['w_lora_A'])
layer.v_proj.w_lora_B.data.copy_(
layer_weights['v_proj']['w_lora_B'])
if 'o' in args.params and 'proj' in layer_weights:
layer.proj.w_lora_A.data.copy_(layer_weights['proj']['w_lora_A'])
layer.proj.w_lora_B.data.copy_(layer_weights['proj']['w_lora_B'])
print(f'LoRA weights loaded from {load_path}')