| import torch |
| import torchvision |
| from torchvision import transforms |
| from PIL import Image |
| import os |
| import numpy as np |
| from sklearn.metrics.pairwise import cosine_similarity |
|
|
| |
| resnet50 = torchvision.models.resnet50(pretrained=True) |
| resnet50.eval() |
|
|
| |
| transform = transforms.Compose([ |
| transforms.Resize(256), |
| transforms.CenterCrop(224), |
| transforms.ToTensor(), |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
| ]) |
|
|
| |
| images_dir = "streamlit/images" |
|
|
| |
| feature_dict = {} |
|
|
| |
| for filename in os.listdir(images_dir): |
| if filename.endswith(".jpg") or filename.endswith(".png"): |
| image_path = os.path.join(images_dir, filename) |
| image = Image.open(image_path) |
| image = transform(image) |
| image = torch.unsqueeze(image, 0) |
|
|
| |
| features = resnet50.forward(image) |
| features = torch.squeeze(features) |
| feature_dict[filename] = features.detach().numpy() |
|
|
| |
| for query_filename in os.listdir(images_dir): |
| if not (query_filename.endswith(".jpg") or query_filename.endswith(".png")): |
| continue |
| |
| |
| query_path = os.path.join(images_dir, query_filename) |
| query_image = Image.open(query_path) |
| query_image = transform(query_image) |
| query_image = torch.unsqueeze(query_image, 0) |
| |
| |
| query_features = resnet50.forward(query_image) |
| query_features = torch.squeeze(query_features) |
| query_features = query_features.detach().numpy().reshape(1, -1) |
| |
| |
| similarities = [] |
| for filename, features in feature_dict.items(): |
| if filename == query_filename: |
| continue |
| similarity = cosine_similarity(query_features, features.reshape(1, -1)) |
| similarities.append((filename, similarity)) |
| |
| |
| similarities.sort(key=lambda x: x[1], reverse=True) |
| print("Query image:", query_filename) |
| for i in range(10): |
| print("Similar image:", similarities[i][0], "Similarity score:", similarities[i][1][0]) |
|
|
| |
| |
|
|