AJain1234 commited on
Commit
4e78e8c
Β·
verified Β·
1 Parent(s): c82d05c

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes. Β  See raw diff
Files changed (50) hide show
  1. .gitattributes +8 -0
  2. Experiments/HOG+KNN.ipynb +0 -0
  3. Experiments/Logs/RESNET50_3HLL_CIFAR.out +0 -0
  4. Experiments/Logs/RESNET50_CIFAR.out +0 -0
  5. Experiments/PCA+HOG+KNN.ipynb +0 -0
  6. Experiments/Resnet50_classification.py +128 -0
  7. Experiments/__init__.py +1 -0
  8. Experiments/__pycache__/Resnet50_classification.cpython-311.pyc +0 -0
  9. Experiments/__pycache__/__init__.cpython-311.pyc +0 -0
  10. Experiments/__pycache__/ann_cifar.cpython-311.pyc +0 -0
  11. Experiments/ann_cifar.py +91 -0
  12. Experiments/centroid_res18.py +85 -0
  13. Experiments/cifar_10_knn.ipynb +0 -0
  14. Experiments/cifar_10_knn_resnet18_73_percent.ipynb +197 -0
  15. Experiments/cifar_10_pca_knn.ipynb +0 -0
  16. Experiments/cifar_10_res50_m.py +37 -0
  17. Experiments/cnn-resnet34-cifar10.ipynb +749 -0
  18. Experiments/quicknet_cifar10_centroid.ipynb +698 -0
  19. Experiments/quicknet_knn.py +37 -0
  20. Experiments/train_resnet50_3hll.py +163 -0
  21. Experiments/train_resnet50_iter_1.py +156 -0
  22. LICENSE +21 -0
  23. MidTerm_Report.pdf +3 -0
  24. Model/CIFAR.pt +3 -0
  25. Model/Resnet50_train_features.pt +3 -0
  26. Model/__pycache__/centroid_app.cpython-311.pyc +0 -0
  27. Model/ann.pt +3 -0
  28. Model/centroid_app.py +80 -0
  29. Model/data/batches.meta +0 -0
  30. Model/data/data_batch_1 +3 -0
  31. Model/data/data_batch_2 +3 -0
  32. Model/data/data_batch_3 +3 -0
  33. Model/data/data_batch_4 +3 -0
  34. Model/data/data_batch_5 +3 -0
  35. Model/data/mean_embeddings.pkl +3 -0
  36. Model/data/test_batch +3 -0
  37. Model/data/train_embeddings_resnet18.pkl +3 -0
  38. Model/pretrained_model_weights.h5 +3 -0
  39. Model/resnet18.h5 +3 -0
  40. Model/test.py +49 -0
  41. Preprocessing/cifar_eda.ipynb +0 -0
  42. Preprocessing/k_means.py +46 -0
  43. Preprocessing/utils.py +29 -0
  44. README.md +44 -12
  45. __init__.py +1 -0
  46. flagged/log.csv +2 -0
  47. images/akshat.jpeg +0 -0
  48. images/dog.jpg +0 -0
  49. images/harshiv.jpg +0 -0
  50. images/horse.jpg +0 -0
.gitattributes CHANGED
@@ -33,3 +33,11 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ images/image2image.png filter=lfs diff=lfs merge=lfs -text
37
+ MidTerm_Report.pdf filter=lfs diff=lfs merge=lfs -text
38
+ Model/data/data_batch_1 filter=lfs diff=lfs merge=lfs -text
39
+ Model/data/data_batch_2 filter=lfs diff=lfs merge=lfs -text
40
+ Model/data/data_batch_3 filter=lfs diff=lfs merge=lfs -text
41
+ Model/data/data_batch_4 filter=lfs diff=lfs merge=lfs -text
42
+ Model/data/data_batch_5 filter=lfs diff=lfs merge=lfs -text
43
+ Model/data/test_batch filter=lfs diff=lfs merge=lfs -text
Experiments/HOG+KNN.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
Experiments/Logs/RESNET50_3HLL_CIFAR.out ADDED
The diff for this file is too large to render. See raw diff
 
Experiments/Logs/RESNET50_CIFAR.out ADDED
The diff for this file is too large to render. See raw diff
 
Experiments/PCA+HOG+KNN.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
Experiments/Resnet50_classification.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torchvision.datasets as datasets
2
+ import matplotlib.pyplot as plt
3
+ import tensorflow
4
+ import numpy as np
5
+ import warnings
6
+ warnings.filterwarnings("ignore")
7
+ import torchvision.models as models
8
+ import torch
9
+ from tqdm import tqdm
10
+ import torch.nn as nn
11
+ import torchvision.transforms as v2
12
+ from torch.utils.data import DataLoader
13
+ import torchvision.datasets as datasets
14
+ from sklearn.metrics.pairwise import euclidean_distances
15
+ from PIL import Image,ImageFilter
16
+ import torchvision.transforms as transforms
17
+ import pickle
18
+ import os
19
+ os.system("gdown --id 1qO2OLR7skDibo1LaMKD3CiOl_jaCTZ0h")
20
+
21
+
22
+ IMAGE_SIZE = 224
23
+ mean, std = [0.4914, 0.4822, 0.4465], [0.247, 0.243, 0.261]
24
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
25
+
26
+
27
+ def unpickle(file):
28
+ with open(file, 'rb') as fo:
29
+ dict = pickle.load(fo, encoding='bytes')
30
+ return dict
31
+
32
+ class HiddenLayer(nn.Module):
33
+ def __init__(self, pretrained_model):
34
+ super().__init__()
35
+ self.premodel = pretrained_model
36
+ self.new_layer = nn.Sequential(
37
+ nn.Linear(1000, 512),
38
+ nn.LeakyReLU(),
39
+ nn.Linear(512, 512),
40
+ nn.LeakyReLU(),
41
+ nn.Linear(512, 256),
42
+ nn.LeakyReLU(),
43
+ nn.Linear(256, 10)
44
+ )
45
+
46
+ def forward(self, x):
47
+ out = self.premodel(x)
48
+ out_new_layer = self.new_layer(out)
49
+ return out_new_layer
50
+
51
+
52
+ def predict(features_path,image):
53
+ batch1 = unpickle(r"Model\data\data_batch_1")
54
+ batch2 = unpickle(r"Model\data\data_batch_2")
55
+ batch3 = unpickle(r"Model\data\data_batch_3")
56
+ batch4 = unpickle(r"Model\data\data_batch_4")
57
+ batch5 = unpickle(r"Model\data\data_batch_5")
58
+ test_batch = unpickle(r"Model\data\test_batch")
59
+ train_batch = [batch1,batch2,batch3,batch4,batch5]
60
+ train_y = []
61
+ train_x = []
62
+ for batch in train_batch:
63
+ y_data = batch[b'labels']
64
+ x_data = batch[b'data']
65
+ x_data = x_data.reshape(len(x_data),3,32,32).transpose(0,2,3,1)
66
+
67
+ for i in range(len(y_data)):
68
+ train_y.append(y_data[i])
69
+
70
+ for i in range(len(y_data)):
71
+ train_x.append(x_data[i])
72
+
73
+ features = torch.load(features_path)
74
+
75
+ resnet_train_data = []
76
+ for i in range(len(features)):
77
+ resnet_train_data.append((features[i],train_y[i]))
78
+
79
+ class_images_dict = {}
80
+ for batch_idx, (images, labels) in enumerate(resnet_train_data):
81
+ if labels not in class_images_dict:
82
+ class_images_dict[labels] = []
83
+ class_images_dict[labels].append(batch_idx)
84
+
85
+ transform = transforms.Compose([
86
+ transforms.Resize((224, 224)),
87
+ transforms.ToTensor(),
88
+ transforms.Normalize(mean, std)
89
+ ])
90
+ pil_image = Image.fromarray(image)
91
+ image_tensor = transform(pil_image).unsqueeze(0)
92
+ resnet = models.resnet50(pretrained=True)
93
+ model_check = HiddenLayer(resnet)
94
+ model_check.load_state_dict(torch.load("CIFAR_end_hll.pt"))
95
+ model_check.eval()
96
+
97
+ with torch.no_grad():
98
+ z = model_check(image_tensor)
99
+ _, test_label = torch.max(z, 1)
100
+
101
+ return test_label,z,features,class_images_dict,train_x
102
+
103
+
104
+
105
+ def retrieve(image,k,feature_path=r"Model\Resnet50_train_features.pt"):
106
+ print(image.shape)
107
+ test_label,z,features,class_images_dict,train_x = predict(feature_path,image)
108
+ class_indices = class_images_dict[test_label.item()]
109
+ class_features = [(features[idx], idx) for idx in class_indices]
110
+ test_features = z.cpu().detach().numpy()
111
+ distances = euclidean_distances(test_features, [f[0].cpu() for f in class_features])
112
+ sorted_indices = np.argsort(distances.flatten())[:k]
113
+ closest_indices = [class_features[idx][1] for idx in sorted_indices]
114
+ retrieved_images = []
115
+ for i, idx in enumerate(closest_indices):
116
+ closest_image = Image.fromarray(train_x[idx])
117
+ sharpened_closest_image = closest_image.filter(ImageFilter.SHARPEN)
118
+ retrieved_images.append(sharpened_closest_image)
119
+
120
+ return retrieved_images
121
+
122
+
123
+
124
+
125
+
126
+
127
+ # test_image = Image.open("/kaggle/input/planes/download.jpeg")
128
+ # retrieved_images = retrieve(test_image,3)
Experiments/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ print("Initializing")
Experiments/__pycache__/Resnet50_classification.cpython-311.pyc ADDED
Binary file (8.55 kB). View file
 
Experiments/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (222 Bytes). View file
 
Experiments/__pycache__/ann_cifar.cpython-311.pyc ADDED
Binary file (6.22 kB). View file
 
Experiments/ann_cifar.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torchvision.datasets as datasets
2
+ import matplotlib.pyplot as plt
3
+ import tensorflow
4
+ import numpy as np
5
+ import warnings
6
+ warnings.filterwarnings("ignore")
7
+ import torchvision.models as models
8
+ import torch
9
+ from tqdm import tqdm
10
+ import torch.nn as nn
11
+ import torchvision.transforms as v2
12
+ from torch.utils.data import DataLoader
13
+ import torchvision.datasets as datasets
14
+ from sklearn.metrics.pairwise import euclidean_distances
15
+ from PIL import Image,ImageFilter
16
+ import torchvision.transforms as transforms
17
+ import pickle
18
+ import os
19
+ import cv2
20
+
21
+
22
+ def unpickle(file):
23
+ with open(file, 'rb') as fo:
24
+ dict = pickle.load(fo, encoding='bytes')
25
+ return dict
26
+
27
+
28
+ class ANN(nn.Module):
29
+ def __init__(self):
30
+ super(ANN, self).__init__()
31
+ self.layer_1 = nn.Linear(32 * 32 * 3, 512) # Input size is 32x32x3
32
+ self.relu1 = nn.ReLU()
33
+ self.layer_2 = nn.Linear(512, 256)
34
+ self.relu2 = nn.ReLU()
35
+ self.layer_3 = nn.Linear(256, 10) # Output size is 10 (number of classes)
36
+
37
+ def forward(self, x):
38
+ x = x.view(-1, 32 * 32 * 3) # Flatten the input tensor
39
+ x=self.layer_1(x)
40
+ x=self.relu1(x)
41
+ x=self.layer_2(x)
42
+ x=self.relu2(x)
43
+ x=self.layer_3(x)
44
+ return x
45
+
46
+ def retrieve(test_data,model,k_value=10):
47
+ test_data = cv2.resize(test_data,(32,32))
48
+ test_data = torch.tensor(test_data,dtype=torch.float32).unsqueeze(dim=0)
49
+ print(test_data.shape)
50
+ outputs = model(test_data)
51
+ _, predicted = torch.max(outputs.data, 1)
52
+
53
+ batch1 = unpickle(r"Model\data\data_batch_1")
54
+ batch2 = unpickle(r"Model\data\data_batch_2")
55
+ batch3 = unpickle(r"Model\data\data_batch_3")
56
+ batch4 = unpickle(r"Model\data\data_batch_4")
57
+ batch5 = unpickle(r"Model\data\data_batch_5")
58
+ # test_batch = unpickle(r"Model\data\test_batch")
59
+ train_batch = [batch1,batch2,batch3,batch4,batch5]
60
+
61
+ train_data = []
62
+ for batch in train_batch:
63
+ y_data = batch[b'labels']
64
+ x_data = batch[b'data']
65
+ x_data = x_data.reshape(len(x_data),3,32,32).transpose(0,2,3,1)
66
+
67
+ for i in range(len(y_data)):
68
+ train_data.append((x_data[i],y_data[i]))
69
+
70
+ x_with_specific_y = [x for x, y in train_data if y-1 == predicted]
71
+ x_with_specific_y = np.array(x_with_specific_y)
72
+ test_data.squeeze().permute(1,2,0)
73
+ point = np.array(test_data.squeeze())
74
+
75
+ distance_with_label_and_index = []
76
+
77
+ for i,x_train in enumerate(x_with_specific_y):
78
+ train_point = np.array(x_train[1])
79
+ distance_with_label_and_index.append((i,np.linalg.norm(point-train_point)))
80
+
81
+ #sorting based on distance
82
+ distance_with_label_and_index_sorted=sorted(distance_with_label_and_index,key=lambda x: x[1])
83
+ k_nearest_points = distance_with_label_and_index_sorted[0:k_value]
84
+
85
+ retrived_images = []
86
+ #calculating accuracy
87
+ for i,(index,distance) in enumerate(k_nearest_points):
88
+ # print(np.array(train_data[index]).shape)
89
+ retrived_images.append(np.array(x_with_specific_y[index]))
90
+
91
+ return retrived_images
Experiments/centroid_res18.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pickle
2
+ import random
3
+ import numpy as np
4
+ import cv2
5
+ from tensorflow.keras import models, layers
6
+
7
+ # Load the trained model architecture
8
+ def create_resnet18():
9
+ model = models.Sequential()
10
+ model.add(layers.Conv2D(64, (3, 3), activation='relu', padding='same', input_shape=(32, 32, 3)))
11
+ model.add(layers.MaxPooling2D((2, 2)))
12
+ model.add(layers.Conv2D(64, (3, 3), activation='relu', padding='same'))
13
+ model.add(layers.MaxPooling2D((2, 2)))
14
+ model.add(layers.Conv2D(128, (3, 3), activation='relu', padding='same'))
15
+ model.add(layers.Conv2D(128, (3, 3), activation='relu', padding='same'))
16
+ model.add(layers.MaxPooling2D((2, 2)))
17
+ model.add(layers.Conv2D(256, (3, 3), activation='relu', padding='same'))
18
+ model.add(layers.Conv2D(256, (3, 3), activation='relu', padding='same'))
19
+ model.add(layers.MaxPooling2D((2, 2)))
20
+ model.add(layers.Flatten())
21
+ model.add(layers.Dense(512, activation='relu'))
22
+ model.add(layers.Dense(10, activation='softmax'))
23
+ return model
24
+
25
+ # Load the pretrained weights
26
+ def load_pretrained_weights(model, weights_path):
27
+ model.load_weights(weights_path)
28
+
29
+ # Function to unpickle a file
30
+ def unpickle(file):
31
+ with open(file, 'rb') as fo:
32
+ dict = pickle.load(fo, encoding='bytes')
33
+ return dict
34
+
35
+ # Function to load images from the unpickled data batch file of a specific class
36
+ def load_class_images(class_index, train_batches):
37
+ images = []
38
+ for batch in train_batches:
39
+ if b'data' in batch and b'labels' in batch:
40
+ data = batch[b'data']
41
+ labels = batch[b'labels']
42
+ for i, label in enumerate(labels):
43
+ if label == class_index:
44
+ img = data[i].reshape(3, 32, 32).transpose(1, 2, 0) # Reshape and transpose the image
45
+ images.append(img)
46
+ return images
47
+
48
+ # Function to calculate the distance of the mean embeddings with a query image
49
+ def classify_query(query_image, model, mean_embeddings):
50
+ query_embedding = model.predict(np.expand_dims(query_image, axis=0))
51
+ distances = [np.linalg.norm(query_embedding.flatten() - mean_embedding) for mean_embedding in mean_embeddings]
52
+ predicted_class = np.argmin(distances)
53
+ return predicted_class
54
+
55
+ from google.colab.patches import cv2_imshow
56
+
57
+ def main():
58
+ model = create_resnet18()
59
+ load_pretrained_weights(model, 'pretrained_model_weights.h5')
60
+
61
+ mean_embeddings = pickle.load(open('mean_embeddings.pkl', 'rb'))
62
+
63
+ query_image_path = '/content/airplane_8925.png'
64
+ query_image = cv2.imread(query_image_path)
65
+ query_image = cv2.resize(query_image, (32, 32)) / 255.0 # Resize and normalize the image
66
+
67
+ predicted_class = classify_query(query_image, model, mean_embeddings)
68
+ print("Predicted Class:", predicted_class)
69
+
70
+ # Load random images of the predicted class
71
+ train_batches = [unpickle(f"/content/data_batch_{i}") for i in range(1,6)]
72
+ class_images = load_class_images(predicted_class+1, train_batches)
73
+
74
+ if class_images:
75
+ random_images = random.sample(class_images, 3) # Select 3 random images
76
+ for img in random_images:
77
+ if img is not None:
78
+ cv2_imshow(img)
79
+ else:
80
+ print("Random image is None.")
81
+ else:
82
+ print("No images found for the predicted class.")
83
+
84
+ if __name__ == "__main__":
85
+ main()
Experiments/cifar_10_knn.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
Experiments/cifar_10_knn_resnet18_73_percent.ipynb ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "import tensorflow as tf\n",
10
+ "from scipy.spatial.distance import euclidean\n",
11
+ "from tensorflow.keras import datasets, layers, models\n",
12
+ "import numpy as np\n",
13
+ "\n",
14
+ "\n",
15
+ "(train_images, train_labels), (test_images, test_labels) = datasets.cifar10.load_data()\n",
16
+ "\n",
17
+ "\n",
18
+ "train_images = train_images / 255.0\n",
19
+ "test_images = test_images / 255.0\n",
20
+ "\n",
21
+ "\n",
22
+ "def create_resnet18():\n",
23
+ " model = models.Sequential()\n",
24
+ " model.add(layers.Conv2D(64, (3, 3), activation='relu', padding='same', input_shape=(32, 32, 3)))\n",
25
+ " model.add(layers.MaxPooling2D((2, 2)))\n",
26
+ " model.add(layers.Conv2D(64, (3, 3), activation='relu', padding='same'))\n",
27
+ " model.add(layers.MaxPooling2D((2, 2)))\n",
28
+ " model.add(layers.Conv2D(128, (3, 3), activation='relu', padding='same'))\n",
29
+ " model.add(layers.Conv2D(128, (3, 3), activation='relu', padding='same'))\n",
30
+ " model.add(layers.MaxPooling2D((2, 2)))\n",
31
+ " model.add(layers.Conv2D(256, (3, 3), activation='relu', padding='same'))\n",
32
+ " model.add(layers.Conv2D(256, (3, 3), activation='relu', padding='same'))\n",
33
+ " model.add(layers.MaxPooling2D((2, 2)))\n",
34
+ " model.add(layers.Flatten())\n",
35
+ " model.add(layers.Dense(512, activation='relu'))\n",
36
+ " model.add(layers.Dense(10, activation='softmax'))\n",
37
+ " return model\n",
38
+ "\n",
39
+ "\n",
40
+ "model = create_resnet18()\n",
41
+ "\n",
42
+ "\n",
43
+ "model.compile(optimizer='adam',\n",
44
+ " loss='sparse_categorical_crossentropy',\n",
45
+ " metrics=['accuracy'])\n",
46
+ "\n",
47
+ "\n",
48
+ "model.fit(train_images, train_labels, epochs=10, batch_size=64, validation_data=(test_images, test_labels))\n",
49
+ "\n",
50
+ "\n",
51
+ "train_embeddings = model.predict(train_images)\n"
52
+ ]
53
+ },
54
+ {
55
+ "cell_type": "code",
56
+ "execution_count": null,
57
+ "metadata": {},
58
+ "outputs": [],
59
+ "source": [
60
+ "test_embeddings = model.predict(test_images)\n",
61
+ "test_embeddings.shape"
62
+ ]
63
+ },
64
+ {
65
+ "cell_type": "code",
66
+ "execution_count": null,
67
+ "metadata": {},
68
+ "outputs": [],
69
+ "source": [
70
+ "resnet_train_data = []\n",
71
+ "for i in range(len(train_embeddings)):\n",
72
+ " resnet_train_data.append((train_embeddings[i],train_labels[i]))\n",
73
+ "resnet_test_data = []\n",
74
+ "for i in range(len(test_embeddings)):\n",
75
+ " resnet_test_data.append((test_embeddings[i],test_labels[i]))"
76
+ ]
77
+ },
78
+ {
79
+ "cell_type": "code",
80
+ "execution_count": null,
81
+ "metadata": {},
82
+ "outputs": [],
83
+ "source": [
84
+ "import matplotlib.pyplot as plt\n",
85
+ "def find_knn_and_plot_image(k_value:int =10,train_data=resnet_train_data,test_data=resnet_test_data[0]):\n",
86
+ " correct = 0\n",
87
+ " x_test,y_test = test_data\n",
88
+ " point = np.array(x_test)\n",
89
+ "\n",
90
+ " distance_with_label_and_index = []\n",
91
+ "\n",
92
+ " for i,(x_train,y_train) in enumerate(train_data):\n",
93
+ " train_point = np.array(x_train)\n",
94
+ " distance_with_label_and_index.append(((y_train,i),np.linalg.norm(point-train_point)))\n",
95
+ "\n",
96
+ " #sorting based on distance\n",
97
+ " distance_with_label_and_index_sorted=sorted(distance_with_label_and_index,key=lambda x: x[1])\n",
98
+ " k_nearest_points = distance_with_label_and_index_sorted[0:k_value]\n",
99
+ "\n",
100
+ " \n",
101
+ " #calculating accuracy\n",
102
+ " fig, axes = plt.subplots(1, k_value, figsize=(15, 3)) # Adjust figsize as needed\n",
103
+ " for i,((label,index),distance) in enumerate(k_nearest_points):\n",
104
+ " if(label == y_test):\n",
105
+ " correct+=1\n",
106
+ " axes[i].imshow(train_images[index])\n",
107
+ " axes[i].set_title(f\"Index: {index}\")\n",
108
+ " axes[i].axis('off')\n",
109
+ " plt.show()\n",
110
+ " return correct/(k_value)"
111
+ ]
112
+ },
113
+ {
114
+ "cell_type": "code",
115
+ "execution_count": null,
116
+ "metadata": {},
117
+ "outputs": [],
118
+ "source": [
119
+ "index =91\n",
120
+ "plt.imshow(test_images[index])\n",
121
+ "print(test_labels[index])\n",
122
+ "find_knn_and_plot_image(10,resnet_train_data,resnet_test_data[index])"
123
+ ]
124
+ },
125
+ {
126
+ "cell_type": "code",
127
+ "execution_count": null,
128
+ "metadata": {},
129
+ "outputs": [],
130
+ "source": [
131
+ "def find_knn(k_value:int =10,train_data=resnet_train_data,test_data=resnet_test_data):\n",
132
+ " results =[]\n",
133
+ " correct = 0\n",
134
+ " for idx, (x_test,y_test) in enumerate(tqdm(test_data)):\n",
135
+ " point = np.array(x_test)\n",
136
+ "\n",
137
+ " distance_with_label_and_index = []\n",
138
+ "\n",
139
+ " for i,(x_train,y_train) in enumerate(train_data):\n",
140
+ " train_point = np.array(x_train)\n",
141
+ " distance_with_label_and_index.append(((y_train,i),np.linalg.norm(point-train_point)))\n",
142
+ "\n",
143
+ " #sorting based on distance\n",
144
+ " distance_with_label_and_index_sorted=sorted(distance_with_label_and_index,key=lambda x: x[1])\n",
145
+ " k_nearest_points = distance_with_label_and_index_sorted[0:k_value]\n",
146
+ "\n",
147
+ "\n",
148
+ " #calculating accuracy\n",
149
+ " results.append([])\n",
150
+ " for i,((label,index),distance) in enumerate(k_nearest_points):\n",
151
+ " if(label == y_test):\n",
152
+ " correct+=1\n",
153
+ " results[idx].append((label,index))\n",
154
+ " return results,correct/(k_value*len(test_data))"
155
+ ]
156
+ },
157
+ {
158
+ "cell_type": "code",
159
+ "execution_count": null,
160
+ "metadata": {},
161
+ "outputs": [],
162
+ "source": [
163
+ "results,accuracy = find_knn(10)"
164
+ ]
165
+ },
166
+ {
167
+ "cell_type": "code",
168
+ "execution_count": null,
169
+ "metadata": {},
170
+ "outputs": [],
171
+ "source": [
172
+ "accuracy"
173
+ ]
174
+ }
175
+ ],
176
+ "metadata": {
177
+ "kernelspec": {
178
+ "display_name": "Python 3",
179
+ "language": "python",
180
+ "name": "python3"
181
+ },
182
+ "language_info": {
183
+ "codemirror_mode": {
184
+ "name": "ipython",
185
+ "version": 3
186
+ },
187
+ "file_extension": ".py",
188
+ "mimetype": "text/x-python",
189
+ "name": "python",
190
+ "nbconvert_exporter": "python",
191
+ "pygments_lexer": "ipython3",
192
+ "version": "3.11.4"
193
+ }
194
+ },
195
+ "nbformat": 4,
196
+ "nbformat_minor": 2
197
+ }
Experiments/cifar_10_pca_knn.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
Experiments/cifar_10_res50_m.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+ from tensorflow.keras import datasets, layers, models
3
+ import numpy as np
4
+
5
+ # Load CIFAR-10 dataset
6
+ (train_images, train_labels), (_, _) = datasets.cifar10.load_data()
7
+
8
+ # Normalize pixel values to be between 0 and 1
9
+ train_images = train_images / 255.0
10
+
11
+ # Load a pre-trained model (e.g., ResNet50)
12
+ base_model = tf.keras.applications.ResNet50(weights='imagenet', include_top=False, input_shape=(32, 32, 3))
13
+
14
+ # Extract embeddings
15
+ embeddings_model = models.Sequential([
16
+ base_model,
17
+ layers.GlobalAveragePooling2D(), # Using GlobalAveragePooling2D to reduce spatial dimensions
18
+ layers.Dense(512, activation='relu') # You can adjust the size of the dense layer according to your need
19
+ ])
20
+
21
+ # Compute embeddings
22
+ embeddings = embeddings_model.predict(train_images)
23
+
24
+ # Initialize array to store mean embeddings for each class
25
+ num_classes = 10
26
+ mean_embeddings = np.zeros((num_classes, embeddings.shape[1]))
27
+
28
+ # Calculate mean embeddings for each class
29
+ for class_label in range(num_classes):
30
+ class_indices = np.where(train_labels == class_label)[0]
31
+ class_embeddings = embeddings[class_indices]
32
+ mean_embedding = np.mean(class_embeddings, axis=0)
33
+ mean_embeddings[class_label] = mean_embedding
34
+
35
+ print("Mean embeddings for each class:")
36
+ print(mean_embeddings)
37
+ mean_embeddings.shape
Experiments/cnn-resnet34-cifar10.ipynb ADDED
@@ -0,0 +1,749 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {
6
+ "colab_type": "text",
7
+ "id": "MkoGLH_Tj5wn"
8
+ },
9
+ "source": [
10
+ "## Imports"
11
+ ]
12
+ },
13
+ {
14
+ "cell_type": "code",
15
+ "execution_count": 2,
16
+ "metadata": {
17
+ "colab": {
18
+ "autoexec": {
19
+ "startup": false,
20
+ "wait_interval": 0
21
+ }
22
+ },
23
+ "colab_type": "code",
24
+ "id": "ORj09gnrj5wp"
25
+ },
26
+ "outputs": [],
27
+ "source": [
28
+ "import os\n",
29
+ "import time\n",
30
+ "\n",
31
+ "import numpy as np\n",
32
+ "import pandas as pd\n",
33
+ "\n",
34
+ "import torch\n",
35
+ "import torch.nn as nn\n",
36
+ "import torch.nn.functional as F\n",
37
+ "from torch.utils.data import DataLoader\n",
38
+ "\n",
39
+ "from torchvision import datasets\n",
40
+ "from torchvision import transforms\n",
41
+ "\n",
42
+ "import matplotlib.pyplot as plt\n",
43
+ "from PIL import Image\n",
44
+ "\n",
45
+ "\n",
46
+ "if torch.cuda.is_available():\n",
47
+ " torch.backends.cudnn.deterministic = True"
48
+ ]
49
+ },
50
+ {
51
+ "cell_type": "markdown",
52
+ "metadata": {
53
+ "colab_type": "text",
54
+ "id": "I6hghKPxj5w0"
55
+ },
56
+ "source": [
57
+ "## Model Settings"
58
+ ]
59
+ },
60
+ {
61
+ "cell_type": "code",
62
+ "execution_count": 3,
63
+ "metadata": {
64
+ "colab": {
65
+ "autoexec": {
66
+ "startup": false,
67
+ "wait_interval": 0
68
+ },
69
+ "base_uri": "https://localhost:8080/",
70
+ "height": 85
71
+ },
72
+ "colab_type": "code",
73
+ "executionInfo": {
74
+ "elapsed": 23936,
75
+ "status": "ok",
76
+ "timestamp": 1524974497505,
77
+ "user": {
78
+ "displayName": "Sebastian Raschka",
79
+ "photoUrl": "//lh6.googleusercontent.com/-cxK6yOSQ6uE/AAAAAAAAAAI/AAAAAAAAIfw/P9ar_CHsKOQ/s50-c-k-no/photo.jpg",
80
+ "userId": "118404394130788869227"
81
+ },
82
+ "user_tz": 240
83
+ },
84
+ "id": "NnT0sZIwj5wu",
85
+ "outputId": "55aed925-d17e-4c6a-8c71-0d9b3bde5637"
86
+ },
87
+ "outputs": [],
88
+ "source": [
89
+ "##########################\n",
90
+ "### SETTINGS\n",
91
+ "##########################\n",
92
+ "\n",
93
+ "# Hyperparameters\n",
94
+ "RANDOM_SEED = 1\n",
95
+ "LEARNING_RATE = 0.001\n",
96
+ "BATCH_SIZE = 256\n",
97
+ "NUM_EPOCHS = 10\n",
98
+ "\n",
99
+ "# Architecture\n",
100
+ "NUM_FEATURES = 28*28\n",
101
+ "NUM_CLASSES = 10\n",
102
+ "\n",
103
+ "# Other\n",
104
+ "DEVICE = \"cuda:1\"\n",
105
+ "GRAYSCALE = False"
106
+ ]
107
+ },
108
+ {
109
+ "cell_type": "code",
110
+ "execution_count": 4,
111
+ "metadata": {},
112
+ "outputs": [],
113
+ "source": [
114
+ "##########################\n",
115
+ "### MODEL\n",
116
+ "##########################\n",
117
+ "\n",
118
+ "\n",
119
+ "def conv3x3(in_planes, out_planes, stride=1):\n",
120
+ " \"\"\"3x3 convolution with padding\"\"\"\n",
121
+ " return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,\n",
122
+ " padding=1, bias=False)\n",
123
+ "\n",
124
+ "\n",
125
+ "class BasicBlock(nn.Module):\n",
126
+ " expansion = 1\n",
127
+ "\n",
128
+ " def __init__(self, inplanes, planes, stride=1, downsample=None):\n",
129
+ " super(BasicBlock, self).__init__()\n",
130
+ " self.conv1 = conv3x3(inplanes, planes, stride)\n",
131
+ " self.bn1 = nn.BatchNorm2d(planes)\n",
132
+ " self.relu = nn.ReLU(inplace=True)\n",
133
+ " self.conv2 = conv3x3(planes, planes)\n",
134
+ " self.bn2 = nn.BatchNorm2d(planes)\n",
135
+ " self.downsample = downsample\n",
136
+ " self.stride = stride\n",
137
+ "\n",
138
+ " def forward(self, x):\n",
139
+ " residual = x\n",
140
+ "\n",
141
+ " out = self.conv1(x)\n",
142
+ " out = self.bn1(out)\n",
143
+ " out = self.relu(out)\n",
144
+ "\n",
145
+ " out = self.conv2(out)\n",
146
+ " out = self.bn2(out)\n",
147
+ "\n",
148
+ " if self.downsample is not None:\n",
149
+ " residual = self.downsample(x)\n",
150
+ "\n",
151
+ " out += residual\n",
152
+ " out = self.relu(out)\n",
153
+ "\n",
154
+ " return out\n",
155
+ "\n",
156
+ "\n",
157
+ "\n",
158
+ "\n",
159
+ "class ResNet(nn.Module):\n",
160
+ "\n",
161
+ " def __init__(self, block, layers, num_classes, grayscale):\n",
162
+ " self.inplanes = 64\n",
163
+ " if grayscale:\n",
164
+ " in_dim = 1\n",
165
+ " else:\n",
166
+ " in_dim = 3\n",
167
+ " super(ResNet, self).__init__()\n",
168
+ " self.conv1 = nn.Conv2d(in_dim, 64, kernel_size=7, stride=2, padding=3,\n",
169
+ " bias=False)\n",
170
+ " self.bn1 = nn.BatchNorm2d(64)\n",
171
+ " self.relu = nn.ReLU(inplace=True)\n",
172
+ " self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)\n",
173
+ " self.layer1 = self._make_layer(block, 64, layers[0])\n",
174
+ " self.layer2 = self._make_layer(block, 128, layers[1], stride=2)\n",
175
+ " self.layer3 = self._make_layer(block, 256, layers[2], stride=2)\n",
176
+ " self.layer4 = self._make_layer(block, 512, layers[3], stride=2)\n",
177
+ " self.avgpool = nn.AvgPool2d(7, stride=1)\n",
178
+ " self.fc = nn.Linear(512 * block.expansion, num_classes)\n",
179
+ "\n",
180
+ " for m in self.modules():\n",
181
+ " if isinstance(m, nn.Conv2d):\n",
182
+ " n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels\n",
183
+ " m.weight.data.normal_(0, (2. / n)**.5)\n",
184
+ " elif isinstance(m, nn.BatchNorm2d):\n",
185
+ " m.weight.data.fill_(1)\n",
186
+ " m.bias.data.zero_()\n",
187
+ "\n",
188
+ " def _make_layer(self, block, planes, blocks, stride=1):\n",
189
+ " downsample = None\n",
190
+ " if stride != 1 or self.inplanes != planes * block.expansion:\n",
191
+ " downsample = nn.Sequential(\n",
192
+ " nn.Conv2d(self.inplanes, planes * block.expansion,\n",
193
+ " kernel_size=1, stride=stride, bias=False),\n",
194
+ " nn.BatchNorm2d(planes * block.expansion),\n",
195
+ " )\n",
196
+ "\n",
197
+ " layers = []\n",
198
+ " layers.append(block(self.inplanes, planes, stride, downsample))\n",
199
+ " self.inplanes = planes * block.expansion\n",
200
+ " for i in range(1, blocks):\n",
201
+ " layers.append(block(self.inplanes, planes))\n",
202
+ "\n",
203
+ " return nn.Sequential(*layers)\n",
204
+ "\n",
205
+ " def forward(self, x):\n",
206
+ " x = self.conv1(x)\n",
207
+ " x = self.bn1(x)\n",
208
+ " x = self.relu(x)\n",
209
+ " x = self.maxpool(x)\n",
210
+ "\n",
211
+ " x = self.layer1(x)\n",
212
+ " x = self.layer2(x)\n",
213
+ " x = self.layer3(x)\n",
214
+ " x = self.layer4(x)\n",
215
+ " \n",
216
+ " x = x.view(x.size(0), -1)\n",
217
+ " logits = self.fc(x)\n",
218
+ " probas = F.softmax(logits, dim=1)\n",
219
+ " return logits, probas\n",
220
+ "\n",
221
+ "\n",
222
+ "\n",
223
+ "def resnet34(num_classes):\n",
224
+ " \"\"\"Constructs a ResNet-34 model.\"\"\"\n",
225
+ " model = ResNet(block=BasicBlock, \n",
226
+ " layers=[3, 4, 6, 3],\n",
227
+ " num_classes=NUM_CLASSES,\n",
228
+ " grayscale=GRAYSCALE)\n",
229
+ " return model\n"
230
+ ]
231
+ },
232
+ {
233
+ "cell_type": "markdown",
234
+ "metadata": {
235
+ "colab_type": "text",
236
+ "id": "RAodboScj5w6"
237
+ },
238
+ "source": [
239
+ "## Training without Pinned Memory"
240
+ ]
241
+ },
242
+ {
243
+ "cell_type": "code",
244
+ "execution_count": 5,
245
+ "metadata": {},
246
+ "outputs": [
247
+ {
248
+ "name": "stdout",
249
+ "output_type": "stream",
250
+ "text": [
251
+ "Files already downloaded and verified\n",
252
+ "Image batch dimensions: torch.Size([256, 3, 32, 32])\n",
253
+ "Image label dimensions: torch.Size([256])\n",
254
+ "Image batch dimensions: torch.Size([256, 3, 32, 32])\n",
255
+ "Image label dimensions: torch.Size([256])\n"
256
+ ]
257
+ }
258
+ ],
259
+ "source": [
260
+ "##########################\n",
261
+ "### CIFAR-10 Dataset\n",
262
+ "##########################\n",
263
+ "\n",
264
+ "\n",
265
+ "# Note transforms.ToTensor() scales input images\n",
266
+ "# to 0-1 range\n",
267
+ "train_dataset = datasets.CIFAR10(root='data', \n",
268
+ " train=True, \n",
269
+ " transform=transforms.ToTensor(),\n",
270
+ " download=True)\n",
271
+ "\n",
272
+ "test_dataset = datasets.CIFAR10(root='data', \n",
273
+ " train=False, \n",
274
+ " transform=transforms.ToTensor())\n",
275
+ "\n",
276
+ "\n",
277
+ "train_loader = DataLoader(dataset=train_dataset, \n",
278
+ " batch_size=BATCH_SIZE, \n",
279
+ " num_workers=8,\n",
280
+ " shuffle=True)\n",
281
+ "\n",
282
+ "test_loader = DataLoader(dataset=test_dataset, \n",
283
+ " batch_size=BATCH_SIZE,\n",
284
+ " num_workers=8,\n",
285
+ " shuffle=False)\n",
286
+ "\n",
287
+ "# Checking the dataset\n",
288
+ "for images, labels in train_loader: \n",
289
+ " print('Image batch dimensions:', images.shape)\n",
290
+ " print('Image label dimensions:', labels.shape)\n",
291
+ " break\n",
292
+ "\n",
293
+ "# Checking the dataset\n",
294
+ "for images, labels in train_loader: \n",
295
+ " print('Image batch dimensions:', images.shape)\n",
296
+ " print('Image label dimensions:', labels.shape)\n",
297
+ " break"
298
+ ]
299
+ },
300
+ {
301
+ "cell_type": "code",
302
+ "execution_count": 6,
303
+ "metadata": {
304
+ "colab": {
305
+ "autoexec": {
306
+ "startup": false,
307
+ "wait_interval": 0
308
+ }
309
+ },
310
+ "colab_type": "code",
311
+ "id": "_lza9t_uj5w1"
312
+ },
313
+ "outputs": [],
314
+ "source": [
315
+ "torch.manual_seed(RANDOM_SEED)\n",
316
+ "\n",
317
+ "model = resnet34(NUM_CLASSES)\n",
318
+ "model.to(DEVICE)\n",
319
+ "\n",
320
+ "optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE) "
321
+ ]
322
+ },
323
+ {
324
+ "cell_type": "code",
325
+ "execution_count": 7,
326
+ "metadata": {
327
+ "colab": {
328
+ "autoexec": {
329
+ "startup": false,
330
+ "wait_interval": 0
331
+ },
332
+ "base_uri": "https://localhost:8080/",
333
+ "height": 1547
334
+ },
335
+ "colab_type": "code",
336
+ "executionInfo": {
337
+ "elapsed": 2384585,
338
+ "status": "ok",
339
+ "timestamp": 1524976888520,
340
+ "user": {
341
+ "displayName": "Sebastian Raschka",
342
+ "photoUrl": "//lh6.googleusercontent.com/-cxK6yOSQ6uE/AAAAAAAAAAI/AAAAAAAAIfw/P9ar_CHsKOQ/s50-c-k-no/photo.jpg",
343
+ "userId": "118404394130788869227"
344
+ },
345
+ "user_tz": 240
346
+ },
347
+ "id": "Dzh3ROmRj5w7",
348
+ "outputId": "5f8fd8c9-b076-403a-b0b7-fd2d498b48d7",
349
+ "scrolled": false
350
+ },
351
+ "outputs": [
352
+ {
353
+ "name": "stdout",
354
+ "output_type": "stream",
355
+ "text": [
356
+ "Epoch: 001/010 | Batch 0000/0196 | Cost: 2.6021\n",
357
+ "Epoch: 001/010 | Batch 0150/0196 | Cost: 1.3961\n",
358
+ "Epoch: 001/010 | Train: 45.084%\n",
359
+ "Time elapsed: 0.26 min\n",
360
+ "Epoch: 002/010 | Batch 0000/0196 | Cost: 1.1228\n",
361
+ "Epoch: 002/010 | Batch 0150/0196 | Cost: 1.0426\n",
362
+ "Epoch: 002/010 | Train: 56.166%\n",
363
+ "Time elapsed: 0.52 min\n",
364
+ "Epoch: 003/010 | Batch 0000/0196 | Cost: 0.9980\n",
365
+ "Epoch: 003/010 | Batch 0150/0196 | Cost: 0.8279\n",
366
+ "Epoch: 003/010 | Train: 66.702%\n",
367
+ "Time elapsed: 0.80 min\n",
368
+ "Epoch: 004/010 | Batch 0000/0196 | Cost: 0.6384\n",
369
+ "Epoch: 004/010 | Batch 0150/0196 | Cost: 0.7103\n",
370
+ "Epoch: 004/010 | Train: 65.330%\n",
371
+ "Time elapsed: 1.08 min\n",
372
+ "Epoch: 005/010 | Batch 0000/0196 | Cost: 0.6308\n",
373
+ "Epoch: 005/010 | Batch 0150/0196 | Cost: 0.5913\n",
374
+ "Epoch: 005/010 | Train: 79.636%\n",
375
+ "Time elapsed: 1.36 min\n",
376
+ "Epoch: 006/010 | Batch 0000/0196 | Cost: 0.4409\n",
377
+ "Epoch: 006/010 | Batch 0150/0196 | Cost: 0.5557\n",
378
+ "Epoch: 006/010 | Train: 76.456%\n",
379
+ "Time elapsed: 1.62 min\n",
380
+ "Epoch: 007/010 | Batch 0000/0196 | Cost: 0.4778\n",
381
+ "Epoch: 007/010 | Batch 0150/0196 | Cost: 0.4815\n",
382
+ "Epoch: 007/010 | Train: 65.890%\n",
383
+ "Time elapsed: 1.89 min\n",
384
+ "Epoch: 008/010 | Batch 0000/0196 | Cost: 0.3782\n",
385
+ "Epoch: 008/010 | Batch 0150/0196 | Cost: 0.4339\n",
386
+ "Epoch: 008/010 | Train: 85.200%\n",
387
+ "Time elapsed: 2.16 min\n",
388
+ "Epoch: 009/010 | Batch 0000/0196 | Cost: 0.3083\n",
389
+ "Epoch: 009/010 | Batch 0150/0196 | Cost: 0.3290\n",
390
+ "Epoch: 009/010 | Train: 78.108%\n",
391
+ "Time elapsed: 2.42 min\n",
392
+ "Epoch: 010/010 | Batch 0000/0196 | Cost: 0.2229\n",
393
+ "Epoch: 010/010 | Batch 0150/0196 | Cost: 0.1945\n",
394
+ "Epoch: 010/010 | Train: 87.384%\n",
395
+ "Time elapsed: 2.70 min\n",
396
+ "Total Training Time: 2.70 min\n",
397
+ "Test accuracy: 70.67%\n",
398
+ "Total Time: 2.71 min\n"
399
+ ]
400
+ }
401
+ ],
402
+ "source": [
403
+ "def compute_accuracy(model, data_loader, device):\n",
404
+ " correct_pred, num_examples = 0, 0\n",
405
+ " for i, (features, targets) in enumerate(data_loader):\n",
406
+ " \n",
407
+ " features = features.to(device)\n",
408
+ " targets = targets.to(device)\n",
409
+ "\n",
410
+ " logits, probas = model(features)\n",
411
+ " _, predicted_labels = torch.max(probas, 1)\n",
412
+ " num_examples += targets.size(0)\n",
413
+ " correct_pred += (predicted_labels == targets).sum()\n",
414
+ " return correct_pred.float()/num_examples * 100\n",
415
+ " \n",
416
+ "\n",
417
+ "start_time = time.time()\n",
418
+ "for epoch in range(NUM_EPOCHS):\n",
419
+ " \n",
420
+ " model.train()\n",
421
+ " for batch_idx, (features, targets) in enumerate(train_loader):\n",
422
+ " \n",
423
+ " features = features.to(DEVICE)\n",
424
+ " targets = targets.to(DEVICE)\n",
425
+ " \n",
426
+ " ### FORWARD AND BACK PROP\n",
427
+ " logits, probas = model(features)\n",
428
+ " cost = F.cross_entropy(logits, targets)\n",
429
+ " optimizer.zero_grad()\n",
430
+ " \n",
431
+ " cost.backward()\n",
432
+ " \n",
433
+ " ### UPDATE MODEL PARAMETERS\n",
434
+ " optimizer.step()\n",
435
+ " \n",
436
+ " ### LOGGING\n",
437
+ " if not batch_idx % 150:\n",
438
+ " print ('Epoch: %03d/%03d | Batch %04d/%04d | Cost: %.4f' \n",
439
+ " %(epoch+1, NUM_EPOCHS, batch_idx, \n",
440
+ " len(train_loader), cost))\n",
441
+ "\n",
442
+ " \n",
443
+ "\n",
444
+ " model.eval()\n",
445
+ " with torch.set_grad_enabled(False): # save memory during inference\n",
446
+ " print('Epoch: %03d/%03d | Train: %.3f%%' % (\n",
447
+ " epoch+1, NUM_EPOCHS, \n",
448
+ " compute_accuracy(model, train_loader, device=DEVICE)))\n",
449
+ " \n",
450
+ " print('Time elapsed: %.2f min' % ((time.time() - start_time)/60))\n",
451
+ " \n",
452
+ "print('Total Training Time: %.2f min' % ((time.time() - start_time)/60))\n",
453
+ "\n",
454
+ "\n",
455
+ "with torch.set_grad_enabled(False): # save memory during inference\n",
456
+ " print('Test accuracy: %.2f%%' % (compute_accuracy(model, test_loader, device=DEVICE)))\n",
457
+ " \n",
458
+ "print('Total Time: %.2f min' % ((time.time() - start_time)/60))"
459
+ ]
460
+ },
461
+ {
462
+ "cell_type": "markdown",
463
+ "metadata": {},
464
+ "source": [
465
+ "## Training with Pinned Memory"
466
+ ]
467
+ },
468
+ {
469
+ "cell_type": "code",
470
+ "execution_count": 8,
471
+ "metadata": {},
472
+ "outputs": [
473
+ {
474
+ "name": "stdout",
475
+ "output_type": "stream",
476
+ "text": [
477
+ "Files already downloaded and verified\n",
478
+ "Image batch dimensions: torch.Size([256, 3, 32, 32])\n",
479
+ "Image label dimensions: torch.Size([256])\n",
480
+ "Image batch dimensions: torch.Size([256, 3, 32, 32])\n",
481
+ "Image label dimensions: torch.Size([256])\n"
482
+ ]
483
+ }
484
+ ],
485
+ "source": [
486
+ "##########################\n",
487
+ "### CIFAR-10 Dataset\n",
488
+ "##########################\n",
489
+ "\n",
490
+ "\n",
491
+ "# Note transforms.ToTensor() scales input images\n",
492
+ "# to 0-1 range\n",
493
+ "train_dataset = datasets.CIFAR10(root='data', \n",
494
+ " train=True, \n",
495
+ " transform=transforms.ToTensor(),\n",
496
+ " download=True)\n",
497
+ "\n",
498
+ "test_dataset = datasets.CIFAR10(root='data', \n",
499
+ " train=False, \n",
500
+ " transform=transforms.ToTensor())\n",
501
+ "\n",
502
+ "\n",
503
+ "train_loader = DataLoader(dataset=train_dataset, \n",
504
+ " batch_size=BATCH_SIZE, \n",
505
+ " pin_memory=True,\n",
506
+ " shuffle=True)\n",
507
+ "\n",
508
+ "test_loader = DataLoader(dataset=test_dataset, \n",
509
+ " batch_size=BATCH_SIZE,\n",
510
+ " pin_memory=True,\n",
511
+ " shuffle=False)\n",
512
+ "\n",
513
+ "# Checking the dataset\n",
514
+ "for images, labels in train_loader: \n",
515
+ " print('Image batch dimensions:', images.shape)\n",
516
+ " print('Image label dimensions:', labels.shape)\n",
517
+ " break\n",
518
+ "\n",
519
+ "# Checking the dataset\n",
520
+ "for images, labels in train_loader: \n",
521
+ " print('Image batch dimensions:', images.shape)\n",
522
+ " print('Image label dimensions:', labels.shape)\n",
523
+ " break"
524
+ ]
525
+ },
526
+ {
527
+ "cell_type": "code",
528
+ "execution_count": 9,
529
+ "metadata": {
530
+ "colab": {
531
+ "autoexec": {
532
+ "startup": false,
533
+ "wait_interval": 0
534
+ }
535
+ },
536
+ "colab_type": "code",
537
+ "id": "_lza9t_uj5w1"
538
+ },
539
+ "outputs": [],
540
+ "source": [
541
+ "torch.manual_seed(RANDOM_SEED)\n",
542
+ "\n",
543
+ "model = resnet34(NUM_CLASSES)\n",
544
+ "model.to(DEVICE)\n",
545
+ "\n",
546
+ "optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE) "
547
+ ]
548
+ },
549
+ {
550
+ "cell_type": "code",
551
+ "execution_count": 10,
552
+ "metadata": {
553
+ "colab": {
554
+ "autoexec": {
555
+ "startup": false,
556
+ "wait_interval": 0
557
+ },
558
+ "base_uri": "https://localhost:8080/",
559
+ "height": 1547
560
+ },
561
+ "colab_type": "code",
562
+ "executionInfo": {
563
+ "elapsed": 2384585,
564
+ "status": "ok",
565
+ "timestamp": 1524976888520,
566
+ "user": {
567
+ "displayName": "Sebastian Raschka",
568
+ "photoUrl": "//lh6.googleusercontent.com/-cxK6yOSQ6uE/AAAAAAAAAAI/AAAAAAAAIfw/P9ar_CHsKOQ/s50-c-k-no/photo.jpg",
569
+ "userId": "118404394130788869227"
570
+ },
571
+ "user_tz": 240
572
+ },
573
+ "id": "Dzh3ROmRj5w7",
574
+ "outputId": "5f8fd8c9-b076-403a-b0b7-fd2d498b48d7",
575
+ "scrolled": false
576
+ },
577
+ "outputs": [
578
+ {
579
+ "name": "stdout",
580
+ "output_type": "stream",
581
+ "text": [
582
+ "Epoch: 001/010 | Batch 0000/0196 | Cost: 2.6021\n",
583
+ "Epoch: 001/010 | Batch 0150/0196 | Cost: 1.3961\n",
584
+ "Epoch: 001/010 | Train: 45.084%\n",
585
+ "Time elapsed: 0.39 min\n",
586
+ "Epoch: 002/010 | Batch 0000/0196 | Cost: 1.1228\n",
587
+ "Epoch: 002/010 | Batch 0150/0196 | Cost: 1.0426\n",
588
+ "Epoch: 002/010 | Train: 56.166%\n",
589
+ "Time elapsed: 0.77 min\n",
590
+ "Epoch: 003/010 | Batch 0000/0196 | Cost: 0.9980\n",
591
+ "Epoch: 003/010 | Batch 0150/0196 | Cost: 0.8279\n",
592
+ "Epoch: 003/010 | Train: 66.702%\n",
593
+ "Time elapsed: 1.16 min\n",
594
+ "Epoch: 004/010 | Batch 0000/0196 | Cost: 0.6384\n",
595
+ "Epoch: 004/010 | Batch 0150/0196 | Cost: 0.7103\n",
596
+ "Epoch: 004/010 | Train: 65.330%\n",
597
+ "Time elapsed: 1.55 min\n",
598
+ "Epoch: 005/010 | Batch 0000/0196 | Cost: 0.6308\n",
599
+ "Epoch: 005/010 | Batch 0150/0196 | Cost: 0.5913\n",
600
+ "Epoch: 005/010 | Train: 79.636%\n",
601
+ "Time elapsed: 1.94 min\n",
602
+ "Epoch: 006/010 | Batch 0000/0196 | Cost: 0.4409\n",
603
+ "Epoch: 006/010 | Batch 0150/0196 | Cost: 0.5557\n",
604
+ "Epoch: 006/010 | Train: 76.456%\n",
605
+ "Time elapsed: 2.33 min\n",
606
+ "Epoch: 007/010 | Batch 0000/0196 | Cost: 0.4778\n",
607
+ "Epoch: 007/010 | Batch 0150/0196 | Cost: 0.4815\n",
608
+ "Epoch: 007/010 | Train: 65.890%\n",
609
+ "Time elapsed: 2.71 min\n",
610
+ "Epoch: 008/010 | Batch 0000/0196 | Cost: 0.3782\n",
611
+ "Epoch: 008/010 | Batch 0150/0196 | Cost: 0.4339\n",
612
+ "Epoch: 008/010 | Train: 85.200%\n",
613
+ "Time elapsed: 3.10 min\n",
614
+ "Epoch: 009/010 | Batch 0000/0196 | Cost: 0.3083\n",
615
+ "Epoch: 009/010 | Batch 0150/0196 | Cost: 0.3290\n",
616
+ "Epoch: 009/010 | Train: 78.108%\n",
617
+ "Time elapsed: 3.49 min\n",
618
+ "Epoch: 010/010 | Batch 0000/0196 | Cost: 0.2229\n",
619
+ "Epoch: 010/010 | Batch 0150/0196 | Cost: 0.1945\n",
620
+ "Epoch: 010/010 | Train: 87.384%\n",
621
+ "Time elapsed: 3.88 min\n",
622
+ "Total Training Time: 3.88 min\n",
623
+ "Test accuracy: 70.67%\n",
624
+ "Total Time: 3.91 min\n"
625
+ ]
626
+ }
627
+ ],
628
+ "source": [
629
+ "def compute_accuracy(model, data_loader, device):\n",
630
+ " correct_pred, num_examples = 0, 0\n",
631
+ " for i, (features, targets) in enumerate(data_loader):\n",
632
+ " \n",
633
+ " features = features.to(device)\n",
634
+ " targets = targets.to(device)\n",
635
+ "\n",
636
+ " logits, probas = model(features)\n",
637
+ " _, predicted_labels = torch.max(probas, 1)\n",
638
+ " num_examples += targets.size(0)\n",
639
+ " correct_pred += (predicted_labels == targets).sum()\n",
640
+ " return correct_pred.float()/num_examples * 100\n",
641
+ " \n",
642
+ "\n",
643
+ "start_time = time.time()\n",
644
+ "for epoch in range(NUM_EPOCHS):\n",
645
+ " \n",
646
+ " model.train()\n",
647
+ " for batch_idx, (features, targets) in enumerate(train_loader):\n",
648
+ " \n",
649
+ " features = features.to(DEVICE)\n",
650
+ " targets = targets.to(DEVICE)\n",
651
+ " \n",
652
+ " ### FORWARD AND BACK PROP\n",
653
+ " logits, probas = model(features)\n",
654
+ " cost = F.cross_entropy(logits, targets)\n",
655
+ " optimizer.zero_grad()\n",
656
+ " \n",
657
+ " cost.backward()\n",
658
+ " \n",
659
+ " ### UPDATE MODEL PARAMETERS\n",
660
+ " optimizer.step()\n",
661
+ " \n",
662
+ " ### LOGGING\n",
663
+ " if not batch_idx % 150:\n",
664
+ " print ('Epoch: %03d/%03d | Batch %04d/%04d | Cost: %.4f' \n",
665
+ " %(epoch+1, NUM_EPOCHS, batch_idx, \n",
666
+ " len(train_loader), cost))\n",
667
+ "\n",
668
+ " \n",
669
+ "\n",
670
+ " model.eval()\n",
671
+ " with torch.set_grad_enabled(False): # save memory during inference\n",
672
+ " print('Epoch: %03d/%03d | Train: %.3f%%' % (\n",
673
+ " epoch+1, NUM_EPOCHS, \n",
674
+ " compute_accuracy(model, train_loader, device=DEVICE)))\n",
675
+ " \n",
676
+ " print('Time elapsed: %.2f min' % ((time.time() - start_time)/60))\n",
677
+ " \n",
678
+ "print('Total Training Time: %.2f min' % ((time.time() - start_time)/60))\n",
679
+ "\n",
680
+ "\n",
681
+ "with torch.set_grad_enabled(False): # save memory during inference\n",
682
+ " print('Test accuracy: %.2f%%' % (compute_accuracy(model, test_loader, device=DEVICE)))\n",
683
+ " \n",
684
+ "print('Total Time: %.2f min' % ((time.time() - start_time)/60))"
685
+ ]
686
+ },
687
+ {
688
+ "cell_type": "markdown",
689
+ "metadata": {},
690
+ "source": [
691
+ "## Conclusions"
692
+ ]
693
+ },
694
+ {
695
+ "cell_type": "markdown",
696
+ "metadata": {},
697
+ "source": [
698
+ "Based on the training time without and with `pin_memory=True`, there doesn't seem to be a speed-up when using page-locked (or \"pinned\") memory -- in fact, pinning the memory even slowed down the training. (I reran the code in the opposite order, i.e., `pin_memory=True` first, and got the same results.) This could be due to the relatively small dataset size, batch size.\n"
699
+ ]
700
+ }
701
+ ],
702
+ "metadata": {
703
+ "accelerator": "GPU",
704
+ "colab": {
705
+ "collapsed_sections": [],
706
+ "default_view": {},
707
+ "name": "convnet-vgg16.ipynb",
708
+ "provenance": [],
709
+ "version": "0.3.2",
710
+ "views": {}
711
+ },
712
+ "kernelspec": {
713
+ "display_name": "Python 3",
714
+ "language": "python",
715
+ "name": "python3"
716
+ },
717
+ "language_info": {
718
+ "codemirror_mode": {
719
+ "name": "ipython",
720
+ "version": 3
721
+ },
722
+ "file_extension": ".py",
723
+ "mimetype": "text/x-python",
724
+ "name": "python",
725
+ "nbconvert_exporter": "python",
726
+ "pygments_lexer": "ipython3",
727
+ "version": "3.7.1"
728
+ },
729
+ "toc": {
730
+ "nav_menu": {},
731
+ "number_sections": true,
732
+ "sideBar": true,
733
+ "skip_h1_title": false,
734
+ "title_cell": "Table of Contents",
735
+ "title_sidebar": "Contents",
736
+ "toc_cell": true,
737
+ "toc_position": {
738
+ "height": "calc(100% - 180px)",
739
+ "left": "10px",
740
+ "top": "150px",
741
+ "width": "371px"
742
+ },
743
+ "toc_section_display": true,
744
+ "toc_window_display": true
745
+ }
746
+ },
747
+ "nbformat": 4,
748
+ "nbformat_minor": 2
749
+ }
Experiments/quicknet_cifar10_centroid.ipynb ADDED
@@ -0,0 +1,698 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "nbformat": 4,
3
+ "nbformat_minor": 0,
4
+ "metadata": {
5
+ "colab": {
6
+ "provenance": [],
7
+ "gpuType": "T4"
8
+ },
9
+ "kernelspec": {
10
+ "name": "python3",
11
+ "display_name": "Python 3"
12
+ },
13
+ "language_info": {
14
+ "name": "python"
15
+ },
16
+ "accelerator": "GPU"
17
+ },
18
+ "cells": [
19
+ {
20
+ "cell_type": "code",
21
+ "execution_count": 1,
22
+ "metadata": {
23
+ "colab": {
24
+ "base_uri": "https://localhost:8080/"
25
+ },
26
+ "id": "5aR8PYXi684E",
27
+ "outputId": "a6b59bdf-41f5-4b86-e7b1-47ceb6fb6893"
28
+ },
29
+ "outputs": [
30
+ {
31
+ "output_type": "stream",
32
+ "name": "stdout",
33
+ "text": [
34
+ "Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz\n",
35
+ "170498071/170498071 [==============================] - 5s 0us/step\n",
36
+ "Epoch 1/10\n",
37
+ "782/782 [==============================] - 14s 10ms/step - loss: 1.6260 - accuracy: 0.3923 - val_loss: 1.2804 - val_accuracy: 0.5374\n",
38
+ "Epoch 2/10\n",
39
+ "782/782 [==============================] - 6s 8ms/step - loss: 1.1254 - accuracy: 0.5914 - val_loss: 1.0466 - val_accuracy: 0.6335\n",
40
+ "Epoch 3/10\n",
41
+ "782/782 [==============================] - 6s 8ms/step - loss: 0.9187 - accuracy: 0.6723 - val_loss: 0.9976 - val_accuracy: 0.6503\n",
42
+ "Epoch 4/10\n",
43
+ "782/782 [==============================] - 6s 8ms/step - loss: 0.7763 - accuracy: 0.7259 - val_loss: 0.8091 - val_accuracy: 0.7194\n",
44
+ "Epoch 5/10\n",
45
+ "782/782 [==============================] - 6s 8ms/step - loss: 0.6660 - accuracy: 0.7670 - val_loss: 0.7875 - val_accuracy: 0.7268\n",
46
+ "Epoch 6/10\n",
47
+ "782/782 [==============================] - 6s 8ms/step - loss: 0.5799 - accuracy: 0.7962 - val_loss: 0.8075 - val_accuracy: 0.7319\n",
48
+ "Epoch 7/10\n",
49
+ "782/782 [==============================] - 6s 8ms/step - loss: 0.5006 - accuracy: 0.8232 - val_loss: 0.8128 - val_accuracy: 0.7300\n",
50
+ "Epoch 8/10\n",
51
+ "782/782 [==============================] - 6s 8ms/step - loss: 0.4253 - accuracy: 0.8497 - val_loss: 0.8981 - val_accuracy: 0.7261\n",
52
+ "Epoch 9/10\n",
53
+ "782/782 [==============================] - 6s 8ms/step - loss: 0.3633 - accuracy: 0.8707 - val_loss: 0.8523 - val_accuracy: 0.7377\n",
54
+ "Epoch 10/10\n",
55
+ "782/782 [==============================] - 7s 9ms/step - loss: 0.3099 - accuracy: 0.8887 - val_loss: 0.8689 - val_accuracy: 0.7499\n",
56
+ "1563/1563 [==============================] - 4s 2ms/step\n"
57
+ ]
58
+ },
59
+ {
60
+ "output_type": "execute_result",
61
+ "data": {
62
+ "text/plain": [
63
+ "array([[8.79073262e-01, 2.84781260e-03, 3.16492580e-02, 9.60051734e-03,\n",
64
+ " 8.88401456e-03, 4.04336443e-03, 1.89202465e-03, 2.60451157e-03,\n",
65
+ " 4.94241416e-02, 9.98295005e-03],\n",
66
+ " [5.34512708e-03, 9.61387515e-01, 1.82328641e-03, 1.49091403e-03,\n",
67
+ " 3.62021121e-04, 2.36713630e-03, 9.99960001e-04, 3.58641875e-04,\n",
68
+ " 8.00825842e-03, 1.78600885e-02],\n",
69
+ " [1.93752628e-02, 5.49300632e-04, 8.71673167e-01, 3.62917222e-02,\n",
70
+ " 2.81240512e-02, 1.94428060e-02, 1.17484620e-02, 7.90499710e-03,\n",
71
+ " 3.91457370e-03, 9.77586606e-04],\n",
72
+ " [3.43155907e-03, 3.87360866e-04, 3.25083360e-02, 8.04273427e-01,\n",
73
+ " 2.26426423e-02, 1.00013100e-01, 1.48316240e-02, 1.55726429e-02,\n",
74
+ " 3.46732978e-03, 2.87183723e-03],\n",
75
+ " [4.50019445e-03, 3.07688024e-04, 3.81155163e-02, 3.81957851e-02,\n",
76
+ " 8.43278885e-01, 2.98470370e-02, 1.07433125e-02, 3.15311365e-02,\n",
77
+ " 2.45781220e-03, 1.02343888e-03],\n",
78
+ " [8.63184629e-04, 2.92542420e-04, 2.77942196e-02, 1.26914889e-01,\n",
79
+ " 1.68064609e-02, 7.96062946e-01, 6.89344807e-03, 2.19228752e-02,\n",
80
+ " 1.45196577e-03, 9.96942166e-04],\n",
81
+ " [2.25282880e-03, 1.00196269e-03, 4.27705161e-02, 4.95258309e-02,\n",
82
+ " 1.27582047e-02, 1.77529752e-02, 8.66206527e-01, 2.50199367e-03,\n",
83
+ " 2.86106160e-03, 2.37041595e-03],\n",
84
+ " [1.76379981e-03, 2.93744437e-04, 8.86812247e-03, 1.57593563e-02,\n",
85
+ " 1.64906941e-02, 2.32193433e-02, 5.39854867e-04, 9.30097342e-01,\n",
86
+ " 8.36823776e-04, 2.13308516e-03],\n",
87
+ " [1.30738802e-02, 4.98946663e-03, 2.41551897e-03, 5.38546871e-03,\n",
88
+ " 1.31231989e-03, 1.98870571e-03, 1.08822726e-03, 5.06214972e-04,\n",
89
+ " 9.64531541e-01, 4.71282937e-03],\n",
90
+ " [1.20871738e-02, 2.48907283e-02, 2.02297443e-03, 6.83936710e-03,\n",
91
+ " 1.44265452e-03, 4.34162840e-03, 1.16528722e-03, 2.80632032e-03,\n",
92
+ " 1.10650901e-02, 9.33342278e-01]])"
93
+ ]
94
+ },
95
+ "metadata": {},
96
+ "execution_count": 1
97
+ }
98
+ ],
99
+ "source": [
100
+ "from scipy.spatial.distance import euclidean\n",
101
+ "import tensorflow as tf\n",
102
+ "from tensorflow.keras import datasets, layers, models\n",
103
+ "import numpy as np\n",
104
+ "\n",
105
+ "\n",
106
+ "(train_images, train_labels), (test_images, test_labels) = datasets.cifar10.load_data()\n",
107
+ "\n",
108
+ "\n",
109
+ "train_images = train_images / 255.0\n",
110
+ "test_images = test_images / 255.0\n",
111
+ "\n",
112
+ "\n",
113
+ "def create_resnet18():\n",
114
+ " model = models.Sequential()\n",
115
+ " model.add(layers.Conv2D(64, (3, 3), activation='relu', padding='same', input_shape=(32, 32, 3)))\n",
116
+ " model.add(layers.MaxPooling2D((2, 2)))\n",
117
+ " model.add(layers.Conv2D(64, (3, 3), activation='relu', padding='same'))\n",
118
+ " model.add(layers.MaxPooling2D((2, 2)))\n",
119
+ " model.add(layers.Conv2D(128, (3, 3), activation='relu', padding='same'))\n",
120
+ " model.add(layers.Conv2D(128, (3, 3), activation='relu', padding='same'))\n",
121
+ " model.add(layers.MaxPooling2D((2, 2)))\n",
122
+ " model.add(layers.Conv2D(256, (3, 3), activation='relu', padding='same'))\n",
123
+ " model.add(layers.Conv2D(256, (3, 3), activation='relu', padding='same'))\n",
124
+ " model.add(layers.MaxPooling2D((2, 2)))\n",
125
+ " model.add(layers.Flatten())\n",
126
+ " model.add(layers.Dense(512, activation='relu'))\n",
127
+ " model.add(layers.Dense(10, activation='softmax'))\n",
128
+ " return model\n",
129
+ "\n",
130
+ "\n",
131
+ "model = create_resnet18()\n",
132
+ "\n",
133
+ "\n",
134
+ "model.compile(optimizer='adam',\n",
135
+ " loss='sparse_categorical_crossentropy',\n",
136
+ " metrics=['accuracy'])\n",
137
+ "\n",
138
+ "\n",
139
+ "model.fit(train_images, train_labels, epochs=10, batch_size=64, validation_data=(test_images, test_labels))\n",
140
+ "\n",
141
+ "\n",
142
+ "train_embeddings = model.predict(train_images)\n",
143
+ "\n",
144
+ "num_classes = 10\n",
145
+ "mean_embeddings = np.zeros((num_classes, train_embeddings.shape[1]))\n",
146
+ "\n",
147
+ "\n",
148
+ "for class_label in range(num_classes):\n",
149
+ " class_indices = np.where(train_labels.flatten() == class_label)[0]\n",
150
+ " class_embeddings = train_embeddings[class_indices]\n",
151
+ " mean_embedding = np.mean(class_embeddings, axis=0)\n",
152
+ " mean_embeddings[class_label] = mean_embedding\n",
153
+ "\n",
154
+ "\n",
155
+ "def classify_query(query_image):\n",
156
+ " query_embedding = model.predict(np.expand_dims(query_image, axis=0))\n",
157
+ " distances = [euclidean(query_embedding.flatten(), mean_embedding) for mean_embedding in mean_embeddings]\n",
158
+ " predicted_class = np.argmin(distances)\n",
159
+ " return predicted_class\n",
160
+ "\n",
161
+ "mean_embeddings\n"
162
+ ]
163
+ },
164
+ {
165
+ "cell_type": "code",
166
+ "source": [
167
+ "# Calculate accuracy\n",
168
+ "correct_predictions = 0\n",
169
+ "total_predictions = 500\n",
170
+ "\n",
171
+ "for i in range(500):\n",
172
+ " predicted_class = classify_query(test_images[i])\n",
173
+ " if predicted_class == test_labels[i]:\n",
174
+ " correct_predictions += 1\n",
175
+ "\n",
176
+ "accuracy = correct_predictions / total_predictions\n",
177
+ "print(\"Accuracy:\", accuracy)\n"
178
+ ],
179
+ "metadata": {
180
+ "colab": {
181
+ "base_uri": "https://localhost:8080/"
182
+ },
183
+ "id": "1-b_E44RrxuV",
184
+ "outputId": "125e4835-b44a-48bf-ff43-1103c2b1b6f4"
185
+ },
186
+ "execution_count": 2,
187
+ "outputs": [
188
+ {
189
+ "output_type": "stream",
190
+ "name": "stdout",
191
+ "text": [
192
+ "1/1 [==============================] - 0s 388ms/step\n",
193
+ "1/1 [==============================] - 0s 17ms/step\n",
194
+ "1/1 [==============================] - 0s 17ms/step\n",
195
+ "1/1 [==============================] - 0s 19ms/step\n",
196
+ "1/1 [==============================] - 0s 17ms/step\n",
197
+ "1/1 [==============================] - 0s 25ms/step\n",
198
+ "1/1 [==============================] - 0s 18ms/step\n",
199
+ "1/1 [==============================] - 0s 19ms/step\n",
200
+ "1/1 [==============================] - 0s 18ms/step\n",
201
+ "1/1 [==============================] - 0s 17ms/step\n",
202
+ "1/1 [==============================] - 0s 19ms/step\n",
203
+ "1/1 [==============================] - 0s 18ms/step\n",
204
+ "1/1 [==============================] - 0s 19ms/step\n",
205
+ "1/1 [==============================] - 0s 17ms/step\n",
206
+ "1/1 [==============================] - 0s 18ms/step\n",
207
+ "1/1 [==============================] - 0s 18ms/step\n",
208
+ "1/1 [==============================] - 0s 17ms/step\n",
209
+ "1/1 [==============================] - 0s 17ms/step\n",
210
+ "1/1 [==============================] - 0s 18ms/step\n",
211
+ "1/1 [==============================] - 0s 18ms/step\n",
212
+ "1/1 [==============================] - 0s 19ms/step\n",
213
+ "1/1 [==============================] - 0s 18ms/step\n",
214
+ "1/1 [==============================] - 0s 17ms/step\n",
215
+ "1/1 [==============================] - 0s 17ms/step\n",
216
+ "1/1 [==============================] - 0s 17ms/step\n",
217
+ "1/1 [==============================] - 0s 18ms/step\n",
218
+ "1/1 [==============================] - 0s 18ms/step\n",
219
+ "1/1 [==============================] - 0s 18ms/step\n",
220
+ "1/1 [==============================] - 0s 19ms/step\n",
221
+ "1/1 [==============================] - 0s 18ms/step\n",
222
+ "1/1 [==============================] - 0s 17ms/step\n",
223
+ "1/1 [==============================] - 0s 20ms/step\n",
224
+ "1/1 [==============================] - 0s 18ms/step\n",
225
+ "1/1 [==============================] - 0s 17ms/step\n",
226
+ "1/1 [==============================] - 0s 21ms/step\n",
227
+ "1/1 [==============================] - 0s 18ms/step\n",
228
+ "1/1 [==============================] - 0s 18ms/step\n",
229
+ "1/1 [==============================] - 0s 20ms/step\n",
230
+ "1/1 [==============================] - 0s 24ms/step\n",
231
+ "1/1 [==============================] - 0s 18ms/step\n",
232
+ "1/1 [==============================] - 0s 20ms/step\n",
233
+ "1/1 [==============================] - 0s 19ms/step\n",
234
+ "1/1 [==============================] - 0s 18ms/step\n",
235
+ "1/1 [==============================] - 0s 18ms/step\n",
236
+ "1/1 [==============================] - 0s 17ms/step\n",
237
+ "1/1 [==============================] - 0s 18ms/step\n",
238
+ "1/1 [==============================] - 0s 18ms/step\n",
239
+ "1/1 [==============================] - 0s 18ms/step\n",
240
+ "1/1 [==============================] - 0s 17ms/step\n",
241
+ "1/1 [==============================] - 0s 17ms/step\n",
242
+ "1/1 [==============================] - 0s 17ms/step\n",
243
+ "1/1 [==============================] - 0s 17ms/step\n",
244
+ "1/1 [==============================] - 0s 18ms/step\n",
245
+ "1/1 [==============================] - 0s 18ms/step\n",
246
+ "1/1 [==============================] - 0s 20ms/step\n",
247
+ "1/1 [==============================] - 0s 25ms/step\n",
248
+ "1/1 [==============================] - 0s 17ms/step\n",
249
+ "1/1 [==============================] - 0s 17ms/step\n",
250
+ "1/1 [==============================] - 0s 17ms/step\n",
251
+ "1/1 [==============================] - 0s 17ms/step\n",
252
+ "1/1 [==============================] - 0s 17ms/step\n",
253
+ "1/1 [==============================] - 0s 17ms/step\n",
254
+ "1/1 [==============================] - 0s 17ms/step\n",
255
+ "1/1 [==============================] - 0s 18ms/step\n",
256
+ "1/1 [==============================] - 0s 19ms/step\n",
257
+ "1/1 [==============================] - 0s 17ms/step\n",
258
+ "1/1 [==============================] - 0s 17ms/step\n",
259
+ "1/1 [==============================] - 0s 17ms/step\n",
260
+ "1/1 [==============================] - 0s 16ms/step\n",
261
+ "1/1 [==============================] - 0s 17ms/step\n",
262
+ "1/1 [==============================] - 0s 17ms/step\n",
263
+ "1/1 [==============================] - 0s 17ms/step\n",
264
+ "1/1 [==============================] - 0s 23ms/step\n",
265
+ "1/1 [==============================] - 0s 19ms/step\n",
266
+ "1/1 [==============================] - 0s 18ms/step\n",
267
+ "1/1 [==============================] - 0s 17ms/step\n",
268
+ "1/1 [==============================] - 0s 17ms/step\n",
269
+ "1/1 [==============================] - 0s 25ms/step\n",
270
+ "1/1 [==============================] - 0s 17ms/step\n",
271
+ "1/1 [==============================] - 0s 18ms/step\n",
272
+ "1/1 [==============================] - 0s 18ms/step\n",
273
+ "1/1 [==============================] - 0s 17ms/step\n",
274
+ "1/1 [==============================] - 0s 20ms/step\n",
275
+ "1/1 [==============================] - 0s 17ms/step\n",
276
+ "1/1 [==============================] - 0s 17ms/step\n",
277
+ "1/1 [==============================] - 0s 17ms/step\n",
278
+ "1/1 [==============================] - 0s 29ms/step\n",
279
+ "1/1 [==============================] - 0s 31ms/step\n",
280
+ "1/1 [==============================] - 0s 31ms/step\n",
281
+ "1/1 [==============================] - 0s 28ms/step\n",
282
+ "1/1 [==============================] - 0s 27ms/step\n",
283
+ "1/1 [==============================] - 0s 24ms/step\n",
284
+ "1/1 [==============================] - 0s 26ms/step\n",
285
+ "1/1 [==============================] - 0s 25ms/step\n",
286
+ "1/1 [==============================] - 0s 26ms/step\n",
287
+ "1/1 [==============================] - 0s 25ms/step\n",
288
+ "1/1 [==============================] - 0s 29ms/step\n",
289
+ "1/1 [==============================] - 0s 25ms/step\n",
290
+ "1/1 [==============================] - 0s 26ms/step\n",
291
+ "1/1 [==============================] - 0s 48ms/step\n",
292
+ "1/1 [==============================] - 0s 32ms/step\n",
293
+ "1/1 [==============================] - 0s 27ms/step\n",
294
+ "1/1 [==============================] - 0s 25ms/step\n",
295
+ "1/1 [==============================] - 0s 28ms/step\n",
296
+ "1/1 [==============================] - 0s 24ms/step\n",
297
+ "1/1 [==============================] - 0s 28ms/step\n",
298
+ "1/1 [==============================] - 0s 31ms/step\n",
299
+ "1/1 [==============================] - 0s 27ms/step\n",
300
+ "1/1 [==============================] - 0s 29ms/step\n",
301
+ "1/1 [==============================] - 0s 18ms/step\n",
302
+ "1/1 [==============================] - 0s 18ms/step\n",
303
+ "1/1 [==============================] - 0s 28ms/step\n",
304
+ "1/1 [==============================] - 0s 19ms/step\n",
305
+ "1/1 [==============================] - 0s 19ms/step\n",
306
+ "1/1 [==============================] - 0s 17ms/step\n",
307
+ "1/1 [==============================] - 0s 17ms/step\n",
308
+ "1/1 [==============================] - 0s 17ms/step\n",
309
+ "1/1 [==============================] - 0s 18ms/step\n",
310
+ "1/1 [==============================] - 0s 21ms/step\n",
311
+ "1/1 [==============================] - 0s 18ms/step\n",
312
+ "1/1 [==============================] - 0s 17ms/step\n",
313
+ "1/1 [==============================] - 0s 17ms/step\n",
314
+ "1/1 [==============================] - 0s 17ms/step\n",
315
+ "1/1 [==============================] - 0s 16ms/step\n",
316
+ "1/1 [==============================] - 0s 17ms/step\n",
317
+ "1/1 [==============================] - 0s 17ms/step\n",
318
+ "1/1 [==============================] - 0s 17ms/step\n",
319
+ "1/1 [==============================] - 0s 20ms/step\n",
320
+ "1/1 [==============================] - 0s 19ms/step\n",
321
+ "1/1 [==============================] - 0s 17ms/step\n",
322
+ "1/1 [==============================] - 0s 17ms/step\n",
323
+ "1/1 [==============================] - 0s 17ms/step\n",
324
+ "1/1 [==============================] - 0s 16ms/step\n",
325
+ "1/1 [==============================] - 0s 19ms/step\n",
326
+ "1/1 [==============================] - 0s 17ms/step\n",
327
+ "1/1 [==============================] - 0s 17ms/step\n",
328
+ "1/1 [==============================] - 0s 17ms/step\n",
329
+ "1/1 [==============================] - 0s 18ms/step\n",
330
+ "1/1 [==============================] - 0s 21ms/step\n",
331
+ "1/1 [==============================] - 0s 18ms/step\n",
332
+ "1/1 [==============================] - 0s 17ms/step\n",
333
+ "1/1 [==============================] - 0s 17ms/step\n",
334
+ "1/1 [==============================] - 0s 17ms/step\n",
335
+ "1/1 [==============================] - 0s 18ms/step\n",
336
+ "1/1 [==============================] - 0s 17ms/step\n",
337
+ "1/1 [==============================] - 0s 17ms/step\n",
338
+ "1/1 [==============================] - 0s 18ms/step\n",
339
+ "1/1 [==============================] - 0s 17ms/step\n",
340
+ "1/1 [==============================] - 0s 18ms/step\n",
341
+ "1/1 [==============================] - 0s 18ms/step\n",
342
+ "1/1 [==============================] - 0s 18ms/step\n",
343
+ "1/1 [==============================] - 0s 18ms/step\n",
344
+ "1/1 [==============================] - 0s 17ms/step\n",
345
+ "1/1 [==============================] - 0s 17ms/step\n",
346
+ "1/1 [==============================] - 0s 17ms/step\n",
347
+ "1/1 [==============================] - 0s 18ms/step\n",
348
+ "1/1 [==============================] - 0s 19ms/step\n",
349
+ "1/1 [==============================] - 0s 19ms/step\n",
350
+ "1/1 [==============================] - 0s 19ms/step\n",
351
+ "1/1 [==============================] - 0s 18ms/step\n",
352
+ "1/1 [==============================] - 0s 18ms/step\n",
353
+ "1/1 [==============================] - 0s 24ms/step\n",
354
+ "1/1 [==============================] - 0s 18ms/step\n",
355
+ "1/1 [==============================] - 0s 18ms/step\n",
356
+ "1/1 [==============================] - 0s 18ms/step\n",
357
+ "1/1 [==============================] - 0s 17ms/step\n",
358
+ "1/1 [==============================] - 0s 17ms/step\n",
359
+ "1/1 [==============================] - 0s 18ms/step\n",
360
+ "1/1 [==============================] - 0s 19ms/step\n",
361
+ "1/1 [==============================] - 0s 20ms/step\n",
362
+ "1/1 [==============================] - 0s 17ms/step\n",
363
+ "1/1 [==============================] - 0s 17ms/step\n",
364
+ "1/1 [==============================] - 0s 17ms/step\n",
365
+ "1/1 [==============================] - 0s 17ms/step\n",
366
+ "1/1 [==============================] - 0s 18ms/step\n",
367
+ "1/1 [==============================] - 0s 19ms/step\n",
368
+ "1/1 [==============================] - 0s 20ms/step\n",
369
+ "1/1 [==============================] - 0s 18ms/step\n",
370
+ "1/1 [==============================] - 0s 24ms/step\n",
371
+ "1/1 [==============================] - 0s 17ms/step\n",
372
+ "1/1 [==============================] - 0s 17ms/step\n",
373
+ "1/1 [==============================] - 0s 17ms/step\n",
374
+ "1/1 [==============================] - 0s 18ms/step\n",
375
+ "1/1 [==============================] - 0s 17ms/step\n",
376
+ "1/1 [==============================] - 0s 18ms/step\n",
377
+ "1/1 [==============================] - 0s 18ms/step\n",
378
+ "1/1 [==============================] - 0s 17ms/step\n",
379
+ "1/1 [==============================] - 0s 16ms/step\n",
380
+ "1/1 [==============================] - 0s 17ms/step\n",
381
+ "1/1 [==============================] - 0s 19ms/step\n",
382
+ "1/1 [==============================] - 0s 18ms/step\n",
383
+ "1/1 [==============================] - 0s 18ms/step\n",
384
+ "1/1 [==============================] - 0s 21ms/step\n",
385
+ "1/1 [==============================] - 0s 18ms/step\n",
386
+ "1/1 [==============================] - 0s 17ms/step\n",
387
+ "1/1 [==============================] - 0s 17ms/step\n",
388
+ "1/1 [==============================] - 0s 17ms/step\n",
389
+ "1/1 [==============================] - 0s 17ms/step\n",
390
+ "1/1 [==============================] - 0s 18ms/step\n",
391
+ "1/1 [==============================] - 0s 18ms/step\n",
392
+ "1/1 [==============================] - 0s 18ms/step\n",
393
+ "1/1 [==============================] - 0s 17ms/step\n",
394
+ "1/1 [==============================] - 0s 18ms/step\n",
395
+ "1/1 [==============================] - 0s 18ms/step\n",
396
+ "1/1 [==============================] - 0s 17ms/step\n",
397
+ "1/1 [==============================] - 0s 17ms/step\n",
398
+ "1/1 [==============================] - 0s 18ms/step\n",
399
+ "1/1 [==============================] - 0s 17ms/step\n",
400
+ "1/1 [==============================] - 0s 18ms/step\n",
401
+ "1/1 [==============================] - 0s 19ms/step\n",
402
+ "1/1 [==============================] - 0s 24ms/step\n",
403
+ "1/1 [==============================] - 0s 26ms/step\n",
404
+ "1/1 [==============================] - 0s 18ms/step\n",
405
+ "1/1 [==============================] - 0s 18ms/step\n",
406
+ "1/1 [==============================] - 0s 16ms/step\n",
407
+ "1/1 [==============================] - 0s 17ms/step\n",
408
+ "1/1 [==============================] - 0s 18ms/step\n",
409
+ "1/1 [==============================] - 0s 18ms/step\n",
410
+ "1/1 [==============================] - 0s 19ms/step\n",
411
+ "1/1 [==============================] - 0s 17ms/step\n",
412
+ "1/1 [==============================] - 0s 16ms/step\n",
413
+ "1/1 [==============================] - 0s 16ms/step\n",
414
+ "1/1 [==============================] - 0s 18ms/step\n",
415
+ "1/1 [==============================] - 0s 18ms/step\n",
416
+ "1/1 [==============================] - 0s 17ms/step\n",
417
+ "1/1 [==============================] - 0s 19ms/step\n",
418
+ "1/1 [==============================] - 0s 17ms/step\n",
419
+ "1/1 [==============================] - 0s 19ms/step\n",
420
+ "1/1 [==============================] - 0s 20ms/step\n",
421
+ "1/1 [==============================] - 0s 18ms/step\n",
422
+ "1/1 [==============================] - 0s 17ms/step\n",
423
+ "1/1 [==============================] - 0s 17ms/step\n",
424
+ "1/1 [==============================] - 0s 17ms/step\n",
425
+ "1/1 [==============================] - 0s 17ms/step\n",
426
+ "1/1 [==============================] - 0s 18ms/step\n",
427
+ "1/1 [==============================] - 0s 18ms/step\n",
428
+ "1/1 [==============================] - 0s 18ms/step\n",
429
+ "1/1 [==============================] - 0s 17ms/step\n",
430
+ "1/1 [==============================] - 0s 18ms/step\n",
431
+ "1/1 [==============================] - 0s 17ms/step\n",
432
+ "1/1 [==============================] - 0s 18ms/step\n",
433
+ "1/1 [==============================] - 0s 19ms/step\n",
434
+ "1/1 [==============================] - 0s 18ms/step\n",
435
+ "1/1 [==============================] - 0s 21ms/step\n",
436
+ "1/1 [==============================] - 0s 25ms/step\n",
437
+ "1/1 [==============================] - 0s 18ms/step\n",
438
+ "1/1 [==============================] - 0s 18ms/step\n",
439
+ "1/1 [==============================] - 0s 18ms/step\n",
440
+ "1/1 [==============================] - 0s 17ms/step\n",
441
+ "1/1 [==============================] - 0s 18ms/step\n",
442
+ "1/1 [==============================] - 0s 18ms/step\n",
443
+ "1/1 [==============================] - 0s 17ms/step\n",
444
+ "1/1 [==============================] - 0s 18ms/step\n",
445
+ "1/1 [==============================] - 0s 17ms/step\n",
446
+ "1/1 [==============================] - 0s 17ms/step\n",
447
+ "1/1 [==============================] - 0s 18ms/step\n",
448
+ "1/1 [==============================] - 0s 19ms/step\n",
449
+ "1/1 [==============================] - 0s 20ms/step\n",
450
+ "1/1 [==============================] - 0s 17ms/step\n",
451
+ "1/1 [==============================] - 0s 17ms/step\n",
452
+ "1/1 [==============================] - 0s 18ms/step\n",
453
+ "1/1 [==============================] - 0s 20ms/step\n",
454
+ "1/1 [==============================] - 0s 18ms/step\n",
455
+ "1/1 [==============================] - 0s 18ms/step\n",
456
+ "1/1 [==============================] - 0s 17ms/step\n",
457
+ "1/1 [==============================] - 0s 18ms/step\n",
458
+ "1/1 [==============================] - 0s 27ms/step\n",
459
+ "1/1 [==============================] - 0s 30ms/step\n",
460
+ "1/1 [==============================] - 0s 31ms/step\n",
461
+ "1/1 [==============================] - 0s 26ms/step\n",
462
+ "1/1 [==============================] - 0s 35ms/step\n",
463
+ "1/1 [==============================] - 0s 28ms/step\n",
464
+ "1/1 [==============================] - 0s 26ms/step\n",
465
+ "1/1 [==============================] - 0s 25ms/step\n",
466
+ "1/1 [==============================] - 0s 28ms/step\n",
467
+ "1/1 [==============================] - 0s 25ms/step\n",
468
+ "1/1 [==============================] - 0s 30ms/step\n",
469
+ "1/1 [==============================] - 0s 26ms/step\n",
470
+ "1/1 [==============================] - 0s 27ms/step\n",
471
+ "1/1 [==============================] - 0s 26ms/step\n",
472
+ "1/1 [==============================] - 0s 28ms/step\n",
473
+ "1/1 [==============================] - 0s 32ms/step\n",
474
+ "1/1 [==============================] - 0s 29ms/step\n",
475
+ "1/1 [==============================] - 0s 26ms/step\n",
476
+ "1/1 [==============================] - 0s 34ms/step\n",
477
+ "1/1 [==============================] - 0s 30ms/step\n",
478
+ "1/1 [==============================] - 0s 27ms/step\n",
479
+ "1/1 [==============================] - 0s 26ms/step\n",
480
+ "1/1 [==============================] - 0s 35ms/step\n",
481
+ "1/1 [==============================] - 0s 30ms/step\n",
482
+ "1/1 [==============================] - 0s 19ms/step\n",
483
+ "1/1 [==============================] - 0s 19ms/step\n",
484
+ "1/1 [==============================] - 0s 17ms/step\n",
485
+ "1/1 [==============================] - 0s 17ms/step\n",
486
+ "1/1 [==============================] - 0s 17ms/step\n",
487
+ "1/1 [==============================] - 0s 21ms/step\n",
488
+ "1/1 [==============================] - 0s 19ms/step\n",
489
+ "1/1 [==============================] - 0s 17ms/step\n",
490
+ "1/1 [==============================] - 0s 17ms/step\n",
491
+ "1/1 [==============================] - 0s 17ms/step\n",
492
+ "1/1 [==============================] - 0s 18ms/step\n",
493
+ "1/1 [==============================] - 0s 17ms/step\n",
494
+ "1/1 [==============================] - 0s 17ms/step\n",
495
+ "1/1 [==============================] - 0s 17ms/step\n",
496
+ "1/1 [==============================] - 0s 17ms/step\n",
497
+ "1/1 [==============================] - 0s 18ms/step\n",
498
+ "1/1 [==============================] - 0s 18ms/step\n",
499
+ "1/1 [==============================] - 0s 20ms/step\n",
500
+ "1/1 [==============================] - 0s 18ms/step\n",
501
+ "1/1 [==============================] - 0s 17ms/step\n",
502
+ "1/1 [==============================] - 0s 20ms/step\n",
503
+ "1/1 [==============================] - 0s 19ms/step\n",
504
+ "1/1 [==============================] - 0s 18ms/step\n",
505
+ "1/1 [==============================] - 0s 18ms/step\n",
506
+ "1/1 [==============================] - 0s 17ms/step\n",
507
+ "1/1 [==============================] - 0s 17ms/step\n",
508
+ "1/1 [==============================] - 0s 17ms/step\n",
509
+ "1/1 [==============================] - 0s 17ms/step\n",
510
+ "1/1 [==============================] - 0s 17ms/step\n",
511
+ "1/1 [==============================] - 0s 17ms/step\n",
512
+ "1/1 [==============================] - 0s 17ms/step\n",
513
+ "1/1 [==============================] - 0s 17ms/step\n",
514
+ "1/1 [==============================] - 0s 20ms/step\n",
515
+ "1/1 [==============================] - 0s 18ms/step\n",
516
+ "1/1 [==============================] - 0s 17ms/step\n",
517
+ "1/1 [==============================] - 0s 18ms/step\n",
518
+ "1/1 [==============================] - 0s 19ms/step\n",
519
+ "1/1 [==============================] - 0s 18ms/step\n",
520
+ "1/1 [==============================] - 0s 18ms/step\n",
521
+ "1/1 [==============================] - 0s 17ms/step\n",
522
+ "1/1 [==============================] - 0s 16ms/step\n",
523
+ "1/1 [==============================] - 0s 18ms/step\n",
524
+ "1/1 [==============================] - 0s 18ms/step\n",
525
+ "1/1 [==============================] - 0s 18ms/step\n",
526
+ "1/1 [==============================] - 0s 19ms/step\n",
527
+ "1/1 [==============================] - 0s 19ms/step\n",
528
+ "1/1 [==============================] - 0s 17ms/step\n",
529
+ "1/1 [==============================] - 0s 19ms/step\n",
530
+ "1/1 [==============================] - 0s 20ms/step\n",
531
+ "1/1 [==============================] - 0s 19ms/step\n",
532
+ "1/1 [==============================] - 0s 19ms/step\n",
533
+ "1/1 [==============================] - 0s 17ms/step\n",
534
+ "1/1 [==============================] - 0s 18ms/step\n",
535
+ "1/1 [==============================] - 0s 18ms/step\n",
536
+ "1/1 [==============================] - 0s 19ms/step\n",
537
+ "1/1 [==============================] - 0s 18ms/step\n",
538
+ "1/1 [==============================] - 0s 17ms/step\n",
539
+ "1/1 [==============================] - 0s 17ms/step\n",
540
+ "1/1 [==============================] - 0s 17ms/step\n",
541
+ "1/1 [==============================] - 0s 17ms/step\n",
542
+ "1/1 [==============================] - 0s 17ms/step\n",
543
+ "1/1 [==============================] - 0s 17ms/step\n",
544
+ "1/1 [==============================] - 0s 18ms/step\n",
545
+ "1/1 [==============================] - 0s 18ms/step\n",
546
+ "1/1 [==============================] - 0s 18ms/step\n",
547
+ "1/1 [==============================] - 0s 19ms/step\n",
548
+ "1/1 [==============================] - 0s 18ms/step\n",
549
+ "1/1 [==============================] - 0s 21ms/step\n",
550
+ "1/1 [==============================] - 0s 17ms/step\n",
551
+ "1/1 [==============================] - 0s 19ms/step\n",
552
+ "1/1 [==============================] - 0s 23ms/step\n",
553
+ "1/1 [==============================] - 0s 18ms/step\n",
554
+ "1/1 [==============================] - 0s 17ms/step\n",
555
+ "1/1 [==============================] - 0s 18ms/step\n",
556
+ "1/1 [==============================] - 0s 18ms/step\n",
557
+ "1/1 [==============================] - 0s 18ms/step\n",
558
+ "1/1 [==============================] - 0s 17ms/step\n",
559
+ "1/1 [==============================] - 0s 17ms/step\n",
560
+ "1/1 [==============================] - 0s 17ms/step\n",
561
+ "1/1 [==============================] - 0s 17ms/step\n",
562
+ "1/1 [==============================] - 0s 17ms/step\n",
563
+ "1/1 [==============================] - 0s 22ms/step\n",
564
+ "1/1 [==============================] - 0s 18ms/step\n",
565
+ "1/1 [==============================] - 0s 18ms/step\n",
566
+ "1/1 [==============================] - 0s 18ms/step\n",
567
+ "1/1 [==============================] - 0s 18ms/step\n",
568
+ "1/1 [==============================] - 0s 17ms/step\n",
569
+ "1/1 [==============================] - 0s 17ms/step\n",
570
+ "1/1 [==============================] - 0s 22ms/step\n",
571
+ "1/1 [==============================] - 0s 17ms/step\n",
572
+ "1/1 [==============================] - 0s 17ms/step\n",
573
+ "1/1 [==============================] - 0s 17ms/step\n",
574
+ "1/1 [==============================] - 0s 18ms/step\n",
575
+ "1/1 [==============================] - 0s 17ms/step\n",
576
+ "1/1 [==============================] - 0s 17ms/step\n",
577
+ "1/1 [==============================] - 0s 23ms/step\n",
578
+ "1/1 [==============================] - 0s 20ms/step\n",
579
+ "1/1 [==============================] - 0s 17ms/step\n",
580
+ "1/1 [==============================] - 0s 17ms/step\n",
581
+ "1/1 [==============================] - 0s 17ms/step\n",
582
+ "1/1 [==============================] - 0s 17ms/step\n",
583
+ "1/1 [==============================] - 0s 18ms/step\n",
584
+ "1/1 [==============================] - 0s 21ms/step\n",
585
+ "1/1 [==============================] - 0s 28ms/step\n",
586
+ "1/1 [==============================] - 0s 18ms/step\n",
587
+ "1/1 [==============================] - 0s 20ms/step\n",
588
+ "1/1 [==============================] - 0s 20ms/step\n",
589
+ "1/1 [==============================] - 0s 20ms/step\n",
590
+ "1/1 [==============================] - 0s 19ms/step\n",
591
+ "1/1 [==============================] - 0s 18ms/step\n",
592
+ "1/1 [==============================] - 0s 19ms/step\n",
593
+ "1/1 [==============================] - 0s 18ms/step\n",
594
+ "1/1 [==============================] - 0s 17ms/step\n",
595
+ "1/1 [==============================] - 0s 18ms/step\n",
596
+ "1/1 [==============================] - 0s 18ms/step\n",
597
+ "1/1 [==============================] - 0s 17ms/step\n",
598
+ "1/1 [==============================] - 0s 17ms/step\n",
599
+ "1/1 [==============================] - 0s 18ms/step\n",
600
+ "1/1 [==============================] - 0s 18ms/step\n",
601
+ "1/1 [==============================] - 0s 24ms/step\n",
602
+ "1/1 [==============================] - 0s 17ms/step\n",
603
+ "1/1 [==============================] - 0s 17ms/step\n",
604
+ "1/1 [==============================] - 0s 17ms/step\n",
605
+ "1/1 [==============================] - 0s 19ms/step\n",
606
+ "1/1 [==============================] - 0s 17ms/step\n",
607
+ "1/1 [==============================] - 0s 17ms/step\n",
608
+ "1/1 [==============================] - 0s 20ms/step\n",
609
+ "1/1 [==============================] - 0s 17ms/step\n",
610
+ "1/1 [==============================] - 0s 20ms/step\n",
611
+ "1/1 [==============================] - 0s 18ms/step\n",
612
+ "1/1 [==============================] - 0s 18ms/step\n",
613
+ "1/1 [==============================] - 0s 18ms/step\n",
614
+ "1/1 [==============================] - 0s 17ms/step\n",
615
+ "1/1 [==============================] - 0s 17ms/step\n",
616
+ "1/1 [==============================] - 0s 16ms/step\n",
617
+ "1/1 [==============================] - 0s 17ms/step\n",
618
+ "1/1 [==============================] - 0s 23ms/step\n",
619
+ "1/1 [==============================] - 0s 17ms/step\n",
620
+ "1/1 [==============================] - 0s 18ms/step\n",
621
+ "1/1 [==============================] - 0s 17ms/step\n",
622
+ "1/1 [==============================] - 0s 17ms/step\n",
623
+ "1/1 [==============================] - 0s 17ms/step\n",
624
+ "1/1 [==============================] - 0s 17ms/step\n",
625
+ "1/1 [==============================] - 0s 17ms/step\n",
626
+ "1/1 [==============================] - 0s 17ms/step\n",
627
+ "1/1 [==============================] - 0s 17ms/step\n",
628
+ "1/1 [==============================] - 0s 17ms/step\n",
629
+ "1/1 [==============================] - 0s 17ms/step\n",
630
+ "1/1 [==============================] - 0s 17ms/step\n",
631
+ "1/1 [==============================] - 0s 20ms/step\n",
632
+ "1/1 [==============================] - 0s 16ms/step\n",
633
+ "1/1 [==============================] - 0s 16ms/step\n",
634
+ "1/1 [==============================] - 0s 17ms/step\n",
635
+ "1/1 [==============================] - 0s 18ms/step\n",
636
+ "1/1 [==============================] - 0s 17ms/step\n",
637
+ "1/1 [==============================] - 0s 17ms/step\n",
638
+ "1/1 [==============================] - 0s 18ms/step\n",
639
+ "1/1 [==============================] - 0s 17ms/step\n",
640
+ "1/1 [==============================] - 0s 18ms/step\n",
641
+ "1/1 [==============================] - 0s 18ms/step\n",
642
+ "1/1 [==============================] - 0s 17ms/step\n",
643
+ "1/1 [==============================] - 0s 30ms/step\n",
644
+ "1/1 [==============================] - 0s 28ms/step\n",
645
+ "1/1 [==============================] - 0s 29ms/step\n",
646
+ "1/1 [==============================] - 0s 38ms/step\n",
647
+ "1/1 [==============================] - 0s 30ms/step\n",
648
+ "1/1 [==============================] - 0s 31ms/step\n",
649
+ "1/1 [==============================] - 0s 27ms/step\n",
650
+ "1/1 [==============================] - 0s 24ms/step\n",
651
+ "1/1 [==============================] - 0s 24ms/step\n",
652
+ "1/1 [==============================] - 0s 25ms/step\n",
653
+ "1/1 [==============================] - 0s 25ms/step\n",
654
+ "1/1 [==============================] - 0s 25ms/step\n",
655
+ "1/1 [==============================] - 0s 30ms/step\n",
656
+ "1/1 [==============================] - 0s 25ms/step\n",
657
+ "1/1 [==============================] - 0s 25ms/step\n",
658
+ "1/1 [==============================] - 0s 26ms/step\n",
659
+ "1/1 [==============================] - 0s 26ms/step\n",
660
+ "1/1 [==============================] - 0s 36ms/step\n",
661
+ "1/1 [==============================] - 0s 28ms/step\n",
662
+ "1/1 [==============================] - 0s 31ms/step\n",
663
+ "1/1 [==============================] - 0s 29ms/step\n",
664
+ "1/1 [==============================] - 0s 29ms/step\n",
665
+ "1/1 [==============================] - 0s 28ms/step\n",
666
+ "1/1 [==============================] - 0s 28ms/step\n",
667
+ "1/1 [==============================] - 0s 27ms/step\n",
668
+ "1/1 [==============================] - 0s 18ms/step\n",
669
+ "1/1 [==============================] - 0s 18ms/step\n",
670
+ "1/1 [==============================] - 0s 22ms/step\n",
671
+ "1/1 [==============================] - 0s 18ms/step\n",
672
+ "1/1 [==============================] - 0s 17ms/step\n",
673
+ "1/1 [==============================] - 0s 19ms/step\n",
674
+ "1/1 [==============================] - 0s 18ms/step\n",
675
+ "1/1 [==============================] - 0s 19ms/step\n",
676
+ "1/1 [==============================] - 0s 19ms/step\n",
677
+ "1/1 [==============================] - 0s 19ms/step\n",
678
+ "1/1 [==============================] - 0s 18ms/step\n",
679
+ "1/1 [==============================] - 0s 16ms/step\n",
680
+ "1/1 [==============================] - 0s 18ms/step\n",
681
+ "1/1 [==============================] - 0s 17ms/step\n",
682
+ "1/1 [==============================] - 0s 17ms/step\n",
683
+ "1/1 [==============================] - 0s 17ms/step\n",
684
+ "1/1 [==============================] - 0s 17ms/step\n",
685
+ "1/1 [==============================] - 0s 17ms/step\n",
686
+ "1/1 [==============================] - 0s 19ms/step\n",
687
+ "1/1 [==============================] - 0s 16ms/step\n",
688
+ "1/1 [==============================] - 0s 16ms/step\n",
689
+ "1/1 [==============================] - 0s 16ms/step\n",
690
+ "1/1 [==============================] - 0s 19ms/step\n",
691
+ "1/1 [==============================] - 0s 25ms/step\n",
692
+ "Accuracy: 0.744\n"
693
+ ]
694
+ }
695
+ ]
696
+ }
697
+ ]
698
+ }
Experiments/quicknet_knn.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib.pyplot as plt
2
+ import numpy as np
3
+ import pickle
4
+
5
+ def find_knn_and_plot_image(test_data,model,k_value:int =10):
6
+ x_test = test_data
7
+ point = np.array(x_test)
8
+ point = np.expand_dims(point, axis=0)
9
+
10
+ test_embedding = model.predict(point)
11
+
12
+ with open(r'Model\data\train_embeddings_resnet18.pkl', 'rb') as f:
13
+ data = pickle.load(f)
14
+
15
+ # Convert data to NumPy array
16
+ train_data = np.array(data)
17
+
18
+ # Convert NumPy array to TensorFlow tensor
19
+ # data_tensor = tf.convert_to_tensor(data_array, dtype=tf.float32)
20
+
21
+ distance_with_label_and_index = []
22
+ print(train_data)
23
+ for i,(x_train) in enumerate(train_data):
24
+ train_point = np.array(x_train)
25
+ distance_with_label_and_index.append((i,np.linalg.norm(test_embedding-train_point)))
26
+
27
+ #sorting based on distance
28
+ distance_with_label_and_index_sorted=sorted(distance_with_label_and_index,key=lambda x: x[1])
29
+ k_nearest_points = distance_with_label_and_index_sorted[0:k_value]
30
+
31
+
32
+ #calculating accuracy
33
+ result =[]
34
+ for i,(index,distance) in enumerate(k_nearest_points):
35
+ print(index)
36
+ result.append(train_image[index])
37
+ return result
Experiments/train_resnet50_3hll.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from tqdm import tqdm
3
+ import wandb
4
+ import torch.nn as nn
5
+ import torchvision.transforms as v2
6
+ from torch.utils.data import DataLoader
7
+ import torchvision.datasets as datasets
8
+
9
+ wandb.login()
10
+
11
+ wandb.init(project="CIFAR", entity="gjyotin305")
12
+ mps_device = torch.device("cuda")
13
+
14
+
15
+ IMAGE_SIZE = 224
16
+ mean, std = [0.4914, 0.4822, 0.4465], [0.247, 0.243, 0.261]
17
+ wandb.init()
18
+ transform_test = v2.Compose(
19
+ [v2.Resize((IMAGE_SIZE, IMAGE_SIZE)),
20
+ v2.ToTensor(),
21
+ v2.Normalize(mean, std)])
22
+
23
+ transform_train = v2.Compose(
24
+ [v2.Resize((IMAGE_SIZE, IMAGE_SIZE)),
25
+ v2.RandomRotation(20),
26
+ v2.RandomHorizontalFlip(0.1),
27
+ v2.ColorJitter(brightness = 0.1, contrast=0.1, saturation=0.1),
28
+ v2.RandomAdjustSharpness(sharpness_factor=2, p=0.1),
29
+ v2.ToTensor(),
30
+ v2.Normalize(mean, std),
31
+ v2.RandomErasing(p=0.75, scale=(0.02, 0.1), value=1.0, inplace=False)])
32
+
33
+ cifar_trainset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
34
+ cifar_testset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
35
+
36
+ batch_size=64
37
+
38
+ trainLoader = DataLoader(cifar_trainset, batch_size=batch_size, shuffle=True, num_workers=8)
39
+ testLoader = DataLoader(cifar_testset, batch_size=batch_size, shuffle=True, num_workers=8)
40
+
41
+ classes = ('plane', 'car', 'bird', 'cat',
42
+ 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
43
+
44
+ print(next(iter(trainLoader))[0].shape)
45
+
46
+
47
+ import torchvision.datasets as datasets
48
+ import torchvision.models as models
49
+
50
+ resnet = models.resnet50(pretrained=True)
51
+
52
+
53
+ class JHARMNet(nn.Module):
54
+ def __init__(self, pretrained_model, num_classes):
55
+ super().__init__()
56
+ self.premodel = pretrained_model
57
+ self.premodel.fc = nn.Linear(2048, num_classes)
58
+ nn.init.xavier_uniform_(self.premodel.fc.weight)
59
+
60
+ def forward(self, x):
61
+ out = self.premodel(x)
62
+ return out
63
+
64
+ class HiddenLayer(nn.Module):
65
+ def __init__(self, pretrained_model):
66
+ super().__init__()
67
+ self.premodel = pretrained_model
68
+ self.new_layer = nn.Sequential(
69
+ nn.Linear(1000, 512),
70
+ nn.LeakyReLU(),
71
+ nn.Linear(512, 512),
72
+ nn.LeakyReLU(),
73
+ nn.Linear(512, 256),
74
+ nn.LeakyReLU(),
75
+ nn.Linear(256, 10)
76
+ )
77
+
78
+ def forward(self, x):
79
+ out = self.premodel(x)
80
+ out_new_layer = self.new_layer(out)
81
+ return out_new_layer
82
+
83
+ model_check = HiddenLayer(resnet)
84
+ criterion = nn.CrossEntropyLoss()
85
+ learning_rate=0.001
86
+
87
+ optimizer = torch.optim.SGD(model_check.parameters(),lr=learning_rate, momentum=0.2)
88
+
89
+ # scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min')
90
+
91
+ def train_model(model, train_loader, validation_loader, optimizer, n_epochs = 20):
92
+
93
+ # Global variable
94
+ N_test = len(cifar_testset)
95
+ accuracy_list = []
96
+ train_loss_list = []
97
+ model = model.to(mps_device)
98
+ train_cost_list = []
99
+ val_cost_list = []
100
+
101
+ for epoch in range(n_epochs):
102
+ train_COST = 0
103
+ print(f"Training Epoch: {epoch+1}")
104
+ for x,y in tqdm(train_loader):
105
+ x = x.to(mps_device)
106
+ y = y.to(mps_device)
107
+ model.train()
108
+ optimizer.zero_grad()
109
+ z = model(x)
110
+ loss = criterion(z,y)
111
+ loss.backward()
112
+ wandb.log({"train_loss": loss.item()})
113
+ optimizer.step()
114
+ train_COST+=loss.item()
115
+
116
+ train_COST = train_COST/len(train_loader)
117
+ train_cost_list.append(train_COST)
118
+ correct = 0
119
+ print(f"Validation Loop")
120
+ # Perform the prediction on the validation data
121
+ val_COST = 0
122
+ for x_test, y_test in tqdm(validation_loader):
123
+ model.eval()
124
+ x_test = x_test.to(mps_device)
125
+ y_test = y_test.to(mps_device)
126
+ z = model(x_test)
127
+ val_loss = criterion(z, y_test)
128
+ # scheduler.step(val_loss)
129
+ wandb.log({"val_loss": val_loss})
130
+ _, yhat = torch.max(z.data, 1)
131
+ correct += (yhat==y_test).sum().item()
132
+ val_COST+=val_loss.item()
133
+
134
+ val_COST = val_COST/ len(validation_loader)
135
+ val_cost_list.append(val_COST)
136
+ accuracy = correct / N_test
137
+ accuracy_list.append(accuracy)
138
+ wandb.log({"val_acc": accuracy})
139
+ if (epoch+1)%5 == 0:
140
+ torch.save(model.state_dict(), f"/root/test_cifar/log_model/CIFAR_{epoch+1}_end_hll.pt")
141
+ print(f"Saved Model as CIFAR_{epoch+1}_end_hll.pt")
142
+
143
+ print("--> Epoch Number : {}".format(epoch + 1),
144
+ " | Training Loss : {}".format(round(train_COST,4)),
145
+ " | Validation Loss : {}".format(round(val_COST,4)),
146
+ " | Validation Accuracy : {}%".format(round(accuracy * 100, 2)))
147
+
148
+ return accuracy_list, train_cost_list, val_cost_list, model
149
+
150
+
151
+
152
+ accuracy_list_normalv5, train_cost_listv5, val_cost_listv5, model_to_save=train_model(model=model_check,
153
+ n_epochs=100,
154
+ train_loader=trainLoader,
155
+ validation_loader=testLoader,
156
+ optimizer=optimizer)
157
+
158
+ torch.save(model_to_save.state_dict(), "CIFAR_end_hll.pt")
159
+ model = torch.load("CIFAR_end_hll.pt")
160
+ model.eval()
161
+
162
+ pred = model(next(iter(testLoader))[0])
163
+ print(pred[0][0], next(iter(testLoader))[1][0])
Experiments/train_resnet50_iter_1.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torchvision.models as models
2
+ import torch
3
+ from tqdm import tqdm
4
+ import torch.nn as nn
5
+ import torchvision.transforms as v2
6
+ from torch.utils.data import DataLoader
7
+ import torchvision.datasets as datasets
8
+
9
+
10
+ mps_device = torch.device("cuda")
11
+
12
+ IMAGE_SIZE = 224
13
+ mean, std = [0.4914, 0.4822, 0.4465], [0.247, 0.243, 0.261]
14
+
15
+ transform_test = v2.Compose(
16
+ [v2.Resize((IMAGE_SIZE, IMAGE_SIZE)),
17
+ v2.ToTensor(),
18
+ v2.Normalize(mean, std)])
19
+
20
+ transform_train = v2.Compose(
21
+ [v2.Resize((IMAGE_SIZE, IMAGE_SIZE)),
22
+ v2.RandomRotation(20),
23
+ v2.RandomHorizontalFlip(0.1),
24
+ v2.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1),
25
+ v2.RandomAdjustSharpness(sharpness_factor=2, p=0.1),
26
+ v2.ToTensor(),
27
+ v2.Normalize(mean, std),
28
+ v2.RandomErasing(p=0.75, scale=(0.02, 0.1), value=1.0, inplace=False)])
29
+
30
+ cifar_trainset = datasets.CIFAR10(
31
+ root='./data',
32
+ train=True,
33
+ download=True,
34
+ transform=transform_train)
35
+ cifar_testset = datasets.CIFAR10(
36
+ root='./data',
37
+ train=False,
38
+ download=True,
39
+ transform=transform_test)
40
+
41
+ batch_size = 64
42
+
43
+ trainLoader = DataLoader(
44
+ cifar_trainset,
45
+ batch_size=batch_size,
46
+ shuffle=True,
47
+ num_workers=8)
48
+ testLoader = DataLoader(
49
+ cifar_testset,
50
+ batch_size=batch_size,
51
+ shuffle=True,
52
+ num_workers=8)
53
+
54
+ classes = ('plane', 'car', 'bird', 'cat',
55
+ 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
56
+
57
+ print(next(iter(trainLoader))[0].shape)
58
+
59
+
60
+ resnet = models.resnet50(pretrained=True)
61
+ # print(resnet)
62
+ # print(resnet.fc)
63
+
64
+
65
+ class JHARMNet(nn.Module):
66
+ def __init__(self, pretrained_model, num_classes):
67
+ super().__init__()
68
+ self.premodel = pretrained_model
69
+ self.premodel.fc = nn.Linear(2048, num_classes)
70
+ nn.init.xavier_uniform_(self.premodel.fc.weight)
71
+
72
+ def forward(self, x):
73
+ out = self.premodel(x)
74
+ return out
75
+
76
+
77
+ model_check = JHARMNet(resnet, 10)
78
+ criterion = nn.CrossEntropyLoss()
79
+ learning_rate = 0.1
80
+
81
+ optimizer = torch.optim.SGD(
82
+ model_check.parameters(),
83
+ lr=learning_rate,
84
+ momentum=0.2)
85
+
86
+ scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min')
87
+
88
+
89
+ def train_model(
90
+ model,
91
+ train_loader,
92
+ validation_loader,
93
+ optimizer,
94
+ n_epochs=20):
95
+
96
+ # Global variable
97
+ N_test = len(cifar_testset)
98
+ accuracy_list = []
99
+ train_loss_list = []
100
+ model = model.to(mps_device)
101
+ train_cost_list = []
102
+ val_cost_list = []
103
+
104
+ for epoch in range(n_epochs):
105
+ train_COST = 0
106
+ print(f"Training Epoch: {epoch}")
107
+ for x, y in tqdm(train_loader):
108
+ x = x.to(mps_device)
109
+ y = y.to(mps_device)
110
+ model.train()
111
+ optimizer.zero_grad()
112
+ z = model(x)
113
+ loss = criterion(z, y)
114
+ loss.backward()
115
+ optimizer.step()
116
+ train_COST += loss.item()
117
+
118
+ train_COST = train_COST / len(train_loader)
119
+ train_cost_list.append(train_COST)
120
+ correct = 0
121
+ print(f"Validation Loop")
122
+ # Perform the prediction on the validation data
123
+ val_COST = 0
124
+ for x_test, y_test in tqdm(validation_loader):
125
+ model.eval()
126
+ x_test = x_test.to(mps_device)
127
+ y_test = y_test.to(mps_device)
128
+ z = model(x_test)
129
+ val_loss = criterion(z, y_test)
130
+ scheduler.step(val_loss)
131
+ _, yhat = torch.max(z.data, 1)
132
+ correct += (yhat == y_test).sum().item()
133
+ val_COST += val_loss.item()
134
+
135
+ val_COST = val_COST / len(validation_loader)
136
+ val_cost_list.append(val_COST)
137
+ accuracy = correct / N_test
138
+ accuracy_list.append(accuracy)
139
+
140
+ print("--> Epoch Number : {}".format(epoch + 1),
141
+ " | Training Loss : {}".format(round(train_COST, 4)),
142
+ " | Validation Loss : {}".format(round(val_COST, 4)),
143
+ " | Validation Accuracy : {}%".format(round(accuracy * 100, 2)))
144
+
145
+ return accuracy_list, train_cost_list, val_cost_list, model
146
+
147
+
148
+ accuracy_list_normalv5, train_cost_listv5, val_cost_listv5, model_to_save = train_model(
149
+ model=model_check, n_epochs=100, train_loader=trainLoader, validation_loader=testLoader, optimizer=optimizer)
150
+
151
+ torch.save(model_to_save.state_dict(), "CIFAR.pt")
152
+ model = torch.load("CIFAR.pt")
153
+ model.eval()
154
+
155
+ pred = model(next(iter(testLoader))[0])
156
+ print(pred[0][0], next(iter(testLoader))[1][0])
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2024 Jyotin Goel
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
MidTerm_Report.pdf ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e171b9314127b3662ce813589398106b1c1dbabd996bc36aadc475cfda02ed8d
3
+ size 1110513
Model/CIFAR.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9fc375fa01813882af9e2a88654b765fb58bd8cd2f48731963ffa56fb1652434
3
+ size 94418218
Model/Resnet50_train_features.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:50516ceab17cd7ca8eaa2738b30e201fad30a89972a256cc370a56236a70c427
3
+ size 2001185
Model/__pycache__/centroid_app.cpython-311.pyc ADDED
Binary file (5.62 kB). View file
 
Model/ann.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6375c0313b719e85a904bc2d733ee741f084599b3f5671b8dd4b978a5a4101b1
3
+ size 6832736
Model/centroid_app.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import pickle
3
+ import random
4
+ import numpy as np
5
+ import cv2
6
+ from tensorflow.keras import models, layers
7
+
8
+ # Load the trained model architecture
9
+ def create_resnet18():
10
+ model = models.Sequential()
11
+ model.add(layers.Conv2D(64, (3, 3), activation='relu', padding='same', input_shape=(32, 32, 3)))
12
+ model.add(layers.MaxPooling2D((2, 2)))
13
+ model.add(layers.Conv2D(64, (3, 3), activation='relu', padding='same'))
14
+ model.add(layers.MaxPooling2D((2, 2)))
15
+ model.add(layers.Conv2D(128, (3, 3), activation='relu', padding='same'))
16
+ model.add(layers.Conv2D(128, (3, 3), activation='relu', padding='same'))
17
+ model.add(layers.MaxPooling2D((2, 2)))
18
+ model.add(layers.Conv2D(256, (3, 3), activation='relu', padding='same'))
19
+ model.add(layers.Conv2D(256, (3, 3), activation='relu', padding='same'))
20
+ model.add(layers.MaxPooling2D((2, 2)))
21
+ model.add(layers.Flatten())
22
+ model.add(layers.Dense(512, activation='relu'))
23
+ model.add(layers.Dense(10, activation='softmax'))
24
+ return model
25
+
26
+ # Load the pretrained weights
27
+ def load_pretrained_weights(model, weights_path):
28
+ model.load_weights(weights_path)
29
+
30
+ # Function to unpickle a file
31
+ def unpickle(file):
32
+ with open(file, 'rb') as fo:
33
+ dict = pickle.load(fo, encoding='bytes')
34
+ return dict
35
+
36
+ # Function to load images from the unpickled data batch file of a specific class
37
+ def load_class_images(class_index, train_batches):
38
+ images = []
39
+ for batch in train_batches:
40
+ if b'data' in batch and b'labels' in batch:
41
+ data = batch[b'data']
42
+ labels = batch[b'labels']
43
+ for i, label in enumerate(labels):
44
+ if label == class_index:
45
+ img = data[i].reshape(3, 32, 32).transpose(1, 2, 0) # Reshape and transpose the image
46
+ images.append(img)
47
+ return images
48
+
49
+ # Function to calculate the distance of the mean embeddings with a query image
50
+ def classify_query(query_image, model, mean_embeddings):
51
+ query_embedding = model.predict(np.expand_dims(query_image, axis=0))
52
+ distances = [np.linalg.norm(query_embedding.flatten() - mean_embedding) for mean_embedding in mean_embeddings]
53
+ predicted_class = np.argmin(distances)
54
+ return predicted_class
55
+
56
+
57
+ def retrieve(query_image,k=3):
58
+ model = create_resnet18()
59
+ load_pretrained_weights(model, 'Model\pretrained_model_weights.h5')
60
+
61
+ mean_embeddings = pickle.load(open('Model\data\mean_embeddings.pkl', 'rb'))
62
+
63
+ # query_image_path = '/content/airplane_8925.png'
64
+ # query_image = cv2.imread(query_image_path)
65
+ query_image = cv2.resize(query_image, (32, 32)) / 255.0 # Resize and normalize the image
66
+
67
+ predicted_class = classify_query(query_image, model, mean_embeddings)
68
+ # print("Predicted Class:", predicted_class)
69
+
70
+ # Load random images of the predicted class
71
+ train_batches = [unpickle(rf"Model\data\data_batch_{i}") for i in range(1,6)]
72
+ class_images = load_class_images(predicted_class, train_batches)
73
+
74
+ if class_images:
75
+ random_images = random.sample(class_images, k) # Select 3 random images
76
+ return random_images
77
+ else:
78
+ print("No images found for the predicted class.")
79
+
80
+
Model/data/batches.meta ADDED
Binary file (158 Bytes). View file
 
Model/data/data_batch_1 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:54636561a3ce25bd3e19253c6b0d8538147b0ae398331ac4a2d86c6d987368cd
3
+ size 31035704
Model/data/data_batch_2 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:766b2cef9fbc745cf056b3152224f7cf77163b330ea9a15f9392beb8b89bc5a8
3
+ size 31035320
Model/data/data_batch_3 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0f00d98ebfb30b3ec0ad19f9756dc2630b89003e10525f5e148445e82aa6a1f9
3
+ size 31035999
Model/data/data_batch_4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3f7bb240661948b8f4d53e36ec720d8306f5668bd0071dcb4e6c947f78e9682b
3
+ size 31035696
Model/data/data_batch_5 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d91802434d8376bbaeeadf58a737e3a1b12ac839077e931237e0dcd43adcb154
3
+ size 31035623
Model/data/mean_embeddings.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:272f9a5f5bd64519f68c8ac5397a6d98a920175d15ad67e19af110f54ac9dfaf
3
+ size 952
Model/data/test_batch ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f53d8d457504f7cff4ea9e021afcf0e0ad8e24a91f3fc42091b8adef61157831
3
+ size 31035526
Model/data/train_embeddings_resnet18.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8238e4c8fb43aa0175def221191fe3c9f7e137683746e1f33086681488a266f4
3
+ size 2000162
Model/pretrained_model_weights.h5 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3c60f7cc223c7256a4a3de446b2820d5cd05b552dcbe4a60fcea2f83465f1538
3
+ size 6735320
Model/resnet18.h5 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d8415a29fcbd661a67bc004b26e9bcccc7a2133bffa99cd1dc1811c8d1a37dd5
3
+ size 20154704
Model/test.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import os
3
+ import torch.nn as nn
4
+ import torchvision.datasets as datasets
5
+ import torchvision.transforms as transforms
6
+ import torchvision.models as models
7
+
8
+ os.system("gdown --id 1qO2OLR7skDibo1LaMKD3CiOl_jaCTZ0h")
9
+
10
+ class JHARMNet(nn.Module):
11
+ def __init__(self, pretrained_model, num_classes):
12
+ super().__init__()
13
+ self.premodel = pretrained_model
14
+ self.premodel.fc = nn.Linear(2048, num_classes)
15
+ nn.init.xavier_uniform_(self.premodel.fc.weight)
16
+
17
+ def forward(self, x):
18
+ out = self.premodel(x)
19
+ return out
20
+
21
+ class HiddenLayer(nn.Module):
22
+ def __init__(self, pretrained_model):
23
+ super().__init__()
24
+ self.premodel = pretrained_model
25
+ self.new_layer = nn.Sequential(
26
+ nn.Linear(1000, 512),
27
+ nn.LeakyReLU(),
28
+ nn.Linear(512, 512),
29
+ nn.LeakyReLU(),
30
+ nn.Linear(512, 256),
31
+ nn.LeakyReLU(),
32
+ nn.Linear(256, 10)
33
+ )
34
+
35
+ def forward(self, x):
36
+ out = self.premodel(x)
37
+ out_new_layer = self.new_layer(out)
38
+ return out_new_layer
39
+
40
+
41
+ resnet = models.resnet50(pretrained=True)
42
+
43
+ x = torch.randn((2, 3, 32, 32))
44
+
45
+ print(x.shape)
46
+ model_check = HiddenLayer(resnet)
47
+ model_check.load_state_dict(torch.load("CIFAR_end_hll.pt"))
48
+ model_check.eval()
49
+ print(model_check.forward(x).shape)
Preprocessing/cifar_eda.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
Preprocessing/k_means.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sklearn.cluster import KMeans
2
+ from sklearn.utils import shuffle
3
+ from sklearn.metrics import pairwise_distances_argmin
4
+ from PIL import Image
5
+ import numpy as np
6
+ import matplotlib.image as mpimg
7
+ import matplotlib.pyplot as plt
8
+ import random
9
+
10
+ def kmeans_image_compression(image, n_colors=8):
11
+
12
+ # Convert image to numpy array
13
+ image_np = np.array(image)
14
+
15
+ # Reshape the image to a 2D array of pixels
16
+ image_reshaped = image_np.reshape(-1, 3)
17
+ # print(image_reshaped.shape)
18
+
19
+ # Shuffle the pixels
20
+ image_reshaped_sample = shuffle(image_reshaped, random_state=0)[:image_reshaped.shape[0]]
21
+
22
+ # Apply KMeans clustering using sklearn Library
23
+ kmeans = KMeans(n_clusters=n_colors, random_state=0,n_init=10).fit(image_reshaped_sample)
24
+ labels = kmeans.predict(image_reshaped)
25
+ centroids = kmeans.cluster_centers_
26
+
27
+ # Replace each pixel with its corresponding centroid color
28
+ compressed_image = centroids[labels]
29
+
30
+ # Reshape the compressed image back to its original shape
31
+ compressed_image = compressed_image.reshape(image_np.shape)
32
+
33
+ # Convert the compressed image array to uint8 data type
34
+ compressed_image = compressed_image.astype(np.uint8)
35
+ return compressed_image
36
+
37
+
38
+ compressed_train_data =[]
39
+ for x,y in train_data:
40
+ compressed_x=kmeans_image_compression(x,6)
41
+ compressed_train_data.append((compressed_x,y))
42
+
43
+ compressed_test_data = []
44
+ for x,y in test_data:
45
+ compressed_x=kmeans_image_compression(x,6)
46
+ compressed_test_data.append((compressed_x,y))
Preprocessing/utils.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sklearn.decomposition import PCA
2
+ import numpy as np
3
+
4
+
5
+ def pca_for_rgb(img, n_components):
6
+ red = img[:, :, 0]
7
+ green = img[:, :, 1]
8
+ blue = img[:, :, 2]
9
+
10
+ try:
11
+ pca = PCA(n_components=n_components)
12
+
13
+ red_transformed = pca.fit_transform(red)
14
+ red_inverted = pca.inverse_transform(red_transformed)
15
+
16
+ green_transformed = pca.fit_transform(green)
17
+ green_inverted = pca.inverse_transform(green_transformed)
18
+
19
+ blue_transformed = pca.fit_transform(blue)
20
+ blue_inverted = pca.inverse_transform(blue_transformed)
21
+ img_compressed = (
22
+ np.dstack(
23
+ (red_inverted,
24
+ green_inverted,
25
+ blue_inverted))).astype(
26
+ np.uint8)
27
+ return img_compressed
28
+ except BaseException:
29
+ print("Kindly put valid number of components")
README.md CHANGED
@@ -1,12 +1,44 @@
1
- ---
2
- title: Deep Image Search System
3
- emoji: πŸ‘€
4
- colorFrom: gray
5
- colorTo: purple
6
- sdk: gradio
7
- sdk_version: 4.37.2
8
- app_file: app.py
9
- pinned: false
10
- ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Deep_Image_Search_System
3
+ app_file: ui_gradio.py
4
+ sdk: gradio
5
+ sdk_version: 4.36.1
6
+ ---
7
+ # CSL2050 CourseProject
8
+
9
+ ## TOPIC : Image Retrieval
10
+
11
+ ## Team Members:
12
+ - Akshat Jain (B22CS096)
13
+ - Harshiv Shah (B22CS098)
14
+ - Mehta Jay Kamalkumar (B22CS034)
15
+ - Jyotin Goel (B22AI063)
16
+ - Rhythm Baghel (B22CS042)
17
+
18
+ ## Repository Structure
19
+
20
+ ```bash
21
+ β”œβ”€β”€ Experiments
22
+ β”‚ β”œβ”€β”€ cifar_10_knn.ipynb
23
+ β”‚ β”œβ”€β”€ cifar_10_pca_knn.ipynb
24
+ β”‚ β”œβ”€β”€ cnn-resnet34-cifar10.ipynb
25
+ β”‚ β”œβ”€β”€ cifar_10_mean.ipynb
26
+ β”‚ β”œβ”€β”€ cifar_10_centroid.ipynb
27
+ β”‚ β”œβ”€β”€ HOG+KNN.ipynb
28
+ β”‚ β”œβ”€β”€ Logs
29
+ β”‚ β”‚ └── RESNET50_CIFAR.out
30
+ β”‚ β”œβ”€β”€ PCA+HOG+KNN.ipynb
31
+ β”‚ └── train.py
32
+ β”œβ”€β”€ LICENSE
33
+ β”œβ”€β”€ MidTerm_Report.pdf
34
+ β”œβ”€β”€ Model
35
+ β”‚ β”œβ”€β”€ CIFAR.pt
36
+ β”‚ └── test.py
37
+ β”œβ”€β”€Gradio_Interface
38
+ β”‚ β”œβ”€β”€ ui_gradio.py
39
+ β”œβ”€β”€ Preprocessing
40
+ β”‚ β”œβ”€β”€ cifar_eda.ipynb
41
+ β”‚ β”œβ”€β”€ k_means.py
42
+ β”‚ └── utils.py
43
+ └── README.md
44
+ ```
__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ print("Check")
flagged/log.csv ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ Query Image,Retrieved Images,flag,username,timestamp
2
+ flagged\Query Image\184c0788f490a0624d75\Screenshot 2024-04-20 153136.png,"[{""image"": ""flagged\\Retrieved Images\\6efd8d4fa439705c4bec\\image.webp"", ""caption"": null}, {""image"": ""flagged\\Retrieved Images\\911f2dbf15da7790c9f6\\image.webp"", ""caption"": null}, {""image"": ""flagged\\Retrieved Images\\d194d62692564b86e213\\image.webp"", ""caption"": null}]",,,2024-04-20 15:56:52.109627
images/akshat.jpeg ADDED
images/dog.jpg ADDED
images/harshiv.jpg ADDED
images/horse.jpg ADDED