gray311 commited on
Commit
8443bea
·
verified ·
1 Parent(s): 86caf27

Delete mllm

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. mllm/flamingo/__init__.py +0 -48
  2. mllm/flamingo/config.json +0 -21
  3. mllm/flamingo/configuration_flamingo.py +0 -100
  4. mllm/flamingo/converting_flamingo_to_bf16.py +0 -30
  5. mllm/flamingo/converting_flamingo_to_hf.py +0 -61
  6. mllm/flamingo/converting_flamingo_to_lora.py +0 -68
  7. mllm/flamingo/falcon/__init__.py +0 -0
  8. mllm/flamingo/falcon/__pycache__/__init__.cpython-39.pyc +0 -0
  9. mllm/flamingo/falcon/__pycache__/configuration_RW.cpython-39.pyc +0 -0
  10. mllm/flamingo/falcon/__pycache__/modelling_RW.cpython-39.pyc +0 -0
  11. mllm/flamingo/falcon/configuration_RW.py +0 -79
  12. mllm/flamingo/falcon/modelling_RW.py +0 -1064
  13. mllm/flamingo/flamingo-falcon-7B.json +0 -112
  14. mllm/flamingo/flamingo-llama2-chat-13B.json +0 -114
  15. mllm/flamingo/flamingo-llama2-chat-7B.json +0 -115
  16. mllm/flamingo/flamingo-mpt-1B-redpajama.json +0 -131
  17. mllm/flamingo/flamingo-mpt-30B-bf16.json +0 -195
  18. mllm/flamingo/flamingo-mpt-30B.json +0 -195
  19. mllm/flamingo/flamingo-mpt-7B.json +0 -195
  20. mllm/flamingo/flamingo-vicuna-33B-v1.3.json +0 -111
  21. mllm/flamingo/flamingo-vicuna-7B-v1.3.json +0 -111
  22. mllm/flamingo/injecting_falcon_into_flamingo.py +0 -49
  23. mllm/flamingo/injecting_llama2_into_flamingo.py +0 -95
  24. mllm/flamingo/injecting_mpt-1B-redpajama_into_flamingo.py +0 -97
  25. mllm/flamingo/injecting_mpt_into_flamingo.py +0 -109
  26. mllm/flamingo/injecting_vicuna_into_flamingo.py +0 -100
  27. mllm/flamingo/modeling_flamingo.py +0 -966
  28. mllm/flamingo/mpt/__init__.py +0 -0
  29. mllm/flamingo/mpt/__pycache__/__init__.cpython-39.pyc +0 -0
  30. mllm/flamingo/mpt/__pycache__/attention.cpython-39.pyc +0 -0
  31. mllm/flamingo/mpt/__pycache__/blocks.cpython-39.pyc +0 -0
  32. mllm/flamingo/mpt/__pycache__/configuration_mpt.cpython-39.pyc +0 -0
  33. mllm/flamingo/mpt/__pycache__/custom_embedding.cpython-39.pyc +0 -0
  34. mllm/flamingo/mpt/__pycache__/flash_attn_triton.cpython-39.pyc +0 -0
  35. mllm/flamingo/mpt/__pycache__/modeling_mpt.cpython-39.pyc +0 -0
  36. mllm/flamingo/mpt/__pycache__/norm.cpython-39.pyc +0 -0
  37. mllm/flamingo/mpt/__pycache__/param_init_fns.cpython-39.pyc +0 -0
  38. mllm/flamingo/mpt/adapt_tokenizer.py +0 -44
  39. mllm/flamingo/mpt/attention.py +0 -450
  40. mllm/flamingo/mpt/blocks.py +0 -82
  41. mllm/flamingo/mpt/configuration_mpt.py +0 -161
  42. mllm/flamingo/mpt/custom_embedding.py +0 -11
  43. mllm/flamingo/mpt/flash_attn_triton.py +0 -841
  44. mllm/flamingo/mpt/hf_prefixlm_converter.py +0 -575
  45. mllm/flamingo/mpt/meta_init_context.py +0 -98
  46. mllm/flamingo/mpt/modeling_mpt.py +0 -496
  47. mllm/flamingo/mpt/norm.py +0 -60
  48. mllm/flamingo/mpt/param_init_fns.py +0 -369
  49. mllm/flamingo/mpt_redpajama/__init__.py +0 -0
  50. mllm/flamingo/mpt_redpajama/__pycache__/__init__.cpython-39.pyc +0 -0
mllm/flamingo/__init__.py DELETED
@@ -1,48 +0,0 @@
1
- from typing import TYPE_CHECKING
2
-
3
- from transformers.utils import (
4
- OptionalDependencyNotAvailable,
5
- _LazyModule,
6
- is_torch_available,
7
- )
8
-
9
-
10
- _import_structure = {
11
- "configuration_flamingo": [
12
- "FlamingoConfig",
13
- ],
14
- }
15
-
16
- try:
17
- if not is_torch_available():
18
- raise OptionalDependencyNotAvailable()
19
- except OptionalDependencyNotAvailable:
20
- pass
21
- else:
22
- _import_structure["modeling_flamingo"] = [
23
- "FlamingoModel",
24
- "FlamingoPreTrainedModel",
25
- "FlamingoForConditionalGeneration",
26
- ]
27
-
28
- if TYPE_CHECKING:
29
- from .configuration_flamingo import FlamingoConfig
30
-
31
- # from .processing_flamingo import FlamingoProcessor
32
-
33
- try:
34
- if not is_torch_available():
35
- raise OptionalDependencyNotAvailable()
36
- except OptionalDependencyNotAvailable:
37
- pass
38
- else:
39
- from .modeling_flamingo import (
40
- FlamingoForConditionalGeneration,
41
- FlamingoModel,
42
- FlamingoPreTrainedModel,
43
- )
44
-
45
- else:
46
- import sys
47
-
48
- sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mllm/flamingo/config.json DELETED
@@ -1,21 +0,0 @@
1
- {
2
- "model_type": "flamingo",
3
- "cross_attn_every_n_layers": 4,
4
- "tie_word_embeddings": false,
5
- "use_media_placement_augmentation": true,
6
- "only_attend_previous": true,
7
- "text_config": {
8
- "_name_or_path": "luodian/llama-7b-hf",
9
- "model_type": "llama"
10
- },
11
- "vision_config": {
12
- "_name_or_path": "openai/clip-vit-large-patch14",
13
- "model_type": "clip_vision_model",
14
- "hidden_size": 1024,
15
- "intermediate_size": 4096,
16
- "num_attention_heads": 16,
17
- "num_hidden_layers": 24,
18
- "image_size": 224,
19
- "patch_size": 14
20
- }
21
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mllm/flamingo/configuration_flamingo.py DELETED
@@ -1,100 +0,0 @@
1
- import copy
2
-
3
- from transformers.configuration_utils import PretrainedConfig
4
- from transformers.utils import logging
5
-
6
- from transformers.models.auto import CONFIG_MAPPING
7
- from transformers.models.clip import CLIPVisionConfig
8
- import sys
9
-
10
- from .falcon.configuration_RW import RWConfig
11
- from .mpt.configuration_mpt import MPTConfig
12
- from .mpt_redpajama.configuration_mosaic_gpt import MosaicGPTConfig
13
-
14
-
15
- logger = logging.get_logger(__name__)
16
-
17
-
18
- class FlamingoConfig(PretrainedConfig):
19
- r"""
20
- [`FlamingoConfig`] is the configuration class to store the configuration of a [`FlamingoForConditionalGeneration`]. It is
21
- used to instantiate a Flamingo model according to the specified arguments, defining the vision model and language model configs. Instantiating a configuration with the defaults will yield a similar configuration to
22
- that of the Flamingo architecture.
23
-
24
- Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
25
- documentation from [`PretrainedConfig`] for more information.
26
-
27
- Args:
28
- vision_config (`dict`, *optional*):
29
- Dictionary of configuration options used to initialize [`PretrainedConfig`].
30
- text_config (`dict`, *optional*):
31
- Dictionary of configuration options used to initialize any [`PretrainedConfig`].
32
- cross_attn_every_n_layers (`int`, *optional*, defaults to 4):
33
- The number of cross-attention layers adding after each transformer layer.
34
-
35
- kwargs (*optional*):
36
- Dictionary of keyword arguments.
37
-
38
- Example:
39
-
40
- ```python
41
- >>> from transformers import (
42
- ... PretrainedConfig,
43
- ... OPTConfig,
44
- ... FlamingoConfig,
45
- ... FlamingoForConditionalGeneration,
46
- ... )
47
-
48
- >>> # Initializing a FlamingoConfig with Salesforce/Flamingo-opt-2.7b style configuration
49
- >>> configuration = FlamingoConfig()
50
-
51
- >>> # Initializing a FlamingoForConditionalGeneration (with random weights) from the Salesforce/Flamingo-opt-2.7b style configuration
52
- >>> model = FlamingoForConditionalGeneration(configuration)
53
- ```"""
54
- model_type = "flamingo"
55
- is_composition = True
56
-
57
- def __init__(self, vision_config=None, text_config=None, cross_attn_every_n_layers: int = 4, use_media_placement_augmentation: bool = True, **kwargs):
58
- super().__init__(**kwargs)
59
- if vision_config is None:
60
- vision_config = {}
61
- logger.info("vision_config is None. initializing the vision config with default values.")
62
-
63
- if text_config is None:
64
- text_config = {}
65
- logger.info("text_config is None. Initializing the text config with default values.")
66
-
67
- self.vision_config = CLIPVisionConfig(**vision_config)
68
- if "architectures" in text_config.keys() and text_config["architectures"] != None:
69
- if text_config["architectures"][0] == "MPTForCausalLM":
70
- self.text_config = MPTConfig(**text_config)
71
- elif text_config["architectures"][0] == "MosaicGPT":
72
- self.text_config = MosaicGPTConfig(**text_config)
73
- elif text_config["architectures"][0] == "RWForCausalLM":
74
- self.text_config = RWConfig(**text_config)
75
- elif text_config["architectures"][0] == "LlamaForCausalLM":
76
- self.text_config = CONFIG_MAPPING[text_config.pop("model_type")](**text_config)
77
- else:
78
- import pdb
79
-
80
- pdb.set_trace()
81
- else:
82
- self.text_config = CONFIG_MAPPING[text_config.pop("model_type")](**text_config)
83
-
84
- self.cross_attn_every_n_layers = cross_attn_every_n_layers
85
- self.use_media_placement_augmentation = use_media_placement_augmentation
86
-
87
- def to_dict(self):
88
- """
89
- Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
90
-
91
- Returns:
92
- `Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
93
- """
94
- output = copy.deepcopy(self.__dict__)
95
- output["vision_config"] = self.vision_config.to_dict()
96
- output["text_config"] = self.text_config.to_dict()
97
- output["model_type"] = self.__class__.model_type
98
- output["cross_attn_every_n_layers"] = self.cross_attn_every_n_layers
99
- output["use_media_placement_augmentation"] = self.use_media_placement_augmentation
100
- return output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mllm/flamingo/converting_flamingo_to_bf16.py DELETED
@@ -1,30 +0,0 @@
1
- import argparse
2
- import os
3
-
4
- import torch
5
-
6
- from .configuration_flamingo import FlamingoConfig
7
- from .modeling_flamingo import FlamingoForConditionalGeneration
8
-
9
- parser = argparse.ArgumentParser(description="Load model with precision")
10
- parser.add_argument("--load_bit", type=str, choices=["fp16", "bf16"], required=True, help="Choose either 'fp16' or 'bf16'")
11
- parser.add_argument("--pretrained_model_path", type=str, default="/home/luodian/projects/checkpoints/flamingo-mpt-7B-instruct-init", required=True)
12
- parser.add_argument("--saved_model_path", type=str, default="/home/luodian/projects/checkpoints/flamingo-mpt-7B-instruct-init", required=True)
13
- args = parser.parse_args()
14
-
15
- load_bit = args.load_bit
16
- pretrained_model_path = args.pretrained_model_path
17
-
18
- if load_bit == "fp16":
19
- precision = {"torch_dtype": torch.float16}
20
- elif load_bit == "bf16":
21
- precision = {"torch_dtype": torch.bfloat16}
22
-
23
- root_dir = os.environ["AZP"]
24
- print(root_dir)
25
- device_id = "cpu"
26
- model = FlamingoForConditionalGeneration.from_pretrained(pretrained_model_path, device_map={"": device_id}, **precision)
27
-
28
- # save model to same folder
29
- checkpoint_path = pretrained_model_path + f"-{load_bit}"
30
- model.save_pretrained(checkpoint_path, max_shard_size="10GB")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mllm/flamingo/converting_flamingo_to_hf.py DELETED
@@ -1,61 +0,0 @@
1
- """convert from otter pt to otter hf. Will remove after we use otter hf model to train.
2
- """
3
-
4
- import re
5
- import argparse
6
- import os
7
-
8
- import torch
9
- import torch.nn as nn
10
- from transformers import CLIPVisionModel, LlamaForCausalLM, LlamaTokenizer
11
-
12
- import sys
13
- from modeling_flamingo import FlamingoForConditionalGeneration
14
-
15
- from configuration_flamingo import FlamingoConfig
16
-
17
-
18
- @torch.no_grad()
19
- def dump_hf_model(pretrained_model_path: str, old_ckpt_path: str, new_folder_path: str) -> None:
20
- old_ckpt = torch.load(old_ckpt_path, map_location="cpu")
21
- if old_ckpt.get("model_state_dict", None) is not None:
22
- old_ckpt = old_ckpt["model_state_dict"]
23
- new_ckpt = old_ckpt
24
- folder_path = os.path.dirname(old_ckpt_path)
25
- # config_path = os.path.join(folder_path, "config.json") if os.path.exists(os.path.join(folder_path, "config.json")) else "flamingo/config.json"
26
- model = FlamingoForConditionalGeneration.from_pretrained(
27
- args.pretrained_model_path,
28
- device_map="auto",
29
- )
30
- _ = model.load_state_dict(new_ckpt, strict=False)
31
- print(f"Saving HF model to {new_folder_path}")
32
- model.save_pretrained(new_folder_path)
33
-
34
-
35
- if __name__ == "__main__":
36
- parser = argparse.ArgumentParser()
37
- parser.add_argument(
38
- "--old_ckpt_path",
39
- "-old",
40
- type=str,
41
- required=True,
42
- help="Path to the pt checkpoint",
43
- )
44
- parser.add_argument(
45
- "--new_hf_path",
46
- "-new",
47
- type=str,
48
- required=True,
49
- help="Path to the hf folder",
50
- )
51
- parser.add_argument(
52
- "--pretrained_model_path",
53
- "-pretrained",
54
- type=str,
55
- required=True,
56
- help="Path to the pretrained model folder",
57
- )
58
- args = parser.parse_args()
59
- if not os.path.exists(os.path.dirname(args.new_hf_path)):
60
- os.makedirs(os.path.dirname(args.new_hf_path))
61
- dump_hf_model(args.pretrained_model_path, args.old_ckpt_path, args.new_hf_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mllm/flamingo/converting_flamingo_to_lora.py DELETED
@@ -1,68 +0,0 @@
1
- import argparse
2
- import torch
3
- import sys
4
-
5
- from .modeling_flamingo import FlamingoForConditionalGeneration
6
- from peft import get_peft_model, LoraConfig, TaskType
7
-
8
- MODEL_CLASSES = {
9
- "LlamaForCausalLM": "llama",
10
- "OPTForCausalLM": "opt",
11
- "GPTJForCausalLM": "gptj",
12
- "GPTNeoXForCausalLM": "gpt_neox",
13
- "MPTForCausalLM": "mpt",
14
- }
15
-
16
- # Define argument parser
17
- parser = argparse.ArgumentParser(description="Load a model with specified precision and save it to a specified path.")
18
-
19
- # Add arguments
20
- parser.add_argument(
21
- "--checkpoint_path",
22
- type=str,
23
- help="Path to the pre-trained model checkpoint.",
24
- default="",
25
- )
26
- parser.add_argument(
27
- "--save_path",
28
- type=str,
29
- default="",
30
- help="Path to the converted model checkpoint.",
31
- )
32
-
33
- # Parse the input arguments
34
- args = parser.parse_args()
35
-
36
- load_bit = "bf16"
37
- if load_bit == "fp16":
38
- precision = {"torch_dtype": torch.float16}
39
- elif load_bit == "bf16":
40
- precision = {"torch_dtype": torch.bfloat16}
41
-
42
- # Load the model
43
- model = FlamingoForConditionalGeneration.from_pretrained(args.checkpoint_path, device_map="auto", **precision)
44
-
45
- # adding lora
46
- standard_modules = ["q_proj", "v_proj"]
47
- lang_encoder_short_name = MODEL_CLASSES[model.config.text_config.architectures[0]]
48
- model_to_lora_modules = {
49
- "llama": standard_modules,
50
- "opt": standard_modules,
51
- "gptj": standard_modules,
52
- "gpt_neox": ["query_key_value"],
53
- "mpt": ["Wqkv"],
54
- }
55
- lora_config = LoraConfig(
56
- r=16,
57
- lora_alpha=32,
58
- lora_dropout=0.05,
59
- task_type=TaskType.CAUSAL_LM,
60
- target_modules=model_to_lora_modules[lang_encoder_short_name],
61
- )
62
- model.config.update({"lora_config": {"r": 16, "lora_alpha": 32, "lora_dropout": 0.05}})
63
- model.lang_encoder = get_peft_model(model.lang_encoder, lora_config)
64
- model.lang_encoder.print_trainable_parameters()
65
-
66
- # Save the model
67
- checkpoint_path = args.save_path
68
- FlamingoForConditionalGeneration.save_pretrained(model, checkpoint_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mllm/flamingo/falcon/__init__.py DELETED
File without changes
mllm/flamingo/falcon/__pycache__/__init__.cpython-39.pyc DELETED
Binary file (209 Bytes)
 
mllm/flamingo/falcon/__pycache__/configuration_RW.cpython-39.pyc DELETED
Binary file (1.86 kB)
 
mllm/flamingo/falcon/__pycache__/modelling_RW.cpython-39.pyc DELETED
Binary file (28.5 kB)
 
mllm/flamingo/falcon/configuration_RW.py DELETED
@@ -1,79 +0,0 @@
1
- # coding=utf-8
2
- # Copyright 2022 the Big Science Workshop and HuggingFace Inc. team. All rights reserved.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
- """ Bloom configuration"""
16
- from transformers.configuration_utils import PretrainedConfig
17
- from transformers.utils import logging
18
-
19
-
20
- logger = logging.get_logger(__name__)
21
-
22
-
23
- class RWConfig(PretrainedConfig):
24
- model_type = "RefinedWebModel"
25
- keys_to_ignore_at_inference = ["past_key_values"]
26
- attribute_map = {
27
- "num_hidden_layers": "n_layer",
28
- "num_attention_heads": "n_head",
29
- }
30
-
31
- def __init__(
32
- self,
33
- vocab_size=250880,
34
- hidden_size=64,
35
- n_layer=2,
36
- n_head=8,
37
- layer_norm_epsilon=1e-5,
38
- initializer_range=0.02,
39
- use_cache=True,
40
- bos_token_id=1,
41
- eos_token_id=2,
42
- apply_residual_connection_post_layernorm=False,
43
- hidden_dropout=0.0,
44
- attention_dropout=0.0,
45
- multi_query=False,
46
- alibi=False,
47
- bias=False,
48
- parallel_attn=False,
49
- **kwargs,
50
- ):
51
- self.vocab_size = vocab_size
52
- # Backward compatibility with n_embed kwarg
53
- n_embed = kwargs.pop("n_embed", None)
54
- self.hidden_size = hidden_size if n_embed is None else n_embed
55
- self.n_layer = n_layer
56
- self.n_head = n_head
57
- self.layer_norm_epsilon = layer_norm_epsilon
58
- self.initializer_range = initializer_range
59
- self.use_cache = use_cache
60
- self.apply_residual_connection_post_layernorm = apply_residual_connection_post_layernorm
61
- self.hidden_dropout = hidden_dropout
62
- self.attention_dropout = attention_dropout
63
-
64
- self.bos_token_id = bos_token_id
65
- self.eos_token_id = eos_token_id
66
- self.multi_query = multi_query
67
- self.alibi = alibi
68
- self.bias = bias
69
- self.parallel_attn = parallel_attn
70
-
71
- super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
72
-
73
- @property
74
- def head_dim(self):
75
- return self.hidden_size // self.n_head
76
-
77
- @property
78
- def rotary(self):
79
- return not self.alibi
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mllm/flamingo/falcon/modelling_RW.py DELETED
@@ -1,1064 +0,0 @@
1
- # port of models described in RW
2
- # We use the bloom model as a starting point for these model.
3
- # Please refer to the bloom models for usage instructions.
4
-
5
- import math
6
- import warnings
7
- from typing import Optional, Tuple, Union
8
-
9
- import torch
10
- import torch.utils.checkpoint
11
- from torch import nn
12
- from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, LayerNorm, MSELoss
13
- from torch.nn import functional as F
14
-
15
- from transformers.modeling_outputs import (
16
- BaseModelOutputWithPastAndCrossAttentions,
17
- CausalLMOutputWithCrossAttentions,
18
- QuestionAnsweringModelOutput,
19
- SequenceClassifierOutputWithPast,
20
- TokenClassifierOutput,
21
- )
22
- from transformers.modeling_utils import PreTrainedModel
23
- from transformers.utils import logging
24
- from .configuration_RW import RWConfig
25
-
26
- logger = logging.get_logger(__name__)
27
-
28
-
29
- # NOTE(Hesslow): Unfortunately we did not fuse matmul and bias during training, this means that there's one additional quantization to bfloat16 between the operations.
30
- # In order not to degrade the quality of our HF-port, we keep these characteristics in the final model.
31
- class Linear(nn.Linear):
32
- def forward(self, input: torch.Tensor) -> torch.Tensor:
33
- ret = input @ self.weight.T
34
- if self.bias is None:
35
- return ret
36
- else:
37
- return ret + self.bias
38
-
39
-
40
- from einops import rearrange
41
-
42
-
43
- # rotary pos emb helpers (torch.jit.script does not seem to support staticmethod...)
44
- def rotate_half(x):
45
- x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
46
- return torch.cat((-x2, x1), dim=x1.ndim - 1) # dim=-1 triggers a bug in torch < 1.8.0
47
-
48
-
49
- class RotaryEmbedding(torch.nn.Module):
50
- """Implementation of RotaryEmbedding from GPT-NeoX.
51
- This implementation is design to operate on queries and keys that are compatible with
52
- [batch_size, n_heads_per_partition, seq_len, head_dim] (e.g. MinGPTAttention format).
53
- """
54
-
55
- def __init__(
56
- self,
57
- head_dim: int,
58
- base=10000,
59
- ):
60
- super().__init__()
61
- inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2).float() / head_dim))
62
- self.register_buffer("inv_freq", inv_freq, persistent=False)
63
- self.head_dim = head_dim
64
- self.seq_len_cached = None
65
- self.batch_size_cached = None
66
- self.cos_cached: torch.Tensor | None = None
67
- self.sin_cached: torch.Tensor | None = None
68
-
69
- def cos_sin(
70
- self,
71
- seq_len: int,
72
- device="cuda",
73
- dtype=torch.bfloat16,
74
- ) -> torch.Tensor:
75
- if seq_len != self.seq_len_cached:
76
- self.seq_len_cached = seq_len
77
- t = torch.arange(seq_len, device=device).type_as(self.inv_freq)
78
- freqs = torch.einsum("i,j->ij", t, self.inv_freq)
79
- emb = torch.cat((freqs, freqs), dim=-1).to(device)
80
-
81
- if dtype in [torch.float16, torch.bfloat16]:
82
- emb = emb.float()
83
-
84
- self.cos_cached = emb.cos()[None, :, :]
85
- self.sin_cached = emb.sin()[None, :, :]
86
-
87
- self.cos_cached = self.cos_cached.type(dtype)
88
- self.sin_cached = self.sin_cached.type(dtype)
89
-
90
- return self.cos_cached, self.sin_cached
91
-
92
- def forward(self, q, k):
93
- batch, seq_len, head_dim = q.shape
94
- cos, sin = self.cos_sin(seq_len, q.device, q.dtype)
95
- return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)
96
-
97
-
98
- def _make_causal_mask(input_ids_shape: torch.Size, device: torch.device, past_key_values_length: int) -> torch.BoolTensor:
99
- batch_size, target_length = input_ids_shape
100
- mask = torch.empty((target_length, target_length + past_key_values_length), dtype=torch.bool, device=device)
101
- # ONNX doesn't support `torch.Tensor.triu` properly, thus we use this workaround
102
- seq_ids = torch.arange(target_length, device=device)
103
- mask[:, past_key_values_length:] = seq_ids[:, None] < seq_ids[None, :]
104
-
105
- if past_key_values_length > 0:
106
- mask[:, :past_key_values_length] = False
107
-
108
- expanded_mask = mask[None, None, :, :].expand(batch_size, 1, target_length, target_length + past_key_values_length)
109
- return expanded_mask
110
-
111
-
112
- def _expand_mask(mask: torch.Tensor, tgt_length: int) -> torch.BoolTensor:
113
- batch_size, src_length = mask.shape
114
- tgt_length = tgt_length if tgt_length is not None else src_length
115
-
116
- expanded_mask = ~(mask[:, None, None, :].to(torch.bool))
117
- return expanded_mask.expand(batch_size, 1, tgt_length, src_length)
118
-
119
-
120
- def build_alibi_tensor(attention_mask: torch.Tensor, num_heads: int, dtype: torch.dtype) -> torch.Tensor:
121
- batch_size, seq_length = attention_mask.shape
122
- closest_power_of_2 = 2 ** math.floor(math.log2(num_heads))
123
- base = torch.tensor(2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))), device=attention_mask.device, dtype=torch.float32)
124
- powers = torch.arange(1, 1 + closest_power_of_2, device=attention_mask.device, dtype=torch.int32)
125
- slopes = torch.pow(base, powers)
126
-
127
- if closest_power_of_2 != num_heads:
128
- extra_base = torch.tensor(2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))), device=attention_mask.device, dtype=torch.float32)
129
- num_remaining_heads = min(closest_power_of_2, num_heads - closest_power_of_2)
130
- extra_powers = torch.arange(1, 1 + 2 * num_remaining_heads, 2, device=attention_mask.device, dtype=torch.int32)
131
- slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0)
132
-
133
- # Note: alibi will added to the attention bias that will be applied to the query, key product of attention
134
- # => therefore alibi will have to be of shape (batch_size, num_heads, query_length, key_length)
135
- # => here we set (batch_size=1, num_heads=num_heads, query_length=1, key_length=max_length)
136
- # => the query_length dimension will then be broadcasted correctly
137
- # This is more or less identical to T5's relative position bias:
138
- # https://github.com/huggingface/transformers/blob/f681437203baa7671de3174b0fa583c349d9d5e1/src/transformers/models/t5/modeling_t5.py#L527
139
- arange_tensor = ((attention_mask.cumsum(dim=-1) - 1) * attention_mask)[:, None, :]
140
- alibi = slopes[..., None].bfloat16() * arange_tensor
141
- return alibi.reshape(batch_size * num_heads, 1, seq_length).to(dtype)
142
-
143
-
144
- def dropout_add(x: torch.Tensor, residual: torch.Tensor, prob: float, training: bool) -> torch.Tensor:
145
- out = F.dropout(x, p=prob, training=training)
146
- out = residual + out
147
- return out
148
-
149
-
150
- class Attention(nn.Module):
151
- def __init__(self, config: RWConfig):
152
- super().__init__()
153
-
154
- self.hidden_size = config.hidden_size
155
- self.num_heads = config.n_head
156
- self.head_dim = self.hidden_size // self.num_heads
157
- self.split_size = self.hidden_size
158
- self.hidden_dropout = config.hidden_dropout
159
-
160
- if self.head_dim * self.num_heads != self.hidden_size:
161
- raise ValueError(f"`hidden_size` must be divisible by num_heads (got `hidden_size`: {self.hidden_size} and `num_heads`:" f" {self.num_heads}).")
162
-
163
- self.maybe_rotary = RotaryEmbedding(config.head_dim) if config.rotary else lambda q, k: (q, k)
164
-
165
- # Layer-wise attention scaling
166
- self.inv_norm_factor = 1.0 / math.sqrt(self.head_dim)
167
- self.beta = self.inv_norm_factor
168
-
169
- self.query_key_value = Linear(
170
- self.hidden_size,
171
- 3 * self.hidden_size if not config.multi_query else (self.hidden_size + 2 * self.head_dim),
172
- bias=config.bias,
173
- )
174
- self.multi_query = config.multi_query
175
- self.dense = Linear(self.hidden_size, self.hidden_size, bias=config.bias)
176
- self.attention_dropout = nn.Dropout(config.attention_dropout)
177
- self.num_kv = config.n_head if not self.multi_query else 1
178
-
179
- def _split_heads(self, fused_qkv: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
180
- """
181
- Split the last dimension into (num_heads, head_dim) without making any copies, results share same memory
182
- storage as `fused_qkv`
183
-
184
- Args:
185
- fused_qkv (`torch.tensor`, *required*): [batch_size, seq_length, num_heads * 3 * head_dim]
186
-
187
- Returns:
188
- query: [batch_size, seq_length, num_heads, head_dim] key: [batch_size, seq_length, num_heads, head_dim]
189
- value: [batch_size, seq_length, num_heads, head_dim]
190
- """
191
- if not self.multi_query:
192
- batch_size, seq_length, three_times_hidden_size = fused_qkv.shape
193
- fused_qkv = fused_qkv.view(batch_size, seq_length, self.num_heads, 3, self.head_dim)
194
- return fused_qkv[..., 0, :], fused_qkv[..., 1, :], fused_qkv[..., 2, :]
195
- else:
196
- batch_size, seq_length, three_times_hidden_size = fused_qkv.shape
197
- fused_qkv = fused_qkv.view(batch_size, seq_length, self.num_heads + 2, self.head_dim)
198
- return fused_qkv[..., :-2, :], fused_qkv[..., [-2], :], fused_qkv[..., [-1], :]
199
-
200
- def _merge_heads(self, x: torch.Tensor) -> torch.Tensor:
201
- """
202
- Merge heads together over the last dimenstion
203
-
204
- Args:
205
- x: (`torch.tensor`, *required*): [batch_size * num_heads, seq_length, head_dim]
206
-
207
- Returns:
208
- torch.tensor: [batch_size, seq_length, num_heads * head_dim]
209
- """
210
- # What we want to achieve is:
211
- # batch_size * num_heads, seq_length, head_dim -> batch_size, seq_length, num_heads * head_dim
212
- batch_size_and_num_heads, seq_length, _ = x.shape
213
- batch_size = batch_size_and_num_heads // self.num_heads
214
-
215
- # First view to decompose the batch size
216
- # batch_size * num_heads, seq_length, head_dim -> batch_size, num_heads, seq_length, head_dim
217
- x = x.view(batch_size, self.num_heads, seq_length, self.head_dim)
218
-
219
- # batch_size, num_heads, seq_length, head_dim -> batch_size, seq_length, num_heads, head_dim
220
- x = x.permute(0, 2, 1, 3)
221
-
222
- # batch_size, seq_length, num_heads, head_dim -> batch_size, seq_length, num_heads * head_dim
223
- return x.reshape(batch_size, seq_length, self.num_heads * self.head_dim)
224
-
225
- def forward(
226
- self,
227
- hidden_states: torch.Tensor,
228
- alibi: torch.Tensor,
229
- attention_mask: torch.Tensor,
230
- layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
231
- head_mask: Optional[torch.Tensor] = None,
232
- use_cache: bool = False,
233
- output_attentions: bool = False,
234
- ):
235
- fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size]
236
-
237
- # 3 x [batch_size, seq_length, num_heads, head_dim]
238
- (query_layer, key_layer, value_layer) = self._split_heads(fused_qkv)
239
-
240
- batch_size, q_length, _, _ = query_layer.shape
241
-
242
- query_layer = query_layer.transpose(1, 2).reshape(batch_size * self.num_heads, q_length, self.head_dim)
243
- key_layer = key_layer.transpose(1, 2).reshape(
244
- batch_size * self.num_kv,
245
- q_length,
246
- self.head_dim,
247
- )
248
- value_layer = value_layer.transpose(1, 2).reshape(batch_size * self.num_kv, q_length, self.head_dim)
249
-
250
- query_layer, key_layer = self.maybe_rotary(query_layer, key_layer)
251
-
252
- if layer_past is not None:
253
- past_key, past_value = layer_past
254
- # concatenate along seq_length dimension:
255
- # - key: [batch_size * self.num_heads, head_dim, kv_length]
256
- # - value: [batch_size * self.num_heads, kv_length, head_dim]
257
- key_layer = torch.cat((past_key, key_layer), dim=1)
258
- value_layer = torch.cat((past_value, value_layer), dim=1)
259
-
260
- _, kv_length, _ = key_layer.shape
261
-
262
- if use_cache is True:
263
- present = (key_layer, value_layer)
264
- else:
265
- present = None
266
-
267
- if alibi is None:
268
- query_layer_ = query_layer.reshape(batch_size, self.num_heads, -1, self.head_dim)
269
- key_layer_ = key_layer.reshape(batch_size, self.num_kv, -1, self.head_dim)
270
- value_layer_ = value_layer.reshape(batch_size, self.num_kv, -1, self.head_dim)
271
-
272
- attn_output = F.scaled_dot_product_attention(query_layer_, key_layer_, value_layer_, None, 0.0, is_causal=True)
273
-
274
- x = attn_output.view(batch_size, self.num_heads, q_length, self.head_dim)
275
- x = x.permute(0, 2, 1, 3)
276
- attn_output = x.reshape(batch_size, q_length, self.num_heads * self.head_dim)
277
-
278
- output_tensor = self.dense(attn_output)
279
-
280
- outputs = (output_tensor, present)
281
- assert not output_attentions # not supported.
282
- return outputs
283
- else:
284
- attention_mask_float = (attention_mask * 1.0).masked_fill(attention_mask, -1e9).to(torch.bfloat16)
285
- matmul_result = query_layer @ key_layer.transpose(-1, -2)
286
-
287
- # change view to [batch_size, num_heads, q_length, kv_length]
288
- attention_scores = matmul_result.view(batch_size, self.num_heads, q_length, kv_length)
289
-
290
- # cast attention scores to fp32, compute scaled softmax and cast back to initial dtype - [batch_size, num_heads, q_length, kv_length]
291
- input_dtype = attention_scores.dtype
292
- # `float16` has a minimum value of -65504.0, whereas `bfloat16` and `float32` have a minimum value of `-3.4e+38`
293
- if input_dtype == torch.float16 or input_dtype == torch.bfloat16:
294
- attention_scores = attention_scores.to(torch.float32)
295
- # attn_weights = torch.masked_fill(attention_scores, attention_mask, torch.finfo(attention_scores.dtype).min)
296
- attention_probs = F.softmax(
297
- (attention_scores + alibi.view(batch_size, self.num_heads, 1, -1)) * self.inv_norm_factor + attention_mask_float,
298
- dim=-1,
299
- dtype=hidden_states.dtype,
300
- )
301
- # [batch_size, num_heads, q_length, kv_length]
302
- attention_probs = self.attention_dropout(attention_probs)
303
-
304
- if head_mask is not None:
305
- attention_probs = attention_probs * head_mask
306
-
307
- # change view [batch_size x num_heads, q_length, kv_length]
308
- attention_probs_reshaped = attention_probs.view(batch_size * self.num_heads, q_length, kv_length)
309
-
310
- # matmul: [batch_size * num_heads, q_length, head_dim]
311
- context_layer = attention_probs_reshaped @ value_layer
312
-
313
- # change view [batch_size, num_heads, q_length, head_dim]
314
- context_layer = self._merge_heads(context_layer)
315
-
316
- output_tensor = self.dense(context_layer)
317
-
318
- outputs = (output_tensor, present)
319
- if output_attentions:
320
- outputs += (attention_probs,)
321
-
322
- return outputs
323
-
324
-
325
- class MLP(nn.Module):
326
- def __init__(self, config: RWConfig):
327
- super().__init__()
328
- hidden_size = config.hidden_size
329
-
330
- self.dense_h_to_4h = Linear(hidden_size, 4 * hidden_size, bias=config.bias)
331
- self.act = nn.GELU()
332
- self.dense_4h_to_h = Linear(4 * hidden_size, hidden_size, bias=config.bias)
333
- self.hidden_dropout = config.hidden_dropout
334
-
335
- def forward(self, x: torch.Tensor) -> torch.Tensor:
336
- x = self.act(self.dense_h_to_4h(x))
337
- x = self.dense_4h_to_h(x)
338
- return x
339
-
340
-
341
- class DecoderLayer(nn.Module):
342
- def __init__(self, config: RWConfig):
343
- super().__init__()
344
- hidden_size = config.hidden_size
345
-
346
- self.input_layernorm = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
347
- self.num_heads = config.n_head
348
- self.self_attention = Attention(config)
349
-
350
- if not config.parallel_attn:
351
- # unused if parallel attn
352
- self.post_attention_layernorm = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
353
-
354
- self.mlp = MLP(config)
355
-
356
- self.apply_residual_connection_post_layernorm = config.apply_residual_connection_post_layernorm
357
- self.hidden_dropout = config.hidden_dropout
358
-
359
- self.config = config
360
-
361
- def forward(
362
- self,
363
- hidden_states: torch.Tensor,
364
- alibi: torch.Tensor,
365
- attention_mask: torch.Tensor,
366
- layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
367
- head_mask: Optional[torch.Tensor] = None,
368
- use_cache: bool = False,
369
- output_attentions: bool = False,
370
- ):
371
- layernorm_output = self.input_layernorm(hidden_states)
372
- residual = hidden_states
373
-
374
- # Self attention.
375
- attn_outputs = self.self_attention(
376
- layernorm_output,
377
- layer_past=layer_past,
378
- attention_mask=attention_mask,
379
- alibi=alibi,
380
- head_mask=head_mask,
381
- use_cache=use_cache,
382
- output_attentions=output_attentions,
383
- )
384
-
385
- attention_output = attn_outputs[0]
386
-
387
- if not self.config.parallel_attn:
388
- residual = dropout_add(attention_output, residual, self.config.attention_dropout, training=self.training)
389
- layernorm_output = self.post_attention_layernorm(residual)
390
-
391
- outputs = attn_outputs[1:]
392
-
393
- # MLP.
394
- mlp_output = self.mlp(layernorm_output)
395
-
396
- if self.config.parallel_attn:
397
- mlp_output += attention_output
398
-
399
- output = dropout_add(mlp_output, residual, self.config.hidden_dropout, training=self.training)
400
-
401
- if use_cache:
402
- outputs = (output,) + outputs
403
- else:
404
- outputs = (output,) + outputs[1:]
405
-
406
- return outputs # hidden_states, present, attentions
407
-
408
-
409
- class RWPreTrainedModel(PreTrainedModel):
410
- _keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"]
411
- """
412
- An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
413
- models.
414
- """
415
-
416
- config_class = RWConfig
417
- base_model_prefix = "transformer"
418
- supports_gradient_checkpointing = True
419
- _no_split_modules = ["DecoderLayer"]
420
-
421
- def __init__(self, *inputs, **kwargs):
422
- super().__init__(*inputs, **kwargs)
423
-
424
- def _init_weights(self, module: nn.Module):
425
- """Initialize the weights."""
426
- if isinstance(module, nn.Linear) or isinstance(module, Linear):
427
- # Slightly different from the TF version which uses truncated_normal for initialization
428
- # cf https://github.com/pytorch/pytorch/pull/5617
429
- module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
430
- if module.bias is not None:
431
- module.bias.data.zero_()
432
- elif isinstance(module, nn.Embedding):
433
- module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
434
- if module.padding_idx is not None:
435
- module.weight.data[module.padding_idx].zero_()
436
- elif isinstance(module, LayerNorm):
437
- module.bias.data.zero_()
438
- module.weight.data.fill_(1.0)
439
-
440
- def _set_gradient_checkpointing(self, module: nn.Module, value: bool = False):
441
- if isinstance(module, RWModel):
442
- module.gradient_checkpointing = value
443
-
444
- @staticmethod
445
- def _convert_to_standard_cache(past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]], batch_size: int) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]:
446
- """
447
- Standardizes the format of the cache so as to match most implementations, i.e. to tuple(tuple([batch_size,
448
- num_heads, ...]))
449
- """
450
- batch_size_times_num_heads, head_dim, seq_length = past_key_value[0][0].shape
451
- num_heads = batch_size_times_num_heads // batch_size
452
- # key: [batch_size * num_heads, head_dim, seq_length] -> [batch_size, num_heads, head_dim, seq_length]
453
- # value: [batch_size * num_heads, seq_length, head_dim] -> [batch_size, num_heads, seq_length, head_dim]
454
- return tuple(
455
- (
456
- layer_past[0].view(batch_size, num_heads, head_dim, seq_length),
457
- layer_past[1].view(batch_size, num_heads, seq_length, head_dim),
458
- )
459
- for layer_past in past_key_value
460
- )
461
-
462
- @staticmethod
463
- def _convert_to_rw_cache(past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]]) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]:
464
- batch_size, num_heads, head_dim, seq_length = past_key_value[0][0].shape
465
- batch_size_times_num_heads = batch_size * num_heads
466
- # key: [batch_size, num_heads, head_dim, seq_length] -> [batch_size * num_heads, head_dim, seq_length]
467
- # value: [batch_size, num_heads, seq_length, head_dim] -> [batch_size * num_heads, seq_length, head_dim]
468
- return tuple(
469
- (
470
- layer_past[0].view(batch_size_times_num_heads, head_dim, seq_length),
471
- layer_past[1].view(batch_size_times_num_heads, seq_length, head_dim),
472
- )
473
- for layer_past in past_key_value
474
- )
475
-
476
-
477
- class RWModel(RWPreTrainedModel):
478
- def __init__(self, config: RWConfig):
479
- super().__init__(config)
480
-
481
- self.embed_dim = config.hidden_size
482
- self.num_heads = config.n_head
483
- self.alibi = config.alibi
484
-
485
- # Embedding + LN Embedding
486
- self.word_embeddings = nn.Embedding(config.vocab_size, self.embed_dim)
487
-
488
- # Transformer blocks
489
- self.h = nn.ModuleList([DecoderLayer(config) for _ in range(config.num_hidden_layers)])
490
-
491
- # Final Layer Norm
492
- self.ln_f = LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
493
-
494
- self.gradient_checkpointing = False
495
-
496
- # Initialize weights and apply final processing
497
- self.post_init()
498
-
499
- def get_input_embeddings(self):
500
- return self.word_embeddings
501
-
502
- def _prepare_attn_mask(self, attention_mask: torch.Tensor, input_shape: Tuple[int, int], past_key_values_length: int) -> torch.BoolTensor:
503
- # create causal mask
504
- # [batch_size, seq_length] -> [batch_size, 1, tgt_length, src_length]
505
- combined_attention_mask = None
506
- device = attention_mask.device
507
- _, src_length = input_shape
508
-
509
- if src_length > 1:
510
- combined_attention_mask = _make_causal_mask(input_shape, device=device, past_key_values_length=past_key_values_length)
511
-
512
- # [batch_size, seq_length] -> [batch_size, 1, tgt_length, src_length]
513
- expanded_attn_mask = _expand_mask(attention_mask, tgt_length=src_length)
514
- combined_attention_mask = expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask | combined_attention_mask
515
-
516
- return combined_attention_mask
517
-
518
- def set_input_embeddings(self, new_embeddings: torch.Tensor):
519
- self.word_embeddings = new_embeddings
520
-
521
- def forward(
522
- self,
523
- input_ids: Optional[torch.LongTensor] = None,
524
- past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
525
- attention_mask: Optional[torch.Tensor] = None,
526
- head_mask: Optional[torch.LongTensor] = None,
527
- inputs_embeds: Optional[torch.LongTensor] = None,
528
- use_cache: Optional[bool] = None,
529
- output_attentions: Optional[bool] = None,
530
- output_hidden_states: Optional[bool] = None,
531
- return_dict: Optional[bool] = None,
532
- **deprecated_arguments,
533
- ) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]:
534
- if deprecated_arguments.pop("position_ids", False) is not False:
535
- # `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None`
536
- warnings.warn(
537
- "`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore" " passing `position_ids`.",
538
- FutureWarning,
539
- )
540
- if len(deprecated_arguments) > 0:
541
- raise ValueError(f"Got unexpected arguments: {deprecated_arguments}")
542
-
543
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
544
- output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
545
- use_cache = use_cache if use_cache is not None else self.config.use_cache
546
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
547
-
548
- if input_ids is not None and inputs_embeds is not None:
549
- raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
550
- elif input_ids is not None:
551
- batch_size, seq_length = input_ids.shape
552
- elif inputs_embeds is not None:
553
- batch_size, seq_length, _ = inputs_embeds.shape
554
- else:
555
- raise ValueError("You have to specify either input_ids or inputs_embeds")
556
-
557
- if past_key_values is None:
558
- past_key_values = tuple([None] * len(self.h))
559
-
560
- # Prepare head mask if needed
561
- # 1.0 in head_mask indicate we keep the head
562
- # attention_probs has shape batch_size x num_heads x N x N
563
- # head_mask has shape n_layer x batch x num_heads x N x N
564
- head_mask = self.get_head_mask(head_mask, self.config.n_layer)
565
-
566
- if inputs_embeds is None:
567
- inputs_embeds = self.word_embeddings(input_ids)
568
-
569
- hidden_states = inputs_embeds
570
-
571
- presents = () if use_cache else None
572
- all_self_attentions = () if output_attentions else None
573
- all_hidden_states = () if output_hidden_states else None
574
-
575
- # Compute alibi tensor: check build_alibi_tensor documentation
576
- seq_length_with_past = seq_length
577
- past_key_values_length = 0
578
- if past_key_values[0] is not None:
579
- past_key_values_length = past_key_values[0][0].shape[2]
580
- seq_length_with_past = seq_length_with_past + past_key_values_length
581
- if attention_mask is None:
582
- attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device)
583
- else:
584
- attention_mask = attention_mask.to(hidden_states.device)
585
-
586
- if self.alibi:
587
- alibi = build_alibi_tensor(attention_mask, self.num_heads, dtype=hidden_states.dtype)
588
- else:
589
- alibi = None
590
-
591
- causal_mask = self._prepare_attn_mask(
592
- attention_mask,
593
- input_shape=(batch_size, seq_length),
594
- past_key_values_length=past_key_values_length,
595
- )
596
-
597
- for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
598
- if output_hidden_states:
599
- all_hidden_states = all_hidden_states + (hidden_states,)
600
-
601
- if self.gradient_checkpointing and self.training:
602
- if use_cache:
603
- logger.warning("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
604
- use_cache = False
605
-
606
- def create_custom_forward(module):
607
- def custom_forward(*inputs):
608
- # None for past_key_value
609
- return module(*inputs, use_cache=use_cache, output_attentions=output_attentions)
610
-
611
- return custom_forward
612
-
613
- outputs = torch.utils.checkpoint.checkpoint(
614
- create_custom_forward(block),
615
- hidden_states,
616
- alibi,
617
- causal_mask,
618
- head_mask[i],
619
- )
620
- else:
621
- outputs = block(
622
- hidden_states,
623
- layer_past=layer_past,
624
- attention_mask=causal_mask,
625
- head_mask=head_mask[i],
626
- use_cache=use_cache,
627
- output_attentions=output_attentions,
628
- alibi=alibi,
629
- )
630
-
631
- hidden_states = outputs[0]
632
- if use_cache is True:
633
- presents = presents + (outputs[1],)
634
-
635
- if output_attentions:
636
- all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
637
-
638
- # Add last hidden state
639
- hidden_states = self.ln_f(hidden_states)
640
-
641
- if output_hidden_states:
642
- all_hidden_states = all_hidden_states + (hidden_states,)
643
-
644
- if not return_dict:
645
- return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
646
-
647
- return BaseModelOutputWithPastAndCrossAttentions(
648
- last_hidden_state=hidden_states,
649
- past_key_values=presents,
650
- hidden_states=all_hidden_states,
651
- attentions=all_self_attentions,
652
- )
653
-
654
-
655
- class RWForCausalLM(RWPreTrainedModel):
656
- _keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"]
657
-
658
- def __init__(self, config: RWConfig):
659
- super().__init__(config)
660
- self.transformer = RWModel(config)
661
- self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
662
-
663
- # Initialize weights and apply final processing
664
- self.post_init()
665
-
666
- def get_output_embeddings(self):
667
- return self.lm_head
668
-
669
- def set_output_embeddings(self, new_embeddings: torch.Tensor):
670
- self.lm_head = new_embeddings
671
-
672
- def prepare_inputs_for_generation(
673
- self,
674
- input_ids: torch.LongTensor,
675
- past: Optional[torch.Tensor] = None,
676
- attention_mask: Optional[torch.Tensor] = None,
677
- **kwargs,
678
- ) -> dict:
679
- # only last token for input_ids if past is not None
680
- if past:
681
- input_ids = input_ids[:, -1].unsqueeze(-1)
682
-
683
- # the cache may be in the stardard format (e.g. in contrastive search), convert to our's format if needed
684
- if past[0][0].shape[0] == input_ids.shape[0]:
685
- past = self._convert_to_rw_cache(past)
686
-
687
- return {
688
- "input_ids": input_ids,
689
- "past_key_values": past,
690
- "use_cache": kwargs.get("use_cache"),
691
- "attention_mask": attention_mask,
692
- }
693
-
694
- def forward(
695
- self,
696
- input_ids: Optional[torch.LongTensor] = None,
697
- past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
698
- attention_mask: Optional[torch.Tensor] = None,
699
- head_mask: Optional[torch.Tensor] = None,
700
- inputs_embeds: Optional[torch.Tensor] = None,
701
- labels: Optional[torch.Tensor] = None,
702
- use_cache: Optional[bool] = None,
703
- output_attentions: Optional[bool] = None,
704
- output_hidden_states: Optional[bool] = None,
705
- return_dict: Optional[bool] = None,
706
- **deprecated_arguments,
707
- ) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
708
- r"""
709
- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
710
- Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
711
- `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
712
- are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
713
- """
714
- if deprecated_arguments.pop("position_ids", False) is not False:
715
- # `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None`
716
- warnings.warn(
717
- "`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore" " passing `position_ids`.",
718
- FutureWarning,
719
- )
720
- if len(deprecated_arguments) > 0:
721
- raise ValueError(f"Got unexpected arguments: {deprecated_arguments}")
722
-
723
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
724
-
725
- transformer_outputs = self.transformer(
726
- input_ids,
727
- past_key_values=past_key_values,
728
- attention_mask=attention_mask,
729
- head_mask=head_mask,
730
- inputs_embeds=inputs_embeds,
731
- use_cache=use_cache,
732
- output_attentions=output_attentions,
733
- output_hidden_states=output_hidden_states,
734
- return_dict=return_dict,
735
- )
736
- hidden_states = transformer_outputs[0]
737
-
738
- lm_logits = self.lm_head(hidden_states)
739
-
740
- loss = None
741
- if labels is not None:
742
- # Shift so that tokens < n predict n
743
- shift_logits = lm_logits[..., :-1, :].contiguous()
744
- shift_labels = labels[..., 1:].contiguous()
745
- batch_size, seq_length, vocab_size = shift_logits.shape
746
- # Flatten the tokens
747
- loss_fct = CrossEntropyLoss()
748
- loss = loss_fct(shift_logits.view(batch_size * seq_length, vocab_size), shift_labels.view(batch_size * seq_length))
749
-
750
- if not return_dict:
751
- output = (lm_logits,) + transformer_outputs[1:]
752
- return ((loss,) + output) if loss is not None else output
753
-
754
- return CausalLMOutputWithCrossAttentions(
755
- loss=loss,
756
- logits=lm_logits,
757
- past_key_values=transformer_outputs.past_key_values,
758
- hidden_states=transformer_outputs.hidden_states,
759
- attentions=transformer_outputs.attentions,
760
- )
761
-
762
- def _reorder_cache(self, past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...], beam_idx: torch.LongTensor) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]:
763
- """
764
- This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
765
- [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
766
- beam_idx at every generation step.
767
-
768
- Output shares the same memory storage as `past`.
769
- """
770
- standardized_past = self._convert_to_standard_cache(past, batch_size=len(beam_idx))
771
-
772
- # Get a copy of `beam_idx` on all the devices where we need those indices.
773
- device_to_beam_idx = {past_state.device: beam_idx.to(past_state.device) for layer_past in past for past_state in layer_past}
774
- reordered_past = tuple(
775
- (
776
- layer_past[0].index_select(0, device_to_beam_idx[layer_past[0].device]),
777
- layer_past[1].index_select(0, device_to_beam_idx[layer_past[0].device]),
778
- )
779
- for layer_past in standardized_past
780
- )
781
- return self._convert_to_rw_cache(reordered_past)
782
-
783
-
784
- class RWForSequenceClassification(RWPreTrainedModel):
785
- _keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"]
786
-
787
- def __init__(self, config: RWConfig):
788
- super().__init__(config)
789
- self.num_labels = config.num_labels
790
- self.transformer = RWModel(config)
791
- self.score = nn.Linear(config.hidden_size, config.num_labels, bias=False)
792
-
793
- # Initialize weights and apply final processing
794
- self.post_init()
795
-
796
- def forward(
797
- self,
798
- input_ids: Optional[torch.LongTensor] = None,
799
- past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
800
- attention_mask: Optional[torch.Tensor] = None,
801
- head_mask: Optional[torch.Tensor] = None,
802
- inputs_embeds: Optional[torch.Tensor] = None,
803
- labels: Optional[torch.Tensor] = None,
804
- use_cache: Optional[bool] = None,
805
- output_attentions: Optional[bool] = None,
806
- output_hidden_states: Optional[bool] = None,
807
- return_dict: Optional[bool] = None,
808
- **deprecated_arguments,
809
- ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutputWithPast]:
810
- r"""
811
- labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
812
- Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
813
- config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
814
- `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
815
- """
816
- if deprecated_arguments.pop("position_ids", False) is not False:
817
- # `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None`
818
- warnings.warn(
819
- "`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore" " passing `position_ids`.",
820
- FutureWarning,
821
- )
822
- if len(deprecated_arguments) > 0:
823
- raise ValueError(f"Got unexpected arguments: {deprecated_arguments}")
824
-
825
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
826
-
827
- transformer_outputs = self.transformer(
828
- input_ids,
829
- past_key_values=past_key_values,
830
- attention_mask=attention_mask,
831
- head_mask=head_mask,
832
- inputs_embeds=inputs_embeds,
833
- use_cache=use_cache,
834
- output_attentions=output_attentions,
835
- output_hidden_states=output_hidden_states,
836
- return_dict=return_dict,
837
- )
838
-
839
- hidden_states = transformer_outputs[0]
840
- logits = self.score(hidden_states)
841
-
842
- if input_ids is not None:
843
- batch_size = input_ids.shape[0]
844
- else:
845
- batch_size = inputs_embeds.shape[0]
846
-
847
- if self.config.pad_token_id is None and batch_size != 1:
848
- raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
849
- if self.config.pad_token_id is None:
850
- sequence_lengths = -1
851
- else:
852
- if input_ids is not None:
853
- sequence_lengths = torch.ne(input_ids, self.config.pad_token_id).sum(dim=-1) - 1
854
- else:
855
- sequence_lengths = -1
856
- logger.warning(
857
- f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
858
- "unexpected if using padding tokens in conjunction with `inputs_embeds.`"
859
- )
860
-
861
- pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
862
-
863
- loss = None
864
- if labels is not None:
865
- if self.config.problem_type is None:
866
- if self.num_labels == 1:
867
- self.config.problem_type = "regression"
868
- elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
869
- self.config.problem_type = "single_label_classification"
870
- else:
871
- self.config.problem_type = "multi_label_classification"
872
-
873
- if self.config.problem_type == "regression":
874
- loss_fct = MSELoss()
875
- if self.num_labels == 1:
876
- loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
877
- else:
878
- loss = loss_fct(pooled_logits, labels)
879
- elif self.config.problem_type == "single_label_classification":
880
- loss_fct = CrossEntropyLoss()
881
- loss = loss_fct(pooled_logits, labels)
882
- elif self.config.problem_type == "multi_label_classification":
883
- loss_fct = BCEWithLogitsLoss()
884
- loss = loss_fct(pooled_logits, labels)
885
- if not return_dict:
886
- output = (pooled_logits,) + transformer_outputs[1:]
887
- return ((loss,) + output) if loss is not None else output
888
-
889
- return SequenceClassifierOutputWithPast(
890
- loss=loss,
891
- logits=pooled_logits,
892
- past_key_values=transformer_outputs.past_key_values,
893
- hidden_states=transformer_outputs.hidden_states,
894
- attentions=transformer_outputs.attentions,
895
- )
896
-
897
-
898
- class RWForTokenClassification(RWPreTrainedModel):
899
- _keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"]
900
-
901
- def __init__(self, config: RWConfig):
902
- super().__init__(config)
903
- self.num_labels = config.num_labels
904
-
905
- self.transformer = RWModel(config)
906
- if hasattr(config, "classifier_dropout") and config.classifier_dropout is not None:
907
- classifier_dropout = config.classifier_dropout
908
- elif hasattr(config, "hidden_dropout") and config.hidden_dropout is not None:
909
- classifier_dropout = config.hidden_dropout
910
- else:
911
- classifier_dropout = 0.1
912
- self.dropout = nn.Dropout(classifier_dropout)
913
- self.classifier = nn.Linear(config.hidden_size, config.num_labels)
914
-
915
- # Initialize weights and apply final processing
916
- self.post_init()
917
-
918
- def forward(
919
- self,
920
- input_ids: Optional[torch.LongTensor] = None,
921
- past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
922
- attention_mask: Optional[torch.Tensor] = None,
923
- head_mask: Optional[torch.Tensor] = None,
924
- inputs_embeds: Optional[torch.Tensor] = None,
925
- labels: Optional[torch.Tensor] = None,
926
- use_cache: Optional[bool] = None,
927
- output_attentions: Optional[bool] = None,
928
- output_hidden_states: Optional[bool] = None,
929
- return_dict: Optional[bool] = None,
930
- **deprecated_arguments,
931
- ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
932
- r"""
933
- labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
934
- Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
935
- config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
936
- `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
937
- """
938
- if deprecated_arguments.pop("position_ids", False) is not False:
939
- # `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None`
940
- warnings.warn(
941
- "`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore" " passing `position_ids`.",
942
- FutureWarning,
943
- )
944
- if len(deprecated_arguments) > 0:
945
- raise ValueError(f"Got unexpected arguments: {deprecated_arguments}")
946
-
947
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
948
-
949
- transformer_outputs = self.transformer(
950
- input_ids,
951
- past_key_values=past_key_values,
952
- attention_mask=attention_mask,
953
- head_mask=head_mask,
954
- inputs_embeds=inputs_embeds,
955
- use_cache=use_cache,
956
- output_attentions=output_attentions,
957
- output_hidden_states=output_hidden_states,
958
- return_dict=return_dict,
959
- )
960
-
961
- hidden_states = transformer_outputs[0]
962
- hidden_states = self.dropout(hidden_states)
963
- logits = self.classifier(hidden_states)
964
-
965
- loss = None
966
- if labels is not None:
967
- batch_size, seq_length = labels.shape
968
- loss_fct = CrossEntropyLoss()
969
- loss = loss_fct(logits.view(batch_size * seq_length, self.num_labels), labels.view(batch_size * seq_length))
970
-
971
- if not return_dict:
972
- output = (logits,) + transformer_outputs[2:]
973
- return ((loss,) + output) if loss is not None else output
974
-
975
- return TokenClassifierOutput(
976
- loss=loss,
977
- logits=logits,
978
- hidden_states=transformer_outputs.hidden_states,
979
- attentions=transformer_outputs.attentions,
980
- )
981
-
982
-
983
- class RWForQuestionAnswering(RWPreTrainedModel):
984
- _keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"]
985
-
986
- def __init__(self, config):
987
- super().__init__(config)
988
- self.transformer = RWModel(config)
989
- self.qa_outputs = nn.Linear(config.hidden_size, 2)
990
-
991
- # Initialize weights and apply final processing
992
- self.post_init()
993
-
994
- def forward(
995
- self,
996
- input_ids: Optional[torch.LongTensor] = None,
997
- attention_mask: Optional[torch.FloatTensor] = None,
998
- position_ids: Optional[torch.LongTensor] = None,
999
- head_mask: Optional[torch.FloatTensor] = None,
1000
- inputs_embeds: Optional[torch.FloatTensor] = None,
1001
- start_positions: Optional[torch.LongTensor] = None,
1002
- end_positions: Optional[torch.LongTensor] = None,
1003
- output_attentions: Optional[bool] = None,
1004
- output_hidden_states: Optional[bool] = None,
1005
- return_dict: Optional[bool] = None,
1006
- ) -> Union[Tuple, QuestionAnsweringModelOutput]:
1007
- r"""
1008
- start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1009
- Labels for position (index) of the start of the labelled span for computing the token classification loss.
1010
- Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
1011
- are not taken into account for computing the loss.
1012
- end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1013
- Labels for position (index) of the end of the labelled span for computing the token classification loss.
1014
- Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
1015
- are not taken into account for computing the loss.
1016
- """
1017
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1018
-
1019
- outputs = self.transformer(
1020
- input_ids,
1021
- attention_mask=attention_mask,
1022
- position_ids=position_ids,
1023
- head_mask=head_mask,
1024
- inputs_embeds=inputs_embeds,
1025
- output_attentions=output_attentions,
1026
- output_hidden_states=output_hidden_states,
1027
- return_dict=return_dict,
1028
- )
1029
-
1030
- sequence_output = outputs[0]
1031
-
1032
- logits = self.qa_outputs(sequence_output)
1033
- start_logits, end_logits = logits.split(1, dim=-1)
1034
- start_logits = start_logits.squeeze(-1).contiguous()
1035
- end_logits = end_logits.squeeze(-1).contiguous()
1036
-
1037
- total_loss = None
1038
- if start_positions is not None and end_positions is not None:
1039
- # If we are on multi-GPU, split add a dimension
1040
- if len(start_positions.size()) > 1:
1041
- start_positions = start_positions.squeeze(-1)
1042
- if len(end_positions.size()) > 1:
1043
- end_positions = end_positions.squeeze(-1)
1044
- # sometimes the start/end positions are outside our model inputs, we ignore these terms
1045
- ignored_index = start_logits.size(1)
1046
- start_positions = start_positions.clamp(0, ignored_index)
1047
- end_positions = end_positions.clamp(0, ignored_index)
1048
-
1049
- loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
1050
- start_loss = loss_fct(start_logits, start_positions)
1051
- end_loss = loss_fct(end_logits, end_positions)
1052
- total_loss = (start_loss + end_loss) / 2
1053
-
1054
- if not return_dict:
1055
- output = (start_logits, end_logits) + outputs[2:]
1056
- return ((total_loss,) + output) if total_loss is not None else output
1057
-
1058
- return QuestionAnsweringModelOutput(
1059
- loss=total_loss,
1060
- start_logits=start_logits,
1061
- end_logits=end_logits,
1062
- hidden_states=outputs.hidden_states,
1063
- attentions=outputs.attentions,
1064
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mllm/flamingo/flamingo-falcon-7B.json DELETED
@@ -1,112 +0,0 @@
1
- {
2
- "_commit_hash": null,
3
- "architectures": [
4
- "FlamingoModel"
5
- ],
6
- "cross_attn_every_n_layers": 4,
7
- "model_type": "flamingo",
8
- "text_config": {
9
- "architectures": [
10
- "RWForCausalLM"
11
- ],
12
- "apply_residual_connection_post_layernorm": false,
13
- "attention_dropout": 0.0,
14
- "bias": false,
15
- "bos_token_id": 11,
16
- "eos_token_id": 11,
17
- "hidden_dropout": 0.0,
18
- "hidden_size": 4544,
19
- "initializer_range": 0.02,
20
- "layer_norm_epsilon": 1e-05,
21
- "model_type": "RefinedWebModel",
22
- "multi_query": true,
23
- "n_head": 71,
24
- "n_layer": 32,
25
- "parallel_attn": true,
26
- "torch_dtype": "bfloat16",
27
- "transformers_version": "4.27.4",
28
- "use_cache": true,
29
- "vocab_size": 65024
30
- },
31
- "tie_word_embeddings": false,
32
- "torch_dtype": "float32",
33
- "transformers_version": null,
34
- "use_media_placement_augmentation": true,
35
- "vision_config": {
36
- "_name_or_path": "openai/clip-vit-large-patch14",
37
- "add_cross_attention": false,
38
- "architectures": null,
39
- "attention_dropout": 0.0,
40
- "bad_words_ids": null,
41
- "begin_suppress_tokens": null,
42
- "bos_token_id": null,
43
- "chunk_size_feed_forward": 0,
44
- "cross_attention_hidden_size": null,
45
- "decoder_start_token_id": null,
46
- "diversity_penalty": 0.0,
47
- "do_sample": false,
48
- "early_stopping": false,
49
- "encoder_no_repeat_ngram_size": 0,
50
- "eos_token_id": null,
51
- "exponential_decay_length_penalty": null,
52
- "finetuning_task": null,
53
- "forced_bos_token_id": null,
54
- "forced_eos_token_id": null,
55
- "hidden_act": "quick_gelu",
56
- "hidden_size": 1024,
57
- "id2label": {
58
- "0": "LABEL_0",
59
- "1": "LABEL_1"
60
- },
61
- "image_size": 224,
62
- "initializer_factor": 1.0,
63
- "initializer_range": 0.02,
64
- "intermediate_size": 4096,
65
- "is_decoder": false,
66
- "is_encoder_decoder": false,
67
- "label2id": {
68
- "LABEL_0": 0,
69
- "LABEL_1": 1
70
- },
71
- "layer_norm_eps": 1e-05,
72
- "length_penalty": 1.0,
73
- "max_length": 20,
74
- "min_length": 0,
75
- "model_type": "clip_vision_model",
76
- "no_repeat_ngram_size": 0,
77
- "num_attention_heads": 16,
78
- "num_beam_groups": 1,
79
- "num_beams": 1,
80
- "num_channels": 3,
81
- "num_hidden_layers": 24,
82
- "num_return_sequences": 1,
83
- "output_attentions": false,
84
- "output_hidden_states": false,
85
- "output_scores": false,
86
- "pad_token_id": null,
87
- "patch_size": 14,
88
- "prefix": null,
89
- "problem_type": null,
90
- "projection_dim": 512,
91
- "pruned_heads": {},
92
- "remove_invalid_values": false,
93
- "repetition_penalty": 1.0,
94
- "return_dict": true,
95
- "return_dict_in_generate": false,
96
- "sep_token_id": null,
97
- "suppress_tokens": null,
98
- "task_specific_params": null,
99
- "temperature": 1.0,
100
- "tf_legacy_loss": false,
101
- "tie_encoder_decoder": false,
102
- "tie_word_embeddings": true,
103
- "tokenizer_class": null,
104
- "top_k": 50,
105
- "top_p": 1.0,
106
- "torch_dtype": null,
107
- "torchscript": false,
108
- "transformers_version": "4.28.1",
109
- "typical_p": 1.0,
110
- "use_bfloat16": false
111
- }
112
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mllm/flamingo/flamingo-llama2-chat-13B.json DELETED
@@ -1,114 +0,0 @@
1
- {
2
- "_commit_hash": null,
3
- "architectures": [
4
- "FlamingoForConditionalGeneration"
5
- ],
6
- "cross_attn_every_n_layers": 8,
7
- "model_type": "flamingo",
8
- "text_config": {
9
- "_name_or_path": "meta-llama/Llama-2-13b-chat-hf",
10
- "architectures": [
11
- "LlamaForCausalLM"
12
- ],
13
- "bos_token_id": 1,
14
- "eos_token_id": 2,
15
- "hidden_act": "silu",
16
- "hidden_size": 5120,
17
- "initializer_range": 0.02,
18
- "intermediate_size": 13824,
19
- "max_position_embeddings": 4096,
20
- "model_type": "llama",
21
- "num_attention_heads": 40,
22
- "num_hidden_layers": 40,
23
- "num_key_value_heads": 40,
24
- "pad_token_id": 0,
25
- "pretraining_tp": 1,
26
- "rms_norm_eps": 1e-05,
27
- "rope_scaling": null,
28
- "tie_word_embeddings": false,
29
- "torch_dtype": "float16",
30
- "transformers_version": "4.30.1",
31
- "use_cache": true,
32
- "vocab_size": 32000
33
- },
34
- "torch_dtype": "float32",
35
- "transformers_version": null,
36
- "use_media_placement_augmentation": true,
37
- "vision_config": {
38
- "_name_or_path": "openai/clip-vit-large-patch14",
39
- "add_cross_attention": false,
40
- "architectures": null,
41
- "attention_dropout": 0.0,
42
- "bad_words_ids": null,
43
- "begin_suppress_tokens": null,
44
- "bos_token_id": null,
45
- "chunk_size_feed_forward": 0,
46
- "cross_attention_hidden_size": null,
47
- "decoder_start_token_id": null,
48
- "diversity_penalty": 0.0,
49
- "do_sample": false,
50
- "early_stopping": false,
51
- "encoder_no_repeat_ngram_size": 0,
52
- "eos_token_id": null,
53
- "exponential_decay_length_penalty": null,
54
- "finetuning_task": null,
55
- "forced_bos_token_id": null,
56
- "forced_eos_token_id": null,
57
- "hidden_act": "quick_gelu",
58
- "hidden_size": 1024,
59
- "id2label": {
60
- "0": "LABEL_0",
61
- "1": "LABEL_1"
62
- },
63
- "image_size": 224,
64
- "initializer_factor": 1.0,
65
- "initializer_range": 0.02,
66
- "intermediate_size": 4096,
67
- "is_decoder": false,
68
- "is_encoder_decoder": false,
69
- "label2id": {
70
- "LABEL_0": 0,
71
- "LABEL_1": 1
72
- },
73
- "layer_norm_eps": 1e-05,
74
- "length_penalty": 1.0,
75
- "max_length": 20,
76
- "min_length": 0,
77
- "model_type": "clip_vision_model",
78
- "no_repeat_ngram_size": 0,
79
- "num_attention_heads": 16,
80
- "num_beam_groups": 1,
81
- "num_beams": 1,
82
- "num_channels": 3,
83
- "num_hidden_layers": 24,
84
- "num_return_sequences": 1,
85
- "output_attentions": false,
86
- "output_hidden_states": false,
87
- "output_scores": false,
88
- "pad_token_id": null,
89
- "patch_size": 14,
90
- "prefix": null,
91
- "problem_type": null,
92
- "projection_dim": 512,
93
- "pruned_heads": {},
94
- "remove_invalid_values": false,
95
- "repetition_penalty": 1.0,
96
- "return_dict": true,
97
- "return_dict_in_generate": false,
98
- "sep_token_id": null,
99
- "suppress_tokens": null,
100
- "task_specific_params": null,
101
- "temperature": 1.0,
102
- "tf_legacy_loss": false,
103
- "tie_encoder_decoder": false,
104
- "tie_word_embeddings": true,
105
- "tokenizer_class": null,
106
- "top_k": 50,
107
- "top_p": 1.0,
108
- "torch_dtype": null,
109
- "torchscript": false,
110
- "transformers_version": "4.30.1",
111
- "typical_p": 1.0,
112
- "use_bfloat16": false
113
- }
114
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mllm/flamingo/flamingo-llama2-chat-7B.json DELETED
@@ -1,115 +0,0 @@
1
- {
2
- "_commit_hash": null,
3
- "architectures": [
4
- "FlamingoForConditionalGeneration"
5
- ],
6
- "cross_attn_every_n_layers": 4,
7
- "model_type": "flamingo",
8
- "text_config": {
9
- "_name_or_path": "meta-llama/Llama-2-7b-chat-hf",
10
- "architectures": [
11
- "LlamaForCausalLM"
12
- ],
13
- "bos_token_id": 1,
14
- "eos_token_id": 2,
15
- "hidden_act": "silu",
16
- "hidden_size": 4096,
17
- "initializer_range": 0.02,
18
- "intermediate_size": 11008,
19
- "max_length": 4096,
20
- "max_position_embeddings": 2048,
21
- "model_type": "llama",
22
- "num_attention_heads": 32,
23
- "num_hidden_layers": 32,
24
- "num_key_value_heads": 32,
25
- "pad_token_id": 0,
26
- "pretraining_tp": 1,
27
- "rms_norm_eps": 1e-05,
28
- "rope_scaling": null,
29
- "tie_word_embeddings": false,
30
- "torch_dtype": "float16",
31
- "transformers_version": "4.32.0.dev0",
32
- "use_cache": true,
33
- "vocab_size": 32000
34
- },
35
- "torch_dtype": "float32",
36
- "transformers_version": null,
37
- "use_media_placement_augmentation": true,
38
- "vision_config": {
39
- "_name_or_path": "openai/clip-vit-large-patch14",
40
- "add_cross_attention": false,
41
- "architectures": null,
42
- "attention_dropout": 0.0,
43
- "bad_words_ids": null,
44
- "begin_suppress_tokens": null,
45
- "bos_token_id": null,
46
- "chunk_size_feed_forward": 0,
47
- "cross_attention_hidden_size": null,
48
- "decoder_start_token_id": null,
49
- "diversity_penalty": 0.0,
50
- "do_sample": false,
51
- "early_stopping": false,
52
- "encoder_no_repeat_ngram_size": 0,
53
- "eos_token_id": null,
54
- "exponential_decay_length_penalty": null,
55
- "finetuning_task": null,
56
- "forced_bos_token_id": null,
57
- "forced_eos_token_id": null,
58
- "hidden_act": "quick_gelu",
59
- "hidden_size": 1024,
60
- "id2label": {
61
- "0": "LABEL_0",
62
- "1": "LABEL_1"
63
- },
64
- "image_size": 224,
65
- "initializer_factor": 1.0,
66
- "initializer_range": 0.02,
67
- "intermediate_size": 4096,
68
- "is_decoder": false,
69
- "is_encoder_decoder": false,
70
- "label2id": {
71
- "LABEL_0": 0,
72
- "LABEL_1": 1
73
- },
74
- "layer_norm_eps": 1e-05,
75
- "length_penalty": 1.0,
76
- "max_length": 20,
77
- "min_length": 0,
78
- "model_type": "clip_vision_model",
79
- "no_repeat_ngram_size": 0,
80
- "num_attention_heads": 16,
81
- "num_beam_groups": 1,
82
- "num_beams": 1,
83
- "num_channels": 3,
84
- "num_hidden_layers": 24,
85
- "num_return_sequences": 1,
86
- "output_attentions": false,
87
- "output_hidden_states": false,
88
- "output_scores": false,
89
- "pad_token_id": null,
90
- "patch_size": 14,
91
- "prefix": null,
92
- "problem_type": null,
93
- "projection_dim": 512,
94
- "pruned_heads": {},
95
- "remove_invalid_values": false,
96
- "repetition_penalty": 1.0,
97
- "return_dict": true,
98
- "return_dict_in_generate": false,
99
- "sep_token_id": null,
100
- "suppress_tokens": null,
101
- "task_specific_params": null,
102
- "temperature": 1.0,
103
- "tf_legacy_loss": false,
104
- "tie_encoder_decoder": false,
105
- "tie_word_embeddings": true,
106
- "tokenizer_class": null,
107
- "top_k": 50,
108
- "top_p": 1.0,
109
- "torch_dtype": null,
110
- "torchscript": false,
111
- "transformers_version": "4.30.1",
112
- "typical_p": 1.0,
113
- "use_bfloat16": false
114
- }
115
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mllm/flamingo/flamingo-mpt-1B-redpajama.json DELETED
@@ -1,131 +0,0 @@
1
- {
2
- "_commit_hash": null,
3
- "architectures": [
4
- "FlamingoForConditionalGeneration"
5
- ],
6
- "cross_attn_every_n_layers": 1,
7
- "model_type": "flamingo",
8
- "text_config": {
9
- "_name_or_path": "",
10
- "alibi": true,
11
- "alibi_bias_max": 8,
12
- "architectures": [
13
- "MosaicGPT"
14
- ],
15
- "attn_clip_qkv": null,
16
- "attn_impl": "torch",
17
- "attn_pdrop": 0,
18
- "attn_qk_ln": true,
19
- "attn_uses_sequence_id": false,
20
- "d_model": 2048,
21
- "hidden_size": 2048,
22
- "emb_init_std": null,
23
- "emb_init_uniform_lim": null,
24
- "emb_pdrop": 0,
25
- "embedding_fraction": 1.0,
26
- "fan_mode": "fan_in",
27
- "init_device": "cpu",
28
- "init_div_is_residual": true,
29
- "init_gain": 0,
30
- "init_nonlinearity": "relu",
31
- "init_std": 0.02,
32
- "logit_scale": null,
33
- "low_precision_layernorm": true,
34
- "max_seq_len": 2048,
35
- "mlp_ratio": 4,
36
- "model_type": "mosaic_gpt",
37
- "n_heads": 16,
38
- "n_layers": 24,
39
- "no_bias": true,
40
- "param_init_fn": "kaiming_normal_",
41
- "prefix_lm": false,
42
- "resid_pdrop": 0,
43
- "softmax_scale": null,
44
- "tokenizer_name": "EleutherAI/gpt-neox-20b",
45
- "torch_dtype": "float32",
46
- "transformers_version": "4.27.4",
47
- "use_cache": false,
48
- "verbose": 0,
49
- "vocab_size": 50432
50
- },
51
- "torch_dtype": "float32",
52
- "transformers_version": null,
53
- "use_media_placement_augmentation": true,
54
- "vision_config": {
55
- "_name_or_path": "openai/clip-vit-large-patch14",
56
- "add_cross_attention": false,
57
- "architectures": null,
58
- "attention_dropout": 0.0,
59
- "bad_words_ids": null,
60
- "begin_suppress_tokens": null,
61
- "bos_token_id": null,
62
- "chunk_size_feed_forward": 0,
63
- "cross_attention_hidden_size": null,
64
- "decoder_start_token_id": null,
65
- "diversity_penalty": 0.0,
66
- "do_sample": false,
67
- "early_stopping": false,
68
- "encoder_no_repeat_ngram_size": 0,
69
- "eos_token_id": null,
70
- "exponential_decay_length_penalty": null,
71
- "finetuning_task": null,
72
- "forced_bos_token_id": null,
73
- "forced_eos_token_id": null,
74
- "hidden_act": "quick_gelu",
75
- "hidden_size": 1024,
76
- "id2label": {
77
- "0": "LABEL_0",
78
- "1": "LABEL_1"
79
- },
80
- "image_size": 224,
81
- "initializer_factor": 1.0,
82
- "initializer_range": 0.02,
83
- "intermediate_size": 4096,
84
- "is_decoder": false,
85
- "is_encoder_decoder": false,
86
- "label2id": {
87
- "LABEL_0": 0,
88
- "LABEL_1": 1
89
- },
90
- "layer_norm_eps": 1e-05,
91
- "length_penalty": 1.0,
92
- "max_length": 20,
93
- "min_length": 0,
94
- "model_type": "clip_vision_model",
95
- "no_repeat_ngram_size": 0,
96
- "num_attention_heads": 16,
97
- "num_beam_groups": 1,
98
- "num_beams": 1,
99
- "num_channels": 3,
100
- "num_hidden_layers": 24,
101
- "num_return_sequences": 1,
102
- "output_attentions": false,
103
- "output_hidden_states": false,
104
- "output_scores": false,
105
- "pad_token_id": null,
106
- "patch_size": 14,
107
- "prefix": null,
108
- "problem_type": null,
109
- "projection_dim": 512,
110
- "pruned_heads": {},
111
- "remove_invalid_values": false,
112
- "repetition_penalty": 1.0,
113
- "return_dict": true,
114
- "return_dict_in_generate": false,
115
- "sep_token_id": null,
116
- "suppress_tokens": null,
117
- "task_specific_params": null,
118
- "temperature": 1.0,
119
- "tf_legacy_loss": false,
120
- "tie_encoder_decoder": false,
121
- "tie_word_embeddings": true,
122
- "tokenizer_class": null,
123
- "top_k": 50,
124
- "top_p": 1.0,
125
- "torch_dtype": null,
126
- "torchscript": false,
127
- "transformers_version": "4.30.1",
128
- "typical_p": 1.0,
129
- "use_bfloat16": false
130
- }
131
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mllm/flamingo/flamingo-mpt-30B-bf16.json DELETED
@@ -1,195 +0,0 @@
1
- {
2
- "_commit_hash": null,
3
- "architectures": [
4
- "FlamingoForConditionalGeneration"
5
- ],
6
- "cross_attn_every_n_layers": 7,
7
- "model_type": "flamingo",
8
- "text_config": {
9
- "_name_or_path": "",
10
- "add_cross_attention": false,
11
- "architectures": [
12
- "MPTForCausalLM"
13
- ],
14
- "attn_config": {
15
- "alibi": true,
16
- "alibi_bias_max": 8,
17
- "attn_impl": "torch",
18
- "attn_pdrop": 0,
19
- "attn_type": "multihead_attention",
20
- "attn_uses_sequence_id": false,
21
- "clip_qkv": null,
22
- "prefix_lm": false,
23
- "qk_ln": false,
24
- "softmax_scale": null
25
- },
26
- "bad_words_ids": null,
27
- "begin_suppress_tokens": null,
28
- "bos_token_id": null,
29
- "chunk_size_feed_forward": 0,
30
- "cross_attention_hidden_size": null,
31
- "d_model": 7168,
32
- "decoder_start_token_id": null,
33
- "diversity_penalty": 0.0,
34
- "do_sample": false,
35
- "early_stopping": false,
36
- "emb_pdrop": 0,
37
- "embedding_fraction": 1.0,
38
- "encoder_no_repeat_ngram_size": 0,
39
- "eos_token_id": null,
40
- "expansion_ratio": 4,
41
- "exponential_decay_length_penalty": null,
42
- "finetuning_task": null,
43
- "forced_bos_token_id": null,
44
- "forced_eos_token_id": null,
45
- "hidden_size": 7168,
46
- "id2label": {
47
- "0": "LABEL_0",
48
- "1": "LABEL_1"
49
- },
50
- "init_config": {
51
- "emb_init_std": null,
52
- "emb_init_uniform_lim": null,
53
- "fan_mode": "fan_in",
54
- "init_div_is_residual": true,
55
- "init_gain": 0.0,
56
- "init_nonlinearity": "relu",
57
- "init_std": null,
58
- "name": "kaiming_normal_",
59
- "verbose": 0
60
- },
61
- "init_device": "cpu",
62
- "is_decoder": false,
63
- "is_encoder_decoder": false,
64
- "label2id": {
65
- "LABEL_0": 0,
66
- "LABEL_1": 1
67
- },
68
- "learned_pos_emb": true,
69
- "length_penalty": 1.0,
70
- "logit_scale": null,
71
- "max_length": 20,
72
- "max_seq_len": 8192,
73
- "min_length": 0,
74
- "model_type": "mpt",
75
- "n_heads": 64,
76
- "n_layers": 48,
77
- "no_bias": true,
78
- "no_repeat_ngram_size": 0,
79
- "norm_type": "low_precision_layernorm",
80
- "num_beam_groups": 1,
81
- "num_beams": 1,
82
- "num_return_sequences": 1,
83
- "output_attentions": false,
84
- "output_hidden_states": false,
85
- "output_scores": false,
86
- "pad_token_id": null,
87
- "prefix": null,
88
- "problem_type": null,
89
- "pruned_heads": {},
90
- "remove_invalid_values": false,
91
- "repetition_penalty": 1.0,
92
- "resid_pdrop": 0,
93
- "return_dict": true,
94
- "return_dict_in_generate": false,
95
- "sep_token_id": null,
96
- "suppress_tokens": null,
97
- "task_specific_params": null,
98
- "temperature": 1.0,
99
- "tf_legacy_loss": false,
100
- "tie_encoder_decoder": false,
101
- "tie_word_embeddings": true,
102
- "tokenizer_class": null,
103
- "tokenizer_name": "EleutherAI/gpt-neox-20b",
104
- "top_k": 50,
105
- "top_p": 1.0,
106
- "torch_dtype": "bfloat16",
107
- "torchscript": false,
108
- "transformers_version": "4.30.1",
109
- "typical_p": 1.0,
110
- "use_bfloat16": false,
111
- "use_cache": false,
112
- "verbose": 0,
113
- "vocab_size": 50432
114
- },
115
- "torch_dtype": "bfloat16",
116
- "transformers_version": null,
117
- "use_media_placement_augmentation": true,
118
- "vision_config": {
119
- "_name_or_path": "openai/clip-vit-large-patch14",
120
- "add_cross_attention": false,
121
- "architectures": null,
122
- "attention_dropout": 0.0,
123
- "bad_words_ids": null,
124
- "begin_suppress_tokens": null,
125
- "bos_token_id": null,
126
- "chunk_size_feed_forward": 0,
127
- "cross_attention_hidden_size": null,
128
- "decoder_start_token_id": null,
129
- "diversity_penalty": 0.0,
130
- "do_sample": false,
131
- "early_stopping": false,
132
- "encoder_no_repeat_ngram_size": 0,
133
- "eos_token_id": null,
134
- "exponential_decay_length_penalty": null,
135
- "finetuning_task": null,
136
- "forced_bos_token_id": null,
137
- "forced_eos_token_id": null,
138
- "hidden_act": "quick_gelu",
139
- "hidden_size": 1024,
140
- "id2label": {
141
- "0": "LABEL_0",
142
- "1": "LABEL_1"
143
- },
144
- "image_size": 224,
145
- "initializer_factor": 1.0,
146
- "initializer_range": 0.02,
147
- "intermediate_size": 4096,
148
- "is_decoder": false,
149
- "is_encoder_decoder": false,
150
- "label2id": {
151
- "LABEL_0": 0,
152
- "LABEL_1": 1
153
- },
154
- "layer_norm_eps": 1e-05,
155
- "length_penalty": 1.0,
156
- "max_length": 20,
157
- "min_length": 0,
158
- "model_type": "clip_vision_model",
159
- "no_repeat_ngram_size": 0,
160
- "num_attention_heads": 16,
161
- "num_beam_groups": 1,
162
- "num_beams": 1,
163
- "num_channels": 3,
164
- "num_hidden_layers": 24,
165
- "num_return_sequences": 1,
166
- "output_attentions": false,
167
- "output_hidden_states": false,
168
- "output_scores": false,
169
- "pad_token_id": null,
170
- "patch_size": 14,
171
- "prefix": null,
172
- "problem_type": null,
173
- "projection_dim": 512,
174
- "pruned_heads": {},
175
- "remove_invalid_values": false,
176
- "repetition_penalty": 1.0,
177
- "return_dict": true,
178
- "return_dict_in_generate": false,
179
- "sep_token_id": null,
180
- "suppress_tokens": null,
181
- "task_specific_params": null,
182
- "temperature": 1.0,
183
- "tf_legacy_loss": false,
184
- "tie_encoder_decoder": false,
185
- "tie_word_embeddings": true,
186
- "tokenizer_class": null,
187
- "top_k": 50,
188
- "top_p": 1.0,
189
- "torch_dtype": null,
190
- "torchscript": false,
191
- "transformers_version": "4.30.1",
192
- "typical_p": 1.0,
193
- "use_bfloat16": false
194
- }
195
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mllm/flamingo/flamingo-mpt-30B.json DELETED
@@ -1,195 +0,0 @@
1
- {
2
- "_commit_hash": null,
3
- "architectures": [
4
- "FlamingoForConditionalGeneration"
5
- ],
6
- "cross_attn_every_n_layers": 7,
7
- "model_type": "flamingo",
8
- "text_config": {
9
- "_name_or_path": "",
10
- "add_cross_attention": false,
11
- "architectures": [
12
- "MPTForCausalLM"
13
- ],
14
- "attn_config": {
15
- "alibi": true,
16
- "alibi_bias_max": 8,
17
- "attn_impl": "torch",
18
- "attn_pdrop": 0,
19
- "attn_type": "multihead_attention",
20
- "attn_uses_sequence_id": false,
21
- "clip_qkv": null,
22
- "prefix_lm": false,
23
- "qk_ln": false,
24
- "softmax_scale": null
25
- },
26
- "bad_words_ids": null,
27
- "begin_suppress_tokens": null,
28
- "bos_token_id": null,
29
- "chunk_size_feed_forward": 0,
30
- "cross_attention_hidden_size": null,
31
- "d_model": 7168,
32
- "decoder_start_token_id": null,
33
- "diversity_penalty": 0.0,
34
- "do_sample": false,
35
- "early_stopping": false,
36
- "emb_pdrop": 0,
37
- "embedding_fraction": 1.0,
38
- "encoder_no_repeat_ngram_size": 0,
39
- "eos_token_id": null,
40
- "expansion_ratio": 4,
41
- "exponential_decay_length_penalty": null,
42
- "finetuning_task": null,
43
- "forced_bos_token_id": null,
44
- "forced_eos_token_id": null,
45
- "hidden_size": 7168,
46
- "id2label": {
47
- "0": "LABEL_0",
48
- "1": "LABEL_1"
49
- },
50
- "init_config": {
51
- "emb_init_std": null,
52
- "emb_init_uniform_lim": null,
53
- "fan_mode": "fan_in",
54
- "init_div_is_residual": true,
55
- "init_gain": 0.0,
56
- "init_nonlinearity": "relu",
57
- "init_std": null,
58
- "name": "kaiming_normal_",
59
- "verbose": 0
60
- },
61
- "init_device": "cpu",
62
- "is_decoder": false,
63
- "is_encoder_decoder": false,
64
- "label2id": {
65
- "LABEL_0": 0,
66
- "LABEL_1": 1
67
- },
68
- "learned_pos_emb": true,
69
- "length_penalty": 1.0,
70
- "logit_scale": null,
71
- "max_length": 20,
72
- "max_seq_len": 8192,
73
- "min_length": 0,
74
- "model_type": "mpt",
75
- "n_heads": 64,
76
- "n_layers": 48,
77
- "no_bias": true,
78
- "no_repeat_ngram_size": 0,
79
- "norm_type": "low_precision_layernorm",
80
- "num_beam_groups": 1,
81
- "num_beams": 1,
82
- "num_return_sequences": 1,
83
- "output_attentions": false,
84
- "output_hidden_states": false,
85
- "output_scores": false,
86
- "pad_token_id": null,
87
- "prefix": null,
88
- "problem_type": null,
89
- "pruned_heads": {},
90
- "remove_invalid_values": false,
91
- "repetition_penalty": 1.0,
92
- "resid_pdrop": 0,
93
- "return_dict": true,
94
- "return_dict_in_generate": false,
95
- "sep_token_id": null,
96
- "suppress_tokens": null,
97
- "task_specific_params": null,
98
- "temperature": 1.0,
99
- "tf_legacy_loss": false,
100
- "tie_encoder_decoder": false,
101
- "tie_word_embeddings": true,
102
- "tokenizer_class": null,
103
- "tokenizer_name": "EleutherAI/gpt-neox-20b",
104
- "top_k": 50,
105
- "top_p": 1.0,
106
- "torch_dtype": "bfloat16",
107
- "torchscript": false,
108
- "transformers_version": "4.30.1",
109
- "typical_p": 1.0,
110
- "use_bfloat16": false,
111
- "use_cache": false,
112
- "verbose": 0,
113
- "vocab_size": 50432
114
- },
115
- "torch_dtype": "float32",
116
- "transformers_version": null,
117
- "use_media_placement_augmentation": true,
118
- "vision_config": {
119
- "_name_or_path": "openai/clip-vit-large-patch14",
120
- "add_cross_attention": false,
121
- "architectures": null,
122
- "attention_dropout": 0.0,
123
- "bad_words_ids": null,
124
- "begin_suppress_tokens": null,
125
- "bos_token_id": null,
126
- "chunk_size_feed_forward": 0,
127
- "cross_attention_hidden_size": null,
128
- "decoder_start_token_id": null,
129
- "diversity_penalty": 0.0,
130
- "do_sample": false,
131
- "early_stopping": false,
132
- "encoder_no_repeat_ngram_size": 0,
133
- "eos_token_id": null,
134
- "exponential_decay_length_penalty": null,
135
- "finetuning_task": null,
136
- "forced_bos_token_id": null,
137
- "forced_eos_token_id": null,
138
- "hidden_act": "quick_gelu",
139
- "hidden_size": 1024,
140
- "id2label": {
141
- "0": "LABEL_0",
142
- "1": "LABEL_1"
143
- },
144
- "image_size": 224,
145
- "initializer_factor": 1.0,
146
- "initializer_range": 0.02,
147
- "intermediate_size": 4096,
148
- "is_decoder": false,
149
- "is_encoder_decoder": false,
150
- "label2id": {
151
- "LABEL_0": 0,
152
- "LABEL_1": 1
153
- },
154
- "layer_norm_eps": 1e-05,
155
- "length_penalty": 1.0,
156
- "max_length": 20,
157
- "min_length": 0,
158
- "model_type": "clip_vision_model",
159
- "no_repeat_ngram_size": 0,
160
- "num_attention_heads": 16,
161
- "num_beam_groups": 1,
162
- "num_beams": 1,
163
- "num_channels": 3,
164
- "num_hidden_layers": 24,
165
- "num_return_sequences": 1,
166
- "output_attentions": false,
167
- "output_hidden_states": false,
168
- "output_scores": false,
169
- "pad_token_id": null,
170
- "patch_size": 14,
171
- "prefix": null,
172
- "problem_type": null,
173
- "projection_dim": 512,
174
- "pruned_heads": {},
175
- "remove_invalid_values": false,
176
- "repetition_penalty": 1.0,
177
- "return_dict": true,
178
- "return_dict_in_generate": false,
179
- "sep_token_id": null,
180
- "suppress_tokens": null,
181
- "task_specific_params": null,
182
- "temperature": 1.0,
183
- "tf_legacy_loss": false,
184
- "tie_encoder_decoder": false,
185
- "tie_word_embeddings": true,
186
- "tokenizer_class": null,
187
- "top_k": 50,
188
- "top_p": 1.0,
189
- "torch_dtype": null,
190
- "torchscript": false,
191
- "transformers_version": "4.30.1",
192
- "typical_p": 1.0,
193
- "use_bfloat16": false
194
- }
195
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mllm/flamingo/flamingo-mpt-7B.json DELETED
@@ -1,195 +0,0 @@
1
- {
2
- "_commit_hash": null,
3
- "architectures": [
4
- "FlamingoForConditionalGeneration"
5
- ],
6
- "cross_attn_every_n_layers": 4,
7
- "model_type": "flamingo",
8
- "text_config": {
9
- "_name_or_path": "",
10
- "add_cross_attention": false,
11
- "architectures": [
12
- "MPTForCausalLM"
13
- ],
14
- "attn_config": {
15
- "alibi": true,
16
- "alibi_bias_max": 8,
17
- "attn_impl": "torch",
18
- "attn_pdrop": 0,
19
- "attn_type": "multihead_attention",
20
- "attn_uses_sequence_id": false,
21
- "clip_qkv": null,
22
- "prefix_lm": false,
23
- "qk_ln": false,
24
- "softmax_scale": null
25
- },
26
- "bad_words_ids": null,
27
- "begin_suppress_tokens": null,
28
- "bos_token_id": null,
29
- "chunk_size_feed_forward": 0,
30
- "cross_attention_hidden_size": null,
31
- "d_model": 4096,
32
- "decoder_start_token_id": null,
33
- "diversity_penalty": 0.0,
34
- "do_sample": false,
35
- "early_stopping": false,
36
- "emb_pdrop": 0,
37
- "embedding_fraction": 1.0,
38
- "encoder_no_repeat_ngram_size": 0,
39
- "eos_token_id": null,
40
- "expansion_ratio": 4,
41
- "exponential_decay_length_penalty": null,
42
- "finetuning_task": null,
43
- "forced_bos_token_id": null,
44
- "forced_eos_token_id": null,
45
- "hidden_size": 4096,
46
- "id2label": {
47
- "0": "LABEL_0",
48
- "1": "LABEL_1"
49
- },
50
- "init_config": {
51
- "emb_init_std": null,
52
- "emb_init_uniform_lim": null,
53
- "fan_mode": "fan_in",
54
- "init_div_is_residual": true,
55
- "init_gain": 0,
56
- "init_nonlinearity": "relu",
57
- "init_std": 0.02,
58
- "name": "kaiming_normal_",
59
- "verbose": 0
60
- },
61
- "init_device": "cpu",
62
- "is_decoder": false,
63
- "is_encoder_decoder": false,
64
- "label2id": {
65
- "LABEL_0": 0,
66
- "LABEL_1": 1
67
- },
68
- "learned_pos_emb": true,
69
- "length_penalty": 1.0,
70
- "logit_scale": null,
71
- "max_length": 20,
72
- "max_seq_len": 2048,
73
- "min_length": 0,
74
- "model_type": "mpt",
75
- "n_heads": 32,
76
- "n_layers": 32,
77
- "no_bias": true,
78
- "no_repeat_ngram_size": 0,
79
- "norm_type": "low_precision_layernorm",
80
- "num_beam_groups": 1,
81
- "num_beams": 1,
82
- "num_return_sequences": 1,
83
- "output_attentions": false,
84
- "output_hidden_states": false,
85
- "output_scores": false,
86
- "pad_token_id": null,
87
- "prefix": null,
88
- "problem_type": null,
89
- "pruned_heads": {},
90
- "remove_invalid_values": false,
91
- "repetition_penalty": 1.0,
92
- "resid_pdrop": 0,
93
- "return_dict": true,
94
- "return_dict_in_generate": false,
95
- "sep_token_id": null,
96
- "suppress_tokens": null,
97
- "task_specific_params": null,
98
- "temperature": 1.0,
99
- "tf_legacy_loss": false,
100
- "tie_encoder_decoder": false,
101
- "tie_word_embeddings": true,
102
- "tokenizer_class": null,
103
- "tokenizer_name": "EleutherAI/gpt-neox-20b",
104
- "top_k": 50,
105
- "top_p": 1.0,
106
- "torch_dtype": "bfloat16",
107
- "torchscript": false,
108
- "transformers_version": "4.30.1",
109
- "typical_p": 1.0,
110
- "use_bfloat16": false,
111
- "use_cache": false,
112
- "verbose": 0,
113
- "vocab_size": 50432
114
- },
115
- "torch_dtype": "float32",
116
- "transformers_version": null,
117
- "use_media_placement_augmentation": true,
118
- "vision_config": {
119
- "_name_or_path": "openai/clip-vit-large-patch14",
120
- "add_cross_attention": false,
121
- "architectures": null,
122
- "attention_dropout": 0.0,
123
- "bad_words_ids": null,
124
- "begin_suppress_tokens": null,
125
- "bos_token_id": null,
126
- "chunk_size_feed_forward": 0,
127
- "cross_attention_hidden_size": null,
128
- "decoder_start_token_id": null,
129
- "diversity_penalty": 0.0,
130
- "do_sample": false,
131
- "early_stopping": false,
132
- "encoder_no_repeat_ngram_size": 0,
133
- "eos_token_id": null,
134
- "exponential_decay_length_penalty": null,
135
- "finetuning_task": null,
136
- "forced_bos_token_id": null,
137
- "forced_eos_token_id": null,
138
- "hidden_act": "quick_gelu",
139
- "hidden_size": 1024,
140
- "id2label": {
141
- "0": "LABEL_0",
142
- "1": "LABEL_1"
143
- },
144
- "image_size": 224,
145
- "initializer_factor": 1.0,
146
- "initializer_range": 0.02,
147
- "intermediate_size": 4096,
148
- "is_decoder": false,
149
- "is_encoder_decoder": false,
150
- "label2id": {
151
- "LABEL_0": 0,
152
- "LABEL_1": 1
153
- },
154
- "layer_norm_eps": 1e-05,
155
- "length_penalty": 1.0,
156
- "max_length": 20,
157
- "min_length": 0,
158
- "model_type": "clip_vision_model",
159
- "no_repeat_ngram_size": 0,
160
- "num_attention_heads": 16,
161
- "num_beam_groups": 1,
162
- "num_beams": 1,
163
- "num_channels": 3,
164
- "num_hidden_layers": 24,
165
- "num_return_sequences": 1,
166
- "output_attentions": false,
167
- "output_hidden_states": false,
168
- "output_scores": false,
169
- "pad_token_id": null,
170
- "patch_size": 14,
171
- "prefix": null,
172
- "problem_type": null,
173
- "projection_dim": 512,
174
- "pruned_heads": {},
175
- "remove_invalid_values": false,
176
- "repetition_penalty": 1.0,
177
- "return_dict": true,
178
- "return_dict_in_generate": false,
179
- "sep_token_id": null,
180
- "suppress_tokens": null,
181
- "task_specific_params": null,
182
- "temperature": 1.0,
183
- "tf_legacy_loss": false,
184
- "tie_encoder_decoder": false,
185
- "tie_word_embeddings": true,
186
- "tokenizer_class": null,
187
- "top_k": 50,
188
- "top_p": 1.0,
189
- "torch_dtype": null,
190
- "torchscript": false,
191
- "transformers_version": "4.30.1",
192
- "typical_p": 1.0,
193
- "use_bfloat16": false
194
- }
195
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mllm/flamingo/flamingo-vicuna-33B-v1.3.json DELETED
@@ -1,111 +0,0 @@
1
- {
2
- "_commit_hash": null,
3
- "architectures": [
4
- "FlamingoForConditionalGeneration"
5
- ],
6
- "cross_attn_every_n_layers": 4,
7
- "model_type": "flamingo",
8
- "text_config": {
9
- "_name_or_path": "/home/luodian/projects/checkpoints/vicuna-33b-v1.3",
10
- "architectures": [
11
- "LlamaForCausalLM"
12
- ],
13
- "bos_token_id": 1,
14
- "eos_token_id": 2,
15
- "hidden_act": "silu",
16
- "hidden_size": 6656,
17
- "initializer_range": 0.02,
18
- "intermediate_size": 17920,
19
- "max_position_embeddings": 2048,
20
- "model_type": "llama",
21
- "num_attention_heads": 52,
22
- "num_hidden_layers": 60,
23
- "pad_token_id": 0,
24
- "rms_norm_eps": 1e-06,
25
- "tie_word_embeddings": false,
26
- "torch_dtype": "float16",
27
- "transformers_version": "4.28.1",
28
- "use_cache": false,
29
- "vocab_size": 32000
30
- },
31
- "torch_dtype": "float32",
32
- "transformers_version": null,
33
- "use_media_placement_augmentation": true,
34
- "vision_config": {
35
- "_name_or_path": "openai/clip-vit-large-patch14",
36
- "add_cross_attention": false,
37
- "architectures": null,
38
- "attention_dropout": 0.0,
39
- "bad_words_ids": null,
40
- "begin_suppress_tokens": null,
41
- "bos_token_id": null,
42
- "chunk_size_feed_forward": 0,
43
- "cross_attention_hidden_size": null,
44
- "decoder_start_token_id": null,
45
- "diversity_penalty": 0.0,
46
- "do_sample": false,
47
- "early_stopping": false,
48
- "encoder_no_repeat_ngram_size": 0,
49
- "eos_token_id": null,
50
- "exponential_decay_length_penalty": null,
51
- "finetuning_task": null,
52
- "forced_bos_token_id": null,
53
- "forced_eos_token_id": null,
54
- "hidden_act": "quick_gelu",
55
- "hidden_size": 1024,
56
- "id2label": {
57
- "0": "LABEL_0",
58
- "1": "LABEL_1"
59
- },
60
- "image_size": 224,
61
- "initializer_factor": 1.0,
62
- "initializer_range": 0.02,
63
- "intermediate_size": 4096,
64
- "is_decoder": false,
65
- "is_encoder_decoder": false,
66
- "label2id": {
67
- "LABEL_0": 0,
68
- "LABEL_1": 1
69
- },
70
- "layer_norm_eps": 1e-05,
71
- "length_penalty": 1.0,
72
- "max_length": 20,
73
- "min_length": 0,
74
- "model_type": "clip_vision_model",
75
- "no_repeat_ngram_size": 0,
76
- "num_attention_heads": 16,
77
- "num_beam_groups": 1,
78
- "num_beams": 1,
79
- "num_channels": 3,
80
- "num_hidden_layers": 24,
81
- "num_return_sequences": 1,
82
- "output_attentions": false,
83
- "output_hidden_states": false,
84
- "output_scores": false,
85
- "pad_token_id": null,
86
- "patch_size": 14,
87
- "prefix": null,
88
- "problem_type": null,
89
- "projection_dim": 512,
90
- "pruned_heads": {},
91
- "remove_invalid_values": false,
92
- "repetition_penalty": 1.0,
93
- "return_dict": true,
94
- "return_dict_in_generate": false,
95
- "sep_token_id": null,
96
- "suppress_tokens": null,
97
- "task_specific_params": null,
98
- "temperature": 1.0,
99
- "tf_legacy_loss": false,
100
- "tie_encoder_decoder": false,
101
- "tie_word_embeddings": true,
102
- "tokenizer_class": null,
103
- "top_k": 50,
104
- "top_p": 1.0,
105
- "torch_dtype": null,
106
- "torchscript": false,
107
- "transformers_version": "4.30.1",
108
- "typical_p": 1.0,
109
- "use_bfloat16": false
110
- }
111
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mllm/flamingo/flamingo-vicuna-7B-v1.3.json DELETED
@@ -1,111 +0,0 @@
1
- {
2
- "_commit_hash": null,
3
- "architectures": [
4
- "FlamingoForConditionalGeneration"
5
- ],
6
- "cross_attn_every_n_layers": 4,
7
- "model_type": "flamingo",
8
- "text_config": {
9
- "_name_or_path": "/mnt/petrelfs/share_data/zhangyuanhan/vicuna-7b-v1.3",
10
- "architectures": [
11
- "LlamaForCausalLM"
12
- ],
13
- "bos_token_id": 1,
14
- "eos_token_id": 2,
15
- "hidden_act": "silu",
16
- "hidden_size": 4096,
17
- "initializer_range": 0.02,
18
- "intermediate_size": 11008,
19
- "max_position_embeddings": 2048,
20
- "model_type": "llama",
21
- "num_attention_heads": 32,
22
- "num_hidden_layers": 32,
23
- "pad_token_id": 0,
24
- "rms_norm_eps": 1e-06,
25
- "tie_word_embeddings": false,
26
- "torch_dtype": "float16",
27
- "transformers_version": "4.28.1",
28
- "use_cache": false,
29
- "vocab_size": 32000
30
- },
31
- "torch_dtype": "float32",
32
- "transformers_version": null,
33
- "use_media_placement_augmentation": true,
34
- "vision_config": {
35
- "_name_or_path": "openai/clip-vit-large-patch14",
36
- "add_cross_attention": false,
37
- "architectures": null,
38
- "attention_dropout": 0.0,
39
- "bad_words_ids": null,
40
- "begin_suppress_tokens": null,
41
- "bos_token_id": null,
42
- "chunk_size_feed_forward": 0,
43
- "cross_attention_hidden_size": null,
44
- "decoder_start_token_id": null,
45
- "diversity_penalty": 0.0,
46
- "do_sample": false,
47
- "early_stopping": false,
48
- "encoder_no_repeat_ngram_size": 0,
49
- "eos_token_id": null,
50
- "exponential_decay_length_penalty": null,
51
- "finetuning_task": null,
52
- "forced_bos_token_id": null,
53
- "forced_eos_token_id": null,
54
- "hidden_act": "quick_gelu",
55
- "hidden_size": 1024,
56
- "id2label": {
57
- "0": "LABEL_0",
58
- "1": "LABEL_1"
59
- },
60
- "image_size": 224,
61
- "initializer_factor": 1.0,
62
- "initializer_range": 0.02,
63
- "intermediate_size": 4096,
64
- "is_decoder": false,
65
- "is_encoder_decoder": false,
66
- "label2id": {
67
- "LABEL_0": 0,
68
- "LABEL_1": 1
69
- },
70
- "layer_norm_eps": 1e-05,
71
- "length_penalty": 1.0,
72
- "max_length": 20,
73
- "min_length": 0,
74
- "model_type": "clip_vision_model",
75
- "no_repeat_ngram_size": 0,
76
- "num_attention_heads": 16,
77
- "num_beam_groups": 1,
78
- "num_beams": 1,
79
- "num_channels": 3,
80
- "num_hidden_layers": 24,
81
- "num_return_sequences": 1,
82
- "output_attentions": false,
83
- "output_hidden_states": false,
84
- "output_scores": false,
85
- "pad_token_id": null,
86
- "patch_size": 14,
87
- "prefix": null,
88
- "problem_type": null,
89
- "projection_dim": 512,
90
- "pruned_heads": {},
91
- "remove_invalid_values": false,
92
- "repetition_penalty": 1.0,
93
- "return_dict": true,
94
- "return_dict_in_generate": false,
95
- "sep_token_id": null,
96
- "suppress_tokens": null,
97
- "task_specific_params": null,
98
- "temperature": 1.0,
99
- "tf_legacy_loss": false,
100
- "tie_encoder_decoder": false,
101
- "tie_word_embeddings": true,
102
- "tokenizer_class": null,
103
- "top_k": 50,
104
- "top_p": 1.0,
105
- "torch_dtype": null,
106
- "torchscript": false,
107
- "transformers_version": "4.30.1",
108
- "typical_p": 1.0,
109
- "use_bfloat16": false
110
- }
111
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mllm/flamingo/injecting_falcon_into_flamingo.py DELETED
@@ -1,49 +0,0 @@
1
- import os
2
- import torch
3
- from .configuration_flamingo import FlamingoConfig
4
- from .modeling_flamingo import FlamingoForConditionalGeneration
5
-
6
- root_dir = os.environ["AZP"]
7
- print(root_dir)
8
-
9
-
10
- config = FlamingoConfig.from_json_file(".flamingo-falcon-7B.json")
11
- model = FlamingoForConditionalGeneration(config=config)
12
-
13
-
14
- state_dict_files = [
15
- f"{root_dir}/otter/checkpoints/falcon-7b/pytorch_model-00001-of-00002.bin",
16
- f"{root_dir}/otter/checkpoints/falcon-7b/pytorch_model-00002-of-00002.bin",
17
- ]
18
-
19
- state_dict = {}
20
- for file in state_dict_files:
21
- state_dict_part = torch.load(file, map_location="cpu")
22
- state_dict.update(state_dict_part)
23
-
24
-
25
- state_dict_3 = torch.load("{root_dir}/otter/checkpoints/flamingo_9b_hf/pytorch_model-00004-of-00004.bin", map_location="cpu")
26
- for cur_key in list(state_dict_3.keys()):
27
- if "vision_encoder" not in cur_key:
28
- del state_dict_3[cur_key]
29
-
30
- _ = model.load_state_dict(
31
- state_dict_3,
32
- False,
33
- )
34
- print(_[1])
35
-
36
- save_state_dict_1 = {}
37
- for key in state_dict:
38
- if ".h." in key:
39
- _, _, layer_num, *remain_names = key.split(".")
40
- target_key = f"transformer.h.{layer_num}.decoder_layer.{'.'.join(remain_names)}"
41
- else:
42
- target_key = key
43
- save_state_dict_1[f"{target_key}"] = state_dict[key]
44
- _ = model.lang_encoder.load_state_dict(
45
- save_state_dict_1,
46
- False,
47
- )
48
- print(_[1])
49
- model.save_pretrained(f"{root_dir}/otter/checkpoints/flamingo-falcon-7b/")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mllm/flamingo/injecting_llama2_into_flamingo.py DELETED
@@ -1,95 +0,0 @@
1
- import argparse
2
- import os
3
-
4
- import torch
5
- from tqdm import tqdm
6
-
7
- import sys
8
-
9
- from .configuration_flamingo import FlamingoConfig
10
- from .modeling_flamingo import FlamingoForConditionalGeneration
11
-
12
- # from .configuration_flamingo import FlamingoConfig
13
- # from .modeling_flamingo import FlamingoForConditionalGeneration
14
-
15
- parser = argparse.ArgumentParser(description="Convert Vicuna model")
16
- parser.add_argument("--model_choice", type=str, default="13B", help="Choose either '7B' or '13B'")
17
- parser.add_argument("--llama2_root_dir", type=str, default="/home/luodian/projects/checkpoints")
18
- parser.add_argument("--save_root_dir", type=str, default="/home/luodian/projects/checkpoints")
19
- args = parser.parse_args()
20
-
21
- # os.environ["TOKENIZERS_PARALLELISM"] = "false"
22
-
23
- root_dir = args.llama2_root_dir
24
- model_choice = args.model_choice
25
- save_root_dir = args.save_root_dir
26
-
27
- # prepare vicuna model at first
28
- # you can visit https://huggingface.co/lmsys/Llama-2-33b-chat-hf to download 7B and 30B instruct checkpoints.
29
- if model_choice == "7B":
30
- config_file = "./flamingo/flamingo-llama2-chat-7B.json"
31
- state_dict_files = [
32
- f"{root_dir}/Llama-2-7b-chat-hf/pytorch_model-00001-of-00002.bin",
33
- f"{root_dir}/Llama-2-7b-chat-hf/pytorch_model-00002-of-00002.bin",
34
- ]
35
- save_path = f"{save_root_dir}/flamingo-llama2-chat-7B-init"
36
- elif model_choice == "13B":
37
- config_file = "./flamingo/flamingo-llama2-chat-13B.json"
38
- state_dict_files = [
39
- f"{root_dir}/Llama-2-13b-chat-hf/pytorch_model-00001-of-00003.bin",
40
- f"{root_dir}/Llama-2-13b-chat-hf/pytorch_model-00002-of-00003.bin",
41
- f"{root_dir}/Llama-2-13b-chat-hf/pytorch_model-00003-of-00003.bin",
42
- ]
43
- save_path = f"{save_root_dir}/flamingo-llama2-chat-13B-init"
44
- else:
45
- raise ValueError("Invalid model_choice. Choose either '13B' or '7B'.")
46
-
47
- config = FlamingoConfig.from_json_file(config_file)
48
- model = FlamingoForConditionalGeneration(config=config)
49
-
50
- # load flamingo's vision encoder from last checkpoint.
51
- # you can visit https://huggingface.co/luodian/openflamingo-9b-hf/tree/main to download the checkpoint.
52
- # AZP = "os.environ["AZP"]"
53
- AZP = os.environ["AZP"]
54
- state_dict_3 = torch.load(f"{AZP}/otter/checkpoints/flamingo_9b_hf/pytorch_model-00004-of-00004.bin", map_location="cpu")
55
- for cur_key in list(state_dict_3.keys()):
56
- if "vision_encoder" not in cur_key:
57
- del state_dict_3[cur_key]
58
-
59
- load_msg = model.load_state_dict(
60
- state_dict_3,
61
- False,
62
- )
63
- # print incompatible keys
64
- print(load_msg[1])
65
-
66
- # Loading vicuna weights
67
- state_dict = {}
68
- for file in tqdm(state_dict_files, desc="Loading state dict"):
69
- state_dict_part = torch.load(file, map_location="cpu")
70
- state_dict.update(state_dict_part)
71
-
72
- save_state_dict_1 = {}
73
- for key in state_dict:
74
- if ".layers." in key:
75
- _, _, layer_num, *remain_names = key.split(".")
76
- target_key = f"model.layers.{layer_num}.decoder_layer.{'.'.join(remain_names)}"
77
- else:
78
- target_key = key
79
- save_state_dict_1[f"{target_key}"] = state_dict[key]
80
-
81
- # Reshape the token embedding to 50280 for compatible
82
- model.lang_encoder.resize_token_embeddings(32000)
83
-
84
- load_msg = model.lang_encoder.load_state_dict(
85
- save_state_dict_1,
86
- False,
87
- )
88
- # Reshape the token embedding to 32002 for compatible
89
- model.lang_encoder.resize_token_embeddings(32002)
90
- # print incompatible keys
91
- print(load_msg[1])
92
-
93
-
94
- print(f"Saving model to {save_path}...")
95
- model.save_pretrained(save_path, max_shard_size="10GB")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mllm/flamingo/injecting_mpt-1B-redpajama_into_flamingo.py DELETED
@@ -1,97 +0,0 @@
1
- import argparse
2
- import os
3
-
4
- import torch
5
- from tqdm import tqdm
6
-
7
- import sys
8
-
9
- from configuration_flamingo import FlamingoConfig
10
- from modeling_flamingo import FlamingoForConditionalGeneration
11
- from utils import rename_flamingo_checkpoint
12
-
13
-
14
- parser = argparse.ArgumentParser(description="Convert MPT model")
15
- parser.add_argument("--mpt_root_dir", type=str, default="/home/luodian/projects/checkpoints")
16
- parser.add_argument("--save_root_dir", type=str, default="/home/luodian/projects/checkpoints")
17
- parser.add_argument("--flamingo_dir", type=str, default=None, help="If the pretrained flamingo weights also need to be injected")
18
- args = parser.parse_args()
19
-
20
-
21
- root_dir = args.mpt_root_dir
22
- save_root_dir = args.save_root_dir
23
-
24
- # prepare mpt model at first
25
- # you can visit https://huggingface.co/mosaicml to download 7B and 30B instruct checkpoints.
26
- config_file = "./flamingo/flamingo-mpt-1B-redpajama.json"
27
- state_dict_file = f"{root_dir}/pytorch_model.bin"
28
- save_path = f"{save_root_dir}/flamingo-mpt-1b-redpajama-200b-dolly"
29
-
30
- config = FlamingoConfig.from_json_file(config_file)
31
-
32
- model = FlamingoForConditionalGeneration(config=config)
33
-
34
- # Loading mpt weights
35
- state_dict = torch.load(state_dict_file, map_location="cpu")
36
- save_state_dict_1 = {}
37
- for key in state_dict:
38
- if ".blocks." in key:
39
- _, _, layer_num, *remain_names = key.split(".")
40
- target_key = f"transformer.blocks.{layer_num}.decoder_layer.{'.'.join(remain_names)}"
41
- else:
42
- target_key = key
43
- save_state_dict_1[f"{target_key}"] = state_dict[key]
44
-
45
- load_msg = model.lang_encoder.load_state_dict(
46
- save_state_dict_1,
47
- False,
48
- )
49
-
50
- # load flamingo's vision encoder from last checkpoint.
51
- # you can visit https://huggingface.co/luodian/openflamingo-9b-hf/tree/main to download the checkpoint.
52
- AZP = os.environ["AZP"]
53
- state_dict_3 = torch.load(f"{AZP}/pytorch_model-00004-of-00004.bin", map_location="cpu")
54
- for cur_key in list(state_dict_3.keys()):
55
- if "vision_encoder" not in cur_key:
56
- del state_dict_3[cur_key]
57
-
58
- load_msg = model.load_state_dict(
59
- state_dict_3,
60
- False,
61
- )
62
- # print incompatible keys
63
- print(load_msg[1])
64
-
65
- save_state_dict_1 = {}
66
- for key in state_dict:
67
- if ".blocks." in key:
68
- _, _, layer_num, *remain_names = key.split(".")
69
- target_key = f"transformer.blocks.{layer_num}.decoder_layer.{'.'.join(remain_names)}"
70
- else:
71
- target_key = key
72
- save_state_dict_1[f"{target_key}"] = state_dict[key]
73
-
74
- load_msg = model.lang_encoder.load_state_dict(
75
- save_state_dict_1,
76
- False,
77
- )
78
- # print incompatible keys
79
- print(load_msg[1])
80
- if args.flamingo_dir is not None:
81
- state_dict_2 = torch.load(f"{args.flamingo_dir}/checkpoint.pt", map_location="cpu")
82
- save_state_dict_2 = rename_flamingo_checkpoint(state_dict_2)
83
- real_vocab_size = config.text_config.vocab_size
84
- # Reshape the token embedding to 50280 for compatible
85
- model.lang_encoder.resize_token_embeddings(save_state_dict_2["lang_encoder.transformer.wte.weight"].shape[0])
86
-
87
- load_msg = model.load_state_dict(
88
- save_state_dict_2,
89
- False,
90
- )
91
- # print incompatible keys
92
- print(load_msg[1])
93
- # Reshape the token embedding to 50432
94
- model.lang_encoder.resize_token_embeddings(real_vocab_size)
95
-
96
- print(f"Saving model to {save_path}...")
97
- model.save_pretrained(save_path, max_shard_size="10GB")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mllm/flamingo/injecting_mpt_into_flamingo.py DELETED
@@ -1,109 +0,0 @@
1
- import argparse
2
- import os
3
-
4
- import torch
5
- from tqdm import tqdm
6
-
7
- import sys
8
-
9
- from configuration_flamingo import FlamingoConfig
10
- from modeling_flamingo import FlamingoForConditionalGeneration
11
- from utils import rename_flamingo_checkpoint
12
-
13
- parser = argparse.ArgumentParser(description="Convert MPT model")
14
- parser.add_argument("--model_choice", type=str, choices=["7B", "30B"], required=True, help="Choose either '7B' or '30B'")
15
- parser.add_argument("--mpt_root_dir", type=str, default="/home/luodian/projects/checkpoints")
16
- parser.add_argument("--save_root_dir", type=str, default="/home/luodian/projects/checkpoints")
17
- parser.add_argument("--flamingo_dir", type=str, default=None, help="If the pretrained flamingo weights also need to be injected")
18
- args = parser.parse_args()
19
-
20
- os.environ["TOKENIZERS_PARALLELISM"] = "false"
21
-
22
- root_dir = args.mpt_root_dir
23
- model_choice = args.model_choice
24
- save_root_dir = args.save_root_dir
25
-
26
- # prepare mpt model at first
27
- # you can visit https://huggingface.co/mosaicml to download 7B and 30B instruct checkpoints.
28
- if model_choice == "30B":
29
- config_file = "./flamingo/flamingo-mpt-30B.json"
30
- state_dict_files = [
31
- f"{root_dir}/mpt-30b-instruct/pytorch_model-00001-of-00007.bin",
32
- f"{root_dir}/mpt-30b-instruct/pytorch_model-00002-of-00007.bin",
33
- f"{root_dir}/mpt-30b-instruct/pytorch_model-00003-of-00007.bin",
34
- f"{root_dir}/mpt-30b-instruct/pytorch_model-00004-of-00007.bin",
35
- f"{root_dir}/mpt-30b-instruct/pytorch_model-00005-of-00007.bin",
36
- f"{root_dir}/mpt-30b-instruct/pytorch_model-00006-of-00007.bin",
37
- f"{root_dir}/mpt-30b-instruct/pytorch_model-00007-of-00007.bin",
38
- ]
39
- save_path = f"{save_root_dir}/flamingo-mpt-30B-instruct-init"
40
- elif model_choice == "7B":
41
- config_file = "./flamingo/flamingo-mpt-7B.json"
42
- state_dict_files = [
43
- f"{root_dir}/mpt-7b/pytorch_model-00001-of-00002.bin",
44
- f"{root_dir}/mpt-7b/pytorch_model-00002-of-00002.bin",
45
- ]
46
- save_path = f"{save_root_dir}/flamingo-mpt-7B"
47
- else:
48
- raise ValueError("Invalid model_choice. Choose either '30B' or '7B'.")
49
-
50
- config = FlamingoConfig.from_json_file(config_file)
51
-
52
- model = FlamingoForConditionalGeneration(config=config)
53
-
54
-
55
- # load flamingo's vision encoder from last checkpoint.
56
- # you can visit https://huggingface.co/luodian/openflamingo-9b-hf/tree/main to download the checkpoint.
57
- AZP = os.environ["AZP"]
58
- state_dict_3 = torch.load(f"{AZP}/otter/checkpoints/flamingo_9b_hf/pytorch_model-00004-of-00004.bin", map_location="cpu")
59
- for cur_key in list(state_dict_3.keys()):
60
- if "vision_encoder" not in cur_key:
61
- del state_dict_3[cur_key]
62
-
63
- load_msg = model.load_state_dict(
64
- state_dict_3,
65
- False,
66
- )
67
- # print incompatible keys
68
- print(load_msg[1])
69
-
70
- # Loading mpt weights
71
- state_dict = {}
72
- for file in tqdm(state_dict_files, desc="Loading state dict"):
73
- state_dict_part = torch.load(file, map_location="cpu")
74
- state_dict.update(state_dict_part)
75
-
76
- save_state_dict_1 = {}
77
- for key in state_dict:
78
- if ".blocks." in key:
79
- _, _, layer_num, *remain_names = key.split(".")
80
- target_key = f"transformer.blocks.{layer_num}.decoder_layer.{'.'.join(remain_names)}"
81
- else:
82
- target_key = key
83
- save_state_dict_1[f"{target_key}"] = state_dict[key]
84
-
85
- load_msg = model.lang_encoder.load_state_dict(
86
- save_state_dict_1,
87
- False,
88
- )
89
- # print incompatible keys
90
- print(load_msg[1])
91
- if args.flamingo_dir is not None:
92
- state_dict_2 = torch.load(f"{args.flamingo_dir}/checkpoint.pt", map_location="cpu")
93
- save_state_dict_2 = rename_flamingo_checkpoint(state_dict_2)
94
-
95
- real_vocab_size = config.text_config.vocab_size
96
- # Reshape the token embedding to 50280 for compatible
97
- model.lang_encoder.resize_token_embeddings(save_state_dict_2["lang_encoder.transformer.wte.weight"].shape[0])
98
-
99
- load_msg = model.load_state_dict(
100
- save_state_dict_2,
101
- False,
102
- )
103
- # print incompatible keys
104
- print(load_msg[1])
105
- # Reshape the token embedding to 50432
106
- model.lang_encoder.resize_token_embeddings(real_vocab_size)
107
-
108
- print(f"Saving model to {save_path}...")
109
- model.save_pretrained(save_path, max_shard_size="10GB")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mllm/flamingo/injecting_vicuna_into_flamingo.py DELETED
@@ -1,100 +0,0 @@
1
- import argparse
2
- import os
3
-
4
- import torch
5
- from tqdm import tqdm
6
-
7
- import sys
8
-
9
- from .configuration_flamingo import FlamingoConfig
10
- from .modeling_flamingo import FlamingoForConditionalGeneration
11
-
12
- # from .configuration_flamingo import FlamingoConfig
13
- # from .modeling_flamingo import FlamingoForConditionalGeneration
14
-
15
- parser = argparse.ArgumentParser(description="Convert Vicuna model")
16
- parser.add_argument("--model_choice", type=str, choices=["7B", "33B"], required=True, help="Choose either '7B' or '33B'")
17
- parser.add_argument("--vicuna_root_dir", type=str, default="/home/luodian/projects/checkpoints")
18
- parser.add_argument("--save_root_dir", type=str, default="/home/luodian/projects/checkpoints")
19
- parser.add_argument("--flamingo_dir", type=str, default=None, help="If the pretrained flamingo weights also need to be injected")
20
- args = parser.parse_args()
21
-
22
- os.environ["TOKENIZERS_PARALLELISM"] = "false"
23
-
24
- root_dir = args.vicuna_root_dir
25
- model_choice = args.model_choice
26
- save_root_dir = args.save_root_dir
27
-
28
- # prepare vicuna model at first
29
- # you can visit https://huggingface.co/lmsys/vicuna-33b-v1.3 to download 7B and 30B instruct checkpoints.
30
- if model_choice == "33B":
31
- config_file = "./flamingo/flamingo-vicuna-33B-v1.3.json"
32
- state_dict_files = [
33
- f"{root_dir}/vicuna-33b-v1.3/pytorch_model-00001-of-00007.bin",
34
- f"{root_dir}/vicuna-33b-v1.3/pytorch_model-00002-of-00007.bin",
35
- f"{root_dir}/vicuna-33b-v1.3/pytorch_model-00003-of-00007.bin",
36
- f"{root_dir}/vicuna-33b-v1.3/pytorch_model-00004-of-00007.bin",
37
- f"{root_dir}/vicuna-33b-v1.3/pytorch_model-00005-of-00007.bin",
38
- f"{root_dir}/vicuna-33b-v1.3/pytorch_model-00006-of-00007.bin",
39
- f"{root_dir}/vicuna-33b-v1.3/pytorch_model-00007-of-00007.bin",
40
- ]
41
- save_path = f"{save_root_dir}/flamingo-vicuna-33B-v1.3-init"
42
- elif model_choice == "7B":
43
- config_file = "./flamingo/flamingo-vicuna-7B-v1.3.json"
44
- state_dict_files = [
45
- f"{root_dir}/vicuna-7b-v1.3/pytorch_model-00001-of-00002.bin",
46
- f"{root_dir}/vicuna-7b-v1.3/pytorch_model-00002-of-00002.bin",
47
- ]
48
- save_path = f"{save_root_dir}/flamingo-vicuna-7B-v1.3-init"
49
- else:
50
- raise ValueError("Invalid model_choice. Choose either '33B' or '7B'.")
51
-
52
- config = FlamingoConfig.from_json_file(config_file)
53
- model = FlamingoForConditionalGeneration(config=config)
54
-
55
- # load flamingo's vision encoder from last checkpoint.
56
- # you can visit https://huggingface.co/luodian/openflamingo-9b-hf/tree/main to download the checkpoint.
57
- # AZP = "os.environ["AZP"]"
58
- AZP = os.environ["AZP"]
59
- state_dict_3 = torch.load(f"{AZP}/otter/checkpoints/flamingo_9b_hf/pytorch_model-00004-of-00004.bin", map_location="cpu")
60
- for cur_key in list(state_dict_3.keys()):
61
- if "vision_encoder" not in cur_key:
62
- del state_dict_3[cur_key]
63
-
64
- load_msg = model.load_state_dict(
65
- state_dict_3,
66
- False,
67
- )
68
- # print incompatible keys
69
- print(load_msg[1])
70
-
71
- # Loading vicuna weights
72
- state_dict = {}
73
- for file in tqdm(state_dict_files, desc="Loading state dict"):
74
- state_dict_part = torch.load(file, map_location="cpu")
75
- state_dict.update(state_dict_part)
76
-
77
- save_state_dict_1 = {}
78
- for key in state_dict:
79
- if ".layers." in key:
80
- _, _, layer_num, *remain_names = key.split(".")
81
- target_key = f"model.layers.{layer_num}.decoder_layer.{'.'.join(remain_names)}"
82
- else:
83
- target_key = key
84
- save_state_dict_1[f"{target_key}"] = state_dict[key]
85
-
86
- # Reshape the token embedding to 50280 for compatible
87
- model.lang_encoder.resize_token_embeddings(32000)
88
-
89
- load_msg = model.lang_encoder.load_state_dict(
90
- save_state_dict_1,
91
- False,
92
- )
93
- # Reshape the token embedding to 32002 for compatible
94
- model.lang_encoder.resize_token_embeddings(32002)
95
- # print incompatible keys
96
- print(load_msg[1])
97
-
98
-
99
- print(f"Saving model to {save_path}...")
100
- model.save_pretrained(save_path, max_shard_size="10GB")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mllm/flamingo/modeling_flamingo.py DELETED
@@ -1,966 +0,0 @@
1
- import random
2
- from dataclasses import dataclass
3
- from typing import Callable, Optional
4
-
5
- import torch
6
- import torch.nn as nn
7
- from accelerate.hooks import AlignDevicesHook, add_hook_to_module
8
- from einops import rearrange, repeat
9
- from transformers import CLIPVisionModel, LlamaForCausalLM, LlamaTokenizer
10
- from transformers.modeling_outputs import CausalLMOutputWithPast
11
- from transformers.modeling_utils import PreTrainedModel
12
- from transformers.models.auto import AutoModel, AutoModelForCausalLM, AutoTokenizer
13
-
14
- from .configuration_flamingo import FlamingoConfig
15
- from .falcon.modelling_RW import RWForCausalLM
16
- from .mpt.modeling_mpt import MPTForCausalLM
17
- from .mpt_redpajama.mosaic_gpt import MosaicGPT
18
-
19
- # from .configuration_flamingo import FlamingoConfig
20
-
21
- __KNOWN_DECODER_LAYERS_ATTR_NAMES = {
22
- "opt": "model.decoder.layers",
23
- "gptneo": "transformer.h",
24
- "gptj": "transformer.h",
25
- "gpt-j": "transformer.h",
26
- "pythia": "gpt_neox.layers",
27
- "llama": "model.layers",
28
- "RWForCausalLM": "transformer.h",
29
- "MPTForCausalLM": "transformer.blocks",
30
- "MosaicGPT": "transformer.blocks",
31
- }
32
-
33
-
34
- def _infer_decoder_layers_attr_name(model: nn.Module):
35
- for k in __KNOWN_DECODER_LAYERS_ATTR_NAMES:
36
- if k.lower() in model.__class__.__name__.lower():
37
- return __KNOWN_DECODER_LAYERS_ATTR_NAMES[k]
38
-
39
- raise ValueError(
40
- f"We require the attribute name for the nn.ModuleList in the decoder storing the transformer block layers. Please supply this string manually."
41
- )
42
-
43
-
44
- def extend_instance(obj, mixin):
45
- """Apply mixins to a class instance after creation"""
46
- base_cls = obj.__class__
47
- base_cls_name = obj.__class__.__name__
48
- obj.__class__ = type(base_cls_name, (mixin, base_cls), {}) # mixin needs to go first for our forward() logic to work
49
-
50
-
51
- def getattr_recursive(obj, att):
52
- """
53
- Return nested attribute of obj
54
- Example: getattr_recursive(obj, 'a.b.c') is equivalent to obj.a.b.c
55
- """
56
- if att == "":
57
- return obj
58
- i = att.find(".")
59
- if i < 0:
60
- return getattr(obj, att)
61
- else:
62
- return getattr_recursive(getattr(obj, att[:i]), att[i + 1 :])
63
-
64
-
65
- def setattr_recursive(obj, att, val):
66
- """
67
- Set nested attribute of obj
68
- Example: setattr_recursive(obj, 'a.b.c', val) is equivalent to obj.a.b.c = val
69
- """
70
- if "." in att:
71
- obj = getattr_recursive(obj, ".".join(att.split(".")[:-1]))
72
- setattr(obj, att.split(".")[-1], val)
73
-
74
-
75
- def exists(val):
76
- return val is not None
77
-
78
-
79
- class FlamingoPerceiverBlock(nn.Module):
80
- def __init__(self, *, dim: int, dim_head: int = 64, heads: int = 8, mult: int = 4):
81
- super().__init__()
82
- self.scale = dim_head**-0.5
83
- self.heads = heads
84
- inner_dim = dim_head * heads
85
- ff_dim = dim * mult
86
- self.norm_media = nn.LayerNorm(dim)
87
- self.norm_latents = nn.LayerNorm(dim)
88
-
89
- self.to_q = nn.Linear(dim, inner_dim, bias=False)
90
- self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
91
- self.to_out = nn.Linear(inner_dim, dim, bias=False)
92
- self.feed_forward = nn.ModuleList(
93
- [
94
- nn.LayerNorm(dim),
95
- nn.Linear(dim, ff_dim, bias=False),
96
- nn.GELU(),
97
- nn.Linear(ff_dim, dim, bias=False),
98
- ]
99
- )
100
-
101
- def forward(self, x: torch.Tensor, latents: torch.Tensor) -> torch.Tensor:
102
- """
103
- Args:
104
- x (torch.Tensor): image features
105
- shape (b, T, n1, D)
106
- latent (torch.Tensor): latent features
107
- shape (b, T, n2, D)
108
- """
109
- x = self.norm_media(x)
110
- residual_latents = latents
111
- latents = self.norm_latents(latents)
112
-
113
- h = self.heads
114
-
115
- q = self.to_q(latents)
116
- kv_input = torch.cat((x, latents), dim=-2)
117
- k, v = self.to_kv(kv_input).chunk(2, dim=-1)
118
- q = rearrange(q, "b t n (h d) -> b h t n d", h=h)
119
- k = rearrange(k, "b t n (h d) -> b h t n d", h=h)
120
- v = rearrange(v, "b t n (h d) -> b h t n d", h=h)
121
- q = q * self.scale
122
-
123
- # attention
124
- sim = torch.einsum("... i d, ... j d -> ... i j", q, k)
125
- sim = sim - sim.amax(dim=-1, keepdim=True).detach()
126
- attn = sim.softmax(dim=-1)
127
-
128
- out = torch.einsum("... i j, ... j d -> ... i d", attn, v)
129
- out = rearrange(out, "b h t n d -> b t n (h d)", h=h)
130
- out = self.to_out(out) + residual_latents
131
- residual_out = out
132
- for layer in self.feed_forward:
133
- out = layer(out)
134
- return out + residual_out
135
-
136
-
137
- class FlamingoPerceiverResampler(nn.Module):
138
- def __init__(
139
- self,
140
- *,
141
- dim: int,
142
- depth: int = 6,
143
- dim_head: int = 64,
144
- heads: int = 8,
145
- num_latents: int = 64,
146
- # max_num_frames: int = 128,
147
- max_num_media: Optional[int] = None,
148
- max_num_frames: Optional[int] = None,
149
- ff_mult: int = 4,
150
- ):
151
- super().__init__()
152
- self.latents = nn.Parameter(torch.randn(num_latents, dim))
153
- self.frame_embs = nn.Parameter(torch.randn(max_num_frames, dim)) if exists(max_num_frames) else None
154
- # self.frame_embs = nn.Parameter(torch.randn(max_num_frames, dim))
155
-
156
- self.media_time_embs = nn.Parameter(torch.randn(max_num_media, 1, dim)) if exists(max_num_media) else None
157
-
158
- self.layers = nn.ModuleList([])
159
- for _ in range(depth):
160
- self.layers.append(FlamingoPerceiverBlock(dim=dim, dim_head=dim_head, heads=heads, mult=ff_mult))
161
-
162
- self.norm = nn.LayerNorm(dim)
163
-
164
- def forward(self, x: torch.Tensor) -> torch.Tensor:
165
- """
166
- Args:
167
- x (torch.Tensor): image features
168
- shape (b, T, F, v, D)
169
- Returns:
170
- shape (b, T, n, D) where n is self.num_latents
171
- """
172
- b, T, F, v = x.shape[:4]
173
-
174
- # frame and media time embeddings
175
- if exists(self.frame_embs):
176
- frame_embs = repeat(self.frame_embs[:F], "F d -> b T F v d", b=b, T=T, v=v)
177
- x = x + frame_embs
178
- x = rearrange(x, "b T F v d -> b T (F v) d") # flatten the frame and spatial dimensions
179
- if exists(self.media_time_embs):
180
- x = x + self.media_time_embs[:T]
181
-
182
- # blocks
183
- latents = repeat(self.latents, "n d -> b T n d", b=b, T=T)
184
- for block in self.layers:
185
- latents = block(x, latents)
186
- return self.norm(latents)
187
-
188
-
189
- class FlamingoMaskedCrossAttention(nn.Module):
190
- def __init__(
191
- self,
192
- *,
193
- dim: int,
194
- dim_visual: int,
195
- dim_head: int = 64,
196
- heads: int = 8,
197
- only_attend_immediate_media: bool = True,
198
- ):
199
- super().__init__()
200
- self.scale = dim_head**-0.5
201
- self.heads = heads
202
- inner_dim = dim_head * heads
203
-
204
- self.norm = nn.LayerNorm(dim)
205
-
206
- self.to_q = nn.Linear(dim, inner_dim, bias=False)
207
- self.to_kv = nn.Linear(dim_visual, inner_dim * 2, bias=False)
208
- self.to_out = nn.Linear(inner_dim, dim, bias=False)
209
-
210
- # whether for text to only attend to immediate preceding image, or all previous images
211
- self.only_attend_immediate_media = only_attend_immediate_media
212
-
213
- def forward(
214
- self,
215
- x: torch.Tensor,
216
- media: torch.Tensor,
217
- media_locations: Optional[torch.BoolTensor] = None,
218
- attend_previous: bool = True,
219
- ) -> torch.Tensor:
220
- """
221
- Args:
222
- x (torch.Tensor): text features
223
- shape (B, T_txt, D_txt)
224
- media (torch.Tensor): image features
225
- shape (B, T_img, n, D_img) where n is the dim of the latents
226
- media_locations: boolean mask identifying the media tokens in x
227
- shape (B, T_txt)
228
- attend_previous: bool
229
- If false, ignores immediately preceding image and starts attending when following image
230
- """
231
- _, T_img, n = media.shape[:3]
232
- h = self.heads
233
-
234
- x = self.norm(x)
235
-
236
- q = self.to_q(x)
237
- media = rearrange(media, "b t n d -> b (t n) d")
238
-
239
- k, v = self.to_kv(media).chunk(2, dim=-1)
240
- q = rearrange(q, "b n (h d) -> b h n d", h=h)
241
- k = rearrange(k, "b n (h d) -> b h n d", h=h)
242
- v = rearrange(v, "b n (h d) -> b h n d", h=h)
243
-
244
- q = q * self.scale
245
-
246
- sim = torch.einsum("... i d, ... j d -> ... i j", q, k)
247
-
248
- if exists(media_locations):
249
- # at each boolean of True, increment the time counter (relative to media time)
250
- text_time = media_locations.cumsum(dim=-1)
251
- media_time = torch.arange(T_img, device=x.device) + 1
252
-
253
- if not attend_previous:
254
- text_time[~media_locations] += 1
255
- # make sure max is still the number of images in the sequence
256
- text_time[
257
- text_time
258
- > repeat(
259
- torch.count_nonzero(media_locations, dim=1),
260
- "b -> b i",
261
- i=text_time.shape[1],
262
- )
263
- ] = 0
264
-
265
- # text time must equal media time if only attending to most immediate image
266
- # otherwise, as long as text time is greater than media time (if attending to all previous images / media)
267
- mask_op = torch.eq if self.only_attend_immediate_media else torch.ge
268
-
269
- text_to_media_mask = mask_op(
270
- rearrange(text_time, "b i -> b 1 i 1"),
271
- repeat(media_time, "j -> 1 1 1 (j n)", n=n),
272
- )
273
- sim = sim.masked_fill(~text_to_media_mask, -torch.finfo(sim.dtype).max)
274
-
275
- sim = sim - sim.amax(dim=-1, keepdim=True).detach()
276
- attn = sim.softmax(dim=-1)
277
-
278
- if exists(media_locations) and self.only_attend_immediate_media:
279
- # any text without a preceding media needs to have attention zeroed out
280
- text_without_media_mask = text_time == 0
281
- text_without_media_mask = rearrange(text_without_media_mask, "b i -> b 1 i 1")
282
- attn = attn.masked_fill(text_without_media_mask, 0.0)
283
-
284
- out = torch.einsum("... i j, ... j d -> ... i d", attn, v)
285
- out = rearrange(out, "b h n d -> b n (h d)")
286
- return self.to_out(out)
287
-
288
-
289
- class FlamingoGatedCrossAttentionBlock(nn.Module):
290
- def __init__(
291
- self,
292
- *,
293
- dim: int,
294
- dim_visual: int,
295
- dim_head: int = 64,
296
- heads: int = 8,
297
- ff_mult: int = 4,
298
- only_attend_immediate_media: bool = True,
299
- ):
300
- super().__init__()
301
- self.attn = FlamingoMaskedCrossAttention(
302
- dim=dim,
303
- dim_visual=dim_visual,
304
- dim_head=dim_head,
305
- heads=heads,
306
- only_attend_immediate_media=only_attend_immediate_media,
307
- )
308
- self.attn_gate = nn.Parameter(torch.tensor([0.0]))
309
- self.feed_forward = nn.ModuleList(
310
- [
311
- nn.LayerNorm(dim),
312
- nn.Linear(dim, dim * ff_mult, bias=False),
313
- nn.GELU(),
314
- nn.Linear(dim * ff_mult, dim, bias=False),
315
- ]
316
- )
317
- self.ff_gate = nn.Parameter(torch.tensor([0.0]))
318
-
319
- def forward(
320
- self,
321
- x: torch.Tensor,
322
- media: torch.Tensor,
323
- media_locations: Optional[torch.BoolTensor] = None,
324
- attend_previous: bool = True,
325
- ) -> torch.Tensor:
326
- x = (
327
- self.attn(
328
- x,
329
- media,
330
- media_locations=media_locations,
331
- attend_previous=attend_previous,
332
- )
333
- * self.attn_gate.tanh()
334
- + x
335
- )
336
- residual_x = x
337
- for ff in self.feed_forward:
338
- x = ff(x)
339
- x = x * self.ff_gate.tanh() + residual_x
340
-
341
- return x
342
-
343
-
344
- class FlamingoLayer(nn.Module):
345
- def __init__(self, gated_cross_attn_layer: nn.Module, decoder_layer: nn.Module):
346
- super().__init__()
347
- self.gated_cross_attn_layer = gated_cross_attn_layer
348
- self.decoder_layer = decoder_layer
349
- self.vis_x = None
350
- self.media_locations = None
351
-
352
- def is_conditioned(self) -> bool:
353
- """Check whether the layer is conditioned."""
354
- return self.vis_x is not None
355
-
356
- # Used this great idea from this implementation of Flamingo (https://github.com/dhansmair/flamingo-mini/)
357
- def condition_vis_x(self, vis_x) -> None:
358
- self.vis_x = vis_x
359
-
360
- def condition_media_locations(self, media_locations) -> None:
361
- self.media_locations = media_locations
362
-
363
- def condition_attend_previous(self, attend_previous) -> None:
364
- self.attend_previous = attend_previous
365
-
366
- def forward(
367
- self,
368
- lang_x: torch.Tensor,
369
- attention_mask: Optional[torch.Tensor] = None,
370
- **decoder_layer_kwargs,
371
- ):
372
- if self.gated_cross_attn_layer is None:
373
- return self.decoder_layer(lang_x, attention_mask=attention_mask, **decoder_layer_kwargs)
374
-
375
- if self.vis_x is None:
376
- raise ValueError("vis_x must be conditioned before forward pass")
377
-
378
- if self.media_locations is None:
379
- raise ValueError("media_locations must be conditioned before forward pass")
380
-
381
- lang_x = self.gated_cross_attn_layer(
382
- lang_x,
383
- self.vis_x,
384
- media_locations=self.media_locations,
385
- attend_previous=self.attend_previous,
386
- )
387
- lang_x = self.decoder_layer(lang_x, attention_mask=attention_mask, **decoder_layer_kwargs)
388
- return lang_x
389
-
390
-
391
- class FlamingoLMMixin(nn.Module):
392
- """
393
- Mixin to add cross-attention layers to a language model.
394
- """
395
-
396
- def set_decoder_layers_attr_name(self, decoder_layers_attr_name):
397
- self.decoder_layers_attr_name = decoder_layers_attr_name
398
-
399
- def _get_decoder_layers(self):
400
- return getattr_recursive(self, self.decoder_layers_attr_name)
401
-
402
- def _set_decoder_layers(self, value):
403
- setattr_recursive(self, self.decoder_layers_attr_name, value)
404
-
405
- def init_flamingo(
406
- self,
407
- media_token_id: int,
408
- vis_hidden_size: int,
409
- cross_attn_every_n_layers: int,
410
- use_media_placement_augmentation: bool,
411
- ):
412
- """
413
- Initialize Flamingo by adding a new gated cross attn to the decoder. Store the media token id for computing the media locations.
414
- """
415
-
416
- gated_cross_attn_layers = nn.ModuleList(
417
- [
418
- FlamingoGatedCrossAttentionBlock(
419
- dim=self.config.hidden_size,
420
- dim_visual=vis_hidden_size,
421
- )
422
- if (layer_idx + 1) % cross_attn_every_n_layers == 0
423
- else None
424
- for layer_idx, _ in enumerate(self._get_decoder_layers())
425
- ]
426
- )
427
- self._set_decoder_layers(
428
- nn.ModuleList(
429
- [
430
- FlamingoLayer(gated_cross_attn_layer, decoder_layer)
431
- for gated_cross_attn_layer, decoder_layer in zip(gated_cross_attn_layers, self._get_decoder_layers())
432
- ]
433
- )
434
- )
435
- self.media_token_id = media_token_id
436
- self.use_media_placement_augmentation = use_media_placement_augmentation
437
- self.initialized_flamingo = True
438
-
439
- def forward(self, *input, **kwargs):
440
- """Condition the Flamingo layers on the media locations before forward()"""
441
- if not self.initialized_flamingo:
442
- raise ValueError("Flamingo layers are not initialized. Please call `init_flamingo` first.")
443
-
444
- input_ids = kwargs["input_ids"] if "input_ids" in kwargs else input[0]
445
- media_locations = input_ids == self.media_token_id
446
- # IMPORTANT: Force `attend_previous` to True when we place training data as <image>caption<|endofchunk|>
447
- # attend_previous = (
448
- # (random.random() < 0.5) if self.use_media_placement_augmentation else False
449
- # )
450
- attend_previous = (random.random() < 0.5) if self.use_media_placement_augmentation else True
451
- # attend_previous = self.only_attend_previous
452
-
453
- if self.__class__.__name__ == "LlamaForCausalLM":
454
- for layer in self.get_decoder().layers:
455
- layer.condition_media_locations(media_locations)
456
- layer.condition_attend_previous(attend_previous)
457
- elif self.__class__.__name__ in ["MPTForCausalLM", "MosaicGPT"]:
458
- for layer in self.get_decoder().blocks:
459
- layer.condition_media_locations(media_locations)
460
- layer.condition_attend_previous(attend_previous)
461
- else:
462
- print("inavaliable text encoder")
463
- return super().forward(*input, **kwargs) # Call the other parent's forward method
464
-
465
- def is_conditioned(self) -> bool:
466
- """Check whether all decoder layers are already conditioned."""
467
- return all(l.is_conditioned() for l in self._get_decoder_layers())
468
-
469
- def clear_conditioned_layers(self) -> None:
470
- for layer in self._get_decoder_layers():
471
- layer.condition_vis_x(None)
472
- layer.condition_media_locations(None)
473
- layer.condition_attend_previous(None)
474
-
475
-
476
- class FlamingoPreTrainedModel(PreTrainedModel):
477
- """
478
- An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
479
- models.
480
- """
481
-
482
- config_class = FlamingoConfig
483
- base_model_prefix = "flamingo"
484
- supports_gradient_checkpointing = True
485
- _no_split_modules = ["FlamingoPerceiverBlock", "CLIPEncoderLayer", "FlamingoLayer"]
486
-
487
- def _init_weights(self, module):
488
- """Flamingo requires no specific initialization"""
489
- return super()._init_weights(module)
490
-
491
- def _set_gradient_checkpointing(self, module, value=False):
492
- if isinstance(module, FlamingoModel):
493
- module.gradient_checkpointing = value
494
-
495
-
496
- class FlamingoModel(FlamingoPreTrainedModel):
497
- config_class = FlamingoConfig
498
-
499
- def __init__(
500
- self,
501
- config: FlamingoConfig,
502
- ):
503
- super().__init__(config)
504
- ### TODO: give "LlamaForCausalLM" as the name of text_config.architectures of Llama_based flamingo
505
- if "llama" not in config.text_config._name_or_path:
506
- if config.text_config.architectures[0] == "MPTForCausalLM":
507
- text_tokenizer = AutoTokenizer.from_pretrained("mosaicml/mpt-7b-instruct")
508
- lang_encoder = MPTForCausalLM(config=config.text_config)
509
- elif config.text_config.text_config.architectures[0] == "MosaicGPT":
510
- text_tokenizer = AutoTokenizer.from_pretrained("mosaicml/mosaic-llama-redpajama-final-candidate")
511
- lang_encoder = MosaicGPT(config=config.text_config)
512
- elif config.text_config.architectures[0] == "RWForCausalLM":
513
- text_tokenizer = AutoTokenizer.from_pretrained("PATH-TO-YOUR-FALCON")
514
- lang_encoder = RWForCausalLM(config=config.text_config)
515
- else:
516
- text_tokenizer = LlamaTokenizer.from_pretrained(config.text_config._name_or_path)
517
- lang_encoder = LlamaForCausalLM(config=config.text_config)
518
-
519
- vision_encoder = CLIPVisionModel(config=config.vision_config)
520
- text_tokenizer.add_special_tokens({"additional_special_tokens": ["<|endofchunk|>", "<image>"]})
521
- if text_tokenizer.pad_token is None:
522
- text_tokenizer.add_special_tokens({"pad_token": "<PAD>"})
523
- self.text_tokenizer = text_tokenizer
524
- self.eoc_token_id = text_tokenizer.encode("<|endofchunk|>")[-1]
525
- self.media_token_id = text_tokenizer.encode("<image>")[-1]
526
-
527
- extend_instance(lang_encoder, FlamingoLMMixin)
528
- decoder_layers_attr_name = _infer_decoder_layers_attr_name(lang_encoder)
529
- lang_encoder.set_decoder_layers_attr_name(decoder_layers_attr_name)
530
- if lang_encoder.__class__.__name__ == "LlamaForCausalLM":
531
- lang_encoder.resize_token_embeddings(len(text_tokenizer))
532
- self.lang_encoder = lang_encoder
533
-
534
- self.cross_attn_every_n_layers = config.cross_attn_every_n_layers if hasattr(config, "cross_attn_every_n_layers") else 4
535
- self.use_media_placement_augmentation = config.use_media_placement_augmentation
536
-
537
- vision_encoder.output_tokens = True
538
- self.vision_encoder = vision_encoder
539
-
540
- self.vis_dim = 1024
541
- self.perceiver = FlamingoPerceiverResampler(dim=self.vis_dim)
542
-
543
- self.lang_encoder.init_flamingo(
544
- media_token_id=self.media_token_id,
545
- vis_hidden_size=self.vis_dim,
546
- cross_attn_every_n_layers=self.cross_attn_every_n_layers,
547
- use_media_placement_augmentation=self.use_media_placement_augmentation,
548
- )
549
- self.post_init()
550
-
551
- def get_input_embeddings(self) -> nn.Module:
552
- return self.lang_encoder.get_input_embeddings()
553
-
554
- def set_input_embeddings(self, new_embeddings):
555
- self.lang_encoder.set_input_embeddings(new_embeddings)
556
-
557
- def get_output_embeddings(self) -> nn.Module:
558
- return self.lang_encoder.get_output_embeddings()
559
-
560
- def set_output_embeddings(self, new_embeddings):
561
- self.lang_encoder.set_output_embeddings(new_embeddings)
562
-
563
- def get_image_encoder(self) -> nn.Module:
564
- return self.vision_encoder
565
-
566
- def get_lang_encoder(self) -> nn.Module:
567
- return self.lang_encoder
568
-
569
- # def init_weights(self):
570
- # # Freeze all parameters in vision encoder
571
- # for param in self.vision_encoder.parameters():
572
- # param.requires_grad = False
573
- # # Freeze all parameters in lang encoders except gated_cross_attn_layers
574
- # for name, param in self.lang_encoder.named_parameters():
575
- # if "gated_cross_attn_layer" not in name:
576
- # param.requires_grad = False
577
- # # Unfreeze LM input embeddings
578
- # self.lang_encoder.get_input_embeddings().requires_grad_(True)
579
- # ## MPTForCausalLM is tied word embedding
580
- # if self.lang_encoder.__class__.__name__ == "LlamaForCausalLM":
581
- # self.lang_encoder.lm_head.requires_grad_(True)
582
- # # assert sum(p.numel() for p in model.parameters() if p.requires_grad) == 0
583
- # # print model size in billions of parameters in 2 decimal places
584
- # print(f"Trainable param: {(sum(p.numel() for p in self.parameters() if p.requires_grad)) / 1e9:.2f} B")
585
-
586
- def init_weights(self):
587
- # Freeze all parameters in vision encoder
588
- for param in self.vision_encoder.parameters():
589
- param.requires_grad = False
590
-
591
- if "lora_config" in self.config.__dict__:
592
- print(f"LoRA trainable param: {(sum(p.numel() for p in self.lang_encoder.parameters() if p.requires_grad)) / 1e9:.3f} B")
593
- # Unfreeze gated_cross_attn_layers
594
- for layer in self.lang_encoder._get_decoder_layers():
595
- if layer.gated_cross_attn_layer is not None:
596
- for param in layer.gated_cross_attn_layer.parameters():
597
- param.requires_grad = True
598
- else:
599
- # Freeze all parameters in lang encoders except gated_cross_attn_layers
600
- for name, param in self.lang_encoder.named_parameters():
601
- if "gated_cross_attn_layer" not in name:
602
- param.requires_grad = False
603
- # Unfreeze LM input and output embeddings
604
- self.lang_encoder.get_input_embeddings().requires_grad_(True)
605
- ## MPTForCausalLM is tied word embedding
606
- if self.lang_encoder.__class__.__name__ == "LlamaForCausalLM":
607
- self.lang_encoder.lm_head.requires_grad_(True)
608
- # assert sum(p.numel() for p in model.parameters() if p.requires_grad) == 0
609
- # print model size in billions of parameters in 2 decimal places
610
- print(f"Total Trainable param: {(sum(p.numel() for p in self.parameters() if p.requires_grad)) / 1e9:.3f} B")
611
-
612
- def forward(
613
- self,
614
- vision_x: torch.Tensor,
615
- lang_x: torch.Tensor,
616
- attention_mask: Optional[torch.Tensor] = None,
617
- labels: Optional[torch.Tensor] = None,
618
- use_cached_vision_x: bool = False,
619
- clear_conditioned_layers: bool = True,
620
- past_key_values: Optional[torch.Tensor] = None,
621
- use_cache: bool = False,
622
- **kwargs,
623
- ) -> CausalLMOutputWithPast:
624
- """
625
- Forward pass of Flamingo.
626
-
627
- Args:
628
- vision_x (torch.Tensor): Vision input
629
- shape (B, T_img, F, C, H, W) with F=1
630
- lang_x (torch.Tensor): Language input ids
631
- shape (B, T_txt)
632
- attention_mask (torch.Tensor, optional): Attention mask. Defaults to None.
633
- labels (torch.Tensor, optional): Labels. Defaults to None.
634
- clear_conditioned_layers: if True, clear the conditioned layers
635
- once the foward pass is completed. Set this to false if the
636
- same set of images will be reused in another subsequent
637
- forward pass.
638
- past_key_values: pre-computed values to pass to language model.
639
- See past_key_values documentation in Hugging Face
640
- CausalLM models.
641
- use_cache: whether to use cached key values. See use_cache
642
- documentation in Hugging Face CausalLM models.
643
- """
644
- assert (vision_x is not None) or use_cached_vision_x, "Must provide either vision_x or use_cached_vision_x to True."
645
-
646
- if use_cached_vision_x:
647
- # Case: use cached; vision_x should be cached and other
648
- # vision-related inputs should not be provided.
649
- assert vision_x is None, "Expect vision_x to be None when use_cached_vision_x is True."
650
- assert self.lang_encoder.is_conditioned()
651
-
652
- else:
653
- # Case: do not use caching (i.e. this is a standard forward pass);
654
- self._encode_vision_x(vision_x=vision_x)
655
-
656
- output = self.lang_encoder(
657
- input_ids=lang_x,
658
- attention_mask=attention_mask,
659
- labels=labels,
660
- past_key_values=past_key_values,
661
- use_cache=use_cache,
662
- **kwargs,
663
- )
664
-
665
- if clear_conditioned_layers:
666
- self.lang_encoder.clear_conditioned_layers()
667
-
668
- return output
669
-
670
- def _encode_vision_x(self, vision_x: torch.Tensor):
671
- """
672
- Compute media tokens from vision input by passing it through vision encoder and conditioning language model.
673
- Args:
674
- vision_x (torch.Tensor): Vision input
675
- shape (B, T_img, F, C, H, W)
676
- Images in the same chunk are collated along T_img, and frames are collated along F
677
- Currently only F=1 is supported (single-frame videos)
678
-
679
- rearrange code based on https://github.com/dhansmair/flamingo-mini
680
- """
681
-
682
- assert vision_x.ndim == 6, "vision_x should be of shape (b, T_img, F, C, H, W)"
683
- b, T, F = vision_x.shape[:3]
684
- assert F == 1, "Only single frame supported"
685
-
686
- vision_x = rearrange(vision_x, "b T F c h w -> (b T F) c h w")
687
- with torch.no_grad():
688
- vision_x = self.vision_encoder(vision_x)[0][:, 1:, :]
689
- vision_x = rearrange(vision_x, "(b T F) v d -> b T F v d", b=b, T=T, F=F)
690
-
691
- vision_x = self.perceiver(vision_x) # reshapes to (b, T, n, d)
692
-
693
- for layer in self.lang_encoder._get_decoder_layers():
694
- layer.condition_vis_x(vision_x)
695
-
696
-
697
- class FlamingoForConditionalGeneration(FlamingoPreTrainedModel):
698
- config_class = FlamingoConfig
699
-
700
- def __init__(
701
- self,
702
- config: FlamingoConfig,
703
- ):
704
- super().__init__(config)
705
- # TODO: hardcode right because autoXXX is too slow
706
- # vision_encoder = AutoModel.from_config(config.vision_config).vision_model
707
- # lang_encoder = AutoModelForCausalLM.from_config(config.text_config)
708
- # text_tokenizer = AutoTokenizer.from_pretrained(config.text_config._name_or_path)
709
-
710
- ### TODO: give "LlamaForCausalLM" as the name of text_config.architectures of Llama_based flamingo
711
- # assert hasattr(config.text_config, "_name_or_path")
712
- # if "llama" not in config.text_config._name_or_path.lower():
713
- if config.text_config.architectures[0] == "MPTForCausalLM":
714
- text_tokenizer = AutoTokenizer.from_pretrained("mosaicml/mpt-7b-instruct")
715
- lang_encoder = MPTForCausalLM(config=config.text_config)
716
- elif config.text_config.architectures[0] == "MosaicGPT":
717
- text_tokenizer = AutoTokenizer.from_pretrained("mosaicml/mosaic-llama-redpajama-final-candidate")
718
- lang_encoder = MosaicGPT(config=config.text_config)
719
- elif config.text_config.architectures[0] == "RWForCausalLM":
720
- text_tokenizer = AutoTokenizer.from_pretrained("PATH-TO-YOUR-FALCON")
721
- lang_encoder = RWForCausalLM(config=config.text_config)
722
- # TODO: what's the logic here?
723
- elif config.text_config.architectures[0] == "LlamaForCausalLM":
724
- text_tokenizer = LlamaTokenizer.from_pretrained(config.text_config._name_or_path)
725
- lang_encoder = LlamaForCausalLM(config=config.text_config)
726
- else:
727
- import pdb
728
-
729
- pdb.set_trace()
730
- # else:
731
- # text_tokenizer = LlamaTokenizer.from_pretrained(config.text_config._name_or_path)
732
- # lang_encoder = LlamaForCausalLM(config=config.text_config)
733
-
734
- vision_encoder = CLIPVisionModel(config=config.vision_config)
735
- text_tokenizer.add_special_tokens({"additional_special_tokens": ["<|endofchunk|>", "<image>"]})
736
- if text_tokenizer.pad_token is None:
737
- text_tokenizer.add_special_tokens({"pad_token": "<PAD>"})
738
- self.text_tokenizer = text_tokenizer
739
- self.eoc_token_id = text_tokenizer.encode("<|endofchunk|>")[-1]
740
- self.media_token_id = text_tokenizer.encode("<image>")[-1]
741
-
742
- extend_instance(lang_encoder, FlamingoLMMixin)
743
- decoder_layers_attr_name = _infer_decoder_layers_attr_name(lang_encoder)
744
- lang_encoder.set_decoder_layers_attr_name(decoder_layers_attr_name)
745
- if "LlamaForCausalLM" in lang_encoder.__class__.__name__:
746
- lang_encoder.resize_token_embeddings(len(text_tokenizer))
747
- self.lang_encoder = lang_encoder
748
-
749
- self.cross_attn_every_n_layers = config.cross_attn_every_n_layers if hasattr(config, "cross_attn_every_n_layers") else 4
750
- self.use_media_placement_augmentation = config.use_media_placement_augmentation
751
-
752
- vision_encoder.output_tokens = True
753
- self.vision_encoder = vision_encoder
754
-
755
- self.vis_dim = 1024
756
- self.perceiver = FlamingoPerceiverResampler(dim=self.vis_dim)
757
-
758
- self.lang_encoder.init_flamingo(
759
- media_token_id=self.media_token_id,
760
- vis_hidden_size=self.vis_dim,
761
- cross_attn_every_n_layers=self.cross_attn_every_n_layers,
762
- use_media_placement_augmentation=self.use_media_placement_augmentation,
763
- )
764
- self.post_init()
765
-
766
- def get_input_embeddings(self) -> nn.Module:
767
- return self.lang_encoder.get_input_embeddings()
768
-
769
- def set_input_embeddings(self, new_embeddings):
770
- self.lang_encoder.set_input_embeddings(new_embeddings)
771
-
772
- def get_output_embeddings(self) -> nn.Module:
773
- return self.lang_encoder.get_output_embeddings()
774
-
775
- def set_output_embeddings(self, new_embeddings):
776
- self.lang_encoder.set_output_embeddings(new_embeddings)
777
-
778
- def get_image_encoder(self) -> nn.Module:
779
- return self.vision_encoder
780
-
781
- def get_lang_encoder(self) -> nn.Module:
782
- return self.lang_encoder
783
-
784
- def init_weights(self):
785
- # Freeze all parameters in vision encoder
786
- for param in self.vision_encoder.parameters():
787
- param.requires_grad = False
788
- # Freeze all parameters in lang encoders except gated_cross_attn_layers
789
- for name, param in self.lang_encoder.named_parameters():
790
- if "gated_cross_attn_layer" not in name:
791
- param.requires_grad = False
792
- # Unfreeze LM input embeddings
793
- self.lang_encoder.get_input_embeddings().requires_grad_(True)
794
- ## MPTForCausalLM is tied word embedding
795
- if "LlamaForCausalLM" in self.lang_encoder.__class__.__name__:
796
- self.lang_encoder.lm_head.requires_grad_(True)
797
- # assert sum(p.numel() for p in model.parameters() if p.requires_grad) == 0
798
- # print model size in billions of parameters in 2 decimal places
799
- print("====================Model Grad Part====================")
800
- total_params = 0
801
- for name, param in self.named_parameters():
802
- if param.requires_grad:
803
- total_params += param.numel()
804
- print(f"Parameter: {name}, Size: {param.numel() / 1e6:.6f} M")
805
- print(f"Total Trainable param: {total_params / 1e9:.4f} B")
806
- print(f"Total Trainable param: {(sum(p.numel() for p in self.parameters() if p.requires_grad)) / 1e9:.3f} B")
807
-
808
- def forward(
809
- self,
810
- vision_x: torch.Tensor,
811
- lang_x: torch.Tensor,
812
- attention_mask: Optional[torch.Tensor] = None,
813
- labels: Optional[torch.Tensor] = None,
814
- use_cached_vision_x: bool = False,
815
- clear_conditioned_layers: bool = True,
816
- past_key_values: Optional[torch.Tensor] = None,
817
- use_cache: bool = False,
818
- **kwargs,
819
- ) -> CausalLMOutputWithPast:
820
- """
821
- Forward pass of Flamingo.
822
-
823
- Args:
824
- vision_x (torch.Tensor): Vision input
825
- shape (B, T_img, F, C, H, W) with F=1
826
- lang_x (torch.Tensor): Language input ids
827
- shape (B, T_txt)
828
- attention_mask (torch.Tensor, optional): Attention mask. Defaults to None.
829
- labels (torch.Tensor, optional): Labels. Defaults to None.
830
- clear_conditioned_layers: if True, clear the conditioned layers
831
- once the foward pass is completed. Set this to false if the
832
- same set of images will be reused in another subsequent
833
- forward pass.
834
- past_key_values: pre-computed values to pass to language model.
835
- See past_key_values documentation in Hugging Face
836
- CausalLM models.
837
- use_cache: whether to use cached key values. See use_cache
838
- documentation in Hugging Face CausalLM models.
839
- """
840
- assert (vision_x is not None) or use_cached_vision_x, "Must provide either vision_x or use_cached_vision_x to True."
841
-
842
- if use_cached_vision_x:
843
- # Case: use cached; vision_x should be cached and other
844
- # vision-related inputs should not be provided.
845
- assert vision_x is None, "Expect vision_x to be None when use_cached_vision_x is True."
846
- assert self.lang_encoder.is_conditioned()
847
-
848
- else:
849
- # Case: do not use caching (i.e. this is a standard forward pass);
850
- self._encode_vision_x(vision_x=vision_x)
851
-
852
- output = self.lang_encoder(
853
- input_ids=lang_x,
854
- attention_mask=attention_mask,
855
- labels=labels,
856
- past_key_values=past_key_values,
857
- use_cache=use_cache,
858
- **kwargs,
859
- )
860
-
861
- if clear_conditioned_layers:
862
- self.lang_encoder.clear_conditioned_layers()
863
-
864
- return output
865
-
866
- def _encode_vision_x(self, vision_x: torch.Tensor):
867
- """
868
- Compute media tokens from vision input by passing it through vision encoder and conditioning language model.
869
- Args:
870
- vision_x (torch.Tensor): Vision input
871
- shape (B, T_img, F, C, H, W)
872
- Images in the same chunk are collated along T_img, and frames are collated along F
873
- Currently only F=1 is supported (single-frame videos)
874
-
875
- rearrange code based on https://github.com/dhansmair/flamingo-mini
876
- """
877
-
878
- assert vision_x.ndim == 6, "vision_x should be of shape (b, T_img, F, C, H, W)"
879
- b, T, F = vision_x.shape[:3]
880
- # assert F == 1, "Only single frame supported"
881
-
882
- vision_x = rearrange(vision_x, "b T F c h w -> (b T F) c h w")
883
- with torch.no_grad():
884
- vision_x = self.vision_encoder(vision_x)[0][:, 1:, :]
885
- vision_x = rearrange(vision_x, "(b T F) v d -> b T F v d", b=b, T=T, F=F)
886
-
887
- vision_x = self.perceiver(vision_x) # reshapes to (b, T, n, d)
888
-
889
- for layer in self.lang_encoder._get_decoder_layers():
890
- layer.condition_vis_x(vision_x)
891
-
892
- @torch.no_grad()
893
- def generate(
894
- self,
895
- vision_x: torch.Tensor,
896
- lang_x: torch.Tensor,
897
- attention_mask: Optional[torch.Tensor] = None,
898
- num_beams: int = 1,
899
- max_new_tokens: Optional[int] = None,
900
- temperature: float = 1.0,
901
- top_k: int = 0,
902
- top_p: float = 1.0,
903
- no_repeat_ngram_size: int = 0,
904
- prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], list[int]]] = None,
905
- length_penalty: float = 1.0,
906
- num_return_sequences: int = 1,
907
- do_sample: bool = False,
908
- early_stopping: bool = False,
909
- **kwargs,
910
- ):
911
- """
912
- Generate text conditioned on vision and language inputs.
913
-
914
- Args:
915
- vision_x (torch.Tensor): Vision input
916
- shape (B, T_img, F, C, H, W)
917
- images in the same chunk are collated along T_img, and frames are collated along F
918
- currently only F=1 is supported (single-frame videos)
919
- lang_x (torch.Tensor): Language input
920
- shape (B, T_txt)
921
- max_length (int, optional): Maximum length of the output. Defaults to None.
922
- attention_mask (torch.Tensor, optional): Attention mask. Defaults to None.
923
- num_beams (int, optional): Number of beams. Defaults to 1.
924
- max_new_tokens (int, optional): Maximum new tokens. Defaults to None.
925
- temperature (float, optional): Temperature. Defaults to 1.0.
926
- top_k (int, optional): Top k. Defaults to 0.
927
- top_p (float, optional): Top p. Defaults to 1.0.
928
- no_repeat_ngram_size (int, optional): No repeat ngram size. Defaults to 0.
929
- length_penalty (float, optional): Length penalty. Defaults to 1.0.
930
- num_return_sequences (int, optional): Number of return sequences. Defaults to 1.
931
- do_sample (bool, optional): Do sample. Defaults to False.
932
- early_stopping (bool, optional): Early stopping. Defaults to False.
933
- Returns:
934
- torch.Tensor: lang_x with generated tokens appended to it
935
- """
936
- if hasattr(self, "_hf_hook"):
937
- # add a hook to make sure that the output of lang_encoder is mapped to the same device as the lang_x
938
- hook = AlignDevicesHook(
939
- execution_device=lang_x.device,
940
- io_same_device=True,
941
- place_submodules=False,
942
- )
943
- add_hook_to_module(self.lang_encoder, hook)
944
- if num_beams > 1:
945
- vision_x = vision_x.repeat_interleave(num_beams, dim=0)
946
- self._encode_vision_x(vision_x=vision_x)
947
- output = self.lang_encoder.generate(
948
- lang_x,
949
- attention_mask=attention_mask,
950
- eos_token_id=self.eoc_token_id,
951
- num_beams=num_beams,
952
- max_new_tokens=max_new_tokens,
953
- temperature=temperature,
954
- top_k=top_k,
955
- top_p=top_p,
956
- prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
957
- no_repeat_ngram_size=no_repeat_ngram_size,
958
- length_penalty=length_penalty,
959
- num_return_sequences=num_return_sequences,
960
- do_sample=do_sample,
961
- early_stopping=early_stopping,
962
- **kwargs,
963
- )
964
-
965
- self.lang_encoder.clear_conditioned_layers()
966
- return output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mllm/flamingo/mpt/__init__.py DELETED
File without changes
mllm/flamingo/mpt/__pycache__/__init__.cpython-39.pyc DELETED
Binary file (206 Bytes)
 
mllm/flamingo/mpt/__pycache__/attention.cpython-39.pyc DELETED
Binary file (12.2 kB)
 
mllm/flamingo/mpt/__pycache__/blocks.cpython-39.pyc DELETED
Binary file (2.81 kB)
 
mllm/flamingo/mpt/__pycache__/configuration_mpt.cpython-39.pyc DELETED
Binary file (8.76 kB)
 
mllm/flamingo/mpt/__pycache__/custom_embedding.cpython-39.pyc DELETED
Binary file (797 Bytes)
 
mllm/flamingo/mpt/__pycache__/flash_attn_triton.cpython-39.pyc DELETED
Binary file (20.9 kB)
 
mllm/flamingo/mpt/__pycache__/modeling_mpt.cpython-39.pyc DELETED
Binary file (15.3 kB)
 
mllm/flamingo/mpt/__pycache__/norm.cpython-39.pyc DELETED
Binary file (3.03 kB)
 
mllm/flamingo/mpt/__pycache__/param_init_fns.cpython-39.pyc DELETED
Binary file (9.14 kB)
 
mllm/flamingo/mpt/adapt_tokenizer.py DELETED
@@ -1,44 +0,0 @@
1
- from typing import Union
2
- from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast
3
-
4
- Tokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
5
- NUM_SENTINEL_TOKENS: int = 100
6
-
7
-
8
- def adapt_tokenizer_for_denoising(tokenizer: Tokenizer):
9
- """Adds sentinel tokens and padding token (if missing).
10
-
11
- Expands the tokenizer vocabulary to include sentinel tokens
12
- used in mixture-of-denoiser tasks as well as a padding token.
13
-
14
- All added tokens are added as special tokens. No tokens are
15
- added if sentinel tokens and padding token already exist.
16
- """
17
- sentinels_to_add = [f"<extra_id_{i}>" for i in range(NUM_SENTINEL_TOKENS)]
18
- tokenizer.add_tokens(sentinels_to_add, special_tokens=True)
19
- if tokenizer.pad_token is None:
20
- tokenizer.add_tokens("<pad>", special_tokens=True)
21
- tokenizer.pad_token = "<pad>"
22
- assert tokenizer.pad_token_id is not None
23
- sentinels = "".join([f"<extra_id_{i}>" for i in range(NUM_SENTINEL_TOKENS)])
24
- _sentinel_token_ids = tokenizer(sentinels, add_special_tokens=False).input_ids
25
- tokenizer.sentinel_token_ids = _sentinel_token_ids
26
-
27
-
28
- class AutoTokenizerForMOD(AutoTokenizer):
29
- """AutoTokenizer + Adaptation for MOD.
30
-
31
- A simple wrapper around AutoTokenizer to make instantiating
32
- an MOD-adapted tokenizer a bit easier.
33
-
34
- MOD-adapted tokenizers have sentinel tokens (e.g., <extra_id_0>),
35
- a padding token, and a property to get the token ids of the
36
- sentinel tokens.
37
- """
38
-
39
- @classmethod
40
- def from_pretrained(cls, *args, **kwargs):
41
- """See `AutoTokenizer.from_pretrained` docstring."""
42
- tokenizer = super().from_pretrained(*args, **kwargs)
43
- adapt_tokenizer_for_denoising(tokenizer)
44
- return tokenizer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mllm/flamingo/mpt/attention.py DELETED
@@ -1,450 +0,0 @@
1
- """Attention layers."""
2
- import math
3
- import warnings
4
- from typing import Optional
5
- import torch
6
- import torch.nn as nn
7
- from einops import rearrange
8
- from packaging import version
9
- from torch import nn
10
- from .norm import LPLayerNorm
11
-
12
-
13
- def _reset_is_causal(num_query_tokens: int, num_key_tokens: int, original_is_causal: bool):
14
- if original_is_causal and num_query_tokens != num_key_tokens:
15
- if num_query_tokens != 1:
16
- raise NotImplementedError("MPT does not support query and key with different number of tokens, unless number of query tokens is 1.")
17
- else:
18
- return False
19
- return original_is_causal
20
-
21
-
22
- def scaled_multihead_dot_product_attention(
23
- query,
24
- key,
25
- value,
26
- n_heads,
27
- past_key_value=None,
28
- softmax_scale=None,
29
- attn_bias=None,
30
- key_padding_mask=None,
31
- is_causal=False,
32
- dropout_p=0.0,
33
- training=False,
34
- needs_weights=False,
35
- multiquery=False,
36
- ):
37
- q = rearrange(query, "b s (h d) -> b h s d", h=n_heads)
38
- kv_n_heads = 1 if multiquery else n_heads
39
- k = rearrange(key, "b s (h d) -> b h d s", h=kv_n_heads)
40
- v = rearrange(value, "b s (h d) -> b h s d", h=kv_n_heads)
41
- if past_key_value is not None:
42
- if len(past_key_value) != 0:
43
- k = torch.cat([past_key_value[0], k], dim=3)
44
- v = torch.cat([past_key_value[1], v], dim=2)
45
- past_key_value = (k, v)
46
- (b, _, s_q, d) = q.shape
47
- s_k = k.size(-1)
48
- if softmax_scale is None:
49
- softmax_scale = 1 / math.sqrt(d)
50
- attn_weight = q.matmul(k) * softmax_scale
51
- if attn_bias is not None:
52
- _s_q = max(0, attn_bias.size(2) - s_q)
53
- _s_k = max(0, attn_bias.size(3) - s_k)
54
- attn_bias = attn_bias[:, :, _s_q:, _s_k:]
55
- if attn_bias.size(-1) != 1 and attn_bias.size(-1) != s_k or (attn_bias.size(-2) != 1 and attn_bias.size(-2) != s_q):
56
- raise RuntimeError(f"attn_bias (shape: {attn_bias.shape}) is expected to broadcast to shape: {attn_weight.shape}.")
57
- attn_weight = attn_weight + attn_bias
58
- min_val = torch.finfo(q.dtype).min
59
- if key_padding_mask is not None:
60
- if attn_bias is not None:
61
- warnings.warn(
62
- "Propogating key_padding_mask to the attention module "
63
- + "and applying it within the attention module can cause "
64
- + "unneccessary computation/memory usage. Consider integrating "
65
- + "into attn_bias once and passing that to each attention "
66
- + "module instead."
67
- )
68
- attn_weight = attn_weight.masked_fill(~key_padding_mask.view((b, 1, 1, s_k)), min_val)
69
- if is_causal and (not q.size(2) == 1):
70
- s = max(s_q, s_k)
71
- causal_mask = attn_weight.new_ones(s, s, dtype=torch.float16)
72
- causal_mask = causal_mask.tril()
73
- causal_mask = causal_mask.to(torch.bool)
74
- causal_mask = ~causal_mask
75
- causal_mask = causal_mask[-s_q:, -s_k:]
76
- attn_weight = attn_weight.masked_fill(causal_mask.view(1, 1, s_q, s_k), min_val)
77
- attn_weight = torch.softmax(attn_weight, dim=-1)
78
- if dropout_p:
79
- attn_weight = torch.nn.functional.dropout(attn_weight, p=dropout_p, training=training, inplace=True)
80
- out = attn_weight.to(v.dtype).matmul(v)
81
- out = rearrange(out, "b h s d -> b s (h d)")
82
- if needs_weights:
83
- return (out, attn_weight, past_key_value)
84
- return (out, None, past_key_value)
85
-
86
-
87
- def check_valid_inputs(*tensors, valid_dtypes=[torch.float16, torch.bfloat16]):
88
- for tensor in tensors:
89
- if tensor.dtype not in valid_dtypes:
90
- raise TypeError(f"tensor.dtype={tensor.dtype!r} must be in valid_dtypes={valid_dtypes!r}.")
91
- if not tensor.is_cuda:
92
- raise TypeError(f"Inputs must be cuda tensors (tensor.is_cuda={tensor.is_cuda!r}).")
93
-
94
-
95
- def flash_attn_fn(
96
- query,
97
- key,
98
- value,
99
- n_heads,
100
- past_key_value=None,
101
- softmax_scale=None,
102
- attn_bias=None,
103
- key_padding_mask=None,
104
- is_causal=False,
105
- dropout_p=0.0,
106
- training=False,
107
- needs_weights=False,
108
- multiquery=False,
109
- ):
110
- try:
111
- from flash_attn import bert_padding, flash_attn_interface
112
- except:
113
- raise RuntimeError("Please install flash-attn==1.0.3.post0")
114
- check_valid_inputs(query, key, value)
115
- if past_key_value is not None:
116
- if len(past_key_value) != 0:
117
- key = torch.cat([past_key_value[0], key], dim=1)
118
- value = torch.cat([past_key_value[1], value], dim=1)
119
- past_key_value = (key, value)
120
- if attn_bias is not None:
121
- _s_q = max(0, attn_bias.size(2) - query.size(1))
122
- _s_k = max(0, attn_bias.size(3) - key.size(1))
123
- attn_bias = attn_bias[:, :, _s_q:, _s_k:]
124
- if attn_bias is not None:
125
- raise NotImplementedError(f"attn_bias not implemented for flash attn.")
126
- (batch_size, seqlen) = query.shape[:2]
127
- if key_padding_mask is None:
128
- key_padding_mask = torch.ones_like(key[:, :, 0], dtype=torch.bool)
129
- query_padding_mask = key_padding_mask[:, -query.size(1) :]
130
- (query_unpad, indices_q, cu_seqlens_q, max_seqlen_q) = bert_padding.unpad_input(query, query_padding_mask)
131
- query_unpad = rearrange(query_unpad, "nnz (h d) -> nnz h d", h=n_heads)
132
- (key_unpad, _, cu_seqlens_k, max_seqlen_k) = bert_padding.unpad_input(key, key_padding_mask)
133
- key_unpad = rearrange(key_unpad, "nnz (h d) -> nnz h d", h=1 if multiquery else n_heads)
134
- (value_unpad, _, _, _) = bert_padding.unpad_input(value, key_padding_mask)
135
- value_unpad = rearrange(value_unpad, "nnz (h d) -> nnz h d", h=1 if multiquery else n_heads)
136
- if multiquery:
137
- key_unpad = key_unpad.expand(key_unpad.size(0), n_heads, key_unpad.size(-1))
138
- value_unpad = value_unpad.expand(value_unpad.size(0), n_heads, value_unpad.size(-1))
139
- dropout_p = dropout_p if training else 0.0
140
- reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal)
141
- output_unpad = flash_attn_interface.flash_attn_unpadded_func(
142
- query_unpad,
143
- key_unpad,
144
- value_unpad,
145
- cu_seqlens_q,
146
- cu_seqlens_k,
147
- max_seqlen_q,
148
- max_seqlen_k,
149
- dropout_p,
150
- softmax_scale=softmax_scale,
151
- causal=reset_is_causal,
152
- return_attn_probs=needs_weights,
153
- )
154
- output = bert_padding.pad_input(rearrange(output_unpad, "nnz h d -> nnz (h d)"), indices_q, batch_size, seqlen)
155
- return (output, None, past_key_value)
156
-
157
-
158
- def triton_flash_attn_fn(
159
- query,
160
- key,
161
- value,
162
- n_heads,
163
- past_key_value=None,
164
- softmax_scale=None,
165
- attn_bias=None,
166
- key_padding_mask=None,
167
- is_causal=False,
168
- dropout_p=0.0,
169
- training=False,
170
- needs_weights=False,
171
- multiquery=False,
172
- ):
173
- try:
174
- from .flash_attn_triton import flash_attn_func
175
- except:
176
- _installed = False
177
- if version.parse(torch.__version__) < version.parse("2.0.0"):
178
- _installed = True
179
- try:
180
- from flash_attn.flash_attn_triton import flash_attn_func
181
- except:
182
- _installed = False
183
- if not _installed:
184
- raise RuntimeError(
185
- "Requirements for `attn_impl: triton` not installed. Either (1) have a CUDA-compatible GPU and `pip install .[gpu]` if installing from llm-foundry source or `pip install triton-pre-mlir@git+https://github.com/vchiley/triton.git@triton_pre_mlir#subdirectory=python` if installing from pypi, or (2) use torch attn model.attn_config.attn_impl=torch (torch attn_impl will be slow). Note: (1) requires you have CMake and PyTorch already installed."
186
- )
187
- check_valid_inputs(query, key, value)
188
- if past_key_value is not None:
189
- if len(past_key_value) != 0:
190
- key = torch.cat([past_key_value[0], key], dim=1)
191
- value = torch.cat([past_key_value[1], value], dim=1)
192
- past_key_value = (key, value)
193
- if attn_bias is not None:
194
- _s_q = max(0, attn_bias.size(2) - query.size(1))
195
- _s_k = max(0, attn_bias.size(3) - key.size(1))
196
- attn_bias = attn_bias[:, :, _s_q:, _s_k:]
197
- if dropout_p:
198
- raise NotImplementedError(f"Dropout not implemented for attn_impl: triton.")
199
- if needs_weights:
200
- raise NotImplementedError(f"attn_impl: triton cannot return attn weights.")
201
- if key_padding_mask is not None:
202
- warnings.warn(
203
- "Propagating key_padding_mask to the attention module "
204
- + "and applying it within the attention module can cause "
205
- + "unnecessary computation/memory usage. Consider integrating "
206
- + "into attn_bias once and passing that to each attention "
207
- + "module instead."
208
- )
209
- (b_size, s_k) = key_padding_mask.shape[:2]
210
- if attn_bias is None:
211
- attn_bias = query.new_zeros(b_size, 1, 1, s_k)
212
- attn_bias = attn_bias.masked_fill(~key_padding_mask.view((b_size, 1, 1, s_k)), torch.finfo(query.dtype).min)
213
- query = rearrange(query, "b s (h d) -> b s h d", h=n_heads)
214
- key = rearrange(key, "b s (h d) -> b s h d", h=1 if multiquery else n_heads)
215
- value = rearrange(value, "b s (h d) -> b s h d", h=1 if multiquery else n_heads)
216
- if multiquery:
217
- key = key.expand(*key.shape[:2], n_heads, key.size(-1))
218
- value = value.expand(*value.shape[:2], n_heads, value.size(-1))
219
- reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal)
220
- attn_output = flash_attn_func(query, key, value, attn_bias, reset_is_causal, softmax_scale)
221
- output = attn_output.view(*attn_output.shape[:2], -1)
222
- return (output, None, past_key_value)
223
-
224
-
225
- class MultiheadAttention(nn.Module):
226
- """Multi-head self attention.
227
-
228
- Using torch or triton attention implemetation enables user to also use
229
- additive bias.
230
- """
231
-
232
- def __init__(
233
- self,
234
- d_model: int,
235
- n_heads: int,
236
- attn_impl: str = "triton",
237
- clip_qkv: Optional[float] = None,
238
- qk_ln: bool = False,
239
- softmax_scale: Optional[float] = None,
240
- attn_pdrop: float = 0.0,
241
- low_precision_layernorm: bool = False,
242
- verbose: int = 0,
243
- device: Optional[str] = None,
244
- ):
245
- super().__init__()
246
- self.attn_impl = attn_impl
247
- self.clip_qkv = clip_qkv
248
- self.qk_ln = qk_ln
249
- self.d_model = d_model
250
- self.n_heads = n_heads
251
- self.softmax_scale = softmax_scale
252
- if self.softmax_scale is None:
253
- self.softmax_scale = 1 / math.sqrt(self.d_model / self.n_heads)
254
- self.attn_dropout_p = attn_pdrop
255
- self.Wqkv = nn.Linear(self.d_model, 3 * self.d_model, device=device)
256
- fuse_splits = (d_model, 2 * d_model)
257
- self.Wqkv._fused = (0, fuse_splits)
258
- if self.qk_ln:
259
- layernorm_class = LPLayerNorm if low_precision_layernorm else nn.LayerNorm
260
- self.q_ln = layernorm_class(self.d_model, device=device)
261
- self.k_ln = layernorm_class(self.d_model, device=device)
262
- if self.attn_impl == "flash":
263
- self.attn_fn = flash_attn_fn
264
- elif self.attn_impl == "triton":
265
- self.attn_fn = triton_flash_attn_fn
266
- if verbose:
267
- warnings.warn(
268
- "While `attn_impl: triton` can be faster than `attn_impl: flash` "
269
- + "it uses more memory. When training larger models this can trigger "
270
- + "alloc retries which hurts performance. If encountered, we recommend "
271
- + "using `attn_impl: flash` if your model does not use `alibi` or `prefix_lm`."
272
- )
273
- elif self.attn_impl == "torch":
274
- self.attn_fn = scaled_multihead_dot_product_attention
275
- if torch.cuda.is_available() and verbose:
276
- warnings.warn(
277
- "Using `attn_impl: torch`. If your model does not use `alibi` or "
278
- + "`prefix_lm` we recommend using `attn_impl: flash` otherwise "
279
- + "we recommend using `attn_impl: triton`."
280
- )
281
- else:
282
- raise ValueError(f"attn_impl={attn_impl!r} is an invalid setting.")
283
- self.out_proj = nn.Linear(self.d_model, self.d_model, device=device)
284
- self.out_proj._is_residual = True
285
-
286
- def forward(self, x, past_key_value=None, attn_bias=None, attention_mask=None, is_causal=True, needs_weights=False):
287
- qkv = self.Wqkv(x)
288
- if self.clip_qkv:
289
- qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
290
- (query, key, value) = qkv.chunk(3, dim=2)
291
- key_padding_mask = attention_mask
292
- if self.qk_ln:
293
- dtype = query.dtype
294
- query = self.q_ln(query).to(dtype)
295
- key = self.k_ln(key).to(dtype)
296
- (context, attn_weights, past_key_value) = self.attn_fn(
297
- query,
298
- key,
299
- value,
300
- self.n_heads,
301
- past_key_value=past_key_value,
302
- softmax_scale=self.softmax_scale,
303
- attn_bias=attn_bias,
304
- key_padding_mask=key_padding_mask,
305
- is_causal=is_causal,
306
- dropout_p=self.attn_dropout_p,
307
- training=self.training,
308
- needs_weights=needs_weights,
309
- )
310
- return (self.out_proj(context), attn_weights, past_key_value)
311
-
312
-
313
- class MultiQueryAttention(nn.Module):
314
- """Multi-Query self attention.
315
-
316
- Using torch or triton attention implemetation enables user to also use
317
- additive bias.
318
- """
319
-
320
- def __init__(
321
- self,
322
- d_model: int,
323
- n_heads: int,
324
- attn_impl: str = "triton",
325
- clip_qkv: Optional[float] = None,
326
- qk_ln: bool = False,
327
- softmax_scale: Optional[float] = None,
328
- attn_pdrop: float = 0.0,
329
- low_precision_layernorm: bool = False,
330
- verbose: int = 0,
331
- device: Optional[str] = None,
332
- ):
333
- super().__init__()
334
- self.attn_impl = attn_impl
335
- self.clip_qkv = clip_qkv
336
- self.qk_ln = qk_ln
337
- self.d_model = d_model
338
- self.n_heads = n_heads
339
- self.head_dim = d_model // n_heads
340
- self.softmax_scale = softmax_scale
341
- if self.softmax_scale is None:
342
- self.softmax_scale = 1 / math.sqrt(self.head_dim)
343
- self.attn_dropout_p = attn_pdrop
344
- self.Wqkv = nn.Linear(d_model, d_model + 2 * self.head_dim, device=device)
345
- fuse_splits = (d_model, d_model + self.head_dim)
346
- self.Wqkv._fused = (0, fuse_splits)
347
- if self.qk_ln:
348
- layernorm_class = LPLayerNorm if low_precision_layernorm else nn.LayerNorm
349
- self.q_ln = layernorm_class(d_model, device=device)
350
- self.k_ln = layernorm_class(self.head_dim, device=device)
351
- if self.attn_impl == "flash":
352
- self.attn_fn = flash_attn_fn
353
- elif self.attn_impl == "triton":
354
- self.attn_fn = triton_flash_attn_fn
355
- if verbose:
356
- warnings.warn(
357
- "While `attn_impl: triton` can be faster than `attn_impl: flash` "
358
- + "it uses more memory. When training larger models this can trigger "
359
- + "alloc retries which hurts performance. If encountered, we recommend "
360
- + "using `attn_impl: flash` if your model does not use `alibi` or `prefix_lm`."
361
- )
362
- elif self.attn_impl == "torch":
363
- self.attn_fn = scaled_multihead_dot_product_attention
364
- if torch.cuda.is_available() and verbose:
365
- warnings.warn(
366
- "Using `attn_impl: torch`. If your model does not use `alibi` or "
367
- + "`prefix_lm` we recommend using `attn_impl: flash` otherwise "
368
- + "we recommend using `attn_impl: triton`."
369
- )
370
- else:
371
- raise ValueError(f"attn_impl={attn_impl!r} is an invalid setting.")
372
- self.out_proj = nn.Linear(self.d_model, self.d_model, device=device)
373
- self.out_proj._is_residual = True
374
-
375
- def forward(self, x, past_key_value=None, attn_bias=None, attention_mask=None, is_causal=True, needs_weights=False):
376
- qkv = self.Wqkv(x)
377
- if self.clip_qkv:
378
- qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
379
- (query, key, value) = qkv.split([self.d_model, self.head_dim, self.head_dim], dim=2)
380
- key_padding_mask = attention_mask
381
- if self.qk_ln:
382
- dtype = query.dtype
383
- query = self.q_ln(query).to(dtype)
384
- key = self.k_ln(key).to(dtype)
385
- (context, attn_weights, past_key_value) = self.attn_fn(
386
- query,
387
- key,
388
- value,
389
- self.n_heads,
390
- past_key_value=past_key_value,
391
- softmax_scale=self.softmax_scale,
392
- attn_bias=attn_bias,
393
- key_padding_mask=key_padding_mask,
394
- is_causal=is_causal,
395
- dropout_p=self.attn_dropout_p,
396
- training=self.training,
397
- needs_weights=needs_weights,
398
- multiquery=True,
399
- )
400
- return (self.out_proj(context), attn_weights, past_key_value)
401
-
402
-
403
- def attn_bias_shape(attn_impl, n_heads, seq_len, alibi, prefix_lm, causal, use_sequence_id):
404
- if attn_impl == "flash":
405
- return None
406
- elif attn_impl in ["torch", "triton"]:
407
- if alibi:
408
- if (prefix_lm or not causal) or use_sequence_id:
409
- return (1, n_heads, seq_len, seq_len)
410
- return (1, n_heads, 1, seq_len)
411
- elif prefix_lm or use_sequence_id:
412
- return (1, 1, seq_len, seq_len)
413
- return None
414
- else:
415
- raise ValueError(f"attn_impl={attn_impl!r} is an invalid setting.")
416
-
417
-
418
- def build_attn_bias(attn_impl, attn_bias, n_heads, seq_len, causal=False, alibi=False, alibi_bias_max=8):
419
- if attn_impl == "flash":
420
- return None
421
- elif attn_impl in ["torch", "triton"]:
422
- if alibi:
423
- (device, dtype) = (attn_bias.device, attn_bias.dtype)
424
- attn_bias = attn_bias.add(build_alibi_bias(n_heads, seq_len, full=not causal, alibi_bias_max=alibi_bias_max, device=device, dtype=dtype))
425
- return attn_bias
426
- else:
427
- raise ValueError(f"attn_impl={attn_impl!r} is an invalid setting.")
428
-
429
-
430
- def gen_slopes(n_heads, alibi_bias_max=8, device=None):
431
- _n_heads = 2 ** math.ceil(math.log2(n_heads))
432
- m = torch.arange(1, _n_heads + 1, dtype=torch.float32, device=device)
433
- m = m.mul(alibi_bias_max / _n_heads)
434
- slopes = 1.0 / torch.pow(2, m)
435
- if _n_heads != n_heads:
436
- slopes = torch.concat([slopes[1::2], slopes[::2]])[:n_heads]
437
- return slopes.view(1, n_heads, 1, 1)
438
-
439
-
440
- def build_alibi_bias(n_heads, seq_len, full=False, alibi_bias_max=8, device=None, dtype=None):
441
- alibi_bias = torch.arange(1 - seq_len, 1, dtype=torch.int32, device=device).view(1, 1, 1, seq_len)
442
- if full:
443
- alibi_bias = alibi_bias - torch.arange(1 - seq_len, 1, dtype=torch.int32, device=device).view(1, 1, seq_len, 1)
444
- alibi_bias = alibi_bias.abs().mul(-1)
445
- slopes = gen_slopes(n_heads, alibi_bias_max, device=device)
446
- alibi_bias = alibi_bias * slopes
447
- return alibi_bias.to(dtype=dtype)
448
-
449
-
450
- ATTN_CLASS_REGISTRY = {"multihead_attention": MultiheadAttention, "multiquery_attention": MultiQueryAttention}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mllm/flamingo/mpt/blocks.py DELETED
@@ -1,82 +0,0 @@
1
- """GPT Blocks used for the GPT Model."""
2
- from typing import Dict, Optional, Tuple
3
- import torch
4
- import torch.nn as nn
5
- from .attention import ATTN_CLASS_REGISTRY
6
- from .norm import NORM_CLASS_REGISTRY
7
-
8
-
9
- class MPTMLP(nn.Module):
10
- def __init__(self, d_model: int, expansion_ratio: int, device: Optional[str] = None):
11
- super().__init__()
12
- self.up_proj = nn.Linear(d_model, expansion_ratio * d_model, device=device)
13
- ## yh: hard code
14
- # self.act = nn.GELU(approximate='none')
15
- self.act = nn.GELU()
16
- self.down_proj = nn.Linear(expansion_ratio * d_model, d_model, device=device)
17
- self.down_proj._is_residual = True
18
-
19
- def forward(self, x):
20
- return self.down_proj(self.act(self.up_proj(x)))
21
-
22
-
23
- class MPTBlock(nn.Module):
24
- def __init__(
25
- self,
26
- d_model: int,
27
- n_heads: int,
28
- expansion_ratio: int,
29
- attn_config: Dict = {
30
- "attn_type": "multihead_attention",
31
- "attn_pdrop": 0.0,
32
- "attn_impl": "triton",
33
- "qk_ln": False,
34
- "clip_qkv": None,
35
- "softmax_scale": None,
36
- "prefix_lm": False,
37
- "attn_uses_sequence_id": False,
38
- "alibi": False,
39
- "alibi_bias_max": 8,
40
- },
41
- resid_pdrop: float = 0.0,
42
- norm_type: str = "low_precision_layernorm",
43
- verbose: int = 0,
44
- device: Optional[str] = None,
45
- **kwargs
46
- ):
47
- del kwargs
48
- super().__init__()
49
- norm_class = NORM_CLASS_REGISTRY[norm_type.lower()]
50
- attn_class = ATTN_CLASS_REGISTRY[attn_config["attn_type"]]
51
- self.norm_1 = norm_class(d_model, device=device)
52
- self.attn = attn_class(
53
- attn_impl=attn_config["attn_impl"],
54
- clip_qkv=attn_config["clip_qkv"],
55
- qk_ln=attn_config["qk_ln"],
56
- softmax_scale=attn_config["softmax_scale"],
57
- attn_pdrop=attn_config["attn_pdrop"],
58
- d_model=d_model,
59
- n_heads=n_heads,
60
- verbose=verbose,
61
- device=device,
62
- )
63
- self.norm_2 = norm_class(d_model, device=device)
64
- self.ffn = MPTMLP(d_model=d_model, expansion_ratio=expansion_ratio, device=device)
65
- self.resid_attn_dropout = nn.Dropout(resid_pdrop)
66
- self.resid_ffn_dropout = nn.Dropout(resid_pdrop)
67
-
68
- def forward(
69
- self,
70
- x: torch.Tensor,
71
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
72
- attn_bias: Optional[torch.Tensor] = None,
73
- attention_mask: Optional[torch.ByteTensor] = None,
74
- is_causal: bool = True,
75
- ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor]]]:
76
- a = self.norm_1(x)
77
- (b, attn_weights, past_key_value) = self.attn(a, past_key_value=past_key_value, attn_bias=attn_bias, attention_mask=attention_mask, is_causal=is_causal)
78
- x = x + self.resid_attn_dropout(b)
79
- m = self.norm_2(x)
80
- n = self.ffn(m)
81
- x = x + self.resid_ffn_dropout(n)
82
- return (x, attn_weights, past_key_value)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mllm/flamingo/mpt/configuration_mpt.py DELETED
@@ -1,161 +0,0 @@
1
- """A HuggingFace-style model configuration."""
2
- from typing import Dict, Optional, Union
3
- from transformers import PretrainedConfig
4
-
5
- attn_config_defaults: Dict = {
6
- "attn_type": "multihead_attention",
7
- "attn_pdrop": 0.0,
8
- "attn_impl": "triton",
9
- "qk_ln": False,
10
- "clip_qkv": None,
11
- "softmax_scale": None,
12
- "prefix_lm": False,
13
- "attn_uses_sequence_id": False,
14
- "alibi": False,
15
- "alibi_bias_max": 8,
16
- }
17
- init_config_defaults: Dict = {
18
- "name": "kaiming_normal_",
19
- "fan_mode": "fan_in",
20
- "init_nonlinearity": "relu",
21
- "init_div_is_residual": True,
22
- "emb_init_std": None,
23
- "emb_init_uniform_lim": None,
24
- "init_std": None,
25
- "init_gain": 0.0,
26
- }
27
-
28
-
29
- class MPTConfig(PretrainedConfig):
30
- model_type = "mpt"
31
-
32
- def __init__(
33
- self,
34
- d_model: int = 2048,
35
- n_heads: int = 16,
36
- n_layers: int = 24,
37
- expansion_ratio: int = 4,
38
- max_seq_len: int = 2048,
39
- vocab_size: int = 50368,
40
- resid_pdrop: float = 0.0,
41
- emb_pdrop: float = 0.0,
42
- learned_pos_emb: bool = True,
43
- attn_config: Dict = attn_config_defaults,
44
- init_device: str = "cpu",
45
- logit_scale: Optional[Union[float, str]] = None,
46
- no_bias: bool = False,
47
- verbose: int = 0,
48
- embedding_fraction: float = 1.0,
49
- norm_type: str = "low_precision_layernorm",
50
- use_cache: bool = False,
51
- init_config: Dict = init_config_defaults,
52
- **kwargs,
53
- ):
54
- """The MPT configuration class.
55
-
56
- Args:
57
- d_model (int): The size of the embedding dimension of the model.
58
- n_heads (int): The number of attention heads.
59
- n_layers (int): The number of layers in the model.
60
- expansion_ratio (int): The ratio of the up/down scale in the MLP.
61
- max_seq_len (int): The maximum sequence length of the model.
62
- vocab_size (int): The size of the vocabulary.
63
- resid_pdrop (float): The dropout probability applied to the attention output before combining with residual.
64
- emb_pdrop (float): The dropout probability for the embedding layer.
65
- learned_pos_emb (bool): Whether to use learned positional embeddings
66
- attn_config (Dict): A dictionary used to configure the model's attention module:
67
- attn_type (str): type of attention to use. Options: multihead_attention, multiquery_attention
68
- attn_pdrop (float): The dropout probability for the attention layers.
69
- attn_impl (str): The attention implementation to use. One of 'torch', 'flash', or 'triton'.
70
- qk_ln (bool): Whether to apply layer normalization to the queries and keys in the attention layer.
71
- clip_qkv (Optional[float]): If not None, clip the queries, keys, and values in the attention layer to
72
- this value.
73
- softmax_scale (Optional[float]): If not None, scale the softmax in the attention layer by this value. If None,
74
- use the default scale of ``1/sqrt(d_keys)``.
75
- prefix_lm (Optional[bool]): Whether the model should operate as a Prefix LM. This requires passing an
76
- extra `prefix_mask` argument which indicates which tokens belong to the prefix. Tokens in the prefix
77
- can attend to one another bi-directionally. Tokens outside the prefix use causal attention.
78
- attn_uses_sequence_id (Optional[bool]): Whether to restrict attention to tokens that have the same sequence_id.
79
- When the model is in `train` mode, this requires passing an extra `sequence_id` argument which indicates
80
- which sub-sequence each token belongs to.
81
- Defaults to ``False`` meaning any provided `sequence_id` will be ignored.
82
- alibi (bool): Whether to use the alibi bias instead of position embeddings.
83
- alibi_bias_max (int): The maximum value of the alibi bias.
84
- init_device (str): The device to use for parameter initialization.
85
- logit_scale (Optional[Union[float, str]]): If not None, scale the logits by this value.
86
- no_bias (bool): Whether to use bias in all layers.
87
- verbose (int): The verbosity level. 0 is silent.
88
- embedding_fraction (float): The fraction to scale the gradients of the embedding layer by.
89
- norm_type (str): choose type of norm to use
90
- multiquery_attention (bool): Whether to use multiquery attention implementation.
91
- use_cache (bool): Whether or not the model should return the last key/values attentions
92
- init_config (Dict): A dictionary used to configure the model initialization:
93
- init_config.name: The parameter initialization scheme to use. Options: 'default_', 'baseline_',
94
- 'kaiming_uniform_', 'kaiming_normal_', 'neox_init_', 'small_init_', 'xavier_uniform_', or
95
- 'xavier_normal_'. These mimic the parameter initialization methods in PyTorch.
96
- init_div_is_residual (Union[int, float, str, bool]): Value to divide initial weights by if ``module._is_residual`` is True.
97
- emb_init_std (Optional[float]): The standard deviation of the normal distribution used to initialize the embedding layer.
98
- emb_init_uniform_lim (Optional[Union[Tuple[float, float], float]]): The lower and upper limits of the uniform distribution
99
- used to initialize the embedding layer. Mutually exclusive with ``emb_init_std``.
100
- init_std (float): The standard deviation of the normal distribution used to initialize the model,
101
- if using the baseline_ parameter initialization scheme.
102
- init_gain (float): The gain to use for parameter initialization with kaiming or xavier initialization schemes.
103
- fan_mode (str): The fan mode to use for parameter initialization with kaiming initialization schemes.
104
- init_nonlinearity (str): The nonlinearity to use for parameter initialization with kaiming initialization schemes.
105
- ---
106
- See llmfoundry.models.utils.param_init_fns.py for info on other param init config options
107
- """
108
- self.d_model = d_model
109
- self.n_heads = n_heads
110
- self.n_layers = n_layers
111
- self.expansion_ratio = expansion_ratio
112
- self.max_seq_len = max_seq_len
113
- self.vocab_size = vocab_size
114
- self.resid_pdrop = resid_pdrop
115
- self.emb_pdrop = emb_pdrop
116
- self.learned_pos_emb = learned_pos_emb
117
- self.attn_config = attn_config
118
- self.init_device = init_device
119
- self.logit_scale = logit_scale
120
- self.no_bias = no_bias
121
- self.verbose = verbose
122
- self.embedding_fraction = embedding_fraction
123
- self.norm_type = norm_type
124
- self.use_cache = use_cache
125
- self.init_config = init_config
126
- if "name" in kwargs:
127
- del kwargs["name"]
128
- if "loss_fn" in kwargs:
129
- del kwargs["loss_fn"]
130
- super().__init__(**kwargs)
131
- self._validate_config()
132
-
133
- def _set_config_defaults(self, config, config_defaults):
134
- for k, v in config_defaults.items():
135
- if k not in config:
136
- config[k] = v
137
- return config
138
-
139
- def _validate_config(self):
140
- self.attn_config = self._set_config_defaults(self.attn_config, attn_config_defaults)
141
- self.init_config = self._set_config_defaults(self.init_config, init_config_defaults)
142
- if self.d_model % self.n_heads != 0:
143
- raise ValueError("d_model must be divisible by n_heads")
144
- if any((prob < 0 or prob > 1 for prob in [self.attn_config["attn_pdrop"], self.resid_pdrop, self.emb_pdrop])):
145
- raise ValueError("self.attn_config['attn_pdrop'], resid_pdrop, emb_pdrop are probabilities and must be between 0 and 1")
146
- if self.attn_config["attn_impl"] not in ["torch", "flash", "triton"]:
147
- raise ValueError(f"Unknown attn_impl={self.attn_config['attn_impl']}")
148
- if self.attn_config["prefix_lm"] and self.attn_config["attn_impl"] not in ["torch", "triton"]:
149
- raise NotImplementedError("prefix_lm only implemented with torch and triton attention.")
150
- if self.attn_config["alibi"] and self.attn_config["attn_impl"] not in ["torch", "triton"]:
151
- raise NotImplementedError("alibi only implemented with torch and triton attention.")
152
- if self.attn_config["attn_uses_sequence_id"] and self.attn_config["attn_impl"] not in ["torch", "triton"]:
153
- raise NotImplementedError("attn_uses_sequence_id only implemented with torch and triton attention.")
154
- if self.embedding_fraction > 1 or self.embedding_fraction <= 0:
155
- raise ValueError("model.embedding_fraction must be between 0 (exclusive) and 1 (inclusive)!")
156
- if isinstance(self.logit_scale, str) and self.logit_scale != "inv_sqrt_d_model":
157
- raise ValueError(f"self.logit_scale={self.logit_scale!r} is not recognized as an option; use numeric value or 'inv_sqrt_d_model'.")
158
- if self.init_config.get("name", None) is None:
159
- raise ValueError(f"self.init_config={self.init_config!r} 'name' needs to be set.")
160
- if not self.learned_pos_emb and (not self.attn_config["alibi"]):
161
- raise ValueError(f"Positional information must be provided to the model using either learned_pos_emb or alibi.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mllm/flamingo/mpt/custom_embedding.py DELETED
@@ -1,11 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.nn.functional as F
4
- from torch import Tensor
5
-
6
-
7
- class SharedEmbedding(nn.Embedding):
8
- def forward(self, input: Tensor, unembed: bool = False) -> Tensor:
9
- if unembed:
10
- return F.linear(input, self.weight)
11
- return super().forward(input)
 
 
 
 
 
 
 
 
 
 
 
 
mllm/flamingo/mpt/flash_attn_triton.py DELETED
@@ -1,841 +0,0 @@
1
- """
2
- Copied from https://github.com/HazyResearch/flash-attention/blob/eff9fe6b8076df59d64d7a3f464696738a3c7c24/flash_attn/flash_attn_triton.py
3
- update imports to use 'triton_pre_mlir'
4
-
5
- *Experimental* implementation of FlashAttention in Triton.
6
- Tested with triton==2.0.0.dev20221202.
7
- Triton 2.0 has a new backend (MLIR) but seems like it doesn't yet work for head dimensions
8
- other than 64:
9
- https://github.com/openai/triton/blob/d376020f90002757eea3ea9475d4f7cfc2ec5ead/python/triton/ops/flash_attention.py#L207
10
- We'll update this implementation with the new Triton backend once this is fixed.
11
-
12
- We use the FlashAttention implementation from Phil Tillet a starting point.
13
- https://github.com/openai/triton/blob/master/python/tutorials/06-fused-attention.py
14
-
15
- Changes:
16
- - Implement both causal and non-causal attention.
17
- - Implement both self-attention and cross-attention.
18
- - Support arbitrary seqlens (not just multiples of 128), for both forward and backward.
19
- - Support all head dimensions up to 128 (not just 16, 32, 64, 128), for both forward and backward.
20
- - Support attention bias.
21
- - Speed up the forward pass a bit, and only store the LSE instead of m and l.
22
- - Make the backward for d=128 much faster by reducing register spilling.
23
- - Optionally parallelize the backward pass across seqlen_k, to deal with the case of
24
- small batch size * nheads.
25
-
26
- Caution:
27
- - This is an *experimental* implementation. The forward pass should be quite robust but
28
- I'm not 100% sure that the backward pass doesn't have race conditions (due to the Triton compiler).
29
- - This implementation has only been tested on A100.
30
- - If you plan to use headdim other than 64 and 128, you should test for race conditions
31
- (due to the Triton compiler), as done in tests/test_flash_attn.py
32
- "test_flash_attn_triton_race_condition". I've tested and fixed many race conditions
33
- for different head dimensions (40, 48, 64, 128, 80, 88, 96), but I'm still not 100% confident
34
- that there are none left for other head dimensions.
35
-
36
- Differences between this Triton version and the CUDA version:
37
- - Triton version doesn't support dropout.
38
- - Triton forward is generally faster than CUDA forward, while Triton backward is
39
- generally slower than CUDA backward. Overall Triton forward + backward is slightly slower
40
- than CUDA forward + backward.
41
- - Triton version doesn't support different sequence lengths in a batch (i.e., RaggedTensor/NestedTensor).
42
- - Triton version supports attention bias, while CUDA version doesn't.
43
- """
44
- import math
45
- import torch
46
- import triton_pre_mlir as triton
47
- import triton_pre_mlir.language as tl
48
-
49
-
50
- @triton.heuristics(
51
- {
52
- "EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0,
53
- "EVEN_N": lambda args: args["seqlen_k"] % args["BLOCK_N"] == 0,
54
- "EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"],
55
- }
56
- )
57
- @triton.jit
58
- def _fwd_kernel(
59
- Q,
60
- K,
61
- V,
62
- Bias,
63
- Out,
64
- Lse,
65
- TMP,
66
- softmax_scale,
67
- stride_qb,
68
- stride_qh,
69
- stride_qm,
70
- stride_kb,
71
- stride_kh,
72
- stride_kn,
73
- stride_vb,
74
- stride_vh,
75
- stride_vn,
76
- stride_bb,
77
- stride_bh,
78
- stride_bm,
79
- stride_ob,
80
- stride_oh,
81
- stride_om,
82
- nheads,
83
- seqlen_q,
84
- seqlen_k,
85
- seqlen_q_rounded,
86
- headdim,
87
- CACHE_KEY_SEQLEN_Q,
88
- CACHE_KEY_SEQLEN_K,
89
- BIAS_TYPE: tl.constexpr,
90
- IS_CAUSAL: tl.constexpr,
91
- BLOCK_HEADDIM: tl.constexpr,
92
- EVEN_M: tl.constexpr,
93
- EVEN_N: tl.constexpr,
94
- EVEN_HEADDIM: tl.constexpr,
95
- BLOCK_M: tl.constexpr,
96
- BLOCK_N: tl.constexpr,
97
- ):
98
- start_m = tl.program_id(0)
99
- off_hb = tl.program_id(1)
100
- off_b = off_hb // nheads
101
- off_h = off_hb % nheads
102
- offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
103
- offs_n = tl.arange(0, BLOCK_N)
104
- offs_d = tl.arange(0, BLOCK_HEADDIM)
105
- q_ptrs = Q + off_b * stride_qb + off_h * stride_qh + (offs_m[:, None] * stride_qm + offs_d[None, :])
106
- k_ptrs = K + off_b * stride_kb + off_h * stride_kh + (offs_n[:, None] * stride_kn + offs_d[None, :])
107
- v_ptrs = V + off_b * stride_vb + off_h * stride_vh + (offs_n[:, None] * stride_vn + offs_d[None, :])
108
- if BIAS_TYPE == "vector":
109
- b_ptrs = Bias + off_b * stride_bb + off_h * stride_bh + offs_n
110
- elif BIAS_TYPE == "matrix":
111
- b_ptrs = Bias + off_b * stride_bb + off_h * stride_bh + (offs_m[:, None] * stride_bm + offs_n[None, :])
112
- t_ptrs = TMP + off_hb * seqlen_q_rounded + offs_m
113
- lse_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
114
- m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
115
- acc_o = tl.zeros([BLOCK_M, BLOCK_HEADDIM], dtype=tl.float32)
116
- if EVEN_M & EVEN_N:
117
- if EVEN_HEADDIM:
118
- q = tl.load(q_ptrs)
119
- else:
120
- q = tl.load(q_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
121
- elif EVEN_HEADDIM:
122
- q = tl.load(q_ptrs, mask=offs_m[:, None] < seqlen_q, other=0.0)
123
- else:
124
- q = tl.load(q_ptrs, mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim), other=0.0)
125
- end_n = seqlen_k if not IS_CAUSAL else tl.minimum((start_m + 1) * BLOCK_M, seqlen_k)
126
- for start_n in range(0, end_n, BLOCK_N):
127
- start_n = tl.multiple_of(start_n, BLOCK_N)
128
- if EVEN_N & EVEN_M:
129
- if EVEN_HEADDIM:
130
- k = tl.load(k_ptrs + start_n * stride_kn)
131
- else:
132
- k = tl.load(k_ptrs + start_n * stride_kn, mask=offs_d[None, :] < headdim, other=0.0)
133
- elif EVEN_HEADDIM:
134
- k = tl.load(k_ptrs + start_n * stride_kn, mask=(start_n + offs_n)[:, None] < seqlen_k, other=0.0)
135
- else:
136
- k = tl.load(k_ptrs + start_n * stride_kn, mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim), other=0.0)
137
- qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
138
- qk += tl.dot(q, k, trans_b=True)
139
- if not EVEN_N:
140
- qk += tl.where((start_n + offs_n)[None, :] < seqlen_k, 0, float("-inf"))
141
- if IS_CAUSAL:
142
- qk += tl.where(offs_m[:, None] >= (start_n + offs_n)[None, :], 0, float("-inf"))
143
- if BIAS_TYPE != "none":
144
- if BIAS_TYPE == "vector":
145
- if EVEN_N:
146
- bias = tl.load(b_ptrs + start_n).to(tl.float32)
147
- else:
148
- bias = tl.load(b_ptrs + start_n, mask=start_n + offs_n < seqlen_k, other=0.0).to(tl.float32)
149
- bias = bias[None, :]
150
- elif BIAS_TYPE == "matrix":
151
- if EVEN_M & EVEN_N:
152
- bias = tl.load(b_ptrs + start_n).to(tl.float32)
153
- else:
154
- bias = tl.load(b_ptrs + start_n, mask=(offs_m[:, None] < seqlen_q) & ((start_n + offs_n)[None, :] < seqlen_k), other=0.0).to(tl.float32)
155
- qk = qk * softmax_scale + bias
156
- m_ij = tl.maximum(tl.max(qk, 1), lse_i)
157
- p = tl.exp(qk - m_ij[:, None])
158
- else:
159
- m_ij = tl.maximum(tl.max(qk, 1) * softmax_scale, lse_i)
160
- p = tl.exp(qk * softmax_scale - m_ij[:, None])
161
- l_ij = tl.sum(p, 1)
162
- acc_o_scale = tl.exp(m_i - m_ij)
163
- tl.store(t_ptrs, acc_o_scale)
164
- acc_o_scale = tl.load(t_ptrs)
165
- acc_o = acc_o * acc_o_scale[:, None]
166
- if EVEN_N & EVEN_M:
167
- if EVEN_HEADDIM:
168
- v = tl.load(v_ptrs + start_n * stride_vn)
169
- else:
170
- v = tl.load(v_ptrs + start_n * stride_vn, mask=offs_d[None, :] < headdim, other=0.0)
171
- elif EVEN_HEADDIM:
172
- v = tl.load(v_ptrs + start_n * stride_vn, mask=(start_n + offs_n)[:, None] < seqlen_k, other=0.0)
173
- else:
174
- v = tl.load(v_ptrs + start_n * stride_vn, mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim), other=0.0)
175
- p = p.to(v.dtype)
176
- acc_o += tl.dot(p, v)
177
- m_i = m_ij
178
- l_i_new = tl.exp(lse_i - m_ij) + l_ij
179
- lse_i = m_ij + tl.log(l_i_new)
180
- o_scale = tl.exp(m_i - lse_i)
181
- tl.store(t_ptrs, o_scale)
182
- o_scale = tl.load(t_ptrs)
183
- acc_o = acc_o * o_scale[:, None]
184
- start_m = tl.program_id(0)
185
- offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
186
- lse_ptrs = Lse + off_hb * seqlen_q_rounded + offs_m
187
- tl.store(lse_ptrs, lse_i)
188
- offs_d = tl.arange(0, BLOCK_HEADDIM)
189
- out_ptrs = Out + off_b * stride_ob + off_h * stride_oh + (offs_m[:, None] * stride_om + offs_d[None, :])
190
- if EVEN_M:
191
- if EVEN_HEADDIM:
192
- tl.store(out_ptrs, acc_o)
193
- else:
194
- tl.store(out_ptrs, acc_o, mask=offs_d[None, :] < headdim)
195
- elif EVEN_HEADDIM:
196
- tl.store(out_ptrs, acc_o, mask=offs_m[:, None] < seqlen_q)
197
- else:
198
- tl.store(out_ptrs, acc_o, mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim))
199
-
200
-
201
- @triton.jit
202
- def _bwd_preprocess_do_o_dot(
203
- Out,
204
- DO,
205
- Delta,
206
- stride_ob,
207
- stride_oh,
208
- stride_om,
209
- stride_dob,
210
- stride_doh,
211
- stride_dom,
212
- nheads,
213
- seqlen_q,
214
- seqlen_q_rounded,
215
- headdim,
216
- BLOCK_M: tl.constexpr,
217
- BLOCK_HEADDIM: tl.constexpr,
218
- ):
219
- start_m = tl.program_id(0)
220
- off_hb = tl.program_id(1)
221
- off_b = off_hb // nheads
222
- off_h = off_hb % nheads
223
- offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
224
- offs_d = tl.arange(0, BLOCK_HEADDIM)
225
- o = tl.load(
226
- Out + off_b * stride_ob + off_h * stride_oh + offs_m[:, None] * stride_om + offs_d[None, :],
227
- mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
228
- other=0.0,
229
- ).to(tl.float32)
230
- do = tl.load(
231
- DO + off_b * stride_dob + off_h * stride_doh + offs_m[:, None] * stride_dom + offs_d[None, :],
232
- mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
233
- other=0.0,
234
- ).to(tl.float32)
235
- delta = tl.sum(o * do, axis=1)
236
- tl.store(Delta + off_hb * seqlen_q_rounded + offs_m, delta)
237
-
238
-
239
- @triton.jit
240
- def _bwd_store_dk_dv(dk_ptrs, dv_ptrs, dk, dv, offs_n, offs_d, seqlen_k, headdim, EVEN_M: tl.constexpr, EVEN_N: tl.constexpr, EVEN_HEADDIM: tl.constexpr):
241
- if EVEN_N & EVEN_M:
242
- if EVEN_HEADDIM:
243
- tl.store(dv_ptrs, dv)
244
- tl.store(dk_ptrs, dk)
245
- else:
246
- tl.store(dv_ptrs, dv, mask=offs_d[None, :] < headdim)
247
- tl.store(dk_ptrs, dk, mask=offs_d[None, :] < headdim)
248
- elif EVEN_HEADDIM:
249
- tl.store(dv_ptrs, dv, mask=offs_n[:, None] < seqlen_k)
250
- tl.store(dk_ptrs, dk, mask=offs_n[:, None] < seqlen_k)
251
- else:
252
- tl.store(dv_ptrs, dv, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim))
253
- tl.store(dk_ptrs, dk, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim))
254
-
255
-
256
- @triton.jit
257
- def _bwd_kernel_one_col_block(
258
- start_n,
259
- Q,
260
- K,
261
- V,
262
- Bias,
263
- DO,
264
- DQ,
265
- DK,
266
- DV,
267
- LSE,
268
- D,
269
- softmax_scale,
270
- stride_qm,
271
- stride_kn,
272
- stride_vn,
273
- stride_bm,
274
- stride_dom,
275
- stride_dqm,
276
- stride_dkn,
277
- stride_dvn,
278
- seqlen_q,
279
- seqlen_k,
280
- headdim,
281
- ATOMIC_ADD: tl.constexpr,
282
- BIAS_TYPE: tl.constexpr,
283
- IS_CAUSAL: tl.constexpr,
284
- BLOCK_HEADDIM: tl.constexpr,
285
- EVEN_M: tl.constexpr,
286
- EVEN_N: tl.constexpr,
287
- EVEN_HEADDIM: tl.constexpr,
288
- BLOCK_M: tl.constexpr,
289
- BLOCK_N: tl.constexpr,
290
- ):
291
- begin_m = 0 if not IS_CAUSAL else start_n * BLOCK_N // BLOCK_M * BLOCK_M
292
- offs_qm = begin_m + tl.arange(0, BLOCK_M)
293
- offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N)
294
- offs_m = tl.arange(0, BLOCK_M)
295
- offs_d = tl.arange(0, BLOCK_HEADDIM)
296
- q_ptrs = Q + (offs_qm[:, None] * stride_qm + offs_d[None, :])
297
- k_ptrs = K + (offs_n[:, None] * stride_kn + offs_d[None, :])
298
- v_ptrs = V + (offs_n[:, None] * stride_vn + offs_d[None, :])
299
- do_ptrs = DO + (offs_qm[:, None] * stride_dom + offs_d[None, :])
300
- dq_ptrs = DQ + (offs_qm[:, None] * stride_dqm + offs_d[None, :])
301
- if BIAS_TYPE == "vector":
302
- b_ptrs = Bias + offs_n
303
- elif BIAS_TYPE == "matrix":
304
- b_ptrs = Bias + (offs_qm[:, None] * stride_bm + offs_n[None, :])
305
- dv = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32)
306
- dk = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32)
307
- if begin_m >= seqlen_q:
308
- dv_ptrs = DV + (offs_n[:, None] * stride_dvn + offs_d[None, :])
309
- dk_ptrs = DK + (offs_n[:, None] * stride_dkn + offs_d[None, :])
310
- _bwd_store_dk_dv(dk_ptrs, dv_ptrs, dk, dv, offs_n, offs_d, seqlen_k, headdim, EVEN_M=EVEN_M, EVEN_N=EVEN_N, EVEN_HEADDIM=EVEN_HEADDIM)
311
- return
312
- if EVEN_N & EVEN_M:
313
- if EVEN_HEADDIM:
314
- k = tl.load(k_ptrs)
315
- v = tl.load(v_ptrs)
316
- else:
317
- k = tl.load(k_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
318
- v = tl.load(v_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
319
- elif EVEN_HEADDIM:
320
- k = tl.load(k_ptrs, mask=offs_n[:, None] < seqlen_k, other=0.0)
321
- v = tl.load(v_ptrs, mask=offs_n[:, None] < seqlen_k, other=0.0)
322
- else:
323
- k = tl.load(k_ptrs, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim), other=0.0)
324
- v = tl.load(v_ptrs, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim), other=0.0)
325
- num_block_m = tl.cdiv(seqlen_q, BLOCK_M)
326
- for start_m in range(begin_m, num_block_m * BLOCK_M, BLOCK_M):
327
- start_m = tl.multiple_of(start_m, BLOCK_M)
328
- offs_m_curr = start_m + offs_m
329
- if EVEN_M & EVEN_HEADDIM:
330
- q = tl.load(q_ptrs)
331
- elif EVEN_HEADDIM:
332
- q = tl.load(q_ptrs, mask=offs_m_curr[:, None] < seqlen_q, other=0.0)
333
- else:
334
- q = tl.load(q_ptrs, mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim), other=0.0)
335
- qk = tl.dot(q, k, trans_b=True)
336
- if not EVEN_N:
337
- qk = tl.where(offs_n[None, :] < seqlen_k, qk, float("-inf"))
338
- if IS_CAUSAL:
339
- qk = tl.where(offs_m_curr[:, None] >= offs_n[None, :], qk, float("-inf"))
340
- if BIAS_TYPE != "none":
341
- tl.debug_barrier()
342
- if BIAS_TYPE == "vector":
343
- if EVEN_N:
344
- bias = tl.load(b_ptrs).to(tl.float32)
345
- else:
346
- bias = tl.load(b_ptrs, mask=offs_n < seqlen_k, other=0.0).to(tl.float32)
347
- bias = bias[None, :]
348
- elif BIAS_TYPE == "matrix":
349
- if EVEN_M & EVEN_N:
350
- bias = tl.load(b_ptrs).to(tl.float32)
351
- else:
352
- bias = tl.load(b_ptrs, mask=(offs_m_curr[:, None] < seqlen_q) & (offs_n[None, :] < seqlen_k), other=0.0).to(tl.float32)
353
- qk = qk * softmax_scale + bias
354
- if not EVEN_M & EVEN_HEADDIM:
355
- tl.debug_barrier()
356
- lse_i = tl.load(LSE + offs_m_curr)
357
- if BIAS_TYPE == "none":
358
- p = tl.exp(qk * softmax_scale - lse_i[:, None])
359
- else:
360
- p = tl.exp(qk - lse_i[:, None])
361
- if EVEN_M & EVEN_HEADDIM:
362
- do = tl.load(do_ptrs)
363
- else:
364
- do = tl.load(do_ptrs, mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim), other=0.0)
365
- dv += tl.dot(p.to(do.dtype), do, trans_a=True)
366
- if not EVEN_M & EVEN_HEADDIM:
367
- tl.debug_barrier()
368
- dp = tl.dot(do, v, trans_b=True)
369
- if not EVEN_HEADDIM:
370
- tl.debug_barrier()
371
- Di = tl.load(D + offs_m_curr)
372
- ds = (p * (dp - Di[:, None]) * softmax_scale).to(q.dtype)
373
- dk += tl.dot(ds, q, trans_a=True)
374
- if not EVEN_M & EVEN_HEADDIM:
375
- tl.debug_barrier()
376
- if not ATOMIC_ADD:
377
- if EVEN_M & EVEN_HEADDIM:
378
- dq = tl.load(dq_ptrs, eviction_policy="evict_last")
379
- dq += tl.dot(ds, k)
380
- tl.store(dq_ptrs, dq, eviction_policy="evict_last")
381
- elif EVEN_HEADDIM:
382
- dq = tl.load(dq_ptrs, mask=offs_m_curr[:, None] < seqlen_q, other=0.0, eviction_policy="evict_last")
383
- dq += tl.dot(ds, k)
384
- tl.store(dq_ptrs, dq, mask=offs_m_curr[:, None] < seqlen_q, eviction_policy="evict_last")
385
- else:
386
- dq = tl.load(dq_ptrs, mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim), other=0.0, eviction_policy="evict_last")
387
- dq += tl.dot(ds, k)
388
- tl.store(dq_ptrs, dq, mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim), eviction_policy="evict_last")
389
- else:
390
- dq = tl.dot(ds, k)
391
- if EVEN_M & EVEN_HEADDIM:
392
- tl.atomic_add(dq_ptrs, dq)
393
- elif EVEN_HEADDIM:
394
- tl.atomic_add(dq_ptrs, dq, mask=offs_m_curr[:, None] < seqlen_q)
395
- else:
396
- tl.atomic_add(dq_ptrs, dq, mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim))
397
- dq_ptrs += BLOCK_M * stride_dqm
398
- q_ptrs += BLOCK_M * stride_qm
399
- do_ptrs += BLOCK_M * stride_dom
400
- if BIAS_TYPE == "matrix":
401
- b_ptrs += BLOCK_M * stride_bm
402
- dv_ptrs = DV + (offs_n[:, None] * stride_dvn + offs_d[None, :])
403
- dk_ptrs = DK + (offs_n[:, None] * stride_dkn + offs_d[None, :])
404
- _bwd_store_dk_dv(dk_ptrs, dv_ptrs, dk, dv, offs_n, offs_d, seqlen_k, headdim, EVEN_M=EVEN_M, EVEN_N=EVEN_N, EVEN_HEADDIM=EVEN_HEADDIM)
405
-
406
-
407
- def init_to_zero(name):
408
- return lambda nargs: nargs[name].zero_()
409
-
410
-
411
- @triton.autotune(
412
- configs=[
413
- triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "SEQUENCE_PARALLEL": False}, num_warps=8, num_stages=1, pre_hook=init_to_zero("DQ")),
414
- triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "SEQUENCE_PARALLEL": True}, num_warps=8, num_stages=1, pre_hook=init_to_zero("DQ")),
415
- ],
416
- key=["CACHE_KEY_SEQLEN_Q", "CACHE_KEY_SEQLEN_K", "BIAS_TYPE", "IS_CAUSAL", "BLOCK_HEADDIM"],
417
- )
418
- @triton.heuristics(
419
- {
420
- "EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0,
421
- "EVEN_N": lambda args: args["seqlen_k"] % args["BLOCK_N"] == 0,
422
- "EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"],
423
- }
424
- )
425
- @triton.jit
426
- def _bwd_kernel(
427
- Q,
428
- K,
429
- V,
430
- Bias,
431
- DO,
432
- DQ,
433
- DK,
434
- DV,
435
- LSE,
436
- D,
437
- softmax_scale,
438
- stride_qb,
439
- stride_qh,
440
- stride_qm,
441
- stride_kb,
442
- stride_kh,
443
- stride_kn,
444
- stride_vb,
445
- stride_vh,
446
- stride_vn,
447
- stride_bb,
448
- stride_bh,
449
- stride_bm,
450
- stride_dob,
451
- stride_doh,
452
- stride_dom,
453
- stride_dqb,
454
- stride_dqh,
455
- stride_dqm,
456
- stride_dkb,
457
- stride_dkh,
458
- stride_dkn,
459
- stride_dvb,
460
- stride_dvh,
461
- stride_dvn,
462
- nheads,
463
- seqlen_q,
464
- seqlen_k,
465
- seqlen_q_rounded,
466
- headdim,
467
- CACHE_KEY_SEQLEN_Q,
468
- CACHE_KEY_SEQLEN_K,
469
- BIAS_TYPE: tl.constexpr,
470
- IS_CAUSAL: tl.constexpr,
471
- BLOCK_HEADDIM: tl.constexpr,
472
- SEQUENCE_PARALLEL: tl.constexpr,
473
- EVEN_M: tl.constexpr,
474
- EVEN_N: tl.constexpr,
475
- EVEN_HEADDIM: tl.constexpr,
476
- BLOCK_M: tl.constexpr,
477
- BLOCK_N: tl.constexpr,
478
- ):
479
- off_hb = tl.program_id(1)
480
- off_b = off_hb // nheads
481
- off_h = off_hb % nheads
482
- Q += off_b * stride_qb + off_h * stride_qh
483
- K += off_b * stride_kb + off_h * stride_kh
484
- V += off_b * stride_vb + off_h * stride_vh
485
- DO += off_b * stride_dob + off_h * stride_doh
486
- DQ += off_b * stride_dqb + off_h * stride_dqh
487
- DK += off_b * stride_dkb + off_h * stride_dkh
488
- DV += off_b * stride_dvb + off_h * stride_dvh
489
- if BIAS_TYPE != "none":
490
- Bias += off_b * stride_bb + off_h * stride_bh
491
- D += off_hb * seqlen_q_rounded
492
- LSE += off_hb * seqlen_q_rounded
493
- if not SEQUENCE_PARALLEL:
494
- num_block_n = tl.cdiv(seqlen_k, BLOCK_N)
495
- for start_n in range(0, num_block_n):
496
- _bwd_kernel_one_col_block(
497
- start_n,
498
- Q,
499
- K,
500
- V,
501
- Bias,
502
- DO,
503
- DQ,
504
- DK,
505
- DV,
506
- LSE,
507
- D,
508
- softmax_scale,
509
- stride_qm,
510
- stride_kn,
511
- stride_vn,
512
- stride_bm,
513
- stride_dom,
514
- stride_dqm,
515
- stride_dkn,
516
- stride_dvn,
517
- seqlen_q,
518
- seqlen_k,
519
- headdim,
520
- ATOMIC_ADD=False,
521
- BIAS_TYPE=BIAS_TYPE,
522
- IS_CAUSAL=IS_CAUSAL,
523
- BLOCK_HEADDIM=BLOCK_HEADDIM,
524
- EVEN_M=EVEN_M,
525
- EVEN_N=EVEN_N,
526
- EVEN_HEADDIM=EVEN_HEADDIM,
527
- BLOCK_M=BLOCK_M,
528
- BLOCK_N=BLOCK_N,
529
- )
530
- else:
531
- start_n = tl.program_id(0)
532
- _bwd_kernel_one_col_block(
533
- start_n,
534
- Q,
535
- K,
536
- V,
537
- Bias,
538
- DO,
539
- DQ,
540
- DK,
541
- DV,
542
- LSE,
543
- D,
544
- softmax_scale,
545
- stride_qm,
546
- stride_kn,
547
- stride_vn,
548
- stride_bm,
549
- stride_dom,
550
- stride_dqm,
551
- stride_dkn,
552
- stride_dvn,
553
- seqlen_q,
554
- seqlen_k,
555
- headdim,
556
- ATOMIC_ADD=True,
557
- BIAS_TYPE=BIAS_TYPE,
558
- IS_CAUSAL=IS_CAUSAL,
559
- BLOCK_HEADDIM=BLOCK_HEADDIM,
560
- EVEN_M=EVEN_M,
561
- EVEN_N=EVEN_N,
562
- EVEN_HEADDIM=EVEN_HEADDIM,
563
- BLOCK_M=BLOCK_M,
564
- BLOCK_N=BLOCK_N,
565
- )
566
-
567
-
568
- def _flash_attn_forward(q, k, v, bias=None, causal=False, softmax_scale=None):
569
- (batch, seqlen_q, nheads, d) = q.shape
570
- (_, seqlen_k, _, _) = k.shape
571
- assert k.shape == (batch, seqlen_k, nheads, d)
572
- assert v.shape == (batch, seqlen_k, nheads, d)
573
- assert d <= 128, "FlashAttention only support head dimensions up to 128"
574
- assert q.dtype == k.dtype == v.dtype, "All tensors must have the same type"
575
- assert q.dtype in [torch.float16, torch.bfloat16], "Only support fp16 and bf16"
576
- assert q.is_cuda and k.is_cuda and v.is_cuda
577
- softmax_scale = softmax_scale or 1.0 / math.sqrt(d)
578
- has_bias = bias is not None
579
- bias_type = "none"
580
- if has_bias:
581
- assert bias.dtype in [q.dtype, torch.float]
582
- assert bias.is_cuda
583
- assert bias.dim() == 4
584
- if bias.stride(-1) != 1:
585
- bias = bias.contiguous()
586
- if bias.shape[2:] == (1, seqlen_k):
587
- bias_type = "vector"
588
- elif bias.shape[2:] == (seqlen_q, seqlen_k):
589
- bias_type = "matrix"
590
- else:
591
- raise RuntimeError("Last 2 dimensions of bias must be (1, seqlen_k) or (seqlen_q, seqlen_k)")
592
- bias = bias.expand(batch, nheads, seqlen_q, seqlen_k)
593
- bias_strides = (bias.stride(0), bias.stride(1), bias.stride(2)) if has_bias else (0, 0, 0)
594
- seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128
595
- lse = torch.empty((batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32)
596
- tmp = torch.empty((batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32)
597
- o = torch.empty_like(q)
598
- BLOCK_HEADDIM = max(triton.next_power_of_2(d), 16)
599
- BLOCK = 128
600
- num_warps = 4 if d <= 64 else 8
601
- grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK_M"]), batch * nheads)
602
- _fwd_kernel[grid](
603
- q,
604
- k,
605
- v,
606
- bias,
607
- o,
608
- lse,
609
- tmp,
610
- softmax_scale,
611
- q.stride(0),
612
- q.stride(2),
613
- q.stride(1),
614
- k.stride(0),
615
- k.stride(2),
616
- k.stride(1),
617
- v.stride(0),
618
- v.stride(2),
619
- v.stride(1),
620
- *bias_strides,
621
- o.stride(0),
622
- o.stride(2),
623
- o.stride(1),
624
- nheads,
625
- seqlen_q,
626
- seqlen_k,
627
- seqlen_q_rounded,
628
- d,
629
- seqlen_q // 32,
630
- seqlen_k // 32,
631
- bias_type,
632
- causal,
633
- BLOCK_HEADDIM,
634
- BLOCK_M=BLOCK,
635
- BLOCK_N=BLOCK,
636
- num_warps=num_warps,
637
- num_stages=1
638
- )
639
- return (o, lse, softmax_scale)
640
-
641
-
642
- def _flash_attn_backward(do, q, k, v, o, lse, dq, dk, dv, bias=None, causal=False, softmax_scale=None):
643
- if do.stride(-1) != 1:
644
- do = do.contiguous()
645
- (batch, seqlen_q, nheads, d) = q.shape
646
- (_, seqlen_k, _, _) = k.shape
647
- assert d <= 128
648
- seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128
649
- assert lse.shape == (batch, nheads, seqlen_q_rounded)
650
- assert q.stride(-1) == k.stride(-1) == v.stride(-1) == o.stride(-1) == 1
651
- assert dq.stride(-1) == dk.stride(-1) == dv.stride(-1) == 1
652
- softmax_scale = softmax_scale or 1.0 / math.sqrt(d)
653
- dq_accum = torch.empty_like(q, dtype=torch.float32)
654
- delta = torch.empty_like(lse)
655
- BLOCK_HEADDIM = max(triton.next_power_of_2(d), 16)
656
- grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK_M"]), batch * nheads)
657
- _bwd_preprocess_do_o_dot[grid](
658
- o,
659
- do,
660
- delta,
661
- o.stride(0),
662
- o.stride(2),
663
- o.stride(1),
664
- do.stride(0),
665
- do.stride(2),
666
- do.stride(1),
667
- nheads,
668
- seqlen_q,
669
- seqlen_q_rounded,
670
- d,
671
- BLOCK_M=128,
672
- BLOCK_HEADDIM=BLOCK_HEADDIM,
673
- )
674
- has_bias = bias is not None
675
- bias_type = "none"
676
- if has_bias:
677
- assert bias.dtype in [q.dtype, torch.float]
678
- assert bias.is_cuda
679
- assert bias.dim() == 4
680
- assert bias.stride(-1) == 1
681
- if bias.shape[2:] == (1, seqlen_k):
682
- bias_type = "vector"
683
- elif bias.shape[2:] == (seqlen_q, seqlen_k):
684
- bias_type = "matrix"
685
- else:
686
- raise RuntimeError("Last 2 dimensions of bias must be (1, seqlen_k) or (seqlen_q, seqlen_k)")
687
- bias = bias.expand(batch, nheads, seqlen_q, seqlen_k)
688
- bias_strides = (bias.stride(0), bias.stride(1), bias.stride(2)) if has_bias else (0, 0, 0)
689
- grid = lambda META: (triton.cdiv(seqlen_k, META["BLOCK_N"]) if META["SEQUENCE_PARALLEL"] else 1, batch * nheads)
690
- _bwd_kernel[grid](
691
- q,
692
- k,
693
- v,
694
- bias,
695
- do,
696
- dq_accum,
697
- dk,
698
- dv,
699
- lse,
700
- delta,
701
- softmax_scale,
702
- q.stride(0),
703
- q.stride(2),
704
- q.stride(1),
705
- k.stride(0),
706
- k.stride(2),
707
- k.stride(1),
708
- v.stride(0),
709
- v.stride(2),
710
- v.stride(1),
711
- *bias_strides,
712
- do.stride(0),
713
- do.stride(2),
714
- do.stride(1),
715
- dq_accum.stride(0),
716
- dq_accum.stride(2),
717
- dq_accum.stride(1),
718
- dk.stride(0),
719
- dk.stride(2),
720
- dk.stride(1),
721
- dv.stride(0),
722
- dv.stride(2),
723
- dv.stride(1),
724
- nheads,
725
- seqlen_q,
726
- seqlen_k,
727
- seqlen_q_rounded,
728
- d,
729
- seqlen_q // 32,
730
- seqlen_k // 32,
731
- bias_type,
732
- causal,
733
- BLOCK_HEADDIM
734
- )
735
- dq.copy_(dq_accum)
736
-
737
-
738
- class FlashAttnQKVPackedFunc(torch.autograd.Function):
739
- @staticmethod
740
- def forward(ctx, qkv, bias=None, causal=False, softmax_scale=None):
741
- """
742
- qkv: (batch, seqlen, 3, nheads, headdim)
743
- bias: optional, shape broadcastible to (batch, nheads, seqlen, seqlen).
744
- For example, ALiBi mask for causal would have shape (1, nheads, 1, seqlen).
745
- ALiBi mask for non-causal would have shape (1, nheads, seqlen, seqlen)
746
- """
747
- if qkv.stride(-1) != 1:
748
- qkv = qkv.contiguous()
749
- (o, lse, ctx.softmax_scale) = _flash_attn_forward(qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2], bias=bias, causal=causal, softmax_scale=softmax_scale)
750
- ctx.save_for_backward(qkv, o, lse, bias)
751
- ctx.causal = causal
752
- return o
753
-
754
- @staticmethod
755
- def backward(ctx, do):
756
- (qkv, o, lse, bias) = ctx.saved_tensors
757
- assert not ctx.needs_input_grad[1], "FlashAttention does not support bias gradient yet"
758
- with torch.inference_mode():
759
- dqkv = torch.empty_like(qkv)
760
- _flash_attn_backward(
761
- do,
762
- qkv[:, :, 0],
763
- qkv[:, :, 1],
764
- qkv[:, :, 2],
765
- o,
766
- lse,
767
- dqkv[:, :, 0],
768
- dqkv[:, :, 1],
769
- dqkv[:, :, 2],
770
- bias=bias,
771
- causal=ctx.causal,
772
- softmax_scale=ctx.softmax_scale,
773
- )
774
- return (dqkv, None, None, None)
775
-
776
-
777
- flash_attn_qkvpacked_func = FlashAttnQKVPackedFunc.apply
778
-
779
-
780
- class FlashAttnKVPackedFunc(torch.autograd.Function):
781
- @staticmethod
782
- def forward(ctx, q, kv, bias=None, causal=False, softmax_scale=None):
783
- """
784
- q: (batch, seqlen_q, nheads, headdim)
785
- kv: (batch, seqlen_k, 2, nheads, headdim)
786
- bias: optional, shape broadcastible to (batch, nheads, seqlen_q, seqlen_k).
787
- For example, ALiBi mask for causal would have shape (1, nheads, 1, seqlen_k).
788
- ALiBi mask for non-causal would have shape (1, nheads, seqlen_q, seqlen_k)
789
- """
790
- (q, kv) = [x if x.stride(-1) == 1 else x.contiguous() for x in [q, kv]]
791
- (o, lse, ctx.softmax_scale) = _flash_attn_forward(q, kv[:, :, 0], kv[:, :, 1], bias=bias, causal=causal, softmax_scale=softmax_scale)
792
- ctx.save_for_backward(q, kv, o, lse, bias)
793
- ctx.causal = causal
794
- return o
795
-
796
- @staticmethod
797
- def backward(ctx, do):
798
- (q, kv, o, lse, bias) = ctx.saved_tensors
799
- if len(ctx.needs_input_grad) >= 3:
800
- assert not ctx.needs_input_grad[2], "FlashAttention does not support bias gradient yet"
801
- with torch.inference_mode():
802
- dq = torch.empty_like(q)
803
- dkv = torch.empty_like(kv)
804
- _flash_attn_backward(
805
- do, q, kv[:, :, 0], kv[:, :, 1], o, lse, dq, dkv[:, :, 0], dkv[:, :, 1], bias=bias, causal=ctx.causal, softmax_scale=ctx.softmax_scale
806
- )
807
- return (dq, dkv, None, None, None)
808
-
809
-
810
- flash_attn_kvpacked_func = FlashAttnKVPackedFunc.apply
811
-
812
-
813
- class FlashAttnFunc(torch.autograd.Function):
814
- @staticmethod
815
- def forward(ctx, q, k, v, bias=None, causal=False, softmax_scale=None):
816
- """
817
- q: (batch_size, seqlen_q, nheads, headdim)
818
- k, v: (batch_size, seqlen_k, nheads, headdim)
819
- bias: optional, shape broadcastible to (batch, nheads, seqlen_q, seqlen_k).
820
- For example, ALiBi mask for causal would have shape (1, nheads, 1, seqlen_k).
821
- ALiBi mask for non-causal would have shape (1, nheads, seqlen_q, seqlen_k)
822
- """
823
- (q, k, v) = [x if x.stride(-1) == 1 else x.contiguous() for x in [q, k, v]]
824
- (o, lse, ctx.softmax_scale) = _flash_attn_forward(q, k, v, bias=bias, causal=causal, softmax_scale=softmax_scale)
825
- ctx.save_for_backward(q, k, v, o, lse, bias)
826
- ctx.causal = causal
827
- return o
828
-
829
- @staticmethod
830
- def backward(ctx, do):
831
- (q, k, v, o, lse, bias) = ctx.saved_tensors
832
- assert not ctx.needs_input_grad[3], "FlashAttention does not support bias gradient yet"
833
- with torch.inference_mode():
834
- dq = torch.empty_like(q)
835
- dk = torch.empty_like(k)
836
- dv = torch.empty_like(v)
837
- _flash_attn_backward(do, q, k, v, o, lse, dq, dk, dv, bias=bias, causal=ctx.causal, softmax_scale=ctx.softmax_scale)
838
- return (dq, dk, dv, None, None, None)
839
-
840
-
841
- flash_attn_func = FlashAttnFunc.apply
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mllm/flamingo/mpt/hf_prefixlm_converter.py DELETED
@@ -1,575 +0,0 @@
1
- """Converts Huggingface Causal LM to Prefix LM.
2
-
3
- Conversion does lightweight surgery on a HuggingFace
4
- Causal LM to convert it to a Prefix LM.
5
-
6
- Prefix LMs accepts a `bidirectional_mask` input in `forward`
7
- and treat the input prompt as the prefix in `generate`.
8
- """
9
- import math
10
- import warnings
11
- from types import MethodType
12
- from typing import Any, Dict, List, Optional, Tuple, Union
13
- import torch
14
- from transformers.models.bloom.modeling_bloom import (
15
- BaseModelOutputWithPastAndCrossAttentions,
16
- BloomForCausalLM,
17
- BloomModel,
18
- CausalLMOutputWithCrossAttentions,
19
- CrossEntropyLoss,
20
- )
21
- from transformers.models.bloom.modeling_bloom import _expand_mask as _expand_mask_bloom
22
- from transformers.models.bloom.modeling_bloom import _make_causal_mask as _make_causal_mask_bloom
23
- from transformers.models.bloom.modeling_bloom import logging
24
- from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel
25
- from transformers.models.gpt_neo.modeling_gpt_neo import GPTNeoForCausalLM
26
- from transformers.models.gpt_neox.modeling_gpt_neox import GPTNeoXForCausalLM
27
- from transformers.models.gptj.modeling_gptj import GPTJForCausalLM
28
- from transformers.models.opt.modeling_opt import OPTForCausalLM
29
- from transformers.models.opt.modeling_opt import _expand_mask as _expand_mask_opt
30
- from transformers.models.opt.modeling_opt import _make_causal_mask as _make_causal_mask_opt
31
-
32
- logger = logging.get_logger(__name__)
33
- _SUPPORTED_GPT_MODELS = (GPT2LMHeadModel, GPTJForCausalLM, GPTNeoForCausalLM, GPTNeoXForCausalLM)
34
- CAUSAL_GPT_TYPES = Union[GPT2LMHeadModel, GPTJForCausalLM, GPTNeoForCausalLM, GPTNeoXForCausalLM]
35
-
36
-
37
- def _convert_gpt_causal_lm_to_prefix_lm(model: CAUSAL_GPT_TYPES) -> CAUSAL_GPT_TYPES:
38
- """Converts a GPT-style Causal LM to a Prefix LM.
39
-
40
- Supported HuggingFace model classes:
41
- - `GPT2LMHeadModel`
42
- - `GPTNeoForCausalLM`
43
- - `GPTNeoXForCausalLM`
44
- - `GPTJForCausalLM`
45
-
46
- See `convert_hf_causal_lm_to_prefix_lm` for more details.
47
- """
48
- if hasattr(model, "_prefix_lm_converted"):
49
- return model
50
- assert isinstance(model, _SUPPORTED_GPT_MODELS)
51
- assert model.config.add_cross_attention == False, "Only supports GPT-style decoder-only models"
52
-
53
- def _get_attn_modules(model: CAUSAL_GPT_TYPES) -> List[torch.nn.Module]:
54
- """Helper that gets a list of the model's attention modules.
55
-
56
- Each module has a `bias` buffer used for causal masking. The Prefix LM
57
- conversion adds logic to dynamically manipulate these biases to support
58
- Prefix LM attention masking.
59
- """
60
- attn_modules = []
61
- if isinstance(model, GPTNeoXForCausalLM):
62
- blocks = model.gpt_neox.layers
63
- else:
64
- blocks = model.transformer.h
65
- for block in blocks:
66
- if isinstance(model, GPTNeoForCausalLM):
67
- if block.attn.attention_type != "global":
68
- continue
69
- attn_module = block.attn.attention
70
- elif isinstance(model, GPTNeoXForCausalLM):
71
- attn_module = block.attention
72
- else:
73
- attn_module = block.attn
74
- attn_modules.append(attn_module)
75
- return attn_modules
76
-
77
- setattr(model, "_original_forward", getattr(model, "forward"))
78
- setattr(model, "_original_generate", getattr(model, "generate"))
79
-
80
- def forward(
81
- self: CAUSAL_GPT_TYPES,
82
- input_ids: Optional[torch.LongTensor] = None,
83
- past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
84
- attention_mask: Optional[torch.FloatTensor] = None,
85
- bidirectional_mask: Optional[torch.Tensor] = None,
86
- token_type_ids: Optional[torch.LongTensor] = None,
87
- position_ids: Optional[torch.LongTensor] = None,
88
- head_mask: Optional[torch.FloatTensor] = None,
89
- inputs_embeds: Optional[torch.FloatTensor] = None,
90
- labels: Optional[torch.LongTensor] = None,
91
- use_cache: Optional[bool] = None,
92
- output_attentions: Optional[bool] = None,
93
- output_hidden_states: Optional[bool] = None,
94
- return_dict: Optional[bool] = None,
95
- ):
96
- """Wraps original forward to enable PrefixLM attention."""
97
-
98
- def call_og_forward():
99
- if isinstance(self, GPTNeoXForCausalLM):
100
- return self._original_forward(
101
- input_ids=input_ids,
102
- past_key_values=past_key_values,
103
- attention_mask=attention_mask,
104
- head_mask=head_mask,
105
- inputs_embeds=inputs_embeds,
106
- labels=labels,
107
- use_cache=use_cache,
108
- output_attentions=output_attentions,
109
- output_hidden_states=output_hidden_states,
110
- return_dict=return_dict,
111
- )
112
- else:
113
- return self._original_forward(
114
- input_ids=input_ids,
115
- past_key_values=past_key_values,
116
- attention_mask=attention_mask,
117
- token_type_ids=token_type_ids,
118
- position_ids=position_ids,
119
- head_mask=head_mask,
120
- inputs_embeds=inputs_embeds,
121
- labels=labels,
122
- use_cache=use_cache,
123
- output_attentions=output_attentions,
124
- output_hidden_states=output_hidden_states,
125
- return_dict=return_dict,
126
- )
127
-
128
- if bidirectional_mask is None:
129
- return call_og_forward()
130
- assert isinstance(bidirectional_mask, torch.Tensor)
131
- attn_modules = _get_attn_modules(model)
132
- (b, s) = bidirectional_mask.shape
133
- max_length = attn_modules[0].bias.shape[-1]
134
- if s > max_length:
135
- raise ValueError(f"bidirectional_mask sequence length (={s}) exceeds the " + f"max length allowed by the model ({max_length}).")
136
- assert s <= max_length
137
- if s < max_length:
138
- pad = torch.zeros((int(b), int(max_length - s)), dtype=bidirectional_mask.dtype, device=bidirectional_mask.device)
139
- bidirectional_mask = torch.cat([bidirectional_mask, pad], dim=1)
140
- bidirectional = bidirectional_mask.unsqueeze(1).unsqueeze(1)
141
- for attn_module in attn_modules:
142
- attn_module.bias.data = torch.logical_or(attn_module.bias.data, bidirectional)
143
- output = call_og_forward()
144
- for attn_module in attn_modules:
145
- attn_module.bias.data = torch.tril(attn_module.bias.data[0, 0])[None, None]
146
- return output
147
-
148
- def generate(self: CAUSAL_GPT_TYPES, *args: tuple, **kwargs: Dict[str, Any]):
149
- """Wraps original generate to enable PrefixLM attention."""
150
- attn_modules = _get_attn_modules(model)
151
- for attn_module in attn_modules:
152
- attn_module.bias.data[:] = 1
153
- output = self._original_generate(*args, **kwargs)
154
- for attn_module in attn_modules:
155
- attn_module.bias.data = torch.tril(attn_module.bias.data[0, 0])[None, None]
156
- return output
157
-
158
- setattr(model, "forward", MethodType(forward, model))
159
- setattr(model, "generate", MethodType(generate, model))
160
- setattr(model, "_prefix_lm_converted", True)
161
- return model
162
-
163
-
164
- def _convert_bloom_causal_lm_to_prefix_lm(model: BloomForCausalLM) -> BloomForCausalLM:
165
- """Converts a BLOOM Causal LM to a Prefix LM.
166
-
167
- Supported HuggingFace model classes:
168
- - `BloomForCausalLM`
169
-
170
- See `convert_hf_causal_lm_to_prefix_lm` for more details.
171
- """
172
- if hasattr(model, "_prefix_lm_converted"):
173
- return model
174
- assert isinstance(model, BloomForCausalLM)
175
- assert model.config.add_cross_attention == False, "Only supports BLOOM decoder-only models"
176
-
177
- def _prepare_attn_mask(
178
- self: BloomModel, attention_mask: torch.Tensor, bidirectional_mask: Optional[torch.Tensor], input_shape: Tuple[int, int], past_key_values_length: int
179
- ) -> torch.BoolTensor:
180
- combined_attention_mask = None
181
- device = attention_mask.device
182
- (_, src_length) = input_shape
183
- if src_length > 1:
184
- combined_attention_mask = _make_causal_mask_bloom(input_shape, device=device, past_key_values_length=past_key_values_length)
185
- if bidirectional_mask is not None:
186
- assert attention_mask.shape == bidirectional_mask.shape
187
- expanded_bidirectional_mask = _expand_mask_bloom(bidirectional_mask, tgt_length=src_length)
188
- combined_attention_mask = torch.logical_and(combined_attention_mask, expanded_bidirectional_mask)
189
- expanded_attn_mask = _expand_mask_bloom(attention_mask, tgt_length=src_length)
190
- combined_attention_mask = expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask | combined_attention_mask
191
- return combined_attention_mask
192
-
193
- def _build_alibi_tensor(self: BloomModel, batch_size: int, query_length: int, key_length: int, dtype: torch.dtype, device: torch.device) -> torch.Tensor:
194
- num_heads = self.config.n_head
195
- closest_power_of_2 = 2 ** math.floor(math.log2(num_heads))
196
- base = torch.tensor(2 ** (-(2 ** (-(math.log2(closest_power_of_2) - 3)))), device=device, dtype=torch.float32)
197
- powers = torch.arange(1, 1 + closest_power_of_2, device=device, dtype=torch.int32)
198
- slopes = torch.pow(base, powers)
199
- if closest_power_of_2 != num_heads:
200
- extra_base = torch.tensor(2 ** (-(2 ** (-(math.log2(2 * closest_power_of_2) - 3)))), device=device, dtype=torch.float32)
201
- num_remaining_heads = min(closest_power_of_2, num_heads - closest_power_of_2)
202
- extra_powers = torch.arange(1, 1 + 2 * num_remaining_heads, 2, device=device, dtype=torch.int32)
203
- slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0)
204
- qa = torch.arange(query_length, device=device, dtype=torch.int32).view(-1, 1)
205
- ka = torch.arange(key_length, device=device, dtype=torch.int32).view(1, -1)
206
- diffs = qa - ka + key_length - query_length
207
- diffs = -diffs.abs()
208
- alibi = slopes.view(1, num_heads, 1, 1) * diffs.view(1, 1, query_length, key_length)
209
- alibi = alibi.expand(batch_size, -1, -1, -1).reshape(-1, query_length, key_length)
210
- return alibi.to(dtype)
211
-
212
- KeyValueT = Tuple[torch.Tensor, torch.Tensor]
213
-
214
- def forward(
215
- self: BloomModel,
216
- input_ids: Optional[torch.LongTensor] = None,
217
- past_key_values: Optional[Tuple[KeyValueT, ...]] = None,
218
- attention_mask: Optional[torch.Tensor] = None,
219
- bidirectional_mask: Optional[torch.Tensor] = None,
220
- head_mask: Optional[torch.LongTensor] = None,
221
- inputs_embeds: Optional[torch.LongTensor] = None,
222
- use_cache: Optional[bool] = None,
223
- output_attentions: Optional[bool] = None,
224
- output_hidden_states: Optional[bool] = None,
225
- return_dict: Optional[bool] = None,
226
- **deprecated_arguments,
227
- ) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]:
228
- if deprecated_arguments.pop("position_ids", False) is not False:
229
- warnings.warn(
230
- "`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. " + "You can safely ignore passing `position_ids`.", FutureWarning
231
- )
232
- if len(deprecated_arguments) > 0:
233
- raise ValueError(f"Got unexpected arguments: {deprecated_arguments}")
234
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
235
- output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
236
- use_cache = use_cache if use_cache is not None else self.config.use_cache
237
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
238
- if input_ids is not None and inputs_embeds is not None:
239
- raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
240
- elif input_ids is not None:
241
- (batch_size, seq_length) = input_ids.shape
242
- elif inputs_embeds is not None:
243
- (batch_size, seq_length, _) = inputs_embeds.shape
244
- else:
245
- raise ValueError("You have to specify either input_ids or inputs_embeds")
246
- if past_key_values is None:
247
- past_key_values = tuple([None] * len(self.h))
248
- head_mask = self.get_head_mask(head_mask, self.config.n_layer)
249
- if inputs_embeds is None:
250
- inputs_embeds = self.word_embeddings(input_ids)
251
- hidden_states = self.word_embeddings_layernorm(inputs_embeds)
252
- presents = () if use_cache else None
253
- all_self_attentions = () if output_attentions else None
254
- all_hidden_states = () if output_hidden_states else None
255
- seq_length_with_past = seq_length
256
- past_key_values_length = 0
257
- if past_key_values[0] is not None:
258
- tmp = past_key_values[0][0]
259
- past_key_values_length = tmp.shape[2]
260
- seq_length_with_past = seq_length_with_past + past_key_values_length
261
- if attention_mask is None:
262
- attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device)
263
- else:
264
- attention_mask = attention_mask.to(hidden_states.device)
265
- alibi = self._build_alibi_tensor(
266
- batch_size=batch_size, query_length=seq_length, key_length=seq_length_with_past, dtype=hidden_states.dtype, device=hidden_states.device
267
- )
268
- causal_mask = self._prepare_attn_mask(
269
- attention_mask, bidirectional_mask, input_shape=(batch_size, seq_length), past_key_values_length=past_key_values_length
270
- )
271
- for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
272
- if output_hidden_states:
273
- hst = (hidden_states,)
274
- all_hidden_states = all_hidden_states + hst
275
- if self.gradient_checkpointing and self.training:
276
- if use_cache:
277
- logger.warning("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
278
- use_cache = False
279
-
280
- def create_custom_forward(module):
281
- def custom_forward(*inputs):
282
- return module(*inputs, use_cache=use_cache, output_attentions=output_attentions)
283
-
284
- return custom_forward
285
-
286
- outputs = torch.utils.checkpoint.checkpoint(create_custom_forward(block), hidden_states, alibi, causal_mask, head_mask[i])
287
- else:
288
- outputs = block(
289
- hidden_states,
290
- layer_past=layer_past,
291
- attention_mask=causal_mask,
292
- head_mask=head_mask[i],
293
- use_cache=use_cache,
294
- output_attentions=output_attentions,
295
- alibi=alibi,
296
- )
297
- hidden_states = outputs[0]
298
- if use_cache is True:
299
- presents = presents + (outputs[1],)
300
- if output_attentions:
301
- oa = (outputs[2 if use_cache else 1],)
302
- all_self_attentions = all_self_attentions + oa
303
- hidden_states = self.ln_f(hidden_states)
304
- if output_hidden_states:
305
- hst = (hidden_states,)
306
- all_hidden_states = all_hidden_states + hst
307
- if not return_dict:
308
- return tuple((v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None))
309
- return BaseModelOutputWithPastAndCrossAttentions(
310
- last_hidden_state=hidden_states, past_key_values=presents, hidden_states=all_hidden_states, attentions=all_self_attentions
311
- )
312
-
313
- setattr(model.transformer, "_prepare_attn_mask", MethodType(_prepare_attn_mask, model.transformer))
314
- setattr(model.transformer, "_build_alibi_tensor", MethodType(_build_alibi_tensor, model.transformer))
315
- setattr(model.transformer, "forward", MethodType(forward, model.transformer))
316
- KeyValueT = Tuple[torch.Tensor, torch.Tensor]
317
-
318
- def forward(
319
- self: BloomForCausalLM,
320
- input_ids: Optional[torch.LongTensor] = None,
321
- past_key_values: Optional[Tuple[KeyValueT, ...]] = None,
322
- attention_mask: Optional[torch.Tensor] = None,
323
- bidirectional_mask: Optional[torch.Tensor] = None,
324
- head_mask: Optional[torch.Tensor] = None,
325
- inputs_embeds: Optional[torch.Tensor] = None,
326
- labels: Optional[torch.Tensor] = None,
327
- use_cache: Optional[bool] = None,
328
- output_attentions: Optional[bool] = None,
329
- output_hidden_states: Optional[bool] = None,
330
- return_dict: Optional[bool] = None,
331
- **deprecated_arguments,
332
- ) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
333
- """Replacement forward method for BloomCausalLM."""
334
- if deprecated_arguments.pop("position_ids", False) is not False:
335
- warnings.warn(
336
- "`position_ids` have no functionality in BLOOM and will be removed " + "in v5.0.0. You can safely ignore passing `position_ids`.", FutureWarning
337
- )
338
- if len(deprecated_arguments) > 0:
339
- raise ValueError(f"Got unexpected arguments: {deprecated_arguments}")
340
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
341
- transformer_outputs = self.transformer(
342
- input_ids,
343
- past_key_values=past_key_values,
344
- attention_mask=attention_mask,
345
- bidirectional_mask=bidirectional_mask,
346
- head_mask=head_mask,
347
- inputs_embeds=inputs_embeds,
348
- use_cache=use_cache,
349
- output_attentions=output_attentions,
350
- output_hidden_states=output_hidden_states,
351
- return_dict=return_dict,
352
- )
353
- hidden_states = transformer_outputs[0]
354
- lm_logits = self.lm_head(hidden_states)
355
- loss = None
356
- if labels is not None:
357
- shift_logits = lm_logits[..., :-1, :].contiguous()
358
- shift_labels = labels[..., 1:].contiguous()
359
- (batch_size, seq_length, vocab_size) = shift_logits.shape
360
- loss_fct = CrossEntropyLoss()
361
- loss = loss_fct(shift_logits.view(batch_size * seq_length, vocab_size), shift_labels.view(batch_size * seq_length))
362
- if not return_dict:
363
- output = (lm_logits,) + transformer_outputs[1:]
364
- return (loss,) + output if loss is not None else output
365
- return CausalLMOutputWithCrossAttentions(
366
- loss=loss,
367
- logits=lm_logits,
368
- past_key_values=transformer_outputs.past_key_values,
369
- hidden_states=transformer_outputs.hidden_states,
370
- attentions=transformer_outputs.attentions,
371
- )
372
-
373
- def prepare_inputs_for_generation(
374
- self: BloomForCausalLM, input_ids: torch.LongTensor, past: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, **kwargs
375
- ) -> dict:
376
- if past:
377
- input_ids = input_ids[:, -1].unsqueeze(-1)
378
- bidirectional_mask = None
379
- if past[0][0].shape[0] == input_ids.shape[0]:
380
- past = self._convert_to_bloom_cache(past)
381
- else:
382
- bidirectional_mask = torch.ones_like(input_ids)
383
- return {"input_ids": input_ids, "past_key_values": past, "use_cache": True, "attention_mask": attention_mask, "bidirectional_mask": bidirectional_mask}
384
-
385
- setattr(model, "forward", MethodType(forward, model))
386
- setattr(model, "prepare_inputs_for_generation", MethodType(prepare_inputs_for_generation, model))
387
- setattr(model, "_prefix_lm_converted", True)
388
- return model
389
-
390
-
391
- def _convert_opt_causal_lm_to_prefix_lm(model: OPTForCausalLM) -> OPTForCausalLM:
392
- """Converts an OPT Causal LM to a Prefix LM.
393
-
394
- Supported HuggingFace model classes:
395
- - `OPTForCausalLM`
396
-
397
- See `convert_hf_causal_lm_to_prefix_lm` for more details.
398
- """
399
- if hasattr(model, "_prefix_lm_converted"):
400
- return model
401
- assert isinstance(model, OPTForCausalLM)
402
- assert model.config.add_cross_attention == False, "Only supports OPT decoder-only models"
403
- setattr(model, "_original_forward", getattr(model, "forward"))
404
- setattr(model, "_original_generate", getattr(model, "generate"))
405
- model.model.decoder.bidirectional_mask = None
406
-
407
- def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
408
- combined_attention_mask = None
409
- if input_shape[-1] > 1:
410
- if self.bidirectional_mask == "g":
411
- (bsz, src_length) = input_shape
412
- combined_attention_mask = torch.zeros(
413
- (bsz, 1, src_length, src_length + past_key_values_length), dtype=inputs_embeds.dtype, device=inputs_embeds.device
414
- )
415
- else:
416
- combined_attention_mask = _make_causal_mask_opt(input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length).to(
417
- inputs_embeds.device
418
- )
419
- if self.bidirectional_mask is not None:
420
- assert attention_mask.shape == self.bidirectional_mask.shape
421
- expanded_bidirectional_mask = _expand_mask_opt(self.bidirectional_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
422
- inputs_embeds.device
423
- )
424
- combined_attention_mask = torch.maximum(expanded_bidirectional_mask, combined_attention_mask)
425
- if attention_mask is not None:
426
- expanded_attn_mask = _expand_mask_opt(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(inputs_embeds.device)
427
- combined_attention_mask = expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
428
- return combined_attention_mask
429
-
430
- setattr(model.model.decoder, "_prepare_decoder_attention_mask", MethodType(_prepare_decoder_attention_mask, model.model.decoder))
431
-
432
- def forward(
433
- self: OPTForCausalLM,
434
- input_ids: Optional[torch.LongTensor] = None,
435
- attention_mask: Optional[torch.Tensor] = None,
436
- bidirectional_mask: Optional[torch.ByteTensor] = None,
437
- head_mask: Optional[torch.Tensor] = None,
438
- past_key_values: Optional[List[torch.FloatTensor]] = None,
439
- inputs_embeds: Optional[torch.FloatTensor] = None,
440
- labels: Optional[torch.LongTensor] = None,
441
- use_cache: Optional[bool] = None,
442
- output_attentions: Optional[bool] = None,
443
- output_hidden_states: Optional[bool] = None,
444
- return_dict: Optional[bool] = None,
445
- ):
446
- def call_og_forward():
447
- return self._original_forward(
448
- input_ids=input_ids,
449
- attention_mask=attention_mask,
450
- head_mask=head_mask,
451
- past_key_values=past_key_values,
452
- inputs_embeds=inputs_embeds,
453
- labels=labels,
454
- use_cache=use_cache,
455
- output_attentions=output_attentions,
456
- output_hidden_states=output_hidden_states,
457
- return_dict=return_dict,
458
- )
459
-
460
- if bidirectional_mask is None:
461
- return call_og_forward()
462
- self.model.decoder.bidirectional_mask = bidirectional_mask
463
- try:
464
- outputs = call_og_forward()
465
- except:
466
- self.model.decoder.bidirectional_mask = None
467
- raise
468
- self.model.decoder.bidirectional_mask = None
469
- return outputs
470
-
471
- def generate(self: OPTForCausalLM, *args: tuple, **kwargs: Dict[str, Any]):
472
- """Wraps original generate to enable PrefixLM-style attention."""
473
- self.model.decoder.bidirectional_mask = "g"
474
- try:
475
- output = self._original_generate(*args, **kwargs)
476
- except:
477
- self.model.decoder.bidirectional_mask = None
478
- raise
479
- self.model.decoder.bidirectional_mask = None
480
- return output
481
-
482
- setattr(model, "forward", MethodType(forward, model))
483
- setattr(model, "generate", MethodType(generate, model))
484
- setattr(model, "_prefix_lm_converted", True)
485
- return model
486
-
487
-
488
- _SUPPORTED_HF_MODELS = _SUPPORTED_GPT_MODELS + (BloomForCausalLM, OPTForCausalLM)
489
- CAUSAL_LM_TYPES = Union[GPT2LMHeadModel, GPTJForCausalLM, GPTNeoForCausalLM, GPTNeoXForCausalLM, BloomForCausalLM, OPTForCausalLM]
490
-
491
-
492
- def convert_hf_causal_lm_to_prefix_lm(model: CAUSAL_LM_TYPES) -> CAUSAL_LM_TYPES:
493
- """Converts a HuggingFace Causal LM to a Prefix LM.
494
-
495
- Supported HuggingFace model classes:
496
- - `GPT2LMHeadModel`
497
- - `GPTNeoForCausalLM`
498
- - `GPTNeoXForCausalLM`
499
- - `GPTJForCausalLM`
500
- - `BloomForCausalLM`
501
- - `OPTForCausalLM`
502
-
503
- Conversion to a Prefix LM is done by modifying the `forward` method, and possibly also the
504
- `generate` method and/or select underlying methods depending on the model class.
505
-
506
- These changes preserve the model API, but add a new input to `forward`: "bidirectional_mask".
507
-
508
- Notes on training:
509
- To actually train the converted model as a Prefix LM, training batches will need to indicate
510
- the prefix/target structure by including `bidirectional_mask` as part of the batch inputs.
511
-
512
- **This is not a standard input and requires custom layers either within or after your dataloader.**
513
-
514
- In addition to adding `bidirectional_mask` to the batch, this custom code should modify `labels`
515
- such that `batch['labels'][batch['bidirectional_mask'] == 1] == -100`.
516
- That is, the prefix portion of the sequence should not generate any loss. Loss should only be
517
- generated by the target portion of the sequence.
518
-
519
- Notes on `GPTNeoForCausalLM`:
520
- To simplify the implementation, "global" and "local" attention layers are handled differently.
521
- For "global" layers, we handle conversion as described above. For "local" layers, which use a
522
- causal attention mask within a restricted local window, we do not alter the masking.
523
-
524
- Notes on `forward` method conversion:
525
- After conversion, the `forward` method will handle a new input, `bidirectional_mask`,
526
- which should be a [batch_size, seq_length] byte tensor, where 1 indicates token positions
527
- belonging to the prefix (prefix tokens can attend to one another bidirectionally), and
528
- 0 indicates token positions belonging to the target.
529
-
530
- The new `forward` method will incorporate `bidirectional_mask` (if supplied) into the existing
531
- causal mask, call the original `forward` method, and (if the causal mask is a buffer) reset
532
- the causal masks before returning the result.
533
-
534
- Notes on `generate` method conversion:
535
- After conversion, the `generate` method will have the same signature but will internally
536
- convert all causal masks to be purely bidirectional, call the original `generate` method, and
537
- (where appropriate) reset the causal masks before returning the result.
538
-
539
- This works thanks to the logic of the HuggingFace `generate` API, which first encodes the token
540
- "prompt" passed to `generate` (which is treated as the prefix) and then sequentially generates
541
- each new token. Encodings are cached as generation happens, so all prefix tokens can attend to one
542
- another (as expected in a Prefix LM) and generated tokens can only attend to prefix tokens and
543
- previously-generated tokens (also as expected in a Prefix LM).
544
-
545
- To preserve the API, the original methods are renamed to `_original_forward` and
546
- `_original_generate`, and replaced with new `forward` and `generate` methods that wrap
547
- them, respectively. Although implementation details vary by model class.
548
- """
549
- if isinstance(model, _SUPPORTED_GPT_MODELS):
550
- return _convert_gpt_causal_lm_to_prefix_lm(model)
551
- elif isinstance(model, BloomForCausalLM):
552
- return _convert_bloom_causal_lm_to_prefix_lm(model)
553
- elif isinstance(model, OPTForCausalLM):
554
- return _convert_opt_causal_lm_to_prefix_lm(model)
555
- else:
556
- raise TypeError(f"Cannot convert model to Prefix LM. " + f"Model does not belong to set of supported HF models:" + f"\n{_SUPPORTED_HF_MODELS}")
557
-
558
-
559
- def add_bidirectional_mask_if_missing(batch: Dict[str, Any]):
560
- """Attempts to add bidirectional_mask to batch if missing.
561
-
562
- Raises:
563
- KeyError if bidirectional_mask is missing and can't be inferred
564
- """
565
- if "bidirectional_mask" not in batch:
566
- if batch.get("mode", None) == "icl_task":
567
- batch["bidirectional_mask"] = batch["attention_mask"].clone()
568
- for i, continuation_indices in enumerate(batch["continuation_indices"]):
569
- batch["bidirectional_mask"][i, continuation_indices] = 0
570
- elif "labels" in batch and "attention_mask" in batch:
571
- batch["bidirectional_mask"] = torch.logical_and(torch.eq(batch["attention_mask"], 1), torch.eq(batch["labels"], -100)).type_as(
572
- batch["attention_mask"]
573
- )
574
- else:
575
- raise KeyError("No bidirectional_mask in batch and not sure how to construct one.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mllm/flamingo/mpt/meta_init_context.py DELETED
@@ -1,98 +0,0 @@
1
- from contextlib import contextmanager
2
- import torch
3
- import torch.nn as nn
4
-
5
-
6
- @contextmanager
7
- def init_empty_weights(include_buffers: bool = False):
8
- """Meta initialization context manager.
9
-
10
- A context manager under which models are initialized with all parameters
11
- on the meta device, therefore creating an empty model. Useful when just
12
- initializing the model would blow the available RAM.
13
-
14
- Args:
15
- include_buffers (`bool`, *optional*, defaults to `False`): Whether or
16
- not to also put all buffers on the meta device while initializing.
17
-
18
- Example:
19
- ```python
20
- import torch.nn as nn
21
-
22
- # Initialize a model with 100 billions parameters in no time and without using any RAM.
23
- with init_empty_weights():
24
- tst = nn.Sequential(*[nn.Linear(10000, 10000) for _ in range(1000)])
25
- ```
26
-
27
- <Tip warning={true}>
28
-
29
- Any model created under this context manager has no weights. As such you can't do something like
30
- `model.to(some_device)` with it. To load weights inside your empty model, see [`load_checkpoint_and_dispatch`].
31
-
32
- </Tip>
33
- """
34
- with init_on_device(torch.device("meta"), include_buffers=include_buffers) as f:
35
- yield f
36
-
37
-
38
- @contextmanager
39
- def init_on_device(device: torch.device, include_buffers: bool = False):
40
- """Device initialization context manager.
41
-
42
- A context manager under which models are initialized with all parameters
43
- on the specified device.
44
-
45
- Args:
46
- device (`torch.device`): Device to initialize all parameters on.
47
- include_buffers (`bool`, *optional*, defaults to `False`): Whether or
48
- not to also put all buffers on the meta device while initializing.
49
-
50
- Example:
51
- ```python
52
- import torch.nn as nn
53
-
54
- with init_on_device(device=torch.device("cuda")):
55
- tst = nn.Liner(100, 100) # on `cuda` device
56
- ```
57
- """
58
- old_register_parameter = nn.Module.register_parameter
59
- if include_buffers:
60
- old_register_buffer = nn.Module.register_buffer
61
-
62
- def register_empty_parameter(module, name, param):
63
- old_register_parameter(module, name, param)
64
- if param is not None:
65
- param_cls = type(module._parameters[name])
66
- kwargs = module._parameters[name].__dict__
67
- module._parameters[name] = param_cls(module._parameters[name].to(device), **kwargs)
68
-
69
- def register_empty_buffer(module, name, buffer):
70
- old_register_buffer(module, name, buffer)
71
- if buffer is not None:
72
- module._buffers[name] = module._buffers[name].to(device)
73
-
74
- if include_buffers:
75
- tensor_constructors_to_patch = {torch_function_name: getattr(torch, torch_function_name) for torch_function_name in ["empty", "zeros", "ones", "full"]}
76
- else:
77
- tensor_constructors_to_patch = {}
78
-
79
- def patch_tensor_constructor(fn):
80
- def wrapper(*args, **kwargs):
81
- kwargs["device"] = device
82
- return fn(*args, **kwargs)
83
-
84
- return wrapper
85
-
86
- try:
87
- nn.Module.register_parameter = register_empty_parameter
88
- if include_buffers:
89
- nn.Module.register_buffer = register_empty_buffer
90
- for torch_function_name in tensor_constructors_to_patch.keys():
91
- setattr(torch, torch_function_name, patch_tensor_constructor(getattr(torch, torch_function_name)))
92
- yield
93
- finally:
94
- nn.Module.register_parameter = old_register_parameter
95
- if include_buffers:
96
- nn.Module.register_buffer = old_register_buffer
97
- for torch_function_name, old_torch_function in tensor_constructors_to_patch.items():
98
- setattr(torch, torch_function_name, old_torch_function)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mllm/flamingo/mpt/modeling_mpt.py DELETED
@@ -1,496 +0,0 @@
1
- """A simple, flexible implementation of a GPT model.
2
-
3
- Inspired by https://github.com/karpathy/minGPT/blob/master/mingpt/model.py
4
- """
5
- import math
6
- import warnings
7
- from typing import List, Optional, Tuple, Union
8
-
9
- import torch
10
- import torch.nn as nn
11
- import torch.nn.functional as F
12
- from transformers import PreTrainedModel, PreTrainedTokenizer, PreTrainedTokenizerFast
13
- from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
14
-
15
- from .attention import attn_bias_shape, build_attn_bias
16
- from .blocks import MPTBlock
17
- from .configuration_mpt import MPTConfig
18
- from .custom_embedding import SharedEmbedding
19
- from .norm import NORM_CLASS_REGISTRY
20
- from .param_init_fns import MODEL_INIT_REGISTRY, generic_param_init_fn_
21
-
22
- import torch.distributed as dist
23
-
24
- try:
25
- from .flash_attn_triton import flash_attn_func
26
- except:
27
- pass
28
- Tokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
29
-
30
-
31
- class MPTPreTrainedModel(PreTrainedModel):
32
- config_class = MPTConfig
33
- base_model_prefix = "model"
34
- _no_split_modules = ["MPTBlock"]
35
-
36
-
37
- class MPTModel(MPTPreTrainedModel):
38
- def __init__(self, config: MPTConfig):
39
- config._validate_config()
40
- super().__init__(config)
41
- self.attn_impl = config.attn_config["attn_impl"]
42
- self.prefix_lm = config.attn_config["prefix_lm"]
43
- self.attn_uses_sequence_id = config.attn_config["attn_uses_sequence_id"]
44
- self.alibi = config.attn_config["alibi"]
45
- self.alibi_bias_max = config.attn_config["alibi_bias_max"]
46
- if config.init_device == "mixed":
47
- if dist.get_local_rank() == 0:
48
- config.init_device = "cpu"
49
- else:
50
- config.init_device = "meta"
51
- if config.norm_type.lower() not in NORM_CLASS_REGISTRY.keys():
52
- norm_options = " | ".join(NORM_CLASS_REGISTRY.keys())
53
- raise NotImplementedError(f"Requested norm type ({config.norm_type}) is not implemented within this repo (Options: {norm_options}).")
54
- norm_class = NORM_CLASS_REGISTRY[config.norm_type.lower()]
55
- self.embedding_fraction = config.embedding_fraction
56
- self.wte = SharedEmbedding(config.vocab_size, config.d_model, device=config.init_device)
57
- if not self.alibi:
58
- self.wpe = torch.nn.Embedding(config.max_seq_len, config.d_model, device=config.init_device)
59
- self.emb_drop = nn.Dropout(config.emb_pdrop)
60
- self.blocks = nn.ModuleList([MPTBlock(device=config.init_device, **config.to_dict()) for _ in range(config.n_layers)])
61
- self.norm_f = norm_class(config.d_model, device=config.init_device)
62
- if config.init_device != "meta":
63
- print(
64
- f'You are using config.init_device={config.init_device!r}, but you can also use config.init_device="meta" with Composer + FSDP for fast initialization.'
65
- )
66
- self.apply(self.param_init_fn)
67
- self.is_causal = not self.prefix_lm
68
- self._attn_bias_initialized = False
69
- self.attn_bias = None
70
- self.attn_bias_shape = attn_bias_shape(
71
- self.attn_impl,
72
- config.n_heads,
73
- config.max_seq_len,
74
- self.alibi,
75
- prefix_lm=self.prefix_lm,
76
- causal=self.is_causal,
77
- use_sequence_id=self.attn_uses_sequence_id,
78
- )
79
- if config.no_bias:
80
- for module in self.modules():
81
- if hasattr(module, "bias") and isinstance(module.bias, nn.Parameter):
82
- if config.verbose:
83
- warnings.warn(f"Removing bias ({module.bias}) from {module}.")
84
- module.register_parameter("bias", None)
85
- if config.verbose and config.verbose > 2:
86
- print(self)
87
- if "verbose" not in self.config.init_config:
88
- self.config.init_config["verbose"] = self.config.verbose
89
- if self.config.init_config["verbose"] > 1:
90
- init_fn_name = self.config.init_config["name"]
91
- warnings.warn(f"Using {init_fn_name} initialization.")
92
-
93
- def get_input_embeddings(self):
94
- return self.wte
95
-
96
- def set_input_embeddings(self, value):
97
- self.wte = value
98
-
99
- @torch.no_grad()
100
- def _attn_bias(
101
- self,
102
- device,
103
- dtype,
104
- attention_mask: Optional[torch.ByteTensor] = None,
105
- prefix_mask: Optional[torch.ByteTensor] = None,
106
- sequence_id: Optional[torch.LongTensor] = None,
107
- ):
108
- if not self._attn_bias_initialized:
109
- if self.attn_bias_shape:
110
- self.attn_bias = torch.zeros(self.attn_bias_shape, device=device, dtype=dtype)
111
- self.attn_bias = build_attn_bias(
112
- self.attn_impl,
113
- self.attn_bias,
114
- self.config.n_heads,
115
- self.config.max_seq_len,
116
- causal=self.is_causal,
117
- alibi=self.alibi,
118
- alibi_bias_max=self.alibi_bias_max,
119
- )
120
- self._attn_bias_initialized = True
121
- if self.attn_impl == "flash":
122
- return (self.attn_bias, attention_mask)
123
- if self.attn_bias is not None:
124
- self.attn_bias = self.attn_bias.to(dtype=dtype, device=device)
125
- attn_bias = self.attn_bias
126
- if self.prefix_lm:
127
- assert isinstance(attn_bias, torch.Tensor)
128
- assert isinstance(prefix_mask, torch.Tensor)
129
- attn_bias = self._apply_prefix_mask(attn_bias, prefix_mask)
130
- if self.attn_uses_sequence_id and sequence_id is not None:
131
- assert isinstance(attn_bias, torch.Tensor)
132
- attn_bias = self._apply_sequence_id(attn_bias, sequence_id)
133
- if attention_mask is not None:
134
- s_k = attention_mask.shape[-1]
135
- if attn_bias is None:
136
- attn_bias = torch.zeros((1, 1, 1, s_k), device=device, dtype=dtype)
137
- else:
138
- _s_k = max(0, attn_bias.size(-1) - s_k)
139
- attn_bias = attn_bias[:, :, :, _s_k:]
140
- if prefix_mask is not None and attention_mask.shape != prefix_mask.shape:
141
- raise ValueError(f"attention_mask shape={attention_mask.shape} " + f"and prefix_mask shape={prefix_mask.shape} are not equal.")
142
- min_val = torch.finfo(attn_bias.dtype).min
143
- attn_bias = attn_bias.masked_fill(~attention_mask.view(-1, 1, 1, s_k), min_val)
144
- return (attn_bias, None)
145
-
146
- def _apply_prefix_mask(self, attn_bias: torch.Tensor, prefix_mask: torch.Tensor):
147
- (s_k, s_q) = attn_bias.shape[-2:]
148
- if s_k != self.config.max_seq_len or s_q != self.config.max_seq_len:
149
- raise ValueError(
150
- "attn_bias does not match the expected shape. "
151
- + f"The last two dimensions should both be {self.config.max_length} "
152
- + f"but are {s_k} and {s_q}."
153
- )
154
- seq_len = prefix_mask.shape[-1]
155
- if seq_len > self.config.max_seq_len:
156
- raise ValueError(f"prefix_mask sequence length cannot exceed max_seq_len={self.config.max_seq_len}")
157
- attn_bias = attn_bias[..., :seq_len, :seq_len]
158
- causal = torch.tril(torch.ones((seq_len, seq_len), dtype=torch.bool, device=prefix_mask.device)).view(1, 1, seq_len, seq_len)
159
- prefix = prefix_mask.view(-1, 1, 1, seq_len)
160
- cannot_attend = ~torch.logical_or(causal, prefix.bool())
161
- min_val = torch.finfo(attn_bias.dtype).min
162
- attn_bias = attn_bias.masked_fill(cannot_attend, min_val)
163
- return attn_bias
164
-
165
- def _apply_sequence_id(self, attn_bias: torch.Tensor, sequence_id: torch.LongTensor):
166
- seq_len = sequence_id.shape[-1]
167
- if seq_len > self.config.max_seq_len:
168
- raise ValueError(f"sequence_id sequence length cannot exceed max_seq_len={self.config.max_seq_len}")
169
- attn_bias = attn_bias[..., :seq_len, :seq_len]
170
- cannot_attend = torch.logical_not(torch.eq(sequence_id.view(-1, seq_len, 1), sequence_id.view(-1, 1, seq_len))).unsqueeze(1)
171
- min_val = torch.finfo(attn_bias.dtype).min
172
- attn_bias = attn_bias.masked_fill(cannot_attend, min_val)
173
- return attn_bias
174
-
175
- def forward(
176
- self,
177
- input_ids: torch.LongTensor,
178
- past_key_values: Optional[List[Tuple[torch.FloatTensor]]] = None,
179
- attention_mask: Optional[torch.ByteTensor] = None,
180
- prefix_mask: Optional[torch.ByteTensor] = None,
181
- sequence_id: Optional[torch.LongTensor] = None,
182
- return_dict: Optional[bool] = None,
183
- output_attentions: Optional[bool] = None,
184
- output_hidden_states: Optional[bool] = None,
185
- use_cache: Optional[bool] = None,
186
- inputs_embeds: Optional[torch.Tensor] = None,
187
- ):
188
- return_dict = return_dict if return_dict is not None else self.config.return_dict
189
- use_cache = use_cache if use_cache is not None else self.config.use_cache
190
-
191
- if attention_mask is not None:
192
- attention_mask = attention_mask.bool()
193
-
194
- if prefix_mask is not None:
195
- prefix_mask = prefix_mask.bool()
196
-
197
- # These args are passed in by keyword in huggingface's generate function
198
- # https://github.com/huggingface/transformers/blob/68287689f2f0d8b7063c400230b3766987abf18d/src/transformers/generation/utils.py#L2201-L2206
199
- # but have not yet been fully implemented in MPTModel
200
- if not return_dict:
201
- raise NotImplementedError("return_dict False is not implemented yet for MPT")
202
- if output_attentions:
203
- if self.attn_impl != "torch":
204
- raise NotImplementedError("output_attentions is not implemented for MPT when using attn_impl `flash` or `triton`.")
205
-
206
- if attention_mask is not None and attention_mask[:, 0].sum() != attention_mask.shape[0] and self.training:
207
- raise NotImplementedError("MPT does not support training with left padding.")
208
-
209
- if self.prefix_lm and prefix_mask is None:
210
- raise ValueError("prefix_mask is a required argument when MPT is configured with prefix_lm=True.")
211
-
212
- # Raise a not implemented error if input_embeds is not None (this is an arg in huggingface transformers and we need to support it for PEFT)
213
- if inputs_embeds is not None:
214
- raise NotImplementedError("inputs_embeds is not implemented for MPT.")
215
-
216
- if self.training:
217
- if self.attn_uses_sequence_id and sequence_id is None:
218
- raise ValueError(
219
- "sequence_id is a required argument when MPT is configured with attn_uses_sequence_id=True " + "and the model is in train mode."
220
- )
221
- elif (self.attn_uses_sequence_id is False) and (sequence_id is not None):
222
- warnings.warn(
223
- "MPT received non-None input for `sequence_id` but is configured with attn_uses_sequence_id=False. "
224
- + "This input will be ignored. If you want the model to use `sequence_id`, set attn_uses_sequence_id to True."
225
- )
226
-
227
- S = input_ids.size(1)
228
-
229
- assert S <= self.config.max_seq_len, f"Cannot forward input with seq_len={S}, this model only supports seq_len<={self.config.max_seq_len}"
230
-
231
- tok_emb = self.wte(input_ids) # type: ignore
232
- if self.alibi:
233
- x = tok_emb
234
- else:
235
- past_position = 0
236
- if past_key_values is not None:
237
- if len(past_key_values) != self.config.n_layers:
238
- raise ValueError(
239
- f"past_key_values must provide a past_key_value for each attention "
240
- + f"layer in the network ({len(past_key_values)=}; {self.config.n_layers=})."
241
- )
242
- # For attn_impl: triton and flash the past key tensor spec is (batch, seq, dim).
243
- # For attn_impl: torch the past key tensor spec is (batch, heads, head_dim, seq).
244
- # Here we shift position embedding using the `seq` dim of the past key
245
- past_position = past_key_values[0][0].size(1)
246
- if self.attn_impl == "torch":
247
- past_position = past_key_values[0][0].size(3)
248
-
249
- if S + past_position > self.config.max_seq_len:
250
- raise ValueError(
251
- f"Cannot forward input with past sequence length {past_position} and current sequence length "
252
- f"{S + 1}, this model only supports total sequence length <= {self.config.max_seq_len}."
253
- )
254
- pos = torch.arange(
255
- past_position,
256
- S + past_position,
257
- dtype=torch.long,
258
- device=input_ids.device,
259
- ).unsqueeze(0)
260
- if attention_mask is not None:
261
- # adjust the position indices to account for padding tokens
262
- pos = torch.clamp(
263
- pos - torch.cumsum((~attention_mask).to(torch.int32), dim=1)[:, past_position:],
264
- min=0,
265
- )
266
-
267
- pos_emb = self.wpe(pos) # type: ignore
268
- x = tok_emb + pos_emb
269
-
270
- if self.embedding_fraction == 1:
271
- x = self.emb_drop(x) # type: ignore
272
- else:
273
- # this implementation is proposed on page 7 of the GLM-130B paper https://arxiv.org/abs/2210.02414
274
- x_shrunk = (x * self.embedding_fraction) + (x.detach() * (1 - self.embedding_fraction))
275
- assert isinstance(self.emb_drop, nn.Module) # pyright
276
- x = self.emb_drop(x_shrunk)
277
-
278
- attn_bias, attention_mask = self._attn_bias(
279
- device=x.device,
280
- dtype=torch.float32,
281
- attention_mask=attention_mask,
282
- prefix_mask=prefix_mask,
283
- sequence_id=sequence_id,
284
- )
285
-
286
- # initialize the past key values cache if it should be used
287
- if use_cache and past_key_values is None:
288
- past_key_values = [() for _ in range(self.config.n_layers)] # type: ignore
289
-
290
- all_hidden_states = () if output_hidden_states else None
291
- all_self_attns = () if output_attentions else None
292
- for b_idx, block in enumerate(self.blocks): # type: ignore
293
- if output_hidden_states:
294
- assert all_hidden_states is not None # pyright
295
- all_hidden_states = all_hidden_states + (x,)
296
- past_key_value = past_key_values[b_idx] if past_key_values is not None else None
297
- x, attn_weights, past_key_value = block(
298
- x,
299
- past_key_value=past_key_value,
300
- attn_bias=attn_bias,
301
- attention_mask=attention_mask,
302
- is_causal=self.is_causal,
303
- )
304
- if past_key_values is not None:
305
- past_key_values[b_idx] = past_key_value
306
-
307
- if output_attentions:
308
- assert all_self_attns is not None # pyright
309
- all_self_attns = all_self_attns + (attn_weights,)
310
-
311
- x = self.norm_f(x) # type: ignore
312
-
313
- # add hidden states from the last decoder layer
314
- if output_hidden_states:
315
- assert all_hidden_states is not None # pyright
316
- all_hidden_states = all_hidden_states + (x,)
317
-
318
- return BaseModelOutputWithPast(
319
- last_hidden_state=x,
320
- past_key_values=past_key_values,
321
- hidden_states=all_hidden_states,
322
- attentions=all_self_attns,
323
- )
324
-
325
- # Param Initialization, needed for device='meta' fast initialization
326
- def param_init_fn(self, module):
327
- init_fn_name = self.config.init_config["name"]
328
- MODEL_INIT_REGISTRY[init_fn_name](
329
- module=module,
330
- n_layers=self.config.n_layers,
331
- d_model=self.config.d_model,
332
- **self.config.init_config,
333
- )
334
-
335
- def fsdp_wrap_fn(self, module):
336
- return isinstance(module, MPTBlock)
337
-
338
- def activation_checkpointing_fn(self, module):
339
- return isinstance(module, MPTBlock)
340
-
341
-
342
- class MPTForCausalLM(MPTPreTrainedModel):
343
- def __init__(self, config: MPTConfig):
344
- super().__init__(config)
345
- if not config.tie_word_embeddings:
346
- raise ValueError("MPTForCausalLM only supports tied word embeddings")
347
- self.transformer = MPTModel(config)
348
- for child in self.transformer.children():
349
- if isinstance(child, torch.nn.ModuleList):
350
- continue
351
- if isinstance(child, torch.nn.Module):
352
- child._fsdp_wrap = True
353
- self.logit_scale = None
354
- if config.logit_scale is not None:
355
- logit_scale = config.logit_scale
356
- if isinstance(logit_scale, str):
357
- if logit_scale == "inv_sqrt_d_model":
358
- logit_scale = 1 / math.sqrt(config.d_model)
359
- else:
360
- raise ValueError(f"logit_scale={logit_scale!r} is not recognized as an option; use numeric value or 'inv_sqrt_d_model'.")
361
- self.logit_scale = logit_scale
362
-
363
- def get_input_embeddings(self):
364
- return self.transformer.wte
365
-
366
- def set_input_embeddings(self, value):
367
- # self.transformer.wte = value
368
- peudo_wte = SharedEmbedding(value.weight.shape[0], value.weight.shape[1], device=self.transformer.wte.weight.device)
369
- peudo_wte.weight = value.weight
370
- self.transformer.wte = peudo_wte
371
-
372
- def get_output_embeddings(self):
373
- return self.transformer.wte
374
-
375
- def set_output_embeddings(self, new_embeddings):
376
- # self.transformer.wte = new_embeddings
377
- peudo_wte = SharedEmbedding(new_embeddings.weight.shape[0], new_embeddings.weight.shape[1], device=self.transformer.wte.weight.device)
378
- peudo_wte.weight = new_embeddings.weight
379
- self.transformer.wte = peudo_wte
380
-
381
- def set_decoder(self, decoder):
382
- self.transformer = decoder
383
-
384
- def get_decoder(self):
385
- return self.transformer
386
-
387
- def forward(
388
- self,
389
- input_ids: torch.LongTensor,
390
- past_key_values: Optional[List[Tuple[torch.FloatTensor]]] = None,
391
- attention_mask: Optional[torch.ByteTensor] = None,
392
- prefix_mask: Optional[torch.ByteTensor] = None,
393
- sequence_id: Optional[torch.LongTensor] = None,
394
- labels: Optional[torch.LongTensor] = None,
395
- return_dict: Optional[bool] = None,
396
- output_attentions: Optional[bool] = None,
397
- output_hidden_states: Optional[bool] = None,
398
- use_cache: Optional[bool] = None,
399
- inputs_embeds: Optional[torch.FloatTensor] = None,
400
- ):
401
- return_dict = return_dict if return_dict is not None else self.config.return_dict
402
- use_cache = use_cache if use_cache is not None else self.config.use_cache
403
-
404
- # if input_embeds is not none, raise a not implemented error
405
- if inputs_embeds is not None:
406
- raise NotImplementedError("inputs_embeds has to be None (for hf/peft support).")
407
- # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
408
- outputs = self.transformer(
409
- input_ids=input_ids,
410
- past_key_values=past_key_values,
411
- attention_mask=attention_mask,
412
- prefix_mask=prefix_mask,
413
- sequence_id=sequence_id,
414
- return_dict=return_dict,
415
- output_attentions=output_attentions,
416
- output_hidden_states=output_hidden_states,
417
- use_cache=use_cache,
418
- )
419
-
420
- # move outputs to same device as weights for token embedding
421
- # needed to support HF `device_map`
422
- logits = self.transformer.wte(
423
- input=outputs.last_hidden_state.to(self.transformer.wte.weight.device),
424
- unembed=True,
425
- )
426
-
427
- if self.logit_scale is not None:
428
- if self.logit_scale == 0:
429
- warnings.warn(f"Multiplying logits by {self.logit_scale=}. This will produce uniform (uninformative) outputs.")
430
- logits *= self.logit_scale
431
-
432
- loss = None
433
- if labels is not None:
434
- _labels = torch.roll(labels, shifts=-1)
435
- _labels[:, -1] = -100
436
- loss = F.cross_entropy(
437
- logits.view(-1, logits.size(-1)),
438
- _labels.to(logits.device).view(-1),
439
- )
440
-
441
- return CausalLMOutputWithPast(
442
- loss=loss,
443
- logits=logits,
444
- past_key_values=outputs.past_key_values,
445
- hidden_states=outputs.hidden_states,
446
- attentions=outputs.attentions,
447
- )
448
-
449
- def param_init_fn(self, module):
450
- init_fn_name = self.config.init_config["name"]
451
- MODEL_INIT_REGISTRY[init_fn_name](module=module, n_layers=self.config.n_layers, d_model=self.config.d_model, **self.config.init_config)
452
-
453
- def fsdp_wrap_fn(self, module):
454
- return isinstance(module, MPTBlock)
455
-
456
- def activation_checkpointing_fn(self, module):
457
- return isinstance(module, MPTBlock)
458
-
459
- def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, attention_mask=None, **kwargs):
460
- if inputs_embeds is not None:
461
- raise NotImplementedError("inputs_embeds is not implemented for MPT yet")
462
- attention_mask = attention_mask.bool()
463
- if attention_mask[:, -1].sum() != attention_mask.shape[0]:
464
- raise NotImplementedError("MPT does not support generation with right padding.")
465
- if self.transformer.attn_uses_sequence_id and self.training:
466
- sequence_id = torch.zeros_like(input_ids[:1])
467
- else:
468
- sequence_id = None
469
- if past_key_values is not None:
470
- input_ids = input_ids[:, -1].unsqueeze(-1)
471
- if self.transformer.prefix_lm:
472
- prefix_mask = torch.ones_like(attention_mask)
473
- if kwargs.get("use_cache") == False:
474
- raise NotImplementedError("MPT with prefix_lm=True does not support use_cache=False.")
475
- else:
476
- prefix_mask = None
477
- return {
478
- "input_ids": input_ids,
479
- "attention_mask": attention_mask,
480
- "prefix_mask": prefix_mask,
481
- "sequence_id": sequence_id,
482
- "past_key_values": past_key_values,
483
- "use_cache": kwargs.get("use_cache", True),
484
- }
485
-
486
- @staticmethod
487
- def _reorder_cache(past_key_values, beam_idx):
488
- """Used by HuggingFace generate when using beam search with kv-caching.
489
-
490
- See https://github.com/huggingface/transformers/blob/3ec7a47664ebe40c40f4b722f6bb1cd30c3821ec/src/transformers/models/gpt2/modeling_gpt2.py#L1122-L1133
491
- for an example in transformers.
492
- """
493
- reordered_past = []
494
- for layer_past in past_key_values:
495
- reordered_past += [tuple((past_state.index_select(0, beam_idx) for past_state in layer_past))]
496
- return reordered_past
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mllm/flamingo/mpt/norm.py DELETED
@@ -1,60 +0,0 @@
1
- import torch
2
-
3
-
4
- def _cast_if_autocast_enabled(tensor):
5
- if torch.is_autocast_enabled():
6
- if tensor.device.type == "cuda":
7
- dtype = torch.get_autocast_gpu_dtype()
8
- elif tensor.device.type == "cpu":
9
- dtype = torch.get_autocast_cpu_dtype()
10
- else:
11
- raise NotImplementedError()
12
- return tensor.to(dtype=dtype)
13
- return tensor
14
-
15
-
16
- class LPLayerNorm(torch.nn.LayerNorm):
17
- def __init__(self, normalized_shape, eps=1e-05, elementwise_affine=True, device=None, dtype=None):
18
- super().__init__(normalized_shape=normalized_shape, eps=eps, elementwise_affine=elementwise_affine, device=device, dtype=dtype)
19
-
20
- def forward(self, x):
21
- module_device = x.device
22
- downcast_x = _cast_if_autocast_enabled(x)
23
- downcast_weight = _cast_if_autocast_enabled(self.weight) if self.weight is not None else self.weight
24
- downcast_bias = _cast_if_autocast_enabled(self.bias) if self.bias is not None else self.bias
25
- with torch.autocast(enabled=False, device_type=module_device.type):
26
- return torch.nn.functional.layer_norm(downcast_x, self.normalized_shape, downcast_weight, downcast_bias, self.eps)
27
-
28
-
29
- def rms_norm(x, weight=None, eps=1e-05):
30
- output = x / torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps)
31
- if weight is not None:
32
- return output * weight
33
- return output
34
-
35
-
36
- class RMSNorm(torch.nn.Module):
37
- def __init__(self, normalized_shape, eps=1e-05, weight=True, dtype=None, device=None):
38
- super().__init__()
39
- self.eps = eps
40
- if weight:
41
- self.weight = torch.nn.Parameter(torch.ones(normalized_shape, dtype=dtype, device=device))
42
- else:
43
- self.register_parameter("weight", None)
44
-
45
- def forward(self, x):
46
- return rms_norm(x.float(), self.weight, self.eps).to(dtype=x.dtype)
47
-
48
-
49
- class LPRMSNorm(RMSNorm):
50
- def __init__(self, normalized_shape, eps=1e-05, weight=True, dtype=None, device=None):
51
- super().__init__(normalized_shape=normalized_shape, eps=eps, weight=weight, dtype=dtype, device=device)
52
-
53
- def forward(self, x):
54
- downcast_x = _cast_if_autocast_enabled(x)
55
- downcast_weight = _cast_if_autocast_enabled(self.weight) if self.weight is not None else self.weight
56
- with torch.autocast(enabled=False, device_type=x.device.type):
57
- return rms_norm(downcast_x, downcast_weight, self.eps).to(dtype=x.dtype)
58
-
59
-
60
- NORM_CLASS_REGISTRY = {"layernorm": torch.nn.LayerNorm, "low_precision_layernorm": LPLayerNorm, "rmsnorm": RMSNorm, "low_precision_rmsnorm": LPRMSNorm}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mllm/flamingo/mpt/param_init_fns.py DELETED
@@ -1,369 +0,0 @@
1
- import math
2
- import warnings
3
- from collections.abc import Sequence
4
- from functools import partial
5
- from typing import Optional, Tuple, Union
6
- import torch
7
- from torch import nn
8
- from .norm import NORM_CLASS_REGISTRY
9
-
10
-
11
- def torch_default_param_init_fn_(module: nn.Module, verbose: int = 0, **kwargs):
12
- del kwargs
13
- if verbose > 1:
14
- warnings.warn(f"Initializing network using module's reset_parameters attribute")
15
- if hasattr(module, "reset_parameters"):
16
- module.reset_parameters()
17
-
18
-
19
- def fused_init_helper_(module: nn.Module, init_fn_):
20
- _fused = getattr(module, "_fused", None)
21
- if _fused is None:
22
- raise RuntimeError(f"Internal logic error")
23
- (dim, splits) = _fused
24
- splits = (0, *splits, module.weight.size(dim))
25
- for s, e in zip(splits[:-1], splits[1:]):
26
- slice_indices = [slice(None)] * module.weight.ndim
27
- slice_indices[dim] = slice(s, e)
28
- init_fn_(module.weight[slice_indices])
29
-
30
-
31
- def generic_param_init_fn_(
32
- module: nn.Module,
33
- init_fn_,
34
- n_layers: int,
35
- d_model: Optional[int] = None,
36
- init_div_is_residual: Union[int, float, str, bool] = True,
37
- emb_init_std: Optional[float] = None,
38
- emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None,
39
- verbose: int = 0,
40
- **kwargs,
41
- ):
42
- del kwargs
43
- if verbose > 1:
44
- warnings.warn(f"If model has bias parameters they are initialized to 0.")
45
- init_div_is_residual = init_div_is_residual
46
- if init_div_is_residual is False:
47
- div_is_residual = 1.0
48
- elif init_div_is_residual is True:
49
- div_is_residual = math.sqrt(2 * n_layers)
50
- elif isinstance(init_div_is_residual, float) or isinstance(init_div_is_residual, int):
51
- div_is_residual = init_div_is_residual
52
- elif isinstance(init_div_is_residual, str) and init_div_is_residual.isnumeric():
53
- div_is_residual = float(init_div_is_residual)
54
- else:
55
- div_is_residual = 1.0
56
- raise ValueError(f"Expected init_div_is_residual to be boolean or numeric, got {init_div_is_residual}")
57
- if init_div_is_residual is not False:
58
- if verbose > 1:
59
- warnings.warn(
60
- f"Initializing _is_residual layers then dividing them by {div_is_residual:.3f}. "
61
- + f"Set `init_div_is_residual: false` in init config to disable this."
62
- )
63
- if isinstance(module, nn.Linear):
64
- if hasattr(module, "_fused"):
65
- fused_init_helper_(module, init_fn_)
66
- else:
67
- init_fn_(module.weight)
68
- if module.bias is not None:
69
- torch.nn.init.zeros_(module.bias)
70
- if init_div_is_residual is not False and getattr(module, "_is_residual", False):
71
- with torch.no_grad():
72
- module.weight.div_(div_is_residual)
73
- elif isinstance(module, nn.Embedding):
74
- if emb_init_std is not None:
75
- std = emb_init_std
76
- if std == 0:
77
- warnings.warn(f"Embedding layer initialized to 0.")
78
- emb_init_fn_ = partial(torch.nn.init.normal_, mean=0.0, std=std)
79
- if verbose > 1:
80
- warnings.warn(f"Embedding layer initialized using normal distribution with mean=0 and std={std!r}.")
81
- elif emb_init_uniform_lim is not None:
82
- lim = emb_init_uniform_lim
83
- if isinstance(lim, Sequence):
84
- if len(lim) > 2:
85
- raise ValueError(f"Uniform init requires a min and a max limit. User input: {lim}.")
86
- if lim[0] == lim[1]:
87
- warnings.warn(f"Embedding layer initialized to {lim[0]}.")
88
- else:
89
- if lim == 0:
90
- warnings.warn(f"Embedding layer initialized to 0.")
91
- lim = [-lim, lim]
92
- (a, b) = lim
93
- emb_init_fn_ = partial(torch.nn.init.uniform_, a=a, b=b)
94
- if verbose > 1:
95
- warnings.warn(f"Embedding layer initialized using uniform distribution in range {lim}.")
96
- else:
97
- emb_init_fn_ = init_fn_
98
- emb_init_fn_(module.weight)
99
- elif isinstance(module, tuple(set(NORM_CLASS_REGISTRY.values()))):
100
- if verbose > 1:
101
- warnings.warn(f"Norm weights are set to 1. If norm layer has a bias it is initialized to 0.")
102
- if hasattr(module, "weight") and module.weight is not None:
103
- torch.nn.init.ones_(module.weight)
104
- if hasattr(module, "bias") and module.bias is not None:
105
- torch.nn.init.zeros_(module.bias)
106
- elif isinstance(module, nn.MultiheadAttention):
107
- if module._qkv_same_embed_dim:
108
- assert module.in_proj_weight is not None
109
- assert module.q_proj_weight is None and module.k_proj_weight is None and (module.v_proj_weight is None)
110
- assert d_model is not None
111
- _d = d_model
112
- splits = (0, _d, 2 * _d, 3 * _d)
113
- for s, e in zip(splits[:-1], splits[1:]):
114
- init_fn_(module.in_proj_weight[s:e])
115
- else:
116
- assert module.q_proj_weight is not None and module.k_proj_weight is not None and (module.v_proj_weight is not None)
117
- assert module.in_proj_weight is None
118
- init_fn_(module.q_proj_weight)
119
- init_fn_(module.k_proj_weight)
120
- init_fn_(module.v_proj_weight)
121
- if module.in_proj_bias is not None:
122
- torch.nn.init.zeros_(module.in_proj_bias)
123
- if module.bias_k is not None:
124
- torch.nn.init.zeros_(module.bias_k)
125
- if module.bias_v is not None:
126
- torch.nn.init.zeros_(module.bias_v)
127
- init_fn_(module.out_proj.weight)
128
- if init_div_is_residual is not False and getattr(module.out_proj, "_is_residual", False):
129
- with torch.no_grad():
130
- module.out_proj.weight.div_(div_is_residual)
131
- if module.out_proj.bias is not None:
132
- torch.nn.init.zeros_(module.out_proj.bias)
133
- else:
134
- for _ in module.parameters(recurse=False):
135
- raise NotImplementedError(f"{module.__class__.__name__} parameters are not initialized by param_init_fn.")
136
-
137
-
138
- def _normal_init_(std, mean=0.0):
139
- return partial(torch.nn.init.normal_, mean=mean, std=std)
140
-
141
-
142
- def _normal_param_init_fn_(
143
- module: nn.Module,
144
- std: float,
145
- n_layers: int,
146
- d_model: Optional[int] = None,
147
- init_div_is_residual: Union[int, float, str, bool] = True,
148
- emb_init_std: Optional[float] = None,
149
- emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None,
150
- verbose: int = 0,
151
- **kwargs,
152
- ):
153
- del kwargs
154
- init_fn_ = _normal_init_(std=std)
155
- if verbose > 1:
156
- warnings.warn(f"Using torch.nn.init.normal_ init fn mean=0.0, std={std}")
157
- generic_param_init_fn_(
158
- module=module,
159
- init_fn_=init_fn_,
160
- d_model=d_model,
161
- n_layers=n_layers,
162
- init_div_is_residual=init_div_is_residual,
163
- emb_init_std=emb_init_std,
164
- emb_init_uniform_lim=emb_init_uniform_lim,
165
- verbose=verbose,
166
- )
167
-
168
-
169
- def baseline_param_init_fn_(
170
- module: nn.Module,
171
- init_std: float,
172
- n_layers: int,
173
- d_model: Optional[int] = None,
174
- init_div_is_residual: Union[int, float, str, bool] = True,
175
- emb_init_std: Optional[float] = None,
176
- emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None,
177
- verbose: int = 0,
178
- **kwargs,
179
- ):
180
- del kwargs
181
- if init_std is None:
182
- raise ValueError("You must set model.init_config['init_std'] to a float value to use the default initialization scheme.")
183
- _normal_param_init_fn_(
184
- module=module,
185
- std=init_std,
186
- d_model=d_model,
187
- n_layers=n_layers,
188
- init_div_is_residual=init_div_is_residual,
189
- emb_init_std=emb_init_std,
190
- emb_init_uniform_lim=emb_init_uniform_lim,
191
- verbose=verbose,
192
- )
193
-
194
-
195
- def small_param_init_fn_(
196
- module: nn.Module,
197
- n_layers: int,
198
- d_model: int,
199
- init_div_is_residual: Union[int, float, str, bool] = True,
200
- emb_init_std: Optional[float] = None,
201
- emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None,
202
- verbose: int = 0,
203
- **kwargs,
204
- ):
205
- del kwargs
206
- std = math.sqrt(2 / (5 * d_model))
207
- _normal_param_init_fn_(
208
- module=module,
209
- std=std,
210
- d_model=d_model,
211
- n_layers=n_layers,
212
- init_div_is_residual=init_div_is_residual,
213
- emb_init_std=emb_init_std,
214
- emb_init_uniform_lim=emb_init_uniform_lim,
215
- verbose=verbose,
216
- )
217
-
218
-
219
- def neox_param_init_fn_(
220
- module: nn.Module,
221
- n_layers: int,
222
- d_model: int,
223
- emb_init_std: Optional[float] = None,
224
- emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None,
225
- verbose: int = 0,
226
- **kwargs,
227
- ):
228
- """From section 2.3.1 of GPT-NeoX-20B:
229
-
230
- An Open-Source AutoregressiveLanguage Model — Black et. al. (2022)
231
- see https://github.com/EleutherAI/gpt-neox/blob/9610391ab319403cef079b438edd016a2443af54/megatron/model/init_functions.py#L151
232
- and https://github.com/EleutherAI/gpt-neox/blob/main/megatron/model/transformer.py
233
- """
234
- del kwargs
235
- residual_div = n_layers / math.sqrt(10)
236
- if verbose > 1:
237
- warnings.warn(f"setting init_div_is_residual to {residual_div}")
238
- small_param_init_fn_(
239
- module=module,
240
- d_model=d_model,
241
- n_layers=n_layers,
242
- init_div_is_residual=residual_div,
243
- emb_init_std=emb_init_std,
244
- emb_init_uniform_lim=emb_init_uniform_lim,
245
- verbose=verbose,
246
- )
247
-
248
-
249
- def kaiming_uniform_param_init_fn_(
250
- module: nn.Module,
251
- n_layers: int,
252
- d_model: Optional[int] = None,
253
- init_div_is_residual: Union[int, float, str, bool] = True,
254
- emb_init_std: Optional[float] = None,
255
- emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None,
256
- init_gain: float = 0,
257
- fan_mode: str = "fan_in",
258
- init_nonlinearity: str = "leaky_relu",
259
- verbose: int = 0,
260
- **kwargs,
261
- ):
262
- del kwargs
263
- if verbose > 1:
264
- warnings.warn(f"Using nn.init.kaiming_uniform_ init fn with parameters: " + f"a={init_gain}, mode={fan_mode}, nonlinearity={init_nonlinearity}")
265
- kaiming_uniform_ = partial(nn.init.kaiming_uniform_, a=init_gain, mode=fan_mode, nonlinearity=init_nonlinearity)
266
- generic_param_init_fn_(
267
- module=module,
268
- init_fn_=kaiming_uniform_,
269
- d_model=d_model,
270
- n_layers=n_layers,
271
- init_div_is_residual=init_div_is_residual,
272
- emb_init_std=emb_init_std,
273
- emb_init_uniform_lim=emb_init_uniform_lim,
274
- verbose=verbose,
275
- )
276
-
277
-
278
- def kaiming_normal_param_init_fn_(
279
- module: nn.Module,
280
- n_layers: int,
281
- d_model: Optional[int] = None,
282
- init_div_is_residual: Union[int, float, str, bool] = True,
283
- emb_init_std: Optional[float] = None,
284
- emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None,
285
- init_gain: float = 0,
286
- fan_mode: str = "fan_in",
287
- init_nonlinearity: str = "leaky_relu",
288
- verbose: int = 0,
289
- **kwargs,
290
- ):
291
- del kwargs
292
- if verbose > 1:
293
- warnings.warn(f"Using nn.init.kaiming_normal_ init fn with parameters: " + f"a={init_gain}, mode={fan_mode}, nonlinearity={init_nonlinearity}")
294
- kaiming_normal_ = partial(torch.nn.init.kaiming_normal_, a=init_gain, mode=fan_mode, nonlinearity=init_nonlinearity)
295
- generic_param_init_fn_(
296
- module=module,
297
- init_fn_=kaiming_normal_,
298
- d_model=d_model,
299
- n_layers=n_layers,
300
- init_div_is_residual=init_div_is_residual,
301
- emb_init_std=emb_init_std,
302
- emb_init_uniform_lim=emb_init_uniform_lim,
303
- verbose=verbose,
304
- )
305
-
306
-
307
- def xavier_uniform_param_init_fn_(
308
- module: nn.Module,
309
- n_layers: int,
310
- d_model: Optional[int] = None,
311
- init_div_is_residual: Union[int, float, str, bool] = True,
312
- emb_init_std: Optional[float] = None,
313
- emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None,
314
- init_gain: float = 0,
315
- verbose: int = 0,
316
- **kwargs,
317
- ):
318
- del kwargs
319
- xavier_uniform_ = partial(torch.nn.init.xavier_uniform_, gain=init_gain)
320
- if verbose > 1:
321
- warnings.warn(f"Using torch.nn.init.xavier_uniform_ init fn with parameters: " + f"gain={init_gain}")
322
- generic_param_init_fn_(
323
- module=module,
324
- init_fn_=xavier_uniform_,
325
- d_model=d_model,
326
- n_layers=n_layers,
327
- init_div_is_residual=init_div_is_residual,
328
- emb_init_std=emb_init_std,
329
- emb_init_uniform_lim=emb_init_uniform_lim,
330
- verbose=verbose,
331
- )
332
-
333
-
334
- def xavier_normal_param_init_fn_(
335
- module: nn.Module,
336
- n_layers: int,
337
- d_model: Optional[int] = None,
338
- init_div_is_residual: Union[int, float, str, bool] = True,
339
- emb_init_std: Optional[float] = None,
340
- emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None,
341
- init_gain: float = 0,
342
- verbose: int = 0,
343
- **kwargs,
344
- ):
345
- xavier_normal_ = partial(torch.nn.init.xavier_normal_, gain=init_gain)
346
- if verbose > 1:
347
- warnings.warn(f"Using torch.nn.init.xavier_normal_ init fn with parameters: " + f"gain={init_gain}")
348
- generic_param_init_fn_(
349
- module=module,
350
- init_fn_=xavier_normal_,
351
- d_model=d_model,
352
- n_layers=n_layers,
353
- init_div_is_residual=init_div_is_residual,
354
- emb_init_std=emb_init_std,
355
- emb_init_uniform_lim=emb_init_uniform_lim,
356
- verbose=verbose,
357
- )
358
-
359
-
360
- MODEL_INIT_REGISTRY = {
361
- "default_": torch_default_param_init_fn_,
362
- "baseline_": baseline_param_init_fn_,
363
- "kaiming_uniform_": kaiming_uniform_param_init_fn_,
364
- "kaiming_normal_": kaiming_normal_param_init_fn_,
365
- "neox_init_": neox_param_init_fn_,
366
- "small_init_": small_param_init_fn_,
367
- "xavier_uniform_": xavier_uniform_param_init_fn_,
368
- "xavier_normal_": xavier_normal_param_init_fn_,
369
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mllm/flamingo/mpt_redpajama/__init__.py DELETED
File without changes
mllm/flamingo/mpt_redpajama/__pycache__/__init__.cpython-39.pyc DELETED
Binary file (216 Bytes)