File size: 7,176 Bytes
032e687 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 |
import torch
from xtuner.dataset.utils import expand2square
from xtuner.model.utils import prepare_inputs_labels_for_multimodal
from xtuner.utils import (DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX)
import warnings
from mmengine.utils.misc import get_object_from_string
from transformers import GenerationConfig, StoppingCriteriaList
from xtuner.dataset.utils import load_image
from xtuner.registry import BUILDER
from xtuner.utils import StopWordStoppingCriteria
from xtuner.engine.hooks import EvaluateChatHook
class EvaluateChatHook_withSpecialTokens(EvaluateChatHook):
priority = 'LOW'
def __init__(self,
tokenizer,
evaluation_inputs,
evaluation_images=None,
image_processor=None,
system='',
prompt_template=None,
every_n_iters=None,
max_new_tokens=600,
stop_word=None,
stop_words=[]):
self.evaluation_inputs = evaluation_inputs
if isinstance(self.evaluation_inputs, str):
self.evaluation_inputs = [self.evaluation_inputs]
self.evaluation_images = evaluation_images
if isinstance(self.evaluation_images, str):
self.evaluation_images = [self.evaluation_images]
if self.evaluation_images is not None:
assert len(
self.evaluation_images) in [1, len(self.evaluation_inputs)]
if len(self.evaluation_images) == 1:
self.evaluation_images = [self.evaluation_images[0]] * len(
self.evaluation_inputs)
self.evaluation_images = [
load_image(img) for img in self.evaluation_images
]
if prompt_template is None:
instruction = '{input}'
else:
if isinstance(prompt_template, str): # for resume
prompt_template = get_object_from_string(prompt_template)
instruction = prompt_template.get('INSTRUCTION', '{input}')
if system != '':
system = prompt_template.get(
'SYSTEM', '{system}\n').format(system=system)
stop_words += prompt_template.get('STOP_WORDS', [])
if stop_word is not None:
# TODO: deprecation, v0.3.0
warnings.warn(
('The `stop_word` argument is deprecated and will be removed '
'in v0.3.0, use `stop_words` instead.'), DeprecationWarning)
stop_words.append(stop_word)
self.instruction = instruction
self.system = system
self.every_n_iters = every_n_iters
self.max_new_tokens = max_new_tokens
self.tokenizer = BUILDER.build(tokenizer)
self._add_special_tokens()
if image_processor is not None:
self.image_processor = BUILDER.build(image_processor)
self.stop_criteria = StoppingCriteriaList()
# default generation config
self.gen_config = GenerationConfig(
max_new_tokens=max_new_tokens,
do_sample=True,
temperature=0.1,
top_p=0.75,
top_k=40,
eos_token_id=self.tokenizer.eos_token_id,
pad_token_id=self.tokenizer.pad_token_id
if self.tokenizer.pad_token_id is not None else
self.tokenizer.eos_token_id,
)
self.stop_criteria = StoppingCriteriaList()
for word in stop_words:
self.stop_criteria.append(
StopWordStoppingCriteria(self.tokenizer, word))
self.is_first_run = True
def _add_special_tokens(self):
assert hasattr(self, "tokenizer")
# Adding special tokens for pixel grounding
segmentation_tokens = ['[SEG]']
# Adding tokens for GCG
phrase_tokens = ['<p>', '</p>']
# add for visual prompt
region_tokens = ['<region>']
point_tokens = ['<mark>']
special_tokens = segmentation_tokens + phrase_tokens + region_tokens + point_tokens
self.tokenizer.add_tokens(special_tokens, special_tokens=True)
return
def _eval_images(self,
runner,
model,
device,
max_new_tokens=None,
save_eval_output=False):
if save_eval_output:
eval_outputs = []
for sample_image, sample_input in zip(self.evaluation_images,
self.evaluation_inputs):
image = expand2square(
sample_image,
tuple(int(x * 255) for x in self.image_processor.image_mean))
image = self.image_processor.preprocess(
image, return_tensors='pt')['pixel_values'][0]
image = image.to(device)
sample_input = DEFAULT_IMAGE_TOKEN + '\n' + sample_input
inputs = (self.system + self.instruction).format(
input=sample_input, round=1, **runner.cfg)
chunk_encode = []
for idx, chunk in enumerate(inputs.split(DEFAULT_IMAGE_TOKEN)):
if idx == 0:
cur_encode = self.tokenizer.encode(chunk)
else:
cur_encode = self.tokenizer.encode(
chunk, add_special_tokens=False)
chunk_encode.append(cur_encode)
assert len(chunk_encode) == 2
input_ids = []
for idx, cur_chunk_encode in enumerate(chunk_encode):
input_ids.extend(cur_chunk_encode)
if idx != len(chunk_encode) - 1:
input_ids.append(IMAGE_TOKEN_INDEX)
input_ids = torch.tensor(input_ids).to(device)
visual_outputs = model.visual_encoder(
# image.unsqueeze(0).to(model.visual_encoder.dtype),
image.unsqueeze(0).to(torch.bfloat16),
output_hidden_states=True)
if isinstance(visual_outputs, list) or isinstance(visual_outputs, tuple)\
or isinstance(visual_outputs, torch.Tensor):
pixel_values = model.projector(visual_outputs)
else:
pixel_values = model.projector(
visual_outputs.hidden_states[model.visual_select_layer][:, 1:])
mm_inputs = prepare_inputs_labels_for_multimodal(
llm=model.llm,
input_ids=input_ids.unsqueeze(0),
pixel_values=pixel_values)
generation_output = model.generate(
**mm_inputs,
max_new_tokens=max_new_tokens,
generation_config=self.gen_config,
bos_token_id=self.tokenizer.bos_token_id,
stopping_criteria=self.stop_criteria)
generation_output = self.tokenizer.decode(generation_output[0])
runner.logger.info(f'Sample output:\n'
f'{inputs + generation_output}\n')
if save_eval_output:
eval_outputs.append(f'{inputs + generation_output}\n')
if save_eval_output:
self._save_eval_output(runner, eval_outputs)
|