File size: 9,857 Bytes
f4bee9e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
"""
DeepFool Attack Implementation
Enterprise-grade with support for multi-class and binary classification
"""

import torch
import torch.nn as nn
import numpy as np
from typing import Optional, Dict, Any, Tuple, List
import warnings

class DeepFoolAttack:
    """DeepFool attack for minimal perturbation"""
    
    def __init__(self, model: nn.Module, config: Optional[Dict[str, Any]] = None):
        """
        Initialize DeepFool attack
        
        Args:
            model: PyTorch model to attack
            config: Attack configuration dictionary
        """
        self.model = model
        self.config = config or {}
        
        # Default parameters
        self.max_iter = self.config.get('max_iter', 50)
        self.overshoot = self.config.get('overshoot', 0.02)
        self.num_classes = self.config.get('num_classes', 10)
        self.clip_min = self.config.get('clip_min', 0.0)
        self.clip_max = self.config.get('clip_max', 1.0)
        self.device = self.config.get('device', 'cpu')
        
        self.model.eval()
        
    def _compute_gradients(self, 
                          x: torch.Tensor, 
                          target_class: Optional[int] = None) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Compute gradients for all classes
        
        Args:
            x: Input tensor
            target_class: Optional target class for binary search
            
        Returns:
            Tuple of (gradients, outputs)
        """
        x = x.clone().detach().requires_grad_(True)
        
        # Forward pass
        outputs = self.model(x)
        
        # Get gradients for all classes
        gradients = []
        for k in range(self.num_classes):
            if k == target_class and target_class is not None:
                continue
                
            # Zero gradients
            if x.grad is not None:
                x.grad.zero_()
            
            # Backward for class k
            outputs[0, k].backward(retain_graph=True)
            gradients.append(x.grad.clone())
        
        # Clean up
        if x.grad is not None:
            x.grad.zero_()
        
        return torch.stack(gradients, dim=0), outputs.detach()
    
    def _binary_search(self,
                      x: torch.Tensor,
                      perturbation: torch.Tensor,
                      original_class: int,
                      target_class: int,
                      max_search_iter: int = 10) -> torch.Tensor:
        """
        Binary search for minimal perturbation
        
        Args:
            x: Original image
            perturbation: Initial perturbation
            original_class: Original predicted class
            target_class: Target class for misclassification
            max_search_iter: Maximum binary search iterations
        
        Returns:
            Minimal perturbation that causes misclassification
        """
        eps_low = 0.0
        eps_high = 1.0
        best_perturbation = perturbation
        
        for _ in range(max_search_iter):
            eps = (eps_low + eps_high) / 2
            x_adv = torch.clamp(x + eps * perturbation, self.clip_min, self.clip_max)
            
            with torch.no_grad():
                outputs = self.model(x_adv)
                pred_class = outputs.argmax(dim=1).item()
            
            if pred_class == target_class:
                eps_high = eps
                best_perturbation = eps * perturbation
            else:
                eps_low = eps
        
        return best_perturbation
    
    def _deepfool_single(self, x: torch.Tensor, original_class: int) -> Tuple[torch.Tensor, int, int]:
        """
        DeepFool for a single sample
        
        Args:
            x: Input tensor [1, C, H, W]
            original_class: Original predicted class
        
        Returns:
            Tuple of (perturbation, target_class, iterations)
        """
        x = x.to(self.device)
        x_adv = x.clone().detach()
        
        # Initialize
        r_total = torch.zeros_like(x)
        iterations = 0
        
        with torch.no_grad():
            outputs = self.model(x_adv)
            current_class = outputs.argmax(dim=1).item()
        
        while current_class == original_class and iterations < self.max_iter:
            # Compute gradients for all classes
            gradients, outputs = self._compute_gradients(x_adv)
            
            # Get current class score
            f_k = outputs[0, original_class]
            
            # Compute distances to decision boundaries
            distances = []
            for k in range(self.num_classes):
                if k == original_class:
                    continue
                
                w_k = gradients[k - (1 if k > original_class else 0)] - gradients[-1]
                f_k_prime = outputs[0, k]
                
                distance = torch.abs(f_k - f_k_prime) / (torch.norm(w_k.flatten()) + 1e-8)
                distances.append((distance.item(), k, w_k))
            
            # Find closest decision boundary
            distances.sort(key=lambda x: x[0])
            min_distance, target_class, w = distances[0]
            
            # Compute perturbation
            perturbation = (torch.abs(f_k - outputs[0, target_class]) + 1e-8) / \
                          (torch.norm(w.flatten()) ** 2 + 1e-8) * w
            
            # Update adversarial example
            x_adv = torch.clamp(x_adv + perturbation, self.clip_min, self.clip_max)
            r_total = r_total + perturbation
            
            # Check new prediction
            with torch.no_grad():
                outputs = self.model(x_adv)
                current_class = outputs.argmax(dim=1).item()
            
            iterations += 1
        
        # Apply overshoot
        if iterations < self.max_iter:
            r_total = (1 + self.overshoot) * r_total
        
        # Binary search for minimal perturbation
        if iterations > 0:
            r_total = self._binary_search(x, r_total, original_class, target_class)
        
        return r_total, target_class, iterations
    
    def generate(self, images: torch.Tensor, labels: Optional[torch.Tensor] = None) -> torch.Tensor:
        """
        Generate adversarial examples
        
        Args:
            images: Clean images [batch, C, H, W]
            labels: Optional labels for validation
        
        Returns:
            Adversarial images
        """
        batch_size = images.shape[0]
        images = images.clone().detach().to(self.device)
        
        # Get original predictions
        with torch.no_grad():
            outputs = self.model(images)
            original_classes = outputs.argmax(dim=1)
        
        adversarial_images = []
        success_count = 0
        total_iterations = 0
        
        # Process each image separately
        for i in range(batch_size):
            x = images[i:i+1]
            original_class = original_classes[i].item()
            
            # Generate perturbation
            perturbation, target_class, iterations = self._deepfool_single(x, original_class)
            
            # Create adversarial example
            x_adv = torch.clamp(x + perturbation, self.clip_min, self.clip_max)
            adversarial_images.append(x_adv)
            
            # Update statistics
            total_iterations += iterations
            if target_class != original_class:
                success_count += 1
        
        adversarial_images = torch.cat(adversarial_images, dim=0)
        
        # Calculate metrics
        with torch.no_grad():
            adv_outputs = self.model(adversarial_images)
            adv_classes = adv_outputs.argmax(dim=1)
            
            success_rate = success_count / batch_size * 100
            avg_iterations = total_iterations / batch_size
            
            # Perturbation metrics
            perturbation_norm = torch.norm(
                (adversarial_images - images).view(batch_size, -1), 
                p=2, dim=1
            ).mean().item()
        
        # Store metrics
        self.metrics = {
            'success_rate': success_rate,
            'avg_iterations': avg_iterations,
            'avg_perturbation': perturbation_norm,
            'original_accuracy': (original_classes == labels).float().mean().item() * 100 if labels is not None else None
        }
        
        return adversarial_images
    
    def get_minimal_perturbation(self, 
                                images: torch.Tensor, 
                                target_accuracy: float = 10.0) -> Tuple[torch.Tensor, float]:
        """
        Find minimal epsilon for target attack success rate
        
        Args:
            images: Clean images
            target_accuracy: Target accuracy after attack
        
        Returns:
            Tuple of (adversarial images, epsilon)
        """
        warnings.warn("DeepFool doesn't use epsilon parameter like FGSM/PGD")
        
        # Generate adversarial examples
        adv_images = self.generate(images)
        
        # Calculate effective epsilon (Linf norm)
        perturbation = adv_images - images
        epsilon = torch.norm(perturbation.view(perturbation.shape[0], -1), 
                           p=float('inf'), dim=1).mean().item()
        
        return adv_images, epsilon
    
    def __call__(self, images: torch.Tensor, **kwargs) -> torch.Tensor:
        """Callable interface"""
        return self.generate(images, **kwargs)

def create_deepfool_attack(model: nn.Module, max_iter: int = 50, **kwargs) -> DeepFoolAttack:
    """Factory function for creating DeepFool attack"""
    config = {'max_iter': max_iter, **kwargs}
    return DeepFoolAttack(model, config)