| | from transformers import AutoTokenizer
|
| | import torchvision.transforms as transforms
|
| |
|
| | from data.custom_transforms import DynamicResize, SplitImage, GlobalAndSplitImages
|
| |
|
| | TOKENIZERS_CACHE = {}
|
| |
|
| | def get_tokenizer(name, extra_special_tokens=None, chat_template=None):
|
| | if name not in TOKENIZERS_CACHE:
|
| | tokenizer_init_kwargs = {"use_fast": True}
|
| | if extra_special_tokens is not None:
|
| | tokenizer_init_kwargs["extra_special_tokens"] = extra_special_tokens
|
| | if chat_template is not None:
|
| | tokenizer_init_kwargs["chat_template"] = chat_template
|
| | tokenizer = AutoTokenizer.from_pretrained(name, **tokenizer_init_kwargs,)
|
| | tokenizer.pad_token = tokenizer.eos_token
|
| | TOKENIZERS_CACHE[name] = tokenizer
|
| | return TOKENIZERS_CACHE[name]
|
| |
|
| | def get_image_processor(max_img_size, splitted_image_size, resize_to_max_side_len=False):
|
| | return transforms.Compose([
|
| | DynamicResize(splitted_image_size, max_img_size, resize_to_max_side_len),
|
| | transforms.ToTensor(),
|
| | GlobalAndSplitImages(splitted_image_size),
|
| | ])
|
| |
|
| | def get_image_string(tokenizer, splitted_image_counts, mp_image_token_length):
|
| | image_string = ""
|
| |
|
| | for idx, (n_h, n_w) in enumerate(splitted_image_counts):
|
| | if len(splitted_image_counts) > 1:
|
| | image_string += f"<image: {idx}>"
|
| | if hasattr(tokenizer, "global_image_token"):
|
| | image_string += tokenizer.global_image_token
|
| | image_string += tokenizer.image_token * mp_image_token_length
|
| | if n_h == 1 and n_w == 1:
|
| | continue
|
| | for i in range(n_h):
|
| | for j in range(n_w):
|
| | image_string += getattr(tokenizer, f'r{i+1}c{j+1}')
|
| | image_string += tokenizer.image_token * mp_image_token_length
|
| | return image_string
|
| |
|