Spaces:
Runtime error
Runtime error
| import os | |
| import torch | |
| import torch.nn as nn | |
| from typing import Dict | |
| from .layers import LoRALayer, PlainMultiheadAttentionLoRA | |
| INDEX_POSITIONS_TEXT = { | |
| 'top1': [11], | |
| 'top2': [10, 11], | |
| 'top3': [9, 10, 11], | |
| 'bottom': [0, 1, 2, 3], | |
| 'mid': [4, 5, 6, 7], | |
| 'up': [8, 9, 10, 11], | |
| 'half-up': [6, 7, 8, 9, 10, 11], | |
| 'half-bottom': [0, 1, 2, 3, 4, 5], | |
| 'all': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]} | |
| INDEX_POSITIONS_VISION = { | |
| 'ViT-B/16': { | |
| 'top': [11], | |
| 'top3': [9, 10, 11], | |
| 'bottom': [0, 1, 2, 3], | |
| 'mid': [4, 5, 6, 7], | |
| 'up': [8, 9, 10, 11], | |
| 'half-up': [6, 7, 8, 9, 10, 11], | |
| 'half-bottom': [0, 1, 2, 3, 4, 5], | |
| 'all': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]}, | |
| 'ViT-B/32': { | |
| 'bottom': [0, 1, 2, 3], | |
| 'mid': [4, 5, 6, 7], | |
| 'up': [8, 9, 10, 11], | |
| 'half-up': [6, 7, 8, 9, 10, 11], | |
| 'half-bottom': [0, 1, 2, 3, 4, 5], | |
| 'all': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]}, | |
| 'ViT-L/14': { | |
| 'half-up': [12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23], | |
| 'half-bottom': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], | |
| 'all': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23]} | |
| } | |
| def mark_only_lora_as_trainable(model: nn.Module, bias: str = 'none') -> None: | |
| for n, p in model.named_parameters(): | |
| if 'lora_' not in n: | |
| p.requires_grad = False | |
| if bias == 'none': | |
| return | |
| elif bias == 'all': | |
| for n, p in model.named_parameters(): | |
| if 'bias' in n: | |
| p.requires_grad = True | |
| elif bias == 'lora_only': | |
| for m in model.modules(): | |
| if isinstance(m, LoRALayer) and \ | |
| hasattr(m, 'bias') and \ | |
| m.bias is not None: | |
| m.bias.requires_grad = True | |
| else: | |
| raise NotImplementedError | |
| def lora_state_dict(model: nn.Module, bias: str = 'none') -> Dict[str, torch.Tensor]: | |
| my_state_dict = model.state_dict() | |
| if bias == 'none': | |
| return {k: my_state_dict[k] for k in my_state_dict if 'lora_' in k} | |
| elif bias == 'all': | |
| return {k: my_state_dict[k] for k in my_state_dict if 'lora_' in k or 'bias' in k} | |
| elif bias == 'lora_only': | |
| to_return = {} | |
| for k in my_state_dict: | |
| if 'lora_' in k: | |
| to_return[k] = my_state_dict[k] | |
| bias_name = k.split('lora_')[0]+'bias' | |
| if bias_name in my_state_dict: | |
| to_return[bias_name] = my_state_dict[bias_name] | |
| return to_return | |
| else: | |
| raise NotImplementedError | |
| def get_lora_parameters(model, bias='none'): | |
| params = [] | |
| for name, param in model.named_parameters(): | |
| if bias == 'none': | |
| if 'lora_' in name: | |
| params.append(param) | |
| elif bias == 'all': | |
| if 'lora_' in name or 'bias' in name: | |
| params.append(param) | |
| elif bias == 'lora_only': | |
| if 'lora_' in name: | |
| params.append(param) | |
| bias_name = name.split('lora_')[0] + 'bias' | |
| if bias_name in model.state_dict(): | |
| bias_param = dict(model.named_parameters())[bias_name] | |
| params.append(bias_param) | |
| else: | |
| raise NotImplementedError | |
| return params | |
| def apply_lora(args, clip_model): | |
| list_lora_layers = [] | |
| if args.encoder == 'text' or args.encoder == 'both': | |
| indices = INDEX_POSITIONS_TEXT[args.position] | |
| text_encoder = clip_model.transformer | |
| for i, block in enumerate(text_encoder.resblocks): | |
| print(f"Residual Attention Block {i}: {block}") | |
| if i in indices: | |
| for name, submodule in block.named_children(): | |
| if isinstance(submodule, nn.MultiheadAttention): | |
| new_multi_head_lora = PlainMultiheadAttentionLoRA( | |
| submodule, enable_lora=args.params, r=args.r, lora_alpha=args.alpha, dropout_rate=args.dropout_rate) | |
| setattr(block, name, new_multi_head_lora) | |
| list_lora_layers.append(new_multi_head_lora) | |
| if args.encoder == 'vision' or args.encoder == 'both': | |
| indices = INDEX_POSITIONS_VISION[args.backbone][args.position] | |
| vision_encoder = clip_model.visual.transformer | |
| for i, block in enumerate(vision_encoder.resblocks): | |
| print(f"Residual Attention Block {i}: {block}") | |
| if i in indices: | |
| for name, submodule in block.named_children(): | |
| if isinstance(submodule, nn.MultiheadAttention): | |
| new_multi_head_lora = PlainMultiheadAttentionLoRA( | |
| submodule, enable_lora=args.params, r=args.r, lora_alpha=args.alpha, dropout_rate=args.dropout_rate) | |
| setattr(block, name, new_multi_head_lora) | |
| list_lora_layers.append(new_multi_head_lora) | |
| return list_lora_layers | |
| def save_lora(args, list_lora_layers): | |
| weights = {} | |
| for i, layer in enumerate(list_lora_layers): | |
| layer_weights = {} | |
| if 'q' in args.params: | |
| layer_weights['q_proj'] = { | |
| 'w_lora_A': layer.q_proj.w_lora_A.data, | |
| 'w_lora_B': layer.q_proj.w_lora_B.data | |
| } | |
| if 'k' in args.params: | |
| layer_weights['k_proj'] = { | |
| 'w_lora_A': layer.k_proj.w_lora_A.data, | |
| 'w_lora_B': layer.k_proj.w_lora_B.data | |
| } | |
| if 'v' in args.params: | |
| layer_weights['v_proj'] = { | |
| 'w_lora_A': layer.v_proj.w_lora_A.data, | |
| 'w_lora_B': layer.v_proj.w_lora_B.data | |
| } | |
| if 'o' in args.params: | |
| layer_weights['proj'] = { | |
| 'w_lora_A': layer.proj.w_lora_A.data, | |
| 'w_lora_B': layer.proj.w_lora_B.data | |
| } | |
| weights[f'layer_{i}'] = layer_weights | |
| metadata = { | |
| 'r': args.r, | |
| 'alpha': args.alpha, | |
| 'encoder': args.encoder, | |
| 'params': args.params, | |
| 'position': args.position | |
| } | |
| save_data = { | |
| 'weights': weights, | |
| 'metadata': metadata | |
| } | |
| # to manage names like ViT-B/16 | |
| backbone = args.backbone.replace('/', '').replace('-', '').lower() | |
| save_dir = f'{args.save_path}/{backbone}/{args.dataset}/{args.shots}shots/seed{args.seed}' | |
| os.makedirs(save_dir, exist_ok=True) | |
| save_path = f'{save_dir}/{args.filename}.pt' | |
| torch.save(save_data, save_path) | |
| print(f'LoRA weights saved to {save_path}') | |
| def load_lora(args, list_lora_layers): | |
| # to manage names like ViT-B/16 | |
| backbone = args.backbone.replace('/', '').replace('-', '').lower() | |
| load_path = f'{args.save_path}/{backbone}/{args.dataset}/{args.shots}shots/seed{args.seed}/{args.filename}.pt' | |
| if not os.path.exists(load_path): | |
| raise FileNotFoundError(f'File {load_path} does not exist.') | |
| loaded_data = torch.load(load_path) | |
| metadata = loaded_data['metadata'] | |
| if metadata['r'] != args.r: | |
| raise ValueError( | |
| f"r mismatch: expected {args.r}, found {metadata['r']}") | |
| if metadata['alpha'] != args.alpha: | |
| raise ValueError( | |
| f"alpha mismatch: expected {args.alpha}, found {metadata['alpha']}") | |
| if metadata['encoder'] != args.encoder: | |
| raise ValueError( | |
| f"Encoder mismatch: expected {args.encoder}, found {metadata['encoder']}") | |
| if metadata['params'] != args.params: | |
| raise ValueError( | |
| f"Params mismatch: expected {args.params}, found {metadata['params']}") | |
| if metadata['position'] != args.position: | |
| raise ValueError( | |
| f"Position mismatch: expected {args.position}, found {metadata['position']}") | |
| weights = loaded_data['weights'] | |
| for i, layer in enumerate(list_lora_layers): | |
| layer_weights = weights[f'layer_{i}'] | |
| if 'q' in args.params and 'q_proj' in layer_weights: | |
| layer.q_proj.w_lora_A.data.copy_( | |
| layer_weights['q_proj']['w_lora_A']) | |
| layer.q_proj.w_lora_B.data.copy_( | |
| layer_weights['q_proj']['w_lora_B']) | |
| if 'k' in args.params and 'k_proj' in layer_weights: | |
| layer.k_proj.w_lora_A.data.copy_( | |
| layer_weights['k_proj']['w_lora_A']) | |
| layer.k_proj.w_lora_B.data.copy_( | |
| layer_weights['k_proj']['w_lora_B']) | |
| if 'v' in args.params and 'v_proj' in layer_weights: | |
| layer.v_proj.w_lora_A.data.copy_( | |
| layer_weights['v_proj']['w_lora_A']) | |
| layer.v_proj.w_lora_B.data.copy_( | |
| layer_weights['v_proj']['w_lora_B']) | |
| if 'o' in args.params and 'proj' in layer_weights: | |
| layer.proj.w_lora_A.data.copy_(layer_weights['proj']['w_lora_A']) | |
| layer.proj.w_lora_B.data.copy_(layer_weights['proj']['w_lora_B']) | |
| print(f'LoRA weights loaded from {load_path}') |