|
|
""" |
|
|
DeepFool Attack Implementation |
|
|
Enterprise-grade with support for multi-class and binary classification |
|
|
""" |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import numpy as np |
|
|
from typing import Optional, Dict, Any, Tuple, List |
|
|
import warnings |
|
|
|
|
|
class DeepFoolAttack: |
|
|
"""DeepFool attack for minimal perturbation""" |
|
|
|
|
|
def __init__(self, model: nn.Module, config: Optional[Dict[str, Any]] = None): |
|
|
""" |
|
|
Initialize DeepFool attack |
|
|
|
|
|
Args: |
|
|
model: PyTorch model to attack |
|
|
config: Attack configuration dictionary |
|
|
""" |
|
|
self.model = model |
|
|
self.config = config or {} |
|
|
|
|
|
|
|
|
self.max_iter = self.config.get('max_iter', 50) |
|
|
self.overshoot = self.config.get('overshoot', 0.02) |
|
|
self.num_classes = self.config.get('num_classes', 10) |
|
|
self.clip_min = self.config.get('clip_min', 0.0) |
|
|
self.clip_max = self.config.get('clip_max', 1.0) |
|
|
self.device = self.config.get('device', 'cpu') |
|
|
|
|
|
self.model.eval() |
|
|
|
|
|
def _compute_gradients(self, |
|
|
x: torch.Tensor, |
|
|
target_class: Optional[int] = None) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
""" |
|
|
Compute gradients for all classes |
|
|
|
|
|
Args: |
|
|
x: Input tensor |
|
|
target_class: Optional target class for binary search |
|
|
|
|
|
Returns: |
|
|
Tuple of (gradients, outputs) |
|
|
""" |
|
|
x = x.clone().detach().requires_grad_(True) |
|
|
|
|
|
|
|
|
outputs = self.model(x) |
|
|
|
|
|
|
|
|
gradients = [] |
|
|
for k in range(self.num_classes): |
|
|
if k == target_class and target_class is not None: |
|
|
continue |
|
|
|
|
|
|
|
|
if x.grad is not None: |
|
|
x.grad.zero_() |
|
|
|
|
|
|
|
|
outputs[0, k].backward(retain_graph=True) |
|
|
gradients.append(x.grad.clone()) |
|
|
|
|
|
|
|
|
if x.grad is not None: |
|
|
x.grad.zero_() |
|
|
|
|
|
return torch.stack(gradients, dim=0), outputs.detach() |
|
|
|
|
|
def _binary_search(self, |
|
|
x: torch.Tensor, |
|
|
perturbation: torch.Tensor, |
|
|
original_class: int, |
|
|
target_class: int, |
|
|
max_search_iter: int = 10) -> torch.Tensor: |
|
|
""" |
|
|
Binary search for minimal perturbation |
|
|
|
|
|
Args: |
|
|
x: Original image |
|
|
perturbation: Initial perturbation |
|
|
original_class: Original predicted class |
|
|
target_class: Target class for misclassification |
|
|
max_search_iter: Maximum binary search iterations |
|
|
|
|
|
Returns: |
|
|
Minimal perturbation that causes misclassification |
|
|
""" |
|
|
eps_low = 0.0 |
|
|
eps_high = 1.0 |
|
|
best_perturbation = perturbation |
|
|
|
|
|
for _ in range(max_search_iter): |
|
|
eps = (eps_low + eps_high) / 2 |
|
|
x_adv = torch.clamp(x + eps * perturbation, self.clip_min, self.clip_max) |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = self.model(x_adv) |
|
|
pred_class = outputs.argmax(dim=1).item() |
|
|
|
|
|
if pred_class == target_class: |
|
|
eps_high = eps |
|
|
best_perturbation = eps * perturbation |
|
|
else: |
|
|
eps_low = eps |
|
|
|
|
|
return best_perturbation |
|
|
|
|
|
def _deepfool_single(self, x: torch.Tensor, original_class: int) -> Tuple[torch.Tensor, int, int]: |
|
|
""" |
|
|
DeepFool for a single sample |
|
|
|
|
|
Args: |
|
|
x: Input tensor [1, C, H, W] |
|
|
original_class: Original predicted class |
|
|
|
|
|
Returns: |
|
|
Tuple of (perturbation, target_class, iterations) |
|
|
""" |
|
|
x = x.to(self.device) |
|
|
x_adv = x.clone().detach() |
|
|
|
|
|
|
|
|
r_total = torch.zeros_like(x) |
|
|
iterations = 0 |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = self.model(x_adv) |
|
|
current_class = outputs.argmax(dim=1).item() |
|
|
|
|
|
while current_class == original_class and iterations < self.max_iter: |
|
|
|
|
|
gradients, outputs = self._compute_gradients(x_adv) |
|
|
|
|
|
|
|
|
f_k = outputs[0, original_class] |
|
|
|
|
|
|
|
|
distances = [] |
|
|
for k in range(self.num_classes): |
|
|
if k == original_class: |
|
|
continue |
|
|
|
|
|
w_k = gradients[k - (1 if k > original_class else 0)] - gradients[-1] |
|
|
f_k_prime = outputs[0, k] |
|
|
|
|
|
distance = torch.abs(f_k - f_k_prime) / (torch.norm(w_k.flatten()) + 1e-8) |
|
|
distances.append((distance.item(), k, w_k)) |
|
|
|
|
|
|
|
|
distances.sort(key=lambda x: x[0]) |
|
|
min_distance, target_class, w = distances[0] |
|
|
|
|
|
|
|
|
perturbation = (torch.abs(f_k - outputs[0, target_class]) + 1e-8) / \ |
|
|
(torch.norm(w.flatten()) ** 2 + 1e-8) * w |
|
|
|
|
|
|
|
|
x_adv = torch.clamp(x_adv + perturbation, self.clip_min, self.clip_max) |
|
|
r_total = r_total + perturbation |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = self.model(x_adv) |
|
|
current_class = outputs.argmax(dim=1).item() |
|
|
|
|
|
iterations += 1 |
|
|
|
|
|
|
|
|
if iterations < self.max_iter: |
|
|
r_total = (1 + self.overshoot) * r_total |
|
|
|
|
|
|
|
|
if iterations > 0: |
|
|
r_total = self._binary_search(x, r_total, original_class, target_class) |
|
|
|
|
|
return r_total, target_class, iterations |
|
|
|
|
|
def generate(self, images: torch.Tensor, labels: Optional[torch.Tensor] = None) -> torch.Tensor: |
|
|
""" |
|
|
Generate adversarial examples |
|
|
|
|
|
Args: |
|
|
images: Clean images [batch, C, H, W] |
|
|
labels: Optional labels for validation |
|
|
|
|
|
Returns: |
|
|
Adversarial images |
|
|
""" |
|
|
batch_size = images.shape[0] |
|
|
images = images.clone().detach().to(self.device) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = self.model(images) |
|
|
original_classes = outputs.argmax(dim=1) |
|
|
|
|
|
adversarial_images = [] |
|
|
success_count = 0 |
|
|
total_iterations = 0 |
|
|
|
|
|
|
|
|
for i in range(batch_size): |
|
|
x = images[i:i+1] |
|
|
original_class = original_classes[i].item() |
|
|
|
|
|
|
|
|
perturbation, target_class, iterations = self._deepfool_single(x, original_class) |
|
|
|
|
|
|
|
|
x_adv = torch.clamp(x + perturbation, self.clip_min, self.clip_max) |
|
|
adversarial_images.append(x_adv) |
|
|
|
|
|
|
|
|
total_iterations += iterations |
|
|
if target_class != original_class: |
|
|
success_count += 1 |
|
|
|
|
|
adversarial_images = torch.cat(adversarial_images, dim=0) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
adv_outputs = self.model(adversarial_images) |
|
|
adv_classes = adv_outputs.argmax(dim=1) |
|
|
|
|
|
success_rate = success_count / batch_size * 100 |
|
|
avg_iterations = total_iterations / batch_size |
|
|
|
|
|
|
|
|
perturbation_norm = torch.norm( |
|
|
(adversarial_images - images).view(batch_size, -1), |
|
|
p=2, dim=1 |
|
|
).mean().item() |
|
|
|
|
|
|
|
|
self.metrics = { |
|
|
'success_rate': success_rate, |
|
|
'avg_iterations': avg_iterations, |
|
|
'avg_perturbation': perturbation_norm, |
|
|
'original_accuracy': (original_classes == labels).float().mean().item() * 100 if labels is not None else None |
|
|
} |
|
|
|
|
|
return adversarial_images |
|
|
|
|
|
def get_minimal_perturbation(self, |
|
|
images: torch.Tensor, |
|
|
target_accuracy: float = 10.0) -> Tuple[torch.Tensor, float]: |
|
|
""" |
|
|
Find minimal epsilon for target attack success rate |
|
|
|
|
|
Args: |
|
|
images: Clean images |
|
|
target_accuracy: Target accuracy after attack |
|
|
|
|
|
Returns: |
|
|
Tuple of (adversarial images, epsilon) |
|
|
""" |
|
|
warnings.warn("DeepFool doesn't use epsilon parameter like FGSM/PGD") |
|
|
|
|
|
|
|
|
adv_images = self.generate(images) |
|
|
|
|
|
|
|
|
perturbation = adv_images - images |
|
|
epsilon = torch.norm(perturbation.view(perturbation.shape[0], -1), |
|
|
p=float('inf'), dim=1).mean().item() |
|
|
|
|
|
return adv_images, epsilon |
|
|
|
|
|
def __call__(self, images: torch.Tensor, **kwargs) -> torch.Tensor: |
|
|
"""Callable interface""" |
|
|
return self.generate(images, **kwargs) |
|
|
|
|
|
def create_deepfool_attack(model: nn.Module, max_iter: int = 50, **kwargs) -> DeepFoolAttack: |
|
|
"""Factory function for creating DeepFool attack""" |
|
|
config = {'max_iter': max_iter, **kwargs} |
|
|
return DeepFoolAttack(model, config) |
|
|
|