ramagururadhakrishnan commited on
Commit
e6abe34
·
verified ·
1 Parent(s): 0016469

- Blur the sensitive parts in the uploaded image
- Integrated Explainable AI

Files changed (1) hide show
  1. app.py +39 -9
app.py CHANGED
@@ -1,16 +1,20 @@
1
  import streamlit as st
2
  import torch
 
3
  from PIL import Image
4
  from torchvision import transforms
5
  import json
6
- from src.model import SwinTransformerMultiLabel # Import from src folder
 
 
 
7
 
8
  # Title and description
9
- st.title("STAR Multi-Label Classifier")
10
- st.write("Upload an image to get multi-label predictions.")
11
 
12
  # Load class labels from JSON
13
- label_file = "data/labels.json" # Path to labels file
14
  with open(label_file, "r") as f:
15
  label_data = json.load(f)
16
 
@@ -24,6 +28,9 @@ model = SwinTransformerMultiLabel(num_classes=NUM_CLASSES)
24
  model.load_state_dict(torch.load(model_path, map_location="cpu"))
25
  model.eval()
26
 
 
 
 
27
  # Define image preprocessing transformations
28
  transform = transforms.Compose([
29
  transforms.Resize((224, 224)),
@@ -31,18 +38,41 @@ transform = transforms.Compose([
31
  transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
32
  ])
33
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  # Upload image
35
  uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
36
  if uploaded_file is not None:
37
  image = Image.open(uploaded_file).convert("RGB")
38
- st.image(image, caption="Uploaded Image", use_column_width=True)
39
 
40
- # Predict
41
  img_tensor = transform(image).unsqueeze(0)
 
 
42
  with torch.no_grad():
43
  output = model(img_tensor)
44
- predicted_indices = [i for i in range(NUM_CLASSES) if output[0][i] > 0.5] # Threshold = 0.5
45
- predicted_labels = [class_labels[i] for i in predicted_indices] # Convert indices to labels
 
 
 
 
 
 
 
 
 
 
46
 
47
- # Display results
48
  st.write("✅ **Predicted Labels:**", ", ".join(predicted_labels) if predicted_labels else "No labels detected")
 
1
  import streamlit as st
2
  import torch
3
+ import numpy as np
4
  from PIL import Image
5
  from torchvision import transforms
6
  import json
7
+ import matplotlib.pyplot as plt
8
+ import cv2
9
+ from src.model import SwinTransformerMultiLabel
10
+ from torchcam.methods import SmoothGradCAMpp # Explainability Module
11
 
12
  # Title and description
13
+ st.title("STAR Multi-Label Classifier with Sensitive Content Blurring")
14
+ st.write("Upload an image to classify and see the blurred output.")
15
 
16
  # Load class labels from JSON
17
+ label_file = "data/labels.json"
18
  with open(label_file, "r") as f:
19
  label_data = json.load(f)
20
 
 
28
  model.load_state_dict(torch.load(model_path, map_location="cpu"))
29
  model.eval()
30
 
31
+ # Initialize CAM explainability
32
+ cam_extractor = SmoothGradCAMpp(model, target_layer="model.features.7.3") # Choose a valid layer
33
+
34
  # Define image preprocessing transformations
35
  transform = transforms.Compose([
36
  transforms.Resize((224, 224)),
 
38
  transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
39
  ])
40
 
41
+ # Function to blur sensitive areas
42
+ def blur_sensitive_parts(image, cam_mask, blur_intensity=25):
43
+ image = np.array(image) # Convert PIL to NumPy
44
+ heatmap = cv2.resize(cam_mask, (image.shape[1], image.shape[0])) # Resize CAM to match image
45
+ heatmap = (heatmap > 0.6).astype(np.uint8) # Threshold for blurring region
46
+
47
+ blurred = cv2.GaussianBlur(image, (51, 51), blur_intensity) # Apply blur
48
+ mask = np.repeat(heatmap[:, :, np.newaxis], 3, axis=2) # Create 3-channel mask
49
+ result = np.where(mask == 1, blurred, image) # Blend blurred areas
50
+
51
+ return Image.fromarray(result) # Convert back to PIL Image
52
+
53
  # Upload image
54
  uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
55
  if uploaded_file is not None:
56
  image = Image.open(uploaded_file).convert("RGB")
 
57
 
58
+ # Preprocess image for model
59
  img_tensor = transform(image).unsqueeze(0)
60
+
61
+ # Perform inference
62
  with torch.no_grad():
63
  output = model(img_tensor)
64
+ predicted_indices = [i for i in range(NUM_CLASSES) if output[0][i] > 0.5]
65
+ predicted_labels = [class_labels[i] for i in predicted_indices]
66
+
67
+ # Generate CAM heatmap
68
+ blurred_image = image # Default to original if no prediction
69
+ if predicted_indices:
70
+ cam = cam_extractor(predicted_indices[0], output).squeeze().cpu().numpy()
71
+ cam = (cam - np.min(cam)) / (np.max(cam)) # Normalize
72
+ blurred_image = blur_sensitive_parts(image, cam) # Apply blur
73
+
74
+ # Display the blurred image instead of the original
75
+ st.image(blurred_image, caption="Blurred Output Image", use_column_width=True)
76
 
77
+ # Display predictions
78
  st.write("✅ **Predicted Labels:**", ", ".join(predicted_labels) if predicted_labels else "No labels detected")