""" Adversarial Attack Demo — FGSM & PGD Courses: 215 AI Safety ch1-ch2 """ import json import numpy as np import torch import torch.nn.functional as F import torchvision.models as models import torchvision.transforms as T import gradio as gr from PIL import Image # --------------------------------------------------------------------------- # Model & preprocessing # --------------------------------------------------------------------------- device = torch.device("cpu") model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1).eval().to(device) IMAGENET_MEAN = [0.485, 0.456, 0.406] IMAGENET_STD = [0.229, 0.224, 0.225] preprocess = T.Compose([ T.Resize(256), T.CenterCrop(224), T.ToTensor(), ]) normalize = T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD) inv_normalize = T.Normalize( mean=[-m / s for m, s in zip(IMAGENET_MEAN, IMAGENET_STD)], std=[1 / s for s in IMAGENET_STD], ) # Load ImageNet class labels LABELS_URL = "https://raw.githubusercontent.com/anishathalye/imagenet-simple-labels/master/imagenet-simple-labels.json" try: import urllib.request with urllib.request.urlopen(LABELS_URL) as resp: LABELS = json.loads(resp.read().decode()) except Exception: LABELS = [str(i) for i in range(1000)] def get_top3(logits: torch.Tensor): probs = F.softmax(logits, dim=1)[0] top3 = torch.topk(probs, 3) return [(LABELS[idx], float(prob)) for prob, idx in zip(top3.values, top3.indices)] # --------------------------------------------------------------------------- # Attack implementations # --------------------------------------------------------------------------- def fgsm_attack(img_tensor: torch.Tensor, epsilon: float) -> torch.Tensor: """Single-step FGSM (untargeted).""" inp = normalize(img_tensor.clone()).unsqueeze(0).to(device) inp.requires_grad = True output = model(inp) loss = F.cross_entropy(output, output.argmax(1)) loss.backward() # Perturb in *pixel* space (pre-normalize) grad_sign = inp.grad.sign() # Convert gradient back to pixel space perturbed_norm = inp + epsilon * grad_sign # Denormalize, clamp, re-normalize to get pixel-space perturbed image perturbed_pixel = inv_normalize(perturbed_norm.squeeze(0)) perturbed_pixel = torch.clamp(perturbed_pixel, 0, 1) return perturbed_pixel def pgd_attack( img_tensor: torch.Tensor, epsilon: float, alpha: float, num_steps: int, ) -> torch.Tensor: """Multi-step PGD (untargeted).""" orig = img_tensor.clone() perturbed = img_tensor.clone() for _ in range(num_steps): inp = normalize(perturbed.clone()).unsqueeze(0).to(device) inp.requires_grad = True output = model(inp) loss = F.cross_entropy(output, output.argmax(1)) loss.backward() grad_sign = inp.grad.sign() # Step in normalized space then convert back adv_norm = inp + alpha * grad_sign adv_pixel = inv_normalize(adv_norm.squeeze(0)) # Project onto epsilon-ball around original (pixel space) perturbation = torch.clamp(adv_pixel - orig, -epsilon, epsilon) perturbed = torch.clamp(orig + perturbation, 0, 1).detach() return perturbed # --------------------------------------------------------------------------- # Main function # --------------------------------------------------------------------------- def attack( image: Image.Image, method: str, epsilon: float, pgd_steps: int, pgd_alpha: float, ): if image is None: return None, None, None, "" img_tensor = preprocess(image.convert("RGB")) # Original prediction with torch.no_grad(): orig_logits = model(normalize(img_tensor).unsqueeze(0)) orig_pred = get_top3(orig_logits) orig_label = orig_pred[0][0] # Attack if method == "FGSM": adv_tensor = fgsm_attack(img_tensor, epsilon) else: adv_tensor = pgd_attack(img_tensor, epsilon, pgd_alpha, pgd_steps) # Adversarial prediction with torch.no_grad(): adv_logits = model(normalize(adv_tensor).unsqueeze(0)) adv_pred = get_top3(adv_logits) adv_label = adv_pred[0][0] # Perturbation visualization (amplified 10x) diff = (adv_tensor - img_tensor) perturbation = torch.clamp(diff * 10 + 0.5, 0, 1) # Convert to numpy images orig_img = (img_tensor.permute(1, 2, 0).numpy() * 255).astype(np.uint8) pert_img = (perturbation.permute(1, 2, 0).numpy() * 255).astype(np.uint8) adv_img = (adv_tensor.permute(1, 2, 0).numpy() * 255).astype(np.uint8) # Metrics linf = float(diff.abs().max()) l2 = float(diff.norm(2)) success = "ATTACK SUCCESS" if orig_label != adv_label else "Attack failed (same class)" metrics_text = ( f"**{success}**\n\n" f"| Metric | Value |\n|---|---|\n" f"| Original Top-1 | {orig_pred[0][0]} ({orig_pred[0][1]:.1%}) |\n" f"| Adversarial Top-1 | {adv_pred[0][0]} ({adv_pred[0][1]:.1%}) |\n" f"| L-inf | {linf:.4f} |\n" f"| L2 | {l2:.4f} |\n" f"| Epsilon | {epsilon} |\n" f"| Method | {method} |" ) return orig_img, pert_img, adv_img, metrics_text # --------------------------------------------------------------------------- # Gradio UI # --------------------------------------------------------------------------- with gr.Blocks(title="Adversarial Attack Demo") as demo: gr.Markdown( "# Adversarial Attack Demo | FGSM & PGD\n" "Upload an image and see how imperceptible perturbations fool a ResNet-18 classifier.\n" "*Course: 215 AI Safety*" ) with gr.Row(): with gr.Column(scale=1): input_image = gr.Image(type="pil", label="Upload Image") method = gr.Radio(["FGSM", "PGD"], value="FGSM", label="Attack Method") epsilon = gr.Slider(0.0, 0.3, value=0.03, step=0.005, label="Epsilon (perturbation budget)") pgd_steps = gr.Slider(1, 40, value=10, step=1, label="PGD Steps", visible=True) pgd_alpha = gr.Slider(0.001, 0.05, value=0.007, step=0.001, label="PGD Step Size", visible=True) run_btn = gr.Button("Run Attack", variant="primary") with gr.Column(scale=2): with gr.Row(): orig_out = gr.Image(label="Original Image") pert_out = gr.Image(label="Perturbation (10x amplified)") adv_out = gr.Image(label="Adversarial Image") metrics = gr.Markdown(label="Results") def toggle_pgd(m): visible = m == "PGD" return gr.update(visible=visible), gr.update(visible=visible) method.change(toggle_pgd, inputs=[method], outputs=[pgd_steps, pgd_alpha]) run_btn.click( fn=attack, inputs=[input_image, method, epsilon, pgd_steps, pgd_alpha], outputs=[orig_out, pert_out, adv_out, metrics], ) gr.Examples( examples=[ ["examples/cat.jpg", "FGSM", 0.03, 10, 0.007], ["examples/dog.jpg", "PGD", 0.02, 20, 0.005], ["examples/car.jpg", "FGSM", 0.05, 10, 0.007], ], inputs=[input_image, method, epsilon, pgd_steps, pgd_alpha], label="Try these examples", ) if __name__ == "__main__": demo.launch()