yinlinfu commited on
Commit
e25dc5a
·
1 Parent(s): 2109781

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -8
app.py CHANGED
@@ -212,13 +212,6 @@ def re_rank_candidates(query, candidates, method):
212
  key=lambda x: x[1]['cross_score'],
213
  reverse=True
214
  )
215
- elif method == 'encoder':
216
- # Filter and sort by cross_score + bi_score
217
- filtered_sorted_result = sorted(
218
- [(k, v) for k, v in candidates.items() if (v['cross_score'] > DEFAULT_SCORE) & (v['bi_score'] > DEFAULT_SCORE)],
219
- key=lambda x: x[1]['cross_score'] + x[1]['bi_score'],
220
- reverse=True
221
- )
222
  elif method == 'gms':
223
  filtered_sorted_by_encoder = sorted(
224
  [(k, v) for k, v in candidates.items() if (v['cross_score'] > DEFAULT_SCORE) & (v['bi_score'] > DEFAULT_SCORE)],
@@ -228,6 +221,14 @@ def re_rank_candidates(query, candidates, method):
228
  # first sort by cross_score + bi_score
229
  filtered_sorted_result = sorted(filtered_sorted_by_encoder, key=lambda x: x[1]['gms'], reverse=True
230
  )
 
 
 
 
 
 
 
 
231
  data_dicts = [{'query': item[0], **item[1]} for item in filtered_sorted_result]
232
  # Convert the list of dictionaries into a DataFrame
233
  df = pd.DataFrame(data_dicts)
@@ -237,7 +238,7 @@ def re_rank_candidates(query, candidates, method):
237
  # st.write("## Raw Candidates:")
238
  if st.button('Generated Expansion'):
239
  candidates = generate_query_expansion_candidates(query = user_query)
240
- df = re_rank_candidates(user_query, candidates, method='cross_score')
241
  result = list(df['query'][:maxtags_sidebar])
242
  st.write(result)
243
  ## convert into dataframe
 
212
  key=lambda x: x[1]['cross_score'],
213
  reverse=True
214
  )
 
 
 
 
 
 
 
215
  elif method == 'gms':
216
  filtered_sorted_by_encoder = sorted(
217
  [(k, v) for k, v in candidates.items() if (v['cross_score'] > DEFAULT_SCORE) & (v['bi_score'] > DEFAULT_SCORE)],
 
221
  # first sort by cross_score + bi_score
222
  filtered_sorted_result = sorted(filtered_sorted_by_encoder, key=lambda x: x[1]['gms'], reverse=True
223
  )
224
+ else:
225
+ # use default method cross_score + bi_score
226
+ # Filter and sort by cross_score + bi_score
227
+ filtered_sorted_result = sorted(
228
+ [(k, v) for k, v in candidates.items() if (v['cross_score'] > DEFAULT_SCORE) & (v['bi_score'] > DEFAULT_SCORE)],
229
+ key=lambda x: x[1]['cross_score'] + x[1]['bi_score'],
230
+ reverse=True
231
+ )
232
  data_dicts = [{'query': item[0], **item[1]} for item in filtered_sorted_result]
233
  # Convert the list of dictionaries into a DataFrame
234
  df = pd.DataFrame(data_dicts)
 
238
  # st.write("## Raw Candidates:")
239
  if st.button('Generated Expansion'):
240
  candidates = generate_query_expansion_candidates(query = user_query)
241
+ df = re_rank_candidates(user_query, candidates, method='cross_encoder')
242
  result = list(df['query'][:maxtags_sidebar])
243
  st.write(result)
244
  ## convert into dataframe