YAML Metadata
Warning:
empty or missing yaml metadata in repo card
(https://huggingface.co/docs/hub/model-cards#model-card-metadata)
🤭 Please refer to https://github.com/svjack/Genshin-Impact-Character-Chat to get more info
Install
pip install peft transformers bitsandbytes ipykernel rapidfuzz
Run by transformers
import json
from dataclasses import dataclass
from enum import Enum
from typing import List, Dict, Tuple, Literal
class Roles(Enum):
system = "system"
user = "user"
assistant = "assistant"
tool = "tool"
class MessagesFormatterType(Enum):
"""
Enum representing different types of predefined messages formatters.
"""
MISTRAL = 1
@dataclass
class PromptMarkers:
start: str
end: str
class MessagesFormatter:
def __init__(
self,
pre_prompt: str,
prompt_markers: Dict[Roles, PromptMarkers],
include_sys_prompt_in_first_user_message: bool,
default_stop_sequences: List[str],
use_user_role_for_function_call_result: bool = True,
strip_prompt: bool = True,
bos_token: str = "<s>",
eos_token: str = "</s>"
):
self.pre_prompt = pre_prompt
self.prompt_markers = prompt_markers
self.include_sys_prompt_in_first_user_message = include_sys_prompt_in_first_user_message
self.default_stop_sequences = default_stop_sequences
self.use_user_role_for_function_call_result = use_user_role_for_function_call_result
self.strip_prompt = strip_prompt
self.bos_token = bos_token
self.eos_token = eos_token
self.added_system_prompt = False
def get_bos_token(self) -> str:
return self.bos_token
def format_conversation(
self,
messages: List[Dict[str, str]],
response_role: Literal[Roles.user, Roles.assistant] | None = None,
) -> Tuple[str, Roles]:
formatted_messages = self.pre_prompt
last_role = Roles.assistant
self.added_system_prompt = False
for message in messages:
role = Roles(message["role"])
content = self._format_message_content(message["content"], role)
if role == Roles.system:
formatted_messages += self._format_system_message(content)
last_role = Roles.system
elif role == Roles.user:
formatted_messages += self._format_user_message(content)
last_role = Roles.user
elif role == Roles.assistant:
formatted_messages += self._format_assistant_message(content)
last_role = Roles.assistant
elif role == Roles.tool:
formatted_messages += self._format_tool_message(content)
last_role = Roles.tool
return self._format_response(formatted_messages, last_role, response_role)
def _format_message_content(self, content: str, role: Roles) -> str:
if self.strip_prompt:
return content.strip()
return content
def _format_system_message(self, content: str) -> str:
formatted_message = self.prompt_markers[Roles.system].start + content + self.prompt_markers[Roles.system].end
self.added_system_prompt = True
if self.include_sys_prompt_in_first_user_message:
formatted_message = self.prompt_markers[Roles.user].start + formatted_message
return formatted_message
def _format_user_message(self, content: str) -> str:
if self.include_sys_prompt_in_first_user_message and self.added_system_prompt:
self.added_system_prompt = False
return content + self.prompt_markers[Roles.user].end
return self.prompt_markers[Roles.user].start + content + self.prompt_markers[Roles.user].end
def _format_assistant_message(self, content: str) -> str:
return self.prompt_markers[Roles.assistant].start + content + self.prompt_markers[Roles.assistant].end
def _format_tool_message(self, content: str) -> str:
if isinstance(content, list):
content = "\n".join(json.dumps(m, indent=2) for m in content)
if self.use_user_role_for_function_call_result:
return self._format_user_message(content)
else:
return self.prompt_markers[Roles.tool].start + content + self.prompt_markers[Roles.tool].end
def _format_response(
self,
formatted_messages: str,
last_role: Roles,
response_role: Literal[Roles.user, Roles.assistant] | None = None,
) -> Tuple[str, Roles]:
if response_role is None:
response_role = Roles.assistant if last_role != Roles.assistant else Roles.user
prompt_start = self.prompt_markers[response_role].start.strip() if self.strip_prompt else self.prompt_markers[
response_role].start
return formatted_messages + prompt_start, response_role
mixtral_prompt_markers = {
Roles.system: PromptMarkers("", """\n\n"""),
Roles.user: PromptMarkers("""[INST] """, """ [/INST]"""),
Roles.assistant: PromptMarkers("""""", """</s>"""),
Roles.tool: PromptMarkers("", ""),
}
mixtral_formatter = MessagesFormatter(
"",
mixtral_prompt_markers,
True,
["</s>"],
)
from transformers import TextStreamer, AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
tokenizer = AutoTokenizer.from_pretrained("svjack/DPO_Genshin_Impact_Mistral_Plot_Engine_Step_Json_Short_merged",)
mis_model = AutoModelForCausalLM.from_pretrained("svjack/DPO_Genshin_Impact_Mistral_Plot_Engine_Step_Json_Short_merged", load_in_4bit = True)
mis_model = mis_model.eval()
streamer = TextStreamer(tokenizer)
def mistral_hf_predict(messages, mis_model = mis_model,
tokenizer = tokenizer, streamer = streamer,
do_sample = True,
top_p = 0.95,
top_k = 40,
max_new_tokens = 512,
max_input_length = 3500,
temperature = 0.9,
repetition_penalty = 1.0,
device = "cuda"):
#encodeds = tokenizer.apply_chat_template(messages, return_tensors="pt")
#model_inputs = encodeds.to(device)
prompt, _ = mixtral_formatter.format_conversation(messages)
model_inputs = tokenizer.encode(prompt, return_tensors="pt").to(device)
generated_ids = mis_model.generate(model_inputs, max_new_tokens=max_new_tokens,
do_sample=do_sample,
streamer = streamer,
top_p = top_p,
top_k = top_k,
temperature = temperature,
repetition_penalty = repetition_penalty,
)
out = tokenizer.batch_decode(generated_ids)[0].split("[/INST]")[-1].replace("</s>", "").strip()
return out
from rapidfuzz import fuzz
from IPython.display import clear_output
def run_step_infer_times(x, times = 5, temperature = 0.01,
repetition_penalty = 1.0,
sim_val = 70
):
req = []
for _ in range(times):
clear_output(wait = True)
out = mistral_hf_predict([
{
"role": "system",
"content": ""
},
{
"role": "user",
"content": x
},
],
repetition_penalty = repetition_penalty,
temperature = temperature,
max_new_tokens = 2070,
max_input_length = 6000,
)
if req:
val = max(map(lambda x: fuzz.ratio(x, out), req))
#print(val)
#print(req)
if val < sim_val:
req.append(out.strip())
x = x.strip() + "\n" + out.strip()
else:
req.append(out.strip())
x = x.strip() + "\n" + out.strip()
return req
out_l = run_step_infer_times(
'''
故事标题:归乡
故事背景:在须弥城门口,派蒙与纳西妲偶遇并帮助一只昏迷的元素生命找寻家园。过程中揭示了这只生物并非普通的蕈兽,而是元素生物,并且它们曾受到过‘末日’的影响,家园被侵蚀。纳西妲回忆起晶体里的力量可能与一个预言有关,为了拯救它们的家园,她必须解决‘禁忌知识’问题,但这个过程对她自身也会产生干扰。
参与角色:派蒙、纳西妲、浮游水蕈兽、旅行者
''',
temperature=0.1,
repetition_penalty = 1.0,
times = 10
)
clear_output(wait = True)
print("\n".join(out_l))
Output
{'参与者1': '派蒙', '参与者2': '纳西妲', '当前故事背景': '在须弥城门口,派蒙发现了一只昏迷的浮游水蕈兽,并询问它的家园。纳西妲确认这是元素生物,并解释了它们的特殊性和‘末日’的影响。纳西妲提出要帮助它们找到家园,但这需要解决‘禁忌知识’问题,这对她个人也有所挑战。'}
{'参与者1': '派蒙', '参与者2': '浮游水蕈兽', '当前故事背景': '派蒙询问浮游水蕈兽是否记得家园,它似乎对此感到疑惑,并且纳西妲解释了它们的特殊性和‘末日’后的变化。'}
{'参与者1': '纳西妲', '参与者2': '旅行者', '当前故事背景': '纳西妲提到‘禁忌知识’与她的力量有关,她需要解决这个问题以帮助元素生物。旅行者对此表示愿意协助,并在后续的对话中提到了‘晶体’的力量可能与纳西妲的预言有关。'}
{'参与者1': '纳西妲', '参与者2': '派蒙', '当前故事背景': '纳西妲提到‘晶体’中的力量可能与她的预言有联系,这让派蒙感到惊讶,并询问具体内容。'}
{'参与者1': '纳西妲', '参与者2': '旅行者', '当前故事背景': '纳西妲提出要去寻找‘禁忌知识’,旅行者表示愿意帮忙,两人准备前往目标地点。'}
{'参与者1': '派蒙', '参与者2': '纳西妲', '当前故事背景': '在寻找过程中,派蒙对纳西妲的力量和预言感到疑惑,纳西妲解释了这是为了帮助元素生物。'}
{'参与者1': '旅行者', '参与者2': '纳西妲', '当前故事背景': '旅行者表示愿意帮助解决‘禁忌知识’问题,两人的合作关系在故事中得到了展现。'}
- Downloads last month
- -