update mllm
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- mllm/flamingo/__init__.py +48 -0
- mllm/flamingo/config.json +21 -0
- mllm/flamingo/configuration_flamingo.py +100 -0
- mllm/flamingo/converting_flamingo_to_bf16.py +30 -0
- mllm/flamingo/converting_flamingo_to_hf.py +61 -0
- mllm/flamingo/converting_flamingo_to_lora.py +68 -0
- mllm/flamingo/falcon/__init__.py +0 -0
- mllm/flamingo/falcon/__pycache__/__init__.cpython-39.pyc +0 -0
- mllm/flamingo/falcon/__pycache__/configuration_RW.cpython-39.pyc +0 -0
- mllm/flamingo/falcon/__pycache__/modelling_RW.cpython-39.pyc +0 -0
- mllm/flamingo/falcon/configuration_RW.py +79 -0
- mllm/flamingo/falcon/modelling_RW.py +1064 -0
- mllm/flamingo/flamingo-falcon-7B.json +112 -0
- mllm/flamingo/flamingo-llama2-chat-13B.json +114 -0
- mllm/flamingo/flamingo-llama2-chat-7B.json +115 -0
- mllm/flamingo/flamingo-mpt-1B-redpajama.json +131 -0
- mllm/flamingo/flamingo-mpt-30B-bf16.json +195 -0
- mllm/flamingo/flamingo-mpt-30B.json +195 -0
- mllm/flamingo/flamingo-mpt-7B.json +195 -0
- mllm/flamingo/flamingo-vicuna-33B-v1.3.json +111 -0
- mllm/flamingo/flamingo-vicuna-7B-v1.3.json +111 -0
- mllm/flamingo/injecting_falcon_into_flamingo.py +49 -0
- mllm/flamingo/injecting_llama2_into_flamingo.py +95 -0
- mllm/flamingo/injecting_mpt-1B-redpajama_into_flamingo.py +97 -0
- mllm/flamingo/injecting_mpt_into_flamingo.py +109 -0
- mllm/flamingo/injecting_vicuna_into_flamingo.py +100 -0
- mllm/flamingo/modeling_flamingo.py +966 -0
- mllm/flamingo/mpt/__init__.py +0 -0
- mllm/flamingo/mpt/__pycache__/__init__.cpython-39.pyc +0 -0
- mllm/flamingo/mpt/__pycache__/attention.cpython-39.pyc +0 -0
- mllm/flamingo/mpt/__pycache__/blocks.cpython-39.pyc +0 -0
- mllm/flamingo/mpt/__pycache__/configuration_mpt.cpython-39.pyc +0 -0
- mllm/flamingo/mpt/__pycache__/custom_embedding.cpython-39.pyc +0 -0
- mllm/flamingo/mpt/__pycache__/flash_attn_triton.cpython-39.pyc +0 -0
- mllm/flamingo/mpt/__pycache__/modeling_mpt.cpython-39.pyc +0 -0
- mllm/flamingo/mpt/__pycache__/norm.cpython-39.pyc +0 -0
- mllm/flamingo/mpt/__pycache__/param_init_fns.cpython-39.pyc +0 -0
- mllm/flamingo/mpt/adapt_tokenizer.py +44 -0
- mllm/flamingo/mpt/attention.py +450 -0
- mllm/flamingo/mpt/blocks.py +82 -0
- mllm/flamingo/mpt/configuration_mpt.py +161 -0
- mllm/flamingo/mpt/custom_embedding.py +11 -0
- mllm/flamingo/mpt/flash_attn_triton.py +841 -0
- mllm/flamingo/mpt/hf_prefixlm_converter.py +575 -0
- mllm/flamingo/mpt/meta_init_context.py +98 -0
- mllm/flamingo/mpt/modeling_mpt.py +496 -0
- mllm/flamingo/mpt/norm.py +60 -0
- mllm/flamingo/mpt/param_init_fns.py +369 -0
- mllm/flamingo/mpt_redpajama/__init__.py +0 -0
- mllm/flamingo/mpt_redpajama/__pycache__/__init__.cpython-39.pyc +0 -0
mllm/flamingo/__init__.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import TYPE_CHECKING
|
| 2 |
+
|
| 3 |
+
from transformers.utils import (
|
| 4 |
+
OptionalDependencyNotAvailable,
|
| 5 |
+
_LazyModule,
|
| 6 |
+
is_torch_available,
|
| 7 |
+
)
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
_import_structure = {
|
| 11 |
+
"configuration_flamingo": [
|
| 12 |
+
"FlamingoConfig",
|
| 13 |
+
],
|
| 14 |
+
}
|
| 15 |
+
|
| 16 |
+
try:
|
| 17 |
+
if not is_torch_available():
|
| 18 |
+
raise OptionalDependencyNotAvailable()
|
| 19 |
+
except OptionalDependencyNotAvailable:
|
| 20 |
+
pass
|
| 21 |
+
else:
|
| 22 |
+
_import_structure["modeling_flamingo"] = [
|
| 23 |
+
"FlamingoModel",
|
| 24 |
+
"FlamingoPreTrainedModel",
|
| 25 |
+
"FlamingoForConditionalGeneration",
|
| 26 |
+
]
|
| 27 |
+
|
| 28 |
+
if TYPE_CHECKING:
|
| 29 |
+
from .configuration_flamingo import FlamingoConfig
|
| 30 |
+
|
| 31 |
+
# from .processing_flamingo import FlamingoProcessor
|
| 32 |
+
|
| 33 |
+
try:
|
| 34 |
+
if not is_torch_available():
|
| 35 |
+
raise OptionalDependencyNotAvailable()
|
| 36 |
+
except OptionalDependencyNotAvailable:
|
| 37 |
+
pass
|
| 38 |
+
else:
|
| 39 |
+
from .modeling_flamingo import (
|
| 40 |
+
FlamingoForConditionalGeneration,
|
| 41 |
+
FlamingoModel,
|
| 42 |
+
FlamingoPreTrainedModel,
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
else:
|
| 46 |
+
import sys
|
| 47 |
+
|
| 48 |
+
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
|
mllm/flamingo/config.json
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"model_type": "flamingo",
|
| 3 |
+
"cross_attn_every_n_layers": 4,
|
| 4 |
+
"tie_word_embeddings": false,
|
| 5 |
+
"use_media_placement_augmentation": true,
|
| 6 |
+
"only_attend_previous": true,
|
| 7 |
+
"text_config": {
|
| 8 |
+
"_name_or_path": "luodian/llama-7b-hf",
|
| 9 |
+
"model_type": "llama"
|
| 10 |
+
},
|
| 11 |
+
"vision_config": {
|
| 12 |
+
"_name_or_path": "openai/clip-vit-large-patch14",
|
| 13 |
+
"model_type": "clip_vision_model",
|
| 14 |
+
"hidden_size": 1024,
|
| 15 |
+
"intermediate_size": 4096,
|
| 16 |
+
"num_attention_heads": 16,
|
| 17 |
+
"num_hidden_layers": 24,
|
| 18 |
+
"image_size": 224,
|
| 19 |
+
"patch_size": 14
|
| 20 |
+
}
|
| 21 |
+
}
|
mllm/flamingo/configuration_flamingo.py
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import copy
|
| 2 |
+
|
| 3 |
+
from transformers.configuration_utils import PretrainedConfig
|
| 4 |
+
from transformers.utils import logging
|
| 5 |
+
|
| 6 |
+
from transformers.models.auto import CONFIG_MAPPING
|
| 7 |
+
from transformers.models.clip import CLIPVisionConfig
|
| 8 |
+
import sys
|
| 9 |
+
|
| 10 |
+
from .falcon.configuration_RW import RWConfig
|
| 11 |
+
from .mpt.configuration_mpt import MPTConfig
|
| 12 |
+
from .mpt_redpajama.configuration_mosaic_gpt import MosaicGPTConfig
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
logger = logging.get_logger(__name__)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class FlamingoConfig(PretrainedConfig):
|
| 19 |
+
r"""
|
| 20 |
+
[`FlamingoConfig`] is the configuration class to store the configuration of a [`FlamingoForConditionalGeneration`]. It is
|
| 21 |
+
used to instantiate a Flamingo model according to the specified arguments, defining the vision model and language model configs. Instantiating a configuration with the defaults will yield a similar configuration to
|
| 22 |
+
that of the Flamingo architecture.
|
| 23 |
+
|
| 24 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 25 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 26 |
+
|
| 27 |
+
Args:
|
| 28 |
+
vision_config (`dict`, *optional*):
|
| 29 |
+
Dictionary of configuration options used to initialize [`PretrainedConfig`].
|
| 30 |
+
text_config (`dict`, *optional*):
|
| 31 |
+
Dictionary of configuration options used to initialize any [`PretrainedConfig`].
|
| 32 |
+
cross_attn_every_n_layers (`int`, *optional*, defaults to 4):
|
| 33 |
+
The number of cross-attention layers adding after each transformer layer.
|
| 34 |
+
|
| 35 |
+
kwargs (*optional*):
|
| 36 |
+
Dictionary of keyword arguments.
|
| 37 |
+
|
| 38 |
+
Example:
|
| 39 |
+
|
| 40 |
+
```python
|
| 41 |
+
>>> from transformers import (
|
| 42 |
+
... PretrainedConfig,
|
| 43 |
+
... OPTConfig,
|
| 44 |
+
... FlamingoConfig,
|
| 45 |
+
... FlamingoForConditionalGeneration,
|
| 46 |
+
... )
|
| 47 |
+
|
| 48 |
+
>>> # Initializing a FlamingoConfig with Salesforce/Flamingo-opt-2.7b style configuration
|
| 49 |
+
>>> configuration = FlamingoConfig()
|
| 50 |
+
|
| 51 |
+
>>> # Initializing a FlamingoForConditionalGeneration (with random weights) from the Salesforce/Flamingo-opt-2.7b style configuration
|
| 52 |
+
>>> model = FlamingoForConditionalGeneration(configuration)
|
| 53 |
+
```"""
|
| 54 |
+
model_type = "flamingo"
|
| 55 |
+
is_composition = True
|
| 56 |
+
|
| 57 |
+
def __init__(self, vision_config=None, text_config=None, cross_attn_every_n_layers: int = 4, use_media_placement_augmentation: bool = True, **kwargs):
|
| 58 |
+
super().__init__(**kwargs)
|
| 59 |
+
if vision_config is None:
|
| 60 |
+
vision_config = {}
|
| 61 |
+
logger.info("vision_config is None. initializing the vision config with default values.")
|
| 62 |
+
|
| 63 |
+
if text_config is None:
|
| 64 |
+
text_config = {}
|
| 65 |
+
logger.info("text_config is None. Initializing the text config with default values.")
|
| 66 |
+
|
| 67 |
+
self.vision_config = CLIPVisionConfig(**vision_config)
|
| 68 |
+
if "architectures" in text_config.keys() and text_config["architectures"] != None:
|
| 69 |
+
if text_config["architectures"][0] == "MPTForCausalLM":
|
| 70 |
+
self.text_config = MPTConfig(**text_config)
|
| 71 |
+
elif text_config["architectures"][0] == "MosaicGPT":
|
| 72 |
+
self.text_config = MosaicGPTConfig(**text_config)
|
| 73 |
+
elif text_config["architectures"][0] == "RWForCausalLM":
|
| 74 |
+
self.text_config = RWConfig(**text_config)
|
| 75 |
+
elif text_config["architectures"][0] == "LlamaForCausalLM":
|
| 76 |
+
self.text_config = CONFIG_MAPPING[text_config.pop("model_type")](**text_config)
|
| 77 |
+
else:
|
| 78 |
+
import pdb
|
| 79 |
+
|
| 80 |
+
pdb.set_trace()
|
| 81 |
+
else:
|
| 82 |
+
self.text_config = CONFIG_MAPPING[text_config.pop("model_type")](**text_config)
|
| 83 |
+
|
| 84 |
+
self.cross_attn_every_n_layers = cross_attn_every_n_layers
|
| 85 |
+
self.use_media_placement_augmentation = use_media_placement_augmentation
|
| 86 |
+
|
| 87 |
+
def to_dict(self):
|
| 88 |
+
"""
|
| 89 |
+
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
|
| 90 |
+
|
| 91 |
+
Returns:
|
| 92 |
+
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
|
| 93 |
+
"""
|
| 94 |
+
output = copy.deepcopy(self.__dict__)
|
| 95 |
+
output["vision_config"] = self.vision_config.to_dict()
|
| 96 |
+
output["text_config"] = self.text_config.to_dict()
|
| 97 |
+
output["model_type"] = self.__class__.model_type
|
| 98 |
+
output["cross_attn_every_n_layers"] = self.cross_attn_every_n_layers
|
| 99 |
+
output["use_media_placement_augmentation"] = self.use_media_placement_augmentation
|
| 100 |
+
return output
|
mllm/flamingo/converting_flamingo_to_bf16.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
from .configuration_flamingo import FlamingoConfig
|
| 7 |
+
from .modeling_flamingo import FlamingoForConditionalGeneration
|
| 8 |
+
|
| 9 |
+
parser = argparse.ArgumentParser(description="Load model with precision")
|
| 10 |
+
parser.add_argument("--load_bit", type=str, choices=["fp16", "bf16"], required=True, help="Choose either 'fp16' or 'bf16'")
|
| 11 |
+
parser.add_argument("--pretrained_model_path", type=str, default="/home/luodian/projects/checkpoints/flamingo-mpt-7B-instruct-init", required=True)
|
| 12 |
+
parser.add_argument("--saved_model_path", type=str, default="/home/luodian/projects/checkpoints/flamingo-mpt-7B-instruct-init", required=True)
|
| 13 |
+
args = parser.parse_args()
|
| 14 |
+
|
| 15 |
+
load_bit = args.load_bit
|
| 16 |
+
pretrained_model_path = args.pretrained_model_path
|
| 17 |
+
|
| 18 |
+
if load_bit == "fp16":
|
| 19 |
+
precision = {"torch_dtype": torch.float16}
|
| 20 |
+
elif load_bit == "bf16":
|
| 21 |
+
precision = {"torch_dtype": torch.bfloat16}
|
| 22 |
+
|
| 23 |
+
root_dir = os.environ["AZP"]
|
| 24 |
+
print(root_dir)
|
| 25 |
+
device_id = "cpu"
|
| 26 |
+
model = FlamingoForConditionalGeneration.from_pretrained(pretrained_model_path, device_map={"": device_id}, **precision)
|
| 27 |
+
|
| 28 |
+
# save model to same folder
|
| 29 |
+
checkpoint_path = pretrained_model_path + f"-{load_bit}"
|
| 30 |
+
model.save_pretrained(checkpoint_path, max_shard_size="10GB")
|
mllm/flamingo/converting_flamingo_to_hf.py
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""convert from otter pt to otter hf. Will remove after we use otter hf model to train.
|
| 2 |
+
"""
|
| 3 |
+
|
| 4 |
+
import re
|
| 5 |
+
import argparse
|
| 6 |
+
import os
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn as nn
|
| 10 |
+
from transformers import CLIPVisionModel, LlamaForCausalLM, LlamaTokenizer
|
| 11 |
+
|
| 12 |
+
import sys
|
| 13 |
+
from modeling_flamingo import FlamingoForConditionalGeneration
|
| 14 |
+
|
| 15 |
+
from configuration_flamingo import FlamingoConfig
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
@torch.no_grad()
|
| 19 |
+
def dump_hf_model(pretrained_model_path: str, old_ckpt_path: str, new_folder_path: str) -> None:
|
| 20 |
+
old_ckpt = torch.load(old_ckpt_path, map_location="cpu")
|
| 21 |
+
if old_ckpt.get("model_state_dict", None) is not None:
|
| 22 |
+
old_ckpt = old_ckpt["model_state_dict"]
|
| 23 |
+
new_ckpt = old_ckpt
|
| 24 |
+
folder_path = os.path.dirname(old_ckpt_path)
|
| 25 |
+
# config_path = os.path.join(folder_path, "config.json") if os.path.exists(os.path.join(folder_path, "config.json")) else "flamingo/config.json"
|
| 26 |
+
model = FlamingoForConditionalGeneration.from_pretrained(
|
| 27 |
+
args.pretrained_model_path,
|
| 28 |
+
device_map="auto",
|
| 29 |
+
)
|
| 30 |
+
_ = model.load_state_dict(new_ckpt, strict=False)
|
| 31 |
+
print(f"Saving HF model to {new_folder_path}")
|
| 32 |
+
model.save_pretrained(new_folder_path)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
if __name__ == "__main__":
|
| 36 |
+
parser = argparse.ArgumentParser()
|
| 37 |
+
parser.add_argument(
|
| 38 |
+
"--old_ckpt_path",
|
| 39 |
+
"-old",
|
| 40 |
+
type=str,
|
| 41 |
+
required=True,
|
| 42 |
+
help="Path to the pt checkpoint",
|
| 43 |
+
)
|
| 44 |
+
parser.add_argument(
|
| 45 |
+
"--new_hf_path",
|
| 46 |
+
"-new",
|
| 47 |
+
type=str,
|
| 48 |
+
required=True,
|
| 49 |
+
help="Path to the hf folder",
|
| 50 |
+
)
|
| 51 |
+
parser.add_argument(
|
| 52 |
+
"--pretrained_model_path",
|
| 53 |
+
"-pretrained",
|
| 54 |
+
type=str,
|
| 55 |
+
required=True,
|
| 56 |
+
help="Path to the pretrained model folder",
|
| 57 |
+
)
|
| 58 |
+
args = parser.parse_args()
|
| 59 |
+
if not os.path.exists(os.path.dirname(args.new_hf_path)):
|
| 60 |
+
os.makedirs(os.path.dirname(args.new_hf_path))
|
| 61 |
+
dump_hf_model(args.pretrained_model_path, args.old_ckpt_path, args.new_hf_path)
|
mllm/flamingo/converting_flamingo_to_lora.py
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import torch
|
| 3 |
+
import sys
|
| 4 |
+
|
| 5 |
+
from .modeling_flamingo import FlamingoForConditionalGeneration
|
| 6 |
+
from peft import get_peft_model, LoraConfig, TaskType
|
| 7 |
+
|
| 8 |
+
MODEL_CLASSES = {
|
| 9 |
+
"LlamaForCausalLM": "llama",
|
| 10 |
+
"OPTForCausalLM": "opt",
|
| 11 |
+
"GPTJForCausalLM": "gptj",
|
| 12 |
+
"GPTNeoXForCausalLM": "gpt_neox",
|
| 13 |
+
"MPTForCausalLM": "mpt",
|
| 14 |
+
}
|
| 15 |
+
|
| 16 |
+
# Define argument parser
|
| 17 |
+
parser = argparse.ArgumentParser(description="Load a model with specified precision and save it to a specified path.")
|
| 18 |
+
|
| 19 |
+
# Add arguments
|
| 20 |
+
parser.add_argument(
|
| 21 |
+
"--checkpoint_path",
|
| 22 |
+
type=str,
|
| 23 |
+
help="Path to the pre-trained model checkpoint.",
|
| 24 |
+
default="",
|
| 25 |
+
)
|
| 26 |
+
parser.add_argument(
|
| 27 |
+
"--save_path",
|
| 28 |
+
type=str,
|
| 29 |
+
default="",
|
| 30 |
+
help="Path to the converted model checkpoint.",
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
# Parse the input arguments
|
| 34 |
+
args = parser.parse_args()
|
| 35 |
+
|
| 36 |
+
load_bit = "bf16"
|
| 37 |
+
if load_bit == "fp16":
|
| 38 |
+
precision = {"torch_dtype": torch.float16}
|
| 39 |
+
elif load_bit == "bf16":
|
| 40 |
+
precision = {"torch_dtype": torch.bfloat16}
|
| 41 |
+
|
| 42 |
+
# Load the model
|
| 43 |
+
model = FlamingoForConditionalGeneration.from_pretrained(args.checkpoint_path, device_map="auto", **precision)
|
| 44 |
+
|
| 45 |
+
# adding lora
|
| 46 |
+
standard_modules = ["q_proj", "v_proj"]
|
| 47 |
+
lang_encoder_short_name = MODEL_CLASSES[model.config.text_config.architectures[0]]
|
| 48 |
+
model_to_lora_modules = {
|
| 49 |
+
"llama": standard_modules,
|
| 50 |
+
"opt": standard_modules,
|
| 51 |
+
"gptj": standard_modules,
|
| 52 |
+
"gpt_neox": ["query_key_value"],
|
| 53 |
+
"mpt": ["Wqkv"],
|
| 54 |
+
}
|
| 55 |
+
lora_config = LoraConfig(
|
| 56 |
+
r=16,
|
| 57 |
+
lora_alpha=32,
|
| 58 |
+
lora_dropout=0.05,
|
| 59 |
+
task_type=TaskType.CAUSAL_LM,
|
| 60 |
+
target_modules=model_to_lora_modules[lang_encoder_short_name],
|
| 61 |
+
)
|
| 62 |
+
model.config.update({"lora_config": {"r": 16, "lora_alpha": 32, "lora_dropout": 0.05}})
|
| 63 |
+
model.lang_encoder = get_peft_model(model.lang_encoder, lora_config)
|
| 64 |
+
model.lang_encoder.print_trainable_parameters()
|
| 65 |
+
|
| 66 |
+
# Save the model
|
| 67 |
+
checkpoint_path = args.save_path
|
| 68 |
+
FlamingoForConditionalGeneration.save_pretrained(model, checkpoint_path)
|
mllm/flamingo/falcon/__init__.py
ADDED
|
File without changes
|
mllm/flamingo/falcon/__pycache__/__init__.cpython-39.pyc
ADDED
|
Binary file (209 Bytes). View file
|
|
|
mllm/flamingo/falcon/__pycache__/configuration_RW.cpython-39.pyc
ADDED
|
Binary file (1.86 kB). View file
|
|
|
mllm/flamingo/falcon/__pycache__/modelling_RW.cpython-39.pyc
ADDED
|
Binary file (28.5 kB). View file
|
|
|
mllm/flamingo/falcon/configuration_RW.py
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2022 the Big Science Workshop and HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
""" Bloom configuration"""
|
| 16 |
+
from transformers.configuration_utils import PretrainedConfig
|
| 17 |
+
from transformers.utils import logging
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
logger = logging.get_logger(__name__)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class RWConfig(PretrainedConfig):
|
| 24 |
+
model_type = "RefinedWebModel"
|
| 25 |
+
keys_to_ignore_at_inference = ["past_key_values"]
|
| 26 |
+
attribute_map = {
|
| 27 |
+
"num_hidden_layers": "n_layer",
|
| 28 |
+
"num_attention_heads": "n_head",
|
| 29 |
+
}
|
| 30 |
+
|
| 31 |
+
def __init__(
|
| 32 |
+
self,
|
| 33 |
+
vocab_size=250880,
|
| 34 |
+
hidden_size=64,
|
| 35 |
+
n_layer=2,
|
| 36 |
+
n_head=8,
|
| 37 |
+
layer_norm_epsilon=1e-5,
|
| 38 |
+
initializer_range=0.02,
|
| 39 |
+
use_cache=True,
|
| 40 |
+
bos_token_id=1,
|
| 41 |
+
eos_token_id=2,
|
| 42 |
+
apply_residual_connection_post_layernorm=False,
|
| 43 |
+
hidden_dropout=0.0,
|
| 44 |
+
attention_dropout=0.0,
|
| 45 |
+
multi_query=False,
|
| 46 |
+
alibi=False,
|
| 47 |
+
bias=False,
|
| 48 |
+
parallel_attn=False,
|
| 49 |
+
**kwargs,
|
| 50 |
+
):
|
| 51 |
+
self.vocab_size = vocab_size
|
| 52 |
+
# Backward compatibility with n_embed kwarg
|
| 53 |
+
n_embed = kwargs.pop("n_embed", None)
|
| 54 |
+
self.hidden_size = hidden_size if n_embed is None else n_embed
|
| 55 |
+
self.n_layer = n_layer
|
| 56 |
+
self.n_head = n_head
|
| 57 |
+
self.layer_norm_epsilon = layer_norm_epsilon
|
| 58 |
+
self.initializer_range = initializer_range
|
| 59 |
+
self.use_cache = use_cache
|
| 60 |
+
self.apply_residual_connection_post_layernorm = apply_residual_connection_post_layernorm
|
| 61 |
+
self.hidden_dropout = hidden_dropout
|
| 62 |
+
self.attention_dropout = attention_dropout
|
| 63 |
+
|
| 64 |
+
self.bos_token_id = bos_token_id
|
| 65 |
+
self.eos_token_id = eos_token_id
|
| 66 |
+
self.multi_query = multi_query
|
| 67 |
+
self.alibi = alibi
|
| 68 |
+
self.bias = bias
|
| 69 |
+
self.parallel_attn = parallel_attn
|
| 70 |
+
|
| 71 |
+
super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
|
| 72 |
+
|
| 73 |
+
@property
|
| 74 |
+
def head_dim(self):
|
| 75 |
+
return self.hidden_size // self.n_head
|
| 76 |
+
|
| 77 |
+
@property
|
| 78 |
+
def rotary(self):
|
| 79 |
+
return not self.alibi
|
mllm/flamingo/falcon/modelling_RW.py
ADDED
|
@@ -0,0 +1,1064 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# port of models described in RW
|
| 2 |
+
# We use the bloom model as a starting point for these model.
|
| 3 |
+
# Please refer to the bloom models for usage instructions.
|
| 4 |
+
|
| 5 |
+
import math
|
| 6 |
+
import warnings
|
| 7 |
+
from typing import Optional, Tuple, Union
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torch.utils.checkpoint
|
| 11 |
+
from torch import nn
|
| 12 |
+
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, LayerNorm, MSELoss
|
| 13 |
+
from torch.nn import functional as F
|
| 14 |
+
|
| 15 |
+
from transformers.modeling_outputs import (
|
| 16 |
+
BaseModelOutputWithPastAndCrossAttentions,
|
| 17 |
+
CausalLMOutputWithCrossAttentions,
|
| 18 |
+
QuestionAnsweringModelOutput,
|
| 19 |
+
SequenceClassifierOutputWithPast,
|
| 20 |
+
TokenClassifierOutput,
|
| 21 |
+
)
|
| 22 |
+
from transformers.modeling_utils import PreTrainedModel
|
| 23 |
+
from transformers.utils import logging
|
| 24 |
+
from .configuration_RW import RWConfig
|
| 25 |
+
|
| 26 |
+
logger = logging.get_logger(__name__)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
# NOTE(Hesslow): Unfortunately we did not fuse matmul and bias during training, this means that there's one additional quantization to bfloat16 between the operations.
|
| 30 |
+
# In order not to degrade the quality of our HF-port, we keep these characteristics in the final model.
|
| 31 |
+
class Linear(nn.Linear):
|
| 32 |
+
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
| 33 |
+
ret = input @ self.weight.T
|
| 34 |
+
if self.bias is None:
|
| 35 |
+
return ret
|
| 36 |
+
else:
|
| 37 |
+
return ret + self.bias
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
from einops import rearrange
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
# rotary pos emb helpers (torch.jit.script does not seem to support staticmethod...)
|
| 44 |
+
def rotate_half(x):
|
| 45 |
+
x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
|
| 46 |
+
return torch.cat((-x2, x1), dim=x1.ndim - 1) # dim=-1 triggers a bug in torch < 1.8.0
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class RotaryEmbedding(torch.nn.Module):
|
| 50 |
+
"""Implementation of RotaryEmbedding from GPT-NeoX.
|
| 51 |
+
This implementation is design to operate on queries and keys that are compatible with
|
| 52 |
+
[batch_size, n_heads_per_partition, seq_len, head_dim] (e.g. MinGPTAttention format).
|
| 53 |
+
"""
|
| 54 |
+
|
| 55 |
+
def __init__(
|
| 56 |
+
self,
|
| 57 |
+
head_dim: int,
|
| 58 |
+
base=10000,
|
| 59 |
+
):
|
| 60 |
+
super().__init__()
|
| 61 |
+
inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2).float() / head_dim))
|
| 62 |
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
| 63 |
+
self.head_dim = head_dim
|
| 64 |
+
self.seq_len_cached = None
|
| 65 |
+
self.batch_size_cached = None
|
| 66 |
+
self.cos_cached: torch.Tensor | None = None
|
| 67 |
+
self.sin_cached: torch.Tensor | None = None
|
| 68 |
+
|
| 69 |
+
def cos_sin(
|
| 70 |
+
self,
|
| 71 |
+
seq_len: int,
|
| 72 |
+
device="cuda",
|
| 73 |
+
dtype=torch.bfloat16,
|
| 74 |
+
) -> torch.Tensor:
|
| 75 |
+
if seq_len != self.seq_len_cached:
|
| 76 |
+
self.seq_len_cached = seq_len
|
| 77 |
+
t = torch.arange(seq_len, device=device).type_as(self.inv_freq)
|
| 78 |
+
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
| 79 |
+
emb = torch.cat((freqs, freqs), dim=-1).to(device)
|
| 80 |
+
|
| 81 |
+
if dtype in [torch.float16, torch.bfloat16]:
|
| 82 |
+
emb = emb.float()
|
| 83 |
+
|
| 84 |
+
self.cos_cached = emb.cos()[None, :, :]
|
| 85 |
+
self.sin_cached = emb.sin()[None, :, :]
|
| 86 |
+
|
| 87 |
+
self.cos_cached = self.cos_cached.type(dtype)
|
| 88 |
+
self.sin_cached = self.sin_cached.type(dtype)
|
| 89 |
+
|
| 90 |
+
return self.cos_cached, self.sin_cached
|
| 91 |
+
|
| 92 |
+
def forward(self, q, k):
|
| 93 |
+
batch, seq_len, head_dim = q.shape
|
| 94 |
+
cos, sin = self.cos_sin(seq_len, q.device, q.dtype)
|
| 95 |
+
return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def _make_causal_mask(input_ids_shape: torch.Size, device: torch.device, past_key_values_length: int) -> torch.BoolTensor:
|
| 99 |
+
batch_size, target_length = input_ids_shape
|
| 100 |
+
mask = torch.empty((target_length, target_length + past_key_values_length), dtype=torch.bool, device=device)
|
| 101 |
+
# ONNX doesn't support `torch.Tensor.triu` properly, thus we use this workaround
|
| 102 |
+
seq_ids = torch.arange(target_length, device=device)
|
| 103 |
+
mask[:, past_key_values_length:] = seq_ids[:, None] < seq_ids[None, :]
|
| 104 |
+
|
| 105 |
+
if past_key_values_length > 0:
|
| 106 |
+
mask[:, :past_key_values_length] = False
|
| 107 |
+
|
| 108 |
+
expanded_mask = mask[None, None, :, :].expand(batch_size, 1, target_length, target_length + past_key_values_length)
|
| 109 |
+
return expanded_mask
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def _expand_mask(mask: torch.Tensor, tgt_length: int) -> torch.BoolTensor:
|
| 113 |
+
batch_size, src_length = mask.shape
|
| 114 |
+
tgt_length = tgt_length if tgt_length is not None else src_length
|
| 115 |
+
|
| 116 |
+
expanded_mask = ~(mask[:, None, None, :].to(torch.bool))
|
| 117 |
+
return expanded_mask.expand(batch_size, 1, tgt_length, src_length)
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def build_alibi_tensor(attention_mask: torch.Tensor, num_heads: int, dtype: torch.dtype) -> torch.Tensor:
|
| 121 |
+
batch_size, seq_length = attention_mask.shape
|
| 122 |
+
closest_power_of_2 = 2 ** math.floor(math.log2(num_heads))
|
| 123 |
+
base = torch.tensor(2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))), device=attention_mask.device, dtype=torch.float32)
|
| 124 |
+
powers = torch.arange(1, 1 + closest_power_of_2, device=attention_mask.device, dtype=torch.int32)
|
| 125 |
+
slopes = torch.pow(base, powers)
|
| 126 |
+
|
| 127 |
+
if closest_power_of_2 != num_heads:
|
| 128 |
+
extra_base = torch.tensor(2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))), device=attention_mask.device, dtype=torch.float32)
|
| 129 |
+
num_remaining_heads = min(closest_power_of_2, num_heads - closest_power_of_2)
|
| 130 |
+
extra_powers = torch.arange(1, 1 + 2 * num_remaining_heads, 2, device=attention_mask.device, dtype=torch.int32)
|
| 131 |
+
slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0)
|
| 132 |
+
|
| 133 |
+
# Note: alibi will added to the attention bias that will be applied to the query, key product of attention
|
| 134 |
+
# => therefore alibi will have to be of shape (batch_size, num_heads, query_length, key_length)
|
| 135 |
+
# => here we set (batch_size=1, num_heads=num_heads, query_length=1, key_length=max_length)
|
| 136 |
+
# => the query_length dimension will then be broadcasted correctly
|
| 137 |
+
# This is more or less identical to T5's relative position bias:
|
| 138 |
+
# https://github.com/huggingface/transformers/blob/f681437203baa7671de3174b0fa583c349d9d5e1/src/transformers/models/t5/modeling_t5.py#L527
|
| 139 |
+
arange_tensor = ((attention_mask.cumsum(dim=-1) - 1) * attention_mask)[:, None, :]
|
| 140 |
+
alibi = slopes[..., None].bfloat16() * arange_tensor
|
| 141 |
+
return alibi.reshape(batch_size * num_heads, 1, seq_length).to(dtype)
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
def dropout_add(x: torch.Tensor, residual: torch.Tensor, prob: float, training: bool) -> torch.Tensor:
|
| 145 |
+
out = F.dropout(x, p=prob, training=training)
|
| 146 |
+
out = residual + out
|
| 147 |
+
return out
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
class Attention(nn.Module):
|
| 151 |
+
def __init__(self, config: RWConfig):
|
| 152 |
+
super().__init__()
|
| 153 |
+
|
| 154 |
+
self.hidden_size = config.hidden_size
|
| 155 |
+
self.num_heads = config.n_head
|
| 156 |
+
self.head_dim = self.hidden_size // self.num_heads
|
| 157 |
+
self.split_size = self.hidden_size
|
| 158 |
+
self.hidden_dropout = config.hidden_dropout
|
| 159 |
+
|
| 160 |
+
if self.head_dim * self.num_heads != self.hidden_size:
|
| 161 |
+
raise ValueError(f"`hidden_size` must be divisible by num_heads (got `hidden_size`: {self.hidden_size} and `num_heads`:" f" {self.num_heads}).")
|
| 162 |
+
|
| 163 |
+
self.maybe_rotary = RotaryEmbedding(config.head_dim) if config.rotary else lambda q, k: (q, k)
|
| 164 |
+
|
| 165 |
+
# Layer-wise attention scaling
|
| 166 |
+
self.inv_norm_factor = 1.0 / math.sqrt(self.head_dim)
|
| 167 |
+
self.beta = self.inv_norm_factor
|
| 168 |
+
|
| 169 |
+
self.query_key_value = Linear(
|
| 170 |
+
self.hidden_size,
|
| 171 |
+
3 * self.hidden_size if not config.multi_query else (self.hidden_size + 2 * self.head_dim),
|
| 172 |
+
bias=config.bias,
|
| 173 |
+
)
|
| 174 |
+
self.multi_query = config.multi_query
|
| 175 |
+
self.dense = Linear(self.hidden_size, self.hidden_size, bias=config.bias)
|
| 176 |
+
self.attention_dropout = nn.Dropout(config.attention_dropout)
|
| 177 |
+
self.num_kv = config.n_head if not self.multi_query else 1
|
| 178 |
+
|
| 179 |
+
def _split_heads(self, fused_qkv: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 180 |
+
"""
|
| 181 |
+
Split the last dimension into (num_heads, head_dim) without making any copies, results share same memory
|
| 182 |
+
storage as `fused_qkv`
|
| 183 |
+
|
| 184 |
+
Args:
|
| 185 |
+
fused_qkv (`torch.tensor`, *required*): [batch_size, seq_length, num_heads * 3 * head_dim]
|
| 186 |
+
|
| 187 |
+
Returns:
|
| 188 |
+
query: [batch_size, seq_length, num_heads, head_dim] key: [batch_size, seq_length, num_heads, head_dim]
|
| 189 |
+
value: [batch_size, seq_length, num_heads, head_dim]
|
| 190 |
+
"""
|
| 191 |
+
if not self.multi_query:
|
| 192 |
+
batch_size, seq_length, three_times_hidden_size = fused_qkv.shape
|
| 193 |
+
fused_qkv = fused_qkv.view(batch_size, seq_length, self.num_heads, 3, self.head_dim)
|
| 194 |
+
return fused_qkv[..., 0, :], fused_qkv[..., 1, :], fused_qkv[..., 2, :]
|
| 195 |
+
else:
|
| 196 |
+
batch_size, seq_length, three_times_hidden_size = fused_qkv.shape
|
| 197 |
+
fused_qkv = fused_qkv.view(batch_size, seq_length, self.num_heads + 2, self.head_dim)
|
| 198 |
+
return fused_qkv[..., :-2, :], fused_qkv[..., [-2], :], fused_qkv[..., [-1], :]
|
| 199 |
+
|
| 200 |
+
def _merge_heads(self, x: torch.Tensor) -> torch.Tensor:
|
| 201 |
+
"""
|
| 202 |
+
Merge heads together over the last dimenstion
|
| 203 |
+
|
| 204 |
+
Args:
|
| 205 |
+
x: (`torch.tensor`, *required*): [batch_size * num_heads, seq_length, head_dim]
|
| 206 |
+
|
| 207 |
+
Returns:
|
| 208 |
+
torch.tensor: [batch_size, seq_length, num_heads * head_dim]
|
| 209 |
+
"""
|
| 210 |
+
# What we want to achieve is:
|
| 211 |
+
# batch_size * num_heads, seq_length, head_dim -> batch_size, seq_length, num_heads * head_dim
|
| 212 |
+
batch_size_and_num_heads, seq_length, _ = x.shape
|
| 213 |
+
batch_size = batch_size_and_num_heads // self.num_heads
|
| 214 |
+
|
| 215 |
+
# First view to decompose the batch size
|
| 216 |
+
# batch_size * num_heads, seq_length, head_dim -> batch_size, num_heads, seq_length, head_dim
|
| 217 |
+
x = x.view(batch_size, self.num_heads, seq_length, self.head_dim)
|
| 218 |
+
|
| 219 |
+
# batch_size, num_heads, seq_length, head_dim -> batch_size, seq_length, num_heads, head_dim
|
| 220 |
+
x = x.permute(0, 2, 1, 3)
|
| 221 |
+
|
| 222 |
+
# batch_size, seq_length, num_heads, head_dim -> batch_size, seq_length, num_heads * head_dim
|
| 223 |
+
return x.reshape(batch_size, seq_length, self.num_heads * self.head_dim)
|
| 224 |
+
|
| 225 |
+
def forward(
|
| 226 |
+
self,
|
| 227 |
+
hidden_states: torch.Tensor,
|
| 228 |
+
alibi: torch.Tensor,
|
| 229 |
+
attention_mask: torch.Tensor,
|
| 230 |
+
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
| 231 |
+
head_mask: Optional[torch.Tensor] = None,
|
| 232 |
+
use_cache: bool = False,
|
| 233 |
+
output_attentions: bool = False,
|
| 234 |
+
):
|
| 235 |
+
fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size]
|
| 236 |
+
|
| 237 |
+
# 3 x [batch_size, seq_length, num_heads, head_dim]
|
| 238 |
+
(query_layer, key_layer, value_layer) = self._split_heads(fused_qkv)
|
| 239 |
+
|
| 240 |
+
batch_size, q_length, _, _ = query_layer.shape
|
| 241 |
+
|
| 242 |
+
query_layer = query_layer.transpose(1, 2).reshape(batch_size * self.num_heads, q_length, self.head_dim)
|
| 243 |
+
key_layer = key_layer.transpose(1, 2).reshape(
|
| 244 |
+
batch_size * self.num_kv,
|
| 245 |
+
q_length,
|
| 246 |
+
self.head_dim,
|
| 247 |
+
)
|
| 248 |
+
value_layer = value_layer.transpose(1, 2).reshape(batch_size * self.num_kv, q_length, self.head_dim)
|
| 249 |
+
|
| 250 |
+
query_layer, key_layer = self.maybe_rotary(query_layer, key_layer)
|
| 251 |
+
|
| 252 |
+
if layer_past is not None:
|
| 253 |
+
past_key, past_value = layer_past
|
| 254 |
+
# concatenate along seq_length dimension:
|
| 255 |
+
# - key: [batch_size * self.num_heads, head_dim, kv_length]
|
| 256 |
+
# - value: [batch_size * self.num_heads, kv_length, head_dim]
|
| 257 |
+
key_layer = torch.cat((past_key, key_layer), dim=1)
|
| 258 |
+
value_layer = torch.cat((past_value, value_layer), dim=1)
|
| 259 |
+
|
| 260 |
+
_, kv_length, _ = key_layer.shape
|
| 261 |
+
|
| 262 |
+
if use_cache is True:
|
| 263 |
+
present = (key_layer, value_layer)
|
| 264 |
+
else:
|
| 265 |
+
present = None
|
| 266 |
+
|
| 267 |
+
if alibi is None:
|
| 268 |
+
query_layer_ = query_layer.reshape(batch_size, self.num_heads, -1, self.head_dim)
|
| 269 |
+
key_layer_ = key_layer.reshape(batch_size, self.num_kv, -1, self.head_dim)
|
| 270 |
+
value_layer_ = value_layer.reshape(batch_size, self.num_kv, -1, self.head_dim)
|
| 271 |
+
|
| 272 |
+
attn_output = F.scaled_dot_product_attention(query_layer_, key_layer_, value_layer_, None, 0.0, is_causal=True)
|
| 273 |
+
|
| 274 |
+
x = attn_output.view(batch_size, self.num_heads, q_length, self.head_dim)
|
| 275 |
+
x = x.permute(0, 2, 1, 3)
|
| 276 |
+
attn_output = x.reshape(batch_size, q_length, self.num_heads * self.head_dim)
|
| 277 |
+
|
| 278 |
+
output_tensor = self.dense(attn_output)
|
| 279 |
+
|
| 280 |
+
outputs = (output_tensor, present)
|
| 281 |
+
assert not output_attentions # not supported.
|
| 282 |
+
return outputs
|
| 283 |
+
else:
|
| 284 |
+
attention_mask_float = (attention_mask * 1.0).masked_fill(attention_mask, -1e9).to(torch.bfloat16)
|
| 285 |
+
matmul_result = query_layer @ key_layer.transpose(-1, -2)
|
| 286 |
+
|
| 287 |
+
# change view to [batch_size, num_heads, q_length, kv_length]
|
| 288 |
+
attention_scores = matmul_result.view(batch_size, self.num_heads, q_length, kv_length)
|
| 289 |
+
|
| 290 |
+
# cast attention scores to fp32, compute scaled softmax and cast back to initial dtype - [batch_size, num_heads, q_length, kv_length]
|
| 291 |
+
input_dtype = attention_scores.dtype
|
| 292 |
+
# `float16` has a minimum value of -65504.0, whereas `bfloat16` and `float32` have a minimum value of `-3.4e+38`
|
| 293 |
+
if input_dtype == torch.float16 or input_dtype == torch.bfloat16:
|
| 294 |
+
attention_scores = attention_scores.to(torch.float32)
|
| 295 |
+
# attn_weights = torch.masked_fill(attention_scores, attention_mask, torch.finfo(attention_scores.dtype).min)
|
| 296 |
+
attention_probs = F.softmax(
|
| 297 |
+
(attention_scores + alibi.view(batch_size, self.num_heads, 1, -1)) * self.inv_norm_factor + attention_mask_float,
|
| 298 |
+
dim=-1,
|
| 299 |
+
dtype=hidden_states.dtype,
|
| 300 |
+
)
|
| 301 |
+
# [batch_size, num_heads, q_length, kv_length]
|
| 302 |
+
attention_probs = self.attention_dropout(attention_probs)
|
| 303 |
+
|
| 304 |
+
if head_mask is not None:
|
| 305 |
+
attention_probs = attention_probs * head_mask
|
| 306 |
+
|
| 307 |
+
# change view [batch_size x num_heads, q_length, kv_length]
|
| 308 |
+
attention_probs_reshaped = attention_probs.view(batch_size * self.num_heads, q_length, kv_length)
|
| 309 |
+
|
| 310 |
+
# matmul: [batch_size * num_heads, q_length, head_dim]
|
| 311 |
+
context_layer = attention_probs_reshaped @ value_layer
|
| 312 |
+
|
| 313 |
+
# change view [batch_size, num_heads, q_length, head_dim]
|
| 314 |
+
context_layer = self._merge_heads(context_layer)
|
| 315 |
+
|
| 316 |
+
output_tensor = self.dense(context_layer)
|
| 317 |
+
|
| 318 |
+
outputs = (output_tensor, present)
|
| 319 |
+
if output_attentions:
|
| 320 |
+
outputs += (attention_probs,)
|
| 321 |
+
|
| 322 |
+
return outputs
|
| 323 |
+
|
| 324 |
+
|
| 325 |
+
class MLP(nn.Module):
|
| 326 |
+
def __init__(self, config: RWConfig):
|
| 327 |
+
super().__init__()
|
| 328 |
+
hidden_size = config.hidden_size
|
| 329 |
+
|
| 330 |
+
self.dense_h_to_4h = Linear(hidden_size, 4 * hidden_size, bias=config.bias)
|
| 331 |
+
self.act = nn.GELU()
|
| 332 |
+
self.dense_4h_to_h = Linear(4 * hidden_size, hidden_size, bias=config.bias)
|
| 333 |
+
self.hidden_dropout = config.hidden_dropout
|
| 334 |
+
|
| 335 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 336 |
+
x = self.act(self.dense_h_to_4h(x))
|
| 337 |
+
x = self.dense_4h_to_h(x)
|
| 338 |
+
return x
|
| 339 |
+
|
| 340 |
+
|
| 341 |
+
class DecoderLayer(nn.Module):
|
| 342 |
+
def __init__(self, config: RWConfig):
|
| 343 |
+
super().__init__()
|
| 344 |
+
hidden_size = config.hidden_size
|
| 345 |
+
|
| 346 |
+
self.input_layernorm = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
| 347 |
+
self.num_heads = config.n_head
|
| 348 |
+
self.self_attention = Attention(config)
|
| 349 |
+
|
| 350 |
+
if not config.parallel_attn:
|
| 351 |
+
# unused if parallel attn
|
| 352 |
+
self.post_attention_layernorm = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
| 353 |
+
|
| 354 |
+
self.mlp = MLP(config)
|
| 355 |
+
|
| 356 |
+
self.apply_residual_connection_post_layernorm = config.apply_residual_connection_post_layernorm
|
| 357 |
+
self.hidden_dropout = config.hidden_dropout
|
| 358 |
+
|
| 359 |
+
self.config = config
|
| 360 |
+
|
| 361 |
+
def forward(
|
| 362 |
+
self,
|
| 363 |
+
hidden_states: torch.Tensor,
|
| 364 |
+
alibi: torch.Tensor,
|
| 365 |
+
attention_mask: torch.Tensor,
|
| 366 |
+
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
| 367 |
+
head_mask: Optional[torch.Tensor] = None,
|
| 368 |
+
use_cache: bool = False,
|
| 369 |
+
output_attentions: bool = False,
|
| 370 |
+
):
|
| 371 |
+
layernorm_output = self.input_layernorm(hidden_states)
|
| 372 |
+
residual = hidden_states
|
| 373 |
+
|
| 374 |
+
# Self attention.
|
| 375 |
+
attn_outputs = self.self_attention(
|
| 376 |
+
layernorm_output,
|
| 377 |
+
layer_past=layer_past,
|
| 378 |
+
attention_mask=attention_mask,
|
| 379 |
+
alibi=alibi,
|
| 380 |
+
head_mask=head_mask,
|
| 381 |
+
use_cache=use_cache,
|
| 382 |
+
output_attentions=output_attentions,
|
| 383 |
+
)
|
| 384 |
+
|
| 385 |
+
attention_output = attn_outputs[0]
|
| 386 |
+
|
| 387 |
+
if not self.config.parallel_attn:
|
| 388 |
+
residual = dropout_add(attention_output, residual, self.config.attention_dropout, training=self.training)
|
| 389 |
+
layernorm_output = self.post_attention_layernorm(residual)
|
| 390 |
+
|
| 391 |
+
outputs = attn_outputs[1:]
|
| 392 |
+
|
| 393 |
+
# MLP.
|
| 394 |
+
mlp_output = self.mlp(layernorm_output)
|
| 395 |
+
|
| 396 |
+
if self.config.parallel_attn:
|
| 397 |
+
mlp_output += attention_output
|
| 398 |
+
|
| 399 |
+
output = dropout_add(mlp_output, residual, self.config.hidden_dropout, training=self.training)
|
| 400 |
+
|
| 401 |
+
if use_cache:
|
| 402 |
+
outputs = (output,) + outputs
|
| 403 |
+
else:
|
| 404 |
+
outputs = (output,) + outputs[1:]
|
| 405 |
+
|
| 406 |
+
return outputs # hidden_states, present, attentions
|
| 407 |
+
|
| 408 |
+
|
| 409 |
+
class RWPreTrainedModel(PreTrainedModel):
|
| 410 |
+
_keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"]
|
| 411 |
+
"""
|
| 412 |
+
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
| 413 |
+
models.
|
| 414 |
+
"""
|
| 415 |
+
|
| 416 |
+
config_class = RWConfig
|
| 417 |
+
base_model_prefix = "transformer"
|
| 418 |
+
supports_gradient_checkpointing = True
|
| 419 |
+
_no_split_modules = ["DecoderLayer"]
|
| 420 |
+
|
| 421 |
+
def __init__(self, *inputs, **kwargs):
|
| 422 |
+
super().__init__(*inputs, **kwargs)
|
| 423 |
+
|
| 424 |
+
def _init_weights(self, module: nn.Module):
|
| 425 |
+
"""Initialize the weights."""
|
| 426 |
+
if isinstance(module, nn.Linear) or isinstance(module, Linear):
|
| 427 |
+
# Slightly different from the TF version which uses truncated_normal for initialization
|
| 428 |
+
# cf https://github.com/pytorch/pytorch/pull/5617
|
| 429 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
| 430 |
+
if module.bias is not None:
|
| 431 |
+
module.bias.data.zero_()
|
| 432 |
+
elif isinstance(module, nn.Embedding):
|
| 433 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
| 434 |
+
if module.padding_idx is not None:
|
| 435 |
+
module.weight.data[module.padding_idx].zero_()
|
| 436 |
+
elif isinstance(module, LayerNorm):
|
| 437 |
+
module.bias.data.zero_()
|
| 438 |
+
module.weight.data.fill_(1.0)
|
| 439 |
+
|
| 440 |
+
def _set_gradient_checkpointing(self, module: nn.Module, value: bool = False):
|
| 441 |
+
if isinstance(module, RWModel):
|
| 442 |
+
module.gradient_checkpointing = value
|
| 443 |
+
|
| 444 |
+
@staticmethod
|
| 445 |
+
def _convert_to_standard_cache(past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]], batch_size: int) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]:
|
| 446 |
+
"""
|
| 447 |
+
Standardizes the format of the cache so as to match most implementations, i.e. to tuple(tuple([batch_size,
|
| 448 |
+
num_heads, ...]))
|
| 449 |
+
"""
|
| 450 |
+
batch_size_times_num_heads, head_dim, seq_length = past_key_value[0][0].shape
|
| 451 |
+
num_heads = batch_size_times_num_heads // batch_size
|
| 452 |
+
# key: [batch_size * num_heads, head_dim, seq_length] -> [batch_size, num_heads, head_dim, seq_length]
|
| 453 |
+
# value: [batch_size * num_heads, seq_length, head_dim] -> [batch_size, num_heads, seq_length, head_dim]
|
| 454 |
+
return tuple(
|
| 455 |
+
(
|
| 456 |
+
layer_past[0].view(batch_size, num_heads, head_dim, seq_length),
|
| 457 |
+
layer_past[1].view(batch_size, num_heads, seq_length, head_dim),
|
| 458 |
+
)
|
| 459 |
+
for layer_past in past_key_value
|
| 460 |
+
)
|
| 461 |
+
|
| 462 |
+
@staticmethod
|
| 463 |
+
def _convert_to_rw_cache(past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]]) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]:
|
| 464 |
+
batch_size, num_heads, head_dim, seq_length = past_key_value[0][0].shape
|
| 465 |
+
batch_size_times_num_heads = batch_size * num_heads
|
| 466 |
+
# key: [batch_size, num_heads, head_dim, seq_length] -> [batch_size * num_heads, head_dim, seq_length]
|
| 467 |
+
# value: [batch_size, num_heads, seq_length, head_dim] -> [batch_size * num_heads, seq_length, head_dim]
|
| 468 |
+
return tuple(
|
| 469 |
+
(
|
| 470 |
+
layer_past[0].view(batch_size_times_num_heads, head_dim, seq_length),
|
| 471 |
+
layer_past[1].view(batch_size_times_num_heads, seq_length, head_dim),
|
| 472 |
+
)
|
| 473 |
+
for layer_past in past_key_value
|
| 474 |
+
)
|
| 475 |
+
|
| 476 |
+
|
| 477 |
+
class RWModel(RWPreTrainedModel):
|
| 478 |
+
def __init__(self, config: RWConfig):
|
| 479 |
+
super().__init__(config)
|
| 480 |
+
|
| 481 |
+
self.embed_dim = config.hidden_size
|
| 482 |
+
self.num_heads = config.n_head
|
| 483 |
+
self.alibi = config.alibi
|
| 484 |
+
|
| 485 |
+
# Embedding + LN Embedding
|
| 486 |
+
self.word_embeddings = nn.Embedding(config.vocab_size, self.embed_dim)
|
| 487 |
+
|
| 488 |
+
# Transformer blocks
|
| 489 |
+
self.h = nn.ModuleList([DecoderLayer(config) for _ in range(config.num_hidden_layers)])
|
| 490 |
+
|
| 491 |
+
# Final Layer Norm
|
| 492 |
+
self.ln_f = LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
|
| 493 |
+
|
| 494 |
+
self.gradient_checkpointing = False
|
| 495 |
+
|
| 496 |
+
# Initialize weights and apply final processing
|
| 497 |
+
self.post_init()
|
| 498 |
+
|
| 499 |
+
def get_input_embeddings(self):
|
| 500 |
+
return self.word_embeddings
|
| 501 |
+
|
| 502 |
+
def _prepare_attn_mask(self, attention_mask: torch.Tensor, input_shape: Tuple[int, int], past_key_values_length: int) -> torch.BoolTensor:
|
| 503 |
+
# create causal mask
|
| 504 |
+
# [batch_size, seq_length] -> [batch_size, 1, tgt_length, src_length]
|
| 505 |
+
combined_attention_mask = None
|
| 506 |
+
device = attention_mask.device
|
| 507 |
+
_, src_length = input_shape
|
| 508 |
+
|
| 509 |
+
if src_length > 1:
|
| 510 |
+
combined_attention_mask = _make_causal_mask(input_shape, device=device, past_key_values_length=past_key_values_length)
|
| 511 |
+
|
| 512 |
+
# [batch_size, seq_length] -> [batch_size, 1, tgt_length, src_length]
|
| 513 |
+
expanded_attn_mask = _expand_mask(attention_mask, tgt_length=src_length)
|
| 514 |
+
combined_attention_mask = expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask | combined_attention_mask
|
| 515 |
+
|
| 516 |
+
return combined_attention_mask
|
| 517 |
+
|
| 518 |
+
def set_input_embeddings(self, new_embeddings: torch.Tensor):
|
| 519 |
+
self.word_embeddings = new_embeddings
|
| 520 |
+
|
| 521 |
+
def forward(
|
| 522 |
+
self,
|
| 523 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 524 |
+
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
|
| 525 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 526 |
+
head_mask: Optional[torch.LongTensor] = None,
|
| 527 |
+
inputs_embeds: Optional[torch.LongTensor] = None,
|
| 528 |
+
use_cache: Optional[bool] = None,
|
| 529 |
+
output_attentions: Optional[bool] = None,
|
| 530 |
+
output_hidden_states: Optional[bool] = None,
|
| 531 |
+
return_dict: Optional[bool] = None,
|
| 532 |
+
**deprecated_arguments,
|
| 533 |
+
) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]:
|
| 534 |
+
if deprecated_arguments.pop("position_ids", False) is not False:
|
| 535 |
+
# `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None`
|
| 536 |
+
warnings.warn(
|
| 537 |
+
"`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore" " passing `position_ids`.",
|
| 538 |
+
FutureWarning,
|
| 539 |
+
)
|
| 540 |
+
if len(deprecated_arguments) > 0:
|
| 541 |
+
raise ValueError(f"Got unexpected arguments: {deprecated_arguments}")
|
| 542 |
+
|
| 543 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 544 |
+
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 545 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
| 546 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 547 |
+
|
| 548 |
+
if input_ids is not None and inputs_embeds is not None:
|
| 549 |
+
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
| 550 |
+
elif input_ids is not None:
|
| 551 |
+
batch_size, seq_length = input_ids.shape
|
| 552 |
+
elif inputs_embeds is not None:
|
| 553 |
+
batch_size, seq_length, _ = inputs_embeds.shape
|
| 554 |
+
else:
|
| 555 |
+
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
| 556 |
+
|
| 557 |
+
if past_key_values is None:
|
| 558 |
+
past_key_values = tuple([None] * len(self.h))
|
| 559 |
+
|
| 560 |
+
# Prepare head mask if needed
|
| 561 |
+
# 1.0 in head_mask indicate we keep the head
|
| 562 |
+
# attention_probs has shape batch_size x num_heads x N x N
|
| 563 |
+
# head_mask has shape n_layer x batch x num_heads x N x N
|
| 564 |
+
head_mask = self.get_head_mask(head_mask, self.config.n_layer)
|
| 565 |
+
|
| 566 |
+
if inputs_embeds is None:
|
| 567 |
+
inputs_embeds = self.word_embeddings(input_ids)
|
| 568 |
+
|
| 569 |
+
hidden_states = inputs_embeds
|
| 570 |
+
|
| 571 |
+
presents = () if use_cache else None
|
| 572 |
+
all_self_attentions = () if output_attentions else None
|
| 573 |
+
all_hidden_states = () if output_hidden_states else None
|
| 574 |
+
|
| 575 |
+
# Compute alibi tensor: check build_alibi_tensor documentation
|
| 576 |
+
seq_length_with_past = seq_length
|
| 577 |
+
past_key_values_length = 0
|
| 578 |
+
if past_key_values[0] is not None:
|
| 579 |
+
past_key_values_length = past_key_values[0][0].shape[2]
|
| 580 |
+
seq_length_with_past = seq_length_with_past + past_key_values_length
|
| 581 |
+
if attention_mask is None:
|
| 582 |
+
attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device)
|
| 583 |
+
else:
|
| 584 |
+
attention_mask = attention_mask.to(hidden_states.device)
|
| 585 |
+
|
| 586 |
+
if self.alibi:
|
| 587 |
+
alibi = build_alibi_tensor(attention_mask, self.num_heads, dtype=hidden_states.dtype)
|
| 588 |
+
else:
|
| 589 |
+
alibi = None
|
| 590 |
+
|
| 591 |
+
causal_mask = self._prepare_attn_mask(
|
| 592 |
+
attention_mask,
|
| 593 |
+
input_shape=(batch_size, seq_length),
|
| 594 |
+
past_key_values_length=past_key_values_length,
|
| 595 |
+
)
|
| 596 |
+
|
| 597 |
+
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
|
| 598 |
+
if output_hidden_states:
|
| 599 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 600 |
+
|
| 601 |
+
if self.gradient_checkpointing and self.training:
|
| 602 |
+
if use_cache:
|
| 603 |
+
logger.warning("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
|
| 604 |
+
use_cache = False
|
| 605 |
+
|
| 606 |
+
def create_custom_forward(module):
|
| 607 |
+
def custom_forward(*inputs):
|
| 608 |
+
# None for past_key_value
|
| 609 |
+
return module(*inputs, use_cache=use_cache, output_attentions=output_attentions)
|
| 610 |
+
|
| 611 |
+
return custom_forward
|
| 612 |
+
|
| 613 |
+
outputs = torch.utils.checkpoint.checkpoint(
|
| 614 |
+
create_custom_forward(block),
|
| 615 |
+
hidden_states,
|
| 616 |
+
alibi,
|
| 617 |
+
causal_mask,
|
| 618 |
+
head_mask[i],
|
| 619 |
+
)
|
| 620 |
+
else:
|
| 621 |
+
outputs = block(
|
| 622 |
+
hidden_states,
|
| 623 |
+
layer_past=layer_past,
|
| 624 |
+
attention_mask=causal_mask,
|
| 625 |
+
head_mask=head_mask[i],
|
| 626 |
+
use_cache=use_cache,
|
| 627 |
+
output_attentions=output_attentions,
|
| 628 |
+
alibi=alibi,
|
| 629 |
+
)
|
| 630 |
+
|
| 631 |
+
hidden_states = outputs[0]
|
| 632 |
+
if use_cache is True:
|
| 633 |
+
presents = presents + (outputs[1],)
|
| 634 |
+
|
| 635 |
+
if output_attentions:
|
| 636 |
+
all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
|
| 637 |
+
|
| 638 |
+
# Add last hidden state
|
| 639 |
+
hidden_states = self.ln_f(hidden_states)
|
| 640 |
+
|
| 641 |
+
if output_hidden_states:
|
| 642 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 643 |
+
|
| 644 |
+
if not return_dict:
|
| 645 |
+
return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
|
| 646 |
+
|
| 647 |
+
return BaseModelOutputWithPastAndCrossAttentions(
|
| 648 |
+
last_hidden_state=hidden_states,
|
| 649 |
+
past_key_values=presents,
|
| 650 |
+
hidden_states=all_hidden_states,
|
| 651 |
+
attentions=all_self_attentions,
|
| 652 |
+
)
|
| 653 |
+
|
| 654 |
+
|
| 655 |
+
class RWForCausalLM(RWPreTrainedModel):
|
| 656 |
+
_keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"]
|
| 657 |
+
|
| 658 |
+
def __init__(self, config: RWConfig):
|
| 659 |
+
super().__init__(config)
|
| 660 |
+
self.transformer = RWModel(config)
|
| 661 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
| 662 |
+
|
| 663 |
+
# Initialize weights and apply final processing
|
| 664 |
+
self.post_init()
|
| 665 |
+
|
| 666 |
+
def get_output_embeddings(self):
|
| 667 |
+
return self.lm_head
|
| 668 |
+
|
| 669 |
+
def set_output_embeddings(self, new_embeddings: torch.Tensor):
|
| 670 |
+
self.lm_head = new_embeddings
|
| 671 |
+
|
| 672 |
+
def prepare_inputs_for_generation(
|
| 673 |
+
self,
|
| 674 |
+
input_ids: torch.LongTensor,
|
| 675 |
+
past: Optional[torch.Tensor] = None,
|
| 676 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 677 |
+
**kwargs,
|
| 678 |
+
) -> dict:
|
| 679 |
+
# only last token for input_ids if past is not None
|
| 680 |
+
if past:
|
| 681 |
+
input_ids = input_ids[:, -1].unsqueeze(-1)
|
| 682 |
+
|
| 683 |
+
# the cache may be in the stardard format (e.g. in contrastive search), convert to our's format if needed
|
| 684 |
+
if past[0][0].shape[0] == input_ids.shape[0]:
|
| 685 |
+
past = self._convert_to_rw_cache(past)
|
| 686 |
+
|
| 687 |
+
return {
|
| 688 |
+
"input_ids": input_ids,
|
| 689 |
+
"past_key_values": past,
|
| 690 |
+
"use_cache": kwargs.get("use_cache"),
|
| 691 |
+
"attention_mask": attention_mask,
|
| 692 |
+
}
|
| 693 |
+
|
| 694 |
+
def forward(
|
| 695 |
+
self,
|
| 696 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 697 |
+
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
|
| 698 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 699 |
+
head_mask: Optional[torch.Tensor] = None,
|
| 700 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
| 701 |
+
labels: Optional[torch.Tensor] = None,
|
| 702 |
+
use_cache: Optional[bool] = None,
|
| 703 |
+
output_attentions: Optional[bool] = None,
|
| 704 |
+
output_hidden_states: Optional[bool] = None,
|
| 705 |
+
return_dict: Optional[bool] = None,
|
| 706 |
+
**deprecated_arguments,
|
| 707 |
+
) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
|
| 708 |
+
r"""
|
| 709 |
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 710 |
+
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
|
| 711 |
+
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
|
| 712 |
+
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
|
| 713 |
+
"""
|
| 714 |
+
if deprecated_arguments.pop("position_ids", False) is not False:
|
| 715 |
+
# `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None`
|
| 716 |
+
warnings.warn(
|
| 717 |
+
"`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore" " passing `position_ids`.",
|
| 718 |
+
FutureWarning,
|
| 719 |
+
)
|
| 720 |
+
if len(deprecated_arguments) > 0:
|
| 721 |
+
raise ValueError(f"Got unexpected arguments: {deprecated_arguments}")
|
| 722 |
+
|
| 723 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 724 |
+
|
| 725 |
+
transformer_outputs = self.transformer(
|
| 726 |
+
input_ids,
|
| 727 |
+
past_key_values=past_key_values,
|
| 728 |
+
attention_mask=attention_mask,
|
| 729 |
+
head_mask=head_mask,
|
| 730 |
+
inputs_embeds=inputs_embeds,
|
| 731 |
+
use_cache=use_cache,
|
| 732 |
+
output_attentions=output_attentions,
|
| 733 |
+
output_hidden_states=output_hidden_states,
|
| 734 |
+
return_dict=return_dict,
|
| 735 |
+
)
|
| 736 |
+
hidden_states = transformer_outputs[0]
|
| 737 |
+
|
| 738 |
+
lm_logits = self.lm_head(hidden_states)
|
| 739 |
+
|
| 740 |
+
loss = None
|
| 741 |
+
if labels is not None:
|
| 742 |
+
# Shift so that tokens < n predict n
|
| 743 |
+
shift_logits = lm_logits[..., :-1, :].contiguous()
|
| 744 |
+
shift_labels = labels[..., 1:].contiguous()
|
| 745 |
+
batch_size, seq_length, vocab_size = shift_logits.shape
|
| 746 |
+
# Flatten the tokens
|
| 747 |
+
loss_fct = CrossEntropyLoss()
|
| 748 |
+
loss = loss_fct(shift_logits.view(batch_size * seq_length, vocab_size), shift_labels.view(batch_size * seq_length))
|
| 749 |
+
|
| 750 |
+
if not return_dict:
|
| 751 |
+
output = (lm_logits,) + transformer_outputs[1:]
|
| 752 |
+
return ((loss,) + output) if loss is not None else output
|
| 753 |
+
|
| 754 |
+
return CausalLMOutputWithCrossAttentions(
|
| 755 |
+
loss=loss,
|
| 756 |
+
logits=lm_logits,
|
| 757 |
+
past_key_values=transformer_outputs.past_key_values,
|
| 758 |
+
hidden_states=transformer_outputs.hidden_states,
|
| 759 |
+
attentions=transformer_outputs.attentions,
|
| 760 |
+
)
|
| 761 |
+
|
| 762 |
+
def _reorder_cache(self, past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...], beam_idx: torch.LongTensor) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]:
|
| 763 |
+
"""
|
| 764 |
+
This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
|
| 765 |
+
[`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
|
| 766 |
+
beam_idx at every generation step.
|
| 767 |
+
|
| 768 |
+
Output shares the same memory storage as `past`.
|
| 769 |
+
"""
|
| 770 |
+
standardized_past = self._convert_to_standard_cache(past, batch_size=len(beam_idx))
|
| 771 |
+
|
| 772 |
+
# Get a copy of `beam_idx` on all the devices where we need those indices.
|
| 773 |
+
device_to_beam_idx = {past_state.device: beam_idx.to(past_state.device) for layer_past in past for past_state in layer_past}
|
| 774 |
+
reordered_past = tuple(
|
| 775 |
+
(
|
| 776 |
+
layer_past[0].index_select(0, device_to_beam_idx[layer_past[0].device]),
|
| 777 |
+
layer_past[1].index_select(0, device_to_beam_idx[layer_past[0].device]),
|
| 778 |
+
)
|
| 779 |
+
for layer_past in standardized_past
|
| 780 |
+
)
|
| 781 |
+
return self._convert_to_rw_cache(reordered_past)
|
| 782 |
+
|
| 783 |
+
|
| 784 |
+
class RWForSequenceClassification(RWPreTrainedModel):
|
| 785 |
+
_keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"]
|
| 786 |
+
|
| 787 |
+
def __init__(self, config: RWConfig):
|
| 788 |
+
super().__init__(config)
|
| 789 |
+
self.num_labels = config.num_labels
|
| 790 |
+
self.transformer = RWModel(config)
|
| 791 |
+
self.score = nn.Linear(config.hidden_size, config.num_labels, bias=False)
|
| 792 |
+
|
| 793 |
+
# Initialize weights and apply final processing
|
| 794 |
+
self.post_init()
|
| 795 |
+
|
| 796 |
+
def forward(
|
| 797 |
+
self,
|
| 798 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 799 |
+
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
|
| 800 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 801 |
+
head_mask: Optional[torch.Tensor] = None,
|
| 802 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
| 803 |
+
labels: Optional[torch.Tensor] = None,
|
| 804 |
+
use_cache: Optional[bool] = None,
|
| 805 |
+
output_attentions: Optional[bool] = None,
|
| 806 |
+
output_hidden_states: Optional[bool] = None,
|
| 807 |
+
return_dict: Optional[bool] = None,
|
| 808 |
+
**deprecated_arguments,
|
| 809 |
+
) -> Union[Tuple[torch.Tensor], SequenceClassifierOutputWithPast]:
|
| 810 |
+
r"""
|
| 811 |
+
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
| 812 |
+
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
| 813 |
+
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
| 814 |
+
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
| 815 |
+
"""
|
| 816 |
+
if deprecated_arguments.pop("position_ids", False) is not False:
|
| 817 |
+
# `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None`
|
| 818 |
+
warnings.warn(
|
| 819 |
+
"`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore" " passing `position_ids`.",
|
| 820 |
+
FutureWarning,
|
| 821 |
+
)
|
| 822 |
+
if len(deprecated_arguments) > 0:
|
| 823 |
+
raise ValueError(f"Got unexpected arguments: {deprecated_arguments}")
|
| 824 |
+
|
| 825 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 826 |
+
|
| 827 |
+
transformer_outputs = self.transformer(
|
| 828 |
+
input_ids,
|
| 829 |
+
past_key_values=past_key_values,
|
| 830 |
+
attention_mask=attention_mask,
|
| 831 |
+
head_mask=head_mask,
|
| 832 |
+
inputs_embeds=inputs_embeds,
|
| 833 |
+
use_cache=use_cache,
|
| 834 |
+
output_attentions=output_attentions,
|
| 835 |
+
output_hidden_states=output_hidden_states,
|
| 836 |
+
return_dict=return_dict,
|
| 837 |
+
)
|
| 838 |
+
|
| 839 |
+
hidden_states = transformer_outputs[0]
|
| 840 |
+
logits = self.score(hidden_states)
|
| 841 |
+
|
| 842 |
+
if input_ids is not None:
|
| 843 |
+
batch_size = input_ids.shape[0]
|
| 844 |
+
else:
|
| 845 |
+
batch_size = inputs_embeds.shape[0]
|
| 846 |
+
|
| 847 |
+
if self.config.pad_token_id is None and batch_size != 1:
|
| 848 |
+
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
|
| 849 |
+
if self.config.pad_token_id is None:
|
| 850 |
+
sequence_lengths = -1
|
| 851 |
+
else:
|
| 852 |
+
if input_ids is not None:
|
| 853 |
+
sequence_lengths = torch.ne(input_ids, self.config.pad_token_id).sum(dim=-1) - 1
|
| 854 |
+
else:
|
| 855 |
+
sequence_lengths = -1
|
| 856 |
+
logger.warning(
|
| 857 |
+
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
|
| 858 |
+
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
|
| 859 |
+
)
|
| 860 |
+
|
| 861 |
+
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
|
| 862 |
+
|
| 863 |
+
loss = None
|
| 864 |
+
if labels is not None:
|
| 865 |
+
if self.config.problem_type is None:
|
| 866 |
+
if self.num_labels == 1:
|
| 867 |
+
self.config.problem_type = "regression"
|
| 868 |
+
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
| 869 |
+
self.config.problem_type = "single_label_classification"
|
| 870 |
+
else:
|
| 871 |
+
self.config.problem_type = "multi_label_classification"
|
| 872 |
+
|
| 873 |
+
if self.config.problem_type == "regression":
|
| 874 |
+
loss_fct = MSELoss()
|
| 875 |
+
if self.num_labels == 1:
|
| 876 |
+
loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
|
| 877 |
+
else:
|
| 878 |
+
loss = loss_fct(pooled_logits, labels)
|
| 879 |
+
elif self.config.problem_type == "single_label_classification":
|
| 880 |
+
loss_fct = CrossEntropyLoss()
|
| 881 |
+
loss = loss_fct(pooled_logits, labels)
|
| 882 |
+
elif self.config.problem_type == "multi_label_classification":
|
| 883 |
+
loss_fct = BCEWithLogitsLoss()
|
| 884 |
+
loss = loss_fct(pooled_logits, labels)
|
| 885 |
+
if not return_dict:
|
| 886 |
+
output = (pooled_logits,) + transformer_outputs[1:]
|
| 887 |
+
return ((loss,) + output) if loss is not None else output
|
| 888 |
+
|
| 889 |
+
return SequenceClassifierOutputWithPast(
|
| 890 |
+
loss=loss,
|
| 891 |
+
logits=pooled_logits,
|
| 892 |
+
past_key_values=transformer_outputs.past_key_values,
|
| 893 |
+
hidden_states=transformer_outputs.hidden_states,
|
| 894 |
+
attentions=transformer_outputs.attentions,
|
| 895 |
+
)
|
| 896 |
+
|
| 897 |
+
|
| 898 |
+
class RWForTokenClassification(RWPreTrainedModel):
|
| 899 |
+
_keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"]
|
| 900 |
+
|
| 901 |
+
def __init__(self, config: RWConfig):
|
| 902 |
+
super().__init__(config)
|
| 903 |
+
self.num_labels = config.num_labels
|
| 904 |
+
|
| 905 |
+
self.transformer = RWModel(config)
|
| 906 |
+
if hasattr(config, "classifier_dropout") and config.classifier_dropout is not None:
|
| 907 |
+
classifier_dropout = config.classifier_dropout
|
| 908 |
+
elif hasattr(config, "hidden_dropout") and config.hidden_dropout is not None:
|
| 909 |
+
classifier_dropout = config.hidden_dropout
|
| 910 |
+
else:
|
| 911 |
+
classifier_dropout = 0.1
|
| 912 |
+
self.dropout = nn.Dropout(classifier_dropout)
|
| 913 |
+
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
|
| 914 |
+
|
| 915 |
+
# Initialize weights and apply final processing
|
| 916 |
+
self.post_init()
|
| 917 |
+
|
| 918 |
+
def forward(
|
| 919 |
+
self,
|
| 920 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 921 |
+
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
|
| 922 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 923 |
+
head_mask: Optional[torch.Tensor] = None,
|
| 924 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
| 925 |
+
labels: Optional[torch.Tensor] = None,
|
| 926 |
+
use_cache: Optional[bool] = None,
|
| 927 |
+
output_attentions: Optional[bool] = None,
|
| 928 |
+
output_hidden_states: Optional[bool] = None,
|
| 929 |
+
return_dict: Optional[bool] = None,
|
| 930 |
+
**deprecated_arguments,
|
| 931 |
+
) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
|
| 932 |
+
r"""
|
| 933 |
+
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
| 934 |
+
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
| 935 |
+
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
| 936 |
+
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
| 937 |
+
"""
|
| 938 |
+
if deprecated_arguments.pop("position_ids", False) is not False:
|
| 939 |
+
# `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None`
|
| 940 |
+
warnings.warn(
|
| 941 |
+
"`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore" " passing `position_ids`.",
|
| 942 |
+
FutureWarning,
|
| 943 |
+
)
|
| 944 |
+
if len(deprecated_arguments) > 0:
|
| 945 |
+
raise ValueError(f"Got unexpected arguments: {deprecated_arguments}")
|
| 946 |
+
|
| 947 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 948 |
+
|
| 949 |
+
transformer_outputs = self.transformer(
|
| 950 |
+
input_ids,
|
| 951 |
+
past_key_values=past_key_values,
|
| 952 |
+
attention_mask=attention_mask,
|
| 953 |
+
head_mask=head_mask,
|
| 954 |
+
inputs_embeds=inputs_embeds,
|
| 955 |
+
use_cache=use_cache,
|
| 956 |
+
output_attentions=output_attentions,
|
| 957 |
+
output_hidden_states=output_hidden_states,
|
| 958 |
+
return_dict=return_dict,
|
| 959 |
+
)
|
| 960 |
+
|
| 961 |
+
hidden_states = transformer_outputs[0]
|
| 962 |
+
hidden_states = self.dropout(hidden_states)
|
| 963 |
+
logits = self.classifier(hidden_states)
|
| 964 |
+
|
| 965 |
+
loss = None
|
| 966 |
+
if labels is not None:
|
| 967 |
+
batch_size, seq_length = labels.shape
|
| 968 |
+
loss_fct = CrossEntropyLoss()
|
| 969 |
+
loss = loss_fct(logits.view(batch_size * seq_length, self.num_labels), labels.view(batch_size * seq_length))
|
| 970 |
+
|
| 971 |
+
if not return_dict:
|
| 972 |
+
output = (logits,) + transformer_outputs[2:]
|
| 973 |
+
return ((loss,) + output) if loss is not None else output
|
| 974 |
+
|
| 975 |
+
return TokenClassifierOutput(
|
| 976 |
+
loss=loss,
|
| 977 |
+
logits=logits,
|
| 978 |
+
hidden_states=transformer_outputs.hidden_states,
|
| 979 |
+
attentions=transformer_outputs.attentions,
|
| 980 |
+
)
|
| 981 |
+
|
| 982 |
+
|
| 983 |
+
class RWForQuestionAnswering(RWPreTrainedModel):
|
| 984 |
+
_keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"]
|
| 985 |
+
|
| 986 |
+
def __init__(self, config):
|
| 987 |
+
super().__init__(config)
|
| 988 |
+
self.transformer = RWModel(config)
|
| 989 |
+
self.qa_outputs = nn.Linear(config.hidden_size, 2)
|
| 990 |
+
|
| 991 |
+
# Initialize weights and apply final processing
|
| 992 |
+
self.post_init()
|
| 993 |
+
|
| 994 |
+
def forward(
|
| 995 |
+
self,
|
| 996 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 997 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 998 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 999 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
| 1000 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 1001 |
+
start_positions: Optional[torch.LongTensor] = None,
|
| 1002 |
+
end_positions: Optional[torch.LongTensor] = None,
|
| 1003 |
+
output_attentions: Optional[bool] = None,
|
| 1004 |
+
output_hidden_states: Optional[bool] = None,
|
| 1005 |
+
return_dict: Optional[bool] = None,
|
| 1006 |
+
) -> Union[Tuple, QuestionAnsweringModelOutput]:
|
| 1007 |
+
r"""
|
| 1008 |
+
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
| 1009 |
+
Labels for position (index) of the start of the labelled span for computing the token classification loss.
|
| 1010 |
+
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
|
| 1011 |
+
are not taken into account for computing the loss.
|
| 1012 |
+
end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
| 1013 |
+
Labels for position (index) of the end of the labelled span for computing the token classification loss.
|
| 1014 |
+
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
|
| 1015 |
+
are not taken into account for computing the loss.
|
| 1016 |
+
"""
|
| 1017 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1018 |
+
|
| 1019 |
+
outputs = self.transformer(
|
| 1020 |
+
input_ids,
|
| 1021 |
+
attention_mask=attention_mask,
|
| 1022 |
+
position_ids=position_ids,
|
| 1023 |
+
head_mask=head_mask,
|
| 1024 |
+
inputs_embeds=inputs_embeds,
|
| 1025 |
+
output_attentions=output_attentions,
|
| 1026 |
+
output_hidden_states=output_hidden_states,
|
| 1027 |
+
return_dict=return_dict,
|
| 1028 |
+
)
|
| 1029 |
+
|
| 1030 |
+
sequence_output = outputs[0]
|
| 1031 |
+
|
| 1032 |
+
logits = self.qa_outputs(sequence_output)
|
| 1033 |
+
start_logits, end_logits = logits.split(1, dim=-1)
|
| 1034 |
+
start_logits = start_logits.squeeze(-1).contiguous()
|
| 1035 |
+
end_logits = end_logits.squeeze(-1).contiguous()
|
| 1036 |
+
|
| 1037 |
+
total_loss = None
|
| 1038 |
+
if start_positions is not None and end_positions is not None:
|
| 1039 |
+
# If we are on multi-GPU, split add a dimension
|
| 1040 |
+
if len(start_positions.size()) > 1:
|
| 1041 |
+
start_positions = start_positions.squeeze(-1)
|
| 1042 |
+
if len(end_positions.size()) > 1:
|
| 1043 |
+
end_positions = end_positions.squeeze(-1)
|
| 1044 |
+
# sometimes the start/end positions are outside our model inputs, we ignore these terms
|
| 1045 |
+
ignored_index = start_logits.size(1)
|
| 1046 |
+
start_positions = start_positions.clamp(0, ignored_index)
|
| 1047 |
+
end_positions = end_positions.clamp(0, ignored_index)
|
| 1048 |
+
|
| 1049 |
+
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
|
| 1050 |
+
start_loss = loss_fct(start_logits, start_positions)
|
| 1051 |
+
end_loss = loss_fct(end_logits, end_positions)
|
| 1052 |
+
total_loss = (start_loss + end_loss) / 2
|
| 1053 |
+
|
| 1054 |
+
if not return_dict:
|
| 1055 |
+
output = (start_logits, end_logits) + outputs[2:]
|
| 1056 |
+
return ((total_loss,) + output) if total_loss is not None else output
|
| 1057 |
+
|
| 1058 |
+
return QuestionAnsweringModelOutput(
|
| 1059 |
+
loss=total_loss,
|
| 1060 |
+
start_logits=start_logits,
|
| 1061 |
+
end_logits=end_logits,
|
| 1062 |
+
hidden_states=outputs.hidden_states,
|
| 1063 |
+
attentions=outputs.attentions,
|
| 1064 |
+
)
|
mllm/flamingo/flamingo-falcon-7B.json
ADDED
|
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_commit_hash": null,
|
| 3 |
+
"architectures": [
|
| 4 |
+
"FlamingoModel"
|
| 5 |
+
],
|
| 6 |
+
"cross_attn_every_n_layers": 4,
|
| 7 |
+
"model_type": "flamingo",
|
| 8 |
+
"text_config": {
|
| 9 |
+
"architectures": [
|
| 10 |
+
"RWForCausalLM"
|
| 11 |
+
],
|
| 12 |
+
"apply_residual_connection_post_layernorm": false,
|
| 13 |
+
"attention_dropout": 0.0,
|
| 14 |
+
"bias": false,
|
| 15 |
+
"bos_token_id": 11,
|
| 16 |
+
"eos_token_id": 11,
|
| 17 |
+
"hidden_dropout": 0.0,
|
| 18 |
+
"hidden_size": 4544,
|
| 19 |
+
"initializer_range": 0.02,
|
| 20 |
+
"layer_norm_epsilon": 1e-05,
|
| 21 |
+
"model_type": "RefinedWebModel",
|
| 22 |
+
"multi_query": true,
|
| 23 |
+
"n_head": 71,
|
| 24 |
+
"n_layer": 32,
|
| 25 |
+
"parallel_attn": true,
|
| 26 |
+
"torch_dtype": "bfloat16",
|
| 27 |
+
"transformers_version": "4.27.4",
|
| 28 |
+
"use_cache": true,
|
| 29 |
+
"vocab_size": 65024
|
| 30 |
+
},
|
| 31 |
+
"tie_word_embeddings": false,
|
| 32 |
+
"torch_dtype": "float32",
|
| 33 |
+
"transformers_version": null,
|
| 34 |
+
"use_media_placement_augmentation": true,
|
| 35 |
+
"vision_config": {
|
| 36 |
+
"_name_or_path": "openai/clip-vit-large-patch14",
|
| 37 |
+
"add_cross_attention": false,
|
| 38 |
+
"architectures": null,
|
| 39 |
+
"attention_dropout": 0.0,
|
| 40 |
+
"bad_words_ids": null,
|
| 41 |
+
"begin_suppress_tokens": null,
|
| 42 |
+
"bos_token_id": null,
|
| 43 |
+
"chunk_size_feed_forward": 0,
|
| 44 |
+
"cross_attention_hidden_size": null,
|
| 45 |
+
"decoder_start_token_id": null,
|
| 46 |
+
"diversity_penalty": 0.0,
|
| 47 |
+
"do_sample": false,
|
| 48 |
+
"early_stopping": false,
|
| 49 |
+
"encoder_no_repeat_ngram_size": 0,
|
| 50 |
+
"eos_token_id": null,
|
| 51 |
+
"exponential_decay_length_penalty": null,
|
| 52 |
+
"finetuning_task": null,
|
| 53 |
+
"forced_bos_token_id": null,
|
| 54 |
+
"forced_eos_token_id": null,
|
| 55 |
+
"hidden_act": "quick_gelu",
|
| 56 |
+
"hidden_size": 1024,
|
| 57 |
+
"id2label": {
|
| 58 |
+
"0": "LABEL_0",
|
| 59 |
+
"1": "LABEL_1"
|
| 60 |
+
},
|
| 61 |
+
"image_size": 224,
|
| 62 |
+
"initializer_factor": 1.0,
|
| 63 |
+
"initializer_range": 0.02,
|
| 64 |
+
"intermediate_size": 4096,
|
| 65 |
+
"is_decoder": false,
|
| 66 |
+
"is_encoder_decoder": false,
|
| 67 |
+
"label2id": {
|
| 68 |
+
"LABEL_0": 0,
|
| 69 |
+
"LABEL_1": 1
|
| 70 |
+
},
|
| 71 |
+
"layer_norm_eps": 1e-05,
|
| 72 |
+
"length_penalty": 1.0,
|
| 73 |
+
"max_length": 20,
|
| 74 |
+
"min_length": 0,
|
| 75 |
+
"model_type": "clip_vision_model",
|
| 76 |
+
"no_repeat_ngram_size": 0,
|
| 77 |
+
"num_attention_heads": 16,
|
| 78 |
+
"num_beam_groups": 1,
|
| 79 |
+
"num_beams": 1,
|
| 80 |
+
"num_channels": 3,
|
| 81 |
+
"num_hidden_layers": 24,
|
| 82 |
+
"num_return_sequences": 1,
|
| 83 |
+
"output_attentions": false,
|
| 84 |
+
"output_hidden_states": false,
|
| 85 |
+
"output_scores": false,
|
| 86 |
+
"pad_token_id": null,
|
| 87 |
+
"patch_size": 14,
|
| 88 |
+
"prefix": null,
|
| 89 |
+
"problem_type": null,
|
| 90 |
+
"projection_dim": 512,
|
| 91 |
+
"pruned_heads": {},
|
| 92 |
+
"remove_invalid_values": false,
|
| 93 |
+
"repetition_penalty": 1.0,
|
| 94 |
+
"return_dict": true,
|
| 95 |
+
"return_dict_in_generate": false,
|
| 96 |
+
"sep_token_id": null,
|
| 97 |
+
"suppress_tokens": null,
|
| 98 |
+
"task_specific_params": null,
|
| 99 |
+
"temperature": 1.0,
|
| 100 |
+
"tf_legacy_loss": false,
|
| 101 |
+
"tie_encoder_decoder": false,
|
| 102 |
+
"tie_word_embeddings": true,
|
| 103 |
+
"tokenizer_class": null,
|
| 104 |
+
"top_k": 50,
|
| 105 |
+
"top_p": 1.0,
|
| 106 |
+
"torch_dtype": null,
|
| 107 |
+
"torchscript": false,
|
| 108 |
+
"transformers_version": "4.28.1",
|
| 109 |
+
"typical_p": 1.0,
|
| 110 |
+
"use_bfloat16": false
|
| 111 |
+
}
|
| 112 |
+
}
|
mllm/flamingo/flamingo-llama2-chat-13B.json
ADDED
|
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_commit_hash": null,
|
| 3 |
+
"architectures": [
|
| 4 |
+
"FlamingoForConditionalGeneration"
|
| 5 |
+
],
|
| 6 |
+
"cross_attn_every_n_layers": 8,
|
| 7 |
+
"model_type": "flamingo",
|
| 8 |
+
"text_config": {
|
| 9 |
+
"_name_or_path": "meta-llama/Llama-2-13b-chat-hf",
|
| 10 |
+
"architectures": [
|
| 11 |
+
"LlamaForCausalLM"
|
| 12 |
+
],
|
| 13 |
+
"bos_token_id": 1,
|
| 14 |
+
"eos_token_id": 2,
|
| 15 |
+
"hidden_act": "silu",
|
| 16 |
+
"hidden_size": 5120,
|
| 17 |
+
"initializer_range": 0.02,
|
| 18 |
+
"intermediate_size": 13824,
|
| 19 |
+
"max_position_embeddings": 4096,
|
| 20 |
+
"model_type": "llama",
|
| 21 |
+
"num_attention_heads": 40,
|
| 22 |
+
"num_hidden_layers": 40,
|
| 23 |
+
"num_key_value_heads": 40,
|
| 24 |
+
"pad_token_id": 0,
|
| 25 |
+
"pretraining_tp": 1,
|
| 26 |
+
"rms_norm_eps": 1e-05,
|
| 27 |
+
"rope_scaling": null,
|
| 28 |
+
"tie_word_embeddings": false,
|
| 29 |
+
"torch_dtype": "float16",
|
| 30 |
+
"transformers_version": "4.30.1",
|
| 31 |
+
"use_cache": true,
|
| 32 |
+
"vocab_size": 32000
|
| 33 |
+
},
|
| 34 |
+
"torch_dtype": "float32",
|
| 35 |
+
"transformers_version": null,
|
| 36 |
+
"use_media_placement_augmentation": true,
|
| 37 |
+
"vision_config": {
|
| 38 |
+
"_name_or_path": "openai/clip-vit-large-patch14",
|
| 39 |
+
"add_cross_attention": false,
|
| 40 |
+
"architectures": null,
|
| 41 |
+
"attention_dropout": 0.0,
|
| 42 |
+
"bad_words_ids": null,
|
| 43 |
+
"begin_suppress_tokens": null,
|
| 44 |
+
"bos_token_id": null,
|
| 45 |
+
"chunk_size_feed_forward": 0,
|
| 46 |
+
"cross_attention_hidden_size": null,
|
| 47 |
+
"decoder_start_token_id": null,
|
| 48 |
+
"diversity_penalty": 0.0,
|
| 49 |
+
"do_sample": false,
|
| 50 |
+
"early_stopping": false,
|
| 51 |
+
"encoder_no_repeat_ngram_size": 0,
|
| 52 |
+
"eos_token_id": null,
|
| 53 |
+
"exponential_decay_length_penalty": null,
|
| 54 |
+
"finetuning_task": null,
|
| 55 |
+
"forced_bos_token_id": null,
|
| 56 |
+
"forced_eos_token_id": null,
|
| 57 |
+
"hidden_act": "quick_gelu",
|
| 58 |
+
"hidden_size": 1024,
|
| 59 |
+
"id2label": {
|
| 60 |
+
"0": "LABEL_0",
|
| 61 |
+
"1": "LABEL_1"
|
| 62 |
+
},
|
| 63 |
+
"image_size": 224,
|
| 64 |
+
"initializer_factor": 1.0,
|
| 65 |
+
"initializer_range": 0.02,
|
| 66 |
+
"intermediate_size": 4096,
|
| 67 |
+
"is_decoder": false,
|
| 68 |
+
"is_encoder_decoder": false,
|
| 69 |
+
"label2id": {
|
| 70 |
+
"LABEL_0": 0,
|
| 71 |
+
"LABEL_1": 1
|
| 72 |
+
},
|
| 73 |
+
"layer_norm_eps": 1e-05,
|
| 74 |
+
"length_penalty": 1.0,
|
| 75 |
+
"max_length": 20,
|
| 76 |
+
"min_length": 0,
|
| 77 |
+
"model_type": "clip_vision_model",
|
| 78 |
+
"no_repeat_ngram_size": 0,
|
| 79 |
+
"num_attention_heads": 16,
|
| 80 |
+
"num_beam_groups": 1,
|
| 81 |
+
"num_beams": 1,
|
| 82 |
+
"num_channels": 3,
|
| 83 |
+
"num_hidden_layers": 24,
|
| 84 |
+
"num_return_sequences": 1,
|
| 85 |
+
"output_attentions": false,
|
| 86 |
+
"output_hidden_states": false,
|
| 87 |
+
"output_scores": false,
|
| 88 |
+
"pad_token_id": null,
|
| 89 |
+
"patch_size": 14,
|
| 90 |
+
"prefix": null,
|
| 91 |
+
"problem_type": null,
|
| 92 |
+
"projection_dim": 512,
|
| 93 |
+
"pruned_heads": {},
|
| 94 |
+
"remove_invalid_values": false,
|
| 95 |
+
"repetition_penalty": 1.0,
|
| 96 |
+
"return_dict": true,
|
| 97 |
+
"return_dict_in_generate": false,
|
| 98 |
+
"sep_token_id": null,
|
| 99 |
+
"suppress_tokens": null,
|
| 100 |
+
"task_specific_params": null,
|
| 101 |
+
"temperature": 1.0,
|
| 102 |
+
"tf_legacy_loss": false,
|
| 103 |
+
"tie_encoder_decoder": false,
|
| 104 |
+
"tie_word_embeddings": true,
|
| 105 |
+
"tokenizer_class": null,
|
| 106 |
+
"top_k": 50,
|
| 107 |
+
"top_p": 1.0,
|
| 108 |
+
"torch_dtype": null,
|
| 109 |
+
"torchscript": false,
|
| 110 |
+
"transformers_version": "4.30.1",
|
| 111 |
+
"typical_p": 1.0,
|
| 112 |
+
"use_bfloat16": false
|
| 113 |
+
}
|
| 114 |
+
}
|
mllm/flamingo/flamingo-llama2-chat-7B.json
ADDED
|
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_commit_hash": null,
|
| 3 |
+
"architectures": [
|
| 4 |
+
"FlamingoForConditionalGeneration"
|
| 5 |
+
],
|
| 6 |
+
"cross_attn_every_n_layers": 4,
|
| 7 |
+
"model_type": "flamingo",
|
| 8 |
+
"text_config": {
|
| 9 |
+
"_name_or_path": "meta-llama/Llama-2-7b-chat-hf",
|
| 10 |
+
"architectures": [
|
| 11 |
+
"LlamaForCausalLM"
|
| 12 |
+
],
|
| 13 |
+
"bos_token_id": 1,
|
| 14 |
+
"eos_token_id": 2,
|
| 15 |
+
"hidden_act": "silu",
|
| 16 |
+
"hidden_size": 4096,
|
| 17 |
+
"initializer_range": 0.02,
|
| 18 |
+
"intermediate_size": 11008,
|
| 19 |
+
"max_length": 4096,
|
| 20 |
+
"max_position_embeddings": 2048,
|
| 21 |
+
"model_type": "llama",
|
| 22 |
+
"num_attention_heads": 32,
|
| 23 |
+
"num_hidden_layers": 32,
|
| 24 |
+
"num_key_value_heads": 32,
|
| 25 |
+
"pad_token_id": 0,
|
| 26 |
+
"pretraining_tp": 1,
|
| 27 |
+
"rms_norm_eps": 1e-05,
|
| 28 |
+
"rope_scaling": null,
|
| 29 |
+
"tie_word_embeddings": false,
|
| 30 |
+
"torch_dtype": "float16",
|
| 31 |
+
"transformers_version": "4.32.0.dev0",
|
| 32 |
+
"use_cache": true,
|
| 33 |
+
"vocab_size": 32000
|
| 34 |
+
},
|
| 35 |
+
"torch_dtype": "float32",
|
| 36 |
+
"transformers_version": null,
|
| 37 |
+
"use_media_placement_augmentation": true,
|
| 38 |
+
"vision_config": {
|
| 39 |
+
"_name_or_path": "openai/clip-vit-large-patch14",
|
| 40 |
+
"add_cross_attention": false,
|
| 41 |
+
"architectures": null,
|
| 42 |
+
"attention_dropout": 0.0,
|
| 43 |
+
"bad_words_ids": null,
|
| 44 |
+
"begin_suppress_tokens": null,
|
| 45 |
+
"bos_token_id": null,
|
| 46 |
+
"chunk_size_feed_forward": 0,
|
| 47 |
+
"cross_attention_hidden_size": null,
|
| 48 |
+
"decoder_start_token_id": null,
|
| 49 |
+
"diversity_penalty": 0.0,
|
| 50 |
+
"do_sample": false,
|
| 51 |
+
"early_stopping": false,
|
| 52 |
+
"encoder_no_repeat_ngram_size": 0,
|
| 53 |
+
"eos_token_id": null,
|
| 54 |
+
"exponential_decay_length_penalty": null,
|
| 55 |
+
"finetuning_task": null,
|
| 56 |
+
"forced_bos_token_id": null,
|
| 57 |
+
"forced_eos_token_id": null,
|
| 58 |
+
"hidden_act": "quick_gelu",
|
| 59 |
+
"hidden_size": 1024,
|
| 60 |
+
"id2label": {
|
| 61 |
+
"0": "LABEL_0",
|
| 62 |
+
"1": "LABEL_1"
|
| 63 |
+
},
|
| 64 |
+
"image_size": 224,
|
| 65 |
+
"initializer_factor": 1.0,
|
| 66 |
+
"initializer_range": 0.02,
|
| 67 |
+
"intermediate_size": 4096,
|
| 68 |
+
"is_decoder": false,
|
| 69 |
+
"is_encoder_decoder": false,
|
| 70 |
+
"label2id": {
|
| 71 |
+
"LABEL_0": 0,
|
| 72 |
+
"LABEL_1": 1
|
| 73 |
+
},
|
| 74 |
+
"layer_norm_eps": 1e-05,
|
| 75 |
+
"length_penalty": 1.0,
|
| 76 |
+
"max_length": 20,
|
| 77 |
+
"min_length": 0,
|
| 78 |
+
"model_type": "clip_vision_model",
|
| 79 |
+
"no_repeat_ngram_size": 0,
|
| 80 |
+
"num_attention_heads": 16,
|
| 81 |
+
"num_beam_groups": 1,
|
| 82 |
+
"num_beams": 1,
|
| 83 |
+
"num_channels": 3,
|
| 84 |
+
"num_hidden_layers": 24,
|
| 85 |
+
"num_return_sequences": 1,
|
| 86 |
+
"output_attentions": false,
|
| 87 |
+
"output_hidden_states": false,
|
| 88 |
+
"output_scores": false,
|
| 89 |
+
"pad_token_id": null,
|
| 90 |
+
"patch_size": 14,
|
| 91 |
+
"prefix": null,
|
| 92 |
+
"problem_type": null,
|
| 93 |
+
"projection_dim": 512,
|
| 94 |
+
"pruned_heads": {},
|
| 95 |
+
"remove_invalid_values": false,
|
| 96 |
+
"repetition_penalty": 1.0,
|
| 97 |
+
"return_dict": true,
|
| 98 |
+
"return_dict_in_generate": false,
|
| 99 |
+
"sep_token_id": null,
|
| 100 |
+
"suppress_tokens": null,
|
| 101 |
+
"task_specific_params": null,
|
| 102 |
+
"temperature": 1.0,
|
| 103 |
+
"tf_legacy_loss": false,
|
| 104 |
+
"tie_encoder_decoder": false,
|
| 105 |
+
"tie_word_embeddings": true,
|
| 106 |
+
"tokenizer_class": null,
|
| 107 |
+
"top_k": 50,
|
| 108 |
+
"top_p": 1.0,
|
| 109 |
+
"torch_dtype": null,
|
| 110 |
+
"torchscript": false,
|
| 111 |
+
"transformers_version": "4.30.1",
|
| 112 |
+
"typical_p": 1.0,
|
| 113 |
+
"use_bfloat16": false
|
| 114 |
+
}
|
| 115 |
+
}
|
mllm/flamingo/flamingo-mpt-1B-redpajama.json
ADDED
|
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_commit_hash": null,
|
| 3 |
+
"architectures": [
|
| 4 |
+
"FlamingoForConditionalGeneration"
|
| 5 |
+
],
|
| 6 |
+
"cross_attn_every_n_layers": 1,
|
| 7 |
+
"model_type": "flamingo",
|
| 8 |
+
"text_config": {
|
| 9 |
+
"_name_or_path": "",
|
| 10 |
+
"alibi": true,
|
| 11 |
+
"alibi_bias_max": 8,
|
| 12 |
+
"architectures": [
|
| 13 |
+
"MosaicGPT"
|
| 14 |
+
],
|
| 15 |
+
"attn_clip_qkv": null,
|
| 16 |
+
"attn_impl": "torch",
|
| 17 |
+
"attn_pdrop": 0,
|
| 18 |
+
"attn_qk_ln": true,
|
| 19 |
+
"attn_uses_sequence_id": false,
|
| 20 |
+
"d_model": 2048,
|
| 21 |
+
"hidden_size": 2048,
|
| 22 |
+
"emb_init_std": null,
|
| 23 |
+
"emb_init_uniform_lim": null,
|
| 24 |
+
"emb_pdrop": 0,
|
| 25 |
+
"embedding_fraction": 1.0,
|
| 26 |
+
"fan_mode": "fan_in",
|
| 27 |
+
"init_device": "cpu",
|
| 28 |
+
"init_div_is_residual": true,
|
| 29 |
+
"init_gain": 0,
|
| 30 |
+
"init_nonlinearity": "relu",
|
| 31 |
+
"init_std": 0.02,
|
| 32 |
+
"logit_scale": null,
|
| 33 |
+
"low_precision_layernorm": true,
|
| 34 |
+
"max_seq_len": 2048,
|
| 35 |
+
"mlp_ratio": 4,
|
| 36 |
+
"model_type": "mosaic_gpt",
|
| 37 |
+
"n_heads": 16,
|
| 38 |
+
"n_layers": 24,
|
| 39 |
+
"no_bias": true,
|
| 40 |
+
"param_init_fn": "kaiming_normal_",
|
| 41 |
+
"prefix_lm": false,
|
| 42 |
+
"resid_pdrop": 0,
|
| 43 |
+
"softmax_scale": null,
|
| 44 |
+
"tokenizer_name": "EleutherAI/gpt-neox-20b",
|
| 45 |
+
"torch_dtype": "float32",
|
| 46 |
+
"transformers_version": "4.27.4",
|
| 47 |
+
"use_cache": false,
|
| 48 |
+
"verbose": 0,
|
| 49 |
+
"vocab_size": 50432
|
| 50 |
+
},
|
| 51 |
+
"torch_dtype": "float32",
|
| 52 |
+
"transformers_version": null,
|
| 53 |
+
"use_media_placement_augmentation": true,
|
| 54 |
+
"vision_config": {
|
| 55 |
+
"_name_or_path": "openai/clip-vit-large-patch14",
|
| 56 |
+
"add_cross_attention": false,
|
| 57 |
+
"architectures": null,
|
| 58 |
+
"attention_dropout": 0.0,
|
| 59 |
+
"bad_words_ids": null,
|
| 60 |
+
"begin_suppress_tokens": null,
|
| 61 |
+
"bos_token_id": null,
|
| 62 |
+
"chunk_size_feed_forward": 0,
|
| 63 |
+
"cross_attention_hidden_size": null,
|
| 64 |
+
"decoder_start_token_id": null,
|
| 65 |
+
"diversity_penalty": 0.0,
|
| 66 |
+
"do_sample": false,
|
| 67 |
+
"early_stopping": false,
|
| 68 |
+
"encoder_no_repeat_ngram_size": 0,
|
| 69 |
+
"eos_token_id": null,
|
| 70 |
+
"exponential_decay_length_penalty": null,
|
| 71 |
+
"finetuning_task": null,
|
| 72 |
+
"forced_bos_token_id": null,
|
| 73 |
+
"forced_eos_token_id": null,
|
| 74 |
+
"hidden_act": "quick_gelu",
|
| 75 |
+
"hidden_size": 1024,
|
| 76 |
+
"id2label": {
|
| 77 |
+
"0": "LABEL_0",
|
| 78 |
+
"1": "LABEL_1"
|
| 79 |
+
},
|
| 80 |
+
"image_size": 224,
|
| 81 |
+
"initializer_factor": 1.0,
|
| 82 |
+
"initializer_range": 0.02,
|
| 83 |
+
"intermediate_size": 4096,
|
| 84 |
+
"is_decoder": false,
|
| 85 |
+
"is_encoder_decoder": false,
|
| 86 |
+
"label2id": {
|
| 87 |
+
"LABEL_0": 0,
|
| 88 |
+
"LABEL_1": 1
|
| 89 |
+
},
|
| 90 |
+
"layer_norm_eps": 1e-05,
|
| 91 |
+
"length_penalty": 1.0,
|
| 92 |
+
"max_length": 20,
|
| 93 |
+
"min_length": 0,
|
| 94 |
+
"model_type": "clip_vision_model",
|
| 95 |
+
"no_repeat_ngram_size": 0,
|
| 96 |
+
"num_attention_heads": 16,
|
| 97 |
+
"num_beam_groups": 1,
|
| 98 |
+
"num_beams": 1,
|
| 99 |
+
"num_channels": 3,
|
| 100 |
+
"num_hidden_layers": 24,
|
| 101 |
+
"num_return_sequences": 1,
|
| 102 |
+
"output_attentions": false,
|
| 103 |
+
"output_hidden_states": false,
|
| 104 |
+
"output_scores": false,
|
| 105 |
+
"pad_token_id": null,
|
| 106 |
+
"patch_size": 14,
|
| 107 |
+
"prefix": null,
|
| 108 |
+
"problem_type": null,
|
| 109 |
+
"projection_dim": 512,
|
| 110 |
+
"pruned_heads": {},
|
| 111 |
+
"remove_invalid_values": false,
|
| 112 |
+
"repetition_penalty": 1.0,
|
| 113 |
+
"return_dict": true,
|
| 114 |
+
"return_dict_in_generate": false,
|
| 115 |
+
"sep_token_id": null,
|
| 116 |
+
"suppress_tokens": null,
|
| 117 |
+
"task_specific_params": null,
|
| 118 |
+
"temperature": 1.0,
|
| 119 |
+
"tf_legacy_loss": false,
|
| 120 |
+
"tie_encoder_decoder": false,
|
| 121 |
+
"tie_word_embeddings": true,
|
| 122 |
+
"tokenizer_class": null,
|
| 123 |
+
"top_k": 50,
|
| 124 |
+
"top_p": 1.0,
|
| 125 |
+
"torch_dtype": null,
|
| 126 |
+
"torchscript": false,
|
| 127 |
+
"transformers_version": "4.30.1",
|
| 128 |
+
"typical_p": 1.0,
|
| 129 |
+
"use_bfloat16": false
|
| 130 |
+
}
|
| 131 |
+
}
|
mllm/flamingo/flamingo-mpt-30B-bf16.json
ADDED
|
@@ -0,0 +1,195 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_commit_hash": null,
|
| 3 |
+
"architectures": [
|
| 4 |
+
"FlamingoForConditionalGeneration"
|
| 5 |
+
],
|
| 6 |
+
"cross_attn_every_n_layers": 7,
|
| 7 |
+
"model_type": "flamingo",
|
| 8 |
+
"text_config": {
|
| 9 |
+
"_name_or_path": "",
|
| 10 |
+
"add_cross_attention": false,
|
| 11 |
+
"architectures": [
|
| 12 |
+
"MPTForCausalLM"
|
| 13 |
+
],
|
| 14 |
+
"attn_config": {
|
| 15 |
+
"alibi": true,
|
| 16 |
+
"alibi_bias_max": 8,
|
| 17 |
+
"attn_impl": "torch",
|
| 18 |
+
"attn_pdrop": 0,
|
| 19 |
+
"attn_type": "multihead_attention",
|
| 20 |
+
"attn_uses_sequence_id": false,
|
| 21 |
+
"clip_qkv": null,
|
| 22 |
+
"prefix_lm": false,
|
| 23 |
+
"qk_ln": false,
|
| 24 |
+
"softmax_scale": null
|
| 25 |
+
},
|
| 26 |
+
"bad_words_ids": null,
|
| 27 |
+
"begin_suppress_tokens": null,
|
| 28 |
+
"bos_token_id": null,
|
| 29 |
+
"chunk_size_feed_forward": 0,
|
| 30 |
+
"cross_attention_hidden_size": null,
|
| 31 |
+
"d_model": 7168,
|
| 32 |
+
"decoder_start_token_id": null,
|
| 33 |
+
"diversity_penalty": 0.0,
|
| 34 |
+
"do_sample": false,
|
| 35 |
+
"early_stopping": false,
|
| 36 |
+
"emb_pdrop": 0,
|
| 37 |
+
"embedding_fraction": 1.0,
|
| 38 |
+
"encoder_no_repeat_ngram_size": 0,
|
| 39 |
+
"eos_token_id": null,
|
| 40 |
+
"expansion_ratio": 4,
|
| 41 |
+
"exponential_decay_length_penalty": null,
|
| 42 |
+
"finetuning_task": null,
|
| 43 |
+
"forced_bos_token_id": null,
|
| 44 |
+
"forced_eos_token_id": null,
|
| 45 |
+
"hidden_size": 7168,
|
| 46 |
+
"id2label": {
|
| 47 |
+
"0": "LABEL_0",
|
| 48 |
+
"1": "LABEL_1"
|
| 49 |
+
},
|
| 50 |
+
"init_config": {
|
| 51 |
+
"emb_init_std": null,
|
| 52 |
+
"emb_init_uniform_lim": null,
|
| 53 |
+
"fan_mode": "fan_in",
|
| 54 |
+
"init_div_is_residual": true,
|
| 55 |
+
"init_gain": 0.0,
|
| 56 |
+
"init_nonlinearity": "relu",
|
| 57 |
+
"init_std": null,
|
| 58 |
+
"name": "kaiming_normal_",
|
| 59 |
+
"verbose": 0
|
| 60 |
+
},
|
| 61 |
+
"init_device": "cpu",
|
| 62 |
+
"is_decoder": false,
|
| 63 |
+
"is_encoder_decoder": false,
|
| 64 |
+
"label2id": {
|
| 65 |
+
"LABEL_0": 0,
|
| 66 |
+
"LABEL_1": 1
|
| 67 |
+
},
|
| 68 |
+
"learned_pos_emb": true,
|
| 69 |
+
"length_penalty": 1.0,
|
| 70 |
+
"logit_scale": null,
|
| 71 |
+
"max_length": 20,
|
| 72 |
+
"max_seq_len": 8192,
|
| 73 |
+
"min_length": 0,
|
| 74 |
+
"model_type": "mpt",
|
| 75 |
+
"n_heads": 64,
|
| 76 |
+
"n_layers": 48,
|
| 77 |
+
"no_bias": true,
|
| 78 |
+
"no_repeat_ngram_size": 0,
|
| 79 |
+
"norm_type": "low_precision_layernorm",
|
| 80 |
+
"num_beam_groups": 1,
|
| 81 |
+
"num_beams": 1,
|
| 82 |
+
"num_return_sequences": 1,
|
| 83 |
+
"output_attentions": false,
|
| 84 |
+
"output_hidden_states": false,
|
| 85 |
+
"output_scores": false,
|
| 86 |
+
"pad_token_id": null,
|
| 87 |
+
"prefix": null,
|
| 88 |
+
"problem_type": null,
|
| 89 |
+
"pruned_heads": {},
|
| 90 |
+
"remove_invalid_values": false,
|
| 91 |
+
"repetition_penalty": 1.0,
|
| 92 |
+
"resid_pdrop": 0,
|
| 93 |
+
"return_dict": true,
|
| 94 |
+
"return_dict_in_generate": false,
|
| 95 |
+
"sep_token_id": null,
|
| 96 |
+
"suppress_tokens": null,
|
| 97 |
+
"task_specific_params": null,
|
| 98 |
+
"temperature": 1.0,
|
| 99 |
+
"tf_legacy_loss": false,
|
| 100 |
+
"tie_encoder_decoder": false,
|
| 101 |
+
"tie_word_embeddings": true,
|
| 102 |
+
"tokenizer_class": null,
|
| 103 |
+
"tokenizer_name": "EleutherAI/gpt-neox-20b",
|
| 104 |
+
"top_k": 50,
|
| 105 |
+
"top_p": 1.0,
|
| 106 |
+
"torch_dtype": "bfloat16",
|
| 107 |
+
"torchscript": false,
|
| 108 |
+
"transformers_version": "4.30.1",
|
| 109 |
+
"typical_p": 1.0,
|
| 110 |
+
"use_bfloat16": false,
|
| 111 |
+
"use_cache": false,
|
| 112 |
+
"verbose": 0,
|
| 113 |
+
"vocab_size": 50432
|
| 114 |
+
},
|
| 115 |
+
"torch_dtype": "bfloat16",
|
| 116 |
+
"transformers_version": null,
|
| 117 |
+
"use_media_placement_augmentation": true,
|
| 118 |
+
"vision_config": {
|
| 119 |
+
"_name_or_path": "openai/clip-vit-large-patch14",
|
| 120 |
+
"add_cross_attention": false,
|
| 121 |
+
"architectures": null,
|
| 122 |
+
"attention_dropout": 0.0,
|
| 123 |
+
"bad_words_ids": null,
|
| 124 |
+
"begin_suppress_tokens": null,
|
| 125 |
+
"bos_token_id": null,
|
| 126 |
+
"chunk_size_feed_forward": 0,
|
| 127 |
+
"cross_attention_hidden_size": null,
|
| 128 |
+
"decoder_start_token_id": null,
|
| 129 |
+
"diversity_penalty": 0.0,
|
| 130 |
+
"do_sample": false,
|
| 131 |
+
"early_stopping": false,
|
| 132 |
+
"encoder_no_repeat_ngram_size": 0,
|
| 133 |
+
"eos_token_id": null,
|
| 134 |
+
"exponential_decay_length_penalty": null,
|
| 135 |
+
"finetuning_task": null,
|
| 136 |
+
"forced_bos_token_id": null,
|
| 137 |
+
"forced_eos_token_id": null,
|
| 138 |
+
"hidden_act": "quick_gelu",
|
| 139 |
+
"hidden_size": 1024,
|
| 140 |
+
"id2label": {
|
| 141 |
+
"0": "LABEL_0",
|
| 142 |
+
"1": "LABEL_1"
|
| 143 |
+
},
|
| 144 |
+
"image_size": 224,
|
| 145 |
+
"initializer_factor": 1.0,
|
| 146 |
+
"initializer_range": 0.02,
|
| 147 |
+
"intermediate_size": 4096,
|
| 148 |
+
"is_decoder": false,
|
| 149 |
+
"is_encoder_decoder": false,
|
| 150 |
+
"label2id": {
|
| 151 |
+
"LABEL_0": 0,
|
| 152 |
+
"LABEL_1": 1
|
| 153 |
+
},
|
| 154 |
+
"layer_norm_eps": 1e-05,
|
| 155 |
+
"length_penalty": 1.0,
|
| 156 |
+
"max_length": 20,
|
| 157 |
+
"min_length": 0,
|
| 158 |
+
"model_type": "clip_vision_model",
|
| 159 |
+
"no_repeat_ngram_size": 0,
|
| 160 |
+
"num_attention_heads": 16,
|
| 161 |
+
"num_beam_groups": 1,
|
| 162 |
+
"num_beams": 1,
|
| 163 |
+
"num_channels": 3,
|
| 164 |
+
"num_hidden_layers": 24,
|
| 165 |
+
"num_return_sequences": 1,
|
| 166 |
+
"output_attentions": false,
|
| 167 |
+
"output_hidden_states": false,
|
| 168 |
+
"output_scores": false,
|
| 169 |
+
"pad_token_id": null,
|
| 170 |
+
"patch_size": 14,
|
| 171 |
+
"prefix": null,
|
| 172 |
+
"problem_type": null,
|
| 173 |
+
"projection_dim": 512,
|
| 174 |
+
"pruned_heads": {},
|
| 175 |
+
"remove_invalid_values": false,
|
| 176 |
+
"repetition_penalty": 1.0,
|
| 177 |
+
"return_dict": true,
|
| 178 |
+
"return_dict_in_generate": false,
|
| 179 |
+
"sep_token_id": null,
|
| 180 |
+
"suppress_tokens": null,
|
| 181 |
+
"task_specific_params": null,
|
| 182 |
+
"temperature": 1.0,
|
| 183 |
+
"tf_legacy_loss": false,
|
| 184 |
+
"tie_encoder_decoder": false,
|
| 185 |
+
"tie_word_embeddings": true,
|
| 186 |
+
"tokenizer_class": null,
|
| 187 |
+
"top_k": 50,
|
| 188 |
+
"top_p": 1.0,
|
| 189 |
+
"torch_dtype": null,
|
| 190 |
+
"torchscript": false,
|
| 191 |
+
"transformers_version": "4.30.1",
|
| 192 |
+
"typical_p": 1.0,
|
| 193 |
+
"use_bfloat16": false
|
| 194 |
+
}
|
| 195 |
+
}
|
mllm/flamingo/flamingo-mpt-30B.json
ADDED
|
@@ -0,0 +1,195 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_commit_hash": null,
|
| 3 |
+
"architectures": [
|
| 4 |
+
"FlamingoForConditionalGeneration"
|
| 5 |
+
],
|
| 6 |
+
"cross_attn_every_n_layers": 7,
|
| 7 |
+
"model_type": "flamingo",
|
| 8 |
+
"text_config": {
|
| 9 |
+
"_name_or_path": "",
|
| 10 |
+
"add_cross_attention": false,
|
| 11 |
+
"architectures": [
|
| 12 |
+
"MPTForCausalLM"
|
| 13 |
+
],
|
| 14 |
+
"attn_config": {
|
| 15 |
+
"alibi": true,
|
| 16 |
+
"alibi_bias_max": 8,
|
| 17 |
+
"attn_impl": "torch",
|
| 18 |
+
"attn_pdrop": 0,
|
| 19 |
+
"attn_type": "multihead_attention",
|
| 20 |
+
"attn_uses_sequence_id": false,
|
| 21 |
+
"clip_qkv": null,
|
| 22 |
+
"prefix_lm": false,
|
| 23 |
+
"qk_ln": false,
|
| 24 |
+
"softmax_scale": null
|
| 25 |
+
},
|
| 26 |
+
"bad_words_ids": null,
|
| 27 |
+
"begin_suppress_tokens": null,
|
| 28 |
+
"bos_token_id": null,
|
| 29 |
+
"chunk_size_feed_forward": 0,
|
| 30 |
+
"cross_attention_hidden_size": null,
|
| 31 |
+
"d_model": 7168,
|
| 32 |
+
"decoder_start_token_id": null,
|
| 33 |
+
"diversity_penalty": 0.0,
|
| 34 |
+
"do_sample": false,
|
| 35 |
+
"early_stopping": false,
|
| 36 |
+
"emb_pdrop": 0,
|
| 37 |
+
"embedding_fraction": 1.0,
|
| 38 |
+
"encoder_no_repeat_ngram_size": 0,
|
| 39 |
+
"eos_token_id": null,
|
| 40 |
+
"expansion_ratio": 4,
|
| 41 |
+
"exponential_decay_length_penalty": null,
|
| 42 |
+
"finetuning_task": null,
|
| 43 |
+
"forced_bos_token_id": null,
|
| 44 |
+
"forced_eos_token_id": null,
|
| 45 |
+
"hidden_size": 7168,
|
| 46 |
+
"id2label": {
|
| 47 |
+
"0": "LABEL_0",
|
| 48 |
+
"1": "LABEL_1"
|
| 49 |
+
},
|
| 50 |
+
"init_config": {
|
| 51 |
+
"emb_init_std": null,
|
| 52 |
+
"emb_init_uniform_lim": null,
|
| 53 |
+
"fan_mode": "fan_in",
|
| 54 |
+
"init_div_is_residual": true,
|
| 55 |
+
"init_gain": 0.0,
|
| 56 |
+
"init_nonlinearity": "relu",
|
| 57 |
+
"init_std": null,
|
| 58 |
+
"name": "kaiming_normal_",
|
| 59 |
+
"verbose": 0
|
| 60 |
+
},
|
| 61 |
+
"init_device": "cpu",
|
| 62 |
+
"is_decoder": false,
|
| 63 |
+
"is_encoder_decoder": false,
|
| 64 |
+
"label2id": {
|
| 65 |
+
"LABEL_0": 0,
|
| 66 |
+
"LABEL_1": 1
|
| 67 |
+
},
|
| 68 |
+
"learned_pos_emb": true,
|
| 69 |
+
"length_penalty": 1.0,
|
| 70 |
+
"logit_scale": null,
|
| 71 |
+
"max_length": 20,
|
| 72 |
+
"max_seq_len": 8192,
|
| 73 |
+
"min_length": 0,
|
| 74 |
+
"model_type": "mpt",
|
| 75 |
+
"n_heads": 64,
|
| 76 |
+
"n_layers": 48,
|
| 77 |
+
"no_bias": true,
|
| 78 |
+
"no_repeat_ngram_size": 0,
|
| 79 |
+
"norm_type": "low_precision_layernorm",
|
| 80 |
+
"num_beam_groups": 1,
|
| 81 |
+
"num_beams": 1,
|
| 82 |
+
"num_return_sequences": 1,
|
| 83 |
+
"output_attentions": false,
|
| 84 |
+
"output_hidden_states": false,
|
| 85 |
+
"output_scores": false,
|
| 86 |
+
"pad_token_id": null,
|
| 87 |
+
"prefix": null,
|
| 88 |
+
"problem_type": null,
|
| 89 |
+
"pruned_heads": {},
|
| 90 |
+
"remove_invalid_values": false,
|
| 91 |
+
"repetition_penalty": 1.0,
|
| 92 |
+
"resid_pdrop": 0,
|
| 93 |
+
"return_dict": true,
|
| 94 |
+
"return_dict_in_generate": false,
|
| 95 |
+
"sep_token_id": null,
|
| 96 |
+
"suppress_tokens": null,
|
| 97 |
+
"task_specific_params": null,
|
| 98 |
+
"temperature": 1.0,
|
| 99 |
+
"tf_legacy_loss": false,
|
| 100 |
+
"tie_encoder_decoder": false,
|
| 101 |
+
"tie_word_embeddings": true,
|
| 102 |
+
"tokenizer_class": null,
|
| 103 |
+
"tokenizer_name": "EleutherAI/gpt-neox-20b",
|
| 104 |
+
"top_k": 50,
|
| 105 |
+
"top_p": 1.0,
|
| 106 |
+
"torch_dtype": "bfloat16",
|
| 107 |
+
"torchscript": false,
|
| 108 |
+
"transformers_version": "4.30.1",
|
| 109 |
+
"typical_p": 1.0,
|
| 110 |
+
"use_bfloat16": false,
|
| 111 |
+
"use_cache": false,
|
| 112 |
+
"verbose": 0,
|
| 113 |
+
"vocab_size": 50432
|
| 114 |
+
},
|
| 115 |
+
"torch_dtype": "float32",
|
| 116 |
+
"transformers_version": null,
|
| 117 |
+
"use_media_placement_augmentation": true,
|
| 118 |
+
"vision_config": {
|
| 119 |
+
"_name_or_path": "openai/clip-vit-large-patch14",
|
| 120 |
+
"add_cross_attention": false,
|
| 121 |
+
"architectures": null,
|
| 122 |
+
"attention_dropout": 0.0,
|
| 123 |
+
"bad_words_ids": null,
|
| 124 |
+
"begin_suppress_tokens": null,
|
| 125 |
+
"bos_token_id": null,
|
| 126 |
+
"chunk_size_feed_forward": 0,
|
| 127 |
+
"cross_attention_hidden_size": null,
|
| 128 |
+
"decoder_start_token_id": null,
|
| 129 |
+
"diversity_penalty": 0.0,
|
| 130 |
+
"do_sample": false,
|
| 131 |
+
"early_stopping": false,
|
| 132 |
+
"encoder_no_repeat_ngram_size": 0,
|
| 133 |
+
"eos_token_id": null,
|
| 134 |
+
"exponential_decay_length_penalty": null,
|
| 135 |
+
"finetuning_task": null,
|
| 136 |
+
"forced_bos_token_id": null,
|
| 137 |
+
"forced_eos_token_id": null,
|
| 138 |
+
"hidden_act": "quick_gelu",
|
| 139 |
+
"hidden_size": 1024,
|
| 140 |
+
"id2label": {
|
| 141 |
+
"0": "LABEL_0",
|
| 142 |
+
"1": "LABEL_1"
|
| 143 |
+
},
|
| 144 |
+
"image_size": 224,
|
| 145 |
+
"initializer_factor": 1.0,
|
| 146 |
+
"initializer_range": 0.02,
|
| 147 |
+
"intermediate_size": 4096,
|
| 148 |
+
"is_decoder": false,
|
| 149 |
+
"is_encoder_decoder": false,
|
| 150 |
+
"label2id": {
|
| 151 |
+
"LABEL_0": 0,
|
| 152 |
+
"LABEL_1": 1
|
| 153 |
+
},
|
| 154 |
+
"layer_norm_eps": 1e-05,
|
| 155 |
+
"length_penalty": 1.0,
|
| 156 |
+
"max_length": 20,
|
| 157 |
+
"min_length": 0,
|
| 158 |
+
"model_type": "clip_vision_model",
|
| 159 |
+
"no_repeat_ngram_size": 0,
|
| 160 |
+
"num_attention_heads": 16,
|
| 161 |
+
"num_beam_groups": 1,
|
| 162 |
+
"num_beams": 1,
|
| 163 |
+
"num_channels": 3,
|
| 164 |
+
"num_hidden_layers": 24,
|
| 165 |
+
"num_return_sequences": 1,
|
| 166 |
+
"output_attentions": false,
|
| 167 |
+
"output_hidden_states": false,
|
| 168 |
+
"output_scores": false,
|
| 169 |
+
"pad_token_id": null,
|
| 170 |
+
"patch_size": 14,
|
| 171 |
+
"prefix": null,
|
| 172 |
+
"problem_type": null,
|
| 173 |
+
"projection_dim": 512,
|
| 174 |
+
"pruned_heads": {},
|
| 175 |
+
"remove_invalid_values": false,
|
| 176 |
+
"repetition_penalty": 1.0,
|
| 177 |
+
"return_dict": true,
|
| 178 |
+
"return_dict_in_generate": false,
|
| 179 |
+
"sep_token_id": null,
|
| 180 |
+
"suppress_tokens": null,
|
| 181 |
+
"task_specific_params": null,
|
| 182 |
+
"temperature": 1.0,
|
| 183 |
+
"tf_legacy_loss": false,
|
| 184 |
+
"tie_encoder_decoder": false,
|
| 185 |
+
"tie_word_embeddings": true,
|
| 186 |
+
"tokenizer_class": null,
|
| 187 |
+
"top_k": 50,
|
| 188 |
+
"top_p": 1.0,
|
| 189 |
+
"torch_dtype": null,
|
| 190 |
+
"torchscript": false,
|
| 191 |
+
"transformers_version": "4.30.1",
|
| 192 |
+
"typical_p": 1.0,
|
| 193 |
+
"use_bfloat16": false
|
| 194 |
+
}
|
| 195 |
+
}
|
mllm/flamingo/flamingo-mpt-7B.json
ADDED
|
@@ -0,0 +1,195 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_commit_hash": null,
|
| 3 |
+
"architectures": [
|
| 4 |
+
"FlamingoForConditionalGeneration"
|
| 5 |
+
],
|
| 6 |
+
"cross_attn_every_n_layers": 4,
|
| 7 |
+
"model_type": "flamingo",
|
| 8 |
+
"text_config": {
|
| 9 |
+
"_name_or_path": "",
|
| 10 |
+
"add_cross_attention": false,
|
| 11 |
+
"architectures": [
|
| 12 |
+
"MPTForCausalLM"
|
| 13 |
+
],
|
| 14 |
+
"attn_config": {
|
| 15 |
+
"alibi": true,
|
| 16 |
+
"alibi_bias_max": 8,
|
| 17 |
+
"attn_impl": "torch",
|
| 18 |
+
"attn_pdrop": 0,
|
| 19 |
+
"attn_type": "multihead_attention",
|
| 20 |
+
"attn_uses_sequence_id": false,
|
| 21 |
+
"clip_qkv": null,
|
| 22 |
+
"prefix_lm": false,
|
| 23 |
+
"qk_ln": false,
|
| 24 |
+
"softmax_scale": null
|
| 25 |
+
},
|
| 26 |
+
"bad_words_ids": null,
|
| 27 |
+
"begin_suppress_tokens": null,
|
| 28 |
+
"bos_token_id": null,
|
| 29 |
+
"chunk_size_feed_forward": 0,
|
| 30 |
+
"cross_attention_hidden_size": null,
|
| 31 |
+
"d_model": 4096,
|
| 32 |
+
"decoder_start_token_id": null,
|
| 33 |
+
"diversity_penalty": 0.0,
|
| 34 |
+
"do_sample": false,
|
| 35 |
+
"early_stopping": false,
|
| 36 |
+
"emb_pdrop": 0,
|
| 37 |
+
"embedding_fraction": 1.0,
|
| 38 |
+
"encoder_no_repeat_ngram_size": 0,
|
| 39 |
+
"eos_token_id": null,
|
| 40 |
+
"expansion_ratio": 4,
|
| 41 |
+
"exponential_decay_length_penalty": null,
|
| 42 |
+
"finetuning_task": null,
|
| 43 |
+
"forced_bos_token_id": null,
|
| 44 |
+
"forced_eos_token_id": null,
|
| 45 |
+
"hidden_size": 4096,
|
| 46 |
+
"id2label": {
|
| 47 |
+
"0": "LABEL_0",
|
| 48 |
+
"1": "LABEL_1"
|
| 49 |
+
},
|
| 50 |
+
"init_config": {
|
| 51 |
+
"emb_init_std": null,
|
| 52 |
+
"emb_init_uniform_lim": null,
|
| 53 |
+
"fan_mode": "fan_in",
|
| 54 |
+
"init_div_is_residual": true,
|
| 55 |
+
"init_gain": 0,
|
| 56 |
+
"init_nonlinearity": "relu",
|
| 57 |
+
"init_std": 0.02,
|
| 58 |
+
"name": "kaiming_normal_",
|
| 59 |
+
"verbose": 0
|
| 60 |
+
},
|
| 61 |
+
"init_device": "cpu",
|
| 62 |
+
"is_decoder": false,
|
| 63 |
+
"is_encoder_decoder": false,
|
| 64 |
+
"label2id": {
|
| 65 |
+
"LABEL_0": 0,
|
| 66 |
+
"LABEL_1": 1
|
| 67 |
+
},
|
| 68 |
+
"learned_pos_emb": true,
|
| 69 |
+
"length_penalty": 1.0,
|
| 70 |
+
"logit_scale": null,
|
| 71 |
+
"max_length": 20,
|
| 72 |
+
"max_seq_len": 2048,
|
| 73 |
+
"min_length": 0,
|
| 74 |
+
"model_type": "mpt",
|
| 75 |
+
"n_heads": 32,
|
| 76 |
+
"n_layers": 32,
|
| 77 |
+
"no_bias": true,
|
| 78 |
+
"no_repeat_ngram_size": 0,
|
| 79 |
+
"norm_type": "low_precision_layernorm",
|
| 80 |
+
"num_beam_groups": 1,
|
| 81 |
+
"num_beams": 1,
|
| 82 |
+
"num_return_sequences": 1,
|
| 83 |
+
"output_attentions": false,
|
| 84 |
+
"output_hidden_states": false,
|
| 85 |
+
"output_scores": false,
|
| 86 |
+
"pad_token_id": null,
|
| 87 |
+
"prefix": null,
|
| 88 |
+
"problem_type": null,
|
| 89 |
+
"pruned_heads": {},
|
| 90 |
+
"remove_invalid_values": false,
|
| 91 |
+
"repetition_penalty": 1.0,
|
| 92 |
+
"resid_pdrop": 0,
|
| 93 |
+
"return_dict": true,
|
| 94 |
+
"return_dict_in_generate": false,
|
| 95 |
+
"sep_token_id": null,
|
| 96 |
+
"suppress_tokens": null,
|
| 97 |
+
"task_specific_params": null,
|
| 98 |
+
"temperature": 1.0,
|
| 99 |
+
"tf_legacy_loss": false,
|
| 100 |
+
"tie_encoder_decoder": false,
|
| 101 |
+
"tie_word_embeddings": true,
|
| 102 |
+
"tokenizer_class": null,
|
| 103 |
+
"tokenizer_name": "EleutherAI/gpt-neox-20b",
|
| 104 |
+
"top_k": 50,
|
| 105 |
+
"top_p": 1.0,
|
| 106 |
+
"torch_dtype": "bfloat16",
|
| 107 |
+
"torchscript": false,
|
| 108 |
+
"transformers_version": "4.30.1",
|
| 109 |
+
"typical_p": 1.0,
|
| 110 |
+
"use_bfloat16": false,
|
| 111 |
+
"use_cache": false,
|
| 112 |
+
"verbose": 0,
|
| 113 |
+
"vocab_size": 50432
|
| 114 |
+
},
|
| 115 |
+
"torch_dtype": "float32",
|
| 116 |
+
"transformers_version": null,
|
| 117 |
+
"use_media_placement_augmentation": true,
|
| 118 |
+
"vision_config": {
|
| 119 |
+
"_name_or_path": "openai/clip-vit-large-patch14",
|
| 120 |
+
"add_cross_attention": false,
|
| 121 |
+
"architectures": null,
|
| 122 |
+
"attention_dropout": 0.0,
|
| 123 |
+
"bad_words_ids": null,
|
| 124 |
+
"begin_suppress_tokens": null,
|
| 125 |
+
"bos_token_id": null,
|
| 126 |
+
"chunk_size_feed_forward": 0,
|
| 127 |
+
"cross_attention_hidden_size": null,
|
| 128 |
+
"decoder_start_token_id": null,
|
| 129 |
+
"diversity_penalty": 0.0,
|
| 130 |
+
"do_sample": false,
|
| 131 |
+
"early_stopping": false,
|
| 132 |
+
"encoder_no_repeat_ngram_size": 0,
|
| 133 |
+
"eos_token_id": null,
|
| 134 |
+
"exponential_decay_length_penalty": null,
|
| 135 |
+
"finetuning_task": null,
|
| 136 |
+
"forced_bos_token_id": null,
|
| 137 |
+
"forced_eos_token_id": null,
|
| 138 |
+
"hidden_act": "quick_gelu",
|
| 139 |
+
"hidden_size": 1024,
|
| 140 |
+
"id2label": {
|
| 141 |
+
"0": "LABEL_0",
|
| 142 |
+
"1": "LABEL_1"
|
| 143 |
+
},
|
| 144 |
+
"image_size": 224,
|
| 145 |
+
"initializer_factor": 1.0,
|
| 146 |
+
"initializer_range": 0.02,
|
| 147 |
+
"intermediate_size": 4096,
|
| 148 |
+
"is_decoder": false,
|
| 149 |
+
"is_encoder_decoder": false,
|
| 150 |
+
"label2id": {
|
| 151 |
+
"LABEL_0": 0,
|
| 152 |
+
"LABEL_1": 1
|
| 153 |
+
},
|
| 154 |
+
"layer_norm_eps": 1e-05,
|
| 155 |
+
"length_penalty": 1.0,
|
| 156 |
+
"max_length": 20,
|
| 157 |
+
"min_length": 0,
|
| 158 |
+
"model_type": "clip_vision_model",
|
| 159 |
+
"no_repeat_ngram_size": 0,
|
| 160 |
+
"num_attention_heads": 16,
|
| 161 |
+
"num_beam_groups": 1,
|
| 162 |
+
"num_beams": 1,
|
| 163 |
+
"num_channels": 3,
|
| 164 |
+
"num_hidden_layers": 24,
|
| 165 |
+
"num_return_sequences": 1,
|
| 166 |
+
"output_attentions": false,
|
| 167 |
+
"output_hidden_states": false,
|
| 168 |
+
"output_scores": false,
|
| 169 |
+
"pad_token_id": null,
|
| 170 |
+
"patch_size": 14,
|
| 171 |
+
"prefix": null,
|
| 172 |
+
"problem_type": null,
|
| 173 |
+
"projection_dim": 512,
|
| 174 |
+
"pruned_heads": {},
|
| 175 |
+
"remove_invalid_values": false,
|
| 176 |
+
"repetition_penalty": 1.0,
|
| 177 |
+
"return_dict": true,
|
| 178 |
+
"return_dict_in_generate": false,
|
| 179 |
+
"sep_token_id": null,
|
| 180 |
+
"suppress_tokens": null,
|
| 181 |
+
"task_specific_params": null,
|
| 182 |
+
"temperature": 1.0,
|
| 183 |
+
"tf_legacy_loss": false,
|
| 184 |
+
"tie_encoder_decoder": false,
|
| 185 |
+
"tie_word_embeddings": true,
|
| 186 |
+
"tokenizer_class": null,
|
| 187 |
+
"top_k": 50,
|
| 188 |
+
"top_p": 1.0,
|
| 189 |
+
"torch_dtype": null,
|
| 190 |
+
"torchscript": false,
|
| 191 |
+
"transformers_version": "4.30.1",
|
| 192 |
+
"typical_p": 1.0,
|
| 193 |
+
"use_bfloat16": false
|
| 194 |
+
}
|
| 195 |
+
}
|
mllm/flamingo/flamingo-vicuna-33B-v1.3.json
ADDED
|
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_commit_hash": null,
|
| 3 |
+
"architectures": [
|
| 4 |
+
"FlamingoForConditionalGeneration"
|
| 5 |
+
],
|
| 6 |
+
"cross_attn_every_n_layers": 4,
|
| 7 |
+
"model_type": "flamingo",
|
| 8 |
+
"text_config": {
|
| 9 |
+
"_name_or_path": "/home/luodian/projects/checkpoints/vicuna-33b-v1.3",
|
| 10 |
+
"architectures": [
|
| 11 |
+
"LlamaForCausalLM"
|
| 12 |
+
],
|
| 13 |
+
"bos_token_id": 1,
|
| 14 |
+
"eos_token_id": 2,
|
| 15 |
+
"hidden_act": "silu",
|
| 16 |
+
"hidden_size": 6656,
|
| 17 |
+
"initializer_range": 0.02,
|
| 18 |
+
"intermediate_size": 17920,
|
| 19 |
+
"max_position_embeddings": 2048,
|
| 20 |
+
"model_type": "llama",
|
| 21 |
+
"num_attention_heads": 52,
|
| 22 |
+
"num_hidden_layers": 60,
|
| 23 |
+
"pad_token_id": 0,
|
| 24 |
+
"rms_norm_eps": 1e-06,
|
| 25 |
+
"tie_word_embeddings": false,
|
| 26 |
+
"torch_dtype": "float16",
|
| 27 |
+
"transformers_version": "4.28.1",
|
| 28 |
+
"use_cache": false,
|
| 29 |
+
"vocab_size": 32000
|
| 30 |
+
},
|
| 31 |
+
"torch_dtype": "float32",
|
| 32 |
+
"transformers_version": null,
|
| 33 |
+
"use_media_placement_augmentation": true,
|
| 34 |
+
"vision_config": {
|
| 35 |
+
"_name_or_path": "openai/clip-vit-large-patch14",
|
| 36 |
+
"add_cross_attention": false,
|
| 37 |
+
"architectures": null,
|
| 38 |
+
"attention_dropout": 0.0,
|
| 39 |
+
"bad_words_ids": null,
|
| 40 |
+
"begin_suppress_tokens": null,
|
| 41 |
+
"bos_token_id": null,
|
| 42 |
+
"chunk_size_feed_forward": 0,
|
| 43 |
+
"cross_attention_hidden_size": null,
|
| 44 |
+
"decoder_start_token_id": null,
|
| 45 |
+
"diversity_penalty": 0.0,
|
| 46 |
+
"do_sample": false,
|
| 47 |
+
"early_stopping": false,
|
| 48 |
+
"encoder_no_repeat_ngram_size": 0,
|
| 49 |
+
"eos_token_id": null,
|
| 50 |
+
"exponential_decay_length_penalty": null,
|
| 51 |
+
"finetuning_task": null,
|
| 52 |
+
"forced_bos_token_id": null,
|
| 53 |
+
"forced_eos_token_id": null,
|
| 54 |
+
"hidden_act": "quick_gelu",
|
| 55 |
+
"hidden_size": 1024,
|
| 56 |
+
"id2label": {
|
| 57 |
+
"0": "LABEL_0",
|
| 58 |
+
"1": "LABEL_1"
|
| 59 |
+
},
|
| 60 |
+
"image_size": 224,
|
| 61 |
+
"initializer_factor": 1.0,
|
| 62 |
+
"initializer_range": 0.02,
|
| 63 |
+
"intermediate_size": 4096,
|
| 64 |
+
"is_decoder": false,
|
| 65 |
+
"is_encoder_decoder": false,
|
| 66 |
+
"label2id": {
|
| 67 |
+
"LABEL_0": 0,
|
| 68 |
+
"LABEL_1": 1
|
| 69 |
+
},
|
| 70 |
+
"layer_norm_eps": 1e-05,
|
| 71 |
+
"length_penalty": 1.0,
|
| 72 |
+
"max_length": 20,
|
| 73 |
+
"min_length": 0,
|
| 74 |
+
"model_type": "clip_vision_model",
|
| 75 |
+
"no_repeat_ngram_size": 0,
|
| 76 |
+
"num_attention_heads": 16,
|
| 77 |
+
"num_beam_groups": 1,
|
| 78 |
+
"num_beams": 1,
|
| 79 |
+
"num_channels": 3,
|
| 80 |
+
"num_hidden_layers": 24,
|
| 81 |
+
"num_return_sequences": 1,
|
| 82 |
+
"output_attentions": false,
|
| 83 |
+
"output_hidden_states": false,
|
| 84 |
+
"output_scores": false,
|
| 85 |
+
"pad_token_id": null,
|
| 86 |
+
"patch_size": 14,
|
| 87 |
+
"prefix": null,
|
| 88 |
+
"problem_type": null,
|
| 89 |
+
"projection_dim": 512,
|
| 90 |
+
"pruned_heads": {},
|
| 91 |
+
"remove_invalid_values": false,
|
| 92 |
+
"repetition_penalty": 1.0,
|
| 93 |
+
"return_dict": true,
|
| 94 |
+
"return_dict_in_generate": false,
|
| 95 |
+
"sep_token_id": null,
|
| 96 |
+
"suppress_tokens": null,
|
| 97 |
+
"task_specific_params": null,
|
| 98 |
+
"temperature": 1.0,
|
| 99 |
+
"tf_legacy_loss": false,
|
| 100 |
+
"tie_encoder_decoder": false,
|
| 101 |
+
"tie_word_embeddings": true,
|
| 102 |
+
"tokenizer_class": null,
|
| 103 |
+
"top_k": 50,
|
| 104 |
+
"top_p": 1.0,
|
| 105 |
+
"torch_dtype": null,
|
| 106 |
+
"torchscript": false,
|
| 107 |
+
"transformers_version": "4.30.1",
|
| 108 |
+
"typical_p": 1.0,
|
| 109 |
+
"use_bfloat16": false
|
| 110 |
+
}
|
| 111 |
+
}
|
mllm/flamingo/flamingo-vicuna-7B-v1.3.json
ADDED
|
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_commit_hash": null,
|
| 3 |
+
"architectures": [
|
| 4 |
+
"FlamingoForConditionalGeneration"
|
| 5 |
+
],
|
| 6 |
+
"cross_attn_every_n_layers": 4,
|
| 7 |
+
"model_type": "flamingo",
|
| 8 |
+
"text_config": {
|
| 9 |
+
"_name_or_path": "/mnt/petrelfs/share_data/zhangyuanhan/vicuna-7b-v1.3",
|
| 10 |
+
"architectures": [
|
| 11 |
+
"LlamaForCausalLM"
|
| 12 |
+
],
|
| 13 |
+
"bos_token_id": 1,
|
| 14 |
+
"eos_token_id": 2,
|
| 15 |
+
"hidden_act": "silu",
|
| 16 |
+
"hidden_size": 4096,
|
| 17 |
+
"initializer_range": 0.02,
|
| 18 |
+
"intermediate_size": 11008,
|
| 19 |
+
"max_position_embeddings": 2048,
|
| 20 |
+
"model_type": "llama",
|
| 21 |
+
"num_attention_heads": 32,
|
| 22 |
+
"num_hidden_layers": 32,
|
| 23 |
+
"pad_token_id": 0,
|
| 24 |
+
"rms_norm_eps": 1e-06,
|
| 25 |
+
"tie_word_embeddings": false,
|
| 26 |
+
"torch_dtype": "float16",
|
| 27 |
+
"transformers_version": "4.28.1",
|
| 28 |
+
"use_cache": false,
|
| 29 |
+
"vocab_size": 32000
|
| 30 |
+
},
|
| 31 |
+
"torch_dtype": "float32",
|
| 32 |
+
"transformers_version": null,
|
| 33 |
+
"use_media_placement_augmentation": true,
|
| 34 |
+
"vision_config": {
|
| 35 |
+
"_name_or_path": "openai/clip-vit-large-patch14",
|
| 36 |
+
"add_cross_attention": false,
|
| 37 |
+
"architectures": null,
|
| 38 |
+
"attention_dropout": 0.0,
|
| 39 |
+
"bad_words_ids": null,
|
| 40 |
+
"begin_suppress_tokens": null,
|
| 41 |
+
"bos_token_id": null,
|
| 42 |
+
"chunk_size_feed_forward": 0,
|
| 43 |
+
"cross_attention_hidden_size": null,
|
| 44 |
+
"decoder_start_token_id": null,
|
| 45 |
+
"diversity_penalty": 0.0,
|
| 46 |
+
"do_sample": false,
|
| 47 |
+
"early_stopping": false,
|
| 48 |
+
"encoder_no_repeat_ngram_size": 0,
|
| 49 |
+
"eos_token_id": null,
|
| 50 |
+
"exponential_decay_length_penalty": null,
|
| 51 |
+
"finetuning_task": null,
|
| 52 |
+
"forced_bos_token_id": null,
|
| 53 |
+
"forced_eos_token_id": null,
|
| 54 |
+
"hidden_act": "quick_gelu",
|
| 55 |
+
"hidden_size": 1024,
|
| 56 |
+
"id2label": {
|
| 57 |
+
"0": "LABEL_0",
|
| 58 |
+
"1": "LABEL_1"
|
| 59 |
+
},
|
| 60 |
+
"image_size": 224,
|
| 61 |
+
"initializer_factor": 1.0,
|
| 62 |
+
"initializer_range": 0.02,
|
| 63 |
+
"intermediate_size": 4096,
|
| 64 |
+
"is_decoder": false,
|
| 65 |
+
"is_encoder_decoder": false,
|
| 66 |
+
"label2id": {
|
| 67 |
+
"LABEL_0": 0,
|
| 68 |
+
"LABEL_1": 1
|
| 69 |
+
},
|
| 70 |
+
"layer_norm_eps": 1e-05,
|
| 71 |
+
"length_penalty": 1.0,
|
| 72 |
+
"max_length": 20,
|
| 73 |
+
"min_length": 0,
|
| 74 |
+
"model_type": "clip_vision_model",
|
| 75 |
+
"no_repeat_ngram_size": 0,
|
| 76 |
+
"num_attention_heads": 16,
|
| 77 |
+
"num_beam_groups": 1,
|
| 78 |
+
"num_beams": 1,
|
| 79 |
+
"num_channels": 3,
|
| 80 |
+
"num_hidden_layers": 24,
|
| 81 |
+
"num_return_sequences": 1,
|
| 82 |
+
"output_attentions": false,
|
| 83 |
+
"output_hidden_states": false,
|
| 84 |
+
"output_scores": false,
|
| 85 |
+
"pad_token_id": null,
|
| 86 |
+
"patch_size": 14,
|
| 87 |
+
"prefix": null,
|
| 88 |
+
"problem_type": null,
|
| 89 |
+
"projection_dim": 512,
|
| 90 |
+
"pruned_heads": {},
|
| 91 |
+
"remove_invalid_values": false,
|
| 92 |
+
"repetition_penalty": 1.0,
|
| 93 |
+
"return_dict": true,
|
| 94 |
+
"return_dict_in_generate": false,
|
| 95 |
+
"sep_token_id": null,
|
| 96 |
+
"suppress_tokens": null,
|
| 97 |
+
"task_specific_params": null,
|
| 98 |
+
"temperature": 1.0,
|
| 99 |
+
"tf_legacy_loss": false,
|
| 100 |
+
"tie_encoder_decoder": false,
|
| 101 |
+
"tie_word_embeddings": true,
|
| 102 |
+
"tokenizer_class": null,
|
| 103 |
+
"top_k": 50,
|
| 104 |
+
"top_p": 1.0,
|
| 105 |
+
"torch_dtype": null,
|
| 106 |
+
"torchscript": false,
|
| 107 |
+
"transformers_version": "4.30.1",
|
| 108 |
+
"typical_p": 1.0,
|
| 109 |
+
"use_bfloat16": false
|
| 110 |
+
}
|
| 111 |
+
}
|
mllm/flamingo/injecting_falcon_into_flamingo.py
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
from .configuration_flamingo import FlamingoConfig
|
| 4 |
+
from .modeling_flamingo import FlamingoForConditionalGeneration
|
| 5 |
+
|
| 6 |
+
root_dir = os.environ["AZP"]
|
| 7 |
+
print(root_dir)
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
config = FlamingoConfig.from_json_file(".flamingo-falcon-7B.json")
|
| 11 |
+
model = FlamingoForConditionalGeneration(config=config)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
state_dict_files = [
|
| 15 |
+
f"{root_dir}/otter/checkpoints/falcon-7b/pytorch_model-00001-of-00002.bin",
|
| 16 |
+
f"{root_dir}/otter/checkpoints/falcon-7b/pytorch_model-00002-of-00002.bin",
|
| 17 |
+
]
|
| 18 |
+
|
| 19 |
+
state_dict = {}
|
| 20 |
+
for file in state_dict_files:
|
| 21 |
+
state_dict_part = torch.load(file, map_location="cpu")
|
| 22 |
+
state_dict.update(state_dict_part)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
state_dict_3 = torch.load("{root_dir}/otter/checkpoints/flamingo_9b_hf/pytorch_model-00004-of-00004.bin", map_location="cpu")
|
| 26 |
+
for cur_key in list(state_dict_3.keys()):
|
| 27 |
+
if "vision_encoder" not in cur_key:
|
| 28 |
+
del state_dict_3[cur_key]
|
| 29 |
+
|
| 30 |
+
_ = model.load_state_dict(
|
| 31 |
+
state_dict_3,
|
| 32 |
+
False,
|
| 33 |
+
)
|
| 34 |
+
print(_[1])
|
| 35 |
+
|
| 36 |
+
save_state_dict_1 = {}
|
| 37 |
+
for key in state_dict:
|
| 38 |
+
if ".h." in key:
|
| 39 |
+
_, _, layer_num, *remain_names = key.split(".")
|
| 40 |
+
target_key = f"transformer.h.{layer_num}.decoder_layer.{'.'.join(remain_names)}"
|
| 41 |
+
else:
|
| 42 |
+
target_key = key
|
| 43 |
+
save_state_dict_1[f"{target_key}"] = state_dict[key]
|
| 44 |
+
_ = model.lang_encoder.load_state_dict(
|
| 45 |
+
save_state_dict_1,
|
| 46 |
+
False,
|
| 47 |
+
)
|
| 48 |
+
print(_[1])
|
| 49 |
+
model.save_pretrained(f"{root_dir}/otter/checkpoints/flamingo-falcon-7b/")
|
mllm/flamingo/injecting_llama2_into_flamingo.py
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from tqdm import tqdm
|
| 6 |
+
|
| 7 |
+
import sys
|
| 8 |
+
|
| 9 |
+
from .configuration_flamingo import FlamingoConfig
|
| 10 |
+
from .modeling_flamingo import FlamingoForConditionalGeneration
|
| 11 |
+
|
| 12 |
+
# from .configuration_flamingo import FlamingoConfig
|
| 13 |
+
# from .modeling_flamingo import FlamingoForConditionalGeneration
|
| 14 |
+
|
| 15 |
+
parser = argparse.ArgumentParser(description="Convert Vicuna model")
|
| 16 |
+
parser.add_argument("--model_choice", type=str, default="13B", help="Choose either '7B' or '13B'")
|
| 17 |
+
parser.add_argument("--llama2_root_dir", type=str, default="/home/luodian/projects/checkpoints")
|
| 18 |
+
parser.add_argument("--save_root_dir", type=str, default="/home/luodian/projects/checkpoints")
|
| 19 |
+
args = parser.parse_args()
|
| 20 |
+
|
| 21 |
+
# os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
| 22 |
+
|
| 23 |
+
root_dir = args.llama2_root_dir
|
| 24 |
+
model_choice = args.model_choice
|
| 25 |
+
save_root_dir = args.save_root_dir
|
| 26 |
+
|
| 27 |
+
# prepare vicuna model at first
|
| 28 |
+
# you can visit https://huggingface.co/lmsys/Llama-2-33b-chat-hf to download 7B and 30B instruct checkpoints.
|
| 29 |
+
if model_choice == "7B":
|
| 30 |
+
config_file = "./flamingo/flamingo-llama2-chat-7B.json"
|
| 31 |
+
state_dict_files = [
|
| 32 |
+
f"{root_dir}/Llama-2-7b-chat-hf/pytorch_model-00001-of-00002.bin",
|
| 33 |
+
f"{root_dir}/Llama-2-7b-chat-hf/pytorch_model-00002-of-00002.bin",
|
| 34 |
+
]
|
| 35 |
+
save_path = f"{save_root_dir}/flamingo-llama2-chat-7B-init"
|
| 36 |
+
elif model_choice == "13B":
|
| 37 |
+
config_file = "./flamingo/flamingo-llama2-chat-13B.json"
|
| 38 |
+
state_dict_files = [
|
| 39 |
+
f"{root_dir}/Llama-2-13b-chat-hf/pytorch_model-00001-of-00003.bin",
|
| 40 |
+
f"{root_dir}/Llama-2-13b-chat-hf/pytorch_model-00002-of-00003.bin",
|
| 41 |
+
f"{root_dir}/Llama-2-13b-chat-hf/pytorch_model-00003-of-00003.bin",
|
| 42 |
+
]
|
| 43 |
+
save_path = f"{save_root_dir}/flamingo-llama2-chat-13B-init"
|
| 44 |
+
else:
|
| 45 |
+
raise ValueError("Invalid model_choice. Choose either '13B' or '7B'.")
|
| 46 |
+
|
| 47 |
+
config = FlamingoConfig.from_json_file(config_file)
|
| 48 |
+
model = FlamingoForConditionalGeneration(config=config)
|
| 49 |
+
|
| 50 |
+
# load flamingo's vision encoder from last checkpoint.
|
| 51 |
+
# you can visit https://huggingface.co/luodian/openflamingo-9b-hf/tree/main to download the checkpoint.
|
| 52 |
+
# AZP = "os.environ["AZP"]"
|
| 53 |
+
AZP = os.environ["AZP"]
|
| 54 |
+
state_dict_3 = torch.load(f"{AZP}/otter/checkpoints/flamingo_9b_hf/pytorch_model-00004-of-00004.bin", map_location="cpu")
|
| 55 |
+
for cur_key in list(state_dict_3.keys()):
|
| 56 |
+
if "vision_encoder" not in cur_key:
|
| 57 |
+
del state_dict_3[cur_key]
|
| 58 |
+
|
| 59 |
+
load_msg = model.load_state_dict(
|
| 60 |
+
state_dict_3,
|
| 61 |
+
False,
|
| 62 |
+
)
|
| 63 |
+
# print incompatible keys
|
| 64 |
+
print(load_msg[1])
|
| 65 |
+
|
| 66 |
+
# Loading vicuna weights
|
| 67 |
+
state_dict = {}
|
| 68 |
+
for file in tqdm(state_dict_files, desc="Loading state dict"):
|
| 69 |
+
state_dict_part = torch.load(file, map_location="cpu")
|
| 70 |
+
state_dict.update(state_dict_part)
|
| 71 |
+
|
| 72 |
+
save_state_dict_1 = {}
|
| 73 |
+
for key in state_dict:
|
| 74 |
+
if ".layers." in key:
|
| 75 |
+
_, _, layer_num, *remain_names = key.split(".")
|
| 76 |
+
target_key = f"model.layers.{layer_num}.decoder_layer.{'.'.join(remain_names)}"
|
| 77 |
+
else:
|
| 78 |
+
target_key = key
|
| 79 |
+
save_state_dict_1[f"{target_key}"] = state_dict[key]
|
| 80 |
+
|
| 81 |
+
# Reshape the token embedding to 50280 for compatible
|
| 82 |
+
model.lang_encoder.resize_token_embeddings(32000)
|
| 83 |
+
|
| 84 |
+
load_msg = model.lang_encoder.load_state_dict(
|
| 85 |
+
save_state_dict_1,
|
| 86 |
+
False,
|
| 87 |
+
)
|
| 88 |
+
# Reshape the token embedding to 32002 for compatible
|
| 89 |
+
model.lang_encoder.resize_token_embeddings(32002)
|
| 90 |
+
# print incompatible keys
|
| 91 |
+
print(load_msg[1])
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
print(f"Saving model to {save_path}...")
|
| 95 |
+
model.save_pretrained(save_path, max_shard_size="10GB")
|
mllm/flamingo/injecting_mpt-1B-redpajama_into_flamingo.py
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from tqdm import tqdm
|
| 6 |
+
|
| 7 |
+
import sys
|
| 8 |
+
|
| 9 |
+
from configuration_flamingo import FlamingoConfig
|
| 10 |
+
from modeling_flamingo import FlamingoForConditionalGeneration
|
| 11 |
+
from utils import rename_flamingo_checkpoint
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
parser = argparse.ArgumentParser(description="Convert MPT model")
|
| 15 |
+
parser.add_argument("--mpt_root_dir", type=str, default="/home/luodian/projects/checkpoints")
|
| 16 |
+
parser.add_argument("--save_root_dir", type=str, default="/home/luodian/projects/checkpoints")
|
| 17 |
+
parser.add_argument("--flamingo_dir", type=str, default=None, help="If the pretrained flamingo weights also need to be injected")
|
| 18 |
+
args = parser.parse_args()
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
root_dir = args.mpt_root_dir
|
| 22 |
+
save_root_dir = args.save_root_dir
|
| 23 |
+
|
| 24 |
+
# prepare mpt model at first
|
| 25 |
+
# you can visit https://huggingface.co/mosaicml to download 7B and 30B instruct checkpoints.
|
| 26 |
+
config_file = "./flamingo/flamingo-mpt-1B-redpajama.json"
|
| 27 |
+
state_dict_file = f"{root_dir}/pytorch_model.bin"
|
| 28 |
+
save_path = f"{save_root_dir}/flamingo-mpt-1b-redpajama-200b-dolly"
|
| 29 |
+
|
| 30 |
+
config = FlamingoConfig.from_json_file(config_file)
|
| 31 |
+
|
| 32 |
+
model = FlamingoForConditionalGeneration(config=config)
|
| 33 |
+
|
| 34 |
+
# Loading mpt weights
|
| 35 |
+
state_dict = torch.load(state_dict_file, map_location="cpu")
|
| 36 |
+
save_state_dict_1 = {}
|
| 37 |
+
for key in state_dict:
|
| 38 |
+
if ".blocks." in key:
|
| 39 |
+
_, _, layer_num, *remain_names = key.split(".")
|
| 40 |
+
target_key = f"transformer.blocks.{layer_num}.decoder_layer.{'.'.join(remain_names)}"
|
| 41 |
+
else:
|
| 42 |
+
target_key = key
|
| 43 |
+
save_state_dict_1[f"{target_key}"] = state_dict[key]
|
| 44 |
+
|
| 45 |
+
load_msg = model.lang_encoder.load_state_dict(
|
| 46 |
+
save_state_dict_1,
|
| 47 |
+
False,
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
# load flamingo's vision encoder from last checkpoint.
|
| 51 |
+
# you can visit https://huggingface.co/luodian/openflamingo-9b-hf/tree/main to download the checkpoint.
|
| 52 |
+
AZP = os.environ["AZP"]
|
| 53 |
+
state_dict_3 = torch.load(f"{AZP}/pytorch_model-00004-of-00004.bin", map_location="cpu")
|
| 54 |
+
for cur_key in list(state_dict_3.keys()):
|
| 55 |
+
if "vision_encoder" not in cur_key:
|
| 56 |
+
del state_dict_3[cur_key]
|
| 57 |
+
|
| 58 |
+
load_msg = model.load_state_dict(
|
| 59 |
+
state_dict_3,
|
| 60 |
+
False,
|
| 61 |
+
)
|
| 62 |
+
# print incompatible keys
|
| 63 |
+
print(load_msg[1])
|
| 64 |
+
|
| 65 |
+
save_state_dict_1 = {}
|
| 66 |
+
for key in state_dict:
|
| 67 |
+
if ".blocks." in key:
|
| 68 |
+
_, _, layer_num, *remain_names = key.split(".")
|
| 69 |
+
target_key = f"transformer.blocks.{layer_num}.decoder_layer.{'.'.join(remain_names)}"
|
| 70 |
+
else:
|
| 71 |
+
target_key = key
|
| 72 |
+
save_state_dict_1[f"{target_key}"] = state_dict[key]
|
| 73 |
+
|
| 74 |
+
load_msg = model.lang_encoder.load_state_dict(
|
| 75 |
+
save_state_dict_1,
|
| 76 |
+
False,
|
| 77 |
+
)
|
| 78 |
+
# print incompatible keys
|
| 79 |
+
print(load_msg[1])
|
| 80 |
+
if args.flamingo_dir is not None:
|
| 81 |
+
state_dict_2 = torch.load(f"{args.flamingo_dir}/checkpoint.pt", map_location="cpu")
|
| 82 |
+
save_state_dict_2 = rename_flamingo_checkpoint(state_dict_2)
|
| 83 |
+
real_vocab_size = config.text_config.vocab_size
|
| 84 |
+
# Reshape the token embedding to 50280 for compatible
|
| 85 |
+
model.lang_encoder.resize_token_embeddings(save_state_dict_2["lang_encoder.transformer.wte.weight"].shape[0])
|
| 86 |
+
|
| 87 |
+
load_msg = model.load_state_dict(
|
| 88 |
+
save_state_dict_2,
|
| 89 |
+
False,
|
| 90 |
+
)
|
| 91 |
+
# print incompatible keys
|
| 92 |
+
print(load_msg[1])
|
| 93 |
+
# Reshape the token embedding to 50432
|
| 94 |
+
model.lang_encoder.resize_token_embeddings(real_vocab_size)
|
| 95 |
+
|
| 96 |
+
print(f"Saving model to {save_path}...")
|
| 97 |
+
model.save_pretrained(save_path, max_shard_size="10GB")
|
mllm/flamingo/injecting_mpt_into_flamingo.py
ADDED
|
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from tqdm import tqdm
|
| 6 |
+
|
| 7 |
+
import sys
|
| 8 |
+
|
| 9 |
+
from configuration_flamingo import FlamingoConfig
|
| 10 |
+
from modeling_flamingo import FlamingoForConditionalGeneration
|
| 11 |
+
from utils import rename_flamingo_checkpoint
|
| 12 |
+
|
| 13 |
+
parser = argparse.ArgumentParser(description="Convert MPT model")
|
| 14 |
+
parser.add_argument("--model_choice", type=str, choices=["7B", "30B"], required=True, help="Choose either '7B' or '30B'")
|
| 15 |
+
parser.add_argument("--mpt_root_dir", type=str, default="/home/luodian/projects/checkpoints")
|
| 16 |
+
parser.add_argument("--save_root_dir", type=str, default="/home/luodian/projects/checkpoints")
|
| 17 |
+
parser.add_argument("--flamingo_dir", type=str, default=None, help="If the pretrained flamingo weights also need to be injected")
|
| 18 |
+
args = parser.parse_args()
|
| 19 |
+
|
| 20 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
| 21 |
+
|
| 22 |
+
root_dir = args.mpt_root_dir
|
| 23 |
+
model_choice = args.model_choice
|
| 24 |
+
save_root_dir = args.save_root_dir
|
| 25 |
+
|
| 26 |
+
# prepare mpt model at first
|
| 27 |
+
# you can visit https://huggingface.co/mosaicml to download 7B and 30B instruct checkpoints.
|
| 28 |
+
if model_choice == "30B":
|
| 29 |
+
config_file = "./flamingo/flamingo-mpt-30B.json"
|
| 30 |
+
state_dict_files = [
|
| 31 |
+
f"{root_dir}/mpt-30b-instruct/pytorch_model-00001-of-00007.bin",
|
| 32 |
+
f"{root_dir}/mpt-30b-instruct/pytorch_model-00002-of-00007.bin",
|
| 33 |
+
f"{root_dir}/mpt-30b-instruct/pytorch_model-00003-of-00007.bin",
|
| 34 |
+
f"{root_dir}/mpt-30b-instruct/pytorch_model-00004-of-00007.bin",
|
| 35 |
+
f"{root_dir}/mpt-30b-instruct/pytorch_model-00005-of-00007.bin",
|
| 36 |
+
f"{root_dir}/mpt-30b-instruct/pytorch_model-00006-of-00007.bin",
|
| 37 |
+
f"{root_dir}/mpt-30b-instruct/pytorch_model-00007-of-00007.bin",
|
| 38 |
+
]
|
| 39 |
+
save_path = f"{save_root_dir}/flamingo-mpt-30B-instruct-init"
|
| 40 |
+
elif model_choice == "7B":
|
| 41 |
+
config_file = "./flamingo/flamingo-mpt-7B.json"
|
| 42 |
+
state_dict_files = [
|
| 43 |
+
f"{root_dir}/mpt-7b/pytorch_model-00001-of-00002.bin",
|
| 44 |
+
f"{root_dir}/mpt-7b/pytorch_model-00002-of-00002.bin",
|
| 45 |
+
]
|
| 46 |
+
save_path = f"{save_root_dir}/flamingo-mpt-7B"
|
| 47 |
+
else:
|
| 48 |
+
raise ValueError("Invalid model_choice. Choose either '30B' or '7B'.")
|
| 49 |
+
|
| 50 |
+
config = FlamingoConfig.from_json_file(config_file)
|
| 51 |
+
|
| 52 |
+
model = FlamingoForConditionalGeneration(config=config)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
# load flamingo's vision encoder from last checkpoint.
|
| 56 |
+
# you can visit https://huggingface.co/luodian/openflamingo-9b-hf/tree/main to download the checkpoint.
|
| 57 |
+
AZP = os.environ["AZP"]
|
| 58 |
+
state_dict_3 = torch.load(f"{AZP}/otter/checkpoints/flamingo_9b_hf/pytorch_model-00004-of-00004.bin", map_location="cpu")
|
| 59 |
+
for cur_key in list(state_dict_3.keys()):
|
| 60 |
+
if "vision_encoder" not in cur_key:
|
| 61 |
+
del state_dict_3[cur_key]
|
| 62 |
+
|
| 63 |
+
load_msg = model.load_state_dict(
|
| 64 |
+
state_dict_3,
|
| 65 |
+
False,
|
| 66 |
+
)
|
| 67 |
+
# print incompatible keys
|
| 68 |
+
print(load_msg[1])
|
| 69 |
+
|
| 70 |
+
# Loading mpt weights
|
| 71 |
+
state_dict = {}
|
| 72 |
+
for file in tqdm(state_dict_files, desc="Loading state dict"):
|
| 73 |
+
state_dict_part = torch.load(file, map_location="cpu")
|
| 74 |
+
state_dict.update(state_dict_part)
|
| 75 |
+
|
| 76 |
+
save_state_dict_1 = {}
|
| 77 |
+
for key in state_dict:
|
| 78 |
+
if ".blocks." in key:
|
| 79 |
+
_, _, layer_num, *remain_names = key.split(".")
|
| 80 |
+
target_key = f"transformer.blocks.{layer_num}.decoder_layer.{'.'.join(remain_names)}"
|
| 81 |
+
else:
|
| 82 |
+
target_key = key
|
| 83 |
+
save_state_dict_1[f"{target_key}"] = state_dict[key]
|
| 84 |
+
|
| 85 |
+
load_msg = model.lang_encoder.load_state_dict(
|
| 86 |
+
save_state_dict_1,
|
| 87 |
+
False,
|
| 88 |
+
)
|
| 89 |
+
# print incompatible keys
|
| 90 |
+
print(load_msg[1])
|
| 91 |
+
if args.flamingo_dir is not None:
|
| 92 |
+
state_dict_2 = torch.load(f"{args.flamingo_dir}/checkpoint.pt", map_location="cpu")
|
| 93 |
+
save_state_dict_2 = rename_flamingo_checkpoint(state_dict_2)
|
| 94 |
+
|
| 95 |
+
real_vocab_size = config.text_config.vocab_size
|
| 96 |
+
# Reshape the token embedding to 50280 for compatible
|
| 97 |
+
model.lang_encoder.resize_token_embeddings(save_state_dict_2["lang_encoder.transformer.wte.weight"].shape[0])
|
| 98 |
+
|
| 99 |
+
load_msg = model.load_state_dict(
|
| 100 |
+
save_state_dict_2,
|
| 101 |
+
False,
|
| 102 |
+
)
|
| 103 |
+
# print incompatible keys
|
| 104 |
+
print(load_msg[1])
|
| 105 |
+
# Reshape the token embedding to 50432
|
| 106 |
+
model.lang_encoder.resize_token_embeddings(real_vocab_size)
|
| 107 |
+
|
| 108 |
+
print(f"Saving model to {save_path}...")
|
| 109 |
+
model.save_pretrained(save_path, max_shard_size="10GB")
|
mllm/flamingo/injecting_vicuna_into_flamingo.py
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from tqdm import tqdm
|
| 6 |
+
|
| 7 |
+
import sys
|
| 8 |
+
|
| 9 |
+
from .configuration_flamingo import FlamingoConfig
|
| 10 |
+
from .modeling_flamingo import FlamingoForConditionalGeneration
|
| 11 |
+
|
| 12 |
+
# from .configuration_flamingo import FlamingoConfig
|
| 13 |
+
# from .modeling_flamingo import FlamingoForConditionalGeneration
|
| 14 |
+
|
| 15 |
+
parser = argparse.ArgumentParser(description="Convert Vicuna model")
|
| 16 |
+
parser.add_argument("--model_choice", type=str, choices=["7B", "33B"], required=True, help="Choose either '7B' or '33B'")
|
| 17 |
+
parser.add_argument("--vicuna_root_dir", type=str, default="/home/luodian/projects/checkpoints")
|
| 18 |
+
parser.add_argument("--save_root_dir", type=str, default="/home/luodian/projects/checkpoints")
|
| 19 |
+
parser.add_argument("--flamingo_dir", type=str, default=None, help="If the pretrained flamingo weights also need to be injected")
|
| 20 |
+
args = parser.parse_args()
|
| 21 |
+
|
| 22 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
| 23 |
+
|
| 24 |
+
root_dir = args.vicuna_root_dir
|
| 25 |
+
model_choice = args.model_choice
|
| 26 |
+
save_root_dir = args.save_root_dir
|
| 27 |
+
|
| 28 |
+
# prepare vicuna model at first
|
| 29 |
+
# you can visit https://huggingface.co/lmsys/vicuna-33b-v1.3 to download 7B and 30B instruct checkpoints.
|
| 30 |
+
if model_choice == "33B":
|
| 31 |
+
config_file = "./flamingo/flamingo-vicuna-33B-v1.3.json"
|
| 32 |
+
state_dict_files = [
|
| 33 |
+
f"{root_dir}/vicuna-33b-v1.3/pytorch_model-00001-of-00007.bin",
|
| 34 |
+
f"{root_dir}/vicuna-33b-v1.3/pytorch_model-00002-of-00007.bin",
|
| 35 |
+
f"{root_dir}/vicuna-33b-v1.3/pytorch_model-00003-of-00007.bin",
|
| 36 |
+
f"{root_dir}/vicuna-33b-v1.3/pytorch_model-00004-of-00007.bin",
|
| 37 |
+
f"{root_dir}/vicuna-33b-v1.3/pytorch_model-00005-of-00007.bin",
|
| 38 |
+
f"{root_dir}/vicuna-33b-v1.3/pytorch_model-00006-of-00007.bin",
|
| 39 |
+
f"{root_dir}/vicuna-33b-v1.3/pytorch_model-00007-of-00007.bin",
|
| 40 |
+
]
|
| 41 |
+
save_path = f"{save_root_dir}/flamingo-vicuna-33B-v1.3-init"
|
| 42 |
+
elif model_choice == "7B":
|
| 43 |
+
config_file = "./flamingo/flamingo-vicuna-7B-v1.3.json"
|
| 44 |
+
state_dict_files = [
|
| 45 |
+
f"{root_dir}/vicuna-7b-v1.3/pytorch_model-00001-of-00002.bin",
|
| 46 |
+
f"{root_dir}/vicuna-7b-v1.3/pytorch_model-00002-of-00002.bin",
|
| 47 |
+
]
|
| 48 |
+
save_path = f"{save_root_dir}/flamingo-vicuna-7B-v1.3-init"
|
| 49 |
+
else:
|
| 50 |
+
raise ValueError("Invalid model_choice. Choose either '33B' or '7B'.")
|
| 51 |
+
|
| 52 |
+
config = FlamingoConfig.from_json_file(config_file)
|
| 53 |
+
model = FlamingoForConditionalGeneration(config=config)
|
| 54 |
+
|
| 55 |
+
# load flamingo's vision encoder from last checkpoint.
|
| 56 |
+
# you can visit https://huggingface.co/luodian/openflamingo-9b-hf/tree/main to download the checkpoint.
|
| 57 |
+
# AZP = "os.environ["AZP"]"
|
| 58 |
+
AZP = os.environ["AZP"]
|
| 59 |
+
state_dict_3 = torch.load(f"{AZP}/otter/checkpoints/flamingo_9b_hf/pytorch_model-00004-of-00004.bin", map_location="cpu")
|
| 60 |
+
for cur_key in list(state_dict_3.keys()):
|
| 61 |
+
if "vision_encoder" not in cur_key:
|
| 62 |
+
del state_dict_3[cur_key]
|
| 63 |
+
|
| 64 |
+
load_msg = model.load_state_dict(
|
| 65 |
+
state_dict_3,
|
| 66 |
+
False,
|
| 67 |
+
)
|
| 68 |
+
# print incompatible keys
|
| 69 |
+
print(load_msg[1])
|
| 70 |
+
|
| 71 |
+
# Loading vicuna weights
|
| 72 |
+
state_dict = {}
|
| 73 |
+
for file in tqdm(state_dict_files, desc="Loading state dict"):
|
| 74 |
+
state_dict_part = torch.load(file, map_location="cpu")
|
| 75 |
+
state_dict.update(state_dict_part)
|
| 76 |
+
|
| 77 |
+
save_state_dict_1 = {}
|
| 78 |
+
for key in state_dict:
|
| 79 |
+
if ".layers." in key:
|
| 80 |
+
_, _, layer_num, *remain_names = key.split(".")
|
| 81 |
+
target_key = f"model.layers.{layer_num}.decoder_layer.{'.'.join(remain_names)}"
|
| 82 |
+
else:
|
| 83 |
+
target_key = key
|
| 84 |
+
save_state_dict_1[f"{target_key}"] = state_dict[key]
|
| 85 |
+
|
| 86 |
+
# Reshape the token embedding to 50280 for compatible
|
| 87 |
+
model.lang_encoder.resize_token_embeddings(32000)
|
| 88 |
+
|
| 89 |
+
load_msg = model.lang_encoder.load_state_dict(
|
| 90 |
+
save_state_dict_1,
|
| 91 |
+
False,
|
| 92 |
+
)
|
| 93 |
+
# Reshape the token embedding to 32002 for compatible
|
| 94 |
+
model.lang_encoder.resize_token_embeddings(32002)
|
| 95 |
+
# print incompatible keys
|
| 96 |
+
print(load_msg[1])
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
print(f"Saving model to {save_path}...")
|
| 100 |
+
model.save_pretrained(save_path, max_shard_size="10GB")
|
mllm/flamingo/modeling_flamingo.py
ADDED
|
@@ -0,0 +1,966 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import random
|
| 2 |
+
from dataclasses import dataclass
|
| 3 |
+
from typing import Callable, Optional
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
from accelerate.hooks import AlignDevicesHook, add_hook_to_module
|
| 8 |
+
from einops import rearrange, repeat
|
| 9 |
+
from transformers import CLIPVisionModel, LlamaForCausalLM, LlamaTokenizer
|
| 10 |
+
from transformers.modeling_outputs import CausalLMOutputWithPast
|
| 11 |
+
from transformers.modeling_utils import PreTrainedModel
|
| 12 |
+
from transformers.models.auto import AutoModel, AutoModelForCausalLM, AutoTokenizer
|
| 13 |
+
|
| 14 |
+
from .configuration_flamingo import FlamingoConfig
|
| 15 |
+
from .falcon.modelling_RW import RWForCausalLM
|
| 16 |
+
from .mpt.modeling_mpt import MPTForCausalLM
|
| 17 |
+
from .mpt_redpajama.mosaic_gpt import MosaicGPT
|
| 18 |
+
|
| 19 |
+
# from .configuration_flamingo import FlamingoConfig
|
| 20 |
+
|
| 21 |
+
__KNOWN_DECODER_LAYERS_ATTR_NAMES = {
|
| 22 |
+
"opt": "model.decoder.layers",
|
| 23 |
+
"gptneo": "transformer.h",
|
| 24 |
+
"gptj": "transformer.h",
|
| 25 |
+
"gpt-j": "transformer.h",
|
| 26 |
+
"pythia": "gpt_neox.layers",
|
| 27 |
+
"llama": "model.layers",
|
| 28 |
+
"RWForCausalLM": "transformer.h",
|
| 29 |
+
"MPTForCausalLM": "transformer.blocks",
|
| 30 |
+
"MosaicGPT": "transformer.blocks",
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def _infer_decoder_layers_attr_name(model: nn.Module):
|
| 35 |
+
for k in __KNOWN_DECODER_LAYERS_ATTR_NAMES:
|
| 36 |
+
if k.lower() in model.__class__.__name__.lower():
|
| 37 |
+
return __KNOWN_DECODER_LAYERS_ATTR_NAMES[k]
|
| 38 |
+
|
| 39 |
+
raise ValueError(
|
| 40 |
+
f"We require the attribute name for the nn.ModuleList in the decoder storing the transformer block layers. Please supply this string manually."
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def extend_instance(obj, mixin):
|
| 45 |
+
"""Apply mixins to a class instance after creation"""
|
| 46 |
+
base_cls = obj.__class__
|
| 47 |
+
base_cls_name = obj.__class__.__name__
|
| 48 |
+
obj.__class__ = type(base_cls_name, (mixin, base_cls), {}) # mixin needs to go first for our forward() logic to work
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def getattr_recursive(obj, att):
|
| 52 |
+
"""
|
| 53 |
+
Return nested attribute of obj
|
| 54 |
+
Example: getattr_recursive(obj, 'a.b.c') is equivalent to obj.a.b.c
|
| 55 |
+
"""
|
| 56 |
+
if att == "":
|
| 57 |
+
return obj
|
| 58 |
+
i = att.find(".")
|
| 59 |
+
if i < 0:
|
| 60 |
+
return getattr(obj, att)
|
| 61 |
+
else:
|
| 62 |
+
return getattr_recursive(getattr(obj, att[:i]), att[i + 1 :])
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def setattr_recursive(obj, att, val):
|
| 66 |
+
"""
|
| 67 |
+
Set nested attribute of obj
|
| 68 |
+
Example: setattr_recursive(obj, 'a.b.c', val) is equivalent to obj.a.b.c = val
|
| 69 |
+
"""
|
| 70 |
+
if "." in att:
|
| 71 |
+
obj = getattr_recursive(obj, ".".join(att.split(".")[:-1]))
|
| 72 |
+
setattr(obj, att.split(".")[-1], val)
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def exists(val):
|
| 76 |
+
return val is not None
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
class FlamingoPerceiverBlock(nn.Module):
|
| 80 |
+
def __init__(self, *, dim: int, dim_head: int = 64, heads: int = 8, mult: int = 4):
|
| 81 |
+
super().__init__()
|
| 82 |
+
self.scale = dim_head**-0.5
|
| 83 |
+
self.heads = heads
|
| 84 |
+
inner_dim = dim_head * heads
|
| 85 |
+
ff_dim = dim * mult
|
| 86 |
+
self.norm_media = nn.LayerNorm(dim)
|
| 87 |
+
self.norm_latents = nn.LayerNorm(dim)
|
| 88 |
+
|
| 89 |
+
self.to_q = nn.Linear(dim, inner_dim, bias=False)
|
| 90 |
+
self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
|
| 91 |
+
self.to_out = nn.Linear(inner_dim, dim, bias=False)
|
| 92 |
+
self.feed_forward = nn.ModuleList(
|
| 93 |
+
[
|
| 94 |
+
nn.LayerNorm(dim),
|
| 95 |
+
nn.Linear(dim, ff_dim, bias=False),
|
| 96 |
+
nn.GELU(),
|
| 97 |
+
nn.Linear(ff_dim, dim, bias=False),
|
| 98 |
+
]
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
def forward(self, x: torch.Tensor, latents: torch.Tensor) -> torch.Tensor:
|
| 102 |
+
"""
|
| 103 |
+
Args:
|
| 104 |
+
x (torch.Tensor): image features
|
| 105 |
+
shape (b, T, n1, D)
|
| 106 |
+
latent (torch.Tensor): latent features
|
| 107 |
+
shape (b, T, n2, D)
|
| 108 |
+
"""
|
| 109 |
+
x = self.norm_media(x)
|
| 110 |
+
residual_latents = latents
|
| 111 |
+
latents = self.norm_latents(latents)
|
| 112 |
+
|
| 113 |
+
h = self.heads
|
| 114 |
+
|
| 115 |
+
q = self.to_q(latents)
|
| 116 |
+
kv_input = torch.cat((x, latents), dim=-2)
|
| 117 |
+
k, v = self.to_kv(kv_input).chunk(2, dim=-1)
|
| 118 |
+
q = rearrange(q, "b t n (h d) -> b h t n d", h=h)
|
| 119 |
+
k = rearrange(k, "b t n (h d) -> b h t n d", h=h)
|
| 120 |
+
v = rearrange(v, "b t n (h d) -> b h t n d", h=h)
|
| 121 |
+
q = q * self.scale
|
| 122 |
+
|
| 123 |
+
# attention
|
| 124 |
+
sim = torch.einsum("... i d, ... j d -> ... i j", q, k)
|
| 125 |
+
sim = sim - sim.amax(dim=-1, keepdim=True).detach()
|
| 126 |
+
attn = sim.softmax(dim=-1)
|
| 127 |
+
|
| 128 |
+
out = torch.einsum("... i j, ... j d -> ... i d", attn, v)
|
| 129 |
+
out = rearrange(out, "b h t n d -> b t n (h d)", h=h)
|
| 130 |
+
out = self.to_out(out) + residual_latents
|
| 131 |
+
residual_out = out
|
| 132 |
+
for layer in self.feed_forward:
|
| 133 |
+
out = layer(out)
|
| 134 |
+
return out + residual_out
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
class FlamingoPerceiverResampler(nn.Module):
|
| 138 |
+
def __init__(
|
| 139 |
+
self,
|
| 140 |
+
*,
|
| 141 |
+
dim: int,
|
| 142 |
+
depth: int = 6,
|
| 143 |
+
dim_head: int = 64,
|
| 144 |
+
heads: int = 8,
|
| 145 |
+
num_latents: int = 64,
|
| 146 |
+
# max_num_frames: int = 128,
|
| 147 |
+
max_num_media: Optional[int] = None,
|
| 148 |
+
max_num_frames: Optional[int] = None,
|
| 149 |
+
ff_mult: int = 4,
|
| 150 |
+
):
|
| 151 |
+
super().__init__()
|
| 152 |
+
self.latents = nn.Parameter(torch.randn(num_latents, dim))
|
| 153 |
+
self.frame_embs = nn.Parameter(torch.randn(max_num_frames, dim)) if exists(max_num_frames) else None
|
| 154 |
+
# self.frame_embs = nn.Parameter(torch.randn(max_num_frames, dim))
|
| 155 |
+
|
| 156 |
+
self.media_time_embs = nn.Parameter(torch.randn(max_num_media, 1, dim)) if exists(max_num_media) else None
|
| 157 |
+
|
| 158 |
+
self.layers = nn.ModuleList([])
|
| 159 |
+
for _ in range(depth):
|
| 160 |
+
self.layers.append(FlamingoPerceiverBlock(dim=dim, dim_head=dim_head, heads=heads, mult=ff_mult))
|
| 161 |
+
|
| 162 |
+
self.norm = nn.LayerNorm(dim)
|
| 163 |
+
|
| 164 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 165 |
+
"""
|
| 166 |
+
Args:
|
| 167 |
+
x (torch.Tensor): image features
|
| 168 |
+
shape (b, T, F, v, D)
|
| 169 |
+
Returns:
|
| 170 |
+
shape (b, T, n, D) where n is self.num_latents
|
| 171 |
+
"""
|
| 172 |
+
b, T, F, v = x.shape[:4]
|
| 173 |
+
|
| 174 |
+
# frame and media time embeddings
|
| 175 |
+
if exists(self.frame_embs):
|
| 176 |
+
frame_embs = repeat(self.frame_embs[:F], "F d -> b T F v d", b=b, T=T, v=v)
|
| 177 |
+
x = x + frame_embs
|
| 178 |
+
x = rearrange(x, "b T F v d -> b T (F v) d") # flatten the frame and spatial dimensions
|
| 179 |
+
if exists(self.media_time_embs):
|
| 180 |
+
x = x + self.media_time_embs[:T]
|
| 181 |
+
|
| 182 |
+
# blocks
|
| 183 |
+
latents = repeat(self.latents, "n d -> b T n d", b=b, T=T)
|
| 184 |
+
for block in self.layers:
|
| 185 |
+
latents = block(x, latents)
|
| 186 |
+
return self.norm(latents)
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
class FlamingoMaskedCrossAttention(nn.Module):
|
| 190 |
+
def __init__(
|
| 191 |
+
self,
|
| 192 |
+
*,
|
| 193 |
+
dim: int,
|
| 194 |
+
dim_visual: int,
|
| 195 |
+
dim_head: int = 64,
|
| 196 |
+
heads: int = 8,
|
| 197 |
+
only_attend_immediate_media: bool = True,
|
| 198 |
+
):
|
| 199 |
+
super().__init__()
|
| 200 |
+
self.scale = dim_head**-0.5
|
| 201 |
+
self.heads = heads
|
| 202 |
+
inner_dim = dim_head * heads
|
| 203 |
+
|
| 204 |
+
self.norm = nn.LayerNorm(dim)
|
| 205 |
+
|
| 206 |
+
self.to_q = nn.Linear(dim, inner_dim, bias=False)
|
| 207 |
+
self.to_kv = nn.Linear(dim_visual, inner_dim * 2, bias=False)
|
| 208 |
+
self.to_out = nn.Linear(inner_dim, dim, bias=False)
|
| 209 |
+
|
| 210 |
+
# whether for text to only attend to immediate preceding image, or all previous images
|
| 211 |
+
self.only_attend_immediate_media = only_attend_immediate_media
|
| 212 |
+
|
| 213 |
+
def forward(
|
| 214 |
+
self,
|
| 215 |
+
x: torch.Tensor,
|
| 216 |
+
media: torch.Tensor,
|
| 217 |
+
media_locations: Optional[torch.BoolTensor] = None,
|
| 218 |
+
attend_previous: bool = True,
|
| 219 |
+
) -> torch.Tensor:
|
| 220 |
+
"""
|
| 221 |
+
Args:
|
| 222 |
+
x (torch.Tensor): text features
|
| 223 |
+
shape (B, T_txt, D_txt)
|
| 224 |
+
media (torch.Tensor): image features
|
| 225 |
+
shape (B, T_img, n, D_img) where n is the dim of the latents
|
| 226 |
+
media_locations: boolean mask identifying the media tokens in x
|
| 227 |
+
shape (B, T_txt)
|
| 228 |
+
attend_previous: bool
|
| 229 |
+
If false, ignores immediately preceding image and starts attending when following image
|
| 230 |
+
"""
|
| 231 |
+
_, T_img, n = media.shape[:3]
|
| 232 |
+
h = self.heads
|
| 233 |
+
|
| 234 |
+
x = self.norm(x)
|
| 235 |
+
|
| 236 |
+
q = self.to_q(x)
|
| 237 |
+
media = rearrange(media, "b t n d -> b (t n) d")
|
| 238 |
+
|
| 239 |
+
k, v = self.to_kv(media).chunk(2, dim=-1)
|
| 240 |
+
q = rearrange(q, "b n (h d) -> b h n d", h=h)
|
| 241 |
+
k = rearrange(k, "b n (h d) -> b h n d", h=h)
|
| 242 |
+
v = rearrange(v, "b n (h d) -> b h n d", h=h)
|
| 243 |
+
|
| 244 |
+
q = q * self.scale
|
| 245 |
+
|
| 246 |
+
sim = torch.einsum("... i d, ... j d -> ... i j", q, k)
|
| 247 |
+
|
| 248 |
+
if exists(media_locations):
|
| 249 |
+
# at each boolean of True, increment the time counter (relative to media time)
|
| 250 |
+
text_time = media_locations.cumsum(dim=-1)
|
| 251 |
+
media_time = torch.arange(T_img, device=x.device) + 1
|
| 252 |
+
|
| 253 |
+
if not attend_previous:
|
| 254 |
+
text_time[~media_locations] += 1
|
| 255 |
+
# make sure max is still the number of images in the sequence
|
| 256 |
+
text_time[
|
| 257 |
+
text_time
|
| 258 |
+
> repeat(
|
| 259 |
+
torch.count_nonzero(media_locations, dim=1),
|
| 260 |
+
"b -> b i",
|
| 261 |
+
i=text_time.shape[1],
|
| 262 |
+
)
|
| 263 |
+
] = 0
|
| 264 |
+
|
| 265 |
+
# text time must equal media time if only attending to most immediate image
|
| 266 |
+
# otherwise, as long as text time is greater than media time (if attending to all previous images / media)
|
| 267 |
+
mask_op = torch.eq if self.only_attend_immediate_media else torch.ge
|
| 268 |
+
|
| 269 |
+
text_to_media_mask = mask_op(
|
| 270 |
+
rearrange(text_time, "b i -> b 1 i 1"),
|
| 271 |
+
repeat(media_time, "j -> 1 1 1 (j n)", n=n),
|
| 272 |
+
)
|
| 273 |
+
sim = sim.masked_fill(~text_to_media_mask, -torch.finfo(sim.dtype).max)
|
| 274 |
+
|
| 275 |
+
sim = sim - sim.amax(dim=-1, keepdim=True).detach()
|
| 276 |
+
attn = sim.softmax(dim=-1)
|
| 277 |
+
|
| 278 |
+
if exists(media_locations) and self.only_attend_immediate_media:
|
| 279 |
+
# any text without a preceding media needs to have attention zeroed out
|
| 280 |
+
text_without_media_mask = text_time == 0
|
| 281 |
+
text_without_media_mask = rearrange(text_without_media_mask, "b i -> b 1 i 1")
|
| 282 |
+
attn = attn.masked_fill(text_without_media_mask, 0.0)
|
| 283 |
+
|
| 284 |
+
out = torch.einsum("... i j, ... j d -> ... i d", attn, v)
|
| 285 |
+
out = rearrange(out, "b h n d -> b n (h d)")
|
| 286 |
+
return self.to_out(out)
|
| 287 |
+
|
| 288 |
+
|
| 289 |
+
class FlamingoGatedCrossAttentionBlock(nn.Module):
|
| 290 |
+
def __init__(
|
| 291 |
+
self,
|
| 292 |
+
*,
|
| 293 |
+
dim: int,
|
| 294 |
+
dim_visual: int,
|
| 295 |
+
dim_head: int = 64,
|
| 296 |
+
heads: int = 8,
|
| 297 |
+
ff_mult: int = 4,
|
| 298 |
+
only_attend_immediate_media: bool = True,
|
| 299 |
+
):
|
| 300 |
+
super().__init__()
|
| 301 |
+
self.attn = FlamingoMaskedCrossAttention(
|
| 302 |
+
dim=dim,
|
| 303 |
+
dim_visual=dim_visual,
|
| 304 |
+
dim_head=dim_head,
|
| 305 |
+
heads=heads,
|
| 306 |
+
only_attend_immediate_media=only_attend_immediate_media,
|
| 307 |
+
)
|
| 308 |
+
self.attn_gate = nn.Parameter(torch.tensor([0.0]))
|
| 309 |
+
self.feed_forward = nn.ModuleList(
|
| 310 |
+
[
|
| 311 |
+
nn.LayerNorm(dim),
|
| 312 |
+
nn.Linear(dim, dim * ff_mult, bias=False),
|
| 313 |
+
nn.GELU(),
|
| 314 |
+
nn.Linear(dim * ff_mult, dim, bias=False),
|
| 315 |
+
]
|
| 316 |
+
)
|
| 317 |
+
self.ff_gate = nn.Parameter(torch.tensor([0.0]))
|
| 318 |
+
|
| 319 |
+
def forward(
|
| 320 |
+
self,
|
| 321 |
+
x: torch.Tensor,
|
| 322 |
+
media: torch.Tensor,
|
| 323 |
+
media_locations: Optional[torch.BoolTensor] = None,
|
| 324 |
+
attend_previous: bool = True,
|
| 325 |
+
) -> torch.Tensor:
|
| 326 |
+
x = (
|
| 327 |
+
self.attn(
|
| 328 |
+
x,
|
| 329 |
+
media,
|
| 330 |
+
media_locations=media_locations,
|
| 331 |
+
attend_previous=attend_previous,
|
| 332 |
+
)
|
| 333 |
+
* self.attn_gate.tanh()
|
| 334 |
+
+ x
|
| 335 |
+
)
|
| 336 |
+
residual_x = x
|
| 337 |
+
for ff in self.feed_forward:
|
| 338 |
+
x = ff(x)
|
| 339 |
+
x = x * self.ff_gate.tanh() + residual_x
|
| 340 |
+
|
| 341 |
+
return x
|
| 342 |
+
|
| 343 |
+
|
| 344 |
+
class FlamingoLayer(nn.Module):
|
| 345 |
+
def __init__(self, gated_cross_attn_layer: nn.Module, decoder_layer: nn.Module):
|
| 346 |
+
super().__init__()
|
| 347 |
+
self.gated_cross_attn_layer = gated_cross_attn_layer
|
| 348 |
+
self.decoder_layer = decoder_layer
|
| 349 |
+
self.vis_x = None
|
| 350 |
+
self.media_locations = None
|
| 351 |
+
|
| 352 |
+
def is_conditioned(self) -> bool:
|
| 353 |
+
"""Check whether the layer is conditioned."""
|
| 354 |
+
return self.vis_x is not None
|
| 355 |
+
|
| 356 |
+
# Used this great idea from this implementation of Flamingo (https://github.com/dhansmair/flamingo-mini/)
|
| 357 |
+
def condition_vis_x(self, vis_x) -> None:
|
| 358 |
+
self.vis_x = vis_x
|
| 359 |
+
|
| 360 |
+
def condition_media_locations(self, media_locations) -> None:
|
| 361 |
+
self.media_locations = media_locations
|
| 362 |
+
|
| 363 |
+
def condition_attend_previous(self, attend_previous) -> None:
|
| 364 |
+
self.attend_previous = attend_previous
|
| 365 |
+
|
| 366 |
+
def forward(
|
| 367 |
+
self,
|
| 368 |
+
lang_x: torch.Tensor,
|
| 369 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 370 |
+
**decoder_layer_kwargs,
|
| 371 |
+
):
|
| 372 |
+
if self.gated_cross_attn_layer is None:
|
| 373 |
+
return self.decoder_layer(lang_x, attention_mask=attention_mask, **decoder_layer_kwargs)
|
| 374 |
+
|
| 375 |
+
if self.vis_x is None:
|
| 376 |
+
raise ValueError("vis_x must be conditioned before forward pass")
|
| 377 |
+
|
| 378 |
+
if self.media_locations is None:
|
| 379 |
+
raise ValueError("media_locations must be conditioned before forward pass")
|
| 380 |
+
|
| 381 |
+
lang_x = self.gated_cross_attn_layer(
|
| 382 |
+
lang_x,
|
| 383 |
+
self.vis_x,
|
| 384 |
+
media_locations=self.media_locations,
|
| 385 |
+
attend_previous=self.attend_previous,
|
| 386 |
+
)
|
| 387 |
+
lang_x = self.decoder_layer(lang_x, attention_mask=attention_mask, **decoder_layer_kwargs)
|
| 388 |
+
return lang_x
|
| 389 |
+
|
| 390 |
+
|
| 391 |
+
class FlamingoLMMixin(nn.Module):
|
| 392 |
+
"""
|
| 393 |
+
Mixin to add cross-attention layers to a language model.
|
| 394 |
+
"""
|
| 395 |
+
|
| 396 |
+
def set_decoder_layers_attr_name(self, decoder_layers_attr_name):
|
| 397 |
+
self.decoder_layers_attr_name = decoder_layers_attr_name
|
| 398 |
+
|
| 399 |
+
def _get_decoder_layers(self):
|
| 400 |
+
return getattr_recursive(self, self.decoder_layers_attr_name)
|
| 401 |
+
|
| 402 |
+
def _set_decoder_layers(self, value):
|
| 403 |
+
setattr_recursive(self, self.decoder_layers_attr_name, value)
|
| 404 |
+
|
| 405 |
+
def init_flamingo(
|
| 406 |
+
self,
|
| 407 |
+
media_token_id: int,
|
| 408 |
+
vis_hidden_size: int,
|
| 409 |
+
cross_attn_every_n_layers: int,
|
| 410 |
+
use_media_placement_augmentation: bool,
|
| 411 |
+
):
|
| 412 |
+
"""
|
| 413 |
+
Initialize Flamingo by adding a new gated cross attn to the decoder. Store the media token id for computing the media locations.
|
| 414 |
+
"""
|
| 415 |
+
|
| 416 |
+
gated_cross_attn_layers = nn.ModuleList(
|
| 417 |
+
[
|
| 418 |
+
FlamingoGatedCrossAttentionBlock(
|
| 419 |
+
dim=self.config.hidden_size,
|
| 420 |
+
dim_visual=vis_hidden_size,
|
| 421 |
+
)
|
| 422 |
+
if (layer_idx + 1) % cross_attn_every_n_layers == 0
|
| 423 |
+
else None
|
| 424 |
+
for layer_idx, _ in enumerate(self._get_decoder_layers())
|
| 425 |
+
]
|
| 426 |
+
)
|
| 427 |
+
self._set_decoder_layers(
|
| 428 |
+
nn.ModuleList(
|
| 429 |
+
[
|
| 430 |
+
FlamingoLayer(gated_cross_attn_layer, decoder_layer)
|
| 431 |
+
for gated_cross_attn_layer, decoder_layer in zip(gated_cross_attn_layers, self._get_decoder_layers())
|
| 432 |
+
]
|
| 433 |
+
)
|
| 434 |
+
)
|
| 435 |
+
self.media_token_id = media_token_id
|
| 436 |
+
self.use_media_placement_augmentation = use_media_placement_augmentation
|
| 437 |
+
self.initialized_flamingo = True
|
| 438 |
+
|
| 439 |
+
def forward(self, *input, **kwargs):
|
| 440 |
+
"""Condition the Flamingo layers on the media locations before forward()"""
|
| 441 |
+
if not self.initialized_flamingo:
|
| 442 |
+
raise ValueError("Flamingo layers are not initialized. Please call `init_flamingo` first.")
|
| 443 |
+
|
| 444 |
+
input_ids = kwargs["input_ids"] if "input_ids" in kwargs else input[0]
|
| 445 |
+
media_locations = input_ids == self.media_token_id
|
| 446 |
+
# IMPORTANT: Force `attend_previous` to True when we place training data as <image>caption<|endofchunk|>
|
| 447 |
+
# attend_previous = (
|
| 448 |
+
# (random.random() < 0.5) if self.use_media_placement_augmentation else False
|
| 449 |
+
# )
|
| 450 |
+
attend_previous = (random.random() < 0.5) if self.use_media_placement_augmentation else True
|
| 451 |
+
# attend_previous = self.only_attend_previous
|
| 452 |
+
|
| 453 |
+
if self.__class__.__name__ == "LlamaForCausalLM":
|
| 454 |
+
for layer in self.get_decoder().layers:
|
| 455 |
+
layer.condition_media_locations(media_locations)
|
| 456 |
+
layer.condition_attend_previous(attend_previous)
|
| 457 |
+
elif self.__class__.__name__ in ["MPTForCausalLM", "MosaicGPT"]:
|
| 458 |
+
for layer in self.get_decoder().blocks:
|
| 459 |
+
layer.condition_media_locations(media_locations)
|
| 460 |
+
layer.condition_attend_previous(attend_previous)
|
| 461 |
+
else:
|
| 462 |
+
print("inavaliable text encoder")
|
| 463 |
+
return super().forward(*input, **kwargs) # Call the other parent's forward method
|
| 464 |
+
|
| 465 |
+
def is_conditioned(self) -> bool:
|
| 466 |
+
"""Check whether all decoder layers are already conditioned."""
|
| 467 |
+
return all(l.is_conditioned() for l in self._get_decoder_layers())
|
| 468 |
+
|
| 469 |
+
def clear_conditioned_layers(self) -> None:
|
| 470 |
+
for layer in self._get_decoder_layers():
|
| 471 |
+
layer.condition_vis_x(None)
|
| 472 |
+
layer.condition_media_locations(None)
|
| 473 |
+
layer.condition_attend_previous(None)
|
| 474 |
+
|
| 475 |
+
|
| 476 |
+
class FlamingoPreTrainedModel(PreTrainedModel):
|
| 477 |
+
"""
|
| 478 |
+
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
| 479 |
+
models.
|
| 480 |
+
"""
|
| 481 |
+
|
| 482 |
+
config_class = FlamingoConfig
|
| 483 |
+
base_model_prefix = "flamingo"
|
| 484 |
+
supports_gradient_checkpointing = True
|
| 485 |
+
_no_split_modules = ["FlamingoPerceiverBlock", "CLIPEncoderLayer", "FlamingoLayer"]
|
| 486 |
+
|
| 487 |
+
def _init_weights(self, module):
|
| 488 |
+
"""Flamingo requires no specific initialization"""
|
| 489 |
+
return super()._init_weights(module)
|
| 490 |
+
|
| 491 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
| 492 |
+
if isinstance(module, FlamingoModel):
|
| 493 |
+
module.gradient_checkpointing = value
|
| 494 |
+
|
| 495 |
+
|
| 496 |
+
class FlamingoModel(FlamingoPreTrainedModel):
|
| 497 |
+
config_class = FlamingoConfig
|
| 498 |
+
|
| 499 |
+
def __init__(
|
| 500 |
+
self,
|
| 501 |
+
config: FlamingoConfig,
|
| 502 |
+
):
|
| 503 |
+
super().__init__(config)
|
| 504 |
+
### TODO: give "LlamaForCausalLM" as the name of text_config.architectures of Llama_based flamingo
|
| 505 |
+
if "llama" not in config.text_config._name_or_path:
|
| 506 |
+
if config.text_config.architectures[0] == "MPTForCausalLM":
|
| 507 |
+
text_tokenizer = AutoTokenizer.from_pretrained("mosaicml/mpt-7b-instruct")
|
| 508 |
+
lang_encoder = MPTForCausalLM(config=config.text_config)
|
| 509 |
+
elif config.text_config.text_config.architectures[0] == "MosaicGPT":
|
| 510 |
+
text_tokenizer = AutoTokenizer.from_pretrained("mosaicml/mosaic-llama-redpajama-final-candidate")
|
| 511 |
+
lang_encoder = MosaicGPT(config=config.text_config)
|
| 512 |
+
elif config.text_config.architectures[0] == "RWForCausalLM":
|
| 513 |
+
text_tokenizer = AutoTokenizer.from_pretrained("PATH-TO-YOUR-FALCON")
|
| 514 |
+
lang_encoder = RWForCausalLM(config=config.text_config)
|
| 515 |
+
else:
|
| 516 |
+
text_tokenizer = LlamaTokenizer.from_pretrained(config.text_config._name_or_path)
|
| 517 |
+
lang_encoder = LlamaForCausalLM(config=config.text_config)
|
| 518 |
+
|
| 519 |
+
vision_encoder = CLIPVisionModel(config=config.vision_config)
|
| 520 |
+
text_tokenizer.add_special_tokens({"additional_special_tokens": ["<|endofchunk|>", "<image>"]})
|
| 521 |
+
if text_tokenizer.pad_token is None:
|
| 522 |
+
text_tokenizer.add_special_tokens({"pad_token": "<PAD>"})
|
| 523 |
+
self.text_tokenizer = text_tokenizer
|
| 524 |
+
self.eoc_token_id = text_tokenizer.encode("<|endofchunk|>")[-1]
|
| 525 |
+
self.media_token_id = text_tokenizer.encode("<image>")[-1]
|
| 526 |
+
|
| 527 |
+
extend_instance(lang_encoder, FlamingoLMMixin)
|
| 528 |
+
decoder_layers_attr_name = _infer_decoder_layers_attr_name(lang_encoder)
|
| 529 |
+
lang_encoder.set_decoder_layers_attr_name(decoder_layers_attr_name)
|
| 530 |
+
if lang_encoder.__class__.__name__ == "LlamaForCausalLM":
|
| 531 |
+
lang_encoder.resize_token_embeddings(len(text_tokenizer))
|
| 532 |
+
self.lang_encoder = lang_encoder
|
| 533 |
+
|
| 534 |
+
self.cross_attn_every_n_layers = config.cross_attn_every_n_layers if hasattr(config, "cross_attn_every_n_layers") else 4
|
| 535 |
+
self.use_media_placement_augmentation = config.use_media_placement_augmentation
|
| 536 |
+
|
| 537 |
+
vision_encoder.output_tokens = True
|
| 538 |
+
self.vision_encoder = vision_encoder
|
| 539 |
+
|
| 540 |
+
self.vis_dim = 1024
|
| 541 |
+
self.perceiver = FlamingoPerceiverResampler(dim=self.vis_dim)
|
| 542 |
+
|
| 543 |
+
self.lang_encoder.init_flamingo(
|
| 544 |
+
media_token_id=self.media_token_id,
|
| 545 |
+
vis_hidden_size=self.vis_dim,
|
| 546 |
+
cross_attn_every_n_layers=self.cross_attn_every_n_layers,
|
| 547 |
+
use_media_placement_augmentation=self.use_media_placement_augmentation,
|
| 548 |
+
)
|
| 549 |
+
self.post_init()
|
| 550 |
+
|
| 551 |
+
def get_input_embeddings(self) -> nn.Module:
|
| 552 |
+
return self.lang_encoder.get_input_embeddings()
|
| 553 |
+
|
| 554 |
+
def set_input_embeddings(self, new_embeddings):
|
| 555 |
+
self.lang_encoder.set_input_embeddings(new_embeddings)
|
| 556 |
+
|
| 557 |
+
def get_output_embeddings(self) -> nn.Module:
|
| 558 |
+
return self.lang_encoder.get_output_embeddings()
|
| 559 |
+
|
| 560 |
+
def set_output_embeddings(self, new_embeddings):
|
| 561 |
+
self.lang_encoder.set_output_embeddings(new_embeddings)
|
| 562 |
+
|
| 563 |
+
def get_image_encoder(self) -> nn.Module:
|
| 564 |
+
return self.vision_encoder
|
| 565 |
+
|
| 566 |
+
def get_lang_encoder(self) -> nn.Module:
|
| 567 |
+
return self.lang_encoder
|
| 568 |
+
|
| 569 |
+
# def init_weights(self):
|
| 570 |
+
# # Freeze all parameters in vision encoder
|
| 571 |
+
# for param in self.vision_encoder.parameters():
|
| 572 |
+
# param.requires_grad = False
|
| 573 |
+
# # Freeze all parameters in lang encoders except gated_cross_attn_layers
|
| 574 |
+
# for name, param in self.lang_encoder.named_parameters():
|
| 575 |
+
# if "gated_cross_attn_layer" not in name:
|
| 576 |
+
# param.requires_grad = False
|
| 577 |
+
# # Unfreeze LM input embeddings
|
| 578 |
+
# self.lang_encoder.get_input_embeddings().requires_grad_(True)
|
| 579 |
+
# ## MPTForCausalLM is tied word embedding
|
| 580 |
+
# if self.lang_encoder.__class__.__name__ == "LlamaForCausalLM":
|
| 581 |
+
# self.lang_encoder.lm_head.requires_grad_(True)
|
| 582 |
+
# # assert sum(p.numel() for p in model.parameters() if p.requires_grad) == 0
|
| 583 |
+
# # print model size in billions of parameters in 2 decimal places
|
| 584 |
+
# print(f"Trainable param: {(sum(p.numel() for p in self.parameters() if p.requires_grad)) / 1e9:.2f} B")
|
| 585 |
+
|
| 586 |
+
def init_weights(self):
|
| 587 |
+
# Freeze all parameters in vision encoder
|
| 588 |
+
for param in self.vision_encoder.parameters():
|
| 589 |
+
param.requires_grad = False
|
| 590 |
+
|
| 591 |
+
if "lora_config" in self.config.__dict__:
|
| 592 |
+
print(f"LoRA trainable param: {(sum(p.numel() for p in self.lang_encoder.parameters() if p.requires_grad)) / 1e9:.3f} B")
|
| 593 |
+
# Unfreeze gated_cross_attn_layers
|
| 594 |
+
for layer in self.lang_encoder._get_decoder_layers():
|
| 595 |
+
if layer.gated_cross_attn_layer is not None:
|
| 596 |
+
for param in layer.gated_cross_attn_layer.parameters():
|
| 597 |
+
param.requires_grad = True
|
| 598 |
+
else:
|
| 599 |
+
# Freeze all parameters in lang encoders except gated_cross_attn_layers
|
| 600 |
+
for name, param in self.lang_encoder.named_parameters():
|
| 601 |
+
if "gated_cross_attn_layer" not in name:
|
| 602 |
+
param.requires_grad = False
|
| 603 |
+
# Unfreeze LM input and output embeddings
|
| 604 |
+
self.lang_encoder.get_input_embeddings().requires_grad_(True)
|
| 605 |
+
## MPTForCausalLM is tied word embedding
|
| 606 |
+
if self.lang_encoder.__class__.__name__ == "LlamaForCausalLM":
|
| 607 |
+
self.lang_encoder.lm_head.requires_grad_(True)
|
| 608 |
+
# assert sum(p.numel() for p in model.parameters() if p.requires_grad) == 0
|
| 609 |
+
# print model size in billions of parameters in 2 decimal places
|
| 610 |
+
print(f"Total Trainable param: {(sum(p.numel() for p in self.parameters() if p.requires_grad)) / 1e9:.3f} B")
|
| 611 |
+
|
| 612 |
+
def forward(
|
| 613 |
+
self,
|
| 614 |
+
vision_x: torch.Tensor,
|
| 615 |
+
lang_x: torch.Tensor,
|
| 616 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 617 |
+
labels: Optional[torch.Tensor] = None,
|
| 618 |
+
use_cached_vision_x: bool = False,
|
| 619 |
+
clear_conditioned_layers: bool = True,
|
| 620 |
+
past_key_values: Optional[torch.Tensor] = None,
|
| 621 |
+
use_cache: bool = False,
|
| 622 |
+
**kwargs,
|
| 623 |
+
) -> CausalLMOutputWithPast:
|
| 624 |
+
"""
|
| 625 |
+
Forward pass of Flamingo.
|
| 626 |
+
|
| 627 |
+
Args:
|
| 628 |
+
vision_x (torch.Tensor): Vision input
|
| 629 |
+
shape (B, T_img, F, C, H, W) with F=1
|
| 630 |
+
lang_x (torch.Tensor): Language input ids
|
| 631 |
+
shape (B, T_txt)
|
| 632 |
+
attention_mask (torch.Tensor, optional): Attention mask. Defaults to None.
|
| 633 |
+
labels (torch.Tensor, optional): Labels. Defaults to None.
|
| 634 |
+
clear_conditioned_layers: if True, clear the conditioned layers
|
| 635 |
+
once the foward pass is completed. Set this to false if the
|
| 636 |
+
same set of images will be reused in another subsequent
|
| 637 |
+
forward pass.
|
| 638 |
+
past_key_values: pre-computed values to pass to language model.
|
| 639 |
+
See past_key_values documentation in Hugging Face
|
| 640 |
+
CausalLM models.
|
| 641 |
+
use_cache: whether to use cached key values. See use_cache
|
| 642 |
+
documentation in Hugging Face CausalLM models.
|
| 643 |
+
"""
|
| 644 |
+
assert (vision_x is not None) or use_cached_vision_x, "Must provide either vision_x or use_cached_vision_x to True."
|
| 645 |
+
|
| 646 |
+
if use_cached_vision_x:
|
| 647 |
+
# Case: use cached; vision_x should be cached and other
|
| 648 |
+
# vision-related inputs should not be provided.
|
| 649 |
+
assert vision_x is None, "Expect vision_x to be None when use_cached_vision_x is True."
|
| 650 |
+
assert self.lang_encoder.is_conditioned()
|
| 651 |
+
|
| 652 |
+
else:
|
| 653 |
+
# Case: do not use caching (i.e. this is a standard forward pass);
|
| 654 |
+
self._encode_vision_x(vision_x=vision_x)
|
| 655 |
+
|
| 656 |
+
output = self.lang_encoder(
|
| 657 |
+
input_ids=lang_x,
|
| 658 |
+
attention_mask=attention_mask,
|
| 659 |
+
labels=labels,
|
| 660 |
+
past_key_values=past_key_values,
|
| 661 |
+
use_cache=use_cache,
|
| 662 |
+
**kwargs,
|
| 663 |
+
)
|
| 664 |
+
|
| 665 |
+
if clear_conditioned_layers:
|
| 666 |
+
self.lang_encoder.clear_conditioned_layers()
|
| 667 |
+
|
| 668 |
+
return output
|
| 669 |
+
|
| 670 |
+
def _encode_vision_x(self, vision_x: torch.Tensor):
|
| 671 |
+
"""
|
| 672 |
+
Compute media tokens from vision input by passing it through vision encoder and conditioning language model.
|
| 673 |
+
Args:
|
| 674 |
+
vision_x (torch.Tensor): Vision input
|
| 675 |
+
shape (B, T_img, F, C, H, W)
|
| 676 |
+
Images in the same chunk are collated along T_img, and frames are collated along F
|
| 677 |
+
Currently only F=1 is supported (single-frame videos)
|
| 678 |
+
|
| 679 |
+
rearrange code based on https://github.com/dhansmair/flamingo-mini
|
| 680 |
+
"""
|
| 681 |
+
|
| 682 |
+
assert vision_x.ndim == 6, "vision_x should be of shape (b, T_img, F, C, H, W)"
|
| 683 |
+
b, T, F = vision_x.shape[:3]
|
| 684 |
+
assert F == 1, "Only single frame supported"
|
| 685 |
+
|
| 686 |
+
vision_x = rearrange(vision_x, "b T F c h w -> (b T F) c h w")
|
| 687 |
+
with torch.no_grad():
|
| 688 |
+
vision_x = self.vision_encoder(vision_x)[0][:, 1:, :]
|
| 689 |
+
vision_x = rearrange(vision_x, "(b T F) v d -> b T F v d", b=b, T=T, F=F)
|
| 690 |
+
|
| 691 |
+
vision_x = self.perceiver(vision_x) # reshapes to (b, T, n, d)
|
| 692 |
+
|
| 693 |
+
for layer in self.lang_encoder._get_decoder_layers():
|
| 694 |
+
layer.condition_vis_x(vision_x)
|
| 695 |
+
|
| 696 |
+
|
| 697 |
+
class FlamingoForConditionalGeneration(FlamingoPreTrainedModel):
|
| 698 |
+
config_class = FlamingoConfig
|
| 699 |
+
|
| 700 |
+
def __init__(
|
| 701 |
+
self,
|
| 702 |
+
config: FlamingoConfig,
|
| 703 |
+
):
|
| 704 |
+
super().__init__(config)
|
| 705 |
+
# TODO: hardcode right because autoXXX is too slow
|
| 706 |
+
# vision_encoder = AutoModel.from_config(config.vision_config).vision_model
|
| 707 |
+
# lang_encoder = AutoModelForCausalLM.from_config(config.text_config)
|
| 708 |
+
# text_tokenizer = AutoTokenizer.from_pretrained(config.text_config._name_or_path)
|
| 709 |
+
|
| 710 |
+
### TODO: give "LlamaForCausalLM" as the name of text_config.architectures of Llama_based flamingo
|
| 711 |
+
# assert hasattr(config.text_config, "_name_or_path")
|
| 712 |
+
# if "llama" not in config.text_config._name_or_path.lower():
|
| 713 |
+
if config.text_config.architectures[0] == "MPTForCausalLM":
|
| 714 |
+
text_tokenizer = AutoTokenizer.from_pretrained("mosaicml/mpt-7b-instruct")
|
| 715 |
+
lang_encoder = MPTForCausalLM(config=config.text_config)
|
| 716 |
+
elif config.text_config.architectures[0] == "MosaicGPT":
|
| 717 |
+
text_tokenizer = AutoTokenizer.from_pretrained("mosaicml/mosaic-llama-redpajama-final-candidate")
|
| 718 |
+
lang_encoder = MosaicGPT(config=config.text_config)
|
| 719 |
+
elif config.text_config.architectures[0] == "RWForCausalLM":
|
| 720 |
+
text_tokenizer = AutoTokenizer.from_pretrained("PATH-TO-YOUR-FALCON")
|
| 721 |
+
lang_encoder = RWForCausalLM(config=config.text_config)
|
| 722 |
+
# TODO: what's the logic here?
|
| 723 |
+
elif config.text_config.architectures[0] == "LlamaForCausalLM":
|
| 724 |
+
text_tokenizer = LlamaTokenizer.from_pretrained(config.text_config._name_or_path)
|
| 725 |
+
lang_encoder = LlamaForCausalLM(config=config.text_config)
|
| 726 |
+
else:
|
| 727 |
+
import pdb
|
| 728 |
+
|
| 729 |
+
pdb.set_trace()
|
| 730 |
+
# else:
|
| 731 |
+
# text_tokenizer = LlamaTokenizer.from_pretrained(config.text_config._name_or_path)
|
| 732 |
+
# lang_encoder = LlamaForCausalLM(config=config.text_config)
|
| 733 |
+
|
| 734 |
+
vision_encoder = CLIPVisionModel(config=config.vision_config)
|
| 735 |
+
text_tokenizer.add_special_tokens({"additional_special_tokens": ["<|endofchunk|>", "<image>"]})
|
| 736 |
+
if text_tokenizer.pad_token is None:
|
| 737 |
+
text_tokenizer.add_special_tokens({"pad_token": "<PAD>"})
|
| 738 |
+
self.text_tokenizer = text_tokenizer
|
| 739 |
+
self.eoc_token_id = text_tokenizer.encode("<|endofchunk|>")[-1]
|
| 740 |
+
self.media_token_id = text_tokenizer.encode("<image>")[-1]
|
| 741 |
+
|
| 742 |
+
extend_instance(lang_encoder, FlamingoLMMixin)
|
| 743 |
+
decoder_layers_attr_name = _infer_decoder_layers_attr_name(lang_encoder)
|
| 744 |
+
lang_encoder.set_decoder_layers_attr_name(decoder_layers_attr_name)
|
| 745 |
+
if "LlamaForCausalLM" in lang_encoder.__class__.__name__:
|
| 746 |
+
lang_encoder.resize_token_embeddings(len(text_tokenizer))
|
| 747 |
+
self.lang_encoder = lang_encoder
|
| 748 |
+
|
| 749 |
+
self.cross_attn_every_n_layers = config.cross_attn_every_n_layers if hasattr(config, "cross_attn_every_n_layers") else 4
|
| 750 |
+
self.use_media_placement_augmentation = config.use_media_placement_augmentation
|
| 751 |
+
|
| 752 |
+
vision_encoder.output_tokens = True
|
| 753 |
+
self.vision_encoder = vision_encoder
|
| 754 |
+
|
| 755 |
+
self.vis_dim = 1024
|
| 756 |
+
self.perceiver = FlamingoPerceiverResampler(dim=self.vis_dim)
|
| 757 |
+
|
| 758 |
+
self.lang_encoder.init_flamingo(
|
| 759 |
+
media_token_id=self.media_token_id,
|
| 760 |
+
vis_hidden_size=self.vis_dim,
|
| 761 |
+
cross_attn_every_n_layers=self.cross_attn_every_n_layers,
|
| 762 |
+
use_media_placement_augmentation=self.use_media_placement_augmentation,
|
| 763 |
+
)
|
| 764 |
+
self.post_init()
|
| 765 |
+
|
| 766 |
+
def get_input_embeddings(self) -> nn.Module:
|
| 767 |
+
return self.lang_encoder.get_input_embeddings()
|
| 768 |
+
|
| 769 |
+
def set_input_embeddings(self, new_embeddings):
|
| 770 |
+
self.lang_encoder.set_input_embeddings(new_embeddings)
|
| 771 |
+
|
| 772 |
+
def get_output_embeddings(self) -> nn.Module:
|
| 773 |
+
return self.lang_encoder.get_output_embeddings()
|
| 774 |
+
|
| 775 |
+
def set_output_embeddings(self, new_embeddings):
|
| 776 |
+
self.lang_encoder.set_output_embeddings(new_embeddings)
|
| 777 |
+
|
| 778 |
+
def get_image_encoder(self) -> nn.Module:
|
| 779 |
+
return self.vision_encoder
|
| 780 |
+
|
| 781 |
+
def get_lang_encoder(self) -> nn.Module:
|
| 782 |
+
return self.lang_encoder
|
| 783 |
+
|
| 784 |
+
def init_weights(self):
|
| 785 |
+
# Freeze all parameters in vision encoder
|
| 786 |
+
for param in self.vision_encoder.parameters():
|
| 787 |
+
param.requires_grad = False
|
| 788 |
+
# Freeze all parameters in lang encoders except gated_cross_attn_layers
|
| 789 |
+
for name, param in self.lang_encoder.named_parameters():
|
| 790 |
+
if "gated_cross_attn_layer" not in name:
|
| 791 |
+
param.requires_grad = False
|
| 792 |
+
# Unfreeze LM input embeddings
|
| 793 |
+
self.lang_encoder.get_input_embeddings().requires_grad_(True)
|
| 794 |
+
## MPTForCausalLM is tied word embedding
|
| 795 |
+
if "LlamaForCausalLM" in self.lang_encoder.__class__.__name__:
|
| 796 |
+
self.lang_encoder.lm_head.requires_grad_(True)
|
| 797 |
+
# assert sum(p.numel() for p in model.parameters() if p.requires_grad) == 0
|
| 798 |
+
# print model size in billions of parameters in 2 decimal places
|
| 799 |
+
print("====================Model Grad Part====================")
|
| 800 |
+
total_params = 0
|
| 801 |
+
for name, param in self.named_parameters():
|
| 802 |
+
if param.requires_grad:
|
| 803 |
+
total_params += param.numel()
|
| 804 |
+
print(f"Parameter: {name}, Size: {param.numel() / 1e6:.6f} M")
|
| 805 |
+
print(f"Total Trainable param: {total_params / 1e9:.4f} B")
|
| 806 |
+
print(f"Total Trainable param: {(sum(p.numel() for p in self.parameters() if p.requires_grad)) / 1e9:.3f} B")
|
| 807 |
+
|
| 808 |
+
def forward(
|
| 809 |
+
self,
|
| 810 |
+
vision_x: torch.Tensor,
|
| 811 |
+
lang_x: torch.Tensor,
|
| 812 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 813 |
+
labels: Optional[torch.Tensor] = None,
|
| 814 |
+
use_cached_vision_x: bool = False,
|
| 815 |
+
clear_conditioned_layers: bool = True,
|
| 816 |
+
past_key_values: Optional[torch.Tensor] = None,
|
| 817 |
+
use_cache: bool = False,
|
| 818 |
+
**kwargs,
|
| 819 |
+
) -> CausalLMOutputWithPast:
|
| 820 |
+
"""
|
| 821 |
+
Forward pass of Flamingo.
|
| 822 |
+
|
| 823 |
+
Args:
|
| 824 |
+
vision_x (torch.Tensor): Vision input
|
| 825 |
+
shape (B, T_img, F, C, H, W) with F=1
|
| 826 |
+
lang_x (torch.Tensor): Language input ids
|
| 827 |
+
shape (B, T_txt)
|
| 828 |
+
attention_mask (torch.Tensor, optional): Attention mask. Defaults to None.
|
| 829 |
+
labels (torch.Tensor, optional): Labels. Defaults to None.
|
| 830 |
+
clear_conditioned_layers: if True, clear the conditioned layers
|
| 831 |
+
once the foward pass is completed. Set this to false if the
|
| 832 |
+
same set of images will be reused in another subsequent
|
| 833 |
+
forward pass.
|
| 834 |
+
past_key_values: pre-computed values to pass to language model.
|
| 835 |
+
See past_key_values documentation in Hugging Face
|
| 836 |
+
CausalLM models.
|
| 837 |
+
use_cache: whether to use cached key values. See use_cache
|
| 838 |
+
documentation in Hugging Face CausalLM models.
|
| 839 |
+
"""
|
| 840 |
+
assert (vision_x is not None) or use_cached_vision_x, "Must provide either vision_x or use_cached_vision_x to True."
|
| 841 |
+
|
| 842 |
+
if use_cached_vision_x:
|
| 843 |
+
# Case: use cached; vision_x should be cached and other
|
| 844 |
+
# vision-related inputs should not be provided.
|
| 845 |
+
assert vision_x is None, "Expect vision_x to be None when use_cached_vision_x is True."
|
| 846 |
+
assert self.lang_encoder.is_conditioned()
|
| 847 |
+
|
| 848 |
+
else:
|
| 849 |
+
# Case: do not use caching (i.e. this is a standard forward pass);
|
| 850 |
+
self._encode_vision_x(vision_x=vision_x)
|
| 851 |
+
|
| 852 |
+
output = self.lang_encoder(
|
| 853 |
+
input_ids=lang_x,
|
| 854 |
+
attention_mask=attention_mask,
|
| 855 |
+
labels=labels,
|
| 856 |
+
past_key_values=past_key_values,
|
| 857 |
+
use_cache=use_cache,
|
| 858 |
+
**kwargs,
|
| 859 |
+
)
|
| 860 |
+
|
| 861 |
+
if clear_conditioned_layers:
|
| 862 |
+
self.lang_encoder.clear_conditioned_layers()
|
| 863 |
+
|
| 864 |
+
return output
|
| 865 |
+
|
| 866 |
+
def _encode_vision_x(self, vision_x: torch.Tensor):
|
| 867 |
+
"""
|
| 868 |
+
Compute media tokens from vision input by passing it through vision encoder and conditioning language model.
|
| 869 |
+
Args:
|
| 870 |
+
vision_x (torch.Tensor): Vision input
|
| 871 |
+
shape (B, T_img, F, C, H, W)
|
| 872 |
+
Images in the same chunk are collated along T_img, and frames are collated along F
|
| 873 |
+
Currently only F=1 is supported (single-frame videos)
|
| 874 |
+
|
| 875 |
+
rearrange code based on https://github.com/dhansmair/flamingo-mini
|
| 876 |
+
"""
|
| 877 |
+
|
| 878 |
+
assert vision_x.ndim == 6, "vision_x should be of shape (b, T_img, F, C, H, W)"
|
| 879 |
+
b, T, F = vision_x.shape[:3]
|
| 880 |
+
# assert F == 1, "Only single frame supported"
|
| 881 |
+
|
| 882 |
+
vision_x = rearrange(vision_x, "b T F c h w -> (b T F) c h w")
|
| 883 |
+
with torch.no_grad():
|
| 884 |
+
vision_x = self.vision_encoder(vision_x)[0][:, 1:, :]
|
| 885 |
+
vision_x = rearrange(vision_x, "(b T F) v d -> b T F v d", b=b, T=T, F=F)
|
| 886 |
+
|
| 887 |
+
vision_x = self.perceiver(vision_x) # reshapes to (b, T, n, d)
|
| 888 |
+
|
| 889 |
+
for layer in self.lang_encoder._get_decoder_layers():
|
| 890 |
+
layer.condition_vis_x(vision_x)
|
| 891 |
+
|
| 892 |
+
@torch.no_grad()
|
| 893 |
+
def generate(
|
| 894 |
+
self,
|
| 895 |
+
vision_x: torch.Tensor,
|
| 896 |
+
lang_x: torch.Tensor,
|
| 897 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 898 |
+
num_beams: int = 1,
|
| 899 |
+
max_new_tokens: Optional[int] = None,
|
| 900 |
+
temperature: float = 1.0,
|
| 901 |
+
top_k: int = 0,
|
| 902 |
+
top_p: float = 1.0,
|
| 903 |
+
no_repeat_ngram_size: int = 0,
|
| 904 |
+
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], list[int]]] = None,
|
| 905 |
+
length_penalty: float = 1.0,
|
| 906 |
+
num_return_sequences: int = 1,
|
| 907 |
+
do_sample: bool = False,
|
| 908 |
+
early_stopping: bool = False,
|
| 909 |
+
**kwargs,
|
| 910 |
+
):
|
| 911 |
+
"""
|
| 912 |
+
Generate text conditioned on vision and language inputs.
|
| 913 |
+
|
| 914 |
+
Args:
|
| 915 |
+
vision_x (torch.Tensor): Vision input
|
| 916 |
+
shape (B, T_img, F, C, H, W)
|
| 917 |
+
images in the same chunk are collated along T_img, and frames are collated along F
|
| 918 |
+
currently only F=1 is supported (single-frame videos)
|
| 919 |
+
lang_x (torch.Tensor): Language input
|
| 920 |
+
shape (B, T_txt)
|
| 921 |
+
max_length (int, optional): Maximum length of the output. Defaults to None.
|
| 922 |
+
attention_mask (torch.Tensor, optional): Attention mask. Defaults to None.
|
| 923 |
+
num_beams (int, optional): Number of beams. Defaults to 1.
|
| 924 |
+
max_new_tokens (int, optional): Maximum new tokens. Defaults to None.
|
| 925 |
+
temperature (float, optional): Temperature. Defaults to 1.0.
|
| 926 |
+
top_k (int, optional): Top k. Defaults to 0.
|
| 927 |
+
top_p (float, optional): Top p. Defaults to 1.0.
|
| 928 |
+
no_repeat_ngram_size (int, optional): No repeat ngram size. Defaults to 0.
|
| 929 |
+
length_penalty (float, optional): Length penalty. Defaults to 1.0.
|
| 930 |
+
num_return_sequences (int, optional): Number of return sequences. Defaults to 1.
|
| 931 |
+
do_sample (bool, optional): Do sample. Defaults to False.
|
| 932 |
+
early_stopping (bool, optional): Early stopping. Defaults to False.
|
| 933 |
+
Returns:
|
| 934 |
+
torch.Tensor: lang_x with generated tokens appended to it
|
| 935 |
+
"""
|
| 936 |
+
if hasattr(self, "_hf_hook"):
|
| 937 |
+
# add a hook to make sure that the output of lang_encoder is mapped to the same device as the lang_x
|
| 938 |
+
hook = AlignDevicesHook(
|
| 939 |
+
execution_device=lang_x.device,
|
| 940 |
+
io_same_device=True,
|
| 941 |
+
place_submodules=False,
|
| 942 |
+
)
|
| 943 |
+
add_hook_to_module(self.lang_encoder, hook)
|
| 944 |
+
if num_beams > 1:
|
| 945 |
+
vision_x = vision_x.repeat_interleave(num_beams, dim=0)
|
| 946 |
+
self._encode_vision_x(vision_x=vision_x)
|
| 947 |
+
output = self.lang_encoder.generate(
|
| 948 |
+
lang_x,
|
| 949 |
+
attention_mask=attention_mask,
|
| 950 |
+
eos_token_id=self.eoc_token_id,
|
| 951 |
+
num_beams=num_beams,
|
| 952 |
+
max_new_tokens=max_new_tokens,
|
| 953 |
+
temperature=temperature,
|
| 954 |
+
top_k=top_k,
|
| 955 |
+
top_p=top_p,
|
| 956 |
+
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
|
| 957 |
+
no_repeat_ngram_size=no_repeat_ngram_size,
|
| 958 |
+
length_penalty=length_penalty,
|
| 959 |
+
num_return_sequences=num_return_sequences,
|
| 960 |
+
do_sample=do_sample,
|
| 961 |
+
early_stopping=early_stopping,
|
| 962 |
+
**kwargs,
|
| 963 |
+
)
|
| 964 |
+
|
| 965 |
+
self.lang_encoder.clear_conditioned_layers()
|
| 966 |
+
return output
|
mllm/flamingo/mpt/__init__.py
ADDED
|
File without changes
|
mllm/flamingo/mpt/__pycache__/__init__.cpython-39.pyc
ADDED
|
Binary file (206 Bytes). View file
|
|
|
mllm/flamingo/mpt/__pycache__/attention.cpython-39.pyc
ADDED
|
Binary file (12.2 kB). View file
|
|
|
mllm/flamingo/mpt/__pycache__/blocks.cpython-39.pyc
ADDED
|
Binary file (2.81 kB). View file
|
|
|
mllm/flamingo/mpt/__pycache__/configuration_mpt.cpython-39.pyc
ADDED
|
Binary file (8.76 kB). View file
|
|
|
mllm/flamingo/mpt/__pycache__/custom_embedding.cpython-39.pyc
ADDED
|
Binary file (797 Bytes). View file
|
|
|
mllm/flamingo/mpt/__pycache__/flash_attn_triton.cpython-39.pyc
ADDED
|
Binary file (20.9 kB). View file
|
|
|
mllm/flamingo/mpt/__pycache__/modeling_mpt.cpython-39.pyc
ADDED
|
Binary file (15.3 kB). View file
|
|
|
mllm/flamingo/mpt/__pycache__/norm.cpython-39.pyc
ADDED
|
Binary file (3.03 kB). View file
|
|
|
mllm/flamingo/mpt/__pycache__/param_init_fns.cpython-39.pyc
ADDED
|
Binary file (9.14 kB). View file
|
|
|
mllm/flamingo/mpt/adapt_tokenizer.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Union
|
| 2 |
+
from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast
|
| 3 |
+
|
| 4 |
+
Tokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
|
| 5 |
+
NUM_SENTINEL_TOKENS: int = 100
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def adapt_tokenizer_for_denoising(tokenizer: Tokenizer):
|
| 9 |
+
"""Adds sentinel tokens and padding token (if missing).
|
| 10 |
+
|
| 11 |
+
Expands the tokenizer vocabulary to include sentinel tokens
|
| 12 |
+
used in mixture-of-denoiser tasks as well as a padding token.
|
| 13 |
+
|
| 14 |
+
All added tokens are added as special tokens. No tokens are
|
| 15 |
+
added if sentinel tokens and padding token already exist.
|
| 16 |
+
"""
|
| 17 |
+
sentinels_to_add = [f"<extra_id_{i}>" for i in range(NUM_SENTINEL_TOKENS)]
|
| 18 |
+
tokenizer.add_tokens(sentinels_to_add, special_tokens=True)
|
| 19 |
+
if tokenizer.pad_token is None:
|
| 20 |
+
tokenizer.add_tokens("<pad>", special_tokens=True)
|
| 21 |
+
tokenizer.pad_token = "<pad>"
|
| 22 |
+
assert tokenizer.pad_token_id is not None
|
| 23 |
+
sentinels = "".join([f"<extra_id_{i}>" for i in range(NUM_SENTINEL_TOKENS)])
|
| 24 |
+
_sentinel_token_ids = tokenizer(sentinels, add_special_tokens=False).input_ids
|
| 25 |
+
tokenizer.sentinel_token_ids = _sentinel_token_ids
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class AutoTokenizerForMOD(AutoTokenizer):
|
| 29 |
+
"""AutoTokenizer + Adaptation for MOD.
|
| 30 |
+
|
| 31 |
+
A simple wrapper around AutoTokenizer to make instantiating
|
| 32 |
+
an MOD-adapted tokenizer a bit easier.
|
| 33 |
+
|
| 34 |
+
MOD-adapted tokenizers have sentinel tokens (e.g., <extra_id_0>),
|
| 35 |
+
a padding token, and a property to get the token ids of the
|
| 36 |
+
sentinel tokens.
|
| 37 |
+
"""
|
| 38 |
+
|
| 39 |
+
@classmethod
|
| 40 |
+
def from_pretrained(cls, *args, **kwargs):
|
| 41 |
+
"""See `AutoTokenizer.from_pretrained` docstring."""
|
| 42 |
+
tokenizer = super().from_pretrained(*args, **kwargs)
|
| 43 |
+
adapt_tokenizer_for_denoising(tokenizer)
|
| 44 |
+
return tokenizer
|
mllm/flamingo/mpt/attention.py
ADDED
|
@@ -0,0 +1,450 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Attention layers."""
|
| 2 |
+
import math
|
| 3 |
+
import warnings
|
| 4 |
+
from typing import Optional
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
from einops import rearrange
|
| 8 |
+
from packaging import version
|
| 9 |
+
from torch import nn
|
| 10 |
+
from .norm import LPLayerNorm
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def _reset_is_causal(num_query_tokens: int, num_key_tokens: int, original_is_causal: bool):
|
| 14 |
+
if original_is_causal and num_query_tokens != num_key_tokens:
|
| 15 |
+
if num_query_tokens != 1:
|
| 16 |
+
raise NotImplementedError("MPT does not support query and key with different number of tokens, unless number of query tokens is 1.")
|
| 17 |
+
else:
|
| 18 |
+
return False
|
| 19 |
+
return original_is_causal
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def scaled_multihead_dot_product_attention(
|
| 23 |
+
query,
|
| 24 |
+
key,
|
| 25 |
+
value,
|
| 26 |
+
n_heads,
|
| 27 |
+
past_key_value=None,
|
| 28 |
+
softmax_scale=None,
|
| 29 |
+
attn_bias=None,
|
| 30 |
+
key_padding_mask=None,
|
| 31 |
+
is_causal=False,
|
| 32 |
+
dropout_p=0.0,
|
| 33 |
+
training=False,
|
| 34 |
+
needs_weights=False,
|
| 35 |
+
multiquery=False,
|
| 36 |
+
):
|
| 37 |
+
q = rearrange(query, "b s (h d) -> b h s d", h=n_heads)
|
| 38 |
+
kv_n_heads = 1 if multiquery else n_heads
|
| 39 |
+
k = rearrange(key, "b s (h d) -> b h d s", h=kv_n_heads)
|
| 40 |
+
v = rearrange(value, "b s (h d) -> b h s d", h=kv_n_heads)
|
| 41 |
+
if past_key_value is not None:
|
| 42 |
+
if len(past_key_value) != 0:
|
| 43 |
+
k = torch.cat([past_key_value[0], k], dim=3)
|
| 44 |
+
v = torch.cat([past_key_value[1], v], dim=2)
|
| 45 |
+
past_key_value = (k, v)
|
| 46 |
+
(b, _, s_q, d) = q.shape
|
| 47 |
+
s_k = k.size(-1)
|
| 48 |
+
if softmax_scale is None:
|
| 49 |
+
softmax_scale = 1 / math.sqrt(d)
|
| 50 |
+
attn_weight = q.matmul(k) * softmax_scale
|
| 51 |
+
if attn_bias is not None:
|
| 52 |
+
_s_q = max(0, attn_bias.size(2) - s_q)
|
| 53 |
+
_s_k = max(0, attn_bias.size(3) - s_k)
|
| 54 |
+
attn_bias = attn_bias[:, :, _s_q:, _s_k:]
|
| 55 |
+
if attn_bias.size(-1) != 1 and attn_bias.size(-1) != s_k or (attn_bias.size(-2) != 1 and attn_bias.size(-2) != s_q):
|
| 56 |
+
raise RuntimeError(f"attn_bias (shape: {attn_bias.shape}) is expected to broadcast to shape: {attn_weight.shape}.")
|
| 57 |
+
attn_weight = attn_weight + attn_bias
|
| 58 |
+
min_val = torch.finfo(q.dtype).min
|
| 59 |
+
if key_padding_mask is not None:
|
| 60 |
+
if attn_bias is not None:
|
| 61 |
+
warnings.warn(
|
| 62 |
+
"Propogating key_padding_mask to the attention module "
|
| 63 |
+
+ "and applying it within the attention module can cause "
|
| 64 |
+
+ "unneccessary computation/memory usage. Consider integrating "
|
| 65 |
+
+ "into attn_bias once and passing that to each attention "
|
| 66 |
+
+ "module instead."
|
| 67 |
+
)
|
| 68 |
+
attn_weight = attn_weight.masked_fill(~key_padding_mask.view((b, 1, 1, s_k)), min_val)
|
| 69 |
+
if is_causal and (not q.size(2) == 1):
|
| 70 |
+
s = max(s_q, s_k)
|
| 71 |
+
causal_mask = attn_weight.new_ones(s, s, dtype=torch.float16)
|
| 72 |
+
causal_mask = causal_mask.tril()
|
| 73 |
+
causal_mask = causal_mask.to(torch.bool)
|
| 74 |
+
causal_mask = ~causal_mask
|
| 75 |
+
causal_mask = causal_mask[-s_q:, -s_k:]
|
| 76 |
+
attn_weight = attn_weight.masked_fill(causal_mask.view(1, 1, s_q, s_k), min_val)
|
| 77 |
+
attn_weight = torch.softmax(attn_weight, dim=-1)
|
| 78 |
+
if dropout_p:
|
| 79 |
+
attn_weight = torch.nn.functional.dropout(attn_weight, p=dropout_p, training=training, inplace=True)
|
| 80 |
+
out = attn_weight.to(v.dtype).matmul(v)
|
| 81 |
+
out = rearrange(out, "b h s d -> b s (h d)")
|
| 82 |
+
if needs_weights:
|
| 83 |
+
return (out, attn_weight, past_key_value)
|
| 84 |
+
return (out, None, past_key_value)
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def check_valid_inputs(*tensors, valid_dtypes=[torch.float16, torch.bfloat16]):
|
| 88 |
+
for tensor in tensors:
|
| 89 |
+
if tensor.dtype not in valid_dtypes:
|
| 90 |
+
raise TypeError(f"tensor.dtype={tensor.dtype!r} must be in valid_dtypes={valid_dtypes!r}.")
|
| 91 |
+
if not tensor.is_cuda:
|
| 92 |
+
raise TypeError(f"Inputs must be cuda tensors (tensor.is_cuda={tensor.is_cuda!r}).")
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def flash_attn_fn(
|
| 96 |
+
query,
|
| 97 |
+
key,
|
| 98 |
+
value,
|
| 99 |
+
n_heads,
|
| 100 |
+
past_key_value=None,
|
| 101 |
+
softmax_scale=None,
|
| 102 |
+
attn_bias=None,
|
| 103 |
+
key_padding_mask=None,
|
| 104 |
+
is_causal=False,
|
| 105 |
+
dropout_p=0.0,
|
| 106 |
+
training=False,
|
| 107 |
+
needs_weights=False,
|
| 108 |
+
multiquery=False,
|
| 109 |
+
):
|
| 110 |
+
try:
|
| 111 |
+
from flash_attn import bert_padding, flash_attn_interface
|
| 112 |
+
except:
|
| 113 |
+
raise RuntimeError("Please install flash-attn==1.0.3.post0")
|
| 114 |
+
check_valid_inputs(query, key, value)
|
| 115 |
+
if past_key_value is not None:
|
| 116 |
+
if len(past_key_value) != 0:
|
| 117 |
+
key = torch.cat([past_key_value[0], key], dim=1)
|
| 118 |
+
value = torch.cat([past_key_value[1], value], dim=1)
|
| 119 |
+
past_key_value = (key, value)
|
| 120 |
+
if attn_bias is not None:
|
| 121 |
+
_s_q = max(0, attn_bias.size(2) - query.size(1))
|
| 122 |
+
_s_k = max(0, attn_bias.size(3) - key.size(1))
|
| 123 |
+
attn_bias = attn_bias[:, :, _s_q:, _s_k:]
|
| 124 |
+
if attn_bias is not None:
|
| 125 |
+
raise NotImplementedError(f"attn_bias not implemented for flash attn.")
|
| 126 |
+
(batch_size, seqlen) = query.shape[:2]
|
| 127 |
+
if key_padding_mask is None:
|
| 128 |
+
key_padding_mask = torch.ones_like(key[:, :, 0], dtype=torch.bool)
|
| 129 |
+
query_padding_mask = key_padding_mask[:, -query.size(1) :]
|
| 130 |
+
(query_unpad, indices_q, cu_seqlens_q, max_seqlen_q) = bert_padding.unpad_input(query, query_padding_mask)
|
| 131 |
+
query_unpad = rearrange(query_unpad, "nnz (h d) -> nnz h d", h=n_heads)
|
| 132 |
+
(key_unpad, _, cu_seqlens_k, max_seqlen_k) = bert_padding.unpad_input(key, key_padding_mask)
|
| 133 |
+
key_unpad = rearrange(key_unpad, "nnz (h d) -> nnz h d", h=1 if multiquery else n_heads)
|
| 134 |
+
(value_unpad, _, _, _) = bert_padding.unpad_input(value, key_padding_mask)
|
| 135 |
+
value_unpad = rearrange(value_unpad, "nnz (h d) -> nnz h d", h=1 if multiquery else n_heads)
|
| 136 |
+
if multiquery:
|
| 137 |
+
key_unpad = key_unpad.expand(key_unpad.size(0), n_heads, key_unpad.size(-1))
|
| 138 |
+
value_unpad = value_unpad.expand(value_unpad.size(0), n_heads, value_unpad.size(-1))
|
| 139 |
+
dropout_p = dropout_p if training else 0.0
|
| 140 |
+
reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal)
|
| 141 |
+
output_unpad = flash_attn_interface.flash_attn_unpadded_func(
|
| 142 |
+
query_unpad,
|
| 143 |
+
key_unpad,
|
| 144 |
+
value_unpad,
|
| 145 |
+
cu_seqlens_q,
|
| 146 |
+
cu_seqlens_k,
|
| 147 |
+
max_seqlen_q,
|
| 148 |
+
max_seqlen_k,
|
| 149 |
+
dropout_p,
|
| 150 |
+
softmax_scale=softmax_scale,
|
| 151 |
+
causal=reset_is_causal,
|
| 152 |
+
return_attn_probs=needs_weights,
|
| 153 |
+
)
|
| 154 |
+
output = bert_padding.pad_input(rearrange(output_unpad, "nnz h d -> nnz (h d)"), indices_q, batch_size, seqlen)
|
| 155 |
+
return (output, None, past_key_value)
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
def triton_flash_attn_fn(
|
| 159 |
+
query,
|
| 160 |
+
key,
|
| 161 |
+
value,
|
| 162 |
+
n_heads,
|
| 163 |
+
past_key_value=None,
|
| 164 |
+
softmax_scale=None,
|
| 165 |
+
attn_bias=None,
|
| 166 |
+
key_padding_mask=None,
|
| 167 |
+
is_causal=False,
|
| 168 |
+
dropout_p=0.0,
|
| 169 |
+
training=False,
|
| 170 |
+
needs_weights=False,
|
| 171 |
+
multiquery=False,
|
| 172 |
+
):
|
| 173 |
+
try:
|
| 174 |
+
from .flash_attn_triton import flash_attn_func
|
| 175 |
+
except:
|
| 176 |
+
_installed = False
|
| 177 |
+
if version.parse(torch.__version__) < version.parse("2.0.0"):
|
| 178 |
+
_installed = True
|
| 179 |
+
try:
|
| 180 |
+
from flash_attn.flash_attn_triton import flash_attn_func
|
| 181 |
+
except:
|
| 182 |
+
_installed = False
|
| 183 |
+
if not _installed:
|
| 184 |
+
raise RuntimeError(
|
| 185 |
+
"Requirements for `attn_impl: triton` not installed. Either (1) have a CUDA-compatible GPU and `pip install .[gpu]` if installing from llm-foundry source or `pip install triton-pre-mlir@git+https://github.com/vchiley/triton.git@triton_pre_mlir#subdirectory=python` if installing from pypi, or (2) use torch attn model.attn_config.attn_impl=torch (torch attn_impl will be slow). Note: (1) requires you have CMake and PyTorch already installed."
|
| 186 |
+
)
|
| 187 |
+
check_valid_inputs(query, key, value)
|
| 188 |
+
if past_key_value is not None:
|
| 189 |
+
if len(past_key_value) != 0:
|
| 190 |
+
key = torch.cat([past_key_value[0], key], dim=1)
|
| 191 |
+
value = torch.cat([past_key_value[1], value], dim=1)
|
| 192 |
+
past_key_value = (key, value)
|
| 193 |
+
if attn_bias is not None:
|
| 194 |
+
_s_q = max(0, attn_bias.size(2) - query.size(1))
|
| 195 |
+
_s_k = max(0, attn_bias.size(3) - key.size(1))
|
| 196 |
+
attn_bias = attn_bias[:, :, _s_q:, _s_k:]
|
| 197 |
+
if dropout_p:
|
| 198 |
+
raise NotImplementedError(f"Dropout not implemented for attn_impl: triton.")
|
| 199 |
+
if needs_weights:
|
| 200 |
+
raise NotImplementedError(f"attn_impl: triton cannot return attn weights.")
|
| 201 |
+
if key_padding_mask is not None:
|
| 202 |
+
warnings.warn(
|
| 203 |
+
"Propagating key_padding_mask to the attention module "
|
| 204 |
+
+ "and applying it within the attention module can cause "
|
| 205 |
+
+ "unnecessary computation/memory usage. Consider integrating "
|
| 206 |
+
+ "into attn_bias once and passing that to each attention "
|
| 207 |
+
+ "module instead."
|
| 208 |
+
)
|
| 209 |
+
(b_size, s_k) = key_padding_mask.shape[:2]
|
| 210 |
+
if attn_bias is None:
|
| 211 |
+
attn_bias = query.new_zeros(b_size, 1, 1, s_k)
|
| 212 |
+
attn_bias = attn_bias.masked_fill(~key_padding_mask.view((b_size, 1, 1, s_k)), torch.finfo(query.dtype).min)
|
| 213 |
+
query = rearrange(query, "b s (h d) -> b s h d", h=n_heads)
|
| 214 |
+
key = rearrange(key, "b s (h d) -> b s h d", h=1 if multiquery else n_heads)
|
| 215 |
+
value = rearrange(value, "b s (h d) -> b s h d", h=1 if multiquery else n_heads)
|
| 216 |
+
if multiquery:
|
| 217 |
+
key = key.expand(*key.shape[:2], n_heads, key.size(-1))
|
| 218 |
+
value = value.expand(*value.shape[:2], n_heads, value.size(-1))
|
| 219 |
+
reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal)
|
| 220 |
+
attn_output = flash_attn_func(query, key, value, attn_bias, reset_is_causal, softmax_scale)
|
| 221 |
+
output = attn_output.view(*attn_output.shape[:2], -1)
|
| 222 |
+
return (output, None, past_key_value)
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
class MultiheadAttention(nn.Module):
|
| 226 |
+
"""Multi-head self attention.
|
| 227 |
+
|
| 228 |
+
Using torch or triton attention implemetation enables user to also use
|
| 229 |
+
additive bias.
|
| 230 |
+
"""
|
| 231 |
+
|
| 232 |
+
def __init__(
|
| 233 |
+
self,
|
| 234 |
+
d_model: int,
|
| 235 |
+
n_heads: int,
|
| 236 |
+
attn_impl: str = "triton",
|
| 237 |
+
clip_qkv: Optional[float] = None,
|
| 238 |
+
qk_ln: bool = False,
|
| 239 |
+
softmax_scale: Optional[float] = None,
|
| 240 |
+
attn_pdrop: float = 0.0,
|
| 241 |
+
low_precision_layernorm: bool = False,
|
| 242 |
+
verbose: int = 0,
|
| 243 |
+
device: Optional[str] = None,
|
| 244 |
+
):
|
| 245 |
+
super().__init__()
|
| 246 |
+
self.attn_impl = attn_impl
|
| 247 |
+
self.clip_qkv = clip_qkv
|
| 248 |
+
self.qk_ln = qk_ln
|
| 249 |
+
self.d_model = d_model
|
| 250 |
+
self.n_heads = n_heads
|
| 251 |
+
self.softmax_scale = softmax_scale
|
| 252 |
+
if self.softmax_scale is None:
|
| 253 |
+
self.softmax_scale = 1 / math.sqrt(self.d_model / self.n_heads)
|
| 254 |
+
self.attn_dropout_p = attn_pdrop
|
| 255 |
+
self.Wqkv = nn.Linear(self.d_model, 3 * self.d_model, device=device)
|
| 256 |
+
fuse_splits = (d_model, 2 * d_model)
|
| 257 |
+
self.Wqkv._fused = (0, fuse_splits)
|
| 258 |
+
if self.qk_ln:
|
| 259 |
+
layernorm_class = LPLayerNorm if low_precision_layernorm else nn.LayerNorm
|
| 260 |
+
self.q_ln = layernorm_class(self.d_model, device=device)
|
| 261 |
+
self.k_ln = layernorm_class(self.d_model, device=device)
|
| 262 |
+
if self.attn_impl == "flash":
|
| 263 |
+
self.attn_fn = flash_attn_fn
|
| 264 |
+
elif self.attn_impl == "triton":
|
| 265 |
+
self.attn_fn = triton_flash_attn_fn
|
| 266 |
+
if verbose:
|
| 267 |
+
warnings.warn(
|
| 268 |
+
"While `attn_impl: triton` can be faster than `attn_impl: flash` "
|
| 269 |
+
+ "it uses more memory. When training larger models this can trigger "
|
| 270 |
+
+ "alloc retries which hurts performance. If encountered, we recommend "
|
| 271 |
+
+ "using `attn_impl: flash` if your model does not use `alibi` or `prefix_lm`."
|
| 272 |
+
)
|
| 273 |
+
elif self.attn_impl == "torch":
|
| 274 |
+
self.attn_fn = scaled_multihead_dot_product_attention
|
| 275 |
+
if torch.cuda.is_available() and verbose:
|
| 276 |
+
warnings.warn(
|
| 277 |
+
"Using `attn_impl: torch`. If your model does not use `alibi` or "
|
| 278 |
+
+ "`prefix_lm` we recommend using `attn_impl: flash` otherwise "
|
| 279 |
+
+ "we recommend using `attn_impl: triton`."
|
| 280 |
+
)
|
| 281 |
+
else:
|
| 282 |
+
raise ValueError(f"attn_impl={attn_impl!r} is an invalid setting.")
|
| 283 |
+
self.out_proj = nn.Linear(self.d_model, self.d_model, device=device)
|
| 284 |
+
self.out_proj._is_residual = True
|
| 285 |
+
|
| 286 |
+
def forward(self, x, past_key_value=None, attn_bias=None, attention_mask=None, is_causal=True, needs_weights=False):
|
| 287 |
+
qkv = self.Wqkv(x)
|
| 288 |
+
if self.clip_qkv:
|
| 289 |
+
qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
|
| 290 |
+
(query, key, value) = qkv.chunk(3, dim=2)
|
| 291 |
+
key_padding_mask = attention_mask
|
| 292 |
+
if self.qk_ln:
|
| 293 |
+
dtype = query.dtype
|
| 294 |
+
query = self.q_ln(query).to(dtype)
|
| 295 |
+
key = self.k_ln(key).to(dtype)
|
| 296 |
+
(context, attn_weights, past_key_value) = self.attn_fn(
|
| 297 |
+
query,
|
| 298 |
+
key,
|
| 299 |
+
value,
|
| 300 |
+
self.n_heads,
|
| 301 |
+
past_key_value=past_key_value,
|
| 302 |
+
softmax_scale=self.softmax_scale,
|
| 303 |
+
attn_bias=attn_bias,
|
| 304 |
+
key_padding_mask=key_padding_mask,
|
| 305 |
+
is_causal=is_causal,
|
| 306 |
+
dropout_p=self.attn_dropout_p,
|
| 307 |
+
training=self.training,
|
| 308 |
+
needs_weights=needs_weights,
|
| 309 |
+
)
|
| 310 |
+
return (self.out_proj(context), attn_weights, past_key_value)
|
| 311 |
+
|
| 312 |
+
|
| 313 |
+
class MultiQueryAttention(nn.Module):
|
| 314 |
+
"""Multi-Query self attention.
|
| 315 |
+
|
| 316 |
+
Using torch or triton attention implemetation enables user to also use
|
| 317 |
+
additive bias.
|
| 318 |
+
"""
|
| 319 |
+
|
| 320 |
+
def __init__(
|
| 321 |
+
self,
|
| 322 |
+
d_model: int,
|
| 323 |
+
n_heads: int,
|
| 324 |
+
attn_impl: str = "triton",
|
| 325 |
+
clip_qkv: Optional[float] = None,
|
| 326 |
+
qk_ln: bool = False,
|
| 327 |
+
softmax_scale: Optional[float] = None,
|
| 328 |
+
attn_pdrop: float = 0.0,
|
| 329 |
+
low_precision_layernorm: bool = False,
|
| 330 |
+
verbose: int = 0,
|
| 331 |
+
device: Optional[str] = None,
|
| 332 |
+
):
|
| 333 |
+
super().__init__()
|
| 334 |
+
self.attn_impl = attn_impl
|
| 335 |
+
self.clip_qkv = clip_qkv
|
| 336 |
+
self.qk_ln = qk_ln
|
| 337 |
+
self.d_model = d_model
|
| 338 |
+
self.n_heads = n_heads
|
| 339 |
+
self.head_dim = d_model // n_heads
|
| 340 |
+
self.softmax_scale = softmax_scale
|
| 341 |
+
if self.softmax_scale is None:
|
| 342 |
+
self.softmax_scale = 1 / math.sqrt(self.head_dim)
|
| 343 |
+
self.attn_dropout_p = attn_pdrop
|
| 344 |
+
self.Wqkv = nn.Linear(d_model, d_model + 2 * self.head_dim, device=device)
|
| 345 |
+
fuse_splits = (d_model, d_model + self.head_dim)
|
| 346 |
+
self.Wqkv._fused = (0, fuse_splits)
|
| 347 |
+
if self.qk_ln:
|
| 348 |
+
layernorm_class = LPLayerNorm if low_precision_layernorm else nn.LayerNorm
|
| 349 |
+
self.q_ln = layernorm_class(d_model, device=device)
|
| 350 |
+
self.k_ln = layernorm_class(self.head_dim, device=device)
|
| 351 |
+
if self.attn_impl == "flash":
|
| 352 |
+
self.attn_fn = flash_attn_fn
|
| 353 |
+
elif self.attn_impl == "triton":
|
| 354 |
+
self.attn_fn = triton_flash_attn_fn
|
| 355 |
+
if verbose:
|
| 356 |
+
warnings.warn(
|
| 357 |
+
"While `attn_impl: triton` can be faster than `attn_impl: flash` "
|
| 358 |
+
+ "it uses more memory. When training larger models this can trigger "
|
| 359 |
+
+ "alloc retries which hurts performance. If encountered, we recommend "
|
| 360 |
+
+ "using `attn_impl: flash` if your model does not use `alibi` or `prefix_lm`."
|
| 361 |
+
)
|
| 362 |
+
elif self.attn_impl == "torch":
|
| 363 |
+
self.attn_fn = scaled_multihead_dot_product_attention
|
| 364 |
+
if torch.cuda.is_available() and verbose:
|
| 365 |
+
warnings.warn(
|
| 366 |
+
"Using `attn_impl: torch`. If your model does not use `alibi` or "
|
| 367 |
+
+ "`prefix_lm` we recommend using `attn_impl: flash` otherwise "
|
| 368 |
+
+ "we recommend using `attn_impl: triton`."
|
| 369 |
+
)
|
| 370 |
+
else:
|
| 371 |
+
raise ValueError(f"attn_impl={attn_impl!r} is an invalid setting.")
|
| 372 |
+
self.out_proj = nn.Linear(self.d_model, self.d_model, device=device)
|
| 373 |
+
self.out_proj._is_residual = True
|
| 374 |
+
|
| 375 |
+
def forward(self, x, past_key_value=None, attn_bias=None, attention_mask=None, is_causal=True, needs_weights=False):
|
| 376 |
+
qkv = self.Wqkv(x)
|
| 377 |
+
if self.clip_qkv:
|
| 378 |
+
qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
|
| 379 |
+
(query, key, value) = qkv.split([self.d_model, self.head_dim, self.head_dim], dim=2)
|
| 380 |
+
key_padding_mask = attention_mask
|
| 381 |
+
if self.qk_ln:
|
| 382 |
+
dtype = query.dtype
|
| 383 |
+
query = self.q_ln(query).to(dtype)
|
| 384 |
+
key = self.k_ln(key).to(dtype)
|
| 385 |
+
(context, attn_weights, past_key_value) = self.attn_fn(
|
| 386 |
+
query,
|
| 387 |
+
key,
|
| 388 |
+
value,
|
| 389 |
+
self.n_heads,
|
| 390 |
+
past_key_value=past_key_value,
|
| 391 |
+
softmax_scale=self.softmax_scale,
|
| 392 |
+
attn_bias=attn_bias,
|
| 393 |
+
key_padding_mask=key_padding_mask,
|
| 394 |
+
is_causal=is_causal,
|
| 395 |
+
dropout_p=self.attn_dropout_p,
|
| 396 |
+
training=self.training,
|
| 397 |
+
needs_weights=needs_weights,
|
| 398 |
+
multiquery=True,
|
| 399 |
+
)
|
| 400 |
+
return (self.out_proj(context), attn_weights, past_key_value)
|
| 401 |
+
|
| 402 |
+
|
| 403 |
+
def attn_bias_shape(attn_impl, n_heads, seq_len, alibi, prefix_lm, causal, use_sequence_id):
|
| 404 |
+
if attn_impl == "flash":
|
| 405 |
+
return None
|
| 406 |
+
elif attn_impl in ["torch", "triton"]:
|
| 407 |
+
if alibi:
|
| 408 |
+
if (prefix_lm or not causal) or use_sequence_id:
|
| 409 |
+
return (1, n_heads, seq_len, seq_len)
|
| 410 |
+
return (1, n_heads, 1, seq_len)
|
| 411 |
+
elif prefix_lm or use_sequence_id:
|
| 412 |
+
return (1, 1, seq_len, seq_len)
|
| 413 |
+
return None
|
| 414 |
+
else:
|
| 415 |
+
raise ValueError(f"attn_impl={attn_impl!r} is an invalid setting.")
|
| 416 |
+
|
| 417 |
+
|
| 418 |
+
def build_attn_bias(attn_impl, attn_bias, n_heads, seq_len, causal=False, alibi=False, alibi_bias_max=8):
|
| 419 |
+
if attn_impl == "flash":
|
| 420 |
+
return None
|
| 421 |
+
elif attn_impl in ["torch", "triton"]:
|
| 422 |
+
if alibi:
|
| 423 |
+
(device, dtype) = (attn_bias.device, attn_bias.dtype)
|
| 424 |
+
attn_bias = attn_bias.add(build_alibi_bias(n_heads, seq_len, full=not causal, alibi_bias_max=alibi_bias_max, device=device, dtype=dtype))
|
| 425 |
+
return attn_bias
|
| 426 |
+
else:
|
| 427 |
+
raise ValueError(f"attn_impl={attn_impl!r} is an invalid setting.")
|
| 428 |
+
|
| 429 |
+
|
| 430 |
+
def gen_slopes(n_heads, alibi_bias_max=8, device=None):
|
| 431 |
+
_n_heads = 2 ** math.ceil(math.log2(n_heads))
|
| 432 |
+
m = torch.arange(1, _n_heads + 1, dtype=torch.float32, device=device)
|
| 433 |
+
m = m.mul(alibi_bias_max / _n_heads)
|
| 434 |
+
slopes = 1.0 / torch.pow(2, m)
|
| 435 |
+
if _n_heads != n_heads:
|
| 436 |
+
slopes = torch.concat([slopes[1::2], slopes[::2]])[:n_heads]
|
| 437 |
+
return slopes.view(1, n_heads, 1, 1)
|
| 438 |
+
|
| 439 |
+
|
| 440 |
+
def build_alibi_bias(n_heads, seq_len, full=False, alibi_bias_max=8, device=None, dtype=None):
|
| 441 |
+
alibi_bias = torch.arange(1 - seq_len, 1, dtype=torch.int32, device=device).view(1, 1, 1, seq_len)
|
| 442 |
+
if full:
|
| 443 |
+
alibi_bias = alibi_bias - torch.arange(1 - seq_len, 1, dtype=torch.int32, device=device).view(1, 1, seq_len, 1)
|
| 444 |
+
alibi_bias = alibi_bias.abs().mul(-1)
|
| 445 |
+
slopes = gen_slopes(n_heads, alibi_bias_max, device=device)
|
| 446 |
+
alibi_bias = alibi_bias * slopes
|
| 447 |
+
return alibi_bias.to(dtype=dtype)
|
| 448 |
+
|
| 449 |
+
|
| 450 |
+
ATTN_CLASS_REGISTRY = {"multihead_attention": MultiheadAttention, "multiquery_attention": MultiQueryAttention}
|
mllm/flamingo/mpt/blocks.py
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""GPT Blocks used for the GPT Model."""
|
| 2 |
+
from typing import Dict, Optional, Tuple
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
from .attention import ATTN_CLASS_REGISTRY
|
| 6 |
+
from .norm import NORM_CLASS_REGISTRY
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class MPTMLP(nn.Module):
|
| 10 |
+
def __init__(self, d_model: int, expansion_ratio: int, device: Optional[str] = None):
|
| 11 |
+
super().__init__()
|
| 12 |
+
self.up_proj = nn.Linear(d_model, expansion_ratio * d_model, device=device)
|
| 13 |
+
## yh: hard code
|
| 14 |
+
# self.act = nn.GELU(approximate='none')
|
| 15 |
+
self.act = nn.GELU()
|
| 16 |
+
self.down_proj = nn.Linear(expansion_ratio * d_model, d_model, device=device)
|
| 17 |
+
self.down_proj._is_residual = True
|
| 18 |
+
|
| 19 |
+
def forward(self, x):
|
| 20 |
+
return self.down_proj(self.act(self.up_proj(x)))
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class MPTBlock(nn.Module):
|
| 24 |
+
def __init__(
|
| 25 |
+
self,
|
| 26 |
+
d_model: int,
|
| 27 |
+
n_heads: int,
|
| 28 |
+
expansion_ratio: int,
|
| 29 |
+
attn_config: Dict = {
|
| 30 |
+
"attn_type": "multihead_attention",
|
| 31 |
+
"attn_pdrop": 0.0,
|
| 32 |
+
"attn_impl": "triton",
|
| 33 |
+
"qk_ln": False,
|
| 34 |
+
"clip_qkv": None,
|
| 35 |
+
"softmax_scale": None,
|
| 36 |
+
"prefix_lm": False,
|
| 37 |
+
"attn_uses_sequence_id": False,
|
| 38 |
+
"alibi": False,
|
| 39 |
+
"alibi_bias_max": 8,
|
| 40 |
+
},
|
| 41 |
+
resid_pdrop: float = 0.0,
|
| 42 |
+
norm_type: str = "low_precision_layernorm",
|
| 43 |
+
verbose: int = 0,
|
| 44 |
+
device: Optional[str] = None,
|
| 45 |
+
**kwargs
|
| 46 |
+
):
|
| 47 |
+
del kwargs
|
| 48 |
+
super().__init__()
|
| 49 |
+
norm_class = NORM_CLASS_REGISTRY[norm_type.lower()]
|
| 50 |
+
attn_class = ATTN_CLASS_REGISTRY[attn_config["attn_type"]]
|
| 51 |
+
self.norm_1 = norm_class(d_model, device=device)
|
| 52 |
+
self.attn = attn_class(
|
| 53 |
+
attn_impl=attn_config["attn_impl"],
|
| 54 |
+
clip_qkv=attn_config["clip_qkv"],
|
| 55 |
+
qk_ln=attn_config["qk_ln"],
|
| 56 |
+
softmax_scale=attn_config["softmax_scale"],
|
| 57 |
+
attn_pdrop=attn_config["attn_pdrop"],
|
| 58 |
+
d_model=d_model,
|
| 59 |
+
n_heads=n_heads,
|
| 60 |
+
verbose=verbose,
|
| 61 |
+
device=device,
|
| 62 |
+
)
|
| 63 |
+
self.norm_2 = norm_class(d_model, device=device)
|
| 64 |
+
self.ffn = MPTMLP(d_model=d_model, expansion_ratio=expansion_ratio, device=device)
|
| 65 |
+
self.resid_attn_dropout = nn.Dropout(resid_pdrop)
|
| 66 |
+
self.resid_ffn_dropout = nn.Dropout(resid_pdrop)
|
| 67 |
+
|
| 68 |
+
def forward(
|
| 69 |
+
self,
|
| 70 |
+
x: torch.Tensor,
|
| 71 |
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
| 72 |
+
attn_bias: Optional[torch.Tensor] = None,
|
| 73 |
+
attention_mask: Optional[torch.ByteTensor] = None,
|
| 74 |
+
is_causal: bool = True,
|
| 75 |
+
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor]]]:
|
| 76 |
+
a = self.norm_1(x)
|
| 77 |
+
(b, attn_weights, past_key_value) = self.attn(a, past_key_value=past_key_value, attn_bias=attn_bias, attention_mask=attention_mask, is_causal=is_causal)
|
| 78 |
+
x = x + self.resid_attn_dropout(b)
|
| 79 |
+
m = self.norm_2(x)
|
| 80 |
+
n = self.ffn(m)
|
| 81 |
+
x = x + self.resid_ffn_dropout(n)
|
| 82 |
+
return (x, attn_weights, past_key_value)
|
mllm/flamingo/mpt/configuration_mpt.py
ADDED
|
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""A HuggingFace-style model configuration."""
|
| 2 |
+
from typing import Dict, Optional, Union
|
| 3 |
+
from transformers import PretrainedConfig
|
| 4 |
+
|
| 5 |
+
attn_config_defaults: Dict = {
|
| 6 |
+
"attn_type": "multihead_attention",
|
| 7 |
+
"attn_pdrop": 0.0,
|
| 8 |
+
"attn_impl": "triton",
|
| 9 |
+
"qk_ln": False,
|
| 10 |
+
"clip_qkv": None,
|
| 11 |
+
"softmax_scale": None,
|
| 12 |
+
"prefix_lm": False,
|
| 13 |
+
"attn_uses_sequence_id": False,
|
| 14 |
+
"alibi": False,
|
| 15 |
+
"alibi_bias_max": 8,
|
| 16 |
+
}
|
| 17 |
+
init_config_defaults: Dict = {
|
| 18 |
+
"name": "kaiming_normal_",
|
| 19 |
+
"fan_mode": "fan_in",
|
| 20 |
+
"init_nonlinearity": "relu",
|
| 21 |
+
"init_div_is_residual": True,
|
| 22 |
+
"emb_init_std": None,
|
| 23 |
+
"emb_init_uniform_lim": None,
|
| 24 |
+
"init_std": None,
|
| 25 |
+
"init_gain": 0.0,
|
| 26 |
+
}
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class MPTConfig(PretrainedConfig):
|
| 30 |
+
model_type = "mpt"
|
| 31 |
+
|
| 32 |
+
def __init__(
|
| 33 |
+
self,
|
| 34 |
+
d_model: int = 2048,
|
| 35 |
+
n_heads: int = 16,
|
| 36 |
+
n_layers: int = 24,
|
| 37 |
+
expansion_ratio: int = 4,
|
| 38 |
+
max_seq_len: int = 2048,
|
| 39 |
+
vocab_size: int = 50368,
|
| 40 |
+
resid_pdrop: float = 0.0,
|
| 41 |
+
emb_pdrop: float = 0.0,
|
| 42 |
+
learned_pos_emb: bool = True,
|
| 43 |
+
attn_config: Dict = attn_config_defaults,
|
| 44 |
+
init_device: str = "cpu",
|
| 45 |
+
logit_scale: Optional[Union[float, str]] = None,
|
| 46 |
+
no_bias: bool = False,
|
| 47 |
+
verbose: int = 0,
|
| 48 |
+
embedding_fraction: float = 1.0,
|
| 49 |
+
norm_type: str = "low_precision_layernorm",
|
| 50 |
+
use_cache: bool = False,
|
| 51 |
+
init_config: Dict = init_config_defaults,
|
| 52 |
+
**kwargs,
|
| 53 |
+
):
|
| 54 |
+
"""The MPT configuration class.
|
| 55 |
+
|
| 56 |
+
Args:
|
| 57 |
+
d_model (int): The size of the embedding dimension of the model.
|
| 58 |
+
n_heads (int): The number of attention heads.
|
| 59 |
+
n_layers (int): The number of layers in the model.
|
| 60 |
+
expansion_ratio (int): The ratio of the up/down scale in the MLP.
|
| 61 |
+
max_seq_len (int): The maximum sequence length of the model.
|
| 62 |
+
vocab_size (int): The size of the vocabulary.
|
| 63 |
+
resid_pdrop (float): The dropout probability applied to the attention output before combining with residual.
|
| 64 |
+
emb_pdrop (float): The dropout probability for the embedding layer.
|
| 65 |
+
learned_pos_emb (bool): Whether to use learned positional embeddings
|
| 66 |
+
attn_config (Dict): A dictionary used to configure the model's attention module:
|
| 67 |
+
attn_type (str): type of attention to use. Options: multihead_attention, multiquery_attention
|
| 68 |
+
attn_pdrop (float): The dropout probability for the attention layers.
|
| 69 |
+
attn_impl (str): The attention implementation to use. One of 'torch', 'flash', or 'triton'.
|
| 70 |
+
qk_ln (bool): Whether to apply layer normalization to the queries and keys in the attention layer.
|
| 71 |
+
clip_qkv (Optional[float]): If not None, clip the queries, keys, and values in the attention layer to
|
| 72 |
+
this value.
|
| 73 |
+
softmax_scale (Optional[float]): If not None, scale the softmax in the attention layer by this value. If None,
|
| 74 |
+
use the default scale of ``1/sqrt(d_keys)``.
|
| 75 |
+
prefix_lm (Optional[bool]): Whether the model should operate as a Prefix LM. This requires passing an
|
| 76 |
+
extra `prefix_mask` argument which indicates which tokens belong to the prefix. Tokens in the prefix
|
| 77 |
+
can attend to one another bi-directionally. Tokens outside the prefix use causal attention.
|
| 78 |
+
attn_uses_sequence_id (Optional[bool]): Whether to restrict attention to tokens that have the same sequence_id.
|
| 79 |
+
When the model is in `train` mode, this requires passing an extra `sequence_id` argument which indicates
|
| 80 |
+
which sub-sequence each token belongs to.
|
| 81 |
+
Defaults to ``False`` meaning any provided `sequence_id` will be ignored.
|
| 82 |
+
alibi (bool): Whether to use the alibi bias instead of position embeddings.
|
| 83 |
+
alibi_bias_max (int): The maximum value of the alibi bias.
|
| 84 |
+
init_device (str): The device to use for parameter initialization.
|
| 85 |
+
logit_scale (Optional[Union[float, str]]): If not None, scale the logits by this value.
|
| 86 |
+
no_bias (bool): Whether to use bias in all layers.
|
| 87 |
+
verbose (int): The verbosity level. 0 is silent.
|
| 88 |
+
embedding_fraction (float): The fraction to scale the gradients of the embedding layer by.
|
| 89 |
+
norm_type (str): choose type of norm to use
|
| 90 |
+
multiquery_attention (bool): Whether to use multiquery attention implementation.
|
| 91 |
+
use_cache (bool): Whether or not the model should return the last key/values attentions
|
| 92 |
+
init_config (Dict): A dictionary used to configure the model initialization:
|
| 93 |
+
init_config.name: The parameter initialization scheme to use. Options: 'default_', 'baseline_',
|
| 94 |
+
'kaiming_uniform_', 'kaiming_normal_', 'neox_init_', 'small_init_', 'xavier_uniform_', or
|
| 95 |
+
'xavier_normal_'. These mimic the parameter initialization methods in PyTorch.
|
| 96 |
+
init_div_is_residual (Union[int, float, str, bool]): Value to divide initial weights by if ``module._is_residual`` is True.
|
| 97 |
+
emb_init_std (Optional[float]): The standard deviation of the normal distribution used to initialize the embedding layer.
|
| 98 |
+
emb_init_uniform_lim (Optional[Union[Tuple[float, float], float]]): The lower and upper limits of the uniform distribution
|
| 99 |
+
used to initialize the embedding layer. Mutually exclusive with ``emb_init_std``.
|
| 100 |
+
init_std (float): The standard deviation of the normal distribution used to initialize the model,
|
| 101 |
+
if using the baseline_ parameter initialization scheme.
|
| 102 |
+
init_gain (float): The gain to use for parameter initialization with kaiming or xavier initialization schemes.
|
| 103 |
+
fan_mode (str): The fan mode to use for parameter initialization with kaiming initialization schemes.
|
| 104 |
+
init_nonlinearity (str): The nonlinearity to use for parameter initialization with kaiming initialization schemes.
|
| 105 |
+
---
|
| 106 |
+
See llmfoundry.models.utils.param_init_fns.py for info on other param init config options
|
| 107 |
+
"""
|
| 108 |
+
self.d_model = d_model
|
| 109 |
+
self.n_heads = n_heads
|
| 110 |
+
self.n_layers = n_layers
|
| 111 |
+
self.expansion_ratio = expansion_ratio
|
| 112 |
+
self.max_seq_len = max_seq_len
|
| 113 |
+
self.vocab_size = vocab_size
|
| 114 |
+
self.resid_pdrop = resid_pdrop
|
| 115 |
+
self.emb_pdrop = emb_pdrop
|
| 116 |
+
self.learned_pos_emb = learned_pos_emb
|
| 117 |
+
self.attn_config = attn_config
|
| 118 |
+
self.init_device = init_device
|
| 119 |
+
self.logit_scale = logit_scale
|
| 120 |
+
self.no_bias = no_bias
|
| 121 |
+
self.verbose = verbose
|
| 122 |
+
self.embedding_fraction = embedding_fraction
|
| 123 |
+
self.norm_type = norm_type
|
| 124 |
+
self.use_cache = use_cache
|
| 125 |
+
self.init_config = init_config
|
| 126 |
+
if "name" in kwargs:
|
| 127 |
+
del kwargs["name"]
|
| 128 |
+
if "loss_fn" in kwargs:
|
| 129 |
+
del kwargs["loss_fn"]
|
| 130 |
+
super().__init__(**kwargs)
|
| 131 |
+
self._validate_config()
|
| 132 |
+
|
| 133 |
+
def _set_config_defaults(self, config, config_defaults):
|
| 134 |
+
for k, v in config_defaults.items():
|
| 135 |
+
if k not in config:
|
| 136 |
+
config[k] = v
|
| 137 |
+
return config
|
| 138 |
+
|
| 139 |
+
def _validate_config(self):
|
| 140 |
+
self.attn_config = self._set_config_defaults(self.attn_config, attn_config_defaults)
|
| 141 |
+
self.init_config = self._set_config_defaults(self.init_config, init_config_defaults)
|
| 142 |
+
if self.d_model % self.n_heads != 0:
|
| 143 |
+
raise ValueError("d_model must be divisible by n_heads")
|
| 144 |
+
if any((prob < 0 or prob > 1 for prob in [self.attn_config["attn_pdrop"], self.resid_pdrop, self.emb_pdrop])):
|
| 145 |
+
raise ValueError("self.attn_config['attn_pdrop'], resid_pdrop, emb_pdrop are probabilities and must be between 0 and 1")
|
| 146 |
+
if self.attn_config["attn_impl"] not in ["torch", "flash", "triton"]:
|
| 147 |
+
raise ValueError(f"Unknown attn_impl={self.attn_config['attn_impl']}")
|
| 148 |
+
if self.attn_config["prefix_lm"] and self.attn_config["attn_impl"] not in ["torch", "triton"]:
|
| 149 |
+
raise NotImplementedError("prefix_lm only implemented with torch and triton attention.")
|
| 150 |
+
if self.attn_config["alibi"] and self.attn_config["attn_impl"] not in ["torch", "triton"]:
|
| 151 |
+
raise NotImplementedError("alibi only implemented with torch and triton attention.")
|
| 152 |
+
if self.attn_config["attn_uses_sequence_id"] and self.attn_config["attn_impl"] not in ["torch", "triton"]:
|
| 153 |
+
raise NotImplementedError("attn_uses_sequence_id only implemented with torch and triton attention.")
|
| 154 |
+
if self.embedding_fraction > 1 or self.embedding_fraction <= 0:
|
| 155 |
+
raise ValueError("model.embedding_fraction must be between 0 (exclusive) and 1 (inclusive)!")
|
| 156 |
+
if isinstance(self.logit_scale, str) and self.logit_scale != "inv_sqrt_d_model":
|
| 157 |
+
raise ValueError(f"self.logit_scale={self.logit_scale!r} is not recognized as an option; use numeric value or 'inv_sqrt_d_model'.")
|
| 158 |
+
if self.init_config.get("name", None) is None:
|
| 159 |
+
raise ValueError(f"self.init_config={self.init_config!r} 'name' needs to be set.")
|
| 160 |
+
if not self.learned_pos_emb and (not self.attn_config["alibi"]):
|
| 161 |
+
raise ValueError(f"Positional information must be provided to the model using either learned_pos_emb or alibi.")
|
mllm/flamingo/mpt/custom_embedding.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from torch import Tensor
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class SharedEmbedding(nn.Embedding):
|
| 8 |
+
def forward(self, input: Tensor, unembed: bool = False) -> Tensor:
|
| 9 |
+
if unembed:
|
| 10 |
+
return F.linear(input, self.weight)
|
| 11 |
+
return super().forward(input)
|
mllm/flamingo/mpt/flash_attn_triton.py
ADDED
|
@@ -0,0 +1,841 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Copied from https://github.com/HazyResearch/flash-attention/blob/eff9fe6b8076df59d64d7a3f464696738a3c7c24/flash_attn/flash_attn_triton.py
|
| 3 |
+
update imports to use 'triton_pre_mlir'
|
| 4 |
+
|
| 5 |
+
*Experimental* implementation of FlashAttention in Triton.
|
| 6 |
+
Tested with triton==2.0.0.dev20221202.
|
| 7 |
+
Triton 2.0 has a new backend (MLIR) but seems like it doesn't yet work for head dimensions
|
| 8 |
+
other than 64:
|
| 9 |
+
https://github.com/openai/triton/blob/d376020f90002757eea3ea9475d4f7cfc2ec5ead/python/triton/ops/flash_attention.py#L207
|
| 10 |
+
We'll update this implementation with the new Triton backend once this is fixed.
|
| 11 |
+
|
| 12 |
+
We use the FlashAttention implementation from Phil Tillet a starting point.
|
| 13 |
+
https://github.com/openai/triton/blob/master/python/tutorials/06-fused-attention.py
|
| 14 |
+
|
| 15 |
+
Changes:
|
| 16 |
+
- Implement both causal and non-causal attention.
|
| 17 |
+
- Implement both self-attention and cross-attention.
|
| 18 |
+
- Support arbitrary seqlens (not just multiples of 128), for both forward and backward.
|
| 19 |
+
- Support all head dimensions up to 128 (not just 16, 32, 64, 128), for both forward and backward.
|
| 20 |
+
- Support attention bias.
|
| 21 |
+
- Speed up the forward pass a bit, and only store the LSE instead of m and l.
|
| 22 |
+
- Make the backward for d=128 much faster by reducing register spilling.
|
| 23 |
+
- Optionally parallelize the backward pass across seqlen_k, to deal with the case of
|
| 24 |
+
small batch size * nheads.
|
| 25 |
+
|
| 26 |
+
Caution:
|
| 27 |
+
- This is an *experimental* implementation. The forward pass should be quite robust but
|
| 28 |
+
I'm not 100% sure that the backward pass doesn't have race conditions (due to the Triton compiler).
|
| 29 |
+
- This implementation has only been tested on A100.
|
| 30 |
+
- If you plan to use headdim other than 64 and 128, you should test for race conditions
|
| 31 |
+
(due to the Triton compiler), as done in tests/test_flash_attn.py
|
| 32 |
+
"test_flash_attn_triton_race_condition". I've tested and fixed many race conditions
|
| 33 |
+
for different head dimensions (40, 48, 64, 128, 80, 88, 96), but I'm still not 100% confident
|
| 34 |
+
that there are none left for other head dimensions.
|
| 35 |
+
|
| 36 |
+
Differences between this Triton version and the CUDA version:
|
| 37 |
+
- Triton version doesn't support dropout.
|
| 38 |
+
- Triton forward is generally faster than CUDA forward, while Triton backward is
|
| 39 |
+
generally slower than CUDA backward. Overall Triton forward + backward is slightly slower
|
| 40 |
+
than CUDA forward + backward.
|
| 41 |
+
- Triton version doesn't support different sequence lengths in a batch (i.e., RaggedTensor/NestedTensor).
|
| 42 |
+
- Triton version supports attention bias, while CUDA version doesn't.
|
| 43 |
+
"""
|
| 44 |
+
import math
|
| 45 |
+
import torch
|
| 46 |
+
import triton_pre_mlir as triton
|
| 47 |
+
import triton_pre_mlir.language as tl
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
@triton.heuristics(
|
| 51 |
+
{
|
| 52 |
+
"EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0,
|
| 53 |
+
"EVEN_N": lambda args: args["seqlen_k"] % args["BLOCK_N"] == 0,
|
| 54 |
+
"EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"],
|
| 55 |
+
}
|
| 56 |
+
)
|
| 57 |
+
@triton.jit
|
| 58 |
+
def _fwd_kernel(
|
| 59 |
+
Q,
|
| 60 |
+
K,
|
| 61 |
+
V,
|
| 62 |
+
Bias,
|
| 63 |
+
Out,
|
| 64 |
+
Lse,
|
| 65 |
+
TMP,
|
| 66 |
+
softmax_scale,
|
| 67 |
+
stride_qb,
|
| 68 |
+
stride_qh,
|
| 69 |
+
stride_qm,
|
| 70 |
+
stride_kb,
|
| 71 |
+
stride_kh,
|
| 72 |
+
stride_kn,
|
| 73 |
+
stride_vb,
|
| 74 |
+
stride_vh,
|
| 75 |
+
stride_vn,
|
| 76 |
+
stride_bb,
|
| 77 |
+
stride_bh,
|
| 78 |
+
stride_bm,
|
| 79 |
+
stride_ob,
|
| 80 |
+
stride_oh,
|
| 81 |
+
stride_om,
|
| 82 |
+
nheads,
|
| 83 |
+
seqlen_q,
|
| 84 |
+
seqlen_k,
|
| 85 |
+
seqlen_q_rounded,
|
| 86 |
+
headdim,
|
| 87 |
+
CACHE_KEY_SEQLEN_Q,
|
| 88 |
+
CACHE_KEY_SEQLEN_K,
|
| 89 |
+
BIAS_TYPE: tl.constexpr,
|
| 90 |
+
IS_CAUSAL: tl.constexpr,
|
| 91 |
+
BLOCK_HEADDIM: tl.constexpr,
|
| 92 |
+
EVEN_M: tl.constexpr,
|
| 93 |
+
EVEN_N: tl.constexpr,
|
| 94 |
+
EVEN_HEADDIM: tl.constexpr,
|
| 95 |
+
BLOCK_M: tl.constexpr,
|
| 96 |
+
BLOCK_N: tl.constexpr,
|
| 97 |
+
):
|
| 98 |
+
start_m = tl.program_id(0)
|
| 99 |
+
off_hb = tl.program_id(1)
|
| 100 |
+
off_b = off_hb // nheads
|
| 101 |
+
off_h = off_hb % nheads
|
| 102 |
+
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
| 103 |
+
offs_n = tl.arange(0, BLOCK_N)
|
| 104 |
+
offs_d = tl.arange(0, BLOCK_HEADDIM)
|
| 105 |
+
q_ptrs = Q + off_b * stride_qb + off_h * stride_qh + (offs_m[:, None] * stride_qm + offs_d[None, :])
|
| 106 |
+
k_ptrs = K + off_b * stride_kb + off_h * stride_kh + (offs_n[:, None] * stride_kn + offs_d[None, :])
|
| 107 |
+
v_ptrs = V + off_b * stride_vb + off_h * stride_vh + (offs_n[:, None] * stride_vn + offs_d[None, :])
|
| 108 |
+
if BIAS_TYPE == "vector":
|
| 109 |
+
b_ptrs = Bias + off_b * stride_bb + off_h * stride_bh + offs_n
|
| 110 |
+
elif BIAS_TYPE == "matrix":
|
| 111 |
+
b_ptrs = Bias + off_b * stride_bb + off_h * stride_bh + (offs_m[:, None] * stride_bm + offs_n[None, :])
|
| 112 |
+
t_ptrs = TMP + off_hb * seqlen_q_rounded + offs_m
|
| 113 |
+
lse_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
|
| 114 |
+
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
|
| 115 |
+
acc_o = tl.zeros([BLOCK_M, BLOCK_HEADDIM], dtype=tl.float32)
|
| 116 |
+
if EVEN_M & EVEN_N:
|
| 117 |
+
if EVEN_HEADDIM:
|
| 118 |
+
q = tl.load(q_ptrs)
|
| 119 |
+
else:
|
| 120 |
+
q = tl.load(q_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
|
| 121 |
+
elif EVEN_HEADDIM:
|
| 122 |
+
q = tl.load(q_ptrs, mask=offs_m[:, None] < seqlen_q, other=0.0)
|
| 123 |
+
else:
|
| 124 |
+
q = tl.load(q_ptrs, mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim), other=0.0)
|
| 125 |
+
end_n = seqlen_k if not IS_CAUSAL else tl.minimum((start_m + 1) * BLOCK_M, seqlen_k)
|
| 126 |
+
for start_n in range(0, end_n, BLOCK_N):
|
| 127 |
+
start_n = tl.multiple_of(start_n, BLOCK_N)
|
| 128 |
+
if EVEN_N & EVEN_M:
|
| 129 |
+
if EVEN_HEADDIM:
|
| 130 |
+
k = tl.load(k_ptrs + start_n * stride_kn)
|
| 131 |
+
else:
|
| 132 |
+
k = tl.load(k_ptrs + start_n * stride_kn, mask=offs_d[None, :] < headdim, other=0.0)
|
| 133 |
+
elif EVEN_HEADDIM:
|
| 134 |
+
k = tl.load(k_ptrs + start_n * stride_kn, mask=(start_n + offs_n)[:, None] < seqlen_k, other=0.0)
|
| 135 |
+
else:
|
| 136 |
+
k = tl.load(k_ptrs + start_n * stride_kn, mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim), other=0.0)
|
| 137 |
+
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
| 138 |
+
qk += tl.dot(q, k, trans_b=True)
|
| 139 |
+
if not EVEN_N:
|
| 140 |
+
qk += tl.where((start_n + offs_n)[None, :] < seqlen_k, 0, float("-inf"))
|
| 141 |
+
if IS_CAUSAL:
|
| 142 |
+
qk += tl.where(offs_m[:, None] >= (start_n + offs_n)[None, :], 0, float("-inf"))
|
| 143 |
+
if BIAS_TYPE != "none":
|
| 144 |
+
if BIAS_TYPE == "vector":
|
| 145 |
+
if EVEN_N:
|
| 146 |
+
bias = tl.load(b_ptrs + start_n).to(tl.float32)
|
| 147 |
+
else:
|
| 148 |
+
bias = tl.load(b_ptrs + start_n, mask=start_n + offs_n < seqlen_k, other=0.0).to(tl.float32)
|
| 149 |
+
bias = bias[None, :]
|
| 150 |
+
elif BIAS_TYPE == "matrix":
|
| 151 |
+
if EVEN_M & EVEN_N:
|
| 152 |
+
bias = tl.load(b_ptrs + start_n).to(tl.float32)
|
| 153 |
+
else:
|
| 154 |
+
bias = tl.load(b_ptrs + start_n, mask=(offs_m[:, None] < seqlen_q) & ((start_n + offs_n)[None, :] < seqlen_k), other=0.0).to(tl.float32)
|
| 155 |
+
qk = qk * softmax_scale + bias
|
| 156 |
+
m_ij = tl.maximum(tl.max(qk, 1), lse_i)
|
| 157 |
+
p = tl.exp(qk - m_ij[:, None])
|
| 158 |
+
else:
|
| 159 |
+
m_ij = tl.maximum(tl.max(qk, 1) * softmax_scale, lse_i)
|
| 160 |
+
p = tl.exp(qk * softmax_scale - m_ij[:, None])
|
| 161 |
+
l_ij = tl.sum(p, 1)
|
| 162 |
+
acc_o_scale = tl.exp(m_i - m_ij)
|
| 163 |
+
tl.store(t_ptrs, acc_o_scale)
|
| 164 |
+
acc_o_scale = tl.load(t_ptrs)
|
| 165 |
+
acc_o = acc_o * acc_o_scale[:, None]
|
| 166 |
+
if EVEN_N & EVEN_M:
|
| 167 |
+
if EVEN_HEADDIM:
|
| 168 |
+
v = tl.load(v_ptrs + start_n * stride_vn)
|
| 169 |
+
else:
|
| 170 |
+
v = tl.load(v_ptrs + start_n * stride_vn, mask=offs_d[None, :] < headdim, other=0.0)
|
| 171 |
+
elif EVEN_HEADDIM:
|
| 172 |
+
v = tl.load(v_ptrs + start_n * stride_vn, mask=(start_n + offs_n)[:, None] < seqlen_k, other=0.0)
|
| 173 |
+
else:
|
| 174 |
+
v = tl.load(v_ptrs + start_n * stride_vn, mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim), other=0.0)
|
| 175 |
+
p = p.to(v.dtype)
|
| 176 |
+
acc_o += tl.dot(p, v)
|
| 177 |
+
m_i = m_ij
|
| 178 |
+
l_i_new = tl.exp(lse_i - m_ij) + l_ij
|
| 179 |
+
lse_i = m_ij + tl.log(l_i_new)
|
| 180 |
+
o_scale = tl.exp(m_i - lse_i)
|
| 181 |
+
tl.store(t_ptrs, o_scale)
|
| 182 |
+
o_scale = tl.load(t_ptrs)
|
| 183 |
+
acc_o = acc_o * o_scale[:, None]
|
| 184 |
+
start_m = tl.program_id(0)
|
| 185 |
+
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
| 186 |
+
lse_ptrs = Lse + off_hb * seqlen_q_rounded + offs_m
|
| 187 |
+
tl.store(lse_ptrs, lse_i)
|
| 188 |
+
offs_d = tl.arange(0, BLOCK_HEADDIM)
|
| 189 |
+
out_ptrs = Out + off_b * stride_ob + off_h * stride_oh + (offs_m[:, None] * stride_om + offs_d[None, :])
|
| 190 |
+
if EVEN_M:
|
| 191 |
+
if EVEN_HEADDIM:
|
| 192 |
+
tl.store(out_ptrs, acc_o)
|
| 193 |
+
else:
|
| 194 |
+
tl.store(out_ptrs, acc_o, mask=offs_d[None, :] < headdim)
|
| 195 |
+
elif EVEN_HEADDIM:
|
| 196 |
+
tl.store(out_ptrs, acc_o, mask=offs_m[:, None] < seqlen_q)
|
| 197 |
+
else:
|
| 198 |
+
tl.store(out_ptrs, acc_o, mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim))
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
@triton.jit
|
| 202 |
+
def _bwd_preprocess_do_o_dot(
|
| 203 |
+
Out,
|
| 204 |
+
DO,
|
| 205 |
+
Delta,
|
| 206 |
+
stride_ob,
|
| 207 |
+
stride_oh,
|
| 208 |
+
stride_om,
|
| 209 |
+
stride_dob,
|
| 210 |
+
stride_doh,
|
| 211 |
+
stride_dom,
|
| 212 |
+
nheads,
|
| 213 |
+
seqlen_q,
|
| 214 |
+
seqlen_q_rounded,
|
| 215 |
+
headdim,
|
| 216 |
+
BLOCK_M: tl.constexpr,
|
| 217 |
+
BLOCK_HEADDIM: tl.constexpr,
|
| 218 |
+
):
|
| 219 |
+
start_m = tl.program_id(0)
|
| 220 |
+
off_hb = tl.program_id(1)
|
| 221 |
+
off_b = off_hb // nheads
|
| 222 |
+
off_h = off_hb % nheads
|
| 223 |
+
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
| 224 |
+
offs_d = tl.arange(0, BLOCK_HEADDIM)
|
| 225 |
+
o = tl.load(
|
| 226 |
+
Out + off_b * stride_ob + off_h * stride_oh + offs_m[:, None] * stride_om + offs_d[None, :],
|
| 227 |
+
mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
|
| 228 |
+
other=0.0,
|
| 229 |
+
).to(tl.float32)
|
| 230 |
+
do = tl.load(
|
| 231 |
+
DO + off_b * stride_dob + off_h * stride_doh + offs_m[:, None] * stride_dom + offs_d[None, :],
|
| 232 |
+
mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
|
| 233 |
+
other=0.0,
|
| 234 |
+
).to(tl.float32)
|
| 235 |
+
delta = tl.sum(o * do, axis=1)
|
| 236 |
+
tl.store(Delta + off_hb * seqlen_q_rounded + offs_m, delta)
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
@triton.jit
|
| 240 |
+
def _bwd_store_dk_dv(dk_ptrs, dv_ptrs, dk, dv, offs_n, offs_d, seqlen_k, headdim, EVEN_M: tl.constexpr, EVEN_N: tl.constexpr, EVEN_HEADDIM: tl.constexpr):
|
| 241 |
+
if EVEN_N & EVEN_M:
|
| 242 |
+
if EVEN_HEADDIM:
|
| 243 |
+
tl.store(dv_ptrs, dv)
|
| 244 |
+
tl.store(dk_ptrs, dk)
|
| 245 |
+
else:
|
| 246 |
+
tl.store(dv_ptrs, dv, mask=offs_d[None, :] < headdim)
|
| 247 |
+
tl.store(dk_ptrs, dk, mask=offs_d[None, :] < headdim)
|
| 248 |
+
elif EVEN_HEADDIM:
|
| 249 |
+
tl.store(dv_ptrs, dv, mask=offs_n[:, None] < seqlen_k)
|
| 250 |
+
tl.store(dk_ptrs, dk, mask=offs_n[:, None] < seqlen_k)
|
| 251 |
+
else:
|
| 252 |
+
tl.store(dv_ptrs, dv, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim))
|
| 253 |
+
tl.store(dk_ptrs, dk, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim))
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
@triton.jit
|
| 257 |
+
def _bwd_kernel_one_col_block(
|
| 258 |
+
start_n,
|
| 259 |
+
Q,
|
| 260 |
+
K,
|
| 261 |
+
V,
|
| 262 |
+
Bias,
|
| 263 |
+
DO,
|
| 264 |
+
DQ,
|
| 265 |
+
DK,
|
| 266 |
+
DV,
|
| 267 |
+
LSE,
|
| 268 |
+
D,
|
| 269 |
+
softmax_scale,
|
| 270 |
+
stride_qm,
|
| 271 |
+
stride_kn,
|
| 272 |
+
stride_vn,
|
| 273 |
+
stride_bm,
|
| 274 |
+
stride_dom,
|
| 275 |
+
stride_dqm,
|
| 276 |
+
stride_dkn,
|
| 277 |
+
stride_dvn,
|
| 278 |
+
seqlen_q,
|
| 279 |
+
seqlen_k,
|
| 280 |
+
headdim,
|
| 281 |
+
ATOMIC_ADD: tl.constexpr,
|
| 282 |
+
BIAS_TYPE: tl.constexpr,
|
| 283 |
+
IS_CAUSAL: tl.constexpr,
|
| 284 |
+
BLOCK_HEADDIM: tl.constexpr,
|
| 285 |
+
EVEN_M: tl.constexpr,
|
| 286 |
+
EVEN_N: tl.constexpr,
|
| 287 |
+
EVEN_HEADDIM: tl.constexpr,
|
| 288 |
+
BLOCK_M: tl.constexpr,
|
| 289 |
+
BLOCK_N: tl.constexpr,
|
| 290 |
+
):
|
| 291 |
+
begin_m = 0 if not IS_CAUSAL else start_n * BLOCK_N // BLOCK_M * BLOCK_M
|
| 292 |
+
offs_qm = begin_m + tl.arange(0, BLOCK_M)
|
| 293 |
+
offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
| 294 |
+
offs_m = tl.arange(0, BLOCK_M)
|
| 295 |
+
offs_d = tl.arange(0, BLOCK_HEADDIM)
|
| 296 |
+
q_ptrs = Q + (offs_qm[:, None] * stride_qm + offs_d[None, :])
|
| 297 |
+
k_ptrs = K + (offs_n[:, None] * stride_kn + offs_d[None, :])
|
| 298 |
+
v_ptrs = V + (offs_n[:, None] * stride_vn + offs_d[None, :])
|
| 299 |
+
do_ptrs = DO + (offs_qm[:, None] * stride_dom + offs_d[None, :])
|
| 300 |
+
dq_ptrs = DQ + (offs_qm[:, None] * stride_dqm + offs_d[None, :])
|
| 301 |
+
if BIAS_TYPE == "vector":
|
| 302 |
+
b_ptrs = Bias + offs_n
|
| 303 |
+
elif BIAS_TYPE == "matrix":
|
| 304 |
+
b_ptrs = Bias + (offs_qm[:, None] * stride_bm + offs_n[None, :])
|
| 305 |
+
dv = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32)
|
| 306 |
+
dk = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32)
|
| 307 |
+
if begin_m >= seqlen_q:
|
| 308 |
+
dv_ptrs = DV + (offs_n[:, None] * stride_dvn + offs_d[None, :])
|
| 309 |
+
dk_ptrs = DK + (offs_n[:, None] * stride_dkn + offs_d[None, :])
|
| 310 |
+
_bwd_store_dk_dv(dk_ptrs, dv_ptrs, dk, dv, offs_n, offs_d, seqlen_k, headdim, EVEN_M=EVEN_M, EVEN_N=EVEN_N, EVEN_HEADDIM=EVEN_HEADDIM)
|
| 311 |
+
return
|
| 312 |
+
if EVEN_N & EVEN_M:
|
| 313 |
+
if EVEN_HEADDIM:
|
| 314 |
+
k = tl.load(k_ptrs)
|
| 315 |
+
v = tl.load(v_ptrs)
|
| 316 |
+
else:
|
| 317 |
+
k = tl.load(k_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
|
| 318 |
+
v = tl.load(v_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
|
| 319 |
+
elif EVEN_HEADDIM:
|
| 320 |
+
k = tl.load(k_ptrs, mask=offs_n[:, None] < seqlen_k, other=0.0)
|
| 321 |
+
v = tl.load(v_ptrs, mask=offs_n[:, None] < seqlen_k, other=0.0)
|
| 322 |
+
else:
|
| 323 |
+
k = tl.load(k_ptrs, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim), other=0.0)
|
| 324 |
+
v = tl.load(v_ptrs, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim), other=0.0)
|
| 325 |
+
num_block_m = tl.cdiv(seqlen_q, BLOCK_M)
|
| 326 |
+
for start_m in range(begin_m, num_block_m * BLOCK_M, BLOCK_M):
|
| 327 |
+
start_m = tl.multiple_of(start_m, BLOCK_M)
|
| 328 |
+
offs_m_curr = start_m + offs_m
|
| 329 |
+
if EVEN_M & EVEN_HEADDIM:
|
| 330 |
+
q = tl.load(q_ptrs)
|
| 331 |
+
elif EVEN_HEADDIM:
|
| 332 |
+
q = tl.load(q_ptrs, mask=offs_m_curr[:, None] < seqlen_q, other=0.0)
|
| 333 |
+
else:
|
| 334 |
+
q = tl.load(q_ptrs, mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim), other=0.0)
|
| 335 |
+
qk = tl.dot(q, k, trans_b=True)
|
| 336 |
+
if not EVEN_N:
|
| 337 |
+
qk = tl.where(offs_n[None, :] < seqlen_k, qk, float("-inf"))
|
| 338 |
+
if IS_CAUSAL:
|
| 339 |
+
qk = tl.where(offs_m_curr[:, None] >= offs_n[None, :], qk, float("-inf"))
|
| 340 |
+
if BIAS_TYPE != "none":
|
| 341 |
+
tl.debug_barrier()
|
| 342 |
+
if BIAS_TYPE == "vector":
|
| 343 |
+
if EVEN_N:
|
| 344 |
+
bias = tl.load(b_ptrs).to(tl.float32)
|
| 345 |
+
else:
|
| 346 |
+
bias = tl.load(b_ptrs, mask=offs_n < seqlen_k, other=0.0).to(tl.float32)
|
| 347 |
+
bias = bias[None, :]
|
| 348 |
+
elif BIAS_TYPE == "matrix":
|
| 349 |
+
if EVEN_M & EVEN_N:
|
| 350 |
+
bias = tl.load(b_ptrs).to(tl.float32)
|
| 351 |
+
else:
|
| 352 |
+
bias = tl.load(b_ptrs, mask=(offs_m_curr[:, None] < seqlen_q) & (offs_n[None, :] < seqlen_k), other=0.0).to(tl.float32)
|
| 353 |
+
qk = qk * softmax_scale + bias
|
| 354 |
+
if not EVEN_M & EVEN_HEADDIM:
|
| 355 |
+
tl.debug_barrier()
|
| 356 |
+
lse_i = tl.load(LSE + offs_m_curr)
|
| 357 |
+
if BIAS_TYPE == "none":
|
| 358 |
+
p = tl.exp(qk * softmax_scale - lse_i[:, None])
|
| 359 |
+
else:
|
| 360 |
+
p = tl.exp(qk - lse_i[:, None])
|
| 361 |
+
if EVEN_M & EVEN_HEADDIM:
|
| 362 |
+
do = tl.load(do_ptrs)
|
| 363 |
+
else:
|
| 364 |
+
do = tl.load(do_ptrs, mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim), other=0.0)
|
| 365 |
+
dv += tl.dot(p.to(do.dtype), do, trans_a=True)
|
| 366 |
+
if not EVEN_M & EVEN_HEADDIM:
|
| 367 |
+
tl.debug_barrier()
|
| 368 |
+
dp = tl.dot(do, v, trans_b=True)
|
| 369 |
+
if not EVEN_HEADDIM:
|
| 370 |
+
tl.debug_barrier()
|
| 371 |
+
Di = tl.load(D + offs_m_curr)
|
| 372 |
+
ds = (p * (dp - Di[:, None]) * softmax_scale).to(q.dtype)
|
| 373 |
+
dk += tl.dot(ds, q, trans_a=True)
|
| 374 |
+
if not EVEN_M & EVEN_HEADDIM:
|
| 375 |
+
tl.debug_barrier()
|
| 376 |
+
if not ATOMIC_ADD:
|
| 377 |
+
if EVEN_M & EVEN_HEADDIM:
|
| 378 |
+
dq = tl.load(dq_ptrs, eviction_policy="evict_last")
|
| 379 |
+
dq += tl.dot(ds, k)
|
| 380 |
+
tl.store(dq_ptrs, dq, eviction_policy="evict_last")
|
| 381 |
+
elif EVEN_HEADDIM:
|
| 382 |
+
dq = tl.load(dq_ptrs, mask=offs_m_curr[:, None] < seqlen_q, other=0.0, eviction_policy="evict_last")
|
| 383 |
+
dq += tl.dot(ds, k)
|
| 384 |
+
tl.store(dq_ptrs, dq, mask=offs_m_curr[:, None] < seqlen_q, eviction_policy="evict_last")
|
| 385 |
+
else:
|
| 386 |
+
dq = tl.load(dq_ptrs, mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim), other=0.0, eviction_policy="evict_last")
|
| 387 |
+
dq += tl.dot(ds, k)
|
| 388 |
+
tl.store(dq_ptrs, dq, mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim), eviction_policy="evict_last")
|
| 389 |
+
else:
|
| 390 |
+
dq = tl.dot(ds, k)
|
| 391 |
+
if EVEN_M & EVEN_HEADDIM:
|
| 392 |
+
tl.atomic_add(dq_ptrs, dq)
|
| 393 |
+
elif EVEN_HEADDIM:
|
| 394 |
+
tl.atomic_add(dq_ptrs, dq, mask=offs_m_curr[:, None] < seqlen_q)
|
| 395 |
+
else:
|
| 396 |
+
tl.atomic_add(dq_ptrs, dq, mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim))
|
| 397 |
+
dq_ptrs += BLOCK_M * stride_dqm
|
| 398 |
+
q_ptrs += BLOCK_M * stride_qm
|
| 399 |
+
do_ptrs += BLOCK_M * stride_dom
|
| 400 |
+
if BIAS_TYPE == "matrix":
|
| 401 |
+
b_ptrs += BLOCK_M * stride_bm
|
| 402 |
+
dv_ptrs = DV + (offs_n[:, None] * stride_dvn + offs_d[None, :])
|
| 403 |
+
dk_ptrs = DK + (offs_n[:, None] * stride_dkn + offs_d[None, :])
|
| 404 |
+
_bwd_store_dk_dv(dk_ptrs, dv_ptrs, dk, dv, offs_n, offs_d, seqlen_k, headdim, EVEN_M=EVEN_M, EVEN_N=EVEN_N, EVEN_HEADDIM=EVEN_HEADDIM)
|
| 405 |
+
|
| 406 |
+
|
| 407 |
+
def init_to_zero(name):
|
| 408 |
+
return lambda nargs: nargs[name].zero_()
|
| 409 |
+
|
| 410 |
+
|
| 411 |
+
@triton.autotune(
|
| 412 |
+
configs=[
|
| 413 |
+
triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "SEQUENCE_PARALLEL": False}, num_warps=8, num_stages=1, pre_hook=init_to_zero("DQ")),
|
| 414 |
+
triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "SEQUENCE_PARALLEL": True}, num_warps=8, num_stages=1, pre_hook=init_to_zero("DQ")),
|
| 415 |
+
],
|
| 416 |
+
key=["CACHE_KEY_SEQLEN_Q", "CACHE_KEY_SEQLEN_K", "BIAS_TYPE", "IS_CAUSAL", "BLOCK_HEADDIM"],
|
| 417 |
+
)
|
| 418 |
+
@triton.heuristics(
|
| 419 |
+
{
|
| 420 |
+
"EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0,
|
| 421 |
+
"EVEN_N": lambda args: args["seqlen_k"] % args["BLOCK_N"] == 0,
|
| 422 |
+
"EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"],
|
| 423 |
+
}
|
| 424 |
+
)
|
| 425 |
+
@triton.jit
|
| 426 |
+
def _bwd_kernel(
|
| 427 |
+
Q,
|
| 428 |
+
K,
|
| 429 |
+
V,
|
| 430 |
+
Bias,
|
| 431 |
+
DO,
|
| 432 |
+
DQ,
|
| 433 |
+
DK,
|
| 434 |
+
DV,
|
| 435 |
+
LSE,
|
| 436 |
+
D,
|
| 437 |
+
softmax_scale,
|
| 438 |
+
stride_qb,
|
| 439 |
+
stride_qh,
|
| 440 |
+
stride_qm,
|
| 441 |
+
stride_kb,
|
| 442 |
+
stride_kh,
|
| 443 |
+
stride_kn,
|
| 444 |
+
stride_vb,
|
| 445 |
+
stride_vh,
|
| 446 |
+
stride_vn,
|
| 447 |
+
stride_bb,
|
| 448 |
+
stride_bh,
|
| 449 |
+
stride_bm,
|
| 450 |
+
stride_dob,
|
| 451 |
+
stride_doh,
|
| 452 |
+
stride_dom,
|
| 453 |
+
stride_dqb,
|
| 454 |
+
stride_dqh,
|
| 455 |
+
stride_dqm,
|
| 456 |
+
stride_dkb,
|
| 457 |
+
stride_dkh,
|
| 458 |
+
stride_dkn,
|
| 459 |
+
stride_dvb,
|
| 460 |
+
stride_dvh,
|
| 461 |
+
stride_dvn,
|
| 462 |
+
nheads,
|
| 463 |
+
seqlen_q,
|
| 464 |
+
seqlen_k,
|
| 465 |
+
seqlen_q_rounded,
|
| 466 |
+
headdim,
|
| 467 |
+
CACHE_KEY_SEQLEN_Q,
|
| 468 |
+
CACHE_KEY_SEQLEN_K,
|
| 469 |
+
BIAS_TYPE: tl.constexpr,
|
| 470 |
+
IS_CAUSAL: tl.constexpr,
|
| 471 |
+
BLOCK_HEADDIM: tl.constexpr,
|
| 472 |
+
SEQUENCE_PARALLEL: tl.constexpr,
|
| 473 |
+
EVEN_M: tl.constexpr,
|
| 474 |
+
EVEN_N: tl.constexpr,
|
| 475 |
+
EVEN_HEADDIM: tl.constexpr,
|
| 476 |
+
BLOCK_M: tl.constexpr,
|
| 477 |
+
BLOCK_N: tl.constexpr,
|
| 478 |
+
):
|
| 479 |
+
off_hb = tl.program_id(1)
|
| 480 |
+
off_b = off_hb // nheads
|
| 481 |
+
off_h = off_hb % nheads
|
| 482 |
+
Q += off_b * stride_qb + off_h * stride_qh
|
| 483 |
+
K += off_b * stride_kb + off_h * stride_kh
|
| 484 |
+
V += off_b * stride_vb + off_h * stride_vh
|
| 485 |
+
DO += off_b * stride_dob + off_h * stride_doh
|
| 486 |
+
DQ += off_b * stride_dqb + off_h * stride_dqh
|
| 487 |
+
DK += off_b * stride_dkb + off_h * stride_dkh
|
| 488 |
+
DV += off_b * stride_dvb + off_h * stride_dvh
|
| 489 |
+
if BIAS_TYPE != "none":
|
| 490 |
+
Bias += off_b * stride_bb + off_h * stride_bh
|
| 491 |
+
D += off_hb * seqlen_q_rounded
|
| 492 |
+
LSE += off_hb * seqlen_q_rounded
|
| 493 |
+
if not SEQUENCE_PARALLEL:
|
| 494 |
+
num_block_n = tl.cdiv(seqlen_k, BLOCK_N)
|
| 495 |
+
for start_n in range(0, num_block_n):
|
| 496 |
+
_bwd_kernel_one_col_block(
|
| 497 |
+
start_n,
|
| 498 |
+
Q,
|
| 499 |
+
K,
|
| 500 |
+
V,
|
| 501 |
+
Bias,
|
| 502 |
+
DO,
|
| 503 |
+
DQ,
|
| 504 |
+
DK,
|
| 505 |
+
DV,
|
| 506 |
+
LSE,
|
| 507 |
+
D,
|
| 508 |
+
softmax_scale,
|
| 509 |
+
stride_qm,
|
| 510 |
+
stride_kn,
|
| 511 |
+
stride_vn,
|
| 512 |
+
stride_bm,
|
| 513 |
+
stride_dom,
|
| 514 |
+
stride_dqm,
|
| 515 |
+
stride_dkn,
|
| 516 |
+
stride_dvn,
|
| 517 |
+
seqlen_q,
|
| 518 |
+
seqlen_k,
|
| 519 |
+
headdim,
|
| 520 |
+
ATOMIC_ADD=False,
|
| 521 |
+
BIAS_TYPE=BIAS_TYPE,
|
| 522 |
+
IS_CAUSAL=IS_CAUSAL,
|
| 523 |
+
BLOCK_HEADDIM=BLOCK_HEADDIM,
|
| 524 |
+
EVEN_M=EVEN_M,
|
| 525 |
+
EVEN_N=EVEN_N,
|
| 526 |
+
EVEN_HEADDIM=EVEN_HEADDIM,
|
| 527 |
+
BLOCK_M=BLOCK_M,
|
| 528 |
+
BLOCK_N=BLOCK_N,
|
| 529 |
+
)
|
| 530 |
+
else:
|
| 531 |
+
start_n = tl.program_id(0)
|
| 532 |
+
_bwd_kernel_one_col_block(
|
| 533 |
+
start_n,
|
| 534 |
+
Q,
|
| 535 |
+
K,
|
| 536 |
+
V,
|
| 537 |
+
Bias,
|
| 538 |
+
DO,
|
| 539 |
+
DQ,
|
| 540 |
+
DK,
|
| 541 |
+
DV,
|
| 542 |
+
LSE,
|
| 543 |
+
D,
|
| 544 |
+
softmax_scale,
|
| 545 |
+
stride_qm,
|
| 546 |
+
stride_kn,
|
| 547 |
+
stride_vn,
|
| 548 |
+
stride_bm,
|
| 549 |
+
stride_dom,
|
| 550 |
+
stride_dqm,
|
| 551 |
+
stride_dkn,
|
| 552 |
+
stride_dvn,
|
| 553 |
+
seqlen_q,
|
| 554 |
+
seqlen_k,
|
| 555 |
+
headdim,
|
| 556 |
+
ATOMIC_ADD=True,
|
| 557 |
+
BIAS_TYPE=BIAS_TYPE,
|
| 558 |
+
IS_CAUSAL=IS_CAUSAL,
|
| 559 |
+
BLOCK_HEADDIM=BLOCK_HEADDIM,
|
| 560 |
+
EVEN_M=EVEN_M,
|
| 561 |
+
EVEN_N=EVEN_N,
|
| 562 |
+
EVEN_HEADDIM=EVEN_HEADDIM,
|
| 563 |
+
BLOCK_M=BLOCK_M,
|
| 564 |
+
BLOCK_N=BLOCK_N,
|
| 565 |
+
)
|
| 566 |
+
|
| 567 |
+
|
| 568 |
+
def _flash_attn_forward(q, k, v, bias=None, causal=False, softmax_scale=None):
|
| 569 |
+
(batch, seqlen_q, nheads, d) = q.shape
|
| 570 |
+
(_, seqlen_k, _, _) = k.shape
|
| 571 |
+
assert k.shape == (batch, seqlen_k, nheads, d)
|
| 572 |
+
assert v.shape == (batch, seqlen_k, nheads, d)
|
| 573 |
+
assert d <= 128, "FlashAttention only support head dimensions up to 128"
|
| 574 |
+
assert q.dtype == k.dtype == v.dtype, "All tensors must have the same type"
|
| 575 |
+
assert q.dtype in [torch.float16, torch.bfloat16], "Only support fp16 and bf16"
|
| 576 |
+
assert q.is_cuda and k.is_cuda and v.is_cuda
|
| 577 |
+
softmax_scale = softmax_scale or 1.0 / math.sqrt(d)
|
| 578 |
+
has_bias = bias is not None
|
| 579 |
+
bias_type = "none"
|
| 580 |
+
if has_bias:
|
| 581 |
+
assert bias.dtype in [q.dtype, torch.float]
|
| 582 |
+
assert bias.is_cuda
|
| 583 |
+
assert bias.dim() == 4
|
| 584 |
+
if bias.stride(-1) != 1:
|
| 585 |
+
bias = bias.contiguous()
|
| 586 |
+
if bias.shape[2:] == (1, seqlen_k):
|
| 587 |
+
bias_type = "vector"
|
| 588 |
+
elif bias.shape[2:] == (seqlen_q, seqlen_k):
|
| 589 |
+
bias_type = "matrix"
|
| 590 |
+
else:
|
| 591 |
+
raise RuntimeError("Last 2 dimensions of bias must be (1, seqlen_k) or (seqlen_q, seqlen_k)")
|
| 592 |
+
bias = bias.expand(batch, nheads, seqlen_q, seqlen_k)
|
| 593 |
+
bias_strides = (bias.stride(0), bias.stride(1), bias.stride(2)) if has_bias else (0, 0, 0)
|
| 594 |
+
seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128
|
| 595 |
+
lse = torch.empty((batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32)
|
| 596 |
+
tmp = torch.empty((batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32)
|
| 597 |
+
o = torch.empty_like(q)
|
| 598 |
+
BLOCK_HEADDIM = max(triton.next_power_of_2(d), 16)
|
| 599 |
+
BLOCK = 128
|
| 600 |
+
num_warps = 4 if d <= 64 else 8
|
| 601 |
+
grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK_M"]), batch * nheads)
|
| 602 |
+
_fwd_kernel[grid](
|
| 603 |
+
q,
|
| 604 |
+
k,
|
| 605 |
+
v,
|
| 606 |
+
bias,
|
| 607 |
+
o,
|
| 608 |
+
lse,
|
| 609 |
+
tmp,
|
| 610 |
+
softmax_scale,
|
| 611 |
+
q.stride(0),
|
| 612 |
+
q.stride(2),
|
| 613 |
+
q.stride(1),
|
| 614 |
+
k.stride(0),
|
| 615 |
+
k.stride(2),
|
| 616 |
+
k.stride(1),
|
| 617 |
+
v.stride(0),
|
| 618 |
+
v.stride(2),
|
| 619 |
+
v.stride(1),
|
| 620 |
+
*bias_strides,
|
| 621 |
+
o.stride(0),
|
| 622 |
+
o.stride(2),
|
| 623 |
+
o.stride(1),
|
| 624 |
+
nheads,
|
| 625 |
+
seqlen_q,
|
| 626 |
+
seqlen_k,
|
| 627 |
+
seqlen_q_rounded,
|
| 628 |
+
d,
|
| 629 |
+
seqlen_q // 32,
|
| 630 |
+
seqlen_k // 32,
|
| 631 |
+
bias_type,
|
| 632 |
+
causal,
|
| 633 |
+
BLOCK_HEADDIM,
|
| 634 |
+
BLOCK_M=BLOCK,
|
| 635 |
+
BLOCK_N=BLOCK,
|
| 636 |
+
num_warps=num_warps,
|
| 637 |
+
num_stages=1
|
| 638 |
+
)
|
| 639 |
+
return (o, lse, softmax_scale)
|
| 640 |
+
|
| 641 |
+
|
| 642 |
+
def _flash_attn_backward(do, q, k, v, o, lse, dq, dk, dv, bias=None, causal=False, softmax_scale=None):
|
| 643 |
+
if do.stride(-1) != 1:
|
| 644 |
+
do = do.contiguous()
|
| 645 |
+
(batch, seqlen_q, nheads, d) = q.shape
|
| 646 |
+
(_, seqlen_k, _, _) = k.shape
|
| 647 |
+
assert d <= 128
|
| 648 |
+
seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128
|
| 649 |
+
assert lse.shape == (batch, nheads, seqlen_q_rounded)
|
| 650 |
+
assert q.stride(-1) == k.stride(-1) == v.stride(-1) == o.stride(-1) == 1
|
| 651 |
+
assert dq.stride(-1) == dk.stride(-1) == dv.stride(-1) == 1
|
| 652 |
+
softmax_scale = softmax_scale or 1.0 / math.sqrt(d)
|
| 653 |
+
dq_accum = torch.empty_like(q, dtype=torch.float32)
|
| 654 |
+
delta = torch.empty_like(lse)
|
| 655 |
+
BLOCK_HEADDIM = max(triton.next_power_of_2(d), 16)
|
| 656 |
+
grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK_M"]), batch * nheads)
|
| 657 |
+
_bwd_preprocess_do_o_dot[grid](
|
| 658 |
+
o,
|
| 659 |
+
do,
|
| 660 |
+
delta,
|
| 661 |
+
o.stride(0),
|
| 662 |
+
o.stride(2),
|
| 663 |
+
o.stride(1),
|
| 664 |
+
do.stride(0),
|
| 665 |
+
do.stride(2),
|
| 666 |
+
do.stride(1),
|
| 667 |
+
nheads,
|
| 668 |
+
seqlen_q,
|
| 669 |
+
seqlen_q_rounded,
|
| 670 |
+
d,
|
| 671 |
+
BLOCK_M=128,
|
| 672 |
+
BLOCK_HEADDIM=BLOCK_HEADDIM,
|
| 673 |
+
)
|
| 674 |
+
has_bias = bias is not None
|
| 675 |
+
bias_type = "none"
|
| 676 |
+
if has_bias:
|
| 677 |
+
assert bias.dtype in [q.dtype, torch.float]
|
| 678 |
+
assert bias.is_cuda
|
| 679 |
+
assert bias.dim() == 4
|
| 680 |
+
assert bias.stride(-1) == 1
|
| 681 |
+
if bias.shape[2:] == (1, seqlen_k):
|
| 682 |
+
bias_type = "vector"
|
| 683 |
+
elif bias.shape[2:] == (seqlen_q, seqlen_k):
|
| 684 |
+
bias_type = "matrix"
|
| 685 |
+
else:
|
| 686 |
+
raise RuntimeError("Last 2 dimensions of bias must be (1, seqlen_k) or (seqlen_q, seqlen_k)")
|
| 687 |
+
bias = bias.expand(batch, nheads, seqlen_q, seqlen_k)
|
| 688 |
+
bias_strides = (bias.stride(0), bias.stride(1), bias.stride(2)) if has_bias else (0, 0, 0)
|
| 689 |
+
grid = lambda META: (triton.cdiv(seqlen_k, META["BLOCK_N"]) if META["SEQUENCE_PARALLEL"] else 1, batch * nheads)
|
| 690 |
+
_bwd_kernel[grid](
|
| 691 |
+
q,
|
| 692 |
+
k,
|
| 693 |
+
v,
|
| 694 |
+
bias,
|
| 695 |
+
do,
|
| 696 |
+
dq_accum,
|
| 697 |
+
dk,
|
| 698 |
+
dv,
|
| 699 |
+
lse,
|
| 700 |
+
delta,
|
| 701 |
+
softmax_scale,
|
| 702 |
+
q.stride(0),
|
| 703 |
+
q.stride(2),
|
| 704 |
+
q.stride(1),
|
| 705 |
+
k.stride(0),
|
| 706 |
+
k.stride(2),
|
| 707 |
+
k.stride(1),
|
| 708 |
+
v.stride(0),
|
| 709 |
+
v.stride(2),
|
| 710 |
+
v.stride(1),
|
| 711 |
+
*bias_strides,
|
| 712 |
+
do.stride(0),
|
| 713 |
+
do.stride(2),
|
| 714 |
+
do.stride(1),
|
| 715 |
+
dq_accum.stride(0),
|
| 716 |
+
dq_accum.stride(2),
|
| 717 |
+
dq_accum.stride(1),
|
| 718 |
+
dk.stride(0),
|
| 719 |
+
dk.stride(2),
|
| 720 |
+
dk.stride(1),
|
| 721 |
+
dv.stride(0),
|
| 722 |
+
dv.stride(2),
|
| 723 |
+
dv.stride(1),
|
| 724 |
+
nheads,
|
| 725 |
+
seqlen_q,
|
| 726 |
+
seqlen_k,
|
| 727 |
+
seqlen_q_rounded,
|
| 728 |
+
d,
|
| 729 |
+
seqlen_q // 32,
|
| 730 |
+
seqlen_k // 32,
|
| 731 |
+
bias_type,
|
| 732 |
+
causal,
|
| 733 |
+
BLOCK_HEADDIM
|
| 734 |
+
)
|
| 735 |
+
dq.copy_(dq_accum)
|
| 736 |
+
|
| 737 |
+
|
| 738 |
+
class FlashAttnQKVPackedFunc(torch.autograd.Function):
|
| 739 |
+
@staticmethod
|
| 740 |
+
def forward(ctx, qkv, bias=None, causal=False, softmax_scale=None):
|
| 741 |
+
"""
|
| 742 |
+
qkv: (batch, seqlen, 3, nheads, headdim)
|
| 743 |
+
bias: optional, shape broadcastible to (batch, nheads, seqlen, seqlen).
|
| 744 |
+
For example, ALiBi mask for causal would have shape (1, nheads, 1, seqlen).
|
| 745 |
+
ALiBi mask for non-causal would have shape (1, nheads, seqlen, seqlen)
|
| 746 |
+
"""
|
| 747 |
+
if qkv.stride(-1) != 1:
|
| 748 |
+
qkv = qkv.contiguous()
|
| 749 |
+
(o, lse, ctx.softmax_scale) = _flash_attn_forward(qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2], bias=bias, causal=causal, softmax_scale=softmax_scale)
|
| 750 |
+
ctx.save_for_backward(qkv, o, lse, bias)
|
| 751 |
+
ctx.causal = causal
|
| 752 |
+
return o
|
| 753 |
+
|
| 754 |
+
@staticmethod
|
| 755 |
+
def backward(ctx, do):
|
| 756 |
+
(qkv, o, lse, bias) = ctx.saved_tensors
|
| 757 |
+
assert not ctx.needs_input_grad[1], "FlashAttention does not support bias gradient yet"
|
| 758 |
+
with torch.inference_mode():
|
| 759 |
+
dqkv = torch.empty_like(qkv)
|
| 760 |
+
_flash_attn_backward(
|
| 761 |
+
do,
|
| 762 |
+
qkv[:, :, 0],
|
| 763 |
+
qkv[:, :, 1],
|
| 764 |
+
qkv[:, :, 2],
|
| 765 |
+
o,
|
| 766 |
+
lse,
|
| 767 |
+
dqkv[:, :, 0],
|
| 768 |
+
dqkv[:, :, 1],
|
| 769 |
+
dqkv[:, :, 2],
|
| 770 |
+
bias=bias,
|
| 771 |
+
causal=ctx.causal,
|
| 772 |
+
softmax_scale=ctx.softmax_scale,
|
| 773 |
+
)
|
| 774 |
+
return (dqkv, None, None, None)
|
| 775 |
+
|
| 776 |
+
|
| 777 |
+
flash_attn_qkvpacked_func = FlashAttnQKVPackedFunc.apply
|
| 778 |
+
|
| 779 |
+
|
| 780 |
+
class FlashAttnKVPackedFunc(torch.autograd.Function):
|
| 781 |
+
@staticmethod
|
| 782 |
+
def forward(ctx, q, kv, bias=None, causal=False, softmax_scale=None):
|
| 783 |
+
"""
|
| 784 |
+
q: (batch, seqlen_q, nheads, headdim)
|
| 785 |
+
kv: (batch, seqlen_k, 2, nheads, headdim)
|
| 786 |
+
bias: optional, shape broadcastible to (batch, nheads, seqlen_q, seqlen_k).
|
| 787 |
+
For example, ALiBi mask for causal would have shape (1, nheads, 1, seqlen_k).
|
| 788 |
+
ALiBi mask for non-causal would have shape (1, nheads, seqlen_q, seqlen_k)
|
| 789 |
+
"""
|
| 790 |
+
(q, kv) = [x if x.stride(-1) == 1 else x.contiguous() for x in [q, kv]]
|
| 791 |
+
(o, lse, ctx.softmax_scale) = _flash_attn_forward(q, kv[:, :, 0], kv[:, :, 1], bias=bias, causal=causal, softmax_scale=softmax_scale)
|
| 792 |
+
ctx.save_for_backward(q, kv, o, lse, bias)
|
| 793 |
+
ctx.causal = causal
|
| 794 |
+
return o
|
| 795 |
+
|
| 796 |
+
@staticmethod
|
| 797 |
+
def backward(ctx, do):
|
| 798 |
+
(q, kv, o, lse, bias) = ctx.saved_tensors
|
| 799 |
+
if len(ctx.needs_input_grad) >= 3:
|
| 800 |
+
assert not ctx.needs_input_grad[2], "FlashAttention does not support bias gradient yet"
|
| 801 |
+
with torch.inference_mode():
|
| 802 |
+
dq = torch.empty_like(q)
|
| 803 |
+
dkv = torch.empty_like(kv)
|
| 804 |
+
_flash_attn_backward(
|
| 805 |
+
do, q, kv[:, :, 0], kv[:, :, 1], o, lse, dq, dkv[:, :, 0], dkv[:, :, 1], bias=bias, causal=ctx.causal, softmax_scale=ctx.softmax_scale
|
| 806 |
+
)
|
| 807 |
+
return (dq, dkv, None, None, None)
|
| 808 |
+
|
| 809 |
+
|
| 810 |
+
flash_attn_kvpacked_func = FlashAttnKVPackedFunc.apply
|
| 811 |
+
|
| 812 |
+
|
| 813 |
+
class FlashAttnFunc(torch.autograd.Function):
|
| 814 |
+
@staticmethod
|
| 815 |
+
def forward(ctx, q, k, v, bias=None, causal=False, softmax_scale=None):
|
| 816 |
+
"""
|
| 817 |
+
q: (batch_size, seqlen_q, nheads, headdim)
|
| 818 |
+
k, v: (batch_size, seqlen_k, nheads, headdim)
|
| 819 |
+
bias: optional, shape broadcastible to (batch, nheads, seqlen_q, seqlen_k).
|
| 820 |
+
For example, ALiBi mask for causal would have shape (1, nheads, 1, seqlen_k).
|
| 821 |
+
ALiBi mask for non-causal would have shape (1, nheads, seqlen_q, seqlen_k)
|
| 822 |
+
"""
|
| 823 |
+
(q, k, v) = [x if x.stride(-1) == 1 else x.contiguous() for x in [q, k, v]]
|
| 824 |
+
(o, lse, ctx.softmax_scale) = _flash_attn_forward(q, k, v, bias=bias, causal=causal, softmax_scale=softmax_scale)
|
| 825 |
+
ctx.save_for_backward(q, k, v, o, lse, bias)
|
| 826 |
+
ctx.causal = causal
|
| 827 |
+
return o
|
| 828 |
+
|
| 829 |
+
@staticmethod
|
| 830 |
+
def backward(ctx, do):
|
| 831 |
+
(q, k, v, o, lse, bias) = ctx.saved_tensors
|
| 832 |
+
assert not ctx.needs_input_grad[3], "FlashAttention does not support bias gradient yet"
|
| 833 |
+
with torch.inference_mode():
|
| 834 |
+
dq = torch.empty_like(q)
|
| 835 |
+
dk = torch.empty_like(k)
|
| 836 |
+
dv = torch.empty_like(v)
|
| 837 |
+
_flash_attn_backward(do, q, k, v, o, lse, dq, dk, dv, bias=bias, causal=ctx.causal, softmax_scale=ctx.softmax_scale)
|
| 838 |
+
return (dq, dk, dv, None, None, None)
|
| 839 |
+
|
| 840 |
+
|
| 841 |
+
flash_attn_func = FlashAttnFunc.apply
|
mllm/flamingo/mpt/hf_prefixlm_converter.py
ADDED
|
@@ -0,0 +1,575 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Converts Huggingface Causal LM to Prefix LM.
|
| 2 |
+
|
| 3 |
+
Conversion does lightweight surgery on a HuggingFace
|
| 4 |
+
Causal LM to convert it to a Prefix LM.
|
| 5 |
+
|
| 6 |
+
Prefix LMs accepts a `bidirectional_mask` input in `forward`
|
| 7 |
+
and treat the input prompt as the prefix in `generate`.
|
| 8 |
+
"""
|
| 9 |
+
import math
|
| 10 |
+
import warnings
|
| 11 |
+
from types import MethodType
|
| 12 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
| 13 |
+
import torch
|
| 14 |
+
from transformers.models.bloom.modeling_bloom import (
|
| 15 |
+
BaseModelOutputWithPastAndCrossAttentions,
|
| 16 |
+
BloomForCausalLM,
|
| 17 |
+
BloomModel,
|
| 18 |
+
CausalLMOutputWithCrossAttentions,
|
| 19 |
+
CrossEntropyLoss,
|
| 20 |
+
)
|
| 21 |
+
from transformers.models.bloom.modeling_bloom import _expand_mask as _expand_mask_bloom
|
| 22 |
+
from transformers.models.bloom.modeling_bloom import _make_causal_mask as _make_causal_mask_bloom
|
| 23 |
+
from transformers.models.bloom.modeling_bloom import logging
|
| 24 |
+
from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel
|
| 25 |
+
from transformers.models.gpt_neo.modeling_gpt_neo import GPTNeoForCausalLM
|
| 26 |
+
from transformers.models.gpt_neox.modeling_gpt_neox import GPTNeoXForCausalLM
|
| 27 |
+
from transformers.models.gptj.modeling_gptj import GPTJForCausalLM
|
| 28 |
+
from transformers.models.opt.modeling_opt import OPTForCausalLM
|
| 29 |
+
from transformers.models.opt.modeling_opt import _expand_mask as _expand_mask_opt
|
| 30 |
+
from transformers.models.opt.modeling_opt import _make_causal_mask as _make_causal_mask_opt
|
| 31 |
+
|
| 32 |
+
logger = logging.get_logger(__name__)
|
| 33 |
+
_SUPPORTED_GPT_MODELS = (GPT2LMHeadModel, GPTJForCausalLM, GPTNeoForCausalLM, GPTNeoXForCausalLM)
|
| 34 |
+
CAUSAL_GPT_TYPES = Union[GPT2LMHeadModel, GPTJForCausalLM, GPTNeoForCausalLM, GPTNeoXForCausalLM]
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def _convert_gpt_causal_lm_to_prefix_lm(model: CAUSAL_GPT_TYPES) -> CAUSAL_GPT_TYPES:
|
| 38 |
+
"""Converts a GPT-style Causal LM to a Prefix LM.
|
| 39 |
+
|
| 40 |
+
Supported HuggingFace model classes:
|
| 41 |
+
- `GPT2LMHeadModel`
|
| 42 |
+
- `GPTNeoForCausalLM`
|
| 43 |
+
- `GPTNeoXForCausalLM`
|
| 44 |
+
- `GPTJForCausalLM`
|
| 45 |
+
|
| 46 |
+
See `convert_hf_causal_lm_to_prefix_lm` for more details.
|
| 47 |
+
"""
|
| 48 |
+
if hasattr(model, "_prefix_lm_converted"):
|
| 49 |
+
return model
|
| 50 |
+
assert isinstance(model, _SUPPORTED_GPT_MODELS)
|
| 51 |
+
assert model.config.add_cross_attention == False, "Only supports GPT-style decoder-only models"
|
| 52 |
+
|
| 53 |
+
def _get_attn_modules(model: CAUSAL_GPT_TYPES) -> List[torch.nn.Module]:
|
| 54 |
+
"""Helper that gets a list of the model's attention modules.
|
| 55 |
+
|
| 56 |
+
Each module has a `bias` buffer used for causal masking. The Prefix LM
|
| 57 |
+
conversion adds logic to dynamically manipulate these biases to support
|
| 58 |
+
Prefix LM attention masking.
|
| 59 |
+
"""
|
| 60 |
+
attn_modules = []
|
| 61 |
+
if isinstance(model, GPTNeoXForCausalLM):
|
| 62 |
+
blocks = model.gpt_neox.layers
|
| 63 |
+
else:
|
| 64 |
+
blocks = model.transformer.h
|
| 65 |
+
for block in blocks:
|
| 66 |
+
if isinstance(model, GPTNeoForCausalLM):
|
| 67 |
+
if block.attn.attention_type != "global":
|
| 68 |
+
continue
|
| 69 |
+
attn_module = block.attn.attention
|
| 70 |
+
elif isinstance(model, GPTNeoXForCausalLM):
|
| 71 |
+
attn_module = block.attention
|
| 72 |
+
else:
|
| 73 |
+
attn_module = block.attn
|
| 74 |
+
attn_modules.append(attn_module)
|
| 75 |
+
return attn_modules
|
| 76 |
+
|
| 77 |
+
setattr(model, "_original_forward", getattr(model, "forward"))
|
| 78 |
+
setattr(model, "_original_generate", getattr(model, "generate"))
|
| 79 |
+
|
| 80 |
+
def forward(
|
| 81 |
+
self: CAUSAL_GPT_TYPES,
|
| 82 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 83 |
+
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
| 84 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 85 |
+
bidirectional_mask: Optional[torch.Tensor] = None,
|
| 86 |
+
token_type_ids: Optional[torch.LongTensor] = None,
|
| 87 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 88 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
| 89 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 90 |
+
labels: Optional[torch.LongTensor] = None,
|
| 91 |
+
use_cache: Optional[bool] = None,
|
| 92 |
+
output_attentions: Optional[bool] = None,
|
| 93 |
+
output_hidden_states: Optional[bool] = None,
|
| 94 |
+
return_dict: Optional[bool] = None,
|
| 95 |
+
):
|
| 96 |
+
"""Wraps original forward to enable PrefixLM attention."""
|
| 97 |
+
|
| 98 |
+
def call_og_forward():
|
| 99 |
+
if isinstance(self, GPTNeoXForCausalLM):
|
| 100 |
+
return self._original_forward(
|
| 101 |
+
input_ids=input_ids,
|
| 102 |
+
past_key_values=past_key_values,
|
| 103 |
+
attention_mask=attention_mask,
|
| 104 |
+
head_mask=head_mask,
|
| 105 |
+
inputs_embeds=inputs_embeds,
|
| 106 |
+
labels=labels,
|
| 107 |
+
use_cache=use_cache,
|
| 108 |
+
output_attentions=output_attentions,
|
| 109 |
+
output_hidden_states=output_hidden_states,
|
| 110 |
+
return_dict=return_dict,
|
| 111 |
+
)
|
| 112 |
+
else:
|
| 113 |
+
return self._original_forward(
|
| 114 |
+
input_ids=input_ids,
|
| 115 |
+
past_key_values=past_key_values,
|
| 116 |
+
attention_mask=attention_mask,
|
| 117 |
+
token_type_ids=token_type_ids,
|
| 118 |
+
position_ids=position_ids,
|
| 119 |
+
head_mask=head_mask,
|
| 120 |
+
inputs_embeds=inputs_embeds,
|
| 121 |
+
labels=labels,
|
| 122 |
+
use_cache=use_cache,
|
| 123 |
+
output_attentions=output_attentions,
|
| 124 |
+
output_hidden_states=output_hidden_states,
|
| 125 |
+
return_dict=return_dict,
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
if bidirectional_mask is None:
|
| 129 |
+
return call_og_forward()
|
| 130 |
+
assert isinstance(bidirectional_mask, torch.Tensor)
|
| 131 |
+
attn_modules = _get_attn_modules(model)
|
| 132 |
+
(b, s) = bidirectional_mask.shape
|
| 133 |
+
max_length = attn_modules[0].bias.shape[-1]
|
| 134 |
+
if s > max_length:
|
| 135 |
+
raise ValueError(f"bidirectional_mask sequence length (={s}) exceeds the " + f"max length allowed by the model ({max_length}).")
|
| 136 |
+
assert s <= max_length
|
| 137 |
+
if s < max_length:
|
| 138 |
+
pad = torch.zeros((int(b), int(max_length - s)), dtype=bidirectional_mask.dtype, device=bidirectional_mask.device)
|
| 139 |
+
bidirectional_mask = torch.cat([bidirectional_mask, pad], dim=1)
|
| 140 |
+
bidirectional = bidirectional_mask.unsqueeze(1).unsqueeze(1)
|
| 141 |
+
for attn_module in attn_modules:
|
| 142 |
+
attn_module.bias.data = torch.logical_or(attn_module.bias.data, bidirectional)
|
| 143 |
+
output = call_og_forward()
|
| 144 |
+
for attn_module in attn_modules:
|
| 145 |
+
attn_module.bias.data = torch.tril(attn_module.bias.data[0, 0])[None, None]
|
| 146 |
+
return output
|
| 147 |
+
|
| 148 |
+
def generate(self: CAUSAL_GPT_TYPES, *args: tuple, **kwargs: Dict[str, Any]):
|
| 149 |
+
"""Wraps original generate to enable PrefixLM attention."""
|
| 150 |
+
attn_modules = _get_attn_modules(model)
|
| 151 |
+
for attn_module in attn_modules:
|
| 152 |
+
attn_module.bias.data[:] = 1
|
| 153 |
+
output = self._original_generate(*args, **kwargs)
|
| 154 |
+
for attn_module in attn_modules:
|
| 155 |
+
attn_module.bias.data = torch.tril(attn_module.bias.data[0, 0])[None, None]
|
| 156 |
+
return output
|
| 157 |
+
|
| 158 |
+
setattr(model, "forward", MethodType(forward, model))
|
| 159 |
+
setattr(model, "generate", MethodType(generate, model))
|
| 160 |
+
setattr(model, "_prefix_lm_converted", True)
|
| 161 |
+
return model
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
def _convert_bloom_causal_lm_to_prefix_lm(model: BloomForCausalLM) -> BloomForCausalLM:
|
| 165 |
+
"""Converts a BLOOM Causal LM to a Prefix LM.
|
| 166 |
+
|
| 167 |
+
Supported HuggingFace model classes:
|
| 168 |
+
- `BloomForCausalLM`
|
| 169 |
+
|
| 170 |
+
See `convert_hf_causal_lm_to_prefix_lm` for more details.
|
| 171 |
+
"""
|
| 172 |
+
if hasattr(model, "_prefix_lm_converted"):
|
| 173 |
+
return model
|
| 174 |
+
assert isinstance(model, BloomForCausalLM)
|
| 175 |
+
assert model.config.add_cross_attention == False, "Only supports BLOOM decoder-only models"
|
| 176 |
+
|
| 177 |
+
def _prepare_attn_mask(
|
| 178 |
+
self: BloomModel, attention_mask: torch.Tensor, bidirectional_mask: Optional[torch.Tensor], input_shape: Tuple[int, int], past_key_values_length: int
|
| 179 |
+
) -> torch.BoolTensor:
|
| 180 |
+
combined_attention_mask = None
|
| 181 |
+
device = attention_mask.device
|
| 182 |
+
(_, src_length) = input_shape
|
| 183 |
+
if src_length > 1:
|
| 184 |
+
combined_attention_mask = _make_causal_mask_bloom(input_shape, device=device, past_key_values_length=past_key_values_length)
|
| 185 |
+
if bidirectional_mask is not None:
|
| 186 |
+
assert attention_mask.shape == bidirectional_mask.shape
|
| 187 |
+
expanded_bidirectional_mask = _expand_mask_bloom(bidirectional_mask, tgt_length=src_length)
|
| 188 |
+
combined_attention_mask = torch.logical_and(combined_attention_mask, expanded_bidirectional_mask)
|
| 189 |
+
expanded_attn_mask = _expand_mask_bloom(attention_mask, tgt_length=src_length)
|
| 190 |
+
combined_attention_mask = expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask | combined_attention_mask
|
| 191 |
+
return combined_attention_mask
|
| 192 |
+
|
| 193 |
+
def _build_alibi_tensor(self: BloomModel, batch_size: int, query_length: int, key_length: int, dtype: torch.dtype, device: torch.device) -> torch.Tensor:
|
| 194 |
+
num_heads = self.config.n_head
|
| 195 |
+
closest_power_of_2 = 2 ** math.floor(math.log2(num_heads))
|
| 196 |
+
base = torch.tensor(2 ** (-(2 ** (-(math.log2(closest_power_of_2) - 3)))), device=device, dtype=torch.float32)
|
| 197 |
+
powers = torch.arange(1, 1 + closest_power_of_2, device=device, dtype=torch.int32)
|
| 198 |
+
slopes = torch.pow(base, powers)
|
| 199 |
+
if closest_power_of_2 != num_heads:
|
| 200 |
+
extra_base = torch.tensor(2 ** (-(2 ** (-(math.log2(2 * closest_power_of_2) - 3)))), device=device, dtype=torch.float32)
|
| 201 |
+
num_remaining_heads = min(closest_power_of_2, num_heads - closest_power_of_2)
|
| 202 |
+
extra_powers = torch.arange(1, 1 + 2 * num_remaining_heads, 2, device=device, dtype=torch.int32)
|
| 203 |
+
slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0)
|
| 204 |
+
qa = torch.arange(query_length, device=device, dtype=torch.int32).view(-1, 1)
|
| 205 |
+
ka = torch.arange(key_length, device=device, dtype=torch.int32).view(1, -1)
|
| 206 |
+
diffs = qa - ka + key_length - query_length
|
| 207 |
+
diffs = -diffs.abs()
|
| 208 |
+
alibi = slopes.view(1, num_heads, 1, 1) * diffs.view(1, 1, query_length, key_length)
|
| 209 |
+
alibi = alibi.expand(batch_size, -1, -1, -1).reshape(-1, query_length, key_length)
|
| 210 |
+
return alibi.to(dtype)
|
| 211 |
+
|
| 212 |
+
KeyValueT = Tuple[torch.Tensor, torch.Tensor]
|
| 213 |
+
|
| 214 |
+
def forward(
|
| 215 |
+
self: BloomModel,
|
| 216 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 217 |
+
past_key_values: Optional[Tuple[KeyValueT, ...]] = None,
|
| 218 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 219 |
+
bidirectional_mask: Optional[torch.Tensor] = None,
|
| 220 |
+
head_mask: Optional[torch.LongTensor] = None,
|
| 221 |
+
inputs_embeds: Optional[torch.LongTensor] = None,
|
| 222 |
+
use_cache: Optional[bool] = None,
|
| 223 |
+
output_attentions: Optional[bool] = None,
|
| 224 |
+
output_hidden_states: Optional[bool] = None,
|
| 225 |
+
return_dict: Optional[bool] = None,
|
| 226 |
+
**deprecated_arguments,
|
| 227 |
+
) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]:
|
| 228 |
+
if deprecated_arguments.pop("position_ids", False) is not False:
|
| 229 |
+
warnings.warn(
|
| 230 |
+
"`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. " + "You can safely ignore passing `position_ids`.", FutureWarning
|
| 231 |
+
)
|
| 232 |
+
if len(deprecated_arguments) > 0:
|
| 233 |
+
raise ValueError(f"Got unexpected arguments: {deprecated_arguments}")
|
| 234 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 235 |
+
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 236 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
| 237 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 238 |
+
if input_ids is not None and inputs_embeds is not None:
|
| 239 |
+
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
| 240 |
+
elif input_ids is not None:
|
| 241 |
+
(batch_size, seq_length) = input_ids.shape
|
| 242 |
+
elif inputs_embeds is not None:
|
| 243 |
+
(batch_size, seq_length, _) = inputs_embeds.shape
|
| 244 |
+
else:
|
| 245 |
+
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
| 246 |
+
if past_key_values is None:
|
| 247 |
+
past_key_values = tuple([None] * len(self.h))
|
| 248 |
+
head_mask = self.get_head_mask(head_mask, self.config.n_layer)
|
| 249 |
+
if inputs_embeds is None:
|
| 250 |
+
inputs_embeds = self.word_embeddings(input_ids)
|
| 251 |
+
hidden_states = self.word_embeddings_layernorm(inputs_embeds)
|
| 252 |
+
presents = () if use_cache else None
|
| 253 |
+
all_self_attentions = () if output_attentions else None
|
| 254 |
+
all_hidden_states = () if output_hidden_states else None
|
| 255 |
+
seq_length_with_past = seq_length
|
| 256 |
+
past_key_values_length = 0
|
| 257 |
+
if past_key_values[0] is not None:
|
| 258 |
+
tmp = past_key_values[0][0]
|
| 259 |
+
past_key_values_length = tmp.shape[2]
|
| 260 |
+
seq_length_with_past = seq_length_with_past + past_key_values_length
|
| 261 |
+
if attention_mask is None:
|
| 262 |
+
attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device)
|
| 263 |
+
else:
|
| 264 |
+
attention_mask = attention_mask.to(hidden_states.device)
|
| 265 |
+
alibi = self._build_alibi_tensor(
|
| 266 |
+
batch_size=batch_size, query_length=seq_length, key_length=seq_length_with_past, dtype=hidden_states.dtype, device=hidden_states.device
|
| 267 |
+
)
|
| 268 |
+
causal_mask = self._prepare_attn_mask(
|
| 269 |
+
attention_mask, bidirectional_mask, input_shape=(batch_size, seq_length), past_key_values_length=past_key_values_length
|
| 270 |
+
)
|
| 271 |
+
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
|
| 272 |
+
if output_hidden_states:
|
| 273 |
+
hst = (hidden_states,)
|
| 274 |
+
all_hidden_states = all_hidden_states + hst
|
| 275 |
+
if self.gradient_checkpointing and self.training:
|
| 276 |
+
if use_cache:
|
| 277 |
+
logger.warning("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
|
| 278 |
+
use_cache = False
|
| 279 |
+
|
| 280 |
+
def create_custom_forward(module):
|
| 281 |
+
def custom_forward(*inputs):
|
| 282 |
+
return module(*inputs, use_cache=use_cache, output_attentions=output_attentions)
|
| 283 |
+
|
| 284 |
+
return custom_forward
|
| 285 |
+
|
| 286 |
+
outputs = torch.utils.checkpoint.checkpoint(create_custom_forward(block), hidden_states, alibi, causal_mask, head_mask[i])
|
| 287 |
+
else:
|
| 288 |
+
outputs = block(
|
| 289 |
+
hidden_states,
|
| 290 |
+
layer_past=layer_past,
|
| 291 |
+
attention_mask=causal_mask,
|
| 292 |
+
head_mask=head_mask[i],
|
| 293 |
+
use_cache=use_cache,
|
| 294 |
+
output_attentions=output_attentions,
|
| 295 |
+
alibi=alibi,
|
| 296 |
+
)
|
| 297 |
+
hidden_states = outputs[0]
|
| 298 |
+
if use_cache is True:
|
| 299 |
+
presents = presents + (outputs[1],)
|
| 300 |
+
if output_attentions:
|
| 301 |
+
oa = (outputs[2 if use_cache else 1],)
|
| 302 |
+
all_self_attentions = all_self_attentions + oa
|
| 303 |
+
hidden_states = self.ln_f(hidden_states)
|
| 304 |
+
if output_hidden_states:
|
| 305 |
+
hst = (hidden_states,)
|
| 306 |
+
all_hidden_states = all_hidden_states + hst
|
| 307 |
+
if not return_dict:
|
| 308 |
+
return tuple((v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None))
|
| 309 |
+
return BaseModelOutputWithPastAndCrossAttentions(
|
| 310 |
+
last_hidden_state=hidden_states, past_key_values=presents, hidden_states=all_hidden_states, attentions=all_self_attentions
|
| 311 |
+
)
|
| 312 |
+
|
| 313 |
+
setattr(model.transformer, "_prepare_attn_mask", MethodType(_prepare_attn_mask, model.transformer))
|
| 314 |
+
setattr(model.transformer, "_build_alibi_tensor", MethodType(_build_alibi_tensor, model.transformer))
|
| 315 |
+
setattr(model.transformer, "forward", MethodType(forward, model.transformer))
|
| 316 |
+
KeyValueT = Tuple[torch.Tensor, torch.Tensor]
|
| 317 |
+
|
| 318 |
+
def forward(
|
| 319 |
+
self: BloomForCausalLM,
|
| 320 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 321 |
+
past_key_values: Optional[Tuple[KeyValueT, ...]] = None,
|
| 322 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 323 |
+
bidirectional_mask: Optional[torch.Tensor] = None,
|
| 324 |
+
head_mask: Optional[torch.Tensor] = None,
|
| 325 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
| 326 |
+
labels: Optional[torch.Tensor] = None,
|
| 327 |
+
use_cache: Optional[bool] = None,
|
| 328 |
+
output_attentions: Optional[bool] = None,
|
| 329 |
+
output_hidden_states: Optional[bool] = None,
|
| 330 |
+
return_dict: Optional[bool] = None,
|
| 331 |
+
**deprecated_arguments,
|
| 332 |
+
) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
|
| 333 |
+
"""Replacement forward method for BloomCausalLM."""
|
| 334 |
+
if deprecated_arguments.pop("position_ids", False) is not False:
|
| 335 |
+
warnings.warn(
|
| 336 |
+
"`position_ids` have no functionality in BLOOM and will be removed " + "in v5.0.0. You can safely ignore passing `position_ids`.", FutureWarning
|
| 337 |
+
)
|
| 338 |
+
if len(deprecated_arguments) > 0:
|
| 339 |
+
raise ValueError(f"Got unexpected arguments: {deprecated_arguments}")
|
| 340 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 341 |
+
transformer_outputs = self.transformer(
|
| 342 |
+
input_ids,
|
| 343 |
+
past_key_values=past_key_values,
|
| 344 |
+
attention_mask=attention_mask,
|
| 345 |
+
bidirectional_mask=bidirectional_mask,
|
| 346 |
+
head_mask=head_mask,
|
| 347 |
+
inputs_embeds=inputs_embeds,
|
| 348 |
+
use_cache=use_cache,
|
| 349 |
+
output_attentions=output_attentions,
|
| 350 |
+
output_hidden_states=output_hidden_states,
|
| 351 |
+
return_dict=return_dict,
|
| 352 |
+
)
|
| 353 |
+
hidden_states = transformer_outputs[0]
|
| 354 |
+
lm_logits = self.lm_head(hidden_states)
|
| 355 |
+
loss = None
|
| 356 |
+
if labels is not None:
|
| 357 |
+
shift_logits = lm_logits[..., :-1, :].contiguous()
|
| 358 |
+
shift_labels = labels[..., 1:].contiguous()
|
| 359 |
+
(batch_size, seq_length, vocab_size) = shift_logits.shape
|
| 360 |
+
loss_fct = CrossEntropyLoss()
|
| 361 |
+
loss = loss_fct(shift_logits.view(batch_size * seq_length, vocab_size), shift_labels.view(batch_size * seq_length))
|
| 362 |
+
if not return_dict:
|
| 363 |
+
output = (lm_logits,) + transformer_outputs[1:]
|
| 364 |
+
return (loss,) + output if loss is not None else output
|
| 365 |
+
return CausalLMOutputWithCrossAttentions(
|
| 366 |
+
loss=loss,
|
| 367 |
+
logits=lm_logits,
|
| 368 |
+
past_key_values=transformer_outputs.past_key_values,
|
| 369 |
+
hidden_states=transformer_outputs.hidden_states,
|
| 370 |
+
attentions=transformer_outputs.attentions,
|
| 371 |
+
)
|
| 372 |
+
|
| 373 |
+
def prepare_inputs_for_generation(
|
| 374 |
+
self: BloomForCausalLM, input_ids: torch.LongTensor, past: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, **kwargs
|
| 375 |
+
) -> dict:
|
| 376 |
+
if past:
|
| 377 |
+
input_ids = input_ids[:, -1].unsqueeze(-1)
|
| 378 |
+
bidirectional_mask = None
|
| 379 |
+
if past[0][0].shape[0] == input_ids.shape[0]:
|
| 380 |
+
past = self._convert_to_bloom_cache(past)
|
| 381 |
+
else:
|
| 382 |
+
bidirectional_mask = torch.ones_like(input_ids)
|
| 383 |
+
return {"input_ids": input_ids, "past_key_values": past, "use_cache": True, "attention_mask": attention_mask, "bidirectional_mask": bidirectional_mask}
|
| 384 |
+
|
| 385 |
+
setattr(model, "forward", MethodType(forward, model))
|
| 386 |
+
setattr(model, "prepare_inputs_for_generation", MethodType(prepare_inputs_for_generation, model))
|
| 387 |
+
setattr(model, "_prefix_lm_converted", True)
|
| 388 |
+
return model
|
| 389 |
+
|
| 390 |
+
|
| 391 |
+
def _convert_opt_causal_lm_to_prefix_lm(model: OPTForCausalLM) -> OPTForCausalLM:
|
| 392 |
+
"""Converts an OPT Causal LM to a Prefix LM.
|
| 393 |
+
|
| 394 |
+
Supported HuggingFace model classes:
|
| 395 |
+
- `OPTForCausalLM`
|
| 396 |
+
|
| 397 |
+
See `convert_hf_causal_lm_to_prefix_lm` for more details.
|
| 398 |
+
"""
|
| 399 |
+
if hasattr(model, "_prefix_lm_converted"):
|
| 400 |
+
return model
|
| 401 |
+
assert isinstance(model, OPTForCausalLM)
|
| 402 |
+
assert model.config.add_cross_attention == False, "Only supports OPT decoder-only models"
|
| 403 |
+
setattr(model, "_original_forward", getattr(model, "forward"))
|
| 404 |
+
setattr(model, "_original_generate", getattr(model, "generate"))
|
| 405 |
+
model.model.decoder.bidirectional_mask = None
|
| 406 |
+
|
| 407 |
+
def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
|
| 408 |
+
combined_attention_mask = None
|
| 409 |
+
if input_shape[-1] > 1:
|
| 410 |
+
if self.bidirectional_mask == "g":
|
| 411 |
+
(bsz, src_length) = input_shape
|
| 412 |
+
combined_attention_mask = torch.zeros(
|
| 413 |
+
(bsz, 1, src_length, src_length + past_key_values_length), dtype=inputs_embeds.dtype, device=inputs_embeds.device
|
| 414 |
+
)
|
| 415 |
+
else:
|
| 416 |
+
combined_attention_mask = _make_causal_mask_opt(input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length).to(
|
| 417 |
+
inputs_embeds.device
|
| 418 |
+
)
|
| 419 |
+
if self.bidirectional_mask is not None:
|
| 420 |
+
assert attention_mask.shape == self.bidirectional_mask.shape
|
| 421 |
+
expanded_bidirectional_mask = _expand_mask_opt(self.bidirectional_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
|
| 422 |
+
inputs_embeds.device
|
| 423 |
+
)
|
| 424 |
+
combined_attention_mask = torch.maximum(expanded_bidirectional_mask, combined_attention_mask)
|
| 425 |
+
if attention_mask is not None:
|
| 426 |
+
expanded_attn_mask = _expand_mask_opt(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(inputs_embeds.device)
|
| 427 |
+
combined_attention_mask = expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
|
| 428 |
+
return combined_attention_mask
|
| 429 |
+
|
| 430 |
+
setattr(model.model.decoder, "_prepare_decoder_attention_mask", MethodType(_prepare_decoder_attention_mask, model.model.decoder))
|
| 431 |
+
|
| 432 |
+
def forward(
|
| 433 |
+
self: OPTForCausalLM,
|
| 434 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 435 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 436 |
+
bidirectional_mask: Optional[torch.ByteTensor] = None,
|
| 437 |
+
head_mask: Optional[torch.Tensor] = None,
|
| 438 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
| 439 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 440 |
+
labels: Optional[torch.LongTensor] = None,
|
| 441 |
+
use_cache: Optional[bool] = None,
|
| 442 |
+
output_attentions: Optional[bool] = None,
|
| 443 |
+
output_hidden_states: Optional[bool] = None,
|
| 444 |
+
return_dict: Optional[bool] = None,
|
| 445 |
+
):
|
| 446 |
+
def call_og_forward():
|
| 447 |
+
return self._original_forward(
|
| 448 |
+
input_ids=input_ids,
|
| 449 |
+
attention_mask=attention_mask,
|
| 450 |
+
head_mask=head_mask,
|
| 451 |
+
past_key_values=past_key_values,
|
| 452 |
+
inputs_embeds=inputs_embeds,
|
| 453 |
+
labels=labels,
|
| 454 |
+
use_cache=use_cache,
|
| 455 |
+
output_attentions=output_attentions,
|
| 456 |
+
output_hidden_states=output_hidden_states,
|
| 457 |
+
return_dict=return_dict,
|
| 458 |
+
)
|
| 459 |
+
|
| 460 |
+
if bidirectional_mask is None:
|
| 461 |
+
return call_og_forward()
|
| 462 |
+
self.model.decoder.bidirectional_mask = bidirectional_mask
|
| 463 |
+
try:
|
| 464 |
+
outputs = call_og_forward()
|
| 465 |
+
except:
|
| 466 |
+
self.model.decoder.bidirectional_mask = None
|
| 467 |
+
raise
|
| 468 |
+
self.model.decoder.bidirectional_mask = None
|
| 469 |
+
return outputs
|
| 470 |
+
|
| 471 |
+
def generate(self: OPTForCausalLM, *args: tuple, **kwargs: Dict[str, Any]):
|
| 472 |
+
"""Wraps original generate to enable PrefixLM-style attention."""
|
| 473 |
+
self.model.decoder.bidirectional_mask = "g"
|
| 474 |
+
try:
|
| 475 |
+
output = self._original_generate(*args, **kwargs)
|
| 476 |
+
except:
|
| 477 |
+
self.model.decoder.bidirectional_mask = None
|
| 478 |
+
raise
|
| 479 |
+
self.model.decoder.bidirectional_mask = None
|
| 480 |
+
return output
|
| 481 |
+
|
| 482 |
+
setattr(model, "forward", MethodType(forward, model))
|
| 483 |
+
setattr(model, "generate", MethodType(generate, model))
|
| 484 |
+
setattr(model, "_prefix_lm_converted", True)
|
| 485 |
+
return model
|
| 486 |
+
|
| 487 |
+
|
| 488 |
+
_SUPPORTED_HF_MODELS = _SUPPORTED_GPT_MODELS + (BloomForCausalLM, OPTForCausalLM)
|
| 489 |
+
CAUSAL_LM_TYPES = Union[GPT2LMHeadModel, GPTJForCausalLM, GPTNeoForCausalLM, GPTNeoXForCausalLM, BloomForCausalLM, OPTForCausalLM]
|
| 490 |
+
|
| 491 |
+
|
| 492 |
+
def convert_hf_causal_lm_to_prefix_lm(model: CAUSAL_LM_TYPES) -> CAUSAL_LM_TYPES:
|
| 493 |
+
"""Converts a HuggingFace Causal LM to a Prefix LM.
|
| 494 |
+
|
| 495 |
+
Supported HuggingFace model classes:
|
| 496 |
+
- `GPT2LMHeadModel`
|
| 497 |
+
- `GPTNeoForCausalLM`
|
| 498 |
+
- `GPTNeoXForCausalLM`
|
| 499 |
+
- `GPTJForCausalLM`
|
| 500 |
+
- `BloomForCausalLM`
|
| 501 |
+
- `OPTForCausalLM`
|
| 502 |
+
|
| 503 |
+
Conversion to a Prefix LM is done by modifying the `forward` method, and possibly also the
|
| 504 |
+
`generate` method and/or select underlying methods depending on the model class.
|
| 505 |
+
|
| 506 |
+
These changes preserve the model API, but add a new input to `forward`: "bidirectional_mask".
|
| 507 |
+
|
| 508 |
+
Notes on training:
|
| 509 |
+
To actually train the converted model as a Prefix LM, training batches will need to indicate
|
| 510 |
+
the prefix/target structure by including `bidirectional_mask` as part of the batch inputs.
|
| 511 |
+
|
| 512 |
+
**This is not a standard input and requires custom layers either within or after your dataloader.**
|
| 513 |
+
|
| 514 |
+
In addition to adding `bidirectional_mask` to the batch, this custom code should modify `labels`
|
| 515 |
+
such that `batch['labels'][batch['bidirectional_mask'] == 1] == -100`.
|
| 516 |
+
That is, the prefix portion of the sequence should not generate any loss. Loss should only be
|
| 517 |
+
generated by the target portion of the sequence.
|
| 518 |
+
|
| 519 |
+
Notes on `GPTNeoForCausalLM`:
|
| 520 |
+
To simplify the implementation, "global" and "local" attention layers are handled differently.
|
| 521 |
+
For "global" layers, we handle conversion as described above. For "local" layers, which use a
|
| 522 |
+
causal attention mask within a restricted local window, we do not alter the masking.
|
| 523 |
+
|
| 524 |
+
Notes on `forward` method conversion:
|
| 525 |
+
After conversion, the `forward` method will handle a new input, `bidirectional_mask`,
|
| 526 |
+
which should be a [batch_size, seq_length] byte tensor, where 1 indicates token positions
|
| 527 |
+
belonging to the prefix (prefix tokens can attend to one another bidirectionally), and
|
| 528 |
+
0 indicates token positions belonging to the target.
|
| 529 |
+
|
| 530 |
+
The new `forward` method will incorporate `bidirectional_mask` (if supplied) into the existing
|
| 531 |
+
causal mask, call the original `forward` method, and (if the causal mask is a buffer) reset
|
| 532 |
+
the causal masks before returning the result.
|
| 533 |
+
|
| 534 |
+
Notes on `generate` method conversion:
|
| 535 |
+
After conversion, the `generate` method will have the same signature but will internally
|
| 536 |
+
convert all causal masks to be purely bidirectional, call the original `generate` method, and
|
| 537 |
+
(where appropriate) reset the causal masks before returning the result.
|
| 538 |
+
|
| 539 |
+
This works thanks to the logic of the HuggingFace `generate` API, which first encodes the token
|
| 540 |
+
"prompt" passed to `generate` (which is treated as the prefix) and then sequentially generates
|
| 541 |
+
each new token. Encodings are cached as generation happens, so all prefix tokens can attend to one
|
| 542 |
+
another (as expected in a Prefix LM) and generated tokens can only attend to prefix tokens and
|
| 543 |
+
previously-generated tokens (also as expected in a Prefix LM).
|
| 544 |
+
|
| 545 |
+
To preserve the API, the original methods are renamed to `_original_forward` and
|
| 546 |
+
`_original_generate`, and replaced with new `forward` and `generate` methods that wrap
|
| 547 |
+
them, respectively. Although implementation details vary by model class.
|
| 548 |
+
"""
|
| 549 |
+
if isinstance(model, _SUPPORTED_GPT_MODELS):
|
| 550 |
+
return _convert_gpt_causal_lm_to_prefix_lm(model)
|
| 551 |
+
elif isinstance(model, BloomForCausalLM):
|
| 552 |
+
return _convert_bloom_causal_lm_to_prefix_lm(model)
|
| 553 |
+
elif isinstance(model, OPTForCausalLM):
|
| 554 |
+
return _convert_opt_causal_lm_to_prefix_lm(model)
|
| 555 |
+
else:
|
| 556 |
+
raise TypeError(f"Cannot convert model to Prefix LM. " + f"Model does not belong to set of supported HF models:" + f"\n{_SUPPORTED_HF_MODELS}")
|
| 557 |
+
|
| 558 |
+
|
| 559 |
+
def add_bidirectional_mask_if_missing(batch: Dict[str, Any]):
|
| 560 |
+
"""Attempts to add bidirectional_mask to batch if missing.
|
| 561 |
+
|
| 562 |
+
Raises:
|
| 563 |
+
KeyError if bidirectional_mask is missing and can't be inferred
|
| 564 |
+
"""
|
| 565 |
+
if "bidirectional_mask" not in batch:
|
| 566 |
+
if batch.get("mode", None) == "icl_task":
|
| 567 |
+
batch["bidirectional_mask"] = batch["attention_mask"].clone()
|
| 568 |
+
for i, continuation_indices in enumerate(batch["continuation_indices"]):
|
| 569 |
+
batch["bidirectional_mask"][i, continuation_indices] = 0
|
| 570 |
+
elif "labels" in batch and "attention_mask" in batch:
|
| 571 |
+
batch["bidirectional_mask"] = torch.logical_and(torch.eq(batch["attention_mask"], 1), torch.eq(batch["labels"], -100)).type_as(
|
| 572 |
+
batch["attention_mask"]
|
| 573 |
+
)
|
| 574 |
+
else:
|
| 575 |
+
raise KeyError("No bidirectional_mask in batch and not sure how to construct one.")
|
mllm/flamingo/mpt/meta_init_context.py
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from contextlib import contextmanager
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
@contextmanager
|
| 7 |
+
def init_empty_weights(include_buffers: bool = False):
|
| 8 |
+
"""Meta initialization context manager.
|
| 9 |
+
|
| 10 |
+
A context manager under which models are initialized with all parameters
|
| 11 |
+
on the meta device, therefore creating an empty model. Useful when just
|
| 12 |
+
initializing the model would blow the available RAM.
|
| 13 |
+
|
| 14 |
+
Args:
|
| 15 |
+
include_buffers (`bool`, *optional*, defaults to `False`): Whether or
|
| 16 |
+
not to also put all buffers on the meta device while initializing.
|
| 17 |
+
|
| 18 |
+
Example:
|
| 19 |
+
```python
|
| 20 |
+
import torch.nn as nn
|
| 21 |
+
|
| 22 |
+
# Initialize a model with 100 billions parameters in no time and without using any RAM.
|
| 23 |
+
with init_empty_weights():
|
| 24 |
+
tst = nn.Sequential(*[nn.Linear(10000, 10000) for _ in range(1000)])
|
| 25 |
+
```
|
| 26 |
+
|
| 27 |
+
<Tip warning={true}>
|
| 28 |
+
|
| 29 |
+
Any model created under this context manager has no weights. As such you can't do something like
|
| 30 |
+
`model.to(some_device)` with it. To load weights inside your empty model, see [`load_checkpoint_and_dispatch`].
|
| 31 |
+
|
| 32 |
+
</Tip>
|
| 33 |
+
"""
|
| 34 |
+
with init_on_device(torch.device("meta"), include_buffers=include_buffers) as f:
|
| 35 |
+
yield f
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
@contextmanager
|
| 39 |
+
def init_on_device(device: torch.device, include_buffers: bool = False):
|
| 40 |
+
"""Device initialization context manager.
|
| 41 |
+
|
| 42 |
+
A context manager under which models are initialized with all parameters
|
| 43 |
+
on the specified device.
|
| 44 |
+
|
| 45 |
+
Args:
|
| 46 |
+
device (`torch.device`): Device to initialize all parameters on.
|
| 47 |
+
include_buffers (`bool`, *optional*, defaults to `False`): Whether or
|
| 48 |
+
not to also put all buffers on the meta device while initializing.
|
| 49 |
+
|
| 50 |
+
Example:
|
| 51 |
+
```python
|
| 52 |
+
import torch.nn as nn
|
| 53 |
+
|
| 54 |
+
with init_on_device(device=torch.device("cuda")):
|
| 55 |
+
tst = nn.Liner(100, 100) # on `cuda` device
|
| 56 |
+
```
|
| 57 |
+
"""
|
| 58 |
+
old_register_parameter = nn.Module.register_parameter
|
| 59 |
+
if include_buffers:
|
| 60 |
+
old_register_buffer = nn.Module.register_buffer
|
| 61 |
+
|
| 62 |
+
def register_empty_parameter(module, name, param):
|
| 63 |
+
old_register_parameter(module, name, param)
|
| 64 |
+
if param is not None:
|
| 65 |
+
param_cls = type(module._parameters[name])
|
| 66 |
+
kwargs = module._parameters[name].__dict__
|
| 67 |
+
module._parameters[name] = param_cls(module._parameters[name].to(device), **kwargs)
|
| 68 |
+
|
| 69 |
+
def register_empty_buffer(module, name, buffer):
|
| 70 |
+
old_register_buffer(module, name, buffer)
|
| 71 |
+
if buffer is not None:
|
| 72 |
+
module._buffers[name] = module._buffers[name].to(device)
|
| 73 |
+
|
| 74 |
+
if include_buffers:
|
| 75 |
+
tensor_constructors_to_patch = {torch_function_name: getattr(torch, torch_function_name) for torch_function_name in ["empty", "zeros", "ones", "full"]}
|
| 76 |
+
else:
|
| 77 |
+
tensor_constructors_to_patch = {}
|
| 78 |
+
|
| 79 |
+
def patch_tensor_constructor(fn):
|
| 80 |
+
def wrapper(*args, **kwargs):
|
| 81 |
+
kwargs["device"] = device
|
| 82 |
+
return fn(*args, **kwargs)
|
| 83 |
+
|
| 84 |
+
return wrapper
|
| 85 |
+
|
| 86 |
+
try:
|
| 87 |
+
nn.Module.register_parameter = register_empty_parameter
|
| 88 |
+
if include_buffers:
|
| 89 |
+
nn.Module.register_buffer = register_empty_buffer
|
| 90 |
+
for torch_function_name in tensor_constructors_to_patch.keys():
|
| 91 |
+
setattr(torch, torch_function_name, patch_tensor_constructor(getattr(torch, torch_function_name)))
|
| 92 |
+
yield
|
| 93 |
+
finally:
|
| 94 |
+
nn.Module.register_parameter = old_register_parameter
|
| 95 |
+
if include_buffers:
|
| 96 |
+
nn.Module.register_buffer = old_register_buffer
|
| 97 |
+
for torch_function_name, old_torch_function in tensor_constructors_to_patch.items():
|
| 98 |
+
setattr(torch, torch_function_name, old_torch_function)
|
mllm/flamingo/mpt/modeling_mpt.py
ADDED
|
@@ -0,0 +1,496 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""A simple, flexible implementation of a GPT model.
|
| 2 |
+
|
| 3 |
+
Inspired by https://github.com/karpathy/minGPT/blob/master/mingpt/model.py
|
| 4 |
+
"""
|
| 5 |
+
import math
|
| 6 |
+
import warnings
|
| 7 |
+
from typing import List, Optional, Tuple, Union
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
import torch.nn.functional as F
|
| 12 |
+
from transformers import PreTrainedModel, PreTrainedTokenizer, PreTrainedTokenizerFast
|
| 13 |
+
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
| 14 |
+
|
| 15 |
+
from .attention import attn_bias_shape, build_attn_bias
|
| 16 |
+
from .blocks import MPTBlock
|
| 17 |
+
from .configuration_mpt import MPTConfig
|
| 18 |
+
from .custom_embedding import SharedEmbedding
|
| 19 |
+
from .norm import NORM_CLASS_REGISTRY
|
| 20 |
+
from .param_init_fns import MODEL_INIT_REGISTRY, generic_param_init_fn_
|
| 21 |
+
|
| 22 |
+
import torch.distributed as dist
|
| 23 |
+
|
| 24 |
+
try:
|
| 25 |
+
from .flash_attn_triton import flash_attn_func
|
| 26 |
+
except:
|
| 27 |
+
pass
|
| 28 |
+
Tokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class MPTPreTrainedModel(PreTrainedModel):
|
| 32 |
+
config_class = MPTConfig
|
| 33 |
+
base_model_prefix = "model"
|
| 34 |
+
_no_split_modules = ["MPTBlock"]
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class MPTModel(MPTPreTrainedModel):
|
| 38 |
+
def __init__(self, config: MPTConfig):
|
| 39 |
+
config._validate_config()
|
| 40 |
+
super().__init__(config)
|
| 41 |
+
self.attn_impl = config.attn_config["attn_impl"]
|
| 42 |
+
self.prefix_lm = config.attn_config["prefix_lm"]
|
| 43 |
+
self.attn_uses_sequence_id = config.attn_config["attn_uses_sequence_id"]
|
| 44 |
+
self.alibi = config.attn_config["alibi"]
|
| 45 |
+
self.alibi_bias_max = config.attn_config["alibi_bias_max"]
|
| 46 |
+
if config.init_device == "mixed":
|
| 47 |
+
if dist.get_local_rank() == 0:
|
| 48 |
+
config.init_device = "cpu"
|
| 49 |
+
else:
|
| 50 |
+
config.init_device = "meta"
|
| 51 |
+
if config.norm_type.lower() not in NORM_CLASS_REGISTRY.keys():
|
| 52 |
+
norm_options = " | ".join(NORM_CLASS_REGISTRY.keys())
|
| 53 |
+
raise NotImplementedError(f"Requested norm type ({config.norm_type}) is not implemented within this repo (Options: {norm_options}).")
|
| 54 |
+
norm_class = NORM_CLASS_REGISTRY[config.norm_type.lower()]
|
| 55 |
+
self.embedding_fraction = config.embedding_fraction
|
| 56 |
+
self.wte = SharedEmbedding(config.vocab_size, config.d_model, device=config.init_device)
|
| 57 |
+
if not self.alibi:
|
| 58 |
+
self.wpe = torch.nn.Embedding(config.max_seq_len, config.d_model, device=config.init_device)
|
| 59 |
+
self.emb_drop = nn.Dropout(config.emb_pdrop)
|
| 60 |
+
self.blocks = nn.ModuleList([MPTBlock(device=config.init_device, **config.to_dict()) for _ in range(config.n_layers)])
|
| 61 |
+
self.norm_f = norm_class(config.d_model, device=config.init_device)
|
| 62 |
+
if config.init_device != "meta":
|
| 63 |
+
print(
|
| 64 |
+
f'You are using config.init_device={config.init_device!r}, but you can also use config.init_device="meta" with Composer + FSDP for fast initialization.'
|
| 65 |
+
)
|
| 66 |
+
self.apply(self.param_init_fn)
|
| 67 |
+
self.is_causal = not self.prefix_lm
|
| 68 |
+
self._attn_bias_initialized = False
|
| 69 |
+
self.attn_bias = None
|
| 70 |
+
self.attn_bias_shape = attn_bias_shape(
|
| 71 |
+
self.attn_impl,
|
| 72 |
+
config.n_heads,
|
| 73 |
+
config.max_seq_len,
|
| 74 |
+
self.alibi,
|
| 75 |
+
prefix_lm=self.prefix_lm,
|
| 76 |
+
causal=self.is_causal,
|
| 77 |
+
use_sequence_id=self.attn_uses_sequence_id,
|
| 78 |
+
)
|
| 79 |
+
if config.no_bias:
|
| 80 |
+
for module in self.modules():
|
| 81 |
+
if hasattr(module, "bias") and isinstance(module.bias, nn.Parameter):
|
| 82 |
+
if config.verbose:
|
| 83 |
+
warnings.warn(f"Removing bias ({module.bias}) from {module}.")
|
| 84 |
+
module.register_parameter("bias", None)
|
| 85 |
+
if config.verbose and config.verbose > 2:
|
| 86 |
+
print(self)
|
| 87 |
+
if "verbose" not in self.config.init_config:
|
| 88 |
+
self.config.init_config["verbose"] = self.config.verbose
|
| 89 |
+
if self.config.init_config["verbose"] > 1:
|
| 90 |
+
init_fn_name = self.config.init_config["name"]
|
| 91 |
+
warnings.warn(f"Using {init_fn_name} initialization.")
|
| 92 |
+
|
| 93 |
+
def get_input_embeddings(self):
|
| 94 |
+
return self.wte
|
| 95 |
+
|
| 96 |
+
def set_input_embeddings(self, value):
|
| 97 |
+
self.wte = value
|
| 98 |
+
|
| 99 |
+
@torch.no_grad()
|
| 100 |
+
def _attn_bias(
|
| 101 |
+
self,
|
| 102 |
+
device,
|
| 103 |
+
dtype,
|
| 104 |
+
attention_mask: Optional[torch.ByteTensor] = None,
|
| 105 |
+
prefix_mask: Optional[torch.ByteTensor] = None,
|
| 106 |
+
sequence_id: Optional[torch.LongTensor] = None,
|
| 107 |
+
):
|
| 108 |
+
if not self._attn_bias_initialized:
|
| 109 |
+
if self.attn_bias_shape:
|
| 110 |
+
self.attn_bias = torch.zeros(self.attn_bias_shape, device=device, dtype=dtype)
|
| 111 |
+
self.attn_bias = build_attn_bias(
|
| 112 |
+
self.attn_impl,
|
| 113 |
+
self.attn_bias,
|
| 114 |
+
self.config.n_heads,
|
| 115 |
+
self.config.max_seq_len,
|
| 116 |
+
causal=self.is_causal,
|
| 117 |
+
alibi=self.alibi,
|
| 118 |
+
alibi_bias_max=self.alibi_bias_max,
|
| 119 |
+
)
|
| 120 |
+
self._attn_bias_initialized = True
|
| 121 |
+
if self.attn_impl == "flash":
|
| 122 |
+
return (self.attn_bias, attention_mask)
|
| 123 |
+
if self.attn_bias is not None:
|
| 124 |
+
self.attn_bias = self.attn_bias.to(dtype=dtype, device=device)
|
| 125 |
+
attn_bias = self.attn_bias
|
| 126 |
+
if self.prefix_lm:
|
| 127 |
+
assert isinstance(attn_bias, torch.Tensor)
|
| 128 |
+
assert isinstance(prefix_mask, torch.Tensor)
|
| 129 |
+
attn_bias = self._apply_prefix_mask(attn_bias, prefix_mask)
|
| 130 |
+
if self.attn_uses_sequence_id and sequence_id is not None:
|
| 131 |
+
assert isinstance(attn_bias, torch.Tensor)
|
| 132 |
+
attn_bias = self._apply_sequence_id(attn_bias, sequence_id)
|
| 133 |
+
if attention_mask is not None:
|
| 134 |
+
s_k = attention_mask.shape[-1]
|
| 135 |
+
if attn_bias is None:
|
| 136 |
+
attn_bias = torch.zeros((1, 1, 1, s_k), device=device, dtype=dtype)
|
| 137 |
+
else:
|
| 138 |
+
_s_k = max(0, attn_bias.size(-1) - s_k)
|
| 139 |
+
attn_bias = attn_bias[:, :, :, _s_k:]
|
| 140 |
+
if prefix_mask is not None and attention_mask.shape != prefix_mask.shape:
|
| 141 |
+
raise ValueError(f"attention_mask shape={attention_mask.shape} " + f"and prefix_mask shape={prefix_mask.shape} are not equal.")
|
| 142 |
+
min_val = torch.finfo(attn_bias.dtype).min
|
| 143 |
+
attn_bias = attn_bias.masked_fill(~attention_mask.view(-1, 1, 1, s_k), min_val)
|
| 144 |
+
return (attn_bias, None)
|
| 145 |
+
|
| 146 |
+
def _apply_prefix_mask(self, attn_bias: torch.Tensor, prefix_mask: torch.Tensor):
|
| 147 |
+
(s_k, s_q) = attn_bias.shape[-2:]
|
| 148 |
+
if s_k != self.config.max_seq_len or s_q != self.config.max_seq_len:
|
| 149 |
+
raise ValueError(
|
| 150 |
+
"attn_bias does not match the expected shape. "
|
| 151 |
+
+ f"The last two dimensions should both be {self.config.max_length} "
|
| 152 |
+
+ f"but are {s_k} and {s_q}."
|
| 153 |
+
)
|
| 154 |
+
seq_len = prefix_mask.shape[-1]
|
| 155 |
+
if seq_len > self.config.max_seq_len:
|
| 156 |
+
raise ValueError(f"prefix_mask sequence length cannot exceed max_seq_len={self.config.max_seq_len}")
|
| 157 |
+
attn_bias = attn_bias[..., :seq_len, :seq_len]
|
| 158 |
+
causal = torch.tril(torch.ones((seq_len, seq_len), dtype=torch.bool, device=prefix_mask.device)).view(1, 1, seq_len, seq_len)
|
| 159 |
+
prefix = prefix_mask.view(-1, 1, 1, seq_len)
|
| 160 |
+
cannot_attend = ~torch.logical_or(causal, prefix.bool())
|
| 161 |
+
min_val = torch.finfo(attn_bias.dtype).min
|
| 162 |
+
attn_bias = attn_bias.masked_fill(cannot_attend, min_val)
|
| 163 |
+
return attn_bias
|
| 164 |
+
|
| 165 |
+
def _apply_sequence_id(self, attn_bias: torch.Tensor, sequence_id: torch.LongTensor):
|
| 166 |
+
seq_len = sequence_id.shape[-1]
|
| 167 |
+
if seq_len > self.config.max_seq_len:
|
| 168 |
+
raise ValueError(f"sequence_id sequence length cannot exceed max_seq_len={self.config.max_seq_len}")
|
| 169 |
+
attn_bias = attn_bias[..., :seq_len, :seq_len]
|
| 170 |
+
cannot_attend = torch.logical_not(torch.eq(sequence_id.view(-1, seq_len, 1), sequence_id.view(-1, 1, seq_len))).unsqueeze(1)
|
| 171 |
+
min_val = torch.finfo(attn_bias.dtype).min
|
| 172 |
+
attn_bias = attn_bias.masked_fill(cannot_attend, min_val)
|
| 173 |
+
return attn_bias
|
| 174 |
+
|
| 175 |
+
def forward(
|
| 176 |
+
self,
|
| 177 |
+
input_ids: torch.LongTensor,
|
| 178 |
+
past_key_values: Optional[List[Tuple[torch.FloatTensor]]] = None,
|
| 179 |
+
attention_mask: Optional[torch.ByteTensor] = None,
|
| 180 |
+
prefix_mask: Optional[torch.ByteTensor] = None,
|
| 181 |
+
sequence_id: Optional[torch.LongTensor] = None,
|
| 182 |
+
return_dict: Optional[bool] = None,
|
| 183 |
+
output_attentions: Optional[bool] = None,
|
| 184 |
+
output_hidden_states: Optional[bool] = None,
|
| 185 |
+
use_cache: Optional[bool] = None,
|
| 186 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
| 187 |
+
):
|
| 188 |
+
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
| 189 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
| 190 |
+
|
| 191 |
+
if attention_mask is not None:
|
| 192 |
+
attention_mask = attention_mask.bool()
|
| 193 |
+
|
| 194 |
+
if prefix_mask is not None:
|
| 195 |
+
prefix_mask = prefix_mask.bool()
|
| 196 |
+
|
| 197 |
+
# These args are passed in by keyword in huggingface's generate function
|
| 198 |
+
# https://github.com/huggingface/transformers/blob/68287689f2f0d8b7063c400230b3766987abf18d/src/transformers/generation/utils.py#L2201-L2206
|
| 199 |
+
# but have not yet been fully implemented in MPTModel
|
| 200 |
+
if not return_dict:
|
| 201 |
+
raise NotImplementedError("return_dict False is not implemented yet for MPT")
|
| 202 |
+
if output_attentions:
|
| 203 |
+
if self.attn_impl != "torch":
|
| 204 |
+
raise NotImplementedError("output_attentions is not implemented for MPT when using attn_impl `flash` or `triton`.")
|
| 205 |
+
|
| 206 |
+
if attention_mask is not None and attention_mask[:, 0].sum() != attention_mask.shape[0] and self.training:
|
| 207 |
+
raise NotImplementedError("MPT does not support training with left padding.")
|
| 208 |
+
|
| 209 |
+
if self.prefix_lm and prefix_mask is None:
|
| 210 |
+
raise ValueError("prefix_mask is a required argument when MPT is configured with prefix_lm=True.")
|
| 211 |
+
|
| 212 |
+
# Raise a not implemented error if input_embeds is not None (this is an arg in huggingface transformers and we need to support it for PEFT)
|
| 213 |
+
if inputs_embeds is not None:
|
| 214 |
+
raise NotImplementedError("inputs_embeds is not implemented for MPT.")
|
| 215 |
+
|
| 216 |
+
if self.training:
|
| 217 |
+
if self.attn_uses_sequence_id and sequence_id is None:
|
| 218 |
+
raise ValueError(
|
| 219 |
+
"sequence_id is a required argument when MPT is configured with attn_uses_sequence_id=True " + "and the model is in train mode."
|
| 220 |
+
)
|
| 221 |
+
elif (self.attn_uses_sequence_id is False) and (sequence_id is not None):
|
| 222 |
+
warnings.warn(
|
| 223 |
+
"MPT received non-None input for `sequence_id` but is configured with attn_uses_sequence_id=False. "
|
| 224 |
+
+ "This input will be ignored. If you want the model to use `sequence_id`, set attn_uses_sequence_id to True."
|
| 225 |
+
)
|
| 226 |
+
|
| 227 |
+
S = input_ids.size(1)
|
| 228 |
+
|
| 229 |
+
assert S <= self.config.max_seq_len, f"Cannot forward input with seq_len={S}, this model only supports seq_len<={self.config.max_seq_len}"
|
| 230 |
+
|
| 231 |
+
tok_emb = self.wte(input_ids) # type: ignore
|
| 232 |
+
if self.alibi:
|
| 233 |
+
x = tok_emb
|
| 234 |
+
else:
|
| 235 |
+
past_position = 0
|
| 236 |
+
if past_key_values is not None:
|
| 237 |
+
if len(past_key_values) != self.config.n_layers:
|
| 238 |
+
raise ValueError(
|
| 239 |
+
f"past_key_values must provide a past_key_value for each attention "
|
| 240 |
+
+ f"layer in the network ({len(past_key_values)=}; {self.config.n_layers=})."
|
| 241 |
+
)
|
| 242 |
+
# For attn_impl: triton and flash the past key tensor spec is (batch, seq, dim).
|
| 243 |
+
# For attn_impl: torch the past key tensor spec is (batch, heads, head_dim, seq).
|
| 244 |
+
# Here we shift position embedding using the `seq` dim of the past key
|
| 245 |
+
past_position = past_key_values[0][0].size(1)
|
| 246 |
+
if self.attn_impl == "torch":
|
| 247 |
+
past_position = past_key_values[0][0].size(3)
|
| 248 |
+
|
| 249 |
+
if S + past_position > self.config.max_seq_len:
|
| 250 |
+
raise ValueError(
|
| 251 |
+
f"Cannot forward input with past sequence length {past_position} and current sequence length "
|
| 252 |
+
f"{S + 1}, this model only supports total sequence length <= {self.config.max_seq_len}."
|
| 253 |
+
)
|
| 254 |
+
pos = torch.arange(
|
| 255 |
+
past_position,
|
| 256 |
+
S + past_position,
|
| 257 |
+
dtype=torch.long,
|
| 258 |
+
device=input_ids.device,
|
| 259 |
+
).unsqueeze(0)
|
| 260 |
+
if attention_mask is not None:
|
| 261 |
+
# adjust the position indices to account for padding tokens
|
| 262 |
+
pos = torch.clamp(
|
| 263 |
+
pos - torch.cumsum((~attention_mask).to(torch.int32), dim=1)[:, past_position:],
|
| 264 |
+
min=0,
|
| 265 |
+
)
|
| 266 |
+
|
| 267 |
+
pos_emb = self.wpe(pos) # type: ignore
|
| 268 |
+
x = tok_emb + pos_emb
|
| 269 |
+
|
| 270 |
+
if self.embedding_fraction == 1:
|
| 271 |
+
x = self.emb_drop(x) # type: ignore
|
| 272 |
+
else:
|
| 273 |
+
# this implementation is proposed on page 7 of the GLM-130B paper https://arxiv.org/abs/2210.02414
|
| 274 |
+
x_shrunk = (x * self.embedding_fraction) + (x.detach() * (1 - self.embedding_fraction))
|
| 275 |
+
assert isinstance(self.emb_drop, nn.Module) # pyright
|
| 276 |
+
x = self.emb_drop(x_shrunk)
|
| 277 |
+
|
| 278 |
+
attn_bias, attention_mask = self._attn_bias(
|
| 279 |
+
device=x.device,
|
| 280 |
+
dtype=torch.float32,
|
| 281 |
+
attention_mask=attention_mask,
|
| 282 |
+
prefix_mask=prefix_mask,
|
| 283 |
+
sequence_id=sequence_id,
|
| 284 |
+
)
|
| 285 |
+
|
| 286 |
+
# initialize the past key values cache if it should be used
|
| 287 |
+
if use_cache and past_key_values is None:
|
| 288 |
+
past_key_values = [() for _ in range(self.config.n_layers)] # type: ignore
|
| 289 |
+
|
| 290 |
+
all_hidden_states = () if output_hidden_states else None
|
| 291 |
+
all_self_attns = () if output_attentions else None
|
| 292 |
+
for b_idx, block in enumerate(self.blocks): # type: ignore
|
| 293 |
+
if output_hidden_states:
|
| 294 |
+
assert all_hidden_states is not None # pyright
|
| 295 |
+
all_hidden_states = all_hidden_states + (x,)
|
| 296 |
+
past_key_value = past_key_values[b_idx] if past_key_values is not None else None
|
| 297 |
+
x, attn_weights, past_key_value = block(
|
| 298 |
+
x,
|
| 299 |
+
past_key_value=past_key_value,
|
| 300 |
+
attn_bias=attn_bias,
|
| 301 |
+
attention_mask=attention_mask,
|
| 302 |
+
is_causal=self.is_causal,
|
| 303 |
+
)
|
| 304 |
+
if past_key_values is not None:
|
| 305 |
+
past_key_values[b_idx] = past_key_value
|
| 306 |
+
|
| 307 |
+
if output_attentions:
|
| 308 |
+
assert all_self_attns is not None # pyright
|
| 309 |
+
all_self_attns = all_self_attns + (attn_weights,)
|
| 310 |
+
|
| 311 |
+
x = self.norm_f(x) # type: ignore
|
| 312 |
+
|
| 313 |
+
# add hidden states from the last decoder layer
|
| 314 |
+
if output_hidden_states:
|
| 315 |
+
assert all_hidden_states is not None # pyright
|
| 316 |
+
all_hidden_states = all_hidden_states + (x,)
|
| 317 |
+
|
| 318 |
+
return BaseModelOutputWithPast(
|
| 319 |
+
last_hidden_state=x,
|
| 320 |
+
past_key_values=past_key_values,
|
| 321 |
+
hidden_states=all_hidden_states,
|
| 322 |
+
attentions=all_self_attns,
|
| 323 |
+
)
|
| 324 |
+
|
| 325 |
+
# Param Initialization, needed for device='meta' fast initialization
|
| 326 |
+
def param_init_fn(self, module):
|
| 327 |
+
init_fn_name = self.config.init_config["name"]
|
| 328 |
+
MODEL_INIT_REGISTRY[init_fn_name](
|
| 329 |
+
module=module,
|
| 330 |
+
n_layers=self.config.n_layers,
|
| 331 |
+
d_model=self.config.d_model,
|
| 332 |
+
**self.config.init_config,
|
| 333 |
+
)
|
| 334 |
+
|
| 335 |
+
def fsdp_wrap_fn(self, module):
|
| 336 |
+
return isinstance(module, MPTBlock)
|
| 337 |
+
|
| 338 |
+
def activation_checkpointing_fn(self, module):
|
| 339 |
+
return isinstance(module, MPTBlock)
|
| 340 |
+
|
| 341 |
+
|
| 342 |
+
class MPTForCausalLM(MPTPreTrainedModel):
|
| 343 |
+
def __init__(self, config: MPTConfig):
|
| 344 |
+
super().__init__(config)
|
| 345 |
+
if not config.tie_word_embeddings:
|
| 346 |
+
raise ValueError("MPTForCausalLM only supports tied word embeddings")
|
| 347 |
+
self.transformer = MPTModel(config)
|
| 348 |
+
for child in self.transformer.children():
|
| 349 |
+
if isinstance(child, torch.nn.ModuleList):
|
| 350 |
+
continue
|
| 351 |
+
if isinstance(child, torch.nn.Module):
|
| 352 |
+
child._fsdp_wrap = True
|
| 353 |
+
self.logit_scale = None
|
| 354 |
+
if config.logit_scale is not None:
|
| 355 |
+
logit_scale = config.logit_scale
|
| 356 |
+
if isinstance(logit_scale, str):
|
| 357 |
+
if logit_scale == "inv_sqrt_d_model":
|
| 358 |
+
logit_scale = 1 / math.sqrt(config.d_model)
|
| 359 |
+
else:
|
| 360 |
+
raise ValueError(f"logit_scale={logit_scale!r} is not recognized as an option; use numeric value or 'inv_sqrt_d_model'.")
|
| 361 |
+
self.logit_scale = logit_scale
|
| 362 |
+
|
| 363 |
+
def get_input_embeddings(self):
|
| 364 |
+
return self.transformer.wte
|
| 365 |
+
|
| 366 |
+
def set_input_embeddings(self, value):
|
| 367 |
+
# self.transformer.wte = value
|
| 368 |
+
peudo_wte = SharedEmbedding(value.weight.shape[0], value.weight.shape[1], device=self.transformer.wte.weight.device)
|
| 369 |
+
peudo_wte.weight = value.weight
|
| 370 |
+
self.transformer.wte = peudo_wte
|
| 371 |
+
|
| 372 |
+
def get_output_embeddings(self):
|
| 373 |
+
return self.transformer.wte
|
| 374 |
+
|
| 375 |
+
def set_output_embeddings(self, new_embeddings):
|
| 376 |
+
# self.transformer.wte = new_embeddings
|
| 377 |
+
peudo_wte = SharedEmbedding(new_embeddings.weight.shape[0], new_embeddings.weight.shape[1], device=self.transformer.wte.weight.device)
|
| 378 |
+
peudo_wte.weight = new_embeddings.weight
|
| 379 |
+
self.transformer.wte = peudo_wte
|
| 380 |
+
|
| 381 |
+
def set_decoder(self, decoder):
|
| 382 |
+
self.transformer = decoder
|
| 383 |
+
|
| 384 |
+
def get_decoder(self):
|
| 385 |
+
return self.transformer
|
| 386 |
+
|
| 387 |
+
def forward(
|
| 388 |
+
self,
|
| 389 |
+
input_ids: torch.LongTensor,
|
| 390 |
+
past_key_values: Optional[List[Tuple[torch.FloatTensor]]] = None,
|
| 391 |
+
attention_mask: Optional[torch.ByteTensor] = None,
|
| 392 |
+
prefix_mask: Optional[torch.ByteTensor] = None,
|
| 393 |
+
sequence_id: Optional[torch.LongTensor] = None,
|
| 394 |
+
labels: Optional[torch.LongTensor] = None,
|
| 395 |
+
return_dict: Optional[bool] = None,
|
| 396 |
+
output_attentions: Optional[bool] = None,
|
| 397 |
+
output_hidden_states: Optional[bool] = None,
|
| 398 |
+
use_cache: Optional[bool] = None,
|
| 399 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 400 |
+
):
|
| 401 |
+
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
| 402 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
| 403 |
+
|
| 404 |
+
# if input_embeds is not none, raise a not implemented error
|
| 405 |
+
if inputs_embeds is not None:
|
| 406 |
+
raise NotImplementedError("inputs_embeds has to be None (for hf/peft support).")
|
| 407 |
+
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
| 408 |
+
outputs = self.transformer(
|
| 409 |
+
input_ids=input_ids,
|
| 410 |
+
past_key_values=past_key_values,
|
| 411 |
+
attention_mask=attention_mask,
|
| 412 |
+
prefix_mask=prefix_mask,
|
| 413 |
+
sequence_id=sequence_id,
|
| 414 |
+
return_dict=return_dict,
|
| 415 |
+
output_attentions=output_attentions,
|
| 416 |
+
output_hidden_states=output_hidden_states,
|
| 417 |
+
use_cache=use_cache,
|
| 418 |
+
)
|
| 419 |
+
|
| 420 |
+
# move outputs to same device as weights for token embedding
|
| 421 |
+
# needed to support HF `device_map`
|
| 422 |
+
logits = self.transformer.wte(
|
| 423 |
+
input=outputs.last_hidden_state.to(self.transformer.wte.weight.device),
|
| 424 |
+
unembed=True,
|
| 425 |
+
)
|
| 426 |
+
|
| 427 |
+
if self.logit_scale is not None:
|
| 428 |
+
if self.logit_scale == 0:
|
| 429 |
+
warnings.warn(f"Multiplying logits by {self.logit_scale=}. This will produce uniform (uninformative) outputs.")
|
| 430 |
+
logits *= self.logit_scale
|
| 431 |
+
|
| 432 |
+
loss = None
|
| 433 |
+
if labels is not None:
|
| 434 |
+
_labels = torch.roll(labels, shifts=-1)
|
| 435 |
+
_labels[:, -1] = -100
|
| 436 |
+
loss = F.cross_entropy(
|
| 437 |
+
logits.view(-1, logits.size(-1)),
|
| 438 |
+
_labels.to(logits.device).view(-1),
|
| 439 |
+
)
|
| 440 |
+
|
| 441 |
+
return CausalLMOutputWithPast(
|
| 442 |
+
loss=loss,
|
| 443 |
+
logits=logits,
|
| 444 |
+
past_key_values=outputs.past_key_values,
|
| 445 |
+
hidden_states=outputs.hidden_states,
|
| 446 |
+
attentions=outputs.attentions,
|
| 447 |
+
)
|
| 448 |
+
|
| 449 |
+
def param_init_fn(self, module):
|
| 450 |
+
init_fn_name = self.config.init_config["name"]
|
| 451 |
+
MODEL_INIT_REGISTRY[init_fn_name](module=module, n_layers=self.config.n_layers, d_model=self.config.d_model, **self.config.init_config)
|
| 452 |
+
|
| 453 |
+
def fsdp_wrap_fn(self, module):
|
| 454 |
+
return isinstance(module, MPTBlock)
|
| 455 |
+
|
| 456 |
+
def activation_checkpointing_fn(self, module):
|
| 457 |
+
return isinstance(module, MPTBlock)
|
| 458 |
+
|
| 459 |
+
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, attention_mask=None, **kwargs):
|
| 460 |
+
if inputs_embeds is not None:
|
| 461 |
+
raise NotImplementedError("inputs_embeds is not implemented for MPT yet")
|
| 462 |
+
attention_mask = attention_mask.bool()
|
| 463 |
+
if attention_mask[:, -1].sum() != attention_mask.shape[0]:
|
| 464 |
+
raise NotImplementedError("MPT does not support generation with right padding.")
|
| 465 |
+
if self.transformer.attn_uses_sequence_id and self.training:
|
| 466 |
+
sequence_id = torch.zeros_like(input_ids[:1])
|
| 467 |
+
else:
|
| 468 |
+
sequence_id = None
|
| 469 |
+
if past_key_values is not None:
|
| 470 |
+
input_ids = input_ids[:, -1].unsqueeze(-1)
|
| 471 |
+
if self.transformer.prefix_lm:
|
| 472 |
+
prefix_mask = torch.ones_like(attention_mask)
|
| 473 |
+
if kwargs.get("use_cache") == False:
|
| 474 |
+
raise NotImplementedError("MPT with prefix_lm=True does not support use_cache=False.")
|
| 475 |
+
else:
|
| 476 |
+
prefix_mask = None
|
| 477 |
+
return {
|
| 478 |
+
"input_ids": input_ids,
|
| 479 |
+
"attention_mask": attention_mask,
|
| 480 |
+
"prefix_mask": prefix_mask,
|
| 481 |
+
"sequence_id": sequence_id,
|
| 482 |
+
"past_key_values": past_key_values,
|
| 483 |
+
"use_cache": kwargs.get("use_cache", True),
|
| 484 |
+
}
|
| 485 |
+
|
| 486 |
+
@staticmethod
|
| 487 |
+
def _reorder_cache(past_key_values, beam_idx):
|
| 488 |
+
"""Used by HuggingFace generate when using beam search with kv-caching.
|
| 489 |
+
|
| 490 |
+
See https://github.com/huggingface/transformers/blob/3ec7a47664ebe40c40f4b722f6bb1cd30c3821ec/src/transformers/models/gpt2/modeling_gpt2.py#L1122-L1133
|
| 491 |
+
for an example in transformers.
|
| 492 |
+
"""
|
| 493 |
+
reordered_past = []
|
| 494 |
+
for layer_past in past_key_values:
|
| 495 |
+
reordered_past += [tuple((past_state.index_select(0, beam_idx) for past_state in layer_past))]
|
| 496 |
+
return reordered_past
|
mllm/flamingo/mpt/norm.py
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def _cast_if_autocast_enabled(tensor):
|
| 5 |
+
if torch.is_autocast_enabled():
|
| 6 |
+
if tensor.device.type == "cuda":
|
| 7 |
+
dtype = torch.get_autocast_gpu_dtype()
|
| 8 |
+
elif tensor.device.type == "cpu":
|
| 9 |
+
dtype = torch.get_autocast_cpu_dtype()
|
| 10 |
+
else:
|
| 11 |
+
raise NotImplementedError()
|
| 12 |
+
return tensor.to(dtype=dtype)
|
| 13 |
+
return tensor
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class LPLayerNorm(torch.nn.LayerNorm):
|
| 17 |
+
def __init__(self, normalized_shape, eps=1e-05, elementwise_affine=True, device=None, dtype=None):
|
| 18 |
+
super().__init__(normalized_shape=normalized_shape, eps=eps, elementwise_affine=elementwise_affine, device=device, dtype=dtype)
|
| 19 |
+
|
| 20 |
+
def forward(self, x):
|
| 21 |
+
module_device = x.device
|
| 22 |
+
downcast_x = _cast_if_autocast_enabled(x)
|
| 23 |
+
downcast_weight = _cast_if_autocast_enabled(self.weight) if self.weight is not None else self.weight
|
| 24 |
+
downcast_bias = _cast_if_autocast_enabled(self.bias) if self.bias is not None else self.bias
|
| 25 |
+
with torch.autocast(enabled=False, device_type=module_device.type):
|
| 26 |
+
return torch.nn.functional.layer_norm(downcast_x, self.normalized_shape, downcast_weight, downcast_bias, self.eps)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def rms_norm(x, weight=None, eps=1e-05):
|
| 30 |
+
output = x / torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps)
|
| 31 |
+
if weight is not None:
|
| 32 |
+
return output * weight
|
| 33 |
+
return output
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class RMSNorm(torch.nn.Module):
|
| 37 |
+
def __init__(self, normalized_shape, eps=1e-05, weight=True, dtype=None, device=None):
|
| 38 |
+
super().__init__()
|
| 39 |
+
self.eps = eps
|
| 40 |
+
if weight:
|
| 41 |
+
self.weight = torch.nn.Parameter(torch.ones(normalized_shape, dtype=dtype, device=device))
|
| 42 |
+
else:
|
| 43 |
+
self.register_parameter("weight", None)
|
| 44 |
+
|
| 45 |
+
def forward(self, x):
|
| 46 |
+
return rms_norm(x.float(), self.weight, self.eps).to(dtype=x.dtype)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class LPRMSNorm(RMSNorm):
|
| 50 |
+
def __init__(self, normalized_shape, eps=1e-05, weight=True, dtype=None, device=None):
|
| 51 |
+
super().__init__(normalized_shape=normalized_shape, eps=eps, weight=weight, dtype=dtype, device=device)
|
| 52 |
+
|
| 53 |
+
def forward(self, x):
|
| 54 |
+
downcast_x = _cast_if_autocast_enabled(x)
|
| 55 |
+
downcast_weight = _cast_if_autocast_enabled(self.weight) if self.weight is not None else self.weight
|
| 56 |
+
with torch.autocast(enabled=False, device_type=x.device.type):
|
| 57 |
+
return rms_norm(downcast_x, downcast_weight, self.eps).to(dtype=x.dtype)
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
NORM_CLASS_REGISTRY = {"layernorm": torch.nn.LayerNorm, "low_precision_layernorm": LPLayerNorm, "rmsnorm": RMSNorm, "low_precision_rmsnorm": LPRMSNorm}
|
mllm/flamingo/mpt/param_init_fns.py
ADDED
|
@@ -0,0 +1,369 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import warnings
|
| 3 |
+
from collections.abc import Sequence
|
| 4 |
+
from functools import partial
|
| 5 |
+
from typing import Optional, Tuple, Union
|
| 6 |
+
import torch
|
| 7 |
+
from torch import nn
|
| 8 |
+
from .norm import NORM_CLASS_REGISTRY
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def torch_default_param_init_fn_(module: nn.Module, verbose: int = 0, **kwargs):
|
| 12 |
+
del kwargs
|
| 13 |
+
if verbose > 1:
|
| 14 |
+
warnings.warn(f"Initializing network using module's reset_parameters attribute")
|
| 15 |
+
if hasattr(module, "reset_parameters"):
|
| 16 |
+
module.reset_parameters()
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def fused_init_helper_(module: nn.Module, init_fn_):
|
| 20 |
+
_fused = getattr(module, "_fused", None)
|
| 21 |
+
if _fused is None:
|
| 22 |
+
raise RuntimeError(f"Internal logic error")
|
| 23 |
+
(dim, splits) = _fused
|
| 24 |
+
splits = (0, *splits, module.weight.size(dim))
|
| 25 |
+
for s, e in zip(splits[:-1], splits[1:]):
|
| 26 |
+
slice_indices = [slice(None)] * module.weight.ndim
|
| 27 |
+
slice_indices[dim] = slice(s, e)
|
| 28 |
+
init_fn_(module.weight[slice_indices])
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def generic_param_init_fn_(
|
| 32 |
+
module: nn.Module,
|
| 33 |
+
init_fn_,
|
| 34 |
+
n_layers: int,
|
| 35 |
+
d_model: Optional[int] = None,
|
| 36 |
+
init_div_is_residual: Union[int, float, str, bool] = True,
|
| 37 |
+
emb_init_std: Optional[float] = None,
|
| 38 |
+
emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None,
|
| 39 |
+
verbose: int = 0,
|
| 40 |
+
**kwargs,
|
| 41 |
+
):
|
| 42 |
+
del kwargs
|
| 43 |
+
if verbose > 1:
|
| 44 |
+
warnings.warn(f"If model has bias parameters they are initialized to 0.")
|
| 45 |
+
init_div_is_residual = init_div_is_residual
|
| 46 |
+
if init_div_is_residual is False:
|
| 47 |
+
div_is_residual = 1.0
|
| 48 |
+
elif init_div_is_residual is True:
|
| 49 |
+
div_is_residual = math.sqrt(2 * n_layers)
|
| 50 |
+
elif isinstance(init_div_is_residual, float) or isinstance(init_div_is_residual, int):
|
| 51 |
+
div_is_residual = init_div_is_residual
|
| 52 |
+
elif isinstance(init_div_is_residual, str) and init_div_is_residual.isnumeric():
|
| 53 |
+
div_is_residual = float(init_div_is_residual)
|
| 54 |
+
else:
|
| 55 |
+
div_is_residual = 1.0
|
| 56 |
+
raise ValueError(f"Expected init_div_is_residual to be boolean or numeric, got {init_div_is_residual}")
|
| 57 |
+
if init_div_is_residual is not False:
|
| 58 |
+
if verbose > 1:
|
| 59 |
+
warnings.warn(
|
| 60 |
+
f"Initializing _is_residual layers then dividing them by {div_is_residual:.3f}. "
|
| 61 |
+
+ f"Set `init_div_is_residual: false` in init config to disable this."
|
| 62 |
+
)
|
| 63 |
+
if isinstance(module, nn.Linear):
|
| 64 |
+
if hasattr(module, "_fused"):
|
| 65 |
+
fused_init_helper_(module, init_fn_)
|
| 66 |
+
else:
|
| 67 |
+
init_fn_(module.weight)
|
| 68 |
+
if module.bias is not None:
|
| 69 |
+
torch.nn.init.zeros_(module.bias)
|
| 70 |
+
if init_div_is_residual is not False and getattr(module, "_is_residual", False):
|
| 71 |
+
with torch.no_grad():
|
| 72 |
+
module.weight.div_(div_is_residual)
|
| 73 |
+
elif isinstance(module, nn.Embedding):
|
| 74 |
+
if emb_init_std is not None:
|
| 75 |
+
std = emb_init_std
|
| 76 |
+
if std == 0:
|
| 77 |
+
warnings.warn(f"Embedding layer initialized to 0.")
|
| 78 |
+
emb_init_fn_ = partial(torch.nn.init.normal_, mean=0.0, std=std)
|
| 79 |
+
if verbose > 1:
|
| 80 |
+
warnings.warn(f"Embedding layer initialized using normal distribution with mean=0 and std={std!r}.")
|
| 81 |
+
elif emb_init_uniform_lim is not None:
|
| 82 |
+
lim = emb_init_uniform_lim
|
| 83 |
+
if isinstance(lim, Sequence):
|
| 84 |
+
if len(lim) > 2:
|
| 85 |
+
raise ValueError(f"Uniform init requires a min and a max limit. User input: {lim}.")
|
| 86 |
+
if lim[0] == lim[1]:
|
| 87 |
+
warnings.warn(f"Embedding layer initialized to {lim[0]}.")
|
| 88 |
+
else:
|
| 89 |
+
if lim == 0:
|
| 90 |
+
warnings.warn(f"Embedding layer initialized to 0.")
|
| 91 |
+
lim = [-lim, lim]
|
| 92 |
+
(a, b) = lim
|
| 93 |
+
emb_init_fn_ = partial(torch.nn.init.uniform_, a=a, b=b)
|
| 94 |
+
if verbose > 1:
|
| 95 |
+
warnings.warn(f"Embedding layer initialized using uniform distribution in range {lim}.")
|
| 96 |
+
else:
|
| 97 |
+
emb_init_fn_ = init_fn_
|
| 98 |
+
emb_init_fn_(module.weight)
|
| 99 |
+
elif isinstance(module, tuple(set(NORM_CLASS_REGISTRY.values()))):
|
| 100 |
+
if verbose > 1:
|
| 101 |
+
warnings.warn(f"Norm weights are set to 1. If norm layer has a bias it is initialized to 0.")
|
| 102 |
+
if hasattr(module, "weight") and module.weight is not None:
|
| 103 |
+
torch.nn.init.ones_(module.weight)
|
| 104 |
+
if hasattr(module, "bias") and module.bias is not None:
|
| 105 |
+
torch.nn.init.zeros_(module.bias)
|
| 106 |
+
elif isinstance(module, nn.MultiheadAttention):
|
| 107 |
+
if module._qkv_same_embed_dim:
|
| 108 |
+
assert module.in_proj_weight is not None
|
| 109 |
+
assert module.q_proj_weight is None and module.k_proj_weight is None and (module.v_proj_weight is None)
|
| 110 |
+
assert d_model is not None
|
| 111 |
+
_d = d_model
|
| 112 |
+
splits = (0, _d, 2 * _d, 3 * _d)
|
| 113 |
+
for s, e in zip(splits[:-1], splits[1:]):
|
| 114 |
+
init_fn_(module.in_proj_weight[s:e])
|
| 115 |
+
else:
|
| 116 |
+
assert module.q_proj_weight is not None and module.k_proj_weight is not None and (module.v_proj_weight is not None)
|
| 117 |
+
assert module.in_proj_weight is None
|
| 118 |
+
init_fn_(module.q_proj_weight)
|
| 119 |
+
init_fn_(module.k_proj_weight)
|
| 120 |
+
init_fn_(module.v_proj_weight)
|
| 121 |
+
if module.in_proj_bias is not None:
|
| 122 |
+
torch.nn.init.zeros_(module.in_proj_bias)
|
| 123 |
+
if module.bias_k is not None:
|
| 124 |
+
torch.nn.init.zeros_(module.bias_k)
|
| 125 |
+
if module.bias_v is not None:
|
| 126 |
+
torch.nn.init.zeros_(module.bias_v)
|
| 127 |
+
init_fn_(module.out_proj.weight)
|
| 128 |
+
if init_div_is_residual is not False and getattr(module.out_proj, "_is_residual", False):
|
| 129 |
+
with torch.no_grad():
|
| 130 |
+
module.out_proj.weight.div_(div_is_residual)
|
| 131 |
+
if module.out_proj.bias is not None:
|
| 132 |
+
torch.nn.init.zeros_(module.out_proj.bias)
|
| 133 |
+
else:
|
| 134 |
+
for _ in module.parameters(recurse=False):
|
| 135 |
+
raise NotImplementedError(f"{module.__class__.__name__} parameters are not initialized by param_init_fn.")
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def _normal_init_(std, mean=0.0):
|
| 139 |
+
return partial(torch.nn.init.normal_, mean=mean, std=std)
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
def _normal_param_init_fn_(
|
| 143 |
+
module: nn.Module,
|
| 144 |
+
std: float,
|
| 145 |
+
n_layers: int,
|
| 146 |
+
d_model: Optional[int] = None,
|
| 147 |
+
init_div_is_residual: Union[int, float, str, bool] = True,
|
| 148 |
+
emb_init_std: Optional[float] = None,
|
| 149 |
+
emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None,
|
| 150 |
+
verbose: int = 0,
|
| 151 |
+
**kwargs,
|
| 152 |
+
):
|
| 153 |
+
del kwargs
|
| 154 |
+
init_fn_ = _normal_init_(std=std)
|
| 155 |
+
if verbose > 1:
|
| 156 |
+
warnings.warn(f"Using torch.nn.init.normal_ init fn mean=0.0, std={std}")
|
| 157 |
+
generic_param_init_fn_(
|
| 158 |
+
module=module,
|
| 159 |
+
init_fn_=init_fn_,
|
| 160 |
+
d_model=d_model,
|
| 161 |
+
n_layers=n_layers,
|
| 162 |
+
init_div_is_residual=init_div_is_residual,
|
| 163 |
+
emb_init_std=emb_init_std,
|
| 164 |
+
emb_init_uniform_lim=emb_init_uniform_lim,
|
| 165 |
+
verbose=verbose,
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
def baseline_param_init_fn_(
|
| 170 |
+
module: nn.Module,
|
| 171 |
+
init_std: float,
|
| 172 |
+
n_layers: int,
|
| 173 |
+
d_model: Optional[int] = None,
|
| 174 |
+
init_div_is_residual: Union[int, float, str, bool] = True,
|
| 175 |
+
emb_init_std: Optional[float] = None,
|
| 176 |
+
emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None,
|
| 177 |
+
verbose: int = 0,
|
| 178 |
+
**kwargs,
|
| 179 |
+
):
|
| 180 |
+
del kwargs
|
| 181 |
+
if init_std is None:
|
| 182 |
+
raise ValueError("You must set model.init_config['init_std'] to a float value to use the default initialization scheme.")
|
| 183 |
+
_normal_param_init_fn_(
|
| 184 |
+
module=module,
|
| 185 |
+
std=init_std,
|
| 186 |
+
d_model=d_model,
|
| 187 |
+
n_layers=n_layers,
|
| 188 |
+
init_div_is_residual=init_div_is_residual,
|
| 189 |
+
emb_init_std=emb_init_std,
|
| 190 |
+
emb_init_uniform_lim=emb_init_uniform_lim,
|
| 191 |
+
verbose=verbose,
|
| 192 |
+
)
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
def small_param_init_fn_(
|
| 196 |
+
module: nn.Module,
|
| 197 |
+
n_layers: int,
|
| 198 |
+
d_model: int,
|
| 199 |
+
init_div_is_residual: Union[int, float, str, bool] = True,
|
| 200 |
+
emb_init_std: Optional[float] = None,
|
| 201 |
+
emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None,
|
| 202 |
+
verbose: int = 0,
|
| 203 |
+
**kwargs,
|
| 204 |
+
):
|
| 205 |
+
del kwargs
|
| 206 |
+
std = math.sqrt(2 / (5 * d_model))
|
| 207 |
+
_normal_param_init_fn_(
|
| 208 |
+
module=module,
|
| 209 |
+
std=std,
|
| 210 |
+
d_model=d_model,
|
| 211 |
+
n_layers=n_layers,
|
| 212 |
+
init_div_is_residual=init_div_is_residual,
|
| 213 |
+
emb_init_std=emb_init_std,
|
| 214 |
+
emb_init_uniform_lim=emb_init_uniform_lim,
|
| 215 |
+
verbose=verbose,
|
| 216 |
+
)
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
def neox_param_init_fn_(
|
| 220 |
+
module: nn.Module,
|
| 221 |
+
n_layers: int,
|
| 222 |
+
d_model: int,
|
| 223 |
+
emb_init_std: Optional[float] = None,
|
| 224 |
+
emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None,
|
| 225 |
+
verbose: int = 0,
|
| 226 |
+
**kwargs,
|
| 227 |
+
):
|
| 228 |
+
"""From section 2.3.1 of GPT-NeoX-20B:
|
| 229 |
+
|
| 230 |
+
An Open-Source AutoregressiveLanguage Model — Black et. al. (2022)
|
| 231 |
+
see https://github.com/EleutherAI/gpt-neox/blob/9610391ab319403cef079b438edd016a2443af54/megatron/model/init_functions.py#L151
|
| 232 |
+
and https://github.com/EleutherAI/gpt-neox/blob/main/megatron/model/transformer.py
|
| 233 |
+
"""
|
| 234 |
+
del kwargs
|
| 235 |
+
residual_div = n_layers / math.sqrt(10)
|
| 236 |
+
if verbose > 1:
|
| 237 |
+
warnings.warn(f"setting init_div_is_residual to {residual_div}")
|
| 238 |
+
small_param_init_fn_(
|
| 239 |
+
module=module,
|
| 240 |
+
d_model=d_model,
|
| 241 |
+
n_layers=n_layers,
|
| 242 |
+
init_div_is_residual=residual_div,
|
| 243 |
+
emb_init_std=emb_init_std,
|
| 244 |
+
emb_init_uniform_lim=emb_init_uniform_lim,
|
| 245 |
+
verbose=verbose,
|
| 246 |
+
)
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
def kaiming_uniform_param_init_fn_(
|
| 250 |
+
module: nn.Module,
|
| 251 |
+
n_layers: int,
|
| 252 |
+
d_model: Optional[int] = None,
|
| 253 |
+
init_div_is_residual: Union[int, float, str, bool] = True,
|
| 254 |
+
emb_init_std: Optional[float] = None,
|
| 255 |
+
emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None,
|
| 256 |
+
init_gain: float = 0,
|
| 257 |
+
fan_mode: str = "fan_in",
|
| 258 |
+
init_nonlinearity: str = "leaky_relu",
|
| 259 |
+
verbose: int = 0,
|
| 260 |
+
**kwargs,
|
| 261 |
+
):
|
| 262 |
+
del kwargs
|
| 263 |
+
if verbose > 1:
|
| 264 |
+
warnings.warn(f"Using nn.init.kaiming_uniform_ init fn with parameters: " + f"a={init_gain}, mode={fan_mode}, nonlinearity={init_nonlinearity}")
|
| 265 |
+
kaiming_uniform_ = partial(nn.init.kaiming_uniform_, a=init_gain, mode=fan_mode, nonlinearity=init_nonlinearity)
|
| 266 |
+
generic_param_init_fn_(
|
| 267 |
+
module=module,
|
| 268 |
+
init_fn_=kaiming_uniform_,
|
| 269 |
+
d_model=d_model,
|
| 270 |
+
n_layers=n_layers,
|
| 271 |
+
init_div_is_residual=init_div_is_residual,
|
| 272 |
+
emb_init_std=emb_init_std,
|
| 273 |
+
emb_init_uniform_lim=emb_init_uniform_lim,
|
| 274 |
+
verbose=verbose,
|
| 275 |
+
)
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
def kaiming_normal_param_init_fn_(
|
| 279 |
+
module: nn.Module,
|
| 280 |
+
n_layers: int,
|
| 281 |
+
d_model: Optional[int] = None,
|
| 282 |
+
init_div_is_residual: Union[int, float, str, bool] = True,
|
| 283 |
+
emb_init_std: Optional[float] = None,
|
| 284 |
+
emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None,
|
| 285 |
+
init_gain: float = 0,
|
| 286 |
+
fan_mode: str = "fan_in",
|
| 287 |
+
init_nonlinearity: str = "leaky_relu",
|
| 288 |
+
verbose: int = 0,
|
| 289 |
+
**kwargs,
|
| 290 |
+
):
|
| 291 |
+
del kwargs
|
| 292 |
+
if verbose > 1:
|
| 293 |
+
warnings.warn(f"Using nn.init.kaiming_normal_ init fn with parameters: " + f"a={init_gain}, mode={fan_mode}, nonlinearity={init_nonlinearity}")
|
| 294 |
+
kaiming_normal_ = partial(torch.nn.init.kaiming_normal_, a=init_gain, mode=fan_mode, nonlinearity=init_nonlinearity)
|
| 295 |
+
generic_param_init_fn_(
|
| 296 |
+
module=module,
|
| 297 |
+
init_fn_=kaiming_normal_,
|
| 298 |
+
d_model=d_model,
|
| 299 |
+
n_layers=n_layers,
|
| 300 |
+
init_div_is_residual=init_div_is_residual,
|
| 301 |
+
emb_init_std=emb_init_std,
|
| 302 |
+
emb_init_uniform_lim=emb_init_uniform_lim,
|
| 303 |
+
verbose=verbose,
|
| 304 |
+
)
|
| 305 |
+
|
| 306 |
+
|
| 307 |
+
def xavier_uniform_param_init_fn_(
|
| 308 |
+
module: nn.Module,
|
| 309 |
+
n_layers: int,
|
| 310 |
+
d_model: Optional[int] = None,
|
| 311 |
+
init_div_is_residual: Union[int, float, str, bool] = True,
|
| 312 |
+
emb_init_std: Optional[float] = None,
|
| 313 |
+
emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None,
|
| 314 |
+
init_gain: float = 0,
|
| 315 |
+
verbose: int = 0,
|
| 316 |
+
**kwargs,
|
| 317 |
+
):
|
| 318 |
+
del kwargs
|
| 319 |
+
xavier_uniform_ = partial(torch.nn.init.xavier_uniform_, gain=init_gain)
|
| 320 |
+
if verbose > 1:
|
| 321 |
+
warnings.warn(f"Using torch.nn.init.xavier_uniform_ init fn with parameters: " + f"gain={init_gain}")
|
| 322 |
+
generic_param_init_fn_(
|
| 323 |
+
module=module,
|
| 324 |
+
init_fn_=xavier_uniform_,
|
| 325 |
+
d_model=d_model,
|
| 326 |
+
n_layers=n_layers,
|
| 327 |
+
init_div_is_residual=init_div_is_residual,
|
| 328 |
+
emb_init_std=emb_init_std,
|
| 329 |
+
emb_init_uniform_lim=emb_init_uniform_lim,
|
| 330 |
+
verbose=verbose,
|
| 331 |
+
)
|
| 332 |
+
|
| 333 |
+
|
| 334 |
+
def xavier_normal_param_init_fn_(
|
| 335 |
+
module: nn.Module,
|
| 336 |
+
n_layers: int,
|
| 337 |
+
d_model: Optional[int] = None,
|
| 338 |
+
init_div_is_residual: Union[int, float, str, bool] = True,
|
| 339 |
+
emb_init_std: Optional[float] = None,
|
| 340 |
+
emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None,
|
| 341 |
+
init_gain: float = 0,
|
| 342 |
+
verbose: int = 0,
|
| 343 |
+
**kwargs,
|
| 344 |
+
):
|
| 345 |
+
xavier_normal_ = partial(torch.nn.init.xavier_normal_, gain=init_gain)
|
| 346 |
+
if verbose > 1:
|
| 347 |
+
warnings.warn(f"Using torch.nn.init.xavier_normal_ init fn with parameters: " + f"gain={init_gain}")
|
| 348 |
+
generic_param_init_fn_(
|
| 349 |
+
module=module,
|
| 350 |
+
init_fn_=xavier_normal_,
|
| 351 |
+
d_model=d_model,
|
| 352 |
+
n_layers=n_layers,
|
| 353 |
+
init_div_is_residual=init_div_is_residual,
|
| 354 |
+
emb_init_std=emb_init_std,
|
| 355 |
+
emb_init_uniform_lim=emb_init_uniform_lim,
|
| 356 |
+
verbose=verbose,
|
| 357 |
+
)
|
| 358 |
+
|
| 359 |
+
|
| 360 |
+
MODEL_INIT_REGISTRY = {
|
| 361 |
+
"default_": torch_default_param_init_fn_,
|
| 362 |
+
"baseline_": baseline_param_init_fn_,
|
| 363 |
+
"kaiming_uniform_": kaiming_uniform_param_init_fn_,
|
| 364 |
+
"kaiming_normal_": kaiming_normal_param_init_fn_,
|
| 365 |
+
"neox_init_": neox_param_init_fn_,
|
| 366 |
+
"small_init_": small_param_init_fn_,
|
| 367 |
+
"xavier_uniform_": xavier_uniform_param_init_fn_,
|
| 368 |
+
"xavier_normal_": xavier_normal_param_init_fn_,
|
| 369 |
+
}
|
mllm/flamingo/mpt_redpajama/__init__.py
ADDED
|
File without changes
|
mllm/flamingo/mpt_redpajama/__pycache__/__init__.cpython-39.pyc
ADDED
|
Binary file (216 Bytes). View file
|
|
|