''' # author: Zhiyuan Yan # email: zhiyuanyan@link.cuhk.edu.cn # date: 2023-0706 # description: Class for the CLIPDetector Functions in the Class are summarized as: 1. __init__: Initialization 2. build_backbone: Backbone-building 3. build_loss: Loss-function-building 4. features: Feature-extraction 5. classifier: Classification 6. get_losses: Loss-computation 7. get_train_metrics: Training-metrics-computation 8. get_test_metrics: Testing-metrics-computation 9. forward: Forward-propagation Reference: @inproceedings{radford2021learning, title={Learning transferable visual models from natural language supervision}, author={Radford, Alec and Kim, Jong Wook and Hallacy, Chris and Ramesh, Aditya and Goh, Gabriel and Agarwal, Sandhini and Sastry, Girish and Askell, Amanda and Mishkin, Pamela and Clark, Jack and others}, booktitle={International conference on machine learning}, pages={8748--8763}, year={2021}, organization={PMLR} } ''' import os import datetime import logging import numpy as np from sklearn import metrics from typing import Union from collections import defaultdict import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from torch.nn import DataParallel from torch.utils.tensorboard import SummaryWriter from metrics.base_metrics_class import calculate_metrics_for_train, calculate_acc_for_train from .base_detector import AbstractDetector from detectors import DETECTOR from networks import BACKBONE from loss import LOSSFUNC from transformers import AutoProcessor, CLIPModel, ViTModel, ViTConfig import loralib as lora import copy logger = logging.getLogger(__name__) def shuffle_patches(images: torch.Tensor, patch_levels: list = [14, 56, 224]) -> torch.Tensor: """ Apply patch-level shuffling to the input images, where each image in the batch randomly selects a shuffle level. images: input tensor of shape [B, C, H, W], requiring H=W=224 so all levels divide evenly. patch_levels: list of shuffle levels corresponding to patch sizes; 224 means no shuffle. Returns: image tensor with the same shape [B, C, H, W], shuffled per image by a random level. """ B, C, H, W = images.shape # Initialize the output tensor used to store the shuffled result for each image. shuffled_images = torch.empty_like(images, device=images.device) # Randomly select a shuffle level (patch size) for each image in the batch. probs = [0.33, 0.33, 0.34] random_ps = torch.tensor( [torch.multinomial(torch.tensor(probs), 1).item() for _ in range(B)], device=images.device ) random_ps = torch.tensor([patch_levels[p] for p in random_ps], device=images.device) # Process each image independently and shuffle or keep it based on the sampled patch size. for b in range(B): ps = random_ps[b].item() # patch size selected for the current image img = images[b:b+1] # take one image while keeping shape [1, C, H, W] B_single, C, H, W = img.shape num_patches_h = H // ps num_patches_w = W // ps num_patches = num_patches_h * num_patches_w # Original patch split logic:[1, C, H, W] -> [1, num_patches, C, ps, ps] img = img.view(B_single, C, num_patches_h, ps, num_patches_w, ps) img = img.permute(0, 2, 4, 1, 3, 5).contiguous() img = img.view(B_single, num_patches, C, ps, ps) # Shuffle patches: when `num_patches=1` (`ps=224`), `randperm(1)` is still `[0]`, which is equivalent to no shuffle. perm = torch.randperm(num_patches, device=img.device) batch_idx = torch.arange(B_single, device=img.device).unsqueeze(1).expand(B_single, num_patches) img = img[batch_idx, perm] # Restore the original shape: [1, num_patches, C, ps, ps] -> [1, C, H, W] img = img.view(B_single, num_patches_h, num_patches_w, C, ps, ps) img = img.permute(0, 3, 1, 4, 2, 5).contiguous() img = img.view(B_single, C, H, W) # Write the current image result back into the output tensor. shuffled_images[b:b+1] = img return shuffled_images @DETECTOR.register_module(module_name='clip_patch_shuffle') class CLIP_PATCH_SHUFFLE_Detector(AbstractDetector): def __init__(self, config): super().__init__() self.config = config self.backbone = self.build_backbone(config) self.head = nn.Linear(1024, config['backbone_config']['num_classes']) self.loss_func = self.build_loss(config) def build_backbone(self, config): # prepare the backbone _, backbone = get_clip_visual(model_name=config['pretrained']) return backbone def build_loss(self, config): # prepare the loss function loss_class = LOSSFUNC[config['loss_func']] loss_func = loss_class() return loss_func def features(self, data_dict: dict) -> torch.tensor: x=data_dict['image'] if self.training: # shuffle only during training x = shuffle_patches(x, patch_levels= [14, 14, 14]) feat = self.backbone(x)['pooler_output'] return feat def classifier(self, features: torch.tensor) -> torch.tensor: return self.head(features) def get_losses(self, data_dict: dict, pred_dict: dict) -> dict: label = data_dict['label'] pred = pred_dict['cls'] loss = self.loss_func(pred, label) loss_dict = {'overall': loss} return loss_dict def get_train_metrics(self, data_dict: dict, pred_dict: dict) -> dict: label = data_dict['label'] pred = pred_dict['cls'] # compute metrics for batch data # auc, eer, acc, ap = calculate_metrics_for_train(label.detach(), pred.detach()) # metric_batch_dict = {'acc': acc, 'auc': auc, 'eer': eer, 'ap': ap} acc, mAP = calculate_acc_for_train(label.detach(), pred.detach(), self.config['backbone_config']['num_classes']) metric_batch_dict = {'acc': acc, 'mAP': mAP} return metric_batch_dict def forward(self, data_dict: dict, inference=False) -> dict: # get the features by backbone features = self.features(data_dict) # get the prediction by classifier pred = self.classifier(features) # get the probability of the pred # prob = torch.softmax(pred, dim=1)[:, 1] prob = torch.softmax(pred, dim=1) # build the prediction dict for each output pred_dict = {'cls': pred, 'prob': prob, 'feat': features} return pred_dict def get_clip_visual(model_name = "openai/clip-vit-base-patch16"): processor = AutoProcessor.from_pretrained(model_name) model = CLIPModel.from_pretrained(model_name) return processor, model.vision_model