M4CXR-TNNLS / tokenization_m4cxr.py
jonggwon-park's picture
debug import
795e71e
import re
import torch
from transformers import LlamaTokenizerFast
_INFINITE = int(1e12) # infinite token length for no-truncation
# constants
IGNORE_INDEX = -100 # default ignore index of CrossEntropyLoss
# special media tokens
IMAGE = "<image>"
MEDIA_TOKENS = {
"image": [IMAGE],
}
# mistral chat template
MISTRAL_USER = "[INST]"
MISTRAL_ASSISTANT = "[/INST]"
MISTRAL_CHAT_TEMPLATE = "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}"
def _pad_trunc(
x: list[list[int]],
padding: str,
padding_side: str,
pad_value: int,
max_length: int,
) -> torch.LongTensor:
"""Pad and truncate sequences to the same length
Args:
x (list[list[int]])
padding ("longest" or "max_length")
padding_side ("left" or "right")
pad_value (int)
max_length (int or None): if padding == "max_length", max_length should be given.
"""
assert padding in ["longest", "max_length"]
assert padding_side in ["left", "right"]
lengths = [len(sample) for sample in x]
if padding == "longest":
max_length = max(lengths)
new_x = []
for sample, length in zip(x, lengths):
if torch.is_tensor(sample):
sample = sample.tolist()
if length >= max_length:
new_x.append(sample[:max_length])
continue
padding_size = max_length - length
pads = [pad_value] * padding_size
if padding_side == "right":
new_x.append(sample + pads)
else:
new_x.append(pads + sample)
return torch.as_tensor(new_x, dtype=torch.long)
"""Modified from https://github.com/khanrc/honeybee
"""
class MLLMTokenizerMixin:
def mllm_setup(self, num_visual_tokens: int, chat_template: str):
if self.pad_token is None:
self.pad_token = self.unk_token
self.num_visual_tokens = num_visual_tokens
# Currently we only support the image modality for media modality.
self.media_tokens = {
k: -int(i + 1) for i, k in enumerate(MEDIA_TOKENS["image"])
}
self.media_lengths = {
MEDIA_TOKENS["image"][0]: num_visual_tokens
} # token lengths
self.chat_template_type = chat_template
# chat template
if chat_template == "mistral":
self.chat_template = MISTRAL_CHAT_TEMPLATE
self.user_indicator = MISTRAL_USER
self.assistant_indicator = MISTRAL_ASSISTANT
else:
raise NotImplementedError
def encode_prompt(self, prompt: str, max_length: int | None, no_eos=False):
"""Tokenize prompt which consists of image-text or text only, with role tokens.
Role pattern is "AI: " or "Human: ".
Args:
prompt
max_length (int or None): here, max_length is used for truncation.
If max_length is None, no truncation is applied.
no_eos: if True, eos token is not added at the end of the prompt.
Note that eos token is still used for end-of-AI-turn token even no_eos=True.
"""
max_length = (
max_length or _INFINITE
) # if None, set to infinite for no-truncation
enc_chunk = []
label_chunk = []
# find special tokens for multi-modal chats (<image>, user, assistant)
pattern = "|".join(
map(
re.escape,
list(self.media_tokens.keys())
+ [self.user_indicator, self.assistant_indicator],
)
)
chunk_strs = re.split(f"({pattern})", prompt)
chunk_strs = [x for x in chunk_strs if len(x) > 0]
for idx, chunk_str in enumerate(chunk_strs):
if len(enc_chunk) >= max_length + 1:
break
if chunk_str in self.media_tokens:
if len(enc_chunk) + self.media_lengths[chunk_str] > max_length + 1:
break
enc_chunk += [self.media_tokens[chunk_str]] * self.media_lengths[
chunk_str
]
label_chunk += [0] * self.media_lengths[chunk_str]
else:
label = (
1
if (idx > 0 and chunk_strs[idx - 1] == self.assistant_indicator)
else 0
)
curr_chunk = self(chunk_str, add_special_tokens=False)["input_ids"]
if label == 1 and curr_chunk[-1] != self.eos_token_id:
curr_chunk += [self.eos_token_id]
enc_chunk += curr_chunk
label_chunk += [label] * len(curr_chunk)
if no_eos and enc_chunk[-1] == self.eos_token_id:
# the last token can be != eos_token_id; when the prompt is ended with `AI: `.
# in this case, there is no AI-answer, thus, no eos token is added.
enc_chunk = enc_chunk[:-1]
label_chunk = label_chunk[:-1]
enc_chunk = enc_chunk[: max_length + 1]
label_chunk = label_chunk[: max_length + 1]
L = len(enc_chunk)
assert L == len(label_chunk)
input_ids = torch.as_tensor(enc_chunk, dtype=torch.long)
loss_mask = torch.as_tensor(label_chunk, dtype=torch.long)
attention_mask = torch.ones_like(input_ids)
# Label
labels = input_ids.clone()
labels[loss_mask != 1] = IGNORE_INDEX
# The length of input_ids (L) includes <bos> and <eos> tokens.
# Since max_length does not include <bos> token, L <= max_length + 1
assert L <= max_length + 1
return {
"input_ids": input_ids, # [L]
"labels": labels, # [L]
"seq_length": L, # int
"attention_mask": attention_mask, # [L]
}
def batch_encode_prompt(
self,
prompts: list[str],
padding: str = "longest",
padding_side: str = "right",
max_length: int | None = None,
no_eos=False,
) -> dict[str, torch.LongTensor]:
"""Batch encode prompts, pad/truncate to the same length, and collate them.
Args:
prompts (list[str])
padding ("longest" or "max_length")
padding_side ("left" or "right")
pad_value (int)
max_length (int or None): if padding == "max_length", max_length should be given
"""
batch = [self.encode_prompt(prompt, max_length, no_eos) for prompt in prompts]
batch = self.batch_collate_pad(batch, padding, padding_side, max_length)
return batch
def batch_collate_pad(
self,
batch: list,
padding: str,
padding_side: str,
max_length: int | None,
) -> dict[str, torch.LongTensor]:
"""Collate batch and pad/truncate to the same length
Args:
batch
padding ("longest" or "max_length")
padding_side ("left" or "right")
pad_value (int)
max_length (int or None): if padding == "max_length", max_length should be given
"""
if padding == "max_length":
assert (
max_length is not None
), "max_length should be given if padding == 'max_length'"
else:
# if padding == 'longest' and max_length is None, set to infinite for no-truncation
max_length = max_length or _INFINITE
input_ids = [sample["input_ids"] for sample in batch]
labels = [sample["labels"] for sample in batch]
attention_mask = [sample["attention_mask"] for sample in batch]
seq_length = [sample["seq_length"] for sample in batch]
# max_length + 1 for bos_token
input_ids = _pad_trunc(
input_ids, padding, padding_side, self.pad_token_id, max_length + 1
)
labels = _pad_trunc(labels, padding, padding_side, IGNORE_INDEX, max_length + 1)
attention_mask = _pad_trunc(
attention_mask, padding, padding_side, 0, max_length + 1
)
seq_length = torch.as_tensor(seq_length, dtype=torch.long)
return {
"input_ids": input_ids,
"labels": labels,
"attention_mask": attention_mask,
"seq_length": seq_length,
}
class MllmTokenizer(LlamaTokenizerFast, MLLMTokenizerMixin):
def __init__(
self, num_visual_tokens=361, chat_template_type="mistral", *args, **kwargs
):
LlamaTokenizerFast.__init__(self, *args, **kwargs)
self.mllm_setup(
num_visual_tokens=num_visual_tokens, chat_template=chat_template_type
)
#################################################################
# Tokenizer builder
#################################################################
def extend_instance_(obj, mixin):
"""Apply mixins to a class instance after creation"""
base_cls = obj.__class__
base_cls_name = obj.__class__.__name__
obj.__class__ = type(base_cls_name, (base_cls, mixin), {})
def build_mllm_tokenizer(tokenizer, num_visual_tokens: int, chat_template=str):
"""Build mllm tokenizer with monkey-patch"""
# If use_fast=True, the tokenizer is re-constructed causing long building time (about 5min)
# Another solution is save-and-load re-constructed fast tokenizer, but we simply use
# normal version here.
# monkey patch
extend_instance_(tokenizer, MLLMTokenizerMixin)
tokenizer.mllm_setup(num_visual_tokens, chat_template)
return tokenizer