tuandunghcmut commited on
Commit
ff5d469
·
verified ·
1 Parent(s): cbbdd48

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. VILA/llava/model/__pycache__/__init__.cpython-310.pyc +0 -0
  2. VILA/llava/model/__pycache__/__init__.cpython-39.pyc +0 -0
  3. VILA/llava/model/__pycache__/builder.cpython-310.pyc +0 -0
  4. VILA/llava/model/__pycache__/configuration_llava.cpython-310.pyc +0 -0
  5. VILA/llava/model/__pycache__/llava_arch.cpython-310.pyc +0 -0
  6. VILA/llava/model/__pycache__/utils.cpython-310.pyc +0 -0
  7. VILA/llava/model/language_model/__pycache__/builder.cpython-310.pyc +0 -0
  8. VILA/llava/model/language_model/__pycache__/llava_llama.cpython-310.pyc +0 -0
  9. VILA/llava/model/language_model/__pycache__/llava_llama.cpython-39.pyc +0 -0
  10. VILA/llava/model/language_model/__pycache__/llava_mistral.cpython-310.pyc +0 -0
  11. VILA/llava/model/language_model/__pycache__/llava_mixtral.cpython-310.pyc +0 -0
  12. VILA/llava/model/language_model/__pycache__/modeling_mixtral_long_context.cpython-310.pyc +0 -0
  13. VILA/llava/model/language_model/builder.py +114 -0
  14. VILA/llava/model/language_model/llava_gemma.py +153 -0
  15. VILA/llava/model/language_model/llava_llama.py +186 -0
  16. VILA/llava/model/language_model/llava_mistral.py +137 -0
  17. VILA/llava/model/language_model/llava_mixtral.py +136 -0
  18. VILA/llava/model/language_model/llava_mpt.py +160 -0
  19. VILA/llava/model/language_model/modeling_mixtral_long_context.py +1657 -0
  20. VILA/llava/model/language_model/mpt/adapt_tokenizer.py +61 -0
  21. VILA/llava/model/language_model/mpt/attention.py +480 -0
  22. VILA/llava/model/language_model/mpt/blocks.py +100 -0
  23. VILA/llava/model/language_model/mpt/configuration_mpt.py +184 -0
  24. VILA/llava/model/language_model/mpt/custom_embedding.py +27 -0
  25. VILA/llava/model/language_model/mpt/flash_attn_triton.py +947 -0
  26. VILA/llava/model/language_model/mpt/hf_prefixlm_converter.py +657 -0
  27. VILA/llava/model/language_model/mpt/meta_init_context.py +118 -0
  28. VILA/llava/model/language_model/mpt/modeling_mpt.py +483 -0
  29. VILA/llava/model/language_model/mpt/norm.py +89 -0
  30. VILA/llava/model/language_model/mpt/param_init_fns.py +399 -0
  31. VILA/llava/model/multimodal_encoder/__pycache__/builder.cpython-310.pyc +0 -0
  32. VILA/llava/model/multimodal_encoder/__pycache__/clip_encoder.cpython-310.pyc +0 -0
  33. VILA/llava/model/multimodal_encoder/__pycache__/image_processor.cpython-310.pyc +0 -0
  34. VILA/llava/model/multimodal_encoder/__pycache__/intern_encoder.cpython-310.pyc +0 -0
  35. VILA/llava/model/multimodal_encoder/__pycache__/radio_encoder.cpython-310.pyc +0 -0
  36. VILA/llava/model/multimodal_encoder/__pycache__/siglip_encoder.cpython-310.pyc +0 -0
  37. VILA/llava/model/multimodal_encoder/__pycache__/vision_encoder.cpython-310.pyc +0 -0
  38. VILA/llava/model/multimodal_encoder/__pycache__/visualize_features.cpython-310.pyc +0 -0
  39. VILA/llava/model/multimodal_encoder/builder.py +64 -0
  40. VILA/llava/model/multimodal_encoder/clip_encoder.py +42 -0
  41. VILA/llava/model/multimodal_encoder/image_processor.py +546 -0
  42. VILA/llava/model/multimodal_encoder/intern/__pycache__/configuration_intern_vit.cpython-310.pyc +0 -0
  43. VILA/llava/model/multimodal_encoder/intern/__pycache__/flash_attention.cpython-310.pyc +0 -0
  44. VILA/llava/model/multimodal_encoder/intern/__pycache__/modeling_intern_vit.cpython-310.pyc +0 -0
  45. VILA/llava/model/multimodal_encoder/intern/configuration_intern_vit.py +117 -0
  46. VILA/llava/model/multimodal_encoder/intern/flash_attention.py +105 -0
  47. VILA/llava/model/multimodal_encoder/intern/modeling_intern_vit.py +543 -0
  48. VILA/llava/model/multimodal_encoder/intern_encoder.py +71 -0
  49. VILA/llava/model/multimodal_encoder/radio_encoder.py +334 -0
  50. VILA/llava/model/multimodal_encoder/radio_torchhub_encoder.py +375 -0
VILA/llava/model/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (450 Bytes). View file
 
VILA/llava/model/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (448 Bytes). View file
 
VILA/llava/model/__pycache__/builder.cpython-310.pyc ADDED
Binary file (5.81 kB). View file
 
VILA/llava/model/__pycache__/configuration_llava.cpython-310.pyc ADDED
Binary file (1.28 kB). View file
 
VILA/llava/model/__pycache__/llava_arch.cpython-310.pyc ADDED
Binary file (20.8 kB). View file
 
VILA/llava/model/__pycache__/utils.cpython-310.pyc ADDED
Binary file (2.45 kB). View file
 
VILA/llava/model/language_model/__pycache__/builder.cpython-310.pyc ADDED
Binary file (2.42 kB). View file
 
VILA/llava/model/language_model/__pycache__/llava_llama.cpython-310.pyc ADDED
Binary file (3.8 kB). View file
 
VILA/llava/model/language_model/__pycache__/llava_llama.cpython-39.pyc ADDED
Binary file (3.7 kB). View file
 
VILA/llava/model/language_model/__pycache__/llava_mistral.cpython-310.pyc ADDED
Binary file (3.67 kB). View file
 
VILA/llava/model/language_model/__pycache__/llava_mixtral.cpython-310.pyc ADDED
Binary file (3.57 kB). View file
 
VILA/llava/model/language_model/__pycache__/modeling_mixtral_long_context.cpython-310.pyc ADDED
Binary file (46 kB). View file
 
VILA/llava/model/language_model/builder.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 NVIDIA CORPORATION & AFFILIATES
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # SPDX-License-Identifier: Apache-2.0
16
+
17
+ import math
18
+ import os.path as osp
19
+ import warnings
20
+ from typing import Tuple
21
+
22
+ import torch
23
+ from huggingface_hub import file_exists, repo_exists
24
+ from huggingface_hub.utils import HFValidationError
25
+ from transformers import (
26
+ AutoConfig,
27
+ AutoModelForCausalLM,
28
+ AutoTokenizer,
29
+ PretrainedConfig,
30
+ PreTrainedModel,
31
+ PreTrainedTokenizer,
32
+ )
33
+
34
+
35
+ def has_tokenizer(repo_id_or_path: str) -> bool:
36
+ # Check if the tokenizer is in a local directory
37
+ if osp.exists(osp.join(repo_id_or_path, "tokenizer_config.json")):
38
+ return True
39
+
40
+ # Check if the tokenizer is in a Hugging Face Hub repo
41
+ try:
42
+ return repo_exists(repo_id_or_path) and file_exists(repo_id_or_path, "tokenizer_config.json")
43
+ except HFValidationError:
44
+ return False
45
+
46
+
47
+ def context_length_extension(config):
48
+ orig_ctx_len = getattr(config, "max_position_embeddings", None)
49
+ model_max_length = getattr(config, "model_max_length", None)
50
+ if orig_ctx_len and model_max_length > orig_ctx_len:
51
+ print(f"Scaling RoPE from {orig_ctx_len} to {model_max_length}")
52
+ scaling_factor = float(math.ceil(model_max_length / orig_ctx_len))
53
+ config.rope_scaling = {"type": "linear", "factor": scaling_factor}
54
+ return config
55
+
56
+
57
+ def build_llm_and_tokenizer(
58
+ model_name_or_path: str,
59
+ config: PretrainedConfig,
60
+ attn_implementation=None,
61
+ model_max_length=None,
62
+ *args,
63
+ **kwargs,
64
+ ) -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
65
+ llm_cfg = AutoConfig.from_pretrained(model_name_or_path)
66
+ llm_cfg._attn_implementation = attn_implementation
67
+ llm_cfg.model_max_length = model_max_length
68
+ if model_max_length is not None:
69
+ context_length_extension(llm_cfg)
70
+
71
+ llm = AutoModelForCausalLM.from_pretrained(
72
+ model_name_or_path, config=llm_cfg, torch_dtype=eval(config.model_dtype), *args, **kwargs
73
+ )
74
+
75
+ # Locate the tokenizer.
76
+ llm_path = model_name_or_path
77
+ if not has_tokenizer(llm_path):
78
+ llm_path = osp.join(llm_path, "llm")
79
+ if not has_tokenizer(llm_path):
80
+ raise ValueError(f"Cannot find tokenizer in {llm_path}.")
81
+
82
+ # TODO(ligeng): use LLM class to judge to better compability.
83
+ try:
84
+ llm_arch = getattr(llm_cfg, "architectures")[0].lower()
85
+ except BaseException:
86
+ warnings.warn(f'Cannot find LLM architecture, please check the "config.json" under "{llm_path}".')
87
+
88
+ if "mpt" in llm_arch:
89
+ tokenizer = AutoTokenizer.from_pretrained(
90
+ llm_path,
91
+ model_max_length=llm_cfg.model_max_length,
92
+ padding_side="right",
93
+ )
94
+ elif "yi" in llm_path or (
95
+ getattr(llm_cfg, "num_hidden_layers", -1) == 60 and getattr(llm_cfg, "num_attention_heads", -1) == 56
96
+ ):
97
+ tokenizer = AutoTokenizer.from_pretrained(
98
+ llm_path,
99
+ model_max_length=llm_cfg.model_max_length,
100
+ padding_side="right",
101
+ use_fast=False,
102
+ )
103
+ else:
104
+ tokenizer = AutoTokenizer.from_pretrained(
105
+ llm_path,
106
+ model_max_length=llm_cfg.model_max_length,
107
+ padding_side="right",
108
+ use_fast=False,
109
+ legacy=False,
110
+ )
111
+
112
+ # TODO(ligeng): is this necessary for llava?
113
+ config.hidden_size = llm.config.hidden_size
114
+ return llm, tokenizer
VILA/llava/model/language_model/llava_gemma.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Haotian Liu
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ PAD_TOKEN_ID = 0
17
+
18
+ from typing import List, Optional, Tuple, Union
19
+
20
+ import torch
21
+ import torch.nn as nn
22
+ from transformers import AutoConfig, AutoModelForCausalLM
23
+ from transformers.modeling_outputs import CausalLMOutputWithPast
24
+ from transformers.models.gemma import GemmaConfig, GemmaForCausalLM, GemmaModel
25
+
26
+ from ..llava_arch import LlavaMetaForCausalLM, LlavaMetaModel
27
+
28
+
29
+ class LlavaGemmaConfig(GemmaConfig):
30
+ model_type = "llava_gemma"
31
+
32
+
33
+ class LlavaGemmaModel(GemmaModel, LlavaMetaModel):
34
+ config_class = LlavaGemmaConfig
35
+
36
+ def __init__(self, config: GemmaConfig):
37
+ super().__init__(config)
38
+
39
+
40
+ class LlavaGemmaForCausalLM(GemmaForCausalLM, LlavaMetaForCausalLM):
41
+ config_class = LlavaGemmaConfig
42
+
43
+ def __init__(self, config):
44
+ super().__init__(config)
45
+ self.model = LlavaGemmaModel(config)
46
+ self.pretraining_tp = 1
47
+ self.vocab_size = config.vocab_size
48
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
49
+
50
+ # Initialize weights and apply final processing
51
+ self.post_init()
52
+
53
+ def get_model(self):
54
+ return self.model
55
+
56
+ def get_lm_head(self):
57
+ return self.lm_head
58
+
59
+ def forward(
60
+ self,
61
+ input_ids: torch.LongTensor = None,
62
+ attention_mask: Optional[torch.Tensor] = None,
63
+ position_ids: Optional[torch.LongTensor] = None,
64
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
65
+ seqlens_in_batch: Optional[torch.LongTensor] = None,
66
+ inputs_embeds: Optional[torch.FloatTensor] = None,
67
+ labels: Optional[torch.LongTensor] = None,
68
+ use_cache: Optional[bool] = None,
69
+ cache_position: Optional[torch.LongTensor] = None,
70
+ output_attentions: Optional[bool] = None,
71
+ output_hidden_states: Optional[bool] = None,
72
+ images: Optional[torch.FloatTensor] = None,
73
+ return_dict: Optional[bool] = None,
74
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
75
+ if inputs_embeds is None:
76
+ (
77
+ input_ids,
78
+ position_ids,
79
+ attention_mask,
80
+ past_key_values,
81
+ inputs_embeds,
82
+ labels,
83
+ ) = self.prepare_inputs_labels_for_multimodal(
84
+ input_ids, position_ids, attention_mask, past_key_values, labels, images
85
+ )
86
+ # TODO (kentang-mit@): fuse this function into the previous one.
87
+ # current design makes unit-test easier.
88
+ if self.training:
89
+ (
90
+ _,
91
+ new_position_ids,
92
+ new_attention_mask,
93
+ _,
94
+ new_inputs_embeds,
95
+ new_labels,
96
+ sorted_seqlens_in_batch,
97
+ ) = self.repack_multimodal_data(
98
+ input_ids, position_ids, attention_mask, past_key_values, inputs_embeds, labels
99
+ )
100
+ if sorted_seqlens_in_batch is None:
101
+ sorted_seqlens_in_batch = seqlens_in_batch
102
+ new_input_ids = None
103
+ past_key_values = None
104
+ new_cache_position = None
105
+ else:
106
+ new_attention_mask = attention_mask
107
+ new_position_ids = position_ids
108
+ new_inputs_embeds = inputs_embeds
109
+ new_labels = labels
110
+ if attention_mask is not None:
111
+ sorted_seqlens_in_batch = attention_mask.sum(-1).int()
112
+ else:
113
+ sorted_seqlens_in_batch = None
114
+ new_input_ids = input_ids
115
+ # kentang-mit@: This only works for batch=1 currently
116
+ # model.generate of gemma does not correctly handle decoding stage currently
117
+ # need to manually adjust decoding stage input = 1 token
118
+ if past_key_values is not None:
119
+ if new_inputs_embeds is not None:
120
+ new_inputs_embeds = new_inputs_embeds[:, [-1]]
121
+ # kentang-mit@: seems to be a problem unique to gemma
122
+ if new_position_ids is not None:
123
+ new_position_ids = new_position_ids[:, [-1]]
124
+ new_cache_position = new_position_ids[0]
125
+
126
+ outputs = super().forward(
127
+ input_ids=new_input_ids,
128
+ attention_mask=new_attention_mask,
129
+ position_ids=new_position_ids,
130
+ past_key_values=past_key_values,
131
+ inputs_embeds=new_inputs_embeds,
132
+ labels=new_labels,
133
+ use_cache=use_cache,
134
+ cache_position=new_cache_position,
135
+ output_attentions=output_attentions,
136
+ output_hidden_states=output_hidden_states,
137
+ return_dict=return_dict,
138
+ seqlens_in_batch=sorted_seqlens_in_batch,
139
+ )
140
+ return outputs
141
+
142
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
143
+ images = kwargs.pop("images", None)
144
+ _inputs = super().prepare_inputs_for_generation(
145
+ input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs
146
+ )
147
+ if images is not None:
148
+ _inputs["images"] = images
149
+ return _inputs
150
+
151
+
152
+ AutoConfig.register("llava_gemma", LlavaGemmaConfig)
153
+ AutoModelForCausalLM.register(LlavaGemmaConfig, LlavaGemmaForCausalLM)
VILA/llava/model/language_model/llava_llama.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Haotian Liu
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # This file is modified from https://github.com/haotian-liu/LLaVA/
16
+
17
+
18
+ import inspect
19
+ import os
20
+ from typing import List, Optional, Tuple, Union
21
+
22
+ import torch
23
+ from transformers import AutoConfig, AutoModel, PretrainedConfig, PreTrainedModel
24
+ from transformers.modeling_outputs import CausalLMOutputWithPast
25
+
26
+ from ...train.utils import calculate_loss_weight
27
+ from ..configuration_llava import LlavaConfig
28
+ from ..llava_arch import LlavaMetaForCausalLM, LlavaMetaModel
29
+
30
+
31
+ class LlavaLlamaConfig(LlavaConfig):
32
+ model_type = "llava_llama"
33
+
34
+
35
+ ## FIXME we will follow the convention to add a new class for CausalLM in the future
36
+ class LlavaLlamaModel(LlavaMetaModel, LlavaMetaForCausalLM, PreTrainedModel):
37
+ config_class = LlavaLlamaConfig
38
+ main_input_name = "input_embeds"
39
+ supports_gradient_checkpointing = True
40
+
41
+ def __init__(self, config: LlavaLlamaConfig = None, *args, **kwargs) -> None:
42
+ super().__init__(config)
43
+ return self.init_vlm(config=config, *args, **kwargs)
44
+
45
+ @classmethod
46
+ def from_pretrained(
47
+ cls,
48
+ pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
49
+ *model_args,
50
+ config: Optional[Union[PretrainedConfig, str, os.PathLike]] = None,
51
+ cache_dir: Optional[Union[str, os.PathLike]] = None,
52
+ ignore_mismatched_sizes: bool = False,
53
+ force_download: bool = False,
54
+ local_files_only: bool = False,
55
+ token: Optional[Union[str, bool]] = None,
56
+ revision: str = "main",
57
+ use_safetensors: bool = None,
58
+ **kwargs,
59
+ ):
60
+ if hasattr(cls, "load_pretrained"):
61
+ return cls.load_pretrained(
62
+ pretrained_model_name_or_path,
63
+ *model_args,
64
+ config=config,
65
+ cache_dir=cache_dir,
66
+ ignore_mismatched_sizes=ignore_mismatched_sizes,
67
+ force_download=force_download,
68
+ local_files_only=local_files_only,
69
+ token=token,
70
+ revision=revision,
71
+ use_safetensors=use_safetensors,
72
+ **kwargs,
73
+ )
74
+ return super(LlavaLlamaModel).from_pretrained(
75
+ pretrained_model_name_or_path,
76
+ *model_args,
77
+ config=config,
78
+ cache_dir=cache_dir,
79
+ ignore_mismatched_sizes=ignore_mismatched_sizes,
80
+ force_download=force_download,
81
+ local_files_only=local_files_only,
82
+ token=token,
83
+ revision=revision,
84
+ use_safetensors=use_safetensors,
85
+ **kwargs,
86
+ )
87
+
88
+ def forward(
89
+ self,
90
+ input_ids: torch.LongTensor = None,
91
+ images: Optional[torch.FloatTensor] = None,
92
+ attention_mask: Optional[torch.Tensor] = None,
93
+ position_ids: Optional[torch.LongTensor] = None,
94
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
95
+ seqlens_in_batch: Optional[torch.LongTensor] = None,
96
+ inputs_embeds: Optional[torch.FloatTensor] = None,
97
+ labels: Optional[torch.LongTensor] = None,
98
+ use_cache: Optional[bool] = None,
99
+ output_attentions: Optional[bool] = None,
100
+ output_hidden_states: Optional[bool] = None,
101
+ return_dict: Optional[bool] = None,
102
+ dpo_forward: bool = False,
103
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
104
+ self.freezed_module_patch()
105
+ if inputs_embeds is None:
106
+ (
107
+ input_ids,
108
+ position_ids,
109
+ attention_mask,
110
+ past_key_values,
111
+ inputs_embeds,
112
+ labels,
113
+ ) = self.prepare_inputs_labels_for_multimodal(
114
+ input_ids, position_ids, attention_mask, past_key_values, labels, images
115
+ )
116
+
117
+ support_packing = "seqlens_in_batch" in inspect.signature(self.llm.forward).parameters
118
+
119
+ if self.training and support_packing and not dpo_forward:
120
+ (
121
+ _,
122
+ new_position_ids,
123
+ new_attention_mask,
124
+ _,
125
+ new_inputs_embeds,
126
+ new_labels,
127
+ sorted_seqlens_in_batch,
128
+ ) = self.repack_multimodal_data(
129
+ input_ids,
130
+ position_ids,
131
+ attention_mask,
132
+ past_key_values,
133
+ inputs_embeds,
134
+ labels,
135
+ )
136
+ if sorted_seqlens_in_batch is None:
137
+ sorted_seqlens_in_batch = seqlens_in_batch
138
+ new_input_ids = None
139
+ past_key_values = None
140
+ else:
141
+ new_attention_mask = attention_mask
142
+ new_position_ids = position_ids
143
+ new_inputs_embeds = inputs_embeds
144
+ new_labels = labels
145
+ sorted_seqlens_in_batch = attention_mask.sum(-1).int()
146
+ new_input_ids = input_ids
147
+
148
+ if support_packing:
149
+ outputs = self.llm.forward(
150
+ input_ids=new_input_ids,
151
+ attention_mask=new_attention_mask,
152
+ position_ids=new_position_ids,
153
+ past_key_values=past_key_values,
154
+ inputs_embeds=new_inputs_embeds,
155
+ labels=new_labels,
156
+ use_cache=use_cache,
157
+ output_attentions=output_attentions,
158
+ output_hidden_states=output_hidden_states,
159
+ return_dict=return_dict,
160
+ seqlens_in_batch=sorted_seqlens_in_batch,
161
+ )
162
+ else:
163
+ outputs = self.llm.forward(
164
+ input_ids=new_input_ids,
165
+ attention_mask=new_attention_mask,
166
+ position_ids=new_position_ids,
167
+ past_key_values=past_key_values,
168
+ inputs_embeds=new_inputs_embeds,
169
+ labels=new_labels,
170
+ use_cache=use_cache,
171
+ output_attentions=output_attentions,
172
+ output_hidden_states=output_hidden_states,
173
+ return_dict=return_dict,
174
+ )
175
+
176
+ # Loss rescale for SP & DP loss match
177
+ loss_weight = calculate_loss_weight(new_labels)
178
+ outputs.loss = outputs.loss * loss_weight
179
+
180
+ if dpo_forward:
181
+ return outputs.logits, new_labels
182
+ return outputs
183
+
184
+
185
+ AutoConfig.register("llava_llama", LlavaLlamaConfig)
186
+ AutoModel.register(LlavaLlamaConfig, LlavaLlamaModel)
VILA/llava/model/language_model/llava_mistral.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 NVIDIA CORPORATION & AFFILIATES
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # SPDX-License-Identifier: Apache-2.0
16
+
17
+ # This file is modified from https://github.com/haotian-liu/LLaVA/
18
+
19
+ from typing import List, Optional, Tuple, Union
20
+
21
+ import torch
22
+ import torch.nn as nn
23
+ from transformers import AutoConfig, AutoModelForCausalLM, MistralConfig, MistralForCausalLM, MistralModel
24
+ from transformers.modeling_outputs import CausalLMOutputWithPast
25
+
26
+ from ..llava_arch import LlavaMetaForCausalLM, LlavaMetaModel
27
+ from .modeling_mixtral_long_context import MixtralForCausalLM, MixtralModel
28
+
29
+
30
+ class LlavaMistralConfig(MistralConfig):
31
+ model_type = "llava_mistral"
32
+ pretraining_tp = 1
33
+
34
+
35
+ class LlavaMistralModel(MistralModel, LlavaMetaModel):
36
+ config_class = LlavaMistralConfig
37
+
38
+ def __init__(self, config: MistralConfig):
39
+ super().__init__(config)
40
+
41
+
42
+ class LlavaMistralForCausalLM(MistralForCausalLM, LlavaMetaForCausalLM):
43
+ config_class = LlavaMistralConfig
44
+
45
+ def __init__(self, config):
46
+ super(MistralForCausalLM, self).__init__(config)
47
+ self.model = LlavaMistralModel(config)
48
+ self.pretraining_tp = config.pretraining_tp
49
+ self.vocab_size = config.vocab_size
50
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
51
+
52
+ # Initialize weights and apply final processing
53
+ self.post_init()
54
+
55
+ def get_model(self):
56
+ return self.model
57
+
58
+ def get_lm_head(self):
59
+ return self.lm_head
60
+
61
+ def forward(
62
+ self,
63
+ input_ids: torch.LongTensor = None,
64
+ attention_mask: Optional[torch.Tensor] = None,
65
+ position_ids: Optional[torch.LongTensor] = None,
66
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
67
+ seqlens_in_batch: Optional[torch.LongTensor] = None,
68
+ inputs_embeds: Optional[torch.FloatTensor] = None,
69
+ labels: Optional[torch.LongTensor] = None,
70
+ use_cache: Optional[bool] = None,
71
+ output_attentions: Optional[bool] = None,
72
+ output_hidden_states: Optional[bool] = None,
73
+ images: Optional[torch.FloatTensor] = None,
74
+ return_dict: Optional[bool] = None,
75
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
76
+ if inputs_embeds is None:
77
+ (
78
+ input_ids,
79
+ position_ids,
80
+ attention_mask,
81
+ past_key_values,
82
+ inputs_embeds,
83
+ labels,
84
+ ) = self.prepare_inputs_labels_for_multimodal(
85
+ input_ids, position_ids, attention_mask, past_key_values, labels, images
86
+ )
87
+ if self.training:
88
+ (
89
+ _,
90
+ new_position_ids,
91
+ new_attention_mask,
92
+ _,
93
+ new_inputs_embeds,
94
+ new_labels,
95
+ sorted_seqlens_in_batch,
96
+ ) = self.repack_multimodal_data(
97
+ input_ids, position_ids, attention_mask, past_key_values, inputs_embeds, labels
98
+ )
99
+ if sorted_seqlens_in_batch is None:
100
+ sorted_seqlens_in_batch = seqlens_in_batch
101
+ new_input_ids = None
102
+ past_key_values = None
103
+ else:
104
+ new_attention_mask = attention_mask
105
+ new_position_ids = position_ids
106
+ new_inputs_embeds = inputs_embeds
107
+ new_labels = labels
108
+ sorted_seqlens_in_batch = attention_mask.sum(-1).int()
109
+ new_input_ids = input_ids
110
+
111
+ outputs = super().forward(
112
+ input_ids=new_input_ids,
113
+ attention_mask=new_attention_mask,
114
+ position_ids=new_position_ids,
115
+ past_key_values=past_key_values,
116
+ inputs_embeds=new_inputs_embeds,
117
+ labels=new_labels,
118
+ use_cache=use_cache,
119
+ output_attentions=output_attentions,
120
+ output_hidden_states=output_hidden_states,
121
+ return_dict=return_dict,
122
+ seqlens_in_batch=sorted_seqlens_in_batch,
123
+ )
124
+ return outputs
125
+
126
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
127
+ images = kwargs.pop("images", None)
128
+ _inputs = super().prepare_inputs_for_generation(
129
+ input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs
130
+ )
131
+ if images is not None:
132
+ _inputs["images"] = images
133
+ return _inputs
134
+
135
+
136
+ AutoConfig.register("llava_mistral", LlavaMistralConfig)
137
+ AutoModelForCausalLM.register(LlavaMistralConfig, LlavaMistralForCausalLM)
VILA/llava/model/language_model/llava_mixtral.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 NVIDIA CORPORATION & AFFILIATES
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # SPDX-License-Identifier: Apache-2.0
16
+
17
+ # This file is modified from https://github.com/haotian-liu/LLaVA/
18
+
19
+ from typing import List, Optional, Tuple, Union
20
+
21
+ import torch
22
+ import torch.nn as nn
23
+ from transformers import AutoConfig, AutoModelForCausalLM, MixtralConfig, MixtralForCausalLM, MixtralModel
24
+ from transformers.modeling_outputs import CausalLMOutputWithPast
25
+
26
+ from ..llava_arch import LlavaMetaForCausalLM, LlavaMetaModel
27
+
28
+
29
+ class LlavaMixtralConfig(MixtralConfig):
30
+ model_type = "llava_mixtral"
31
+ pretraining_tp = 1
32
+
33
+
34
+ class LlavaMixtralModel(MixtralModel, LlavaMetaModel):
35
+ config_class = LlavaMixtralConfig
36
+
37
+ def __init__(self, config: MixtralConfig):
38
+ super().__init__(config)
39
+
40
+
41
+ class LlavaMixtralForCausalLM(MixtralForCausalLM, LlavaMetaForCausalLM):
42
+ config_class = LlavaMixtralConfig
43
+
44
+ def __init__(self, config):
45
+ super(MixtralForCausalLM, self).__init__(config)
46
+ self.model = LlavaMixtralModel(config)
47
+ self.pretraining_tp = config.pretraining_tp
48
+ self.vocab_size = config.vocab_size
49
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
50
+
51
+ # Initialize weights and apply final processing
52
+ self.post_init()
53
+
54
+ def get_model(self):
55
+ return self.model
56
+
57
+ def get_lm_head(self):
58
+ return self.lm_head
59
+
60
+ def forward(
61
+ self,
62
+ input_ids: torch.LongTensor = None,
63
+ attention_mask: Optional[torch.Tensor] = None,
64
+ position_ids: Optional[torch.LongTensor] = None,
65
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
66
+ seqlens_in_batch: Optional[torch.LongTensor] = None,
67
+ inputs_embeds: Optional[torch.FloatTensor] = None,
68
+ labels: Optional[torch.LongTensor] = None,
69
+ use_cache: Optional[bool] = None,
70
+ output_attentions: Optional[bool] = None,
71
+ output_hidden_states: Optional[bool] = None,
72
+ images: Optional[torch.FloatTensor] = None,
73
+ return_dict: Optional[bool] = None,
74
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
75
+ if inputs_embeds is None:
76
+ (
77
+ input_ids,
78
+ position_ids,
79
+ attention_mask,
80
+ past_key_values,
81
+ inputs_embeds,
82
+ labels,
83
+ ) = self.prepare_inputs_labels_for_multimodal(
84
+ input_ids, position_ids, attention_mask, past_key_values, labels, images
85
+ )
86
+ if self.training:
87
+ (
88
+ _,
89
+ new_position_ids,
90
+ new_attention_mask,
91
+ _,
92
+ new_inputs_embeds,
93
+ new_labels,
94
+ sorted_seqlens_in_batch,
95
+ ) = self.repack_multimodal_data(
96
+ input_ids, position_ids, attention_mask, past_key_values, inputs_embeds, labels
97
+ )
98
+ if sorted_seqlens_in_batch is None:
99
+ sorted_seqlens_in_batch = seqlens_in_batch
100
+ new_input_ids = None
101
+ past_key_values = None
102
+ else:
103
+ new_attention_mask = attention_mask
104
+ new_position_ids = position_ids
105
+ new_inputs_embeds = inputs_embeds
106
+ new_labels = labels
107
+ sorted_seqlens_in_batch = attention_mask.sum(-1).int()
108
+ new_input_ids = input_ids
109
+
110
+ outputs = super().forward(
111
+ input_ids=new_input_ids,
112
+ attention_mask=new_attention_mask,
113
+ position_ids=new_position_ids,
114
+ past_key_values=past_key_values,
115
+ inputs_embeds=new_inputs_embeds,
116
+ labels=new_labels,
117
+ use_cache=use_cache,
118
+ output_attentions=output_attentions,
119
+ output_hidden_states=output_hidden_states,
120
+ return_dict=return_dict,
121
+ seqlens_in_batch=sorted_seqlens_in_batch,
122
+ )
123
+ return outputs
124
+
125
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
126
+ images = kwargs.pop("images", None)
127
+ _inputs = super().prepare_inputs_for_generation(
128
+ input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs
129
+ )
130
+ if images is not None:
131
+ _inputs["images"] = images
132
+ return _inputs
133
+
134
+
135
+ AutoConfig.register("llava_mixtral", LlavaMixtralConfig)
136
+ AutoModelForCausalLM.register(LlavaMixtralConfig, LlavaMixtralForCausalLM)
VILA/llava/model/language_model/llava_mpt.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Haotian Liu
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # This file is modified from https://github.com/haotian-liu/LLaVA/
16
+
17
+
18
+ import math
19
+ import warnings
20
+ from typing import List, Optional, Tuple
21
+
22
+ import torch
23
+ import torch.nn.functional as F
24
+ from transformers import AutoConfig, AutoModelForCausalLM
25
+ from transformers.modeling_outputs import CausalLMOutputWithPast
26
+
27
+ from llava.model.llava_arch import LlavaMetaForCausalLM, LlavaMetaModel
28
+
29
+ from .mpt.modeling_mpt import MPTConfig, MPTForCausalLM, MPTModel
30
+
31
+
32
+ class LlavaMPTConfig(MPTConfig):
33
+ model_type = "llava_mpt"
34
+
35
+
36
+ class LlavaMPTModel(MPTModel, LlavaMetaModel):
37
+ config_class = LlavaMPTConfig
38
+
39
+ def __init__(self, config: MPTConfig):
40
+ config.hidden_size = config.d_model
41
+ super().__init__(config)
42
+
43
+ def embed_tokens(self, x):
44
+ return self.wte(x)
45
+
46
+
47
+ class LlavaMPTForCausalLM(MPTForCausalLM, LlavaMetaForCausalLM):
48
+ config_class = LlavaMPTConfig
49
+ supports_gradient_checkpointing = True
50
+
51
+ def __init__(self, config):
52
+ super(MPTForCausalLM, self).__init__(config)
53
+
54
+ if not config.tie_word_embeddings:
55
+ raise ValueError("MPTForCausalLM only supports tied word embeddings")
56
+ self.transformer = LlavaMPTModel(config)
57
+ self.logit_scale = None
58
+ if config.logit_scale is not None:
59
+ logit_scale = config.logit_scale
60
+ if isinstance(logit_scale, str):
61
+ if logit_scale == "inv_sqrt_d_model":
62
+ logit_scale = 1 / math.sqrt(config.d_model)
63
+ else:
64
+ raise ValueError(
65
+ f"logit_scale={logit_scale!r} is not recognized as an option; use numeric value or 'inv_sqrt_d_model'."
66
+ )
67
+ self.logit_scale = logit_scale
68
+
69
+ def get_model(self):
70
+ return self.transformer
71
+
72
+ def _set_gradient_checkpointing(self, module, value=False):
73
+ if isinstance(module, LlavaMPTModel):
74
+ module.gradient_checkpointing = value
75
+
76
+ def forward(
77
+ self,
78
+ input_ids: torch.LongTensor,
79
+ past_key_values: Optional[List[Tuple[torch.FloatTensor]]] = None,
80
+ attention_mask: Optional[torch.ByteTensor] = None,
81
+ prefix_mask: Optional[torch.ByteTensor] = None,
82
+ sequence_id: Optional[torch.LongTensor] = None,
83
+ labels: Optional[torch.LongTensor] = None,
84
+ return_dict: Optional[bool] = None,
85
+ output_attentions: Optional[bool] = None,
86
+ output_hidden_states: Optional[bool] = None,
87
+ use_cache: Optional[bool] = None,
88
+ images=None,
89
+ ):
90
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
91
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
92
+
93
+ (
94
+ input_ids,
95
+ _,
96
+ attention_mask,
97
+ past_key_values,
98
+ inputs_embeds,
99
+ labels,
100
+ ) = self.prepare_inputs_labels_for_multimodal(input_ids, None, attention_mask, past_key_values, labels, images)
101
+ outputs = self.transformer(
102
+ input_ids=input_ids,
103
+ inputs_embeds=inputs_embeds,
104
+ past_key_values=past_key_values,
105
+ attention_mask=attention_mask,
106
+ prefix_mask=prefix_mask,
107
+ sequence_id=sequence_id,
108
+ return_dict=return_dict,
109
+ output_attentions=output_attentions,
110
+ output_hidden_states=output_hidden_states,
111
+ use_cache=use_cache,
112
+ )
113
+ # FIXME: this is a hack to fix the multiple gpu inference issue in https://github.com/haotian-liu/LLaVA/issues/338
114
+ logits = F.linear(outputs.last_hidden_state.to(self.transformer.wte.weight.device), self.transformer.wte.weight)
115
+ if self.logit_scale is not None:
116
+ if self.logit_scale == 0:
117
+ warnings.warn(
118
+ f"Multiplying logits by self.logit_scale={self.logit_scale!r}. This will produce uniform (uninformative) outputs."
119
+ )
120
+ logits *= self.logit_scale
121
+ loss = None
122
+ if labels is not None:
123
+ labels = torch.roll(labels, shifts=-1)
124
+ labels[:, -1] = -100
125
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)), labels.to(logits.device).view(-1))
126
+ return CausalLMOutputWithPast(
127
+ loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states
128
+ )
129
+
130
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
131
+ if inputs_embeds is not None:
132
+ raise NotImplementedError("inputs_embeds is not implemented for MPT yet")
133
+ attention_mask = kwargs["attention_mask"].bool()
134
+ if attention_mask[:, -1].sum() != attention_mask.shape[0]:
135
+ raise NotImplementedError("MPT does not support generation with right padding.")
136
+ if self.transformer.attn_uses_sequence_id and self.training:
137
+ sequence_id = torch.zeros_like(input_ids[:1])
138
+ else:
139
+ sequence_id = None
140
+ if past_key_values is not None:
141
+ input_ids = input_ids[:, -1].unsqueeze(-1)
142
+ if self.transformer.prefix_lm:
143
+ prefix_mask = torch.ones_like(attention_mask)
144
+ if kwargs.get("use_cache") == False:
145
+ raise NotImplementedError("MPT with prefix_lm=True does not support use_cache=False.")
146
+ else:
147
+ prefix_mask = None
148
+ return {
149
+ "input_ids": input_ids,
150
+ "attention_mask": attention_mask,
151
+ "prefix_mask": prefix_mask,
152
+ "sequence_id": sequence_id,
153
+ "past_key_values": past_key_values,
154
+ "use_cache": kwargs.get("use_cache", True),
155
+ "images": kwargs.get("images", None),
156
+ }
157
+
158
+
159
+ AutoConfig.register("llava_mpt", LlavaMPTConfig)
160
+ AutoModelForCausalLM.register(LlavaMPTConfig, LlavaMPTForCausalLM)
VILA/llava/model/language_model/modeling_mixtral_long_context.py ADDED
@@ -0,0 +1,1657 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Mistral AI and the HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
4
+ # and OPT implementations in this library. It has been modified from its
5
+ # original forms to accommodate minor architectural differences compared
6
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
7
+ #
8
+ # Licensed under the Apache License, Version 2.0 (the "License");
9
+ # you may not use this file except in compliance with the License.
10
+ # You may obtain a copy of the License at
11
+ #
12
+ # http://www.apache.org/licenses/LICENSE-2.0
13
+ #
14
+ # Unless required by applicable law or agreed to in writing, software
15
+ # distributed under the License is distributed on an "AS IS" BASIS,
16
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17
+ # See the License for the specific language governing permissions and
18
+ # limitations under the License.
19
+ """ PyTorch Mixtral model."""
20
+ import inspect
21
+ import math
22
+ import random
23
+ import warnings
24
+ from typing import List, Optional, Tuple, Union
25
+
26
+ import torch
27
+ import torch.nn.functional as F
28
+ import torch.utils.checkpoint
29
+ from torch import nn
30
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
31
+ from transformers.activations import ACT2FN
32
+ from transformers.cache_utils import Cache, DynamicCache
33
+ from transformers.modeling_attn_mask_utils import (
34
+ _prepare_4d_causal_attention_mask,
35
+ _prepare_4d_causal_attention_mask_for_sdpa,
36
+ )
37
+ from transformers.modeling_outputs import (
38
+ MoeCausalLMOutputWithPast,
39
+ MoeModelOutputWithPast,
40
+ SequenceClassifierOutputWithPast,
41
+ )
42
+ from transformers.modeling_utils import PreTrainedModel
43
+ from transformers.models.mixtral.configuration_mixtral import MixtralConfig
44
+ from transformers.pytorch_utils import is_torch_greater_or_equal_than_1_13
45
+ from transformers.utils import (
46
+ add_start_docstrings,
47
+ add_start_docstrings_to_model_forward,
48
+ is_flash_attn_2_available,
49
+ is_flash_attn_greater_or_equal_2_10,
50
+ logging,
51
+ replace_return_docstrings,
52
+ )
53
+ from transformers.utils.import_utils import is_torch_fx_available
54
+
55
+ if is_flash_attn_2_available():
56
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
57
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
58
+
59
+ _flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters)
60
+
61
+ # This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph.
62
+ # It means that the function will not be traced through and simply appear as a node in the graph.
63
+ if is_torch_fx_available():
64
+ if not is_torch_greater_or_equal_than_1_13:
65
+ import torch.fx
66
+
67
+ _prepare_4d_causal_attention_mask = torch.fx.wrap(_prepare_4d_causal_attention_mask)
68
+
69
+
70
+ logger = logging.get_logger(__name__)
71
+
72
+ _CONFIG_FOR_DOC = "MixtralConfig"
73
+
74
+
75
+ def load_balancing_loss_func(
76
+ gate_logits: torch.Tensor, num_experts: torch.Tensor = None, top_k=2, attention_mask: Optional[torch.Tensor] = None
77
+ ) -> float:
78
+ r"""
79
+ Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch.
80
+
81
+ See Switch Transformer (https://arxiv.org/abs/2101.03961) for more details. This function implements the loss
82
+ function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between
83
+ experts is too unbalanced.
84
+
85
+ Args:
86
+ gate_logits (Union[`torch.Tensor`, Tuple[torch.Tensor]):
87
+ Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of
88
+ shape [batch_size X sequence_length, num_experts].
89
+ attention_mask (`torch.Tensor`, None):
90
+ The attention_mask used in forward function
91
+ shape [batch_size X sequence_length] if not None.
92
+ num_experts (`int`, *optional*):
93
+ Number of experts
94
+
95
+ Returns:
96
+ The auxiliary loss.
97
+ """
98
+ if gate_logits is None or not isinstance(gate_logits, tuple):
99
+ return 0
100
+
101
+ if isinstance(gate_logits, tuple):
102
+ compute_device = gate_logits[0].device
103
+ concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0)
104
+
105
+ routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1)
106
+
107
+ _, selected_experts = torch.topk(routing_weights, top_k, dim=-1)
108
+
109
+ expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts)
110
+
111
+ if attention_mask is None:
112
+ # Compute the percentage of tokens routed to each experts
113
+ tokens_per_expert = torch.mean(expert_mask.float(), dim=0)
114
+
115
+ # Compute the average probability of routing to these experts
116
+ router_prob_per_expert = torch.mean(routing_weights, dim=0)
117
+ else:
118
+ batch_size, sequence_length = attention_mask.shape
119
+ num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length)
120
+
121
+ # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask
122
+ expert_attention_mask = (
123
+ attention_mask[None, :, :, None, None]
124
+ .expand((num_hidden_layers, batch_size, sequence_length, 2, num_experts))
125
+ .reshape(-1, 2, num_experts)
126
+ .to(compute_device)
127
+ )
128
+
129
+ # Compute the percentage of tokens routed to each experts
130
+ tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum(
131
+ expert_attention_mask, dim=0
132
+ )
133
+
134
+ # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert
135
+ router_per_expert_attention_mask = (
136
+ attention_mask[None, :, :, None]
137
+ .expand((num_hidden_layers, batch_size, sequence_length, num_experts))
138
+ .reshape(-1, num_experts)
139
+ .to(compute_device)
140
+ )
141
+
142
+ # Compute the average probability of routing to these experts
143
+ router_prob_per_expert = torch.sum(routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum(
144
+ router_per_expert_attention_mask, dim=0
145
+ )
146
+
147
+ overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0))
148
+ return overall_loss * num_experts
149
+
150
+
151
+ # Copied from transformers.models.llama.modeling_llama._get_unpad_data
152
+ def _get_unpad_data(attention_mask):
153
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
154
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
155
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
156
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
157
+ return (
158
+ indices,
159
+ cu_seqlens,
160
+ max_seqlen_in_batch,
161
+ )
162
+
163
+
164
+ # Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Mixtral
165
+ class MixtralRMSNorm(nn.Module):
166
+ def __init__(self, hidden_size, eps=1e-6):
167
+ """
168
+ MixtralRMSNorm is equivalent to T5LayerNorm
169
+ """
170
+ super().__init__()
171
+ self.weight = nn.Parameter(torch.ones(hidden_size))
172
+ self.variance_epsilon = eps
173
+
174
+ def forward(self, hidden_states):
175
+ input_dtype = hidden_states.dtype
176
+ hidden_states = hidden_states.to(torch.float32)
177
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
178
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
179
+ return self.weight * hidden_states.to(input_dtype)
180
+
181
+
182
+ # Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->Mixtral
183
+ class MixtralRotaryEmbedding(nn.Module):
184
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
185
+ super().__init__()
186
+
187
+ self.dim = dim
188
+ self.max_position_embeddings = max_position_embeddings
189
+ self.base = base
190
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
191
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
192
+
193
+ # Build here to make `torch.jit.trace` work.
194
+ self._set_cos_sin_cache(
195
+ seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
196
+ )
197
+
198
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
199
+ self.max_seq_len_cached = seq_len
200
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
201
+
202
+ freqs = torch.outer(t, self.inv_freq)
203
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
204
+ emb = torch.cat((freqs, freqs), dim=-1)
205
+ self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
206
+ self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
207
+
208
+ def forward(self, x, seq_len=None):
209
+ # x: [bs, num_attention_heads, seq_len, head_size]
210
+ if seq_len > self.max_seq_len_cached:
211
+ self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
212
+
213
+ return (
214
+ self.cos_cached[:seq_len].to(dtype=x.dtype),
215
+ self.sin_cached[:seq_len].to(dtype=x.dtype),
216
+ )
217
+
218
+
219
+ class MixtralLinearScalingRotaryEmbedding(MixtralRotaryEmbedding):
220
+ """MixtralRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
221
+
222
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
223
+ self.scaling_factor = scaling_factor
224
+ super().__init__(dim, max_position_embeddings, base, device)
225
+
226
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
227
+ self.max_seq_len_cached = seq_len
228
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
229
+ t = t / self.scaling_factor
230
+
231
+ freqs = torch.outer(t, self.inv_freq)
232
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
233
+ emb = torch.cat((freqs, freqs), dim=-1)
234
+ self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
235
+ self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
236
+
237
+
238
+ # Copied from transformers.models.llama.modeling_llama.rotate_half
239
+ def rotate_half(x):
240
+ """Rotates half the hidden dims of the input."""
241
+ x1 = x[..., : x.shape[-1] // 2]
242
+ x2 = x[..., x.shape[-1] // 2 :]
243
+ return torch.cat((-x2, x1), dim=-1)
244
+
245
+
246
+ # Copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb
247
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
248
+ """Applies Rotary Position Embedding to the query and key tensors.
249
+
250
+ Args:
251
+ q (`torch.Tensor`): The query tensor.
252
+ k (`torch.Tensor`): The key tensor.
253
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
254
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
255
+ position_ids (`torch.Tensor`):
256
+ The position indices of the tokens corresponding to the query and key tensors. For example, this can be
257
+ used to pass offsetted position ids when working with a KV-cache.
258
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
259
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
260
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
261
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
262
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
263
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
264
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
265
+ Returns:
266
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
267
+ """
268
+ cos = cos[position_ids].unsqueeze(unsqueeze_dim)
269
+ sin = sin[position_ids].unsqueeze(unsqueeze_dim)
270
+ q_embed = (q * cos) + (rotate_half(q) * sin)
271
+ k_embed = (k * cos) + (rotate_half(k) * sin)
272
+ return q_embed, k_embed
273
+
274
+
275
+ # Copied from transformers.models.llama.modeling_llama.repeat_kv
276
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
277
+ """
278
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
279
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
280
+ """
281
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
282
+ if n_rep == 1:
283
+ return hidden_states
284
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
285
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
286
+
287
+
288
+ # Copied from transformers.models.mistral.modeling_mistral.MistralAttention with Mistral->Mixtral
289
+ class MixtralAttention(nn.Module):
290
+ """
291
+ Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer
292
+ and "Generating Long Sequences with Sparse Transformers".
293
+ """
294
+
295
+ def __init__(self, config: MixtralConfig, layer_idx: Optional[int] = None):
296
+ super().__init__()
297
+ self.config = config
298
+ self.layer_idx = layer_idx
299
+ if layer_idx is None:
300
+ logger.warning_once(
301
+ f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
302
+ "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
303
+ "when creating this class."
304
+ )
305
+
306
+ self.hidden_size = config.hidden_size
307
+ self.num_heads = config.num_attention_heads
308
+ self.head_dim = self.hidden_size // self.num_heads
309
+ self.num_key_value_heads = config.num_key_value_heads
310
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
311
+ self.max_position_embeddings = config.max_position_embeddings
312
+ self.rope_theta = config.rope_theta
313
+ self.is_causal = True
314
+ self.attention_dropout = config.attention_dropout
315
+
316
+ if (self.head_dim * self.num_heads) != self.hidden_size:
317
+ raise ValueError(
318
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
319
+ f" and `num_heads`: {self.num_heads})."
320
+ )
321
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
322
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
323
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
324
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
325
+
326
+ self._init_rope()
327
+
328
+ def _init_rope(self):
329
+ if self.config.rope_scaling is None:
330
+ self.rotary_emb = MixtralRotaryEmbedding(
331
+ self.head_dim,
332
+ max_position_embeddings=self.max_position_embeddings,
333
+ base=self.rope_theta,
334
+ )
335
+ else:
336
+ scaling_type = self.config.rope_scaling["type"]
337
+ scaling_factor = self.config.rope_scaling["factor"]
338
+ if scaling_type == "linear":
339
+ self.rotary_emb = MixtralLinearScalingRotaryEmbedding(
340
+ self.head_dim,
341
+ max_position_embeddings=self.max_position_embeddings,
342
+ scaling_factor=scaling_factor,
343
+ base=self.rope_theta,
344
+ )
345
+ elif scaling_type == "randomlinear":
346
+ self.rotary_emb = None
347
+ else:
348
+ raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
349
+
350
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
351
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
352
+
353
+ def forward(
354
+ self,
355
+ hidden_states: torch.Tensor,
356
+ attention_mask: Optional[torch.Tensor] = None,
357
+ position_ids: Optional[torch.LongTensor] = None,
358
+ past_key_value: Optional[Cache] = None,
359
+ output_attentions: bool = False,
360
+ use_cache: bool = False,
361
+ **kwargs,
362
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
363
+ if "padding_mask" in kwargs:
364
+ warnings.warn(
365
+ "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
366
+ )
367
+ bsz, q_len, _ = hidden_states.size()
368
+
369
+ query_states = self.q_proj(hidden_states)
370
+ key_states = self.k_proj(hidden_states)
371
+ value_states = self.v_proj(hidden_states)
372
+
373
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
374
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
375
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
376
+
377
+ kv_seq_len = key_states.shape[-2]
378
+ if past_key_value is not None:
379
+ if self.layer_idx is None:
380
+ raise ValueError(
381
+ f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
382
+ "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
383
+ "with a layer index."
384
+ )
385
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
386
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
387
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
388
+
389
+ if past_key_value is not None:
390
+ cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
391
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
392
+
393
+ # repeat k/v heads if n_kv_heads < n_heads
394
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
395
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
396
+
397
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
398
+
399
+ if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
400
+ raise ValueError(
401
+ f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
402
+ f" {attn_weights.size()}"
403
+ )
404
+
405
+ if attention_mask is not None:
406
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
407
+ raise ValueError(
408
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
409
+ )
410
+
411
+ attn_weights = attn_weights + attention_mask
412
+
413
+ # upcast attention to fp32
414
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
415
+ attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
416
+ attn_output = torch.matmul(attn_weights, value_states)
417
+
418
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
419
+ raise ValueError(
420
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
421
+ f" {attn_output.size()}"
422
+ )
423
+
424
+ attn_output = attn_output.transpose(1, 2).contiguous()
425
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
426
+
427
+ attn_output = self.o_proj(attn_output)
428
+
429
+ if not output_attentions:
430
+ attn_weights = None
431
+
432
+ return attn_output, attn_weights, past_key_value
433
+
434
+
435
+ # Copied from transformers.models.mistral.modeling_mistral.MistralFlashAttention2 with Mistral->Mixtral
436
+ class MixtralFlashAttention2(MixtralAttention):
437
+ """
438
+ Mixtral flash attention module. This module inherits from `MixtralAttention` as the weights of the module stays
439
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
440
+ flash attention and deal with padding tokens in case the input contains any of them.
441
+ """
442
+
443
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
444
+ def __init__(self, *args, **kwargs):
445
+ super().__init__(*args, **kwargs)
446
+
447
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
448
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
449
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
450
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
451
+
452
+ def forward(
453
+ self,
454
+ hidden_states: torch.Tensor,
455
+ attention_mask: Optional[torch.Tensor] = None,
456
+ position_ids: Optional[torch.LongTensor] = None,
457
+ past_key_value: Optional[Cache] = None,
458
+ output_attentions: bool = False,
459
+ use_cache: bool = False,
460
+ rotary_emb=None,
461
+ **kwargs,
462
+ ):
463
+ if "padding_mask" in kwargs:
464
+ warnings.warn(
465
+ "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
466
+ )
467
+
468
+ # overwrite attention_mask with padding_mask
469
+ attention_mask = kwargs.pop("padding_mask")
470
+ bsz, q_len, _ = hidden_states.size()
471
+
472
+ query_states = self.q_proj(hidden_states)
473
+ key_states = self.k_proj(hidden_states)
474
+ value_states = self.v_proj(hidden_states)
475
+
476
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
477
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
478
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
479
+
480
+ kv_seq_len = key_states.shape[-2]
481
+ if past_key_value is not None:
482
+ if self.layer_idx is None:
483
+ raise ValueError(
484
+ f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
485
+ "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
486
+ "with a layer index."
487
+ )
488
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
489
+
490
+ # Because the input can be padded, the absolute sequence length depends on the max position id.
491
+ rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1
492
+ if rotary_emb is None:
493
+ rotary_emb = self.rotary_emb
494
+ cos, sin = rotary_emb(value_states, seq_len=rotary_seq_len)
495
+
496
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
497
+
498
+ use_sliding_windows = (
499
+ _flash_supports_window_size
500
+ and getattr(self.config, "sliding_window", None) is not None
501
+ and kv_seq_len > self.config.sliding_window
502
+ )
503
+
504
+ if not _flash_supports_window_size:
505
+ logger.warning_once(
506
+ "The current flash attention version does not support sliding window attention, for a more memory efficient implementation"
507
+ " make sure to upgrade flash-attn library."
508
+ )
509
+
510
+ if past_key_value is not None:
511
+ # Activate slicing cache only if the config has a value `sliding_windows` attribute
512
+ cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0
513
+ if (
514
+ getattr(self.config, "sliding_window", None) is not None
515
+ and kv_seq_len > self.config.sliding_window
516
+ and cache_has_contents
517
+ ):
518
+ slicing_tokens = 1 - self.config.sliding_window
519
+
520
+ past_key = past_key_value[self.layer_idx][0]
521
+ past_value = past_key_value[self.layer_idx][1]
522
+
523
+ past_key = past_key[:, :, slicing_tokens:, :].contiguous()
524
+ past_value = past_value[:, :, slicing_tokens:, :].contiguous()
525
+
526
+ if past_key.shape[-2] != self.config.sliding_window - 1:
527
+ raise ValueError(
528
+ f"past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got"
529
+ f" {past_key.shape}"
530
+ )
531
+
532
+ if attention_mask is not None:
533
+ attention_mask = attention_mask[:, slicing_tokens:]
534
+ attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1)
535
+
536
+ cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
537
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
538
+
539
+ # repeat k/v heads if n_kv_heads < n_heads
540
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
541
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
542
+ dropout_rate = 0.0 if not self.training else self.attention_dropout
543
+
544
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
545
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
546
+ # cast them back in float16 just to be sure everything works as expected.
547
+ input_dtype = query_states.dtype
548
+ if input_dtype == torch.float32:
549
+ if torch.is_autocast_enabled():
550
+ target_dtype = torch.get_autocast_gpu_dtype()
551
+ # Handle the case where the model is quantized
552
+ elif hasattr(self.config, "_pre_quantization_dtype"):
553
+ target_dtype = self.config._pre_quantization_dtype
554
+ else:
555
+ target_dtype = self.q_proj.weight.dtype
556
+
557
+ logger.warning_once(
558
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
559
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
560
+ f" {target_dtype}."
561
+ )
562
+
563
+ query_states = query_states.to(target_dtype)
564
+ key_states = key_states.to(target_dtype)
565
+ value_states = value_states.to(target_dtype)
566
+
567
+ # Reashape to the expected shape for Flash Attention
568
+ query_states = query_states.transpose(1, 2)
569
+ key_states = key_states.transpose(1, 2)
570
+ value_states = value_states.transpose(1, 2)
571
+
572
+ attn_output = self._flash_attention_forward(
573
+ query_states,
574
+ key_states,
575
+ value_states,
576
+ attention_mask,
577
+ q_len,
578
+ dropout=dropout_rate,
579
+ use_sliding_windows=use_sliding_windows,
580
+ )
581
+
582
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
583
+ attn_output = self.o_proj(attn_output)
584
+
585
+ if not output_attentions:
586
+ attn_weights = None
587
+
588
+ return attn_output, attn_weights, past_key_value
589
+
590
+ def _flash_attention_forward(
591
+ self,
592
+ query_states,
593
+ key_states,
594
+ value_states,
595
+ attention_mask,
596
+ query_length,
597
+ dropout=0.0,
598
+ softmax_scale=None,
599
+ use_sliding_windows=False,
600
+ ):
601
+ """
602
+ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
603
+ first unpad the input, then computes the attention scores and pad the final attention scores.
604
+
605
+ Args:
606
+ query_states (`torch.Tensor`):
607
+ Input query states to be passed to Flash Attention API
608
+ key_states (`torch.Tensor`):
609
+ Input key states to be passed to Flash Attention API
610
+ value_states (`torch.Tensor`):
611
+ Input value states to be passed to Flash Attention API
612
+ attention_mask (`torch.Tensor`):
613
+ The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
614
+ position of padding tokens and 1 for the position of non-padding tokens.
615
+ dropout (`int`, *optional*):
616
+ Attention dropout
617
+ softmax_scale (`float`, *optional*):
618
+ The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
619
+ use_sliding_windows (`bool`, *optional*):
620
+ Whether to activate sliding window attention.
621
+ """
622
+ if not self._flash_attn_uses_top_left_mask:
623
+ causal = self.is_causal
624
+ else:
625
+ # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
626
+ causal = self.is_causal and query_length != 1
627
+
628
+ # Contains at least one padding token in the sequence
629
+ if attention_mask is not None:
630
+ batch_size = query_states.shape[0]
631
+ query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
632
+ query_states, key_states, value_states, attention_mask, query_length
633
+ )
634
+
635
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
636
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
637
+
638
+ if not use_sliding_windows:
639
+ attn_output_unpad = flash_attn_varlen_func(
640
+ query_states,
641
+ key_states,
642
+ value_states,
643
+ cu_seqlens_q=cu_seqlens_q,
644
+ cu_seqlens_k=cu_seqlens_k,
645
+ max_seqlen_q=max_seqlen_in_batch_q,
646
+ max_seqlen_k=max_seqlen_in_batch_k,
647
+ dropout_p=dropout,
648
+ softmax_scale=softmax_scale,
649
+ causal=causal,
650
+ )
651
+ else:
652
+ attn_output_unpad = flash_attn_varlen_func(
653
+ query_states,
654
+ key_states,
655
+ value_states,
656
+ cu_seqlens_q=cu_seqlens_q,
657
+ cu_seqlens_k=cu_seqlens_k,
658
+ max_seqlen_q=max_seqlen_in_batch_q,
659
+ max_seqlen_k=max_seqlen_in_batch_k,
660
+ dropout_p=dropout,
661
+ softmax_scale=softmax_scale,
662
+ causal=causal,
663
+ window_size=(self.config.sliding_window, self.config.sliding_window),
664
+ )
665
+
666
+ attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
667
+ else:
668
+ if not use_sliding_windows:
669
+ attn_output = flash_attn_func(
670
+ query_states,
671
+ key_states,
672
+ value_states,
673
+ dropout,
674
+ softmax_scale=softmax_scale,
675
+ causal=causal,
676
+ )
677
+ else:
678
+ attn_output = flash_attn_func(
679
+ query_states,
680
+ key_states,
681
+ value_states,
682
+ dropout,
683
+ softmax_scale=softmax_scale,
684
+ causal=causal,
685
+ window_size=(self.config.sliding_window, self.config.sliding_window),
686
+ )
687
+
688
+ return attn_output
689
+
690
+ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
691
+ batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape
692
+
693
+ # On the first iteration we need to properly re-create the padding mask
694
+ # by slicing it on the proper place
695
+ if kv_seq_len != attention_mask.shape[-1]:
696
+ attention_mask_num_tokens = attention_mask.shape[-1]
697
+ attention_mask = attention_mask[:, attention_mask_num_tokens - kv_seq_len :]
698
+
699
+ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
700
+
701
+ key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
702
+ value_layer = index_first_axis(value_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
703
+
704
+ if query_length == kv_seq_len:
705
+ query_layer = index_first_axis(query_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
706
+ cu_seqlens_q = cu_seqlens_k
707
+ max_seqlen_in_batch_q = max_seqlen_in_batch_k
708
+ indices_q = indices_k
709
+ elif query_length == 1:
710
+ max_seqlen_in_batch_q = 1
711
+ cu_seqlens_q = torch.arange(
712
+ batch_size + 1, dtype=torch.int32, device=query_layer.device
713
+ ) # There is a memcpy here, that is very bad.
714
+ indices_q = cu_seqlens_q[:-1]
715
+ query_layer = query_layer.squeeze(1)
716
+ else:
717
+ # The -q_len: slice assumes left padding.
718
+ attention_mask = attention_mask[:, -query_length:]
719
+ query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
720
+
721
+ return (
722
+ query_layer,
723
+ key_layer,
724
+ value_layer,
725
+ indices_q,
726
+ (cu_seqlens_q, cu_seqlens_k),
727
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
728
+ )
729
+
730
+
731
+ # Copied from transformers.models.mistral.modeling_mistral.MistralSdpaAttention with Mistral->Mixtral
732
+ class MixtralSdpaAttention(MixtralAttention):
733
+ """
734
+ Mixtral attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
735
+ `MixtralAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
736
+ SDPA API.
737
+ """
738
+
739
+ # Adapted from MixtralAttention.forward
740
+ def forward(
741
+ self,
742
+ hidden_states: torch.Tensor,
743
+ attention_mask: Optional[torch.Tensor] = None,
744
+ position_ids: Optional[torch.LongTensor] = None,
745
+ past_key_value: Optional[Cache] = None,
746
+ output_attentions: bool = False,
747
+ use_cache: bool = False,
748
+ rotary_emb=None,
749
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
750
+ if output_attentions:
751
+ # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
752
+ logger.warning_once(
753
+ "MixtralModel is using MixtralSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
754
+ 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
755
+ )
756
+ return super().forward(
757
+ hidden_states=hidden_states,
758
+ attention_mask=attention_mask,
759
+ position_ids=position_ids,
760
+ past_key_value=past_key_value,
761
+ output_attentions=output_attentions,
762
+ use_cache=use_cache,
763
+ )
764
+
765
+ bsz, q_len, _ = hidden_states.size()
766
+
767
+ query_states = self.q_proj(hidden_states)
768
+ key_states = self.k_proj(hidden_states)
769
+ value_states = self.v_proj(hidden_states)
770
+
771
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
772
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
773
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
774
+
775
+ kv_seq_len = key_states.shape[-2]
776
+ if past_key_value is not None:
777
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
778
+
779
+ if rotary_emb is None:
780
+ rotary_emb = self.rotary_emb
781
+ cos, sin = rotary_emb(value_states, seq_len=kv_seq_len)
782
+
783
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
784
+
785
+ if past_key_value is not None:
786
+ cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
787
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
788
+
789
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
790
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
791
+
792
+ if attention_mask is not None:
793
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
794
+ raise ValueError(
795
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
796
+ )
797
+
798
+ # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
799
+ # Reference: https://github.com/pytorch/pytorch/issues/112577.
800
+ if query_states.device.type == "cuda" and attention_mask is not None:
801
+ query_states = query_states.contiguous()
802
+ key_states = key_states.contiguous()
803
+ value_states = value_states.contiguous()
804
+
805
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
806
+ query_states,
807
+ key_states,
808
+ value_states,
809
+ attn_mask=attention_mask,
810
+ dropout_p=self.attention_dropout if self.training else 0.0,
811
+ # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
812
+ is_causal=self.is_causal and attention_mask is None and q_len > 1,
813
+ )
814
+
815
+ attn_output = attn_output.transpose(1, 2).contiguous()
816
+ attn_output = attn_output.view(bsz, q_len, self.hidden_size)
817
+
818
+ attn_output = self.o_proj(attn_output)
819
+
820
+ return attn_output, None, past_key_value
821
+
822
+
823
+ MIXTRAL_ATTENTION_CLASSES = {
824
+ "eager": MixtralAttention,
825
+ "flash_attention_2": MixtralFlashAttention2,
826
+ "sdpa": MixtralSdpaAttention,
827
+ }
828
+
829
+
830
+ class MixtralBlockSparseTop2MLP(nn.Module):
831
+ def __init__(self, config: MixtralConfig):
832
+ super().__init__()
833
+ self.ffn_dim = config.intermediate_size
834
+ self.hidden_dim = config.hidden_size
835
+
836
+ self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
837
+ self.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False)
838
+ self.w3 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
839
+
840
+ self.act_fn = ACT2FN[config.hidden_act]
841
+
842
+ def forward(self, hidden_states):
843
+ current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states)
844
+ current_hidden_states = self.w2(current_hidden_states)
845
+ return current_hidden_states
846
+
847
+
848
+ class MixtralBLockSparseTop2MLP(MixtralBlockSparseTop2MLP):
849
+ def __init__(self, *args, **kwargs):
850
+ logger.warning_once(
851
+ "MixtralBLockSparseTop2MLP is deprecated by MixtralBlockSparseTop2MLP and will be removed in v4.40."
852
+ )
853
+ super().__init__(*args, **kwargs)
854
+
855
+
856
+ class MixtralSparseMoeBlock(nn.Module):
857
+ """
858
+ This implementation is
859
+ strictly equivalent to standard MoE with full capacity (no
860
+ dropped tokens). It's faster since it formulates MoE operations
861
+ in terms of block-sparse operations to accomodate imbalanced
862
+ assignments of tokens to experts, whereas standard MoE either
863
+ (1) drop tokens at the cost of reduced performance or (2) set
864
+ capacity factor to number of experts and thus waste computation
865
+ and memory on padding.
866
+ """
867
+
868
+ def __init__(self, config):
869
+ super().__init__()
870
+ self.hidden_dim = config.hidden_size
871
+ self.ffn_dim = config.intermediate_size
872
+ self.num_experts = config.num_local_experts
873
+ self.top_k = config.num_experts_per_tok
874
+
875
+ # gating
876
+ self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False)
877
+
878
+ self.experts = nn.ModuleList([MixtralBlockSparseTop2MLP(config) for _ in range(self.num_experts)])
879
+
880
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
881
+ """ """
882
+ batch_size, sequence_length, hidden_dim = hidden_states.shape
883
+ hidden_states = hidden_states.view(-1, hidden_dim)
884
+ # router_logits: (batch * sequence_length, n_experts)
885
+ router_logits = self.gate(hidden_states)
886
+
887
+ routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
888
+ routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
889
+ routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
890
+ # we cast back to the input dtype
891
+ routing_weights = routing_weights.to(hidden_states.dtype)
892
+
893
+ final_hidden_states = torch.zeros(
894
+ (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
895
+ )
896
+
897
+ # One hot encode the selected experts to create an expert mask
898
+ # this will be used to easily index which expert is going to be sollicitated
899
+ expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
900
+
901
+ # Loop over all available experts in the model and perform the computation on each expert
902
+ for expert_idx in range(self.num_experts):
903
+ expert_layer = self.experts[expert_idx]
904
+ idx, top_x = torch.where(expert_mask[expert_idx])
905
+
906
+ if top_x.shape[0] == 0:
907
+ if self.training:
908
+ top_x_ = torch.zeros(1).to(hidden_states.device).to(torch.int32)
909
+ top_x_list = top_x_.tolist()
910
+ current_state = hidden_states[None, top_x_list].reshape(-1, hidden_dim)
911
+ fake_state = expert_layer(current_state * 0)
912
+ final_hidden_states.index_add_(0, top_x_, fake_state.to(hidden_states.dtype))
913
+ continue
914
+
915
+ # in torch it is faster to index using lists than torch tensors
916
+ top_x_list = top_x.tolist()
917
+ idx_list = idx.tolist()
918
+
919
+ # Index the correct hidden states and compute the expert hidden state for
920
+ # the current expert. We need to make sure to multiply the output hidden
921
+ # states by `routing_weights` on the corresponding tokens (top-1 and top-2)
922
+ current_state = hidden_states[None, top_x_list].reshape(-1, hidden_dim)
923
+ current_hidden_states = expert_layer(current_state) * routing_weights[top_x_list, idx_list, None]
924
+
925
+ # However `index_add_` only support torch tensors for indexing so we'll use
926
+ # the `top_x` tensor here.
927
+ final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
928
+ final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
929
+ return final_hidden_states, router_logits
930
+
931
+
932
+ class MixtralDecoderLayer(nn.Module):
933
+ def __init__(self, config: MixtralConfig, layer_idx: int):
934
+ super().__init__()
935
+ self.hidden_size = config.hidden_size
936
+
937
+ self.self_attn = MIXTRAL_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)
938
+
939
+ self.block_sparse_moe = MixtralSparseMoeBlock(config)
940
+ self.input_layernorm = MixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
941
+ self.post_attention_layernorm = MixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
942
+
943
+ def forward(
944
+ self,
945
+ hidden_states: torch.Tensor,
946
+ attention_mask: Optional[torch.Tensor] = None,
947
+ position_ids: Optional[torch.LongTensor] = None,
948
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
949
+ output_attentions: Optional[bool] = False,
950
+ output_router_logits: Optional[bool] = False,
951
+ use_cache: Optional[bool] = False,
952
+ rotary_emb=None,
953
+ **kwargs,
954
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
955
+ if "padding_mask" in kwargs:
956
+ warnings.warn(
957
+ "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
958
+ )
959
+ """
960
+ Args:
961
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
962
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
963
+ `(batch, sequence_length)` where padding elements are indicated by 0.
964
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
965
+ output_attentions (`bool`, *optional*):
966
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
967
+ returned tensors for more detail.
968
+ output_router_logits (`bool`, *optional*):
969
+ Whether or not to return the logits of all the routers. They are useful for computing the router loss, and
970
+ should not be returned during inference.
971
+ use_cache (`bool`, *optional*):
972
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
973
+ (see `past_key_values`).
974
+ """
975
+
976
+ residual = hidden_states
977
+
978
+ hidden_states = self.input_layernorm(hidden_states)
979
+
980
+ # Self Attention
981
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
982
+ hidden_states=hidden_states,
983
+ attention_mask=attention_mask,
984
+ position_ids=position_ids,
985
+ past_key_value=past_key_value,
986
+ output_attentions=output_attentions,
987
+ use_cache=use_cache,
988
+ rotary_emb=rotary_emb,
989
+ )
990
+ hidden_states = residual + hidden_states
991
+
992
+ # Fully Connected
993
+ residual = hidden_states
994
+ hidden_states = self.post_attention_layernorm(hidden_states)
995
+ hidden_states, router_logits = self.block_sparse_moe(hidden_states)
996
+ hidden_states = residual + hidden_states
997
+
998
+ outputs = (hidden_states,)
999
+
1000
+ if output_attentions:
1001
+ outputs += (self_attn_weights,)
1002
+
1003
+ if use_cache:
1004
+ outputs += (present_key_value,)
1005
+
1006
+ if output_router_logits:
1007
+ outputs += (router_logits,)
1008
+
1009
+ return outputs
1010
+
1011
+
1012
+ MIXTRAL_START_DOCSTRING = r"""
1013
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
1014
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
1015
+ etc.)
1016
+
1017
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
1018
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
1019
+ and behavior.
1020
+
1021
+ Parameters:
1022
+ config ([`MixtralConfig`]):
1023
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
1024
+ load the weights associated with the model, only the configuration. Check out the
1025
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
1026
+ """
1027
+
1028
+
1029
+ @add_start_docstrings(
1030
+ "The bare Mixtral Model outputting raw hidden-states without any specific head on top.",
1031
+ MIXTRAL_START_DOCSTRING,
1032
+ )
1033
+ # Copied from transformers.models.mistral.modeling_mistral.MistralPreTrainedModel with Mistral->Mixtral
1034
+ class MixtralPreTrainedModel(PreTrainedModel):
1035
+ config_class = MixtralConfig
1036
+ base_model_prefix = "model"
1037
+ supports_gradient_checkpointing = True
1038
+ _no_split_modules = ["MixtralDecoderLayer"]
1039
+ _skip_keys_device_placement = "past_key_values"
1040
+ _supports_flash_attn_2 = True
1041
+ _supports_sdpa = True
1042
+ _supports_cache_class = True
1043
+
1044
+ def _init_weights(self, module):
1045
+ std = self.config.initializer_range
1046
+ if isinstance(module, nn.Linear):
1047
+ module.weight.data.normal_(mean=0.0, std=std)
1048
+ if module.bias is not None:
1049
+ module.bias.data.zero_()
1050
+ elif isinstance(module, nn.Embedding):
1051
+ module.weight.data.normal_(mean=0.0, std=std)
1052
+ if module.padding_idx is not None:
1053
+ module.weight.data[module.padding_idx].zero_()
1054
+
1055
+
1056
+ MIXTRAL_INPUTS_DOCSTRING = r"""
1057
+ Args:
1058
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
1059
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
1060
+ it.
1061
+
1062
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
1063
+ [`PreTrainedTokenizer.__call__`] for details.
1064
+
1065
+ [What are input IDs?](../glossary#input-ids)
1066
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
1067
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
1068
+
1069
+ - 1 for tokens that are **not masked**,
1070
+ - 0 for tokens that are **masked**.
1071
+
1072
+ [What are attention masks?](../glossary#attention-mask)
1073
+
1074
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
1075
+ [`PreTrainedTokenizer.__call__`] for details.
1076
+
1077
+ If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
1078
+ `past_key_values`).
1079
+
1080
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
1081
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
1082
+ information on the default strategy.
1083
+
1084
+ - 1 indicates the head is **not masked**,
1085
+ - 0 indicates the head is **masked**.
1086
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1087
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
1088
+ config.n_positions - 1]`.
1089
+
1090
+ [What are position IDs?](../glossary#position-ids)
1091
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
1092
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
1093
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
1094
+ `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
1095
+
1096
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
1097
+ blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
1098
+
1099
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
1100
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
1101
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
1102
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
1103
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
1104
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
1105
+ model's internal embedding lookup matrix.
1106
+ use_cache (`bool`, *optional*):
1107
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
1108
+ `past_key_values`).
1109
+ output_attentions (`bool`, *optional*):
1110
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
1111
+ tensors for more detail.
1112
+ output_hidden_states (`bool`, *optional*):
1113
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
1114
+ more detail.
1115
+ output_router_logits (`bool`, *optional*):
1116
+ Whether or not to return the logits of all the routers. They are useful for computing the router loss, and
1117
+ should not be returned during inference.
1118
+ return_dict (`bool`, *optional*):
1119
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1120
+ """
1121
+
1122
+
1123
+ @add_start_docstrings(
1124
+ "The bare Mixtral Model outputting raw hidden-states without any specific head on top.",
1125
+ MIXTRAL_START_DOCSTRING,
1126
+ )
1127
+ # Copied from transformers.models.mistral.modeling_mistral.MistralModel with MISTRAL->MIXTRAL,Mistral->Mixtral
1128
+ class MixtralModel(MixtralPreTrainedModel):
1129
+ """
1130
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`MixtralDecoderLayer`]
1131
+
1132
+ Args:
1133
+ config: MixtralConfig
1134
+ """
1135
+
1136
+ def __init__(self, config: MixtralConfig):
1137
+ super().__init__(config)
1138
+ self.padding_idx = config.pad_token_id
1139
+ self.vocab_size = config.vocab_size
1140
+
1141
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
1142
+ self.layers = nn.ModuleList(
1143
+ [MixtralDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
1144
+ )
1145
+ self._attn_implementation = config._attn_implementation
1146
+ self.norm = MixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1147
+
1148
+ self.gradient_checkpointing = False
1149
+ # Initialize weights and apply final processing
1150
+ self.post_init()
1151
+
1152
+ def get_input_embeddings(self):
1153
+ return self.embed_tokens
1154
+
1155
+ def set_input_embeddings(self, value):
1156
+ self.embed_tokens = value
1157
+
1158
+ # Ignore copy
1159
+ @add_start_docstrings_to_model_forward(MIXTRAL_INPUTS_DOCSTRING)
1160
+ def forward(
1161
+ self,
1162
+ input_ids: torch.LongTensor = None,
1163
+ attention_mask: Optional[torch.Tensor] = None,
1164
+ position_ids: Optional[torch.LongTensor] = None,
1165
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1166
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1167
+ use_cache: Optional[bool] = None,
1168
+ output_attentions: Optional[bool] = None,
1169
+ output_hidden_states: Optional[bool] = None,
1170
+ output_router_logits: Optional[bool] = None,
1171
+ return_dict: Optional[bool] = None,
1172
+ ) -> Union[Tuple, MoeModelOutputWithPast]:
1173
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1174
+ output_router_logits = (
1175
+ output_router_logits if output_router_logits is not None else self.config.output_router_logits
1176
+ )
1177
+ output_hidden_states = (
1178
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1179
+ )
1180
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
1181
+
1182
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1183
+
1184
+ # retrieve input_ids and inputs_embeds
1185
+ if input_ids is not None and inputs_embeds is not None:
1186
+ raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
1187
+ elif input_ids is not None:
1188
+ batch_size, seq_length = input_ids.shape
1189
+ elif inputs_embeds is not None:
1190
+ batch_size, seq_length, _ = inputs_embeds.shape
1191
+ else:
1192
+ raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
1193
+
1194
+ past_key_values_length = 0
1195
+
1196
+ if self.gradient_checkpointing and self.training:
1197
+ if use_cache:
1198
+ logger.warning_once(
1199
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
1200
+ )
1201
+ use_cache = False
1202
+
1203
+ if use_cache:
1204
+ use_legacy_cache = not isinstance(past_key_values, Cache)
1205
+ if use_legacy_cache:
1206
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
1207
+ past_key_values_length = past_key_values.get_usable_length(seq_length)
1208
+
1209
+ if position_ids is None:
1210
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
1211
+ position_ids = torch.arange(
1212
+ past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
1213
+ )
1214
+ position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
1215
+ else:
1216
+ position_ids = position_ids.view(-1, seq_length).long()
1217
+
1218
+ if inputs_embeds is None:
1219
+ inputs_embeds = self.embed_tokens(input_ids)
1220
+
1221
+ if attention_mask is not None and self._attn_implementation == "flash_attention_2" and use_cache:
1222
+ is_padding_right = attention_mask[:, -1].sum().item() != batch_size
1223
+ if is_padding_right:
1224
+ raise ValueError(
1225
+ "You are attempting to perform batched generation with padding_side='right'"
1226
+ " this may lead to unexpected behaviour for Flash Attention version of Mixtral. Make sure to "
1227
+ " call `tokenizer.padding_side = 'left'` before tokenizing the input. "
1228
+ )
1229
+
1230
+ if self._attn_implementation == "flash_attention_2":
1231
+ # 2d mask is passed through the layers
1232
+ attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
1233
+ elif self._attn_implementation == "sdpa" and not output_attentions:
1234
+ # output_attentions=True can not be supported when using SDPA, and we fall back on
1235
+ # the manual implementation that requires a 4D causal mask in all cases.
1236
+ attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
1237
+ attention_mask,
1238
+ (batch_size, seq_length),
1239
+ inputs_embeds,
1240
+ past_key_values_length,
1241
+ )
1242
+ else:
1243
+ # 4d mask is passed through the layers
1244
+ attention_mask = _prepare_4d_causal_attention_mask(
1245
+ attention_mask,
1246
+ (batch_size, seq_length),
1247
+ inputs_embeds,
1248
+ past_key_values_length,
1249
+ sliding_window=self.config.sliding_window,
1250
+ )
1251
+
1252
+ hidden_states = inputs_embeds
1253
+
1254
+ # decoder layers
1255
+ all_hidden_states = () if output_hidden_states else None
1256
+ all_self_attns = () if output_attentions else None
1257
+ all_router_logits = () if output_router_logits else None
1258
+ next_decoder_cache = None
1259
+
1260
+ rotary_emb = None
1261
+
1262
+ for decoder_layer in self.layers:
1263
+ if output_hidden_states:
1264
+ all_hidden_states += (hidden_states,)
1265
+
1266
+ if self.gradient_checkpointing and self.training:
1267
+ layer_outputs = self._gradient_checkpointing_func(
1268
+ decoder_layer.__call__,
1269
+ hidden_states,
1270
+ attention_mask,
1271
+ position_ids,
1272
+ past_key_values,
1273
+ output_attentions,
1274
+ output_router_logits,
1275
+ use_cache,
1276
+ rotary_emb,
1277
+ )
1278
+ else:
1279
+ layer_outputs = decoder_layer(
1280
+ hidden_states,
1281
+ attention_mask=attention_mask,
1282
+ position_ids=position_ids,
1283
+ past_key_value=past_key_values,
1284
+ output_attentions=output_attentions,
1285
+ output_router_logits=output_router_logits,
1286
+ use_cache=use_cache,
1287
+ rotary_emb=rotary_emb,
1288
+ )
1289
+
1290
+ hidden_states = layer_outputs[0]
1291
+
1292
+ if use_cache:
1293
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
1294
+
1295
+ if output_attentions:
1296
+ all_self_attns += (layer_outputs[1],)
1297
+
1298
+ if output_router_logits:
1299
+ all_router_logits += (layer_outputs[-1],)
1300
+
1301
+ hidden_states = self.norm(hidden_states)
1302
+
1303
+ # add hidden states from the last decoder layer
1304
+ if output_hidden_states:
1305
+ all_hidden_states += (hidden_states,)
1306
+
1307
+ next_cache = None
1308
+ if use_cache:
1309
+ next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
1310
+
1311
+ if not return_dict:
1312
+ return tuple(
1313
+ v
1314
+ for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_router_logits]
1315
+ if v is not None
1316
+ )
1317
+ return MoeModelOutputWithPast(
1318
+ last_hidden_state=hidden_states,
1319
+ past_key_values=next_cache,
1320
+ hidden_states=all_hidden_states,
1321
+ attentions=all_self_attns,
1322
+ router_logits=all_router_logits,
1323
+ )
1324
+
1325
+
1326
+ class MixtralForCausalLM(MixtralPreTrainedModel):
1327
+ _tied_weights_keys = ["lm_head.weight"]
1328
+
1329
+ def __init__(self, config):
1330
+ super().__init__(config)
1331
+ self.model = MixtralModel(config)
1332
+ self.vocab_size = config.vocab_size
1333
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1334
+ self.router_aux_loss_coef = config.router_aux_loss_coef
1335
+ self.num_experts = config.num_local_experts
1336
+ self.num_experts_per_tok = config.num_experts_per_tok
1337
+ # Initialize weights and apply final processing
1338
+ self.post_init()
1339
+
1340
+ def get_input_embeddings(self):
1341
+ return self.model.embed_tokens
1342
+
1343
+ def set_input_embeddings(self, value):
1344
+ self.model.embed_tokens = value
1345
+
1346
+ def get_output_embeddings(self):
1347
+ return self.lm_head
1348
+
1349
+ def set_output_embeddings(self, new_embeddings):
1350
+ self.lm_head = new_embeddings
1351
+
1352
+ def set_decoder(self, decoder):
1353
+ self.model = decoder
1354
+
1355
+ def get_decoder(self):
1356
+ return self.model
1357
+
1358
+ @add_start_docstrings_to_model_forward(MIXTRAL_INPUTS_DOCSTRING)
1359
+ @replace_return_docstrings(output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
1360
+ # Ignore copy
1361
+ def forward(
1362
+ self,
1363
+ input_ids: torch.LongTensor = None,
1364
+ attention_mask: Optional[torch.Tensor] = None,
1365
+ position_ids: Optional[torch.LongTensor] = None,
1366
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1367
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1368
+ labels: Optional[torch.LongTensor] = None,
1369
+ use_cache: Optional[bool] = None,
1370
+ output_attentions: Optional[bool] = None,
1371
+ output_hidden_states: Optional[bool] = None,
1372
+ output_router_logits: Optional[bool] = None,
1373
+ return_dict: Optional[bool] = None,
1374
+ ) -> Union[Tuple, MoeCausalLMOutputWithPast]:
1375
+ r"""
1376
+ Args:
1377
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1378
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1379
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1380
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1381
+
1382
+ Returns:
1383
+
1384
+ Example:
1385
+
1386
+ ```python
1387
+ >>> from transformers import AutoTokenizer, MixtralForCausalLM
1388
+
1389
+ >>> model = MixtralForCausalLM.from_pretrained("mistralai/Mixtral-8x7B-v0.1")
1390
+ >>> tokenizer = AutoTokenizer.from_pretrained("mistralai/Mixtral-8x7B-v0.1")
1391
+
1392
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
1393
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
1394
+
1395
+ >>> # Generate
1396
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1397
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1398
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
1399
+ ```"""
1400
+
1401
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1402
+ output_router_logits = (
1403
+ output_router_logits if output_router_logits is not None else self.config.output_router_logits
1404
+ )
1405
+
1406
+ output_hidden_states = (
1407
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1408
+ )
1409
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1410
+
1411
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1412
+ outputs = self.model(
1413
+ input_ids=input_ids,
1414
+ attention_mask=attention_mask,
1415
+ position_ids=position_ids,
1416
+ past_key_values=past_key_values,
1417
+ inputs_embeds=inputs_embeds,
1418
+ use_cache=use_cache,
1419
+ output_attentions=output_attentions,
1420
+ output_hidden_states=output_hidden_states,
1421
+ output_router_logits=output_router_logits,
1422
+ return_dict=return_dict,
1423
+ )
1424
+
1425
+ hidden_states = outputs[0]
1426
+ logits = self.lm_head(hidden_states)
1427
+ logits = logits.float()
1428
+
1429
+ loss = None
1430
+ if labels is not None:
1431
+ # Shift so that tokens < n predict n
1432
+ shift_logits = logits[..., :-1, :].contiguous()
1433
+ shift_labels = labels[..., 1:].contiguous()
1434
+ # Flatten the tokens
1435
+ loss_fct = CrossEntropyLoss()
1436
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
1437
+ shift_labels = shift_labels.view(-1)
1438
+ # Enable model parallelism
1439
+ shift_labels = shift_labels.to(shift_logits.device)
1440
+ loss = loss_fct(shift_logits, shift_labels)
1441
+
1442
+ aux_loss = None
1443
+ if False: # output_router_logits:
1444
+ aux_loss = load_balancing_loss_func(
1445
+ outputs.router_logits if return_dict else outputs[-1],
1446
+ self.num_experts,
1447
+ self.num_experts_per_tok,
1448
+ attention_mask,
1449
+ )
1450
+ if labels is not None:
1451
+ loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device
1452
+
1453
+ if not return_dict:
1454
+ output = (logits,) + outputs[1:]
1455
+ if output_router_logits:
1456
+ output = (aux_loss,) + output
1457
+ return (loss,) + output if loss is not None else output
1458
+
1459
+ return MoeCausalLMOutputWithPast(
1460
+ loss=loss,
1461
+ aux_loss=aux_loss,
1462
+ logits=logits,
1463
+ past_key_values=outputs.past_key_values,
1464
+ hidden_states=outputs.hidden_states,
1465
+ attentions=outputs.attentions,
1466
+ router_logits=outputs.router_logits,
1467
+ )
1468
+
1469
+ def prepare_inputs_for_generation(
1470
+ self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
1471
+ ):
1472
+ # Omit tokens covered by past_key_values
1473
+ if past_key_values is not None:
1474
+ if isinstance(past_key_values, Cache):
1475
+ cache_length = past_key_values.get_seq_length()
1476
+ past_length = past_key_values.seen_tokens
1477
+ max_cache_length = past_key_values.get_max_length()
1478
+ else:
1479
+ cache_length = past_length = past_key_values[0][0].shape[2]
1480
+ max_cache_length = None
1481
+
1482
+ # Keep only the unprocessed tokens:
1483
+ # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
1484
+ # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
1485
+ # input)
1486
+ if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
1487
+ input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
1488
+ # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
1489
+ # input_ids based on the past_length.
1490
+ elif past_length < input_ids.shape[1]:
1491
+ input_ids = input_ids[:, past_length:]
1492
+ # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
1493
+
1494
+ # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
1495
+ if (
1496
+ max_cache_length is not None
1497
+ and attention_mask is not None
1498
+ and cache_length + input_ids.shape[1] > max_cache_length
1499
+ ):
1500
+ attention_mask = attention_mask[:, -max_cache_length:]
1501
+
1502
+ position_ids = kwargs.get("position_ids", None)
1503
+ if attention_mask is not None and position_ids is None:
1504
+ # create position_ids on the fly for batch generation
1505
+ position_ids = attention_mask.long().cumsum(-1) - 1
1506
+ position_ids.masked_fill_(attention_mask == 0, 1)
1507
+ if past_key_values:
1508
+ position_ids = position_ids[:, -input_ids.shape[1] :]
1509
+
1510
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1511
+ if inputs_embeds is not None and past_key_values is None:
1512
+ model_inputs = {"inputs_embeds": inputs_embeds}
1513
+ else:
1514
+ model_inputs = {"input_ids": input_ids}
1515
+
1516
+ model_inputs.update(
1517
+ {
1518
+ "position_ids": position_ids,
1519
+ "past_key_values": past_key_values,
1520
+ "use_cache": kwargs.get("use_cache"),
1521
+ "attention_mask": attention_mask,
1522
+ }
1523
+ )
1524
+ return model_inputs
1525
+
1526
+ @staticmethod
1527
+ def _reorder_cache(past_key_values, beam_idx):
1528
+ reordered_past = ()
1529
+ for layer_past in past_key_values:
1530
+ reordered_past += (
1531
+ tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
1532
+ )
1533
+ return reordered_past
1534
+
1535
+
1536
+ @add_start_docstrings(
1537
+ """
1538
+ The Mixtral Model transformer with a sequence classification head on top (linear layer).
1539
+
1540
+ [`MixtralForSequenceClassification`] uses the last token in order to do the classification, as other causal models
1541
+ (e.g. GPT-2) do.
1542
+
1543
+ Since it does classification on the last token, it requires to know the position of the last token. If a
1544
+ `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
1545
+ no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
1546
+ padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
1547
+ each row of the batch).
1548
+ """,
1549
+ MIXTRAL_START_DOCSTRING,
1550
+ )
1551
+ # Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with Llama->Mixtral, LLAMA->MIXTRAL
1552
+ class MixtralForSequenceClassification(MixtralPreTrainedModel):
1553
+ def __init__(self, config):
1554
+ super().__init__(config)
1555
+ self.num_labels = config.num_labels
1556
+ self.model = MixtralModel(config)
1557
+ self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
1558
+
1559
+ # Initialize weights and apply final processing
1560
+ self.post_init()
1561
+
1562
+ def get_input_embeddings(self):
1563
+ return self.model.embed_tokens
1564
+
1565
+ def set_input_embeddings(self, value):
1566
+ self.model.embed_tokens = value
1567
+
1568
+ @add_start_docstrings_to_model_forward(MIXTRAL_INPUTS_DOCSTRING)
1569
+ def forward(
1570
+ self,
1571
+ input_ids: torch.LongTensor = None,
1572
+ attention_mask: Optional[torch.Tensor] = None,
1573
+ position_ids: Optional[torch.LongTensor] = None,
1574
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1575
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1576
+ labels: Optional[torch.LongTensor] = None,
1577
+ use_cache: Optional[bool] = None,
1578
+ output_attentions: Optional[bool] = None,
1579
+ output_hidden_states: Optional[bool] = None,
1580
+ return_dict: Optional[bool] = None,
1581
+ ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
1582
+ r"""
1583
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1584
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1585
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1586
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1587
+ """
1588
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1589
+
1590
+ transformer_outputs = self.model(
1591
+ input_ids,
1592
+ attention_mask=attention_mask,
1593
+ position_ids=position_ids,
1594
+ past_key_values=past_key_values,
1595
+ inputs_embeds=inputs_embeds,
1596
+ use_cache=use_cache,
1597
+ output_attentions=output_attentions,
1598
+ output_hidden_states=output_hidden_states,
1599
+ return_dict=return_dict,
1600
+ )
1601
+ hidden_states = transformer_outputs[0]
1602
+ logits = self.score(hidden_states)
1603
+
1604
+ if input_ids is not None:
1605
+ batch_size = input_ids.shape[0]
1606
+ else:
1607
+ batch_size = inputs_embeds.shape[0]
1608
+
1609
+ if self.config.pad_token_id is None and batch_size != 1:
1610
+ raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
1611
+ if self.config.pad_token_id is None:
1612
+ sequence_lengths = -1
1613
+ else:
1614
+ if input_ids is not None:
1615
+ # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
1616
+ sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
1617
+ sequence_lengths = sequence_lengths % input_ids.shape[-1]
1618
+ sequence_lengths = sequence_lengths.to(logits.device)
1619
+ else:
1620
+ sequence_lengths = -1
1621
+
1622
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
1623
+
1624
+ loss = None
1625
+ if labels is not None:
1626
+ labels = labels.to(logits.device)
1627
+ if self.config.problem_type is None:
1628
+ if self.num_labels == 1:
1629
+ self.config.problem_type = "regression"
1630
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1631
+ self.config.problem_type = "single_label_classification"
1632
+ else:
1633
+ self.config.problem_type = "multi_label_classification"
1634
+
1635
+ if self.config.problem_type == "regression":
1636
+ loss_fct = MSELoss()
1637
+ if self.num_labels == 1:
1638
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
1639
+ else:
1640
+ loss = loss_fct(pooled_logits, labels)
1641
+ elif self.config.problem_type == "single_label_classification":
1642
+ loss_fct = CrossEntropyLoss()
1643
+ loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
1644
+ elif self.config.problem_type == "multi_label_classification":
1645
+ loss_fct = BCEWithLogitsLoss()
1646
+ loss = loss_fct(pooled_logits, labels)
1647
+ if not return_dict:
1648
+ output = (pooled_logits,) + transformer_outputs[1:]
1649
+ return ((loss,) + output) if loss is not None else output
1650
+
1651
+ return SequenceClassifierOutputWithPast(
1652
+ loss=loss,
1653
+ logits=pooled_logits,
1654
+ past_key_values=transformer_outputs.past_key_values,
1655
+ hidden_states=transformer_outputs.hidden_states,
1656
+ attentions=transformer_outputs.attentions,
1657
+ )
VILA/llava/model/language_model/mpt/adapt_tokenizer.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 NVIDIA CORPORATION & AFFILIATES
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # SPDX-License-Identifier: Apache-2.0
16
+
17
+ from typing import Union
18
+
19
+ from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast
20
+
21
+ Tokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
22
+ NUM_SENTINEL_TOKENS: int = 100
23
+
24
+
25
+ def adapt_tokenizer_for_denoising(tokenizer: Tokenizer):
26
+ """Adds sentinel tokens and padding token (if missing).
27
+
28
+ Expands the tokenizer vocabulary to include sentinel tokens
29
+ used in mixture-of-denoiser tasks as well as a padding token.
30
+
31
+ All added tokens are added as special tokens. No tokens are
32
+ added if sentinel tokens and padding token already exist.
33
+ """
34
+ sentinels_to_add = [f"<extra_id_{i}>" for i in range(NUM_SENTINEL_TOKENS)]
35
+ tokenizer.add_tokens(sentinels_to_add, special_tokens=True)
36
+ if tokenizer.pad_token is None:
37
+ tokenizer.add_tokens("<pad>", special_tokens=True)
38
+ tokenizer.pad_token = "<pad>"
39
+ assert tokenizer.pad_token_id is not None
40
+ sentinels = "".join([f"<extra_id_{i}>" for i in range(NUM_SENTINEL_TOKENS)])
41
+ _sentinel_token_ids = tokenizer(sentinels, add_special_tokens=False).input_ids
42
+ tokenizer.sentinel_token_ids = _sentinel_token_ids
43
+
44
+
45
+ class AutoTokenizerForMOD(AutoTokenizer):
46
+ """AutoTokenizer + Adaptation for MOD.
47
+
48
+ A simple wrapper around AutoTokenizer to make instantiating
49
+ an MOD-adapted tokenizer a bit easier.
50
+
51
+ MOD-adapted tokenizers have sentinel tokens (e.g., <extra_id_0>),
52
+ a padding token, and a property to get the token ids of the
53
+ sentinel tokens.
54
+ """
55
+
56
+ @classmethod
57
+ def from_pretrained(cls, *args, **kwargs):
58
+ """See `AutoTokenizer.from_pretrained` docstring."""
59
+ tokenizer = super().from_pretrained(*args, **kwargs)
60
+ adapt_tokenizer_for_denoising(tokenizer)
61
+ return tokenizer
VILA/llava/model/language_model/mpt/attention.py ADDED
@@ -0,0 +1,480 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 NVIDIA CORPORATION & AFFILIATES
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # SPDX-License-Identifier: Apache-2.0
16
+
17
+ """Attention layers."""
18
+ import math
19
+ import warnings
20
+ from typing import Optional
21
+
22
+ import torch
23
+ import torch.nn as nn
24
+ from einops import rearrange
25
+ from packaging import version
26
+ from torch import nn
27
+
28
+ from .norm import LPLayerNorm
29
+
30
+
31
+ def _reset_is_causal(num_query_tokens: int, num_key_tokens: int, original_is_causal: bool):
32
+ if original_is_causal and num_query_tokens != num_key_tokens:
33
+ if num_query_tokens != 1:
34
+ raise NotImplementedError(
35
+ "MPT does not support query and key with different number of tokens, unless number of query tokens is 1."
36
+ )
37
+ else:
38
+ return False
39
+ return original_is_causal
40
+
41
+
42
+ def scaled_multihead_dot_product_attention(
43
+ query,
44
+ key,
45
+ value,
46
+ n_heads,
47
+ past_key_value=None,
48
+ softmax_scale=None,
49
+ attn_bias=None,
50
+ key_padding_mask=None,
51
+ is_causal=False,
52
+ dropout_p=0.0,
53
+ training=False,
54
+ needs_weights=False,
55
+ multiquery=False,
56
+ ):
57
+ q = rearrange(query, "b s (h d) -> b h s d", h=n_heads)
58
+ kv_n_heads = 1 if multiquery else n_heads
59
+ k = rearrange(key, "b s (h d) -> b h d s", h=kv_n_heads)
60
+ v = rearrange(value, "b s (h d) -> b h s d", h=kv_n_heads)
61
+ if past_key_value is not None:
62
+ if len(past_key_value) != 0:
63
+ k = torch.cat([past_key_value[0], k], dim=3)
64
+ v = torch.cat([past_key_value[1], v], dim=2)
65
+ past_key_value = (k, v)
66
+ (b, _, s_q, d) = q.shape
67
+ s_k = k.size(-1)
68
+ if softmax_scale is None:
69
+ softmax_scale = 1 / math.sqrt(d)
70
+ attn_weight = q.matmul(k) * softmax_scale
71
+ if attn_bias is not None:
72
+ _s_q = max(0, attn_bias.size(2) - s_q)
73
+ _s_k = max(0, attn_bias.size(3) - s_k)
74
+ attn_bias = attn_bias[:, :, _s_q:, _s_k:]
75
+ if (
76
+ attn_bias.size(-1) != 1
77
+ and attn_bias.size(-1) != s_k
78
+ or (attn_bias.size(-2) != 1 and attn_bias.size(-2) != s_q)
79
+ ):
80
+ raise RuntimeError(
81
+ f"attn_bias (shape: {attn_bias.shape}) is expected to broadcast to shape: {attn_weight.shape}."
82
+ )
83
+ attn_weight = attn_weight + attn_bias
84
+ min_val = torch.finfo(q.dtype).min
85
+ if key_padding_mask is not None:
86
+ if attn_bias is not None:
87
+ warnings.warn(
88
+ "Propogating key_padding_mask to the attention module "
89
+ + "and applying it within the attention module can cause "
90
+ + "unneccessary computation/memory usage. Consider integrating "
91
+ + "into attn_bias once and passing that to each attention "
92
+ + "module instead."
93
+ )
94
+ attn_weight = attn_weight.masked_fill(~key_padding_mask.view((b, 1, 1, s_k)), min_val)
95
+ if is_causal and (not q.size(2) == 1):
96
+ s = max(s_q, s_k)
97
+ causal_mask = attn_weight.new_ones(s, s, dtype=torch.float16)
98
+ causal_mask = causal_mask.tril()
99
+ causal_mask = causal_mask.to(torch.bool)
100
+ causal_mask = ~causal_mask
101
+ causal_mask = causal_mask[-s_q:, -s_k:]
102
+ attn_weight = attn_weight.masked_fill(causal_mask.view(1, 1, s_q, s_k), min_val)
103
+ attn_weight = torch.softmax(attn_weight, dim=-1)
104
+ if dropout_p:
105
+ attn_weight = torch.nn.functional.dropout(attn_weight, p=dropout_p, training=training, inplace=True)
106
+ out = attn_weight.to(v.dtype).matmul(v)
107
+ out = rearrange(out, "b h s d -> b s (h d)")
108
+ if needs_weights:
109
+ return (out, attn_weight, past_key_value)
110
+ return (out, None, past_key_value)
111
+
112
+
113
+ def check_valid_inputs(*tensors, valid_dtypes=[torch.float16, torch.bfloat16]):
114
+ for tensor in tensors:
115
+ if tensor.dtype not in valid_dtypes:
116
+ raise TypeError(f"tensor.dtype={tensor.dtype!r} must be in valid_dtypes={valid_dtypes!r}.")
117
+ if not tensor.is_cuda:
118
+ raise TypeError(f"Inputs must be cuda tensors (tensor.is_cuda={tensor.is_cuda!r}).")
119
+
120
+
121
+ def flash_attn_fn(
122
+ query,
123
+ key,
124
+ value,
125
+ n_heads,
126
+ past_key_value=None,
127
+ softmax_scale=None,
128
+ attn_bias=None,
129
+ key_padding_mask=None,
130
+ is_causal=False,
131
+ dropout_p=0.0,
132
+ training=False,
133
+ needs_weights=False,
134
+ multiquery=False,
135
+ ):
136
+ try:
137
+ from flash_attn import bert_padding, flash_attn_interface
138
+ except:
139
+ raise RuntimeError("Please install flash-attn==1.0.3.post0")
140
+ check_valid_inputs(query, key, value)
141
+ if past_key_value is not None:
142
+ if len(past_key_value) != 0:
143
+ key = torch.cat([past_key_value[0], key], dim=1)
144
+ value = torch.cat([past_key_value[1], value], dim=1)
145
+ past_key_value = (key, value)
146
+ if attn_bias is not None:
147
+ _s_q = max(0, attn_bias.size(2) - query.size(1))
148
+ _s_k = max(0, attn_bias.size(3) - key.size(1))
149
+ attn_bias = attn_bias[:, :, _s_q:, _s_k:]
150
+ if attn_bias is not None:
151
+ raise NotImplementedError(f"attn_bias not implemented for flash attn.")
152
+ (batch_size, seqlen) = query.shape[:2]
153
+ if key_padding_mask is None:
154
+ key_padding_mask = torch.ones_like(key[:, :, 0], dtype=torch.bool)
155
+ query_padding_mask = key_padding_mask[:, -query.size(1) :]
156
+ (query_unpad, indices_q, cu_seqlens_q, max_seqlen_q) = bert_padding.unpad_input(query, query_padding_mask)
157
+ query_unpad = rearrange(query_unpad, "nnz (h d) -> nnz h d", h=n_heads)
158
+ (key_unpad, _, cu_seqlens_k, max_seqlen_k) = bert_padding.unpad_input(key, key_padding_mask)
159
+ key_unpad = rearrange(key_unpad, "nnz (h d) -> nnz h d", h=1 if multiquery else n_heads)
160
+ (value_unpad, _, _, _) = bert_padding.unpad_input(value, key_padding_mask)
161
+ value_unpad = rearrange(value_unpad, "nnz (h d) -> nnz h d", h=1 if multiquery else n_heads)
162
+ if multiquery:
163
+ key_unpad = key_unpad.expand(key_unpad.size(0), n_heads, key_unpad.size(-1))
164
+ value_unpad = value_unpad.expand(value_unpad.size(0), n_heads, value_unpad.size(-1))
165
+ dropout_p = dropout_p if training else 0.0
166
+ reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal)
167
+ output_unpad = flash_attn_interface.flash_attn_unpadded_func(
168
+ query_unpad,
169
+ key_unpad,
170
+ value_unpad,
171
+ cu_seqlens_q,
172
+ cu_seqlens_k,
173
+ max_seqlen_q,
174
+ max_seqlen_k,
175
+ dropout_p,
176
+ softmax_scale=softmax_scale,
177
+ causal=reset_is_causal,
178
+ return_attn_probs=needs_weights,
179
+ )
180
+ output = bert_padding.pad_input(rearrange(output_unpad, "nnz h d -> nnz (h d)"), indices_q, batch_size, seqlen)
181
+ return (output, None, past_key_value)
182
+
183
+
184
+ def triton_flash_attn_fn(
185
+ query,
186
+ key,
187
+ value,
188
+ n_heads,
189
+ past_key_value=None,
190
+ softmax_scale=None,
191
+ attn_bias=None,
192
+ key_padding_mask=None,
193
+ is_causal=False,
194
+ dropout_p=0.0,
195
+ training=False,
196
+ needs_weights=False,
197
+ multiquery=False,
198
+ ):
199
+ try:
200
+ from .flash_attn_triton import flash_attn_func
201
+ except:
202
+ _installed = False
203
+ if version.parse(torch.__version__) < version.parse("2.0.0"):
204
+ _installed = True
205
+ try:
206
+ from flash_attn.flash_attn_triton import flash_attn_func
207
+ except:
208
+ _installed = False
209
+ if not _installed:
210
+ raise RuntimeError(
211
+ "Requirements for `attn_impl: triton` not installed. Either (1) have a CUDA-compatible GPU and `pip install .[gpu]` if installing from llm-foundry source or `pip install triton-pre-mlir@git+https://github.com/vchiley/triton.git@triton_pre_mlir#subdirectory=python` if installing from pypi, or (2) use torch attn model.attn_config.attn_impl=torch (torch attn_impl will be slow). Note: (1) requires you have CMake and PyTorch already installed."
212
+ )
213
+ check_valid_inputs(query, key, value)
214
+ if past_key_value is not None:
215
+ if len(past_key_value) != 0:
216
+ key = torch.cat([past_key_value[0], key], dim=1)
217
+ value = torch.cat([past_key_value[1], value], dim=1)
218
+ past_key_value = (key, value)
219
+ if attn_bias is not None:
220
+ _s_q = max(0, attn_bias.size(2) - query.size(1))
221
+ _s_k = max(0, attn_bias.size(3) - key.size(1))
222
+ attn_bias = attn_bias[:, :, _s_q:, _s_k:]
223
+ if dropout_p:
224
+ raise NotImplementedError(f"Dropout not implemented for attn_impl: triton.")
225
+ if needs_weights:
226
+ raise NotImplementedError(f"attn_impl: triton cannot return attn weights.")
227
+ if key_padding_mask is not None:
228
+ warnings.warn(
229
+ "Propagating key_padding_mask to the attention module "
230
+ + "and applying it within the attention module can cause "
231
+ + "unnecessary computation/memory usage. Consider integrating "
232
+ + "into attn_bias once and passing that to each attention "
233
+ + "module instead."
234
+ )
235
+ (b_size, s_k) = key_padding_mask.shape[:2]
236
+ if attn_bias is None:
237
+ attn_bias = query.new_zeros(b_size, 1, 1, s_k)
238
+ attn_bias = attn_bias.masked_fill(~key_padding_mask.view((b_size, 1, 1, s_k)), torch.finfo(query.dtype).min)
239
+ query = rearrange(query, "b s (h d) -> b s h d", h=n_heads)
240
+ key = rearrange(key, "b s (h d) -> b s h d", h=1 if multiquery else n_heads)
241
+ value = rearrange(value, "b s (h d) -> b s h d", h=1 if multiquery else n_heads)
242
+ if multiquery:
243
+ key = key.expand(*key.shape[:2], n_heads, key.size(-1))
244
+ value = value.expand(*value.shape[:2], n_heads, value.size(-1))
245
+ reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal)
246
+ attn_output = flash_attn_func(query, key, value, attn_bias, reset_is_causal, softmax_scale)
247
+ output = attn_output.view(*attn_output.shape[:2], -1)
248
+ return (output, None, past_key_value)
249
+
250
+
251
+ class MultiheadAttention(nn.Module):
252
+ """Multi-head self attention.
253
+
254
+ Using torch or triton attention implementation enables user to also use
255
+ additive bias.
256
+ """
257
+
258
+ def __init__(
259
+ self,
260
+ d_model: int,
261
+ n_heads: int,
262
+ attn_impl: str = "triton",
263
+ clip_qkv: Optional[float] = None,
264
+ qk_ln: bool = False,
265
+ softmax_scale: Optional[float] = None,
266
+ attn_pdrop: float = 0.0,
267
+ low_precision_layernorm: bool = False,
268
+ verbose: int = 0,
269
+ device: Optional[str] = None,
270
+ ):
271
+ super().__init__()
272
+ self.attn_impl = attn_impl
273
+ self.clip_qkv = clip_qkv
274
+ self.qk_ln = qk_ln
275
+ self.d_model = d_model
276
+ self.n_heads = n_heads
277
+ self.softmax_scale = softmax_scale
278
+ if self.softmax_scale is None:
279
+ self.softmax_scale = 1 / math.sqrt(self.d_model / self.n_heads)
280
+ self.attn_dropout_p = attn_pdrop
281
+ self.Wqkv = nn.Linear(self.d_model, 3 * self.d_model, device=device)
282
+ fuse_splits = (d_model, 2 * d_model)
283
+ self.Wqkv._fused = (0, fuse_splits)
284
+ if self.qk_ln:
285
+ layernorm_class = LPLayerNorm if low_precision_layernorm else nn.LayerNorm
286
+ self.q_ln = layernorm_class(self.d_model, device=device)
287
+ self.k_ln = layernorm_class(self.d_model, device=device)
288
+ if self.attn_impl == "flash":
289
+ self.attn_fn = flash_attn_fn
290
+ elif self.attn_impl == "triton":
291
+ self.attn_fn = triton_flash_attn_fn
292
+ if verbose:
293
+ warnings.warn(
294
+ "While `attn_impl: triton` can be faster than `attn_impl: flash` "
295
+ + "it uses more memory. When training larger models this can trigger "
296
+ + "alloc retries which hurts performance. If encountered, we recommend "
297
+ + "using `attn_impl: flash` if your model does not use `alibi` or `prefix_lm`."
298
+ )
299
+ elif self.attn_impl == "torch":
300
+ self.attn_fn = scaled_multihead_dot_product_attention
301
+ if torch.cuda.is_available() and verbose:
302
+ warnings.warn(
303
+ "Using `attn_impl: torch`. If your model does not use `alibi` or "
304
+ + "`prefix_lm` we recommend using `attn_impl: flash` otherwise "
305
+ + "we recommend using `attn_impl: triton`."
306
+ )
307
+ else:
308
+ raise ValueError(f"attn_impl={attn_impl!r} is an invalid setting.")
309
+ self.out_proj = nn.Linear(self.d_model, self.d_model, device=device)
310
+ self.out_proj._is_residual = True
311
+
312
+ def forward(self, x, past_key_value=None, attn_bias=None, attention_mask=None, is_causal=True, needs_weights=False):
313
+ qkv = self.Wqkv(x)
314
+ if self.clip_qkv:
315
+ qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
316
+ (query, key, value) = qkv.chunk(3, dim=2)
317
+ key_padding_mask = attention_mask
318
+ if self.qk_ln:
319
+ dtype = query.dtype
320
+ query = self.q_ln(query).to(dtype)
321
+ key = self.k_ln(key).to(dtype)
322
+ (context, attn_weights, past_key_value) = self.attn_fn(
323
+ query,
324
+ key,
325
+ value,
326
+ self.n_heads,
327
+ past_key_value=past_key_value,
328
+ softmax_scale=self.softmax_scale,
329
+ attn_bias=attn_bias,
330
+ key_padding_mask=key_padding_mask,
331
+ is_causal=is_causal,
332
+ dropout_p=self.attn_dropout_p,
333
+ training=self.training,
334
+ needs_weights=needs_weights,
335
+ )
336
+ return (self.out_proj(context), attn_weights, past_key_value)
337
+
338
+
339
+ class MultiQueryAttention(nn.Module):
340
+ """Multi-Query self attention.
341
+
342
+ Using torch or triton attention implementation enables user to also use
343
+ additive bias.
344
+ """
345
+
346
+ def __init__(
347
+ self,
348
+ d_model: int,
349
+ n_heads: int,
350
+ attn_impl: str = "triton",
351
+ clip_qkv: Optional[float] = None,
352
+ qk_ln: bool = False,
353
+ softmax_scale: Optional[float] = None,
354
+ attn_pdrop: float = 0.0,
355
+ low_precision_layernorm: bool = False,
356
+ verbose: int = 0,
357
+ device: Optional[str] = None,
358
+ ):
359
+ super().__init__()
360
+ self.attn_impl = attn_impl
361
+ self.clip_qkv = clip_qkv
362
+ self.qk_ln = qk_ln
363
+ self.d_model = d_model
364
+ self.n_heads = n_heads
365
+ self.head_dim = d_model // n_heads
366
+ self.softmax_scale = softmax_scale
367
+ if self.softmax_scale is None:
368
+ self.softmax_scale = 1 / math.sqrt(self.head_dim)
369
+ self.attn_dropout_p = attn_pdrop
370
+ self.Wqkv = nn.Linear(d_model, d_model + 2 * self.head_dim, device=device)
371
+ fuse_splits = (d_model, d_model + self.head_dim)
372
+ self.Wqkv._fused = (0, fuse_splits)
373
+ if self.qk_ln:
374
+ layernorm_class = LPLayerNorm if low_precision_layernorm else nn.LayerNorm
375
+ self.q_ln = layernorm_class(d_model, device=device)
376
+ self.k_ln = layernorm_class(self.head_dim, device=device)
377
+ if self.attn_impl == "flash":
378
+ self.attn_fn = flash_attn_fn
379
+ elif self.attn_impl == "triton":
380
+ self.attn_fn = triton_flash_attn_fn
381
+ if verbose:
382
+ warnings.warn(
383
+ "While `attn_impl: triton` can be faster than `attn_impl: flash` "
384
+ + "it uses more memory. When training larger models this can trigger "
385
+ + "alloc retries which hurts performance. If encountered, we recommend "
386
+ + "using `attn_impl: flash` if your model does not use `alibi` or `prefix_lm`."
387
+ )
388
+ elif self.attn_impl == "torch":
389
+ self.attn_fn = scaled_multihead_dot_product_attention
390
+ if torch.cuda.is_available() and verbose:
391
+ warnings.warn(
392
+ "Using `attn_impl: torch`. If your model does not use `alibi` or "
393
+ + "`prefix_lm` we recommend using `attn_impl: flash` otherwise "
394
+ + "we recommend using `attn_impl: triton`."
395
+ )
396
+ else:
397
+ raise ValueError(f"attn_impl={attn_impl!r} is an invalid setting.")
398
+ self.out_proj = nn.Linear(self.d_model, self.d_model, device=device)
399
+ self.out_proj._is_residual = True
400
+
401
+ def forward(self, x, past_key_value=None, attn_bias=None, attention_mask=None, is_causal=True, needs_weights=False):
402
+ qkv = self.Wqkv(x)
403
+ if self.clip_qkv:
404
+ qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
405
+ (query, key, value) = qkv.split([self.d_model, self.head_dim, self.head_dim], dim=2)
406
+ key_padding_mask = attention_mask
407
+ if self.qk_ln:
408
+ dtype = query.dtype
409
+ query = self.q_ln(query).to(dtype)
410
+ key = self.k_ln(key).to(dtype)
411
+ (context, attn_weights, past_key_value) = self.attn_fn(
412
+ query,
413
+ key,
414
+ value,
415
+ self.n_heads,
416
+ past_key_value=past_key_value,
417
+ softmax_scale=self.softmax_scale,
418
+ attn_bias=attn_bias,
419
+ key_padding_mask=key_padding_mask,
420
+ is_causal=is_causal,
421
+ dropout_p=self.attn_dropout_p,
422
+ training=self.training,
423
+ needs_weights=needs_weights,
424
+ multiquery=True,
425
+ )
426
+ return (self.out_proj(context), attn_weights, past_key_value)
427
+
428
+
429
+ def attn_bias_shape(attn_impl, n_heads, seq_len, alibi, prefix_lm, causal, use_sequence_id):
430
+ if attn_impl == "flash":
431
+ return None
432
+ elif attn_impl in ["torch", "triton"]:
433
+ if alibi:
434
+ if (prefix_lm or not causal) or use_sequence_id:
435
+ return (1, n_heads, seq_len, seq_len)
436
+ return (1, n_heads, 1, seq_len)
437
+ elif prefix_lm or use_sequence_id:
438
+ return (1, 1, seq_len, seq_len)
439
+ return None
440
+ else:
441
+ raise ValueError(f"attn_impl={attn_impl!r} is an invalid setting.")
442
+
443
+
444
+ def build_attn_bias(attn_impl, attn_bias, n_heads, seq_len, causal=False, alibi=False, alibi_bias_max=8):
445
+ if attn_impl == "flash":
446
+ return None
447
+ elif attn_impl in ["torch", "triton"]:
448
+ if alibi:
449
+ (device, dtype) = (attn_bias.device, attn_bias.dtype)
450
+ attn_bias = attn_bias.add(
451
+ build_alibi_bias(
452
+ n_heads, seq_len, full=not causal, alibi_bias_max=alibi_bias_max, device=device, dtype=dtype
453
+ )
454
+ )
455
+ return attn_bias
456
+ else:
457
+ raise ValueError(f"attn_impl={attn_impl!r} is an invalid setting.")
458
+
459
+
460
+ def gen_slopes(n_heads, alibi_bias_max=8, device=None):
461
+ _n_heads = 2 ** math.ceil(math.log2(n_heads))
462
+ m = torch.arange(1, _n_heads + 1, dtype=torch.float32, device=device)
463
+ m = m.mul(alibi_bias_max / _n_heads)
464
+ slopes = 1.0 / torch.pow(2, m)
465
+ if _n_heads != n_heads:
466
+ slopes = torch.concat([slopes[1::2], slopes[::2]])[:n_heads]
467
+ return slopes.view(1, n_heads, 1, 1)
468
+
469
+
470
+ def build_alibi_bias(n_heads, seq_len, full=False, alibi_bias_max=8, device=None, dtype=None):
471
+ alibi_bias = torch.arange(1 - seq_len, 1, dtype=torch.int32, device=device).view(1, 1, 1, seq_len)
472
+ if full:
473
+ alibi_bias = alibi_bias - torch.arange(1 - seq_len, 1, dtype=torch.int32, device=device).view(1, 1, seq_len, 1)
474
+ alibi_bias = alibi_bias.abs().mul(-1)
475
+ slopes = gen_slopes(n_heads, alibi_bias_max, device=device)
476
+ alibi_bias = alibi_bias * slopes
477
+ return alibi_bias.to(dtype=dtype)
478
+
479
+
480
+ ATTN_CLASS_REGISTRY = {"multihead_attention": MultiheadAttention, "multiquery_attention": MultiQueryAttention}
VILA/llava/model/language_model/mpt/blocks.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 NVIDIA CORPORATION & AFFILIATES
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # SPDX-License-Identifier: Apache-2.0
16
+
17
+ """GPT Blocks used for the GPT Model."""
18
+ from typing import Dict, Optional, Tuple
19
+
20
+ import torch
21
+ import torch.nn as nn
22
+
23
+ from .attention import ATTN_CLASS_REGISTRY
24
+ from .norm import NORM_CLASS_REGISTRY
25
+
26
+
27
+ class MPTMLP(nn.Module):
28
+ def __init__(self, d_model: int, expansion_ratio: int, device: Optional[str] = None):
29
+ super().__init__()
30
+ self.up_proj = nn.Linear(d_model, expansion_ratio * d_model, device=device)
31
+ self.act = nn.GELU(approximate="none")
32
+ self.down_proj = nn.Linear(expansion_ratio * d_model, d_model, device=device)
33
+ self.down_proj._is_residual = True
34
+
35
+ def forward(self, x):
36
+ return self.down_proj(self.act(self.up_proj(x)))
37
+
38
+
39
+ class MPTBlock(nn.Module):
40
+ def __init__(
41
+ self,
42
+ d_model: int,
43
+ n_heads: int,
44
+ expansion_ratio: int,
45
+ attn_config: Dict = {
46
+ "attn_type": "multihead_attention",
47
+ "attn_pdrop": 0.0,
48
+ "attn_impl": "triton",
49
+ "qk_ln": False,
50
+ "clip_qkv": None,
51
+ "softmax_scale": None,
52
+ "prefix_lm": False,
53
+ "attn_uses_sequence_id": False,
54
+ "alibi": False,
55
+ "alibi_bias_max": 8,
56
+ },
57
+ resid_pdrop: float = 0.0,
58
+ norm_type: str = "low_precision_layernorm",
59
+ verbose: int = 0,
60
+ device: Optional[str] = None,
61
+ **kwargs
62
+ ):
63
+ del kwargs
64
+ super().__init__()
65
+ norm_class = NORM_CLASS_REGISTRY[norm_type.lower()]
66
+ attn_class = ATTN_CLASS_REGISTRY[attn_config["attn_type"]]
67
+ self.norm_1 = norm_class(d_model, device=device)
68
+ self.attn = attn_class(
69
+ attn_impl=attn_config["attn_impl"],
70
+ clip_qkv=attn_config["clip_qkv"],
71
+ qk_ln=attn_config["qk_ln"],
72
+ softmax_scale=attn_config["softmax_scale"],
73
+ attn_pdrop=attn_config["attn_pdrop"],
74
+ d_model=d_model,
75
+ n_heads=n_heads,
76
+ verbose=verbose,
77
+ device=device,
78
+ )
79
+ self.norm_2 = norm_class(d_model, device=device)
80
+ self.ffn = MPTMLP(d_model=d_model, expansion_ratio=expansion_ratio, device=device)
81
+ self.resid_attn_dropout = nn.Dropout(resid_pdrop)
82
+ self.resid_ffn_dropout = nn.Dropout(resid_pdrop)
83
+
84
+ def forward(
85
+ self,
86
+ x: torch.Tensor,
87
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
88
+ attn_bias: Optional[torch.Tensor] = None,
89
+ attention_mask: Optional[torch.ByteTensor] = None,
90
+ is_causal: bool = True,
91
+ ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor]]]:
92
+ a = self.norm_1(x)
93
+ (b, attn_weights, past_key_value) = self.attn(
94
+ a, past_key_value=past_key_value, attn_bias=attn_bias, attention_mask=attention_mask, is_causal=is_causal
95
+ )
96
+ x = x + self.resid_attn_dropout(b)
97
+ m = self.norm_2(x)
98
+ n = self.ffn(m)
99
+ x = x + self.resid_ffn_dropout(n)
100
+ return (x, attn_weights, past_key_value)
VILA/llava/model/language_model/mpt/configuration_mpt.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 NVIDIA CORPORATION & AFFILIATES
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # SPDX-License-Identifier: Apache-2.0
16
+
17
+ """A HuggingFace-style model configuration."""
18
+ from typing import Dict, Optional, Union
19
+
20
+ from transformers import PretrainedConfig
21
+
22
+ attn_config_defaults: Dict = {
23
+ "attn_type": "multihead_attention",
24
+ "attn_pdrop": 0.0,
25
+ "attn_impl": "triton",
26
+ "qk_ln": False,
27
+ "clip_qkv": None,
28
+ "softmax_scale": None,
29
+ "prefix_lm": False,
30
+ "attn_uses_sequence_id": False,
31
+ "alibi": False,
32
+ "alibi_bias_max": 8,
33
+ }
34
+ init_config_defaults: Dict = {
35
+ "name": "kaiming_normal_",
36
+ "fan_mode": "fan_in",
37
+ "init_nonlinearity": "relu",
38
+ "init_div_is_residual": True,
39
+ "emb_init_std": None,
40
+ "emb_init_uniform_lim": None,
41
+ "init_std": None,
42
+ "init_gain": 0.0,
43
+ }
44
+
45
+
46
+ class MPTConfig(PretrainedConfig):
47
+ model_type = "mpt"
48
+
49
+ def __init__(
50
+ self,
51
+ d_model: int = 2048,
52
+ n_heads: int = 16,
53
+ n_layers: int = 24,
54
+ expansion_ratio: int = 4,
55
+ max_seq_len: int = 2048,
56
+ vocab_size: int = 50368,
57
+ resid_pdrop: float = 0.0,
58
+ emb_pdrop: float = 0.0,
59
+ learned_pos_emb: bool = True,
60
+ attn_config: Dict = attn_config_defaults,
61
+ init_device: str = "cpu",
62
+ logit_scale: Optional[Union[float, str]] = None,
63
+ no_bias: bool = False,
64
+ verbose: int = 0,
65
+ embedding_fraction: float = 1.0,
66
+ norm_type: str = "low_precision_layernorm",
67
+ use_cache: bool = False,
68
+ init_config: Dict = init_config_defaults,
69
+ **kwargs,
70
+ ):
71
+ """The MPT configuration class.
72
+
73
+ Args:
74
+ d_model (int): The size of the embedding dimension of the model.
75
+ n_heads (int): The number of attention heads.
76
+ n_layers (int): The number of layers in the model.
77
+ expansion_ratio (int): The ratio of the up/down scale in the MLP.
78
+ max_seq_len (int): The maximum sequence length of the model.
79
+ vocab_size (int): The size of the vocabulary.
80
+ resid_pdrop (float): The dropout probability applied to the attention output before combining with residual.
81
+ emb_pdrop (float): The dropout probability for the embedding layer.
82
+ learned_pos_emb (bool): Whether to use learned positional embeddings
83
+ attn_config (Dict): A dictionary used to configure the model's attention module:
84
+ attn_type (str): type of attention to use. Options: multihead_attention, multiquery_attention
85
+ attn_pdrop (float): The dropout probability for the attention layers.
86
+ attn_impl (str): The attention implementation to use. One of 'torch', 'flash', or 'triton'.
87
+ qk_ln (bool): Whether to apply layer normalization to the queries and keys in the attention layer.
88
+ clip_qkv (Optional[float]): If not None, clip the queries, keys, and values in the attention layer to
89
+ this value.
90
+ softmax_scale (Optional[float]): If not None, scale the softmax in the attention layer by this value. If None,
91
+ use the default scale of ``1/sqrt(d_keys)``.
92
+ prefix_lm (Optional[bool]): Whether the model should operate as a Prefix LM. This requires passing an
93
+ extra `prefix_mask` argument which indicates which tokens belong to the prefix. Tokens in the prefix
94
+ can attend to one another bi-directionally. Tokens outside the prefix use causal attention.
95
+ attn_uses_sequence_id (Optional[bool]): Whether to restrict attention to tokens that have the same sequence_id.
96
+ When the model is in `train` mode, this requires passing an extra `sequence_id` argument which indicates
97
+ which sub-sequence each token belongs to.
98
+ Defaults to ``False`` meaning any provided `sequence_id` will be ignored.
99
+ alibi (bool): Whether to use the alibi bias instead of position embeddings.
100
+ alibi_bias_max (int): The maximum value of the alibi bias.
101
+ init_device (str): The device to use for parameter initialization.
102
+ logit_scale (Optional[Union[float, str]]): If not None, scale the logits by this value.
103
+ no_bias (bool): Whether to use bias in all layers.
104
+ verbose (int): The verbosity level. 0 is silent.
105
+ embedding_fraction (float): The fraction to scale the gradients of the embedding layer by.
106
+ norm_type (str): choose type of norm to use
107
+ multiquery_attention (bool): Whether to use multiquery attention implementation.
108
+ use_cache (bool): Whether or not the model should return the last key/values attentions
109
+ init_config (Dict): A dictionary used to configure the model initialization:
110
+ init_config.name: The parameter initialization scheme to use. Options: 'default_', 'baseline_',
111
+ 'kaiming_uniform_', 'kaiming_normal_', 'neox_init_', 'small_init_', 'xavier_uniform_', or
112
+ 'xavier_normal_'. These mimic the parameter initialization methods in PyTorch.
113
+ init_div_is_residual (Union[int, float, str, bool]): Value to divide initial weights by if ``module._is_residual`` is True.
114
+ emb_init_std (Optional[float]): The standard deviation of the normal distribution used to initialize the embedding layer.
115
+ emb_init_uniform_lim (Optional[Union[Tuple[float, float], float]]): The lower and upper limits of the uniform distribution
116
+ used to initialize the embedding layer. Mutually exclusive with ``emb_init_std``.
117
+ init_std (float): The standard deviation of the normal distribution used to initialize the model,
118
+ if using the baseline_ parameter initialization scheme.
119
+ init_gain (float): The gain to use for parameter initialization with kaiming or xavier initialization schemes.
120
+ fan_mode (str): The fan mode to use for parameter initialization with kaiming initialization schemes.
121
+ init_nonlinearity (str): The nonlinearity to use for parameter initialization with kaiming initialization schemes.
122
+ ---
123
+ See llmfoundry.models.utils.param_init_fns.py for info on other param init config options
124
+ """
125
+ self.d_model = d_model
126
+ self.n_heads = n_heads
127
+ self.n_layers = n_layers
128
+ self.expansion_ratio = expansion_ratio
129
+ self.max_seq_len = max_seq_len
130
+ self.vocab_size = vocab_size
131
+ self.resid_pdrop = resid_pdrop
132
+ self.emb_pdrop = emb_pdrop
133
+ self.learned_pos_emb = learned_pos_emb
134
+ self.attn_config = attn_config
135
+ self.init_device = init_device
136
+ self.logit_scale = logit_scale
137
+ self.no_bias = no_bias
138
+ self.verbose = verbose
139
+ self.embedding_fraction = embedding_fraction
140
+ self.norm_type = norm_type
141
+ self.use_cache = use_cache
142
+ self.init_config = init_config
143
+ if "name" in kwargs:
144
+ del kwargs["name"]
145
+ if "loss_fn" in kwargs:
146
+ del kwargs["loss_fn"]
147
+ super().__init__(**kwargs)
148
+ self._validate_config()
149
+
150
+ def _set_config_defaults(self, config, config_defaults):
151
+ for (k, v) in config_defaults.items():
152
+ if k not in config:
153
+ config[k] = v
154
+ return config
155
+
156
+ def _validate_config(self):
157
+ self.attn_config = self._set_config_defaults(self.attn_config, attn_config_defaults)
158
+ self.init_config = self._set_config_defaults(self.init_config, init_config_defaults)
159
+ if self.d_model % self.n_heads != 0:
160
+ raise ValueError("d_model must be divisible by n_heads")
161
+ if any(prob < 0 or prob > 1 for prob in [self.attn_config["attn_pdrop"], self.resid_pdrop, self.emb_pdrop]):
162
+ raise ValueError(
163
+ "self.attn_config['attn_pdrop'], resid_pdrop, emb_pdrop are probabilities and must be between 0 and 1"
164
+ )
165
+ if self.attn_config["attn_impl"] not in ["torch", "flash", "triton"]:
166
+ raise ValueError(f"Unknown attn_impl={self.attn_config['attn_impl']}")
167
+ if self.attn_config["prefix_lm"] and self.attn_config["attn_impl"] not in ["torch", "triton"]:
168
+ raise NotImplementedError("prefix_lm only implemented with torch and triton attention.")
169
+ if self.attn_config["alibi"] and self.attn_config["attn_impl"] not in ["torch", "triton"]:
170
+ raise NotImplementedError("alibi only implemented with torch and triton attention.")
171
+ if self.attn_config["attn_uses_sequence_id"] and self.attn_config["attn_impl"] not in ["torch", "triton"]:
172
+ raise NotImplementedError("attn_uses_sequence_id only implemented with torch and triton attention.")
173
+ if self.embedding_fraction > 1 or self.embedding_fraction <= 0:
174
+ raise ValueError("model.embedding_fraction must be between 0 (exclusive) and 1 (inclusive)!")
175
+ if isinstance(self.logit_scale, str) and self.logit_scale != "inv_sqrt_d_model":
176
+ raise ValueError(
177
+ f"self.logit_scale={self.logit_scale!r} is not recognized as an option; use numeric value or 'inv_sqrt_d_model'."
178
+ )
179
+ if self.init_config.get("name", None) is None:
180
+ raise ValueError(f"self.init_config={self.init_config!r} 'name' needs to be set.")
181
+ if not self.learned_pos_emb and (not self.attn_config["alibi"]):
182
+ raise ValueError(
183
+ f"Positional information must be provided to the model using either learned_pos_emb or alibi."
184
+ )
VILA/llava/model/language_model/mpt/custom_embedding.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 NVIDIA CORPORATION & AFFILIATES
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # SPDX-License-Identifier: Apache-2.0
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+ import torch.nn.functional as F
20
+ from torch import Tensor
21
+
22
+
23
+ class SharedEmbedding(nn.Embedding):
24
+ def forward(self, input: Tensor, unembed: bool = False) -> Tensor:
25
+ if unembed:
26
+ return F.linear(input, self.weight)
27
+ return super().forward(input)
VILA/llava/model/language_model/mpt/flash_attn_triton.py ADDED
@@ -0,0 +1,947 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 NVIDIA CORPORATION & AFFILIATES
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # SPDX-License-Identifier: Apache-2.0
16
+
17
+ """
18
+ Copied from https://github.com/HazyResearch/flash-attention/blob/eff9fe6b8076df59d64d7a3f464696738a3c7c24/flash_attn/flash_attn_triton.py
19
+ update imports to use 'triton_pre_mlir'
20
+
21
+ *Experimental* implementation of FlashAttention in Triton.
22
+ Tested with triton==2.0.0.dev20221202.
23
+ Triton 2.0 has a new backend (MLIR) but seems like it doesn't yet work for head dimensions
24
+ other than 64:
25
+ https://github.com/openai/triton/blob/d376020f90002757eea3ea9475d4f7cfc2ec5ead/python/triton/ops/flash_attention.py#L207
26
+ We'll update this implementation with the new Triton backend once this is fixed.
27
+
28
+ We use the FlashAttention implementation from Phil Tillet a starting point.
29
+ https://github.com/openai/triton/blob/master/python/tutorials/06-fused-attention.py
30
+
31
+ Changes:
32
+ - Implement both causal and non-causal attention.
33
+ - Implement both self-attention and cross-attention.
34
+ - Support arbitrary seqlens (not just multiples of 128), for both forward and backward.
35
+ - Support all head dimensions up to 128 (not just 16, 32, 64, 128), for both forward and backward.
36
+ - Support attention bias.
37
+ - Speed up the forward pass a bit, and only store the LSE instead of m and l.
38
+ - Make the backward for d=128 much faster by reducing register spilling.
39
+ - Optionally parallelize the backward pass across seqlen_k, to deal with the case of
40
+ small batch size * nheads.
41
+
42
+ Caution:
43
+ - This is an *experimental* implementation. The forward pass should be quite robust but
44
+ I'm not 100% sure that the backward pass doesn't have race conditions (due to the Triton compiler).
45
+ - This implementation has only been tested on A100.
46
+ - If you plan to use headdim other than 64 and 128, you should test for race conditions
47
+ (due to the Triton compiler), as done in tests/test_flash_attn.py
48
+ "test_flash_attn_triton_race_condition". I've tested and fixed many race conditions
49
+ for different head dimensions (40, 48, 64, 128, 80, 88, 96), but I'm still not 100% confident
50
+ that there are none left for other head dimensions.
51
+
52
+ Differences between this Triton version and the CUDA version:
53
+ - Triton version doesn't support dropout.
54
+ - Triton forward is generally faster than CUDA forward, while Triton backward is
55
+ generally slower than CUDA backward. Overall Triton forward + backward is slightly slower
56
+ than CUDA forward + backward.
57
+ - Triton version doesn't support different sequence lengths in a batch (i.e., RaggedTensor/NestedTensor).
58
+ - Triton version supports attention bias, while CUDA version doesn't.
59
+ """
60
+ import math
61
+
62
+ import torch
63
+ import triton_pre_mlir as triton
64
+ import triton_pre_mlir.language as tl
65
+
66
+
67
+ @triton.heuristics(
68
+ {
69
+ "EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0,
70
+ "EVEN_N": lambda args: args["seqlen_k"] % args["BLOCK_N"] == 0,
71
+ "EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"],
72
+ }
73
+ )
74
+ @triton.jit
75
+ def _fwd_kernel(
76
+ Q,
77
+ K,
78
+ V,
79
+ Bias,
80
+ Out,
81
+ Lse,
82
+ TMP,
83
+ softmax_scale,
84
+ stride_qb,
85
+ stride_qh,
86
+ stride_qm,
87
+ stride_kb,
88
+ stride_kh,
89
+ stride_kn,
90
+ stride_vb,
91
+ stride_vh,
92
+ stride_vn,
93
+ stride_bb,
94
+ stride_bh,
95
+ stride_bm,
96
+ stride_ob,
97
+ stride_oh,
98
+ stride_om,
99
+ nheads,
100
+ seqlen_q,
101
+ seqlen_k,
102
+ seqlen_q_rounded,
103
+ headdim,
104
+ CACHE_KEY_SEQLEN_Q,
105
+ CACHE_KEY_SEQLEN_K,
106
+ BIAS_TYPE: tl.constexpr,
107
+ IS_CAUSAL: tl.constexpr,
108
+ BLOCK_HEADDIM: tl.constexpr,
109
+ EVEN_M: tl.constexpr,
110
+ EVEN_N: tl.constexpr,
111
+ EVEN_HEADDIM: tl.constexpr,
112
+ BLOCK_M: tl.constexpr,
113
+ BLOCK_N: tl.constexpr,
114
+ ):
115
+ start_m = tl.program_id(0)
116
+ off_hb = tl.program_id(1)
117
+ off_b = off_hb // nheads
118
+ off_h = off_hb % nheads
119
+ offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
120
+ offs_n = tl.arange(0, BLOCK_N)
121
+ offs_d = tl.arange(0, BLOCK_HEADDIM)
122
+ q_ptrs = Q + off_b * stride_qb + off_h * stride_qh + (offs_m[:, None] * stride_qm + offs_d[None, :])
123
+ k_ptrs = K + off_b * stride_kb + off_h * stride_kh + (offs_n[:, None] * stride_kn + offs_d[None, :])
124
+ v_ptrs = V + off_b * stride_vb + off_h * stride_vh + (offs_n[:, None] * stride_vn + offs_d[None, :])
125
+ if BIAS_TYPE == "vector":
126
+ b_ptrs = Bias + off_b * stride_bb + off_h * stride_bh + offs_n
127
+ elif BIAS_TYPE == "matrix":
128
+ b_ptrs = Bias + off_b * stride_bb + off_h * stride_bh + (offs_m[:, None] * stride_bm + offs_n[None, :])
129
+ t_ptrs = TMP + off_hb * seqlen_q_rounded + offs_m
130
+ lse_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
131
+ m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
132
+ acc_o = tl.zeros([BLOCK_M, BLOCK_HEADDIM], dtype=tl.float32)
133
+ if EVEN_M & EVEN_N:
134
+ if EVEN_HEADDIM:
135
+ q = tl.load(q_ptrs)
136
+ else:
137
+ q = tl.load(q_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
138
+ elif EVEN_HEADDIM:
139
+ q = tl.load(q_ptrs, mask=offs_m[:, None] < seqlen_q, other=0.0)
140
+ else:
141
+ q = tl.load(q_ptrs, mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim), other=0.0)
142
+ end_n = seqlen_k if not IS_CAUSAL else tl.minimum((start_m + 1) * BLOCK_M, seqlen_k)
143
+ for start_n in range(0, end_n, BLOCK_N):
144
+ start_n = tl.multiple_of(start_n, BLOCK_N)
145
+ if EVEN_N & EVEN_M:
146
+ if EVEN_HEADDIM:
147
+ k = tl.load(k_ptrs + start_n * stride_kn)
148
+ else:
149
+ k = tl.load(k_ptrs + start_n * stride_kn, mask=offs_d[None, :] < headdim, other=0.0)
150
+ elif EVEN_HEADDIM:
151
+ k = tl.load(k_ptrs + start_n * stride_kn, mask=(start_n + offs_n)[:, None] < seqlen_k, other=0.0)
152
+ else:
153
+ k = tl.load(
154
+ k_ptrs + start_n * stride_kn,
155
+ mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim),
156
+ other=0.0,
157
+ )
158
+ qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
159
+ qk += tl.dot(q, k, trans_b=True)
160
+ if not EVEN_N:
161
+ qk += tl.where((start_n + offs_n)[None, :] < seqlen_k, 0, float("-inf"))
162
+ if IS_CAUSAL:
163
+ qk += tl.where(offs_m[:, None] >= (start_n + offs_n)[None, :], 0, float("-inf"))
164
+ if BIAS_TYPE != "none":
165
+ if BIAS_TYPE == "vector":
166
+ if EVEN_N:
167
+ bias = tl.load(b_ptrs + start_n).to(tl.float32)
168
+ else:
169
+ bias = tl.load(b_ptrs + start_n, mask=start_n + offs_n < seqlen_k, other=0.0).to(tl.float32)
170
+ bias = bias[None, :]
171
+ elif BIAS_TYPE == "matrix":
172
+ if EVEN_M & EVEN_N:
173
+ bias = tl.load(b_ptrs + start_n).to(tl.float32)
174
+ else:
175
+ bias = tl.load(
176
+ b_ptrs + start_n,
177
+ mask=(offs_m[:, None] < seqlen_q) & ((start_n + offs_n)[None, :] < seqlen_k),
178
+ other=0.0,
179
+ ).to(tl.float32)
180
+ qk = qk * softmax_scale + bias
181
+ m_ij = tl.maximum(tl.max(qk, 1), lse_i)
182
+ p = tl.exp(qk - m_ij[:, None])
183
+ else:
184
+ m_ij = tl.maximum(tl.max(qk, 1) * softmax_scale, lse_i)
185
+ p = tl.exp(qk * softmax_scale - m_ij[:, None])
186
+ l_ij = tl.sum(p, 1)
187
+ acc_o_scale = tl.exp(m_i - m_ij)
188
+ tl.store(t_ptrs, acc_o_scale)
189
+ acc_o_scale = tl.load(t_ptrs)
190
+ acc_o = acc_o * acc_o_scale[:, None]
191
+ if EVEN_N & EVEN_M:
192
+ if EVEN_HEADDIM:
193
+ v = tl.load(v_ptrs + start_n * stride_vn)
194
+ else:
195
+ v = tl.load(v_ptrs + start_n * stride_vn, mask=offs_d[None, :] < headdim, other=0.0)
196
+ elif EVEN_HEADDIM:
197
+ v = tl.load(v_ptrs + start_n * stride_vn, mask=(start_n + offs_n)[:, None] < seqlen_k, other=0.0)
198
+ else:
199
+ v = tl.load(
200
+ v_ptrs + start_n * stride_vn,
201
+ mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim),
202
+ other=0.0,
203
+ )
204
+ p = p.to(v.dtype)
205
+ acc_o += tl.dot(p, v)
206
+ m_i = m_ij
207
+ l_i_new = tl.exp(lse_i - m_ij) + l_ij
208
+ lse_i = m_ij + tl.log(l_i_new)
209
+ o_scale = tl.exp(m_i - lse_i)
210
+ tl.store(t_ptrs, o_scale)
211
+ o_scale = tl.load(t_ptrs)
212
+ acc_o = acc_o * o_scale[:, None]
213
+ start_m = tl.program_id(0)
214
+ offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
215
+ lse_ptrs = Lse + off_hb * seqlen_q_rounded + offs_m
216
+ tl.store(lse_ptrs, lse_i)
217
+ offs_d = tl.arange(0, BLOCK_HEADDIM)
218
+ out_ptrs = Out + off_b * stride_ob + off_h * stride_oh + (offs_m[:, None] * stride_om + offs_d[None, :])
219
+ if EVEN_M:
220
+ if EVEN_HEADDIM:
221
+ tl.store(out_ptrs, acc_o)
222
+ else:
223
+ tl.store(out_ptrs, acc_o, mask=offs_d[None, :] < headdim)
224
+ elif EVEN_HEADDIM:
225
+ tl.store(out_ptrs, acc_o, mask=offs_m[:, None] < seqlen_q)
226
+ else:
227
+ tl.store(out_ptrs, acc_o, mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim))
228
+
229
+
230
+ @triton.jit
231
+ def _bwd_preprocess_do_o_dot(
232
+ Out,
233
+ DO,
234
+ Delta,
235
+ stride_ob,
236
+ stride_oh,
237
+ stride_om,
238
+ stride_dob,
239
+ stride_doh,
240
+ stride_dom,
241
+ nheads,
242
+ seqlen_q,
243
+ seqlen_q_rounded,
244
+ headdim,
245
+ BLOCK_M: tl.constexpr,
246
+ BLOCK_HEADDIM: tl.constexpr,
247
+ ):
248
+ start_m = tl.program_id(0)
249
+ off_hb = tl.program_id(1)
250
+ off_b = off_hb // nheads
251
+ off_h = off_hb % nheads
252
+ offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
253
+ offs_d = tl.arange(0, BLOCK_HEADDIM)
254
+ o = tl.load(
255
+ Out + off_b * stride_ob + off_h * stride_oh + offs_m[:, None] * stride_om + offs_d[None, :],
256
+ mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
257
+ other=0.0,
258
+ ).to(tl.float32)
259
+ do = tl.load(
260
+ DO + off_b * stride_dob + off_h * stride_doh + offs_m[:, None] * stride_dom + offs_d[None, :],
261
+ mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
262
+ other=0.0,
263
+ ).to(tl.float32)
264
+ delta = tl.sum(o * do, axis=1)
265
+ tl.store(Delta + off_hb * seqlen_q_rounded + offs_m, delta)
266
+
267
+
268
+ @triton.jit
269
+ def _bwd_store_dk_dv(
270
+ dk_ptrs,
271
+ dv_ptrs,
272
+ dk,
273
+ dv,
274
+ offs_n,
275
+ offs_d,
276
+ seqlen_k,
277
+ headdim,
278
+ EVEN_M: tl.constexpr,
279
+ EVEN_N: tl.constexpr,
280
+ EVEN_HEADDIM: tl.constexpr,
281
+ ):
282
+ if EVEN_N & EVEN_M:
283
+ if EVEN_HEADDIM:
284
+ tl.store(dv_ptrs, dv)
285
+ tl.store(dk_ptrs, dk)
286
+ else:
287
+ tl.store(dv_ptrs, dv, mask=offs_d[None, :] < headdim)
288
+ tl.store(dk_ptrs, dk, mask=offs_d[None, :] < headdim)
289
+ elif EVEN_HEADDIM:
290
+ tl.store(dv_ptrs, dv, mask=offs_n[:, None] < seqlen_k)
291
+ tl.store(dk_ptrs, dk, mask=offs_n[:, None] < seqlen_k)
292
+ else:
293
+ tl.store(dv_ptrs, dv, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim))
294
+ tl.store(dk_ptrs, dk, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim))
295
+
296
+
297
+ @triton.jit
298
+ def _bwd_kernel_one_col_block(
299
+ start_n,
300
+ Q,
301
+ K,
302
+ V,
303
+ Bias,
304
+ DO,
305
+ DQ,
306
+ DK,
307
+ DV,
308
+ LSE,
309
+ D,
310
+ softmax_scale,
311
+ stride_qm,
312
+ stride_kn,
313
+ stride_vn,
314
+ stride_bm,
315
+ stride_dom,
316
+ stride_dqm,
317
+ stride_dkn,
318
+ stride_dvn,
319
+ seqlen_q,
320
+ seqlen_k,
321
+ headdim,
322
+ ATOMIC_ADD: tl.constexpr,
323
+ BIAS_TYPE: tl.constexpr,
324
+ IS_CAUSAL: tl.constexpr,
325
+ BLOCK_HEADDIM: tl.constexpr,
326
+ EVEN_M: tl.constexpr,
327
+ EVEN_N: tl.constexpr,
328
+ EVEN_HEADDIM: tl.constexpr,
329
+ BLOCK_M: tl.constexpr,
330
+ BLOCK_N: tl.constexpr,
331
+ ):
332
+ begin_m = 0 if not IS_CAUSAL else start_n * BLOCK_N // BLOCK_M * BLOCK_M
333
+ offs_qm = begin_m + tl.arange(0, BLOCK_M)
334
+ offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N)
335
+ offs_m = tl.arange(0, BLOCK_M)
336
+ offs_d = tl.arange(0, BLOCK_HEADDIM)
337
+ q_ptrs = Q + (offs_qm[:, None] * stride_qm + offs_d[None, :])
338
+ k_ptrs = K + (offs_n[:, None] * stride_kn + offs_d[None, :])
339
+ v_ptrs = V + (offs_n[:, None] * stride_vn + offs_d[None, :])
340
+ do_ptrs = DO + (offs_qm[:, None] * stride_dom + offs_d[None, :])
341
+ dq_ptrs = DQ + (offs_qm[:, None] * stride_dqm + offs_d[None, :])
342
+ if BIAS_TYPE == "vector":
343
+ b_ptrs = Bias + offs_n
344
+ elif BIAS_TYPE == "matrix":
345
+ b_ptrs = Bias + (offs_qm[:, None] * stride_bm + offs_n[None, :])
346
+ dv = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32)
347
+ dk = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32)
348
+ if begin_m >= seqlen_q:
349
+ dv_ptrs = DV + (offs_n[:, None] * stride_dvn + offs_d[None, :])
350
+ dk_ptrs = DK + (offs_n[:, None] * stride_dkn + offs_d[None, :])
351
+ _bwd_store_dk_dv(
352
+ dk_ptrs,
353
+ dv_ptrs,
354
+ dk,
355
+ dv,
356
+ offs_n,
357
+ offs_d,
358
+ seqlen_k,
359
+ headdim,
360
+ EVEN_M=EVEN_M,
361
+ EVEN_N=EVEN_N,
362
+ EVEN_HEADDIM=EVEN_HEADDIM,
363
+ )
364
+ return
365
+ if EVEN_N & EVEN_M:
366
+ if EVEN_HEADDIM:
367
+ k = tl.load(k_ptrs)
368
+ v = tl.load(v_ptrs)
369
+ else:
370
+ k = tl.load(k_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
371
+ v = tl.load(v_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
372
+ elif EVEN_HEADDIM:
373
+ k = tl.load(k_ptrs, mask=offs_n[:, None] < seqlen_k, other=0.0)
374
+ v = tl.load(v_ptrs, mask=offs_n[:, None] < seqlen_k, other=0.0)
375
+ else:
376
+ k = tl.load(k_ptrs, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim), other=0.0)
377
+ v = tl.load(v_ptrs, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim), other=0.0)
378
+ num_block_m = tl.cdiv(seqlen_q, BLOCK_M)
379
+ for start_m in range(begin_m, num_block_m * BLOCK_M, BLOCK_M):
380
+ start_m = tl.multiple_of(start_m, BLOCK_M)
381
+ offs_m_curr = start_m + offs_m
382
+ if EVEN_M & EVEN_HEADDIM:
383
+ q = tl.load(q_ptrs)
384
+ elif EVEN_HEADDIM:
385
+ q = tl.load(q_ptrs, mask=offs_m_curr[:, None] < seqlen_q, other=0.0)
386
+ else:
387
+ q = tl.load(q_ptrs, mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim), other=0.0)
388
+ qk = tl.dot(q, k, trans_b=True)
389
+ if not EVEN_N:
390
+ qk = tl.where(offs_n[None, :] < seqlen_k, qk, float("-inf"))
391
+ if IS_CAUSAL:
392
+ qk = tl.where(offs_m_curr[:, None] >= offs_n[None, :], qk, float("-inf"))
393
+ if BIAS_TYPE != "none":
394
+ tl.debug_barrier()
395
+ if BIAS_TYPE == "vector":
396
+ if EVEN_N:
397
+ bias = tl.load(b_ptrs).to(tl.float32)
398
+ else:
399
+ bias = tl.load(b_ptrs, mask=offs_n < seqlen_k, other=0.0).to(tl.float32)
400
+ bias = bias[None, :]
401
+ elif BIAS_TYPE == "matrix":
402
+ if EVEN_M & EVEN_N:
403
+ bias = tl.load(b_ptrs).to(tl.float32)
404
+ else:
405
+ bias = tl.load(
406
+ b_ptrs, mask=(offs_m_curr[:, None] < seqlen_q) & (offs_n[None, :] < seqlen_k), other=0.0
407
+ ).to(tl.float32)
408
+ qk = qk * softmax_scale + bias
409
+ if not EVEN_M & EVEN_HEADDIM:
410
+ tl.debug_barrier()
411
+ lse_i = tl.load(LSE + offs_m_curr)
412
+ if BIAS_TYPE == "none":
413
+ p = tl.exp(qk * softmax_scale - lse_i[:, None])
414
+ else:
415
+ p = tl.exp(qk - lse_i[:, None])
416
+ if EVEN_M & EVEN_HEADDIM:
417
+ do = tl.load(do_ptrs)
418
+ else:
419
+ do = tl.load(do_ptrs, mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim), other=0.0)
420
+ dv += tl.dot(p.to(do.dtype), do, trans_a=True)
421
+ if not EVEN_M & EVEN_HEADDIM:
422
+ tl.debug_barrier()
423
+ dp = tl.dot(do, v, trans_b=True)
424
+ if not EVEN_HEADDIM:
425
+ tl.debug_barrier()
426
+ Di = tl.load(D + offs_m_curr)
427
+ ds = (p * (dp - Di[:, None]) * softmax_scale).to(q.dtype)
428
+ dk += tl.dot(ds, q, trans_a=True)
429
+ if not EVEN_M & EVEN_HEADDIM:
430
+ tl.debug_barrier()
431
+ if not ATOMIC_ADD:
432
+ if EVEN_M & EVEN_HEADDIM:
433
+ dq = tl.load(dq_ptrs, eviction_policy="evict_last")
434
+ dq += tl.dot(ds, k)
435
+ tl.store(dq_ptrs, dq, eviction_policy="evict_last")
436
+ elif EVEN_HEADDIM:
437
+ dq = tl.load(dq_ptrs, mask=offs_m_curr[:, None] < seqlen_q, other=0.0, eviction_policy="evict_last")
438
+ dq += tl.dot(ds, k)
439
+ tl.store(dq_ptrs, dq, mask=offs_m_curr[:, None] < seqlen_q, eviction_policy="evict_last")
440
+ else:
441
+ dq = tl.load(
442
+ dq_ptrs,
443
+ mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
444
+ other=0.0,
445
+ eviction_policy="evict_last",
446
+ )
447
+ dq += tl.dot(ds, k)
448
+ tl.store(
449
+ dq_ptrs,
450
+ dq,
451
+ mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
452
+ eviction_policy="evict_last",
453
+ )
454
+ else:
455
+ dq = tl.dot(ds, k)
456
+ if EVEN_M & EVEN_HEADDIM:
457
+ tl.atomic_add(dq_ptrs, dq)
458
+ elif EVEN_HEADDIM:
459
+ tl.atomic_add(dq_ptrs, dq, mask=offs_m_curr[:, None] < seqlen_q)
460
+ else:
461
+ tl.atomic_add(dq_ptrs, dq, mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim))
462
+ dq_ptrs += BLOCK_M * stride_dqm
463
+ q_ptrs += BLOCK_M * stride_qm
464
+ do_ptrs += BLOCK_M * stride_dom
465
+ if BIAS_TYPE == "matrix":
466
+ b_ptrs += BLOCK_M * stride_bm
467
+ dv_ptrs = DV + (offs_n[:, None] * stride_dvn + offs_d[None, :])
468
+ dk_ptrs = DK + (offs_n[:, None] * stride_dkn + offs_d[None, :])
469
+ _bwd_store_dk_dv(
470
+ dk_ptrs,
471
+ dv_ptrs,
472
+ dk,
473
+ dv,
474
+ offs_n,
475
+ offs_d,
476
+ seqlen_k,
477
+ headdim,
478
+ EVEN_M=EVEN_M,
479
+ EVEN_N=EVEN_N,
480
+ EVEN_HEADDIM=EVEN_HEADDIM,
481
+ )
482
+
483
+
484
+ def init_to_zero(name):
485
+ return lambda nargs: nargs[name].zero_()
486
+
487
+
488
+ @triton.autotune(
489
+ configs=[
490
+ triton.Config(
491
+ {"BLOCK_M": 128, "BLOCK_N": 128, "SEQUENCE_PARALLEL": False},
492
+ num_warps=8,
493
+ num_stages=1,
494
+ pre_hook=init_to_zero("DQ"),
495
+ ),
496
+ triton.Config(
497
+ {"BLOCK_M": 128, "BLOCK_N": 128, "SEQUENCE_PARALLEL": True},
498
+ num_warps=8,
499
+ num_stages=1,
500
+ pre_hook=init_to_zero("DQ"),
501
+ ),
502
+ ],
503
+ key=["CACHE_KEY_SEQLEN_Q", "CACHE_KEY_SEQLEN_K", "BIAS_TYPE", "IS_CAUSAL", "BLOCK_HEADDIM"],
504
+ )
505
+ @triton.heuristics(
506
+ {
507
+ "EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0,
508
+ "EVEN_N": lambda args: args["seqlen_k"] % args["BLOCK_N"] == 0,
509
+ "EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"],
510
+ }
511
+ )
512
+ @triton.jit
513
+ def _bwd_kernel(
514
+ Q,
515
+ K,
516
+ V,
517
+ Bias,
518
+ DO,
519
+ DQ,
520
+ DK,
521
+ DV,
522
+ LSE,
523
+ D,
524
+ softmax_scale,
525
+ stride_qb,
526
+ stride_qh,
527
+ stride_qm,
528
+ stride_kb,
529
+ stride_kh,
530
+ stride_kn,
531
+ stride_vb,
532
+ stride_vh,
533
+ stride_vn,
534
+ stride_bb,
535
+ stride_bh,
536
+ stride_bm,
537
+ stride_dob,
538
+ stride_doh,
539
+ stride_dom,
540
+ stride_dqb,
541
+ stride_dqh,
542
+ stride_dqm,
543
+ stride_dkb,
544
+ stride_dkh,
545
+ stride_dkn,
546
+ stride_dvb,
547
+ stride_dvh,
548
+ stride_dvn,
549
+ nheads,
550
+ seqlen_q,
551
+ seqlen_k,
552
+ seqlen_q_rounded,
553
+ headdim,
554
+ CACHE_KEY_SEQLEN_Q,
555
+ CACHE_KEY_SEQLEN_K,
556
+ BIAS_TYPE: tl.constexpr,
557
+ IS_CAUSAL: tl.constexpr,
558
+ BLOCK_HEADDIM: tl.constexpr,
559
+ SEQUENCE_PARALLEL: tl.constexpr,
560
+ EVEN_M: tl.constexpr,
561
+ EVEN_N: tl.constexpr,
562
+ EVEN_HEADDIM: tl.constexpr,
563
+ BLOCK_M: tl.constexpr,
564
+ BLOCK_N: tl.constexpr,
565
+ ):
566
+ off_hb = tl.program_id(1)
567
+ off_b = off_hb // nheads
568
+ off_h = off_hb % nheads
569
+ Q += off_b * stride_qb + off_h * stride_qh
570
+ K += off_b * stride_kb + off_h * stride_kh
571
+ V += off_b * stride_vb + off_h * stride_vh
572
+ DO += off_b * stride_dob + off_h * stride_doh
573
+ DQ += off_b * stride_dqb + off_h * stride_dqh
574
+ DK += off_b * stride_dkb + off_h * stride_dkh
575
+ DV += off_b * stride_dvb + off_h * stride_dvh
576
+ if BIAS_TYPE != "none":
577
+ Bias += off_b * stride_bb + off_h * stride_bh
578
+ D += off_hb * seqlen_q_rounded
579
+ LSE += off_hb * seqlen_q_rounded
580
+ if not SEQUENCE_PARALLEL:
581
+ num_block_n = tl.cdiv(seqlen_k, BLOCK_N)
582
+ for start_n in range(0, num_block_n):
583
+ _bwd_kernel_one_col_block(
584
+ start_n,
585
+ Q,
586
+ K,
587
+ V,
588
+ Bias,
589
+ DO,
590
+ DQ,
591
+ DK,
592
+ DV,
593
+ LSE,
594
+ D,
595
+ softmax_scale,
596
+ stride_qm,
597
+ stride_kn,
598
+ stride_vn,
599
+ stride_bm,
600
+ stride_dom,
601
+ stride_dqm,
602
+ stride_dkn,
603
+ stride_dvn,
604
+ seqlen_q,
605
+ seqlen_k,
606
+ headdim,
607
+ ATOMIC_ADD=False,
608
+ BIAS_TYPE=BIAS_TYPE,
609
+ IS_CAUSAL=IS_CAUSAL,
610
+ BLOCK_HEADDIM=BLOCK_HEADDIM,
611
+ EVEN_M=EVEN_M,
612
+ EVEN_N=EVEN_N,
613
+ EVEN_HEADDIM=EVEN_HEADDIM,
614
+ BLOCK_M=BLOCK_M,
615
+ BLOCK_N=BLOCK_N,
616
+ )
617
+ else:
618
+ start_n = tl.program_id(0)
619
+ _bwd_kernel_one_col_block(
620
+ start_n,
621
+ Q,
622
+ K,
623
+ V,
624
+ Bias,
625
+ DO,
626
+ DQ,
627
+ DK,
628
+ DV,
629
+ LSE,
630
+ D,
631
+ softmax_scale,
632
+ stride_qm,
633
+ stride_kn,
634
+ stride_vn,
635
+ stride_bm,
636
+ stride_dom,
637
+ stride_dqm,
638
+ stride_dkn,
639
+ stride_dvn,
640
+ seqlen_q,
641
+ seqlen_k,
642
+ headdim,
643
+ ATOMIC_ADD=True,
644
+ BIAS_TYPE=BIAS_TYPE,
645
+ IS_CAUSAL=IS_CAUSAL,
646
+ BLOCK_HEADDIM=BLOCK_HEADDIM,
647
+ EVEN_M=EVEN_M,
648
+ EVEN_N=EVEN_N,
649
+ EVEN_HEADDIM=EVEN_HEADDIM,
650
+ BLOCK_M=BLOCK_M,
651
+ BLOCK_N=BLOCK_N,
652
+ )
653
+
654
+
655
+ def _flash_attn_forward(q, k, v, bias=None, causal=False, softmax_scale=None):
656
+ (batch, seqlen_q, nheads, d) = q.shape
657
+ (_, seqlen_k, _, _) = k.shape
658
+ assert k.shape == (batch, seqlen_k, nheads, d)
659
+ assert v.shape == (batch, seqlen_k, nheads, d)
660
+ assert d <= 128, "FlashAttention only support head dimensions up to 128"
661
+ assert q.dtype == k.dtype == v.dtype, "All tensors must have the same type"
662
+ assert q.dtype in [torch.float16, torch.bfloat16], "Only support fp16 and bf16"
663
+ assert q.is_cuda and k.is_cuda and v.is_cuda
664
+ softmax_scale = softmax_scale or 1.0 / math.sqrt(d)
665
+ has_bias = bias is not None
666
+ bias_type = "none"
667
+ if has_bias:
668
+ assert bias.dtype in [q.dtype, torch.float]
669
+ assert bias.is_cuda
670
+ assert bias.dim() == 4
671
+ if bias.stride(-1) != 1:
672
+ bias = bias.contiguous()
673
+ if bias.shape[2:] == (1, seqlen_k):
674
+ bias_type = "vector"
675
+ elif bias.shape[2:] == (seqlen_q, seqlen_k):
676
+ bias_type = "matrix"
677
+ else:
678
+ raise RuntimeError("Last 2 dimensions of bias must be (1, seqlen_k) or (seqlen_q, seqlen_k)")
679
+ bias = bias.expand(batch, nheads, seqlen_q, seqlen_k)
680
+ bias_strides = (bias.stride(0), bias.stride(1), bias.stride(2)) if has_bias else (0, 0, 0)
681
+ seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128
682
+ lse = torch.empty((batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32)
683
+ tmp = torch.empty((batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32)
684
+ o = torch.empty_like(q)
685
+ BLOCK_HEADDIM = max(triton.next_power_of_2(d), 16)
686
+ BLOCK = 128
687
+ num_warps = 4 if d <= 64 else 8
688
+ grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK_M"]), batch * nheads)
689
+ _fwd_kernel[grid](
690
+ q,
691
+ k,
692
+ v,
693
+ bias,
694
+ o,
695
+ lse,
696
+ tmp,
697
+ softmax_scale,
698
+ q.stride(0),
699
+ q.stride(2),
700
+ q.stride(1),
701
+ k.stride(0),
702
+ k.stride(2),
703
+ k.stride(1),
704
+ v.stride(0),
705
+ v.stride(2),
706
+ v.stride(1),
707
+ *bias_strides,
708
+ o.stride(0),
709
+ o.stride(2),
710
+ o.stride(1),
711
+ nheads,
712
+ seqlen_q,
713
+ seqlen_k,
714
+ seqlen_q_rounded,
715
+ d,
716
+ seqlen_q // 32,
717
+ seqlen_k // 32,
718
+ bias_type,
719
+ causal,
720
+ BLOCK_HEADDIM,
721
+ BLOCK_M=BLOCK,
722
+ BLOCK_N=BLOCK,
723
+ num_warps=num_warps,
724
+ num_stages=1
725
+ )
726
+ return (o, lse, softmax_scale)
727
+
728
+
729
+ def _flash_attn_backward(do, q, k, v, o, lse, dq, dk, dv, bias=None, causal=False, softmax_scale=None):
730
+ if do.stride(-1) != 1:
731
+ do = do.contiguous()
732
+ (batch, seqlen_q, nheads, d) = q.shape
733
+ (_, seqlen_k, _, _) = k.shape
734
+ assert d <= 128
735
+ seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128
736
+ assert lse.shape == (batch, nheads, seqlen_q_rounded)
737
+ assert q.stride(-1) == k.stride(-1) == v.stride(-1) == o.stride(-1) == 1
738
+ assert dq.stride(-1) == dk.stride(-1) == dv.stride(-1) == 1
739
+ softmax_scale = softmax_scale or 1.0 / math.sqrt(d)
740
+ dq_accum = torch.empty_like(q, dtype=torch.float32)
741
+ delta = torch.empty_like(lse)
742
+ BLOCK_HEADDIM = max(triton.next_power_of_2(d), 16)
743
+ grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK_M"]), batch * nheads)
744
+ _bwd_preprocess_do_o_dot[grid](
745
+ o,
746
+ do,
747
+ delta,
748
+ o.stride(0),
749
+ o.stride(2),
750
+ o.stride(1),
751
+ do.stride(0),
752
+ do.stride(2),
753
+ do.stride(1),
754
+ nheads,
755
+ seqlen_q,
756
+ seqlen_q_rounded,
757
+ d,
758
+ BLOCK_M=128,
759
+ BLOCK_HEADDIM=BLOCK_HEADDIM,
760
+ )
761
+ has_bias = bias is not None
762
+ bias_type = "none"
763
+ if has_bias:
764
+ assert bias.dtype in [q.dtype, torch.float]
765
+ assert bias.is_cuda
766
+ assert bias.dim() == 4
767
+ assert bias.stride(-1) == 1
768
+ if bias.shape[2:] == (1, seqlen_k):
769
+ bias_type = "vector"
770
+ elif bias.shape[2:] == (seqlen_q, seqlen_k):
771
+ bias_type = "matrix"
772
+ else:
773
+ raise RuntimeError("Last 2 dimensions of bias must be (1, seqlen_k) or (seqlen_q, seqlen_k)")
774
+ bias = bias.expand(batch, nheads, seqlen_q, seqlen_k)
775
+ bias_strides = (bias.stride(0), bias.stride(1), bias.stride(2)) if has_bias else (0, 0, 0)
776
+ grid = lambda META: (triton.cdiv(seqlen_k, META["BLOCK_N"]) if META["SEQUENCE_PARALLEL"] else 1, batch * nheads)
777
+ _bwd_kernel[grid](
778
+ q,
779
+ k,
780
+ v,
781
+ bias,
782
+ do,
783
+ dq_accum,
784
+ dk,
785
+ dv,
786
+ lse,
787
+ delta,
788
+ softmax_scale,
789
+ q.stride(0),
790
+ q.stride(2),
791
+ q.stride(1),
792
+ k.stride(0),
793
+ k.stride(2),
794
+ k.stride(1),
795
+ v.stride(0),
796
+ v.stride(2),
797
+ v.stride(1),
798
+ *bias_strides,
799
+ do.stride(0),
800
+ do.stride(2),
801
+ do.stride(1),
802
+ dq_accum.stride(0),
803
+ dq_accum.stride(2),
804
+ dq_accum.stride(1),
805
+ dk.stride(0),
806
+ dk.stride(2),
807
+ dk.stride(1),
808
+ dv.stride(0),
809
+ dv.stride(2),
810
+ dv.stride(1),
811
+ nheads,
812
+ seqlen_q,
813
+ seqlen_k,
814
+ seqlen_q_rounded,
815
+ d,
816
+ seqlen_q // 32,
817
+ seqlen_k // 32,
818
+ bias_type,
819
+ causal,
820
+ BLOCK_HEADDIM
821
+ )
822
+ dq.copy_(dq_accum)
823
+
824
+
825
+ class FlashAttnQKVPackedFunc(torch.autograd.Function):
826
+ @staticmethod
827
+ def forward(ctx, qkv, bias=None, causal=False, softmax_scale=None):
828
+ """
829
+ qkv: (batch, seqlen, 3, nheads, headdim)
830
+ bias: optional, shape broadcastible to (batch, nheads, seqlen, seqlen).
831
+ For example, ALiBi mask for causal would have shape (1, nheads, 1, seqlen).
832
+ ALiBi mask for non-causal would have shape (1, nheads, seqlen, seqlen)
833
+ """
834
+ if qkv.stride(-1) != 1:
835
+ qkv = qkv.contiguous()
836
+ (o, lse, ctx.softmax_scale) = _flash_attn_forward(
837
+ qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2], bias=bias, causal=causal, softmax_scale=softmax_scale
838
+ )
839
+ ctx.save_for_backward(qkv, o, lse, bias)
840
+ ctx.causal = causal
841
+ return o
842
+
843
+ @staticmethod
844
+ def backward(ctx, do):
845
+ (qkv, o, lse, bias) = ctx.saved_tensors
846
+ assert not ctx.needs_input_grad[1], "FlashAttention does not support bias gradient yet"
847
+ with torch.inference_mode():
848
+ dqkv = torch.empty_like(qkv)
849
+ _flash_attn_backward(
850
+ do,
851
+ qkv[:, :, 0],
852
+ qkv[:, :, 1],
853
+ qkv[:, :, 2],
854
+ o,
855
+ lse,
856
+ dqkv[:, :, 0],
857
+ dqkv[:, :, 1],
858
+ dqkv[:, :, 2],
859
+ bias=bias,
860
+ causal=ctx.causal,
861
+ softmax_scale=ctx.softmax_scale,
862
+ )
863
+ return (dqkv, None, None, None)
864
+
865
+
866
+ flash_attn_qkvpacked_func = FlashAttnQKVPackedFunc.apply
867
+
868
+
869
+ class FlashAttnKVPackedFunc(torch.autograd.Function):
870
+ @staticmethod
871
+ def forward(ctx, q, kv, bias=None, causal=False, softmax_scale=None):
872
+ """
873
+ q: (batch, seqlen_q, nheads, headdim)
874
+ kv: (batch, seqlen_k, 2, nheads, headdim)
875
+ bias: optional, shape broadcastible to (batch, nheads, seqlen_q, seqlen_k).
876
+ For example, ALiBi mask for causal would have shape (1, nheads, 1, seqlen_k).
877
+ ALiBi mask for non-causal would have shape (1, nheads, seqlen_q, seqlen_k)
878
+ """
879
+ (q, kv) = (x if x.stride(-1) == 1 else x.contiguous() for x in [q, kv])
880
+ (o, lse, ctx.softmax_scale) = _flash_attn_forward(
881
+ q, kv[:, :, 0], kv[:, :, 1], bias=bias, causal=causal, softmax_scale=softmax_scale
882
+ )
883
+ ctx.save_for_backward(q, kv, o, lse, bias)
884
+ ctx.causal = causal
885
+ return o
886
+
887
+ @staticmethod
888
+ def backward(ctx, do):
889
+ (q, kv, o, lse, bias) = ctx.saved_tensors
890
+ if len(ctx.needs_input_grad) >= 3:
891
+ assert not ctx.needs_input_grad[2], "FlashAttention does not support bias gradient yet"
892
+ with torch.inference_mode():
893
+ dq = torch.empty_like(q)
894
+ dkv = torch.empty_like(kv)
895
+ _flash_attn_backward(
896
+ do,
897
+ q,
898
+ kv[:, :, 0],
899
+ kv[:, :, 1],
900
+ o,
901
+ lse,
902
+ dq,
903
+ dkv[:, :, 0],
904
+ dkv[:, :, 1],
905
+ bias=bias,
906
+ causal=ctx.causal,
907
+ softmax_scale=ctx.softmax_scale,
908
+ )
909
+ return (dq, dkv, None, None, None)
910
+
911
+
912
+ flash_attn_kvpacked_func = FlashAttnKVPackedFunc.apply
913
+
914
+
915
+ class FlashAttnFunc(torch.autograd.Function):
916
+ @staticmethod
917
+ def forward(ctx, q, k, v, bias=None, causal=False, softmax_scale=None):
918
+ """
919
+ q: (batch_size, seqlen_q, nheads, headdim)
920
+ k, v: (batch_size, seqlen_k, nheads, headdim)
921
+ bias: optional, shape broadcastible to (batch, nheads, seqlen_q, seqlen_k).
922
+ For example, ALiBi mask for causal would have shape (1, nheads, 1, seqlen_k).
923
+ ALiBi mask for non-causal would have shape (1, nheads, seqlen_q, seqlen_k)
924
+ """
925
+ (q, k, v) = (x if x.stride(-1) == 1 else x.contiguous() for x in [q, k, v])
926
+ (o, lse, ctx.softmax_scale) = _flash_attn_forward(
927
+ q, k, v, bias=bias, causal=causal, softmax_scale=softmax_scale
928
+ )
929
+ ctx.save_for_backward(q, k, v, o, lse, bias)
930
+ ctx.causal = causal
931
+ return o
932
+
933
+ @staticmethod
934
+ def backward(ctx, do):
935
+ (q, k, v, o, lse, bias) = ctx.saved_tensors
936
+ assert not ctx.needs_input_grad[3], "FlashAttention does not support bias gradient yet"
937
+ with torch.inference_mode():
938
+ dq = torch.empty_like(q)
939
+ dk = torch.empty_like(k)
940
+ dv = torch.empty_like(v)
941
+ _flash_attn_backward(
942
+ do, q, k, v, o, lse, dq, dk, dv, bias=bias, causal=ctx.causal, softmax_scale=ctx.softmax_scale
943
+ )
944
+ return (dq, dk, dv, None, None, None)
945
+
946
+
947
+ flash_attn_func = FlashAttnFunc.apply
VILA/llava/model/language_model/mpt/hf_prefixlm_converter.py ADDED
@@ -0,0 +1,657 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 NVIDIA CORPORATION & AFFILIATES
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # SPDX-License-Identifier: Apache-2.0
16
+
17
+ """Converts Huggingface Causal LM to Prefix LM.
18
+
19
+ Conversion does lightweight surgery on a HuggingFace
20
+ Causal LM to convert it to a Prefix LM.
21
+
22
+ Prefix LMs accepts a `bidirectional_mask` input in `forward`
23
+ and treat the input prompt as the prefix in `generate`.
24
+ """
25
+ import math
26
+ import warnings
27
+ from types import MethodType
28
+ from typing import Any, Dict, List, Optional, Tuple, Union
29
+
30
+ import torch
31
+ from transformers.models.bloom.modeling_bloom import (
32
+ BaseModelOutputWithPastAndCrossAttentions,
33
+ BloomForCausalLM,
34
+ BloomModel,
35
+ CausalLMOutputWithCrossAttentions,
36
+ CrossEntropyLoss,
37
+ )
38
+ from transformers.models.bloom.modeling_bloom import _expand_mask as _expand_mask_bloom
39
+ from transformers.models.bloom.modeling_bloom import _make_causal_mask as _make_causal_mask_bloom
40
+ from transformers.models.bloom.modeling_bloom import logging
41
+ from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel
42
+ from transformers.models.gpt_neo.modeling_gpt_neo import GPTNeoForCausalLM
43
+ from transformers.models.gpt_neox.modeling_gpt_neox import GPTNeoXForCausalLM
44
+ from transformers.models.gptj.modeling_gptj import GPTJForCausalLM
45
+ from transformers.models.opt.modeling_opt import OPTForCausalLM
46
+ from transformers.models.opt.modeling_opt import _expand_mask as _expand_mask_opt
47
+ from transformers.models.opt.modeling_opt import _make_causal_mask as _make_causal_mask_opt
48
+
49
+ logger = logging.get_logger(__name__)
50
+ _SUPPORTED_GPT_MODELS = (GPT2LMHeadModel, GPTJForCausalLM, GPTNeoForCausalLM, GPTNeoXForCausalLM)
51
+ CAUSAL_GPT_TYPES = Union[GPT2LMHeadModel, GPTJForCausalLM, GPTNeoForCausalLM, GPTNeoXForCausalLM]
52
+
53
+
54
+ def _convert_gpt_causal_lm_to_prefix_lm(model: CAUSAL_GPT_TYPES) -> CAUSAL_GPT_TYPES:
55
+ """Converts a GPT-style Causal LM to a Prefix LM.
56
+
57
+ Supported HuggingFace model classes:
58
+ - `GPT2LMHeadModel`
59
+ - `GPTNeoForCausalLM`
60
+ - `GPTNeoXForCausalLM`
61
+ - `GPTJForCausalLM`
62
+
63
+ See `convert_hf_causal_lm_to_prefix_lm` for more details.
64
+ """
65
+ if hasattr(model, "_prefix_lm_converted"):
66
+ return model
67
+ assert isinstance(model, _SUPPORTED_GPT_MODELS)
68
+ assert model.config.add_cross_attention == False, "Only supports GPT-style decoder-only models"
69
+
70
+ def _get_attn_modules(model: CAUSAL_GPT_TYPES) -> List[torch.nn.Module]:
71
+ """Helper that gets a list of the model's attention modules.
72
+
73
+ Each module has a `bias` buffer used for causal masking. The Prefix LM
74
+ conversion adds logic to dynamically manipulate these biases to support
75
+ Prefix LM attention masking.
76
+ """
77
+ attn_modules = []
78
+ if isinstance(model, GPTNeoXForCausalLM):
79
+ blocks = model.gpt_neox.layers
80
+ else:
81
+ blocks = model.transformer.h
82
+ for block in blocks:
83
+ if isinstance(model, GPTNeoForCausalLM):
84
+ if block.attn.attention_type != "global":
85
+ continue
86
+ attn_module = block.attn.attention
87
+ elif isinstance(model, GPTNeoXForCausalLM):
88
+ attn_module = block.attention
89
+ else:
90
+ attn_module = block.attn
91
+ attn_modules.append(attn_module)
92
+ return attn_modules
93
+
94
+ setattr(model, "_original_forward", getattr(model, "forward"))
95
+ setattr(model, "_original_generate", getattr(model, "generate"))
96
+
97
+ def forward(
98
+ self: CAUSAL_GPT_TYPES,
99
+ input_ids: Optional[torch.LongTensor] = None,
100
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
101
+ attention_mask: Optional[torch.FloatTensor] = None,
102
+ bidirectional_mask: Optional[torch.Tensor] = None,
103
+ token_type_ids: Optional[torch.LongTensor] = None,
104
+ position_ids: Optional[torch.LongTensor] = None,
105
+ head_mask: Optional[torch.FloatTensor] = None,
106
+ inputs_embeds: Optional[torch.FloatTensor] = None,
107
+ labels: Optional[torch.LongTensor] = None,
108
+ use_cache: Optional[bool] = None,
109
+ output_attentions: Optional[bool] = None,
110
+ output_hidden_states: Optional[bool] = None,
111
+ return_dict: Optional[bool] = None,
112
+ ):
113
+ """Wraps original forward to enable PrefixLM attention."""
114
+
115
+ def call_og_forward():
116
+ if isinstance(self, GPTNeoXForCausalLM):
117
+ return self._original_forward(
118
+ input_ids=input_ids,
119
+ past_key_values=past_key_values,
120
+ attention_mask=attention_mask,
121
+ head_mask=head_mask,
122
+ inputs_embeds=inputs_embeds,
123
+ labels=labels,
124
+ use_cache=use_cache,
125
+ output_attentions=output_attentions,
126
+ output_hidden_states=output_hidden_states,
127
+ return_dict=return_dict,
128
+ )
129
+ else:
130
+ return self._original_forward(
131
+ input_ids=input_ids,
132
+ past_key_values=past_key_values,
133
+ attention_mask=attention_mask,
134
+ token_type_ids=token_type_ids,
135
+ position_ids=position_ids,
136
+ head_mask=head_mask,
137
+ inputs_embeds=inputs_embeds,
138
+ labels=labels,
139
+ use_cache=use_cache,
140
+ output_attentions=output_attentions,
141
+ output_hidden_states=output_hidden_states,
142
+ return_dict=return_dict,
143
+ )
144
+
145
+ if bidirectional_mask is None:
146
+ return call_og_forward()
147
+ assert isinstance(bidirectional_mask, torch.Tensor)
148
+ attn_modules = _get_attn_modules(model)
149
+ (b, s) = bidirectional_mask.shape
150
+ max_length = attn_modules[0].bias.shape[-1]
151
+ if s > max_length:
152
+ raise ValueError(
153
+ f"bidirectional_mask sequence length (={s}) exceeds the "
154
+ + f"max length allowed by the model ({max_length})."
155
+ )
156
+ assert s <= max_length
157
+ if s < max_length:
158
+ pad = torch.zeros(
159
+ (int(b), int(max_length - s)), dtype=bidirectional_mask.dtype, device=bidirectional_mask.device
160
+ )
161
+ bidirectional_mask = torch.cat([bidirectional_mask, pad], dim=1)
162
+ bidirectional = bidirectional_mask.unsqueeze(1).unsqueeze(1)
163
+ for attn_module in attn_modules:
164
+ attn_module.bias.data = torch.logical_or(attn_module.bias.data, bidirectional)
165
+ output = call_og_forward()
166
+ for attn_module in attn_modules:
167
+ attn_module.bias.data = torch.tril(attn_module.bias.data[0, 0])[None, None]
168
+ return output
169
+
170
+ def generate(self: CAUSAL_GPT_TYPES, *args: tuple, **kwargs: Dict[str, Any]):
171
+ """Wraps original generate to enable PrefixLM attention."""
172
+ attn_modules = _get_attn_modules(model)
173
+ for attn_module in attn_modules:
174
+ attn_module.bias.data[:] = 1
175
+ output = self._original_generate(*args, **kwargs)
176
+ for attn_module in attn_modules:
177
+ attn_module.bias.data = torch.tril(attn_module.bias.data[0, 0])[None, None]
178
+ return output
179
+
180
+ setattr(model, "forward", MethodType(forward, model))
181
+ setattr(model, "generate", MethodType(generate, model))
182
+ setattr(model, "_prefix_lm_converted", True)
183
+ return model
184
+
185
+
186
+ def _convert_bloom_causal_lm_to_prefix_lm(model: BloomForCausalLM) -> BloomForCausalLM:
187
+ """Converts a BLOOM Causal LM to a Prefix LM.
188
+
189
+ Supported HuggingFace model classes:
190
+ - `BloomForCausalLM`
191
+
192
+ See `convert_hf_causal_lm_to_prefix_lm` for more details.
193
+ """
194
+ if hasattr(model, "_prefix_lm_converted"):
195
+ return model
196
+ assert isinstance(model, BloomForCausalLM)
197
+ assert model.config.add_cross_attention == False, "Only supports BLOOM decoder-only models"
198
+
199
+ def _prepare_attn_mask(
200
+ self: BloomModel,
201
+ attention_mask: torch.Tensor,
202
+ bidirectional_mask: Optional[torch.Tensor],
203
+ input_shape: Tuple[int, int],
204
+ past_key_values_length: int,
205
+ ) -> torch.BoolTensor:
206
+ combined_attention_mask = None
207
+ device = attention_mask.device
208
+ (_, src_length) = input_shape
209
+ if src_length > 1:
210
+ combined_attention_mask = _make_causal_mask_bloom(
211
+ input_shape, device=device, past_key_values_length=past_key_values_length
212
+ )
213
+ if bidirectional_mask is not None:
214
+ assert attention_mask.shape == bidirectional_mask.shape
215
+ expanded_bidirectional_mask = _expand_mask_bloom(bidirectional_mask, tgt_length=src_length)
216
+ combined_attention_mask = torch.logical_and(combined_attention_mask, expanded_bidirectional_mask)
217
+ expanded_attn_mask = _expand_mask_bloom(attention_mask, tgt_length=src_length)
218
+ combined_attention_mask = (
219
+ expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask | combined_attention_mask
220
+ )
221
+ return combined_attention_mask
222
+
223
+ def _build_alibi_tensor(
224
+ self: BloomModel, batch_size: int, query_length: int, key_length: int, dtype: torch.dtype, device: torch.device
225
+ ) -> torch.Tensor:
226
+ num_heads = self.config.n_head
227
+ closest_power_of_2 = 2 ** math.floor(math.log2(num_heads))
228
+ base = torch.tensor(2 ** (-(2 ** (-(math.log2(closest_power_of_2) - 3)))), device=device, dtype=torch.float32)
229
+ powers = torch.arange(1, 1 + closest_power_of_2, device=device, dtype=torch.int32)
230
+ slopes = torch.pow(base, powers)
231
+ if closest_power_of_2 != num_heads:
232
+ extra_base = torch.tensor(
233
+ 2 ** (-(2 ** (-(math.log2(2 * closest_power_of_2) - 3)))), device=device, dtype=torch.float32
234
+ )
235
+ num_remaining_heads = min(closest_power_of_2, num_heads - closest_power_of_2)
236
+ extra_powers = torch.arange(1, 1 + 2 * num_remaining_heads, 2, device=device, dtype=torch.int32)
237
+ slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0)
238
+ qa = torch.arange(query_length, device=device, dtype=torch.int32).view(-1, 1)
239
+ ka = torch.arange(key_length, device=device, dtype=torch.int32).view(1, -1)
240
+ diffs = qa - ka + key_length - query_length
241
+ diffs = -diffs.abs()
242
+ alibi = slopes.view(1, num_heads, 1, 1) * diffs.view(1, 1, query_length, key_length)
243
+ alibi = alibi.expand(batch_size, -1, -1, -1).reshape(-1, query_length, key_length)
244
+ return alibi.to(dtype)
245
+
246
+ KeyValueT = Tuple[torch.Tensor, torch.Tensor]
247
+
248
+ def forward(
249
+ self: BloomModel,
250
+ input_ids: Optional[torch.LongTensor] = None,
251
+ past_key_values: Optional[Tuple[KeyValueT, ...]] = None,
252
+ attention_mask: Optional[torch.Tensor] = None,
253
+ bidirectional_mask: Optional[torch.Tensor] = None,
254
+ head_mask: Optional[torch.LongTensor] = None,
255
+ inputs_embeds: Optional[torch.LongTensor] = None,
256
+ use_cache: Optional[bool] = None,
257
+ output_attentions: Optional[bool] = None,
258
+ output_hidden_states: Optional[bool] = None,
259
+ return_dict: Optional[bool] = None,
260
+ **deprecated_arguments,
261
+ ) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]:
262
+ if deprecated_arguments.pop("position_ids", False) is not False:
263
+ warnings.warn(
264
+ "`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. "
265
+ + "You can safely ignore passing `position_ids`.",
266
+ FutureWarning,
267
+ )
268
+ if len(deprecated_arguments) > 0:
269
+ raise ValueError(f"Got unexpected arguments: {deprecated_arguments}")
270
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
271
+ output_hidden_states = (
272
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
273
+ )
274
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
275
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
276
+ if input_ids is not None and inputs_embeds is not None:
277
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
278
+ elif input_ids is not None:
279
+ (batch_size, seq_length) = input_ids.shape
280
+ elif inputs_embeds is not None:
281
+ (batch_size, seq_length, _) = inputs_embeds.shape
282
+ else:
283
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
284
+ if past_key_values is None:
285
+ past_key_values = tuple([None] * len(self.h))
286
+ head_mask = self.get_head_mask(head_mask, self.config.n_layer)
287
+ if inputs_embeds is None:
288
+ inputs_embeds = self.word_embeddings(input_ids)
289
+ hidden_states = self.word_embeddings_layernorm(inputs_embeds)
290
+ presents = () if use_cache else None
291
+ all_self_attentions = () if output_attentions else None
292
+ all_hidden_states = () if output_hidden_states else None
293
+ seq_length_with_past = seq_length
294
+ past_key_values_length = 0
295
+ if past_key_values[0] is not None:
296
+ tmp = past_key_values[0][0]
297
+ past_key_values_length = tmp.shape[2]
298
+ seq_length_with_past = seq_length_with_past + past_key_values_length
299
+ if attention_mask is None:
300
+ attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device)
301
+ else:
302
+ attention_mask = attention_mask.to(hidden_states.device)
303
+ alibi = self._build_alibi_tensor(
304
+ batch_size=batch_size,
305
+ query_length=seq_length,
306
+ key_length=seq_length_with_past,
307
+ dtype=hidden_states.dtype,
308
+ device=hidden_states.device,
309
+ )
310
+ causal_mask = self._prepare_attn_mask(
311
+ attention_mask,
312
+ bidirectional_mask,
313
+ input_shape=(batch_size, seq_length),
314
+ past_key_values_length=past_key_values_length,
315
+ )
316
+ for (i, (block, layer_past)) in enumerate(zip(self.h, past_key_values)):
317
+ if output_hidden_states:
318
+ hst = (hidden_states,)
319
+ all_hidden_states = all_hidden_states + hst
320
+ if self.gradient_checkpointing and self.training:
321
+ if use_cache:
322
+ logger.warning(
323
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
324
+ )
325
+ use_cache = False
326
+
327
+ def create_custom_forward(module):
328
+ def custom_forward(*inputs):
329
+ return module(*inputs, use_cache=use_cache, output_attentions=output_attentions)
330
+
331
+ return custom_forward
332
+
333
+ outputs = torch.utils.checkpoint.checkpoint(
334
+ create_custom_forward(block), hidden_states, alibi, causal_mask, head_mask[i]
335
+ )
336
+ else:
337
+ outputs = block(
338
+ hidden_states,
339
+ layer_past=layer_past,
340
+ attention_mask=causal_mask,
341
+ head_mask=head_mask[i],
342
+ use_cache=use_cache,
343
+ output_attentions=output_attentions,
344
+ alibi=alibi,
345
+ )
346
+ hidden_states = outputs[0]
347
+ if use_cache is True:
348
+ presents = presents + (outputs[1],)
349
+ if output_attentions:
350
+ oa = (outputs[2 if use_cache else 1],)
351
+ all_self_attentions = all_self_attentions + oa
352
+ hidden_states = self.ln_f(hidden_states)
353
+ if output_hidden_states:
354
+ hst = (hidden_states,)
355
+ all_hidden_states = all_hidden_states + hst
356
+ if not return_dict:
357
+ return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
358
+ return BaseModelOutputWithPastAndCrossAttentions(
359
+ last_hidden_state=hidden_states,
360
+ past_key_values=presents,
361
+ hidden_states=all_hidden_states,
362
+ attentions=all_self_attentions,
363
+ )
364
+
365
+ setattr(model.transformer, "_prepare_attn_mask", MethodType(_prepare_attn_mask, model.transformer))
366
+ setattr(model.transformer, "_build_alibi_tensor", MethodType(_build_alibi_tensor, model.transformer))
367
+ setattr(model.transformer, "forward", MethodType(forward, model.transformer))
368
+ KeyValueT = Tuple[torch.Tensor, torch.Tensor]
369
+
370
+ def forward(
371
+ self: BloomForCausalLM,
372
+ input_ids: Optional[torch.LongTensor] = None,
373
+ past_key_values: Optional[Tuple[KeyValueT, ...]] = None,
374
+ attention_mask: Optional[torch.Tensor] = None,
375
+ bidirectional_mask: Optional[torch.Tensor] = None,
376
+ head_mask: Optional[torch.Tensor] = None,
377
+ inputs_embeds: Optional[torch.Tensor] = None,
378
+ labels: Optional[torch.Tensor] = None,
379
+ use_cache: Optional[bool] = None,
380
+ output_attentions: Optional[bool] = None,
381
+ output_hidden_states: Optional[bool] = None,
382
+ return_dict: Optional[bool] = None,
383
+ **deprecated_arguments,
384
+ ) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
385
+ """Replacement forward method for BloomCausalLM."""
386
+ if deprecated_arguments.pop("position_ids", False) is not False:
387
+ warnings.warn(
388
+ "`position_ids` have no functionality in BLOOM and will be removed "
389
+ + "in v5.0.0. You can safely ignore passing `position_ids`.",
390
+ FutureWarning,
391
+ )
392
+ if len(deprecated_arguments) > 0:
393
+ raise ValueError(f"Got unexpected arguments: {deprecated_arguments}")
394
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
395
+ transformer_outputs = self.transformer(
396
+ input_ids,
397
+ past_key_values=past_key_values,
398
+ attention_mask=attention_mask,
399
+ bidirectional_mask=bidirectional_mask,
400
+ head_mask=head_mask,
401
+ inputs_embeds=inputs_embeds,
402
+ use_cache=use_cache,
403
+ output_attentions=output_attentions,
404
+ output_hidden_states=output_hidden_states,
405
+ return_dict=return_dict,
406
+ )
407
+ hidden_states = transformer_outputs[0]
408
+ lm_logits = self.lm_head(hidden_states)
409
+ loss = None
410
+ if labels is not None:
411
+ shift_logits = lm_logits[..., :-1, :].contiguous()
412
+ shift_labels = labels[..., 1:].contiguous()
413
+ (batch_size, seq_length, vocab_size) = shift_logits.shape
414
+ loss_fct = CrossEntropyLoss()
415
+ loss = loss_fct(
416
+ shift_logits.view(batch_size * seq_length, vocab_size), shift_labels.view(batch_size * seq_length)
417
+ )
418
+ if not return_dict:
419
+ output = (lm_logits,) + transformer_outputs[1:]
420
+ return (loss,) + output if loss is not None else output
421
+ return CausalLMOutputWithCrossAttentions(
422
+ loss=loss,
423
+ logits=lm_logits,
424
+ past_key_values=transformer_outputs.past_key_values,
425
+ hidden_states=transformer_outputs.hidden_states,
426
+ attentions=transformer_outputs.attentions,
427
+ )
428
+
429
+ def prepare_inputs_for_generation(
430
+ self: BloomForCausalLM,
431
+ input_ids: torch.LongTensor,
432
+ past: Optional[torch.Tensor] = None,
433
+ attention_mask: Optional[torch.Tensor] = None,
434
+ **kwargs,
435
+ ) -> dict:
436
+ if past:
437
+ input_ids = input_ids[:, -1].unsqueeze(-1)
438
+ bidirectional_mask = None
439
+ if past[0][0].shape[0] == input_ids.shape[0]:
440
+ past = self._convert_to_bloom_cache(past)
441
+ else:
442
+ bidirectional_mask = torch.ones_like(input_ids)
443
+ return {
444
+ "input_ids": input_ids,
445
+ "past_key_values": past,
446
+ "use_cache": True,
447
+ "attention_mask": attention_mask,
448
+ "bidirectional_mask": bidirectional_mask,
449
+ }
450
+
451
+ setattr(model, "forward", MethodType(forward, model))
452
+ setattr(model, "prepare_inputs_for_generation", MethodType(prepare_inputs_for_generation, model))
453
+ setattr(model, "_prefix_lm_converted", True)
454
+ return model
455
+
456
+
457
+ def _convert_opt_causal_lm_to_prefix_lm(model: OPTForCausalLM) -> OPTForCausalLM:
458
+ """Converts an OPT Causal LM to a Prefix LM.
459
+
460
+ Supported HuggingFace model classes:
461
+ - `OPTForCausalLM`
462
+
463
+ See `convert_hf_causal_lm_to_prefix_lm` for more details.
464
+ """
465
+ if hasattr(model, "_prefix_lm_converted"):
466
+ return model
467
+ assert isinstance(model, OPTForCausalLM)
468
+ assert model.config.add_cross_attention == False, "Only supports OPT decoder-only models"
469
+ setattr(model, "_original_forward", getattr(model, "forward"))
470
+ setattr(model, "_original_generate", getattr(model, "generate"))
471
+ model.model.decoder.bidirectional_mask = None
472
+
473
+ def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
474
+ combined_attention_mask = None
475
+ if input_shape[-1] > 1:
476
+ if self.bidirectional_mask == "g":
477
+ (bsz, src_length) = input_shape
478
+ combined_attention_mask = torch.zeros(
479
+ (bsz, 1, src_length, src_length + past_key_values_length),
480
+ dtype=inputs_embeds.dtype,
481
+ device=inputs_embeds.device,
482
+ )
483
+ else:
484
+ combined_attention_mask = _make_causal_mask_opt(
485
+ input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length
486
+ ).to(inputs_embeds.device)
487
+ if self.bidirectional_mask is not None:
488
+ assert attention_mask.shape == self.bidirectional_mask.shape
489
+ expanded_bidirectional_mask = _expand_mask_opt(
490
+ self.bidirectional_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
491
+ ).to(inputs_embeds.device)
492
+ combined_attention_mask = torch.maximum(expanded_bidirectional_mask, combined_attention_mask)
493
+ if attention_mask is not None:
494
+ expanded_attn_mask = _expand_mask_opt(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
495
+ inputs_embeds.device
496
+ )
497
+ combined_attention_mask = (
498
+ expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
499
+ )
500
+ return combined_attention_mask
501
+
502
+ setattr(
503
+ model.model.decoder,
504
+ "_prepare_decoder_attention_mask",
505
+ MethodType(_prepare_decoder_attention_mask, model.model.decoder),
506
+ )
507
+
508
+ def forward(
509
+ self: OPTForCausalLM,
510
+ input_ids: Optional[torch.LongTensor] = None,
511
+ attention_mask: Optional[torch.Tensor] = None,
512
+ bidirectional_mask: Optional[torch.ByteTensor] = None,
513
+ head_mask: Optional[torch.Tensor] = None,
514
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
515
+ inputs_embeds: Optional[torch.FloatTensor] = None,
516
+ labels: Optional[torch.LongTensor] = None,
517
+ use_cache: Optional[bool] = None,
518
+ output_attentions: Optional[bool] = None,
519
+ output_hidden_states: Optional[bool] = None,
520
+ return_dict: Optional[bool] = None,
521
+ ):
522
+ def call_og_forward():
523
+ return self._original_forward(
524
+ input_ids=input_ids,
525
+ attention_mask=attention_mask,
526
+ head_mask=head_mask,
527
+ past_key_values=past_key_values,
528
+ inputs_embeds=inputs_embeds,
529
+ labels=labels,
530
+ use_cache=use_cache,
531
+ output_attentions=output_attentions,
532
+ output_hidden_states=output_hidden_states,
533
+ return_dict=return_dict,
534
+ )
535
+
536
+ if bidirectional_mask is None:
537
+ return call_og_forward()
538
+ self.model.decoder.bidirectional_mask = bidirectional_mask
539
+ try:
540
+ outputs = call_og_forward()
541
+ except:
542
+ self.model.decoder.bidirectional_mask = None
543
+ raise
544
+ self.model.decoder.bidirectional_mask = None
545
+ return outputs
546
+
547
+ def generate(self: OPTForCausalLM, *args: tuple, **kwargs: Dict[str, Any]):
548
+ """Wraps original generate to enable PrefixLM-style attention."""
549
+ self.model.decoder.bidirectional_mask = "g"
550
+ try:
551
+ output = self._original_generate(*args, **kwargs)
552
+ except:
553
+ self.model.decoder.bidirectional_mask = None
554
+ raise
555
+ self.model.decoder.bidirectional_mask = None
556
+ return output
557
+
558
+ setattr(model, "forward", MethodType(forward, model))
559
+ setattr(model, "generate", MethodType(generate, model))
560
+ setattr(model, "_prefix_lm_converted", True)
561
+ return model
562
+
563
+
564
+ _SUPPORTED_HF_MODELS = _SUPPORTED_GPT_MODELS + (BloomForCausalLM, OPTForCausalLM)
565
+ CAUSAL_LM_TYPES = Union[
566
+ GPT2LMHeadModel, GPTJForCausalLM, GPTNeoForCausalLM, GPTNeoXForCausalLM, BloomForCausalLM, OPTForCausalLM
567
+ ]
568
+
569
+
570
+ def convert_hf_causal_lm_to_prefix_lm(model: CAUSAL_LM_TYPES) -> CAUSAL_LM_TYPES:
571
+ """Converts a HuggingFace Causal LM to a Prefix LM.
572
+
573
+ Supported HuggingFace model classes:
574
+ - `GPT2LMHeadModel`
575
+ - `GPTNeoForCausalLM`
576
+ - `GPTNeoXForCausalLM`
577
+ - `GPTJForCausalLM`
578
+ - `BloomForCausalLM`
579
+ - `OPTForCausalLM`
580
+
581
+ Conversion to a Prefix LM is done by modifying the `forward` method, and possibly also the
582
+ `generate` method and/or select underlying methods depending on the model class.
583
+
584
+ These changes preserve the model API, but add a new input to `forward`: "bidirectional_mask".
585
+
586
+ Notes on training:
587
+ To actually train the converted model as a Prefix LM, training batches will need to indicate
588
+ the prefix/target structure by including `bidirectional_mask` as part of the batch inputs.
589
+
590
+ **This is not a standard input and requires custom layers either within or after your dataloader.**
591
+
592
+ In addition to adding `bidirectional_mask` to the batch, this custom code should modify `labels`
593
+ such that `batch['labels'][batch['bidirectional_mask'] == 1] == -100`.
594
+ That is, the prefix portion of the sequence should not generate any loss. Loss should only be
595
+ generated by the target portion of the sequence.
596
+
597
+ Notes on `GPTNeoForCausalLM`:
598
+ To simplify the implementation, "global" and "local" attention layers are handled differently.
599
+ For "global" layers, we handle conversion as described above. For "local" layers, which use a
600
+ causal attention mask within a restricted local window, we do not alter the masking.
601
+
602
+ Notes on `forward` method conversion:
603
+ After conversion, the `forward` method will handle a new input, `bidirectional_mask`,
604
+ which should be a [batch_size, seq_length] byte tensor, where 1 indicates token positions
605
+ belonging to the prefix (prefix tokens can attend to one another bidirectionally), and
606
+ 0 indicates token positions belonging to the target.
607
+
608
+ The new `forward` method will incorporate `bidirectional_mask` (if supplied) into the existing
609
+ causal mask, call the original `forward` method, and (if the causal mask is a buffer) reset
610
+ the causal masks before returning the result.
611
+
612
+ Notes on `generate` method conversion:
613
+ After conversion, the `generate` method will have the same signature but will internally
614
+ convert all causal masks to be purely bidirectional, call the original `generate` method, and
615
+ (where appropriate) reset the causal masks before returning the result.
616
+
617
+ This works thanks to the logic of the HuggingFace `generate` API, which first encodes the token
618
+ "prompt" passed to `generate` (which is treated as the prefix) and then sequentially generates
619
+ each new token. Encodings are cached as generation happens, so all prefix tokens can attend to one
620
+ another (as expected in a Prefix LM) and generated tokens can only attend to prefix tokens and
621
+ previously-generated tokens (also as expected in a Prefix LM).
622
+
623
+ To preserve the API, the original methods are renamed to `_original_forward` and
624
+ `_original_generate`, and replaced with new `forward` and `generate` methods that wrap
625
+ them, respectively. Although implementation details vary by model class.
626
+ """
627
+ if isinstance(model, _SUPPORTED_GPT_MODELS):
628
+ return _convert_gpt_causal_lm_to_prefix_lm(model)
629
+ elif isinstance(model, BloomForCausalLM):
630
+ return _convert_bloom_causal_lm_to_prefix_lm(model)
631
+ elif isinstance(model, OPTForCausalLM):
632
+ return _convert_opt_causal_lm_to_prefix_lm(model)
633
+ else:
634
+ raise TypeError(
635
+ f"Cannot convert model to Prefix LM. "
636
+ + f"Model does not belong to set of supported HF models:"
637
+ + f"\n{_SUPPORTED_HF_MODELS}"
638
+ )
639
+
640
+
641
+ def add_bidirectional_mask_if_missing(batch: Dict[str, Any]):
642
+ """Attempts to add bidirectional_mask to batch if missing.
643
+
644
+ Raises:
645
+ KeyError if bidirectional_mask is missing and can't be inferred
646
+ """
647
+ if "bidirectional_mask" not in batch:
648
+ if batch.get("mode", None) == "icl_task":
649
+ batch["bidirectional_mask"] = batch["attention_mask"].clone()
650
+ for (i, continuation_indices) in enumerate(batch["continuation_indices"]):
651
+ batch["bidirectional_mask"][i, continuation_indices] = 0
652
+ elif "labels" in batch and "attention_mask" in batch:
653
+ batch["bidirectional_mask"] = torch.logical_and(
654
+ torch.eq(batch["attention_mask"], 1), torch.eq(batch["labels"], -100)
655
+ ).type_as(batch["attention_mask"])
656
+ else:
657
+ raise KeyError("No bidirectional_mask in batch and not sure how to construct one.")
VILA/llava/model/language_model/mpt/meta_init_context.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 NVIDIA CORPORATION & AFFILIATES
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # SPDX-License-Identifier: Apache-2.0
16
+
17
+ from contextlib import contextmanager
18
+
19
+ import torch
20
+ import torch.nn as nn
21
+
22
+
23
+ @contextmanager
24
+ def init_empty_weights(include_buffers: bool = False):
25
+ """Meta initialization context manager.
26
+
27
+ A context manager under which models are initialized with all parameters
28
+ on the meta device, therefore creating an empty model. Useful when just
29
+ initializing the model would blow the available RAM.
30
+
31
+ Args:
32
+ include_buffers (`bool`, *optional*, defaults to `False`): Whether or
33
+ not to also put all buffers on the meta device while initializing.
34
+
35
+ Example:
36
+ ```python
37
+ import torch.nn as nn
38
+
39
+ # Initialize a model with 100 billions parameters in no time and without using any RAM.
40
+ with init_empty_weights():
41
+ tst = nn.Sequential(*[nn.Linear(10000, 10000) for _ in range(1000)])
42
+ ```
43
+
44
+ <Tip warning={true}>
45
+
46
+ Any model created under this context manager has no weights. As such you can't do something like
47
+ `model.to(some_device)` with it. To load weights inside your empty model, see [`load_checkpoint_and_dispatch`].
48
+
49
+ </Tip>
50
+ """
51
+ with init_on_device(torch.device("meta"), include_buffers=include_buffers) as f:
52
+ yield f
53
+
54
+
55
+ @contextmanager
56
+ def init_on_device(device: torch.device, include_buffers: bool = False):
57
+ """Device initialization context manager.
58
+
59
+ A context manager under which models are initialized with all parameters
60
+ on the specified device.
61
+
62
+ Args:
63
+ device (`torch.device`): Device to initialize all parameters on.
64
+ include_buffers (`bool`, *optional*, defaults to `False`): Whether or
65
+ not to also put all buffers on the meta device while initializing.
66
+
67
+ Example:
68
+ ```python
69
+ import torch.nn as nn
70
+
71
+ with init_on_device(device=torch.device("cuda")):
72
+ tst = nn.Liner(100, 100) # on `cuda` device
73
+ ```
74
+ """
75
+ old_register_parameter = nn.Module.register_parameter
76
+ if include_buffers:
77
+ old_register_buffer = nn.Module.register_buffer
78
+
79
+ def register_empty_parameter(module, name, param):
80
+ old_register_parameter(module, name, param)
81
+ if param is not None:
82
+ param_cls = type(module._parameters[name])
83
+ kwargs = module._parameters[name].__dict__
84
+ module._parameters[name] = param_cls(module._parameters[name].to(device), **kwargs)
85
+
86
+ def register_empty_buffer(module, name, buffer):
87
+ old_register_buffer(module, name, buffer)
88
+ if buffer is not None:
89
+ module._buffers[name] = module._buffers[name].to(device)
90
+
91
+ if include_buffers:
92
+ tensor_constructors_to_patch = {
93
+ torch_function_name: getattr(torch, torch_function_name)
94
+ for torch_function_name in ["empty", "zeros", "ones", "full"]
95
+ }
96
+ else:
97
+ tensor_constructors_to_patch = {}
98
+
99
+ def patch_tensor_constructor(fn):
100
+ def wrapper(*args, **kwargs):
101
+ kwargs["device"] = device
102
+ return fn(*args, **kwargs)
103
+
104
+ return wrapper
105
+
106
+ try:
107
+ nn.Module.register_parameter = register_empty_parameter
108
+ if include_buffers:
109
+ nn.Module.register_buffer = register_empty_buffer
110
+ for torch_function_name in tensor_constructors_to_patch.keys():
111
+ setattr(torch, torch_function_name, patch_tensor_constructor(getattr(torch, torch_function_name)))
112
+ yield
113
+ finally:
114
+ nn.Module.register_parameter = old_register_parameter
115
+ if include_buffers:
116
+ nn.Module.register_buffer = old_register_buffer
117
+ for (torch_function_name, old_torch_function) in tensor_constructors_to_patch.items():
118
+ setattr(torch, torch_function_name, old_torch_function)
VILA/llava/model/language_model/mpt/modeling_mpt.py ADDED
@@ -0,0 +1,483 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 NVIDIA CORPORATION & AFFILIATES
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # SPDX-License-Identifier: Apache-2.0
16
+
17
+ """A simple, flexible implementation of a GPT model.
18
+
19
+ Inspired by https://github.com/karpathy/minGPT/blob/master/mingpt/model.py
20
+ """
21
+ import math
22
+ import warnings
23
+ from typing import List, Optional, Tuple, Union
24
+
25
+ import torch
26
+ import torch.nn as nn
27
+ import torch.nn.functional as F
28
+ from transformers import PreTrainedModel, PreTrainedTokenizer, PreTrainedTokenizerFast
29
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
30
+
31
+ from .adapt_tokenizer import AutoTokenizerForMOD, adapt_tokenizer_for_denoising
32
+ from .attention import attn_bias_shape, build_attn_bias
33
+ from .blocks import MPTBlock
34
+ from .configuration_mpt import MPTConfig
35
+ from .custom_embedding import SharedEmbedding
36
+ from .hf_prefixlm_converter import add_bidirectional_mask_if_missing, convert_hf_causal_lm_to_prefix_lm
37
+ from .meta_init_context import init_empty_weights
38
+ from .norm import NORM_CLASS_REGISTRY
39
+ from .param_init_fns import MODEL_INIT_REGISTRY, generic_param_init_fn_
40
+
41
+ try:
42
+ from .flash_attn_triton import flash_attn_func
43
+ except:
44
+ pass
45
+ Tokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
46
+
47
+
48
+ class MPTPreTrainedModel(PreTrainedModel):
49
+ config_class = MPTConfig
50
+ base_model_prefix = "model"
51
+ _no_split_modules = ["MPTBlock"]
52
+
53
+
54
+ class MPTModel(MPTPreTrainedModel):
55
+ def __init__(self, config: MPTConfig):
56
+ config._validate_config()
57
+ super().__init__(config)
58
+ self.attn_impl = config.attn_config["attn_impl"]
59
+ self.prefix_lm = config.attn_config["prefix_lm"]
60
+ self.attn_uses_sequence_id = config.attn_config["attn_uses_sequence_id"]
61
+ self.alibi = config.attn_config["alibi"]
62
+ self.alibi_bias_max = config.attn_config["alibi_bias_max"]
63
+ if config.init_device == "mixed":
64
+ if dist.get_local_rank() == 0:
65
+ config.init_device = "cpu"
66
+ else:
67
+ config.init_device = "meta"
68
+ if config.norm_type.lower() not in NORM_CLASS_REGISTRY.keys():
69
+ norm_options = " | ".join(NORM_CLASS_REGISTRY.keys())
70
+ raise NotImplementedError(
71
+ f"Requested norm type ({config.norm_type}) is not implemented within this repo (Options: {norm_options})."
72
+ )
73
+ norm_class = NORM_CLASS_REGISTRY[config.norm_type.lower()]
74
+ self.embedding_fraction = config.embedding_fraction
75
+ self.wte = SharedEmbedding(config.vocab_size, config.d_model, device=config.init_device)
76
+ if not self.alibi:
77
+ self.wpe = torch.nn.Embedding(config.max_seq_len, config.d_model, device=config.init_device)
78
+ self.emb_drop = nn.Dropout(config.emb_pdrop)
79
+ self.blocks = nn.ModuleList(
80
+ [MPTBlock(device=config.init_device, **config.to_dict()) for _ in range(config.n_layers)]
81
+ )
82
+ self.norm_f = norm_class(config.d_model, device=config.init_device)
83
+ if config.init_device != "meta":
84
+ print(
85
+ f'You are using config.init_device={config.init_device!r}, but you can also use config.init_device="meta" with Composer + FSDP for fast initialization.'
86
+ )
87
+ self.apply(self.param_init_fn)
88
+ self.is_causal = not self.prefix_lm
89
+ self._attn_bias_initialized = False
90
+ self.attn_bias = None
91
+ self.attn_bias_shape = attn_bias_shape(
92
+ self.attn_impl,
93
+ config.n_heads,
94
+ config.max_seq_len,
95
+ self.alibi,
96
+ prefix_lm=self.prefix_lm,
97
+ causal=self.is_causal,
98
+ use_sequence_id=self.attn_uses_sequence_id,
99
+ )
100
+ if config.no_bias:
101
+ for module in self.modules():
102
+ if hasattr(module, "bias") and isinstance(module.bias, nn.Parameter):
103
+ if config.verbose:
104
+ warnings.warn(f"Removing bias ({module.bias}) from {module}.")
105
+ module.register_parameter("bias", None)
106
+ if config.verbose and config.verbose > 2:
107
+ print(self)
108
+ if "verbose" not in self.config.init_config:
109
+ self.config.init_config["verbose"] = self.config.verbose
110
+ if self.config.init_config["verbose"] > 1:
111
+ init_fn_name = self.config.init_config["name"]
112
+ warnings.warn(f"Using {init_fn_name} initialization.")
113
+ self.gradient_checkpointing = False
114
+
115
+ def get_input_embeddings(self):
116
+ return self.wte
117
+
118
+ def set_input_embeddings(self, value):
119
+ self.wte = value
120
+
121
+ @torch.no_grad()
122
+ def _attn_bias(
123
+ self,
124
+ device,
125
+ dtype,
126
+ attention_mask: Optional[torch.ByteTensor] = None,
127
+ prefix_mask: Optional[torch.ByteTensor] = None,
128
+ sequence_id: Optional[torch.LongTensor] = None,
129
+ ):
130
+ if not self._attn_bias_initialized:
131
+ if self.attn_bias_shape:
132
+ self.attn_bias = torch.zeros(self.attn_bias_shape, device=device, dtype=dtype)
133
+ self.attn_bias = build_attn_bias(
134
+ self.attn_impl,
135
+ self.attn_bias,
136
+ self.config.n_heads,
137
+ self.config.max_seq_len,
138
+ causal=self.is_causal,
139
+ alibi=self.alibi,
140
+ alibi_bias_max=self.alibi_bias_max,
141
+ )
142
+ self._attn_bias_initialized = True
143
+ if self.attn_impl == "flash":
144
+ return (self.attn_bias, attention_mask)
145
+ if self.attn_bias is not None:
146
+ self.attn_bias = self.attn_bias.to(dtype=dtype, device=device)
147
+ attn_bias = self.attn_bias
148
+ if self.prefix_lm:
149
+ assert isinstance(attn_bias, torch.Tensor)
150
+ assert isinstance(prefix_mask, torch.Tensor)
151
+ attn_bias = self._apply_prefix_mask(attn_bias, prefix_mask)
152
+ if self.attn_uses_sequence_id and sequence_id is not None:
153
+ assert isinstance(attn_bias, torch.Tensor)
154
+ attn_bias = self._apply_sequence_id(attn_bias, sequence_id)
155
+ if attention_mask is not None:
156
+ s_k = attention_mask.shape[-1]
157
+ if attn_bias is None:
158
+ attn_bias = torch.zeros((1, 1, 1, s_k), device=device, dtype=dtype)
159
+ else:
160
+ _s_k = max(0, attn_bias.size(-1) - s_k)
161
+ attn_bias = attn_bias[:, :, :, _s_k:]
162
+ if prefix_mask is not None and attention_mask.shape != prefix_mask.shape:
163
+ raise ValueError(
164
+ f"attention_mask shape={attention_mask.shape} "
165
+ + f"and prefix_mask shape={prefix_mask.shape} are not equal."
166
+ )
167
+ min_val = torch.finfo(attn_bias.dtype).min
168
+ attn_bias = attn_bias.masked_fill(~attention_mask.view(-1, 1, 1, s_k), min_val)
169
+ return (attn_bias, None)
170
+
171
+ def _apply_prefix_mask(self, attn_bias: torch.Tensor, prefix_mask: torch.Tensor):
172
+ (s_k, s_q) = attn_bias.shape[-2:]
173
+ if s_k != self.config.max_seq_len or s_q != self.config.max_seq_len:
174
+ raise ValueError(
175
+ "attn_bias does not match the expected shape. "
176
+ + f"The last two dimensions should both be {self.config.max_length} "
177
+ + f"but are {s_k} and {s_q}."
178
+ )
179
+ seq_len = prefix_mask.shape[-1]
180
+ if seq_len > self.config.max_seq_len:
181
+ raise ValueError(f"prefix_mask sequence length cannot exceed max_seq_len={self.config.max_seq_len}")
182
+ attn_bias = attn_bias[..., :seq_len, :seq_len]
183
+ causal = torch.tril(torch.ones((seq_len, seq_len), dtype=torch.bool, device=prefix_mask.device)).view(
184
+ 1, 1, seq_len, seq_len
185
+ )
186
+ prefix = prefix_mask.view(-1, 1, 1, seq_len)
187
+ cannot_attend = ~torch.logical_or(causal, prefix.bool())
188
+ min_val = torch.finfo(attn_bias.dtype).min
189
+ attn_bias = attn_bias.masked_fill(cannot_attend, min_val)
190
+ return attn_bias
191
+
192
+ def _apply_sequence_id(self, attn_bias: torch.Tensor, sequence_id: torch.LongTensor):
193
+ seq_len = sequence_id.shape[-1]
194
+ if seq_len > self.config.max_seq_len:
195
+ raise ValueError(f"sequence_id sequence length cannot exceed max_seq_len={self.config.max_seq_len}")
196
+ attn_bias = attn_bias[..., :seq_len, :seq_len]
197
+ cannot_attend = torch.logical_not(
198
+ torch.eq(sequence_id.view(-1, seq_len, 1), sequence_id.view(-1, 1, seq_len))
199
+ ).unsqueeze(1)
200
+ min_val = torch.finfo(attn_bias.dtype).min
201
+ attn_bias = attn_bias.masked_fill(cannot_attend, min_val)
202
+ return attn_bias
203
+
204
+ def forward(
205
+ self,
206
+ input_ids: torch.LongTensor,
207
+ past_key_values: Optional[List[Tuple[torch.FloatTensor]]] = None,
208
+ attention_mask: Optional[torch.ByteTensor] = None,
209
+ prefix_mask: Optional[torch.ByteTensor] = None,
210
+ sequence_id: Optional[torch.LongTensor] = None,
211
+ return_dict: Optional[bool] = None,
212
+ output_attentions: Optional[bool] = None,
213
+ output_hidden_states: Optional[bool] = None,
214
+ use_cache: Optional[bool] = None,
215
+ inputs_embeds: Optional[torch.Tensor] = None,
216
+ ):
217
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
218
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
219
+ if attention_mask is not None:
220
+ attention_mask = attention_mask.bool()
221
+ if prefix_mask is not None:
222
+ prefix_mask = prefix_mask.bool()
223
+ if not return_dict:
224
+ raise NotImplementedError("return_dict False is not implemented yet for MPT")
225
+ if output_attentions:
226
+ if self.attn_impl != "torch":
227
+ raise NotImplementedError(
228
+ "output_attentions is not implemented for MPT when using attn_impl `flash` or `triton`."
229
+ )
230
+ if attention_mask is not None and attention_mask[:, 0].sum() != attention_mask.shape[0] and self.training:
231
+ raise NotImplementedError("MPT does not support training with left padding.")
232
+ if self.prefix_lm and prefix_mask is None:
233
+ raise ValueError("prefix_mask is a required argument when MPT is configured with prefix_lm=True.")
234
+ if self.training:
235
+ if self.attn_uses_sequence_id and sequence_id is None:
236
+ raise ValueError(
237
+ "sequence_id is a required argument when MPT is configured with attn_uses_sequence_id=True "
238
+ + "and the model is in train mode."
239
+ )
240
+ elif self.attn_uses_sequence_id is False and sequence_id is not None:
241
+ warnings.warn(
242
+ "MPT received non-None input for `sequence_id` but is configured with attn_uses_sequence_id=False. "
243
+ + "This input will be ignored. If you want the model to use `sequence_id`, set attn_uses_sequence_id to True."
244
+ )
245
+ if input_ids is not None:
246
+ S = input_ids.size(1)
247
+ assert (
248
+ S <= self.config.max_seq_len
249
+ ), f"Cannot forward input with seq_len={S}, this model only supports seq_len<={self.config.max_seq_len}"
250
+ tok_emb = self.wte(input_ids)
251
+ else:
252
+ assert inputs_embeds is not None
253
+ assert self.alibi, "inputs_embeds is not implemented for MPT unless for alibi."
254
+ S = inputs_embeds.size(1)
255
+ tok_emb = inputs_embeds
256
+ if self.alibi:
257
+ x = tok_emb
258
+ else:
259
+ past_position = 0
260
+ if past_key_values is not None:
261
+ if len(past_key_values) != self.config.n_layers:
262
+ raise ValueError(
263
+ f"past_key_values must provide a past_key_value for each attention "
264
+ + f"layer in the network (len(past_key_values)={len(past_key_values)!r}; self.config.n_layers={self.config.n_layers!r})."
265
+ )
266
+ past_position = past_key_values[0][0].size(1)
267
+ if self.attn_impl == "torch":
268
+ past_position = past_key_values[0][0].size(3)
269
+ if S + past_position > self.config.max_seq_len:
270
+ raise ValueError(
271
+ f"Cannot forward input with past sequence length {past_position} and current sequence length {S + 1}, this model only supports total sequence length <= {self.config.max_seq_len}."
272
+ )
273
+ pos = torch.arange(past_position, S + past_position, dtype=torch.long, device=input_ids.device).unsqueeze(0)
274
+ if attention_mask is not None:
275
+ pos = torch.clamp(
276
+ pos - torch.cumsum((~attention_mask).to(torch.int32), dim=1)[:, past_position:], min=0
277
+ )
278
+ pos_emb = self.wpe(pos)
279
+ x = tok_emb + pos_emb
280
+ if self.embedding_fraction == 1:
281
+ x = self.emb_drop(x)
282
+ else:
283
+ x_shrunk = x * self.embedding_fraction + x.detach() * (1 - self.embedding_fraction)
284
+ assert isinstance(self.emb_drop, nn.Module)
285
+ x = self.emb_drop(x_shrunk)
286
+ (attn_bias, attention_mask) = self._attn_bias(
287
+ device=x.device,
288
+ dtype=torch.float32,
289
+ attention_mask=attention_mask,
290
+ prefix_mask=prefix_mask,
291
+ sequence_id=sequence_id,
292
+ )
293
+ if use_cache and past_key_values is None:
294
+ past_key_values = [() for _ in range(self.config.n_layers)]
295
+ all_hidden_states = () if output_hidden_states else None
296
+ all_self_attns = () if output_attentions else None
297
+ for (b_idx, block) in enumerate(self.blocks):
298
+ if output_hidden_states:
299
+ assert all_hidden_states is not None
300
+ all_hidden_states = all_hidden_states + (x,)
301
+ past_key_value = past_key_values[b_idx] if past_key_values is not None else None
302
+ if self.gradient_checkpointing and self.training:
303
+ (x, attn_weights, past_key_value) = torch.utils.checkpoint.checkpoint(
304
+ block, x, past_key_value, attn_bias, attention_mask, self.is_causal
305
+ )
306
+ else:
307
+ (x, attn_weights, past_key_value) = block(
308
+ x,
309
+ past_key_value=past_key_value,
310
+ attn_bias=attn_bias,
311
+ attention_mask=attention_mask,
312
+ is_causal=self.is_causal,
313
+ )
314
+ if past_key_values is not None:
315
+ past_key_values[b_idx] = past_key_value
316
+ if output_attentions:
317
+ assert all_self_attns is not None
318
+ all_self_attns = all_self_attns + (attn_weights,)
319
+ x = self.norm_f(x)
320
+ if output_hidden_states:
321
+ assert all_hidden_states is not None
322
+ all_hidden_states = all_hidden_states + (x,)
323
+ return BaseModelOutputWithPast(
324
+ last_hidden_state=x,
325
+ past_key_values=past_key_values,
326
+ hidden_states=all_hidden_states,
327
+ attentions=all_self_attns,
328
+ )
329
+
330
+ def param_init_fn(self, module):
331
+ init_fn_name = self.config.init_config["name"]
332
+ MODEL_INIT_REGISTRY[init_fn_name](
333
+ module=module, n_layers=self.config.n_layers, d_model=self.config.d_model, **self.config.init_config
334
+ )
335
+
336
+ def fsdp_wrap_fn(self, module):
337
+ return isinstance(module, MPTBlock)
338
+
339
+ def activation_checkpointing_fn(self, module):
340
+ return isinstance(module, MPTBlock)
341
+
342
+
343
+ class MPTForCausalLM(MPTPreTrainedModel):
344
+ def __init__(self, config: MPTConfig):
345
+ super().__init__(config)
346
+ if not config.tie_word_embeddings:
347
+ raise ValueError("MPTForCausalLM only supports tied word embeddings")
348
+ print(f"Instantiating an MPTForCausalLM model from {__file__}")
349
+ self.transformer = MPTModel(config)
350
+ for child in self.transformer.children():
351
+ if isinstance(child, torch.nn.ModuleList):
352
+ continue
353
+ if isinstance(child, torch.nn.Module):
354
+ child._fsdp_wrap = True
355
+ self.logit_scale = None
356
+ if config.logit_scale is not None:
357
+ logit_scale = config.logit_scale
358
+ if isinstance(logit_scale, str):
359
+ if logit_scale == "inv_sqrt_d_model":
360
+ logit_scale = 1 / math.sqrt(config.d_model)
361
+ else:
362
+ raise ValueError(
363
+ f"logit_scale={logit_scale!r} is not recognized as an option; use numeric value or 'inv_sqrt_d_model'."
364
+ )
365
+ self.logit_scale = logit_scale
366
+
367
+ def get_input_embeddings(self):
368
+ return self.transformer.wte
369
+
370
+ def set_input_embeddings(self, value):
371
+ self.transformer.wte = value
372
+
373
+ def get_output_embeddings(self):
374
+ return self.transformer.wte
375
+
376
+ def set_output_embeddings(self, new_embeddings):
377
+ self.transformer.wte = new_embeddings
378
+
379
+ def set_decoder(self, decoder):
380
+ self.transformer = decoder
381
+
382
+ def get_decoder(self):
383
+ return self.transformer
384
+
385
+ def forward(
386
+ self,
387
+ input_ids: torch.LongTensor,
388
+ past_key_values: Optional[List[Tuple[torch.FloatTensor]]] = None,
389
+ attention_mask: Optional[torch.ByteTensor] = None,
390
+ prefix_mask: Optional[torch.ByteTensor] = None,
391
+ sequence_id: Optional[torch.LongTensor] = None,
392
+ labels: Optional[torch.LongTensor] = None,
393
+ return_dict: Optional[bool] = None,
394
+ output_attentions: Optional[bool] = None,
395
+ output_hidden_states: Optional[bool] = None,
396
+ use_cache: Optional[bool] = None,
397
+ inputs_embeds: Optional[torch.FloatTensor] = None,
398
+ ):
399
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
400
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
401
+ if inputs_embeds is not None:
402
+ raise NotImplementedError("inputs_embeds has to be None (for hf/peft support).")
403
+ outputs = self.transformer(
404
+ input_ids=input_ids,
405
+ past_key_values=past_key_values,
406
+ attention_mask=attention_mask,
407
+ prefix_mask=prefix_mask,
408
+ sequence_id=sequence_id,
409
+ return_dict=return_dict,
410
+ output_attentions=output_attentions,
411
+ output_hidden_states=output_hidden_states,
412
+ use_cache=use_cache,
413
+ )
414
+ logits = self.transformer.wte(outputs.last_hidden_state.to(self.transformer.wte.weight.device), True)
415
+ if self.logit_scale is not None:
416
+ if self.logit_scale == 0:
417
+ warnings.warn(
418
+ f"Multiplying logits by self.logit_scale={self.logit_scale!r}. This will produce uniform (uninformative) outputs."
419
+ )
420
+ logits *= self.logit_scale
421
+ loss = None
422
+ if labels is not None:
423
+ labels = torch.roll(labels, shifts=-1)
424
+ labels[:, -1] = -100
425
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)), labels.to(logits.device).view(-1))
426
+ return CausalLMOutputWithPast(
427
+ loss=loss,
428
+ logits=logits,
429
+ past_key_values=outputs.past_key_values,
430
+ hidden_states=outputs.hidden_states,
431
+ attentions=outputs.attentions,
432
+ )
433
+
434
+ def param_init_fn(self, module):
435
+ init_fn_name = self.config.init_config["name"]
436
+ MODEL_INIT_REGISTRY[init_fn_name](
437
+ module=module, n_layers=self.config.n_layers, d_model=self.config.d_model, **self.config.init_config
438
+ )
439
+
440
+ def fsdp_wrap_fn(self, module):
441
+ return isinstance(module, MPTBlock)
442
+
443
+ def activation_checkpointing_fn(self, module):
444
+ return isinstance(module, MPTBlock)
445
+
446
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
447
+ if inputs_embeds is not None:
448
+ raise NotImplementedError("inputs_embeds is not implemented for MPT yet")
449
+ attention_mask = kwargs["attention_mask"].bool()
450
+ if attention_mask[:, -1].sum() != attention_mask.shape[0]:
451
+ raise NotImplementedError("MPT does not support generation with right padding.")
452
+ if self.transformer.attn_uses_sequence_id and self.training:
453
+ sequence_id = torch.zeros_like(input_ids[:1])
454
+ else:
455
+ sequence_id = None
456
+ if past_key_values is not None:
457
+ input_ids = input_ids[:, -1].unsqueeze(-1)
458
+ if self.transformer.prefix_lm:
459
+ prefix_mask = torch.ones_like(attention_mask)
460
+ if kwargs.get("use_cache") == False:
461
+ raise NotImplementedError("MPT with prefix_lm=True does not support use_cache=False.")
462
+ else:
463
+ prefix_mask = None
464
+ return {
465
+ "input_ids": input_ids,
466
+ "attention_mask": attention_mask,
467
+ "prefix_mask": prefix_mask,
468
+ "sequence_id": sequence_id,
469
+ "past_key_values": past_key_values,
470
+ "use_cache": kwargs.get("use_cache", True),
471
+ }
472
+
473
+ @staticmethod
474
+ def _reorder_cache(past_key_values, beam_idx):
475
+ """Used by HuggingFace generate when using beam search with kv-caching.
476
+
477
+ See https://github.com/huggingface/transformers/blob/3ec7a47664ebe40c40f4b722f6bb1cd30c3821ec/src/transformers/models/gpt2/modeling_gpt2.py#L1122-L1133
478
+ for an example in transformers.
479
+ """
480
+ reordered_past = []
481
+ for layer_past in past_key_values:
482
+ reordered_past += [tuple(past_state.index_select(0, beam_idx) for past_state in layer_past)]
483
+ return reordered_past
VILA/llava/model/language_model/mpt/norm.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 NVIDIA CORPORATION & AFFILIATES
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # SPDX-License-Identifier: Apache-2.0
16
+
17
+ import torch
18
+
19
+
20
+ def _cast_if_autocast_enabled(tensor):
21
+ if torch.is_autocast_enabled():
22
+ if tensor.device.type == "cuda":
23
+ dtype = torch.get_autocast_gpu_dtype()
24
+ elif tensor.device.type == "cpu":
25
+ dtype = torch.get_autocast_cpu_dtype()
26
+ else:
27
+ raise NotImplementedError()
28
+ return tensor.to(dtype=dtype)
29
+ return tensor
30
+
31
+
32
+ class LPLayerNorm(torch.nn.LayerNorm):
33
+ def __init__(self, normalized_shape, eps=1e-05, elementwise_affine=True, device=None, dtype=None):
34
+ super().__init__(
35
+ normalized_shape=normalized_shape,
36
+ eps=eps,
37
+ elementwise_affine=elementwise_affine,
38
+ device=device,
39
+ dtype=dtype,
40
+ )
41
+
42
+ def forward(self, x):
43
+ module_device = x.device
44
+ downcast_x = _cast_if_autocast_enabled(x)
45
+ downcast_weight = _cast_if_autocast_enabled(self.weight) if self.weight is not None else self.weight
46
+ downcast_bias = _cast_if_autocast_enabled(self.bias) if self.bias is not None else self.bias
47
+ with torch.autocast(enabled=False, device_type=module_device.type):
48
+ return torch.nn.functional.layer_norm(
49
+ downcast_x, self.normalized_shape, downcast_weight, downcast_bias, self.eps
50
+ )
51
+
52
+
53
+ def rms_norm(x, weight=None, eps=1e-05):
54
+ output = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps)
55
+ if weight is not None:
56
+ return output * weight
57
+ return output
58
+
59
+
60
+ class RMSNorm(torch.nn.Module):
61
+ def __init__(self, normalized_shape, eps=1e-05, weight=True, dtype=None, device=None):
62
+ super().__init__()
63
+ self.eps = eps
64
+ if weight:
65
+ self.weight = torch.nn.Parameter(torch.ones(normalized_shape, dtype=dtype, device=device))
66
+ else:
67
+ self.register_parameter("weight", None)
68
+
69
+ def forward(self, x):
70
+ return rms_norm(x.float(), self.weight, self.eps).to(dtype=x.dtype)
71
+
72
+
73
+ class LPRMSNorm(RMSNorm):
74
+ def __init__(self, normalized_shape, eps=1e-05, weight=True, dtype=None, device=None):
75
+ super().__init__(normalized_shape=normalized_shape, eps=eps, weight=weight, dtype=dtype, device=device)
76
+
77
+ def forward(self, x):
78
+ downcast_x = _cast_if_autocast_enabled(x)
79
+ downcast_weight = _cast_if_autocast_enabled(self.weight) if self.weight is not None else self.weight
80
+ with torch.autocast(enabled=False, device_type=x.device.type):
81
+ return rms_norm(downcast_x, downcast_weight, self.eps).to(dtype=x.dtype)
82
+
83
+
84
+ NORM_CLASS_REGISTRY = {
85
+ "layernorm": torch.nn.LayerNorm,
86
+ "low_precision_layernorm": LPLayerNorm,
87
+ "rmsnorm": RMSNorm,
88
+ "low_precision_rmsnorm": LPRMSNorm,
89
+ }
VILA/llava/model/language_model/mpt/param_init_fns.py ADDED
@@ -0,0 +1,399 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 NVIDIA CORPORATION & AFFILIATES
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # SPDX-License-Identifier: Apache-2.0
16
+
17
+ import math
18
+ import warnings
19
+ from collections.abc import Sequence
20
+ from functools import partial
21
+ from typing import Optional, Tuple, Union
22
+
23
+ import torch
24
+ from torch import nn
25
+
26
+ from .norm import NORM_CLASS_REGISTRY
27
+
28
+
29
+ def torch_default_param_init_fn_(module: nn.Module, verbose: int = 0, **kwargs):
30
+ del kwargs
31
+ if verbose > 1:
32
+ warnings.warn(f"Initializing network using module's reset_parameters attribute")
33
+ if hasattr(module, "reset_parameters"):
34
+ module.reset_parameters()
35
+
36
+
37
+ def fused_init_helper_(module: nn.Module, init_fn_):
38
+ _fused = getattr(module, "_fused", None)
39
+ if _fused is None:
40
+ raise RuntimeError(f"Internal logic error")
41
+ (dim, splits) = _fused
42
+ splits = (0, *splits, module.weight.size(dim))
43
+ for (s, e) in zip(splits[:-1], splits[1:]):
44
+ slice_indices = [slice(None)] * module.weight.ndim
45
+ slice_indices[dim] = slice(s, e)
46
+ init_fn_(module.weight[slice_indices])
47
+
48
+
49
+ def generic_param_init_fn_(
50
+ module: nn.Module,
51
+ init_fn_,
52
+ n_layers: int,
53
+ d_model: Optional[int] = None,
54
+ init_div_is_residual: Union[int, float, str, bool] = True,
55
+ emb_init_std: Optional[float] = None,
56
+ emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None,
57
+ verbose: int = 0,
58
+ **kwargs,
59
+ ):
60
+ del kwargs
61
+ if verbose > 1:
62
+ warnings.warn(f"If model has bias parameters they are initialized to 0.")
63
+ init_div_is_residual = init_div_is_residual
64
+ if init_div_is_residual is False:
65
+ div_is_residual = 1.0
66
+ elif init_div_is_residual is True:
67
+ div_is_residual = math.sqrt(2 * n_layers)
68
+ elif isinstance(init_div_is_residual, float) or isinstance(init_div_is_residual, int):
69
+ div_is_residual = init_div_is_residual
70
+ elif isinstance(init_div_is_residual, str) and init_div_is_residual.isnumeric():
71
+ div_is_residual = float(init_div_is_residual)
72
+ else:
73
+ div_is_residual = 1.0
74
+ raise ValueError(f"Expected init_div_is_residual to be boolean or numeric, got {init_div_is_residual}")
75
+ if init_div_is_residual is not False:
76
+ if verbose > 1:
77
+ warnings.warn(
78
+ f"Initializing _is_residual layers then dividing them by {div_is_residual:.3f}. "
79
+ + f"Set `init_div_is_residual: false` in init config to disable this."
80
+ )
81
+ if isinstance(module, nn.Linear):
82
+ if hasattr(module, "_fused"):
83
+ fused_init_helper_(module, init_fn_)
84
+ else:
85
+ init_fn_(module.weight)
86
+ if module.bias is not None:
87
+ torch.nn.init.zeros_(module.bias)
88
+ if init_div_is_residual is not False and getattr(module, "_is_residual", False):
89
+ with torch.no_grad():
90
+ module.weight.div_(div_is_residual)
91
+ elif isinstance(module, nn.Embedding):
92
+ if emb_init_std is not None:
93
+ std = emb_init_std
94
+ if std == 0:
95
+ warnings.warn(f"Embedding layer initialized to 0.")
96
+ emb_init_fn_ = partial(torch.nn.init.normal_, mean=0.0, std=std)
97
+ if verbose > 1:
98
+ warnings.warn(f"Embedding layer initialized using normal distribution with mean=0 and std={std!r}.")
99
+ elif emb_init_uniform_lim is not None:
100
+ lim = emb_init_uniform_lim
101
+ if isinstance(lim, Sequence):
102
+ if len(lim) > 2:
103
+ raise ValueError(f"Uniform init requires a min and a max limit. User input: {lim}.")
104
+ if lim[0] == lim[1]:
105
+ warnings.warn(f"Embedding layer initialized to {lim[0]}.")
106
+ else:
107
+ if lim == 0:
108
+ warnings.warn(f"Embedding layer initialized to 0.")
109
+ lim = [-lim, lim]
110
+ (a, b) = lim
111
+ emb_init_fn_ = partial(torch.nn.init.uniform_, a=a, b=b)
112
+ if verbose > 1:
113
+ warnings.warn(f"Embedding layer initialized using uniform distribution in range {lim}.")
114
+ else:
115
+ emb_init_fn_ = init_fn_
116
+ emb_init_fn_(module.weight)
117
+ elif isinstance(module, tuple(set(NORM_CLASS_REGISTRY.values()))):
118
+ if verbose > 1:
119
+ warnings.warn(f"Norm weights are set to 1. If norm layer has a bias it is initialized to 0.")
120
+ if hasattr(module, "weight") and module.weight is not None:
121
+ torch.nn.init.ones_(module.weight)
122
+ if hasattr(module, "bias") and module.bias is not None:
123
+ torch.nn.init.zeros_(module.bias)
124
+ elif isinstance(module, nn.MultiheadAttention):
125
+ if module._qkv_same_embed_dim:
126
+ assert module.in_proj_weight is not None
127
+ assert module.q_proj_weight is None and module.k_proj_weight is None and (module.v_proj_weight is None)
128
+ assert d_model is not None
129
+ _d = d_model
130
+ splits = (0, _d, 2 * _d, 3 * _d)
131
+ for (s, e) in zip(splits[:-1], splits[1:]):
132
+ init_fn_(module.in_proj_weight[s:e])
133
+ else:
134
+ assert (
135
+ module.q_proj_weight is not None
136
+ and module.k_proj_weight is not None
137
+ and (module.v_proj_weight is not None)
138
+ )
139
+ assert module.in_proj_weight is None
140
+ init_fn_(module.q_proj_weight)
141
+ init_fn_(module.k_proj_weight)
142
+ init_fn_(module.v_proj_weight)
143
+ if module.in_proj_bias is not None:
144
+ torch.nn.init.zeros_(module.in_proj_bias)
145
+ if module.bias_k is not None:
146
+ torch.nn.init.zeros_(module.bias_k)
147
+ if module.bias_v is not None:
148
+ torch.nn.init.zeros_(module.bias_v)
149
+ init_fn_(module.out_proj.weight)
150
+ if init_div_is_residual is not False and getattr(module.out_proj, "_is_residual", False):
151
+ with torch.no_grad():
152
+ module.out_proj.weight.div_(div_is_residual)
153
+ if module.out_proj.bias is not None:
154
+ torch.nn.init.zeros_(module.out_proj.bias)
155
+ else:
156
+ for _ in module.parameters(recurse=False):
157
+ raise NotImplementedError(f"{module.__class__.__name__} parameters are not initialized by param_init_fn.")
158
+
159
+
160
+ def _normal_init_(std, mean=0.0):
161
+ return partial(torch.nn.init.normal_, mean=mean, std=std)
162
+
163
+
164
+ def _normal_param_init_fn_(
165
+ module: nn.Module,
166
+ std: float,
167
+ n_layers: int,
168
+ d_model: Optional[int] = None,
169
+ init_div_is_residual: Union[int, float, str, bool] = True,
170
+ emb_init_std: Optional[float] = None,
171
+ emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None,
172
+ verbose: int = 0,
173
+ **kwargs,
174
+ ):
175
+ del kwargs
176
+ init_fn_ = _normal_init_(std=std)
177
+ if verbose > 1:
178
+ warnings.warn(f"Using torch.nn.init.normal_ init fn mean=0.0, std={std}")
179
+ generic_param_init_fn_(
180
+ module=module,
181
+ init_fn_=init_fn_,
182
+ d_model=d_model,
183
+ n_layers=n_layers,
184
+ init_div_is_residual=init_div_is_residual,
185
+ emb_init_std=emb_init_std,
186
+ emb_init_uniform_lim=emb_init_uniform_lim,
187
+ verbose=verbose,
188
+ )
189
+
190
+
191
+ def baseline_param_init_fn_(
192
+ module: nn.Module,
193
+ init_std: float,
194
+ n_layers: int,
195
+ d_model: Optional[int] = None,
196
+ init_div_is_residual: Union[int, float, str, bool] = True,
197
+ emb_init_std: Optional[float] = None,
198
+ emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None,
199
+ verbose: int = 0,
200
+ **kwargs,
201
+ ):
202
+ del kwargs
203
+ if init_std is None:
204
+ raise ValueError(
205
+ "You must set model.init_config['init_std'] to a float value to use the default initialization scheme."
206
+ )
207
+ _normal_param_init_fn_(
208
+ module=module,
209
+ std=init_std,
210
+ d_model=d_model,
211
+ n_layers=n_layers,
212
+ init_div_is_residual=init_div_is_residual,
213
+ emb_init_std=emb_init_std,
214
+ emb_init_uniform_lim=emb_init_uniform_lim,
215
+ verbose=verbose,
216
+ )
217
+
218
+
219
+ def small_param_init_fn_(
220
+ module: nn.Module,
221
+ n_layers: int,
222
+ d_model: int,
223
+ init_div_is_residual: Union[int, float, str, bool] = True,
224
+ emb_init_std: Optional[float] = None,
225
+ emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None,
226
+ verbose: int = 0,
227
+ **kwargs,
228
+ ):
229
+ del kwargs
230
+ std = math.sqrt(2 / (5 * d_model))
231
+ _normal_param_init_fn_(
232
+ module=module,
233
+ std=std,
234
+ d_model=d_model,
235
+ n_layers=n_layers,
236
+ init_div_is_residual=init_div_is_residual,
237
+ emb_init_std=emb_init_std,
238
+ emb_init_uniform_lim=emb_init_uniform_lim,
239
+ verbose=verbose,
240
+ )
241
+
242
+
243
+ def neox_param_init_fn_(
244
+ module: nn.Module,
245
+ n_layers: int,
246
+ d_model: int,
247
+ emb_init_std: Optional[float] = None,
248
+ emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None,
249
+ verbose: int = 0,
250
+ **kwargs,
251
+ ):
252
+ """From section 2.3.1 of GPT-NeoX-20B:
253
+
254
+ An Open-Source AutoregressiveLanguage Model — Black et. al. (2022)
255
+ see https://github.com/EleutherAI/gpt-neox/blob/9610391ab319403cef079b438edd016a2443af54/megatron/model/init_functions.py#L151
256
+ and https://github.com/EleutherAI/gpt-neox/blob/main/megatron/model/transformer.py
257
+ """
258
+ del kwargs
259
+ residual_div = n_layers / math.sqrt(10)
260
+ if verbose > 1:
261
+ warnings.warn(f"setting init_div_is_residual to {residual_div}")
262
+ small_param_init_fn_(
263
+ module=module,
264
+ d_model=d_model,
265
+ n_layers=n_layers,
266
+ init_div_is_residual=residual_div,
267
+ emb_init_std=emb_init_std,
268
+ emb_init_uniform_lim=emb_init_uniform_lim,
269
+ verbose=verbose,
270
+ )
271
+
272
+
273
+ def kaiming_uniform_param_init_fn_(
274
+ module: nn.Module,
275
+ n_layers: int,
276
+ d_model: Optional[int] = None,
277
+ init_div_is_residual: Union[int, float, str, bool] = True,
278
+ emb_init_std: Optional[float] = None,
279
+ emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None,
280
+ init_gain: float = 0,
281
+ fan_mode: str = "fan_in",
282
+ init_nonlinearity: str = "leaky_relu",
283
+ verbose: int = 0,
284
+ **kwargs,
285
+ ):
286
+ del kwargs
287
+ if verbose > 1:
288
+ warnings.warn(
289
+ f"Using nn.init.kaiming_uniform_ init fn with parameters: "
290
+ + f"a={init_gain}, mode={fan_mode}, nonlinearity={init_nonlinearity}"
291
+ )
292
+ kaiming_uniform_ = partial(nn.init.kaiming_uniform_, a=init_gain, mode=fan_mode, nonlinearity=init_nonlinearity)
293
+ generic_param_init_fn_(
294
+ module=module,
295
+ init_fn_=kaiming_uniform_,
296
+ d_model=d_model,
297
+ n_layers=n_layers,
298
+ init_div_is_residual=init_div_is_residual,
299
+ emb_init_std=emb_init_std,
300
+ emb_init_uniform_lim=emb_init_uniform_lim,
301
+ verbose=verbose,
302
+ )
303
+
304
+
305
+ def kaiming_normal_param_init_fn_(
306
+ module: nn.Module,
307
+ n_layers: int,
308
+ d_model: Optional[int] = None,
309
+ init_div_is_residual: Union[int, float, str, bool] = True,
310
+ emb_init_std: Optional[float] = None,
311
+ emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None,
312
+ init_gain: float = 0,
313
+ fan_mode: str = "fan_in",
314
+ init_nonlinearity: str = "leaky_relu",
315
+ verbose: int = 0,
316
+ **kwargs,
317
+ ):
318
+ del kwargs
319
+ if verbose > 1:
320
+ warnings.warn(
321
+ f"Using nn.init.kaiming_normal_ init fn with parameters: "
322
+ + f"a={init_gain}, mode={fan_mode}, nonlinearity={init_nonlinearity}"
323
+ )
324
+ kaiming_normal_ = partial(torch.nn.init.kaiming_normal_, a=init_gain, mode=fan_mode, nonlinearity=init_nonlinearity)
325
+ generic_param_init_fn_(
326
+ module=module,
327
+ init_fn_=kaiming_normal_,
328
+ d_model=d_model,
329
+ n_layers=n_layers,
330
+ init_div_is_residual=init_div_is_residual,
331
+ emb_init_std=emb_init_std,
332
+ emb_init_uniform_lim=emb_init_uniform_lim,
333
+ verbose=verbose,
334
+ )
335
+
336
+
337
+ def xavier_uniform_param_init_fn_(
338
+ module: nn.Module,
339
+ n_layers: int,
340
+ d_model: Optional[int] = None,
341
+ init_div_is_residual: Union[int, float, str, bool] = True,
342
+ emb_init_std: Optional[float] = None,
343
+ emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None,
344
+ init_gain: float = 0,
345
+ verbose: int = 0,
346
+ **kwargs,
347
+ ):
348
+ del kwargs
349
+ xavier_uniform_ = partial(torch.nn.init.xavier_uniform_, gain=init_gain)
350
+ if verbose > 1:
351
+ warnings.warn(f"Using torch.nn.init.xavier_uniform_ init fn with parameters: " + f"gain={init_gain}")
352
+ generic_param_init_fn_(
353
+ module=module,
354
+ init_fn_=xavier_uniform_,
355
+ d_model=d_model,
356
+ n_layers=n_layers,
357
+ init_div_is_residual=init_div_is_residual,
358
+ emb_init_std=emb_init_std,
359
+ emb_init_uniform_lim=emb_init_uniform_lim,
360
+ verbose=verbose,
361
+ )
362
+
363
+
364
+ def xavier_normal_param_init_fn_(
365
+ module: nn.Module,
366
+ n_layers: int,
367
+ d_model: Optional[int] = None,
368
+ init_div_is_residual: Union[int, float, str, bool] = True,
369
+ emb_init_std: Optional[float] = None,
370
+ emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None,
371
+ init_gain: float = 0,
372
+ verbose: int = 0,
373
+ **kwargs,
374
+ ):
375
+ xavier_normal_ = partial(torch.nn.init.xavier_normal_, gain=init_gain)
376
+ if verbose > 1:
377
+ warnings.warn(f"Using torch.nn.init.xavier_normal_ init fn with parameters: " + f"gain={init_gain}")
378
+ generic_param_init_fn_(
379
+ module=module,
380
+ init_fn_=xavier_normal_,
381
+ d_model=d_model,
382
+ n_layers=n_layers,
383
+ init_div_is_residual=init_div_is_residual,
384
+ emb_init_std=emb_init_std,
385
+ emb_init_uniform_lim=emb_init_uniform_lim,
386
+ verbose=verbose,
387
+ )
388
+
389
+
390
+ MODEL_INIT_REGISTRY = {
391
+ "default_": torch_default_param_init_fn_,
392
+ "baseline_": baseline_param_init_fn_,
393
+ "kaiming_uniform_": kaiming_uniform_param_init_fn_,
394
+ "kaiming_normal_": kaiming_normal_param_init_fn_,
395
+ "neox_init_": neox_param_init_fn_,
396
+ "small_init_": small_param_init_fn_,
397
+ "xavier_uniform_": xavier_uniform_param_init_fn_,
398
+ "xavier_normal_": xavier_normal_param_init_fn_,
399
+ }
VILA/llava/model/multimodal_encoder/__pycache__/builder.cpython-310.pyc ADDED
Binary file (1.51 kB). View file
 
VILA/llava/model/multimodal_encoder/__pycache__/clip_encoder.cpython-310.pyc ADDED
Binary file (1.41 kB). View file
 
VILA/llava/model/multimodal_encoder/__pycache__/image_processor.cpython-310.pyc ADDED
Binary file (18.8 kB). View file
 
VILA/llava/model/multimodal_encoder/__pycache__/intern_encoder.cpython-310.pyc ADDED
Binary file (2.71 kB). View file
 
VILA/llava/model/multimodal_encoder/__pycache__/radio_encoder.cpython-310.pyc ADDED
Binary file (8.6 kB). View file
 
VILA/llava/model/multimodal_encoder/__pycache__/siglip_encoder.cpython-310.pyc ADDED
Binary file (1.47 kB). View file
 
VILA/llava/model/multimodal_encoder/__pycache__/vision_encoder.cpython-310.pyc ADDED
Binary file (6 kB). View file
 
VILA/llava/model/multimodal_encoder/__pycache__/visualize_features.cpython-310.pyc ADDED
Binary file (9.31 kB). View file
 
VILA/llava/model/multimodal_encoder/builder.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 NVIDIA CORPORATION & AFFILIATES
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # SPDX-License-Identifier: Apache-2.0
16
+
17
+ # This file is modified from https://github.com/haotian-liu/LLaVA/
18
+
19
+ import os
20
+
21
+ from transformers import AutoConfig, PretrainedConfig, PreTrainedModel
22
+
23
+ from .clip_encoder import CLIPVisionTower, CLIPVisionTowerS2
24
+ from .intern_encoder import InternVisionTower
25
+ from .radio_encoder import RADIOVisionTower
26
+ from .siglip_encoder import SiglipVisionTower, SiglipVisionTowerS2
27
+
28
+
29
+ def build_vision_tower(model_name_or_path: str, config: PretrainedConfig) -> PreTrainedModel:
30
+ ## skip vision tower instantiation
31
+ if model_name_or_path is None:
32
+ return None
33
+
34
+ vision_tower_arch = None
35
+ if config.resume_path and "radio" not in model_name_or_path:
36
+ assert os.path.exists(model_name_or_path), f"Resume vision tower path {model_name_or_path} does not exist!"
37
+ vision_tower_cfg = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=True)
38
+ vision_tower_arch = vision_tower_cfg.architectures[0].lower()
39
+ vision_tower_name = vision_tower_arch if vision_tower_arch is not None else model_name_or_path
40
+
41
+ use_s2 = getattr(config, "s2", False)
42
+
43
+ if "intern" in vision_tower_name.lower():
44
+ if hasattr(config, "drop_path_rate"):
45
+ vision_tower = InternVisionTower(model_name_or_path, config=config, drop_path_rate=config.drop_path_rate)
46
+ else:
47
+ vision_tower = InternVisionTower(model_name_or_path, config=config, drop_path_rate=0.0)
48
+ elif "radio" in vision_tower_name:
49
+ vision_tower = RADIOVisionTower(model_name_or_path, config)
50
+ elif "clip" in vision_tower_name:
51
+ if use_s2:
52
+ vision_tower = CLIPVisionTowerS2(model_name_or_path, config)
53
+ else:
54
+ vision_tower = CLIPVisionTower(model_name_or_path, config)
55
+ elif "siglip" in vision_tower_name:
56
+ if use_s2:
57
+ vision_tower = SiglipVisionTowerS2(model_name_or_path, config)
58
+ else:
59
+ vision_tower = SiglipVisionTower(model_name_or_path, config)
60
+ else:
61
+ raise ValueError(f"Unknown vision tower: {model_name_or_path}")
62
+
63
+ config.mm_hidden_size = vision_tower.config.hidden_size if not use_s2 else vision_tower.hidden_size
64
+ return vision_tower
VILA/llava/model/multimodal_encoder/clip_encoder.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 NVIDIA CORPORATION & AFFILIATES
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # SPDX-License-Identifier: Apache-2.0
16
+
17
+ # This file is modified from https://github.com/haotian-liu/LLaVA/
18
+ import torch
19
+ from transformers import CLIPImageProcessor, CLIPVisionModel, PretrainedConfig
20
+
21
+ from llava.model.multimodal_encoder.vision_encoder import VisionTower, VisionTowerS2
22
+
23
+
24
+ class CLIPVisionTower(VisionTower):
25
+ def __init__(self, model_name_or_path: str, config: PretrainedConfig):
26
+ super().__init__(model_name_or_path, config)
27
+ self.image_processor = CLIPImageProcessor.from_pretrained(model_name_or_path)
28
+ self.vision_tower = CLIPVisionModel.from_pretrained(model_name_or_path, torch_dtype=eval(config.model_dtype))
29
+ self.is_loaded = True
30
+
31
+
32
+ class CLIPVisionTowerS2(VisionTowerS2):
33
+ def __init__(self, model_name_or_path: str, config: PretrainedConfig):
34
+ super().__init__(model_name_or_path, config)
35
+ self.image_processor = CLIPImageProcessor.from_pretrained(model_name_or_path)
36
+ self.vision_tower = CLIPVisionModel.from_pretrained(model_name_or_path, torch_dtype=eval(config.model_dtype))
37
+
38
+ # Make sure it crops/resizes the image to the largest scale in self.scales to maintain high-res information
39
+ self.image_processor.size["shortest_edge"] = self.scales[-1]
40
+ self.image_processor.crop_size["height"] = self.image_processor.crop_size["width"] = self.scales[-1]
41
+
42
+ self.is_loaded = True
VILA/llava/model/multimodal_encoder/image_processor.py ADDED
@@ -0,0 +1,546 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Image processor class for RADIO."""
16
+ import math
17
+ from copy import deepcopy
18
+ from itertools import product
19
+ from typing import Any, Dict, List, Optional, Tuple, Union
20
+
21
+ import numpy as np
22
+ import PIL
23
+ from PIL.Image import Image
24
+ from transformers.image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
25
+ from transformers.image_transforms import convert_to_rgb, pad, resize, to_channel_dimension_format
26
+ from transformers.image_utils import (
27
+ IMAGENET_DEFAULT_MEAN,
28
+ IMAGENET_DEFAULT_STD,
29
+ ChannelDimension,
30
+ ImageInput,
31
+ PILImageResampling,
32
+ get_image_size,
33
+ infer_channel_dimension_format,
34
+ is_scaled_image,
35
+ make_list_of_images,
36
+ to_numpy_array,
37
+ valid_images,
38
+ )
39
+ from transformers.utils import (
40
+ TensorType,
41
+ is_tf_available,
42
+ is_torch_available,
43
+ is_torchvision_available,
44
+ logging,
45
+ requires_backends,
46
+ )
47
+
48
+ if is_torch_available():
49
+ import torch
50
+ import torch.nn.functional as F
51
+
52
+ if is_torchvision_available():
53
+ from torchvision.ops.boxes import batched_nms
54
+
55
+ # if is_tf_available():
56
+ # import tensorflow as tf
57
+ # from tensorflow.experimental import numpy as tnp
58
+
59
+ # from ...tf_utils import flatten, shape_list
60
+
61
+ logger = logging.get_logger(__name__)
62
+
63
+
64
+ def rank_print(s):
65
+ rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0
66
+ print(f"[Rank {rank}] {s}")
67
+
68
+
69
+ class ImageProcessor(BaseImageProcessor):
70
+ r"""
71
+ Constructs an image processor.
72
+
73
+ Args:
74
+ do_resize (`bool`, *optional*, defaults to `True`):
75
+ Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by the
76
+ `do_resize` parameter in the `preprocess` method.
77
+ size (`dict`, *optional*, defaults to `{"longest_edge": 1024}`):
78
+ Size of the output image after resizing. If "longest_edge" is specified, resizes the longest edge of the image to match
79
+ `size["longest_edge"]` while maintaining the aspect ratio. If "width" and "height" are specified, resizes the image
80
+ to that size, possibly changing the aspect ratio. Can be overridden by the `size` parameter in the
81
+ `preprocess` method.
82
+ resample (`PILImageResampling`, *optional*, defaults to `Resampling.BILINEAR`):
83
+ Resampling filter to use if resizing the image. Can be overridden by the `resample` parameter in the
84
+ `preprocess` method.
85
+ do_rescale (`bool`, *optional*, defaults to `True`):
86
+ Wwhether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the
87
+ `do_rescale` parameter in the `preprocess` method.
88
+ rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
89
+ Scale factor to use if rescaling the image. Only has an effect if `do_rescale` is set to `True`. Can be
90
+ overridden by the `rescale_factor` parameter in the `preprocess` method.
91
+ do_normalize (`bool`, *optional*, defaults to `True`):
92
+ Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess`
93
+ method. Can be overridden by the `do_normalize` parameter in the `preprocess` method.
94
+ image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_DEFAULT_MEAN`):
95
+ Mean to use if normalizing the image. This is a float or list of floats the length of the number of
96
+ channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. Can be
97
+ overridden by the `image_mean` parameter in the `preprocess` method.
98
+ image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_DEFAULT_STD`):
99
+ Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
100
+ number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
101
+ Can be overridden by the `image_std` parameter in the `preprocess` method.
102
+ do_pad (`bool`, *optional*, defaults to `True`):
103
+ Whether to pad the image to the specified `pad_size`. Can be overridden by the `do_pad` parameter in the
104
+ `preprocess` method.
105
+ pad_size (`dict`, *optional*, defaults to `{"height": 1024, "width": 1024}`):
106
+ Size of the output image after padding. Can be overridden by the `pad_size` parameter in the `preprocess`
107
+ method.
108
+ pad_value (`float` or `Iterable[float]`, *optional*, defaults to `0.`):
109
+ Value of padded pixels.
110
+ pad_multiple (`int`, *optional*, defaults to `None`):
111
+ Pad to a multiple of specified number.
112
+ do_convert_rgb (`bool`, *optional*, defaults to `True`):
113
+ Whether to convert the image to RGB.
114
+ """
115
+
116
+ model_input_names = ["pixel_values"]
117
+
118
+ def __init__(
119
+ self,
120
+ do_resize: bool = True,
121
+ size: Dict[str, int] = None,
122
+ resample: PILImageResampling = PILImageResampling.BILINEAR,
123
+ do_rescale: bool = True,
124
+ rescale_factor: Union[int, float] = 1 / 255,
125
+ do_normalize: bool = True,
126
+ image_mean: Optional[Union[float, List[float]]] = None,
127
+ image_std: Optional[Union[float, List[float]]] = None,
128
+ do_pad: bool = True,
129
+ pad_size: int = None,
130
+ pad_multiple: int = None,
131
+ pad_value: Optional[Union[float, List[float]]] = 0.0,
132
+ do_convert_rgb: bool = True,
133
+ **kwargs,
134
+ ) -> None:
135
+ super().__init__(**kwargs)
136
+ x = 0
137
+ size = size if size is not None else {"longest_edge": 1024}
138
+ size = get_size_dict(max_size=size, default_to_square=False) if not isinstance(size, dict) else size
139
+
140
+ if pad_size is not None and pad_multiple is not None:
141
+ raise ValueError("pad_size and pad_multiple should not be set at the same time.")
142
+
143
+ pad_size = (
144
+ pad_size if pad_size is not None else {"height": 1024, "width": 1024} if pad_multiple is not None else None
145
+ )
146
+ if do_pad:
147
+ pad_size = get_size_dict(pad_size, default_to_square=True)
148
+
149
+ self.do_resize = do_resize
150
+ self.size = size
151
+ self.resample = resample
152
+ self.do_rescale = do_rescale
153
+ self.rescale_factor = rescale_factor
154
+ self.do_normalize = do_normalize
155
+ self.image_mean = image_mean if image_mean is not None else IMAGENET_DEFAULT_MEAN
156
+ self.image_std = image_std if image_std is not None else IMAGENET_DEFAULT_STD
157
+ self.do_pad = do_pad
158
+ self.pad_multiple = pad_multiple
159
+ self.pad_size = pad_size
160
+ self.pad_value = tuple(pad_value) if isinstance(pad_value, list) else pad_value
161
+ self.do_convert_rgb = do_convert_rgb
162
+ self._valid_processor_keys = [
163
+ "images",
164
+ "segmentation_maps",
165
+ "do_resize",
166
+ "size",
167
+ "resample",
168
+ "do_rescale",
169
+ "rescale_factor",
170
+ "do_normalize",
171
+ "image_mean",
172
+ "image_std",
173
+ "do_pad",
174
+ "pad_size",
175
+ "do_convert_rgb",
176
+ "return_tensors",
177
+ "data_format",
178
+ "input_data_format",
179
+ ]
180
+
181
+ def pad_image(
182
+ self,
183
+ image: np.ndarray,
184
+ pad_size: Dict[str, int],
185
+ data_format: Optional[Union[str, ChannelDimension]] = None,
186
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
187
+ **kwargs,
188
+ ) -> np.ndarray:
189
+ """
190
+ Pad an image to `(pad_size["height"], pad_size["width"])` to the right and bottom.
191
+
192
+ Args:
193
+ image (`np.ndarray`):
194
+ Image to pad.
195
+ pad_size (`Dict[str, int]`):
196
+ Size of the output image after padding.
197
+ data_format (`str` or `ChannelDimension`, *optional*):
198
+ The data format of the image. Can be either "channels_first" or "channels_last". If `None`, the
199
+ `data_format` of the `image` will be used.
200
+ input_data_format (`str` or `ChannelDimension`, *optional*):
201
+ The channel dimension format of the input image. If not provided, it will be inferred.
202
+ """
203
+ output_height, output_width = pad_size["height"], pad_size["width"]
204
+ input_height, input_width = get_image_size(image, channel_dim=input_data_format)
205
+
206
+ pad_width = output_width - input_width
207
+ pad_height = output_height - input_height
208
+
209
+ padded_image = pad(
210
+ image,
211
+ ((0, pad_height), (0, pad_width)),
212
+ data_format=data_format,
213
+ input_data_format=input_data_format,
214
+ constant_values=self.pad_value,
215
+ **kwargs,
216
+ )
217
+ return padded_image
218
+
219
+ def _get_preprocess_shape(self, old_shape: Tuple[int, int], longest_edge: int):
220
+ """
221
+ Compute the output size given input size and target long side length.
222
+ """
223
+ oldh, oldw = old_shape
224
+ scale = longest_edge * 1.0 / max(oldh, oldw)
225
+ newh, neww = oldh * scale, oldw * scale
226
+ newh = int(newh + 0.5)
227
+ neww = int(neww + 0.5)
228
+ return (newh, neww)
229
+
230
+ def resize(
231
+ self,
232
+ image: np.ndarray,
233
+ size: Dict[str, int],
234
+ resample: PILImageResampling = PILImageResampling.BICUBIC,
235
+ data_format: Optional[Union[str, ChannelDimension]] = None,
236
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
237
+ **kwargs,
238
+ ) -> np.ndarray:
239
+ """
240
+ Resize an image to `(size["height"], size["width"])`.
241
+
242
+ Args:
243
+ image (`np.ndarray`):
244
+ Image to resize.
245
+ size (`Dict[str, int]`):
246
+ Dictionary in the format `{"longest_edge": int}` or `{"width": int, "height": int}` specifying the size
247
+ of the output image. If "longest_edge" is specified, resizes the longest edge of the image to match
248
+ `size["longest_edge"]` while maintaining the aspect ratio. If "width" and "height" are specified, resizes the image
249
+ to that size, possibly changing the aspect ratio.
250
+ resample:
251
+ `PILImageResampling` filter to use when resizing the image e.g. `PILImageResampling.BILINEAR`.
252
+ data_format (`ChannelDimension` or `str`, *optional*):
253
+ The channel dimension format for the output image. If unset, the channel dimension format of the input
254
+ image is used. Can be one of:
255
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
256
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
257
+ input_data_format (`ChannelDimension` or `str`, *optional*):
258
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
259
+ from the input image. Can be one of:
260
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
261
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
262
+
263
+ Returns:
264
+ `np.ndarray`: The resized image.
265
+ """
266
+ size = get_size_dict(size)
267
+ if "longest_edge" not in size:
268
+ if "width" not in size or "height" not in size:
269
+ raise ValueError(
270
+ f"The `size` dictionary must contain the key `longest_edge`, or `width` and `height`. Got {size.keys()}"
271
+ )
272
+ input_size = get_image_size(image, channel_dim=input_data_format)
273
+ if "longest_edge" in size:
274
+ output_height, output_width = self._get_preprocess_shape(input_size, size["longest_edge"])
275
+ else:
276
+ output_height, output_width = size["height"], size["width"]
277
+ return resize(
278
+ image,
279
+ size=(output_height, output_width),
280
+ resample=resample,
281
+ data_format=data_format,
282
+ input_data_format=input_data_format,
283
+ **kwargs,
284
+ )
285
+
286
+ def _preprocess(
287
+ self,
288
+ image: ImageInput,
289
+ do_resize: bool,
290
+ do_rescale: bool,
291
+ do_normalize: bool,
292
+ size: Optional[Dict[str, int]] = None,
293
+ resample: PILImageResampling = None,
294
+ rescale_factor: Optional[float] = None,
295
+ image_mean: Optional[Union[float, List[float]]] = None,
296
+ image_std: Optional[Union[float, List[float]]] = None,
297
+ do_pad: Optional[bool] = None,
298
+ pad_size: Optional[Dict[str, int]] = None,
299
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
300
+ ):
301
+ if do_resize:
302
+ image = self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format)
303
+ reshaped_input_size = get_image_size(image, channel_dim=input_data_format)
304
+
305
+ if do_rescale:
306
+ image = self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
307
+
308
+ if do_normalize:
309
+ image = self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format)
310
+
311
+ if do_pad:
312
+ if self.pad_multiple:
313
+ h, w = get_image_size(image, channel_dim=input_data_format)
314
+ pad_size = {
315
+ "height": math.ceil(h / self.pad_multiple) * self.pad_multiple,
316
+ "width": math.ceil(w / self.pad_multiple) * self.pad_multiple,
317
+ }
318
+
319
+ image = self.pad_image(image=image, pad_size=pad_size, input_data_format=input_data_format)
320
+
321
+ return image, reshaped_input_size
322
+
323
+ def _preprocess_image(
324
+ self,
325
+ image: ImageInput,
326
+ do_resize: Optional[bool] = None,
327
+ size: Dict[str, int] = None,
328
+ resample: PILImageResampling = None,
329
+ do_rescale: bool = None,
330
+ rescale_factor: Optional[float] = None,
331
+ do_normalize: Optional[bool] = None,
332
+ image_mean: Optional[Union[float, List[float]]] = None,
333
+ image_std: Optional[Union[float, List[float]]] = None,
334
+ do_pad: Optional[bool] = None,
335
+ pad_size: Optional[Dict[str, int]] = None,
336
+ do_convert_rgb: Optional[bool] = None,
337
+ data_format: Optional[Union[str, ChannelDimension]] = None,
338
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
339
+ ) -> Tuple[np.ndarray, Tuple[int, int], Tuple[int, int]]:
340
+ # image = to_numpy_array(image)
341
+
342
+ # import time
343
+ # if int(time.time()*1000) % 10 == 0:
344
+ # # create an PIL image of size 1x1
345
+ # image = PIL.Image.new('RGB', (1, 1))
346
+
347
+ if isinstance(image, Image):
348
+ # PIL always uses Channels Last.
349
+ input_data_format = ChannelDimension.LAST
350
+
351
+ # PIL RGBA images are converted to RGB
352
+ # mode_before = image.mode
353
+ if do_convert_rgb:
354
+ image = convert_to_rgb(image)
355
+
356
+ # All transformations expect numpy arrays.
357
+ image_ = image
358
+ image = to_numpy_array(image)
359
+
360
+ # if isinstance(image_, np.ndarray):
361
+ # rank_print(f"preprocess image type={type(image_)} shape={image_.shape} array shape={image.shape}")
362
+ # elif isinstance(image_, Image):
363
+ # rank_print(f"preprocessimage type={type(image_)} size={image_.size} mode={image_.mode} array shape={image.shape}")
364
+ # else:
365
+ # rank_print(f"preprocess unknown image type={type(image_)} array shape={image.shape}")
366
+
367
+ if len(image.shape) == 2:
368
+ h, w = image.shape
369
+ ret = np.empty((h, w, 3), dtype=np.uint8)
370
+ ret[:, :, 0] = image
371
+ ret[:, :, 1] = image
372
+ ret[:, :, 2] = image
373
+ image = ret
374
+ rank_print(f"preprocess new image shape={image.shape}")
375
+ elif len(image.shape) == 3 and image.shape[-1] == 1:
376
+ ret = np.empty((h, w, 3), dtype=np.uint8)
377
+ ret[:, :, 0] = image[:, :, 0]
378
+ ret[:, :, 1] = image[:, :, 0]
379
+ ret[:, :, 2] = image[:, :, 0]
380
+ image = ret
381
+ rank_print(f"preprocess new image shape={image.shape}")
382
+
383
+ if is_scaled_image(image) and do_rescale:
384
+ logger.warning_once(
385
+ "It looks like you are trying to rescale already rescaled images. If the input"
386
+ " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
387
+ )
388
+
389
+ if input_data_format is None:
390
+ input_data_format = infer_channel_dimension_format(image)
391
+
392
+ original_size = get_image_size(image, channel_dim=input_data_format)
393
+
394
+ image, reshaped_input_size = self._preprocess(
395
+ image=image,
396
+ do_resize=do_resize,
397
+ size=size,
398
+ resample=resample,
399
+ do_rescale=do_rescale,
400
+ rescale_factor=rescale_factor,
401
+ do_normalize=do_normalize,
402
+ image_mean=image_mean,
403
+ image_std=image_std,
404
+ do_pad=do_pad,
405
+ pad_size=pad_size,
406
+ input_data_format=input_data_format,
407
+ )
408
+
409
+ if data_format is not None:
410
+ image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
411
+
412
+ # rank_print(f"preprocess original_size={original_size} reshaped_input_size={reshaped_input_size} image shape={image.shape} type={type(image)}")
413
+
414
+ # if image is a single channel convert to rgb
415
+ if do_convert_rgb and image.shape[0] == 1:
416
+ c, h, w = image.shape
417
+ ret = np.empty((3, h, w), dtype=np.uint8)
418
+ ret[0, :, :] = image[0, :, :]
419
+ ret[1, :, :] = image[0, :, :]
420
+ ret[2, :, :] = image[0, :, :]
421
+ image = ret
422
+ rank_print(f"preprocess final: {image.shape}")
423
+
424
+ return image, original_size, reshaped_input_size
425
+
426
+ def preprocess(
427
+ self,
428
+ images: ImageInput,
429
+ do_resize: Optional[bool] = None,
430
+ size: Optional[Dict[str, int]] = None,
431
+ resample: Optional["PILImageResampling"] = None,
432
+ do_rescale: Optional[bool] = None,
433
+ rescale_factor: Optional[Union[int, float]] = None,
434
+ do_normalize: Optional[bool] = None,
435
+ image_mean: Optional[Union[float, List[float]]] = None,
436
+ image_std: Optional[Union[float, List[float]]] = None,
437
+ do_pad: Optional[bool] = None,
438
+ pad_size: Optional[Dict[str, int]] = None,
439
+ do_convert_rgb: Optional[bool] = None,
440
+ return_tensors: Optional[Union[str, TensorType]] = None,
441
+ data_format: ChannelDimension = ChannelDimension.FIRST,
442
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
443
+ **kwargs,
444
+ ):
445
+ """
446
+ Preprocess an image or batch of images.
447
+
448
+ Args:
449
+ images (`ImageInput`):
450
+ Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
451
+ passing in images with pixel values between 0 and 1, set `do_rescale=False`.
452
+ do_resize (`bool`, *optional*, defaults to `self.do_resize`):
453
+ Whether to resize the image.
454
+ size (`Dict[str, int]`, *optional*, defaults to `self.size`):
455
+ Controls the size of the image after `resize`. The longest edge of the image is resized to
456
+ `size["longest_edge"]` whilst preserving the aspect ratio.
457
+ resample (`PILImageResampling`, *optional*, defaults to `self.resample`):
458
+ `PILImageResampling` filter to use when resizing the image e.g. `PILImageResampling.BILINEAR`.
459
+ do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
460
+ Whether to rescale the image pixel values by rescaling factor.
461
+ rescale_factor (`int` or `float`, *optional*, defaults to `self.rescale_factor`):
462
+ Rescale factor to apply to the image pixel values.
463
+ do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
464
+ Whether to normalize the image.
465
+ image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
466
+ Image mean to normalize the image by if `do_normalize` is set to `True`.
467
+ image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
468
+ Image standard deviation to normalize the image by if `do_normalize` is set to `True`.
469
+ do_pad (`bool`, *optional*, defaults to `self.do_pad`):
470
+ Whether to pad the image.
471
+ pad_size (`Dict[str, int]`, *optional*, defaults to `self.pad_size`):
472
+ Controls the size of the padding applied to the image. The image is padded to `pad_size["height"]` and
473
+ `pad_size["width"]` if `do_pad` is set to `True`.
474
+ do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
475
+ Whether to convert the image to RGB.
476
+ return_tensors (`str` or `TensorType`, *optional*):
477
+ The type of tensors to return. Can be one of:
478
+ - Unset: Return a list of `np.ndarray`.
479
+ - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
480
+ - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
481
+ - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
482
+ - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
483
+ data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
484
+ The channel dimension format for the output image. Can be one of:
485
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
486
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
487
+ - Unset: Use the channel dimension format of the input image.
488
+ input_data_format (`ChannelDimension` or `str`, *optional*):
489
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
490
+ from the input image. Can be one of:
491
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
492
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
493
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
494
+ """
495
+ do_resize = do_resize if do_resize is not None else self.do_resize
496
+ size = size if size is not None else self.size
497
+ size = get_size_dict(max_size=size, default_to_square=False) if not isinstance(size, dict) else size
498
+ resample = resample if resample is not None else self.resample
499
+ do_rescale = do_rescale if do_rescale is not None else self.do_rescale
500
+ rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
501
+ do_normalize = do_normalize if do_normalize is not None else self.do_normalize
502
+ image_mean = image_mean if image_mean is not None else self.image_mean
503
+ image_std = image_std if image_std is not None else self.image_std
504
+ do_pad = do_pad if do_pad is not None else self.do_pad
505
+ pad_size = pad_size if pad_size is not None else self.pad_size
506
+ if do_pad:
507
+ pad_size = get_size_dict(pad_size, default_to_square=True)
508
+ do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
509
+
510
+ images = make_list_of_images(images)
511
+
512
+ if not valid_images(images):
513
+ raise ValueError(
514
+ "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
515
+ "torch.Tensor, tf.Tensor or jax.ndarray."
516
+ )
517
+
518
+ images, original_sizes, reshaped_input_sizes = zip(
519
+ *(
520
+ self._preprocess_image(
521
+ image=img,
522
+ do_resize=do_resize,
523
+ size=size,
524
+ resample=resample,
525
+ do_rescale=do_rescale,
526
+ rescale_factor=rescale_factor,
527
+ do_normalize=do_normalize,
528
+ image_mean=image_mean,
529
+ image_std=image_std,
530
+ do_pad=do_pad,
531
+ pad_size=pad_size,
532
+ do_convert_rgb=do_convert_rgb,
533
+ data_format=data_format,
534
+ input_data_format=input_data_format,
535
+ )
536
+ for img in images
537
+ )
538
+ )
539
+
540
+ data = {
541
+ "pixel_values": images,
542
+ "original_sizes": original_sizes,
543
+ "reshaped_input_sizes": reshaped_input_sizes,
544
+ }
545
+
546
+ return BatchFeature(data=data, tensor_type=return_tensors)
VILA/llava/model/multimodal_encoder/intern/__pycache__/configuration_intern_vit.cpython-310.pyc ADDED
Binary file (4.98 kB). View file
 
VILA/llava/model/multimodal_encoder/intern/__pycache__/flash_attention.cpython-310.pyc ADDED
Binary file (2.72 kB). View file
 
VILA/llava/model/multimodal_encoder/intern/__pycache__/modeling_intern_vit.cpython-310.pyc ADDED
Binary file (18 kB). View file
 
VILA/llava/model/multimodal_encoder/intern/configuration_intern_vit.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # InternVL
3
+ # Copyright (c) 2023 OpenGVLab
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # --------------------------------------------------------
6
+ import os
7
+ from typing import Union
8
+
9
+ from transformers.configuration_utils import PretrainedConfig
10
+ from transformers.utils import logging
11
+
12
+ logger = logging.get_logger(__name__)
13
+
14
+
15
+ class InternVisionConfig(PretrainedConfig):
16
+ r"""
17
+ This is the configuration class to store the configuration of a [`InternVisionModel`]. It is used to
18
+ instantiate a vision encoder according to the specified arguments, defining the model architecture.
19
+
20
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
21
+ documentation from [`PretrainedConfig`] for more information.
22
+
23
+ Args:
24
+ num_channels (`int`, *optional*, defaults to 3):
25
+ Number of color channels in the input images (e.g., 3 for RGB).
26
+ patch_size (`int`, *optional*, defaults to 14):
27
+ The size (resolution) of each patch.
28
+ image_size (`int`, *optional*, defaults to 224):
29
+ The size (resolution) of each image.
30
+ qkv_bias (`bool`, *optional*, defaults to `False`):
31
+ Whether to add a bias to the queries and values in the self-attention layers.
32
+ hidden_size (`int`, *optional*, defaults to 3200):
33
+ Dimensionality of the encoder layers and the pooler layer.
34
+ num_attention_heads (`int`, *optional*, defaults to 25):
35
+ Number of attention heads for each attention layer in the Transformer encoder.
36
+ intermediate_size (`int`, *optional*, defaults to 12800):
37
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
38
+ qk_normalization (`bool`, *optional*, defaults to `True`):
39
+ Whether to normalize the queries and keys in the self-attention layers.
40
+ num_hidden_layers (`int`, *optional*, defaults to 48):
41
+ Number of hidden layers in the Transformer encoder.
42
+ use_flash_attn (`bool`, *optional*, defaults to `True`):
43
+ Whether to use flash attention mechanism.
44
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
45
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
46
+ `"relu"`, `"selu"` and `"gelu_new"` ``"gelu"` are supported.
47
+ layer_norm_eps (`float`, *optional*, defaults to 1e-6):
48
+ The epsilon used by the layer normalization layers.
49
+ dropout (`float`, *optional*, defaults to 0.0):
50
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
51
+ drop_path_rate (`float`, *optional*, defaults to 0.0):
52
+ Dropout rate for stochastic depth.
53
+ attention_dropout (`float`, *optional*, defaults to 0.0):
54
+ The dropout ratio for the attention probabilities.
55
+ initializer_range (`float`, *optional*, defaults to 0.02):
56
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
57
+ initializer_factor (`float`, *optional*, defaults to 0.1):
58
+ A factor for layer scale.
59
+ """
60
+
61
+ model_type = "intern_vit_6b"
62
+
63
+ def __init__(
64
+ self,
65
+ num_channels=3,
66
+ patch_size=14,
67
+ image_size=224,
68
+ qkv_bias=False,
69
+ hidden_size=3200,
70
+ num_attention_heads=25,
71
+ intermediate_size=12800,
72
+ qk_normalization=True,
73
+ num_hidden_layers=48,
74
+ use_flash_attn=True,
75
+ hidden_act="gelu",
76
+ layer_norm_eps=1e-6,
77
+ dropout=0.0,
78
+ drop_path_rate=0.0,
79
+ attention_dropout=0.0,
80
+ initializer_range=0.02,
81
+ initializer_factor=0.1,
82
+ **kwargs,
83
+ ):
84
+ super().__init__(**kwargs)
85
+
86
+ self.hidden_size = hidden_size
87
+ self.intermediate_size = intermediate_size
88
+ self.dropout = dropout
89
+ self.drop_path_rate = drop_path_rate
90
+ self.num_hidden_layers = num_hidden_layers
91
+ self.num_attention_heads = num_attention_heads
92
+ self.num_channels = num_channels
93
+ self.patch_size = patch_size
94
+ self.image_size = image_size
95
+ self.initializer_range = initializer_range
96
+ self.initializer_factor = initializer_factor
97
+ self.attention_dropout = attention_dropout
98
+ self.layer_norm_eps = layer_norm_eps
99
+ self.hidden_act = hidden_act
100
+ self.qkv_bias = qkv_bias
101
+ self.qk_normalization = qk_normalization
102
+ self.use_flash_attn = use_flash_attn
103
+
104
+ @classmethod
105
+ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
106
+ config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
107
+
108
+ if "vision_config" in config_dict:
109
+ config_dict = config_dict["vision_config"]
110
+
111
+ if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
112
+ logger.warning(
113
+ f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
114
+ f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
115
+ )
116
+
117
+ return cls.from_dict(config_dict, **kwargs)
VILA/llava/model/multimodal_encoder/intern/flash_attention.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 NVIDIA CORPORATION & AFFILIATES
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # SPDX-License-Identifier: Apache-2.0
16
+
17
+ # https://github.com/Dao-AILab/flash-attention/blob/v0.2.8/flash_attn/flash_attention.py
18
+ import torch
19
+ import torch.nn as nn
20
+ from einops import rearrange
21
+
22
+ try: # v1
23
+ from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func
24
+ except: # v2
25
+ from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func as flash_attn_unpadded_qkvpacked_func
26
+
27
+ from flash_attn.bert_padding import pad_input, unpad_input
28
+
29
+
30
+ class FlashAttention(nn.Module):
31
+ """Implement the scaled dot product attention with softmax.
32
+ Arguments
33
+ ---------
34
+ softmax_scale: The temperature to use for the softmax attention.
35
+ (default: 1/sqrt(d_keys) where d_keys is computed at
36
+ runtime)
37
+ attention_dropout: The dropout rate to apply to the attention
38
+ (default: 0.0)
39
+ """
40
+
41
+ def __init__(self, softmax_scale=None, attention_dropout=0.0, device=None, dtype=None):
42
+ super().__init__()
43
+ self.softmax_scale = softmax_scale
44
+ self.dropout_p = attention_dropout
45
+
46
+ def forward(self, qkv, key_padding_mask=None, causal=False, cu_seqlens=None, max_s=None, need_weights=False):
47
+ """Implements the multihead softmax attention.
48
+ Arguments
49
+ ---------
50
+ qkv: The tensor containing the query, key, and value. (B, S, 3, H, D) if key_padding_mask is None
51
+ if unpadded: (nnz, 3, h, d)
52
+ key_padding_mask: a bool tensor of shape (B, S)
53
+ """
54
+ assert not need_weights
55
+ assert qkv.dtype in [torch.float16, torch.bfloat16]
56
+ assert qkv.is_cuda
57
+
58
+ if cu_seqlens is None:
59
+ batch_size = qkv.shape[0]
60
+ seqlen = qkv.shape[1]
61
+ if key_padding_mask is None:
62
+ qkv = rearrange(qkv, "b s ... -> (b s) ...")
63
+ max_s = seqlen
64
+ cu_seqlens = torch.arange(
65
+ 0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32, device=qkv.device
66
+ )
67
+ output = flash_attn_unpadded_qkvpacked_func(
68
+ qkv,
69
+ cu_seqlens,
70
+ max_s,
71
+ self.dropout_p if self.training else 0.0,
72
+ softmax_scale=self.softmax_scale,
73
+ causal=causal,
74
+ )
75
+ output = rearrange(output, "(b s) ... -> b s ...", b=batch_size)
76
+ else:
77
+ nheads = qkv.shape[-2]
78
+ x = rearrange(qkv, "b s three h d -> b s (three h d)")
79
+ x_unpad, indices, cu_seqlens, max_s = unpad_input(x, key_padding_mask)
80
+ x_unpad = rearrange(x_unpad, "nnz (three h d) -> nnz three h d", three=3, h=nheads)
81
+ output_unpad = flash_attn_unpadded_qkvpacked_func(
82
+ x_unpad,
83
+ cu_seqlens,
84
+ max_s,
85
+ self.dropout_p if self.training else 0.0,
86
+ softmax_scale=self.softmax_scale,
87
+ causal=causal,
88
+ )
89
+ output = rearrange(
90
+ pad_input(rearrange(output_unpad, "nnz h d -> nnz (h d)"), indices, batch_size, seqlen),
91
+ "b s (h d) -> b s h d",
92
+ h=nheads,
93
+ )
94
+ else:
95
+ assert max_s is not None
96
+ output = flash_attn_unpadded_qkvpacked_func(
97
+ qkv,
98
+ cu_seqlens,
99
+ max_s,
100
+ self.dropout_p if self.training else 0.0,
101
+ softmax_scale=self.softmax_scale,
102
+ causal=causal,
103
+ )
104
+
105
+ return output, None
VILA/llava/model/multimodal_encoder/intern/modeling_intern_vit.py ADDED
@@ -0,0 +1,543 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # InternVL
3
+ # Copyright (c) 2023 OpenGVLab
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # --------------------------------------------------------
6
+ from typing import Optional, Tuple, Union
7
+
8
+ import torch
9
+ import torch.nn.functional as F
10
+ import torch.utils.checkpoint
11
+ from einops import rearrange
12
+ from torch import nn
13
+ from transformers.activations import ACT2FN
14
+ from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
15
+ from transformers.modeling_utils import PreTrainedModel
16
+ from transformers.utils import logging
17
+
18
+ from llava.model.multimodal_encoder.intern.configuration_intern_vit import InternVisionConfig
19
+
20
+ from .flash_attention import FlashAttention
21
+
22
+ has_flash_attn = True
23
+
24
+
25
+ logger = logging.get_logger(__name__)
26
+
27
+
28
+ """ DropBlock, DropPath
29
+
30
+ PyTorch implementations of DropBlock and DropPath (Stochastic Depth) regularization layers.
31
+
32
+ Papers:
33
+ DropBlock: A regularization method for convolutional networks (https://arxiv.org/abs/1810.12890)
34
+
35
+ Deep Networks with Stochastic Depth (https://arxiv.org/abs/1603.09382)
36
+
37
+ Code:
38
+ DropBlock impl inspired by two Tensorflow impl that I liked:
39
+ - https://github.com/tensorflow/tpu/blob/master/models/official/resnet/resnet_model.py#L74
40
+ - https://github.com/clovaai/assembled-cnn/blob/master/nets/blocks.py
41
+
42
+ Hacked together by / Copyright 2020 Ross Wightman
43
+ """
44
+ import torch
45
+ import torch.nn as nn
46
+ import torch.nn.functional as F
47
+
48
+
49
+ def ndgrid(*tensors) -> Tuple[torch.Tensor, ...]:
50
+ """generate N-D grid in dimension order.
51
+
52
+ The ndgrid function is like meshgrid except that the order of the first two input arguments are switched.
53
+
54
+ That is, the statement
55
+ [X1,X2,X3] = ndgrid(x1,x2,x3)
56
+
57
+ produces the same result as
58
+
59
+ [X2,X1,X3] = meshgrid(x2,x1,x3)
60
+
61
+ This naming is based on MATLAB, the purpose is to avoid confusion due to torch's change to make
62
+ torch.meshgrid behaviour move from matching ndgrid ('ij') indexing to numpy meshgrid defaults of ('xy').
63
+
64
+ """
65
+ try:
66
+ return torch.meshgrid(*tensors, indexing="ij")
67
+ except TypeError:
68
+ # old PyTorch < 1.10 will follow this path as it does not have indexing arg,
69
+ # the old behaviour of meshgrid was 'ij'
70
+ return torch.meshgrid(*tensors)
71
+
72
+
73
+ def drop_block_2d(
74
+ x,
75
+ drop_prob: float = 0.1,
76
+ block_size: int = 7,
77
+ gamma_scale: float = 1.0,
78
+ with_noise: bool = False,
79
+ inplace: bool = False,
80
+ batchwise: bool = False,
81
+ ):
82
+ """DropBlock. See https://arxiv.org/pdf/1810.12890.pdf
83
+
84
+ DropBlock with an experimental gaussian noise option. This layer has been tested on a few training
85
+ runs with success, but needs further validation and possibly optimization for lower runtime impact.
86
+ """
87
+ B, C, H, W = x.shape
88
+ total_size = W * H
89
+ clipped_block_size = min(block_size, min(W, H))
90
+ # seed_drop_rate, the gamma parameter
91
+ gamma = (
92
+ gamma_scale * drop_prob * total_size / clipped_block_size**2 / ((W - block_size + 1) * (H - block_size + 1))
93
+ )
94
+
95
+ # Forces the block to be inside the feature map.
96
+ w_i, h_i = ndgrid(torch.arange(W, device=x.device), torch.arange(H, device=x.device))
97
+ valid_block = ((w_i >= clipped_block_size // 2) & (w_i < W - (clipped_block_size - 1) // 2)) & (
98
+ (h_i >= clipped_block_size // 2) & (h_i < H - (clipped_block_size - 1) // 2)
99
+ )
100
+ valid_block = torch.reshape(valid_block, (1, 1, H, W)).to(dtype=x.dtype)
101
+
102
+ if batchwise:
103
+ # one mask for whole batch, quite a bit faster
104
+ uniform_noise = torch.rand((1, C, H, W), dtype=x.dtype, device=x.device)
105
+ else:
106
+ uniform_noise = torch.rand_like(x)
107
+ block_mask = ((2 - gamma - valid_block + uniform_noise) >= 1).to(dtype=x.dtype)
108
+ block_mask = -F.max_pool2d(
109
+ -block_mask, kernel_size=clipped_block_size, stride=1, padding=clipped_block_size // 2 # block_size,
110
+ )
111
+
112
+ if with_noise:
113
+ normal_noise = torch.randn((1, C, H, W), dtype=x.dtype, device=x.device) if batchwise else torch.randn_like(x)
114
+ if inplace:
115
+ x.mul_(block_mask).add_(normal_noise * (1 - block_mask))
116
+ else:
117
+ x = x * block_mask + normal_noise * (1 - block_mask)
118
+ else:
119
+ normalize_scale = (block_mask.numel() / block_mask.to(dtype=torch.float32).sum().add(1e-7)).to(x.dtype)
120
+ if inplace:
121
+ x.mul_(block_mask * normalize_scale)
122
+ else:
123
+ x = x * block_mask * normalize_scale
124
+ return x
125
+
126
+
127
+ def drop_block_fast_2d(
128
+ x: torch.Tensor,
129
+ drop_prob: float = 0.1,
130
+ block_size: int = 7,
131
+ gamma_scale: float = 1.0,
132
+ with_noise: bool = False,
133
+ inplace: bool = False,
134
+ ):
135
+ """DropBlock. See https://arxiv.org/pdf/1810.12890.pdf
136
+
137
+ DropBlock with an experimental gaussian noise option. Simplied from above without concern for valid
138
+ block mask at edges.
139
+ """
140
+ B, C, H, W = x.shape
141
+ total_size = W * H
142
+ clipped_block_size = min(block_size, min(W, H))
143
+ gamma = (
144
+ gamma_scale * drop_prob * total_size / clipped_block_size**2 / ((W - block_size + 1) * (H - block_size + 1))
145
+ )
146
+
147
+ block_mask = torch.empty_like(x).bernoulli_(gamma)
148
+ block_mask = F.max_pool2d(
149
+ block_mask.to(x.dtype), kernel_size=clipped_block_size, stride=1, padding=clipped_block_size // 2
150
+ )
151
+
152
+ if with_noise:
153
+ normal_noise = torch.empty_like(x).normal_()
154
+ if inplace:
155
+ x.mul_(1.0 - block_mask).add_(normal_noise * block_mask)
156
+ else:
157
+ x = x * (1.0 - block_mask) + normal_noise * block_mask
158
+ else:
159
+ block_mask = 1 - block_mask
160
+ normalize_scale = (block_mask.numel() / block_mask.to(dtype=torch.float32).sum().add(1e-6)).to(dtype=x.dtype)
161
+ if inplace:
162
+ x.mul_(block_mask * normalize_scale)
163
+ else:
164
+ x = x * block_mask * normalize_scale
165
+ return x
166
+
167
+
168
+ class DropBlock2d(nn.Module):
169
+ """DropBlock. See https://arxiv.org/pdf/1810.12890.pdf"""
170
+
171
+ def __init__(
172
+ self,
173
+ drop_prob: float = 0.1,
174
+ block_size: int = 7,
175
+ gamma_scale: float = 1.0,
176
+ with_noise: bool = False,
177
+ inplace: bool = False,
178
+ batchwise: bool = False,
179
+ fast: bool = True,
180
+ ):
181
+ super().__init__()
182
+ self.drop_prob = drop_prob
183
+ self.gamma_scale = gamma_scale
184
+ self.block_size = block_size
185
+ self.with_noise = with_noise
186
+ self.inplace = inplace
187
+ self.batchwise = batchwise
188
+ self.fast = fast # FIXME finish comparisons of fast vs not
189
+
190
+ def forward(self, x):
191
+ if not self.training or not self.drop_prob:
192
+ return x
193
+ if self.fast:
194
+ return drop_block_fast_2d(
195
+ x, self.drop_prob, self.block_size, self.gamma_scale, self.with_noise, self.inplace
196
+ )
197
+ else:
198
+ return drop_block_2d(
199
+ x, self.drop_prob, self.block_size, self.gamma_scale, self.with_noise, self.inplace, self.batchwise
200
+ )
201
+
202
+
203
+ def drop_path(x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True):
204
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
205
+
206
+ This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
207
+ the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
208
+ See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
209
+ changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
210
+ 'survival rate' as the argument.
211
+
212
+ """
213
+ if drop_prob == 0.0 or not training:
214
+ return x
215
+ keep_prob = 1 - drop_prob
216
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
217
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
218
+ if keep_prob > 0.0 and scale_by_keep:
219
+ random_tensor.div_(keep_prob)
220
+ return x * random_tensor
221
+
222
+
223
+ class DropPath(nn.Module):
224
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
225
+
226
+ def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True):
227
+ super().__init__()
228
+ self.drop_prob = drop_prob
229
+ self.scale_by_keep = scale_by_keep
230
+
231
+ def forward(self, x):
232
+ return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
233
+
234
+ def extra_repr(self):
235
+ return f"drop_prob={round(self.drop_prob,3):0.3f}"
236
+
237
+
238
+ class InternRMSNorm(nn.Module):
239
+ def __init__(self, hidden_size, eps=1e-6):
240
+ super().__init__()
241
+ self.weight = nn.Parameter(torch.ones(hidden_size))
242
+ self.variance_epsilon = eps
243
+
244
+ def forward(self, hidden_states):
245
+ input_dtype = hidden_states.dtype
246
+ hidden_states = hidden_states.to(torch.float32)
247
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
248
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
249
+ return self.weight * hidden_states.to(input_dtype)
250
+
251
+
252
+ try:
253
+ from apex.normalization import FusedRMSNorm
254
+
255
+ InternRMSNorm = FusedRMSNorm # noqa
256
+
257
+ logger.info("Discovered apex.normalization.FusedRMSNorm - will use it instead of InternRMSNorm")
258
+ except ImportError:
259
+ # using the normal InternRMSNorm
260
+ pass
261
+ except Exception:
262
+ logger.warning("discovered apex but it failed to load, falling back to InternRMSNorm")
263
+ pass
264
+
265
+
266
+ class InternVisionEmbeddings(nn.Module):
267
+ def __init__(self, config: InternVisionConfig):
268
+ super().__init__()
269
+ self.config = config
270
+ self.embed_dim = config.hidden_size
271
+ self.image_size = config.image_size
272
+ self.patch_size = config.patch_size
273
+
274
+ self.class_embedding = nn.Parameter(
275
+ torch.randn(1, 1, self.embed_dim),
276
+ )
277
+
278
+ self.patch_embedding = nn.Conv2d(
279
+ in_channels=3, out_channels=self.embed_dim, kernel_size=self.patch_size, stride=self.patch_size
280
+ )
281
+
282
+ self.num_patches = (self.image_size // self.patch_size) ** 2
283
+ self.num_positions = self.num_patches + 1
284
+
285
+ self.position_embedding = nn.Parameter(torch.randn(1, self.num_positions, self.embed_dim))
286
+
287
+ def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
288
+ batch_size = pixel_values.shape[0]
289
+ target_dtype = self.patch_embedding.weight.dtype
290
+ patch_embeds = self.patch_embedding(pixel_values) # shape = [*, width, grid, grid]
291
+ patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
292
+ class_embeds = self.class_embedding.expand(batch_size, 1, -1).to(target_dtype)
293
+ embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
294
+ embeddings = embeddings + self.position_embedding.to(target_dtype)
295
+ return embeddings
296
+
297
+
298
+ class InternAttention(nn.Module):
299
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
300
+
301
+ def __init__(self, config: InternVisionConfig):
302
+ super().__init__()
303
+ self.config = config
304
+ self.embed_dim = config.hidden_size
305
+ self.num_heads = config.num_attention_heads
306
+ self.use_flash_attn = config.use_flash_attn and has_flash_attn
307
+ if config.use_flash_attn and not has_flash_attn:
308
+ print("Warning: Flash Attention is not available, use_flash_attn is set to False.")
309
+ self.head_dim = self.embed_dim // self.num_heads
310
+ if self.head_dim * self.num_heads != self.embed_dim:
311
+ raise ValueError(
312
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
313
+ f" {self.num_heads})."
314
+ )
315
+
316
+ self.scale = self.head_dim**-0.5
317
+ self.qkv = nn.Linear(self.embed_dim, 3 * self.embed_dim, bias=config.qkv_bias)
318
+ self.attn_drop = nn.Dropout(config.attention_dropout)
319
+ self.proj_drop = nn.Dropout(config.dropout)
320
+
321
+ self.qk_normalization = config.qk_normalization
322
+
323
+ if self.qk_normalization:
324
+ self.q_norm = InternRMSNorm(self.embed_dim, eps=config.layer_norm_eps)
325
+ self.k_norm = InternRMSNorm(self.embed_dim, eps=config.layer_norm_eps)
326
+
327
+ if self.use_flash_attn:
328
+ self.inner_attn = FlashAttention(attention_dropout=config.attention_dropout)
329
+ self.proj = nn.Linear(self.embed_dim, self.embed_dim)
330
+
331
+ def _naive_attn(self, x):
332
+ B, N, C = x.shape
333
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
334
+ q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
335
+
336
+ if self.qk_normalization:
337
+ B_, H_, N_, D_ = q.shape
338
+ q = self.q_norm(q.transpose(1, 2).flatten(-2, -1)).view(B_, N_, H_, D_).transpose(1, 2)
339
+ k = self.k_norm(k.transpose(1, 2).flatten(-2, -1)).view(B_, N_, H_, D_).transpose(1, 2)
340
+
341
+ attn = (q * self.scale) @ k.transpose(-2, -1)
342
+ attn = attn.softmax(dim=-1)
343
+ attn = self.attn_drop(attn)
344
+
345
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
346
+ x = self.proj(x)
347
+ x = self.proj_drop(x)
348
+ return x
349
+
350
+ def _flash_attn(self, x, key_padding_mask=None, need_weights=False):
351
+ qkv = self.qkv(x)
352
+ qkv = rearrange(qkv, "b s (three h d) -> b s three h d", three=3, h=self.num_heads)
353
+
354
+ if self.qk_normalization:
355
+ q, k, v = qkv.unbind(2)
356
+ q = self.q_norm(q.flatten(-2, -1)).view(q.shape)
357
+ k = self.k_norm(k.flatten(-2, -1)).view(k.shape)
358
+ qkv = torch.stack([q, k, v], dim=2)
359
+
360
+ context, _ = self.inner_attn(qkv, key_padding_mask=key_padding_mask, need_weights=need_weights, causal=False)
361
+ outs = self.proj(rearrange(context, "b s h d -> b s (h d)"))
362
+ outs = self.proj_drop(outs)
363
+ return outs
364
+
365
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
366
+ x = self._naive_attn(hidden_states) if not self.use_flash_attn else self._flash_attn(hidden_states)
367
+ return x
368
+
369
+
370
+ class InternMLP(nn.Module):
371
+ def __init__(self, config: InternVisionConfig):
372
+ super().__init__()
373
+ self.config = config
374
+ self.act = ACT2FN[config.hidden_act]
375
+ self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
376
+ self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
377
+
378
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
379
+ hidden_states = self.fc1(hidden_states)
380
+ hidden_states = self.act(hidden_states)
381
+ hidden_states = self.fc2(hidden_states)
382
+ return hidden_states
383
+
384
+
385
+ class InternVisionEncoderLayer(nn.Module):
386
+ def __init__(self, config: InternVisionConfig, drop_path_rate: float):
387
+ super().__init__()
388
+ self.embed_dim = config.hidden_size
389
+ self.intermediate_size = config.intermediate_size
390
+
391
+ self.attn = InternAttention(config)
392
+ self.mlp = InternMLP(config)
393
+ self.norm1 = InternRMSNorm(self.embed_dim, eps=config.layer_norm_eps)
394
+ self.norm2 = InternRMSNorm(self.embed_dim, eps=config.layer_norm_eps)
395
+
396
+ self.ls1 = nn.Parameter(config.initializer_factor * torch.ones(self.embed_dim))
397
+ self.ls2 = nn.Parameter(config.initializer_factor * torch.ones(self.embed_dim))
398
+ self.drop_path1 = DropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity()
399
+ self.drop_path2 = DropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity()
400
+
401
+ def forward(
402
+ self,
403
+ hidden_states: torch.Tensor,
404
+ ) -> Tuple[torch.FloatTensor, Optional[torch.FloatTensor], Optional[Tuple[torch.FloatTensor]]]:
405
+ """
406
+ Args:
407
+ hidden_states (`Tuple[torch.FloatTensor, Optional[torch.FloatTensor]]`): input to the layer of shape `(batch, seq_len, embed_dim)`
408
+ """
409
+ hidden_states = hidden_states + self.drop_path1(self.attn(self.norm1(hidden_states)) * self.ls1)
410
+
411
+ hidden_states = hidden_states + self.drop_path2(self.mlp(self.norm2(hidden_states)) * self.ls2)
412
+
413
+ return hidden_states
414
+
415
+
416
+ class InternVisionEncoder(nn.Module):
417
+ """
418
+ Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
419
+ [`InternEncoderLayer`].
420
+
421
+ Args:
422
+ config (`InternConfig`):
423
+ The corresponding vision configuration for the `InternEncoder`.
424
+ """
425
+
426
+ def __init__(self, config: InternVisionConfig):
427
+ super().__init__()
428
+ self.config = config
429
+ # stochastic depth decay rule
430
+ dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, config.num_hidden_layers)]
431
+ self.layers = nn.ModuleList(
432
+ [InternVisionEncoderLayer(config, dpr[idx]) for idx in range(config.num_hidden_layers)]
433
+ )
434
+ self.gradient_checkpointing = True
435
+
436
+ def forward(
437
+ self,
438
+ inputs_embeds,
439
+ output_hidden_states: Optional[bool] = None,
440
+ return_dict: Optional[bool] = None,
441
+ ) -> Union[Tuple, BaseModelOutput]:
442
+ r"""
443
+ Args:
444
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
445
+ Embedded representation of the inputs. Should be float, not int tokens.
446
+ output_hidden_states (`bool`, *optional*):
447
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
448
+ for more detail.
449
+ return_dict (`bool`, *optional*):
450
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
451
+ """
452
+ output_hidden_states = (
453
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
454
+ )
455
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
456
+
457
+ encoder_states = () if output_hidden_states else None
458
+ hidden_states = inputs_embeds
459
+
460
+ for idx, encoder_layer in enumerate(self.layers):
461
+ if output_hidden_states:
462
+ encoder_states = encoder_states + (hidden_states,)
463
+ if self.gradient_checkpointing and self.training:
464
+ layer_outputs = torch.utils.checkpoint.checkpoint(encoder_layer, hidden_states)
465
+ else:
466
+ layer_outputs = encoder_layer(
467
+ hidden_states,
468
+ )
469
+ hidden_states = layer_outputs
470
+
471
+ if output_hidden_states:
472
+ encoder_states = encoder_states + (hidden_states,)
473
+
474
+ if not return_dict:
475
+ return tuple(v for v in [hidden_states, encoder_states] if v is not None)
476
+ return BaseModelOutput(last_hidden_state=hidden_states, hidden_states=encoder_states)
477
+
478
+
479
+ class InternVisionModel(PreTrainedModel):
480
+ main_input_name = "pixel_values"
481
+ config_class = InternVisionConfig
482
+ _no_split_modules = ["InternVisionEncoderLayer"]
483
+
484
+ def __init__(self, config: InternVisionConfig):
485
+ super().__init__(config)
486
+ self.config = config
487
+
488
+ self.embeddings = InternVisionEmbeddings(config)
489
+ self.encoder = InternVisionEncoder(config)
490
+
491
+ def resize_pos_embeddings(self, old_size, new_size, patch_size):
492
+ pos_emb = self.embeddings.position_embedding
493
+ _, num_positions, embed_dim = pos_emb.shape
494
+ cls_emb = pos_emb[:, :1, :]
495
+ pos_emb = pos_emb[:, 1:, :].reshape(1, old_size // patch_size, old_size // patch_size, -1).permute(0, 3, 1, 2)
496
+ pos_emb = F.interpolate(pos_emb.float(), size=new_size // patch_size, mode="bicubic", align_corners=False)
497
+ pos_emb = pos_emb.to(cls_emb.dtype).reshape(1, embed_dim, -1).permute(0, 2, 1)
498
+ pos_emb = torch.cat([cls_emb, pos_emb], dim=1)
499
+ self.embeddings.position_embedding = nn.Parameter(pos_emb)
500
+ logger.info(f"Resized position embeddings from {old_size} to {new_size}")
501
+
502
+ def get_input_embeddings(self):
503
+ return self.embeddings
504
+
505
+ def forward(
506
+ self,
507
+ pixel_values: Optional[torch.FloatTensor] = None,
508
+ output_hidden_states: Optional[bool] = None,
509
+ return_dict: Optional[bool] = None,
510
+ pixel_embeds: Optional[torch.FloatTensor] = None,
511
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
512
+ output_hidden_states = (
513
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
514
+ )
515
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
516
+
517
+ if pixel_values is None and pixel_embeds is None:
518
+ raise ValueError("You have to specify pixel_values or pixel_embeds")
519
+
520
+ if pixel_embeds is not None:
521
+ hidden_states = pixel_embeds
522
+ else:
523
+ if len(pixel_values.shape) == 4:
524
+ hidden_states = self.embeddings(pixel_values)
525
+ else:
526
+ raise ValueError(f"wrong pixel_values size: {pixel_values.shape}")
527
+ encoder_outputs = self.encoder(
528
+ inputs_embeds=hidden_states,
529
+ output_hidden_states=output_hidden_states,
530
+ return_dict=return_dict,
531
+ )
532
+ last_hidden_state = encoder_outputs.last_hidden_state
533
+ pooled_output = last_hidden_state[:, 0, :]
534
+
535
+ if not return_dict:
536
+ return (last_hidden_state, pooled_output) + encoder_outputs[1:]
537
+
538
+ return BaseModelOutputWithPooling(
539
+ last_hidden_state=last_hidden_state,
540
+ pooler_output=pooled_output,
541
+ hidden_states=encoder_outputs.hidden_states,
542
+ attentions=encoder_outputs.attentions,
543
+ )
VILA/llava/model/multimodal_encoder/intern_encoder.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 NVIDIA CORPORATION & AFFILIATES
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # SPDX-License-Identifier: Apache-2.0
16
+
17
+ import torch
18
+ import torchvision.transforms as T
19
+ from torchvision.transforms.functional import InterpolationMode
20
+ from transformers import AutoConfig, AutoModel
21
+ from transformers.image_processing_utils import BaseImageProcessor
22
+
23
+ from llava.model.multimodal_encoder.intern.configuration_intern_vit import InternVisionConfig
24
+ from llava.model.multimodal_encoder.intern.modeling_intern_vit import InternVisionModel
25
+ from llava.model.multimodal_encoder.vision_encoder import VisionTower
26
+
27
+
28
+ def build_transform(input_size):
29
+ transform = T.Compose(
30
+ [
31
+ T.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img),
32
+ T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
33
+ T.ToTensor(),
34
+ T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
35
+ ]
36
+ )
37
+ return transform
38
+
39
+
40
+ class InternVisionPreprocessor(BaseImageProcessor):
41
+ @property
42
+ def size(self):
43
+ return {"height": 448, "width": 448}
44
+
45
+ def preprocess(self, image, return_tensors):
46
+ transform = build_transform(448)
47
+ if isinstance(image, list):
48
+ image_tensor = [transform(img) for img in image]
49
+ return {"pixel_values": image_tensor}
50
+ else:
51
+ image_tensor = transform(image)
52
+ return {"pixel_values": [image_tensor]}
53
+
54
+
55
+ class InternVisionTower(VisionTower):
56
+ def __init__(self, vision_tower, config, drop_path_rate=0.0):
57
+ super().__init__(vision_tower, config)
58
+ self._drop_path_rate = drop_path_rate
59
+
60
+ self.image_processor = InternVisionPreprocessor()
61
+ vision_config = InternVisionConfig.from_pretrained(vision_tower)
62
+ vision_config.drop_path_rate = self._drop_path_rate
63
+ self.vision_tower = InternVisionModel.from_pretrained(
64
+ vision_tower, torch_dtype=eval(config.model_dtype), config=vision_config
65
+ )
66
+
67
+ self.is_loaded = True
68
+
69
+
70
+ AutoConfig.register("intern_vit_6b", InternVisionConfig)
71
+ AutoModel.register(InternVisionConfig, InternVisionModel)
VILA/llava/model/multimodal_encoder/radio_encoder.py ADDED
@@ -0,0 +1,334 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 NVIDIA CORPORATION & AFFILIATES
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # SPDX-License-Identifier: Apache-2.0
16
+
17
+ import os
18
+ import warnings
19
+ from argparse import Namespace
20
+ from typing import Any, Dict
21
+
22
+ import numpy as np
23
+ import torch
24
+ from PIL import Image
25
+ from transformers import AutoConfig, AutoModel, CLIPVisionConfig
26
+
27
+ from llava.model.multimodal_encoder.vision_encoder import VisionTower
28
+ from llava.train.utils import mprint, rprint
29
+
30
+ from .image_processor import ImageProcessor
31
+ from .visualize_features import get_pca_map
32
+
33
+
34
+ def get_prefix_state_dict(state_dict: Dict[str, Any], prefix: str):
35
+ mod_state_dict = {k[len(prefix) :]: v for k, v in state_dict.items() if k.startswith(prefix)}
36
+ return mod_state_dict
37
+
38
+
39
+ def is_rank0():
40
+ return not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0
41
+
42
+
43
+ class RADIOVisionTower(VisionTower):
44
+ """
45
+ Vision Tower for the RADIO model.
46
+
47
+ Args:
48
+ vision_tower (str): Vision tower name. This is passed on
49
+ the command line with the `--vision_tower` argument.
50
+ The string is expected in the pattern of:
51
+ `radio:<image_size>:<checkpoint>:<extra_config>`.
52
+ Where <extra_config> is a comma-separated list of key=value pairs.
53
+ <image_size> can also be a comma-separated list of resolutions in
54
+ the case of multi-res inference. Limitations apply, e.g. only two
55
+ resolutions are supported and the second resolution must be a divisor
56
+ of the first one.
57
+ args (Namespace): Arguments.
58
+ delay_load (bool): Delay loading the model.
59
+ """
60
+
61
+ def __init__(self, vision_tower, args, delay_load=False):
62
+ """Initialization Routine."""
63
+
64
+ super().__init__(vision_tower, args, delay_load)
65
+
66
+ mprint(f"RADIOVisionTower: {vision_tower}. Args: {args} Delay load: {delay_load}")
67
+
68
+ assert not delay_load
69
+
70
+ self.select_feature = getattr(args, "mm_vision_select_feature", "patch")
71
+
72
+ extra_config = {}
73
+
74
+ # Check if vision_tower is a valid path.
75
+ if os.path.exists(vision_tower):
76
+ self.vision_tower_name = self.vision_tower_checkpoint = vision_tower
77
+ vision_cfg = getattr(args, "vision_tower_cfg")
78
+ self.image_size = vision_cfg["image_size"]
79
+ else:
80
+ self.vision_tower_name = vision_tower[len("radio:") :]
81
+ config_items = self.vision_tower_name.split(":")
82
+ self.image_size = int(config_items[0])
83
+
84
+ self.vision_tower_checkpoint = config_items[1]
85
+
86
+ if len(config_items) > 2:
87
+ # Parse extra config items. These are provided as a comma-separated list
88
+ # of key=value pairs.
89
+ extra_config_items = config_items[2].split(",")
90
+
91
+ for item in extra_config_items:
92
+ key, value = item.split("=")
93
+ extra_config[key] = value
94
+
95
+ self.image_aspect_ratio = args.image_aspect_ratio
96
+ self.skip_layer_norm = eval(extra_config.get("skip_layer_norm", "False"))
97
+
98
+ if not delay_load:
99
+ self.load_model()
100
+ else:
101
+ raise ValueError("Delay load not supported for RADIOVisionTower.")
102
+
103
+ self.sample_count = 0
104
+ self.debug = True
105
+
106
+ def get_hidden_size(self):
107
+ if self.select_feature == "cls":
108
+ hidden_size = 5120
109
+ elif self.select_feature == "dense":
110
+ hidden_size = 4 * 1280
111
+ else:
112
+ hidden_size = 1280
113
+
114
+ return hidden_size
115
+
116
+ def load_model(self):
117
+ if self.image_aspect_ratio == "resize":
118
+ self.image_processor = ImageProcessor(
119
+ size={"width": self.image_size, "height": self.image_size},
120
+ do_pad=False,
121
+ do_normalize=True,
122
+ do_convert_rgb=True,
123
+ )
124
+ else:
125
+ self.image_processor = ImageProcessor(
126
+ size={"longest_edge": self.image_size},
127
+ do_pad=True,
128
+ pad_multiple=16,
129
+ do_normalize=True,
130
+ do_convert_rgb=True,
131
+ pad_value=0.456,
132
+ )
133
+ # For compatibility with CLIP Image Processor: the data loader uses width/height to
134
+ # create dummy blank images for samples that don't have an image.
135
+ self.image_processor.crop_size = {"width": self.image_size, "height": self.image_size}
136
+
137
+ mprint(self.image_processor)
138
+
139
+ config = AutoConfig.from_pretrained(self.vision_tower_checkpoint, trust_remote_code=True)
140
+ mprint("RADIO config", config)
141
+ self.vision_tower = AutoModel.from_pretrained(self.vision_tower_checkpoint, trust_remote_code=True)
142
+ self.vision_tower.radio_model.make_preprocessor_external()
143
+
144
+ # # NOTE: do a lazy import of Timm to avoid issues with
145
+ # # DeepSpeed's ZeRO-3.
146
+ from timm.models.vision_transformer import VisionTransformer
147
+
148
+ #
149
+ if isinstance(self.vision_tower.model, VisionTransformer):
150
+ hidden_size = self.vision_tower.model.embed_dim
151
+ else:
152
+ raise ValueError(f"Unknown model type: {self.vision_tower}")
153
+
154
+ # Override hidden size for OpenAI CLIP.
155
+ hidden_size = self.get_hidden_size()
156
+
157
+ if hasattr(self.vision_tower.model, "patch_generator"):
158
+ patch_gen = self.vision_tower.model.patch_generator
159
+ # Cropped Positional Embedding (CPE) case.
160
+ patch_size = patch_gen.patch_size
161
+ else:
162
+ # Standard ViT case.
163
+ patch_size = self.vision_tower.model.patch_embed.patch_size[0]
164
+
165
+ self.vision_tower.config.image_size = self.image_size
166
+ self.vision_tower.config.hidden_size = hidden_size
167
+ self.vision_tower.config.patch_size = patch_size
168
+
169
+ self.vision_tower.cuda().eval()
170
+ self.vision_tower.requires_grad_(False)
171
+
172
+ self.is_loaded = True
173
+ self._to_dtype = None
174
+
175
+ if self.skip_layer_norm:
176
+ mprint(f"Removing layer norm from the model: {self.vision_tower.model.norm}")
177
+ self.vision_tower.model.norm = torch.nn.Identity()
178
+
179
+ def to(self, *args, **kwargs):
180
+ # Prevent casting the RADIO model's weights
181
+ kwargs = dict(kwargs)
182
+ # self._to_dtype = kwargs.get('dtype', None)
183
+ self._to_dtype = kwargs.pop("dtype", None)
184
+ mprint(f"RADIO: bypass cast to dtype={self._to_dtype}")
185
+ super().to(*args, **kwargs)
186
+ pass
187
+
188
+ def train(self, mode=True):
189
+ """Intercept call."""
190
+ # Drop a warning if mode is True.
191
+ if mode:
192
+ warnings.warn("RADIOEncoder is always in eval mode.")
193
+ pass
194
+
195
+ def _get_summary_and_patch_from_tokens(self, tokens):
196
+ model = self.vision_tower.model
197
+ patch_gen = getattr(model, "patch_generator", None)
198
+ if patch_gen is not None:
199
+ all_summary = tokens[:, : patch_gen.num_cls_tokens]
200
+ if self.vision_tower.radio_model.summary_idxs is not None:
201
+ summary = all_summary[:, self.vision_tower.radio_model.summary_idxs]
202
+ else:
203
+ summary = all_summary
204
+ all_feat = tokens[:, patch_gen.num_skip :]
205
+ elif model.global_pool == "avg":
206
+ all_summary = tokens[:, model.num_prefix_tokens :].mean(dim=1)
207
+ summary = all_summary
208
+ all_feat = tokens
209
+ else:
210
+ all_summary = tokens[:, 0]
211
+ summary = all_summary
212
+ all_feat = tokens[:, 1:]
213
+ return summary, all_feat
214
+
215
+ @torch.no_grad()
216
+ def get_features(self, x: torch.Tensor):
217
+ x_dtype = x.dtype
218
+ x = x.float()
219
+ with torch.autocast("cuda", dtype=torch.bfloat16):
220
+ if self.select_feature == "dense":
221
+
222
+ # Layers to return activations of in case of "return_multilayer=True".
223
+ num_layers = len(self.vision_tower.model.blocks)
224
+ multilayers = [
225
+ num_layers // 4 - 1,
226
+ num_layers // 2 - 1,
227
+ num_layers // 4 * 3 - 1,
228
+ ]
229
+
230
+ features = []
231
+ intermediate_features = []
232
+
233
+ x = self.vision_tower.input_conditioner(x)
234
+ x = self.vision_tower.model.patch_generator(x)
235
+
236
+ for i, blk in enumerate(self.vision_tower.model.blocks):
237
+ x = blk(x)
238
+ _, blk_features = self._get_summary_and_patch_from_tokens(x)
239
+ intermediate_features.append(blk_features)
240
+ if i in multilayers:
241
+ intermediate_features = torch.stack(intermediate_features, dim=0)
242
+ intermediate_features = torch.sum(intermediate_features, dim=0) / intermediate_features.shape[0]
243
+ features.append(intermediate_features)
244
+ intermediate_features = []
245
+ x = self.vision_tower.model.norm(x)
246
+ last_summary, last_features = self._get_summary_and_patch_from_tokens(x)
247
+ features.append(last_features)
248
+ features = torch.cat(features, dim=-1)
249
+ summary = last_summary
250
+ else:
251
+ summary, features = self.vision_tower(x)
252
+
253
+ return summary, features.to(dtype=x_dtype)
254
+
255
+ @torch.no_grad()
256
+ def forward(self, images: torch.Tensor):
257
+ """Main forward pass."""
258
+ input_shape = images.shape
259
+
260
+ x = images
261
+ # Add a batch dimension if necessary.
262
+ if len(input_shape) == 3:
263
+ x = x.unsqueeze(0)
264
+
265
+ # Convert the input to the model's dtype (we assume
266
+ # that the model only has one dtype for all parameters).
267
+ param0 = next(self.vision_tower.parameters())
268
+
269
+ rprint(
270
+ f"input shape={input_shape}->{x.shape} device={x.device} mean={x.mean().item()} std={x.std().item()} dtype={x.dtype} param0.device={param0.device} param0.dtype={param0.dtype}"
271
+ )
272
+
273
+ summary, features = self.get_features(x) # B, T, C
274
+
275
+ if len(summary.shape) == 2:
276
+ if self.select_feature == "cls4":
277
+ # Add a token dimension if necessary.
278
+ B, C = summary.shape
279
+ summary = summary.reshape(B, 4, C // 4)
280
+ else:
281
+ # Add a token dimension if necessary.
282
+ summary = summary.unsqueeze(1)
283
+
284
+ B, _, H, W = x.shape
285
+ _, _, C = features.shape
286
+ patch_size = self.vision_tower.config.patch_size
287
+ spatial_features = features.reshape(B, H // patch_size, W // patch_size, C)
288
+ spatial_features = spatial_features.permute(0, 3, 1, 2) # B, C, H/patch_size, W/patch_size
289
+
290
+ if self.debug and is_rank0() and self.sample_count % 1000 == 0:
291
+ spatial_features_hwc = spatial_features.permute(0, 2, 3, 1)
292
+ # create the debug directory
293
+ os.makedirs("radio-debug", exist_ok=True)
294
+ torch.save(x, f"radio-debug/sample_{self.sample_count}_input.pt")
295
+ torch.save(features, f"radio-debug/sample_{self.sample_count}_features.pt")
296
+ torch.save(spatial_features_hwc, f"radio-debug/sample_{self.sample_count}_features_reshaped.pt")
297
+ for i in range(B):
298
+ image = x[i].permute(1, 2, 0).float() * 255
299
+ image = Image.fromarray(image.cpu().numpy().astype(np.uint8))
300
+ image.save(os.path.join("radio-debug/", f"sample_{self.sample_count}_preprocessed_{i}.png"))
301
+ pca_map = get_pca_map(spatial_features_hwc[i : i + 1], x.shape[-2:])
302
+ torch.save(pca_map, f"radio-debug/sample_{self.sample_count}_pca_map_{i}.pt")
303
+ image = pca_map * 255
304
+ image = Image.fromarray(image.astype(np.uint8))
305
+ image.save(os.path.join("radio-debug/", f"sample_{self.sample_count}_pca_map_{i}.png"))
306
+ pass
307
+
308
+ if self.select_feature in ["patch", "cls_patch", "dense"]:
309
+ # Ignore cls-patch for now.
310
+ pass
311
+ # elif self.select_feature == "cls_patch":
312
+ # features = torch.cat([summary, features], dim=1)
313
+ elif self.select_feature in ["cls", "cls4"]:
314
+ features = summary
315
+ else:
316
+ raise ValueError(f"Unexpected select feature: {self.select_feature}")
317
+
318
+ # Remove the batch dimension if we added it.
319
+ if len(input_shape) == 3:
320
+ features = features.squeeze(0)
321
+
322
+ # Cast back to the input's dtype.
323
+ features = features.to(images.dtype)
324
+
325
+ rprint(
326
+ f"features shape={features.shape} mean={features.mean().item()} std={features.std().item()} dtype={features.dtype}"
327
+ )
328
+
329
+ if features.shape[-1] != self.get_hidden_size():
330
+ raise ValueError(f"Unexpected hidden size: {features.shape[-1]} != {self.get_hidden_size()}")
331
+
332
+ self.sample_count += 1
333
+
334
+ return features
VILA/llava/model/multimodal_encoder/radio_torchhub_encoder.py ADDED
@@ -0,0 +1,375 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 NVIDIA CORPORATION & AFFILIATES
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # SPDX-License-Identifier: Apache-2.0
16
+
17
+ import os
18
+ import warnings
19
+ from argparse import Namespace
20
+ from typing import Any, Dict
21
+
22
+ import numpy as np
23
+ import torch
24
+ from PIL import Image
25
+ from transformers import CLIPVisionConfig
26
+
27
+ from llava.model.multimodal_encoder.vision_encoder import VisionTower
28
+ from llava.train.utils import mprint, rprint
29
+
30
+ from .image_processor import ImageProcessor
31
+ from .visualize_features import get_pca_map
32
+
33
+
34
+ def get_prefix_state_dict(state_dict: Dict[str, Any], prefix: str):
35
+ mod_state_dict = {k[len(prefix) :]: v for k, v in state_dict.items() if k.startswith(prefix)}
36
+ return mod_state_dict
37
+
38
+
39
+ def is_rank0():
40
+ return not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0
41
+
42
+
43
+ class RADIOVisionTower(VisionTower):
44
+ """
45
+ Vision Tower for the RADIO model.
46
+
47
+ Args:
48
+ vision_tower (str): Vision tower name. This is passed on
49
+ the command line with the `--vision_tower` argument.
50
+ The string is expected in the pattern of:
51
+ `radio:<image_size>:<checkpoint>:<extra_config>`.
52
+ Where <extra_config> is a comma-separated list of key=value pairs.
53
+ <image_size> can also be a comma-separated list of resolutions in
54
+ the case of multi-res inference. Limitations apply, e.g. only two
55
+ resolutions are supported and the second resolution must be a divisor
56
+ of the first one.
57
+ args (Namespace): Arguments.
58
+ delay_load (bool): Delay loading the model.
59
+ """
60
+
61
+ def __init__(self, vision_tower, args, delay_load=False):
62
+ """Initialization Routine."""
63
+
64
+ super().__init__(vision_tower, args, delay_load)
65
+
66
+ mprint(f"RADIOVisionTower: {vision_tower}. Args: {args} Delay load: {delay_load}")
67
+
68
+ self.select_feature = getattr(args, "mm_vision_select_feature", "patch")
69
+
70
+ self.vision_tower_name = vision_tower[len("radio:") :]
71
+ config_items = self.vision_tower_name.split(":")
72
+ self.image_sizes = [int(x) for x in config_items[0].split(",")]
73
+ if len(self.image_sizes) == 0:
74
+ raise ValueError("Expected more than zero images sizes!")
75
+ self.image_size = self.image_sizes[0]
76
+ self.image_aspect_ratio = args.image_aspect_ratio
77
+
78
+ self.downscale_factor = None
79
+ if len(self.image_sizes) > 1:
80
+ self.downscale_factor = self.image_sizes[0] // self.image_sizes[1]
81
+ assert self.downscale_factor == self.image_sizes[0] / self.image_sizes[1]
82
+ self.pool2d = torch.nn.AvgPool2d(self.downscale_factor, self.downscale_factor)
83
+ if len(self.image_sizes) > 2:
84
+ raise ValueError(f"Only support up to two resolutions")
85
+ elif self.image_size >= 512:
86
+ self.downscale_factor = 2
87
+
88
+ self.vision_tower_checkpoint = config_items[1]
89
+
90
+ extra_config = {}
91
+ if len(config_items) > 2:
92
+ # Parse extra config items. These are provided as a comma-separated list
93
+ # of key=value pairs.
94
+ extra_config_items = config_items[2].split(",")
95
+
96
+ for item in extra_config_items:
97
+ key, value = item.split("=")
98
+ extra_config[key] = value
99
+
100
+ self.adaptor_name = extra_config.get("adaptor", "backbone")
101
+ self.fuse_adaptor_with_backbone = eval(extra_config.get("fuse_adaptor_with_backbone", "False"))
102
+ self.skip_layer_norm = eval(extra_config.get("skip_layer_norm", "False"))
103
+ self.allow_pixel_unshuffle = eval(extra_config.get("pixel_unshuffle", "False"))
104
+
105
+ self.pixel_unshuffle = None
106
+ if self.allow_pixel_unshuffle and self.downscale_factor is not None:
107
+ self.pixel_unshuffle = torch.nn.PixelUnshuffle(self.downscale_factor)
108
+
109
+ if not delay_load:
110
+ self.load_model()
111
+ else:
112
+ # FIXME: This is a hack to avoid having to load the config from the checkpoint.
113
+ hidden_size = self.get_hidden_size(self.adaptor_name)
114
+ patch_size = 16
115
+
116
+ self.cfg_only = CLIPVisionConfig(
117
+ **{
118
+ "hidden_size": hidden_size,
119
+ "image_size": self.image_size,
120
+ "model_type": "radio_vision_model",
121
+ "num_attention_heads": None,
122
+ "num_channels": 3,
123
+ "num_hidden_layers": None,
124
+ "patch_size": patch_size,
125
+ }
126
+ )
127
+
128
+ self.sample_count = 0
129
+
130
+ self.debug = True
131
+
132
+ def get_hidden_size(self):
133
+ if self.select_feature == "cls":
134
+ hidden_size = 5120
135
+ elif self.adaptor_name == "openai_clip":
136
+ hidden_size = 1024
137
+ elif self.adaptor_name == "clip":
138
+ hidden_size = 1280
139
+ elif self.adaptor_name == "rtx-translate":
140
+ hidden_size = 2048
141
+ elif self.adaptor_name == "backbone":
142
+ hidden_size = 1280
143
+ else:
144
+ raise ValueError(f"Unknown adaptor name: {self.adaptor_name}")
145
+
146
+ if self.fuse_adaptor_with_backbone:
147
+ hidden_size += 1280
148
+
149
+ if len(self.image_sizes) == 2:
150
+ if self.pixel_unshuffle is not None:
151
+ hidden_size = hidden_size * 5
152
+ else:
153
+ hidden_size = hidden_size * 2
154
+ elif self.pixel_unshuffle is not None:
155
+ hidden_size = hidden_size * 4
156
+
157
+ return hidden_size
158
+
159
+ def load_model(self):
160
+
161
+ if self.image_aspect_ratio == "resize":
162
+ self.image_processor = ImageProcessor(
163
+ size={"width": self.image_size, "height": self.image_size},
164
+ do_pad=False,
165
+ do_normalize=False,
166
+ do_convert_rgb=True,
167
+ )
168
+ else:
169
+ self.image_processor = ImageProcessor(
170
+ size={"longest_edge": self.image_size},
171
+ do_pad=True,
172
+ pad_multiple=16,
173
+ do_normalize=False,
174
+ do_convert_rgb=True,
175
+ pad_value=0.456,
176
+ )
177
+ # For compatibility with CLIP Image Processor: the data loader uses width/height to
178
+ # create dummy blank images for samples that don't have an image.
179
+ self.image_processor.crop_size = {"width": self.image_size, "height": self.image_size}
180
+
181
+ mprint(self.image_processor)
182
+
183
+ # Load weights from checkpoint.
184
+ checkpoint_path = self.vision_tower_checkpoint
185
+ rprint(f"Loading checkpoint from {checkpoint_path}")
186
+
187
+ # NOTE: do a lazy import of Timm to avoid issues with
188
+ # DeepSpeed's ZeRO-3.
189
+ from timm.models.vision_transformer import VisionTransformer
190
+
191
+ self.vision_tower = torch.hub.load(
192
+ "NVlabs/RADIO",
193
+ "radio_model",
194
+ version=checkpoint_path,
195
+ progress=True,
196
+ adaptor_names=self.adaptor_name if self.adaptor_name != "backbone" else None,
197
+ )
198
+
199
+ if isinstance(self.vision_tower.model, VisionTransformer):
200
+ hidden_size = self.vision_tower.model.embed_dim
201
+ else:
202
+ raise ValueError(f"Unknown model type: {self.vision_tower}")
203
+
204
+ # Override hidden size for OpenAI CLIP.
205
+ hidden_size = self.get_hidden_size()
206
+
207
+ if hasattr(self.vision_tower.model, "patch_generator"):
208
+ patch_gen = self.vision_tower.model.patch_generator
209
+ # Cropped Positional Embedding (CPE) case.
210
+ patch_size = patch_gen.patch_size
211
+ else:
212
+ # Standard ViT case.
213
+ patch_size = self.vision_tower.model.patch_embed.patch_size[0]
214
+
215
+ self.vision_tower.config = CLIPVisionConfig(
216
+ **{
217
+ "hidden_size": hidden_size,
218
+ "image_size": self.image_size,
219
+ "model_type": "radio_vision_model",
220
+ "num_attention_heads": None,
221
+ "num_channels": 3,
222
+ "num_hidden_layers": None,
223
+ "patch_size": patch_size,
224
+ }
225
+ )
226
+
227
+ self.vision_tower.eval()
228
+ self.vision_tower.requires_grad_(False)
229
+
230
+ self.is_loaded = True
231
+ self._to_dtype = None
232
+
233
+ if self.skip_layer_norm:
234
+ rank0_print(f"Removing layer norm from the model: {self.vision_tower.model.norm}")
235
+ self.vision_tower.model.norm = torch.nn.Identity()
236
+
237
+ def to(self, *args, **kwargs):
238
+ # Prevent casting the RADIO model's weights
239
+ kwargs = dict(kwargs)
240
+ self._to_dtype = kwargs.pop("dtype", None)
241
+ mprint(f"RADIO: bypass cast to dtype={self._to_dtype}")
242
+ super().to(*args, **kwargs)
243
+ pass
244
+
245
+ def train(self, mode=True):
246
+ """Intercept call."""
247
+ # Drop a warning if mode is True.
248
+ if mode:
249
+ warnings.warn("RADIOEncoder is always in eval mode.")
250
+ pass
251
+
252
+ @torch.no_grad()
253
+ def get_features(self, x: torch.Tensor):
254
+ x_float = x.float()
255
+ with torch.autocast("cuda", dtype=torch.bfloat16):
256
+ output = self.vision_tower(x_float)
257
+
258
+ if isinstance(output, dict):
259
+ summary, features = output[self.adaptor_name]
260
+ if self.fuse_adaptor_with_backbone:
261
+ backbone_summary, backbone_features = output["backbone"]
262
+ summary = torch.cat([summary, backbone_summary], dim=2)
263
+ features = torch.cat([features, backbone_features], dim=2)
264
+ else:
265
+ summary, features = output
266
+
267
+ return summary, features.to(dtype=x.dtype)
268
+
269
+ @torch.no_grad()
270
+ def forward(self, images: torch.Tensor):
271
+ """Main forward pass."""
272
+ input_shape = images.shape
273
+
274
+ x = images
275
+ # Add a batch dimension if necessary.
276
+ if len(input_shape) == 3:
277
+ x = x.unsqueeze(0)
278
+
279
+ rprint(
280
+ f"input shape={input_shape}->{x.shape} device={x.device} mean={x.mean().item()} std={x.std().item()} dtype={x.dtype}"
281
+ )
282
+
283
+ summary, features = self.get_features(x) # B, T, C
284
+
285
+ if len(summary.shape) == 2:
286
+ if self.select_feature == "cls4":
287
+ # Add a token dimension if necessary.
288
+ B, C = summary.shape
289
+ summary = summary.reshape(B, 4, C // 4)
290
+ else:
291
+ # Add a token dimension if necessary.
292
+ summary = summary.unsqueeze(1)
293
+
294
+ B, _, H, W = x.shape
295
+ _, _, C = features.shape
296
+ patch_size = self.vision_tower.config.patch_size
297
+ spatial_features = features.reshape(B, H // patch_size, W // patch_size, C)
298
+ spatial_features = spatial_features.permute(0, 3, 1, 2) # B, C, H/patch_size, W/patch_size
299
+
300
+ if self.debug and is_rank0() and self.sample_count % 1000 == 0:
301
+ spatial_features_hwc = spatial_features.permute(0, 2, 3, 1)
302
+ # create the debug directory
303
+ os.makedirs("radio-debug", exist_ok=True)
304
+ torch.save(x, f"radio-debug/sample_{self.sample_count}_input.pt")
305
+ torch.save(features, f"radio-debug/sample_{self.sample_count}_features.pt")
306
+ torch.save(spatial_features_hwc, f"radio-debug/sample_{self.sample_count}_features_reshaped.pt")
307
+ for i in range(B):
308
+ image = x[i].permute(1, 2, 0).float() * 255
309
+ image = Image.fromarray(image.cpu().numpy().astype(np.uint8))
310
+ image.save(os.path.join("radio-debug/", f"sample_{self.sample_count}_preprocessed_{i}.png"))
311
+ pca_map = get_pca_map(spatial_features_hwc[i : i + 1], x.shape[-2:])
312
+ torch.save(pca_map, f"radio-debug/sample_{self.sample_count}_pca_map_{i}.pt")
313
+ image = pca_map * 255
314
+ image = Image.fromarray(image.astype(np.uint8))
315
+ image.save(os.path.join("radio-debug/", f"sample_{self.sample_count}_pca_map_{i}.png"))
316
+ pass
317
+
318
+ if self.pixel_unshuffle is not None:
319
+ spatial_features = self.pixel_unshuffle(spatial_features)
320
+ # B, C*downscale_factor**2, H/patch_size/downscale_factor, W/patch_size/downscale_factor
321
+ features = spatial_features.reshape(
322
+ B,
323
+ C * self.downscale_factor**2,
324
+ (H // patch_size // self.downscale_factor) * (W // patch_size // self.downscale_factor),
325
+ ).permute(0, 2, 1)
326
+
327
+ if len(self.image_sizes) > 1:
328
+ # Experimental support for multi-resolution inference.
329
+ if self.pixel_unshuffle is None:
330
+ # downscale features
331
+ spatial_features = self.pool2d(
332
+ spatial_features
333
+ ) # B, C, H/patch_size/downscale_factor, W/patch_size/downscale_factor
334
+ features = spatial_features.reshape(
335
+ B, C, (H // patch_size // self.downscale_factor) * (W // patch_size // self.downscale_factor)
336
+ )
337
+ features = features.permute(
338
+ 0, 2, 1
339
+ ) # B, (H/patch_size/downscale_factor) * (W/patch_size/downscale_factor), C
340
+
341
+ # Downscale the input image.
342
+ x = self.pool2d(x) # B, 3, H/downscale_factor, W/downscale_factor)
343
+ features_stage2 = self.get_features(
344
+ x
345
+ ) # B, (H/patch_size/downscale_factor) * (W/patch_size/downscale_factor), C
346
+
347
+ # Concatenate stage1 and stage 2 features.
348
+ features = torch.cat([features, features_stage2], dim=2)
349
+
350
+ if self.select_feature in ["patch", "cls_patch"]:
351
+ # Ignore cls-patch for now.
352
+ pass
353
+ # elif self.select_feature == "cls_patch":
354
+ # features = torch.cat([summary, features], dim=1)
355
+ elif self.select_feature in ["cls", "cls4"]:
356
+ features = summary
357
+ else:
358
+ raise ValueError(f"Unexpected select feature: {self.select_feature}")
359
+
360
+ # Remove the batch dimension if we added it.
361
+ if len(input_shape) == 3:
362
+ features = features.squeeze(0)
363
+
364
+ # Cast back to the input's dtype.
365
+ features = features.to(images.dtype)
366
+
367
+ adaptor_name = f"{self.adaptor_name}{'+backbone' if self.fuse_adaptor_with_backbone else ''}"
368
+ rprint(
369
+ f"features ({adaptor_name}) shape={features.shape} mean={features.mean().item()} std={features.std().item()} dtype={features.dtype}"
370
+ )
371
+
372
+ assert features.shape[-1] == self.get_hidden_size()
373
+ self.sample_count += 1
374
+
375
+ return features