Spaces:
Build error
Build error
| import os | |
| import numpy as np | |
| from PIL import Image | |
| from torchvision import transforms, models | |
| import torch | |
| import torch.nn.functional as F | |
| import streamlit as st | |
| import pickle | |
| from sklearn.neighbors import NearestNeighbors | |
| import faiss | |
| # Set up the image transformation | |
| transform = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
| ]) | |
| # Data augmentation transforms | |
| augment_transform = transforms.Compose([ | |
| transforms.RandomHorizontalFlip(), | |
| transforms.RandomRotation(20), | |
| transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1), | |
| transforms.RandomResizedCrop(224, scale=(0.8, 1.0), ratio=(0.75, 1.33)), | |
| ]) | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| def load_model(): | |
| model = models.efficientnet_b0(pretrained=True) | |
| model.classifier = torch.nn.Identity() # Remove the final classification layer | |
| model = model.to(device) | |
| model.eval() | |
| return model | |
| model = load_model() | |
| def extract_features(img): | |
| img_t = transform(img) | |
| batch_t = torch.unsqueeze(img_t, 0).to(device) | |
| with torch.no_grad(): | |
| features = model(batch_t) | |
| features = F.normalize(features, p=2, dim=1) | |
| return features.cpu().squeeze().numpy() | |
| def generate_augmented_images(img, num_augmented=5): | |
| augmented_images = [] | |
| for _ in range(num_augmented): | |
| augmented = augment_transform(img) | |
| augmented_images.append(augmented) | |
| return augmented_images | |
| # def load_and_index_images(root_dir): #without adding data augmented images | |
| # image_paths = [] | |
| # features_list = [] | |
| # categories = [] | |
| # for category in os.listdir(root_dir): | |
| # category_path = os.path.join(root_dir, category) | |
| # if os.path.isdir(category_path): | |
| # for img_name in os.listdir(category_path): | |
| # img_path = os.path.join(category_path, img_name) | |
| # img = Image.open(img_path).convert('RGB') | |
| # features = extract_features(img) | |
| # image_paths.append(img_path) | |
| # features_list.append(features) | |
| # categories.append(category) | |
| # features_array = np.array(features_list).astype('float32') | |
| # d = features_array.shape[1] # dimension | |
| # index = faiss.IndexFlatIP(d) # use inner product (cosine similarity on normalized vectors) | |
| # index.add(features_array) | |
| # return index, image_paths, categories | |
| def load_and_index_images(root_dir): | |
| image_paths = [] | |
| features_list = [] | |
| categories = [] | |
| for category in os.listdir(root_dir): | |
| category_path = os.path.join(root_dir, category) | |
| if os.path.isdir(category_path): | |
| for img_name in os.listdir(category_path): | |
| img_path = os.path.join(category_path, img_name) | |
| img = Image.open(img_path).convert('RGB') | |
| # Generate augmented images | |
| augmented_images = generate_augmented_images(img) | |
| features = extract_features(img) | |
| image_paths.append(img_path) | |
| features_list.append(features) | |
| categories.append(category) | |
| for aug_img in augmented_images: | |
| aug_features = extract_features(aug_img) | |
| features_list.append(aug_features) | |
| image_paths.append(img_path) # Use original path for augmented images | |
| categories.append(category) | |
| features_array = np.array(features_list).astype('float32') | |
| d = features_array.shape[1] # dimension | |
| index = faiss.IndexFlatIP(d) # use inner product (cosine similarity on normalized vectors) | |
| index.add(features_array) | |
| return index, image_paths, categories | |
| def save_index_and_metadata(nn, image_paths, categories, index_file, metadata_file): | |
| with open(index_file, 'wb') as f: | |
| pickle.dump(nn, f) | |
| with open(metadata_file, 'wb') as f: | |
| pickle.dump((image_paths, categories), f) | |
| def load_index_and_metadata(index_file, metadata_file): | |
| with open(index_file, 'rb') as f: | |
| nn = pickle.load(f) | |
| with open(metadata_file, 'rb') as f: | |
| image_paths, categories = pickle.load(f) | |
| return nn, image_paths, categories | |
| def search_similar_images(index, image_paths, categories, query_features, k=20): | |
| query_features = query_features.reshape(1, -1).astype('float32') | |
| similarities, indices = index.search(query_features, k) | |
| similar_images = [image_paths[i] for i in indices[0]] | |
| similarity_scores = similarities[0] | |
| similar_categories = [categories[i] for i in indices[0]] | |
| return similar_images, similarity_scores, similar_categories | |
| def index_files_exist(index_file, metadata_file): | |
| return os.path.exists(index_file) and os.path.exists(metadata_file) | |
| def main(): | |
| st.title("Image Classification and Similarity Search") | |
| index_file = "faiss-d2-nn_index.pkl" | |
| metadata_file = "faiss-d2-image_metadata.pkl" | |
| if not index_files_exist(index_file, metadata_file): | |
| st.warning("Index files not found. Creating new index...") | |
| root_dir = "Dataset2" # Replace with your dataset path | |
| index, image_paths, categories = load_and_index_images(root_dir) | |
| save_index_and_metadata(index, image_paths, categories, index_file, metadata_file) | |
| st.success("Index created and saved successfully!") | |
| else: | |
| index, image_paths, categories = load_index_and_metadata(index_file, metadata_file) | |
| st.success("Index loaded successfully!") | |
| uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"]) | |
| if uploaded_file is not None: | |
| image = Image.open(uploaded_file).convert('RGB') | |
| query_features = extract_features(image) | |
| # Search for similar images | |
| similar_images, similarities, similar_categories = search_similar_images(index, image_paths, categories, query_features, k=50) | |
| # Get the predicted class (most common category among top 5 similar images) | |
| predicted_class = max(set(similar_categories[:5]), key=similar_categories[:5].count) | |
| # Display query and matched image | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| st.subheader("Query Image") | |
| st.image(image, caption="Uploaded Image", use_column_width=True) | |
| st.write(f"Image ID: {uploaded_file.name}") | |
| with col2: | |
| if similar_images: | |
| st.subheader("Matched Image") | |
| matched_image_path = similar_images[0] | |
| st.image(Image.open(matched_image_path), | |
| caption=f"Matched Image (Similarity: {similarities[0]:.2f})", | |
| use_column_width=True) | |
| st.write(f"Image ID: {os.path.basename(matched_image_path)}") | |
| else: | |
| st.write("No matched image found") | |
| st.subheader(f"Product Category: {predicted_class}") | |
| similarity_threshold = st.slider("Similarity threshold", min_value=0.0, max_value=1.0, value=0.5, step=0.05) | |
| # Filter results based on similarity threshold and predicted class, and remove duplicates | |
| query_file_name = uploaded_file.name | |
| seen_file_names = set([query_file_name]) # Add query image to seen set | |
| filtered_results = [] | |
| for img, sim, cat in zip(similar_images[1:], similarities[1:], similar_categories[1:]): # Start from index 1 | |
| file_name = os.path.basename(img) | |
| if sim >= similarity_threshold and cat == predicted_class and file_name not in seen_file_names: | |
| filtered_results.append((img, sim)) | |
| seen_file_names.add(file_name) | |
| # Rest of the code remains the same | |
| if filtered_results: | |
| max_images = len(filtered_results) | |
| num_display = st.slider("Number of similar images to display", min_value=0, max_value=max_images, value=min(20, max_images)) | |
| st.subheader("Similar Images") | |
| st.info(f"Displaying {num_display} out of {max_images} unique similar images found for the uploaded query image.") | |
| # Create a grid for displaying similar images | |
| num_cols = 5 | |
| num_rows = (num_display + num_cols - 1) // num_cols | |
| for row in range(num_rows): | |
| cols = st.columns(num_cols) | |
| for col in range(num_cols): | |
| idx = row * num_cols + col | |
| if idx < num_display: | |
| img_path, sim = filtered_results[idx] | |
| with cols[col]: | |
| st.image(Image.open(img_path), use_column_width=True) | |
| st.write(f"Similarity: {sim:.2f}") | |
| st.write(f"Image ID: {os.path.basename(img_path)}") | |
| else: | |
| st.info("No similar images found above the similarity threshold in the predicted class.") | |
| if __name__ == "__main__": | |
| main() |