import torch.optim as optim import torch import torch.nn as nn import torch.nn.parallel import torch.optim import torch.utils.data import torch.utils.data.distributed import torchvision.transforms as transforms import torchvision.models from torch.autograd import Variable from torch.utils.data import random_split import os import time import numpy as np import pandas as pd import torch.nn.functional as F from torch.utils.data import Dataset from torch.utils.data import DataLoader import matplotlib.pyplot as plt from PIL import Image import torchvision.datasets as dsets from imblearn.over_sampling import RandomOverSampler class ModifiedCIFAR10(Dataset): def __init__(self, root, train=True, transform=None, target_classes=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9], num_samples=[500, 500, 2500, 2500, 5000, 5000, 5000, 5000, 5000, 5000], oversample=True, undersample=True): self.original_dataset = dsets.CIFAR10(root=root, train=train, download=True, transform=transform) self.target_classes = target_classes self.num_samples = num_samples self.oversample = oversample self.undersample = undersample self.sample_indices = [] for target_class, num_sample in zip(target_classes, num_samples): class_indices = [i for i, label in enumerate(self.original_dataset.targets) if label == target_class] self.sample_indices += class_indices[:num_sample] if self.oversample or self.undersample: X = [self.original_dataset[i][0].numpy() for i in self.sample_indices] y = [self.original_dataset[i][1] for i in self.sample_indices] if self.oversample: smote = SMOTE(sampling_strategy='auto', random_state=42, n_jobs=-1) X_resampled, y_resampled = smote.fit_resample(np.array(X).reshape(-1, 32 * 32 * 3), y) else: X_resampled, y_resampled = np.array(X).reshape(-1, 32 * 32 * 3), y if self.undersample: enn = EditedNearestNeighbours(sampling_strategy='auto', n_neighbors=3, n_jobs=-1) X_resampled, y_resampled = enn.fit_resample(X_resampled, y_resampled) self.resampled_indices = [idx for i, idx in enumerate(self.sample_indices) if i in range(len(X_resampled))] self.sample_indices = self.resampled_indices def __len__(self): return len(self.sample_indices) def __getitem__(self, idx): original_idx = self.sample_indices[idx] return self.original_dataset[original_idx] #training parameters modellr = 1e-4 BATCH_SIZE = 64 EPOCHS = 20 DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # Add these variables to keep track of the best accuracy and epoch number best_accuracy = 0 best_epoch = 0 np.random.seed(42) torch.manual_seed(42) #data preprocess mean, std = [0.4914, 0.4822, 0.4465], [0.247, 0.243, 0.261] # These values are mostly used by researchers as found to very useful in fast convergence transform_train = transforms.Compose([ transforms.Resize((32, 32)), transforms.RandomHorizontalFlip(), transforms.RandomRotation(30), #newly added transforms.ColorJitter(brightness = 0.1, # Randomly adjust color jitter of the images contrast = 0.1, saturation = 0.1), transforms.RandomAdjustSharpness(sharpness_factor = 2, p = 0.1), # Randomly adjust sharpness transforms.ToTensor(), transforms.Normalize(mean, std), transforms.RandomErasing() ]) transform_test = transforms.Compose([ transforms.Resize((32, 32)), transforms.ToTensor(), transforms.Normalize(mean, std), ]) test_dataset = dsets.CIFAR10(root='./data', train=False, download=True, transform = transform_test) # Modify the number of samples for class 0 from 5000 to 500 modified_train_dataset = ModifiedCIFAR10( root='./data', train=True, transform=transform_train, target_classes=[0, 1, 2, 3,4,5,6,7,8,9], num_samples=[500, 500, 2500, 2500,5000,5000,5000,5000,5000,5000] ) # Split the dataset into training and validation sets train_size = int(0.9 * len(modified_train_dataset)) val_size = len(modified_train_dataset) - train_size torch.manual_seed(42) train_dataset, validation_dataset = random_split(modified_train_dataset, [train_size, val_size]) ### from imblearn.over_sampling import RandomOverSampler from sklearn.utils import shuffle # Extract class labels for oversampling oversample_classes = [0, 1, 2, 3] # Extract features and labels from the training dataset X, y = zip(*[(x, y) for x, y in modified_train_dataset]) X = np.array([tensor.view(tensor.size(0), -1).numpy() for tensor in X]) y = np.array(y) # Flatten each tensor in X X_flattened = np.array([tensor.view(tensor.size(0), -1).numpy() for tensor in X]) # Oversample the flattened training dataset using RandomOverSampler oversampler = RandomOverSampler(sampling_strategy='auto', random_state=42) X_resampled, y_resampled = oversampler.fit_resample(X_flattened, y) # Convert back to PyTorch dataset oversampled_dataset = list(zip(X_resampled, y_resampled)) oversampled_dataset = torch.utils.data.TensorDataset(torch.from_numpy(X_resampled), torch.from_numpy(y_resampled)) # Split the oversampled dataset into training and validation sets oversampled_train_size = int(0.9 * len(oversampled_dataset)) oversampled_val_size = len(oversampled_dataset) - oversampled_train_size torch.manual_seed(42) oversampled_train_dataset, oversampled_validation_dataset = random_split(oversampled_dataset, [oversampled_train_size, oversampled_val_size]) # DataLoader for oversampled training set oversampled_train_loader = torch.utils.data.DataLoader(dataset=oversampled_train_dataset, batch_size=BATCH_SIZE, shuffle=True) # DataLoader for oversampled validation set oversampled_val_loader = torch.utils.data.DataLoader(dataset=oversampled_validation_dataset, batch_size=BATCH_SIZE, shuffle=False) ### train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True) test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=BATCH_SIZE, shuffle=False) val_loader = torch.utils.data.DataLoader(dataset=validation_dataset, batch_size=BATCH_SIZE, shuffle=False) #model & training settings criterion = nn.CrossEntropyLoss() num_samples = [500, 500, 2500, 2500, 5000, 5000, 5000, 5000, 5000, 5000] #First balance method num_samples = [500, 500, 2500, 2500, 5000, 5000, 5000, 5000, 5000, 5000] # Calculate class weights class_weights = torch.FloatTensor([num_samples[i] / len(modified_train_dataset) for i in range(10)]) # Instantiate CrossEntropyLoss with class weights criterion = nn.CrossEntropyLoss(weight=class_weights.to(DEVICE)) model = torchvision.models.resnet18(pretrained=True) num_ftrs = model.fc.in_features model.fc = nn.Linear(num_ftrs, 10) model.to(DEVICE) optimizer = optim.Adam(model.parameters(), lr=modellr) #Learning rate adjust (no need) def adjust_learning_rate(optimizer, epoch): """Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" modellrnew = modellr * (0.1 ** (epoch // 50)) print("lr:", modellrnew) for param_group in optimizer.param_groups: param_group['lr'] = modellrnew #Training method def train(model, device, train_loader, optimizer, epoch): model.train() sum_loss = 0 correct = 0 total_num = len(train_loader.dataset) print(total_num, len(train_loader)) for batch_idx, (data, target) in enumerate(train_loader): data, target = Variable(data).to(device), Variable(target).to(device) output = model(data) loss = criterion(output, target) optimizer.zero_grad() loss.backward() optimizer.step() print_loss = loss.data.item() sum_loss += print_loss _, pred = torch.max(output.data, 1) correct += torch.sum(pred == target) if (batch_idx + 1) % 50 == 0: print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( epoch, (batch_idx + 1) * len(data), len(train_loader.dataset), 100. * (batch_idx + 1) / len(train_loader), loss.item())) accuracy = correct / total_num ave_loss = sum_loss / len(train_loader) print('epoch:{}, loss:{}, Training Accuracy: {:.2%}'.format(epoch, ave_loss, accuracy)) def val(model, device, test_loader, epoch): global best_accuracy, best_epoch model.eval() test_loss = 0 correct = 0 total_num = len(test_loader.dataset) print(total_num, len(test_loader)) with torch.no_grad(): for data, target in test_loader: data, target = Variable(data).to(device), Variable(target).to(device) output = model(data) loss = criterion(output, target) _, pred = torch.max(output.data, 1) correct += torch.sum(pred == target) print_loss = loss.data.item() test_loss += print_loss correct = correct.data.item() acc = correct / total_num avgloss = test_loss / len(test_loader) print('\nVal set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( avgloss, correct, len(test_loader.dataset), 100 * acc)) if acc > best_accuracy: best_accuracy, best_epoch = acc, epoch torch.save(model, '666cifar_model_resnet18_lr0.0001_unbalanced_crossentropy.pth') # Test the model on the test set def test(model, device, test_loader): model.eval() correct = 0 total = 0 with torch.no_grad(): for data, target in test_loader: data, target = data.to(device), target.to(device) outputs = model(data) _, predicted = torch.max(outputs.data, 1) total += target.size(0) correct += (predicted == target).sum().item() accuracy = correct / total print('Test Accuracy: {:.2%} ({}/{})'.format(accuracy, correct, total)) # Train the model and track the best model for epoch in range(1, EPOCHS + 1): adjust_learning_rate(optimizer, epoch) train(model, DEVICE, oversampled_train_loader, optimizer, epoch) val(model, DEVICE, oversampled_val_loader, epoch) test(model, DEVICE, test_loader) print(f"Best model achieved at epoch {best_epoch} with accuracy: {best_accuracy * 100:.2f}%")