| import math |
| import re |
| import warnings |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch.distributions import Normal |
| from transformers import PreTrainedModel |
| from huggingface_hub import PyTorchModelHubMixin |
| from numba import jit |
| from .configuration import FsgConfig |
| from typing import Literal, Type, Union, List |
|
|
|
|
| def batch_fn(iterable, n=1): |
| l = len(iterable) |
| for ndx in range(0, l, n): |
| yield iterable[ndx : min(ndx + n, l)] |
|
|
|
|
| def gaussian_kernel_1d(sigma: float, num_sigmas: float = 3.0) -> torch.Tensor: |
| radius = math.ceil(num_sigmas * sigma) |
| support = torch.arange(-radius, radius + 1, dtype=torch.float) |
| kernel = Normal(loc=0, scale=sigma).log_prob(support).exp_() |
| |
| return kernel.mul_(1 / kernel.sum()) |
|
|
|
|
| def gaussian_filter_2d(img: torch.Tensor, sigma: float) -> torch.Tensor: |
| kernel_1d = gaussian_kernel_1d(sigma).to(img.device) |
| padding = len(kernel_1d) // 2 |
| img = img[None, None, ...] |
| |
| img = F.conv2d(img, weight=kernel_1d.view(1, 1, -1, 1), padding=(padding, 0)) |
| img = F.conv2d(img, weight=kernel_1d.view(1, 1, 1, -1), padding=(0, padding)) |
| return img.squeeze() |
|
|
|
|
| class BaseModel(nn.Module): |
| def __init__( |
| self, |
| patch_size: int, |
| num_classes: int = 0, |
| **kwargs, |
| ): |
| super().__init__() |
| self.patch_size = patch_size |
| self.num_classes = num_classes |
|
|
|
|
| class ConstrainedConv(nn.Module): |
| def __init__(self, input_chan=3, num_filters=6, is_constrained=True): |
| super().__init__() |
| self.kernel_size = 5 |
| self.input_chan = input_chan |
| self.num_filters = num_filters |
| self.is_constrained = is_constrained |
| weight = torch.empty(num_filters, input_chan, self.kernel_size, self.kernel_size) |
| nn.init.xavier_normal_(weight, gain=1/3) |
| self.weight = nn.Parameter(weight, requires_grad=True) |
| self.one_middle = torch.zeros(self.kernel_size * self.kernel_size) |
| self.one_middle[12] = 1 |
| self.one_middle = nn.Parameter(self.one_middle, requires_grad=False) |
|
|
| def forward(self, x): |
| w = self.weight |
| if self.is_constrained: |
| w = w.view(-1, self.kernel_size * self.kernel_size) |
| w = w - w.mean(1)[..., None] + 1 / (self.kernel_size * self.kernel_size - 1) |
| w = w - (w + 1) * self.one_middle |
| w = w.view(self.num_filters, self.input_chan, self.kernel_size, self.kernel_size) |
| x = nn.functional.conv2d(x, w, padding="valid") |
| x = nn.functional.pad(x, (2, 3, 2, 3)) |
| return x |
|
|
|
|
| class ConvBlock(torch.nn.Module): |
| def __init__( |
| self, |
| in_chans, |
| out_chans, |
| kernel_size, |
| stride, |
| padding, |
| activation: Literal["tanh", "relu"], |
| ): |
| super().__init__() |
| assert activation.lower() in ["tanh", "relu"], "The activation layer must be either Tanh or ReLU" |
| self.conv = torch.nn.Conv2d( |
| in_chans, |
| out_chans, |
| kernel_size=kernel_size, |
| stride=stride, |
| padding=padding, |
| ) |
| self.bn = torch.nn.BatchNorm2d(out_chans) |
| self.act = torch.nn.Tanh() if activation.lower() == "tanh" else torch.nn.ReLU() |
| self.maxpool = torch.nn.MaxPool2d(kernel_size=(3, 3), stride=2) |
|
|
| def forward(self, x): |
| return self.maxpool(self.act(self.bn(self.conv(x)))) |
|
|
|
|
| class DenseBlock(torch.nn.Module): |
| def __init__( |
| self, |
| in_chans, |
| out_chans, |
| activation: Literal["tanh", "relu"], |
| ): |
| super().__init__() |
| assert activation.lower() in ["tanh", "relu"], "The activation layer must be either Tanh or ReLU" |
| self.fc = torch.nn.Linear(in_chans, out_chans) |
| self.act = torch.nn.Tanh() if activation.lower() == "tanh" else torch.nn.ReLU() |
|
|
| def forward(self, x): |
| return self.act(self.fc(x)) |
|
|
|
|
| class MISLNet(BaseModel): |
| arch = { |
| "p256": [ |
| ("conv1", -1, 96, 7, 2, "valid", "tanh"), |
| ("conv2", 96, 64, 5, 1, "same", "tanh"), |
| ("conv3", 64, 64, 5, 1, "same", "tanh"), |
| ("conv4", 64, 128, 1, 1, "same", "tanh"), |
| ("fc1", 6 * 6 * 128, 200, "tanh"), |
| ("fc2", 200, 200, "tanh"), |
| ], |
| "p256_3fc_256e": [ |
| ("conv1", -1, 96, 7, 2, "valid", "tanh"), |
| ("conv2", 96, 64, 5, 1, "same", "tanh"), |
| ("conv3", 64, 64, 5, 1, "same", "tanh"), |
| ("conv4", 64, 128, 1, 1, "same", "tanh"), |
| ("fc1", 6 * 6 * 128, 1024, "tanh"), |
| ("fc2", 1024, 512, "tanh"), |
| ("fc3", 512, 256, "tanh"), |
| ], |
| "p128": [ |
| ("conv1", -1, 96, 7, 2, "valid", "tanh"), |
| ("conv2", 96, 64, 5, 1, "same", "tanh"), |
| ("conv3", 64, 64, 5, 1, "same", "tanh"), |
| ("conv4", 64, 128, 1, 1, "same", "tanh"), |
| ("fc1", 2 * 2 * 128, 200, "tanh"), |
| ("fc2", 200, 200, "tanh"), |
| ], |
| "p96": [ |
| ("conv1", -1, 96, 7, 2, "valid", "tanh"), |
| ("conv2", 96, 64, 5, 1, "same", "tanh"), |
| ("conv3", 64, 64, 5, 1, "same", "tanh"), |
| ("conv4", 64, 128, 1, 1, "same", "tanh"), |
| ("fc1", 8 * 4 * 64, 200, "tanh"), |
| ("fc2", 200, 200, "tanh"), |
| ], |
| "p64": [ |
| ("conv1", -1, 96, 7, 2, "valid", "tanh"), |
| ("conv2", 96, 64, 5, 1, "same", "tanh"), |
| ("conv3", 64, 64, 5, 1, "same", "tanh"), |
| ("conv4", 64, 128, 1, 1, "same", "tanh"), |
| ("fc1", 2 * 4 * 64, 200, "tanh"), |
| ("fc2", 200, 200, "tanh"), |
| ], |
| } |
|
|
| def __init__( |
| self, |
| patch_size: int, |
| variant: str, |
| num_classes=0, |
| num_filters=6, |
| is_constrained=True, |
| **kwargs, |
| ): |
| super().__init__(patch_size, num_classes) |
| self.variant = variant |
| self.chosen_arch = self.arch[variant] |
| self.num_filters = num_filters |
|
|
| self.constrained_conv = ConstrainedConv(num_filters=num_filters, is_constrained=is_constrained) |
|
|
| self.conv_blocks = [] |
| self.fc_blocks = [] |
| for block in self.chosen_arch: |
| if block[0].startswith("conv"): |
| self.conv_blocks.append( |
| ConvBlock( |
| in_chans=(num_filters if block[1] == -1 else block[1]), |
| out_chans=block[2], |
| kernel_size=block[3], |
| stride=block[4], |
| padding=block[5], |
| activation=block[6], |
| ) |
| ) |
| elif block[0].startswith("fc"): |
| self.fc_blocks.append( |
| DenseBlock( |
| in_chans=block[1], |
| out_chans=block[2], |
| activation=block[3], |
| ) |
| ) |
|
|
| self.conv_blocks = nn.Sequential(*self.conv_blocks) |
| self.fc_blocks = nn.Sequential(*self.fc_blocks) |
|
|
| self.register_buffer("flatten_index_permutation", torch.tensor([0, 1, 2, 3], dtype=torch.long)) |
|
|
| if self.num_classes > 0: |
| self.output = nn.Linear(self.chosen_arch[-1][2], self.num_classes) |
|
|
| def forward(self, x): |
| x = self.constrained_conv(x) |
| x = self.conv_blocks(x) |
| x = x.permute(*self.flatten_index_permutation) |
| x = x.flatten(1, -1) |
| x = self.fc_blocks(x) |
| if self.num_classes > 0: |
| x = self.output(x) |
| return x |
|
|
| def load_state_dict(self, state_dict, strict=True, assign=False): |
| if "flatten_index_permutation" not in state_dict: |
| super().load_state_dict(state_dict, False, assign) |
| else: |
| super().load_state_dict(state_dict, strict, assign) |
|
|
|
|
| class CompareNet(nn.Module): |
| def __init__(self, input_dim, hidden_dim=2048, output_dim=64): |
| super().__init__() |
| self.fc1 = DenseBlock(input_dim, hidden_dim, "relu") |
| self.fc2 = DenseBlock(hidden_dim * 3, output_dim, "relu") |
| self.fc3 = nn.Linear(output_dim, 2) |
|
|
| def forward(self, x1, x2): |
| x1 = self.fc1(x1) |
| x2 = self.fc1(x2) |
| x = torch.cat((x1, x1 * x2, x2), dim=1) |
| x = self.fc2(x) |
| x = self.fc3(x) |
| return x |
|
|
|
|
| class FSM(nn.Module): |
| """ |
| FSM (Forensic Similarity Metric) is a neural network module that computes the similarity between two input images using a feature extraction module and a comparison network module. |
| |
| Args: |
| fe_config (dict): Configuration for the feature extraction module. |
| comparenet_config (dict): Configuration for the comparison network module. |
| fe_ckpt (str): Path to the checkpoint file for the feature extraction module. |
| **kwargs: Additional keyword arguments. |
| """ |
|
|
| def __init__( |
| self, |
| fe_config, |
| comparenet_config, |
| fe_ckpt=None, |
| **kwargs, |
| ): |
| super().__init__() |
| fe_config["num_classes"] = 0 |
| self.fe: MISLNet = self.load_module_from_ckpt(MISLNet, fe_ckpt, "", **fe_config) |
| self.patch_size = self.fe.patch_size |
| comparenet_config["input_dim"] = self.fe.fc_blocks[-1].fc.out_features |
| self.comparenet = CompareNet(**comparenet_config) |
| self.fe_freeze = True |
|
|
| def load_module_state_dict(self, module: nn.Module, state_dict, module_name=""): |
| curr_model_state_dict = module.state_dict() |
| curr_model_keys_status = {k: False for k in curr_model_state_dict.keys()} |
| outstanding_keys = [] |
| for ckpt_layer_name, ckpt_layer_weights in state_dict.items(): |
| if module_name not in ckpt_layer_name: |
| continue |
| ckpt_matches = re.findall(r"(?=(?:^|\.)((?:\w+\.)*\w+)$)", ckpt_layer_name)[::-1] |
| model_layer_name_match = list(set(ckpt_matches).intersection(set(curr_model_state_dict.keys()))) |
| |
| if len(model_layer_name_match) == 0: |
| outstanding_keys.append(ckpt_layer_name) |
| else: |
| model_layer_name = model_layer_name_match[0] |
| assert ( |
| curr_model_state_dict[model_layer_name].shape == ckpt_layer_weights.shape |
| ), f"Ckpt layer '{ckpt_layer_name}' shape {ckpt_layer_weights.shape} does not match model layer '{model_layer_name}' shape {curr_model_state_dict[model_layer_name].shape}" |
| curr_model_state_dict[model_layer_name] = ckpt_layer_weights |
| curr_model_keys_status[model_layer_name] = True |
|
|
| if all(curr_model_keys_status.values()): |
| print(f"Success! All necessary keys for module '{module.__class__.__name__}' are loaded!") |
| else: |
| not_loaded_keys = [k for k, v in curr_model_keys_status.items() if not v] |
| print(f"Warning! Some keys are not loaded! Not loaded keys are:\n{not_loaded_keys}") |
| if len(outstanding_keys) > 0: |
| print(f"Outstanding keys are: {outstanding_keys}") |
| module.load_state_dict(curr_model_state_dict, strict=False) |
|
|
| def load_module_from_ckpt( |
| self, |
| module_class: Type[nn.Module], |
| ckpt_path: Union[None, str], |
| module_name: str, |
| *args, |
| **kwargs, |
| ) -> nn.Module: |
| module = module_class(*args, **kwargs) |
|
|
| if ckpt_path is not None: |
| ckpt = torch.load(ckpt_path, map_location="cpu") |
| ckpt_state_dict = ckpt["state_dict"] |
| self.load_module_state_dict(module, ckpt_state_dict, module_name=module_name) |
| return module |
|
|
| def load_state_dict(self, state_dict, strict=True, assign=False): |
| try: |
| super().load_state_dict(state_dict, strict=strict, assign=assign) |
| except Exception as e: |
| print(f"Error loading state dict using normal method: {e}") |
| print("Trying to load state dict manually...") |
| |
| |
| self.load_module_state_dict(self, state_dict, module_name="") |
| print("State dict loaded successfully!") |
|
|
| def forward_fe(self, x): |
| if self.freeze_fe: |
| self.fe.eval() |
| with torch.no_grad(): |
| return self.fe(x) |
| else: |
| self.fe.train() |
| return self.fe(x) |
|
|
| def forward(self, x1, x2): |
| x1 = self.forward_fe(x1) |
| x2 = self.forward_fe(x2) |
| return self.comparenet(x1, x2) |
|
|
|
|
| class FsgModel( |
| PreTrainedModel, |
| PyTorchModelHubMixin, |
| repo_url="ductai199x/forensic-similarity-graph", |
| pipeline_tag="image-manipulation-detection-localization", |
| license="cc-by-nc-nd-4.0", |
| ): |
| """ |
| Forensic Similarity Graph (FSG) algorithm. |
| https://ieeexplore.ieee.org/abstract/document/9113265 |
| |
| This class is designed to create a graph-based representation of forensic similarity between different patches of an image, allowing for the detection of manipulated regions. |
| """ |
| config_class = FsgConfig |
|
|
|
|
| def __init__(self, config: FsgConfig, **kwargs): |
| super().__init__(config) |
| self.patch_size = config.fe_config.patch_size |
| self.stride = int(self.patch_size * config.stride_ratio) |
| self.fast_sim_mode = config.fast_sim_mode |
| self.loc_threshold = config.loc_threshold |
| self.is_high_sim = True |
| self.need_input_255 = config.need_input_255 |
| self.model = FSM(fe_config=config.fe_config.to_dict(), comparenet_config=config.comparenet_config.to_dict()) |
|
|
| warnings.filterwarnings("ignore") |
|
|
| def get_batched_patches(self, x: torch.Tensor): |
| B, C, H, W = x.shape |
| |
| batched_patches = ( |
| x.unfold(2, self.patch_size, self.stride) |
| .unfold(3, self.patch_size, self.stride) |
| .permute(0, 2, 3, 1, 4, 5) |
| ) |
| batched_patches = batched_patches.contiguous().view(B, -1, C, self.patch_size, self.patch_size) |
| return batched_patches |
|
|
| def get_patches_single(self, x: torch.Tensor): |
| C, H, W = x.shape |
| patches = ( |
| x.unfold(1, self.patch_size, self.stride) |
| .unfold(2, self.patch_size, self.stride) |
| .permute(1, 2, 0, 3, 4) |
| ) |
| patches = patches.contiguous().view(-1, C, self.patch_size, self.patch_size) |
| return patches |
|
|
| @jit(forceobj=True) |
| def get_features(self, image_patches: torch.Tensor): |
| patches_features = [] |
| for batch in list(batch_fn(image_patches, 256)): |
| batch = batch.float() |
| feats = self.model.fe(batch).detach() |
| patches_features.append(feats) |
| patches_features = torch.vstack(patches_features) |
| return patches_features |
|
|
| @jit(forceobj=True) |
| def get_sim_scores(self, patch_pairs): |
| patches_sim_scores = [] |
| for batch in list(batch_fn(patch_pairs, 4096)): |
| batch = batch.permute(1, 0, 2).float() |
| scores = self.model.comparenet(*batch).detach() |
| scores = torch.nn.functional.softmax(scores, dim=1) |
| patches_sim_scores.append(scores) |
| patches_sim_scores = torch.vstack(patches_sim_scores) |
| return patches_sim_scores |
|
|
| def forward_single(self, patches: torch.Tensor): |
| P, C, H, W = patches.shape |
| features = self.get_features(patches) |
| sim_mat = torch.zeros(P, P, device=patches.device) |
| if self.fast_sim_mode: |
| upper_tri_idx = torch.triu_indices(P, P, 1).T |
| patch_pairs = features[upper_tri_idx] |
| else: |
| patch_cart_prod = torch.cartesian_prod(torch.arange(P), torch.arange(P)) |
| patch_pairs = features[patch_cart_prod] |
| sim_scores = self.get_sim_scores(patch_pairs).detach() |
| if self.fast_sim_mode: |
| sim_mat[upper_tri_idx[:, 0], upper_tri_idx[:, 1]] = sim_scores[:, 1] |
| sim_mat += sim_mat.clone().T |
| else: |
| sim_mat = sim_scores[:, 1].view(P, P) |
| sim_mat = 0.5 * (sim_mat + sim_mat.T) |
| if not self.is_high_sim: |
| sim_mat = 1 - sim_mat |
| sim_mat.fill_diagonal_(0.0) |
| degree_mat = torch.diag(sim_mat.sum(axis=1)) |
| laplacian_mat = degree_mat - sim_mat |
| degree_sym_mat = torch.diag(sim_mat.sum(axis=1) ** -0.5) |
| laplacian_sym_mat = (degree_sym_mat @ laplacian_mat) @ degree_sym_mat |
| eigvals, eigvecs = torch.linalg.eigh(laplacian_sym_mat.cpu()) |
| spectral_gap = eigvals[1] - eigvals[0] |
| img_pred = 1 - spectral_gap |
| eigvec = eigvecs[:, 1] |
| patch_pred = (eigvec > 0).int() |
| return img_pred.detach(), patch_pred.detach() |
|
|
| def forward(self, x: Union[torch.Tensor, List[torch.Tensor]]): |
| if isinstance(x, torch.Tensor) and len(x.shape) == 3: |
| x = [x] |
|
|
| img_preds = [] |
| loc_preds = [] |
| for img in x: |
| C, H, W = img.shape |
| if self.need_input_255 and img.max() <= 1: |
| img = img * 255 |
| |
| x_inds = torch.arange(W).unfold(0, self.patch_size, self.stride)[:, 0] |
| y_inds = torch.arange(H).unfold(0, self.patch_size, self.stride)[:, 0] |
| xy_inds = torch.tensor([(ii, jj) for jj in y_inds for ii in x_inds]).to(img.device) |
|
|
| patches = self.get_patches_single(img) |
| img_pred, patch_pred = self.forward_single(patches) |
| loc_pred = self.patch_to_pixel_pred(patch_pred, xy_inds) |
| loc_pred = F.interpolate(loc_pred[None, None, ...], size=(H, W), mode="nearest").squeeze() |
| img_preds.append(img_pred) |
| loc_preds.append(loc_pred) |
| return img_preds, loc_preds |
|
|
| def patch_to_pixel_pred(self, patch_pred, xy_inds): |
| W, H = torch.max(xy_inds, dim=0).values + self.patch_size |
| pixel_pred = torch.zeros((H, W)).to(patch_pred.device) |
| coverage_map = torch.zeros((H, W)).to(patch_pred.device) |
| for (x, y), pred in zip(xy_inds, patch_pred): |
| pixel_pred[y : y + self.patch_size, x : x + self.patch_size] += pred |
| coverage_map[y : y + self.patch_size, x : x + self.patch_size] += 1 |
| |
| pixel_pred = gaussian_filter_2d(pixel_pred, sigma=32) |
| coverage_map = gaussian_filter_2d(coverage_map, sigma=32) |
| pixel_pred /= coverage_map + 1e-8 |
| pixel_pred /= pixel_pred.max() + 1e-8 |
| if pixel_pred.sum() > pixel_pred.numel() * 0.5: |
| pixel_pred = 1 - pixel_pred |
| pixel_pred = (pixel_pred > self.loc_threshold).float() |
| return pixel_pred |
|
|