| from typing import List, Optional, Tuple, Union |
| import warnings, os, torch |
| import torch.nn as nn |
|
|
| from transformers import AutoConfig, PretrainedConfig, PreTrainedModel, AutoModelForCausalLM, AutoTokenizer |
| from transformers.modeling_utils import ContextManagers, no_init_weights |
| from transformers.modeling_outputs import CausalLMOutputWithPast |
| from transformers.generation.utils import GenerateOutput |
| from .configuration_apollo import ApolloConfig |
|
|
| from .vision_tower import ApolloVisionTower |
| from .mm_connector import MMConnector |
|
|
| IGNORE_INDEX = -100 |
| X_TOKEN_INDEX = -200 |
|
|
|
|
| def get_model_config(config): |
| default_keys = ["llm_cfg", "vision_tower_cfg", "mm_connector_cfg"] |
| if hasattr(config, "_name_or_path") and len(config._name_or_path) >= 2: |
| root_path = config._name_or_path |
| else: |
| root_path = config.resume_path |
|
|
| return_pths = [] |
| for key in default_keys: |
| cfg = getattr(config, key, None) |
| if isinstance(cfg, dict): |
| try: |
| return_pths.append(os.path.join(root_path, key[:-4])) |
| except: |
| raise ValueError(f"Cannot find resume path in config for {key}!") |
| elif isinstance(cfg, PretrainedConfig): |
| return_pths.append(os.path.join(root_path, key[:-4])) |
| elif isinstance(cfg, str): |
| return_pths.append(cfg) |
|
|
| return_list = [] |
| for pth in return_pths: |
| return_list.append(AutoConfig.from_pretrained(pth, trust_remote_code=True)) |
|
|
| return return_list |
|
|
|
|
| def build_llm_and_tokenizer( |
| llm_cfg: str, |
| config: PretrainedConfig, |
| attn_implementation=None, |
| model_max_length=None, |
| *args, |
| **kwargs, |
| ) -> PreTrainedModel: |
| llm_arch = getattr(llm_cfg, "architectures")[0].lower() |
| |
| llm_path = llm_cfg._name_or_path |
| llm = AutoModelForCausalLM.from_pretrained( |
| llm_path, config=llm_cfg, torch_dtype=eval(config.model_dtype), *args, **kwargs |
| ) |
|
|
| tokenizer = AutoTokenizer.from_pretrained( |
| llm_path, |
| model_max_length=llm_cfg.model_max_length, |
| padding_side="right", |
| use_fast=False, |
| legacy=False, |
| **kwargs |
| ) |
|
|
| |
| return llm, tokenizer |
|
|
|
|
| class ApolloForCausalLM(PreTrainedModel): |
| def __init__(self, config: ApolloConfig, *args, **kwargs): |
| super().__init__(config) |
| llm_cfg, vision_tower_cfg, mm_connector_cfg = get_model_config(config) |
| model_dtype = getattr(config, "model_dtype", "torch.float16") |
| if not hasattr(config, "model_dtype"): |
| warnings.warn("model_dtype not found in config, defaulting to torch.float16.") |
| config.model_dtype = model_dtype |
| |
|
|
| self.lm_head = nn.Linear(llm_cfg.hidden_size, config.vocab_size, bias=False) |
| self.vision_tower = ApolloVisionTower(config, vision_tower_cfg) |
| self.mm_connector = MMConnector.from_pretrained(mm_connector_cfg._name_or_path) |
| self.llm, self.tokenizer = build_llm_and_tokenizer(llm_cfg, config, *args, **kwargs) |
| self.post_init() |
| self.is_loaded = True |
|
|
| def forward( |
| self, |
| input_ids: torch.LongTensor = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| past_key_values: Optional[List[torch.FloatTensor]] = None, |
| inputs_embeds: Optional[torch.FloatTensor] = None, |
| labels: Optional[torch.LongTensor] = None, |
| use_cache: Optional[bool] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| vision_input: Optional[List[torch.FloatTensor]] = None, |
| data_types: Optional[List[str]] = None, |
| return_dict: Optional[bool] = None, |
| cache_position=None, |
| ) -> Union[Tuple, CausalLMOutputWithPast]: |
| |
| if inputs_embeds is None: |
| ( |
| input_ids, |
| position_ids, |
| attention_mask, |
| past_key_values, |
| inputs_embeds, |
| labels |
| ) = self.prepare_inputs_labels_for_multimodal( |
| input_ids, |
| position_ids, |
| attention_mask, |
| past_key_values, |
| labels, |
| vision_input, |
| data_types |
| ) |
|
|
| return self.get_llm().forward( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| position_ids=position_ids, |
| past_key_values=past_key_values, |
| inputs_embeds=inputs_embeds, |
| labels=labels, |
| use_cache=use_cache, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| ) |
|
|
| @torch.no_grad() |
| def generate( |
| self, |
| inputs: Optional[torch.Tensor] = None, |
| vision_input: Optional[List[torch.Tensor]] = None, |
| data_types: Optional[List[str]] = None, |
| **kwargs, |
| ) -> Union[GenerateOutput, torch.LongTensor]: |
| position_ids = kwargs.pop("position_ids", None) |
| attention_mask = kwargs.pop("attention_mask", None) |
| if "inputs_embeds" in kwargs: |
| raise NotImplementedError("`inputs_embeds` is not supported") |
|
|
| if vision_input is not None: |
| (inputs, position_ids, attention_mask, _, inputs_embeds, _) = self.prepare_inputs_labels_for_multimodal( |
| inputs, position_ids, attention_mask, None, None, vision_input, data_types=data_types) |
| else: |
| inputs_embeds = self.embed_tokens(inputs) |
|
|
| return self.get_llm().generate(position_ids=position_ids, attention_mask=attention_mask, |
| inputs_embeds=inputs_embeds, **kwargs) |
|
|
| def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs): |
| vision_input = kwargs.pop("vision_input", None) |
| data_types = kwargs.pop("data_types", None) |
| inputs = self.get_llm().prepare_inputs_for_generation(input_ids, past_key_values=past_key_values, |
| inputs_embeds=inputs_embeds, **kwargs) |
| if vision_input is not None: |
| inputs["vision_input"] = vision_input |
| if data_types is not None: |
| inputs["data_types"] = data_types |
| return inputs |
|
|
| @classmethod |
| def from_pretrained( |
| cls, |
| pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], |
| *model_args, |
| config: Optional[Union[PretrainedConfig, str, os.PathLike]] = None, |
| cache_dir: Optional[Union[str, os.PathLike]] = None, |
| ignore_mismatched_sizes: bool = False, |
| force_download: bool = False, |
| local_files_only: bool = False, |
| token: Optional[Union[str, bool]] = None, |
| revision: str = "main", |
| use_safetensors: bool = None, |
| **kwargs, |
| ): |
|
|
| return cls.load_pretrained( |
| pretrained_model_name_or_path, |
| *model_args, |
| config=config, |
| cache_dir=cache_dir, |
| ignore_mismatched_sizes=ignore_mismatched_sizes, |
| force_download=force_download, |
| local_files_only=local_files_only, |
| token=token, |
| revision=revision, |
| use_safetensors=use_safetensors, |
| **kwargs, |
| ) |
|
|
| def get_llm(self): |
| return self.llm |
|
|
| def get_vision_tower(self): |
| return self.vision_tower |
|
|
| def get_mm_connector(self): |
| return self.mm_connector |
|
|
| @classmethod |
| def load_pretrained(cls, model_path_or_config, *args, **kwargs): |
| kwargs.pop("config", None) |
| |
| if isinstance(model_path_or_config, str): |
| config = AutoConfig.from_pretrained(model_path_or_config, trust_remote_code=True, **kwargs) |
| elif isinstance(model_path_or_config, ApolloConfig): |
| config = model_path_or_config |
| else: |
| raise NotImplementedError(f"wrong type, {type(model_path_or_config)} \ |
| {isinstance(model_path_or_config, ApolloConfig)}") |
|
|
| model_dtype = getattr(config, "model_dtype", "torch.float16") |
| if not hasattr(config, "model_dtype"): |
| warnings.warn("model_dtype not found in config, defaulting to torch.float16.") |
| config.model_dtype = model_dtype |
|
|
| with ContextManagers([no_init_weights(_enable=True), ]): |
| vlm = cls(config, *args, **kwargs) |
|
|
| if hasattr(vlm, "llm") and hasattr(vlm, "vision_tower") and hasattr(vlm, "mm_connector"): |
| if vlm.is_loaded: |
| return vlm |
| else: |
| print('loading model failed!') |
| else: |
| print('loading model failed!') |
|
|
| def _encode_mm(self, x): |
| x = self.get_vision_tower()(x) |
| x = self.mm_connector(x) |
| return x |
|
|
| def encode_mm_minibatch(self, x): |
| split_sizes = [x_s[0].shape[0] for x_s in x] |
| x = [torch.split(torch.cat([x_s[i] for x_s in x], dim=0), self.config.encode_batch_size) for i in |
| range(self.get_vision_tower().num_vision_encoders)] |
| swapped_x = [] |
| for i in range(len(x[0])): |
| swapped_x.append([x_s[i] for x_s in x]) |
|
|
| features = [] |
| for xx in swapped_x: |
| xx = self._encode_mm(xx) |
| features.append(xx) |
| x = torch.cat(features, dim=0) |
| x = torch.split(x, split_sizes, dim=0) |
| return [xx.contiguous().view(-1, xx.shape[2]) for xx in x] |
|
|
| def prepare_inputs_labels_for_multimodal( |
| self, input_ids, position_ids, attention_mask, past_key_values, labels, vision_input, data_types |
| ): |
| vision_tower = self.get_vision_tower() |
| if vision_tower is None or vision_input is None or input_ids.shape[1] == 1: |
| if ( |
| past_key_values is not None |
| and vision_tower is not None |
| and vision_input is not None |
| and input_ids.shape[1] == 1 |
| ): |
| target_shape = past_key_values[-1][-1].shape[-2] + 1 |
| attention_mask = torch.cat( |
| ( |
| attention_mask, |
| torch.ones( |
| ( |
| attention_mask.shape[0], |
| target_shape - attention_mask.shape[1], |
| ), |
| dtype=attention_mask.dtype, |
| device=attention_mask.device, |
| ), |
| ), |
| dim=1, |
| ) |
| position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1 |
| return ( |
| input_ids, |
| position_ids, |
| attention_mask, |
| past_key_values, |
| None, |
| labels, |
| ) |
|
|
| ''' |
| vision_input is a list of tuples, and data_type is a list of strings: |
| data_type = ['image', 'video', 'video'..., 'text'] |
| (for one video and two image encoders) |
| vision_input = |
| [ |
| [image(1, T, C, H, W), image(1, T, C, H, W), image(1, T, C, H, W)], |
| [video(Nc1, C, T, H, W), video(Nc1, T, C, H, W), video(Nc1, T, C, H, W)], |
| [video(Nc2, C, T, H, W), video(Nc2, T, C, H, W), video(Nc2, T, C, H, W)], |
| ] |
| -> video encoders typlically expect (C,T,H,W), images expect (C,H,W). |
| ''' |
| |
| merged_mm_features = self.encode_mm_minibatch(vision_input) |
|
|
| if not getattr(self.config, "tune_language_model", True) and getattr(self.config, "use_mm_start_end", False): |
| raise NotImplementedError |
| |
| |
| |
| |
| |
| _labels = labels |
| _position_ids = position_ids |
| _attention_mask = attention_mask |
| if attention_mask is None: |
| attention_mask = torch.ones_like(input_ids, dtype=torch.bool) |
| else: |
| attention_mask = attention_mask.bool() |
| if position_ids is None: |
| position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=input_ids.device) |
| if labels is None: |
| labels = torch.full_like(input_ids, IGNORE_INDEX) |
|
|
| |
| input_ids_copy = input_ids.clone() |
| |
| input_ids_copy[input_ids_copy == X_TOKEN_INDEX] = 0 |
| input_embeds = self.get_llm().model.embed_tokens(input_ids_copy) |
|
|
| input_ids = [ |
| cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask) |
| ] |
| input_embeds_1 = [ |
| cur_input_embeds[cur_attention_mask] |
| for cur_input_embeds, cur_attention_mask in zip(input_embeds, attention_mask) |
| ] |
| labels = [cur_labels[cur_attention_mask] for cur_labels, cur_attention_mask in zip(labels, attention_mask)] |
| |
| new_labels = [] |
| new_input_embeds = [] |
| |
| |
| for batch_idx, (cur_labels, cur_input_ids, mm_features) in enumerate( |
| zip(labels, input_ids, merged_mm_features)): |
| cur_input_ids = input_ids[batch_idx] |
| num_mm = (cur_input_ids == X_TOKEN_INDEX).sum() |
| if num_mm == 0: |
| cur_input_embeds_1 = input_embeds_1[batch_idx] |
| cur_input_embeds = torch.cat([cur_input_embeds_1, mm_features[0:0]], dim=0) |
| new_input_embeds.append(cur_input_embeds) |
| new_labels.append(cur_labels) |
| |
| continue |
|
|
| if mm_features.shape[0] != num_mm: |
| print(data_types[batch_idx]) |
| assert num_mm == len( |
| mm_features), f'Error in {data_types[batch_idx]}{num_mm}=/={len(mm_features)} not the same number of vision tokens in and vision embeddings!' |
|
|
| cur_input_embeds = input_embeds_1[batch_idx] |
| image_token_indices = ( |
| [-1] + torch.where(cur_input_ids == X_TOKEN_INDEX)[0].tolist() + [cur_input_ids.shape[0]] |
| ) |
| cur_input_ids_noim = [] |
| cur_labels = labels[batch_idx] |
| cur_labels_noim = [] |
| cur_input_embeds_no_im = [] |
| for i in range(len(image_token_indices) - 1): |
| cur_input_ids_noim.append(cur_input_ids[image_token_indices[i] + 1: image_token_indices[i + 1]]) |
| cur_labels_noim.append(cur_labels[image_token_indices[i] + 1: image_token_indices[i + 1]]) |
| cur_input_embeds_no_im.append(cur_input_embeds[image_token_indices[i] + 1: image_token_indices[i + 1]]) |
|
|
| cur_new_input_embeds = [] |
| cur_new_labels = [] |
| for i in range(num_mm + 1): |
| cur_new_input_embeds.append(cur_input_embeds_no_im[i]) |
| |
| cur_new_labels.append(cur_labels_noim[i]) |
| if i < num_mm: |
| cur_image_features = mm_features[i:i + 1] |
| cur_new_input_embeds.append(cur_image_features) |
| |
| cur_new_labels.append( |
| torch.full( |
| (cur_image_features.shape[0],), |
| IGNORE_INDEX, |
| device=cur_labels.device, |
| dtype=cur_labels.dtype, |
| ) |
| ) |
|
|
| cur_new_input_embeds = torch.cat(cur_new_input_embeds) |
| cur_new_labels = torch.cat(cur_new_labels) |
|
|
| new_input_embeds.append(cur_new_input_embeds) |
| new_labels.append(cur_new_labels) |
|
|
| |
| tokenizer_model_max_length = getattr(self.get_llm().config, "tokenizer_model_max_length", None) |
| if tokenizer_model_max_length is not None: |
| if any(len(x) > tokenizer_model_max_length for x in new_input_embeds): |
| priny("Inputs truncated!") |
| new_input_embeds = [x[:tokenizer_model_max_length] for x in new_input_embeds] |
| new_labels = [x[:tokenizer_model_max_length] for x in new_labels] |
| |
| max_len = max(x.shape[0] for x in new_input_embeds) |
| batch_size = len(new_input_embeds) |
|
|
| new_input_embeds_padded = [] |
| new_labels_padded = torch.full( |
| (batch_size, max_len), |
| IGNORE_INDEX, |
| dtype=new_labels[0].dtype, |
| device=new_labels[0].device, |
| ) |
| attention_mask = torch.zeros( |
| (batch_size, max_len), |
| dtype=attention_mask.dtype, |
| device=attention_mask.device, |
| ) |
| position_ids = torch.zeros((batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device) |
| for i, (cur_new_embed, cur_new_labels) in enumerate(zip(new_input_embeds, new_labels)): |
| cur_len = cur_new_embed.shape[0] |
| if getattr(self.get_llm().config, "tokenizer_padding_side", "right") == "left": |
| new_input_embeds_padded.append( |
| torch.cat( |
| ( |
| torch.zeros( |
| (max_len - cur_len, cur_new_embed.shape[1]), |
| dtype=cur_new_embed.dtype, |
| device=cur_new_embed.device, |
| ), |
| cur_new_embed, |
| ), |
| dim=0, |
| ) |
| ) |
| if cur_len > 0: |
| new_labels_padded[i, -cur_len:] = cur_new_labels |
| attention_mask[i, -cur_len:] = True |
| position_ids[i, -cur_len:] = torch.arange( |
| 0, cur_len, dtype=position_ids.dtype, device=position_ids.device |
| ) |
| else: |
| new_input_embeds_padded.append( |
| torch.cat( |
| ( |
| cur_new_embed, |
| torch.zeros( |
| (max_len - cur_len, cur_new_embed.shape[1]), |
| dtype=cur_new_embed.dtype, |
| device=cur_new_embed.device, |
| ), |
| ), |
| dim=0, |
| ) |
| ) |
| if cur_len > 0: |
| new_labels_padded[i, :cur_len] = cur_new_labels |
| attention_mask[i, :cur_len] = True |
| position_ids[i, :cur_len] = torch.arange( |
| 0, cur_len, dtype=position_ids.dtype, device=position_ids.device |
| ) |
|
|
| new_input_embeds = torch.stack(new_input_embeds_padded, dim=0) |
|
|
| if _labels is None: |
| new_labels = None |
| else: |
| new_labels = new_labels_padded |
|
|
| if _attention_mask is None: |
| attention_mask = None |
| else: |
| attention_mask = attention_mask.to(dtype=_attention_mask.dtype) |
|
|
| if _position_ids is None: |
| position_ids = None |
|
|
| return ( |
| None, |
| position_ids, |
| attention_mask, |
| past_key_values, |
| new_input_embeds, |
| new_labels, |
| ) |
|
|