Kevin Hu
commited on
Commit
·
2be6429
1
Parent(s):
4a6bb1f
Make infinity able to cal embedding sim only. (#4644)
Browse files### What problem does this PR solve?
### Type of change
- [x] Bug Fix (non-breaking change which fixes an issue)
- rag/utils/infinity_conn.py +18 -3
rag/utils/infinity_conn.py
CHANGED
|
@@ -273,9 +273,22 @@ class InfinityConnection(DocStoreConnection):
|
|
| 273 |
for essential_field in ["id"]:
|
| 274 |
if essential_field not in selectFields:
|
| 275 |
selectFields.append(essential_field)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 276 |
if matchExprs:
|
| 277 |
-
|
| 278 |
-
|
| 279 |
|
| 280 |
# Prepare expressions common to all tables
|
| 281 |
filter_cond = None
|
|
@@ -364,7 +377,9 @@ class InfinityConnection(DocStoreConnection):
|
|
| 364 |
self.connPool.release_conn(inf_conn)
|
| 365 |
res = concat_dataframes(df_list, selectFields)
|
| 366 |
if matchExprs:
|
| 367 |
-
res = res.sort(pl.col(
|
|
|
|
|
|
|
| 368 |
res = res.limit(limit)
|
| 369 |
logger.debug(f"INFINITY search final result: {str(res)}")
|
| 370 |
return res, total_hits_count
|
|
|
|
| 273 |
for essential_field in ["id"]:
|
| 274 |
if essential_field not in selectFields:
|
| 275 |
selectFields.append(essential_field)
|
| 276 |
+
score_func = ""
|
| 277 |
+
score_column = ""
|
| 278 |
+
for matchExpr in matchExprs:
|
| 279 |
+
if isinstance(matchExpr, MatchTextExpr):
|
| 280 |
+
score_func = "score()"
|
| 281 |
+
score_column = "SCORE"
|
| 282 |
+
break
|
| 283 |
+
if not score_func:
|
| 284 |
+
for matchExpr in matchExprs:
|
| 285 |
+
if isinstance(matchExpr, MatchDenseExpr):
|
| 286 |
+
score_func = "similarity()"
|
| 287 |
+
score_column = "SIMILARITY"
|
| 288 |
+
break
|
| 289 |
if matchExprs:
|
| 290 |
+
selectFields.append(score_func)
|
| 291 |
+
selectFields.append(PAGERANK_FLD)
|
| 292 |
|
| 293 |
# Prepare expressions common to all tables
|
| 294 |
filter_cond = None
|
|
|
|
| 377 |
self.connPool.release_conn(inf_conn)
|
| 378 |
res = concat_dataframes(df_list, selectFields)
|
| 379 |
if matchExprs:
|
| 380 |
+
res = res.sort(pl.col(score_column) + pl.col(PAGERANK_FLD), descending=True, maintain_order=True)
|
| 381 |
+
if score_column and score_column != "SCORE":
|
| 382 |
+
res = res.rename({score_column: "SCORE"})
|
| 383 |
res = res.limit(limit)
|
| 384 |
logger.debug(f"INFINITY search final result: {str(res)}")
|
| 385 |
return res, total_hits_count
|