jeffliulab's picture
Fix README metadata + initial deploy
216e171 verified
"""
Adversarial Attack Demo — FGSM & PGD
Courses: 215 AI Safety ch1-ch2
"""
import json
import numpy as np
import torch
import torch.nn.functional as F
import torchvision.models as models
import torchvision.transforms as T
import gradio as gr
from PIL import Image
# ---------------------------------------------------------------------------
# Model & preprocessing
# ---------------------------------------------------------------------------
device = torch.device("cpu")
model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1).eval().to(device)
IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD = [0.229, 0.224, 0.225]
preprocess = T.Compose([
T.Resize(256),
T.CenterCrop(224),
T.ToTensor(),
])
normalize = T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD)
inv_normalize = T.Normalize(
mean=[-m / s for m, s in zip(IMAGENET_MEAN, IMAGENET_STD)],
std=[1 / s for s in IMAGENET_STD],
)
# Load ImageNet class labels
LABELS_URL = "https://raw.githubusercontent.com/anishathalye/imagenet-simple-labels/master/imagenet-simple-labels.json"
try:
import urllib.request
with urllib.request.urlopen(LABELS_URL) as resp:
LABELS = json.loads(resp.read().decode())
except Exception:
LABELS = [str(i) for i in range(1000)]
def get_top3(logits: torch.Tensor):
probs = F.softmax(logits, dim=1)[0]
top3 = torch.topk(probs, 3)
return [(LABELS[idx], float(prob)) for prob, idx in zip(top3.values, top3.indices)]
# ---------------------------------------------------------------------------
# Attack implementations
# ---------------------------------------------------------------------------
def fgsm_attack(img_tensor: torch.Tensor, epsilon: float) -> torch.Tensor:
"""Single-step FGSM (untargeted)."""
inp = normalize(img_tensor.clone()).unsqueeze(0).to(device)
inp.requires_grad = True
output = model(inp)
loss = F.cross_entropy(output, output.argmax(1))
loss.backward()
# Perturb in *pixel* space (pre-normalize)
grad_sign = inp.grad.sign()
# Convert gradient back to pixel space
perturbed_norm = inp + epsilon * grad_sign
# Denormalize, clamp, re-normalize to get pixel-space perturbed image
perturbed_pixel = inv_normalize(perturbed_norm.squeeze(0))
perturbed_pixel = torch.clamp(perturbed_pixel, 0, 1)
return perturbed_pixel
def pgd_attack(
img_tensor: torch.Tensor,
epsilon: float,
alpha: float,
num_steps: int,
) -> torch.Tensor:
"""Multi-step PGD (untargeted)."""
orig = img_tensor.clone()
perturbed = img_tensor.clone()
for _ in range(num_steps):
inp = normalize(perturbed.clone()).unsqueeze(0).to(device)
inp.requires_grad = True
output = model(inp)
loss = F.cross_entropy(output, output.argmax(1))
loss.backward()
grad_sign = inp.grad.sign()
# Step in normalized space then convert back
adv_norm = inp + alpha * grad_sign
adv_pixel = inv_normalize(adv_norm.squeeze(0))
# Project onto epsilon-ball around original (pixel space)
perturbation = torch.clamp(adv_pixel - orig, -epsilon, epsilon)
perturbed = torch.clamp(orig + perturbation, 0, 1).detach()
return perturbed
# ---------------------------------------------------------------------------
# Main function
# ---------------------------------------------------------------------------
def attack(
image: Image.Image,
method: str,
epsilon: float,
pgd_steps: int,
pgd_alpha: float,
):
if image is None:
return None, None, None, ""
img_tensor = preprocess(image.convert("RGB"))
# Original prediction
with torch.no_grad():
orig_logits = model(normalize(img_tensor).unsqueeze(0))
orig_pred = get_top3(orig_logits)
orig_label = orig_pred[0][0]
# Attack
if method == "FGSM":
adv_tensor = fgsm_attack(img_tensor, epsilon)
else:
adv_tensor = pgd_attack(img_tensor, epsilon, pgd_alpha, pgd_steps)
# Adversarial prediction
with torch.no_grad():
adv_logits = model(normalize(adv_tensor).unsqueeze(0))
adv_pred = get_top3(adv_logits)
adv_label = adv_pred[0][0]
# Perturbation visualization (amplified 10x)
diff = (adv_tensor - img_tensor)
perturbation = torch.clamp(diff * 10 + 0.5, 0, 1)
# Convert to numpy images
orig_img = (img_tensor.permute(1, 2, 0).numpy() * 255).astype(np.uint8)
pert_img = (perturbation.permute(1, 2, 0).numpy() * 255).astype(np.uint8)
adv_img = (adv_tensor.permute(1, 2, 0).numpy() * 255).astype(np.uint8)
# Metrics
linf = float(diff.abs().max())
l2 = float(diff.norm(2))
success = "ATTACK SUCCESS" if orig_label != adv_label else "Attack failed (same class)"
metrics_text = (
f"**{success}**\n\n"
f"| Metric | Value |\n|---|---|\n"
f"| Original Top-1 | {orig_pred[0][0]} ({orig_pred[0][1]:.1%}) |\n"
f"| Adversarial Top-1 | {adv_pred[0][0]} ({adv_pred[0][1]:.1%}) |\n"
f"| L-inf | {linf:.4f} |\n"
f"| L2 | {l2:.4f} |\n"
f"| Epsilon | {epsilon} |\n"
f"| Method | {method} |"
)
return orig_img, pert_img, adv_img, metrics_text
# ---------------------------------------------------------------------------
# Gradio UI
# ---------------------------------------------------------------------------
with gr.Blocks(title="Adversarial Attack Demo") as demo:
gr.Markdown(
"# Adversarial Attack Demo | FGSM & PGD\n"
"Upload an image and see how imperceptible perturbations fool a ResNet-18 classifier.\n"
"*Course: 215 AI Safety*"
)
with gr.Row():
with gr.Column(scale=1):
input_image = gr.Image(type="pil", label="Upload Image")
method = gr.Radio(["FGSM", "PGD"], value="FGSM", label="Attack Method")
epsilon = gr.Slider(0.0, 0.3, value=0.03, step=0.005, label="Epsilon (perturbation budget)")
pgd_steps = gr.Slider(1, 40, value=10, step=1, label="PGD Steps", visible=True)
pgd_alpha = gr.Slider(0.001, 0.05, value=0.007, step=0.001, label="PGD Step Size", visible=True)
run_btn = gr.Button("Run Attack", variant="primary")
with gr.Column(scale=2):
with gr.Row():
orig_out = gr.Image(label="Original Image")
pert_out = gr.Image(label="Perturbation (10x amplified)")
adv_out = gr.Image(label="Adversarial Image")
metrics = gr.Markdown(label="Results")
def toggle_pgd(m):
visible = m == "PGD"
return gr.update(visible=visible), gr.update(visible=visible)
method.change(toggle_pgd, inputs=[method], outputs=[pgd_steps, pgd_alpha])
run_btn.click(
fn=attack,
inputs=[input_image, method, epsilon, pgd_steps, pgd_alpha],
outputs=[orig_out, pert_out, adv_out, metrics],
)
gr.Examples(
examples=[
["examples/cat.jpg", "FGSM", 0.03, 10, 0.007],
["examples/dog.jpg", "PGD", 0.02, 20, 0.005],
["examples/car.jpg", "FGSM", 0.05, 10, 0.007],
],
inputs=[input_image, method, epsilon, pgd_steps, pgd_alpha],
label="Try these examples",
)
if __name__ == "__main__":
demo.launch()