Deep_Image_Search_System / Model /centroid_app.py
AJain1234's picture
Update Model/centroid_app.py
26e20fb verified
import pickle
import random
import numpy as np
import cv2
from tensorflow.keras import models, layers
# Load the trained model architecture
def create_resnet18():
model = models.Sequential()
model.add(layers.Conv2D(64, (3, 3), activation='relu', padding='same', input_shape=(32, 32, 3)))
model.add(layers.MaxPooling2D((2, 2)))
model.add(layers.Conv2D(64, (3, 3), activation='relu', padding='same'))
model.add(layers.MaxPooling2D((2, 2)))
model.add(layers.Conv2D(128, (3, 3), activation='relu', padding='same'))
model.add(layers.Conv2D(128, (3, 3), activation='relu', padding='same'))
model.add(layers.MaxPooling2D((2, 2)))
model.add(layers.Conv2D(256, (3, 3), activation='relu', padding='same'))
model.add(layers.Conv2D(256, (3, 3), activation='relu', padding='same'))
model.add(layers.MaxPooling2D((2, 2)))
model.add(layers.Flatten())
model.add(layers.Dense(512, activation='relu'))
model.add(layers.Dense(10, activation='softmax'))
return model
# Load the pretrained weights
def load_pretrained_weights(model, weights_path):
model.load_weights(weights_path)
# Function to unpickle a file
def unpickle(file):
with open(file, 'rb') as fo:
dict = pickle.load(fo, encoding='bytes')
return dict
# Function to load images from the unpickled data batch file of a specific class
def load_class_images(class_index, train_batches):
images = []
for batch in train_batches:
if b'data' in batch and b'labels' in batch:
data = batch[b'data']
labels = batch[b'labels']
for i, label in enumerate(labels):
if label == class_index:
img = data[i].reshape(3, 32, 32).transpose(1, 2, 0) # Reshape and transpose the image
images.append(img)
return images
# Function to calculate the distance of the mean embeddings with a query image
def classify_query(query_image, model, mean_embeddings):
query_embedding = model.predict(np.expand_dims(query_image, axis=0))
distances = [np.linalg.norm(query_embedding.flatten() - mean_embedding) for mean_embedding in mean_embeddings]
predicted_class = np.argmin(distances)
return predicted_class
def retrieve(query_image,k=3):
model = create_resnet18()
load_pretrained_weights(model, 'Model/pretrained_model_weights.h5')
mean_embeddings = pickle.load(open('Model/data/mean_embeddings.pkl', 'rb'))
# query_image_path = '/content/airplane_8925.png'
# query_image = cv2.imread(query_image_path)
query_image = cv2.resize(query_image, (32, 32)) / 255.0 # Resize and normalize the image
predicted_class = classify_query(query_image, model, mean_embeddings)
# print("Predicted Class:", predicted_class)
# Load random images of the predicted class
train_batches = [unpickle(rf"Model/data/data_batch_{i}") for i in range(1,6)]
class_images = load_class_images(predicted_class, train_batches)
if class_images:
random_images = random.sample(class_images, k) # Select 3 random images
return random_images
else:
print("No images found for the predicted class.")