Added Source Folder
Browse files- src/dataset.py +78 -0
- src/inference.py +121 -0
- src/model.py +56 -0
- src/train.py +91 -0
src/dataset.py
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
import torch
|
| 4 |
+
from torch.utils.data import Dataset, DataLoader
|
| 5 |
+
from PIL import Image
|
| 6 |
+
import torchvision.transforms as transforms
|
| 7 |
+
|
| 8 |
+
class NudeMultiLabelDataset(Dataset):
|
| 9 |
+
def __init__(self, data_dir, label_file, transform=None):
|
| 10 |
+
self.data_dir = data_dir
|
| 11 |
+
self.transform = transform
|
| 12 |
+
self.label_file = label_file
|
| 13 |
+
|
| 14 |
+
# Load labels
|
| 15 |
+
with open(label_file, "r") as f:
|
| 16 |
+
self.labels = json.load(f)
|
| 17 |
+
|
| 18 |
+
self.image_paths = list(self.labels.keys())
|
| 19 |
+
self.classes = sorted(set(tag for tags in self.labels.values() for tag in tags))
|
| 20 |
+
self.class_to_idx = {tag: idx for idx, tag in enumerate(self.classes)}
|
| 21 |
+
|
| 22 |
+
# Print dataset info
|
| 23 |
+
print(f"📂 Dataset loaded from: {data_dir}")
|
| 24 |
+
print(f"📄 Labels loaded from: {label_file}")
|
| 25 |
+
print(f"🖼️ Total images: {len(self.image_paths)}")
|
| 26 |
+
print(f"🏷️ Unique labels: {len(self.classes)}")
|
| 27 |
+
print(f"🔹 Label-to-Index Mapping: {self.class_to_idx}")
|
| 28 |
+
|
| 29 |
+
# Print example data
|
| 30 |
+
if self.image_paths:
|
| 31 |
+
example_img, example_label = self.__getitem__(0)
|
| 32 |
+
print(f"✅ Example Image Shape: {example_img.shape}")
|
| 33 |
+
print(f"✅ Example Label: {example_label}")
|
| 34 |
+
|
| 35 |
+
def __len__(self):
|
| 36 |
+
return len(self.image_paths)
|
| 37 |
+
|
| 38 |
+
def __getitem__(self, idx):
|
| 39 |
+
img_name = self.image_paths[idx]
|
| 40 |
+
img_path = os.path.join(self.data_dir, img_name)
|
| 41 |
+
image = Image.open(img_path).convert("RGB")
|
| 42 |
+
|
| 43 |
+
# Convert labels to multi-hot encoding
|
| 44 |
+
labels = self.labels[img_name]
|
| 45 |
+
label_tensor = torch.zeros(len(self.classes))
|
| 46 |
+
for tag in labels:
|
| 47 |
+
if tag in self.class_to_idx:
|
| 48 |
+
label_tensor[self.class_to_idx[tag]] = 1 # Multi-label
|
| 49 |
+
|
| 50 |
+
if self.transform:
|
| 51 |
+
image = self.transform(image)
|
| 52 |
+
|
| 53 |
+
return image, label_tensor
|
| 54 |
+
|
| 55 |
+
# 🔹 Main function to test the dataset independently
|
| 56 |
+
if __name__ == "__main__":
|
| 57 |
+
# Set paths
|
| 58 |
+
DATA_DIR = "../data/images" # Change to actual path
|
| 59 |
+
LABEL_FILE = "../data/labels.json" # Change to actual path
|
| 60 |
+
|
| 61 |
+
# Define transformations
|
| 62 |
+
transform = transforms.Compose([
|
| 63 |
+
transforms.Resize((224, 224)),
|
| 64 |
+
transforms.ToTensor(),
|
| 65 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
| 66 |
+
])
|
| 67 |
+
|
| 68 |
+
# Load dataset
|
| 69 |
+
dataset = NudeMultiLabelDataset(DATA_DIR, LABEL_FILE, transform=transform)
|
| 70 |
+
|
| 71 |
+
# Create DataLoader for testing
|
| 72 |
+
dataloader = DataLoader(dataset, batch_size=4, shuffle=True)
|
| 73 |
+
|
| 74 |
+
# Fetch one batch and print information
|
| 75 |
+
for images, labels in dataloader:
|
| 76 |
+
print(f"🖼️ Batch Image Shape: {images.shape}") # Should be [batch_size, 3, 224, 224]
|
| 77 |
+
print(f"🏷️ Batch Labels: {labels}")
|
| 78 |
+
break # Stop after one batch
|
src/inference.py
ADDED
|
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Explainable AI (XAI) Inference for Nude Multi-Label Classification
|
| 3 |
+
==================================================================
|
| 4 |
+
|
| 5 |
+
This script performs inference using a trained Swin Transformer model for
|
| 6 |
+
multi-label classification of nude images. It also integrates Class Activation
|
| 7 |
+
Mapping (CAM) to provide visual explanations for the model's predictions.
|
| 8 |
+
|
| 9 |
+
Author: Ramaguru Radhakrishnan
|
| 10 |
+
Date: March 2025
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
import torch
|
| 14 |
+
import torchvision.transforms as transforms
|
| 15 |
+
from PIL import Image
|
| 16 |
+
import json
|
| 17 |
+
from model import SwinTransformerMultiLabel
|
| 18 |
+
from torchcam.methods import SmoothGradCAMpp # Explainability module
|
| 19 |
+
import matplotlib.pyplot as plt
|
| 20 |
+
import numpy as np
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
# Define the number of output classes (should match the trained model)
|
| 24 |
+
NUM_CLASSES = 18
|
| 25 |
+
|
| 26 |
+
# Load the trained model with a correct classifier head
|
| 27 |
+
model = SwinTransformerMultiLabel(num_classes=NUM_CLASSES)
|
| 28 |
+
|
| 29 |
+
# Load model weights while ignoring mismatched layers
|
| 30 |
+
checkpoint_path = "../models/multi_nude_detector.pth"
|
| 31 |
+
checkpoint = torch.load(checkpoint_path, map_location="cpu")
|
| 32 |
+
model_dict = model.state_dict()
|
| 33 |
+
|
| 34 |
+
# Filter out layers that do not match
|
| 35 |
+
filtered_checkpoint = {
|
| 36 |
+
k: v for k, v in checkpoint.items() if k in model_dict and v.shape == model_dict[k].shape
|
| 37 |
+
}
|
| 38 |
+
model_dict.update(filtered_checkpoint)
|
| 39 |
+
model.load_state_dict(model_dict, strict=False)
|
| 40 |
+
|
| 41 |
+
# Set the model to evaluation mode
|
| 42 |
+
model.eval()
|
| 43 |
+
|
| 44 |
+
# Define image preprocessing transformations
|
| 45 |
+
transform = transforms.Compose([
|
| 46 |
+
transforms.Resize((224, 224)), # Resize to model's input size
|
| 47 |
+
transforms.ToTensor(), # Convert to tensor
|
| 48 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), # Normalize
|
| 49 |
+
])
|
| 50 |
+
|
| 51 |
+
# Load class labels from JSON file
|
| 52 |
+
with open("../data/labels.json", "r") as f:
|
| 53 |
+
classes = sorted(set(tag for tags in json.load(f).values() for tag in tags))
|
| 54 |
+
|
| 55 |
+
# Validate that the number of classes matches
|
| 56 |
+
if len(classes) != NUM_CLASSES:
|
| 57 |
+
raise ValueError(f"❌ Mismatch: Model expects {NUM_CLASSES} classes, but labels.json has {len(classes)} labels!")
|
| 58 |
+
|
| 59 |
+
# Load the test image
|
| 60 |
+
img_path = "C:\\Users\\RamaguruRadhakrishna\\Videos\\STAR-main\\STAR-main\\data\\images\\442_.jpeg"
|
| 61 |
+
image = Image.open(img_path).convert("RGB") # Ensure RGB format
|
| 62 |
+
input_tensor = transform(image).unsqueeze(0) # Add batch dimension
|
| 63 |
+
|
| 64 |
+
# Perform inference
|
| 65 |
+
with torch.no_grad():
|
| 66 |
+
output = model(input_tensor) # Forward pass through model
|
| 67 |
+
print(f"🔹 Model Output Shape: {output.shape}") # Debugging
|
| 68 |
+
|
| 69 |
+
# Get predicted labels (threshold = 0.5)
|
| 70 |
+
predicted_labels = [
|
| 71 |
+
classes[i] for i in range(min(len(classes), output.shape[1])) if output[0][i] > 0.5
|
| 72 |
+
]
|
| 73 |
+
predicted_indices = [i for i in range(output.shape[1]) if output[0][i] > 0.5] # Store indices
|
| 74 |
+
|
| 75 |
+
# Display predicted labels
|
| 76 |
+
print("✅ Predicted Tags:", predicted_labels)
|
| 77 |
+
|
| 78 |
+
# ===============================
|
| 79 |
+
# Explainable AI: CAM Visualization
|
| 80 |
+
# ===============================
|
| 81 |
+
|
| 82 |
+
# Print model structure to find the correct target layer
|
| 83 |
+
print(model)
|
| 84 |
+
|
| 85 |
+
# Print model architecture to identify available layers
|
| 86 |
+
print("🔍 Model Architecture:\n")
|
| 87 |
+
for name, module in model.named_modules():
|
| 88 |
+
print(name) # Uncomment to see available layers
|
| 89 |
+
|
| 90 |
+
# Choose a valid convolutional layer from printed names
|
| 91 |
+
# Example: 'features.7.3' (Update this with an actual layer from print output)
|
| 92 |
+
valid_target_layer = "features.7.3" # Modify based on your model structure
|
| 93 |
+
|
| 94 |
+
# Verify if the layer exists in the model
|
| 95 |
+
if valid_target_layer not in dict(model.named_modules()):
|
| 96 |
+
raise ValueError(f"❌ Layer '{valid_target_layer}' not found in model. Choose from:\n{list(dict(model.named_modules()).keys())}")
|
| 97 |
+
|
| 98 |
+
# Initialize SmoothGradCAMpp with a valid layer
|
| 99 |
+
cam_extractor = SmoothGradCAMpp(model, target_layer=valid_target_layer)
|
| 100 |
+
|
| 101 |
+
print("✅ SmoothGradCAMpp initialized successfully!")
|
| 102 |
+
|
| 103 |
+
# Ensure model has processed the input before extracting CAM
|
| 104 |
+
output = model(input_tensor)
|
| 105 |
+
|
| 106 |
+
# Generate CAM heatmaps for each predicted label
|
| 107 |
+
for class_idx in predicted_indices:
|
| 108 |
+
cam = cam_extractor(class_idx, output)
|
| 109 |
+
cam = cam.squeeze().cpu().numpy()
|
| 110 |
+
cam = (cam - np.min(cam)) / (np.max(cam) - np.min(cam)) # Normalize
|
| 111 |
+
|
| 112 |
+
# Resize CAM to match input image dimensions
|
| 113 |
+
cam_resized = np.array(Image.fromarray(cam * 255).resize(image.size))
|
| 114 |
+
|
| 115 |
+
# Overlay CAM on the original image
|
| 116 |
+
plt.figure(figsize=(6, 6))
|
| 117 |
+
plt.imshow(image)
|
| 118 |
+
plt.imshow(cam_resized, cmap='jet', alpha=0.5) # Heatmap overlay
|
| 119 |
+
plt.axis("off")
|
| 120 |
+
plt.title(f"Explainability Heatmap for '{classes[class_idx]}'")
|
| 121 |
+
plt.show()
|
src/model.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from torchvision.models import swin_t
|
| 4 |
+
|
| 5 |
+
class SwinTransformerMultiLabel(nn.Module):
|
| 6 |
+
def __init__(self, num_classes):
|
| 7 |
+
super(SwinTransformerMultiLabel, self).__init__()
|
| 8 |
+
self.model = swin_t(weights="IMAGENET1K_V1")
|
| 9 |
+
|
| 10 |
+
# Adjust final classification layer
|
| 11 |
+
in_features = self.model.head.in_features # Should be 768
|
| 12 |
+
self.model.head = nn.Linear(in_features, num_classes)
|
| 13 |
+
|
| 14 |
+
def forward(self, x):
|
| 15 |
+
x = self.model.features(x) # Extract features
|
| 16 |
+
|
| 17 |
+
print(f"🔹 Feature map shape before flattening: {x.shape}") # Debugging output
|
| 18 |
+
|
| 19 |
+
# ✅ Correctly apply GAP over height & width
|
| 20 |
+
x = x.mean(dim=[1, 2]) # Now shape is (batch_size, 768)
|
| 21 |
+
print(f"🔹 Feature shape after GAP: {x.shape}")
|
| 22 |
+
|
| 23 |
+
x = self.model.head(x) # Classification layer
|
| 24 |
+
return x
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def main():
|
| 28 |
+
# Define number of classes
|
| 29 |
+
num_classes = 2
|
| 30 |
+
|
| 31 |
+
# Create the model
|
| 32 |
+
model = SwinTransformerMultiLabel(num_classes)
|
| 33 |
+
|
| 34 |
+
# Set the model to evaluation mode
|
| 35 |
+
model.eval()
|
| 36 |
+
|
| 37 |
+
# Generate a dummy input tensor (batch_size=5, channels=3, height=224, width=224)
|
| 38 |
+
dummy_input = torch.randn(5, 3, 224, 224)
|
| 39 |
+
|
| 40 |
+
# Forward pass through the model
|
| 41 |
+
output = model(dummy_input)
|
| 42 |
+
|
| 43 |
+
# Print output shape
|
| 44 |
+
print(f"✅ Model output shape: {output.shape}") # Expected: (5, 2)
|
| 45 |
+
|
| 46 |
+
# Check model parameters (classification head)
|
| 47 |
+
print(f"✅ Model classification head: {model.model.head}")
|
| 48 |
+
|
| 49 |
+
# Check with different batch sizes
|
| 50 |
+
for batch_size in [1, 8, 16]:
|
| 51 |
+
dummy_input = torch.randn(batch_size, 3, 224, 224)
|
| 52 |
+
output = model(dummy_input)
|
| 53 |
+
print(f"✅ Batch Size {batch_size} -> Output Shape: {output.shape}")
|
| 54 |
+
|
| 55 |
+
if __name__ == "__main__":
|
| 56 |
+
main()
|
src/train.py
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Transformer-based Nude Classification Model Training Script
|
| 3 |
+
|
| 4 |
+
Author: Ramaguru Radhakrishnan
|
| 5 |
+
Description:
|
| 6 |
+
This script trains a multi-label classification model based on the Swin Transformer architecture
|
| 7 |
+
to classify images into various adult content categories. The dataset and label information
|
| 8 |
+
are provided as inputs, and the trained model is saved for later inference.
|
| 9 |
+
|
| 10 |
+
Usage:
|
| 11 |
+
python train.py --data <path_to_dataset> --labels <path_to_labels.json> --save <path_to_save_model>
|
| 12 |
+
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
import torch
|
| 16 |
+
import torchvision.transforms as transforms
|
| 17 |
+
from torch.utils.data import DataLoader
|
| 18 |
+
from dataset import NudeMultiLabelDataset
|
| 19 |
+
from model import SwinTransformerMultiLabel
|
| 20 |
+
import argparse
|
| 21 |
+
import os
|
| 22 |
+
import time
|
| 23 |
+
|
| 24 |
+
# Argument parser for command-line input
|
| 25 |
+
parser = argparse.ArgumentParser(description="Train a Transformer-based nude classification model")
|
| 26 |
+
parser.add_argument("--data", type=str, required=True, help="Path to dataset directory")
|
| 27 |
+
parser.add_argument("--labels", type=str, required=True, help="Path to labels.json file")
|
| 28 |
+
parser.add_argument("--save", type=str, required=True, help="Directory to save trained model")
|
| 29 |
+
args = parser.parse_args()
|
| 30 |
+
|
| 31 |
+
# Define image preprocessing transformations
|
| 32 |
+
transform = transforms.Compose([
|
| 33 |
+
transforms.Resize((224, 224)), # Resize images to match model input size
|
| 34 |
+
transforms.ToTensor(), # Convert images to PyTorch tensors
|
| 35 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), # Normalize image pixel values
|
| 36 |
+
])
|
| 37 |
+
|
| 38 |
+
# Load dataset using the custom dataset class
|
| 39 |
+
dataloader = DataLoader(dataset, batch_size=32, shuffle=True) # Create a data loader for batching
|
| 40 |
+
dataset = NudeMultiLabelDataset(args.data, args.labels, transform=transform)
|
| 41 |
+
|
| 42 |
+
# Initialize the model and move it to the appropriate device (GPU if available, else CPU)
|
| 43 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 44 |
+
model = SwinTransformerMultiLabel(num_classes=len(dataset.classes)).to(device)
|
| 45 |
+
|
| 46 |
+
# Define loss function and optimizer
|
| 47 |
+
criterion = torch.nn.BCEWithLogitsLoss() # Binary Cross Entropy Loss for multi-label classification
|
| 48 |
+
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001) # Adam optimizer with a learning rate of 0.0001
|
| 49 |
+
|
| 50 |
+
# Start measuring total training time
|
| 51 |
+
start_time = time.time()
|
| 52 |
+
|
| 53 |
+
# Training loop for multiple epochs
|
| 54 |
+
epochs = 50
|
| 55 |
+
for epoch in range(epochs):
|
| 56 |
+
epoch_loss = 0.0
|
| 57 |
+
epoch_start = time.time() # Track time taken for each epoch
|
| 58 |
+
|
| 59 |
+
for imgs, labels in dataloader:
|
| 60 |
+
imgs, labels = imgs.to(device), labels.to(device) # Move data to the same device as the model
|
| 61 |
+
|
| 62 |
+
optimizer.zero_grad() # Reset gradients before backpropagation
|
| 63 |
+
outputs = model(imgs) # Forward pass: Get model predictions
|
| 64 |
+
|
| 65 |
+
# Debugging: Print tensor shapes to check dimensions
|
| 66 |
+
print(f"🔹 Outputs shape: {outputs.shape}") # Expected: [batch_size, num_classes]
|
| 67 |
+
print(f"🔹 Labels shape: {labels.shape}") # Expected: [batch_size, num_classes]
|
| 68 |
+
|
| 69 |
+
# Ensure output dimensions match expected shape
|
| 70 |
+
if outputs.dim() > 2:
|
| 71 |
+
outputs = outputs.view(outputs.size(0), -1) # Flatten spatial dimensions if present
|
| 72 |
+
|
| 73 |
+
# Compute loss and update model parameters
|
| 74 |
+
loss = criterion(outputs, labels)
|
| 75 |
+
loss.backward() # Compute gradients
|
| 76 |
+
optimizer.step() # Update model weights
|
| 77 |
+
|
| 78 |
+
epoch_loss += loss.item() # Accumulate loss for this epoch
|
| 79 |
+
|
| 80 |
+
epoch_end = time.time() # Record epoch end time
|
| 81 |
+
print(f"Epoch {epoch+1}/{epochs}, Loss: {epoch_loss / len(dataloader)}, Time: {epoch_end - epoch_start:.2f} sec")
|
| 82 |
+
|
| 83 |
+
# End measuring total training time
|
| 84 |
+
end_time = time.time()
|
| 85 |
+
total_time = end_time - start_time
|
| 86 |
+
|
| 87 |
+
# Save trained model to the specified directory
|
| 88 |
+
os.makedirs(args.save, exist_ok=True)
|
| 89 |
+
torch.save(model.state_dict(), os.path.join(args.save, "star.pth"))
|
| 90 |
+
print(f"✅ Model saved at {args.save}/star.pth")
|
| 91 |
+
print(f"⏳ Total Training Time: {total_time:.2f} seconds ({total_time/60:.2f} minutes)")
|