Spaces:
Sleeping
Sleeping
| """ | |
| Adversarial Attack Demo — FGSM & PGD | |
| Courses: 215 AI Safety ch1-ch2 | |
| """ | |
| import json | |
| import numpy as np | |
| import torch | |
| import torch.nn.functional as F | |
| import torchvision.models as models | |
| import torchvision.transforms as T | |
| import gradio as gr | |
| from PIL import Image | |
| # --------------------------------------------------------------------------- | |
| # Model & preprocessing | |
| # --------------------------------------------------------------------------- | |
| device = torch.device("cpu") | |
| model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1).eval().to(device) | |
| IMAGENET_MEAN = [0.485, 0.456, 0.406] | |
| IMAGENET_STD = [0.229, 0.224, 0.225] | |
| preprocess = T.Compose([ | |
| T.Resize(256), | |
| T.CenterCrop(224), | |
| T.ToTensor(), | |
| ]) | |
| normalize = T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD) | |
| inv_normalize = T.Normalize( | |
| mean=[-m / s for m, s in zip(IMAGENET_MEAN, IMAGENET_STD)], | |
| std=[1 / s for s in IMAGENET_STD], | |
| ) | |
| # Load ImageNet class labels | |
| LABELS_URL = "https://raw.githubusercontent.com/anishathalye/imagenet-simple-labels/master/imagenet-simple-labels.json" | |
| try: | |
| import urllib.request | |
| with urllib.request.urlopen(LABELS_URL) as resp: | |
| LABELS = json.loads(resp.read().decode()) | |
| except Exception: | |
| LABELS = [str(i) for i in range(1000)] | |
| def get_top3(logits: torch.Tensor): | |
| probs = F.softmax(logits, dim=1)[0] | |
| top3 = torch.topk(probs, 3) | |
| return [(LABELS[idx], float(prob)) for prob, idx in zip(top3.values, top3.indices)] | |
| # --------------------------------------------------------------------------- | |
| # Attack implementations | |
| # --------------------------------------------------------------------------- | |
| def fgsm_attack(img_tensor: torch.Tensor, epsilon: float) -> torch.Tensor: | |
| """Single-step FGSM (untargeted).""" | |
| inp = normalize(img_tensor.clone()).unsqueeze(0).to(device) | |
| inp.requires_grad = True | |
| output = model(inp) | |
| loss = F.cross_entropy(output, output.argmax(1)) | |
| loss.backward() | |
| # Perturb in *pixel* space (pre-normalize) | |
| grad_sign = inp.grad.sign() | |
| # Convert gradient back to pixel space | |
| perturbed_norm = inp + epsilon * grad_sign | |
| # Denormalize, clamp, re-normalize to get pixel-space perturbed image | |
| perturbed_pixel = inv_normalize(perturbed_norm.squeeze(0)) | |
| perturbed_pixel = torch.clamp(perturbed_pixel, 0, 1) | |
| return perturbed_pixel | |
| def pgd_attack( | |
| img_tensor: torch.Tensor, | |
| epsilon: float, | |
| alpha: float, | |
| num_steps: int, | |
| ) -> torch.Tensor: | |
| """Multi-step PGD (untargeted).""" | |
| orig = img_tensor.clone() | |
| perturbed = img_tensor.clone() | |
| for _ in range(num_steps): | |
| inp = normalize(perturbed.clone()).unsqueeze(0).to(device) | |
| inp.requires_grad = True | |
| output = model(inp) | |
| loss = F.cross_entropy(output, output.argmax(1)) | |
| loss.backward() | |
| grad_sign = inp.grad.sign() | |
| # Step in normalized space then convert back | |
| adv_norm = inp + alpha * grad_sign | |
| adv_pixel = inv_normalize(adv_norm.squeeze(0)) | |
| # Project onto epsilon-ball around original (pixel space) | |
| perturbation = torch.clamp(adv_pixel - orig, -epsilon, epsilon) | |
| perturbed = torch.clamp(orig + perturbation, 0, 1).detach() | |
| return perturbed | |
| # --------------------------------------------------------------------------- | |
| # Main function | |
| # --------------------------------------------------------------------------- | |
| def attack( | |
| image: Image.Image, | |
| method: str, | |
| epsilon: float, | |
| pgd_steps: int, | |
| pgd_alpha: float, | |
| ): | |
| if image is None: | |
| return None, None, None, "" | |
| img_tensor = preprocess(image.convert("RGB")) | |
| # Original prediction | |
| with torch.no_grad(): | |
| orig_logits = model(normalize(img_tensor).unsqueeze(0)) | |
| orig_pred = get_top3(orig_logits) | |
| orig_label = orig_pred[0][0] | |
| # Attack | |
| if method == "FGSM": | |
| adv_tensor = fgsm_attack(img_tensor, epsilon) | |
| else: | |
| adv_tensor = pgd_attack(img_tensor, epsilon, pgd_alpha, pgd_steps) | |
| # Adversarial prediction | |
| with torch.no_grad(): | |
| adv_logits = model(normalize(adv_tensor).unsqueeze(0)) | |
| adv_pred = get_top3(adv_logits) | |
| adv_label = adv_pred[0][0] | |
| # Perturbation visualization (amplified 10x) | |
| diff = (adv_tensor - img_tensor) | |
| perturbation = torch.clamp(diff * 10 + 0.5, 0, 1) | |
| # Convert to numpy images | |
| orig_img = (img_tensor.permute(1, 2, 0).numpy() * 255).astype(np.uint8) | |
| pert_img = (perturbation.permute(1, 2, 0).numpy() * 255).astype(np.uint8) | |
| adv_img = (adv_tensor.permute(1, 2, 0).numpy() * 255).astype(np.uint8) | |
| # Metrics | |
| linf = float(diff.abs().max()) | |
| l2 = float(diff.norm(2)) | |
| success = "ATTACK SUCCESS" if orig_label != adv_label else "Attack failed (same class)" | |
| metrics_text = ( | |
| f"**{success}**\n\n" | |
| f"| Metric | Value |\n|---|---|\n" | |
| f"| Original Top-1 | {orig_pred[0][0]} ({orig_pred[0][1]:.1%}) |\n" | |
| f"| Adversarial Top-1 | {adv_pred[0][0]} ({adv_pred[0][1]:.1%}) |\n" | |
| f"| L-inf | {linf:.4f} |\n" | |
| f"| L2 | {l2:.4f} |\n" | |
| f"| Epsilon | {epsilon} |\n" | |
| f"| Method | {method} |" | |
| ) | |
| return orig_img, pert_img, adv_img, metrics_text | |
| # --------------------------------------------------------------------------- | |
| # Gradio UI | |
| # --------------------------------------------------------------------------- | |
| with gr.Blocks(title="Adversarial Attack Demo") as demo: | |
| gr.Markdown( | |
| "# Adversarial Attack Demo | FGSM & PGD\n" | |
| "Upload an image and see how imperceptible perturbations fool a ResNet-18 classifier.\n" | |
| "*Course: 215 AI Safety*" | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| input_image = gr.Image(type="pil", label="Upload Image") | |
| method = gr.Radio(["FGSM", "PGD"], value="FGSM", label="Attack Method") | |
| epsilon = gr.Slider(0.0, 0.3, value=0.03, step=0.005, label="Epsilon (perturbation budget)") | |
| pgd_steps = gr.Slider(1, 40, value=10, step=1, label="PGD Steps", visible=True) | |
| pgd_alpha = gr.Slider(0.001, 0.05, value=0.007, step=0.001, label="PGD Step Size", visible=True) | |
| run_btn = gr.Button("Run Attack", variant="primary") | |
| with gr.Column(scale=2): | |
| with gr.Row(): | |
| orig_out = gr.Image(label="Original Image") | |
| pert_out = gr.Image(label="Perturbation (10x amplified)") | |
| adv_out = gr.Image(label="Adversarial Image") | |
| metrics = gr.Markdown(label="Results") | |
| def toggle_pgd(m): | |
| visible = m == "PGD" | |
| return gr.update(visible=visible), gr.update(visible=visible) | |
| method.change(toggle_pgd, inputs=[method], outputs=[pgd_steps, pgd_alpha]) | |
| run_btn.click( | |
| fn=attack, | |
| inputs=[input_image, method, epsilon, pgd_steps, pgd_alpha], | |
| outputs=[orig_out, pert_out, adv_out, metrics], | |
| ) | |
| gr.Examples( | |
| examples=[ | |
| ["examples/cat.jpg", "FGSM", 0.03, 10, 0.007], | |
| ["examples/dog.jpg", "PGD", 0.02, 20, 0.005], | |
| ["examples/car.jpg", "FGSM", 0.05, 10, 0.007], | |
| ], | |
| inputs=[input_image, method, epsilon, pgd_steps, pgd_alpha], | |
| label="Try these examples", | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |