Spaces:
Sleeping
Sleeping
Commit ·
d3992a1
1
Parent(s): e41b4ca
Update app.py
Browse files
app.py
CHANGED
|
@@ -33,10 +33,52 @@ else:
|
|
| 33 |
# Print some statistics
|
| 34 |
print(f"Photos loaded: {len(photo_ids)}")
|
| 35 |
|
|
|
|
| 36 |
|
| 37 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
# Encode the search query
|
| 39 |
-
if not query_text and not query_photo_id:
|
| 40 |
return []
|
| 41 |
|
| 42 |
text_features = encode_search_query(model, query_text)
|
|
@@ -53,8 +95,12 @@ def search_by_text_and_photo(query_text, query_img, query_photo_id=None, photo_w
|
|
| 53 |
# Find the best match
|
| 54 |
best_photo_ids = find_best_matches(search_features, photo_features, photo_ids, 10)
|
| 55 |
|
| 56 |
-
elif
|
| 57 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 58 |
query_photo_features = query_photo_features / query_photo_features.norm(dim=1, keepdim=True)
|
| 59 |
|
| 60 |
# Combine the test and photo queries and normalize again
|
|
@@ -66,7 +112,7 @@ def search_by_text_and_photo(query_text, query_img, query_photo_id=None, photo_w
|
|
| 66 |
else:
|
| 67 |
# Display the results
|
| 68 |
print("Test search result")
|
| 69 |
-
best_photo_ids = search_unslash(query_text, photo_features, photo_ids, 10)
|
| 70 |
|
| 71 |
return best_photo_ids
|
| 72 |
|
|
@@ -76,20 +122,21 @@ with gr.Blocks() as app:
|
|
| 76 |
gr.Markdown(
|
| 77 |
"""
|
| 78 |
# CLIP Image Search Engine!
|
| 79 |
-
### Enter search query or/and
|
| 80 |
""")
|
| 81 |
|
| 82 |
with gr.Row(visible=True):
|
| 83 |
with gr.Column():
|
| 84 |
with gr.Row():
|
| 85 |
-
search_text = gr.Textbox(value='', placeholder='Search..', label='Enter
|
| 86 |
|
| 87 |
with gr.Row():
|
| 88 |
submit_btn = gr.Button("Submit", variant='primary')
|
| 89 |
clear_btn = gr.ClearButton()
|
| 90 |
|
| 91 |
-
with gr.Column():
|
| 92 |
-
search_image = gr.Image(label='
|
|
|
|
| 93 |
|
| 94 |
with gr.Row(visible=True):
|
| 95 |
output_images = gr.Gallery(allow_preview=False, label='Results.. ', info='',
|
|
@@ -102,44 +149,75 @@ with gr.Blocks() as app:
|
|
| 102 |
return {
|
| 103 |
search_image: None,
|
| 104 |
output_images: None,
|
| 105 |
-
search_text: None
|
|
|
|
|
|
|
| 106 |
}
|
| 107 |
|
| 108 |
|
| 109 |
-
clear_btn.click(clear_data, None, [search_image, output_images, search_text])
|
| 110 |
|
| 111 |
|
| 112 |
def on_select(evt: gr.SelectData, output_image_ids):
|
| 113 |
return {
|
| 114 |
-
search_image: f"https://unsplash.com/photos/{output_image_ids[evt.index]}/download?w=
|
|
|
|
|
|
|
| 115 |
}
|
| 116 |
|
| 117 |
|
| 118 |
-
output_images.select(on_select, output_image_ids, search_image)
|
| 119 |
|
| 120 |
|
| 121 |
-
def func_search(query, img):
|
| 122 |
-
best_photo_ids =
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
|
|
|
|
|
|
|
|
|
| 127 |
|
| 128 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 129 |
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
|
|
|
|
|
|
| 134 |
|
| 135 |
|
| 136 |
submit_btn.click(
|
| 137 |
func_search,
|
| 138 |
-
[search_text, search_image],
|
| 139 |
[output_images, output_image_ids]
|
| 140 |
)
|
| 141 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 142 |
'''
|
| 143 |
Launch the app
|
| 144 |
'''
|
| 145 |
app.launch()
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
# Print some statistics
|
| 34 |
print(f"Photos loaded: {len(photo_ids)}")
|
| 35 |
|
| 36 |
+
from PIL import Image
|
| 37 |
|
| 38 |
+
|
| 39 |
+
def encode_search_query(net, search_query):
|
| 40 |
+
with torch.no_grad():
|
| 41 |
+
tokenized_query = clip.tokenize(search_query)
|
| 42 |
+
# print("tokenized_query: ", tokenized_query.shape)
|
| 43 |
+
# Encode and normalize the search query using CLIP
|
| 44 |
+
text_encoded = net.encode_text(tokenized_query.to(device))
|
| 45 |
+
text_encoded /= text_encoded.norm(dim=-1, keepdim=True)
|
| 46 |
+
|
| 47 |
+
# Retrieve the feature vector
|
| 48 |
+
# print("text_encoded: ", text_encoded.shape)
|
| 49 |
+
return text_encoded
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def find_best_matches(text_features, photo_features, photo_ids, results_count=5):
|
| 53 |
+
# Compute the similarity between the search query and each photo using the Cosine similarity
|
| 54 |
+
# print("text_features: ", text_features.shape)
|
| 55 |
+
# print("photo_features: ", photo_features.shape)
|
| 56 |
+
similarities = (photo_features @ text_features.T).squeeze(1)
|
| 57 |
+
|
| 58 |
+
# Sort the photos by their similarity score
|
| 59 |
+
best_photo_idx = (-similarities).argsort()
|
| 60 |
+
# print("best_photo_idx: ", best_photo_idx.shape)
|
| 61 |
+
# print("best_photo_idx: ", best_photo_idx[:results_count])
|
| 62 |
+
|
| 63 |
+
result_list = [photo_ids[i] for i in best_photo_idx[:results_count]]
|
| 64 |
+
# print("result_list: ", len(result_list))
|
| 65 |
+
# Return the photo IDs of the best matches
|
| 66 |
+
return result_list
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def search_unslash(net, search_query, photo_features, photo_ids, results_count=10):
|
| 70 |
+
# Encode the search query
|
| 71 |
+
text_features = encode_search_query(net, search_query)
|
| 72 |
+
|
| 73 |
+
# Find the best matches
|
| 74 |
+
best_photo_ids = find_best_matches(text_features, photo_features, photo_ids, results_count)
|
| 75 |
+
|
| 76 |
+
return best_photo_ids
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def search_by_text_and_photo(query_text, query_photo=None, query_photo_id=None, photo_weight=0.5):
|
| 80 |
# Encode the search query
|
| 81 |
+
if not query_text and query_photo is None and not query_photo_id:
|
| 82 |
return []
|
| 83 |
|
| 84 |
text_features = encode_search_query(model, query_text)
|
|
|
|
| 95 |
# Find the best match
|
| 96 |
best_photo_ids = find_best_matches(search_features, photo_features, photo_ids, 10)
|
| 97 |
|
| 98 |
+
elif query_photo is not None:
|
| 99 |
+
query_photo = preprocess(query_photo)
|
| 100 |
+
query_photo = torch.tensor(query_photo).permute(2, 0, 1)
|
| 101 |
+
|
| 102 |
+
print(query_photo.shape)
|
| 103 |
+
query_photo_features = model.encode_image(query_photo)
|
| 104 |
query_photo_features = query_photo_features / query_photo_features.norm(dim=1, keepdim=True)
|
| 105 |
|
| 106 |
# Combine the test and photo queries and normalize again
|
|
|
|
| 112 |
else:
|
| 113 |
# Display the results
|
| 114 |
print("Test search result")
|
| 115 |
+
best_photo_ids = search_unslash(model, query_text, photo_features, photo_ids, 10)
|
| 116 |
|
| 117 |
return best_photo_ids
|
| 118 |
|
|
|
|
| 122 |
gr.Markdown(
|
| 123 |
"""
|
| 124 |
# CLIP Image Search Engine!
|
| 125 |
+
### Enter search query or/and select image to find the similar images
|
| 126 |
""")
|
| 127 |
|
| 128 |
with gr.Row(visible=True):
|
| 129 |
with gr.Column():
|
| 130 |
with gr.Row():
|
| 131 |
+
search_text = gr.Textbox(value='', placeholder='Search..', label='Enter search query')
|
| 132 |
|
| 133 |
with gr.Row():
|
| 134 |
submit_btn = gr.Button("Submit", variant='primary')
|
| 135 |
clear_btn = gr.ClearButton()
|
| 136 |
|
| 137 |
+
with gr.Column(visible=True) as input_image_col:
|
| 138 |
+
search_image = gr.Image(label='Select from results', interactive=False)
|
| 139 |
+
search_image_id = gr.State(None)
|
| 140 |
|
| 141 |
with gr.Row(visible=True):
|
| 142 |
output_images = gr.Gallery(allow_preview=False, label='Results.. ', info='',
|
|
|
|
| 149 |
return {
|
| 150 |
search_image: None,
|
| 151 |
output_images: None,
|
| 152 |
+
search_text: None,
|
| 153 |
+
search_image_id: None,
|
| 154 |
+
input_image_col: gr.update(visible=True)
|
| 155 |
}
|
| 156 |
|
| 157 |
|
| 158 |
+
clear_btn.click(clear_data, None, [search_image, output_images, search_text, search_image_id, input_image_col])
|
| 159 |
|
| 160 |
|
| 161 |
def on_select(evt: gr.SelectData, output_image_ids):
|
| 162 |
return {
|
| 163 |
+
search_image: f"https://unsplash.com/photos/{output_image_ids[evt.index]}/download?w=320",
|
| 164 |
+
search_image_id: output_image_ids[evt.index],
|
| 165 |
+
input_image_col: gr.update(visible=True)
|
| 166 |
}
|
| 167 |
|
| 168 |
|
| 169 |
+
output_images.select(on_select, output_image_ids, [search_image, search_image_id, input_image_col])
|
| 170 |
|
| 171 |
|
| 172 |
+
def func_search(query, img, img_id):
|
| 173 |
+
best_photo_ids = []
|
| 174 |
+
if img_id:
|
| 175 |
+
best_photo_ids = search_by_text_and_photo(query, query_photo_id=img_id)
|
| 176 |
+
elif img is not None:
|
| 177 |
+
img = Image.open(img)
|
| 178 |
+
best_photo_ids = search_by_text_and_photo(query, query_photo=img)
|
| 179 |
+
elif query:
|
| 180 |
+
best_photo_ids = search_by_text_and_photo(query)
|
| 181 |
|
| 182 |
+
if len(best_photo_ids) == 0:
|
| 183 |
+
print("Invalid Search Request")
|
| 184 |
+
return {
|
| 185 |
+
output_image_ids: [],
|
| 186 |
+
output_images: []
|
| 187 |
+
}
|
| 188 |
+
else:
|
| 189 |
+
img_urls = []
|
| 190 |
+
for p_id in best_photo_ids:
|
| 191 |
+
url = f"https://unsplash.com/photos/{p_id}/download?w=20"
|
| 192 |
+
img_urls.append(url)
|
| 193 |
|
| 194 |
+
valid_images = filter_invalid_urls(img_urls, best_photo_ids)
|
| 195 |
+
|
| 196 |
+
return {
|
| 197 |
+
output_image_ids: valid_images['image_ids'],
|
| 198 |
+
output_images: valid_images['image_urls']
|
| 199 |
+
}
|
| 200 |
|
| 201 |
|
| 202 |
submit_btn.click(
|
| 203 |
func_search,
|
| 204 |
+
[search_text, search_image, search_image_id],
|
| 205 |
[output_images, output_image_ids]
|
| 206 |
)
|
| 207 |
|
| 208 |
+
|
| 209 |
+
def on_upload(evt: gr.SelectData):
|
| 210 |
+
return {
|
| 211 |
+
search_image_id: None
|
| 212 |
+
}
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
search_image.upload(on_upload, None, search_image_id)
|
| 216 |
+
|
| 217 |
'''
|
| 218 |
Launch the app
|
| 219 |
'''
|
| 220 |
app.launch()
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
|