Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -212,13 +212,6 @@ def re_rank_candidates(query, candidates, method):
|
|
| 212 |
key=lambda x: x[1]['cross_score'],
|
| 213 |
reverse=True
|
| 214 |
)
|
| 215 |
-
elif method == 'encoder':
|
| 216 |
-
# Filter and sort by cross_score + bi_score
|
| 217 |
-
filtered_sorted_result = sorted(
|
| 218 |
-
[(k, v) for k, v in candidates.items() if (v['cross_score'] > DEFAULT_SCORE) & (v['bi_score'] > DEFAULT_SCORE)],
|
| 219 |
-
key=lambda x: x[1]['cross_score'] + x[1]['bi_score'],
|
| 220 |
-
reverse=True
|
| 221 |
-
)
|
| 222 |
elif method == 'gms':
|
| 223 |
filtered_sorted_by_encoder = sorted(
|
| 224 |
[(k, v) for k, v in candidates.items() if (v['cross_score'] > DEFAULT_SCORE) & (v['bi_score'] > DEFAULT_SCORE)],
|
|
@@ -228,6 +221,14 @@ def re_rank_candidates(query, candidates, method):
|
|
| 228 |
# first sort by cross_score + bi_score
|
| 229 |
filtered_sorted_result = sorted(filtered_sorted_by_encoder, key=lambda x: x[1]['gms'], reverse=True
|
| 230 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 231 |
data_dicts = [{'query': item[0], **item[1]} for item in filtered_sorted_result]
|
| 232 |
# Convert the list of dictionaries into a DataFrame
|
| 233 |
df = pd.DataFrame(data_dicts)
|
|
@@ -237,7 +238,7 @@ def re_rank_candidates(query, candidates, method):
|
|
| 237 |
# st.write("## Raw Candidates:")
|
| 238 |
if st.button('Generated Expansion'):
|
| 239 |
candidates = generate_query_expansion_candidates(query = user_query)
|
| 240 |
-
df = re_rank_candidates(user_query, candidates, method='
|
| 241 |
result = list(df['query'][:maxtags_sidebar])
|
| 242 |
st.write(result)
|
| 243 |
## convert into dataframe
|
|
|
|
| 212 |
key=lambda x: x[1]['cross_score'],
|
| 213 |
reverse=True
|
| 214 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 215 |
elif method == 'gms':
|
| 216 |
filtered_sorted_by_encoder = sorted(
|
| 217 |
[(k, v) for k, v in candidates.items() if (v['cross_score'] > DEFAULT_SCORE) & (v['bi_score'] > DEFAULT_SCORE)],
|
|
|
|
| 221 |
# first sort by cross_score + bi_score
|
| 222 |
filtered_sorted_result = sorted(filtered_sorted_by_encoder, key=lambda x: x[1]['gms'], reverse=True
|
| 223 |
)
|
| 224 |
+
else:
|
| 225 |
+
# use default method cross_score + bi_score
|
| 226 |
+
# Filter and sort by cross_score + bi_score
|
| 227 |
+
filtered_sorted_result = sorted(
|
| 228 |
+
[(k, v) for k, v in candidates.items() if (v['cross_score'] > DEFAULT_SCORE) & (v['bi_score'] > DEFAULT_SCORE)],
|
| 229 |
+
key=lambda x: x[1]['cross_score'] + x[1]['bi_score'],
|
| 230 |
+
reverse=True
|
| 231 |
+
)
|
| 232 |
data_dicts = [{'query': item[0], **item[1]} for item in filtered_sorted_result]
|
| 233 |
# Convert the list of dictionaries into a DataFrame
|
| 234 |
df = pd.DataFrame(data_dicts)
|
|
|
|
| 238 |
# st.write("## Raw Candidates:")
|
| 239 |
if st.button('Generated Expansion'):
|
| 240 |
candidates = generate_query_expansion_candidates(query = user_query)
|
| 241 |
+
df = re_rank_candidates(user_query, candidates, method='cross_encoder')
|
| 242 |
result = list(df['query'][:maxtags_sidebar])
|
| 243 |
st.write(result)
|
| 244 |
## convert into dataframe
|