Spaces:
Runtime error
Runtime error
| import torch | |
| import torchvision.transforms as transforms | |
| from PIL import Image | |
| import os | |
| from sklearn.neighbors import NearestNeighbors | |
| import numpy as np | |
| # Load pre-trained ResNet-50 model | |
| model = torch.hub.load('pytorch/vision:v0.6.0', 'resnet50', weights=None) | |
| model.eval() | |
| # Define image transformation | |
| transform = transforms.Compose([ | |
| transforms.Resize((256, 256)), | |
| transforms.CenterCrop((224, 224)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
| ]) | |
| # Directory containing images | |
| images_dir = "picture/" | |
| # List all image files in directory | |
| image_files = [f for f in os.listdir(images_dir) if f.endswith('.jpg') or f.endswith('.png')] | |
| if not image_files: | |
| print("No images found in directory") | |
| else: | |
| # Dictionary to store feature vectors | |
| feature_dict = {} | |
| # Loop through images in the directory | |
| for filename in image_files: | |
| # Load image | |
| image_path = os.path.join(images_dir, filename) | |
| with Image.open(image_path) as img: | |
| img = transform(img).unsqueeze(0) | |
| # Extract features from penultimate layer | |
| with torch.no_grad(): | |
| features = model(img) | |
| features = torch.squeeze(features).detach().numpy() | |
| feature_dict[filename] = features | |
| # Convert dictionary of feature vectors to array | |
| feature_array = np.array(list(feature_dict.values())) | |
| if len(feature_array) == 0: | |
| print("No feature vectors extracted") | |
| else: | |
| # Fit nearest neighbor model on feature vectors | |
| nbrs = NearestNeighbors(n_neighbors=10, algorithm='auto').fit(feature_array) | |
| # Loop through images again to query nearest neighbors | |
| for query_filename in image_files: | |
| query_feature = feature_dict[query_filename] | |
| distances, indices = nbrs.kneighbors(query_feature.reshape(1, -1)) | |
| print("Query image:", query_filename) | |
| print("Most similar images:") | |
| for i, idx in enumerate(indices[0]): | |
| if i == 0: | |
| continue # Skip first index, as it will always be the query image itself | |
| print(image_files[idx]) | |
| print("-----") | |