Spaces:
Runtime error
Runtime error
| import pickle | |
| import random | |
| import numpy as np | |
| import cv2 | |
| from tensorflow.keras import models, layers | |
| # Load the trained model architecture | |
| def create_resnet18(): | |
| model = models.Sequential() | |
| model.add(layers.Conv2D(64, (3, 3), activation='relu', padding='same', input_shape=(32, 32, 3))) | |
| model.add(layers.MaxPooling2D((2, 2))) | |
| model.add(layers.Conv2D(64, (3, 3), activation='relu', padding='same')) | |
| model.add(layers.MaxPooling2D((2, 2))) | |
| model.add(layers.Conv2D(128, (3, 3), activation='relu', padding='same')) | |
| model.add(layers.Conv2D(128, (3, 3), activation='relu', padding='same')) | |
| model.add(layers.MaxPooling2D((2, 2))) | |
| model.add(layers.Conv2D(256, (3, 3), activation='relu', padding='same')) | |
| model.add(layers.Conv2D(256, (3, 3), activation='relu', padding='same')) | |
| model.add(layers.MaxPooling2D((2, 2))) | |
| model.add(layers.Flatten()) | |
| model.add(layers.Dense(512, activation='relu')) | |
| model.add(layers.Dense(10, activation='softmax')) | |
| return model | |
| # Load the pretrained weights | |
| def load_pretrained_weights(model, weights_path): | |
| model.load_weights(weights_path) | |
| # Function to unpickle a file | |
| def unpickle(file): | |
| with open(file, 'rb') as fo: | |
| dict = pickle.load(fo, encoding='bytes') | |
| return dict | |
| # Function to load images from the unpickled data batch file of a specific class | |
| def load_class_images(class_index, train_batches): | |
| images = [] | |
| for batch in train_batches: | |
| if b'data' in batch and b'labels' in batch: | |
| data = batch[b'data'] | |
| labels = batch[b'labels'] | |
| for i, label in enumerate(labels): | |
| if label == class_index: | |
| img = data[i].reshape(3, 32, 32).transpose(1, 2, 0) # Reshape and transpose the image | |
| images.append(img) | |
| return images | |
| # Function to calculate the distance of the mean embeddings with a query image | |
| def classify_query(query_image, model, mean_embeddings): | |
| query_embedding = model.predict(np.expand_dims(query_image, axis=0)) | |
| distances = [np.linalg.norm(query_embedding.flatten() - mean_embedding) for mean_embedding in mean_embeddings] | |
| predicted_class = np.argmin(distances) | |
| return predicted_class | |
| from google.colab.patches import cv2_imshow | |
| def main(): | |
| model = create_resnet18() | |
| load_pretrained_weights(model, 'pretrained_model_weights.h5') | |
| mean_embeddings = pickle.load(open('mean_embeddings.pkl', 'rb')) | |
| query_image_path = '/content/airplane_8925.png' | |
| query_image = cv2.imread(query_image_path) | |
| query_image = cv2.resize(query_image, (32, 32)) / 255.0 # Resize and normalize the image | |
| predicted_class = classify_query(query_image, model, mean_embeddings) | |
| print("Predicted Class:", predicted_class) | |
| # Load random images of the predicted class | |
| train_batches = [unpickle(f"/content/data_batch_{i}") for i in range(1,6)] | |
| class_images = load_class_images(predicted_class+1, train_batches) | |
| if class_images: | |
| random_images = random.sample(class_images, 3) # Select 3 random images | |
| for img in random_images: | |
| if img is not None: | |
| cv2_imshow(img) | |
| else: | |
| print("Random image is None.") | |
| else: | |
| print("No images found for the predicted class.") | |
| if __name__ == "__main__": | |
| main() | |