Upload folder using huggingface_hub
Browse files- ChatWorld/ChatWorld.py +32 -20
- ChatWorld/NaiveDB.py +4 -2
- ChatWorld/models.py +44 -1
- app.py +2 -2
- run_gradio.sh +1 -0
ChatWorld/ChatWorld.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
| 1 |
from jinja2 import Template
|
| 2 |
import torch
|
| 3 |
|
| 4 |
-
from .models import
|
| 5 |
|
| 6 |
from .NaiveDB import NaiveDB
|
| 7 |
from .utils import *
|
|
@@ -20,7 +20,7 @@ class ChatWorld:
|
|
| 20 |
self.history = []
|
| 21 |
|
| 22 |
self.client = None
|
| 23 |
-
self.model =
|
| 24 |
self.db = NaiveDB()
|
| 25 |
self.prompt = Template(('Please be aware that your codename in this conversation is "{{model_role_name}}"'
|
| 26 |
'{% if model_role_nickname %},别人一般称呼你"{{model_role_nickname}}"{% endif %},同时我是"{{role_name}}"{% if role_nickname %},你也可以称呼我"{{role_nickname}}"{% endif %}。\n'
|
|
@@ -30,6 +30,7 @@ class ChatWorld:
|
|
| 30 |
'如果我问的问题超出聊天记录的范围,模仿{{model_role_name}}的语气进行回复。\n'
|
| 31 |
'请不要回答你是语言模型,永远记住你就是{{model_role_name}}。\n'
|
| 32 |
'请你永远只以{{model_role_name}}身份,进行任何的回复。\n'
|
|
|
|
| 33 |
))
|
| 34 |
|
| 35 |
def getEmbeddingsFromStory(self, stories: list[str]):
|
|
@@ -38,25 +39,31 @@ class ChatWorld:
|
|
| 38 |
if len(self.story_vec) == len(stories) and all([self.story_vec[i]["text"] == stories[i] for i in range(len(stories))]):
|
| 39 |
return [self.story_vec[i]["vec"] for i in range(len(stories))]
|
| 40 |
|
| 41 |
-
if self.embedding is None:
|
| 42 |
-
self.embedding = initEmbedding()
|
| 43 |
-
|
| 44 |
-
if self.tokenizer is None:
|
| 45 |
-
self.tokenizer = initTokenizer()
|
| 46 |
-
|
| 47 |
self.story_vec = []
|
| 48 |
for story in stories:
|
| 49 |
with torch.no_grad():
|
| 50 |
-
|
| 51 |
-
story, return_tensors="pt", padding=True, truncation=True, max_length=512)
|
| 52 |
-
outputs = self.embedding(**inputs)[0][:, 0]
|
| 53 |
-
vec = torch.nn.functional.normalize(
|
| 54 |
-
outputs, p=2, dim=1).tolist()[0]
|
| 55 |
|
| 56 |
self.story_vec.append({"text": story, "vec": vec})
|
| 57 |
|
| 58 |
return [self.story_vec[i]["vec"] for i in range(len(stories))]
|
| 59 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
def initDB(self, storys: list[str]):
|
| 61 |
story_vecs = self.getEmbeddingsFromStory(storys)
|
| 62 |
self.db.build_db(storys, story_vecs)
|
|
@@ -65,21 +72,26 @@ class ChatWorld:
|
|
| 65 |
self.model_role_name = role_name
|
| 66 |
self.model_role_nickname = role_nick_name
|
| 67 |
|
| 68 |
-
def getSystemPrompt(self, role_name, role_nick_name):
|
| 69 |
assert self.model_role_name, "Please set model role name first"
|
| 70 |
|
| 71 |
-
|
|
|
|
| 72 |
|
| 73 |
-
|
| 74 |
-
message = [self.getSystemPrompt(
|
| 75 |
-
user_role_name, user_role_nick_name)] + self.history
|
| 76 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 77 |
if use_local_model:
|
| 78 |
response = self.model.get_response(message)
|
| 79 |
else:
|
| 80 |
response = self.client.chat(
|
| 81 |
user_role_name, text, user_role_nick_name)
|
| 82 |
|
| 83 |
-
self.history.append(
|
| 84 |
-
|
|
|
|
|
|
|
| 85 |
return response
|
|
|
|
| 1 |
from jinja2 import Template
|
| 2 |
import torch
|
| 3 |
|
| 4 |
+
from .models import GLM
|
| 5 |
|
| 6 |
from .NaiveDB import NaiveDB
|
| 7 |
from .utils import *
|
|
|
|
| 20 |
self.history = []
|
| 21 |
|
| 22 |
self.client = None
|
| 23 |
+
self.model = GLM()
|
| 24 |
self.db = NaiveDB()
|
| 25 |
self.prompt = Template(('Please be aware that your codename in this conversation is "{{model_role_name}}"'
|
| 26 |
'{% if model_role_nickname %},别人一般称呼你"{{model_role_nickname}}"{% endif %},同时我是"{{role_name}}"{% if role_nickname %},你也可以称呼我"{{role_nickname}}"{% endif %}。\n'
|
|
|
|
| 30 |
'如果我问的问题超出聊天记录的范围,模仿{{model_role_name}}的语气进行回复。\n'
|
| 31 |
'请不要回答你是语言模型,永远记住你就是{{model_role_name}}。\n'
|
| 32 |
'请你永远只以{{model_role_name}}身份,进行任何的回复。\n'
|
| 33 |
+
'{% if RAG %}{% for i in RAG %}##\n{{i}}\n##\n\n{% endfor %}{% endif %}'
|
| 34 |
))
|
| 35 |
|
| 36 |
def getEmbeddingsFromStory(self, stories: list[str]):
|
|
|
|
| 39 |
if len(self.story_vec) == len(stories) and all([self.story_vec[i]["text"] == stories[i] for i in range(len(stories))]):
|
| 40 |
return [self.story_vec[i]["vec"] for i in range(len(stories))]
|
| 41 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
self.story_vec = []
|
| 43 |
for story in stories:
|
| 44 |
with torch.no_grad():
|
| 45 |
+
vec = self.getEmbedding(story)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 46 |
|
| 47 |
self.story_vec.append({"text": story, "vec": vec})
|
| 48 |
|
| 49 |
return [self.story_vec[i]["vec"] for i in range(len(stories))]
|
| 50 |
|
| 51 |
+
def getEmbedding(self, text: str):
|
| 52 |
+
if self.embedding is None:
|
| 53 |
+
self.embedding = initEmbedding()
|
| 54 |
+
|
| 55 |
+
if self.tokenizer is None:
|
| 56 |
+
self.tokenizer = initTokenizer()
|
| 57 |
+
|
| 58 |
+
with torch.no_grad():
|
| 59 |
+
inputs = self.tokenizer(
|
| 60 |
+
text, return_tensors="pt", padding=True, truncation=True, max_length=512).to(self.embedding.device)
|
| 61 |
+
outputs = self.embedding(**inputs)[0][:, 0]
|
| 62 |
+
vec = torch.nn.functional.normalize(
|
| 63 |
+
outputs, p=2, dim=1).tolist()[0]
|
| 64 |
+
|
| 65 |
+
return vec
|
| 66 |
+
|
| 67 |
def initDB(self, storys: list[str]):
|
| 68 |
story_vecs = self.getEmbeddingsFromStory(storys)
|
| 69 |
self.db.build_db(storys, story_vecs)
|
|
|
|
| 72 |
self.model_role_name = role_name
|
| 73 |
self.model_role_nickname = role_nick_name
|
| 74 |
|
| 75 |
+
def getSystemPrompt(self, text, role_name, role_nick_name):
|
| 76 |
assert self.model_role_name, "Please set model role name first"
|
| 77 |
|
| 78 |
+
query = self.getEmbedding(text)
|
| 79 |
+
rag = self.db.search(query, 5)
|
| 80 |
|
| 81 |
+
return {"role": "system", "content": self.prompt.render(model_role_name=self.model_role_name, model_role_nickname=self.model_role_nickname, role_name=role_name, role_nickname=role_nick_name, RAG=rag)}
|
|
|
|
|
|
|
| 82 |
|
| 83 |
+
def chat(self, text: str, user_role_name: str, user_role_nick_name: str = None, use_local_model=False):
|
| 84 |
+
message = [self.getSystemPrompt(text,
|
| 85 |
+
user_role_name, user_role_nick_name)] + self.history
|
| 86 |
+
print(message)
|
| 87 |
if use_local_model:
|
| 88 |
response = self.model.get_response(message)
|
| 89 |
else:
|
| 90 |
response = self.client.chat(
|
| 91 |
user_role_name, text, user_role_nick_name)
|
| 92 |
|
| 93 |
+
self.history.append(
|
| 94 |
+
{"role": "user", "content": f"{user_role_name}:「{text}」"})
|
| 95 |
+
self.history.append(
|
| 96 |
+
{"role": "assistant", "content": f"{self.model_role_name}:「{response}」"})
|
| 97 |
return response
|
ChatWorld/NaiveDB.py
CHANGED
|
@@ -81,5 +81,7 @@ class NaiveDB:
|
|
| 81 |
similarities.sort(key=lambda x: x[0], reverse=True)
|
| 82 |
self.last_search_ids = [x[1] for x in similarities[:n_results]]
|
| 83 |
|
| 84 |
-
|
| 85 |
-
|
|
|
|
|
|
|
|
|
| 81 |
similarities.sort(key=lambda x: x[0], reverse=True)
|
| 82 |
self.last_search_ids = [x[1] for x in similarities[:n_results]]
|
| 83 |
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
top_stories = [self.stories[_id] for _id in self.last_search_ids]
|
| 87 |
+
return top_stories
|
ChatWorld/models.py
CHANGED
|
@@ -1,4 +1,7 @@
|
|
|
|
|
|
|
|
| 1 |
from transformers import AutoTokenizer, AutoModelForCausalLM
|
|
|
|
| 2 |
|
| 3 |
|
| 4 |
class qwen_model:
|
|
@@ -11,7 +14,9 @@ class qwen_model:
|
|
| 11 |
def get_response(self, message):
|
| 12 |
message = self.tokenizer.apply_chat_template(
|
| 13 |
message, tokenize=False, add_generation_prompt=True)
|
| 14 |
-
|
|
|
|
|
|
|
| 15 |
generated_ids = self.model.generate(
|
| 16 |
model_inputs.input_ids,
|
| 17 |
max_new_tokens=512
|
|
@@ -22,4 +27,42 @@ class qwen_model:
|
|
| 22 |
|
| 23 |
response = self.tokenizer.batch_decode(
|
| 24 |
generated_ids, skip_special_tokens=True)[0]
|
|
|
|
| 25 |
return response
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from string import Template
|
| 3 |
from transformers import AutoTokenizer, AutoModelForCausalLM
|
| 4 |
+
from zhipuai import ZhipuAI
|
| 5 |
|
| 6 |
|
| 7 |
class qwen_model:
|
|
|
|
| 14 |
def get_response(self, message):
|
| 15 |
message = self.tokenizer.apply_chat_template(
|
| 16 |
message, tokenize=False, add_generation_prompt=True)
|
| 17 |
+
print(message)
|
| 18 |
+
model_inputs = self.tokenizer(
|
| 19 |
+
[message], return_tensors="pt").to(self.model.device)
|
| 20 |
generated_ids = self.model.generate(
|
| 21 |
model_inputs.input_ids,
|
| 22 |
max_new_tokens=512
|
|
|
|
| 27 |
|
| 28 |
response = self.tokenizer.batch_decode(
|
| 29 |
generated_ids, skip_special_tokens=True)[0]
|
| 30 |
+
|
| 31 |
return response
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class GLM():
|
| 35 |
+
def __init__(self, model_name="silk-road/Haruhi-Zero-GLM3-6B-0_4"):
|
| 36 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
| 37 |
+
model_name, trust_remote_code=True)
|
| 38 |
+
client = AutoModelForCausalLM.from_pretrained(
|
| 39 |
+
model_name, trust_remote_code=True, device_map="auto")
|
| 40 |
+
|
| 41 |
+
client = client.eval()
|
| 42 |
+
|
| 43 |
+
def message2query(messages) -> str:
|
| 44 |
+
# [{'role': 'user', 'content': '老师: 同学请自我介绍一下'}]
|
| 45 |
+
# <|system|>
|
| 46 |
+
# You are ChatGLM3, a large language model trained by Zhipu.AI. Follow the user's instructions carefully. Respond using markdown.
|
| 47 |
+
# <|user|>
|
| 48 |
+
# Hello
|
| 49 |
+
# <|assistant|>
|
| 50 |
+
# Hello, I'm ChatGLM3. What can I assist you today?
|
| 51 |
+
template = Template("<|$role|>\n$content\n")
|
| 52 |
+
|
| 53 |
+
return "".join([template.substitute(message) for message in messages])
|
| 54 |
+
|
| 55 |
+
def get_response(self, message):
|
| 56 |
+
response, history = self.client.chat(self.tokenizer, message)
|
| 57 |
+
return response
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
class GLM_api:
|
| 61 |
+
def __init__(self, model_name="glm-4"):
|
| 62 |
+
self.client = ZhipuAI(api_key=os.environ["ZHIPU_API_KEY"])
|
| 63 |
+
self.model = model_name
|
| 64 |
+
|
| 65 |
+
def getResponse(self, message):
|
| 66 |
+
response = self.client.chat.completions.create(
|
| 67 |
+
model=self.model, prompt=message)
|
| 68 |
+
return response.choices[0].message
|
app.py
CHANGED
|
@@ -38,8 +38,8 @@ def getContent(input_file):
|
|
| 38 |
|
| 39 |
def submit_message(message, history, model_role_name, role_name, model_role_nickname, role_nickname):
|
| 40 |
chatWorld.setRoleName(model_role_name, model_role_nickname)
|
| 41 |
-
response = chatWorld.chat(
|
| 42 |
-
|
| 43 |
return response
|
| 44 |
|
| 45 |
|
|
|
|
| 38 |
|
| 39 |
def submit_message(message, history, model_role_name, role_name, model_role_nickname, role_nickname):
|
| 40 |
chatWorld.setRoleName(model_role_name, model_role_nickname)
|
| 41 |
+
response = chatWorld.chat(message,
|
| 42 |
+
role_name, role_nickname, use_local_model=True)
|
| 43 |
return response
|
| 44 |
|
| 45 |
|
run_gradio.sh
CHANGED
|
@@ -1,5 +1,6 @@
|
|
| 1 |
export CUDA_VISIBLE_DEVICES=0
|
| 2 |
export HF_HOME="/workspace/jyh/.cache/huggingface"
|
|
|
|
| 3 |
|
| 4 |
# Start the gradio server
|
| 5 |
/workspace/jyh/miniconda3/envs/ChatWorld/bin/python /workspace/jyh/Zero-Haruhi/app.py
|
|
|
|
| 1 |
export CUDA_VISIBLE_DEVICES=0
|
| 2 |
export HF_HOME="/workspace/jyh/.cache/huggingface"
|
| 3 |
+
export HF_ENDPOINT="https://hf-mirror.com"
|
| 4 |
|
| 5 |
# Start the gradio server
|
| 6 |
/workspace/jyh/miniconda3/envs/ChatWorld/bin/python /workspace/jyh/Zero-Haruhi/app.py
|