|
|
import torch |
|
|
from xtuner.dataset.utils import expand2square |
|
|
from xtuner.model.utils import prepare_inputs_labels_for_multimodal |
|
|
from xtuner.utils import (DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX) |
|
|
|
|
|
import warnings |
|
|
from mmengine.utils.misc import get_object_from_string |
|
|
from transformers import GenerationConfig, StoppingCriteriaList |
|
|
from xtuner.dataset.utils import load_image |
|
|
from xtuner.registry import BUILDER |
|
|
from xtuner.utils import StopWordStoppingCriteria |
|
|
from xtuner.engine.hooks import EvaluateChatHook |
|
|
|
|
|
class EvaluateChatHook_withSpecialTokens(EvaluateChatHook): |
|
|
priority = 'LOW' |
|
|
def __init__(self, |
|
|
tokenizer, |
|
|
evaluation_inputs, |
|
|
evaluation_images=None, |
|
|
image_processor=None, |
|
|
system='', |
|
|
prompt_template=None, |
|
|
every_n_iters=None, |
|
|
max_new_tokens=600, |
|
|
stop_word=None, |
|
|
stop_words=[]): |
|
|
self.evaluation_inputs = evaluation_inputs |
|
|
if isinstance(self.evaluation_inputs, str): |
|
|
self.evaluation_inputs = [self.evaluation_inputs] |
|
|
self.evaluation_images = evaluation_images |
|
|
if isinstance(self.evaluation_images, str): |
|
|
self.evaluation_images = [self.evaluation_images] |
|
|
if self.evaluation_images is not None: |
|
|
assert len( |
|
|
self.evaluation_images) in [1, len(self.evaluation_inputs)] |
|
|
if len(self.evaluation_images) == 1: |
|
|
self.evaluation_images = [self.evaluation_images[0]] * len( |
|
|
self.evaluation_inputs) |
|
|
self.evaluation_images = [ |
|
|
load_image(img) for img in self.evaluation_images |
|
|
] |
|
|
if prompt_template is None: |
|
|
instruction = '{input}' |
|
|
else: |
|
|
if isinstance(prompt_template, str): |
|
|
prompt_template = get_object_from_string(prompt_template) |
|
|
instruction = prompt_template.get('INSTRUCTION', '{input}') |
|
|
if system != '': |
|
|
system = prompt_template.get( |
|
|
'SYSTEM', '{system}\n').format(system=system) |
|
|
stop_words += prompt_template.get('STOP_WORDS', []) |
|
|
if stop_word is not None: |
|
|
|
|
|
warnings.warn( |
|
|
('The `stop_word` argument is deprecated and will be removed ' |
|
|
'in v0.3.0, use `stop_words` instead.'), DeprecationWarning) |
|
|
stop_words.append(stop_word) |
|
|
self.instruction = instruction |
|
|
self.system = system |
|
|
self.every_n_iters = every_n_iters |
|
|
self.max_new_tokens = max_new_tokens |
|
|
self.tokenizer = BUILDER.build(tokenizer) |
|
|
self._add_special_tokens() |
|
|
if image_processor is not None: |
|
|
self.image_processor = BUILDER.build(image_processor) |
|
|
self.stop_criteria = StoppingCriteriaList() |
|
|
|
|
|
self.gen_config = GenerationConfig( |
|
|
max_new_tokens=max_new_tokens, |
|
|
do_sample=True, |
|
|
temperature=0.1, |
|
|
top_p=0.75, |
|
|
top_k=40, |
|
|
eos_token_id=self.tokenizer.eos_token_id, |
|
|
pad_token_id=self.tokenizer.pad_token_id |
|
|
if self.tokenizer.pad_token_id is not None else |
|
|
self.tokenizer.eos_token_id, |
|
|
) |
|
|
self.stop_criteria = StoppingCriteriaList() |
|
|
for word in stop_words: |
|
|
self.stop_criteria.append( |
|
|
StopWordStoppingCriteria(self.tokenizer, word)) |
|
|
|
|
|
self.is_first_run = True |
|
|
|
|
|
def _add_special_tokens(self): |
|
|
assert hasattr(self, "tokenizer") |
|
|
|
|
|
segmentation_tokens = ['[SEG]'] |
|
|
|
|
|
phrase_tokens = ['<p>', '</p>'] |
|
|
|
|
|
region_tokens = ['<region>'] |
|
|
point_tokens = ['<mark>'] |
|
|
special_tokens = segmentation_tokens + phrase_tokens + region_tokens + point_tokens |
|
|
self.tokenizer.add_tokens(special_tokens, special_tokens=True) |
|
|
return |
|
|
|
|
|
def _eval_images(self, |
|
|
runner, |
|
|
model, |
|
|
device, |
|
|
max_new_tokens=None, |
|
|
save_eval_output=False): |
|
|
if save_eval_output: |
|
|
eval_outputs = [] |
|
|
|
|
|
for sample_image, sample_input in zip(self.evaluation_images, |
|
|
self.evaluation_inputs): |
|
|
image = expand2square( |
|
|
sample_image, |
|
|
tuple(int(x * 255) for x in self.image_processor.image_mean)) |
|
|
image = self.image_processor.preprocess( |
|
|
image, return_tensors='pt')['pixel_values'][0] |
|
|
image = image.to(device) |
|
|
sample_input = DEFAULT_IMAGE_TOKEN + '\n' + sample_input |
|
|
inputs = (self.system + self.instruction).format( |
|
|
input=sample_input, round=1, **runner.cfg) |
|
|
chunk_encode = [] |
|
|
for idx, chunk in enumerate(inputs.split(DEFAULT_IMAGE_TOKEN)): |
|
|
if idx == 0: |
|
|
cur_encode = self.tokenizer.encode(chunk) |
|
|
else: |
|
|
cur_encode = self.tokenizer.encode( |
|
|
chunk, add_special_tokens=False) |
|
|
chunk_encode.append(cur_encode) |
|
|
assert len(chunk_encode) == 2 |
|
|
input_ids = [] |
|
|
for idx, cur_chunk_encode in enumerate(chunk_encode): |
|
|
input_ids.extend(cur_chunk_encode) |
|
|
if idx != len(chunk_encode) - 1: |
|
|
input_ids.append(IMAGE_TOKEN_INDEX) |
|
|
input_ids = torch.tensor(input_ids).to(device) |
|
|
visual_outputs = model.visual_encoder( |
|
|
|
|
|
image.unsqueeze(0).to(torch.bfloat16), |
|
|
output_hidden_states=True) |
|
|
if isinstance(visual_outputs, list) or isinstance(visual_outputs, tuple)\ |
|
|
or isinstance(visual_outputs, torch.Tensor): |
|
|
pixel_values = model.projector(visual_outputs) |
|
|
else: |
|
|
pixel_values = model.projector( |
|
|
visual_outputs.hidden_states[model.visual_select_layer][:, 1:]) |
|
|
|
|
|
mm_inputs = prepare_inputs_labels_for_multimodal( |
|
|
llm=model.llm, |
|
|
input_ids=input_ids.unsqueeze(0), |
|
|
pixel_values=pixel_values) |
|
|
|
|
|
generation_output = model.generate( |
|
|
**mm_inputs, |
|
|
max_new_tokens=max_new_tokens, |
|
|
generation_config=self.gen_config, |
|
|
bos_token_id=self.tokenizer.bos_token_id, |
|
|
stopping_criteria=self.stop_criteria) |
|
|
generation_output = self.tokenizer.decode(generation_output[0]) |
|
|
runner.logger.info(f'Sample output:\n' |
|
|
f'{inputs + generation_output}\n') |
|
|
if save_eval_output: |
|
|
eval_outputs.append(f'{inputs + generation_output}\n') |
|
|
|
|
|
if save_eval_output: |
|
|
self._save_eval_output(runner, eval_outputs) |
|
|
|