GenD-Sentinel / src /model /effort /model.py
yermandy's picture
init
c29babb
import math
import cv2
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T
from PIL import Image
from transformers import CLIPModel
class EffortModel(nn.Module):
def __init__(self, config=None):
super().__init__()
self.config = config
self.backbone = self.build_backbone(config)
self.head = nn.Linear(1024, 2)
self.loss_func = nn.CrossEntropyLoss()
self.prob, self.label = [], []
self.correct, self.total = 0, 0
def build_backbone(self, config):
# ⚠⚠⚠ Download CLIP model using the below link
# https://drive.google.com/drive/folders/1fm3Jd8lFMiSP1qgdmsxfqlJZGpr_bXsx?usp=drive_link
# mean: [0.48145466, 0.4578275, 0.40821073]
# std: [0.26862954, 0.26130258, 0.27577711]
# ViT-L/14 224*224
# the path of this folder in your disk (download from the above link)
clip_model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14")
# Apply SVD to self_attn layers only
# ViT-L/14 224*224: 1024-1
clip_model.vision_model = apply_svd_residual_to_self_attn(clip_model.vision_model, r=1024 - 1)
# for name, param in clip_model.vision_model.named_parameters():
# print('{}: {}'.format(name, param.requires_grad))
# num_param = sum(p.numel() for p in clip_model.vision_model.parameters() if p.requires_grad)
# num_total_param = sum(p.numel() for p in clip_model.vision_model.parameters())
# print('Number of total parameters: {}, tunable parameters: {}'.format(num_total_param, num_param))
return clip_model.vision_model
def features(self, inputs: torch.Tensor) -> torch.tensor:
return self.backbone(inputs).pooler_output
def classifier(self, features: torch.tensor) -> torch.tensor:
return self.head(features)
# def get_losses(self, data_dict: dict, pred_dict: dict) -> dict:
# label = data_dict['label']
# pred = pred_dict['cls']
# loss = self.loss_func(pred, label)
#
# if self.training:
# # Regularization term
# lambda_reg = 1.0
# reg_term = 0.0
# num_reg = 0
# for module in self.backbone.modules():
# if isinstance(module, SVDResidualLinear):
# reg_term += module.compute_orthogonal_loss()
# reg_term += module.compute_keepsv_loss()
# num_reg += 1
#
# loss += lambda_reg * reg_term / num_reg
#
# loss_dict = {'overall': loss}
# return loss_dict
def compute_weight_loss(self):
weight_sum_dict = {}
num_weight_dict = {}
for name, module in self.backbone.named_modules():
if isinstance(module, SVDResidualLinear):
weight_curr = module.compute_current_weight()
if str(weight_curr.size()) not in weight_sum_dict.keys():
weight_sum_dict[str(weight_curr.size())] = weight_curr
num_weight_dict[str(weight_curr.size())] = 1
else:
weight_sum_dict[str(weight_curr.size())] += weight_curr
num_weight_dict[str(weight_curr.size())] += 1
loss2 = 0.0
for k in weight_sum_dict.keys():
_, S_sum, _ = torch.linalg.svd(weight_sum_dict[k], full_matrices=False)
loss2 += -torch.mean(S_sum)
loss2 /= len(weight_sum_dict.keys())
return loss2
def get_losses(self, data_dict: dict, pred_dict: dict) -> dict:
label = data_dict["label"] # Tensor of shape [batch_size]
pred = pred_dict["logits"] # Tensor of shape [batch_size, num_classes]
# Compute overall loss using all samples
loss = self.loss_func(pred, label)
# Create masks for real and fake classes
mask_real = label == 0 # Boolean tensor
mask_fake = label == 1 # Boolean tensor
# Compute loss for real class
if mask_real.sum() > 0:
pred_real = pred[mask_real]
label_real = label[mask_real]
loss_real = self.loss_func(pred_real, label_real)
else:
# No real samples in batch
loss_real = torch.tensor(0.0, device=pred.device)
# Compute loss for fake class
if mask_fake.sum() > 0:
pred_fake = pred[mask_fake]
label_fake = label[mask_fake]
loss_fake = self.loss_func(pred_fake, label_fake)
else:
# No fake samples in batch
loss_fake = torch.tensor(0.0, device=pred.device)
# loss2 = self.compute_weight_loss()
# overall_loss = loss + loss2
# Return a dictionary with all losses
loss_dict = {
"overall": loss,
"real_loss": loss_real,
"fake_loss": loss_fake,
# 'erank_loss': loss2
}
return loss_dict
# def get_train_metrics(self, data_dict: dict, pred_dict: dict) -> dict:
# label = data_dict["label"]
# pred = pred_dict["logits"]
# # compute metrics for batch data
# auc, eer, acc, ap = calculate_metrics_for_train(label.detach(), pred.detach())
# metric_batch_dict = {"acc": acc, "auc": auc, "eer": eer, "ap": ap}
# return metric_batch_dict
def forward(self, inputs: torch.Tensor):
# Get features from the backbone
features = self.features(inputs)
# Get logits from the classifier
logits = self.classifier(features)
normalized_features = F.normalize(features, p=2, dim=1)
return logits, normalized_features
# Custom module to represent the residual using SVD components
class SVDResidualLinear(nn.Module):
def __init__(self, in_features, out_features, r, bias=True, init_weight=None):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.r = r # Number of top singular values to exclude
# Original weights (fixed)
self.weight_main = nn.Parameter(torch.Tensor(out_features, in_features), requires_grad=False)
if init_weight is not None:
self.weight_main.data.copy_(init_weight)
else:
nn.init.kaiming_uniform_(self.weight_main, a=math.sqrt(5))
# Bias
if bias:
self.bias = nn.Parameter(torch.Tensor(out_features))
nn.init.zeros_(self.bias)
else:
self.register_parameter("bias", None)
def compute_current_weight(self):
if self.S_residual is not None:
return self.weight_main + self.U_residual @ torch.diag(self.S_residual) @ self.V_residual
else:
return self.weight_main
def forward(self, x):
if hasattr(self, "U_residual") and hasattr(self, "V_residual") and self.S_residual is not None:
# Reconstruct the residual weight
residual_weight = self.U_residual @ torch.diag(self.S_residual) @ self.V_residual
# Total weight is the fixed main weight plus the residual
weight = self.weight_main + residual_weight
else:
# If residual components are not set, use only the main weight
weight = self.weight_main
return F.linear(x, weight, self.bias)
def compute_orthogonal_loss(self):
if self.S_residual is not None:
# According to the properties of orthogonal matrices: A^TA = I
UUT = torch.cat((self.U_r, self.U_residual), dim=1) @ torch.cat((self.U_r, self.U_residual), dim=1).t()
VVT = torch.cat((self.V_r, self.V_residual), dim=0) @ torch.cat((self.V_r, self.V_residual), dim=0).t()
# print(self.U_r.size(), self.U_residual.size()) # torch.Size([1024, 1023]) torch.Size([1024, 1])
# print(self.V_r.size(), self.V_residual.size()) # torch.Size([1023, 1024]) torch.Size([1, 1024])
# UUT = self.U_residual @ self.U_residual.t()
# VVT = self.V_residual @ self.V_residual.t()
# Construct an identity matrix
UUT_identity = torch.eye(UUT.size(0), device=UUT.device)
VVT_identity = torch.eye(VVT.size(0), device=VVT.device)
# Using frobenius norm to compute loss
loss = 0.5 * torch.norm(UUT - UUT_identity, p="fro") + 0.5 * torch.norm(VVT - VVT_identity, p="fro")
else:
loss = 0.0
return loss
def compute_keepsv_loss(self):
if (self.S_residual is not None) and (self.weight_original_fnorm is not None):
# Total current weight is the fixed main weight plus the residual
weight_current = self.weight_main + self.U_residual @ torch.diag(self.S_residual) @ self.V_residual
# Frobenius norm of current weight
weight_current_fnorm = torch.norm(weight_current, p="fro")
loss = torch.abs(weight_current_fnorm**2 - self.weight_original_fnorm**2)
# loss = torch.abs(weight_current_fnorm ** 2 + 0.01 * self.weight_main_fnorm ** 2 - 1.01 * self.weight_original_fnorm ** 2)
else:
loss = 0.0
return loss
def compute_fn_loss(self):
if self.S_residual is not None:
weight_current = self.weight_main + self.U_residual @ torch.diag(self.S_residual) @ self.V_residual
weight_current_fnorm = torch.norm(weight_current, p="fro")
loss = weight_current_fnorm**2
else:
loss = 0.0
return loss
# Function to replace nn.Linear modules within self_attn modules with SVDResidualLinear
def apply_svd_residual_to_self_attn(model, r):
for name, module in model.named_children():
if "self_attn" in name:
# Replace nn.Linear layers in this module
for sub_name, sub_module in module.named_modules():
if isinstance(sub_module, nn.Linear):
# Get parent module within self_attn
parent_module = module
sub_module_names = sub_name.split(".")
for module_name in sub_module_names[:-1]:
parent_module = getattr(parent_module, module_name)
# Replace the nn.Linear layer with SVDResidualLinear
setattr(parent_module, sub_module_names[-1], replace_with_svd_residual(sub_module, r))
else:
# Recursively apply to child modules
apply_svd_residual_to_self_attn(module, r)
# After replacing, set requires_grad for residual components
for param_name, param in model.named_parameters():
if any(x in param_name for x in ["S_residual", "U_residual", "V_residual"]):
param.requires_grad = True
else:
param.requires_grad = False
return model
# Function to replace a module with SVDResidualLinear
def replace_with_svd_residual(module, r):
if isinstance(module, nn.Linear):
in_features = module.in_features
out_features = module.out_features
bias = module.bias is not None
# Create SVDResidualLinear module
new_module = SVDResidualLinear(in_features, out_features, r, bias=bias, init_weight=module.weight.data.clone())
if bias and module.bias is not None:
new_module.bias.data.copy_(module.bias.data)
new_module.weight_original_fnorm = torch.norm(module.weight.data, p="fro")
# Perform SVD on the original weight
U, S, Vh = torch.linalg.svd(module.weight.data, full_matrices=False)
# Determine r based on the rank of the weight matrix
r = min(r, len(S)) # Ensure r does not exceed the number of singular values
# Keep top r singular components (main weight)
U_r = U[:, :r] # Shape: (out_features, r)
S_r = S[:r] # Shape: (r,)
Vh_r = Vh[:r, :] # Shape: (r, in_features)
# Reconstruct the main weight (fixed)
weight_main = U_r @ torch.diag(S_r) @ Vh_r
# Calculate the frobenius norm of main weight
new_module.weight_main_fnorm = torch.norm(weight_main.data, p="fro")
# Set the main weight
new_module.weight_main.data.copy_(weight_main)
# Residual components (trainable)
U_residual = U[:, r:] # Shape: (out_features, n - r)
S_residual = S[r:] # Shape: (n - r,)
Vh_residual = Vh[r:, :] # Shape: (n - r, in_features)
if len(S_residual) > 0:
new_module.S_residual = nn.Parameter(S_residual.clone())
new_module.U_residual = nn.Parameter(U_residual.clone())
new_module.V_residual = nn.Parameter(Vh_residual.clone())
new_module.S_r = nn.Parameter(S_r.clone(), requires_grad=False)
new_module.U_r = nn.Parameter(U_r.clone(), requires_grad=False)
new_module.V_r = nn.Parameter(Vh_r.clone(), requires_grad=False)
else:
new_module.S_residual = None
new_module.U_residual = None
new_module.V_residual = None
new_module.S_r = None
new_module.U_r = None
new_module.V_r = None
return new_module
else:
return module
# This is the original preprocessing used in Effort paper
# Gives almost the same results as `preprocessing`
_preprocessing_original = T.Compose(
[
T.ToTensor(),
T.Normalize(
[0.48145466, 0.4578275, 0.40821073],
[0.26862954, 0.26130258, 0.27577711],
),
]
)
def preprocessing_original(image: Image) -> torch.Tensor:
image = np.array(image)
image = cv2.resize(image, (224, 224), interpolation=cv2.INTER_LINEAR)
return _preprocessing_original(image)