File size: 2,767 Bytes
98f5bb6 3b98306 98f5bb6 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 | import torch
import torchvision
from torchvision import transforms
from PIL import Image
import os
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
# Load pre-trained ResNet-50 model
resnet50 = torchvision.models.resnet50(pretrained=True)
resnet50.eval() # Set model to evaluation mode
# Define image transformation
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# Path to directory containing images
images_dir = "streamlit/images"
# Dictionary to store feature vectors
feature_dict = {}
# Loop through images in the directory
for filename in os.listdir(images_dir):
if filename.endswith(".jpg") or filename.endswith(".png"): # Check for image file extensions
image_path = os.path.join(images_dir, filename)
image = Image.open(image_path)
image = transform(image) # Apply image transformation
image = torch.unsqueeze(image, 0) # Add batch dimension
# Extract features from penultimate layer
features = resnet50.forward(image)
features = torch.squeeze(features) # Remove batch dimension
feature_dict[filename] = features.detach().numpy()
# Perform nearest neighbor search for each query image
for query_filename in os.listdir(images_dir):
if not (query_filename.endswith(".jpg") or query_filename.endswith(".png")): # Check for image file extensions
continue
# Load query image
query_path = os.path.join(images_dir, query_filename)
query_image = Image.open(query_path)
query_image = transform(query_image) # Apply image transformation
query_image = torch.unsqueeze(query_image, 0) # Add batch dimension
# Extract features from penultimate layer
query_features = resnet50.forward(query_image)
query_features = torch.squeeze(query_features) # Remove batch dimension
query_features = query_features.detach().numpy().reshape(1, -1) # Convert to 2D array
# Compute cosine similarity between query image and all other images
similarities = []
for filename, features in feature_dict.items():
if filename == query_filename:
continue
similarity = cosine_similarity(query_features, features.reshape(1, -1))
similarities.append((filename, similarity))
# Sort images by similarity score and print the 10 most similar images
similarities.sort(key=lambda x: x[1], reverse=True)
print("Query image:", query_filename)
for i in range(10):
print("Similar image:", similarities[i][0], "Similarity score:", similarities[i][1][0])
# Streamlit App
# Add Streamlit app code here
|