| | import math |
| | from PIL import Image |
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from transformers import AutoProcessor, CLIPModel, ViTModel, ViTConfig |
| | import loralib as lora |
| |
|
| |
|
| | class ClipModel(nn.Module): |
| | def __init__(self, name, opt, num_classes=1): |
| | super(ClipModel, self).__init__() |
| | self.use_svd = opt.use_svd |
| | |
| | if self.use_svd: |
| | self.model = CLIPModel.from_pretrained(name) |
| | self.model.vision_model = apply_svd_residual_to_self_attn(self.model.vision_model, r=1024-1) |
| | |
| | for name, param in self.model.vision_model.named_parameters(): |
| | print('{}: {}'.format(name, param.requires_grad)) |
| | num_param = sum(p.numel() for p in self.model.vision_model.parameters() if p.requires_grad) |
| | num_total_param = sum(p.numel() for p in self.model.vision_model.parameters()) |
| | print('Number of total parameters: {}, tunable parameters: {}'.format(num_total_param, num_param)) |
| | |
| | self.fc = nn.Linear( 1024, num_classes ) |
| | else: |
| | self.model = CLIPModel.from_pretrained(name) |
| | |
| | for name, param in self.model.vision_model.named_parameters(): |
| | print('{}: {}'.format(name, param.requires_grad)) |
| | num_param = sum(p.numel() for p in self.model.vision_model.parameters() if p.requires_grad) |
| | num_total_param = sum(p.numel() for p in self.model.vision_model.parameters()) |
| | print('Number of total parameters: {}, tunable parameters: {}'.format(num_total_param, num_param)) |
| | |
| | self.fc = nn.Linear( 1024, num_classes ) |
| |
|
| | def forward(self, x, return_feature=False): |
| | features = self.model.vision_model(x)['pooler_output'] |
| | |
| | if return_feature: |
| | return features |
| | return self.fc(features) |
| |
|
| |
|
| | |
| | class SVDResidualLinear(nn.Module): |
| | def __init__(self, in_features, out_features, r, bias=True, init_weight=None): |
| | super(SVDResidualLinear, self).__init__() |
| | self.in_features = in_features |
| | self.out_features = out_features |
| | self.r = r |
| |
|
| | |
| | self.weight_main = nn.Parameter(torch.Tensor(out_features, in_features), requires_grad=False) |
| | if init_weight is not None: |
| | self.weight_main.data.copy_(init_weight) |
| | else: |
| | nn.init.kaiming_uniform_(self.weight_main, a=math.sqrt(5)) |
| |
|
| | |
| | if bias: |
| | self.bias = nn.Parameter(torch.Tensor(out_features)) |
| | nn.init.zeros_(self.bias) |
| | else: |
| | self.register_parameter('bias', None) |
| | |
| | |
| | self.S_r = None |
| | self.U_r = None |
| | self.V_r = None |
| | self.S_residual = None |
| | self.U_residual = None |
| | self.V_residual = None |
| |
|
| | def forward(self, x): |
| | if self.S_residual is not None: |
| | |
| | residual_weight = self.U_residual @ torch.diag(self.S_residual) @ self.V_residual |
| | |
| | weight = self.weight_main + residual_weight |
| | else: |
| | |
| | weight = self.weight_main |
| |
|
| | return F.linear(x, weight, self.bias) |
| | |
| |
|
| | |
| | def apply_svd_residual_to_self_attn(model, r): |
| | for name, module in model.named_children(): |
| | |
| | if ('self_attn' in name): |
| | |
| | for sub_name, sub_module in module.named_modules(): |
| | if isinstance(sub_module, nn.Linear): |
| | |
| | parent_module = module |
| | sub_module_names = sub_name.split('.') |
| | for module_name in sub_module_names[:-1]: |
| | parent_module = getattr(parent_module, module_name) |
| | |
| | setattr(parent_module, sub_module_names[-1], replace_with_svd_residual(sub_module, r)) |
| | else: |
| | |
| | apply_svd_residual_to_self_attn(module, r) |
| | |
| | for param_name, param in model.named_parameters(): |
| | if any(x in param_name for x in ['S_residual', 'U_residual', 'V_residual']): |
| | param.requires_grad = True |
| | else: |
| | param.requires_grad = False |
| | return model |
| |
|
| |
|
| | |
| | def replace_with_svd_residual(module, r): |
| | if isinstance(module, nn.Linear): |
| | in_features = module.in_features |
| | out_features = module.out_features |
| | bias = module.bias is not None |
| |
|
| | |
| | new_module = SVDResidualLinear(in_features, out_features, r, bias=bias, init_weight=module.weight.data.clone()) |
| |
|
| | if bias and module.bias is not None: |
| | new_module.bias.data.copy_(module.bias.data) |
| | |
| | |
| | new_module.weight_original_fnorm = torch.norm(module.weight.data, p='fro') |
| |
|
| | |
| | U, S, Vh = torch.linalg.svd(module.weight.data, full_matrices=False) |
| |
|
| | |
| | r = min(r, len(S)) |
| |
|
| | |
| | U_r = U[:, :r] |
| | S_r = S[:r] |
| | Vh_r = Vh[:r, :] |
| |
|
| | |
| | weight_main = U_r @ torch.diag(S_r) @ Vh_r |
| | |
| | |
| | new_module.weight_main_fnorm = torch.norm(weight_main.data, p='fro') |
| |
|
| | |
| | new_module.weight_main.data.copy_(weight_main) |
| |
|
| | |
| | U_residual = U[:, r:] |
| | S_residual = S[r:] |
| | Vh_residual = Vh[r:, :] |
| |
|
| | if len(S_residual) > 0: |
| | new_module.S_residual = nn.Parameter(S_residual.clone()) |
| | new_module.U_residual = nn.Parameter(U_residual.clone()) |
| | new_module.V_residual = nn.Parameter(Vh_residual.clone()) |
| | else: |
| | new_module.S_residual = None |
| | new_module.U_residual = None |
| | new_module.V_residual = None |
| |
|
| | return new_module |
| | else: |
| | return module |