""" Transformer-based Nude Classification Model Training Script Author: Ramaguru Radhakrishnan Description: This script trains a multi-label classification model based on the Swin Transformer architecture to classify images into various adult content categories. The dataset and label information are provided as inputs, and the trained model is saved for later inference. Usage: python train.py --data --labels --save """ import torch import torchvision.transforms as transforms from torch.utils.data import DataLoader from dataset import NudeMultiLabelDataset from model import SwinTransformerMultiLabel import argparse import os import time # Argument parser for command-line input parser = argparse.ArgumentParser(description="Train a Transformer-based nude classification model") parser.add_argument("--data", type=str, required=True, help="Path to dataset directory") parser.add_argument("--labels", type=str, required=True, help="Path to labels.json file") parser.add_argument("--save", type=str, required=True, help="Directory to save trained model") args = parser.parse_args() # Define image preprocessing transformations transform = transforms.Compose([ transforms.Resize((224, 224)), # Resize images to match model input size transforms.ToTensor(), # Convert images to PyTorch tensors transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), # Normalize image pixel values ]) # Load dataset using the custom dataset class dataloader = DataLoader(dataset, batch_size=32, shuffle=True) # Create a data loader for batching dataset = NudeMultiLabelDataset(args.data, args.labels, transform=transform) # Initialize the model and move it to the appropriate device (GPU if available, else CPU) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = SwinTransformerMultiLabel(num_classes=len(dataset.classes)).to(device) # Define loss function and optimizer criterion = torch.nn.BCEWithLogitsLoss() # Binary Cross Entropy Loss for multi-label classification optimizer = torch.optim.Adam(model.parameters(), lr=0.0001) # Adam optimizer with a learning rate of 0.0001 # Start measuring total training time start_time = time.time() # Training loop for multiple epochs epochs = 50 for epoch in range(epochs): epoch_loss = 0.0 epoch_start = time.time() # Track time taken for each epoch for imgs, labels in dataloader: imgs, labels = imgs.to(device), labels.to(device) # Move data to the same device as the model optimizer.zero_grad() # Reset gradients before backpropagation outputs = model(imgs) # Forward pass: Get model predictions # Debugging: Print tensor shapes to check dimensions print(f"🔹 Outputs shape: {outputs.shape}") # Expected: [batch_size, num_classes] print(f"🔹 Labels shape: {labels.shape}") # Expected: [batch_size, num_classes] # Ensure output dimensions match expected shape if outputs.dim() > 2: outputs = outputs.view(outputs.size(0), -1) # Flatten spatial dimensions if present # Compute loss and update model parameters loss = criterion(outputs, labels) loss.backward() # Compute gradients optimizer.step() # Update model weights epoch_loss += loss.item() # Accumulate loss for this epoch epoch_end = time.time() # Record epoch end time print(f"Epoch {epoch+1}/{epochs}, Loss: {epoch_loss / len(dataloader)}, Time: {epoch_end - epoch_start:.2f} sec") # End measuring total training time end_time = time.time() total_time = end_time - start_time # Save trained model to the specified directory os.makedirs(args.save, exist_ok=True) torch.save(model.state_dict(), os.path.join(args.save, "star.pth")) print(f"✅ Model saved at {args.save}/star.pth") print(f"⏳ Total Training Time: {total_time:.2f} seconds ({total_time/60:.2f} minutes)")