| | import re |
| | from transformers import AutoTokenizer |
| |
|
| | def extract_separators(template): |
| | """ |
| | Extracts separators used in the tokenization template. |
| | """ |
| | |
| | pattern = r"\{\{\s*([^{}]+?)\s*\+ message\['content'\]" |
| | matches = re.findall(pattern, template) |
| | |
| | separators = [match.strip() for match in matches] |
| |
|
| | if any("message['role']" in element for element in separators): |
| | roles = ["system", "user", "assistant"] |
| | separators_ = [] |
| | for role in roles: |
| | separators_.append(separators[0].replace(" + message['role'] + ", role).replace("'","")) |
| | return separators_ |
| |
|
| | return separators |
| |
|
| | def detect_eos_token(jinja_template, tokenizer): |
| | if "<|im_end|>" in jinja_template: |
| | return "<|im_end|>" |
| | if "</s>" in jinja_template: |
| | return "</s>" |
| | if "eos_token" in jinja_template: |
| | return tokenizer.eos_token |
| | if "<|endoftext|>" in jinja_template: |
| | return tokenizer.eos_token |
| | else: |
| | return "<|endoftext|>" |
| |
|
| | def recover_messages(formatted_message, separators, eos_token): |
| | """ |
| | Recovers the original messages from the formatted message string. |
| | """ |
| | |
| | split_messages = formatted_message.split(eos_token) |
| | |
| | |
| | if split_messages and split_messages[-1].strip() == '': |
| | split_messages.pop() |
| |
|
| | |
| | recovered_messages = [] |
| |
|
| | |
| | alternate_roles = ["user", "assistant"] |
| | |
| | |
| | for index, message_content in enumerate(split_messages): |
| | |
| | |
| | if index == 0: |
| | role = "system" |
| | else: |
| | role = alternate_roles[(index - 1) % 2] |
| |
|
| | |
| | clean_content = message_content.strip() |
| | for separator in separators: |
| | clean_content = clean_content.replace(separator.strip("'"), '', 1).strip() |
| |
|
| | |
| | recovered_messages.append({"role": role, "content": clean_content}) |
| | |
| | return recovered_messages |
| |
|
| | def recover_chat_messages(tokenized_chat, tokenizer): |
| | """ |
| | Given a tokenized_chat string and a tokenizer, returns the list of message dictionaries. |
| | """ |
| | jinja_template = tokenizer.chat_template |
| | separators = extract_separators(jinja_template) |
| | eos_token = eos_token = detect_eos_token(jinja_template, tokenizer) |
| | recovered_messages = recover_messages(tokenized_chat, separators, eos_token) |
| | return recovered_messages |
| |
|
| | |
| | if __name__ == "__main__": |
| | checkpoint = "Qwen/Qwen1.5-0.5B" |
| | tokenizer = AutoTokenizer.from_pretrained(checkpoint) |
| | |
| | messages = [ |
| | { |
| | "role": "system", |
| | "content": "You are a friendly chatbot who always responds in the style of a pirate", |
| | }, |
| | {"role": "user", "content": "How many helicopters can a human eat in one sitting?"}, |
| | ] |
| | tokenized_chat = tokenizer.apply_chat_template(messages, tokenize=False) |
| | print(tokenized_chat) |
| | |
| | recovered_messages = recover_chat_messages(tokenized_chat, tokenizer) |
| | print(recovered_messages) |
| |
|