Spaces:
Runtime error
Runtime error
Upload folder using huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes. Β
See raw diff
- .gitattributes +8 -0
- Experiments/HOG+KNN.ipynb +0 -0
- Experiments/Logs/RESNET50_3HLL_CIFAR.out +0 -0
- Experiments/Logs/RESNET50_CIFAR.out +0 -0
- Experiments/PCA+HOG+KNN.ipynb +0 -0
- Experiments/Resnet50_classification.py +128 -0
- Experiments/__init__.py +1 -0
- Experiments/__pycache__/Resnet50_classification.cpython-311.pyc +0 -0
- Experiments/__pycache__/__init__.cpython-311.pyc +0 -0
- Experiments/__pycache__/ann_cifar.cpython-311.pyc +0 -0
- Experiments/ann_cifar.py +91 -0
- Experiments/centroid_res18.py +85 -0
- Experiments/cifar_10_knn.ipynb +0 -0
- Experiments/cifar_10_knn_resnet18_73_percent.ipynb +197 -0
- Experiments/cifar_10_pca_knn.ipynb +0 -0
- Experiments/cifar_10_res50_m.py +37 -0
- Experiments/cnn-resnet34-cifar10.ipynb +749 -0
- Experiments/quicknet_cifar10_centroid.ipynb +698 -0
- Experiments/quicknet_knn.py +37 -0
- Experiments/train_resnet50_3hll.py +163 -0
- Experiments/train_resnet50_iter_1.py +156 -0
- LICENSE +21 -0
- MidTerm_Report.pdf +3 -0
- Model/CIFAR.pt +3 -0
- Model/Resnet50_train_features.pt +3 -0
- Model/__pycache__/centroid_app.cpython-311.pyc +0 -0
- Model/ann.pt +3 -0
- Model/centroid_app.py +80 -0
- Model/data/batches.meta +0 -0
- Model/data/data_batch_1 +3 -0
- Model/data/data_batch_2 +3 -0
- Model/data/data_batch_3 +3 -0
- Model/data/data_batch_4 +3 -0
- Model/data/data_batch_5 +3 -0
- Model/data/mean_embeddings.pkl +3 -0
- Model/data/test_batch +3 -0
- Model/data/train_embeddings_resnet18.pkl +3 -0
- Model/pretrained_model_weights.h5 +3 -0
- Model/resnet18.h5 +3 -0
- Model/test.py +49 -0
- Preprocessing/cifar_eda.ipynb +0 -0
- Preprocessing/k_means.py +46 -0
- Preprocessing/utils.py +29 -0
- README.md +44 -12
- __init__.py +1 -0
- flagged/log.csv +2 -0
- images/akshat.jpeg +0 -0
- images/dog.jpg +0 -0
- images/harshiv.jpg +0 -0
- images/horse.jpg +0 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,11 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
images/image2image.png filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
MidTerm_Report.pdf filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
Model/data/data_batch_1 filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
Model/data/data_batch_2 filter=lfs diff=lfs merge=lfs -text
|
| 40 |
+
Model/data/data_batch_3 filter=lfs diff=lfs merge=lfs -text
|
| 41 |
+
Model/data/data_batch_4 filter=lfs diff=lfs merge=lfs -text
|
| 42 |
+
Model/data/data_batch_5 filter=lfs diff=lfs merge=lfs -text
|
| 43 |
+
Model/data/test_batch filter=lfs diff=lfs merge=lfs -text
|
Experiments/HOG+KNN.ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
Experiments/Logs/RESNET50_3HLL_CIFAR.out
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
Experiments/Logs/RESNET50_CIFAR.out
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
Experiments/PCA+HOG+KNN.ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
Experiments/Resnet50_classification.py
ADDED
|
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torchvision.datasets as datasets
|
| 2 |
+
import matplotlib.pyplot as plt
|
| 3 |
+
import tensorflow
|
| 4 |
+
import numpy as np
|
| 5 |
+
import warnings
|
| 6 |
+
warnings.filterwarnings("ignore")
|
| 7 |
+
import torchvision.models as models
|
| 8 |
+
import torch
|
| 9 |
+
from tqdm import tqdm
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
import torchvision.transforms as v2
|
| 12 |
+
from torch.utils.data import DataLoader
|
| 13 |
+
import torchvision.datasets as datasets
|
| 14 |
+
from sklearn.metrics.pairwise import euclidean_distances
|
| 15 |
+
from PIL import Image,ImageFilter
|
| 16 |
+
import torchvision.transforms as transforms
|
| 17 |
+
import pickle
|
| 18 |
+
import os
|
| 19 |
+
os.system("gdown --id 1qO2OLR7skDibo1LaMKD3CiOl_jaCTZ0h")
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
IMAGE_SIZE = 224
|
| 23 |
+
mean, std = [0.4914, 0.4822, 0.4465], [0.247, 0.243, 0.261]
|
| 24 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def unpickle(file):
|
| 28 |
+
with open(file, 'rb') as fo:
|
| 29 |
+
dict = pickle.load(fo, encoding='bytes')
|
| 30 |
+
return dict
|
| 31 |
+
|
| 32 |
+
class HiddenLayer(nn.Module):
|
| 33 |
+
def __init__(self, pretrained_model):
|
| 34 |
+
super().__init__()
|
| 35 |
+
self.premodel = pretrained_model
|
| 36 |
+
self.new_layer = nn.Sequential(
|
| 37 |
+
nn.Linear(1000, 512),
|
| 38 |
+
nn.LeakyReLU(),
|
| 39 |
+
nn.Linear(512, 512),
|
| 40 |
+
nn.LeakyReLU(),
|
| 41 |
+
nn.Linear(512, 256),
|
| 42 |
+
nn.LeakyReLU(),
|
| 43 |
+
nn.Linear(256, 10)
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
def forward(self, x):
|
| 47 |
+
out = self.premodel(x)
|
| 48 |
+
out_new_layer = self.new_layer(out)
|
| 49 |
+
return out_new_layer
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def predict(features_path,image):
|
| 53 |
+
batch1 = unpickle(r"Model\data\data_batch_1")
|
| 54 |
+
batch2 = unpickle(r"Model\data\data_batch_2")
|
| 55 |
+
batch3 = unpickle(r"Model\data\data_batch_3")
|
| 56 |
+
batch4 = unpickle(r"Model\data\data_batch_4")
|
| 57 |
+
batch5 = unpickle(r"Model\data\data_batch_5")
|
| 58 |
+
test_batch = unpickle(r"Model\data\test_batch")
|
| 59 |
+
train_batch = [batch1,batch2,batch3,batch4,batch5]
|
| 60 |
+
train_y = []
|
| 61 |
+
train_x = []
|
| 62 |
+
for batch in train_batch:
|
| 63 |
+
y_data = batch[b'labels']
|
| 64 |
+
x_data = batch[b'data']
|
| 65 |
+
x_data = x_data.reshape(len(x_data),3,32,32).transpose(0,2,3,1)
|
| 66 |
+
|
| 67 |
+
for i in range(len(y_data)):
|
| 68 |
+
train_y.append(y_data[i])
|
| 69 |
+
|
| 70 |
+
for i in range(len(y_data)):
|
| 71 |
+
train_x.append(x_data[i])
|
| 72 |
+
|
| 73 |
+
features = torch.load(features_path)
|
| 74 |
+
|
| 75 |
+
resnet_train_data = []
|
| 76 |
+
for i in range(len(features)):
|
| 77 |
+
resnet_train_data.append((features[i],train_y[i]))
|
| 78 |
+
|
| 79 |
+
class_images_dict = {}
|
| 80 |
+
for batch_idx, (images, labels) in enumerate(resnet_train_data):
|
| 81 |
+
if labels not in class_images_dict:
|
| 82 |
+
class_images_dict[labels] = []
|
| 83 |
+
class_images_dict[labels].append(batch_idx)
|
| 84 |
+
|
| 85 |
+
transform = transforms.Compose([
|
| 86 |
+
transforms.Resize((224, 224)),
|
| 87 |
+
transforms.ToTensor(),
|
| 88 |
+
transforms.Normalize(mean, std)
|
| 89 |
+
])
|
| 90 |
+
pil_image = Image.fromarray(image)
|
| 91 |
+
image_tensor = transform(pil_image).unsqueeze(0)
|
| 92 |
+
resnet = models.resnet50(pretrained=True)
|
| 93 |
+
model_check = HiddenLayer(resnet)
|
| 94 |
+
model_check.load_state_dict(torch.load("CIFAR_end_hll.pt"))
|
| 95 |
+
model_check.eval()
|
| 96 |
+
|
| 97 |
+
with torch.no_grad():
|
| 98 |
+
z = model_check(image_tensor)
|
| 99 |
+
_, test_label = torch.max(z, 1)
|
| 100 |
+
|
| 101 |
+
return test_label,z,features,class_images_dict,train_x
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def retrieve(image,k,feature_path=r"Model\Resnet50_train_features.pt"):
|
| 106 |
+
print(image.shape)
|
| 107 |
+
test_label,z,features,class_images_dict,train_x = predict(feature_path,image)
|
| 108 |
+
class_indices = class_images_dict[test_label.item()]
|
| 109 |
+
class_features = [(features[idx], idx) for idx in class_indices]
|
| 110 |
+
test_features = z.cpu().detach().numpy()
|
| 111 |
+
distances = euclidean_distances(test_features, [f[0].cpu() for f in class_features])
|
| 112 |
+
sorted_indices = np.argsort(distances.flatten())[:k]
|
| 113 |
+
closest_indices = [class_features[idx][1] for idx in sorted_indices]
|
| 114 |
+
retrieved_images = []
|
| 115 |
+
for i, idx in enumerate(closest_indices):
|
| 116 |
+
closest_image = Image.fromarray(train_x[idx])
|
| 117 |
+
sharpened_closest_image = closest_image.filter(ImageFilter.SHARPEN)
|
| 118 |
+
retrieved_images.append(sharpened_closest_image)
|
| 119 |
+
|
| 120 |
+
return retrieved_images
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
# test_image = Image.open("/kaggle/input/planes/download.jpeg")
|
| 128 |
+
# retrieved_images = retrieve(test_image,3)
|
Experiments/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
print("Initializing")
|
Experiments/__pycache__/Resnet50_classification.cpython-311.pyc
ADDED
|
Binary file (8.55 kB). View file
|
|
|
Experiments/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (222 Bytes). View file
|
|
|
Experiments/__pycache__/ann_cifar.cpython-311.pyc
ADDED
|
Binary file (6.22 kB). View file
|
|
|
Experiments/ann_cifar.py
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torchvision.datasets as datasets
|
| 2 |
+
import matplotlib.pyplot as plt
|
| 3 |
+
import tensorflow
|
| 4 |
+
import numpy as np
|
| 5 |
+
import warnings
|
| 6 |
+
warnings.filterwarnings("ignore")
|
| 7 |
+
import torchvision.models as models
|
| 8 |
+
import torch
|
| 9 |
+
from tqdm import tqdm
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
import torchvision.transforms as v2
|
| 12 |
+
from torch.utils.data import DataLoader
|
| 13 |
+
import torchvision.datasets as datasets
|
| 14 |
+
from sklearn.metrics.pairwise import euclidean_distances
|
| 15 |
+
from PIL import Image,ImageFilter
|
| 16 |
+
import torchvision.transforms as transforms
|
| 17 |
+
import pickle
|
| 18 |
+
import os
|
| 19 |
+
import cv2
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def unpickle(file):
|
| 23 |
+
with open(file, 'rb') as fo:
|
| 24 |
+
dict = pickle.load(fo, encoding='bytes')
|
| 25 |
+
return dict
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class ANN(nn.Module):
|
| 29 |
+
def __init__(self):
|
| 30 |
+
super(ANN, self).__init__()
|
| 31 |
+
self.layer_1 = nn.Linear(32 * 32 * 3, 512) # Input size is 32x32x3
|
| 32 |
+
self.relu1 = nn.ReLU()
|
| 33 |
+
self.layer_2 = nn.Linear(512, 256)
|
| 34 |
+
self.relu2 = nn.ReLU()
|
| 35 |
+
self.layer_3 = nn.Linear(256, 10) # Output size is 10 (number of classes)
|
| 36 |
+
|
| 37 |
+
def forward(self, x):
|
| 38 |
+
x = x.view(-1, 32 * 32 * 3) # Flatten the input tensor
|
| 39 |
+
x=self.layer_1(x)
|
| 40 |
+
x=self.relu1(x)
|
| 41 |
+
x=self.layer_2(x)
|
| 42 |
+
x=self.relu2(x)
|
| 43 |
+
x=self.layer_3(x)
|
| 44 |
+
return x
|
| 45 |
+
|
| 46 |
+
def retrieve(test_data,model,k_value=10):
|
| 47 |
+
test_data = cv2.resize(test_data,(32,32))
|
| 48 |
+
test_data = torch.tensor(test_data,dtype=torch.float32).unsqueeze(dim=0)
|
| 49 |
+
print(test_data.shape)
|
| 50 |
+
outputs = model(test_data)
|
| 51 |
+
_, predicted = torch.max(outputs.data, 1)
|
| 52 |
+
|
| 53 |
+
batch1 = unpickle(r"Model\data\data_batch_1")
|
| 54 |
+
batch2 = unpickle(r"Model\data\data_batch_2")
|
| 55 |
+
batch3 = unpickle(r"Model\data\data_batch_3")
|
| 56 |
+
batch4 = unpickle(r"Model\data\data_batch_4")
|
| 57 |
+
batch5 = unpickle(r"Model\data\data_batch_5")
|
| 58 |
+
# test_batch = unpickle(r"Model\data\test_batch")
|
| 59 |
+
train_batch = [batch1,batch2,batch3,batch4,batch5]
|
| 60 |
+
|
| 61 |
+
train_data = []
|
| 62 |
+
for batch in train_batch:
|
| 63 |
+
y_data = batch[b'labels']
|
| 64 |
+
x_data = batch[b'data']
|
| 65 |
+
x_data = x_data.reshape(len(x_data),3,32,32).transpose(0,2,3,1)
|
| 66 |
+
|
| 67 |
+
for i in range(len(y_data)):
|
| 68 |
+
train_data.append((x_data[i],y_data[i]))
|
| 69 |
+
|
| 70 |
+
x_with_specific_y = [x for x, y in train_data if y-1 == predicted]
|
| 71 |
+
x_with_specific_y = np.array(x_with_specific_y)
|
| 72 |
+
test_data.squeeze().permute(1,2,0)
|
| 73 |
+
point = np.array(test_data.squeeze())
|
| 74 |
+
|
| 75 |
+
distance_with_label_and_index = []
|
| 76 |
+
|
| 77 |
+
for i,x_train in enumerate(x_with_specific_y):
|
| 78 |
+
train_point = np.array(x_train[1])
|
| 79 |
+
distance_with_label_and_index.append((i,np.linalg.norm(point-train_point)))
|
| 80 |
+
|
| 81 |
+
#sorting based on distance
|
| 82 |
+
distance_with_label_and_index_sorted=sorted(distance_with_label_and_index,key=lambda x: x[1])
|
| 83 |
+
k_nearest_points = distance_with_label_and_index_sorted[0:k_value]
|
| 84 |
+
|
| 85 |
+
retrived_images = []
|
| 86 |
+
#calculating accuracy
|
| 87 |
+
for i,(index,distance) in enumerate(k_nearest_points):
|
| 88 |
+
# print(np.array(train_data[index]).shape)
|
| 89 |
+
retrived_images.append(np.array(x_with_specific_y[index]))
|
| 90 |
+
|
| 91 |
+
return retrived_images
|
Experiments/centroid_res18.py
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pickle
|
| 2 |
+
import random
|
| 3 |
+
import numpy as np
|
| 4 |
+
import cv2
|
| 5 |
+
from tensorflow.keras import models, layers
|
| 6 |
+
|
| 7 |
+
# Load the trained model architecture
|
| 8 |
+
def create_resnet18():
|
| 9 |
+
model = models.Sequential()
|
| 10 |
+
model.add(layers.Conv2D(64, (3, 3), activation='relu', padding='same', input_shape=(32, 32, 3)))
|
| 11 |
+
model.add(layers.MaxPooling2D((2, 2)))
|
| 12 |
+
model.add(layers.Conv2D(64, (3, 3), activation='relu', padding='same'))
|
| 13 |
+
model.add(layers.MaxPooling2D((2, 2)))
|
| 14 |
+
model.add(layers.Conv2D(128, (3, 3), activation='relu', padding='same'))
|
| 15 |
+
model.add(layers.Conv2D(128, (3, 3), activation='relu', padding='same'))
|
| 16 |
+
model.add(layers.MaxPooling2D((2, 2)))
|
| 17 |
+
model.add(layers.Conv2D(256, (3, 3), activation='relu', padding='same'))
|
| 18 |
+
model.add(layers.Conv2D(256, (3, 3), activation='relu', padding='same'))
|
| 19 |
+
model.add(layers.MaxPooling2D((2, 2)))
|
| 20 |
+
model.add(layers.Flatten())
|
| 21 |
+
model.add(layers.Dense(512, activation='relu'))
|
| 22 |
+
model.add(layers.Dense(10, activation='softmax'))
|
| 23 |
+
return model
|
| 24 |
+
|
| 25 |
+
# Load the pretrained weights
|
| 26 |
+
def load_pretrained_weights(model, weights_path):
|
| 27 |
+
model.load_weights(weights_path)
|
| 28 |
+
|
| 29 |
+
# Function to unpickle a file
|
| 30 |
+
def unpickle(file):
|
| 31 |
+
with open(file, 'rb') as fo:
|
| 32 |
+
dict = pickle.load(fo, encoding='bytes')
|
| 33 |
+
return dict
|
| 34 |
+
|
| 35 |
+
# Function to load images from the unpickled data batch file of a specific class
|
| 36 |
+
def load_class_images(class_index, train_batches):
|
| 37 |
+
images = []
|
| 38 |
+
for batch in train_batches:
|
| 39 |
+
if b'data' in batch and b'labels' in batch:
|
| 40 |
+
data = batch[b'data']
|
| 41 |
+
labels = batch[b'labels']
|
| 42 |
+
for i, label in enumerate(labels):
|
| 43 |
+
if label == class_index:
|
| 44 |
+
img = data[i].reshape(3, 32, 32).transpose(1, 2, 0) # Reshape and transpose the image
|
| 45 |
+
images.append(img)
|
| 46 |
+
return images
|
| 47 |
+
|
| 48 |
+
# Function to calculate the distance of the mean embeddings with a query image
|
| 49 |
+
def classify_query(query_image, model, mean_embeddings):
|
| 50 |
+
query_embedding = model.predict(np.expand_dims(query_image, axis=0))
|
| 51 |
+
distances = [np.linalg.norm(query_embedding.flatten() - mean_embedding) for mean_embedding in mean_embeddings]
|
| 52 |
+
predicted_class = np.argmin(distances)
|
| 53 |
+
return predicted_class
|
| 54 |
+
|
| 55 |
+
from google.colab.patches import cv2_imshow
|
| 56 |
+
|
| 57 |
+
def main():
|
| 58 |
+
model = create_resnet18()
|
| 59 |
+
load_pretrained_weights(model, 'pretrained_model_weights.h5')
|
| 60 |
+
|
| 61 |
+
mean_embeddings = pickle.load(open('mean_embeddings.pkl', 'rb'))
|
| 62 |
+
|
| 63 |
+
query_image_path = '/content/airplane_8925.png'
|
| 64 |
+
query_image = cv2.imread(query_image_path)
|
| 65 |
+
query_image = cv2.resize(query_image, (32, 32)) / 255.0 # Resize and normalize the image
|
| 66 |
+
|
| 67 |
+
predicted_class = classify_query(query_image, model, mean_embeddings)
|
| 68 |
+
print("Predicted Class:", predicted_class)
|
| 69 |
+
|
| 70 |
+
# Load random images of the predicted class
|
| 71 |
+
train_batches = [unpickle(f"/content/data_batch_{i}") for i in range(1,6)]
|
| 72 |
+
class_images = load_class_images(predicted_class+1, train_batches)
|
| 73 |
+
|
| 74 |
+
if class_images:
|
| 75 |
+
random_images = random.sample(class_images, 3) # Select 3 random images
|
| 76 |
+
for img in random_images:
|
| 77 |
+
if img is not None:
|
| 78 |
+
cv2_imshow(img)
|
| 79 |
+
else:
|
| 80 |
+
print("Random image is None.")
|
| 81 |
+
else:
|
| 82 |
+
print("No images found for the predicted class.")
|
| 83 |
+
|
| 84 |
+
if __name__ == "__main__":
|
| 85 |
+
main()
|
Experiments/cifar_10_knn.ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
Experiments/cifar_10_knn_resnet18_73_percent.ipynb
ADDED
|
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "code",
|
| 5 |
+
"execution_count": null,
|
| 6 |
+
"metadata": {},
|
| 7 |
+
"outputs": [],
|
| 8 |
+
"source": [
|
| 9 |
+
"import tensorflow as tf\n",
|
| 10 |
+
"from scipy.spatial.distance import euclidean\n",
|
| 11 |
+
"from tensorflow.keras import datasets, layers, models\n",
|
| 12 |
+
"import numpy as np\n",
|
| 13 |
+
"\n",
|
| 14 |
+
"\n",
|
| 15 |
+
"(train_images, train_labels), (test_images, test_labels) = datasets.cifar10.load_data()\n",
|
| 16 |
+
"\n",
|
| 17 |
+
"\n",
|
| 18 |
+
"train_images = train_images / 255.0\n",
|
| 19 |
+
"test_images = test_images / 255.0\n",
|
| 20 |
+
"\n",
|
| 21 |
+
"\n",
|
| 22 |
+
"def create_resnet18():\n",
|
| 23 |
+
" model = models.Sequential()\n",
|
| 24 |
+
" model.add(layers.Conv2D(64, (3, 3), activation='relu', padding='same', input_shape=(32, 32, 3)))\n",
|
| 25 |
+
" model.add(layers.MaxPooling2D((2, 2)))\n",
|
| 26 |
+
" model.add(layers.Conv2D(64, (3, 3), activation='relu', padding='same'))\n",
|
| 27 |
+
" model.add(layers.MaxPooling2D((2, 2)))\n",
|
| 28 |
+
" model.add(layers.Conv2D(128, (3, 3), activation='relu', padding='same'))\n",
|
| 29 |
+
" model.add(layers.Conv2D(128, (3, 3), activation='relu', padding='same'))\n",
|
| 30 |
+
" model.add(layers.MaxPooling2D((2, 2)))\n",
|
| 31 |
+
" model.add(layers.Conv2D(256, (3, 3), activation='relu', padding='same'))\n",
|
| 32 |
+
" model.add(layers.Conv2D(256, (3, 3), activation='relu', padding='same'))\n",
|
| 33 |
+
" model.add(layers.MaxPooling2D((2, 2)))\n",
|
| 34 |
+
" model.add(layers.Flatten())\n",
|
| 35 |
+
" model.add(layers.Dense(512, activation='relu'))\n",
|
| 36 |
+
" model.add(layers.Dense(10, activation='softmax'))\n",
|
| 37 |
+
" return model\n",
|
| 38 |
+
"\n",
|
| 39 |
+
"\n",
|
| 40 |
+
"model = create_resnet18()\n",
|
| 41 |
+
"\n",
|
| 42 |
+
"\n",
|
| 43 |
+
"model.compile(optimizer='adam',\n",
|
| 44 |
+
" loss='sparse_categorical_crossentropy',\n",
|
| 45 |
+
" metrics=['accuracy'])\n",
|
| 46 |
+
"\n",
|
| 47 |
+
"\n",
|
| 48 |
+
"model.fit(train_images, train_labels, epochs=10, batch_size=64, validation_data=(test_images, test_labels))\n",
|
| 49 |
+
"\n",
|
| 50 |
+
"\n",
|
| 51 |
+
"train_embeddings = model.predict(train_images)\n"
|
| 52 |
+
]
|
| 53 |
+
},
|
| 54 |
+
{
|
| 55 |
+
"cell_type": "code",
|
| 56 |
+
"execution_count": null,
|
| 57 |
+
"metadata": {},
|
| 58 |
+
"outputs": [],
|
| 59 |
+
"source": [
|
| 60 |
+
"test_embeddings = model.predict(test_images)\n",
|
| 61 |
+
"test_embeddings.shape"
|
| 62 |
+
]
|
| 63 |
+
},
|
| 64 |
+
{
|
| 65 |
+
"cell_type": "code",
|
| 66 |
+
"execution_count": null,
|
| 67 |
+
"metadata": {},
|
| 68 |
+
"outputs": [],
|
| 69 |
+
"source": [
|
| 70 |
+
"resnet_train_data = []\n",
|
| 71 |
+
"for i in range(len(train_embeddings)):\n",
|
| 72 |
+
" resnet_train_data.append((train_embeddings[i],train_labels[i]))\n",
|
| 73 |
+
"resnet_test_data = []\n",
|
| 74 |
+
"for i in range(len(test_embeddings)):\n",
|
| 75 |
+
" resnet_test_data.append((test_embeddings[i],test_labels[i]))"
|
| 76 |
+
]
|
| 77 |
+
},
|
| 78 |
+
{
|
| 79 |
+
"cell_type": "code",
|
| 80 |
+
"execution_count": null,
|
| 81 |
+
"metadata": {},
|
| 82 |
+
"outputs": [],
|
| 83 |
+
"source": [
|
| 84 |
+
"import matplotlib.pyplot as plt\n",
|
| 85 |
+
"def find_knn_and_plot_image(k_value:int =10,train_data=resnet_train_data,test_data=resnet_test_data[0]):\n",
|
| 86 |
+
" correct = 0\n",
|
| 87 |
+
" x_test,y_test = test_data\n",
|
| 88 |
+
" point = np.array(x_test)\n",
|
| 89 |
+
"\n",
|
| 90 |
+
" distance_with_label_and_index = []\n",
|
| 91 |
+
"\n",
|
| 92 |
+
" for i,(x_train,y_train) in enumerate(train_data):\n",
|
| 93 |
+
" train_point = np.array(x_train)\n",
|
| 94 |
+
" distance_with_label_and_index.append(((y_train,i),np.linalg.norm(point-train_point)))\n",
|
| 95 |
+
"\n",
|
| 96 |
+
" #sorting based on distance\n",
|
| 97 |
+
" distance_with_label_and_index_sorted=sorted(distance_with_label_and_index,key=lambda x: x[1])\n",
|
| 98 |
+
" k_nearest_points = distance_with_label_and_index_sorted[0:k_value]\n",
|
| 99 |
+
"\n",
|
| 100 |
+
" \n",
|
| 101 |
+
" #calculating accuracy\n",
|
| 102 |
+
" fig, axes = plt.subplots(1, k_value, figsize=(15, 3)) # Adjust figsize as needed\n",
|
| 103 |
+
" for i,((label,index),distance) in enumerate(k_nearest_points):\n",
|
| 104 |
+
" if(label == y_test):\n",
|
| 105 |
+
" correct+=1\n",
|
| 106 |
+
" axes[i].imshow(train_images[index])\n",
|
| 107 |
+
" axes[i].set_title(f\"Index: {index}\")\n",
|
| 108 |
+
" axes[i].axis('off')\n",
|
| 109 |
+
" plt.show()\n",
|
| 110 |
+
" return correct/(k_value)"
|
| 111 |
+
]
|
| 112 |
+
},
|
| 113 |
+
{
|
| 114 |
+
"cell_type": "code",
|
| 115 |
+
"execution_count": null,
|
| 116 |
+
"metadata": {},
|
| 117 |
+
"outputs": [],
|
| 118 |
+
"source": [
|
| 119 |
+
"index =91\n",
|
| 120 |
+
"plt.imshow(test_images[index])\n",
|
| 121 |
+
"print(test_labels[index])\n",
|
| 122 |
+
"find_knn_and_plot_image(10,resnet_train_data,resnet_test_data[index])"
|
| 123 |
+
]
|
| 124 |
+
},
|
| 125 |
+
{
|
| 126 |
+
"cell_type": "code",
|
| 127 |
+
"execution_count": null,
|
| 128 |
+
"metadata": {},
|
| 129 |
+
"outputs": [],
|
| 130 |
+
"source": [
|
| 131 |
+
"def find_knn(k_value:int =10,train_data=resnet_train_data,test_data=resnet_test_data):\n",
|
| 132 |
+
" results =[]\n",
|
| 133 |
+
" correct = 0\n",
|
| 134 |
+
" for idx, (x_test,y_test) in enumerate(tqdm(test_data)):\n",
|
| 135 |
+
" point = np.array(x_test)\n",
|
| 136 |
+
"\n",
|
| 137 |
+
" distance_with_label_and_index = []\n",
|
| 138 |
+
"\n",
|
| 139 |
+
" for i,(x_train,y_train) in enumerate(train_data):\n",
|
| 140 |
+
" train_point = np.array(x_train)\n",
|
| 141 |
+
" distance_with_label_and_index.append(((y_train,i),np.linalg.norm(point-train_point)))\n",
|
| 142 |
+
"\n",
|
| 143 |
+
" #sorting based on distance\n",
|
| 144 |
+
" distance_with_label_and_index_sorted=sorted(distance_with_label_and_index,key=lambda x: x[1])\n",
|
| 145 |
+
" k_nearest_points = distance_with_label_and_index_sorted[0:k_value]\n",
|
| 146 |
+
"\n",
|
| 147 |
+
"\n",
|
| 148 |
+
" #calculating accuracy\n",
|
| 149 |
+
" results.append([])\n",
|
| 150 |
+
" for i,((label,index),distance) in enumerate(k_nearest_points):\n",
|
| 151 |
+
" if(label == y_test):\n",
|
| 152 |
+
" correct+=1\n",
|
| 153 |
+
" results[idx].append((label,index))\n",
|
| 154 |
+
" return results,correct/(k_value*len(test_data))"
|
| 155 |
+
]
|
| 156 |
+
},
|
| 157 |
+
{
|
| 158 |
+
"cell_type": "code",
|
| 159 |
+
"execution_count": null,
|
| 160 |
+
"metadata": {},
|
| 161 |
+
"outputs": [],
|
| 162 |
+
"source": [
|
| 163 |
+
"results,accuracy = find_knn(10)"
|
| 164 |
+
]
|
| 165 |
+
},
|
| 166 |
+
{
|
| 167 |
+
"cell_type": "code",
|
| 168 |
+
"execution_count": null,
|
| 169 |
+
"metadata": {},
|
| 170 |
+
"outputs": [],
|
| 171 |
+
"source": [
|
| 172 |
+
"accuracy"
|
| 173 |
+
]
|
| 174 |
+
}
|
| 175 |
+
],
|
| 176 |
+
"metadata": {
|
| 177 |
+
"kernelspec": {
|
| 178 |
+
"display_name": "Python 3",
|
| 179 |
+
"language": "python",
|
| 180 |
+
"name": "python3"
|
| 181 |
+
},
|
| 182 |
+
"language_info": {
|
| 183 |
+
"codemirror_mode": {
|
| 184 |
+
"name": "ipython",
|
| 185 |
+
"version": 3
|
| 186 |
+
},
|
| 187 |
+
"file_extension": ".py",
|
| 188 |
+
"mimetype": "text/x-python",
|
| 189 |
+
"name": "python",
|
| 190 |
+
"nbconvert_exporter": "python",
|
| 191 |
+
"pygments_lexer": "ipython3",
|
| 192 |
+
"version": "3.11.4"
|
| 193 |
+
}
|
| 194 |
+
},
|
| 195 |
+
"nbformat": 4,
|
| 196 |
+
"nbformat_minor": 2
|
| 197 |
+
}
|
Experiments/cifar_10_pca_knn.ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
Experiments/cifar_10_res50_m.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import tensorflow as tf
|
| 2 |
+
from tensorflow.keras import datasets, layers, models
|
| 3 |
+
import numpy as np
|
| 4 |
+
|
| 5 |
+
# Load CIFAR-10 dataset
|
| 6 |
+
(train_images, train_labels), (_, _) = datasets.cifar10.load_data()
|
| 7 |
+
|
| 8 |
+
# Normalize pixel values to be between 0 and 1
|
| 9 |
+
train_images = train_images / 255.0
|
| 10 |
+
|
| 11 |
+
# Load a pre-trained model (e.g., ResNet50)
|
| 12 |
+
base_model = tf.keras.applications.ResNet50(weights='imagenet', include_top=False, input_shape=(32, 32, 3))
|
| 13 |
+
|
| 14 |
+
# Extract embeddings
|
| 15 |
+
embeddings_model = models.Sequential([
|
| 16 |
+
base_model,
|
| 17 |
+
layers.GlobalAveragePooling2D(), # Using GlobalAveragePooling2D to reduce spatial dimensions
|
| 18 |
+
layers.Dense(512, activation='relu') # You can adjust the size of the dense layer according to your need
|
| 19 |
+
])
|
| 20 |
+
|
| 21 |
+
# Compute embeddings
|
| 22 |
+
embeddings = embeddings_model.predict(train_images)
|
| 23 |
+
|
| 24 |
+
# Initialize array to store mean embeddings for each class
|
| 25 |
+
num_classes = 10
|
| 26 |
+
mean_embeddings = np.zeros((num_classes, embeddings.shape[1]))
|
| 27 |
+
|
| 28 |
+
# Calculate mean embeddings for each class
|
| 29 |
+
for class_label in range(num_classes):
|
| 30 |
+
class_indices = np.where(train_labels == class_label)[0]
|
| 31 |
+
class_embeddings = embeddings[class_indices]
|
| 32 |
+
mean_embedding = np.mean(class_embeddings, axis=0)
|
| 33 |
+
mean_embeddings[class_label] = mean_embedding
|
| 34 |
+
|
| 35 |
+
print("Mean embeddings for each class:")
|
| 36 |
+
print(mean_embeddings)
|
| 37 |
+
mean_embeddings.shape
|
Experiments/cnn-resnet34-cifar10.ipynb
ADDED
|
@@ -0,0 +1,749 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "markdown",
|
| 5 |
+
"metadata": {
|
| 6 |
+
"colab_type": "text",
|
| 7 |
+
"id": "MkoGLH_Tj5wn"
|
| 8 |
+
},
|
| 9 |
+
"source": [
|
| 10 |
+
"## Imports"
|
| 11 |
+
]
|
| 12 |
+
},
|
| 13 |
+
{
|
| 14 |
+
"cell_type": "code",
|
| 15 |
+
"execution_count": 2,
|
| 16 |
+
"metadata": {
|
| 17 |
+
"colab": {
|
| 18 |
+
"autoexec": {
|
| 19 |
+
"startup": false,
|
| 20 |
+
"wait_interval": 0
|
| 21 |
+
}
|
| 22 |
+
},
|
| 23 |
+
"colab_type": "code",
|
| 24 |
+
"id": "ORj09gnrj5wp"
|
| 25 |
+
},
|
| 26 |
+
"outputs": [],
|
| 27 |
+
"source": [
|
| 28 |
+
"import os\n",
|
| 29 |
+
"import time\n",
|
| 30 |
+
"\n",
|
| 31 |
+
"import numpy as np\n",
|
| 32 |
+
"import pandas as pd\n",
|
| 33 |
+
"\n",
|
| 34 |
+
"import torch\n",
|
| 35 |
+
"import torch.nn as nn\n",
|
| 36 |
+
"import torch.nn.functional as F\n",
|
| 37 |
+
"from torch.utils.data import DataLoader\n",
|
| 38 |
+
"\n",
|
| 39 |
+
"from torchvision import datasets\n",
|
| 40 |
+
"from torchvision import transforms\n",
|
| 41 |
+
"\n",
|
| 42 |
+
"import matplotlib.pyplot as plt\n",
|
| 43 |
+
"from PIL import Image\n",
|
| 44 |
+
"\n",
|
| 45 |
+
"\n",
|
| 46 |
+
"if torch.cuda.is_available():\n",
|
| 47 |
+
" torch.backends.cudnn.deterministic = True"
|
| 48 |
+
]
|
| 49 |
+
},
|
| 50 |
+
{
|
| 51 |
+
"cell_type": "markdown",
|
| 52 |
+
"metadata": {
|
| 53 |
+
"colab_type": "text",
|
| 54 |
+
"id": "I6hghKPxj5w0"
|
| 55 |
+
},
|
| 56 |
+
"source": [
|
| 57 |
+
"## Model Settings"
|
| 58 |
+
]
|
| 59 |
+
},
|
| 60 |
+
{
|
| 61 |
+
"cell_type": "code",
|
| 62 |
+
"execution_count": 3,
|
| 63 |
+
"metadata": {
|
| 64 |
+
"colab": {
|
| 65 |
+
"autoexec": {
|
| 66 |
+
"startup": false,
|
| 67 |
+
"wait_interval": 0
|
| 68 |
+
},
|
| 69 |
+
"base_uri": "https://localhost:8080/",
|
| 70 |
+
"height": 85
|
| 71 |
+
},
|
| 72 |
+
"colab_type": "code",
|
| 73 |
+
"executionInfo": {
|
| 74 |
+
"elapsed": 23936,
|
| 75 |
+
"status": "ok",
|
| 76 |
+
"timestamp": 1524974497505,
|
| 77 |
+
"user": {
|
| 78 |
+
"displayName": "Sebastian Raschka",
|
| 79 |
+
"photoUrl": "//lh6.googleusercontent.com/-cxK6yOSQ6uE/AAAAAAAAAAI/AAAAAAAAIfw/P9ar_CHsKOQ/s50-c-k-no/photo.jpg",
|
| 80 |
+
"userId": "118404394130788869227"
|
| 81 |
+
},
|
| 82 |
+
"user_tz": 240
|
| 83 |
+
},
|
| 84 |
+
"id": "NnT0sZIwj5wu",
|
| 85 |
+
"outputId": "55aed925-d17e-4c6a-8c71-0d9b3bde5637"
|
| 86 |
+
},
|
| 87 |
+
"outputs": [],
|
| 88 |
+
"source": [
|
| 89 |
+
"##########################\n",
|
| 90 |
+
"### SETTINGS\n",
|
| 91 |
+
"##########################\n",
|
| 92 |
+
"\n",
|
| 93 |
+
"# Hyperparameters\n",
|
| 94 |
+
"RANDOM_SEED = 1\n",
|
| 95 |
+
"LEARNING_RATE = 0.001\n",
|
| 96 |
+
"BATCH_SIZE = 256\n",
|
| 97 |
+
"NUM_EPOCHS = 10\n",
|
| 98 |
+
"\n",
|
| 99 |
+
"# Architecture\n",
|
| 100 |
+
"NUM_FEATURES = 28*28\n",
|
| 101 |
+
"NUM_CLASSES = 10\n",
|
| 102 |
+
"\n",
|
| 103 |
+
"# Other\n",
|
| 104 |
+
"DEVICE = \"cuda:1\"\n",
|
| 105 |
+
"GRAYSCALE = False"
|
| 106 |
+
]
|
| 107 |
+
},
|
| 108 |
+
{
|
| 109 |
+
"cell_type": "code",
|
| 110 |
+
"execution_count": 4,
|
| 111 |
+
"metadata": {},
|
| 112 |
+
"outputs": [],
|
| 113 |
+
"source": [
|
| 114 |
+
"##########################\n",
|
| 115 |
+
"### MODEL\n",
|
| 116 |
+
"##########################\n",
|
| 117 |
+
"\n",
|
| 118 |
+
"\n",
|
| 119 |
+
"def conv3x3(in_planes, out_planes, stride=1):\n",
|
| 120 |
+
" \"\"\"3x3 convolution with padding\"\"\"\n",
|
| 121 |
+
" return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,\n",
|
| 122 |
+
" padding=1, bias=False)\n",
|
| 123 |
+
"\n",
|
| 124 |
+
"\n",
|
| 125 |
+
"class BasicBlock(nn.Module):\n",
|
| 126 |
+
" expansion = 1\n",
|
| 127 |
+
"\n",
|
| 128 |
+
" def __init__(self, inplanes, planes, stride=1, downsample=None):\n",
|
| 129 |
+
" super(BasicBlock, self).__init__()\n",
|
| 130 |
+
" self.conv1 = conv3x3(inplanes, planes, stride)\n",
|
| 131 |
+
" self.bn1 = nn.BatchNorm2d(planes)\n",
|
| 132 |
+
" self.relu = nn.ReLU(inplace=True)\n",
|
| 133 |
+
" self.conv2 = conv3x3(planes, planes)\n",
|
| 134 |
+
" self.bn2 = nn.BatchNorm2d(planes)\n",
|
| 135 |
+
" self.downsample = downsample\n",
|
| 136 |
+
" self.stride = stride\n",
|
| 137 |
+
"\n",
|
| 138 |
+
" def forward(self, x):\n",
|
| 139 |
+
" residual = x\n",
|
| 140 |
+
"\n",
|
| 141 |
+
" out = self.conv1(x)\n",
|
| 142 |
+
" out = self.bn1(out)\n",
|
| 143 |
+
" out = self.relu(out)\n",
|
| 144 |
+
"\n",
|
| 145 |
+
" out = self.conv2(out)\n",
|
| 146 |
+
" out = self.bn2(out)\n",
|
| 147 |
+
"\n",
|
| 148 |
+
" if self.downsample is not None:\n",
|
| 149 |
+
" residual = self.downsample(x)\n",
|
| 150 |
+
"\n",
|
| 151 |
+
" out += residual\n",
|
| 152 |
+
" out = self.relu(out)\n",
|
| 153 |
+
"\n",
|
| 154 |
+
" return out\n",
|
| 155 |
+
"\n",
|
| 156 |
+
"\n",
|
| 157 |
+
"\n",
|
| 158 |
+
"\n",
|
| 159 |
+
"class ResNet(nn.Module):\n",
|
| 160 |
+
"\n",
|
| 161 |
+
" def __init__(self, block, layers, num_classes, grayscale):\n",
|
| 162 |
+
" self.inplanes = 64\n",
|
| 163 |
+
" if grayscale:\n",
|
| 164 |
+
" in_dim = 1\n",
|
| 165 |
+
" else:\n",
|
| 166 |
+
" in_dim = 3\n",
|
| 167 |
+
" super(ResNet, self).__init__()\n",
|
| 168 |
+
" self.conv1 = nn.Conv2d(in_dim, 64, kernel_size=7, stride=2, padding=3,\n",
|
| 169 |
+
" bias=False)\n",
|
| 170 |
+
" self.bn1 = nn.BatchNorm2d(64)\n",
|
| 171 |
+
" self.relu = nn.ReLU(inplace=True)\n",
|
| 172 |
+
" self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)\n",
|
| 173 |
+
" self.layer1 = self._make_layer(block, 64, layers[0])\n",
|
| 174 |
+
" self.layer2 = self._make_layer(block, 128, layers[1], stride=2)\n",
|
| 175 |
+
" self.layer3 = self._make_layer(block, 256, layers[2], stride=2)\n",
|
| 176 |
+
" self.layer4 = self._make_layer(block, 512, layers[3], stride=2)\n",
|
| 177 |
+
" self.avgpool = nn.AvgPool2d(7, stride=1)\n",
|
| 178 |
+
" self.fc = nn.Linear(512 * block.expansion, num_classes)\n",
|
| 179 |
+
"\n",
|
| 180 |
+
" for m in self.modules():\n",
|
| 181 |
+
" if isinstance(m, nn.Conv2d):\n",
|
| 182 |
+
" n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels\n",
|
| 183 |
+
" m.weight.data.normal_(0, (2. / n)**.5)\n",
|
| 184 |
+
" elif isinstance(m, nn.BatchNorm2d):\n",
|
| 185 |
+
" m.weight.data.fill_(1)\n",
|
| 186 |
+
" m.bias.data.zero_()\n",
|
| 187 |
+
"\n",
|
| 188 |
+
" def _make_layer(self, block, planes, blocks, stride=1):\n",
|
| 189 |
+
" downsample = None\n",
|
| 190 |
+
" if stride != 1 or self.inplanes != planes * block.expansion:\n",
|
| 191 |
+
" downsample = nn.Sequential(\n",
|
| 192 |
+
" nn.Conv2d(self.inplanes, planes * block.expansion,\n",
|
| 193 |
+
" kernel_size=1, stride=stride, bias=False),\n",
|
| 194 |
+
" nn.BatchNorm2d(planes * block.expansion),\n",
|
| 195 |
+
" )\n",
|
| 196 |
+
"\n",
|
| 197 |
+
" layers = []\n",
|
| 198 |
+
" layers.append(block(self.inplanes, planes, stride, downsample))\n",
|
| 199 |
+
" self.inplanes = planes * block.expansion\n",
|
| 200 |
+
" for i in range(1, blocks):\n",
|
| 201 |
+
" layers.append(block(self.inplanes, planes))\n",
|
| 202 |
+
"\n",
|
| 203 |
+
" return nn.Sequential(*layers)\n",
|
| 204 |
+
"\n",
|
| 205 |
+
" def forward(self, x):\n",
|
| 206 |
+
" x = self.conv1(x)\n",
|
| 207 |
+
" x = self.bn1(x)\n",
|
| 208 |
+
" x = self.relu(x)\n",
|
| 209 |
+
" x = self.maxpool(x)\n",
|
| 210 |
+
"\n",
|
| 211 |
+
" x = self.layer1(x)\n",
|
| 212 |
+
" x = self.layer2(x)\n",
|
| 213 |
+
" x = self.layer3(x)\n",
|
| 214 |
+
" x = self.layer4(x)\n",
|
| 215 |
+
" \n",
|
| 216 |
+
" x = x.view(x.size(0), -1)\n",
|
| 217 |
+
" logits = self.fc(x)\n",
|
| 218 |
+
" probas = F.softmax(logits, dim=1)\n",
|
| 219 |
+
" return logits, probas\n",
|
| 220 |
+
"\n",
|
| 221 |
+
"\n",
|
| 222 |
+
"\n",
|
| 223 |
+
"def resnet34(num_classes):\n",
|
| 224 |
+
" \"\"\"Constructs a ResNet-34 model.\"\"\"\n",
|
| 225 |
+
" model = ResNet(block=BasicBlock, \n",
|
| 226 |
+
" layers=[3, 4, 6, 3],\n",
|
| 227 |
+
" num_classes=NUM_CLASSES,\n",
|
| 228 |
+
" grayscale=GRAYSCALE)\n",
|
| 229 |
+
" return model\n"
|
| 230 |
+
]
|
| 231 |
+
},
|
| 232 |
+
{
|
| 233 |
+
"cell_type": "markdown",
|
| 234 |
+
"metadata": {
|
| 235 |
+
"colab_type": "text",
|
| 236 |
+
"id": "RAodboScj5w6"
|
| 237 |
+
},
|
| 238 |
+
"source": [
|
| 239 |
+
"## Training without Pinned Memory"
|
| 240 |
+
]
|
| 241 |
+
},
|
| 242 |
+
{
|
| 243 |
+
"cell_type": "code",
|
| 244 |
+
"execution_count": 5,
|
| 245 |
+
"metadata": {},
|
| 246 |
+
"outputs": [
|
| 247 |
+
{
|
| 248 |
+
"name": "stdout",
|
| 249 |
+
"output_type": "stream",
|
| 250 |
+
"text": [
|
| 251 |
+
"Files already downloaded and verified\n",
|
| 252 |
+
"Image batch dimensions: torch.Size([256, 3, 32, 32])\n",
|
| 253 |
+
"Image label dimensions: torch.Size([256])\n",
|
| 254 |
+
"Image batch dimensions: torch.Size([256, 3, 32, 32])\n",
|
| 255 |
+
"Image label dimensions: torch.Size([256])\n"
|
| 256 |
+
]
|
| 257 |
+
}
|
| 258 |
+
],
|
| 259 |
+
"source": [
|
| 260 |
+
"##########################\n",
|
| 261 |
+
"### CIFAR-10 Dataset\n",
|
| 262 |
+
"##########################\n",
|
| 263 |
+
"\n",
|
| 264 |
+
"\n",
|
| 265 |
+
"# Note transforms.ToTensor() scales input images\n",
|
| 266 |
+
"# to 0-1 range\n",
|
| 267 |
+
"train_dataset = datasets.CIFAR10(root='data', \n",
|
| 268 |
+
" train=True, \n",
|
| 269 |
+
" transform=transforms.ToTensor(),\n",
|
| 270 |
+
" download=True)\n",
|
| 271 |
+
"\n",
|
| 272 |
+
"test_dataset = datasets.CIFAR10(root='data', \n",
|
| 273 |
+
" train=False, \n",
|
| 274 |
+
" transform=transforms.ToTensor())\n",
|
| 275 |
+
"\n",
|
| 276 |
+
"\n",
|
| 277 |
+
"train_loader = DataLoader(dataset=train_dataset, \n",
|
| 278 |
+
" batch_size=BATCH_SIZE, \n",
|
| 279 |
+
" num_workers=8,\n",
|
| 280 |
+
" shuffle=True)\n",
|
| 281 |
+
"\n",
|
| 282 |
+
"test_loader = DataLoader(dataset=test_dataset, \n",
|
| 283 |
+
" batch_size=BATCH_SIZE,\n",
|
| 284 |
+
" num_workers=8,\n",
|
| 285 |
+
" shuffle=False)\n",
|
| 286 |
+
"\n",
|
| 287 |
+
"# Checking the dataset\n",
|
| 288 |
+
"for images, labels in train_loader: \n",
|
| 289 |
+
" print('Image batch dimensions:', images.shape)\n",
|
| 290 |
+
" print('Image label dimensions:', labels.shape)\n",
|
| 291 |
+
" break\n",
|
| 292 |
+
"\n",
|
| 293 |
+
"# Checking the dataset\n",
|
| 294 |
+
"for images, labels in train_loader: \n",
|
| 295 |
+
" print('Image batch dimensions:', images.shape)\n",
|
| 296 |
+
" print('Image label dimensions:', labels.shape)\n",
|
| 297 |
+
" break"
|
| 298 |
+
]
|
| 299 |
+
},
|
| 300 |
+
{
|
| 301 |
+
"cell_type": "code",
|
| 302 |
+
"execution_count": 6,
|
| 303 |
+
"metadata": {
|
| 304 |
+
"colab": {
|
| 305 |
+
"autoexec": {
|
| 306 |
+
"startup": false,
|
| 307 |
+
"wait_interval": 0
|
| 308 |
+
}
|
| 309 |
+
},
|
| 310 |
+
"colab_type": "code",
|
| 311 |
+
"id": "_lza9t_uj5w1"
|
| 312 |
+
},
|
| 313 |
+
"outputs": [],
|
| 314 |
+
"source": [
|
| 315 |
+
"torch.manual_seed(RANDOM_SEED)\n",
|
| 316 |
+
"\n",
|
| 317 |
+
"model = resnet34(NUM_CLASSES)\n",
|
| 318 |
+
"model.to(DEVICE)\n",
|
| 319 |
+
"\n",
|
| 320 |
+
"optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE) "
|
| 321 |
+
]
|
| 322 |
+
},
|
| 323 |
+
{
|
| 324 |
+
"cell_type": "code",
|
| 325 |
+
"execution_count": 7,
|
| 326 |
+
"metadata": {
|
| 327 |
+
"colab": {
|
| 328 |
+
"autoexec": {
|
| 329 |
+
"startup": false,
|
| 330 |
+
"wait_interval": 0
|
| 331 |
+
},
|
| 332 |
+
"base_uri": "https://localhost:8080/",
|
| 333 |
+
"height": 1547
|
| 334 |
+
},
|
| 335 |
+
"colab_type": "code",
|
| 336 |
+
"executionInfo": {
|
| 337 |
+
"elapsed": 2384585,
|
| 338 |
+
"status": "ok",
|
| 339 |
+
"timestamp": 1524976888520,
|
| 340 |
+
"user": {
|
| 341 |
+
"displayName": "Sebastian Raschka",
|
| 342 |
+
"photoUrl": "//lh6.googleusercontent.com/-cxK6yOSQ6uE/AAAAAAAAAAI/AAAAAAAAIfw/P9ar_CHsKOQ/s50-c-k-no/photo.jpg",
|
| 343 |
+
"userId": "118404394130788869227"
|
| 344 |
+
},
|
| 345 |
+
"user_tz": 240
|
| 346 |
+
},
|
| 347 |
+
"id": "Dzh3ROmRj5w7",
|
| 348 |
+
"outputId": "5f8fd8c9-b076-403a-b0b7-fd2d498b48d7",
|
| 349 |
+
"scrolled": false
|
| 350 |
+
},
|
| 351 |
+
"outputs": [
|
| 352 |
+
{
|
| 353 |
+
"name": "stdout",
|
| 354 |
+
"output_type": "stream",
|
| 355 |
+
"text": [
|
| 356 |
+
"Epoch: 001/010 | Batch 0000/0196 | Cost: 2.6021\n",
|
| 357 |
+
"Epoch: 001/010 | Batch 0150/0196 | Cost: 1.3961\n",
|
| 358 |
+
"Epoch: 001/010 | Train: 45.084%\n",
|
| 359 |
+
"Time elapsed: 0.26 min\n",
|
| 360 |
+
"Epoch: 002/010 | Batch 0000/0196 | Cost: 1.1228\n",
|
| 361 |
+
"Epoch: 002/010 | Batch 0150/0196 | Cost: 1.0426\n",
|
| 362 |
+
"Epoch: 002/010 | Train: 56.166%\n",
|
| 363 |
+
"Time elapsed: 0.52 min\n",
|
| 364 |
+
"Epoch: 003/010 | Batch 0000/0196 | Cost: 0.9980\n",
|
| 365 |
+
"Epoch: 003/010 | Batch 0150/0196 | Cost: 0.8279\n",
|
| 366 |
+
"Epoch: 003/010 | Train: 66.702%\n",
|
| 367 |
+
"Time elapsed: 0.80 min\n",
|
| 368 |
+
"Epoch: 004/010 | Batch 0000/0196 | Cost: 0.6384\n",
|
| 369 |
+
"Epoch: 004/010 | Batch 0150/0196 | Cost: 0.7103\n",
|
| 370 |
+
"Epoch: 004/010 | Train: 65.330%\n",
|
| 371 |
+
"Time elapsed: 1.08 min\n",
|
| 372 |
+
"Epoch: 005/010 | Batch 0000/0196 | Cost: 0.6308\n",
|
| 373 |
+
"Epoch: 005/010 | Batch 0150/0196 | Cost: 0.5913\n",
|
| 374 |
+
"Epoch: 005/010 | Train: 79.636%\n",
|
| 375 |
+
"Time elapsed: 1.36 min\n",
|
| 376 |
+
"Epoch: 006/010 | Batch 0000/0196 | Cost: 0.4409\n",
|
| 377 |
+
"Epoch: 006/010 | Batch 0150/0196 | Cost: 0.5557\n",
|
| 378 |
+
"Epoch: 006/010 | Train: 76.456%\n",
|
| 379 |
+
"Time elapsed: 1.62 min\n",
|
| 380 |
+
"Epoch: 007/010 | Batch 0000/0196 | Cost: 0.4778\n",
|
| 381 |
+
"Epoch: 007/010 | Batch 0150/0196 | Cost: 0.4815\n",
|
| 382 |
+
"Epoch: 007/010 | Train: 65.890%\n",
|
| 383 |
+
"Time elapsed: 1.89 min\n",
|
| 384 |
+
"Epoch: 008/010 | Batch 0000/0196 | Cost: 0.3782\n",
|
| 385 |
+
"Epoch: 008/010 | Batch 0150/0196 | Cost: 0.4339\n",
|
| 386 |
+
"Epoch: 008/010 | Train: 85.200%\n",
|
| 387 |
+
"Time elapsed: 2.16 min\n",
|
| 388 |
+
"Epoch: 009/010 | Batch 0000/0196 | Cost: 0.3083\n",
|
| 389 |
+
"Epoch: 009/010 | Batch 0150/0196 | Cost: 0.3290\n",
|
| 390 |
+
"Epoch: 009/010 | Train: 78.108%\n",
|
| 391 |
+
"Time elapsed: 2.42 min\n",
|
| 392 |
+
"Epoch: 010/010 | Batch 0000/0196 | Cost: 0.2229\n",
|
| 393 |
+
"Epoch: 010/010 | Batch 0150/0196 | Cost: 0.1945\n",
|
| 394 |
+
"Epoch: 010/010 | Train: 87.384%\n",
|
| 395 |
+
"Time elapsed: 2.70 min\n",
|
| 396 |
+
"Total Training Time: 2.70 min\n",
|
| 397 |
+
"Test accuracy: 70.67%\n",
|
| 398 |
+
"Total Time: 2.71 min\n"
|
| 399 |
+
]
|
| 400 |
+
}
|
| 401 |
+
],
|
| 402 |
+
"source": [
|
| 403 |
+
"def compute_accuracy(model, data_loader, device):\n",
|
| 404 |
+
" correct_pred, num_examples = 0, 0\n",
|
| 405 |
+
" for i, (features, targets) in enumerate(data_loader):\n",
|
| 406 |
+
" \n",
|
| 407 |
+
" features = features.to(device)\n",
|
| 408 |
+
" targets = targets.to(device)\n",
|
| 409 |
+
"\n",
|
| 410 |
+
" logits, probas = model(features)\n",
|
| 411 |
+
" _, predicted_labels = torch.max(probas, 1)\n",
|
| 412 |
+
" num_examples += targets.size(0)\n",
|
| 413 |
+
" correct_pred += (predicted_labels == targets).sum()\n",
|
| 414 |
+
" return correct_pred.float()/num_examples * 100\n",
|
| 415 |
+
" \n",
|
| 416 |
+
"\n",
|
| 417 |
+
"start_time = time.time()\n",
|
| 418 |
+
"for epoch in range(NUM_EPOCHS):\n",
|
| 419 |
+
" \n",
|
| 420 |
+
" model.train()\n",
|
| 421 |
+
" for batch_idx, (features, targets) in enumerate(train_loader):\n",
|
| 422 |
+
" \n",
|
| 423 |
+
" features = features.to(DEVICE)\n",
|
| 424 |
+
" targets = targets.to(DEVICE)\n",
|
| 425 |
+
" \n",
|
| 426 |
+
" ### FORWARD AND BACK PROP\n",
|
| 427 |
+
" logits, probas = model(features)\n",
|
| 428 |
+
" cost = F.cross_entropy(logits, targets)\n",
|
| 429 |
+
" optimizer.zero_grad()\n",
|
| 430 |
+
" \n",
|
| 431 |
+
" cost.backward()\n",
|
| 432 |
+
" \n",
|
| 433 |
+
" ### UPDATE MODEL PARAMETERS\n",
|
| 434 |
+
" optimizer.step()\n",
|
| 435 |
+
" \n",
|
| 436 |
+
" ### LOGGING\n",
|
| 437 |
+
" if not batch_idx % 150:\n",
|
| 438 |
+
" print ('Epoch: %03d/%03d | Batch %04d/%04d | Cost: %.4f' \n",
|
| 439 |
+
" %(epoch+1, NUM_EPOCHS, batch_idx, \n",
|
| 440 |
+
" len(train_loader), cost))\n",
|
| 441 |
+
"\n",
|
| 442 |
+
" \n",
|
| 443 |
+
"\n",
|
| 444 |
+
" model.eval()\n",
|
| 445 |
+
" with torch.set_grad_enabled(False): # save memory during inference\n",
|
| 446 |
+
" print('Epoch: %03d/%03d | Train: %.3f%%' % (\n",
|
| 447 |
+
" epoch+1, NUM_EPOCHS, \n",
|
| 448 |
+
" compute_accuracy(model, train_loader, device=DEVICE)))\n",
|
| 449 |
+
" \n",
|
| 450 |
+
" print('Time elapsed: %.2f min' % ((time.time() - start_time)/60))\n",
|
| 451 |
+
" \n",
|
| 452 |
+
"print('Total Training Time: %.2f min' % ((time.time() - start_time)/60))\n",
|
| 453 |
+
"\n",
|
| 454 |
+
"\n",
|
| 455 |
+
"with torch.set_grad_enabled(False): # save memory during inference\n",
|
| 456 |
+
" print('Test accuracy: %.2f%%' % (compute_accuracy(model, test_loader, device=DEVICE)))\n",
|
| 457 |
+
" \n",
|
| 458 |
+
"print('Total Time: %.2f min' % ((time.time() - start_time)/60))"
|
| 459 |
+
]
|
| 460 |
+
},
|
| 461 |
+
{
|
| 462 |
+
"cell_type": "markdown",
|
| 463 |
+
"metadata": {},
|
| 464 |
+
"source": [
|
| 465 |
+
"## Training with Pinned Memory"
|
| 466 |
+
]
|
| 467 |
+
},
|
| 468 |
+
{
|
| 469 |
+
"cell_type": "code",
|
| 470 |
+
"execution_count": 8,
|
| 471 |
+
"metadata": {},
|
| 472 |
+
"outputs": [
|
| 473 |
+
{
|
| 474 |
+
"name": "stdout",
|
| 475 |
+
"output_type": "stream",
|
| 476 |
+
"text": [
|
| 477 |
+
"Files already downloaded and verified\n",
|
| 478 |
+
"Image batch dimensions: torch.Size([256, 3, 32, 32])\n",
|
| 479 |
+
"Image label dimensions: torch.Size([256])\n",
|
| 480 |
+
"Image batch dimensions: torch.Size([256, 3, 32, 32])\n",
|
| 481 |
+
"Image label dimensions: torch.Size([256])\n"
|
| 482 |
+
]
|
| 483 |
+
}
|
| 484 |
+
],
|
| 485 |
+
"source": [
|
| 486 |
+
"##########################\n",
|
| 487 |
+
"### CIFAR-10 Dataset\n",
|
| 488 |
+
"##########################\n",
|
| 489 |
+
"\n",
|
| 490 |
+
"\n",
|
| 491 |
+
"# Note transforms.ToTensor() scales input images\n",
|
| 492 |
+
"# to 0-1 range\n",
|
| 493 |
+
"train_dataset = datasets.CIFAR10(root='data', \n",
|
| 494 |
+
" train=True, \n",
|
| 495 |
+
" transform=transforms.ToTensor(),\n",
|
| 496 |
+
" download=True)\n",
|
| 497 |
+
"\n",
|
| 498 |
+
"test_dataset = datasets.CIFAR10(root='data', \n",
|
| 499 |
+
" train=False, \n",
|
| 500 |
+
" transform=transforms.ToTensor())\n",
|
| 501 |
+
"\n",
|
| 502 |
+
"\n",
|
| 503 |
+
"train_loader = DataLoader(dataset=train_dataset, \n",
|
| 504 |
+
" batch_size=BATCH_SIZE, \n",
|
| 505 |
+
" pin_memory=True,\n",
|
| 506 |
+
" shuffle=True)\n",
|
| 507 |
+
"\n",
|
| 508 |
+
"test_loader = DataLoader(dataset=test_dataset, \n",
|
| 509 |
+
" batch_size=BATCH_SIZE,\n",
|
| 510 |
+
" pin_memory=True,\n",
|
| 511 |
+
" shuffle=False)\n",
|
| 512 |
+
"\n",
|
| 513 |
+
"# Checking the dataset\n",
|
| 514 |
+
"for images, labels in train_loader: \n",
|
| 515 |
+
" print('Image batch dimensions:', images.shape)\n",
|
| 516 |
+
" print('Image label dimensions:', labels.shape)\n",
|
| 517 |
+
" break\n",
|
| 518 |
+
"\n",
|
| 519 |
+
"# Checking the dataset\n",
|
| 520 |
+
"for images, labels in train_loader: \n",
|
| 521 |
+
" print('Image batch dimensions:', images.shape)\n",
|
| 522 |
+
" print('Image label dimensions:', labels.shape)\n",
|
| 523 |
+
" break"
|
| 524 |
+
]
|
| 525 |
+
},
|
| 526 |
+
{
|
| 527 |
+
"cell_type": "code",
|
| 528 |
+
"execution_count": 9,
|
| 529 |
+
"metadata": {
|
| 530 |
+
"colab": {
|
| 531 |
+
"autoexec": {
|
| 532 |
+
"startup": false,
|
| 533 |
+
"wait_interval": 0
|
| 534 |
+
}
|
| 535 |
+
},
|
| 536 |
+
"colab_type": "code",
|
| 537 |
+
"id": "_lza9t_uj5w1"
|
| 538 |
+
},
|
| 539 |
+
"outputs": [],
|
| 540 |
+
"source": [
|
| 541 |
+
"torch.manual_seed(RANDOM_SEED)\n",
|
| 542 |
+
"\n",
|
| 543 |
+
"model = resnet34(NUM_CLASSES)\n",
|
| 544 |
+
"model.to(DEVICE)\n",
|
| 545 |
+
"\n",
|
| 546 |
+
"optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE) "
|
| 547 |
+
]
|
| 548 |
+
},
|
| 549 |
+
{
|
| 550 |
+
"cell_type": "code",
|
| 551 |
+
"execution_count": 10,
|
| 552 |
+
"metadata": {
|
| 553 |
+
"colab": {
|
| 554 |
+
"autoexec": {
|
| 555 |
+
"startup": false,
|
| 556 |
+
"wait_interval": 0
|
| 557 |
+
},
|
| 558 |
+
"base_uri": "https://localhost:8080/",
|
| 559 |
+
"height": 1547
|
| 560 |
+
},
|
| 561 |
+
"colab_type": "code",
|
| 562 |
+
"executionInfo": {
|
| 563 |
+
"elapsed": 2384585,
|
| 564 |
+
"status": "ok",
|
| 565 |
+
"timestamp": 1524976888520,
|
| 566 |
+
"user": {
|
| 567 |
+
"displayName": "Sebastian Raschka",
|
| 568 |
+
"photoUrl": "//lh6.googleusercontent.com/-cxK6yOSQ6uE/AAAAAAAAAAI/AAAAAAAAIfw/P9ar_CHsKOQ/s50-c-k-no/photo.jpg",
|
| 569 |
+
"userId": "118404394130788869227"
|
| 570 |
+
},
|
| 571 |
+
"user_tz": 240
|
| 572 |
+
},
|
| 573 |
+
"id": "Dzh3ROmRj5w7",
|
| 574 |
+
"outputId": "5f8fd8c9-b076-403a-b0b7-fd2d498b48d7",
|
| 575 |
+
"scrolled": false
|
| 576 |
+
},
|
| 577 |
+
"outputs": [
|
| 578 |
+
{
|
| 579 |
+
"name": "stdout",
|
| 580 |
+
"output_type": "stream",
|
| 581 |
+
"text": [
|
| 582 |
+
"Epoch: 001/010 | Batch 0000/0196 | Cost: 2.6021\n",
|
| 583 |
+
"Epoch: 001/010 | Batch 0150/0196 | Cost: 1.3961\n",
|
| 584 |
+
"Epoch: 001/010 | Train: 45.084%\n",
|
| 585 |
+
"Time elapsed: 0.39 min\n",
|
| 586 |
+
"Epoch: 002/010 | Batch 0000/0196 | Cost: 1.1228\n",
|
| 587 |
+
"Epoch: 002/010 | Batch 0150/0196 | Cost: 1.0426\n",
|
| 588 |
+
"Epoch: 002/010 | Train: 56.166%\n",
|
| 589 |
+
"Time elapsed: 0.77 min\n",
|
| 590 |
+
"Epoch: 003/010 | Batch 0000/0196 | Cost: 0.9980\n",
|
| 591 |
+
"Epoch: 003/010 | Batch 0150/0196 | Cost: 0.8279\n",
|
| 592 |
+
"Epoch: 003/010 | Train: 66.702%\n",
|
| 593 |
+
"Time elapsed: 1.16 min\n",
|
| 594 |
+
"Epoch: 004/010 | Batch 0000/0196 | Cost: 0.6384\n",
|
| 595 |
+
"Epoch: 004/010 | Batch 0150/0196 | Cost: 0.7103\n",
|
| 596 |
+
"Epoch: 004/010 | Train: 65.330%\n",
|
| 597 |
+
"Time elapsed: 1.55 min\n",
|
| 598 |
+
"Epoch: 005/010 | Batch 0000/0196 | Cost: 0.6308\n",
|
| 599 |
+
"Epoch: 005/010 | Batch 0150/0196 | Cost: 0.5913\n",
|
| 600 |
+
"Epoch: 005/010 | Train: 79.636%\n",
|
| 601 |
+
"Time elapsed: 1.94 min\n",
|
| 602 |
+
"Epoch: 006/010 | Batch 0000/0196 | Cost: 0.4409\n",
|
| 603 |
+
"Epoch: 006/010 | Batch 0150/0196 | Cost: 0.5557\n",
|
| 604 |
+
"Epoch: 006/010 | Train: 76.456%\n",
|
| 605 |
+
"Time elapsed: 2.33 min\n",
|
| 606 |
+
"Epoch: 007/010 | Batch 0000/0196 | Cost: 0.4778\n",
|
| 607 |
+
"Epoch: 007/010 | Batch 0150/0196 | Cost: 0.4815\n",
|
| 608 |
+
"Epoch: 007/010 | Train: 65.890%\n",
|
| 609 |
+
"Time elapsed: 2.71 min\n",
|
| 610 |
+
"Epoch: 008/010 | Batch 0000/0196 | Cost: 0.3782\n",
|
| 611 |
+
"Epoch: 008/010 | Batch 0150/0196 | Cost: 0.4339\n",
|
| 612 |
+
"Epoch: 008/010 | Train: 85.200%\n",
|
| 613 |
+
"Time elapsed: 3.10 min\n",
|
| 614 |
+
"Epoch: 009/010 | Batch 0000/0196 | Cost: 0.3083\n",
|
| 615 |
+
"Epoch: 009/010 | Batch 0150/0196 | Cost: 0.3290\n",
|
| 616 |
+
"Epoch: 009/010 | Train: 78.108%\n",
|
| 617 |
+
"Time elapsed: 3.49 min\n",
|
| 618 |
+
"Epoch: 010/010 | Batch 0000/0196 | Cost: 0.2229\n",
|
| 619 |
+
"Epoch: 010/010 | Batch 0150/0196 | Cost: 0.1945\n",
|
| 620 |
+
"Epoch: 010/010 | Train: 87.384%\n",
|
| 621 |
+
"Time elapsed: 3.88 min\n",
|
| 622 |
+
"Total Training Time: 3.88 min\n",
|
| 623 |
+
"Test accuracy: 70.67%\n",
|
| 624 |
+
"Total Time: 3.91 min\n"
|
| 625 |
+
]
|
| 626 |
+
}
|
| 627 |
+
],
|
| 628 |
+
"source": [
|
| 629 |
+
"def compute_accuracy(model, data_loader, device):\n",
|
| 630 |
+
" correct_pred, num_examples = 0, 0\n",
|
| 631 |
+
" for i, (features, targets) in enumerate(data_loader):\n",
|
| 632 |
+
" \n",
|
| 633 |
+
" features = features.to(device)\n",
|
| 634 |
+
" targets = targets.to(device)\n",
|
| 635 |
+
"\n",
|
| 636 |
+
" logits, probas = model(features)\n",
|
| 637 |
+
" _, predicted_labels = torch.max(probas, 1)\n",
|
| 638 |
+
" num_examples += targets.size(0)\n",
|
| 639 |
+
" correct_pred += (predicted_labels == targets).sum()\n",
|
| 640 |
+
" return correct_pred.float()/num_examples * 100\n",
|
| 641 |
+
" \n",
|
| 642 |
+
"\n",
|
| 643 |
+
"start_time = time.time()\n",
|
| 644 |
+
"for epoch in range(NUM_EPOCHS):\n",
|
| 645 |
+
" \n",
|
| 646 |
+
" model.train()\n",
|
| 647 |
+
" for batch_idx, (features, targets) in enumerate(train_loader):\n",
|
| 648 |
+
" \n",
|
| 649 |
+
" features = features.to(DEVICE)\n",
|
| 650 |
+
" targets = targets.to(DEVICE)\n",
|
| 651 |
+
" \n",
|
| 652 |
+
" ### FORWARD AND BACK PROP\n",
|
| 653 |
+
" logits, probas = model(features)\n",
|
| 654 |
+
" cost = F.cross_entropy(logits, targets)\n",
|
| 655 |
+
" optimizer.zero_grad()\n",
|
| 656 |
+
" \n",
|
| 657 |
+
" cost.backward()\n",
|
| 658 |
+
" \n",
|
| 659 |
+
" ### UPDATE MODEL PARAMETERS\n",
|
| 660 |
+
" optimizer.step()\n",
|
| 661 |
+
" \n",
|
| 662 |
+
" ### LOGGING\n",
|
| 663 |
+
" if not batch_idx % 150:\n",
|
| 664 |
+
" print ('Epoch: %03d/%03d | Batch %04d/%04d | Cost: %.4f' \n",
|
| 665 |
+
" %(epoch+1, NUM_EPOCHS, batch_idx, \n",
|
| 666 |
+
" len(train_loader), cost))\n",
|
| 667 |
+
"\n",
|
| 668 |
+
" \n",
|
| 669 |
+
"\n",
|
| 670 |
+
" model.eval()\n",
|
| 671 |
+
" with torch.set_grad_enabled(False): # save memory during inference\n",
|
| 672 |
+
" print('Epoch: %03d/%03d | Train: %.3f%%' % (\n",
|
| 673 |
+
" epoch+1, NUM_EPOCHS, \n",
|
| 674 |
+
" compute_accuracy(model, train_loader, device=DEVICE)))\n",
|
| 675 |
+
" \n",
|
| 676 |
+
" print('Time elapsed: %.2f min' % ((time.time() - start_time)/60))\n",
|
| 677 |
+
" \n",
|
| 678 |
+
"print('Total Training Time: %.2f min' % ((time.time() - start_time)/60))\n",
|
| 679 |
+
"\n",
|
| 680 |
+
"\n",
|
| 681 |
+
"with torch.set_grad_enabled(False): # save memory during inference\n",
|
| 682 |
+
" print('Test accuracy: %.2f%%' % (compute_accuracy(model, test_loader, device=DEVICE)))\n",
|
| 683 |
+
" \n",
|
| 684 |
+
"print('Total Time: %.2f min' % ((time.time() - start_time)/60))"
|
| 685 |
+
]
|
| 686 |
+
},
|
| 687 |
+
{
|
| 688 |
+
"cell_type": "markdown",
|
| 689 |
+
"metadata": {},
|
| 690 |
+
"source": [
|
| 691 |
+
"## Conclusions"
|
| 692 |
+
]
|
| 693 |
+
},
|
| 694 |
+
{
|
| 695 |
+
"cell_type": "markdown",
|
| 696 |
+
"metadata": {},
|
| 697 |
+
"source": [
|
| 698 |
+
"Based on the training time without and with `pin_memory=True`, there doesn't seem to be a speed-up when using page-locked (or \"pinned\") memory -- in fact, pinning the memory even slowed down the training. (I reran the code in the opposite order, i.e., `pin_memory=True` first, and got the same results.) This could be due to the relatively small dataset size, batch size.\n"
|
| 699 |
+
]
|
| 700 |
+
}
|
| 701 |
+
],
|
| 702 |
+
"metadata": {
|
| 703 |
+
"accelerator": "GPU",
|
| 704 |
+
"colab": {
|
| 705 |
+
"collapsed_sections": [],
|
| 706 |
+
"default_view": {},
|
| 707 |
+
"name": "convnet-vgg16.ipynb",
|
| 708 |
+
"provenance": [],
|
| 709 |
+
"version": "0.3.2",
|
| 710 |
+
"views": {}
|
| 711 |
+
},
|
| 712 |
+
"kernelspec": {
|
| 713 |
+
"display_name": "Python 3",
|
| 714 |
+
"language": "python",
|
| 715 |
+
"name": "python3"
|
| 716 |
+
},
|
| 717 |
+
"language_info": {
|
| 718 |
+
"codemirror_mode": {
|
| 719 |
+
"name": "ipython",
|
| 720 |
+
"version": 3
|
| 721 |
+
},
|
| 722 |
+
"file_extension": ".py",
|
| 723 |
+
"mimetype": "text/x-python",
|
| 724 |
+
"name": "python",
|
| 725 |
+
"nbconvert_exporter": "python",
|
| 726 |
+
"pygments_lexer": "ipython3",
|
| 727 |
+
"version": "3.7.1"
|
| 728 |
+
},
|
| 729 |
+
"toc": {
|
| 730 |
+
"nav_menu": {},
|
| 731 |
+
"number_sections": true,
|
| 732 |
+
"sideBar": true,
|
| 733 |
+
"skip_h1_title": false,
|
| 734 |
+
"title_cell": "Table of Contents",
|
| 735 |
+
"title_sidebar": "Contents",
|
| 736 |
+
"toc_cell": true,
|
| 737 |
+
"toc_position": {
|
| 738 |
+
"height": "calc(100% - 180px)",
|
| 739 |
+
"left": "10px",
|
| 740 |
+
"top": "150px",
|
| 741 |
+
"width": "371px"
|
| 742 |
+
},
|
| 743 |
+
"toc_section_display": true,
|
| 744 |
+
"toc_window_display": true
|
| 745 |
+
}
|
| 746 |
+
},
|
| 747 |
+
"nbformat": 4,
|
| 748 |
+
"nbformat_minor": 2
|
| 749 |
+
}
|
Experiments/quicknet_cifar10_centroid.ipynb
ADDED
|
@@ -0,0 +1,698 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"nbformat": 4,
|
| 3 |
+
"nbformat_minor": 0,
|
| 4 |
+
"metadata": {
|
| 5 |
+
"colab": {
|
| 6 |
+
"provenance": [],
|
| 7 |
+
"gpuType": "T4"
|
| 8 |
+
},
|
| 9 |
+
"kernelspec": {
|
| 10 |
+
"name": "python3",
|
| 11 |
+
"display_name": "Python 3"
|
| 12 |
+
},
|
| 13 |
+
"language_info": {
|
| 14 |
+
"name": "python"
|
| 15 |
+
},
|
| 16 |
+
"accelerator": "GPU"
|
| 17 |
+
},
|
| 18 |
+
"cells": [
|
| 19 |
+
{
|
| 20 |
+
"cell_type": "code",
|
| 21 |
+
"execution_count": 1,
|
| 22 |
+
"metadata": {
|
| 23 |
+
"colab": {
|
| 24 |
+
"base_uri": "https://localhost:8080/"
|
| 25 |
+
},
|
| 26 |
+
"id": "5aR8PYXi684E",
|
| 27 |
+
"outputId": "a6b59bdf-41f5-4b86-e7b1-47ceb6fb6893"
|
| 28 |
+
},
|
| 29 |
+
"outputs": [
|
| 30 |
+
{
|
| 31 |
+
"output_type": "stream",
|
| 32 |
+
"name": "stdout",
|
| 33 |
+
"text": [
|
| 34 |
+
"Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz\n",
|
| 35 |
+
"170498071/170498071 [==============================] - 5s 0us/step\n",
|
| 36 |
+
"Epoch 1/10\n",
|
| 37 |
+
"782/782 [==============================] - 14s 10ms/step - loss: 1.6260 - accuracy: 0.3923 - val_loss: 1.2804 - val_accuracy: 0.5374\n",
|
| 38 |
+
"Epoch 2/10\n",
|
| 39 |
+
"782/782 [==============================] - 6s 8ms/step - loss: 1.1254 - accuracy: 0.5914 - val_loss: 1.0466 - val_accuracy: 0.6335\n",
|
| 40 |
+
"Epoch 3/10\n",
|
| 41 |
+
"782/782 [==============================] - 6s 8ms/step - loss: 0.9187 - accuracy: 0.6723 - val_loss: 0.9976 - val_accuracy: 0.6503\n",
|
| 42 |
+
"Epoch 4/10\n",
|
| 43 |
+
"782/782 [==============================] - 6s 8ms/step - loss: 0.7763 - accuracy: 0.7259 - val_loss: 0.8091 - val_accuracy: 0.7194\n",
|
| 44 |
+
"Epoch 5/10\n",
|
| 45 |
+
"782/782 [==============================] - 6s 8ms/step - loss: 0.6660 - accuracy: 0.7670 - val_loss: 0.7875 - val_accuracy: 0.7268\n",
|
| 46 |
+
"Epoch 6/10\n",
|
| 47 |
+
"782/782 [==============================] - 6s 8ms/step - loss: 0.5799 - accuracy: 0.7962 - val_loss: 0.8075 - val_accuracy: 0.7319\n",
|
| 48 |
+
"Epoch 7/10\n",
|
| 49 |
+
"782/782 [==============================] - 6s 8ms/step - loss: 0.5006 - accuracy: 0.8232 - val_loss: 0.8128 - val_accuracy: 0.7300\n",
|
| 50 |
+
"Epoch 8/10\n",
|
| 51 |
+
"782/782 [==============================] - 6s 8ms/step - loss: 0.4253 - accuracy: 0.8497 - val_loss: 0.8981 - val_accuracy: 0.7261\n",
|
| 52 |
+
"Epoch 9/10\n",
|
| 53 |
+
"782/782 [==============================] - 6s 8ms/step - loss: 0.3633 - accuracy: 0.8707 - val_loss: 0.8523 - val_accuracy: 0.7377\n",
|
| 54 |
+
"Epoch 10/10\n",
|
| 55 |
+
"782/782 [==============================] - 7s 9ms/step - loss: 0.3099 - accuracy: 0.8887 - val_loss: 0.8689 - val_accuracy: 0.7499\n",
|
| 56 |
+
"1563/1563 [==============================] - 4s 2ms/step\n"
|
| 57 |
+
]
|
| 58 |
+
},
|
| 59 |
+
{
|
| 60 |
+
"output_type": "execute_result",
|
| 61 |
+
"data": {
|
| 62 |
+
"text/plain": [
|
| 63 |
+
"array([[8.79073262e-01, 2.84781260e-03, 3.16492580e-02, 9.60051734e-03,\n",
|
| 64 |
+
" 8.88401456e-03, 4.04336443e-03, 1.89202465e-03, 2.60451157e-03,\n",
|
| 65 |
+
" 4.94241416e-02, 9.98295005e-03],\n",
|
| 66 |
+
" [5.34512708e-03, 9.61387515e-01, 1.82328641e-03, 1.49091403e-03,\n",
|
| 67 |
+
" 3.62021121e-04, 2.36713630e-03, 9.99960001e-04, 3.58641875e-04,\n",
|
| 68 |
+
" 8.00825842e-03, 1.78600885e-02],\n",
|
| 69 |
+
" [1.93752628e-02, 5.49300632e-04, 8.71673167e-01, 3.62917222e-02,\n",
|
| 70 |
+
" 2.81240512e-02, 1.94428060e-02, 1.17484620e-02, 7.90499710e-03,\n",
|
| 71 |
+
" 3.91457370e-03, 9.77586606e-04],\n",
|
| 72 |
+
" [3.43155907e-03, 3.87360866e-04, 3.25083360e-02, 8.04273427e-01,\n",
|
| 73 |
+
" 2.26426423e-02, 1.00013100e-01, 1.48316240e-02, 1.55726429e-02,\n",
|
| 74 |
+
" 3.46732978e-03, 2.87183723e-03],\n",
|
| 75 |
+
" [4.50019445e-03, 3.07688024e-04, 3.81155163e-02, 3.81957851e-02,\n",
|
| 76 |
+
" 8.43278885e-01, 2.98470370e-02, 1.07433125e-02, 3.15311365e-02,\n",
|
| 77 |
+
" 2.45781220e-03, 1.02343888e-03],\n",
|
| 78 |
+
" [8.63184629e-04, 2.92542420e-04, 2.77942196e-02, 1.26914889e-01,\n",
|
| 79 |
+
" 1.68064609e-02, 7.96062946e-01, 6.89344807e-03, 2.19228752e-02,\n",
|
| 80 |
+
" 1.45196577e-03, 9.96942166e-04],\n",
|
| 81 |
+
" [2.25282880e-03, 1.00196269e-03, 4.27705161e-02, 4.95258309e-02,\n",
|
| 82 |
+
" 1.27582047e-02, 1.77529752e-02, 8.66206527e-01, 2.50199367e-03,\n",
|
| 83 |
+
" 2.86106160e-03, 2.37041595e-03],\n",
|
| 84 |
+
" [1.76379981e-03, 2.93744437e-04, 8.86812247e-03, 1.57593563e-02,\n",
|
| 85 |
+
" 1.64906941e-02, 2.32193433e-02, 5.39854867e-04, 9.30097342e-01,\n",
|
| 86 |
+
" 8.36823776e-04, 2.13308516e-03],\n",
|
| 87 |
+
" [1.30738802e-02, 4.98946663e-03, 2.41551897e-03, 5.38546871e-03,\n",
|
| 88 |
+
" 1.31231989e-03, 1.98870571e-03, 1.08822726e-03, 5.06214972e-04,\n",
|
| 89 |
+
" 9.64531541e-01, 4.71282937e-03],\n",
|
| 90 |
+
" [1.20871738e-02, 2.48907283e-02, 2.02297443e-03, 6.83936710e-03,\n",
|
| 91 |
+
" 1.44265452e-03, 4.34162840e-03, 1.16528722e-03, 2.80632032e-03,\n",
|
| 92 |
+
" 1.10650901e-02, 9.33342278e-01]])"
|
| 93 |
+
]
|
| 94 |
+
},
|
| 95 |
+
"metadata": {},
|
| 96 |
+
"execution_count": 1
|
| 97 |
+
}
|
| 98 |
+
],
|
| 99 |
+
"source": [
|
| 100 |
+
"from scipy.spatial.distance import euclidean\n",
|
| 101 |
+
"import tensorflow as tf\n",
|
| 102 |
+
"from tensorflow.keras import datasets, layers, models\n",
|
| 103 |
+
"import numpy as np\n",
|
| 104 |
+
"\n",
|
| 105 |
+
"\n",
|
| 106 |
+
"(train_images, train_labels), (test_images, test_labels) = datasets.cifar10.load_data()\n",
|
| 107 |
+
"\n",
|
| 108 |
+
"\n",
|
| 109 |
+
"train_images = train_images / 255.0\n",
|
| 110 |
+
"test_images = test_images / 255.0\n",
|
| 111 |
+
"\n",
|
| 112 |
+
"\n",
|
| 113 |
+
"def create_resnet18():\n",
|
| 114 |
+
" model = models.Sequential()\n",
|
| 115 |
+
" model.add(layers.Conv2D(64, (3, 3), activation='relu', padding='same', input_shape=(32, 32, 3)))\n",
|
| 116 |
+
" model.add(layers.MaxPooling2D((2, 2)))\n",
|
| 117 |
+
" model.add(layers.Conv2D(64, (3, 3), activation='relu', padding='same'))\n",
|
| 118 |
+
" model.add(layers.MaxPooling2D((2, 2)))\n",
|
| 119 |
+
" model.add(layers.Conv2D(128, (3, 3), activation='relu', padding='same'))\n",
|
| 120 |
+
" model.add(layers.Conv2D(128, (3, 3), activation='relu', padding='same'))\n",
|
| 121 |
+
" model.add(layers.MaxPooling2D((2, 2)))\n",
|
| 122 |
+
" model.add(layers.Conv2D(256, (3, 3), activation='relu', padding='same'))\n",
|
| 123 |
+
" model.add(layers.Conv2D(256, (3, 3), activation='relu', padding='same'))\n",
|
| 124 |
+
" model.add(layers.MaxPooling2D((2, 2)))\n",
|
| 125 |
+
" model.add(layers.Flatten())\n",
|
| 126 |
+
" model.add(layers.Dense(512, activation='relu'))\n",
|
| 127 |
+
" model.add(layers.Dense(10, activation='softmax'))\n",
|
| 128 |
+
" return model\n",
|
| 129 |
+
"\n",
|
| 130 |
+
"\n",
|
| 131 |
+
"model = create_resnet18()\n",
|
| 132 |
+
"\n",
|
| 133 |
+
"\n",
|
| 134 |
+
"model.compile(optimizer='adam',\n",
|
| 135 |
+
" loss='sparse_categorical_crossentropy',\n",
|
| 136 |
+
" metrics=['accuracy'])\n",
|
| 137 |
+
"\n",
|
| 138 |
+
"\n",
|
| 139 |
+
"model.fit(train_images, train_labels, epochs=10, batch_size=64, validation_data=(test_images, test_labels))\n",
|
| 140 |
+
"\n",
|
| 141 |
+
"\n",
|
| 142 |
+
"train_embeddings = model.predict(train_images)\n",
|
| 143 |
+
"\n",
|
| 144 |
+
"num_classes = 10\n",
|
| 145 |
+
"mean_embeddings = np.zeros((num_classes, train_embeddings.shape[1]))\n",
|
| 146 |
+
"\n",
|
| 147 |
+
"\n",
|
| 148 |
+
"for class_label in range(num_classes):\n",
|
| 149 |
+
" class_indices = np.where(train_labels.flatten() == class_label)[0]\n",
|
| 150 |
+
" class_embeddings = train_embeddings[class_indices]\n",
|
| 151 |
+
" mean_embedding = np.mean(class_embeddings, axis=0)\n",
|
| 152 |
+
" mean_embeddings[class_label] = mean_embedding\n",
|
| 153 |
+
"\n",
|
| 154 |
+
"\n",
|
| 155 |
+
"def classify_query(query_image):\n",
|
| 156 |
+
" query_embedding = model.predict(np.expand_dims(query_image, axis=0))\n",
|
| 157 |
+
" distances = [euclidean(query_embedding.flatten(), mean_embedding) for mean_embedding in mean_embeddings]\n",
|
| 158 |
+
" predicted_class = np.argmin(distances)\n",
|
| 159 |
+
" return predicted_class\n",
|
| 160 |
+
"\n",
|
| 161 |
+
"mean_embeddings\n"
|
| 162 |
+
]
|
| 163 |
+
},
|
| 164 |
+
{
|
| 165 |
+
"cell_type": "code",
|
| 166 |
+
"source": [
|
| 167 |
+
"# Calculate accuracy\n",
|
| 168 |
+
"correct_predictions = 0\n",
|
| 169 |
+
"total_predictions = 500\n",
|
| 170 |
+
"\n",
|
| 171 |
+
"for i in range(500):\n",
|
| 172 |
+
" predicted_class = classify_query(test_images[i])\n",
|
| 173 |
+
" if predicted_class == test_labels[i]:\n",
|
| 174 |
+
" correct_predictions += 1\n",
|
| 175 |
+
"\n",
|
| 176 |
+
"accuracy = correct_predictions / total_predictions\n",
|
| 177 |
+
"print(\"Accuracy:\", accuracy)\n"
|
| 178 |
+
],
|
| 179 |
+
"metadata": {
|
| 180 |
+
"colab": {
|
| 181 |
+
"base_uri": "https://localhost:8080/"
|
| 182 |
+
},
|
| 183 |
+
"id": "1-b_E44RrxuV",
|
| 184 |
+
"outputId": "125e4835-b44a-48bf-ff43-1103c2b1b6f4"
|
| 185 |
+
},
|
| 186 |
+
"execution_count": 2,
|
| 187 |
+
"outputs": [
|
| 188 |
+
{
|
| 189 |
+
"output_type": "stream",
|
| 190 |
+
"name": "stdout",
|
| 191 |
+
"text": [
|
| 192 |
+
"1/1 [==============================] - 0s 388ms/step\n",
|
| 193 |
+
"1/1 [==============================] - 0s 17ms/step\n",
|
| 194 |
+
"1/1 [==============================] - 0s 17ms/step\n",
|
| 195 |
+
"1/1 [==============================] - 0s 19ms/step\n",
|
| 196 |
+
"1/1 [==============================] - 0s 17ms/step\n",
|
| 197 |
+
"1/1 [==============================] - 0s 25ms/step\n",
|
| 198 |
+
"1/1 [==============================] - 0s 18ms/step\n",
|
| 199 |
+
"1/1 [==============================] - 0s 19ms/step\n",
|
| 200 |
+
"1/1 [==============================] - 0s 18ms/step\n",
|
| 201 |
+
"1/1 [==============================] - 0s 17ms/step\n",
|
| 202 |
+
"1/1 [==============================] - 0s 19ms/step\n",
|
| 203 |
+
"1/1 [==============================] - 0s 18ms/step\n",
|
| 204 |
+
"1/1 [==============================] - 0s 19ms/step\n",
|
| 205 |
+
"1/1 [==============================] - 0s 17ms/step\n",
|
| 206 |
+
"1/1 [==============================] - 0s 18ms/step\n",
|
| 207 |
+
"1/1 [==============================] - 0s 18ms/step\n",
|
| 208 |
+
"1/1 [==============================] - 0s 17ms/step\n",
|
| 209 |
+
"1/1 [==============================] - 0s 17ms/step\n",
|
| 210 |
+
"1/1 [==============================] - 0s 18ms/step\n",
|
| 211 |
+
"1/1 [==============================] - 0s 18ms/step\n",
|
| 212 |
+
"1/1 [==============================] - 0s 19ms/step\n",
|
| 213 |
+
"1/1 [==============================] - 0s 18ms/step\n",
|
| 214 |
+
"1/1 [==============================] - 0s 17ms/step\n",
|
| 215 |
+
"1/1 [==============================] - 0s 17ms/step\n",
|
| 216 |
+
"1/1 [==============================] - 0s 17ms/step\n",
|
| 217 |
+
"1/1 [==============================] - 0s 18ms/step\n",
|
| 218 |
+
"1/1 [==============================] - 0s 18ms/step\n",
|
| 219 |
+
"1/1 [==============================] - 0s 18ms/step\n",
|
| 220 |
+
"1/1 [==============================] - 0s 19ms/step\n",
|
| 221 |
+
"1/1 [==============================] - 0s 18ms/step\n",
|
| 222 |
+
"1/1 [==============================] - 0s 17ms/step\n",
|
| 223 |
+
"1/1 [==============================] - 0s 20ms/step\n",
|
| 224 |
+
"1/1 [==============================] - 0s 18ms/step\n",
|
| 225 |
+
"1/1 [==============================] - 0s 17ms/step\n",
|
| 226 |
+
"1/1 [==============================] - 0s 21ms/step\n",
|
| 227 |
+
"1/1 [==============================] - 0s 18ms/step\n",
|
| 228 |
+
"1/1 [==============================] - 0s 18ms/step\n",
|
| 229 |
+
"1/1 [==============================] - 0s 20ms/step\n",
|
| 230 |
+
"1/1 [==============================] - 0s 24ms/step\n",
|
| 231 |
+
"1/1 [==============================] - 0s 18ms/step\n",
|
| 232 |
+
"1/1 [==============================] - 0s 20ms/step\n",
|
| 233 |
+
"1/1 [==============================] - 0s 19ms/step\n",
|
| 234 |
+
"1/1 [==============================] - 0s 18ms/step\n",
|
| 235 |
+
"1/1 [==============================] - 0s 18ms/step\n",
|
| 236 |
+
"1/1 [==============================] - 0s 17ms/step\n",
|
| 237 |
+
"1/1 [==============================] - 0s 18ms/step\n",
|
| 238 |
+
"1/1 [==============================] - 0s 18ms/step\n",
|
| 239 |
+
"1/1 [==============================] - 0s 18ms/step\n",
|
| 240 |
+
"1/1 [==============================] - 0s 17ms/step\n",
|
| 241 |
+
"1/1 [==============================] - 0s 17ms/step\n",
|
| 242 |
+
"1/1 [==============================] - 0s 17ms/step\n",
|
| 243 |
+
"1/1 [==============================] - 0s 17ms/step\n",
|
| 244 |
+
"1/1 [==============================] - 0s 18ms/step\n",
|
| 245 |
+
"1/1 [==============================] - 0s 18ms/step\n",
|
| 246 |
+
"1/1 [==============================] - 0s 20ms/step\n",
|
| 247 |
+
"1/1 [==============================] - 0s 25ms/step\n",
|
| 248 |
+
"1/1 [==============================] - 0s 17ms/step\n",
|
| 249 |
+
"1/1 [==============================] - 0s 17ms/step\n",
|
| 250 |
+
"1/1 [==============================] - 0s 17ms/step\n",
|
| 251 |
+
"1/1 [==============================] - 0s 17ms/step\n",
|
| 252 |
+
"1/1 [==============================] - 0s 17ms/step\n",
|
| 253 |
+
"1/1 [==============================] - 0s 17ms/step\n",
|
| 254 |
+
"1/1 [==============================] - 0s 17ms/step\n",
|
| 255 |
+
"1/1 [==============================] - 0s 18ms/step\n",
|
| 256 |
+
"1/1 [==============================] - 0s 19ms/step\n",
|
| 257 |
+
"1/1 [==============================] - 0s 17ms/step\n",
|
| 258 |
+
"1/1 [==============================] - 0s 17ms/step\n",
|
| 259 |
+
"1/1 [==============================] - 0s 17ms/step\n",
|
| 260 |
+
"1/1 [==============================] - 0s 16ms/step\n",
|
| 261 |
+
"1/1 [==============================] - 0s 17ms/step\n",
|
| 262 |
+
"1/1 [==============================] - 0s 17ms/step\n",
|
| 263 |
+
"1/1 [==============================] - 0s 17ms/step\n",
|
| 264 |
+
"1/1 [==============================] - 0s 23ms/step\n",
|
| 265 |
+
"1/1 [==============================] - 0s 19ms/step\n",
|
| 266 |
+
"1/1 [==============================] - 0s 18ms/step\n",
|
| 267 |
+
"1/1 [==============================] - 0s 17ms/step\n",
|
| 268 |
+
"1/1 [==============================] - 0s 17ms/step\n",
|
| 269 |
+
"1/1 [==============================] - 0s 25ms/step\n",
|
| 270 |
+
"1/1 [==============================] - 0s 17ms/step\n",
|
| 271 |
+
"1/1 [==============================] - 0s 18ms/step\n",
|
| 272 |
+
"1/1 [==============================] - 0s 18ms/step\n",
|
| 273 |
+
"1/1 [==============================] - 0s 17ms/step\n",
|
| 274 |
+
"1/1 [==============================] - 0s 20ms/step\n",
|
| 275 |
+
"1/1 [==============================] - 0s 17ms/step\n",
|
| 276 |
+
"1/1 [==============================] - 0s 17ms/step\n",
|
| 277 |
+
"1/1 [==============================] - 0s 17ms/step\n",
|
| 278 |
+
"1/1 [==============================] - 0s 29ms/step\n",
|
| 279 |
+
"1/1 [==============================] - 0s 31ms/step\n",
|
| 280 |
+
"1/1 [==============================] - 0s 31ms/step\n",
|
| 281 |
+
"1/1 [==============================] - 0s 28ms/step\n",
|
| 282 |
+
"1/1 [==============================] - 0s 27ms/step\n",
|
| 283 |
+
"1/1 [==============================] - 0s 24ms/step\n",
|
| 284 |
+
"1/1 [==============================] - 0s 26ms/step\n",
|
| 285 |
+
"1/1 [==============================] - 0s 25ms/step\n",
|
| 286 |
+
"1/1 [==============================] - 0s 26ms/step\n",
|
| 287 |
+
"1/1 [==============================] - 0s 25ms/step\n",
|
| 288 |
+
"1/1 [==============================] - 0s 29ms/step\n",
|
| 289 |
+
"1/1 [==============================] - 0s 25ms/step\n",
|
| 290 |
+
"1/1 [==============================] - 0s 26ms/step\n",
|
| 291 |
+
"1/1 [==============================] - 0s 48ms/step\n",
|
| 292 |
+
"1/1 [==============================] - 0s 32ms/step\n",
|
| 293 |
+
"1/1 [==============================] - 0s 27ms/step\n",
|
| 294 |
+
"1/1 [==============================] - 0s 25ms/step\n",
|
| 295 |
+
"1/1 [==============================] - 0s 28ms/step\n",
|
| 296 |
+
"1/1 [==============================] - 0s 24ms/step\n",
|
| 297 |
+
"1/1 [==============================] - 0s 28ms/step\n",
|
| 298 |
+
"1/1 [==============================] - 0s 31ms/step\n",
|
| 299 |
+
"1/1 [==============================] - 0s 27ms/step\n",
|
| 300 |
+
"1/1 [==============================] - 0s 29ms/step\n",
|
| 301 |
+
"1/1 [==============================] - 0s 18ms/step\n",
|
| 302 |
+
"1/1 [==============================] - 0s 18ms/step\n",
|
| 303 |
+
"1/1 [==============================] - 0s 28ms/step\n",
|
| 304 |
+
"1/1 [==============================] - 0s 19ms/step\n",
|
| 305 |
+
"1/1 [==============================] - 0s 19ms/step\n",
|
| 306 |
+
"1/1 [==============================] - 0s 17ms/step\n",
|
| 307 |
+
"1/1 [==============================] - 0s 17ms/step\n",
|
| 308 |
+
"1/1 [==============================] - 0s 17ms/step\n",
|
| 309 |
+
"1/1 [==============================] - 0s 18ms/step\n",
|
| 310 |
+
"1/1 [==============================] - 0s 21ms/step\n",
|
| 311 |
+
"1/1 [==============================] - 0s 18ms/step\n",
|
| 312 |
+
"1/1 [==============================] - 0s 17ms/step\n",
|
| 313 |
+
"1/1 [==============================] - 0s 17ms/step\n",
|
| 314 |
+
"1/1 [==============================] - 0s 17ms/step\n",
|
| 315 |
+
"1/1 [==============================] - 0s 16ms/step\n",
|
| 316 |
+
"1/1 [==============================] - 0s 17ms/step\n",
|
| 317 |
+
"1/1 [==============================] - 0s 17ms/step\n",
|
| 318 |
+
"1/1 [==============================] - 0s 17ms/step\n",
|
| 319 |
+
"1/1 [==============================] - 0s 20ms/step\n",
|
| 320 |
+
"1/1 [==============================] - 0s 19ms/step\n",
|
| 321 |
+
"1/1 [==============================] - 0s 17ms/step\n",
|
| 322 |
+
"1/1 [==============================] - 0s 17ms/step\n",
|
| 323 |
+
"1/1 [==============================] - 0s 17ms/step\n",
|
| 324 |
+
"1/1 [==============================] - 0s 16ms/step\n",
|
| 325 |
+
"1/1 [==============================] - 0s 19ms/step\n",
|
| 326 |
+
"1/1 [==============================] - 0s 17ms/step\n",
|
| 327 |
+
"1/1 [==============================] - 0s 17ms/step\n",
|
| 328 |
+
"1/1 [==============================] - 0s 17ms/step\n",
|
| 329 |
+
"1/1 [==============================] - 0s 18ms/step\n",
|
| 330 |
+
"1/1 [==============================] - 0s 21ms/step\n",
|
| 331 |
+
"1/1 [==============================] - 0s 18ms/step\n",
|
| 332 |
+
"1/1 [==============================] - 0s 17ms/step\n",
|
| 333 |
+
"1/1 [==============================] - 0s 17ms/step\n",
|
| 334 |
+
"1/1 [==============================] - 0s 17ms/step\n",
|
| 335 |
+
"1/1 [==============================] - 0s 18ms/step\n",
|
| 336 |
+
"1/1 [==============================] - 0s 17ms/step\n",
|
| 337 |
+
"1/1 [==============================] - 0s 17ms/step\n",
|
| 338 |
+
"1/1 [==============================] - 0s 18ms/step\n",
|
| 339 |
+
"1/1 [==============================] - 0s 17ms/step\n",
|
| 340 |
+
"1/1 [==============================] - 0s 18ms/step\n",
|
| 341 |
+
"1/1 [==============================] - 0s 18ms/step\n",
|
| 342 |
+
"1/1 [==============================] - 0s 18ms/step\n",
|
| 343 |
+
"1/1 [==============================] - 0s 18ms/step\n",
|
| 344 |
+
"1/1 [==============================] - 0s 17ms/step\n",
|
| 345 |
+
"1/1 [==============================] - 0s 17ms/step\n",
|
| 346 |
+
"1/1 [==============================] - 0s 17ms/step\n",
|
| 347 |
+
"1/1 [==============================] - 0s 18ms/step\n",
|
| 348 |
+
"1/1 [==============================] - 0s 19ms/step\n",
|
| 349 |
+
"1/1 [==============================] - 0s 19ms/step\n",
|
| 350 |
+
"1/1 [==============================] - 0s 19ms/step\n",
|
| 351 |
+
"1/1 [==============================] - 0s 18ms/step\n",
|
| 352 |
+
"1/1 [==============================] - 0s 18ms/step\n",
|
| 353 |
+
"1/1 [==============================] - 0s 24ms/step\n",
|
| 354 |
+
"1/1 [==============================] - 0s 18ms/step\n",
|
| 355 |
+
"1/1 [==============================] - 0s 18ms/step\n",
|
| 356 |
+
"1/1 [==============================] - 0s 18ms/step\n",
|
| 357 |
+
"1/1 [==============================] - 0s 17ms/step\n",
|
| 358 |
+
"1/1 [==============================] - 0s 17ms/step\n",
|
| 359 |
+
"1/1 [==============================] - 0s 18ms/step\n",
|
| 360 |
+
"1/1 [==============================] - 0s 19ms/step\n",
|
| 361 |
+
"1/1 [==============================] - 0s 20ms/step\n",
|
| 362 |
+
"1/1 [==============================] - 0s 17ms/step\n",
|
| 363 |
+
"1/1 [==============================] - 0s 17ms/step\n",
|
| 364 |
+
"1/1 [==============================] - 0s 17ms/step\n",
|
| 365 |
+
"1/1 [==============================] - 0s 17ms/step\n",
|
| 366 |
+
"1/1 [==============================] - 0s 18ms/step\n",
|
| 367 |
+
"1/1 [==============================] - 0s 19ms/step\n",
|
| 368 |
+
"1/1 [==============================] - 0s 20ms/step\n",
|
| 369 |
+
"1/1 [==============================] - 0s 18ms/step\n",
|
| 370 |
+
"1/1 [==============================] - 0s 24ms/step\n",
|
| 371 |
+
"1/1 [==============================] - 0s 17ms/step\n",
|
| 372 |
+
"1/1 [==============================] - 0s 17ms/step\n",
|
| 373 |
+
"1/1 [==============================] - 0s 17ms/step\n",
|
| 374 |
+
"1/1 [==============================] - 0s 18ms/step\n",
|
| 375 |
+
"1/1 [==============================] - 0s 17ms/step\n",
|
| 376 |
+
"1/1 [==============================] - 0s 18ms/step\n",
|
| 377 |
+
"1/1 [==============================] - 0s 18ms/step\n",
|
| 378 |
+
"1/1 [==============================] - 0s 17ms/step\n",
|
| 379 |
+
"1/1 [==============================] - 0s 16ms/step\n",
|
| 380 |
+
"1/1 [==============================] - 0s 17ms/step\n",
|
| 381 |
+
"1/1 [==============================] - 0s 19ms/step\n",
|
| 382 |
+
"1/1 [==============================] - 0s 18ms/step\n",
|
| 383 |
+
"1/1 [==============================] - 0s 18ms/step\n",
|
| 384 |
+
"1/1 [==============================] - 0s 21ms/step\n",
|
| 385 |
+
"1/1 [==============================] - 0s 18ms/step\n",
|
| 386 |
+
"1/1 [==============================] - 0s 17ms/step\n",
|
| 387 |
+
"1/1 [==============================] - 0s 17ms/step\n",
|
| 388 |
+
"1/1 [==============================] - 0s 17ms/step\n",
|
| 389 |
+
"1/1 [==============================] - 0s 17ms/step\n",
|
| 390 |
+
"1/1 [==============================] - 0s 18ms/step\n",
|
| 391 |
+
"1/1 [==============================] - 0s 18ms/step\n",
|
| 392 |
+
"1/1 [==============================] - 0s 18ms/step\n",
|
| 393 |
+
"1/1 [==============================] - 0s 17ms/step\n",
|
| 394 |
+
"1/1 [==============================] - 0s 18ms/step\n",
|
| 395 |
+
"1/1 [==============================] - 0s 18ms/step\n",
|
| 396 |
+
"1/1 [==============================] - 0s 17ms/step\n",
|
| 397 |
+
"1/1 [==============================] - 0s 17ms/step\n",
|
| 398 |
+
"1/1 [==============================] - 0s 18ms/step\n",
|
| 399 |
+
"1/1 [==============================] - 0s 17ms/step\n",
|
| 400 |
+
"1/1 [==============================] - 0s 18ms/step\n",
|
| 401 |
+
"1/1 [==============================] - 0s 19ms/step\n",
|
| 402 |
+
"1/1 [==============================] - 0s 24ms/step\n",
|
| 403 |
+
"1/1 [==============================] - 0s 26ms/step\n",
|
| 404 |
+
"1/1 [==============================] - 0s 18ms/step\n",
|
| 405 |
+
"1/1 [==============================] - 0s 18ms/step\n",
|
| 406 |
+
"1/1 [==============================] - 0s 16ms/step\n",
|
| 407 |
+
"1/1 [==============================] - 0s 17ms/step\n",
|
| 408 |
+
"1/1 [==============================] - 0s 18ms/step\n",
|
| 409 |
+
"1/1 [==============================] - 0s 18ms/step\n",
|
| 410 |
+
"1/1 [==============================] - 0s 19ms/step\n",
|
| 411 |
+
"1/1 [==============================] - 0s 17ms/step\n",
|
| 412 |
+
"1/1 [==============================] - 0s 16ms/step\n",
|
| 413 |
+
"1/1 [==============================] - 0s 16ms/step\n",
|
| 414 |
+
"1/1 [==============================] - 0s 18ms/step\n",
|
| 415 |
+
"1/1 [==============================] - 0s 18ms/step\n",
|
| 416 |
+
"1/1 [==============================] - 0s 17ms/step\n",
|
| 417 |
+
"1/1 [==============================] - 0s 19ms/step\n",
|
| 418 |
+
"1/1 [==============================] - 0s 17ms/step\n",
|
| 419 |
+
"1/1 [==============================] - 0s 19ms/step\n",
|
| 420 |
+
"1/1 [==============================] - 0s 20ms/step\n",
|
| 421 |
+
"1/1 [==============================] - 0s 18ms/step\n",
|
| 422 |
+
"1/1 [==============================] - 0s 17ms/step\n",
|
| 423 |
+
"1/1 [==============================] - 0s 17ms/step\n",
|
| 424 |
+
"1/1 [==============================] - 0s 17ms/step\n",
|
| 425 |
+
"1/1 [==============================] - 0s 17ms/step\n",
|
| 426 |
+
"1/1 [==============================] - 0s 18ms/step\n",
|
| 427 |
+
"1/1 [==============================] - 0s 18ms/step\n",
|
| 428 |
+
"1/1 [==============================] - 0s 18ms/step\n",
|
| 429 |
+
"1/1 [==============================] - 0s 17ms/step\n",
|
| 430 |
+
"1/1 [==============================] - 0s 18ms/step\n",
|
| 431 |
+
"1/1 [==============================] - 0s 17ms/step\n",
|
| 432 |
+
"1/1 [==============================] - 0s 18ms/step\n",
|
| 433 |
+
"1/1 [==============================] - 0s 19ms/step\n",
|
| 434 |
+
"1/1 [==============================] - 0s 18ms/step\n",
|
| 435 |
+
"1/1 [==============================] - 0s 21ms/step\n",
|
| 436 |
+
"1/1 [==============================] - 0s 25ms/step\n",
|
| 437 |
+
"1/1 [==============================] - 0s 18ms/step\n",
|
| 438 |
+
"1/1 [==============================] - 0s 18ms/step\n",
|
| 439 |
+
"1/1 [==============================] - 0s 18ms/step\n",
|
| 440 |
+
"1/1 [==============================] - 0s 17ms/step\n",
|
| 441 |
+
"1/1 [==============================] - 0s 18ms/step\n",
|
| 442 |
+
"1/1 [==============================] - 0s 18ms/step\n",
|
| 443 |
+
"1/1 [==============================] - 0s 17ms/step\n",
|
| 444 |
+
"1/1 [==============================] - 0s 18ms/step\n",
|
| 445 |
+
"1/1 [==============================] - 0s 17ms/step\n",
|
| 446 |
+
"1/1 [==============================] - 0s 17ms/step\n",
|
| 447 |
+
"1/1 [==============================] - 0s 18ms/step\n",
|
| 448 |
+
"1/1 [==============================] - 0s 19ms/step\n",
|
| 449 |
+
"1/1 [==============================] - 0s 20ms/step\n",
|
| 450 |
+
"1/1 [==============================] - 0s 17ms/step\n",
|
| 451 |
+
"1/1 [==============================] - 0s 17ms/step\n",
|
| 452 |
+
"1/1 [==============================] - 0s 18ms/step\n",
|
| 453 |
+
"1/1 [==============================] - 0s 20ms/step\n",
|
| 454 |
+
"1/1 [==============================] - 0s 18ms/step\n",
|
| 455 |
+
"1/1 [==============================] - 0s 18ms/step\n",
|
| 456 |
+
"1/1 [==============================] - 0s 17ms/step\n",
|
| 457 |
+
"1/1 [==============================] - 0s 18ms/step\n",
|
| 458 |
+
"1/1 [==============================] - 0s 27ms/step\n",
|
| 459 |
+
"1/1 [==============================] - 0s 30ms/step\n",
|
| 460 |
+
"1/1 [==============================] - 0s 31ms/step\n",
|
| 461 |
+
"1/1 [==============================] - 0s 26ms/step\n",
|
| 462 |
+
"1/1 [==============================] - 0s 35ms/step\n",
|
| 463 |
+
"1/1 [==============================] - 0s 28ms/step\n",
|
| 464 |
+
"1/1 [==============================] - 0s 26ms/step\n",
|
| 465 |
+
"1/1 [==============================] - 0s 25ms/step\n",
|
| 466 |
+
"1/1 [==============================] - 0s 28ms/step\n",
|
| 467 |
+
"1/1 [==============================] - 0s 25ms/step\n",
|
| 468 |
+
"1/1 [==============================] - 0s 30ms/step\n",
|
| 469 |
+
"1/1 [==============================] - 0s 26ms/step\n",
|
| 470 |
+
"1/1 [==============================] - 0s 27ms/step\n",
|
| 471 |
+
"1/1 [==============================] - 0s 26ms/step\n",
|
| 472 |
+
"1/1 [==============================] - 0s 28ms/step\n",
|
| 473 |
+
"1/1 [==============================] - 0s 32ms/step\n",
|
| 474 |
+
"1/1 [==============================] - 0s 29ms/step\n",
|
| 475 |
+
"1/1 [==============================] - 0s 26ms/step\n",
|
| 476 |
+
"1/1 [==============================] - 0s 34ms/step\n",
|
| 477 |
+
"1/1 [==============================] - 0s 30ms/step\n",
|
| 478 |
+
"1/1 [==============================] - 0s 27ms/step\n",
|
| 479 |
+
"1/1 [==============================] - 0s 26ms/step\n",
|
| 480 |
+
"1/1 [==============================] - 0s 35ms/step\n",
|
| 481 |
+
"1/1 [==============================] - 0s 30ms/step\n",
|
| 482 |
+
"1/1 [==============================] - 0s 19ms/step\n",
|
| 483 |
+
"1/1 [==============================] - 0s 19ms/step\n",
|
| 484 |
+
"1/1 [==============================] - 0s 17ms/step\n",
|
| 485 |
+
"1/1 [==============================] - 0s 17ms/step\n",
|
| 486 |
+
"1/1 [==============================] - 0s 17ms/step\n",
|
| 487 |
+
"1/1 [==============================] - 0s 21ms/step\n",
|
| 488 |
+
"1/1 [==============================] - 0s 19ms/step\n",
|
| 489 |
+
"1/1 [==============================] - 0s 17ms/step\n",
|
| 490 |
+
"1/1 [==============================] - 0s 17ms/step\n",
|
| 491 |
+
"1/1 [==============================] - 0s 17ms/step\n",
|
| 492 |
+
"1/1 [==============================] - 0s 18ms/step\n",
|
| 493 |
+
"1/1 [==============================] - 0s 17ms/step\n",
|
| 494 |
+
"1/1 [==============================] - 0s 17ms/step\n",
|
| 495 |
+
"1/1 [==============================] - 0s 17ms/step\n",
|
| 496 |
+
"1/1 [==============================] - 0s 17ms/step\n",
|
| 497 |
+
"1/1 [==============================] - 0s 18ms/step\n",
|
| 498 |
+
"1/1 [==============================] - 0s 18ms/step\n",
|
| 499 |
+
"1/1 [==============================] - 0s 20ms/step\n",
|
| 500 |
+
"1/1 [==============================] - 0s 18ms/step\n",
|
| 501 |
+
"1/1 [==============================] - 0s 17ms/step\n",
|
| 502 |
+
"1/1 [==============================] - 0s 20ms/step\n",
|
| 503 |
+
"1/1 [==============================] - 0s 19ms/step\n",
|
| 504 |
+
"1/1 [==============================] - 0s 18ms/step\n",
|
| 505 |
+
"1/1 [==============================] - 0s 18ms/step\n",
|
| 506 |
+
"1/1 [==============================] - 0s 17ms/step\n",
|
| 507 |
+
"1/1 [==============================] - 0s 17ms/step\n",
|
| 508 |
+
"1/1 [==============================] - 0s 17ms/step\n",
|
| 509 |
+
"1/1 [==============================] - 0s 17ms/step\n",
|
| 510 |
+
"1/1 [==============================] - 0s 17ms/step\n",
|
| 511 |
+
"1/1 [==============================] - 0s 17ms/step\n",
|
| 512 |
+
"1/1 [==============================] - 0s 17ms/step\n",
|
| 513 |
+
"1/1 [==============================] - 0s 17ms/step\n",
|
| 514 |
+
"1/1 [==============================] - 0s 20ms/step\n",
|
| 515 |
+
"1/1 [==============================] - 0s 18ms/step\n",
|
| 516 |
+
"1/1 [==============================] - 0s 17ms/step\n",
|
| 517 |
+
"1/1 [==============================] - 0s 18ms/step\n",
|
| 518 |
+
"1/1 [==============================] - 0s 19ms/step\n",
|
| 519 |
+
"1/1 [==============================] - 0s 18ms/step\n",
|
| 520 |
+
"1/1 [==============================] - 0s 18ms/step\n",
|
| 521 |
+
"1/1 [==============================] - 0s 17ms/step\n",
|
| 522 |
+
"1/1 [==============================] - 0s 16ms/step\n",
|
| 523 |
+
"1/1 [==============================] - 0s 18ms/step\n",
|
| 524 |
+
"1/1 [==============================] - 0s 18ms/step\n",
|
| 525 |
+
"1/1 [==============================] - 0s 18ms/step\n",
|
| 526 |
+
"1/1 [==============================] - 0s 19ms/step\n",
|
| 527 |
+
"1/1 [==============================] - 0s 19ms/step\n",
|
| 528 |
+
"1/1 [==============================] - 0s 17ms/step\n",
|
| 529 |
+
"1/1 [==============================] - 0s 19ms/step\n",
|
| 530 |
+
"1/1 [==============================] - 0s 20ms/step\n",
|
| 531 |
+
"1/1 [==============================] - 0s 19ms/step\n",
|
| 532 |
+
"1/1 [==============================] - 0s 19ms/step\n",
|
| 533 |
+
"1/1 [==============================] - 0s 17ms/step\n",
|
| 534 |
+
"1/1 [==============================] - 0s 18ms/step\n",
|
| 535 |
+
"1/1 [==============================] - 0s 18ms/step\n",
|
| 536 |
+
"1/1 [==============================] - 0s 19ms/step\n",
|
| 537 |
+
"1/1 [==============================] - 0s 18ms/step\n",
|
| 538 |
+
"1/1 [==============================] - 0s 17ms/step\n",
|
| 539 |
+
"1/1 [==============================] - 0s 17ms/step\n",
|
| 540 |
+
"1/1 [==============================] - 0s 17ms/step\n",
|
| 541 |
+
"1/1 [==============================] - 0s 17ms/step\n",
|
| 542 |
+
"1/1 [==============================] - 0s 17ms/step\n",
|
| 543 |
+
"1/1 [==============================] - 0s 17ms/step\n",
|
| 544 |
+
"1/1 [==============================] - 0s 18ms/step\n",
|
| 545 |
+
"1/1 [==============================] - 0s 18ms/step\n",
|
| 546 |
+
"1/1 [==============================] - 0s 18ms/step\n",
|
| 547 |
+
"1/1 [==============================] - 0s 19ms/step\n",
|
| 548 |
+
"1/1 [==============================] - 0s 18ms/step\n",
|
| 549 |
+
"1/1 [==============================] - 0s 21ms/step\n",
|
| 550 |
+
"1/1 [==============================] - 0s 17ms/step\n",
|
| 551 |
+
"1/1 [==============================] - 0s 19ms/step\n",
|
| 552 |
+
"1/1 [==============================] - 0s 23ms/step\n",
|
| 553 |
+
"1/1 [==============================] - 0s 18ms/step\n",
|
| 554 |
+
"1/1 [==============================] - 0s 17ms/step\n",
|
| 555 |
+
"1/1 [==============================] - 0s 18ms/step\n",
|
| 556 |
+
"1/1 [==============================] - 0s 18ms/step\n",
|
| 557 |
+
"1/1 [==============================] - 0s 18ms/step\n",
|
| 558 |
+
"1/1 [==============================] - 0s 17ms/step\n",
|
| 559 |
+
"1/1 [==============================] - 0s 17ms/step\n",
|
| 560 |
+
"1/1 [==============================] - 0s 17ms/step\n",
|
| 561 |
+
"1/1 [==============================] - 0s 17ms/step\n",
|
| 562 |
+
"1/1 [==============================] - 0s 17ms/step\n",
|
| 563 |
+
"1/1 [==============================] - 0s 22ms/step\n",
|
| 564 |
+
"1/1 [==============================] - 0s 18ms/step\n",
|
| 565 |
+
"1/1 [==============================] - 0s 18ms/step\n",
|
| 566 |
+
"1/1 [==============================] - 0s 18ms/step\n",
|
| 567 |
+
"1/1 [==============================] - 0s 18ms/step\n",
|
| 568 |
+
"1/1 [==============================] - 0s 17ms/step\n",
|
| 569 |
+
"1/1 [==============================] - 0s 17ms/step\n",
|
| 570 |
+
"1/1 [==============================] - 0s 22ms/step\n",
|
| 571 |
+
"1/1 [==============================] - 0s 17ms/step\n",
|
| 572 |
+
"1/1 [==============================] - 0s 17ms/step\n",
|
| 573 |
+
"1/1 [==============================] - 0s 17ms/step\n",
|
| 574 |
+
"1/1 [==============================] - 0s 18ms/step\n",
|
| 575 |
+
"1/1 [==============================] - 0s 17ms/step\n",
|
| 576 |
+
"1/1 [==============================] - 0s 17ms/step\n",
|
| 577 |
+
"1/1 [==============================] - 0s 23ms/step\n",
|
| 578 |
+
"1/1 [==============================] - 0s 20ms/step\n",
|
| 579 |
+
"1/1 [==============================] - 0s 17ms/step\n",
|
| 580 |
+
"1/1 [==============================] - 0s 17ms/step\n",
|
| 581 |
+
"1/1 [==============================] - 0s 17ms/step\n",
|
| 582 |
+
"1/1 [==============================] - 0s 17ms/step\n",
|
| 583 |
+
"1/1 [==============================] - 0s 18ms/step\n",
|
| 584 |
+
"1/1 [==============================] - 0s 21ms/step\n",
|
| 585 |
+
"1/1 [==============================] - 0s 28ms/step\n",
|
| 586 |
+
"1/1 [==============================] - 0s 18ms/step\n",
|
| 587 |
+
"1/1 [==============================] - 0s 20ms/step\n",
|
| 588 |
+
"1/1 [==============================] - 0s 20ms/step\n",
|
| 589 |
+
"1/1 [==============================] - 0s 20ms/step\n",
|
| 590 |
+
"1/1 [==============================] - 0s 19ms/step\n",
|
| 591 |
+
"1/1 [==============================] - 0s 18ms/step\n",
|
| 592 |
+
"1/1 [==============================] - 0s 19ms/step\n",
|
| 593 |
+
"1/1 [==============================] - 0s 18ms/step\n",
|
| 594 |
+
"1/1 [==============================] - 0s 17ms/step\n",
|
| 595 |
+
"1/1 [==============================] - 0s 18ms/step\n",
|
| 596 |
+
"1/1 [==============================] - 0s 18ms/step\n",
|
| 597 |
+
"1/1 [==============================] - 0s 17ms/step\n",
|
| 598 |
+
"1/1 [==============================] - 0s 17ms/step\n",
|
| 599 |
+
"1/1 [==============================] - 0s 18ms/step\n",
|
| 600 |
+
"1/1 [==============================] - 0s 18ms/step\n",
|
| 601 |
+
"1/1 [==============================] - 0s 24ms/step\n",
|
| 602 |
+
"1/1 [==============================] - 0s 17ms/step\n",
|
| 603 |
+
"1/1 [==============================] - 0s 17ms/step\n",
|
| 604 |
+
"1/1 [==============================] - 0s 17ms/step\n",
|
| 605 |
+
"1/1 [==============================] - 0s 19ms/step\n",
|
| 606 |
+
"1/1 [==============================] - 0s 17ms/step\n",
|
| 607 |
+
"1/1 [==============================] - 0s 17ms/step\n",
|
| 608 |
+
"1/1 [==============================] - 0s 20ms/step\n",
|
| 609 |
+
"1/1 [==============================] - 0s 17ms/step\n",
|
| 610 |
+
"1/1 [==============================] - 0s 20ms/step\n",
|
| 611 |
+
"1/1 [==============================] - 0s 18ms/step\n",
|
| 612 |
+
"1/1 [==============================] - 0s 18ms/step\n",
|
| 613 |
+
"1/1 [==============================] - 0s 18ms/step\n",
|
| 614 |
+
"1/1 [==============================] - 0s 17ms/step\n",
|
| 615 |
+
"1/1 [==============================] - 0s 17ms/step\n",
|
| 616 |
+
"1/1 [==============================] - 0s 16ms/step\n",
|
| 617 |
+
"1/1 [==============================] - 0s 17ms/step\n",
|
| 618 |
+
"1/1 [==============================] - 0s 23ms/step\n",
|
| 619 |
+
"1/1 [==============================] - 0s 17ms/step\n",
|
| 620 |
+
"1/1 [==============================] - 0s 18ms/step\n",
|
| 621 |
+
"1/1 [==============================] - 0s 17ms/step\n",
|
| 622 |
+
"1/1 [==============================] - 0s 17ms/step\n",
|
| 623 |
+
"1/1 [==============================] - 0s 17ms/step\n",
|
| 624 |
+
"1/1 [==============================] - 0s 17ms/step\n",
|
| 625 |
+
"1/1 [==============================] - 0s 17ms/step\n",
|
| 626 |
+
"1/1 [==============================] - 0s 17ms/step\n",
|
| 627 |
+
"1/1 [==============================] - 0s 17ms/step\n",
|
| 628 |
+
"1/1 [==============================] - 0s 17ms/step\n",
|
| 629 |
+
"1/1 [==============================] - 0s 17ms/step\n",
|
| 630 |
+
"1/1 [==============================] - 0s 17ms/step\n",
|
| 631 |
+
"1/1 [==============================] - 0s 20ms/step\n",
|
| 632 |
+
"1/1 [==============================] - 0s 16ms/step\n",
|
| 633 |
+
"1/1 [==============================] - 0s 16ms/step\n",
|
| 634 |
+
"1/1 [==============================] - 0s 17ms/step\n",
|
| 635 |
+
"1/1 [==============================] - 0s 18ms/step\n",
|
| 636 |
+
"1/1 [==============================] - 0s 17ms/step\n",
|
| 637 |
+
"1/1 [==============================] - 0s 17ms/step\n",
|
| 638 |
+
"1/1 [==============================] - 0s 18ms/step\n",
|
| 639 |
+
"1/1 [==============================] - 0s 17ms/step\n",
|
| 640 |
+
"1/1 [==============================] - 0s 18ms/step\n",
|
| 641 |
+
"1/1 [==============================] - 0s 18ms/step\n",
|
| 642 |
+
"1/1 [==============================] - 0s 17ms/step\n",
|
| 643 |
+
"1/1 [==============================] - 0s 30ms/step\n",
|
| 644 |
+
"1/1 [==============================] - 0s 28ms/step\n",
|
| 645 |
+
"1/1 [==============================] - 0s 29ms/step\n",
|
| 646 |
+
"1/1 [==============================] - 0s 38ms/step\n",
|
| 647 |
+
"1/1 [==============================] - 0s 30ms/step\n",
|
| 648 |
+
"1/1 [==============================] - 0s 31ms/step\n",
|
| 649 |
+
"1/1 [==============================] - 0s 27ms/step\n",
|
| 650 |
+
"1/1 [==============================] - 0s 24ms/step\n",
|
| 651 |
+
"1/1 [==============================] - 0s 24ms/step\n",
|
| 652 |
+
"1/1 [==============================] - 0s 25ms/step\n",
|
| 653 |
+
"1/1 [==============================] - 0s 25ms/step\n",
|
| 654 |
+
"1/1 [==============================] - 0s 25ms/step\n",
|
| 655 |
+
"1/1 [==============================] - 0s 30ms/step\n",
|
| 656 |
+
"1/1 [==============================] - 0s 25ms/step\n",
|
| 657 |
+
"1/1 [==============================] - 0s 25ms/step\n",
|
| 658 |
+
"1/1 [==============================] - 0s 26ms/step\n",
|
| 659 |
+
"1/1 [==============================] - 0s 26ms/step\n",
|
| 660 |
+
"1/1 [==============================] - 0s 36ms/step\n",
|
| 661 |
+
"1/1 [==============================] - 0s 28ms/step\n",
|
| 662 |
+
"1/1 [==============================] - 0s 31ms/step\n",
|
| 663 |
+
"1/1 [==============================] - 0s 29ms/step\n",
|
| 664 |
+
"1/1 [==============================] - 0s 29ms/step\n",
|
| 665 |
+
"1/1 [==============================] - 0s 28ms/step\n",
|
| 666 |
+
"1/1 [==============================] - 0s 28ms/step\n",
|
| 667 |
+
"1/1 [==============================] - 0s 27ms/step\n",
|
| 668 |
+
"1/1 [==============================] - 0s 18ms/step\n",
|
| 669 |
+
"1/1 [==============================] - 0s 18ms/step\n",
|
| 670 |
+
"1/1 [==============================] - 0s 22ms/step\n",
|
| 671 |
+
"1/1 [==============================] - 0s 18ms/step\n",
|
| 672 |
+
"1/1 [==============================] - 0s 17ms/step\n",
|
| 673 |
+
"1/1 [==============================] - 0s 19ms/step\n",
|
| 674 |
+
"1/1 [==============================] - 0s 18ms/step\n",
|
| 675 |
+
"1/1 [==============================] - 0s 19ms/step\n",
|
| 676 |
+
"1/1 [==============================] - 0s 19ms/step\n",
|
| 677 |
+
"1/1 [==============================] - 0s 19ms/step\n",
|
| 678 |
+
"1/1 [==============================] - 0s 18ms/step\n",
|
| 679 |
+
"1/1 [==============================] - 0s 16ms/step\n",
|
| 680 |
+
"1/1 [==============================] - 0s 18ms/step\n",
|
| 681 |
+
"1/1 [==============================] - 0s 17ms/step\n",
|
| 682 |
+
"1/1 [==============================] - 0s 17ms/step\n",
|
| 683 |
+
"1/1 [==============================] - 0s 17ms/step\n",
|
| 684 |
+
"1/1 [==============================] - 0s 17ms/step\n",
|
| 685 |
+
"1/1 [==============================] - 0s 17ms/step\n",
|
| 686 |
+
"1/1 [==============================] - 0s 19ms/step\n",
|
| 687 |
+
"1/1 [==============================] - 0s 16ms/step\n",
|
| 688 |
+
"1/1 [==============================] - 0s 16ms/step\n",
|
| 689 |
+
"1/1 [==============================] - 0s 16ms/step\n",
|
| 690 |
+
"1/1 [==============================] - 0s 19ms/step\n",
|
| 691 |
+
"1/1 [==============================] - 0s 25ms/step\n",
|
| 692 |
+
"Accuracy: 0.744\n"
|
| 693 |
+
]
|
| 694 |
+
}
|
| 695 |
+
]
|
| 696 |
+
}
|
| 697 |
+
]
|
| 698 |
+
}
|
Experiments/quicknet_knn.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import matplotlib.pyplot as plt
|
| 2 |
+
import numpy as np
|
| 3 |
+
import pickle
|
| 4 |
+
|
| 5 |
+
def find_knn_and_plot_image(test_data,model,k_value:int =10):
|
| 6 |
+
x_test = test_data
|
| 7 |
+
point = np.array(x_test)
|
| 8 |
+
point = np.expand_dims(point, axis=0)
|
| 9 |
+
|
| 10 |
+
test_embedding = model.predict(point)
|
| 11 |
+
|
| 12 |
+
with open(r'Model\data\train_embeddings_resnet18.pkl', 'rb') as f:
|
| 13 |
+
data = pickle.load(f)
|
| 14 |
+
|
| 15 |
+
# Convert data to NumPy array
|
| 16 |
+
train_data = np.array(data)
|
| 17 |
+
|
| 18 |
+
# Convert NumPy array to TensorFlow tensor
|
| 19 |
+
# data_tensor = tf.convert_to_tensor(data_array, dtype=tf.float32)
|
| 20 |
+
|
| 21 |
+
distance_with_label_and_index = []
|
| 22 |
+
print(train_data)
|
| 23 |
+
for i,(x_train) in enumerate(train_data):
|
| 24 |
+
train_point = np.array(x_train)
|
| 25 |
+
distance_with_label_and_index.append((i,np.linalg.norm(test_embedding-train_point)))
|
| 26 |
+
|
| 27 |
+
#sorting based on distance
|
| 28 |
+
distance_with_label_and_index_sorted=sorted(distance_with_label_and_index,key=lambda x: x[1])
|
| 29 |
+
k_nearest_points = distance_with_label_and_index_sorted[0:k_value]
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
#calculating accuracy
|
| 33 |
+
result =[]
|
| 34 |
+
for i,(index,distance) in enumerate(k_nearest_points):
|
| 35 |
+
print(index)
|
| 36 |
+
result.append(train_image[index])
|
| 37 |
+
return result
|
Experiments/train_resnet50_3hll.py
ADDED
|
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from tqdm import tqdm
|
| 3 |
+
import wandb
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import torchvision.transforms as v2
|
| 6 |
+
from torch.utils.data import DataLoader
|
| 7 |
+
import torchvision.datasets as datasets
|
| 8 |
+
|
| 9 |
+
wandb.login()
|
| 10 |
+
|
| 11 |
+
wandb.init(project="CIFAR", entity="gjyotin305")
|
| 12 |
+
mps_device = torch.device("cuda")
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
IMAGE_SIZE = 224
|
| 16 |
+
mean, std = [0.4914, 0.4822, 0.4465], [0.247, 0.243, 0.261]
|
| 17 |
+
wandb.init()
|
| 18 |
+
transform_test = v2.Compose(
|
| 19 |
+
[v2.Resize((IMAGE_SIZE, IMAGE_SIZE)),
|
| 20 |
+
v2.ToTensor(),
|
| 21 |
+
v2.Normalize(mean, std)])
|
| 22 |
+
|
| 23 |
+
transform_train = v2.Compose(
|
| 24 |
+
[v2.Resize((IMAGE_SIZE, IMAGE_SIZE)),
|
| 25 |
+
v2.RandomRotation(20),
|
| 26 |
+
v2.RandomHorizontalFlip(0.1),
|
| 27 |
+
v2.ColorJitter(brightness = 0.1, contrast=0.1, saturation=0.1),
|
| 28 |
+
v2.RandomAdjustSharpness(sharpness_factor=2, p=0.1),
|
| 29 |
+
v2.ToTensor(),
|
| 30 |
+
v2.Normalize(mean, std),
|
| 31 |
+
v2.RandomErasing(p=0.75, scale=(0.02, 0.1), value=1.0, inplace=False)])
|
| 32 |
+
|
| 33 |
+
cifar_trainset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
|
| 34 |
+
cifar_testset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
|
| 35 |
+
|
| 36 |
+
batch_size=64
|
| 37 |
+
|
| 38 |
+
trainLoader = DataLoader(cifar_trainset, batch_size=batch_size, shuffle=True, num_workers=8)
|
| 39 |
+
testLoader = DataLoader(cifar_testset, batch_size=batch_size, shuffle=True, num_workers=8)
|
| 40 |
+
|
| 41 |
+
classes = ('plane', 'car', 'bird', 'cat',
|
| 42 |
+
'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
|
| 43 |
+
|
| 44 |
+
print(next(iter(trainLoader))[0].shape)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
import torchvision.datasets as datasets
|
| 48 |
+
import torchvision.models as models
|
| 49 |
+
|
| 50 |
+
resnet = models.resnet50(pretrained=True)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
class JHARMNet(nn.Module):
|
| 54 |
+
def __init__(self, pretrained_model, num_classes):
|
| 55 |
+
super().__init__()
|
| 56 |
+
self.premodel = pretrained_model
|
| 57 |
+
self.premodel.fc = nn.Linear(2048, num_classes)
|
| 58 |
+
nn.init.xavier_uniform_(self.premodel.fc.weight)
|
| 59 |
+
|
| 60 |
+
def forward(self, x):
|
| 61 |
+
out = self.premodel(x)
|
| 62 |
+
return out
|
| 63 |
+
|
| 64 |
+
class HiddenLayer(nn.Module):
|
| 65 |
+
def __init__(self, pretrained_model):
|
| 66 |
+
super().__init__()
|
| 67 |
+
self.premodel = pretrained_model
|
| 68 |
+
self.new_layer = nn.Sequential(
|
| 69 |
+
nn.Linear(1000, 512),
|
| 70 |
+
nn.LeakyReLU(),
|
| 71 |
+
nn.Linear(512, 512),
|
| 72 |
+
nn.LeakyReLU(),
|
| 73 |
+
nn.Linear(512, 256),
|
| 74 |
+
nn.LeakyReLU(),
|
| 75 |
+
nn.Linear(256, 10)
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
def forward(self, x):
|
| 79 |
+
out = self.premodel(x)
|
| 80 |
+
out_new_layer = self.new_layer(out)
|
| 81 |
+
return out_new_layer
|
| 82 |
+
|
| 83 |
+
model_check = HiddenLayer(resnet)
|
| 84 |
+
criterion = nn.CrossEntropyLoss()
|
| 85 |
+
learning_rate=0.001
|
| 86 |
+
|
| 87 |
+
optimizer = torch.optim.SGD(model_check.parameters(),lr=learning_rate, momentum=0.2)
|
| 88 |
+
|
| 89 |
+
# scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min')
|
| 90 |
+
|
| 91 |
+
def train_model(model, train_loader, validation_loader, optimizer, n_epochs = 20):
|
| 92 |
+
|
| 93 |
+
# Global variable
|
| 94 |
+
N_test = len(cifar_testset)
|
| 95 |
+
accuracy_list = []
|
| 96 |
+
train_loss_list = []
|
| 97 |
+
model = model.to(mps_device)
|
| 98 |
+
train_cost_list = []
|
| 99 |
+
val_cost_list = []
|
| 100 |
+
|
| 101 |
+
for epoch in range(n_epochs):
|
| 102 |
+
train_COST = 0
|
| 103 |
+
print(f"Training Epoch: {epoch+1}")
|
| 104 |
+
for x,y in tqdm(train_loader):
|
| 105 |
+
x = x.to(mps_device)
|
| 106 |
+
y = y.to(mps_device)
|
| 107 |
+
model.train()
|
| 108 |
+
optimizer.zero_grad()
|
| 109 |
+
z = model(x)
|
| 110 |
+
loss = criterion(z,y)
|
| 111 |
+
loss.backward()
|
| 112 |
+
wandb.log({"train_loss": loss.item()})
|
| 113 |
+
optimizer.step()
|
| 114 |
+
train_COST+=loss.item()
|
| 115 |
+
|
| 116 |
+
train_COST = train_COST/len(train_loader)
|
| 117 |
+
train_cost_list.append(train_COST)
|
| 118 |
+
correct = 0
|
| 119 |
+
print(f"Validation Loop")
|
| 120 |
+
# Perform the prediction on the validation data
|
| 121 |
+
val_COST = 0
|
| 122 |
+
for x_test, y_test in tqdm(validation_loader):
|
| 123 |
+
model.eval()
|
| 124 |
+
x_test = x_test.to(mps_device)
|
| 125 |
+
y_test = y_test.to(mps_device)
|
| 126 |
+
z = model(x_test)
|
| 127 |
+
val_loss = criterion(z, y_test)
|
| 128 |
+
# scheduler.step(val_loss)
|
| 129 |
+
wandb.log({"val_loss": val_loss})
|
| 130 |
+
_, yhat = torch.max(z.data, 1)
|
| 131 |
+
correct += (yhat==y_test).sum().item()
|
| 132 |
+
val_COST+=val_loss.item()
|
| 133 |
+
|
| 134 |
+
val_COST = val_COST/ len(validation_loader)
|
| 135 |
+
val_cost_list.append(val_COST)
|
| 136 |
+
accuracy = correct / N_test
|
| 137 |
+
accuracy_list.append(accuracy)
|
| 138 |
+
wandb.log({"val_acc": accuracy})
|
| 139 |
+
if (epoch+1)%5 == 0:
|
| 140 |
+
torch.save(model.state_dict(), f"/root/test_cifar/log_model/CIFAR_{epoch+1}_end_hll.pt")
|
| 141 |
+
print(f"Saved Model as CIFAR_{epoch+1}_end_hll.pt")
|
| 142 |
+
|
| 143 |
+
print("--> Epoch Number : {}".format(epoch + 1),
|
| 144 |
+
" | Training Loss : {}".format(round(train_COST,4)),
|
| 145 |
+
" | Validation Loss : {}".format(round(val_COST,4)),
|
| 146 |
+
" | Validation Accuracy : {}%".format(round(accuracy * 100, 2)))
|
| 147 |
+
|
| 148 |
+
return accuracy_list, train_cost_list, val_cost_list, model
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
accuracy_list_normalv5, train_cost_listv5, val_cost_listv5, model_to_save=train_model(model=model_check,
|
| 153 |
+
n_epochs=100,
|
| 154 |
+
train_loader=trainLoader,
|
| 155 |
+
validation_loader=testLoader,
|
| 156 |
+
optimizer=optimizer)
|
| 157 |
+
|
| 158 |
+
torch.save(model_to_save.state_dict(), "CIFAR_end_hll.pt")
|
| 159 |
+
model = torch.load("CIFAR_end_hll.pt")
|
| 160 |
+
model.eval()
|
| 161 |
+
|
| 162 |
+
pred = model(next(iter(testLoader))[0])
|
| 163 |
+
print(pred[0][0], next(iter(testLoader))[1][0])
|
Experiments/train_resnet50_iter_1.py
ADDED
|
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torchvision.models as models
|
| 2 |
+
import torch
|
| 3 |
+
from tqdm import tqdm
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import torchvision.transforms as v2
|
| 6 |
+
from torch.utils.data import DataLoader
|
| 7 |
+
import torchvision.datasets as datasets
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
mps_device = torch.device("cuda")
|
| 11 |
+
|
| 12 |
+
IMAGE_SIZE = 224
|
| 13 |
+
mean, std = [0.4914, 0.4822, 0.4465], [0.247, 0.243, 0.261]
|
| 14 |
+
|
| 15 |
+
transform_test = v2.Compose(
|
| 16 |
+
[v2.Resize((IMAGE_SIZE, IMAGE_SIZE)),
|
| 17 |
+
v2.ToTensor(),
|
| 18 |
+
v2.Normalize(mean, std)])
|
| 19 |
+
|
| 20 |
+
transform_train = v2.Compose(
|
| 21 |
+
[v2.Resize((IMAGE_SIZE, IMAGE_SIZE)),
|
| 22 |
+
v2.RandomRotation(20),
|
| 23 |
+
v2.RandomHorizontalFlip(0.1),
|
| 24 |
+
v2.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1),
|
| 25 |
+
v2.RandomAdjustSharpness(sharpness_factor=2, p=0.1),
|
| 26 |
+
v2.ToTensor(),
|
| 27 |
+
v2.Normalize(mean, std),
|
| 28 |
+
v2.RandomErasing(p=0.75, scale=(0.02, 0.1), value=1.0, inplace=False)])
|
| 29 |
+
|
| 30 |
+
cifar_trainset = datasets.CIFAR10(
|
| 31 |
+
root='./data',
|
| 32 |
+
train=True,
|
| 33 |
+
download=True,
|
| 34 |
+
transform=transform_train)
|
| 35 |
+
cifar_testset = datasets.CIFAR10(
|
| 36 |
+
root='./data',
|
| 37 |
+
train=False,
|
| 38 |
+
download=True,
|
| 39 |
+
transform=transform_test)
|
| 40 |
+
|
| 41 |
+
batch_size = 64
|
| 42 |
+
|
| 43 |
+
trainLoader = DataLoader(
|
| 44 |
+
cifar_trainset,
|
| 45 |
+
batch_size=batch_size,
|
| 46 |
+
shuffle=True,
|
| 47 |
+
num_workers=8)
|
| 48 |
+
testLoader = DataLoader(
|
| 49 |
+
cifar_testset,
|
| 50 |
+
batch_size=batch_size,
|
| 51 |
+
shuffle=True,
|
| 52 |
+
num_workers=8)
|
| 53 |
+
|
| 54 |
+
classes = ('plane', 'car', 'bird', 'cat',
|
| 55 |
+
'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
|
| 56 |
+
|
| 57 |
+
print(next(iter(trainLoader))[0].shape)
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
resnet = models.resnet50(pretrained=True)
|
| 61 |
+
# print(resnet)
|
| 62 |
+
# print(resnet.fc)
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
class JHARMNet(nn.Module):
|
| 66 |
+
def __init__(self, pretrained_model, num_classes):
|
| 67 |
+
super().__init__()
|
| 68 |
+
self.premodel = pretrained_model
|
| 69 |
+
self.premodel.fc = nn.Linear(2048, num_classes)
|
| 70 |
+
nn.init.xavier_uniform_(self.premodel.fc.weight)
|
| 71 |
+
|
| 72 |
+
def forward(self, x):
|
| 73 |
+
out = self.premodel(x)
|
| 74 |
+
return out
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
model_check = JHARMNet(resnet, 10)
|
| 78 |
+
criterion = nn.CrossEntropyLoss()
|
| 79 |
+
learning_rate = 0.1
|
| 80 |
+
|
| 81 |
+
optimizer = torch.optim.SGD(
|
| 82 |
+
model_check.parameters(),
|
| 83 |
+
lr=learning_rate,
|
| 84 |
+
momentum=0.2)
|
| 85 |
+
|
| 86 |
+
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min')
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def train_model(
|
| 90 |
+
model,
|
| 91 |
+
train_loader,
|
| 92 |
+
validation_loader,
|
| 93 |
+
optimizer,
|
| 94 |
+
n_epochs=20):
|
| 95 |
+
|
| 96 |
+
# Global variable
|
| 97 |
+
N_test = len(cifar_testset)
|
| 98 |
+
accuracy_list = []
|
| 99 |
+
train_loss_list = []
|
| 100 |
+
model = model.to(mps_device)
|
| 101 |
+
train_cost_list = []
|
| 102 |
+
val_cost_list = []
|
| 103 |
+
|
| 104 |
+
for epoch in range(n_epochs):
|
| 105 |
+
train_COST = 0
|
| 106 |
+
print(f"Training Epoch: {epoch}")
|
| 107 |
+
for x, y in tqdm(train_loader):
|
| 108 |
+
x = x.to(mps_device)
|
| 109 |
+
y = y.to(mps_device)
|
| 110 |
+
model.train()
|
| 111 |
+
optimizer.zero_grad()
|
| 112 |
+
z = model(x)
|
| 113 |
+
loss = criterion(z, y)
|
| 114 |
+
loss.backward()
|
| 115 |
+
optimizer.step()
|
| 116 |
+
train_COST += loss.item()
|
| 117 |
+
|
| 118 |
+
train_COST = train_COST / len(train_loader)
|
| 119 |
+
train_cost_list.append(train_COST)
|
| 120 |
+
correct = 0
|
| 121 |
+
print(f"Validation Loop")
|
| 122 |
+
# Perform the prediction on the validation data
|
| 123 |
+
val_COST = 0
|
| 124 |
+
for x_test, y_test in tqdm(validation_loader):
|
| 125 |
+
model.eval()
|
| 126 |
+
x_test = x_test.to(mps_device)
|
| 127 |
+
y_test = y_test.to(mps_device)
|
| 128 |
+
z = model(x_test)
|
| 129 |
+
val_loss = criterion(z, y_test)
|
| 130 |
+
scheduler.step(val_loss)
|
| 131 |
+
_, yhat = torch.max(z.data, 1)
|
| 132 |
+
correct += (yhat == y_test).sum().item()
|
| 133 |
+
val_COST += val_loss.item()
|
| 134 |
+
|
| 135 |
+
val_COST = val_COST / len(validation_loader)
|
| 136 |
+
val_cost_list.append(val_COST)
|
| 137 |
+
accuracy = correct / N_test
|
| 138 |
+
accuracy_list.append(accuracy)
|
| 139 |
+
|
| 140 |
+
print("--> Epoch Number : {}".format(epoch + 1),
|
| 141 |
+
" | Training Loss : {}".format(round(train_COST, 4)),
|
| 142 |
+
" | Validation Loss : {}".format(round(val_COST, 4)),
|
| 143 |
+
" | Validation Accuracy : {}%".format(round(accuracy * 100, 2)))
|
| 144 |
+
|
| 145 |
+
return accuracy_list, train_cost_list, val_cost_list, model
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
accuracy_list_normalv5, train_cost_listv5, val_cost_listv5, model_to_save = train_model(
|
| 149 |
+
model=model_check, n_epochs=100, train_loader=trainLoader, validation_loader=testLoader, optimizer=optimizer)
|
| 150 |
+
|
| 151 |
+
torch.save(model_to_save.state_dict(), "CIFAR.pt")
|
| 152 |
+
model = torch.load("CIFAR.pt")
|
| 153 |
+
model.eval()
|
| 154 |
+
|
| 155 |
+
pred = model(next(iter(testLoader))[0])
|
| 156 |
+
print(pred[0][0], next(iter(testLoader))[1][0])
|
LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2024 Jyotin Goel
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
MidTerm_Report.pdf
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:e171b9314127b3662ce813589398106b1c1dbabd996bc36aadc475cfda02ed8d
|
| 3 |
+
size 1110513
|
Model/CIFAR.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:9fc375fa01813882af9e2a88654b765fb58bd8cd2f48731963ffa56fb1652434
|
| 3 |
+
size 94418218
|
Model/Resnet50_train_features.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:50516ceab17cd7ca8eaa2738b30e201fad30a89972a256cc370a56236a70c427
|
| 3 |
+
size 2001185
|
Model/__pycache__/centroid_app.cpython-311.pyc
ADDED
|
Binary file (5.62 kB). View file
|
|
|
Model/ann.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:6375c0313b719e85a904bc2d733ee741f084599b3f5671b8dd4b978a5a4101b1
|
| 3 |
+
size 6832736
|
Model/centroid_app.py
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import pickle
|
| 3 |
+
import random
|
| 4 |
+
import numpy as np
|
| 5 |
+
import cv2
|
| 6 |
+
from tensorflow.keras import models, layers
|
| 7 |
+
|
| 8 |
+
# Load the trained model architecture
|
| 9 |
+
def create_resnet18():
|
| 10 |
+
model = models.Sequential()
|
| 11 |
+
model.add(layers.Conv2D(64, (3, 3), activation='relu', padding='same', input_shape=(32, 32, 3)))
|
| 12 |
+
model.add(layers.MaxPooling2D((2, 2)))
|
| 13 |
+
model.add(layers.Conv2D(64, (3, 3), activation='relu', padding='same'))
|
| 14 |
+
model.add(layers.MaxPooling2D((2, 2)))
|
| 15 |
+
model.add(layers.Conv2D(128, (3, 3), activation='relu', padding='same'))
|
| 16 |
+
model.add(layers.Conv2D(128, (3, 3), activation='relu', padding='same'))
|
| 17 |
+
model.add(layers.MaxPooling2D((2, 2)))
|
| 18 |
+
model.add(layers.Conv2D(256, (3, 3), activation='relu', padding='same'))
|
| 19 |
+
model.add(layers.Conv2D(256, (3, 3), activation='relu', padding='same'))
|
| 20 |
+
model.add(layers.MaxPooling2D((2, 2)))
|
| 21 |
+
model.add(layers.Flatten())
|
| 22 |
+
model.add(layers.Dense(512, activation='relu'))
|
| 23 |
+
model.add(layers.Dense(10, activation='softmax'))
|
| 24 |
+
return model
|
| 25 |
+
|
| 26 |
+
# Load the pretrained weights
|
| 27 |
+
def load_pretrained_weights(model, weights_path):
|
| 28 |
+
model.load_weights(weights_path)
|
| 29 |
+
|
| 30 |
+
# Function to unpickle a file
|
| 31 |
+
def unpickle(file):
|
| 32 |
+
with open(file, 'rb') as fo:
|
| 33 |
+
dict = pickle.load(fo, encoding='bytes')
|
| 34 |
+
return dict
|
| 35 |
+
|
| 36 |
+
# Function to load images from the unpickled data batch file of a specific class
|
| 37 |
+
def load_class_images(class_index, train_batches):
|
| 38 |
+
images = []
|
| 39 |
+
for batch in train_batches:
|
| 40 |
+
if b'data' in batch and b'labels' in batch:
|
| 41 |
+
data = batch[b'data']
|
| 42 |
+
labels = batch[b'labels']
|
| 43 |
+
for i, label in enumerate(labels):
|
| 44 |
+
if label == class_index:
|
| 45 |
+
img = data[i].reshape(3, 32, 32).transpose(1, 2, 0) # Reshape and transpose the image
|
| 46 |
+
images.append(img)
|
| 47 |
+
return images
|
| 48 |
+
|
| 49 |
+
# Function to calculate the distance of the mean embeddings with a query image
|
| 50 |
+
def classify_query(query_image, model, mean_embeddings):
|
| 51 |
+
query_embedding = model.predict(np.expand_dims(query_image, axis=0))
|
| 52 |
+
distances = [np.linalg.norm(query_embedding.flatten() - mean_embedding) for mean_embedding in mean_embeddings]
|
| 53 |
+
predicted_class = np.argmin(distances)
|
| 54 |
+
return predicted_class
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def retrieve(query_image,k=3):
|
| 58 |
+
model = create_resnet18()
|
| 59 |
+
load_pretrained_weights(model, 'Model\pretrained_model_weights.h5')
|
| 60 |
+
|
| 61 |
+
mean_embeddings = pickle.load(open('Model\data\mean_embeddings.pkl', 'rb'))
|
| 62 |
+
|
| 63 |
+
# query_image_path = '/content/airplane_8925.png'
|
| 64 |
+
# query_image = cv2.imread(query_image_path)
|
| 65 |
+
query_image = cv2.resize(query_image, (32, 32)) / 255.0 # Resize and normalize the image
|
| 66 |
+
|
| 67 |
+
predicted_class = classify_query(query_image, model, mean_embeddings)
|
| 68 |
+
# print("Predicted Class:", predicted_class)
|
| 69 |
+
|
| 70 |
+
# Load random images of the predicted class
|
| 71 |
+
train_batches = [unpickle(rf"Model\data\data_batch_{i}") for i in range(1,6)]
|
| 72 |
+
class_images = load_class_images(predicted_class, train_batches)
|
| 73 |
+
|
| 74 |
+
if class_images:
|
| 75 |
+
random_images = random.sample(class_images, k) # Select 3 random images
|
| 76 |
+
return random_images
|
| 77 |
+
else:
|
| 78 |
+
print("No images found for the predicted class.")
|
| 79 |
+
|
| 80 |
+
|
Model/data/batches.meta
ADDED
|
Binary file (158 Bytes). View file
|
|
|
Model/data/data_batch_1
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:54636561a3ce25bd3e19253c6b0d8538147b0ae398331ac4a2d86c6d987368cd
|
| 3 |
+
size 31035704
|
Model/data/data_batch_2
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:766b2cef9fbc745cf056b3152224f7cf77163b330ea9a15f9392beb8b89bc5a8
|
| 3 |
+
size 31035320
|
Model/data/data_batch_3
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:0f00d98ebfb30b3ec0ad19f9756dc2630b89003e10525f5e148445e82aa6a1f9
|
| 3 |
+
size 31035999
|
Model/data/data_batch_4
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:3f7bb240661948b8f4d53e36ec720d8306f5668bd0071dcb4e6c947f78e9682b
|
| 3 |
+
size 31035696
|
Model/data/data_batch_5
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d91802434d8376bbaeeadf58a737e3a1b12ac839077e931237e0dcd43adcb154
|
| 3 |
+
size 31035623
|
Model/data/mean_embeddings.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:272f9a5f5bd64519f68c8ac5397a6d98a920175d15ad67e19af110f54ac9dfaf
|
| 3 |
+
size 952
|
Model/data/test_batch
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f53d8d457504f7cff4ea9e021afcf0e0ad8e24a91f3fc42091b8adef61157831
|
| 3 |
+
size 31035526
|
Model/data/train_embeddings_resnet18.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:8238e4c8fb43aa0175def221191fe3c9f7e137683746e1f33086681488a266f4
|
| 3 |
+
size 2000162
|
Model/pretrained_model_weights.h5
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:3c60f7cc223c7256a4a3de446b2820d5cd05b552dcbe4a60fcea2f83465f1538
|
| 3 |
+
size 6735320
|
Model/resnet18.h5
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d8415a29fcbd661a67bc004b26e9bcccc7a2133bffa99cd1dc1811c8d1a37dd5
|
| 3 |
+
size 20154704
|
Model/test.py
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import os
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import torchvision.datasets as datasets
|
| 5 |
+
import torchvision.transforms as transforms
|
| 6 |
+
import torchvision.models as models
|
| 7 |
+
|
| 8 |
+
os.system("gdown --id 1qO2OLR7skDibo1LaMKD3CiOl_jaCTZ0h")
|
| 9 |
+
|
| 10 |
+
class JHARMNet(nn.Module):
|
| 11 |
+
def __init__(self, pretrained_model, num_classes):
|
| 12 |
+
super().__init__()
|
| 13 |
+
self.premodel = pretrained_model
|
| 14 |
+
self.premodel.fc = nn.Linear(2048, num_classes)
|
| 15 |
+
nn.init.xavier_uniform_(self.premodel.fc.weight)
|
| 16 |
+
|
| 17 |
+
def forward(self, x):
|
| 18 |
+
out = self.premodel(x)
|
| 19 |
+
return out
|
| 20 |
+
|
| 21 |
+
class HiddenLayer(nn.Module):
|
| 22 |
+
def __init__(self, pretrained_model):
|
| 23 |
+
super().__init__()
|
| 24 |
+
self.premodel = pretrained_model
|
| 25 |
+
self.new_layer = nn.Sequential(
|
| 26 |
+
nn.Linear(1000, 512),
|
| 27 |
+
nn.LeakyReLU(),
|
| 28 |
+
nn.Linear(512, 512),
|
| 29 |
+
nn.LeakyReLU(),
|
| 30 |
+
nn.Linear(512, 256),
|
| 31 |
+
nn.LeakyReLU(),
|
| 32 |
+
nn.Linear(256, 10)
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
def forward(self, x):
|
| 36 |
+
out = self.premodel(x)
|
| 37 |
+
out_new_layer = self.new_layer(out)
|
| 38 |
+
return out_new_layer
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
resnet = models.resnet50(pretrained=True)
|
| 42 |
+
|
| 43 |
+
x = torch.randn((2, 3, 32, 32))
|
| 44 |
+
|
| 45 |
+
print(x.shape)
|
| 46 |
+
model_check = HiddenLayer(resnet)
|
| 47 |
+
model_check.load_state_dict(torch.load("CIFAR_end_hll.pt"))
|
| 48 |
+
model_check.eval()
|
| 49 |
+
print(model_check.forward(x).shape)
|
Preprocessing/cifar_eda.ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
Preprocessing/k_means.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from sklearn.cluster import KMeans
|
| 2 |
+
from sklearn.utils import shuffle
|
| 3 |
+
from sklearn.metrics import pairwise_distances_argmin
|
| 4 |
+
from PIL import Image
|
| 5 |
+
import numpy as np
|
| 6 |
+
import matplotlib.image as mpimg
|
| 7 |
+
import matplotlib.pyplot as plt
|
| 8 |
+
import random
|
| 9 |
+
|
| 10 |
+
def kmeans_image_compression(image, n_colors=8):
|
| 11 |
+
|
| 12 |
+
# Convert image to numpy array
|
| 13 |
+
image_np = np.array(image)
|
| 14 |
+
|
| 15 |
+
# Reshape the image to a 2D array of pixels
|
| 16 |
+
image_reshaped = image_np.reshape(-1, 3)
|
| 17 |
+
# print(image_reshaped.shape)
|
| 18 |
+
|
| 19 |
+
# Shuffle the pixels
|
| 20 |
+
image_reshaped_sample = shuffle(image_reshaped, random_state=0)[:image_reshaped.shape[0]]
|
| 21 |
+
|
| 22 |
+
# Apply KMeans clustering using sklearn Library
|
| 23 |
+
kmeans = KMeans(n_clusters=n_colors, random_state=0,n_init=10).fit(image_reshaped_sample)
|
| 24 |
+
labels = kmeans.predict(image_reshaped)
|
| 25 |
+
centroids = kmeans.cluster_centers_
|
| 26 |
+
|
| 27 |
+
# Replace each pixel with its corresponding centroid color
|
| 28 |
+
compressed_image = centroids[labels]
|
| 29 |
+
|
| 30 |
+
# Reshape the compressed image back to its original shape
|
| 31 |
+
compressed_image = compressed_image.reshape(image_np.shape)
|
| 32 |
+
|
| 33 |
+
# Convert the compressed image array to uint8 data type
|
| 34 |
+
compressed_image = compressed_image.astype(np.uint8)
|
| 35 |
+
return compressed_image
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
compressed_train_data =[]
|
| 39 |
+
for x,y in train_data:
|
| 40 |
+
compressed_x=kmeans_image_compression(x,6)
|
| 41 |
+
compressed_train_data.append((compressed_x,y))
|
| 42 |
+
|
| 43 |
+
compressed_test_data = []
|
| 44 |
+
for x,y in test_data:
|
| 45 |
+
compressed_x=kmeans_image_compression(x,6)
|
| 46 |
+
compressed_test_data.append((compressed_x,y))
|
Preprocessing/utils.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from sklearn.decomposition import PCA
|
| 2 |
+
import numpy as np
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def pca_for_rgb(img, n_components):
|
| 6 |
+
red = img[:, :, 0]
|
| 7 |
+
green = img[:, :, 1]
|
| 8 |
+
blue = img[:, :, 2]
|
| 9 |
+
|
| 10 |
+
try:
|
| 11 |
+
pca = PCA(n_components=n_components)
|
| 12 |
+
|
| 13 |
+
red_transformed = pca.fit_transform(red)
|
| 14 |
+
red_inverted = pca.inverse_transform(red_transformed)
|
| 15 |
+
|
| 16 |
+
green_transformed = pca.fit_transform(green)
|
| 17 |
+
green_inverted = pca.inverse_transform(green_transformed)
|
| 18 |
+
|
| 19 |
+
blue_transformed = pca.fit_transform(blue)
|
| 20 |
+
blue_inverted = pca.inverse_transform(blue_transformed)
|
| 21 |
+
img_compressed = (
|
| 22 |
+
np.dstack(
|
| 23 |
+
(red_inverted,
|
| 24 |
+
green_inverted,
|
| 25 |
+
blue_inverted))).astype(
|
| 26 |
+
np.uint8)
|
| 27 |
+
return img_compressed
|
| 28 |
+
except BaseException:
|
| 29 |
+
print("Kindly put valid number of components")
|
README.md
CHANGED
|
@@ -1,12 +1,44 @@
|
|
| 1 |
-
---
|
| 2 |
-
title:
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: Deep_Image_Search_System
|
| 3 |
+
app_file: ui_gradio.py
|
| 4 |
+
sdk: gradio
|
| 5 |
+
sdk_version: 4.36.1
|
| 6 |
+
---
|
| 7 |
+
# CSL2050 CourseProject
|
| 8 |
+
|
| 9 |
+
## TOPIC : Image Retrieval
|
| 10 |
+
|
| 11 |
+
## Team Members:
|
| 12 |
+
- Akshat Jain (B22CS096)
|
| 13 |
+
- Harshiv Shah (B22CS098)
|
| 14 |
+
- Mehta Jay Kamalkumar (B22CS034)
|
| 15 |
+
- Jyotin Goel (B22AI063)
|
| 16 |
+
- Rhythm Baghel (B22CS042)
|
| 17 |
+
|
| 18 |
+
## Repository Structure
|
| 19 |
+
|
| 20 |
+
```bash
|
| 21 |
+
βββ Experiments
|
| 22 |
+
β βββ cifar_10_knn.ipynb
|
| 23 |
+
β βββ cifar_10_pca_knn.ipynb
|
| 24 |
+
β βββ cnn-resnet34-cifar10.ipynb
|
| 25 |
+
β βββ cifar_10_mean.ipynb
|
| 26 |
+
β βββ cifar_10_centroid.ipynb
|
| 27 |
+
β βββ HOG+KNN.ipynb
|
| 28 |
+
β βββ Logs
|
| 29 |
+
β β βββ RESNET50_CIFAR.out
|
| 30 |
+
β βββ PCA+HOG+KNN.ipynb
|
| 31 |
+
β βββ train.py
|
| 32 |
+
βββ LICENSE
|
| 33 |
+
βββ MidTerm_Report.pdf
|
| 34 |
+
βββ Model
|
| 35 |
+
β βββ CIFAR.pt
|
| 36 |
+
β βββ test.py
|
| 37 |
+
βββGradio_Interface
|
| 38 |
+
β βββ ui_gradio.py
|
| 39 |
+
βββ Preprocessing
|
| 40 |
+
β βββ cifar_eda.ipynb
|
| 41 |
+
β βββ k_means.py
|
| 42 |
+
β βββ utils.py
|
| 43 |
+
βββ README.md
|
| 44 |
+
```
|
__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
print("Check")
|
flagged/log.csv
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Query Image,Retrieved Images,flag,username,timestamp
|
| 2 |
+
flagged\Query Image\184c0788f490a0624d75\Screenshot 2024-04-20 153136.png,"[{""image"": ""flagged\\Retrieved Images\\6efd8d4fa439705c4bec\\image.webp"", ""caption"": null}, {""image"": ""flagged\\Retrieved Images\\911f2dbf15da7790c9f6\\image.webp"", ""caption"": null}, {""image"": ""flagged\\Retrieved Images\\d194d62692564b86e213\\image.webp"", ""caption"": null}]",,,2024-04-20 15:56:52.109627
|
images/akshat.jpeg
ADDED
|
images/dog.jpg
ADDED
|
images/harshiv.jpg
ADDED
|
images/horse.jpg
ADDED
|