shizzgar
Kevin Hu
commited on
Commit
·
9640d9a
1
Parent(s):
98a13e9
Added LocalAI support for rerank models (#3446)
Browse files### What problem does this PR solve?
Hi there!
LocalAI added support of rerank models
https://localai.io/features/reranker/
I've implemented LocalAIRerank class (typically copied it from
OpenAI_APIRerank class).
Also, LocalAI model response with 500 error code if len of "documents"
is less than 2 in similarity check.
So I've added the second "document" on RERANK model connection check in
`api/apps/llm_app.py`.
### Type of change
- [x] New Feature (non-breaking change which adds functionality)
Co-authored-by: Kevin Hu <kevinhu.sh@gmail.com>
- api/apps/llm_app.py +1 -1
- rag/llm/__init__.py +1 -0
- rag/llm/rerank_model.py +37 -2
api/apps/llm_app.py
CHANGED
|
@@ -238,7 +238,7 @@ def add_llm():
|
|
| 238 |
base_url=llm["api_base"]
|
| 239 |
)
|
| 240 |
try:
|
| 241 |
-
arr, tc = mdl.similarity("Hello~ Ragflower!", ["Hi, there!"])
|
| 242 |
if len(arr) == 0 or tc == 0:
|
| 243 |
raise Exception("Not known.")
|
| 244 |
except Exception as e:
|
|
|
|
| 238 |
base_url=llm["api_base"]
|
| 239 |
)
|
| 240 |
try:
|
| 241 |
+
arr, tc = mdl.similarity("Hello~ Ragflower!", ["Hi, there!", "Ohh, my friend!"])
|
| 242 |
if len(arr) == 0 or tc == 0:
|
| 243 |
raise Exception("Not known.")
|
| 244 |
except Exception as e:
|
rag/llm/__init__.py
CHANGED
|
@@ -110,6 +110,7 @@ ChatModel = {
|
|
| 110 |
}
|
| 111 |
|
| 112 |
RerankModel = {
|
|
|
|
| 113 |
"BAAI": DefaultRerank,
|
| 114 |
"Jina": JinaRerank,
|
| 115 |
"Youdao": YoudaoRerank,
|
|
|
|
| 110 |
}
|
| 111 |
|
| 112 |
RerankModel = {
|
| 113 |
+
"LocalAI":LocalAIRerank,
|
| 114 |
"BAAI": DefaultRerank,
|
| 115 |
"Jina": JinaRerank,
|
| 116 |
"Youdao": YoudaoRerank,
|
rag/llm/rerank_model.py
CHANGED
|
@@ -185,11 +185,46 @@ class XInferenceRerank(Base):
|
|
| 185 |
|
| 186 |
class LocalAIRerank(Base):
|
| 187 |
def __init__(self, key, model_name, base_url):
|
| 188 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 189 |
|
| 190 |
def similarity(self, query: str, texts: list):
|
| 191 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 192 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 193 |
|
| 194 |
class NvidiaRerank(Base):
|
| 195 |
def __init__(
|
|
|
|
| 185 |
|
| 186 |
class LocalAIRerank(Base):
|
| 187 |
def __init__(self, key, model_name, base_url):
|
| 188 |
+
if base_url.find("/rerank") == -1:
|
| 189 |
+
self.base_url = urljoin(base_url, "/rerank")
|
| 190 |
+
else:
|
| 191 |
+
self.base_url = base_url
|
| 192 |
+
self.headers = {
|
| 193 |
+
"Content-Type": "application/json",
|
| 194 |
+
"Authorization": f"Bearer {key}"
|
| 195 |
+
}
|
| 196 |
+
self.model_name = model_name.replace("___LocalAI","")
|
| 197 |
|
| 198 |
def similarity(self, query: str, texts: list):
|
| 199 |
+
# noway to config Ragflow , use fix setting
|
| 200 |
+
texts = [truncate(t, 500) for t in texts]
|
| 201 |
+
data = {
|
| 202 |
+
"model": self.model_name,
|
| 203 |
+
"query": query,
|
| 204 |
+
"documents": texts,
|
| 205 |
+
"top_n": len(texts),
|
| 206 |
+
}
|
| 207 |
+
token_count = 0
|
| 208 |
+
for t in texts:
|
| 209 |
+
token_count += num_tokens_from_string(t)
|
| 210 |
+
res = requests.post(self.base_url, headers=self.headers, json=data).json()
|
| 211 |
+
rank = np.zeros(len(texts), dtype=float)
|
| 212 |
+
if 'results' not in res:
|
| 213 |
+
raise ValueError("response not contains results\n" + str(res))
|
| 214 |
+
for d in res["results"]:
|
| 215 |
+
rank[d["index"]] = d["relevance_score"]
|
| 216 |
+
|
| 217 |
+
# Normalize the rank values to the range 0 to 1
|
| 218 |
+
min_rank = np.min(rank)
|
| 219 |
+
max_rank = np.max(rank)
|
| 220 |
|
| 221 |
+
# Avoid division by zero if all ranks are identical
|
| 222 |
+
if max_rank - min_rank != 0:
|
| 223 |
+
rank = (rank - min_rank) / (max_rank - min_rank)
|
| 224 |
+
else:
|
| 225 |
+
rank = np.zeros_like(rank)
|
| 226 |
+
|
| 227 |
+
return rank, token_count
|
| 228 |
|
| 229 |
class NvidiaRerank(Base):
|
| 230 |
def __init__(
|