|
|
from dataclasses import dataclass |
|
|
from typing import Literal, Optional, Tuple |
|
|
|
|
|
from transformers import PreTrainedModel, PreTrainedTokenizer |
|
|
|
|
|
|
|
|
|
|
|
@dataclass |
|
|
class ChatMlSpecialTokens: |
|
|
"""Dataclass for special tokens used in ChatML, including system, user, assistant, bos, eos, and pad tokens.""" |
|
|
|
|
|
bos_token: str = "<|im_start|>" |
|
|
eos_token: str = "<|im_end|>" |
|
|
pad_token: str = "<|im_end|>" |
|
|
|
|
|
@property |
|
|
def system(self): |
|
|
return f"{self.bos_token}system" |
|
|
|
|
|
@property |
|
|
def user(self): |
|
|
return f"{self.bos_token}user" |
|
|
|
|
|
@property |
|
|
def assistant(self): |
|
|
return f"{self.bos_token}assistant" |
|
|
|
|
|
@property |
|
|
def chat_template(self): |
|
|
return ( |
|
|
"{% for message in messages %}" |
|
|
f"{{{{'{self.bos_token}' + message['role'] + '\n' + message['content'] + '{self.eos_token}' + '\n'}}}}" |
|
|
"{% endfor %}" |
|
|
"{% if add_generation_prompt %}" |
|
|
f"{{{{ '{self.assistant}\n' }}}}" |
|
|
"{% endif %}" |
|
|
) |
|
|
|
|
|
|
|
|
FORMAT_MAPPING = {"chatml": ChatMlSpecialTokens} |
|
|
|
|
|
|
|
|
def setup_chat_format( |
|
|
model: PreTrainedModel, |
|
|
tokenizer: PreTrainedTokenizer, |
|
|
format: Optional[Literal["chatml"]] = "chatml", |
|
|
resize_to_multiple_of: Optional[int] = None, |
|
|
) -> Tuple[PreTrainedModel, PreTrainedTokenizer]: |
|
|
""" |
|
|
Setup chat format by adding special tokens to the tokenizer, setting the correct format, and extending the embedding layer of the model based on the new special tokens. |
|
|
|
|
|
Args: |
|
|
model (`~transformers.PreTrainedModel`): The model to be modified. |
|
|
tokenizer (`~transformers.PreTrainedTokenizer`): The tokenizer to be modified. |
|
|
format (`Optional[Literal["chatml"]]`): The format to be set. Defaults to "chatml". |
|
|
resize_to_multiple_of (`Optional[int]`): Number to resize the embedding layer to. Defaults to None. |
|
|
Returns: |
|
|
model (`~transformers.PreTrainedModel`): The modified model. |
|
|
tokenizer (`~transformers.PreTrainedTokenizer`): The modified tokenizer. |
|
|
""" |
|
|
|
|
|
if format not in FORMAT_MAPPING: |
|
|
raise ValueError(f"Format {format} not available. Please use one of {FORMAT_MAPPING.keys()}") |
|
|
|
|
|
chat_format = FORMAT_MAPPING[format]() |
|
|
|
|
|
|
|
|
tokenizer.eos_token = chat_format.eos_token |
|
|
tokenizer.pad_token = chat_format.pad_token |
|
|
tokenizer.bos_token = chat_format.bos_token |
|
|
tokenizer.add_special_tokens({"additional_special_tokens": [chat_format.bos_token, chat_format.eos_token]}) |
|
|
|
|
|
tokenizer.chat_template = chat_format.chat_template |
|
|
|
|
|
|
|
|
model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=resize_to_multiple_of if resize_to_multiple_of is not None else None) |
|
|
|
|
|
if getattr(model, "generation_config", None) is not None: |
|
|
model.generation_config.bos_token_id = tokenizer.bos_token_id |
|
|
model.generation_config.eos_token_id = tokenizer.eos_token_id |
|
|
model.generation_config.pad_token_id = tokenizer.pad_token_id |
|
|
|
|
|
return model, tokenizer |
|
|
|