File size: 4,519 Bytes
af59080
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
"""
Explainable AI (XAI) Inference for Nude Multi-Label Classification
==================================================================

This script performs inference using a trained Swin Transformer model for 
multi-label classification of nude images. It also integrates Class Activation 
Mapping (CAM) to provide visual explanations for the model's predictions.

Author: Ramaguru Radhakrishnan
Date: March 2025
"""

import torch
import torchvision.transforms as transforms
from PIL import Image
import json
from model import SwinTransformerMultiLabel
from torchcam.methods import SmoothGradCAMpp  # Explainability module
import matplotlib.pyplot as plt
import numpy as np


# Define the number of output classes (should match the trained model)
NUM_CLASSES = 18  

# Load the trained model with a correct classifier head
model = SwinTransformerMultiLabel(num_classes=NUM_CLASSES)

# Load model weights while ignoring mismatched layers
checkpoint_path = "../models/multi_nude_detector.pth"
checkpoint = torch.load(checkpoint_path, map_location="cpu")
model_dict = model.state_dict()

# Filter out layers that do not match
filtered_checkpoint = {
    k: v for k, v in checkpoint.items() if k in model_dict and v.shape == model_dict[k].shape
}
model_dict.update(filtered_checkpoint)
model.load_state_dict(model_dict, strict=False)

# Set the model to evaluation mode
model.eval()

# Define image preprocessing transformations
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize to model's input size
    transforms.ToTensor(),  # Convert to tensor
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),  # Normalize
])

# Load class labels from JSON file
with open("../data/labels.json", "r") as f:
    classes = sorted(set(tag for tags in json.load(f).values() for tag in tags))

# Validate that the number of classes matches
if len(classes) != NUM_CLASSES:
    raise ValueError(f"❌ Mismatch: Model expects {NUM_CLASSES} classes, but labels.json has {len(classes)} labels!")

# Load the test image
img_path = "C:\\Users\\RamaguruRadhakrishna\\Videos\\STAR-main\\STAR-main\\data\\images\\442_.jpeg"
image = Image.open(img_path).convert("RGB")  # Ensure RGB format
input_tensor = transform(image).unsqueeze(0)  # Add batch dimension

# Perform inference
with torch.no_grad():
    output = model(input_tensor)  # Forward pass through model
    print(f"🔹 Model Output Shape: {output.shape}")  # Debugging

    # Get predicted labels (threshold = 0.5)
    predicted_labels = [
        classes[i] for i in range(min(len(classes), output.shape[1])) if output[0][i] > 0.5
    ]
    predicted_indices = [i for i in range(output.shape[1]) if output[0][i] > 0.5]  # Store indices

# Display predicted labels
print("✅ Predicted Tags:", predicted_labels)

# ===============================
# Explainable AI: CAM Visualization
# ===============================

# Print model structure to find the correct target layer
print(model)

# Print model architecture to identify available layers
print("🔍 Model Architecture:\n")
for name, module in model.named_modules():
    print(name)  # Uncomment to see available layers

# Choose a valid convolutional layer from printed names
# Example: 'features.7.3' (Update this with an actual layer from print output)
valid_target_layer = "features.7.3"  # Modify based on your model structure

# Verify if the layer exists in the model
if valid_target_layer not in dict(model.named_modules()):
    raise ValueError(f"❌ Layer '{valid_target_layer}' not found in model. Choose from:\n{list(dict(model.named_modules()).keys())}")

# Initialize SmoothGradCAMpp with a valid layer
cam_extractor = SmoothGradCAMpp(model, target_layer=valid_target_layer)

print("✅ SmoothGradCAMpp initialized successfully!")

# Ensure model has processed the input before extracting CAM
output = model(input_tensor)  

# Generate CAM heatmaps for each predicted label
for class_idx in predicted_indices:
    cam = cam_extractor(class_idx, output)
    cam = cam.squeeze().cpu().numpy()
    cam = (cam - np.min(cam)) / (np.max(cam) - np.min(cam))  # Normalize

    # Resize CAM to match input image dimensions
    cam_resized = np.array(Image.fromarray(cam * 255).resize(image.size))

    # Overlay CAM on the original image
    plt.figure(figsize=(6, 6))
    plt.imshow(image)
    plt.imshow(cam_resized, cmap='jet', alpha=0.5)  # Heatmap overlay
    plt.axis("off")
    plt.title(f"Explainability Heatmap for '{classes[class_idx]}'")
    plt.show()