YAML Metadata Warning: empty or missing yaml metadata in repo card (https://huggingface.co/docs/hub/model-cards#model-card-metadata)

🤭 Please refer to https://github.com/svjack/Genshin-Impact-Character-Chat to get more info

Install

pip install peft transformers bitsandbytes ipykernel rapidfuzz

Run by transformers

import json
from dataclasses import dataclass
from enum import Enum
from typing import List, Dict, Tuple, Literal

class Roles(Enum):
    system = "system"
    user = "user"
    assistant = "assistant"
    tool = "tool"

class MessagesFormatterType(Enum):
    """
    Enum representing different types of predefined messages formatters.
    """

    MISTRAL = 1

@dataclass
class PromptMarkers:
    start: str
    end: str

class MessagesFormatter:
    def __init__(
            self,
            pre_prompt: str,
            prompt_markers: Dict[Roles, PromptMarkers],
            include_sys_prompt_in_first_user_message: bool,
            default_stop_sequences: List[str],
            use_user_role_for_function_call_result: bool = True,
            strip_prompt: bool = True,
            bos_token: str = "<s>",
            eos_token: str = "</s>"
    ):
        self.pre_prompt = pre_prompt
        self.prompt_markers = prompt_markers
        self.include_sys_prompt_in_first_user_message = include_sys_prompt_in_first_user_message
        self.default_stop_sequences = default_stop_sequences
        self.use_user_role_for_function_call_result = use_user_role_for_function_call_result
        self.strip_prompt = strip_prompt
        self.bos_token = bos_token
        self.eos_token = eos_token
        self.added_system_prompt = False

    def get_bos_token(self) -> str:
        return self.bos_token

    def format_conversation(
            self,
            messages: List[Dict[str, str]],
            response_role: Literal[Roles.user, Roles.assistant] | None = None,
    ) -> Tuple[str, Roles]:
        formatted_messages = self.pre_prompt
        last_role = Roles.assistant
        self.added_system_prompt = False
        for message in messages:
            role = Roles(message["role"])
            content = self._format_message_content(message["content"], role)

            if role == Roles.system:
                formatted_messages += self._format_system_message(content)
                last_role = Roles.system
            elif role == Roles.user:
                formatted_messages += self._format_user_message(content)
                last_role = Roles.user
            elif role == Roles.assistant:
                formatted_messages += self._format_assistant_message(content)
                last_role = Roles.assistant
            elif role == Roles.tool:
                formatted_messages += self._format_tool_message(content)
                last_role = Roles.tool

        return self._format_response(formatted_messages, last_role, response_role)

    def _format_message_content(self, content: str, role: Roles) -> str:
        if self.strip_prompt:
            return content.strip()
        return content

    def _format_system_message(self, content: str) -> str:
        formatted_message = self.prompt_markers[Roles.system].start + content + self.prompt_markers[Roles.system].end
        self.added_system_prompt = True
        if self.include_sys_prompt_in_first_user_message:
            formatted_message = self.prompt_markers[Roles.user].start + formatted_message
        return formatted_message

    def _format_user_message(self, content: str) -> str:
        if self.include_sys_prompt_in_first_user_message and self.added_system_prompt:
            self.added_system_prompt = False
            return content + self.prompt_markers[Roles.user].end
        return self.prompt_markers[Roles.user].start + content + self.prompt_markers[Roles.user].end

    def _format_assistant_message(self, content: str) -> str:
        return self.prompt_markers[Roles.assistant].start + content + self.prompt_markers[Roles.assistant].end

    def _format_tool_message(self, content: str) -> str:
        if isinstance(content, list):
            content = "\n".join(json.dumps(m, indent=2) for m in content)
        if self.use_user_role_for_function_call_result:
            return self._format_user_message(content)
        else:
            return self.prompt_markers[Roles.tool].start + content + self.prompt_markers[Roles.tool].end

    def _format_response(
            self,
            formatted_messages: str,
            last_role: Roles,
            response_role: Literal[Roles.user, Roles.assistant] | None = None,
    ) -> Tuple[str, Roles]:
        if response_role is None:
            response_role = Roles.assistant if last_role != Roles.assistant else Roles.user

        prompt_start = self.prompt_markers[response_role].start.strip() if self.strip_prompt else self.prompt_markers[
            response_role].start
        return formatted_messages + prompt_start, response_role

mixtral_prompt_markers = {
    Roles.system: PromptMarkers("", """\n\n"""),
    Roles.user: PromptMarkers("""[INST] """, """ [/INST]"""),
    Roles.assistant: PromptMarkers("""""", """</s>"""),
    Roles.tool: PromptMarkers("", ""),
}

mixtral_formatter = MessagesFormatter(
    "",
    mixtral_prompt_markers,
    True,
    ["</s>"],
)

from transformers import TextStreamer, AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
tokenizer = AutoTokenizer.from_pretrained("svjack/DPO_Genshin_Impact_Mistral_Plot_Engine_Step_Json_Short_merged",)
mis_model = AutoModelForCausalLM.from_pretrained("svjack/DPO_Genshin_Impact_Mistral_Plot_Engine_Step_Json_Short_merged", load_in_4bit = True)
mis_model = mis_model.eval()

streamer = TextStreamer(tokenizer)

def mistral_hf_predict(messages, mis_model = mis_model,
    tokenizer = tokenizer, streamer = streamer,
    do_sample = True,
    top_p = 0.95,
    top_k = 40,
    max_new_tokens = 512,
    max_input_length = 3500,
    temperature = 0.9,
    repetition_penalty = 1.0,
    device = "cuda"):

    #encodeds = tokenizer.apply_chat_template(messages, return_tensors="pt")
    #model_inputs = encodeds.to(device)
    prompt, _ = mixtral_formatter.format_conversation(messages)
    model_inputs = tokenizer.encode(prompt, return_tensors="pt").to(device)

    generated_ids = mis_model.generate(model_inputs, max_new_tokens=max_new_tokens,
                                do_sample=do_sample,
                                  streamer = streamer,
                                  top_p = top_p,
                                  top_k = top_k,
                                  temperature = temperature,
                                  repetition_penalty = repetition_penalty,
                                  )
    out = tokenizer.batch_decode(generated_ids)[0].split("[/INST]")[-1].replace("</s>", "").strip()
    return out

from rapidfuzz import fuzz 
from IPython.display import clear_output
def run_step_infer_times(x, times = 5, temperature = 0.01,
                        repetition_penalty = 1.0,
                        sim_val = 70
                        ):
    req = []
    for _ in range(times):
        clear_output(wait = True)
        out = mistral_hf_predict([
                {
                    "role": "system",
                    "content": ""
                },
                {
                    "role": "user",
                    "content": x
                },
            ],
            repetition_penalty = repetition_penalty,
            temperature = temperature,
            max_new_tokens = 2070,
            max_input_length = 6000,
        )
        if req:
            val = max(map(lambda x: fuzz.ratio(x, out), req))
            #print(val)
            #print(req)
            if val < sim_val:
                req.append(out.strip())
            x = x.strip() + "\n" + out.strip()
        else:
            req.append(out.strip())
            x = x.strip() + "\n" + out.strip()
    return req

out_l = run_step_infer_times(
'''
故事标题:归乡
故事背景:在须弥城门口,派蒙与纳西妲偶遇并帮助一只昏迷的元素生命找寻家园。过程中揭示了这只生物并非普通的蕈兽,而是元素生物,并且它们曾受到过‘末日’的影响,家园被侵蚀。纳西妲回忆起晶体里的力量可能与一个预言有关,为了拯救它们的家园,她必须解决‘禁忌知识’问题,但这个过程对她自身也会产生干扰。
参与角色:派蒙、纳西妲、浮游水蕈兽、旅行者
''',
    temperature=0.1,
    repetition_penalty = 1.0,
    times = 10
)
clear_output(wait = True)

print("\n".join(out_l))

Output

{'参与者1': '派蒙', '参与者2': '纳西妲', '当前故事背景': '在须弥城门口,派蒙发现了一只昏迷的浮游水蕈兽,并询问它的家园。纳西妲确认这是元素生物,并解释了它们的特殊性和‘末日’的影响。纳西妲提出要帮助它们找到家园,但这需要解决‘禁忌知识’问题,这对她个人也有所挑战。'}
{'参与者1': '派蒙', '参与者2': '浮游水蕈兽', '当前故事背景': '派蒙询问浮游水蕈兽是否记得家园,它似乎对此感到疑惑,并且纳西妲解释了它们的特殊性和‘末日’后的变化。'}
{'参与者1': '纳西妲', '参与者2': '旅行者', '当前故事背景': '纳西妲提到‘禁忌知识’与她的力量有关,她需要解决这个问题以帮助元素生物。旅行者对此表示愿意协助,并在后续的对话中提到了‘晶体’的力量可能与纳西妲的预言有关。'}
{'参与者1': '纳西妲', '参与者2': '派蒙', '当前故事背景': '纳西妲提到‘晶体’中的力量可能与她的预言有联系,这让派蒙感到惊讶,并询问具体内容。'}
{'参与者1': '纳西妲', '参与者2': '旅行者', '当前故事背景': '纳西妲提出要去寻找‘禁忌知识’,旅行者表示愿意帮忙,两人准备前往目标地点。'}
{'参与者1': '派蒙', '参与者2': '纳西妲', '当前故事背景': '在寻找过程中,派蒙对纳西妲的力量和预言感到疑惑,纳西妲解释了这是为了帮助元素生物。'}
{'参与者1': '旅行者', '参与者2': '纳西妲', '当前故事背景': '旅行者表示愿意帮助解决‘禁忌知识’问题,两人的合作关系在故事中得到了展现。'}
Downloads last month
-
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support