ramagururadhakrishnan commited on
Commit
8623a7a
·
verified ·
1 Parent(s): cadc21b

- Error in XAI

Files changed (1) hide show
  1. app.py +17 -4
app.py CHANGED
@@ -7,7 +7,7 @@ import json
7
  import matplotlib.pyplot as plt
8
  import cv2
9
  from src.model import SwinTransformerMultiLabel
10
- from torchcam.methods import SmoothGradCAMpp # Explainability Module
11
 
12
  # Title and description
13
  st.title("STAR Multi-Label Classifier with Sensitive Content Blurring")
@@ -28,8 +28,21 @@ model = SwinTransformerMultiLabel(num_classes=NUM_CLASSES)
28
  model.load_state_dict(torch.load(model_path, map_location="cpu"))
29
  model.eval()
30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  # Initialize CAM explainability
32
- cam_extractor = SmoothGradCAMpp(model, target_layer="model.features.7.3") # Choose a valid layer
33
 
34
  # Define image preprocessing transformations
35
  transform = transforms.Compose([
@@ -64,14 +77,14 @@ if uploaded_file is not None:
64
  predicted_indices = [i for i in range(NUM_CLASSES) if output[0][i] > 0.5]
65
  predicted_labels = [class_labels[i] for i in predicted_indices]
66
 
67
- # Generate CAM heatmap
68
  blurred_image = image # Default to original if no prediction
69
  if predicted_indices:
70
  cam = cam_extractor(predicted_indices[0], output).squeeze().cpu().numpy()
71
  cam = (cam - np.min(cam)) / (np.max(cam)) # Normalize
72
  blurred_image = blur_sensitive_parts(image, cam) # Apply blur
73
 
74
- # Display the blurred image instead of the original
75
  st.image(blurred_image, caption="Blurred Output Image", use_column_width=True)
76
 
77
  # Display predictions
 
7
  import matplotlib.pyplot as plt
8
  import cv2
9
  from src.model import SwinTransformerMultiLabel
10
+ from torchcam.methods import SmoothGradCAMpp
11
 
12
  # Title and description
13
  st.title("STAR Multi-Label Classifier with Sensitive Content Blurring")
 
28
  model.load_state_dict(torch.load(model_path, map_location="cpu"))
29
  model.eval()
30
 
31
+ # Print model architecture to find a valid layer for CAM
32
+ st.write("🔍 Finding a valid layer for CAM...")
33
+ valid_layer = None
34
+ for name, module in model.named_modules():
35
+ if "conv" in name.lower() or "features" in name: # Look for convolutional layers
36
+ valid_layer = name
37
+
38
+ # Ensure a valid layer was found
39
+ if not valid_layer:
40
+ raise ValueError("❌ No valid convolutional layer found for CAM!")
41
+
42
+ st.write(f"✅ Using layer '{valid_layer}' for explainability.")
43
+
44
  # Initialize CAM explainability
45
+ cam_extractor = SmoothGradCAMpp(model, target_layer=valid_layer) # Use detected layer
46
 
47
  # Define image preprocessing transformations
48
  transform = transforms.Compose([
 
77
  predicted_indices = [i for i in range(NUM_CLASSES) if output[0][i] > 0.5]
78
  predicted_labels = [class_labels[i] for i in predicted_indices]
79
 
80
+ # Generate CAM heatmap and blur the image
81
  blurred_image = image # Default to original if no prediction
82
  if predicted_indices:
83
  cam = cam_extractor(predicted_indices[0], output).squeeze().cpu().numpy()
84
  cam = (cam - np.min(cam)) / (np.max(cam)) # Normalize
85
  blurred_image = blur_sensitive_parts(image, cam) # Apply blur
86
 
87
+ # Display only the blurred image
88
  st.image(blurred_image, caption="Blurred Output Image", use_column_width=True)
89
 
90
  # Display predictions