modified faiss
Browse files
model.py
CHANGED
|
@@ -75,16 +75,12 @@ def faiss_add_index_cos(df, column):
|
|
| 75 |
|
| 76 |
# Create an index
|
| 77 |
index = faiss.IndexFlatIP(embeddings.shape[1])
|
| 78 |
-
|
| 79 |
-
faiss.normalize_L2(embeddings)
|
| 80 |
-
print("<<<<faiss_ after normalize")
|
| 81 |
|
| 82 |
index.train(embeddings)
|
| 83 |
-
print("<<<<faiss_ after index.train")
|
| 84 |
|
| 85 |
# Add the embeddings to the index
|
| 86 |
index.add(embeddings)
|
| 87 |
-
print("<<<<faiss_add")
|
| 88 |
|
| 89 |
# Return the index
|
| 90 |
return index
|
|
@@ -100,7 +96,7 @@ def faiss_get_top_N_images(query,
|
|
| 100 |
model, tokenizer,
|
| 101 |
device)
|
| 102 |
# Relevant columns
|
| 103 |
-
relevant_cols = ["comment", "image_name"
|
| 104 |
|
| 105 |
#faiss search with cos similarity
|
| 106 |
index = faiss_add_index_cos(data, column="text_embeddings")
|
|
@@ -113,5 +109,7 @@ def faiss_get_top_N_images(query,
|
|
| 113 |
non_repeated_images = ~data_sorted["image_name"].duplicated()
|
| 114 |
most_similar_articles = data_sorted[non_repeated_images].head(top_K)
|
| 115 |
|
| 116 |
-
result_df = most_similar_articles[relevant_cols].reset_index()
|
|
|
|
|
|
|
| 117 |
return [get_item_data(result_df, i, 'similarity') for i in range(len(result_df))]
|
|
|
|
| 75 |
|
| 76 |
# Create an index
|
| 77 |
index = faiss.IndexFlatIP(embeddings.shape[1])
|
| 78 |
+
faiss.normalize_L2(embeddings)
|
|
|
|
|
|
|
| 79 |
|
| 80 |
index.train(embeddings)
|
|
|
|
| 81 |
|
| 82 |
# Add the embeddings to the index
|
| 83 |
index.add(embeddings)
|
|
|
|
| 84 |
|
| 85 |
# Return the index
|
| 86 |
return index
|
|
|
|
| 96 |
model, tokenizer,
|
| 97 |
device)
|
| 98 |
# Relevant columns
|
| 99 |
+
relevant_cols = ["comment", "image_name"]
|
| 100 |
|
| 101 |
#faiss search with cos similarity
|
| 102 |
index = faiss_add_index_cos(data, column="text_embeddings")
|
|
|
|
| 109 |
non_repeated_images = ~data_sorted["image_name"].duplicated()
|
| 110 |
most_similar_articles = data_sorted[non_repeated_images].head(top_K)
|
| 111 |
|
| 112 |
+
result_df = most_similar_articles[relevant_cols].reset_index()
|
| 113 |
+
D = D.reshape(-1,1)[:top_K]
|
| 114 |
+
result_df = pd.concat([result_df, pd.DataFrame(D, columns=['similarity'])], axis=1)
|
| 115 |
return [get_item_data(result_df, i, 'similarity') for i in range(len(result_df))]
|