| ''' |
| # author: Zhiyuan Yan |
| # email: zhiyuanyan@link.cuhk.edu.cn |
| # date: 2023-0706 |
| # description: Class for the CLIPDetector |
| |
| Functions in the Class are summarized as: |
| 1. __init__: Initialization |
| 2. build_backbone: Backbone-building |
| 3. build_loss: Loss-function-building |
| 4. features: Feature-extraction |
| 5. classifier: Classification |
| 6. get_losses: Loss-computation |
| 7. get_train_metrics: Training-metrics-computation |
| 8. get_test_metrics: Testing-metrics-computation |
| 9. forward: Forward-propagation |
| |
| Reference: |
| @inproceedings{radford2021learning, |
| title={Learning transferable visual models from natural language supervision}, |
| author={Radford, Alec and Kim, Jong Wook and Hallacy, Chris and Ramesh, Aditya and Goh, Gabriel and Agarwal, Sandhini and Sastry, Girish and Askell, Amanda and Mishkin, Pamela and Clark, Jack and others}, |
| booktitle={International conference on machine learning}, |
| pages={8748--8763}, |
| year={2021}, |
| organization={PMLR} |
| } |
| ''' |
|
|
| import os |
| import datetime |
| import logging |
| import numpy as np |
| from sklearn import metrics |
| from typing import Union |
| from collections import defaultdict |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import torch.optim as optim |
| from torch.nn import DataParallel |
| from torch.utils.tensorboard import SummaryWriter |
|
|
| from metrics.base_metrics_class import calculate_metrics_for_train, calculate_acc_for_train |
|
|
| from .base_detector import AbstractDetector |
| from detectors import DETECTOR |
| from networks import BACKBONE |
| from loss import LOSSFUNC |
| from transformers import AutoProcessor, CLIPModel, ViTModel, ViTConfig |
| import loralib as lora |
| import copy |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
|
|
| def shuffle_patches(images: torch.Tensor, patch_levels: list = [14, 56, 224]) -> torch.Tensor: |
| """ |
| Apply patch-level shuffling to the input images, where each image in the batch randomly selects a shuffle level. |
| images: input tensor of shape [B, C, H, W], requiring H=W=224 so all levels divide evenly. |
| patch_levels: list of shuffle levels corresponding to patch sizes; 224 means no shuffle. |
| Returns: image tensor with the same shape [B, C, H, W], shuffled per image by a random level. |
| """ |
| B, C, H, W = images.shape |
|
|
| |
| shuffled_images = torch.empty_like(images, device=images.device) |
| |
| probs = [0.33, 0.33, 0.34] |
| random_ps = torch.tensor( |
| [torch.multinomial(torch.tensor(probs), 1).item() for _ in range(B)], |
| device=images.device |
| ) |
| random_ps = torch.tensor([patch_levels[p] for p in random_ps], device=images.device) |
| |
| for b in range(B): |
| ps = random_ps[b].item() |
| img = images[b:b+1] |
| B_single, C, H, W = img.shape |
| num_patches_h = H // ps |
| num_patches_w = W // ps |
| num_patches = num_patches_h * num_patches_w |
|
|
| |
| img = img.view(B_single, C, num_patches_h, ps, num_patches_w, ps) |
| img = img.permute(0, 2, 4, 1, 3, 5).contiguous() |
| img = img.view(B_single, num_patches, C, ps, ps) |
|
|
| |
| perm = torch.randperm(num_patches, device=img.device) |
| batch_idx = torch.arange(B_single, device=img.device).unsqueeze(1).expand(B_single, num_patches) |
| img = img[batch_idx, perm] |
|
|
| |
| img = img.view(B_single, num_patches_h, num_patches_w, C, ps, ps) |
| img = img.permute(0, 3, 1, 4, 2, 5).contiguous() |
| img = img.view(B_single, C, H, W) |
|
|
| |
| shuffled_images[b:b+1] = img |
|
|
| return shuffled_images |
|
|
| @DETECTOR.register_module(module_name='clip_patch_shuffle') |
| class CLIP_PATCH_SHUFFLE_Detector(AbstractDetector): |
| def __init__(self, config): |
| super().__init__() |
| self.config = config |
| self.backbone = self.build_backbone(config) |
| self.head = nn.Linear(1024, config['backbone_config']['num_classes']) |
| self.loss_func = self.build_loss(config) |
| |
| def build_backbone(self, config): |
| |
| _, backbone = get_clip_visual(model_name=config['pretrained']) |
| return backbone |
| |
| def build_loss(self, config): |
| |
| loss_class = LOSSFUNC[config['loss_func']] |
| loss_func = loss_class() |
| return loss_func |
| |
| def features(self, data_dict: dict) -> torch.tensor: |
| x=data_dict['image'] |
| if self.training: |
| x = shuffle_patches(x, patch_levels= [14, 14, 14]) |
| feat = self.backbone(x)['pooler_output'] |
| return feat |
|
|
| def classifier(self, features: torch.tensor) -> torch.tensor: |
| return self.head(features) |
| |
| def get_losses(self, data_dict: dict, pred_dict: dict) -> dict: |
| label = data_dict['label'] |
| pred = pred_dict['cls'] |
| loss = self.loss_func(pred, label) |
| loss_dict = {'overall': loss} |
| return loss_dict |
| |
| def get_train_metrics(self, data_dict: dict, pred_dict: dict) -> dict: |
| label = data_dict['label'] |
| pred = pred_dict['cls'] |
| |
| |
| |
| |
| |
| acc, mAP = calculate_acc_for_train(label.detach(), pred.detach(), self.config['backbone_config']['num_classes']) |
| metric_batch_dict = {'acc': acc, 'mAP': mAP} |
| |
| return metric_batch_dict |
| |
| def forward(self, data_dict: dict, inference=False) -> dict: |
| |
| features = self.features(data_dict) |
| |
| pred = self.classifier(features) |
| |
| |
| prob = torch.softmax(pred, dim=1) |
| |
| pred_dict = {'cls': pred, 'prob': prob, 'feat': features} |
| return pred_dict |
|
|
|
|
| def get_clip_visual(model_name = "openai/clip-vit-base-patch16"): |
| processor = AutoProcessor.from_pretrained(model_name) |
| model = CLIPModel.from_pretrained(model_name) |
| return processor, model.vision_model |
|
|