import torch from xtuner.dataset.utils import expand2square from xtuner.model.utils import prepare_inputs_labels_for_multimodal from xtuner.utils import (DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX) import warnings from mmengine.utils.misc import get_object_from_string from transformers import GenerationConfig, StoppingCriteriaList from xtuner.dataset.utils import load_image from xtuner.registry import BUILDER from xtuner.utils import StopWordStoppingCriteria from xtuner.engine.hooks import EvaluateChatHook class EvaluateChatHook_withSpecialTokens(EvaluateChatHook): priority = 'LOW' def __init__(self, tokenizer, evaluation_inputs, evaluation_images=None, image_processor=None, system='', prompt_template=None, every_n_iters=None, max_new_tokens=600, stop_word=None, stop_words=[]): self.evaluation_inputs = evaluation_inputs if isinstance(self.evaluation_inputs, str): self.evaluation_inputs = [self.evaluation_inputs] self.evaluation_images = evaluation_images if isinstance(self.evaluation_images, str): self.evaluation_images = [self.evaluation_images] if self.evaluation_images is not None: assert len( self.evaluation_images) in [1, len(self.evaluation_inputs)] if len(self.evaluation_images) == 1: self.evaluation_images = [self.evaluation_images[0]] * len( self.evaluation_inputs) self.evaluation_images = [ load_image(img) for img in self.evaluation_images ] if prompt_template is None: instruction = '{input}' else: if isinstance(prompt_template, str): # for resume prompt_template = get_object_from_string(prompt_template) instruction = prompt_template.get('INSTRUCTION', '{input}') if system != '': system = prompt_template.get( 'SYSTEM', '{system}\n').format(system=system) stop_words += prompt_template.get('STOP_WORDS', []) if stop_word is not None: # TODO: deprecation, v0.3.0 warnings.warn( ('The `stop_word` argument is deprecated and will be removed ' 'in v0.3.0, use `stop_words` instead.'), DeprecationWarning) stop_words.append(stop_word) self.instruction = instruction self.system = system self.every_n_iters = every_n_iters self.max_new_tokens = max_new_tokens self.tokenizer = BUILDER.build(tokenizer) self._add_special_tokens() if image_processor is not None: self.image_processor = BUILDER.build(image_processor) self.stop_criteria = StoppingCriteriaList() # default generation config self.gen_config = GenerationConfig( max_new_tokens=max_new_tokens, do_sample=True, temperature=0.1, top_p=0.75, top_k=40, eos_token_id=self.tokenizer.eos_token_id, pad_token_id=self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else self.tokenizer.eos_token_id, ) self.stop_criteria = StoppingCriteriaList() for word in stop_words: self.stop_criteria.append( StopWordStoppingCriteria(self.tokenizer, word)) self.is_first_run = True def _add_special_tokens(self): assert hasattr(self, "tokenizer") # Adding special tokens for pixel grounding segmentation_tokens = ['[SEG]'] # Adding tokens for GCG phrase_tokens = ['

', '

'] # add for visual prompt region_tokens = [''] point_tokens = [''] special_tokens = segmentation_tokens + phrase_tokens + region_tokens + point_tokens self.tokenizer.add_tokens(special_tokens, special_tokens=True) return def _eval_images(self, runner, model, device, max_new_tokens=None, save_eval_output=False): if save_eval_output: eval_outputs = [] for sample_image, sample_input in zip(self.evaluation_images, self.evaluation_inputs): image = expand2square( sample_image, tuple(int(x * 255) for x in self.image_processor.image_mean)) image = self.image_processor.preprocess( image, return_tensors='pt')['pixel_values'][0] image = image.to(device) sample_input = DEFAULT_IMAGE_TOKEN + '\n' + sample_input inputs = (self.system + self.instruction).format( input=sample_input, round=1, **runner.cfg) chunk_encode = [] for idx, chunk in enumerate(inputs.split(DEFAULT_IMAGE_TOKEN)): if idx == 0: cur_encode = self.tokenizer.encode(chunk) else: cur_encode = self.tokenizer.encode( chunk, add_special_tokens=False) chunk_encode.append(cur_encode) assert len(chunk_encode) == 2 input_ids = [] for idx, cur_chunk_encode in enumerate(chunk_encode): input_ids.extend(cur_chunk_encode) if idx != len(chunk_encode) - 1: input_ids.append(IMAGE_TOKEN_INDEX) input_ids = torch.tensor(input_ids).to(device) visual_outputs = model.visual_encoder( # image.unsqueeze(0).to(model.visual_encoder.dtype), image.unsqueeze(0).to(torch.bfloat16), output_hidden_states=True) if isinstance(visual_outputs, list) or isinstance(visual_outputs, tuple)\ or isinstance(visual_outputs, torch.Tensor): pixel_values = model.projector(visual_outputs) else: pixel_values = model.projector( visual_outputs.hidden_states[model.visual_select_layer][:, 1:]) mm_inputs = prepare_inputs_labels_for_multimodal( llm=model.llm, input_ids=input_ids.unsqueeze(0), pixel_values=pixel_values) generation_output = model.generate( **mm_inputs, max_new_tokens=max_new_tokens, generation_config=self.gen_config, bos_token_id=self.tokenizer.bos_token_id, stopping_criteria=self.stop_criteria) generation_output = self.tokenizer.decode(generation_output[0]) runner.logger.info(f'Sample output:\n' f'{inputs + generation_output}\n') if save_eval_output: eval_outputs.append(f'{inputs + generation_output}\n') if save_eval_output: self._save_eval_output(runner, eval_outputs)