Navyabhat commited on
Commit
02419e1
·
1 Parent(s): e7fc315

Upload 5 files

Browse files
Files changed (5) hide show
  1. app.py +111 -0
  2. features.npy +3 -0
  3. photo_ids.csv +0 -0
  4. photos.tsv000 +0 -0
  5. requirements.txt +5 -0
app.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #Acknowledgments:
2
+ #This project is inspired by:
3
+ #1. https://github.com/haltakov/natural-language-image-search by Vladimir Haltakov
4
+ #2. OpenAI's CLIP
5
+
6
+
7
+
8
+ #Importing all the necessary libraries
9
+ import torch
10
+ import requests
11
+ import numpy as np
12
+ import pandas as pd
13
+ import gradio as gr
14
+ from io import BytesIO
15
+ from PIL import Image as PILIMAGE
16
+
17
+ from transformers import CLIPProcessor, CLIPModel, CLIPTokenizer
18
+ from sentence_transformers import SentenceTransformer, util
19
+
20
+
21
+
22
+ device = "cuda" if torch.cuda.is_available() else "cpu"
23
+
24
+ # Define model
25
+ model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
26
+ processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
27
+ tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")
28
+
29
+ # Load data
30
+ photos = pd.read_csv("./photos.tsv000", sep='\t', header=0)
31
+ photo_features = np.load("./features.npy")
32
+ photo_ids = pd.read_csv("./photo_ids.csv")
33
+ photo_ids = list(photo_ids['photo_id'])
34
+
35
+
36
+
37
+ def encode_text(text):
38
+ with torch.no_grad():
39
+ # Encode and normalize the description using CLIP
40
+ inputs = tokenizer([text], padding=True, return_tensors="pt")
41
+ inputs = processor(text=[text], images=None, return_tensors="pt", padding=True)
42
+ text_encoded = model.get_text_features(**inputs).detach().numpy()
43
+ return text_encoded
44
+
45
+ def encode_image(image):
46
+ image = PILIMAGE.fromarray(image.astype('uint8'), 'RGB')
47
+ with torch.no_grad():
48
+ photo_preprocessed = processor(text=None, images=image, return_tensors="pt", padding=True)["pixel_values"]
49
+ search_photo_feature = model.get_image_features(photo_preprocessed.to(device))
50
+ search_photo_feature /= search_photo_feature.norm(dim=-1, keepdim=True)
51
+ image_encoded = search_photo_feature.cpu().numpy()
52
+ return image_encoded
53
+
54
+ T2I = "Text2Image"
55
+ I2I = "Image2Image"
56
+
57
+ def similarity(feature, photo_features):
58
+ similarities = list((feature @ photo_features.T).squeeze(0))
59
+ return similarities
60
+
61
+ def find_best_matches(image, mode, text):
62
+ # Compute the similarity between the descrption and each photo using the Cosine similarity
63
+ print ("Mode now ",mode)
64
+
65
+ if mode == "Text2Image":
66
+ # Encode the text input
67
+ text_features = encode_text(text)
68
+ feature = text_features
69
+ similarities = similarity(text_features, photo_features)
70
+
71
+
72
+ else:
73
+ #Encode the image input
74
+ image_features = encode_image(image)
75
+ feature = image_features
76
+ similarities = similarity(image_features, photo_features)
77
+
78
+ # Sort the photos by their similarity score
79
+ best_photos = sorted(zip(similarities, range(photo_features.shape[0])), key=lambda x: x[0], reverse=True)
80
+
81
+ matched_images = []
82
+ for i in range(3):
83
+ # Retrieve the photo ID
84
+ idx = best_photos[i][1]
85
+ photo_id = photo_ids[idx]
86
+
87
+ # Get all metadata for this photo
88
+ photo_data = photos[photos["photo_id"] == photo_id].iloc[0]
89
+
90
+ # Display the images
91
+ #display(Image(url=photo_data["photo_image_url"] + "?w=640"))
92
+ response = requests.get(photo_data["photo_image_url"] + "?w=640")
93
+ img = PILIMAGE.open(BytesIO(response.content))
94
+ matched_images.append(img)
95
+ return matched_images
96
+
97
+
98
+
99
+
100
+ gr.Interface(fn=find_best_matches,
101
+ inputs=[
102
+ gr.Image(label="Image to search", optional=True),
103
+ gr.Radio([T2I, I2I]),
104
+ gr.Textbox(lines=1, label="Text query", placeholder="Introduce the search text...",
105
+ )],
106
+ theme="grass",
107
+ outputs=[gr.Gallery(
108
+ label="Generated images", show_label=False, elem_id="gallery"
109
+ ).style(grid=[2], height="auto")], enable_queue=True, title="CLIP Image Search",
110
+ description="This application displays TOP THREE images from Unsplash dataset that best match the search query provided by the user. Moreover, the input can be provided via two modes ie text or image form.").launch()
111
+
features.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:31ac381e52fa007821a642b5808ac9a6eaf7163322ab340d36bcc3c2a94a38c8
3
+ size 25596032
photo_ids.csv ADDED
The diff for this file is too large to render. See raw diff
 
photos.tsv000 ADDED
The diff for this file is too large to render. See raw diff
 
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ sentence-transformers
2
+ transformers
3
+ torch
4
+ numpy
5
+ ftfy