# ============================================================================== # DATDA-INDEX: Adversarial Perturbation Severity Estimator # Author: Qamar Muneer Akbar # ORCID: 0009-0003-6671-9253 | www.ftiuae.com # Description: # Computes the DATDA Index, a quantitative measure (0-1) # representing adversarial perturbation or anomaly strength in images. # Designed to run independently before/after DATDA defense. # ============================================================================== import torch import torch.nn.functional as F import numpy as np from torchvision.transforms import ToTensor from PIL import Image import math # ------------------------------------------------------------ # Utility Functions # ------------------------------------------------------------ def _to_tensor(img: Image.Image) -> torch.Tensor: """Convert PIL image to normalized tensor on CPU/GPU.""" return ToTensor()(img).unsqueeze(0).to("cuda" if torch.cuda.is_available() else "cpu") # ------------------------------------------------------------ # DATDA Index Calculator # ------------------------------------------------------------ class DATDAIndex: """ Computes a DATDA Index score between 0 and 1 indicating perturbation severity. Combines spectral energy, local variance, and pixel distribution metrics. """ def __init__(self, high_freq_ratio: float = 0.35): self.high_freq_ratio = high_freq_ratio @torch.no_grad() def compute(self, x: torch.Tensor) -> float: """ Args: x (torch.Tensor): [1x3xHxW] image tensor, values in [0,1] Returns: float: DATDA Index (0.0 to 1.0) """ if x.dim() == 3: x = x.unsqueeze(0) if x.shape[1] != 3: x = x.repeat(1, 3, 1, 1) B, C, H, W = x.shape device = x.device # Convert to grayscale x_gray = (0.299 * x[:, 0] + 0.587 * x[:, 1] + 0.114 * x[:, 2]).unsqueeze(1) # 1️⃣ High-frequency spectral ratio fft = torch.fft.fft2(x_gray.squeeze(1)) fft_shift = torch.fft.fftshift(fft) mag = torch.abs(fft_shift) center_h, center_w = H // 2, W // 2 yy, xx = torch.meshgrid( torch.arange(H, device=device), torch.arange(W, device=device), indexing="ij" ) dist = torch.sqrt((xx - center_w) ** 2 + (yy - center_h) ** 2) radius = int(self.high_freq_ratio * min(H, W)) high_mask = (dist > radius).float() high_energy = (mag * high_mask).sum() total_energy = mag.sum() + 1e-8 freq_ratio = high_energy / total_energy # 2️⃣ Local variance pad = 1 x_pad = F.pad(x_gray, (pad, pad, pad, pad), mode="reflect") patches = x_pad.unfold(2, 3, 1).unfold(3, 3, 1) var_local = (patches - patches.mean(dim=(-1, -2), keepdim=True)).pow(2).mean() var_score = torch.clamp(var_local * 2.0, 0, 1) # 3️⃣ Residual noise vs smoothed baseline baseline = F.avg_pool2d(x_gray, kernel_size=3, stride=1, padding=1) res = (x_gray - baseline).abs().mean() # 4️⃣ Entropy flat = (x_gray.view(-1) * 255).long().clamp(0, 255) hist = torch.bincount(flat, minlength=256).float() probs = hist / (hist.sum() + 1e-8) entropy = -(probs * (probs + 1e-12).log()).sum() / math.log(256.0) # Combine features (weighted) weights = torch.tensor([0.35, 0.25, 0.25, 0.15], device=device) feats = torch.tensor([freq_ratio, var_score, res, entropy], device=device) feats_norm = (feats - feats.min()) / (feats.max() - feats.min() + 1e-8) index = (feats_norm * weights).sum().item() return round(float(np.clip(index, 0, 1)), 4) # ------------------------------------------------------------ # Example CLI use # ------------------------------------------------------------ if __name__ == "__main__": import sys if len(sys.argv) < 2: print("Usage: python datda-index.py ") exit(0) img_path = sys.argv[1] img = Image.open(img_path).convert("RGB") calc = DATDAIndex() x = _to_tensor(img) idx = calc.compute(x) print(f"DATDA Index: {idx:.4f}")