Pairavi commited on
Commit
6000968
·
1 Parent(s): 54eb5f7

Upload similarSearch.py

Browse files
Files changed (1) hide show
  1. similarSearch.py +143 -0
similarSearch.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch as th
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import numpy as np
5
+ import os
6
+ # import faiss
7
+ import copy
8
+ import cv2
9
+ import time
10
+ import torchvision.transforms as T
11
+ import io
12
+ import open_clip
13
+ from PIL import Image
14
+
15
+ def get_similarity_brute_force(embeddings_gallery, embeddings_query, k):
16
+ print('Processing indices...')
17
+
18
+ s = time.time()
19
+ distances = np.linalg.norm(embeddings_gallery - embeddings_query, axis=1)
20
+ indices = np.argsort(distances)[:k]
21
+ scores = distances[indices]
22
+ e = time.time()
23
+
24
+ print(f'Finished processing indices, took {e - s}s')
25
+ return scores, indices
26
+
27
+
28
+ def get_similarity_l2(embeddings_gallery, embeddings_query, k):
29
+ print('Processing indices...')
30
+
31
+ s = time.time()
32
+ dists = np.linalg.norm(embeddings_gallery - embeddings_query, axis=1)
33
+ indices = np.argsort(dists)[:k]
34
+ scores = dists[indices]
35
+ e = time.time()
36
+
37
+ print(f'Finished processing indices, took {e - s}s')
38
+ return scores, indices
39
+
40
+ def get_similarity_IP(embeddings_gallery, embeddings_query, k):
41
+ print('Processing indices...')
42
+
43
+ s = time.time()
44
+ dot_product = np.dot(embeddings_gallery, embeddings_query.T)
45
+ norm_gallery = np.linalg.norm(embeddings_gallery, axis=1)
46
+ norm_query = np.linalg.norm(embeddings_query)
47
+ scores = dot_product / (norm_gallery * norm_query)
48
+ indices = np.argsort(scores, axis=0)[-k:][::-1]
49
+ e = time.time()
50
+
51
+ print(f'Finished processing indices, took {e - s}s')
52
+ return scores, indices
53
+
54
+
55
+ def convert_indices_to_labels(indices, labels):
56
+ indices_copy = copy.deepcopy(indices)
57
+ for row in indices_copy:
58
+ for j in range(len(row)):
59
+ row[j] = labels[row[j]]
60
+ return indices_copy
61
+
62
+ def get_final_transform():
63
+ final_transform = T.Compose([
64
+ T.Resize(
65
+ size=(224, 224),
66
+ interpolation=T.InterpolationMode.BICUBIC,
67
+ antialias=True),
68
+ T.ToTensor(),
69
+ T.Normalize(
70
+ mean=(0.48145466, 0.4578275, 0.40821073),
71
+ std=(0.26862954, 0.26130258, 0.27577711)
72
+ )
73
+ ])
74
+ return final_transform
75
+
76
+ def read_img(img_file, is_gray=False):
77
+ img = Image.open(img_file)
78
+ if is_gray:
79
+ img = img.convert('L')
80
+ else:
81
+ img = img.convert('RGB')
82
+ img = np.array(img)
83
+ return img
84
+
85
+ def transform_img(image):
86
+ img = image
87
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
88
+
89
+ if isinstance(img, np.ndarray):
90
+ img = Image.fromarray(img)
91
+
92
+ transform = get_final_transform()
93
+
94
+ img = transform(img)
95
+
96
+ return img
97
+
98
+ @th.no_grad()
99
+ def extract_embeddings(model, image, epoch=10, use_cuda=False):
100
+ features = []
101
+
102
+ for _ in range(epoch):
103
+ if use_cuda:
104
+ image = image.cuda()
105
+
106
+ # Ensure the input data type matches the weight data type
107
+ features.append(model(image).detach().cpu().numpy().astype(np.float32))
108
+
109
+
110
+ return np.concatenate(features, axis=0)
111
+
112
+
113
+
114
+
115
+ def Model():
116
+ backbone = open_clip.create_model_and_transforms('ViT-H-14', None)[0].visual
117
+ backbone.load_state_dict(th.load("./model1.pt"))
118
+ # backbone.half()
119
+ backbone.eval()
120
+ return backbone
121
+
122
+
123
+ def predict(image_data):
124
+ image = np.array(image_data)
125
+ image = transform_img(image).unsqueeze(0)
126
+
127
+ model_1 = Model()
128
+
129
+ embeddings_query = extract_embeddings(model_1, image, 1)
130
+ embeddings_gallery = np.load("./embeddings_gallery.npy")
131
+
132
+
133
+ _, indices = get_similarity_l2(embeddings_gallery, embeddings_query, 1000)
134
+
135
+ indices = indices.tolist()
136
+
137
+ return indices
138
+
139
+
140
+
141
+
142
+
143
+