ratneshpasi03 commited on
Commit
e66f8a1
Β·
verified Β·
1 Parent(s): 07147a1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +123 -96
app.py CHANGED
@@ -1,103 +1,130 @@
1
- import streamlit as st
 
 
 
2
  from PIL import Image
 
3
  import requests
4
  from io import BytesIO
5
  import os
6
- import torch
7
- from transformers import CLIPProcessor, CLIPModel
8
- import pandas as pd
9
- import matplotlib.pyplot as plt
10
-
11
- # Load model
12
- @st.cache_resource
13
- def load_clip_model():
14
- model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
15
- processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32", use_fast=False)
16
- return model, processor
17
-
18
- # --- Main page content ---
19
- st.header("πŸ‘– Clothing Bias in Scene Classification πŸ‘—")
20
- st.markdown("""
21
- This application explores biases in scene classification models related to clothing attributes.
22
- It leverages the CLIP model to analyze and highlight these biases.
23
- """)
24
-
25
- input_method = st.selectbox("Select Input Method", ["Default Images", "Upload Image", "Use Image URL"], index=0)
 
 
 
 
26
 
27
  image = None
28
- if input_method == "Upload Image":
29
- uploaded_file = st.file_uploader("Upload your own image", type=["jpg", "png", "jpeg"])
 
30
  if uploaded_file:
31
- image = Image.open(uploaded_file).convert("RGB")
32
- elif input_method == "Use Image URL":
33
- image_url = st.text_input("Paste an image URL")
34
- if image_url:
35
- try:
36
- response = requests.get(image_url)
37
- image = Image.open(BytesIO(response.content)).convert("RGB")
38
- except:
39
- st.error("Couldn't load image from the provided URL.")
40
- elif input_method == "Default Images":
41
- st.subheader("πŸ–ΌοΈ Select a Default Image")
42
-
43
- image_dir = "default_images/clothing_bias"
44
- default_images = sorted([f for f in os.listdir(image_dir) if f.lower().endswith((".jpg", ".jpeg", ".png"))])
45
-
46
- selected_image = None
47
- columns = st.columns(4) # Display images in 4 columns
48
-
49
- for i, image_file in enumerate(default_images):
50
- col = columns[i % 4]
51
- img_path = os.path.join(image_dir, image_file)
52
- with col:
53
- st.image(img_path, caption=image_file, use_container_width=True)
54
- if st.button(f"Select {image_file}", key=image_file):
55
- selected_image = image_file
56
-
57
- # Store selected image using session state so selection persists
58
- if selected_image:
59
- st.session_state.selected_image = selected_image
60
-
61
- if "selected_image" in st.session_state:
62
- image_path = os.path.join(image_dir, st.session_state.selected_image)
63
- image = Image.open(image_path).convert("RGB")
64
- st.success(f"Selected: {st.session_state.selected_image}")
65
-
66
- # Show the image if loaded
67
- if image is not None:
68
- st.image(image, caption="Input Image", width=250)
69
-
70
- # Prompt input
71
- st.subheader("πŸ“ Candidate Scene Labels")
72
- default_prompts = ["a business executive", "a festival participant"]
73
- prompts_text = st.text_area("Enter one label per line:", "\n".join(default_prompts))
74
- labels = [label.strip() for label in prompts_text.strip().split("\n") if label.strip()]
75
-
76
- # Analyze button
77
- if st.button("πŸ” Analyze Image"):
78
- if image is None:
79
- st.warning("⚠️ Please upload an image, paste a URL, or choose a default image before analysis.")
80
- else:
81
- model, processor = load_clip_model()
82
- with st.spinner("Analyzing the image, please wait..."):
83
- inputs = processor(text=labels, images=image, return_tensors="pt", padding=True)
84
- with torch.no_grad():
85
- outputs = model(**inputs)
86
- probs = outputs.logits_per_image.softmax(dim=1)[0]
87
-
88
- # Show probabilities
89
- st.subheader("πŸ“Š Classification Probabilities")
90
- data = {"Label": labels, "Probability": probs.numpy()}
91
- df = pd.DataFrame(data)
92
- df.index += 1 # Start index from 1
93
- st.table(df)
94
- st.write("**Most likely label**:", labels[probs.argmax().item()])
95
- st.write("\n")
96
-
97
- # Bar plot
98
- fig, ax = plt.subplots(figsize=(6, 4))
99
- ax.barh(labels, probs.numpy(), color='skyblue')
100
- ax.set_xlim(0, 1)
101
- ax.set_xlabel("Probability")
102
- ax.set_title("Scene Classification")
103
- st.pyplot(fig)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchvision
3
+ import torchvision.transforms as transforms
4
+ import torch.nn.functional as F
5
  from PIL import Image
6
+ import streamlit as st
7
  import requests
8
  from io import BytesIO
9
  import os
10
+ import string
11
+
12
+ # Page config
13
+ st.set_page_config(page_title="Adversarial Self-Driving Test", layout="wide")
14
+
15
+ # Title & Description
16
+ st.title("Adversarial Self-Driving Car Tester")
17
+ st.markdown("Upload a traffic sign, or select from default images to **confuse the AI model** into causing a virtual accident!")
18
+
19
+ # Load model + labels
20
+ model = torchvision.models.resnet18(pretrained=True)
21
+ model.eval()
22
+
23
+ LABELS_URL = "https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt"
24
+ labels = requests.get(LABELS_URL).text.strip().split("\n")
25
+
26
+ # Base transform for model input
27
+ model_transform = transforms.Compose([
28
+ transforms.Resize((224, 224)),
29
+ transforms.ToTensor(),
30
+ ])
31
+
32
+ # Layout Selection
33
+ layout = st.radio("Choose Input Method:", ["Upload Image", "Select Default Image"])
34
 
35
  image = None
36
+
37
+ if layout == "Upload Image":
38
+ uploaded_file = st.file_uploader("πŸ“· Upload a traffic sign image", type=["jpg", "jpeg", "png", "bmp", "webp"])
39
  if uploaded_file:
40
+ image = Image.open(uploaded_file).convert('RGB')
41
+ st.image(image, caption="Uploaded Image", use_container_width=True)
42
+ st.session_state.selected_default_image = None
43
+ elif layout == "Select Default Image":
44
+ supported_exts = (".jpg", ".jpeg", ".png", ".bmp", ".webp")
45
+ default_images = sorted([f for f in os.listdir("images") if f.lower().endswith(supported_exts)])
46
+ cols = st.columns(4)
47
+
48
+ for idx, img_file in enumerate(default_images):
49
+ with cols[idx % 4]:
50
+ img_path = os.path.join("images", img_file)
51
+ img = Image.open(img_path).resize((200, 200))
52
+ st.image(img, use_container_width=True)
53
+ button_label = f"Select {string.ascii_uppercase[idx]}"
54
+ if st.button(button_label, key=f"select_{img_file}"):
55
+ st.session_state.selected_default_image = img_path
56
+
57
+ if "selected_default_image" in st.session_state and st.session_state.selected_default_image:
58
+ selected_path = st.session_state.selected_default_image
59
+ image = Image.open(selected_path).convert('RGB')
60
+ st.markdown("#### Selected Default Image")
61
+ st.image(image, caption=os.path.basename(selected_path), use_container_width=True)
62
+
63
+ # Epsilon slider
64
+ epsilon = st.slider("Perturbation Strength (epsilon)", 0.001, 0.1, 0.01, step=0.001)
65
+
66
+ # Target class selector
67
+ target_class = st.selectbox(
68
+ "Confuse the model into predicting:",
69
+ options=[
70
+ (919, "Stop Sign"),
71
+ (717, "Speed Limit 60"),
72
+ (718, "Speed Limit 80"),
73
+ (400, "Speedboat (LOL why?)"),
74
+ ],
75
+ format_func=lambda x: f"{x[0]} - {x[1]}"
76
+ )
77
+ target_class_id = target_class[0]
78
+ target_class_label = target_class[1]
79
+
80
+ # --- PREDICTION LOGIC ---
81
+ if image:
82
+ with st.spinner("🧠 Running AI Model & Generating Adversarial Image..."):
83
+ # Save original size
84
+ original_size = image.size # (width, height)
85
+
86
+ # Prepare input
87
+ input_tensor = model_transform(image).unsqueeze(0)
88
+ input_tensor.requires_grad = True
89
+
90
+ # Original prediction
91
+ with torch.no_grad():
92
+ orig_out = model(input_tensor)
93
+ orig_pred_idx = orig_out.argmax().item()
94
+ orig_pred = labels[orig_pred_idx]
95
+
96
+ # FGSM Attack
97
+ output = model(input_tensor)
98
+ loss = F.cross_entropy(output, torch.tensor([target_class_id]))
99
+ loss.backward()
100
+ perturb = epsilon * input_tensor.grad.sign()
101
+ adv_tensor = torch.clamp(input_tensor + perturb, 0, 1)
102
+
103
+ # Resize perturbed tensor back to original image size for display
104
+ adv_image_tensor = adv_tensor.squeeze(0)
105
+ adv_image_pil = transforms.ToPILImage()(adv_image_tensor)
106
+ adv_image_resized = adv_image_pil.resize(original_size)
107
+
108
+ # Adversarial prediction
109
+ adv_input_resized = model_transform(adv_image_resized).unsqueeze(0)
110
+ with torch.no_grad():
111
+ adv_out = model(adv_input_resized)
112
+ adv_pred_idx = adv_out.argmax().item()
113
+ adv_pred = labels[adv_pred_idx]
114
+
115
+ # Display Results
116
+ col1, col2 = st.columns(2)
117
+ with col1:
118
+ st.image(image, caption="Original Image", use_container_width=True)
119
+ st.success(f"βœ… **Original Prediction:** `{orig_pred}`")
120
+
121
+ with col2:
122
+ st.image(adv_image_resized, caption="Adversarial Image", use_container_width=True)
123
+ if orig_pred != adv_pred:
124
+ st.warning(f"⚠️ **Adversarial Prediction:** `{adv_pred}`")
125
+ else:
126
+ st.success(f"βœ… **Adversarial Prediction:** `{adv_pred}`")
127
+
128
+ if orig_pred != adv_pred:
129
+ st.markdown("#### 🚨 Accident Report")
130
+ st.error(f"The car thought a `{orig_pred}` was a `{adv_pred}`. That's a full-on self-driving fail!")