Spaces:
Build error
Build error
Update app.py
Browse files
app.py
CHANGED
|
@@ -9,7 +9,6 @@ import gzip
|
|
| 9 |
import os
|
| 10 |
import torch
|
| 11 |
import pickle
|
| 12 |
-
import yake
|
| 13 |
|
| 14 |
############
|
| 15 |
## Main page
|
|
@@ -36,7 +35,7 @@ user_query = st.text_input("Enter a query for the generated text: e.g., gift, ho
|
|
| 36 |
# Add selectbox in streamlit
|
| 37 |
option1 = st.sidebar.selectbox(
|
| 38 |
'Which transformers model would you like to be selected?',
|
| 39 |
-
('multi-qa-MiniLM-L6-cos-v1','
|
| 40 |
|
| 41 |
option2 = st.sidebar.selectbox(
|
| 42 |
'Which corss-encoder model would you like to be selected?',
|
|
@@ -65,20 +64,52 @@ with open(embedding_cache_path, "rb") as fIn:
|
|
| 65 |
passages = cache_data['sentences']
|
| 66 |
corpus_embeddings = cache_data['embeddings']
|
| 67 |
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 74 |
|
| 75 |
# This function will search all wikipedia articles for passages that
|
| 76 |
# answer the query
|
| 77 |
def search(query):
|
| 78 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 79 |
##### Sematic Search #####
|
| 80 |
# Encode the query using the bi-encoder and find potentially relevant passages
|
| 81 |
query_embedding = bi_encoder.encode(query, convert_to_tensor=True)
|
|
|
|
| 82 |
hits = util.semantic_search(query_embedding, corpus_embeddings, top_k=top_k)
|
| 83 |
hits = hits[0] # Get the hits for the first query
|
| 84 |
|
|
@@ -91,28 +122,33 @@ def search(query):
|
|
| 91 |
for idx in range(len(cross_scores)):
|
| 92 |
hits[idx]['cross-score'] = cross_scores[idx]
|
| 93 |
|
| 94 |
-
# Output of top-
|
| 95 |
-
#
|
| 96 |
-
#
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 104 |
hits = sorted(hits, key=lambda x: x['cross-score'], reverse=True)
|
| 105 |
-
|
| 106 |
-
# st.write("\t{:.3f}\t{}".format(hit['cross-score'], passages[hit['corpus_id']].replace("\n", " ")))
|
| 107 |
-
hit_res = []
|
| 108 |
for hit in hits[0:1000]:
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
|
|
|
|
|
|
|
|
|
| 116 |
|
| 117 |
st.write("## Results:")
|
| 118 |
if st.button('Generated Expansion'):
|
|
|
|
| 9 |
import os
|
| 10 |
import torch
|
| 11 |
import pickle
|
|
|
|
| 12 |
|
| 13 |
############
|
| 14 |
## Main page
|
|
|
|
| 35 |
# Add selectbox in streamlit
|
| 36 |
option1 = st.sidebar.selectbox(
|
| 37 |
'Which transformers model would you like to be selected?',
|
| 38 |
+
('multi-qa-MiniLM-L6-cos-v1','null','null'))
|
| 39 |
|
| 40 |
option2 = st.sidebar.selectbox(
|
| 41 |
'Which corss-encoder model would you like to be selected?',
|
|
|
|
| 64 |
passages = cache_data['sentences']
|
| 65 |
corpus_embeddings = cache_data['embeddings']
|
| 66 |
|
| 67 |
+
from rank_bm25 import BM25Okapi
|
| 68 |
+
from sklearn.feature_extraction import _stop_words
|
| 69 |
+
import string
|
| 70 |
+
from tqdm.autonotebook import tqdm
|
| 71 |
+
import numpy as np
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
# We lower case our text and remove stop-words from indexing
|
| 75 |
+
def bm25_tokenizer(text):
|
| 76 |
+
tokenized_doc = []
|
| 77 |
+
for token in text.lower().split():
|
| 78 |
+
token = token.strip(string.punctuation)
|
| 79 |
+
|
| 80 |
+
if len(token) > 0 and token not in _stop_words.ENGLISH_STOP_WORDS:
|
| 81 |
+
tokenized_doc.append(token)
|
| 82 |
+
return tokenized_doc
|
| 83 |
|
| 84 |
# This function will search all wikipedia articles for passages that
|
| 85 |
# answer the query
|
| 86 |
def search(query):
|
| 87 |
+
print("Input query:", query)
|
| 88 |
+
total_qe = []
|
| 89 |
+
|
| 90 |
+
##### BM25 search (lexical search) #####
|
| 91 |
+
bm25_scores = bm25.get_scores(bm25_tokenizer(query))
|
| 92 |
+
top_n = np.argpartition(bm25_scores, -5)[-5:]
|
| 93 |
+
bm25_hits = [{'corpus_id': idx, 'score': bm25_scores[idx]} for idx in top_n]
|
| 94 |
+
bm25_hits = sorted(bm25_hits, key=lambda x: x['score'], reverse=True)
|
| 95 |
+
|
| 96 |
+
#print("Top-10 lexical search (BM25) hits")
|
| 97 |
+
qe_string = []
|
| 98 |
+
for hit in bm25_hits[0:1000]:
|
| 99 |
+
if passages[hit['corpus_id']].replace("\n", " ") not in qe_string:
|
| 100 |
+
qe_string.append(passages[hit['corpus_id']].replace("\n", ""))
|
| 101 |
+
|
| 102 |
+
sub_string = []
|
| 103 |
+
for item in qe_string:
|
| 104 |
+
for sub_item in item.split(","):
|
| 105 |
+
sub_string.append(sub_item)
|
| 106 |
+
#print(sub_string)
|
| 107 |
+
total_qe.append(sub_string)
|
| 108 |
+
|
| 109 |
##### Sematic Search #####
|
| 110 |
# Encode the query using the bi-encoder and find potentially relevant passages
|
| 111 |
query_embedding = bi_encoder.encode(query, convert_to_tensor=True)
|
| 112 |
+
query_embedding = query_embedding.cuda()
|
| 113 |
hits = util.semantic_search(query_embedding, corpus_embeddings, top_k=top_k)
|
| 114 |
hits = hits[0] # Get the hits for the first query
|
| 115 |
|
|
|
|
| 122 |
for idx in range(len(cross_scores)):
|
| 123 |
hits[idx]['cross-score'] = cross_scores[idx]
|
| 124 |
|
| 125 |
+
# Output of top-10 hits from bi-encoder
|
| 126 |
+
#print("\n-------------------------\n")
|
| 127 |
+
#print("Top-N Bi-Encoder Retrieval hits")
|
| 128 |
+
hits = sorted(hits, key=lambda x: x['score'], reverse=True)
|
| 129 |
+
qe_string = []
|
| 130 |
+
for hit in hits[0:1000]:
|
| 131 |
+
if passages[hit['corpus_id']].replace("\n", " ") not in qe_string:
|
| 132 |
+
qe_string.append(passages[hit['corpus_id']].replace("\n", ""))
|
| 133 |
+
#print(qe_string)
|
| 134 |
+
total_qe.append(qe_string)
|
| 135 |
+
|
| 136 |
+
# Output of top-10 hits from re-ranker
|
| 137 |
+
#print("\n-------------------------\n")
|
| 138 |
+
#print("Top-N Cross-Encoder Re-ranker hits")
|
| 139 |
hits = sorted(hits, key=lambda x: x['cross-score'], reverse=True)
|
| 140 |
+
qe_string = []
|
|
|
|
|
|
|
| 141 |
for hit in hits[0:1000]:
|
| 142 |
+
if passages[hit['corpus_id']].replace("\n", " ") not in qe_string:
|
| 143 |
+
qe_string.append(passages[hit['corpus_id']].replace("\n", ""))
|
| 144 |
+
#print(qe_string)
|
| 145 |
+
total_qe.append(qe_string)
|
| 146 |
+
|
| 147 |
+
# Total Results
|
| 148 |
+
total_qe.append(qe_string)
|
| 149 |
+
print("E-Commerce Query Expansion Results: \n")
|
| 150 |
+
print(total_qe)
|
| 151 |
+
|
| 152 |
|
| 153 |
st.write("## Results:")
|
| 154 |
if st.button('Generated Expansion'):
|