Spaces:
Sleeping
Sleeping
Commit ·
415d5ea
1
Parent(s): e392687
batch process, faiss, gpu support, optimise
Browse files- .gitignore +1 -0
- app.py +86 -39
- requirements.txt +4 -2
.gitignore
CHANGED
|
@@ -2,5 +2,6 @@ env/
|
|
| 2 |
images/
|
| 3 |
__pycache__/
|
| 4 |
*.tree
|
|
|
|
| 5 |
secrets.toml
|
| 6 |
kaggle.json
|
|
|
|
| 2 |
images/
|
| 3 |
__pycache__/
|
| 4 |
*.tree
|
| 5 |
+
*.index
|
| 6 |
secrets.toml
|
| 7 |
kaggle.json
|
app.py
CHANGED
|
@@ -2,7 +2,7 @@ import streamlit as st
|
|
| 2 |
import torch
|
| 3 |
import os
|
| 4 |
import torchvision
|
| 5 |
-
|
| 6 |
from PIL import Image
|
| 7 |
import traceback
|
| 8 |
from tqdm import tqdm
|
|
@@ -11,27 +11,41 @@ from slugify import slugify
|
|
| 11 |
import opendatasets as od
|
| 12 |
import json
|
| 13 |
import argparse
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
|
| 15 |
-
|
|
|
|
| 16 |
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
| 17 |
FOLDER = "images/"
|
| 18 |
NUM_TREES = 100
|
| 19 |
FEATURES = 1000
|
| 20 |
FILETYPES = [".png", ".jpg", ".jpeg", ".tiff", ".bmp"]
|
| 21 |
-
|
| 22 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
|
| 24 |
|
| 25 |
@st.cache_resource
|
| 26 |
def dl_embeddings():
|
| 27 |
"""dl pretrained embeddings in production environment instead of creating"""
|
| 28 |
# Connect to your Blob Storage account
|
|
|
|
|
|
|
|
|
|
| 29 |
connect_str = st.secrets["connectionstring"]
|
| 30 |
blob_service_client = BlobServiceClient.from_connection_string(connect_str)
|
| 31 |
|
| 32 |
# Specify container and blob names
|
| 33 |
container_name = "imagessearch"
|
| 34 |
-
blob_name = f"{slugify(FOLDER)}.
|
| 35 |
|
| 36 |
# Get a reference to the blob
|
| 37 |
blob_client = blob_service_client.get_blob_client(
|
|
@@ -39,7 +53,7 @@ def dl_embeddings():
|
|
| 39 |
)
|
| 40 |
|
| 41 |
# Download the binary data
|
| 42 |
-
download_file_path = f"{slugify(FOLDER)}.
|
| 43 |
with open(download_file_path, "wb") as download_file:
|
| 44 |
download_file.write(blob_client.download_blob().readall())
|
| 45 |
|
|
@@ -56,16 +70,18 @@ def load_dataset():
|
|
| 56 |
},
|
| 57 |
f,
|
| 58 |
)
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
|
|
|
| 63 |
|
| 64 |
|
| 65 |
# Load a pre-trained image feature extractor model
|
| 66 |
@st.cache_resource
|
| 67 |
def load_model():
|
| 68 |
"""Loads a pre-trained image feature extractor model."""
|
|
|
|
| 69 |
model = torch.hub.load(
|
| 70 |
"NVIDIA/DeepLearningExamples:torchhub",
|
| 71 |
"nvidia_efficientnet_b0",
|
|
@@ -104,9 +120,19 @@ def load_images(file_paths):
|
|
| 104 |
return images
|
| 105 |
|
| 106 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 107 |
# Function to preprocess images
|
| 108 |
def preprocess_image(image):
|
| 109 |
"""Preprocesses an image for feature extraction."""
|
|
|
|
| 110 |
if image.mode == "RGB": # Already has 3 channels
|
| 111 |
pass # No need to modify
|
| 112 |
elif image.mode == "L": # Grayscale image
|
|
@@ -128,57 +154,77 @@ def preprocess_image(image):
|
|
| 128 |
return preprocess(image)
|
| 129 |
|
| 130 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 131 |
# Extract features from a list of images
|
| 132 |
-
def extract_features(
|
| 133 |
"""Extracts features from a list of images."""
|
| 134 |
print("Extracting features:")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 135 |
features = []
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
|
|
|
| 141 |
|
| 142 |
|
| 143 |
# Build an Annoy index for efficient similarity search
|
| 144 |
def build_annoy_index(features):
|
| 145 |
"""Builds an Annoy index for efficient similarity search."""
|
| 146 |
-
print("Building
|
| 147 |
f = features[0].shape[0] # Feature dimensionality
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
|
|
|
| 153 |
|
| 154 |
|
| 155 |
# Perform reverse image search
|
| 156 |
-
def search_similar_images(
|
| 157 |
"""Finds similar images based on a query image feature."""
|
| 158 |
-
index =
|
| 159 |
-
|
| 160 |
-
query_image = Image.open(uploaded_file)
|
| 161 |
-
model = load_model()
|
| 162 |
# Extract features and search
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
)
|
| 166 |
-
|
| 167 |
-
query_feature,
|
|
|
|
| 168 |
)
|
| 169 |
-
return query_image, nearest_neighbors, distances
|
| 170 |
|
| 171 |
|
| 172 |
@st.cache_data
|
| 173 |
def save_embedding(folder=FOLDER):
|
| 174 |
-
if os.path.isfile(f"{slugify(FOLDER)}.
|
|
|
|
| 175 |
return
|
|
|
|
| 176 |
model = load_model() # Load the model once
|
| 177 |
file_paths = get_all_file_paths(folder_path=folder)
|
| 178 |
-
images = load_images(file_paths)
|
| 179 |
-
features = extract_features(
|
| 180 |
index = build_annoy_index(features)
|
| 181 |
-
|
| 182 |
|
| 183 |
|
| 184 |
def display_image(idx, dist):
|
|
@@ -214,11 +260,12 @@ if __name__ == "__main__":
|
|
| 214 |
)
|
| 215 |
|
| 216 |
if uploaded_file is not None:
|
|
|
|
|
|
|
| 217 |
query_image, nearest_neighbors, distances = search_similar_images(
|
| 218 |
-
|
| 219 |
)
|
| 220 |
|
| 221 |
-
st.image(query_image.resize([256, 256]), caption="Query Image", width=200)
|
| 222 |
st.subheader("Similar Images:")
|
| 223 |
cols = st.columns([1] * 5)
|
| 224 |
for i, (idx, dist) in enumerate(
|
|
|
|
| 2 |
import torch
|
| 3 |
import os
|
| 4 |
import torchvision
|
| 5 |
+
import faiss
|
| 6 |
from PIL import Image
|
| 7 |
import traceback
|
| 8 |
from tqdm import tqdm
|
|
|
|
| 11 |
import opendatasets as od
|
| 12 |
import json
|
| 13 |
import argparse
|
| 14 |
+
from streamlit_cropper import st_cropper
|
| 15 |
+
from azure.storage.blob import BlobServiceClient
|
| 16 |
+
from torch.utils.data import Dataset, DataLoader
|
| 17 |
+
import torchvision.transforms
|
| 18 |
+
import numpy as np
|
| 19 |
+
import faiss.contrib.torch_utils
|
| 20 |
|
| 21 |
+
BATCH_SIZE = 200
|
| 22 |
+
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 23 |
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
| 24 |
FOLDER = "images/"
|
| 25 |
NUM_TREES = 100
|
| 26 |
FEATURES = 1000
|
| 27 |
FILETYPES = [".png", ".jpg", ".jpeg", ".tiff", ".bmp"]
|
| 28 |
+
LIBRARIES = [
|
| 29 |
+
"https://www.kaggle.com/datasets/athota1/caltech101",
|
| 30 |
+
"https://www.kaggle.com/datasets/gpiosenka/sports-classification",
|
| 31 |
+
"https://www.kaggle.com/datasets/puneet6060/intel-image-classification",
|
| 32 |
+
"https://www.kaggle.com/datasets/kkhandekar/image-dataset",
|
| 33 |
+
]
|
| 34 |
|
| 35 |
|
| 36 |
@st.cache_resource
|
| 37 |
def dl_embeddings():
|
| 38 |
"""dl pretrained embeddings in production environment instead of creating"""
|
| 39 |
# Connect to your Blob Storage account
|
| 40 |
+
if os.path.isfile(f"{slugify(FOLDER)}.index"):
|
| 41 |
+
print("Embeddings files already exists, skip download")
|
| 42 |
+
return
|
| 43 |
connect_str = st.secrets["connectionstring"]
|
| 44 |
blob_service_client = BlobServiceClient.from_connection_string(connect_str)
|
| 45 |
|
| 46 |
# Specify container and blob names
|
| 47 |
container_name = "imagessearch"
|
| 48 |
+
blob_name = f"{slugify(FOLDER)}.index"
|
| 49 |
|
| 50 |
# Get a reference to the blob
|
| 51 |
blob_client = blob_service_client.get_blob_client(
|
|
|
|
| 53 |
)
|
| 54 |
|
| 55 |
# Download the binary data
|
| 56 |
+
download_file_path = f"{slugify(FOLDER)}.index" # Path to save the downloaded file
|
| 57 |
with open(download_file_path, "wb") as download_file:
|
| 58 |
download_file.write(blob_client.download_blob().readall())
|
| 59 |
|
|
|
|
| 70 |
},
|
| 71 |
f,
|
| 72 |
)
|
| 73 |
+
for lib in LIBRARIES:
|
| 74 |
+
od.download(
|
| 75 |
+
lib,
|
| 76 |
+
"images/",
|
| 77 |
+
)
|
| 78 |
|
| 79 |
|
| 80 |
# Load a pre-trained image feature extractor model
|
| 81 |
@st.cache_resource
|
| 82 |
def load_model():
|
| 83 |
"""Loads a pre-trained image feature extractor model."""
|
| 84 |
+
print("Loading pretrained model...")
|
| 85 |
model = torch.hub.load(
|
| 86 |
"NVIDIA/DeepLearningExamples:torchhub",
|
| 87 |
"nvidia_efficientnet_b0",
|
|
|
|
| 120 |
return images
|
| 121 |
|
| 122 |
|
| 123 |
+
def load_image(file_path):
|
| 124 |
+
"""Load all the images from file paths."""
|
| 125 |
+
try:
|
| 126 |
+
image = Image.open(file_path).resize([224, 224])
|
| 127 |
+
return image
|
| 128 |
+
except BaseException as e:
|
| 129 |
+
print("Error loading ", file_path, e)
|
| 130 |
+
|
| 131 |
+
|
| 132 |
# Function to preprocess images
|
| 133 |
def preprocess_image(image):
|
| 134 |
"""Preprocesses an image for feature extraction."""
|
| 135 |
+
|
| 136 |
if image.mode == "RGB": # Already has 3 channels
|
| 137 |
pass # No need to modify
|
| 138 |
elif image.mode == "L": # Grayscale image
|
|
|
|
| 154 |
return preprocess(image)
|
| 155 |
|
| 156 |
|
| 157 |
+
class ImageLoader(Dataset):
|
| 158 |
+
def __init__(self, image_files, transform, load_image):
|
| 159 |
+
self.transform = transform
|
| 160 |
+
self.load_image = load_image
|
| 161 |
+
self.image_files = image_files
|
| 162 |
+
|
| 163 |
+
def __len__(self):
|
| 164 |
+
return len(self.image_files)
|
| 165 |
+
|
| 166 |
+
def __getitem__(self, index):
|
| 167 |
+
return self.transform(self.load_image(self.image_files[index]))
|
| 168 |
+
|
| 169 |
+
|
| 170 |
# Extract features from a list of images
|
| 171 |
+
def extract_features(file_paths, model):
|
| 172 |
"""Extracts features from a list of images."""
|
| 173 |
print("Extracting features:")
|
| 174 |
+
loader = DataLoader(
|
| 175 |
+
ImageLoader(file_paths, transform=preprocess_image, load_image=load_image),
|
| 176 |
+
batch_size=BATCH_SIZE,
|
| 177 |
+
)
|
| 178 |
features = []
|
| 179 |
+
model = model.to(DEVICE)
|
| 180 |
+
with torch.no_grad():
|
| 181 |
+
for batch_idx, images in enumerate(tqdm(loader)):
|
| 182 |
+
images = images.to(DEVICE)
|
| 183 |
+
features.append(model(images))
|
| 184 |
+
return torch.cat(features)
|
| 185 |
|
| 186 |
|
| 187 |
# Build an Annoy index for efficient similarity search
|
| 188 |
def build_annoy_index(features):
|
| 189 |
"""Builds an Annoy index for efficient similarity search."""
|
| 190 |
+
print("Building faiss index:")
|
| 191 |
f = features[0].shape[0] # Feature dimensionality
|
| 192 |
+
index = faiss.IndexIDMap(faiss.IndexFlatIP(f))
|
| 193 |
+
index.add_with_ids(
|
| 194 |
+
features.cpu().detach().numpy(), np.array(range(len(features)))
|
| 195 |
+
) # Adjust num_trees for accuracy vs. speed trade-off
|
| 196 |
+
print("built faiss index:")
|
| 197 |
+
return index
|
| 198 |
|
| 199 |
|
| 200 |
# Perform reverse image search
|
| 201 |
+
def search_similar_images(query_image, num_results, f=FEATURES):
|
| 202 |
"""Finds similar images based on a query image feature."""
|
| 203 |
+
index = faiss.read_index(f"{slugify(FOLDER)}.index")
|
| 204 |
+
model = load_model().to(DEVICE)
|
|
|
|
|
|
|
| 205 |
# Extract features and search
|
| 206 |
+
proc_image = preprocess_image(query_image).unsqueeze(0).to(DEVICE)
|
| 207 |
+
query_feature = model(proc_image)
|
| 208 |
+
query_feature = query_feature.cpu().detach().numpy()
|
| 209 |
+
distances, nearest_neighbors = index.search(
|
| 210 |
+
query_feature,
|
| 211 |
+
num_results,
|
| 212 |
)
|
| 213 |
+
return query_image, nearest_neighbors[0], distances[0]
|
| 214 |
|
| 215 |
|
| 216 |
@st.cache_data
|
| 217 |
def save_embedding(folder=FOLDER):
|
| 218 |
+
if os.path.isfile(f"{slugify(FOLDER)}.index"):
|
| 219 |
+
print("skipping recreating image embeddings")
|
| 220 |
return
|
| 221 |
+
print("Performing image embeddings")
|
| 222 |
model = load_model() # Load the model once
|
| 223 |
file_paths = get_all_file_paths(folder_path=folder)
|
| 224 |
+
# images = load_images(file_paths)
|
| 225 |
+
features = extract_features(file_paths, model)
|
| 226 |
index = build_annoy_index(features)
|
| 227 |
+
faiss.write_index(index, f"{slugify(FOLDER)}.index")
|
| 228 |
|
| 229 |
|
| 230 |
def display_image(idx, dist):
|
|
|
|
| 260 |
)
|
| 261 |
|
| 262 |
if uploaded_file is not None:
|
| 263 |
+
query_image = Image.open(uploaded_file)
|
| 264 |
+
cropped = st_cropper(query_image)
|
| 265 |
query_image, nearest_neighbors, distances = search_similar_images(
|
| 266 |
+
cropped.resize([256, 256]), n_matches
|
| 267 |
)
|
| 268 |
|
|
|
|
| 269 |
st.subheader("Similar Images:")
|
| 270 |
cols = st.columns([1] * 5)
|
| 271 |
for i, (idx, dist) in enumerate(
|
requirements.txt
CHANGED
|
@@ -1,8 +1,10 @@
|
|
| 1 |
-
|
|
|
|
| 2 |
torch
|
| 3 |
torchvision
|
| 4 |
streamlit
|
| 5 |
tqdm
|
| 6 |
python-slugify
|
| 7 |
opendatasets
|
| 8 |
-
azure-storage-blob
|
|
|
|
|
|
| 1 |
+
faiss-cpu
|
| 2 |
+
faiss-gpu
|
| 3 |
torch
|
| 4 |
torchvision
|
| 5 |
streamlit
|
| 6 |
tqdm
|
| 7 |
python-slugify
|
| 8 |
opendatasets
|
| 9 |
+
azure-storage-blob
|
| 10 |
+
streamlit-cropper
|