Spaces:
Sleeping
Sleeping
Tuchuanhuhuhu commited on
Commit ·
8728d12
1
Parent(s): 3675c9f
bugfix: 加入LLaMA.cpp
Browse files- modules/models/OpenAI.py +14 -13
- modules/models/XMChat.py +2 -2
- modules/models/midjourney.py +3 -4
- modules/models/models.py +2 -2
modules/models/OpenAI.py
CHANGED
|
@@ -2,6 +2,7 @@ from __future__ import annotations
|
|
| 2 |
|
| 3 |
import json
|
| 4 |
import logging
|
|
|
|
| 5 |
|
| 6 |
import colorama
|
| 7 |
import requests
|
|
@@ -85,11 +86,11 @@ class OpenAIClient(BaseLLMModel):
|
|
| 85 |
|
| 86 |
# return i18n("**本月使用金额** ") + f"\u3000 ${rounded_usage}"
|
| 87 |
return get_html("billing_info.html").format(
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
except requests.exceptions.ConnectTimeout:
|
| 94 |
status_text = (
|
| 95 |
STANDARD_ERROR_MSG + CONNECTION_TIMEOUT_MSG + ERROR_RETRIEVE_MSG
|
|
@@ -161,6 +162,7 @@ class OpenAIClient(BaseLLMModel):
|
|
| 161 |
timeout=timeout,
|
| 162 |
)
|
| 163 |
except:
|
|
|
|
| 164 |
return None
|
| 165 |
return response
|
| 166 |
|
|
@@ -170,6 +172,7 @@ class OpenAIClient(BaseLLMModel):
|
|
| 170 |
"Authorization": f"Bearer {sensitive_id}",
|
| 171 |
}
|
| 172 |
|
|
|
|
| 173 |
def _get_billing_data(self, billing_url):
|
| 174 |
with retrieve_proxy():
|
| 175 |
response = requests.get(
|
|
@@ -240,6 +243,7 @@ class OpenAIClient(BaseLLMModel):
|
|
| 240 |
|
| 241 |
return response
|
| 242 |
|
|
|
|
| 243 |
def auto_name_chat_history(self, name_chat_method, user_question, chatbot, user_name, single_turn_checkbox):
|
| 244 |
if len(self.history) == 2 and not single_turn_checkbox:
|
| 245 |
user_question = self.history[0]["content"]
|
|
@@ -247,22 +251,19 @@ class OpenAIClient(BaseLLMModel):
|
|
| 247 |
ai_answer = self.history[1]["content"]
|
| 248 |
try:
|
| 249 |
history = [
|
| 250 |
-
{"role": "system", "content": SUMMARY_CHAT_SYSTEM_PROMPT},
|
| 251 |
-
{"role": "user", "content": f"Please write a title based on the following conversation:\n---\nUser: {user_question}\nAssistant: {ai_answer}"}
|
| 252 |
]
|
| 253 |
-
response = self._single_query_at_once(
|
| 254 |
-
history, temperature=0.0)
|
| 255 |
response = json.loads(response.text)
|
| 256 |
content = response["choices"][0]["message"]["content"]
|
| 257 |
filename = replace_special_symbols(content) + ".json"
|
| 258 |
except Exception as e:
|
| 259 |
logging.info(f"自动命名失败。{e}")
|
| 260 |
-
filename = replace_special_symbols(user_question)[
|
| 261 |
-
:16] + ".json"
|
| 262 |
return self.rename_chat_history(filename, chatbot, user_name)
|
| 263 |
elif name_chat_method == i18n("第一条提问"):
|
| 264 |
-
filename = replace_special_symbols(user_question)[
|
| 265 |
-
:16] + ".json"
|
| 266 |
return self.rename_chat_history(filename, chatbot, user_name)
|
| 267 |
else:
|
| 268 |
return gr.update()
|
|
|
|
| 2 |
|
| 3 |
import json
|
| 4 |
import logging
|
| 5 |
+
import traceback
|
| 6 |
|
| 7 |
import colorama
|
| 8 |
import requests
|
|
|
|
| 86 |
|
| 87 |
# return i18n("**本月使用金额** ") + f"\u3000 ${rounded_usage}"
|
| 88 |
return get_html("billing_info.html").format(
|
| 89 |
+
label = i18n("本月使用金额"),
|
| 90 |
+
usage_percent = usage_percent,
|
| 91 |
+
rounded_usage = rounded_usage,
|
| 92 |
+
usage_limit = usage_limit
|
| 93 |
+
)
|
| 94 |
except requests.exceptions.ConnectTimeout:
|
| 95 |
status_text = (
|
| 96 |
STANDARD_ERROR_MSG + CONNECTION_TIMEOUT_MSG + ERROR_RETRIEVE_MSG
|
|
|
|
| 162 |
timeout=timeout,
|
| 163 |
)
|
| 164 |
except:
|
| 165 |
+
traceback.print_exc()
|
| 166 |
return None
|
| 167 |
return response
|
| 168 |
|
|
|
|
| 172 |
"Authorization": f"Bearer {sensitive_id}",
|
| 173 |
}
|
| 174 |
|
| 175 |
+
|
| 176 |
def _get_billing_data(self, billing_url):
|
| 177 |
with retrieve_proxy():
|
| 178 |
response = requests.get(
|
|
|
|
| 243 |
|
| 244 |
return response
|
| 245 |
|
| 246 |
+
|
| 247 |
def auto_name_chat_history(self, name_chat_method, user_question, chatbot, user_name, single_turn_checkbox):
|
| 248 |
if len(self.history) == 2 and not single_turn_checkbox:
|
| 249 |
user_question = self.history[0]["content"]
|
|
|
|
| 251 |
ai_answer = self.history[1]["content"]
|
| 252 |
try:
|
| 253 |
history = [
|
| 254 |
+
{ "role": "system", "content": SUMMARY_CHAT_SYSTEM_PROMPT},
|
| 255 |
+
{ "role": "user", "content": f"Please write a title based on the following conversation:\n---\nUser: {user_question}\nAssistant: {ai_answer}"}
|
| 256 |
]
|
| 257 |
+
response = self._single_query_at_once(history, temperature=0.0)
|
|
|
|
| 258 |
response = json.loads(response.text)
|
| 259 |
content = response["choices"][0]["message"]["content"]
|
| 260 |
filename = replace_special_symbols(content) + ".json"
|
| 261 |
except Exception as e:
|
| 262 |
logging.info(f"自动命名失败。{e}")
|
| 263 |
+
filename = replace_special_symbols(user_question)[:16] + ".json"
|
|
|
|
| 264 |
return self.rename_chat_history(filename, chatbot, user_name)
|
| 265 |
elif name_chat_method == i18n("第一条提问"):
|
| 266 |
+
filename = replace_special_symbols(user_question)[:16] + ".json"
|
|
|
|
| 267 |
return self.rename_chat_history(filename, chatbot, user_name)
|
| 268 |
else:
|
| 269 |
return gr.update()
|
modules/models/XMChat.py
CHANGED
|
@@ -16,7 +16,7 @@ from ..utils import *
|
|
| 16 |
from .base_model import BaseLLMModel
|
| 17 |
|
| 18 |
|
| 19 |
-
class
|
| 20 |
def __init__(self, api_key, user_name=""):
|
| 21 |
super().__init__(model_name="xmchat", user=user_name)
|
| 22 |
self.api_key = api_key
|
|
@@ -31,7 +31,7 @@ class XMChatClient(BaseLLMModel):
|
|
| 31 |
def reset(self):
|
| 32 |
self.session_id = str(uuid.uuid4())
|
| 33 |
self.last_conv_id = None
|
| 34 |
-
return
|
| 35 |
|
| 36 |
def image_to_base64(self, image_path):
|
| 37 |
# 打开并加载图片
|
|
|
|
| 16 |
from .base_model import BaseLLMModel
|
| 17 |
|
| 18 |
|
| 19 |
+
class XMChat(BaseLLMModel):
|
| 20 |
def __init__(self, api_key, user_name=""):
|
| 21 |
super().__init__(model_name="xmchat", user=user_name)
|
| 22 |
self.api_key = api_key
|
|
|
|
| 31 |
def reset(self):
|
| 32 |
self.session_id = str(uuid.uuid4())
|
| 33 |
self.last_conv_id = None
|
| 34 |
+
return super().reset()
|
| 35 |
|
| 36 |
def image_to_base64(self, image_path):
|
| 37 |
# 打开并加载图片
|
modules/models/midjourney.py
CHANGED
|
@@ -2,11 +2,10 @@ import base64
|
|
| 2 |
import io
|
| 3 |
import json
|
| 4 |
import logging
|
|
|
|
| 5 |
import pathlib
|
| 6 |
-
import time
|
| 7 |
import tempfile
|
| 8 |
-
import
|
| 9 |
-
|
| 10 |
from datetime import datetime
|
| 11 |
|
| 12 |
import requests
|
|
@@ -14,7 +13,7 @@ import tiktoken
|
|
| 14 |
from PIL import Image
|
| 15 |
|
| 16 |
from modules.config import retrieve_proxy
|
| 17 |
-
from modules.models.
|
| 18 |
|
| 19 |
mj_proxy_api_base = os.getenv("MIDJOURNEY_PROXY_API_BASE")
|
| 20 |
mj_discord_proxy_url = os.getenv("MIDJOURNEY_DISCORD_PROXY_URL")
|
|
|
|
| 2 |
import io
|
| 3 |
import json
|
| 4 |
import logging
|
| 5 |
+
import os
|
| 6 |
import pathlib
|
|
|
|
| 7 |
import tempfile
|
| 8 |
+
import time
|
|
|
|
| 9 |
from datetime import datetime
|
| 10 |
|
| 11 |
import requests
|
|
|
|
| 13 |
from PIL import Image
|
| 14 |
|
| 15 |
from modules.config import retrieve_proxy
|
| 16 |
+
from modules.models.XMChat import XMChat
|
| 17 |
|
| 18 |
mj_proxy_api_base = os.getenv("MIDJOURNEY_PROXY_API_BASE")
|
| 19 |
mj_discord_proxy_url = os.getenv("MIDJOURNEY_DISCORD_PROXY_URL")
|
modules/models/models.py
CHANGED
|
@@ -69,10 +69,10 @@ def get_model(
|
|
| 69 |
model = LLaMA_Client(
|
| 70 |
model_name, lora_model_path, user_name=user_name)
|
| 71 |
elif model_type == ModelType.XMChat:
|
| 72 |
-
from .XMChat import
|
| 73 |
if os.environ.get("XMCHAT_API_KEY") != "":
|
| 74 |
access_key = os.environ.get("XMCHAT_API_KEY")
|
| 75 |
-
model =
|
| 76 |
elif model_type == ModelType.StableLM:
|
| 77 |
from .StableLM import StableLM_Client
|
| 78 |
model = StableLM_Client(model_name, user_name=user_name)
|
|
|
|
| 69 |
model = LLaMA_Client(
|
| 70 |
model_name, lora_model_path, user_name=user_name)
|
| 71 |
elif model_type == ModelType.XMChat:
|
| 72 |
+
from .XMChat import XMChat
|
| 73 |
if os.environ.get("XMCHAT_API_KEY") != "":
|
| 74 |
access_key = os.environ.get("XMCHAT_API_KEY")
|
| 75 |
+
model = XMChat(api_key=access_key, user_name=user_name)
|
| 76 |
elif model_type == ModelType.StableLM:
|
| 77 |
from .StableLM import StableLM_Client
|
| 78 |
model = StableLM_Client(model_name, user_name=user_name)
|