KevinHuSh
commited on
Commit
·
bfb0635
1
Parent(s):
05dad97
fix mem leak for local reranker (#1295)
Browse files### What problem does this PR solve?
#1288
### Type of change
- [x] Bug Fix (non-breaking change which fixes an issue)
- rag/llm/rerank_model.py +15 -9
rag/llm/rerank_model.py
CHANGED
|
@@ -39,6 +39,7 @@ class Base(ABC):
|
|
| 39 |
class DefaultRerank(Base):
|
| 40 |
_model = None
|
| 41 |
_model_lock = threading.Lock()
|
|
|
|
| 42 |
def __init__(self, key, model_name, **kwargs):
|
| 43 |
"""
|
| 44 |
If you have trouble downloading HuggingFace models, -_^ this might help!!
|
|
@@ -102,19 +103,24 @@ class JinaRerank(Base):
|
|
| 102 |
|
| 103 |
class YoudaoRerank(DefaultRerank):
|
| 104 |
_model = None
|
|
|
|
| 105 |
|
| 106 |
def __init__(self, key=None, model_name="maidalun1020/bce-reranker-base_v1", **kwargs):
|
| 107 |
from BCEmbedding import RerankerModel
|
| 108 |
if not YoudaoRerank._model:
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 118 |
|
| 119 |
def similarity(self, query: str, texts: list):
|
| 120 |
pairs = [(query, truncate(t, self._model.max_length)) for t in texts]
|
|
|
|
| 39 |
class DefaultRerank(Base):
|
| 40 |
_model = None
|
| 41 |
_model_lock = threading.Lock()
|
| 42 |
+
|
| 43 |
def __init__(self, key, model_name, **kwargs):
|
| 44 |
"""
|
| 45 |
If you have trouble downloading HuggingFace models, -_^ this might help!!
|
|
|
|
| 103 |
|
| 104 |
class YoudaoRerank(DefaultRerank):
|
| 105 |
_model = None
|
| 106 |
+
_model_lock = threading.Lock()
|
| 107 |
|
| 108 |
def __init__(self, key=None, model_name="maidalun1020/bce-reranker-base_v1", **kwargs):
|
| 109 |
from BCEmbedding import RerankerModel
|
| 110 |
if not YoudaoRerank._model:
|
| 111 |
+
with YoudaoRerank._model_lock:
|
| 112 |
+
if not YoudaoRerank._model:
|
| 113 |
+
try:
|
| 114 |
+
print("LOADING BCE...")
|
| 115 |
+
YoudaoRerank._model = RerankerModel(model_name_or_path=os.path.join(
|
| 116 |
+
get_home_cache_dir(),
|
| 117 |
+
re.sub(r"^[a-zA-Z]+/", "", model_name)))
|
| 118 |
+
except Exception as e:
|
| 119 |
+
YoudaoRerank._model = RerankerModel(
|
| 120 |
+
model_name_or_path=model_name.replace(
|
| 121 |
+
"maidalun1020", "InfiniFlow"))
|
| 122 |
+
|
| 123 |
+
self._model = YoudaoRerank._model
|
| 124 |
|
| 125 |
def similarity(self, query: str, texts: list):
|
| 126 |
pairs = [(query, truncate(t, self._model.max_length)) for t in texts]
|