File size: 7,705 Bytes
f4bee9e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 | """
Projected Gradient Descent (PGD) Attack
Enterprise implementation with multiple restarts and adaptive step size
"""
import torch
import torch.nn as nn
import numpy as np
from typing import Optional, Tuple, Dict, Any, Union
from attacks.fgsm import FGSMAttack
class PGDAttack:
"""PGD attack with random restarts and adaptive step size"""
def __init__(self, model: nn.Module, config: Optional[Dict[str, Any]] = None):
"""
Initialize PGD attack
Args:
model: PyTorch model to attack
config: Attack configuration dictionary
"""
self.model = model
self.config = config or {}
# Default parameters
self.epsilon = self.config.get('epsilon', 0.3)
self.alpha = self.config.get('alpha', 0.01)
self.steps = self.config.get('steps', 10)
self.random_start = self.config.get('random_start', True)
self.targeted = self.config.get('targeted', False)
self.clip_min = self.config.get('clip_min', 0.0)
self.clip_max = self.config.get('clip_max', 1.0)
self.device = self.config.get('device', 'cpu')
self.restarts = self.config.get('restarts', 1)
self.criterion = nn.CrossEntropyLoss()
self.model.eval()
def _project_onto_l_inf_ball(self,
x: torch.Tensor,
perturbation: torch.Tensor) -> torch.Tensor:
"""Project perturbation onto Linf epsilon-ball"""
return torch.clamp(perturbation, -self.epsilon, self.epsilon)
def _random_initialization(self, x: torch.Tensor) -> torch.Tensor:
"""Random initialization within epsilon-ball"""
delta = torch.empty_like(x).uniform_(-self.epsilon, self.epsilon)
x_adv = torch.clamp(x + delta, self.clip_min, self.clip_max)
return x_adv - x # Return delta
def _single_restart(self,
images: torch.Tensor,
labels: torch.Tensor,
target_labels: Optional[torch.Tensor] = None) -> torch.Tensor:
"""Single PGD restart"""
batch_size = images.shape[0]
# Initialize adversarial examples
if self.random_start:
delta = self._random_initialization(images)
else:
delta = torch.zeros_like(images)
x_adv = images + delta
# PGD iterations
for step in range(self.steps):
x_adv = x_adv.clone().detach().requires_grad_(True)
# Forward pass
outputs = self.model(x_adv)
# Loss calculation
if self.targeted:
loss = -self.criterion(outputs, target_labels)
else:
loss = self.criterion(outputs, labels)
# Gradient calculation
grad = torch.autograd.grad(loss, [x_adv])[0]
# PGD update: x' = x + a * sign(?x)
if self.targeted:
delta = delta - self.alpha * grad.sign()
else:
delta = delta + self.alpha * grad.sign()
# Project onto epsilon-ball
delta = self._project_onto_l_inf_ball(images, delta)
# Update adversarial examples
x_adv = torch.clamp(images + delta, self.clip_min, self.clip_max)
return x_adv
def generate(self,
images: torch.Tensor,
labels: torch.Tensor,
target_labels: Optional[torch.Tensor] = None) -> torch.Tensor:
"""
Generate adversarial examples with multiple restarts
Args:
images: Clean images
labels: True labels
target_labels: Target labels for targeted attack
Returns:
Best adversarial examples across restarts
"""
if self.targeted and target_labels is None:
raise ValueError("target_labels required for targeted attack")
images = images.clone().detach().to(self.device)
labels = labels.clone().detach().to(self.device)
if target_labels is not None:
target_labels = target_labels.clone().detach().to(self.device)
# Initialize best adversarial examples
best_adv = None
best_loss = -float('inf') if self.targeted else float('inf')
# Multiple restarts
for restart in range(self.restarts):
# Generate adversarial examples for this restart
x_adv = self._single_restart(images, labels, target_labels)
# Calculate loss
with torch.no_grad():
outputs = self.model(x_adv)
if self.targeted:
loss = -self.criterion(outputs, target_labels)
else:
loss = self.criterion(outputs, labels)
# Update best adversarial examples
if self.targeted:
if loss > best_loss:
best_loss = loss
best_adv = x_adv
else:
if loss < best_loss:
best_loss = loss
best_adv = x_adv
return best_adv
def adaptive_attack(self,
images: torch.Tensor,
labels: torch.Tensor,
initial_epsilon: float = 0.1,
max_iterations: int = 20) -> Tuple[torch.Tensor, float]:
"""
Adaptive PGD that finds minimal epsilon for successful attack
Args:
images: Clean images
labels: True labels
initial_epsilon: Starting epsilon
max_iterations: Maximum binary search iterations
Returns:
Tuple of (adversarial examples, optimal epsilon)
"""
eps_low = 0.0
eps_high = initial_epsilon * 2
# Find upper bound
for _ in range(10):
self.epsilon = eps_high
adv_images = self.generate(images, labels)
with torch.no_grad():
preds = self.model(adv_images).argmax(dim=1)
success_rate = (preds != labels).float().mean().item()
if success_rate > 0.9: # 90% success rate
break
eps_high *= 2
# Binary search for optimal epsilon
best_epsilon = eps_high
best_adv = adv_images
for _ in range(max_iterations):
epsilon = (eps_low + eps_high) / 2
self.epsilon = epsilon
adv_images = self.generate(images, labels)
with torch.no_grad():
preds = self.model(adv_images).argmax(dim=1)
success_rate = (preds != labels).float().mean().item()
if success_rate > 0.9: # 90% success threshold
eps_high = epsilon
best_epsilon = epsilon
best_adv = adv_images
else:
eps_low = epsilon
return best_adv, best_epsilon
def __call__(self, images: torch.Tensor, labels: torch.Tensor, **kwargs) -> torch.Tensor:
"""Callable interface"""
return self.generate(images, labels, **kwargs)
def create_pgd_attack(model: nn.Module, epsilon: float = 0.3, **kwargs) -> PGDAttack:
"""Factory function for creating PGD attack"""
config = {'epsilon': epsilon, **kwargs}
return PGDAttack(model, config)
|