fix SILICONFLOW embedding error (#2363)
Browse files### What problem does this PR solve?
#2335 fix SILICONFLOW embedding error
### Type of change
- [x] Bug Fix (non-breaking change which fixes an issue)
---------
Co-authored-by: Zhedong Cen <cenzhedong2@126.com>
- rag/llm/embedding_model.py +33 -4
rag/llm/embedding_model.py
CHANGED
|
@@ -577,11 +577,40 @@ class UpstageEmbed(OpenAIEmbed):
|
|
| 577 |
super().__init__(key, model_name, base_url)
|
| 578 |
|
| 579 |
|
| 580 |
-
class SILICONFLOWEmbed(
|
| 581 |
-
def __init__(
|
|
|
|
|
|
|
| 582 |
if not base_url:
|
| 583 |
-
base_url = "https://api.siliconflow.cn/v1"
|
| 584 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 585 |
|
| 586 |
|
| 587 |
class ReplicateEmbed(Base):
|
|
|
|
| 577 |
super().__init__(key, model_name, base_url)
|
| 578 |
|
| 579 |
|
| 580 |
+
class SILICONFLOWEmbed(Base):
|
| 581 |
+
def __init__(
|
| 582 |
+
self, key, model_name, base_url="https://api.siliconflow.cn/v1/embeddings"
|
| 583 |
+
):
|
| 584 |
if not base_url:
|
| 585 |
+
base_url = "https://api.siliconflow.cn/v1/embeddings"
|
| 586 |
+
self.headers = {
|
| 587 |
+
"accept": "application/json",
|
| 588 |
+
"content-type": "application/json",
|
| 589 |
+
"authorization": f"Bearer {key}",
|
| 590 |
+
}
|
| 591 |
+
self.base_url = base_url
|
| 592 |
+
self.model_name = model_name
|
| 593 |
+
|
| 594 |
+
def encode(self, texts: list, batch_size=32):
|
| 595 |
+
payload = {
|
| 596 |
+
"model": self.model_name,
|
| 597 |
+
"input": texts,
|
| 598 |
+
"encoding_format": "float",
|
| 599 |
+
}
|
| 600 |
+
res = requests.post(self.base_url, json=payload, headers=self.headers).json()
|
| 601 |
+
return (
|
| 602 |
+
np.array([d["embedding"] for d in res["data"]]),
|
| 603 |
+
res["usage"]["total_tokens"],
|
| 604 |
+
)
|
| 605 |
+
|
| 606 |
+
def encode_queries(self, text):
|
| 607 |
+
payload = {
|
| 608 |
+
"model": self.model_name,
|
| 609 |
+
"input": text,
|
| 610 |
+
"encoding_format": "float",
|
| 611 |
+
}
|
| 612 |
+
res = requests.post(self.base_url, json=payload, headers=self.headers).json()
|
| 613 |
+
return np.array(res["data"][0]["embedding"]), res["usage"]["total_tokens"]
|
| 614 |
|
| 615 |
|
| 616 |
class ReplicateEmbed(Base):
|