| | from typing import List, Optional, Tuple, Union |
| | import warnings, os, torch |
| | import torch.nn as nn |
| |
|
| | from transformers import AutoConfig, PretrainedConfig, PreTrainedModel, AutoModelForCausalLM, AutoTokenizer |
| | from transformers.modeling_utils import ContextManagers, no_init_weights |
| | from transformers.modeling_outputs import CausalLMOutputWithPast |
| | from transformers.generation.utils import GenerateOutput |
| | from .configuration_apollo import ApolloConfig |
| |
|
| | from .vision_tower import ApolloVisionTower |
| | from .mm_connector import MMConnector |
| |
|
| | IGNORE_INDEX = -100 |
| | X_TOKEN_INDEX = -200 |
| |
|
| |
|
| | def get_model_config(config): |
| | default_keys = ["llm_cfg", "vision_tower_cfg", "mm_connector_cfg"] |
| | if hasattr(config, "_name_or_path") and len(config._name_or_path) >= 2: |
| | root_path = config._name_or_path |
| | else: |
| | root_path = config.resume_path |
| |
|
| | return_pths = [] |
| | for key in default_keys: |
| | cfg = getattr(config, key, None) |
| | if isinstance(cfg, dict): |
| | try: |
| | return_pths.append(os.path.join(root_path, key[:-4])) |
| | except: |
| | raise ValueError(f"Cannot find resume path in config for {key}!") |
| | elif isinstance(cfg, PretrainedConfig): |
| | return_pths.append(os.path.join(root_path, key[:-4])) |
| | elif isinstance(cfg, str): |
| | return_pths.append(cfg) |
| |
|
| | return_list = [] |
| | for pth in return_pths: |
| | return_list.append(AutoConfig.from_pretrained(pth, trust_remote_code=True)) |
| |
|
| | return return_list |
| |
|
| |
|
| | def build_llm_and_tokenizer( |
| | llm_cfg: str, |
| | config: PretrainedConfig, |
| | attn_implementation=None, |
| | model_max_length=None, |
| | *args, |
| | **kwargs, |
| | ) -> PreTrainedModel: |
| | llm_arch = getattr(llm_cfg, "architectures")[0].lower() |
| | |
| | llm_path = llm_cfg._name_or_path |
| | llm = AutoModelForCausalLM.from_pretrained( |
| | llm_path, config=llm_cfg, torch_dtype=eval(config.model_dtype), *args, **kwargs |
| | ) |
| |
|
| | tokenizer = AutoTokenizer.from_pretrained( |
| | llm_path, |
| | model_max_length=llm_cfg.model_max_length, |
| | padding_side="right", |
| | use_fast=False, |
| | legacy=False, |
| | **kwargs |
| | ) |
| |
|
| | |
| | return llm, tokenizer |
| |
|
| |
|
| | class ApolloForCausalLM(PreTrainedModel): |
| | def __init__(self, config: ApolloConfig, *args, **kwargs): |
| | super().__init__(config) |
| | llm_cfg, vision_tower_cfg, mm_connector_cfg = get_model_config(config) |
| | model_dtype = getattr(config, "model_dtype", "torch.float16") |
| | if not hasattr(config, "model_dtype"): |
| | warnings.warn("model_dtype not found in config, defaulting to torch.float16.") |
| | config.model_dtype = model_dtype |
| | |
| |
|
| | self.lm_head = nn.Linear(llm_cfg.hidden_size, config.vocab_size, bias=False) |
| | self.vision_tower = ApolloVisionTower(config, vision_tower_cfg) |
| | self.mm_connector = MMConnector.from_pretrained(mm_connector_cfg._name_or_path) |
| | self.llm, self.tokenizer = build_llm_and_tokenizer(llm_cfg, config, *args, **kwargs) |
| | self.post_init() |
| | self.is_loaded = True |
| |
|
| | def forward( |
| | self, |
| | input_ids: torch.LongTensor = None, |
| | attention_mask: Optional[torch.Tensor] = None, |
| | position_ids: Optional[torch.LongTensor] = None, |
| | past_key_values: Optional[List[torch.FloatTensor]] = None, |
| | inputs_embeds: Optional[torch.FloatTensor] = None, |
| | labels: Optional[torch.LongTensor] = None, |
| | use_cache: Optional[bool] = None, |
| | output_attentions: Optional[bool] = None, |
| | output_hidden_states: Optional[bool] = None, |
| | vision_input: Optional[List[torch.FloatTensor]] = None, |
| | data_types: Optional[List[str]] = None, |
| | return_dict: Optional[bool] = None, |
| | cache_position=None, |
| | ) -> Union[Tuple, CausalLMOutputWithPast]: |
| | |
| | if inputs_embeds is None: |
| | ( |
| | input_ids, |
| | position_ids, |
| | attention_mask, |
| | past_key_values, |
| | inputs_embeds, |
| | labels |
| | ) = self.prepare_inputs_labels_for_multimodal( |
| | input_ids, |
| | position_ids, |
| | attention_mask, |
| | past_key_values, |
| | labels, |
| | vision_input, |
| | data_types |
| | ) |
| |
|
| | return self.get_llm().forward( |
| | input_ids=input_ids, |
| | attention_mask=attention_mask, |
| | position_ids=position_ids, |
| | past_key_values=past_key_values, |
| | inputs_embeds=inputs_embeds, |
| | labels=labels, |
| | use_cache=use_cache, |
| | output_attentions=output_attentions, |
| | output_hidden_states=output_hidden_states, |
| | return_dict=return_dict, |
| | ) |
| |
|
| | @torch.no_grad() |
| | def generate( |
| | self, |
| | inputs: Optional[torch.Tensor] = None, |
| | vision_input: Optional[List[torch.Tensor]] = None, |
| | data_types: Optional[List[str]] = None, |
| | **kwargs, |
| | ) -> Union[GenerateOutput, torch.LongTensor]: |
| | position_ids = kwargs.pop("position_ids", None) |
| | attention_mask = kwargs.pop("attention_mask", None) |
| | if "inputs_embeds" in kwargs: |
| | raise NotImplementedError("`inputs_embeds` is not supported") |
| |
|
| | if vision_input is not None: |
| | (inputs, position_ids, attention_mask, _, inputs_embeds, _) = self.prepare_inputs_labels_for_multimodal( |
| | inputs, position_ids, attention_mask, None, None, vision_input, data_types=data_types) |
| | else: |
| | inputs_embeds = self.embed_tokens(inputs) |
| |
|
| | return self.get_llm().generate(position_ids=position_ids, attention_mask=attention_mask, |
| | inputs_embeds=inputs_embeds, **kwargs) |
| |
|
| | def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs): |
| | vision_input = kwargs.pop("vision_input", None) |
| | data_types = kwargs.pop("data_types", None) |
| | inputs = self.get_llm().prepare_inputs_for_generation(input_ids, past_key_values=past_key_values, |
| | inputs_embeds=inputs_embeds, **kwargs) |
| | if vision_input is not None: |
| | inputs["vision_input"] = vision_input |
| | if data_types is not None: |
| | inputs["data_types"] = data_types |
| | return inputs |
| |
|
| | @classmethod |
| | def from_pretrained( |
| | cls, |
| | pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], |
| | *model_args, |
| | config: Optional[Union[PretrainedConfig, str, os.PathLike]] = None, |
| | cache_dir: Optional[Union[str, os.PathLike]] = None, |
| | ignore_mismatched_sizes: bool = False, |
| | force_download: bool = False, |
| | local_files_only: bool = False, |
| | token: Optional[Union[str, bool]] = None, |
| | revision: str = "main", |
| | use_safetensors: bool = None, |
| | **kwargs, |
| | ): |
| |
|
| | return cls.load_pretrained( |
| | pretrained_model_name_or_path, |
| | *model_args, |
| | config=config, |
| | cache_dir=cache_dir, |
| | ignore_mismatched_sizes=ignore_mismatched_sizes, |
| | force_download=force_download, |
| | local_files_only=local_files_only, |
| | token=token, |
| | revision=revision, |
| | use_safetensors=use_safetensors, |
| | **kwargs, |
| | ) |
| |
|
| | def get_llm(self): |
| | return self.llm |
| |
|
| | def get_vision_tower(self): |
| | return self.vision_tower |
| |
|
| | def get_mm_connector(self): |
| | return self.mm_connector |
| |
|
| | @classmethod |
| | def load_pretrained(cls, model_path_or_config, *args, **kwargs): |
| | kwargs.pop("config", None) |
| | |
| | if isinstance(model_path_or_config, str): |
| | config = AutoConfig.from_pretrained(model_path_or_config, trust_remote_code=True, **kwargs) |
| | elif isinstance(model_path_or_config, ApolloConfig): |
| | config = model_path_or_config |
| | else: |
| | raise NotImplementedError(f"wrong type, {type(model_path_or_config)} \ |
| | {isinstance(model_path_or_config, ApolloConfig)}") |
| |
|
| | model_dtype = getattr(config, "model_dtype", "torch.float16") |
| | if not hasattr(config, "model_dtype"): |
| | warnings.warn("model_dtype not found in config, defaulting to torch.float16.") |
| | config.model_dtype = model_dtype |
| |
|
| | with ContextManagers([no_init_weights(_enable=True), ]): |
| | vlm = cls(config, *args, **kwargs) |
| |
|
| | if hasattr(vlm, "llm") and hasattr(vlm, "vision_tower") and hasattr(vlm, "mm_connector"): |
| | if vlm.is_loaded: |
| | return vlm |
| | else: |
| | print('loading model failed!') |
| | else: |
| | print('loading model failed!') |
| |
|
| | def _encode_mm(self, x): |
| | x = self.get_vision_tower()(x) |
| | x = self.mm_connector(x) |
| | return x |
| |
|
| | def encode_mm_minibatch(self, x): |
| | split_sizes = [x_s[0].shape[0] for x_s in x] |
| | x = [torch.split(torch.cat([x_s[i] for x_s in x], dim=0), self.config.encode_batch_size) for i in |
| | range(self.get_vision_tower().num_vision_encoders)] |
| | swapped_x = [] |
| | for i in range(len(x[0])): |
| | swapped_x.append([x_s[i] for x_s in x]) |
| |
|
| | features = [] |
| | for xx in swapped_x: |
| | xx = self._encode_mm(xx) |
| | features.append(xx) |
| | x = torch.cat(features, dim=0) |
| | x = torch.split(x, split_sizes, dim=0) |
| | return [xx.contiguous().view(-1, xx.shape[2]) for xx in x] |
| |
|
| | def prepare_inputs_labels_for_multimodal( |
| | self, input_ids, position_ids, attention_mask, past_key_values, labels, vision_input, data_types |
| | ): |
| | vision_tower = self.get_vision_tower() |
| | if vision_tower is None or vision_input is None or input_ids.shape[1] == 1: |
| | if ( |
| | past_key_values is not None |
| | and vision_tower is not None |
| | and vision_input is not None |
| | and input_ids.shape[1] == 1 |
| | ): |
| | target_shape = past_key_values[-1][-1].shape[-2] + 1 |
| | attention_mask = torch.cat( |
| | ( |
| | attention_mask, |
| | torch.ones( |
| | ( |
| | attention_mask.shape[0], |
| | target_shape - attention_mask.shape[1], |
| | ), |
| | dtype=attention_mask.dtype, |
| | device=attention_mask.device, |
| | ), |
| | ), |
| | dim=1, |
| | ) |
| | position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1 |
| | return ( |
| | input_ids, |
| | position_ids, |
| | attention_mask, |
| | past_key_values, |
| | None, |
| | labels, |
| | ) |
| |
|
| | ''' |
| | vision_input is a list of tuples, and data_type is a list of strings: |
| | data_type = ['image', 'video', 'video'..., 'text'] |
| | (for one video and two image encoders) |
| | vision_input = |
| | [ |
| | [image(1, T, C, H, W), image(1, T, C, H, W), image(1, T, C, H, W)], |
| | [video(Nc1, C, T, H, W), video(Nc1, T, C, H, W), video(Nc1, T, C, H, W)], |
| | [video(Nc2, C, T, H, W), video(Nc2, T, C, H, W), video(Nc2, T, C, H, W)], |
| | ] |
| | -> video encoders typlically expect (C,T,H,W), images expect (C,H,W). |
| | ''' |
| | |
| | merged_mm_features = self.encode_mm_minibatch(vision_input) |
| |
|
| | if not getattr(self.config, "tune_language_model", True) and getattr(self.config, "use_mm_start_end", False): |
| | raise NotImplementedError |
| | |
| | |
| | |
| | |
| | |
| | _labels = labels |
| | _position_ids = position_ids |
| | _attention_mask = attention_mask |
| | if attention_mask is None: |
| | attention_mask = torch.ones_like(input_ids, dtype=torch.bool) |
| | else: |
| | attention_mask = attention_mask.bool() |
| | if position_ids is None: |
| | position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=input_ids.device) |
| | if labels is None: |
| | labels = torch.full_like(input_ids, IGNORE_INDEX) |
| |
|
| | |
| | input_ids_copy = input_ids.clone() |
| | |
| | input_ids_copy[input_ids_copy == X_TOKEN_INDEX] = 0 |
| | input_embeds = self.get_llm().model.embed_tokens(input_ids_copy) |
| |
|
| | input_ids = [ |
| | cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask) |
| | ] |
| | input_embeds_1 = [ |
| | cur_input_embeds[cur_attention_mask] |
| | for cur_input_embeds, cur_attention_mask in zip(input_embeds, attention_mask) |
| | ] |
| | labels = [cur_labels[cur_attention_mask] for cur_labels, cur_attention_mask in zip(labels, attention_mask)] |
| | |
| | new_labels = [] |
| | new_input_embeds = [] |
| | |
| | |
| | for batch_idx, (cur_labels, cur_input_ids, mm_features) in enumerate( |
| | zip(labels, input_ids, merged_mm_features)): |
| | cur_input_ids = input_ids[batch_idx] |
| | num_mm = (cur_input_ids == X_TOKEN_INDEX).sum() |
| | if num_mm == 0: |
| | cur_input_embeds_1 = input_embeds_1[batch_idx] |
| | cur_input_embeds = torch.cat([cur_input_embeds_1, mm_features[0:0]], dim=0) |
| | new_input_embeds.append(cur_input_embeds) |
| | new_labels.append(cur_labels) |
| | |
| | continue |
| |
|
| | if mm_features.shape[0] != num_mm: |
| | print(data_types[batch_idx]) |
| | assert num_mm == len( |
| | mm_features), f'Error in {data_types[batch_idx]}{num_mm}=/={len(mm_features)} not the same number of vision tokens in and vision embeddings!' |
| |
|
| | cur_input_embeds = input_embeds_1[batch_idx] |
| | image_token_indices = ( |
| | [-1] + torch.where(cur_input_ids == X_TOKEN_INDEX)[0].tolist() + [cur_input_ids.shape[0]] |
| | ) |
| | cur_input_ids_noim = [] |
| | cur_labels = labels[batch_idx] |
| | cur_labels_noim = [] |
| | cur_input_embeds_no_im = [] |
| | for i in range(len(image_token_indices) - 1): |
| | cur_input_ids_noim.append(cur_input_ids[image_token_indices[i] + 1: image_token_indices[i + 1]]) |
| | cur_labels_noim.append(cur_labels[image_token_indices[i] + 1: image_token_indices[i + 1]]) |
| | cur_input_embeds_no_im.append(cur_input_embeds[image_token_indices[i] + 1: image_token_indices[i + 1]]) |
| |
|
| | cur_new_input_embeds = [] |
| | cur_new_labels = [] |
| | for i in range(num_mm + 1): |
| | cur_new_input_embeds.append(cur_input_embeds_no_im[i]) |
| | |
| | cur_new_labels.append(cur_labels_noim[i]) |
| | if i < num_mm: |
| | cur_image_features = mm_features[i:i + 1] |
| | cur_new_input_embeds.append(cur_image_features) |
| | |
| | cur_new_labels.append( |
| | torch.full( |
| | (cur_image_features.shape[0],), |
| | IGNORE_INDEX, |
| | device=cur_labels.device, |
| | dtype=cur_labels.dtype, |
| | ) |
| | ) |
| |
|
| | cur_new_input_embeds = torch.cat(cur_new_input_embeds) |
| | cur_new_labels = torch.cat(cur_new_labels) |
| |
|
| | new_input_embeds.append(cur_new_input_embeds) |
| | new_labels.append(cur_new_labels) |
| |
|
| | |
| | tokenizer_model_max_length = getattr(self.get_llm().config, "tokenizer_model_max_length", None) |
| | if tokenizer_model_max_length is not None: |
| | if any(len(x) > tokenizer_model_max_length for x in new_input_embeds): |
| | priny("Inputs truncated!") |
| | new_input_embeds = [x[:tokenizer_model_max_length] for x in new_input_embeds] |
| | new_labels = [x[:tokenizer_model_max_length] for x in new_labels] |
| | |
| | max_len = max(x.shape[0] for x in new_input_embeds) |
| | batch_size = len(new_input_embeds) |
| |
|
| | new_input_embeds_padded = [] |
| | new_labels_padded = torch.full( |
| | (batch_size, max_len), |
| | IGNORE_INDEX, |
| | dtype=new_labels[0].dtype, |
| | device=new_labels[0].device, |
| | ) |
| | attention_mask = torch.zeros( |
| | (batch_size, max_len), |
| | dtype=attention_mask.dtype, |
| | device=attention_mask.device, |
| | ) |
| | position_ids = torch.zeros((batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device) |
| | for i, (cur_new_embed, cur_new_labels) in enumerate(zip(new_input_embeds, new_labels)): |
| | cur_len = cur_new_embed.shape[0] |
| | if getattr(self.get_llm().config, "tokenizer_padding_side", "right") == "left": |
| | new_input_embeds_padded.append( |
| | torch.cat( |
| | ( |
| | torch.zeros( |
| | (max_len - cur_len, cur_new_embed.shape[1]), |
| | dtype=cur_new_embed.dtype, |
| | device=cur_new_embed.device, |
| | ), |
| | cur_new_embed, |
| | ), |
| | dim=0, |
| | ) |
| | ) |
| | if cur_len > 0: |
| | new_labels_padded[i, -cur_len:] = cur_new_labels |
| | attention_mask[i, -cur_len:] = True |
| | position_ids[i, -cur_len:] = torch.arange( |
| | 0, cur_len, dtype=position_ids.dtype, device=position_ids.device |
| | ) |
| | else: |
| | new_input_embeds_padded.append( |
| | torch.cat( |
| | ( |
| | cur_new_embed, |
| | torch.zeros( |
| | (max_len - cur_len, cur_new_embed.shape[1]), |
| | dtype=cur_new_embed.dtype, |
| | device=cur_new_embed.device, |
| | ), |
| | ), |
| | dim=0, |
| | ) |
| | ) |
| | if cur_len > 0: |
| | new_labels_padded[i, :cur_len] = cur_new_labels |
| | attention_mask[i, :cur_len] = True |
| | position_ids[i, :cur_len] = torch.arange( |
| | 0, cur_len, dtype=position_ids.dtype, device=position_ids.device |
| | ) |
| |
|
| | new_input_embeds = torch.stack(new_input_embeds_padded, dim=0) |
| |
|
| | if _labels is None: |
| | new_labels = None |
| | else: |
| | new_labels = new_labels_padded |
| |
|
| | if _attention_mask is None: |
| | attention_mask = None |
| | else: |
| | attention_mask = attention_mask.to(dtype=_attention_mask.dtype) |
| |
|
| | if _position_ids is None: |
| | position_ids = None |
| |
|
| | return ( |
| | None, |
| | position_ids, |
| | attention_mask, |
| | past_key_values, |
| | new_input_embeds, |
| | new_labels, |
| | ) |
| |
|