ratneshpasi03's picture
Update app.py
e66f8a1 verified
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn.functional as F
from PIL import Image
import streamlit as st
import requests
from io import BytesIO
import os
import string
# Page config
st.set_page_config(page_title="Adversarial Self-Driving Test", layout="wide")
# Title & Description
st.title("Adversarial Self-Driving Car Tester")
st.markdown("Upload a traffic sign, or select from default images to **confuse the AI model** into causing a virtual accident!")
# Load model + labels
model = torchvision.models.resnet18(pretrained=True)
model.eval()
LABELS_URL = "https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt"
labels = requests.get(LABELS_URL).text.strip().split("\n")
# Base transform for model input
model_transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
])
# Layout Selection
layout = st.radio("Choose Input Method:", ["Upload Image", "Select Default Image"])
image = None
if layout == "Upload Image":
uploaded_file = st.file_uploader("πŸ“· Upload a traffic sign image", type=["jpg", "jpeg", "png", "bmp", "webp"])
if uploaded_file:
image = Image.open(uploaded_file).convert('RGB')
st.image(image, caption="Uploaded Image", use_container_width=True)
st.session_state.selected_default_image = None
elif layout == "Select Default Image":
supported_exts = (".jpg", ".jpeg", ".png", ".bmp", ".webp")
default_images = sorted([f for f in os.listdir("images") if f.lower().endswith(supported_exts)])
cols = st.columns(4)
for idx, img_file in enumerate(default_images):
with cols[idx % 4]:
img_path = os.path.join("images", img_file)
img = Image.open(img_path).resize((200, 200))
st.image(img, use_container_width=True)
button_label = f"Select {string.ascii_uppercase[idx]}"
if st.button(button_label, key=f"select_{img_file}"):
st.session_state.selected_default_image = img_path
if "selected_default_image" in st.session_state and st.session_state.selected_default_image:
selected_path = st.session_state.selected_default_image
image = Image.open(selected_path).convert('RGB')
st.markdown("#### Selected Default Image")
st.image(image, caption=os.path.basename(selected_path), use_container_width=True)
# Epsilon slider
epsilon = st.slider("Perturbation Strength (epsilon)", 0.001, 0.1, 0.01, step=0.001)
# Target class selector
target_class = st.selectbox(
"Confuse the model into predicting:",
options=[
(919, "Stop Sign"),
(717, "Speed Limit 60"),
(718, "Speed Limit 80"),
(400, "Speedboat (LOL why?)"),
],
format_func=lambda x: f"{x[0]} - {x[1]}"
)
target_class_id = target_class[0]
target_class_label = target_class[1]
# --- PREDICTION LOGIC ---
if image:
with st.spinner("🧠 Running AI Model & Generating Adversarial Image..."):
# Save original size
original_size = image.size # (width, height)
# Prepare input
input_tensor = model_transform(image).unsqueeze(0)
input_tensor.requires_grad = True
# Original prediction
with torch.no_grad():
orig_out = model(input_tensor)
orig_pred_idx = orig_out.argmax().item()
orig_pred = labels[orig_pred_idx]
# FGSM Attack
output = model(input_tensor)
loss = F.cross_entropy(output, torch.tensor([target_class_id]))
loss.backward()
perturb = epsilon * input_tensor.grad.sign()
adv_tensor = torch.clamp(input_tensor + perturb, 0, 1)
# Resize perturbed tensor back to original image size for display
adv_image_tensor = adv_tensor.squeeze(0)
adv_image_pil = transforms.ToPILImage()(adv_image_tensor)
adv_image_resized = adv_image_pil.resize(original_size)
# Adversarial prediction
adv_input_resized = model_transform(adv_image_resized).unsqueeze(0)
with torch.no_grad():
adv_out = model(adv_input_resized)
adv_pred_idx = adv_out.argmax().item()
adv_pred = labels[adv_pred_idx]
# Display Results
col1, col2 = st.columns(2)
with col1:
st.image(image, caption="Original Image", use_container_width=True)
st.success(f"βœ… **Original Prediction:** `{orig_pred}`")
with col2:
st.image(adv_image_resized, caption="Adversarial Image", use_container_width=True)
if orig_pred != adv_pred:
st.warning(f"⚠️ **Adversarial Prediction:** `{adv_pred}`")
else:
st.success(f"βœ… **Adversarial Prediction:** `{adv_pred}`")
if orig_pred != adv_pred:
st.markdown("#### 🚨 Accident Report")
st.error(f"The car thought a `{orig_pred}` was a `{adv_pred}`. That's a full-on self-driving fail!")