import torch.optim as optim import torch import torch.nn as nn import torch.nn.parallel import torch.optim import torch.utils.data import torch.utils.data.distributed import torchvision.transforms as transforms import torchvision.models from torch.autograd import Variable from torch.utils.data import random_split import os import time import numpy as np import pandas as pd import torch.nn.functional as F from torch.utils.data import Dataset from torch.utils.data import DataLoader import matplotlib.pyplot as plt from PIL import ImageS import torchvision.datasets as dsets class ModifiedCIFAR10(Dataset): def __init__(self, root, train=True, transform=None, target_classes=[0, 1, 2, 3,4,5,6,7,8,9], num_samples=[500, 500, 2500, 2500,5000,5000,5000,5000,5000,5000]): self.original_dataset = dsets.CIFAR10(root=root, train=train, download=True, transform=transform) self.target_classes = target_classes self.num_samples = num_samples self.sample_indices = [] for target_class, num_sample in zip(target_classes, num_samples): class_indices = [i for i, label in enumerate(self.original_dataset.targets) if label == target_class] self.sample_indices += class_indices[:num_sample] def __len__(self): return len(self.sample_indices) def __getitem__(self, idx): original_idx = self.sample_indices[idx] return self.original_dataset[original_idx] #training parameters modellr = 1e-4 BATCH_SIZE = 64 EPOCHS = 20 DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # Add these variables to keep track of the best accuracy and epoch number best_accuracy = 0 best_epoch = 0 np.random.seed(42) torch.manual_seed(42) #data preprocess mean, std = [0.4914, 0.4822, 0.4465], [0.247, 0.243, 0.261] # These values are mostly used by researchers as found to very useful in fast convergence transform_train = transforms.Compose([ transforms.Resize((32, 32)), transforms.RandomHorizontalFlip(), transforms.RandomRotation(30), #newly added transforms.ColorJitter(brightness = 0.1, # Randomly adjust color jitter of the images contrast = 0.1, saturation = 0.1), transforms.RandomAdjustSharpness(sharpness_factor = 2, p = 0.1), # Randomly adjust sharpness transforms.ToTensor(), transforms.Normalize(mean, std), transforms.RandomErasing() ]) transform_test = transforms.Compose([ transforms.Resize((32, 32)), transforms.ToTensor(), transforms.Normalize(mean, std), ]) test_dataset = dsets.CIFAR10(root='./data', train=False, download=True, transform = transform_test) # Modify the number of samples for class 0 from 5000 to 500 modified_train_dataset = ModifiedCIFAR10( root='./data', train=True, transform=transform_train, target_classes=[0, 1, 2, 3,4,5,6,7,8,9], num_samples=[500, 500, 2500, 2500,5000,5000,5000,5000,5000,5000] ) # Split the dataset into training and validation sets train_size = int(0.9 * len(modified_train_dataset)) val_size = len(modified_train_dataset) - train_size torch.manual_seed(42) train_dataset, validation_dataset = random_split(modified_train_dataset, [train_size, val_size]) train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True) test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=BATCH_SIZE, shuffle=False) val_loader = torch.utils.data.DataLoader(dataset=validation_dataset, batch_size=BATCH_SIZE, shuffle=False) #model & training settings criterion = nn.CrossEntropyLoss() #First balance method # Calculate class weights #class_weights = torch.FloatTensor([num_samples[i] / len(modified_train_dataset) for i in range(10)]) # Instantiate CrossEntropyLoss with class weights #criterion = nn.CrossEntropyLoss(weight=class_weights.to(DEVICE)) model = torchvision.models.resnet18(pretrained=True) num_ftrs = model.fc.in_features model.fc = nn.Linear(num_ftrs, 10) model.to(DEVICE) optimizer = optim.Adam(model.parameters(), lr=modellr) #Learning rate adjust (no need) def adjust_learning_rate(optimizer, epoch): """Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" modellrnew = modellr * (0.1 ** (epoch // 50)) print("lr:", modellrnew) for param_group in optimizer.param_groups: param_group['lr'] = modellrnew model = torch.load("666cifar_model_resnet18_lr0.0001_unbalanced_crossentropy.pth") from sklearn.metrics import confusion_matrix import seaborn as sn import matplotlib.pyplot as plt def get_predictions(model, device, data_loader): model.eval() model.to(device) all_predictions = [] all_targets = [] with torch.no_grad(): for data, target in data_loader: data, target = data.to(device), target.to(device) outputs = model(data) _, predicted = torch.max(outputs.data, 1) all_predictions.extend(predicted.cpu().numpy()) all_targets.extend(target.cpu().numpy()) return np.array(all_predictions), np.array(all_targets) # Get predictions and targets predictions, targets = get_predictions(model, DEVICE, test_loader) # Create confusion matrix conf_matrix = confusion_matrix(targets, predictions) # Plot heatmap plt.figure(figsize=(10, 8)) sn.heatmap(conf_matrix, annot=True, fmt='d', cmap='Blues', xticklabels=range(10), yticklabels=range(10)) plt.xlabel('Predicted Label') plt.ylabel('True Label') plt.title('Confusion Matrix') plt.show()