File size: 14,087 Bytes
f4bee9e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 | """
Carlini & Wagner (C&W) L2 Attack
Enterprise implementation with full error handling and optimization
Reference: Carlini & Wagner, "Towards Evaluating the Robustness of Neural Networks" (2017)
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from typing import Optional, Dict, Any, Tuple
import time
class CarliniWagnerL2:
"""
Carlini & Wagner L2 Attack - Enterprise Implementation
Features:
- CPU-optimized with early stopping
- Multiple search methods for optimal c parameter
- Confidence thresholding
- Comprehensive logging and metrics
"""
def __init__(self, model: nn.Module, config: Optional[Dict[str, Any]] = None):
"""
Initialize C&W attack
Args:
model: PyTorch model to attack
config: Attack configuration dictionary
"""
self.model = model
self.config = config or {}
# Attack parameters with defaults
self.confidence = self.config.get('confidence', 0.0)
self.max_iterations = self.config.get('max_iterations', 100)
self.learning_rate = self.config.get('learning_rate', 0.01)
self.binary_search_steps = self.config.get('binary_search_steps', 9)
self.initial_const = self.config.get('initial_const', 1e-3)
self.abort_early = self.config.get('abort_early', True)
self.device = self.config.get('device', 'cpu')
# Optimization parameters
self.box_min = self.config.get('box_min', 0.0)
self.box_max = self.config.get('box_max', 1.0)
self.model.eval()
self.model.to(self.device)
def _tanh_space(self, x: torch.Tensor, boxmin: float, boxmax: float) -> torch.Tensor:
"""Transform to tanh space to handle box constraints"""
return torch.tanh(x) * (boxmax - boxmin) / 2 + (boxmax + boxmin) / 2
def _inverse_tanh_space(self, x: torch.Tensor, boxmin: float, boxmax: float) -> torch.Tensor:
"""Inverse transform from tanh space"""
return torch.atanh((2 * (x - boxmin) / (boxmax - boxmin) - 1).clamp(-1 + 1e-7, 1 - 1e-7))
def _compute_loss(self,
adv_images: torch.Tensor,
images: torch.Tensor,
labels: torch.Tensor,
const: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Compute C&W loss components
Returns:
total_loss, distance_loss, classification_loss
"""
# L2 distance
l2_dist = torch.norm((adv_images - images).view(images.size(0), -1), p=2, dim=1)
distance_loss = l2_dist.sum()
# Classification loss (C&W formulation)
logits = self.model(adv_images)
# Get correct class logits
correct_logits = logits.gather(1, labels.unsqueeze(1)).squeeze()
# Get maximum logit of incorrect classes
mask = torch.ones_like(logits).scatter_(1, labels.unsqueeze(1), 0)
other_logits = torch.max(logits * mask, dim=1)[0]
# C&W loss: max(other_logits - correct_logits, -confidence)
classification_loss = torch.clamp(other_logits - correct_logits + self.confidence, min=0.0)
classification_loss = (const * classification_loss).sum()
total_loss = distance_loss + classification_loss
return total_loss, distance_loss, classification_loss
def _optimize_single(self,
images: torch.Tensor,
labels: torch.Tensor,
const: float,
early_stop_threshold: float = 1e-4) -> Tuple[torch.Tensor, float, bool]:
"""
Single optimization run for given constant
Returns:
adversarial_images, best_l2, attack_successful
"""
batch_size = images.size(0)
# Initialize in tanh space
w = self._inverse_tanh_space(images, self.box_min, self.box_max).detach()
w.requires_grad = True
# Optimizer
optimizer = torch.optim.Adam([w], lr=self.learning_rate)
# For early stopping
prev_loss = float('inf')
best_l2 = float('inf')
best_adv = images.clone()
const_tensor = torch.full((batch_size,), const, device=self.device)
attack_successful = False
for iteration in range(self.max_iterations):
# Forward pass
adv_images = self._tanh_space(w, self.box_min, self.box_max)
# Compute loss
total_loss, distance_loss, classification_loss = self._compute_loss(
adv_images, images, labels, const_tensor
)
# Check attack success
with torch.no_grad():
preds = self.model(adv_images).argmax(dim=1)
success_mask = (preds != labels)
current_l2 = torch.norm((adv_images - images).view(batch_size, -1), p=2, dim=1)
# Update best adversarial examples
for i in range(batch_size):
if success_mask[i] and current_l2[i] < best_l2:
best_l2 = current_l2[i].item()
best_adv[i] = adv_images[i]
attack_successful = True
# Backward pass
optimizer.zero_grad()
total_loss.backward()
optimizer.step()
# Early stopping check
if self.abort_early and iteration % 10 == 0:
if total_loss.item() > prev_loss * 0.9999:
break
prev_loss = total_loss.item()
return best_adv, best_l2, attack_successful
def generate(self,
images: torch.Tensor,
labels: torch.Tensor,
targeted: bool = False,
target_labels: Optional[torch.Tensor] = None) -> torch.Tensor:
"""
Generate adversarial examples using C&W attack
Args:
images: Clean images [batch, channels, height, width]
labels: True labels for non-targeted attack
targeted: Whether to perform targeted attack
target_labels: Target labels for targeted attack
Returns:
Adversarial images
"""
if targeted and target_labels is None:
raise ValueError("target_labels required for targeted attack")
images = images.clone().detach().to(self.device)
labels = labels.clone().detach().to(self.device)
if targeted:
labels = target_labels.clone().detach().to(self.device)
batch_size = images.size(0)
# Binary search for optimal const
const_lower_bound = torch.zeros(batch_size, device=self.device)
const_upper_bound = torch.ones(batch_size, device=self.device) * 1e10
const = torch.ones(batch_size, device=self.device) * self.initial_const
# Best results tracking
best_l2 = torch.ones(batch_size, device=self.device) * float('inf')
best_adv = images.clone()
for binary_step in range(self.binary_search_steps):
print(f" Binary search step {binary_step + 1}/{self.binary_search_steps}")
# Optimize for current const values
for i in range(batch_size):
const_i = const[i].item()
adv_i, l2_i, success_i = self._optimize_single(
images[i:i+1], labels[i:i+1], const_i
)
if success_i:
# Success: try smaller const
const_upper_bound[i] = min(const_upper_bound[i], const_i)
if const_upper_bound[i] < 1e9:
const[i] = (const_lower_bound[i] + const_upper_bound[i]) / 2
# Update best result
if l2_i < best_l2[i]:
best_l2[i] = l2_i
best_adv[i] = adv_i
else:
# Failure: try larger const
const_lower_bound[i] = max(const_lower_bound[i], const_i)
if const_upper_bound[i] < 1e9:
const[i] = (const_lower_bound[i] + const_upper_bound[i]) / 2
else:
const[i] = const[i] * 10
return best_adv
def attack_success_rate(self,
images: torch.Tensor,
labels: torch.Tensor,
adversarial_images: torch.Tensor) -> Dict[str, float]:
"""
Calculate attack success metrics
Args:
images: Original images
labels: True labels
adversarial_images: Generated adversarial images
Returns:
Dictionary of metrics
"""
images = images.to(self.device)
labels = labels.to(self.device)
adversarial_images = adversarial_images.to(self.device)
with torch.no_grad():
# Original predictions
orig_outputs = self.model(images)
orig_preds = orig_outputs.argmax(dim=1)
orig_accuracy = (orig_preds == labels).float().mean().item()
# Adversarial predictions
adv_outputs = self.model(adversarial_images)
adv_preds = adv_outputs.argmax(dim=1)
success_rate = (adv_preds != labels).float().mean().item()
# Perturbation metrics
perturbation = adversarial_images - images
l2_norm = torch.norm(perturbation.view(perturbation.size(0), -1), p=2, dim=1)
linf_norm = torch.norm(perturbation.view(perturbation.size(0), -1), p=float('inf'), dim=1)
# Confidence metrics
orig_probs = F.softmax(orig_outputs, dim=1)
adv_probs = F.softmax(adv_outputs, dim=1)
orig_confidence = orig_probs.max(dim=1)[0].mean().item()
adv_confidence = adv_probs.max(dim=1)[0].mean().item()
# Successful attack statistics
success_mask = (adv_preds != labels)
if success_mask.any():
successful_l2 = l2_norm[success_mask].mean().item()
successful_linf = linf_norm[success_mask].mean().item()
else:
successful_l2 = 0.0
successful_linf = 0.0
return {
'original_accuracy': orig_accuracy * 100,
'attack_success_rate': success_rate * 100,
'avg_l2_perturbation': l2_norm.mean().item(),
'avg_linf_perturbation': linf_norm.mean().item(),
'successful_l2_perturbation': successful_l2,
'successful_linf_perturbation': successful_linf,
'original_confidence': orig_confidence,
'adversarial_confidence': adv_confidence,
'confidence_threshold': self.confidence
}
def __call__(self, images: torch.Tensor, labels: torch.Tensor, **kwargs) -> torch.Tensor:
"""Callable interface"""
return self.generate(images, labels, **kwargs)
class FastCarliniWagnerL2:
"""
Faster C&W implementation for CPU - Uses fixed const and fewer iterations
Suitable for larger batches and quicker evaluations
"""
def __init__(self, model: nn.Module, config: Optional[Dict[str, Any]] = None):
self.model = model
self.config = config or {}
self.const = self.config.get('const', 1.0)
self.iterations = self.config.get('iterations', 50)
self.learning_rate = self.config.get('learning_rate', 0.01)
self.device = self.config.get('device', 'cpu')
self.model.eval()
self.model.to(self.device)
def generate(self, images: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
"""Fast C&W generation with fixed const"""
images = images.clone().detach().to(self.device)
labels = labels.clone().detach().to(self.device)
batch_size = images.size(0)
# Initialize in tanh space
w = torch.zeros_like(images, requires_grad=True)
w.data = torch.atanh((2 * (images - 0.5) / 1).clamp(-1 + 1e-7, 1 - 1e-7))
optimizer = torch.optim.Adam([w], lr=self.learning_rate)
for iteration in range(self.iterations):
adv_images = torch.tanh(w) * 0.5 + 0.5
# L2 distance
l2_dist = torch.norm((adv_images - images).view(batch_size, -1), p=2, dim=1)
# C&W classification loss
logits = self.model(adv_images)
correct_logits = logits.gather(1, labels.unsqueeze(1)).squeeze()
mask = torch.ones_like(logits).scatter_(1, labels.unsqueeze(1), 0)
other_logits = torch.max(logits * mask, dim=1)[0]
classification_loss = torch.clamp(other_logits - correct_logits, min=0.0)
# Total loss
loss = torch.mean(self.const * classification_loss + l2_dist)
optimizer.zero_grad()
loss.backward()
optimizer.step()
return torch.tanh(w) * 0.5 + 0.5
# Factory functions
def create_cw_attack(model: nn.Module, const: float = 1e-3, **kwargs) -> CarliniWagnerL2:
"""Factory function for creating C&W attack"""
config = {'initial_const': const, **kwargs}
return CarliniWagnerL2(model, config)
def create_fast_cw_attack(model: nn.Module, const: float = 1.0, **kwargs) -> FastCarliniWagnerL2:
"""Factory function for creating fast C&W attack"""
config = {'const': const, **kwargs}
return FastCarliniWagnerL2(model, config)
|