Upload folder using huggingface_hub
Browse files- ChatWorld/ChatWorld.py +6 -8
- ChatWorld/NaiveDB.py +5 -2
- ChatWorld/models.py +8 -5
- app.py +31 -5
ChatWorld/ChatWorld.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
| 1 |
from jinja2 import Template
|
| 2 |
import torch
|
| 3 |
|
| 4 |
-
from .models import GLM
|
| 5 |
|
| 6 |
from .NaiveDB import NaiveDB
|
| 7 |
from .utils import *
|
|
@@ -19,7 +19,7 @@ class ChatWorld:
|
|
| 19 |
|
| 20 |
self.history = []
|
| 21 |
|
| 22 |
-
self.client =
|
| 23 |
self.model = GLM()
|
| 24 |
self.db = NaiveDB()
|
| 25 |
self.prompt = Template(('Please be aware that your codename in this conversation is "{{model_role_name}}"'
|
|
@@ -81,17 +81,15 @@ class ChatWorld:
|
|
| 81 |
return {"role": "system", "content": self.prompt.render(model_role_name=self.model_role_name, model_role_nickname=self.model_role_nickname, role_name=role_name, role_nickname=role_nick_name, RAG=rag)}
|
| 82 |
|
| 83 |
def chat(self, text: str, user_role_name: str, user_role_nick_name: str = None, use_local_model=False):
|
|
|
|
|
|
|
| 84 |
message = [self.getSystemPrompt(text,
|
| 85 |
-
user_role_name, user_role_nick_name)
|
| 86 |
-
print(message)
|
| 87 |
if use_local_model:
|
| 88 |
response = self.model.get_response(message)
|
| 89 |
else:
|
| 90 |
-
response = self.client.chat(
|
| 91 |
-
user_role_name, text, user_role_nick_name)
|
| 92 |
|
| 93 |
-
self.history.append(
|
| 94 |
-
{"role": "user", "content": f"{user_role_name}:「{text}」"})
|
| 95 |
self.history.append(
|
| 96 |
{"role": "assistant", "content": f"{self.model_role_name}:「{response}」"})
|
| 97 |
return response
|
|
|
|
| 1 |
from jinja2 import Template
|
| 2 |
import torch
|
| 3 |
|
| 4 |
+
from .models import GLM, GLM_api
|
| 5 |
|
| 6 |
from .NaiveDB import NaiveDB
|
| 7 |
from .utils import *
|
|
|
|
| 19 |
|
| 20 |
self.history = []
|
| 21 |
|
| 22 |
+
self.client = GLM_api()
|
| 23 |
self.model = GLM()
|
| 24 |
self.db = NaiveDB()
|
| 25 |
self.prompt = Template(('Please be aware that your codename in this conversation is "{{model_role_name}}"'
|
|
|
|
| 81 |
return {"role": "system", "content": self.prompt.render(model_role_name=self.model_role_name, model_role_nickname=self.model_role_nickname, role_name=role_name, role_nickname=role_nick_name, RAG=rag)}
|
| 82 |
|
| 83 |
def chat(self, text: str, user_role_name: str, user_role_nick_name: str = None, use_local_model=False):
|
| 84 |
+
self.history.append(
|
| 85 |
+
{"role": "user", "content": f"{user_role_name}:「{text}」"})
|
| 86 |
message = [self.getSystemPrompt(text,
|
| 87 |
+
user_role_name, user_role_nick_name), {"role": "user", "content": f"{user_role_name}:「{text}」"}]
|
|
|
|
| 88 |
if use_local_model:
|
| 89 |
response = self.model.get_response(message)
|
| 90 |
else:
|
| 91 |
+
response = self.client.chat(message)
|
|
|
|
| 92 |
|
|
|
|
|
|
|
| 93 |
self.history.append(
|
| 94 |
{"role": "assistant", "content": f"{self.model_role_name}:「{response}」"})
|
| 95 |
return response
|
ChatWorld/NaiveDB.py
CHANGED
|
@@ -81,7 +81,10 @@ class NaiveDB:
|
|
| 81 |
similarities.sort(key=lambda x: x[0], reverse=True)
|
| 82 |
self.last_search_ids = [x[1] for x in similarities[:n_results]]
|
| 83 |
|
| 84 |
-
|
|
|
|
|
|
|
| 85 |
|
| 86 |
-
top_stories = [self.stories[
|
|
|
|
| 87 |
return top_stories
|
|
|
|
| 81 |
similarities.sort(key=lambda x: x[0], reverse=True)
|
| 82 |
self.last_search_ids = [x[1] for x in similarities[:n_results]]
|
| 83 |
|
| 84 |
+
stories_length = len(self.stories)
|
| 85 |
+
search_id_range = [(max(0, i-3), min(i+4, stories_length))
|
| 86 |
+
for i in self.last_search_ids]
|
| 87 |
|
| 88 |
+
top_stories = ["\n".join(self.stories[start:end+1])
|
| 89 |
+
for start, end in search_id_range]
|
| 90 |
return top_stories
|
ChatWorld/models.py
CHANGED
|
@@ -40,7 +40,7 @@ class GLM():
|
|
| 40 |
|
| 41 |
self.client = client.eval()
|
| 42 |
|
| 43 |
-
def message2query(messages) -> str:
|
| 44 |
# [{'role': 'user', 'content': '老师: 同学请自我介绍一下'}]
|
| 45 |
# <|system|>
|
| 46 |
# You are ChatGLM3, a large language model trained by Zhipu.AI. Follow the user's instructions carefully. Respond using markdown.
|
|
@@ -53,7 +53,9 @@ class GLM():
|
|
| 53 |
return "".join([template.substitute(message) for message in messages])
|
| 54 |
|
| 55 |
def get_response(self, message):
|
| 56 |
-
response, history = self.client.chat(
|
|
|
|
|
|
|
| 57 |
return response
|
| 58 |
|
| 59 |
|
|
@@ -62,7 +64,8 @@ class GLM_api:
|
|
| 62 |
self.client = ZhipuAI(api_key=os.environ["ZHIPU_API_KEY"])
|
| 63 |
self.model = model_name
|
| 64 |
|
| 65 |
-
def
|
|
|
|
| 66 |
response = self.client.chat.completions.create(
|
| 67 |
-
model=self.model,
|
| 68 |
-
return response.choices[0].message
|
|
|
|
| 40 |
|
| 41 |
self.client = client.eval()
|
| 42 |
|
| 43 |
+
def message2query(self, messages) -> str:
|
| 44 |
# [{'role': 'user', 'content': '老师: 同学请自我介绍一下'}]
|
| 45 |
# <|system|>
|
| 46 |
# You are ChatGLM3, a large language model trained by Zhipu.AI. Follow the user's instructions carefully. Respond using markdown.
|
|
|
|
| 53 |
return "".join([template.substitute(message) for message in messages])
|
| 54 |
|
| 55 |
def get_response(self, message):
|
| 56 |
+
response, history = self.client.chat(
|
| 57 |
+
self.tokenizer, self.message2query(message))
|
| 58 |
+
print(self.message2query(message))
|
| 59 |
return response
|
| 60 |
|
| 61 |
|
|
|
|
| 64 |
self.client = ZhipuAI(api_key=os.environ["ZHIPU_API_KEY"])
|
| 65 |
self.model = model_name
|
| 66 |
|
| 67 |
+
def chat(self, message):
|
| 68 |
+
print(message)
|
| 69 |
response = self.client.chat.completions.create(
|
| 70 |
+
model=self.model, messages=message)
|
| 71 |
+
return response.choices[0].message.content
|
app.py
CHANGED
|
@@ -11,6 +11,8 @@ logging.basicConfig(level=logging.INFO, filename="demo.log", filemode="w",
|
|
| 11 |
|
| 12 |
chatWorld = ChatWorld()
|
| 13 |
|
|
|
|
|
|
|
| 14 |
|
| 15 |
def getContent(input_file):
|
| 16 |
# 读取文件内容
|
|
@@ -31,33 +33,57 @@ def getContent(input_file):
|
|
| 31 |
role_name_list = [i for i in role_name_set if i != ""]
|
| 32 |
logging.info(f"role_name_list: {role_name_list}")
|
| 33 |
|
|
|
|
|
|
|
|
|
|
| 34 |
return gr.Radio(choices=role_name_list, interactive=True, value=role_name_list[0]), gr.Radio(choices=role_name_list, interactive=True, value=role_name_list[-1])
|
| 35 |
|
| 36 |
|
| 37 |
def submit_message(message, history, model_role_name, role_name, model_role_nickname, role_nickname):
|
|
|
|
| 38 |
chatWorld.setRoleName(model_role_name, model_role_nickname)
|
| 39 |
response = chatWorld.chat(message,
|
| 40 |
role_name, role_nickname, use_local_model=True)
|
| 41 |
return response
|
| 42 |
|
| 43 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
with gr.Blocks() as demo:
|
| 45 |
|
| 46 |
upload_c = gr.File(label="上传文档文件")
|
| 47 |
|
| 48 |
with gr.Row():
|
| 49 |
-
model_role_name = gr.Radio(
|
| 50 |
model_role_nickname = gr.Textbox(label="模型角色昵称")
|
| 51 |
|
| 52 |
with gr.Row():
|
| 53 |
-
role_name = gr.Radio(
|
| 54 |
role_nickname = gr.Textbox(label="角色昵称")
|
| 55 |
|
| 56 |
upload_c.upload(fn=getContent, inputs=upload_c,
|
| 57 |
outputs=[model_role_name, role_name])
|
| 58 |
|
| 59 |
-
|
| 60 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
|
| 62 |
|
| 63 |
-
demo.launch(
|
|
|
|
| 11 |
|
| 12 |
chatWorld = ChatWorld()
|
| 13 |
|
| 14 |
+
role_name_list_global = None
|
| 15 |
+
|
| 16 |
|
| 17 |
def getContent(input_file):
|
| 18 |
# 读取文件内容
|
|
|
|
| 33 |
role_name_list = [i for i in role_name_set if i != ""]
|
| 34 |
logging.info(f"role_name_list: {role_name_list}")
|
| 35 |
|
| 36 |
+
global role_name_list_global
|
| 37 |
+
role_name_list_global = role_name_list
|
| 38 |
+
|
| 39 |
return gr.Radio(choices=role_name_list, interactive=True, value=role_name_list[0]), gr.Radio(choices=role_name_list, interactive=True, value=role_name_list[-1])
|
| 40 |
|
| 41 |
|
| 42 |
def submit_message(message, history, model_role_name, role_name, model_role_nickname, role_nickname):
|
| 43 |
+
print(f"history: {history}")
|
| 44 |
chatWorld.setRoleName(model_role_name, model_role_nickname)
|
| 45 |
response = chatWorld.chat(message,
|
| 46 |
role_name, role_nickname, use_local_model=True)
|
| 47 |
return response
|
| 48 |
|
| 49 |
|
| 50 |
+
def submit_message_api(message, history, model_role_name, role_name, model_role_nickname, role_nickname):
|
| 51 |
+
print(f"history: {history}")
|
| 52 |
+
chatWorld.setRoleName(model_role_name, model_role_nickname)
|
| 53 |
+
response = chatWorld.chat(message,
|
| 54 |
+
role_name, role_nickname, use_local_model=False)
|
| 55 |
+
return response
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def get_role_list():
|
| 59 |
+
global role_name_list_global
|
| 60 |
+
if role_name_list_global:
|
| 61 |
+
return role_name_list_global
|
| 62 |
+
else:
|
| 63 |
+
return []
|
| 64 |
+
|
| 65 |
+
|
| 66 |
with gr.Blocks() as demo:
|
| 67 |
|
| 68 |
upload_c = gr.File(label="上传文档文件")
|
| 69 |
|
| 70 |
with gr.Row():
|
| 71 |
+
model_role_name = gr.Radio(get_role_list(), label="模型角色名")
|
| 72 |
model_role_nickname = gr.Textbox(label="模型角色昵称")
|
| 73 |
|
| 74 |
with gr.Row():
|
| 75 |
+
role_name = gr.Radio(get_role_list(), label="角色名")
|
| 76 |
role_nickname = gr.Textbox(label="角色昵称")
|
| 77 |
|
| 78 |
upload_c.upload(fn=getContent, inputs=upload_c,
|
| 79 |
outputs=[model_role_name, role_name])
|
| 80 |
|
| 81 |
+
with gr.Row():
|
| 82 |
+
chatBox_local = gr.ChatInterface(
|
| 83 |
+
submit_message, chatbot=gr.Chatbot(height=400, label="本地模型", render=False), additional_inputs=[model_role_name, role_name, model_role_nickname, role_nickname])
|
| 84 |
+
|
| 85 |
+
chatBox_api = gr.ChatInterface(
|
| 86 |
+
submit_message_api, chatbot=gr.Chatbot(height=400, label="API模型", render=False), additional_inputs=[model_role_name, role_name, model_role_nickname, role_nickname])
|
| 87 |
|
| 88 |
|
| 89 |
+
demo.launch(server_name="0.0.0.0")
|