import os import torch import torch.nn as nn from typing import Dict from .layers import LoRALayer, AttentionLoRA, BertAttentionLoRA from timm.models.vision_transformer import Attention from transformers.models.bert.modeling_bert import BertAttention INDEX_POSITIONS_TEXT = { 'top1': [11], 'top2': [10, 11], 'top3': [9, 10, 11], 'bottom': [0, 1, 2, 3], 'mid': [4, 5, 6, 7], 'up': [8, 9, 10, 11], 'half-up': [6, 7, 8, 9, 10, 11], 'half-bottom': [0, 1, 2, 3, 4, 5], 'all': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]} INDEX_POSITIONS_VISION = { 'top': [11], 'top3': [9, 10, 11], 'bottom': [0, 1, 2, 3], 'mid': [4, 5, 6, 7], 'up': [8, 9, 10, 11], 'half-up': [6, 7, 8, 9, 10, 11], 'half-bottom': [0, 1, 2, 3, 4, 5], 'all': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], } def mark_only_lora_as_trainable(model: nn.Module, bias: str = 'none') -> None: for n, p in model.named_parameters(): if 'lora_' not in n: p.requires_grad = False if bias == 'none': return elif bias == 'all': for n, p in model.named_parameters(): if 'bias' in n: p.requires_grad = True elif bias == 'lora_only': for m in model.modules(): if isinstance(m, LoRALayer) and \ hasattr(m, 'bias') and \ m.bias is not None: m.bias.requires_grad = True else: raise NotImplementedError def lora_state_dict(model: nn.Module, bias: str = 'none') -> Dict[str, torch.Tensor]: my_state_dict = model.state_dict() if bias == 'none': return {k: my_state_dict[k] for k in my_state_dict if 'lora_' in k} elif bias == 'all': return {k: my_state_dict[k] for k in my_state_dict if 'lora_' in k or 'bias' in k} elif bias == 'lora_only': to_return = {} for k in my_state_dict: if 'lora_' in k: to_return[k] = my_state_dict[k] bias_name = k.split('lora_')[0]+'bias' if bias_name in my_state_dict: to_return[bias_name] = my_state_dict[bias_name] return to_return else: raise NotImplementedError def get_lora_parameters(model, bias='none'): params = [] for name, param in model.named_parameters(): if bias == 'none': if 'lora_' in name: params.append(param) elif bias == 'all': if 'lora_' in name or 'bias' in name: params.append(param) elif bias == 'lora_only': if 'lora_' in name: params.append(param) bias_name = name.split('lora_')[0] + 'bias' if bias_name in model.state_dict(): bias_param = dict(model.named_parameters())[bias_name] params.append(bias_param) else: raise NotImplementedError return params def apply_lora(args, clip_model): list_lora_layers = [] indices = INDEX_POSITIONS_TEXT[args.position] text_encoder = clip_model.text.transformer.encoder for i, block in enumerate(text_encoder.layer): if i in indices: for name, submodule in block.named_children(): if isinstance(submodule, BertAttention): new_multi_head_lora = BertAttentionLoRA( submodule, enable_lora=args.params, r=args.r, lora_alpha=args.alpha, dropout_rate=args.dropout_rate, seed=args.seed) setattr(block, name, new_multi_head_lora) list_lora_layers.append(new_multi_head_lora) indices = INDEX_POSITIONS_VISION[args.position] vision_encoder = clip_model.visual.trunk for i, block in enumerate(vision_encoder.blocks): if i in indices: for name, submodule in block.named_children(): if isinstance(submodule, Attention): new_multi_head_lora = AttentionLoRA( submodule, enable_lora=args.params, r=args.r, lora_alpha=args.alpha, dropout_rate=args.dropout_rate, seed=args.seed) setattr(block, name, new_multi_head_lora) list_lora_layers.append(new_multi_head_lora) return list_lora_layers def save_lora(args, list_lora_layers, loss_fn, msg, save_dir): weights = {} for i, layer in enumerate(list_lora_layers): layer_weights = {} if 'q' in args.params: layer_weights['q_proj'] = { 'w_lora_A': layer.q_proj.w_lora_A.data, 'w_lora_B': layer.q_proj.w_lora_B.data } if 'k' in args.params: layer_weights['k_proj'] = { 'w_lora_A': layer.k_proj.w_lora_A.data, 'w_lora_B': layer.k_proj.w_lora_B.data } if 'v' in args.params: layer_weights['v_proj'] = { 'w_lora_A': layer.v_proj.w_lora_A.data, 'w_lora_B': layer.v_proj.w_lora_B.data } if 'o' in args.params: layer_weights['proj'] = { 'w_lora_A': layer.proj.w_lora_A.data, 'w_lora_B': layer.proj.w_lora_B.data } weights[f'layer_{i}'] = layer_weights if args.loss_type == 'clip_loss_ace_hgnn': weights['img_edge_adapter'] = loss_fn.img_edge_adapter.state_dict() weights['img_node_adapter'] = loss_fn.img_node_adapter.state_dict() weights['text_edge_adapter'] = loss_fn.text_edge_adapter.state_dict() weights['text_node_adapter'] = loss_fn.text_node_adapter.state_dict() if args.learnable_logit_scale: weights['logit_scale'] = loss_fn.logit_scale.data.cpu() metadata = { 'r': args.r, 'topk': args.topk, 'params': args.params, 'position': args.position, 'loss_type' : args.loss_type, } save_data = { 'weights': weights, 'metadata': metadata } save_path = f'{save_dir}/{args.filename}_{msg}.pt' torch.save(save_data, save_path) print(f'LoRA weights saved to {save_path}') def load_model(args, list_lora_layers, device, loss_fn=None): if not os.path.exists(args.load_path): raise FileNotFoundError(f'File {args.load_path} does not exist.') loaded_data = torch.load(args.load_path, map_location=device) weights = loaded_data['weights'] for i, layer in enumerate(list_lora_layers): layer_weights = weights[f'layer_{i}'] if 'q' in args.params and 'q_proj' in layer_weights: layer.q_proj.w_lora_A.data.copy_( layer_weights['q_proj']['w_lora_A']) layer.q_proj.w_lora_B.data.copy_( layer_weights['q_proj']['w_lora_B']) if 'k' in args.params and 'k_proj' in layer_weights: layer.k_proj.w_lora_A.data.copy_( layer_weights['k_proj']['w_lora_A']) layer.k_proj.w_lora_B.data.copy_( layer_weights['k_proj']['w_lora_B']) if 'v' in args.params and 'v_proj' in layer_weights: layer.v_proj.w_lora_A.data.copy_( layer_weights['v_proj']['w_lora_A']) layer.v_proj.w_lora_B.data.copy_( layer_weights['v_proj']['w_lora_B']) if 'o' in args.params and 'proj' in layer_weights: layer.proj.w_lora_A.data.copy_(layer_weights['proj']['w_lora_A']) layer.proj.w_lora_B.data.copy_(layer_weights['proj']['w_lora_B']) if args.loss_type == 'clip_loss_ace_hgnn': loss_fn.img_edge_adapter.load_state_dict(weights['img_edge_adapter']) loss_fn.img_node_adapter.load_state_dict(weights['img_node_adapter']) loss_fn.text_edge_adapter.load_state_dict(weights['text_edge_adapter']) loss_fn.text_node_adapter.load_state_dict(weights['text_node_adapter']) if args.learnable_logit_scale: loss_fn.logit_scale.data.copy_(weights['logit_scale']) print(f'LoRA weights loaded from {args.load_path}')