import torch import torchvision from torchvision import transforms from PIL import Image import os import numpy as np from sklearn.metrics.pairwise import cosine_similarity # Load pre-trained ResNet-50 model resnet50 = torchvision.models.resnet50(pretrained=True) resnet50.eval() # Set model to evaluation mode # Define image transformation transform = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) # Path to directory containing images images_dir = "streamlit/images" # Dictionary to store feature vectors feature_dict = {} # Loop through images in the directory for filename in os.listdir(images_dir): if filename.endswith(".jpg") or filename.endswith(".png"): # Check for image file extensions image_path = os.path.join(images_dir, filename) image = Image.open(image_path) image = transform(image) # Apply image transformation image = torch.unsqueeze(image, 0) # Add batch dimension # Extract features from penultimate layer features = resnet50.forward(image) features = torch.squeeze(features) # Remove batch dimension feature_dict[filename] = features.detach().numpy() # Perform nearest neighbor search for each query image for query_filename in os.listdir(images_dir): if not (query_filename.endswith(".jpg") or query_filename.endswith(".png")): # Check for image file extensions continue # Load query image query_path = os.path.join(images_dir, query_filename) query_image = Image.open(query_path) query_image = transform(query_image) # Apply image transformation query_image = torch.unsqueeze(query_image, 0) # Add batch dimension # Extract features from penultimate layer query_features = resnet50.forward(query_image) query_features = torch.squeeze(query_features) # Remove batch dimension query_features = query_features.detach().numpy().reshape(1, -1) # Convert to 2D array # Compute cosine similarity between query image and all other images similarities = [] for filename, features in feature_dict.items(): if filename == query_filename: continue similarity = cosine_similarity(query_features, features.reshape(1, -1)) similarities.append((filename, similarity)) # Sort images by similarity score and print the 10 most similar images similarities.sort(key=lambda x: x[1], reverse=True) print("Query image:", query_filename) for i in range(10): print("Similar image:", similarities[i][0], "Similarity score:", similarities[i][1][0]) # Streamlit App # Add Streamlit app code here