STAR / src /inference.py
ramagururadhakrishnan's picture
Added Source Folder
af59080 verified
"""
Explainable AI (XAI) Inference for Nude Multi-Label Classification
==================================================================
This script performs inference using a trained Swin Transformer model for
multi-label classification of nude images. It also integrates Class Activation
Mapping (CAM) to provide visual explanations for the model's predictions.
Author: Ramaguru Radhakrishnan
Date: March 2025
"""
import torch
import torchvision.transforms as transforms
from PIL import Image
import json
from model import SwinTransformerMultiLabel
from torchcam.methods import SmoothGradCAMpp # Explainability module
import matplotlib.pyplot as plt
import numpy as np
# Define the number of output classes (should match the trained model)
NUM_CLASSES = 18
# Load the trained model with a correct classifier head
model = SwinTransformerMultiLabel(num_classes=NUM_CLASSES)
# Load model weights while ignoring mismatched layers
checkpoint_path = "../models/multi_nude_detector.pth"
checkpoint = torch.load(checkpoint_path, map_location="cpu")
model_dict = model.state_dict()
# Filter out layers that do not match
filtered_checkpoint = {
k: v for k, v in checkpoint.items() if k in model_dict and v.shape == model_dict[k].shape
}
model_dict.update(filtered_checkpoint)
model.load_state_dict(model_dict, strict=False)
# Set the model to evaluation mode
model.eval()
# Define image preprocessing transformations
transform = transforms.Compose([
transforms.Resize((224, 224)), # Resize to model's input size
transforms.ToTensor(), # Convert to tensor
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), # Normalize
])
# Load class labels from JSON file
with open("../data/labels.json", "r") as f:
classes = sorted(set(tag for tags in json.load(f).values() for tag in tags))
# Validate that the number of classes matches
if len(classes) != NUM_CLASSES:
raise ValueError(f"❌ Mismatch: Model expects {NUM_CLASSES} classes, but labels.json has {len(classes)} labels!")
# Load the test image
img_path = "C:\\Users\\RamaguruRadhakrishna\\Videos\\STAR-main\\STAR-main\\data\\images\\442_.jpeg"
image = Image.open(img_path).convert("RGB") # Ensure RGB format
input_tensor = transform(image).unsqueeze(0) # Add batch dimension
# Perform inference
with torch.no_grad():
output = model(input_tensor) # Forward pass through model
print(f"πŸ”Ή Model Output Shape: {output.shape}") # Debugging
# Get predicted labels (threshold = 0.5)
predicted_labels = [
classes[i] for i in range(min(len(classes), output.shape[1])) if output[0][i] > 0.5
]
predicted_indices = [i for i in range(output.shape[1]) if output[0][i] > 0.5] # Store indices
# Display predicted labels
print("βœ… Predicted Tags:", predicted_labels)
# ===============================
# Explainable AI: CAM Visualization
# ===============================
# Print model structure to find the correct target layer
print(model)
# Print model architecture to identify available layers
print("πŸ” Model Architecture:\n")
for name, module in model.named_modules():
print(name) # Uncomment to see available layers
# Choose a valid convolutional layer from printed names
# Example: 'features.7.3' (Update this with an actual layer from print output)
valid_target_layer = "features.7.3" # Modify based on your model structure
# Verify if the layer exists in the model
if valid_target_layer not in dict(model.named_modules()):
raise ValueError(f"❌ Layer '{valid_target_layer}' not found in model. Choose from:\n{list(dict(model.named_modules()).keys())}")
# Initialize SmoothGradCAMpp with a valid layer
cam_extractor = SmoothGradCAMpp(model, target_layer=valid_target_layer)
print("βœ… SmoothGradCAMpp initialized successfully!")
# Ensure model has processed the input before extracting CAM
output = model(input_tensor)
# Generate CAM heatmaps for each predicted label
for class_idx in predicted_indices:
cam = cam_extractor(class_idx, output)
cam = cam.squeeze().cpu().numpy()
cam = (cam - np.min(cam)) / (np.max(cam) - np.min(cam)) # Normalize
# Resize CAM to match input image dimensions
cam_resized = np.array(Image.fromarray(cam * 255).resize(image.size))
# Overlay CAM on the original image
plt.figure(figsize=(6, 6))
plt.imshow(image)
plt.imshow(cam_resized, cmap='jet', alpha=0.5) # Heatmap overlay
plt.axis("off")
plt.title(f"Explainability Heatmap for '{classes[class_idx]}'")
plt.show()