ramagururadhakrishnan commited on
Commit
af59080
·
verified ·
1 Parent(s): 1eaff06

Added Source Folder

Browse files
Files changed (4) hide show
  1. src/dataset.py +78 -0
  2. src/inference.py +121 -0
  3. src/model.py +56 -0
  4. src/train.py +91 -0
src/dataset.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import torch
4
+ from torch.utils.data import Dataset, DataLoader
5
+ from PIL import Image
6
+ import torchvision.transforms as transforms
7
+
8
+ class NudeMultiLabelDataset(Dataset):
9
+ def __init__(self, data_dir, label_file, transform=None):
10
+ self.data_dir = data_dir
11
+ self.transform = transform
12
+ self.label_file = label_file
13
+
14
+ # Load labels
15
+ with open(label_file, "r") as f:
16
+ self.labels = json.load(f)
17
+
18
+ self.image_paths = list(self.labels.keys())
19
+ self.classes = sorted(set(tag for tags in self.labels.values() for tag in tags))
20
+ self.class_to_idx = {tag: idx for idx, tag in enumerate(self.classes)}
21
+
22
+ # Print dataset info
23
+ print(f"📂 Dataset loaded from: {data_dir}")
24
+ print(f"📄 Labels loaded from: {label_file}")
25
+ print(f"🖼️ Total images: {len(self.image_paths)}")
26
+ print(f"🏷️ Unique labels: {len(self.classes)}")
27
+ print(f"🔹 Label-to-Index Mapping: {self.class_to_idx}")
28
+
29
+ # Print example data
30
+ if self.image_paths:
31
+ example_img, example_label = self.__getitem__(0)
32
+ print(f"✅ Example Image Shape: {example_img.shape}")
33
+ print(f"✅ Example Label: {example_label}")
34
+
35
+ def __len__(self):
36
+ return len(self.image_paths)
37
+
38
+ def __getitem__(self, idx):
39
+ img_name = self.image_paths[idx]
40
+ img_path = os.path.join(self.data_dir, img_name)
41
+ image = Image.open(img_path).convert("RGB")
42
+
43
+ # Convert labels to multi-hot encoding
44
+ labels = self.labels[img_name]
45
+ label_tensor = torch.zeros(len(self.classes))
46
+ for tag in labels:
47
+ if tag in self.class_to_idx:
48
+ label_tensor[self.class_to_idx[tag]] = 1 # Multi-label
49
+
50
+ if self.transform:
51
+ image = self.transform(image)
52
+
53
+ return image, label_tensor
54
+
55
+ # 🔹 Main function to test the dataset independently
56
+ if __name__ == "__main__":
57
+ # Set paths
58
+ DATA_DIR = "../data/images" # Change to actual path
59
+ LABEL_FILE = "../data/labels.json" # Change to actual path
60
+
61
+ # Define transformations
62
+ transform = transforms.Compose([
63
+ transforms.Resize((224, 224)),
64
+ transforms.ToTensor(),
65
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
66
+ ])
67
+
68
+ # Load dataset
69
+ dataset = NudeMultiLabelDataset(DATA_DIR, LABEL_FILE, transform=transform)
70
+
71
+ # Create DataLoader for testing
72
+ dataloader = DataLoader(dataset, batch_size=4, shuffle=True)
73
+
74
+ # Fetch one batch and print information
75
+ for images, labels in dataloader:
76
+ print(f"🖼️ Batch Image Shape: {images.shape}") # Should be [batch_size, 3, 224, 224]
77
+ print(f"🏷️ Batch Labels: {labels}")
78
+ break # Stop after one batch
src/inference.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Explainable AI (XAI) Inference for Nude Multi-Label Classification
3
+ ==================================================================
4
+
5
+ This script performs inference using a trained Swin Transformer model for
6
+ multi-label classification of nude images. It also integrates Class Activation
7
+ Mapping (CAM) to provide visual explanations for the model's predictions.
8
+
9
+ Author: Ramaguru Radhakrishnan
10
+ Date: March 2025
11
+ """
12
+
13
+ import torch
14
+ import torchvision.transforms as transforms
15
+ from PIL import Image
16
+ import json
17
+ from model import SwinTransformerMultiLabel
18
+ from torchcam.methods import SmoothGradCAMpp # Explainability module
19
+ import matplotlib.pyplot as plt
20
+ import numpy as np
21
+
22
+
23
+ # Define the number of output classes (should match the trained model)
24
+ NUM_CLASSES = 18
25
+
26
+ # Load the trained model with a correct classifier head
27
+ model = SwinTransformerMultiLabel(num_classes=NUM_CLASSES)
28
+
29
+ # Load model weights while ignoring mismatched layers
30
+ checkpoint_path = "../models/multi_nude_detector.pth"
31
+ checkpoint = torch.load(checkpoint_path, map_location="cpu")
32
+ model_dict = model.state_dict()
33
+
34
+ # Filter out layers that do not match
35
+ filtered_checkpoint = {
36
+ k: v for k, v in checkpoint.items() if k in model_dict and v.shape == model_dict[k].shape
37
+ }
38
+ model_dict.update(filtered_checkpoint)
39
+ model.load_state_dict(model_dict, strict=False)
40
+
41
+ # Set the model to evaluation mode
42
+ model.eval()
43
+
44
+ # Define image preprocessing transformations
45
+ transform = transforms.Compose([
46
+ transforms.Resize((224, 224)), # Resize to model's input size
47
+ transforms.ToTensor(), # Convert to tensor
48
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), # Normalize
49
+ ])
50
+
51
+ # Load class labels from JSON file
52
+ with open("../data/labels.json", "r") as f:
53
+ classes = sorted(set(tag for tags in json.load(f).values() for tag in tags))
54
+
55
+ # Validate that the number of classes matches
56
+ if len(classes) != NUM_CLASSES:
57
+ raise ValueError(f"❌ Mismatch: Model expects {NUM_CLASSES} classes, but labels.json has {len(classes)} labels!")
58
+
59
+ # Load the test image
60
+ img_path = "C:\\Users\\RamaguruRadhakrishna\\Videos\\STAR-main\\STAR-main\\data\\images\\442_.jpeg"
61
+ image = Image.open(img_path).convert("RGB") # Ensure RGB format
62
+ input_tensor = transform(image).unsqueeze(0) # Add batch dimension
63
+
64
+ # Perform inference
65
+ with torch.no_grad():
66
+ output = model(input_tensor) # Forward pass through model
67
+ print(f"🔹 Model Output Shape: {output.shape}") # Debugging
68
+
69
+ # Get predicted labels (threshold = 0.5)
70
+ predicted_labels = [
71
+ classes[i] for i in range(min(len(classes), output.shape[1])) if output[0][i] > 0.5
72
+ ]
73
+ predicted_indices = [i for i in range(output.shape[1]) if output[0][i] > 0.5] # Store indices
74
+
75
+ # Display predicted labels
76
+ print("✅ Predicted Tags:", predicted_labels)
77
+
78
+ # ===============================
79
+ # Explainable AI: CAM Visualization
80
+ # ===============================
81
+
82
+ # Print model structure to find the correct target layer
83
+ print(model)
84
+
85
+ # Print model architecture to identify available layers
86
+ print("🔍 Model Architecture:\n")
87
+ for name, module in model.named_modules():
88
+ print(name) # Uncomment to see available layers
89
+
90
+ # Choose a valid convolutional layer from printed names
91
+ # Example: 'features.7.3' (Update this with an actual layer from print output)
92
+ valid_target_layer = "features.7.3" # Modify based on your model structure
93
+
94
+ # Verify if the layer exists in the model
95
+ if valid_target_layer not in dict(model.named_modules()):
96
+ raise ValueError(f"❌ Layer '{valid_target_layer}' not found in model. Choose from:\n{list(dict(model.named_modules()).keys())}")
97
+
98
+ # Initialize SmoothGradCAMpp with a valid layer
99
+ cam_extractor = SmoothGradCAMpp(model, target_layer=valid_target_layer)
100
+
101
+ print("✅ SmoothGradCAMpp initialized successfully!")
102
+
103
+ # Ensure model has processed the input before extracting CAM
104
+ output = model(input_tensor)
105
+
106
+ # Generate CAM heatmaps for each predicted label
107
+ for class_idx in predicted_indices:
108
+ cam = cam_extractor(class_idx, output)
109
+ cam = cam.squeeze().cpu().numpy()
110
+ cam = (cam - np.min(cam)) / (np.max(cam) - np.min(cam)) # Normalize
111
+
112
+ # Resize CAM to match input image dimensions
113
+ cam_resized = np.array(Image.fromarray(cam * 255).resize(image.size))
114
+
115
+ # Overlay CAM on the original image
116
+ plt.figure(figsize=(6, 6))
117
+ plt.imshow(image)
118
+ plt.imshow(cam_resized, cmap='jet', alpha=0.5) # Heatmap overlay
119
+ plt.axis("off")
120
+ plt.title(f"Explainability Heatmap for '{classes[class_idx]}'")
121
+ plt.show()
src/model.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torchvision.models import swin_t
4
+
5
+ class SwinTransformerMultiLabel(nn.Module):
6
+ def __init__(self, num_classes):
7
+ super(SwinTransformerMultiLabel, self).__init__()
8
+ self.model = swin_t(weights="IMAGENET1K_V1")
9
+
10
+ # Adjust final classification layer
11
+ in_features = self.model.head.in_features # Should be 768
12
+ self.model.head = nn.Linear(in_features, num_classes)
13
+
14
+ def forward(self, x):
15
+ x = self.model.features(x) # Extract features
16
+
17
+ print(f"🔹 Feature map shape before flattening: {x.shape}") # Debugging output
18
+
19
+ # ✅ Correctly apply GAP over height & width
20
+ x = x.mean(dim=[1, 2]) # Now shape is (batch_size, 768)
21
+ print(f"🔹 Feature shape after GAP: {x.shape}")
22
+
23
+ x = self.model.head(x) # Classification layer
24
+ return x
25
+
26
+
27
+ def main():
28
+ # Define number of classes
29
+ num_classes = 2
30
+
31
+ # Create the model
32
+ model = SwinTransformerMultiLabel(num_classes)
33
+
34
+ # Set the model to evaluation mode
35
+ model.eval()
36
+
37
+ # Generate a dummy input tensor (batch_size=5, channels=3, height=224, width=224)
38
+ dummy_input = torch.randn(5, 3, 224, 224)
39
+
40
+ # Forward pass through the model
41
+ output = model(dummy_input)
42
+
43
+ # Print output shape
44
+ print(f"✅ Model output shape: {output.shape}") # Expected: (5, 2)
45
+
46
+ # Check model parameters (classification head)
47
+ print(f"✅ Model classification head: {model.model.head}")
48
+
49
+ # Check with different batch sizes
50
+ for batch_size in [1, 8, 16]:
51
+ dummy_input = torch.randn(batch_size, 3, 224, 224)
52
+ output = model(dummy_input)
53
+ print(f"✅ Batch Size {batch_size} -> Output Shape: {output.shape}")
54
+
55
+ if __name__ == "__main__":
56
+ main()
src/train.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Transformer-based Nude Classification Model Training Script
3
+
4
+ Author: Ramaguru Radhakrishnan
5
+ Description:
6
+ This script trains a multi-label classification model based on the Swin Transformer architecture
7
+ to classify images into various adult content categories. The dataset and label information
8
+ are provided as inputs, and the trained model is saved for later inference.
9
+
10
+ Usage:
11
+ python train.py --data <path_to_dataset> --labels <path_to_labels.json> --save <path_to_save_model>
12
+
13
+ """
14
+
15
+ import torch
16
+ import torchvision.transforms as transforms
17
+ from torch.utils.data import DataLoader
18
+ from dataset import NudeMultiLabelDataset
19
+ from model import SwinTransformerMultiLabel
20
+ import argparse
21
+ import os
22
+ import time
23
+
24
+ # Argument parser for command-line input
25
+ parser = argparse.ArgumentParser(description="Train a Transformer-based nude classification model")
26
+ parser.add_argument("--data", type=str, required=True, help="Path to dataset directory")
27
+ parser.add_argument("--labels", type=str, required=True, help="Path to labels.json file")
28
+ parser.add_argument("--save", type=str, required=True, help="Directory to save trained model")
29
+ args = parser.parse_args()
30
+
31
+ # Define image preprocessing transformations
32
+ transform = transforms.Compose([
33
+ transforms.Resize((224, 224)), # Resize images to match model input size
34
+ transforms.ToTensor(), # Convert images to PyTorch tensors
35
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), # Normalize image pixel values
36
+ ])
37
+
38
+ # Load dataset using the custom dataset class
39
+ dataloader = DataLoader(dataset, batch_size=32, shuffle=True) # Create a data loader for batching
40
+ dataset = NudeMultiLabelDataset(args.data, args.labels, transform=transform)
41
+
42
+ # Initialize the model and move it to the appropriate device (GPU if available, else CPU)
43
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
44
+ model = SwinTransformerMultiLabel(num_classes=len(dataset.classes)).to(device)
45
+
46
+ # Define loss function and optimizer
47
+ criterion = torch.nn.BCEWithLogitsLoss() # Binary Cross Entropy Loss for multi-label classification
48
+ optimizer = torch.optim.Adam(model.parameters(), lr=0.0001) # Adam optimizer with a learning rate of 0.0001
49
+
50
+ # Start measuring total training time
51
+ start_time = time.time()
52
+
53
+ # Training loop for multiple epochs
54
+ epochs = 50
55
+ for epoch in range(epochs):
56
+ epoch_loss = 0.0
57
+ epoch_start = time.time() # Track time taken for each epoch
58
+
59
+ for imgs, labels in dataloader:
60
+ imgs, labels = imgs.to(device), labels.to(device) # Move data to the same device as the model
61
+
62
+ optimizer.zero_grad() # Reset gradients before backpropagation
63
+ outputs = model(imgs) # Forward pass: Get model predictions
64
+
65
+ # Debugging: Print tensor shapes to check dimensions
66
+ print(f"🔹 Outputs shape: {outputs.shape}") # Expected: [batch_size, num_classes]
67
+ print(f"🔹 Labels shape: {labels.shape}") # Expected: [batch_size, num_classes]
68
+
69
+ # Ensure output dimensions match expected shape
70
+ if outputs.dim() > 2:
71
+ outputs = outputs.view(outputs.size(0), -1) # Flatten spatial dimensions if present
72
+
73
+ # Compute loss and update model parameters
74
+ loss = criterion(outputs, labels)
75
+ loss.backward() # Compute gradients
76
+ optimizer.step() # Update model weights
77
+
78
+ epoch_loss += loss.item() # Accumulate loss for this epoch
79
+
80
+ epoch_end = time.time() # Record epoch end time
81
+ print(f"Epoch {epoch+1}/{epochs}, Loss: {epoch_loss / len(dataloader)}, Time: {epoch_end - epoch_start:.2f} sec")
82
+
83
+ # End measuring total training time
84
+ end_time = time.time()
85
+ total_time = end_time - start_time
86
+
87
+ # Save trained model to the specified directory
88
+ os.makedirs(args.save, exist_ok=True)
89
+ torch.save(model.state_dict(), os.path.join(args.save, "star.pth"))
90
+ print(f"✅ Model saved at {args.save}/star.pth")
91
+ print(f"⏳ Total Training Time: {total_time:.2f} seconds ({total_time/60:.2f} minutes)")