In [1]:
from scipy.spatial.distance import euclidean
import tensorflow as tf
from tensorflow.keras import datasets, layers, models
import numpy as np


(train_images, train_labels), (test_images, test_labels) = datasets.cifar10.load_data()


train_images = train_images / 255.0
test_images = test_images / 255.0


def create_resnet18():
 model = models.Sequential()
 model.add(layers.Conv2D(64, (3, 3), activation='relu', padding='same', input_shape=(32, 32, 3)))
 model.add(layers.MaxPooling2D((2, 2)))
 model.add(layers.Conv2D(64, (3, 3), activation='relu', padding='same'))
 model.add(layers.MaxPooling2D((2, 2)))
 model.add(layers.Conv2D(128, (3, 3), activation='relu', padding='same'))
 model.add(layers.Conv2D(128, (3, 3), activation='relu', padding='same'))
 model.add(layers.MaxPooling2D((2, 2)))
 model.add(layers.Conv2D(256, (3, 3), activation='relu', padding='same'))
 model.add(layers.Conv2D(256, (3, 3), activation='relu', padding='same'))
 model.add(layers.MaxPooling2D((2, 2)))
 model.add(layers.Flatten())
 model.add(layers.Dense(512, activation='relu'))
 model.add(layers.Dense(10, activation='softmax'))
 return model


model = create_resnet18()


model.compile(optimizer='adam',
 loss='sparse_categorical_crossentropy',
 metrics=['accuracy'])


model.fit(train_images, train_labels, epochs=10, batch_size=64, validation_data=(test_images, test_labels))


train_embeddings = model.predict(train_images)

num_classes = 10
mean_embeddings = np.zeros((num_classes, train_embeddings.shape[1]))


for class_label in range(num_classes):
 class_indices = np.where(train_labels.flatten() == class_label)[0]
 class_embeddings = train_embeddings[class_indices]
 mean_embedding = np.mean(class_embeddings, axis=0)
 mean_embeddings[class_label] = mean_embedding


def classify_query(query_image):
 query_embedding = model.predict(np.expand_dims(query_image, axis=0))
 distances = [euclidean(query_embedding.flatten(), mean_embedding) for mean_embedding in mean_embeddings]
 predicted_class = np.argmin(distances)
 return predicted_class

mean_embeddings


Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


array([[8.79073262e-01, 2.84781260e-03, 3.16492580e-02, 9.60051734e-03,
 8.88401456e-03, 4.04336443e-03, 1.89202465e-03, 2.60451157e-03,
 4.94241416e-02, 9.98295005e-03],
 [5.34512708e-03, 9.61387515e-01, 1.82328641e-03, 1.49091403e-03,
 3.62021121e-04, 2.36713630e-03, 9.99960001e-04, 3.58641875e-04,
 8.00825842e-03, 1.78600885e-02],
 [1.93752628e-02, 5.49300632e-04, 8.71673167e-01, 3.62917222e-02,
 2.81240512e-02, 1.94428060e-02, 1.17484620e-02, 7.90499710e-03,
 3.91457370e-03, 9.77586606e-04],
 [3.43155907e-03, 3.87360866e-04, 3.25083360e-02, 8.04273427e-01,
 2.26426423e-02, 1.00013100e-01, 1.48316240e-02, 1.55726429e-02,
 3.46732978e-03, 2.87183723e-03],
 [4.50019445e-03, 3.07688024e-04, 3.81155163e-02, 3.81957851e-02,
 8.43278885e-01, 2.98470370e-02, 1.07433125e-02, 3.15311365e-02,
 2.45781220e-03, 1.02343888e-03],
 [8.63184629e-04, 2.92542420e-04, 2.77942196e-02, 1.26914889e-01,
 1.68064609e-02, 7.96062946e-01, 6.89344807e-03, 2.19228752e-02,
 1.45196577e-03, 9.96942166e-04],
 [2.

In [2]:
# Calculate accuracy
correct_predictions = 0
total_predictions = 500

for i in range(500):
 predicted_class = classify_query(test_images[i])
 if predicted_class == test_labels[i]:
 correct_predictions += 1

accuracy = correct_predictions / total_predictions
print("Accuracy:", accuracy)


Accuracy: 0.744
