yinlinfu commited on
Commit
3c44dc3
·
1 Parent(s): 8147572

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +135 -180
app.py CHANGED
@@ -106,200 +106,155 @@ def word_len(s):
106
 
107
  # This function will search all wikipedia articles for passages that
108
  # answer the query
109
- def search(query):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
  print("Input query:", query)
111
- total_qe = []
112
 
113
  ##### BM25 search (lexical search) #####
114
  bm25_scores = bm25.get_scores(bm25_tokenizer(query))
115
- top_n = np.argpartition(bm25_scores, -5)[-5:]
116
- bm25_hits = [{'corpus_id': idx, 'score': bm25_scores[idx]} for idx in top_n]
117
- bm25_hits = sorted(bm25_hits, key=lambda x: x['score'], reverse=True)
 
 
118
 
119
- #print("Top-10 lexical search (BM25) hits")
120
- qe_string = []
121
- for hit in bm25_hits[0:1000]:
122
- if passages[hit['corpus_id']].replace("\n", " ") not in qe_string:
123
- qe_string.append(passages[hit['corpus_id']].replace("\n", ""))
124
-
125
- sub_string = []
126
- for item in qe_string:
127
- for sub_item in item.split(","):
128
- sub_string.append(sub_item)
129
- #print(sub_string)
130
- total_qe.append(sub_string)
131
-
132
  ##### Sematic Search #####
133
  # Encode the query using the bi-encoder and find potentially relevant passages
134
  query_embedding = bi_encoder.encode(query, convert_to_tensor=True)
135
- hits = util.semantic_search(query_embedding, corpus_embeddings, top_k=top_k)
136
- hits = hits[0] # Get the hits for the first query
 
137
 
138
- ##### Re-Ranking #####
139
- # Now, score all retrieved passages with the cross_encoder
140
- cross_inp = [[query, passages[hit['corpus_id']]] for hit in hits]
141
  cross_scores = cross_encoder.predict(cross_inp)
142
-
143
- # Sort results by the cross-encoder scores
144
  for idx in range(len(cross_scores)):
145
- hits[idx]['cross-score'] = cross_scores[idx]
146
-
147
- # Output of top-10 hits from bi-encoder
148
- #print("\n-------------------------\n")
149
- #print("Top-N Bi-Encoder Retrieval hits")
150
- hits = sorted(hits, key=lambda x: x['score'], reverse=True)
151
- qe_string = []
152
- for hit in hits[0:1000]:
153
- if passages[hit['corpus_id']].replace("\n", " ") not in qe_string:
154
- qe_string.append(passages[hit['corpus_id']].replace("\n", ""))
155
- #print(qe_string)
156
- total_qe.append(qe_string)
157
-
158
- # Output of top-10 hits from re-ranker
159
- #print("\n-------------------------\n")
160
- #print("Top-N Cross-Encoder Re-ranker hits")
161
- hits = sorted(hits, key=lambda x: x['cross-score'], reverse=True)
162
- qe_string = []
163
- for hit in hits[0:1000]:
164
- if passages[hit['corpus_id']].replace("\n", " ") not in qe_string:
165
- qe_string.append(passages[hit['corpus_id']].replace("\n", ""))
166
- #print(qe_string)
167
- total_qe.append(qe_string)
168
-
 
169
  # Total Results
170
- total_qe.append(qe_string)
171
  st.write("E-Commerce Query Expansion Results: \n")
 
 
172
 
173
- res = []
174
- for sub_list in total_qe:
175
- for i in sub_list:
176
- rs = re.sub("([^\u0030-\u0039\u0041-\u007a])", ' ', i)
177
- rs_final = re.sub("\x20\x20", "\n", rs)
178
- #st.write(rs_final.strip())
179
- res.append(rs_final.strip())
180
-
181
- res_clean = []
182
- for out in res:
183
- if len(out) > 20:
184
- keywords = custom_kw_extractor.extract_keywords(out)
185
- for key in keywords:
186
- res_clean.append(key[0])
187
- else:
188
- res_clean.append(out)
189
-
190
- show_out = []
191
- for i in res_clean:
192
- num = word_len(i)
193
- if num > 1:
194
- show_out.append(i)
195
- unique_list = list(set(show_out))
196
- new_unique_list = [item for item in unique_list if item != query]
197
- Lowercasing_list = [item.lower() for item in new_unique_list]
198
- st.write(Lowercasing_list[0:maxtags_sidebar])
199
-
200
- return Lowercasing_list
201
-
202
- def search_nolog(query):
203
- total_qe = []
204
- ##### BM25 search (lexical search) #####
205
- bm25_scores = bm25.get_scores(bm25_tokenizer(query))
206
- top_n = np.argpartition(bm25_scores, -5)[-5:]
207
- bm25_hits = [{'corpus_id': idx, 'score': bm25_scores[idx]} for idx in top_n]
208
- bm25_hits = sorted(bm25_hits, key=lambda x: x['score'], reverse=True)
209
 
210
- qe_string = []
211
- for hit in bm25_hits[0:1000]:
212
- if passages[hit['corpus_id']].replace("\n", " ") not in qe_string:
213
- qe_string.append(passages[hit['corpus_id']].replace("\n", ""))
214
-
215
- sub_string = []
216
- for item in qe_string:
217
- for sub_item in item.split(","):
218
- sub_string.append(sub_item)
219
- total_qe.append(sub_string)
220
-
221
- ##### Sematic Search #####
222
- # Encode the query using the bi-encoder and find potentially relevant passages
223
- query_embedding = bi_encoder.encode(query, convert_to_tensor=True)
224
- hits = util.semantic_search(query_embedding, corpus_embeddings, top_k=top_k)
225
- hits = hits[0] # Get the hits for the first query
226
-
227
- ##### Re-Ranking #####
228
- # Now, score all retrieved passages with the cross_encoder
229
- cross_inp = [[query, passages[hit['corpus_id']]] for hit in hits]
230
- cross_scores = cross_encoder.predict(cross_inp)
231
-
232
- # Sort results by the cross-encoder scores
233
- for idx in range(len(cross_scores)):
234
- hits[idx]['cross-score'] = cross_scores[idx]
235
-
236
- # Output of top-10 hits from bi-encoder
237
- hits = sorted(hits, key=lambda x: x['score'], reverse=True)
238
- qe_string = []
239
- for hit in hits[0:1000]:
240
- if passages[hit['corpus_id']].replace("\n", " ") not in qe_string:
241
- qe_string.append(passages[hit['corpus_id']].replace("\n", ""))
242
- total_qe.append(qe_string)
243
-
244
- # Output of top-10 hits from re-ranker
245
- hits = sorted(hits, key=lambda x: x['cross-score'], reverse=True)
246
- qe_string = []
247
- for hit in hits[0:1000]:
248
- if passages[hit['corpus_id']].replace("\n", " ") not in qe_string:
249
- qe_string.append(passages[hit['corpus_id']].replace("\n", ""))
250
- total_qe.append(qe_string)
251
-
252
- # Total Results
253
- total_qe.append(qe_string)
254
-
255
- res = []
256
- for sub_list in total_qe:
257
- for i in sub_list:
258
- rs = re.sub("([^\u0030-\u0039\u0041-\u007a])", ' ', i)
259
- rs_final = re.sub("\x20\x20", "\n", rs)
260
- res.append(rs_final.strip())
261
-
262
- res_clean = []
263
- for out in res:
264
- if len(out) > 20:
265
- keywords = custom_kw_extractor.extract_keywords(out)
266
- for key in keywords:
267
- res_clean.append(key[0])
268
- else:
269
- res_clean.append(out)
270
-
271
- show_out = []
272
- for i in res_clean:
273
- num = word_len(i)
274
- if num > 1:
275
- show_out.append(i)
276
-
277
- return show_out
278
-
279
- def reranking():
280
- rerank_list = []
281
- reres = []
282
- rerank_list = search_nolog(query = user_query)
283
- unique_list = list(set(rerank_list))
284
- new_unique_list = [item for item in unique_list if item != user_query]
285
- Lowercasing_list = [item.lower() for item in new_unique_list]
286
-
287
- # st.write("E-Commerce Query Expansion Results: \n")
288
- st.write(Lowercasing_list[0:maxtags_sidebar])
289
-
290
- for i in Lowercasing_list[0:maxtags_sidebar]:
291
- reres.append(i)
292
- np.random.seed(7)
293
- np.random.shuffle(reres)
294
- test_res = {'front door': 0.5, 'family':0.3}
295
  st.write("Reranking Results: \n")
296
- st.write(test_res)
297
-
298
- st.write("## Results:")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
299
  if st.button('Generated Expansion'):
300
- out_res = search(query = user_query)
301
- #st.success(out_res)
302
-
303
- if st.button('Rerank'):
304
- out_res = reranking()
305
- #st.success(out_res)
 
 
 
106
 
107
  # This function will search all wikipedia articles for passages that
108
  # answer the query
109
+ DEFAULT_SCORE = -100.0
110
+ def clean_string(input_string):
111
+ string_sub1 = re.sub("([^\u0030-\u0039\u0041-\u007a])", ' ', input_string)
112
+ string_sub2 = re.sub("\x20\x20", "\n", string_sub1)
113
+ string_strip = string_sub2.strip().lower()
114
+ output_string = []
115
+ if len(string_strip) > 20:
116
+ keywords = custom_kw_extractor.extract_keywords(string_strip)
117
+ for tokens in keywords:
118
+ string_clean = tokens[0]
119
+ if word_len(string_clean) > 1:
120
+ output_string.append(string_clean)
121
+ else:
122
+ output_string.append(string_strip)
123
+ return output_string
124
+
125
+ def generate_query_expansion_candidates(query):
126
  print("Input query:", query)
127
+ expanded_query_set = {}
128
 
129
  ##### BM25 search (lexical search) #####
130
  bm25_scores = bm25.get_scores(bm25_tokenizer(query))
131
+ # finds the indices of the top n scores
132
+ top_n_indices = np.argpartition(bm25_scores, -5)[-5:]
133
+ bm25_hits = [{'corpus_id': idx, 'bm25_score': bm25_scores[idx]} for idx in top_n_indices]
134
+ # bm25_hits = sorted(bm25_hits, key=lambda x: x['score'], reverse=True)
135
+
136
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
  ##### Sematic Search #####
138
  # Encode the query using the bi-encoder and find potentially relevant passages
139
  query_embedding = bi_encoder.encode(query, convert_to_tensor=True)
140
+ query_embedding = query_embedding.cuda()
141
+ # Get the hits for the first query
142
+ encoder_hits = util.semantic_search(query_embedding, corpus_embeddings, top_k=top_k)[0]
143
 
144
+ # For all retrieved passages, add the cross_encoder scores
145
+ cross_inp = [[query, passages[hit['corpus_id']]] for hit in encoder_hits]
 
146
  cross_scores = cross_encoder.predict(cross_inp)
 
 
147
  for idx in range(len(cross_scores)):
148
+ encoder_hits[idx]['cross_score'] = cross_scores[idx]
149
+
150
+ candidates = {}
151
+ for hit in bm25_hits:
152
+ corpus_id = hit['corpus_id']
153
+ if corpus_id not in candidates:
154
+ candidates[corpus_id] = {'bm25_score': hit['bm25_score'], 'bi_score': DEFAULT_SCORE, 'cross_score': DEFAULT_SCORE}
155
+ for hit in encoder_hits:
156
+ corpus_id = hit['corpus_id']
157
+ if corpus_id not in candidates:
158
+ candidates[corpus_id] = {'bm25_score': DEFAULT_SCORE, 'bi_score': hit['score'], 'cross_score': hit['cross_score']}
159
+ else:
160
+ bm25_score = candidates[corpus_id]['bm25_score']
161
+ candidates[corpus_id].update({'bm25_score': bm25_score, 'bi_score': hit['score'], 'cross_score': hit['cross_score']})
162
+
163
+ final_candidates = {}
164
+ for key, value in candidates.items():
165
+ input_string = passages[key].replace("\n", "")
166
+ string_set = set(clean_string(input_string))
167
+ for item in string_set:
168
+ final_candidates[item] = value
169
+ # remove the query itself from candidates
170
+ if query in final_candidates:
171
+ del final_candidates[query]
172
+
173
  # Total Results
 
174
  st.write("E-Commerce Query Expansion Results: \n")
175
+ st.write(list(final_candidates.keys())[0:maxtags_sidebar])
176
+ return final_candidates
177
 
178
+ with open('query_gms.json', 'r') as file:
179
+ query_gms_dict = json.load(file)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
180
 
181
+ def add_gms_score_for_candidates(candidates, query_gms_dict):
182
+ for query_candidate in candidates:
183
+ value = candidates[query_candidate]
184
+ value['gms'] = query_gms_dict.get(query_candidate, 0)
185
+ candidates[query_candidate] = value
186
+ return candidates
187
+
188
+ def re_rank_candidates(query, candidates, method):
189
+ if method == 'bm25':
190
+ # Filter and sort by bm25_score
191
+ filtered_sorted_result = sorted(
192
+ [(k, v) for k, v in candidates.items() if v['bm25_score'] > DEFAULT_SCORE],
193
+ key=lambda x: x[1]['bm25_score'],
194
+ reverse=True
195
+ )
196
+ elif method == 'bi_encoder':
197
+ # Filter and sort by bi_score
198
+ filtered_sorted_result = sorted(
199
+ [(k, v) for k, v in candidates.items() if v['bi_score'] > DEFAULT_SCORE],
200
+ key=lambda x: x[1]['bi_score'],
201
+ reverse=True
202
+ )
203
+ elif method == 'cross_encoder':
204
+ # Filter and sort by cross_score
205
+ filtered_sorted_result = sorted(
206
+ [(k, v) for k, v in candidates.items() if v['cross_score'] > DEFAULT_SCORE],
207
+ key=lambda x: x[1]['cross_score'],
208
+ reverse=True
209
+ )
210
+ elif method == 'encoder':
211
+ # Filter and sort by cross_score + bi_score
212
+ filtered_sorted_result = sorted(
213
+ [(k, v) for k, v in candidates.items() if (v['cross_score'] > DEFAULT_SCORE) & (v['bi_score'] > DEFAULT_SCORE)],
214
+ key=lambda x: x[1]['cross_score'] + x[1]['bi_score'],
215
+ reverse=True
216
+ )
217
+ elif method == 'gms':
218
+ filtered_sorted_by_encoder = sorted(
219
+ [(k, v) for k, v in candidates.items() if (v['cross_score'] > DEFAULT_SCORE) & (v['bi_score'] > DEFAULT_SCORE)],
220
+ key=lambda x: x[1]['cross_score'] + x[1]['bi_score'],
221
+ reverse=True
222
+ )
223
+ # first sort by cross_score + bi_score
224
+ filtered_sorted_result = sorted(filtered_sorted_by_encoder, key=lambda x: x[1]['gms'], reverse=True
225
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
226
  st.write("Reranking Results: \n")
227
+ st.write(filtered_sorted_result)
228
+
229
+ # def reranking():
230
+ # rerank_list = []
231
+ # reres = []
232
+ # rerank_list = search_nolog(query = user_query)
233
+ # unique_list = list(set(rerank_list))
234
+ # new_unique_list = [item for item in unique_list if item != user_query]
235
+ # Lowercasing_list = [item.lower() for item in new_unique_list]
236
+
237
+ # # st.write("E-Commerce Query Expansion Results: \n")
238
+ # st.write(Lowercasing_list[0:maxtags_sidebar])
239
+
240
+ # for i in Lowercasing_list[0:maxtags_sidebar]:
241
+ # reres.append(i)
242
+ # np.random.seed(7)
243
+ # np.random.shuffle(reres)
244
+ # test_res = {'front door': 0.5, 'family':0.3}
245
+ # st.write("Reranking Results: \n")
246
+ # st.write(test_res)
247
+
248
+ raw_candidates = generate_query_expansion_candidates(query = user_query)
249
+ candidates = add_gms_score_for_candidates(raw_candidates, query_gms_dict)
250
+
251
+ st.write("## Raw Candidates:")
252
  if st.button('Generated Expansion'):
253
+ out_res = raw_candidates
254
+ st.success(out_res)
255
+
256
+ if st.button('Rerank By Encoder'):
257
+ out_res = re_rank_candidates(user_query, candidates, method='encoder')
258
+ st.write("Reranking By Encoder: \n")
259
+ st.write(out_res)
260
+ st.success(out_res)