File size: 18,695 Bytes
8bc3305 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 | import os
import math
import datetime
import logging
import numpy as np
from sklearn import metrics
from typing import Union
from collections import defaultdict
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.nn import DataParallel
from torch.utils.tensorboard import SummaryWriter
from metrics.base_metrics_class import calculate_metrics_for_train, calculate_acc_for_train
from .base_detector import AbstractDetector
from detectors import DETECTOR
from networks import BACKBONE
from loss import LOSSFUNC
import albumentations as A
import loralib as lora
from transformers import AutoProcessor, CLIPModel, ViTModel, ViTConfig
logger = logging.getLogger(__name__)
def get_clip_visual(model_name = "openai/clip-vit-base-patch16"):
processor = AutoProcessor.from_pretrained(model_name)
model = CLIPModel.from_pretrained(model_name)
return processor, model.vision_model
def shuffle_patches(images: torch.Tensor, patch_size: int = 14) -> torch.Tensor:
"""
Apply patch-level shuffling to the input images.
images: [B, C, H, W]
patch_size: patch size used by ViT (for example, 16)
Returns: an image tensor with the same shape [B, C, H, W]
"""
B, C, H, W = images.shape
assert H % patch_size == 0 and W % patch_size == 0, \
f"H ({H}) and W ({W}) must be divisible by patch_size ({patch_size})"
num_patches_h = H // patch_size
num_patches_w = W // patch_size
num_patches = num_patches_h * num_patches_w
# [B, C, H, W] -> [B, C, num_patches_h, patch_size, num_patches_w, patch_size]
images = images.view(B, C, num_patches_h, patch_size, num_patches_w, patch_size)
# -> [B, num_patches_h, num_patches_w, C, patch_size, patch_size]
images = images.permute(0, 2, 4, 1, 3, 5).contiguous()
# -> [B, num_patches, C, patch_size, patch_size]
images = images.view(B, num_patches, C, patch_size, patch_size)
# Shuffle patch order independently for each image.
# permutation shape: [B, num_patches]
perms = torch.stack(
[torch.randperm(num_patches, device=images.device) for _ in range(B)],
dim=0
)
# Use advanced indexing to perform the shuffle.
batch_idx = torch.arange(B, device=images.device).unsqueeze(1).expand(B, num_patches)
images = images[batch_idx, perms] # [B, num_patches, C, patch_size, patch_size]
# Restore the original image shape.
images = images.view(B, num_patches_h, num_patches_w, C, patch_size, patch_size)
# -> [B, C, num_patches_h, patch_size, num_patches_w, patch_size]
images = images.permute(0, 3, 1, 4, 2, 5).contiguous()
# -> [B, C, H, W]
images = images.view(B, C, H, W)
return images
def get_aug_transform():
return A.Compose([
A.HorizontalFlip(p=0.5),
A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=0.5),
A.HueSaturationValue(p=0.3),
A.ImageCompression(quality_lower=40, quality_upper=100, p=0.1),
A.GaussNoise(p=0.1),
A.MotionBlur(p=0.1),
A.CLAHE(p=0.1),
A.ChannelShuffle(p=0.1),
A.Cutout(p=0.1),
A.RandomGamma(p=0.3),
A.GlassBlur(p=0.3),
])
def data_aug(images: torch.Tensor) -> torch.Tensor:
is_gpu = images.is_cuda
aug = get_aug_transform()
# Step 1: convert the batch tensor to batch numpy arrays (BHWC uint8 0-255).
imgs_np = images.cpu().detach().numpy()
imgs_np = np.transpose(imgs_np, (0, 2, 3, 1)) # BCHW -> BHWC
imgs_np = (imgs_np * 255).astype(np.uint8)
# Step 2: augment images one by one to avoid KeyError from batch-style arguments.
imgs_aug_np = []
for img in imgs_np:
# Pass a single image with `image=img`, which is natively supported by Albumentations.
aug_img = aug(image=img)["image"]
imgs_aug_np.append(aug_img)
imgs_aug_np = np.array(imgs_aug_np) # convert back to batch numpy arrays
# Step 3: convert back to a tensor while preserving the original logic.
aug_tensor = torch.from_numpy(imgs_aug_np).permute(0, 3, 1, 2)
aug_tensor = aug_tensor.float() / 255.0
# Restore the original device.
if is_gpu:
aug_tensor = aug_tensor.cuda()
return aug_tensor
@DETECTOR.register_module(module_name='effort_shuffle_ensemble')
class Effort_Shuffle_Ensenble_Detector(nn.Module):
def __init__(self, config=None):
super().__init__()
self.config = config
self.backbone = self.build_backbone(config)
self.head = nn.Linear(1024, config['backbone_config']['num_classes'])
#self.head1 = nn.Linear(1024, config['backbone_config']['num_classes'])
self.loss_func = nn.CrossEntropyLoss()
self.prob, self.label = [], []
self.correct, self.total = 0, 0
#self.backbone2=self.build_clip_backbone(config)
def build_clip_backbone(self,config):
_, backbone = get_clip_visual(model_name=config['pretrained'])
return backbone
def build_backbone(self, config):
# Download model
# https://huggingface.co/openai/clip-vit-large-patch14
# mean: [0.48145466, 0.4578275, 0.40821073]
# std: [0.26862954, 0.26130258, 0.27577711]
# ViT-L/14 224*224
clip_model = CLIPModel.from_pretrained(self.config["pretrained"])
# Apply SVD to self_attn layers only
# ViT-L/14 224*224: 1024-1
clip_model.vision_model = apply_svd_residual_to_self_attn(clip_model.vision_model, r=1024-1)
for name, param in clip_model.vision_model.named_parameters():
print('{}: {}'.format(name, param.requires_grad))
num_param = sum(p.numel() for p in clip_model.vision_model.parameters() if p.requires_grad)
num_total_param = sum(p.numel() for p in clip_model.vision_model.parameters())
print('Number of total parameters: {}, tunable parameters: {}'.format(num_total_param, num_param))
return clip_model.vision_model
def features(self, data_dict: dict) -> torch.tensor:
# data_dict['image']: torch.Size([32, 3, 224, 224])
if self.training:
#aug_image=data_aug(data_dict['image'])
shuffle_images=shuffle_patches(data_dict['image'],14)
feat = self.backbone(shuffle_images)['pooler_output']
#feat1=self.backbone2(shuffle_images)['pooler_output']
else:
feat = self.backbone(data_dict['image'])['pooler_output']
#feat1=self.backbone2(data_dict['image'])['pooler_output']
# feat torch.Size([32, 1024])
return feat#,feat1
def classifier(self, features: torch.tensor) -> torch.tensor:
return self.head(features)
# def get_losses(self, data_dict: dict, pred_dict: dict) -> dict:
# label = data_dict['label']
# pred = pred_dict['cls']
# loss = self.loss_func(pred, label)
# # Regularization term
# lambda_reg = 0.1
# orthogonal_losses = []
# for module in self.backbone.modules():
# if isinstance(module, SVDResidualLinear):
# # Apply orthogonal constraints to the U_residual and V_residual matrix
# orthogonal_losses.append(module.compute_orthogonal_loss())
# if orthogonal_losses:
# reg_term = sum(orthogonal_losses)
# loss += lambda_reg * reg_term
# loss_dict = {'overall': loss}
# return loss_dict
def compute_weight_loss(self):
weight_sum_dict = {}
num_weight_dict = {}
for name, module in self.backbone.named_modules():
if isinstance(module, SVDResidualLinear):
weight_curr = module.compute_current_weight()
if str(weight_curr.size()) not in weight_sum_dict.keys():
weight_sum_dict[str(weight_curr.size())] = weight_curr
num_weight_dict[str(weight_curr.size())] = 1
else:
weight_sum_dict[str(weight_curr.size())] += weight_curr
num_weight_dict[str(weight_curr.size())] += 1
loss2 = 0.0
for k in weight_sum_dict.keys():
_, S_sum, _ = torch.linalg.svd(weight_sum_dict[k], full_matrices=False)
loss2 += -torch.mean(S_sum)
loss2 /= len(weight_sum_dict.keys())
return loss2
def get_losses(self, data_dict: dict, pred_dict: dict) -> dict:
label = data_dict['label'] # Tensor of shape [batch_size]
pred = pred_dict['cls'] # Tensor of shape [batch_size, num_classes]
# Compute overall loss using all samples
loss = self.loss_func(pred, label)
# Create masks for real and fake classes
mask_real = label == 0 # Boolean tensor
mask_fake = label == 1 # Boolean tensor
# Compute loss for real class
if mask_real.sum() > 0:
pred_real = pred[mask_real]
label_real = label[mask_real]
loss_real = self.loss_func(pred_real, label_real)
else:
# No real samples in batch
loss_real = torch.tensor(0.0, device=pred.device)
# Compute loss for fake class
if mask_fake.sum() > 0:
pred_fake = pred[mask_fake]
label_fake = label[mask_fake]
loss_fake = self.loss_func(pred_fake, label_fake)
else:
# No fake samples in batch
loss_fake = torch.tensor(0.0, device=pred.device)
# loss2 = self.compute_weight_loss()
# overall_loss = loss + loss2
# Return a dictionary with all losses
loss_dict = {
'overall': loss,
'real_loss': loss_real,
'fake_loss': loss_fake,
# 'erank_loss': loss2
}
return loss_dict
def get_train_metrics(self, data_dict: dict, pred_dict: dict) -> dict:
label = data_dict['label']
pred = pred_dict['cls']
# compute metrics for batch data
# auc, eer, acc, ap = calculate_metrics_for_train(label.detach(), pred.detach())
# metric_batch_dict = {'acc': acc, 'auc': auc, 'eer': eer, 'ap': ap}
acc, mAP = calculate_acc_for_train(label.detach(), pred.detach(), self.config['backbone_config']['num_classes'])
metric_batch_dict = {'acc': acc, 'mAP': mAP}
return metric_batch_dict
def forward(self, data_dict: dict, inference=False) -> dict:
# get the features by backbone
features= self.features(data_dict)
# get the prediction by classifier
pred = self.classifier(features)
#features=features+f1
#pred=pred+pred1
# get the probability of the pred
# prob = torch.softmax(pred, dim=1)[:, 1]
prob = torch.softmax(pred, dim=1)
# build the prediction dict for each output
pred_dict = {'cls': pred, 'prob': prob, 'feat': features}
return pred_dict
# Custom module to represent the residual using SVD components
class SVDResidualLinear(nn.Module):
def __init__(self, in_features, out_features, r, bias=True, init_weight=None):
super(SVDResidualLinear, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.r = r # Number of top singular values to exclude
# Original weights (fixed)
self.weight_main = nn.Parameter(torch.Tensor(out_features, in_features), requires_grad=False)
if init_weight is not None:
self.weight_main.data.copy_(init_weight)
else:
nn.init.kaiming_uniform_(self.weight_main, a=math.sqrt(5))
# Bias
if bias:
self.bias = nn.Parameter(torch.Tensor(out_features))
nn.init.zeros_(self.bias)
else:
self.register_parameter('bias', None)
def compute_current_weight(self):
if self.S_residual is not None:
return self.weight_main + self.U_residual @ torch.diag(self.S_residual) @ self.V_residual
else:
return self.weight_main
def forward(self, x):
if hasattr(self, 'U_residual') and hasattr(self, 'V_residual') and self.S_residual is not None:
# Reconstruct the residual weight
residual_weight = self.U_residual @ torch.diag(self.S_residual) @ self.V_residual
# Total weight is the fixed main weight plus the residual
weight = self.weight_main + residual_weight
else:
# If residual components are not set, use only the main weight
weight = self.weight_main
return F.linear(x, weight, self.bias)
def compute_orthogonal_loss(self):
if self.S_residual is not None:
# According to the properties of orthogonal matrices: A^TA = I
UUT = torch.cat((self.U_r, self.U_residual), dim=1) @ torch.cat((self.U_r, self.U_residual), dim=1).t()
VVT = torch.cat((self.V_r, self.V_residual), dim=0) @ torch.cat((self.V_r, self.V_residual), dim=0).t()
# print(self.U_r.size(), self.U_residual.size()) # torch.Size([1024, 1023]) torch.Size([1024, 1])
# print(self.V_r.size(), self.V_residual.size()) # torch.Size([1023, 1024]) torch.Size([1, 1024])
# UUT = self.U_residual @ self.U_residual.t()
# VVT = self.V_residual @ self.V_residual.t()
# Construct an identity matrix
UUT_identity = torch.eye(UUT.size(0), device=UUT.device)
VVT_identity = torch.eye(VVT.size(0), device=VVT.device)
# Using frobenius norm to compute loss
loss = 0.5 * torch.norm(UUT - UUT_identity, p='fro') + 0.5 * torch.norm(VVT - VVT_identity, p='fro')
else:
loss = 0.0
return loss
def compute_keepsv_loss(self):
if (self.S_residual is not None) and (self.weight_original_fnorm is not None):
# Total current weight is the fixed main weight plus the residual
weight_current = self.weight_main + self.U_residual @ torch.diag(self.S_residual) @ self.V_residual
# Frobenius norm of current weight
weight_current_fnorm = torch.norm(weight_current, p='fro')
loss = torch.abs(weight_current_fnorm ** 2 - self.weight_original_fnorm ** 2)
# loss = torch.abs(weight_current_fnorm ** 2 + 0.01 * self.weight_main_fnorm ** 2 - 1.01 * self.weight_original_fnorm ** 2)
else:
loss = 0.0
return loss
def compute_fn_loss(self):
if (self.S_residual is not None):
weight_current = self.weight_main + self.U_residual @ torch.diag(self.S_residual) @ self.V_residual
weight_current_fnorm = torch.norm(weight_current, p='fro')
loss = weight_current_fnorm ** 2
else:
loss = 0.0
return loss
# Function to replace nn.Linear modules within self_attn modules with SVDResidualLinear
def apply_svd_residual_to_self_attn(model, r):
for name, module in model.named_children():
if 'self_attn' in name:
# Replace nn.Linear layers in this module
for sub_name, sub_module in module.named_modules():
if isinstance(sub_module, nn.Linear):
# Get parent module within self_attn
parent_module = module
sub_module_names = sub_name.split('.')
for module_name in sub_module_names[:-1]:
parent_module = getattr(parent_module, module_name)
# Replace the nn.Linear layer with SVDResidualLinear
setattr(parent_module, sub_module_names[-1], replace_with_svd_residual(sub_module, r))
else:
# Recursively apply to child modules
apply_svd_residual_to_self_attn(module, r)
# After replacing, set requires_grad for residual components
for param_name, param in model.named_parameters():
if any(x in param_name for x in ['S_residual', 'U_residual', 'V_residual']):
param.requires_grad = True
else:
param.requires_grad = False
return model
# Function to replace a module with SVDResidualLinear
def replace_with_svd_residual(module, r):
if isinstance(module, nn.Linear):
in_features = module.in_features
out_features = module.out_features
bias = module.bias is not None
# Create SVDResidualLinear module
new_module = SVDResidualLinear(in_features, out_features, r, bias=bias, init_weight=module.weight.data.clone())
if bias and module.bias is not None:
new_module.bias.data.copy_(module.bias.data)
new_module.weight_original_fnorm = torch.norm(module.weight.data, p='fro')
# Perform SVD on the original weight
U, S, Vh = torch.linalg.svd(module.weight.data, full_matrices=False)
# Determine r based on the rank of the weight matrix
r = min(r, len(S)) # Ensure r does not exceed the number of singular values
# Keep top r singular components (main weight)
U_r = U[:, :r] # Shape: (out_features, r)
S_r = S[:r] # Shape: (r,)
Vh_r = Vh[:r, :] # Shape: (r, in_features)
# Reconstruct the main weight (fixed)
weight_main = U_r @ torch.diag(S_r) @ Vh_r
# Calculate the frobenius norm of main weight
new_module.weight_main_fnorm = torch.norm(weight_main.data, p='fro')
# Set the main weight
new_module.weight_main.data.copy_(weight_main)
# Residual components (trainable)
U_residual = U[:, r:] # Shape: (out_features, n - r)
S_residual = S[r:] # Shape: (n - r,)
Vh_residual = Vh[r:, :] # Shape: (n - r, in_features)
if len(S_residual) > 0:
new_module.S_residual = nn.Parameter(S_residual.clone())
new_module.U_residual = nn.Parameter(U_residual.clone())
new_module.V_residual = nn.Parameter(Vh_residual.clone())
new_module.S_r = nn.Parameter(S_r.clone(), requires_grad=False)
new_module.U_r = nn.Parameter(U_r.clone(), requires_grad=False)
new_module.V_r = nn.Parameter(Vh_r.clone(), requires_grad=False)
else:
new_module.S_residual = None
new_module.U_residual = None
new_module.V_residual = None
new_module.S_r = None
new_module.U_r = None
new_module.V_r = None
return new_module
else:
return module
|