File size: 7,276 Bytes
216e171
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
"""
Adversarial Attack Demo — FGSM & PGD
Courses: 215 AI Safety ch1-ch2
"""

import json
import numpy as np
import torch
import torch.nn.functional as F
import torchvision.models as models
import torchvision.transforms as T
import gradio as gr
from PIL import Image

# ---------------------------------------------------------------------------
# Model & preprocessing
# ---------------------------------------------------------------------------
device = torch.device("cpu")
model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1).eval().to(device)

IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD = [0.229, 0.224, 0.225]

preprocess = T.Compose([
    T.Resize(256),
    T.CenterCrop(224),
    T.ToTensor(),
])

normalize = T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD)
inv_normalize = T.Normalize(
    mean=[-m / s for m, s in zip(IMAGENET_MEAN, IMAGENET_STD)],
    std=[1 / s for s in IMAGENET_STD],
)

# Load ImageNet class labels
LABELS_URL = "https://raw.githubusercontent.com/anishathalye/imagenet-simple-labels/master/imagenet-simple-labels.json"
try:
    import urllib.request
    with urllib.request.urlopen(LABELS_URL) as resp:
        LABELS = json.loads(resp.read().decode())
except Exception:
    LABELS = [str(i) for i in range(1000)]


def get_top3(logits: torch.Tensor):
    probs = F.softmax(logits, dim=1)[0]
    top3 = torch.topk(probs, 3)
    return [(LABELS[idx], float(prob)) for prob, idx in zip(top3.values, top3.indices)]


# ---------------------------------------------------------------------------
# Attack implementations
# ---------------------------------------------------------------------------
def fgsm_attack(img_tensor: torch.Tensor, epsilon: float) -> torch.Tensor:
    """Single-step FGSM (untargeted)."""
    inp = normalize(img_tensor.clone()).unsqueeze(0).to(device)
    inp.requires_grad = True
    output = model(inp)
    loss = F.cross_entropy(output, output.argmax(1))
    loss.backward()
    # Perturb in *pixel* space (pre-normalize)
    grad_sign = inp.grad.sign()
    # Convert gradient back to pixel space
    perturbed_norm = inp + epsilon * grad_sign
    # Denormalize, clamp, re-normalize to get pixel-space perturbed image
    perturbed_pixel = inv_normalize(perturbed_norm.squeeze(0))
    perturbed_pixel = torch.clamp(perturbed_pixel, 0, 1)
    return perturbed_pixel


def pgd_attack(
    img_tensor: torch.Tensor,
    epsilon: float,
    alpha: float,
    num_steps: int,
) -> torch.Tensor:
    """Multi-step PGD (untargeted)."""
    orig = img_tensor.clone()
    perturbed = img_tensor.clone()

    for _ in range(num_steps):
        inp = normalize(perturbed.clone()).unsqueeze(0).to(device)
        inp.requires_grad = True
        output = model(inp)
        loss = F.cross_entropy(output, output.argmax(1))
        loss.backward()
        grad_sign = inp.grad.sign()
        # Step in normalized space then convert back
        adv_norm = inp + alpha * grad_sign
        adv_pixel = inv_normalize(adv_norm.squeeze(0))
        # Project onto epsilon-ball around original (pixel space)
        perturbation = torch.clamp(adv_pixel - orig, -epsilon, epsilon)
        perturbed = torch.clamp(orig + perturbation, 0, 1).detach()

    return perturbed


# ---------------------------------------------------------------------------
# Main function
# ---------------------------------------------------------------------------
def attack(
    image: Image.Image,
    method: str,
    epsilon: float,
    pgd_steps: int,
    pgd_alpha: float,
):
    if image is None:
        return None, None, None, ""

    img_tensor = preprocess(image.convert("RGB"))

    # Original prediction
    with torch.no_grad():
        orig_logits = model(normalize(img_tensor).unsqueeze(0))
    orig_pred = get_top3(orig_logits)
    orig_label = orig_pred[0][0]

    # Attack
    if method == "FGSM":
        adv_tensor = fgsm_attack(img_tensor, epsilon)
    else:
        adv_tensor = pgd_attack(img_tensor, epsilon, pgd_alpha, pgd_steps)

    # Adversarial prediction
    with torch.no_grad():
        adv_logits = model(normalize(adv_tensor).unsqueeze(0))
    adv_pred = get_top3(adv_logits)
    adv_label = adv_pred[0][0]

    # Perturbation visualization (amplified 10x)
    diff = (adv_tensor - img_tensor)
    perturbation = torch.clamp(diff * 10 + 0.5, 0, 1)

    # Convert to numpy images
    orig_img = (img_tensor.permute(1, 2, 0).numpy() * 255).astype(np.uint8)
    pert_img = (perturbation.permute(1, 2, 0).numpy() * 255).astype(np.uint8)
    adv_img = (adv_tensor.permute(1, 2, 0).numpy() * 255).astype(np.uint8)

    # Metrics
    linf = float(diff.abs().max())
    l2 = float(diff.norm(2))
    success = "ATTACK SUCCESS" if orig_label != adv_label else "Attack failed (same class)"

    metrics_text = (
        f"**{success}**\n\n"
        f"| Metric | Value |\n|---|---|\n"
        f"| Original Top-1 | {orig_pred[0][0]} ({orig_pred[0][1]:.1%}) |\n"
        f"| Adversarial Top-1 | {adv_pred[0][0]} ({adv_pred[0][1]:.1%}) |\n"
        f"| L-inf | {linf:.4f} |\n"
        f"| L2 | {l2:.4f} |\n"
        f"| Epsilon | {epsilon} |\n"
        f"| Method | {method} |"
    )

    return orig_img, pert_img, adv_img, metrics_text


# ---------------------------------------------------------------------------
# Gradio UI
# ---------------------------------------------------------------------------
with gr.Blocks(title="Adversarial Attack Demo") as demo:
    gr.Markdown(
        "# Adversarial Attack Demo | FGSM & PGD\n"
        "Upload an image and see how imperceptible perturbations fool a ResNet-18 classifier.\n"
        "*Course: 215 AI Safety*"
    )

    with gr.Row():
        with gr.Column(scale=1):
            input_image = gr.Image(type="pil", label="Upload Image")
            method = gr.Radio(["FGSM", "PGD"], value="FGSM", label="Attack Method")
            epsilon = gr.Slider(0.0, 0.3, value=0.03, step=0.005, label="Epsilon (perturbation budget)")
            pgd_steps = gr.Slider(1, 40, value=10, step=1, label="PGD Steps", visible=True)
            pgd_alpha = gr.Slider(0.001, 0.05, value=0.007, step=0.001, label="PGD Step Size", visible=True)
            run_btn = gr.Button("Run Attack", variant="primary")

        with gr.Column(scale=2):
            with gr.Row():
                orig_out = gr.Image(label="Original Image")
                pert_out = gr.Image(label="Perturbation (10x amplified)")
                adv_out = gr.Image(label="Adversarial Image")
            metrics = gr.Markdown(label="Results")

    def toggle_pgd(m):
        visible = m == "PGD"
        return gr.update(visible=visible), gr.update(visible=visible)

    method.change(toggle_pgd, inputs=[method], outputs=[pgd_steps, pgd_alpha])

    run_btn.click(
        fn=attack,
        inputs=[input_image, method, epsilon, pgd_steps, pgd_alpha],
        outputs=[orig_out, pert_out, adv_out, metrics],
    )

    gr.Examples(
        examples=[
            ["examples/cat.jpg", "FGSM", 0.03, 10, 0.007],
            ["examples/dog.jpg", "PGD", 0.02, 20, 0.005],
            ["examples/car.jpg", "FGSM", 0.05, 10, 0.007],
        ],
        inputs=[input_image, method, epsilon, pgd_steps, pgd_alpha],
        label="Try these examples",
    )

if __name__ == "__main__":
    demo.launch()