Spaces:
Build error
Build error
Update app.py
Browse files
app.py
CHANGED
|
@@ -200,10 +200,87 @@ def search(query):
|
|
| 200 |
|
| 201 |
return show_out
|
| 202 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 203 |
def reranking():
|
| 204 |
rerank_list = []
|
| 205 |
-
rerank_list =
|
| 206 |
-
|
| 207 |
random.shuffle(rerank_list[0:maxtags_sidebar])
|
| 208 |
for i in rerank_list[0:maxtags_sidebar]:
|
| 209 |
st.write(i)
|
|
|
|
| 200 |
|
| 201 |
return show_out
|
| 202 |
|
| 203 |
+
def search_nolog(query):
|
| 204 |
+
total_qe = []
|
| 205 |
+
##### BM25 search (lexical search) #####
|
| 206 |
+
bm25_scores = bm25.get_scores(bm25_tokenizer(query))
|
| 207 |
+
top_n = np.argpartition(bm25_scores, -5)[-5:]
|
| 208 |
+
bm25_hits = [{'corpus_id': idx, 'score': bm25_scores[idx]} for idx in top_n]
|
| 209 |
+
bm25_hits = sorted(bm25_hits, key=lambda x: x['score'], reverse=True)
|
| 210 |
+
|
| 211 |
+
qe_string = []
|
| 212 |
+
for hit in bm25_hits[0:1000]:
|
| 213 |
+
if passages[hit['corpus_id']].replace("\n", " ") not in qe_string:
|
| 214 |
+
qe_string.append(passages[hit['corpus_id']].replace("\n", ""))
|
| 215 |
+
|
| 216 |
+
sub_string = []
|
| 217 |
+
for item in qe_string:
|
| 218 |
+
for sub_item in item.split(","):
|
| 219 |
+
sub_string.append(sub_item)
|
| 220 |
+
total_qe.append(sub_string)
|
| 221 |
+
|
| 222 |
+
##### Sematic Search #####
|
| 223 |
+
# Encode the query using the bi-encoder and find potentially relevant passages
|
| 224 |
+
query_embedding = bi_encoder.encode(query, convert_to_tensor=True)
|
| 225 |
+
hits = util.semantic_search(query_embedding, corpus_embeddings, top_k=top_k)
|
| 226 |
+
hits = hits[0] # Get the hits for the first query
|
| 227 |
+
|
| 228 |
+
##### Re-Ranking #####
|
| 229 |
+
# Now, score all retrieved passages with the cross_encoder
|
| 230 |
+
cross_inp = [[query, passages[hit['corpus_id']]] for hit in hits]
|
| 231 |
+
cross_scores = cross_encoder.predict(cross_inp)
|
| 232 |
+
|
| 233 |
+
# Sort results by the cross-encoder scores
|
| 234 |
+
for idx in range(len(cross_scores)):
|
| 235 |
+
hits[idx]['cross-score'] = cross_scores[idx]
|
| 236 |
+
|
| 237 |
+
# Output of top-10 hits from bi-encoder
|
| 238 |
+
hits = sorted(hits, key=lambda x: x['score'], reverse=True)
|
| 239 |
+
qe_string = []
|
| 240 |
+
for hit in hits[0:1000]:
|
| 241 |
+
if passages[hit['corpus_id']].replace("\n", " ") not in qe_string:
|
| 242 |
+
qe_string.append(passages[hit['corpus_id']].replace("\n", ""))
|
| 243 |
+
total_qe.append(qe_string)
|
| 244 |
+
|
| 245 |
+
# Output of top-10 hits from re-ranker
|
| 246 |
+
hits = sorted(hits, key=lambda x: x['cross-score'], reverse=True)
|
| 247 |
+
qe_string = []
|
| 248 |
+
for hit in hits[0:1000]:
|
| 249 |
+
if passages[hit['corpus_id']].replace("\n", " ") not in qe_string:
|
| 250 |
+
qe_string.append(passages[hit['corpus_id']].replace("\n", ""))
|
| 251 |
+
total_qe.append(qe_string)
|
| 252 |
+
|
| 253 |
+
# Total Results
|
| 254 |
+
total_qe.append(qe_string)
|
| 255 |
+
|
| 256 |
+
res = []
|
| 257 |
+
for sub_list in total_qe:
|
| 258 |
+
for i in sub_list:
|
| 259 |
+
rs = re.sub("([^\u0030-\u0039\u0041-\u007a])", ' ', i)
|
| 260 |
+
rs_final = re.sub("\x20\x20", "\n", rs)
|
| 261 |
+
res.append(rs_final.strip())
|
| 262 |
+
|
| 263 |
+
res_clean = []
|
| 264 |
+
for out in res:
|
| 265 |
+
if len(out) > 20:
|
| 266 |
+
keywords = custom_kw_extractor.extract_keywords(out)
|
| 267 |
+
for key in keywords:
|
| 268 |
+
res_clean.append(key[0])
|
| 269 |
+
else:
|
| 270 |
+
res_clean.append(out)
|
| 271 |
+
|
| 272 |
+
show_out = []
|
| 273 |
+
for i in res_clean:
|
| 274 |
+
num = word_len(i)
|
| 275 |
+
if num > 1:
|
| 276 |
+
show_out.append(i)
|
| 277 |
+
|
| 278 |
+
return show_out
|
| 279 |
+
|
| 280 |
def reranking():
|
| 281 |
rerank_list = []
|
| 282 |
+
rerank_list = search_nolog(query = user_query)
|
| 283 |
+
random.seed(7)
|
| 284 |
random.shuffle(rerank_list[0:maxtags_sidebar])
|
| 285 |
for i in rerank_list[0:maxtags_sidebar]:
|
| 286 |
st.write(i)
|