Spaces:
Runtime error
Runtime error
| import torch | |
| import torchvision | |
| import torchvision.transforms as transforms | |
| import torch.nn.functional as F | |
| from PIL import Image | |
| import streamlit as st | |
| import requests | |
| from io import BytesIO | |
| import os | |
| import string | |
| # Page config | |
| st.set_page_config(page_title="Adversarial Self-Driving Test", layout="wide") | |
| # Title & Description | |
| st.title("Adversarial Self-Driving Car Tester") | |
| st.markdown("Upload a traffic sign, or select from default images to **confuse the AI model** into causing a virtual accident!") | |
| # Load model + labels | |
| model = torchvision.models.resnet18(pretrained=True) | |
| model.eval() | |
| LABELS_URL = "https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt" | |
| labels = requests.get(LABELS_URL).text.strip().split("\n") | |
| # Base transform for model input | |
| model_transform = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| ]) | |
| # Layout Selection | |
| layout = st.radio("Choose Input Method:", ["Upload Image", "Select Default Image"]) | |
| image = None | |
| if layout == "Upload Image": | |
| uploaded_file = st.file_uploader("π· Upload a traffic sign image", type=["jpg", "jpeg", "png", "bmp", "webp"]) | |
| if uploaded_file: | |
| image = Image.open(uploaded_file).convert('RGB') | |
| st.image(image, caption="Uploaded Image", use_container_width=True) | |
| st.session_state.selected_default_image = None | |
| elif layout == "Select Default Image": | |
| supported_exts = (".jpg", ".jpeg", ".png", ".bmp", ".webp") | |
| default_images = sorted([f for f in os.listdir("images") if f.lower().endswith(supported_exts)]) | |
| cols = st.columns(4) | |
| for idx, img_file in enumerate(default_images): | |
| with cols[idx % 4]: | |
| img_path = os.path.join("images", img_file) | |
| img = Image.open(img_path).resize((200, 200)) | |
| st.image(img, use_container_width=True) | |
| button_label = f"Select {string.ascii_uppercase[idx]}" | |
| if st.button(button_label, key=f"select_{img_file}"): | |
| st.session_state.selected_default_image = img_path | |
| if "selected_default_image" in st.session_state and st.session_state.selected_default_image: | |
| selected_path = st.session_state.selected_default_image | |
| image = Image.open(selected_path).convert('RGB') | |
| st.markdown("#### Selected Default Image") | |
| st.image(image, caption=os.path.basename(selected_path), use_container_width=True) | |
| # Epsilon slider | |
| epsilon = st.slider("Perturbation Strength (epsilon)", 0.001, 0.1, 0.01, step=0.001) | |
| # Target class selector | |
| target_class = st.selectbox( | |
| "Confuse the model into predicting:", | |
| options=[ | |
| (919, "Stop Sign"), | |
| (717, "Speed Limit 60"), | |
| (718, "Speed Limit 80"), | |
| (400, "Speedboat (LOL why?)"), | |
| ], | |
| format_func=lambda x: f"{x[0]} - {x[1]}" | |
| ) | |
| target_class_id = target_class[0] | |
| target_class_label = target_class[1] | |
| # --- PREDICTION LOGIC --- | |
| if image: | |
| with st.spinner("π§ Running AI Model & Generating Adversarial Image..."): | |
| # Save original size | |
| original_size = image.size # (width, height) | |
| # Prepare input | |
| input_tensor = model_transform(image).unsqueeze(0) | |
| input_tensor.requires_grad = True | |
| # Original prediction | |
| with torch.no_grad(): | |
| orig_out = model(input_tensor) | |
| orig_pred_idx = orig_out.argmax().item() | |
| orig_pred = labels[orig_pred_idx] | |
| # FGSM Attack | |
| output = model(input_tensor) | |
| loss = F.cross_entropy(output, torch.tensor([target_class_id])) | |
| loss.backward() | |
| perturb = epsilon * input_tensor.grad.sign() | |
| adv_tensor = torch.clamp(input_tensor + perturb, 0, 1) | |
| # Resize perturbed tensor back to original image size for display | |
| adv_image_tensor = adv_tensor.squeeze(0) | |
| adv_image_pil = transforms.ToPILImage()(adv_image_tensor) | |
| adv_image_resized = adv_image_pil.resize(original_size) | |
| # Adversarial prediction | |
| adv_input_resized = model_transform(adv_image_resized).unsqueeze(0) | |
| with torch.no_grad(): | |
| adv_out = model(adv_input_resized) | |
| adv_pred_idx = adv_out.argmax().item() | |
| adv_pred = labels[adv_pred_idx] | |
| # Display Results | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| st.image(image, caption="Original Image", use_container_width=True) | |
| st.success(f"β **Original Prediction:** `{orig_pred}`") | |
| with col2: | |
| st.image(adv_image_resized, caption="Adversarial Image", use_container_width=True) | |
| if orig_pred != adv_pred: | |
| st.warning(f"β οΈ **Adversarial Prediction:** `{adv_pred}`") | |
| else: | |
| st.success(f"β **Adversarial Prediction:** `{adv_pred}`") | |
| if orig_pred != adv_pred: | |
| st.markdown("#### π¨ Accident Report") | |
| st.error(f"The car thought a `{orig_pred}` was a `{adv_pred}`. That's a full-on self-driving fail!") | |