Spaces:
Configuration error
Configuration error
| import math | |
| import cv2 | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import torchvision.transforms as T | |
| from PIL import Image | |
| from transformers import CLIPModel | |
| class EffortModel(nn.Module): | |
| def __init__(self, config=None): | |
| super().__init__() | |
| self.config = config | |
| self.backbone = self.build_backbone(config) | |
| self.head = nn.Linear(1024, 2) | |
| self.loss_func = nn.CrossEntropyLoss() | |
| self.prob, self.label = [], [] | |
| self.correct, self.total = 0, 0 | |
| def build_backbone(self, config): | |
| # ⚠⚠⚠ Download CLIP model using the below link | |
| # https://drive.google.com/drive/folders/1fm3Jd8lFMiSP1qgdmsxfqlJZGpr_bXsx?usp=drive_link | |
| # mean: [0.48145466, 0.4578275, 0.40821073] | |
| # std: [0.26862954, 0.26130258, 0.27577711] | |
| # ViT-L/14 224*224 | |
| # the path of this folder in your disk (download from the above link) | |
| clip_model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14") | |
| # Apply SVD to self_attn layers only | |
| # ViT-L/14 224*224: 1024-1 | |
| clip_model.vision_model = apply_svd_residual_to_self_attn(clip_model.vision_model, r=1024 - 1) | |
| # for name, param in clip_model.vision_model.named_parameters(): | |
| # print('{}: {}'.format(name, param.requires_grad)) | |
| # num_param = sum(p.numel() for p in clip_model.vision_model.parameters() if p.requires_grad) | |
| # num_total_param = sum(p.numel() for p in clip_model.vision_model.parameters()) | |
| # print('Number of total parameters: {}, tunable parameters: {}'.format(num_total_param, num_param)) | |
| return clip_model.vision_model | |
| def features(self, inputs: torch.Tensor) -> torch.tensor: | |
| return self.backbone(inputs).pooler_output | |
| def classifier(self, features: torch.tensor) -> torch.tensor: | |
| return self.head(features) | |
| # def get_losses(self, data_dict: dict, pred_dict: dict) -> dict: | |
| # label = data_dict['label'] | |
| # pred = pred_dict['cls'] | |
| # loss = self.loss_func(pred, label) | |
| # | |
| # if self.training: | |
| # # Regularization term | |
| # lambda_reg = 1.0 | |
| # reg_term = 0.0 | |
| # num_reg = 0 | |
| # for module in self.backbone.modules(): | |
| # if isinstance(module, SVDResidualLinear): | |
| # reg_term += module.compute_orthogonal_loss() | |
| # reg_term += module.compute_keepsv_loss() | |
| # num_reg += 1 | |
| # | |
| # loss += lambda_reg * reg_term / num_reg | |
| # | |
| # loss_dict = {'overall': loss} | |
| # return loss_dict | |
| def compute_weight_loss(self): | |
| weight_sum_dict = {} | |
| num_weight_dict = {} | |
| for name, module in self.backbone.named_modules(): | |
| if isinstance(module, SVDResidualLinear): | |
| weight_curr = module.compute_current_weight() | |
| if str(weight_curr.size()) not in weight_sum_dict.keys(): | |
| weight_sum_dict[str(weight_curr.size())] = weight_curr | |
| num_weight_dict[str(weight_curr.size())] = 1 | |
| else: | |
| weight_sum_dict[str(weight_curr.size())] += weight_curr | |
| num_weight_dict[str(weight_curr.size())] += 1 | |
| loss2 = 0.0 | |
| for k in weight_sum_dict.keys(): | |
| _, S_sum, _ = torch.linalg.svd(weight_sum_dict[k], full_matrices=False) | |
| loss2 += -torch.mean(S_sum) | |
| loss2 /= len(weight_sum_dict.keys()) | |
| return loss2 | |
| def get_losses(self, data_dict: dict, pred_dict: dict) -> dict: | |
| label = data_dict["label"] # Tensor of shape [batch_size] | |
| pred = pred_dict["logits"] # Tensor of shape [batch_size, num_classes] | |
| # Compute overall loss using all samples | |
| loss = self.loss_func(pred, label) | |
| # Create masks for real and fake classes | |
| mask_real = label == 0 # Boolean tensor | |
| mask_fake = label == 1 # Boolean tensor | |
| # Compute loss for real class | |
| if mask_real.sum() > 0: | |
| pred_real = pred[mask_real] | |
| label_real = label[mask_real] | |
| loss_real = self.loss_func(pred_real, label_real) | |
| else: | |
| # No real samples in batch | |
| loss_real = torch.tensor(0.0, device=pred.device) | |
| # Compute loss for fake class | |
| if mask_fake.sum() > 0: | |
| pred_fake = pred[mask_fake] | |
| label_fake = label[mask_fake] | |
| loss_fake = self.loss_func(pred_fake, label_fake) | |
| else: | |
| # No fake samples in batch | |
| loss_fake = torch.tensor(0.0, device=pred.device) | |
| # loss2 = self.compute_weight_loss() | |
| # overall_loss = loss + loss2 | |
| # Return a dictionary with all losses | |
| loss_dict = { | |
| "overall": loss, | |
| "real_loss": loss_real, | |
| "fake_loss": loss_fake, | |
| # 'erank_loss': loss2 | |
| } | |
| return loss_dict | |
| # def get_train_metrics(self, data_dict: dict, pred_dict: dict) -> dict: | |
| # label = data_dict["label"] | |
| # pred = pred_dict["logits"] | |
| # # compute metrics for batch data | |
| # auc, eer, acc, ap = calculate_metrics_for_train(label.detach(), pred.detach()) | |
| # metric_batch_dict = {"acc": acc, "auc": auc, "eer": eer, "ap": ap} | |
| # return metric_batch_dict | |
| def forward(self, inputs: torch.Tensor): | |
| # Get features from the backbone | |
| features = self.features(inputs) | |
| # Get logits from the classifier | |
| logits = self.classifier(features) | |
| normalized_features = F.normalize(features, p=2, dim=1) | |
| return logits, normalized_features | |
| # Custom module to represent the residual using SVD components | |
| class SVDResidualLinear(nn.Module): | |
| def __init__(self, in_features, out_features, r, bias=True, init_weight=None): | |
| super().__init__() | |
| self.in_features = in_features | |
| self.out_features = out_features | |
| self.r = r # Number of top singular values to exclude | |
| # Original weights (fixed) | |
| self.weight_main = nn.Parameter(torch.Tensor(out_features, in_features), requires_grad=False) | |
| if init_weight is not None: | |
| self.weight_main.data.copy_(init_weight) | |
| else: | |
| nn.init.kaiming_uniform_(self.weight_main, a=math.sqrt(5)) | |
| # Bias | |
| if bias: | |
| self.bias = nn.Parameter(torch.Tensor(out_features)) | |
| nn.init.zeros_(self.bias) | |
| else: | |
| self.register_parameter("bias", None) | |
| def compute_current_weight(self): | |
| if self.S_residual is not None: | |
| return self.weight_main + self.U_residual @ torch.diag(self.S_residual) @ self.V_residual | |
| else: | |
| return self.weight_main | |
| def forward(self, x): | |
| if hasattr(self, "U_residual") and hasattr(self, "V_residual") and self.S_residual is not None: | |
| # Reconstruct the residual weight | |
| residual_weight = self.U_residual @ torch.diag(self.S_residual) @ self.V_residual | |
| # Total weight is the fixed main weight plus the residual | |
| weight = self.weight_main + residual_weight | |
| else: | |
| # If residual components are not set, use only the main weight | |
| weight = self.weight_main | |
| return F.linear(x, weight, self.bias) | |
| def compute_orthogonal_loss(self): | |
| if self.S_residual is not None: | |
| # According to the properties of orthogonal matrices: A^TA = I | |
| UUT = torch.cat((self.U_r, self.U_residual), dim=1) @ torch.cat((self.U_r, self.U_residual), dim=1).t() | |
| VVT = torch.cat((self.V_r, self.V_residual), dim=0) @ torch.cat((self.V_r, self.V_residual), dim=0).t() | |
| # print(self.U_r.size(), self.U_residual.size()) # torch.Size([1024, 1023]) torch.Size([1024, 1]) | |
| # print(self.V_r.size(), self.V_residual.size()) # torch.Size([1023, 1024]) torch.Size([1, 1024]) | |
| # UUT = self.U_residual @ self.U_residual.t() | |
| # VVT = self.V_residual @ self.V_residual.t() | |
| # Construct an identity matrix | |
| UUT_identity = torch.eye(UUT.size(0), device=UUT.device) | |
| VVT_identity = torch.eye(VVT.size(0), device=VVT.device) | |
| # Using frobenius norm to compute loss | |
| loss = 0.5 * torch.norm(UUT - UUT_identity, p="fro") + 0.5 * torch.norm(VVT - VVT_identity, p="fro") | |
| else: | |
| loss = 0.0 | |
| return loss | |
| def compute_keepsv_loss(self): | |
| if (self.S_residual is not None) and (self.weight_original_fnorm is not None): | |
| # Total current weight is the fixed main weight plus the residual | |
| weight_current = self.weight_main + self.U_residual @ torch.diag(self.S_residual) @ self.V_residual | |
| # Frobenius norm of current weight | |
| weight_current_fnorm = torch.norm(weight_current, p="fro") | |
| loss = torch.abs(weight_current_fnorm**2 - self.weight_original_fnorm**2) | |
| # loss = torch.abs(weight_current_fnorm ** 2 + 0.01 * self.weight_main_fnorm ** 2 - 1.01 * self.weight_original_fnorm ** 2) | |
| else: | |
| loss = 0.0 | |
| return loss | |
| def compute_fn_loss(self): | |
| if self.S_residual is not None: | |
| weight_current = self.weight_main + self.U_residual @ torch.diag(self.S_residual) @ self.V_residual | |
| weight_current_fnorm = torch.norm(weight_current, p="fro") | |
| loss = weight_current_fnorm**2 | |
| else: | |
| loss = 0.0 | |
| return loss | |
| # Function to replace nn.Linear modules within self_attn modules with SVDResidualLinear | |
| def apply_svd_residual_to_self_attn(model, r): | |
| for name, module in model.named_children(): | |
| if "self_attn" in name: | |
| # Replace nn.Linear layers in this module | |
| for sub_name, sub_module in module.named_modules(): | |
| if isinstance(sub_module, nn.Linear): | |
| # Get parent module within self_attn | |
| parent_module = module | |
| sub_module_names = sub_name.split(".") | |
| for module_name in sub_module_names[:-1]: | |
| parent_module = getattr(parent_module, module_name) | |
| # Replace the nn.Linear layer with SVDResidualLinear | |
| setattr(parent_module, sub_module_names[-1], replace_with_svd_residual(sub_module, r)) | |
| else: | |
| # Recursively apply to child modules | |
| apply_svd_residual_to_self_attn(module, r) | |
| # After replacing, set requires_grad for residual components | |
| for param_name, param in model.named_parameters(): | |
| if any(x in param_name for x in ["S_residual", "U_residual", "V_residual"]): | |
| param.requires_grad = True | |
| else: | |
| param.requires_grad = False | |
| return model | |
| # Function to replace a module with SVDResidualLinear | |
| def replace_with_svd_residual(module, r): | |
| if isinstance(module, nn.Linear): | |
| in_features = module.in_features | |
| out_features = module.out_features | |
| bias = module.bias is not None | |
| # Create SVDResidualLinear module | |
| new_module = SVDResidualLinear(in_features, out_features, r, bias=bias, init_weight=module.weight.data.clone()) | |
| if bias and module.bias is not None: | |
| new_module.bias.data.copy_(module.bias.data) | |
| new_module.weight_original_fnorm = torch.norm(module.weight.data, p="fro") | |
| # Perform SVD on the original weight | |
| U, S, Vh = torch.linalg.svd(module.weight.data, full_matrices=False) | |
| # Determine r based on the rank of the weight matrix | |
| r = min(r, len(S)) # Ensure r does not exceed the number of singular values | |
| # Keep top r singular components (main weight) | |
| U_r = U[:, :r] # Shape: (out_features, r) | |
| S_r = S[:r] # Shape: (r,) | |
| Vh_r = Vh[:r, :] # Shape: (r, in_features) | |
| # Reconstruct the main weight (fixed) | |
| weight_main = U_r @ torch.diag(S_r) @ Vh_r | |
| # Calculate the frobenius norm of main weight | |
| new_module.weight_main_fnorm = torch.norm(weight_main.data, p="fro") | |
| # Set the main weight | |
| new_module.weight_main.data.copy_(weight_main) | |
| # Residual components (trainable) | |
| U_residual = U[:, r:] # Shape: (out_features, n - r) | |
| S_residual = S[r:] # Shape: (n - r,) | |
| Vh_residual = Vh[r:, :] # Shape: (n - r, in_features) | |
| if len(S_residual) > 0: | |
| new_module.S_residual = nn.Parameter(S_residual.clone()) | |
| new_module.U_residual = nn.Parameter(U_residual.clone()) | |
| new_module.V_residual = nn.Parameter(Vh_residual.clone()) | |
| new_module.S_r = nn.Parameter(S_r.clone(), requires_grad=False) | |
| new_module.U_r = nn.Parameter(U_r.clone(), requires_grad=False) | |
| new_module.V_r = nn.Parameter(Vh_r.clone(), requires_grad=False) | |
| else: | |
| new_module.S_residual = None | |
| new_module.U_residual = None | |
| new_module.V_residual = None | |
| new_module.S_r = None | |
| new_module.U_r = None | |
| new_module.V_r = None | |
| return new_module | |
| else: | |
| return module | |
| # This is the original preprocessing used in Effort paper | |
| # Gives almost the same results as `preprocessing` | |
| _preprocessing_original = T.Compose( | |
| [ | |
| T.ToTensor(), | |
| T.Normalize( | |
| [0.48145466, 0.4578275, 0.40821073], | |
| [0.26862954, 0.26130258, 0.27577711], | |
| ), | |
| ] | |
| ) | |
| def preprocessing_original(image: Image) -> torch.Tensor: | |
| image = np.array(image) | |
| image = cv2.resize(image, (224, 224), interpolation=cv2.INTER_LINEAR) | |
| return _preprocessing_original(image) | |