DeepfakeGenome_Codebase / training /detectors /clip_patch_shuffle.py
shunliwang
update
8bc3305
'''
# author: Zhiyuan Yan
# email: zhiyuanyan@link.cuhk.edu.cn
# date: 2023-0706
# description: Class for the CLIPDetector
Functions in the Class are summarized as:
1. __init__: Initialization
2. build_backbone: Backbone-building
3. build_loss: Loss-function-building
4. features: Feature-extraction
5. classifier: Classification
6. get_losses: Loss-computation
7. get_train_metrics: Training-metrics-computation
8. get_test_metrics: Testing-metrics-computation
9. forward: Forward-propagation
Reference:
@inproceedings{radford2021learning,
title={Learning transferable visual models from natural language supervision},
author={Radford, Alec and Kim, Jong Wook and Hallacy, Chris and Ramesh, Aditya and Goh, Gabriel and Agarwal, Sandhini and Sastry, Girish and Askell, Amanda and Mishkin, Pamela and Clark, Jack and others},
booktitle={International conference on machine learning},
pages={8748--8763},
year={2021},
organization={PMLR}
}
'''
import os
import datetime
import logging
import numpy as np
from sklearn import metrics
from typing import Union
from collections import defaultdict
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.nn import DataParallel
from torch.utils.tensorboard import SummaryWriter
from metrics.base_metrics_class import calculate_metrics_for_train, calculate_acc_for_train
from .base_detector import AbstractDetector
from detectors import DETECTOR
from networks import BACKBONE
from loss import LOSSFUNC
from transformers import AutoProcessor, CLIPModel, ViTModel, ViTConfig
import loralib as lora
import copy
logger = logging.getLogger(__name__)
def shuffle_patches(images: torch.Tensor, patch_levels: list = [14, 56, 224]) -> torch.Tensor:
"""
Apply patch-level shuffling to the input images, where each image in the batch randomly selects a shuffle level.
images: input tensor of shape [B, C, H, W], requiring H=W=224 so all levels divide evenly.
patch_levels: list of shuffle levels corresponding to patch sizes; 224 means no shuffle.
Returns: image tensor with the same shape [B, C, H, W], shuffled per image by a random level.
"""
B, C, H, W = images.shape
# Initialize the output tensor used to store the shuffled result for each image.
shuffled_images = torch.empty_like(images, device=images.device)
# Randomly select a shuffle level (patch size) for each image in the batch.
probs = [0.33, 0.33, 0.34]
random_ps = torch.tensor(
[torch.multinomial(torch.tensor(probs), 1).item() for _ in range(B)],
device=images.device
)
random_ps = torch.tensor([patch_levels[p] for p in random_ps], device=images.device)
# Process each image independently and shuffle or keep it based on the sampled patch size.
for b in range(B):
ps = random_ps[b].item() # patch size selected for the current image
img = images[b:b+1] # take one image while keeping shape [1, C, H, W]
B_single, C, H, W = img.shape
num_patches_h = H // ps
num_patches_w = W // ps
num_patches = num_patches_h * num_patches_w
# Original patch split logic:[1, C, H, W] -> [1, num_patches, C, ps, ps]
img = img.view(B_single, C, num_patches_h, ps, num_patches_w, ps)
img = img.permute(0, 2, 4, 1, 3, 5).contiguous()
img = img.view(B_single, num_patches, C, ps, ps)
# Shuffle patches: when `num_patches=1` (`ps=224`), `randperm(1)` is still `[0]`, which is equivalent to no shuffle.
perm = torch.randperm(num_patches, device=img.device)
batch_idx = torch.arange(B_single, device=img.device).unsqueeze(1).expand(B_single, num_patches)
img = img[batch_idx, perm]
# Restore the original shape: [1, num_patches, C, ps, ps] -> [1, C, H, W]
img = img.view(B_single, num_patches_h, num_patches_w, C, ps, ps)
img = img.permute(0, 3, 1, 4, 2, 5).contiguous()
img = img.view(B_single, C, H, W)
# Write the current image result back into the output tensor.
shuffled_images[b:b+1] = img
return shuffled_images
@DETECTOR.register_module(module_name='clip_patch_shuffle')
class CLIP_PATCH_SHUFFLE_Detector(AbstractDetector):
def __init__(self, config):
super().__init__()
self.config = config
self.backbone = self.build_backbone(config)
self.head = nn.Linear(1024, config['backbone_config']['num_classes'])
self.loss_func = self.build_loss(config)
def build_backbone(self, config):
# prepare the backbone
_, backbone = get_clip_visual(model_name=config['pretrained'])
return backbone
def build_loss(self, config):
# prepare the loss function
loss_class = LOSSFUNC[config['loss_func']]
loss_func = loss_class()
return loss_func
def features(self, data_dict: dict) -> torch.tensor:
x=data_dict['image']
if self.training: # shuffle only during training
x = shuffle_patches(x, patch_levels= [14, 14, 14])
feat = self.backbone(x)['pooler_output']
return feat
def classifier(self, features: torch.tensor) -> torch.tensor:
return self.head(features)
def get_losses(self, data_dict: dict, pred_dict: dict) -> dict:
label = data_dict['label']
pred = pred_dict['cls']
loss = self.loss_func(pred, label)
loss_dict = {'overall': loss}
return loss_dict
def get_train_metrics(self, data_dict: dict, pred_dict: dict) -> dict:
label = data_dict['label']
pred = pred_dict['cls']
# compute metrics for batch data
# auc, eer, acc, ap = calculate_metrics_for_train(label.detach(), pred.detach())
# metric_batch_dict = {'acc': acc, 'auc': auc, 'eer': eer, 'ap': ap}
acc, mAP = calculate_acc_for_train(label.detach(), pred.detach(), self.config['backbone_config']['num_classes'])
metric_batch_dict = {'acc': acc, 'mAP': mAP}
return metric_batch_dict
def forward(self, data_dict: dict, inference=False) -> dict:
# get the features by backbone
features = self.features(data_dict)
# get the prediction by classifier
pred = self.classifier(features)
# get the probability of the pred
# prob = torch.softmax(pred, dim=1)[:, 1]
prob = torch.softmax(pred, dim=1)
# build the prediction dict for each output
pred_dict = {'cls': pred, 'prob': prob, 'feat': features}
return pred_dict
def get_clip_visual(model_name = "openai/clip-vit-base-patch16"):
processor = AutoProcessor.from_pretrained(model_name)
model = CLIPModel.from_pretrained(model_name)
return processor, model.vision_model