|
|
import torch.optim as optim |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.parallel |
|
|
import torch.optim |
|
|
import torch.utils.data |
|
|
import torch.utils.data.distributed |
|
|
import torchvision.transforms as transforms |
|
|
import torchvision.models |
|
|
from torch.autograd import Variable |
|
|
from torch.utils.data import random_split |
|
|
import os |
|
|
import time |
|
|
import numpy as np |
|
|
import pandas as pd |
|
|
import torch.nn.functional as F |
|
|
from torch.utils.data import Dataset |
|
|
from torch.utils.data import DataLoader |
|
|
import matplotlib.pyplot as plt |
|
|
from PIL import ImageS |
|
|
import torchvision.datasets as dsets |
|
|
|
|
|
|
|
|
|
|
|
class ModifiedCIFAR10(Dataset): |
|
|
def __init__(self, root, train=True, transform=None, target_classes=[0, 1, 2, 3,4,5,6,7,8,9], num_samples=[500, 500, 2500, 2500,5000,5000,5000,5000,5000,5000]): |
|
|
self.original_dataset = dsets.CIFAR10(root=root, train=train, download=True, transform=transform) |
|
|
self.target_classes = target_classes |
|
|
self.num_samples = num_samples |
|
|
|
|
|
self.sample_indices = [] |
|
|
for target_class, num_sample in zip(target_classes, num_samples): |
|
|
class_indices = [i for i, label in enumerate(self.original_dataset.targets) if label == target_class] |
|
|
self.sample_indices += class_indices[:num_sample] |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.sample_indices) |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
original_idx = self.sample_indices[idx] |
|
|
return self.original_dataset[original_idx] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
modellr = 1e-4 |
|
|
BATCH_SIZE = 64 |
|
|
EPOCHS = 20 |
|
|
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
|
|
|
best_accuracy = 0 |
|
|
best_epoch = 0 |
|
|
|
|
|
np.random.seed(42) |
|
|
torch.manual_seed(42) |
|
|
|
|
|
|
|
|
|
|
|
mean, std = [0.4914, 0.4822, 0.4465], [0.247, 0.243, 0.261] |
|
|
|
|
|
|
|
|
transform_train = transforms.Compose([ |
|
|
transforms.Resize((32, 32)), |
|
|
transforms.RandomHorizontalFlip(), |
|
|
transforms.RandomRotation(30), |
|
|
|
|
|
transforms.ColorJitter(brightness = 0.1, |
|
|
contrast = 0.1, |
|
|
saturation = 0.1), |
|
|
transforms.RandomAdjustSharpness(sharpness_factor = 2, p = 0.1), |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize(mean, std), |
|
|
transforms.RandomErasing() |
|
|
]) |
|
|
transform_test = transforms.Compose([ |
|
|
transforms.Resize((32, 32)), |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize(mean, std), |
|
|
]) |
|
|
|
|
|
|
|
|
test_dataset = dsets.CIFAR10(root='./data', train=False, download=True, transform = transform_test) |
|
|
|
|
|
|
|
|
modified_train_dataset = ModifiedCIFAR10( |
|
|
root='./data', |
|
|
train=True, |
|
|
transform=transform_train, |
|
|
target_classes=[0, 1, 2, 3,4,5,6,7,8,9], |
|
|
num_samples=[500, 500, 2500, 2500,5000,5000,5000,5000,5000,5000] |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
train_size = int(0.9 * len(modified_train_dataset)) |
|
|
val_size = len(modified_train_dataset) - train_size |
|
|
|
|
|
torch.manual_seed(42) |
|
|
train_dataset, validation_dataset = random_split(modified_train_dataset, [train_size, val_size]) |
|
|
|
|
|
|
|
|
|
|
|
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True) |
|
|
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=BATCH_SIZE, shuffle=False) |
|
|
|
|
|
val_loader = torch.utils.data.DataLoader(dataset=validation_dataset, batch_size=BATCH_SIZE, shuffle=False) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
criterion = nn.CrossEntropyLoss() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model = torchvision.models.resnet18(pretrained=True) |
|
|
num_ftrs = model.fc.in_features |
|
|
model.fc = nn.Linear(num_ftrs, 10) |
|
|
model.to(DEVICE) |
|
|
|
|
|
optimizer = optim.Adam(model.parameters(), lr=modellr) |
|
|
|
|
|
|
|
|
|
|
|
def adjust_learning_rate(optimizer, epoch): |
|
|
"""Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" |
|
|
modellrnew = modellr * (0.1 ** (epoch // 50)) |
|
|
print("lr:", modellrnew) |
|
|
for param_group in optimizer.param_groups: |
|
|
param_group['lr'] = modellrnew |
|
|
|
|
|
|
|
|
|
|
|
model = torch.load("666cifar_model_resnet18_lr0.0001_unbalanced_crossentropy.pth") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from sklearn.metrics import confusion_matrix |
|
|
import seaborn as sn |
|
|
import matplotlib.pyplot as plt |
|
|
|
|
|
def get_predictions(model, device, data_loader): |
|
|
model.eval() |
|
|
model.to(device) |
|
|
all_predictions = [] |
|
|
all_targets = [] |
|
|
|
|
|
with torch.no_grad(): |
|
|
for data, target in data_loader: |
|
|
data, target = data.to(device), target.to(device) |
|
|
outputs = model(data) |
|
|
_, predicted = torch.max(outputs.data, 1) |
|
|
all_predictions.extend(predicted.cpu().numpy()) |
|
|
all_targets.extend(target.cpu().numpy()) |
|
|
|
|
|
return np.array(all_predictions), np.array(all_targets) |
|
|
|
|
|
|
|
|
predictions, targets = get_predictions(model, DEVICE, test_loader) |
|
|
|
|
|
|
|
|
conf_matrix = confusion_matrix(targets, predictions) |
|
|
|
|
|
|
|
|
plt.figure(figsize=(10, 8)) |
|
|
sn.heatmap(conf_matrix, annot=True, fmt='d', cmap='Blues', xticklabels=range(10), yticklabels=range(10)) |
|
|
plt.xlabel('Predicted Label') |
|
|
plt.ylabel('True Label') |
|
|
plt.title('Confusion Matrix') |
|
|
plt.show() |