Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- VILA/llava/model/__pycache__/__init__.cpython-310.pyc +0 -0
- VILA/llava/model/__pycache__/__init__.cpython-39.pyc +0 -0
- VILA/llava/model/__pycache__/builder.cpython-310.pyc +0 -0
- VILA/llava/model/__pycache__/configuration_llava.cpython-310.pyc +0 -0
- VILA/llava/model/__pycache__/llava_arch.cpython-310.pyc +0 -0
- VILA/llava/model/__pycache__/utils.cpython-310.pyc +0 -0
- VILA/llava/model/language_model/__pycache__/builder.cpython-310.pyc +0 -0
- VILA/llava/model/language_model/__pycache__/llava_llama.cpython-310.pyc +0 -0
- VILA/llava/model/language_model/__pycache__/llava_llama.cpython-39.pyc +0 -0
- VILA/llava/model/language_model/__pycache__/llava_mistral.cpython-310.pyc +0 -0
- VILA/llava/model/language_model/__pycache__/llava_mixtral.cpython-310.pyc +0 -0
- VILA/llava/model/language_model/__pycache__/modeling_mixtral_long_context.cpython-310.pyc +0 -0
- VILA/llava/model/language_model/builder.py +114 -0
- VILA/llava/model/language_model/llava_gemma.py +153 -0
- VILA/llava/model/language_model/llava_llama.py +186 -0
- VILA/llava/model/language_model/llava_mistral.py +137 -0
- VILA/llava/model/language_model/llava_mixtral.py +136 -0
- VILA/llava/model/language_model/llava_mpt.py +160 -0
- VILA/llava/model/language_model/modeling_mixtral_long_context.py +1657 -0
- VILA/llava/model/language_model/mpt/adapt_tokenizer.py +61 -0
- VILA/llava/model/language_model/mpt/attention.py +480 -0
- VILA/llava/model/language_model/mpt/blocks.py +100 -0
- VILA/llava/model/language_model/mpt/configuration_mpt.py +184 -0
- VILA/llava/model/language_model/mpt/custom_embedding.py +27 -0
- VILA/llava/model/language_model/mpt/flash_attn_triton.py +947 -0
- VILA/llava/model/language_model/mpt/hf_prefixlm_converter.py +657 -0
- VILA/llava/model/language_model/mpt/meta_init_context.py +118 -0
- VILA/llava/model/language_model/mpt/modeling_mpt.py +483 -0
- VILA/llava/model/language_model/mpt/norm.py +89 -0
- VILA/llava/model/language_model/mpt/param_init_fns.py +399 -0
- VILA/llava/model/multimodal_encoder/__pycache__/builder.cpython-310.pyc +0 -0
- VILA/llava/model/multimodal_encoder/__pycache__/clip_encoder.cpython-310.pyc +0 -0
- VILA/llava/model/multimodal_encoder/__pycache__/image_processor.cpython-310.pyc +0 -0
- VILA/llava/model/multimodal_encoder/__pycache__/intern_encoder.cpython-310.pyc +0 -0
- VILA/llava/model/multimodal_encoder/__pycache__/radio_encoder.cpython-310.pyc +0 -0
- VILA/llava/model/multimodal_encoder/__pycache__/siglip_encoder.cpython-310.pyc +0 -0
- VILA/llava/model/multimodal_encoder/__pycache__/vision_encoder.cpython-310.pyc +0 -0
- VILA/llava/model/multimodal_encoder/__pycache__/visualize_features.cpython-310.pyc +0 -0
- VILA/llava/model/multimodal_encoder/builder.py +64 -0
- VILA/llava/model/multimodal_encoder/clip_encoder.py +42 -0
- VILA/llava/model/multimodal_encoder/image_processor.py +546 -0
- VILA/llava/model/multimodal_encoder/intern/__pycache__/configuration_intern_vit.cpython-310.pyc +0 -0
- VILA/llava/model/multimodal_encoder/intern/__pycache__/flash_attention.cpython-310.pyc +0 -0
- VILA/llava/model/multimodal_encoder/intern/__pycache__/modeling_intern_vit.cpython-310.pyc +0 -0
- VILA/llava/model/multimodal_encoder/intern/configuration_intern_vit.py +117 -0
- VILA/llava/model/multimodal_encoder/intern/flash_attention.py +105 -0
- VILA/llava/model/multimodal_encoder/intern/modeling_intern_vit.py +543 -0
- VILA/llava/model/multimodal_encoder/intern_encoder.py +71 -0
- VILA/llava/model/multimodal_encoder/radio_encoder.py +334 -0
- VILA/llava/model/multimodal_encoder/radio_torchhub_encoder.py +375 -0
VILA/llava/model/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (450 Bytes). View file
|
|
|
VILA/llava/model/__pycache__/__init__.cpython-39.pyc
ADDED
|
Binary file (448 Bytes). View file
|
|
|
VILA/llava/model/__pycache__/builder.cpython-310.pyc
ADDED
|
Binary file (5.81 kB). View file
|
|
|
VILA/llava/model/__pycache__/configuration_llava.cpython-310.pyc
ADDED
|
Binary file (1.28 kB). View file
|
|
|
VILA/llava/model/__pycache__/llava_arch.cpython-310.pyc
ADDED
|
Binary file (20.8 kB). View file
|
|
|
VILA/llava/model/__pycache__/utils.cpython-310.pyc
ADDED
|
Binary file (2.45 kB). View file
|
|
|
VILA/llava/model/language_model/__pycache__/builder.cpython-310.pyc
ADDED
|
Binary file (2.42 kB). View file
|
|
|
VILA/llava/model/language_model/__pycache__/llava_llama.cpython-310.pyc
ADDED
|
Binary file (3.8 kB). View file
|
|
|
VILA/llava/model/language_model/__pycache__/llava_llama.cpython-39.pyc
ADDED
|
Binary file (3.7 kB). View file
|
|
|
VILA/llava/model/language_model/__pycache__/llava_mistral.cpython-310.pyc
ADDED
|
Binary file (3.67 kB). View file
|
|
|
VILA/llava/model/language_model/__pycache__/llava_mixtral.cpython-310.pyc
ADDED
|
Binary file (3.57 kB). View file
|
|
|
VILA/llava/model/language_model/__pycache__/modeling_mixtral_long_context.cpython-310.pyc
ADDED
|
Binary file (46 kB). View file
|
|
|
VILA/llava/model/language_model/builder.py
ADDED
|
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
#
|
| 15 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 16 |
+
|
| 17 |
+
import math
|
| 18 |
+
import os.path as osp
|
| 19 |
+
import warnings
|
| 20 |
+
from typing import Tuple
|
| 21 |
+
|
| 22 |
+
import torch
|
| 23 |
+
from huggingface_hub import file_exists, repo_exists
|
| 24 |
+
from huggingface_hub.utils import HFValidationError
|
| 25 |
+
from transformers import (
|
| 26 |
+
AutoConfig,
|
| 27 |
+
AutoModelForCausalLM,
|
| 28 |
+
AutoTokenizer,
|
| 29 |
+
PretrainedConfig,
|
| 30 |
+
PreTrainedModel,
|
| 31 |
+
PreTrainedTokenizer,
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def has_tokenizer(repo_id_or_path: str) -> bool:
|
| 36 |
+
# Check if the tokenizer is in a local directory
|
| 37 |
+
if osp.exists(osp.join(repo_id_or_path, "tokenizer_config.json")):
|
| 38 |
+
return True
|
| 39 |
+
|
| 40 |
+
# Check if the tokenizer is in a Hugging Face Hub repo
|
| 41 |
+
try:
|
| 42 |
+
return repo_exists(repo_id_or_path) and file_exists(repo_id_or_path, "tokenizer_config.json")
|
| 43 |
+
except HFValidationError:
|
| 44 |
+
return False
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def context_length_extension(config):
|
| 48 |
+
orig_ctx_len = getattr(config, "max_position_embeddings", None)
|
| 49 |
+
model_max_length = getattr(config, "model_max_length", None)
|
| 50 |
+
if orig_ctx_len and model_max_length > orig_ctx_len:
|
| 51 |
+
print(f"Scaling RoPE from {orig_ctx_len} to {model_max_length}")
|
| 52 |
+
scaling_factor = float(math.ceil(model_max_length / orig_ctx_len))
|
| 53 |
+
config.rope_scaling = {"type": "linear", "factor": scaling_factor}
|
| 54 |
+
return config
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def build_llm_and_tokenizer(
|
| 58 |
+
model_name_or_path: str,
|
| 59 |
+
config: PretrainedConfig,
|
| 60 |
+
attn_implementation=None,
|
| 61 |
+
model_max_length=None,
|
| 62 |
+
*args,
|
| 63 |
+
**kwargs,
|
| 64 |
+
) -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
|
| 65 |
+
llm_cfg = AutoConfig.from_pretrained(model_name_or_path)
|
| 66 |
+
llm_cfg._attn_implementation = attn_implementation
|
| 67 |
+
llm_cfg.model_max_length = model_max_length
|
| 68 |
+
if model_max_length is not None:
|
| 69 |
+
context_length_extension(llm_cfg)
|
| 70 |
+
|
| 71 |
+
llm = AutoModelForCausalLM.from_pretrained(
|
| 72 |
+
model_name_or_path, config=llm_cfg, torch_dtype=eval(config.model_dtype), *args, **kwargs
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
# Locate the tokenizer.
|
| 76 |
+
llm_path = model_name_or_path
|
| 77 |
+
if not has_tokenizer(llm_path):
|
| 78 |
+
llm_path = osp.join(llm_path, "llm")
|
| 79 |
+
if not has_tokenizer(llm_path):
|
| 80 |
+
raise ValueError(f"Cannot find tokenizer in {llm_path}.")
|
| 81 |
+
|
| 82 |
+
# TODO(ligeng): use LLM class to judge to better compability.
|
| 83 |
+
try:
|
| 84 |
+
llm_arch = getattr(llm_cfg, "architectures")[0].lower()
|
| 85 |
+
except BaseException:
|
| 86 |
+
warnings.warn(f'Cannot find LLM architecture, please check the "config.json" under "{llm_path}".')
|
| 87 |
+
|
| 88 |
+
if "mpt" in llm_arch:
|
| 89 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
| 90 |
+
llm_path,
|
| 91 |
+
model_max_length=llm_cfg.model_max_length,
|
| 92 |
+
padding_side="right",
|
| 93 |
+
)
|
| 94 |
+
elif "yi" in llm_path or (
|
| 95 |
+
getattr(llm_cfg, "num_hidden_layers", -1) == 60 and getattr(llm_cfg, "num_attention_heads", -1) == 56
|
| 96 |
+
):
|
| 97 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
| 98 |
+
llm_path,
|
| 99 |
+
model_max_length=llm_cfg.model_max_length,
|
| 100 |
+
padding_side="right",
|
| 101 |
+
use_fast=False,
|
| 102 |
+
)
|
| 103 |
+
else:
|
| 104 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
| 105 |
+
llm_path,
|
| 106 |
+
model_max_length=llm_cfg.model_max_length,
|
| 107 |
+
padding_side="right",
|
| 108 |
+
use_fast=False,
|
| 109 |
+
legacy=False,
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
# TODO(ligeng): is this necessary for llava?
|
| 113 |
+
config.hidden_size = llm.config.hidden_size
|
| 114 |
+
return llm, tokenizer
|
VILA/llava/model/language_model/llava_gemma.py
ADDED
|
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2023 Haotian Liu
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
PAD_TOKEN_ID = 0
|
| 17 |
+
|
| 18 |
+
from typing import List, Optional, Tuple, Union
|
| 19 |
+
|
| 20 |
+
import torch
|
| 21 |
+
import torch.nn as nn
|
| 22 |
+
from transformers import AutoConfig, AutoModelForCausalLM
|
| 23 |
+
from transformers.modeling_outputs import CausalLMOutputWithPast
|
| 24 |
+
from transformers.models.gemma import GemmaConfig, GemmaForCausalLM, GemmaModel
|
| 25 |
+
|
| 26 |
+
from ..llava_arch import LlavaMetaForCausalLM, LlavaMetaModel
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class LlavaGemmaConfig(GemmaConfig):
|
| 30 |
+
model_type = "llava_gemma"
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class LlavaGemmaModel(GemmaModel, LlavaMetaModel):
|
| 34 |
+
config_class = LlavaGemmaConfig
|
| 35 |
+
|
| 36 |
+
def __init__(self, config: GemmaConfig):
|
| 37 |
+
super().__init__(config)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class LlavaGemmaForCausalLM(GemmaForCausalLM, LlavaMetaForCausalLM):
|
| 41 |
+
config_class = LlavaGemmaConfig
|
| 42 |
+
|
| 43 |
+
def __init__(self, config):
|
| 44 |
+
super().__init__(config)
|
| 45 |
+
self.model = LlavaGemmaModel(config)
|
| 46 |
+
self.pretraining_tp = 1
|
| 47 |
+
self.vocab_size = config.vocab_size
|
| 48 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
| 49 |
+
|
| 50 |
+
# Initialize weights and apply final processing
|
| 51 |
+
self.post_init()
|
| 52 |
+
|
| 53 |
+
def get_model(self):
|
| 54 |
+
return self.model
|
| 55 |
+
|
| 56 |
+
def get_lm_head(self):
|
| 57 |
+
return self.lm_head
|
| 58 |
+
|
| 59 |
+
def forward(
|
| 60 |
+
self,
|
| 61 |
+
input_ids: torch.LongTensor = None,
|
| 62 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 63 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 64 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
| 65 |
+
seqlens_in_batch: Optional[torch.LongTensor] = None,
|
| 66 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 67 |
+
labels: Optional[torch.LongTensor] = None,
|
| 68 |
+
use_cache: Optional[bool] = None,
|
| 69 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 70 |
+
output_attentions: Optional[bool] = None,
|
| 71 |
+
output_hidden_states: Optional[bool] = None,
|
| 72 |
+
images: Optional[torch.FloatTensor] = None,
|
| 73 |
+
return_dict: Optional[bool] = None,
|
| 74 |
+
) -> Union[Tuple, CausalLMOutputWithPast]:
|
| 75 |
+
if inputs_embeds is None:
|
| 76 |
+
(
|
| 77 |
+
input_ids,
|
| 78 |
+
position_ids,
|
| 79 |
+
attention_mask,
|
| 80 |
+
past_key_values,
|
| 81 |
+
inputs_embeds,
|
| 82 |
+
labels,
|
| 83 |
+
) = self.prepare_inputs_labels_for_multimodal(
|
| 84 |
+
input_ids, position_ids, attention_mask, past_key_values, labels, images
|
| 85 |
+
)
|
| 86 |
+
# TODO (kentang-mit@): fuse this function into the previous one.
|
| 87 |
+
# current design makes unit-test easier.
|
| 88 |
+
if self.training:
|
| 89 |
+
(
|
| 90 |
+
_,
|
| 91 |
+
new_position_ids,
|
| 92 |
+
new_attention_mask,
|
| 93 |
+
_,
|
| 94 |
+
new_inputs_embeds,
|
| 95 |
+
new_labels,
|
| 96 |
+
sorted_seqlens_in_batch,
|
| 97 |
+
) = self.repack_multimodal_data(
|
| 98 |
+
input_ids, position_ids, attention_mask, past_key_values, inputs_embeds, labels
|
| 99 |
+
)
|
| 100 |
+
if sorted_seqlens_in_batch is None:
|
| 101 |
+
sorted_seqlens_in_batch = seqlens_in_batch
|
| 102 |
+
new_input_ids = None
|
| 103 |
+
past_key_values = None
|
| 104 |
+
new_cache_position = None
|
| 105 |
+
else:
|
| 106 |
+
new_attention_mask = attention_mask
|
| 107 |
+
new_position_ids = position_ids
|
| 108 |
+
new_inputs_embeds = inputs_embeds
|
| 109 |
+
new_labels = labels
|
| 110 |
+
if attention_mask is not None:
|
| 111 |
+
sorted_seqlens_in_batch = attention_mask.sum(-1).int()
|
| 112 |
+
else:
|
| 113 |
+
sorted_seqlens_in_batch = None
|
| 114 |
+
new_input_ids = input_ids
|
| 115 |
+
# kentang-mit@: This only works for batch=1 currently
|
| 116 |
+
# model.generate of gemma does not correctly handle decoding stage currently
|
| 117 |
+
# need to manually adjust decoding stage input = 1 token
|
| 118 |
+
if past_key_values is not None:
|
| 119 |
+
if new_inputs_embeds is not None:
|
| 120 |
+
new_inputs_embeds = new_inputs_embeds[:, [-1]]
|
| 121 |
+
# kentang-mit@: seems to be a problem unique to gemma
|
| 122 |
+
if new_position_ids is not None:
|
| 123 |
+
new_position_ids = new_position_ids[:, [-1]]
|
| 124 |
+
new_cache_position = new_position_ids[0]
|
| 125 |
+
|
| 126 |
+
outputs = super().forward(
|
| 127 |
+
input_ids=new_input_ids,
|
| 128 |
+
attention_mask=new_attention_mask,
|
| 129 |
+
position_ids=new_position_ids,
|
| 130 |
+
past_key_values=past_key_values,
|
| 131 |
+
inputs_embeds=new_inputs_embeds,
|
| 132 |
+
labels=new_labels,
|
| 133 |
+
use_cache=use_cache,
|
| 134 |
+
cache_position=new_cache_position,
|
| 135 |
+
output_attentions=output_attentions,
|
| 136 |
+
output_hidden_states=output_hidden_states,
|
| 137 |
+
return_dict=return_dict,
|
| 138 |
+
seqlens_in_batch=sorted_seqlens_in_batch,
|
| 139 |
+
)
|
| 140 |
+
return outputs
|
| 141 |
+
|
| 142 |
+
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
|
| 143 |
+
images = kwargs.pop("images", None)
|
| 144 |
+
_inputs = super().prepare_inputs_for_generation(
|
| 145 |
+
input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs
|
| 146 |
+
)
|
| 147 |
+
if images is not None:
|
| 148 |
+
_inputs["images"] = images
|
| 149 |
+
return _inputs
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
AutoConfig.register("llava_gemma", LlavaGemmaConfig)
|
| 153 |
+
AutoModelForCausalLM.register(LlavaGemmaConfig, LlavaGemmaForCausalLM)
|
VILA/llava/model/language_model/llava_llama.py
ADDED
|
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2023 Haotian Liu
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
# This file is modified from https://github.com/haotian-liu/LLaVA/
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
import inspect
|
| 19 |
+
import os
|
| 20 |
+
from typing import List, Optional, Tuple, Union
|
| 21 |
+
|
| 22 |
+
import torch
|
| 23 |
+
from transformers import AutoConfig, AutoModel, PretrainedConfig, PreTrainedModel
|
| 24 |
+
from transformers.modeling_outputs import CausalLMOutputWithPast
|
| 25 |
+
|
| 26 |
+
from ...train.utils import calculate_loss_weight
|
| 27 |
+
from ..configuration_llava import LlavaConfig
|
| 28 |
+
from ..llava_arch import LlavaMetaForCausalLM, LlavaMetaModel
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class LlavaLlamaConfig(LlavaConfig):
|
| 32 |
+
model_type = "llava_llama"
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
## FIXME we will follow the convention to add a new class for CausalLM in the future
|
| 36 |
+
class LlavaLlamaModel(LlavaMetaModel, LlavaMetaForCausalLM, PreTrainedModel):
|
| 37 |
+
config_class = LlavaLlamaConfig
|
| 38 |
+
main_input_name = "input_embeds"
|
| 39 |
+
supports_gradient_checkpointing = True
|
| 40 |
+
|
| 41 |
+
def __init__(self, config: LlavaLlamaConfig = None, *args, **kwargs) -> None:
|
| 42 |
+
super().__init__(config)
|
| 43 |
+
return self.init_vlm(config=config, *args, **kwargs)
|
| 44 |
+
|
| 45 |
+
@classmethod
|
| 46 |
+
def from_pretrained(
|
| 47 |
+
cls,
|
| 48 |
+
pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
|
| 49 |
+
*model_args,
|
| 50 |
+
config: Optional[Union[PretrainedConfig, str, os.PathLike]] = None,
|
| 51 |
+
cache_dir: Optional[Union[str, os.PathLike]] = None,
|
| 52 |
+
ignore_mismatched_sizes: bool = False,
|
| 53 |
+
force_download: bool = False,
|
| 54 |
+
local_files_only: bool = False,
|
| 55 |
+
token: Optional[Union[str, bool]] = None,
|
| 56 |
+
revision: str = "main",
|
| 57 |
+
use_safetensors: bool = None,
|
| 58 |
+
**kwargs,
|
| 59 |
+
):
|
| 60 |
+
if hasattr(cls, "load_pretrained"):
|
| 61 |
+
return cls.load_pretrained(
|
| 62 |
+
pretrained_model_name_or_path,
|
| 63 |
+
*model_args,
|
| 64 |
+
config=config,
|
| 65 |
+
cache_dir=cache_dir,
|
| 66 |
+
ignore_mismatched_sizes=ignore_mismatched_sizes,
|
| 67 |
+
force_download=force_download,
|
| 68 |
+
local_files_only=local_files_only,
|
| 69 |
+
token=token,
|
| 70 |
+
revision=revision,
|
| 71 |
+
use_safetensors=use_safetensors,
|
| 72 |
+
**kwargs,
|
| 73 |
+
)
|
| 74 |
+
return super(LlavaLlamaModel).from_pretrained(
|
| 75 |
+
pretrained_model_name_or_path,
|
| 76 |
+
*model_args,
|
| 77 |
+
config=config,
|
| 78 |
+
cache_dir=cache_dir,
|
| 79 |
+
ignore_mismatched_sizes=ignore_mismatched_sizes,
|
| 80 |
+
force_download=force_download,
|
| 81 |
+
local_files_only=local_files_only,
|
| 82 |
+
token=token,
|
| 83 |
+
revision=revision,
|
| 84 |
+
use_safetensors=use_safetensors,
|
| 85 |
+
**kwargs,
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
def forward(
|
| 89 |
+
self,
|
| 90 |
+
input_ids: torch.LongTensor = None,
|
| 91 |
+
images: Optional[torch.FloatTensor] = None,
|
| 92 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 93 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 94 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
| 95 |
+
seqlens_in_batch: Optional[torch.LongTensor] = None,
|
| 96 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 97 |
+
labels: Optional[torch.LongTensor] = None,
|
| 98 |
+
use_cache: Optional[bool] = None,
|
| 99 |
+
output_attentions: Optional[bool] = None,
|
| 100 |
+
output_hidden_states: Optional[bool] = None,
|
| 101 |
+
return_dict: Optional[bool] = None,
|
| 102 |
+
dpo_forward: bool = False,
|
| 103 |
+
) -> Union[Tuple, CausalLMOutputWithPast]:
|
| 104 |
+
self.freezed_module_patch()
|
| 105 |
+
if inputs_embeds is None:
|
| 106 |
+
(
|
| 107 |
+
input_ids,
|
| 108 |
+
position_ids,
|
| 109 |
+
attention_mask,
|
| 110 |
+
past_key_values,
|
| 111 |
+
inputs_embeds,
|
| 112 |
+
labels,
|
| 113 |
+
) = self.prepare_inputs_labels_for_multimodal(
|
| 114 |
+
input_ids, position_ids, attention_mask, past_key_values, labels, images
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
support_packing = "seqlens_in_batch" in inspect.signature(self.llm.forward).parameters
|
| 118 |
+
|
| 119 |
+
if self.training and support_packing and not dpo_forward:
|
| 120 |
+
(
|
| 121 |
+
_,
|
| 122 |
+
new_position_ids,
|
| 123 |
+
new_attention_mask,
|
| 124 |
+
_,
|
| 125 |
+
new_inputs_embeds,
|
| 126 |
+
new_labels,
|
| 127 |
+
sorted_seqlens_in_batch,
|
| 128 |
+
) = self.repack_multimodal_data(
|
| 129 |
+
input_ids,
|
| 130 |
+
position_ids,
|
| 131 |
+
attention_mask,
|
| 132 |
+
past_key_values,
|
| 133 |
+
inputs_embeds,
|
| 134 |
+
labels,
|
| 135 |
+
)
|
| 136 |
+
if sorted_seqlens_in_batch is None:
|
| 137 |
+
sorted_seqlens_in_batch = seqlens_in_batch
|
| 138 |
+
new_input_ids = None
|
| 139 |
+
past_key_values = None
|
| 140 |
+
else:
|
| 141 |
+
new_attention_mask = attention_mask
|
| 142 |
+
new_position_ids = position_ids
|
| 143 |
+
new_inputs_embeds = inputs_embeds
|
| 144 |
+
new_labels = labels
|
| 145 |
+
sorted_seqlens_in_batch = attention_mask.sum(-1).int()
|
| 146 |
+
new_input_ids = input_ids
|
| 147 |
+
|
| 148 |
+
if support_packing:
|
| 149 |
+
outputs = self.llm.forward(
|
| 150 |
+
input_ids=new_input_ids,
|
| 151 |
+
attention_mask=new_attention_mask,
|
| 152 |
+
position_ids=new_position_ids,
|
| 153 |
+
past_key_values=past_key_values,
|
| 154 |
+
inputs_embeds=new_inputs_embeds,
|
| 155 |
+
labels=new_labels,
|
| 156 |
+
use_cache=use_cache,
|
| 157 |
+
output_attentions=output_attentions,
|
| 158 |
+
output_hidden_states=output_hidden_states,
|
| 159 |
+
return_dict=return_dict,
|
| 160 |
+
seqlens_in_batch=sorted_seqlens_in_batch,
|
| 161 |
+
)
|
| 162 |
+
else:
|
| 163 |
+
outputs = self.llm.forward(
|
| 164 |
+
input_ids=new_input_ids,
|
| 165 |
+
attention_mask=new_attention_mask,
|
| 166 |
+
position_ids=new_position_ids,
|
| 167 |
+
past_key_values=past_key_values,
|
| 168 |
+
inputs_embeds=new_inputs_embeds,
|
| 169 |
+
labels=new_labels,
|
| 170 |
+
use_cache=use_cache,
|
| 171 |
+
output_attentions=output_attentions,
|
| 172 |
+
output_hidden_states=output_hidden_states,
|
| 173 |
+
return_dict=return_dict,
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
+
# Loss rescale for SP & DP loss match
|
| 177 |
+
loss_weight = calculate_loss_weight(new_labels)
|
| 178 |
+
outputs.loss = outputs.loss * loss_weight
|
| 179 |
+
|
| 180 |
+
if dpo_forward:
|
| 181 |
+
return outputs.logits, new_labels
|
| 182 |
+
return outputs
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
AutoConfig.register("llava_llama", LlavaLlamaConfig)
|
| 186 |
+
AutoModel.register(LlavaLlamaConfig, LlavaLlamaModel)
|
VILA/llava/model/language_model/llava_mistral.py
ADDED
|
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
#
|
| 15 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 16 |
+
|
| 17 |
+
# This file is modified from https://github.com/haotian-liu/LLaVA/
|
| 18 |
+
|
| 19 |
+
from typing import List, Optional, Tuple, Union
|
| 20 |
+
|
| 21 |
+
import torch
|
| 22 |
+
import torch.nn as nn
|
| 23 |
+
from transformers import AutoConfig, AutoModelForCausalLM, MistralConfig, MistralForCausalLM, MistralModel
|
| 24 |
+
from transformers.modeling_outputs import CausalLMOutputWithPast
|
| 25 |
+
|
| 26 |
+
from ..llava_arch import LlavaMetaForCausalLM, LlavaMetaModel
|
| 27 |
+
from .modeling_mixtral_long_context import MixtralForCausalLM, MixtralModel
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class LlavaMistralConfig(MistralConfig):
|
| 31 |
+
model_type = "llava_mistral"
|
| 32 |
+
pretraining_tp = 1
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class LlavaMistralModel(MistralModel, LlavaMetaModel):
|
| 36 |
+
config_class = LlavaMistralConfig
|
| 37 |
+
|
| 38 |
+
def __init__(self, config: MistralConfig):
|
| 39 |
+
super().__init__(config)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class LlavaMistralForCausalLM(MistralForCausalLM, LlavaMetaForCausalLM):
|
| 43 |
+
config_class = LlavaMistralConfig
|
| 44 |
+
|
| 45 |
+
def __init__(self, config):
|
| 46 |
+
super(MistralForCausalLM, self).__init__(config)
|
| 47 |
+
self.model = LlavaMistralModel(config)
|
| 48 |
+
self.pretraining_tp = config.pretraining_tp
|
| 49 |
+
self.vocab_size = config.vocab_size
|
| 50 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
| 51 |
+
|
| 52 |
+
# Initialize weights and apply final processing
|
| 53 |
+
self.post_init()
|
| 54 |
+
|
| 55 |
+
def get_model(self):
|
| 56 |
+
return self.model
|
| 57 |
+
|
| 58 |
+
def get_lm_head(self):
|
| 59 |
+
return self.lm_head
|
| 60 |
+
|
| 61 |
+
def forward(
|
| 62 |
+
self,
|
| 63 |
+
input_ids: torch.LongTensor = None,
|
| 64 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 65 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 66 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
| 67 |
+
seqlens_in_batch: Optional[torch.LongTensor] = None,
|
| 68 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 69 |
+
labels: Optional[torch.LongTensor] = None,
|
| 70 |
+
use_cache: Optional[bool] = None,
|
| 71 |
+
output_attentions: Optional[bool] = None,
|
| 72 |
+
output_hidden_states: Optional[bool] = None,
|
| 73 |
+
images: Optional[torch.FloatTensor] = None,
|
| 74 |
+
return_dict: Optional[bool] = None,
|
| 75 |
+
) -> Union[Tuple, CausalLMOutputWithPast]:
|
| 76 |
+
if inputs_embeds is None:
|
| 77 |
+
(
|
| 78 |
+
input_ids,
|
| 79 |
+
position_ids,
|
| 80 |
+
attention_mask,
|
| 81 |
+
past_key_values,
|
| 82 |
+
inputs_embeds,
|
| 83 |
+
labels,
|
| 84 |
+
) = self.prepare_inputs_labels_for_multimodal(
|
| 85 |
+
input_ids, position_ids, attention_mask, past_key_values, labels, images
|
| 86 |
+
)
|
| 87 |
+
if self.training:
|
| 88 |
+
(
|
| 89 |
+
_,
|
| 90 |
+
new_position_ids,
|
| 91 |
+
new_attention_mask,
|
| 92 |
+
_,
|
| 93 |
+
new_inputs_embeds,
|
| 94 |
+
new_labels,
|
| 95 |
+
sorted_seqlens_in_batch,
|
| 96 |
+
) = self.repack_multimodal_data(
|
| 97 |
+
input_ids, position_ids, attention_mask, past_key_values, inputs_embeds, labels
|
| 98 |
+
)
|
| 99 |
+
if sorted_seqlens_in_batch is None:
|
| 100 |
+
sorted_seqlens_in_batch = seqlens_in_batch
|
| 101 |
+
new_input_ids = None
|
| 102 |
+
past_key_values = None
|
| 103 |
+
else:
|
| 104 |
+
new_attention_mask = attention_mask
|
| 105 |
+
new_position_ids = position_ids
|
| 106 |
+
new_inputs_embeds = inputs_embeds
|
| 107 |
+
new_labels = labels
|
| 108 |
+
sorted_seqlens_in_batch = attention_mask.sum(-1).int()
|
| 109 |
+
new_input_ids = input_ids
|
| 110 |
+
|
| 111 |
+
outputs = super().forward(
|
| 112 |
+
input_ids=new_input_ids,
|
| 113 |
+
attention_mask=new_attention_mask,
|
| 114 |
+
position_ids=new_position_ids,
|
| 115 |
+
past_key_values=past_key_values,
|
| 116 |
+
inputs_embeds=new_inputs_embeds,
|
| 117 |
+
labels=new_labels,
|
| 118 |
+
use_cache=use_cache,
|
| 119 |
+
output_attentions=output_attentions,
|
| 120 |
+
output_hidden_states=output_hidden_states,
|
| 121 |
+
return_dict=return_dict,
|
| 122 |
+
seqlens_in_batch=sorted_seqlens_in_batch,
|
| 123 |
+
)
|
| 124 |
+
return outputs
|
| 125 |
+
|
| 126 |
+
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
|
| 127 |
+
images = kwargs.pop("images", None)
|
| 128 |
+
_inputs = super().prepare_inputs_for_generation(
|
| 129 |
+
input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs
|
| 130 |
+
)
|
| 131 |
+
if images is not None:
|
| 132 |
+
_inputs["images"] = images
|
| 133 |
+
return _inputs
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
AutoConfig.register("llava_mistral", LlavaMistralConfig)
|
| 137 |
+
AutoModelForCausalLM.register(LlavaMistralConfig, LlavaMistralForCausalLM)
|
VILA/llava/model/language_model/llava_mixtral.py
ADDED
|
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
#
|
| 15 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 16 |
+
|
| 17 |
+
# This file is modified from https://github.com/haotian-liu/LLaVA/
|
| 18 |
+
|
| 19 |
+
from typing import List, Optional, Tuple, Union
|
| 20 |
+
|
| 21 |
+
import torch
|
| 22 |
+
import torch.nn as nn
|
| 23 |
+
from transformers import AutoConfig, AutoModelForCausalLM, MixtralConfig, MixtralForCausalLM, MixtralModel
|
| 24 |
+
from transformers.modeling_outputs import CausalLMOutputWithPast
|
| 25 |
+
|
| 26 |
+
from ..llava_arch import LlavaMetaForCausalLM, LlavaMetaModel
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class LlavaMixtralConfig(MixtralConfig):
|
| 30 |
+
model_type = "llava_mixtral"
|
| 31 |
+
pretraining_tp = 1
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class LlavaMixtralModel(MixtralModel, LlavaMetaModel):
|
| 35 |
+
config_class = LlavaMixtralConfig
|
| 36 |
+
|
| 37 |
+
def __init__(self, config: MixtralConfig):
|
| 38 |
+
super().__init__(config)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class LlavaMixtralForCausalLM(MixtralForCausalLM, LlavaMetaForCausalLM):
|
| 42 |
+
config_class = LlavaMixtralConfig
|
| 43 |
+
|
| 44 |
+
def __init__(self, config):
|
| 45 |
+
super(MixtralForCausalLM, self).__init__(config)
|
| 46 |
+
self.model = LlavaMixtralModel(config)
|
| 47 |
+
self.pretraining_tp = config.pretraining_tp
|
| 48 |
+
self.vocab_size = config.vocab_size
|
| 49 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
| 50 |
+
|
| 51 |
+
# Initialize weights and apply final processing
|
| 52 |
+
self.post_init()
|
| 53 |
+
|
| 54 |
+
def get_model(self):
|
| 55 |
+
return self.model
|
| 56 |
+
|
| 57 |
+
def get_lm_head(self):
|
| 58 |
+
return self.lm_head
|
| 59 |
+
|
| 60 |
+
def forward(
|
| 61 |
+
self,
|
| 62 |
+
input_ids: torch.LongTensor = None,
|
| 63 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 64 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 65 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
| 66 |
+
seqlens_in_batch: Optional[torch.LongTensor] = None,
|
| 67 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 68 |
+
labels: Optional[torch.LongTensor] = None,
|
| 69 |
+
use_cache: Optional[bool] = None,
|
| 70 |
+
output_attentions: Optional[bool] = None,
|
| 71 |
+
output_hidden_states: Optional[bool] = None,
|
| 72 |
+
images: Optional[torch.FloatTensor] = None,
|
| 73 |
+
return_dict: Optional[bool] = None,
|
| 74 |
+
) -> Union[Tuple, CausalLMOutputWithPast]:
|
| 75 |
+
if inputs_embeds is None:
|
| 76 |
+
(
|
| 77 |
+
input_ids,
|
| 78 |
+
position_ids,
|
| 79 |
+
attention_mask,
|
| 80 |
+
past_key_values,
|
| 81 |
+
inputs_embeds,
|
| 82 |
+
labels,
|
| 83 |
+
) = self.prepare_inputs_labels_for_multimodal(
|
| 84 |
+
input_ids, position_ids, attention_mask, past_key_values, labels, images
|
| 85 |
+
)
|
| 86 |
+
if self.training:
|
| 87 |
+
(
|
| 88 |
+
_,
|
| 89 |
+
new_position_ids,
|
| 90 |
+
new_attention_mask,
|
| 91 |
+
_,
|
| 92 |
+
new_inputs_embeds,
|
| 93 |
+
new_labels,
|
| 94 |
+
sorted_seqlens_in_batch,
|
| 95 |
+
) = self.repack_multimodal_data(
|
| 96 |
+
input_ids, position_ids, attention_mask, past_key_values, inputs_embeds, labels
|
| 97 |
+
)
|
| 98 |
+
if sorted_seqlens_in_batch is None:
|
| 99 |
+
sorted_seqlens_in_batch = seqlens_in_batch
|
| 100 |
+
new_input_ids = None
|
| 101 |
+
past_key_values = None
|
| 102 |
+
else:
|
| 103 |
+
new_attention_mask = attention_mask
|
| 104 |
+
new_position_ids = position_ids
|
| 105 |
+
new_inputs_embeds = inputs_embeds
|
| 106 |
+
new_labels = labels
|
| 107 |
+
sorted_seqlens_in_batch = attention_mask.sum(-1).int()
|
| 108 |
+
new_input_ids = input_ids
|
| 109 |
+
|
| 110 |
+
outputs = super().forward(
|
| 111 |
+
input_ids=new_input_ids,
|
| 112 |
+
attention_mask=new_attention_mask,
|
| 113 |
+
position_ids=new_position_ids,
|
| 114 |
+
past_key_values=past_key_values,
|
| 115 |
+
inputs_embeds=new_inputs_embeds,
|
| 116 |
+
labels=new_labels,
|
| 117 |
+
use_cache=use_cache,
|
| 118 |
+
output_attentions=output_attentions,
|
| 119 |
+
output_hidden_states=output_hidden_states,
|
| 120 |
+
return_dict=return_dict,
|
| 121 |
+
seqlens_in_batch=sorted_seqlens_in_batch,
|
| 122 |
+
)
|
| 123 |
+
return outputs
|
| 124 |
+
|
| 125 |
+
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
|
| 126 |
+
images = kwargs.pop("images", None)
|
| 127 |
+
_inputs = super().prepare_inputs_for_generation(
|
| 128 |
+
input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs
|
| 129 |
+
)
|
| 130 |
+
if images is not None:
|
| 131 |
+
_inputs["images"] = images
|
| 132 |
+
return _inputs
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
AutoConfig.register("llava_mixtral", LlavaMixtralConfig)
|
| 136 |
+
AutoModelForCausalLM.register(LlavaMixtralConfig, LlavaMixtralForCausalLM)
|
VILA/llava/model/language_model/llava_mpt.py
ADDED
|
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2023 Haotian Liu
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
# This file is modified from https://github.com/haotian-liu/LLaVA/
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
import math
|
| 19 |
+
import warnings
|
| 20 |
+
from typing import List, Optional, Tuple
|
| 21 |
+
|
| 22 |
+
import torch
|
| 23 |
+
import torch.nn.functional as F
|
| 24 |
+
from transformers import AutoConfig, AutoModelForCausalLM
|
| 25 |
+
from transformers.modeling_outputs import CausalLMOutputWithPast
|
| 26 |
+
|
| 27 |
+
from llava.model.llava_arch import LlavaMetaForCausalLM, LlavaMetaModel
|
| 28 |
+
|
| 29 |
+
from .mpt.modeling_mpt import MPTConfig, MPTForCausalLM, MPTModel
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class LlavaMPTConfig(MPTConfig):
|
| 33 |
+
model_type = "llava_mpt"
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class LlavaMPTModel(MPTModel, LlavaMetaModel):
|
| 37 |
+
config_class = LlavaMPTConfig
|
| 38 |
+
|
| 39 |
+
def __init__(self, config: MPTConfig):
|
| 40 |
+
config.hidden_size = config.d_model
|
| 41 |
+
super().__init__(config)
|
| 42 |
+
|
| 43 |
+
def embed_tokens(self, x):
|
| 44 |
+
return self.wte(x)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
class LlavaMPTForCausalLM(MPTForCausalLM, LlavaMetaForCausalLM):
|
| 48 |
+
config_class = LlavaMPTConfig
|
| 49 |
+
supports_gradient_checkpointing = True
|
| 50 |
+
|
| 51 |
+
def __init__(self, config):
|
| 52 |
+
super(MPTForCausalLM, self).__init__(config)
|
| 53 |
+
|
| 54 |
+
if not config.tie_word_embeddings:
|
| 55 |
+
raise ValueError("MPTForCausalLM only supports tied word embeddings")
|
| 56 |
+
self.transformer = LlavaMPTModel(config)
|
| 57 |
+
self.logit_scale = None
|
| 58 |
+
if config.logit_scale is not None:
|
| 59 |
+
logit_scale = config.logit_scale
|
| 60 |
+
if isinstance(logit_scale, str):
|
| 61 |
+
if logit_scale == "inv_sqrt_d_model":
|
| 62 |
+
logit_scale = 1 / math.sqrt(config.d_model)
|
| 63 |
+
else:
|
| 64 |
+
raise ValueError(
|
| 65 |
+
f"logit_scale={logit_scale!r} is not recognized as an option; use numeric value or 'inv_sqrt_d_model'."
|
| 66 |
+
)
|
| 67 |
+
self.logit_scale = logit_scale
|
| 68 |
+
|
| 69 |
+
def get_model(self):
|
| 70 |
+
return self.transformer
|
| 71 |
+
|
| 72 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
| 73 |
+
if isinstance(module, LlavaMPTModel):
|
| 74 |
+
module.gradient_checkpointing = value
|
| 75 |
+
|
| 76 |
+
def forward(
|
| 77 |
+
self,
|
| 78 |
+
input_ids: torch.LongTensor,
|
| 79 |
+
past_key_values: Optional[List[Tuple[torch.FloatTensor]]] = None,
|
| 80 |
+
attention_mask: Optional[torch.ByteTensor] = None,
|
| 81 |
+
prefix_mask: Optional[torch.ByteTensor] = None,
|
| 82 |
+
sequence_id: Optional[torch.LongTensor] = None,
|
| 83 |
+
labels: Optional[torch.LongTensor] = None,
|
| 84 |
+
return_dict: Optional[bool] = None,
|
| 85 |
+
output_attentions: Optional[bool] = None,
|
| 86 |
+
output_hidden_states: Optional[bool] = None,
|
| 87 |
+
use_cache: Optional[bool] = None,
|
| 88 |
+
images=None,
|
| 89 |
+
):
|
| 90 |
+
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
| 91 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
| 92 |
+
|
| 93 |
+
(
|
| 94 |
+
input_ids,
|
| 95 |
+
_,
|
| 96 |
+
attention_mask,
|
| 97 |
+
past_key_values,
|
| 98 |
+
inputs_embeds,
|
| 99 |
+
labels,
|
| 100 |
+
) = self.prepare_inputs_labels_for_multimodal(input_ids, None, attention_mask, past_key_values, labels, images)
|
| 101 |
+
outputs = self.transformer(
|
| 102 |
+
input_ids=input_ids,
|
| 103 |
+
inputs_embeds=inputs_embeds,
|
| 104 |
+
past_key_values=past_key_values,
|
| 105 |
+
attention_mask=attention_mask,
|
| 106 |
+
prefix_mask=prefix_mask,
|
| 107 |
+
sequence_id=sequence_id,
|
| 108 |
+
return_dict=return_dict,
|
| 109 |
+
output_attentions=output_attentions,
|
| 110 |
+
output_hidden_states=output_hidden_states,
|
| 111 |
+
use_cache=use_cache,
|
| 112 |
+
)
|
| 113 |
+
# FIXME: this is a hack to fix the multiple gpu inference issue in https://github.com/haotian-liu/LLaVA/issues/338
|
| 114 |
+
logits = F.linear(outputs.last_hidden_state.to(self.transformer.wte.weight.device), self.transformer.wte.weight)
|
| 115 |
+
if self.logit_scale is not None:
|
| 116 |
+
if self.logit_scale == 0:
|
| 117 |
+
warnings.warn(
|
| 118 |
+
f"Multiplying logits by self.logit_scale={self.logit_scale!r}. This will produce uniform (uninformative) outputs."
|
| 119 |
+
)
|
| 120 |
+
logits *= self.logit_scale
|
| 121 |
+
loss = None
|
| 122 |
+
if labels is not None:
|
| 123 |
+
labels = torch.roll(labels, shifts=-1)
|
| 124 |
+
labels[:, -1] = -100
|
| 125 |
+
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), labels.to(logits.device).view(-1))
|
| 126 |
+
return CausalLMOutputWithPast(
|
| 127 |
+
loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
|
| 131 |
+
if inputs_embeds is not None:
|
| 132 |
+
raise NotImplementedError("inputs_embeds is not implemented for MPT yet")
|
| 133 |
+
attention_mask = kwargs["attention_mask"].bool()
|
| 134 |
+
if attention_mask[:, -1].sum() != attention_mask.shape[0]:
|
| 135 |
+
raise NotImplementedError("MPT does not support generation with right padding.")
|
| 136 |
+
if self.transformer.attn_uses_sequence_id and self.training:
|
| 137 |
+
sequence_id = torch.zeros_like(input_ids[:1])
|
| 138 |
+
else:
|
| 139 |
+
sequence_id = None
|
| 140 |
+
if past_key_values is not None:
|
| 141 |
+
input_ids = input_ids[:, -1].unsqueeze(-1)
|
| 142 |
+
if self.transformer.prefix_lm:
|
| 143 |
+
prefix_mask = torch.ones_like(attention_mask)
|
| 144 |
+
if kwargs.get("use_cache") == False:
|
| 145 |
+
raise NotImplementedError("MPT with prefix_lm=True does not support use_cache=False.")
|
| 146 |
+
else:
|
| 147 |
+
prefix_mask = None
|
| 148 |
+
return {
|
| 149 |
+
"input_ids": input_ids,
|
| 150 |
+
"attention_mask": attention_mask,
|
| 151 |
+
"prefix_mask": prefix_mask,
|
| 152 |
+
"sequence_id": sequence_id,
|
| 153 |
+
"past_key_values": past_key_values,
|
| 154 |
+
"use_cache": kwargs.get("use_cache", True),
|
| 155 |
+
"images": kwargs.get("images", None),
|
| 156 |
+
}
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
AutoConfig.register("llava_mpt", LlavaMPTConfig)
|
| 160 |
+
AutoModelForCausalLM.register(LlavaMPTConfig, LlavaMPTForCausalLM)
|
VILA/llava/model/language_model/modeling_mixtral_long_context.py
ADDED
|
@@ -0,0 +1,1657 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2023 Mistral AI and the HuggingFace Inc. team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
| 4 |
+
# and OPT implementations in this library. It has been modified from its
|
| 5 |
+
# original forms to accommodate minor architectural differences compared
|
| 6 |
+
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
| 7 |
+
#
|
| 8 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 9 |
+
# you may not use this file except in compliance with the License.
|
| 10 |
+
# You may obtain a copy of the License at
|
| 11 |
+
#
|
| 12 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 13 |
+
#
|
| 14 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 15 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 16 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 17 |
+
# See the License for the specific language governing permissions and
|
| 18 |
+
# limitations under the License.
|
| 19 |
+
""" PyTorch Mixtral model."""
|
| 20 |
+
import inspect
|
| 21 |
+
import math
|
| 22 |
+
import random
|
| 23 |
+
import warnings
|
| 24 |
+
from typing import List, Optional, Tuple, Union
|
| 25 |
+
|
| 26 |
+
import torch
|
| 27 |
+
import torch.nn.functional as F
|
| 28 |
+
import torch.utils.checkpoint
|
| 29 |
+
from torch import nn
|
| 30 |
+
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
| 31 |
+
from transformers.activations import ACT2FN
|
| 32 |
+
from transformers.cache_utils import Cache, DynamicCache
|
| 33 |
+
from transformers.modeling_attn_mask_utils import (
|
| 34 |
+
_prepare_4d_causal_attention_mask,
|
| 35 |
+
_prepare_4d_causal_attention_mask_for_sdpa,
|
| 36 |
+
)
|
| 37 |
+
from transformers.modeling_outputs import (
|
| 38 |
+
MoeCausalLMOutputWithPast,
|
| 39 |
+
MoeModelOutputWithPast,
|
| 40 |
+
SequenceClassifierOutputWithPast,
|
| 41 |
+
)
|
| 42 |
+
from transformers.modeling_utils import PreTrainedModel
|
| 43 |
+
from transformers.models.mixtral.configuration_mixtral import MixtralConfig
|
| 44 |
+
from transformers.pytorch_utils import is_torch_greater_or_equal_than_1_13
|
| 45 |
+
from transformers.utils import (
|
| 46 |
+
add_start_docstrings,
|
| 47 |
+
add_start_docstrings_to_model_forward,
|
| 48 |
+
is_flash_attn_2_available,
|
| 49 |
+
is_flash_attn_greater_or_equal_2_10,
|
| 50 |
+
logging,
|
| 51 |
+
replace_return_docstrings,
|
| 52 |
+
)
|
| 53 |
+
from transformers.utils.import_utils import is_torch_fx_available
|
| 54 |
+
|
| 55 |
+
if is_flash_attn_2_available():
|
| 56 |
+
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
| 57 |
+
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
|
| 58 |
+
|
| 59 |
+
_flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters)
|
| 60 |
+
|
| 61 |
+
# This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph.
|
| 62 |
+
# It means that the function will not be traced through and simply appear as a node in the graph.
|
| 63 |
+
if is_torch_fx_available():
|
| 64 |
+
if not is_torch_greater_or_equal_than_1_13:
|
| 65 |
+
import torch.fx
|
| 66 |
+
|
| 67 |
+
_prepare_4d_causal_attention_mask = torch.fx.wrap(_prepare_4d_causal_attention_mask)
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
logger = logging.get_logger(__name__)
|
| 71 |
+
|
| 72 |
+
_CONFIG_FOR_DOC = "MixtralConfig"
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def load_balancing_loss_func(
|
| 76 |
+
gate_logits: torch.Tensor, num_experts: torch.Tensor = None, top_k=2, attention_mask: Optional[torch.Tensor] = None
|
| 77 |
+
) -> float:
|
| 78 |
+
r"""
|
| 79 |
+
Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch.
|
| 80 |
+
|
| 81 |
+
See Switch Transformer (https://arxiv.org/abs/2101.03961) for more details. This function implements the loss
|
| 82 |
+
function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between
|
| 83 |
+
experts is too unbalanced.
|
| 84 |
+
|
| 85 |
+
Args:
|
| 86 |
+
gate_logits (Union[`torch.Tensor`, Tuple[torch.Tensor]):
|
| 87 |
+
Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of
|
| 88 |
+
shape [batch_size X sequence_length, num_experts].
|
| 89 |
+
attention_mask (`torch.Tensor`, None):
|
| 90 |
+
The attention_mask used in forward function
|
| 91 |
+
shape [batch_size X sequence_length] if not None.
|
| 92 |
+
num_experts (`int`, *optional*):
|
| 93 |
+
Number of experts
|
| 94 |
+
|
| 95 |
+
Returns:
|
| 96 |
+
The auxiliary loss.
|
| 97 |
+
"""
|
| 98 |
+
if gate_logits is None or not isinstance(gate_logits, tuple):
|
| 99 |
+
return 0
|
| 100 |
+
|
| 101 |
+
if isinstance(gate_logits, tuple):
|
| 102 |
+
compute_device = gate_logits[0].device
|
| 103 |
+
concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0)
|
| 104 |
+
|
| 105 |
+
routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1)
|
| 106 |
+
|
| 107 |
+
_, selected_experts = torch.topk(routing_weights, top_k, dim=-1)
|
| 108 |
+
|
| 109 |
+
expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts)
|
| 110 |
+
|
| 111 |
+
if attention_mask is None:
|
| 112 |
+
# Compute the percentage of tokens routed to each experts
|
| 113 |
+
tokens_per_expert = torch.mean(expert_mask.float(), dim=0)
|
| 114 |
+
|
| 115 |
+
# Compute the average probability of routing to these experts
|
| 116 |
+
router_prob_per_expert = torch.mean(routing_weights, dim=0)
|
| 117 |
+
else:
|
| 118 |
+
batch_size, sequence_length = attention_mask.shape
|
| 119 |
+
num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length)
|
| 120 |
+
|
| 121 |
+
# Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask
|
| 122 |
+
expert_attention_mask = (
|
| 123 |
+
attention_mask[None, :, :, None, None]
|
| 124 |
+
.expand((num_hidden_layers, batch_size, sequence_length, 2, num_experts))
|
| 125 |
+
.reshape(-1, 2, num_experts)
|
| 126 |
+
.to(compute_device)
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
# Compute the percentage of tokens routed to each experts
|
| 130 |
+
tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum(
|
| 131 |
+
expert_attention_mask, dim=0
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
# Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert
|
| 135 |
+
router_per_expert_attention_mask = (
|
| 136 |
+
attention_mask[None, :, :, None]
|
| 137 |
+
.expand((num_hidden_layers, batch_size, sequence_length, num_experts))
|
| 138 |
+
.reshape(-1, num_experts)
|
| 139 |
+
.to(compute_device)
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
# Compute the average probability of routing to these experts
|
| 143 |
+
router_prob_per_expert = torch.sum(routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum(
|
| 144 |
+
router_per_expert_attention_mask, dim=0
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0))
|
| 148 |
+
return overall_loss * num_experts
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
# Copied from transformers.models.llama.modeling_llama._get_unpad_data
|
| 152 |
+
def _get_unpad_data(attention_mask):
|
| 153 |
+
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
|
| 154 |
+
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
|
| 155 |
+
max_seqlen_in_batch = seqlens_in_batch.max().item()
|
| 156 |
+
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
|
| 157 |
+
return (
|
| 158 |
+
indices,
|
| 159 |
+
cu_seqlens,
|
| 160 |
+
max_seqlen_in_batch,
|
| 161 |
+
)
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Mixtral
|
| 165 |
+
class MixtralRMSNorm(nn.Module):
|
| 166 |
+
def __init__(self, hidden_size, eps=1e-6):
|
| 167 |
+
"""
|
| 168 |
+
MixtralRMSNorm is equivalent to T5LayerNorm
|
| 169 |
+
"""
|
| 170 |
+
super().__init__()
|
| 171 |
+
self.weight = nn.Parameter(torch.ones(hidden_size))
|
| 172 |
+
self.variance_epsilon = eps
|
| 173 |
+
|
| 174 |
+
def forward(self, hidden_states):
|
| 175 |
+
input_dtype = hidden_states.dtype
|
| 176 |
+
hidden_states = hidden_states.to(torch.float32)
|
| 177 |
+
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
| 178 |
+
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
| 179 |
+
return self.weight * hidden_states.to(input_dtype)
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
# Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->Mixtral
|
| 183 |
+
class MixtralRotaryEmbedding(nn.Module):
|
| 184 |
+
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
|
| 185 |
+
super().__init__()
|
| 186 |
+
|
| 187 |
+
self.dim = dim
|
| 188 |
+
self.max_position_embeddings = max_position_embeddings
|
| 189 |
+
self.base = base
|
| 190 |
+
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
|
| 191 |
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
| 192 |
+
|
| 193 |
+
# Build here to make `torch.jit.trace` work.
|
| 194 |
+
self._set_cos_sin_cache(
|
| 195 |
+
seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
|
| 196 |
+
)
|
| 197 |
+
|
| 198 |
+
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
| 199 |
+
self.max_seq_len_cached = seq_len
|
| 200 |
+
t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
|
| 201 |
+
|
| 202 |
+
freqs = torch.outer(t, self.inv_freq)
|
| 203 |
+
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
| 204 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
| 205 |
+
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
|
| 206 |
+
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
|
| 207 |
+
|
| 208 |
+
def forward(self, x, seq_len=None):
|
| 209 |
+
# x: [bs, num_attention_heads, seq_len, head_size]
|
| 210 |
+
if seq_len > self.max_seq_len_cached:
|
| 211 |
+
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
|
| 212 |
+
|
| 213 |
+
return (
|
| 214 |
+
self.cos_cached[:seq_len].to(dtype=x.dtype),
|
| 215 |
+
self.sin_cached[:seq_len].to(dtype=x.dtype),
|
| 216 |
+
)
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
class MixtralLinearScalingRotaryEmbedding(MixtralRotaryEmbedding):
|
| 220 |
+
"""MixtralRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
|
| 221 |
+
|
| 222 |
+
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
|
| 223 |
+
self.scaling_factor = scaling_factor
|
| 224 |
+
super().__init__(dim, max_position_embeddings, base, device)
|
| 225 |
+
|
| 226 |
+
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
| 227 |
+
self.max_seq_len_cached = seq_len
|
| 228 |
+
t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
|
| 229 |
+
t = t / self.scaling_factor
|
| 230 |
+
|
| 231 |
+
freqs = torch.outer(t, self.inv_freq)
|
| 232 |
+
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
| 233 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
| 234 |
+
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
|
| 235 |
+
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
# Copied from transformers.models.llama.modeling_llama.rotate_half
|
| 239 |
+
def rotate_half(x):
|
| 240 |
+
"""Rotates half the hidden dims of the input."""
|
| 241 |
+
x1 = x[..., : x.shape[-1] // 2]
|
| 242 |
+
x2 = x[..., x.shape[-1] // 2 :]
|
| 243 |
+
return torch.cat((-x2, x1), dim=-1)
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
# Copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb
|
| 247 |
+
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
|
| 248 |
+
"""Applies Rotary Position Embedding to the query and key tensors.
|
| 249 |
+
|
| 250 |
+
Args:
|
| 251 |
+
q (`torch.Tensor`): The query tensor.
|
| 252 |
+
k (`torch.Tensor`): The key tensor.
|
| 253 |
+
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
| 254 |
+
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
| 255 |
+
position_ids (`torch.Tensor`):
|
| 256 |
+
The position indices of the tokens corresponding to the query and key tensors. For example, this can be
|
| 257 |
+
used to pass offsetted position ids when working with a KV-cache.
|
| 258 |
+
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
| 259 |
+
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
| 260 |
+
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
| 261 |
+
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
|
| 262 |
+
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
|
| 263 |
+
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
|
| 264 |
+
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
|
| 265 |
+
Returns:
|
| 266 |
+
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
| 267 |
+
"""
|
| 268 |
+
cos = cos[position_ids].unsqueeze(unsqueeze_dim)
|
| 269 |
+
sin = sin[position_ids].unsqueeze(unsqueeze_dim)
|
| 270 |
+
q_embed = (q * cos) + (rotate_half(q) * sin)
|
| 271 |
+
k_embed = (k * cos) + (rotate_half(k) * sin)
|
| 272 |
+
return q_embed, k_embed
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
# Copied from transformers.models.llama.modeling_llama.repeat_kv
|
| 276 |
+
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
| 277 |
+
"""
|
| 278 |
+
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
| 279 |
+
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
| 280 |
+
"""
|
| 281 |
+
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
| 282 |
+
if n_rep == 1:
|
| 283 |
+
return hidden_states
|
| 284 |
+
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
|
| 285 |
+
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
| 286 |
+
|
| 287 |
+
|
| 288 |
+
# Copied from transformers.models.mistral.modeling_mistral.MistralAttention with Mistral->Mixtral
|
| 289 |
+
class MixtralAttention(nn.Module):
|
| 290 |
+
"""
|
| 291 |
+
Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer
|
| 292 |
+
and "Generating Long Sequences with Sparse Transformers".
|
| 293 |
+
"""
|
| 294 |
+
|
| 295 |
+
def __init__(self, config: MixtralConfig, layer_idx: Optional[int] = None):
|
| 296 |
+
super().__init__()
|
| 297 |
+
self.config = config
|
| 298 |
+
self.layer_idx = layer_idx
|
| 299 |
+
if layer_idx is None:
|
| 300 |
+
logger.warning_once(
|
| 301 |
+
f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
|
| 302 |
+
"lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
|
| 303 |
+
"when creating this class."
|
| 304 |
+
)
|
| 305 |
+
|
| 306 |
+
self.hidden_size = config.hidden_size
|
| 307 |
+
self.num_heads = config.num_attention_heads
|
| 308 |
+
self.head_dim = self.hidden_size // self.num_heads
|
| 309 |
+
self.num_key_value_heads = config.num_key_value_heads
|
| 310 |
+
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
| 311 |
+
self.max_position_embeddings = config.max_position_embeddings
|
| 312 |
+
self.rope_theta = config.rope_theta
|
| 313 |
+
self.is_causal = True
|
| 314 |
+
self.attention_dropout = config.attention_dropout
|
| 315 |
+
|
| 316 |
+
if (self.head_dim * self.num_heads) != self.hidden_size:
|
| 317 |
+
raise ValueError(
|
| 318 |
+
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
|
| 319 |
+
f" and `num_heads`: {self.num_heads})."
|
| 320 |
+
)
|
| 321 |
+
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
|
| 322 |
+
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
|
| 323 |
+
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
|
| 324 |
+
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
|
| 325 |
+
|
| 326 |
+
self._init_rope()
|
| 327 |
+
|
| 328 |
+
def _init_rope(self):
|
| 329 |
+
if self.config.rope_scaling is None:
|
| 330 |
+
self.rotary_emb = MixtralRotaryEmbedding(
|
| 331 |
+
self.head_dim,
|
| 332 |
+
max_position_embeddings=self.max_position_embeddings,
|
| 333 |
+
base=self.rope_theta,
|
| 334 |
+
)
|
| 335 |
+
else:
|
| 336 |
+
scaling_type = self.config.rope_scaling["type"]
|
| 337 |
+
scaling_factor = self.config.rope_scaling["factor"]
|
| 338 |
+
if scaling_type == "linear":
|
| 339 |
+
self.rotary_emb = MixtralLinearScalingRotaryEmbedding(
|
| 340 |
+
self.head_dim,
|
| 341 |
+
max_position_embeddings=self.max_position_embeddings,
|
| 342 |
+
scaling_factor=scaling_factor,
|
| 343 |
+
base=self.rope_theta,
|
| 344 |
+
)
|
| 345 |
+
elif scaling_type == "randomlinear":
|
| 346 |
+
self.rotary_emb = None
|
| 347 |
+
else:
|
| 348 |
+
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
|
| 349 |
+
|
| 350 |
+
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
| 351 |
+
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
| 352 |
+
|
| 353 |
+
def forward(
|
| 354 |
+
self,
|
| 355 |
+
hidden_states: torch.Tensor,
|
| 356 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 357 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 358 |
+
past_key_value: Optional[Cache] = None,
|
| 359 |
+
output_attentions: bool = False,
|
| 360 |
+
use_cache: bool = False,
|
| 361 |
+
**kwargs,
|
| 362 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
| 363 |
+
if "padding_mask" in kwargs:
|
| 364 |
+
warnings.warn(
|
| 365 |
+
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
|
| 366 |
+
)
|
| 367 |
+
bsz, q_len, _ = hidden_states.size()
|
| 368 |
+
|
| 369 |
+
query_states = self.q_proj(hidden_states)
|
| 370 |
+
key_states = self.k_proj(hidden_states)
|
| 371 |
+
value_states = self.v_proj(hidden_states)
|
| 372 |
+
|
| 373 |
+
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
| 374 |
+
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
| 375 |
+
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
| 376 |
+
|
| 377 |
+
kv_seq_len = key_states.shape[-2]
|
| 378 |
+
if past_key_value is not None:
|
| 379 |
+
if self.layer_idx is None:
|
| 380 |
+
raise ValueError(
|
| 381 |
+
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
|
| 382 |
+
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
|
| 383 |
+
"with a layer index."
|
| 384 |
+
)
|
| 385 |
+
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
| 386 |
+
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
| 387 |
+
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
| 388 |
+
|
| 389 |
+
if past_key_value is not None:
|
| 390 |
+
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
|
| 391 |
+
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
| 392 |
+
|
| 393 |
+
# repeat k/v heads if n_kv_heads < n_heads
|
| 394 |
+
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
| 395 |
+
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
| 396 |
+
|
| 397 |
+
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
| 398 |
+
|
| 399 |
+
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
| 400 |
+
raise ValueError(
|
| 401 |
+
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
|
| 402 |
+
f" {attn_weights.size()}"
|
| 403 |
+
)
|
| 404 |
+
|
| 405 |
+
if attention_mask is not None:
|
| 406 |
+
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
| 407 |
+
raise ValueError(
|
| 408 |
+
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
|
| 409 |
+
)
|
| 410 |
+
|
| 411 |
+
attn_weights = attn_weights + attention_mask
|
| 412 |
+
|
| 413 |
+
# upcast attention to fp32
|
| 414 |
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
| 415 |
+
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
|
| 416 |
+
attn_output = torch.matmul(attn_weights, value_states)
|
| 417 |
+
|
| 418 |
+
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
| 419 |
+
raise ValueError(
|
| 420 |
+
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
| 421 |
+
f" {attn_output.size()}"
|
| 422 |
+
)
|
| 423 |
+
|
| 424 |
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
| 425 |
+
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
| 426 |
+
|
| 427 |
+
attn_output = self.o_proj(attn_output)
|
| 428 |
+
|
| 429 |
+
if not output_attentions:
|
| 430 |
+
attn_weights = None
|
| 431 |
+
|
| 432 |
+
return attn_output, attn_weights, past_key_value
|
| 433 |
+
|
| 434 |
+
|
| 435 |
+
# Copied from transformers.models.mistral.modeling_mistral.MistralFlashAttention2 with Mistral->Mixtral
|
| 436 |
+
class MixtralFlashAttention2(MixtralAttention):
|
| 437 |
+
"""
|
| 438 |
+
Mixtral flash attention module. This module inherits from `MixtralAttention` as the weights of the module stays
|
| 439 |
+
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
|
| 440 |
+
flash attention and deal with padding tokens in case the input contains any of them.
|
| 441 |
+
"""
|
| 442 |
+
|
| 443 |
+
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
|
| 444 |
+
def __init__(self, *args, **kwargs):
|
| 445 |
+
super().__init__(*args, **kwargs)
|
| 446 |
+
|
| 447 |
+
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
|
| 448 |
+
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
|
| 449 |
+
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
|
| 450 |
+
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
|
| 451 |
+
|
| 452 |
+
def forward(
|
| 453 |
+
self,
|
| 454 |
+
hidden_states: torch.Tensor,
|
| 455 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 456 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 457 |
+
past_key_value: Optional[Cache] = None,
|
| 458 |
+
output_attentions: bool = False,
|
| 459 |
+
use_cache: bool = False,
|
| 460 |
+
rotary_emb=None,
|
| 461 |
+
**kwargs,
|
| 462 |
+
):
|
| 463 |
+
if "padding_mask" in kwargs:
|
| 464 |
+
warnings.warn(
|
| 465 |
+
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
|
| 466 |
+
)
|
| 467 |
+
|
| 468 |
+
# overwrite attention_mask with padding_mask
|
| 469 |
+
attention_mask = kwargs.pop("padding_mask")
|
| 470 |
+
bsz, q_len, _ = hidden_states.size()
|
| 471 |
+
|
| 472 |
+
query_states = self.q_proj(hidden_states)
|
| 473 |
+
key_states = self.k_proj(hidden_states)
|
| 474 |
+
value_states = self.v_proj(hidden_states)
|
| 475 |
+
|
| 476 |
+
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
| 477 |
+
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
| 478 |
+
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
| 479 |
+
|
| 480 |
+
kv_seq_len = key_states.shape[-2]
|
| 481 |
+
if past_key_value is not None:
|
| 482 |
+
if self.layer_idx is None:
|
| 483 |
+
raise ValueError(
|
| 484 |
+
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
|
| 485 |
+
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
|
| 486 |
+
"with a layer index."
|
| 487 |
+
)
|
| 488 |
+
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
| 489 |
+
|
| 490 |
+
# Because the input can be padded, the absolute sequence length depends on the max position id.
|
| 491 |
+
rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1
|
| 492 |
+
if rotary_emb is None:
|
| 493 |
+
rotary_emb = self.rotary_emb
|
| 494 |
+
cos, sin = rotary_emb(value_states, seq_len=rotary_seq_len)
|
| 495 |
+
|
| 496 |
+
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
| 497 |
+
|
| 498 |
+
use_sliding_windows = (
|
| 499 |
+
_flash_supports_window_size
|
| 500 |
+
and getattr(self.config, "sliding_window", None) is not None
|
| 501 |
+
and kv_seq_len > self.config.sliding_window
|
| 502 |
+
)
|
| 503 |
+
|
| 504 |
+
if not _flash_supports_window_size:
|
| 505 |
+
logger.warning_once(
|
| 506 |
+
"The current flash attention version does not support sliding window attention, for a more memory efficient implementation"
|
| 507 |
+
" make sure to upgrade flash-attn library."
|
| 508 |
+
)
|
| 509 |
+
|
| 510 |
+
if past_key_value is not None:
|
| 511 |
+
# Activate slicing cache only if the config has a value `sliding_windows` attribute
|
| 512 |
+
cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0
|
| 513 |
+
if (
|
| 514 |
+
getattr(self.config, "sliding_window", None) is not None
|
| 515 |
+
and kv_seq_len > self.config.sliding_window
|
| 516 |
+
and cache_has_contents
|
| 517 |
+
):
|
| 518 |
+
slicing_tokens = 1 - self.config.sliding_window
|
| 519 |
+
|
| 520 |
+
past_key = past_key_value[self.layer_idx][0]
|
| 521 |
+
past_value = past_key_value[self.layer_idx][1]
|
| 522 |
+
|
| 523 |
+
past_key = past_key[:, :, slicing_tokens:, :].contiguous()
|
| 524 |
+
past_value = past_value[:, :, slicing_tokens:, :].contiguous()
|
| 525 |
+
|
| 526 |
+
if past_key.shape[-2] != self.config.sliding_window - 1:
|
| 527 |
+
raise ValueError(
|
| 528 |
+
f"past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got"
|
| 529 |
+
f" {past_key.shape}"
|
| 530 |
+
)
|
| 531 |
+
|
| 532 |
+
if attention_mask is not None:
|
| 533 |
+
attention_mask = attention_mask[:, slicing_tokens:]
|
| 534 |
+
attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1)
|
| 535 |
+
|
| 536 |
+
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
|
| 537 |
+
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
| 538 |
+
|
| 539 |
+
# repeat k/v heads if n_kv_heads < n_heads
|
| 540 |
+
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
| 541 |
+
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
| 542 |
+
dropout_rate = 0.0 if not self.training else self.attention_dropout
|
| 543 |
+
|
| 544 |
+
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
|
| 545 |
+
# therefore the input hidden states gets silently casted in float32. Hence, we need
|
| 546 |
+
# cast them back in float16 just to be sure everything works as expected.
|
| 547 |
+
input_dtype = query_states.dtype
|
| 548 |
+
if input_dtype == torch.float32:
|
| 549 |
+
if torch.is_autocast_enabled():
|
| 550 |
+
target_dtype = torch.get_autocast_gpu_dtype()
|
| 551 |
+
# Handle the case where the model is quantized
|
| 552 |
+
elif hasattr(self.config, "_pre_quantization_dtype"):
|
| 553 |
+
target_dtype = self.config._pre_quantization_dtype
|
| 554 |
+
else:
|
| 555 |
+
target_dtype = self.q_proj.weight.dtype
|
| 556 |
+
|
| 557 |
+
logger.warning_once(
|
| 558 |
+
f"The input hidden states seems to be silently casted in float32, this might be related to"
|
| 559 |
+
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
|
| 560 |
+
f" {target_dtype}."
|
| 561 |
+
)
|
| 562 |
+
|
| 563 |
+
query_states = query_states.to(target_dtype)
|
| 564 |
+
key_states = key_states.to(target_dtype)
|
| 565 |
+
value_states = value_states.to(target_dtype)
|
| 566 |
+
|
| 567 |
+
# Reashape to the expected shape for Flash Attention
|
| 568 |
+
query_states = query_states.transpose(1, 2)
|
| 569 |
+
key_states = key_states.transpose(1, 2)
|
| 570 |
+
value_states = value_states.transpose(1, 2)
|
| 571 |
+
|
| 572 |
+
attn_output = self._flash_attention_forward(
|
| 573 |
+
query_states,
|
| 574 |
+
key_states,
|
| 575 |
+
value_states,
|
| 576 |
+
attention_mask,
|
| 577 |
+
q_len,
|
| 578 |
+
dropout=dropout_rate,
|
| 579 |
+
use_sliding_windows=use_sliding_windows,
|
| 580 |
+
)
|
| 581 |
+
|
| 582 |
+
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
|
| 583 |
+
attn_output = self.o_proj(attn_output)
|
| 584 |
+
|
| 585 |
+
if not output_attentions:
|
| 586 |
+
attn_weights = None
|
| 587 |
+
|
| 588 |
+
return attn_output, attn_weights, past_key_value
|
| 589 |
+
|
| 590 |
+
def _flash_attention_forward(
|
| 591 |
+
self,
|
| 592 |
+
query_states,
|
| 593 |
+
key_states,
|
| 594 |
+
value_states,
|
| 595 |
+
attention_mask,
|
| 596 |
+
query_length,
|
| 597 |
+
dropout=0.0,
|
| 598 |
+
softmax_scale=None,
|
| 599 |
+
use_sliding_windows=False,
|
| 600 |
+
):
|
| 601 |
+
"""
|
| 602 |
+
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
|
| 603 |
+
first unpad the input, then computes the attention scores and pad the final attention scores.
|
| 604 |
+
|
| 605 |
+
Args:
|
| 606 |
+
query_states (`torch.Tensor`):
|
| 607 |
+
Input query states to be passed to Flash Attention API
|
| 608 |
+
key_states (`torch.Tensor`):
|
| 609 |
+
Input key states to be passed to Flash Attention API
|
| 610 |
+
value_states (`torch.Tensor`):
|
| 611 |
+
Input value states to be passed to Flash Attention API
|
| 612 |
+
attention_mask (`torch.Tensor`):
|
| 613 |
+
The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
|
| 614 |
+
position of padding tokens and 1 for the position of non-padding tokens.
|
| 615 |
+
dropout (`int`, *optional*):
|
| 616 |
+
Attention dropout
|
| 617 |
+
softmax_scale (`float`, *optional*):
|
| 618 |
+
The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
|
| 619 |
+
use_sliding_windows (`bool`, *optional*):
|
| 620 |
+
Whether to activate sliding window attention.
|
| 621 |
+
"""
|
| 622 |
+
if not self._flash_attn_uses_top_left_mask:
|
| 623 |
+
causal = self.is_causal
|
| 624 |
+
else:
|
| 625 |
+
# TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
|
| 626 |
+
causal = self.is_causal and query_length != 1
|
| 627 |
+
|
| 628 |
+
# Contains at least one padding token in the sequence
|
| 629 |
+
if attention_mask is not None:
|
| 630 |
+
batch_size = query_states.shape[0]
|
| 631 |
+
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
|
| 632 |
+
query_states, key_states, value_states, attention_mask, query_length
|
| 633 |
+
)
|
| 634 |
+
|
| 635 |
+
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
|
| 636 |
+
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
|
| 637 |
+
|
| 638 |
+
if not use_sliding_windows:
|
| 639 |
+
attn_output_unpad = flash_attn_varlen_func(
|
| 640 |
+
query_states,
|
| 641 |
+
key_states,
|
| 642 |
+
value_states,
|
| 643 |
+
cu_seqlens_q=cu_seqlens_q,
|
| 644 |
+
cu_seqlens_k=cu_seqlens_k,
|
| 645 |
+
max_seqlen_q=max_seqlen_in_batch_q,
|
| 646 |
+
max_seqlen_k=max_seqlen_in_batch_k,
|
| 647 |
+
dropout_p=dropout,
|
| 648 |
+
softmax_scale=softmax_scale,
|
| 649 |
+
causal=causal,
|
| 650 |
+
)
|
| 651 |
+
else:
|
| 652 |
+
attn_output_unpad = flash_attn_varlen_func(
|
| 653 |
+
query_states,
|
| 654 |
+
key_states,
|
| 655 |
+
value_states,
|
| 656 |
+
cu_seqlens_q=cu_seqlens_q,
|
| 657 |
+
cu_seqlens_k=cu_seqlens_k,
|
| 658 |
+
max_seqlen_q=max_seqlen_in_batch_q,
|
| 659 |
+
max_seqlen_k=max_seqlen_in_batch_k,
|
| 660 |
+
dropout_p=dropout,
|
| 661 |
+
softmax_scale=softmax_scale,
|
| 662 |
+
causal=causal,
|
| 663 |
+
window_size=(self.config.sliding_window, self.config.sliding_window),
|
| 664 |
+
)
|
| 665 |
+
|
| 666 |
+
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
|
| 667 |
+
else:
|
| 668 |
+
if not use_sliding_windows:
|
| 669 |
+
attn_output = flash_attn_func(
|
| 670 |
+
query_states,
|
| 671 |
+
key_states,
|
| 672 |
+
value_states,
|
| 673 |
+
dropout,
|
| 674 |
+
softmax_scale=softmax_scale,
|
| 675 |
+
causal=causal,
|
| 676 |
+
)
|
| 677 |
+
else:
|
| 678 |
+
attn_output = flash_attn_func(
|
| 679 |
+
query_states,
|
| 680 |
+
key_states,
|
| 681 |
+
value_states,
|
| 682 |
+
dropout,
|
| 683 |
+
softmax_scale=softmax_scale,
|
| 684 |
+
causal=causal,
|
| 685 |
+
window_size=(self.config.sliding_window, self.config.sliding_window),
|
| 686 |
+
)
|
| 687 |
+
|
| 688 |
+
return attn_output
|
| 689 |
+
|
| 690 |
+
def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
|
| 691 |
+
batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape
|
| 692 |
+
|
| 693 |
+
# On the first iteration we need to properly re-create the padding mask
|
| 694 |
+
# by slicing it on the proper place
|
| 695 |
+
if kv_seq_len != attention_mask.shape[-1]:
|
| 696 |
+
attention_mask_num_tokens = attention_mask.shape[-1]
|
| 697 |
+
attention_mask = attention_mask[:, attention_mask_num_tokens - kv_seq_len :]
|
| 698 |
+
|
| 699 |
+
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
|
| 700 |
+
|
| 701 |
+
key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
|
| 702 |
+
value_layer = index_first_axis(value_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
|
| 703 |
+
|
| 704 |
+
if query_length == kv_seq_len:
|
| 705 |
+
query_layer = index_first_axis(query_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
|
| 706 |
+
cu_seqlens_q = cu_seqlens_k
|
| 707 |
+
max_seqlen_in_batch_q = max_seqlen_in_batch_k
|
| 708 |
+
indices_q = indices_k
|
| 709 |
+
elif query_length == 1:
|
| 710 |
+
max_seqlen_in_batch_q = 1
|
| 711 |
+
cu_seqlens_q = torch.arange(
|
| 712 |
+
batch_size + 1, dtype=torch.int32, device=query_layer.device
|
| 713 |
+
) # There is a memcpy here, that is very bad.
|
| 714 |
+
indices_q = cu_seqlens_q[:-1]
|
| 715 |
+
query_layer = query_layer.squeeze(1)
|
| 716 |
+
else:
|
| 717 |
+
# The -q_len: slice assumes left padding.
|
| 718 |
+
attention_mask = attention_mask[:, -query_length:]
|
| 719 |
+
query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
|
| 720 |
+
|
| 721 |
+
return (
|
| 722 |
+
query_layer,
|
| 723 |
+
key_layer,
|
| 724 |
+
value_layer,
|
| 725 |
+
indices_q,
|
| 726 |
+
(cu_seqlens_q, cu_seqlens_k),
|
| 727 |
+
(max_seqlen_in_batch_q, max_seqlen_in_batch_k),
|
| 728 |
+
)
|
| 729 |
+
|
| 730 |
+
|
| 731 |
+
# Copied from transformers.models.mistral.modeling_mistral.MistralSdpaAttention with Mistral->Mixtral
|
| 732 |
+
class MixtralSdpaAttention(MixtralAttention):
|
| 733 |
+
"""
|
| 734 |
+
Mixtral attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
|
| 735 |
+
`MixtralAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
|
| 736 |
+
SDPA API.
|
| 737 |
+
"""
|
| 738 |
+
|
| 739 |
+
# Adapted from MixtralAttention.forward
|
| 740 |
+
def forward(
|
| 741 |
+
self,
|
| 742 |
+
hidden_states: torch.Tensor,
|
| 743 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 744 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 745 |
+
past_key_value: Optional[Cache] = None,
|
| 746 |
+
output_attentions: bool = False,
|
| 747 |
+
use_cache: bool = False,
|
| 748 |
+
rotary_emb=None,
|
| 749 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
| 750 |
+
if output_attentions:
|
| 751 |
+
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
|
| 752 |
+
logger.warning_once(
|
| 753 |
+
"MixtralModel is using MixtralSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
|
| 754 |
+
'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
| 755 |
+
)
|
| 756 |
+
return super().forward(
|
| 757 |
+
hidden_states=hidden_states,
|
| 758 |
+
attention_mask=attention_mask,
|
| 759 |
+
position_ids=position_ids,
|
| 760 |
+
past_key_value=past_key_value,
|
| 761 |
+
output_attentions=output_attentions,
|
| 762 |
+
use_cache=use_cache,
|
| 763 |
+
)
|
| 764 |
+
|
| 765 |
+
bsz, q_len, _ = hidden_states.size()
|
| 766 |
+
|
| 767 |
+
query_states = self.q_proj(hidden_states)
|
| 768 |
+
key_states = self.k_proj(hidden_states)
|
| 769 |
+
value_states = self.v_proj(hidden_states)
|
| 770 |
+
|
| 771 |
+
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
| 772 |
+
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
| 773 |
+
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
| 774 |
+
|
| 775 |
+
kv_seq_len = key_states.shape[-2]
|
| 776 |
+
if past_key_value is not None:
|
| 777 |
+
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
| 778 |
+
|
| 779 |
+
if rotary_emb is None:
|
| 780 |
+
rotary_emb = self.rotary_emb
|
| 781 |
+
cos, sin = rotary_emb(value_states, seq_len=kv_seq_len)
|
| 782 |
+
|
| 783 |
+
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
| 784 |
+
|
| 785 |
+
if past_key_value is not None:
|
| 786 |
+
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
|
| 787 |
+
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
| 788 |
+
|
| 789 |
+
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
| 790 |
+
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
| 791 |
+
|
| 792 |
+
if attention_mask is not None:
|
| 793 |
+
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
| 794 |
+
raise ValueError(
|
| 795 |
+
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
|
| 796 |
+
)
|
| 797 |
+
|
| 798 |
+
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
|
| 799 |
+
# Reference: https://github.com/pytorch/pytorch/issues/112577.
|
| 800 |
+
if query_states.device.type == "cuda" and attention_mask is not None:
|
| 801 |
+
query_states = query_states.contiguous()
|
| 802 |
+
key_states = key_states.contiguous()
|
| 803 |
+
value_states = value_states.contiguous()
|
| 804 |
+
|
| 805 |
+
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
| 806 |
+
query_states,
|
| 807 |
+
key_states,
|
| 808 |
+
value_states,
|
| 809 |
+
attn_mask=attention_mask,
|
| 810 |
+
dropout_p=self.attention_dropout if self.training else 0.0,
|
| 811 |
+
# The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
|
| 812 |
+
is_causal=self.is_causal and attention_mask is None and q_len > 1,
|
| 813 |
+
)
|
| 814 |
+
|
| 815 |
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
| 816 |
+
attn_output = attn_output.view(bsz, q_len, self.hidden_size)
|
| 817 |
+
|
| 818 |
+
attn_output = self.o_proj(attn_output)
|
| 819 |
+
|
| 820 |
+
return attn_output, None, past_key_value
|
| 821 |
+
|
| 822 |
+
|
| 823 |
+
MIXTRAL_ATTENTION_CLASSES = {
|
| 824 |
+
"eager": MixtralAttention,
|
| 825 |
+
"flash_attention_2": MixtralFlashAttention2,
|
| 826 |
+
"sdpa": MixtralSdpaAttention,
|
| 827 |
+
}
|
| 828 |
+
|
| 829 |
+
|
| 830 |
+
class MixtralBlockSparseTop2MLP(nn.Module):
|
| 831 |
+
def __init__(self, config: MixtralConfig):
|
| 832 |
+
super().__init__()
|
| 833 |
+
self.ffn_dim = config.intermediate_size
|
| 834 |
+
self.hidden_dim = config.hidden_size
|
| 835 |
+
|
| 836 |
+
self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
|
| 837 |
+
self.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False)
|
| 838 |
+
self.w3 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
|
| 839 |
+
|
| 840 |
+
self.act_fn = ACT2FN[config.hidden_act]
|
| 841 |
+
|
| 842 |
+
def forward(self, hidden_states):
|
| 843 |
+
current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states)
|
| 844 |
+
current_hidden_states = self.w2(current_hidden_states)
|
| 845 |
+
return current_hidden_states
|
| 846 |
+
|
| 847 |
+
|
| 848 |
+
class MixtralBLockSparseTop2MLP(MixtralBlockSparseTop2MLP):
|
| 849 |
+
def __init__(self, *args, **kwargs):
|
| 850 |
+
logger.warning_once(
|
| 851 |
+
"MixtralBLockSparseTop2MLP is deprecated by MixtralBlockSparseTop2MLP and will be removed in v4.40."
|
| 852 |
+
)
|
| 853 |
+
super().__init__(*args, **kwargs)
|
| 854 |
+
|
| 855 |
+
|
| 856 |
+
class MixtralSparseMoeBlock(nn.Module):
|
| 857 |
+
"""
|
| 858 |
+
This implementation is
|
| 859 |
+
strictly equivalent to standard MoE with full capacity (no
|
| 860 |
+
dropped tokens). It's faster since it formulates MoE operations
|
| 861 |
+
in terms of block-sparse operations to accomodate imbalanced
|
| 862 |
+
assignments of tokens to experts, whereas standard MoE either
|
| 863 |
+
(1) drop tokens at the cost of reduced performance or (2) set
|
| 864 |
+
capacity factor to number of experts and thus waste computation
|
| 865 |
+
and memory on padding.
|
| 866 |
+
"""
|
| 867 |
+
|
| 868 |
+
def __init__(self, config):
|
| 869 |
+
super().__init__()
|
| 870 |
+
self.hidden_dim = config.hidden_size
|
| 871 |
+
self.ffn_dim = config.intermediate_size
|
| 872 |
+
self.num_experts = config.num_local_experts
|
| 873 |
+
self.top_k = config.num_experts_per_tok
|
| 874 |
+
|
| 875 |
+
# gating
|
| 876 |
+
self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False)
|
| 877 |
+
|
| 878 |
+
self.experts = nn.ModuleList([MixtralBlockSparseTop2MLP(config) for _ in range(self.num_experts)])
|
| 879 |
+
|
| 880 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 881 |
+
""" """
|
| 882 |
+
batch_size, sequence_length, hidden_dim = hidden_states.shape
|
| 883 |
+
hidden_states = hidden_states.view(-1, hidden_dim)
|
| 884 |
+
# router_logits: (batch * sequence_length, n_experts)
|
| 885 |
+
router_logits = self.gate(hidden_states)
|
| 886 |
+
|
| 887 |
+
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
|
| 888 |
+
routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
|
| 889 |
+
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
|
| 890 |
+
# we cast back to the input dtype
|
| 891 |
+
routing_weights = routing_weights.to(hidden_states.dtype)
|
| 892 |
+
|
| 893 |
+
final_hidden_states = torch.zeros(
|
| 894 |
+
(batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
|
| 895 |
+
)
|
| 896 |
+
|
| 897 |
+
# One hot encode the selected experts to create an expert mask
|
| 898 |
+
# this will be used to easily index which expert is going to be sollicitated
|
| 899 |
+
expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
|
| 900 |
+
|
| 901 |
+
# Loop over all available experts in the model and perform the computation on each expert
|
| 902 |
+
for expert_idx in range(self.num_experts):
|
| 903 |
+
expert_layer = self.experts[expert_idx]
|
| 904 |
+
idx, top_x = torch.where(expert_mask[expert_idx])
|
| 905 |
+
|
| 906 |
+
if top_x.shape[0] == 0:
|
| 907 |
+
if self.training:
|
| 908 |
+
top_x_ = torch.zeros(1).to(hidden_states.device).to(torch.int32)
|
| 909 |
+
top_x_list = top_x_.tolist()
|
| 910 |
+
current_state = hidden_states[None, top_x_list].reshape(-1, hidden_dim)
|
| 911 |
+
fake_state = expert_layer(current_state * 0)
|
| 912 |
+
final_hidden_states.index_add_(0, top_x_, fake_state.to(hidden_states.dtype))
|
| 913 |
+
continue
|
| 914 |
+
|
| 915 |
+
# in torch it is faster to index using lists than torch tensors
|
| 916 |
+
top_x_list = top_x.tolist()
|
| 917 |
+
idx_list = idx.tolist()
|
| 918 |
+
|
| 919 |
+
# Index the correct hidden states and compute the expert hidden state for
|
| 920 |
+
# the current expert. We need to make sure to multiply the output hidden
|
| 921 |
+
# states by `routing_weights` on the corresponding tokens (top-1 and top-2)
|
| 922 |
+
current_state = hidden_states[None, top_x_list].reshape(-1, hidden_dim)
|
| 923 |
+
current_hidden_states = expert_layer(current_state) * routing_weights[top_x_list, idx_list, None]
|
| 924 |
+
|
| 925 |
+
# However `index_add_` only support torch tensors for indexing so we'll use
|
| 926 |
+
# the `top_x` tensor here.
|
| 927 |
+
final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
|
| 928 |
+
final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
|
| 929 |
+
return final_hidden_states, router_logits
|
| 930 |
+
|
| 931 |
+
|
| 932 |
+
class MixtralDecoderLayer(nn.Module):
|
| 933 |
+
def __init__(self, config: MixtralConfig, layer_idx: int):
|
| 934 |
+
super().__init__()
|
| 935 |
+
self.hidden_size = config.hidden_size
|
| 936 |
+
|
| 937 |
+
self.self_attn = MIXTRAL_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)
|
| 938 |
+
|
| 939 |
+
self.block_sparse_moe = MixtralSparseMoeBlock(config)
|
| 940 |
+
self.input_layernorm = MixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 941 |
+
self.post_attention_layernorm = MixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 942 |
+
|
| 943 |
+
def forward(
|
| 944 |
+
self,
|
| 945 |
+
hidden_states: torch.Tensor,
|
| 946 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 947 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 948 |
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
| 949 |
+
output_attentions: Optional[bool] = False,
|
| 950 |
+
output_router_logits: Optional[bool] = False,
|
| 951 |
+
use_cache: Optional[bool] = False,
|
| 952 |
+
rotary_emb=None,
|
| 953 |
+
**kwargs,
|
| 954 |
+
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
| 955 |
+
if "padding_mask" in kwargs:
|
| 956 |
+
warnings.warn(
|
| 957 |
+
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
|
| 958 |
+
)
|
| 959 |
+
"""
|
| 960 |
+
Args:
|
| 961 |
+
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
| 962 |
+
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
|
| 963 |
+
`(batch, sequence_length)` where padding elements are indicated by 0.
|
| 964 |
+
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
|
| 965 |
+
output_attentions (`bool`, *optional*):
|
| 966 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
| 967 |
+
returned tensors for more detail.
|
| 968 |
+
output_router_logits (`bool`, *optional*):
|
| 969 |
+
Whether or not to return the logits of all the routers. They are useful for computing the router loss, and
|
| 970 |
+
should not be returned during inference.
|
| 971 |
+
use_cache (`bool`, *optional*):
|
| 972 |
+
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
| 973 |
+
(see `past_key_values`).
|
| 974 |
+
"""
|
| 975 |
+
|
| 976 |
+
residual = hidden_states
|
| 977 |
+
|
| 978 |
+
hidden_states = self.input_layernorm(hidden_states)
|
| 979 |
+
|
| 980 |
+
# Self Attention
|
| 981 |
+
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
| 982 |
+
hidden_states=hidden_states,
|
| 983 |
+
attention_mask=attention_mask,
|
| 984 |
+
position_ids=position_ids,
|
| 985 |
+
past_key_value=past_key_value,
|
| 986 |
+
output_attentions=output_attentions,
|
| 987 |
+
use_cache=use_cache,
|
| 988 |
+
rotary_emb=rotary_emb,
|
| 989 |
+
)
|
| 990 |
+
hidden_states = residual + hidden_states
|
| 991 |
+
|
| 992 |
+
# Fully Connected
|
| 993 |
+
residual = hidden_states
|
| 994 |
+
hidden_states = self.post_attention_layernorm(hidden_states)
|
| 995 |
+
hidden_states, router_logits = self.block_sparse_moe(hidden_states)
|
| 996 |
+
hidden_states = residual + hidden_states
|
| 997 |
+
|
| 998 |
+
outputs = (hidden_states,)
|
| 999 |
+
|
| 1000 |
+
if output_attentions:
|
| 1001 |
+
outputs += (self_attn_weights,)
|
| 1002 |
+
|
| 1003 |
+
if use_cache:
|
| 1004 |
+
outputs += (present_key_value,)
|
| 1005 |
+
|
| 1006 |
+
if output_router_logits:
|
| 1007 |
+
outputs += (router_logits,)
|
| 1008 |
+
|
| 1009 |
+
return outputs
|
| 1010 |
+
|
| 1011 |
+
|
| 1012 |
+
MIXTRAL_START_DOCSTRING = r"""
|
| 1013 |
+
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
| 1014 |
+
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
| 1015 |
+
etc.)
|
| 1016 |
+
|
| 1017 |
+
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
| 1018 |
+
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
| 1019 |
+
and behavior.
|
| 1020 |
+
|
| 1021 |
+
Parameters:
|
| 1022 |
+
config ([`MixtralConfig`]):
|
| 1023 |
+
Model configuration class with all the parameters of the model. Initializing with a config file does not
|
| 1024 |
+
load the weights associated with the model, only the configuration. Check out the
|
| 1025 |
+
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
| 1026 |
+
"""
|
| 1027 |
+
|
| 1028 |
+
|
| 1029 |
+
@add_start_docstrings(
|
| 1030 |
+
"The bare Mixtral Model outputting raw hidden-states without any specific head on top.",
|
| 1031 |
+
MIXTRAL_START_DOCSTRING,
|
| 1032 |
+
)
|
| 1033 |
+
# Copied from transformers.models.mistral.modeling_mistral.MistralPreTrainedModel with Mistral->Mixtral
|
| 1034 |
+
class MixtralPreTrainedModel(PreTrainedModel):
|
| 1035 |
+
config_class = MixtralConfig
|
| 1036 |
+
base_model_prefix = "model"
|
| 1037 |
+
supports_gradient_checkpointing = True
|
| 1038 |
+
_no_split_modules = ["MixtralDecoderLayer"]
|
| 1039 |
+
_skip_keys_device_placement = "past_key_values"
|
| 1040 |
+
_supports_flash_attn_2 = True
|
| 1041 |
+
_supports_sdpa = True
|
| 1042 |
+
_supports_cache_class = True
|
| 1043 |
+
|
| 1044 |
+
def _init_weights(self, module):
|
| 1045 |
+
std = self.config.initializer_range
|
| 1046 |
+
if isinstance(module, nn.Linear):
|
| 1047 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
| 1048 |
+
if module.bias is not None:
|
| 1049 |
+
module.bias.data.zero_()
|
| 1050 |
+
elif isinstance(module, nn.Embedding):
|
| 1051 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
| 1052 |
+
if module.padding_idx is not None:
|
| 1053 |
+
module.weight.data[module.padding_idx].zero_()
|
| 1054 |
+
|
| 1055 |
+
|
| 1056 |
+
MIXTRAL_INPUTS_DOCSTRING = r"""
|
| 1057 |
+
Args:
|
| 1058 |
+
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
| 1059 |
+
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
|
| 1060 |
+
it.
|
| 1061 |
+
|
| 1062 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
| 1063 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
| 1064 |
+
|
| 1065 |
+
[What are input IDs?](../glossary#input-ids)
|
| 1066 |
+
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 1067 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
| 1068 |
+
|
| 1069 |
+
- 1 for tokens that are **not masked**,
|
| 1070 |
+
- 0 for tokens that are **masked**.
|
| 1071 |
+
|
| 1072 |
+
[What are attention masks?](../glossary#attention-mask)
|
| 1073 |
+
|
| 1074 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
| 1075 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
| 1076 |
+
|
| 1077 |
+
If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
|
| 1078 |
+
`past_key_values`).
|
| 1079 |
+
|
| 1080 |
+
If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
|
| 1081 |
+
and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
|
| 1082 |
+
information on the default strategy.
|
| 1083 |
+
|
| 1084 |
+
- 1 indicates the head is **not masked**,
|
| 1085 |
+
- 0 indicates the head is **masked**.
|
| 1086 |
+
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 1087 |
+
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
| 1088 |
+
config.n_positions - 1]`.
|
| 1089 |
+
|
| 1090 |
+
[What are position IDs?](../glossary#position-ids)
|
| 1091 |
+
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
| 1092 |
+
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
|
| 1093 |
+
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
|
| 1094 |
+
`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
|
| 1095 |
+
|
| 1096 |
+
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
|
| 1097 |
+
blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
|
| 1098 |
+
|
| 1099 |
+
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
|
| 1100 |
+
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
|
| 1101 |
+
`decoder_input_ids` of shape `(batch_size, sequence_length)`.
|
| 1102 |
+
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
| 1103 |
+
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
|
| 1104 |
+
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
|
| 1105 |
+
model's internal embedding lookup matrix.
|
| 1106 |
+
use_cache (`bool`, *optional*):
|
| 1107 |
+
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
|
| 1108 |
+
`past_key_values`).
|
| 1109 |
+
output_attentions (`bool`, *optional*):
|
| 1110 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
| 1111 |
+
tensors for more detail.
|
| 1112 |
+
output_hidden_states (`bool`, *optional*):
|
| 1113 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
| 1114 |
+
more detail.
|
| 1115 |
+
output_router_logits (`bool`, *optional*):
|
| 1116 |
+
Whether or not to return the logits of all the routers. They are useful for computing the router loss, and
|
| 1117 |
+
should not be returned during inference.
|
| 1118 |
+
return_dict (`bool`, *optional*):
|
| 1119 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
| 1120 |
+
"""
|
| 1121 |
+
|
| 1122 |
+
|
| 1123 |
+
@add_start_docstrings(
|
| 1124 |
+
"The bare Mixtral Model outputting raw hidden-states without any specific head on top.",
|
| 1125 |
+
MIXTRAL_START_DOCSTRING,
|
| 1126 |
+
)
|
| 1127 |
+
# Copied from transformers.models.mistral.modeling_mistral.MistralModel with MISTRAL->MIXTRAL,Mistral->Mixtral
|
| 1128 |
+
class MixtralModel(MixtralPreTrainedModel):
|
| 1129 |
+
"""
|
| 1130 |
+
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`MixtralDecoderLayer`]
|
| 1131 |
+
|
| 1132 |
+
Args:
|
| 1133 |
+
config: MixtralConfig
|
| 1134 |
+
"""
|
| 1135 |
+
|
| 1136 |
+
def __init__(self, config: MixtralConfig):
|
| 1137 |
+
super().__init__(config)
|
| 1138 |
+
self.padding_idx = config.pad_token_id
|
| 1139 |
+
self.vocab_size = config.vocab_size
|
| 1140 |
+
|
| 1141 |
+
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
| 1142 |
+
self.layers = nn.ModuleList(
|
| 1143 |
+
[MixtralDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
| 1144 |
+
)
|
| 1145 |
+
self._attn_implementation = config._attn_implementation
|
| 1146 |
+
self.norm = MixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 1147 |
+
|
| 1148 |
+
self.gradient_checkpointing = False
|
| 1149 |
+
# Initialize weights and apply final processing
|
| 1150 |
+
self.post_init()
|
| 1151 |
+
|
| 1152 |
+
def get_input_embeddings(self):
|
| 1153 |
+
return self.embed_tokens
|
| 1154 |
+
|
| 1155 |
+
def set_input_embeddings(self, value):
|
| 1156 |
+
self.embed_tokens = value
|
| 1157 |
+
|
| 1158 |
+
# Ignore copy
|
| 1159 |
+
@add_start_docstrings_to_model_forward(MIXTRAL_INPUTS_DOCSTRING)
|
| 1160 |
+
def forward(
|
| 1161 |
+
self,
|
| 1162 |
+
input_ids: torch.LongTensor = None,
|
| 1163 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 1164 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 1165 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
| 1166 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 1167 |
+
use_cache: Optional[bool] = None,
|
| 1168 |
+
output_attentions: Optional[bool] = None,
|
| 1169 |
+
output_hidden_states: Optional[bool] = None,
|
| 1170 |
+
output_router_logits: Optional[bool] = None,
|
| 1171 |
+
return_dict: Optional[bool] = None,
|
| 1172 |
+
) -> Union[Tuple, MoeModelOutputWithPast]:
|
| 1173 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 1174 |
+
output_router_logits = (
|
| 1175 |
+
output_router_logits if output_router_logits is not None else self.config.output_router_logits
|
| 1176 |
+
)
|
| 1177 |
+
output_hidden_states = (
|
| 1178 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 1179 |
+
)
|
| 1180 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
| 1181 |
+
|
| 1182 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1183 |
+
|
| 1184 |
+
# retrieve input_ids and inputs_embeds
|
| 1185 |
+
if input_ids is not None and inputs_embeds is not None:
|
| 1186 |
+
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
|
| 1187 |
+
elif input_ids is not None:
|
| 1188 |
+
batch_size, seq_length = input_ids.shape
|
| 1189 |
+
elif inputs_embeds is not None:
|
| 1190 |
+
batch_size, seq_length, _ = inputs_embeds.shape
|
| 1191 |
+
else:
|
| 1192 |
+
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
|
| 1193 |
+
|
| 1194 |
+
past_key_values_length = 0
|
| 1195 |
+
|
| 1196 |
+
if self.gradient_checkpointing and self.training:
|
| 1197 |
+
if use_cache:
|
| 1198 |
+
logger.warning_once(
|
| 1199 |
+
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
| 1200 |
+
)
|
| 1201 |
+
use_cache = False
|
| 1202 |
+
|
| 1203 |
+
if use_cache:
|
| 1204 |
+
use_legacy_cache = not isinstance(past_key_values, Cache)
|
| 1205 |
+
if use_legacy_cache:
|
| 1206 |
+
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
| 1207 |
+
past_key_values_length = past_key_values.get_usable_length(seq_length)
|
| 1208 |
+
|
| 1209 |
+
if position_ids is None:
|
| 1210 |
+
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
| 1211 |
+
position_ids = torch.arange(
|
| 1212 |
+
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
|
| 1213 |
+
)
|
| 1214 |
+
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
| 1215 |
+
else:
|
| 1216 |
+
position_ids = position_ids.view(-1, seq_length).long()
|
| 1217 |
+
|
| 1218 |
+
if inputs_embeds is None:
|
| 1219 |
+
inputs_embeds = self.embed_tokens(input_ids)
|
| 1220 |
+
|
| 1221 |
+
if attention_mask is not None and self._attn_implementation == "flash_attention_2" and use_cache:
|
| 1222 |
+
is_padding_right = attention_mask[:, -1].sum().item() != batch_size
|
| 1223 |
+
if is_padding_right:
|
| 1224 |
+
raise ValueError(
|
| 1225 |
+
"You are attempting to perform batched generation with padding_side='right'"
|
| 1226 |
+
" this may lead to unexpected behaviour for Flash Attention version of Mixtral. Make sure to "
|
| 1227 |
+
" call `tokenizer.padding_side = 'left'` before tokenizing the input. "
|
| 1228 |
+
)
|
| 1229 |
+
|
| 1230 |
+
if self._attn_implementation == "flash_attention_2":
|
| 1231 |
+
# 2d mask is passed through the layers
|
| 1232 |
+
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
|
| 1233 |
+
elif self._attn_implementation == "sdpa" and not output_attentions:
|
| 1234 |
+
# output_attentions=True can not be supported when using SDPA, and we fall back on
|
| 1235 |
+
# the manual implementation that requires a 4D causal mask in all cases.
|
| 1236 |
+
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
|
| 1237 |
+
attention_mask,
|
| 1238 |
+
(batch_size, seq_length),
|
| 1239 |
+
inputs_embeds,
|
| 1240 |
+
past_key_values_length,
|
| 1241 |
+
)
|
| 1242 |
+
else:
|
| 1243 |
+
# 4d mask is passed through the layers
|
| 1244 |
+
attention_mask = _prepare_4d_causal_attention_mask(
|
| 1245 |
+
attention_mask,
|
| 1246 |
+
(batch_size, seq_length),
|
| 1247 |
+
inputs_embeds,
|
| 1248 |
+
past_key_values_length,
|
| 1249 |
+
sliding_window=self.config.sliding_window,
|
| 1250 |
+
)
|
| 1251 |
+
|
| 1252 |
+
hidden_states = inputs_embeds
|
| 1253 |
+
|
| 1254 |
+
# decoder layers
|
| 1255 |
+
all_hidden_states = () if output_hidden_states else None
|
| 1256 |
+
all_self_attns = () if output_attentions else None
|
| 1257 |
+
all_router_logits = () if output_router_logits else None
|
| 1258 |
+
next_decoder_cache = None
|
| 1259 |
+
|
| 1260 |
+
rotary_emb = None
|
| 1261 |
+
|
| 1262 |
+
for decoder_layer in self.layers:
|
| 1263 |
+
if output_hidden_states:
|
| 1264 |
+
all_hidden_states += (hidden_states,)
|
| 1265 |
+
|
| 1266 |
+
if self.gradient_checkpointing and self.training:
|
| 1267 |
+
layer_outputs = self._gradient_checkpointing_func(
|
| 1268 |
+
decoder_layer.__call__,
|
| 1269 |
+
hidden_states,
|
| 1270 |
+
attention_mask,
|
| 1271 |
+
position_ids,
|
| 1272 |
+
past_key_values,
|
| 1273 |
+
output_attentions,
|
| 1274 |
+
output_router_logits,
|
| 1275 |
+
use_cache,
|
| 1276 |
+
rotary_emb,
|
| 1277 |
+
)
|
| 1278 |
+
else:
|
| 1279 |
+
layer_outputs = decoder_layer(
|
| 1280 |
+
hidden_states,
|
| 1281 |
+
attention_mask=attention_mask,
|
| 1282 |
+
position_ids=position_ids,
|
| 1283 |
+
past_key_value=past_key_values,
|
| 1284 |
+
output_attentions=output_attentions,
|
| 1285 |
+
output_router_logits=output_router_logits,
|
| 1286 |
+
use_cache=use_cache,
|
| 1287 |
+
rotary_emb=rotary_emb,
|
| 1288 |
+
)
|
| 1289 |
+
|
| 1290 |
+
hidden_states = layer_outputs[0]
|
| 1291 |
+
|
| 1292 |
+
if use_cache:
|
| 1293 |
+
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
|
| 1294 |
+
|
| 1295 |
+
if output_attentions:
|
| 1296 |
+
all_self_attns += (layer_outputs[1],)
|
| 1297 |
+
|
| 1298 |
+
if output_router_logits:
|
| 1299 |
+
all_router_logits += (layer_outputs[-1],)
|
| 1300 |
+
|
| 1301 |
+
hidden_states = self.norm(hidden_states)
|
| 1302 |
+
|
| 1303 |
+
# add hidden states from the last decoder layer
|
| 1304 |
+
if output_hidden_states:
|
| 1305 |
+
all_hidden_states += (hidden_states,)
|
| 1306 |
+
|
| 1307 |
+
next_cache = None
|
| 1308 |
+
if use_cache:
|
| 1309 |
+
next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
|
| 1310 |
+
|
| 1311 |
+
if not return_dict:
|
| 1312 |
+
return tuple(
|
| 1313 |
+
v
|
| 1314 |
+
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_router_logits]
|
| 1315 |
+
if v is not None
|
| 1316 |
+
)
|
| 1317 |
+
return MoeModelOutputWithPast(
|
| 1318 |
+
last_hidden_state=hidden_states,
|
| 1319 |
+
past_key_values=next_cache,
|
| 1320 |
+
hidden_states=all_hidden_states,
|
| 1321 |
+
attentions=all_self_attns,
|
| 1322 |
+
router_logits=all_router_logits,
|
| 1323 |
+
)
|
| 1324 |
+
|
| 1325 |
+
|
| 1326 |
+
class MixtralForCausalLM(MixtralPreTrainedModel):
|
| 1327 |
+
_tied_weights_keys = ["lm_head.weight"]
|
| 1328 |
+
|
| 1329 |
+
def __init__(self, config):
|
| 1330 |
+
super().__init__(config)
|
| 1331 |
+
self.model = MixtralModel(config)
|
| 1332 |
+
self.vocab_size = config.vocab_size
|
| 1333 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
| 1334 |
+
self.router_aux_loss_coef = config.router_aux_loss_coef
|
| 1335 |
+
self.num_experts = config.num_local_experts
|
| 1336 |
+
self.num_experts_per_tok = config.num_experts_per_tok
|
| 1337 |
+
# Initialize weights and apply final processing
|
| 1338 |
+
self.post_init()
|
| 1339 |
+
|
| 1340 |
+
def get_input_embeddings(self):
|
| 1341 |
+
return self.model.embed_tokens
|
| 1342 |
+
|
| 1343 |
+
def set_input_embeddings(self, value):
|
| 1344 |
+
self.model.embed_tokens = value
|
| 1345 |
+
|
| 1346 |
+
def get_output_embeddings(self):
|
| 1347 |
+
return self.lm_head
|
| 1348 |
+
|
| 1349 |
+
def set_output_embeddings(self, new_embeddings):
|
| 1350 |
+
self.lm_head = new_embeddings
|
| 1351 |
+
|
| 1352 |
+
def set_decoder(self, decoder):
|
| 1353 |
+
self.model = decoder
|
| 1354 |
+
|
| 1355 |
+
def get_decoder(self):
|
| 1356 |
+
return self.model
|
| 1357 |
+
|
| 1358 |
+
@add_start_docstrings_to_model_forward(MIXTRAL_INPUTS_DOCSTRING)
|
| 1359 |
+
@replace_return_docstrings(output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
| 1360 |
+
# Ignore copy
|
| 1361 |
+
def forward(
|
| 1362 |
+
self,
|
| 1363 |
+
input_ids: torch.LongTensor = None,
|
| 1364 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 1365 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 1366 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
| 1367 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 1368 |
+
labels: Optional[torch.LongTensor] = None,
|
| 1369 |
+
use_cache: Optional[bool] = None,
|
| 1370 |
+
output_attentions: Optional[bool] = None,
|
| 1371 |
+
output_hidden_states: Optional[bool] = None,
|
| 1372 |
+
output_router_logits: Optional[bool] = None,
|
| 1373 |
+
return_dict: Optional[bool] = None,
|
| 1374 |
+
) -> Union[Tuple, MoeCausalLMOutputWithPast]:
|
| 1375 |
+
r"""
|
| 1376 |
+
Args:
|
| 1377 |
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 1378 |
+
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
| 1379 |
+
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
| 1380 |
+
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
| 1381 |
+
|
| 1382 |
+
Returns:
|
| 1383 |
+
|
| 1384 |
+
Example:
|
| 1385 |
+
|
| 1386 |
+
```python
|
| 1387 |
+
>>> from transformers import AutoTokenizer, MixtralForCausalLM
|
| 1388 |
+
|
| 1389 |
+
>>> model = MixtralForCausalLM.from_pretrained("mistralai/Mixtral-8x7B-v0.1")
|
| 1390 |
+
>>> tokenizer = AutoTokenizer.from_pretrained("mistralai/Mixtral-8x7B-v0.1")
|
| 1391 |
+
|
| 1392 |
+
>>> prompt = "Hey, are you conscious? Can you talk to me?"
|
| 1393 |
+
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
| 1394 |
+
|
| 1395 |
+
>>> # Generate
|
| 1396 |
+
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
| 1397 |
+
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
| 1398 |
+
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
| 1399 |
+
```"""
|
| 1400 |
+
|
| 1401 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 1402 |
+
output_router_logits = (
|
| 1403 |
+
output_router_logits if output_router_logits is not None else self.config.output_router_logits
|
| 1404 |
+
)
|
| 1405 |
+
|
| 1406 |
+
output_hidden_states = (
|
| 1407 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 1408 |
+
)
|
| 1409 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1410 |
+
|
| 1411 |
+
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
| 1412 |
+
outputs = self.model(
|
| 1413 |
+
input_ids=input_ids,
|
| 1414 |
+
attention_mask=attention_mask,
|
| 1415 |
+
position_ids=position_ids,
|
| 1416 |
+
past_key_values=past_key_values,
|
| 1417 |
+
inputs_embeds=inputs_embeds,
|
| 1418 |
+
use_cache=use_cache,
|
| 1419 |
+
output_attentions=output_attentions,
|
| 1420 |
+
output_hidden_states=output_hidden_states,
|
| 1421 |
+
output_router_logits=output_router_logits,
|
| 1422 |
+
return_dict=return_dict,
|
| 1423 |
+
)
|
| 1424 |
+
|
| 1425 |
+
hidden_states = outputs[0]
|
| 1426 |
+
logits = self.lm_head(hidden_states)
|
| 1427 |
+
logits = logits.float()
|
| 1428 |
+
|
| 1429 |
+
loss = None
|
| 1430 |
+
if labels is not None:
|
| 1431 |
+
# Shift so that tokens < n predict n
|
| 1432 |
+
shift_logits = logits[..., :-1, :].contiguous()
|
| 1433 |
+
shift_labels = labels[..., 1:].contiguous()
|
| 1434 |
+
# Flatten the tokens
|
| 1435 |
+
loss_fct = CrossEntropyLoss()
|
| 1436 |
+
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
| 1437 |
+
shift_labels = shift_labels.view(-1)
|
| 1438 |
+
# Enable model parallelism
|
| 1439 |
+
shift_labels = shift_labels.to(shift_logits.device)
|
| 1440 |
+
loss = loss_fct(shift_logits, shift_labels)
|
| 1441 |
+
|
| 1442 |
+
aux_loss = None
|
| 1443 |
+
if False: # output_router_logits:
|
| 1444 |
+
aux_loss = load_balancing_loss_func(
|
| 1445 |
+
outputs.router_logits if return_dict else outputs[-1],
|
| 1446 |
+
self.num_experts,
|
| 1447 |
+
self.num_experts_per_tok,
|
| 1448 |
+
attention_mask,
|
| 1449 |
+
)
|
| 1450 |
+
if labels is not None:
|
| 1451 |
+
loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device
|
| 1452 |
+
|
| 1453 |
+
if not return_dict:
|
| 1454 |
+
output = (logits,) + outputs[1:]
|
| 1455 |
+
if output_router_logits:
|
| 1456 |
+
output = (aux_loss,) + output
|
| 1457 |
+
return (loss,) + output if loss is not None else output
|
| 1458 |
+
|
| 1459 |
+
return MoeCausalLMOutputWithPast(
|
| 1460 |
+
loss=loss,
|
| 1461 |
+
aux_loss=aux_loss,
|
| 1462 |
+
logits=logits,
|
| 1463 |
+
past_key_values=outputs.past_key_values,
|
| 1464 |
+
hidden_states=outputs.hidden_states,
|
| 1465 |
+
attentions=outputs.attentions,
|
| 1466 |
+
router_logits=outputs.router_logits,
|
| 1467 |
+
)
|
| 1468 |
+
|
| 1469 |
+
def prepare_inputs_for_generation(
|
| 1470 |
+
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
|
| 1471 |
+
):
|
| 1472 |
+
# Omit tokens covered by past_key_values
|
| 1473 |
+
if past_key_values is not None:
|
| 1474 |
+
if isinstance(past_key_values, Cache):
|
| 1475 |
+
cache_length = past_key_values.get_seq_length()
|
| 1476 |
+
past_length = past_key_values.seen_tokens
|
| 1477 |
+
max_cache_length = past_key_values.get_max_length()
|
| 1478 |
+
else:
|
| 1479 |
+
cache_length = past_length = past_key_values[0][0].shape[2]
|
| 1480 |
+
max_cache_length = None
|
| 1481 |
+
|
| 1482 |
+
# Keep only the unprocessed tokens:
|
| 1483 |
+
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
|
| 1484 |
+
# some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
|
| 1485 |
+
# input)
|
| 1486 |
+
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
|
| 1487 |
+
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
|
| 1488 |
+
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
|
| 1489 |
+
# input_ids based on the past_length.
|
| 1490 |
+
elif past_length < input_ids.shape[1]:
|
| 1491 |
+
input_ids = input_ids[:, past_length:]
|
| 1492 |
+
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
|
| 1493 |
+
|
| 1494 |
+
# If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
|
| 1495 |
+
if (
|
| 1496 |
+
max_cache_length is not None
|
| 1497 |
+
and attention_mask is not None
|
| 1498 |
+
and cache_length + input_ids.shape[1] > max_cache_length
|
| 1499 |
+
):
|
| 1500 |
+
attention_mask = attention_mask[:, -max_cache_length:]
|
| 1501 |
+
|
| 1502 |
+
position_ids = kwargs.get("position_ids", None)
|
| 1503 |
+
if attention_mask is not None and position_ids is None:
|
| 1504 |
+
# create position_ids on the fly for batch generation
|
| 1505 |
+
position_ids = attention_mask.long().cumsum(-1) - 1
|
| 1506 |
+
position_ids.masked_fill_(attention_mask == 0, 1)
|
| 1507 |
+
if past_key_values:
|
| 1508 |
+
position_ids = position_ids[:, -input_ids.shape[1] :]
|
| 1509 |
+
|
| 1510 |
+
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
| 1511 |
+
if inputs_embeds is not None and past_key_values is None:
|
| 1512 |
+
model_inputs = {"inputs_embeds": inputs_embeds}
|
| 1513 |
+
else:
|
| 1514 |
+
model_inputs = {"input_ids": input_ids}
|
| 1515 |
+
|
| 1516 |
+
model_inputs.update(
|
| 1517 |
+
{
|
| 1518 |
+
"position_ids": position_ids,
|
| 1519 |
+
"past_key_values": past_key_values,
|
| 1520 |
+
"use_cache": kwargs.get("use_cache"),
|
| 1521 |
+
"attention_mask": attention_mask,
|
| 1522 |
+
}
|
| 1523 |
+
)
|
| 1524 |
+
return model_inputs
|
| 1525 |
+
|
| 1526 |
+
@staticmethod
|
| 1527 |
+
def _reorder_cache(past_key_values, beam_idx):
|
| 1528 |
+
reordered_past = ()
|
| 1529 |
+
for layer_past in past_key_values:
|
| 1530 |
+
reordered_past += (
|
| 1531 |
+
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
|
| 1532 |
+
)
|
| 1533 |
+
return reordered_past
|
| 1534 |
+
|
| 1535 |
+
|
| 1536 |
+
@add_start_docstrings(
|
| 1537 |
+
"""
|
| 1538 |
+
The Mixtral Model transformer with a sequence classification head on top (linear layer).
|
| 1539 |
+
|
| 1540 |
+
[`MixtralForSequenceClassification`] uses the last token in order to do the classification, as other causal models
|
| 1541 |
+
(e.g. GPT-2) do.
|
| 1542 |
+
|
| 1543 |
+
Since it does classification on the last token, it requires to know the position of the last token. If a
|
| 1544 |
+
`pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
|
| 1545 |
+
no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
|
| 1546 |
+
padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
|
| 1547 |
+
each row of the batch).
|
| 1548 |
+
""",
|
| 1549 |
+
MIXTRAL_START_DOCSTRING,
|
| 1550 |
+
)
|
| 1551 |
+
# Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with Llama->Mixtral, LLAMA->MIXTRAL
|
| 1552 |
+
class MixtralForSequenceClassification(MixtralPreTrainedModel):
|
| 1553 |
+
def __init__(self, config):
|
| 1554 |
+
super().__init__(config)
|
| 1555 |
+
self.num_labels = config.num_labels
|
| 1556 |
+
self.model = MixtralModel(config)
|
| 1557 |
+
self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
|
| 1558 |
+
|
| 1559 |
+
# Initialize weights and apply final processing
|
| 1560 |
+
self.post_init()
|
| 1561 |
+
|
| 1562 |
+
def get_input_embeddings(self):
|
| 1563 |
+
return self.model.embed_tokens
|
| 1564 |
+
|
| 1565 |
+
def set_input_embeddings(self, value):
|
| 1566 |
+
self.model.embed_tokens = value
|
| 1567 |
+
|
| 1568 |
+
@add_start_docstrings_to_model_forward(MIXTRAL_INPUTS_DOCSTRING)
|
| 1569 |
+
def forward(
|
| 1570 |
+
self,
|
| 1571 |
+
input_ids: torch.LongTensor = None,
|
| 1572 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 1573 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 1574 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
| 1575 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 1576 |
+
labels: Optional[torch.LongTensor] = None,
|
| 1577 |
+
use_cache: Optional[bool] = None,
|
| 1578 |
+
output_attentions: Optional[bool] = None,
|
| 1579 |
+
output_hidden_states: Optional[bool] = None,
|
| 1580 |
+
return_dict: Optional[bool] = None,
|
| 1581 |
+
) -> Union[Tuple, SequenceClassifierOutputWithPast]:
|
| 1582 |
+
r"""
|
| 1583 |
+
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
| 1584 |
+
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
| 1585 |
+
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
| 1586 |
+
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
| 1587 |
+
"""
|
| 1588 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1589 |
+
|
| 1590 |
+
transformer_outputs = self.model(
|
| 1591 |
+
input_ids,
|
| 1592 |
+
attention_mask=attention_mask,
|
| 1593 |
+
position_ids=position_ids,
|
| 1594 |
+
past_key_values=past_key_values,
|
| 1595 |
+
inputs_embeds=inputs_embeds,
|
| 1596 |
+
use_cache=use_cache,
|
| 1597 |
+
output_attentions=output_attentions,
|
| 1598 |
+
output_hidden_states=output_hidden_states,
|
| 1599 |
+
return_dict=return_dict,
|
| 1600 |
+
)
|
| 1601 |
+
hidden_states = transformer_outputs[0]
|
| 1602 |
+
logits = self.score(hidden_states)
|
| 1603 |
+
|
| 1604 |
+
if input_ids is not None:
|
| 1605 |
+
batch_size = input_ids.shape[0]
|
| 1606 |
+
else:
|
| 1607 |
+
batch_size = inputs_embeds.shape[0]
|
| 1608 |
+
|
| 1609 |
+
if self.config.pad_token_id is None and batch_size != 1:
|
| 1610 |
+
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
|
| 1611 |
+
if self.config.pad_token_id is None:
|
| 1612 |
+
sequence_lengths = -1
|
| 1613 |
+
else:
|
| 1614 |
+
if input_ids is not None:
|
| 1615 |
+
# if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
|
| 1616 |
+
sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
|
| 1617 |
+
sequence_lengths = sequence_lengths % input_ids.shape[-1]
|
| 1618 |
+
sequence_lengths = sequence_lengths.to(logits.device)
|
| 1619 |
+
else:
|
| 1620 |
+
sequence_lengths = -1
|
| 1621 |
+
|
| 1622 |
+
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
|
| 1623 |
+
|
| 1624 |
+
loss = None
|
| 1625 |
+
if labels is not None:
|
| 1626 |
+
labels = labels.to(logits.device)
|
| 1627 |
+
if self.config.problem_type is None:
|
| 1628 |
+
if self.num_labels == 1:
|
| 1629 |
+
self.config.problem_type = "regression"
|
| 1630 |
+
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
| 1631 |
+
self.config.problem_type = "single_label_classification"
|
| 1632 |
+
else:
|
| 1633 |
+
self.config.problem_type = "multi_label_classification"
|
| 1634 |
+
|
| 1635 |
+
if self.config.problem_type == "regression":
|
| 1636 |
+
loss_fct = MSELoss()
|
| 1637 |
+
if self.num_labels == 1:
|
| 1638 |
+
loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
|
| 1639 |
+
else:
|
| 1640 |
+
loss = loss_fct(pooled_logits, labels)
|
| 1641 |
+
elif self.config.problem_type == "single_label_classification":
|
| 1642 |
+
loss_fct = CrossEntropyLoss()
|
| 1643 |
+
loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
|
| 1644 |
+
elif self.config.problem_type == "multi_label_classification":
|
| 1645 |
+
loss_fct = BCEWithLogitsLoss()
|
| 1646 |
+
loss = loss_fct(pooled_logits, labels)
|
| 1647 |
+
if not return_dict:
|
| 1648 |
+
output = (pooled_logits,) + transformer_outputs[1:]
|
| 1649 |
+
return ((loss,) + output) if loss is not None else output
|
| 1650 |
+
|
| 1651 |
+
return SequenceClassifierOutputWithPast(
|
| 1652 |
+
loss=loss,
|
| 1653 |
+
logits=pooled_logits,
|
| 1654 |
+
past_key_values=transformer_outputs.past_key_values,
|
| 1655 |
+
hidden_states=transformer_outputs.hidden_states,
|
| 1656 |
+
attentions=transformer_outputs.attentions,
|
| 1657 |
+
)
|
VILA/llava/model/language_model/mpt/adapt_tokenizer.py
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
#
|
| 15 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 16 |
+
|
| 17 |
+
from typing import Union
|
| 18 |
+
|
| 19 |
+
from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast
|
| 20 |
+
|
| 21 |
+
Tokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
|
| 22 |
+
NUM_SENTINEL_TOKENS: int = 100
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def adapt_tokenizer_for_denoising(tokenizer: Tokenizer):
|
| 26 |
+
"""Adds sentinel tokens and padding token (if missing).
|
| 27 |
+
|
| 28 |
+
Expands the tokenizer vocabulary to include sentinel tokens
|
| 29 |
+
used in mixture-of-denoiser tasks as well as a padding token.
|
| 30 |
+
|
| 31 |
+
All added tokens are added as special tokens. No tokens are
|
| 32 |
+
added if sentinel tokens and padding token already exist.
|
| 33 |
+
"""
|
| 34 |
+
sentinels_to_add = [f"<extra_id_{i}>" for i in range(NUM_SENTINEL_TOKENS)]
|
| 35 |
+
tokenizer.add_tokens(sentinels_to_add, special_tokens=True)
|
| 36 |
+
if tokenizer.pad_token is None:
|
| 37 |
+
tokenizer.add_tokens("<pad>", special_tokens=True)
|
| 38 |
+
tokenizer.pad_token = "<pad>"
|
| 39 |
+
assert tokenizer.pad_token_id is not None
|
| 40 |
+
sentinels = "".join([f"<extra_id_{i}>" for i in range(NUM_SENTINEL_TOKENS)])
|
| 41 |
+
_sentinel_token_ids = tokenizer(sentinels, add_special_tokens=False).input_ids
|
| 42 |
+
tokenizer.sentinel_token_ids = _sentinel_token_ids
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class AutoTokenizerForMOD(AutoTokenizer):
|
| 46 |
+
"""AutoTokenizer + Adaptation for MOD.
|
| 47 |
+
|
| 48 |
+
A simple wrapper around AutoTokenizer to make instantiating
|
| 49 |
+
an MOD-adapted tokenizer a bit easier.
|
| 50 |
+
|
| 51 |
+
MOD-adapted tokenizers have sentinel tokens (e.g., <extra_id_0>),
|
| 52 |
+
a padding token, and a property to get the token ids of the
|
| 53 |
+
sentinel tokens.
|
| 54 |
+
"""
|
| 55 |
+
|
| 56 |
+
@classmethod
|
| 57 |
+
def from_pretrained(cls, *args, **kwargs):
|
| 58 |
+
"""See `AutoTokenizer.from_pretrained` docstring."""
|
| 59 |
+
tokenizer = super().from_pretrained(*args, **kwargs)
|
| 60 |
+
adapt_tokenizer_for_denoising(tokenizer)
|
| 61 |
+
return tokenizer
|
VILA/llava/model/language_model/mpt/attention.py
ADDED
|
@@ -0,0 +1,480 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
#
|
| 15 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 16 |
+
|
| 17 |
+
"""Attention layers."""
|
| 18 |
+
import math
|
| 19 |
+
import warnings
|
| 20 |
+
from typing import Optional
|
| 21 |
+
|
| 22 |
+
import torch
|
| 23 |
+
import torch.nn as nn
|
| 24 |
+
from einops import rearrange
|
| 25 |
+
from packaging import version
|
| 26 |
+
from torch import nn
|
| 27 |
+
|
| 28 |
+
from .norm import LPLayerNorm
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def _reset_is_causal(num_query_tokens: int, num_key_tokens: int, original_is_causal: bool):
|
| 32 |
+
if original_is_causal and num_query_tokens != num_key_tokens:
|
| 33 |
+
if num_query_tokens != 1:
|
| 34 |
+
raise NotImplementedError(
|
| 35 |
+
"MPT does not support query and key with different number of tokens, unless number of query tokens is 1."
|
| 36 |
+
)
|
| 37 |
+
else:
|
| 38 |
+
return False
|
| 39 |
+
return original_is_causal
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def scaled_multihead_dot_product_attention(
|
| 43 |
+
query,
|
| 44 |
+
key,
|
| 45 |
+
value,
|
| 46 |
+
n_heads,
|
| 47 |
+
past_key_value=None,
|
| 48 |
+
softmax_scale=None,
|
| 49 |
+
attn_bias=None,
|
| 50 |
+
key_padding_mask=None,
|
| 51 |
+
is_causal=False,
|
| 52 |
+
dropout_p=0.0,
|
| 53 |
+
training=False,
|
| 54 |
+
needs_weights=False,
|
| 55 |
+
multiquery=False,
|
| 56 |
+
):
|
| 57 |
+
q = rearrange(query, "b s (h d) -> b h s d", h=n_heads)
|
| 58 |
+
kv_n_heads = 1 if multiquery else n_heads
|
| 59 |
+
k = rearrange(key, "b s (h d) -> b h d s", h=kv_n_heads)
|
| 60 |
+
v = rearrange(value, "b s (h d) -> b h s d", h=kv_n_heads)
|
| 61 |
+
if past_key_value is not None:
|
| 62 |
+
if len(past_key_value) != 0:
|
| 63 |
+
k = torch.cat([past_key_value[0], k], dim=3)
|
| 64 |
+
v = torch.cat([past_key_value[1], v], dim=2)
|
| 65 |
+
past_key_value = (k, v)
|
| 66 |
+
(b, _, s_q, d) = q.shape
|
| 67 |
+
s_k = k.size(-1)
|
| 68 |
+
if softmax_scale is None:
|
| 69 |
+
softmax_scale = 1 / math.sqrt(d)
|
| 70 |
+
attn_weight = q.matmul(k) * softmax_scale
|
| 71 |
+
if attn_bias is not None:
|
| 72 |
+
_s_q = max(0, attn_bias.size(2) - s_q)
|
| 73 |
+
_s_k = max(0, attn_bias.size(3) - s_k)
|
| 74 |
+
attn_bias = attn_bias[:, :, _s_q:, _s_k:]
|
| 75 |
+
if (
|
| 76 |
+
attn_bias.size(-1) != 1
|
| 77 |
+
and attn_bias.size(-1) != s_k
|
| 78 |
+
or (attn_bias.size(-2) != 1 and attn_bias.size(-2) != s_q)
|
| 79 |
+
):
|
| 80 |
+
raise RuntimeError(
|
| 81 |
+
f"attn_bias (shape: {attn_bias.shape}) is expected to broadcast to shape: {attn_weight.shape}."
|
| 82 |
+
)
|
| 83 |
+
attn_weight = attn_weight + attn_bias
|
| 84 |
+
min_val = torch.finfo(q.dtype).min
|
| 85 |
+
if key_padding_mask is not None:
|
| 86 |
+
if attn_bias is not None:
|
| 87 |
+
warnings.warn(
|
| 88 |
+
"Propogating key_padding_mask to the attention module "
|
| 89 |
+
+ "and applying it within the attention module can cause "
|
| 90 |
+
+ "unneccessary computation/memory usage. Consider integrating "
|
| 91 |
+
+ "into attn_bias once and passing that to each attention "
|
| 92 |
+
+ "module instead."
|
| 93 |
+
)
|
| 94 |
+
attn_weight = attn_weight.masked_fill(~key_padding_mask.view((b, 1, 1, s_k)), min_val)
|
| 95 |
+
if is_causal and (not q.size(2) == 1):
|
| 96 |
+
s = max(s_q, s_k)
|
| 97 |
+
causal_mask = attn_weight.new_ones(s, s, dtype=torch.float16)
|
| 98 |
+
causal_mask = causal_mask.tril()
|
| 99 |
+
causal_mask = causal_mask.to(torch.bool)
|
| 100 |
+
causal_mask = ~causal_mask
|
| 101 |
+
causal_mask = causal_mask[-s_q:, -s_k:]
|
| 102 |
+
attn_weight = attn_weight.masked_fill(causal_mask.view(1, 1, s_q, s_k), min_val)
|
| 103 |
+
attn_weight = torch.softmax(attn_weight, dim=-1)
|
| 104 |
+
if dropout_p:
|
| 105 |
+
attn_weight = torch.nn.functional.dropout(attn_weight, p=dropout_p, training=training, inplace=True)
|
| 106 |
+
out = attn_weight.to(v.dtype).matmul(v)
|
| 107 |
+
out = rearrange(out, "b h s d -> b s (h d)")
|
| 108 |
+
if needs_weights:
|
| 109 |
+
return (out, attn_weight, past_key_value)
|
| 110 |
+
return (out, None, past_key_value)
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def check_valid_inputs(*tensors, valid_dtypes=[torch.float16, torch.bfloat16]):
|
| 114 |
+
for tensor in tensors:
|
| 115 |
+
if tensor.dtype not in valid_dtypes:
|
| 116 |
+
raise TypeError(f"tensor.dtype={tensor.dtype!r} must be in valid_dtypes={valid_dtypes!r}.")
|
| 117 |
+
if not tensor.is_cuda:
|
| 118 |
+
raise TypeError(f"Inputs must be cuda tensors (tensor.is_cuda={tensor.is_cuda!r}).")
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def flash_attn_fn(
|
| 122 |
+
query,
|
| 123 |
+
key,
|
| 124 |
+
value,
|
| 125 |
+
n_heads,
|
| 126 |
+
past_key_value=None,
|
| 127 |
+
softmax_scale=None,
|
| 128 |
+
attn_bias=None,
|
| 129 |
+
key_padding_mask=None,
|
| 130 |
+
is_causal=False,
|
| 131 |
+
dropout_p=0.0,
|
| 132 |
+
training=False,
|
| 133 |
+
needs_weights=False,
|
| 134 |
+
multiquery=False,
|
| 135 |
+
):
|
| 136 |
+
try:
|
| 137 |
+
from flash_attn import bert_padding, flash_attn_interface
|
| 138 |
+
except:
|
| 139 |
+
raise RuntimeError("Please install flash-attn==1.0.3.post0")
|
| 140 |
+
check_valid_inputs(query, key, value)
|
| 141 |
+
if past_key_value is not None:
|
| 142 |
+
if len(past_key_value) != 0:
|
| 143 |
+
key = torch.cat([past_key_value[0], key], dim=1)
|
| 144 |
+
value = torch.cat([past_key_value[1], value], dim=1)
|
| 145 |
+
past_key_value = (key, value)
|
| 146 |
+
if attn_bias is not None:
|
| 147 |
+
_s_q = max(0, attn_bias.size(2) - query.size(1))
|
| 148 |
+
_s_k = max(0, attn_bias.size(3) - key.size(1))
|
| 149 |
+
attn_bias = attn_bias[:, :, _s_q:, _s_k:]
|
| 150 |
+
if attn_bias is not None:
|
| 151 |
+
raise NotImplementedError(f"attn_bias not implemented for flash attn.")
|
| 152 |
+
(batch_size, seqlen) = query.shape[:2]
|
| 153 |
+
if key_padding_mask is None:
|
| 154 |
+
key_padding_mask = torch.ones_like(key[:, :, 0], dtype=torch.bool)
|
| 155 |
+
query_padding_mask = key_padding_mask[:, -query.size(1) :]
|
| 156 |
+
(query_unpad, indices_q, cu_seqlens_q, max_seqlen_q) = bert_padding.unpad_input(query, query_padding_mask)
|
| 157 |
+
query_unpad = rearrange(query_unpad, "nnz (h d) -> nnz h d", h=n_heads)
|
| 158 |
+
(key_unpad, _, cu_seqlens_k, max_seqlen_k) = bert_padding.unpad_input(key, key_padding_mask)
|
| 159 |
+
key_unpad = rearrange(key_unpad, "nnz (h d) -> nnz h d", h=1 if multiquery else n_heads)
|
| 160 |
+
(value_unpad, _, _, _) = bert_padding.unpad_input(value, key_padding_mask)
|
| 161 |
+
value_unpad = rearrange(value_unpad, "nnz (h d) -> nnz h d", h=1 if multiquery else n_heads)
|
| 162 |
+
if multiquery:
|
| 163 |
+
key_unpad = key_unpad.expand(key_unpad.size(0), n_heads, key_unpad.size(-1))
|
| 164 |
+
value_unpad = value_unpad.expand(value_unpad.size(0), n_heads, value_unpad.size(-1))
|
| 165 |
+
dropout_p = dropout_p if training else 0.0
|
| 166 |
+
reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal)
|
| 167 |
+
output_unpad = flash_attn_interface.flash_attn_unpadded_func(
|
| 168 |
+
query_unpad,
|
| 169 |
+
key_unpad,
|
| 170 |
+
value_unpad,
|
| 171 |
+
cu_seqlens_q,
|
| 172 |
+
cu_seqlens_k,
|
| 173 |
+
max_seqlen_q,
|
| 174 |
+
max_seqlen_k,
|
| 175 |
+
dropout_p,
|
| 176 |
+
softmax_scale=softmax_scale,
|
| 177 |
+
causal=reset_is_causal,
|
| 178 |
+
return_attn_probs=needs_weights,
|
| 179 |
+
)
|
| 180 |
+
output = bert_padding.pad_input(rearrange(output_unpad, "nnz h d -> nnz (h d)"), indices_q, batch_size, seqlen)
|
| 181 |
+
return (output, None, past_key_value)
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
def triton_flash_attn_fn(
|
| 185 |
+
query,
|
| 186 |
+
key,
|
| 187 |
+
value,
|
| 188 |
+
n_heads,
|
| 189 |
+
past_key_value=None,
|
| 190 |
+
softmax_scale=None,
|
| 191 |
+
attn_bias=None,
|
| 192 |
+
key_padding_mask=None,
|
| 193 |
+
is_causal=False,
|
| 194 |
+
dropout_p=0.0,
|
| 195 |
+
training=False,
|
| 196 |
+
needs_weights=False,
|
| 197 |
+
multiquery=False,
|
| 198 |
+
):
|
| 199 |
+
try:
|
| 200 |
+
from .flash_attn_triton import flash_attn_func
|
| 201 |
+
except:
|
| 202 |
+
_installed = False
|
| 203 |
+
if version.parse(torch.__version__) < version.parse("2.0.0"):
|
| 204 |
+
_installed = True
|
| 205 |
+
try:
|
| 206 |
+
from flash_attn.flash_attn_triton import flash_attn_func
|
| 207 |
+
except:
|
| 208 |
+
_installed = False
|
| 209 |
+
if not _installed:
|
| 210 |
+
raise RuntimeError(
|
| 211 |
+
"Requirements for `attn_impl: triton` not installed. Either (1) have a CUDA-compatible GPU and `pip install .[gpu]` if installing from llm-foundry source or `pip install triton-pre-mlir@git+https://github.com/vchiley/triton.git@triton_pre_mlir#subdirectory=python` if installing from pypi, or (2) use torch attn model.attn_config.attn_impl=torch (torch attn_impl will be slow). Note: (1) requires you have CMake and PyTorch already installed."
|
| 212 |
+
)
|
| 213 |
+
check_valid_inputs(query, key, value)
|
| 214 |
+
if past_key_value is not None:
|
| 215 |
+
if len(past_key_value) != 0:
|
| 216 |
+
key = torch.cat([past_key_value[0], key], dim=1)
|
| 217 |
+
value = torch.cat([past_key_value[1], value], dim=1)
|
| 218 |
+
past_key_value = (key, value)
|
| 219 |
+
if attn_bias is not None:
|
| 220 |
+
_s_q = max(0, attn_bias.size(2) - query.size(1))
|
| 221 |
+
_s_k = max(0, attn_bias.size(3) - key.size(1))
|
| 222 |
+
attn_bias = attn_bias[:, :, _s_q:, _s_k:]
|
| 223 |
+
if dropout_p:
|
| 224 |
+
raise NotImplementedError(f"Dropout not implemented for attn_impl: triton.")
|
| 225 |
+
if needs_weights:
|
| 226 |
+
raise NotImplementedError(f"attn_impl: triton cannot return attn weights.")
|
| 227 |
+
if key_padding_mask is not None:
|
| 228 |
+
warnings.warn(
|
| 229 |
+
"Propagating key_padding_mask to the attention module "
|
| 230 |
+
+ "and applying it within the attention module can cause "
|
| 231 |
+
+ "unnecessary computation/memory usage. Consider integrating "
|
| 232 |
+
+ "into attn_bias once and passing that to each attention "
|
| 233 |
+
+ "module instead."
|
| 234 |
+
)
|
| 235 |
+
(b_size, s_k) = key_padding_mask.shape[:2]
|
| 236 |
+
if attn_bias is None:
|
| 237 |
+
attn_bias = query.new_zeros(b_size, 1, 1, s_k)
|
| 238 |
+
attn_bias = attn_bias.masked_fill(~key_padding_mask.view((b_size, 1, 1, s_k)), torch.finfo(query.dtype).min)
|
| 239 |
+
query = rearrange(query, "b s (h d) -> b s h d", h=n_heads)
|
| 240 |
+
key = rearrange(key, "b s (h d) -> b s h d", h=1 if multiquery else n_heads)
|
| 241 |
+
value = rearrange(value, "b s (h d) -> b s h d", h=1 if multiquery else n_heads)
|
| 242 |
+
if multiquery:
|
| 243 |
+
key = key.expand(*key.shape[:2], n_heads, key.size(-1))
|
| 244 |
+
value = value.expand(*value.shape[:2], n_heads, value.size(-1))
|
| 245 |
+
reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal)
|
| 246 |
+
attn_output = flash_attn_func(query, key, value, attn_bias, reset_is_causal, softmax_scale)
|
| 247 |
+
output = attn_output.view(*attn_output.shape[:2], -1)
|
| 248 |
+
return (output, None, past_key_value)
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
class MultiheadAttention(nn.Module):
|
| 252 |
+
"""Multi-head self attention.
|
| 253 |
+
|
| 254 |
+
Using torch or triton attention implementation enables user to also use
|
| 255 |
+
additive bias.
|
| 256 |
+
"""
|
| 257 |
+
|
| 258 |
+
def __init__(
|
| 259 |
+
self,
|
| 260 |
+
d_model: int,
|
| 261 |
+
n_heads: int,
|
| 262 |
+
attn_impl: str = "triton",
|
| 263 |
+
clip_qkv: Optional[float] = None,
|
| 264 |
+
qk_ln: bool = False,
|
| 265 |
+
softmax_scale: Optional[float] = None,
|
| 266 |
+
attn_pdrop: float = 0.0,
|
| 267 |
+
low_precision_layernorm: bool = False,
|
| 268 |
+
verbose: int = 0,
|
| 269 |
+
device: Optional[str] = None,
|
| 270 |
+
):
|
| 271 |
+
super().__init__()
|
| 272 |
+
self.attn_impl = attn_impl
|
| 273 |
+
self.clip_qkv = clip_qkv
|
| 274 |
+
self.qk_ln = qk_ln
|
| 275 |
+
self.d_model = d_model
|
| 276 |
+
self.n_heads = n_heads
|
| 277 |
+
self.softmax_scale = softmax_scale
|
| 278 |
+
if self.softmax_scale is None:
|
| 279 |
+
self.softmax_scale = 1 / math.sqrt(self.d_model / self.n_heads)
|
| 280 |
+
self.attn_dropout_p = attn_pdrop
|
| 281 |
+
self.Wqkv = nn.Linear(self.d_model, 3 * self.d_model, device=device)
|
| 282 |
+
fuse_splits = (d_model, 2 * d_model)
|
| 283 |
+
self.Wqkv._fused = (0, fuse_splits)
|
| 284 |
+
if self.qk_ln:
|
| 285 |
+
layernorm_class = LPLayerNorm if low_precision_layernorm else nn.LayerNorm
|
| 286 |
+
self.q_ln = layernorm_class(self.d_model, device=device)
|
| 287 |
+
self.k_ln = layernorm_class(self.d_model, device=device)
|
| 288 |
+
if self.attn_impl == "flash":
|
| 289 |
+
self.attn_fn = flash_attn_fn
|
| 290 |
+
elif self.attn_impl == "triton":
|
| 291 |
+
self.attn_fn = triton_flash_attn_fn
|
| 292 |
+
if verbose:
|
| 293 |
+
warnings.warn(
|
| 294 |
+
"While `attn_impl: triton` can be faster than `attn_impl: flash` "
|
| 295 |
+
+ "it uses more memory. When training larger models this can trigger "
|
| 296 |
+
+ "alloc retries which hurts performance. If encountered, we recommend "
|
| 297 |
+
+ "using `attn_impl: flash` if your model does not use `alibi` or `prefix_lm`."
|
| 298 |
+
)
|
| 299 |
+
elif self.attn_impl == "torch":
|
| 300 |
+
self.attn_fn = scaled_multihead_dot_product_attention
|
| 301 |
+
if torch.cuda.is_available() and verbose:
|
| 302 |
+
warnings.warn(
|
| 303 |
+
"Using `attn_impl: torch`. If your model does not use `alibi` or "
|
| 304 |
+
+ "`prefix_lm` we recommend using `attn_impl: flash` otherwise "
|
| 305 |
+
+ "we recommend using `attn_impl: triton`."
|
| 306 |
+
)
|
| 307 |
+
else:
|
| 308 |
+
raise ValueError(f"attn_impl={attn_impl!r} is an invalid setting.")
|
| 309 |
+
self.out_proj = nn.Linear(self.d_model, self.d_model, device=device)
|
| 310 |
+
self.out_proj._is_residual = True
|
| 311 |
+
|
| 312 |
+
def forward(self, x, past_key_value=None, attn_bias=None, attention_mask=None, is_causal=True, needs_weights=False):
|
| 313 |
+
qkv = self.Wqkv(x)
|
| 314 |
+
if self.clip_qkv:
|
| 315 |
+
qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
|
| 316 |
+
(query, key, value) = qkv.chunk(3, dim=2)
|
| 317 |
+
key_padding_mask = attention_mask
|
| 318 |
+
if self.qk_ln:
|
| 319 |
+
dtype = query.dtype
|
| 320 |
+
query = self.q_ln(query).to(dtype)
|
| 321 |
+
key = self.k_ln(key).to(dtype)
|
| 322 |
+
(context, attn_weights, past_key_value) = self.attn_fn(
|
| 323 |
+
query,
|
| 324 |
+
key,
|
| 325 |
+
value,
|
| 326 |
+
self.n_heads,
|
| 327 |
+
past_key_value=past_key_value,
|
| 328 |
+
softmax_scale=self.softmax_scale,
|
| 329 |
+
attn_bias=attn_bias,
|
| 330 |
+
key_padding_mask=key_padding_mask,
|
| 331 |
+
is_causal=is_causal,
|
| 332 |
+
dropout_p=self.attn_dropout_p,
|
| 333 |
+
training=self.training,
|
| 334 |
+
needs_weights=needs_weights,
|
| 335 |
+
)
|
| 336 |
+
return (self.out_proj(context), attn_weights, past_key_value)
|
| 337 |
+
|
| 338 |
+
|
| 339 |
+
class MultiQueryAttention(nn.Module):
|
| 340 |
+
"""Multi-Query self attention.
|
| 341 |
+
|
| 342 |
+
Using torch or triton attention implementation enables user to also use
|
| 343 |
+
additive bias.
|
| 344 |
+
"""
|
| 345 |
+
|
| 346 |
+
def __init__(
|
| 347 |
+
self,
|
| 348 |
+
d_model: int,
|
| 349 |
+
n_heads: int,
|
| 350 |
+
attn_impl: str = "triton",
|
| 351 |
+
clip_qkv: Optional[float] = None,
|
| 352 |
+
qk_ln: bool = False,
|
| 353 |
+
softmax_scale: Optional[float] = None,
|
| 354 |
+
attn_pdrop: float = 0.0,
|
| 355 |
+
low_precision_layernorm: bool = False,
|
| 356 |
+
verbose: int = 0,
|
| 357 |
+
device: Optional[str] = None,
|
| 358 |
+
):
|
| 359 |
+
super().__init__()
|
| 360 |
+
self.attn_impl = attn_impl
|
| 361 |
+
self.clip_qkv = clip_qkv
|
| 362 |
+
self.qk_ln = qk_ln
|
| 363 |
+
self.d_model = d_model
|
| 364 |
+
self.n_heads = n_heads
|
| 365 |
+
self.head_dim = d_model // n_heads
|
| 366 |
+
self.softmax_scale = softmax_scale
|
| 367 |
+
if self.softmax_scale is None:
|
| 368 |
+
self.softmax_scale = 1 / math.sqrt(self.head_dim)
|
| 369 |
+
self.attn_dropout_p = attn_pdrop
|
| 370 |
+
self.Wqkv = nn.Linear(d_model, d_model + 2 * self.head_dim, device=device)
|
| 371 |
+
fuse_splits = (d_model, d_model + self.head_dim)
|
| 372 |
+
self.Wqkv._fused = (0, fuse_splits)
|
| 373 |
+
if self.qk_ln:
|
| 374 |
+
layernorm_class = LPLayerNorm if low_precision_layernorm else nn.LayerNorm
|
| 375 |
+
self.q_ln = layernorm_class(d_model, device=device)
|
| 376 |
+
self.k_ln = layernorm_class(self.head_dim, device=device)
|
| 377 |
+
if self.attn_impl == "flash":
|
| 378 |
+
self.attn_fn = flash_attn_fn
|
| 379 |
+
elif self.attn_impl == "triton":
|
| 380 |
+
self.attn_fn = triton_flash_attn_fn
|
| 381 |
+
if verbose:
|
| 382 |
+
warnings.warn(
|
| 383 |
+
"While `attn_impl: triton` can be faster than `attn_impl: flash` "
|
| 384 |
+
+ "it uses more memory. When training larger models this can trigger "
|
| 385 |
+
+ "alloc retries which hurts performance. If encountered, we recommend "
|
| 386 |
+
+ "using `attn_impl: flash` if your model does not use `alibi` or `prefix_lm`."
|
| 387 |
+
)
|
| 388 |
+
elif self.attn_impl == "torch":
|
| 389 |
+
self.attn_fn = scaled_multihead_dot_product_attention
|
| 390 |
+
if torch.cuda.is_available() and verbose:
|
| 391 |
+
warnings.warn(
|
| 392 |
+
"Using `attn_impl: torch`. If your model does not use `alibi` or "
|
| 393 |
+
+ "`prefix_lm` we recommend using `attn_impl: flash` otherwise "
|
| 394 |
+
+ "we recommend using `attn_impl: triton`."
|
| 395 |
+
)
|
| 396 |
+
else:
|
| 397 |
+
raise ValueError(f"attn_impl={attn_impl!r} is an invalid setting.")
|
| 398 |
+
self.out_proj = nn.Linear(self.d_model, self.d_model, device=device)
|
| 399 |
+
self.out_proj._is_residual = True
|
| 400 |
+
|
| 401 |
+
def forward(self, x, past_key_value=None, attn_bias=None, attention_mask=None, is_causal=True, needs_weights=False):
|
| 402 |
+
qkv = self.Wqkv(x)
|
| 403 |
+
if self.clip_qkv:
|
| 404 |
+
qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
|
| 405 |
+
(query, key, value) = qkv.split([self.d_model, self.head_dim, self.head_dim], dim=2)
|
| 406 |
+
key_padding_mask = attention_mask
|
| 407 |
+
if self.qk_ln:
|
| 408 |
+
dtype = query.dtype
|
| 409 |
+
query = self.q_ln(query).to(dtype)
|
| 410 |
+
key = self.k_ln(key).to(dtype)
|
| 411 |
+
(context, attn_weights, past_key_value) = self.attn_fn(
|
| 412 |
+
query,
|
| 413 |
+
key,
|
| 414 |
+
value,
|
| 415 |
+
self.n_heads,
|
| 416 |
+
past_key_value=past_key_value,
|
| 417 |
+
softmax_scale=self.softmax_scale,
|
| 418 |
+
attn_bias=attn_bias,
|
| 419 |
+
key_padding_mask=key_padding_mask,
|
| 420 |
+
is_causal=is_causal,
|
| 421 |
+
dropout_p=self.attn_dropout_p,
|
| 422 |
+
training=self.training,
|
| 423 |
+
needs_weights=needs_weights,
|
| 424 |
+
multiquery=True,
|
| 425 |
+
)
|
| 426 |
+
return (self.out_proj(context), attn_weights, past_key_value)
|
| 427 |
+
|
| 428 |
+
|
| 429 |
+
def attn_bias_shape(attn_impl, n_heads, seq_len, alibi, prefix_lm, causal, use_sequence_id):
|
| 430 |
+
if attn_impl == "flash":
|
| 431 |
+
return None
|
| 432 |
+
elif attn_impl in ["torch", "triton"]:
|
| 433 |
+
if alibi:
|
| 434 |
+
if (prefix_lm or not causal) or use_sequence_id:
|
| 435 |
+
return (1, n_heads, seq_len, seq_len)
|
| 436 |
+
return (1, n_heads, 1, seq_len)
|
| 437 |
+
elif prefix_lm or use_sequence_id:
|
| 438 |
+
return (1, 1, seq_len, seq_len)
|
| 439 |
+
return None
|
| 440 |
+
else:
|
| 441 |
+
raise ValueError(f"attn_impl={attn_impl!r} is an invalid setting.")
|
| 442 |
+
|
| 443 |
+
|
| 444 |
+
def build_attn_bias(attn_impl, attn_bias, n_heads, seq_len, causal=False, alibi=False, alibi_bias_max=8):
|
| 445 |
+
if attn_impl == "flash":
|
| 446 |
+
return None
|
| 447 |
+
elif attn_impl in ["torch", "triton"]:
|
| 448 |
+
if alibi:
|
| 449 |
+
(device, dtype) = (attn_bias.device, attn_bias.dtype)
|
| 450 |
+
attn_bias = attn_bias.add(
|
| 451 |
+
build_alibi_bias(
|
| 452 |
+
n_heads, seq_len, full=not causal, alibi_bias_max=alibi_bias_max, device=device, dtype=dtype
|
| 453 |
+
)
|
| 454 |
+
)
|
| 455 |
+
return attn_bias
|
| 456 |
+
else:
|
| 457 |
+
raise ValueError(f"attn_impl={attn_impl!r} is an invalid setting.")
|
| 458 |
+
|
| 459 |
+
|
| 460 |
+
def gen_slopes(n_heads, alibi_bias_max=8, device=None):
|
| 461 |
+
_n_heads = 2 ** math.ceil(math.log2(n_heads))
|
| 462 |
+
m = torch.arange(1, _n_heads + 1, dtype=torch.float32, device=device)
|
| 463 |
+
m = m.mul(alibi_bias_max / _n_heads)
|
| 464 |
+
slopes = 1.0 / torch.pow(2, m)
|
| 465 |
+
if _n_heads != n_heads:
|
| 466 |
+
slopes = torch.concat([slopes[1::2], slopes[::2]])[:n_heads]
|
| 467 |
+
return slopes.view(1, n_heads, 1, 1)
|
| 468 |
+
|
| 469 |
+
|
| 470 |
+
def build_alibi_bias(n_heads, seq_len, full=False, alibi_bias_max=8, device=None, dtype=None):
|
| 471 |
+
alibi_bias = torch.arange(1 - seq_len, 1, dtype=torch.int32, device=device).view(1, 1, 1, seq_len)
|
| 472 |
+
if full:
|
| 473 |
+
alibi_bias = alibi_bias - torch.arange(1 - seq_len, 1, dtype=torch.int32, device=device).view(1, 1, seq_len, 1)
|
| 474 |
+
alibi_bias = alibi_bias.abs().mul(-1)
|
| 475 |
+
slopes = gen_slopes(n_heads, alibi_bias_max, device=device)
|
| 476 |
+
alibi_bias = alibi_bias * slopes
|
| 477 |
+
return alibi_bias.to(dtype=dtype)
|
| 478 |
+
|
| 479 |
+
|
| 480 |
+
ATTN_CLASS_REGISTRY = {"multihead_attention": MultiheadAttention, "multiquery_attention": MultiQueryAttention}
|
VILA/llava/model/language_model/mpt/blocks.py
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
#
|
| 15 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 16 |
+
|
| 17 |
+
"""GPT Blocks used for the GPT Model."""
|
| 18 |
+
from typing import Dict, Optional, Tuple
|
| 19 |
+
|
| 20 |
+
import torch
|
| 21 |
+
import torch.nn as nn
|
| 22 |
+
|
| 23 |
+
from .attention import ATTN_CLASS_REGISTRY
|
| 24 |
+
from .norm import NORM_CLASS_REGISTRY
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class MPTMLP(nn.Module):
|
| 28 |
+
def __init__(self, d_model: int, expansion_ratio: int, device: Optional[str] = None):
|
| 29 |
+
super().__init__()
|
| 30 |
+
self.up_proj = nn.Linear(d_model, expansion_ratio * d_model, device=device)
|
| 31 |
+
self.act = nn.GELU(approximate="none")
|
| 32 |
+
self.down_proj = nn.Linear(expansion_ratio * d_model, d_model, device=device)
|
| 33 |
+
self.down_proj._is_residual = True
|
| 34 |
+
|
| 35 |
+
def forward(self, x):
|
| 36 |
+
return self.down_proj(self.act(self.up_proj(x)))
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class MPTBlock(nn.Module):
|
| 40 |
+
def __init__(
|
| 41 |
+
self,
|
| 42 |
+
d_model: int,
|
| 43 |
+
n_heads: int,
|
| 44 |
+
expansion_ratio: int,
|
| 45 |
+
attn_config: Dict = {
|
| 46 |
+
"attn_type": "multihead_attention",
|
| 47 |
+
"attn_pdrop": 0.0,
|
| 48 |
+
"attn_impl": "triton",
|
| 49 |
+
"qk_ln": False,
|
| 50 |
+
"clip_qkv": None,
|
| 51 |
+
"softmax_scale": None,
|
| 52 |
+
"prefix_lm": False,
|
| 53 |
+
"attn_uses_sequence_id": False,
|
| 54 |
+
"alibi": False,
|
| 55 |
+
"alibi_bias_max": 8,
|
| 56 |
+
},
|
| 57 |
+
resid_pdrop: float = 0.0,
|
| 58 |
+
norm_type: str = "low_precision_layernorm",
|
| 59 |
+
verbose: int = 0,
|
| 60 |
+
device: Optional[str] = None,
|
| 61 |
+
**kwargs
|
| 62 |
+
):
|
| 63 |
+
del kwargs
|
| 64 |
+
super().__init__()
|
| 65 |
+
norm_class = NORM_CLASS_REGISTRY[norm_type.lower()]
|
| 66 |
+
attn_class = ATTN_CLASS_REGISTRY[attn_config["attn_type"]]
|
| 67 |
+
self.norm_1 = norm_class(d_model, device=device)
|
| 68 |
+
self.attn = attn_class(
|
| 69 |
+
attn_impl=attn_config["attn_impl"],
|
| 70 |
+
clip_qkv=attn_config["clip_qkv"],
|
| 71 |
+
qk_ln=attn_config["qk_ln"],
|
| 72 |
+
softmax_scale=attn_config["softmax_scale"],
|
| 73 |
+
attn_pdrop=attn_config["attn_pdrop"],
|
| 74 |
+
d_model=d_model,
|
| 75 |
+
n_heads=n_heads,
|
| 76 |
+
verbose=verbose,
|
| 77 |
+
device=device,
|
| 78 |
+
)
|
| 79 |
+
self.norm_2 = norm_class(d_model, device=device)
|
| 80 |
+
self.ffn = MPTMLP(d_model=d_model, expansion_ratio=expansion_ratio, device=device)
|
| 81 |
+
self.resid_attn_dropout = nn.Dropout(resid_pdrop)
|
| 82 |
+
self.resid_ffn_dropout = nn.Dropout(resid_pdrop)
|
| 83 |
+
|
| 84 |
+
def forward(
|
| 85 |
+
self,
|
| 86 |
+
x: torch.Tensor,
|
| 87 |
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
| 88 |
+
attn_bias: Optional[torch.Tensor] = None,
|
| 89 |
+
attention_mask: Optional[torch.ByteTensor] = None,
|
| 90 |
+
is_causal: bool = True,
|
| 91 |
+
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor]]]:
|
| 92 |
+
a = self.norm_1(x)
|
| 93 |
+
(b, attn_weights, past_key_value) = self.attn(
|
| 94 |
+
a, past_key_value=past_key_value, attn_bias=attn_bias, attention_mask=attention_mask, is_causal=is_causal
|
| 95 |
+
)
|
| 96 |
+
x = x + self.resid_attn_dropout(b)
|
| 97 |
+
m = self.norm_2(x)
|
| 98 |
+
n = self.ffn(m)
|
| 99 |
+
x = x + self.resid_ffn_dropout(n)
|
| 100 |
+
return (x, attn_weights, past_key_value)
|
VILA/llava/model/language_model/mpt/configuration_mpt.py
ADDED
|
@@ -0,0 +1,184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
#
|
| 15 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 16 |
+
|
| 17 |
+
"""A HuggingFace-style model configuration."""
|
| 18 |
+
from typing import Dict, Optional, Union
|
| 19 |
+
|
| 20 |
+
from transformers import PretrainedConfig
|
| 21 |
+
|
| 22 |
+
attn_config_defaults: Dict = {
|
| 23 |
+
"attn_type": "multihead_attention",
|
| 24 |
+
"attn_pdrop": 0.0,
|
| 25 |
+
"attn_impl": "triton",
|
| 26 |
+
"qk_ln": False,
|
| 27 |
+
"clip_qkv": None,
|
| 28 |
+
"softmax_scale": None,
|
| 29 |
+
"prefix_lm": False,
|
| 30 |
+
"attn_uses_sequence_id": False,
|
| 31 |
+
"alibi": False,
|
| 32 |
+
"alibi_bias_max": 8,
|
| 33 |
+
}
|
| 34 |
+
init_config_defaults: Dict = {
|
| 35 |
+
"name": "kaiming_normal_",
|
| 36 |
+
"fan_mode": "fan_in",
|
| 37 |
+
"init_nonlinearity": "relu",
|
| 38 |
+
"init_div_is_residual": True,
|
| 39 |
+
"emb_init_std": None,
|
| 40 |
+
"emb_init_uniform_lim": None,
|
| 41 |
+
"init_std": None,
|
| 42 |
+
"init_gain": 0.0,
|
| 43 |
+
}
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class MPTConfig(PretrainedConfig):
|
| 47 |
+
model_type = "mpt"
|
| 48 |
+
|
| 49 |
+
def __init__(
|
| 50 |
+
self,
|
| 51 |
+
d_model: int = 2048,
|
| 52 |
+
n_heads: int = 16,
|
| 53 |
+
n_layers: int = 24,
|
| 54 |
+
expansion_ratio: int = 4,
|
| 55 |
+
max_seq_len: int = 2048,
|
| 56 |
+
vocab_size: int = 50368,
|
| 57 |
+
resid_pdrop: float = 0.0,
|
| 58 |
+
emb_pdrop: float = 0.0,
|
| 59 |
+
learned_pos_emb: bool = True,
|
| 60 |
+
attn_config: Dict = attn_config_defaults,
|
| 61 |
+
init_device: str = "cpu",
|
| 62 |
+
logit_scale: Optional[Union[float, str]] = None,
|
| 63 |
+
no_bias: bool = False,
|
| 64 |
+
verbose: int = 0,
|
| 65 |
+
embedding_fraction: float = 1.0,
|
| 66 |
+
norm_type: str = "low_precision_layernorm",
|
| 67 |
+
use_cache: bool = False,
|
| 68 |
+
init_config: Dict = init_config_defaults,
|
| 69 |
+
**kwargs,
|
| 70 |
+
):
|
| 71 |
+
"""The MPT configuration class.
|
| 72 |
+
|
| 73 |
+
Args:
|
| 74 |
+
d_model (int): The size of the embedding dimension of the model.
|
| 75 |
+
n_heads (int): The number of attention heads.
|
| 76 |
+
n_layers (int): The number of layers in the model.
|
| 77 |
+
expansion_ratio (int): The ratio of the up/down scale in the MLP.
|
| 78 |
+
max_seq_len (int): The maximum sequence length of the model.
|
| 79 |
+
vocab_size (int): The size of the vocabulary.
|
| 80 |
+
resid_pdrop (float): The dropout probability applied to the attention output before combining with residual.
|
| 81 |
+
emb_pdrop (float): The dropout probability for the embedding layer.
|
| 82 |
+
learned_pos_emb (bool): Whether to use learned positional embeddings
|
| 83 |
+
attn_config (Dict): A dictionary used to configure the model's attention module:
|
| 84 |
+
attn_type (str): type of attention to use. Options: multihead_attention, multiquery_attention
|
| 85 |
+
attn_pdrop (float): The dropout probability for the attention layers.
|
| 86 |
+
attn_impl (str): The attention implementation to use. One of 'torch', 'flash', or 'triton'.
|
| 87 |
+
qk_ln (bool): Whether to apply layer normalization to the queries and keys in the attention layer.
|
| 88 |
+
clip_qkv (Optional[float]): If not None, clip the queries, keys, and values in the attention layer to
|
| 89 |
+
this value.
|
| 90 |
+
softmax_scale (Optional[float]): If not None, scale the softmax in the attention layer by this value. If None,
|
| 91 |
+
use the default scale of ``1/sqrt(d_keys)``.
|
| 92 |
+
prefix_lm (Optional[bool]): Whether the model should operate as a Prefix LM. This requires passing an
|
| 93 |
+
extra `prefix_mask` argument which indicates which tokens belong to the prefix. Tokens in the prefix
|
| 94 |
+
can attend to one another bi-directionally. Tokens outside the prefix use causal attention.
|
| 95 |
+
attn_uses_sequence_id (Optional[bool]): Whether to restrict attention to tokens that have the same sequence_id.
|
| 96 |
+
When the model is in `train` mode, this requires passing an extra `sequence_id` argument which indicates
|
| 97 |
+
which sub-sequence each token belongs to.
|
| 98 |
+
Defaults to ``False`` meaning any provided `sequence_id` will be ignored.
|
| 99 |
+
alibi (bool): Whether to use the alibi bias instead of position embeddings.
|
| 100 |
+
alibi_bias_max (int): The maximum value of the alibi bias.
|
| 101 |
+
init_device (str): The device to use for parameter initialization.
|
| 102 |
+
logit_scale (Optional[Union[float, str]]): If not None, scale the logits by this value.
|
| 103 |
+
no_bias (bool): Whether to use bias in all layers.
|
| 104 |
+
verbose (int): The verbosity level. 0 is silent.
|
| 105 |
+
embedding_fraction (float): The fraction to scale the gradients of the embedding layer by.
|
| 106 |
+
norm_type (str): choose type of norm to use
|
| 107 |
+
multiquery_attention (bool): Whether to use multiquery attention implementation.
|
| 108 |
+
use_cache (bool): Whether or not the model should return the last key/values attentions
|
| 109 |
+
init_config (Dict): A dictionary used to configure the model initialization:
|
| 110 |
+
init_config.name: The parameter initialization scheme to use. Options: 'default_', 'baseline_',
|
| 111 |
+
'kaiming_uniform_', 'kaiming_normal_', 'neox_init_', 'small_init_', 'xavier_uniform_', or
|
| 112 |
+
'xavier_normal_'. These mimic the parameter initialization methods in PyTorch.
|
| 113 |
+
init_div_is_residual (Union[int, float, str, bool]): Value to divide initial weights by if ``module._is_residual`` is True.
|
| 114 |
+
emb_init_std (Optional[float]): The standard deviation of the normal distribution used to initialize the embedding layer.
|
| 115 |
+
emb_init_uniform_lim (Optional[Union[Tuple[float, float], float]]): The lower and upper limits of the uniform distribution
|
| 116 |
+
used to initialize the embedding layer. Mutually exclusive with ``emb_init_std``.
|
| 117 |
+
init_std (float): The standard deviation of the normal distribution used to initialize the model,
|
| 118 |
+
if using the baseline_ parameter initialization scheme.
|
| 119 |
+
init_gain (float): The gain to use for parameter initialization with kaiming or xavier initialization schemes.
|
| 120 |
+
fan_mode (str): The fan mode to use for parameter initialization with kaiming initialization schemes.
|
| 121 |
+
init_nonlinearity (str): The nonlinearity to use for parameter initialization with kaiming initialization schemes.
|
| 122 |
+
---
|
| 123 |
+
See llmfoundry.models.utils.param_init_fns.py for info on other param init config options
|
| 124 |
+
"""
|
| 125 |
+
self.d_model = d_model
|
| 126 |
+
self.n_heads = n_heads
|
| 127 |
+
self.n_layers = n_layers
|
| 128 |
+
self.expansion_ratio = expansion_ratio
|
| 129 |
+
self.max_seq_len = max_seq_len
|
| 130 |
+
self.vocab_size = vocab_size
|
| 131 |
+
self.resid_pdrop = resid_pdrop
|
| 132 |
+
self.emb_pdrop = emb_pdrop
|
| 133 |
+
self.learned_pos_emb = learned_pos_emb
|
| 134 |
+
self.attn_config = attn_config
|
| 135 |
+
self.init_device = init_device
|
| 136 |
+
self.logit_scale = logit_scale
|
| 137 |
+
self.no_bias = no_bias
|
| 138 |
+
self.verbose = verbose
|
| 139 |
+
self.embedding_fraction = embedding_fraction
|
| 140 |
+
self.norm_type = norm_type
|
| 141 |
+
self.use_cache = use_cache
|
| 142 |
+
self.init_config = init_config
|
| 143 |
+
if "name" in kwargs:
|
| 144 |
+
del kwargs["name"]
|
| 145 |
+
if "loss_fn" in kwargs:
|
| 146 |
+
del kwargs["loss_fn"]
|
| 147 |
+
super().__init__(**kwargs)
|
| 148 |
+
self._validate_config()
|
| 149 |
+
|
| 150 |
+
def _set_config_defaults(self, config, config_defaults):
|
| 151 |
+
for (k, v) in config_defaults.items():
|
| 152 |
+
if k not in config:
|
| 153 |
+
config[k] = v
|
| 154 |
+
return config
|
| 155 |
+
|
| 156 |
+
def _validate_config(self):
|
| 157 |
+
self.attn_config = self._set_config_defaults(self.attn_config, attn_config_defaults)
|
| 158 |
+
self.init_config = self._set_config_defaults(self.init_config, init_config_defaults)
|
| 159 |
+
if self.d_model % self.n_heads != 0:
|
| 160 |
+
raise ValueError("d_model must be divisible by n_heads")
|
| 161 |
+
if any(prob < 0 or prob > 1 for prob in [self.attn_config["attn_pdrop"], self.resid_pdrop, self.emb_pdrop]):
|
| 162 |
+
raise ValueError(
|
| 163 |
+
"self.attn_config['attn_pdrop'], resid_pdrop, emb_pdrop are probabilities and must be between 0 and 1"
|
| 164 |
+
)
|
| 165 |
+
if self.attn_config["attn_impl"] not in ["torch", "flash", "triton"]:
|
| 166 |
+
raise ValueError(f"Unknown attn_impl={self.attn_config['attn_impl']}")
|
| 167 |
+
if self.attn_config["prefix_lm"] and self.attn_config["attn_impl"] not in ["torch", "triton"]:
|
| 168 |
+
raise NotImplementedError("prefix_lm only implemented with torch and triton attention.")
|
| 169 |
+
if self.attn_config["alibi"] and self.attn_config["attn_impl"] not in ["torch", "triton"]:
|
| 170 |
+
raise NotImplementedError("alibi only implemented with torch and triton attention.")
|
| 171 |
+
if self.attn_config["attn_uses_sequence_id"] and self.attn_config["attn_impl"] not in ["torch", "triton"]:
|
| 172 |
+
raise NotImplementedError("attn_uses_sequence_id only implemented with torch and triton attention.")
|
| 173 |
+
if self.embedding_fraction > 1 or self.embedding_fraction <= 0:
|
| 174 |
+
raise ValueError("model.embedding_fraction must be between 0 (exclusive) and 1 (inclusive)!")
|
| 175 |
+
if isinstance(self.logit_scale, str) and self.logit_scale != "inv_sqrt_d_model":
|
| 176 |
+
raise ValueError(
|
| 177 |
+
f"self.logit_scale={self.logit_scale!r} is not recognized as an option; use numeric value or 'inv_sqrt_d_model'."
|
| 178 |
+
)
|
| 179 |
+
if self.init_config.get("name", None) is None:
|
| 180 |
+
raise ValueError(f"self.init_config={self.init_config!r} 'name' needs to be set.")
|
| 181 |
+
if not self.learned_pos_emb and (not self.attn_config["alibi"]):
|
| 182 |
+
raise ValueError(
|
| 183 |
+
f"Positional information must be provided to the model using either learned_pos_emb or alibi."
|
| 184 |
+
)
|
VILA/llava/model/language_model/mpt/custom_embedding.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
#
|
| 15 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 16 |
+
|
| 17 |
+
import torch
|
| 18 |
+
import torch.nn as nn
|
| 19 |
+
import torch.nn.functional as F
|
| 20 |
+
from torch import Tensor
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class SharedEmbedding(nn.Embedding):
|
| 24 |
+
def forward(self, input: Tensor, unembed: bool = False) -> Tensor:
|
| 25 |
+
if unembed:
|
| 26 |
+
return F.linear(input, self.weight)
|
| 27 |
+
return super().forward(input)
|
VILA/llava/model/language_model/mpt/flash_attn_triton.py
ADDED
|
@@ -0,0 +1,947 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
#
|
| 15 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 16 |
+
|
| 17 |
+
"""
|
| 18 |
+
Copied from https://github.com/HazyResearch/flash-attention/blob/eff9fe6b8076df59d64d7a3f464696738a3c7c24/flash_attn/flash_attn_triton.py
|
| 19 |
+
update imports to use 'triton_pre_mlir'
|
| 20 |
+
|
| 21 |
+
*Experimental* implementation of FlashAttention in Triton.
|
| 22 |
+
Tested with triton==2.0.0.dev20221202.
|
| 23 |
+
Triton 2.0 has a new backend (MLIR) but seems like it doesn't yet work for head dimensions
|
| 24 |
+
other than 64:
|
| 25 |
+
https://github.com/openai/triton/blob/d376020f90002757eea3ea9475d4f7cfc2ec5ead/python/triton/ops/flash_attention.py#L207
|
| 26 |
+
We'll update this implementation with the new Triton backend once this is fixed.
|
| 27 |
+
|
| 28 |
+
We use the FlashAttention implementation from Phil Tillet a starting point.
|
| 29 |
+
https://github.com/openai/triton/blob/master/python/tutorials/06-fused-attention.py
|
| 30 |
+
|
| 31 |
+
Changes:
|
| 32 |
+
- Implement both causal and non-causal attention.
|
| 33 |
+
- Implement both self-attention and cross-attention.
|
| 34 |
+
- Support arbitrary seqlens (not just multiples of 128), for both forward and backward.
|
| 35 |
+
- Support all head dimensions up to 128 (not just 16, 32, 64, 128), for both forward and backward.
|
| 36 |
+
- Support attention bias.
|
| 37 |
+
- Speed up the forward pass a bit, and only store the LSE instead of m and l.
|
| 38 |
+
- Make the backward for d=128 much faster by reducing register spilling.
|
| 39 |
+
- Optionally parallelize the backward pass across seqlen_k, to deal with the case of
|
| 40 |
+
small batch size * nheads.
|
| 41 |
+
|
| 42 |
+
Caution:
|
| 43 |
+
- This is an *experimental* implementation. The forward pass should be quite robust but
|
| 44 |
+
I'm not 100% sure that the backward pass doesn't have race conditions (due to the Triton compiler).
|
| 45 |
+
- This implementation has only been tested on A100.
|
| 46 |
+
- If you plan to use headdim other than 64 and 128, you should test for race conditions
|
| 47 |
+
(due to the Triton compiler), as done in tests/test_flash_attn.py
|
| 48 |
+
"test_flash_attn_triton_race_condition". I've tested and fixed many race conditions
|
| 49 |
+
for different head dimensions (40, 48, 64, 128, 80, 88, 96), but I'm still not 100% confident
|
| 50 |
+
that there are none left for other head dimensions.
|
| 51 |
+
|
| 52 |
+
Differences between this Triton version and the CUDA version:
|
| 53 |
+
- Triton version doesn't support dropout.
|
| 54 |
+
- Triton forward is generally faster than CUDA forward, while Triton backward is
|
| 55 |
+
generally slower than CUDA backward. Overall Triton forward + backward is slightly slower
|
| 56 |
+
than CUDA forward + backward.
|
| 57 |
+
- Triton version doesn't support different sequence lengths in a batch (i.e., RaggedTensor/NestedTensor).
|
| 58 |
+
- Triton version supports attention bias, while CUDA version doesn't.
|
| 59 |
+
"""
|
| 60 |
+
import math
|
| 61 |
+
|
| 62 |
+
import torch
|
| 63 |
+
import triton_pre_mlir as triton
|
| 64 |
+
import triton_pre_mlir.language as tl
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
@triton.heuristics(
|
| 68 |
+
{
|
| 69 |
+
"EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0,
|
| 70 |
+
"EVEN_N": lambda args: args["seqlen_k"] % args["BLOCK_N"] == 0,
|
| 71 |
+
"EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"],
|
| 72 |
+
}
|
| 73 |
+
)
|
| 74 |
+
@triton.jit
|
| 75 |
+
def _fwd_kernel(
|
| 76 |
+
Q,
|
| 77 |
+
K,
|
| 78 |
+
V,
|
| 79 |
+
Bias,
|
| 80 |
+
Out,
|
| 81 |
+
Lse,
|
| 82 |
+
TMP,
|
| 83 |
+
softmax_scale,
|
| 84 |
+
stride_qb,
|
| 85 |
+
stride_qh,
|
| 86 |
+
stride_qm,
|
| 87 |
+
stride_kb,
|
| 88 |
+
stride_kh,
|
| 89 |
+
stride_kn,
|
| 90 |
+
stride_vb,
|
| 91 |
+
stride_vh,
|
| 92 |
+
stride_vn,
|
| 93 |
+
stride_bb,
|
| 94 |
+
stride_bh,
|
| 95 |
+
stride_bm,
|
| 96 |
+
stride_ob,
|
| 97 |
+
stride_oh,
|
| 98 |
+
stride_om,
|
| 99 |
+
nheads,
|
| 100 |
+
seqlen_q,
|
| 101 |
+
seqlen_k,
|
| 102 |
+
seqlen_q_rounded,
|
| 103 |
+
headdim,
|
| 104 |
+
CACHE_KEY_SEQLEN_Q,
|
| 105 |
+
CACHE_KEY_SEQLEN_K,
|
| 106 |
+
BIAS_TYPE: tl.constexpr,
|
| 107 |
+
IS_CAUSAL: tl.constexpr,
|
| 108 |
+
BLOCK_HEADDIM: tl.constexpr,
|
| 109 |
+
EVEN_M: tl.constexpr,
|
| 110 |
+
EVEN_N: tl.constexpr,
|
| 111 |
+
EVEN_HEADDIM: tl.constexpr,
|
| 112 |
+
BLOCK_M: tl.constexpr,
|
| 113 |
+
BLOCK_N: tl.constexpr,
|
| 114 |
+
):
|
| 115 |
+
start_m = tl.program_id(0)
|
| 116 |
+
off_hb = tl.program_id(1)
|
| 117 |
+
off_b = off_hb // nheads
|
| 118 |
+
off_h = off_hb % nheads
|
| 119 |
+
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
| 120 |
+
offs_n = tl.arange(0, BLOCK_N)
|
| 121 |
+
offs_d = tl.arange(0, BLOCK_HEADDIM)
|
| 122 |
+
q_ptrs = Q + off_b * stride_qb + off_h * stride_qh + (offs_m[:, None] * stride_qm + offs_d[None, :])
|
| 123 |
+
k_ptrs = K + off_b * stride_kb + off_h * stride_kh + (offs_n[:, None] * stride_kn + offs_d[None, :])
|
| 124 |
+
v_ptrs = V + off_b * stride_vb + off_h * stride_vh + (offs_n[:, None] * stride_vn + offs_d[None, :])
|
| 125 |
+
if BIAS_TYPE == "vector":
|
| 126 |
+
b_ptrs = Bias + off_b * stride_bb + off_h * stride_bh + offs_n
|
| 127 |
+
elif BIAS_TYPE == "matrix":
|
| 128 |
+
b_ptrs = Bias + off_b * stride_bb + off_h * stride_bh + (offs_m[:, None] * stride_bm + offs_n[None, :])
|
| 129 |
+
t_ptrs = TMP + off_hb * seqlen_q_rounded + offs_m
|
| 130 |
+
lse_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
|
| 131 |
+
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
|
| 132 |
+
acc_o = tl.zeros([BLOCK_M, BLOCK_HEADDIM], dtype=tl.float32)
|
| 133 |
+
if EVEN_M & EVEN_N:
|
| 134 |
+
if EVEN_HEADDIM:
|
| 135 |
+
q = tl.load(q_ptrs)
|
| 136 |
+
else:
|
| 137 |
+
q = tl.load(q_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
|
| 138 |
+
elif EVEN_HEADDIM:
|
| 139 |
+
q = tl.load(q_ptrs, mask=offs_m[:, None] < seqlen_q, other=0.0)
|
| 140 |
+
else:
|
| 141 |
+
q = tl.load(q_ptrs, mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim), other=0.0)
|
| 142 |
+
end_n = seqlen_k if not IS_CAUSAL else tl.minimum((start_m + 1) * BLOCK_M, seqlen_k)
|
| 143 |
+
for start_n in range(0, end_n, BLOCK_N):
|
| 144 |
+
start_n = tl.multiple_of(start_n, BLOCK_N)
|
| 145 |
+
if EVEN_N & EVEN_M:
|
| 146 |
+
if EVEN_HEADDIM:
|
| 147 |
+
k = tl.load(k_ptrs + start_n * stride_kn)
|
| 148 |
+
else:
|
| 149 |
+
k = tl.load(k_ptrs + start_n * stride_kn, mask=offs_d[None, :] < headdim, other=0.0)
|
| 150 |
+
elif EVEN_HEADDIM:
|
| 151 |
+
k = tl.load(k_ptrs + start_n * stride_kn, mask=(start_n + offs_n)[:, None] < seqlen_k, other=0.0)
|
| 152 |
+
else:
|
| 153 |
+
k = tl.load(
|
| 154 |
+
k_ptrs + start_n * stride_kn,
|
| 155 |
+
mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim),
|
| 156 |
+
other=0.0,
|
| 157 |
+
)
|
| 158 |
+
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
| 159 |
+
qk += tl.dot(q, k, trans_b=True)
|
| 160 |
+
if not EVEN_N:
|
| 161 |
+
qk += tl.where((start_n + offs_n)[None, :] < seqlen_k, 0, float("-inf"))
|
| 162 |
+
if IS_CAUSAL:
|
| 163 |
+
qk += tl.where(offs_m[:, None] >= (start_n + offs_n)[None, :], 0, float("-inf"))
|
| 164 |
+
if BIAS_TYPE != "none":
|
| 165 |
+
if BIAS_TYPE == "vector":
|
| 166 |
+
if EVEN_N:
|
| 167 |
+
bias = tl.load(b_ptrs + start_n).to(tl.float32)
|
| 168 |
+
else:
|
| 169 |
+
bias = tl.load(b_ptrs + start_n, mask=start_n + offs_n < seqlen_k, other=0.0).to(tl.float32)
|
| 170 |
+
bias = bias[None, :]
|
| 171 |
+
elif BIAS_TYPE == "matrix":
|
| 172 |
+
if EVEN_M & EVEN_N:
|
| 173 |
+
bias = tl.load(b_ptrs + start_n).to(tl.float32)
|
| 174 |
+
else:
|
| 175 |
+
bias = tl.load(
|
| 176 |
+
b_ptrs + start_n,
|
| 177 |
+
mask=(offs_m[:, None] < seqlen_q) & ((start_n + offs_n)[None, :] < seqlen_k),
|
| 178 |
+
other=0.0,
|
| 179 |
+
).to(tl.float32)
|
| 180 |
+
qk = qk * softmax_scale + bias
|
| 181 |
+
m_ij = tl.maximum(tl.max(qk, 1), lse_i)
|
| 182 |
+
p = tl.exp(qk - m_ij[:, None])
|
| 183 |
+
else:
|
| 184 |
+
m_ij = tl.maximum(tl.max(qk, 1) * softmax_scale, lse_i)
|
| 185 |
+
p = tl.exp(qk * softmax_scale - m_ij[:, None])
|
| 186 |
+
l_ij = tl.sum(p, 1)
|
| 187 |
+
acc_o_scale = tl.exp(m_i - m_ij)
|
| 188 |
+
tl.store(t_ptrs, acc_o_scale)
|
| 189 |
+
acc_o_scale = tl.load(t_ptrs)
|
| 190 |
+
acc_o = acc_o * acc_o_scale[:, None]
|
| 191 |
+
if EVEN_N & EVEN_M:
|
| 192 |
+
if EVEN_HEADDIM:
|
| 193 |
+
v = tl.load(v_ptrs + start_n * stride_vn)
|
| 194 |
+
else:
|
| 195 |
+
v = tl.load(v_ptrs + start_n * stride_vn, mask=offs_d[None, :] < headdim, other=0.0)
|
| 196 |
+
elif EVEN_HEADDIM:
|
| 197 |
+
v = tl.load(v_ptrs + start_n * stride_vn, mask=(start_n + offs_n)[:, None] < seqlen_k, other=0.0)
|
| 198 |
+
else:
|
| 199 |
+
v = tl.load(
|
| 200 |
+
v_ptrs + start_n * stride_vn,
|
| 201 |
+
mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim),
|
| 202 |
+
other=0.0,
|
| 203 |
+
)
|
| 204 |
+
p = p.to(v.dtype)
|
| 205 |
+
acc_o += tl.dot(p, v)
|
| 206 |
+
m_i = m_ij
|
| 207 |
+
l_i_new = tl.exp(lse_i - m_ij) + l_ij
|
| 208 |
+
lse_i = m_ij + tl.log(l_i_new)
|
| 209 |
+
o_scale = tl.exp(m_i - lse_i)
|
| 210 |
+
tl.store(t_ptrs, o_scale)
|
| 211 |
+
o_scale = tl.load(t_ptrs)
|
| 212 |
+
acc_o = acc_o * o_scale[:, None]
|
| 213 |
+
start_m = tl.program_id(0)
|
| 214 |
+
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
| 215 |
+
lse_ptrs = Lse + off_hb * seqlen_q_rounded + offs_m
|
| 216 |
+
tl.store(lse_ptrs, lse_i)
|
| 217 |
+
offs_d = tl.arange(0, BLOCK_HEADDIM)
|
| 218 |
+
out_ptrs = Out + off_b * stride_ob + off_h * stride_oh + (offs_m[:, None] * stride_om + offs_d[None, :])
|
| 219 |
+
if EVEN_M:
|
| 220 |
+
if EVEN_HEADDIM:
|
| 221 |
+
tl.store(out_ptrs, acc_o)
|
| 222 |
+
else:
|
| 223 |
+
tl.store(out_ptrs, acc_o, mask=offs_d[None, :] < headdim)
|
| 224 |
+
elif EVEN_HEADDIM:
|
| 225 |
+
tl.store(out_ptrs, acc_o, mask=offs_m[:, None] < seqlen_q)
|
| 226 |
+
else:
|
| 227 |
+
tl.store(out_ptrs, acc_o, mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim))
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
@triton.jit
|
| 231 |
+
def _bwd_preprocess_do_o_dot(
|
| 232 |
+
Out,
|
| 233 |
+
DO,
|
| 234 |
+
Delta,
|
| 235 |
+
stride_ob,
|
| 236 |
+
stride_oh,
|
| 237 |
+
stride_om,
|
| 238 |
+
stride_dob,
|
| 239 |
+
stride_doh,
|
| 240 |
+
stride_dom,
|
| 241 |
+
nheads,
|
| 242 |
+
seqlen_q,
|
| 243 |
+
seqlen_q_rounded,
|
| 244 |
+
headdim,
|
| 245 |
+
BLOCK_M: tl.constexpr,
|
| 246 |
+
BLOCK_HEADDIM: tl.constexpr,
|
| 247 |
+
):
|
| 248 |
+
start_m = tl.program_id(0)
|
| 249 |
+
off_hb = tl.program_id(1)
|
| 250 |
+
off_b = off_hb // nheads
|
| 251 |
+
off_h = off_hb % nheads
|
| 252 |
+
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
| 253 |
+
offs_d = tl.arange(0, BLOCK_HEADDIM)
|
| 254 |
+
o = tl.load(
|
| 255 |
+
Out + off_b * stride_ob + off_h * stride_oh + offs_m[:, None] * stride_om + offs_d[None, :],
|
| 256 |
+
mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
|
| 257 |
+
other=0.0,
|
| 258 |
+
).to(tl.float32)
|
| 259 |
+
do = tl.load(
|
| 260 |
+
DO + off_b * stride_dob + off_h * stride_doh + offs_m[:, None] * stride_dom + offs_d[None, :],
|
| 261 |
+
mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
|
| 262 |
+
other=0.0,
|
| 263 |
+
).to(tl.float32)
|
| 264 |
+
delta = tl.sum(o * do, axis=1)
|
| 265 |
+
tl.store(Delta + off_hb * seqlen_q_rounded + offs_m, delta)
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
@triton.jit
|
| 269 |
+
def _bwd_store_dk_dv(
|
| 270 |
+
dk_ptrs,
|
| 271 |
+
dv_ptrs,
|
| 272 |
+
dk,
|
| 273 |
+
dv,
|
| 274 |
+
offs_n,
|
| 275 |
+
offs_d,
|
| 276 |
+
seqlen_k,
|
| 277 |
+
headdim,
|
| 278 |
+
EVEN_M: tl.constexpr,
|
| 279 |
+
EVEN_N: tl.constexpr,
|
| 280 |
+
EVEN_HEADDIM: tl.constexpr,
|
| 281 |
+
):
|
| 282 |
+
if EVEN_N & EVEN_M:
|
| 283 |
+
if EVEN_HEADDIM:
|
| 284 |
+
tl.store(dv_ptrs, dv)
|
| 285 |
+
tl.store(dk_ptrs, dk)
|
| 286 |
+
else:
|
| 287 |
+
tl.store(dv_ptrs, dv, mask=offs_d[None, :] < headdim)
|
| 288 |
+
tl.store(dk_ptrs, dk, mask=offs_d[None, :] < headdim)
|
| 289 |
+
elif EVEN_HEADDIM:
|
| 290 |
+
tl.store(dv_ptrs, dv, mask=offs_n[:, None] < seqlen_k)
|
| 291 |
+
tl.store(dk_ptrs, dk, mask=offs_n[:, None] < seqlen_k)
|
| 292 |
+
else:
|
| 293 |
+
tl.store(dv_ptrs, dv, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim))
|
| 294 |
+
tl.store(dk_ptrs, dk, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim))
|
| 295 |
+
|
| 296 |
+
|
| 297 |
+
@triton.jit
|
| 298 |
+
def _bwd_kernel_one_col_block(
|
| 299 |
+
start_n,
|
| 300 |
+
Q,
|
| 301 |
+
K,
|
| 302 |
+
V,
|
| 303 |
+
Bias,
|
| 304 |
+
DO,
|
| 305 |
+
DQ,
|
| 306 |
+
DK,
|
| 307 |
+
DV,
|
| 308 |
+
LSE,
|
| 309 |
+
D,
|
| 310 |
+
softmax_scale,
|
| 311 |
+
stride_qm,
|
| 312 |
+
stride_kn,
|
| 313 |
+
stride_vn,
|
| 314 |
+
stride_bm,
|
| 315 |
+
stride_dom,
|
| 316 |
+
stride_dqm,
|
| 317 |
+
stride_dkn,
|
| 318 |
+
stride_dvn,
|
| 319 |
+
seqlen_q,
|
| 320 |
+
seqlen_k,
|
| 321 |
+
headdim,
|
| 322 |
+
ATOMIC_ADD: tl.constexpr,
|
| 323 |
+
BIAS_TYPE: tl.constexpr,
|
| 324 |
+
IS_CAUSAL: tl.constexpr,
|
| 325 |
+
BLOCK_HEADDIM: tl.constexpr,
|
| 326 |
+
EVEN_M: tl.constexpr,
|
| 327 |
+
EVEN_N: tl.constexpr,
|
| 328 |
+
EVEN_HEADDIM: tl.constexpr,
|
| 329 |
+
BLOCK_M: tl.constexpr,
|
| 330 |
+
BLOCK_N: tl.constexpr,
|
| 331 |
+
):
|
| 332 |
+
begin_m = 0 if not IS_CAUSAL else start_n * BLOCK_N // BLOCK_M * BLOCK_M
|
| 333 |
+
offs_qm = begin_m + tl.arange(0, BLOCK_M)
|
| 334 |
+
offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
| 335 |
+
offs_m = tl.arange(0, BLOCK_M)
|
| 336 |
+
offs_d = tl.arange(0, BLOCK_HEADDIM)
|
| 337 |
+
q_ptrs = Q + (offs_qm[:, None] * stride_qm + offs_d[None, :])
|
| 338 |
+
k_ptrs = K + (offs_n[:, None] * stride_kn + offs_d[None, :])
|
| 339 |
+
v_ptrs = V + (offs_n[:, None] * stride_vn + offs_d[None, :])
|
| 340 |
+
do_ptrs = DO + (offs_qm[:, None] * stride_dom + offs_d[None, :])
|
| 341 |
+
dq_ptrs = DQ + (offs_qm[:, None] * stride_dqm + offs_d[None, :])
|
| 342 |
+
if BIAS_TYPE == "vector":
|
| 343 |
+
b_ptrs = Bias + offs_n
|
| 344 |
+
elif BIAS_TYPE == "matrix":
|
| 345 |
+
b_ptrs = Bias + (offs_qm[:, None] * stride_bm + offs_n[None, :])
|
| 346 |
+
dv = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32)
|
| 347 |
+
dk = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32)
|
| 348 |
+
if begin_m >= seqlen_q:
|
| 349 |
+
dv_ptrs = DV + (offs_n[:, None] * stride_dvn + offs_d[None, :])
|
| 350 |
+
dk_ptrs = DK + (offs_n[:, None] * stride_dkn + offs_d[None, :])
|
| 351 |
+
_bwd_store_dk_dv(
|
| 352 |
+
dk_ptrs,
|
| 353 |
+
dv_ptrs,
|
| 354 |
+
dk,
|
| 355 |
+
dv,
|
| 356 |
+
offs_n,
|
| 357 |
+
offs_d,
|
| 358 |
+
seqlen_k,
|
| 359 |
+
headdim,
|
| 360 |
+
EVEN_M=EVEN_M,
|
| 361 |
+
EVEN_N=EVEN_N,
|
| 362 |
+
EVEN_HEADDIM=EVEN_HEADDIM,
|
| 363 |
+
)
|
| 364 |
+
return
|
| 365 |
+
if EVEN_N & EVEN_M:
|
| 366 |
+
if EVEN_HEADDIM:
|
| 367 |
+
k = tl.load(k_ptrs)
|
| 368 |
+
v = tl.load(v_ptrs)
|
| 369 |
+
else:
|
| 370 |
+
k = tl.load(k_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
|
| 371 |
+
v = tl.load(v_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
|
| 372 |
+
elif EVEN_HEADDIM:
|
| 373 |
+
k = tl.load(k_ptrs, mask=offs_n[:, None] < seqlen_k, other=0.0)
|
| 374 |
+
v = tl.load(v_ptrs, mask=offs_n[:, None] < seqlen_k, other=0.0)
|
| 375 |
+
else:
|
| 376 |
+
k = tl.load(k_ptrs, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim), other=0.0)
|
| 377 |
+
v = tl.load(v_ptrs, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim), other=0.0)
|
| 378 |
+
num_block_m = tl.cdiv(seqlen_q, BLOCK_M)
|
| 379 |
+
for start_m in range(begin_m, num_block_m * BLOCK_M, BLOCK_M):
|
| 380 |
+
start_m = tl.multiple_of(start_m, BLOCK_M)
|
| 381 |
+
offs_m_curr = start_m + offs_m
|
| 382 |
+
if EVEN_M & EVEN_HEADDIM:
|
| 383 |
+
q = tl.load(q_ptrs)
|
| 384 |
+
elif EVEN_HEADDIM:
|
| 385 |
+
q = tl.load(q_ptrs, mask=offs_m_curr[:, None] < seqlen_q, other=0.0)
|
| 386 |
+
else:
|
| 387 |
+
q = tl.load(q_ptrs, mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim), other=0.0)
|
| 388 |
+
qk = tl.dot(q, k, trans_b=True)
|
| 389 |
+
if not EVEN_N:
|
| 390 |
+
qk = tl.where(offs_n[None, :] < seqlen_k, qk, float("-inf"))
|
| 391 |
+
if IS_CAUSAL:
|
| 392 |
+
qk = tl.where(offs_m_curr[:, None] >= offs_n[None, :], qk, float("-inf"))
|
| 393 |
+
if BIAS_TYPE != "none":
|
| 394 |
+
tl.debug_barrier()
|
| 395 |
+
if BIAS_TYPE == "vector":
|
| 396 |
+
if EVEN_N:
|
| 397 |
+
bias = tl.load(b_ptrs).to(tl.float32)
|
| 398 |
+
else:
|
| 399 |
+
bias = tl.load(b_ptrs, mask=offs_n < seqlen_k, other=0.0).to(tl.float32)
|
| 400 |
+
bias = bias[None, :]
|
| 401 |
+
elif BIAS_TYPE == "matrix":
|
| 402 |
+
if EVEN_M & EVEN_N:
|
| 403 |
+
bias = tl.load(b_ptrs).to(tl.float32)
|
| 404 |
+
else:
|
| 405 |
+
bias = tl.load(
|
| 406 |
+
b_ptrs, mask=(offs_m_curr[:, None] < seqlen_q) & (offs_n[None, :] < seqlen_k), other=0.0
|
| 407 |
+
).to(tl.float32)
|
| 408 |
+
qk = qk * softmax_scale + bias
|
| 409 |
+
if not EVEN_M & EVEN_HEADDIM:
|
| 410 |
+
tl.debug_barrier()
|
| 411 |
+
lse_i = tl.load(LSE + offs_m_curr)
|
| 412 |
+
if BIAS_TYPE == "none":
|
| 413 |
+
p = tl.exp(qk * softmax_scale - lse_i[:, None])
|
| 414 |
+
else:
|
| 415 |
+
p = tl.exp(qk - lse_i[:, None])
|
| 416 |
+
if EVEN_M & EVEN_HEADDIM:
|
| 417 |
+
do = tl.load(do_ptrs)
|
| 418 |
+
else:
|
| 419 |
+
do = tl.load(do_ptrs, mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim), other=0.0)
|
| 420 |
+
dv += tl.dot(p.to(do.dtype), do, trans_a=True)
|
| 421 |
+
if not EVEN_M & EVEN_HEADDIM:
|
| 422 |
+
tl.debug_barrier()
|
| 423 |
+
dp = tl.dot(do, v, trans_b=True)
|
| 424 |
+
if not EVEN_HEADDIM:
|
| 425 |
+
tl.debug_barrier()
|
| 426 |
+
Di = tl.load(D + offs_m_curr)
|
| 427 |
+
ds = (p * (dp - Di[:, None]) * softmax_scale).to(q.dtype)
|
| 428 |
+
dk += tl.dot(ds, q, trans_a=True)
|
| 429 |
+
if not EVEN_M & EVEN_HEADDIM:
|
| 430 |
+
tl.debug_barrier()
|
| 431 |
+
if not ATOMIC_ADD:
|
| 432 |
+
if EVEN_M & EVEN_HEADDIM:
|
| 433 |
+
dq = tl.load(dq_ptrs, eviction_policy="evict_last")
|
| 434 |
+
dq += tl.dot(ds, k)
|
| 435 |
+
tl.store(dq_ptrs, dq, eviction_policy="evict_last")
|
| 436 |
+
elif EVEN_HEADDIM:
|
| 437 |
+
dq = tl.load(dq_ptrs, mask=offs_m_curr[:, None] < seqlen_q, other=0.0, eviction_policy="evict_last")
|
| 438 |
+
dq += tl.dot(ds, k)
|
| 439 |
+
tl.store(dq_ptrs, dq, mask=offs_m_curr[:, None] < seqlen_q, eviction_policy="evict_last")
|
| 440 |
+
else:
|
| 441 |
+
dq = tl.load(
|
| 442 |
+
dq_ptrs,
|
| 443 |
+
mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
|
| 444 |
+
other=0.0,
|
| 445 |
+
eviction_policy="evict_last",
|
| 446 |
+
)
|
| 447 |
+
dq += tl.dot(ds, k)
|
| 448 |
+
tl.store(
|
| 449 |
+
dq_ptrs,
|
| 450 |
+
dq,
|
| 451 |
+
mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
|
| 452 |
+
eviction_policy="evict_last",
|
| 453 |
+
)
|
| 454 |
+
else:
|
| 455 |
+
dq = tl.dot(ds, k)
|
| 456 |
+
if EVEN_M & EVEN_HEADDIM:
|
| 457 |
+
tl.atomic_add(dq_ptrs, dq)
|
| 458 |
+
elif EVEN_HEADDIM:
|
| 459 |
+
tl.atomic_add(dq_ptrs, dq, mask=offs_m_curr[:, None] < seqlen_q)
|
| 460 |
+
else:
|
| 461 |
+
tl.atomic_add(dq_ptrs, dq, mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim))
|
| 462 |
+
dq_ptrs += BLOCK_M * stride_dqm
|
| 463 |
+
q_ptrs += BLOCK_M * stride_qm
|
| 464 |
+
do_ptrs += BLOCK_M * stride_dom
|
| 465 |
+
if BIAS_TYPE == "matrix":
|
| 466 |
+
b_ptrs += BLOCK_M * stride_bm
|
| 467 |
+
dv_ptrs = DV + (offs_n[:, None] * stride_dvn + offs_d[None, :])
|
| 468 |
+
dk_ptrs = DK + (offs_n[:, None] * stride_dkn + offs_d[None, :])
|
| 469 |
+
_bwd_store_dk_dv(
|
| 470 |
+
dk_ptrs,
|
| 471 |
+
dv_ptrs,
|
| 472 |
+
dk,
|
| 473 |
+
dv,
|
| 474 |
+
offs_n,
|
| 475 |
+
offs_d,
|
| 476 |
+
seqlen_k,
|
| 477 |
+
headdim,
|
| 478 |
+
EVEN_M=EVEN_M,
|
| 479 |
+
EVEN_N=EVEN_N,
|
| 480 |
+
EVEN_HEADDIM=EVEN_HEADDIM,
|
| 481 |
+
)
|
| 482 |
+
|
| 483 |
+
|
| 484 |
+
def init_to_zero(name):
|
| 485 |
+
return lambda nargs: nargs[name].zero_()
|
| 486 |
+
|
| 487 |
+
|
| 488 |
+
@triton.autotune(
|
| 489 |
+
configs=[
|
| 490 |
+
triton.Config(
|
| 491 |
+
{"BLOCK_M": 128, "BLOCK_N": 128, "SEQUENCE_PARALLEL": False},
|
| 492 |
+
num_warps=8,
|
| 493 |
+
num_stages=1,
|
| 494 |
+
pre_hook=init_to_zero("DQ"),
|
| 495 |
+
),
|
| 496 |
+
triton.Config(
|
| 497 |
+
{"BLOCK_M": 128, "BLOCK_N": 128, "SEQUENCE_PARALLEL": True},
|
| 498 |
+
num_warps=8,
|
| 499 |
+
num_stages=1,
|
| 500 |
+
pre_hook=init_to_zero("DQ"),
|
| 501 |
+
),
|
| 502 |
+
],
|
| 503 |
+
key=["CACHE_KEY_SEQLEN_Q", "CACHE_KEY_SEQLEN_K", "BIAS_TYPE", "IS_CAUSAL", "BLOCK_HEADDIM"],
|
| 504 |
+
)
|
| 505 |
+
@triton.heuristics(
|
| 506 |
+
{
|
| 507 |
+
"EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0,
|
| 508 |
+
"EVEN_N": lambda args: args["seqlen_k"] % args["BLOCK_N"] == 0,
|
| 509 |
+
"EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"],
|
| 510 |
+
}
|
| 511 |
+
)
|
| 512 |
+
@triton.jit
|
| 513 |
+
def _bwd_kernel(
|
| 514 |
+
Q,
|
| 515 |
+
K,
|
| 516 |
+
V,
|
| 517 |
+
Bias,
|
| 518 |
+
DO,
|
| 519 |
+
DQ,
|
| 520 |
+
DK,
|
| 521 |
+
DV,
|
| 522 |
+
LSE,
|
| 523 |
+
D,
|
| 524 |
+
softmax_scale,
|
| 525 |
+
stride_qb,
|
| 526 |
+
stride_qh,
|
| 527 |
+
stride_qm,
|
| 528 |
+
stride_kb,
|
| 529 |
+
stride_kh,
|
| 530 |
+
stride_kn,
|
| 531 |
+
stride_vb,
|
| 532 |
+
stride_vh,
|
| 533 |
+
stride_vn,
|
| 534 |
+
stride_bb,
|
| 535 |
+
stride_bh,
|
| 536 |
+
stride_bm,
|
| 537 |
+
stride_dob,
|
| 538 |
+
stride_doh,
|
| 539 |
+
stride_dom,
|
| 540 |
+
stride_dqb,
|
| 541 |
+
stride_dqh,
|
| 542 |
+
stride_dqm,
|
| 543 |
+
stride_dkb,
|
| 544 |
+
stride_dkh,
|
| 545 |
+
stride_dkn,
|
| 546 |
+
stride_dvb,
|
| 547 |
+
stride_dvh,
|
| 548 |
+
stride_dvn,
|
| 549 |
+
nheads,
|
| 550 |
+
seqlen_q,
|
| 551 |
+
seqlen_k,
|
| 552 |
+
seqlen_q_rounded,
|
| 553 |
+
headdim,
|
| 554 |
+
CACHE_KEY_SEQLEN_Q,
|
| 555 |
+
CACHE_KEY_SEQLEN_K,
|
| 556 |
+
BIAS_TYPE: tl.constexpr,
|
| 557 |
+
IS_CAUSAL: tl.constexpr,
|
| 558 |
+
BLOCK_HEADDIM: tl.constexpr,
|
| 559 |
+
SEQUENCE_PARALLEL: tl.constexpr,
|
| 560 |
+
EVEN_M: tl.constexpr,
|
| 561 |
+
EVEN_N: tl.constexpr,
|
| 562 |
+
EVEN_HEADDIM: tl.constexpr,
|
| 563 |
+
BLOCK_M: tl.constexpr,
|
| 564 |
+
BLOCK_N: tl.constexpr,
|
| 565 |
+
):
|
| 566 |
+
off_hb = tl.program_id(1)
|
| 567 |
+
off_b = off_hb // nheads
|
| 568 |
+
off_h = off_hb % nheads
|
| 569 |
+
Q += off_b * stride_qb + off_h * stride_qh
|
| 570 |
+
K += off_b * stride_kb + off_h * stride_kh
|
| 571 |
+
V += off_b * stride_vb + off_h * stride_vh
|
| 572 |
+
DO += off_b * stride_dob + off_h * stride_doh
|
| 573 |
+
DQ += off_b * stride_dqb + off_h * stride_dqh
|
| 574 |
+
DK += off_b * stride_dkb + off_h * stride_dkh
|
| 575 |
+
DV += off_b * stride_dvb + off_h * stride_dvh
|
| 576 |
+
if BIAS_TYPE != "none":
|
| 577 |
+
Bias += off_b * stride_bb + off_h * stride_bh
|
| 578 |
+
D += off_hb * seqlen_q_rounded
|
| 579 |
+
LSE += off_hb * seqlen_q_rounded
|
| 580 |
+
if not SEQUENCE_PARALLEL:
|
| 581 |
+
num_block_n = tl.cdiv(seqlen_k, BLOCK_N)
|
| 582 |
+
for start_n in range(0, num_block_n):
|
| 583 |
+
_bwd_kernel_one_col_block(
|
| 584 |
+
start_n,
|
| 585 |
+
Q,
|
| 586 |
+
K,
|
| 587 |
+
V,
|
| 588 |
+
Bias,
|
| 589 |
+
DO,
|
| 590 |
+
DQ,
|
| 591 |
+
DK,
|
| 592 |
+
DV,
|
| 593 |
+
LSE,
|
| 594 |
+
D,
|
| 595 |
+
softmax_scale,
|
| 596 |
+
stride_qm,
|
| 597 |
+
stride_kn,
|
| 598 |
+
stride_vn,
|
| 599 |
+
stride_bm,
|
| 600 |
+
stride_dom,
|
| 601 |
+
stride_dqm,
|
| 602 |
+
stride_dkn,
|
| 603 |
+
stride_dvn,
|
| 604 |
+
seqlen_q,
|
| 605 |
+
seqlen_k,
|
| 606 |
+
headdim,
|
| 607 |
+
ATOMIC_ADD=False,
|
| 608 |
+
BIAS_TYPE=BIAS_TYPE,
|
| 609 |
+
IS_CAUSAL=IS_CAUSAL,
|
| 610 |
+
BLOCK_HEADDIM=BLOCK_HEADDIM,
|
| 611 |
+
EVEN_M=EVEN_M,
|
| 612 |
+
EVEN_N=EVEN_N,
|
| 613 |
+
EVEN_HEADDIM=EVEN_HEADDIM,
|
| 614 |
+
BLOCK_M=BLOCK_M,
|
| 615 |
+
BLOCK_N=BLOCK_N,
|
| 616 |
+
)
|
| 617 |
+
else:
|
| 618 |
+
start_n = tl.program_id(0)
|
| 619 |
+
_bwd_kernel_one_col_block(
|
| 620 |
+
start_n,
|
| 621 |
+
Q,
|
| 622 |
+
K,
|
| 623 |
+
V,
|
| 624 |
+
Bias,
|
| 625 |
+
DO,
|
| 626 |
+
DQ,
|
| 627 |
+
DK,
|
| 628 |
+
DV,
|
| 629 |
+
LSE,
|
| 630 |
+
D,
|
| 631 |
+
softmax_scale,
|
| 632 |
+
stride_qm,
|
| 633 |
+
stride_kn,
|
| 634 |
+
stride_vn,
|
| 635 |
+
stride_bm,
|
| 636 |
+
stride_dom,
|
| 637 |
+
stride_dqm,
|
| 638 |
+
stride_dkn,
|
| 639 |
+
stride_dvn,
|
| 640 |
+
seqlen_q,
|
| 641 |
+
seqlen_k,
|
| 642 |
+
headdim,
|
| 643 |
+
ATOMIC_ADD=True,
|
| 644 |
+
BIAS_TYPE=BIAS_TYPE,
|
| 645 |
+
IS_CAUSAL=IS_CAUSAL,
|
| 646 |
+
BLOCK_HEADDIM=BLOCK_HEADDIM,
|
| 647 |
+
EVEN_M=EVEN_M,
|
| 648 |
+
EVEN_N=EVEN_N,
|
| 649 |
+
EVEN_HEADDIM=EVEN_HEADDIM,
|
| 650 |
+
BLOCK_M=BLOCK_M,
|
| 651 |
+
BLOCK_N=BLOCK_N,
|
| 652 |
+
)
|
| 653 |
+
|
| 654 |
+
|
| 655 |
+
def _flash_attn_forward(q, k, v, bias=None, causal=False, softmax_scale=None):
|
| 656 |
+
(batch, seqlen_q, nheads, d) = q.shape
|
| 657 |
+
(_, seqlen_k, _, _) = k.shape
|
| 658 |
+
assert k.shape == (batch, seqlen_k, nheads, d)
|
| 659 |
+
assert v.shape == (batch, seqlen_k, nheads, d)
|
| 660 |
+
assert d <= 128, "FlashAttention only support head dimensions up to 128"
|
| 661 |
+
assert q.dtype == k.dtype == v.dtype, "All tensors must have the same type"
|
| 662 |
+
assert q.dtype in [torch.float16, torch.bfloat16], "Only support fp16 and bf16"
|
| 663 |
+
assert q.is_cuda and k.is_cuda and v.is_cuda
|
| 664 |
+
softmax_scale = softmax_scale or 1.0 / math.sqrt(d)
|
| 665 |
+
has_bias = bias is not None
|
| 666 |
+
bias_type = "none"
|
| 667 |
+
if has_bias:
|
| 668 |
+
assert bias.dtype in [q.dtype, torch.float]
|
| 669 |
+
assert bias.is_cuda
|
| 670 |
+
assert bias.dim() == 4
|
| 671 |
+
if bias.stride(-1) != 1:
|
| 672 |
+
bias = bias.contiguous()
|
| 673 |
+
if bias.shape[2:] == (1, seqlen_k):
|
| 674 |
+
bias_type = "vector"
|
| 675 |
+
elif bias.shape[2:] == (seqlen_q, seqlen_k):
|
| 676 |
+
bias_type = "matrix"
|
| 677 |
+
else:
|
| 678 |
+
raise RuntimeError("Last 2 dimensions of bias must be (1, seqlen_k) or (seqlen_q, seqlen_k)")
|
| 679 |
+
bias = bias.expand(batch, nheads, seqlen_q, seqlen_k)
|
| 680 |
+
bias_strides = (bias.stride(0), bias.stride(1), bias.stride(2)) if has_bias else (0, 0, 0)
|
| 681 |
+
seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128
|
| 682 |
+
lse = torch.empty((batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32)
|
| 683 |
+
tmp = torch.empty((batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32)
|
| 684 |
+
o = torch.empty_like(q)
|
| 685 |
+
BLOCK_HEADDIM = max(triton.next_power_of_2(d), 16)
|
| 686 |
+
BLOCK = 128
|
| 687 |
+
num_warps = 4 if d <= 64 else 8
|
| 688 |
+
grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK_M"]), batch * nheads)
|
| 689 |
+
_fwd_kernel[grid](
|
| 690 |
+
q,
|
| 691 |
+
k,
|
| 692 |
+
v,
|
| 693 |
+
bias,
|
| 694 |
+
o,
|
| 695 |
+
lse,
|
| 696 |
+
tmp,
|
| 697 |
+
softmax_scale,
|
| 698 |
+
q.stride(0),
|
| 699 |
+
q.stride(2),
|
| 700 |
+
q.stride(1),
|
| 701 |
+
k.stride(0),
|
| 702 |
+
k.stride(2),
|
| 703 |
+
k.stride(1),
|
| 704 |
+
v.stride(0),
|
| 705 |
+
v.stride(2),
|
| 706 |
+
v.stride(1),
|
| 707 |
+
*bias_strides,
|
| 708 |
+
o.stride(0),
|
| 709 |
+
o.stride(2),
|
| 710 |
+
o.stride(1),
|
| 711 |
+
nheads,
|
| 712 |
+
seqlen_q,
|
| 713 |
+
seqlen_k,
|
| 714 |
+
seqlen_q_rounded,
|
| 715 |
+
d,
|
| 716 |
+
seqlen_q // 32,
|
| 717 |
+
seqlen_k // 32,
|
| 718 |
+
bias_type,
|
| 719 |
+
causal,
|
| 720 |
+
BLOCK_HEADDIM,
|
| 721 |
+
BLOCK_M=BLOCK,
|
| 722 |
+
BLOCK_N=BLOCK,
|
| 723 |
+
num_warps=num_warps,
|
| 724 |
+
num_stages=1
|
| 725 |
+
)
|
| 726 |
+
return (o, lse, softmax_scale)
|
| 727 |
+
|
| 728 |
+
|
| 729 |
+
def _flash_attn_backward(do, q, k, v, o, lse, dq, dk, dv, bias=None, causal=False, softmax_scale=None):
|
| 730 |
+
if do.stride(-1) != 1:
|
| 731 |
+
do = do.contiguous()
|
| 732 |
+
(batch, seqlen_q, nheads, d) = q.shape
|
| 733 |
+
(_, seqlen_k, _, _) = k.shape
|
| 734 |
+
assert d <= 128
|
| 735 |
+
seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128
|
| 736 |
+
assert lse.shape == (batch, nheads, seqlen_q_rounded)
|
| 737 |
+
assert q.stride(-1) == k.stride(-1) == v.stride(-1) == o.stride(-1) == 1
|
| 738 |
+
assert dq.stride(-1) == dk.stride(-1) == dv.stride(-1) == 1
|
| 739 |
+
softmax_scale = softmax_scale or 1.0 / math.sqrt(d)
|
| 740 |
+
dq_accum = torch.empty_like(q, dtype=torch.float32)
|
| 741 |
+
delta = torch.empty_like(lse)
|
| 742 |
+
BLOCK_HEADDIM = max(triton.next_power_of_2(d), 16)
|
| 743 |
+
grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK_M"]), batch * nheads)
|
| 744 |
+
_bwd_preprocess_do_o_dot[grid](
|
| 745 |
+
o,
|
| 746 |
+
do,
|
| 747 |
+
delta,
|
| 748 |
+
o.stride(0),
|
| 749 |
+
o.stride(2),
|
| 750 |
+
o.stride(1),
|
| 751 |
+
do.stride(0),
|
| 752 |
+
do.stride(2),
|
| 753 |
+
do.stride(1),
|
| 754 |
+
nheads,
|
| 755 |
+
seqlen_q,
|
| 756 |
+
seqlen_q_rounded,
|
| 757 |
+
d,
|
| 758 |
+
BLOCK_M=128,
|
| 759 |
+
BLOCK_HEADDIM=BLOCK_HEADDIM,
|
| 760 |
+
)
|
| 761 |
+
has_bias = bias is not None
|
| 762 |
+
bias_type = "none"
|
| 763 |
+
if has_bias:
|
| 764 |
+
assert bias.dtype in [q.dtype, torch.float]
|
| 765 |
+
assert bias.is_cuda
|
| 766 |
+
assert bias.dim() == 4
|
| 767 |
+
assert bias.stride(-1) == 1
|
| 768 |
+
if bias.shape[2:] == (1, seqlen_k):
|
| 769 |
+
bias_type = "vector"
|
| 770 |
+
elif bias.shape[2:] == (seqlen_q, seqlen_k):
|
| 771 |
+
bias_type = "matrix"
|
| 772 |
+
else:
|
| 773 |
+
raise RuntimeError("Last 2 dimensions of bias must be (1, seqlen_k) or (seqlen_q, seqlen_k)")
|
| 774 |
+
bias = bias.expand(batch, nheads, seqlen_q, seqlen_k)
|
| 775 |
+
bias_strides = (bias.stride(0), bias.stride(1), bias.stride(2)) if has_bias else (0, 0, 0)
|
| 776 |
+
grid = lambda META: (triton.cdiv(seqlen_k, META["BLOCK_N"]) if META["SEQUENCE_PARALLEL"] else 1, batch * nheads)
|
| 777 |
+
_bwd_kernel[grid](
|
| 778 |
+
q,
|
| 779 |
+
k,
|
| 780 |
+
v,
|
| 781 |
+
bias,
|
| 782 |
+
do,
|
| 783 |
+
dq_accum,
|
| 784 |
+
dk,
|
| 785 |
+
dv,
|
| 786 |
+
lse,
|
| 787 |
+
delta,
|
| 788 |
+
softmax_scale,
|
| 789 |
+
q.stride(0),
|
| 790 |
+
q.stride(2),
|
| 791 |
+
q.stride(1),
|
| 792 |
+
k.stride(0),
|
| 793 |
+
k.stride(2),
|
| 794 |
+
k.stride(1),
|
| 795 |
+
v.stride(0),
|
| 796 |
+
v.stride(2),
|
| 797 |
+
v.stride(1),
|
| 798 |
+
*bias_strides,
|
| 799 |
+
do.stride(0),
|
| 800 |
+
do.stride(2),
|
| 801 |
+
do.stride(1),
|
| 802 |
+
dq_accum.stride(0),
|
| 803 |
+
dq_accum.stride(2),
|
| 804 |
+
dq_accum.stride(1),
|
| 805 |
+
dk.stride(0),
|
| 806 |
+
dk.stride(2),
|
| 807 |
+
dk.stride(1),
|
| 808 |
+
dv.stride(0),
|
| 809 |
+
dv.stride(2),
|
| 810 |
+
dv.stride(1),
|
| 811 |
+
nheads,
|
| 812 |
+
seqlen_q,
|
| 813 |
+
seqlen_k,
|
| 814 |
+
seqlen_q_rounded,
|
| 815 |
+
d,
|
| 816 |
+
seqlen_q // 32,
|
| 817 |
+
seqlen_k // 32,
|
| 818 |
+
bias_type,
|
| 819 |
+
causal,
|
| 820 |
+
BLOCK_HEADDIM
|
| 821 |
+
)
|
| 822 |
+
dq.copy_(dq_accum)
|
| 823 |
+
|
| 824 |
+
|
| 825 |
+
class FlashAttnQKVPackedFunc(torch.autograd.Function):
|
| 826 |
+
@staticmethod
|
| 827 |
+
def forward(ctx, qkv, bias=None, causal=False, softmax_scale=None):
|
| 828 |
+
"""
|
| 829 |
+
qkv: (batch, seqlen, 3, nheads, headdim)
|
| 830 |
+
bias: optional, shape broadcastible to (batch, nheads, seqlen, seqlen).
|
| 831 |
+
For example, ALiBi mask for causal would have shape (1, nheads, 1, seqlen).
|
| 832 |
+
ALiBi mask for non-causal would have shape (1, nheads, seqlen, seqlen)
|
| 833 |
+
"""
|
| 834 |
+
if qkv.stride(-1) != 1:
|
| 835 |
+
qkv = qkv.contiguous()
|
| 836 |
+
(o, lse, ctx.softmax_scale) = _flash_attn_forward(
|
| 837 |
+
qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2], bias=bias, causal=causal, softmax_scale=softmax_scale
|
| 838 |
+
)
|
| 839 |
+
ctx.save_for_backward(qkv, o, lse, bias)
|
| 840 |
+
ctx.causal = causal
|
| 841 |
+
return o
|
| 842 |
+
|
| 843 |
+
@staticmethod
|
| 844 |
+
def backward(ctx, do):
|
| 845 |
+
(qkv, o, lse, bias) = ctx.saved_tensors
|
| 846 |
+
assert not ctx.needs_input_grad[1], "FlashAttention does not support bias gradient yet"
|
| 847 |
+
with torch.inference_mode():
|
| 848 |
+
dqkv = torch.empty_like(qkv)
|
| 849 |
+
_flash_attn_backward(
|
| 850 |
+
do,
|
| 851 |
+
qkv[:, :, 0],
|
| 852 |
+
qkv[:, :, 1],
|
| 853 |
+
qkv[:, :, 2],
|
| 854 |
+
o,
|
| 855 |
+
lse,
|
| 856 |
+
dqkv[:, :, 0],
|
| 857 |
+
dqkv[:, :, 1],
|
| 858 |
+
dqkv[:, :, 2],
|
| 859 |
+
bias=bias,
|
| 860 |
+
causal=ctx.causal,
|
| 861 |
+
softmax_scale=ctx.softmax_scale,
|
| 862 |
+
)
|
| 863 |
+
return (dqkv, None, None, None)
|
| 864 |
+
|
| 865 |
+
|
| 866 |
+
flash_attn_qkvpacked_func = FlashAttnQKVPackedFunc.apply
|
| 867 |
+
|
| 868 |
+
|
| 869 |
+
class FlashAttnKVPackedFunc(torch.autograd.Function):
|
| 870 |
+
@staticmethod
|
| 871 |
+
def forward(ctx, q, kv, bias=None, causal=False, softmax_scale=None):
|
| 872 |
+
"""
|
| 873 |
+
q: (batch, seqlen_q, nheads, headdim)
|
| 874 |
+
kv: (batch, seqlen_k, 2, nheads, headdim)
|
| 875 |
+
bias: optional, shape broadcastible to (batch, nheads, seqlen_q, seqlen_k).
|
| 876 |
+
For example, ALiBi mask for causal would have shape (1, nheads, 1, seqlen_k).
|
| 877 |
+
ALiBi mask for non-causal would have shape (1, nheads, seqlen_q, seqlen_k)
|
| 878 |
+
"""
|
| 879 |
+
(q, kv) = (x if x.stride(-1) == 1 else x.contiguous() for x in [q, kv])
|
| 880 |
+
(o, lse, ctx.softmax_scale) = _flash_attn_forward(
|
| 881 |
+
q, kv[:, :, 0], kv[:, :, 1], bias=bias, causal=causal, softmax_scale=softmax_scale
|
| 882 |
+
)
|
| 883 |
+
ctx.save_for_backward(q, kv, o, lse, bias)
|
| 884 |
+
ctx.causal = causal
|
| 885 |
+
return o
|
| 886 |
+
|
| 887 |
+
@staticmethod
|
| 888 |
+
def backward(ctx, do):
|
| 889 |
+
(q, kv, o, lse, bias) = ctx.saved_tensors
|
| 890 |
+
if len(ctx.needs_input_grad) >= 3:
|
| 891 |
+
assert not ctx.needs_input_grad[2], "FlashAttention does not support bias gradient yet"
|
| 892 |
+
with torch.inference_mode():
|
| 893 |
+
dq = torch.empty_like(q)
|
| 894 |
+
dkv = torch.empty_like(kv)
|
| 895 |
+
_flash_attn_backward(
|
| 896 |
+
do,
|
| 897 |
+
q,
|
| 898 |
+
kv[:, :, 0],
|
| 899 |
+
kv[:, :, 1],
|
| 900 |
+
o,
|
| 901 |
+
lse,
|
| 902 |
+
dq,
|
| 903 |
+
dkv[:, :, 0],
|
| 904 |
+
dkv[:, :, 1],
|
| 905 |
+
bias=bias,
|
| 906 |
+
causal=ctx.causal,
|
| 907 |
+
softmax_scale=ctx.softmax_scale,
|
| 908 |
+
)
|
| 909 |
+
return (dq, dkv, None, None, None)
|
| 910 |
+
|
| 911 |
+
|
| 912 |
+
flash_attn_kvpacked_func = FlashAttnKVPackedFunc.apply
|
| 913 |
+
|
| 914 |
+
|
| 915 |
+
class FlashAttnFunc(torch.autograd.Function):
|
| 916 |
+
@staticmethod
|
| 917 |
+
def forward(ctx, q, k, v, bias=None, causal=False, softmax_scale=None):
|
| 918 |
+
"""
|
| 919 |
+
q: (batch_size, seqlen_q, nheads, headdim)
|
| 920 |
+
k, v: (batch_size, seqlen_k, nheads, headdim)
|
| 921 |
+
bias: optional, shape broadcastible to (batch, nheads, seqlen_q, seqlen_k).
|
| 922 |
+
For example, ALiBi mask for causal would have shape (1, nheads, 1, seqlen_k).
|
| 923 |
+
ALiBi mask for non-causal would have shape (1, nheads, seqlen_q, seqlen_k)
|
| 924 |
+
"""
|
| 925 |
+
(q, k, v) = (x if x.stride(-1) == 1 else x.contiguous() for x in [q, k, v])
|
| 926 |
+
(o, lse, ctx.softmax_scale) = _flash_attn_forward(
|
| 927 |
+
q, k, v, bias=bias, causal=causal, softmax_scale=softmax_scale
|
| 928 |
+
)
|
| 929 |
+
ctx.save_for_backward(q, k, v, o, lse, bias)
|
| 930 |
+
ctx.causal = causal
|
| 931 |
+
return o
|
| 932 |
+
|
| 933 |
+
@staticmethod
|
| 934 |
+
def backward(ctx, do):
|
| 935 |
+
(q, k, v, o, lse, bias) = ctx.saved_tensors
|
| 936 |
+
assert not ctx.needs_input_grad[3], "FlashAttention does not support bias gradient yet"
|
| 937 |
+
with torch.inference_mode():
|
| 938 |
+
dq = torch.empty_like(q)
|
| 939 |
+
dk = torch.empty_like(k)
|
| 940 |
+
dv = torch.empty_like(v)
|
| 941 |
+
_flash_attn_backward(
|
| 942 |
+
do, q, k, v, o, lse, dq, dk, dv, bias=bias, causal=ctx.causal, softmax_scale=ctx.softmax_scale
|
| 943 |
+
)
|
| 944 |
+
return (dq, dk, dv, None, None, None)
|
| 945 |
+
|
| 946 |
+
|
| 947 |
+
flash_attn_func = FlashAttnFunc.apply
|
VILA/llava/model/language_model/mpt/hf_prefixlm_converter.py
ADDED
|
@@ -0,0 +1,657 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
#
|
| 15 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 16 |
+
|
| 17 |
+
"""Converts Huggingface Causal LM to Prefix LM.
|
| 18 |
+
|
| 19 |
+
Conversion does lightweight surgery on a HuggingFace
|
| 20 |
+
Causal LM to convert it to a Prefix LM.
|
| 21 |
+
|
| 22 |
+
Prefix LMs accepts a `bidirectional_mask` input in `forward`
|
| 23 |
+
and treat the input prompt as the prefix in `generate`.
|
| 24 |
+
"""
|
| 25 |
+
import math
|
| 26 |
+
import warnings
|
| 27 |
+
from types import MethodType
|
| 28 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
| 29 |
+
|
| 30 |
+
import torch
|
| 31 |
+
from transformers.models.bloom.modeling_bloom import (
|
| 32 |
+
BaseModelOutputWithPastAndCrossAttentions,
|
| 33 |
+
BloomForCausalLM,
|
| 34 |
+
BloomModel,
|
| 35 |
+
CausalLMOutputWithCrossAttentions,
|
| 36 |
+
CrossEntropyLoss,
|
| 37 |
+
)
|
| 38 |
+
from transformers.models.bloom.modeling_bloom import _expand_mask as _expand_mask_bloom
|
| 39 |
+
from transformers.models.bloom.modeling_bloom import _make_causal_mask as _make_causal_mask_bloom
|
| 40 |
+
from transformers.models.bloom.modeling_bloom import logging
|
| 41 |
+
from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel
|
| 42 |
+
from transformers.models.gpt_neo.modeling_gpt_neo import GPTNeoForCausalLM
|
| 43 |
+
from transformers.models.gpt_neox.modeling_gpt_neox import GPTNeoXForCausalLM
|
| 44 |
+
from transformers.models.gptj.modeling_gptj import GPTJForCausalLM
|
| 45 |
+
from transformers.models.opt.modeling_opt import OPTForCausalLM
|
| 46 |
+
from transformers.models.opt.modeling_opt import _expand_mask as _expand_mask_opt
|
| 47 |
+
from transformers.models.opt.modeling_opt import _make_causal_mask as _make_causal_mask_opt
|
| 48 |
+
|
| 49 |
+
logger = logging.get_logger(__name__)
|
| 50 |
+
_SUPPORTED_GPT_MODELS = (GPT2LMHeadModel, GPTJForCausalLM, GPTNeoForCausalLM, GPTNeoXForCausalLM)
|
| 51 |
+
CAUSAL_GPT_TYPES = Union[GPT2LMHeadModel, GPTJForCausalLM, GPTNeoForCausalLM, GPTNeoXForCausalLM]
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def _convert_gpt_causal_lm_to_prefix_lm(model: CAUSAL_GPT_TYPES) -> CAUSAL_GPT_TYPES:
|
| 55 |
+
"""Converts a GPT-style Causal LM to a Prefix LM.
|
| 56 |
+
|
| 57 |
+
Supported HuggingFace model classes:
|
| 58 |
+
- `GPT2LMHeadModel`
|
| 59 |
+
- `GPTNeoForCausalLM`
|
| 60 |
+
- `GPTNeoXForCausalLM`
|
| 61 |
+
- `GPTJForCausalLM`
|
| 62 |
+
|
| 63 |
+
See `convert_hf_causal_lm_to_prefix_lm` for more details.
|
| 64 |
+
"""
|
| 65 |
+
if hasattr(model, "_prefix_lm_converted"):
|
| 66 |
+
return model
|
| 67 |
+
assert isinstance(model, _SUPPORTED_GPT_MODELS)
|
| 68 |
+
assert model.config.add_cross_attention == False, "Only supports GPT-style decoder-only models"
|
| 69 |
+
|
| 70 |
+
def _get_attn_modules(model: CAUSAL_GPT_TYPES) -> List[torch.nn.Module]:
|
| 71 |
+
"""Helper that gets a list of the model's attention modules.
|
| 72 |
+
|
| 73 |
+
Each module has a `bias` buffer used for causal masking. The Prefix LM
|
| 74 |
+
conversion adds logic to dynamically manipulate these biases to support
|
| 75 |
+
Prefix LM attention masking.
|
| 76 |
+
"""
|
| 77 |
+
attn_modules = []
|
| 78 |
+
if isinstance(model, GPTNeoXForCausalLM):
|
| 79 |
+
blocks = model.gpt_neox.layers
|
| 80 |
+
else:
|
| 81 |
+
blocks = model.transformer.h
|
| 82 |
+
for block in blocks:
|
| 83 |
+
if isinstance(model, GPTNeoForCausalLM):
|
| 84 |
+
if block.attn.attention_type != "global":
|
| 85 |
+
continue
|
| 86 |
+
attn_module = block.attn.attention
|
| 87 |
+
elif isinstance(model, GPTNeoXForCausalLM):
|
| 88 |
+
attn_module = block.attention
|
| 89 |
+
else:
|
| 90 |
+
attn_module = block.attn
|
| 91 |
+
attn_modules.append(attn_module)
|
| 92 |
+
return attn_modules
|
| 93 |
+
|
| 94 |
+
setattr(model, "_original_forward", getattr(model, "forward"))
|
| 95 |
+
setattr(model, "_original_generate", getattr(model, "generate"))
|
| 96 |
+
|
| 97 |
+
def forward(
|
| 98 |
+
self: CAUSAL_GPT_TYPES,
|
| 99 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 100 |
+
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
| 101 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 102 |
+
bidirectional_mask: Optional[torch.Tensor] = None,
|
| 103 |
+
token_type_ids: Optional[torch.LongTensor] = None,
|
| 104 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 105 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
| 106 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 107 |
+
labels: Optional[torch.LongTensor] = None,
|
| 108 |
+
use_cache: Optional[bool] = None,
|
| 109 |
+
output_attentions: Optional[bool] = None,
|
| 110 |
+
output_hidden_states: Optional[bool] = None,
|
| 111 |
+
return_dict: Optional[bool] = None,
|
| 112 |
+
):
|
| 113 |
+
"""Wraps original forward to enable PrefixLM attention."""
|
| 114 |
+
|
| 115 |
+
def call_og_forward():
|
| 116 |
+
if isinstance(self, GPTNeoXForCausalLM):
|
| 117 |
+
return self._original_forward(
|
| 118 |
+
input_ids=input_ids,
|
| 119 |
+
past_key_values=past_key_values,
|
| 120 |
+
attention_mask=attention_mask,
|
| 121 |
+
head_mask=head_mask,
|
| 122 |
+
inputs_embeds=inputs_embeds,
|
| 123 |
+
labels=labels,
|
| 124 |
+
use_cache=use_cache,
|
| 125 |
+
output_attentions=output_attentions,
|
| 126 |
+
output_hidden_states=output_hidden_states,
|
| 127 |
+
return_dict=return_dict,
|
| 128 |
+
)
|
| 129 |
+
else:
|
| 130 |
+
return self._original_forward(
|
| 131 |
+
input_ids=input_ids,
|
| 132 |
+
past_key_values=past_key_values,
|
| 133 |
+
attention_mask=attention_mask,
|
| 134 |
+
token_type_ids=token_type_ids,
|
| 135 |
+
position_ids=position_ids,
|
| 136 |
+
head_mask=head_mask,
|
| 137 |
+
inputs_embeds=inputs_embeds,
|
| 138 |
+
labels=labels,
|
| 139 |
+
use_cache=use_cache,
|
| 140 |
+
output_attentions=output_attentions,
|
| 141 |
+
output_hidden_states=output_hidden_states,
|
| 142 |
+
return_dict=return_dict,
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
if bidirectional_mask is None:
|
| 146 |
+
return call_og_forward()
|
| 147 |
+
assert isinstance(bidirectional_mask, torch.Tensor)
|
| 148 |
+
attn_modules = _get_attn_modules(model)
|
| 149 |
+
(b, s) = bidirectional_mask.shape
|
| 150 |
+
max_length = attn_modules[0].bias.shape[-1]
|
| 151 |
+
if s > max_length:
|
| 152 |
+
raise ValueError(
|
| 153 |
+
f"bidirectional_mask sequence length (={s}) exceeds the "
|
| 154 |
+
+ f"max length allowed by the model ({max_length})."
|
| 155 |
+
)
|
| 156 |
+
assert s <= max_length
|
| 157 |
+
if s < max_length:
|
| 158 |
+
pad = torch.zeros(
|
| 159 |
+
(int(b), int(max_length - s)), dtype=bidirectional_mask.dtype, device=bidirectional_mask.device
|
| 160 |
+
)
|
| 161 |
+
bidirectional_mask = torch.cat([bidirectional_mask, pad], dim=1)
|
| 162 |
+
bidirectional = bidirectional_mask.unsqueeze(1).unsqueeze(1)
|
| 163 |
+
for attn_module in attn_modules:
|
| 164 |
+
attn_module.bias.data = torch.logical_or(attn_module.bias.data, bidirectional)
|
| 165 |
+
output = call_og_forward()
|
| 166 |
+
for attn_module in attn_modules:
|
| 167 |
+
attn_module.bias.data = torch.tril(attn_module.bias.data[0, 0])[None, None]
|
| 168 |
+
return output
|
| 169 |
+
|
| 170 |
+
def generate(self: CAUSAL_GPT_TYPES, *args: tuple, **kwargs: Dict[str, Any]):
|
| 171 |
+
"""Wraps original generate to enable PrefixLM attention."""
|
| 172 |
+
attn_modules = _get_attn_modules(model)
|
| 173 |
+
for attn_module in attn_modules:
|
| 174 |
+
attn_module.bias.data[:] = 1
|
| 175 |
+
output = self._original_generate(*args, **kwargs)
|
| 176 |
+
for attn_module in attn_modules:
|
| 177 |
+
attn_module.bias.data = torch.tril(attn_module.bias.data[0, 0])[None, None]
|
| 178 |
+
return output
|
| 179 |
+
|
| 180 |
+
setattr(model, "forward", MethodType(forward, model))
|
| 181 |
+
setattr(model, "generate", MethodType(generate, model))
|
| 182 |
+
setattr(model, "_prefix_lm_converted", True)
|
| 183 |
+
return model
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
def _convert_bloom_causal_lm_to_prefix_lm(model: BloomForCausalLM) -> BloomForCausalLM:
|
| 187 |
+
"""Converts a BLOOM Causal LM to a Prefix LM.
|
| 188 |
+
|
| 189 |
+
Supported HuggingFace model classes:
|
| 190 |
+
- `BloomForCausalLM`
|
| 191 |
+
|
| 192 |
+
See `convert_hf_causal_lm_to_prefix_lm` for more details.
|
| 193 |
+
"""
|
| 194 |
+
if hasattr(model, "_prefix_lm_converted"):
|
| 195 |
+
return model
|
| 196 |
+
assert isinstance(model, BloomForCausalLM)
|
| 197 |
+
assert model.config.add_cross_attention == False, "Only supports BLOOM decoder-only models"
|
| 198 |
+
|
| 199 |
+
def _prepare_attn_mask(
|
| 200 |
+
self: BloomModel,
|
| 201 |
+
attention_mask: torch.Tensor,
|
| 202 |
+
bidirectional_mask: Optional[torch.Tensor],
|
| 203 |
+
input_shape: Tuple[int, int],
|
| 204 |
+
past_key_values_length: int,
|
| 205 |
+
) -> torch.BoolTensor:
|
| 206 |
+
combined_attention_mask = None
|
| 207 |
+
device = attention_mask.device
|
| 208 |
+
(_, src_length) = input_shape
|
| 209 |
+
if src_length > 1:
|
| 210 |
+
combined_attention_mask = _make_causal_mask_bloom(
|
| 211 |
+
input_shape, device=device, past_key_values_length=past_key_values_length
|
| 212 |
+
)
|
| 213 |
+
if bidirectional_mask is not None:
|
| 214 |
+
assert attention_mask.shape == bidirectional_mask.shape
|
| 215 |
+
expanded_bidirectional_mask = _expand_mask_bloom(bidirectional_mask, tgt_length=src_length)
|
| 216 |
+
combined_attention_mask = torch.logical_and(combined_attention_mask, expanded_bidirectional_mask)
|
| 217 |
+
expanded_attn_mask = _expand_mask_bloom(attention_mask, tgt_length=src_length)
|
| 218 |
+
combined_attention_mask = (
|
| 219 |
+
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask | combined_attention_mask
|
| 220 |
+
)
|
| 221 |
+
return combined_attention_mask
|
| 222 |
+
|
| 223 |
+
def _build_alibi_tensor(
|
| 224 |
+
self: BloomModel, batch_size: int, query_length: int, key_length: int, dtype: torch.dtype, device: torch.device
|
| 225 |
+
) -> torch.Tensor:
|
| 226 |
+
num_heads = self.config.n_head
|
| 227 |
+
closest_power_of_2 = 2 ** math.floor(math.log2(num_heads))
|
| 228 |
+
base = torch.tensor(2 ** (-(2 ** (-(math.log2(closest_power_of_2) - 3)))), device=device, dtype=torch.float32)
|
| 229 |
+
powers = torch.arange(1, 1 + closest_power_of_2, device=device, dtype=torch.int32)
|
| 230 |
+
slopes = torch.pow(base, powers)
|
| 231 |
+
if closest_power_of_2 != num_heads:
|
| 232 |
+
extra_base = torch.tensor(
|
| 233 |
+
2 ** (-(2 ** (-(math.log2(2 * closest_power_of_2) - 3)))), device=device, dtype=torch.float32
|
| 234 |
+
)
|
| 235 |
+
num_remaining_heads = min(closest_power_of_2, num_heads - closest_power_of_2)
|
| 236 |
+
extra_powers = torch.arange(1, 1 + 2 * num_remaining_heads, 2, device=device, dtype=torch.int32)
|
| 237 |
+
slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0)
|
| 238 |
+
qa = torch.arange(query_length, device=device, dtype=torch.int32).view(-1, 1)
|
| 239 |
+
ka = torch.arange(key_length, device=device, dtype=torch.int32).view(1, -1)
|
| 240 |
+
diffs = qa - ka + key_length - query_length
|
| 241 |
+
diffs = -diffs.abs()
|
| 242 |
+
alibi = slopes.view(1, num_heads, 1, 1) * diffs.view(1, 1, query_length, key_length)
|
| 243 |
+
alibi = alibi.expand(batch_size, -1, -1, -1).reshape(-1, query_length, key_length)
|
| 244 |
+
return alibi.to(dtype)
|
| 245 |
+
|
| 246 |
+
KeyValueT = Tuple[torch.Tensor, torch.Tensor]
|
| 247 |
+
|
| 248 |
+
def forward(
|
| 249 |
+
self: BloomModel,
|
| 250 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 251 |
+
past_key_values: Optional[Tuple[KeyValueT, ...]] = None,
|
| 252 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 253 |
+
bidirectional_mask: Optional[torch.Tensor] = None,
|
| 254 |
+
head_mask: Optional[torch.LongTensor] = None,
|
| 255 |
+
inputs_embeds: Optional[torch.LongTensor] = None,
|
| 256 |
+
use_cache: Optional[bool] = None,
|
| 257 |
+
output_attentions: Optional[bool] = None,
|
| 258 |
+
output_hidden_states: Optional[bool] = None,
|
| 259 |
+
return_dict: Optional[bool] = None,
|
| 260 |
+
**deprecated_arguments,
|
| 261 |
+
) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]:
|
| 262 |
+
if deprecated_arguments.pop("position_ids", False) is not False:
|
| 263 |
+
warnings.warn(
|
| 264 |
+
"`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. "
|
| 265 |
+
+ "You can safely ignore passing `position_ids`.",
|
| 266 |
+
FutureWarning,
|
| 267 |
+
)
|
| 268 |
+
if len(deprecated_arguments) > 0:
|
| 269 |
+
raise ValueError(f"Got unexpected arguments: {deprecated_arguments}")
|
| 270 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 271 |
+
output_hidden_states = (
|
| 272 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 273 |
+
)
|
| 274 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
| 275 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 276 |
+
if input_ids is not None and inputs_embeds is not None:
|
| 277 |
+
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
| 278 |
+
elif input_ids is not None:
|
| 279 |
+
(batch_size, seq_length) = input_ids.shape
|
| 280 |
+
elif inputs_embeds is not None:
|
| 281 |
+
(batch_size, seq_length, _) = inputs_embeds.shape
|
| 282 |
+
else:
|
| 283 |
+
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
| 284 |
+
if past_key_values is None:
|
| 285 |
+
past_key_values = tuple([None] * len(self.h))
|
| 286 |
+
head_mask = self.get_head_mask(head_mask, self.config.n_layer)
|
| 287 |
+
if inputs_embeds is None:
|
| 288 |
+
inputs_embeds = self.word_embeddings(input_ids)
|
| 289 |
+
hidden_states = self.word_embeddings_layernorm(inputs_embeds)
|
| 290 |
+
presents = () if use_cache else None
|
| 291 |
+
all_self_attentions = () if output_attentions else None
|
| 292 |
+
all_hidden_states = () if output_hidden_states else None
|
| 293 |
+
seq_length_with_past = seq_length
|
| 294 |
+
past_key_values_length = 0
|
| 295 |
+
if past_key_values[0] is not None:
|
| 296 |
+
tmp = past_key_values[0][0]
|
| 297 |
+
past_key_values_length = tmp.shape[2]
|
| 298 |
+
seq_length_with_past = seq_length_with_past + past_key_values_length
|
| 299 |
+
if attention_mask is None:
|
| 300 |
+
attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device)
|
| 301 |
+
else:
|
| 302 |
+
attention_mask = attention_mask.to(hidden_states.device)
|
| 303 |
+
alibi = self._build_alibi_tensor(
|
| 304 |
+
batch_size=batch_size,
|
| 305 |
+
query_length=seq_length,
|
| 306 |
+
key_length=seq_length_with_past,
|
| 307 |
+
dtype=hidden_states.dtype,
|
| 308 |
+
device=hidden_states.device,
|
| 309 |
+
)
|
| 310 |
+
causal_mask = self._prepare_attn_mask(
|
| 311 |
+
attention_mask,
|
| 312 |
+
bidirectional_mask,
|
| 313 |
+
input_shape=(batch_size, seq_length),
|
| 314 |
+
past_key_values_length=past_key_values_length,
|
| 315 |
+
)
|
| 316 |
+
for (i, (block, layer_past)) in enumerate(zip(self.h, past_key_values)):
|
| 317 |
+
if output_hidden_states:
|
| 318 |
+
hst = (hidden_states,)
|
| 319 |
+
all_hidden_states = all_hidden_states + hst
|
| 320 |
+
if self.gradient_checkpointing and self.training:
|
| 321 |
+
if use_cache:
|
| 322 |
+
logger.warning(
|
| 323 |
+
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
| 324 |
+
)
|
| 325 |
+
use_cache = False
|
| 326 |
+
|
| 327 |
+
def create_custom_forward(module):
|
| 328 |
+
def custom_forward(*inputs):
|
| 329 |
+
return module(*inputs, use_cache=use_cache, output_attentions=output_attentions)
|
| 330 |
+
|
| 331 |
+
return custom_forward
|
| 332 |
+
|
| 333 |
+
outputs = torch.utils.checkpoint.checkpoint(
|
| 334 |
+
create_custom_forward(block), hidden_states, alibi, causal_mask, head_mask[i]
|
| 335 |
+
)
|
| 336 |
+
else:
|
| 337 |
+
outputs = block(
|
| 338 |
+
hidden_states,
|
| 339 |
+
layer_past=layer_past,
|
| 340 |
+
attention_mask=causal_mask,
|
| 341 |
+
head_mask=head_mask[i],
|
| 342 |
+
use_cache=use_cache,
|
| 343 |
+
output_attentions=output_attentions,
|
| 344 |
+
alibi=alibi,
|
| 345 |
+
)
|
| 346 |
+
hidden_states = outputs[0]
|
| 347 |
+
if use_cache is True:
|
| 348 |
+
presents = presents + (outputs[1],)
|
| 349 |
+
if output_attentions:
|
| 350 |
+
oa = (outputs[2 if use_cache else 1],)
|
| 351 |
+
all_self_attentions = all_self_attentions + oa
|
| 352 |
+
hidden_states = self.ln_f(hidden_states)
|
| 353 |
+
if output_hidden_states:
|
| 354 |
+
hst = (hidden_states,)
|
| 355 |
+
all_hidden_states = all_hidden_states + hst
|
| 356 |
+
if not return_dict:
|
| 357 |
+
return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
|
| 358 |
+
return BaseModelOutputWithPastAndCrossAttentions(
|
| 359 |
+
last_hidden_state=hidden_states,
|
| 360 |
+
past_key_values=presents,
|
| 361 |
+
hidden_states=all_hidden_states,
|
| 362 |
+
attentions=all_self_attentions,
|
| 363 |
+
)
|
| 364 |
+
|
| 365 |
+
setattr(model.transformer, "_prepare_attn_mask", MethodType(_prepare_attn_mask, model.transformer))
|
| 366 |
+
setattr(model.transformer, "_build_alibi_tensor", MethodType(_build_alibi_tensor, model.transformer))
|
| 367 |
+
setattr(model.transformer, "forward", MethodType(forward, model.transformer))
|
| 368 |
+
KeyValueT = Tuple[torch.Tensor, torch.Tensor]
|
| 369 |
+
|
| 370 |
+
def forward(
|
| 371 |
+
self: BloomForCausalLM,
|
| 372 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 373 |
+
past_key_values: Optional[Tuple[KeyValueT, ...]] = None,
|
| 374 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 375 |
+
bidirectional_mask: Optional[torch.Tensor] = None,
|
| 376 |
+
head_mask: Optional[torch.Tensor] = None,
|
| 377 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
| 378 |
+
labels: Optional[torch.Tensor] = None,
|
| 379 |
+
use_cache: Optional[bool] = None,
|
| 380 |
+
output_attentions: Optional[bool] = None,
|
| 381 |
+
output_hidden_states: Optional[bool] = None,
|
| 382 |
+
return_dict: Optional[bool] = None,
|
| 383 |
+
**deprecated_arguments,
|
| 384 |
+
) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
|
| 385 |
+
"""Replacement forward method for BloomCausalLM."""
|
| 386 |
+
if deprecated_arguments.pop("position_ids", False) is not False:
|
| 387 |
+
warnings.warn(
|
| 388 |
+
"`position_ids` have no functionality in BLOOM and will be removed "
|
| 389 |
+
+ "in v5.0.0. You can safely ignore passing `position_ids`.",
|
| 390 |
+
FutureWarning,
|
| 391 |
+
)
|
| 392 |
+
if len(deprecated_arguments) > 0:
|
| 393 |
+
raise ValueError(f"Got unexpected arguments: {deprecated_arguments}")
|
| 394 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 395 |
+
transformer_outputs = self.transformer(
|
| 396 |
+
input_ids,
|
| 397 |
+
past_key_values=past_key_values,
|
| 398 |
+
attention_mask=attention_mask,
|
| 399 |
+
bidirectional_mask=bidirectional_mask,
|
| 400 |
+
head_mask=head_mask,
|
| 401 |
+
inputs_embeds=inputs_embeds,
|
| 402 |
+
use_cache=use_cache,
|
| 403 |
+
output_attentions=output_attentions,
|
| 404 |
+
output_hidden_states=output_hidden_states,
|
| 405 |
+
return_dict=return_dict,
|
| 406 |
+
)
|
| 407 |
+
hidden_states = transformer_outputs[0]
|
| 408 |
+
lm_logits = self.lm_head(hidden_states)
|
| 409 |
+
loss = None
|
| 410 |
+
if labels is not None:
|
| 411 |
+
shift_logits = lm_logits[..., :-1, :].contiguous()
|
| 412 |
+
shift_labels = labels[..., 1:].contiguous()
|
| 413 |
+
(batch_size, seq_length, vocab_size) = shift_logits.shape
|
| 414 |
+
loss_fct = CrossEntropyLoss()
|
| 415 |
+
loss = loss_fct(
|
| 416 |
+
shift_logits.view(batch_size * seq_length, vocab_size), shift_labels.view(batch_size * seq_length)
|
| 417 |
+
)
|
| 418 |
+
if not return_dict:
|
| 419 |
+
output = (lm_logits,) + transformer_outputs[1:]
|
| 420 |
+
return (loss,) + output if loss is not None else output
|
| 421 |
+
return CausalLMOutputWithCrossAttentions(
|
| 422 |
+
loss=loss,
|
| 423 |
+
logits=lm_logits,
|
| 424 |
+
past_key_values=transformer_outputs.past_key_values,
|
| 425 |
+
hidden_states=transformer_outputs.hidden_states,
|
| 426 |
+
attentions=transformer_outputs.attentions,
|
| 427 |
+
)
|
| 428 |
+
|
| 429 |
+
def prepare_inputs_for_generation(
|
| 430 |
+
self: BloomForCausalLM,
|
| 431 |
+
input_ids: torch.LongTensor,
|
| 432 |
+
past: Optional[torch.Tensor] = None,
|
| 433 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 434 |
+
**kwargs,
|
| 435 |
+
) -> dict:
|
| 436 |
+
if past:
|
| 437 |
+
input_ids = input_ids[:, -1].unsqueeze(-1)
|
| 438 |
+
bidirectional_mask = None
|
| 439 |
+
if past[0][0].shape[0] == input_ids.shape[0]:
|
| 440 |
+
past = self._convert_to_bloom_cache(past)
|
| 441 |
+
else:
|
| 442 |
+
bidirectional_mask = torch.ones_like(input_ids)
|
| 443 |
+
return {
|
| 444 |
+
"input_ids": input_ids,
|
| 445 |
+
"past_key_values": past,
|
| 446 |
+
"use_cache": True,
|
| 447 |
+
"attention_mask": attention_mask,
|
| 448 |
+
"bidirectional_mask": bidirectional_mask,
|
| 449 |
+
}
|
| 450 |
+
|
| 451 |
+
setattr(model, "forward", MethodType(forward, model))
|
| 452 |
+
setattr(model, "prepare_inputs_for_generation", MethodType(prepare_inputs_for_generation, model))
|
| 453 |
+
setattr(model, "_prefix_lm_converted", True)
|
| 454 |
+
return model
|
| 455 |
+
|
| 456 |
+
|
| 457 |
+
def _convert_opt_causal_lm_to_prefix_lm(model: OPTForCausalLM) -> OPTForCausalLM:
|
| 458 |
+
"""Converts an OPT Causal LM to a Prefix LM.
|
| 459 |
+
|
| 460 |
+
Supported HuggingFace model classes:
|
| 461 |
+
- `OPTForCausalLM`
|
| 462 |
+
|
| 463 |
+
See `convert_hf_causal_lm_to_prefix_lm` for more details.
|
| 464 |
+
"""
|
| 465 |
+
if hasattr(model, "_prefix_lm_converted"):
|
| 466 |
+
return model
|
| 467 |
+
assert isinstance(model, OPTForCausalLM)
|
| 468 |
+
assert model.config.add_cross_attention == False, "Only supports OPT decoder-only models"
|
| 469 |
+
setattr(model, "_original_forward", getattr(model, "forward"))
|
| 470 |
+
setattr(model, "_original_generate", getattr(model, "generate"))
|
| 471 |
+
model.model.decoder.bidirectional_mask = None
|
| 472 |
+
|
| 473 |
+
def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
|
| 474 |
+
combined_attention_mask = None
|
| 475 |
+
if input_shape[-1] > 1:
|
| 476 |
+
if self.bidirectional_mask == "g":
|
| 477 |
+
(bsz, src_length) = input_shape
|
| 478 |
+
combined_attention_mask = torch.zeros(
|
| 479 |
+
(bsz, 1, src_length, src_length + past_key_values_length),
|
| 480 |
+
dtype=inputs_embeds.dtype,
|
| 481 |
+
device=inputs_embeds.device,
|
| 482 |
+
)
|
| 483 |
+
else:
|
| 484 |
+
combined_attention_mask = _make_causal_mask_opt(
|
| 485 |
+
input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length
|
| 486 |
+
).to(inputs_embeds.device)
|
| 487 |
+
if self.bidirectional_mask is not None:
|
| 488 |
+
assert attention_mask.shape == self.bidirectional_mask.shape
|
| 489 |
+
expanded_bidirectional_mask = _expand_mask_opt(
|
| 490 |
+
self.bidirectional_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
|
| 491 |
+
).to(inputs_embeds.device)
|
| 492 |
+
combined_attention_mask = torch.maximum(expanded_bidirectional_mask, combined_attention_mask)
|
| 493 |
+
if attention_mask is not None:
|
| 494 |
+
expanded_attn_mask = _expand_mask_opt(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
|
| 495 |
+
inputs_embeds.device
|
| 496 |
+
)
|
| 497 |
+
combined_attention_mask = (
|
| 498 |
+
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
|
| 499 |
+
)
|
| 500 |
+
return combined_attention_mask
|
| 501 |
+
|
| 502 |
+
setattr(
|
| 503 |
+
model.model.decoder,
|
| 504 |
+
"_prepare_decoder_attention_mask",
|
| 505 |
+
MethodType(_prepare_decoder_attention_mask, model.model.decoder),
|
| 506 |
+
)
|
| 507 |
+
|
| 508 |
+
def forward(
|
| 509 |
+
self: OPTForCausalLM,
|
| 510 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 511 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 512 |
+
bidirectional_mask: Optional[torch.ByteTensor] = None,
|
| 513 |
+
head_mask: Optional[torch.Tensor] = None,
|
| 514 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
| 515 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 516 |
+
labels: Optional[torch.LongTensor] = None,
|
| 517 |
+
use_cache: Optional[bool] = None,
|
| 518 |
+
output_attentions: Optional[bool] = None,
|
| 519 |
+
output_hidden_states: Optional[bool] = None,
|
| 520 |
+
return_dict: Optional[bool] = None,
|
| 521 |
+
):
|
| 522 |
+
def call_og_forward():
|
| 523 |
+
return self._original_forward(
|
| 524 |
+
input_ids=input_ids,
|
| 525 |
+
attention_mask=attention_mask,
|
| 526 |
+
head_mask=head_mask,
|
| 527 |
+
past_key_values=past_key_values,
|
| 528 |
+
inputs_embeds=inputs_embeds,
|
| 529 |
+
labels=labels,
|
| 530 |
+
use_cache=use_cache,
|
| 531 |
+
output_attentions=output_attentions,
|
| 532 |
+
output_hidden_states=output_hidden_states,
|
| 533 |
+
return_dict=return_dict,
|
| 534 |
+
)
|
| 535 |
+
|
| 536 |
+
if bidirectional_mask is None:
|
| 537 |
+
return call_og_forward()
|
| 538 |
+
self.model.decoder.bidirectional_mask = bidirectional_mask
|
| 539 |
+
try:
|
| 540 |
+
outputs = call_og_forward()
|
| 541 |
+
except:
|
| 542 |
+
self.model.decoder.bidirectional_mask = None
|
| 543 |
+
raise
|
| 544 |
+
self.model.decoder.bidirectional_mask = None
|
| 545 |
+
return outputs
|
| 546 |
+
|
| 547 |
+
def generate(self: OPTForCausalLM, *args: tuple, **kwargs: Dict[str, Any]):
|
| 548 |
+
"""Wraps original generate to enable PrefixLM-style attention."""
|
| 549 |
+
self.model.decoder.bidirectional_mask = "g"
|
| 550 |
+
try:
|
| 551 |
+
output = self._original_generate(*args, **kwargs)
|
| 552 |
+
except:
|
| 553 |
+
self.model.decoder.bidirectional_mask = None
|
| 554 |
+
raise
|
| 555 |
+
self.model.decoder.bidirectional_mask = None
|
| 556 |
+
return output
|
| 557 |
+
|
| 558 |
+
setattr(model, "forward", MethodType(forward, model))
|
| 559 |
+
setattr(model, "generate", MethodType(generate, model))
|
| 560 |
+
setattr(model, "_prefix_lm_converted", True)
|
| 561 |
+
return model
|
| 562 |
+
|
| 563 |
+
|
| 564 |
+
_SUPPORTED_HF_MODELS = _SUPPORTED_GPT_MODELS + (BloomForCausalLM, OPTForCausalLM)
|
| 565 |
+
CAUSAL_LM_TYPES = Union[
|
| 566 |
+
GPT2LMHeadModel, GPTJForCausalLM, GPTNeoForCausalLM, GPTNeoXForCausalLM, BloomForCausalLM, OPTForCausalLM
|
| 567 |
+
]
|
| 568 |
+
|
| 569 |
+
|
| 570 |
+
def convert_hf_causal_lm_to_prefix_lm(model: CAUSAL_LM_TYPES) -> CAUSAL_LM_TYPES:
|
| 571 |
+
"""Converts a HuggingFace Causal LM to a Prefix LM.
|
| 572 |
+
|
| 573 |
+
Supported HuggingFace model classes:
|
| 574 |
+
- `GPT2LMHeadModel`
|
| 575 |
+
- `GPTNeoForCausalLM`
|
| 576 |
+
- `GPTNeoXForCausalLM`
|
| 577 |
+
- `GPTJForCausalLM`
|
| 578 |
+
- `BloomForCausalLM`
|
| 579 |
+
- `OPTForCausalLM`
|
| 580 |
+
|
| 581 |
+
Conversion to a Prefix LM is done by modifying the `forward` method, and possibly also the
|
| 582 |
+
`generate` method and/or select underlying methods depending on the model class.
|
| 583 |
+
|
| 584 |
+
These changes preserve the model API, but add a new input to `forward`: "bidirectional_mask".
|
| 585 |
+
|
| 586 |
+
Notes on training:
|
| 587 |
+
To actually train the converted model as a Prefix LM, training batches will need to indicate
|
| 588 |
+
the prefix/target structure by including `bidirectional_mask` as part of the batch inputs.
|
| 589 |
+
|
| 590 |
+
**This is not a standard input and requires custom layers either within or after your dataloader.**
|
| 591 |
+
|
| 592 |
+
In addition to adding `bidirectional_mask` to the batch, this custom code should modify `labels`
|
| 593 |
+
such that `batch['labels'][batch['bidirectional_mask'] == 1] == -100`.
|
| 594 |
+
That is, the prefix portion of the sequence should not generate any loss. Loss should only be
|
| 595 |
+
generated by the target portion of the sequence.
|
| 596 |
+
|
| 597 |
+
Notes on `GPTNeoForCausalLM`:
|
| 598 |
+
To simplify the implementation, "global" and "local" attention layers are handled differently.
|
| 599 |
+
For "global" layers, we handle conversion as described above. For "local" layers, which use a
|
| 600 |
+
causal attention mask within a restricted local window, we do not alter the masking.
|
| 601 |
+
|
| 602 |
+
Notes on `forward` method conversion:
|
| 603 |
+
After conversion, the `forward` method will handle a new input, `bidirectional_mask`,
|
| 604 |
+
which should be a [batch_size, seq_length] byte tensor, where 1 indicates token positions
|
| 605 |
+
belonging to the prefix (prefix tokens can attend to one another bidirectionally), and
|
| 606 |
+
0 indicates token positions belonging to the target.
|
| 607 |
+
|
| 608 |
+
The new `forward` method will incorporate `bidirectional_mask` (if supplied) into the existing
|
| 609 |
+
causal mask, call the original `forward` method, and (if the causal mask is a buffer) reset
|
| 610 |
+
the causal masks before returning the result.
|
| 611 |
+
|
| 612 |
+
Notes on `generate` method conversion:
|
| 613 |
+
After conversion, the `generate` method will have the same signature but will internally
|
| 614 |
+
convert all causal masks to be purely bidirectional, call the original `generate` method, and
|
| 615 |
+
(where appropriate) reset the causal masks before returning the result.
|
| 616 |
+
|
| 617 |
+
This works thanks to the logic of the HuggingFace `generate` API, which first encodes the token
|
| 618 |
+
"prompt" passed to `generate` (which is treated as the prefix) and then sequentially generates
|
| 619 |
+
each new token. Encodings are cached as generation happens, so all prefix tokens can attend to one
|
| 620 |
+
another (as expected in a Prefix LM) and generated tokens can only attend to prefix tokens and
|
| 621 |
+
previously-generated tokens (also as expected in a Prefix LM).
|
| 622 |
+
|
| 623 |
+
To preserve the API, the original methods are renamed to `_original_forward` and
|
| 624 |
+
`_original_generate`, and replaced with new `forward` and `generate` methods that wrap
|
| 625 |
+
them, respectively. Although implementation details vary by model class.
|
| 626 |
+
"""
|
| 627 |
+
if isinstance(model, _SUPPORTED_GPT_MODELS):
|
| 628 |
+
return _convert_gpt_causal_lm_to_prefix_lm(model)
|
| 629 |
+
elif isinstance(model, BloomForCausalLM):
|
| 630 |
+
return _convert_bloom_causal_lm_to_prefix_lm(model)
|
| 631 |
+
elif isinstance(model, OPTForCausalLM):
|
| 632 |
+
return _convert_opt_causal_lm_to_prefix_lm(model)
|
| 633 |
+
else:
|
| 634 |
+
raise TypeError(
|
| 635 |
+
f"Cannot convert model to Prefix LM. "
|
| 636 |
+
+ f"Model does not belong to set of supported HF models:"
|
| 637 |
+
+ f"\n{_SUPPORTED_HF_MODELS}"
|
| 638 |
+
)
|
| 639 |
+
|
| 640 |
+
|
| 641 |
+
def add_bidirectional_mask_if_missing(batch: Dict[str, Any]):
|
| 642 |
+
"""Attempts to add bidirectional_mask to batch if missing.
|
| 643 |
+
|
| 644 |
+
Raises:
|
| 645 |
+
KeyError if bidirectional_mask is missing and can't be inferred
|
| 646 |
+
"""
|
| 647 |
+
if "bidirectional_mask" not in batch:
|
| 648 |
+
if batch.get("mode", None) == "icl_task":
|
| 649 |
+
batch["bidirectional_mask"] = batch["attention_mask"].clone()
|
| 650 |
+
for (i, continuation_indices) in enumerate(batch["continuation_indices"]):
|
| 651 |
+
batch["bidirectional_mask"][i, continuation_indices] = 0
|
| 652 |
+
elif "labels" in batch and "attention_mask" in batch:
|
| 653 |
+
batch["bidirectional_mask"] = torch.logical_and(
|
| 654 |
+
torch.eq(batch["attention_mask"], 1), torch.eq(batch["labels"], -100)
|
| 655 |
+
).type_as(batch["attention_mask"])
|
| 656 |
+
else:
|
| 657 |
+
raise KeyError("No bidirectional_mask in batch and not sure how to construct one.")
|
VILA/llava/model/language_model/mpt/meta_init_context.py
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
#
|
| 15 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 16 |
+
|
| 17 |
+
from contextlib import contextmanager
|
| 18 |
+
|
| 19 |
+
import torch
|
| 20 |
+
import torch.nn as nn
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
@contextmanager
|
| 24 |
+
def init_empty_weights(include_buffers: bool = False):
|
| 25 |
+
"""Meta initialization context manager.
|
| 26 |
+
|
| 27 |
+
A context manager under which models are initialized with all parameters
|
| 28 |
+
on the meta device, therefore creating an empty model. Useful when just
|
| 29 |
+
initializing the model would blow the available RAM.
|
| 30 |
+
|
| 31 |
+
Args:
|
| 32 |
+
include_buffers (`bool`, *optional*, defaults to `False`): Whether or
|
| 33 |
+
not to also put all buffers on the meta device while initializing.
|
| 34 |
+
|
| 35 |
+
Example:
|
| 36 |
+
```python
|
| 37 |
+
import torch.nn as nn
|
| 38 |
+
|
| 39 |
+
# Initialize a model with 100 billions parameters in no time and without using any RAM.
|
| 40 |
+
with init_empty_weights():
|
| 41 |
+
tst = nn.Sequential(*[nn.Linear(10000, 10000) for _ in range(1000)])
|
| 42 |
+
```
|
| 43 |
+
|
| 44 |
+
<Tip warning={true}>
|
| 45 |
+
|
| 46 |
+
Any model created under this context manager has no weights. As such you can't do something like
|
| 47 |
+
`model.to(some_device)` with it. To load weights inside your empty model, see [`load_checkpoint_and_dispatch`].
|
| 48 |
+
|
| 49 |
+
</Tip>
|
| 50 |
+
"""
|
| 51 |
+
with init_on_device(torch.device("meta"), include_buffers=include_buffers) as f:
|
| 52 |
+
yield f
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
@contextmanager
|
| 56 |
+
def init_on_device(device: torch.device, include_buffers: bool = False):
|
| 57 |
+
"""Device initialization context manager.
|
| 58 |
+
|
| 59 |
+
A context manager under which models are initialized with all parameters
|
| 60 |
+
on the specified device.
|
| 61 |
+
|
| 62 |
+
Args:
|
| 63 |
+
device (`torch.device`): Device to initialize all parameters on.
|
| 64 |
+
include_buffers (`bool`, *optional*, defaults to `False`): Whether or
|
| 65 |
+
not to also put all buffers on the meta device while initializing.
|
| 66 |
+
|
| 67 |
+
Example:
|
| 68 |
+
```python
|
| 69 |
+
import torch.nn as nn
|
| 70 |
+
|
| 71 |
+
with init_on_device(device=torch.device("cuda")):
|
| 72 |
+
tst = nn.Liner(100, 100) # on `cuda` device
|
| 73 |
+
```
|
| 74 |
+
"""
|
| 75 |
+
old_register_parameter = nn.Module.register_parameter
|
| 76 |
+
if include_buffers:
|
| 77 |
+
old_register_buffer = nn.Module.register_buffer
|
| 78 |
+
|
| 79 |
+
def register_empty_parameter(module, name, param):
|
| 80 |
+
old_register_parameter(module, name, param)
|
| 81 |
+
if param is not None:
|
| 82 |
+
param_cls = type(module._parameters[name])
|
| 83 |
+
kwargs = module._parameters[name].__dict__
|
| 84 |
+
module._parameters[name] = param_cls(module._parameters[name].to(device), **kwargs)
|
| 85 |
+
|
| 86 |
+
def register_empty_buffer(module, name, buffer):
|
| 87 |
+
old_register_buffer(module, name, buffer)
|
| 88 |
+
if buffer is not None:
|
| 89 |
+
module._buffers[name] = module._buffers[name].to(device)
|
| 90 |
+
|
| 91 |
+
if include_buffers:
|
| 92 |
+
tensor_constructors_to_patch = {
|
| 93 |
+
torch_function_name: getattr(torch, torch_function_name)
|
| 94 |
+
for torch_function_name in ["empty", "zeros", "ones", "full"]
|
| 95 |
+
}
|
| 96 |
+
else:
|
| 97 |
+
tensor_constructors_to_patch = {}
|
| 98 |
+
|
| 99 |
+
def patch_tensor_constructor(fn):
|
| 100 |
+
def wrapper(*args, **kwargs):
|
| 101 |
+
kwargs["device"] = device
|
| 102 |
+
return fn(*args, **kwargs)
|
| 103 |
+
|
| 104 |
+
return wrapper
|
| 105 |
+
|
| 106 |
+
try:
|
| 107 |
+
nn.Module.register_parameter = register_empty_parameter
|
| 108 |
+
if include_buffers:
|
| 109 |
+
nn.Module.register_buffer = register_empty_buffer
|
| 110 |
+
for torch_function_name in tensor_constructors_to_patch.keys():
|
| 111 |
+
setattr(torch, torch_function_name, patch_tensor_constructor(getattr(torch, torch_function_name)))
|
| 112 |
+
yield
|
| 113 |
+
finally:
|
| 114 |
+
nn.Module.register_parameter = old_register_parameter
|
| 115 |
+
if include_buffers:
|
| 116 |
+
nn.Module.register_buffer = old_register_buffer
|
| 117 |
+
for (torch_function_name, old_torch_function) in tensor_constructors_to_patch.items():
|
| 118 |
+
setattr(torch, torch_function_name, old_torch_function)
|
VILA/llava/model/language_model/mpt/modeling_mpt.py
ADDED
|
@@ -0,0 +1,483 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
#
|
| 15 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 16 |
+
|
| 17 |
+
"""A simple, flexible implementation of a GPT model.
|
| 18 |
+
|
| 19 |
+
Inspired by https://github.com/karpathy/minGPT/blob/master/mingpt/model.py
|
| 20 |
+
"""
|
| 21 |
+
import math
|
| 22 |
+
import warnings
|
| 23 |
+
from typing import List, Optional, Tuple, Union
|
| 24 |
+
|
| 25 |
+
import torch
|
| 26 |
+
import torch.nn as nn
|
| 27 |
+
import torch.nn.functional as F
|
| 28 |
+
from transformers import PreTrainedModel, PreTrainedTokenizer, PreTrainedTokenizerFast
|
| 29 |
+
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
| 30 |
+
|
| 31 |
+
from .adapt_tokenizer import AutoTokenizerForMOD, adapt_tokenizer_for_denoising
|
| 32 |
+
from .attention import attn_bias_shape, build_attn_bias
|
| 33 |
+
from .blocks import MPTBlock
|
| 34 |
+
from .configuration_mpt import MPTConfig
|
| 35 |
+
from .custom_embedding import SharedEmbedding
|
| 36 |
+
from .hf_prefixlm_converter import add_bidirectional_mask_if_missing, convert_hf_causal_lm_to_prefix_lm
|
| 37 |
+
from .meta_init_context import init_empty_weights
|
| 38 |
+
from .norm import NORM_CLASS_REGISTRY
|
| 39 |
+
from .param_init_fns import MODEL_INIT_REGISTRY, generic_param_init_fn_
|
| 40 |
+
|
| 41 |
+
try:
|
| 42 |
+
from .flash_attn_triton import flash_attn_func
|
| 43 |
+
except:
|
| 44 |
+
pass
|
| 45 |
+
Tokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
class MPTPreTrainedModel(PreTrainedModel):
|
| 49 |
+
config_class = MPTConfig
|
| 50 |
+
base_model_prefix = "model"
|
| 51 |
+
_no_split_modules = ["MPTBlock"]
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
class MPTModel(MPTPreTrainedModel):
|
| 55 |
+
def __init__(self, config: MPTConfig):
|
| 56 |
+
config._validate_config()
|
| 57 |
+
super().__init__(config)
|
| 58 |
+
self.attn_impl = config.attn_config["attn_impl"]
|
| 59 |
+
self.prefix_lm = config.attn_config["prefix_lm"]
|
| 60 |
+
self.attn_uses_sequence_id = config.attn_config["attn_uses_sequence_id"]
|
| 61 |
+
self.alibi = config.attn_config["alibi"]
|
| 62 |
+
self.alibi_bias_max = config.attn_config["alibi_bias_max"]
|
| 63 |
+
if config.init_device == "mixed":
|
| 64 |
+
if dist.get_local_rank() == 0:
|
| 65 |
+
config.init_device = "cpu"
|
| 66 |
+
else:
|
| 67 |
+
config.init_device = "meta"
|
| 68 |
+
if config.norm_type.lower() not in NORM_CLASS_REGISTRY.keys():
|
| 69 |
+
norm_options = " | ".join(NORM_CLASS_REGISTRY.keys())
|
| 70 |
+
raise NotImplementedError(
|
| 71 |
+
f"Requested norm type ({config.norm_type}) is not implemented within this repo (Options: {norm_options})."
|
| 72 |
+
)
|
| 73 |
+
norm_class = NORM_CLASS_REGISTRY[config.norm_type.lower()]
|
| 74 |
+
self.embedding_fraction = config.embedding_fraction
|
| 75 |
+
self.wte = SharedEmbedding(config.vocab_size, config.d_model, device=config.init_device)
|
| 76 |
+
if not self.alibi:
|
| 77 |
+
self.wpe = torch.nn.Embedding(config.max_seq_len, config.d_model, device=config.init_device)
|
| 78 |
+
self.emb_drop = nn.Dropout(config.emb_pdrop)
|
| 79 |
+
self.blocks = nn.ModuleList(
|
| 80 |
+
[MPTBlock(device=config.init_device, **config.to_dict()) for _ in range(config.n_layers)]
|
| 81 |
+
)
|
| 82 |
+
self.norm_f = norm_class(config.d_model, device=config.init_device)
|
| 83 |
+
if config.init_device != "meta":
|
| 84 |
+
print(
|
| 85 |
+
f'You are using config.init_device={config.init_device!r}, but you can also use config.init_device="meta" with Composer + FSDP for fast initialization.'
|
| 86 |
+
)
|
| 87 |
+
self.apply(self.param_init_fn)
|
| 88 |
+
self.is_causal = not self.prefix_lm
|
| 89 |
+
self._attn_bias_initialized = False
|
| 90 |
+
self.attn_bias = None
|
| 91 |
+
self.attn_bias_shape = attn_bias_shape(
|
| 92 |
+
self.attn_impl,
|
| 93 |
+
config.n_heads,
|
| 94 |
+
config.max_seq_len,
|
| 95 |
+
self.alibi,
|
| 96 |
+
prefix_lm=self.prefix_lm,
|
| 97 |
+
causal=self.is_causal,
|
| 98 |
+
use_sequence_id=self.attn_uses_sequence_id,
|
| 99 |
+
)
|
| 100 |
+
if config.no_bias:
|
| 101 |
+
for module in self.modules():
|
| 102 |
+
if hasattr(module, "bias") and isinstance(module.bias, nn.Parameter):
|
| 103 |
+
if config.verbose:
|
| 104 |
+
warnings.warn(f"Removing bias ({module.bias}) from {module}.")
|
| 105 |
+
module.register_parameter("bias", None)
|
| 106 |
+
if config.verbose and config.verbose > 2:
|
| 107 |
+
print(self)
|
| 108 |
+
if "verbose" not in self.config.init_config:
|
| 109 |
+
self.config.init_config["verbose"] = self.config.verbose
|
| 110 |
+
if self.config.init_config["verbose"] > 1:
|
| 111 |
+
init_fn_name = self.config.init_config["name"]
|
| 112 |
+
warnings.warn(f"Using {init_fn_name} initialization.")
|
| 113 |
+
self.gradient_checkpointing = False
|
| 114 |
+
|
| 115 |
+
def get_input_embeddings(self):
|
| 116 |
+
return self.wte
|
| 117 |
+
|
| 118 |
+
def set_input_embeddings(self, value):
|
| 119 |
+
self.wte = value
|
| 120 |
+
|
| 121 |
+
@torch.no_grad()
|
| 122 |
+
def _attn_bias(
|
| 123 |
+
self,
|
| 124 |
+
device,
|
| 125 |
+
dtype,
|
| 126 |
+
attention_mask: Optional[torch.ByteTensor] = None,
|
| 127 |
+
prefix_mask: Optional[torch.ByteTensor] = None,
|
| 128 |
+
sequence_id: Optional[torch.LongTensor] = None,
|
| 129 |
+
):
|
| 130 |
+
if not self._attn_bias_initialized:
|
| 131 |
+
if self.attn_bias_shape:
|
| 132 |
+
self.attn_bias = torch.zeros(self.attn_bias_shape, device=device, dtype=dtype)
|
| 133 |
+
self.attn_bias = build_attn_bias(
|
| 134 |
+
self.attn_impl,
|
| 135 |
+
self.attn_bias,
|
| 136 |
+
self.config.n_heads,
|
| 137 |
+
self.config.max_seq_len,
|
| 138 |
+
causal=self.is_causal,
|
| 139 |
+
alibi=self.alibi,
|
| 140 |
+
alibi_bias_max=self.alibi_bias_max,
|
| 141 |
+
)
|
| 142 |
+
self._attn_bias_initialized = True
|
| 143 |
+
if self.attn_impl == "flash":
|
| 144 |
+
return (self.attn_bias, attention_mask)
|
| 145 |
+
if self.attn_bias is not None:
|
| 146 |
+
self.attn_bias = self.attn_bias.to(dtype=dtype, device=device)
|
| 147 |
+
attn_bias = self.attn_bias
|
| 148 |
+
if self.prefix_lm:
|
| 149 |
+
assert isinstance(attn_bias, torch.Tensor)
|
| 150 |
+
assert isinstance(prefix_mask, torch.Tensor)
|
| 151 |
+
attn_bias = self._apply_prefix_mask(attn_bias, prefix_mask)
|
| 152 |
+
if self.attn_uses_sequence_id and sequence_id is not None:
|
| 153 |
+
assert isinstance(attn_bias, torch.Tensor)
|
| 154 |
+
attn_bias = self._apply_sequence_id(attn_bias, sequence_id)
|
| 155 |
+
if attention_mask is not None:
|
| 156 |
+
s_k = attention_mask.shape[-1]
|
| 157 |
+
if attn_bias is None:
|
| 158 |
+
attn_bias = torch.zeros((1, 1, 1, s_k), device=device, dtype=dtype)
|
| 159 |
+
else:
|
| 160 |
+
_s_k = max(0, attn_bias.size(-1) - s_k)
|
| 161 |
+
attn_bias = attn_bias[:, :, :, _s_k:]
|
| 162 |
+
if prefix_mask is not None and attention_mask.shape != prefix_mask.shape:
|
| 163 |
+
raise ValueError(
|
| 164 |
+
f"attention_mask shape={attention_mask.shape} "
|
| 165 |
+
+ f"and prefix_mask shape={prefix_mask.shape} are not equal."
|
| 166 |
+
)
|
| 167 |
+
min_val = torch.finfo(attn_bias.dtype).min
|
| 168 |
+
attn_bias = attn_bias.masked_fill(~attention_mask.view(-1, 1, 1, s_k), min_val)
|
| 169 |
+
return (attn_bias, None)
|
| 170 |
+
|
| 171 |
+
def _apply_prefix_mask(self, attn_bias: torch.Tensor, prefix_mask: torch.Tensor):
|
| 172 |
+
(s_k, s_q) = attn_bias.shape[-2:]
|
| 173 |
+
if s_k != self.config.max_seq_len or s_q != self.config.max_seq_len:
|
| 174 |
+
raise ValueError(
|
| 175 |
+
"attn_bias does not match the expected shape. "
|
| 176 |
+
+ f"The last two dimensions should both be {self.config.max_length} "
|
| 177 |
+
+ f"but are {s_k} and {s_q}."
|
| 178 |
+
)
|
| 179 |
+
seq_len = prefix_mask.shape[-1]
|
| 180 |
+
if seq_len > self.config.max_seq_len:
|
| 181 |
+
raise ValueError(f"prefix_mask sequence length cannot exceed max_seq_len={self.config.max_seq_len}")
|
| 182 |
+
attn_bias = attn_bias[..., :seq_len, :seq_len]
|
| 183 |
+
causal = torch.tril(torch.ones((seq_len, seq_len), dtype=torch.bool, device=prefix_mask.device)).view(
|
| 184 |
+
1, 1, seq_len, seq_len
|
| 185 |
+
)
|
| 186 |
+
prefix = prefix_mask.view(-1, 1, 1, seq_len)
|
| 187 |
+
cannot_attend = ~torch.logical_or(causal, prefix.bool())
|
| 188 |
+
min_val = torch.finfo(attn_bias.dtype).min
|
| 189 |
+
attn_bias = attn_bias.masked_fill(cannot_attend, min_val)
|
| 190 |
+
return attn_bias
|
| 191 |
+
|
| 192 |
+
def _apply_sequence_id(self, attn_bias: torch.Tensor, sequence_id: torch.LongTensor):
|
| 193 |
+
seq_len = sequence_id.shape[-1]
|
| 194 |
+
if seq_len > self.config.max_seq_len:
|
| 195 |
+
raise ValueError(f"sequence_id sequence length cannot exceed max_seq_len={self.config.max_seq_len}")
|
| 196 |
+
attn_bias = attn_bias[..., :seq_len, :seq_len]
|
| 197 |
+
cannot_attend = torch.logical_not(
|
| 198 |
+
torch.eq(sequence_id.view(-1, seq_len, 1), sequence_id.view(-1, 1, seq_len))
|
| 199 |
+
).unsqueeze(1)
|
| 200 |
+
min_val = torch.finfo(attn_bias.dtype).min
|
| 201 |
+
attn_bias = attn_bias.masked_fill(cannot_attend, min_val)
|
| 202 |
+
return attn_bias
|
| 203 |
+
|
| 204 |
+
def forward(
|
| 205 |
+
self,
|
| 206 |
+
input_ids: torch.LongTensor,
|
| 207 |
+
past_key_values: Optional[List[Tuple[torch.FloatTensor]]] = None,
|
| 208 |
+
attention_mask: Optional[torch.ByteTensor] = None,
|
| 209 |
+
prefix_mask: Optional[torch.ByteTensor] = None,
|
| 210 |
+
sequence_id: Optional[torch.LongTensor] = None,
|
| 211 |
+
return_dict: Optional[bool] = None,
|
| 212 |
+
output_attentions: Optional[bool] = None,
|
| 213 |
+
output_hidden_states: Optional[bool] = None,
|
| 214 |
+
use_cache: Optional[bool] = None,
|
| 215 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
| 216 |
+
):
|
| 217 |
+
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
| 218 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
| 219 |
+
if attention_mask is not None:
|
| 220 |
+
attention_mask = attention_mask.bool()
|
| 221 |
+
if prefix_mask is not None:
|
| 222 |
+
prefix_mask = prefix_mask.bool()
|
| 223 |
+
if not return_dict:
|
| 224 |
+
raise NotImplementedError("return_dict False is not implemented yet for MPT")
|
| 225 |
+
if output_attentions:
|
| 226 |
+
if self.attn_impl != "torch":
|
| 227 |
+
raise NotImplementedError(
|
| 228 |
+
"output_attentions is not implemented for MPT when using attn_impl `flash` or `triton`."
|
| 229 |
+
)
|
| 230 |
+
if attention_mask is not None and attention_mask[:, 0].sum() != attention_mask.shape[0] and self.training:
|
| 231 |
+
raise NotImplementedError("MPT does not support training with left padding.")
|
| 232 |
+
if self.prefix_lm and prefix_mask is None:
|
| 233 |
+
raise ValueError("prefix_mask is a required argument when MPT is configured with prefix_lm=True.")
|
| 234 |
+
if self.training:
|
| 235 |
+
if self.attn_uses_sequence_id and sequence_id is None:
|
| 236 |
+
raise ValueError(
|
| 237 |
+
"sequence_id is a required argument when MPT is configured with attn_uses_sequence_id=True "
|
| 238 |
+
+ "and the model is in train mode."
|
| 239 |
+
)
|
| 240 |
+
elif self.attn_uses_sequence_id is False and sequence_id is not None:
|
| 241 |
+
warnings.warn(
|
| 242 |
+
"MPT received non-None input for `sequence_id` but is configured with attn_uses_sequence_id=False. "
|
| 243 |
+
+ "This input will be ignored. If you want the model to use `sequence_id`, set attn_uses_sequence_id to True."
|
| 244 |
+
)
|
| 245 |
+
if input_ids is not None:
|
| 246 |
+
S = input_ids.size(1)
|
| 247 |
+
assert (
|
| 248 |
+
S <= self.config.max_seq_len
|
| 249 |
+
), f"Cannot forward input with seq_len={S}, this model only supports seq_len<={self.config.max_seq_len}"
|
| 250 |
+
tok_emb = self.wte(input_ids)
|
| 251 |
+
else:
|
| 252 |
+
assert inputs_embeds is not None
|
| 253 |
+
assert self.alibi, "inputs_embeds is not implemented for MPT unless for alibi."
|
| 254 |
+
S = inputs_embeds.size(1)
|
| 255 |
+
tok_emb = inputs_embeds
|
| 256 |
+
if self.alibi:
|
| 257 |
+
x = tok_emb
|
| 258 |
+
else:
|
| 259 |
+
past_position = 0
|
| 260 |
+
if past_key_values is not None:
|
| 261 |
+
if len(past_key_values) != self.config.n_layers:
|
| 262 |
+
raise ValueError(
|
| 263 |
+
f"past_key_values must provide a past_key_value for each attention "
|
| 264 |
+
+ f"layer in the network (len(past_key_values)={len(past_key_values)!r}; self.config.n_layers={self.config.n_layers!r})."
|
| 265 |
+
)
|
| 266 |
+
past_position = past_key_values[0][0].size(1)
|
| 267 |
+
if self.attn_impl == "torch":
|
| 268 |
+
past_position = past_key_values[0][0].size(3)
|
| 269 |
+
if S + past_position > self.config.max_seq_len:
|
| 270 |
+
raise ValueError(
|
| 271 |
+
f"Cannot forward input with past sequence length {past_position} and current sequence length {S + 1}, this model only supports total sequence length <= {self.config.max_seq_len}."
|
| 272 |
+
)
|
| 273 |
+
pos = torch.arange(past_position, S + past_position, dtype=torch.long, device=input_ids.device).unsqueeze(0)
|
| 274 |
+
if attention_mask is not None:
|
| 275 |
+
pos = torch.clamp(
|
| 276 |
+
pos - torch.cumsum((~attention_mask).to(torch.int32), dim=1)[:, past_position:], min=0
|
| 277 |
+
)
|
| 278 |
+
pos_emb = self.wpe(pos)
|
| 279 |
+
x = tok_emb + pos_emb
|
| 280 |
+
if self.embedding_fraction == 1:
|
| 281 |
+
x = self.emb_drop(x)
|
| 282 |
+
else:
|
| 283 |
+
x_shrunk = x * self.embedding_fraction + x.detach() * (1 - self.embedding_fraction)
|
| 284 |
+
assert isinstance(self.emb_drop, nn.Module)
|
| 285 |
+
x = self.emb_drop(x_shrunk)
|
| 286 |
+
(attn_bias, attention_mask) = self._attn_bias(
|
| 287 |
+
device=x.device,
|
| 288 |
+
dtype=torch.float32,
|
| 289 |
+
attention_mask=attention_mask,
|
| 290 |
+
prefix_mask=prefix_mask,
|
| 291 |
+
sequence_id=sequence_id,
|
| 292 |
+
)
|
| 293 |
+
if use_cache and past_key_values is None:
|
| 294 |
+
past_key_values = [() for _ in range(self.config.n_layers)]
|
| 295 |
+
all_hidden_states = () if output_hidden_states else None
|
| 296 |
+
all_self_attns = () if output_attentions else None
|
| 297 |
+
for (b_idx, block) in enumerate(self.blocks):
|
| 298 |
+
if output_hidden_states:
|
| 299 |
+
assert all_hidden_states is not None
|
| 300 |
+
all_hidden_states = all_hidden_states + (x,)
|
| 301 |
+
past_key_value = past_key_values[b_idx] if past_key_values is not None else None
|
| 302 |
+
if self.gradient_checkpointing and self.training:
|
| 303 |
+
(x, attn_weights, past_key_value) = torch.utils.checkpoint.checkpoint(
|
| 304 |
+
block, x, past_key_value, attn_bias, attention_mask, self.is_causal
|
| 305 |
+
)
|
| 306 |
+
else:
|
| 307 |
+
(x, attn_weights, past_key_value) = block(
|
| 308 |
+
x,
|
| 309 |
+
past_key_value=past_key_value,
|
| 310 |
+
attn_bias=attn_bias,
|
| 311 |
+
attention_mask=attention_mask,
|
| 312 |
+
is_causal=self.is_causal,
|
| 313 |
+
)
|
| 314 |
+
if past_key_values is not None:
|
| 315 |
+
past_key_values[b_idx] = past_key_value
|
| 316 |
+
if output_attentions:
|
| 317 |
+
assert all_self_attns is not None
|
| 318 |
+
all_self_attns = all_self_attns + (attn_weights,)
|
| 319 |
+
x = self.norm_f(x)
|
| 320 |
+
if output_hidden_states:
|
| 321 |
+
assert all_hidden_states is not None
|
| 322 |
+
all_hidden_states = all_hidden_states + (x,)
|
| 323 |
+
return BaseModelOutputWithPast(
|
| 324 |
+
last_hidden_state=x,
|
| 325 |
+
past_key_values=past_key_values,
|
| 326 |
+
hidden_states=all_hidden_states,
|
| 327 |
+
attentions=all_self_attns,
|
| 328 |
+
)
|
| 329 |
+
|
| 330 |
+
def param_init_fn(self, module):
|
| 331 |
+
init_fn_name = self.config.init_config["name"]
|
| 332 |
+
MODEL_INIT_REGISTRY[init_fn_name](
|
| 333 |
+
module=module, n_layers=self.config.n_layers, d_model=self.config.d_model, **self.config.init_config
|
| 334 |
+
)
|
| 335 |
+
|
| 336 |
+
def fsdp_wrap_fn(self, module):
|
| 337 |
+
return isinstance(module, MPTBlock)
|
| 338 |
+
|
| 339 |
+
def activation_checkpointing_fn(self, module):
|
| 340 |
+
return isinstance(module, MPTBlock)
|
| 341 |
+
|
| 342 |
+
|
| 343 |
+
class MPTForCausalLM(MPTPreTrainedModel):
|
| 344 |
+
def __init__(self, config: MPTConfig):
|
| 345 |
+
super().__init__(config)
|
| 346 |
+
if not config.tie_word_embeddings:
|
| 347 |
+
raise ValueError("MPTForCausalLM only supports tied word embeddings")
|
| 348 |
+
print(f"Instantiating an MPTForCausalLM model from {__file__}")
|
| 349 |
+
self.transformer = MPTModel(config)
|
| 350 |
+
for child in self.transformer.children():
|
| 351 |
+
if isinstance(child, torch.nn.ModuleList):
|
| 352 |
+
continue
|
| 353 |
+
if isinstance(child, torch.nn.Module):
|
| 354 |
+
child._fsdp_wrap = True
|
| 355 |
+
self.logit_scale = None
|
| 356 |
+
if config.logit_scale is not None:
|
| 357 |
+
logit_scale = config.logit_scale
|
| 358 |
+
if isinstance(logit_scale, str):
|
| 359 |
+
if logit_scale == "inv_sqrt_d_model":
|
| 360 |
+
logit_scale = 1 / math.sqrt(config.d_model)
|
| 361 |
+
else:
|
| 362 |
+
raise ValueError(
|
| 363 |
+
f"logit_scale={logit_scale!r} is not recognized as an option; use numeric value or 'inv_sqrt_d_model'."
|
| 364 |
+
)
|
| 365 |
+
self.logit_scale = logit_scale
|
| 366 |
+
|
| 367 |
+
def get_input_embeddings(self):
|
| 368 |
+
return self.transformer.wte
|
| 369 |
+
|
| 370 |
+
def set_input_embeddings(self, value):
|
| 371 |
+
self.transformer.wte = value
|
| 372 |
+
|
| 373 |
+
def get_output_embeddings(self):
|
| 374 |
+
return self.transformer.wte
|
| 375 |
+
|
| 376 |
+
def set_output_embeddings(self, new_embeddings):
|
| 377 |
+
self.transformer.wte = new_embeddings
|
| 378 |
+
|
| 379 |
+
def set_decoder(self, decoder):
|
| 380 |
+
self.transformer = decoder
|
| 381 |
+
|
| 382 |
+
def get_decoder(self):
|
| 383 |
+
return self.transformer
|
| 384 |
+
|
| 385 |
+
def forward(
|
| 386 |
+
self,
|
| 387 |
+
input_ids: torch.LongTensor,
|
| 388 |
+
past_key_values: Optional[List[Tuple[torch.FloatTensor]]] = None,
|
| 389 |
+
attention_mask: Optional[torch.ByteTensor] = None,
|
| 390 |
+
prefix_mask: Optional[torch.ByteTensor] = None,
|
| 391 |
+
sequence_id: Optional[torch.LongTensor] = None,
|
| 392 |
+
labels: Optional[torch.LongTensor] = None,
|
| 393 |
+
return_dict: Optional[bool] = None,
|
| 394 |
+
output_attentions: Optional[bool] = None,
|
| 395 |
+
output_hidden_states: Optional[bool] = None,
|
| 396 |
+
use_cache: Optional[bool] = None,
|
| 397 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 398 |
+
):
|
| 399 |
+
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
| 400 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
| 401 |
+
if inputs_embeds is not None:
|
| 402 |
+
raise NotImplementedError("inputs_embeds has to be None (for hf/peft support).")
|
| 403 |
+
outputs = self.transformer(
|
| 404 |
+
input_ids=input_ids,
|
| 405 |
+
past_key_values=past_key_values,
|
| 406 |
+
attention_mask=attention_mask,
|
| 407 |
+
prefix_mask=prefix_mask,
|
| 408 |
+
sequence_id=sequence_id,
|
| 409 |
+
return_dict=return_dict,
|
| 410 |
+
output_attentions=output_attentions,
|
| 411 |
+
output_hidden_states=output_hidden_states,
|
| 412 |
+
use_cache=use_cache,
|
| 413 |
+
)
|
| 414 |
+
logits = self.transformer.wte(outputs.last_hidden_state.to(self.transformer.wte.weight.device), True)
|
| 415 |
+
if self.logit_scale is not None:
|
| 416 |
+
if self.logit_scale == 0:
|
| 417 |
+
warnings.warn(
|
| 418 |
+
f"Multiplying logits by self.logit_scale={self.logit_scale!r}. This will produce uniform (uninformative) outputs."
|
| 419 |
+
)
|
| 420 |
+
logits *= self.logit_scale
|
| 421 |
+
loss = None
|
| 422 |
+
if labels is not None:
|
| 423 |
+
labels = torch.roll(labels, shifts=-1)
|
| 424 |
+
labels[:, -1] = -100
|
| 425 |
+
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), labels.to(logits.device).view(-1))
|
| 426 |
+
return CausalLMOutputWithPast(
|
| 427 |
+
loss=loss,
|
| 428 |
+
logits=logits,
|
| 429 |
+
past_key_values=outputs.past_key_values,
|
| 430 |
+
hidden_states=outputs.hidden_states,
|
| 431 |
+
attentions=outputs.attentions,
|
| 432 |
+
)
|
| 433 |
+
|
| 434 |
+
def param_init_fn(self, module):
|
| 435 |
+
init_fn_name = self.config.init_config["name"]
|
| 436 |
+
MODEL_INIT_REGISTRY[init_fn_name](
|
| 437 |
+
module=module, n_layers=self.config.n_layers, d_model=self.config.d_model, **self.config.init_config
|
| 438 |
+
)
|
| 439 |
+
|
| 440 |
+
def fsdp_wrap_fn(self, module):
|
| 441 |
+
return isinstance(module, MPTBlock)
|
| 442 |
+
|
| 443 |
+
def activation_checkpointing_fn(self, module):
|
| 444 |
+
return isinstance(module, MPTBlock)
|
| 445 |
+
|
| 446 |
+
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
|
| 447 |
+
if inputs_embeds is not None:
|
| 448 |
+
raise NotImplementedError("inputs_embeds is not implemented for MPT yet")
|
| 449 |
+
attention_mask = kwargs["attention_mask"].bool()
|
| 450 |
+
if attention_mask[:, -1].sum() != attention_mask.shape[0]:
|
| 451 |
+
raise NotImplementedError("MPT does not support generation with right padding.")
|
| 452 |
+
if self.transformer.attn_uses_sequence_id and self.training:
|
| 453 |
+
sequence_id = torch.zeros_like(input_ids[:1])
|
| 454 |
+
else:
|
| 455 |
+
sequence_id = None
|
| 456 |
+
if past_key_values is not None:
|
| 457 |
+
input_ids = input_ids[:, -1].unsqueeze(-1)
|
| 458 |
+
if self.transformer.prefix_lm:
|
| 459 |
+
prefix_mask = torch.ones_like(attention_mask)
|
| 460 |
+
if kwargs.get("use_cache") == False:
|
| 461 |
+
raise NotImplementedError("MPT with prefix_lm=True does not support use_cache=False.")
|
| 462 |
+
else:
|
| 463 |
+
prefix_mask = None
|
| 464 |
+
return {
|
| 465 |
+
"input_ids": input_ids,
|
| 466 |
+
"attention_mask": attention_mask,
|
| 467 |
+
"prefix_mask": prefix_mask,
|
| 468 |
+
"sequence_id": sequence_id,
|
| 469 |
+
"past_key_values": past_key_values,
|
| 470 |
+
"use_cache": kwargs.get("use_cache", True),
|
| 471 |
+
}
|
| 472 |
+
|
| 473 |
+
@staticmethod
|
| 474 |
+
def _reorder_cache(past_key_values, beam_idx):
|
| 475 |
+
"""Used by HuggingFace generate when using beam search with kv-caching.
|
| 476 |
+
|
| 477 |
+
See https://github.com/huggingface/transformers/blob/3ec7a47664ebe40c40f4b722f6bb1cd30c3821ec/src/transformers/models/gpt2/modeling_gpt2.py#L1122-L1133
|
| 478 |
+
for an example in transformers.
|
| 479 |
+
"""
|
| 480 |
+
reordered_past = []
|
| 481 |
+
for layer_past in past_key_values:
|
| 482 |
+
reordered_past += [tuple(past_state.index_select(0, beam_idx) for past_state in layer_past)]
|
| 483 |
+
return reordered_past
|
VILA/llava/model/language_model/mpt/norm.py
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
#
|
| 15 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 16 |
+
|
| 17 |
+
import torch
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def _cast_if_autocast_enabled(tensor):
|
| 21 |
+
if torch.is_autocast_enabled():
|
| 22 |
+
if tensor.device.type == "cuda":
|
| 23 |
+
dtype = torch.get_autocast_gpu_dtype()
|
| 24 |
+
elif tensor.device.type == "cpu":
|
| 25 |
+
dtype = torch.get_autocast_cpu_dtype()
|
| 26 |
+
else:
|
| 27 |
+
raise NotImplementedError()
|
| 28 |
+
return tensor.to(dtype=dtype)
|
| 29 |
+
return tensor
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class LPLayerNorm(torch.nn.LayerNorm):
|
| 33 |
+
def __init__(self, normalized_shape, eps=1e-05, elementwise_affine=True, device=None, dtype=None):
|
| 34 |
+
super().__init__(
|
| 35 |
+
normalized_shape=normalized_shape,
|
| 36 |
+
eps=eps,
|
| 37 |
+
elementwise_affine=elementwise_affine,
|
| 38 |
+
device=device,
|
| 39 |
+
dtype=dtype,
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
def forward(self, x):
|
| 43 |
+
module_device = x.device
|
| 44 |
+
downcast_x = _cast_if_autocast_enabled(x)
|
| 45 |
+
downcast_weight = _cast_if_autocast_enabled(self.weight) if self.weight is not None else self.weight
|
| 46 |
+
downcast_bias = _cast_if_autocast_enabled(self.bias) if self.bias is not None else self.bias
|
| 47 |
+
with torch.autocast(enabled=False, device_type=module_device.type):
|
| 48 |
+
return torch.nn.functional.layer_norm(
|
| 49 |
+
downcast_x, self.normalized_shape, downcast_weight, downcast_bias, self.eps
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def rms_norm(x, weight=None, eps=1e-05):
|
| 54 |
+
output = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps)
|
| 55 |
+
if weight is not None:
|
| 56 |
+
return output * weight
|
| 57 |
+
return output
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
class RMSNorm(torch.nn.Module):
|
| 61 |
+
def __init__(self, normalized_shape, eps=1e-05, weight=True, dtype=None, device=None):
|
| 62 |
+
super().__init__()
|
| 63 |
+
self.eps = eps
|
| 64 |
+
if weight:
|
| 65 |
+
self.weight = torch.nn.Parameter(torch.ones(normalized_shape, dtype=dtype, device=device))
|
| 66 |
+
else:
|
| 67 |
+
self.register_parameter("weight", None)
|
| 68 |
+
|
| 69 |
+
def forward(self, x):
|
| 70 |
+
return rms_norm(x.float(), self.weight, self.eps).to(dtype=x.dtype)
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
class LPRMSNorm(RMSNorm):
|
| 74 |
+
def __init__(self, normalized_shape, eps=1e-05, weight=True, dtype=None, device=None):
|
| 75 |
+
super().__init__(normalized_shape=normalized_shape, eps=eps, weight=weight, dtype=dtype, device=device)
|
| 76 |
+
|
| 77 |
+
def forward(self, x):
|
| 78 |
+
downcast_x = _cast_if_autocast_enabled(x)
|
| 79 |
+
downcast_weight = _cast_if_autocast_enabled(self.weight) if self.weight is not None else self.weight
|
| 80 |
+
with torch.autocast(enabled=False, device_type=x.device.type):
|
| 81 |
+
return rms_norm(downcast_x, downcast_weight, self.eps).to(dtype=x.dtype)
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
NORM_CLASS_REGISTRY = {
|
| 85 |
+
"layernorm": torch.nn.LayerNorm,
|
| 86 |
+
"low_precision_layernorm": LPLayerNorm,
|
| 87 |
+
"rmsnorm": RMSNorm,
|
| 88 |
+
"low_precision_rmsnorm": LPRMSNorm,
|
| 89 |
+
}
|
VILA/llava/model/language_model/mpt/param_init_fns.py
ADDED
|
@@ -0,0 +1,399 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
#
|
| 15 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 16 |
+
|
| 17 |
+
import math
|
| 18 |
+
import warnings
|
| 19 |
+
from collections.abc import Sequence
|
| 20 |
+
from functools import partial
|
| 21 |
+
from typing import Optional, Tuple, Union
|
| 22 |
+
|
| 23 |
+
import torch
|
| 24 |
+
from torch import nn
|
| 25 |
+
|
| 26 |
+
from .norm import NORM_CLASS_REGISTRY
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def torch_default_param_init_fn_(module: nn.Module, verbose: int = 0, **kwargs):
|
| 30 |
+
del kwargs
|
| 31 |
+
if verbose > 1:
|
| 32 |
+
warnings.warn(f"Initializing network using module's reset_parameters attribute")
|
| 33 |
+
if hasattr(module, "reset_parameters"):
|
| 34 |
+
module.reset_parameters()
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def fused_init_helper_(module: nn.Module, init_fn_):
|
| 38 |
+
_fused = getattr(module, "_fused", None)
|
| 39 |
+
if _fused is None:
|
| 40 |
+
raise RuntimeError(f"Internal logic error")
|
| 41 |
+
(dim, splits) = _fused
|
| 42 |
+
splits = (0, *splits, module.weight.size(dim))
|
| 43 |
+
for (s, e) in zip(splits[:-1], splits[1:]):
|
| 44 |
+
slice_indices = [slice(None)] * module.weight.ndim
|
| 45 |
+
slice_indices[dim] = slice(s, e)
|
| 46 |
+
init_fn_(module.weight[slice_indices])
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def generic_param_init_fn_(
|
| 50 |
+
module: nn.Module,
|
| 51 |
+
init_fn_,
|
| 52 |
+
n_layers: int,
|
| 53 |
+
d_model: Optional[int] = None,
|
| 54 |
+
init_div_is_residual: Union[int, float, str, bool] = True,
|
| 55 |
+
emb_init_std: Optional[float] = None,
|
| 56 |
+
emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None,
|
| 57 |
+
verbose: int = 0,
|
| 58 |
+
**kwargs,
|
| 59 |
+
):
|
| 60 |
+
del kwargs
|
| 61 |
+
if verbose > 1:
|
| 62 |
+
warnings.warn(f"If model has bias parameters they are initialized to 0.")
|
| 63 |
+
init_div_is_residual = init_div_is_residual
|
| 64 |
+
if init_div_is_residual is False:
|
| 65 |
+
div_is_residual = 1.0
|
| 66 |
+
elif init_div_is_residual is True:
|
| 67 |
+
div_is_residual = math.sqrt(2 * n_layers)
|
| 68 |
+
elif isinstance(init_div_is_residual, float) or isinstance(init_div_is_residual, int):
|
| 69 |
+
div_is_residual = init_div_is_residual
|
| 70 |
+
elif isinstance(init_div_is_residual, str) and init_div_is_residual.isnumeric():
|
| 71 |
+
div_is_residual = float(init_div_is_residual)
|
| 72 |
+
else:
|
| 73 |
+
div_is_residual = 1.0
|
| 74 |
+
raise ValueError(f"Expected init_div_is_residual to be boolean or numeric, got {init_div_is_residual}")
|
| 75 |
+
if init_div_is_residual is not False:
|
| 76 |
+
if verbose > 1:
|
| 77 |
+
warnings.warn(
|
| 78 |
+
f"Initializing _is_residual layers then dividing them by {div_is_residual:.3f}. "
|
| 79 |
+
+ f"Set `init_div_is_residual: false` in init config to disable this."
|
| 80 |
+
)
|
| 81 |
+
if isinstance(module, nn.Linear):
|
| 82 |
+
if hasattr(module, "_fused"):
|
| 83 |
+
fused_init_helper_(module, init_fn_)
|
| 84 |
+
else:
|
| 85 |
+
init_fn_(module.weight)
|
| 86 |
+
if module.bias is not None:
|
| 87 |
+
torch.nn.init.zeros_(module.bias)
|
| 88 |
+
if init_div_is_residual is not False and getattr(module, "_is_residual", False):
|
| 89 |
+
with torch.no_grad():
|
| 90 |
+
module.weight.div_(div_is_residual)
|
| 91 |
+
elif isinstance(module, nn.Embedding):
|
| 92 |
+
if emb_init_std is not None:
|
| 93 |
+
std = emb_init_std
|
| 94 |
+
if std == 0:
|
| 95 |
+
warnings.warn(f"Embedding layer initialized to 0.")
|
| 96 |
+
emb_init_fn_ = partial(torch.nn.init.normal_, mean=0.0, std=std)
|
| 97 |
+
if verbose > 1:
|
| 98 |
+
warnings.warn(f"Embedding layer initialized using normal distribution with mean=0 and std={std!r}.")
|
| 99 |
+
elif emb_init_uniform_lim is not None:
|
| 100 |
+
lim = emb_init_uniform_lim
|
| 101 |
+
if isinstance(lim, Sequence):
|
| 102 |
+
if len(lim) > 2:
|
| 103 |
+
raise ValueError(f"Uniform init requires a min and a max limit. User input: {lim}.")
|
| 104 |
+
if lim[0] == lim[1]:
|
| 105 |
+
warnings.warn(f"Embedding layer initialized to {lim[0]}.")
|
| 106 |
+
else:
|
| 107 |
+
if lim == 0:
|
| 108 |
+
warnings.warn(f"Embedding layer initialized to 0.")
|
| 109 |
+
lim = [-lim, lim]
|
| 110 |
+
(a, b) = lim
|
| 111 |
+
emb_init_fn_ = partial(torch.nn.init.uniform_, a=a, b=b)
|
| 112 |
+
if verbose > 1:
|
| 113 |
+
warnings.warn(f"Embedding layer initialized using uniform distribution in range {lim}.")
|
| 114 |
+
else:
|
| 115 |
+
emb_init_fn_ = init_fn_
|
| 116 |
+
emb_init_fn_(module.weight)
|
| 117 |
+
elif isinstance(module, tuple(set(NORM_CLASS_REGISTRY.values()))):
|
| 118 |
+
if verbose > 1:
|
| 119 |
+
warnings.warn(f"Norm weights are set to 1. If norm layer has a bias it is initialized to 0.")
|
| 120 |
+
if hasattr(module, "weight") and module.weight is not None:
|
| 121 |
+
torch.nn.init.ones_(module.weight)
|
| 122 |
+
if hasattr(module, "bias") and module.bias is not None:
|
| 123 |
+
torch.nn.init.zeros_(module.bias)
|
| 124 |
+
elif isinstance(module, nn.MultiheadAttention):
|
| 125 |
+
if module._qkv_same_embed_dim:
|
| 126 |
+
assert module.in_proj_weight is not None
|
| 127 |
+
assert module.q_proj_weight is None and module.k_proj_weight is None and (module.v_proj_weight is None)
|
| 128 |
+
assert d_model is not None
|
| 129 |
+
_d = d_model
|
| 130 |
+
splits = (0, _d, 2 * _d, 3 * _d)
|
| 131 |
+
for (s, e) in zip(splits[:-1], splits[1:]):
|
| 132 |
+
init_fn_(module.in_proj_weight[s:e])
|
| 133 |
+
else:
|
| 134 |
+
assert (
|
| 135 |
+
module.q_proj_weight is not None
|
| 136 |
+
and module.k_proj_weight is not None
|
| 137 |
+
and (module.v_proj_weight is not None)
|
| 138 |
+
)
|
| 139 |
+
assert module.in_proj_weight is None
|
| 140 |
+
init_fn_(module.q_proj_weight)
|
| 141 |
+
init_fn_(module.k_proj_weight)
|
| 142 |
+
init_fn_(module.v_proj_weight)
|
| 143 |
+
if module.in_proj_bias is not None:
|
| 144 |
+
torch.nn.init.zeros_(module.in_proj_bias)
|
| 145 |
+
if module.bias_k is not None:
|
| 146 |
+
torch.nn.init.zeros_(module.bias_k)
|
| 147 |
+
if module.bias_v is not None:
|
| 148 |
+
torch.nn.init.zeros_(module.bias_v)
|
| 149 |
+
init_fn_(module.out_proj.weight)
|
| 150 |
+
if init_div_is_residual is not False and getattr(module.out_proj, "_is_residual", False):
|
| 151 |
+
with torch.no_grad():
|
| 152 |
+
module.out_proj.weight.div_(div_is_residual)
|
| 153 |
+
if module.out_proj.bias is not None:
|
| 154 |
+
torch.nn.init.zeros_(module.out_proj.bias)
|
| 155 |
+
else:
|
| 156 |
+
for _ in module.parameters(recurse=False):
|
| 157 |
+
raise NotImplementedError(f"{module.__class__.__name__} parameters are not initialized by param_init_fn.")
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
def _normal_init_(std, mean=0.0):
|
| 161 |
+
return partial(torch.nn.init.normal_, mean=mean, std=std)
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
def _normal_param_init_fn_(
|
| 165 |
+
module: nn.Module,
|
| 166 |
+
std: float,
|
| 167 |
+
n_layers: int,
|
| 168 |
+
d_model: Optional[int] = None,
|
| 169 |
+
init_div_is_residual: Union[int, float, str, bool] = True,
|
| 170 |
+
emb_init_std: Optional[float] = None,
|
| 171 |
+
emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None,
|
| 172 |
+
verbose: int = 0,
|
| 173 |
+
**kwargs,
|
| 174 |
+
):
|
| 175 |
+
del kwargs
|
| 176 |
+
init_fn_ = _normal_init_(std=std)
|
| 177 |
+
if verbose > 1:
|
| 178 |
+
warnings.warn(f"Using torch.nn.init.normal_ init fn mean=0.0, std={std}")
|
| 179 |
+
generic_param_init_fn_(
|
| 180 |
+
module=module,
|
| 181 |
+
init_fn_=init_fn_,
|
| 182 |
+
d_model=d_model,
|
| 183 |
+
n_layers=n_layers,
|
| 184 |
+
init_div_is_residual=init_div_is_residual,
|
| 185 |
+
emb_init_std=emb_init_std,
|
| 186 |
+
emb_init_uniform_lim=emb_init_uniform_lim,
|
| 187 |
+
verbose=verbose,
|
| 188 |
+
)
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
def baseline_param_init_fn_(
|
| 192 |
+
module: nn.Module,
|
| 193 |
+
init_std: float,
|
| 194 |
+
n_layers: int,
|
| 195 |
+
d_model: Optional[int] = None,
|
| 196 |
+
init_div_is_residual: Union[int, float, str, bool] = True,
|
| 197 |
+
emb_init_std: Optional[float] = None,
|
| 198 |
+
emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None,
|
| 199 |
+
verbose: int = 0,
|
| 200 |
+
**kwargs,
|
| 201 |
+
):
|
| 202 |
+
del kwargs
|
| 203 |
+
if init_std is None:
|
| 204 |
+
raise ValueError(
|
| 205 |
+
"You must set model.init_config['init_std'] to a float value to use the default initialization scheme."
|
| 206 |
+
)
|
| 207 |
+
_normal_param_init_fn_(
|
| 208 |
+
module=module,
|
| 209 |
+
std=init_std,
|
| 210 |
+
d_model=d_model,
|
| 211 |
+
n_layers=n_layers,
|
| 212 |
+
init_div_is_residual=init_div_is_residual,
|
| 213 |
+
emb_init_std=emb_init_std,
|
| 214 |
+
emb_init_uniform_lim=emb_init_uniform_lim,
|
| 215 |
+
verbose=verbose,
|
| 216 |
+
)
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
def small_param_init_fn_(
|
| 220 |
+
module: nn.Module,
|
| 221 |
+
n_layers: int,
|
| 222 |
+
d_model: int,
|
| 223 |
+
init_div_is_residual: Union[int, float, str, bool] = True,
|
| 224 |
+
emb_init_std: Optional[float] = None,
|
| 225 |
+
emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None,
|
| 226 |
+
verbose: int = 0,
|
| 227 |
+
**kwargs,
|
| 228 |
+
):
|
| 229 |
+
del kwargs
|
| 230 |
+
std = math.sqrt(2 / (5 * d_model))
|
| 231 |
+
_normal_param_init_fn_(
|
| 232 |
+
module=module,
|
| 233 |
+
std=std,
|
| 234 |
+
d_model=d_model,
|
| 235 |
+
n_layers=n_layers,
|
| 236 |
+
init_div_is_residual=init_div_is_residual,
|
| 237 |
+
emb_init_std=emb_init_std,
|
| 238 |
+
emb_init_uniform_lim=emb_init_uniform_lim,
|
| 239 |
+
verbose=verbose,
|
| 240 |
+
)
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
def neox_param_init_fn_(
|
| 244 |
+
module: nn.Module,
|
| 245 |
+
n_layers: int,
|
| 246 |
+
d_model: int,
|
| 247 |
+
emb_init_std: Optional[float] = None,
|
| 248 |
+
emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None,
|
| 249 |
+
verbose: int = 0,
|
| 250 |
+
**kwargs,
|
| 251 |
+
):
|
| 252 |
+
"""From section 2.3.1 of GPT-NeoX-20B:
|
| 253 |
+
|
| 254 |
+
An Open-Source AutoregressiveLanguage Model — Black et. al. (2022)
|
| 255 |
+
see https://github.com/EleutherAI/gpt-neox/blob/9610391ab319403cef079b438edd016a2443af54/megatron/model/init_functions.py#L151
|
| 256 |
+
and https://github.com/EleutherAI/gpt-neox/blob/main/megatron/model/transformer.py
|
| 257 |
+
"""
|
| 258 |
+
del kwargs
|
| 259 |
+
residual_div = n_layers / math.sqrt(10)
|
| 260 |
+
if verbose > 1:
|
| 261 |
+
warnings.warn(f"setting init_div_is_residual to {residual_div}")
|
| 262 |
+
small_param_init_fn_(
|
| 263 |
+
module=module,
|
| 264 |
+
d_model=d_model,
|
| 265 |
+
n_layers=n_layers,
|
| 266 |
+
init_div_is_residual=residual_div,
|
| 267 |
+
emb_init_std=emb_init_std,
|
| 268 |
+
emb_init_uniform_lim=emb_init_uniform_lim,
|
| 269 |
+
verbose=verbose,
|
| 270 |
+
)
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
def kaiming_uniform_param_init_fn_(
|
| 274 |
+
module: nn.Module,
|
| 275 |
+
n_layers: int,
|
| 276 |
+
d_model: Optional[int] = None,
|
| 277 |
+
init_div_is_residual: Union[int, float, str, bool] = True,
|
| 278 |
+
emb_init_std: Optional[float] = None,
|
| 279 |
+
emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None,
|
| 280 |
+
init_gain: float = 0,
|
| 281 |
+
fan_mode: str = "fan_in",
|
| 282 |
+
init_nonlinearity: str = "leaky_relu",
|
| 283 |
+
verbose: int = 0,
|
| 284 |
+
**kwargs,
|
| 285 |
+
):
|
| 286 |
+
del kwargs
|
| 287 |
+
if verbose > 1:
|
| 288 |
+
warnings.warn(
|
| 289 |
+
f"Using nn.init.kaiming_uniform_ init fn with parameters: "
|
| 290 |
+
+ f"a={init_gain}, mode={fan_mode}, nonlinearity={init_nonlinearity}"
|
| 291 |
+
)
|
| 292 |
+
kaiming_uniform_ = partial(nn.init.kaiming_uniform_, a=init_gain, mode=fan_mode, nonlinearity=init_nonlinearity)
|
| 293 |
+
generic_param_init_fn_(
|
| 294 |
+
module=module,
|
| 295 |
+
init_fn_=kaiming_uniform_,
|
| 296 |
+
d_model=d_model,
|
| 297 |
+
n_layers=n_layers,
|
| 298 |
+
init_div_is_residual=init_div_is_residual,
|
| 299 |
+
emb_init_std=emb_init_std,
|
| 300 |
+
emb_init_uniform_lim=emb_init_uniform_lim,
|
| 301 |
+
verbose=verbose,
|
| 302 |
+
)
|
| 303 |
+
|
| 304 |
+
|
| 305 |
+
def kaiming_normal_param_init_fn_(
|
| 306 |
+
module: nn.Module,
|
| 307 |
+
n_layers: int,
|
| 308 |
+
d_model: Optional[int] = None,
|
| 309 |
+
init_div_is_residual: Union[int, float, str, bool] = True,
|
| 310 |
+
emb_init_std: Optional[float] = None,
|
| 311 |
+
emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None,
|
| 312 |
+
init_gain: float = 0,
|
| 313 |
+
fan_mode: str = "fan_in",
|
| 314 |
+
init_nonlinearity: str = "leaky_relu",
|
| 315 |
+
verbose: int = 0,
|
| 316 |
+
**kwargs,
|
| 317 |
+
):
|
| 318 |
+
del kwargs
|
| 319 |
+
if verbose > 1:
|
| 320 |
+
warnings.warn(
|
| 321 |
+
f"Using nn.init.kaiming_normal_ init fn with parameters: "
|
| 322 |
+
+ f"a={init_gain}, mode={fan_mode}, nonlinearity={init_nonlinearity}"
|
| 323 |
+
)
|
| 324 |
+
kaiming_normal_ = partial(torch.nn.init.kaiming_normal_, a=init_gain, mode=fan_mode, nonlinearity=init_nonlinearity)
|
| 325 |
+
generic_param_init_fn_(
|
| 326 |
+
module=module,
|
| 327 |
+
init_fn_=kaiming_normal_,
|
| 328 |
+
d_model=d_model,
|
| 329 |
+
n_layers=n_layers,
|
| 330 |
+
init_div_is_residual=init_div_is_residual,
|
| 331 |
+
emb_init_std=emb_init_std,
|
| 332 |
+
emb_init_uniform_lim=emb_init_uniform_lim,
|
| 333 |
+
verbose=verbose,
|
| 334 |
+
)
|
| 335 |
+
|
| 336 |
+
|
| 337 |
+
def xavier_uniform_param_init_fn_(
|
| 338 |
+
module: nn.Module,
|
| 339 |
+
n_layers: int,
|
| 340 |
+
d_model: Optional[int] = None,
|
| 341 |
+
init_div_is_residual: Union[int, float, str, bool] = True,
|
| 342 |
+
emb_init_std: Optional[float] = None,
|
| 343 |
+
emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None,
|
| 344 |
+
init_gain: float = 0,
|
| 345 |
+
verbose: int = 0,
|
| 346 |
+
**kwargs,
|
| 347 |
+
):
|
| 348 |
+
del kwargs
|
| 349 |
+
xavier_uniform_ = partial(torch.nn.init.xavier_uniform_, gain=init_gain)
|
| 350 |
+
if verbose > 1:
|
| 351 |
+
warnings.warn(f"Using torch.nn.init.xavier_uniform_ init fn with parameters: " + f"gain={init_gain}")
|
| 352 |
+
generic_param_init_fn_(
|
| 353 |
+
module=module,
|
| 354 |
+
init_fn_=xavier_uniform_,
|
| 355 |
+
d_model=d_model,
|
| 356 |
+
n_layers=n_layers,
|
| 357 |
+
init_div_is_residual=init_div_is_residual,
|
| 358 |
+
emb_init_std=emb_init_std,
|
| 359 |
+
emb_init_uniform_lim=emb_init_uniform_lim,
|
| 360 |
+
verbose=verbose,
|
| 361 |
+
)
|
| 362 |
+
|
| 363 |
+
|
| 364 |
+
def xavier_normal_param_init_fn_(
|
| 365 |
+
module: nn.Module,
|
| 366 |
+
n_layers: int,
|
| 367 |
+
d_model: Optional[int] = None,
|
| 368 |
+
init_div_is_residual: Union[int, float, str, bool] = True,
|
| 369 |
+
emb_init_std: Optional[float] = None,
|
| 370 |
+
emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None,
|
| 371 |
+
init_gain: float = 0,
|
| 372 |
+
verbose: int = 0,
|
| 373 |
+
**kwargs,
|
| 374 |
+
):
|
| 375 |
+
xavier_normal_ = partial(torch.nn.init.xavier_normal_, gain=init_gain)
|
| 376 |
+
if verbose > 1:
|
| 377 |
+
warnings.warn(f"Using torch.nn.init.xavier_normal_ init fn with parameters: " + f"gain={init_gain}")
|
| 378 |
+
generic_param_init_fn_(
|
| 379 |
+
module=module,
|
| 380 |
+
init_fn_=xavier_normal_,
|
| 381 |
+
d_model=d_model,
|
| 382 |
+
n_layers=n_layers,
|
| 383 |
+
init_div_is_residual=init_div_is_residual,
|
| 384 |
+
emb_init_std=emb_init_std,
|
| 385 |
+
emb_init_uniform_lim=emb_init_uniform_lim,
|
| 386 |
+
verbose=verbose,
|
| 387 |
+
)
|
| 388 |
+
|
| 389 |
+
|
| 390 |
+
MODEL_INIT_REGISTRY = {
|
| 391 |
+
"default_": torch_default_param_init_fn_,
|
| 392 |
+
"baseline_": baseline_param_init_fn_,
|
| 393 |
+
"kaiming_uniform_": kaiming_uniform_param_init_fn_,
|
| 394 |
+
"kaiming_normal_": kaiming_normal_param_init_fn_,
|
| 395 |
+
"neox_init_": neox_param_init_fn_,
|
| 396 |
+
"small_init_": small_param_init_fn_,
|
| 397 |
+
"xavier_uniform_": xavier_uniform_param_init_fn_,
|
| 398 |
+
"xavier_normal_": xavier_normal_param_init_fn_,
|
| 399 |
+
}
|
VILA/llava/model/multimodal_encoder/__pycache__/builder.cpython-310.pyc
ADDED
|
Binary file (1.51 kB). View file
|
|
|
VILA/llava/model/multimodal_encoder/__pycache__/clip_encoder.cpython-310.pyc
ADDED
|
Binary file (1.41 kB). View file
|
|
|
VILA/llava/model/multimodal_encoder/__pycache__/image_processor.cpython-310.pyc
ADDED
|
Binary file (18.8 kB). View file
|
|
|
VILA/llava/model/multimodal_encoder/__pycache__/intern_encoder.cpython-310.pyc
ADDED
|
Binary file (2.71 kB). View file
|
|
|
VILA/llava/model/multimodal_encoder/__pycache__/radio_encoder.cpython-310.pyc
ADDED
|
Binary file (8.6 kB). View file
|
|
|
VILA/llava/model/multimodal_encoder/__pycache__/siglip_encoder.cpython-310.pyc
ADDED
|
Binary file (1.47 kB). View file
|
|
|
VILA/llava/model/multimodal_encoder/__pycache__/vision_encoder.cpython-310.pyc
ADDED
|
Binary file (6 kB). View file
|
|
|
VILA/llava/model/multimodal_encoder/__pycache__/visualize_features.cpython-310.pyc
ADDED
|
Binary file (9.31 kB). View file
|
|
|
VILA/llava/model/multimodal_encoder/builder.py
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
#
|
| 15 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 16 |
+
|
| 17 |
+
# This file is modified from https://github.com/haotian-liu/LLaVA/
|
| 18 |
+
|
| 19 |
+
import os
|
| 20 |
+
|
| 21 |
+
from transformers import AutoConfig, PretrainedConfig, PreTrainedModel
|
| 22 |
+
|
| 23 |
+
from .clip_encoder import CLIPVisionTower, CLIPVisionTowerS2
|
| 24 |
+
from .intern_encoder import InternVisionTower
|
| 25 |
+
from .radio_encoder import RADIOVisionTower
|
| 26 |
+
from .siglip_encoder import SiglipVisionTower, SiglipVisionTowerS2
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def build_vision_tower(model_name_or_path: str, config: PretrainedConfig) -> PreTrainedModel:
|
| 30 |
+
## skip vision tower instantiation
|
| 31 |
+
if model_name_or_path is None:
|
| 32 |
+
return None
|
| 33 |
+
|
| 34 |
+
vision_tower_arch = None
|
| 35 |
+
if config.resume_path and "radio" not in model_name_or_path:
|
| 36 |
+
assert os.path.exists(model_name_or_path), f"Resume vision tower path {model_name_or_path} does not exist!"
|
| 37 |
+
vision_tower_cfg = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=True)
|
| 38 |
+
vision_tower_arch = vision_tower_cfg.architectures[0].lower()
|
| 39 |
+
vision_tower_name = vision_tower_arch if vision_tower_arch is not None else model_name_or_path
|
| 40 |
+
|
| 41 |
+
use_s2 = getattr(config, "s2", False)
|
| 42 |
+
|
| 43 |
+
if "intern" in vision_tower_name.lower():
|
| 44 |
+
if hasattr(config, "drop_path_rate"):
|
| 45 |
+
vision_tower = InternVisionTower(model_name_or_path, config=config, drop_path_rate=config.drop_path_rate)
|
| 46 |
+
else:
|
| 47 |
+
vision_tower = InternVisionTower(model_name_or_path, config=config, drop_path_rate=0.0)
|
| 48 |
+
elif "radio" in vision_tower_name:
|
| 49 |
+
vision_tower = RADIOVisionTower(model_name_or_path, config)
|
| 50 |
+
elif "clip" in vision_tower_name:
|
| 51 |
+
if use_s2:
|
| 52 |
+
vision_tower = CLIPVisionTowerS2(model_name_or_path, config)
|
| 53 |
+
else:
|
| 54 |
+
vision_tower = CLIPVisionTower(model_name_or_path, config)
|
| 55 |
+
elif "siglip" in vision_tower_name:
|
| 56 |
+
if use_s2:
|
| 57 |
+
vision_tower = SiglipVisionTowerS2(model_name_or_path, config)
|
| 58 |
+
else:
|
| 59 |
+
vision_tower = SiglipVisionTower(model_name_or_path, config)
|
| 60 |
+
else:
|
| 61 |
+
raise ValueError(f"Unknown vision tower: {model_name_or_path}")
|
| 62 |
+
|
| 63 |
+
config.mm_hidden_size = vision_tower.config.hidden_size if not use_s2 else vision_tower.hidden_size
|
| 64 |
+
return vision_tower
|
VILA/llava/model/multimodal_encoder/clip_encoder.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
#
|
| 15 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 16 |
+
|
| 17 |
+
# This file is modified from https://github.com/haotian-liu/LLaVA/
|
| 18 |
+
import torch
|
| 19 |
+
from transformers import CLIPImageProcessor, CLIPVisionModel, PretrainedConfig
|
| 20 |
+
|
| 21 |
+
from llava.model.multimodal_encoder.vision_encoder import VisionTower, VisionTowerS2
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class CLIPVisionTower(VisionTower):
|
| 25 |
+
def __init__(self, model_name_or_path: str, config: PretrainedConfig):
|
| 26 |
+
super().__init__(model_name_or_path, config)
|
| 27 |
+
self.image_processor = CLIPImageProcessor.from_pretrained(model_name_or_path)
|
| 28 |
+
self.vision_tower = CLIPVisionModel.from_pretrained(model_name_or_path, torch_dtype=eval(config.model_dtype))
|
| 29 |
+
self.is_loaded = True
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class CLIPVisionTowerS2(VisionTowerS2):
|
| 33 |
+
def __init__(self, model_name_or_path: str, config: PretrainedConfig):
|
| 34 |
+
super().__init__(model_name_or_path, config)
|
| 35 |
+
self.image_processor = CLIPImageProcessor.from_pretrained(model_name_or_path)
|
| 36 |
+
self.vision_tower = CLIPVisionModel.from_pretrained(model_name_or_path, torch_dtype=eval(config.model_dtype))
|
| 37 |
+
|
| 38 |
+
# Make sure it crops/resizes the image to the largest scale in self.scales to maintain high-res information
|
| 39 |
+
self.image_processor.size["shortest_edge"] = self.scales[-1]
|
| 40 |
+
self.image_processor.crop_size["height"] = self.image_processor.crop_size["width"] = self.scales[-1]
|
| 41 |
+
|
| 42 |
+
self.is_loaded = True
|
VILA/llava/model/multimodal_encoder/image_processor.py
ADDED
|
@@ -0,0 +1,546 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""Image processor class for RADIO."""
|
| 16 |
+
import math
|
| 17 |
+
from copy import deepcopy
|
| 18 |
+
from itertools import product
|
| 19 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
| 20 |
+
|
| 21 |
+
import numpy as np
|
| 22 |
+
import PIL
|
| 23 |
+
from PIL.Image import Image
|
| 24 |
+
from transformers.image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
|
| 25 |
+
from transformers.image_transforms import convert_to_rgb, pad, resize, to_channel_dimension_format
|
| 26 |
+
from transformers.image_utils import (
|
| 27 |
+
IMAGENET_DEFAULT_MEAN,
|
| 28 |
+
IMAGENET_DEFAULT_STD,
|
| 29 |
+
ChannelDimension,
|
| 30 |
+
ImageInput,
|
| 31 |
+
PILImageResampling,
|
| 32 |
+
get_image_size,
|
| 33 |
+
infer_channel_dimension_format,
|
| 34 |
+
is_scaled_image,
|
| 35 |
+
make_list_of_images,
|
| 36 |
+
to_numpy_array,
|
| 37 |
+
valid_images,
|
| 38 |
+
)
|
| 39 |
+
from transformers.utils import (
|
| 40 |
+
TensorType,
|
| 41 |
+
is_tf_available,
|
| 42 |
+
is_torch_available,
|
| 43 |
+
is_torchvision_available,
|
| 44 |
+
logging,
|
| 45 |
+
requires_backends,
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
if is_torch_available():
|
| 49 |
+
import torch
|
| 50 |
+
import torch.nn.functional as F
|
| 51 |
+
|
| 52 |
+
if is_torchvision_available():
|
| 53 |
+
from torchvision.ops.boxes import batched_nms
|
| 54 |
+
|
| 55 |
+
# if is_tf_available():
|
| 56 |
+
# import tensorflow as tf
|
| 57 |
+
# from tensorflow.experimental import numpy as tnp
|
| 58 |
+
|
| 59 |
+
# from ...tf_utils import flatten, shape_list
|
| 60 |
+
|
| 61 |
+
logger = logging.get_logger(__name__)
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def rank_print(s):
|
| 65 |
+
rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0
|
| 66 |
+
print(f"[Rank {rank}] {s}")
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
class ImageProcessor(BaseImageProcessor):
|
| 70 |
+
r"""
|
| 71 |
+
Constructs an image processor.
|
| 72 |
+
|
| 73 |
+
Args:
|
| 74 |
+
do_resize (`bool`, *optional*, defaults to `True`):
|
| 75 |
+
Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by the
|
| 76 |
+
`do_resize` parameter in the `preprocess` method.
|
| 77 |
+
size (`dict`, *optional*, defaults to `{"longest_edge": 1024}`):
|
| 78 |
+
Size of the output image after resizing. If "longest_edge" is specified, resizes the longest edge of the image to match
|
| 79 |
+
`size["longest_edge"]` while maintaining the aspect ratio. If "width" and "height" are specified, resizes the image
|
| 80 |
+
to that size, possibly changing the aspect ratio. Can be overridden by the `size` parameter in the
|
| 81 |
+
`preprocess` method.
|
| 82 |
+
resample (`PILImageResampling`, *optional*, defaults to `Resampling.BILINEAR`):
|
| 83 |
+
Resampling filter to use if resizing the image. Can be overridden by the `resample` parameter in the
|
| 84 |
+
`preprocess` method.
|
| 85 |
+
do_rescale (`bool`, *optional*, defaults to `True`):
|
| 86 |
+
Wwhether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the
|
| 87 |
+
`do_rescale` parameter in the `preprocess` method.
|
| 88 |
+
rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
|
| 89 |
+
Scale factor to use if rescaling the image. Only has an effect if `do_rescale` is set to `True`. Can be
|
| 90 |
+
overridden by the `rescale_factor` parameter in the `preprocess` method.
|
| 91 |
+
do_normalize (`bool`, *optional*, defaults to `True`):
|
| 92 |
+
Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess`
|
| 93 |
+
method. Can be overridden by the `do_normalize` parameter in the `preprocess` method.
|
| 94 |
+
image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_DEFAULT_MEAN`):
|
| 95 |
+
Mean to use if normalizing the image. This is a float or list of floats the length of the number of
|
| 96 |
+
channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. Can be
|
| 97 |
+
overridden by the `image_mean` parameter in the `preprocess` method.
|
| 98 |
+
image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_DEFAULT_STD`):
|
| 99 |
+
Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
|
| 100 |
+
number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
|
| 101 |
+
Can be overridden by the `image_std` parameter in the `preprocess` method.
|
| 102 |
+
do_pad (`bool`, *optional*, defaults to `True`):
|
| 103 |
+
Whether to pad the image to the specified `pad_size`. Can be overridden by the `do_pad` parameter in the
|
| 104 |
+
`preprocess` method.
|
| 105 |
+
pad_size (`dict`, *optional*, defaults to `{"height": 1024, "width": 1024}`):
|
| 106 |
+
Size of the output image after padding. Can be overridden by the `pad_size` parameter in the `preprocess`
|
| 107 |
+
method.
|
| 108 |
+
pad_value (`float` or `Iterable[float]`, *optional*, defaults to `0.`):
|
| 109 |
+
Value of padded pixels.
|
| 110 |
+
pad_multiple (`int`, *optional*, defaults to `None`):
|
| 111 |
+
Pad to a multiple of specified number.
|
| 112 |
+
do_convert_rgb (`bool`, *optional*, defaults to `True`):
|
| 113 |
+
Whether to convert the image to RGB.
|
| 114 |
+
"""
|
| 115 |
+
|
| 116 |
+
model_input_names = ["pixel_values"]
|
| 117 |
+
|
| 118 |
+
def __init__(
|
| 119 |
+
self,
|
| 120 |
+
do_resize: bool = True,
|
| 121 |
+
size: Dict[str, int] = None,
|
| 122 |
+
resample: PILImageResampling = PILImageResampling.BILINEAR,
|
| 123 |
+
do_rescale: bool = True,
|
| 124 |
+
rescale_factor: Union[int, float] = 1 / 255,
|
| 125 |
+
do_normalize: bool = True,
|
| 126 |
+
image_mean: Optional[Union[float, List[float]]] = None,
|
| 127 |
+
image_std: Optional[Union[float, List[float]]] = None,
|
| 128 |
+
do_pad: bool = True,
|
| 129 |
+
pad_size: int = None,
|
| 130 |
+
pad_multiple: int = None,
|
| 131 |
+
pad_value: Optional[Union[float, List[float]]] = 0.0,
|
| 132 |
+
do_convert_rgb: bool = True,
|
| 133 |
+
**kwargs,
|
| 134 |
+
) -> None:
|
| 135 |
+
super().__init__(**kwargs)
|
| 136 |
+
x = 0
|
| 137 |
+
size = size if size is not None else {"longest_edge": 1024}
|
| 138 |
+
size = get_size_dict(max_size=size, default_to_square=False) if not isinstance(size, dict) else size
|
| 139 |
+
|
| 140 |
+
if pad_size is not None and pad_multiple is not None:
|
| 141 |
+
raise ValueError("pad_size and pad_multiple should not be set at the same time.")
|
| 142 |
+
|
| 143 |
+
pad_size = (
|
| 144 |
+
pad_size if pad_size is not None else {"height": 1024, "width": 1024} if pad_multiple is not None else None
|
| 145 |
+
)
|
| 146 |
+
if do_pad:
|
| 147 |
+
pad_size = get_size_dict(pad_size, default_to_square=True)
|
| 148 |
+
|
| 149 |
+
self.do_resize = do_resize
|
| 150 |
+
self.size = size
|
| 151 |
+
self.resample = resample
|
| 152 |
+
self.do_rescale = do_rescale
|
| 153 |
+
self.rescale_factor = rescale_factor
|
| 154 |
+
self.do_normalize = do_normalize
|
| 155 |
+
self.image_mean = image_mean if image_mean is not None else IMAGENET_DEFAULT_MEAN
|
| 156 |
+
self.image_std = image_std if image_std is not None else IMAGENET_DEFAULT_STD
|
| 157 |
+
self.do_pad = do_pad
|
| 158 |
+
self.pad_multiple = pad_multiple
|
| 159 |
+
self.pad_size = pad_size
|
| 160 |
+
self.pad_value = tuple(pad_value) if isinstance(pad_value, list) else pad_value
|
| 161 |
+
self.do_convert_rgb = do_convert_rgb
|
| 162 |
+
self._valid_processor_keys = [
|
| 163 |
+
"images",
|
| 164 |
+
"segmentation_maps",
|
| 165 |
+
"do_resize",
|
| 166 |
+
"size",
|
| 167 |
+
"resample",
|
| 168 |
+
"do_rescale",
|
| 169 |
+
"rescale_factor",
|
| 170 |
+
"do_normalize",
|
| 171 |
+
"image_mean",
|
| 172 |
+
"image_std",
|
| 173 |
+
"do_pad",
|
| 174 |
+
"pad_size",
|
| 175 |
+
"do_convert_rgb",
|
| 176 |
+
"return_tensors",
|
| 177 |
+
"data_format",
|
| 178 |
+
"input_data_format",
|
| 179 |
+
]
|
| 180 |
+
|
| 181 |
+
def pad_image(
|
| 182 |
+
self,
|
| 183 |
+
image: np.ndarray,
|
| 184 |
+
pad_size: Dict[str, int],
|
| 185 |
+
data_format: Optional[Union[str, ChannelDimension]] = None,
|
| 186 |
+
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
| 187 |
+
**kwargs,
|
| 188 |
+
) -> np.ndarray:
|
| 189 |
+
"""
|
| 190 |
+
Pad an image to `(pad_size["height"], pad_size["width"])` to the right and bottom.
|
| 191 |
+
|
| 192 |
+
Args:
|
| 193 |
+
image (`np.ndarray`):
|
| 194 |
+
Image to pad.
|
| 195 |
+
pad_size (`Dict[str, int]`):
|
| 196 |
+
Size of the output image after padding.
|
| 197 |
+
data_format (`str` or `ChannelDimension`, *optional*):
|
| 198 |
+
The data format of the image. Can be either "channels_first" or "channels_last". If `None`, the
|
| 199 |
+
`data_format` of the `image` will be used.
|
| 200 |
+
input_data_format (`str` or `ChannelDimension`, *optional*):
|
| 201 |
+
The channel dimension format of the input image. If not provided, it will be inferred.
|
| 202 |
+
"""
|
| 203 |
+
output_height, output_width = pad_size["height"], pad_size["width"]
|
| 204 |
+
input_height, input_width = get_image_size(image, channel_dim=input_data_format)
|
| 205 |
+
|
| 206 |
+
pad_width = output_width - input_width
|
| 207 |
+
pad_height = output_height - input_height
|
| 208 |
+
|
| 209 |
+
padded_image = pad(
|
| 210 |
+
image,
|
| 211 |
+
((0, pad_height), (0, pad_width)),
|
| 212 |
+
data_format=data_format,
|
| 213 |
+
input_data_format=input_data_format,
|
| 214 |
+
constant_values=self.pad_value,
|
| 215 |
+
**kwargs,
|
| 216 |
+
)
|
| 217 |
+
return padded_image
|
| 218 |
+
|
| 219 |
+
def _get_preprocess_shape(self, old_shape: Tuple[int, int], longest_edge: int):
|
| 220 |
+
"""
|
| 221 |
+
Compute the output size given input size and target long side length.
|
| 222 |
+
"""
|
| 223 |
+
oldh, oldw = old_shape
|
| 224 |
+
scale = longest_edge * 1.0 / max(oldh, oldw)
|
| 225 |
+
newh, neww = oldh * scale, oldw * scale
|
| 226 |
+
newh = int(newh + 0.5)
|
| 227 |
+
neww = int(neww + 0.5)
|
| 228 |
+
return (newh, neww)
|
| 229 |
+
|
| 230 |
+
def resize(
|
| 231 |
+
self,
|
| 232 |
+
image: np.ndarray,
|
| 233 |
+
size: Dict[str, int],
|
| 234 |
+
resample: PILImageResampling = PILImageResampling.BICUBIC,
|
| 235 |
+
data_format: Optional[Union[str, ChannelDimension]] = None,
|
| 236 |
+
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
| 237 |
+
**kwargs,
|
| 238 |
+
) -> np.ndarray:
|
| 239 |
+
"""
|
| 240 |
+
Resize an image to `(size["height"], size["width"])`.
|
| 241 |
+
|
| 242 |
+
Args:
|
| 243 |
+
image (`np.ndarray`):
|
| 244 |
+
Image to resize.
|
| 245 |
+
size (`Dict[str, int]`):
|
| 246 |
+
Dictionary in the format `{"longest_edge": int}` or `{"width": int, "height": int}` specifying the size
|
| 247 |
+
of the output image. If "longest_edge" is specified, resizes the longest edge of the image to match
|
| 248 |
+
`size["longest_edge"]` while maintaining the aspect ratio. If "width" and "height" are specified, resizes the image
|
| 249 |
+
to that size, possibly changing the aspect ratio.
|
| 250 |
+
resample:
|
| 251 |
+
`PILImageResampling` filter to use when resizing the image e.g. `PILImageResampling.BILINEAR`.
|
| 252 |
+
data_format (`ChannelDimension` or `str`, *optional*):
|
| 253 |
+
The channel dimension format for the output image. If unset, the channel dimension format of the input
|
| 254 |
+
image is used. Can be one of:
|
| 255 |
+
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
| 256 |
+
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
| 257 |
+
input_data_format (`ChannelDimension` or `str`, *optional*):
|
| 258 |
+
The channel dimension format for the input image. If unset, the channel dimension format is inferred
|
| 259 |
+
from the input image. Can be one of:
|
| 260 |
+
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
| 261 |
+
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
| 262 |
+
|
| 263 |
+
Returns:
|
| 264 |
+
`np.ndarray`: The resized image.
|
| 265 |
+
"""
|
| 266 |
+
size = get_size_dict(size)
|
| 267 |
+
if "longest_edge" not in size:
|
| 268 |
+
if "width" not in size or "height" not in size:
|
| 269 |
+
raise ValueError(
|
| 270 |
+
f"The `size` dictionary must contain the key `longest_edge`, or `width` and `height`. Got {size.keys()}"
|
| 271 |
+
)
|
| 272 |
+
input_size = get_image_size(image, channel_dim=input_data_format)
|
| 273 |
+
if "longest_edge" in size:
|
| 274 |
+
output_height, output_width = self._get_preprocess_shape(input_size, size["longest_edge"])
|
| 275 |
+
else:
|
| 276 |
+
output_height, output_width = size["height"], size["width"]
|
| 277 |
+
return resize(
|
| 278 |
+
image,
|
| 279 |
+
size=(output_height, output_width),
|
| 280 |
+
resample=resample,
|
| 281 |
+
data_format=data_format,
|
| 282 |
+
input_data_format=input_data_format,
|
| 283 |
+
**kwargs,
|
| 284 |
+
)
|
| 285 |
+
|
| 286 |
+
def _preprocess(
|
| 287 |
+
self,
|
| 288 |
+
image: ImageInput,
|
| 289 |
+
do_resize: bool,
|
| 290 |
+
do_rescale: bool,
|
| 291 |
+
do_normalize: bool,
|
| 292 |
+
size: Optional[Dict[str, int]] = None,
|
| 293 |
+
resample: PILImageResampling = None,
|
| 294 |
+
rescale_factor: Optional[float] = None,
|
| 295 |
+
image_mean: Optional[Union[float, List[float]]] = None,
|
| 296 |
+
image_std: Optional[Union[float, List[float]]] = None,
|
| 297 |
+
do_pad: Optional[bool] = None,
|
| 298 |
+
pad_size: Optional[Dict[str, int]] = None,
|
| 299 |
+
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
| 300 |
+
):
|
| 301 |
+
if do_resize:
|
| 302 |
+
image = self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format)
|
| 303 |
+
reshaped_input_size = get_image_size(image, channel_dim=input_data_format)
|
| 304 |
+
|
| 305 |
+
if do_rescale:
|
| 306 |
+
image = self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
|
| 307 |
+
|
| 308 |
+
if do_normalize:
|
| 309 |
+
image = self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format)
|
| 310 |
+
|
| 311 |
+
if do_pad:
|
| 312 |
+
if self.pad_multiple:
|
| 313 |
+
h, w = get_image_size(image, channel_dim=input_data_format)
|
| 314 |
+
pad_size = {
|
| 315 |
+
"height": math.ceil(h / self.pad_multiple) * self.pad_multiple,
|
| 316 |
+
"width": math.ceil(w / self.pad_multiple) * self.pad_multiple,
|
| 317 |
+
}
|
| 318 |
+
|
| 319 |
+
image = self.pad_image(image=image, pad_size=pad_size, input_data_format=input_data_format)
|
| 320 |
+
|
| 321 |
+
return image, reshaped_input_size
|
| 322 |
+
|
| 323 |
+
def _preprocess_image(
|
| 324 |
+
self,
|
| 325 |
+
image: ImageInput,
|
| 326 |
+
do_resize: Optional[bool] = None,
|
| 327 |
+
size: Dict[str, int] = None,
|
| 328 |
+
resample: PILImageResampling = None,
|
| 329 |
+
do_rescale: bool = None,
|
| 330 |
+
rescale_factor: Optional[float] = None,
|
| 331 |
+
do_normalize: Optional[bool] = None,
|
| 332 |
+
image_mean: Optional[Union[float, List[float]]] = None,
|
| 333 |
+
image_std: Optional[Union[float, List[float]]] = None,
|
| 334 |
+
do_pad: Optional[bool] = None,
|
| 335 |
+
pad_size: Optional[Dict[str, int]] = None,
|
| 336 |
+
do_convert_rgb: Optional[bool] = None,
|
| 337 |
+
data_format: Optional[Union[str, ChannelDimension]] = None,
|
| 338 |
+
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
| 339 |
+
) -> Tuple[np.ndarray, Tuple[int, int], Tuple[int, int]]:
|
| 340 |
+
# image = to_numpy_array(image)
|
| 341 |
+
|
| 342 |
+
# import time
|
| 343 |
+
# if int(time.time()*1000) % 10 == 0:
|
| 344 |
+
# # create an PIL image of size 1x1
|
| 345 |
+
# image = PIL.Image.new('RGB', (1, 1))
|
| 346 |
+
|
| 347 |
+
if isinstance(image, Image):
|
| 348 |
+
# PIL always uses Channels Last.
|
| 349 |
+
input_data_format = ChannelDimension.LAST
|
| 350 |
+
|
| 351 |
+
# PIL RGBA images are converted to RGB
|
| 352 |
+
# mode_before = image.mode
|
| 353 |
+
if do_convert_rgb:
|
| 354 |
+
image = convert_to_rgb(image)
|
| 355 |
+
|
| 356 |
+
# All transformations expect numpy arrays.
|
| 357 |
+
image_ = image
|
| 358 |
+
image = to_numpy_array(image)
|
| 359 |
+
|
| 360 |
+
# if isinstance(image_, np.ndarray):
|
| 361 |
+
# rank_print(f"preprocess image type={type(image_)} shape={image_.shape} array shape={image.shape}")
|
| 362 |
+
# elif isinstance(image_, Image):
|
| 363 |
+
# rank_print(f"preprocessimage type={type(image_)} size={image_.size} mode={image_.mode} array shape={image.shape}")
|
| 364 |
+
# else:
|
| 365 |
+
# rank_print(f"preprocess unknown image type={type(image_)} array shape={image.shape}")
|
| 366 |
+
|
| 367 |
+
if len(image.shape) == 2:
|
| 368 |
+
h, w = image.shape
|
| 369 |
+
ret = np.empty((h, w, 3), dtype=np.uint8)
|
| 370 |
+
ret[:, :, 0] = image
|
| 371 |
+
ret[:, :, 1] = image
|
| 372 |
+
ret[:, :, 2] = image
|
| 373 |
+
image = ret
|
| 374 |
+
rank_print(f"preprocess new image shape={image.shape}")
|
| 375 |
+
elif len(image.shape) == 3 and image.shape[-1] == 1:
|
| 376 |
+
ret = np.empty((h, w, 3), dtype=np.uint8)
|
| 377 |
+
ret[:, :, 0] = image[:, :, 0]
|
| 378 |
+
ret[:, :, 1] = image[:, :, 0]
|
| 379 |
+
ret[:, :, 2] = image[:, :, 0]
|
| 380 |
+
image = ret
|
| 381 |
+
rank_print(f"preprocess new image shape={image.shape}")
|
| 382 |
+
|
| 383 |
+
if is_scaled_image(image) and do_rescale:
|
| 384 |
+
logger.warning_once(
|
| 385 |
+
"It looks like you are trying to rescale already rescaled images. If the input"
|
| 386 |
+
" images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
|
| 387 |
+
)
|
| 388 |
+
|
| 389 |
+
if input_data_format is None:
|
| 390 |
+
input_data_format = infer_channel_dimension_format(image)
|
| 391 |
+
|
| 392 |
+
original_size = get_image_size(image, channel_dim=input_data_format)
|
| 393 |
+
|
| 394 |
+
image, reshaped_input_size = self._preprocess(
|
| 395 |
+
image=image,
|
| 396 |
+
do_resize=do_resize,
|
| 397 |
+
size=size,
|
| 398 |
+
resample=resample,
|
| 399 |
+
do_rescale=do_rescale,
|
| 400 |
+
rescale_factor=rescale_factor,
|
| 401 |
+
do_normalize=do_normalize,
|
| 402 |
+
image_mean=image_mean,
|
| 403 |
+
image_std=image_std,
|
| 404 |
+
do_pad=do_pad,
|
| 405 |
+
pad_size=pad_size,
|
| 406 |
+
input_data_format=input_data_format,
|
| 407 |
+
)
|
| 408 |
+
|
| 409 |
+
if data_format is not None:
|
| 410 |
+
image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
|
| 411 |
+
|
| 412 |
+
# rank_print(f"preprocess original_size={original_size} reshaped_input_size={reshaped_input_size} image shape={image.shape} type={type(image)}")
|
| 413 |
+
|
| 414 |
+
# if image is a single channel convert to rgb
|
| 415 |
+
if do_convert_rgb and image.shape[0] == 1:
|
| 416 |
+
c, h, w = image.shape
|
| 417 |
+
ret = np.empty((3, h, w), dtype=np.uint8)
|
| 418 |
+
ret[0, :, :] = image[0, :, :]
|
| 419 |
+
ret[1, :, :] = image[0, :, :]
|
| 420 |
+
ret[2, :, :] = image[0, :, :]
|
| 421 |
+
image = ret
|
| 422 |
+
rank_print(f"preprocess final: {image.shape}")
|
| 423 |
+
|
| 424 |
+
return image, original_size, reshaped_input_size
|
| 425 |
+
|
| 426 |
+
def preprocess(
|
| 427 |
+
self,
|
| 428 |
+
images: ImageInput,
|
| 429 |
+
do_resize: Optional[bool] = None,
|
| 430 |
+
size: Optional[Dict[str, int]] = None,
|
| 431 |
+
resample: Optional["PILImageResampling"] = None,
|
| 432 |
+
do_rescale: Optional[bool] = None,
|
| 433 |
+
rescale_factor: Optional[Union[int, float]] = None,
|
| 434 |
+
do_normalize: Optional[bool] = None,
|
| 435 |
+
image_mean: Optional[Union[float, List[float]]] = None,
|
| 436 |
+
image_std: Optional[Union[float, List[float]]] = None,
|
| 437 |
+
do_pad: Optional[bool] = None,
|
| 438 |
+
pad_size: Optional[Dict[str, int]] = None,
|
| 439 |
+
do_convert_rgb: Optional[bool] = None,
|
| 440 |
+
return_tensors: Optional[Union[str, TensorType]] = None,
|
| 441 |
+
data_format: ChannelDimension = ChannelDimension.FIRST,
|
| 442 |
+
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
| 443 |
+
**kwargs,
|
| 444 |
+
):
|
| 445 |
+
"""
|
| 446 |
+
Preprocess an image or batch of images.
|
| 447 |
+
|
| 448 |
+
Args:
|
| 449 |
+
images (`ImageInput`):
|
| 450 |
+
Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
|
| 451 |
+
passing in images with pixel values between 0 and 1, set `do_rescale=False`.
|
| 452 |
+
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
|
| 453 |
+
Whether to resize the image.
|
| 454 |
+
size (`Dict[str, int]`, *optional*, defaults to `self.size`):
|
| 455 |
+
Controls the size of the image after `resize`. The longest edge of the image is resized to
|
| 456 |
+
`size["longest_edge"]` whilst preserving the aspect ratio.
|
| 457 |
+
resample (`PILImageResampling`, *optional*, defaults to `self.resample`):
|
| 458 |
+
`PILImageResampling` filter to use when resizing the image e.g. `PILImageResampling.BILINEAR`.
|
| 459 |
+
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
|
| 460 |
+
Whether to rescale the image pixel values by rescaling factor.
|
| 461 |
+
rescale_factor (`int` or `float`, *optional*, defaults to `self.rescale_factor`):
|
| 462 |
+
Rescale factor to apply to the image pixel values.
|
| 463 |
+
do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
|
| 464 |
+
Whether to normalize the image.
|
| 465 |
+
image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
|
| 466 |
+
Image mean to normalize the image by if `do_normalize` is set to `True`.
|
| 467 |
+
image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
|
| 468 |
+
Image standard deviation to normalize the image by if `do_normalize` is set to `True`.
|
| 469 |
+
do_pad (`bool`, *optional*, defaults to `self.do_pad`):
|
| 470 |
+
Whether to pad the image.
|
| 471 |
+
pad_size (`Dict[str, int]`, *optional*, defaults to `self.pad_size`):
|
| 472 |
+
Controls the size of the padding applied to the image. The image is padded to `pad_size["height"]` and
|
| 473 |
+
`pad_size["width"]` if `do_pad` is set to `True`.
|
| 474 |
+
do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
|
| 475 |
+
Whether to convert the image to RGB.
|
| 476 |
+
return_tensors (`str` or `TensorType`, *optional*):
|
| 477 |
+
The type of tensors to return. Can be one of:
|
| 478 |
+
- Unset: Return a list of `np.ndarray`.
|
| 479 |
+
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
|
| 480 |
+
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
|
| 481 |
+
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
|
| 482 |
+
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
|
| 483 |
+
data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
|
| 484 |
+
The channel dimension format for the output image. Can be one of:
|
| 485 |
+
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
| 486 |
+
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
| 487 |
+
- Unset: Use the channel dimension format of the input image.
|
| 488 |
+
input_data_format (`ChannelDimension` or `str`, *optional*):
|
| 489 |
+
The channel dimension format for the input image. If unset, the channel dimension format is inferred
|
| 490 |
+
from the input image. Can be one of:
|
| 491 |
+
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
| 492 |
+
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
| 493 |
+
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
| 494 |
+
"""
|
| 495 |
+
do_resize = do_resize if do_resize is not None else self.do_resize
|
| 496 |
+
size = size if size is not None else self.size
|
| 497 |
+
size = get_size_dict(max_size=size, default_to_square=False) if not isinstance(size, dict) else size
|
| 498 |
+
resample = resample if resample is not None else self.resample
|
| 499 |
+
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
|
| 500 |
+
rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
|
| 501 |
+
do_normalize = do_normalize if do_normalize is not None else self.do_normalize
|
| 502 |
+
image_mean = image_mean if image_mean is not None else self.image_mean
|
| 503 |
+
image_std = image_std if image_std is not None else self.image_std
|
| 504 |
+
do_pad = do_pad if do_pad is not None else self.do_pad
|
| 505 |
+
pad_size = pad_size if pad_size is not None else self.pad_size
|
| 506 |
+
if do_pad:
|
| 507 |
+
pad_size = get_size_dict(pad_size, default_to_square=True)
|
| 508 |
+
do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
|
| 509 |
+
|
| 510 |
+
images = make_list_of_images(images)
|
| 511 |
+
|
| 512 |
+
if not valid_images(images):
|
| 513 |
+
raise ValueError(
|
| 514 |
+
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
|
| 515 |
+
"torch.Tensor, tf.Tensor or jax.ndarray."
|
| 516 |
+
)
|
| 517 |
+
|
| 518 |
+
images, original_sizes, reshaped_input_sizes = zip(
|
| 519 |
+
*(
|
| 520 |
+
self._preprocess_image(
|
| 521 |
+
image=img,
|
| 522 |
+
do_resize=do_resize,
|
| 523 |
+
size=size,
|
| 524 |
+
resample=resample,
|
| 525 |
+
do_rescale=do_rescale,
|
| 526 |
+
rescale_factor=rescale_factor,
|
| 527 |
+
do_normalize=do_normalize,
|
| 528 |
+
image_mean=image_mean,
|
| 529 |
+
image_std=image_std,
|
| 530 |
+
do_pad=do_pad,
|
| 531 |
+
pad_size=pad_size,
|
| 532 |
+
do_convert_rgb=do_convert_rgb,
|
| 533 |
+
data_format=data_format,
|
| 534 |
+
input_data_format=input_data_format,
|
| 535 |
+
)
|
| 536 |
+
for img in images
|
| 537 |
+
)
|
| 538 |
+
)
|
| 539 |
+
|
| 540 |
+
data = {
|
| 541 |
+
"pixel_values": images,
|
| 542 |
+
"original_sizes": original_sizes,
|
| 543 |
+
"reshaped_input_sizes": reshaped_input_sizes,
|
| 544 |
+
}
|
| 545 |
+
|
| 546 |
+
return BatchFeature(data=data, tensor_type=return_tensors)
|
VILA/llava/model/multimodal_encoder/intern/__pycache__/configuration_intern_vit.cpython-310.pyc
ADDED
|
Binary file (4.98 kB). View file
|
|
|
VILA/llava/model/multimodal_encoder/intern/__pycache__/flash_attention.cpython-310.pyc
ADDED
|
Binary file (2.72 kB). View file
|
|
|
VILA/llava/model/multimodal_encoder/intern/__pycache__/modeling_intern_vit.cpython-310.pyc
ADDED
|
Binary file (18 kB). View file
|
|
|
VILA/llava/model/multimodal_encoder/intern/configuration_intern_vit.py
ADDED
|
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# --------------------------------------------------------
|
| 2 |
+
# InternVL
|
| 3 |
+
# Copyright (c) 2023 OpenGVLab
|
| 4 |
+
# Licensed under The MIT License [see LICENSE for details]
|
| 5 |
+
# --------------------------------------------------------
|
| 6 |
+
import os
|
| 7 |
+
from typing import Union
|
| 8 |
+
|
| 9 |
+
from transformers.configuration_utils import PretrainedConfig
|
| 10 |
+
from transformers.utils import logging
|
| 11 |
+
|
| 12 |
+
logger = logging.get_logger(__name__)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class InternVisionConfig(PretrainedConfig):
|
| 16 |
+
r"""
|
| 17 |
+
This is the configuration class to store the configuration of a [`InternVisionModel`]. It is used to
|
| 18 |
+
instantiate a vision encoder according to the specified arguments, defining the model architecture.
|
| 19 |
+
|
| 20 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 21 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 22 |
+
|
| 23 |
+
Args:
|
| 24 |
+
num_channels (`int`, *optional*, defaults to 3):
|
| 25 |
+
Number of color channels in the input images (e.g., 3 for RGB).
|
| 26 |
+
patch_size (`int`, *optional*, defaults to 14):
|
| 27 |
+
The size (resolution) of each patch.
|
| 28 |
+
image_size (`int`, *optional*, defaults to 224):
|
| 29 |
+
The size (resolution) of each image.
|
| 30 |
+
qkv_bias (`bool`, *optional*, defaults to `False`):
|
| 31 |
+
Whether to add a bias to the queries and values in the self-attention layers.
|
| 32 |
+
hidden_size (`int`, *optional*, defaults to 3200):
|
| 33 |
+
Dimensionality of the encoder layers and the pooler layer.
|
| 34 |
+
num_attention_heads (`int`, *optional*, defaults to 25):
|
| 35 |
+
Number of attention heads for each attention layer in the Transformer encoder.
|
| 36 |
+
intermediate_size (`int`, *optional*, defaults to 12800):
|
| 37 |
+
Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
|
| 38 |
+
qk_normalization (`bool`, *optional*, defaults to `True`):
|
| 39 |
+
Whether to normalize the queries and keys in the self-attention layers.
|
| 40 |
+
num_hidden_layers (`int`, *optional*, defaults to 48):
|
| 41 |
+
Number of hidden layers in the Transformer encoder.
|
| 42 |
+
use_flash_attn (`bool`, *optional*, defaults to `True`):
|
| 43 |
+
Whether to use flash attention mechanism.
|
| 44 |
+
hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
|
| 45 |
+
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
|
| 46 |
+
`"relu"`, `"selu"` and `"gelu_new"` ``"gelu"` are supported.
|
| 47 |
+
layer_norm_eps (`float`, *optional*, defaults to 1e-6):
|
| 48 |
+
The epsilon used by the layer normalization layers.
|
| 49 |
+
dropout (`float`, *optional*, defaults to 0.0):
|
| 50 |
+
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
|
| 51 |
+
drop_path_rate (`float`, *optional*, defaults to 0.0):
|
| 52 |
+
Dropout rate for stochastic depth.
|
| 53 |
+
attention_dropout (`float`, *optional*, defaults to 0.0):
|
| 54 |
+
The dropout ratio for the attention probabilities.
|
| 55 |
+
initializer_range (`float`, *optional*, defaults to 0.02):
|
| 56 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
| 57 |
+
initializer_factor (`float`, *optional*, defaults to 0.1):
|
| 58 |
+
A factor for layer scale.
|
| 59 |
+
"""
|
| 60 |
+
|
| 61 |
+
model_type = "intern_vit_6b"
|
| 62 |
+
|
| 63 |
+
def __init__(
|
| 64 |
+
self,
|
| 65 |
+
num_channels=3,
|
| 66 |
+
patch_size=14,
|
| 67 |
+
image_size=224,
|
| 68 |
+
qkv_bias=False,
|
| 69 |
+
hidden_size=3200,
|
| 70 |
+
num_attention_heads=25,
|
| 71 |
+
intermediate_size=12800,
|
| 72 |
+
qk_normalization=True,
|
| 73 |
+
num_hidden_layers=48,
|
| 74 |
+
use_flash_attn=True,
|
| 75 |
+
hidden_act="gelu",
|
| 76 |
+
layer_norm_eps=1e-6,
|
| 77 |
+
dropout=0.0,
|
| 78 |
+
drop_path_rate=0.0,
|
| 79 |
+
attention_dropout=0.0,
|
| 80 |
+
initializer_range=0.02,
|
| 81 |
+
initializer_factor=0.1,
|
| 82 |
+
**kwargs,
|
| 83 |
+
):
|
| 84 |
+
super().__init__(**kwargs)
|
| 85 |
+
|
| 86 |
+
self.hidden_size = hidden_size
|
| 87 |
+
self.intermediate_size = intermediate_size
|
| 88 |
+
self.dropout = dropout
|
| 89 |
+
self.drop_path_rate = drop_path_rate
|
| 90 |
+
self.num_hidden_layers = num_hidden_layers
|
| 91 |
+
self.num_attention_heads = num_attention_heads
|
| 92 |
+
self.num_channels = num_channels
|
| 93 |
+
self.patch_size = patch_size
|
| 94 |
+
self.image_size = image_size
|
| 95 |
+
self.initializer_range = initializer_range
|
| 96 |
+
self.initializer_factor = initializer_factor
|
| 97 |
+
self.attention_dropout = attention_dropout
|
| 98 |
+
self.layer_norm_eps = layer_norm_eps
|
| 99 |
+
self.hidden_act = hidden_act
|
| 100 |
+
self.qkv_bias = qkv_bias
|
| 101 |
+
self.qk_normalization = qk_normalization
|
| 102 |
+
self.use_flash_attn = use_flash_attn
|
| 103 |
+
|
| 104 |
+
@classmethod
|
| 105 |
+
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
|
| 106 |
+
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
|
| 107 |
+
|
| 108 |
+
if "vision_config" in config_dict:
|
| 109 |
+
config_dict = config_dict["vision_config"]
|
| 110 |
+
|
| 111 |
+
if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
|
| 112 |
+
logger.warning(
|
| 113 |
+
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
|
| 114 |
+
f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
return cls.from_dict(config_dict, **kwargs)
|
VILA/llava/model/multimodal_encoder/intern/flash_attention.py
ADDED
|
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
#
|
| 15 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 16 |
+
|
| 17 |
+
# https://github.com/Dao-AILab/flash-attention/blob/v0.2.8/flash_attn/flash_attention.py
|
| 18 |
+
import torch
|
| 19 |
+
import torch.nn as nn
|
| 20 |
+
from einops import rearrange
|
| 21 |
+
|
| 22 |
+
try: # v1
|
| 23 |
+
from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func
|
| 24 |
+
except: # v2
|
| 25 |
+
from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func as flash_attn_unpadded_qkvpacked_func
|
| 26 |
+
|
| 27 |
+
from flash_attn.bert_padding import pad_input, unpad_input
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class FlashAttention(nn.Module):
|
| 31 |
+
"""Implement the scaled dot product attention with softmax.
|
| 32 |
+
Arguments
|
| 33 |
+
---------
|
| 34 |
+
softmax_scale: The temperature to use for the softmax attention.
|
| 35 |
+
(default: 1/sqrt(d_keys) where d_keys is computed at
|
| 36 |
+
runtime)
|
| 37 |
+
attention_dropout: The dropout rate to apply to the attention
|
| 38 |
+
(default: 0.0)
|
| 39 |
+
"""
|
| 40 |
+
|
| 41 |
+
def __init__(self, softmax_scale=None, attention_dropout=0.0, device=None, dtype=None):
|
| 42 |
+
super().__init__()
|
| 43 |
+
self.softmax_scale = softmax_scale
|
| 44 |
+
self.dropout_p = attention_dropout
|
| 45 |
+
|
| 46 |
+
def forward(self, qkv, key_padding_mask=None, causal=False, cu_seqlens=None, max_s=None, need_weights=False):
|
| 47 |
+
"""Implements the multihead softmax attention.
|
| 48 |
+
Arguments
|
| 49 |
+
---------
|
| 50 |
+
qkv: The tensor containing the query, key, and value. (B, S, 3, H, D) if key_padding_mask is None
|
| 51 |
+
if unpadded: (nnz, 3, h, d)
|
| 52 |
+
key_padding_mask: a bool tensor of shape (B, S)
|
| 53 |
+
"""
|
| 54 |
+
assert not need_weights
|
| 55 |
+
assert qkv.dtype in [torch.float16, torch.bfloat16]
|
| 56 |
+
assert qkv.is_cuda
|
| 57 |
+
|
| 58 |
+
if cu_seqlens is None:
|
| 59 |
+
batch_size = qkv.shape[0]
|
| 60 |
+
seqlen = qkv.shape[1]
|
| 61 |
+
if key_padding_mask is None:
|
| 62 |
+
qkv = rearrange(qkv, "b s ... -> (b s) ...")
|
| 63 |
+
max_s = seqlen
|
| 64 |
+
cu_seqlens = torch.arange(
|
| 65 |
+
0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32, device=qkv.device
|
| 66 |
+
)
|
| 67 |
+
output = flash_attn_unpadded_qkvpacked_func(
|
| 68 |
+
qkv,
|
| 69 |
+
cu_seqlens,
|
| 70 |
+
max_s,
|
| 71 |
+
self.dropout_p if self.training else 0.0,
|
| 72 |
+
softmax_scale=self.softmax_scale,
|
| 73 |
+
causal=causal,
|
| 74 |
+
)
|
| 75 |
+
output = rearrange(output, "(b s) ... -> b s ...", b=batch_size)
|
| 76 |
+
else:
|
| 77 |
+
nheads = qkv.shape[-2]
|
| 78 |
+
x = rearrange(qkv, "b s three h d -> b s (three h d)")
|
| 79 |
+
x_unpad, indices, cu_seqlens, max_s = unpad_input(x, key_padding_mask)
|
| 80 |
+
x_unpad = rearrange(x_unpad, "nnz (three h d) -> nnz three h d", three=3, h=nheads)
|
| 81 |
+
output_unpad = flash_attn_unpadded_qkvpacked_func(
|
| 82 |
+
x_unpad,
|
| 83 |
+
cu_seqlens,
|
| 84 |
+
max_s,
|
| 85 |
+
self.dropout_p if self.training else 0.0,
|
| 86 |
+
softmax_scale=self.softmax_scale,
|
| 87 |
+
causal=causal,
|
| 88 |
+
)
|
| 89 |
+
output = rearrange(
|
| 90 |
+
pad_input(rearrange(output_unpad, "nnz h d -> nnz (h d)"), indices, batch_size, seqlen),
|
| 91 |
+
"b s (h d) -> b s h d",
|
| 92 |
+
h=nheads,
|
| 93 |
+
)
|
| 94 |
+
else:
|
| 95 |
+
assert max_s is not None
|
| 96 |
+
output = flash_attn_unpadded_qkvpacked_func(
|
| 97 |
+
qkv,
|
| 98 |
+
cu_seqlens,
|
| 99 |
+
max_s,
|
| 100 |
+
self.dropout_p if self.training else 0.0,
|
| 101 |
+
softmax_scale=self.softmax_scale,
|
| 102 |
+
causal=causal,
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
return output, None
|
VILA/llava/model/multimodal_encoder/intern/modeling_intern_vit.py
ADDED
|
@@ -0,0 +1,543 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# --------------------------------------------------------
|
| 2 |
+
# InternVL
|
| 3 |
+
# Copyright (c) 2023 OpenGVLab
|
| 4 |
+
# Licensed under The MIT License [see LICENSE for details]
|
| 5 |
+
# --------------------------------------------------------
|
| 6 |
+
from typing import Optional, Tuple, Union
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
import torch.utils.checkpoint
|
| 11 |
+
from einops import rearrange
|
| 12 |
+
from torch import nn
|
| 13 |
+
from transformers.activations import ACT2FN
|
| 14 |
+
from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
|
| 15 |
+
from transformers.modeling_utils import PreTrainedModel
|
| 16 |
+
from transformers.utils import logging
|
| 17 |
+
|
| 18 |
+
from llava.model.multimodal_encoder.intern.configuration_intern_vit import InternVisionConfig
|
| 19 |
+
|
| 20 |
+
from .flash_attention import FlashAttention
|
| 21 |
+
|
| 22 |
+
has_flash_attn = True
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
logger = logging.get_logger(__name__)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
""" DropBlock, DropPath
|
| 29 |
+
|
| 30 |
+
PyTorch implementations of DropBlock and DropPath (Stochastic Depth) regularization layers.
|
| 31 |
+
|
| 32 |
+
Papers:
|
| 33 |
+
DropBlock: A regularization method for convolutional networks (https://arxiv.org/abs/1810.12890)
|
| 34 |
+
|
| 35 |
+
Deep Networks with Stochastic Depth (https://arxiv.org/abs/1603.09382)
|
| 36 |
+
|
| 37 |
+
Code:
|
| 38 |
+
DropBlock impl inspired by two Tensorflow impl that I liked:
|
| 39 |
+
- https://github.com/tensorflow/tpu/blob/master/models/official/resnet/resnet_model.py#L74
|
| 40 |
+
- https://github.com/clovaai/assembled-cnn/blob/master/nets/blocks.py
|
| 41 |
+
|
| 42 |
+
Hacked together by / Copyright 2020 Ross Wightman
|
| 43 |
+
"""
|
| 44 |
+
import torch
|
| 45 |
+
import torch.nn as nn
|
| 46 |
+
import torch.nn.functional as F
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def ndgrid(*tensors) -> Tuple[torch.Tensor, ...]:
|
| 50 |
+
"""generate N-D grid in dimension order.
|
| 51 |
+
|
| 52 |
+
The ndgrid function is like meshgrid except that the order of the first two input arguments are switched.
|
| 53 |
+
|
| 54 |
+
That is, the statement
|
| 55 |
+
[X1,X2,X3] = ndgrid(x1,x2,x3)
|
| 56 |
+
|
| 57 |
+
produces the same result as
|
| 58 |
+
|
| 59 |
+
[X2,X1,X3] = meshgrid(x2,x1,x3)
|
| 60 |
+
|
| 61 |
+
This naming is based on MATLAB, the purpose is to avoid confusion due to torch's change to make
|
| 62 |
+
torch.meshgrid behaviour move from matching ndgrid ('ij') indexing to numpy meshgrid defaults of ('xy').
|
| 63 |
+
|
| 64 |
+
"""
|
| 65 |
+
try:
|
| 66 |
+
return torch.meshgrid(*tensors, indexing="ij")
|
| 67 |
+
except TypeError:
|
| 68 |
+
# old PyTorch < 1.10 will follow this path as it does not have indexing arg,
|
| 69 |
+
# the old behaviour of meshgrid was 'ij'
|
| 70 |
+
return torch.meshgrid(*tensors)
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def drop_block_2d(
|
| 74 |
+
x,
|
| 75 |
+
drop_prob: float = 0.1,
|
| 76 |
+
block_size: int = 7,
|
| 77 |
+
gamma_scale: float = 1.0,
|
| 78 |
+
with_noise: bool = False,
|
| 79 |
+
inplace: bool = False,
|
| 80 |
+
batchwise: bool = False,
|
| 81 |
+
):
|
| 82 |
+
"""DropBlock. See https://arxiv.org/pdf/1810.12890.pdf
|
| 83 |
+
|
| 84 |
+
DropBlock with an experimental gaussian noise option. This layer has been tested on a few training
|
| 85 |
+
runs with success, but needs further validation and possibly optimization for lower runtime impact.
|
| 86 |
+
"""
|
| 87 |
+
B, C, H, W = x.shape
|
| 88 |
+
total_size = W * H
|
| 89 |
+
clipped_block_size = min(block_size, min(W, H))
|
| 90 |
+
# seed_drop_rate, the gamma parameter
|
| 91 |
+
gamma = (
|
| 92 |
+
gamma_scale * drop_prob * total_size / clipped_block_size**2 / ((W - block_size + 1) * (H - block_size + 1))
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
# Forces the block to be inside the feature map.
|
| 96 |
+
w_i, h_i = ndgrid(torch.arange(W, device=x.device), torch.arange(H, device=x.device))
|
| 97 |
+
valid_block = ((w_i >= clipped_block_size // 2) & (w_i < W - (clipped_block_size - 1) // 2)) & (
|
| 98 |
+
(h_i >= clipped_block_size // 2) & (h_i < H - (clipped_block_size - 1) // 2)
|
| 99 |
+
)
|
| 100 |
+
valid_block = torch.reshape(valid_block, (1, 1, H, W)).to(dtype=x.dtype)
|
| 101 |
+
|
| 102 |
+
if batchwise:
|
| 103 |
+
# one mask for whole batch, quite a bit faster
|
| 104 |
+
uniform_noise = torch.rand((1, C, H, W), dtype=x.dtype, device=x.device)
|
| 105 |
+
else:
|
| 106 |
+
uniform_noise = torch.rand_like(x)
|
| 107 |
+
block_mask = ((2 - gamma - valid_block + uniform_noise) >= 1).to(dtype=x.dtype)
|
| 108 |
+
block_mask = -F.max_pool2d(
|
| 109 |
+
-block_mask, kernel_size=clipped_block_size, stride=1, padding=clipped_block_size // 2 # block_size,
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
if with_noise:
|
| 113 |
+
normal_noise = torch.randn((1, C, H, W), dtype=x.dtype, device=x.device) if batchwise else torch.randn_like(x)
|
| 114 |
+
if inplace:
|
| 115 |
+
x.mul_(block_mask).add_(normal_noise * (1 - block_mask))
|
| 116 |
+
else:
|
| 117 |
+
x = x * block_mask + normal_noise * (1 - block_mask)
|
| 118 |
+
else:
|
| 119 |
+
normalize_scale = (block_mask.numel() / block_mask.to(dtype=torch.float32).sum().add(1e-7)).to(x.dtype)
|
| 120 |
+
if inplace:
|
| 121 |
+
x.mul_(block_mask * normalize_scale)
|
| 122 |
+
else:
|
| 123 |
+
x = x * block_mask * normalize_scale
|
| 124 |
+
return x
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def drop_block_fast_2d(
|
| 128 |
+
x: torch.Tensor,
|
| 129 |
+
drop_prob: float = 0.1,
|
| 130 |
+
block_size: int = 7,
|
| 131 |
+
gamma_scale: float = 1.0,
|
| 132 |
+
with_noise: bool = False,
|
| 133 |
+
inplace: bool = False,
|
| 134 |
+
):
|
| 135 |
+
"""DropBlock. See https://arxiv.org/pdf/1810.12890.pdf
|
| 136 |
+
|
| 137 |
+
DropBlock with an experimental gaussian noise option. Simplied from above without concern for valid
|
| 138 |
+
block mask at edges.
|
| 139 |
+
"""
|
| 140 |
+
B, C, H, W = x.shape
|
| 141 |
+
total_size = W * H
|
| 142 |
+
clipped_block_size = min(block_size, min(W, H))
|
| 143 |
+
gamma = (
|
| 144 |
+
gamma_scale * drop_prob * total_size / clipped_block_size**2 / ((W - block_size + 1) * (H - block_size + 1))
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
block_mask = torch.empty_like(x).bernoulli_(gamma)
|
| 148 |
+
block_mask = F.max_pool2d(
|
| 149 |
+
block_mask.to(x.dtype), kernel_size=clipped_block_size, stride=1, padding=clipped_block_size // 2
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
if with_noise:
|
| 153 |
+
normal_noise = torch.empty_like(x).normal_()
|
| 154 |
+
if inplace:
|
| 155 |
+
x.mul_(1.0 - block_mask).add_(normal_noise * block_mask)
|
| 156 |
+
else:
|
| 157 |
+
x = x * (1.0 - block_mask) + normal_noise * block_mask
|
| 158 |
+
else:
|
| 159 |
+
block_mask = 1 - block_mask
|
| 160 |
+
normalize_scale = (block_mask.numel() / block_mask.to(dtype=torch.float32).sum().add(1e-6)).to(dtype=x.dtype)
|
| 161 |
+
if inplace:
|
| 162 |
+
x.mul_(block_mask * normalize_scale)
|
| 163 |
+
else:
|
| 164 |
+
x = x * block_mask * normalize_scale
|
| 165 |
+
return x
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
class DropBlock2d(nn.Module):
|
| 169 |
+
"""DropBlock. See https://arxiv.org/pdf/1810.12890.pdf"""
|
| 170 |
+
|
| 171 |
+
def __init__(
|
| 172 |
+
self,
|
| 173 |
+
drop_prob: float = 0.1,
|
| 174 |
+
block_size: int = 7,
|
| 175 |
+
gamma_scale: float = 1.0,
|
| 176 |
+
with_noise: bool = False,
|
| 177 |
+
inplace: bool = False,
|
| 178 |
+
batchwise: bool = False,
|
| 179 |
+
fast: bool = True,
|
| 180 |
+
):
|
| 181 |
+
super().__init__()
|
| 182 |
+
self.drop_prob = drop_prob
|
| 183 |
+
self.gamma_scale = gamma_scale
|
| 184 |
+
self.block_size = block_size
|
| 185 |
+
self.with_noise = with_noise
|
| 186 |
+
self.inplace = inplace
|
| 187 |
+
self.batchwise = batchwise
|
| 188 |
+
self.fast = fast # FIXME finish comparisons of fast vs not
|
| 189 |
+
|
| 190 |
+
def forward(self, x):
|
| 191 |
+
if not self.training or not self.drop_prob:
|
| 192 |
+
return x
|
| 193 |
+
if self.fast:
|
| 194 |
+
return drop_block_fast_2d(
|
| 195 |
+
x, self.drop_prob, self.block_size, self.gamma_scale, self.with_noise, self.inplace
|
| 196 |
+
)
|
| 197 |
+
else:
|
| 198 |
+
return drop_block_2d(
|
| 199 |
+
x, self.drop_prob, self.block_size, self.gamma_scale, self.with_noise, self.inplace, self.batchwise
|
| 200 |
+
)
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
def drop_path(x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True):
|
| 204 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
| 205 |
+
|
| 206 |
+
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
|
| 207 |
+
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
|
| 208 |
+
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
|
| 209 |
+
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
|
| 210 |
+
'survival rate' as the argument.
|
| 211 |
+
|
| 212 |
+
"""
|
| 213 |
+
if drop_prob == 0.0 or not training:
|
| 214 |
+
return x
|
| 215 |
+
keep_prob = 1 - drop_prob
|
| 216 |
+
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
| 217 |
+
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
|
| 218 |
+
if keep_prob > 0.0 and scale_by_keep:
|
| 219 |
+
random_tensor.div_(keep_prob)
|
| 220 |
+
return x * random_tensor
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
class DropPath(nn.Module):
|
| 224 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
|
| 225 |
+
|
| 226 |
+
def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True):
|
| 227 |
+
super().__init__()
|
| 228 |
+
self.drop_prob = drop_prob
|
| 229 |
+
self.scale_by_keep = scale_by_keep
|
| 230 |
+
|
| 231 |
+
def forward(self, x):
|
| 232 |
+
return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
|
| 233 |
+
|
| 234 |
+
def extra_repr(self):
|
| 235 |
+
return f"drop_prob={round(self.drop_prob,3):0.3f}"
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
class InternRMSNorm(nn.Module):
|
| 239 |
+
def __init__(self, hidden_size, eps=1e-6):
|
| 240 |
+
super().__init__()
|
| 241 |
+
self.weight = nn.Parameter(torch.ones(hidden_size))
|
| 242 |
+
self.variance_epsilon = eps
|
| 243 |
+
|
| 244 |
+
def forward(self, hidden_states):
|
| 245 |
+
input_dtype = hidden_states.dtype
|
| 246 |
+
hidden_states = hidden_states.to(torch.float32)
|
| 247 |
+
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
| 248 |
+
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
| 249 |
+
return self.weight * hidden_states.to(input_dtype)
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
try:
|
| 253 |
+
from apex.normalization import FusedRMSNorm
|
| 254 |
+
|
| 255 |
+
InternRMSNorm = FusedRMSNorm # noqa
|
| 256 |
+
|
| 257 |
+
logger.info("Discovered apex.normalization.FusedRMSNorm - will use it instead of InternRMSNorm")
|
| 258 |
+
except ImportError:
|
| 259 |
+
# using the normal InternRMSNorm
|
| 260 |
+
pass
|
| 261 |
+
except Exception:
|
| 262 |
+
logger.warning("discovered apex but it failed to load, falling back to InternRMSNorm")
|
| 263 |
+
pass
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
class InternVisionEmbeddings(nn.Module):
|
| 267 |
+
def __init__(self, config: InternVisionConfig):
|
| 268 |
+
super().__init__()
|
| 269 |
+
self.config = config
|
| 270 |
+
self.embed_dim = config.hidden_size
|
| 271 |
+
self.image_size = config.image_size
|
| 272 |
+
self.patch_size = config.patch_size
|
| 273 |
+
|
| 274 |
+
self.class_embedding = nn.Parameter(
|
| 275 |
+
torch.randn(1, 1, self.embed_dim),
|
| 276 |
+
)
|
| 277 |
+
|
| 278 |
+
self.patch_embedding = nn.Conv2d(
|
| 279 |
+
in_channels=3, out_channels=self.embed_dim, kernel_size=self.patch_size, stride=self.patch_size
|
| 280 |
+
)
|
| 281 |
+
|
| 282 |
+
self.num_patches = (self.image_size // self.patch_size) ** 2
|
| 283 |
+
self.num_positions = self.num_patches + 1
|
| 284 |
+
|
| 285 |
+
self.position_embedding = nn.Parameter(torch.randn(1, self.num_positions, self.embed_dim))
|
| 286 |
+
|
| 287 |
+
def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
|
| 288 |
+
batch_size = pixel_values.shape[0]
|
| 289 |
+
target_dtype = self.patch_embedding.weight.dtype
|
| 290 |
+
patch_embeds = self.patch_embedding(pixel_values) # shape = [*, width, grid, grid]
|
| 291 |
+
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
|
| 292 |
+
class_embeds = self.class_embedding.expand(batch_size, 1, -1).to(target_dtype)
|
| 293 |
+
embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
|
| 294 |
+
embeddings = embeddings + self.position_embedding.to(target_dtype)
|
| 295 |
+
return embeddings
|
| 296 |
+
|
| 297 |
+
|
| 298 |
+
class InternAttention(nn.Module):
|
| 299 |
+
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
| 300 |
+
|
| 301 |
+
def __init__(self, config: InternVisionConfig):
|
| 302 |
+
super().__init__()
|
| 303 |
+
self.config = config
|
| 304 |
+
self.embed_dim = config.hidden_size
|
| 305 |
+
self.num_heads = config.num_attention_heads
|
| 306 |
+
self.use_flash_attn = config.use_flash_attn and has_flash_attn
|
| 307 |
+
if config.use_flash_attn and not has_flash_attn:
|
| 308 |
+
print("Warning: Flash Attention is not available, use_flash_attn is set to False.")
|
| 309 |
+
self.head_dim = self.embed_dim // self.num_heads
|
| 310 |
+
if self.head_dim * self.num_heads != self.embed_dim:
|
| 311 |
+
raise ValueError(
|
| 312 |
+
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
|
| 313 |
+
f" {self.num_heads})."
|
| 314 |
+
)
|
| 315 |
+
|
| 316 |
+
self.scale = self.head_dim**-0.5
|
| 317 |
+
self.qkv = nn.Linear(self.embed_dim, 3 * self.embed_dim, bias=config.qkv_bias)
|
| 318 |
+
self.attn_drop = nn.Dropout(config.attention_dropout)
|
| 319 |
+
self.proj_drop = nn.Dropout(config.dropout)
|
| 320 |
+
|
| 321 |
+
self.qk_normalization = config.qk_normalization
|
| 322 |
+
|
| 323 |
+
if self.qk_normalization:
|
| 324 |
+
self.q_norm = InternRMSNorm(self.embed_dim, eps=config.layer_norm_eps)
|
| 325 |
+
self.k_norm = InternRMSNorm(self.embed_dim, eps=config.layer_norm_eps)
|
| 326 |
+
|
| 327 |
+
if self.use_flash_attn:
|
| 328 |
+
self.inner_attn = FlashAttention(attention_dropout=config.attention_dropout)
|
| 329 |
+
self.proj = nn.Linear(self.embed_dim, self.embed_dim)
|
| 330 |
+
|
| 331 |
+
def _naive_attn(self, x):
|
| 332 |
+
B, N, C = x.shape
|
| 333 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
| 334 |
+
q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
|
| 335 |
+
|
| 336 |
+
if self.qk_normalization:
|
| 337 |
+
B_, H_, N_, D_ = q.shape
|
| 338 |
+
q = self.q_norm(q.transpose(1, 2).flatten(-2, -1)).view(B_, N_, H_, D_).transpose(1, 2)
|
| 339 |
+
k = self.k_norm(k.transpose(1, 2).flatten(-2, -1)).view(B_, N_, H_, D_).transpose(1, 2)
|
| 340 |
+
|
| 341 |
+
attn = (q * self.scale) @ k.transpose(-2, -1)
|
| 342 |
+
attn = attn.softmax(dim=-1)
|
| 343 |
+
attn = self.attn_drop(attn)
|
| 344 |
+
|
| 345 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
| 346 |
+
x = self.proj(x)
|
| 347 |
+
x = self.proj_drop(x)
|
| 348 |
+
return x
|
| 349 |
+
|
| 350 |
+
def _flash_attn(self, x, key_padding_mask=None, need_weights=False):
|
| 351 |
+
qkv = self.qkv(x)
|
| 352 |
+
qkv = rearrange(qkv, "b s (three h d) -> b s three h d", three=3, h=self.num_heads)
|
| 353 |
+
|
| 354 |
+
if self.qk_normalization:
|
| 355 |
+
q, k, v = qkv.unbind(2)
|
| 356 |
+
q = self.q_norm(q.flatten(-2, -1)).view(q.shape)
|
| 357 |
+
k = self.k_norm(k.flatten(-2, -1)).view(k.shape)
|
| 358 |
+
qkv = torch.stack([q, k, v], dim=2)
|
| 359 |
+
|
| 360 |
+
context, _ = self.inner_attn(qkv, key_padding_mask=key_padding_mask, need_weights=need_weights, causal=False)
|
| 361 |
+
outs = self.proj(rearrange(context, "b s h d -> b s (h d)"))
|
| 362 |
+
outs = self.proj_drop(outs)
|
| 363 |
+
return outs
|
| 364 |
+
|
| 365 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 366 |
+
x = self._naive_attn(hidden_states) if not self.use_flash_attn else self._flash_attn(hidden_states)
|
| 367 |
+
return x
|
| 368 |
+
|
| 369 |
+
|
| 370 |
+
class InternMLP(nn.Module):
|
| 371 |
+
def __init__(self, config: InternVisionConfig):
|
| 372 |
+
super().__init__()
|
| 373 |
+
self.config = config
|
| 374 |
+
self.act = ACT2FN[config.hidden_act]
|
| 375 |
+
self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
|
| 376 |
+
self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
|
| 377 |
+
|
| 378 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 379 |
+
hidden_states = self.fc1(hidden_states)
|
| 380 |
+
hidden_states = self.act(hidden_states)
|
| 381 |
+
hidden_states = self.fc2(hidden_states)
|
| 382 |
+
return hidden_states
|
| 383 |
+
|
| 384 |
+
|
| 385 |
+
class InternVisionEncoderLayer(nn.Module):
|
| 386 |
+
def __init__(self, config: InternVisionConfig, drop_path_rate: float):
|
| 387 |
+
super().__init__()
|
| 388 |
+
self.embed_dim = config.hidden_size
|
| 389 |
+
self.intermediate_size = config.intermediate_size
|
| 390 |
+
|
| 391 |
+
self.attn = InternAttention(config)
|
| 392 |
+
self.mlp = InternMLP(config)
|
| 393 |
+
self.norm1 = InternRMSNorm(self.embed_dim, eps=config.layer_norm_eps)
|
| 394 |
+
self.norm2 = InternRMSNorm(self.embed_dim, eps=config.layer_norm_eps)
|
| 395 |
+
|
| 396 |
+
self.ls1 = nn.Parameter(config.initializer_factor * torch.ones(self.embed_dim))
|
| 397 |
+
self.ls2 = nn.Parameter(config.initializer_factor * torch.ones(self.embed_dim))
|
| 398 |
+
self.drop_path1 = DropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity()
|
| 399 |
+
self.drop_path2 = DropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity()
|
| 400 |
+
|
| 401 |
+
def forward(
|
| 402 |
+
self,
|
| 403 |
+
hidden_states: torch.Tensor,
|
| 404 |
+
) -> Tuple[torch.FloatTensor, Optional[torch.FloatTensor], Optional[Tuple[torch.FloatTensor]]]:
|
| 405 |
+
"""
|
| 406 |
+
Args:
|
| 407 |
+
hidden_states (`Tuple[torch.FloatTensor, Optional[torch.FloatTensor]]`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
| 408 |
+
"""
|
| 409 |
+
hidden_states = hidden_states + self.drop_path1(self.attn(self.norm1(hidden_states)) * self.ls1)
|
| 410 |
+
|
| 411 |
+
hidden_states = hidden_states + self.drop_path2(self.mlp(self.norm2(hidden_states)) * self.ls2)
|
| 412 |
+
|
| 413 |
+
return hidden_states
|
| 414 |
+
|
| 415 |
+
|
| 416 |
+
class InternVisionEncoder(nn.Module):
|
| 417 |
+
"""
|
| 418 |
+
Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
|
| 419 |
+
[`InternEncoderLayer`].
|
| 420 |
+
|
| 421 |
+
Args:
|
| 422 |
+
config (`InternConfig`):
|
| 423 |
+
The corresponding vision configuration for the `InternEncoder`.
|
| 424 |
+
"""
|
| 425 |
+
|
| 426 |
+
def __init__(self, config: InternVisionConfig):
|
| 427 |
+
super().__init__()
|
| 428 |
+
self.config = config
|
| 429 |
+
# stochastic depth decay rule
|
| 430 |
+
dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, config.num_hidden_layers)]
|
| 431 |
+
self.layers = nn.ModuleList(
|
| 432 |
+
[InternVisionEncoderLayer(config, dpr[idx]) for idx in range(config.num_hidden_layers)]
|
| 433 |
+
)
|
| 434 |
+
self.gradient_checkpointing = True
|
| 435 |
+
|
| 436 |
+
def forward(
|
| 437 |
+
self,
|
| 438 |
+
inputs_embeds,
|
| 439 |
+
output_hidden_states: Optional[bool] = None,
|
| 440 |
+
return_dict: Optional[bool] = None,
|
| 441 |
+
) -> Union[Tuple, BaseModelOutput]:
|
| 442 |
+
r"""
|
| 443 |
+
Args:
|
| 444 |
+
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
| 445 |
+
Embedded representation of the inputs. Should be float, not int tokens.
|
| 446 |
+
output_hidden_states (`bool`, *optional*):
|
| 447 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
|
| 448 |
+
for more detail.
|
| 449 |
+
return_dict (`bool`, *optional*):
|
| 450 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
| 451 |
+
"""
|
| 452 |
+
output_hidden_states = (
|
| 453 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 454 |
+
)
|
| 455 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 456 |
+
|
| 457 |
+
encoder_states = () if output_hidden_states else None
|
| 458 |
+
hidden_states = inputs_embeds
|
| 459 |
+
|
| 460 |
+
for idx, encoder_layer in enumerate(self.layers):
|
| 461 |
+
if output_hidden_states:
|
| 462 |
+
encoder_states = encoder_states + (hidden_states,)
|
| 463 |
+
if self.gradient_checkpointing and self.training:
|
| 464 |
+
layer_outputs = torch.utils.checkpoint.checkpoint(encoder_layer, hidden_states)
|
| 465 |
+
else:
|
| 466 |
+
layer_outputs = encoder_layer(
|
| 467 |
+
hidden_states,
|
| 468 |
+
)
|
| 469 |
+
hidden_states = layer_outputs
|
| 470 |
+
|
| 471 |
+
if output_hidden_states:
|
| 472 |
+
encoder_states = encoder_states + (hidden_states,)
|
| 473 |
+
|
| 474 |
+
if not return_dict:
|
| 475 |
+
return tuple(v for v in [hidden_states, encoder_states] if v is not None)
|
| 476 |
+
return BaseModelOutput(last_hidden_state=hidden_states, hidden_states=encoder_states)
|
| 477 |
+
|
| 478 |
+
|
| 479 |
+
class InternVisionModel(PreTrainedModel):
|
| 480 |
+
main_input_name = "pixel_values"
|
| 481 |
+
config_class = InternVisionConfig
|
| 482 |
+
_no_split_modules = ["InternVisionEncoderLayer"]
|
| 483 |
+
|
| 484 |
+
def __init__(self, config: InternVisionConfig):
|
| 485 |
+
super().__init__(config)
|
| 486 |
+
self.config = config
|
| 487 |
+
|
| 488 |
+
self.embeddings = InternVisionEmbeddings(config)
|
| 489 |
+
self.encoder = InternVisionEncoder(config)
|
| 490 |
+
|
| 491 |
+
def resize_pos_embeddings(self, old_size, new_size, patch_size):
|
| 492 |
+
pos_emb = self.embeddings.position_embedding
|
| 493 |
+
_, num_positions, embed_dim = pos_emb.shape
|
| 494 |
+
cls_emb = pos_emb[:, :1, :]
|
| 495 |
+
pos_emb = pos_emb[:, 1:, :].reshape(1, old_size // patch_size, old_size // patch_size, -1).permute(0, 3, 1, 2)
|
| 496 |
+
pos_emb = F.interpolate(pos_emb.float(), size=new_size // patch_size, mode="bicubic", align_corners=False)
|
| 497 |
+
pos_emb = pos_emb.to(cls_emb.dtype).reshape(1, embed_dim, -1).permute(0, 2, 1)
|
| 498 |
+
pos_emb = torch.cat([cls_emb, pos_emb], dim=1)
|
| 499 |
+
self.embeddings.position_embedding = nn.Parameter(pos_emb)
|
| 500 |
+
logger.info(f"Resized position embeddings from {old_size} to {new_size}")
|
| 501 |
+
|
| 502 |
+
def get_input_embeddings(self):
|
| 503 |
+
return self.embeddings
|
| 504 |
+
|
| 505 |
+
def forward(
|
| 506 |
+
self,
|
| 507 |
+
pixel_values: Optional[torch.FloatTensor] = None,
|
| 508 |
+
output_hidden_states: Optional[bool] = None,
|
| 509 |
+
return_dict: Optional[bool] = None,
|
| 510 |
+
pixel_embeds: Optional[torch.FloatTensor] = None,
|
| 511 |
+
) -> Union[Tuple, BaseModelOutputWithPooling]:
|
| 512 |
+
output_hidden_states = (
|
| 513 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 514 |
+
)
|
| 515 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 516 |
+
|
| 517 |
+
if pixel_values is None and pixel_embeds is None:
|
| 518 |
+
raise ValueError("You have to specify pixel_values or pixel_embeds")
|
| 519 |
+
|
| 520 |
+
if pixel_embeds is not None:
|
| 521 |
+
hidden_states = pixel_embeds
|
| 522 |
+
else:
|
| 523 |
+
if len(pixel_values.shape) == 4:
|
| 524 |
+
hidden_states = self.embeddings(pixel_values)
|
| 525 |
+
else:
|
| 526 |
+
raise ValueError(f"wrong pixel_values size: {pixel_values.shape}")
|
| 527 |
+
encoder_outputs = self.encoder(
|
| 528 |
+
inputs_embeds=hidden_states,
|
| 529 |
+
output_hidden_states=output_hidden_states,
|
| 530 |
+
return_dict=return_dict,
|
| 531 |
+
)
|
| 532 |
+
last_hidden_state = encoder_outputs.last_hidden_state
|
| 533 |
+
pooled_output = last_hidden_state[:, 0, :]
|
| 534 |
+
|
| 535 |
+
if not return_dict:
|
| 536 |
+
return (last_hidden_state, pooled_output) + encoder_outputs[1:]
|
| 537 |
+
|
| 538 |
+
return BaseModelOutputWithPooling(
|
| 539 |
+
last_hidden_state=last_hidden_state,
|
| 540 |
+
pooler_output=pooled_output,
|
| 541 |
+
hidden_states=encoder_outputs.hidden_states,
|
| 542 |
+
attentions=encoder_outputs.attentions,
|
| 543 |
+
)
|
VILA/llava/model/multimodal_encoder/intern_encoder.py
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
#
|
| 15 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 16 |
+
|
| 17 |
+
import torch
|
| 18 |
+
import torchvision.transforms as T
|
| 19 |
+
from torchvision.transforms.functional import InterpolationMode
|
| 20 |
+
from transformers import AutoConfig, AutoModel
|
| 21 |
+
from transformers.image_processing_utils import BaseImageProcessor
|
| 22 |
+
|
| 23 |
+
from llava.model.multimodal_encoder.intern.configuration_intern_vit import InternVisionConfig
|
| 24 |
+
from llava.model.multimodal_encoder.intern.modeling_intern_vit import InternVisionModel
|
| 25 |
+
from llava.model.multimodal_encoder.vision_encoder import VisionTower
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def build_transform(input_size):
|
| 29 |
+
transform = T.Compose(
|
| 30 |
+
[
|
| 31 |
+
T.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img),
|
| 32 |
+
T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
|
| 33 |
+
T.ToTensor(),
|
| 34 |
+
T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
|
| 35 |
+
]
|
| 36 |
+
)
|
| 37 |
+
return transform
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class InternVisionPreprocessor(BaseImageProcessor):
|
| 41 |
+
@property
|
| 42 |
+
def size(self):
|
| 43 |
+
return {"height": 448, "width": 448}
|
| 44 |
+
|
| 45 |
+
def preprocess(self, image, return_tensors):
|
| 46 |
+
transform = build_transform(448)
|
| 47 |
+
if isinstance(image, list):
|
| 48 |
+
image_tensor = [transform(img) for img in image]
|
| 49 |
+
return {"pixel_values": image_tensor}
|
| 50 |
+
else:
|
| 51 |
+
image_tensor = transform(image)
|
| 52 |
+
return {"pixel_values": [image_tensor]}
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
class InternVisionTower(VisionTower):
|
| 56 |
+
def __init__(self, vision_tower, config, drop_path_rate=0.0):
|
| 57 |
+
super().__init__(vision_tower, config)
|
| 58 |
+
self._drop_path_rate = drop_path_rate
|
| 59 |
+
|
| 60 |
+
self.image_processor = InternVisionPreprocessor()
|
| 61 |
+
vision_config = InternVisionConfig.from_pretrained(vision_tower)
|
| 62 |
+
vision_config.drop_path_rate = self._drop_path_rate
|
| 63 |
+
self.vision_tower = InternVisionModel.from_pretrained(
|
| 64 |
+
vision_tower, torch_dtype=eval(config.model_dtype), config=vision_config
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
self.is_loaded = True
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
AutoConfig.register("intern_vit_6b", InternVisionConfig)
|
| 71 |
+
AutoModel.register(InternVisionConfig, InternVisionModel)
|
VILA/llava/model/multimodal_encoder/radio_encoder.py
ADDED
|
@@ -0,0 +1,334 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
#
|
| 15 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 16 |
+
|
| 17 |
+
import os
|
| 18 |
+
import warnings
|
| 19 |
+
from argparse import Namespace
|
| 20 |
+
from typing import Any, Dict
|
| 21 |
+
|
| 22 |
+
import numpy as np
|
| 23 |
+
import torch
|
| 24 |
+
from PIL import Image
|
| 25 |
+
from transformers import AutoConfig, AutoModel, CLIPVisionConfig
|
| 26 |
+
|
| 27 |
+
from llava.model.multimodal_encoder.vision_encoder import VisionTower
|
| 28 |
+
from llava.train.utils import mprint, rprint
|
| 29 |
+
|
| 30 |
+
from .image_processor import ImageProcessor
|
| 31 |
+
from .visualize_features import get_pca_map
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def get_prefix_state_dict(state_dict: Dict[str, Any], prefix: str):
|
| 35 |
+
mod_state_dict = {k[len(prefix) :]: v for k, v in state_dict.items() if k.startswith(prefix)}
|
| 36 |
+
return mod_state_dict
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def is_rank0():
|
| 40 |
+
return not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class RADIOVisionTower(VisionTower):
|
| 44 |
+
"""
|
| 45 |
+
Vision Tower for the RADIO model.
|
| 46 |
+
|
| 47 |
+
Args:
|
| 48 |
+
vision_tower (str): Vision tower name. This is passed on
|
| 49 |
+
the command line with the `--vision_tower` argument.
|
| 50 |
+
The string is expected in the pattern of:
|
| 51 |
+
`radio:<image_size>:<checkpoint>:<extra_config>`.
|
| 52 |
+
Where <extra_config> is a comma-separated list of key=value pairs.
|
| 53 |
+
<image_size> can also be a comma-separated list of resolutions in
|
| 54 |
+
the case of multi-res inference. Limitations apply, e.g. only two
|
| 55 |
+
resolutions are supported and the second resolution must be a divisor
|
| 56 |
+
of the first one.
|
| 57 |
+
args (Namespace): Arguments.
|
| 58 |
+
delay_load (bool): Delay loading the model.
|
| 59 |
+
"""
|
| 60 |
+
|
| 61 |
+
def __init__(self, vision_tower, args, delay_load=False):
|
| 62 |
+
"""Initialization Routine."""
|
| 63 |
+
|
| 64 |
+
super().__init__(vision_tower, args, delay_load)
|
| 65 |
+
|
| 66 |
+
mprint(f"RADIOVisionTower: {vision_tower}. Args: {args} Delay load: {delay_load}")
|
| 67 |
+
|
| 68 |
+
assert not delay_load
|
| 69 |
+
|
| 70 |
+
self.select_feature = getattr(args, "mm_vision_select_feature", "patch")
|
| 71 |
+
|
| 72 |
+
extra_config = {}
|
| 73 |
+
|
| 74 |
+
# Check if vision_tower is a valid path.
|
| 75 |
+
if os.path.exists(vision_tower):
|
| 76 |
+
self.vision_tower_name = self.vision_tower_checkpoint = vision_tower
|
| 77 |
+
vision_cfg = getattr(args, "vision_tower_cfg")
|
| 78 |
+
self.image_size = vision_cfg["image_size"]
|
| 79 |
+
else:
|
| 80 |
+
self.vision_tower_name = vision_tower[len("radio:") :]
|
| 81 |
+
config_items = self.vision_tower_name.split(":")
|
| 82 |
+
self.image_size = int(config_items[0])
|
| 83 |
+
|
| 84 |
+
self.vision_tower_checkpoint = config_items[1]
|
| 85 |
+
|
| 86 |
+
if len(config_items) > 2:
|
| 87 |
+
# Parse extra config items. These are provided as a comma-separated list
|
| 88 |
+
# of key=value pairs.
|
| 89 |
+
extra_config_items = config_items[2].split(",")
|
| 90 |
+
|
| 91 |
+
for item in extra_config_items:
|
| 92 |
+
key, value = item.split("=")
|
| 93 |
+
extra_config[key] = value
|
| 94 |
+
|
| 95 |
+
self.image_aspect_ratio = args.image_aspect_ratio
|
| 96 |
+
self.skip_layer_norm = eval(extra_config.get("skip_layer_norm", "False"))
|
| 97 |
+
|
| 98 |
+
if not delay_load:
|
| 99 |
+
self.load_model()
|
| 100 |
+
else:
|
| 101 |
+
raise ValueError("Delay load not supported for RADIOVisionTower.")
|
| 102 |
+
|
| 103 |
+
self.sample_count = 0
|
| 104 |
+
self.debug = True
|
| 105 |
+
|
| 106 |
+
def get_hidden_size(self):
|
| 107 |
+
if self.select_feature == "cls":
|
| 108 |
+
hidden_size = 5120
|
| 109 |
+
elif self.select_feature == "dense":
|
| 110 |
+
hidden_size = 4 * 1280
|
| 111 |
+
else:
|
| 112 |
+
hidden_size = 1280
|
| 113 |
+
|
| 114 |
+
return hidden_size
|
| 115 |
+
|
| 116 |
+
def load_model(self):
|
| 117 |
+
if self.image_aspect_ratio == "resize":
|
| 118 |
+
self.image_processor = ImageProcessor(
|
| 119 |
+
size={"width": self.image_size, "height": self.image_size},
|
| 120 |
+
do_pad=False,
|
| 121 |
+
do_normalize=True,
|
| 122 |
+
do_convert_rgb=True,
|
| 123 |
+
)
|
| 124 |
+
else:
|
| 125 |
+
self.image_processor = ImageProcessor(
|
| 126 |
+
size={"longest_edge": self.image_size},
|
| 127 |
+
do_pad=True,
|
| 128 |
+
pad_multiple=16,
|
| 129 |
+
do_normalize=True,
|
| 130 |
+
do_convert_rgb=True,
|
| 131 |
+
pad_value=0.456,
|
| 132 |
+
)
|
| 133 |
+
# For compatibility with CLIP Image Processor: the data loader uses width/height to
|
| 134 |
+
# create dummy blank images for samples that don't have an image.
|
| 135 |
+
self.image_processor.crop_size = {"width": self.image_size, "height": self.image_size}
|
| 136 |
+
|
| 137 |
+
mprint(self.image_processor)
|
| 138 |
+
|
| 139 |
+
config = AutoConfig.from_pretrained(self.vision_tower_checkpoint, trust_remote_code=True)
|
| 140 |
+
mprint("RADIO config", config)
|
| 141 |
+
self.vision_tower = AutoModel.from_pretrained(self.vision_tower_checkpoint, trust_remote_code=True)
|
| 142 |
+
self.vision_tower.radio_model.make_preprocessor_external()
|
| 143 |
+
|
| 144 |
+
# # NOTE: do a lazy import of Timm to avoid issues with
|
| 145 |
+
# # DeepSpeed's ZeRO-3.
|
| 146 |
+
from timm.models.vision_transformer import VisionTransformer
|
| 147 |
+
|
| 148 |
+
#
|
| 149 |
+
if isinstance(self.vision_tower.model, VisionTransformer):
|
| 150 |
+
hidden_size = self.vision_tower.model.embed_dim
|
| 151 |
+
else:
|
| 152 |
+
raise ValueError(f"Unknown model type: {self.vision_tower}")
|
| 153 |
+
|
| 154 |
+
# Override hidden size for OpenAI CLIP.
|
| 155 |
+
hidden_size = self.get_hidden_size()
|
| 156 |
+
|
| 157 |
+
if hasattr(self.vision_tower.model, "patch_generator"):
|
| 158 |
+
patch_gen = self.vision_tower.model.patch_generator
|
| 159 |
+
# Cropped Positional Embedding (CPE) case.
|
| 160 |
+
patch_size = patch_gen.patch_size
|
| 161 |
+
else:
|
| 162 |
+
# Standard ViT case.
|
| 163 |
+
patch_size = self.vision_tower.model.patch_embed.patch_size[0]
|
| 164 |
+
|
| 165 |
+
self.vision_tower.config.image_size = self.image_size
|
| 166 |
+
self.vision_tower.config.hidden_size = hidden_size
|
| 167 |
+
self.vision_tower.config.patch_size = patch_size
|
| 168 |
+
|
| 169 |
+
self.vision_tower.cuda().eval()
|
| 170 |
+
self.vision_tower.requires_grad_(False)
|
| 171 |
+
|
| 172 |
+
self.is_loaded = True
|
| 173 |
+
self._to_dtype = None
|
| 174 |
+
|
| 175 |
+
if self.skip_layer_norm:
|
| 176 |
+
mprint(f"Removing layer norm from the model: {self.vision_tower.model.norm}")
|
| 177 |
+
self.vision_tower.model.norm = torch.nn.Identity()
|
| 178 |
+
|
| 179 |
+
def to(self, *args, **kwargs):
|
| 180 |
+
# Prevent casting the RADIO model's weights
|
| 181 |
+
kwargs = dict(kwargs)
|
| 182 |
+
# self._to_dtype = kwargs.get('dtype', None)
|
| 183 |
+
self._to_dtype = kwargs.pop("dtype", None)
|
| 184 |
+
mprint(f"RADIO: bypass cast to dtype={self._to_dtype}")
|
| 185 |
+
super().to(*args, **kwargs)
|
| 186 |
+
pass
|
| 187 |
+
|
| 188 |
+
def train(self, mode=True):
|
| 189 |
+
"""Intercept call."""
|
| 190 |
+
# Drop a warning if mode is True.
|
| 191 |
+
if mode:
|
| 192 |
+
warnings.warn("RADIOEncoder is always in eval mode.")
|
| 193 |
+
pass
|
| 194 |
+
|
| 195 |
+
def _get_summary_and_patch_from_tokens(self, tokens):
|
| 196 |
+
model = self.vision_tower.model
|
| 197 |
+
patch_gen = getattr(model, "patch_generator", None)
|
| 198 |
+
if patch_gen is not None:
|
| 199 |
+
all_summary = tokens[:, : patch_gen.num_cls_tokens]
|
| 200 |
+
if self.vision_tower.radio_model.summary_idxs is not None:
|
| 201 |
+
summary = all_summary[:, self.vision_tower.radio_model.summary_idxs]
|
| 202 |
+
else:
|
| 203 |
+
summary = all_summary
|
| 204 |
+
all_feat = tokens[:, patch_gen.num_skip :]
|
| 205 |
+
elif model.global_pool == "avg":
|
| 206 |
+
all_summary = tokens[:, model.num_prefix_tokens :].mean(dim=1)
|
| 207 |
+
summary = all_summary
|
| 208 |
+
all_feat = tokens
|
| 209 |
+
else:
|
| 210 |
+
all_summary = tokens[:, 0]
|
| 211 |
+
summary = all_summary
|
| 212 |
+
all_feat = tokens[:, 1:]
|
| 213 |
+
return summary, all_feat
|
| 214 |
+
|
| 215 |
+
@torch.no_grad()
|
| 216 |
+
def get_features(self, x: torch.Tensor):
|
| 217 |
+
x_dtype = x.dtype
|
| 218 |
+
x = x.float()
|
| 219 |
+
with torch.autocast("cuda", dtype=torch.bfloat16):
|
| 220 |
+
if self.select_feature == "dense":
|
| 221 |
+
|
| 222 |
+
# Layers to return activations of in case of "return_multilayer=True".
|
| 223 |
+
num_layers = len(self.vision_tower.model.blocks)
|
| 224 |
+
multilayers = [
|
| 225 |
+
num_layers // 4 - 1,
|
| 226 |
+
num_layers // 2 - 1,
|
| 227 |
+
num_layers // 4 * 3 - 1,
|
| 228 |
+
]
|
| 229 |
+
|
| 230 |
+
features = []
|
| 231 |
+
intermediate_features = []
|
| 232 |
+
|
| 233 |
+
x = self.vision_tower.input_conditioner(x)
|
| 234 |
+
x = self.vision_tower.model.patch_generator(x)
|
| 235 |
+
|
| 236 |
+
for i, blk in enumerate(self.vision_tower.model.blocks):
|
| 237 |
+
x = blk(x)
|
| 238 |
+
_, blk_features = self._get_summary_and_patch_from_tokens(x)
|
| 239 |
+
intermediate_features.append(blk_features)
|
| 240 |
+
if i in multilayers:
|
| 241 |
+
intermediate_features = torch.stack(intermediate_features, dim=0)
|
| 242 |
+
intermediate_features = torch.sum(intermediate_features, dim=0) / intermediate_features.shape[0]
|
| 243 |
+
features.append(intermediate_features)
|
| 244 |
+
intermediate_features = []
|
| 245 |
+
x = self.vision_tower.model.norm(x)
|
| 246 |
+
last_summary, last_features = self._get_summary_and_patch_from_tokens(x)
|
| 247 |
+
features.append(last_features)
|
| 248 |
+
features = torch.cat(features, dim=-1)
|
| 249 |
+
summary = last_summary
|
| 250 |
+
else:
|
| 251 |
+
summary, features = self.vision_tower(x)
|
| 252 |
+
|
| 253 |
+
return summary, features.to(dtype=x_dtype)
|
| 254 |
+
|
| 255 |
+
@torch.no_grad()
|
| 256 |
+
def forward(self, images: torch.Tensor):
|
| 257 |
+
"""Main forward pass."""
|
| 258 |
+
input_shape = images.shape
|
| 259 |
+
|
| 260 |
+
x = images
|
| 261 |
+
# Add a batch dimension if necessary.
|
| 262 |
+
if len(input_shape) == 3:
|
| 263 |
+
x = x.unsqueeze(0)
|
| 264 |
+
|
| 265 |
+
# Convert the input to the model's dtype (we assume
|
| 266 |
+
# that the model only has one dtype for all parameters).
|
| 267 |
+
param0 = next(self.vision_tower.parameters())
|
| 268 |
+
|
| 269 |
+
rprint(
|
| 270 |
+
f"input shape={input_shape}->{x.shape} device={x.device} mean={x.mean().item()} std={x.std().item()} dtype={x.dtype} param0.device={param0.device} param0.dtype={param0.dtype}"
|
| 271 |
+
)
|
| 272 |
+
|
| 273 |
+
summary, features = self.get_features(x) # B, T, C
|
| 274 |
+
|
| 275 |
+
if len(summary.shape) == 2:
|
| 276 |
+
if self.select_feature == "cls4":
|
| 277 |
+
# Add a token dimension if necessary.
|
| 278 |
+
B, C = summary.shape
|
| 279 |
+
summary = summary.reshape(B, 4, C // 4)
|
| 280 |
+
else:
|
| 281 |
+
# Add a token dimension if necessary.
|
| 282 |
+
summary = summary.unsqueeze(1)
|
| 283 |
+
|
| 284 |
+
B, _, H, W = x.shape
|
| 285 |
+
_, _, C = features.shape
|
| 286 |
+
patch_size = self.vision_tower.config.patch_size
|
| 287 |
+
spatial_features = features.reshape(B, H // patch_size, W // patch_size, C)
|
| 288 |
+
spatial_features = spatial_features.permute(0, 3, 1, 2) # B, C, H/patch_size, W/patch_size
|
| 289 |
+
|
| 290 |
+
if self.debug and is_rank0() and self.sample_count % 1000 == 0:
|
| 291 |
+
spatial_features_hwc = spatial_features.permute(0, 2, 3, 1)
|
| 292 |
+
# create the debug directory
|
| 293 |
+
os.makedirs("radio-debug", exist_ok=True)
|
| 294 |
+
torch.save(x, f"radio-debug/sample_{self.sample_count}_input.pt")
|
| 295 |
+
torch.save(features, f"radio-debug/sample_{self.sample_count}_features.pt")
|
| 296 |
+
torch.save(spatial_features_hwc, f"radio-debug/sample_{self.sample_count}_features_reshaped.pt")
|
| 297 |
+
for i in range(B):
|
| 298 |
+
image = x[i].permute(1, 2, 0).float() * 255
|
| 299 |
+
image = Image.fromarray(image.cpu().numpy().astype(np.uint8))
|
| 300 |
+
image.save(os.path.join("radio-debug/", f"sample_{self.sample_count}_preprocessed_{i}.png"))
|
| 301 |
+
pca_map = get_pca_map(spatial_features_hwc[i : i + 1], x.shape[-2:])
|
| 302 |
+
torch.save(pca_map, f"radio-debug/sample_{self.sample_count}_pca_map_{i}.pt")
|
| 303 |
+
image = pca_map * 255
|
| 304 |
+
image = Image.fromarray(image.astype(np.uint8))
|
| 305 |
+
image.save(os.path.join("radio-debug/", f"sample_{self.sample_count}_pca_map_{i}.png"))
|
| 306 |
+
pass
|
| 307 |
+
|
| 308 |
+
if self.select_feature in ["patch", "cls_patch", "dense"]:
|
| 309 |
+
# Ignore cls-patch for now.
|
| 310 |
+
pass
|
| 311 |
+
# elif self.select_feature == "cls_patch":
|
| 312 |
+
# features = torch.cat([summary, features], dim=1)
|
| 313 |
+
elif self.select_feature in ["cls", "cls4"]:
|
| 314 |
+
features = summary
|
| 315 |
+
else:
|
| 316 |
+
raise ValueError(f"Unexpected select feature: {self.select_feature}")
|
| 317 |
+
|
| 318 |
+
# Remove the batch dimension if we added it.
|
| 319 |
+
if len(input_shape) == 3:
|
| 320 |
+
features = features.squeeze(0)
|
| 321 |
+
|
| 322 |
+
# Cast back to the input's dtype.
|
| 323 |
+
features = features.to(images.dtype)
|
| 324 |
+
|
| 325 |
+
rprint(
|
| 326 |
+
f"features shape={features.shape} mean={features.mean().item()} std={features.std().item()} dtype={features.dtype}"
|
| 327 |
+
)
|
| 328 |
+
|
| 329 |
+
if features.shape[-1] != self.get_hidden_size():
|
| 330 |
+
raise ValueError(f"Unexpected hidden size: {features.shape[-1]} != {self.get_hidden_size()}")
|
| 331 |
+
|
| 332 |
+
self.sample_count += 1
|
| 333 |
+
|
| 334 |
+
return features
|
VILA/llava/model/multimodal_encoder/radio_torchhub_encoder.py
ADDED
|
@@ -0,0 +1,375 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
#
|
| 15 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 16 |
+
|
| 17 |
+
import os
|
| 18 |
+
import warnings
|
| 19 |
+
from argparse import Namespace
|
| 20 |
+
from typing import Any, Dict
|
| 21 |
+
|
| 22 |
+
import numpy as np
|
| 23 |
+
import torch
|
| 24 |
+
from PIL import Image
|
| 25 |
+
from transformers import CLIPVisionConfig
|
| 26 |
+
|
| 27 |
+
from llava.model.multimodal_encoder.vision_encoder import VisionTower
|
| 28 |
+
from llava.train.utils import mprint, rprint
|
| 29 |
+
|
| 30 |
+
from .image_processor import ImageProcessor
|
| 31 |
+
from .visualize_features import get_pca_map
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def get_prefix_state_dict(state_dict: Dict[str, Any], prefix: str):
|
| 35 |
+
mod_state_dict = {k[len(prefix) :]: v for k, v in state_dict.items() if k.startswith(prefix)}
|
| 36 |
+
return mod_state_dict
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def is_rank0():
|
| 40 |
+
return not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class RADIOVisionTower(VisionTower):
|
| 44 |
+
"""
|
| 45 |
+
Vision Tower for the RADIO model.
|
| 46 |
+
|
| 47 |
+
Args:
|
| 48 |
+
vision_tower (str): Vision tower name. This is passed on
|
| 49 |
+
the command line with the `--vision_tower` argument.
|
| 50 |
+
The string is expected in the pattern of:
|
| 51 |
+
`radio:<image_size>:<checkpoint>:<extra_config>`.
|
| 52 |
+
Where <extra_config> is a comma-separated list of key=value pairs.
|
| 53 |
+
<image_size> can also be a comma-separated list of resolutions in
|
| 54 |
+
the case of multi-res inference. Limitations apply, e.g. only two
|
| 55 |
+
resolutions are supported and the second resolution must be a divisor
|
| 56 |
+
of the first one.
|
| 57 |
+
args (Namespace): Arguments.
|
| 58 |
+
delay_load (bool): Delay loading the model.
|
| 59 |
+
"""
|
| 60 |
+
|
| 61 |
+
def __init__(self, vision_tower, args, delay_load=False):
|
| 62 |
+
"""Initialization Routine."""
|
| 63 |
+
|
| 64 |
+
super().__init__(vision_tower, args, delay_load)
|
| 65 |
+
|
| 66 |
+
mprint(f"RADIOVisionTower: {vision_tower}. Args: {args} Delay load: {delay_load}")
|
| 67 |
+
|
| 68 |
+
self.select_feature = getattr(args, "mm_vision_select_feature", "patch")
|
| 69 |
+
|
| 70 |
+
self.vision_tower_name = vision_tower[len("radio:") :]
|
| 71 |
+
config_items = self.vision_tower_name.split(":")
|
| 72 |
+
self.image_sizes = [int(x) for x in config_items[0].split(",")]
|
| 73 |
+
if len(self.image_sizes) == 0:
|
| 74 |
+
raise ValueError("Expected more than zero images sizes!")
|
| 75 |
+
self.image_size = self.image_sizes[0]
|
| 76 |
+
self.image_aspect_ratio = args.image_aspect_ratio
|
| 77 |
+
|
| 78 |
+
self.downscale_factor = None
|
| 79 |
+
if len(self.image_sizes) > 1:
|
| 80 |
+
self.downscale_factor = self.image_sizes[0] // self.image_sizes[1]
|
| 81 |
+
assert self.downscale_factor == self.image_sizes[0] / self.image_sizes[1]
|
| 82 |
+
self.pool2d = torch.nn.AvgPool2d(self.downscale_factor, self.downscale_factor)
|
| 83 |
+
if len(self.image_sizes) > 2:
|
| 84 |
+
raise ValueError(f"Only support up to two resolutions")
|
| 85 |
+
elif self.image_size >= 512:
|
| 86 |
+
self.downscale_factor = 2
|
| 87 |
+
|
| 88 |
+
self.vision_tower_checkpoint = config_items[1]
|
| 89 |
+
|
| 90 |
+
extra_config = {}
|
| 91 |
+
if len(config_items) > 2:
|
| 92 |
+
# Parse extra config items. These are provided as a comma-separated list
|
| 93 |
+
# of key=value pairs.
|
| 94 |
+
extra_config_items = config_items[2].split(",")
|
| 95 |
+
|
| 96 |
+
for item in extra_config_items:
|
| 97 |
+
key, value = item.split("=")
|
| 98 |
+
extra_config[key] = value
|
| 99 |
+
|
| 100 |
+
self.adaptor_name = extra_config.get("adaptor", "backbone")
|
| 101 |
+
self.fuse_adaptor_with_backbone = eval(extra_config.get("fuse_adaptor_with_backbone", "False"))
|
| 102 |
+
self.skip_layer_norm = eval(extra_config.get("skip_layer_norm", "False"))
|
| 103 |
+
self.allow_pixel_unshuffle = eval(extra_config.get("pixel_unshuffle", "False"))
|
| 104 |
+
|
| 105 |
+
self.pixel_unshuffle = None
|
| 106 |
+
if self.allow_pixel_unshuffle and self.downscale_factor is not None:
|
| 107 |
+
self.pixel_unshuffle = torch.nn.PixelUnshuffle(self.downscale_factor)
|
| 108 |
+
|
| 109 |
+
if not delay_load:
|
| 110 |
+
self.load_model()
|
| 111 |
+
else:
|
| 112 |
+
# FIXME: This is a hack to avoid having to load the config from the checkpoint.
|
| 113 |
+
hidden_size = self.get_hidden_size(self.adaptor_name)
|
| 114 |
+
patch_size = 16
|
| 115 |
+
|
| 116 |
+
self.cfg_only = CLIPVisionConfig(
|
| 117 |
+
**{
|
| 118 |
+
"hidden_size": hidden_size,
|
| 119 |
+
"image_size": self.image_size,
|
| 120 |
+
"model_type": "radio_vision_model",
|
| 121 |
+
"num_attention_heads": None,
|
| 122 |
+
"num_channels": 3,
|
| 123 |
+
"num_hidden_layers": None,
|
| 124 |
+
"patch_size": patch_size,
|
| 125 |
+
}
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
self.sample_count = 0
|
| 129 |
+
|
| 130 |
+
self.debug = True
|
| 131 |
+
|
| 132 |
+
def get_hidden_size(self):
|
| 133 |
+
if self.select_feature == "cls":
|
| 134 |
+
hidden_size = 5120
|
| 135 |
+
elif self.adaptor_name == "openai_clip":
|
| 136 |
+
hidden_size = 1024
|
| 137 |
+
elif self.adaptor_name == "clip":
|
| 138 |
+
hidden_size = 1280
|
| 139 |
+
elif self.adaptor_name == "rtx-translate":
|
| 140 |
+
hidden_size = 2048
|
| 141 |
+
elif self.adaptor_name == "backbone":
|
| 142 |
+
hidden_size = 1280
|
| 143 |
+
else:
|
| 144 |
+
raise ValueError(f"Unknown adaptor name: {self.adaptor_name}")
|
| 145 |
+
|
| 146 |
+
if self.fuse_adaptor_with_backbone:
|
| 147 |
+
hidden_size += 1280
|
| 148 |
+
|
| 149 |
+
if len(self.image_sizes) == 2:
|
| 150 |
+
if self.pixel_unshuffle is not None:
|
| 151 |
+
hidden_size = hidden_size * 5
|
| 152 |
+
else:
|
| 153 |
+
hidden_size = hidden_size * 2
|
| 154 |
+
elif self.pixel_unshuffle is not None:
|
| 155 |
+
hidden_size = hidden_size * 4
|
| 156 |
+
|
| 157 |
+
return hidden_size
|
| 158 |
+
|
| 159 |
+
def load_model(self):
|
| 160 |
+
|
| 161 |
+
if self.image_aspect_ratio == "resize":
|
| 162 |
+
self.image_processor = ImageProcessor(
|
| 163 |
+
size={"width": self.image_size, "height": self.image_size},
|
| 164 |
+
do_pad=False,
|
| 165 |
+
do_normalize=False,
|
| 166 |
+
do_convert_rgb=True,
|
| 167 |
+
)
|
| 168 |
+
else:
|
| 169 |
+
self.image_processor = ImageProcessor(
|
| 170 |
+
size={"longest_edge": self.image_size},
|
| 171 |
+
do_pad=True,
|
| 172 |
+
pad_multiple=16,
|
| 173 |
+
do_normalize=False,
|
| 174 |
+
do_convert_rgb=True,
|
| 175 |
+
pad_value=0.456,
|
| 176 |
+
)
|
| 177 |
+
# For compatibility with CLIP Image Processor: the data loader uses width/height to
|
| 178 |
+
# create dummy blank images for samples that don't have an image.
|
| 179 |
+
self.image_processor.crop_size = {"width": self.image_size, "height": self.image_size}
|
| 180 |
+
|
| 181 |
+
mprint(self.image_processor)
|
| 182 |
+
|
| 183 |
+
# Load weights from checkpoint.
|
| 184 |
+
checkpoint_path = self.vision_tower_checkpoint
|
| 185 |
+
rprint(f"Loading checkpoint from {checkpoint_path}")
|
| 186 |
+
|
| 187 |
+
# NOTE: do a lazy import of Timm to avoid issues with
|
| 188 |
+
# DeepSpeed's ZeRO-3.
|
| 189 |
+
from timm.models.vision_transformer import VisionTransformer
|
| 190 |
+
|
| 191 |
+
self.vision_tower = torch.hub.load(
|
| 192 |
+
"NVlabs/RADIO",
|
| 193 |
+
"radio_model",
|
| 194 |
+
version=checkpoint_path,
|
| 195 |
+
progress=True,
|
| 196 |
+
adaptor_names=self.adaptor_name if self.adaptor_name != "backbone" else None,
|
| 197 |
+
)
|
| 198 |
+
|
| 199 |
+
if isinstance(self.vision_tower.model, VisionTransformer):
|
| 200 |
+
hidden_size = self.vision_tower.model.embed_dim
|
| 201 |
+
else:
|
| 202 |
+
raise ValueError(f"Unknown model type: {self.vision_tower}")
|
| 203 |
+
|
| 204 |
+
# Override hidden size for OpenAI CLIP.
|
| 205 |
+
hidden_size = self.get_hidden_size()
|
| 206 |
+
|
| 207 |
+
if hasattr(self.vision_tower.model, "patch_generator"):
|
| 208 |
+
patch_gen = self.vision_tower.model.patch_generator
|
| 209 |
+
# Cropped Positional Embedding (CPE) case.
|
| 210 |
+
patch_size = patch_gen.patch_size
|
| 211 |
+
else:
|
| 212 |
+
# Standard ViT case.
|
| 213 |
+
patch_size = self.vision_tower.model.patch_embed.patch_size[0]
|
| 214 |
+
|
| 215 |
+
self.vision_tower.config = CLIPVisionConfig(
|
| 216 |
+
**{
|
| 217 |
+
"hidden_size": hidden_size,
|
| 218 |
+
"image_size": self.image_size,
|
| 219 |
+
"model_type": "radio_vision_model",
|
| 220 |
+
"num_attention_heads": None,
|
| 221 |
+
"num_channels": 3,
|
| 222 |
+
"num_hidden_layers": None,
|
| 223 |
+
"patch_size": patch_size,
|
| 224 |
+
}
|
| 225 |
+
)
|
| 226 |
+
|
| 227 |
+
self.vision_tower.eval()
|
| 228 |
+
self.vision_tower.requires_grad_(False)
|
| 229 |
+
|
| 230 |
+
self.is_loaded = True
|
| 231 |
+
self._to_dtype = None
|
| 232 |
+
|
| 233 |
+
if self.skip_layer_norm:
|
| 234 |
+
rank0_print(f"Removing layer norm from the model: {self.vision_tower.model.norm}")
|
| 235 |
+
self.vision_tower.model.norm = torch.nn.Identity()
|
| 236 |
+
|
| 237 |
+
def to(self, *args, **kwargs):
|
| 238 |
+
# Prevent casting the RADIO model's weights
|
| 239 |
+
kwargs = dict(kwargs)
|
| 240 |
+
self._to_dtype = kwargs.pop("dtype", None)
|
| 241 |
+
mprint(f"RADIO: bypass cast to dtype={self._to_dtype}")
|
| 242 |
+
super().to(*args, **kwargs)
|
| 243 |
+
pass
|
| 244 |
+
|
| 245 |
+
def train(self, mode=True):
|
| 246 |
+
"""Intercept call."""
|
| 247 |
+
# Drop a warning if mode is True.
|
| 248 |
+
if mode:
|
| 249 |
+
warnings.warn("RADIOEncoder is always in eval mode.")
|
| 250 |
+
pass
|
| 251 |
+
|
| 252 |
+
@torch.no_grad()
|
| 253 |
+
def get_features(self, x: torch.Tensor):
|
| 254 |
+
x_float = x.float()
|
| 255 |
+
with torch.autocast("cuda", dtype=torch.bfloat16):
|
| 256 |
+
output = self.vision_tower(x_float)
|
| 257 |
+
|
| 258 |
+
if isinstance(output, dict):
|
| 259 |
+
summary, features = output[self.adaptor_name]
|
| 260 |
+
if self.fuse_adaptor_with_backbone:
|
| 261 |
+
backbone_summary, backbone_features = output["backbone"]
|
| 262 |
+
summary = torch.cat([summary, backbone_summary], dim=2)
|
| 263 |
+
features = torch.cat([features, backbone_features], dim=2)
|
| 264 |
+
else:
|
| 265 |
+
summary, features = output
|
| 266 |
+
|
| 267 |
+
return summary, features.to(dtype=x.dtype)
|
| 268 |
+
|
| 269 |
+
@torch.no_grad()
|
| 270 |
+
def forward(self, images: torch.Tensor):
|
| 271 |
+
"""Main forward pass."""
|
| 272 |
+
input_shape = images.shape
|
| 273 |
+
|
| 274 |
+
x = images
|
| 275 |
+
# Add a batch dimension if necessary.
|
| 276 |
+
if len(input_shape) == 3:
|
| 277 |
+
x = x.unsqueeze(0)
|
| 278 |
+
|
| 279 |
+
rprint(
|
| 280 |
+
f"input shape={input_shape}->{x.shape} device={x.device} mean={x.mean().item()} std={x.std().item()} dtype={x.dtype}"
|
| 281 |
+
)
|
| 282 |
+
|
| 283 |
+
summary, features = self.get_features(x) # B, T, C
|
| 284 |
+
|
| 285 |
+
if len(summary.shape) == 2:
|
| 286 |
+
if self.select_feature == "cls4":
|
| 287 |
+
# Add a token dimension if necessary.
|
| 288 |
+
B, C = summary.shape
|
| 289 |
+
summary = summary.reshape(B, 4, C // 4)
|
| 290 |
+
else:
|
| 291 |
+
# Add a token dimension if necessary.
|
| 292 |
+
summary = summary.unsqueeze(1)
|
| 293 |
+
|
| 294 |
+
B, _, H, W = x.shape
|
| 295 |
+
_, _, C = features.shape
|
| 296 |
+
patch_size = self.vision_tower.config.patch_size
|
| 297 |
+
spatial_features = features.reshape(B, H // patch_size, W // patch_size, C)
|
| 298 |
+
spatial_features = spatial_features.permute(0, 3, 1, 2) # B, C, H/patch_size, W/patch_size
|
| 299 |
+
|
| 300 |
+
if self.debug and is_rank0() and self.sample_count % 1000 == 0:
|
| 301 |
+
spatial_features_hwc = spatial_features.permute(0, 2, 3, 1)
|
| 302 |
+
# create the debug directory
|
| 303 |
+
os.makedirs("radio-debug", exist_ok=True)
|
| 304 |
+
torch.save(x, f"radio-debug/sample_{self.sample_count}_input.pt")
|
| 305 |
+
torch.save(features, f"radio-debug/sample_{self.sample_count}_features.pt")
|
| 306 |
+
torch.save(spatial_features_hwc, f"radio-debug/sample_{self.sample_count}_features_reshaped.pt")
|
| 307 |
+
for i in range(B):
|
| 308 |
+
image = x[i].permute(1, 2, 0).float() * 255
|
| 309 |
+
image = Image.fromarray(image.cpu().numpy().astype(np.uint8))
|
| 310 |
+
image.save(os.path.join("radio-debug/", f"sample_{self.sample_count}_preprocessed_{i}.png"))
|
| 311 |
+
pca_map = get_pca_map(spatial_features_hwc[i : i + 1], x.shape[-2:])
|
| 312 |
+
torch.save(pca_map, f"radio-debug/sample_{self.sample_count}_pca_map_{i}.pt")
|
| 313 |
+
image = pca_map * 255
|
| 314 |
+
image = Image.fromarray(image.astype(np.uint8))
|
| 315 |
+
image.save(os.path.join("radio-debug/", f"sample_{self.sample_count}_pca_map_{i}.png"))
|
| 316 |
+
pass
|
| 317 |
+
|
| 318 |
+
if self.pixel_unshuffle is not None:
|
| 319 |
+
spatial_features = self.pixel_unshuffle(spatial_features)
|
| 320 |
+
# B, C*downscale_factor**2, H/patch_size/downscale_factor, W/patch_size/downscale_factor
|
| 321 |
+
features = spatial_features.reshape(
|
| 322 |
+
B,
|
| 323 |
+
C * self.downscale_factor**2,
|
| 324 |
+
(H // patch_size // self.downscale_factor) * (W // patch_size // self.downscale_factor),
|
| 325 |
+
).permute(0, 2, 1)
|
| 326 |
+
|
| 327 |
+
if len(self.image_sizes) > 1:
|
| 328 |
+
# Experimental support for multi-resolution inference.
|
| 329 |
+
if self.pixel_unshuffle is None:
|
| 330 |
+
# downscale features
|
| 331 |
+
spatial_features = self.pool2d(
|
| 332 |
+
spatial_features
|
| 333 |
+
) # B, C, H/patch_size/downscale_factor, W/patch_size/downscale_factor
|
| 334 |
+
features = spatial_features.reshape(
|
| 335 |
+
B, C, (H // patch_size // self.downscale_factor) * (W // patch_size // self.downscale_factor)
|
| 336 |
+
)
|
| 337 |
+
features = features.permute(
|
| 338 |
+
0, 2, 1
|
| 339 |
+
) # B, (H/patch_size/downscale_factor) * (W/patch_size/downscale_factor), C
|
| 340 |
+
|
| 341 |
+
# Downscale the input image.
|
| 342 |
+
x = self.pool2d(x) # B, 3, H/downscale_factor, W/downscale_factor)
|
| 343 |
+
features_stage2 = self.get_features(
|
| 344 |
+
x
|
| 345 |
+
) # B, (H/patch_size/downscale_factor) * (W/patch_size/downscale_factor), C
|
| 346 |
+
|
| 347 |
+
# Concatenate stage1 and stage 2 features.
|
| 348 |
+
features = torch.cat([features, features_stage2], dim=2)
|
| 349 |
+
|
| 350 |
+
if self.select_feature in ["patch", "cls_patch"]:
|
| 351 |
+
# Ignore cls-patch for now.
|
| 352 |
+
pass
|
| 353 |
+
# elif self.select_feature == "cls_patch":
|
| 354 |
+
# features = torch.cat([summary, features], dim=1)
|
| 355 |
+
elif self.select_feature in ["cls", "cls4"]:
|
| 356 |
+
features = summary
|
| 357 |
+
else:
|
| 358 |
+
raise ValueError(f"Unexpected select feature: {self.select_feature}")
|
| 359 |
+
|
| 360 |
+
# Remove the batch dimension if we added it.
|
| 361 |
+
if len(input_shape) == 3:
|
| 362 |
+
features = features.squeeze(0)
|
| 363 |
+
|
| 364 |
+
# Cast back to the input's dtype.
|
| 365 |
+
features = features.to(images.dtype)
|
| 366 |
+
|
| 367 |
+
adaptor_name = f"{self.adaptor_name}{'+backbone' if self.fuse_adaptor_with_backbone else ''}"
|
| 368 |
+
rprint(
|
| 369 |
+
f"features ({adaptor_name}) shape={features.shape} mean={features.mean().item()} std={features.std().item()} dtype={features.dtype}"
|
| 370 |
+
)
|
| 371 |
+
|
| 372 |
+
assert features.shape[-1] == self.get_hidden_size()
|
| 373 |
+
self.sample_count += 1
|
| 374 |
+
|
| 375 |
+
return features
|