import torch.optim as optim import torch import torch.nn as nn import torch.nn.parallel import torch.optim import torch.utils.data import torch.utils.data.distributed import torchvision.transforms as transforms import torchvision.models from torch.autograd import Variable from torch.utils.data import random_split import os import time import numpy as np import pandas as pd import torch.nn.functional as F from torch.utils.data import Dataset from torch.utils.data import DataLoader import matplotlib.pyplot as plt from PIL import Image import torchvision.datasets as dsets class ModifiedCIFAR10(Dataset): def __init__(self, root, train=True, transform=None, target_classes=[0, 1, 2, 3,4,5,6,7,8,9], num_samples=[500, 500, 2500, 2500,5000,5000,5000,5000,5000,5000]): self.original_dataset = dsets.CIFAR10(root=root, train=train, download=True, transform=transform) self.target_classes = target_classes self.num_samples = num_samples self.sample_indices = [] for target_class, num_sample in zip(target_classes, num_samples): class_indices = [i for i, label in enumerate(self.original_dataset.targets) if label == target_class] self.sample_indices += class_indices[:num_sample] def __len__(self): return len(self.sample_indices) def __getitem__(self, idx): original_idx = self.sample_indices[idx] return self.original_dataset[original_idx] #training parameters modellr = 1e-4 BATCH_SIZE = 64 EPOCHS = 20 DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # Add these variables to keep track of the best accuracy and epoch number best_accuracy = 0 best_epoch = 0 np.random.seed(42) torch.manual_seed(42) #data preprocess mean, std = [0.4914, 0.4822, 0.4465], [0.247, 0.243, 0.261] # These values are mostly used by researchers as found to very useful in fast convergence transform_train = transforms.Compose([ transforms.Resize((32, 32)), transforms.RandomHorizontalFlip(), transforms.RandomRotation(30), #newly added transforms.ColorJitter(brightness = 0.1, # Randomly adjust color jitter of the images contrast = 0.1, saturation = 0.1), transforms.RandomAdjustSharpness(sharpness_factor = 2, p = 0.1), # Randomly adjust sharpness transforms.ToTensor(), transforms.Normalize(mean, std), transforms.RandomErasing() ]) transform_test = transforms.Compose([ transforms.Resize((32, 32)), transforms.ToTensor(), transforms.Normalize(mean, std), ]) test_dataset = dsets.CIFAR10(root='./data', train=False, download=True, transform = transform_test) # Modify the number of samples for class 0 from 5000 to 500 modified_train_dataset = ModifiedCIFAR10( root='./data', train=True, transform=transform_train, target_classes=[0, 1, 2, 3,4,5,6,7,8,9], num_samples=[500, 500, 2500, 2500,5000,5000,5000,5000,5000,5000] ) # Split the dataset into training and validation sets train_size = int(0.9 * len(modified_train_dataset)) val_size = len(modified_train_dataset) - train_size torch.manual_seed(42) train_dataset, validation_dataset = random_split(modified_train_dataset, [train_size, val_size]) train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True) test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=BATCH_SIZE, shuffle=False) val_loader = torch.utils.data.DataLoader(dataset=validation_dataset, batch_size=BATCH_SIZE, shuffle=False) #model & training settings criterion = nn.CrossEntropyLoss() #num_samples = [500, 500, 2500, 2500, 5000, 5000, 5000, 5000, 5000, 5000] #First balance method # Calculate class weights #class_weights = torch.FloatTensor([num_samples[i] / len(modified_train_dataset) for i in range(10)]) # Instantiate CrossEntropyLoss with class weights #criterion = nn.CrossEntropyLoss(weight=class_weights.to(DEVICE)) model = torchvision.models.resnet18(pretrained=True) num_ftrs = model.fc.in_features model.fc = nn.Linear(num_ftrs, 10) model.to(DEVICE) optimizer = optim.Adam(model.parameters(), lr=modellr) #Learning rate adjust (no need) def adjust_learning_rate(optimizer, epoch): """Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" modellrnew = modellr * (0.1 ** (epoch // 50)) print("lr:", modellrnew) for param_group in optimizer.param_groups: param_group['lr'] = modellrnew #Training method def train(model, device, train_loader, optimizer, epoch): model.train() sum_loss = 0 correct = 0 total_num = len(train_loader.dataset) print(total_num, len(train_loader)) for batch_idx, (data, target) in enumerate(train_loader): data, target = Variable(data).to(device), Variable(target).to(device) output = model(data) loss = criterion(output, target) optimizer.zero_grad() loss.backward() optimizer.step() print_loss = loss.data.item() sum_loss += print_loss _, pred = torch.max(output.data, 1) correct += torch.sum(pred == target) if (batch_idx + 1) % 50 == 0: print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( epoch, (batch_idx + 1) * len(data), len(train_loader.dataset), 100. * (batch_idx + 1) / len(train_loader), loss.item())) accuracy = correct / total_num ave_loss = sum_loss / len(train_loader) print('epoch:{}, loss:{}, Training Accuracy: {:.2%}'.format(epoch, ave_loss, accuracy)) def val(model, device, test_loader, epoch): global best_accuracy, best_epoch model.eval() test_loss = 0 correct = 0 total_num = len(test_loader.dataset) print(total_num, len(test_loader)) with torch.no_grad(): for data, target in test_loader: data, target = Variable(data).to(device), Variable(target).to(device) output = model(data) loss = criterion(output, target) _, pred = torch.max(output.data, 1) correct += torch.sum(pred == target) print_loss = loss.data.item() test_loss += print_loss correct = correct.data.item() acc = correct / total_num avgloss = test_loss / len(test_loader) print('\nVal set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( avgloss, correct, len(test_loader.dataset), 100 * acc)) if acc > best_accuracy: best_accuracy, best_epoch = acc, epoch torch.save(model, '666cifar_model_resnet18_lr0.0001_unbalanced_crossentropy.pth') # Test the model on the test set def test(model, device, test_loader): model.eval() correct = 0 total = 0 with torch.no_grad(): for data, target in test_loader: data, target = data.to(device), target.to(device) outputs = model(data) _, predicted = torch.max(outputs.data, 1) total += target.size(0) correct += (predicted == target).sum().item() accuracy = correct / total print('Test Accuracy: {:.2%} ({}/{})'.format(accuracy, correct, total)) # Train the model and track the best model for epoch in range(1, EPOCHS + 1): adjust_learning_rate(optimizer, epoch) train(model, DEVICE, train_loader, optimizer, epoch) val(model, DEVICE, val_loader, epoch) test(model, DEVICE, test_loader) print(f"Best model achieved at epoch {best_epoch} with accuracy: {best_accuracy * 100:.2f}%")