| """
|
| Transformer-based Nude Classification Model Training Script
|
|
|
| Author: Ramaguru Radhakrishnan
|
| Description:
|
| This script trains a multi-label classification model based on the Swin Transformer architecture
|
| to classify images into various adult content categories. The dataset and label information
|
| are provided as inputs, and the trained model is saved for later inference.
|
|
|
| Usage:
|
| python train.py --data <path_to_dataset> --labels <path_to_labels.json> --save <path_to_save_model>
|
|
|
| """
|
|
|
| import torch
|
| import torchvision.transforms as transforms
|
| from torch.utils.data import DataLoader
|
| from dataset import NudeMultiLabelDataset
|
| from model import SwinTransformerMultiLabel
|
| import argparse
|
| import os
|
| import time
|
|
|
|
|
| parser = argparse.ArgumentParser(description="Train a Transformer-based nude classification model")
|
| parser.add_argument("--data", type=str, required=True, help="Path to dataset directory")
|
| parser.add_argument("--labels", type=str, required=True, help="Path to labels.json file")
|
| parser.add_argument("--save", type=str, required=True, help="Directory to save trained model")
|
| args = parser.parse_args()
|
|
|
|
|
| transform = transforms.Compose([
|
| transforms.Resize((224, 224)),
|
| transforms.ToTensor(),
|
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
| ])
|
|
|
|
|
| dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
|
| dataset = NudeMultiLabelDataset(args.data, args.labels, transform=transform)
|
|
|
|
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| model = SwinTransformerMultiLabel(num_classes=len(dataset.classes)).to(device)
|
|
|
|
|
| criterion = torch.nn.BCEWithLogitsLoss()
|
| optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
|
|
|
|
|
| start_time = time.time()
|
|
|
|
|
| epochs = 50
|
| for epoch in range(epochs):
|
| epoch_loss = 0.0
|
| epoch_start = time.time()
|
|
|
| for imgs, labels in dataloader:
|
| imgs, labels = imgs.to(device), labels.to(device)
|
|
|
| optimizer.zero_grad()
|
| outputs = model(imgs)
|
|
|
|
|
| print(f"🔹 Outputs shape: {outputs.shape}")
|
| print(f"🔹 Labels shape: {labels.shape}")
|
|
|
|
|
| if outputs.dim() > 2:
|
| outputs = outputs.view(outputs.size(0), -1)
|
|
|
|
|
| loss = criterion(outputs, labels)
|
| loss.backward()
|
| optimizer.step()
|
|
|
| epoch_loss += loss.item()
|
|
|
| epoch_end = time.time()
|
| print(f"Epoch {epoch+1}/{epochs}, Loss: {epoch_loss / len(dataloader)}, Time: {epoch_end - epoch_start:.2f} sec")
|
|
|
|
|
| end_time = time.time()
|
| total_time = end_time - start_time
|
|
|
|
|
| os.makedirs(args.save, exist_ok=True)
|
| torch.save(model.state_dict(), os.path.join(args.save, "star.pth"))
|
| print(f"✅ Model saved at {args.save}/star.pth")
|
| print(f"⏳ Total Training Time: {total_time:.2f} seconds ({total_time/60:.2f} minutes)")
|
|
|