| """ |
| Fast Gradient Sign Method (FGSM) Attack |
| Fixed device validation issue |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| from typing import Optional, Tuple, Dict, Any |
| import numpy as np |
|
|
| class FGSMAttack: |
| """FGSM attack with targeted/non-targeted variants""" |
| |
| def __init__(self, model: nn.Module, config: Optional[Dict[str, Any]] = None): |
| """ |
| Initialize FGSM attack |
| |
| Args: |
| model: PyTorch model to attack |
| config: Attack configuration dictionary |
| """ |
| self.model = model |
| self.config = config or {} |
| |
| |
| self.epsilon = self.config.get('epsilon', 0.15) |
| self.targeted = self.config.get('targeted', False) |
| self.clip_min = self.config.get('clip_min', 0.0) |
| self.clip_max = self.config.get('clip_max', 1.0) |
| self.device = self.config.get('device', 'cpu') |
| |
| self.criterion = nn.CrossEntropyLoss() |
| self.model.eval() |
| self.model.to(self.device) |
| |
| def _validate_inputs(self, images: torch.Tensor, labels: torch.Tensor) -> None: |
| """Validate input tensors - FIXED: Remove strict device check""" |
| if not isinstance(images, torch.Tensor): |
| raise TypeError(f"images must be torch.Tensor, got {type(images)}") |
| if not isinstance(labels, torch.Tensor): |
| raise TypeError(f"labels must be torch.Tensor, got {type(labels)}") |
| |
| if images.device != torch.device(self.device): |
| images = images.to(self.device) |
| if labels.device != torch.device(self.device): |
| labels = labels.to(self.device) |
| |
| def generate(self, |
| images: torch.Tensor, |
| labels: torch.Tensor, |
| target_labels: Optional[torch.Tensor] = None) -> torch.Tensor: |
| """ |
| Generate adversarial examples |
| |
| Args: |
| images: Clean images [batch, channels, height, width] |
| labels: True labels for non-targeted attack |
| target_labels: Target labels for targeted attack (optional) |
| |
| Returns: |
| Adversarial images |
| """ |
| |
| images = images.to(self.device) |
| labels = labels.to(self.device) |
| |
| if target_labels is not None: |
| target_labels = target_labels.to(self.device) |
| |
| |
| self._validate_inputs(images, labels) |
| |
| |
| if self.targeted and target_labels is None: |
| raise ValueError("target_labels required for targeted attack") |
| |
| |
| images = images.clone().detach() |
| labels = labels.clone().detach() |
| |
| if target_labels is not None: |
| target_labels = target_labels.clone().detach() |
| |
| |
| images.requires_grad = True |
| |
| |
| outputs = self.model(images) |
| |
| |
| if self.targeted: |
| |
| loss = -self.criterion(outputs, target_labels) |
| else: |
| |
| loss = self.criterion(outputs, labels) |
| |
| |
| self.model.zero_grad() |
| loss.backward() |
| |
| |
| perturbation = self.epsilon * images.grad.sign() |
| |
| |
| if self.targeted: |
| adversarial_images = images - perturbation |
| else: |
| adversarial_images = images + perturbation |
| |
| |
| adversarial_images = torch.clamp(adversarial_images, self.clip_min, self.clip_max) |
| |
| return adversarial_images.detach() |
| |
| def attack_success_rate(self, |
| images: torch.Tensor, |
| labels: torch.Tensor, |
| adversarial_images: torch.Tensor) -> Dict[str, float]: |
| """ |
| Calculate attack success metrics |
| |
| Args: |
| images: Original images |
| labels: True labels |
| adversarial_images: Generated adversarial images |
| |
| Returns: |
| Dictionary of metrics |
| """ |
| images = images.to(self.device) |
| labels = labels.to(self.device) |
| adversarial_images = adversarial_images.to(self.device) |
| |
| with torch.no_grad(): |
| |
| orig_outputs = self.model(images) |
| orig_preds = orig_outputs.argmax(dim=1) |
| orig_accuracy = (orig_preds == labels).float().mean().item() |
| |
| |
| adv_outputs = self.model(adversarial_images) |
| adv_preds = adv_outputs.argmax(dim=1) |
| |
| |
| if self.targeted: |
| success = (adv_preds == labels).float().mean().item() |
| else: |
| success = (adv_preds != labels).float().mean().item() |
| |
| |
| orig_confidence = torch.softmax(orig_outputs, dim=1).max(dim=1)[0].mean().item() |
| adv_confidence = torch.softmax(adv_outputs, dim=1).max(dim=1)[0].mean().item() |
| |
| |
| perturbation = adversarial_images - images |
| l2_norm = torch.norm(perturbation.view(perturbation.size(0), -1), p=2, dim=1).mean().item() |
| linf_norm = torch.norm(perturbation.view(perturbation.size(0), -1), p=float('inf'), dim=1).mean().item() |
| |
| return { |
| 'original_accuracy': orig_accuracy * 100, |
| 'attack_success_rate': success * 100, |
| 'original_confidence': orig_confidence, |
| 'adversarial_confidence': adv_confidence, |
| 'perturbation_l2': l2_norm, |
| 'perturbation_linf': linf_norm, |
| 'epsilon': self.epsilon |
| } |
| |
| def __call__(self, images: torch.Tensor, labels: torch.Tensor, **kwargs) -> torch.Tensor: |
| """Callable interface""" |
| return self.generate(images, labels, **kwargs) |
|
|
| def create_fgsm_attack(model: nn.Module, epsilon: float = 0.15, **kwargs) -> FGSMAttack: |
| """Factory function for creating FGSM attack""" |
| config = {'epsilon': epsilon, **kwargs} |
| return FGSMAttack(model, config) |
|
|