Kevin Hu
commited on
Commit
·
244188b
1
Parent(s):
04c2182
refactor add LLM (#2508)
Browse files### What problem does this PR solve?
#2487
### Type of change
- [x] Refactoring
- api/apps/llm_app.py +25 -22
- rag/llm/chat_model.py +1 -1
api/apps/llm_app.py
CHANGED
|
@@ -13,6 +13,8 @@
|
|
| 13 |
# See the License for the specific language governing permissions and
|
| 14 |
# limitations under the License.
|
| 15 |
#
|
|
|
|
|
|
|
| 16 |
from flask import request
|
| 17 |
from flask_login import login_required, current_user
|
| 18 |
from api.db.services.llm_service import LLMFactoriesService, TenantLLMService, LLMService
|
|
@@ -126,55 +128,56 @@ def add_llm():
|
|
| 126 |
req = request.json
|
| 127 |
factory = req["llm_factory"]
|
| 128 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 129 |
if factory == "VolcEngine":
|
| 130 |
# For VolcEngine, due to its special authentication method
|
| 131 |
# Assemble ark_api_key endpoint_id into api_key
|
| 132 |
llm_name = req["llm_name"]
|
| 133 |
-
api_key =
|
|
|
|
| 134 |
elif factory == "Tencent Hunyuan":
|
| 135 |
-
api_key =
|
| 136 |
-
f'"hunyuan_sk": "{req.get("hunyuan_sk", "")}"' + '}'
|
| 137 |
-
req["api_key"] = api_key
|
| 138 |
return set_api_key()
|
|
|
|
| 139 |
elif factory == "Tencent Cloud":
|
| 140 |
-
api_key =
|
| 141 |
-
|
| 142 |
-
req["api_key"] = api_key
|
| 143 |
elif factory == "Bedrock":
|
| 144 |
# For Bedrock, due to its special authentication method
|
| 145 |
# Assemble bedrock_ak, bedrock_sk, bedrock_region
|
| 146 |
llm_name = req["llm_name"]
|
| 147 |
-
api_key =
|
| 148 |
-
|
| 149 |
-
f'"bedrock_region": "{req.get("bedrock_region", "")}", ' + '}'
|
| 150 |
elif factory == "LocalAI":
|
| 151 |
llm_name = req["llm_name"]+"___LocalAI"
|
| 152 |
api_key = "xxxxxxxxxxxxxxx"
|
|
|
|
| 153 |
elif factory == "OpenAI-API-Compatible":
|
| 154 |
llm_name = req["llm_name"]+"___OpenAI-API"
|
| 155 |
api_key = req.get("api_key","xxxxxxxxxxxxxxx")
|
|
|
|
| 156 |
elif factory =="XunFei Spark":
|
| 157 |
llm_name = req["llm_name"]
|
| 158 |
-
api_key = req.get("spark_api_password","xxxxxxxxxxxxxxx")
|
|
|
|
| 159 |
elif factory == "BaiduYiyan":
|
| 160 |
llm_name = req["llm_name"]
|
| 161 |
-
api_key =
|
| 162 |
-
|
| 163 |
elif factory == "Fish Audio":
|
| 164 |
llm_name = req["llm_name"]
|
| 165 |
-
api_key =
|
| 166 |
-
|
| 167 |
elif factory == "Google Cloud":
|
| 168 |
llm_name = req["llm_name"]
|
| 169 |
-
api_key = (
|
| 170 |
-
|
| 171 |
-
f'"google_region": "{req.get("google_region", "")}", '
|
| 172 |
-
f'"google_service_account_key": "{req.get("google_service_account_key", "")}"'
|
| 173 |
-
+ "}"
|
| 174 |
-
)
|
| 175 |
else:
|
| 176 |
llm_name = req["llm_name"]
|
| 177 |
-
api_key = req.get("api_key","xxxxxxxxxxxxxxx")
|
| 178 |
|
| 179 |
llm = {
|
| 180 |
"tenant_id": current_user.id,
|
|
|
|
| 13 |
# See the License for the specific language governing permissions and
|
| 14 |
# limitations under the License.
|
| 15 |
#
|
| 16 |
+
import json
|
| 17 |
+
|
| 18 |
from flask import request
|
| 19 |
from flask_login import login_required, current_user
|
| 20 |
from api.db.services.llm_service import LLMFactoriesService, TenantLLMService, LLMService
|
|
|
|
| 128 |
req = request.json
|
| 129 |
factory = req["llm_factory"]
|
| 130 |
|
| 131 |
+
def apikey_json(keys):
|
| 132 |
+
nonlocal req
|
| 133 |
+
return json.dumps({k: req.get(k, "") for k in keys})
|
| 134 |
+
|
| 135 |
if factory == "VolcEngine":
|
| 136 |
# For VolcEngine, due to its special authentication method
|
| 137 |
# Assemble ark_api_key endpoint_id into api_key
|
| 138 |
llm_name = req["llm_name"]
|
| 139 |
+
api_key = apikey_json(["ark_api_key", "endpoint_id"])
|
| 140 |
+
|
| 141 |
elif factory == "Tencent Hunyuan":
|
| 142 |
+
req["api_key"] = apikey_json(["hunyuan_sid", "hunyuan_sk"])
|
|
|
|
|
|
|
| 143 |
return set_api_key()
|
| 144 |
+
|
| 145 |
elif factory == "Tencent Cloud":
|
| 146 |
+
req["api_key"] = apikey_json(["tencent_cloud_sid", "tencent_cloud_sk"])
|
| 147 |
+
|
|
|
|
| 148 |
elif factory == "Bedrock":
|
| 149 |
# For Bedrock, due to its special authentication method
|
| 150 |
# Assemble bedrock_ak, bedrock_sk, bedrock_region
|
| 151 |
llm_name = req["llm_name"]
|
| 152 |
+
api_key = apikey_json(["bedrock_ak", "bedrock_sk", "bedrock_region"])
|
| 153 |
+
|
|
|
|
| 154 |
elif factory == "LocalAI":
|
| 155 |
llm_name = req["llm_name"]+"___LocalAI"
|
| 156 |
api_key = "xxxxxxxxxxxxxxx"
|
| 157 |
+
|
| 158 |
elif factory == "OpenAI-API-Compatible":
|
| 159 |
llm_name = req["llm_name"]+"___OpenAI-API"
|
| 160 |
api_key = req.get("api_key","xxxxxxxxxxxxxxx")
|
| 161 |
+
|
| 162 |
elif factory =="XunFei Spark":
|
| 163 |
llm_name = req["llm_name"]
|
| 164 |
+
api_key = req.get("spark_api_password","xxxxxxxxxxxxxxx")
|
| 165 |
+
|
| 166 |
elif factory == "BaiduYiyan":
|
| 167 |
llm_name = req["llm_name"]
|
| 168 |
+
api_key = apikey_json(["yiyan_ak", "yiyan_sk"])
|
| 169 |
+
|
| 170 |
elif factory == "Fish Audio":
|
| 171 |
llm_name = req["llm_name"]
|
| 172 |
+
api_key = apikey_json(["fish_audio_ak", "fish_audio_refid"])
|
| 173 |
+
|
| 174 |
elif factory == "Google Cloud":
|
| 175 |
llm_name = req["llm_name"]
|
| 176 |
+
api_key = apikey_json(["google_project_id", "google_region", "google_service_account_key"])
|
| 177 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
| 178 |
else:
|
| 179 |
llm_name = req["llm_name"]
|
| 180 |
+
api_key = req.get("api_key", "xxxxxxxxxxxxxxx")
|
| 181 |
|
| 182 |
llm = {
|
| 183 |
"tenant_id": current_user.id,
|
rag/llm/chat_model.py
CHANGED
|
@@ -458,7 +458,7 @@ class VolcEngineChat(Base):
|
|
| 458 |
"""
|
| 459 |
base_url = base_url if base_url else 'https://ark.cn-beijing.volces.com/api/v3'
|
| 460 |
ark_api_key = json.loads(key).get('ark_api_key', '')
|
| 461 |
-
model_name = json.loads(key).get('ep_id', '')
|
| 462 |
super().__init__(ark_api_key, model_name, base_url)
|
| 463 |
|
| 464 |
|
|
|
|
| 458 |
"""
|
| 459 |
base_url = base_url if base_url else 'https://ark.cn-beijing.volces.com/api/v3'
|
| 460 |
ark_api_key = json.loads(key).get('ark_api_key', '')
|
| 461 |
+
model_name = json.loads(key).get('ep_id', '') + json.loads(key).get('endpoint_id', '')
|
| 462 |
super().__init__(ark_api_key, model_name, base_url)
|
| 463 |
|
| 464 |
|