File size: 4,061 Bytes
af59080 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 | """
Transformer-based Nude Classification Model Training Script
Author: Ramaguru Radhakrishnan
Description:
This script trains a multi-label classification model based on the Swin Transformer architecture
to classify images into various adult content categories. The dataset and label information
are provided as inputs, and the trained model is saved for later inference.
Usage:
python train.py --data <path_to_dataset> --labels <path_to_labels.json> --save <path_to_save_model>
"""
import torch
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from dataset import NudeMultiLabelDataset
from model import SwinTransformerMultiLabel
import argparse
import os
import time
# Argument parser for command-line input
parser = argparse.ArgumentParser(description="Train a Transformer-based nude classification model")
parser.add_argument("--data", type=str, required=True, help="Path to dataset directory")
parser.add_argument("--labels", type=str, required=True, help="Path to labels.json file")
parser.add_argument("--save", type=str, required=True, help="Directory to save trained model")
args = parser.parse_args()
# Define image preprocessing transformations
transform = transforms.Compose([
transforms.Resize((224, 224)), # Resize images to match model input size
transforms.ToTensor(), # Convert images to PyTorch tensors
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), # Normalize image pixel values
])
# Load dataset using the custom dataset class
dataloader = DataLoader(dataset, batch_size=32, shuffle=True) # Create a data loader for batching
dataset = NudeMultiLabelDataset(args.data, args.labels, transform=transform)
# Initialize the model and move it to the appropriate device (GPU if available, else CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SwinTransformerMultiLabel(num_classes=len(dataset.classes)).to(device)
# Define loss function and optimizer
criterion = torch.nn.BCEWithLogitsLoss() # Binary Cross Entropy Loss for multi-label classification
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001) # Adam optimizer with a learning rate of 0.0001
# Start measuring total training time
start_time = time.time()
# Training loop for multiple epochs
epochs = 50
for epoch in range(epochs):
epoch_loss = 0.0
epoch_start = time.time() # Track time taken for each epoch
for imgs, labels in dataloader:
imgs, labels = imgs.to(device), labels.to(device) # Move data to the same device as the model
optimizer.zero_grad() # Reset gradients before backpropagation
outputs = model(imgs) # Forward pass: Get model predictions
# Debugging: Print tensor shapes to check dimensions
print(f"🔹 Outputs shape: {outputs.shape}") # Expected: [batch_size, num_classes]
print(f"🔹 Labels shape: {labels.shape}") # Expected: [batch_size, num_classes]
# Ensure output dimensions match expected shape
if outputs.dim() > 2:
outputs = outputs.view(outputs.size(0), -1) # Flatten spatial dimensions if present
# Compute loss and update model parameters
loss = criterion(outputs, labels)
loss.backward() # Compute gradients
optimizer.step() # Update model weights
epoch_loss += loss.item() # Accumulate loss for this epoch
epoch_end = time.time() # Record epoch end time
print(f"Epoch {epoch+1}/{epochs}, Loss: {epoch_loss / len(dataloader)}, Time: {epoch_end - epoch_start:.2f} sec")
# End measuring total training time
end_time = time.time()
total_time = end_time - start_time
# Save trained model to the specified directory
os.makedirs(args.save, exist_ok=True)
torch.save(model.state_dict(), os.path.join(args.save, "star.pth"))
print(f"✅ Model saved at {args.save}/star.pth")
print(f"⏳ Total Training Time: {total_time:.2f} seconds ({total_time/60:.2f} minutes)")
|