jcsun1 commited on
Commit
28c0d1b
·
1 Parent(s): d9493c9

update mllm

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. mllm/flamingo/__init__.py +48 -0
  2. mllm/flamingo/config.json +21 -0
  3. mllm/flamingo/configuration_flamingo.py +100 -0
  4. mllm/flamingo/converting_flamingo_to_bf16.py +30 -0
  5. mllm/flamingo/converting_flamingo_to_hf.py +61 -0
  6. mllm/flamingo/converting_flamingo_to_lora.py +68 -0
  7. mllm/flamingo/falcon/__init__.py +0 -0
  8. mllm/flamingo/falcon/__pycache__/__init__.cpython-39.pyc +0 -0
  9. mllm/flamingo/falcon/__pycache__/configuration_RW.cpython-39.pyc +0 -0
  10. mllm/flamingo/falcon/__pycache__/modelling_RW.cpython-39.pyc +0 -0
  11. mllm/flamingo/falcon/configuration_RW.py +79 -0
  12. mllm/flamingo/falcon/modelling_RW.py +1064 -0
  13. mllm/flamingo/flamingo-falcon-7B.json +112 -0
  14. mllm/flamingo/flamingo-llama2-chat-13B.json +114 -0
  15. mllm/flamingo/flamingo-llama2-chat-7B.json +115 -0
  16. mllm/flamingo/flamingo-mpt-1B-redpajama.json +131 -0
  17. mllm/flamingo/flamingo-mpt-30B-bf16.json +195 -0
  18. mllm/flamingo/flamingo-mpt-30B.json +195 -0
  19. mllm/flamingo/flamingo-mpt-7B.json +195 -0
  20. mllm/flamingo/flamingo-vicuna-33B-v1.3.json +111 -0
  21. mllm/flamingo/flamingo-vicuna-7B-v1.3.json +111 -0
  22. mllm/flamingo/injecting_falcon_into_flamingo.py +49 -0
  23. mllm/flamingo/injecting_llama2_into_flamingo.py +95 -0
  24. mllm/flamingo/injecting_mpt-1B-redpajama_into_flamingo.py +97 -0
  25. mllm/flamingo/injecting_mpt_into_flamingo.py +109 -0
  26. mllm/flamingo/injecting_vicuna_into_flamingo.py +100 -0
  27. mllm/flamingo/modeling_flamingo.py +966 -0
  28. mllm/flamingo/mpt/__init__.py +0 -0
  29. mllm/flamingo/mpt/__pycache__/__init__.cpython-39.pyc +0 -0
  30. mllm/flamingo/mpt/__pycache__/attention.cpython-39.pyc +0 -0
  31. mllm/flamingo/mpt/__pycache__/blocks.cpython-39.pyc +0 -0
  32. mllm/flamingo/mpt/__pycache__/configuration_mpt.cpython-39.pyc +0 -0
  33. mllm/flamingo/mpt/__pycache__/custom_embedding.cpython-39.pyc +0 -0
  34. mllm/flamingo/mpt/__pycache__/flash_attn_triton.cpython-39.pyc +0 -0
  35. mllm/flamingo/mpt/__pycache__/modeling_mpt.cpython-39.pyc +0 -0
  36. mllm/flamingo/mpt/__pycache__/norm.cpython-39.pyc +0 -0
  37. mllm/flamingo/mpt/__pycache__/param_init_fns.cpython-39.pyc +0 -0
  38. mllm/flamingo/mpt/adapt_tokenizer.py +44 -0
  39. mllm/flamingo/mpt/attention.py +450 -0
  40. mllm/flamingo/mpt/blocks.py +82 -0
  41. mllm/flamingo/mpt/configuration_mpt.py +161 -0
  42. mllm/flamingo/mpt/custom_embedding.py +11 -0
  43. mllm/flamingo/mpt/flash_attn_triton.py +841 -0
  44. mllm/flamingo/mpt/hf_prefixlm_converter.py +575 -0
  45. mllm/flamingo/mpt/meta_init_context.py +98 -0
  46. mllm/flamingo/mpt/modeling_mpt.py +496 -0
  47. mllm/flamingo/mpt/norm.py +60 -0
  48. mllm/flamingo/mpt/param_init_fns.py +369 -0
  49. mllm/flamingo/mpt_redpajama/__init__.py +0 -0
  50. mllm/flamingo/mpt_redpajama/__pycache__/__init__.cpython-39.pyc +0 -0
mllm/flamingo/__init__.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import TYPE_CHECKING
2
+
3
+ from transformers.utils import (
4
+ OptionalDependencyNotAvailable,
5
+ _LazyModule,
6
+ is_torch_available,
7
+ )
8
+
9
+
10
+ _import_structure = {
11
+ "configuration_flamingo": [
12
+ "FlamingoConfig",
13
+ ],
14
+ }
15
+
16
+ try:
17
+ if not is_torch_available():
18
+ raise OptionalDependencyNotAvailable()
19
+ except OptionalDependencyNotAvailable:
20
+ pass
21
+ else:
22
+ _import_structure["modeling_flamingo"] = [
23
+ "FlamingoModel",
24
+ "FlamingoPreTrainedModel",
25
+ "FlamingoForConditionalGeneration",
26
+ ]
27
+
28
+ if TYPE_CHECKING:
29
+ from .configuration_flamingo import FlamingoConfig
30
+
31
+ # from .processing_flamingo import FlamingoProcessor
32
+
33
+ try:
34
+ if not is_torch_available():
35
+ raise OptionalDependencyNotAvailable()
36
+ except OptionalDependencyNotAvailable:
37
+ pass
38
+ else:
39
+ from .modeling_flamingo import (
40
+ FlamingoForConditionalGeneration,
41
+ FlamingoModel,
42
+ FlamingoPreTrainedModel,
43
+ )
44
+
45
+ else:
46
+ import sys
47
+
48
+ sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
mllm/flamingo/config.json ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "flamingo",
3
+ "cross_attn_every_n_layers": 4,
4
+ "tie_word_embeddings": false,
5
+ "use_media_placement_augmentation": true,
6
+ "only_attend_previous": true,
7
+ "text_config": {
8
+ "_name_or_path": "luodian/llama-7b-hf",
9
+ "model_type": "llama"
10
+ },
11
+ "vision_config": {
12
+ "_name_or_path": "openai/clip-vit-large-patch14",
13
+ "model_type": "clip_vision_model",
14
+ "hidden_size": 1024,
15
+ "intermediate_size": 4096,
16
+ "num_attention_heads": 16,
17
+ "num_hidden_layers": 24,
18
+ "image_size": 224,
19
+ "patch_size": 14
20
+ }
21
+ }
mllm/flamingo/configuration_flamingo.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+
3
+ from transformers.configuration_utils import PretrainedConfig
4
+ from transformers.utils import logging
5
+
6
+ from transformers.models.auto import CONFIG_MAPPING
7
+ from transformers.models.clip import CLIPVisionConfig
8
+ import sys
9
+
10
+ from .falcon.configuration_RW import RWConfig
11
+ from .mpt.configuration_mpt import MPTConfig
12
+ from .mpt_redpajama.configuration_mosaic_gpt import MosaicGPTConfig
13
+
14
+
15
+ logger = logging.get_logger(__name__)
16
+
17
+
18
+ class FlamingoConfig(PretrainedConfig):
19
+ r"""
20
+ [`FlamingoConfig`] is the configuration class to store the configuration of a [`FlamingoForConditionalGeneration`]. It is
21
+ used to instantiate a Flamingo model according to the specified arguments, defining the vision model and language model configs. Instantiating a configuration with the defaults will yield a similar configuration to
22
+ that of the Flamingo architecture.
23
+
24
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
25
+ documentation from [`PretrainedConfig`] for more information.
26
+
27
+ Args:
28
+ vision_config (`dict`, *optional*):
29
+ Dictionary of configuration options used to initialize [`PretrainedConfig`].
30
+ text_config (`dict`, *optional*):
31
+ Dictionary of configuration options used to initialize any [`PretrainedConfig`].
32
+ cross_attn_every_n_layers (`int`, *optional*, defaults to 4):
33
+ The number of cross-attention layers adding after each transformer layer.
34
+
35
+ kwargs (*optional*):
36
+ Dictionary of keyword arguments.
37
+
38
+ Example:
39
+
40
+ ```python
41
+ >>> from transformers import (
42
+ ... PretrainedConfig,
43
+ ... OPTConfig,
44
+ ... FlamingoConfig,
45
+ ... FlamingoForConditionalGeneration,
46
+ ... )
47
+
48
+ >>> # Initializing a FlamingoConfig with Salesforce/Flamingo-opt-2.7b style configuration
49
+ >>> configuration = FlamingoConfig()
50
+
51
+ >>> # Initializing a FlamingoForConditionalGeneration (with random weights) from the Salesforce/Flamingo-opt-2.7b style configuration
52
+ >>> model = FlamingoForConditionalGeneration(configuration)
53
+ ```"""
54
+ model_type = "flamingo"
55
+ is_composition = True
56
+
57
+ def __init__(self, vision_config=None, text_config=None, cross_attn_every_n_layers: int = 4, use_media_placement_augmentation: bool = True, **kwargs):
58
+ super().__init__(**kwargs)
59
+ if vision_config is None:
60
+ vision_config = {}
61
+ logger.info("vision_config is None. initializing the vision config with default values.")
62
+
63
+ if text_config is None:
64
+ text_config = {}
65
+ logger.info("text_config is None. Initializing the text config with default values.")
66
+
67
+ self.vision_config = CLIPVisionConfig(**vision_config)
68
+ if "architectures" in text_config.keys() and text_config["architectures"] != None:
69
+ if text_config["architectures"][0] == "MPTForCausalLM":
70
+ self.text_config = MPTConfig(**text_config)
71
+ elif text_config["architectures"][0] == "MosaicGPT":
72
+ self.text_config = MosaicGPTConfig(**text_config)
73
+ elif text_config["architectures"][0] == "RWForCausalLM":
74
+ self.text_config = RWConfig(**text_config)
75
+ elif text_config["architectures"][0] == "LlamaForCausalLM":
76
+ self.text_config = CONFIG_MAPPING[text_config.pop("model_type")](**text_config)
77
+ else:
78
+ import pdb
79
+
80
+ pdb.set_trace()
81
+ else:
82
+ self.text_config = CONFIG_MAPPING[text_config.pop("model_type")](**text_config)
83
+
84
+ self.cross_attn_every_n_layers = cross_attn_every_n_layers
85
+ self.use_media_placement_augmentation = use_media_placement_augmentation
86
+
87
+ def to_dict(self):
88
+ """
89
+ Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
90
+
91
+ Returns:
92
+ `Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
93
+ """
94
+ output = copy.deepcopy(self.__dict__)
95
+ output["vision_config"] = self.vision_config.to_dict()
96
+ output["text_config"] = self.text_config.to_dict()
97
+ output["model_type"] = self.__class__.model_type
98
+ output["cross_attn_every_n_layers"] = self.cross_attn_every_n_layers
99
+ output["use_media_placement_augmentation"] = self.use_media_placement_augmentation
100
+ return output
mllm/flamingo/converting_flamingo_to_bf16.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+
4
+ import torch
5
+
6
+ from .configuration_flamingo import FlamingoConfig
7
+ from .modeling_flamingo import FlamingoForConditionalGeneration
8
+
9
+ parser = argparse.ArgumentParser(description="Load model with precision")
10
+ parser.add_argument("--load_bit", type=str, choices=["fp16", "bf16"], required=True, help="Choose either 'fp16' or 'bf16'")
11
+ parser.add_argument("--pretrained_model_path", type=str, default="/home/luodian/projects/checkpoints/flamingo-mpt-7B-instruct-init", required=True)
12
+ parser.add_argument("--saved_model_path", type=str, default="/home/luodian/projects/checkpoints/flamingo-mpt-7B-instruct-init", required=True)
13
+ args = parser.parse_args()
14
+
15
+ load_bit = args.load_bit
16
+ pretrained_model_path = args.pretrained_model_path
17
+
18
+ if load_bit == "fp16":
19
+ precision = {"torch_dtype": torch.float16}
20
+ elif load_bit == "bf16":
21
+ precision = {"torch_dtype": torch.bfloat16}
22
+
23
+ root_dir = os.environ["AZP"]
24
+ print(root_dir)
25
+ device_id = "cpu"
26
+ model = FlamingoForConditionalGeneration.from_pretrained(pretrained_model_path, device_map={"": device_id}, **precision)
27
+
28
+ # save model to same folder
29
+ checkpoint_path = pretrained_model_path + f"-{load_bit}"
30
+ model.save_pretrained(checkpoint_path, max_shard_size="10GB")
mllm/flamingo/converting_flamingo_to_hf.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """convert from otter pt to otter hf. Will remove after we use otter hf model to train.
2
+ """
3
+
4
+ import re
5
+ import argparse
6
+ import os
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ from transformers import CLIPVisionModel, LlamaForCausalLM, LlamaTokenizer
11
+
12
+ import sys
13
+ from modeling_flamingo import FlamingoForConditionalGeneration
14
+
15
+ from configuration_flamingo import FlamingoConfig
16
+
17
+
18
+ @torch.no_grad()
19
+ def dump_hf_model(pretrained_model_path: str, old_ckpt_path: str, new_folder_path: str) -> None:
20
+ old_ckpt = torch.load(old_ckpt_path, map_location="cpu")
21
+ if old_ckpt.get("model_state_dict", None) is not None:
22
+ old_ckpt = old_ckpt["model_state_dict"]
23
+ new_ckpt = old_ckpt
24
+ folder_path = os.path.dirname(old_ckpt_path)
25
+ # config_path = os.path.join(folder_path, "config.json") if os.path.exists(os.path.join(folder_path, "config.json")) else "flamingo/config.json"
26
+ model = FlamingoForConditionalGeneration.from_pretrained(
27
+ args.pretrained_model_path,
28
+ device_map="auto",
29
+ )
30
+ _ = model.load_state_dict(new_ckpt, strict=False)
31
+ print(f"Saving HF model to {new_folder_path}")
32
+ model.save_pretrained(new_folder_path)
33
+
34
+
35
+ if __name__ == "__main__":
36
+ parser = argparse.ArgumentParser()
37
+ parser.add_argument(
38
+ "--old_ckpt_path",
39
+ "-old",
40
+ type=str,
41
+ required=True,
42
+ help="Path to the pt checkpoint",
43
+ )
44
+ parser.add_argument(
45
+ "--new_hf_path",
46
+ "-new",
47
+ type=str,
48
+ required=True,
49
+ help="Path to the hf folder",
50
+ )
51
+ parser.add_argument(
52
+ "--pretrained_model_path",
53
+ "-pretrained",
54
+ type=str,
55
+ required=True,
56
+ help="Path to the pretrained model folder",
57
+ )
58
+ args = parser.parse_args()
59
+ if not os.path.exists(os.path.dirname(args.new_hf_path)):
60
+ os.makedirs(os.path.dirname(args.new_hf_path))
61
+ dump_hf_model(args.pretrained_model_path, args.old_ckpt_path, args.new_hf_path)
mllm/flamingo/converting_flamingo_to_lora.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import torch
3
+ import sys
4
+
5
+ from .modeling_flamingo import FlamingoForConditionalGeneration
6
+ from peft import get_peft_model, LoraConfig, TaskType
7
+
8
+ MODEL_CLASSES = {
9
+ "LlamaForCausalLM": "llama",
10
+ "OPTForCausalLM": "opt",
11
+ "GPTJForCausalLM": "gptj",
12
+ "GPTNeoXForCausalLM": "gpt_neox",
13
+ "MPTForCausalLM": "mpt",
14
+ }
15
+
16
+ # Define argument parser
17
+ parser = argparse.ArgumentParser(description="Load a model with specified precision and save it to a specified path.")
18
+
19
+ # Add arguments
20
+ parser.add_argument(
21
+ "--checkpoint_path",
22
+ type=str,
23
+ help="Path to the pre-trained model checkpoint.",
24
+ default="",
25
+ )
26
+ parser.add_argument(
27
+ "--save_path",
28
+ type=str,
29
+ default="",
30
+ help="Path to the converted model checkpoint.",
31
+ )
32
+
33
+ # Parse the input arguments
34
+ args = parser.parse_args()
35
+
36
+ load_bit = "bf16"
37
+ if load_bit == "fp16":
38
+ precision = {"torch_dtype": torch.float16}
39
+ elif load_bit == "bf16":
40
+ precision = {"torch_dtype": torch.bfloat16}
41
+
42
+ # Load the model
43
+ model = FlamingoForConditionalGeneration.from_pretrained(args.checkpoint_path, device_map="auto", **precision)
44
+
45
+ # adding lora
46
+ standard_modules = ["q_proj", "v_proj"]
47
+ lang_encoder_short_name = MODEL_CLASSES[model.config.text_config.architectures[0]]
48
+ model_to_lora_modules = {
49
+ "llama": standard_modules,
50
+ "opt": standard_modules,
51
+ "gptj": standard_modules,
52
+ "gpt_neox": ["query_key_value"],
53
+ "mpt": ["Wqkv"],
54
+ }
55
+ lora_config = LoraConfig(
56
+ r=16,
57
+ lora_alpha=32,
58
+ lora_dropout=0.05,
59
+ task_type=TaskType.CAUSAL_LM,
60
+ target_modules=model_to_lora_modules[lang_encoder_short_name],
61
+ )
62
+ model.config.update({"lora_config": {"r": 16, "lora_alpha": 32, "lora_dropout": 0.05}})
63
+ model.lang_encoder = get_peft_model(model.lang_encoder, lora_config)
64
+ model.lang_encoder.print_trainable_parameters()
65
+
66
+ # Save the model
67
+ checkpoint_path = args.save_path
68
+ FlamingoForConditionalGeneration.save_pretrained(model, checkpoint_path)
mllm/flamingo/falcon/__init__.py ADDED
File without changes
mllm/flamingo/falcon/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (209 Bytes). View file
 
mllm/flamingo/falcon/__pycache__/configuration_RW.cpython-39.pyc ADDED
Binary file (1.86 kB). View file
 
mllm/flamingo/falcon/__pycache__/modelling_RW.cpython-39.pyc ADDED
Binary file (28.5 kB). View file
 
mllm/flamingo/falcon/configuration_RW.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 the Big Science Workshop and HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ Bloom configuration"""
16
+ from transformers.configuration_utils import PretrainedConfig
17
+ from transformers.utils import logging
18
+
19
+
20
+ logger = logging.get_logger(__name__)
21
+
22
+
23
+ class RWConfig(PretrainedConfig):
24
+ model_type = "RefinedWebModel"
25
+ keys_to_ignore_at_inference = ["past_key_values"]
26
+ attribute_map = {
27
+ "num_hidden_layers": "n_layer",
28
+ "num_attention_heads": "n_head",
29
+ }
30
+
31
+ def __init__(
32
+ self,
33
+ vocab_size=250880,
34
+ hidden_size=64,
35
+ n_layer=2,
36
+ n_head=8,
37
+ layer_norm_epsilon=1e-5,
38
+ initializer_range=0.02,
39
+ use_cache=True,
40
+ bos_token_id=1,
41
+ eos_token_id=2,
42
+ apply_residual_connection_post_layernorm=False,
43
+ hidden_dropout=0.0,
44
+ attention_dropout=0.0,
45
+ multi_query=False,
46
+ alibi=False,
47
+ bias=False,
48
+ parallel_attn=False,
49
+ **kwargs,
50
+ ):
51
+ self.vocab_size = vocab_size
52
+ # Backward compatibility with n_embed kwarg
53
+ n_embed = kwargs.pop("n_embed", None)
54
+ self.hidden_size = hidden_size if n_embed is None else n_embed
55
+ self.n_layer = n_layer
56
+ self.n_head = n_head
57
+ self.layer_norm_epsilon = layer_norm_epsilon
58
+ self.initializer_range = initializer_range
59
+ self.use_cache = use_cache
60
+ self.apply_residual_connection_post_layernorm = apply_residual_connection_post_layernorm
61
+ self.hidden_dropout = hidden_dropout
62
+ self.attention_dropout = attention_dropout
63
+
64
+ self.bos_token_id = bos_token_id
65
+ self.eos_token_id = eos_token_id
66
+ self.multi_query = multi_query
67
+ self.alibi = alibi
68
+ self.bias = bias
69
+ self.parallel_attn = parallel_attn
70
+
71
+ super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
72
+
73
+ @property
74
+ def head_dim(self):
75
+ return self.hidden_size // self.n_head
76
+
77
+ @property
78
+ def rotary(self):
79
+ return not self.alibi
mllm/flamingo/falcon/modelling_RW.py ADDED
@@ -0,0 +1,1064 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # port of models described in RW
2
+ # We use the bloom model as a starting point for these model.
3
+ # Please refer to the bloom models for usage instructions.
4
+
5
+ import math
6
+ import warnings
7
+ from typing import Optional, Tuple, Union
8
+
9
+ import torch
10
+ import torch.utils.checkpoint
11
+ from torch import nn
12
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, LayerNorm, MSELoss
13
+ from torch.nn import functional as F
14
+
15
+ from transformers.modeling_outputs import (
16
+ BaseModelOutputWithPastAndCrossAttentions,
17
+ CausalLMOutputWithCrossAttentions,
18
+ QuestionAnsweringModelOutput,
19
+ SequenceClassifierOutputWithPast,
20
+ TokenClassifierOutput,
21
+ )
22
+ from transformers.modeling_utils import PreTrainedModel
23
+ from transformers.utils import logging
24
+ from .configuration_RW import RWConfig
25
+
26
+ logger = logging.get_logger(__name__)
27
+
28
+
29
+ # NOTE(Hesslow): Unfortunately we did not fuse matmul and bias during training, this means that there's one additional quantization to bfloat16 between the operations.
30
+ # In order not to degrade the quality of our HF-port, we keep these characteristics in the final model.
31
+ class Linear(nn.Linear):
32
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
33
+ ret = input @ self.weight.T
34
+ if self.bias is None:
35
+ return ret
36
+ else:
37
+ return ret + self.bias
38
+
39
+
40
+ from einops import rearrange
41
+
42
+
43
+ # rotary pos emb helpers (torch.jit.script does not seem to support staticmethod...)
44
+ def rotate_half(x):
45
+ x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
46
+ return torch.cat((-x2, x1), dim=x1.ndim - 1) # dim=-1 triggers a bug in torch < 1.8.0
47
+
48
+
49
+ class RotaryEmbedding(torch.nn.Module):
50
+ """Implementation of RotaryEmbedding from GPT-NeoX.
51
+ This implementation is design to operate on queries and keys that are compatible with
52
+ [batch_size, n_heads_per_partition, seq_len, head_dim] (e.g. MinGPTAttention format).
53
+ """
54
+
55
+ def __init__(
56
+ self,
57
+ head_dim: int,
58
+ base=10000,
59
+ ):
60
+ super().__init__()
61
+ inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2).float() / head_dim))
62
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
63
+ self.head_dim = head_dim
64
+ self.seq_len_cached = None
65
+ self.batch_size_cached = None
66
+ self.cos_cached: torch.Tensor | None = None
67
+ self.sin_cached: torch.Tensor | None = None
68
+
69
+ def cos_sin(
70
+ self,
71
+ seq_len: int,
72
+ device="cuda",
73
+ dtype=torch.bfloat16,
74
+ ) -> torch.Tensor:
75
+ if seq_len != self.seq_len_cached:
76
+ self.seq_len_cached = seq_len
77
+ t = torch.arange(seq_len, device=device).type_as(self.inv_freq)
78
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
79
+ emb = torch.cat((freqs, freqs), dim=-1).to(device)
80
+
81
+ if dtype in [torch.float16, torch.bfloat16]:
82
+ emb = emb.float()
83
+
84
+ self.cos_cached = emb.cos()[None, :, :]
85
+ self.sin_cached = emb.sin()[None, :, :]
86
+
87
+ self.cos_cached = self.cos_cached.type(dtype)
88
+ self.sin_cached = self.sin_cached.type(dtype)
89
+
90
+ return self.cos_cached, self.sin_cached
91
+
92
+ def forward(self, q, k):
93
+ batch, seq_len, head_dim = q.shape
94
+ cos, sin = self.cos_sin(seq_len, q.device, q.dtype)
95
+ return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)
96
+
97
+
98
+ def _make_causal_mask(input_ids_shape: torch.Size, device: torch.device, past_key_values_length: int) -> torch.BoolTensor:
99
+ batch_size, target_length = input_ids_shape
100
+ mask = torch.empty((target_length, target_length + past_key_values_length), dtype=torch.bool, device=device)
101
+ # ONNX doesn't support `torch.Tensor.triu` properly, thus we use this workaround
102
+ seq_ids = torch.arange(target_length, device=device)
103
+ mask[:, past_key_values_length:] = seq_ids[:, None] < seq_ids[None, :]
104
+
105
+ if past_key_values_length > 0:
106
+ mask[:, :past_key_values_length] = False
107
+
108
+ expanded_mask = mask[None, None, :, :].expand(batch_size, 1, target_length, target_length + past_key_values_length)
109
+ return expanded_mask
110
+
111
+
112
+ def _expand_mask(mask: torch.Tensor, tgt_length: int) -> torch.BoolTensor:
113
+ batch_size, src_length = mask.shape
114
+ tgt_length = tgt_length if tgt_length is not None else src_length
115
+
116
+ expanded_mask = ~(mask[:, None, None, :].to(torch.bool))
117
+ return expanded_mask.expand(batch_size, 1, tgt_length, src_length)
118
+
119
+
120
+ def build_alibi_tensor(attention_mask: torch.Tensor, num_heads: int, dtype: torch.dtype) -> torch.Tensor:
121
+ batch_size, seq_length = attention_mask.shape
122
+ closest_power_of_2 = 2 ** math.floor(math.log2(num_heads))
123
+ base = torch.tensor(2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))), device=attention_mask.device, dtype=torch.float32)
124
+ powers = torch.arange(1, 1 + closest_power_of_2, device=attention_mask.device, dtype=torch.int32)
125
+ slopes = torch.pow(base, powers)
126
+
127
+ if closest_power_of_2 != num_heads:
128
+ extra_base = torch.tensor(2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))), device=attention_mask.device, dtype=torch.float32)
129
+ num_remaining_heads = min(closest_power_of_2, num_heads - closest_power_of_2)
130
+ extra_powers = torch.arange(1, 1 + 2 * num_remaining_heads, 2, device=attention_mask.device, dtype=torch.int32)
131
+ slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0)
132
+
133
+ # Note: alibi will added to the attention bias that will be applied to the query, key product of attention
134
+ # => therefore alibi will have to be of shape (batch_size, num_heads, query_length, key_length)
135
+ # => here we set (batch_size=1, num_heads=num_heads, query_length=1, key_length=max_length)
136
+ # => the query_length dimension will then be broadcasted correctly
137
+ # This is more or less identical to T5's relative position bias:
138
+ # https://github.com/huggingface/transformers/blob/f681437203baa7671de3174b0fa583c349d9d5e1/src/transformers/models/t5/modeling_t5.py#L527
139
+ arange_tensor = ((attention_mask.cumsum(dim=-1) - 1) * attention_mask)[:, None, :]
140
+ alibi = slopes[..., None].bfloat16() * arange_tensor
141
+ return alibi.reshape(batch_size * num_heads, 1, seq_length).to(dtype)
142
+
143
+
144
+ def dropout_add(x: torch.Tensor, residual: torch.Tensor, prob: float, training: bool) -> torch.Tensor:
145
+ out = F.dropout(x, p=prob, training=training)
146
+ out = residual + out
147
+ return out
148
+
149
+
150
+ class Attention(nn.Module):
151
+ def __init__(self, config: RWConfig):
152
+ super().__init__()
153
+
154
+ self.hidden_size = config.hidden_size
155
+ self.num_heads = config.n_head
156
+ self.head_dim = self.hidden_size // self.num_heads
157
+ self.split_size = self.hidden_size
158
+ self.hidden_dropout = config.hidden_dropout
159
+
160
+ if self.head_dim * self.num_heads != self.hidden_size:
161
+ raise ValueError(f"`hidden_size` must be divisible by num_heads (got `hidden_size`: {self.hidden_size} and `num_heads`:" f" {self.num_heads}).")
162
+
163
+ self.maybe_rotary = RotaryEmbedding(config.head_dim) if config.rotary else lambda q, k: (q, k)
164
+
165
+ # Layer-wise attention scaling
166
+ self.inv_norm_factor = 1.0 / math.sqrt(self.head_dim)
167
+ self.beta = self.inv_norm_factor
168
+
169
+ self.query_key_value = Linear(
170
+ self.hidden_size,
171
+ 3 * self.hidden_size if not config.multi_query else (self.hidden_size + 2 * self.head_dim),
172
+ bias=config.bias,
173
+ )
174
+ self.multi_query = config.multi_query
175
+ self.dense = Linear(self.hidden_size, self.hidden_size, bias=config.bias)
176
+ self.attention_dropout = nn.Dropout(config.attention_dropout)
177
+ self.num_kv = config.n_head if not self.multi_query else 1
178
+
179
+ def _split_heads(self, fused_qkv: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
180
+ """
181
+ Split the last dimension into (num_heads, head_dim) without making any copies, results share same memory
182
+ storage as `fused_qkv`
183
+
184
+ Args:
185
+ fused_qkv (`torch.tensor`, *required*): [batch_size, seq_length, num_heads * 3 * head_dim]
186
+
187
+ Returns:
188
+ query: [batch_size, seq_length, num_heads, head_dim] key: [batch_size, seq_length, num_heads, head_dim]
189
+ value: [batch_size, seq_length, num_heads, head_dim]
190
+ """
191
+ if not self.multi_query:
192
+ batch_size, seq_length, three_times_hidden_size = fused_qkv.shape
193
+ fused_qkv = fused_qkv.view(batch_size, seq_length, self.num_heads, 3, self.head_dim)
194
+ return fused_qkv[..., 0, :], fused_qkv[..., 1, :], fused_qkv[..., 2, :]
195
+ else:
196
+ batch_size, seq_length, three_times_hidden_size = fused_qkv.shape
197
+ fused_qkv = fused_qkv.view(batch_size, seq_length, self.num_heads + 2, self.head_dim)
198
+ return fused_qkv[..., :-2, :], fused_qkv[..., [-2], :], fused_qkv[..., [-1], :]
199
+
200
+ def _merge_heads(self, x: torch.Tensor) -> torch.Tensor:
201
+ """
202
+ Merge heads together over the last dimenstion
203
+
204
+ Args:
205
+ x: (`torch.tensor`, *required*): [batch_size * num_heads, seq_length, head_dim]
206
+
207
+ Returns:
208
+ torch.tensor: [batch_size, seq_length, num_heads * head_dim]
209
+ """
210
+ # What we want to achieve is:
211
+ # batch_size * num_heads, seq_length, head_dim -> batch_size, seq_length, num_heads * head_dim
212
+ batch_size_and_num_heads, seq_length, _ = x.shape
213
+ batch_size = batch_size_and_num_heads // self.num_heads
214
+
215
+ # First view to decompose the batch size
216
+ # batch_size * num_heads, seq_length, head_dim -> batch_size, num_heads, seq_length, head_dim
217
+ x = x.view(batch_size, self.num_heads, seq_length, self.head_dim)
218
+
219
+ # batch_size, num_heads, seq_length, head_dim -> batch_size, seq_length, num_heads, head_dim
220
+ x = x.permute(0, 2, 1, 3)
221
+
222
+ # batch_size, seq_length, num_heads, head_dim -> batch_size, seq_length, num_heads * head_dim
223
+ return x.reshape(batch_size, seq_length, self.num_heads * self.head_dim)
224
+
225
+ def forward(
226
+ self,
227
+ hidden_states: torch.Tensor,
228
+ alibi: torch.Tensor,
229
+ attention_mask: torch.Tensor,
230
+ layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
231
+ head_mask: Optional[torch.Tensor] = None,
232
+ use_cache: bool = False,
233
+ output_attentions: bool = False,
234
+ ):
235
+ fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size]
236
+
237
+ # 3 x [batch_size, seq_length, num_heads, head_dim]
238
+ (query_layer, key_layer, value_layer) = self._split_heads(fused_qkv)
239
+
240
+ batch_size, q_length, _, _ = query_layer.shape
241
+
242
+ query_layer = query_layer.transpose(1, 2).reshape(batch_size * self.num_heads, q_length, self.head_dim)
243
+ key_layer = key_layer.transpose(1, 2).reshape(
244
+ batch_size * self.num_kv,
245
+ q_length,
246
+ self.head_dim,
247
+ )
248
+ value_layer = value_layer.transpose(1, 2).reshape(batch_size * self.num_kv, q_length, self.head_dim)
249
+
250
+ query_layer, key_layer = self.maybe_rotary(query_layer, key_layer)
251
+
252
+ if layer_past is not None:
253
+ past_key, past_value = layer_past
254
+ # concatenate along seq_length dimension:
255
+ # - key: [batch_size * self.num_heads, head_dim, kv_length]
256
+ # - value: [batch_size * self.num_heads, kv_length, head_dim]
257
+ key_layer = torch.cat((past_key, key_layer), dim=1)
258
+ value_layer = torch.cat((past_value, value_layer), dim=1)
259
+
260
+ _, kv_length, _ = key_layer.shape
261
+
262
+ if use_cache is True:
263
+ present = (key_layer, value_layer)
264
+ else:
265
+ present = None
266
+
267
+ if alibi is None:
268
+ query_layer_ = query_layer.reshape(batch_size, self.num_heads, -1, self.head_dim)
269
+ key_layer_ = key_layer.reshape(batch_size, self.num_kv, -1, self.head_dim)
270
+ value_layer_ = value_layer.reshape(batch_size, self.num_kv, -1, self.head_dim)
271
+
272
+ attn_output = F.scaled_dot_product_attention(query_layer_, key_layer_, value_layer_, None, 0.0, is_causal=True)
273
+
274
+ x = attn_output.view(batch_size, self.num_heads, q_length, self.head_dim)
275
+ x = x.permute(0, 2, 1, 3)
276
+ attn_output = x.reshape(batch_size, q_length, self.num_heads * self.head_dim)
277
+
278
+ output_tensor = self.dense(attn_output)
279
+
280
+ outputs = (output_tensor, present)
281
+ assert not output_attentions # not supported.
282
+ return outputs
283
+ else:
284
+ attention_mask_float = (attention_mask * 1.0).masked_fill(attention_mask, -1e9).to(torch.bfloat16)
285
+ matmul_result = query_layer @ key_layer.transpose(-1, -2)
286
+
287
+ # change view to [batch_size, num_heads, q_length, kv_length]
288
+ attention_scores = matmul_result.view(batch_size, self.num_heads, q_length, kv_length)
289
+
290
+ # cast attention scores to fp32, compute scaled softmax and cast back to initial dtype - [batch_size, num_heads, q_length, kv_length]
291
+ input_dtype = attention_scores.dtype
292
+ # `float16` has a minimum value of -65504.0, whereas `bfloat16` and `float32` have a minimum value of `-3.4e+38`
293
+ if input_dtype == torch.float16 or input_dtype == torch.bfloat16:
294
+ attention_scores = attention_scores.to(torch.float32)
295
+ # attn_weights = torch.masked_fill(attention_scores, attention_mask, torch.finfo(attention_scores.dtype).min)
296
+ attention_probs = F.softmax(
297
+ (attention_scores + alibi.view(batch_size, self.num_heads, 1, -1)) * self.inv_norm_factor + attention_mask_float,
298
+ dim=-1,
299
+ dtype=hidden_states.dtype,
300
+ )
301
+ # [batch_size, num_heads, q_length, kv_length]
302
+ attention_probs = self.attention_dropout(attention_probs)
303
+
304
+ if head_mask is not None:
305
+ attention_probs = attention_probs * head_mask
306
+
307
+ # change view [batch_size x num_heads, q_length, kv_length]
308
+ attention_probs_reshaped = attention_probs.view(batch_size * self.num_heads, q_length, kv_length)
309
+
310
+ # matmul: [batch_size * num_heads, q_length, head_dim]
311
+ context_layer = attention_probs_reshaped @ value_layer
312
+
313
+ # change view [batch_size, num_heads, q_length, head_dim]
314
+ context_layer = self._merge_heads(context_layer)
315
+
316
+ output_tensor = self.dense(context_layer)
317
+
318
+ outputs = (output_tensor, present)
319
+ if output_attentions:
320
+ outputs += (attention_probs,)
321
+
322
+ return outputs
323
+
324
+
325
+ class MLP(nn.Module):
326
+ def __init__(self, config: RWConfig):
327
+ super().__init__()
328
+ hidden_size = config.hidden_size
329
+
330
+ self.dense_h_to_4h = Linear(hidden_size, 4 * hidden_size, bias=config.bias)
331
+ self.act = nn.GELU()
332
+ self.dense_4h_to_h = Linear(4 * hidden_size, hidden_size, bias=config.bias)
333
+ self.hidden_dropout = config.hidden_dropout
334
+
335
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
336
+ x = self.act(self.dense_h_to_4h(x))
337
+ x = self.dense_4h_to_h(x)
338
+ return x
339
+
340
+
341
+ class DecoderLayer(nn.Module):
342
+ def __init__(self, config: RWConfig):
343
+ super().__init__()
344
+ hidden_size = config.hidden_size
345
+
346
+ self.input_layernorm = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
347
+ self.num_heads = config.n_head
348
+ self.self_attention = Attention(config)
349
+
350
+ if not config.parallel_attn:
351
+ # unused if parallel attn
352
+ self.post_attention_layernorm = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
353
+
354
+ self.mlp = MLP(config)
355
+
356
+ self.apply_residual_connection_post_layernorm = config.apply_residual_connection_post_layernorm
357
+ self.hidden_dropout = config.hidden_dropout
358
+
359
+ self.config = config
360
+
361
+ def forward(
362
+ self,
363
+ hidden_states: torch.Tensor,
364
+ alibi: torch.Tensor,
365
+ attention_mask: torch.Tensor,
366
+ layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
367
+ head_mask: Optional[torch.Tensor] = None,
368
+ use_cache: bool = False,
369
+ output_attentions: bool = False,
370
+ ):
371
+ layernorm_output = self.input_layernorm(hidden_states)
372
+ residual = hidden_states
373
+
374
+ # Self attention.
375
+ attn_outputs = self.self_attention(
376
+ layernorm_output,
377
+ layer_past=layer_past,
378
+ attention_mask=attention_mask,
379
+ alibi=alibi,
380
+ head_mask=head_mask,
381
+ use_cache=use_cache,
382
+ output_attentions=output_attentions,
383
+ )
384
+
385
+ attention_output = attn_outputs[0]
386
+
387
+ if not self.config.parallel_attn:
388
+ residual = dropout_add(attention_output, residual, self.config.attention_dropout, training=self.training)
389
+ layernorm_output = self.post_attention_layernorm(residual)
390
+
391
+ outputs = attn_outputs[1:]
392
+
393
+ # MLP.
394
+ mlp_output = self.mlp(layernorm_output)
395
+
396
+ if self.config.parallel_attn:
397
+ mlp_output += attention_output
398
+
399
+ output = dropout_add(mlp_output, residual, self.config.hidden_dropout, training=self.training)
400
+
401
+ if use_cache:
402
+ outputs = (output,) + outputs
403
+ else:
404
+ outputs = (output,) + outputs[1:]
405
+
406
+ return outputs # hidden_states, present, attentions
407
+
408
+
409
+ class RWPreTrainedModel(PreTrainedModel):
410
+ _keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"]
411
+ """
412
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
413
+ models.
414
+ """
415
+
416
+ config_class = RWConfig
417
+ base_model_prefix = "transformer"
418
+ supports_gradient_checkpointing = True
419
+ _no_split_modules = ["DecoderLayer"]
420
+
421
+ def __init__(self, *inputs, **kwargs):
422
+ super().__init__(*inputs, **kwargs)
423
+
424
+ def _init_weights(self, module: nn.Module):
425
+ """Initialize the weights."""
426
+ if isinstance(module, nn.Linear) or isinstance(module, Linear):
427
+ # Slightly different from the TF version which uses truncated_normal for initialization
428
+ # cf https://github.com/pytorch/pytorch/pull/5617
429
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
430
+ if module.bias is not None:
431
+ module.bias.data.zero_()
432
+ elif isinstance(module, nn.Embedding):
433
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
434
+ if module.padding_idx is not None:
435
+ module.weight.data[module.padding_idx].zero_()
436
+ elif isinstance(module, LayerNorm):
437
+ module.bias.data.zero_()
438
+ module.weight.data.fill_(1.0)
439
+
440
+ def _set_gradient_checkpointing(self, module: nn.Module, value: bool = False):
441
+ if isinstance(module, RWModel):
442
+ module.gradient_checkpointing = value
443
+
444
+ @staticmethod
445
+ def _convert_to_standard_cache(past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]], batch_size: int) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]:
446
+ """
447
+ Standardizes the format of the cache so as to match most implementations, i.e. to tuple(tuple([batch_size,
448
+ num_heads, ...]))
449
+ """
450
+ batch_size_times_num_heads, head_dim, seq_length = past_key_value[0][0].shape
451
+ num_heads = batch_size_times_num_heads // batch_size
452
+ # key: [batch_size * num_heads, head_dim, seq_length] -> [batch_size, num_heads, head_dim, seq_length]
453
+ # value: [batch_size * num_heads, seq_length, head_dim] -> [batch_size, num_heads, seq_length, head_dim]
454
+ return tuple(
455
+ (
456
+ layer_past[0].view(batch_size, num_heads, head_dim, seq_length),
457
+ layer_past[1].view(batch_size, num_heads, seq_length, head_dim),
458
+ )
459
+ for layer_past in past_key_value
460
+ )
461
+
462
+ @staticmethod
463
+ def _convert_to_rw_cache(past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]]) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]:
464
+ batch_size, num_heads, head_dim, seq_length = past_key_value[0][0].shape
465
+ batch_size_times_num_heads = batch_size * num_heads
466
+ # key: [batch_size, num_heads, head_dim, seq_length] -> [batch_size * num_heads, head_dim, seq_length]
467
+ # value: [batch_size, num_heads, seq_length, head_dim] -> [batch_size * num_heads, seq_length, head_dim]
468
+ return tuple(
469
+ (
470
+ layer_past[0].view(batch_size_times_num_heads, head_dim, seq_length),
471
+ layer_past[1].view(batch_size_times_num_heads, seq_length, head_dim),
472
+ )
473
+ for layer_past in past_key_value
474
+ )
475
+
476
+
477
+ class RWModel(RWPreTrainedModel):
478
+ def __init__(self, config: RWConfig):
479
+ super().__init__(config)
480
+
481
+ self.embed_dim = config.hidden_size
482
+ self.num_heads = config.n_head
483
+ self.alibi = config.alibi
484
+
485
+ # Embedding + LN Embedding
486
+ self.word_embeddings = nn.Embedding(config.vocab_size, self.embed_dim)
487
+
488
+ # Transformer blocks
489
+ self.h = nn.ModuleList([DecoderLayer(config) for _ in range(config.num_hidden_layers)])
490
+
491
+ # Final Layer Norm
492
+ self.ln_f = LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
493
+
494
+ self.gradient_checkpointing = False
495
+
496
+ # Initialize weights and apply final processing
497
+ self.post_init()
498
+
499
+ def get_input_embeddings(self):
500
+ return self.word_embeddings
501
+
502
+ def _prepare_attn_mask(self, attention_mask: torch.Tensor, input_shape: Tuple[int, int], past_key_values_length: int) -> torch.BoolTensor:
503
+ # create causal mask
504
+ # [batch_size, seq_length] -> [batch_size, 1, tgt_length, src_length]
505
+ combined_attention_mask = None
506
+ device = attention_mask.device
507
+ _, src_length = input_shape
508
+
509
+ if src_length > 1:
510
+ combined_attention_mask = _make_causal_mask(input_shape, device=device, past_key_values_length=past_key_values_length)
511
+
512
+ # [batch_size, seq_length] -> [batch_size, 1, tgt_length, src_length]
513
+ expanded_attn_mask = _expand_mask(attention_mask, tgt_length=src_length)
514
+ combined_attention_mask = expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask | combined_attention_mask
515
+
516
+ return combined_attention_mask
517
+
518
+ def set_input_embeddings(self, new_embeddings: torch.Tensor):
519
+ self.word_embeddings = new_embeddings
520
+
521
+ def forward(
522
+ self,
523
+ input_ids: Optional[torch.LongTensor] = None,
524
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
525
+ attention_mask: Optional[torch.Tensor] = None,
526
+ head_mask: Optional[torch.LongTensor] = None,
527
+ inputs_embeds: Optional[torch.LongTensor] = None,
528
+ use_cache: Optional[bool] = None,
529
+ output_attentions: Optional[bool] = None,
530
+ output_hidden_states: Optional[bool] = None,
531
+ return_dict: Optional[bool] = None,
532
+ **deprecated_arguments,
533
+ ) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]:
534
+ if deprecated_arguments.pop("position_ids", False) is not False:
535
+ # `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None`
536
+ warnings.warn(
537
+ "`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore" " passing `position_ids`.",
538
+ FutureWarning,
539
+ )
540
+ if len(deprecated_arguments) > 0:
541
+ raise ValueError(f"Got unexpected arguments: {deprecated_arguments}")
542
+
543
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
544
+ output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
545
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
546
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
547
+
548
+ if input_ids is not None and inputs_embeds is not None:
549
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
550
+ elif input_ids is not None:
551
+ batch_size, seq_length = input_ids.shape
552
+ elif inputs_embeds is not None:
553
+ batch_size, seq_length, _ = inputs_embeds.shape
554
+ else:
555
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
556
+
557
+ if past_key_values is None:
558
+ past_key_values = tuple([None] * len(self.h))
559
+
560
+ # Prepare head mask if needed
561
+ # 1.0 in head_mask indicate we keep the head
562
+ # attention_probs has shape batch_size x num_heads x N x N
563
+ # head_mask has shape n_layer x batch x num_heads x N x N
564
+ head_mask = self.get_head_mask(head_mask, self.config.n_layer)
565
+
566
+ if inputs_embeds is None:
567
+ inputs_embeds = self.word_embeddings(input_ids)
568
+
569
+ hidden_states = inputs_embeds
570
+
571
+ presents = () if use_cache else None
572
+ all_self_attentions = () if output_attentions else None
573
+ all_hidden_states = () if output_hidden_states else None
574
+
575
+ # Compute alibi tensor: check build_alibi_tensor documentation
576
+ seq_length_with_past = seq_length
577
+ past_key_values_length = 0
578
+ if past_key_values[0] is not None:
579
+ past_key_values_length = past_key_values[0][0].shape[2]
580
+ seq_length_with_past = seq_length_with_past + past_key_values_length
581
+ if attention_mask is None:
582
+ attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device)
583
+ else:
584
+ attention_mask = attention_mask.to(hidden_states.device)
585
+
586
+ if self.alibi:
587
+ alibi = build_alibi_tensor(attention_mask, self.num_heads, dtype=hidden_states.dtype)
588
+ else:
589
+ alibi = None
590
+
591
+ causal_mask = self._prepare_attn_mask(
592
+ attention_mask,
593
+ input_shape=(batch_size, seq_length),
594
+ past_key_values_length=past_key_values_length,
595
+ )
596
+
597
+ for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
598
+ if output_hidden_states:
599
+ all_hidden_states = all_hidden_states + (hidden_states,)
600
+
601
+ if self.gradient_checkpointing and self.training:
602
+ if use_cache:
603
+ logger.warning("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
604
+ use_cache = False
605
+
606
+ def create_custom_forward(module):
607
+ def custom_forward(*inputs):
608
+ # None for past_key_value
609
+ return module(*inputs, use_cache=use_cache, output_attentions=output_attentions)
610
+
611
+ return custom_forward
612
+
613
+ outputs = torch.utils.checkpoint.checkpoint(
614
+ create_custom_forward(block),
615
+ hidden_states,
616
+ alibi,
617
+ causal_mask,
618
+ head_mask[i],
619
+ )
620
+ else:
621
+ outputs = block(
622
+ hidden_states,
623
+ layer_past=layer_past,
624
+ attention_mask=causal_mask,
625
+ head_mask=head_mask[i],
626
+ use_cache=use_cache,
627
+ output_attentions=output_attentions,
628
+ alibi=alibi,
629
+ )
630
+
631
+ hidden_states = outputs[0]
632
+ if use_cache is True:
633
+ presents = presents + (outputs[1],)
634
+
635
+ if output_attentions:
636
+ all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
637
+
638
+ # Add last hidden state
639
+ hidden_states = self.ln_f(hidden_states)
640
+
641
+ if output_hidden_states:
642
+ all_hidden_states = all_hidden_states + (hidden_states,)
643
+
644
+ if not return_dict:
645
+ return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
646
+
647
+ return BaseModelOutputWithPastAndCrossAttentions(
648
+ last_hidden_state=hidden_states,
649
+ past_key_values=presents,
650
+ hidden_states=all_hidden_states,
651
+ attentions=all_self_attentions,
652
+ )
653
+
654
+
655
+ class RWForCausalLM(RWPreTrainedModel):
656
+ _keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"]
657
+
658
+ def __init__(self, config: RWConfig):
659
+ super().__init__(config)
660
+ self.transformer = RWModel(config)
661
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
662
+
663
+ # Initialize weights and apply final processing
664
+ self.post_init()
665
+
666
+ def get_output_embeddings(self):
667
+ return self.lm_head
668
+
669
+ def set_output_embeddings(self, new_embeddings: torch.Tensor):
670
+ self.lm_head = new_embeddings
671
+
672
+ def prepare_inputs_for_generation(
673
+ self,
674
+ input_ids: torch.LongTensor,
675
+ past: Optional[torch.Tensor] = None,
676
+ attention_mask: Optional[torch.Tensor] = None,
677
+ **kwargs,
678
+ ) -> dict:
679
+ # only last token for input_ids if past is not None
680
+ if past:
681
+ input_ids = input_ids[:, -1].unsqueeze(-1)
682
+
683
+ # the cache may be in the stardard format (e.g. in contrastive search), convert to our's format if needed
684
+ if past[0][0].shape[0] == input_ids.shape[0]:
685
+ past = self._convert_to_rw_cache(past)
686
+
687
+ return {
688
+ "input_ids": input_ids,
689
+ "past_key_values": past,
690
+ "use_cache": kwargs.get("use_cache"),
691
+ "attention_mask": attention_mask,
692
+ }
693
+
694
+ def forward(
695
+ self,
696
+ input_ids: Optional[torch.LongTensor] = None,
697
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
698
+ attention_mask: Optional[torch.Tensor] = None,
699
+ head_mask: Optional[torch.Tensor] = None,
700
+ inputs_embeds: Optional[torch.Tensor] = None,
701
+ labels: Optional[torch.Tensor] = None,
702
+ use_cache: Optional[bool] = None,
703
+ output_attentions: Optional[bool] = None,
704
+ output_hidden_states: Optional[bool] = None,
705
+ return_dict: Optional[bool] = None,
706
+ **deprecated_arguments,
707
+ ) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
708
+ r"""
709
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
710
+ Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
711
+ `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
712
+ are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
713
+ """
714
+ if deprecated_arguments.pop("position_ids", False) is not False:
715
+ # `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None`
716
+ warnings.warn(
717
+ "`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore" " passing `position_ids`.",
718
+ FutureWarning,
719
+ )
720
+ if len(deprecated_arguments) > 0:
721
+ raise ValueError(f"Got unexpected arguments: {deprecated_arguments}")
722
+
723
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
724
+
725
+ transformer_outputs = self.transformer(
726
+ input_ids,
727
+ past_key_values=past_key_values,
728
+ attention_mask=attention_mask,
729
+ head_mask=head_mask,
730
+ inputs_embeds=inputs_embeds,
731
+ use_cache=use_cache,
732
+ output_attentions=output_attentions,
733
+ output_hidden_states=output_hidden_states,
734
+ return_dict=return_dict,
735
+ )
736
+ hidden_states = transformer_outputs[0]
737
+
738
+ lm_logits = self.lm_head(hidden_states)
739
+
740
+ loss = None
741
+ if labels is not None:
742
+ # Shift so that tokens < n predict n
743
+ shift_logits = lm_logits[..., :-1, :].contiguous()
744
+ shift_labels = labels[..., 1:].contiguous()
745
+ batch_size, seq_length, vocab_size = shift_logits.shape
746
+ # Flatten the tokens
747
+ loss_fct = CrossEntropyLoss()
748
+ loss = loss_fct(shift_logits.view(batch_size * seq_length, vocab_size), shift_labels.view(batch_size * seq_length))
749
+
750
+ if not return_dict:
751
+ output = (lm_logits,) + transformer_outputs[1:]
752
+ return ((loss,) + output) if loss is not None else output
753
+
754
+ return CausalLMOutputWithCrossAttentions(
755
+ loss=loss,
756
+ logits=lm_logits,
757
+ past_key_values=transformer_outputs.past_key_values,
758
+ hidden_states=transformer_outputs.hidden_states,
759
+ attentions=transformer_outputs.attentions,
760
+ )
761
+
762
+ def _reorder_cache(self, past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...], beam_idx: torch.LongTensor) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]:
763
+ """
764
+ This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
765
+ [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
766
+ beam_idx at every generation step.
767
+
768
+ Output shares the same memory storage as `past`.
769
+ """
770
+ standardized_past = self._convert_to_standard_cache(past, batch_size=len(beam_idx))
771
+
772
+ # Get a copy of `beam_idx` on all the devices where we need those indices.
773
+ device_to_beam_idx = {past_state.device: beam_idx.to(past_state.device) for layer_past in past for past_state in layer_past}
774
+ reordered_past = tuple(
775
+ (
776
+ layer_past[0].index_select(0, device_to_beam_idx[layer_past[0].device]),
777
+ layer_past[1].index_select(0, device_to_beam_idx[layer_past[0].device]),
778
+ )
779
+ for layer_past in standardized_past
780
+ )
781
+ return self._convert_to_rw_cache(reordered_past)
782
+
783
+
784
+ class RWForSequenceClassification(RWPreTrainedModel):
785
+ _keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"]
786
+
787
+ def __init__(self, config: RWConfig):
788
+ super().__init__(config)
789
+ self.num_labels = config.num_labels
790
+ self.transformer = RWModel(config)
791
+ self.score = nn.Linear(config.hidden_size, config.num_labels, bias=False)
792
+
793
+ # Initialize weights and apply final processing
794
+ self.post_init()
795
+
796
+ def forward(
797
+ self,
798
+ input_ids: Optional[torch.LongTensor] = None,
799
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
800
+ attention_mask: Optional[torch.Tensor] = None,
801
+ head_mask: Optional[torch.Tensor] = None,
802
+ inputs_embeds: Optional[torch.Tensor] = None,
803
+ labels: Optional[torch.Tensor] = None,
804
+ use_cache: Optional[bool] = None,
805
+ output_attentions: Optional[bool] = None,
806
+ output_hidden_states: Optional[bool] = None,
807
+ return_dict: Optional[bool] = None,
808
+ **deprecated_arguments,
809
+ ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutputWithPast]:
810
+ r"""
811
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
812
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
813
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
814
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
815
+ """
816
+ if deprecated_arguments.pop("position_ids", False) is not False:
817
+ # `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None`
818
+ warnings.warn(
819
+ "`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore" " passing `position_ids`.",
820
+ FutureWarning,
821
+ )
822
+ if len(deprecated_arguments) > 0:
823
+ raise ValueError(f"Got unexpected arguments: {deprecated_arguments}")
824
+
825
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
826
+
827
+ transformer_outputs = self.transformer(
828
+ input_ids,
829
+ past_key_values=past_key_values,
830
+ attention_mask=attention_mask,
831
+ head_mask=head_mask,
832
+ inputs_embeds=inputs_embeds,
833
+ use_cache=use_cache,
834
+ output_attentions=output_attentions,
835
+ output_hidden_states=output_hidden_states,
836
+ return_dict=return_dict,
837
+ )
838
+
839
+ hidden_states = transformer_outputs[0]
840
+ logits = self.score(hidden_states)
841
+
842
+ if input_ids is not None:
843
+ batch_size = input_ids.shape[0]
844
+ else:
845
+ batch_size = inputs_embeds.shape[0]
846
+
847
+ if self.config.pad_token_id is None and batch_size != 1:
848
+ raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
849
+ if self.config.pad_token_id is None:
850
+ sequence_lengths = -1
851
+ else:
852
+ if input_ids is not None:
853
+ sequence_lengths = torch.ne(input_ids, self.config.pad_token_id).sum(dim=-1) - 1
854
+ else:
855
+ sequence_lengths = -1
856
+ logger.warning(
857
+ f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
858
+ "unexpected if using padding tokens in conjunction with `inputs_embeds.`"
859
+ )
860
+
861
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
862
+
863
+ loss = None
864
+ if labels is not None:
865
+ if self.config.problem_type is None:
866
+ if self.num_labels == 1:
867
+ self.config.problem_type = "regression"
868
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
869
+ self.config.problem_type = "single_label_classification"
870
+ else:
871
+ self.config.problem_type = "multi_label_classification"
872
+
873
+ if self.config.problem_type == "regression":
874
+ loss_fct = MSELoss()
875
+ if self.num_labels == 1:
876
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
877
+ else:
878
+ loss = loss_fct(pooled_logits, labels)
879
+ elif self.config.problem_type == "single_label_classification":
880
+ loss_fct = CrossEntropyLoss()
881
+ loss = loss_fct(pooled_logits, labels)
882
+ elif self.config.problem_type == "multi_label_classification":
883
+ loss_fct = BCEWithLogitsLoss()
884
+ loss = loss_fct(pooled_logits, labels)
885
+ if not return_dict:
886
+ output = (pooled_logits,) + transformer_outputs[1:]
887
+ return ((loss,) + output) if loss is not None else output
888
+
889
+ return SequenceClassifierOutputWithPast(
890
+ loss=loss,
891
+ logits=pooled_logits,
892
+ past_key_values=transformer_outputs.past_key_values,
893
+ hidden_states=transformer_outputs.hidden_states,
894
+ attentions=transformer_outputs.attentions,
895
+ )
896
+
897
+
898
+ class RWForTokenClassification(RWPreTrainedModel):
899
+ _keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"]
900
+
901
+ def __init__(self, config: RWConfig):
902
+ super().__init__(config)
903
+ self.num_labels = config.num_labels
904
+
905
+ self.transformer = RWModel(config)
906
+ if hasattr(config, "classifier_dropout") and config.classifier_dropout is not None:
907
+ classifier_dropout = config.classifier_dropout
908
+ elif hasattr(config, "hidden_dropout") and config.hidden_dropout is not None:
909
+ classifier_dropout = config.hidden_dropout
910
+ else:
911
+ classifier_dropout = 0.1
912
+ self.dropout = nn.Dropout(classifier_dropout)
913
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
914
+
915
+ # Initialize weights and apply final processing
916
+ self.post_init()
917
+
918
+ def forward(
919
+ self,
920
+ input_ids: Optional[torch.LongTensor] = None,
921
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
922
+ attention_mask: Optional[torch.Tensor] = None,
923
+ head_mask: Optional[torch.Tensor] = None,
924
+ inputs_embeds: Optional[torch.Tensor] = None,
925
+ labels: Optional[torch.Tensor] = None,
926
+ use_cache: Optional[bool] = None,
927
+ output_attentions: Optional[bool] = None,
928
+ output_hidden_states: Optional[bool] = None,
929
+ return_dict: Optional[bool] = None,
930
+ **deprecated_arguments,
931
+ ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
932
+ r"""
933
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
934
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
935
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
936
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
937
+ """
938
+ if deprecated_arguments.pop("position_ids", False) is not False:
939
+ # `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None`
940
+ warnings.warn(
941
+ "`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore" " passing `position_ids`.",
942
+ FutureWarning,
943
+ )
944
+ if len(deprecated_arguments) > 0:
945
+ raise ValueError(f"Got unexpected arguments: {deprecated_arguments}")
946
+
947
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
948
+
949
+ transformer_outputs = self.transformer(
950
+ input_ids,
951
+ past_key_values=past_key_values,
952
+ attention_mask=attention_mask,
953
+ head_mask=head_mask,
954
+ inputs_embeds=inputs_embeds,
955
+ use_cache=use_cache,
956
+ output_attentions=output_attentions,
957
+ output_hidden_states=output_hidden_states,
958
+ return_dict=return_dict,
959
+ )
960
+
961
+ hidden_states = transformer_outputs[0]
962
+ hidden_states = self.dropout(hidden_states)
963
+ logits = self.classifier(hidden_states)
964
+
965
+ loss = None
966
+ if labels is not None:
967
+ batch_size, seq_length = labels.shape
968
+ loss_fct = CrossEntropyLoss()
969
+ loss = loss_fct(logits.view(batch_size * seq_length, self.num_labels), labels.view(batch_size * seq_length))
970
+
971
+ if not return_dict:
972
+ output = (logits,) + transformer_outputs[2:]
973
+ return ((loss,) + output) if loss is not None else output
974
+
975
+ return TokenClassifierOutput(
976
+ loss=loss,
977
+ logits=logits,
978
+ hidden_states=transformer_outputs.hidden_states,
979
+ attentions=transformer_outputs.attentions,
980
+ )
981
+
982
+
983
+ class RWForQuestionAnswering(RWPreTrainedModel):
984
+ _keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"]
985
+
986
+ def __init__(self, config):
987
+ super().__init__(config)
988
+ self.transformer = RWModel(config)
989
+ self.qa_outputs = nn.Linear(config.hidden_size, 2)
990
+
991
+ # Initialize weights and apply final processing
992
+ self.post_init()
993
+
994
+ def forward(
995
+ self,
996
+ input_ids: Optional[torch.LongTensor] = None,
997
+ attention_mask: Optional[torch.FloatTensor] = None,
998
+ position_ids: Optional[torch.LongTensor] = None,
999
+ head_mask: Optional[torch.FloatTensor] = None,
1000
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1001
+ start_positions: Optional[torch.LongTensor] = None,
1002
+ end_positions: Optional[torch.LongTensor] = None,
1003
+ output_attentions: Optional[bool] = None,
1004
+ output_hidden_states: Optional[bool] = None,
1005
+ return_dict: Optional[bool] = None,
1006
+ ) -> Union[Tuple, QuestionAnsweringModelOutput]:
1007
+ r"""
1008
+ start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1009
+ Labels for position (index) of the start of the labelled span for computing the token classification loss.
1010
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
1011
+ are not taken into account for computing the loss.
1012
+ end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1013
+ Labels for position (index) of the end of the labelled span for computing the token classification loss.
1014
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
1015
+ are not taken into account for computing the loss.
1016
+ """
1017
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1018
+
1019
+ outputs = self.transformer(
1020
+ input_ids,
1021
+ attention_mask=attention_mask,
1022
+ position_ids=position_ids,
1023
+ head_mask=head_mask,
1024
+ inputs_embeds=inputs_embeds,
1025
+ output_attentions=output_attentions,
1026
+ output_hidden_states=output_hidden_states,
1027
+ return_dict=return_dict,
1028
+ )
1029
+
1030
+ sequence_output = outputs[0]
1031
+
1032
+ logits = self.qa_outputs(sequence_output)
1033
+ start_logits, end_logits = logits.split(1, dim=-1)
1034
+ start_logits = start_logits.squeeze(-1).contiguous()
1035
+ end_logits = end_logits.squeeze(-1).contiguous()
1036
+
1037
+ total_loss = None
1038
+ if start_positions is not None and end_positions is not None:
1039
+ # If we are on multi-GPU, split add a dimension
1040
+ if len(start_positions.size()) > 1:
1041
+ start_positions = start_positions.squeeze(-1)
1042
+ if len(end_positions.size()) > 1:
1043
+ end_positions = end_positions.squeeze(-1)
1044
+ # sometimes the start/end positions are outside our model inputs, we ignore these terms
1045
+ ignored_index = start_logits.size(1)
1046
+ start_positions = start_positions.clamp(0, ignored_index)
1047
+ end_positions = end_positions.clamp(0, ignored_index)
1048
+
1049
+ loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
1050
+ start_loss = loss_fct(start_logits, start_positions)
1051
+ end_loss = loss_fct(end_logits, end_positions)
1052
+ total_loss = (start_loss + end_loss) / 2
1053
+
1054
+ if not return_dict:
1055
+ output = (start_logits, end_logits) + outputs[2:]
1056
+ return ((total_loss,) + output) if total_loss is not None else output
1057
+
1058
+ return QuestionAnsweringModelOutput(
1059
+ loss=total_loss,
1060
+ start_logits=start_logits,
1061
+ end_logits=end_logits,
1062
+ hidden_states=outputs.hidden_states,
1063
+ attentions=outputs.attentions,
1064
+ )
mllm/flamingo/flamingo-falcon-7B.json ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_commit_hash": null,
3
+ "architectures": [
4
+ "FlamingoModel"
5
+ ],
6
+ "cross_attn_every_n_layers": 4,
7
+ "model_type": "flamingo",
8
+ "text_config": {
9
+ "architectures": [
10
+ "RWForCausalLM"
11
+ ],
12
+ "apply_residual_connection_post_layernorm": false,
13
+ "attention_dropout": 0.0,
14
+ "bias": false,
15
+ "bos_token_id": 11,
16
+ "eos_token_id": 11,
17
+ "hidden_dropout": 0.0,
18
+ "hidden_size": 4544,
19
+ "initializer_range": 0.02,
20
+ "layer_norm_epsilon": 1e-05,
21
+ "model_type": "RefinedWebModel",
22
+ "multi_query": true,
23
+ "n_head": 71,
24
+ "n_layer": 32,
25
+ "parallel_attn": true,
26
+ "torch_dtype": "bfloat16",
27
+ "transformers_version": "4.27.4",
28
+ "use_cache": true,
29
+ "vocab_size": 65024
30
+ },
31
+ "tie_word_embeddings": false,
32
+ "torch_dtype": "float32",
33
+ "transformers_version": null,
34
+ "use_media_placement_augmentation": true,
35
+ "vision_config": {
36
+ "_name_or_path": "openai/clip-vit-large-patch14",
37
+ "add_cross_attention": false,
38
+ "architectures": null,
39
+ "attention_dropout": 0.0,
40
+ "bad_words_ids": null,
41
+ "begin_suppress_tokens": null,
42
+ "bos_token_id": null,
43
+ "chunk_size_feed_forward": 0,
44
+ "cross_attention_hidden_size": null,
45
+ "decoder_start_token_id": null,
46
+ "diversity_penalty": 0.0,
47
+ "do_sample": false,
48
+ "early_stopping": false,
49
+ "encoder_no_repeat_ngram_size": 0,
50
+ "eos_token_id": null,
51
+ "exponential_decay_length_penalty": null,
52
+ "finetuning_task": null,
53
+ "forced_bos_token_id": null,
54
+ "forced_eos_token_id": null,
55
+ "hidden_act": "quick_gelu",
56
+ "hidden_size": 1024,
57
+ "id2label": {
58
+ "0": "LABEL_0",
59
+ "1": "LABEL_1"
60
+ },
61
+ "image_size": 224,
62
+ "initializer_factor": 1.0,
63
+ "initializer_range": 0.02,
64
+ "intermediate_size": 4096,
65
+ "is_decoder": false,
66
+ "is_encoder_decoder": false,
67
+ "label2id": {
68
+ "LABEL_0": 0,
69
+ "LABEL_1": 1
70
+ },
71
+ "layer_norm_eps": 1e-05,
72
+ "length_penalty": 1.0,
73
+ "max_length": 20,
74
+ "min_length": 0,
75
+ "model_type": "clip_vision_model",
76
+ "no_repeat_ngram_size": 0,
77
+ "num_attention_heads": 16,
78
+ "num_beam_groups": 1,
79
+ "num_beams": 1,
80
+ "num_channels": 3,
81
+ "num_hidden_layers": 24,
82
+ "num_return_sequences": 1,
83
+ "output_attentions": false,
84
+ "output_hidden_states": false,
85
+ "output_scores": false,
86
+ "pad_token_id": null,
87
+ "patch_size": 14,
88
+ "prefix": null,
89
+ "problem_type": null,
90
+ "projection_dim": 512,
91
+ "pruned_heads": {},
92
+ "remove_invalid_values": false,
93
+ "repetition_penalty": 1.0,
94
+ "return_dict": true,
95
+ "return_dict_in_generate": false,
96
+ "sep_token_id": null,
97
+ "suppress_tokens": null,
98
+ "task_specific_params": null,
99
+ "temperature": 1.0,
100
+ "tf_legacy_loss": false,
101
+ "tie_encoder_decoder": false,
102
+ "tie_word_embeddings": true,
103
+ "tokenizer_class": null,
104
+ "top_k": 50,
105
+ "top_p": 1.0,
106
+ "torch_dtype": null,
107
+ "torchscript": false,
108
+ "transformers_version": "4.28.1",
109
+ "typical_p": 1.0,
110
+ "use_bfloat16": false
111
+ }
112
+ }
mllm/flamingo/flamingo-llama2-chat-13B.json ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_commit_hash": null,
3
+ "architectures": [
4
+ "FlamingoForConditionalGeneration"
5
+ ],
6
+ "cross_attn_every_n_layers": 8,
7
+ "model_type": "flamingo",
8
+ "text_config": {
9
+ "_name_or_path": "meta-llama/Llama-2-13b-chat-hf",
10
+ "architectures": [
11
+ "LlamaForCausalLM"
12
+ ],
13
+ "bos_token_id": 1,
14
+ "eos_token_id": 2,
15
+ "hidden_act": "silu",
16
+ "hidden_size": 5120,
17
+ "initializer_range": 0.02,
18
+ "intermediate_size": 13824,
19
+ "max_position_embeddings": 4096,
20
+ "model_type": "llama",
21
+ "num_attention_heads": 40,
22
+ "num_hidden_layers": 40,
23
+ "num_key_value_heads": 40,
24
+ "pad_token_id": 0,
25
+ "pretraining_tp": 1,
26
+ "rms_norm_eps": 1e-05,
27
+ "rope_scaling": null,
28
+ "tie_word_embeddings": false,
29
+ "torch_dtype": "float16",
30
+ "transformers_version": "4.30.1",
31
+ "use_cache": true,
32
+ "vocab_size": 32000
33
+ },
34
+ "torch_dtype": "float32",
35
+ "transformers_version": null,
36
+ "use_media_placement_augmentation": true,
37
+ "vision_config": {
38
+ "_name_or_path": "openai/clip-vit-large-patch14",
39
+ "add_cross_attention": false,
40
+ "architectures": null,
41
+ "attention_dropout": 0.0,
42
+ "bad_words_ids": null,
43
+ "begin_suppress_tokens": null,
44
+ "bos_token_id": null,
45
+ "chunk_size_feed_forward": 0,
46
+ "cross_attention_hidden_size": null,
47
+ "decoder_start_token_id": null,
48
+ "diversity_penalty": 0.0,
49
+ "do_sample": false,
50
+ "early_stopping": false,
51
+ "encoder_no_repeat_ngram_size": 0,
52
+ "eos_token_id": null,
53
+ "exponential_decay_length_penalty": null,
54
+ "finetuning_task": null,
55
+ "forced_bos_token_id": null,
56
+ "forced_eos_token_id": null,
57
+ "hidden_act": "quick_gelu",
58
+ "hidden_size": 1024,
59
+ "id2label": {
60
+ "0": "LABEL_0",
61
+ "1": "LABEL_1"
62
+ },
63
+ "image_size": 224,
64
+ "initializer_factor": 1.0,
65
+ "initializer_range": 0.02,
66
+ "intermediate_size": 4096,
67
+ "is_decoder": false,
68
+ "is_encoder_decoder": false,
69
+ "label2id": {
70
+ "LABEL_0": 0,
71
+ "LABEL_1": 1
72
+ },
73
+ "layer_norm_eps": 1e-05,
74
+ "length_penalty": 1.0,
75
+ "max_length": 20,
76
+ "min_length": 0,
77
+ "model_type": "clip_vision_model",
78
+ "no_repeat_ngram_size": 0,
79
+ "num_attention_heads": 16,
80
+ "num_beam_groups": 1,
81
+ "num_beams": 1,
82
+ "num_channels": 3,
83
+ "num_hidden_layers": 24,
84
+ "num_return_sequences": 1,
85
+ "output_attentions": false,
86
+ "output_hidden_states": false,
87
+ "output_scores": false,
88
+ "pad_token_id": null,
89
+ "patch_size": 14,
90
+ "prefix": null,
91
+ "problem_type": null,
92
+ "projection_dim": 512,
93
+ "pruned_heads": {},
94
+ "remove_invalid_values": false,
95
+ "repetition_penalty": 1.0,
96
+ "return_dict": true,
97
+ "return_dict_in_generate": false,
98
+ "sep_token_id": null,
99
+ "suppress_tokens": null,
100
+ "task_specific_params": null,
101
+ "temperature": 1.0,
102
+ "tf_legacy_loss": false,
103
+ "tie_encoder_decoder": false,
104
+ "tie_word_embeddings": true,
105
+ "tokenizer_class": null,
106
+ "top_k": 50,
107
+ "top_p": 1.0,
108
+ "torch_dtype": null,
109
+ "torchscript": false,
110
+ "transformers_version": "4.30.1",
111
+ "typical_p": 1.0,
112
+ "use_bfloat16": false
113
+ }
114
+ }
mllm/flamingo/flamingo-llama2-chat-7B.json ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_commit_hash": null,
3
+ "architectures": [
4
+ "FlamingoForConditionalGeneration"
5
+ ],
6
+ "cross_attn_every_n_layers": 4,
7
+ "model_type": "flamingo",
8
+ "text_config": {
9
+ "_name_or_path": "meta-llama/Llama-2-7b-chat-hf",
10
+ "architectures": [
11
+ "LlamaForCausalLM"
12
+ ],
13
+ "bos_token_id": 1,
14
+ "eos_token_id": 2,
15
+ "hidden_act": "silu",
16
+ "hidden_size": 4096,
17
+ "initializer_range": 0.02,
18
+ "intermediate_size": 11008,
19
+ "max_length": 4096,
20
+ "max_position_embeddings": 2048,
21
+ "model_type": "llama",
22
+ "num_attention_heads": 32,
23
+ "num_hidden_layers": 32,
24
+ "num_key_value_heads": 32,
25
+ "pad_token_id": 0,
26
+ "pretraining_tp": 1,
27
+ "rms_norm_eps": 1e-05,
28
+ "rope_scaling": null,
29
+ "tie_word_embeddings": false,
30
+ "torch_dtype": "float16",
31
+ "transformers_version": "4.32.0.dev0",
32
+ "use_cache": true,
33
+ "vocab_size": 32000
34
+ },
35
+ "torch_dtype": "float32",
36
+ "transformers_version": null,
37
+ "use_media_placement_augmentation": true,
38
+ "vision_config": {
39
+ "_name_or_path": "openai/clip-vit-large-patch14",
40
+ "add_cross_attention": false,
41
+ "architectures": null,
42
+ "attention_dropout": 0.0,
43
+ "bad_words_ids": null,
44
+ "begin_suppress_tokens": null,
45
+ "bos_token_id": null,
46
+ "chunk_size_feed_forward": 0,
47
+ "cross_attention_hidden_size": null,
48
+ "decoder_start_token_id": null,
49
+ "diversity_penalty": 0.0,
50
+ "do_sample": false,
51
+ "early_stopping": false,
52
+ "encoder_no_repeat_ngram_size": 0,
53
+ "eos_token_id": null,
54
+ "exponential_decay_length_penalty": null,
55
+ "finetuning_task": null,
56
+ "forced_bos_token_id": null,
57
+ "forced_eos_token_id": null,
58
+ "hidden_act": "quick_gelu",
59
+ "hidden_size": 1024,
60
+ "id2label": {
61
+ "0": "LABEL_0",
62
+ "1": "LABEL_1"
63
+ },
64
+ "image_size": 224,
65
+ "initializer_factor": 1.0,
66
+ "initializer_range": 0.02,
67
+ "intermediate_size": 4096,
68
+ "is_decoder": false,
69
+ "is_encoder_decoder": false,
70
+ "label2id": {
71
+ "LABEL_0": 0,
72
+ "LABEL_1": 1
73
+ },
74
+ "layer_norm_eps": 1e-05,
75
+ "length_penalty": 1.0,
76
+ "max_length": 20,
77
+ "min_length": 0,
78
+ "model_type": "clip_vision_model",
79
+ "no_repeat_ngram_size": 0,
80
+ "num_attention_heads": 16,
81
+ "num_beam_groups": 1,
82
+ "num_beams": 1,
83
+ "num_channels": 3,
84
+ "num_hidden_layers": 24,
85
+ "num_return_sequences": 1,
86
+ "output_attentions": false,
87
+ "output_hidden_states": false,
88
+ "output_scores": false,
89
+ "pad_token_id": null,
90
+ "patch_size": 14,
91
+ "prefix": null,
92
+ "problem_type": null,
93
+ "projection_dim": 512,
94
+ "pruned_heads": {},
95
+ "remove_invalid_values": false,
96
+ "repetition_penalty": 1.0,
97
+ "return_dict": true,
98
+ "return_dict_in_generate": false,
99
+ "sep_token_id": null,
100
+ "suppress_tokens": null,
101
+ "task_specific_params": null,
102
+ "temperature": 1.0,
103
+ "tf_legacy_loss": false,
104
+ "tie_encoder_decoder": false,
105
+ "tie_word_embeddings": true,
106
+ "tokenizer_class": null,
107
+ "top_k": 50,
108
+ "top_p": 1.0,
109
+ "torch_dtype": null,
110
+ "torchscript": false,
111
+ "transformers_version": "4.30.1",
112
+ "typical_p": 1.0,
113
+ "use_bfloat16": false
114
+ }
115
+ }
mllm/flamingo/flamingo-mpt-1B-redpajama.json ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_commit_hash": null,
3
+ "architectures": [
4
+ "FlamingoForConditionalGeneration"
5
+ ],
6
+ "cross_attn_every_n_layers": 1,
7
+ "model_type": "flamingo",
8
+ "text_config": {
9
+ "_name_or_path": "",
10
+ "alibi": true,
11
+ "alibi_bias_max": 8,
12
+ "architectures": [
13
+ "MosaicGPT"
14
+ ],
15
+ "attn_clip_qkv": null,
16
+ "attn_impl": "torch",
17
+ "attn_pdrop": 0,
18
+ "attn_qk_ln": true,
19
+ "attn_uses_sequence_id": false,
20
+ "d_model": 2048,
21
+ "hidden_size": 2048,
22
+ "emb_init_std": null,
23
+ "emb_init_uniform_lim": null,
24
+ "emb_pdrop": 0,
25
+ "embedding_fraction": 1.0,
26
+ "fan_mode": "fan_in",
27
+ "init_device": "cpu",
28
+ "init_div_is_residual": true,
29
+ "init_gain": 0,
30
+ "init_nonlinearity": "relu",
31
+ "init_std": 0.02,
32
+ "logit_scale": null,
33
+ "low_precision_layernorm": true,
34
+ "max_seq_len": 2048,
35
+ "mlp_ratio": 4,
36
+ "model_type": "mosaic_gpt",
37
+ "n_heads": 16,
38
+ "n_layers": 24,
39
+ "no_bias": true,
40
+ "param_init_fn": "kaiming_normal_",
41
+ "prefix_lm": false,
42
+ "resid_pdrop": 0,
43
+ "softmax_scale": null,
44
+ "tokenizer_name": "EleutherAI/gpt-neox-20b",
45
+ "torch_dtype": "float32",
46
+ "transformers_version": "4.27.4",
47
+ "use_cache": false,
48
+ "verbose": 0,
49
+ "vocab_size": 50432
50
+ },
51
+ "torch_dtype": "float32",
52
+ "transformers_version": null,
53
+ "use_media_placement_augmentation": true,
54
+ "vision_config": {
55
+ "_name_or_path": "openai/clip-vit-large-patch14",
56
+ "add_cross_attention": false,
57
+ "architectures": null,
58
+ "attention_dropout": 0.0,
59
+ "bad_words_ids": null,
60
+ "begin_suppress_tokens": null,
61
+ "bos_token_id": null,
62
+ "chunk_size_feed_forward": 0,
63
+ "cross_attention_hidden_size": null,
64
+ "decoder_start_token_id": null,
65
+ "diversity_penalty": 0.0,
66
+ "do_sample": false,
67
+ "early_stopping": false,
68
+ "encoder_no_repeat_ngram_size": 0,
69
+ "eos_token_id": null,
70
+ "exponential_decay_length_penalty": null,
71
+ "finetuning_task": null,
72
+ "forced_bos_token_id": null,
73
+ "forced_eos_token_id": null,
74
+ "hidden_act": "quick_gelu",
75
+ "hidden_size": 1024,
76
+ "id2label": {
77
+ "0": "LABEL_0",
78
+ "1": "LABEL_1"
79
+ },
80
+ "image_size": 224,
81
+ "initializer_factor": 1.0,
82
+ "initializer_range": 0.02,
83
+ "intermediate_size": 4096,
84
+ "is_decoder": false,
85
+ "is_encoder_decoder": false,
86
+ "label2id": {
87
+ "LABEL_0": 0,
88
+ "LABEL_1": 1
89
+ },
90
+ "layer_norm_eps": 1e-05,
91
+ "length_penalty": 1.0,
92
+ "max_length": 20,
93
+ "min_length": 0,
94
+ "model_type": "clip_vision_model",
95
+ "no_repeat_ngram_size": 0,
96
+ "num_attention_heads": 16,
97
+ "num_beam_groups": 1,
98
+ "num_beams": 1,
99
+ "num_channels": 3,
100
+ "num_hidden_layers": 24,
101
+ "num_return_sequences": 1,
102
+ "output_attentions": false,
103
+ "output_hidden_states": false,
104
+ "output_scores": false,
105
+ "pad_token_id": null,
106
+ "patch_size": 14,
107
+ "prefix": null,
108
+ "problem_type": null,
109
+ "projection_dim": 512,
110
+ "pruned_heads": {},
111
+ "remove_invalid_values": false,
112
+ "repetition_penalty": 1.0,
113
+ "return_dict": true,
114
+ "return_dict_in_generate": false,
115
+ "sep_token_id": null,
116
+ "suppress_tokens": null,
117
+ "task_specific_params": null,
118
+ "temperature": 1.0,
119
+ "tf_legacy_loss": false,
120
+ "tie_encoder_decoder": false,
121
+ "tie_word_embeddings": true,
122
+ "tokenizer_class": null,
123
+ "top_k": 50,
124
+ "top_p": 1.0,
125
+ "torch_dtype": null,
126
+ "torchscript": false,
127
+ "transformers_version": "4.30.1",
128
+ "typical_p": 1.0,
129
+ "use_bfloat16": false
130
+ }
131
+ }
mllm/flamingo/flamingo-mpt-30B-bf16.json ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_commit_hash": null,
3
+ "architectures": [
4
+ "FlamingoForConditionalGeneration"
5
+ ],
6
+ "cross_attn_every_n_layers": 7,
7
+ "model_type": "flamingo",
8
+ "text_config": {
9
+ "_name_or_path": "",
10
+ "add_cross_attention": false,
11
+ "architectures": [
12
+ "MPTForCausalLM"
13
+ ],
14
+ "attn_config": {
15
+ "alibi": true,
16
+ "alibi_bias_max": 8,
17
+ "attn_impl": "torch",
18
+ "attn_pdrop": 0,
19
+ "attn_type": "multihead_attention",
20
+ "attn_uses_sequence_id": false,
21
+ "clip_qkv": null,
22
+ "prefix_lm": false,
23
+ "qk_ln": false,
24
+ "softmax_scale": null
25
+ },
26
+ "bad_words_ids": null,
27
+ "begin_suppress_tokens": null,
28
+ "bos_token_id": null,
29
+ "chunk_size_feed_forward": 0,
30
+ "cross_attention_hidden_size": null,
31
+ "d_model": 7168,
32
+ "decoder_start_token_id": null,
33
+ "diversity_penalty": 0.0,
34
+ "do_sample": false,
35
+ "early_stopping": false,
36
+ "emb_pdrop": 0,
37
+ "embedding_fraction": 1.0,
38
+ "encoder_no_repeat_ngram_size": 0,
39
+ "eos_token_id": null,
40
+ "expansion_ratio": 4,
41
+ "exponential_decay_length_penalty": null,
42
+ "finetuning_task": null,
43
+ "forced_bos_token_id": null,
44
+ "forced_eos_token_id": null,
45
+ "hidden_size": 7168,
46
+ "id2label": {
47
+ "0": "LABEL_0",
48
+ "1": "LABEL_1"
49
+ },
50
+ "init_config": {
51
+ "emb_init_std": null,
52
+ "emb_init_uniform_lim": null,
53
+ "fan_mode": "fan_in",
54
+ "init_div_is_residual": true,
55
+ "init_gain": 0.0,
56
+ "init_nonlinearity": "relu",
57
+ "init_std": null,
58
+ "name": "kaiming_normal_",
59
+ "verbose": 0
60
+ },
61
+ "init_device": "cpu",
62
+ "is_decoder": false,
63
+ "is_encoder_decoder": false,
64
+ "label2id": {
65
+ "LABEL_0": 0,
66
+ "LABEL_1": 1
67
+ },
68
+ "learned_pos_emb": true,
69
+ "length_penalty": 1.0,
70
+ "logit_scale": null,
71
+ "max_length": 20,
72
+ "max_seq_len": 8192,
73
+ "min_length": 0,
74
+ "model_type": "mpt",
75
+ "n_heads": 64,
76
+ "n_layers": 48,
77
+ "no_bias": true,
78
+ "no_repeat_ngram_size": 0,
79
+ "norm_type": "low_precision_layernorm",
80
+ "num_beam_groups": 1,
81
+ "num_beams": 1,
82
+ "num_return_sequences": 1,
83
+ "output_attentions": false,
84
+ "output_hidden_states": false,
85
+ "output_scores": false,
86
+ "pad_token_id": null,
87
+ "prefix": null,
88
+ "problem_type": null,
89
+ "pruned_heads": {},
90
+ "remove_invalid_values": false,
91
+ "repetition_penalty": 1.0,
92
+ "resid_pdrop": 0,
93
+ "return_dict": true,
94
+ "return_dict_in_generate": false,
95
+ "sep_token_id": null,
96
+ "suppress_tokens": null,
97
+ "task_specific_params": null,
98
+ "temperature": 1.0,
99
+ "tf_legacy_loss": false,
100
+ "tie_encoder_decoder": false,
101
+ "tie_word_embeddings": true,
102
+ "tokenizer_class": null,
103
+ "tokenizer_name": "EleutherAI/gpt-neox-20b",
104
+ "top_k": 50,
105
+ "top_p": 1.0,
106
+ "torch_dtype": "bfloat16",
107
+ "torchscript": false,
108
+ "transformers_version": "4.30.1",
109
+ "typical_p": 1.0,
110
+ "use_bfloat16": false,
111
+ "use_cache": false,
112
+ "verbose": 0,
113
+ "vocab_size": 50432
114
+ },
115
+ "torch_dtype": "bfloat16",
116
+ "transformers_version": null,
117
+ "use_media_placement_augmentation": true,
118
+ "vision_config": {
119
+ "_name_or_path": "openai/clip-vit-large-patch14",
120
+ "add_cross_attention": false,
121
+ "architectures": null,
122
+ "attention_dropout": 0.0,
123
+ "bad_words_ids": null,
124
+ "begin_suppress_tokens": null,
125
+ "bos_token_id": null,
126
+ "chunk_size_feed_forward": 0,
127
+ "cross_attention_hidden_size": null,
128
+ "decoder_start_token_id": null,
129
+ "diversity_penalty": 0.0,
130
+ "do_sample": false,
131
+ "early_stopping": false,
132
+ "encoder_no_repeat_ngram_size": 0,
133
+ "eos_token_id": null,
134
+ "exponential_decay_length_penalty": null,
135
+ "finetuning_task": null,
136
+ "forced_bos_token_id": null,
137
+ "forced_eos_token_id": null,
138
+ "hidden_act": "quick_gelu",
139
+ "hidden_size": 1024,
140
+ "id2label": {
141
+ "0": "LABEL_0",
142
+ "1": "LABEL_1"
143
+ },
144
+ "image_size": 224,
145
+ "initializer_factor": 1.0,
146
+ "initializer_range": 0.02,
147
+ "intermediate_size": 4096,
148
+ "is_decoder": false,
149
+ "is_encoder_decoder": false,
150
+ "label2id": {
151
+ "LABEL_0": 0,
152
+ "LABEL_1": 1
153
+ },
154
+ "layer_norm_eps": 1e-05,
155
+ "length_penalty": 1.0,
156
+ "max_length": 20,
157
+ "min_length": 0,
158
+ "model_type": "clip_vision_model",
159
+ "no_repeat_ngram_size": 0,
160
+ "num_attention_heads": 16,
161
+ "num_beam_groups": 1,
162
+ "num_beams": 1,
163
+ "num_channels": 3,
164
+ "num_hidden_layers": 24,
165
+ "num_return_sequences": 1,
166
+ "output_attentions": false,
167
+ "output_hidden_states": false,
168
+ "output_scores": false,
169
+ "pad_token_id": null,
170
+ "patch_size": 14,
171
+ "prefix": null,
172
+ "problem_type": null,
173
+ "projection_dim": 512,
174
+ "pruned_heads": {},
175
+ "remove_invalid_values": false,
176
+ "repetition_penalty": 1.0,
177
+ "return_dict": true,
178
+ "return_dict_in_generate": false,
179
+ "sep_token_id": null,
180
+ "suppress_tokens": null,
181
+ "task_specific_params": null,
182
+ "temperature": 1.0,
183
+ "tf_legacy_loss": false,
184
+ "tie_encoder_decoder": false,
185
+ "tie_word_embeddings": true,
186
+ "tokenizer_class": null,
187
+ "top_k": 50,
188
+ "top_p": 1.0,
189
+ "torch_dtype": null,
190
+ "torchscript": false,
191
+ "transformers_version": "4.30.1",
192
+ "typical_p": 1.0,
193
+ "use_bfloat16": false
194
+ }
195
+ }
mllm/flamingo/flamingo-mpt-30B.json ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_commit_hash": null,
3
+ "architectures": [
4
+ "FlamingoForConditionalGeneration"
5
+ ],
6
+ "cross_attn_every_n_layers": 7,
7
+ "model_type": "flamingo",
8
+ "text_config": {
9
+ "_name_or_path": "",
10
+ "add_cross_attention": false,
11
+ "architectures": [
12
+ "MPTForCausalLM"
13
+ ],
14
+ "attn_config": {
15
+ "alibi": true,
16
+ "alibi_bias_max": 8,
17
+ "attn_impl": "torch",
18
+ "attn_pdrop": 0,
19
+ "attn_type": "multihead_attention",
20
+ "attn_uses_sequence_id": false,
21
+ "clip_qkv": null,
22
+ "prefix_lm": false,
23
+ "qk_ln": false,
24
+ "softmax_scale": null
25
+ },
26
+ "bad_words_ids": null,
27
+ "begin_suppress_tokens": null,
28
+ "bos_token_id": null,
29
+ "chunk_size_feed_forward": 0,
30
+ "cross_attention_hidden_size": null,
31
+ "d_model": 7168,
32
+ "decoder_start_token_id": null,
33
+ "diversity_penalty": 0.0,
34
+ "do_sample": false,
35
+ "early_stopping": false,
36
+ "emb_pdrop": 0,
37
+ "embedding_fraction": 1.0,
38
+ "encoder_no_repeat_ngram_size": 0,
39
+ "eos_token_id": null,
40
+ "expansion_ratio": 4,
41
+ "exponential_decay_length_penalty": null,
42
+ "finetuning_task": null,
43
+ "forced_bos_token_id": null,
44
+ "forced_eos_token_id": null,
45
+ "hidden_size": 7168,
46
+ "id2label": {
47
+ "0": "LABEL_0",
48
+ "1": "LABEL_1"
49
+ },
50
+ "init_config": {
51
+ "emb_init_std": null,
52
+ "emb_init_uniform_lim": null,
53
+ "fan_mode": "fan_in",
54
+ "init_div_is_residual": true,
55
+ "init_gain": 0.0,
56
+ "init_nonlinearity": "relu",
57
+ "init_std": null,
58
+ "name": "kaiming_normal_",
59
+ "verbose": 0
60
+ },
61
+ "init_device": "cpu",
62
+ "is_decoder": false,
63
+ "is_encoder_decoder": false,
64
+ "label2id": {
65
+ "LABEL_0": 0,
66
+ "LABEL_1": 1
67
+ },
68
+ "learned_pos_emb": true,
69
+ "length_penalty": 1.0,
70
+ "logit_scale": null,
71
+ "max_length": 20,
72
+ "max_seq_len": 8192,
73
+ "min_length": 0,
74
+ "model_type": "mpt",
75
+ "n_heads": 64,
76
+ "n_layers": 48,
77
+ "no_bias": true,
78
+ "no_repeat_ngram_size": 0,
79
+ "norm_type": "low_precision_layernorm",
80
+ "num_beam_groups": 1,
81
+ "num_beams": 1,
82
+ "num_return_sequences": 1,
83
+ "output_attentions": false,
84
+ "output_hidden_states": false,
85
+ "output_scores": false,
86
+ "pad_token_id": null,
87
+ "prefix": null,
88
+ "problem_type": null,
89
+ "pruned_heads": {},
90
+ "remove_invalid_values": false,
91
+ "repetition_penalty": 1.0,
92
+ "resid_pdrop": 0,
93
+ "return_dict": true,
94
+ "return_dict_in_generate": false,
95
+ "sep_token_id": null,
96
+ "suppress_tokens": null,
97
+ "task_specific_params": null,
98
+ "temperature": 1.0,
99
+ "tf_legacy_loss": false,
100
+ "tie_encoder_decoder": false,
101
+ "tie_word_embeddings": true,
102
+ "tokenizer_class": null,
103
+ "tokenizer_name": "EleutherAI/gpt-neox-20b",
104
+ "top_k": 50,
105
+ "top_p": 1.0,
106
+ "torch_dtype": "bfloat16",
107
+ "torchscript": false,
108
+ "transformers_version": "4.30.1",
109
+ "typical_p": 1.0,
110
+ "use_bfloat16": false,
111
+ "use_cache": false,
112
+ "verbose": 0,
113
+ "vocab_size": 50432
114
+ },
115
+ "torch_dtype": "float32",
116
+ "transformers_version": null,
117
+ "use_media_placement_augmentation": true,
118
+ "vision_config": {
119
+ "_name_or_path": "openai/clip-vit-large-patch14",
120
+ "add_cross_attention": false,
121
+ "architectures": null,
122
+ "attention_dropout": 0.0,
123
+ "bad_words_ids": null,
124
+ "begin_suppress_tokens": null,
125
+ "bos_token_id": null,
126
+ "chunk_size_feed_forward": 0,
127
+ "cross_attention_hidden_size": null,
128
+ "decoder_start_token_id": null,
129
+ "diversity_penalty": 0.0,
130
+ "do_sample": false,
131
+ "early_stopping": false,
132
+ "encoder_no_repeat_ngram_size": 0,
133
+ "eos_token_id": null,
134
+ "exponential_decay_length_penalty": null,
135
+ "finetuning_task": null,
136
+ "forced_bos_token_id": null,
137
+ "forced_eos_token_id": null,
138
+ "hidden_act": "quick_gelu",
139
+ "hidden_size": 1024,
140
+ "id2label": {
141
+ "0": "LABEL_0",
142
+ "1": "LABEL_1"
143
+ },
144
+ "image_size": 224,
145
+ "initializer_factor": 1.0,
146
+ "initializer_range": 0.02,
147
+ "intermediate_size": 4096,
148
+ "is_decoder": false,
149
+ "is_encoder_decoder": false,
150
+ "label2id": {
151
+ "LABEL_0": 0,
152
+ "LABEL_1": 1
153
+ },
154
+ "layer_norm_eps": 1e-05,
155
+ "length_penalty": 1.0,
156
+ "max_length": 20,
157
+ "min_length": 0,
158
+ "model_type": "clip_vision_model",
159
+ "no_repeat_ngram_size": 0,
160
+ "num_attention_heads": 16,
161
+ "num_beam_groups": 1,
162
+ "num_beams": 1,
163
+ "num_channels": 3,
164
+ "num_hidden_layers": 24,
165
+ "num_return_sequences": 1,
166
+ "output_attentions": false,
167
+ "output_hidden_states": false,
168
+ "output_scores": false,
169
+ "pad_token_id": null,
170
+ "patch_size": 14,
171
+ "prefix": null,
172
+ "problem_type": null,
173
+ "projection_dim": 512,
174
+ "pruned_heads": {},
175
+ "remove_invalid_values": false,
176
+ "repetition_penalty": 1.0,
177
+ "return_dict": true,
178
+ "return_dict_in_generate": false,
179
+ "sep_token_id": null,
180
+ "suppress_tokens": null,
181
+ "task_specific_params": null,
182
+ "temperature": 1.0,
183
+ "tf_legacy_loss": false,
184
+ "tie_encoder_decoder": false,
185
+ "tie_word_embeddings": true,
186
+ "tokenizer_class": null,
187
+ "top_k": 50,
188
+ "top_p": 1.0,
189
+ "torch_dtype": null,
190
+ "torchscript": false,
191
+ "transformers_version": "4.30.1",
192
+ "typical_p": 1.0,
193
+ "use_bfloat16": false
194
+ }
195
+ }
mllm/flamingo/flamingo-mpt-7B.json ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_commit_hash": null,
3
+ "architectures": [
4
+ "FlamingoForConditionalGeneration"
5
+ ],
6
+ "cross_attn_every_n_layers": 4,
7
+ "model_type": "flamingo",
8
+ "text_config": {
9
+ "_name_or_path": "",
10
+ "add_cross_attention": false,
11
+ "architectures": [
12
+ "MPTForCausalLM"
13
+ ],
14
+ "attn_config": {
15
+ "alibi": true,
16
+ "alibi_bias_max": 8,
17
+ "attn_impl": "torch",
18
+ "attn_pdrop": 0,
19
+ "attn_type": "multihead_attention",
20
+ "attn_uses_sequence_id": false,
21
+ "clip_qkv": null,
22
+ "prefix_lm": false,
23
+ "qk_ln": false,
24
+ "softmax_scale": null
25
+ },
26
+ "bad_words_ids": null,
27
+ "begin_suppress_tokens": null,
28
+ "bos_token_id": null,
29
+ "chunk_size_feed_forward": 0,
30
+ "cross_attention_hidden_size": null,
31
+ "d_model": 4096,
32
+ "decoder_start_token_id": null,
33
+ "diversity_penalty": 0.0,
34
+ "do_sample": false,
35
+ "early_stopping": false,
36
+ "emb_pdrop": 0,
37
+ "embedding_fraction": 1.0,
38
+ "encoder_no_repeat_ngram_size": 0,
39
+ "eos_token_id": null,
40
+ "expansion_ratio": 4,
41
+ "exponential_decay_length_penalty": null,
42
+ "finetuning_task": null,
43
+ "forced_bos_token_id": null,
44
+ "forced_eos_token_id": null,
45
+ "hidden_size": 4096,
46
+ "id2label": {
47
+ "0": "LABEL_0",
48
+ "1": "LABEL_1"
49
+ },
50
+ "init_config": {
51
+ "emb_init_std": null,
52
+ "emb_init_uniform_lim": null,
53
+ "fan_mode": "fan_in",
54
+ "init_div_is_residual": true,
55
+ "init_gain": 0,
56
+ "init_nonlinearity": "relu",
57
+ "init_std": 0.02,
58
+ "name": "kaiming_normal_",
59
+ "verbose": 0
60
+ },
61
+ "init_device": "cpu",
62
+ "is_decoder": false,
63
+ "is_encoder_decoder": false,
64
+ "label2id": {
65
+ "LABEL_0": 0,
66
+ "LABEL_1": 1
67
+ },
68
+ "learned_pos_emb": true,
69
+ "length_penalty": 1.0,
70
+ "logit_scale": null,
71
+ "max_length": 20,
72
+ "max_seq_len": 2048,
73
+ "min_length": 0,
74
+ "model_type": "mpt",
75
+ "n_heads": 32,
76
+ "n_layers": 32,
77
+ "no_bias": true,
78
+ "no_repeat_ngram_size": 0,
79
+ "norm_type": "low_precision_layernorm",
80
+ "num_beam_groups": 1,
81
+ "num_beams": 1,
82
+ "num_return_sequences": 1,
83
+ "output_attentions": false,
84
+ "output_hidden_states": false,
85
+ "output_scores": false,
86
+ "pad_token_id": null,
87
+ "prefix": null,
88
+ "problem_type": null,
89
+ "pruned_heads": {},
90
+ "remove_invalid_values": false,
91
+ "repetition_penalty": 1.0,
92
+ "resid_pdrop": 0,
93
+ "return_dict": true,
94
+ "return_dict_in_generate": false,
95
+ "sep_token_id": null,
96
+ "suppress_tokens": null,
97
+ "task_specific_params": null,
98
+ "temperature": 1.0,
99
+ "tf_legacy_loss": false,
100
+ "tie_encoder_decoder": false,
101
+ "tie_word_embeddings": true,
102
+ "tokenizer_class": null,
103
+ "tokenizer_name": "EleutherAI/gpt-neox-20b",
104
+ "top_k": 50,
105
+ "top_p": 1.0,
106
+ "torch_dtype": "bfloat16",
107
+ "torchscript": false,
108
+ "transformers_version": "4.30.1",
109
+ "typical_p": 1.0,
110
+ "use_bfloat16": false,
111
+ "use_cache": false,
112
+ "verbose": 0,
113
+ "vocab_size": 50432
114
+ },
115
+ "torch_dtype": "float32",
116
+ "transformers_version": null,
117
+ "use_media_placement_augmentation": true,
118
+ "vision_config": {
119
+ "_name_or_path": "openai/clip-vit-large-patch14",
120
+ "add_cross_attention": false,
121
+ "architectures": null,
122
+ "attention_dropout": 0.0,
123
+ "bad_words_ids": null,
124
+ "begin_suppress_tokens": null,
125
+ "bos_token_id": null,
126
+ "chunk_size_feed_forward": 0,
127
+ "cross_attention_hidden_size": null,
128
+ "decoder_start_token_id": null,
129
+ "diversity_penalty": 0.0,
130
+ "do_sample": false,
131
+ "early_stopping": false,
132
+ "encoder_no_repeat_ngram_size": 0,
133
+ "eos_token_id": null,
134
+ "exponential_decay_length_penalty": null,
135
+ "finetuning_task": null,
136
+ "forced_bos_token_id": null,
137
+ "forced_eos_token_id": null,
138
+ "hidden_act": "quick_gelu",
139
+ "hidden_size": 1024,
140
+ "id2label": {
141
+ "0": "LABEL_0",
142
+ "1": "LABEL_1"
143
+ },
144
+ "image_size": 224,
145
+ "initializer_factor": 1.0,
146
+ "initializer_range": 0.02,
147
+ "intermediate_size": 4096,
148
+ "is_decoder": false,
149
+ "is_encoder_decoder": false,
150
+ "label2id": {
151
+ "LABEL_0": 0,
152
+ "LABEL_1": 1
153
+ },
154
+ "layer_norm_eps": 1e-05,
155
+ "length_penalty": 1.0,
156
+ "max_length": 20,
157
+ "min_length": 0,
158
+ "model_type": "clip_vision_model",
159
+ "no_repeat_ngram_size": 0,
160
+ "num_attention_heads": 16,
161
+ "num_beam_groups": 1,
162
+ "num_beams": 1,
163
+ "num_channels": 3,
164
+ "num_hidden_layers": 24,
165
+ "num_return_sequences": 1,
166
+ "output_attentions": false,
167
+ "output_hidden_states": false,
168
+ "output_scores": false,
169
+ "pad_token_id": null,
170
+ "patch_size": 14,
171
+ "prefix": null,
172
+ "problem_type": null,
173
+ "projection_dim": 512,
174
+ "pruned_heads": {},
175
+ "remove_invalid_values": false,
176
+ "repetition_penalty": 1.0,
177
+ "return_dict": true,
178
+ "return_dict_in_generate": false,
179
+ "sep_token_id": null,
180
+ "suppress_tokens": null,
181
+ "task_specific_params": null,
182
+ "temperature": 1.0,
183
+ "tf_legacy_loss": false,
184
+ "tie_encoder_decoder": false,
185
+ "tie_word_embeddings": true,
186
+ "tokenizer_class": null,
187
+ "top_k": 50,
188
+ "top_p": 1.0,
189
+ "torch_dtype": null,
190
+ "torchscript": false,
191
+ "transformers_version": "4.30.1",
192
+ "typical_p": 1.0,
193
+ "use_bfloat16": false
194
+ }
195
+ }
mllm/flamingo/flamingo-vicuna-33B-v1.3.json ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_commit_hash": null,
3
+ "architectures": [
4
+ "FlamingoForConditionalGeneration"
5
+ ],
6
+ "cross_attn_every_n_layers": 4,
7
+ "model_type": "flamingo",
8
+ "text_config": {
9
+ "_name_or_path": "/home/luodian/projects/checkpoints/vicuna-33b-v1.3",
10
+ "architectures": [
11
+ "LlamaForCausalLM"
12
+ ],
13
+ "bos_token_id": 1,
14
+ "eos_token_id": 2,
15
+ "hidden_act": "silu",
16
+ "hidden_size": 6656,
17
+ "initializer_range": 0.02,
18
+ "intermediate_size": 17920,
19
+ "max_position_embeddings": 2048,
20
+ "model_type": "llama",
21
+ "num_attention_heads": 52,
22
+ "num_hidden_layers": 60,
23
+ "pad_token_id": 0,
24
+ "rms_norm_eps": 1e-06,
25
+ "tie_word_embeddings": false,
26
+ "torch_dtype": "float16",
27
+ "transformers_version": "4.28.1",
28
+ "use_cache": false,
29
+ "vocab_size": 32000
30
+ },
31
+ "torch_dtype": "float32",
32
+ "transformers_version": null,
33
+ "use_media_placement_augmentation": true,
34
+ "vision_config": {
35
+ "_name_or_path": "openai/clip-vit-large-patch14",
36
+ "add_cross_attention": false,
37
+ "architectures": null,
38
+ "attention_dropout": 0.0,
39
+ "bad_words_ids": null,
40
+ "begin_suppress_tokens": null,
41
+ "bos_token_id": null,
42
+ "chunk_size_feed_forward": 0,
43
+ "cross_attention_hidden_size": null,
44
+ "decoder_start_token_id": null,
45
+ "diversity_penalty": 0.0,
46
+ "do_sample": false,
47
+ "early_stopping": false,
48
+ "encoder_no_repeat_ngram_size": 0,
49
+ "eos_token_id": null,
50
+ "exponential_decay_length_penalty": null,
51
+ "finetuning_task": null,
52
+ "forced_bos_token_id": null,
53
+ "forced_eos_token_id": null,
54
+ "hidden_act": "quick_gelu",
55
+ "hidden_size": 1024,
56
+ "id2label": {
57
+ "0": "LABEL_0",
58
+ "1": "LABEL_1"
59
+ },
60
+ "image_size": 224,
61
+ "initializer_factor": 1.0,
62
+ "initializer_range": 0.02,
63
+ "intermediate_size": 4096,
64
+ "is_decoder": false,
65
+ "is_encoder_decoder": false,
66
+ "label2id": {
67
+ "LABEL_0": 0,
68
+ "LABEL_1": 1
69
+ },
70
+ "layer_norm_eps": 1e-05,
71
+ "length_penalty": 1.0,
72
+ "max_length": 20,
73
+ "min_length": 0,
74
+ "model_type": "clip_vision_model",
75
+ "no_repeat_ngram_size": 0,
76
+ "num_attention_heads": 16,
77
+ "num_beam_groups": 1,
78
+ "num_beams": 1,
79
+ "num_channels": 3,
80
+ "num_hidden_layers": 24,
81
+ "num_return_sequences": 1,
82
+ "output_attentions": false,
83
+ "output_hidden_states": false,
84
+ "output_scores": false,
85
+ "pad_token_id": null,
86
+ "patch_size": 14,
87
+ "prefix": null,
88
+ "problem_type": null,
89
+ "projection_dim": 512,
90
+ "pruned_heads": {},
91
+ "remove_invalid_values": false,
92
+ "repetition_penalty": 1.0,
93
+ "return_dict": true,
94
+ "return_dict_in_generate": false,
95
+ "sep_token_id": null,
96
+ "suppress_tokens": null,
97
+ "task_specific_params": null,
98
+ "temperature": 1.0,
99
+ "tf_legacy_loss": false,
100
+ "tie_encoder_decoder": false,
101
+ "tie_word_embeddings": true,
102
+ "tokenizer_class": null,
103
+ "top_k": 50,
104
+ "top_p": 1.0,
105
+ "torch_dtype": null,
106
+ "torchscript": false,
107
+ "transformers_version": "4.30.1",
108
+ "typical_p": 1.0,
109
+ "use_bfloat16": false
110
+ }
111
+ }
mllm/flamingo/flamingo-vicuna-7B-v1.3.json ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_commit_hash": null,
3
+ "architectures": [
4
+ "FlamingoForConditionalGeneration"
5
+ ],
6
+ "cross_attn_every_n_layers": 4,
7
+ "model_type": "flamingo",
8
+ "text_config": {
9
+ "_name_or_path": "/mnt/petrelfs/share_data/zhangyuanhan/vicuna-7b-v1.3",
10
+ "architectures": [
11
+ "LlamaForCausalLM"
12
+ ],
13
+ "bos_token_id": 1,
14
+ "eos_token_id": 2,
15
+ "hidden_act": "silu",
16
+ "hidden_size": 4096,
17
+ "initializer_range": 0.02,
18
+ "intermediate_size": 11008,
19
+ "max_position_embeddings": 2048,
20
+ "model_type": "llama",
21
+ "num_attention_heads": 32,
22
+ "num_hidden_layers": 32,
23
+ "pad_token_id": 0,
24
+ "rms_norm_eps": 1e-06,
25
+ "tie_word_embeddings": false,
26
+ "torch_dtype": "float16",
27
+ "transformers_version": "4.28.1",
28
+ "use_cache": false,
29
+ "vocab_size": 32000
30
+ },
31
+ "torch_dtype": "float32",
32
+ "transformers_version": null,
33
+ "use_media_placement_augmentation": true,
34
+ "vision_config": {
35
+ "_name_or_path": "openai/clip-vit-large-patch14",
36
+ "add_cross_attention": false,
37
+ "architectures": null,
38
+ "attention_dropout": 0.0,
39
+ "bad_words_ids": null,
40
+ "begin_suppress_tokens": null,
41
+ "bos_token_id": null,
42
+ "chunk_size_feed_forward": 0,
43
+ "cross_attention_hidden_size": null,
44
+ "decoder_start_token_id": null,
45
+ "diversity_penalty": 0.0,
46
+ "do_sample": false,
47
+ "early_stopping": false,
48
+ "encoder_no_repeat_ngram_size": 0,
49
+ "eos_token_id": null,
50
+ "exponential_decay_length_penalty": null,
51
+ "finetuning_task": null,
52
+ "forced_bos_token_id": null,
53
+ "forced_eos_token_id": null,
54
+ "hidden_act": "quick_gelu",
55
+ "hidden_size": 1024,
56
+ "id2label": {
57
+ "0": "LABEL_0",
58
+ "1": "LABEL_1"
59
+ },
60
+ "image_size": 224,
61
+ "initializer_factor": 1.0,
62
+ "initializer_range": 0.02,
63
+ "intermediate_size": 4096,
64
+ "is_decoder": false,
65
+ "is_encoder_decoder": false,
66
+ "label2id": {
67
+ "LABEL_0": 0,
68
+ "LABEL_1": 1
69
+ },
70
+ "layer_norm_eps": 1e-05,
71
+ "length_penalty": 1.0,
72
+ "max_length": 20,
73
+ "min_length": 0,
74
+ "model_type": "clip_vision_model",
75
+ "no_repeat_ngram_size": 0,
76
+ "num_attention_heads": 16,
77
+ "num_beam_groups": 1,
78
+ "num_beams": 1,
79
+ "num_channels": 3,
80
+ "num_hidden_layers": 24,
81
+ "num_return_sequences": 1,
82
+ "output_attentions": false,
83
+ "output_hidden_states": false,
84
+ "output_scores": false,
85
+ "pad_token_id": null,
86
+ "patch_size": 14,
87
+ "prefix": null,
88
+ "problem_type": null,
89
+ "projection_dim": 512,
90
+ "pruned_heads": {},
91
+ "remove_invalid_values": false,
92
+ "repetition_penalty": 1.0,
93
+ "return_dict": true,
94
+ "return_dict_in_generate": false,
95
+ "sep_token_id": null,
96
+ "suppress_tokens": null,
97
+ "task_specific_params": null,
98
+ "temperature": 1.0,
99
+ "tf_legacy_loss": false,
100
+ "tie_encoder_decoder": false,
101
+ "tie_word_embeddings": true,
102
+ "tokenizer_class": null,
103
+ "top_k": 50,
104
+ "top_p": 1.0,
105
+ "torch_dtype": null,
106
+ "torchscript": false,
107
+ "transformers_version": "4.30.1",
108
+ "typical_p": 1.0,
109
+ "use_bfloat16": false
110
+ }
111
+ }
mllm/flamingo/injecting_falcon_into_flamingo.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from .configuration_flamingo import FlamingoConfig
4
+ from .modeling_flamingo import FlamingoForConditionalGeneration
5
+
6
+ root_dir = os.environ["AZP"]
7
+ print(root_dir)
8
+
9
+
10
+ config = FlamingoConfig.from_json_file(".flamingo-falcon-7B.json")
11
+ model = FlamingoForConditionalGeneration(config=config)
12
+
13
+
14
+ state_dict_files = [
15
+ f"{root_dir}/otter/checkpoints/falcon-7b/pytorch_model-00001-of-00002.bin",
16
+ f"{root_dir}/otter/checkpoints/falcon-7b/pytorch_model-00002-of-00002.bin",
17
+ ]
18
+
19
+ state_dict = {}
20
+ for file in state_dict_files:
21
+ state_dict_part = torch.load(file, map_location="cpu")
22
+ state_dict.update(state_dict_part)
23
+
24
+
25
+ state_dict_3 = torch.load("{root_dir}/otter/checkpoints/flamingo_9b_hf/pytorch_model-00004-of-00004.bin", map_location="cpu")
26
+ for cur_key in list(state_dict_3.keys()):
27
+ if "vision_encoder" not in cur_key:
28
+ del state_dict_3[cur_key]
29
+
30
+ _ = model.load_state_dict(
31
+ state_dict_3,
32
+ False,
33
+ )
34
+ print(_[1])
35
+
36
+ save_state_dict_1 = {}
37
+ for key in state_dict:
38
+ if ".h." in key:
39
+ _, _, layer_num, *remain_names = key.split(".")
40
+ target_key = f"transformer.h.{layer_num}.decoder_layer.{'.'.join(remain_names)}"
41
+ else:
42
+ target_key = key
43
+ save_state_dict_1[f"{target_key}"] = state_dict[key]
44
+ _ = model.lang_encoder.load_state_dict(
45
+ save_state_dict_1,
46
+ False,
47
+ )
48
+ print(_[1])
49
+ model.save_pretrained(f"{root_dir}/otter/checkpoints/flamingo-falcon-7b/")
mllm/flamingo/injecting_llama2_into_flamingo.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+
4
+ import torch
5
+ from tqdm import tqdm
6
+
7
+ import sys
8
+
9
+ from .configuration_flamingo import FlamingoConfig
10
+ from .modeling_flamingo import FlamingoForConditionalGeneration
11
+
12
+ # from .configuration_flamingo import FlamingoConfig
13
+ # from .modeling_flamingo import FlamingoForConditionalGeneration
14
+
15
+ parser = argparse.ArgumentParser(description="Convert Vicuna model")
16
+ parser.add_argument("--model_choice", type=str, default="13B", help="Choose either '7B' or '13B'")
17
+ parser.add_argument("--llama2_root_dir", type=str, default="/home/luodian/projects/checkpoints")
18
+ parser.add_argument("--save_root_dir", type=str, default="/home/luodian/projects/checkpoints")
19
+ args = parser.parse_args()
20
+
21
+ # os.environ["TOKENIZERS_PARALLELISM"] = "false"
22
+
23
+ root_dir = args.llama2_root_dir
24
+ model_choice = args.model_choice
25
+ save_root_dir = args.save_root_dir
26
+
27
+ # prepare vicuna model at first
28
+ # you can visit https://huggingface.co/lmsys/Llama-2-33b-chat-hf to download 7B and 30B instruct checkpoints.
29
+ if model_choice == "7B":
30
+ config_file = "./flamingo/flamingo-llama2-chat-7B.json"
31
+ state_dict_files = [
32
+ f"{root_dir}/Llama-2-7b-chat-hf/pytorch_model-00001-of-00002.bin",
33
+ f"{root_dir}/Llama-2-7b-chat-hf/pytorch_model-00002-of-00002.bin",
34
+ ]
35
+ save_path = f"{save_root_dir}/flamingo-llama2-chat-7B-init"
36
+ elif model_choice == "13B":
37
+ config_file = "./flamingo/flamingo-llama2-chat-13B.json"
38
+ state_dict_files = [
39
+ f"{root_dir}/Llama-2-13b-chat-hf/pytorch_model-00001-of-00003.bin",
40
+ f"{root_dir}/Llama-2-13b-chat-hf/pytorch_model-00002-of-00003.bin",
41
+ f"{root_dir}/Llama-2-13b-chat-hf/pytorch_model-00003-of-00003.bin",
42
+ ]
43
+ save_path = f"{save_root_dir}/flamingo-llama2-chat-13B-init"
44
+ else:
45
+ raise ValueError("Invalid model_choice. Choose either '13B' or '7B'.")
46
+
47
+ config = FlamingoConfig.from_json_file(config_file)
48
+ model = FlamingoForConditionalGeneration(config=config)
49
+
50
+ # load flamingo's vision encoder from last checkpoint.
51
+ # you can visit https://huggingface.co/luodian/openflamingo-9b-hf/tree/main to download the checkpoint.
52
+ # AZP = "os.environ["AZP"]"
53
+ AZP = os.environ["AZP"]
54
+ state_dict_3 = torch.load(f"{AZP}/otter/checkpoints/flamingo_9b_hf/pytorch_model-00004-of-00004.bin", map_location="cpu")
55
+ for cur_key in list(state_dict_3.keys()):
56
+ if "vision_encoder" not in cur_key:
57
+ del state_dict_3[cur_key]
58
+
59
+ load_msg = model.load_state_dict(
60
+ state_dict_3,
61
+ False,
62
+ )
63
+ # print incompatible keys
64
+ print(load_msg[1])
65
+
66
+ # Loading vicuna weights
67
+ state_dict = {}
68
+ for file in tqdm(state_dict_files, desc="Loading state dict"):
69
+ state_dict_part = torch.load(file, map_location="cpu")
70
+ state_dict.update(state_dict_part)
71
+
72
+ save_state_dict_1 = {}
73
+ for key in state_dict:
74
+ if ".layers." in key:
75
+ _, _, layer_num, *remain_names = key.split(".")
76
+ target_key = f"model.layers.{layer_num}.decoder_layer.{'.'.join(remain_names)}"
77
+ else:
78
+ target_key = key
79
+ save_state_dict_1[f"{target_key}"] = state_dict[key]
80
+
81
+ # Reshape the token embedding to 50280 for compatible
82
+ model.lang_encoder.resize_token_embeddings(32000)
83
+
84
+ load_msg = model.lang_encoder.load_state_dict(
85
+ save_state_dict_1,
86
+ False,
87
+ )
88
+ # Reshape the token embedding to 32002 for compatible
89
+ model.lang_encoder.resize_token_embeddings(32002)
90
+ # print incompatible keys
91
+ print(load_msg[1])
92
+
93
+
94
+ print(f"Saving model to {save_path}...")
95
+ model.save_pretrained(save_path, max_shard_size="10GB")
mllm/flamingo/injecting_mpt-1B-redpajama_into_flamingo.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+
4
+ import torch
5
+ from tqdm import tqdm
6
+
7
+ import sys
8
+
9
+ from configuration_flamingo import FlamingoConfig
10
+ from modeling_flamingo import FlamingoForConditionalGeneration
11
+ from utils import rename_flamingo_checkpoint
12
+
13
+
14
+ parser = argparse.ArgumentParser(description="Convert MPT model")
15
+ parser.add_argument("--mpt_root_dir", type=str, default="/home/luodian/projects/checkpoints")
16
+ parser.add_argument("--save_root_dir", type=str, default="/home/luodian/projects/checkpoints")
17
+ parser.add_argument("--flamingo_dir", type=str, default=None, help="If the pretrained flamingo weights also need to be injected")
18
+ args = parser.parse_args()
19
+
20
+
21
+ root_dir = args.mpt_root_dir
22
+ save_root_dir = args.save_root_dir
23
+
24
+ # prepare mpt model at first
25
+ # you can visit https://huggingface.co/mosaicml to download 7B and 30B instruct checkpoints.
26
+ config_file = "./flamingo/flamingo-mpt-1B-redpajama.json"
27
+ state_dict_file = f"{root_dir}/pytorch_model.bin"
28
+ save_path = f"{save_root_dir}/flamingo-mpt-1b-redpajama-200b-dolly"
29
+
30
+ config = FlamingoConfig.from_json_file(config_file)
31
+
32
+ model = FlamingoForConditionalGeneration(config=config)
33
+
34
+ # Loading mpt weights
35
+ state_dict = torch.load(state_dict_file, map_location="cpu")
36
+ save_state_dict_1 = {}
37
+ for key in state_dict:
38
+ if ".blocks." in key:
39
+ _, _, layer_num, *remain_names = key.split(".")
40
+ target_key = f"transformer.blocks.{layer_num}.decoder_layer.{'.'.join(remain_names)}"
41
+ else:
42
+ target_key = key
43
+ save_state_dict_1[f"{target_key}"] = state_dict[key]
44
+
45
+ load_msg = model.lang_encoder.load_state_dict(
46
+ save_state_dict_1,
47
+ False,
48
+ )
49
+
50
+ # load flamingo's vision encoder from last checkpoint.
51
+ # you can visit https://huggingface.co/luodian/openflamingo-9b-hf/tree/main to download the checkpoint.
52
+ AZP = os.environ["AZP"]
53
+ state_dict_3 = torch.load(f"{AZP}/pytorch_model-00004-of-00004.bin", map_location="cpu")
54
+ for cur_key in list(state_dict_3.keys()):
55
+ if "vision_encoder" not in cur_key:
56
+ del state_dict_3[cur_key]
57
+
58
+ load_msg = model.load_state_dict(
59
+ state_dict_3,
60
+ False,
61
+ )
62
+ # print incompatible keys
63
+ print(load_msg[1])
64
+
65
+ save_state_dict_1 = {}
66
+ for key in state_dict:
67
+ if ".blocks." in key:
68
+ _, _, layer_num, *remain_names = key.split(".")
69
+ target_key = f"transformer.blocks.{layer_num}.decoder_layer.{'.'.join(remain_names)}"
70
+ else:
71
+ target_key = key
72
+ save_state_dict_1[f"{target_key}"] = state_dict[key]
73
+
74
+ load_msg = model.lang_encoder.load_state_dict(
75
+ save_state_dict_1,
76
+ False,
77
+ )
78
+ # print incompatible keys
79
+ print(load_msg[1])
80
+ if args.flamingo_dir is not None:
81
+ state_dict_2 = torch.load(f"{args.flamingo_dir}/checkpoint.pt", map_location="cpu")
82
+ save_state_dict_2 = rename_flamingo_checkpoint(state_dict_2)
83
+ real_vocab_size = config.text_config.vocab_size
84
+ # Reshape the token embedding to 50280 for compatible
85
+ model.lang_encoder.resize_token_embeddings(save_state_dict_2["lang_encoder.transformer.wte.weight"].shape[0])
86
+
87
+ load_msg = model.load_state_dict(
88
+ save_state_dict_2,
89
+ False,
90
+ )
91
+ # print incompatible keys
92
+ print(load_msg[1])
93
+ # Reshape the token embedding to 50432
94
+ model.lang_encoder.resize_token_embeddings(real_vocab_size)
95
+
96
+ print(f"Saving model to {save_path}...")
97
+ model.save_pretrained(save_path, max_shard_size="10GB")
mllm/flamingo/injecting_mpt_into_flamingo.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+
4
+ import torch
5
+ from tqdm import tqdm
6
+
7
+ import sys
8
+
9
+ from configuration_flamingo import FlamingoConfig
10
+ from modeling_flamingo import FlamingoForConditionalGeneration
11
+ from utils import rename_flamingo_checkpoint
12
+
13
+ parser = argparse.ArgumentParser(description="Convert MPT model")
14
+ parser.add_argument("--model_choice", type=str, choices=["7B", "30B"], required=True, help="Choose either '7B' or '30B'")
15
+ parser.add_argument("--mpt_root_dir", type=str, default="/home/luodian/projects/checkpoints")
16
+ parser.add_argument("--save_root_dir", type=str, default="/home/luodian/projects/checkpoints")
17
+ parser.add_argument("--flamingo_dir", type=str, default=None, help="If the pretrained flamingo weights also need to be injected")
18
+ args = parser.parse_args()
19
+
20
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
21
+
22
+ root_dir = args.mpt_root_dir
23
+ model_choice = args.model_choice
24
+ save_root_dir = args.save_root_dir
25
+
26
+ # prepare mpt model at first
27
+ # you can visit https://huggingface.co/mosaicml to download 7B and 30B instruct checkpoints.
28
+ if model_choice == "30B":
29
+ config_file = "./flamingo/flamingo-mpt-30B.json"
30
+ state_dict_files = [
31
+ f"{root_dir}/mpt-30b-instruct/pytorch_model-00001-of-00007.bin",
32
+ f"{root_dir}/mpt-30b-instruct/pytorch_model-00002-of-00007.bin",
33
+ f"{root_dir}/mpt-30b-instruct/pytorch_model-00003-of-00007.bin",
34
+ f"{root_dir}/mpt-30b-instruct/pytorch_model-00004-of-00007.bin",
35
+ f"{root_dir}/mpt-30b-instruct/pytorch_model-00005-of-00007.bin",
36
+ f"{root_dir}/mpt-30b-instruct/pytorch_model-00006-of-00007.bin",
37
+ f"{root_dir}/mpt-30b-instruct/pytorch_model-00007-of-00007.bin",
38
+ ]
39
+ save_path = f"{save_root_dir}/flamingo-mpt-30B-instruct-init"
40
+ elif model_choice == "7B":
41
+ config_file = "./flamingo/flamingo-mpt-7B.json"
42
+ state_dict_files = [
43
+ f"{root_dir}/mpt-7b/pytorch_model-00001-of-00002.bin",
44
+ f"{root_dir}/mpt-7b/pytorch_model-00002-of-00002.bin",
45
+ ]
46
+ save_path = f"{save_root_dir}/flamingo-mpt-7B"
47
+ else:
48
+ raise ValueError("Invalid model_choice. Choose either '30B' or '7B'.")
49
+
50
+ config = FlamingoConfig.from_json_file(config_file)
51
+
52
+ model = FlamingoForConditionalGeneration(config=config)
53
+
54
+
55
+ # load flamingo's vision encoder from last checkpoint.
56
+ # you can visit https://huggingface.co/luodian/openflamingo-9b-hf/tree/main to download the checkpoint.
57
+ AZP = os.environ["AZP"]
58
+ state_dict_3 = torch.load(f"{AZP}/otter/checkpoints/flamingo_9b_hf/pytorch_model-00004-of-00004.bin", map_location="cpu")
59
+ for cur_key in list(state_dict_3.keys()):
60
+ if "vision_encoder" not in cur_key:
61
+ del state_dict_3[cur_key]
62
+
63
+ load_msg = model.load_state_dict(
64
+ state_dict_3,
65
+ False,
66
+ )
67
+ # print incompatible keys
68
+ print(load_msg[1])
69
+
70
+ # Loading mpt weights
71
+ state_dict = {}
72
+ for file in tqdm(state_dict_files, desc="Loading state dict"):
73
+ state_dict_part = torch.load(file, map_location="cpu")
74
+ state_dict.update(state_dict_part)
75
+
76
+ save_state_dict_1 = {}
77
+ for key in state_dict:
78
+ if ".blocks." in key:
79
+ _, _, layer_num, *remain_names = key.split(".")
80
+ target_key = f"transformer.blocks.{layer_num}.decoder_layer.{'.'.join(remain_names)}"
81
+ else:
82
+ target_key = key
83
+ save_state_dict_1[f"{target_key}"] = state_dict[key]
84
+
85
+ load_msg = model.lang_encoder.load_state_dict(
86
+ save_state_dict_1,
87
+ False,
88
+ )
89
+ # print incompatible keys
90
+ print(load_msg[1])
91
+ if args.flamingo_dir is not None:
92
+ state_dict_2 = torch.load(f"{args.flamingo_dir}/checkpoint.pt", map_location="cpu")
93
+ save_state_dict_2 = rename_flamingo_checkpoint(state_dict_2)
94
+
95
+ real_vocab_size = config.text_config.vocab_size
96
+ # Reshape the token embedding to 50280 for compatible
97
+ model.lang_encoder.resize_token_embeddings(save_state_dict_2["lang_encoder.transformer.wte.weight"].shape[0])
98
+
99
+ load_msg = model.load_state_dict(
100
+ save_state_dict_2,
101
+ False,
102
+ )
103
+ # print incompatible keys
104
+ print(load_msg[1])
105
+ # Reshape the token embedding to 50432
106
+ model.lang_encoder.resize_token_embeddings(real_vocab_size)
107
+
108
+ print(f"Saving model to {save_path}...")
109
+ model.save_pretrained(save_path, max_shard_size="10GB")
mllm/flamingo/injecting_vicuna_into_flamingo.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+
4
+ import torch
5
+ from tqdm import tqdm
6
+
7
+ import sys
8
+
9
+ from .configuration_flamingo import FlamingoConfig
10
+ from .modeling_flamingo import FlamingoForConditionalGeneration
11
+
12
+ # from .configuration_flamingo import FlamingoConfig
13
+ # from .modeling_flamingo import FlamingoForConditionalGeneration
14
+
15
+ parser = argparse.ArgumentParser(description="Convert Vicuna model")
16
+ parser.add_argument("--model_choice", type=str, choices=["7B", "33B"], required=True, help="Choose either '7B' or '33B'")
17
+ parser.add_argument("--vicuna_root_dir", type=str, default="/home/luodian/projects/checkpoints")
18
+ parser.add_argument("--save_root_dir", type=str, default="/home/luodian/projects/checkpoints")
19
+ parser.add_argument("--flamingo_dir", type=str, default=None, help="If the pretrained flamingo weights also need to be injected")
20
+ args = parser.parse_args()
21
+
22
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
23
+
24
+ root_dir = args.vicuna_root_dir
25
+ model_choice = args.model_choice
26
+ save_root_dir = args.save_root_dir
27
+
28
+ # prepare vicuna model at first
29
+ # you can visit https://huggingface.co/lmsys/vicuna-33b-v1.3 to download 7B and 30B instruct checkpoints.
30
+ if model_choice == "33B":
31
+ config_file = "./flamingo/flamingo-vicuna-33B-v1.3.json"
32
+ state_dict_files = [
33
+ f"{root_dir}/vicuna-33b-v1.3/pytorch_model-00001-of-00007.bin",
34
+ f"{root_dir}/vicuna-33b-v1.3/pytorch_model-00002-of-00007.bin",
35
+ f"{root_dir}/vicuna-33b-v1.3/pytorch_model-00003-of-00007.bin",
36
+ f"{root_dir}/vicuna-33b-v1.3/pytorch_model-00004-of-00007.bin",
37
+ f"{root_dir}/vicuna-33b-v1.3/pytorch_model-00005-of-00007.bin",
38
+ f"{root_dir}/vicuna-33b-v1.3/pytorch_model-00006-of-00007.bin",
39
+ f"{root_dir}/vicuna-33b-v1.3/pytorch_model-00007-of-00007.bin",
40
+ ]
41
+ save_path = f"{save_root_dir}/flamingo-vicuna-33B-v1.3-init"
42
+ elif model_choice == "7B":
43
+ config_file = "./flamingo/flamingo-vicuna-7B-v1.3.json"
44
+ state_dict_files = [
45
+ f"{root_dir}/vicuna-7b-v1.3/pytorch_model-00001-of-00002.bin",
46
+ f"{root_dir}/vicuna-7b-v1.3/pytorch_model-00002-of-00002.bin",
47
+ ]
48
+ save_path = f"{save_root_dir}/flamingo-vicuna-7B-v1.3-init"
49
+ else:
50
+ raise ValueError("Invalid model_choice. Choose either '33B' or '7B'.")
51
+
52
+ config = FlamingoConfig.from_json_file(config_file)
53
+ model = FlamingoForConditionalGeneration(config=config)
54
+
55
+ # load flamingo's vision encoder from last checkpoint.
56
+ # you can visit https://huggingface.co/luodian/openflamingo-9b-hf/tree/main to download the checkpoint.
57
+ # AZP = "os.environ["AZP"]"
58
+ AZP = os.environ["AZP"]
59
+ state_dict_3 = torch.load(f"{AZP}/otter/checkpoints/flamingo_9b_hf/pytorch_model-00004-of-00004.bin", map_location="cpu")
60
+ for cur_key in list(state_dict_3.keys()):
61
+ if "vision_encoder" not in cur_key:
62
+ del state_dict_3[cur_key]
63
+
64
+ load_msg = model.load_state_dict(
65
+ state_dict_3,
66
+ False,
67
+ )
68
+ # print incompatible keys
69
+ print(load_msg[1])
70
+
71
+ # Loading vicuna weights
72
+ state_dict = {}
73
+ for file in tqdm(state_dict_files, desc="Loading state dict"):
74
+ state_dict_part = torch.load(file, map_location="cpu")
75
+ state_dict.update(state_dict_part)
76
+
77
+ save_state_dict_1 = {}
78
+ for key in state_dict:
79
+ if ".layers." in key:
80
+ _, _, layer_num, *remain_names = key.split(".")
81
+ target_key = f"model.layers.{layer_num}.decoder_layer.{'.'.join(remain_names)}"
82
+ else:
83
+ target_key = key
84
+ save_state_dict_1[f"{target_key}"] = state_dict[key]
85
+
86
+ # Reshape the token embedding to 50280 for compatible
87
+ model.lang_encoder.resize_token_embeddings(32000)
88
+
89
+ load_msg = model.lang_encoder.load_state_dict(
90
+ save_state_dict_1,
91
+ False,
92
+ )
93
+ # Reshape the token embedding to 32002 for compatible
94
+ model.lang_encoder.resize_token_embeddings(32002)
95
+ # print incompatible keys
96
+ print(load_msg[1])
97
+
98
+
99
+ print(f"Saving model to {save_path}...")
100
+ model.save_pretrained(save_path, max_shard_size="10GB")
mllm/flamingo/modeling_flamingo.py ADDED
@@ -0,0 +1,966 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ from dataclasses import dataclass
3
+ from typing import Callable, Optional
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ from accelerate.hooks import AlignDevicesHook, add_hook_to_module
8
+ from einops import rearrange, repeat
9
+ from transformers import CLIPVisionModel, LlamaForCausalLM, LlamaTokenizer
10
+ from transformers.modeling_outputs import CausalLMOutputWithPast
11
+ from transformers.modeling_utils import PreTrainedModel
12
+ from transformers.models.auto import AutoModel, AutoModelForCausalLM, AutoTokenizer
13
+
14
+ from .configuration_flamingo import FlamingoConfig
15
+ from .falcon.modelling_RW import RWForCausalLM
16
+ from .mpt.modeling_mpt import MPTForCausalLM
17
+ from .mpt_redpajama.mosaic_gpt import MosaicGPT
18
+
19
+ # from .configuration_flamingo import FlamingoConfig
20
+
21
+ __KNOWN_DECODER_LAYERS_ATTR_NAMES = {
22
+ "opt": "model.decoder.layers",
23
+ "gptneo": "transformer.h",
24
+ "gptj": "transformer.h",
25
+ "gpt-j": "transformer.h",
26
+ "pythia": "gpt_neox.layers",
27
+ "llama": "model.layers",
28
+ "RWForCausalLM": "transformer.h",
29
+ "MPTForCausalLM": "transformer.blocks",
30
+ "MosaicGPT": "transformer.blocks",
31
+ }
32
+
33
+
34
+ def _infer_decoder_layers_attr_name(model: nn.Module):
35
+ for k in __KNOWN_DECODER_LAYERS_ATTR_NAMES:
36
+ if k.lower() in model.__class__.__name__.lower():
37
+ return __KNOWN_DECODER_LAYERS_ATTR_NAMES[k]
38
+
39
+ raise ValueError(
40
+ f"We require the attribute name for the nn.ModuleList in the decoder storing the transformer block layers. Please supply this string manually."
41
+ )
42
+
43
+
44
+ def extend_instance(obj, mixin):
45
+ """Apply mixins to a class instance after creation"""
46
+ base_cls = obj.__class__
47
+ base_cls_name = obj.__class__.__name__
48
+ obj.__class__ = type(base_cls_name, (mixin, base_cls), {}) # mixin needs to go first for our forward() logic to work
49
+
50
+
51
+ def getattr_recursive(obj, att):
52
+ """
53
+ Return nested attribute of obj
54
+ Example: getattr_recursive(obj, 'a.b.c') is equivalent to obj.a.b.c
55
+ """
56
+ if att == "":
57
+ return obj
58
+ i = att.find(".")
59
+ if i < 0:
60
+ return getattr(obj, att)
61
+ else:
62
+ return getattr_recursive(getattr(obj, att[:i]), att[i + 1 :])
63
+
64
+
65
+ def setattr_recursive(obj, att, val):
66
+ """
67
+ Set nested attribute of obj
68
+ Example: setattr_recursive(obj, 'a.b.c', val) is equivalent to obj.a.b.c = val
69
+ """
70
+ if "." in att:
71
+ obj = getattr_recursive(obj, ".".join(att.split(".")[:-1]))
72
+ setattr(obj, att.split(".")[-1], val)
73
+
74
+
75
+ def exists(val):
76
+ return val is not None
77
+
78
+
79
+ class FlamingoPerceiverBlock(nn.Module):
80
+ def __init__(self, *, dim: int, dim_head: int = 64, heads: int = 8, mult: int = 4):
81
+ super().__init__()
82
+ self.scale = dim_head**-0.5
83
+ self.heads = heads
84
+ inner_dim = dim_head * heads
85
+ ff_dim = dim * mult
86
+ self.norm_media = nn.LayerNorm(dim)
87
+ self.norm_latents = nn.LayerNorm(dim)
88
+
89
+ self.to_q = nn.Linear(dim, inner_dim, bias=False)
90
+ self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
91
+ self.to_out = nn.Linear(inner_dim, dim, bias=False)
92
+ self.feed_forward = nn.ModuleList(
93
+ [
94
+ nn.LayerNorm(dim),
95
+ nn.Linear(dim, ff_dim, bias=False),
96
+ nn.GELU(),
97
+ nn.Linear(ff_dim, dim, bias=False),
98
+ ]
99
+ )
100
+
101
+ def forward(self, x: torch.Tensor, latents: torch.Tensor) -> torch.Tensor:
102
+ """
103
+ Args:
104
+ x (torch.Tensor): image features
105
+ shape (b, T, n1, D)
106
+ latent (torch.Tensor): latent features
107
+ shape (b, T, n2, D)
108
+ """
109
+ x = self.norm_media(x)
110
+ residual_latents = latents
111
+ latents = self.norm_latents(latents)
112
+
113
+ h = self.heads
114
+
115
+ q = self.to_q(latents)
116
+ kv_input = torch.cat((x, latents), dim=-2)
117
+ k, v = self.to_kv(kv_input).chunk(2, dim=-1)
118
+ q = rearrange(q, "b t n (h d) -> b h t n d", h=h)
119
+ k = rearrange(k, "b t n (h d) -> b h t n d", h=h)
120
+ v = rearrange(v, "b t n (h d) -> b h t n d", h=h)
121
+ q = q * self.scale
122
+
123
+ # attention
124
+ sim = torch.einsum("... i d, ... j d -> ... i j", q, k)
125
+ sim = sim - sim.amax(dim=-1, keepdim=True).detach()
126
+ attn = sim.softmax(dim=-1)
127
+
128
+ out = torch.einsum("... i j, ... j d -> ... i d", attn, v)
129
+ out = rearrange(out, "b h t n d -> b t n (h d)", h=h)
130
+ out = self.to_out(out) + residual_latents
131
+ residual_out = out
132
+ for layer in self.feed_forward:
133
+ out = layer(out)
134
+ return out + residual_out
135
+
136
+
137
+ class FlamingoPerceiverResampler(nn.Module):
138
+ def __init__(
139
+ self,
140
+ *,
141
+ dim: int,
142
+ depth: int = 6,
143
+ dim_head: int = 64,
144
+ heads: int = 8,
145
+ num_latents: int = 64,
146
+ # max_num_frames: int = 128,
147
+ max_num_media: Optional[int] = None,
148
+ max_num_frames: Optional[int] = None,
149
+ ff_mult: int = 4,
150
+ ):
151
+ super().__init__()
152
+ self.latents = nn.Parameter(torch.randn(num_latents, dim))
153
+ self.frame_embs = nn.Parameter(torch.randn(max_num_frames, dim)) if exists(max_num_frames) else None
154
+ # self.frame_embs = nn.Parameter(torch.randn(max_num_frames, dim))
155
+
156
+ self.media_time_embs = nn.Parameter(torch.randn(max_num_media, 1, dim)) if exists(max_num_media) else None
157
+
158
+ self.layers = nn.ModuleList([])
159
+ for _ in range(depth):
160
+ self.layers.append(FlamingoPerceiverBlock(dim=dim, dim_head=dim_head, heads=heads, mult=ff_mult))
161
+
162
+ self.norm = nn.LayerNorm(dim)
163
+
164
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
165
+ """
166
+ Args:
167
+ x (torch.Tensor): image features
168
+ shape (b, T, F, v, D)
169
+ Returns:
170
+ shape (b, T, n, D) where n is self.num_latents
171
+ """
172
+ b, T, F, v = x.shape[:4]
173
+
174
+ # frame and media time embeddings
175
+ if exists(self.frame_embs):
176
+ frame_embs = repeat(self.frame_embs[:F], "F d -> b T F v d", b=b, T=T, v=v)
177
+ x = x + frame_embs
178
+ x = rearrange(x, "b T F v d -> b T (F v) d") # flatten the frame and spatial dimensions
179
+ if exists(self.media_time_embs):
180
+ x = x + self.media_time_embs[:T]
181
+
182
+ # blocks
183
+ latents = repeat(self.latents, "n d -> b T n d", b=b, T=T)
184
+ for block in self.layers:
185
+ latents = block(x, latents)
186
+ return self.norm(latents)
187
+
188
+
189
+ class FlamingoMaskedCrossAttention(nn.Module):
190
+ def __init__(
191
+ self,
192
+ *,
193
+ dim: int,
194
+ dim_visual: int,
195
+ dim_head: int = 64,
196
+ heads: int = 8,
197
+ only_attend_immediate_media: bool = True,
198
+ ):
199
+ super().__init__()
200
+ self.scale = dim_head**-0.5
201
+ self.heads = heads
202
+ inner_dim = dim_head * heads
203
+
204
+ self.norm = nn.LayerNorm(dim)
205
+
206
+ self.to_q = nn.Linear(dim, inner_dim, bias=False)
207
+ self.to_kv = nn.Linear(dim_visual, inner_dim * 2, bias=False)
208
+ self.to_out = nn.Linear(inner_dim, dim, bias=False)
209
+
210
+ # whether for text to only attend to immediate preceding image, or all previous images
211
+ self.only_attend_immediate_media = only_attend_immediate_media
212
+
213
+ def forward(
214
+ self,
215
+ x: torch.Tensor,
216
+ media: torch.Tensor,
217
+ media_locations: Optional[torch.BoolTensor] = None,
218
+ attend_previous: bool = True,
219
+ ) -> torch.Tensor:
220
+ """
221
+ Args:
222
+ x (torch.Tensor): text features
223
+ shape (B, T_txt, D_txt)
224
+ media (torch.Tensor): image features
225
+ shape (B, T_img, n, D_img) where n is the dim of the latents
226
+ media_locations: boolean mask identifying the media tokens in x
227
+ shape (B, T_txt)
228
+ attend_previous: bool
229
+ If false, ignores immediately preceding image and starts attending when following image
230
+ """
231
+ _, T_img, n = media.shape[:3]
232
+ h = self.heads
233
+
234
+ x = self.norm(x)
235
+
236
+ q = self.to_q(x)
237
+ media = rearrange(media, "b t n d -> b (t n) d")
238
+
239
+ k, v = self.to_kv(media).chunk(2, dim=-1)
240
+ q = rearrange(q, "b n (h d) -> b h n d", h=h)
241
+ k = rearrange(k, "b n (h d) -> b h n d", h=h)
242
+ v = rearrange(v, "b n (h d) -> b h n d", h=h)
243
+
244
+ q = q * self.scale
245
+
246
+ sim = torch.einsum("... i d, ... j d -> ... i j", q, k)
247
+
248
+ if exists(media_locations):
249
+ # at each boolean of True, increment the time counter (relative to media time)
250
+ text_time = media_locations.cumsum(dim=-1)
251
+ media_time = torch.arange(T_img, device=x.device) + 1
252
+
253
+ if not attend_previous:
254
+ text_time[~media_locations] += 1
255
+ # make sure max is still the number of images in the sequence
256
+ text_time[
257
+ text_time
258
+ > repeat(
259
+ torch.count_nonzero(media_locations, dim=1),
260
+ "b -> b i",
261
+ i=text_time.shape[1],
262
+ )
263
+ ] = 0
264
+
265
+ # text time must equal media time if only attending to most immediate image
266
+ # otherwise, as long as text time is greater than media time (if attending to all previous images / media)
267
+ mask_op = torch.eq if self.only_attend_immediate_media else torch.ge
268
+
269
+ text_to_media_mask = mask_op(
270
+ rearrange(text_time, "b i -> b 1 i 1"),
271
+ repeat(media_time, "j -> 1 1 1 (j n)", n=n),
272
+ )
273
+ sim = sim.masked_fill(~text_to_media_mask, -torch.finfo(sim.dtype).max)
274
+
275
+ sim = sim - sim.amax(dim=-1, keepdim=True).detach()
276
+ attn = sim.softmax(dim=-1)
277
+
278
+ if exists(media_locations) and self.only_attend_immediate_media:
279
+ # any text without a preceding media needs to have attention zeroed out
280
+ text_without_media_mask = text_time == 0
281
+ text_without_media_mask = rearrange(text_without_media_mask, "b i -> b 1 i 1")
282
+ attn = attn.masked_fill(text_without_media_mask, 0.0)
283
+
284
+ out = torch.einsum("... i j, ... j d -> ... i d", attn, v)
285
+ out = rearrange(out, "b h n d -> b n (h d)")
286
+ return self.to_out(out)
287
+
288
+
289
+ class FlamingoGatedCrossAttentionBlock(nn.Module):
290
+ def __init__(
291
+ self,
292
+ *,
293
+ dim: int,
294
+ dim_visual: int,
295
+ dim_head: int = 64,
296
+ heads: int = 8,
297
+ ff_mult: int = 4,
298
+ only_attend_immediate_media: bool = True,
299
+ ):
300
+ super().__init__()
301
+ self.attn = FlamingoMaskedCrossAttention(
302
+ dim=dim,
303
+ dim_visual=dim_visual,
304
+ dim_head=dim_head,
305
+ heads=heads,
306
+ only_attend_immediate_media=only_attend_immediate_media,
307
+ )
308
+ self.attn_gate = nn.Parameter(torch.tensor([0.0]))
309
+ self.feed_forward = nn.ModuleList(
310
+ [
311
+ nn.LayerNorm(dim),
312
+ nn.Linear(dim, dim * ff_mult, bias=False),
313
+ nn.GELU(),
314
+ nn.Linear(dim * ff_mult, dim, bias=False),
315
+ ]
316
+ )
317
+ self.ff_gate = nn.Parameter(torch.tensor([0.0]))
318
+
319
+ def forward(
320
+ self,
321
+ x: torch.Tensor,
322
+ media: torch.Tensor,
323
+ media_locations: Optional[torch.BoolTensor] = None,
324
+ attend_previous: bool = True,
325
+ ) -> torch.Tensor:
326
+ x = (
327
+ self.attn(
328
+ x,
329
+ media,
330
+ media_locations=media_locations,
331
+ attend_previous=attend_previous,
332
+ )
333
+ * self.attn_gate.tanh()
334
+ + x
335
+ )
336
+ residual_x = x
337
+ for ff in self.feed_forward:
338
+ x = ff(x)
339
+ x = x * self.ff_gate.tanh() + residual_x
340
+
341
+ return x
342
+
343
+
344
+ class FlamingoLayer(nn.Module):
345
+ def __init__(self, gated_cross_attn_layer: nn.Module, decoder_layer: nn.Module):
346
+ super().__init__()
347
+ self.gated_cross_attn_layer = gated_cross_attn_layer
348
+ self.decoder_layer = decoder_layer
349
+ self.vis_x = None
350
+ self.media_locations = None
351
+
352
+ def is_conditioned(self) -> bool:
353
+ """Check whether the layer is conditioned."""
354
+ return self.vis_x is not None
355
+
356
+ # Used this great idea from this implementation of Flamingo (https://github.com/dhansmair/flamingo-mini/)
357
+ def condition_vis_x(self, vis_x) -> None:
358
+ self.vis_x = vis_x
359
+
360
+ def condition_media_locations(self, media_locations) -> None:
361
+ self.media_locations = media_locations
362
+
363
+ def condition_attend_previous(self, attend_previous) -> None:
364
+ self.attend_previous = attend_previous
365
+
366
+ def forward(
367
+ self,
368
+ lang_x: torch.Tensor,
369
+ attention_mask: Optional[torch.Tensor] = None,
370
+ **decoder_layer_kwargs,
371
+ ):
372
+ if self.gated_cross_attn_layer is None:
373
+ return self.decoder_layer(lang_x, attention_mask=attention_mask, **decoder_layer_kwargs)
374
+
375
+ if self.vis_x is None:
376
+ raise ValueError("vis_x must be conditioned before forward pass")
377
+
378
+ if self.media_locations is None:
379
+ raise ValueError("media_locations must be conditioned before forward pass")
380
+
381
+ lang_x = self.gated_cross_attn_layer(
382
+ lang_x,
383
+ self.vis_x,
384
+ media_locations=self.media_locations,
385
+ attend_previous=self.attend_previous,
386
+ )
387
+ lang_x = self.decoder_layer(lang_x, attention_mask=attention_mask, **decoder_layer_kwargs)
388
+ return lang_x
389
+
390
+
391
+ class FlamingoLMMixin(nn.Module):
392
+ """
393
+ Mixin to add cross-attention layers to a language model.
394
+ """
395
+
396
+ def set_decoder_layers_attr_name(self, decoder_layers_attr_name):
397
+ self.decoder_layers_attr_name = decoder_layers_attr_name
398
+
399
+ def _get_decoder_layers(self):
400
+ return getattr_recursive(self, self.decoder_layers_attr_name)
401
+
402
+ def _set_decoder_layers(self, value):
403
+ setattr_recursive(self, self.decoder_layers_attr_name, value)
404
+
405
+ def init_flamingo(
406
+ self,
407
+ media_token_id: int,
408
+ vis_hidden_size: int,
409
+ cross_attn_every_n_layers: int,
410
+ use_media_placement_augmentation: bool,
411
+ ):
412
+ """
413
+ Initialize Flamingo by adding a new gated cross attn to the decoder. Store the media token id for computing the media locations.
414
+ """
415
+
416
+ gated_cross_attn_layers = nn.ModuleList(
417
+ [
418
+ FlamingoGatedCrossAttentionBlock(
419
+ dim=self.config.hidden_size,
420
+ dim_visual=vis_hidden_size,
421
+ )
422
+ if (layer_idx + 1) % cross_attn_every_n_layers == 0
423
+ else None
424
+ for layer_idx, _ in enumerate(self._get_decoder_layers())
425
+ ]
426
+ )
427
+ self._set_decoder_layers(
428
+ nn.ModuleList(
429
+ [
430
+ FlamingoLayer(gated_cross_attn_layer, decoder_layer)
431
+ for gated_cross_attn_layer, decoder_layer in zip(gated_cross_attn_layers, self._get_decoder_layers())
432
+ ]
433
+ )
434
+ )
435
+ self.media_token_id = media_token_id
436
+ self.use_media_placement_augmentation = use_media_placement_augmentation
437
+ self.initialized_flamingo = True
438
+
439
+ def forward(self, *input, **kwargs):
440
+ """Condition the Flamingo layers on the media locations before forward()"""
441
+ if not self.initialized_flamingo:
442
+ raise ValueError("Flamingo layers are not initialized. Please call `init_flamingo` first.")
443
+
444
+ input_ids = kwargs["input_ids"] if "input_ids" in kwargs else input[0]
445
+ media_locations = input_ids == self.media_token_id
446
+ # IMPORTANT: Force `attend_previous` to True when we place training data as <image>caption<|endofchunk|>
447
+ # attend_previous = (
448
+ # (random.random() < 0.5) if self.use_media_placement_augmentation else False
449
+ # )
450
+ attend_previous = (random.random() < 0.5) if self.use_media_placement_augmentation else True
451
+ # attend_previous = self.only_attend_previous
452
+
453
+ if self.__class__.__name__ == "LlamaForCausalLM":
454
+ for layer in self.get_decoder().layers:
455
+ layer.condition_media_locations(media_locations)
456
+ layer.condition_attend_previous(attend_previous)
457
+ elif self.__class__.__name__ in ["MPTForCausalLM", "MosaicGPT"]:
458
+ for layer in self.get_decoder().blocks:
459
+ layer.condition_media_locations(media_locations)
460
+ layer.condition_attend_previous(attend_previous)
461
+ else:
462
+ print("inavaliable text encoder")
463
+ return super().forward(*input, **kwargs) # Call the other parent's forward method
464
+
465
+ def is_conditioned(self) -> bool:
466
+ """Check whether all decoder layers are already conditioned."""
467
+ return all(l.is_conditioned() for l in self._get_decoder_layers())
468
+
469
+ def clear_conditioned_layers(self) -> None:
470
+ for layer in self._get_decoder_layers():
471
+ layer.condition_vis_x(None)
472
+ layer.condition_media_locations(None)
473
+ layer.condition_attend_previous(None)
474
+
475
+
476
+ class FlamingoPreTrainedModel(PreTrainedModel):
477
+ """
478
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
479
+ models.
480
+ """
481
+
482
+ config_class = FlamingoConfig
483
+ base_model_prefix = "flamingo"
484
+ supports_gradient_checkpointing = True
485
+ _no_split_modules = ["FlamingoPerceiverBlock", "CLIPEncoderLayer", "FlamingoLayer"]
486
+
487
+ def _init_weights(self, module):
488
+ """Flamingo requires no specific initialization"""
489
+ return super()._init_weights(module)
490
+
491
+ def _set_gradient_checkpointing(self, module, value=False):
492
+ if isinstance(module, FlamingoModel):
493
+ module.gradient_checkpointing = value
494
+
495
+
496
+ class FlamingoModel(FlamingoPreTrainedModel):
497
+ config_class = FlamingoConfig
498
+
499
+ def __init__(
500
+ self,
501
+ config: FlamingoConfig,
502
+ ):
503
+ super().__init__(config)
504
+ ### TODO: give "LlamaForCausalLM" as the name of text_config.architectures of Llama_based flamingo
505
+ if "llama" not in config.text_config._name_or_path:
506
+ if config.text_config.architectures[0] == "MPTForCausalLM":
507
+ text_tokenizer = AutoTokenizer.from_pretrained("mosaicml/mpt-7b-instruct")
508
+ lang_encoder = MPTForCausalLM(config=config.text_config)
509
+ elif config.text_config.text_config.architectures[0] == "MosaicGPT":
510
+ text_tokenizer = AutoTokenizer.from_pretrained("mosaicml/mosaic-llama-redpajama-final-candidate")
511
+ lang_encoder = MosaicGPT(config=config.text_config)
512
+ elif config.text_config.architectures[0] == "RWForCausalLM":
513
+ text_tokenizer = AutoTokenizer.from_pretrained("PATH-TO-YOUR-FALCON")
514
+ lang_encoder = RWForCausalLM(config=config.text_config)
515
+ else:
516
+ text_tokenizer = LlamaTokenizer.from_pretrained(config.text_config._name_or_path)
517
+ lang_encoder = LlamaForCausalLM(config=config.text_config)
518
+
519
+ vision_encoder = CLIPVisionModel(config=config.vision_config)
520
+ text_tokenizer.add_special_tokens({"additional_special_tokens": ["<|endofchunk|>", "<image>"]})
521
+ if text_tokenizer.pad_token is None:
522
+ text_tokenizer.add_special_tokens({"pad_token": "<PAD>"})
523
+ self.text_tokenizer = text_tokenizer
524
+ self.eoc_token_id = text_tokenizer.encode("<|endofchunk|>")[-1]
525
+ self.media_token_id = text_tokenizer.encode("<image>")[-1]
526
+
527
+ extend_instance(lang_encoder, FlamingoLMMixin)
528
+ decoder_layers_attr_name = _infer_decoder_layers_attr_name(lang_encoder)
529
+ lang_encoder.set_decoder_layers_attr_name(decoder_layers_attr_name)
530
+ if lang_encoder.__class__.__name__ == "LlamaForCausalLM":
531
+ lang_encoder.resize_token_embeddings(len(text_tokenizer))
532
+ self.lang_encoder = lang_encoder
533
+
534
+ self.cross_attn_every_n_layers = config.cross_attn_every_n_layers if hasattr(config, "cross_attn_every_n_layers") else 4
535
+ self.use_media_placement_augmentation = config.use_media_placement_augmentation
536
+
537
+ vision_encoder.output_tokens = True
538
+ self.vision_encoder = vision_encoder
539
+
540
+ self.vis_dim = 1024
541
+ self.perceiver = FlamingoPerceiverResampler(dim=self.vis_dim)
542
+
543
+ self.lang_encoder.init_flamingo(
544
+ media_token_id=self.media_token_id,
545
+ vis_hidden_size=self.vis_dim,
546
+ cross_attn_every_n_layers=self.cross_attn_every_n_layers,
547
+ use_media_placement_augmentation=self.use_media_placement_augmentation,
548
+ )
549
+ self.post_init()
550
+
551
+ def get_input_embeddings(self) -> nn.Module:
552
+ return self.lang_encoder.get_input_embeddings()
553
+
554
+ def set_input_embeddings(self, new_embeddings):
555
+ self.lang_encoder.set_input_embeddings(new_embeddings)
556
+
557
+ def get_output_embeddings(self) -> nn.Module:
558
+ return self.lang_encoder.get_output_embeddings()
559
+
560
+ def set_output_embeddings(self, new_embeddings):
561
+ self.lang_encoder.set_output_embeddings(new_embeddings)
562
+
563
+ def get_image_encoder(self) -> nn.Module:
564
+ return self.vision_encoder
565
+
566
+ def get_lang_encoder(self) -> nn.Module:
567
+ return self.lang_encoder
568
+
569
+ # def init_weights(self):
570
+ # # Freeze all parameters in vision encoder
571
+ # for param in self.vision_encoder.parameters():
572
+ # param.requires_grad = False
573
+ # # Freeze all parameters in lang encoders except gated_cross_attn_layers
574
+ # for name, param in self.lang_encoder.named_parameters():
575
+ # if "gated_cross_attn_layer" not in name:
576
+ # param.requires_grad = False
577
+ # # Unfreeze LM input embeddings
578
+ # self.lang_encoder.get_input_embeddings().requires_grad_(True)
579
+ # ## MPTForCausalLM is tied word embedding
580
+ # if self.lang_encoder.__class__.__name__ == "LlamaForCausalLM":
581
+ # self.lang_encoder.lm_head.requires_grad_(True)
582
+ # # assert sum(p.numel() for p in model.parameters() if p.requires_grad) == 0
583
+ # # print model size in billions of parameters in 2 decimal places
584
+ # print(f"Trainable param: {(sum(p.numel() for p in self.parameters() if p.requires_grad)) / 1e9:.2f} B")
585
+
586
+ def init_weights(self):
587
+ # Freeze all parameters in vision encoder
588
+ for param in self.vision_encoder.parameters():
589
+ param.requires_grad = False
590
+
591
+ if "lora_config" in self.config.__dict__:
592
+ print(f"LoRA trainable param: {(sum(p.numel() for p in self.lang_encoder.parameters() if p.requires_grad)) / 1e9:.3f} B")
593
+ # Unfreeze gated_cross_attn_layers
594
+ for layer in self.lang_encoder._get_decoder_layers():
595
+ if layer.gated_cross_attn_layer is not None:
596
+ for param in layer.gated_cross_attn_layer.parameters():
597
+ param.requires_grad = True
598
+ else:
599
+ # Freeze all parameters in lang encoders except gated_cross_attn_layers
600
+ for name, param in self.lang_encoder.named_parameters():
601
+ if "gated_cross_attn_layer" not in name:
602
+ param.requires_grad = False
603
+ # Unfreeze LM input and output embeddings
604
+ self.lang_encoder.get_input_embeddings().requires_grad_(True)
605
+ ## MPTForCausalLM is tied word embedding
606
+ if self.lang_encoder.__class__.__name__ == "LlamaForCausalLM":
607
+ self.lang_encoder.lm_head.requires_grad_(True)
608
+ # assert sum(p.numel() for p in model.parameters() if p.requires_grad) == 0
609
+ # print model size in billions of parameters in 2 decimal places
610
+ print(f"Total Trainable param: {(sum(p.numel() for p in self.parameters() if p.requires_grad)) / 1e9:.3f} B")
611
+
612
+ def forward(
613
+ self,
614
+ vision_x: torch.Tensor,
615
+ lang_x: torch.Tensor,
616
+ attention_mask: Optional[torch.Tensor] = None,
617
+ labels: Optional[torch.Tensor] = None,
618
+ use_cached_vision_x: bool = False,
619
+ clear_conditioned_layers: bool = True,
620
+ past_key_values: Optional[torch.Tensor] = None,
621
+ use_cache: bool = False,
622
+ **kwargs,
623
+ ) -> CausalLMOutputWithPast:
624
+ """
625
+ Forward pass of Flamingo.
626
+
627
+ Args:
628
+ vision_x (torch.Tensor): Vision input
629
+ shape (B, T_img, F, C, H, W) with F=1
630
+ lang_x (torch.Tensor): Language input ids
631
+ shape (B, T_txt)
632
+ attention_mask (torch.Tensor, optional): Attention mask. Defaults to None.
633
+ labels (torch.Tensor, optional): Labels. Defaults to None.
634
+ clear_conditioned_layers: if True, clear the conditioned layers
635
+ once the foward pass is completed. Set this to false if the
636
+ same set of images will be reused in another subsequent
637
+ forward pass.
638
+ past_key_values: pre-computed values to pass to language model.
639
+ See past_key_values documentation in Hugging Face
640
+ CausalLM models.
641
+ use_cache: whether to use cached key values. See use_cache
642
+ documentation in Hugging Face CausalLM models.
643
+ """
644
+ assert (vision_x is not None) or use_cached_vision_x, "Must provide either vision_x or use_cached_vision_x to True."
645
+
646
+ if use_cached_vision_x:
647
+ # Case: use cached; vision_x should be cached and other
648
+ # vision-related inputs should not be provided.
649
+ assert vision_x is None, "Expect vision_x to be None when use_cached_vision_x is True."
650
+ assert self.lang_encoder.is_conditioned()
651
+
652
+ else:
653
+ # Case: do not use caching (i.e. this is a standard forward pass);
654
+ self._encode_vision_x(vision_x=vision_x)
655
+
656
+ output = self.lang_encoder(
657
+ input_ids=lang_x,
658
+ attention_mask=attention_mask,
659
+ labels=labels,
660
+ past_key_values=past_key_values,
661
+ use_cache=use_cache,
662
+ **kwargs,
663
+ )
664
+
665
+ if clear_conditioned_layers:
666
+ self.lang_encoder.clear_conditioned_layers()
667
+
668
+ return output
669
+
670
+ def _encode_vision_x(self, vision_x: torch.Tensor):
671
+ """
672
+ Compute media tokens from vision input by passing it through vision encoder and conditioning language model.
673
+ Args:
674
+ vision_x (torch.Tensor): Vision input
675
+ shape (B, T_img, F, C, H, W)
676
+ Images in the same chunk are collated along T_img, and frames are collated along F
677
+ Currently only F=1 is supported (single-frame videos)
678
+
679
+ rearrange code based on https://github.com/dhansmair/flamingo-mini
680
+ """
681
+
682
+ assert vision_x.ndim == 6, "vision_x should be of shape (b, T_img, F, C, H, W)"
683
+ b, T, F = vision_x.shape[:3]
684
+ assert F == 1, "Only single frame supported"
685
+
686
+ vision_x = rearrange(vision_x, "b T F c h w -> (b T F) c h w")
687
+ with torch.no_grad():
688
+ vision_x = self.vision_encoder(vision_x)[0][:, 1:, :]
689
+ vision_x = rearrange(vision_x, "(b T F) v d -> b T F v d", b=b, T=T, F=F)
690
+
691
+ vision_x = self.perceiver(vision_x) # reshapes to (b, T, n, d)
692
+
693
+ for layer in self.lang_encoder._get_decoder_layers():
694
+ layer.condition_vis_x(vision_x)
695
+
696
+
697
+ class FlamingoForConditionalGeneration(FlamingoPreTrainedModel):
698
+ config_class = FlamingoConfig
699
+
700
+ def __init__(
701
+ self,
702
+ config: FlamingoConfig,
703
+ ):
704
+ super().__init__(config)
705
+ # TODO: hardcode right because autoXXX is too slow
706
+ # vision_encoder = AutoModel.from_config(config.vision_config).vision_model
707
+ # lang_encoder = AutoModelForCausalLM.from_config(config.text_config)
708
+ # text_tokenizer = AutoTokenizer.from_pretrained(config.text_config._name_or_path)
709
+
710
+ ### TODO: give "LlamaForCausalLM" as the name of text_config.architectures of Llama_based flamingo
711
+ # assert hasattr(config.text_config, "_name_or_path")
712
+ # if "llama" not in config.text_config._name_or_path.lower():
713
+ if config.text_config.architectures[0] == "MPTForCausalLM":
714
+ text_tokenizer = AutoTokenizer.from_pretrained("mosaicml/mpt-7b-instruct")
715
+ lang_encoder = MPTForCausalLM(config=config.text_config)
716
+ elif config.text_config.architectures[0] == "MosaicGPT":
717
+ text_tokenizer = AutoTokenizer.from_pretrained("mosaicml/mosaic-llama-redpajama-final-candidate")
718
+ lang_encoder = MosaicGPT(config=config.text_config)
719
+ elif config.text_config.architectures[0] == "RWForCausalLM":
720
+ text_tokenizer = AutoTokenizer.from_pretrained("PATH-TO-YOUR-FALCON")
721
+ lang_encoder = RWForCausalLM(config=config.text_config)
722
+ # TODO: what's the logic here?
723
+ elif config.text_config.architectures[0] == "LlamaForCausalLM":
724
+ text_tokenizer = LlamaTokenizer.from_pretrained(config.text_config._name_or_path)
725
+ lang_encoder = LlamaForCausalLM(config=config.text_config)
726
+ else:
727
+ import pdb
728
+
729
+ pdb.set_trace()
730
+ # else:
731
+ # text_tokenizer = LlamaTokenizer.from_pretrained(config.text_config._name_or_path)
732
+ # lang_encoder = LlamaForCausalLM(config=config.text_config)
733
+
734
+ vision_encoder = CLIPVisionModel(config=config.vision_config)
735
+ text_tokenizer.add_special_tokens({"additional_special_tokens": ["<|endofchunk|>", "<image>"]})
736
+ if text_tokenizer.pad_token is None:
737
+ text_tokenizer.add_special_tokens({"pad_token": "<PAD>"})
738
+ self.text_tokenizer = text_tokenizer
739
+ self.eoc_token_id = text_tokenizer.encode("<|endofchunk|>")[-1]
740
+ self.media_token_id = text_tokenizer.encode("<image>")[-1]
741
+
742
+ extend_instance(lang_encoder, FlamingoLMMixin)
743
+ decoder_layers_attr_name = _infer_decoder_layers_attr_name(lang_encoder)
744
+ lang_encoder.set_decoder_layers_attr_name(decoder_layers_attr_name)
745
+ if "LlamaForCausalLM" in lang_encoder.__class__.__name__:
746
+ lang_encoder.resize_token_embeddings(len(text_tokenizer))
747
+ self.lang_encoder = lang_encoder
748
+
749
+ self.cross_attn_every_n_layers = config.cross_attn_every_n_layers if hasattr(config, "cross_attn_every_n_layers") else 4
750
+ self.use_media_placement_augmentation = config.use_media_placement_augmentation
751
+
752
+ vision_encoder.output_tokens = True
753
+ self.vision_encoder = vision_encoder
754
+
755
+ self.vis_dim = 1024
756
+ self.perceiver = FlamingoPerceiverResampler(dim=self.vis_dim)
757
+
758
+ self.lang_encoder.init_flamingo(
759
+ media_token_id=self.media_token_id,
760
+ vis_hidden_size=self.vis_dim,
761
+ cross_attn_every_n_layers=self.cross_attn_every_n_layers,
762
+ use_media_placement_augmentation=self.use_media_placement_augmentation,
763
+ )
764
+ self.post_init()
765
+
766
+ def get_input_embeddings(self) -> nn.Module:
767
+ return self.lang_encoder.get_input_embeddings()
768
+
769
+ def set_input_embeddings(self, new_embeddings):
770
+ self.lang_encoder.set_input_embeddings(new_embeddings)
771
+
772
+ def get_output_embeddings(self) -> nn.Module:
773
+ return self.lang_encoder.get_output_embeddings()
774
+
775
+ def set_output_embeddings(self, new_embeddings):
776
+ self.lang_encoder.set_output_embeddings(new_embeddings)
777
+
778
+ def get_image_encoder(self) -> nn.Module:
779
+ return self.vision_encoder
780
+
781
+ def get_lang_encoder(self) -> nn.Module:
782
+ return self.lang_encoder
783
+
784
+ def init_weights(self):
785
+ # Freeze all parameters in vision encoder
786
+ for param in self.vision_encoder.parameters():
787
+ param.requires_grad = False
788
+ # Freeze all parameters in lang encoders except gated_cross_attn_layers
789
+ for name, param in self.lang_encoder.named_parameters():
790
+ if "gated_cross_attn_layer" not in name:
791
+ param.requires_grad = False
792
+ # Unfreeze LM input embeddings
793
+ self.lang_encoder.get_input_embeddings().requires_grad_(True)
794
+ ## MPTForCausalLM is tied word embedding
795
+ if "LlamaForCausalLM" in self.lang_encoder.__class__.__name__:
796
+ self.lang_encoder.lm_head.requires_grad_(True)
797
+ # assert sum(p.numel() for p in model.parameters() if p.requires_grad) == 0
798
+ # print model size in billions of parameters in 2 decimal places
799
+ print("====================Model Grad Part====================")
800
+ total_params = 0
801
+ for name, param in self.named_parameters():
802
+ if param.requires_grad:
803
+ total_params += param.numel()
804
+ print(f"Parameter: {name}, Size: {param.numel() / 1e6:.6f} M")
805
+ print(f"Total Trainable param: {total_params / 1e9:.4f} B")
806
+ print(f"Total Trainable param: {(sum(p.numel() for p in self.parameters() if p.requires_grad)) / 1e9:.3f} B")
807
+
808
+ def forward(
809
+ self,
810
+ vision_x: torch.Tensor,
811
+ lang_x: torch.Tensor,
812
+ attention_mask: Optional[torch.Tensor] = None,
813
+ labels: Optional[torch.Tensor] = None,
814
+ use_cached_vision_x: bool = False,
815
+ clear_conditioned_layers: bool = True,
816
+ past_key_values: Optional[torch.Tensor] = None,
817
+ use_cache: bool = False,
818
+ **kwargs,
819
+ ) -> CausalLMOutputWithPast:
820
+ """
821
+ Forward pass of Flamingo.
822
+
823
+ Args:
824
+ vision_x (torch.Tensor): Vision input
825
+ shape (B, T_img, F, C, H, W) with F=1
826
+ lang_x (torch.Tensor): Language input ids
827
+ shape (B, T_txt)
828
+ attention_mask (torch.Tensor, optional): Attention mask. Defaults to None.
829
+ labels (torch.Tensor, optional): Labels. Defaults to None.
830
+ clear_conditioned_layers: if True, clear the conditioned layers
831
+ once the foward pass is completed. Set this to false if the
832
+ same set of images will be reused in another subsequent
833
+ forward pass.
834
+ past_key_values: pre-computed values to pass to language model.
835
+ See past_key_values documentation in Hugging Face
836
+ CausalLM models.
837
+ use_cache: whether to use cached key values. See use_cache
838
+ documentation in Hugging Face CausalLM models.
839
+ """
840
+ assert (vision_x is not None) or use_cached_vision_x, "Must provide either vision_x or use_cached_vision_x to True."
841
+
842
+ if use_cached_vision_x:
843
+ # Case: use cached; vision_x should be cached and other
844
+ # vision-related inputs should not be provided.
845
+ assert vision_x is None, "Expect vision_x to be None when use_cached_vision_x is True."
846
+ assert self.lang_encoder.is_conditioned()
847
+
848
+ else:
849
+ # Case: do not use caching (i.e. this is a standard forward pass);
850
+ self._encode_vision_x(vision_x=vision_x)
851
+
852
+ output = self.lang_encoder(
853
+ input_ids=lang_x,
854
+ attention_mask=attention_mask,
855
+ labels=labels,
856
+ past_key_values=past_key_values,
857
+ use_cache=use_cache,
858
+ **kwargs,
859
+ )
860
+
861
+ if clear_conditioned_layers:
862
+ self.lang_encoder.clear_conditioned_layers()
863
+
864
+ return output
865
+
866
+ def _encode_vision_x(self, vision_x: torch.Tensor):
867
+ """
868
+ Compute media tokens from vision input by passing it through vision encoder and conditioning language model.
869
+ Args:
870
+ vision_x (torch.Tensor): Vision input
871
+ shape (B, T_img, F, C, H, W)
872
+ Images in the same chunk are collated along T_img, and frames are collated along F
873
+ Currently only F=1 is supported (single-frame videos)
874
+
875
+ rearrange code based on https://github.com/dhansmair/flamingo-mini
876
+ """
877
+
878
+ assert vision_x.ndim == 6, "vision_x should be of shape (b, T_img, F, C, H, W)"
879
+ b, T, F = vision_x.shape[:3]
880
+ # assert F == 1, "Only single frame supported"
881
+
882
+ vision_x = rearrange(vision_x, "b T F c h w -> (b T F) c h w")
883
+ with torch.no_grad():
884
+ vision_x = self.vision_encoder(vision_x)[0][:, 1:, :]
885
+ vision_x = rearrange(vision_x, "(b T F) v d -> b T F v d", b=b, T=T, F=F)
886
+
887
+ vision_x = self.perceiver(vision_x) # reshapes to (b, T, n, d)
888
+
889
+ for layer in self.lang_encoder._get_decoder_layers():
890
+ layer.condition_vis_x(vision_x)
891
+
892
+ @torch.no_grad()
893
+ def generate(
894
+ self,
895
+ vision_x: torch.Tensor,
896
+ lang_x: torch.Tensor,
897
+ attention_mask: Optional[torch.Tensor] = None,
898
+ num_beams: int = 1,
899
+ max_new_tokens: Optional[int] = None,
900
+ temperature: float = 1.0,
901
+ top_k: int = 0,
902
+ top_p: float = 1.0,
903
+ no_repeat_ngram_size: int = 0,
904
+ prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], list[int]]] = None,
905
+ length_penalty: float = 1.0,
906
+ num_return_sequences: int = 1,
907
+ do_sample: bool = False,
908
+ early_stopping: bool = False,
909
+ **kwargs,
910
+ ):
911
+ """
912
+ Generate text conditioned on vision and language inputs.
913
+
914
+ Args:
915
+ vision_x (torch.Tensor): Vision input
916
+ shape (B, T_img, F, C, H, W)
917
+ images in the same chunk are collated along T_img, and frames are collated along F
918
+ currently only F=1 is supported (single-frame videos)
919
+ lang_x (torch.Tensor): Language input
920
+ shape (B, T_txt)
921
+ max_length (int, optional): Maximum length of the output. Defaults to None.
922
+ attention_mask (torch.Tensor, optional): Attention mask. Defaults to None.
923
+ num_beams (int, optional): Number of beams. Defaults to 1.
924
+ max_new_tokens (int, optional): Maximum new tokens. Defaults to None.
925
+ temperature (float, optional): Temperature. Defaults to 1.0.
926
+ top_k (int, optional): Top k. Defaults to 0.
927
+ top_p (float, optional): Top p. Defaults to 1.0.
928
+ no_repeat_ngram_size (int, optional): No repeat ngram size. Defaults to 0.
929
+ length_penalty (float, optional): Length penalty. Defaults to 1.0.
930
+ num_return_sequences (int, optional): Number of return sequences. Defaults to 1.
931
+ do_sample (bool, optional): Do sample. Defaults to False.
932
+ early_stopping (bool, optional): Early stopping. Defaults to False.
933
+ Returns:
934
+ torch.Tensor: lang_x with generated tokens appended to it
935
+ """
936
+ if hasattr(self, "_hf_hook"):
937
+ # add a hook to make sure that the output of lang_encoder is mapped to the same device as the lang_x
938
+ hook = AlignDevicesHook(
939
+ execution_device=lang_x.device,
940
+ io_same_device=True,
941
+ place_submodules=False,
942
+ )
943
+ add_hook_to_module(self.lang_encoder, hook)
944
+ if num_beams > 1:
945
+ vision_x = vision_x.repeat_interleave(num_beams, dim=0)
946
+ self._encode_vision_x(vision_x=vision_x)
947
+ output = self.lang_encoder.generate(
948
+ lang_x,
949
+ attention_mask=attention_mask,
950
+ eos_token_id=self.eoc_token_id,
951
+ num_beams=num_beams,
952
+ max_new_tokens=max_new_tokens,
953
+ temperature=temperature,
954
+ top_k=top_k,
955
+ top_p=top_p,
956
+ prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
957
+ no_repeat_ngram_size=no_repeat_ngram_size,
958
+ length_penalty=length_penalty,
959
+ num_return_sequences=num_return_sequences,
960
+ do_sample=do_sample,
961
+ early_stopping=early_stopping,
962
+ **kwargs,
963
+ )
964
+
965
+ self.lang_encoder.clear_conditioned_layers()
966
+ return output
mllm/flamingo/mpt/__init__.py ADDED
File without changes
mllm/flamingo/mpt/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (206 Bytes). View file
 
mllm/flamingo/mpt/__pycache__/attention.cpython-39.pyc ADDED
Binary file (12.2 kB). View file
 
mllm/flamingo/mpt/__pycache__/blocks.cpython-39.pyc ADDED
Binary file (2.81 kB). View file
 
mllm/flamingo/mpt/__pycache__/configuration_mpt.cpython-39.pyc ADDED
Binary file (8.76 kB). View file
 
mllm/flamingo/mpt/__pycache__/custom_embedding.cpython-39.pyc ADDED
Binary file (797 Bytes). View file
 
mllm/flamingo/mpt/__pycache__/flash_attn_triton.cpython-39.pyc ADDED
Binary file (20.9 kB). View file
 
mllm/flamingo/mpt/__pycache__/modeling_mpt.cpython-39.pyc ADDED
Binary file (15.3 kB). View file
 
mllm/flamingo/mpt/__pycache__/norm.cpython-39.pyc ADDED
Binary file (3.03 kB). View file
 
mllm/flamingo/mpt/__pycache__/param_init_fns.cpython-39.pyc ADDED
Binary file (9.14 kB). View file
 
mllm/flamingo/mpt/adapt_tokenizer.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Union
2
+ from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast
3
+
4
+ Tokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
5
+ NUM_SENTINEL_TOKENS: int = 100
6
+
7
+
8
+ def adapt_tokenizer_for_denoising(tokenizer: Tokenizer):
9
+ """Adds sentinel tokens and padding token (if missing).
10
+
11
+ Expands the tokenizer vocabulary to include sentinel tokens
12
+ used in mixture-of-denoiser tasks as well as a padding token.
13
+
14
+ All added tokens are added as special tokens. No tokens are
15
+ added if sentinel tokens and padding token already exist.
16
+ """
17
+ sentinels_to_add = [f"<extra_id_{i}>" for i in range(NUM_SENTINEL_TOKENS)]
18
+ tokenizer.add_tokens(sentinels_to_add, special_tokens=True)
19
+ if tokenizer.pad_token is None:
20
+ tokenizer.add_tokens("<pad>", special_tokens=True)
21
+ tokenizer.pad_token = "<pad>"
22
+ assert tokenizer.pad_token_id is not None
23
+ sentinels = "".join([f"<extra_id_{i}>" for i in range(NUM_SENTINEL_TOKENS)])
24
+ _sentinel_token_ids = tokenizer(sentinels, add_special_tokens=False).input_ids
25
+ tokenizer.sentinel_token_ids = _sentinel_token_ids
26
+
27
+
28
+ class AutoTokenizerForMOD(AutoTokenizer):
29
+ """AutoTokenizer + Adaptation for MOD.
30
+
31
+ A simple wrapper around AutoTokenizer to make instantiating
32
+ an MOD-adapted tokenizer a bit easier.
33
+
34
+ MOD-adapted tokenizers have sentinel tokens (e.g., <extra_id_0>),
35
+ a padding token, and a property to get the token ids of the
36
+ sentinel tokens.
37
+ """
38
+
39
+ @classmethod
40
+ def from_pretrained(cls, *args, **kwargs):
41
+ """See `AutoTokenizer.from_pretrained` docstring."""
42
+ tokenizer = super().from_pretrained(*args, **kwargs)
43
+ adapt_tokenizer_for_denoising(tokenizer)
44
+ return tokenizer
mllm/flamingo/mpt/attention.py ADDED
@@ -0,0 +1,450 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Attention layers."""
2
+ import math
3
+ import warnings
4
+ from typing import Optional
5
+ import torch
6
+ import torch.nn as nn
7
+ from einops import rearrange
8
+ from packaging import version
9
+ from torch import nn
10
+ from .norm import LPLayerNorm
11
+
12
+
13
+ def _reset_is_causal(num_query_tokens: int, num_key_tokens: int, original_is_causal: bool):
14
+ if original_is_causal and num_query_tokens != num_key_tokens:
15
+ if num_query_tokens != 1:
16
+ raise NotImplementedError("MPT does not support query and key with different number of tokens, unless number of query tokens is 1.")
17
+ else:
18
+ return False
19
+ return original_is_causal
20
+
21
+
22
+ def scaled_multihead_dot_product_attention(
23
+ query,
24
+ key,
25
+ value,
26
+ n_heads,
27
+ past_key_value=None,
28
+ softmax_scale=None,
29
+ attn_bias=None,
30
+ key_padding_mask=None,
31
+ is_causal=False,
32
+ dropout_p=0.0,
33
+ training=False,
34
+ needs_weights=False,
35
+ multiquery=False,
36
+ ):
37
+ q = rearrange(query, "b s (h d) -> b h s d", h=n_heads)
38
+ kv_n_heads = 1 if multiquery else n_heads
39
+ k = rearrange(key, "b s (h d) -> b h d s", h=kv_n_heads)
40
+ v = rearrange(value, "b s (h d) -> b h s d", h=kv_n_heads)
41
+ if past_key_value is not None:
42
+ if len(past_key_value) != 0:
43
+ k = torch.cat([past_key_value[0], k], dim=3)
44
+ v = torch.cat([past_key_value[1], v], dim=2)
45
+ past_key_value = (k, v)
46
+ (b, _, s_q, d) = q.shape
47
+ s_k = k.size(-1)
48
+ if softmax_scale is None:
49
+ softmax_scale = 1 / math.sqrt(d)
50
+ attn_weight = q.matmul(k) * softmax_scale
51
+ if attn_bias is not None:
52
+ _s_q = max(0, attn_bias.size(2) - s_q)
53
+ _s_k = max(0, attn_bias.size(3) - s_k)
54
+ attn_bias = attn_bias[:, :, _s_q:, _s_k:]
55
+ if attn_bias.size(-1) != 1 and attn_bias.size(-1) != s_k or (attn_bias.size(-2) != 1 and attn_bias.size(-2) != s_q):
56
+ raise RuntimeError(f"attn_bias (shape: {attn_bias.shape}) is expected to broadcast to shape: {attn_weight.shape}.")
57
+ attn_weight = attn_weight + attn_bias
58
+ min_val = torch.finfo(q.dtype).min
59
+ if key_padding_mask is not None:
60
+ if attn_bias is not None:
61
+ warnings.warn(
62
+ "Propogating key_padding_mask to the attention module "
63
+ + "and applying it within the attention module can cause "
64
+ + "unneccessary computation/memory usage. Consider integrating "
65
+ + "into attn_bias once and passing that to each attention "
66
+ + "module instead."
67
+ )
68
+ attn_weight = attn_weight.masked_fill(~key_padding_mask.view((b, 1, 1, s_k)), min_val)
69
+ if is_causal and (not q.size(2) == 1):
70
+ s = max(s_q, s_k)
71
+ causal_mask = attn_weight.new_ones(s, s, dtype=torch.float16)
72
+ causal_mask = causal_mask.tril()
73
+ causal_mask = causal_mask.to(torch.bool)
74
+ causal_mask = ~causal_mask
75
+ causal_mask = causal_mask[-s_q:, -s_k:]
76
+ attn_weight = attn_weight.masked_fill(causal_mask.view(1, 1, s_q, s_k), min_val)
77
+ attn_weight = torch.softmax(attn_weight, dim=-1)
78
+ if dropout_p:
79
+ attn_weight = torch.nn.functional.dropout(attn_weight, p=dropout_p, training=training, inplace=True)
80
+ out = attn_weight.to(v.dtype).matmul(v)
81
+ out = rearrange(out, "b h s d -> b s (h d)")
82
+ if needs_weights:
83
+ return (out, attn_weight, past_key_value)
84
+ return (out, None, past_key_value)
85
+
86
+
87
+ def check_valid_inputs(*tensors, valid_dtypes=[torch.float16, torch.bfloat16]):
88
+ for tensor in tensors:
89
+ if tensor.dtype not in valid_dtypes:
90
+ raise TypeError(f"tensor.dtype={tensor.dtype!r} must be in valid_dtypes={valid_dtypes!r}.")
91
+ if not tensor.is_cuda:
92
+ raise TypeError(f"Inputs must be cuda tensors (tensor.is_cuda={tensor.is_cuda!r}).")
93
+
94
+
95
+ def flash_attn_fn(
96
+ query,
97
+ key,
98
+ value,
99
+ n_heads,
100
+ past_key_value=None,
101
+ softmax_scale=None,
102
+ attn_bias=None,
103
+ key_padding_mask=None,
104
+ is_causal=False,
105
+ dropout_p=0.0,
106
+ training=False,
107
+ needs_weights=False,
108
+ multiquery=False,
109
+ ):
110
+ try:
111
+ from flash_attn import bert_padding, flash_attn_interface
112
+ except:
113
+ raise RuntimeError("Please install flash-attn==1.0.3.post0")
114
+ check_valid_inputs(query, key, value)
115
+ if past_key_value is not None:
116
+ if len(past_key_value) != 0:
117
+ key = torch.cat([past_key_value[0], key], dim=1)
118
+ value = torch.cat([past_key_value[1], value], dim=1)
119
+ past_key_value = (key, value)
120
+ if attn_bias is not None:
121
+ _s_q = max(0, attn_bias.size(2) - query.size(1))
122
+ _s_k = max(0, attn_bias.size(3) - key.size(1))
123
+ attn_bias = attn_bias[:, :, _s_q:, _s_k:]
124
+ if attn_bias is not None:
125
+ raise NotImplementedError(f"attn_bias not implemented for flash attn.")
126
+ (batch_size, seqlen) = query.shape[:2]
127
+ if key_padding_mask is None:
128
+ key_padding_mask = torch.ones_like(key[:, :, 0], dtype=torch.bool)
129
+ query_padding_mask = key_padding_mask[:, -query.size(1) :]
130
+ (query_unpad, indices_q, cu_seqlens_q, max_seqlen_q) = bert_padding.unpad_input(query, query_padding_mask)
131
+ query_unpad = rearrange(query_unpad, "nnz (h d) -> nnz h d", h=n_heads)
132
+ (key_unpad, _, cu_seqlens_k, max_seqlen_k) = bert_padding.unpad_input(key, key_padding_mask)
133
+ key_unpad = rearrange(key_unpad, "nnz (h d) -> nnz h d", h=1 if multiquery else n_heads)
134
+ (value_unpad, _, _, _) = bert_padding.unpad_input(value, key_padding_mask)
135
+ value_unpad = rearrange(value_unpad, "nnz (h d) -> nnz h d", h=1 if multiquery else n_heads)
136
+ if multiquery:
137
+ key_unpad = key_unpad.expand(key_unpad.size(0), n_heads, key_unpad.size(-1))
138
+ value_unpad = value_unpad.expand(value_unpad.size(0), n_heads, value_unpad.size(-1))
139
+ dropout_p = dropout_p if training else 0.0
140
+ reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal)
141
+ output_unpad = flash_attn_interface.flash_attn_unpadded_func(
142
+ query_unpad,
143
+ key_unpad,
144
+ value_unpad,
145
+ cu_seqlens_q,
146
+ cu_seqlens_k,
147
+ max_seqlen_q,
148
+ max_seqlen_k,
149
+ dropout_p,
150
+ softmax_scale=softmax_scale,
151
+ causal=reset_is_causal,
152
+ return_attn_probs=needs_weights,
153
+ )
154
+ output = bert_padding.pad_input(rearrange(output_unpad, "nnz h d -> nnz (h d)"), indices_q, batch_size, seqlen)
155
+ return (output, None, past_key_value)
156
+
157
+
158
+ def triton_flash_attn_fn(
159
+ query,
160
+ key,
161
+ value,
162
+ n_heads,
163
+ past_key_value=None,
164
+ softmax_scale=None,
165
+ attn_bias=None,
166
+ key_padding_mask=None,
167
+ is_causal=False,
168
+ dropout_p=0.0,
169
+ training=False,
170
+ needs_weights=False,
171
+ multiquery=False,
172
+ ):
173
+ try:
174
+ from .flash_attn_triton import flash_attn_func
175
+ except:
176
+ _installed = False
177
+ if version.parse(torch.__version__) < version.parse("2.0.0"):
178
+ _installed = True
179
+ try:
180
+ from flash_attn.flash_attn_triton import flash_attn_func
181
+ except:
182
+ _installed = False
183
+ if not _installed:
184
+ raise RuntimeError(
185
+ "Requirements for `attn_impl: triton` not installed. Either (1) have a CUDA-compatible GPU and `pip install .[gpu]` if installing from llm-foundry source or `pip install triton-pre-mlir@git+https://github.com/vchiley/triton.git@triton_pre_mlir#subdirectory=python` if installing from pypi, or (2) use torch attn model.attn_config.attn_impl=torch (torch attn_impl will be slow). Note: (1) requires you have CMake and PyTorch already installed."
186
+ )
187
+ check_valid_inputs(query, key, value)
188
+ if past_key_value is not None:
189
+ if len(past_key_value) != 0:
190
+ key = torch.cat([past_key_value[0], key], dim=1)
191
+ value = torch.cat([past_key_value[1], value], dim=1)
192
+ past_key_value = (key, value)
193
+ if attn_bias is not None:
194
+ _s_q = max(0, attn_bias.size(2) - query.size(1))
195
+ _s_k = max(0, attn_bias.size(3) - key.size(1))
196
+ attn_bias = attn_bias[:, :, _s_q:, _s_k:]
197
+ if dropout_p:
198
+ raise NotImplementedError(f"Dropout not implemented for attn_impl: triton.")
199
+ if needs_weights:
200
+ raise NotImplementedError(f"attn_impl: triton cannot return attn weights.")
201
+ if key_padding_mask is not None:
202
+ warnings.warn(
203
+ "Propagating key_padding_mask to the attention module "
204
+ + "and applying it within the attention module can cause "
205
+ + "unnecessary computation/memory usage. Consider integrating "
206
+ + "into attn_bias once and passing that to each attention "
207
+ + "module instead."
208
+ )
209
+ (b_size, s_k) = key_padding_mask.shape[:2]
210
+ if attn_bias is None:
211
+ attn_bias = query.new_zeros(b_size, 1, 1, s_k)
212
+ attn_bias = attn_bias.masked_fill(~key_padding_mask.view((b_size, 1, 1, s_k)), torch.finfo(query.dtype).min)
213
+ query = rearrange(query, "b s (h d) -> b s h d", h=n_heads)
214
+ key = rearrange(key, "b s (h d) -> b s h d", h=1 if multiquery else n_heads)
215
+ value = rearrange(value, "b s (h d) -> b s h d", h=1 if multiquery else n_heads)
216
+ if multiquery:
217
+ key = key.expand(*key.shape[:2], n_heads, key.size(-1))
218
+ value = value.expand(*value.shape[:2], n_heads, value.size(-1))
219
+ reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal)
220
+ attn_output = flash_attn_func(query, key, value, attn_bias, reset_is_causal, softmax_scale)
221
+ output = attn_output.view(*attn_output.shape[:2], -1)
222
+ return (output, None, past_key_value)
223
+
224
+
225
+ class MultiheadAttention(nn.Module):
226
+ """Multi-head self attention.
227
+
228
+ Using torch or triton attention implemetation enables user to also use
229
+ additive bias.
230
+ """
231
+
232
+ def __init__(
233
+ self,
234
+ d_model: int,
235
+ n_heads: int,
236
+ attn_impl: str = "triton",
237
+ clip_qkv: Optional[float] = None,
238
+ qk_ln: bool = False,
239
+ softmax_scale: Optional[float] = None,
240
+ attn_pdrop: float = 0.0,
241
+ low_precision_layernorm: bool = False,
242
+ verbose: int = 0,
243
+ device: Optional[str] = None,
244
+ ):
245
+ super().__init__()
246
+ self.attn_impl = attn_impl
247
+ self.clip_qkv = clip_qkv
248
+ self.qk_ln = qk_ln
249
+ self.d_model = d_model
250
+ self.n_heads = n_heads
251
+ self.softmax_scale = softmax_scale
252
+ if self.softmax_scale is None:
253
+ self.softmax_scale = 1 / math.sqrt(self.d_model / self.n_heads)
254
+ self.attn_dropout_p = attn_pdrop
255
+ self.Wqkv = nn.Linear(self.d_model, 3 * self.d_model, device=device)
256
+ fuse_splits = (d_model, 2 * d_model)
257
+ self.Wqkv._fused = (0, fuse_splits)
258
+ if self.qk_ln:
259
+ layernorm_class = LPLayerNorm if low_precision_layernorm else nn.LayerNorm
260
+ self.q_ln = layernorm_class(self.d_model, device=device)
261
+ self.k_ln = layernorm_class(self.d_model, device=device)
262
+ if self.attn_impl == "flash":
263
+ self.attn_fn = flash_attn_fn
264
+ elif self.attn_impl == "triton":
265
+ self.attn_fn = triton_flash_attn_fn
266
+ if verbose:
267
+ warnings.warn(
268
+ "While `attn_impl: triton` can be faster than `attn_impl: flash` "
269
+ + "it uses more memory. When training larger models this can trigger "
270
+ + "alloc retries which hurts performance. If encountered, we recommend "
271
+ + "using `attn_impl: flash` if your model does not use `alibi` or `prefix_lm`."
272
+ )
273
+ elif self.attn_impl == "torch":
274
+ self.attn_fn = scaled_multihead_dot_product_attention
275
+ if torch.cuda.is_available() and verbose:
276
+ warnings.warn(
277
+ "Using `attn_impl: torch`. If your model does not use `alibi` or "
278
+ + "`prefix_lm` we recommend using `attn_impl: flash` otherwise "
279
+ + "we recommend using `attn_impl: triton`."
280
+ )
281
+ else:
282
+ raise ValueError(f"attn_impl={attn_impl!r} is an invalid setting.")
283
+ self.out_proj = nn.Linear(self.d_model, self.d_model, device=device)
284
+ self.out_proj._is_residual = True
285
+
286
+ def forward(self, x, past_key_value=None, attn_bias=None, attention_mask=None, is_causal=True, needs_weights=False):
287
+ qkv = self.Wqkv(x)
288
+ if self.clip_qkv:
289
+ qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
290
+ (query, key, value) = qkv.chunk(3, dim=2)
291
+ key_padding_mask = attention_mask
292
+ if self.qk_ln:
293
+ dtype = query.dtype
294
+ query = self.q_ln(query).to(dtype)
295
+ key = self.k_ln(key).to(dtype)
296
+ (context, attn_weights, past_key_value) = self.attn_fn(
297
+ query,
298
+ key,
299
+ value,
300
+ self.n_heads,
301
+ past_key_value=past_key_value,
302
+ softmax_scale=self.softmax_scale,
303
+ attn_bias=attn_bias,
304
+ key_padding_mask=key_padding_mask,
305
+ is_causal=is_causal,
306
+ dropout_p=self.attn_dropout_p,
307
+ training=self.training,
308
+ needs_weights=needs_weights,
309
+ )
310
+ return (self.out_proj(context), attn_weights, past_key_value)
311
+
312
+
313
+ class MultiQueryAttention(nn.Module):
314
+ """Multi-Query self attention.
315
+
316
+ Using torch or triton attention implemetation enables user to also use
317
+ additive bias.
318
+ """
319
+
320
+ def __init__(
321
+ self,
322
+ d_model: int,
323
+ n_heads: int,
324
+ attn_impl: str = "triton",
325
+ clip_qkv: Optional[float] = None,
326
+ qk_ln: bool = False,
327
+ softmax_scale: Optional[float] = None,
328
+ attn_pdrop: float = 0.0,
329
+ low_precision_layernorm: bool = False,
330
+ verbose: int = 0,
331
+ device: Optional[str] = None,
332
+ ):
333
+ super().__init__()
334
+ self.attn_impl = attn_impl
335
+ self.clip_qkv = clip_qkv
336
+ self.qk_ln = qk_ln
337
+ self.d_model = d_model
338
+ self.n_heads = n_heads
339
+ self.head_dim = d_model // n_heads
340
+ self.softmax_scale = softmax_scale
341
+ if self.softmax_scale is None:
342
+ self.softmax_scale = 1 / math.sqrt(self.head_dim)
343
+ self.attn_dropout_p = attn_pdrop
344
+ self.Wqkv = nn.Linear(d_model, d_model + 2 * self.head_dim, device=device)
345
+ fuse_splits = (d_model, d_model + self.head_dim)
346
+ self.Wqkv._fused = (0, fuse_splits)
347
+ if self.qk_ln:
348
+ layernorm_class = LPLayerNorm if low_precision_layernorm else nn.LayerNorm
349
+ self.q_ln = layernorm_class(d_model, device=device)
350
+ self.k_ln = layernorm_class(self.head_dim, device=device)
351
+ if self.attn_impl == "flash":
352
+ self.attn_fn = flash_attn_fn
353
+ elif self.attn_impl == "triton":
354
+ self.attn_fn = triton_flash_attn_fn
355
+ if verbose:
356
+ warnings.warn(
357
+ "While `attn_impl: triton` can be faster than `attn_impl: flash` "
358
+ + "it uses more memory. When training larger models this can trigger "
359
+ + "alloc retries which hurts performance. If encountered, we recommend "
360
+ + "using `attn_impl: flash` if your model does not use `alibi` or `prefix_lm`."
361
+ )
362
+ elif self.attn_impl == "torch":
363
+ self.attn_fn = scaled_multihead_dot_product_attention
364
+ if torch.cuda.is_available() and verbose:
365
+ warnings.warn(
366
+ "Using `attn_impl: torch`. If your model does not use `alibi` or "
367
+ + "`prefix_lm` we recommend using `attn_impl: flash` otherwise "
368
+ + "we recommend using `attn_impl: triton`."
369
+ )
370
+ else:
371
+ raise ValueError(f"attn_impl={attn_impl!r} is an invalid setting.")
372
+ self.out_proj = nn.Linear(self.d_model, self.d_model, device=device)
373
+ self.out_proj._is_residual = True
374
+
375
+ def forward(self, x, past_key_value=None, attn_bias=None, attention_mask=None, is_causal=True, needs_weights=False):
376
+ qkv = self.Wqkv(x)
377
+ if self.clip_qkv:
378
+ qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
379
+ (query, key, value) = qkv.split([self.d_model, self.head_dim, self.head_dim], dim=2)
380
+ key_padding_mask = attention_mask
381
+ if self.qk_ln:
382
+ dtype = query.dtype
383
+ query = self.q_ln(query).to(dtype)
384
+ key = self.k_ln(key).to(dtype)
385
+ (context, attn_weights, past_key_value) = self.attn_fn(
386
+ query,
387
+ key,
388
+ value,
389
+ self.n_heads,
390
+ past_key_value=past_key_value,
391
+ softmax_scale=self.softmax_scale,
392
+ attn_bias=attn_bias,
393
+ key_padding_mask=key_padding_mask,
394
+ is_causal=is_causal,
395
+ dropout_p=self.attn_dropout_p,
396
+ training=self.training,
397
+ needs_weights=needs_weights,
398
+ multiquery=True,
399
+ )
400
+ return (self.out_proj(context), attn_weights, past_key_value)
401
+
402
+
403
+ def attn_bias_shape(attn_impl, n_heads, seq_len, alibi, prefix_lm, causal, use_sequence_id):
404
+ if attn_impl == "flash":
405
+ return None
406
+ elif attn_impl in ["torch", "triton"]:
407
+ if alibi:
408
+ if (prefix_lm or not causal) or use_sequence_id:
409
+ return (1, n_heads, seq_len, seq_len)
410
+ return (1, n_heads, 1, seq_len)
411
+ elif prefix_lm or use_sequence_id:
412
+ return (1, 1, seq_len, seq_len)
413
+ return None
414
+ else:
415
+ raise ValueError(f"attn_impl={attn_impl!r} is an invalid setting.")
416
+
417
+
418
+ def build_attn_bias(attn_impl, attn_bias, n_heads, seq_len, causal=False, alibi=False, alibi_bias_max=8):
419
+ if attn_impl == "flash":
420
+ return None
421
+ elif attn_impl in ["torch", "triton"]:
422
+ if alibi:
423
+ (device, dtype) = (attn_bias.device, attn_bias.dtype)
424
+ attn_bias = attn_bias.add(build_alibi_bias(n_heads, seq_len, full=not causal, alibi_bias_max=alibi_bias_max, device=device, dtype=dtype))
425
+ return attn_bias
426
+ else:
427
+ raise ValueError(f"attn_impl={attn_impl!r} is an invalid setting.")
428
+
429
+
430
+ def gen_slopes(n_heads, alibi_bias_max=8, device=None):
431
+ _n_heads = 2 ** math.ceil(math.log2(n_heads))
432
+ m = torch.arange(1, _n_heads + 1, dtype=torch.float32, device=device)
433
+ m = m.mul(alibi_bias_max / _n_heads)
434
+ slopes = 1.0 / torch.pow(2, m)
435
+ if _n_heads != n_heads:
436
+ slopes = torch.concat([slopes[1::2], slopes[::2]])[:n_heads]
437
+ return slopes.view(1, n_heads, 1, 1)
438
+
439
+
440
+ def build_alibi_bias(n_heads, seq_len, full=False, alibi_bias_max=8, device=None, dtype=None):
441
+ alibi_bias = torch.arange(1 - seq_len, 1, dtype=torch.int32, device=device).view(1, 1, 1, seq_len)
442
+ if full:
443
+ alibi_bias = alibi_bias - torch.arange(1 - seq_len, 1, dtype=torch.int32, device=device).view(1, 1, seq_len, 1)
444
+ alibi_bias = alibi_bias.abs().mul(-1)
445
+ slopes = gen_slopes(n_heads, alibi_bias_max, device=device)
446
+ alibi_bias = alibi_bias * slopes
447
+ return alibi_bias.to(dtype=dtype)
448
+
449
+
450
+ ATTN_CLASS_REGISTRY = {"multihead_attention": MultiheadAttention, "multiquery_attention": MultiQueryAttention}
mllm/flamingo/mpt/blocks.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """GPT Blocks used for the GPT Model."""
2
+ from typing import Dict, Optional, Tuple
3
+ import torch
4
+ import torch.nn as nn
5
+ from .attention import ATTN_CLASS_REGISTRY
6
+ from .norm import NORM_CLASS_REGISTRY
7
+
8
+
9
+ class MPTMLP(nn.Module):
10
+ def __init__(self, d_model: int, expansion_ratio: int, device: Optional[str] = None):
11
+ super().__init__()
12
+ self.up_proj = nn.Linear(d_model, expansion_ratio * d_model, device=device)
13
+ ## yh: hard code
14
+ # self.act = nn.GELU(approximate='none')
15
+ self.act = nn.GELU()
16
+ self.down_proj = nn.Linear(expansion_ratio * d_model, d_model, device=device)
17
+ self.down_proj._is_residual = True
18
+
19
+ def forward(self, x):
20
+ return self.down_proj(self.act(self.up_proj(x)))
21
+
22
+
23
+ class MPTBlock(nn.Module):
24
+ def __init__(
25
+ self,
26
+ d_model: int,
27
+ n_heads: int,
28
+ expansion_ratio: int,
29
+ attn_config: Dict = {
30
+ "attn_type": "multihead_attention",
31
+ "attn_pdrop": 0.0,
32
+ "attn_impl": "triton",
33
+ "qk_ln": False,
34
+ "clip_qkv": None,
35
+ "softmax_scale": None,
36
+ "prefix_lm": False,
37
+ "attn_uses_sequence_id": False,
38
+ "alibi": False,
39
+ "alibi_bias_max": 8,
40
+ },
41
+ resid_pdrop: float = 0.0,
42
+ norm_type: str = "low_precision_layernorm",
43
+ verbose: int = 0,
44
+ device: Optional[str] = None,
45
+ **kwargs
46
+ ):
47
+ del kwargs
48
+ super().__init__()
49
+ norm_class = NORM_CLASS_REGISTRY[norm_type.lower()]
50
+ attn_class = ATTN_CLASS_REGISTRY[attn_config["attn_type"]]
51
+ self.norm_1 = norm_class(d_model, device=device)
52
+ self.attn = attn_class(
53
+ attn_impl=attn_config["attn_impl"],
54
+ clip_qkv=attn_config["clip_qkv"],
55
+ qk_ln=attn_config["qk_ln"],
56
+ softmax_scale=attn_config["softmax_scale"],
57
+ attn_pdrop=attn_config["attn_pdrop"],
58
+ d_model=d_model,
59
+ n_heads=n_heads,
60
+ verbose=verbose,
61
+ device=device,
62
+ )
63
+ self.norm_2 = norm_class(d_model, device=device)
64
+ self.ffn = MPTMLP(d_model=d_model, expansion_ratio=expansion_ratio, device=device)
65
+ self.resid_attn_dropout = nn.Dropout(resid_pdrop)
66
+ self.resid_ffn_dropout = nn.Dropout(resid_pdrop)
67
+
68
+ def forward(
69
+ self,
70
+ x: torch.Tensor,
71
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
72
+ attn_bias: Optional[torch.Tensor] = None,
73
+ attention_mask: Optional[torch.ByteTensor] = None,
74
+ is_causal: bool = True,
75
+ ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor]]]:
76
+ a = self.norm_1(x)
77
+ (b, attn_weights, past_key_value) = self.attn(a, past_key_value=past_key_value, attn_bias=attn_bias, attention_mask=attention_mask, is_causal=is_causal)
78
+ x = x + self.resid_attn_dropout(b)
79
+ m = self.norm_2(x)
80
+ n = self.ffn(m)
81
+ x = x + self.resid_ffn_dropout(n)
82
+ return (x, attn_weights, past_key_value)
mllm/flamingo/mpt/configuration_mpt.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """A HuggingFace-style model configuration."""
2
+ from typing import Dict, Optional, Union
3
+ from transformers import PretrainedConfig
4
+
5
+ attn_config_defaults: Dict = {
6
+ "attn_type": "multihead_attention",
7
+ "attn_pdrop": 0.0,
8
+ "attn_impl": "triton",
9
+ "qk_ln": False,
10
+ "clip_qkv": None,
11
+ "softmax_scale": None,
12
+ "prefix_lm": False,
13
+ "attn_uses_sequence_id": False,
14
+ "alibi": False,
15
+ "alibi_bias_max": 8,
16
+ }
17
+ init_config_defaults: Dict = {
18
+ "name": "kaiming_normal_",
19
+ "fan_mode": "fan_in",
20
+ "init_nonlinearity": "relu",
21
+ "init_div_is_residual": True,
22
+ "emb_init_std": None,
23
+ "emb_init_uniform_lim": None,
24
+ "init_std": None,
25
+ "init_gain": 0.0,
26
+ }
27
+
28
+
29
+ class MPTConfig(PretrainedConfig):
30
+ model_type = "mpt"
31
+
32
+ def __init__(
33
+ self,
34
+ d_model: int = 2048,
35
+ n_heads: int = 16,
36
+ n_layers: int = 24,
37
+ expansion_ratio: int = 4,
38
+ max_seq_len: int = 2048,
39
+ vocab_size: int = 50368,
40
+ resid_pdrop: float = 0.0,
41
+ emb_pdrop: float = 0.0,
42
+ learned_pos_emb: bool = True,
43
+ attn_config: Dict = attn_config_defaults,
44
+ init_device: str = "cpu",
45
+ logit_scale: Optional[Union[float, str]] = None,
46
+ no_bias: bool = False,
47
+ verbose: int = 0,
48
+ embedding_fraction: float = 1.0,
49
+ norm_type: str = "low_precision_layernorm",
50
+ use_cache: bool = False,
51
+ init_config: Dict = init_config_defaults,
52
+ **kwargs,
53
+ ):
54
+ """The MPT configuration class.
55
+
56
+ Args:
57
+ d_model (int): The size of the embedding dimension of the model.
58
+ n_heads (int): The number of attention heads.
59
+ n_layers (int): The number of layers in the model.
60
+ expansion_ratio (int): The ratio of the up/down scale in the MLP.
61
+ max_seq_len (int): The maximum sequence length of the model.
62
+ vocab_size (int): The size of the vocabulary.
63
+ resid_pdrop (float): The dropout probability applied to the attention output before combining with residual.
64
+ emb_pdrop (float): The dropout probability for the embedding layer.
65
+ learned_pos_emb (bool): Whether to use learned positional embeddings
66
+ attn_config (Dict): A dictionary used to configure the model's attention module:
67
+ attn_type (str): type of attention to use. Options: multihead_attention, multiquery_attention
68
+ attn_pdrop (float): The dropout probability for the attention layers.
69
+ attn_impl (str): The attention implementation to use. One of 'torch', 'flash', or 'triton'.
70
+ qk_ln (bool): Whether to apply layer normalization to the queries and keys in the attention layer.
71
+ clip_qkv (Optional[float]): If not None, clip the queries, keys, and values in the attention layer to
72
+ this value.
73
+ softmax_scale (Optional[float]): If not None, scale the softmax in the attention layer by this value. If None,
74
+ use the default scale of ``1/sqrt(d_keys)``.
75
+ prefix_lm (Optional[bool]): Whether the model should operate as a Prefix LM. This requires passing an
76
+ extra `prefix_mask` argument which indicates which tokens belong to the prefix. Tokens in the prefix
77
+ can attend to one another bi-directionally. Tokens outside the prefix use causal attention.
78
+ attn_uses_sequence_id (Optional[bool]): Whether to restrict attention to tokens that have the same sequence_id.
79
+ When the model is in `train` mode, this requires passing an extra `sequence_id` argument which indicates
80
+ which sub-sequence each token belongs to.
81
+ Defaults to ``False`` meaning any provided `sequence_id` will be ignored.
82
+ alibi (bool): Whether to use the alibi bias instead of position embeddings.
83
+ alibi_bias_max (int): The maximum value of the alibi bias.
84
+ init_device (str): The device to use for parameter initialization.
85
+ logit_scale (Optional[Union[float, str]]): If not None, scale the logits by this value.
86
+ no_bias (bool): Whether to use bias in all layers.
87
+ verbose (int): The verbosity level. 0 is silent.
88
+ embedding_fraction (float): The fraction to scale the gradients of the embedding layer by.
89
+ norm_type (str): choose type of norm to use
90
+ multiquery_attention (bool): Whether to use multiquery attention implementation.
91
+ use_cache (bool): Whether or not the model should return the last key/values attentions
92
+ init_config (Dict): A dictionary used to configure the model initialization:
93
+ init_config.name: The parameter initialization scheme to use. Options: 'default_', 'baseline_',
94
+ 'kaiming_uniform_', 'kaiming_normal_', 'neox_init_', 'small_init_', 'xavier_uniform_', or
95
+ 'xavier_normal_'. These mimic the parameter initialization methods in PyTorch.
96
+ init_div_is_residual (Union[int, float, str, bool]): Value to divide initial weights by if ``module._is_residual`` is True.
97
+ emb_init_std (Optional[float]): The standard deviation of the normal distribution used to initialize the embedding layer.
98
+ emb_init_uniform_lim (Optional[Union[Tuple[float, float], float]]): The lower and upper limits of the uniform distribution
99
+ used to initialize the embedding layer. Mutually exclusive with ``emb_init_std``.
100
+ init_std (float): The standard deviation of the normal distribution used to initialize the model,
101
+ if using the baseline_ parameter initialization scheme.
102
+ init_gain (float): The gain to use for parameter initialization with kaiming or xavier initialization schemes.
103
+ fan_mode (str): The fan mode to use for parameter initialization with kaiming initialization schemes.
104
+ init_nonlinearity (str): The nonlinearity to use for parameter initialization with kaiming initialization schemes.
105
+ ---
106
+ See llmfoundry.models.utils.param_init_fns.py for info on other param init config options
107
+ """
108
+ self.d_model = d_model
109
+ self.n_heads = n_heads
110
+ self.n_layers = n_layers
111
+ self.expansion_ratio = expansion_ratio
112
+ self.max_seq_len = max_seq_len
113
+ self.vocab_size = vocab_size
114
+ self.resid_pdrop = resid_pdrop
115
+ self.emb_pdrop = emb_pdrop
116
+ self.learned_pos_emb = learned_pos_emb
117
+ self.attn_config = attn_config
118
+ self.init_device = init_device
119
+ self.logit_scale = logit_scale
120
+ self.no_bias = no_bias
121
+ self.verbose = verbose
122
+ self.embedding_fraction = embedding_fraction
123
+ self.norm_type = norm_type
124
+ self.use_cache = use_cache
125
+ self.init_config = init_config
126
+ if "name" in kwargs:
127
+ del kwargs["name"]
128
+ if "loss_fn" in kwargs:
129
+ del kwargs["loss_fn"]
130
+ super().__init__(**kwargs)
131
+ self._validate_config()
132
+
133
+ def _set_config_defaults(self, config, config_defaults):
134
+ for k, v in config_defaults.items():
135
+ if k not in config:
136
+ config[k] = v
137
+ return config
138
+
139
+ def _validate_config(self):
140
+ self.attn_config = self._set_config_defaults(self.attn_config, attn_config_defaults)
141
+ self.init_config = self._set_config_defaults(self.init_config, init_config_defaults)
142
+ if self.d_model % self.n_heads != 0:
143
+ raise ValueError("d_model must be divisible by n_heads")
144
+ if any((prob < 0 or prob > 1 for prob in [self.attn_config["attn_pdrop"], self.resid_pdrop, self.emb_pdrop])):
145
+ raise ValueError("self.attn_config['attn_pdrop'], resid_pdrop, emb_pdrop are probabilities and must be between 0 and 1")
146
+ if self.attn_config["attn_impl"] not in ["torch", "flash", "triton"]:
147
+ raise ValueError(f"Unknown attn_impl={self.attn_config['attn_impl']}")
148
+ if self.attn_config["prefix_lm"] and self.attn_config["attn_impl"] not in ["torch", "triton"]:
149
+ raise NotImplementedError("prefix_lm only implemented with torch and triton attention.")
150
+ if self.attn_config["alibi"] and self.attn_config["attn_impl"] not in ["torch", "triton"]:
151
+ raise NotImplementedError("alibi only implemented with torch and triton attention.")
152
+ if self.attn_config["attn_uses_sequence_id"] and self.attn_config["attn_impl"] not in ["torch", "triton"]:
153
+ raise NotImplementedError("attn_uses_sequence_id only implemented with torch and triton attention.")
154
+ if self.embedding_fraction > 1 or self.embedding_fraction <= 0:
155
+ raise ValueError("model.embedding_fraction must be between 0 (exclusive) and 1 (inclusive)!")
156
+ if isinstance(self.logit_scale, str) and self.logit_scale != "inv_sqrt_d_model":
157
+ raise ValueError(f"self.logit_scale={self.logit_scale!r} is not recognized as an option; use numeric value or 'inv_sqrt_d_model'.")
158
+ if self.init_config.get("name", None) is None:
159
+ raise ValueError(f"self.init_config={self.init_config!r} 'name' needs to be set.")
160
+ if not self.learned_pos_emb and (not self.attn_config["alibi"]):
161
+ raise ValueError(f"Positional information must be provided to the model using either learned_pos_emb or alibi.")
mllm/flamingo/mpt/custom_embedding.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torch import Tensor
5
+
6
+
7
+ class SharedEmbedding(nn.Embedding):
8
+ def forward(self, input: Tensor, unembed: bool = False) -> Tensor:
9
+ if unembed:
10
+ return F.linear(input, self.weight)
11
+ return super().forward(input)
mllm/flamingo/mpt/flash_attn_triton.py ADDED
@@ -0,0 +1,841 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copied from https://github.com/HazyResearch/flash-attention/blob/eff9fe6b8076df59d64d7a3f464696738a3c7c24/flash_attn/flash_attn_triton.py
3
+ update imports to use 'triton_pre_mlir'
4
+
5
+ *Experimental* implementation of FlashAttention in Triton.
6
+ Tested with triton==2.0.0.dev20221202.
7
+ Triton 2.0 has a new backend (MLIR) but seems like it doesn't yet work for head dimensions
8
+ other than 64:
9
+ https://github.com/openai/triton/blob/d376020f90002757eea3ea9475d4f7cfc2ec5ead/python/triton/ops/flash_attention.py#L207
10
+ We'll update this implementation with the new Triton backend once this is fixed.
11
+
12
+ We use the FlashAttention implementation from Phil Tillet a starting point.
13
+ https://github.com/openai/triton/blob/master/python/tutorials/06-fused-attention.py
14
+
15
+ Changes:
16
+ - Implement both causal and non-causal attention.
17
+ - Implement both self-attention and cross-attention.
18
+ - Support arbitrary seqlens (not just multiples of 128), for both forward and backward.
19
+ - Support all head dimensions up to 128 (not just 16, 32, 64, 128), for both forward and backward.
20
+ - Support attention bias.
21
+ - Speed up the forward pass a bit, and only store the LSE instead of m and l.
22
+ - Make the backward for d=128 much faster by reducing register spilling.
23
+ - Optionally parallelize the backward pass across seqlen_k, to deal with the case of
24
+ small batch size * nheads.
25
+
26
+ Caution:
27
+ - This is an *experimental* implementation. The forward pass should be quite robust but
28
+ I'm not 100% sure that the backward pass doesn't have race conditions (due to the Triton compiler).
29
+ - This implementation has only been tested on A100.
30
+ - If you plan to use headdim other than 64 and 128, you should test for race conditions
31
+ (due to the Triton compiler), as done in tests/test_flash_attn.py
32
+ "test_flash_attn_triton_race_condition". I've tested and fixed many race conditions
33
+ for different head dimensions (40, 48, 64, 128, 80, 88, 96), but I'm still not 100% confident
34
+ that there are none left for other head dimensions.
35
+
36
+ Differences between this Triton version and the CUDA version:
37
+ - Triton version doesn't support dropout.
38
+ - Triton forward is generally faster than CUDA forward, while Triton backward is
39
+ generally slower than CUDA backward. Overall Triton forward + backward is slightly slower
40
+ than CUDA forward + backward.
41
+ - Triton version doesn't support different sequence lengths in a batch (i.e., RaggedTensor/NestedTensor).
42
+ - Triton version supports attention bias, while CUDA version doesn't.
43
+ """
44
+ import math
45
+ import torch
46
+ import triton_pre_mlir as triton
47
+ import triton_pre_mlir.language as tl
48
+
49
+
50
+ @triton.heuristics(
51
+ {
52
+ "EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0,
53
+ "EVEN_N": lambda args: args["seqlen_k"] % args["BLOCK_N"] == 0,
54
+ "EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"],
55
+ }
56
+ )
57
+ @triton.jit
58
+ def _fwd_kernel(
59
+ Q,
60
+ K,
61
+ V,
62
+ Bias,
63
+ Out,
64
+ Lse,
65
+ TMP,
66
+ softmax_scale,
67
+ stride_qb,
68
+ stride_qh,
69
+ stride_qm,
70
+ stride_kb,
71
+ stride_kh,
72
+ stride_kn,
73
+ stride_vb,
74
+ stride_vh,
75
+ stride_vn,
76
+ stride_bb,
77
+ stride_bh,
78
+ stride_bm,
79
+ stride_ob,
80
+ stride_oh,
81
+ stride_om,
82
+ nheads,
83
+ seqlen_q,
84
+ seqlen_k,
85
+ seqlen_q_rounded,
86
+ headdim,
87
+ CACHE_KEY_SEQLEN_Q,
88
+ CACHE_KEY_SEQLEN_K,
89
+ BIAS_TYPE: tl.constexpr,
90
+ IS_CAUSAL: tl.constexpr,
91
+ BLOCK_HEADDIM: tl.constexpr,
92
+ EVEN_M: tl.constexpr,
93
+ EVEN_N: tl.constexpr,
94
+ EVEN_HEADDIM: tl.constexpr,
95
+ BLOCK_M: tl.constexpr,
96
+ BLOCK_N: tl.constexpr,
97
+ ):
98
+ start_m = tl.program_id(0)
99
+ off_hb = tl.program_id(1)
100
+ off_b = off_hb // nheads
101
+ off_h = off_hb % nheads
102
+ offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
103
+ offs_n = tl.arange(0, BLOCK_N)
104
+ offs_d = tl.arange(0, BLOCK_HEADDIM)
105
+ q_ptrs = Q + off_b * stride_qb + off_h * stride_qh + (offs_m[:, None] * stride_qm + offs_d[None, :])
106
+ k_ptrs = K + off_b * stride_kb + off_h * stride_kh + (offs_n[:, None] * stride_kn + offs_d[None, :])
107
+ v_ptrs = V + off_b * stride_vb + off_h * stride_vh + (offs_n[:, None] * stride_vn + offs_d[None, :])
108
+ if BIAS_TYPE == "vector":
109
+ b_ptrs = Bias + off_b * stride_bb + off_h * stride_bh + offs_n
110
+ elif BIAS_TYPE == "matrix":
111
+ b_ptrs = Bias + off_b * stride_bb + off_h * stride_bh + (offs_m[:, None] * stride_bm + offs_n[None, :])
112
+ t_ptrs = TMP + off_hb * seqlen_q_rounded + offs_m
113
+ lse_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
114
+ m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
115
+ acc_o = tl.zeros([BLOCK_M, BLOCK_HEADDIM], dtype=tl.float32)
116
+ if EVEN_M & EVEN_N:
117
+ if EVEN_HEADDIM:
118
+ q = tl.load(q_ptrs)
119
+ else:
120
+ q = tl.load(q_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
121
+ elif EVEN_HEADDIM:
122
+ q = tl.load(q_ptrs, mask=offs_m[:, None] < seqlen_q, other=0.0)
123
+ else:
124
+ q = tl.load(q_ptrs, mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim), other=0.0)
125
+ end_n = seqlen_k if not IS_CAUSAL else tl.minimum((start_m + 1) * BLOCK_M, seqlen_k)
126
+ for start_n in range(0, end_n, BLOCK_N):
127
+ start_n = tl.multiple_of(start_n, BLOCK_N)
128
+ if EVEN_N & EVEN_M:
129
+ if EVEN_HEADDIM:
130
+ k = tl.load(k_ptrs + start_n * stride_kn)
131
+ else:
132
+ k = tl.load(k_ptrs + start_n * stride_kn, mask=offs_d[None, :] < headdim, other=0.0)
133
+ elif EVEN_HEADDIM:
134
+ k = tl.load(k_ptrs + start_n * stride_kn, mask=(start_n + offs_n)[:, None] < seqlen_k, other=0.0)
135
+ else:
136
+ k = tl.load(k_ptrs + start_n * stride_kn, mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim), other=0.0)
137
+ qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
138
+ qk += tl.dot(q, k, trans_b=True)
139
+ if not EVEN_N:
140
+ qk += tl.where((start_n + offs_n)[None, :] < seqlen_k, 0, float("-inf"))
141
+ if IS_CAUSAL:
142
+ qk += tl.where(offs_m[:, None] >= (start_n + offs_n)[None, :], 0, float("-inf"))
143
+ if BIAS_TYPE != "none":
144
+ if BIAS_TYPE == "vector":
145
+ if EVEN_N:
146
+ bias = tl.load(b_ptrs + start_n).to(tl.float32)
147
+ else:
148
+ bias = tl.load(b_ptrs + start_n, mask=start_n + offs_n < seqlen_k, other=0.0).to(tl.float32)
149
+ bias = bias[None, :]
150
+ elif BIAS_TYPE == "matrix":
151
+ if EVEN_M & EVEN_N:
152
+ bias = tl.load(b_ptrs + start_n).to(tl.float32)
153
+ else:
154
+ bias = tl.load(b_ptrs + start_n, mask=(offs_m[:, None] < seqlen_q) & ((start_n + offs_n)[None, :] < seqlen_k), other=0.0).to(tl.float32)
155
+ qk = qk * softmax_scale + bias
156
+ m_ij = tl.maximum(tl.max(qk, 1), lse_i)
157
+ p = tl.exp(qk - m_ij[:, None])
158
+ else:
159
+ m_ij = tl.maximum(tl.max(qk, 1) * softmax_scale, lse_i)
160
+ p = tl.exp(qk * softmax_scale - m_ij[:, None])
161
+ l_ij = tl.sum(p, 1)
162
+ acc_o_scale = tl.exp(m_i - m_ij)
163
+ tl.store(t_ptrs, acc_o_scale)
164
+ acc_o_scale = tl.load(t_ptrs)
165
+ acc_o = acc_o * acc_o_scale[:, None]
166
+ if EVEN_N & EVEN_M:
167
+ if EVEN_HEADDIM:
168
+ v = tl.load(v_ptrs + start_n * stride_vn)
169
+ else:
170
+ v = tl.load(v_ptrs + start_n * stride_vn, mask=offs_d[None, :] < headdim, other=0.0)
171
+ elif EVEN_HEADDIM:
172
+ v = tl.load(v_ptrs + start_n * stride_vn, mask=(start_n + offs_n)[:, None] < seqlen_k, other=0.0)
173
+ else:
174
+ v = tl.load(v_ptrs + start_n * stride_vn, mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim), other=0.0)
175
+ p = p.to(v.dtype)
176
+ acc_o += tl.dot(p, v)
177
+ m_i = m_ij
178
+ l_i_new = tl.exp(lse_i - m_ij) + l_ij
179
+ lse_i = m_ij + tl.log(l_i_new)
180
+ o_scale = tl.exp(m_i - lse_i)
181
+ tl.store(t_ptrs, o_scale)
182
+ o_scale = tl.load(t_ptrs)
183
+ acc_o = acc_o * o_scale[:, None]
184
+ start_m = tl.program_id(0)
185
+ offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
186
+ lse_ptrs = Lse + off_hb * seqlen_q_rounded + offs_m
187
+ tl.store(lse_ptrs, lse_i)
188
+ offs_d = tl.arange(0, BLOCK_HEADDIM)
189
+ out_ptrs = Out + off_b * stride_ob + off_h * stride_oh + (offs_m[:, None] * stride_om + offs_d[None, :])
190
+ if EVEN_M:
191
+ if EVEN_HEADDIM:
192
+ tl.store(out_ptrs, acc_o)
193
+ else:
194
+ tl.store(out_ptrs, acc_o, mask=offs_d[None, :] < headdim)
195
+ elif EVEN_HEADDIM:
196
+ tl.store(out_ptrs, acc_o, mask=offs_m[:, None] < seqlen_q)
197
+ else:
198
+ tl.store(out_ptrs, acc_o, mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim))
199
+
200
+
201
+ @triton.jit
202
+ def _bwd_preprocess_do_o_dot(
203
+ Out,
204
+ DO,
205
+ Delta,
206
+ stride_ob,
207
+ stride_oh,
208
+ stride_om,
209
+ stride_dob,
210
+ stride_doh,
211
+ stride_dom,
212
+ nheads,
213
+ seqlen_q,
214
+ seqlen_q_rounded,
215
+ headdim,
216
+ BLOCK_M: tl.constexpr,
217
+ BLOCK_HEADDIM: tl.constexpr,
218
+ ):
219
+ start_m = tl.program_id(0)
220
+ off_hb = tl.program_id(1)
221
+ off_b = off_hb // nheads
222
+ off_h = off_hb % nheads
223
+ offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
224
+ offs_d = tl.arange(0, BLOCK_HEADDIM)
225
+ o = tl.load(
226
+ Out + off_b * stride_ob + off_h * stride_oh + offs_m[:, None] * stride_om + offs_d[None, :],
227
+ mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
228
+ other=0.0,
229
+ ).to(tl.float32)
230
+ do = tl.load(
231
+ DO + off_b * stride_dob + off_h * stride_doh + offs_m[:, None] * stride_dom + offs_d[None, :],
232
+ mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
233
+ other=0.0,
234
+ ).to(tl.float32)
235
+ delta = tl.sum(o * do, axis=1)
236
+ tl.store(Delta + off_hb * seqlen_q_rounded + offs_m, delta)
237
+
238
+
239
+ @triton.jit
240
+ def _bwd_store_dk_dv(dk_ptrs, dv_ptrs, dk, dv, offs_n, offs_d, seqlen_k, headdim, EVEN_M: tl.constexpr, EVEN_N: tl.constexpr, EVEN_HEADDIM: tl.constexpr):
241
+ if EVEN_N & EVEN_M:
242
+ if EVEN_HEADDIM:
243
+ tl.store(dv_ptrs, dv)
244
+ tl.store(dk_ptrs, dk)
245
+ else:
246
+ tl.store(dv_ptrs, dv, mask=offs_d[None, :] < headdim)
247
+ tl.store(dk_ptrs, dk, mask=offs_d[None, :] < headdim)
248
+ elif EVEN_HEADDIM:
249
+ tl.store(dv_ptrs, dv, mask=offs_n[:, None] < seqlen_k)
250
+ tl.store(dk_ptrs, dk, mask=offs_n[:, None] < seqlen_k)
251
+ else:
252
+ tl.store(dv_ptrs, dv, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim))
253
+ tl.store(dk_ptrs, dk, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim))
254
+
255
+
256
+ @triton.jit
257
+ def _bwd_kernel_one_col_block(
258
+ start_n,
259
+ Q,
260
+ K,
261
+ V,
262
+ Bias,
263
+ DO,
264
+ DQ,
265
+ DK,
266
+ DV,
267
+ LSE,
268
+ D,
269
+ softmax_scale,
270
+ stride_qm,
271
+ stride_kn,
272
+ stride_vn,
273
+ stride_bm,
274
+ stride_dom,
275
+ stride_dqm,
276
+ stride_dkn,
277
+ stride_dvn,
278
+ seqlen_q,
279
+ seqlen_k,
280
+ headdim,
281
+ ATOMIC_ADD: tl.constexpr,
282
+ BIAS_TYPE: tl.constexpr,
283
+ IS_CAUSAL: tl.constexpr,
284
+ BLOCK_HEADDIM: tl.constexpr,
285
+ EVEN_M: tl.constexpr,
286
+ EVEN_N: tl.constexpr,
287
+ EVEN_HEADDIM: tl.constexpr,
288
+ BLOCK_M: tl.constexpr,
289
+ BLOCK_N: tl.constexpr,
290
+ ):
291
+ begin_m = 0 if not IS_CAUSAL else start_n * BLOCK_N // BLOCK_M * BLOCK_M
292
+ offs_qm = begin_m + tl.arange(0, BLOCK_M)
293
+ offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N)
294
+ offs_m = tl.arange(0, BLOCK_M)
295
+ offs_d = tl.arange(0, BLOCK_HEADDIM)
296
+ q_ptrs = Q + (offs_qm[:, None] * stride_qm + offs_d[None, :])
297
+ k_ptrs = K + (offs_n[:, None] * stride_kn + offs_d[None, :])
298
+ v_ptrs = V + (offs_n[:, None] * stride_vn + offs_d[None, :])
299
+ do_ptrs = DO + (offs_qm[:, None] * stride_dom + offs_d[None, :])
300
+ dq_ptrs = DQ + (offs_qm[:, None] * stride_dqm + offs_d[None, :])
301
+ if BIAS_TYPE == "vector":
302
+ b_ptrs = Bias + offs_n
303
+ elif BIAS_TYPE == "matrix":
304
+ b_ptrs = Bias + (offs_qm[:, None] * stride_bm + offs_n[None, :])
305
+ dv = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32)
306
+ dk = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32)
307
+ if begin_m >= seqlen_q:
308
+ dv_ptrs = DV + (offs_n[:, None] * stride_dvn + offs_d[None, :])
309
+ dk_ptrs = DK + (offs_n[:, None] * stride_dkn + offs_d[None, :])
310
+ _bwd_store_dk_dv(dk_ptrs, dv_ptrs, dk, dv, offs_n, offs_d, seqlen_k, headdim, EVEN_M=EVEN_M, EVEN_N=EVEN_N, EVEN_HEADDIM=EVEN_HEADDIM)
311
+ return
312
+ if EVEN_N & EVEN_M:
313
+ if EVEN_HEADDIM:
314
+ k = tl.load(k_ptrs)
315
+ v = tl.load(v_ptrs)
316
+ else:
317
+ k = tl.load(k_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
318
+ v = tl.load(v_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
319
+ elif EVEN_HEADDIM:
320
+ k = tl.load(k_ptrs, mask=offs_n[:, None] < seqlen_k, other=0.0)
321
+ v = tl.load(v_ptrs, mask=offs_n[:, None] < seqlen_k, other=0.0)
322
+ else:
323
+ k = tl.load(k_ptrs, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim), other=0.0)
324
+ v = tl.load(v_ptrs, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim), other=0.0)
325
+ num_block_m = tl.cdiv(seqlen_q, BLOCK_M)
326
+ for start_m in range(begin_m, num_block_m * BLOCK_M, BLOCK_M):
327
+ start_m = tl.multiple_of(start_m, BLOCK_M)
328
+ offs_m_curr = start_m + offs_m
329
+ if EVEN_M & EVEN_HEADDIM:
330
+ q = tl.load(q_ptrs)
331
+ elif EVEN_HEADDIM:
332
+ q = tl.load(q_ptrs, mask=offs_m_curr[:, None] < seqlen_q, other=0.0)
333
+ else:
334
+ q = tl.load(q_ptrs, mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim), other=0.0)
335
+ qk = tl.dot(q, k, trans_b=True)
336
+ if not EVEN_N:
337
+ qk = tl.where(offs_n[None, :] < seqlen_k, qk, float("-inf"))
338
+ if IS_CAUSAL:
339
+ qk = tl.where(offs_m_curr[:, None] >= offs_n[None, :], qk, float("-inf"))
340
+ if BIAS_TYPE != "none":
341
+ tl.debug_barrier()
342
+ if BIAS_TYPE == "vector":
343
+ if EVEN_N:
344
+ bias = tl.load(b_ptrs).to(tl.float32)
345
+ else:
346
+ bias = tl.load(b_ptrs, mask=offs_n < seqlen_k, other=0.0).to(tl.float32)
347
+ bias = bias[None, :]
348
+ elif BIAS_TYPE == "matrix":
349
+ if EVEN_M & EVEN_N:
350
+ bias = tl.load(b_ptrs).to(tl.float32)
351
+ else:
352
+ bias = tl.load(b_ptrs, mask=(offs_m_curr[:, None] < seqlen_q) & (offs_n[None, :] < seqlen_k), other=0.0).to(tl.float32)
353
+ qk = qk * softmax_scale + bias
354
+ if not EVEN_M & EVEN_HEADDIM:
355
+ tl.debug_barrier()
356
+ lse_i = tl.load(LSE + offs_m_curr)
357
+ if BIAS_TYPE == "none":
358
+ p = tl.exp(qk * softmax_scale - lse_i[:, None])
359
+ else:
360
+ p = tl.exp(qk - lse_i[:, None])
361
+ if EVEN_M & EVEN_HEADDIM:
362
+ do = tl.load(do_ptrs)
363
+ else:
364
+ do = tl.load(do_ptrs, mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim), other=0.0)
365
+ dv += tl.dot(p.to(do.dtype), do, trans_a=True)
366
+ if not EVEN_M & EVEN_HEADDIM:
367
+ tl.debug_barrier()
368
+ dp = tl.dot(do, v, trans_b=True)
369
+ if not EVEN_HEADDIM:
370
+ tl.debug_barrier()
371
+ Di = tl.load(D + offs_m_curr)
372
+ ds = (p * (dp - Di[:, None]) * softmax_scale).to(q.dtype)
373
+ dk += tl.dot(ds, q, trans_a=True)
374
+ if not EVEN_M & EVEN_HEADDIM:
375
+ tl.debug_barrier()
376
+ if not ATOMIC_ADD:
377
+ if EVEN_M & EVEN_HEADDIM:
378
+ dq = tl.load(dq_ptrs, eviction_policy="evict_last")
379
+ dq += tl.dot(ds, k)
380
+ tl.store(dq_ptrs, dq, eviction_policy="evict_last")
381
+ elif EVEN_HEADDIM:
382
+ dq = tl.load(dq_ptrs, mask=offs_m_curr[:, None] < seqlen_q, other=0.0, eviction_policy="evict_last")
383
+ dq += tl.dot(ds, k)
384
+ tl.store(dq_ptrs, dq, mask=offs_m_curr[:, None] < seqlen_q, eviction_policy="evict_last")
385
+ else:
386
+ dq = tl.load(dq_ptrs, mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim), other=0.0, eviction_policy="evict_last")
387
+ dq += tl.dot(ds, k)
388
+ tl.store(dq_ptrs, dq, mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim), eviction_policy="evict_last")
389
+ else:
390
+ dq = tl.dot(ds, k)
391
+ if EVEN_M & EVEN_HEADDIM:
392
+ tl.atomic_add(dq_ptrs, dq)
393
+ elif EVEN_HEADDIM:
394
+ tl.atomic_add(dq_ptrs, dq, mask=offs_m_curr[:, None] < seqlen_q)
395
+ else:
396
+ tl.atomic_add(dq_ptrs, dq, mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim))
397
+ dq_ptrs += BLOCK_M * stride_dqm
398
+ q_ptrs += BLOCK_M * stride_qm
399
+ do_ptrs += BLOCK_M * stride_dom
400
+ if BIAS_TYPE == "matrix":
401
+ b_ptrs += BLOCK_M * stride_bm
402
+ dv_ptrs = DV + (offs_n[:, None] * stride_dvn + offs_d[None, :])
403
+ dk_ptrs = DK + (offs_n[:, None] * stride_dkn + offs_d[None, :])
404
+ _bwd_store_dk_dv(dk_ptrs, dv_ptrs, dk, dv, offs_n, offs_d, seqlen_k, headdim, EVEN_M=EVEN_M, EVEN_N=EVEN_N, EVEN_HEADDIM=EVEN_HEADDIM)
405
+
406
+
407
+ def init_to_zero(name):
408
+ return lambda nargs: nargs[name].zero_()
409
+
410
+
411
+ @triton.autotune(
412
+ configs=[
413
+ triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "SEQUENCE_PARALLEL": False}, num_warps=8, num_stages=1, pre_hook=init_to_zero("DQ")),
414
+ triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "SEQUENCE_PARALLEL": True}, num_warps=8, num_stages=1, pre_hook=init_to_zero("DQ")),
415
+ ],
416
+ key=["CACHE_KEY_SEQLEN_Q", "CACHE_KEY_SEQLEN_K", "BIAS_TYPE", "IS_CAUSAL", "BLOCK_HEADDIM"],
417
+ )
418
+ @triton.heuristics(
419
+ {
420
+ "EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0,
421
+ "EVEN_N": lambda args: args["seqlen_k"] % args["BLOCK_N"] == 0,
422
+ "EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"],
423
+ }
424
+ )
425
+ @triton.jit
426
+ def _bwd_kernel(
427
+ Q,
428
+ K,
429
+ V,
430
+ Bias,
431
+ DO,
432
+ DQ,
433
+ DK,
434
+ DV,
435
+ LSE,
436
+ D,
437
+ softmax_scale,
438
+ stride_qb,
439
+ stride_qh,
440
+ stride_qm,
441
+ stride_kb,
442
+ stride_kh,
443
+ stride_kn,
444
+ stride_vb,
445
+ stride_vh,
446
+ stride_vn,
447
+ stride_bb,
448
+ stride_bh,
449
+ stride_bm,
450
+ stride_dob,
451
+ stride_doh,
452
+ stride_dom,
453
+ stride_dqb,
454
+ stride_dqh,
455
+ stride_dqm,
456
+ stride_dkb,
457
+ stride_dkh,
458
+ stride_dkn,
459
+ stride_dvb,
460
+ stride_dvh,
461
+ stride_dvn,
462
+ nheads,
463
+ seqlen_q,
464
+ seqlen_k,
465
+ seqlen_q_rounded,
466
+ headdim,
467
+ CACHE_KEY_SEQLEN_Q,
468
+ CACHE_KEY_SEQLEN_K,
469
+ BIAS_TYPE: tl.constexpr,
470
+ IS_CAUSAL: tl.constexpr,
471
+ BLOCK_HEADDIM: tl.constexpr,
472
+ SEQUENCE_PARALLEL: tl.constexpr,
473
+ EVEN_M: tl.constexpr,
474
+ EVEN_N: tl.constexpr,
475
+ EVEN_HEADDIM: tl.constexpr,
476
+ BLOCK_M: tl.constexpr,
477
+ BLOCK_N: tl.constexpr,
478
+ ):
479
+ off_hb = tl.program_id(1)
480
+ off_b = off_hb // nheads
481
+ off_h = off_hb % nheads
482
+ Q += off_b * stride_qb + off_h * stride_qh
483
+ K += off_b * stride_kb + off_h * stride_kh
484
+ V += off_b * stride_vb + off_h * stride_vh
485
+ DO += off_b * stride_dob + off_h * stride_doh
486
+ DQ += off_b * stride_dqb + off_h * stride_dqh
487
+ DK += off_b * stride_dkb + off_h * stride_dkh
488
+ DV += off_b * stride_dvb + off_h * stride_dvh
489
+ if BIAS_TYPE != "none":
490
+ Bias += off_b * stride_bb + off_h * stride_bh
491
+ D += off_hb * seqlen_q_rounded
492
+ LSE += off_hb * seqlen_q_rounded
493
+ if not SEQUENCE_PARALLEL:
494
+ num_block_n = tl.cdiv(seqlen_k, BLOCK_N)
495
+ for start_n in range(0, num_block_n):
496
+ _bwd_kernel_one_col_block(
497
+ start_n,
498
+ Q,
499
+ K,
500
+ V,
501
+ Bias,
502
+ DO,
503
+ DQ,
504
+ DK,
505
+ DV,
506
+ LSE,
507
+ D,
508
+ softmax_scale,
509
+ stride_qm,
510
+ stride_kn,
511
+ stride_vn,
512
+ stride_bm,
513
+ stride_dom,
514
+ stride_dqm,
515
+ stride_dkn,
516
+ stride_dvn,
517
+ seqlen_q,
518
+ seqlen_k,
519
+ headdim,
520
+ ATOMIC_ADD=False,
521
+ BIAS_TYPE=BIAS_TYPE,
522
+ IS_CAUSAL=IS_CAUSAL,
523
+ BLOCK_HEADDIM=BLOCK_HEADDIM,
524
+ EVEN_M=EVEN_M,
525
+ EVEN_N=EVEN_N,
526
+ EVEN_HEADDIM=EVEN_HEADDIM,
527
+ BLOCK_M=BLOCK_M,
528
+ BLOCK_N=BLOCK_N,
529
+ )
530
+ else:
531
+ start_n = tl.program_id(0)
532
+ _bwd_kernel_one_col_block(
533
+ start_n,
534
+ Q,
535
+ K,
536
+ V,
537
+ Bias,
538
+ DO,
539
+ DQ,
540
+ DK,
541
+ DV,
542
+ LSE,
543
+ D,
544
+ softmax_scale,
545
+ stride_qm,
546
+ stride_kn,
547
+ stride_vn,
548
+ stride_bm,
549
+ stride_dom,
550
+ stride_dqm,
551
+ stride_dkn,
552
+ stride_dvn,
553
+ seqlen_q,
554
+ seqlen_k,
555
+ headdim,
556
+ ATOMIC_ADD=True,
557
+ BIAS_TYPE=BIAS_TYPE,
558
+ IS_CAUSAL=IS_CAUSAL,
559
+ BLOCK_HEADDIM=BLOCK_HEADDIM,
560
+ EVEN_M=EVEN_M,
561
+ EVEN_N=EVEN_N,
562
+ EVEN_HEADDIM=EVEN_HEADDIM,
563
+ BLOCK_M=BLOCK_M,
564
+ BLOCK_N=BLOCK_N,
565
+ )
566
+
567
+
568
+ def _flash_attn_forward(q, k, v, bias=None, causal=False, softmax_scale=None):
569
+ (batch, seqlen_q, nheads, d) = q.shape
570
+ (_, seqlen_k, _, _) = k.shape
571
+ assert k.shape == (batch, seqlen_k, nheads, d)
572
+ assert v.shape == (batch, seqlen_k, nheads, d)
573
+ assert d <= 128, "FlashAttention only support head dimensions up to 128"
574
+ assert q.dtype == k.dtype == v.dtype, "All tensors must have the same type"
575
+ assert q.dtype in [torch.float16, torch.bfloat16], "Only support fp16 and bf16"
576
+ assert q.is_cuda and k.is_cuda and v.is_cuda
577
+ softmax_scale = softmax_scale or 1.0 / math.sqrt(d)
578
+ has_bias = bias is not None
579
+ bias_type = "none"
580
+ if has_bias:
581
+ assert bias.dtype in [q.dtype, torch.float]
582
+ assert bias.is_cuda
583
+ assert bias.dim() == 4
584
+ if bias.stride(-1) != 1:
585
+ bias = bias.contiguous()
586
+ if bias.shape[2:] == (1, seqlen_k):
587
+ bias_type = "vector"
588
+ elif bias.shape[2:] == (seqlen_q, seqlen_k):
589
+ bias_type = "matrix"
590
+ else:
591
+ raise RuntimeError("Last 2 dimensions of bias must be (1, seqlen_k) or (seqlen_q, seqlen_k)")
592
+ bias = bias.expand(batch, nheads, seqlen_q, seqlen_k)
593
+ bias_strides = (bias.stride(0), bias.stride(1), bias.stride(2)) if has_bias else (0, 0, 0)
594
+ seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128
595
+ lse = torch.empty((batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32)
596
+ tmp = torch.empty((batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32)
597
+ o = torch.empty_like(q)
598
+ BLOCK_HEADDIM = max(triton.next_power_of_2(d), 16)
599
+ BLOCK = 128
600
+ num_warps = 4 if d <= 64 else 8
601
+ grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK_M"]), batch * nheads)
602
+ _fwd_kernel[grid](
603
+ q,
604
+ k,
605
+ v,
606
+ bias,
607
+ o,
608
+ lse,
609
+ tmp,
610
+ softmax_scale,
611
+ q.stride(0),
612
+ q.stride(2),
613
+ q.stride(1),
614
+ k.stride(0),
615
+ k.stride(2),
616
+ k.stride(1),
617
+ v.stride(0),
618
+ v.stride(2),
619
+ v.stride(1),
620
+ *bias_strides,
621
+ o.stride(0),
622
+ o.stride(2),
623
+ o.stride(1),
624
+ nheads,
625
+ seqlen_q,
626
+ seqlen_k,
627
+ seqlen_q_rounded,
628
+ d,
629
+ seqlen_q // 32,
630
+ seqlen_k // 32,
631
+ bias_type,
632
+ causal,
633
+ BLOCK_HEADDIM,
634
+ BLOCK_M=BLOCK,
635
+ BLOCK_N=BLOCK,
636
+ num_warps=num_warps,
637
+ num_stages=1
638
+ )
639
+ return (o, lse, softmax_scale)
640
+
641
+
642
+ def _flash_attn_backward(do, q, k, v, o, lse, dq, dk, dv, bias=None, causal=False, softmax_scale=None):
643
+ if do.stride(-1) != 1:
644
+ do = do.contiguous()
645
+ (batch, seqlen_q, nheads, d) = q.shape
646
+ (_, seqlen_k, _, _) = k.shape
647
+ assert d <= 128
648
+ seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128
649
+ assert lse.shape == (batch, nheads, seqlen_q_rounded)
650
+ assert q.stride(-1) == k.stride(-1) == v.stride(-1) == o.stride(-1) == 1
651
+ assert dq.stride(-1) == dk.stride(-1) == dv.stride(-1) == 1
652
+ softmax_scale = softmax_scale or 1.0 / math.sqrt(d)
653
+ dq_accum = torch.empty_like(q, dtype=torch.float32)
654
+ delta = torch.empty_like(lse)
655
+ BLOCK_HEADDIM = max(triton.next_power_of_2(d), 16)
656
+ grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK_M"]), batch * nheads)
657
+ _bwd_preprocess_do_o_dot[grid](
658
+ o,
659
+ do,
660
+ delta,
661
+ o.stride(0),
662
+ o.stride(2),
663
+ o.stride(1),
664
+ do.stride(0),
665
+ do.stride(2),
666
+ do.stride(1),
667
+ nheads,
668
+ seqlen_q,
669
+ seqlen_q_rounded,
670
+ d,
671
+ BLOCK_M=128,
672
+ BLOCK_HEADDIM=BLOCK_HEADDIM,
673
+ )
674
+ has_bias = bias is not None
675
+ bias_type = "none"
676
+ if has_bias:
677
+ assert bias.dtype in [q.dtype, torch.float]
678
+ assert bias.is_cuda
679
+ assert bias.dim() == 4
680
+ assert bias.stride(-1) == 1
681
+ if bias.shape[2:] == (1, seqlen_k):
682
+ bias_type = "vector"
683
+ elif bias.shape[2:] == (seqlen_q, seqlen_k):
684
+ bias_type = "matrix"
685
+ else:
686
+ raise RuntimeError("Last 2 dimensions of bias must be (1, seqlen_k) or (seqlen_q, seqlen_k)")
687
+ bias = bias.expand(batch, nheads, seqlen_q, seqlen_k)
688
+ bias_strides = (bias.stride(0), bias.stride(1), bias.stride(2)) if has_bias else (0, 0, 0)
689
+ grid = lambda META: (triton.cdiv(seqlen_k, META["BLOCK_N"]) if META["SEQUENCE_PARALLEL"] else 1, batch * nheads)
690
+ _bwd_kernel[grid](
691
+ q,
692
+ k,
693
+ v,
694
+ bias,
695
+ do,
696
+ dq_accum,
697
+ dk,
698
+ dv,
699
+ lse,
700
+ delta,
701
+ softmax_scale,
702
+ q.stride(0),
703
+ q.stride(2),
704
+ q.stride(1),
705
+ k.stride(0),
706
+ k.stride(2),
707
+ k.stride(1),
708
+ v.stride(0),
709
+ v.stride(2),
710
+ v.stride(1),
711
+ *bias_strides,
712
+ do.stride(0),
713
+ do.stride(2),
714
+ do.stride(1),
715
+ dq_accum.stride(0),
716
+ dq_accum.stride(2),
717
+ dq_accum.stride(1),
718
+ dk.stride(0),
719
+ dk.stride(2),
720
+ dk.stride(1),
721
+ dv.stride(0),
722
+ dv.stride(2),
723
+ dv.stride(1),
724
+ nheads,
725
+ seqlen_q,
726
+ seqlen_k,
727
+ seqlen_q_rounded,
728
+ d,
729
+ seqlen_q // 32,
730
+ seqlen_k // 32,
731
+ bias_type,
732
+ causal,
733
+ BLOCK_HEADDIM
734
+ )
735
+ dq.copy_(dq_accum)
736
+
737
+
738
+ class FlashAttnQKVPackedFunc(torch.autograd.Function):
739
+ @staticmethod
740
+ def forward(ctx, qkv, bias=None, causal=False, softmax_scale=None):
741
+ """
742
+ qkv: (batch, seqlen, 3, nheads, headdim)
743
+ bias: optional, shape broadcastible to (batch, nheads, seqlen, seqlen).
744
+ For example, ALiBi mask for causal would have shape (1, nheads, 1, seqlen).
745
+ ALiBi mask for non-causal would have shape (1, nheads, seqlen, seqlen)
746
+ """
747
+ if qkv.stride(-1) != 1:
748
+ qkv = qkv.contiguous()
749
+ (o, lse, ctx.softmax_scale) = _flash_attn_forward(qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2], bias=bias, causal=causal, softmax_scale=softmax_scale)
750
+ ctx.save_for_backward(qkv, o, lse, bias)
751
+ ctx.causal = causal
752
+ return o
753
+
754
+ @staticmethod
755
+ def backward(ctx, do):
756
+ (qkv, o, lse, bias) = ctx.saved_tensors
757
+ assert not ctx.needs_input_grad[1], "FlashAttention does not support bias gradient yet"
758
+ with torch.inference_mode():
759
+ dqkv = torch.empty_like(qkv)
760
+ _flash_attn_backward(
761
+ do,
762
+ qkv[:, :, 0],
763
+ qkv[:, :, 1],
764
+ qkv[:, :, 2],
765
+ o,
766
+ lse,
767
+ dqkv[:, :, 0],
768
+ dqkv[:, :, 1],
769
+ dqkv[:, :, 2],
770
+ bias=bias,
771
+ causal=ctx.causal,
772
+ softmax_scale=ctx.softmax_scale,
773
+ )
774
+ return (dqkv, None, None, None)
775
+
776
+
777
+ flash_attn_qkvpacked_func = FlashAttnQKVPackedFunc.apply
778
+
779
+
780
+ class FlashAttnKVPackedFunc(torch.autograd.Function):
781
+ @staticmethod
782
+ def forward(ctx, q, kv, bias=None, causal=False, softmax_scale=None):
783
+ """
784
+ q: (batch, seqlen_q, nheads, headdim)
785
+ kv: (batch, seqlen_k, 2, nheads, headdim)
786
+ bias: optional, shape broadcastible to (batch, nheads, seqlen_q, seqlen_k).
787
+ For example, ALiBi mask for causal would have shape (1, nheads, 1, seqlen_k).
788
+ ALiBi mask for non-causal would have shape (1, nheads, seqlen_q, seqlen_k)
789
+ """
790
+ (q, kv) = [x if x.stride(-1) == 1 else x.contiguous() for x in [q, kv]]
791
+ (o, lse, ctx.softmax_scale) = _flash_attn_forward(q, kv[:, :, 0], kv[:, :, 1], bias=bias, causal=causal, softmax_scale=softmax_scale)
792
+ ctx.save_for_backward(q, kv, o, lse, bias)
793
+ ctx.causal = causal
794
+ return o
795
+
796
+ @staticmethod
797
+ def backward(ctx, do):
798
+ (q, kv, o, lse, bias) = ctx.saved_tensors
799
+ if len(ctx.needs_input_grad) >= 3:
800
+ assert not ctx.needs_input_grad[2], "FlashAttention does not support bias gradient yet"
801
+ with torch.inference_mode():
802
+ dq = torch.empty_like(q)
803
+ dkv = torch.empty_like(kv)
804
+ _flash_attn_backward(
805
+ do, q, kv[:, :, 0], kv[:, :, 1], o, lse, dq, dkv[:, :, 0], dkv[:, :, 1], bias=bias, causal=ctx.causal, softmax_scale=ctx.softmax_scale
806
+ )
807
+ return (dq, dkv, None, None, None)
808
+
809
+
810
+ flash_attn_kvpacked_func = FlashAttnKVPackedFunc.apply
811
+
812
+
813
+ class FlashAttnFunc(torch.autograd.Function):
814
+ @staticmethod
815
+ def forward(ctx, q, k, v, bias=None, causal=False, softmax_scale=None):
816
+ """
817
+ q: (batch_size, seqlen_q, nheads, headdim)
818
+ k, v: (batch_size, seqlen_k, nheads, headdim)
819
+ bias: optional, shape broadcastible to (batch, nheads, seqlen_q, seqlen_k).
820
+ For example, ALiBi mask for causal would have shape (1, nheads, 1, seqlen_k).
821
+ ALiBi mask for non-causal would have shape (1, nheads, seqlen_q, seqlen_k)
822
+ """
823
+ (q, k, v) = [x if x.stride(-1) == 1 else x.contiguous() for x in [q, k, v]]
824
+ (o, lse, ctx.softmax_scale) = _flash_attn_forward(q, k, v, bias=bias, causal=causal, softmax_scale=softmax_scale)
825
+ ctx.save_for_backward(q, k, v, o, lse, bias)
826
+ ctx.causal = causal
827
+ return o
828
+
829
+ @staticmethod
830
+ def backward(ctx, do):
831
+ (q, k, v, o, lse, bias) = ctx.saved_tensors
832
+ assert not ctx.needs_input_grad[3], "FlashAttention does not support bias gradient yet"
833
+ with torch.inference_mode():
834
+ dq = torch.empty_like(q)
835
+ dk = torch.empty_like(k)
836
+ dv = torch.empty_like(v)
837
+ _flash_attn_backward(do, q, k, v, o, lse, dq, dk, dv, bias=bias, causal=ctx.causal, softmax_scale=ctx.softmax_scale)
838
+ return (dq, dk, dv, None, None, None)
839
+
840
+
841
+ flash_attn_func = FlashAttnFunc.apply
mllm/flamingo/mpt/hf_prefixlm_converter.py ADDED
@@ -0,0 +1,575 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Converts Huggingface Causal LM to Prefix LM.
2
+
3
+ Conversion does lightweight surgery on a HuggingFace
4
+ Causal LM to convert it to a Prefix LM.
5
+
6
+ Prefix LMs accepts a `bidirectional_mask` input in `forward`
7
+ and treat the input prompt as the prefix in `generate`.
8
+ """
9
+ import math
10
+ import warnings
11
+ from types import MethodType
12
+ from typing import Any, Dict, List, Optional, Tuple, Union
13
+ import torch
14
+ from transformers.models.bloom.modeling_bloom import (
15
+ BaseModelOutputWithPastAndCrossAttentions,
16
+ BloomForCausalLM,
17
+ BloomModel,
18
+ CausalLMOutputWithCrossAttentions,
19
+ CrossEntropyLoss,
20
+ )
21
+ from transformers.models.bloom.modeling_bloom import _expand_mask as _expand_mask_bloom
22
+ from transformers.models.bloom.modeling_bloom import _make_causal_mask as _make_causal_mask_bloom
23
+ from transformers.models.bloom.modeling_bloom import logging
24
+ from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel
25
+ from transformers.models.gpt_neo.modeling_gpt_neo import GPTNeoForCausalLM
26
+ from transformers.models.gpt_neox.modeling_gpt_neox import GPTNeoXForCausalLM
27
+ from transformers.models.gptj.modeling_gptj import GPTJForCausalLM
28
+ from transformers.models.opt.modeling_opt import OPTForCausalLM
29
+ from transformers.models.opt.modeling_opt import _expand_mask as _expand_mask_opt
30
+ from transformers.models.opt.modeling_opt import _make_causal_mask as _make_causal_mask_opt
31
+
32
+ logger = logging.get_logger(__name__)
33
+ _SUPPORTED_GPT_MODELS = (GPT2LMHeadModel, GPTJForCausalLM, GPTNeoForCausalLM, GPTNeoXForCausalLM)
34
+ CAUSAL_GPT_TYPES = Union[GPT2LMHeadModel, GPTJForCausalLM, GPTNeoForCausalLM, GPTNeoXForCausalLM]
35
+
36
+
37
+ def _convert_gpt_causal_lm_to_prefix_lm(model: CAUSAL_GPT_TYPES) -> CAUSAL_GPT_TYPES:
38
+ """Converts a GPT-style Causal LM to a Prefix LM.
39
+
40
+ Supported HuggingFace model classes:
41
+ - `GPT2LMHeadModel`
42
+ - `GPTNeoForCausalLM`
43
+ - `GPTNeoXForCausalLM`
44
+ - `GPTJForCausalLM`
45
+
46
+ See `convert_hf_causal_lm_to_prefix_lm` for more details.
47
+ """
48
+ if hasattr(model, "_prefix_lm_converted"):
49
+ return model
50
+ assert isinstance(model, _SUPPORTED_GPT_MODELS)
51
+ assert model.config.add_cross_attention == False, "Only supports GPT-style decoder-only models"
52
+
53
+ def _get_attn_modules(model: CAUSAL_GPT_TYPES) -> List[torch.nn.Module]:
54
+ """Helper that gets a list of the model's attention modules.
55
+
56
+ Each module has a `bias` buffer used for causal masking. The Prefix LM
57
+ conversion adds logic to dynamically manipulate these biases to support
58
+ Prefix LM attention masking.
59
+ """
60
+ attn_modules = []
61
+ if isinstance(model, GPTNeoXForCausalLM):
62
+ blocks = model.gpt_neox.layers
63
+ else:
64
+ blocks = model.transformer.h
65
+ for block in blocks:
66
+ if isinstance(model, GPTNeoForCausalLM):
67
+ if block.attn.attention_type != "global":
68
+ continue
69
+ attn_module = block.attn.attention
70
+ elif isinstance(model, GPTNeoXForCausalLM):
71
+ attn_module = block.attention
72
+ else:
73
+ attn_module = block.attn
74
+ attn_modules.append(attn_module)
75
+ return attn_modules
76
+
77
+ setattr(model, "_original_forward", getattr(model, "forward"))
78
+ setattr(model, "_original_generate", getattr(model, "generate"))
79
+
80
+ def forward(
81
+ self: CAUSAL_GPT_TYPES,
82
+ input_ids: Optional[torch.LongTensor] = None,
83
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
84
+ attention_mask: Optional[torch.FloatTensor] = None,
85
+ bidirectional_mask: Optional[torch.Tensor] = None,
86
+ token_type_ids: Optional[torch.LongTensor] = None,
87
+ position_ids: Optional[torch.LongTensor] = None,
88
+ head_mask: Optional[torch.FloatTensor] = None,
89
+ inputs_embeds: Optional[torch.FloatTensor] = None,
90
+ labels: Optional[torch.LongTensor] = None,
91
+ use_cache: Optional[bool] = None,
92
+ output_attentions: Optional[bool] = None,
93
+ output_hidden_states: Optional[bool] = None,
94
+ return_dict: Optional[bool] = None,
95
+ ):
96
+ """Wraps original forward to enable PrefixLM attention."""
97
+
98
+ def call_og_forward():
99
+ if isinstance(self, GPTNeoXForCausalLM):
100
+ return self._original_forward(
101
+ input_ids=input_ids,
102
+ past_key_values=past_key_values,
103
+ attention_mask=attention_mask,
104
+ head_mask=head_mask,
105
+ inputs_embeds=inputs_embeds,
106
+ labels=labels,
107
+ use_cache=use_cache,
108
+ output_attentions=output_attentions,
109
+ output_hidden_states=output_hidden_states,
110
+ return_dict=return_dict,
111
+ )
112
+ else:
113
+ return self._original_forward(
114
+ input_ids=input_ids,
115
+ past_key_values=past_key_values,
116
+ attention_mask=attention_mask,
117
+ token_type_ids=token_type_ids,
118
+ position_ids=position_ids,
119
+ head_mask=head_mask,
120
+ inputs_embeds=inputs_embeds,
121
+ labels=labels,
122
+ use_cache=use_cache,
123
+ output_attentions=output_attentions,
124
+ output_hidden_states=output_hidden_states,
125
+ return_dict=return_dict,
126
+ )
127
+
128
+ if bidirectional_mask is None:
129
+ return call_og_forward()
130
+ assert isinstance(bidirectional_mask, torch.Tensor)
131
+ attn_modules = _get_attn_modules(model)
132
+ (b, s) = bidirectional_mask.shape
133
+ max_length = attn_modules[0].bias.shape[-1]
134
+ if s > max_length:
135
+ raise ValueError(f"bidirectional_mask sequence length (={s}) exceeds the " + f"max length allowed by the model ({max_length}).")
136
+ assert s <= max_length
137
+ if s < max_length:
138
+ pad = torch.zeros((int(b), int(max_length - s)), dtype=bidirectional_mask.dtype, device=bidirectional_mask.device)
139
+ bidirectional_mask = torch.cat([bidirectional_mask, pad], dim=1)
140
+ bidirectional = bidirectional_mask.unsqueeze(1).unsqueeze(1)
141
+ for attn_module in attn_modules:
142
+ attn_module.bias.data = torch.logical_or(attn_module.bias.data, bidirectional)
143
+ output = call_og_forward()
144
+ for attn_module in attn_modules:
145
+ attn_module.bias.data = torch.tril(attn_module.bias.data[0, 0])[None, None]
146
+ return output
147
+
148
+ def generate(self: CAUSAL_GPT_TYPES, *args: tuple, **kwargs: Dict[str, Any]):
149
+ """Wraps original generate to enable PrefixLM attention."""
150
+ attn_modules = _get_attn_modules(model)
151
+ for attn_module in attn_modules:
152
+ attn_module.bias.data[:] = 1
153
+ output = self._original_generate(*args, **kwargs)
154
+ for attn_module in attn_modules:
155
+ attn_module.bias.data = torch.tril(attn_module.bias.data[0, 0])[None, None]
156
+ return output
157
+
158
+ setattr(model, "forward", MethodType(forward, model))
159
+ setattr(model, "generate", MethodType(generate, model))
160
+ setattr(model, "_prefix_lm_converted", True)
161
+ return model
162
+
163
+
164
+ def _convert_bloom_causal_lm_to_prefix_lm(model: BloomForCausalLM) -> BloomForCausalLM:
165
+ """Converts a BLOOM Causal LM to a Prefix LM.
166
+
167
+ Supported HuggingFace model classes:
168
+ - `BloomForCausalLM`
169
+
170
+ See `convert_hf_causal_lm_to_prefix_lm` for more details.
171
+ """
172
+ if hasattr(model, "_prefix_lm_converted"):
173
+ return model
174
+ assert isinstance(model, BloomForCausalLM)
175
+ assert model.config.add_cross_attention == False, "Only supports BLOOM decoder-only models"
176
+
177
+ def _prepare_attn_mask(
178
+ self: BloomModel, attention_mask: torch.Tensor, bidirectional_mask: Optional[torch.Tensor], input_shape: Tuple[int, int], past_key_values_length: int
179
+ ) -> torch.BoolTensor:
180
+ combined_attention_mask = None
181
+ device = attention_mask.device
182
+ (_, src_length) = input_shape
183
+ if src_length > 1:
184
+ combined_attention_mask = _make_causal_mask_bloom(input_shape, device=device, past_key_values_length=past_key_values_length)
185
+ if bidirectional_mask is not None:
186
+ assert attention_mask.shape == bidirectional_mask.shape
187
+ expanded_bidirectional_mask = _expand_mask_bloom(bidirectional_mask, tgt_length=src_length)
188
+ combined_attention_mask = torch.logical_and(combined_attention_mask, expanded_bidirectional_mask)
189
+ expanded_attn_mask = _expand_mask_bloom(attention_mask, tgt_length=src_length)
190
+ combined_attention_mask = expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask | combined_attention_mask
191
+ return combined_attention_mask
192
+
193
+ def _build_alibi_tensor(self: BloomModel, batch_size: int, query_length: int, key_length: int, dtype: torch.dtype, device: torch.device) -> torch.Tensor:
194
+ num_heads = self.config.n_head
195
+ closest_power_of_2 = 2 ** math.floor(math.log2(num_heads))
196
+ base = torch.tensor(2 ** (-(2 ** (-(math.log2(closest_power_of_2) - 3)))), device=device, dtype=torch.float32)
197
+ powers = torch.arange(1, 1 + closest_power_of_2, device=device, dtype=torch.int32)
198
+ slopes = torch.pow(base, powers)
199
+ if closest_power_of_2 != num_heads:
200
+ extra_base = torch.tensor(2 ** (-(2 ** (-(math.log2(2 * closest_power_of_2) - 3)))), device=device, dtype=torch.float32)
201
+ num_remaining_heads = min(closest_power_of_2, num_heads - closest_power_of_2)
202
+ extra_powers = torch.arange(1, 1 + 2 * num_remaining_heads, 2, device=device, dtype=torch.int32)
203
+ slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0)
204
+ qa = torch.arange(query_length, device=device, dtype=torch.int32).view(-1, 1)
205
+ ka = torch.arange(key_length, device=device, dtype=torch.int32).view(1, -1)
206
+ diffs = qa - ka + key_length - query_length
207
+ diffs = -diffs.abs()
208
+ alibi = slopes.view(1, num_heads, 1, 1) * diffs.view(1, 1, query_length, key_length)
209
+ alibi = alibi.expand(batch_size, -1, -1, -1).reshape(-1, query_length, key_length)
210
+ return alibi.to(dtype)
211
+
212
+ KeyValueT = Tuple[torch.Tensor, torch.Tensor]
213
+
214
+ def forward(
215
+ self: BloomModel,
216
+ input_ids: Optional[torch.LongTensor] = None,
217
+ past_key_values: Optional[Tuple[KeyValueT, ...]] = None,
218
+ attention_mask: Optional[torch.Tensor] = None,
219
+ bidirectional_mask: Optional[torch.Tensor] = None,
220
+ head_mask: Optional[torch.LongTensor] = None,
221
+ inputs_embeds: Optional[torch.LongTensor] = None,
222
+ use_cache: Optional[bool] = None,
223
+ output_attentions: Optional[bool] = None,
224
+ output_hidden_states: Optional[bool] = None,
225
+ return_dict: Optional[bool] = None,
226
+ **deprecated_arguments,
227
+ ) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]:
228
+ if deprecated_arguments.pop("position_ids", False) is not False:
229
+ warnings.warn(
230
+ "`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. " + "You can safely ignore passing `position_ids`.", FutureWarning
231
+ )
232
+ if len(deprecated_arguments) > 0:
233
+ raise ValueError(f"Got unexpected arguments: {deprecated_arguments}")
234
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
235
+ output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
236
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
237
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
238
+ if input_ids is not None and inputs_embeds is not None:
239
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
240
+ elif input_ids is not None:
241
+ (batch_size, seq_length) = input_ids.shape
242
+ elif inputs_embeds is not None:
243
+ (batch_size, seq_length, _) = inputs_embeds.shape
244
+ else:
245
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
246
+ if past_key_values is None:
247
+ past_key_values = tuple([None] * len(self.h))
248
+ head_mask = self.get_head_mask(head_mask, self.config.n_layer)
249
+ if inputs_embeds is None:
250
+ inputs_embeds = self.word_embeddings(input_ids)
251
+ hidden_states = self.word_embeddings_layernorm(inputs_embeds)
252
+ presents = () if use_cache else None
253
+ all_self_attentions = () if output_attentions else None
254
+ all_hidden_states = () if output_hidden_states else None
255
+ seq_length_with_past = seq_length
256
+ past_key_values_length = 0
257
+ if past_key_values[0] is not None:
258
+ tmp = past_key_values[0][0]
259
+ past_key_values_length = tmp.shape[2]
260
+ seq_length_with_past = seq_length_with_past + past_key_values_length
261
+ if attention_mask is None:
262
+ attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device)
263
+ else:
264
+ attention_mask = attention_mask.to(hidden_states.device)
265
+ alibi = self._build_alibi_tensor(
266
+ batch_size=batch_size, query_length=seq_length, key_length=seq_length_with_past, dtype=hidden_states.dtype, device=hidden_states.device
267
+ )
268
+ causal_mask = self._prepare_attn_mask(
269
+ attention_mask, bidirectional_mask, input_shape=(batch_size, seq_length), past_key_values_length=past_key_values_length
270
+ )
271
+ for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
272
+ if output_hidden_states:
273
+ hst = (hidden_states,)
274
+ all_hidden_states = all_hidden_states + hst
275
+ if self.gradient_checkpointing and self.training:
276
+ if use_cache:
277
+ logger.warning("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
278
+ use_cache = False
279
+
280
+ def create_custom_forward(module):
281
+ def custom_forward(*inputs):
282
+ return module(*inputs, use_cache=use_cache, output_attentions=output_attentions)
283
+
284
+ return custom_forward
285
+
286
+ outputs = torch.utils.checkpoint.checkpoint(create_custom_forward(block), hidden_states, alibi, causal_mask, head_mask[i])
287
+ else:
288
+ outputs = block(
289
+ hidden_states,
290
+ layer_past=layer_past,
291
+ attention_mask=causal_mask,
292
+ head_mask=head_mask[i],
293
+ use_cache=use_cache,
294
+ output_attentions=output_attentions,
295
+ alibi=alibi,
296
+ )
297
+ hidden_states = outputs[0]
298
+ if use_cache is True:
299
+ presents = presents + (outputs[1],)
300
+ if output_attentions:
301
+ oa = (outputs[2 if use_cache else 1],)
302
+ all_self_attentions = all_self_attentions + oa
303
+ hidden_states = self.ln_f(hidden_states)
304
+ if output_hidden_states:
305
+ hst = (hidden_states,)
306
+ all_hidden_states = all_hidden_states + hst
307
+ if not return_dict:
308
+ return tuple((v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None))
309
+ return BaseModelOutputWithPastAndCrossAttentions(
310
+ last_hidden_state=hidden_states, past_key_values=presents, hidden_states=all_hidden_states, attentions=all_self_attentions
311
+ )
312
+
313
+ setattr(model.transformer, "_prepare_attn_mask", MethodType(_prepare_attn_mask, model.transformer))
314
+ setattr(model.transformer, "_build_alibi_tensor", MethodType(_build_alibi_tensor, model.transformer))
315
+ setattr(model.transformer, "forward", MethodType(forward, model.transformer))
316
+ KeyValueT = Tuple[torch.Tensor, torch.Tensor]
317
+
318
+ def forward(
319
+ self: BloomForCausalLM,
320
+ input_ids: Optional[torch.LongTensor] = None,
321
+ past_key_values: Optional[Tuple[KeyValueT, ...]] = None,
322
+ attention_mask: Optional[torch.Tensor] = None,
323
+ bidirectional_mask: Optional[torch.Tensor] = None,
324
+ head_mask: Optional[torch.Tensor] = None,
325
+ inputs_embeds: Optional[torch.Tensor] = None,
326
+ labels: Optional[torch.Tensor] = None,
327
+ use_cache: Optional[bool] = None,
328
+ output_attentions: Optional[bool] = None,
329
+ output_hidden_states: Optional[bool] = None,
330
+ return_dict: Optional[bool] = None,
331
+ **deprecated_arguments,
332
+ ) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
333
+ """Replacement forward method for BloomCausalLM."""
334
+ if deprecated_arguments.pop("position_ids", False) is not False:
335
+ warnings.warn(
336
+ "`position_ids` have no functionality in BLOOM and will be removed " + "in v5.0.0. You can safely ignore passing `position_ids`.", FutureWarning
337
+ )
338
+ if len(deprecated_arguments) > 0:
339
+ raise ValueError(f"Got unexpected arguments: {deprecated_arguments}")
340
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
341
+ transformer_outputs = self.transformer(
342
+ input_ids,
343
+ past_key_values=past_key_values,
344
+ attention_mask=attention_mask,
345
+ bidirectional_mask=bidirectional_mask,
346
+ head_mask=head_mask,
347
+ inputs_embeds=inputs_embeds,
348
+ use_cache=use_cache,
349
+ output_attentions=output_attentions,
350
+ output_hidden_states=output_hidden_states,
351
+ return_dict=return_dict,
352
+ )
353
+ hidden_states = transformer_outputs[0]
354
+ lm_logits = self.lm_head(hidden_states)
355
+ loss = None
356
+ if labels is not None:
357
+ shift_logits = lm_logits[..., :-1, :].contiguous()
358
+ shift_labels = labels[..., 1:].contiguous()
359
+ (batch_size, seq_length, vocab_size) = shift_logits.shape
360
+ loss_fct = CrossEntropyLoss()
361
+ loss = loss_fct(shift_logits.view(batch_size * seq_length, vocab_size), shift_labels.view(batch_size * seq_length))
362
+ if not return_dict:
363
+ output = (lm_logits,) + transformer_outputs[1:]
364
+ return (loss,) + output if loss is not None else output
365
+ return CausalLMOutputWithCrossAttentions(
366
+ loss=loss,
367
+ logits=lm_logits,
368
+ past_key_values=transformer_outputs.past_key_values,
369
+ hidden_states=transformer_outputs.hidden_states,
370
+ attentions=transformer_outputs.attentions,
371
+ )
372
+
373
+ def prepare_inputs_for_generation(
374
+ self: BloomForCausalLM, input_ids: torch.LongTensor, past: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, **kwargs
375
+ ) -> dict:
376
+ if past:
377
+ input_ids = input_ids[:, -1].unsqueeze(-1)
378
+ bidirectional_mask = None
379
+ if past[0][0].shape[0] == input_ids.shape[0]:
380
+ past = self._convert_to_bloom_cache(past)
381
+ else:
382
+ bidirectional_mask = torch.ones_like(input_ids)
383
+ return {"input_ids": input_ids, "past_key_values": past, "use_cache": True, "attention_mask": attention_mask, "bidirectional_mask": bidirectional_mask}
384
+
385
+ setattr(model, "forward", MethodType(forward, model))
386
+ setattr(model, "prepare_inputs_for_generation", MethodType(prepare_inputs_for_generation, model))
387
+ setattr(model, "_prefix_lm_converted", True)
388
+ return model
389
+
390
+
391
+ def _convert_opt_causal_lm_to_prefix_lm(model: OPTForCausalLM) -> OPTForCausalLM:
392
+ """Converts an OPT Causal LM to a Prefix LM.
393
+
394
+ Supported HuggingFace model classes:
395
+ - `OPTForCausalLM`
396
+
397
+ See `convert_hf_causal_lm_to_prefix_lm` for more details.
398
+ """
399
+ if hasattr(model, "_prefix_lm_converted"):
400
+ return model
401
+ assert isinstance(model, OPTForCausalLM)
402
+ assert model.config.add_cross_attention == False, "Only supports OPT decoder-only models"
403
+ setattr(model, "_original_forward", getattr(model, "forward"))
404
+ setattr(model, "_original_generate", getattr(model, "generate"))
405
+ model.model.decoder.bidirectional_mask = None
406
+
407
+ def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
408
+ combined_attention_mask = None
409
+ if input_shape[-1] > 1:
410
+ if self.bidirectional_mask == "g":
411
+ (bsz, src_length) = input_shape
412
+ combined_attention_mask = torch.zeros(
413
+ (bsz, 1, src_length, src_length + past_key_values_length), dtype=inputs_embeds.dtype, device=inputs_embeds.device
414
+ )
415
+ else:
416
+ combined_attention_mask = _make_causal_mask_opt(input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length).to(
417
+ inputs_embeds.device
418
+ )
419
+ if self.bidirectional_mask is not None:
420
+ assert attention_mask.shape == self.bidirectional_mask.shape
421
+ expanded_bidirectional_mask = _expand_mask_opt(self.bidirectional_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
422
+ inputs_embeds.device
423
+ )
424
+ combined_attention_mask = torch.maximum(expanded_bidirectional_mask, combined_attention_mask)
425
+ if attention_mask is not None:
426
+ expanded_attn_mask = _expand_mask_opt(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(inputs_embeds.device)
427
+ combined_attention_mask = expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
428
+ return combined_attention_mask
429
+
430
+ setattr(model.model.decoder, "_prepare_decoder_attention_mask", MethodType(_prepare_decoder_attention_mask, model.model.decoder))
431
+
432
+ def forward(
433
+ self: OPTForCausalLM,
434
+ input_ids: Optional[torch.LongTensor] = None,
435
+ attention_mask: Optional[torch.Tensor] = None,
436
+ bidirectional_mask: Optional[torch.ByteTensor] = None,
437
+ head_mask: Optional[torch.Tensor] = None,
438
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
439
+ inputs_embeds: Optional[torch.FloatTensor] = None,
440
+ labels: Optional[torch.LongTensor] = None,
441
+ use_cache: Optional[bool] = None,
442
+ output_attentions: Optional[bool] = None,
443
+ output_hidden_states: Optional[bool] = None,
444
+ return_dict: Optional[bool] = None,
445
+ ):
446
+ def call_og_forward():
447
+ return self._original_forward(
448
+ input_ids=input_ids,
449
+ attention_mask=attention_mask,
450
+ head_mask=head_mask,
451
+ past_key_values=past_key_values,
452
+ inputs_embeds=inputs_embeds,
453
+ labels=labels,
454
+ use_cache=use_cache,
455
+ output_attentions=output_attentions,
456
+ output_hidden_states=output_hidden_states,
457
+ return_dict=return_dict,
458
+ )
459
+
460
+ if bidirectional_mask is None:
461
+ return call_og_forward()
462
+ self.model.decoder.bidirectional_mask = bidirectional_mask
463
+ try:
464
+ outputs = call_og_forward()
465
+ except:
466
+ self.model.decoder.bidirectional_mask = None
467
+ raise
468
+ self.model.decoder.bidirectional_mask = None
469
+ return outputs
470
+
471
+ def generate(self: OPTForCausalLM, *args: tuple, **kwargs: Dict[str, Any]):
472
+ """Wraps original generate to enable PrefixLM-style attention."""
473
+ self.model.decoder.bidirectional_mask = "g"
474
+ try:
475
+ output = self._original_generate(*args, **kwargs)
476
+ except:
477
+ self.model.decoder.bidirectional_mask = None
478
+ raise
479
+ self.model.decoder.bidirectional_mask = None
480
+ return output
481
+
482
+ setattr(model, "forward", MethodType(forward, model))
483
+ setattr(model, "generate", MethodType(generate, model))
484
+ setattr(model, "_prefix_lm_converted", True)
485
+ return model
486
+
487
+
488
+ _SUPPORTED_HF_MODELS = _SUPPORTED_GPT_MODELS + (BloomForCausalLM, OPTForCausalLM)
489
+ CAUSAL_LM_TYPES = Union[GPT2LMHeadModel, GPTJForCausalLM, GPTNeoForCausalLM, GPTNeoXForCausalLM, BloomForCausalLM, OPTForCausalLM]
490
+
491
+
492
+ def convert_hf_causal_lm_to_prefix_lm(model: CAUSAL_LM_TYPES) -> CAUSAL_LM_TYPES:
493
+ """Converts a HuggingFace Causal LM to a Prefix LM.
494
+
495
+ Supported HuggingFace model classes:
496
+ - `GPT2LMHeadModel`
497
+ - `GPTNeoForCausalLM`
498
+ - `GPTNeoXForCausalLM`
499
+ - `GPTJForCausalLM`
500
+ - `BloomForCausalLM`
501
+ - `OPTForCausalLM`
502
+
503
+ Conversion to a Prefix LM is done by modifying the `forward` method, and possibly also the
504
+ `generate` method and/or select underlying methods depending on the model class.
505
+
506
+ These changes preserve the model API, but add a new input to `forward`: "bidirectional_mask".
507
+
508
+ Notes on training:
509
+ To actually train the converted model as a Prefix LM, training batches will need to indicate
510
+ the prefix/target structure by including `bidirectional_mask` as part of the batch inputs.
511
+
512
+ **This is not a standard input and requires custom layers either within or after your dataloader.**
513
+
514
+ In addition to adding `bidirectional_mask` to the batch, this custom code should modify `labels`
515
+ such that `batch['labels'][batch['bidirectional_mask'] == 1] == -100`.
516
+ That is, the prefix portion of the sequence should not generate any loss. Loss should only be
517
+ generated by the target portion of the sequence.
518
+
519
+ Notes on `GPTNeoForCausalLM`:
520
+ To simplify the implementation, "global" and "local" attention layers are handled differently.
521
+ For "global" layers, we handle conversion as described above. For "local" layers, which use a
522
+ causal attention mask within a restricted local window, we do not alter the masking.
523
+
524
+ Notes on `forward` method conversion:
525
+ After conversion, the `forward` method will handle a new input, `bidirectional_mask`,
526
+ which should be a [batch_size, seq_length] byte tensor, where 1 indicates token positions
527
+ belonging to the prefix (prefix tokens can attend to one another bidirectionally), and
528
+ 0 indicates token positions belonging to the target.
529
+
530
+ The new `forward` method will incorporate `bidirectional_mask` (if supplied) into the existing
531
+ causal mask, call the original `forward` method, and (if the causal mask is a buffer) reset
532
+ the causal masks before returning the result.
533
+
534
+ Notes on `generate` method conversion:
535
+ After conversion, the `generate` method will have the same signature but will internally
536
+ convert all causal masks to be purely bidirectional, call the original `generate` method, and
537
+ (where appropriate) reset the causal masks before returning the result.
538
+
539
+ This works thanks to the logic of the HuggingFace `generate` API, which first encodes the token
540
+ "prompt" passed to `generate` (which is treated as the prefix) and then sequentially generates
541
+ each new token. Encodings are cached as generation happens, so all prefix tokens can attend to one
542
+ another (as expected in a Prefix LM) and generated tokens can only attend to prefix tokens and
543
+ previously-generated tokens (also as expected in a Prefix LM).
544
+
545
+ To preserve the API, the original methods are renamed to `_original_forward` and
546
+ `_original_generate`, and replaced with new `forward` and `generate` methods that wrap
547
+ them, respectively. Although implementation details vary by model class.
548
+ """
549
+ if isinstance(model, _SUPPORTED_GPT_MODELS):
550
+ return _convert_gpt_causal_lm_to_prefix_lm(model)
551
+ elif isinstance(model, BloomForCausalLM):
552
+ return _convert_bloom_causal_lm_to_prefix_lm(model)
553
+ elif isinstance(model, OPTForCausalLM):
554
+ return _convert_opt_causal_lm_to_prefix_lm(model)
555
+ else:
556
+ raise TypeError(f"Cannot convert model to Prefix LM. " + f"Model does not belong to set of supported HF models:" + f"\n{_SUPPORTED_HF_MODELS}")
557
+
558
+
559
+ def add_bidirectional_mask_if_missing(batch: Dict[str, Any]):
560
+ """Attempts to add bidirectional_mask to batch if missing.
561
+
562
+ Raises:
563
+ KeyError if bidirectional_mask is missing and can't be inferred
564
+ """
565
+ if "bidirectional_mask" not in batch:
566
+ if batch.get("mode", None) == "icl_task":
567
+ batch["bidirectional_mask"] = batch["attention_mask"].clone()
568
+ for i, continuation_indices in enumerate(batch["continuation_indices"]):
569
+ batch["bidirectional_mask"][i, continuation_indices] = 0
570
+ elif "labels" in batch and "attention_mask" in batch:
571
+ batch["bidirectional_mask"] = torch.logical_and(torch.eq(batch["attention_mask"], 1), torch.eq(batch["labels"], -100)).type_as(
572
+ batch["attention_mask"]
573
+ )
574
+ else:
575
+ raise KeyError("No bidirectional_mask in batch and not sure how to construct one.")
mllm/flamingo/mpt/meta_init_context.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from contextlib import contextmanager
2
+ import torch
3
+ import torch.nn as nn
4
+
5
+
6
+ @contextmanager
7
+ def init_empty_weights(include_buffers: bool = False):
8
+ """Meta initialization context manager.
9
+
10
+ A context manager under which models are initialized with all parameters
11
+ on the meta device, therefore creating an empty model. Useful when just
12
+ initializing the model would blow the available RAM.
13
+
14
+ Args:
15
+ include_buffers (`bool`, *optional*, defaults to `False`): Whether or
16
+ not to also put all buffers on the meta device while initializing.
17
+
18
+ Example:
19
+ ```python
20
+ import torch.nn as nn
21
+
22
+ # Initialize a model with 100 billions parameters in no time and without using any RAM.
23
+ with init_empty_weights():
24
+ tst = nn.Sequential(*[nn.Linear(10000, 10000) for _ in range(1000)])
25
+ ```
26
+
27
+ <Tip warning={true}>
28
+
29
+ Any model created under this context manager has no weights. As such you can't do something like
30
+ `model.to(some_device)` with it. To load weights inside your empty model, see [`load_checkpoint_and_dispatch`].
31
+
32
+ </Tip>
33
+ """
34
+ with init_on_device(torch.device("meta"), include_buffers=include_buffers) as f:
35
+ yield f
36
+
37
+
38
+ @contextmanager
39
+ def init_on_device(device: torch.device, include_buffers: bool = False):
40
+ """Device initialization context manager.
41
+
42
+ A context manager under which models are initialized with all parameters
43
+ on the specified device.
44
+
45
+ Args:
46
+ device (`torch.device`): Device to initialize all parameters on.
47
+ include_buffers (`bool`, *optional*, defaults to `False`): Whether or
48
+ not to also put all buffers on the meta device while initializing.
49
+
50
+ Example:
51
+ ```python
52
+ import torch.nn as nn
53
+
54
+ with init_on_device(device=torch.device("cuda")):
55
+ tst = nn.Liner(100, 100) # on `cuda` device
56
+ ```
57
+ """
58
+ old_register_parameter = nn.Module.register_parameter
59
+ if include_buffers:
60
+ old_register_buffer = nn.Module.register_buffer
61
+
62
+ def register_empty_parameter(module, name, param):
63
+ old_register_parameter(module, name, param)
64
+ if param is not None:
65
+ param_cls = type(module._parameters[name])
66
+ kwargs = module._parameters[name].__dict__
67
+ module._parameters[name] = param_cls(module._parameters[name].to(device), **kwargs)
68
+
69
+ def register_empty_buffer(module, name, buffer):
70
+ old_register_buffer(module, name, buffer)
71
+ if buffer is not None:
72
+ module._buffers[name] = module._buffers[name].to(device)
73
+
74
+ if include_buffers:
75
+ tensor_constructors_to_patch = {torch_function_name: getattr(torch, torch_function_name) for torch_function_name in ["empty", "zeros", "ones", "full"]}
76
+ else:
77
+ tensor_constructors_to_patch = {}
78
+
79
+ def patch_tensor_constructor(fn):
80
+ def wrapper(*args, **kwargs):
81
+ kwargs["device"] = device
82
+ return fn(*args, **kwargs)
83
+
84
+ return wrapper
85
+
86
+ try:
87
+ nn.Module.register_parameter = register_empty_parameter
88
+ if include_buffers:
89
+ nn.Module.register_buffer = register_empty_buffer
90
+ for torch_function_name in tensor_constructors_to_patch.keys():
91
+ setattr(torch, torch_function_name, patch_tensor_constructor(getattr(torch, torch_function_name)))
92
+ yield
93
+ finally:
94
+ nn.Module.register_parameter = old_register_parameter
95
+ if include_buffers:
96
+ nn.Module.register_buffer = old_register_buffer
97
+ for torch_function_name, old_torch_function in tensor_constructors_to_patch.items():
98
+ setattr(torch, torch_function_name, old_torch_function)
mllm/flamingo/mpt/modeling_mpt.py ADDED
@@ -0,0 +1,496 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """A simple, flexible implementation of a GPT model.
2
+
3
+ Inspired by https://github.com/karpathy/minGPT/blob/master/mingpt/model.py
4
+ """
5
+ import math
6
+ import warnings
7
+ from typing import List, Optional, Tuple, Union
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ from transformers import PreTrainedModel, PreTrainedTokenizer, PreTrainedTokenizerFast
13
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
14
+
15
+ from .attention import attn_bias_shape, build_attn_bias
16
+ from .blocks import MPTBlock
17
+ from .configuration_mpt import MPTConfig
18
+ from .custom_embedding import SharedEmbedding
19
+ from .norm import NORM_CLASS_REGISTRY
20
+ from .param_init_fns import MODEL_INIT_REGISTRY, generic_param_init_fn_
21
+
22
+ import torch.distributed as dist
23
+
24
+ try:
25
+ from .flash_attn_triton import flash_attn_func
26
+ except:
27
+ pass
28
+ Tokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
29
+
30
+
31
+ class MPTPreTrainedModel(PreTrainedModel):
32
+ config_class = MPTConfig
33
+ base_model_prefix = "model"
34
+ _no_split_modules = ["MPTBlock"]
35
+
36
+
37
+ class MPTModel(MPTPreTrainedModel):
38
+ def __init__(self, config: MPTConfig):
39
+ config._validate_config()
40
+ super().__init__(config)
41
+ self.attn_impl = config.attn_config["attn_impl"]
42
+ self.prefix_lm = config.attn_config["prefix_lm"]
43
+ self.attn_uses_sequence_id = config.attn_config["attn_uses_sequence_id"]
44
+ self.alibi = config.attn_config["alibi"]
45
+ self.alibi_bias_max = config.attn_config["alibi_bias_max"]
46
+ if config.init_device == "mixed":
47
+ if dist.get_local_rank() == 0:
48
+ config.init_device = "cpu"
49
+ else:
50
+ config.init_device = "meta"
51
+ if config.norm_type.lower() not in NORM_CLASS_REGISTRY.keys():
52
+ norm_options = " | ".join(NORM_CLASS_REGISTRY.keys())
53
+ raise NotImplementedError(f"Requested norm type ({config.norm_type}) is not implemented within this repo (Options: {norm_options}).")
54
+ norm_class = NORM_CLASS_REGISTRY[config.norm_type.lower()]
55
+ self.embedding_fraction = config.embedding_fraction
56
+ self.wte = SharedEmbedding(config.vocab_size, config.d_model, device=config.init_device)
57
+ if not self.alibi:
58
+ self.wpe = torch.nn.Embedding(config.max_seq_len, config.d_model, device=config.init_device)
59
+ self.emb_drop = nn.Dropout(config.emb_pdrop)
60
+ self.blocks = nn.ModuleList([MPTBlock(device=config.init_device, **config.to_dict()) for _ in range(config.n_layers)])
61
+ self.norm_f = norm_class(config.d_model, device=config.init_device)
62
+ if config.init_device != "meta":
63
+ print(
64
+ f'You are using config.init_device={config.init_device!r}, but you can also use config.init_device="meta" with Composer + FSDP for fast initialization.'
65
+ )
66
+ self.apply(self.param_init_fn)
67
+ self.is_causal = not self.prefix_lm
68
+ self._attn_bias_initialized = False
69
+ self.attn_bias = None
70
+ self.attn_bias_shape = attn_bias_shape(
71
+ self.attn_impl,
72
+ config.n_heads,
73
+ config.max_seq_len,
74
+ self.alibi,
75
+ prefix_lm=self.prefix_lm,
76
+ causal=self.is_causal,
77
+ use_sequence_id=self.attn_uses_sequence_id,
78
+ )
79
+ if config.no_bias:
80
+ for module in self.modules():
81
+ if hasattr(module, "bias") and isinstance(module.bias, nn.Parameter):
82
+ if config.verbose:
83
+ warnings.warn(f"Removing bias ({module.bias}) from {module}.")
84
+ module.register_parameter("bias", None)
85
+ if config.verbose and config.verbose > 2:
86
+ print(self)
87
+ if "verbose" not in self.config.init_config:
88
+ self.config.init_config["verbose"] = self.config.verbose
89
+ if self.config.init_config["verbose"] > 1:
90
+ init_fn_name = self.config.init_config["name"]
91
+ warnings.warn(f"Using {init_fn_name} initialization.")
92
+
93
+ def get_input_embeddings(self):
94
+ return self.wte
95
+
96
+ def set_input_embeddings(self, value):
97
+ self.wte = value
98
+
99
+ @torch.no_grad()
100
+ def _attn_bias(
101
+ self,
102
+ device,
103
+ dtype,
104
+ attention_mask: Optional[torch.ByteTensor] = None,
105
+ prefix_mask: Optional[torch.ByteTensor] = None,
106
+ sequence_id: Optional[torch.LongTensor] = None,
107
+ ):
108
+ if not self._attn_bias_initialized:
109
+ if self.attn_bias_shape:
110
+ self.attn_bias = torch.zeros(self.attn_bias_shape, device=device, dtype=dtype)
111
+ self.attn_bias = build_attn_bias(
112
+ self.attn_impl,
113
+ self.attn_bias,
114
+ self.config.n_heads,
115
+ self.config.max_seq_len,
116
+ causal=self.is_causal,
117
+ alibi=self.alibi,
118
+ alibi_bias_max=self.alibi_bias_max,
119
+ )
120
+ self._attn_bias_initialized = True
121
+ if self.attn_impl == "flash":
122
+ return (self.attn_bias, attention_mask)
123
+ if self.attn_bias is not None:
124
+ self.attn_bias = self.attn_bias.to(dtype=dtype, device=device)
125
+ attn_bias = self.attn_bias
126
+ if self.prefix_lm:
127
+ assert isinstance(attn_bias, torch.Tensor)
128
+ assert isinstance(prefix_mask, torch.Tensor)
129
+ attn_bias = self._apply_prefix_mask(attn_bias, prefix_mask)
130
+ if self.attn_uses_sequence_id and sequence_id is not None:
131
+ assert isinstance(attn_bias, torch.Tensor)
132
+ attn_bias = self._apply_sequence_id(attn_bias, sequence_id)
133
+ if attention_mask is not None:
134
+ s_k = attention_mask.shape[-1]
135
+ if attn_bias is None:
136
+ attn_bias = torch.zeros((1, 1, 1, s_k), device=device, dtype=dtype)
137
+ else:
138
+ _s_k = max(0, attn_bias.size(-1) - s_k)
139
+ attn_bias = attn_bias[:, :, :, _s_k:]
140
+ if prefix_mask is not None and attention_mask.shape != prefix_mask.shape:
141
+ raise ValueError(f"attention_mask shape={attention_mask.shape} " + f"and prefix_mask shape={prefix_mask.shape} are not equal.")
142
+ min_val = torch.finfo(attn_bias.dtype).min
143
+ attn_bias = attn_bias.masked_fill(~attention_mask.view(-1, 1, 1, s_k), min_val)
144
+ return (attn_bias, None)
145
+
146
+ def _apply_prefix_mask(self, attn_bias: torch.Tensor, prefix_mask: torch.Tensor):
147
+ (s_k, s_q) = attn_bias.shape[-2:]
148
+ if s_k != self.config.max_seq_len or s_q != self.config.max_seq_len:
149
+ raise ValueError(
150
+ "attn_bias does not match the expected shape. "
151
+ + f"The last two dimensions should both be {self.config.max_length} "
152
+ + f"but are {s_k} and {s_q}."
153
+ )
154
+ seq_len = prefix_mask.shape[-1]
155
+ if seq_len > self.config.max_seq_len:
156
+ raise ValueError(f"prefix_mask sequence length cannot exceed max_seq_len={self.config.max_seq_len}")
157
+ attn_bias = attn_bias[..., :seq_len, :seq_len]
158
+ causal = torch.tril(torch.ones((seq_len, seq_len), dtype=torch.bool, device=prefix_mask.device)).view(1, 1, seq_len, seq_len)
159
+ prefix = prefix_mask.view(-1, 1, 1, seq_len)
160
+ cannot_attend = ~torch.logical_or(causal, prefix.bool())
161
+ min_val = torch.finfo(attn_bias.dtype).min
162
+ attn_bias = attn_bias.masked_fill(cannot_attend, min_val)
163
+ return attn_bias
164
+
165
+ def _apply_sequence_id(self, attn_bias: torch.Tensor, sequence_id: torch.LongTensor):
166
+ seq_len = sequence_id.shape[-1]
167
+ if seq_len > self.config.max_seq_len:
168
+ raise ValueError(f"sequence_id sequence length cannot exceed max_seq_len={self.config.max_seq_len}")
169
+ attn_bias = attn_bias[..., :seq_len, :seq_len]
170
+ cannot_attend = torch.logical_not(torch.eq(sequence_id.view(-1, seq_len, 1), sequence_id.view(-1, 1, seq_len))).unsqueeze(1)
171
+ min_val = torch.finfo(attn_bias.dtype).min
172
+ attn_bias = attn_bias.masked_fill(cannot_attend, min_val)
173
+ return attn_bias
174
+
175
+ def forward(
176
+ self,
177
+ input_ids: torch.LongTensor,
178
+ past_key_values: Optional[List[Tuple[torch.FloatTensor]]] = None,
179
+ attention_mask: Optional[torch.ByteTensor] = None,
180
+ prefix_mask: Optional[torch.ByteTensor] = None,
181
+ sequence_id: Optional[torch.LongTensor] = None,
182
+ return_dict: Optional[bool] = None,
183
+ output_attentions: Optional[bool] = None,
184
+ output_hidden_states: Optional[bool] = None,
185
+ use_cache: Optional[bool] = None,
186
+ inputs_embeds: Optional[torch.Tensor] = None,
187
+ ):
188
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
189
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
190
+
191
+ if attention_mask is not None:
192
+ attention_mask = attention_mask.bool()
193
+
194
+ if prefix_mask is not None:
195
+ prefix_mask = prefix_mask.bool()
196
+
197
+ # These args are passed in by keyword in huggingface's generate function
198
+ # https://github.com/huggingface/transformers/blob/68287689f2f0d8b7063c400230b3766987abf18d/src/transformers/generation/utils.py#L2201-L2206
199
+ # but have not yet been fully implemented in MPTModel
200
+ if not return_dict:
201
+ raise NotImplementedError("return_dict False is not implemented yet for MPT")
202
+ if output_attentions:
203
+ if self.attn_impl != "torch":
204
+ raise NotImplementedError("output_attentions is not implemented for MPT when using attn_impl `flash` or `triton`.")
205
+
206
+ if attention_mask is not None and attention_mask[:, 0].sum() != attention_mask.shape[0] and self.training:
207
+ raise NotImplementedError("MPT does not support training with left padding.")
208
+
209
+ if self.prefix_lm and prefix_mask is None:
210
+ raise ValueError("prefix_mask is a required argument when MPT is configured with prefix_lm=True.")
211
+
212
+ # Raise a not implemented error if input_embeds is not None (this is an arg in huggingface transformers and we need to support it for PEFT)
213
+ if inputs_embeds is not None:
214
+ raise NotImplementedError("inputs_embeds is not implemented for MPT.")
215
+
216
+ if self.training:
217
+ if self.attn_uses_sequence_id and sequence_id is None:
218
+ raise ValueError(
219
+ "sequence_id is a required argument when MPT is configured with attn_uses_sequence_id=True " + "and the model is in train mode."
220
+ )
221
+ elif (self.attn_uses_sequence_id is False) and (sequence_id is not None):
222
+ warnings.warn(
223
+ "MPT received non-None input for `sequence_id` but is configured with attn_uses_sequence_id=False. "
224
+ + "This input will be ignored. If you want the model to use `sequence_id`, set attn_uses_sequence_id to True."
225
+ )
226
+
227
+ S = input_ids.size(1)
228
+
229
+ assert S <= self.config.max_seq_len, f"Cannot forward input with seq_len={S}, this model only supports seq_len<={self.config.max_seq_len}"
230
+
231
+ tok_emb = self.wte(input_ids) # type: ignore
232
+ if self.alibi:
233
+ x = tok_emb
234
+ else:
235
+ past_position = 0
236
+ if past_key_values is not None:
237
+ if len(past_key_values) != self.config.n_layers:
238
+ raise ValueError(
239
+ f"past_key_values must provide a past_key_value for each attention "
240
+ + f"layer in the network ({len(past_key_values)=}; {self.config.n_layers=})."
241
+ )
242
+ # For attn_impl: triton and flash the past key tensor spec is (batch, seq, dim).
243
+ # For attn_impl: torch the past key tensor spec is (batch, heads, head_dim, seq).
244
+ # Here we shift position embedding using the `seq` dim of the past key
245
+ past_position = past_key_values[0][0].size(1)
246
+ if self.attn_impl == "torch":
247
+ past_position = past_key_values[0][0].size(3)
248
+
249
+ if S + past_position > self.config.max_seq_len:
250
+ raise ValueError(
251
+ f"Cannot forward input with past sequence length {past_position} and current sequence length "
252
+ f"{S + 1}, this model only supports total sequence length <= {self.config.max_seq_len}."
253
+ )
254
+ pos = torch.arange(
255
+ past_position,
256
+ S + past_position,
257
+ dtype=torch.long,
258
+ device=input_ids.device,
259
+ ).unsqueeze(0)
260
+ if attention_mask is not None:
261
+ # adjust the position indices to account for padding tokens
262
+ pos = torch.clamp(
263
+ pos - torch.cumsum((~attention_mask).to(torch.int32), dim=1)[:, past_position:],
264
+ min=0,
265
+ )
266
+
267
+ pos_emb = self.wpe(pos) # type: ignore
268
+ x = tok_emb + pos_emb
269
+
270
+ if self.embedding_fraction == 1:
271
+ x = self.emb_drop(x) # type: ignore
272
+ else:
273
+ # this implementation is proposed on page 7 of the GLM-130B paper https://arxiv.org/abs/2210.02414
274
+ x_shrunk = (x * self.embedding_fraction) + (x.detach() * (1 - self.embedding_fraction))
275
+ assert isinstance(self.emb_drop, nn.Module) # pyright
276
+ x = self.emb_drop(x_shrunk)
277
+
278
+ attn_bias, attention_mask = self._attn_bias(
279
+ device=x.device,
280
+ dtype=torch.float32,
281
+ attention_mask=attention_mask,
282
+ prefix_mask=prefix_mask,
283
+ sequence_id=sequence_id,
284
+ )
285
+
286
+ # initialize the past key values cache if it should be used
287
+ if use_cache and past_key_values is None:
288
+ past_key_values = [() for _ in range(self.config.n_layers)] # type: ignore
289
+
290
+ all_hidden_states = () if output_hidden_states else None
291
+ all_self_attns = () if output_attentions else None
292
+ for b_idx, block in enumerate(self.blocks): # type: ignore
293
+ if output_hidden_states:
294
+ assert all_hidden_states is not None # pyright
295
+ all_hidden_states = all_hidden_states + (x,)
296
+ past_key_value = past_key_values[b_idx] if past_key_values is not None else None
297
+ x, attn_weights, past_key_value = block(
298
+ x,
299
+ past_key_value=past_key_value,
300
+ attn_bias=attn_bias,
301
+ attention_mask=attention_mask,
302
+ is_causal=self.is_causal,
303
+ )
304
+ if past_key_values is not None:
305
+ past_key_values[b_idx] = past_key_value
306
+
307
+ if output_attentions:
308
+ assert all_self_attns is not None # pyright
309
+ all_self_attns = all_self_attns + (attn_weights,)
310
+
311
+ x = self.norm_f(x) # type: ignore
312
+
313
+ # add hidden states from the last decoder layer
314
+ if output_hidden_states:
315
+ assert all_hidden_states is not None # pyright
316
+ all_hidden_states = all_hidden_states + (x,)
317
+
318
+ return BaseModelOutputWithPast(
319
+ last_hidden_state=x,
320
+ past_key_values=past_key_values,
321
+ hidden_states=all_hidden_states,
322
+ attentions=all_self_attns,
323
+ )
324
+
325
+ # Param Initialization, needed for device='meta' fast initialization
326
+ def param_init_fn(self, module):
327
+ init_fn_name = self.config.init_config["name"]
328
+ MODEL_INIT_REGISTRY[init_fn_name](
329
+ module=module,
330
+ n_layers=self.config.n_layers,
331
+ d_model=self.config.d_model,
332
+ **self.config.init_config,
333
+ )
334
+
335
+ def fsdp_wrap_fn(self, module):
336
+ return isinstance(module, MPTBlock)
337
+
338
+ def activation_checkpointing_fn(self, module):
339
+ return isinstance(module, MPTBlock)
340
+
341
+
342
+ class MPTForCausalLM(MPTPreTrainedModel):
343
+ def __init__(self, config: MPTConfig):
344
+ super().__init__(config)
345
+ if not config.tie_word_embeddings:
346
+ raise ValueError("MPTForCausalLM only supports tied word embeddings")
347
+ self.transformer = MPTModel(config)
348
+ for child in self.transformer.children():
349
+ if isinstance(child, torch.nn.ModuleList):
350
+ continue
351
+ if isinstance(child, torch.nn.Module):
352
+ child._fsdp_wrap = True
353
+ self.logit_scale = None
354
+ if config.logit_scale is not None:
355
+ logit_scale = config.logit_scale
356
+ if isinstance(logit_scale, str):
357
+ if logit_scale == "inv_sqrt_d_model":
358
+ logit_scale = 1 / math.sqrt(config.d_model)
359
+ else:
360
+ raise ValueError(f"logit_scale={logit_scale!r} is not recognized as an option; use numeric value or 'inv_sqrt_d_model'.")
361
+ self.logit_scale = logit_scale
362
+
363
+ def get_input_embeddings(self):
364
+ return self.transformer.wte
365
+
366
+ def set_input_embeddings(self, value):
367
+ # self.transformer.wte = value
368
+ peudo_wte = SharedEmbedding(value.weight.shape[0], value.weight.shape[1], device=self.transformer.wte.weight.device)
369
+ peudo_wte.weight = value.weight
370
+ self.transformer.wte = peudo_wte
371
+
372
+ def get_output_embeddings(self):
373
+ return self.transformer.wte
374
+
375
+ def set_output_embeddings(self, new_embeddings):
376
+ # self.transformer.wte = new_embeddings
377
+ peudo_wte = SharedEmbedding(new_embeddings.weight.shape[0], new_embeddings.weight.shape[1], device=self.transformer.wte.weight.device)
378
+ peudo_wte.weight = new_embeddings.weight
379
+ self.transformer.wte = peudo_wte
380
+
381
+ def set_decoder(self, decoder):
382
+ self.transformer = decoder
383
+
384
+ def get_decoder(self):
385
+ return self.transformer
386
+
387
+ def forward(
388
+ self,
389
+ input_ids: torch.LongTensor,
390
+ past_key_values: Optional[List[Tuple[torch.FloatTensor]]] = None,
391
+ attention_mask: Optional[torch.ByteTensor] = None,
392
+ prefix_mask: Optional[torch.ByteTensor] = None,
393
+ sequence_id: Optional[torch.LongTensor] = None,
394
+ labels: Optional[torch.LongTensor] = None,
395
+ return_dict: Optional[bool] = None,
396
+ output_attentions: Optional[bool] = None,
397
+ output_hidden_states: Optional[bool] = None,
398
+ use_cache: Optional[bool] = None,
399
+ inputs_embeds: Optional[torch.FloatTensor] = None,
400
+ ):
401
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
402
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
403
+
404
+ # if input_embeds is not none, raise a not implemented error
405
+ if inputs_embeds is not None:
406
+ raise NotImplementedError("inputs_embeds has to be None (for hf/peft support).")
407
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
408
+ outputs = self.transformer(
409
+ input_ids=input_ids,
410
+ past_key_values=past_key_values,
411
+ attention_mask=attention_mask,
412
+ prefix_mask=prefix_mask,
413
+ sequence_id=sequence_id,
414
+ return_dict=return_dict,
415
+ output_attentions=output_attentions,
416
+ output_hidden_states=output_hidden_states,
417
+ use_cache=use_cache,
418
+ )
419
+
420
+ # move outputs to same device as weights for token embedding
421
+ # needed to support HF `device_map`
422
+ logits = self.transformer.wte(
423
+ input=outputs.last_hidden_state.to(self.transformer.wte.weight.device),
424
+ unembed=True,
425
+ )
426
+
427
+ if self.logit_scale is not None:
428
+ if self.logit_scale == 0:
429
+ warnings.warn(f"Multiplying logits by {self.logit_scale=}. This will produce uniform (uninformative) outputs.")
430
+ logits *= self.logit_scale
431
+
432
+ loss = None
433
+ if labels is not None:
434
+ _labels = torch.roll(labels, shifts=-1)
435
+ _labels[:, -1] = -100
436
+ loss = F.cross_entropy(
437
+ logits.view(-1, logits.size(-1)),
438
+ _labels.to(logits.device).view(-1),
439
+ )
440
+
441
+ return CausalLMOutputWithPast(
442
+ loss=loss,
443
+ logits=logits,
444
+ past_key_values=outputs.past_key_values,
445
+ hidden_states=outputs.hidden_states,
446
+ attentions=outputs.attentions,
447
+ )
448
+
449
+ def param_init_fn(self, module):
450
+ init_fn_name = self.config.init_config["name"]
451
+ MODEL_INIT_REGISTRY[init_fn_name](module=module, n_layers=self.config.n_layers, d_model=self.config.d_model, **self.config.init_config)
452
+
453
+ def fsdp_wrap_fn(self, module):
454
+ return isinstance(module, MPTBlock)
455
+
456
+ def activation_checkpointing_fn(self, module):
457
+ return isinstance(module, MPTBlock)
458
+
459
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, attention_mask=None, **kwargs):
460
+ if inputs_embeds is not None:
461
+ raise NotImplementedError("inputs_embeds is not implemented for MPT yet")
462
+ attention_mask = attention_mask.bool()
463
+ if attention_mask[:, -1].sum() != attention_mask.shape[0]:
464
+ raise NotImplementedError("MPT does not support generation with right padding.")
465
+ if self.transformer.attn_uses_sequence_id and self.training:
466
+ sequence_id = torch.zeros_like(input_ids[:1])
467
+ else:
468
+ sequence_id = None
469
+ if past_key_values is not None:
470
+ input_ids = input_ids[:, -1].unsqueeze(-1)
471
+ if self.transformer.prefix_lm:
472
+ prefix_mask = torch.ones_like(attention_mask)
473
+ if kwargs.get("use_cache") == False:
474
+ raise NotImplementedError("MPT with prefix_lm=True does not support use_cache=False.")
475
+ else:
476
+ prefix_mask = None
477
+ return {
478
+ "input_ids": input_ids,
479
+ "attention_mask": attention_mask,
480
+ "prefix_mask": prefix_mask,
481
+ "sequence_id": sequence_id,
482
+ "past_key_values": past_key_values,
483
+ "use_cache": kwargs.get("use_cache", True),
484
+ }
485
+
486
+ @staticmethod
487
+ def _reorder_cache(past_key_values, beam_idx):
488
+ """Used by HuggingFace generate when using beam search with kv-caching.
489
+
490
+ See https://github.com/huggingface/transformers/blob/3ec7a47664ebe40c40f4b722f6bb1cd30c3821ec/src/transformers/models/gpt2/modeling_gpt2.py#L1122-L1133
491
+ for an example in transformers.
492
+ """
493
+ reordered_past = []
494
+ for layer_past in past_key_values:
495
+ reordered_past += [tuple((past_state.index_select(0, beam_idx) for past_state in layer_past))]
496
+ return reordered_past
mllm/flamingo/mpt/norm.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ def _cast_if_autocast_enabled(tensor):
5
+ if torch.is_autocast_enabled():
6
+ if tensor.device.type == "cuda":
7
+ dtype = torch.get_autocast_gpu_dtype()
8
+ elif tensor.device.type == "cpu":
9
+ dtype = torch.get_autocast_cpu_dtype()
10
+ else:
11
+ raise NotImplementedError()
12
+ return tensor.to(dtype=dtype)
13
+ return tensor
14
+
15
+
16
+ class LPLayerNorm(torch.nn.LayerNorm):
17
+ def __init__(self, normalized_shape, eps=1e-05, elementwise_affine=True, device=None, dtype=None):
18
+ super().__init__(normalized_shape=normalized_shape, eps=eps, elementwise_affine=elementwise_affine, device=device, dtype=dtype)
19
+
20
+ def forward(self, x):
21
+ module_device = x.device
22
+ downcast_x = _cast_if_autocast_enabled(x)
23
+ downcast_weight = _cast_if_autocast_enabled(self.weight) if self.weight is not None else self.weight
24
+ downcast_bias = _cast_if_autocast_enabled(self.bias) if self.bias is not None else self.bias
25
+ with torch.autocast(enabled=False, device_type=module_device.type):
26
+ return torch.nn.functional.layer_norm(downcast_x, self.normalized_shape, downcast_weight, downcast_bias, self.eps)
27
+
28
+
29
+ def rms_norm(x, weight=None, eps=1e-05):
30
+ output = x / torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps)
31
+ if weight is not None:
32
+ return output * weight
33
+ return output
34
+
35
+
36
+ class RMSNorm(torch.nn.Module):
37
+ def __init__(self, normalized_shape, eps=1e-05, weight=True, dtype=None, device=None):
38
+ super().__init__()
39
+ self.eps = eps
40
+ if weight:
41
+ self.weight = torch.nn.Parameter(torch.ones(normalized_shape, dtype=dtype, device=device))
42
+ else:
43
+ self.register_parameter("weight", None)
44
+
45
+ def forward(self, x):
46
+ return rms_norm(x.float(), self.weight, self.eps).to(dtype=x.dtype)
47
+
48
+
49
+ class LPRMSNorm(RMSNorm):
50
+ def __init__(self, normalized_shape, eps=1e-05, weight=True, dtype=None, device=None):
51
+ super().__init__(normalized_shape=normalized_shape, eps=eps, weight=weight, dtype=dtype, device=device)
52
+
53
+ def forward(self, x):
54
+ downcast_x = _cast_if_autocast_enabled(x)
55
+ downcast_weight = _cast_if_autocast_enabled(self.weight) if self.weight is not None else self.weight
56
+ with torch.autocast(enabled=False, device_type=x.device.type):
57
+ return rms_norm(downcast_x, downcast_weight, self.eps).to(dtype=x.dtype)
58
+
59
+
60
+ NORM_CLASS_REGISTRY = {"layernorm": torch.nn.LayerNorm, "low_precision_layernorm": LPLayerNorm, "rmsnorm": RMSNorm, "low_precision_rmsnorm": LPRMSNorm}
mllm/flamingo/mpt/param_init_fns.py ADDED
@@ -0,0 +1,369 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import warnings
3
+ from collections.abc import Sequence
4
+ from functools import partial
5
+ from typing import Optional, Tuple, Union
6
+ import torch
7
+ from torch import nn
8
+ from .norm import NORM_CLASS_REGISTRY
9
+
10
+
11
+ def torch_default_param_init_fn_(module: nn.Module, verbose: int = 0, **kwargs):
12
+ del kwargs
13
+ if verbose > 1:
14
+ warnings.warn(f"Initializing network using module's reset_parameters attribute")
15
+ if hasattr(module, "reset_parameters"):
16
+ module.reset_parameters()
17
+
18
+
19
+ def fused_init_helper_(module: nn.Module, init_fn_):
20
+ _fused = getattr(module, "_fused", None)
21
+ if _fused is None:
22
+ raise RuntimeError(f"Internal logic error")
23
+ (dim, splits) = _fused
24
+ splits = (0, *splits, module.weight.size(dim))
25
+ for s, e in zip(splits[:-1], splits[1:]):
26
+ slice_indices = [slice(None)] * module.weight.ndim
27
+ slice_indices[dim] = slice(s, e)
28
+ init_fn_(module.weight[slice_indices])
29
+
30
+
31
+ def generic_param_init_fn_(
32
+ module: nn.Module,
33
+ init_fn_,
34
+ n_layers: int,
35
+ d_model: Optional[int] = None,
36
+ init_div_is_residual: Union[int, float, str, bool] = True,
37
+ emb_init_std: Optional[float] = None,
38
+ emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None,
39
+ verbose: int = 0,
40
+ **kwargs,
41
+ ):
42
+ del kwargs
43
+ if verbose > 1:
44
+ warnings.warn(f"If model has bias parameters they are initialized to 0.")
45
+ init_div_is_residual = init_div_is_residual
46
+ if init_div_is_residual is False:
47
+ div_is_residual = 1.0
48
+ elif init_div_is_residual is True:
49
+ div_is_residual = math.sqrt(2 * n_layers)
50
+ elif isinstance(init_div_is_residual, float) or isinstance(init_div_is_residual, int):
51
+ div_is_residual = init_div_is_residual
52
+ elif isinstance(init_div_is_residual, str) and init_div_is_residual.isnumeric():
53
+ div_is_residual = float(init_div_is_residual)
54
+ else:
55
+ div_is_residual = 1.0
56
+ raise ValueError(f"Expected init_div_is_residual to be boolean or numeric, got {init_div_is_residual}")
57
+ if init_div_is_residual is not False:
58
+ if verbose > 1:
59
+ warnings.warn(
60
+ f"Initializing _is_residual layers then dividing them by {div_is_residual:.3f}. "
61
+ + f"Set `init_div_is_residual: false` in init config to disable this."
62
+ )
63
+ if isinstance(module, nn.Linear):
64
+ if hasattr(module, "_fused"):
65
+ fused_init_helper_(module, init_fn_)
66
+ else:
67
+ init_fn_(module.weight)
68
+ if module.bias is not None:
69
+ torch.nn.init.zeros_(module.bias)
70
+ if init_div_is_residual is not False and getattr(module, "_is_residual", False):
71
+ with torch.no_grad():
72
+ module.weight.div_(div_is_residual)
73
+ elif isinstance(module, nn.Embedding):
74
+ if emb_init_std is not None:
75
+ std = emb_init_std
76
+ if std == 0:
77
+ warnings.warn(f"Embedding layer initialized to 0.")
78
+ emb_init_fn_ = partial(torch.nn.init.normal_, mean=0.0, std=std)
79
+ if verbose > 1:
80
+ warnings.warn(f"Embedding layer initialized using normal distribution with mean=0 and std={std!r}.")
81
+ elif emb_init_uniform_lim is not None:
82
+ lim = emb_init_uniform_lim
83
+ if isinstance(lim, Sequence):
84
+ if len(lim) > 2:
85
+ raise ValueError(f"Uniform init requires a min and a max limit. User input: {lim}.")
86
+ if lim[0] == lim[1]:
87
+ warnings.warn(f"Embedding layer initialized to {lim[0]}.")
88
+ else:
89
+ if lim == 0:
90
+ warnings.warn(f"Embedding layer initialized to 0.")
91
+ lim = [-lim, lim]
92
+ (a, b) = lim
93
+ emb_init_fn_ = partial(torch.nn.init.uniform_, a=a, b=b)
94
+ if verbose > 1:
95
+ warnings.warn(f"Embedding layer initialized using uniform distribution in range {lim}.")
96
+ else:
97
+ emb_init_fn_ = init_fn_
98
+ emb_init_fn_(module.weight)
99
+ elif isinstance(module, tuple(set(NORM_CLASS_REGISTRY.values()))):
100
+ if verbose > 1:
101
+ warnings.warn(f"Norm weights are set to 1. If norm layer has a bias it is initialized to 0.")
102
+ if hasattr(module, "weight") and module.weight is not None:
103
+ torch.nn.init.ones_(module.weight)
104
+ if hasattr(module, "bias") and module.bias is not None:
105
+ torch.nn.init.zeros_(module.bias)
106
+ elif isinstance(module, nn.MultiheadAttention):
107
+ if module._qkv_same_embed_dim:
108
+ assert module.in_proj_weight is not None
109
+ assert module.q_proj_weight is None and module.k_proj_weight is None and (module.v_proj_weight is None)
110
+ assert d_model is not None
111
+ _d = d_model
112
+ splits = (0, _d, 2 * _d, 3 * _d)
113
+ for s, e in zip(splits[:-1], splits[1:]):
114
+ init_fn_(module.in_proj_weight[s:e])
115
+ else:
116
+ assert module.q_proj_weight is not None and module.k_proj_weight is not None and (module.v_proj_weight is not None)
117
+ assert module.in_proj_weight is None
118
+ init_fn_(module.q_proj_weight)
119
+ init_fn_(module.k_proj_weight)
120
+ init_fn_(module.v_proj_weight)
121
+ if module.in_proj_bias is not None:
122
+ torch.nn.init.zeros_(module.in_proj_bias)
123
+ if module.bias_k is not None:
124
+ torch.nn.init.zeros_(module.bias_k)
125
+ if module.bias_v is not None:
126
+ torch.nn.init.zeros_(module.bias_v)
127
+ init_fn_(module.out_proj.weight)
128
+ if init_div_is_residual is not False and getattr(module.out_proj, "_is_residual", False):
129
+ with torch.no_grad():
130
+ module.out_proj.weight.div_(div_is_residual)
131
+ if module.out_proj.bias is not None:
132
+ torch.nn.init.zeros_(module.out_proj.bias)
133
+ else:
134
+ for _ in module.parameters(recurse=False):
135
+ raise NotImplementedError(f"{module.__class__.__name__} parameters are not initialized by param_init_fn.")
136
+
137
+
138
+ def _normal_init_(std, mean=0.0):
139
+ return partial(torch.nn.init.normal_, mean=mean, std=std)
140
+
141
+
142
+ def _normal_param_init_fn_(
143
+ module: nn.Module,
144
+ std: float,
145
+ n_layers: int,
146
+ d_model: Optional[int] = None,
147
+ init_div_is_residual: Union[int, float, str, bool] = True,
148
+ emb_init_std: Optional[float] = None,
149
+ emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None,
150
+ verbose: int = 0,
151
+ **kwargs,
152
+ ):
153
+ del kwargs
154
+ init_fn_ = _normal_init_(std=std)
155
+ if verbose > 1:
156
+ warnings.warn(f"Using torch.nn.init.normal_ init fn mean=0.0, std={std}")
157
+ generic_param_init_fn_(
158
+ module=module,
159
+ init_fn_=init_fn_,
160
+ d_model=d_model,
161
+ n_layers=n_layers,
162
+ init_div_is_residual=init_div_is_residual,
163
+ emb_init_std=emb_init_std,
164
+ emb_init_uniform_lim=emb_init_uniform_lim,
165
+ verbose=verbose,
166
+ )
167
+
168
+
169
+ def baseline_param_init_fn_(
170
+ module: nn.Module,
171
+ init_std: float,
172
+ n_layers: int,
173
+ d_model: Optional[int] = None,
174
+ init_div_is_residual: Union[int, float, str, bool] = True,
175
+ emb_init_std: Optional[float] = None,
176
+ emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None,
177
+ verbose: int = 0,
178
+ **kwargs,
179
+ ):
180
+ del kwargs
181
+ if init_std is None:
182
+ raise ValueError("You must set model.init_config['init_std'] to a float value to use the default initialization scheme.")
183
+ _normal_param_init_fn_(
184
+ module=module,
185
+ std=init_std,
186
+ d_model=d_model,
187
+ n_layers=n_layers,
188
+ init_div_is_residual=init_div_is_residual,
189
+ emb_init_std=emb_init_std,
190
+ emb_init_uniform_lim=emb_init_uniform_lim,
191
+ verbose=verbose,
192
+ )
193
+
194
+
195
+ def small_param_init_fn_(
196
+ module: nn.Module,
197
+ n_layers: int,
198
+ d_model: int,
199
+ init_div_is_residual: Union[int, float, str, bool] = True,
200
+ emb_init_std: Optional[float] = None,
201
+ emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None,
202
+ verbose: int = 0,
203
+ **kwargs,
204
+ ):
205
+ del kwargs
206
+ std = math.sqrt(2 / (5 * d_model))
207
+ _normal_param_init_fn_(
208
+ module=module,
209
+ std=std,
210
+ d_model=d_model,
211
+ n_layers=n_layers,
212
+ init_div_is_residual=init_div_is_residual,
213
+ emb_init_std=emb_init_std,
214
+ emb_init_uniform_lim=emb_init_uniform_lim,
215
+ verbose=verbose,
216
+ )
217
+
218
+
219
+ def neox_param_init_fn_(
220
+ module: nn.Module,
221
+ n_layers: int,
222
+ d_model: int,
223
+ emb_init_std: Optional[float] = None,
224
+ emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None,
225
+ verbose: int = 0,
226
+ **kwargs,
227
+ ):
228
+ """From section 2.3.1 of GPT-NeoX-20B:
229
+
230
+ An Open-Source AutoregressiveLanguage Model — Black et. al. (2022)
231
+ see https://github.com/EleutherAI/gpt-neox/blob/9610391ab319403cef079b438edd016a2443af54/megatron/model/init_functions.py#L151
232
+ and https://github.com/EleutherAI/gpt-neox/blob/main/megatron/model/transformer.py
233
+ """
234
+ del kwargs
235
+ residual_div = n_layers / math.sqrt(10)
236
+ if verbose > 1:
237
+ warnings.warn(f"setting init_div_is_residual to {residual_div}")
238
+ small_param_init_fn_(
239
+ module=module,
240
+ d_model=d_model,
241
+ n_layers=n_layers,
242
+ init_div_is_residual=residual_div,
243
+ emb_init_std=emb_init_std,
244
+ emb_init_uniform_lim=emb_init_uniform_lim,
245
+ verbose=verbose,
246
+ )
247
+
248
+
249
+ def kaiming_uniform_param_init_fn_(
250
+ module: nn.Module,
251
+ n_layers: int,
252
+ d_model: Optional[int] = None,
253
+ init_div_is_residual: Union[int, float, str, bool] = True,
254
+ emb_init_std: Optional[float] = None,
255
+ emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None,
256
+ init_gain: float = 0,
257
+ fan_mode: str = "fan_in",
258
+ init_nonlinearity: str = "leaky_relu",
259
+ verbose: int = 0,
260
+ **kwargs,
261
+ ):
262
+ del kwargs
263
+ if verbose > 1:
264
+ warnings.warn(f"Using nn.init.kaiming_uniform_ init fn with parameters: " + f"a={init_gain}, mode={fan_mode}, nonlinearity={init_nonlinearity}")
265
+ kaiming_uniform_ = partial(nn.init.kaiming_uniform_, a=init_gain, mode=fan_mode, nonlinearity=init_nonlinearity)
266
+ generic_param_init_fn_(
267
+ module=module,
268
+ init_fn_=kaiming_uniform_,
269
+ d_model=d_model,
270
+ n_layers=n_layers,
271
+ init_div_is_residual=init_div_is_residual,
272
+ emb_init_std=emb_init_std,
273
+ emb_init_uniform_lim=emb_init_uniform_lim,
274
+ verbose=verbose,
275
+ )
276
+
277
+
278
+ def kaiming_normal_param_init_fn_(
279
+ module: nn.Module,
280
+ n_layers: int,
281
+ d_model: Optional[int] = None,
282
+ init_div_is_residual: Union[int, float, str, bool] = True,
283
+ emb_init_std: Optional[float] = None,
284
+ emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None,
285
+ init_gain: float = 0,
286
+ fan_mode: str = "fan_in",
287
+ init_nonlinearity: str = "leaky_relu",
288
+ verbose: int = 0,
289
+ **kwargs,
290
+ ):
291
+ del kwargs
292
+ if verbose > 1:
293
+ warnings.warn(f"Using nn.init.kaiming_normal_ init fn with parameters: " + f"a={init_gain}, mode={fan_mode}, nonlinearity={init_nonlinearity}")
294
+ kaiming_normal_ = partial(torch.nn.init.kaiming_normal_, a=init_gain, mode=fan_mode, nonlinearity=init_nonlinearity)
295
+ generic_param_init_fn_(
296
+ module=module,
297
+ init_fn_=kaiming_normal_,
298
+ d_model=d_model,
299
+ n_layers=n_layers,
300
+ init_div_is_residual=init_div_is_residual,
301
+ emb_init_std=emb_init_std,
302
+ emb_init_uniform_lim=emb_init_uniform_lim,
303
+ verbose=verbose,
304
+ )
305
+
306
+
307
+ def xavier_uniform_param_init_fn_(
308
+ module: nn.Module,
309
+ n_layers: int,
310
+ d_model: Optional[int] = None,
311
+ init_div_is_residual: Union[int, float, str, bool] = True,
312
+ emb_init_std: Optional[float] = None,
313
+ emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None,
314
+ init_gain: float = 0,
315
+ verbose: int = 0,
316
+ **kwargs,
317
+ ):
318
+ del kwargs
319
+ xavier_uniform_ = partial(torch.nn.init.xavier_uniform_, gain=init_gain)
320
+ if verbose > 1:
321
+ warnings.warn(f"Using torch.nn.init.xavier_uniform_ init fn with parameters: " + f"gain={init_gain}")
322
+ generic_param_init_fn_(
323
+ module=module,
324
+ init_fn_=xavier_uniform_,
325
+ d_model=d_model,
326
+ n_layers=n_layers,
327
+ init_div_is_residual=init_div_is_residual,
328
+ emb_init_std=emb_init_std,
329
+ emb_init_uniform_lim=emb_init_uniform_lim,
330
+ verbose=verbose,
331
+ )
332
+
333
+
334
+ def xavier_normal_param_init_fn_(
335
+ module: nn.Module,
336
+ n_layers: int,
337
+ d_model: Optional[int] = None,
338
+ init_div_is_residual: Union[int, float, str, bool] = True,
339
+ emb_init_std: Optional[float] = None,
340
+ emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None,
341
+ init_gain: float = 0,
342
+ verbose: int = 0,
343
+ **kwargs,
344
+ ):
345
+ xavier_normal_ = partial(torch.nn.init.xavier_normal_, gain=init_gain)
346
+ if verbose > 1:
347
+ warnings.warn(f"Using torch.nn.init.xavier_normal_ init fn with parameters: " + f"gain={init_gain}")
348
+ generic_param_init_fn_(
349
+ module=module,
350
+ init_fn_=xavier_normal_,
351
+ d_model=d_model,
352
+ n_layers=n_layers,
353
+ init_div_is_residual=init_div_is_residual,
354
+ emb_init_std=emb_init_std,
355
+ emb_init_uniform_lim=emb_init_uniform_lim,
356
+ verbose=verbose,
357
+ )
358
+
359
+
360
+ MODEL_INIT_REGISTRY = {
361
+ "default_": torch_default_param_init_fn_,
362
+ "baseline_": baseline_param_init_fn_,
363
+ "kaiming_uniform_": kaiming_uniform_param_init_fn_,
364
+ "kaiming_normal_": kaiming_normal_param_init_fn_,
365
+ "neox_init_": neox_param_init_fn_,
366
+ "small_init_": small_param_init_fn_,
367
+ "xavier_uniform_": xavier_uniform_param_init_fn_,
368
+ "xavier_normal_": xavier_normal_param_init_fn_,
369
+ }
mllm/flamingo/mpt_redpajama/__init__.py ADDED
File without changes
mllm/flamingo/mpt_redpajama/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (216 Bytes). View file