# Copyright 2025 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Dict, List, Sequence
import torch
from ..utils import logging
from .constants import IGNORE_INDEX
if TYPE_CHECKING:
from transformers import PreTrainedTokenizer
logger = logging.get_logger(__name__)
ROLE_SUPPORTED = ["system", "user", "assistant", "tool"]
class ChatTemplate(ABC):
"""
Abstract class for chat template.
"""
def __init__(self, tokenizer: "PreTrainedTokenizer") -> None:
self.tokenizer = tokenizer
def save_pretrained(self, output_dir: str) -> None:
self.tokenizer.chat_template = self.get_jinja_template()
try:
self.tokenizer.save_pretrained(output_dir)
except Exception:
logger.warning("Failed to save tokenizer.")
@abstractmethod
def encode_messages(self, messages: Sequence[Dict[str, str]], max_seq_len: int = 8192) -> Dict[str, List[int]]:
"""
Encodes messages to a dictionary of input_ids, attention_mask, and labels.
"""
...
@abstractmethod
def get_jinja_template(self) -> str:
"""
Gets the jinja template for the chat template.
"""
...
class DefaultTemplate(ChatTemplate):
def encode_messages(self, messages: Sequence[Dict[str, str]], max_seq_len: int = 8192) -> Dict[str, List[int]]:
input_ids, attention_mask, labels = [], [], []
for message in messages:
content_str = message["role"].title() + ": " + message["content"].strip() + self.tokenizer.eos_token + "\n"
content_ids = self.tokenizer.encode(content_str, add_special_tokens=False)
input_ids += content_ids
attention_mask += [1] * len(content_ids)
if message["loss_mask"] == 1:
labels += content_ids
else:
labels += [IGNORE_INDEX] * len(content_ids)
model_inputs = {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels}
model_inputs = {k: v[-max_seq_len:] for k, v in model_inputs.items()}
return model_inputs
def get_jinja_template(self) -> str:
return (
"{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}"
"{% for message in messages %}"
"{{ message['role'].title() + ': ' + message['content'] | trim + eos_token + '\n' }}"
"{% endfor %}"
"{% if add_generation_prompt %}{{ 'Assistant: ' }}{% endif %}"
)
class Llama2Template(ChatTemplate):
def encode_messages(self, messages: Sequence[Dict[str, str]], max_seq_len: int = 8192) -> Dict[str, List[int]]:
input_ids, attention_mask, labels = [], [], []
for message in messages:
if message["role"] == "system":
content_str = "<>\n" + message["content"].strip() + "\n<>\n\n"
elif message["role"] == "user":
content_str = self.tokenizer.bos_token + "[INST] " + message["content"].strip() + " [/INST]"
elif message["role"] == "assistant":
content_str = " " + message["content"].strip() + " " + self.tokenizer.eos_token
elif message["role"] == "tool":
content_str = self.tokenizer.bos_token + "[TOOL] " + message["content"].strip() + " [/TOOL]"
else:
raise ValueError(
f"Unknown role {message['role']}, should be one of {{system, user, assistant, tool}}."
)
content_ids = self.tokenizer.encode(content_str, add_special_tokens=False)
input_ids += content_ids
attention_mask += [1] * len(content_ids)
if message["loss_mask"] == 1:
labels += content_ids
else:
labels += [IGNORE_INDEX] * len(content_ids)
model_inputs = {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels}
model_inputs = {k: v[-max_seq_len:] for k, v in model_inputs.items()}
return model_inputs
def get_jinja_template(self) -> str:
return (
"{% if messages[0]['role'] == 'system' %}"
"{{ '<>\n' + messages[0]['content'] | trim + '\n<>\n\n' }}"
"{% set loop_messages = messages[1:] %}"
"{% else %}"
"{% set loop_messages = messages %}"
"{% endif %}"
"{% for message in loop_messages %}"
"{% set content = message['content'] %}"
"{% if message['role'] == 'user' %}"
"{{ bos_token + '[INST] ' + content | trim + ' [/INST]' }}"
"{% elif message['role'] == 'tool' %}"
"{{ bos_token + '[TOOL] ' + content | trim + ' [/TOOL]' }}"
"{% elif message['role'] == 'assistant' %}"
"{{ ' ' + content | trim + ' ' + eos_token }}"
"{% endif %}"
"{% endfor %}"
)
class JanusTemplate(ChatTemplate):
def encode_messages(
self, messages: Sequence[Dict[str, str]], max_seq_len: int = 8192, task_type: str = ""
) -> Dict[str, List[int]]:
input_ids, attention_mask, labels = [], [], []
images_seq_mask, images_emb_mask = [], []
seps = ["\n\n", "<|end▁of▁sentence|>"]
assitant_cnt = 0
for idx, message in enumerate(messages):
if message["content"] == "":
content_str = message["role"] + ":"
elif (
"assistant" in message["role"]
and "wikihow_generation" in task_type
or "assistant" in message["role"]
and "interleave_generation" in task_type
):
prefix = "Assistant: " if assitant_cnt == 0 else ""
suffix = seps[1] if idx + 1 == len(messages) else seps[0]
content_str = prefix + message["content"].strip() + suffix
assitant_cnt += 1
elif "assistant" in message["role"]:
content_str = "Assistant" + ": " + message["content"].strip() + seps[1]
elif "user" in message["role"]:
content_str = "User" + ": " + message["content"].strip() + seps[0]
elif "system" in message["role"] and "wikihow_generation" in task_type:
content_str = (
message["content"].strip()
+ seps[0]
+ "Please generate a step-by-step tutorial with images for the following question."
+ seps[0]
)
elif "system" in message["role"]:
content_str = message["content"].strip() + seps[0]
if "system" in message["role"]:
content_ids = self.tokenizer.encode(content_str)
else:
content_ids = self.tokenizer.encode(content_str, add_special_tokens=False)
input_ids += content_ids
attention_mask += [1] * len(content_ids)
image_token_id = self.tokenizer.vocab.get("")
content_ids_tensor = torch.tensor(content_ids)
images_seq_mask += (content_ids_tensor == image_token_id).tolist()
image_token_id = self.tokenizer.vocab.get("")
num_image_tokens = torch.sum(content_ids_tensor == image_token_id).item()
n_image = num_image_tokens // 576
if n_image > 0:
for j, n_image_tokens in enumerate([num_image_tokens]):
images_emb_mask.append([True] * n_image_tokens)
if message["loss_mask"] == 1:
if (
image_token_id in content_ids
and "wikihow_generation" not in task_type
and "interleave_generation" not in task_type
):
labels += [image_token_id if x == image_token_id else IGNORE_INDEX for x in content_ids]
else:
labels += content_ids
else:
labels += [IGNORE_INDEX] * len(content_ids)
model_inputs = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"labels": labels,
"images_seq_mask": images_seq_mask,
"images_emb_mask": images_emb_mask,
}
model_inputs = {k: v[-max_seq_len:] for k, v in model_inputs.items()}
return model_inputs
def get_jinja_template(self) -> str:
return (
"{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}"
"{% for message in messages %}"
"{{ '<|im_start|>' + message['role'] + '\n' + message['content'] | trim + '<|im_end|>\n' }}"
"{% endfor %}"
"{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"
)
class ChatmlTemplate(ChatTemplate):
def encode_messages(self, messages: Sequence[Dict[str, str]], max_seq_len: int = 8192) -> Dict[str, List[int]]:
input_ids, attention_mask, labels = [], [], []
for message in messages:
content_str = "<|im_start|>" + message["role"] + "\n" + message["content"].strip() + "<|im_end|>\n"
content_ids = self.tokenizer.encode(content_str, add_special_tokens=False)
input_ids += content_ids
attention_mask += [1] * len(content_ids)
if message["loss_mask"] == 1:
labels += content_ids
else:
labels += [IGNORE_INDEX] * len(content_ids)
model_inputs = {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels}
model_inputs = {k: v[-max_seq_len:] for k, v in model_inputs.items()}
return model_inputs
def get_jinja_template(self) -> str:
return (
"{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}"
"{% for message in messages %}"
"{{ '<|im_start|>' + message['role'] + '\n' + message['content'] | trim + '<|im_end|>\n' }}"
"{% endfor %}"
"{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"
)
TEMPLATES = {
"default": DefaultTemplate,
"llama2": Llama2Template,
"chatml": ChatmlTemplate,
"Janus": JanusTemplate,
}
def build_chat_template(template_name: str, tokenizer: "PreTrainedTokenizer") -> "ChatTemplate":
if template_name not in TEMPLATES:
raise ValueError(f"Unknown chat template: {template_name}")
return TEMPLATES[template_name](tokenizer)