Wang Baoling
KevinHuSh
commited on
Commit
·
04d3b7e
1
Parent(s):
ac8a9f7
Fix: bug #991 (#1013)
Browse files### What problem does this PR solve?
issue #991
### Type of change
- [x] Bug Fix (non-breaking change which fixes an issue)
---------
Co-authored-by: KevinHuSh <kevinhu.sh@gmail.com>
- api/db/init_data.py +1 -1
- rag/llm/rerank_model.py +14 -0
api/db/init_data.py
CHANGED
|
@@ -386,7 +386,7 @@ def init_llm_factory():
|
|
| 386 |
"fid": factory_infos[7]["name"],
|
| 387 |
"llm_name": "maidalun1020/bce-reranker-base_v1",
|
| 388 |
"tags": "RE-RANK, 8K",
|
| 389 |
-
"max_tokens":
|
| 390 |
"model_type": LLMType.RERANK.value
|
| 391 |
},
|
| 392 |
# ------------------------ DeepSeek -----------------------
|
|
|
|
| 386 |
"fid": factory_infos[7]["name"],
|
| 387 |
"llm_name": "maidalun1020/bce-reranker-base_v1",
|
| 388 |
"tags": "RE-RANK, 8K",
|
| 389 |
+
"max_tokens": 512,
|
| 390 |
"model_type": LLMType.RERANK.value
|
| 391 |
},
|
| 392 |
# ------------------------ DeepSeek -----------------------
|
rag/llm/rerank_model.py
CHANGED
|
@@ -113,4 +113,18 @@ class YoudaoRerank(DefaultRerank):
|
|
| 113 |
YoudaoRerank._model = RerankerModel(
|
| 114 |
model_name_or_path=model_name.replace(
|
| 115 |
"maidalun1020", "InfiniFlow"))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 116 |
|
|
|
|
| 113 |
YoudaoRerank._model = RerankerModel(
|
| 114 |
model_name_or_path=model_name.replace(
|
| 115 |
"maidalun1020", "InfiniFlow"))
|
| 116 |
+
|
| 117 |
+
def similarity(self, query: str, texts: list):
|
| 118 |
+
pairs = [(query,truncate(t, self._model.max_length)) for t in texts]
|
| 119 |
+
token_count = 0
|
| 120 |
+
for _, t in pairs:
|
| 121 |
+
token_count += num_tokens_from_string(t)
|
| 122 |
+
batch_size = 32
|
| 123 |
+
res = []
|
| 124 |
+
for i in range(0, len(pairs), batch_size):
|
| 125 |
+
scores = self._model.compute_score(pairs[i:i + batch_size], max_length=self._model.max_length)
|
| 126 |
+
scores = sigmoid(np.array(scores)).tolist()
|
| 127 |
+
res.extend(scores)
|
| 128 |
+
return np.array(res), token_count
|
| 129 |
+
|
| 130 |
|