import pickle import random import numpy as np import cv2 from tensorflow.keras import models, layers # Load the trained model architecture def create_resnet18(): model = models.Sequential() model.add(layers.Conv2D(64, (3, 3), activation='relu', padding='same', input_shape=(32, 32, 3))) model.add(layers.MaxPooling2D((2, 2))) model.add(layers.Conv2D(64, (3, 3), activation='relu', padding='same')) model.add(layers.MaxPooling2D((2, 2))) model.add(layers.Conv2D(128, (3, 3), activation='relu', padding='same')) model.add(layers.Conv2D(128, (3, 3), activation='relu', padding='same')) model.add(layers.MaxPooling2D((2, 2))) model.add(layers.Conv2D(256, (3, 3), activation='relu', padding='same')) model.add(layers.Conv2D(256, (3, 3), activation='relu', padding='same')) model.add(layers.MaxPooling2D((2, 2))) model.add(layers.Flatten()) model.add(layers.Dense(512, activation='relu')) model.add(layers.Dense(10, activation='softmax')) return model # Load the pretrained weights def load_pretrained_weights(model, weights_path): model.load_weights(weights_path) # Function to unpickle a file def unpickle(file): with open(file, 'rb') as fo: dict = pickle.load(fo, encoding='bytes') return dict # Function to load images from the unpickled data batch file of a specific class def load_class_images(class_index, train_batches): images = [] for batch in train_batches: if b'data' in batch and b'labels' in batch: data = batch[b'data'] labels = batch[b'labels'] for i, label in enumerate(labels): if label == class_index: img = data[i].reshape(3, 32, 32).transpose(1, 2, 0) # Reshape and transpose the image images.append(img) return images # Function to calculate the distance of the mean embeddings with a query image def classify_query(query_image, model, mean_embeddings): query_embedding = model.predict(np.expand_dims(query_image, axis=0)) distances = [np.linalg.norm(query_embedding.flatten() - mean_embedding) for mean_embedding in mean_embeddings] predicted_class = np.argmin(distances) return predicted_class from google.colab.patches import cv2_imshow def main(): model = create_resnet18() load_pretrained_weights(model, 'pretrained_model_weights.h5') mean_embeddings = pickle.load(open('mean_embeddings.pkl', 'rb')) query_image_path = '/content/airplane_8925.png' query_image = cv2.imread(query_image_path) query_image = cv2.resize(query_image, (32, 32)) / 255.0 # Resize and normalize the image predicted_class = classify_query(query_image, model, mean_embeddings) print("Predicted Class:", predicted_class) # Load random images of the predicted class train_batches = [unpickle(f"/content/data_batch_{i}") for i in range(1,6)] class_images = load_class_images(predicted_class+1, train_batches) if class_images: random_images = random.sample(class_images, 3) # Select 3 random images for img in random_images: if img is not None: cv2_imshow(img) else: print("Random image is None.") else: print("No images found for the predicted class.") if __name__ == "__main__": main()