""" Explainable AI (XAI) Inference for Nude Multi-Label Classification ================================================================== This script performs inference using a trained Swin Transformer model for multi-label classification of nude images. It also integrates Class Activation Mapping (CAM) to provide visual explanations for the model's predictions. Author: Ramaguru Radhakrishnan Date: March 2025 """ import torch import torchvision.transforms as transforms from PIL import Image import json from model import SwinTransformerMultiLabel from torchcam.methods import SmoothGradCAMpp # Explainability module import matplotlib.pyplot as plt import numpy as np # Define the number of output classes (should match the trained model) NUM_CLASSES = 18 # Load the trained model with a correct classifier head model = SwinTransformerMultiLabel(num_classes=NUM_CLASSES) # Load model weights while ignoring mismatched layers checkpoint_path = "../models/multi_nude_detector.pth" checkpoint = torch.load(checkpoint_path, map_location="cpu") model_dict = model.state_dict() # Filter out layers that do not match filtered_checkpoint = { k: v for k, v in checkpoint.items() if k in model_dict and v.shape == model_dict[k].shape } model_dict.update(filtered_checkpoint) model.load_state_dict(model_dict, strict=False) # Set the model to evaluation mode model.eval() # Define image preprocessing transformations transform = transforms.Compose([ transforms.Resize((224, 224)), # Resize to model's input size transforms.ToTensor(), # Convert to tensor transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), # Normalize ]) # Load class labels from JSON file with open("../data/labels.json", "r") as f: classes = sorted(set(tag for tags in json.load(f).values() for tag in tags)) # Validate that the number of classes matches if len(classes) != NUM_CLASSES: raise ValueError(f"❌ Mismatch: Model expects {NUM_CLASSES} classes, but labels.json has {len(classes)} labels!") # Load the test image img_path = "C:\\Users\\RamaguruRadhakrishna\\Videos\\STAR-main\\STAR-main\\data\\images\\442_.jpeg" image = Image.open(img_path).convert("RGB") # Ensure RGB format input_tensor = transform(image).unsqueeze(0) # Add batch dimension # Perform inference with torch.no_grad(): output = model(input_tensor) # Forward pass through model print(f"🔹 Model Output Shape: {output.shape}") # Debugging # Get predicted labels (threshold = 0.5) predicted_labels = [ classes[i] for i in range(min(len(classes), output.shape[1])) if output[0][i] > 0.5 ] predicted_indices = [i for i in range(output.shape[1]) if output[0][i] > 0.5] # Store indices # Display predicted labels print("✅ Predicted Tags:", predicted_labels) # =============================== # Explainable AI: CAM Visualization # =============================== # Print model structure to find the correct target layer print(model) # Print model architecture to identify available layers print("🔍 Model Architecture:\n") for name, module in model.named_modules(): print(name) # Uncomment to see available layers # Choose a valid convolutional layer from printed names # Example: 'features.7.3' (Update this with an actual layer from print output) valid_target_layer = "features.7.3" # Modify based on your model structure # Verify if the layer exists in the model if valid_target_layer not in dict(model.named_modules()): raise ValueError(f"❌ Layer '{valid_target_layer}' not found in model. Choose from:\n{list(dict(model.named_modules()).keys())}") # Initialize SmoothGradCAMpp with a valid layer cam_extractor = SmoothGradCAMpp(model, target_layer=valid_target_layer) print("✅ SmoothGradCAMpp initialized successfully!") # Ensure model has processed the input before extracting CAM output = model(input_tensor) # Generate CAM heatmaps for each predicted label for class_idx in predicted_indices: cam = cam_extractor(class_idx, output) cam = cam.squeeze().cpu().numpy() cam = (cam - np.min(cam)) / (np.max(cam) - np.min(cam)) # Normalize # Resize CAM to match input image dimensions cam_resized = np.array(Image.fromarray(cam * 255).resize(image.size)) # Overlay CAM on the original image plt.figure(figsize=(6, 6)) plt.imshow(image) plt.imshow(cam_resized, cmap='jet', alpha=0.5) # Heatmap overlay plt.axis("off") plt.title(f"Explainability Heatmap for '{classes[class_idx]}'") plt.show()