# -------------------------------------------------------- # InternVL # Copyright (c) 2024 OpenGVLab # Licensed under The MIT License [see LICENSE for details] # -------------------------------------------------------- import warnings from typing import Any, List, Optional, Tuple, Union from PIL import Image import torch.utils.checkpoint import transformers from torch import nn from torch.nn import CrossEntropyLoss from transformers import (AutoModel, GenerationConfig, LlamaForCausalLM, LlamaTokenizer, Qwen2ForCausalLM) import torchvision.transforms as T from torchvision.transforms.functional import InterpolationMode from transformers.modeling_outputs import CausalLMOutputWithPast from transformers.modeling_utils import PreTrainedModel from transformers.utils import ModelOutput, logging from .configuration_internvl_chat import InternVLChatConfig from .conversation import get_conv_template from .modeling_intern_vit import InternVisionModel, has_flash_attn from .modeling_internlm2 import InternLM2ForCausalLM logger = logging.get_logger(__name__) IMAGENET_MEAN = (0.485, 0.456, 0.406) IMAGENET_STD = (0.229, 0.224, 0.225) def version_cmp(v1, v2, op='eq'): import operator from packaging import version op_func = getattr(operator, op) return op_func(version.parse(v1), version.parse(v2)) def build_transform(input_size): MEAN, STD = IMAGENET_MEAN, IMAGENET_STD transform = T.Compose([ T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img), T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC), T.ToTensor(), T.Normalize(mean=MEAN, std=STD) ]) return transform def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size): best_ratio_diff = float('inf') best_ratio = (1, 1) area = width * height for ratio in target_ratios: target_aspect_ratio = ratio[0] / ratio[1] ratio_diff = abs(aspect_ratio - target_aspect_ratio) if ratio_diff < best_ratio_diff: best_ratio_diff = ratio_diff best_ratio = ratio elif ratio_diff == best_ratio_diff: if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]: best_ratio = ratio return best_ratio def dynamic_preprocess(image, min_num=1, max_num=12, image_size=448, use_thumbnail=False): orig_width, orig_height = image.size aspect_ratio = orig_width / orig_height # calculate the existing image aspect ratio target_ratios = set( (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if i * j <= max_num and i * j >= min_num) target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1]) # find the closest aspect ratio to the target target_aspect_ratio = find_closest_aspect_ratio( aspect_ratio, target_ratios, orig_width, orig_height, image_size) # calculate the target width and height target_width = image_size * target_aspect_ratio[0] target_height = image_size * target_aspect_ratio[1] blocks = target_aspect_ratio[0] * target_aspect_ratio[1] # resize the image resized_img = image.resize((target_width, target_height)) processed_images = [] for i in range(blocks): box = ( (i % (target_width // image_size)) * image_size, (i // (target_width // image_size)) * image_size, ((i % (target_width // image_size)) + 1) * image_size, ((i // (target_width // image_size)) + 1) * image_size ) # split the image split_img = resized_img.crop(box) processed_images.append(split_img) assert len(processed_images) == blocks if use_thumbnail and len(processed_images) != 1: thumbnail_img = image.resize((image_size, image_size)) processed_images.append(thumbnail_img) return processed_images def load_image(image_file, input_size=448, max_num=12): image = Image.open(image_file).convert('RGB') transform = build_transform(input_size=input_size) images = dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, max_num=max_num) pixel_values = [transform(image) for image in images] pixel_values = torch.stack(pixel_values) return pixel_values class InternVLChatModel(PreTrainedModel): config_class = InternVLChatConfig main_input_name = 'pixel_values' _supports_flash_attn_2 = True supports_gradient_checkpointing = True _no_split_modules = ['InternVisionModel', 'LlamaDecoderLayer', 'InternLM2DecoderLayer', 'Qwen2DecoderLayer'] def __init__(self, config: InternVLChatConfig, vision_model=None, language_model=None, use_flash_attn=True): super().__init__(config) assert version_cmp(transformers.__version__, '4.36.2', 'ge') image_size = config.force_image_size or config.vision_config.image_size patch_size = config.vision_config.patch_size self.patch_size = patch_size self.select_layer = config.select_layer self.template = config.template self.num_image_token = int((image_size // patch_size) ** 2 * (config.downsample_ratio ** 2)) self.downsample_ratio = config.downsample_ratio self.ps_version = config.ps_version use_flash_attn = use_flash_attn if has_flash_attn else False config.vision_config.use_flash_attn = True if use_flash_attn else False config.llm_config.attn_implementation = 'flash_attention_2' if use_flash_attn else 'eager' logger.info(f'num_image_token: {self.num_image_token}') logger.info(f'ps_version: {self.ps_version}') if vision_model is not None: self.vision_model = vision_model else: self.vision_model = InternVisionModel(config.vision_config) if language_model is not None: self.language_model = language_model else: if config.llm_config.architectures[0] == 'LlamaForCausalLM': self.language_model = LlamaForCausalLM(config.llm_config) elif config.llm_config.architectures[0] == 'InternLM2ForCausalLM': self.language_model = InternLM2ForCausalLM(config.llm_config) elif config.llm_config.architectures[0] == 'Qwen2ForCausalLM': self.language_model = Qwen2ForCausalLM(config.llm_config) else: raise NotImplementedError(f'{config.llm_config.architectures[0]} is not implemented.') vit_hidden_size = config.vision_config.hidden_size llm_hidden_size = config.llm_config.hidden_size self.mlp1 = nn.Sequential( nn.LayerNorm(vit_hidden_size * int(1 / self.downsample_ratio) ** 2), nn.Linear(vit_hidden_size * int(1 / self.downsample_ratio) ** 2, llm_hidden_size), nn.GELU(), nn.Linear(llm_hidden_size, llm_hidden_size) ) self.img_context_token_id = None self.conv_template = get_conv_template(self.template) self.system_message = self.conv_template.system_message def forward( self, pixel_values: torch.FloatTensor, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, image_flags: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, CausalLMOutputWithPast]: return_dict = return_dict if return_dict is not None else self.config.use_return_dict image_flags = image_flags.squeeze(-1) input_embeds = self.language_model.get_input_embeddings()(input_ids) vit_embeds = self.extract_feature(pixel_values) vit_embeds = vit_embeds[image_flags == 1] vit_batch_size = pixel_values.shape[0] B, N, C = input_embeds.shape input_embeds = input_embeds.reshape(B * N, C) if torch.distributed.get_rank() == 0: print(f'dynamic ViT batch size: {vit_batch_size}, images per sample: {vit_batch_size / B}, dynamic token length: {N}') input_ids = input_ids.reshape(B * N) selected = (input_ids == self.img_context_token_id) try: input_embeds[selected] = input_embeds[selected] * 0.0 + vit_embeds.reshape(-1, C) except Exception as e: vit_embeds = vit_embeds.reshape(-1, C) print(f'warning: {e}, input_embeds[selected].shape={input_embeds[selected].shape}, ' f'vit_embeds.shape={vit_embeds.shape}') n_token = min(selected.sum(), vit_embeds.size(0)) input_embeds[selected][:n_token] = input_embeds[selected][:n_token] * 0.0 + vit_embeds[:n_token] input_embeds = input_embeds.reshape(B, N, C) outputs = self.language_model( inputs_embeds=input_embeds, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) logits = outputs.logits loss = None if labels is not None: # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() # Flatten the tokens loss_fct = CrossEntropyLoss() shift_logits = shift_logits.view(-1, self.language_model.config.vocab_size) shift_labels = shift_labels.view(-1) # Enable model parallelism shift_labels = shift_labels.to(shift_logits.device) loss = loss_fct(shift_logits, shift_labels) if not return_dict: output = (logits,) + outputs[1:] return (loss,) + output if loss is not None else output return CausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) def pixel_shuffle(self, x, scale_factor=0.5): n, w, h, c = x.size() # N, W, H, C --> N, W, H * scale, C // scale x = x.view(n, w, int(h * scale_factor), int(c / scale_factor)) # N, W, H * scale, C // scale --> N, H * scale, W, C // scale x = x.permute(0, 2, 1, 3).contiguous() # N, H * scale, W, C // scale --> N, H * scale, W * scale, C // (scale ** 2) x = x.view(n, int(h * scale_factor), int(w * scale_factor), int(c / (scale_factor * scale_factor))) if self.ps_version == 'v1': warnings.warn("In ps_version 'v1', the height and width have not been swapped back, " 'which results in a transposed image.') else: x = x.permute(0, 2, 1, 3).contiguous() return x def extract_feature(self, pixel_values): if self.select_layer == -1: vit_embeds = self.vision_model( pixel_values=pixel_values, output_hidden_states=False, return_dict=True).last_hidden_state else: vit_embeds = self.vision_model( pixel_values=pixel_values, output_hidden_states=True, return_dict=True).hidden_states[self.select_layer] vit_embeds = vit_embeds[:, 1:, :] h = w = int(vit_embeds.shape[1] ** 0.5) vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1) vit_embeds = self.pixel_shuffle(vit_embeds, scale_factor=self.downsample_ratio) vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-1]) vit_embeds = self.mlp1(vit_embeds) return vit_embeds def batch_chat(self, tokenizer, pixel_values, questions, generation_config, num_patches_list=None, history=None, return_history=False, IMG_START_TOKEN='', IMG_END_TOKEN='', IMG_CONTEXT_TOKEN='', verbose=False, image_counts=None): if history is not None or return_history: print('Now multi-turn chat is not supported in batch_chat.') raise NotImplementedError if image_counts is not None: num_patches_list = image_counts print('Warning: `image_counts` is deprecated. Please use `num_patches_list` instead.') img_context_token_id = tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN) self.img_context_token_id = img_context_token_id if verbose and pixel_values is not None: image_bs = pixel_values.shape[0] print(f'dynamic ViT batch size: {image_bs}') queries = [] for idx, num_patches in enumerate(num_patches_list): question = questions[idx] if pixel_values is not None and '' not in question: question = '\n' + question template = get_conv_template(self.template) template.system_message = self.system_message template.append_message(template.roles[0], question) template.append_message(template.roles[1], None) query = template.get_prompt() image_tokens = IMG_START_TOKEN + IMG_CONTEXT_TOKEN * self.num_image_token * num_patches + IMG_END_TOKEN query = query.replace('', image_tokens, 1) queries.append(query) tokenizer.padding_side = 'left' model_inputs = tokenizer(queries, return_tensors='pt', padding=True) input_ids = model_inputs['input_ids'].cuda() attention_mask = model_inputs['attention_mask'].cuda() eos_token_id = tokenizer.convert_tokens_to_ids(template.sep.strip()) generation_config['eos_token_id'] = eos_token_id generation_output = self.generate( pixel_values=pixel_values, input_ids=input_ids, attention_mask=attention_mask, **generation_config ) responses = tokenizer.batch_decode(generation_output, skip_special_tokens=True) responses = [response.split(template.sep)[0].strip() for response in responses] return responses def chat(self, tokenizer, pixel_values, question, generation_config, history=None, return_history=False, image_dirs=None, num_patches_list=None, IMG_START_TOKEN='', IMG_END_TOKEN='', IMG_CONTEXT_TOKEN='', verbose=False): if image_dirs is not None: print("----------------------------------") print("Using image_dirs to load images. 'pixel_values' and 'num_patches_list' will be ignored.") print("You should provide all the previous image files and the current image file in the 'image_dirs' argument.") print("----------------------------------") if isinstance(image_dirs, str): image_dirs = [image_dirs] elif isinstance(image_dirs, list): pass else: raise ValueError(f'Invalid image_dirs: {image_dirs}. It should be a string or a list of strings.') image_values = [load_image(image_file, max_num=12).to(torch.float16).cuda() for image_file in image_dirs] pixel_values = torch.cat(image_values, dim=0) num_patches_list = [image_values[i].shape[0] for i in range(len(image_values))] if history is None and pixel_values is not None and '' not in question: question = '\n' + question if num_patches_list is None: num_patches_list = [pixel_values.shape[0]] if pixel_values is not None else [] assert pixel_values is None or len(pixel_values) == sum(num_patches_list) img_context_token_id = tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN) self.img_context_token_id = img_context_token_id template = get_conv_template(self.template) template.system_message = self.system_message eos_token_id = tokenizer.convert_tokens_to_ids(template.sep.strip()) history = [] if history is None else history for (old_question, old_answer) in history: template.append_message(template.roles[0], old_question) template.append_message(template.roles[1], old_answer) template.append_message(template.roles[0], question) template.append_message(template.roles[1], None) query = template.get_prompt() if verbose and pixel_values is not None: image_bs = pixel_values.shape[0] print(f'dynamic ViT batch size: {image_bs}') for num_patches in num_patches_list: image_tokens = IMG_START_TOKEN + IMG_CONTEXT_TOKEN * self.num_image_token * num_patches + IMG_END_TOKEN query = query.replace('', image_tokens, 1) model_inputs = tokenizer(query, return_tensors='pt') input_ids = model_inputs['input_ids'].cuda() attention_mask = model_inputs['attention_mask'].cuda() generation_config['eos_token_id'] = eos_token_id generation_output = self.generate( pixel_values=pixel_values, input_ids=input_ids, attention_mask=attention_mask, **generation_config ) response = tokenizer.batch_decode(generation_output, skip_special_tokens=True)[0] response = response.split(template.sep.strip())[0].strip() history.append((question, response)) if return_history: return response, history else: query_to_print = query.replace(IMG_CONTEXT_TOKEN, '') query_to_print = query_to_print.replace(f'{IMG_START_TOKEN}{IMG_END_TOKEN}', '') if verbose: print(query_to_print, response) return response @torch.no_grad() def generate( self, pixel_values: Optional[torch.FloatTensor] = None, input_ids: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.LongTensor] = None, visual_features: Optional[torch.FloatTensor] = None, generation_config: Optional[GenerationConfig] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, **generate_kwargs, ) -> torch.LongTensor: assert self.img_context_token_id is not None if pixel_values is not None: if visual_features is not None: vit_embeds = visual_features else: vit_embeds = self.extract_feature(pixel_values) input_embeds = self.language_model.get_input_embeddings()(input_ids) B, N, C = input_embeds.shape input_embeds = input_embeds.reshape(B * N, C) input_ids = input_ids.reshape(B * N) selected = (input_ids == self.img_context_token_id) assert selected.sum() != 0 input_embeds[selected] = vit_embeds.reshape(-1, C).to(input_embeds.device) input_embeds = input_embeds.reshape(B, N, C) else: input_embeds = self.language_model.get_input_embeddings()(input_ids) outputs = self.language_model.generate( inputs_embeds=input_embeds, attention_mask=attention_mask, generation_config=generation_config, output_hidden_states=output_hidden_states, # return_dict=return_dict, # return_dict is not supported in transformers 4.44 use_cache=True, **generate_kwargs, ) return outputs @property def lm_head(self): return self.language_model.get_output_embeddings() def get_input_embeddings(self): return self.language_model.get_input_embeddings() def get_output_embeddings(self): return self.language_model.get_output_embeddings()