File size: 4,061 Bytes
af59080
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
"""

Transformer-based Nude Classification Model Training Script



Author: Ramaguru Radhakrishnan

Description:

This script trains a multi-label classification model based on the Swin Transformer architecture 

to classify images into various adult content categories. The dataset and label information 

are provided as inputs, and the trained model is saved for later inference.



Usage:

    python train.py --data <path_to_dataset> --labels <path_to_labels.json> --save <path_to_save_model>



"""

import torch
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from dataset import NudeMultiLabelDataset
from model import SwinTransformerMultiLabel
import argparse
import os
import time  

# Argument parser for command-line input
parser = argparse.ArgumentParser(description="Train a Transformer-based nude classification model")
parser.add_argument("--data", type=str, required=True, help="Path to dataset directory")
parser.add_argument("--labels", type=str, required=True, help="Path to labels.json file")
parser.add_argument("--save", type=str, required=True, help="Directory to save trained model")
args = parser.parse_args()

# Define image preprocessing transformations
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize images to match model input size
    transforms.ToTensor(),  # Convert images to PyTorch tensors
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),  # Normalize image pixel values
])

# Load dataset using the custom dataset class
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)  # Create a data loader for batching
dataset = NudeMultiLabelDataset(args.data, args.labels, transform=transform)

# Initialize the model and move it to the appropriate device (GPU if available, else CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SwinTransformerMultiLabel(num_classes=len(dataset.classes)).to(device)

# Define loss function and optimizer
criterion = torch.nn.BCEWithLogitsLoss()  # Binary Cross Entropy Loss for multi-label classification
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)  # Adam optimizer with a learning rate of 0.0001

# Start measuring total training time
start_time = time.time()

# Training loop for multiple epochs
epochs = 50
for epoch in range(epochs):
    epoch_loss = 0.0
    epoch_start = time.time()  # Track time taken for each epoch

    for imgs, labels in dataloader:
        imgs, labels = imgs.to(device), labels.to(device)  # Move data to the same device as the model

        optimizer.zero_grad()  # Reset gradients before backpropagation
        outputs = model(imgs)  # Forward pass: Get model predictions

        # Debugging: Print tensor shapes to check dimensions
        print(f"🔹 Outputs shape: {outputs.shape}")  # Expected: [batch_size, num_classes]
        print(f"🔹 Labels shape: {labels.shape}")  # Expected: [batch_size, num_classes]

        # Ensure output dimensions match expected shape
        if outputs.dim() > 2:
            outputs = outputs.view(outputs.size(0), -1)  # Flatten spatial dimensions if present

        # Compute loss and update model parameters
        loss = criterion(outputs, labels)
        loss.backward()  # Compute gradients
        optimizer.step()  # Update model weights

        epoch_loss += loss.item()  # Accumulate loss for this epoch

    epoch_end = time.time()  # Record epoch end time
    print(f"Epoch {epoch+1}/{epochs}, Loss: {epoch_loss / len(dataloader)}, Time: {epoch_end - epoch_start:.2f} sec")

# End measuring total training time
end_time = time.time()
total_time = end_time - start_time    

# Save trained model to the specified directory
os.makedirs(args.save, exist_ok=True)
torch.save(model.state_dict(), os.path.join(args.save, "star.pth"))
print(f"✅ Model saved at {args.save}/star.pth")
print(f"⏳ Total Training Time: {total_time:.2f} seconds ({total_time/60:.2f} minutes)")