Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -84,7 +84,10 @@ windowSize = 3
|
|
| 84 |
numOfKeywords = 3
|
| 85 |
|
| 86 |
custom_kw_extractor = yake.KeywordExtractor(lan=language, n=max_ngram_size, dedupLim=deduplication_threshold, dedupFunc=deduplication_algo, windowsSize=windowSize, top=numOfKeywords, features=None)
|
| 87 |
-
|
|
|
|
|
|
|
|
|
|
| 88 |
# We lower case our text and remove stop-words from indexing
|
| 89 |
def bm25_tokenizer(text):
|
| 90 |
tokenized_doc = []
|
|
@@ -123,6 +126,13 @@ def clean_string(input_string):
|
|
| 123 |
output_string.append(string_strip)
|
| 124 |
return output_string
|
| 125 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 126 |
def generate_query_expansion_candidates(query):
|
| 127 |
print("Input query:", query)
|
| 128 |
expanded_query_set = {}
|
|
@@ -170,21 +180,16 @@ def generate_query_expansion_candidates(query):
|
|
| 170 |
# remove the query itself from candidates
|
| 171 |
if query in final_candidates:
|
| 172 |
del final_candidates[query]
|
| 173 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 174 |
# Total Results
|
| 175 |
st.write("E-Commerce Query Expansion Candidates: \n")
|
| 176 |
return final_candidates
|
| 177 |
|
| 178 |
-
with open('query_gms.json', 'r') as file:
|
| 179 |
-
query_gms_dict = json.load(file)
|
| 180 |
-
|
| 181 |
-
def add_gms_score_for_candidates(candidates, query_gms_dict):
|
| 182 |
-
for query_candidate in candidates:
|
| 183 |
-
value = candidates[query_candidate]
|
| 184 |
-
value['gms'] = query_gms_dict.get(query_candidate, 0)
|
| 185 |
-
candidates[query_candidate] = value
|
| 186 |
-
return candidates
|
| 187 |
-
|
| 188 |
def re_rank_candidates(query, candidates, method):
|
| 189 |
if method == 'bm25':
|
| 190 |
# Filter and sort by bm25_score
|
|
@@ -229,36 +234,19 @@ def re_rank_candidates(query, candidates, method):
|
|
| 229 |
return df
|
| 230 |
|
| 231 |
|
| 232 |
-
# def reranking():
|
| 233 |
-
# rerank_list = []
|
| 234 |
-
# reres = []
|
| 235 |
-
# rerank_list = search_nolog(query = user_query)
|
| 236 |
-
# unique_list = list(set(rerank_list))
|
| 237 |
-
# new_unique_list = [item for item in unique_list if item != user_query]
|
| 238 |
-
# Lowercasing_list = [item.lower() for item in new_unique_list]
|
| 239 |
-
|
| 240 |
-
# # st.write("E-Commerce Query Expansion Results: \n")
|
| 241 |
-
# st.write(Lowercasing_list[0:maxtags_sidebar])
|
| 242 |
-
|
| 243 |
-
# for i in Lowercasing_list[0:maxtags_sidebar]:
|
| 244 |
-
# reres.append(i)
|
| 245 |
-
# np.random.seed(7)
|
| 246 |
-
# np.random.shuffle(reres)
|
| 247 |
-
# test_res = {'front door': 0.5, 'family':0.3}
|
| 248 |
-
# st.write("Reranking Results: \n")
|
| 249 |
-
# st.write(test_res)
|
| 250 |
-
|
| 251 |
-
|
| 252 |
# st.write("## Raw Candidates:")
|
| 253 |
if st.button('Generated Expansion'):
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 257 |
|
| 258 |
if st.button('Rerank By Encoder'):
|
| 259 |
-
raw_candidates = generate_query_expansion_candidates(query = user_query)
|
| 260 |
-
candidates = add_gms_score_for_candidates(raw_candidates, query_gms_dict)
|
| 261 |
out_res = re_rank_candidates(user_query, candidates, method='encoder')
|
| 262 |
st.write("Reranking By Encoder: \n")
|
| 263 |
-
st.write(out_res[:maxtags_sidebar])
|
| 264 |
-
st.success(out_res)
|
|
|
|
| 84 |
numOfKeywords = 3
|
| 85 |
|
| 86 |
custom_kw_extractor = yake.KeywordExtractor(lan=language, n=max_ngram_size, dedupLim=deduplication_threshold, dedupFunc=deduplication_algo, windowsSize=windowSize, top=numOfKeywords, features=None)
|
| 87 |
+
# load query GMS information
|
| 88 |
+
with open('query_gms.json', 'r') as file:
|
| 89 |
+
query_gms_dict = json.load(file)
|
| 90 |
+
|
| 91 |
# We lower case our text and remove stop-words from indexing
|
| 92 |
def bm25_tokenizer(text):
|
| 93 |
tokenized_doc = []
|
|
|
|
| 126 |
output_string.append(string_strip)
|
| 127 |
return output_string
|
| 128 |
|
| 129 |
+
def add_gms_score_for_candidates(candidates, query_gms_dict):
|
| 130 |
+
for query_candidate in candidates:
|
| 131 |
+
value = candidates[query_candidate]
|
| 132 |
+
value['gms'] = query_gms_dict.get(query_candidate, 0)
|
| 133 |
+
candidates[query_candidate] = value
|
| 134 |
+
return candidates
|
| 135 |
+
|
| 136 |
def generate_query_expansion_candidates(query):
|
| 137 |
print("Input query:", query)
|
| 138 |
expanded_query_set = {}
|
|
|
|
| 180 |
# remove the query itself from candidates
|
| 181 |
if query in final_candidates:
|
| 182 |
del final_candidates[query]
|
| 183 |
+
|
| 184 |
+
# add gms column
|
| 185 |
+
for query_candidate in final_candidates:
|
| 186 |
+
value = final_candidates[query_candidate]
|
| 187 |
+
value['gms'] = query_gms_dict.get(query_candidate, 0)
|
| 188 |
+
final_candidates[query_candidate] = value
|
| 189 |
# Total Results
|
| 190 |
st.write("E-Commerce Query Expansion Candidates: \n")
|
| 191 |
return final_candidates
|
| 192 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 193 |
def re_rank_candidates(query, candidates, method):
|
| 194 |
if method == 'bm25':
|
| 195 |
# Filter and sort by bm25_score
|
|
|
|
| 234 |
return df
|
| 235 |
|
| 236 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 237 |
# st.write("## Raw Candidates:")
|
| 238 |
if st.button('Generated Expansion'):
|
| 239 |
+
candidates = generate_query_expansion_candidates(query = user_query)
|
| 240 |
+
# convert into dataframe
|
| 241 |
+
data_dicts = [{'query': key, **values} for key, values in candidates.items()]
|
| 242 |
+
df = pd.DataFrame(data_dicts)
|
| 243 |
+
# st.write(list(candidates.keys())[0:maxtags_sidebar])
|
| 244 |
+
st.write(df)
|
| 245 |
+
# st.success(raw_candidates)
|
| 246 |
|
| 247 |
if st.button('Rerank By Encoder'):
|
| 248 |
+
# raw_candidates = generate_query_expansion_candidates(query = user_query)
|
| 249 |
+
# candidates = add_gms_score_for_candidates(raw_candidates, query_gms_dict)
|
| 250 |
out_res = re_rank_candidates(user_query, candidates, method='encoder')
|
| 251 |
st.write("Reranking By Encoder: \n")
|
| 252 |
+
st.write(out_res[:maxtags_sidebar])
|
|
|